Skip to content

[Common/PyTorch] Grouped-quantize kernels for 1D and 2D FP8 block-scaling#3135

Open
denera wants to merge 4 commits into
NVIDIA:mainfrom
denera:common/fp8-block-scaling-grouped-quantize
Open

[Common/PyTorch] Grouped-quantize kernels for 1D and 2D FP8 block-scaling#3135
denera wants to merge 4 commits into
NVIDIA:mainfrom
denera:common/fp8-block-scaling-grouped-quantize

Conversation

@denera

@denera denera commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator

Description

Implements grouped-tensor quantize for the FP8 1D (1x128) and 2D (128x128) block-scaling recipes in row-wise (RW), column-wise (CW) and BOTH quantization directions. A single CUDA kernel launch walks 128x128 tiles across every tensor in the group, with each CTA decoding its owning tensor from the device-side GroupedTensor metadata with (N, R, K) shapes. Supports SAME_BOTH_DIMS (all tensors identical) and VARYING_FIRST_DIM (constant K, varying R) shape representations.

Three kernels share the dispatcher in group_quantize_blockwise_{1d,2d}:

  • group_block_scaled_1d_rw_kernel — RW-only dispatch; 8 threads/row, reads global memory directly into vec-16 registers; bypasses TMA because the shared memory roundtrip and ptx::mbarrier does not buy anything without re-use in CW path.
  • group_block_scaled_1d_tma_kernel — CW-only and BOTH dispatch; TMA bulk-load fills shared memory input cache. BOTH runs RW pass first (8 threads/row, vec-16 read from shared memory) then CW pass (2 threads/column, 64-row register stage); CW-only skips the RW pass. CW path writes the transposed-FP8 tile to a shared memory transpose staging buffer, then drains to global memory.
  • group_block_scaled_2d_tma_kernel — RW-only, CW-only and BOTH dispatch; TMA bulk-load fills shared memory input cache. Pass 1 stages 8 IVecs/thread in registers while computing the per-tile scalar amax. Pass 2 quantizes from registers, emits row-wise output, stages column-wise output to shared memory transpose staging buffer, then drains to global memory.

Kernels are gated to Hopper (sm_90) at the host dispatcher (cuBlasLt grouped GEMM supports FP8 block-scaling only on Hopper).

PR includes PyTorch integration.

JAX integration is intentionally left out-of-scope and deferred to a follow-up PR because it requires non-trivial new scaffolding on the framework side.

Resolves #2525

Performance

Table below measures performance on H200 with a sweep of grouped tensors in (N, M, K) shapes with:

  • N ∈ {4, 8, 16, 32, 64, 128} (# of device-local experts)
  • M = 4096 @ N = 4 —> M = 128 @ N = 128 (# of tokens/expert, scaling inversely with # of experts)
  • K ∈ {1024, 1792, 2048, 3584, 4096, 7168} (device-local shard of TP-hidden/intermediate-FFN dim)

The shapes are split into two buckets:

  • Small/Unsaturated (S): N x M x K <= 32M (< 2048 tiles and < 15 waves on H200's 132 SMs)
  • Large/Saturated (L): N x M x K > 32M (> 2048 tiles with enough work to keep SMs busy across many waves)

Reported kernel times and throughput ratios are bucket medians.

Speedup is measured relative to the split-quantized fallback that loops over the grouped tensor and sequentially quantizes each one.

% of "mono" throughput is measured relative to the throughput of a single non-grouped FP8 block-scaling quantize kernel invoked with the equivalent monolithic (NxM, K) tensor where the # of experts are collapsed with # of tokens/expert.

Bucket Path Grouped (ms) Split (ms) Speedup % memcpy tput % mono tput
S 1D RW 0.028 0.082 4.53× 76.5 % 117.2 %
S 1D CW 0.031 0.089 4.44× 66.1 % 116.9 %
S 1D BOTH 0.044 0.116 4.04× 63.5 % 115.4 %
S 2D RW 0.027 0.075 4.25× 74.2 % 99.7 %
S 2D CW 0.028 0.086 4.74× 72.3 % 128.9 %
S 2D BOTH 0.037 0.088 3.66× 74.5 % 97.6 %
L 1D RW 0.056 0.195 2.24× 88.9 % 119.9 %
L 1D CW 0.065 0.211 2.10× 79.9 % 122.1 %
L 1D BOTH 0.093 0.281 1.94× 74.0 % 118.4 %
L 2D RW 0.056 0.177 2.01× 88.6 % 99.6 %
L 2D CW 0.059 0.211 2.22× 85.8 % 135.0 %
L 2D BOTH 0.078 0.210 1.69× 84.2 % 99.1 %
# experts (N) S bucket L bucket
4 1.67× 1.45×
8 2.51× 1.49×
16 4.34× 1.97×
32 5.66× 2.92×
64 10.08× 6.40×
128 20.18× 9.06×

Notes

  • % of mono throughput is roughly consistent across buckets for every path, confirms no per-expert overhead in the new kernels.
  • Greater than 100% mono throughput cases are due to TMA bulk-loads, register staging and and vec-16 reads missing from the non-grouped FP8 block-scaling kernels.
  • Speedup over split-quantize scales as expected with # of experts (roughly linearly in the unsaturated regime) .

Known Sub-Optimalities

1D CW has bank conflicts on ~35% of load wavefronts (reading from the shared memory input-cache)

  • No possible stride padding or XOR swizzle to alleviate.
  • TMA hardware swizzle with CU_TENSOR_MAP_SWIZZLE_128B has the right pattern but caps FP16/BF16 at 64-elements; does not fit the 128-element tile for FP8 block-scaling without doubling per-tile launch overhead (quadrupling for FP32).
  • Threading restructure shifts bottleneck with no perf gain. Increasing threads/column loses the savings to additional cross-warp amax reduction plus sync. Decreasing to thread/column collapses occupancy to 1 CTA/SM under higher register pressure and shared memory footprint.

1D BOTH reads the shared memory input-cache twice

  • The RW (8 threads/row) and CW (2 threads/column) passes have different threading.
  • Attempted to unify with 8 threads/row for both RW and CW. Caused bank conflicts on ~76% of store wavefronts (writing to the shared memory transpose buffer), reduced to ~43% with a XOR swizzle but not enough to beat separate RW/CW passes.
  • Did not pursue the 2 threads/column unification; costs 40x more shfl ops than 8 threads/row attempt, plus a shared memory partial buffer and sync.

2D CW/BOTH has bank conflicts on ~16% of store wavefronts (when writing to the shared memory transpose buffer)

  • Already reduced from ~75% via a XOR swizzle, further reduction was not possible.
  • Minimal impact (< 5%) on kernel time.

No TMA-store

  • MXFP8 grouped quantize kernel leverages this by decomposing a 128x128 tile into 32-row sub-stages that each have their own independent 32x1 or 1x32 scale; shared memory footprint is based on a single sub-stage; can be quantized and TMA-stored independently; hides TMA-store of one stage under the compute of next stage.
  • FP8 block-scaling 128-element scale-block spans the entire 128-row tile. Cannot decompose into independent sub-stages and pipeline the TMA-stores. Single non-pipelined TMA-store requires holding the transposed staging buffer for the entire tile until all work on tile is finished, blows up shared memory footprint, collapses occupancy to 2CTA/SM. The recipe itself is the roadblock.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Implements grouped-tensor quantize for the FP8 1D (1x128) and 2D (128x128)
block-scaling recipes. A single CUDA kernel launch walks 128x128 tiles
across every tensor in the group, with each CTA decoding its owning
tensor from the device-side GroupedTensor metadata.

Supported shape representations:
  - SAME_BOTH_DIMS (all tensors identical)
  - VARYING_FIRST_DIM (constant K, varying R - the common MoE topology)

Supported directions: rowwise-only, columnwise-only, and both.

These kernels are gated to Hopper (sm_90) at the host dispatcher because
the consumer cuBLAS FP8 block-scaling *grouped* GEMM is itself
Hopper-only (cuBLAS does not provide native FP8 block-scaling grouped
GEMM on Blackwell; the recommended quantization recipe on Blackwell is
MXFP8). The device-side kernel bodies are gated on __CUDA_ARCH__ >= 900
so the kernels compile and link as part of multi-arch builds, but the
host gate prevents launches on Blackwell.

Three kernels share the dispatcher in
group_quantize_blockwise_{1d,2d}:

| Kernel | Dispatched when | Threading | Smem |
|--------|-----------------|-----------|------|
| group_block_scaled_1d_rw_kernel  | 1D RW-only       | 8 threads/row x 32 row-warps x 4 iters; reads gmem directly into vec-16 registers | none |
| group_block_scaled_1d_tma_kernel | 1D CW or 1D BOTH | TMA bulk-load fills 32 KB input cache. BOTH runs RW pass first (8 t/row, vec-16) then CW pass (2 t/col, 64-row register stage); CW-only skips the RW pass. CW writes the transposed-FP8 tile to a 16.5 KB smem_T staging buffer, then drains to gmem. | 32 KB + 16.5 KB |
| group_block_scaled_2d_tma_kernel | 2D RW / CW / BOTH | TMA bulk-load fills 32 KB cache. Pass 1 stages 8 IVecs/thread in registers while computing the per-tile scalar amax. Pass 2 quantizes from registers, emits rowwise output, stages columnwise output to smem_T, then drains. | 32 KB + 16.5 KB |

The RW-only 1D path bypasses TMA because a streaming read has no reuse
- the smem round-trip and mbarrier overhead would just add latency.

The C++ test tests/cpp/operator/test_cast_float8blockwise_grouped.cu
exercises 72 configurations covering RW/CW/BOTH x 1D/2D x SAME/VARYING
shape representations against a per-tensor split-quantize reference.

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera requested review from ptrendx and vthumbe1503 June 17, 2026 13:01
@denera denera self-assigned this Jun 17, 2026
@denera denera added performance Performance issues FP8 MoE labels Jun 17, 2026
constexpr int kThreadsPerBlock = 256;
constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;

// Align a dynamic-smem pointer to 128 bytes (TMA requirement).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we reuse the existing align_smem_ptr_per_TMA_requirements() helper from transformer_engine/cast/core/common.h here?

size_t total_row_blocks) {
using namespace transformer_engine::dispatch::mxfp8::swizzle;
const size_t num_tiles_X =
(total_row_blocks + GEMM_SWIZZLED_SCALE_TILE_DIM_X - 1) / GEMM_SWIZZLED_SCALE_TILE_DIM_X;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also reuse the existing DIVUP() helper here (defined in transformer_engin/common/common.h).


// ---- Tensor-lookup helpers ----------------------------------------------------

// Map a global tile-row index to its owning tensor by binary-searching

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also reuse the existing get_current_tensor_id() helper defined in transformer_engine/cast/core/common.cuh

@greptile-apps

greptile-apps Bot commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds CUDA kernels for grouped FP8 block-scaling quantization (1D and 2D) on Hopper (SM90), plus PyTorch integration. Three kernel variants handle rowwise-only, column-wise-only, and both-directions paths using TMA bulk-loads, register staging, and XOR-swizzled shared-memory transpose buffers.

  • group_quantize_fp8_blockwise.cuh: three kernels share a tile-walking dispatch; each CTA decodes its owning tensor from device-resident GroupedTensor metadata. Both SAME_BOTH_DIMS and VARYING_FIRST_DIM shape representations are supported.
  • ptx.cuh: mbarrier_* helpers are lowered from SM10+ to SM9+; a new cp_async_bulk_tensor_2d_global_to_shared_cta function adds CTA-scoped TMA for non-cluster Hopper launches.
  • quantizer.cpp / cast.cpp: Float8BlockQuantizer::create_grouped_tensor and group_quantize wire the new kernels into the PyTorch path, but with_gemm_swizzled_scales is hardcoded to false, ignoring optimize_for_gemm.

Confidence Score: 4/5

Merge-ready for non-GEMM use cases; callers with optimize_for_gemm=True will silently produce wrong colwise scale layouts for cuBLAS FP8 block-scaling GEMM.

The kernel implementations, TMA setup, smem layouts, and swizzle logic are correct. The one concrete defect is in Float8BlockQuantizer::create_grouped_tensor: with_gemm_swizzled_scales is hardcoded to false, so the kernel's native swizzled-scale path (kSwizzled=true) is never activated through the PyTorch group_quantize API even when the quantizer has optimize_for_gemm=True. Since cuBLAS FP8 block-scaling GEMM is the declared primary consumer of this feature, the wrong scale layout would cause incorrect GEMM results without any error signal.

transformer_engine/pytorch/csrc/quantizer.cpp (line 1153) and transformer_engine/pytorch/csrc/extensions/cast.cpp need to propagate optimize_for_gemm into with_gemm_swizzled_scales on the grouped output tensor before the kernel is launched.

Important Files Changed

Filename Overview
transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh Core implementation: three CUDA kernels (1D RW, 1D TMA, 2D TMA) and host dispatchers for grouped FP8 block-scaling quantize. TMA setup, smem layout, XOR swizzle for bank-conflict reduction, and scale indexing all look correct.
transformer_engine/pytorch/csrc/quantizer.cpp New Float8BlockQuantizer::create_grouped_tensor hardcodes with_gemm_swizzled_scales=false, ignoring this->optimize_for_gemm. Callers with optimize_for_gemm=True will receive non-swizzled colwise scales, silently breaking downstream cuBLAS FP8 block-scaling GEMM correctness.
transformer_engine/common/util/ptx.cuh Lowers mbarrier_* and related PTX helpers from SM 10.0 to SM 9.0; adds new cp_async_bulk_tensor_2d_global_to_shared_cta variant for non-cluster TMA on Hopper. PTX instructions are valid on SM 9.0+; existing SM 10.0 callers unaffected.
tests/cpp/operator/test_cast_float8blockwise_grouped.cu New test covers SAME_BOTH_DIMS and VARYING_FIRST_DIM shapes, all three scaling directions, and both 1D/2D block modes with swizzled-scale variants; only uses force_pow_2_scales=true, leaving the non-pow2 scale computation path untested.
transformer_engine/common/cast/dispatch/quantize.cuh Wires the new 1D and 2D grouped block-scaling dispatchers into group_quantize_fwd_helper; IS_ACT guard added for both modes. Straightforward integration.
transformer_engine/pytorch/csrc/extensions/cast.cpp Adds FP8_BLOCKWISE_GROUPED_QUANTIZE mode to group_quantize, correctly propagating force_pow_2_scales and amax_epsilon. Missing propagation of optimize_for_gemm → with_gemm_swizzled_scales (root is in quantizer.cpp).

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["group_quantize() [cast.cpp]"] --> B{Quantizer type}
    B -->|Float8Blockwise| C["create_grouped_tensor()\n[quantizer.cpp]\nwith_gemm_swizzled_scales=false ⚠️"]
    C --> D["group_quantize_fwd_helper()\n[dispatch/quantize.cuh]"]
    D -->|BLOCK_SCALING_1D| E["group_quantize_blockwise_1d()"]
    D -->|BLOCK_SCALING_2D| F["group_quantize_blockwise_2d()"]
    E --> G{use_rowwise only?}
    G -->|yes| H["group_block_scaled_1d_rw_kernel\n8 t/row, vec-16 global reads"]
    G -->|no| I["group_block_scaled_1d_tma_kernel\nTMA → smem_in, RW+CW passes"]
    F --> J["group_block_scaled_2d_tma_kernel\nTMA → smem_in, reg-staged amax,\nquantize + drain smem_T"]
    I --> K["kSwizzled = output->with_gemm_swizzled_scales\n(always false via PyTorch path ⚠️)"]
    J --> K
    K -->|false| M["Non-swizzled CW scales"]
    K -->|true| N["GEMM-swizzled CW scales\nfor cuBLAS FP8 GEMM"]
    style C fill:#ffcccc
    style K fill:#ffcccc
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A["group_quantize() [cast.cpp]"] --> B{Quantizer type}
    B -->|Float8Blockwise| C["create_grouped_tensor()\n[quantizer.cpp]\nwith_gemm_swizzled_scales=false ⚠️"]
    C --> D["group_quantize_fwd_helper()\n[dispatch/quantize.cuh]"]
    D -->|BLOCK_SCALING_1D| E["group_quantize_blockwise_1d()"]
    D -->|BLOCK_SCALING_2D| F["group_quantize_blockwise_2d()"]
    E --> G{use_rowwise only?}
    G -->|yes| H["group_block_scaled_1d_rw_kernel\n8 t/row, vec-16 global reads"]
    G -->|no| I["group_block_scaled_1d_tma_kernel\nTMA → smem_in, RW+CW passes"]
    F --> J["group_block_scaled_2d_tma_kernel\nTMA → smem_in, reg-staged amax,\nquantize + drain smem_T"]
    I --> K["kSwizzled = output->with_gemm_swizzled_scales\n(always false via PyTorch path ⚠️)"]
    J --> K
    K -->|false| M["Non-swizzled CW scales"]
    K -->|true| N["GEMM-swizzled CW scales\nfor cuBLAS FP8 GEMM"]
    style C fill:#ffcccc
    style K fill:#ffcccc
Loading

Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +655 to +665
} else {
info.common_first_dim_blocks = 0;
info.R_total = output->logical_shape.data[0];
info.tensor_offsets_d = reinterpret_cast<const int64_t*>(output->tensor_offsets.dptr);
NVTE_CHECK(info.tensor_offsets_d != nullptr,
"VARYING_FIRST_DIM requires tensor_offsets to be set on the GroupedTensor.");
}
info.total_row_blocks = (info.R_total + kTileDim - 1) / kTileDim;
info.blocks_X = (info.K + kTileDim - 1) / kTileDim;
return info;
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 VARYING_FIRST_DIM path silently requires 128-aligned per-tensor first dims

The SAME_BOTH_DIMS branch (line 651) enforces common_first_dim % kTileDim == 0, but the VARYING_FIRST_DIM branch has no equivalent check. The kernel's correctness depends entirely on this alignment: tensor_block_y_base_from_offsets divides element offsets by kTileDim * K using integer truncation, and tensor_row_blocks is derived the same way. A tensor with first_dim = 192 (not a multiple of 128) would produce tensor_row_blocks = 1 instead of 2, causing the second 64-row slice (rows 128–191) to be silently skipped by the in-kernel bounds guard and left un-quantized. The offsets are device-resident so host validation isn't straightforward, but a prominent NVTE_CHECK comment or a note in the function contract would prevent silent data loss from callers with unexpected shapes.

Comment on lines +747 to +750
NVTE_CHECK(sm >= 90 && sm < 100,
"Grouped FP8 block-scaling quantize is only supported on Hopper (SM90); "
"use MXFP8 on Blackwell (SM100) or newer. Got SM",
sm, ".");

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The error message in group_quantize_blockwise_1d says "SM90" while the identical check in group_quantize_blockwise_2d correctly says "SM90-SM99". The condition sm >= 90 && sm < 100 covers the full Hopper range, so the 1D message is misleading.

Suggested change
NVTE_CHECK(sm >= 90 && sm < 100,
"Grouped FP8 block-scaling quantize is only supported on Hopper (SM90); "
"use MXFP8 on Blackwell (SM100) or newer. Got SM",
sm, ".");
NVTE_CHECK(sm >= 90 && sm < 100,
"Grouped FP8 block-scaling quantize is only supported on Hopper (SM90-SM99); "
"use MXFP8 on Blackwell (SM100) or newer. Got SM",
sm, ".");

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +325 to +355
if (first_dims_d) cudaFree(first_dims_d);
}

struct TestConfig {
ShapeRep shape_rep;
BlockDim block_dim;
ScalingDir dir;
std::vector<size_t> first_dims;
size_t K;
};

class GroupedFP8BlockwiseTestSuite : public ::testing::TestWithParam<TestConfig> {};

TEST_P(GroupedFP8BlockwiseTestSuite, Test) {
const TestConfig& cfg = GetParam();
perform_test<bf16, fp8e4m3>(cfg.shape_rep, cfg.block_dim, cfg.dir, cfg.first_dims, cfg.K,
/*force_pow_2_scales=*/true, /*epsilon=*/0.0f);
}

std::vector<TestConfig> make_configs() {
std::vector<TestConfig> configs;
std::vector<std::vector<size_t>> uniform = {{128, 128}, {256, 256, 256, 256}};
std::vector<std::vector<size_t>> jagged = {
{128, 256, 384, 512}, {256, 128, 512, 384, 1024}};
std::vector<size_t> Ks = {128, 256, 512};
for (auto bd : {BlockDim::ONE_D, BlockDim::TWO_D}) {
for (auto dir : {ScalingDir::ROWWISE, ScalingDir::COLWISE, ScalingDir::BOTH}) {
for (size_t K : Ks) {
for (const auto& v : uniform) {
configs.push_back({ShapeRep::SAME_BOTH_DIMS, bd, dir, v, K});
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Swizzled-scale path (with_gemm_swizzled_scales=true) is not exercised

The host dispatchers plumb output->with_gemm_swizzled_scales into both the 1D and 2D kernels (the kSwizzled template parameter), and the swizzled-scale indexing in swizzled_colwise_scale_idx is a separate non-trivial code path. Neither make_configs() nor any test fixture sets this flag, so the swizzled layout is never compared against a reference. Since cuBLAS FP8 block-scaling GEMM is the primary consumer of the swizzled layout, a bug there would be invisible until GEMM produces wrong results.


// ---- TMA async load of the input tile ----
if (leading_thread) {
ptx::mbarrier_init(&tma_mbar, 1);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since mbar resides in shared memory, a cross-proxy fence between the async and generic proxies needs to be issued here before __syncthreads() so that both the TMA engine and the threads observe mbar in the correct state. We can use ptx::fence_proxy_async_shared_cta() defined in transformer_engine/common/util/ptx.cuh.

}

CType amax = compute_row_amax<IType, CType, kVec>(in_vec[it]);
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 1));

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we reuse the existing amax warp-reduction helpers (warp_reduce_max() or reduce_max()) from transformer_engine/common/utils.cuh here?

// ---- TMA async load of the input tile ----
if (leading_thread) {
ptx::mbarrier_init(&tma_mbar, 1);
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the above:

Suggested change
}
ptx::fence_proxy_async_shared_cta();
}

Comment on lines +535 to +537
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 1));
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 2));
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 4));

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also reuse reduce_max() or warp_reduce_max() here.


// ----- Host-side dispatchers --------------------------------------------------------------------

inline size_t align_up_to(size_t x, size_t a) { return ((x + a - 1) / a) * a; }

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can reuse DIVUP_TO_MULTIPLE() defined in transformer_engine/common/common.h.

NVTE_CHECK(info.tensor_offsets_d != nullptr,
"VARYING_FIRST_DIM requires tensor_offsets to be set on the GroupedTensor.");
}
info.total_row_blocks = (info.R_total + kTileDim - 1) / kTileDim;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
info.total_row_blocks = (info.R_total + kTileDim - 1) / kTileDim;
info.total_row_blocks = DIVUP(info.R_total, kTileDim);

"VARYING_FIRST_DIM requires tensor_offsets to be set on the GroupedTensor.");
}
info.total_row_blocks = (info.R_total + kTileDim - 1) / kTileDim;
info.blocks_X = (info.K + kTileDim - 1) / kTileDim;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
info.blocks_X = (info.K + kTileDim - 1) / kTileDim;
info.blocks_X = DIVUP(info.K, kTileDim);

info.same_both_dims = same_both_dims;
info.num_tensors = output->num_tensors;
info.K = output->get_common_last_dim();
NVTE_CHECK(info.K % 16 == 0, "Last dim must be multiple of 16 (FP8 alignment).");

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is a TMA requirement, we can use the TMA_GMEM_ALIGNMENT constant defined in transformer_engine/common/common.h

const float* noop_ptr =
(noop != nullptr) ? reinterpret_cast<const float*>(noop->data.dptr) : nullptr;

const size_t scale_stride_y = align_up_to(info.blocks_X, 4);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const size_t scale_stride_y = align_up_to(info.blocks_X, 4);
const size_t scale_stride_y = DIVUP_TO_MULTIPLE(info.blocks_X, 4);

const size_t scale_stride_y = align_up_to(info.blocks_X, 4);
// CW scales are stored [blocks_X, align4(total_row_blocks)] -- transposed to
// match the physically-transposed columnwise data the TN cuBLAS GEMM consumes.
const size_t scale_t_stride_y = align_up_to(info.total_row_blocks, 4);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const size_t scale_t_stride_y = align_up_to(info.total_row_blocks, 4);
const size_t scale_t_stride_y = DIVUP_TO_MULTIPLE(info.total_row_blocks, 4);

const float* noop_ptr =
(noop != nullptr) ? reinterpret_cast<const float*>(noop->data.dptr) : nullptr;

const size_t scale_stride_aligned_R = align_up_to(info.R_total, 4);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const size_t scale_stride_aligned_R = align_up_to(info.R_total, 4);
const size_t scale_stride_aligned_R = DIVUP_TO_MULTIPLE(info.R_total, 4);

(noop != nullptr) ? reinterpret_cast<const float*>(noop->data.dptr) : nullptr;

const size_t scale_stride_aligned_R = align_up_to(info.R_total, 4);
const size_t scale_t_stride_aligned_K = align_up_to(info.K, 4);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const size_t scale_t_stride_aligned_K = align_up_to(info.K, 4);
const size_t scale_t_stride_aligned_K = DIVUP_TO_MULTIPLE(info.K, 4);

// ---- TMA async load of the input tile ----
if (leading_thread) {
ptx::mbarrier_init(&tma_mbar, 1);
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
}
ptx::fence_proxy_async_shared_cta();
}

denera and others added 2 commits June 22, 2026 22:49
- Reuse shared helpers (DIVUP, DIVUP_TO_MULTIPLE, TMA_GMEM_ALIGNMENT,
  align_smem_ptr_per_TMA_requirements, get_current_tensor_id,
  subwarp_reduce_max_broadcast) in place of local equivalents.
- Add proxy-async fence after mbarrier_init in 2D + 1D TMA kernels.
- Enforce per-tensor first_dim % 128 device-side for VARYING_FIRST_DIM
  (matches MXFP8 grouped quantize behavior).
- Fix Hopper SM range wording in 1D dispatcher.
- Extend cpp tests to cover with_gemm_swizzled_scales path.

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera requested a review from Oleg-Goncharov June 22, 2026 23:06
// num_tiles_X = DIVUP(total_row_blocks, TILE_DIM_X=4)
__device__ __forceinline__ size_t swizzled_colwise_scale_idx(size_t i, size_t j,
size_t total_row_blocks) {
using namespace transformer_engine::dispatch::mxfp8::swizzle;

@vthumbe1503 vthumbe1503 Jun 22, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should rename the namespace for swizzle...given that we use the same constants for mxfp8, nvfp4, fp8 block scaling

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

FP8 MoE performance Performance issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Blockwise (1x128 and 128x128) FP8 grouped quantization

3 participants