[JAX] Expert Parallelism: JAX primitives + VJPs#3036
Conversation
| Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts, | ||
| Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) { | ||
| auto topk_dims = topk_idx.dimensions(); | ||
| NVTE_CHECK(topk_dims.size() >= 2, |
There was a problem hiding this comment.
nit: can we return FFI InvalidArgument instead of a NVTE_CHECK for these inputs?
There was a problem hiding this comment.
This is probably a good idea. I suggest we make another follow-up MR to do so for all the FFIs.
|
I would appreciate your help to review this PR @tdophung @jberchtold-nvidia! |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM pending CI
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype field. The C++ backend (ep_backend.cpp:349) enforces typeToSize(tok_dtype) <= typeToSize(max_token_dtype) at every dispatch, and the field is also used at group create to size the NCCL EP staging buffers (ep_backend.cpp:221-222). PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written before this field existed and never set it, so any JAX EP group landed with the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from JAX then failed immediately with: tokens dtype (6) wider than group max_token_dtype (0) This commit threads max_token_dtype end-to-end: - transformer_engine/jax/csrc/extensions.h update SetEpBootstrapParams declaration to match the new arity. - transformer_engine/jax/csrc/extensions/ep.cpp add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams; forward it into NVTEEpGroupConfig in the EpResources ctor. - transformer_engine/jax/csrc/extensions/pybind.cpp add the matching pybind11::arg("max_token_dtype") = 0. - transformer_engine/jax/ep.py add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to NVTEDType int, forward to the C++ setter. Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream. See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
06f8a13 to
c34771d
Compare
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype field. The C++ backend (ep_backend.cpp:349) enforces typeToSize(tok_dtype) <= typeToSize(max_token_dtype) at every dispatch, and the field is also used at group create to size the NCCL EP staging buffers (ep_backend.cpp:221-222). PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written before this field existed and never set it, so any JAX EP group landed with the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from JAX then failed immediately with: tokens dtype (6) wider than group max_token_dtype (0) This commit threads max_token_dtype end-to-end: - transformer_engine/jax/csrc/extensions.h update SetEpBootstrapParams declaration to match the new arity. - transformer_engine/jax/csrc/extensions/ep.cpp add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams; forward it into NVTEEpGroupConfig in the EpResources ctor. - transformer_engine/jax/csrc/extensions/pybind.cpp add the matching pybind11::arg("max_token_dtype") = 0. - transformer_engine/jax/ep.py add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to NVTEDType int, forward to the C++ setter. Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream. See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
Reset 33 local commits onto phuong/ep-3-jax @ c34771d (her latest with EpConfig + EpLayerConfig API, NCCL bumped to 808d2433) and re-applied the three deltas uniquely ours: * transformer_engine/jax/moe.py: replaces upstream's multi-backend MoE block with our TE-EP-only single-custom-vjp rewrite. Adapted to her new API surface: tex.EpLayerConfig replaces tex.ep_make_handle (no more EpHandle pool/cache); 5 EP callsites rewired (cfg passed in place of handle, ep_prepare arg order swapped, top_k= dropped from ep_dispatch_bwd since it's now in cfg. * tests/jax/test_te_ep_moe.py: TE-EP MoE test (kept), with ep_bootstrap kwargs ep_size= and allow_handle_mem_reloc= dropped (no longer supported; ep_size is derived from mesh axes and the handle_mem reloc gating is gone). * tests/jax/run_te_ep_moe.sh: multi-process launcher (kept). Pre-sync state preserved at branch teddy/te_ep_integration.backup-pre-phuong-sync. EOF )
c34771d to
351b9df
Compare
|
/te-ci JAX L1 |
|
/te-ci JAX L1 |
After the upstream PR NVIDIA#3036 resync the moe() API surface lost PermutationBackend (TE-EP is the only backend now), gate_inside_vjp (always True), and the per-call quantizer_sets knob (quantization flows through the standard TE autocast / with_quantizer_set context). It also gained apply_topk_weights_early and renamed the wrapper's private _align_size to the public align_size the test suite already uses. The Flax _MoEBlock wrapper was still passing the old kwargs, which broke every test that touched the wrapper. Wrapper changes: * drop "from ..moe import PermutationBackend" plus the dataclass field, the isinstance(..., PermutationBackend) validation in __post_init__, and the pass-through to moe(). * drop "from ..quantize import noop_quantizer_set" and the quantizer_sets=(noop, noop, noop) pass-through. * drop gate_inside_vjp=True. * rename _align_size: int = 0 -> align_size: int = 0 (matches what tests/jax/test_te_ep_moe.py already passes). * add apply_topk_weights_early: bool = False and pass it through to moe(). * refresh class docstring: drop permutation_backend / _align_size / quantizer_sets descriptions, add apply_topk_weights_early / align_size, note that quantization currently flows only through fp8_autocast. Signed-off-by: tdophung <tdophung@nvidia.com>
…and inline justifications) Responds to jberchtold-nvidia's PR NVIDIA#3116 review threads on ``transformer_engine/jax/moe.py``. All changes are confined to a single file because each review thread targets a localized region and splitting mid-file would risk reordering bugs. Per review thread: 1. "Why do we need _with_sharding_constraint_cast_bwd? I haven't seen something like this required for our other VJPs." -- Expand the helper's docstring to spell out exactly why MoE needs it: unlike LN+MLP, the MoE bwd composes a bf16 cotangent from ep_dispatch_bwd with an fp32 cotangent from fused_topk_with_score_function_bwd (which the fwd's logits_2d -> fp32 promotion forces). Without the cast, ``d_x`` surfaces at fp32 even when ``x`` is bf16, doubling activation grad bandwidth and breaking any downstream LN bwd that pins a bf16 layout. (Review thread "Why do we need this utility function?".) 2. "Why is this dtype casting required? I don't recall us needing it for the non-MoE LNMLP block." -- Expand the comment above the bwd activation fp32 promotion to explain the MoE-specific math: LN+MLP's silu sits behind a downstream LN that absorbs the bf16 rounding error, while MoE's silu sits on the *expert* side of routing -- the bf16 rounding rides directly into expert_outputs and is summed across topk experts by ep_combine. Bf16 silu alone drifts ~1% vs fp32 silu and compounds through wo->combine into the ~1.4% per-element parity gap we measured against the pure-JAX softmax reference. Mirroring the fwd's fp32 promotion in the bwd keeps silu' in lock-step with silu. (Review thread on "# Activation bwd. Mirror the fwd's fp32 promotion of silu+multiply".) 3. "Do we have a use-case for user-specified alignments beyond 128 currently? ... it'd make sense to instead hardcode _ALIGN_SIZE = 128 as a constant at the top of the file for now to simplify this MoEBlock API. We can always expand the API to support a user-specified align size in the future." -- Implement the suggestion. Drop ``align_size`` from ``_moe_fwd_rule`` / ``_moe_bwd_rule`` / ``_moe`` / public ``moe()``; shift the ``custom_vjp`` ``nondiff_argnums`` from ``range(9, 27)`` -> ``range(9, 26)``; replace ``effective_align = max(int(align_size), 128)`` with the new module-level ``_ALIGN_SIZE = 128`` constant. Trim the ``moe()`` docstring accordingly. (Review thread on "natural_spe = num_ep * max_tokens_per_rank".) 4. "Which axis name inputs are physical mesh axes and why can be logical axes? ... No need to make any changes for now, I just want to assess which are which and then we can discuss if it makes sense to support logical on some/all or if some are required to be physical axes." -- Add an "Axis-name parameters" section to ``moe()``'s docstring listing which kwargs are physical mesh axes (``ep_axis``, ``data_parallelism_axes`` -- they index ``Mesh.shape`` directly to compute ``num_ep`` / ``dp_size`` and to construct the ``P((dp..., ep), None, None)`` for ``jax.lax.with_sharding_constraint``) vs logical axes (``input_axes``, ``gate_kernel_axes``, ``wi_kernel_axes``, ``wo_kernel_axes`` -- resolved via the Flax logical-axis rules). Also document why ``ep_axis`` / ``data_parallelism_axes`` are intentionally non-logical: the EP comm-group construction (``dp_color = rank // ep_size``) and the bootstrap signature check both require concrete integer sizes. (Review thread on "batch_pspec_axis = (*data_parallelism_axes, ep_axis)".) 5. "Is this NaN filtering a debugging artifact or something we need in the final version?" -- Strengthen the inline comment above ``sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, ...)`` to explicitly call this out as a CORRECTNESS REQUIREMENT, not a debugging artifact: it covers the sigmoid+K>1 underflow path where top-K sigmoid scores all round to zero and the ``weights / (weights.sum + 1e-20)`` normalisation emits NaN. Observationally the filter is a no-op on the dense unit-test distributions, but it must stay in for sparse / production routing. (Review thread on "sparse_probs = jnp.where(jnp.isnan(sparse_probs), ...).") Not addressed in this commit (intentional): * Review thread on the ``align_size: int = 0`` placeholder in ``flax/moe.py`` ("Placeholder comment for me to fix this so align_size is inferred automatically based on the recipe and doesn't need to be specified by the user"). That's jberchtold's own follow-up. * Review thread on the explicit ``tree_flatten`` / ``tree_unflatten`` on ``_Ctx`` ("better to use the ``@flax_struct.dataclass``"). Deferred to a separate, testable commit because changing a ``custom_vjp`` residual's pytree registration touches subtle ordering / None-handling semantics that warrant their own bisect surface. * Review thread on ``use_bias`` / ``use_expert_bias`` renames -- handled in the immediately preceding commit ``jax/flax,tests: rename use_bias/use_expert_bias for symmetry``. * Review thread on the ``expert_bias`` fp32 init -- already resolved during the Phuong PR NVIDIA#3036 resync (the redundant ``jnp.float32`` second-dtype argument on ``self.param`` was dropped; ``expert_bias`` now lives at ``self.dtype``). Signed-off-by: tdophung <tdophung@nvidia.com>
Reset 33 local commits onto phuong/ep-3-jax @ c34771d (her latest with EpConfig + EpLayerConfig API, NCCL bumped to 808d2433) and re-applied the three deltas uniquely ours: * transformer_engine/jax/moe.py: replaces upstream's multi-backend MoE block with our TE-EP-only single-custom-vjp rewrite. Adapted to her new API surface: tex.EpLayerConfig replaces tex.ep_make_handle (no more EpHandle pool/cache); 5 EP callsites rewired (cfg passed in place of handle, ep_prepare arg order swapped, top_k= dropped from ep_dispatch_bwd since it's now in cfg. * tests/jax/test_te_ep_moe.py: TE-EP MoE test (kept), with ep_bootstrap kwargs ep_size= and allow_handle_mem_reloc= dropped (no longer supported; ep_size is derived from mesh axes and the handle_mem reloc gating is gone). * tests/jax/run_te_ep_moe.sh: multi-process launcher (kept). Pre-sync state preserved at branch teddy/te_ep_integration.backup-pre-phuong-sync. EOF ) Signed-off-by: tdophung <tdophung@nvidia.com>
After the upstream PR NVIDIA#3036 resync the moe() API surface lost PermutationBackend (TE-EP is the only backend now), gate_inside_vjp (always True), and the per-call quantizer_sets knob (quantization flows through the standard TE autocast / with_quantizer_set context). It also gained apply_topk_weights_early and renamed the wrapper's private _align_size to the public align_size the test suite already uses. The Flax _MoEBlock wrapper was still passing the old kwargs, which broke every test that touched the wrapper. Wrapper changes: * drop "from ..moe import PermutationBackend" plus the dataclass field, the isinstance(..., PermutationBackend) validation in __post_init__, and the pass-through to moe(). * drop "from ..quantize import noop_quantizer_set" and the quantizer_sets=(noop, noop, noop) pass-through. * drop gate_inside_vjp=True. * rename _align_size: int = 0 -> align_size: int = 0 (matches what tests/jax/test_te_ep_moe.py already passes). * add apply_topk_weights_early: bool = False and pass it through to moe(). * refresh class docstring: drop permutation_backend / _align_size / quantizer_sets descriptions, add apply_topk_weights_early / align_size, note that quantization currently flows only through fp8_autocast. Signed-off-by: tdophung <tdophung@nvidia.com>
…and inline justifications) Responds to jberchtold-nvidia's PR NVIDIA#3116 review threads on ``transformer_engine/jax/moe.py``. All changes are confined to a single file because each review thread targets a localized region and splitting mid-file would risk reordering bugs. Per review thread: 1. "Why do we need _with_sharding_constraint_cast_bwd? I haven't seen something like this required for our other VJPs." -- Expand the helper's docstring to spell out exactly why MoE needs it: unlike LN+MLP, the MoE bwd composes a bf16 cotangent from ep_dispatch_bwd with an fp32 cotangent from fused_topk_with_score_function_bwd (which the fwd's logits_2d -> fp32 promotion forces). Without the cast, ``d_x`` surfaces at fp32 even when ``x`` is bf16, doubling activation grad bandwidth and breaking any downstream LN bwd that pins a bf16 layout. (Review thread "Why do we need this utility function?".) 2. "Why is this dtype casting required? I don't recall us needing it for the non-MoE LNMLP block." -- Expand the comment above the bwd activation fp32 promotion to explain the MoE-specific math: LN+MLP's silu sits behind a downstream LN that absorbs the bf16 rounding error, while MoE's silu sits on the *expert* side of routing -- the bf16 rounding rides directly into expert_outputs and is summed across topk experts by ep_combine. Bf16 silu alone drifts ~1% vs fp32 silu and compounds through wo->combine into the ~1.4% per-element parity gap we measured against the pure-JAX softmax reference. Mirroring the fwd's fp32 promotion in the bwd keeps silu' in lock-step with silu. (Review thread on "# Activation bwd. Mirror the fwd's fp32 promotion of silu+multiply".) 3. "Do we have a use-case for user-specified alignments beyond 128 currently? ... it'd make sense to instead hardcode _ALIGN_SIZE = 128 as a constant at the top of the file for now to simplify this MoEBlock API. We can always expand the API to support a user-specified align size in the future." -- Implement the suggestion. Drop ``align_size`` from ``_moe_fwd_rule`` / ``_moe_bwd_rule`` / ``_moe`` / public ``moe()``; shift the ``custom_vjp`` ``nondiff_argnums`` from ``range(9, 27)`` -> ``range(9, 26)``; replace ``effective_align = max(int(align_size), 128)`` with the new module-level ``_ALIGN_SIZE = 128`` constant. Trim the ``moe()`` docstring accordingly. (Review thread on "natural_spe = num_ep * max_tokens_per_rank".) 4. "Which axis name inputs are physical mesh axes and why can be logical axes? ... No need to make any changes for now, I just want to assess which are which and then we can discuss if it makes sense to support logical on some/all or if some are required to be physical axes." -- Add an "Axis-name parameters" section to ``moe()``'s docstring listing which kwargs are physical mesh axes (``ep_axis``, ``data_parallelism_axes`` -- they index ``Mesh.shape`` directly to compute ``num_ep`` / ``dp_size`` and to construct the ``P((dp..., ep), None, None)`` for ``jax.lax.with_sharding_constraint``) vs logical axes (``input_axes``, ``gate_kernel_axes``, ``wi_kernel_axes``, ``wo_kernel_axes`` -- resolved via the Flax logical-axis rules). Also document why ``ep_axis`` / ``data_parallelism_axes`` are intentionally non-logical: the EP comm-group construction (``dp_color = rank // ep_size``) and the bootstrap signature check both require concrete integer sizes. (Review thread on "batch_pspec_axis = (*data_parallelism_axes, ep_axis)".) 5. "Is this NaN filtering a debugging artifact or something we need in the final version?" -- Strengthen the inline comment above ``sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, ...)`` to explicitly call this out as a CORRECTNESS REQUIREMENT, not a debugging artifact: it covers the sigmoid+K>1 underflow path where top-K sigmoid scores all round to zero and the ``weights / (weights.sum + 1e-20)`` normalisation emits NaN. Observationally the filter is a no-op on the dense unit-test distributions, but it must stay in for sparse / production routing. (Review thread on "sparse_probs = jnp.where(jnp.isnan(sparse_probs), ...).") Not addressed in this commit (intentional): * Review thread on the ``align_size: int = 0`` placeholder in ``flax/moe.py`` ("Placeholder comment for me to fix this so align_size is inferred automatically based on the recipe and doesn't need to be specified by the user"). That's jberchtold's own follow-up. * Review thread on the explicit ``tree_flatten`` / ``tree_unflatten`` on ``_Ctx`` ("better to use the ``@flax_struct.dataclass``"). Deferred to a separate, testable commit because changing a ``custom_vjp`` residual's pytree registration touches subtle ordering / None-handling semantics that warrant their own bisect surface. * Review thread on ``use_bias`` / ``use_expert_bias`` renames -- handled in the immediately preceding commit ``jax/flax,tests: rename use_bias/use_expert_bias for symmetry``. * Review thread on the ``expert_bias`` fp32 init -- already resolved during the Phuong PR NVIDIA#3036 resync (the redundant ``jnp.float32`` second-dtype argument on ``self.param`` was dropped; ``expert_bias`` now lives at ``self.dtype``). Signed-off-by: tdophung <tdophung@nvidia.com>
|
pipeline 54608589 |
| dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler); | ||
| dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler); | ||
|
|
||
| #ifdef NVTE_WITH_NCCL_EP |
There was a problem hiding this comment.
Since the TE core library and TE/JAX may be built separately, e.g. when pip installing from PyPi the TE core is pre-built but TE/JAX is a source dist and rebuilt. In that case, it's possible TE core is built with NVTE_WITH_NCCL_EP support but when installing TE/JAX the user doesn't have this flag set.
Is this intended? If not, would it be better to have TE core expose a nvte_is_nccl_ep_supported() and we call that here instead of using the NVTE_WITH_NCCL_EP macro?
|
|
||
|
|
||
| def ep_bootstrap( | ||
| world_size, |
There was a problem hiding this comment.
Can we add a description of each arg into the docstring? I think I could infer what "world_size" is but would be good to be explicit
…successor test_moe_vjp.py and test_multiprocess_moe_vjp.py both import PermutationBackend from transformer_engine.jax.moe -- an API that was removed during the Phuong PR NVIDIA#3036 resync. Both files have been dead-on-import ever since; the multiprocess launcher run_multiprocess_moe_vjp.sh only points at the dead test. test_te_ep_moe.py (the TE-EP-only custom_vjp suite) already covers everything the legacy files exercised that is still meaningful: fwd, bwd parity vs the pure-JAX reference, aux loss, both score functions, multi-process. The legacy parametrize axis (PermutationBackend.PURE_JAX vs TRITON) no longer exists. * Delete tests/jax/test_moe_vjp.py * Delete tests/jax/test_multiprocess_moe_vjp.py * Delete tests/jax/run_multiprocess_moe_vjp.sh * qa/L0_jax_distributed_unittest/test.sh: switch the MoE VJP distributed suite invocation from run_multiprocess_moe_vjp.sh / test_multiprocess_moe_vjp.py to run_te_ep_moe.sh / test_te_ep_moe.py. * tests/jax/conftest.py: docstring reference updated. * tests/jax/test_te_ep_moe.py: drop stale "successor to ..." aside and the "mirroring run_multiprocess_moe_vjp.sh" parenthetical. Net: -981 / +9. Signed-off-by: tdophung <tdophung@nvidia.com>
…erplate Per reviewer feedback (Jaberchtold on PR NVIDIA#3036): the manual tree_flatten / tree_unflatten on _Ctx duplicate exactly what @flax.struct.dataclass auto-generates, and the permutation dataclasses elsewhere in this module already use flax.struct. Switching to @flax.struct.dataclass: * Removes ~75 lines of mechanical tree_flatten / tree_unflatten that have to be kept in sync with the field list by hand. * Keeps cfg as the single static field via flax.struct.field(pytree_node=False), so the fwd -> bwd boundary behavior under jax.custom_vjp is unchanged. * Drops two now-unused imports (dataclasses.dataclass, jax.tree_util.register_pytree_node_class) and adds flax.struct. Field order and the (children, aux_data) split are byte-equivalent to the previous manual implementation, so the pytree treedef seen by jax.custom_vjp is identical. Signed-off-by: tdophung <tdophung@nvidia.com>
|
pipeline 54620197 |
…16 max_token_dtype Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… with_sharding_constraint Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…trap Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…EpLayerConfig type) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ives (lint 10.00) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…; define NVTE_WITH_NCCL_EP Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ract, drop dead helpers Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…e example) jax distributed suites Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ARY_PATH for libnccl_ep.so Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
9184ebd to
31fe375
Compare
Summary
Third PR in the TE Expert Parallelism (EP) series, built on top of #3034. Lands the JAX bindings: an XLA FFI layer over the
nvte_ep_*C API, a Python wrapper withcustom_vjpfor autograd, mesh-aware sharding rules, a multi-process test suite, and an end-to-end MoE example. NCCLncclEpDispatch/ncclEpCombineare exposed as XLA primitives and work with CUDA-graph capture.Implementation
Public Python API (
transformer_engine/jax/ep.py)ep_dispatch/ep_combinearejax.custom_vjpfunctions: forward is the FFI primitive, backward calls the matchingnvte_ep_*_bwdFFI primitive directly (noep_preparein the bwd — routing state is already cached inhandle.mem). Note thatep_dispatchalso callsep_preparein the forward path, which all-gathers and preprocesses routing maps.XLA FFI bindings (
transformer_engine/jax/csrc/extensions/ep.cpp)Five
XLA_FFI_DEFINE_HANDLER_SYMBOLentries —EpPrepareHandler,EpDispatchHandler,EpCombineHandler,EpDispatchBwdHandler,EpCombineBwdHandler— each calling the correspondingnvte_ep_*C entry point. All markedFFI_CudaGraph_Traitsso they capture cleanly.handle_idis a static FFI attribute baked at jit trace time.Primitives + Python layer (
transformer_engine/jax/cpp_extensions/ep.py, +951 lines)Standard TE primitive plumbing:
abstract_eval(shape/dtype inference),lowering,impl,outer_primitiveregistration, and partitioning rules so the EP collective is treated as a single sharded op by XLA (no spurious resharding around it).Sharding (
transformer_engine/jax/sharding.py, +12 lines)Adds the EP mesh axis to the global mesh resource set so downstream sharding rules can reference it.
Build wiring (
build_tools/jax.py, +41 lines)Threads NCCL EP linkage through the JAX
transformer_engine_jaxextension. No new top-level build flags; rides on the parent PR'sNVTE_BUILD_WITH_NCCL_EP.Tests & example
tests/jax/test_multi_process_ep.py(+690 lines): 13 tests covering bootstrap,ep_prepareshape/handle contracts, primitive-level dispatch/combine identity (uniform + skewed routing),custom_vjpfwd+bwd correctness, and HLO inspection (must not insert XLA collectives outside the EP FFI).tests/jax/multi_process_launch_ep.sh: 4-rank launcher; setsXLA_FLAGSto keep XLA command-buffer capture off for the EP FFI sequence (NCCL EP graph-destroy interaction).examples/jax/ep/ep_moe.py(+394 lines) +run_test_ep.sh: end-to-end MoE with EP, dp=ep=2 mesh, includes a ref-comparison--checkthat verifies fwd+bwd vs a single-process reference.Type of change
Checklist: