feat(dpmodel): add descriptor compression#5592
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds descriptor compression support for ChangesDescriptor Compression and CLI Entrypoints
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (3)
source/tests/jax/test_model_compression.py (1)
59-82: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winAdd an entrypoint round-trip case for DPA1 compression.
The descriptor tests cover DPA1 in-memory serialization, but both CLI/Orbax entrypoint tests use only
se_e2_a. Add a DPA1 model-data variant soserialize_from_file()exercises restoredtype_embd_data/geometric compression state from a real.jaxcheckpoint.Also applies to: 216-301
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/jax/test_model_compression.py` around lines 59 - 82, Add a DPA1 model-data variant to the compression tests so the entrypoint round-trip covers more than se_e2_a. Update the model fixture-building helpers in test_model_compression.py, especially _make_model_data and any related CLI/Orbax test setup, to create a DPA1 checkpoint and verify serialize_from_file() restores type_embd_data and geometric compression state from the saved .jax file. Keep the existing se_e2_a coverage, but add a separate DPA1 case that exercises the same entrypoint flow end-to-end.deepmd/dpmodel/entrypoints/compress.py (2)
53-93: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueConsider de-duplicating the compress entrypoint logic shared with the JAX path.
enable_compression,_compute_min_nbor_dist, and the_get_saved_min_nbor_distresolution are nearly identical todeepmd/jax/entrypoints/compress.py. Extracting the backend-agnostic min_nbor_dist resolution/compression flow into a shared helper would reduce drift (e.g., this path skips the_to_floatcoercion the JAX path applies). Optional given the differing serialization imports per backend.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/dpmodel/entrypoints/compress.py` around lines 53 - 93, The compression flow in enable_compression duplicates the JAX entrypoint logic, so extract the shared min_nbor_dist resolution and compression steps into a backend-agnostic helper used by both paths. Keep the backend-specific model loading/saving in place, but centralize the common _get_saved_min_nbor_dist, _compute_min_nbor_dist, and model.enable_compression handling so behavior stays aligned and drift like missing _to_float coercion is avoided.
82-87: 🎯 Functional Correctness | 🔵 TrivialConfirm the positional argument order for
model.enable_compression.The signature of
BaseModel.enable_compressionis:
(table_extrapolate, table_stride_1, table_stride_2, check_frequency).The call at lines 82-87:
model.enable_compression( extrapolate, stride, stride * 10, check_frequency, )correctly maps to the signature parameters in order. The calculation
stride * 10correctly assigns totable_stride_2.While the current positional usage is correct, using explicit keyword arguments would improve readability and prevent future regressions if the signature changes:
model.enable_compression( table_extrapolate=extrapolate, table_stride_1=stride, table_stride_2=stride * 10, check_frequency=check_frequency, )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/dpmodel/entrypoints/compress.py` around lines 82 - 87, The call to model.enable_compression already matches BaseModel.enable_compression’s parameter order, but it should be rewritten with explicit keyword arguments for clarity and to avoid regressions; update the enable_compression call in compress.py to name table_extrapolate, table_stride_1, table_stride_2, and check_frequency explicitly while keeping the same values.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/dpmodel/descriptor/se_e2_a.py`:
- Line 738: The compressed `DescrptSeAArrayAPI` path uses an in-place slice
update on `gr_s`, which is not JAX-compatible. Update the accumulation logic
around `self._tabulate_fusion_se_a(...)` to use functional array updates (for
example via the array API update pattern on `gr_s`) instead of direct slice
assignment. Make sure the fix is applied in the `DescrptSeAArrayAPI` flow so
both `type_one_side=True` and `type_one_side=False` work correctly under JAX.
In `@deepmd/jax/utils/serialization.py`:
- Around line 47-54: The restore logic in the serialization helper only
recreates compress_data and compress_info, but DPA1/se_atten_v2 compressed
checkpoints also need type_embd_data to be restored before replace_by_pure_dict
runs. Update the state-reconstruction branch in the helper that checks
obj.compress so it also detects and assigns type_embd_data from state, using the
same sequence-to-numpy conversion pattern as the other compressed slots, so
compressed DPA1 checkpoints keep matching keys and preserve the type-embedding
payload.
---
Nitpick comments:
In `@deepmd/dpmodel/entrypoints/compress.py`:
- Around line 53-93: The compression flow in enable_compression duplicates the
JAX entrypoint logic, so extract the shared min_nbor_dist resolution and
compression steps into a backend-agnostic helper used by both paths. Keep the
backend-specific model loading/saving in place, but centralize the common
_get_saved_min_nbor_dist, _compute_min_nbor_dist, and model.enable_compression
handling so behavior stays aligned and drift like missing _to_float coercion is
avoided.
- Around line 82-87: The call to model.enable_compression already matches
BaseModel.enable_compression’s parameter order, but it should be rewritten with
explicit keyword arguments for clarity and to avoid regressions; update the
enable_compression call in compress.py to name table_extrapolate,
table_stride_1, table_stride_2, and check_frequency explicitly while keeping the
same values.
In `@source/tests/jax/test_model_compression.py`:
- Around line 59-82: Add a DPA1 model-data variant to the compression tests so
the entrypoint round-trip covers more than se_e2_a. Update the model
fixture-building helpers in test_model_compression.py, especially
_make_model_data and any related CLI/Orbax test setup, to create a DPA1
checkpoint and verify serialize_from_file() restores type_embd_data and
geometric compression state from the saved .jax file. Keep the existing se_e2_a
coverage, but add a separate DPA1 case that exercises the same entrypoint flow
end-to-end.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: e7e1c4a8-5fb1-4cfb-8f9c-64794a8b169e
📒 Files selected for processing (14)
deepmd/backend/dpmodel.pydeepmd/dpmodel/descriptor/dpa1.pydeepmd/dpmodel/descriptor/se_atten_v2.pydeepmd/dpmodel/descriptor/se_e2_a.pydeepmd/dpmodel/descriptor/se_r.pydeepmd/dpmodel/entrypoints/__init__.pydeepmd/dpmodel/entrypoints/compress.pydeepmd/dpmodel/entrypoints/main.pydeepmd/jax/entrypoints/compress.pydeepmd/jax/entrypoints/main.pydeepmd/jax/utils/serialization.pydeepmd/main.pysource/tests/common/dpmodel/test_model_compression.pysource/tests/jax/test_model_compression.py
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #5592 +/- ##
==========================================
+ Coverage 82.37% 82.42% +0.05%
==========================================
Files 902 907 +5
Lines 101527 101998 +471
Branches 4056 4056
==========================================
+ Hits 83630 84074 +444
- Misses 16432 16459 +27
Partials 1465 1465 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
🧹 Nitpick comments (2)
source/lib/tests/test_tabulate_extrapolate.cc (2)
202-210: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value
grad_central_diffis identical tocentral_diff. Both compute the same central-difference formula. Either reusecentral_difffor the constant-gradient check or, if the second helper is meant to document a distinct intent (numerically differentiating the gradient), add a brief comment; the duplicated body is otherwise confusing.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/lib/tests/test_tabulate_extrapolate.cc` around lines 202 - 210, grad_central_diff duplicates central_diff with the same central-difference formula, so either call central_diff from grad_central_diff to reuse the shared logic or add a short comment in grad_central_diff explaining the distinct intent if it is meant to represent differentiating the gradient. Update the test helpers in test_tabulate_extrapolate.cc so the relationship between central_diff and grad_central_diff is explicit and not misleading.
232-268: 📐 Maintainability & Code Quality | 🔵 Trivial | 🏗️ Heavy liftSecond-order (
*_grad_grad) extrapolation paths are untested.These tests validate value and first-gradient linear tails, but the substantially-changed
tabulate_fusion_se_a/se_t/se_t_tebd/se_r_grad_grad_*paths (which now fold the extrapolation term intovar = poly5 + var_grad * extrapolate_delta) get no coverage here. Second-order math is the easiest to get wrong; consider adding grad_grad cases (e.g. a forward-over-reverse / numeric check ondz_dybelow-lower and above-max).🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/lib/tests/test_tabulate_extrapolate.cc` around lines 232 - 268, Add coverage for the second-order extrapolation paths in the tabulation tests: the current cases in TabulateExtrapolate only exercise value and first-derivative linear tails, but do not validate the new `*_grad_grad` logic in `tabulate_fusion_se_a`, `se_r`, `se_t`, and `se_t_tebd`. Extend the existing test helpers around `expect_linear_tail`, `expect_boundary`, and the `se_*_grad`/`se_*_value` checks with cases that call the `*_grad_grad` functions below `kLower`/`kMin` and above `kMax`, and verify the second-derivative output against a numeric or forward-over-reverse expectation. Use the existing `locate_se_a_or_r`, `locate_se_t`, and `expected_table_*` helpers to keep the new cases aligned with the current test structure.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@source/lib/tests/test_tabulate_extrapolate.cc`:
- Around line 202-210: grad_central_diff duplicates central_diff with the same
central-difference formula, so either call central_diff from grad_central_diff
to reuse the shared logic or add a short comment in grad_central_diff explaining
the distinct intent if it is meant to represent differentiating the gradient.
Update the test helpers in test_tabulate_extrapolate.cc so the relationship
between central_diff and grad_central_diff is explicit and not misleading.
- Around line 232-268: Add coverage for the second-order extrapolation paths in
the tabulation tests: the current cases in TabulateExtrapolate only exercise
value and first-derivative linear tails, but do not validate the new
`*_grad_grad` logic in `tabulate_fusion_se_a`, `se_r`, `se_t`, and `se_t_tebd`.
Extend the existing test helpers around `expect_linear_tail`, `expect_boundary`,
and the `se_*_grad`/`se_*_value` checks with cases that call the `*_grad_grad`
functions below `kLower`/`kMin` and above `kMax`, and verify the
second-derivative output against a numeric or forward-over-reverse expectation.
Use the existing `locate_se_a_or_r`, `locate_se_t`, and `expected_table_*`
helpers to keep the new cases aligned with the current test structure.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 2ae98b78-4092-4e73-98f2-b7625c53eab2
📒 Files selected for processing (3)
source/lib/src/gpu/tabulate.cusource/lib/src/tabulate.ccsource/lib/tests/test_tabulate_extrapolate.cc
661b3bd to
52f8c44
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
source/tests/common/dpmodel/test_model_compression.py (1)
349-400: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winExercise one inference pass on the reloaded compressed model.
This test only checks serialized metadata. If the entrypoint preserves
compressandmin_nbor_distbut corrupts descriptor state during save/load, it still passes. Comparing one pre/post forward call on the fixture would close that gap.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/common/dpmodel/test_model_compression.py` around lines 349 - 400, The test in test_dpmodel_compress_entrypoint only validates serialized metadata after enable_compression and load_dp_model, so it misses descriptor corruption in the actual model state. After reloading the compressed model, run one forward/inference pass on the compressed model and compare its output to the original model from get_model/serialize to ensure the entrypoint preserves usable behavior, not just the presence of compress and min_nbor_dist.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@source/tests/common/dpmodel/test_model_compression.py`:
- Around line 176-185: The `np.testing.assert_allclose` checks in
`test_model_compression` are too lenient because they still use the default
relative tolerance, so make the extrapolation/parity assertions strict by
setting `rtol=0` wherever `assert_allclose` is used in this test. Update the
affected checks in `test_model_compression` (including the repeated cases later
in the file) so they rely only on the explicit `atol` values.
---
Nitpick comments:
In `@source/tests/common/dpmodel/test_model_compression.py`:
- Around line 349-400: The test in test_dpmodel_compress_entrypoint only
validates serialized metadata after enable_compression and load_dp_model, so it
misses descriptor corruption in the actual model state. After reloading the
compressed model, run one forward/inference pass on the compressed model and
compare its output to the original model from get_model/serialize to ensure the
entrypoint preserves usable behavior, not just the presence of compress and
min_nbor_dist.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 157bc41e-23bb-4ee8-b572-69a0856c0ff3
📒 Files selected for processing (4)
deepmd/dpmodel/descriptor/dpa1.pydeepmd/dpmodel/descriptor/se_e2_a.pydeepmd/dpmodel/descriptor/se_r.pysource/tests/common/dpmodel/test_model_compression.py
🚧 Files skipped from review as they are similar to previous changes (3)
- deepmd/dpmodel/descriptor/se_r.py
- deepmd/dpmodel/descriptor/se_e2_a.py
- deepmd/dpmodel/descriptor/dpa1.py
## Summary - add dpmodel compression entrypoints and wire dp --dp/--jax compress - persist and restore compressed dpmodel/JAX descriptor state, including HLO export metadata - implement dpmodel compression for se_e2_a, se_e2_r, and se_atten/dpa1 (with se_atten_v2 state sync) - add common dpmodel and JAX compression coverage ## Tests - python -m pytest source/tests/common/dpmodel/test_model_compression.py -q - python -m pytest source/tests/jax/test_model_compression.py -q - python -m pytest source/tests/common/test_argument_parser.py -q - python -m pytest source/tests/pt_expt/descriptor/test_se_r.py source/tests/pt_expt/descriptor/test_dpa1.py source/tests/pt_expt/descriptor/test_se_atten_v2.py -q - python -m pytest source/tests/pt_expt/model/test_model_compression.py -q - ruff check . - ruff format --check . <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added model compression support for DPModel and JAX workflows. * Introduced a new `compress` command in the CLI for creating compressed model files. * Expanded support for compressed descriptors, including saving and restoring compressed models. * **Documentation** * Updated CLI help text with new examples and supported model file formats. * **Tests** * Added end-to-end and descriptor-level coverage for compression behavior and model reloads. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
Summary
Tests
Summary by CodeRabbit
New Features
compresscommand in the CLI for creating compressed model files.Documentation
Tests