Skip to content

[WIP] Add backend-agnostic semi-discrete OT module and SGD-based solver in ot.semidiscrete#812

Open
Ferdinand-Genans wants to merge 11 commits into
PythonOT:masterfrom
Ferdinand-Genans:feature/semi-discrete
Open

[WIP] Add backend-agnostic semi-discrete OT module and SGD-based solver in ot.semidiscrete#812
Ferdinand-Genans wants to merge 11 commits into
PythonOT:masterfrom
Ferdinand-Genans:feature/semi-discrete

Conversation

@Ferdinand-Genans

Copy link
Copy Markdown

Types of changes

  • New feature (non-breaking change which adds functionality)
  • Bug fix (non-breaking change which fixes an issue)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation / Examples update

Motivation and context / Related issue

Semi-discrete optimal transport — continuous source, discrete target — is a setting that does not naturally fit POT's discrete-discrete solvers, yet appears in many applications: Monge–Kantorovich statistical depth and quantiles, generative modeling with a continuous prior, Brenier-map estimation, and any setup where the source is given by a sampler rather than a finite empirical distribution.

This PR adds ot.semidiscrete, a single-file backend-agnostic solver for the semi-dual of semi-discrete OT. Every public function takes a reg: float argument, so the user can switch between unregularized OT (reg=0) and the entropic formulation (reg>0) with a single parameter. The underlying algorithm is Averaged SGD on the semi-dual; optionally, the regularization can be decreased throughout the iterations via the DRAG schedule ([83] in README.md; Genans et al., NeurIPS 2025), which empirically improves convergence, especially on large scale problems, when the target regularization is small.

No existing issue tracks this addition.

How has this been tested

  • test/test_semidiscrete.py34 tests, automatically parametrized over every available POT backend (NumPy and PyTorch on the CI runner). Covers:
    • convergence to a known closed-form optimum on three toy problems (regular grid, nonuniform target weights, shifted 1D);
    • both plain SGD and the DRAG schedule;
    • the entropic regime;
    • a user-supplied custom cost;
    • helper functions (atom_weights row-stochasticity, c_transform reducing to min_j c(x, y_j) at g = 0, ot_map shape and finiteness);
    • solver options (warm start, projection bound, log dict, polyak_average=False, init_potential is not mutated).
  • One doctest on solve_semidiscrete (a 500-iter smoke run).
  • Full suite + doctest run locally: 35 passed in 5.88 s.
  • examples/others/plot_semidiscrete.py runs end-to-end in ~5 s on CPU NumPy and produces the Laguerre-cells figure that becomes the gallery page.
  • Also verified locally a larger scale problem on GPU with PyTorch: discrete measure with 20 000 atoms and 30 000 iterations of minibatch of 64 run in ~17 s on a single CUDA device.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

Introduces ot.semidiscrete: Projected Averaged SGD on the semi-dual,
with an optional decreasing entropic-regularization schedule (DRAG)
from Genans et al. 2025. Works with NumPy, PyTorch, JAX, CuPy and
TensorFlow via ot.backend.

- ot/semidiscrete.py: solve_semidiscrete, atom_weights, ot_map,
  c_transform. Closed-form gradient, no autograd graph through the
  loop; quadratic cost by default with custom-callable override.
- ot/__init__.py: register the new submodule.
- test/test_semidiscrete.py: convergence on three toy problems
  with known optimal potentials, helper-function contracts (row-
  stochasticity of atom_weights, identity for c_transform at g=0,
  shape and finiteness of ot_map), and solver options (warm-start,
  projection, log, polyak_average off, entropic regime, custom
  cost). All tests parametrized over the nx fixture (NumPy + PyTorch).
- examples/others/plot_semidiscrete.py: gallery example on a small
  2D toy problem with Laguerre cells, empirical cell masses and a
  Monte Carlo estimate of the semi-dual cost.
- RELEASES.md: new-features entry under 0.9.7.dev0.
	modified:   RELEASES.md
	modified:   examples/others/plot_semidiscrete.py

Final small doc modifications.
solve_semidiscrete, and added a more detailed
explanation of the effect of this argument on
convergence in the example scipt.
@codecov

codecov Bot commented May 22, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 96.84%. Comparing base (3e10dd6) to head (a5aa34d).

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #812      +/-   ##
==========================================
+ Coverage   96.81%   96.84%   +0.02%     
==========================================
  Files         124      126       +2     
  Lines       24276    24502     +226     
==========================================
+ Hits        23503    23729     +226     
  Misses        773      773              
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@rflamary

Copy link
Copy Markdown
Collaborator

Hello @Ferdinand-Genans ,

thanks for the PR, we have since merged a paper, could you please resolve conflits and update the reference number ()there is already a 83 now in teh readme file).

I will have a look and do a proper code review ASAP.

@rflamary rflamary changed the title [MRG] Add backend-agnostic semi-discrete OT module and SGD-based solver in ot.semidiscrete [WIP] Add backend-agnostic semi-discrete OT module and SGD-based solver in ot.semidiscrete May 27, 2026
- README.md: keep new [83] (Spectral-Grassmann) and [84] (BSP-OT) from
  upstream; renumber the DRAG citation to [85].
- RELEASES.md: keep both feature bullets (semidiscrete + sgot).
- ot/__init__.py: keep both new submodule imports (semidiscrete + sgot).
- examples/others/plot_semidiscrete.py: update local refs [83] -> [85]
  consistently with the README.
@Ferdinand-Genans Ferdinand-Genans force-pushed the feature/semi-discrete branch from bc2b63c to 50f4094 Compare May 29, 2026 08:04
@Ferdinand-Genans

Copy link
Copy Markdown
Author

Hi @rflamary,
I think it is good now.
Best regards

@rflamary rflamary left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @Ferdinand-Genans ,

This is a very nice PR, I did a code reveiw and have a few comments mainly on the API to make i more consistent with the other POT functions. Also some suggetsion for the example.

Comment thread ot/semidiscrete.py
return nx.where(mask, one, zero)


def atom_weights(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def atom_weights(
def semidiscrete_atom_weights(

Comment thread ot/semidiscrete.py
return _atom_weights(score, reg, log_b, nx)


def ot_map(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def ot_map(
def semidiscrete_ot_map(

Comment thread ot/semidiscrete.py
cost=None,
reg=0.0,
):
r"""Row-stochastic atom-assignment weights induced by ``semi_dual_potential``.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please do a proper documentation with detailed input and output description. define what is computed with math equations and refer to eq. in papers/books.

Comment thread ot/semidiscrete.py
cost=None,
reg=0.0,
):
r"""Transport map :math:`T(x) = \sum_j w_j(x)\, y_j` induced by the potential."""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here alos add detailed documentation

# A single call to :func:`solve_semidiscrete` runs DRAG with the default
# arguments (``decreasing_reg=True``). We show the initial Voronoi cells
# (:math:`g = 0`) next to the Laguerre cells at the optimum.
# In this problem, the maximum cost between samples is 1.0, so we pass it as

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it 1? I would say since it is the sqare it would be \sqrt(2) no? what max cost is it between the continuous dist and the discrete one?

plot_laguerre_cells(target_positions, g_drag, axes[1], "DRAG")
plt.tight_layout()
plt.show()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe do a visualization of the map from a grid on source with arrows? it would e a nice complementary visualization than the cells and would show an example of the use of the mapping function

# .. math::
# \mathcal{S}(g) = \langle g, b\rangle + \mathbb{E}_X[\varphi_g(X)]
#
# estimated by Monte Carlo. The solver maximises :math:`\mathcal{S}`.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

visualize the mass with a bar plot to see the small discrepancy (with ylim around 1/15 and betwen 0 and 1/15 to show clearly)?

Comment thread ot/semidiscrete.py
Comment on lines +132 to +149
def solve_semidiscrete(
target_positions,
source_sampler,
target_weights=None,
cost=None,
reg=0.0,
n_iter=10_000,
batch_size=32,
lr0=None,
lr_exponent=2.0 / 3.0,
init_potential=None,
decreasing_reg=True,
decreasing_reg_initial_eps=0.1,
decreasing_reg_exponent=0.5,
max_cost=None,
polyak_average=True,
log=False,
):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def solve_semidiscrete(
target_positions,
source_sampler,
target_weights=None,
cost=None,
reg=0.0,
n_iter=10_000,
batch_size=32,
lr0=None,
lr_exponent=2.0 / 3.0,
init_potential=None,
decreasing_reg=True,
decreasing_reg_initial_eps=0.1,
decreasing_reg_exponent=0.5,
max_cost=None,
polyak_average=True,
log=False,
):
def solve_semidiscrete(
X_target,
sampler_source,
a_target=None,
metric=None,
reg=0.0,
max_iter=10_000,
batch_size=32,
lr0=None,
lr_exponent=2.0 / 3.0,
init_potential=None,
decreasing_reg=True,
decreasing_reg_initial_eps=0.1,
decreasing_reg_exponent=0.5,
max_metric=None,
polyak_average=True,
log=False,
):

Renamed parameter to have a closer API to ot.solve_sample. all other public functions shoudl also have renamed params (map, weight, c_transform)

cost-> metric should accept 'sqeuclidean' by default but also euclidean and other stuff from ot.dist (maybe should call ot.dist if not a function?)

sampler_source should accept 'unif', 'unit_ball' et 'normal' in addition to a function that takes the size of eh batch in parameter

Comment thread test/test_semidiscrete.py
return float(np.linalg.norm(estimated - reference))


def lift(nx, target_np, weights_np):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can replace this function by
target, weights = nx.from_numpy(target_np,weights)

no need for a new function there

Comment thread ot/semidiscrete.py
Comment on lines +132 to +149
def solve_semidiscrete(
target_positions,
source_sampler,
target_weights=None,
cost=None,
reg=0.0,
n_iter=10_000,
batch_size=32,
lr0=None,
lr_exponent=2.0 / 3.0,
init_potential=None,
decreasing_reg=True,
decreasing_reg_initial_eps=0.1,
decreasing_reg_exponent=0.5,
max_cost=None,
polyak_average=True,
log=False,
):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def solve_semidiscrete(
target_positions,
source_sampler,
target_weights=None,
cost=None,
reg=0.0,
n_iter=10_000,
batch_size=32,
lr0=None,
lr_exponent=2.0 / 3.0,
init_potential=None,
decreasing_reg=True,
decreasing_reg_initial_eps=0.1,
decreasing_reg_exponent=0.5,
max_cost=None,
polyak_average=True,
log=False,
):
def solve_semidiscrete(
X_target,
sampler_source,
a_target=None,
metric=None,
reg=0.0,
max_iter=10_000,
batch_size=32,
lr0=None,
lr_exponent=2.0 / 3.0,
init_potential=None,
decreasing_reg=True,
decreasing_reg_initial_eps=0.1,
decreasing_reg_exponent=0.5,
max_cost=None,
polyak_average=True,
log=False,
):

few change in name of parameters to be more consistent with ot.solve_sample. please change them in all public facing function (atom_weight, map, c_transform)

sampler_source should accept also strings for 'unif' (or 'unif_cube'), 'ball' or 'unif_ball' and 'normal' if not callable , should be unif by default.

cost -> metric should also accept any existing string from ot.dist (it should use ot.dist in the function) if not callable. shoudl be 'sqeuclidean' by default.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants