Skip to content

feat(jax): freeze models with hessian output#5608

Open
njzjz wants to merge 1 commit into
deepmodeling:masterfrom
njzjz:feat/jax-hessian-freeze
Open

feat(jax): freeze models with hessian output#5608
njzjz wants to merge 1 commit into
deepmodeling:masterfrom
njzjz:feat/jax-hessian-freeze

Conversation

@njzjz

@njzjz njzjz commented Jun 29, 2026

Copy link
Copy Markdown
Member

Summary

  • add dp freeze --hessian support for the JAX backend
  • propagate the Hessian flag through JAX .jax, .hlo, and .savedmodel serialization
  • mark HLO energy output definitions as Hessian-enabled and request Hessian outputs during JAX inference

Tests

  • source venv/bin/activate && pytest source/tests/jax/test_training.py::TestJAXTraining::test_freeze_entrypoint_uses_checkpoint_pointer source/tests/jax/test_training.py::TestJAXTraining::test_main_dispatches_freeze source/tests/jax/test_training.py::TestJAXTraining::test_hlo_hessian_mode_updates_output_def source/tests/jax/test_training.py::TestJAXTraining::test_deep_eval_requests_hessian_for_hessian_model -q
  • source venv/bin/activate && ruff check .
  • source venv/bin/activate && ruff format .

Summary by CodeRabbit

  • New Features
    • Added an optional Hessian flag to model freezing, allowing frozen outputs to include Hessian information.
    • Hessian-aware evaluation is now supported when the model is configured for it.
  • Bug Fixes
    • Improved export behavior so Hessian settings are preserved in saved model metadata and output definitions.
    • Updated evaluation output requests to include Hessian-related outputs when needed.
  • Tests
    • Added regression coverage for Hessian-enabled freezing and inference behavior.

@dosubot dosubot Bot added the new feature label Jun 29, 2026
@coderabbitai

coderabbitai Bot commented Jun 29, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Warning

Review limit reached

@njzjz, you've reached your PR review limit, so we couldn't start this review.

Next review available in: 17 minutes

Enable usage-based reviews in Billing to review now. Otherwise, wait until the next included review is available.
You're only billed for reviews past your plan's rate limits ($0.25/file).

How can I continue?

After more reviews become available, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

To avoid repeated limits, reduce automatic review volume by pausing incremental auto-reviews earlier, using label-based review opt-in, excluding WIP or generated PR titles, or requesting reviews manually when the PR is ready. If your team needs uninterrupted high-volume reviews, an organization admin can enable usage-based reviews.

How do review limits work?

CodeRabbit enforces per-developer PR review limits for each organization. Most developers receive the normal plan review availability.

For paid Pro and Pro+ PR reviews, CodeRabbit uses adaptive limits for sustained high-volume activity. When a developer's recent PR review activity reaches the 95th percentile or higher among CodeRabbit users, additional reviews become available more gradually as earlier reviews age out of the rolling window.

Please refer docs for additional details.

Review details
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: ff75b8f1-10ca-40a7-84d9-5542aae55590

📥 Commits

Reviewing files that changed from the base of the PR and between 16058df and 6eabd63.

📒 Files selected for processing (7)
  • deepmd/jax/entrypoints/freeze.py
  • deepmd/jax/infer/deep_eval.py
  • deepmd/jax/jax2tf/serialization.py
  • deepmd/jax/model/hlo.py
  • deepmd/jax/utils/serialization.py
  • deepmd/main.py
  • source/tests/jax/test_training.py
📝 Walkthrough

Walkthrough

Adds a --hessian boolean flag to the JAX freeze CLI subcommand. The flag propagates through the freeze entrypoint to deserialize_to_file in both utils and jax2tf serialization modules, which call enable_hessian() on the model and set hessian_mode in model_def_script. HLO.model_output_def() remaps energy to a new energy_hessian output definition when hessian_mode is active. DeepEval gains DERV_R_DERV_R in non-atomic requests and a get_has_hessian() method.

Changes

JAX freeze Hessian support

Layer / File(s) Summary
HLO energy_hessian output definition and remapping
deepmd/jax/model/hlo.py
Adds "energy_hessian" to OUTPUT_DEFS with r_hessian=True and updates model_output_def() to substitute "energy_hessian" for "energy" when hessian_mode is set in the parsed model_def_script.
deserialize_to_file hessian propagation
deepmd/jax/utils/serialization.py, deepmd/jax/jax2tf/serialization.py
Both deserialize_to_file functions accept a new hessian: bool = False parameter. When true, each calls model.enable_hessian(), sets model_def_script["hessian_mode"] = True, writes model_def_script back to data, and forwards the flag to deserialize_to_savedmodel.
CLI flag and freeze entrypoint wiring
deepmd/main.py, deepmd/jax/entrypoints/freeze.py
Adds --hessian (action="store_true") to the freeze subcommand and adds hessian: bool = False to freeze(), passing it through to deserialize_to_file.
DeepEval Hessian request and detection
deepmd/jax/infer/deep_eval.py
_get_request_defs adds OutputVariableCategory.DERV_R_DERV_R to non-atomic output categories; new get_has_hessian() reads hessian_mode from the model definition script.
Tests
source/tests/jax/test_training.py
Updates the freeze entrypoint test to assert hessian=True forwarding; adds test_hlo_hessian_mode_updates_output_def and test_deep_eval_requests_hessian_for_hessian_model.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly matches the main change: adding JAX freeze support for Hessian output.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
source/tests/jax/test_training.py (1)

199-214: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Assert hessian in the dispatcher call.

This still passes if main() drops args.hessian before invoking freeze(). Please assert the patched call carries the new flag as well.

Suggested test tightening
         main(args)

         freeze_entrypoint.assert_called_once()
+        self.assertIn("hessian", freeze_entrypoint.call_args.kwargs)
+        self.assertFalse(freeze_entrypoint.call_args.kwargs["hessian"])
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/jax/test_training.py` around lines 199 - 214, The test for the
JAX CLI freeze dispatch is too loose because it only checks that the patched
freeze entrypoint was called. Tighten test_main_dispatches_freeze in
test_training.py by asserting the call includes the hessian flag from the
argparse.Namespace, so main() must forward args.hessian into
deepmd.jax.entrypoints.main.freeze rather than dropping it. Use the existing
freeze_entrypoint mock and verify its call arguments reflect the new flag.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@source/tests/jax/test_training.py`:
- Around line 199-214: The test for the JAX CLI freeze dispatch is too loose
because it only checks that the patched freeze entrypoint was called. Tighten
test_main_dispatches_freeze in test_training.py by asserting the call includes
the hessian flag from the argparse.Namespace, so main() must forward
args.hessian into deepmd.jax.entrypoints.main.freeze rather than dropping it.
Use the existing freeze_entrypoint mock and verify its call arguments reflect
the new flag.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 44d1886d-fd38-4bbe-a056-c56edc3c92fc

📥 Commits

Reviewing files that changed from the base of the PR and between 1550599 and 16058df.

📒 Files selected for processing (7)
  • deepmd/jax/entrypoints/freeze.py
  • deepmd/jax/infer/deep_eval.py
  • deepmd/jax/jax2tf/serialization.py
  • deepmd/jax/model/hlo.py
  • deepmd/jax/utils/serialization.py
  • deepmd/main.py
  • source/tests/jax/test_training.py

@njzjz njzjz force-pushed the feat/jax-hessian-freeze branch from 16058df to 6eabd63 Compare June 29, 2026 18:34


def deserialize_to_file(model_file: str, data: dict) -> None:
def deserialize_to_file(model_file: str, data: dict, hessian: bool = False) -> None:
@codecov

codecov Bot commented Jun 29, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 61.90476% with 8 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.41%. Comparing base (1550599) to head (6eabd63).

Files with missing lines Patch % Lines
deepmd/jax/utils/serialization.py 45.45% 6 Missing ⚠️
deepmd/jax/jax2tf/serialization.py 60.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5608      +/-   ##
==========================================
- Coverage   82.41%   82.41%   -0.01%     
==========================================
  Files         903      903              
  Lines      101846   101859      +13     
  Branches     4071     4073       +2     
==========================================
+ Hits        83940    83943       +3     
- Misses      16439    16447       +8     
- Partials     1467     1469       +2     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants