Skip to content

feat(dpmodel): add descriptor compression#5592

Merged
njzjz merged 7 commits into
deepmodeling:masterfrom
njzjz:feat/dpmodel-compress-descriptors
Jun 29, 2026
Merged

feat(dpmodel): add descriptor compression#5592
njzjz merged 7 commits into
deepmodeling:masterfrom
njzjz:feat/dpmodel-compress-descriptors

Conversation

@njzjz

@njzjz njzjz commented Jun 26, 2026

Copy link
Copy Markdown
Member

Summary

  • add dpmodel compression entrypoints and wire dp --dp/--jax compress
  • persist and restore compressed dpmodel/JAX descriptor state, including HLO export metadata
  • implement dpmodel compression for se_e2_a, se_e2_r, and se_atten/dpa1 (with se_atten_v2 state sync)
  • add common dpmodel and JAX compression coverage

Tests

  • python -m pytest source/tests/common/dpmodel/test_model_compression.py -q
  • python -m pytest source/tests/jax/test_model_compression.py -q
  • python -m pytest source/tests/common/test_argument_parser.py -q
  • python -m pytest source/tests/pt_expt/descriptor/test_se_r.py source/tests/pt_expt/descriptor/test_dpa1.py source/tests/pt_expt/descriptor/test_se_atten_v2.py -q
  • python -m pytest source/tests/pt_expt/model/test_model_compression.py -q
  • ruff check .
  • ruff format --check .

Summary by CodeRabbit

  • New Features

    • Added model compression support for DPModel and JAX workflows.
    • Introduced a new compress command in the CLI for creating compressed model files.
    • Expanded support for compressed descriptors, including saving and restoring compressed models.
  • Documentation

    • Updated CLI help text with new examples and supported model file formats.
  • Tests

    • Added end-to-end and descriptor-level coverage for compression behavior and model reloads.

@coderabbitai

coderabbitai Bot commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds descriptor compression support for DescrptSeA, DescrptSeR, and DescrptDPA1 in the DPModel backend, including tabulation-based forward paths and serialize/deserialize with version 3 payloads. Introduces shared compress_common helpers, new compress CLI entrypoints for DPModel and JAX backends, JAX serialization utilities for restoring compression slots from Orbax checkpoints, and corresponding unit/integration tests.

Changes

Descriptor Compression and CLI Entrypoints

Layer / File(s) Summary
Shared compression helpers and CLI wiring
deepmd/dpmodel/entrypoints/compress_common.py, deepmd/dpmodel/entrypoints/compress.py, deepmd/dpmodel/entrypoints/main.py, deepmd/dpmodel/entrypoints/__init__.py, deepmd/jax/entrypoints/compress.py, deepmd/jax/entrypoints/main.py, deepmd/backend/dpmodel.py, deepmd/main.py
compress_common provides resolve_min_nbor_dist and enable_model_compression; DPModel and JAX entrypoints dispatch the compress command; DPModelBackend.entry_point_hook returns the new main callable instead of raising; CLI help text documents .dp/.hlo/.jax suffixes.
SeR and SeA descriptor compression
deepmd/dpmodel/descriptor/se_r.py, deepmd/dpmodel/descriptor/se_e2_a.py
DescrptSeR and DescrptSeA gain enable_compression, DPTabulate-based tabulation, _tabulate_fusion_se_r/_tabulate_fusion_se_a helpers, compressed call paths, and serialize/deserialize at @version: 3 with a compress payload.
DPA1 descriptor compression (tebd + geometric)
deepmd/dpmodel/descriptor/dpa1.py, deepmd/dpmodel/descriptor/se_atten_v2.py
DescrptDPA1 and DescrptBlockSeAtten gain tebd_compress/geo_compress state, enable_compression with strip-mode validation, type_embedding_compression, _tabulate_fusion_se_atten, compressed call branching using precomputed type_embd_data, and serialization with compress_data/compress_info.
JAX serialization compression slot restoration
deepmd/jax/utils/serialization.py
Adds _restore_compression_slots_from_state to rebuild NNX compression attributes from Orbax state before serialization; injects min_nbor_dist into .jax/.hlo artifacts during deserialization and includes it in serialize_from_file output.
Compression tests
source/tests/common/dpmodel/test_model_compression.py, source/tests/jax/test_model_compression.py
Tests cover tabulation extrapolation (C1 continuity for SeR, SeA, DPA1), descriptor compression flags/metadata/output parity before and after deserialize, and end-to-end CLI compress entrypoints with min_nbor_dist preservation for both DPModel and JAX backends.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5323: Directly precedes this PR's work — adds the initial compress serialization/deserialization payload and version handling in dpa1.py that this PR greatly expands with enable_compression, tebd_compress, geo_compress, and tabulation logic.
  • deepmodeling/deepmd-kit#5428: Adjusted embedding-net output ranges and float32 compression test tolerances for DPA1/SeAttenV2, directly affecting the compressed-forward tests validated in this PR.

Suggested labels

Docs

Suggested reviewers

  • iProzd
  • wanghan-iapcm
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.30% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title is concise and matches the main theme of descriptor compression, though it understates the added JAX entrypoints and serialization work.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (3)
source/tests/jax/test_model_compression.py (1)

59-82: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Add an entrypoint round-trip case for DPA1 compression.

The descriptor tests cover DPA1 in-memory serialization, but both CLI/Orbax entrypoint tests use only se_e2_a. Add a DPA1 model-data variant so serialize_from_file() exercises restored type_embd_data/geometric compression state from a real .jax checkpoint.

Also applies to: 216-301

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/jax/test_model_compression.py` around lines 59 - 82, Add a DPA1
model-data variant to the compression tests so the entrypoint round-trip covers
more than se_e2_a. Update the model fixture-building helpers in
test_model_compression.py, especially _make_model_data and any related CLI/Orbax
test setup, to create a DPA1 checkpoint and verify serialize_from_file()
restores type_embd_data and geometric compression state from the saved .jax
file. Keep the existing se_e2_a coverage, but add a separate DPA1 case that
exercises the same entrypoint flow end-to-end.
deepmd/dpmodel/entrypoints/compress.py (2)

53-93: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value

Consider de-duplicating the compress entrypoint logic shared with the JAX path.

enable_compression, _compute_min_nbor_dist, and the _get_saved_min_nbor_dist resolution are nearly identical to deepmd/jax/entrypoints/compress.py. Extracting the backend-agnostic min_nbor_dist resolution/compression flow into a shared helper would reduce drift (e.g., this path skips the _to_float coercion the JAX path applies). Optional given the differing serialization imports per backend.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/entrypoints/compress.py` around lines 53 - 93, The compression
flow in enable_compression duplicates the JAX entrypoint logic, so extract the
shared min_nbor_dist resolution and compression steps into a backend-agnostic
helper used by both paths. Keep the backend-specific model loading/saving in
place, but centralize the common _get_saved_min_nbor_dist,
_compute_min_nbor_dist, and model.enable_compression handling so behavior stays
aligned and drift like missing _to_float coercion is avoided.

82-87: 🎯 Functional Correctness | 🔵 Trivial

Confirm the positional argument order for model.enable_compression.

The signature of BaseModel.enable_compression is:
(table_extrapolate, table_stride_1, table_stride_2, check_frequency).

The call at lines 82-87:

model.enable_compression(
    extrapolate,
    stride,
    stride * 10,
    check_frequency,
)

correctly maps to the signature parameters in order. The calculation stride * 10 correctly assigns to table_stride_2.

While the current positional usage is correct, using explicit keyword arguments would improve readability and prevent future regressions if the signature changes:

model.enable_compression(
    table_extrapolate=extrapolate,
    table_stride_1=stride,
    table_stride_2=stride * 10,
    check_frequency=check_frequency,
)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/entrypoints/compress.py` around lines 82 - 87, The call to
model.enable_compression already matches BaseModel.enable_compression’s
parameter order, but it should be rewritten with explicit keyword arguments for
clarity and to avoid regressions; update the enable_compression call in
compress.py to name table_extrapolate, table_stride_1, table_stride_2, and
check_frequency explicitly while keeping the same values.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/descriptor/se_e2_a.py`:
- Line 738: The compressed `DescrptSeAArrayAPI` path uses an in-place slice
update on `gr_s`, which is not JAX-compatible. Update the accumulation logic
around `self._tabulate_fusion_se_a(...)` to use functional array updates (for
example via the array API update pattern on `gr_s`) instead of direct slice
assignment. Make sure the fix is applied in the `DescrptSeAArrayAPI` flow so
both `type_one_side=True` and `type_one_side=False` work correctly under JAX.

In `@deepmd/jax/utils/serialization.py`:
- Around line 47-54: The restore logic in the serialization helper only
recreates compress_data and compress_info, but DPA1/se_atten_v2 compressed
checkpoints also need type_embd_data to be restored before replace_by_pure_dict
runs. Update the state-reconstruction branch in the helper that checks
obj.compress so it also detects and assigns type_embd_data from state, using the
same sequence-to-numpy conversion pattern as the other compressed slots, so
compressed DPA1 checkpoints keep matching keys and preserve the type-embedding
payload.

---

Nitpick comments:
In `@deepmd/dpmodel/entrypoints/compress.py`:
- Around line 53-93: The compression flow in enable_compression duplicates the
JAX entrypoint logic, so extract the shared min_nbor_dist resolution and
compression steps into a backend-agnostic helper used by both paths. Keep the
backend-specific model loading/saving in place, but centralize the common
_get_saved_min_nbor_dist, _compute_min_nbor_dist, and model.enable_compression
handling so behavior stays aligned and drift like missing _to_float coercion is
avoided.
- Around line 82-87: The call to model.enable_compression already matches
BaseModel.enable_compression’s parameter order, but it should be rewritten with
explicit keyword arguments for clarity and to avoid regressions; update the
enable_compression call in compress.py to name table_extrapolate,
table_stride_1, table_stride_2, and check_frequency explicitly while keeping the
same values.

In `@source/tests/jax/test_model_compression.py`:
- Around line 59-82: Add a DPA1 model-data variant to the compression tests so
the entrypoint round-trip covers more than se_e2_a. Update the model
fixture-building helpers in test_model_compression.py, especially
_make_model_data and any related CLI/Orbax test setup, to create a DPA1
checkpoint and verify serialize_from_file() restores type_embd_data and
geometric compression state from the saved .jax file. Keep the existing se_e2_a
coverage, but add a separate DPA1 case that exercises the same entrypoint flow
end-to-end.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: e7e1c4a8-5fb1-4cfb-8f9c-64794a8b169e

📥 Commits

Reviewing files that changed from the base of the PR and between 330fa75 and fea615f.

📒 Files selected for processing (14)
  • deepmd/backend/dpmodel.py
  • deepmd/dpmodel/descriptor/dpa1.py
  • deepmd/dpmodel/descriptor/se_atten_v2.py
  • deepmd/dpmodel/descriptor/se_e2_a.py
  • deepmd/dpmodel/descriptor/se_r.py
  • deepmd/dpmodel/entrypoints/__init__.py
  • deepmd/dpmodel/entrypoints/compress.py
  • deepmd/dpmodel/entrypoints/main.py
  • deepmd/jax/entrypoints/compress.py
  • deepmd/jax/entrypoints/main.py
  • deepmd/jax/utils/serialization.py
  • deepmd/main.py
  • source/tests/common/dpmodel/test_model_compression.py
  • source/tests/jax/test_model_compression.py

Comment thread deepmd/dpmodel/descriptor/se_e2_a.py Outdated
Comment thread deepmd/jax/utils/serialization.py Outdated
@codecov

codecov Bot commented Jun 26, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 84.17850% with 78 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.42%. Comparing base (5082854) to head (a681715).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/jax/utils/serialization.py 83.01% 18 Missing ⚠️
deepmd/dpmodel/entrypoints/main.py 0.00% 16 Missing ⚠️
deepmd/dpmodel/entrypoints/compress_common.py 74.41% 11 Missing ⚠️
deepmd/dpmodel/descriptor/dpa1.py 89.69% 10 Missing ⚠️
deepmd/dpmodel/descriptor/se_atten_v2.py 0.00% 9 Missing ⚠️
deepmd/dpmodel/descriptor/se_e2_a.py 92.78% 7 Missing ⚠️
deepmd/dpmodel/descriptor/se_r.py 88.37% 5 Missing ⚠️
deepmd/backend/dpmodel.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5592      +/-   ##
==========================================
+ Coverage   82.37%   82.42%   +0.05%     
==========================================
  Files         902      907       +5     
  Lines      101527   101998     +471     
  Branches     4056     4056              
==========================================
+ Hits        83630    84074     +444     
- Misses      16432    16459      +27     
  Partials     1465     1465              

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
source/lib/tests/test_tabulate_extrapolate.cc (2)

202-210: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value

grad_central_diff is identical to central_diff. Both compute the same central-difference formula. Either reuse central_diff for the constant-gradient check or, if the second helper is meant to document a distinct intent (numerically differentiating the gradient), add a brief comment; the duplicated body is otherwise confusing.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/lib/tests/test_tabulate_extrapolate.cc` around lines 202 - 210,
grad_central_diff duplicates central_diff with the same central-difference
formula, so either call central_diff from grad_central_diff to reuse the shared
logic or add a short comment in grad_central_diff explaining the distinct intent
if it is meant to represent differentiating the gradient. Update the test
helpers in test_tabulate_extrapolate.cc so the relationship between central_diff
and grad_central_diff is explicit and not misleading.

232-268: 📐 Maintainability & Code Quality | 🔵 Trivial | 🏗️ Heavy lift

Second-order (*_grad_grad) extrapolation paths are untested.

These tests validate value and first-gradient linear tails, but the substantially-changed tabulate_fusion_se_a/se_t/se_t_tebd/se_r_grad_grad_* paths (which now fold the extrapolation term into var = poly5 + var_grad * extrapolate_delta) get no coverage here. Second-order math is the easiest to get wrong; consider adding grad_grad cases (e.g. a forward-over-reverse / numeric check on dz_dy below-lower and above-max).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/lib/tests/test_tabulate_extrapolate.cc` around lines 232 - 268, Add
coverage for the second-order extrapolation paths in the tabulation tests: the
current cases in TabulateExtrapolate only exercise value and first-derivative
linear tails, but do not validate the new `*_grad_grad` logic in
`tabulate_fusion_se_a`, `se_r`, `se_t`, and `se_t_tebd`. Extend the existing
test helpers around `expect_linear_tail`, `expect_boundary`, and the
`se_*_grad`/`se_*_value` checks with cases that call the `*_grad_grad` functions
below `kLower`/`kMin` and above `kMax`, and verify the second-derivative output
against a numeric or forward-over-reverse expectation. Use the existing
`locate_se_a_or_r`, `locate_se_t`, and `expected_table_*` helpers to keep the
new cases aligned with the current test structure.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@source/lib/tests/test_tabulate_extrapolate.cc`:
- Around line 202-210: grad_central_diff duplicates central_diff with the same
central-difference formula, so either call central_diff from grad_central_diff
to reuse the shared logic or add a short comment in grad_central_diff explaining
the distinct intent if it is meant to represent differentiating the gradient.
Update the test helpers in test_tabulate_extrapolate.cc so the relationship
between central_diff and grad_central_diff is explicit and not misleading.
- Around line 232-268: Add coverage for the second-order extrapolation paths in
the tabulation tests: the current cases in TabulateExtrapolate only exercise
value and first-derivative linear tails, but do not validate the new
`*_grad_grad` logic in `tabulate_fusion_se_a`, `se_r`, `se_t`, and `se_t_tebd`.
Extend the existing test helpers around `expect_linear_tail`, `expect_boundary`,
and the `se_*_grad`/`se_*_value` checks with cases that call the `*_grad_grad`
functions below `kLower`/`kMin` and above `kMax`, and verify the
second-derivative output against a numeric or forward-over-reverse expectation.
Use the existing `locate_se_a_or_r`, `locate_se_t`, and `expected_table_*`
helpers to keep the new cases aligned with the current test structure.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 2ae98b78-4092-4e73-98f2-b7625c53eab2

📥 Commits

Reviewing files that changed from the base of the PR and between 52f8c44 and 661b3bd.

📒 Files selected for processing (3)
  • source/lib/src/gpu/tabulate.cu
  • source/lib/src/tabulate.cc
  • source/lib/tests/test_tabulate_extrapolate.cc

@njzjz njzjz force-pushed the feat/dpmodel-compress-descriptors branch from 661b3bd to 52f8c44 Compare June 28, 2026 16:04
@njzjz njzjz requested a review from wanghan-iapcm June 28, 2026 16:20

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
source/tests/common/dpmodel/test_model_compression.py (1)

349-400: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Exercise one inference pass on the reloaded compressed model.

This test only checks serialized metadata. If the entrypoint preserves compress and min_nbor_dist but corrupts descriptor state during save/load, it still passes. Comparing one pre/post forward call on the fixture would close that gap.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/common/dpmodel/test_model_compression.py` around lines 349 -
400, The test in test_dpmodel_compress_entrypoint only validates serialized
metadata after enable_compression and load_dp_model, so it misses descriptor
corruption in the actual model state. After reloading the compressed model, run
one forward/inference pass on the compressed model and compare its output to the
original model from get_model/serialize to ensure the entrypoint preserves
usable behavior, not just the presence of compress and min_nbor_dist.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@source/tests/common/dpmodel/test_model_compression.py`:
- Around line 176-185: The `np.testing.assert_allclose` checks in
`test_model_compression` are too lenient because they still use the default
relative tolerance, so make the extrapolation/parity assertions strict by
setting `rtol=0` wherever `assert_allclose` is used in this test. Update the
affected checks in `test_model_compression` (including the repeated cases later
in the file) so they rely only on the explicit `atol` values.

---

Nitpick comments:
In `@source/tests/common/dpmodel/test_model_compression.py`:
- Around line 349-400: The test in test_dpmodel_compress_entrypoint only
validates serialized metadata after enable_compression and load_dp_model, so it
misses descriptor corruption in the actual model state. After reloading the
compressed model, run one forward/inference pass on the compressed model and
compare its output to the original model from get_model/serialize to ensure the
entrypoint preserves usable behavior, not just the presence of compress and
min_nbor_dist.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 157bc41e-23bb-4ee8-b572-69a0856c0ff3

📥 Commits

Reviewing files that changed from the base of the PR and between 661b3bd and 48e01a8.

📒 Files selected for processing (4)
  • deepmd/dpmodel/descriptor/dpa1.py
  • deepmd/dpmodel/descriptor/se_e2_a.py
  • deepmd/dpmodel/descriptor/se_r.py
  • source/tests/common/dpmodel/test_model_compression.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • deepmd/dpmodel/descriptor/se_r.py
  • deepmd/dpmodel/descriptor/se_e2_a.py
  • deepmd/dpmodel/descriptor/dpa1.py

Comment thread source/tests/common/dpmodel/test_model_compression.py Outdated
Comment thread deepmd/dpmodel/descriptor/dpa1.py
@njzjz njzjz requested a review from wanghan-iapcm June 29, 2026 11:11
Comment thread deepmd/dpmodel/descriptor/se_e2_a.py
Comment thread deepmd/dpmodel/descriptor/se_e2_a.py
@njzjz njzjz requested a review from wanghan-iapcm June 29, 2026 12:09
@njzjz njzjz added this pull request to the merge queue Jun 29, 2026
Merged via the queue into deepmodeling:master with commit f143171 Jun 29, 2026
70 checks passed
@njzjz njzjz deleted the feat/dpmodel-compress-descriptors branch June 29, 2026 23:19
SchrodingersCattt pushed a commit to SchrodingersCattt/deepmd-kit that referenced this pull request Jun 30, 2026
## Summary
- add dpmodel compression entrypoints and wire dp --dp/--jax compress
- persist and restore compressed dpmodel/JAX descriptor state, including
HLO export metadata
- implement dpmodel compression for se_e2_a, se_e2_r, and se_atten/dpa1
(with se_atten_v2 state sync)
- add common dpmodel and JAX compression coverage

## Tests
- python -m pytest source/tests/common/dpmodel/test_model_compression.py
-q
- python -m pytest source/tests/jax/test_model_compression.py -q
- python -m pytest source/tests/common/test_argument_parser.py -q
- python -m pytest source/tests/pt_expt/descriptor/test_se_r.py
source/tests/pt_expt/descriptor/test_dpa1.py
source/tests/pt_expt/descriptor/test_se_atten_v2.py -q
- python -m pytest source/tests/pt_expt/model/test_model_compression.py
-q
- ruff check .
- ruff format --check .

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
  * Added model compression support for DPModel and JAX workflows.
* Introduced a new `compress` command in the CLI for creating compressed
model files.
* Expanded support for compressed descriptors, including saving and
restoring compressed models.

* **Documentation**
* Updated CLI help text with new examples and supported model file
formats.

* **Tests**
* Added end-to-end and descriptor-level coverage for compression
behavior and model reloads.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants