feat(jax): freeze models with hessian output#5608
Conversation
|
Warning Review limit reached
Next review available in: 17 minutes Enable usage-based reviews in Billing to review now. Otherwise, wait until the next included review is available. How can I continue?After more reviews become available, a review can be triggered using the To avoid repeated limits, reduce automatic review volume by pausing incremental auto-reviews earlier, using label-based review opt-in, excluding WIP or generated PR titles, or requesting reviews manually when the PR is ready. If your team needs uninterrupted high-volume reviews, an organization admin can enable usage-based reviews. How do review limits work?CodeRabbit enforces per-developer PR review limits for each organization. Most developers receive the normal plan review availability. For paid Pro and Pro+ PR reviews, CodeRabbit uses adaptive limits for sustained high-volume activity. When a developer's recent PR review activity reaches the 95th percentile or higher among CodeRabbit users, additional reviews become available more gradually as earlier reviews age out of the rolling window. Please refer docs for additional details. Review details⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (7)
📝 WalkthroughWalkthroughAdds a ChangesJAX freeze Hessian support
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
source/tests/jax/test_training.py (1)
199-214: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winAssert
hessianin the dispatcher call.This still passes if
main()dropsargs.hessianbefore invokingfreeze(). Please assert the patched call carries the new flag as well.Suggested test tightening
main(args) freeze_entrypoint.assert_called_once() + self.assertIn("hessian", freeze_entrypoint.call_args.kwargs) + self.assertFalse(freeze_entrypoint.call_args.kwargs["hessian"])🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/jax/test_training.py` around lines 199 - 214, The test for the JAX CLI freeze dispatch is too loose because it only checks that the patched freeze entrypoint was called. Tighten test_main_dispatches_freeze in test_training.py by asserting the call includes the hessian flag from the argparse.Namespace, so main() must forward args.hessian into deepmd.jax.entrypoints.main.freeze rather than dropping it. Use the existing freeze_entrypoint mock and verify its call arguments reflect the new flag.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@source/tests/jax/test_training.py`:
- Around line 199-214: The test for the JAX CLI freeze dispatch is too loose
because it only checks that the patched freeze entrypoint was called. Tighten
test_main_dispatches_freeze in test_training.py by asserting the call includes
the hessian flag from the argparse.Namespace, so main() must forward
args.hessian into deepmd.jax.entrypoints.main.freeze rather than dropping it.
Use the existing freeze_entrypoint mock and verify its call arguments reflect
the new flag.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 44d1886d-fd38-4bbe-a056-c56edc3c92fc
📒 Files selected for processing (7)
deepmd/jax/entrypoints/freeze.pydeepmd/jax/infer/deep_eval.pydeepmd/jax/jax2tf/serialization.pydeepmd/jax/model/hlo.pydeepmd/jax/utils/serialization.pydeepmd/main.pysource/tests/jax/test_training.py
16058df to
6eabd63
Compare
|
|
||
|
|
||
| def deserialize_to_file(model_file: str, data: dict) -> None: | ||
| def deserialize_to_file(model_file: str, data: dict, hessian: bool = False) -> None: |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5608 +/- ##
==========================================
- Coverage 82.41% 82.41% -0.01%
==========================================
Files 903 903
Lines 101846 101859 +13
Branches 4071 4073 +2
==========================================
+ Hits 83940 83943 +3
- Misses 16439 16447 +8
- Partials 1467 1469 +2 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Summary
dp freeze --hessiansupport for the JAX backend.jax,.hlo, and.savedmodelserializationTests
source venv/bin/activate && pytest source/tests/jax/test_training.py::TestJAXTraining::test_freeze_entrypoint_uses_checkpoint_pointer source/tests/jax/test_training.py::TestJAXTraining::test_main_dispatches_freeze source/tests/jax/test_training.py::TestJAXTraining::test_hlo_hessian_mode_updates_output_def source/tests/jax/test_training.py::TestJAXTraining::test_deep_eval_requests_hessian_for_hessian_model -qsource venv/bin/activate && ruff check .source venv/bin/activate && ruff format .Summary by CodeRabbit