Skip to content

fix(jax): skip padding for non-XLA SavedModels#5602

Open
njzjz wants to merge 2 commits into
deepmodeling:masterfrom
njzjz:fix/jax-padding-xla
Open

fix(jax): skip padding for non-XLA SavedModels#5602
njzjz wants to merge 2 commits into
deepmodeling:masterfrom
njzjz:fix/jax-padding-xla

Conversation

@njzjz

@njzjz njzjz commented Jun 28, 2026

Copy link
Copy Markdown
Member

Summary

  • detect whether a JAX SavedModel contains XlaCallModule during C++ initialization
  • keep the dynamic atom-count padding only for XLA-compiled lower calls
  • pass exact nall_real shapes for non-XLA SavedModels

This is mainly to support #5598: padding only has value for XLA static-shape execution and otherwise changes the non-XLA inference shape unnecessarily.

Tests

  • git diff --check
  • cmake --build source/build --target deepmd_cc -j2
  • ruff check .
  • ruff format --check .

Not run: JAX C++ SavedModel runtime test, because local source/tests/infer/deeppot_dpa.savedmodel is not available.

Summary by CodeRabbit

  • Bug Fixes
    • Improved inference stability and correctness by applying neighbor-list input padding only for models that benefit from it.
    • Updated tensor sizing for neighbor-list computations (including coordinate/atom-type and mapping shapes) to use the appropriate atom count, reducing the risk of shape mismatches and unnecessary overhead.
    • Enhanced detection of whether the loaded computation graph uses XLA-style compilation to drive the padding behavior.

@dosubot dosubot Bot added the bug label Jun 28, 2026
@github-actions github-actions Bot added the C++ label Jun 28, 2026
@coderabbitai

coderabbitai Bot commented Jun 28, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 1ac15463-2f67-4fea-9ce7-1c4d5f2ab2b9

📥 Commits

Reviewing files that changed from the base of the PR and between 92f1b37 and cb22e27.

📒 Files selected for processing (2)
  • source/api_cc/include/DeepPotJAX.h
  • source/api_cc/src/DeepPotJAX.cc

📝 Walkthrough

Walkthrough

DeepPotJAX now records whether loaded TensorFlow artifacts use XLA-style compilation markers, and the neighbor-list compute path uses that flag to decide when to apply atom-count padding and matching tensor shapes.

Changes

Conditional XLA padding in DeepPotJAX

Layer / File(s) Summary
XLA detection helpers and init flag
source/api_cc/include/DeepPotJAX.h, source/api_cc/src/DeepPotJAX.cc
Adds the private XLA-usage flag, helper code that scans graph ops and serialized function bodies for XLA markers, and initialization code that stores the detection result.
Conditional padding in neighbor-list compute
source/api_cc/src/DeepPotJAX.cc
Changes neighbor-list compute to derive nall_model conditionally, resize coord_double and atype to match, and pass the updated size through add_input and mapping allocation.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly matches the main change: conditionally skipping padding for non-XLA JAX SavedModels.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@njzjz njzjz requested a review from wanghan-iapcm June 28, 2026 17:56
@codecov

codecov Bot commented Jun 28, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 63.58382% with 63 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.96%. Comparing base (a9bcbc5) to head (cb22e27).
⚠️ Report is 6 commits behind head on master.

Files with missing lines Patch % Lines
source/api_cc/src/DeepPotJAX.cc 63.58% 41 Missing and 22 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5602      +/-   ##
==========================================
- Coverage   82.35%   81.96%   -0.40%     
==========================================
  Files         896      959      +63     
  Lines      100952   105593    +4641     
  Branches     4059     4104      +45     
==========================================
+ Hits        83138    86547    +3409     
- Misses      16349    17559    +1210     
- Partials     1465     1487      +22     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Comment thread source/api_cc/src/DeepPotJAX.cc Outdated
Comment on lines +70 to +73
inline bool contains_xla_call_module(const std::vector<TF_Function*>& funcs,
TF_Status* status) {
for (TF_Function* func : funcs) {
if (function_def_contains(func, "XlaCallModule", status)) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking, but worth checking before merge: this XlaCallModule signal looks stale relative to current master. After #5598 (eager TF2 backend) merged, jax2tf native serialization was removed — git grep "jax2tf.convert|XlaCallModule|native_serialization" over deepmd/ on master now returns nothing, and deepmd/jax/jax2tf/serialization.py is a shim that re-exports the eager deepmd/tf2/utils/serialization.py exporter. Both .savedmodel and .savedmodeltf are now produced by that eager tf.function exporter, whose XLA-ness is controlled by jit_compile (defaulting to the DP_JIT env flag), and a @tf.function(jit_compile=True) SavedModel does not emit an XlaCallModule op.

Consequence: no model produced by the current toolchain contains XlaCallModule, so has_xla_call_module_ is false even for a DP_JIT=1 XLA-compiled export — which then skips the padding that #4307 added to avoid per-shape XLA recompilation. The detector effectively only fires for legacy on-disk jax2tf .savedmodel artifacts. Impact is performance-only (recompilation churn, never wrong results) and DP_JIT defaults off, so this isn't a correctness blocker — but the comment/title premise ("XLA models still pad") no longer holds for freshly-exported XLA models. Consider keying the padding decision on the actual XLA-enabling signal (jit_compile/DP_JIT, or a flag recorded in model metadata) rather than the op name.

Secondary point on the mechanism itself: function_def_contains does a raw-bytes substring find() over the serialized FunctionDef to drive control flow, which is fragile — it can false-positive if the literal string appears in any unrelated node/attr name, and false-negative across TF serialization changes. This is the same class of "string comparison for control flow" that was raised and refactored on #4363. Parsing node op fields (or the metadata-flag approach above) would be more robust. (Resource handling here is fine — TF_DeleteBuffer on both paths, length-based search so embedded NULs don't truncate.)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, both points are valid. I updated this PR in cb22e27.

XlaCallModule is not a generic TF2 JIT marker. A minimal tf.function(jit_compile=True) SavedModel serializes _XlaMustCompile=true on the PartitionedCall/FunctionDef, while the graph body can still contain ordinary TF ops. XlaCallModule is the marker for the native jax2tf serialization path.

The loader now detects both cases:

  • top-level graph ops via TensorFlow C API (TF_OperationGetAttrBool for _XlaMustCompile)
  • function attrs via TensorFlow C API (TF_FunctionGetAttrValueProto)
  • function body nodes via a minimal FunctionDef wire-format reader, checking NodeDef.op == "XlaCallModule" and _XlaMustCompile=true

I removed the raw serialized-bytes substring search. I also kept this within the TensorFlow C API boundary: the C API can serialize a TF_Function to FunctionDef, but it does not expose a C accessor to enumerate the FunctionDef NodeDefs, so the small wire-format reader avoids pulling in TensorFlow C++ protobuf headers.

Separately, #5613 restores the intended suffix split: .savedmodel is the JAX/jax2tf SavedModel path, and .savedmodeltf is the TF2 SavedModel path.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving my earlier top-level note here for context:\n\nI checked this more carefully against source/tests/infer/convert-models.sh.

The confusing part was introduced by #5598: the .savedmodel suffix still belongs to the JAX backend registry and the test script still documents it as the JAX/JAX2TF output suffix, but the export implementation had been changed to delegate to the TF2 SavedModel exporter. As a result, freshly exported .savedmodel and .savedmodeltf artifacts had the same ordinary TF op structure and neither contained XlaCallModule.

I opened #5613 to restore the intended split:

  • .savedmodel -> JAX/jax2tf SavedModel, with jax2tf.convert(...) and XlaCallModule nodes
  • .savedmodeltf -> TF2 eager SavedModel exporter

I also checked the DP_JIT=1/TF2 case separately. A minimal tf.function(jit_compile=True) SavedModel in TF 2.21 does not serialize XlaCallModule; it stores _XlaMustCompile: true on the FunctionDef/PartitionedCall while the graph still contains ordinary TF ops. A minimal jax2tf.convert(...) SavedModel does serialize XlaCallModule. So XlaCallModule is a marker for the jax2tf native serialization path, not a generic marker for TF2 jit_compile=True.

With #5613 applied, my local checks show:

  • deeppot_sea.savedmodel: 8 XlaCallModule ops
  • deeppot_dpa.savedmodel: 8 XlaCallModule ops
  • deeppot_sea.savedmodeltf: 0 XLA op names, expected for TF2 eager export

@njzjz

njzjz commented Jun 30, 2026

Copy link
Copy Markdown
Member Author

Moved to wanghan-iapcm's inline review thread: #5602 (comment)

@njzjz njzjz requested a review from wanghan-iapcm June 30, 2026 05:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants