Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
44d4614
merge
cedricvincentcuaz Sep 10, 2024
63477c2
Merge branch 'master' of https://github.com/cedricvincentcuaz/POT
cedricvincentcuaz Sep 10, 2024
a94c6ac
Merge branch 'master' of https://github.com/cedricvincentcuaz/POT
cedricvincentcuaz Nov 6, 2024
27944a5
Merge branch 'master' of https://github.com/cedricvincentcuaz/POT
cedricvincentcuaz Nov 6, 2024
0392961
Merge branch 'master' of https://github.com/cedricvincentcuaz/POT
cedricvincentcuaz Nov 16, 2024
60d1295
Merge branch 'master' of https://github.com/cedricvincentcuaz/POT
cedricvincentcuaz Apr 2, 2025
a93c60c
first commit
cedricvincentcuaz Apr 21, 2025
9e25e80
handle masses in unbalanced cases
cedricvincentcuaz Apr 22, 2025
46c4638
update free support
cedricvincentcuaz Apr 24, 2025
671788d
trying to fix tests
cedricvincentcuaz Apr 24, 2025
a5b0f70
Merge branch 'master' into solvers
rflamary May 23, 2025
9cf60fd
Merge branch 'master' into solvers
rflamary Jun 3, 2025
48a63ea
merge
cedricvincentcuaz Oct 11, 2025
df1da8d
small updates
cedricvincentcuaz Oct 12, 2025
75d6c11
fix fun name
cedricvincentcuaz Oct 12, 2025
c71e544
update tests
cedricvincentcuaz Oct 12, 2025
ed9c992
update tests
cedricvincentcuaz Oct 12, 2025
7132f62
fix tests
cedricvincentcuaz Oct 15, 2025
8e55777
Merge branch 'master' into solvers
rflamary Oct 21, 2025
4a0b5ec
fix tests
cedricvincentcuaz Mar 6, 2026
b8219ac
Merge branch 'solvers' of https://github.com/cedricvincentcuaz/POT in…
cedricvincentcuaz Mar 6, 2026
6f900c6
Merge branch 'master' into solvers
cedricvincentcuaz Mar 6, 2026
3d9f987
update docstring for solve_bary_sample
cedricvincentcuaz Mar 9, 2026
7b7cfc0
update plot quickstart guide
cedricvincentcuaz Mar 9, 2026
406d026
add ex
cedricvincentcuaz Mar 9, 2026
0172a89
fix ex
cedricvincentcuaz Mar 10, 2026
d493194
fix docs
cedricvincentcuaz Mar 11, 2026
95e5a1c
fix docs
cedricvincentcuaz Mar 11, 2026
05b47fc
fix sphinx
cedricvincentcuaz Mar 11, 2026
ab50009
Merge branch 'master' into solvers
rflamary May 27, 2026
1c70c2e
add callable cost functions
cedricvincentcuaz May 28, 2026
c8d48b2
merge
cedricvincentcuaz May 28, 2026
53d103b
Merge branch 'PythonOT:master' into solvers
cedricvincentcuaz May 28, 2026
1188611
merge
cedricvincentcuaz May 28, 2026
ea84a70
Merge branch 'master' into solvers
rflamary Jun 1, 2026
0d5c4d2
Merge branch 'master' into solvers
rflamary Jun 3, 2026
6098b72
update tests with callable metrics
cedricvincentcuaz Jun 8, 2026
17f9bb6
merge
cedricvincentcuaz Jun 8, 2026
38973c8
merge
cedricvincentcuaz Jun 8, 2026
50582d2
merge
cedricvincentcuaz Jun 8, 2026
82d2ff8
explicit ground_bary requirements in lp/_barycenter_solvers.py
cedricvincentcuaz Jun 9, 2026
33a37a9
Merge branch 'master' into solvers
rflamary Jun 16, 2026
515b24b
updates after first review
cedricvincentcuaz Jun 17, 2026
59110f0
Merge branch 'solvers' of https://github.com/cedricvincentcuaz/POT in…
cedricvincentcuaz Jun 17, 2026
b0f51d2
Merge branch 'master' into solvers
rflamary Jun 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver
- Add "BSP-OT: Sparse transport plans between discrete measures in loglinear time" (PR #768)
- Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765)
- Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765)
- Add cost functions between linear operators following
[A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920),
implemented in `ot.sgot` (PR #792)
- Add `ot.utils.DataScaler` class for backend-aware joint normalization of input distributions, with sklearn-compatible `fit`/`transform`/`fit_transform` API and support for `'standard'`, `'minmax'`, and `'l2'` methods (PR #808)
- Add `ot.utils.apply_scaler` helper that dispatches preprocessing to a scaler object,
a callable, or a no-op (PR #808)
- Add optional `scaler` parameter to `sliced_wasserstein_distance` and `max_sliced_wasserstein_distance` (PR #808)
- Add a numerically stable log-domain solver for entropic partial Wasserstein, selectable via the new `method` parameter of `entropic_partial_wasserstein` (`method='sinkhorn_log'`) or directly through `entropic_partial_wasserstein_logscale` (Issue #723)
- Add cost functions between linear operators following [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920), implemented in `ot.sgot` (PR #792)
- Build wheels on ubuntu ARM to avoid QEMU emulation (PR #818)
- Wrapper for barycenter solvers with free support `ot.solvers.bary_free_support` (PR #730)
- Add new methods to compute the linear transport map and the related 2-Wasserstein distance betweeen high-dimensional (HD) Gaussian distributions as described in [88], implemented in `ot.gaussian.bures_wasserstein_mapping_hd` and `ot.gaussian.bures_wasserstein_distance_hd`, respectively. Two additional methods estimate the same quantities from the source and destination observed data and are implemented in `ot.gaussian.empirical_bures_wasserstein_mapping_hd` and `ot.gaussian.empirical_bures_wasserstein_distance_hd`, respectively (PR #814)


#### Closed issues

- Mitigate NaN regime of `entropic_partial_wasserstein` at small `reg` via a new log-domain solver, reachable with `entropic_partial_wasserstein(..., method='sinkhorn_log')` (Issue #723; the default `method='sinkhorn'` path is unchanged — callers opt into the log-domain variant)
Expand Down
132 changes: 132 additions & 0 deletions examples/barycenters/plot_solve_barycenter_variants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# -*- coding: utf-8 -*-
"""
======================================
Optimal Transport Barycenter solvers comparison
======================================

This example illustrates solutions returned for different variants of exact,
regularized and unbalanced OT barycenter problems with free support using our wrapper `ot.solve_bary_sample`.
"""

# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
#
# License: MIT License
# sphinx_gallery_thumbnail_number = 2

# %%

import numpy as np
import matplotlib.pylab as pl
import ot
from ot.plot import plot2D_samples_mat

# %%
# 2D data example
# ---------------
#
# We first generate two sets of samples in 2D of 8 and 16
# points uniformly separated on circles. The weights of the samples are uniform.

# Problem size
n1, n2 = 8, 16
nbary = 12

# Generate random data
np.random.seed(0)

r1, r2 = 1, 3
x1 = r1 * np.array(
[(np.cos(2 * i * np.pi / n1), np.sin(2 * i * np.pi / n1)) for i in range(n1)]
)

x2 = r2 * np.array(
[(np.cos(2 * i * np.pi / n2), np.sin(2 * i * np.pi / n2)) for i in range(n2)]
)

style = {"markeredgecolor": "k"}

pl.figure(1, (4, 4))
pl.plot(x1[:, 0], x1[:, 1], "ob", **style)
pl.plot(x2[:, 0], x2[:, 1], "or", **style)
pl.title("Source distributions")
pl.show()


# %%
# Set up parameters for barycenter solvers and solve
# ---------------------------------------

lst_regs = [
"No Reg.",
"Entropic",
] # support e.g ["No Reg.", "Entropic", "L2", "Group Lasso + L2"]
lst_unbalanced = [
"Balanced",
"Unbalanced KL",
] # ["Balanced", "Unb. KL", "Unb. L2", "Unb L1 (partial)"]

lst_solvers = [ # name, param for ot.solve function
# balanced OT
("Exact OT", dict()),
("Entropic Reg. OT", dict(reg=1.0)),
# unbalanced OT KL
("Unbalanced KL No Reg.", dict(unbalanced=0.05)),
(
"Unbalanced KL with KL Reg.",
dict(reg=0.1, unbalanced=0.05, unbalanced_type="kl", reg_type="kl"),
),
]

lst_res = []
for name, param in lst_solvers:
print(f"-- name = {name} / param = {param}")
res = ot.solve_bary_sample(X_a_list=[x1, x2], n=nbary, **param)
lst_res.append(res)
list_P = [res.list_res[k].plan for k in range(2)]
print("X:", res.X)
print("loss:", res.value)
print("loss:", res.log)
print(
"marginals OT 1:",
res.list_res[0].plan.sum(axis=1),
res.list_res[0].plan.sum(axis=0),
)
print(
"marginals OT 2:",
res.list_res[1].plan.sum(axis=1),
res.list_res[1].plan.sum(axis=0),
)

##############################################################################
# Plot distributions and plans
# ----------

pl.figure(2, figsize=(16, 16))

style.update({"markersize": 20})

for i, bname in enumerate(lst_unbalanced):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should do it in two pats: first sharp vs entropic baryenter (easier to compare) and then a second section where you vary the marginal violation weight so that we see the transformation from exact bary to points on the left and right (as we can see there)

for j, rname in enumerate(lst_regs):
pl.subplot(len(lst_unbalanced), len(lst_regs), i * len(lst_regs) + j + 1)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe plot first the barycenter and belwo the barycenter + the OT plan? this is a bit dens and hard to parse visually.


X = lst_res[i * len(lst_regs) + j].X
list_P = [lst_res[i * len(lst_regs) + j].list_res[k].plan for k in range(2)]
loss = lst_res[i * len(lst_regs) + j].value

plot2D_samples_mat(x1, X, list_P[0])
plot2D_samples_mat(x2, X, list_P[1])

if i == 0 and j == 0: # add labels
pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source distribution 1", **style)
pl.plot(x2[:, 0], x2[:, 1], "or", label="Source distribution 2", **style)
pl.plot(X[:, 0], X[:, 1], "og", label="Barycenter distribution", **style)
pl.legend(loc="best")
else:
pl.plot(x1[:, 0], x1[:, 1], "ob", **style)
pl.plot(x2[:, 0], x2[:, 1], "or", **style)
pl.plot(X[:, 0], X[:, 1], "og", **style)

if i == 0:
pl.title(rname)
if j == 0:
pl.ylabel(bname, fontsize=14)
53 changes: 53 additions & 0 deletions examples/plot_quickstart_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,59 @@ def df(G):
plot_plan(P_fgw, "Fused GW plan", axis=False)
pl.show()

# sphinx_gallery_end_ignore

# %%
#
# Solving barycenter problems
# -------------------------------
# Solve Optimal transport barycenter problem with free support between several input distributions.
# ~~~~~~~~~~~~~~~~~~~~
#
# The :func:`ot.solve_bary_sample` function can be used to solve the Optimal Transport barycenter problem
# between multiple sets of samples while optimizing the support of the barycenter and letting fixed their probability weights.
# The function takes as its first argument the list of samples in each input distribution,
# and as second argument the number of samples to learn in the barycenter. By default, the probability weights in each distribution and the barycentric weights are uniform but they can be customized by the user.
#
# The function returns an :class:`ot.utils.OTBaryResult` object that contains in part the barycenter samples and the OT plans between the barycenter and each input distribution.
#
# In the following, we illustrate the use of this function with the same 2D data as above considered as input distributions and compute their barycenter while using exact OT.
# Notice that most of the arguments of the :func:`ot.solve_bary_sample` function are similar to those of the :func:`ot.solve_sample` function and that the same regularization and unbalanced parameters can be used to solve regularized and unbalanced barycenter problems.

# Solve the OT barycenter problem (exact OT without any regularization)
sol = ot.solve_bary_sample([x1, x2], n=35)

# get the barycenter support
X = sol.X

# get the OT plans between the barycenter and each input distribution
list_P = [sol.list_res[i].plan for i in range(2)]

# get the barycenterOT loss
loss = sol.value

print(f"Barycenter OT loss = {loss:1.3f}")

# sphinx_gallery_start_ignore
pl.figure(1, (8, 8))
plot2D_samples_mat(x1, X, list_P[0])
plot2D_samples_mat(x2, X, list_P[1])

pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source distribution 1", **style)
pl.plot(x2[:, 0], x2[:, 1], "or", label="Source distribution 2", **style)
pl.plot(X[:, 0], X[:, 1], "og", label="Barycenter distribution", **style)

pl.title(
"Barycenter samples and OT plans \n total loss= %s = 0.5 * %s + 0.5 * %s"
% (
np.round(loss, 3),
np.round(sol.list_res[0].value, 3),
np.round(sol.list_res[1].value, 3),
)
)
pl.legend(loc="best")
pl.show()

# sphinx_gallery_end_ignore
# %%
#
3 changes: 2 additions & 1 deletion ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
)
from .weak import weak_optimal_transport
from .factored import factored_optimal_transport
from .solvers import solve, solve_gromov, solve_sample
from .solvers import solve, solve_gromov, solve_sample, solve_bary_sample
from .lowrank import lowrank_sinkhorn

from .batch import solve_batch, solve_sample_batch, solve_gromov_batch, dist_batch
Expand Down Expand Up @@ -145,6 +145,7 @@
"solve",
"solve_gromov",
"solve_sample",
"solve_bary_sample",
"smooth",
"stochastic",
"unbalanced",
Expand Down
7 changes: 5 additions & 2 deletions ot/lp/_barycenter_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def free_support_barycenter_generic_costs(
of shape :math:`(n\times d_K)`, computing the ground barycenters
(broadcasted over n). If not provided, done with Adam on PyTorch
(requires PyTorch backend), inefficiently using the cost functions in
`cost_list`.
`cost_list`. This function must be provided if `method="true_fixed_point"` is used.
a : array-like, optional
Array of shape (n,) representing weights of the barycenter
measure.Defaults to uniform.
Expand Down Expand Up @@ -673,8 +673,11 @@ def free_support_barycenter_generic_costs(

if ground_bary is None:
auto_ground_bary = True
assert (
method == "L2_barycentric_proj"
), "ground_bary must be provided if method is 'true_fixed_point'"
assert str(nx) == "torch", (
f"Backend {str(nx)} is not compatible with ground_bary=None, it"
f"Backend {str(nx)} is not compatible with ground_bary=None, it "
"must be provided if not using PyTorch backend"
)
try:
Expand Down
28 changes: 28 additions & 0 deletions ot/solvers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
"""
General OT solvers with unified API
"""

# Author: Remi Flamary <remi.flamary@polytechnique.edu>
# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
#
# License: MIT License

# All submodules and packages
from ._linear import solve, solve_sample

from ._gromov import (
solve_gromov,
)

from ._bary import (
solve_bary_sample,
)


__all__ = [
"solve",
"solve_sample",
"solve_gromov",
"solve_bary_sample",
]
Loading
Loading