Skip to content

[JAX] Expert Parallelism: JAX primitives + VJPs#3036

Open
phu0ngng wants to merge 16 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax
Open

[JAX] Expert Parallelism: JAX primitives + VJPs#3036
phu0ngng wants to merge 16 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax

Conversation

@phu0ngng

@phu0ngng phu0ngng commented May 22, 2026

Copy link
Copy Markdown
Collaborator

Summary

Third PR in the TE Expert Parallelism (EP) series, built on top of #3034. Lands the JAX bindings: an XLA FFI layer over the nvte_ep_* C API, a Python wrapper with custom_vjp for autograd, mesh-aware sharding rules, a multi-process test suite, and an end-to-end MoE example. NCCL ncclEpDispatch/ncclEpCombine are exposed as XLA primitives and work with CUDA-graph capture.

Implementation

Public Python API (transformer_engine/jax/ep.py)

from transformer_engine.jax.ep import (
    EpHandle,        # opaque (id, handle_mem) pair from ep_prepare
    ep_bootstrap,    # one-shot per-process: init NCCL comm + nvte_ep_initialize
    ep_dispatch,     # custom_vjp-wrapped dispatch 
    ep_combine,      # custom_vjp-wrapped combine

ep_dispatch / ep_combine are jax.custom_vjp functions: forward is the FFI primitive, backward calls the matching nvte_ep_*_bwd FFI primitive directly (no ep_prepare in the bwd — routing state is already cached in handle.mem). Note that ep_dispatch also calls ep_prepare in the forward path, which all-gathers and preprocesses routing maps.

XLA FFI bindings (transformer_engine/jax/csrc/extensions/ep.cpp)

Five XLA_FFI_DEFINE_HANDLER_SYMBOL entries — EpPrepareHandler, EpDispatchHandler, EpCombineHandler, EpDispatchBwdHandler, EpCombineBwdHandler — each calling the corresponding nvte_ep_* C entry point. All marked FFI_CudaGraph_Traits so they capture cleanly. handle_id is a static FFI attribute baked at jit trace time.

Primitives + Python layer (transformer_engine/jax/cpp_extensions/ep.py, +951 lines)

Standard TE primitive plumbing: abstract_eval (shape/dtype inference), lowering, impl, outer_primitive registration, and partitioning rules so the EP collective is treated as a single sharded op by XLA (no spurious resharding around it).

Sharding (transformer_engine/jax/sharding.py, +12 lines)

Adds the EP mesh axis to the global mesh resource set so downstream sharding rules can reference it.

Build wiring (build_tools/jax.py, +41 lines)

Threads NCCL EP linkage through the JAX transformer_engine_jax extension. No new top-level build flags; rides on the parent PR's NVTE_BUILD_WITH_NCCL_EP.

Tests & example

  • tests/jax/test_multi_process_ep.py (+690 lines): 13 tests covering bootstrap, ep_prepare shape/handle contracts, primitive-level dispatch/combine identity (uniform + skewed routing), custom_vjp fwd+bwd correctness, and HLO inspection (must not insert XLA collectives outside the EP FFI).
  • tests/jax/multi_process_launch_ep.sh: 4-rank launcher; sets XLA_FLAGS to keep XLA command-buffer capture off for the EP FFI sequence (NCCL EP graph-destroy interaction).
  • examples/jax/ep/ep_moe.py (+394 lines) + run_test_ep.sh: end-to-end MoE with EP, dp=ep=2 mesh, includes a ref-comparison --check that verifies fwd+bwd vs a single-process reference.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps

greptile-apps Bot commented May 22, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR lands the JAX Expert Parallelism bindings: five XLA FFI handlers over the nvte_ep_* C API, a Python wrapper with jax.custom_vjp for autograd, mesh-aware SPMD sharding rules, a 13-test multi-process suite, and an end-to-end MoE example. The implementation is substantial (~3,800 lines) and well-structured.

  • transformer_engine/jax/cpp_extensions/ep.py: ~960-line primitive layer registering EpPrepare, EpDispatch, EpCombine (fwd+bwd) as JAX primitives with abstract_eval, lowering, impl, batcher, partition, and Shardy rules.
  • transformer_engine/jax/ep.py: Public API — ep_bootstrap (one-shot NCCL comm init via ctypes UID exchange + XLA collective fallback), ep_dispatch / ep_combine as jax.custom_vjp functions with EP-mesh-aware residual handling.
  • transformer_engine/jax/csrc/extensions/ep.cpp: XLA FFI C++ handlers using a shared_ptr<EpResources> weak-ref lifetime model; all five handlers marked FFI_CudaGraph_Traits for CUDA-graph compatibility.

Confidence Score: 5/5

The new JAX EP layer is safe to merge; no new blocking defects were found beyond what has already been flagged in earlier review rounds.

The core forward and backward logic, custom_vjp residual handling, and SPMD partition rules are all correct. The only new findings are non-blocking: a broad exception swallow in _allgather_uid that masks bootstrap failures, EpCombineBwdPrimitive.partition omitting the sharding validation every other EP partition function performs, and _dispatch_bwd capturing out_partition_spec from mesh context at trace time rather than threading it as a nondiff arg.

The backward path in transformer_engine/jax/ep.py (_dispatch_bwd) and EpCombineBwdPrimitive.partition in transformer_engine/jax/cpp_extensions/ep.py deserve a second look for sharding-spec consistency.

Important Files Changed

Filename Overview
transformer_engine/jax/ep.py New file: high-level ep_bootstrap, ep_dispatch (custom_vjp), ep_combine (custom_vjp) APIs; backward's out_partition_spec is trace-time-captured rather than threaded as a nondiff arg (inconsistency with ep_combine); broad exception swallow in _allgather_uid.
transformer_engine/jax/cpp_extensions/ep.py New file: ~960-line primitive layer for EpPrepare/Dispatch/Combine (fwd+bwd); EpCombineBwdPrimitive.partition lacks input-sharding validation unlike every other EP partition function.
transformer_engine/jax/csrc/extensions/ep.cpp New file: XLA FFI handlers for 5 EP operations; double-lock pattern in SetEpBootstrapParams (race window on re-init already flagged); topk_weights always wrapped as float32 (acknowledged by developer as intentional).
build_tools/jax.py +3 lines: defaults NVTE_BUILD_WITH_NCCL_EP=1 (EP always on), but arch guard logic and hard-fail on missing submodule are inconsistent with setup.py's auto-disable behaviour (flagged in previous review cycles).
transformer_engine/jax/sharding.py Adds ep_axis_size() helper and expands MeshResource.ep_resource docstring; straightforward, no issues.
tests/jax/test_multi_process_ep.py New file: 742-line multi-process test suite with 13 cases covering bootstrap, primitive shapes, VJP correctness (closed-form), and HLO collective guard; good coverage.
transformer_engine/jax/csrc/extensions/pybind.cpp +31 lines: registers EP FFI handlers and pybind11 bindings under NVTE_WITH_NCCL_EP guard; correct use of instantiate+execute pair for stateful FFI.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant User as User code
    participant Bootstrap as ep_bootstrap
    participant Allgather as _allgather_uid
    participant NCCL as ncclGetUniqueId (ctypes)
    participant TE_JAX as transformer_engine_jax (C++)
    participant Dispatch as ep_dispatch (custom_vjp)
    participant Combine as ep_combine (custom_vjp)
    participant FFI as XLA FFI (ep.cpp)

    User->>Bootstrap: ep_bootstrap(world_size, rank, ...)
    Bootstrap->>NCCL: ncclGetUniqueId (color-root only)
    Bootstrap->>Allgather: _allgather_uid(uid_arr, world_size)
    Allgather-->>Bootstrap: all_uids[world_size, 128]
    Bootstrap->>TE_JAX: set_ep_bootstrap_params(uid_bytes, ep_size, ...)
    TE_JAX->>FFI: AcquireEpResources - ncclCommInitRank + nvte_ep_initialize
    Bootstrap->>Bootstrap: set_ep_config(EpConfig(...))

    User->>Dispatch: ep_dispatch(cfg, topk_idx, tokens, topk_weights, recv_cap)
    Dispatch->>FFI: EpPrepareFFI (nvte_ep_prepare - handle_mem, token_counts)
    Dispatch->>FFI: EpDispatchFFI (nvte_ep_dispatch - recv_tokens, recv_topk_weights)
    Dispatch-->>User: (recv_tokens, recv_topk_weights, handle_mem, token_counts)

    User->>Combine: ep_combine(cfg, handle_mem, token_counts, expert_out, ...)
    Combine->>FFI: EpCombineFFI (nvte_ep_combine - result)
    Combine-->>User: result [T, H]

    Note over Dispatch,FFI: Backward (custom_vjp)
    Dispatch->>FFI: EpDispatchBwdFFI (nvte_ep_dispatch_bwd)
    Combine->>FFI: EpCombineBwdFFI (nvte_ep_combine_bwd)
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant User as User code
    participant Bootstrap as ep_bootstrap
    participant Allgather as _allgather_uid
    participant NCCL as ncclGetUniqueId (ctypes)
    participant TE_JAX as transformer_engine_jax (C++)
    participant Dispatch as ep_dispatch (custom_vjp)
    participant Combine as ep_combine (custom_vjp)
    participant FFI as XLA FFI (ep.cpp)

    User->>Bootstrap: ep_bootstrap(world_size, rank, ...)
    Bootstrap->>NCCL: ncclGetUniqueId (color-root only)
    Bootstrap->>Allgather: _allgather_uid(uid_arr, world_size)
    Allgather-->>Bootstrap: all_uids[world_size, 128]
    Bootstrap->>TE_JAX: set_ep_bootstrap_params(uid_bytes, ep_size, ...)
    TE_JAX->>FFI: AcquireEpResources - ncclCommInitRank + nvte_ep_initialize
    Bootstrap->>Bootstrap: set_ep_config(EpConfig(...))

    User->>Dispatch: ep_dispatch(cfg, topk_idx, tokens, topk_weights, recv_cap)
    Dispatch->>FFI: EpPrepareFFI (nvte_ep_prepare - handle_mem, token_counts)
    Dispatch->>FFI: EpDispatchFFI (nvte_ep_dispatch - recv_tokens, recv_topk_weights)
    Dispatch-->>User: (recv_tokens, recv_topk_weights, handle_mem, token_counts)

    User->>Combine: ep_combine(cfg, handle_mem, token_counts, expert_out, ...)
    Combine->>FFI: EpCombineFFI (nvte_ep_combine - result)
    Combine-->>User: result [T, H]

    Note over Dispatch,FFI: Backward (custom_vjp)
    Dispatch->>FFI: EpDispatchBwdFFI (nvte_ep_dispatch_bwd)
    Combine->>FFI: EpCombineBwdFFI (nvte_ep_combine_bwd)
Loading

Reviews (19): Last reviewed commit: "jax/ep: resolve test path from launcher ..." | Re-trigger Greptile

Comment thread build_tools/jax.py Outdated
Comment thread build_tools/jax.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/ep.cpp Outdated
Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts,
Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) {
auto topk_dims = topk_idx.dimensions();
NVTE_CHECK(topk_dims.size() >= 2,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we return FFI InvalidArgument instead of a NVTE_CHECK for these inputs?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a good idea. I suggest we make another follow-up MR to do so for all the FFIs.

@phu0ngng phu0ngng requested a review from tdophung May 22, 2026 15:51
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

I would appreciate your help to review this PR @tdophung @jberchtold-nvidia!
Please focus on the changes in the JAX side, as the TE/Common ones will be discussed in #3034

Comment thread examples/jax/ep/ep_moe.py Outdated
Comment thread tests/jax/multi_process_launch_ep.sh Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread examples/jax/ep/ep_moe.py
Comment thread transformer_engine/jax/ep.py Outdated
Comment thread transformer_engine/jax/ep.py Outdated
Comment thread transformer_engine/jax/ep.py Outdated

@jberchtold-nvidia jberchtold-nvidia left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending CI

Comment thread transformer_engine/jax/ep.py
Comment thread transformer_engine/jax/ep.py
Comment thread transformer_engine/jax/ep.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/ep.cpp
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions.h Outdated
jberchtold-nvidia pushed a commit to jberchtold-nvidia/TransformerEngine that referenced this pull request Jun 5, 2026
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype
field. The C++ backend (ep_backend.cpp:349) enforces
    typeToSize(tok_dtype) <= typeToSize(max_token_dtype)
at every dispatch, and the field is also used at group create to size the
NCCL EP staging buffers (ep_backend.cpp:221-222).

PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written
before this field existed and never set it, so any JAX EP group landed with
the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from
JAX then failed immediately with:
    tokens dtype (6) wider than group max_token_dtype (0)

This commit threads max_token_dtype end-to-end:

  - transformer_engine/jax/csrc/extensions.h
    update SetEpBootstrapParams declaration to match the new arity.

  - transformer_engine/jax/csrc/extensions/ep.cpp
    add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams;
    forward it into NVTEEpGroupConfig in the EpResources ctor.

  - transformer_engine/jax/csrc/extensions/pybind.cpp
    add the matching pybind11::arg("max_token_dtype") = 0.

  - transformer_engine/jax/ep.py
    add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to
    NVTEDType int, forward to the C++ setter.

Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream.
See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/ep.cpp
@phu0ngng phu0ngng force-pushed the phuong/ep-3-jax branch 2 times, most recently from 06f8a13 to c34771d Compare June 10, 2026 15:24
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 10, 2026
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype
field. The C++ backend (ep_backend.cpp:349) enforces
    typeToSize(tok_dtype) <= typeToSize(max_token_dtype)
at every dispatch, and the field is also used at group create to size the
NCCL EP staging buffers (ep_backend.cpp:221-222).

PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written
before this field existed and never set it, so any JAX EP group landed with
the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from
JAX then failed immediately with:
    tokens dtype (6) wider than group max_token_dtype (0)

This commit threads max_token_dtype end-to-end:

  - transformer_engine/jax/csrc/extensions.h
    update SetEpBootstrapParams declaration to match the new arity.

  - transformer_engine/jax/csrc/extensions/ep.cpp
    add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams;
    forward it into NVTEEpGroupConfig in the EpResources ctor.

  - transformer_engine/jax/csrc/extensions/pybind.cpp
    add the matching pybind11::arg("max_token_dtype") = 0.

  - transformer_engine/jax/ep.py
    add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to
    NVTEDType int, forward to the C++ setter.

Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream.
See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 10, 2026
Reset 33 local commits onto phuong/ep-3-jax @ c34771d (her latest with
EpConfig + EpLayerConfig API, NCCL bumped to 808d2433) and re-applied
the three deltas uniquely ours:

  * transformer_engine/jax/moe.py: replaces upstream's multi-backend
    MoE block with our TE-EP-only single-custom-vjp rewrite. Adapted
    to her new API surface: tex.EpLayerConfig replaces tex.ep_make_handle
    (no more EpHandle pool/cache); 5 EP callsites rewired (cfg passed
    in place of handle, ep_prepare arg order swapped, top_k= dropped
    from ep_dispatch_bwd since it's now in cfg.
  * tests/jax/test_te_ep_moe.py: TE-EP MoE test (kept), with
    ep_bootstrap kwargs ep_size= and allow_handle_mem_reloc= dropped
    (no longer supported; ep_size is derived from mesh axes and the
    handle_mem reloc gating is gone).
  * tests/jax/run_te_ep_moe.sh: multi-process launcher (kept).

Pre-sync state preserved at branch
teddy/te_ep_integration.backup-pre-phuong-sync.
EOF
)
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

Comment thread transformer_engine/jax/ep.py
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 11, 2026
After the upstream PR NVIDIA#3036 resync the moe() API surface lost
PermutationBackend (TE-EP is the only backend now), gate_inside_vjp
(always True), and the per-call quantizer_sets knob (quantization
flows through the standard TE autocast / with_quantizer_set context).
It also gained apply_topk_weights_early and renamed the wrapper's
private _align_size to the public align_size the test suite already
uses. The Flax _MoEBlock wrapper was still passing the old kwargs,
which broke every test that touched the wrapper.

Wrapper changes:
  * drop "from ..moe import PermutationBackend" plus the dataclass
    field, the isinstance(..., PermutationBackend) validation in
    __post_init__, and the pass-through to moe().
  * drop "from ..quantize import noop_quantizer_set" and the
    quantizer_sets=(noop, noop, noop) pass-through.
  * drop gate_inside_vjp=True.
  * rename _align_size: int = 0 -> align_size: int = 0 (matches
    what tests/jax/test_te_ep_moe.py already passes).
  * add apply_topk_weights_early: bool = False and pass it through
    to moe().
  * refresh class docstring: drop permutation_backend / _align_size
    / quantizer_sets descriptions, add apply_topk_weights_early /
    align_size, note that quantization currently flows only through
    fp8_autocast.

Signed-off-by: tdophung <tdophung@nvidia.com>
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 12, 2026
…and inline justifications)

Responds to jberchtold-nvidia's PR NVIDIA#3116 review threads on
``transformer_engine/jax/moe.py``. All changes are confined to a
single file because each review thread targets a localized region
and splitting mid-file would risk reordering bugs.

Per review thread:

1. "Why do we need _with_sharding_constraint_cast_bwd? I haven't
    seen something like this required for our other VJPs."
   -- Expand the helper's docstring to spell out exactly why MoE
   needs it: unlike LN+MLP, the MoE bwd composes a bf16 cotangent
   from ep_dispatch_bwd with an fp32 cotangent from
   fused_topk_with_score_function_bwd (which the fwd's
   logits_2d -> fp32 promotion forces). Without the cast, ``d_x``
   surfaces at fp32 even when ``x`` is bf16, doubling activation
   grad bandwidth and breaking any downstream LN bwd that pins a
   bf16 layout. (Review thread "Why do we need this utility
   function?".)

2. "Why is this dtype casting required? I don't recall us needing
    it for the non-MoE LNMLP block."
   -- Expand the comment above the bwd activation fp32 promotion
   to explain the MoE-specific math: LN+MLP's silu sits behind a
   downstream LN that absorbs the bf16 rounding error, while
   MoE's silu sits on the *expert* side of routing -- the bf16
   rounding rides directly into expert_outputs and is summed
   across topk experts by ep_combine. Bf16 silu alone drifts ~1%
   vs fp32 silu and compounds through wo->combine into the ~1.4%
   per-element parity gap we measured against the pure-JAX
   softmax reference. Mirroring the fwd's fp32 promotion in the
   bwd keeps silu' in lock-step with silu. (Review thread on
   "# Activation bwd. Mirror the fwd's fp32 promotion of
   silu+multiply".)

3. "Do we have a use-case for user-specified alignments beyond
    128 currently? ... it'd make sense to instead hardcode
    _ALIGN_SIZE = 128 as a constant at the top of the file for
    now to simplify this MoEBlock API. We can always expand the
    API to support a user-specified align size in the future."
   -- Implement the suggestion. Drop ``align_size`` from
   ``_moe_fwd_rule`` / ``_moe_bwd_rule`` / ``_moe`` / public
   ``moe()``; shift the ``custom_vjp`` ``nondiff_argnums`` from
   ``range(9, 27)`` -> ``range(9, 26)``; replace ``effective_align
   = max(int(align_size), 128)`` with the new module-level
   ``_ALIGN_SIZE = 128`` constant. Trim the ``moe()`` docstring
   accordingly. (Review thread on
   "natural_spe = num_ep * max_tokens_per_rank".)

4. "Which axis name inputs are physical mesh axes and why can be
    logical axes? ... No need to make any changes for now, I just
    want to assess which are which and then we can discuss if it
    makes sense to support logical on some/all or if some are
    required to be physical axes."
   -- Add an "Axis-name parameters" section to ``moe()``'s
   docstring listing which kwargs are physical mesh axes
   (``ep_axis``, ``data_parallelism_axes`` -- they index
   ``Mesh.shape`` directly to compute ``num_ep`` / ``dp_size``
   and to construct the ``P((dp..., ep), None, None)`` for
   ``jax.lax.with_sharding_constraint``) vs logical axes
   (``input_axes``, ``gate_kernel_axes``, ``wi_kernel_axes``,
   ``wo_kernel_axes`` -- resolved via the Flax logical-axis
   rules). Also document why ``ep_axis`` / ``data_parallelism_axes``
   are intentionally non-logical: the EP comm-group construction
   (``dp_color = rank // ep_size``) and the bootstrap signature
   check both require concrete integer sizes. (Review thread on
   "batch_pspec_axis = (*data_parallelism_axes, ep_axis)".)

5. "Is this NaN filtering a debugging artifact or something we
    need in the final version?"
   -- Strengthen the inline comment above
   ``sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, ...)``
   to explicitly call this out as a CORRECTNESS REQUIREMENT, not
   a debugging artifact: it covers the sigmoid+K>1 underflow
   path where top-K sigmoid scores all round to zero and the
   ``weights / (weights.sum + 1e-20)`` normalisation emits NaN.
   Observationally the filter is a no-op on the dense unit-test
   distributions, but it must stay in for sparse / production
   routing. (Review thread on
   "sparse_probs = jnp.where(jnp.isnan(sparse_probs), ...).")

Not addressed in this commit (intentional):

* Review thread on the ``align_size: int = 0`` placeholder in
  ``flax/moe.py`` ("Placeholder comment for me to fix this so
  align_size is inferred automatically based on the recipe and
  doesn't need to be specified by the user"). That's
  jberchtold's own follow-up.
* Review thread on the explicit ``tree_flatten`` /
  ``tree_unflatten`` on ``_Ctx`` ("better to use the
  ``@flax_struct.dataclass``"). Deferred to a separate, testable
  commit because changing a ``custom_vjp`` residual's pytree
  registration touches subtle ordering / None-handling semantics
  that warrant their own bisect surface.
* Review thread on ``use_bias`` / ``use_expert_bias`` renames --
  handled in the immediately preceding commit
  ``jax/flax,tests: rename use_bias/use_expert_bias for symmetry``.
* Review thread on the ``expert_bias`` fp32 init -- already
  resolved during the Phuong PR NVIDIA#3036 resync (the redundant
  ``jnp.float32`` second-dtype argument on ``self.param`` was
  dropped; ``expert_bias`` now lives at ``self.dtype``).

Signed-off-by: tdophung <tdophung@nvidia.com>
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 12, 2026
Reset 33 local commits onto phuong/ep-3-jax @ c34771d (her latest with
EpConfig + EpLayerConfig API, NCCL bumped to 808d2433) and re-applied
the three deltas uniquely ours:

  * transformer_engine/jax/moe.py: replaces upstream's multi-backend
    MoE block with our TE-EP-only single-custom-vjp rewrite. Adapted
    to her new API surface: tex.EpLayerConfig replaces tex.ep_make_handle
    (no more EpHandle pool/cache); 5 EP callsites rewired (cfg passed
    in place of handle, ep_prepare arg order swapped, top_k= dropped
    from ep_dispatch_bwd since it's now in cfg.
  * tests/jax/test_te_ep_moe.py: TE-EP MoE test (kept), with
    ep_bootstrap kwargs ep_size= and allow_handle_mem_reloc= dropped
    (no longer supported; ep_size is derived from mesh axes and the
    handle_mem reloc gating is gone).
  * tests/jax/run_te_ep_moe.sh: multi-process launcher (kept).

Pre-sync state preserved at branch
teddy/te_ep_integration.backup-pre-phuong-sync.
EOF
)

Signed-off-by: tdophung <tdophung@nvidia.com>
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 12, 2026
After the upstream PR NVIDIA#3036 resync the moe() API surface lost
PermutationBackend (TE-EP is the only backend now), gate_inside_vjp
(always True), and the per-call quantizer_sets knob (quantization
flows through the standard TE autocast / with_quantizer_set context).
It also gained apply_topk_weights_early and renamed the wrapper's
private _align_size to the public align_size the test suite already
uses. The Flax _MoEBlock wrapper was still passing the old kwargs,
which broke every test that touched the wrapper.

Wrapper changes:
  * drop "from ..moe import PermutationBackend" plus the dataclass
    field, the isinstance(..., PermutationBackend) validation in
    __post_init__, and the pass-through to moe().
  * drop "from ..quantize import noop_quantizer_set" and the
    quantizer_sets=(noop, noop, noop) pass-through.
  * drop gate_inside_vjp=True.
  * rename _align_size: int = 0 -> align_size: int = 0 (matches
    what tests/jax/test_te_ep_moe.py already passes).
  * add apply_topk_weights_early: bool = False and pass it through
    to moe().
  * refresh class docstring: drop permutation_backend / _align_size
    / quantizer_sets descriptions, add apply_topk_weights_early /
    align_size, note that quantization currently flows only through
    fp8_autocast.

Signed-off-by: tdophung <tdophung@nvidia.com>
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 12, 2026
…and inline justifications)

Responds to jberchtold-nvidia's PR NVIDIA#3116 review threads on
``transformer_engine/jax/moe.py``. All changes are confined to a
single file because each review thread targets a localized region
and splitting mid-file would risk reordering bugs.

Per review thread:

1. "Why do we need _with_sharding_constraint_cast_bwd? I haven't
    seen something like this required for our other VJPs."
   -- Expand the helper's docstring to spell out exactly why MoE
   needs it: unlike LN+MLP, the MoE bwd composes a bf16 cotangent
   from ep_dispatch_bwd with an fp32 cotangent from
   fused_topk_with_score_function_bwd (which the fwd's
   logits_2d -> fp32 promotion forces). Without the cast, ``d_x``
   surfaces at fp32 even when ``x`` is bf16, doubling activation
   grad bandwidth and breaking any downstream LN bwd that pins a
   bf16 layout. (Review thread "Why do we need this utility
   function?".)

2. "Why is this dtype casting required? I don't recall us needing
    it for the non-MoE LNMLP block."
   -- Expand the comment above the bwd activation fp32 promotion
   to explain the MoE-specific math: LN+MLP's silu sits behind a
   downstream LN that absorbs the bf16 rounding error, while
   MoE's silu sits on the *expert* side of routing -- the bf16
   rounding rides directly into expert_outputs and is summed
   across topk experts by ep_combine. Bf16 silu alone drifts ~1%
   vs fp32 silu and compounds through wo->combine into the ~1.4%
   per-element parity gap we measured against the pure-JAX
   softmax reference. Mirroring the fwd's fp32 promotion in the
   bwd keeps silu' in lock-step with silu. (Review thread on
   "# Activation bwd. Mirror the fwd's fp32 promotion of
   silu+multiply".)

3. "Do we have a use-case for user-specified alignments beyond
    128 currently? ... it'd make sense to instead hardcode
    _ALIGN_SIZE = 128 as a constant at the top of the file for
    now to simplify this MoEBlock API. We can always expand the
    API to support a user-specified align size in the future."
   -- Implement the suggestion. Drop ``align_size`` from
   ``_moe_fwd_rule`` / ``_moe_bwd_rule`` / ``_moe`` / public
   ``moe()``; shift the ``custom_vjp`` ``nondiff_argnums`` from
   ``range(9, 27)`` -> ``range(9, 26)``; replace ``effective_align
   = max(int(align_size), 128)`` with the new module-level
   ``_ALIGN_SIZE = 128`` constant. Trim the ``moe()`` docstring
   accordingly. (Review thread on
   "natural_spe = num_ep * max_tokens_per_rank".)

4. "Which axis name inputs are physical mesh axes and why can be
    logical axes? ... No need to make any changes for now, I just
    want to assess which are which and then we can discuss if it
    makes sense to support logical on some/all or if some are
    required to be physical axes."
   -- Add an "Axis-name parameters" section to ``moe()``'s
   docstring listing which kwargs are physical mesh axes
   (``ep_axis``, ``data_parallelism_axes`` -- they index
   ``Mesh.shape`` directly to compute ``num_ep`` / ``dp_size``
   and to construct the ``P((dp..., ep), None, None)`` for
   ``jax.lax.with_sharding_constraint``) vs logical axes
   (``input_axes``, ``gate_kernel_axes``, ``wi_kernel_axes``,
   ``wo_kernel_axes`` -- resolved via the Flax logical-axis
   rules). Also document why ``ep_axis`` / ``data_parallelism_axes``
   are intentionally non-logical: the EP comm-group construction
   (``dp_color = rank // ep_size``) and the bootstrap signature
   check both require concrete integer sizes. (Review thread on
   "batch_pspec_axis = (*data_parallelism_axes, ep_axis)".)

5. "Is this NaN filtering a debugging artifact or something we
    need in the final version?"
   -- Strengthen the inline comment above
   ``sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, ...)``
   to explicitly call this out as a CORRECTNESS REQUIREMENT, not
   a debugging artifact: it covers the sigmoid+K>1 underflow
   path where top-K sigmoid scores all round to zero and the
   ``weights / (weights.sum + 1e-20)`` normalisation emits NaN.
   Observationally the filter is a no-op on the dense unit-test
   distributions, but it must stay in for sparse / production
   routing. (Review thread on
   "sparse_probs = jnp.where(jnp.isnan(sparse_probs), ...).")

Not addressed in this commit (intentional):

* Review thread on the ``align_size: int = 0`` placeholder in
  ``flax/moe.py`` ("Placeholder comment for me to fix this so
  align_size is inferred automatically based on the recipe and
  doesn't need to be specified by the user"). That's
  jberchtold's own follow-up.
* Review thread on the explicit ``tree_flatten`` /
  ``tree_unflatten`` on ``_Ctx`` ("better to use the
  ``@flax_struct.dataclass``"). Deferred to a separate, testable
  commit because changing a ``custom_vjp`` residual's pytree
  registration touches subtle ordering / None-handling semantics
  that warrant their own bisect surface.
* Review thread on ``use_bias`` / ``use_expert_bias`` renames --
  handled in the immediately preceding commit
  ``jax/flax,tests: rename use_bias/use_expert_bias for symmetry``.
* Review thread on the ``expert_bias`` fp32 init -- already
  resolved during the Phuong PR NVIDIA#3036 resync (the redundant
  ``jnp.float32`` second-dtype argument on ``self.param`` was
  dropped; ``expert_bias`` now lives at ``self.dtype``).

Signed-off-by: tdophung <tdophung@nvidia.com>
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

pipeline 54608589

dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler);
dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler);

#ifdef NVTE_WITH_NCCL_EP

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the TE core library and TE/JAX may be built separately, e.g. when pip installing from PyPi the TE core is pre-built but TE/JAX is a source dist and rebuilt. In that case, it's possible TE core is built with NVTE_WITH_NCCL_EP support but when installing TE/JAX the user doesn't have this flag set.

Is this intended? If not, would it be better to have TE core expose a nvte_is_nccl_ep_supported() and we call that here instead of using the NVTE_WITH_NCCL_EP macro?



def ep_bootstrap(
world_size,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a description of each arg into the docstring? I think I could infer what "world_size" is but would be good to be explicit

tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 12, 2026
…successor

test_moe_vjp.py and test_multiprocess_moe_vjp.py both import
PermutationBackend from transformer_engine.jax.moe -- an API that
was removed during the Phuong PR NVIDIA#3036 resync. Both files have
been dead-on-import ever since; the multiprocess launcher
run_multiprocess_moe_vjp.sh only points at the dead test.

test_te_ep_moe.py (the TE-EP-only custom_vjp suite) already covers
everything the legacy files exercised that is still meaningful:
fwd, bwd parity vs the pure-JAX reference, aux loss, both score
functions, multi-process. The legacy parametrize axis
(PermutationBackend.PURE_JAX vs TRITON) no longer exists.

* Delete tests/jax/test_moe_vjp.py
* Delete tests/jax/test_multiprocess_moe_vjp.py
* Delete tests/jax/run_multiprocess_moe_vjp.sh
* qa/L0_jax_distributed_unittest/test.sh: switch the MoE VJP
  distributed suite invocation from run_multiprocess_moe_vjp.sh /
  test_multiprocess_moe_vjp.py to run_te_ep_moe.sh /
  test_te_ep_moe.py.
* tests/jax/conftest.py: docstring reference updated.
* tests/jax/test_te_ep_moe.py: drop stale "successor to ..." aside
  and the "mirroring run_multiprocess_moe_vjp.sh" parenthetical.

Net: -981 / +9.

Signed-off-by: tdophung <tdophung@nvidia.com>
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 12, 2026
…erplate

Per reviewer feedback (Jaberchtold on PR NVIDIA#3036): the manual
tree_flatten / tree_unflatten on _Ctx duplicate exactly what
@flax.struct.dataclass auto-generates, and the permutation
dataclasses elsewhere in this module already use flax.struct.

Switching to @flax.struct.dataclass:
* Removes ~75 lines of mechanical tree_flatten / tree_unflatten
  that have to be kept in sync with the field list by hand.
* Keeps cfg as the single static field via
  flax.struct.field(pytree_node=False), so the fwd -> bwd boundary
  behavior under jax.custom_vjp is unchanged.
* Drops two now-unused imports (dataclasses.dataclass,
  jax.tree_util.register_pytree_node_class) and adds flax.struct.

Field order and the (children, aux_data) split are byte-equivalent
to the previous manual implementation, so the pytree treedef seen
by jax.custom_vjp is identical.

Signed-off-by: tdophung <tdophung@nvidia.com>
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

pipeline 54620197

@tdophung tdophung mentioned this pull request Jun 12, 2026
13 tasks
phu0ngng added 16 commits June 23, 2026 02:30
…16 max_token_dtype

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… with_sharding_constraint

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…trap

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…EpLayerConfig type)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ives (lint 10.00)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…; define NVTE_WITH_NCCL_EP

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ract, drop dead helpers

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…e example) jax distributed suites

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ARY_PATH for libnccl_ep.so

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants