Skip to content

feat: add TrainerRank API#740

Draft
bradhilton wants to merge 13 commits into
feat/megatron-gdn-tree-corefrom
feat/trainer-rank-api
Draft

feat: add TrainerRank API#740
bradhilton wants to merge 13 commits into
feat/megatron-gdn-tree-corefrom
feat/trainer-rank-api

Conversation

@bradhilton

Copy link
Copy Markdown
Collaborator

Summary

This PR depends on #739 and contains the user-facing TrainerRank layer moved into a new module, art.trainer_rank. The intent is to keep Austin-facing Megatron/CP/GDN changes separate from the higher-level API work.

Public API

  • ForwardInput / ForwardOutput are typed and support target logprobs, multi-target labels, top-k, logits, hidden states, and per-request adapter routing.
  • TrainerRank.dp_rank_forward(...) is the explicit API for already-DP-local inputs.
  • TrainerRank.forward_micro_batches(...) is the DP-aware adaptive training iterator. It slices only shallow top-level items, returns MicroBatch objects with inputs, outputs, indices, and planning stats, and raises proactively when the smallest safe batch is expected not to fit.
  • Slot APIs are included for TrainerRank-managed adapter/checkpoint routing: set_checkpoint, set_lora, push_checkpoint, push_lora, and pop_pushed_lora_or_checkpoint.
  • optim_step(...) supports the Megatron optimizer path plus ART-managed dynamic checkpoint slot optimizers.

Implementation highlights

  • The module is now src/art/trainer_rank, not src/art/megatron/trainer_rank.py.
  • Forward execution is hidden-first: packed hidden states are produced once, then lightweight heads compute target logprobs, multi-target logprobs, logits, top-k, and hidden-state outputs.
  • Adaptive microbatch planning estimates packed tokens/output bytes before materializing candidate packs, caches candidate/final plans, and updates empirical memory profiles after successful forwards.
  • Heterogeneous adapter/checkpoint requests are grouped by resolved slot so correctness is preserved while keeping the public API simple.
  • The Triton local top-k/logsumexp helper moved with the TrainerRank module as art.trainer_rank.topk.
  • Dev harnesses live under dev/ and are intentionally outside art.megatron.

Validation

Local on this branch:

  • uv run --no-sync prek run --all-files: passed.
  • uv run python dev/trainer_rank_fast_check.py: 67 passed, 8 skipped locally.

4x H200 SkyPilot (codex-trainer-rank-h200) on this branch:

  • uv run python dev/trainer_rank_fast_check.py: 85 passed.
  • Dynamic LoRA slot integration smoke: 1 passed.
  • Compact topology smoke with Qwen3-0.6B, depth=3:
    • world=1,tp=1,cp=1: mean_abs_pct=1.43e-7, max_abs_diff=1.91e-6
    • world=2,tp=1,cp=1: mean_abs_pct=1.27e-7, max_abs_diff=1.91e-6
    • world=2,tp=2,cp=1: mean_abs_pct=7.16e-8, max_abs_diff=1.91e-6
    • CP target-logprob smoke, world=2,tp=1,cp=2: mean_abs_pct=4.51e-8, max_abs_diff=1.91e-6
    • CP target-logprob smoke, world=4,tp=1,cp=4: mean_abs_pct=4.51e-8, max_abs_diff=1.91e-6
  • 35B/A3B quick full-step check, Qwen/Qwen3.5-35B-A3B, all 40 layers, CP=4, EP=4, 4x H200, Austin-style 198k packed-token workload (30 x (5k prefix + 16 x 100 completion), 2.448M logical tokens), warmup=1, repeat=1:
    • native labels= train step: 23.5k packed tok/s
    • TrainerRank target-logprob train step: 25.6k packed tok/s
    • peak allocated: 58.2GB
    • peak reserved: 113.3GB
    • peak process memory: 119.3GB

Known caveats

  • The mixed-output CP topology smoke is compile-heavy on the 0.6B test and was cancelled after several minutes in Inductor compilation. The smaller CP target-logprob smoke passed for CP=2 and CP=4; the previous full branch had broader topology/perf sweeps, but this split PR validation keeps the quick gate bounded.
  • Pipeline parallelism remains out of scope.

Review focus

For this PR, I mostly want review on the API shape, adaptive microbatch semantics, slot routing semantics, and whether the implementation is sufficiently small and maintainable given the behavior it supports.

@bradhilton bradhilton mentioned this pull request Jun 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant