From ae453b010110c98eb1813deb819c1f0f8d6a4fa8 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 30 Jun 2026 02:09:21 +0800 Subject: [PATCH] fix(jax): add hessian energy loss --- deepmd/dpmodel/loss/ener.py | 39 ++++++++++++++ deepmd/jax/train/trainer.py | 6 +++ source/tests/common/dpmodel/test_loss_ener.py | 54 ++++++++++++++++++- 3 files changed, 97 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/loss/ener.py b/deepmd/dpmodel/loss/ener.py index 7515f19b9a..0a8d57fc8c 100644 --- a/deepmd/dpmodel/loss/ener.py +++ b/deepmd/dpmodel/loss/ener.py @@ -71,6 +71,10 @@ class EnergyLoss(Loss): The prefactor of generalized force loss at the end of the training. numb_generalized_coord : int The dimension of generalized coordinates. + start_pref_h : float + The prefactor of Hessian loss at the start of the training. + limit_pref_h : float + The prefactor of Hessian loss at the end of the training. use_default_pf : bool If true, use default atom_pref of 1.0 for all atoms when atom_pref data is not provided. This allows using the prefactor force loss (pf) without requiring atom_pref.npy files. @@ -123,6 +127,8 @@ def __init__( start_pref_gf: float = 0.0, limit_pref_gf: float = 0.0, numb_generalized_coord: int = 0, + start_pref_h: float = 0.0, + limit_pref_h: float = 0.0, use_huber: bool = False, huber_delta: float | list[float] = 0.01, loss_func: str = "mse", @@ -155,12 +161,15 @@ def __init__( self.start_pref_gf = start_pref_gf self.limit_pref_gf = limit_pref_gf self.numb_generalized_coord = numb_generalized_coord + self.start_pref_h = start_pref_h + self.limit_pref_h = limit_pref_h self.has_e = self.start_pref_e != 0.0 or self.limit_pref_e != 0.0 self.has_f = self.start_pref_f != 0.0 or self.limit_pref_f != 0.0 self.has_v = self.start_pref_v != 0.0 or self.limit_pref_v != 0.0 self.has_ae = self.start_pref_ae != 0.0 or self.limit_pref_ae != 0.0 self.has_pf = self.start_pref_pf != 0.0 or self.limit_pref_pf != 0.0 self.has_gf = self.start_pref_gf != 0.0 or self.limit_pref_gf != 0.0 + self.has_h = self.start_pref_h != 0.0 or self.limit_pref_h != 0.0 if self.has_gf and self.numb_generalized_coord < 1: raise RuntimeError( "When generalized force loss is used, the dimension of generalized coordinates should be larger than 0" @@ -270,6 +279,7 @@ def call( pref_pf = find_atom_pref * ( self.limit_pref_pf + (self.start_pref_pf - self.limit_pref_pf) * lr_ratio ) + pref_h = self.limit_pref_h + (self.start_pref_h - self.limit_pref_h) * lr_ratio loss = 0 more_loss = {} @@ -457,6 +467,23 @@ def call( more_loss["rmse_gf"] = self.display_if_exist( xp.sqrt(l2_gen_force_loss), find_drdq ) + hessian = model_dict.get( + "hessian", model_dict.get("energy_derv_r_derv_r", None) + ) + if self.has_h and hessian is not None and "hessian" in label_dict: + find_hessian = label_dict.get("find_hessian", 0.0) + diff_h = xp.reshape(label_dict["hessian"], (-1,)) - xp.reshape( + hessian, + (-1,), + ) + l2_hessian_loss = xp.mean(xp.square(diff_h)) + loss += pref_h * find_hessian * l2_hessian_loss + more_loss["rmse_h"] = self.display_if_exist( + xp.sqrt(l2_hessian_loss), find_hessian + ) + if mae: + mae_h = xp.mean(xp.abs(diff_h)) + more_loss["mae_h"] = self.display_if_exist(mae_h, find_hessian) self.l2_l = loss more_loss["rmse"] = xp.sqrt(loss) @@ -535,6 +562,16 @@ def label_requirement(self) -> list[DataRequirementItem]: default=1.0, ) ) + if self.has_h: + label_requirement.append( + DataRequirementItem( + "hessian", + ndof=1, + atomic=True, + must=False, + high_prec=False, + ) + ) return label_requirement def serialize(self) -> dict: @@ -564,6 +601,8 @@ def serialize(self) -> dict: "start_pref_gf": self.start_pref_gf, "limit_pref_gf": self.limit_pref_gf, "numb_generalized_coord": self.numb_generalized_coord, + "start_pref_h": self.start_pref_h, + "limit_pref_h": self.limit_pref_h, "use_huber": self.use_huber, "huber_delta": self.huber_delta, "loss_func": self.loss_func, diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index 180249eaef..46d8fc581e 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -117,6 +117,8 @@ def get_lr_and_coef(lr_param: dict) -> LearningRateExp: loss_type = loss_param.get("type", "ener") if loss_type == "ener": self.loss = EnergyLoss.get_loss(loss_param) + if getattr(self.loss, "has_h", False): + self.model.enable_hessian() else: raise RuntimeError("unknown loss type " + loss_type) @@ -211,6 +213,8 @@ def loss_fn( model_dict["energy"] = model_dict["energy_redu"] model_dict["force"] = model_dict["energy_derv_r"].squeeze(-2) model_dict["virial"] = model_dict["energy_derv_c_redu"].squeeze(-2) + if model_dict.get("energy_derv_r_derv_r") is not None: + model_dict["hessian"] = model_dict["energy_derv_r_derv_r"].squeeze(-3) loss, more_loss = self.loss( learning_rate=lr, natoms=label_dict["type"].shape[1], @@ -249,6 +253,8 @@ def loss_fn_more_loss( model_dict["energy"] = model_dict["energy_redu"] model_dict["force"] = model_dict["energy_derv_r"].squeeze(-2) model_dict["virial"] = model_dict["energy_derv_c_redu"].squeeze(-2) + if model_dict.get("energy_derv_r_derv_r") is not None: + model_dict["hessian"] = model_dict["energy_derv_r_derv_r"].squeeze(-3) loss, more_loss = self.loss( learning_rate=lr, natoms=label_dict["type"].shape[1], diff --git a/source/tests/common/dpmodel/test_loss_ener.py b/source/tests/common/dpmodel/test_loss_ener.py index ebf9ba0a64..cd6803a4ab 100644 --- a/source/tests/common/dpmodel/test_loss_ener.py +++ b/source/tests/common/dpmodel/test_loss_ener.py @@ -15,7 +15,14 @@ class TestEnergyLossBase(unittest.TestCase): """Base class providing common setup for dpmodel EnergyLoss tests.""" - def _make_data(self, natoms=5, nframes=2, numb_generalized_coord=0): + def _make_data( + self, + natoms=5, + nframes=2, + numb_generalized_coord=0, + hessian=False, + hessian_key="hessian", + ): """Generate fake model predictions and labels.""" rng = np.random.default_rng(GLOBAL_SEED) model_dict = { @@ -43,6 +50,10 @@ def _make_data(self, natoms=5, nframes=2, numb_generalized_coord=0): label_dict["find_drdq"] = 1.0 if hasattr(self, "enable_atom_ener_coeff") and self.enable_atom_ener_coeff: label_dict["atom_ener_coeff"] = rng.random((nframes, natoms, 1)) + if hessian: + model_dict[hessian_key] = rng.random((nframes, 3 * natoms, 3 * natoms)) + label_dict["hessian"] = rng.random((nframes, 3 * natoms, 3 * natoms)) + label_dict["find_hessian"] = 1.0 return model_dict, label_dict, natoms @@ -145,6 +156,38 @@ def test_forward(self) -> None: self.assertIsNotNone(loss) +class TestEnergyLossHessian(TestEnergyLossBase): + """Test Hessian loss inside the dpmodel energy loss.""" + + def test_forward_hessian(self) -> None: + loss_fn = EnergyLoss( + starter_learning_rate=1.0, + start_pref_e=0.0, + limit_pref_e=0.0, + start_pref_f=0.0, + limit_pref_f=0.0, + start_pref_v=0.0, + limit_pref_v=0.0, + start_pref_h=2.0, + limit_pref_h=1.0, + ) + model_dict, label_dict, natoms = self._make_data( + hessian=True, + hessian_key="energy_derv_r_derv_r", + ) + loss, more_loss = loss_fn.call(1.0, natoms, model_dict, label_dict) + diff_h = label_dict["hessian"].reshape(-1) - model_dict[ + "energy_derv_r_derv_r" + ].reshape(-1) + l2_hessian_loss = np.mean(np.square(diff_h)) + np.testing.assert_allclose(loss, 2.0 * l2_hessian_loss) + np.testing.assert_allclose(more_loss["rmse_h"], np.sqrt(l2_hessian_loss)) + self.assertIn( + "hessian", + {item.key for item in loss_fn.label_requirement}, + ) + + class TestEnergyLossSerialize(TestEnergyLossBase): """Test serialize/deserialize round-trip.""" @@ -160,10 +203,17 @@ def test_serialize_deserialize(self) -> None: start_pref_gf=1.0, limit_pref_gf=0.5, numb_generalized_coord=2, + start_pref_h=2.0, + limit_pref_h=0.5, ) data = loss_fn.serialize() + self.assertEqual(data["start_pref_h"], 2.0) + self.assertEqual(data["limit_pref_h"], 0.5) loss_fn2 = EnergyLoss.deserialize(data) - model_dict, label_dict, natoms = self._make_data(numb_generalized_coord=2) + model_dict, label_dict, natoms = self._make_data( + numb_generalized_coord=2, + hessian=True, + ) loss1, more1 = loss_fn.call(1.0, natoms, model_dict, label_dict) loss2, more2 = loss_fn2.call(1.0, natoms, model_dict, label_dict) np.testing.assert_allclose(loss1, loss2)