Avoid unpickling the extra state when not needed#3123
Conversation
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Greptile SummaryThis PR avoids unnecessary unpickling of extra state for stateless FP8 recipes and introduces a security guard (
Confidence Score: 3/5The security guard and stateless-recipe optimisation are sound, but Both
Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant M as Module (base.py / op.py)
participant ES as _extra_state.py
participant P as pickle
Note over M,P: get_extra_state
M->>ES: is_stateless_recipe(recipe)
alt STATELESS
ES-->>M: True → return empty tensor
else DYNAMIC / STATEFUL
M->>P: pickle.dumps(state)
P-->>M: byte tensor
end
Note over M,P: set_extra_state
M->>ES: should_load_extra_state_pickle(bytes, ctx)
ES->>ES: _classify_extra_state_pickle_impl(bytes)
alt No recipe key (TE 1.x) OR DelayedScaling OR delayed keys
ES-->>M: raises RuntimeError (unless env var set)
else DYNAMIC no delayed keys
ES-->>M: False (IGNORE — state silently dropped)
else STATELESS detected
ES-->>M: False (IGNORE)
end
M->>P: pickle.loads(bytes)
P-->>M: state dict → restore fp8_meta
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant M as Module (base.py / op.py)
participant ES as _extra_state.py
participant P as pickle
Note over M,P: get_extra_state
M->>ES: is_stateless_recipe(recipe)
alt STATELESS
ES-->>M: True → return empty tensor
else DYNAMIC / STATEFUL
M->>P: pickle.dumps(state)
P-->>M: byte tensor
end
Note over M,P: set_extra_state
M->>ES: should_load_extra_state_pickle(bytes, ctx)
ES->>ES: _classify_extra_state_pickle_impl(bytes)
alt No recipe key (TE 1.x) OR DelayedScaling OR delayed keys
ES-->>M: raises RuntimeError (unless env var set)
else DYNAMIC no delayed keys
ES-->>M: False (IGNORE — state silently dropped)
else STATELESS detected
ES-->>M: False (IGNORE)
end
M->>P: pickle.loads(bytes)
P-->>M: state dict → restore fp8_meta
Reviews (4): Last reviewed commit: "Remove pytorch changes from common and h..." | Re-trigger Greptile |
| _RECIPE_POLICIES = { | ||
| (_RECIPE_MODULE, cls.__name__): cls.checkpoint_extra_state_policy | ||
| for cls in _recipe_subclasses(Recipe) | ||
| if cls.checkpoint_extra_state_policy is not None | ||
| } |
There was a problem hiding this comment.
_RECIPE_POLICIES misses user-defined Recipe subclasses
_RECIPE_POLICIES is computed once at import time and only contains subclasses visible inside transformer_engine.common.recipe. Any Recipe subclass defined in user code (or in a different module) will not appear in this dict, so _classify_extra_state_pickle_impl will find an empty policies set and return UNSAFE_LOAD, forcing the user to set NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 even for a genuinely stateless custom recipe. The test test_checkpoint_extra_state_policy_classifier_map_covers_all_recipes only asserts coverage for first-party recipes and will not catch this gap for downstream users.
timmoon10
left a comment
There was a problem hiding this comment.
Overall this seems like a reasonable fix, although I have some design suggestions and nits. FP8 delayed scaling still has pickling, but at least we can avoid it for more modern recipes.
| """ | ||
|
|
||
| STATELESS = "stateless" | ||
| STATEFUL = "stateful" |
There was a problem hiding this comment.
We may have stateful recipes in the future, but we've learned our lesson not to naively pickle. We should make clear that this particular enum value represents stateful recipes with unsafe pickling.
| STATEFUL = "stateful" | |
| STATEFUL_FP8_DELAYED_SCALING = "stateful_fp8_delayed_scaling" |
Other possible names could be STATEFUL_PICKLE or STATEFUL_UNSAFE.
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
| if not policies: | ||
| return _PickledExtraStateAction.UNSAFE_LOAD | ||
|
|
||
| return _PickledExtraStateAction.IGNORE |
There was a problem hiding this comment.
DYNAMIC without delayed state is silently ignored even with the opt-in env var
The final return _PickledExtraStateAction.IGNORE is reached when policies contains only DYNAMIC (i.e. CustomRecipe) and has_delayed_state_keys is False. should_load_extra_state_pickle short-circuits on IGNORE before ever consulting unsafe_pickle_extra_state_enabled(), so setting NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 has no effect for this case.
A CustomRecipe user whose checkpoint contains state stored under non-standard key names (anything outside _DELAYED_STATE_KEYS) will have that state silently dropped on load with no recoverable opt-in path, even if they know the checkpoint is from a trusted source and explicitly set the env var.
Signed-off-by: ksivamani <ksivamani@nvidia.com>
|
/te-ci pytorch |
Description
Avoids unpickling of the extra state if the recipe is stateless. Adds a guard prompting user to explicitly allow loading of the checkpoint when the unpickling is necessary.
Type of change
Changes
Please list the changes introduced in this PR: