fix(jax): skip padding for non-XLA SavedModels#5602
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
📝 WalkthroughWalkthrough
ChangesConditional XLA padding in DeepPotJAX
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5602 +/- ##
==========================================
- Coverage 82.35% 81.96% -0.40%
==========================================
Files 896 959 +63
Lines 100952 105593 +4641
Branches 4059 4104 +45
==========================================
+ Hits 83138 86547 +3409
- Misses 16349 17559 +1210
- Partials 1465 1487 +22 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
| inline bool contains_xla_call_module(const std::vector<TF_Function*>& funcs, | ||
| TF_Status* status) { | ||
| for (TF_Function* func : funcs) { | ||
| if (function_def_contains(func, "XlaCallModule", status)) { |
There was a problem hiding this comment.
Non-blocking, but worth checking before merge: this XlaCallModule signal looks stale relative to current master. After #5598 (eager TF2 backend) merged, jax2tf native serialization was removed — git grep "jax2tf.convert|XlaCallModule|native_serialization" over deepmd/ on master now returns nothing, and deepmd/jax/jax2tf/serialization.py is a shim that re-exports the eager deepmd/tf2/utils/serialization.py exporter. Both .savedmodel and .savedmodeltf are now produced by that eager tf.function exporter, whose XLA-ness is controlled by jit_compile (defaulting to the DP_JIT env flag), and a @tf.function(jit_compile=True) SavedModel does not emit an XlaCallModule op.
Consequence: no model produced by the current toolchain contains XlaCallModule, so has_xla_call_module_ is false even for a DP_JIT=1 XLA-compiled export — which then skips the padding that #4307 added to avoid per-shape XLA recompilation. The detector effectively only fires for legacy on-disk jax2tf .savedmodel artifacts. Impact is performance-only (recompilation churn, never wrong results) and DP_JIT defaults off, so this isn't a correctness blocker — but the comment/title premise ("XLA models still pad") no longer holds for freshly-exported XLA models. Consider keying the padding decision on the actual XLA-enabling signal (jit_compile/DP_JIT, or a flag recorded in model metadata) rather than the op name.
Secondary point on the mechanism itself: function_def_contains does a raw-bytes substring find() over the serialized FunctionDef to drive control flow, which is fragile — it can false-positive if the literal string appears in any unrelated node/attr name, and false-negative across TF serialization changes. This is the same class of "string comparison for control flow" that was raised and refactored on #4363. Parsing node op fields (or the metadata-flag approach above) would be more robust. (Resource handling here is fine — TF_DeleteBuffer on both paths, length-based search so embedded NULs don't truncate.)
There was a problem hiding this comment.
Thanks, both points are valid. I updated this PR in cb22e27.
XlaCallModule is not a generic TF2 JIT marker. A minimal tf.function(jit_compile=True) SavedModel serializes _XlaMustCompile=true on the PartitionedCall/FunctionDef, while the graph body can still contain ordinary TF ops. XlaCallModule is the marker for the native jax2tf serialization path.
The loader now detects both cases:
- top-level graph ops via TensorFlow C API (
TF_OperationGetAttrBoolfor_XlaMustCompile) - function attrs via TensorFlow C API (
TF_FunctionGetAttrValueProto) - function body nodes via a minimal FunctionDef wire-format reader, checking
NodeDef.op == "XlaCallModule"and_XlaMustCompile=true
I removed the raw serialized-bytes substring search. I also kept this within the TensorFlow C API boundary: the C API can serialize a TF_Function to FunctionDef, but it does not expose a C accessor to enumerate the FunctionDef NodeDefs, so the small wire-format reader avoids pulling in TensorFlow C++ protobuf headers.
Separately, #5613 restores the intended suffix split: .savedmodel is the JAX/jax2tf SavedModel path, and .savedmodeltf is the TF2 SavedModel path.
There was a problem hiding this comment.
Moving my earlier top-level note here for context:\n\nI checked this more carefully against source/tests/infer/convert-models.sh.
The confusing part was introduced by #5598: the .savedmodel suffix still belongs to the JAX backend registry and the test script still documents it as the JAX/JAX2TF output suffix, but the export implementation had been changed to delegate to the TF2 SavedModel exporter. As a result, freshly exported .savedmodel and .savedmodeltf artifacts had the same ordinary TF op structure and neither contained XlaCallModule.
I opened #5613 to restore the intended split:
.savedmodel-> JAX/jax2tf SavedModel, withjax2tf.convert(...)andXlaCallModulenodes.savedmodeltf-> TF2 eager SavedModel exporter
I also checked the DP_JIT=1/TF2 case separately. A minimal tf.function(jit_compile=True) SavedModel in TF 2.21 does not serialize XlaCallModule; it stores _XlaMustCompile: true on the FunctionDef/PartitionedCall while the graph still contains ordinary TF ops. A minimal jax2tf.convert(...) SavedModel does serialize XlaCallModule. So XlaCallModule is a marker for the jax2tf native serialization path, not a generic marker for TF2 jit_compile=True.
With #5613 applied, my local checks show:
deeppot_sea.savedmodel: 8XlaCallModuleopsdeeppot_dpa.savedmodel: 8XlaCallModuleopsdeeppot_sea.savedmodeltf: 0 XLA op names, expected for TF2 eager export
|
Moved to wanghan-iapcm's inline review thread: #5602 (comment) |
Summary
XlaCallModuleduring C++ initializationnall_realshapes for non-XLA SavedModelsThis is mainly to support #5598: padding only has value for XLA static-shape execution and otherwise changes the non-XLA inference shape unnecessarily.
Tests
git diff --checkcmake --build source/build --target deepmd_cc -j2ruff check .ruff format --check .Not run: JAX C++ SavedModel runtime test, because local
source/tests/infer/deeppot_dpa.savedmodelis not available.Summary by CodeRabbit