From 1c2c647c7674dc65e56d1f853e50f9ec657af422 Mon Sep 17 00:00:00 2001 From: maocheng23 Date: Sat, 27 Jun 2026 22:10:23 -0700 Subject: [PATCH 1/2] feat: DSpark trainer (DFlash + Markov/confidence heads + L1 distillation) Port of TorchSpec PR #129 to SpecForge. Adds: - specforge/modeling/draft/dspark.py: DSparkConfig, VanillaMarkov, AcceptRatePredictor, DSparkDraftModel (subclass of DFlashDraftModel) - specforge/core/dspark.py: OnlineDSparkModel (subclass of OnlineDFlashModel) with Markov-biased logits + CE + L1 distribution distillation + confidence BCE and a pooled global-mean loss - scripts/train_dspark.py: training driver (clone of train_dflash.py) - configs/qwen3-8b-dspark.json, examples/run_qwen3_8b_dspark_online.sh - last_hidden_states surfaced from the DFlash target backends (HF + sglang) - tests/test_utils/test_dspark.py: 11 CPU unit tests Co-Authored-By: Claude Opus 4.8 --- configs/qwen3-8b-dspark.json | 49 ++ examples/run_qwen3_8b_dspark_online.sh | 51 ++ scripts/train_dspark.py | 719 ++++++++++++++++++ specforge/core/__init__.py | 2 + specforge/core/dspark.py | 334 ++++++++ specforge/modeling/draft/__init__.py | 12 + specforge/modeling/draft/dspark.py | 160 ++++ .../modeling/target/dflash_target_model.py | 39 +- tests/test_utils/test_dspark.py | 288 +++++++ 9 files changed, 1645 insertions(+), 9 deletions(-) create mode 100644 configs/qwen3-8b-dspark.json create mode 100644 examples/run_qwen3_8b_dspark_online.sh create mode 100644 scripts/train_dspark.py create mode 100644 specforge/core/dspark.py create mode 100644 specforge/modeling/draft/dspark.py create mode 100644 tests/test_utils/test_dspark.py diff --git a/configs/qwen3-8b-dspark.json b/configs/qwen3-8b-dspark.json new file mode 100644 index 000000000..4b757ea4e --- /dev/null +++ b/configs/qwen3-8b-dspark.json @@ -0,0 +1,49 @@ +{ + "architectures": [ + "DSparkDraftModel" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "auto_map": { + "AutoModel": "dspark.DSparkDraftModel" + }, + "block_size": 16, + "bos_token_id": 151643, + "dflash_config": { + "mask_token_id": 151669, + "target_layer_ids": [1, 9, 17, 25, 33] + }, + "dtype": "bfloat16", + "enable_confidence_head": true, + "confidence_head_with_markov": true, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12288, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "markov_head_type": "vanilla", + "markov_rank": 256, + "max_position_embeddings": 40960, + "max_window_layers": 5, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 5, + "num_key_value_heads": 8, + "num_target_layers": 36, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": false, + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936 +} diff --git a/examples/run_qwen3_8b_dspark_online.sh b/examples/run_qwen3_8b_dspark_online.sh new file mode 100644 index 000000000..6d3936bea --- /dev/null +++ b/examples/run_qwen3_8b_dspark_online.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# DSpark online training for Qwen3-8B. +# +# DSpark = DFlash block-diffusion drafter + EAGLE-style Markov & confidence heads, +# trained with cross-entropy + L1 distribution distillation + confidence BCE. +# The L1 / confidence losses need the target model's FINAL hidden state, so the +# target backend must surface it. The 'hf' backend (default below) always does; +# the 'sglang' backend does when its runner returns both the captured aux stream +# and the final hidden state. To train CE-only (no target final hidden state), +# pass: --l1-loss-alpha 0 --no-confidence-head --ce-loss-alpha 1.0 + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels +export SPECFORGE_DATA_NUM_PROC=32 +NUM_GPUS=${1:-8} + +ATTENTION_BACKEND=${2:-flex_attention} +TARGET_BACKEND=${3:-hf} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_dspark.py \ + --target-model-path Qwen/Qwen3-8B \ + --draft-config-path $ROOT_DIR/configs/qwen3-8b-dspark.json \ + --train-data-path $ROOT_DIR/cache/dataset/perfectblend_qwen3-8b_regen.jsonl \ + --output-dir $ROOT_DIR/outputs/qwen3-8b-dspark \ + --num-epochs 6 \ + --batch-size 4 \ + --learning-rate 6e-4 \ + --warmup-ratio 0.04 \ + --max-grad-norm 1.0 \ + --max-length 3072 \ + --chat-template qwen \ + --attention-backend $ATTENTION_BACKEND \ + --loss-decay-gamma 4.0 \ + --log-interval 50 \ + --save-interval 1000 \ + --report-to wandb \ + --wandb-project specforge-qwen3-8b-dspark \ + --target-model-backend $TARGET_BACKEND \ + --block-size 16 \ + --num-anchors 512 \ + --markov-rank 256 \ + --enable-confidence-head \ + --confidence-head-with-markov \ + --ce-loss-alpha 0.1 \ + --l1-loss-alpha 0.9 \ + --confidence-head-alpha 1.0 \ + --wandb-name qwen3-8b-dspark-perfectblend diff --git a/scripts/train_dspark.py b/scripts/train_dspark.py new file mode 100644 index 000000000..ea67ca39d --- /dev/null +++ b/scripts/train_dspark.py @@ -0,0 +1,719 @@ +#!/usr/bin/env python3 +# coding=utf-8 +"""DSpark Training Script. + +DSpark = DFlash block-diffusion drafter + EAGLE-style Markov & confidence heads, +trained with cross-entropy + L1 distribution distillation + confidence BCE. The +L1 / confidence terms need the target model's FINAL hidden state, so the target +backend must surface it (HF always does; sglang does when it returns both the +captured aux stream and the final hidden state). Set ``--l1-loss-alpha 0`` and +``--no-confidence-head`` to train CE-only without the target final hidden state. + +Cloned from ``scripts/train_dflash.py`` and adapted: builds a DSparkDraftModel + +OnlineDSparkModel, plumbs ``last_hidden_states`` into the forward, and logs the +per-component (ce / l1 / confidence) losses. +""" + +import argparse +import functools +import logging +import math +import os +import shutil +import time +import warnings +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +from accelerate.utils import set_seed +from torch.distributed.fsdp import BackwardPrefetch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoConfig, AutoTokenizer + +from datasets import load_dataset +from specforge.args import SGLangBackendArgs, TrackerArgs +from specforge.core.dspark import OnlineDSparkModel +from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders +from specforge.distributed import destroy_distributed, get_dp_group, init_distributed +from specforge.modeling.draft.dspark import DSparkDraftModel +from specforge.modeling.target.dflash_target_model import ( + DFlashTargetModel, + get_dflash_target_model, +) +from specforge.modeling.target.target_utils import TargetEmbeddingsAndHead +from specforge.optimizer import BF16Optimizer +from specforge.tracker import create_tracker +from specforge.utils import ( + get_last_checkpoint, + get_local_device, + print_on_rank0, + print_with_rank, +) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train DSpark Draft Model") + + model_group = parser.add_argument_group("model") + model_group.add_argument("--target-model-path", type=str, required=True) + model_group.add_argument( + "--target-model-backend", + type=str, + default="hf", + choices=["sglang", "hf"], + help="Backend for target model: 'sglang' (service) or 'hf' (local). " + "DSpark's L1/confidence losses need the target's final hidden state; " + "the 'hf' backend always surfaces it.", + ) + model_group.add_argument("--draft-config-path", type=str, default=None) + model_group.add_argument("--block-size", type=int, default=16) + model_group.add_argument("--num-draft-layers", type=int, default=1) + model_group.add_argument( + "--mask-token-id", + type=int, + default=None, + help="MASK token ID. If not provided, auto-detect from tokenizer.", + ) + model_group.add_argument( + "--attention-backend", + type=str, + default="flex_attention", + choices=["eager", "sdpa", "flex_attention"], + help="Attention backend for draft model.", + ) + model_group.add_argument( + "--trust-remote-code", action="store_true", help="Trust remote code" + ) + model_group.add_argument( + "--num-anchors", + type=int, + default=512, + help="Number of anchor positions per sequence", + ) + model_group.add_argument( + "--loss-decay-gamma", + type=float, + default=4.0, + help="Gamma for exponential within-block loss decay (exp(-k/gamma), " + "k = within-block slot index). None disables.", + ) + model_group.add_argument( + "--embedding-key", + type=str, + default=None, + help="Embedding weight key in the target model. " + "Default: 'model.embed_tokens.weight' for standard models.", + ) + model_group.add_argument( + "--lm-head-key", + type=str, + default=None, + help="LM head weight key in the target model. Default: 'lm_head.weight'.", + ) + + # DSpark-specific knobs + dspark_group = parser.add_argument_group("dspark") + dspark_group.add_argument( + "--markov-rank", + type=int, + default=256, + help="Rank of the low-rank Markov (bigram) bias head. 0 disables it.", + ) + dspark_group.add_argument( + "--markov-head-type", type=str, default="vanilla", choices=["vanilla"] + ) + dspark_group.add_argument( + "--enable-confidence-head", + action="store_true", + default=True, + help="Enable the per-position accept-rate (confidence) head.", + ) + dspark_group.add_argument( + "--no-confidence-head", + dest="enable_confidence_head", + action="store_false", + help="Disable the confidence head.", + ) + dspark_group.add_argument( + "--confidence-head-with-markov", + action="store_true", + default=True, + help="Fuse the Markov prev-token embedding into the confidence features.", + ) + dspark_group.add_argument( + "--ce-loss-alpha", type=float, default=0.1, help="Weight on cross-entropy." + ) + dspark_group.add_argument( + "--l1-loss-alpha", + type=float, + default=0.9, + help="Weight on L1 distribution distillation (needs target last hidden).", + ) + dspark_group.add_argument( + "--confidence-head-alpha", + type=float, + default=1.0, + help="Weight on the confidence-head BCE (needs target last hidden).", + ) + + dataset_group = parser.add_argument_group("dataset") + dataset_group.add_argument("--train-data-path", type=str, required=True) + dataset_group.add_argument("--eval-data-path", type=str, default=None) + dataset_group.add_argument("--chat-template", type=str, default="qwen") + dataset_group.add_argument("--is-preformatted", action="store_true") + dataset_group.add_argument("--dataloader-num-workers", type=int, default=8) + dataset_group.add_argument( + "--build-dataset-num-proc", + type=int, + default=int(os.environ.get("SPECFORGE_DATA_NUM_PROC", 8)), + ) + + training_group = parser.add_argument_group("training") + training_group.add_argument("--num-epochs", type=int, default=6) + training_group.add_argument("--batch-size", type=int, default=1) + training_group.add_argument("--learning-rate", type=float, default=6e-4) + training_group.add_argument("--max-length", type=int, default=3072) + training_group.add_argument("--warmup-ratio", type=float, default=0.04) + training_group.add_argument("--max-grad-norm", type=float, default=1.0) + training_group.add_argument("--accumulation-steps", type=int, default=1) + training_group.add_argument("--seed", type=int, default=42) + training_group.add_argument("--resume", action="store_true") + training_group.add_argument( + "--max-steps", + type=int, + default=None, + help="If set, stop after this many optimizer steps (smoke testing).", + ) + + output_group = parser.add_argument_group("output") + output_group.add_argument("--output-dir", type=str, required=True) + output_group.add_argument("--cache-dir", type=str, default="./cache") + output_group.add_argument("--log-interval", type=int, default=50) + output_group.add_argument("--eval-interval", type=int, default=1000) + output_group.add_argument("--save-interval", type=int, default=1000) + + optimization_group = parser.add_argument_group("optimization") + optimization_group.add_argument( + "--tp-size", + type=int, + default=1, + help="The size of the tensor parallel for the target model", + ) + + tracker_group = parser.add_argument_group("tracker") + TrackerArgs.add_args(tracker_group) + + dist_group = parser.add_argument_group("distributed") + dist_group.add_argument("--dist-timeout", type=int, default=30) + + # SGLang specific args + sglang_group = parser.add_argument_group("sglang backend") + SGLangBackendArgs.add_args(sglang_group) + + return parser.parse_args() + + +def _apply_dspark_config(draft_config, args) -> None: + """Set DSpark head fields on the draft config, preferring values already in + the config JSON and falling back to CLI args.""" + defaults = { + "markov_rank": args.markov_rank, + "markov_head_type": args.markov_head_type, + "enable_confidence_head": args.enable_confidence_head, + "confidence_head_with_markov": args.confidence_head_with_markov, + } + for key, value in defaults.items(): + if not hasattr(draft_config, key) or getattr(draft_config, key) is None: + setattr(draft_config, key, value) + + +def build_models(args, device) -> Tuple[DFlashTargetModel, DSparkDraftModel]: + """Build target model (backend wrapper) and DSpark draft model.""" + print_on_rank0( + f"Loading target model from {args.target_model_path} using " + f"{args.target_model_backend} backend" + ) + + target_model_kwargs = {} + if args.target_model_backend == "sglang": + target_model_kwargs = SGLangBackendArgs.from_args(args).to_kwargs() + + device_type = device.type + + target_model = get_dflash_target_model( + pretrained_model_name_or_path=args.target_model_path, + backend=args.target_model_backend, + torch_dtype=torch.bfloat16, + device=device_type if args.target_model_backend == "hf" else None, + trust_remote_code=args.trust_remote_code, + **target_model_kwargs, + ) + + if args.draft_config_path: + draft_config = AutoConfig.from_pretrained(args.draft_config_path) + print_on_rank0(f"Loaded draft config from {args.draft_config_path}") + if ( + hasattr(draft_config, "block_size") + and draft_config.block_size != args.block_size + ): + print_on_rank0( + f"Warning: config block_size ({draft_config.block_size}) differs from " + f"command-line arg ({args.block_size}). Using config value." + ) + else: + target_config = AutoConfig.from_pretrained(args.target_model_path) + draft_config = AutoConfig.from_pretrained(args.target_model_path) + draft_config.num_hidden_layers = args.num_draft_layers + draft_config.block_size = args.block_size + draft_config.num_target_layers = target_config.num_hidden_layers + print_on_rank0("Auto-generated draft config from target model") + + if not hasattr(draft_config, "dflash_config") or draft_config.dflash_config is None: + draft_config.dflash_config = {} + + _apply_dspark_config(draft_config, args) + draft_config._attn_implementation = args.attention_backend + print_on_rank0(f"Using attention backend: {args.attention_backend}") + + draft_model = DSparkDraftModel(draft_config).to(device=device, dtype=torch.bfloat16) + + target_model.set_capture_layers(draft_model.target_layer_ids) + + print_on_rank0( + f"Draft config: block_size={draft_config.block_size}, " + f"num_hidden_layers={draft_config.num_hidden_layers}, " + f"num_target_layers={draft_config.num_target_layers}, " + f"markov_rank={getattr(draft_config, 'markov_rank', 0)}, " + f"enable_confidence_head={getattr(draft_config, 'enable_confidence_head', False)}" + ) + print_on_rank0( + f"Draft model parameters: {sum(p.numel() for p in draft_model.parameters()):,}" + ) + + return target_model, draft_model + + +def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]: + """Build train and eval dataloaders.""" + import hashlib + + cache_params_string = ( + f"{args.train_data_path}-" + f"{args.max_length}-" + f"{args.chat_template}-" + f"{args.target_model_path}" + ) + cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() + + train_dataset = load_dataset("json", data_files=args.train_data_path)["train"] + train_eagle3_dataset = build_eagle3_dataset( + dataset=train_dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + is_preformatted=args.is_preformatted, + cache_dir=os.path.join(args.cache_dir, "processed_dataset"), + cache_key=cache_key, + num_proc=args.build_dataset_num_proc, + ) + + min_loss_tokens = 2 * args.block_size + original_size = len(train_eagle3_dataset) + train_eagle3_dataset = train_eagle3_dataset.filter( + lambda x: x["loss_mask"].sum() >= min_loss_tokens + ) + print_on_rank0( + f"Filtered train dataset: {original_size} -> {len(train_eagle3_dataset)} samples" + ) + + train_dataloader = prepare_dp_dataloaders( + train_eagle3_dataset, + args.batch_size, + num_workers=args.dataloader_num_workers, + shuffle=True, + process_group=get_dp_group(), + ) + + eval_dataloader = None + if args.eval_data_path: + eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"] + eval_eagle3_dataset = build_eagle3_dataset( + dataset=eval_dataset, + tokenizer=tokenizer, + chat_template=args.chat_template, + max_length=args.max_length, + is_preformatted=args.is_preformatted, + ) + eval_dataloader = prepare_dp_dataloaders( + eval_eagle3_dataset, + args.batch_size, + num_workers=args.dataloader_num_workers, + shuffle=False, + process_group=get_dp_group(), + ) + + return train_dataloader, eval_dataloader + + +def save_checkpoint(args, epoch, step, dspark_model, draft_model, optimizer): + """Save checkpoint.""" + save_dir = os.path.join(args.output_dir, f"epoch_{epoch}_step_{step}") + if dist.get_rank() == 0: + os.makedirs(save_dir, exist_ok=True) + dist.barrier() + + with FSDP.state_dict_type(dspark_model, StateDictType.FULL_STATE_DICT): + state_dict = dspark_model.state_dict() + draft_state_dict = { + k.replace("draft_model.", ""): v + for k, v in state_dict.items() + if "draft_model." in k + } + + if dist.get_rank() == 0: + torch.save( + { + "epoch": epoch, + "global_step": step, + "args": args, + **optimizer.state_dict(), + }, + os.path.join(save_dir, "training_state.pt"), + ) + + draft_model.save_pretrained(save_dir, state_dict=draft_state_dict) + + # Copy the modeling files next to the checkpoint so auto_map can + # resolve DSparkDraftModel (which subclasses DFlashDraftModel) on + # reload with trust_remote_code. + modeling_dir = os.path.join( + os.path.dirname(__file__), "..", "specforge", "modeling", "draft" + ) + for fname in ("dspark.py", "dflash.py"): + src = os.path.join(modeling_dir, fname) + if os.path.exists(src): + shutil.copy(src, os.path.join(save_dir, fname)) + + print_on_rank0(f"Saved checkpoint to {save_dir}") + + dist.barrier() + + +def record_metrics( + args, + loss: float, + accuracy: float, + components: dict, + global_step: int, + tracker, + optimizer, + train_dataloader=None, + mode: str = "train", +) -> None: + logdict = {} + + if mode == "train" and optimizer is not None: + logdict["train/lr"] = optimizer.get_learning_rate() + + logdict[f"{mode}/loss"] = loss + logdict[f"{mode}/accuracy"] = accuracy + for key, value in components.items(): + logdict[f"{mode}/{key}"] = value + + comp_str = " ".join(f"{k}={v:.4f}" for k, v in components.items()) + print_on_rank0( + f"{mode.capitalize()} - Step {global_step}" + f"[{global_step}/{args.num_epochs * len(train_dataloader) // args.accumulation_steps}?]," + f" Loss: {loss:.4f}, Acc: {accuracy:.4f}, {comp_str}" + ) + + tracker.log(logdict, step=global_step) + + +def main(): + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logging.getLogger().setLevel(logging.INFO) + warnings.filterwarnings( + "ignore", + "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed", + ) + + args = parse_args() + set_seed(args.seed) + + init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size) + print_with_rank("Initialized distributed") + + device = get_local_device() + device_type = device.type + + needs_target_hidden = (args.l1_loss_alpha > 0) or ( + args.enable_confidence_head and args.confidence_head_alpha > 0 + ) + + draft_model_last_checkpoint = None + ckpt_info = (0, 0) + if args.resume and os.path.isdir(args.output_dir): + draft_model_last_checkpoint, ckpt_info = get_last_checkpoint(args.output_dir) + print(f"Last checkpoint detected: {draft_model_last_checkpoint}") + + if draft_model_last_checkpoint: + checkpoint_config_path = os.path.join( + draft_model_last_checkpoint, "config.json" + ) + if os.path.exists(checkpoint_config_path): + print(f"Loading draft config from checkpoint: {checkpoint_config_path}") + args.draft_config_path = checkpoint_config_path + + target_model, draft_model = build_models(args, device) + + resume_state = None + if draft_model_last_checkpoint: + loaded_model = DSparkDraftModel.from_pretrained( + draft_model_last_checkpoint, torch_dtype=torch.bfloat16 + ) + draft_model.load_state_dict(loaded_model.state_dict()) + del loaded_model + print("Loaded draft model weights from checkpoint") + + training_state_path = os.path.join( + draft_model_last_checkpoint, "training_state.pt" + ) + if os.path.exists(training_state_path): + resume_state = torch.load( + training_state_path, map_location="cpu", weights_only=False + ) + print( + f"Will resume from epoch {resume_state['epoch']}, " + f"step {resume_state['global_step']}" + ) + + tokenizer = AutoTokenizer.from_pretrained(args.target_model_path) + + if args.mask_token_id is not None: + mask_token_id = args.mask_token_id + elif tokenizer.mask_token_id is not None: + mask_token_id = tokenizer.mask_token_id + else: + tokenizer.add_special_tokens({"mask_token": "<|MASK|>"}) + mask_token_id = tokenizer.mask_token_id + print_on_rank0(f"Using mask_token_id: {mask_token_id}") + + draft_model.mask_token_id = mask_token_id + draft_model.config.dflash_config["mask_token_id"] = mask_token_id + draft_model.config.dflash_config["target_layer_ids"] = draft_model.target_layer_ids + print_on_rank0(f"dflash_config: {draft_model.config.dflash_config}") + + train_dataloader, eval_dataloader = build_dataloader(args, tokenizer) + + steps_per_epoch = math.ceil(len(train_dataloader) / args.accumulation_steps) + total_steps = args.num_epochs * steps_per_epoch + print_on_rank0(f"Total training steps: {total_steps}") + + print_on_rank0("Loading target embeddings and head...") + target_components = TargetEmbeddingsAndHead.from_pretrained( + args.target_model_path, + embed_key=args.embedding_key, + lm_head_key=args.lm_head_key, + device=device_type, + trust_remote_code=args.trust_remote_code, + ) + + dspark_model = OnlineDSparkModel( + draft_model=draft_model, + target_lm_head=target_components.lm_head, + target_embed_tokens=target_components.embed_tokens, + block_size=draft_model.block_size, + mask_token_id=mask_token_id, + attention_backend=args.attention_backend, + num_anchors=args.num_anchors, + loss_decay_gamma=args.loss_decay_gamma, + ce_loss_alpha=args.ce_loss_alpha, + l1_loss_alpha=args.l1_loss_alpha, + confidence_head_alpha=args.confidence_head_alpha, + ) + + # Wrap each transformer block as its own FSDP unit (compute/comm overlap). + fsdp_kwargs = dict( + use_orig_params=True, + forward_prefetch=True, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + limit_all_gathers=True, + mixed_precision=MixedPrecision( + param_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ), + sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, + ) + block_names = set(getattr(draft_model, "_no_split_modules", None) or []) + block_classes = { + type(m) for m in dspark_model.modules() if type(m).__name__ in block_names + } + if block_classes: + fsdp_kwargs["auto_wrap_policy"] = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=block_classes, + ) + else: + print_with_rank( + "No _no_split_modules on draft model; falling back to single-unit " + "FSDP wrap (no compute-comm overlap)." + ) + dspark_model = FSDP(dspark_model, **fsdp_kwargs) + print_with_rank("Initialized FSDP") + + start_epoch = ckpt_info[0] + global_step = ckpt_info[1] + + optimizer = BF16Optimizer( + draft_model, + lr=args.learning_rate, + max_grad_norm=args.max_grad_norm, + warmup_ratio=args.warmup_ratio, + total_steps=total_steps, + ) + + if resume_state is not None: + optimizer.load_state_dict(resume_state) + start_epoch = resume_state["epoch"] + global_step = resume_state["global_step"] + del resume_state + print_on_rank0( + f"Restored optimizer/scheduler state: " + f"epoch={start_epoch}, step={global_step}, " + f"lr={optimizer.get_learning_rate():.6f}" + ) + + skip_steps = global_step - start_epoch * len(train_dataloader) + + print_on_rank0(f"Initializing tracker (report_to={args.report_to})...") + tracker = create_tracker(args, args.output_dir) + print_on_rank0("Tracker initialized successfully.") + + last_time = time.time() + print_on_rank0(f"Starting training from epoch {start_epoch}, step {global_step}") + stop = False + + for epoch in range(start_epoch, args.num_epochs): + if stop: + break + train_dataloader.sampler.set_epoch(epoch) + draft_model.train() + + if dist.get_rank() == 0: + progress_bar = tqdm( + train_dataloader, desc=f"Training Epoch {epoch}", leave=True + ) + else: + progress_bar = train_dataloader + + for step_in_epoch, data in enumerate(progress_bar): + if epoch == start_epoch and step_in_epoch < skip_steps: + continue + global_step += 1 + + input_ids = data["input_ids"].to(device, non_blocking=True) + attention_mask = data["attention_mask"].to(device, non_blocking=True) + loss_mask = data["loss_mask"].to(device, non_blocking=True) + target_output = target_model.generate_dflash_data( + input_ids, attention_mask, loss_mask + ) + hidden_states = target_output.hidden_states.to(device, non_blocking=True) + + last_hidden_states = target_output.last_hidden_states + if last_hidden_states is not None: + last_hidden_states = last_hidden_states.to(device, non_blocking=True) + elif needs_target_hidden: + raise RuntimeError( + "DSpark L1/confidence losses are enabled but the target backend " + f"({args.target_model_backend}) did not surface last_hidden_states. " + "Use --target-model-backend hf, or run CE-only with " + "--l1-loss-alpha 0 --no-confidence-head." + ) + + ( + loss, + accuracy, + loss_per_position, + acc_per_position, + count_per_position, + loss_components, + ) = dspark_model( + input_ids=input_ids, + hidden_states=hidden_states, + loss_mask=loss_mask, + last_hidden_states=last_hidden_states, + ) + + (loss / args.accumulation_steps).backward() + + if global_step % args.accumulation_steps == 0: + optimizer.step() + + if global_step % args.log_interval == 0: + loss_log = loss.clone() + acc_log = accuracy.clone() + dist.all_reduce(loss_log) + dist.all_reduce(acc_log) + loss_log = loss_log / dist.get_world_size() + acc_log = acc_log / dist.get_world_size() + + comp_log = {} + for key, value in loss_components.items(): + v = value.clone().float() + dist.all_reduce(v) + comp_log[key] = (v / dist.get_world_size()).item() + + record_metrics( + args, + loss_log.item(), + acc_log.item(), + comp_log, + global_step, + tracker, + optimizer, + train_dataloader, + mode="train", + ) + + if dist.get_rank() == 0: + elapsed = time.time() - last_time + last_time = time.time() + progress_bar.set_postfix( + { + "loss": f"{loss.item():.4f}", + "acc": f"{accuracy.item():.4f}", + "iter_time": f"{elapsed:.2f}s", + } + ) + + if global_step % args.save_interval == 0: + save_checkpoint( + args, epoch, global_step, dspark_model, draft_model, optimizer + ) + + if args.max_steps is not None and global_step >= args.max_steps: + print_on_rank0(f"Reached max_steps={args.max_steps}; stopping.") + stop = True + break + + save_checkpoint( + args, args.num_epochs, global_step, dspark_model, draft_model, optimizer + ) + + tracker.close() + destroy_distributed() + + +if __name__ == "__main__": + main() diff --git a/specforge/core/__init__.py b/specforge/core/__init__.py index 4d5dcc644..d3b1de48e 100644 --- a/specforge/core/__init__.py +++ b/specforge/core/__init__.py @@ -1,10 +1,12 @@ from .dflash import OnlineDFlashModel from .domino import OnlineDominoModel +from .dspark import OnlineDSparkModel from .eagle3 import OnlineEagle3Model, QwenVLOnlineEagle3Model from .peagle import OnlinePEagleModel __all__ = [ "OnlineDFlashModel", + "OnlineDSparkModel", "OnlineDominoModel", "OnlineEagle3Model", "OnlinePEagleModel", diff --git a/specforge/core/dspark.py b/specforge/core/dspark.py new file mode 100644 index 000000000..d71571a9c --- /dev/null +++ b/specforge/core/dspark.py @@ -0,0 +1,334 @@ +# coding=utf-8 +"""DSpark online training wrapper: DFlash backbone + Markov / L1 / confidence losses. + +Ported from TorchSpec PR #129 (``torchspec/models/dspark.py``). Reuses SpecForge's +:class:`OnlineDFlashModel` anchor sampling, block-causal mask construction, and +MASK-token noise stream verbatim (via ``super()``), then layers on the DSpark +training objective: + + - Markov-biased draft logits (teacher-forced previous token). + - Cross-entropy against the ground-truth next tokens (hard labels). + - L1 distribution distillation: ``|softmax(draft) - softmax(target)|`` where the + target distribution is the frozen LM head applied to the *target's* final + hidden state at the aligned position (requires ``last_hidden_states``). + - Confidence head BCE against the empirical per-token accept rate. + +Combined: ``ce_alpha*ce + l1_alpha*l1 + confidence_alpha*confidence``. + +Loss formulation adapted from DeepSeek's DeepSpec (``deepspec/modeling/dspark/loss.py``, +MIT), including its pooled global-mean reduction: local numerators over a +cross-rank all-reduced denominator, scaled by world_size to cancel FSDP's mean +gradient reduction. + +Key SpecForge differences vs TorchSpec (see port notes in the PR): + - SpecForge's :class:`OnlineDFlashModel.forward` returns ``(loss, accuracy)``; + DSpark needs the per-component losses, so this forward returns a 6-tuple + ``(loss, accuracy, loss_per_position, acc_per_position, count_per_position, + loss_components)``. ``train_dspark.py`` consumes the extra elements. + - The target ``lm_head`` is a frozen ``nn.Linear`` module on the wrapper + (``self.lm_head``); the L1 path uses ``self.lm_head.weight`` for ``F.linear``. + - The fused multi-layer context feature (``hidden_states``) is produced upstream + by ``generate_dflash_data`` and fed straight to the draft as ``target_hidden``. +""" + +from typing import List, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from specforge.core.dflash import ( + FLEX_ATTENTION_AVAILABLE, + OnlineDFlashModel, + create_dflash_block_mask, + create_dflash_sdpa_mask, +) +from specforge.modeling.draft.dspark import DSparkDraftModel + + +class OnlineDSparkModel(OnlineDFlashModel): + """DSpark online training wrapper (DFlash backbone + Markov/L1/confidence heads).""" + + def __init__( + self, + draft_model: DSparkDraftModel, + target_lm_head: nn.Module, + target_embed_tokens: nn.Module, + mask_token_id: int, + block_size: int = 7, + attention_backend: str = "flex_attention", + num_anchors: int = 512, + loss_decay_gamma: Optional[float] = 4.0, + ce_loss_alpha: float = 0.1, + l1_loss_alpha: float = 0.9, + confidence_head_alpha: float = 1.0, + ): + # Reuse DFlash anchor/mask/noise machinery. loss_type="dflash" is only a + # placeholder to satisfy the parent validator — DSpark overrides forward() + # entirely and never dispatches on loss_type. + super().__init__( + draft_model=draft_model, + target_lm_head=target_lm_head, + target_embed_tokens=target_embed_tokens, + mask_token_id=mask_token_id, + block_size=block_size, + attention_backend=attention_backend, + num_anchors=num_anchors, + loss_decay_gamma=loss_decay_gamma, + loss_type="dflash", + ) + self.ce_loss_alpha = float(ce_loss_alpha) + self.l1_loss_alpha = float(l1_loss_alpha) + self.confidence_head_alpha = float(confidence_head_alpha) + + def _decay_weights(self, device: torch.device) -> torch.Tensor: + """exp(-k/gamma) over within-block position k (DeepSpec convention). + + Every slot 0..B-1 is a real prediction in DSpark (unlike DFlash, where + slot 0 is the masked anchor), so slot 0 (the first predicted token) gets + weight 1.0 and later slots decay. + """ + k = torch.arange(self.block_size, device=device).view(1, 1, -1) + if self.loss_decay_gamma is not None and self.loss_decay_gamma > 0: + return torch.exp(-k.float() / self.loss_decay_gamma) + return torch.ones_like(k, dtype=torch.float32) + + def forward( + self, + input_ids: torch.Tensor, + hidden_states: torch.Tensor, + loss_mask: torch.Tensor, + last_hidden_states: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + dict, + ]: + """DSpark training forward. + + ``hidden_states`` is the fused multi-layer context feature + ``[B, S, len(target_layer_ids)*hidden]`` (the draft model applies its + ``fc``/``hidden_norm`` internally). ``last_hidden_states`` is the target + model's final hidden state ``[B, S, hidden]`` (needed only for the L1 / + confidence objectives). + + Returns ``(loss, accuracy, loss_per_position, acc_per_position, + count_per_position, loss_components)``. ``loss`` is the combined + ce+l1+confidence objective; ``loss_components`` is a dict of detached + per-rank local-mean scalars (ce_loss / l1_loss / confidence_loss) for + logging. + """ + if self.attention_backend == "flex_attention" and not FLEX_ATTENTION_AVAILABLE: + raise ValueError( + "flex_attention is not available on this device; use sdpa/eager." + ) + bsz, seq_len = input_ids.shape + device = input_ids.device + + # ---- DFlash backbone (identical construction to OnlineDFlashModel.forward) ---- + anchor_positions, block_keep_mask = self._sample_anchor_positions( + seq_len, loss_mask, device + ) + n_blocks = anchor_positions.shape[1] + + noise_embedding = self._create_noise_embed( + input_ids, anchor_positions, block_keep_mask + ) + + context_position_ids = ( + torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, -1) + ) + draft_position_ids = self._create_position_ids(anchor_positions) + full_position_ids = torch.cat([context_position_ids, draft_position_ids], dim=1) + + if self.attention_backend == "flex_attention": + dflash_attn_mask = create_dflash_block_mask( + anchor_positions=anchor_positions, + block_keep_mask=block_keep_mask, + S=seq_len, + block_size=self.block_size, + device=device, + ) + else: + dflash_attn_mask = create_dflash_sdpa_mask( + anchor_positions=anchor_positions, + block_keep_mask=block_keep_mask, + S=seq_len, + block_size=self.block_size, + device=device, + ) + + draft_hidden = self.draft_model( + position_ids=full_position_ids, + noise_embedding=noise_embedding, + target_hidden=hidden_states, + attention_mask=dflash_attn_mask, + ) + hidden_4d = draft_hidden.view(bsz, n_blocks, self.block_size, -1) + + base_logits = self.lm_head(draft_hidden) + base_logits_4d = base_logits.view(bsz, n_blocks, self.block_size, -1) + vocab_size = base_logits_4d.size(-1) + + # ---- Labels + eval mask (DSpark / DeepSpec convention) ---- + # Slot j predicts the token at anchor+j+1 (the real anchor token seeds + # slot 0). All block_size slots are supervised — there is no masked anchor + # slot, unlike SpecForge DFlash which drops slot 0. + label_offsets = torch.arange(1, self.block_size + 1, device=device).view( + 1, 1, -1 + ) + label_indices = anchor_positions.unsqueeze(-1) + label_offsets # [B, nb, bs] + valid_label_mask = label_indices < seq_len + safe_label_indices = label_indices.clamp(max=seq_len - 1) + safe_label_indices = torch.where( + block_keep_mask.unsqueeze(-1), + safe_label_indices, + torch.zeros_like(safe_label_indices), + ) + + target_ids = torch.gather( + input_ids.unsqueeze(1).expand(-1, n_blocks, -1), 2, safe_label_indices + ) # [B, nb, bs] + + # eval mask = contiguous supervised prefix per block (DeepSpec + # build_eval_mask): block kept, label in-bounds, target token supervised, + # then cumprod so a gap truncates the rest of the block. + target_loss_mask = torch.gather( + loss_mask.unsqueeze(1).expand(-1, n_blocks, -1), 2, safe_label_indices + ) + eval_bool = ( + block_keep_mask.unsqueeze(-1) & valid_label_mask & (target_loss_mask > 0.5) + ) + eval_bool = eval_bool.to(torch.int32).cumprod(dim=-1).bool() + eval_mask = eval_bool.float() # [B, nb, bs] + + decay_weight_mask = eval_mask * self._decay_weights(device) + local_den = decay_weight_mask.sum() + + # ---- Markov-biased draft logits ---- + # prev token for slot j is the ground-truth token immediately before the + # one slot j predicts: slot 0's prev is the real anchor token, slot j's is + # target_ids[j-1]. Matches DeepSpec prev_token_ids. + anchor_token_ids = torch.gather(input_ids, 1, anchor_positions) # [B, nb] + prev_token_ids = torch.cat( + [anchor_token_ids.unsqueeze(-1), target_ids[:, :, :-1]], dim=-1 + ) + logits_4d = base_logits_4d + if self.draft_model.markov_head is not None: + logits_4d = self.draft_model.markov_head.apply_block_logits( + base_logits_4d, token_ids=prev_token_ids + ) + + # ---- Cross entropy (hard labels) ---- + flat_logits = logits_4d.reshape(-1, vocab_size) + flat_targets = target_ids.reshape(-1) + ce_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none").view( + bsz, n_blocks, self.block_size + ) + ce_num = (ce_per_token * decay_weight_mask).sum() + + # ---- L1 distribution distillation + accept rate ---- + l1_num = base_logits.new_zeros((), dtype=torch.float32) + accept_rate = None + need_target = (self.l1_loss_alpha > 0) or ( + self.draft_model.confidence_head is not None + and self.confidence_head_alpha > 0 + ) + if need_target: + if last_hidden_states is None: + raise ValueError( + "DSpark L1/confidence losses require target last_hidden_states; " + "ensure the target model surfaces its final hidden state." + ) + # target distribution for the token at label_indices = target LM head + # applied to the target hidden one position earlier (anchor+j). + tgt_idx = (safe_label_indices - 1).clamp(min=0) # [B, nb, bs] + hdim = last_hidden_states.size(-1) + gather_idx = tgt_idx.reshape(bsz, -1, 1).expand(-1, -1, hdim) + aligned_hidden = torch.gather(last_hidden_states, 1, gather_idx) + aligned_target_logits = F.linear( + aligned_hidden, self.lm_head.weight + ).view(bsz, n_blocks, self.block_size, vocab_size) + draft_probs = torch.softmax(logits_4d.float(), dim=-1) + target_probs = torch.softmax(aligned_target_logits.float(), dim=-1) + l1_per_token = (draft_probs - target_probs).abs().sum(dim=-1) # [B, nb, bs] + if self.l1_loss_alpha > 0: + l1_num = (l1_per_token * decay_weight_mask).sum() + accept_rate = (1.0 - 0.5 * l1_per_token).clamp(0.0, 1.0) + + # ---- Confidence head BCE ---- + conf_num = base_logits.new_zeros((), dtype=torch.float32) + if ( + self.draft_model.confidence_head is not None + and self.confidence_head_alpha > 0 + ): + if self.draft_model.confidence_head_with_markov: + prev_emb = self.draft_model.markov_head.get_prev_embeddings( + prev_token_ids + ).to(hidden_4d.dtype) + conf_features = torch.cat([hidden_4d, prev_emb], dim=-1) + else: + conf_features = hidden_4d + confidence_pred = self.draft_model.confidence_head(conf_features).float() + conf_bce = ( + F.binary_cross_entropy_with_logits( + confidence_pred, accept_rate.detach(), reduction="none" + ) + * decay_weight_mask + ) + conf_num = conf_bce.sum() + + # ---- Pooled global loss (DeepSpec _build_loss) ---- + # Local numerators over a cross-rank-summed denominator, x world_size to + # cancel FSDP's mean gradient reduction -> a true token-pooled global mean + # rather than a mean-of-per-rank-means. + # NOTE: uses the global training group size; correct for plain DP / ZeRO-2 + # (single shard group). With a multi-dim mesh (e.g. HSDP/USP) the FSDP + # shard group differs from world_size and this would need the shard group. + world_size = dist.get_world_size() if dist.is_initialized() else 1 + global_den = local_den.detach().clone() + if world_size > 1: + dist.all_reduce(global_den, op=dist.ReduceOp.SUM) + global_den = global_den + 1e-6 + loss = ( + self.ce_loss_alpha * ce_num / global_den + + self.l1_loss_alpha * l1_num / global_den + + self.confidence_head_alpha * conf_num / global_den + ) * world_size + + # Per-component loss values (per-rank local means) for logging only — lets + # you watch L1 fall while the greedy-CE proxy plateaus. + local_den_eps = local_den + 1e-6 + loss_components = { + "ce_loss": (ce_num / local_den_eps).detach(), + "l1_loss": (l1_num / local_den_eps).detach(), + "confidence_loss": (conf_num / local_den_eps).detach(), + } + + # ---- Metrics (cross-entropy based; all block_size slots are productive) ---- + with torch.no_grad(): + flat_binary = eval_mask.reshape(-1) + pred_ids = torch.argmax(flat_logits, dim=-1) + correct = (pred_ids == flat_targets) & (flat_binary > 0.5) + accuracy = correct.sum().float() / flat_binary.sum().clamp(min=1e-6) + + count_per_position = eval_mask.sum(dim=(0, 1)) + count_pp = count_per_position.clamp(min=1.0) + loss_per_position = (ce_per_token * eval_mask).sum(dim=(0, 1)) / count_pp + acc_per_position = ( + correct.view(bsz, n_blocks, self.block_size).float().sum(dim=(0, 1)) + / count_pp + ) + + return ( + loss, + accuracy, + loss_per_position, + acc_per_position, + count_per_position, + loss_components, + ) diff --git a/specforge/modeling/draft/__init__.py b/specforge/modeling/draft/__init__.py index 6130dcc63..23aa13afd 100644 --- a/specforge/modeling/draft/__init__.py +++ b/specforge/modeling/draft/__init__.py @@ -5,12 +5,24 @@ extract_context_feature, sample, ) +from .dspark import ( + AcceptRatePredictor, + DSparkConfig, + DSparkDraftModel, + VanillaMarkov, + build_markov_head, +) from .llama3_eagle import LlamaForCausalLMEagle3 from .peagle import PEagleDraftModel __all__ = [ "Eagle3DraftModel", "DFlashDraftModel", + "DSparkDraftModel", + "DSparkConfig", + "VanillaMarkov", + "AcceptRatePredictor", + "build_markov_head", "LlamaForCausalLMEagle3", "PEagleDraftModel", "build_target_layer_ids", diff --git a/specforge/modeling/draft/dspark.py b/specforge/modeling/draft/dspark.py new file mode 100644 index 000000000..3d7bda332 --- /dev/null +++ b/specforge/modeling/draft/dspark.py @@ -0,0 +1,160 @@ +# coding=utf-8 +"""DSpark draft model: DFlash backbone + EAGLE-style Markov and confidence heads. + +DSpark shares SpecForge's DFlash block-diffusion drafter (dual-source KV +injection via :class:`DFlashDraftModel`, anchor sampling, MASK-token noise +stream) and adds two heads on top: + + - Markov head: a low-rank learned bigram bias added to the draft logits, + conditioned on the (teacher-forced) previous token. Improves the per-token + distribution without touching the backbone. + - Confidence head (AcceptRatePredictor): predicts a per-draft-position + acceptance probability, trained against the empirical draft-vs-target + accept rate (used at inference time for adaptive block length). + +Ported from TorchSpec PR #129 (``torchspec/models/draft/dspark.py``). The Markov +/ confidence modeling code is adapted from DeepSeek's DeepSpec +(``deepspec/modeling/dspark/{markov_head,common}.py``, MIT License). + +SpecForge differences vs TorchSpec (load-bearing): + - There is no ``DFlashConfig``; SpecForge's :class:`DFlashDraftModel` uses a + plain ``Qwen3Config`` plus a ``config.dflash_config`` dict. So + :class:`DSparkConfig` subclasses ``Qwen3Config`` and declares the DSpark + fields as top-level attributes; DFlash-carried fields (``block_size``, + ``num_target_layers``, ``dflash_config``) stay as before. + - The draft model has no ``embed_tokens`` of its own (the embedding lives on + the target and is passed into the online wrapper), and the context + projection is ``self.fc`` (not ``context_proj``). The heads only depend on + ``config.hidden_size`` / ``config.vocab_size``, so this does not matter for + construction. +""" + +from typing import Optional + +import torch +import torch.nn as nn +from transformers.models.qwen3.modeling_qwen3 import Qwen3Config + +from specforge.modeling.draft.dflash import DFlashDraftModel + + +class DSparkConfig(Qwen3Config): + """Configuration for the DSpark draft model. + + Extends ``Qwen3Config`` (SpecForge's DFlash draft is config-light and reads a + plain ``Qwen3Config``). DSpark-specific fields are declared here; the + DFlash-carried fields (``block_size``, ``num_target_layers``, and the nested + ``dflash_config`` dict holding ``target_layer_ids`` / ``mask_token_id``) are + consumed by the :class:`DFlashDraftModel` base ``__init__`` and must be + present on the config object before constructing the model. + """ + + model_type = "dspark" + + def __init__( + self, + markov_rank: int = 256, + markov_head_type: str = "vanilla", + enable_confidence_head: bool = True, + confidence_head_with_markov: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.markov_rank = markov_rank + self.markov_head_type = markov_head_type + self.enable_confidence_head = enable_confidence_head + self.confidence_head_with_markov = confidence_head_with_markov + + +class VanillaMarkov(nn.Module): + """Low-rank learned bigram bias added to the draft logits. + + Adapted from DeepSpec's ``deepspec/modeling/dspark/markov_head.py``. + """ + + def __init__(self, *, vocab_size: int, markov_rank: int): + super().__init__() + self.vocab_size = int(vocab_size) + self.markov_rank = int(markov_rank) + self.markov_head_type = "vanilla" + assert ( + self.markov_rank > 0 + ), f"VanillaMarkov requires markov_rank > 0, got {self.markov_rank}." + self.markov_w1 = nn.Embedding(self.vocab_size, self.markov_rank) + self.markov_w2 = nn.Linear(self.markov_rank, self.vocab_size, bias=False) + + def get_prev_embeddings(self, token_ids: torch.Tensor) -> torch.Tensor: + return self.markov_w1(token_ids.long()) + + def project_bias(self, latent_states: torch.Tensor) -> torch.Tensor: + return self.markov_w2(latent_states) + + def compute_step_bias(self, token_ids: torch.Tensor) -> torch.Tensor: + return self.project_bias(self.get_prev_embeddings(token_ids)) + + def apply_block_logits( + self, + base_logits: torch.Tensor, + *, + token_ids: torch.Tensor, + ) -> torch.Tensor: + if base_logits.size(2) == 0: + return base_logits + return base_logits + self.compute_step_bias(token_ids) + + +class AcceptRatePredictor(nn.Module): + """Per-position acceptance-probability predictor (a single linear head). + + Adapted from DeepSpec's ``deepspec/modeling/dspark/common.py``. + """ + + def __init__(self, input_dim: int): + super().__init__() + self.proj = nn.Linear(int(input_dim), 1) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + return self.proj(features).squeeze(-1) + + +def build_markov_head(config) -> Optional[nn.Module]: + markov_rank = int(getattr(config, "markov_rank", 0)) + assert markov_rank >= 0, f"markov_rank must be >= 0, got {markov_rank}" + if markov_rank == 0: + return None + + markov_head_type = str(getattr(config, "markov_head_type", "vanilla")).lower() + if markov_head_type == "vanilla": + return VanillaMarkov(vocab_size=config.vocab_size, markov_rank=markov_rank) + raise NotImplementedError( + f"markov_head_type={markov_head_type!r} is not supported yet; only 'vanilla' " + "is implemented as it is recommended by the authors." + ) + + +class DSparkDraftModel(DFlashDraftModel): + """DSpark draft network: DFlash backbone + Markov / confidence heads.""" + + config_class = DSparkConfig + + def __init__(self, config) -> None: + super().__init__(config) + + self.markov_rank = int(getattr(config, "markov_rank", 0)) + self.confidence_head_with_markov = bool( + getattr(config, "confidence_head_with_markov", True) + ) + + self.markov_head = build_markov_head(config) + + self.confidence_head: Optional[nn.Module] = None + if getattr(config, "enable_confidence_head", False): + conf_input_dim = config.hidden_size + if self.confidence_head_with_markov: + if self.markov_head is None: + raise ValueError( + "confidence_head_with_markov=True requires a Markov head " + "(markov_rank > 0)." + ) + conf_input_dim += self.markov_rank + self.confidence_head = AcceptRatePredictor(conf_input_dim) diff --git a/specforge/modeling/target/dflash_target_model.py b/specforge/modeling/target/dflash_target_model.py index 0df938239..09186ad19 100644 --- a/specforge/modeling/target/dflash_target_model.py +++ b/specforge/modeling/target/dflash_target_model.py @@ -24,10 +24,16 @@ @dataclass class DFlashTargetOutput: - hidden_states: torch.Tensor # [batch, seq_len, hidden_size] + hidden_states: torch.Tensor # [batch, seq_len, n_capture*hidden_size] input_ids: torch.Tensor # [batch, seq_len] attention_mask: torch.Tensor # [batch, seq_len] loss_mask: torch.Tensor # [batch, seq_len] + # Target model's FINAL hidden state [batch, seq_len, hidden_size]. Optional: + # DFlash never reads it, but DSpark's L1 distribution-distillation and + # confidence-head losses need it (the frozen target LM head is applied to it + # to form the soft next-token distribution). None when the backend does not + # surface it (then DSpark must run CE-only). + last_hidden_states: Optional[torch.Tensor] = None class DFlashTargetModel(ABC): @@ -163,22 +169,26 @@ def _extend(self, reqs): output = output.logits_output input_lens = [len(req.origin_input_ids) for req in reqs] + # context = the captured (aux) mid-layer concat used by DFlash; final = the + # post-norm last-layer hidden, surfaced for DSpark's L1 / confidence losses + # (None if the runner only returned a single hidden stream). + final_list = None if ( hasattr(output, "aux_hidden_states") and output.aux_hidden_states is not None ): - hidden_states_list = torch.split( - output.aux_hidden_states, input_lens, dim=0 - ) + context_list = torch.split(output.aux_hidden_states, input_lens, dim=0) + if hasattr(output, "hidden_states") and output.hidden_states is not None: + final_list = torch.split(output.hidden_states, input_lens, dim=0) elif hasattr(output, "hidden_states") and output.hidden_states is not None: - hidden_states_list = torch.split(output.hidden_states, input_lens, dim=0) + context_list = torch.split(output.hidden_states, input_lens, dim=0) else: raise ValueError("SGLang output does not contain hidden states.") self.model_runner.req_to_token_pool.clear() self.model_runner.token_to_kv_pool_allocator.clear() - return hidden_states_list + return context_list, final_list @torch.no_grad() def generate_dflash_data( @@ -209,10 +219,15 @@ def generate_dflash_data( data_cache.append((curr_ids, curr_attn, curr_loss)) reqs.append(req) - hidden_states_list = self._extend(reqs) + context_list, final_list = self._extend(reqs) # Stack back to batch - hidden_states = torch.cat([h.unsqueeze(0) for h in hidden_states_list], dim=0) + hidden_states = torch.cat([h.unsqueeze(0) for h in context_list], dim=0) + last_hidden_states = None + if final_list is not None: + last_hidden_states = torch.cat( + [h.unsqueeze(0) for h in final_list], dim=0 + ) input_ids = torch.cat([d[0] for d in data_cache], dim=0) attention_mask = torch.cat([d[1] for d in data_cache], dim=0) loss_mask = torch.cat([d[2] for d in data_cache], dim=0) @@ -222,6 +237,7 @@ def generate_dflash_data( input_ids=input_ids, attention_mask=attention_mask, loss_mask=loss_mask, + last_hidden_states=last_hidden_states, ) @@ -269,7 +285,8 @@ def generate_dflash_data( use_cache=False, ) - # hidden_states[0] = embedding output; hidden_states[i+1] = layer i output + # hidden_states[0] = embedding output; hidden_states[i+1] = layer i output; + # hidden_states[-1] = final (post-norm) hidden, i.e. the LM-head input. offset = 1 selected = [] if self.capture_layer_ids is not None: @@ -279,11 +296,15 @@ def generate_dflash_data( else: hidden_states = outputs.hidden_states[-1] + # Final hidden state for DSpark's L1 / confidence losses (DFlash ignores it). + last_hidden_states = outputs.hidden_states[-1] + return DFlashTargetOutput( hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask, loss_mask=loss_mask, + last_hidden_states=last_hidden_states, ) diff --git a/tests/test_utils/test_dspark.py b/tests/test_utils/test_dspark.py new file mode 100644 index 000000000..11ba3ae4f --- /dev/null +++ b/tests/test_utils/test_dspark.py @@ -0,0 +1,288 @@ +# coding=utf-8 +"""Tests for DSpark (DFlash backbone + Markov / confidence heads + L1 distillation). + +Ported/adapted from TorchSpec PR #129 ``tests/test_dspark.py`` to SpecForge's +DFlash structure. Pins the DSpark wiring so future refactors can't silently break +the objective: + +1. DSparkConfig / DSparkDraftModel: head construction, base relationship. +2. forward returns the 6-tuple with detached per-component losses. +3. Loss-wiring invariants: + - internal identity: combined loss == ce_a*ce + l1_a*l1 + cf_a*conf (world 1) + - all-masked batch -> loss 0 + - gradients reach markov + confidence + backbone; target embedding stays frozen + - next-token convention: every within-block slot is supervised (B predictions) +4. Markov / confidence head unit math. +5. CE-only path runs without target last_hidden_states. + +Runs on CPU with ``attention_backend="sdpa"`` (flex_attention needs CUDA). The +real DSpark modules are loaded directly via importlib + stub parent packages so +the heavy ``specforge`` package ``__init__`` (sglang/TF target stack) is not +imported. Run with ``USE_TF=0`` on macOS to avoid the torch/TF OpenMP clash. +""" + +import importlib.util +import sys +import types +import unittest +from pathlib import Path + +import torch + +REPO = Path(__file__).resolve().parents[2] + + +def _stub_pkg(name: str, path: Path) -> types.ModuleType: + mod = types.ModuleType(name) + mod.__path__ = [str(path)] + sys.modules[name] = mod + return mod + + +def _load(modname: str, relpath: str) -> types.ModuleType: + spec = importlib.util.spec_from_file_location(modname, REPO / relpath) + mod = importlib.util.module_from_spec(spec) + sys.modules[modname] = mod + spec.loader.exec_module(mod) + return mod + + +# Build stub parent packages, then load the real leaf modules in dependency order. +_stub_pkg("specforge", REPO / "specforge") +_stub_pkg("specforge.core", REPO / "specforge" / "core") +_stub_pkg("specforge.modeling", REPO / "specforge" / "modeling") +_stub_pkg("specforge.modeling.draft", REPO / "specforge" / "modeling" / "draft") + +_load("specforge.modeling.draft.dflash", "specforge/modeling/draft/dflash.py") +_dspark_draft = _load( + "specforge.modeling.draft.dspark", "specforge/modeling/draft/dspark.py" +) +_load("specforge.core.dflash", "specforge/core/dflash.py") +_core_dspark = _load("specforge.core.dspark", "specforge/core/dspark.py") + +AcceptRatePredictor = _dspark_draft.AcceptRatePredictor +DSparkConfig = _dspark_draft.DSparkConfig +DSparkDraftModel = _dspark_draft.DSparkDraftModel +VanillaMarkov = _dspark_draft.VanillaMarkov +OnlineDSparkModel = _core_dspark.OnlineDSparkModel + +import torch.nn as nn # noqa: E402 + +CE_A, L1_A, CF_A = 0.1, 0.9, 1.0 + + +def _make_dspark_config( + H=64, + V=128, + num_target_layers=2, + num_hidden_layers=1, + markov_rank=16, + enable_confidence_head=True, + confidence_head_with_markov=True, +): + return DSparkConfig( + hidden_size=H, + intermediate_size=256, + num_hidden_layers=num_hidden_layers, + num_attention_heads=4, + num_key_value_heads=2, + vocab_size=V, + rms_norm_eps=1e-6, + max_position_embeddings=512, + rope_theta=10000.0, + sliding_window=None, + layer_types=["full_attention"] * num_hidden_layers, + attn_implementation="sdpa", + # DFlash-carried fields + block_size=4, + num_target_layers=num_target_layers, + dflash_config={"mask_token_id": V - 1}, + # DSpark fields + markov_rank=markov_rank, + markov_head_type="vanilla", + enable_confidence_head=enable_confidence_head, + confidence_head_with_markov=confidence_head_with_markov, + ) + + +def _make_dspark_model(block_size=4, num_anchors=6, H=64, V=128, **cfg_kw): + config = _make_dspark_config(H=H, V=V, **cfg_kw) + config.block_size = block_size + config._attn_implementation = "sdpa" + draft = DSparkDraftModel(config).to(dtype=torch.float32) + draft.mask_token_id = V - 1 + + target_embed = nn.Embedding(V, H) + target_lm_head = nn.Linear(H, V, bias=False) + target_embed.requires_grad_(False) + target_lm_head.requires_grad_(False) + + return OnlineDSparkModel( + draft_model=draft, + target_lm_head=target_lm_head, + target_embed_tokens=target_embed, + mask_token_id=V - 1, + block_size=block_size, + attention_backend="sdpa", + num_anchors=num_anchors, + loss_decay_gamma=4.0, + ce_loss_alpha=CE_A, + l1_loss_alpha=L1_A, + confidence_head_alpha=CF_A, + ).to(dtype=torch.float32) + + +def _batch(B=2, S=24, H=64, V=128, all_masked=False, seed=0): + g = torch.Generator().manual_seed(seed) + input_ids = torch.randint(0, V, (B, S), generator=g) + # SpecForge fused context feature: one captured layer -> width H. + hidden_states = torch.randn(B, S, H, generator=g) + loss_mask = torch.zeros(B, S) if all_masked else torch.ones(B, S) + if not all_masked: + loss_mask[:, :2] = 0 # prompt + last_hidden_states = torch.randn(B, S, H, generator=g) + return dict( + input_ids=input_ids, + hidden_states=hidden_states, + loss_mask=loss_mask, + last_hidden_states=last_hidden_states, + ) + + +class TestDSparkConfig(unittest.TestCase): + def test_subclasses_qwen3_and_attrs(self): + from transformers.models.qwen3.modeling_qwen3 import Qwen3Config + + cfg = _make_dspark_config(markov_rank=32) + self.assertIsInstance(cfg, Qwen3Config) + self.assertEqual(cfg.model_type, "dspark") + self.assertEqual(cfg.markov_rank, 32) + self.assertTrue(cfg.enable_confidence_head) + + def test_draft_model_heads(self): + cfg = _make_dspark_config(H=64, markov_rank=16) + m = DSparkDraftModel(cfg) + self.assertIsInstance(m.markov_head, VanillaMarkov) + self.assertIsInstance(m.confidence_head, AcceptRatePredictor) + # confidence input = hidden + markov_rank when fused + self.assertEqual(m.confidence_head.proj.in_features, 64 + 16) + + def test_no_heads(self): + cfg = _make_dspark_config( + markov_rank=0, + enable_confidence_head=False, + confidence_head_with_markov=False, + ) + m = DSparkDraftModel(cfg) + self.assertIsNone(m.markov_head) + self.assertIsNone(m.confidence_head) + + +class TestDSparkForward(unittest.TestCase): + def test_returns_six_tuple_with_detached_components(self): + m = _make_dspark_model() + out = m(**_batch()) + self.assertEqual(len(out), 6) + loss, acc, lpp, app, cpp, comps = out + self.assertEqual(set(comps), {"ce_loss", "l1_loss", "confidence_loss"}) + for v in comps.values(): + self.assertTrue(torch.isfinite(v).all()) + self.assertFalse(v.requires_grad) # detached for logging + self.assertTrue(torch.isfinite(loss)) + self.assertEqual(lpp.shape[0], m.block_size) + + def test_internal_loss_identity(self): + # At world_size==1 the combined loss must equal the alpha-weighted sum of + # the logged components (same denominator) — so the components are a + # faithful decomposition of what's actually optimized. + m = _make_dspark_model() + loss, _, _, _, _, comps = m(**_batch(seed=1)) + recomputed = ( + CE_A * comps["ce_loss"] + + L1_A * comps["l1_loss"] + + CF_A * comps["confidence_loss"] + ) + self.assertTrue( + torch.allclose(loss, recomputed, atol=1e-4), + f"{loss.item()} vs {recomputed.item()}", + ) + + def test_all_masked_raises_guard(self): + # SpecForge's anchor sampler refuses an all-masked sample (the dataloader + # filters these out upstream via min_loss_tokens = 2*block_size). This + # differs from TorchSpec, where an all-masked batch yields loss 0. The + # per-label masking ("masked tokens contribute zero") is still exercised + # by the prompt mask (loss_mask[:, :2]=0) in the other forward tests. + m = _make_dspark_model() + with self.assertRaises(ValueError): + m(**_batch(all_masked=True)) + + def test_next_token_convention_all_slots_supervised(self): + # Every within-block slot predicts a real token (B predictions), unlike + # DFlash where slot 0 is the masked anchor. With a long fully supervised + # sequence, every position should accumulate supervised tokens. + m = _make_dspark_model(block_size=4, num_anchors=8) + b = _batch(B=2, S=40) + b["loss_mask"] = torch.ones(2, 40) + _, _, _, _, count_per_position, _ = m(**b) + self.assertEqual(count_per_position.shape[0], 4) + self.assertTrue( + (count_per_position > 0).all(), + f"some slot unsupervised: {count_per_position.tolist()}", + ) + + def test_grad_flow_and_frozen_embedding(self): + m = _make_dspark_model() + loss, *_ = m(**_batch(seed=2)) + loss.backward() + draft = m.draft_model + self.assertIsNotNone(draft.markov_head.markov_w2.weight.grad) + self.assertGreater(draft.markov_head.markov_w2.weight.grad.abs().sum().item(), 0) + self.assertIsNotNone(draft.confidence_head.proj.weight.grad) + self.assertGreater( + draft.confidence_head.proj.weight.grad.abs().sum().item(), 0 + ) + # backbone context projection (SpecForge's `fc`) gets gradient + self.assertIsNotNone(draft.fc.weight.grad) + self.assertGreater(draft.fc.weight.grad.abs().sum().item(), 0) + # target embedding is frozen (lives on the wrapper, not the draft) + self.assertIsNone(m.embed_tokens.weight.grad) + + def test_ce_only_without_target(self): + # ce-only (l1=0, no confidence) must run without last_hidden_states. + m = _make_dspark_model( + markov_rank=16, + enable_confidence_head=False, + confidence_head_with_markov=False, + ) + m.l1_loss_alpha = 0.0 + m.ce_loss_alpha = 1.0 + m.confidence_head_alpha = 0.0 + b = _batch() + b["last_hidden_states"] = None + loss, *_ = m(**b) + self.assertTrue(torch.isfinite(loss)) + + +class TestHeadMath(unittest.TestCase): + def test_vanilla_markov_is_bigram_bias(self): + torch.manual_seed(0) + mk = VanillaMarkov(vocab_size=50, markov_rank=8) + base = torch.randn(2, 3, 4, 50) + prev = torch.randint(0, 50, (2, 3, 4)) + out = mk.apply_block_logits(base, token_ids=prev) + expected = base + mk.markov_w2(mk.markov_w1(prev)) + self.assertTrue(torch.allclose(out, expected, atol=1e-6)) + + def test_confidence_head_is_linear(self): + torch.manual_seed(0) + head = AcceptRatePredictor(20) + feats = torch.randn(2, 3, 4, 20) + out = head(feats) + expected = head.proj(feats).squeeze(-1) + self.assertTrue(torch.allclose(out, expected, atol=1e-6)) + self.assertEqual(out.shape, (2, 3, 4)) + + +if __name__ == "__main__": + unittest.main() From 3a4f4323f69df9d46a10ae42cbb7caeac97f8fd4 Mon Sep 17 00:00:00 2001 From: maocheng23 Date: Mon, 29 Jun 2026 12:32:17 -0700 Subject: [PATCH 2/2] =?UTF-8?q?style:=20fix=20lint=20=E2=80=94=20black=20+?= =?UTF-8?q?=20autoflake=20+=20executable=20shebang=20on=20DSpark=20files?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.8 --- examples/run_qwen3_8b_dspark_online.sh | 0 scripts/train_dspark.py | 0 specforge/core/dspark.py | 14 +++++++------- specforge/modeling/target/dflash_target_model.py | 4 +--- tests/test_utils/test_dspark.py | 6 +++--- 5 files changed, 11 insertions(+), 13 deletions(-) mode change 100644 => 100755 examples/run_qwen3_8b_dspark_online.sh mode change 100644 => 100755 scripts/train_dspark.py diff --git a/examples/run_qwen3_8b_dspark_online.sh b/examples/run_qwen3_8b_dspark_online.sh old mode 100644 new mode 100755 diff --git a/scripts/train_dspark.py b/scripts/train_dspark.py old mode 100644 new mode 100755 diff --git a/specforge/core/dspark.py b/specforge/core/dspark.py index d71571a9c..6d16e88ba 100644 --- a/specforge/core/dspark.py +++ b/specforge/core/dspark.py @@ -31,7 +31,7 @@ by ``generate_dflash_data`` and fed straight to the draft as ``target_hidden``. """ -from typing import List, Optional, Tuple +from typing import Optional, Tuple import torch import torch.distributed as dist @@ -226,9 +226,9 @@ def forward( # ---- Cross entropy (hard labels) ---- flat_logits = logits_4d.reshape(-1, vocab_size) flat_targets = target_ids.reshape(-1) - ce_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none").view( - bsz, n_blocks, self.block_size - ) + ce_per_token = F.cross_entropy( + flat_logits, flat_targets, reduction="none" + ).view(bsz, n_blocks, self.block_size) ce_num = (ce_per_token * decay_weight_mask).sum() # ---- L1 distribution distillation + accept rate ---- @@ -250,9 +250,9 @@ def forward( hdim = last_hidden_states.size(-1) gather_idx = tgt_idx.reshape(bsz, -1, 1).expand(-1, -1, hdim) aligned_hidden = torch.gather(last_hidden_states, 1, gather_idx) - aligned_target_logits = F.linear( - aligned_hidden, self.lm_head.weight - ).view(bsz, n_blocks, self.block_size, vocab_size) + aligned_target_logits = F.linear(aligned_hidden, self.lm_head.weight).view( + bsz, n_blocks, self.block_size, vocab_size + ) draft_probs = torch.softmax(logits_4d.float(), dim=-1) target_probs = torch.softmax(aligned_target_logits.float(), dim=-1) l1_per_token = (draft_probs - target_probs).abs().sum(dim=-1) # [B, nb, bs] diff --git a/specforge/modeling/target/dflash_target_model.py b/specforge/modeling/target/dflash_target_model.py index 09186ad19..27821d869 100644 --- a/specforge/modeling/target/dflash_target_model.py +++ b/specforge/modeling/target/dflash_target_model.py @@ -225,9 +225,7 @@ def generate_dflash_data( hidden_states = torch.cat([h.unsqueeze(0) for h in context_list], dim=0) last_hidden_states = None if final_list is not None: - last_hidden_states = torch.cat( - [h.unsqueeze(0) for h in final_list], dim=0 - ) + last_hidden_states = torch.cat([h.unsqueeze(0) for h in final_list], dim=0) input_ids = torch.cat([d[0] for d in data_cache], dim=0) attention_mask = torch.cat([d[1] for d in data_cache], dim=0) loss_mask = torch.cat([d[2] for d in data_cache], dim=0) diff --git a/tests/test_utils/test_dspark.py b/tests/test_utils/test_dspark.py index 11ba3ae4f..a43eb5a84 100644 --- a/tests/test_utils/test_dspark.py +++ b/tests/test_utils/test_dspark.py @@ -237,11 +237,11 @@ def test_grad_flow_and_frozen_embedding(self): loss.backward() draft = m.draft_model self.assertIsNotNone(draft.markov_head.markov_w2.weight.grad) - self.assertGreater(draft.markov_head.markov_w2.weight.grad.abs().sum().item(), 0) - self.assertIsNotNone(draft.confidence_head.proj.weight.grad) self.assertGreater( - draft.confidence_head.proj.weight.grad.abs().sum().item(), 0 + draft.markov_head.markov_w2.weight.grad.abs().sum().item(), 0 ) + self.assertIsNotNone(draft.confidence_head.proj.weight.grad) + self.assertGreater(draft.confidence_head.proj.weight.grad.abs().sum().item(), 0) # backbone context projection (SpecForge's `fc`) gets gradient self.assertIsNotNone(draft.fc.weight.grad) self.assertGreater(draft.fc.weight.grad.abs().sum().item(), 0)