From 7528f71b28f199489168231e10ebf0d889abf820 Mon Sep 17 00:00:00 2001 From: Elisa Tsai Date: Thu, 25 Jun 2026 22:35:50 +0000 Subject: [PATCH] 2D ring (USP) attention with custom splash kernel --- src/maxdiffusion/configs/base_wan_27b.yml | 8 +- .../kernels/custom_splash_attention.py | 127 +++++++- .../splash_attention/ring_attention_kernel.py | 298 ++++++++++++++++++ .../loaders/ltx2_lora_nnx_loader.py | 4 +- src/maxdiffusion/models/attention_flax.py | 297 +++++++++++++++-- 5 files changed, 704 insertions(+), 30 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 19153da19..6f1c5035f 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -83,7 +83,13 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring +attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, tokamax_ring_custom, ulysses, ulysses_custom, ulysses_ring, ulysses_ring_custom, ulysses_ring_custom_bidir +# +# Best 2D-ring / USP (Ulysses x ring) configs for WAN2.2-T2V-A14B (720x1280, 81 frames) +# Set attention=ulysses_ring_custom and ulysses_shards=U (ring degree R=CP/U): +# CP4 (v7x-8): ulysses_shards=2 (R=2), BQ=9472 +# CP8 (v7x-8): ulysses_shards=4 (R=2), BQ=9472 +# CP16 (v7x-16): ulysses_shards=8 (R=2), BQ=9472 use_base2_exp: True use_experimental_scheduler: True # For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this. diff --git a/src/maxdiffusion/kernels/custom_splash_attention.py b/src/maxdiffusion/kernels/custom_splash_attention.py index fb50a51a9..ab5b6ee6a 100644 --- a/src/maxdiffusion/kernels/custom_splash_attention.py +++ b/src/maxdiffusion/kernels/custom_splash_attention.py @@ -59,6 +59,8 @@ def _flash_attention_kernel( l_scratch_ref, o_scratch_ref, o_ref, + l_ring_ref=None, + m_ring_ref=None, *, mask_value: float, grid_width: int, @@ -70,6 +72,7 @@ def _flash_attention_kernel( q_seq_len: int, kv_seq_len: int, use_base2_exp: bool = True, + fuse_reciprocal: bool = True, ): float32 = jnp.float32 head_dim_v_repeats, rem = divmod(head_dim_v, NUM_SUBLANES) @@ -192,8 +195,18 @@ def last_body(): @pl.when(j == grid_width - 1) def end(): l = l_scratch_ref[...] - l_inv = jnp.tile(1.0 / l, (head_dim_v_repeats, 1)) - o_ref[...] = (o_scratch_ref[...] * l_inv).astype(o_ref.dtype) + if fuse_reciprocal: + l_inv = jnp.tile(1.0 / l, (head_dim_v_repeats, 1)) + o_ref[...] = (o_scratch_ref[...] * l_inv).astype(o_ref.dtype) + else: + # Ring path: emit the un-normalized numerator plus the running softmax + # stats (max logit `m` and linear denominator `l`) so the outer ring loop + # can merge shard contributions and normalize only once at the very end. + o_ref[...] = o_scratch_ref[...].astype(o_ref.dtype) + if l_ring_ref is not None: + l_ring_ref[...] = l.astype(l_ring_ref.dtype) + if m_ring_ref is not None: + m_ring_ref[...] = m_scratch_ref[...].astype(m_ring_ref.dtype) def _flash_attention_kernel_mhpt( @@ -437,6 +450,116 @@ def v_index_map(h, i, j, *_): return all_out[-1] +def _splash_attention_forward_ring( + q: jax.Array, + k: jax.Array, + v: jax.Array, + block_sizes: _BlockSizes, + bkv_compute_in: int, + q_seq_len: int | None = None, + kv_seq_len: int | None = None, + use_base2_exp: bool = True, + use_experimental_scheduler: bool = False, + vmem_limit_bytes: int | None = None, +): + """Ring-specific forward path that returns pre-reciprocal fp32 accumulators. + + Mirrors `_splash_attention_forward`, but instead of normalizing the output by + the softmax denominator inside the kernel, it returns the un-normalized + numerator (`out`) together with the per-row max logit (`m`) and linear softmax + denominator (`l`). The outer ring loop merges these shard contributions and + normalizes only once at the very end (see + `ring_attention_kernel._custom_ring_attention_forward`). + + Returns: + A tuple `(out, m, l)` where + - `out` has shape `(num_q_heads, q_seq_len, head_dim_v)` (fp32, un-normalized), + - `m` and `l` have shape `(num_q_heads, q_seq_len)` (fp32). + """ + num_q_heads, padded_q_seq_len, head_dim_qk = q.shape + head_dim_v = v.shape[-1] + bq, bkv = block_sizes.block_q, block_sizes.block_kv + bkv_compute = block_sizes.block_kv_compute + num_kv_heads = k.shape[0] + padded_kv_seq_len = k.shape[1] + + actual_q_seq_len = q_seq_len if q_seq_len is not None else padded_q_seq_len + actual_kv_seq_len = kv_seq_len if kv_seq_len is not None else padded_kv_seq_len + q_heads_per_kv_head = num_q_heads // num_kv_heads + + def q_index_map(h, i, j, *_): + return (h, i, 0) + + def out_index_map(h, i, j, *_): + return h, 0, i + + def k_index_map(h, i, j, *_): + return (h // q_heads_per_kv_head, j, 0) + + def v_index_map(h, i, j, *_): + return (h // q_heads_per_kv_head, j, 0) + + in_specs = [ + pl.BlockSpec((None, bq, head_dim_qk), q_index_map), + pl.BlockSpec((None, bkv, head_dim_qk), k_index_map), + pl.BlockSpec((None, bkv, head_dim_v), v_index_map), + ] + out_shapes = [ + jax.ShapeDtypeStruct((NUM_SUBLANES, bq), jnp.float32), + jax.ShapeDtypeStruct((NUM_SUBLANES, bq), jnp.float32), + jax.ShapeDtypeStruct((head_dim_v, bq), jnp.float32), + jax.ShapeDtypeStruct((num_q_heads, head_dim_v, actual_q_seq_len), jnp.float32), + jax.ShapeDtypeStruct((num_q_heads, NUM_SUBLANES, actual_q_seq_len), jnp.float32), + jax.ShapeDtypeStruct((num_q_heads, NUM_SUBLANES, actual_q_seq_len), jnp.float32), + ] + out_specs = [ + pl.BlockSpec((NUM_SUBLANES, bq), lambda *_: (0, 0)), + pl.BlockSpec((NUM_SUBLANES, bq), lambda *_: (0, 0)), + pl.BlockSpec((head_dim_v, bq), lambda *_: (0, 0)), + pl.BlockSpec((None, head_dim_v, bq), out_index_map), + pl.BlockSpec((None, NUM_SUBLANES, bq), out_index_map), + pl.BlockSpec((None, NUM_SUBLANES, bq), out_index_map), + ] + grid_width = (actual_kv_seq_len + bkv - 1) // bkv + grid_height = (actual_q_seq_len + bq - 1) // bq + grid = (num_q_heads, grid_height, grid_width) + + all_out = pl.pallas_call( + functools.partial( + _flash_attention_kernel, + mask_value=DEFAULT_MASK_VALUE, + grid_width=grid_width, + bq=bq, + bkv=bkv, + bkv_compute=bkv_compute, + bkv_compute_in=bkv_compute_in, + head_dim_v=head_dim_v, + q_seq_len=actual_q_seq_len, + kv_seq_len=actual_kv_seq_len, + use_base2_exp=use_base2_exp, + fuse_reciprocal=False, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=in_specs, + out_specs=out_specs, + grid=grid, + ), + compiler_params=pltpu.CompilerParams( + dimension_semantics=("parallel", "arbitrary", "arbitrary"), + flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": use_experimental_scheduler}, + disable_bounds_checks=True, + skip_device_barrier=True, + vmem_limit_bytes=vmem_limit_bytes, + ), + out_shape=out_shapes, + )(q, k, v) + out = jnp.swapaxes(all_out[3], 1, 2) # (h, head_dim_v, s) -> (h, s, head_dim_v) + l = all_out[4][:, 0, :] # (h, s) + m = all_out[5][:, 0, :] # (h, s) + return out, m, l + + def _splash_attention_forward_mhpt( q: jax.Array, k: jax.Array, diff --git a/src/maxdiffusion/kernels/splash_attention/ring_attention_kernel.py b/src/maxdiffusion/kernels/splash_attention/ring_attention_kernel.py index e1e52b794..d8a80b512 100644 --- a/src/maxdiffusion/kernels/splash_attention/ring_attention_kernel.py +++ b/src/maxdiffusion/kernels/splash_attention/ring_attention_kernel.py @@ -27,6 +27,7 @@ from . import splash_attention_kernel as splash_kernel from . import splash_attention_mask as mask_lib from . import splash_attention_mask_info as mask_info_lib +from .. import custom_splash_attention as custom_splash P = jax.P MaskInfo = mask_info_lib.MaskInfo @@ -711,3 +712,300 @@ def make_ring_attention( fwd_mask_sparsity=fwd_mask_sparsity, dkv_mask_sparsity=dkv_mask_sparsity, ) + + +# --------------------------------------------------------------------------- +# Ring attention backed by the custom (head-dim-minor) splash kernel. +# +# This mirrors `_ring_attention_forward` above, but uses the dense custom Pallas +# kernel (`custom_splash_attention`) as the per-shard compute instead of the +# splash kernel. The custom kernel is a FullMask / dense kernel: it does not use +# MaskInfo or in-kernel segment ids (padding is handled by the caller via +# `_pad_data_for_flash` and the `q_seq_len` / `kv_seq_len` bounds), so this ring +# variant drops all of the MaskInfo slicing machinery and is forward-only. +# --------------------------------------------------------------------------- + + +def _custom_bidirectional_ring_forward( + q: jax.Array, + k: jax.Array, + v: jax.Array, + *, + block_sizes: "custom_splash._BlockSizes", + bkv_compute_in: int, + orig_q_seq_len: int, + orig_kv_seq_len: int, + use_base2_exp: bool, + use_experimental_scheduler: bool, + vmem_limit_bytes: int | None, + mask_value: float, + ring_axis: str, +) -> jax.Array: + """Wrap-free (bidirectional) ring attention for a NON-wrapping ring axis. + + On a torus dimension the +1-mod-R ppermute is nearest-neighbor, but on a + cut/non-wrapping axis (e.g. the size-4 z line of a v7x-16 slice) the wrap edge + (R-1 -> 0) spans the whole line diameter. Instead of one rotating stream with + that long, congested wrap, stream K/V BOTH directions one hop at a time: + - rightward stream: device i holds KV_{i-t} after t hops, + - leftward stream: device i holds KV_{i+t} after t hops, + with out-of-range shards (line ends) masked out of the online softmax. Every + step is a single hop and uses both link directions; no edge ever spans the + diameter. + + Trade-off: each device computes ~2x attention blocks (the line-end ones are + masked), traded for the removed multi-hop wrap. Net win when the ring is + comms-bound (the case on a non-wrapping axis). Operates on the FULL real ring + axis (no sub-group perm). + """ + axis_size = lax.axis_size(ring_axis) + idx = lax.axis_index(ring_axis) + exp_fn = jnp.exp2 if use_base2_exp else jnp.exp + + def _attn(kc, vc): + o, m, l = custom_splash._splash_attention_forward_ring( # pylint: disable=protected-access + q, + kc, + vc, + block_sizes, + bkv_compute_in, + q_seq_len=orig_q_seq_len, + kv_seq_len=orig_kv_seq_len, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, + vmem_limit_bytes=vmem_limit_bytes, + ) + return o.astype(jnp.float32), m.astype(jnp.float32), l.astype(jnp.float32) + + def _merge(m, l, o, mc, lc, oc, valid): + # Nullify invalid (line-end) contributions. Force mc to mask_value (beta -> 0) + # AND zero lc/oc so a non-finite zero-buffer result can't leak via 0*inf=nan. + mc = jnp.where(valid, mc, mask_value) + lc = jnp.where(valid, lc, 0.0) + oc = jnp.where(valid, oc, 0.0) + m_next = jnp.maximum(m, mc) + alpha = exp_fn(m - m_next) + beta = exp_fn(mc - m_next) + return m_next, alpha * l + beta * lc, alpha[..., None] * o + beta[..., None] * oc + + # t=0: own shard (always valid). _attn returns (o, m, l). + o, m, l = _attn(k, v) + + # Non-wrapping one-hop shifts (line ends send/receive nothing). + shift_r = partial(lax.ppermute, axis_name=ring_axis, perm=[(i, i + 1) for i in range(axis_size - 1)]) + shift_l = partial(lax.ppermute, axis_name=ring_axis, perm=[(i, i - 1) for i in range(1, axis_size)]) + + # Prime buffers for t=1 (one hop each direction): device i -> KV_{i-1}, KV_{i+1}. + kr, vr = shift_r(k), shift_r(v) + kl, vl = shift_l(k), shift_l(v) + + def body(carry, t): + m, l, o, kr, vr, kl, vl = carry + valid_r = (idx - t) >= 0 + valid_l = (idx + t) <= (axis_size - 1) + # Feed real (own) K/V on invalid steps so _attn never runs on a degenerate + # zero buffer (line ends receive 0 from the partial ppermute); masked below. + kr_s, vr_s = jnp.where(valid_r, kr, k), jnp.where(valid_r, vr, v) + kl_s, vl_s = jnp.where(valid_l, kl, k), jnp.where(valid_l, vl, v) + # Compute against the current shards (KV_{i-t}, KV_{i+t}) ... + o_r, m_r, l_r = _attn(kr_s, vr_s) + m, l, o = _merge(m, l, o, m_r, l_r, o_r, valid_r) + o_l, m_l, l_l = _attn(kl_s, vl_s) + m, l, o = _merge(m, l, o, m_l, l_l, o_l, valid_l) + # ... and prefetch the next hop (independent of the matmuls above -> overlaps). + kr_n, vr_n = shift_r(kr), shift_r(vr) + kl_n, vl_n = shift_l(kl), shift_l(vl) + return (m, l, o, kr_n, vr_n, kl_n, vl_n), None + + (_, l_final, o_final, *_), _ = lax.scan( + body, + (m, l, o, kr, vr, kl, vl), + xs=jnp.arange(1, axis_size), + length=axis_size - 1, + unroll=True, + ) + + l_inv = jnp.where(l_final == 0.0, 0.0, 1.0 / l_final) + return (o_final * l_inv[..., None]).astype(q.dtype) + + +def _custom_ring_attention_forward( + q: jax.Array, + k: jax.Array, + v: jax.Array, + *, + block_sizes: "custom_splash._BlockSizes", + bkv_compute_in: int, + orig_q_seq_len: int, + orig_kv_seq_len: int, + use_base2_exp: bool, + use_experimental_scheduler: bool, + vmem_limit_bytes: int | None, + mask_value: float, + ring_axis: str, + ring_size: int | None = None, + perm: list[tuple[int, int]] | None = None, + bidirectional: bool = False, +) -> jax.Array: + """Forward-only ring attention using the custom dense splash kernel. + + Args: + q: Query shard, shape `(num_q_heads, q_seq_len, head_dim_qk)`. Stationary + across ring steps. Must already be padded (and pre-scaled by LOG2E when + `use_base2_exp`) by the caller. + k: Key shard, shape `(num_kv_heads, kv_seq_len, head_dim_qk)`. Rotated across + the ring axis. + v: Value shard, shape `(num_kv_heads, kv_seq_len, head_dim_v)`. Rotated. + block_sizes: Custom-kernel block sizes (block_q / block_kv / block_kv_compute). + bkv_compute_in: Inner VPU register-tiling step for the custom kernel. + orig_q_seq_len: Un-padded local query length (grid bound). + orig_kv_seq_len: Un-padded local key/value length (grid bound). Assumed equal + across all shards (uniform per-shard padding), matching the + `rotate_segment_ids=False` convention of the tokamax ring path. + use_base2_exp: Whether the kernel uses base-2 exp (must match the LOG2E + pre-scaling applied to `q` by the caller). + use_experimental_scheduler: Forwarded to the custom kernel. + vmem_limit_bytes: Forwarded to the custom kernel. + mask_value: Initial running-max value for the online softmax. + ring_axis: Name of the mesh axis to rotate K/V over (e.g. "context"). + ring_size: Number of ring steps to scan over. Defaults to the full size of + `ring_axis`. For a hybrid Ulysses+Ring (USP) split this is the ring + sub-group size R (< full axis size), so each device only rotates within its + ring sub-group. + perm: Explicit `ppermute` permutation. Defaults to a full-axis +1 rotation. + For the hybrid split, pass a perm that rotates K/V *within each ring + sub-group only* (built by the caller from the U x R factorization). + + Returns: + Normalized attention output, shape `(num_q_heads, q_seq_len, head_dim_v)`. + """ + axis_size = lax.axis_size(ring_axis) + if bidirectional: + if perm is not None or (ring_size is not None and ring_size != axis_size): + raise ValueError( + "bidirectional (wrap-free) ring requires perm=None and ring_size==axis_size " + "(it operates on the full real ring axis)." + ) + return _custom_bidirectional_ring_forward( + q, + k, + v, + block_sizes=block_sizes, + bkv_compute_in=bkv_compute_in, + orig_q_seq_len=orig_q_seq_len, + orig_kv_seq_len=orig_kv_seq_len, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, + vmem_limit_bytes=vmem_limit_bytes, + mask_value=mask_value, + ring_axis=ring_axis, + ) + if ring_size is None: + ring_size = axis_size + if perm is None: + perm = [(i, (i + 1) % axis_size) for i in range(axis_size)] + + shift = partial(lax.ppermute, axis_name=ring_axis, perm=perm) + + exp_fn = jnp.exp2 if use_base2_exp else jnp.exp + + num_q_heads = q.shape[0] + head_dim_v = v.shape[-1] + o_init = jnp.zeros((num_q_heads, orig_q_seq_len, head_dim_v), jnp.float32) + l_init = jnp.zeros((num_q_heads, orig_q_seq_len), jnp.float32) + m_init = jnp.full((num_q_heads, orig_q_seq_len), mask_value, jnp.float32) + + def body(carry, i): + m_prev, l_prev, o_prev, k_current, v_current = carry + # Prefetch the next shard while we compute on the current one. + k_next = shift(k_current) + v_next = shift(v_current) + + o_curr, m_curr, l_curr = custom_splash._splash_attention_forward_ring( # pylint: disable=protected-access + q, + k_current, + v_current, + block_sizes, + bkv_compute_in, + q_seq_len=orig_q_seq_len, + kv_seq_len=orig_kv_seq_len, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, + vmem_limit_bytes=vmem_limit_bytes, + ) + m_curr = m_curr.astype(jnp.float32) + l_curr = l_curr.astype(jnp.float32) + o_curr = o_curr.astype(jnp.float32) + + m_next = jnp.maximum(m_prev, m_curr) + alpha = exp_fn(m_prev - m_next) + beta = exp_fn(m_curr - m_next) + l_next = alpha * l_prev + beta * l_curr + o_next = alpha[..., None] * o_prev + beta[..., None] * o_curr + return (m_next, l_next, o_next, k_next, v_next), None + + initial_carry = (m_init, l_init, o_init, k, v) + (_, l_final, o_final, _, _), _ = lax.scan( + body, + initial_carry, + xs=jnp.arange(0, ring_size), + length=ring_size, + unroll=True, + ) + + l_inv = jnp.where(l_final == 0.0, 0.0, 1.0 / l_final) + out = (o_final * l_inv[..., None]).astype(q.dtype) + return out + + +def make_custom_ring_attention( + *, + block_sizes: "custom_splash._BlockSizes", + bkv_compute_in: int, + orig_q_seq_len: int, + orig_kv_seq_len: int, + use_base2_exp: bool = True, + use_experimental_scheduler: bool = False, + vmem_limit_bytes: int | None = None, + mask_value: float = base.DEFAULT_MASK_VALUE, + ring_axis: str = "context", + ring_size: int | None = None, + perm: list[tuple[int, int]] | None = None, + bidirectional: bool = False, +): + """Builds a forward-only ring-attention callable around the custom kernel. + + The returned function takes a single (un-batched) `(q, k, v)` triple of shape + `(num_heads, seq, head_dim)` and is meant to be `jax.vmap`-ped over the batch + axis inside the attention `shard_map` (the `ppermute` rotates over `ring_axis`, + which is a mesh axis and independent of the vmap batch axis). + + `ring_size` / `perm` let a caller restrict the rotation to a ring sub-group of + the axis (for the hybrid Ulysses+Ring / USP split); when omitted the rotation + covers the whole `ring_axis`. + + `bidirectional=True` selects the wrap-free schedule (streams K/V both directions + one hop at a time) for a NON-wrapping ring axis, avoiding the diameter-length + wrap hop. Requires `perm=None` and the full real ring axis (no sub-group). + """ + + def _ring(q, k, v): + return _custom_ring_attention_forward( + q, + k, + v, + block_sizes=block_sizes, + bkv_compute_in=bkv_compute_in, + orig_q_seq_len=orig_q_seq_len, + orig_kv_seq_len=orig_kv_seq_len, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, + vmem_limit_bytes=vmem_limit_bytes, + mask_value=mask_value, + ring_axis=ring_axis, + ring_size=ring_size, + perm=perm, + bidirectional=bidirectional, + ) + + return _ring diff --git a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py b/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py index 1fad541c6..a3c4d0d38 100644 --- a/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py +++ b/src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py @@ -76,8 +76,6 @@ def translate_fn(nnx_path_str): # the merge_fn warns about unmatched keys in each dict, so we only warn about any leftovers unmatched_keys = set(h_state_dict) - set(transformer_state_dict) - set(connector_state_dict) if unmatched_keys: - max_logging.log( - f"{len(unmatched_keys)} key(s) in LoRA dictionary routed to no merge target: {unmatched_keys}" - ) + max_logging.log(f"{len(unmatched_keys)} key(s) in LoRA dictionary routed to no merge target: {unmatched_keys}") return pipeline diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index edc9f4f7b..f6edc8309 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -347,6 +347,37 @@ def convert_to_tokamax_splash_config( ) +def _extract_custom_block_sizes(flash_block_sizes): + """Pulls custom-kernel block sizes out of the (dict or BlockSizes-like) config. + + Mirrors the extraction used by the `ulysses_custom` path so the custom ring + kernel honors the same `flash_block_sizes={...}` knobs. + """ + bq = 4864 + bkv = 1024 + bkv_compute = 1024 + bkv_compute_in = 1024 + heads_per_tile = 1 + vmem_limit_bytes = None + if flash_block_sizes is not None: + if isinstance(flash_block_sizes, dict): + get = flash_block_sizes.get + bq = get("block_q", bq) + bkv = get("block_kv", bkv) + bkv_compute = get("block_kv_compute", bkv_compute) + bkv_compute_in = get("block_kv_compute_in", bkv_compute_in) + heads_per_tile = get("heads_per_tile", heads_per_tile) + vmem_limit_bytes = get("vmem_limit_bytes", vmem_limit_bytes) + else: + bq = getattr(flash_block_sizes, "block_q", bq) + bkv = getattr(flash_block_sizes, "block_kv", bkv) + bkv_compute = getattr(flash_block_sizes, "block_kv_compute", bkv_compute) + bkv_compute_in = getattr(flash_block_sizes, "block_kv_compute_in", bkv_compute_in) + heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", heads_per_tile) + vmem_limit_bytes = getattr(flash_block_sizes, "vmem_limit_bytes", vmem_limit_bytes) + return bq, bkv, bkv_compute, bkv_compute_in, heads_per_tile, vmem_limit_bytes + + def _build_padding_segment_ids( query_seq_len: int, q_padded_len: int, @@ -418,6 +449,32 @@ def _tpu_flash_attention( check_rep=False, ) def wrap_flash_attention(query, key, value): + if attention_kernel == "tokamax_ring_custom": + # Ring attention backed by the custom dense splash kernel. q stays local, + # k/v rotate over the "context" axis (handled inside the ring kernel). + bq, bkv, bkv_compute, bkv_compute_in, heads_per_tile, vmem_limit_bytes = _extract_custom_block_sizes(flash_block_sizes) + if heads_per_tile > 1: + raise NotImplementedError("tokamax_ring_custom currently supports heads_per_tile == 1 only.") + query_local = query * LOG2E if use_base2_exp else query + query_local, kv_size, query_seq_len = _pad_data_for_flash(query_local, heads, bq) + key_local, _, key_seq_len = _pad_data_for_flash(key, heads, bkv) + value_local, _, _ = _pad_data_for_flash(value, heads, bkv) + + bsizes = custom_splash._BlockSizes(block_q=bq, block_kv=bkv, block_kv_compute=bkv_compute) + ring_kernel = tokamax_ring_attention_kernel.make_custom_ring_attention( + block_sizes=bsizes, + bkv_compute_in=bkv_compute_in, + orig_q_seq_len=query_seq_len, + orig_kv_seq_len=key_seq_len, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, + vmem_limit_bytes=vmem_limit_bytes, + ring_axis="context", + ) + vmapped_ring = jax.vmap(ring_kernel, in_axes=(0, 0, 0)) + attention_output = vmapped_ring(query_local, key_local, value_local) + return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) + uses_fused_kernel = block_sizes.use_fused_bwd_kernel block_q_sizes = ( block_sizes.block_q, @@ -626,28 +683,12 @@ def wrap_ulysses_attention(query, key, value): value = jax.lax.all_to_all(value, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True) if use_custom_kernel: - bq = 4864 - bkv = 1024 - bkv_compute = 1024 - bkv_compute_in = 1024 - heads_per_tile = 1 - vmem_limit_bytes = None - - if flash_block_sizes is not None: - if isinstance(flash_block_sizes, dict): - bq = flash_block_sizes.get("block_q", None) or bq - bkv = flash_block_sizes.get("block_kv", None) or bkv - bkv_compute = flash_block_sizes.get("block_kv_compute", None) or bkv_compute - bkv_compute_in = flash_block_sizes.get("block_kv_compute_in", None) or bkv_compute_in - heads_per_tile = flash_block_sizes.get("heads_per_tile", None) or heads_per_tile - vmem_limit_bytes = flash_block_sizes.get("vmem_limit_bytes", None) or vmem_limit_bytes - else: - bq = getattr(flash_block_sizes, "block_q", None) or bq - bkv = getattr(flash_block_sizes, "block_kv", None) or bkv - bkv_compute = getattr(flash_block_sizes, "block_kv_compute", None) or bkv_compute - bkv_compute_in = getattr(flash_block_sizes, "block_kv_compute_in", None) or bkv_compute_in - heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", None) or heads_per_tile - vmem_limit_bytes = getattr(flash_block_sizes, "vmem_limit_bytes", None) or vmem_limit_bytes + if attention_mask is not None: + raise NotImplementedError( + "The custom dense splash kernel (use_custom_kernel) does not support attention_mask " + "(it only handles padding via orig_seq_len); got a non-None attention_mask." + ) + bq, bkv, bkv_compute, bkv_compute_in, heads_per_tile, vmem_limit_bytes = _extract_custom_block_sizes(flash_block_sizes) if use_base2_exp: query = query * LOG2E @@ -883,6 +924,146 @@ def wrap_ulysses_ring_attention(query, key, value): return x +def _ulysses_ring_custom_attention( + query: jax.Array, + key: jax.Array, + value: jax.Array, + heads: int, + mesh: Mesh, + axis_names_q: AxisNames, + axis_names_kv: AxisNames, + flash_block_sizes: BlockSizes, + dtype: jnp.dtype = jnp.float32, + mask_padding_tokens: bool = True, + residual_checkpoint_name: str | None = None, + attention_mask: jax.Array = None, + ulysses_shards: int = -1, + use_base2_exp: bool = True, + use_experimental_scheduler: bool = False, + bidirectional: bool = False, +) -> jax.Array: + """Hybrid Ulysses + Ring (USP) with the CUSTOM splash kernel on main's mesh. + + Uses origin/main's explicit internal `(ring, ulysses)` mesh + (`_create_internal_ulysses_ring_mesh`, commit c104db51) instead of single-axis + collective sub-groups: the public `context` axis is reshaped with the Ulysses + axis innermost, so the Ulysses all-to-all stays INTRA-chip and the ring rotates + ACROSS chips. The per-shard attention is our custom splash kernel + (`make_custom_ring_attention`), not the tokamax_ring kernel main uses. + + 1. all-to-all over the (intra-chip) Ulysses axis: trade sequence for heads; + 2. ring (full ppermute) over the (cross-chip) ring axis, online-softmax merge; + 3. all-to-all back to restore the sequence-sharded / full-heads layout. + + U = ulysses_shards (from config); R = context // U. U=context -> pure + Ulysses, U=1 -> pure Ring (all on the same custom kernel). + """ + if attention_mask is not None: + raise NotImplementedError( + "ulysses_ring_custom does not support attention_mask (the custom splash kernels only " + "handle padding via orig_seq_len); got a non-None attention_mask." + ) + axis_name = "context" + num_context_shards = mesh.shape[axis_name] + num_ulysses_shards = ulysses_shards + if num_ulysses_shards <= 0: + raise ValueError("ulysses_ring_custom requires ulysses_shards to be set from config or command line.") + if num_context_shards % num_ulysses_shards != 0: + raise ValueError( + f"ulysses_ring_custom requires ulysses_shards to divide the context shard count, " + f"got context_shards={num_context_shards} and ulysses_shards={num_ulysses_shards}." + ) + num_ring_shards = num_context_shards // num_ulysses_shards + + query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards) + key, _ = _reshape_data_for_flash(key, heads, num_context_shards) + value, _ = _reshape_data_for_flash(value, heads, num_context_shards) + num_heads = query.shape[1] + if num_heads % num_ulysses_shards != 0: + raise ValueError(f"Ulysses+Ring requires heads divisible by U={num_ulysses_shards}, got heads={num_heads}.") + + bq, bkv, bkv_compute, bkv_compute_in, heads_per_tile, vmem_limit_bytes = _extract_custom_block_sizes(flash_block_sizes) + if heads_per_tile > 1: + raise NotImplementedError("ulysses_ring_custom currently supports heads_per_tile == 1 only.") + + internal_mesh = _create_internal_ulysses_ring_mesh(mesh, num_ring_shards, num_ulysses_shards) + ring_axis = INTERNAL_RING_AXIS + ulysses_axis = INTERNAL_ULYSSES_AXIS + + q_axis_names = nn.logical_to_mesh_axes(axis_names_q) + kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) + internal_q_axis_names = _replace_mesh_axis_names(q_axis_names, axis_name, (ring_axis, ulysses_axis)) + internal_kv_axis_names = _replace_mesh_axis_names(kv_axis_names, axis_name, (ring_axis, ulysses_axis)) + + @functools.partial( + jax.shard_map, + mesh=internal_mesh, + in_specs=(internal_q_axis_names, internal_kv_axis_names, internal_kv_axis_names), + out_specs=internal_q_axis_names, + check_vma=False, + ) + def wrap_ulysses_ring_attention(query, key, value): + # (1) Ulysses all-to-all over the (intra-chip) ulysses axis: heads -> sequence, + # so each device holds the full ring-chunk sequence with heads/U heads. + a2a = functools.partial(jax.lax.all_to_all, axis_name=ulysses_axis, tiled=True) + query = a2a(query, split_axis=1, concat_axis=2) + key = a2a(key, split_axis=1, concat_axis=2) + value = a2a(value, split_axis=1, concat_axis=2) + + if use_base2_exp: + query = query * LOG2E + + query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, bq) + key, _, key_seq_len = _pad_data_for_flash(key, heads, bkv) + value, _, _ = _pad_data_for_flash(value, heads, bkv) + + bsizes = custom_splash._BlockSizes(block_q=bq, block_kv=bkv, block_kv_compute=bkv_compute) + if num_ring_shards == 1: + # (2a) R=1: the ring is trivial (no rotation) -> use the lighter dedicated + # splash kernel (fuse_reciprocal, no fp32 online-softmax residual windows). + # Same math as the 1-step ring, and it fits BQ=8448 where the ring kernel + # OOMs (its 3x residual windows). make_splash_mha returns [H, D, S]. + splash_kernel = custom_splash.make_splash_mha( + block_sizes=bsizes, + bkv_compute_in=bkv_compute_in, + orig_q_seq_len=query_seq_len, + orig_kv_seq_len=key_seq_len, + heads_per_tile=heads_per_tile, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, + vmem_limit_bytes=vmem_limit_bytes, + ) + attention_output = jnp.swapaxes(jax.vmap(splash_kernel, in_axes=(0, 0, 0))(query, key, value), 2, 3) + else: + # (2b) Ring (full ppermute over the cross-chip ring axis) with the custom kernel. + # bidirectional=True -> wrap-free schedule (streams K/V both directions one hop + # at a time), for a non-wrapping ring axis. Selected by attention=ulysses_ring_custom_bidir. + ring_kernel = tokamax_ring_attention_kernel.make_custom_ring_attention( + block_sizes=bsizes, + bkv_compute_in=bkv_compute_in, + orig_q_seq_len=query_seq_len, + orig_kv_seq_len=key_seq_len, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, + vmem_limit_bytes=vmem_limit_bytes, + ring_axis=ring_axis, + ring_size=num_ring_shards, + bidirectional=bidirectional, + ) + attention_output = jax.vmap(ring_kernel, in_axes=(0, 0, 0))(query, key, value) + attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) + + # (3) Ulysses all-to-all back: sequence -> heads, restoring the layout. + attention_output = a2a(attention_output, split_axis=2, concat_axis=1) + return attention_output + + x = wrap_ulysses_ring_attention(query, key, value) + x = jax.lax.with_sharding_constraint(x, q_axis_names) + x = x[:, :, :orig_q_seq_len, :] + x = _reshape_heads_to_head_dim(x) + return x + + def _apply_attention_dot( query: Array, key: Array, @@ -1029,6 +1210,52 @@ def ulysses_custom_kernel(q, k, v, context): ) +@register_kernel("ulysses_ring_custom") +def ulysses_ring_custom_kernel(q, k, v, context): + return _ulysses_ring_custom_attention( + q, + k * context["scale"], + v, + context["heads"], + context["mesh"], + context["axis_names_q"], + context["axis_names_kv"], + context["flash_block_sizes"], + context["dtype"], + mask_padding_tokens=context["mask_padding_tokens"], + residual_checkpoint_name=context["residual_checkpoint_name"], + attention_mask=context["attention_mask"], + ulysses_shards=context["ulysses_shards"], + use_base2_exp=context.get("use_base2_exp", True), + use_experimental_scheduler=context.get("use_experimental_scheduler", False), + ) + + +@register_kernel("ulysses_ring_custom_bidir") +def ulysses_ring_custom_bidir_kernel(q, k, v, context): + """Wrap-free (bidirectional) variant of ulysses_ring_custom: the ring streams + K/V both directions one hop at a time, avoiding the diameter-length wrap hop + on a non-wrapping ring axis. Same USP split as ulysses_ring_custom otherwise.""" + return _ulysses_ring_custom_attention( + q, + k * context["scale"], + v, + context["heads"], + context["mesh"], + context["axis_names_q"], + context["axis_names_kv"], + context["flash_block_sizes"], + context["dtype"], + mask_padding_tokens=context["mask_padding_tokens"], + residual_checkpoint_name=context["residual_checkpoint_name"], + attention_mask=context["attention_mask"], + ulysses_shards=context["ulysses_shards"], + use_base2_exp=context.get("use_base2_exp", True), + use_experimental_scheduler=context.get("use_experimental_scheduler", False), + bidirectional=True, + ) + + @register_kernel("ulysses") def ulysses_kernel(q, k, v, context): return _ulysses_attention( @@ -1128,6 +1355,26 @@ def tokamax_ring_kernel(q, k, v, context): ) +@register_kernel("tokamax_ring_custom") +def tokamax_ring_custom_kernel(q, k, v, context): + return _tpu_flash_attention( + q, + k * context["scale"], + v, + context["heads"], + context["mesh"], + context["axis_names_q"], + context["axis_names_kv"], + context["flash_block_sizes"], + context["dtype"], + attention_kernel="tokamax_ring_custom", + mask_padding_tokens=context["mask_padding_tokens"], + attention_mask=context["attention_mask"], + use_base2_exp=context.get("use_base2_exp", True), + use_experimental_scheduler=context.get("use_experimental_scheduler", False), + ) + + @register_kernel("cudnn_flash_te") def cudnn_flash_te_kernel(q, k, v, context): return _cudnn_flash_attention(q, k, v, context["heads"], context["mesh"], context["dpa_layer"]) @@ -1602,8 +1849,10 @@ def __init__( else: axis_names_q = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_Q_LENGTH, D_KV) axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV) - if attention_kernel in ("tokamax_ring", "ulysses_ring") and not is_self_attention: - attention_kernel = "tokamax_flash" + if attention_kernel in ("tokamax_ring", "tokamax_ring_custom", "ulysses_ring") and not is_self_attention: + attention_kernel = "tokamax_flash" # do not use ring attention for cross attention + if attention_kernel in ("ulysses_ring_custom", "ulysses_ring_custom_bidir") and not is_self_attention: + attention_kernel = "ulysses_custom" # plain ulysses (no ring) for cross attention self.added_kv_proj_dim = added_kv_proj_dim # New for I2V self.image_seq_len = image_seq_len # New for I2V tpu_type = get_tpu_type()