[WIP] Add backend-agnostic semi-discrete OT module and SGD-based solver in ot.semidiscrete#812
[WIP] Add backend-agnostic semi-discrete OT module and SGD-based solver in ot.semidiscrete#812Ferdinand-Genans wants to merge 11 commits into
Conversation
Introduces ot.semidiscrete: Projected Averaged SGD on the semi-dual, with an optional decreasing entropic-regularization schedule (DRAG) from Genans et al. 2025. Works with NumPy, PyTorch, JAX, CuPy and TensorFlow via ot.backend. - ot/semidiscrete.py: solve_semidiscrete, atom_weights, ot_map, c_transform. Closed-form gradient, no autograd graph through the loop; quadratic cost by default with custom-callable override. - ot/__init__.py: register the new submodule. - test/test_semidiscrete.py: convergence on three toy problems with known optimal potentials, helper-function contracts (row- stochasticity of atom_weights, identity for c_transform at g=0, shape and finiteness of ot_map), and solver options (warm-start, projection, log, polyak_average off, entropic regime, custom cost). All tests parametrized over the nx fixture (NumPy + PyTorch). - examples/others/plot_semidiscrete.py: gallery example on a small 2D toy problem with Laguerre cells, empirical cell masses and a Monte Carlo estimate of the semi-dual cost. - RELEASES.md: new-features entry under 0.9.7.dev0.
modified: RELEASES.md modified: examples/others/plot_semidiscrete.py Final small doc modifications.
solve_semidiscrete, and added a more detailed explanation of the effect of this argument on convergence in the example scipt.
"max_cost" explanation.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #812 +/- ##
==========================================
+ Coverage 96.81% 96.84% +0.02%
==========================================
Files 124 126 +2
Lines 24276 24502 +226
==========================================
+ Hits 23503 23729 +226
Misses 773 773 🚀 New features to boost your workflow:
|
|
Hello @Ferdinand-Genans , thanks for the PR, we have since merged a paper, could you please resolve conflits and update the reference number ()there is already a 83 now in teh readme file). I will have a look and do a proper code review ASAP. |
- README.md: keep new [83] (Spectral-Grassmann) and [84] (BSP-OT) from upstream; renumber the DRAG citation to [85]. - RELEASES.md: keep both feature bullets (semidiscrete + sgot). - ot/__init__.py: keep both new submodule imports (semidiscrete + sgot). - examples/others/plot_semidiscrete.py: update local refs [83] -> [85] consistently with the README.
bc2b63c to
50f4094
Compare
|
Hi @rflamary, |
rflamary
left a comment
There was a problem hiding this comment.
Hello @Ferdinand-Genans ,
This is a very nice PR, I did a code reveiw and have a few comments mainly on the API to make i more consistent with the other POT functions. Also some suggetsion for the example.
| return nx.where(mask, one, zero) | ||
|
|
||
|
|
||
| def atom_weights( |
There was a problem hiding this comment.
| def atom_weights( | |
| def semidiscrete_atom_weights( |
| return _atom_weights(score, reg, log_b, nx) | ||
|
|
||
|
|
||
| def ot_map( |
There was a problem hiding this comment.
| def ot_map( | |
| def semidiscrete_ot_map( |
| cost=None, | ||
| reg=0.0, | ||
| ): | ||
| r"""Row-stochastic atom-assignment weights induced by ``semi_dual_potential``. |
There was a problem hiding this comment.
please do a proper documentation with detailed input and output description. define what is computed with math equations and refer to eq. in papers/books.
| cost=None, | ||
| reg=0.0, | ||
| ): | ||
| r"""Transport map :math:`T(x) = \sum_j w_j(x)\, y_j` induced by the potential.""" |
There was a problem hiding this comment.
Same here alos add detailed documentation
| # A single call to :func:`solve_semidiscrete` runs DRAG with the default | ||
| # arguments (``decreasing_reg=True``). We show the initial Voronoi cells | ||
| # (:math:`g = 0`) next to the Laguerre cells at the optimum. | ||
| # In this problem, the maximum cost between samples is 1.0, so we pass it as |
There was a problem hiding this comment.
is it 1? I would say since it is the sqare it would be \sqrt(2) no? what max cost is it between the continuous dist and the discrete one?
| plot_laguerre_cells(target_positions, g_drag, axes[1], "DRAG") | ||
| plt.tight_layout() | ||
| plt.show() | ||
|
|
There was a problem hiding this comment.
maybe do a visualization of the map from a grid on source with arrows? it would e a nice complementary visualization than the cells and would show an example of the use of the mapping function
| # .. math:: | ||
| # \mathcal{S}(g) = \langle g, b\rangle + \mathbb{E}_X[\varphi_g(X)] | ||
| # | ||
| # estimated by Monte Carlo. The solver maximises :math:`\mathcal{S}`. |
There was a problem hiding this comment.
visualize the mass with a bar plot to see the small discrepancy (with ylim around 1/15 and betwen 0 and 1/15 to show clearly)?
| def solve_semidiscrete( | ||
| target_positions, | ||
| source_sampler, | ||
| target_weights=None, | ||
| cost=None, | ||
| reg=0.0, | ||
| n_iter=10_000, | ||
| batch_size=32, | ||
| lr0=None, | ||
| lr_exponent=2.0 / 3.0, | ||
| init_potential=None, | ||
| decreasing_reg=True, | ||
| decreasing_reg_initial_eps=0.1, | ||
| decreasing_reg_exponent=0.5, | ||
| max_cost=None, | ||
| polyak_average=True, | ||
| log=False, | ||
| ): |
There was a problem hiding this comment.
| def solve_semidiscrete( | |
| target_positions, | |
| source_sampler, | |
| target_weights=None, | |
| cost=None, | |
| reg=0.0, | |
| n_iter=10_000, | |
| batch_size=32, | |
| lr0=None, | |
| lr_exponent=2.0 / 3.0, | |
| init_potential=None, | |
| decreasing_reg=True, | |
| decreasing_reg_initial_eps=0.1, | |
| decreasing_reg_exponent=0.5, | |
| max_cost=None, | |
| polyak_average=True, | |
| log=False, | |
| ): | |
| def solve_semidiscrete( | |
| X_target, | |
| sampler_source, | |
| a_target=None, | |
| metric=None, | |
| reg=0.0, | |
| max_iter=10_000, | |
| batch_size=32, | |
| lr0=None, | |
| lr_exponent=2.0 / 3.0, | |
| init_potential=None, | |
| decreasing_reg=True, | |
| decreasing_reg_initial_eps=0.1, | |
| decreasing_reg_exponent=0.5, | |
| max_metric=None, | |
| polyak_average=True, | |
| log=False, | |
| ): |
Renamed parameter to have a closer API to ot.solve_sample. all other public functions shoudl also have renamed params (map, weight, c_transform)
cost-> metric should accept 'sqeuclidean' by default but also euclidean and other stuff from ot.dist (maybe should call ot.dist if not a function?)
sampler_source should accept 'unif', 'unit_ball' et 'normal' in addition to a function that takes the size of eh batch in parameter
| return float(np.linalg.norm(estimated - reference)) | ||
|
|
||
|
|
||
| def lift(nx, target_np, weights_np): |
There was a problem hiding this comment.
you can replace this function by
target, weights = nx.from_numpy(target_np,weights)
no need for a new function there
| def solve_semidiscrete( | ||
| target_positions, | ||
| source_sampler, | ||
| target_weights=None, | ||
| cost=None, | ||
| reg=0.0, | ||
| n_iter=10_000, | ||
| batch_size=32, | ||
| lr0=None, | ||
| lr_exponent=2.0 / 3.0, | ||
| init_potential=None, | ||
| decreasing_reg=True, | ||
| decreasing_reg_initial_eps=0.1, | ||
| decreasing_reg_exponent=0.5, | ||
| max_cost=None, | ||
| polyak_average=True, | ||
| log=False, | ||
| ): |
There was a problem hiding this comment.
| def solve_semidiscrete( | |
| target_positions, | |
| source_sampler, | |
| target_weights=None, | |
| cost=None, | |
| reg=0.0, | |
| n_iter=10_000, | |
| batch_size=32, | |
| lr0=None, | |
| lr_exponent=2.0 / 3.0, | |
| init_potential=None, | |
| decreasing_reg=True, | |
| decreasing_reg_initial_eps=0.1, | |
| decreasing_reg_exponent=0.5, | |
| max_cost=None, | |
| polyak_average=True, | |
| log=False, | |
| ): | |
| def solve_semidiscrete( | |
| X_target, | |
| sampler_source, | |
| a_target=None, | |
| metric=None, | |
| reg=0.0, | |
| max_iter=10_000, | |
| batch_size=32, | |
| lr0=None, | |
| lr_exponent=2.0 / 3.0, | |
| init_potential=None, | |
| decreasing_reg=True, | |
| decreasing_reg_initial_eps=0.1, | |
| decreasing_reg_exponent=0.5, | |
| max_cost=None, | |
| polyak_average=True, | |
| log=False, | |
| ): |
few change in name of parameters to be more consistent with ot.solve_sample. please change them in all public facing function (atom_weight, map, c_transform)
sampler_source should accept also strings for 'unif' (or 'unif_cube'), 'ball' or 'unif_ball' and 'normal' if not callable , should be unif by default.
cost -> metric should also accept any existing string from ot.dist (it should use ot.dist in the function) if not callable. shoudl be 'sqeuclidean' by default.
Types of changes
Motivation and context / Related issue
Semi-discrete optimal transport — continuous source, discrete target — is a setting that does not naturally fit POT's discrete-discrete solvers, yet appears in many applications: Monge–Kantorovich statistical depth and quantiles, generative modeling with a continuous prior, Brenier-map estimation, and any setup where the source is given by a sampler rather than a finite empirical distribution.
This PR adds
ot.semidiscrete, a single-file backend-agnostic solver for the semi-dual of semi-discrete OT. Every public function takes areg: floatargument, so the user can switch between unregularized OT (reg=0) and the entropic formulation (reg>0) with a single parameter. The underlying algorithm is Averaged SGD on the semi-dual; optionally, the regularization can be decreased throughout the iterations via the DRAG schedule ([83] inREADME.md; Genans et al., NeurIPS 2025), which empirically improves convergence, especially on large scale problems, when the target regularization is small.No existing issue tracks this addition.
How has this been tested
test/test_semidiscrete.py— 34 tests, automatically parametrized over every available POT backend (NumPy and PyTorch on the CI runner). Covers:atom_weightsrow-stochasticity,c_transformreducing tomin_j c(x, y_j)atg = 0,ot_mapshape and finiteness);polyak_average=False,init_potentialis not mutated).solve_semidiscrete(a 500-iter smoke run).35 passed in 5.88 s.examples/others/plot_semidiscrete.pyruns end-to-end in ~5 s on CPU NumPy and produces the Laguerre-cells figure that becomes the gallery page.PR checklist