Skip to content

Avoid unpickling the extra state when not needed#3123

Open
ptrendx wants to merge 5 commits into
NVIDIA:mainfrom
ptrendx:pr_avoid_unpickle
Open

Avoid unpickling the extra state when not needed#3123
ptrendx wants to merge 5 commits into
NVIDIA:mainfrom
ptrendx:pr_avoid_unpickle

Conversation

@ptrendx

@ptrendx ptrendx commented Jun 12, 2026

Copy link
Copy Markdown
Member

Description

Avoids unpickling of the extra state if the recipe is stateless. Adds a guard prompting user to explicitly allow loading of the checkpoint when the unpickling is necessary.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Avoids unpickling of the stateless recipe extra state
  • Adds a guard and environment variable for the delayed scaling recipes

ptrendx added 2 commits June 12, 2026 05:24
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@greptile-apps

greptile-apps Bot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR avoids unnecessary unpickling of extra state for stateless FP8 recipes and introduces a security guard (NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE) that blocks unsafe pickle loading of delayed-scaling checkpoints by default.

  • _extra_state.py (new) centralises pickle classification: a pickletools-based static analyser assigns each checkpoint payload an IGNORE or UNSAFE_LOAD action without executing the pickle, with STATELESS recipes returning empty tensors on save and STATEFUL_FP8_DELAYED_SCALING requiring the env-var opt-in on load.
  • base.py and op.py are updated to call is_stateless_recipe before serialising and should_load_extra_state_pickle before deserialising; the deprecated BytesIO path in base.py is now also gated behind the env var.
  • Test files are updated to set the env var around load_state_dict calls that involve delayed-scaling recipes.

Confidence Score: 3/5

The security guard and stateless-recipe optimisation are sound, but BasicOperation subclasses using CustomRecipe will silently lose their recipe on checkpoint round-trip.

Both op.py and base.py share a save/load asymmetry where CustomRecipe state is serialised on save but silently discarded on load with no recoverable opt-in path, constituting a quiet regression for users of BasicOperation with custom recipes.

transformer_engine/pytorch/ops/op.py and transformer_engine/pytorch/_extra_state.py — the DYNAMIC-without-delayed-state classification path needs a fix in both the classifier and both get_extra_state implementations.

Important Files Changed

Filename Overview
transformer_engine/pytorch/_extra_state.py New module centralising pickle classification logic; the DYNAMIC-without-delayed-state path returns IGNORE and cannot be overridden even with the opt-in env var, and user-defined Recipe subclasses always fall into UNSAFE_LOAD.
transformer_engine/pytorch/ops/op.py get_extra_state still serializes CustomRecipe (DYNAMIC) state but set_extra_state classifies it as IGNORE and returns early — silent regression for BasicOperation subclasses using CustomRecipe without delayed scaling.
transformer_engine/pytorch/module/base.py Stateless recipes now produce an empty extra-state tensor; set_extra_state is gated behind should_load_extra_state_pickle with the correct env-var guard; deprecated BytesIO path correctly upgraded to require the opt-in.
tests/pytorch/test_recipe.py New unit tests cover stateless ignore, unsafe-load guard, and policy-map coverage; does not test the DYNAMIC-without-delayed-state asymmetry in op.py.
tests/pytorch/test_checkpoint.py Adds env-var opt-in around load_state_dict for FP8 quantization; correctly restores the old env value after the test.
tests/pytorch/test_fusible_ops.py Same env-var guard pattern as test_checkpoint.py; correctly scoped to fp8 and fp8_delayed_scaling quantization modes.
tests/pytorch/test_numerics.py Adds env-var guard before load_state_dict, scoped to delayed-scaling recipes; finally block correctly restores original env state.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant M as Module (base.py / op.py)
    participant ES as _extra_state.py
    participant P as pickle

    Note over M,P: get_extra_state
    M->>ES: is_stateless_recipe(recipe)
    alt STATELESS
        ES-->>M: True → return empty tensor
    else DYNAMIC / STATEFUL
        M->>P: pickle.dumps(state)
        P-->>M: byte tensor
    end

    Note over M,P: set_extra_state
    M->>ES: should_load_extra_state_pickle(bytes, ctx)
    ES->>ES: _classify_extra_state_pickle_impl(bytes)
    alt No recipe key (TE 1.x) OR DelayedScaling OR delayed keys
        ES-->>M: raises RuntimeError (unless env var set)
    else DYNAMIC no delayed keys
        ES-->>M: False (IGNORE — state silently dropped)
    else STATELESS detected
        ES-->>M: False (IGNORE)
    end
    M->>P: pickle.loads(bytes)
    P-->>M: state dict → restore fp8_meta
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant M as Module (base.py / op.py)
    participant ES as _extra_state.py
    participant P as pickle

    Note over M,P: get_extra_state
    M->>ES: is_stateless_recipe(recipe)
    alt STATELESS
        ES-->>M: True → return empty tensor
    else DYNAMIC / STATEFUL
        M->>P: pickle.dumps(state)
        P-->>M: byte tensor
    end

    Note over M,P: set_extra_state
    M->>ES: should_load_extra_state_pickle(bytes, ctx)
    ES->>ES: _classify_extra_state_pickle_impl(bytes)
    alt No recipe key (TE 1.x) OR DelayedScaling OR delayed keys
        ES-->>M: raises RuntimeError (unless env var set)
    else DYNAMIC no delayed keys
        ES-->>M: False (IGNORE — state silently dropped)
    else STATELESS detected
        ES-->>M: False (IGNORE)
    end
    M->>P: pickle.loads(bytes)
    P-->>M: state dict → restore fp8_meta
Loading

Reviews (4): Last reviewed commit: "Remove pytorch changes from common and h..." | Re-trigger Greptile

Comment on lines +37 to +41
_RECIPE_POLICIES = {
(_RECIPE_MODULE, cls.__name__): cls.checkpoint_extra_state_policy
for cls in _recipe_subclasses(Recipe)
if cls.checkpoint_extra_state_policy is not None
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _RECIPE_POLICIES misses user-defined Recipe subclasses

_RECIPE_POLICIES is computed once at import time and only contains subclasses visible inside transformer_engine.common.recipe. Any Recipe subclass defined in user code (or in a different module) will not appear in this dict, so _classify_extra_state_pickle_impl will find an empty policies set and return UNSAFE_LOAD, forcing the user to set NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 even for a genuinely stateless custom recipe. The test test_checkpoint_extra_state_policy_classifier_map_covers_all_recipes only asserts coverage for first-party recipes and will not catch this gap for downstream users.

@timmoon10 timmoon10 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this seems like a reasonable fix, although I have some design suggestions and nits. FP8 delayed scaling still has pickling, but at least we can avoid it for more modern recipes.

Comment thread transformer_engine/common/recipe/__init__.py Outdated
Comment thread transformer_engine/pytorch/_extra_state.py Outdated
"""

STATELESS = "stateless"
STATEFUL = "stateful"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may have stateful recipes in the future, but we've learned our lesson not to naively pickle. We should make clear that this particular enum value represents stateful recipes with unsafe pickling.

Suggested change
STATEFUL = "stateful"
STATEFUL_FP8_DELAYED_SCALING = "stateful_fp8_delayed_scaling"

Other possible names could be STATEFUL_PICKLE or STATEFUL_UNSAFE.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Comment on lines +139 to +142
if not policies:
return _PickledExtraStateAction.UNSAFE_LOAD

return _PickledExtraStateAction.IGNORE

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 DYNAMIC without delayed state is silently ignored even with the opt-in env var

The final return _PickledExtraStateAction.IGNORE is reached when policies contains only DYNAMIC (i.e. CustomRecipe) and has_delayed_state_keys is False. should_load_extra_state_pickle short-circuits on IGNORE before ever consulting unsafe_pickle_extra_state_enabled(), so setting NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 has no effect for this case.

A CustomRecipe user whose checkpoint contains state stored under non-standard key names (anything outside _DELAYED_STATE_KEYS) will have that state silently dropped on load with no recoverable opt-in path, even if they know the checkpoint is from a trusted source and explicitly set the env var.

@ksivaman

Copy link
Copy Markdown
Member

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants