Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,13 @@ jit_initializers: True
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, tokamax_ring_custom, ulysses, ulysses_custom, ulysses_ring, ulysses_ring_custom, ulysses_ring_custom_bidir

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we comment the best flashblock sizes here itself, it would be easy for anypne to look into

#
# Best 2D-ring / USP (Ulysses x ring) configs for WAN2.2-T2V-A14B (720x1280, 81 frames)
# Set attention=ulysses_ring_custom and ulysses_shards=U (ring degree R=CP/U):
# CP4 (v7x-8): ulysses_shards=2 (R=2), BQ=9472
# CP8 (v7x-8): ulysses_shards=4 (R=2), BQ=9472
# CP16 (v7x-16): ulysses_shards=8 (R=2), BQ=9472
use_base2_exp: True
use_experimental_scheduler: True
# For attention=ulysses_ring, hidden Ulysses shard count; ring shards are context / this.
Expand Down
127 changes: 125 additions & 2 deletions src/maxdiffusion/kernels/custom_splash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def _flash_attention_kernel(
l_scratch_ref,
o_scratch_ref,
o_ref,
l_ring_ref=None,
m_ring_ref=None,
*,
mask_value: float,
grid_width: int,
Expand All @@ -70,6 +72,7 @@ def _flash_attention_kernel(
q_seq_len: int,
kv_seq_len: int,
use_base2_exp: bool = True,
fuse_reciprocal: bool = True,
):
float32 = jnp.float32
head_dim_v_repeats, rem = divmod(head_dim_v, NUM_SUBLANES)
Expand Down Expand Up @@ -192,8 +195,18 @@ def last_body():
@pl.when(j == grid_width - 1)
def end():
l = l_scratch_ref[...]
l_inv = jnp.tile(1.0 / l, (head_dim_v_repeats, 1))
o_ref[...] = (o_scratch_ref[...] * l_inv).astype(o_ref.dtype)
if fuse_reciprocal:
l_inv = jnp.tile(1.0 / l, (head_dim_v_repeats, 1))
o_ref[...] = (o_scratch_ref[...] * l_inv).astype(o_ref.dtype)
else:
# Ring path: emit the un-normalized numerator plus the running softmax
# stats (max logit `m` and linear denominator `l`) so the outer ring loop
# can merge shard contributions and normalize only once at the very end.
o_ref[...] = o_scratch_ref[...].astype(o_ref.dtype)
if l_ring_ref is not None:
l_ring_ref[...] = l.astype(l_ring_ref.dtype)
if m_ring_ref is not None:
m_ring_ref[...] = m_scratch_ref[...].astype(m_ring_ref.dtype)


def _flash_attention_kernel_mhpt(
Expand Down Expand Up @@ -437,6 +450,116 @@ def v_index_map(h, i, j, *_):
return all_out[-1]


def _splash_attention_forward_ring(
q: jax.Array,
k: jax.Array,
v: jax.Array,
block_sizes: _BlockSizes,
bkv_compute_in: int,
q_seq_len: int | None = None,
kv_seq_len: int | None = None,
use_base2_exp: bool = True,
use_experimental_scheduler: bool = False,
vmem_limit_bytes: int | None = None,
):
"""Ring-specific forward path that returns pre-reciprocal fp32 accumulators.

Mirrors `_splash_attention_forward`, but instead of normalizing the output by
the softmax denominator inside the kernel, it returns the un-normalized
numerator (`out`) together with the per-row max logit (`m`) and linear softmax
denominator (`l`). The outer ring loop merges these shard contributions and
normalizes only once at the very end (see
`ring_attention_kernel._custom_ring_attention_forward`).

Returns:
A tuple `(out, m, l)` where
- `out` has shape `(num_q_heads, q_seq_len, head_dim_v)` (fp32, un-normalized),
- `m` and `l` have shape `(num_q_heads, q_seq_len)` (fp32).
"""
num_q_heads, padded_q_seq_len, head_dim_qk = q.shape
head_dim_v = v.shape[-1]
bq, bkv = block_sizes.block_q, block_sizes.block_kv
bkv_compute = block_sizes.block_kv_compute
num_kv_heads = k.shape[0]
padded_kv_seq_len = k.shape[1]

actual_q_seq_len = q_seq_len if q_seq_len is not None else padded_q_seq_len
actual_kv_seq_len = kv_seq_len if kv_seq_len is not None else padded_kv_seq_len
q_heads_per_kv_head = num_q_heads // num_kv_heads

def q_index_map(h, i, j, *_):
return (h, i, 0)

def out_index_map(h, i, j, *_):
return h, 0, i

def k_index_map(h, i, j, *_):
return (h // q_heads_per_kv_head, j, 0)

def v_index_map(h, i, j, *_):
return (h // q_heads_per_kv_head, j, 0)

in_specs = [
pl.BlockSpec((None, bq, head_dim_qk), q_index_map),
pl.BlockSpec((None, bkv, head_dim_qk), k_index_map),
pl.BlockSpec((None, bkv, head_dim_v), v_index_map),
]
out_shapes = [
jax.ShapeDtypeStruct((NUM_SUBLANES, bq), jnp.float32),
jax.ShapeDtypeStruct((NUM_SUBLANES, bq), jnp.float32),
jax.ShapeDtypeStruct((head_dim_v, bq), jnp.float32),
jax.ShapeDtypeStruct((num_q_heads, head_dim_v, actual_q_seq_len), jnp.float32),
jax.ShapeDtypeStruct((num_q_heads, NUM_SUBLANES, actual_q_seq_len), jnp.float32),
jax.ShapeDtypeStruct((num_q_heads, NUM_SUBLANES, actual_q_seq_len), jnp.float32),
]
out_specs = [
pl.BlockSpec((NUM_SUBLANES, bq), lambda *_: (0, 0)),
pl.BlockSpec((NUM_SUBLANES, bq), lambda *_: (0, 0)),
pl.BlockSpec((head_dim_v, bq), lambda *_: (0, 0)),
pl.BlockSpec((None, head_dim_v, bq), out_index_map),
pl.BlockSpec((None, NUM_SUBLANES, bq), out_index_map),
pl.BlockSpec((None, NUM_SUBLANES, bq), out_index_map),
]
grid_width = (actual_kv_seq_len + bkv - 1) // bkv
grid_height = (actual_q_seq_len + bq - 1) // bq
grid = (num_q_heads, grid_height, grid_width)

all_out = pl.pallas_call(
functools.partial(
_flash_attention_kernel,
mask_value=DEFAULT_MASK_VALUE,
grid_width=grid_width,
bq=bq,
bkv=bkv,
bkv_compute=bkv_compute,
bkv_compute_in=bkv_compute_in,
head_dim_v=head_dim_v,
q_seq_len=actual_q_seq_len,
kv_seq_len=actual_kv_seq_len,
use_base2_exp=use_base2_exp,
fuse_reciprocal=False,
),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=in_specs,
out_specs=out_specs,
grid=grid,
),
compiler_params=pltpu.CompilerParams(
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": use_experimental_scheduler},
disable_bounds_checks=True,
skip_device_barrier=True,
vmem_limit_bytes=vmem_limit_bytes,
),
out_shape=out_shapes,
)(q, k, v)
out = jnp.swapaxes(all_out[3], 1, 2) # (h, head_dim_v, s) -> (h, s, head_dim_v)
l = all_out[4][:, 0, :] # (h, s)
m = all_out[5][:, 0, :] # (h, s)
return out, m, l


def _splash_attention_forward_mhpt(
q: jax.Array,
k: jax.Array,
Expand Down
Loading
Loading