[Pytorch][Common] Hybrid quantization#2817
Conversation
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR introduces hybrid (per-direction) quantization for PyTorch: a
Confidence Score: 4/5The PR is broadly safe to merge for the targeted use cases, but one GroupedLinear code path produces incorrect quantized outputs for a factory that ships with this same PR. The transformer_engine/pytorch/module/grouped_linear.py ( Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
HQ["HybridQuantizer\n(rowwise_quantizer, columnwise_quantizer,\ncolumnwise_source)"]
subgraph quantize_impl["quantize_impl(tensor) — non-grouped path"]
RW["rowwise_quantizer.quantize(tensor)\n→ rowwise_storage"]
CS{columnwise_source}
ORIG["columnwise_quantizer.quantize(tensor)\n→ columnwise_storage"]
DEQUANT["rowwise_storage.dequantize()\ncolumnwise_quantizer.quantize(dequantized)\n→ columnwise_storage"]
CS -- original --> ORIG
CS -- rowwise_dequantized --> DEQUANT
RW --> CS
end
subgraph grouped_path["_hybrid_split_quantize — GroupedLinear path"]
GRW["tex.split_quantize(tensor, row_quantizers)\n→ row_results"]
GCW["tex.split_quantize(tensor, col_quantizers)\n← always original tensor\ncolumnwise_source ignored ⚠️"]
GZIP["zip → HybridQuantizedTensorStorage per split"]
GRW --> GZIP
GCW --> GZIP
end
HQ --> quantize_impl
HQ --> grouped_path
GZIP --> GEMM["_unwrap_hybrid_A/B\n→ native sub-storage → C++ GEMM"]
ORIG --> HybridTensor["HybridQuantizedTensor"]
DEQUANT --> HybridTensor
HybridTensor --> GEMM
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
HQ["HybridQuantizer\n(rowwise_quantizer, columnwise_quantizer,\ncolumnwise_source)"]
subgraph quantize_impl["quantize_impl(tensor) — non-grouped path"]
RW["rowwise_quantizer.quantize(tensor)\n→ rowwise_storage"]
CS{columnwise_source}
ORIG["columnwise_quantizer.quantize(tensor)\n→ columnwise_storage"]
DEQUANT["rowwise_storage.dequantize()\ncolumnwise_quantizer.quantize(dequantized)\n→ columnwise_storage"]
CS -- original --> ORIG
CS -- rowwise_dequantized --> DEQUANT
RW --> CS
end
subgraph grouped_path["_hybrid_split_quantize — GroupedLinear path"]
GRW["tex.split_quantize(tensor, row_quantizers)\n→ row_results"]
GCW["tex.split_quantize(tensor, col_quantizers)\n← always original tensor\ncolumnwise_source ignored ⚠️"]
GZIP["zip → HybridQuantizedTensorStorage per split"]
GRW --> GZIP
GCW --> GZIP
end
HQ --> quantize_impl
HQ --> grouped_path
GZIP --> GEMM["_unwrap_hybrid_A/B\n→ native sub-storage → C++ GEMM"]
ORIG --> HybridTensor["HybridQuantizedTensor"]
DEQUANT --> HybridTensor
HybridTensor --> GEMM
Reviews (17): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
timmoon10
left a comment
There was a problem hiding this comment.
Overall I think this moves us in a good direction. I see some minor bugs, as well as bugs reported by @greptile-apps.
| rowwise_result = self.rowwise_quantizer.quantize(tensor) | ||
| columnwise_result = self.columnwise_quantizer.quantize(tensor) |
There was a problem hiding this comment.
Do we handle the case where not all usages are needed? I'd expect something like:
| rowwise_result = self.rowwise_quantizer.quantize(tensor) | |
| columnwise_result = self.columnwise_quantizer.quantize(tensor) | |
| rowwise_result = self.rowwise_quantizer.quantize(tensor) if self.rowwise_usage else None | |
| columnwise_result = self.columnwise_quantizer.quantize(tensor) if self.columnwise_usage else None |
| requires_grad: bool = False, | ||
| pin_memory: bool = False, | ||
| ) -> HybridQuantizedTensor: | ||
| self.rowwise_quantizer.internal = True |
There was a problem hiding this comment.
Could we just set internal=True in the constructor? I don't think we ever need PyTorch tensor functionality in the per-usage data.
There was a problem hiding this comment.
This would not work under FSDP2.
| def factory(role): | ||
| if role == "linear_weight": | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_fp8_quantizer(), | ||
| columnwise_quantizer=_make_mxfp8_quantizer(), | ||
| ) | ||
| if role == "linear_input": | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_fp8_quantizer(), | ||
| columnwise_quantizer=_make_nvfp4_quantizer(), | ||
| ) | ||
| if role in ("linear_grad_output", "linear_grad_input"): | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_mxfp8_quantizer(), | ||
| columnwise_quantizer=_make_nvfp4_quantizer(), | ||
| ) | ||
| return None |
There was a problem hiding this comment.
This is horrifying. Good test.
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
| # DCP serializes ``CustomRecipe`` via ``pickle``; closure-based qfactories | ||
| # (lambdas, inner functions referencing captured state) are not picklable, | ||
| # so the qfactory must live at module scope. See | ||
| # ``run_fsdp2_fused_adam.py::test_hybrid_dcp_output_parity``. |
There was a problem hiding this comment.
This comment is potentially useful, but I don't think it is in the right place - shouldn't it be closer to the actual implementation?
| for param in model.parameters(): | ||
| state = optimizer.state[param] | ||
| assert state["exp_avg"].dtype == torch.float32 | ||
| assert state["exp_avg_sq"].dtype == torch.float32 | ||
| if "master_param" in state: | ||
| assert state["master_param"].dtype == torch.float32 | ||
|
|
||
| assert losses[-1] < losses[0], f"Loss did not decrease: {losses[0]:.4f} -> {losses[-1]:.4f}" |
There was a problem hiding this comment.
That's not a very strict test, is there a way for us to do some numerical correctness comparisons?
There was a problem hiding this comment.
Enabled check for the monotonic loss decrease (still mostly sanity), and also enabled hybrid vs vanilla bitwise recipe comparizon, see e.g. test_fused_adam_hybrid_vs_base_recipe_parity.
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L1 |
Signed-off-by: Evgeny <etsykunov@nvidia.com>
|
/te-ci pytorch L1 |
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
|
/te-ci pytorch L1 |
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny Tsykunov <etsykunov@nvidia.com>
|
/te-ci pytorch L1 |
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
|
Enable columnwise_source and hybrid recipes Respect quantizer veto for save_original_inp |
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Description
Hybrid (per-direction) quantization. Hybrid means rowwise/colwise can use different formats via CustomRecipe(qfactory).
This is an experimental feature.
The main problem that it tries to solve is that precision requirements are non-uniform.
Current recipes set one format for both rowwise and colwise directions.
Hybrid quantization enables, e.g. MXFP8 fwd and NVFP4 bwd (or vice versa) or any other valid combination. No need for a hardcoded recipe for every combination.
Composer-style (Composer 2 paper) grouped GEMM recipe, e.g. row-scaled NVFP4 fwd + MXFP8 bwd:
By default, the above factory uses
columnwise_source="original", so MXFP8 backward operands are quantized from the original high-precision tensor. Usecolumnwise_source="rowwise_dequantized"when the backward operand should be derived from the dequantized rowwise NVFP4 forward value.C++ optimizations (fusions, etc.) will come as standalone PRs. cc @kainzhong
TODO:
Integration
Ecosystem integration (all functional, unit-tested):
Megatron-LM integration status:
--fp{4,8}-param-gather+ dist opt (persistent low-precision params viaquantized_model_init+ sharded-master FP32 → quantized cast viaquantize_master_weights.)- [Done] Per-tensor Float8 hybrid (delayed and/or current, any per-direction combination
including same-format, cross-format Float8, single-direction)
- [TODO] Per-block hybrid sub-quantizers (MXFP8, NVFP4, Float8Blockwise) — each rejected per-direction by
quantize_master_weights; unblocker is TE-side cast-helper / kernel.--fp{4,8}-param-gather(fix private attribute access)--fp{4,8}-param-gather- [Done] TE-side hybrid FSDP2 path works end-to-end for Float8 / MXFP8 / Float8Blockwise sub-storages (TODO: need some minor MLM update)
- [TODO] NVFP4 sub-storage FSDP2 hooks
_hybrid_split_quantizeunder Megatron MoE)Review
Total diff +14000
New hybrid source (
hybrid_tensor.py,hybrid_tensor_storage.py,identity_tensor.py,identity_tensor_storage.py) ~1800Adjacent modifications ~1500
Tests are the rest (~10K)
Suggested reading order
-columnwise_source controls whether columnwise quantization uses the original input or the rowwise-dequantized value.
1.1 Identity passthrough — b99277a
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: