Docs

Quick Start

Install, run, reproduce. Minimal steps for nano-PEARL.

Overview

Prerequisites and sensible defaults.

Python
≥ 3.12
PyTorch
≥ 2.4
CUDA
12.x
Draft TP
1
Target TP
1, 2, 4
Dtype
bfloat16

Future release: additional tensor-parallel sizes (e.g. 6, 7) will be supported.

1. Installation

nano-PEARL builds on nano-vllm. We recommend Python 3.12 and installing from source with uv (or pip).

1.1 Create environment

conda create -n nano-pearl python=3.12 -y
conda activate nano-pearl

1.2 Install package

From source (recommended)
uv pip install -e .
From GitHub
pip install git+https://github.com/smart-lty/nano-PEARL.git

2. Basic Usage

The API mirrors vLLM / nano-vllm. Provide both a draft model and a target model with their tensor-parallel sizes.

2.1 example.py

Minimal single-request example on 2 GPUs (TP: 1 + 1). Edit paths, then run:

python example.py

example.py on GitHub — quick single-turn Q&A.

View example.py source
import argparse
from nano_pearl import PEARLConfig, PEARLEngine, SamplingParams, logger


def main():
    draft_model_path = "/path/to/draft/model"
    target_model_path = "/path/to/target/model"

    config = PEARLConfig(
        draft_model_path,
        target_model_path,
        draft_tensor_parallel_size=1,
        target_tensor_parallel_size=1,
        gpu_memory_utilization=0.9,
    )
    engine = PEARLEngine(config)

    prompt = "Explain quantum computing in simple terms"
    sampling_params = SamplingParams(
        temperature=0.0,
        max_tokens=256,
        ignore_eos=False,
    )
    engine.add_request(prompt, sampling_params)

    output_text, num_tokens, num_acc_tokens, elapsed_time = engine.generate()
    logger.info("Completion:", color="yellow")
    logger.info(output_text[0])

    throughput = num_tokens[0] / elapsed_time if elapsed_time else 0.0
    mat = sum(num_acc_tokens[0]) / len(num_acc_tokens[0])
    logger.info(
        "Tokens: %d, Time: %.2fs, Throughput: %.2f tok/s, MAT: %.2f",
        num_tokens[0],
        elapsed_time,
        throughput,
        mat,
    )


if __name__ == "__main__":
    main()

2.2 bench.py

Batch benchmarking utility with warmup, optional AR baseline, and simple CLI. You can use the default prompts under static/default_prompts.txt or more benchmarks under benchmark/data.

Common invocation:

python bench.py \
  --draft-model "/path/to/draft" \
  --target-model "/path/to/target" \
  --draft-tp 1 --target-tp 2 \
  --max-tokens 200 --temperature 0 \
  --run-ar-benchmark -v

Notes:

  • --ignore-eos: generate until reaching max_tokens or EOS if omitted.
  • -p/--custom-prompts: provide one or more prompts inline; otherwise defaults are used.
  • Warmup runs once with a short prompt to stabilize kernels and KV caches.

bench.py on GitHub — multi-batch/dataset benchmark.

View bench.py source
import sys
import copy
import argparse
import os
from random import seed
import time
import torch
from nano_pearl import PEARLConfig, PEARLEngine, SamplingParams, logger


def parse_args():
    parser = argparse.ArgumentParser(description="PEARL Benchmark Tool")

    parser.add_argument(
        "--draft-model",
        "-d",
        type=str,
        required=True,
        help="Draft model path (required)",
    )
    parser.add_argument(
        "--target-model",
        "-t",
        type=str,
        required=True,
        help="Target model path (required)",
    )
    parser.add_argument(
        "--draft-tp",
        type=int,
        default=1,
        help="Draft model tensor parallel size (default: 1)",
    )
    parser.add_argument(
        "--target-tp",
        type=int,
        default=2,
        help="Target model tensor parallel size (default: 2)",
    )
    parser.add_argument(
        "--gpu-memory-utilization",
        type=float,
        default=0.9,
        help="GPU memory utilization (default: 0.9)",
    )
    parser.add_argument(
        "--temperature",
        "-temp",
        type=float,
        default=0.0,
        help="Sampling temperature (default: 0.0)",
    )
    parser.add_argument(
        "--max-tokens",
        type=int,
        default=200,
        help="Maximum tokens to generate (default: 200)",
    )
    parser.add_argument(
        "--ignore-eos",
        "-noeos",
        action="store_true",
        help="Ignore EOS token (default: False)",
    )
    parser.add_argument(
        "--run-ar-benchmark",
        "-ar",
        action="store_true",
        help="Run AR (Autoregressive) benchmark (default: False)",
    )
    parser.add_argument(
        "--custom-prompts",
        "-p",
        type=str,
        nargs='+',
        help="Custom prompts for benchmark",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=0,
        help="Random seed (default: 0)",
    )
    parser.add_argument(
        "--verbose",
        "-v",
        action="store_true",
        help="Verbose output (default: False)",
    )

    return parser.parse_args()


def get_default_prompts():
    script_dir = os.path.dirname(os.path.abspath(__file__))
    prompts_file = os.path.join(script_dir, "static", "default_prompts.txt")
    with open(prompts_file, "r", encoding="utf-8") as f:
        prompts = [line.strip() for line in f.readlines() if line.strip()]
    return prompts


if __name__ == "__main__":
    args = parse_args()

    seed(args.seed)
    draft_model_path = args.draft_model
    target_model_path = args.target_model
    config = PEARLConfig(
        draft_model_path,
        target_model_path,
        draft_tensor_parallel_size=args.draft_tp,
        target_tensor_parallel_size=args.target_tp,
        gpu_memory_utilization=args.gpu_memory_utilization,
    )
    engine = PEARLEngine(config)

    # warmup
    prompt = "Benchmark:"
    sampling_params = SamplingParams(
        temperature=0,
        ignore_eos=False,
        max_tokens=512,
    )
    engine.add_request(prompt, sampling_params)
    output_text, num_tokens, num_acc_tokens, elapsed_time = engine.generate()
    throughput = (
        sum(num_tokens) / elapsed_time if elapsed_time else 0.0
    )
    logger.info(
        "[Warmup] Total: %stok, Time: %.2fs, Throughput: %.2f tok/s, MAT: %s",
        sum(num_tokens),
        elapsed_time,
        throughput,
        [sum(n) / len(n) for n in num_acc_tokens],
    )

    prompts = args.custom_prompts if args.custom_prompts else get_default_prompts()
    sampling_params = SamplingParams(
        temperature=args.temperature,
        ignore_eos=args.ignore_eos,
        max_tokens=args.max_tokens,
    )
    for prompt in prompts:
        engine.add_request(prompt, copy.deepcopy(sampling_params))

    output_text, num_tokens, num_acc_tokens, elapsed_time = engine.bench_generate(
        num_pearl_steps=100
    )
    throughput = sum(num_tokens) / elapsed_time if elapsed_time else 0.0
    logger.info(
        "[Bench] Total: %stok, Time: %.2fs, Throughput: %.2f tok/s, MAT: %s",
        sum(num_tokens),
        elapsed_time,
        throughput,
        [sum(n) / len(n) for n in num_acc_tokens],
    )
Tip
Choose --target-tp according to your available GPUs and model size; nano-PEARL currently supports static TP per model group.

3. Simple Benchmark

Use benchmark/bench_example.sh to reproduce the default evaluation settings for nano-PEARL. The script wires the same parameters shown on the Benchmark page (TP: 1/2, temperature 0, max tokens 200, optional AR baseline).

  1. Download or point to local draft/target checkpoints.
  2. Run the helper script with your paths (dataset defaults to all):
bash benchmark/bench_example.sh \
  /models/meta-llama/Meta-Llama-3-8B-Instruct \
  /models/meta-llama/Meta-Llama-3-70B-Instruct

Customize quickly:

  • Pass a dataset (e.g. gsm8k) as the third argument.
  • Switch mode to random (fourth argument) to sample synthetic prompts.
bash benchmark/bench_example.sh \
  /path/to/draft \
  /path/to/target \
  gsm8k \
  random
Outputs & follow-up
The script writes intermediate logs to bench.log. Compare results and richer breakdowns on the interactive benchmark dashboard or rerun with --run-ar-benchmark disabled if you only need PEARL runs.

4. Troubleshooting

flash-attn build fails
Install a torch build matching your CUDA first, then reinstall. Prefer prebuilt wheels when available.
Slow first run
Run a short warmup prompt to stabilize kernels and KV caches before measuring throughput.
Out-of-Memory (OOM)
Memory & Batching Parameters (pearl_config.py, pearl_engine.py)
  • gpu_memory_utilization: float — fraction of GPU memory reserved for KV cache. Key knob. Recommendation: 0.9.
Note: tune the following based on GPU memory and deployment scenario.
  • max_num_batched_tokens: int — max tokens in a single batch (batch_size * seq_len). Recommendation: 16384 on 80GB (H100/A100); 8192 on 40GB.
  • max_num_seqs: int — max concurrent sequences. Recommendation: 512 (80GB); 128256 (40GB).
  • max_model_len: int — maximum context length (prompt + generated tokens).
  • set_gamma_batch_size: list[int] — batch sizes for auto-tuning when gamma = -1. Recommendation: default [1,2,4,8,16,32,64]; reduce (e.g., [1,2,4,8]) on constrained GPUs.
Practical tip: if OOM occurs, first lower max_num_batched_tokens or max_num_seqs, then adjust gpu_memory_utilization slightly (e.g., 0.85).