Skip to content

fix(jax): add hessian energy loss#5607

Open
njzjz wants to merge 1 commit into
deepmodeling:masterfrom
njzjz:fix/dpmodel-energy-hessian-loss
Open

fix(jax): add hessian energy loss#5607
njzjz wants to merge 1 commit into
deepmodeling:masterfrom
njzjz:fix/dpmodel-energy-hessian-loss

Conversation

@njzjz

@njzjz njzjz commented Jun 29, 2026

Copy link
Copy Markdown
Member

Summary

  • add Hessian prefactors directly to dpmodel EnergyLoss instead of introducing a separate loss class
  • add Hessian label requirements, loss/RMSE reporting, and serialization fields
  • enable Hessian outputs in the JAX trainer when the Hessian loss is configured

Tests

  • source venv/bin/activate && pytest source/tests/common/dpmodel/test_loss_ener.py -q
  • source venv/bin/activate && ruff check .
  • source venv/bin/activate && ruff format .

Summary by CodeRabbit

  • New Features

    • Added optional Hessian loss support to the energy loss, including configurable scheduling parameters for Hessian weighting.
    • When Hessian targets are enabled, training now requests Hessian labels and computes an additional Hessian error term (with corresponding metrics).
  • Tests

    • Added dedicated test coverage for the Hessian-loss computation path.
    • Updated loss serialization/round-trip tests to include the new Hessian weighting parameters and regenerated Hessian test data.

@dosubot dosubot Bot added the enhancement label Jun 29, 2026
@njzjz njzjz force-pushed the fix/dpmodel-energy-hessian-loss branch from b56a253 to ae453b0 Compare June 29, 2026 18:24
@coderabbitai

coderabbitai Bot commented Jun 29, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 162d37c7-9fa9-4b6f-a35c-1b8c82486b72

📥 Commits

Reviewing files that changed from the base of the PR and between b56a253 and ae453b0.

📒 Files selected for processing (3)
  • deepmd/dpmodel/loss/ener.py
  • deepmd/jax/train/trainer.py
  • source/tests/common/dpmodel/test_loss_ener.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • deepmd/jax/train/trainer.py
  • source/tests/common/dpmodel/test_loss_ener.py
  • deepmd/dpmodel/loss/ener.py

📝 Walkthrough

Walkthrough

EnergyLoss adds optional Hessian loss prefactors and computes a Hessian L2 term when enabled. The JAX trainer enables Hessian support and exposes Hessian tensors to the loss. Tests cover the new path and serialization.

Changes

Hessian Loss Feature

Layer / File(s) Summary
EnergyLoss Hessian parameters, call logic, label requirement, and serialization
deepmd/dpmodel/loss/ener.py
Adds start_pref_h/limit_pref_h params and has_h flag to __init__, computes pref_h and Hessian MSE loss in call(), appends "hessian" DataRequirementItem in label_requirement, and exports new fields in serialize().
JAX trainer Hessian wiring
deepmd/jax/train/trainer.py
DPTrainer calls self.model.enable_hessian() when loss.has_h is true; both loss_fn and loss_fn_more_loss populate model_dict["hessian"] from energy_derv_r_derv_r by squeezing axis -3.
Tests
source/tests/common/dpmodel/test_loss_ener.py
_make_data extended with hessian/hessian_key options; TestEnergyLossHessian validates loss value, rmse_h metric, and label_requirement; serialize/deserialize round-trip updated to cover Hessian prefactors.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5325: Modifies EnergyLoss.call return path and more_loss dict in the same file, directly adjacent to the Hessian loss branch added here.

Suggested labels

new feature, Python

Suggested reviewers

  • wanghan-iapcm
  • anyangml
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: adding Hessian energy loss support in JAX.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Warning

There were issues while running some tools. Please review the errors and either fix the tool's configuration or disable the tool if it's a critical failure.

🔧 OpenGrep (1.23.0)
source/tests/common/dpmodel/test_loss_ener.py

┌──────────────┐
│ Opengrep CLI │
└──────────────┘

�[32m✔�[39m �[1mOpengrep OSS�[0m
�[32m✔�[39m Basic security coverage for first-party code vulnerabilities.

[00.13][ERROR]: unable to find a config; path .coderabbit-opengrep-fallback.yml does not exist

deepmd/dpmodel/loss/ener.py

┌──────────────┐
│ Opengrep CLI │
└──────────────┘

�[32m✔�[39m �[1mOpengrep OSS�[0m
�[32m✔�[39m Basic security coverage for first-party code vulnerabilities.

[00.15][ERROR]: unable to find a config; path .coderabbit-opengrep-fallback.yml does not exist

deepmd/jax/train/trainer.py

┌──────────────┐
│ Opengrep CLI │
└──────────────┘

�[32m✔�[39m �[1mOpengrep OSS�[0m
�[32m✔�[39m Basic security coverage for first-party code vulnerabilities.

[00.14][ERROR]: unable to find a config; path .coderabbit-opengrep-fallback.yml does not exist


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/loss/ener.py`:
- Around line 565-574: The Hessian entry in `DPTrainer.data_requirements` is
advertising the wrong tensor shape for `label_requirement`. Update the
`DataRequirementItem("hessian", ...)` definition in `ener.py` so it matches the
real on-disk Hessian layout used by the new loss path and tests, rather than the
current atomic `ndof=1` schema. If the dataset loader cannot yet consume the
full Hessian tensor, add the loader support first and keep `has_h` gated until
the contract is consistent.

In `@deepmd/jax/train/trainer.py`:
- Around line 120-121: The Hessian enablement in Trainer setup is unguarded, so
`self.model.enable_hessian()` can break for model/loss combinations like the JAX
`zbl` path that do not implement it. Update the `Trainer` logic to check that
the model actually exposes `enable_hessian` before calling it, using the
existing `self.loss.has_h` condition as the trigger and keeping the current
hessian tensor handling unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 23778160-e75a-4d2b-be1f-2babb0b6710f

📥 Commits

Reviewing files that changed from the base of the PR and between 1550599 and b56a253.

📒 Files selected for processing (3)
  • deepmd/dpmodel/loss/ener.py
  • deepmd/jax/train/trainer.py
  • source/tests/common/dpmodel/test_loss_ener.py

Comment on lines +565 to +574
if self.has_h:
label_requirement.append(
DataRequirementItem(
"hessian",
ndof=1,
atomic=True,
must=False,
high_prec=False,
)
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🗄️ Data Integrity & Integration | 🟠 Major | 🏗️ Heavy lift

label_requirement advertises the wrong Hessian shape.

DataRequirementItem("hessian", ndof=1, atomic=True) describes an nframes x natoms x 1 label, but the new loss path and tests consume a full nframes x (3 * natoms) x (3 * natoms) tensor. DPTrainer.data_requirements forwards this schema to the dataset loader, so real Hessian training will request/load hessian.npy with an incompatible contract even though these unit tests pass by building label_dict manually. Please make the requirement describe the real on-disk Hessian layout, or add dedicated loader support before enabling this in training.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/loss/ener.py` around lines 565 - 574, The Hessian entry in
`DPTrainer.data_requirements` is advertising the wrong tensor shape for
`label_requirement`. Update the `DataRequirementItem("hessian", ...)` definition
in `ener.py` so it matches the real on-disk Hessian layout used by the new loss
path and tests, rather than the current atomic `ndof=1` schema. If the dataset
loader cannot yet consume the full Hessian tensor, add the loader support first
and keep `has_h` gated until the contract is consistent.

Comment thread deepmd/jax/train/trainer.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant