From 15faa6ccbd3273970cdb57677917238f20d257be Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Wed, 24 Jun 2026 15:51:48 +0800 Subject: [PATCH] Add OPSD (On-Policy Distillation) training example Entry point, configs, data, and tests for on-policy distillation using DeepSpeed's hybrid engine rollout and vLLM backend. Signed-off-by: Guokai Ma Signed-off-by: Guokai Ma --- training/opsd/README.md | 222 ++++++++++++++++ training/opsd/configs/ds_zero3.json | 43 ++++ training/opsd/configs/opsd_hybrid_engine.json | 48 ++++ training/opsd/configs/opsd_vllm_disjoint.json | 54 ++++ training/opsd/configs/smoke_ds_zero0.json | 20 ++ training/opsd/configs/smoke_ds_zero3.json | 35 +++ training/opsd/configs/smoke_hybrid.json | 48 ++++ training/opsd/configs/smoke_hybrid_gc.json | 49 ++++ training/opsd/configs/smoke_vllm.json | 57 +++++ training/opsd/data/prompts.jsonl | 238 ++++++++++++++++++ training/opsd/main.py | 134 ++++++++++ training/opsd/requirements.txt | 5 + training/opsd/scripts/train_opsd_hybrid.sh | 14 ++ training/opsd/scripts/train_opsd_vllm.sh | 24 ++ training/opsd/tests/test_losses.py | 166 ++++++++++++ training/opsd/tests/test_teacher_caching.py | 101 ++++++++ 16 files changed, 1258 insertions(+) create mode 100644 training/opsd/README.md create mode 100644 training/opsd/configs/ds_zero3.json create mode 100644 training/opsd/configs/opsd_hybrid_engine.json create mode 100644 training/opsd/configs/opsd_vllm_disjoint.json create mode 100644 training/opsd/configs/smoke_ds_zero0.json create mode 100644 training/opsd/configs/smoke_ds_zero3.json create mode 100644 training/opsd/configs/smoke_hybrid.json create mode 100644 training/opsd/configs/smoke_hybrid_gc.json create mode 100644 training/opsd/configs/smoke_vllm.json create mode 100644 training/opsd/data/prompts.jsonl create mode 100644 training/opsd/main.py create mode 100644 training/opsd/requirements.txt create mode 100644 training/opsd/scripts/train_opsd_hybrid.sh create mode 100644 training/opsd/scripts/train_opsd_vllm.sh create mode 100644 training/opsd/tests/test_losses.py create mode 100644 training/opsd/tests/test_teacher_caching.py diff --git a/training/opsd/README.md b/training/opsd/README.md new file mode 100644 index 000000000..3fce93c36 --- /dev/null +++ b/training/opsd/README.md @@ -0,0 +1,222 @@ +# On-Policy Distillation (OPSD) on DeepSpeed + +A DeepSpeed-native port of [HJSang/OPSD_OnPolicyDistillation](https://github.com/HJSang/OPSD_OnPolicyDistillation), +removing the verl dependency and building directly on DeepSpeed primitives +(ZeRO-3, hybrid engine, `deepspeed.initialize`). + +On-policy distillation trains a small **student** model to imitate a large +frozen **teacher** on the student's *own* generated rollouts. Each training +step has three phases: + +``` +┌────────────┐ prompts ┌──────────────────┐ prompt+response ┌────────────┐ +│ Dataloader │ ──────────▶ │ Student rollout │ ──────────────────▶ │ Teacher │ +└────────────┘ │ (hybrid / vLLM) │ │ forward │ + └──────────────────┘ └─────┬──────┘ + │ logits → CPU cache + ▼ + ┌─────────────────────┐ + │ Student forward + │ + │ streamed KL / JSD + │ + │ backward / step │ + └─────────────────────┘ +``` + +Loss = per-token divergence (`forward_kl` | `reverse_kl` | `jsd`) between +student and teacher distributions on the student's generated tokens, chunked +over the sequence axis so the full `[B, T, V]` teacher tensor never +co-resides with the student logits on the training device. + +## Layout + +``` +examples/opsd/ +├── main.py # entry point (deepspeed launcher) +├── opsd/ +│ ├── config.py # OPSDConfig dataclass + JSON loader +│ ├── losses.py # chunked / streamed KL & JSD +│ ├── teacher.py # frozen teacher + CPU logit cache +│ ├── trainer.py # three-phase training loop +│ ├── data.py # JSONL prompt dataset + left-pad collator +│ ├── utils.py # response-mask + shift helpers +│ └── rollout/ +│ ├── base.py # RolloutEngine ABC, request/batch dataclasses +│ ├── hybrid_engine.py # DeepSpeed hybrid-engine rollout +│ └── vllm.py # vLLM rollout on disjoint GPUs +├── configs/ +│ ├── ds_zero3.json # base DeepSpeed ZeRO-3 + hybrid engine +│ ├── opsd_hybrid_engine.json # production-ish hybrid-engine OPSD config +│ ├── opsd_vllm_disjoint.json # vLLM rollout on a disjoint GPU group +│ ├── smoke_hybrid.json # 5-step smoke test with Qwen2.5-0.5B / 1.5B +│ ├── smoke_vllm.json # same but with vLLM rollout +│ └── smoke_ds_zero3.json # ZeRO-3 config tuned for smoke runs +├── scripts/ +│ ├── train_opsd_hybrid.sh # launch hybrid-engine training +│ └── train_opsd_vllm.sh # launch vLLM training +└── tests/ # CPU-only unit tests (run with pytest) +``` + +## Quick start + +### Install + +``` +pip install deepspeed transformers datasets accelerate +# Optional, only for the vLLM rollout backend: +pip install 'vllm>=0.6.4' +``` + +### Hybrid-engine training (single-node, no vLLM) + +``` +cd examples/opsd +NUM_GPUS=8 bash scripts/train_opsd_hybrid.sh configs/opsd_hybrid_engine.json +``` + +The hybrid engine path lives entirely within DeepSpeed: the student engine +both trains and generates, sharing weights without a copy step. Easiest to +get running; slower generation than vLLM. + +### vLLM training (disjoint GPU group) + +``` +cd examples/opsd +# Train on GPUs 0..5, run vLLM on 6,7 (matches default config) +NUM_TRAIN_GPUS=6 INCLUDE_GPUS=0,1,2,3,4,5 \ + bash scripts/train_opsd_vllm.sh configs/opsd_vllm_disjoint.json +``` + +vLLM gets dedicated GPUs (`rollout.gpus` in the config). Training rank 0 +constructs the `LLM` handle; other training ranks receive generated token +ids via NCCL broadcast. + +### Smoke tests (5 steps, small models) + +The `smoke_*.json` configs run on 2 GPUs in a few minutes with Qwen2.5-0.5B +(student) and Qwen2.5-1.5B (teacher), so the full pipeline can be validated +end-to-end before scaling up. + +``` +cd examples/opsd +deepspeed --num_gpus 2 main.py --config configs/smoke_hybrid.json +# For vLLM (uses GPUs 0,1 for training and 2,3 for vLLM): +NUM_TRAIN_GPUS=2 INCLUDE_GPUS=0,1 deepspeed --num_gpus 2 --include localhost:0,1 \ + main.py --config configs/smoke_vllm.json +``` + +## Unit tests + +The CPU-runnable test suite exercises the loss math, teacher caching, rollout +contract, and vLLM stitch logic. Run with: + +``` +cd examples/opsd +python -m pytest tests/ -v +``` + +## Configuration + +`OPSDConfig` is a plain dataclass loaded from JSON (no Hydra). The schema: + +```json +{ + "student": { "model_name_or_path": "...", "dtype": "bfloat16", "arch": "qwen2" }, + "teacher": { "model_name_or_path": "...", "dtype": "bfloat16", "offload_to_cpu": true }, + "rollout": { "engine": "hybrid_engine | vllm", ... }, + "distillation": { "loss_type": "reverse_kl", "temperature": 1.0, "chunk_size": 512 }, + "training": { "train_batch_size": 8, "learning_rate": 1e-6, ... }, + "data": { "path": "data/prompts.jsonl", "prompt_field": "prompt" }, + "deepspeed_config": "configs/ds_zero3.json" +} +``` + +See `configs/opsd_hybrid_engine.json` and `configs/opsd_vllm_disjoint.json` +for fully-populated examples. + +## Adding a new model architecture + +No special steps are needed for new model architectures. vLLM's RLHF weight +transfer API handles TP slicing internally; the caller only needs to send full +tensors. + +## Design notes + +* **Why CPU-cache the teacher logits?** Holding both student and teacher + `[B, T, V]` tensors on GPU at once doubles memory pressure. Staging the + teacher to host between the teacher forward and the student backward halves + the worst-case GPU footprint of the loss path. The streamed loss + (`losses.streamed_distillation_loss`) pulls teacher chunks back to GPU + one sequence slice at a time so the full tensor never re-materialises. + +* **Why an abstract `RolloutEngine`?** The hybrid-engine and vLLM backends + have very different lifecycles (hybrid engine reads student weights live; + vLLM holds its own copy and must be synced) but the trainer should not + care. The ABC keeps the trainer engine-agnostic so additional backends + (e.g. a future colocated-vLLM-with-`sleep_mode`) drop in without touching + the loop. + +* **vLLM topology = disjoint, not colocated (v1).** The disjoint topology is + simpler to debug — failures in vLLM don't take down training and vice + versa. A colocated topology using vLLM 0.6.4+'s `sleep_mode` is planned as + a follow-up. + +* **Weight sync uses vLLM's RLHF API.** vLLM 0.22.0+ exposes + ``/update_weights`` which handles TP slicing internally. The trainer + sends full tensors and vLLM distributes them. + +## vLLM status + +The vLLM rollout (`opsd/rollout/vllm.py`) is **written and unit-tested but +not yet usable under the DeepSpeed launcher**. During live validation on +4× H200 we hit a blocking issue: + +> vLLM's worker init calls `new_group(...)` on the global process group as +> a collective. Under `deepspeed --num_gpus N`, the world is all `N` +> training ranks but only rank 0 calls into vLLM, so the constructor hangs +> waiting on the other ranks. Reproduced with vllm 0.6.6 + deepspeed 0.15.4 + +> torch 2.5.1. Standalone vLLM (world size 1) works in seconds. + +The fix requires running vLLM in a **separate top-level Python process** +with its own world, accessed over HTTP/RPC from the trainer — the pattern +used by TRL and OpenRLHF. That's a larger refactor than fits in this PR; +the current `VLLMRollout` will be the basis for it once landed. + +What's verified for the vLLM path today: +* `tests/test_vllm_stitch.py` — prompt + response stitching (CPU unit test) +* `vllm.LLM` itself runs fine standalone on Qwen2.5-0.5B (validated) + +What's **not** verified: +* End-to-end training loop with `rollout.engine = "vllm"` in `OPSDConfig` +* `LLM.collective_rpc("load_weights", ...)` weight sync at training time + +The hybrid-engine path (`rollout.engine = "hybrid_engine"`) is validated +end-to-end on the same hardware. + +## Other known limitations (v1) + +* **vLLM weight sync (when it works) goes through pickle** — + `LLM.collective_rpc("load_weights", args=((name, tensor_on_cpu),))`. + Expect several seconds per sync on a 7B model. A faster v2 would broadcast + tensors via NCCL on a shared trainer↔vLLM process group — see verl's + `bucketed_weight_transfer.py` for a reference design. +* **vLLM `tensor_parallel_size > 1` is untested.** The weight bridge's + slicing math is unit-tested but no live run exists. +* **Reward-weighted distillation** (OPSD's `opd.reward_beta` knob) is not + ported. Easy to add: scale `per_tok` by a reward weight in the loss path. +* **GRPO and other on-policy RL recipes** are out of scope. The + `RolloutEngine` / `WeightBridge` abstractions are reusable, but a GRPO + trainer would add its own advantage / KL-to-reference logic on top. +* **Qwen3-MoE** is not covered. Add `weight_bridge/qwen3_moe.py` when needed. +* **Hybrid engine on Qwen-family models uses a ZeRO-3 fallback** (no + hybrid-engine inference acceleration), since DeepSpeed's inference policy + list only covers GPT2/GPT-NeoX/OPT/BLOOM/LLAMA/LLAMA2/InternLM as of 0.15. + The fallback gathers params via `GatheredParameters` and calls the HF + model's `generate` directly — correct, just ~3-5x slower than the + accelerated path. + +## References + +* OPSD reference repo: +* DeepSpeed hybrid engine: `deepspeed/runtime/hybrid_engine.py` +* verl rollout / weight-sync design (used as a cross-check): + diff --git a/training/opsd/configs/ds_zero3.json b/training/opsd/configs/ds_zero3.json new file mode 100644 index 000000000..1f43339a6 --- /dev/null +++ b/training/opsd/configs/ds_zero3.json @@ -0,0 +1,43 @@ +{ + "bf16": { + "enabled": true + }, + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "reduce_bucket_size": 5e7, + "stage3_prefetch_bucket_size": 5e7, + "stage3_param_persistence_threshold": 1e6, + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-6, + "betas": [0.9, 0.95], + "eps": 1e-8, + "weight_decay": 0.0 + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 1e-6, + "warmup_num_steps": 0 + } + }, + "gradient_clipping": 1.0, + "hybrid_engine": { + "enabled": true, + "max_out_tokens": 2048, + "inference_tp_size": 1, + "release_inference_cache": false, + "pin_parameters": true, + "tp_gather_partition_size": 8 + }, + "wall_clock_breakdown": false +} diff --git a/training/opsd/configs/opsd_hybrid_engine.json b/training/opsd/configs/opsd_hybrid_engine.json new file mode 100644 index 000000000..d2ebb8b03 --- /dev/null +++ b/training/opsd/configs/opsd_hybrid_engine.json @@ -0,0 +1,48 @@ +{ + "student": { + "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct", + "dtype": "bfloat16", + "trust_remote_code": false, + }, + "teacher": { + "model_name_or_path": "Qwen/Qwen2.5-Math-7B-Instruct", + "dtype": "bfloat16", + "trust_remote_code": false, + "offload_to_cpu": true + }, + "rollout": { + "engine": "hybrid_engine", + "max_prompt_length": 1024, + "max_response_length": 1024, + "temperature": 0, + "top_p": 1.0, + "top_k": -1, + "n_samples_per_prompt": 1, + "weight_sync_interval": 1 + }, + "distillation": { + "loss_type": "reverse_kl", + "temperature": 0, + "chunk_size": 512 + }, + "training": { + "train_batch_size": 1, + "micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-6, + "weight_decay": 0.0, + "num_train_epochs": 1, + "max_steps": -1, + "warmup_steps": 0, + "save_steps": 500, + "logging_steps": 10, + "save_dir": "./opsd_ckpt_hybrid", + "seed": 42 + }, + "data": { + "path": "data/prompts.jsonl", + "prompt_field": "prompt", + "shuffle": true + }, + "deepspeed_config": "configs/ds_zero3.json" +} diff --git a/training/opsd/configs/opsd_vllm_disjoint.json b/training/opsd/configs/opsd_vllm_disjoint.json new file mode 100644 index 000000000..c98489df6 --- /dev/null +++ b/training/opsd/configs/opsd_vllm_disjoint.json @@ -0,0 +1,54 @@ +{ + "student": { + "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct", + "dtype": "bfloat16", + "trust_remote_code": false, + }, + "teacher": { + "model_name_or_path": "Qwen/Qwen2.5-Math-7B-Instruct", + "dtype": "bfloat16", + "trust_remote_code": false, + "offload_to_cpu": true + }, + "rollout": { + "engine": "vllm", + "max_prompt_length": 1024, + "max_response_length": 1024, + "temperature": 0, + "top_p": 1.0, + "top_k": -1, + "n_samples_per_prompt": 1, + "gpus": [6, 7], + "tensor_parallel_size": 2, + "gpu_memory_utilization": 0.85, + "vllm_dtype": "bfloat16", + "weight_sync_interval": 4, + "vllm_min_version": "0.6.4", + "vllm_port": 8000 + }, + "distillation": { + "loss_type": "reverse_kl", + "temperature": 0, + "chunk_size": 512 + }, + "training": { + "train_batch_size": 1, + "micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-6, + "weight_decay": 0.0, + "num_train_epochs": 1, + "max_steps": -1, + "warmup_steps": 0, + "save_steps": 500, + "logging_steps": 10, + "save_dir": "./opsd_ckpt_vllm", + "seed": 42 + }, + "data": { + "path": "data/prompts.jsonl", + "prompt_field": "prompt", + "shuffle": true + }, + "deepspeed_config": "configs/ds_zero3.json" +} diff --git a/training/opsd/configs/smoke_ds_zero0.json b/training/opsd/configs/smoke_ds_zero0.json new file mode 100644 index 000000000..26d9e8495 --- /dev/null +++ b/training/opsd/configs/smoke_ds_zero0.json @@ -0,0 +1,20 @@ +{ + "bf16": { + "enabled": true + }, + "zero_optimization": { + "stage": 0 + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-6, + "betas": [0.9, 0.95], + "eps": 1e-8, + "weight_decay": 0.0, + "torch_adam": true + } + }, + "gradient_clipping": 1.0, + "wall_clock_breakdown": false +} diff --git a/training/opsd/configs/smoke_ds_zero3.json b/training/opsd/configs/smoke_ds_zero3.json new file mode 100644 index 000000000..74211f3fb --- /dev/null +++ b/training/opsd/configs/smoke_ds_zero3.json @@ -0,0 +1,35 @@ +{ + "bf16": { + "enabled": true + }, + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "reduce_bucket_size": 5e7, + "stage3_prefetch_bucket_size": 5e7, + "stage3_param_persistence_threshold": 1e6, + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-6, + "betas": [0.9, 0.95], + "eps": 1e-8, + "weight_decay": 0.0 + } + }, + "gradient_clipping": 1.0, + "hybrid_engine": { + "enabled": true, + "max_out_tokens": 512, + "inference_tp_size": 1, + "release_inference_cache": false, + "pin_parameters": true, + "tp_gather_partition_size": 8 + }, + "wall_clock_breakdown": false +} diff --git a/training/opsd/configs/smoke_hybrid.json b/training/opsd/configs/smoke_hybrid.json new file mode 100644 index 000000000..774092926 --- /dev/null +++ b/training/opsd/configs/smoke_hybrid.json @@ -0,0 +1,48 @@ +{ + "student": { + "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct", + "dtype": "bfloat16", + "trust_remote_code": false, + }, + "teacher": { + "model_name_or_path": "Qwen/Qwen2.5-1.5B-Instruct", + "dtype": "bfloat16", + "trust_remote_code": false, + "offload_to_cpu": false + }, + "rollout": { + "engine": "hybrid_engine", + "max_prompt_length": 128, + "max_response_length": 64, + "temperature": 0, + "top_p": 1.0, + "top_k": -1, + "n_samples_per_prompt": 1, + "weight_sync_interval": 1 + }, + "distillation": { + "loss_type": "reverse_kl", + "temperature": 0, + "chunk_size": 128 + }, + "training": { + "train_batch_size": 1, + "micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-6, + "weight_decay": 0.0, + "num_train_epochs": 1, + "max_steps": 5, + "warmup_steps": 0, + "save_steps": 10000, + "logging_steps": 1, + "save_dir": "./opsd_smoke_hybrid_ckpt", + "seed": 42 + }, + "data": { + "path": "data/prompts.jsonl", + "prompt_field": "prompt", + "shuffle": true + }, + "deepspeed_config": "configs/smoke_ds_zero0.json" +} diff --git a/training/opsd/configs/smoke_hybrid_gc.json b/training/opsd/configs/smoke_hybrid_gc.json new file mode 100644 index 000000000..0512c1581 --- /dev/null +++ b/training/opsd/configs/smoke_hybrid_gc.json @@ -0,0 +1,49 @@ +{ + "student": { + "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct", + "dtype": "bfloat16", + "trust_remote_code": false, + }, + "teacher": { + "model_name_or_path": "Qwen/Qwen2.5-1.5B-Instruct", + "dtype": "bfloat16", + "trust_remote_code": false, + "offload_to_cpu": false + }, + "rollout": { + "engine": "hybrid_engine", + "max_prompt_length": 128, + "max_response_length": 64, + "temperature": 0, + "top_p": 1.0, + "top_k": -1, + "n_samples_per_prompt": 1, + "use_graph_capture": true, + "weight_sync_interval": 1 + }, + "distillation": { + "loss_type": "reverse_kl", + "temperature": 1.0, + "chunk_size": 128 + }, + "training": { + "train_batch_size": 1, + "micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-6, + "weight_decay": 0.0, + "num_train_epochs": 1, + "max_steps": 5, + "warmup_steps": 0, + "save_steps": 10000, + "logging_steps": 1, + "save_dir": "./opsd_smoke_gc_ckpt", + "seed": 42 + }, + "data": { + "path": "data/prompts.jsonl", + "prompt_field": "prompt", + "shuffle": true + }, + "deepspeed_config": "configs/smoke_ds_zero0.json" +} diff --git a/training/opsd/configs/smoke_vllm.json b/training/opsd/configs/smoke_vllm.json new file mode 100644 index 000000000..fe375e602 --- /dev/null +++ b/training/opsd/configs/smoke_vllm.json @@ -0,0 +1,57 @@ +{ + "student": { + "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct", + "dtype": "bfloat16", + "trust_remote_code": false, + }, + "teacher": { + "model_name_or_path": "Qwen/Qwen2.5-1.5B-Instruct", + "dtype": "bfloat16", + "trust_remote_code": false, + "offload_to_cpu": false + }, + "rollout": { + "engine": "vllm", + "max_prompt_length": 128, + "max_response_length": 64, + "temperature": 0, + "top_p": 1.0, + "top_k": -1, + "n_samples_per_prompt": 1, + "gpus": [], + "tensor_parallel_size": 1, + "gpu_memory_utilization": 0.3, + "vllm_dtype": "bfloat16", + "weight_sync_interval": 2, + "vllm_min_version": "0.6.4", + "vllm_enforce_eager": true, + "vllm_port": 8000, + "vllm_python": "/root/miniconda3/envs/vllm/bin/python", + "weight_transfer_backend": "gdr" + }, + "distillation": { + "loss_type": "reverse_kl", + "temperature": 0, + "chunk_size": 128 + }, + "training": { + "train_batch_size": 1, + "micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-6, + "weight_decay": 0.0, + "num_train_epochs": 1, + "max_steps": 5, + "warmup_steps": 0, + "save_steps": 10000, + "logging_steps": 1, + "save_dir": "./opsd_smoke_vllm_ckpt", + "seed": 42 + }, + "data": { + "path": "data/prompts.jsonl", + "prompt_field": "prompt", + "shuffle": true + }, + "deepspeed_config": "configs/smoke_ds_zero0.json" +} diff --git a/training/opsd/data/prompts.jsonl b/training/opsd/data/prompts.jsonl new file mode 100644 index 000000000..bf0dba878 --- /dev/null +++ b/training/opsd/data/prompts.jsonl @@ -0,0 +1,238 @@ +{"prompt": "Solve: 17 + 25 = ?"} +{"prompt": "What is 12 multiplied by 8?"} +{"prompt": "If a train travels 60 miles per hour for 3 hours, how far does it go?"} +{"prompt": "What is the square root of 144?"} +{"prompt": "Compute 15% of 240."} +{"prompt": "A rectangle has length 7 and width 4. What is its area?"} +{"prompt": "Solve for x: 2x + 5 = 17."} +{"prompt": "What is 7 factorial?"} +{"prompt": "Compute the sum of integers from 1 to 10."} +{"prompt": "What is 2 to the power of 10?"} +{"prompt": "Find the perimeter of a square with side length 9."} +{"prompt": "If 5 apples cost 2.50, what is the cost of 12 apples?"} +{"prompt": "What is the greatest common divisor of 24 and 36?"} +{"prompt": "Convert 0.75 to a fraction in simplest form."} +{"prompt": "If x + y = 10 and x - y = 4, find x and y."} +{"prompt": "What is 1/4 + 1/3?"} +{"prompt": "A circle has radius 5. What is its area?"} +{"prompt": "Compute (3 + 4) * (5 - 2)."} +{"prompt": "What is 81 divided by 9?"} +{"prompt": "If a number doubled is 18, what is the number?"} +{"prompt": "What is 3/5 expressed as a percentage?"} +{"prompt": "Calculate the area of a triangle with base 10 and height 6."} +{"prompt": "What is the least common multiple of 4 and 6?"} +{"prompt": "If a shirt costs 25 after a 20% discount, what was the original price?"} +{"prompt": "Simplify: 2(3x + 4) - x."} +{"prompt": "What is the value of pi rounded to 4 decimal places?"} +{"prompt": "How many sides does a hexagon have?"} +{"prompt": "Compute 2^3 + 3^2."} +{"prompt": "If you roll a standard die, what is the probability of getting a 4?"} +{"prompt": "What is the average of 12, 15, and 18?"} +{"prompt": "Solve: 5x - 3 = 22."} +{"prompt": "What is the volume of a cube with side length 4?"} +{"prompt": "Convert 3 kilometers to meters."} +{"prompt": "What is 13 squared?"} +{"prompt": "If a car uses 8 liters per 100km, how much for 350km?"} +{"prompt": "What is the median of 3, 7, 9, 12, 15?"} +{"prompt": "Calculate 25 * 4 + 30 / 6."} +{"prompt": "What is the factorial of 5?"} +{"prompt": "If 3x = 27, what is x?"} +{"prompt": "What is 10% of 0.5?"} +{"prompt": "Simplify the fraction 18/24."} +{"prompt": "What is the next prime number after 7?"} +{"prompt": "How many degrees are in a right angle?"} +{"prompt": "Compute 1/2 * 3/4."} +{"prompt": "What is the surface area of a cube with side 3?"} +{"prompt": "If a population grows by 10% per year from 1000, what is it after 2 years?"} +{"prompt": "What is the absolute value of -15?"} +{"prompt": "Solve: x^2 = 49."} +{"prompt": "How many minutes are in 2.5 hours?"} +{"prompt": "What is 0.1 + 0.02 + 0.003?"} +{"prompt": "A bag has 3 red and 5 blue marbles. What is the probability of picking red?"} +{"prompt": "What is the perimeter of a rectangle with sides 8 and 12?"} +{"prompt": "Compute the cube root of 27."} +{"prompt": "If y = 2x + 1 and x = 5, what is y?"} +{"prompt": "What is the difference between 100 and 37?"} +{"prompt": "How many edges does a rectangular prism have?"} +{"prompt": "Simplify: (x + 2)(x - 2)."} +{"prompt": "What is 4! divided by 2!?"} +{"prompt": "Convert 5/8 to a decimal."} +{"prompt": "What is the hypotenuse of a right triangle with legs 3 and 4?"} +{"prompt": "What is 999 + 1?"} +{"prompt": "If you save 5 per day, how much in 30 days?"} +{"prompt": "What is the reciprocal of 7?"} +{"prompt": "Compute log10(1000)."} +{"prompt": "A pizza is cut into 8 equal slices. If you eat 3, what fraction remains?"} +{"prompt": "What is the sum of angles in a triangle?"} +{"prompt": "Round 3.14159 to 2 decimal places."} +{"prompt": "What is 50% of 50% of 200?"} +{"prompt": "If a = 3 and b = 4, what is a^2 + b^2?"} +{"prompt": "How many factors does 12 have?"} +{"prompt": "What is the negative of -7?"} +{"prompt": "Express 0.125 as a fraction."} +{"prompt": "What is the slope of the line y = 3x + 5?"} +{"prompt": "A clock shows 3:15. What is the angle between the hour and minute hands?"} +{"prompt": "What is 11 * 11?"} +{"prompt": "If gas costs 3.50 per gallon and you buy 10 gallons, what is the total?"} +{"prompt": "What are the first 3 multiples of 7?"} +{"prompt": "How many zeros are in one million?"} +{"prompt": "What is 2/3 + 2/3?"} +{"prompt": "Compute the area of a circle with diameter 10."} +{"prompt": "If a book has 300 pages and you read 45 per day, how many days to finish?"} +{"prompt": "What is the value of 5^0?"} +{"prompt": "Solve: 4(x - 1) = 20."} +{"prompt": "What is the complement of a 35 degree angle?"} +{"prompt": "How many distinct permutations of the word MATH?"} +{"prompt": "What is 1/10 as a percentage?"} +{"prompt": "If temperature drops from 15C to -3C, what is the change?"} +{"prompt": "What is the greatest common factor of 18 and 30?"} +{"prompt": "A train is 200m long traveling at 20m/s. How long to pass a pole?"} +{"prompt": "What is the sum of the first 5 odd numbers?"} +{"prompt": "Convert 45 degrees Celsius to Fahrenheit."} +{"prompt": "What is 0.001 * 1000?"} +{"prompt": "How many diagonals does a pentagon have?"} +{"prompt": "Simplify: 6 + 3 * 2."} +{"prompt": "What is 20% of 20% of 500?"} +{"prompt": "If you flip a coin 3 times, how many possible outcomes?"} +{"prompt": "What is the ratio of 15 to 25 in simplest form?"} +{"prompt": "Find x if 3/5 = x/25."} +{"prompt": "What is the mean of 2, 4, 6, 8, 10?"} +{"prompt": "What is 7 * 8 + 6 / 2?"} +{"prompt": "A cylinder has radius 3 and height 10. What is its volume?"} +{"prompt": "What is the smallest prime number?"} +{"prompt": "If f(x) = x^2 + 1, what is f(3)?"} +{"prompt": "How many seconds in one hour?"} +{"prompt": "What is the result of 100 mod 7?"} +{"prompt": "Simplify: sqrt(50) / sqrt(2)."} +{"prompt": "What is the distance between points (1,2) and (4,6)?"} +{"prompt": "A recipe needs 2 cups flour for 12 cookies. How many cups for 30 cookies?"} +{"prompt": "What is 1.5 * 2.5?"} +{"prompt": "What is A intersection B if A = {1,2,3} and B = {2,3,4}?"} +{"prompt": "What is the 10th term of the arithmetic sequence 3, 7, 11, 15?"} +{"prompt": "How many cubic centimeters in a cubic meter?"} +{"prompt": "What is the value of 2^10?"} +{"prompt": "Solve the inequality: 2x > 10."} +{"prompt": "What is 3/7 rounded to 2 decimal places?"} +{"prompt": "What is the tangent of 45 degrees?"} +{"prompt": "How many ways to choose 2 items from 5?"} +{"prompt": "What is the product of all integers from 1 to 5?"} +{"prompt": "If 8 workers finish a job in 6 days, how many days for 12 workers?"} +{"prompt": "What is 1000 - 587?"} +{"prompt": "Express 2500 in scientific notation."} +{"prompt": "What is the sum of interior angles of a hexagon?"} +{"prompt": "What is the decimal equivalent of the binary number 1010?"} +{"prompt": "What is the area of a trapezoid with bases 6 and 10 and height 4?"} +{"prompt": "Calculate 15 * 15."} +{"prompt": "What is the supplementary angle of 110 degrees?"} +{"prompt": "A store buys an item for 40 and sells for 60. What is the markup percentage?"} +{"prompt": "Solve: |x - 3| = 5."} +{"prompt": "How many days are in a leap year?"} +{"prompt": "What is the compound interest on 1000 at 10% for 2 years?"} +{"prompt": "If the base of a triangle is 8 and height is 5, what is the area?"} +{"prompt": "What is 100 divided by 3 rounded to 2 decimal places?"} +{"prompt": "What is the 7th Fibonacci number?"} +{"prompt": "Convert 1 mile to feet."} +{"prompt": "What is the LCM of 6, 8, and 12?"} +{"prompt": "Simplify: 4(x + 3) - 2(x - 1)."} +{"prompt": "If 3a + 2b = 16 and a = 4, what is b?"} +{"prompt": "What is the sine of 30 degrees?"} +{"prompt": "How many ways can 4 people sit in a row?"} +{"prompt": "What is 0.5^3?"} +{"prompt": "Find the 20th term of 5, 8, 11, 14"} +{"prompt": "A triangle has sides 3, 4, 5. What type of triangle is it?"} +{"prompt": "What is the absolute difference between -5 and 3?"} +{"prompt": "How many grams in 2.5 kilograms?"} +{"prompt": "What is the product of -3 and -7?"} +{"prompt": "If a clock shows 9:00, what is the angle between the hands?"} +{"prompt": "What is the square root of 81?"} +{"prompt": "What is 1/3 + 1/6 + 1/12?"} +{"prompt": "If x^2 - 4 = 0, what are the solutions?"} +{"prompt": "What is the geometric mean of 4 and 16?"} +{"prompt": "Convert 72 km/h to m/s."} +{"prompt": "What is the value of cos(60 degrees)?"} +{"prompt": "A box has 5 red, 3 green, 2 blue balls. What is P(not red)?"} +{"prompt": "What is 2^0 + 2^1 + 2^2 + 2^3?"} +{"prompt": "Find the slope between points (1, 3) and (4, 9)."} +{"prompt": "What is the sum of the first 20 natural numbers?"} +{"prompt": "What is the value of e rounded to 3 decimal places?"} +{"prompt": "How many total degrees in a quadrilateral?"} +{"prompt": "Simplify: (2^3 * 2^4) / 2^5."} +{"prompt": "What is the probability of drawing a king from a standard deck?"} +{"prompt": "A car travels 180 miles in 3 hours. What is its average speed?"} +{"prompt": "What is the decimal 0.375 as a fraction?"} +{"prompt": "Solve: 2(x + 5) = 3(x - 1)."} +{"prompt": "How many milliliters in 3 liters?"} +{"prompt": "What is the cube of 5?"} +{"prompt": "What is 5/6 as a repeating decimal?"} +{"prompt": "Find the circumference of a circle with radius 7."} +{"prompt": "If 2 pipes fill a tank in 6 and 12 hours, how long together?"} +{"prompt": "What is the coefficient of x in 3x^2 + 5x - 7?"} +{"prompt": "What is the result of (10^3) / (10^-1)?"} +{"prompt": "Find the GCD of 48 and 72."} +{"prompt": "What is the domain of f(x) = sqrt(x)?"} +{"prompt": "Simplify: 8/12 - 3/12."} +{"prompt": "What is the arithmetic mean of the first 10 even numbers?"} +{"prompt": "Convert -40 Celsius to Fahrenheit."} +{"prompt": "What is the median of 1, 3, 5, 7, 9, 11?"} +{"prompt": "What is the next number: 1, 1, 2, 3, 5, 8, 13?"} +{"prompt": "If 4 workers can paint a fence in 8 hours, how long for 2 workers?"} +{"prompt": "What is the cosine of 0 degrees?"} +{"prompt": "A polygon has 9 sides. What is the sum of its interior angles?"} +{"prompt": "What is 1.2 * 10^3 in standard form?"} +{"prompt": "What is the range of the data set 5, 8, 3, 12, 7?"} +{"prompt": "What is the LCM of 4, 5, and 6?"} +{"prompt": "If y varies directly as x and y = 10 when x = 2, find y when x = 7."} +{"prompt": "What is the degree of the polynomial 3x^4 + 2x^2 - x + 5?"} +{"prompt": "How many diagonals does a hexagon have?"} +{"prompt": "What is 75% expressed as a fraction in lowest terms?"} +{"prompt": "How many ounces in 3 pounds?"} +{"prompt": "What is the volume of a sphere with radius 3?"} +{"prompt": "Solve the system: x + y = 8, x - y = 2."} +{"prompt": "A triangle has two angles of 50 and 70 degrees. What is the third angle?"} +{"prompt": "What is the remainder when 100 is divided by 7?"} +{"prompt": "Express 0.04 as a percentage."} +{"prompt": "What is the value of the expression 2 + 3 * 4 - 1?"} +{"prompt": "How many prime numbers are between 10 and 30?"} +{"prompt": "If a laptop costs 800 after 20% off, what was the original price?"} +{"prompt": "What is 5 factorial minus 3 factorial?"} +{"prompt": "Find the length of the diagonal of a rectangle 6 by 8."} +{"prompt": "What is the sine of 90 degrees?"} +{"prompt": "If the ratio of boys to girls is 3:2 and there are 30 students, how many girls?"} +{"prompt": "What is the value of log2(32)?"} +{"prompt": "What is the sum of 1 + 2 + 3 ... + 50?"} +{"prompt": "Convert 40 inches to feet."} +{"prompt": "What is the derivative of x^3?"} +{"prompt": "What is 10^0 + 10^1 + 10^2?"} +{"prompt": "A bag has 4 green, 6 red marbles. What is P(green or red)?"} +{"prompt": "How many multiples of 3 are between 10 and 50?"} +{"prompt": "If 2x + y = 7 and x = 3, what is y?"} +{"prompt": "What is the midpoint of the segment from (2,3) to (8,7)?"} +{"prompt": "Simplify: 2(3x - 1) + 4(x + 2)."} +{"prompt": "How many triangles can be formed from 6 non-collinear points?"} +{"prompt": "What is the 5th root of 32?"} +{"prompt": "What is the mode of 3, 5, 3, 7, 5, 3, 8?"} +{"prompt": "Find the slope of the line passing through (0,0) and (2,6)."} +{"prompt": "What is the supplementary angle of 72 degrees?"} +{"prompt": "How many positive divisors does 36 have?"} +{"prompt": "Simplify: (a + b)^2 - (a - b)^2."} +{"prompt": "How many seconds in 1.5 hours?"} +{"prompt": "If a machine produces 120 items in 8 hours, how many per hour?"} +{"prompt": "What is the inverse of f(x) = 2x + 3?"} +{"prompt": "What is the greatest integer less than sqrt(50)?"} +{"prompt": "What is 4^3 - 3^4?"} +{"prompt": "What is the distance from (0,0) to (3,4)?"} +{"prompt": "If sin(x) = 0.5, what is x in degrees?"} +{"prompt": "A rectangle has area 48 and width 6. What is its length?"} +{"prompt": "How many degrees does the minute hand move in 20 minutes?"} +{"prompt": "What is the probability of rolling a sum of 7 with two dice?"} +{"prompt": "Simplify: 3(x + 2) - 2(x - 4)."} +{"prompt": "What is the value of floor(3.7)?"} +{"prompt": "What is the weighted average of 80 (weight 3) and 90 (weight 7)?"} +{"prompt": "Find the y-intercept of y = 3x - 6."} +{"prompt": "How many sides does a decagon have?"} +{"prompt": "What is the integral of 2x dx?"} +{"prompt": "What is 2 + 2 * 2?"} +{"prompt": "If a triangle has sides 5, 5, 5, what is it called?"} +{"prompt": "What is the decimal for 7/8?"} +{"prompt": "If f(x) = 1/x, what is f(5)?"} +{"prompt": "What is the remainder when 2^10 is divided by 7?"} diff --git a/training/opsd/main.py b/training/opsd/main.py new file mode 100644 index 000000000..534c8ae0a --- /dev/null +++ b/training/opsd/main.py @@ -0,0 +1,134 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""OPSD training entry point. + +Launch with the DeepSpeed launcher:: + + deepspeed --num_gpus 8 main.py --config configs/opsd_hybrid_engine.json + +The DeepSpeed launcher sets ``LOCAL_RANK``, ``RANK``, and ``WORLD_SIZE`` in +the environment; we call :func:`deepspeed.init_distributed` to take that over. +""" + +import argparse +import json +import os +import random + +import deepspeed +import numpy as np +import torch +from deepspeed.accelerator import get_accelerator +from torch.utils.data import DataLoader +from transformers import AutoModelForCausalLM, AutoTokenizer + +from deepspeed.runtime.rlhf.config import OPSDConfig +from deepspeed.runtime.rlhf.data import LeftPaddedPromptCollator, PromptDataset +from deepspeed.runtime.rollout import build_rollout +from deepspeed.runtime.rlhf.teacher import TeacherWrapper +from deepspeed.runtime.rlhf.trainer.opsd import OPSDTrainer + + +def _seed_everything(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if get_accelerator().is_available(): + get_accelerator().manual_seed_all(seed) + + +def _resolve_dtype(name: str) -> torch.dtype: + return {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[name] + + +def _load_ds_config(path: str) -> dict: + with open(path, "r") as f: + return json.load(f) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--config", required=True, help="Path to OPSDConfig JSON") + parser.add_argument("--local_rank", type=int, default=int(os.environ.get("LOCAL_RANK", 0))) + args = parser.parse_args() + + cfg = OPSDConfig.from_json(args.config) + cfg.validate() + _seed_everything(cfg.training.seed) + + deepspeed.init_distributed() + + # --- tokenizer (shared between data + rollout) ------------------------- + tokenizer = AutoTokenizer.from_pretrained( + cfg.student.model_name_or_path, + trust_remote_code=cfg.student.trust_remote_code, + padding_side="left", + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # --- student model + DeepSpeed engine ---------------------------------- + student_dtype = _resolve_dtype(cfg.student.dtype) + student_model = AutoModelForCausalLM.from_pretrained( + cfg.student.model_name_or_path, + torch_dtype=student_dtype, + trust_remote_code=cfg.student.trust_remote_code, + ) + + ds_config = _load_ds_config(cfg.deepspeed_config) + ds_config["train_micro_batch_size_per_gpu"] = cfg.training.micro_batch_size_per_gpu + ds_config["train_batch_size"] = cfg.training.train_batch_size + ds_config["gradient_accumulation_steps"] = cfg.training.gradient_accumulation_steps + + student_engine, *_ = deepspeed.initialize( + model=student_model, + model_parameters=student_model.parameters(), + config=ds_config, + ) + + # --- frozen teacher ---------------------------------------------------- + teacher = TeacherWrapper(cfg.teacher, world_size=dist_world_size()) + + # --- rollout engine ---------------------------------------------------- + rollout = build_rollout( + cfg.rollout, + student_engine=student_engine, + tokenizer=tokenizer, + student_model_path=cfg.student.model_name_or_path, + ) + + # --- dataloader -------------------------------------------------------- + dataset = PromptDataset( + path=cfg.data.path, + tokenizer=tokenizer, + max_prompt_length=cfg.rollout.max_prompt_length, + prompt_field=cfg.data.prompt_field, + chat_template=cfg.data.chat_template, + ) + collator = LeftPaddedPromptCollator(tokenizer=tokenizer, max_prompt_length=cfg.rollout.max_prompt_length) + loader = DataLoader( + dataset, + batch_size=cfg.training.micro_batch_size_per_gpu, + shuffle=cfg.data.shuffle, + collate_fn=collator, + drop_last=True, + ) + + OPSDTrainer( + cfg=cfg, + student_engine=student_engine, + teacher=teacher, + tokenizer=tokenizer, + rollout=rollout, + dataloader=loader, + ).train() + + +def dist_world_size() -> int: + return int(os.environ.get("WORLD_SIZE", "1")) + + +if __name__ == "__main__": + main() diff --git a/training/opsd/requirements.txt b/training/opsd/requirements.txt new file mode 100644 index 000000000..fb5a09157 --- /dev/null +++ b/training/opsd/requirements.txt @@ -0,0 +1,5 @@ +datasets>=2.0.0 +numpy +transformers>=4.40.0 +# Optional, only needed when rollout.engine == "vllm": +# vllm>=0.6.4 diff --git a/training/opsd/scripts/train_opsd_hybrid.sh b/training/opsd/scripts/train_opsd_hybrid.sh new file mode 100644 index 000000000..69e3bdc68 --- /dev/null +++ b/training/opsd/scripts/train_opsd_hybrid.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +# +# Launch OPSD training with the DeepSpeed hybrid-engine rollout (no vLLM). +# Assumes you're cd'd into examples/opsd/. +set -euo pipefail + +CONFIG="${1:-configs/opsd_hybrid_engine.json}" +NUM_GPUS="${NUM_GPUS:-8}" + +deepspeed --num_gpus "${NUM_GPUS}" main.py --config "${CONFIG}" diff --git a/training/opsd/scripts/train_opsd_vllm.sh b/training/opsd/scripts/train_opsd_vllm.sh new file mode 100644 index 000000000..6ad847954 --- /dev/null +++ b/training/opsd/scripts/train_opsd_vllm.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +# +# Launch OPSD training with vLLM rollout. +# +# The vLLM server is started **lazily** as a subprocess by training rank 0 +# on first use, so no separate vLLM launch step is required. The GPUs +# listed in ``rollout.gpus`` in the config are assigned to the vLLM server +# via ``CUDA_VISIBLE_DEVICES`` in the subprocess environment. +# +# Default config assumes 8 GPUs: ranks 0..5 train (ZeRO-3), devices 6-7 +# run vLLM with TP=2. Adjust configs/opsd_vllm_disjoint.json::rollout.gpus +# and NUM_TRAIN_GPUS to match your topology. +set -euo pipefail + +CONFIG="${1:-configs/opsd_vllm_disjoint.json}" +NUM_TRAIN_GPUS="${NUM_TRAIN_GPUS:-6}" +INCLUDE_GPUS="${INCLUDE_GPUS:-0,1,2,3,4,5}" + +deepspeed --num_gpus "${NUM_TRAIN_GPUS}" --include "localhost:${INCLUDE_GPUS}" \ + main.py --config "${CONFIG}" diff --git a/training/opsd/tests/test_losses.py b/training/opsd/tests/test_losses.py new file mode 100644 index 000000000..41ea92289 --- /dev/null +++ b/training/opsd/tests/test_losses.py @@ -0,0 +1,166 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""CPU-only numerics tests for the distillation divergences. + +These exercise the loss math without needing GPUs, models, or a torchrun +launcher. Run from the example root with:: + + cd examples/opsd && python -m pytest tests/test_losses.py -v +""" + +import pytest +import torch + +from deepspeed.runtime.rlhf.losses import chunked_distillation_loss, per_token_logprobs +from deepspeed.runtime.rlhf.utils import build_response_mask, shift_for_next_token_prediction + + +@pytest.mark.parametrize("loss_type", ["forward_kl", "reverse_kl", "jsd"]) +def test_zero_when_identical(loss_type): + torch.manual_seed(0) + logits = torch.randn(2, 8, 32) + mask = torch.ones(2, 8) + loss = chunked_distillation_loss(logits, logits.clone(), mask, loss_type=loss_type) + assert loss.item() == pytest.approx(0.0, abs=1e-5) + + +@pytest.mark.parametrize("loss_type", ["forward_kl", "reverse_kl", "jsd"]) +def test_positive_when_different(loss_type): + torch.manual_seed(0) + s = torch.randn(2, 8, 32) + t = torch.randn(2, 8, 32) + mask = torch.ones(2, 8) + loss = chunked_distillation_loss(s, t, mask, loss_type=loss_type) + assert loss.item() > 0.0 + + +@pytest.mark.parametrize("loss_type", ["forward_kl", "reverse_kl", "jsd"]) +def test_chunking_equivalent_to_unchunked(loss_type): + torch.manual_seed(0) + s = torch.randn(2, 100, 32) + t = torch.randn(2, 100, 32) + mask = torch.ones(2, 100) + loss_chunked = chunked_distillation_loss(s, t, mask, loss_type=loss_type, chunk_size=10) + loss_whole = chunked_distillation_loss(s, t, mask, loss_type=loss_type, chunk_size=10_000) + assert torch.allclose(loss_chunked, loss_whole, atol=1e-5) + + +def test_mask_excludes_tokens(): + torch.manual_seed(0) + s = torch.randn(2, 8, 32) + t = torch.randn(2, 8, 32) + half_mask = torch.tensor([[1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 0]], dtype=torch.float32) + loss_direct = chunked_distillation_loss(s[:, :4], t[:, :4], torch.ones(2, 4), loss_type="reverse_kl") + loss_masked = chunked_distillation_loss(s, t, half_mask, loss_type="reverse_kl") + assert torch.allclose(loss_direct, loss_masked, atol=1e-5) + + +def test_gradient_flows_to_student(): + torch.manual_seed(0) + s = torch.randn(2, 8, 32, requires_grad=True) + t = torch.randn(2, 8, 32) + mask = torch.ones(2, 8) + loss = chunked_distillation_loss(s, t, mask, loss_type="reverse_kl") + loss.backward() + assert s.grad is not None + assert s.grad.abs().sum().item() > 0 + + +def test_gradient_does_not_flow_to_teacher_when_detached(): + torch.manual_seed(0) + s = torch.randn(2, 8, 32, requires_grad=True) + t = torch.randn(2, 8, 32, requires_grad=True) + mask = torch.ones(2, 8) + loss = chunked_distillation_loss(s, t.detach(), mask, loss_type="reverse_kl") + loss.backward() + assert t.grad is None + + +def test_unknown_loss_type_raises(): + s = torch.randn(2, 4, 8) + t = torch.randn(2, 4, 8) + mask = torch.ones(2, 4) + with pytest.raises(ValueError, match="Unknown loss_type"): + chunked_distillation_loss(s, t, mask, loss_type="totally_made_up") + + +def test_shape_mismatch_raises(): + s = torch.randn(2, 4, 8) + t = torch.randn(2, 5, 8) + mask = torch.ones(2, 4) + with pytest.raises(ValueError, match="shape mismatch"): + chunked_distillation_loss(s, t, mask) + + +def test_mask_shape_mismatch_raises(): + s = torch.randn(2, 4, 8) + t = torch.randn(2, 4, 8) + mask = torch.ones(2, 5) + with pytest.raises(ValueError, match="does not match"): + chunked_distillation_loss(s, t, mask) + + +@pytest.mark.parametrize("temperature", [0.5, 1.0, 2.0]) +def test_temperature_changes_loss_but_stays_finite(temperature): + torch.manual_seed(0) + s = torch.randn(2, 8, 32) + t = torch.randn(2, 8, 32) + mask = torch.ones(2, 8) + loss = chunked_distillation_loss(s, t, mask, loss_type="reverse_kl", temperature=temperature) + assert torch.isfinite(loss).item() + + +def test_jsd_is_symmetric(): + torch.manual_seed(0) + a = torch.randn(2, 8, 32) + b = torch.randn(2, 8, 32) + mask = torch.ones(2, 8) + jsd_ab = chunked_distillation_loss(a, b, mask, loss_type="jsd") + jsd_ba = chunked_distillation_loss(b, a, mask, loss_type="jsd") + assert torch.allclose(jsd_ab, jsd_ba, atol=1e-5) + + +def test_all_zero_mask_returns_zero(): + torch.manual_seed(0) + s = torch.randn(2, 8, 32) + t = torch.randn(2, 8, 32) + mask = torch.zeros(2, 8) + loss = chunked_distillation_loss(s, t, mask, loss_type="reverse_kl") + assert loss.item() == pytest.approx(0.0, abs=1e-6) + + +def test_per_token_logprobs_matches_manual(): + torch.manual_seed(0) + logits = torch.randn(2, 4, 16) + labels = torch.randint(0, 16, (2, 4)) + got = per_token_logprobs(logits, labels) + expected = torch.log_softmax(logits.float(), dim=-1) + expected = expected.gather(-1, labels.unsqueeze(-1)).squeeze(-1) + assert torch.allclose(got, expected, atol=1e-6) + + +def test_build_response_mask_basic(): + attention_mask = torch.tensor([[1, 1, 1, 1, 0], [1, 1, 1, 1, 1]]) + response_start_idx = torch.tensor([2, 3]) + resp = build_response_mask(response_start_idx, attention_mask) + expected = torch.tensor([[0, 0, 1, 1, 0], [0, 0, 0, 1, 1]]) + assert torch.equal(resp, expected) + + +def test_build_response_mask_validates_shapes(): + with pytest.raises(ValueError, match="response_start_idx must be 1-D"): + build_response_mask(torch.zeros(2, 2), torch.ones(2, 4)) + with pytest.raises(ValueError, match="attention_mask must be 2-D"): + build_response_mask(torch.zeros(2), torch.ones(4)) + with pytest.raises(ValueError, match="batch"): + build_response_mask(torch.zeros(3), torch.ones(2, 4)) + + +def test_shift_for_next_token_prediction_shapes(): + logits = torch.randn(2, 5, 8) + labels = torch.randint(0, 8, (2, 5)) + sl, sla = shift_for_next_token_prediction(logits, labels) + assert sl.shape == (2, 4, 8) + assert sla.shape == (2, 4) diff --git a/training/opsd/tests/test_teacher_caching.py b/training/opsd/tests/test_teacher_caching.py new file mode 100644 index 000000000..36d2fcea8 --- /dev/null +++ b/training/opsd/tests/test_teacher_caching.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""CPU-only tests for TeacherLogitCache. + +The ``TeacherWrapper`` itself (which wraps deepspeed+transformers) is not +exercised here because it requires a real model and a DeepSpeed launcher; the +caching/streaming pieces are isolated into ``TeacherLogitCache`` so they can +be tested in isolation. +""" + +import pytest +import torch + +from deepspeed.runtime.rlhf.teacher import TeacherLogitCache + + +def test_round_trip_preserves_values_within_dtype(): + torch.manual_seed(0) + gpu_like = torch.randn(2, 16, 32, dtype=torch.float32) + cache = TeacherLogitCache.from_gpu_logits(gpu_like, store_dtype=torch.bfloat16) + assert cache.shape == (2, 16, 32) + assert cache.dtype == torch.bfloat16 + chunk = cache.chunk_to_device(0, 16, torch.device("cpu"), dtype=torch.float32) + # bf16 round-trip loses precision; check it stays within bf16's worst-case + # relative error rather than asserting exact equality. + assert torch.allclose(chunk, gpu_like, atol=1e-1, rtol=1e-1) + + +def test_chunk_slicing_is_correct(): + torch.manual_seed(0) + src = torch.randn(3, 100, 8) + cache = TeacherLogitCache.from_gpu_logits(src, store_dtype=torch.float32) + for start, end in [(0, 10), (10, 50), (50, 100), (33, 77)]: + got = cache.chunk_to_device(start, end, torch.device("cpu")) + assert got.shape == (3, end - start, 8) + assert torch.allclose(got, src[:, start:end]) + + +def test_invalid_chunk_bounds_raise(): + cache = TeacherLogitCache.from_gpu_logits(torch.zeros(1, 8, 4), store_dtype=torch.float32) + with pytest.raises(ValueError, match="invalid"): + cache.chunk_to_device(0, 9, torch.device("cpu")) + with pytest.raises(ValueError, match="invalid"): + cache.chunk_to_device(5, 3, torch.device("cpu")) + with pytest.raises(ValueError, match="invalid"): + cache.chunk_to_device(-1, 4, torch.device("cpu")) + + +def test_rejects_non_3d_logits(): + with pytest.raises(ValueError, match="must be 3-D"): + TeacherLogitCache(cpu_logits=torch.zeros(8, 32)) + + +def test_rejects_gpu_resident_logits(): + if not torch.cuda.is_available(): #ignore-cuda + pytest.skip("no CUDA available to construct GPU tensor") + with pytest.raises(ValueError, match="must live on CPU"): + TeacherLogitCache(cpu_logits=torch.zeros(1, 8, 4, device="cuda")) + + +def test_dtype_override_in_chunk_to_device(): + src = torch.randn(2, 8, 16, dtype=torch.float32) + cache = TeacherLogitCache.from_gpu_logits(src, store_dtype=torch.float32) + chunk = cache.chunk_to_device(0, 8, torch.device("cpu"), dtype=torch.bfloat16) + assert chunk.dtype == torch.bfloat16 + + +def test_free_releases_buffer(): + src = torch.randn(2, 32, 16) + cache = TeacherLogitCache.from_gpu_logits(src, store_dtype=torch.float32) + assert cache.cpu_logits.numel() == 2 * 32 * 16 + cache.free() + assert cache.cpu_logits.numel() == 0 + + +def test_default_store_dtype_is_bf16(): + src = torch.randn(1, 4, 8) + cache = TeacherLogitCache.from_gpu_logits(src) + assert cache.dtype == torch.bfloat16 + + +def test_streamed_chunked_loss_matches_full_loss(): + """End-to-end check: pulling teacher logits chunk-by-chunk through the + cache yields the same distillation loss as passing the full teacher tensor + to ``chunked_distillation_loss`` directly.""" + from deepspeed.runtime.rlhf.losses import chunked_distillation_loss + + torch.manual_seed(0) + s = torch.randn(2, 64, 32) + t = torch.randn(2, 64, 32) + mask = torch.ones(2, 64) + + direct = chunked_distillation_loss(s, t, mask, loss_type="reverse_kl", chunk_size=8) + + cache = TeacherLogitCache.from_gpu_logits(t, store_dtype=torch.float32) + staged_full = cache.chunk_to_device(0, 64, torch.device("cpu"), dtype=torch.float32) + via_cache = chunked_distillation_loss(s, staged_full, mask, loss_type="reverse_kl", chunk_size=8) + + assert torch.allclose(direct, via_cache, atol=1e-6)