diff --git a/RELEASES.md b/RELEASES.md index c7024ad04..12095a12a 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -22,16 +22,18 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver - Add "BSP-OT: Sparse transport plans between discrete measures in loglinear time" (PR #768) - Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765) - Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765) +- Add cost functions between linear operators following + [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920), + implemented in `ot.sgot` (PR #792) - Add `ot.utils.DataScaler` class for backend-aware joint normalization of input distributions, with sklearn-compatible `fit`/`transform`/`fit_transform` API and support for `'standard'`, `'minmax'`, and `'l2'` methods (PR #808) - Add `ot.utils.apply_scaler` helper that dispatches preprocessing to a scaler object, a callable, or a no-op (PR #808) - Add optional `scaler` parameter to `sliced_wasserstein_distance` and `max_sliced_wasserstein_distance` (PR #808) - Add a numerically stable log-domain solver for entropic partial Wasserstein, selectable via the new `method` parameter of `entropic_partial_wasserstein` (`method='sinkhorn_log'`) or directly through `entropic_partial_wasserstein_logscale` (Issue #723) -- Add cost functions between linear operators following [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920), implemented in `ot.sgot` (PR #792) - Build wheels on ubuntu ARM to avoid QEMU emulation (PR #818) +- Wrapper for barycenter solvers with free support `ot.solvers.bary_free_support` (PR #730) - Add new methods to compute the linear transport map and the related 2-Wasserstein distance betweeen high-dimensional (HD) Gaussian distributions as described in [88], implemented in `ot.gaussian.bures_wasserstein_mapping_hd` and `ot.gaussian.bures_wasserstein_distance_hd`, respectively. Two additional methods estimate the same quantities from the source and destination observed data and are implemented in `ot.gaussian.empirical_bures_wasserstein_mapping_hd` and `ot.gaussian.empirical_bures_wasserstein_distance_hd`, respectively (PR #814) - #### Closed issues - Mitigate NaN regime of `entropic_partial_wasserstein` at small `reg` via a new log-domain solver, reachable with `entropic_partial_wasserstein(..., method='sinkhorn_log')` (Issue #723; the default `method='sinkhorn'` path is unchanged — callers opt into the log-domain variant) diff --git a/examples/barycenters/plot_solve_barycenter_variants.py b/examples/barycenters/plot_solve_barycenter_variants.py new file mode 100644 index 000000000..4b0c7ec27 --- /dev/null +++ b/examples/barycenters/plot_solve_barycenter_variants.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +""" +====================================== +Optimal Transport Barycenter solvers comparison +====================================== + +This example illustrates solutions returned for different variants of exact, +regularized and unbalanced OT barycenter problems with free support using our wrapper `ot.solve_bary_sample`. +""" + +# Author: Cédric Vincent-Cuaz +# +# License: MIT License +# sphinx_gallery_thumbnail_number = 2 + +# %% + +import numpy as np +import matplotlib.pylab as pl +import ot +from ot.plot import plot2D_samples_mat + +# %% +# 2D data example +# --------------- +# +# We first generate two sets of samples in 2D of 8 and 16 +# points uniformly separated on circles. The weights of the samples are uniform. + +# Problem size +n1, n2 = 8, 16 +nbary = 12 + +# Generate random data +np.random.seed(0) + +r1, r2 = 1, 3 +x1 = r1 * np.array( + [(np.cos(2 * i * np.pi / n1), np.sin(2 * i * np.pi / n1)) for i in range(n1)] +) + +x2 = r2 * np.array( + [(np.cos(2 * i * np.pi / n2), np.sin(2 * i * np.pi / n2)) for i in range(n2)] +) + +style = {"markeredgecolor": "k"} + +pl.figure(1, (4, 4)) +pl.plot(x1[:, 0], x1[:, 1], "ob", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", **style) +pl.title("Source distributions") +pl.show() + + +# %% +# Set up parameters for barycenter solvers and solve +# --------------------------------------- + +lst_regs = [ + "No Reg.", + "Entropic", +] # support e.g ["No Reg.", "Entropic", "L2", "Group Lasso + L2"] +lst_unbalanced = [ + "Balanced", + "Unbalanced KL", +] # ["Balanced", "Unb. KL", "Unb. L2", "Unb L1 (partial)"] + +lst_solvers = [ # name, param for ot.solve function + # balanced OT + ("Exact OT", dict()), + ("Entropic Reg. OT", dict(reg=1.0)), + # unbalanced OT KL + ("Unbalanced KL No Reg.", dict(unbalanced=0.05)), + ( + "Unbalanced KL with KL Reg.", + dict(reg=0.1, unbalanced=0.05, unbalanced_type="kl", reg_type="kl"), + ), +] + +lst_res = [] +for name, param in lst_solvers: + print(f"-- name = {name} / param = {param}") + res = ot.solve_bary_sample(X_a_list=[x1, x2], n=nbary, **param) + lst_res.append(res) + list_P = [res.list_res[k].plan for k in range(2)] + print("X:", res.X) + print("loss:", res.value) + print("loss:", res.log) + print( + "marginals OT 1:", + res.list_res[0].plan.sum(axis=1), + res.list_res[0].plan.sum(axis=0), + ) + print( + "marginals OT 2:", + res.list_res[1].plan.sum(axis=1), + res.list_res[1].plan.sum(axis=0), + ) + +############################################################################## +# Plot distributions and plans +# ---------- + +pl.figure(2, figsize=(16, 16)) + +style.update({"markersize": 20}) + +for i, bname in enumerate(lst_unbalanced): + for j, rname in enumerate(lst_regs): + pl.subplot(len(lst_unbalanced), len(lst_regs), i * len(lst_regs) + j + 1) + + X = lst_res[i * len(lst_regs) + j].X + list_P = [lst_res[i * len(lst_regs) + j].list_res[k].plan for k in range(2)] + loss = lst_res[i * len(lst_regs) + j].value + + plot2D_samples_mat(x1, X, list_P[0]) + plot2D_samples_mat(x2, X, list_P[1]) + + if i == 0 and j == 0: # add labels + pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source distribution 1", **style) + pl.plot(x2[:, 0], x2[:, 1], "or", label="Source distribution 2", **style) + pl.plot(X[:, 0], X[:, 1], "og", label="Barycenter distribution", **style) + pl.legend(loc="best") + else: + pl.plot(x1[:, 0], x1[:, 1], "ob", **style) + pl.plot(x2[:, 0], x2[:, 1], "or", **style) + pl.plot(X[:, 0], X[:, 1], "og", **style) + + if i == 0: + pl.title(rname) + if j == 0: + pl.ylabel(bname, fontsize=14) diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index 6704b860e..d5cfac8e6 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -638,6 +638,59 @@ def df(G): plot_plan(P_fgw, "Fused GW plan", axis=False) pl.show() +# sphinx_gallery_end_ignore + +# %% +# +# Solving barycenter problems +# ------------------------------- +# Solve Optimal transport barycenter problem with free support between several input distributions. +# ~~~~~~~~~~~~~~~~~~~~ +# +# The :func:`ot.solve_bary_sample` function can be used to solve the Optimal Transport barycenter problem +# between multiple sets of samples while optimizing the support of the barycenter and letting fixed their probability weights. +# The function takes as its first argument the list of samples in each input distribution, +# and as second argument the number of samples to learn in the barycenter. By default, the probability weights in each distribution and the barycentric weights are uniform but they can be customized by the user. +# +# The function returns an :class:`ot.utils.OTBaryResult` object that contains in part the barycenter samples and the OT plans between the barycenter and each input distribution. +# +# In the following, we illustrate the use of this function with the same 2D data as above considered as input distributions and compute their barycenter while using exact OT. +# Notice that most of the arguments of the :func:`ot.solve_bary_sample` function are similar to those of the :func:`ot.solve_sample` function and that the same regularization and unbalanced parameters can be used to solve regularized and unbalanced barycenter problems. + +# Solve the OT barycenter problem (exact OT without any regularization) +sol = ot.solve_bary_sample([x1, x2], n=35) + +# get the barycenter support +X = sol.X + +# get the OT plans between the barycenter and each input distribution +list_P = [sol.list_res[i].plan for i in range(2)] + +# get the barycenterOT loss +loss = sol.value + +print(f"Barycenter OT loss = {loss:1.3f}") + +# sphinx_gallery_start_ignore +pl.figure(1, (8, 8)) +plot2D_samples_mat(x1, X, list_P[0]) +plot2D_samples_mat(x2, X, list_P[1]) + +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source distribution 1", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Source distribution 2", **style) +pl.plot(X[:, 0], X[:, 1], "og", label="Barycenter distribution", **style) + +pl.title( + "Barycenter samples and OT plans \n total loss= %s = 0.5 * %s + 0.5 * %s" + % ( + np.round(loss, 3), + np.round(sol.list_res[0].value, 3), + np.round(sol.list_res[1].value, 3), + ) +) +pl.legend(loc="best") +pl.show() + # sphinx_gallery_end_ignore # %% # diff --git a/ot/__init__.py b/ot/__init__.py index 378209472..deb3a087d 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -82,7 +82,7 @@ ) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov, solve_sample +from .solvers import solve, solve_gromov, solve_sample, solve_bary_sample from .lowrank import lowrank_sinkhorn from .batch import solve_batch, solve_sample_batch, solve_gromov_batch, dist_batch @@ -145,6 +145,7 @@ "solve", "solve_gromov", "solve_sample", + "solve_bary_sample", "smooth", "stochastic", "unbalanced", diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index 46fcd1a48..b938cb799 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -590,7 +590,7 @@ def free_support_barycenter_generic_costs( of shape :math:`(n\times d_K)`, computing the ground barycenters (broadcasted over n). If not provided, done with Adam on PyTorch (requires PyTorch backend), inefficiently using the cost functions in - `cost_list`. + `cost_list`. This function must be provided if `method="true_fixed_point"` is used. a : array-like, optional Array of shape (n,) representing weights of the barycenter measure.Defaults to uniform. @@ -673,8 +673,11 @@ def free_support_barycenter_generic_costs( if ground_bary is None: auto_ground_bary = True + assert ( + method == "L2_barycentric_proj" + ), "ground_bary must be provided if method is 'true_fixed_point'" assert str(nx) == "torch", ( - f"Backend {str(nx)} is not compatible with ground_bary=None, it" + f"Backend {str(nx)} is not compatible with ground_bary=None, it " "must be provided if not using PyTorch backend" ) try: diff --git a/ot/solvers/__init__.py b/ot/solvers/__init__.py new file mode 100644 index 000000000..d6c3ea835 --- /dev/null +++ b/ot/solvers/__init__.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +""" +General OT solvers with unified API +""" + +# Author: Remi Flamary +# Cédric Vincent-Cuaz +# +# License: MIT License + +# All submodules and packages +from ._linear import solve, solve_sample + +from ._gromov import ( + solve_gromov, +) + +from ._bary import ( + solve_bary_sample, +) + + +__all__ = [ + "solve", + "solve_sample", + "solve_gromov", + "solve_bary_sample", +] diff --git a/ot/solvers/_bary.py b/ot/solvers/_bary.py new file mode 100644 index 000000000..797b14076 --- /dev/null +++ b/ot/solvers/_bary.py @@ -0,0 +1,626 @@ +# -*- coding: utf-8 -*- +""" +General OT solvers with unified API +""" + +# Author: Remi Flamary +# Cédric Vincent-Cuaz +# +# License: MIT License + +from ..utils import BaryResult +from ..lp import free_support_barycenter_generic_costs +from ..backend import get_backend + +from ._linear import solve, solve_sample, lst_method_lazy + +import numpy as np + + +def _bary_sample_bcd( + X_a_list, + X_b_init, + a_list, + b_init, + w, + metric, + inner_solver, + update_masses, + warmstart_plan, + warmstart_potentials, + stopping_criterion, + max_iter_bary, + tol_bary, + verbose, + log, + nx, +): + """Compute the barycenter using BCD. + + Parameters + ---------- + X_a_list : list of array-like, shape (n_samples_k, dim) + List of samples in each source distribution + X_b_init : array-like, shape (n_samples_b, dim), + Initialization of the barycenter samples. + a_list : list of array-like, shape (dim_k,) + List of samples weights in each source distribution + b_init : array-like, shape (n_samples_b,) + Initialization of the barycenter weights. + w : list of array-like, shape (N,) + Samples barycentric weights + metric : str + Metric to use for the cost matrix, by default "sqeuclidean" + inner_solver : callable with parameters (X_a, X_b, a, b, plan_init, potentials_init) + Function to solve the inner OT problem with inputs: source and target samples (`X_a`, `X_b`), + their respective masses (`a`, `b`), optional initial transport plan `plan_init`, optional initial + dual potentials `potentials_init` used for instance with sinkhorn-like inner solvers. + update_masses : bool + Update the masses of the barycenter, depending on whether balanced or unbalanced OT is used. + warmstart_plan : bool + Use the previous plan as initialization for the inner solver. Set based on inner solver type in ot.bary_sample + warmstart_potentials : bool + Use the previous potentials as initialization for the inner solver. Set based on inner solver type in ot.bary_sample + stopping_criterion : str + Stopping criterion for the BCD algorithm. Can be "loss" or "bary". + max_iter_bary : int + Maximum number of iterations for the barycenter + tol_bary : float + Tolerance for the barycenter convergence + verbose : bool + Print information in the solver + log : bool + Log the loss during the iterations + nx: backend + Backend to use for the computation. Must match<< + Returns + ------- + + res : BaryResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.X : Barycenter samples + - res.b : Barycenter weights + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + - res.list_res: List of OTResult for each inner OT problem (one per source distribution) + - res.log: log of the optimization process (if log=True) + + See :any:`BaryResult` for more information. + + """ + + X_b = X_b_init + b = b_init + inv_b = nx.nan_to_num(1.0 / b, nan=1.0, posinf=1.0, neginf=1.0) + + prev_criterion = np.inf + n_samples = len(X_a_list) + + log_ = None + if log: + log_ = {"stopping_criterion": []} + + # Compute the barycenter using BCD + for it in range(max_iter_bary): + # Solve the inner OT problem for each source distribution + if it == 0: # no pre-defined warmstart used at iteration 0. + list_res = [ + inner_solver(X_a_list[k], X_b, a_list[k], b, None, None) + for k in range(n_samples) + ] + elif warmstart_plan: + list_res = [ + inner_solver(X_a_list[k], X_b, a_list[k], b, list_res[k].plan, None) + for k in range(n_samples) + ] + elif warmstart_potentials: + list_res = [ + inner_solver( + X_a_list[k], X_b, a_list[k], b, None, list_res[k].potentials + ) + for k in range(n_samples) + ] + else: + list_res = [ + inner_solver(X_a_list[k], X_b, a_list[k], b, None, None) + for k in range(n_samples) + ] + + # Update the estimated barycenter weights in unbalanced cases + if update_masses: + b = sum([w[k] * list_res[k].plan.sum(axis=0) for k in range(n_samples)]) + inv_b = nx.nan_to_num(1.0 / b, nan=1.0, posinf=1.0, neginf=1.0) + + # Update the barycenter samples + if metric in ["sqeuclidean", "euclidean"]: + X_b_new = ( + sum([w[k] * list_res[k].plan.T @ X_a_list[k] for k in range(n_samples)]) + * inv_b[:, None] + ) + else: + raise NotImplementedError('Not implemented metric="{}"'.format(metric)) + + # compute criterion + if stopping_criterion == "loss": + new_criterion = sum([w[k] * list_res[k].value for k in range(n_samples)]) + else: # stopping_criterion = "bary" + new_criterion = nx.sum((X_b_new - X_b) ** 2) + + if verbose: + if it % 1 == 0: + print( + f"BCD iteration {it}: criterion {stopping_criterion} = {new_criterion:.4f}" + ) + + if log: + log_["stopping_criterion"].append(new_criterion) + # Check convergence + if abs(new_criterion - prev_criterion) / abs(prev_criterion) < tol_bary: + print(f"BCD converged in {it} iterations") + break + + X_b = X_b_new + prev_criterion = new_criterion + + # compute loss values + + value_linear = sum([w[k] * list_res[k].value_linear for k in range(n_samples)]) + if stopping_criterion == "loss": + value = new_criterion + else: + value = sum([w[k] * list_res[k].value for k in range(n_samples)]) + # update BaryResult + bary_res = BaryResult( + X=X_b, + b=b, + value=value, + value_linear=value_linear, + log=log_, + list_res=list_res, + backend=nx, + ) + return bary_res + + +def solve_bary_sample( + X_a_list, + n, + a_list=None, + w=None, + X_b_init=None, + b_init=None, + metric="sqeuclidean", + reg=None, + c=None, + reg_type="KL", + unbalanced=None, + unbalanced_type="KL", + lazy=False, + method=None, + auto_bary_method="L2_barycentric_proj", + warmstart=False, + stopping_criterion="loss", + max_iter_bary=1000, + tol_bary=1e-5, + random_state=0, + verbose=False, + **kwargs, +): + r"""Solve the discrete OT barycenter problem over source distributions optimizing the barycenter support using Block-Coordinate Descent. + + The function solves the following general OT barycenter problem + + .. math:: + \min_{\mathbf{X} \in \mathbb{R}^{n \times d}} \min_{\{ \mathbf{T}^{(k)} \}_k \in \mathbb{R}_+^{n_i \times n}} \quad \sum_k w_k \{ \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + \lambda_r R(\mathbf{T}^{(k)}) + + \lambda_u U(\mathbf{T^{(k)}}\mathbf{1},\mathbf{a}^{(k)}) + + \lambda_u U(\mathbf{T}^{(k)T}\mathbf{1},\mathbf{b}) \} + + where the cost matrices :math:`\mathbf{M}^{(k)}` from each input distribution :math:`(\mathbf{X}^{(k)}, \mathbf{a}^{(k)})` + to the barycenter domain are computed as :math:`M^{(k)}_{i,j} = d(x^{(k)}_i,x_j)` where + :math:`d` is a metric (by default the squared Euclidean distance). For common metrics the barycenter is computed in closed-form. + For balanced OT, the `metric` parameter can also be any callable function, or list of functions, that computes the distance from an input to the barycenter. + In which case, the barycenter is updated by gradient descent using the provided metric(s) and the optimal transport plan(s) at each iteration. + The barycenter probability weights are fixed to :math:`\mathbf{b}`. + + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + `unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + X_a_list : list of array-like, shape (n_samples_k, dim) + List of N samples in each source distribution + n : int + number of samples in the barycenter domain + a_list : list of array-like, shape (n_samples_k,), optional + List of samples weights in each source distribution (default is uniform) + w : list of array-like, shape (N,), optional + Samples barycentric weights (default is uniform) + X_b_init : array-like, shape (n, dim), optional + Initialization of the barycenter samples (default is gaussian random sampling) + b_init : array-like, shape (n,), optional + Initialization of the barycenter weights (default is uniform) + metric : str, callable or list of callables optional + Metric to use for the computation of the cost matrix, by default "sqeuclidean". + It can be a list of callables (bary, source) of length N (number of source distributions) to use different metrics for each source distribution. + In this case, the barycenter is updated by gradient descent using the provided metric(s) and the optimal transport plan(s) at each iteration. + If only callable is provided the same cost function is used for all source distributions. + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + c : array-like, shape (dim_a, dim_b), optional (default=None) + Reference measure for the regularization. + If None, then use :math:`\mathbf{c} = \mathbf{a}^{(k)} \mathbf{b}^T`. + If :math:`\texttt{reg_type}=`'entropy', then :math:`\mathbf{c} = 1_{|a^{(k)}|} 1_{|b|}^T`. + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + unbalanced : float or indexable object of length 1 or 2 + Marginal relaxation term. + If it is a scalar or an indexable object of length 1, + then the same relaxation is applied to both marginal relaxations. + The balanced OT can be recovered using :math:`unbalanced=float("inf")`. + For semi-relaxed case, use either + :math:`unbalanced=(float("inf"), scalar)` or + :math:`unbalanced=(scalar, float("inf"))`. + If unbalanced is an array, + it must have the same backend as input arrays `(a, b, M)`. + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + lazy : bool, optional + Return :any:`OTResultlazy` object to reduce memory cost when True, by + default False + method : str, optional + Method for solving the problem, this can be used to select the solver + for unbalanced problems (see :any:`ot.solve`), or to select a specific + large scale solver. + auto_bary_method: str, optional + For balanced OT with callable metric functions, the barycenter method to use in 'L2_barycentric_proj' (default) for Euclidean + barycentric projection, or 'true_fixed_point' for iterates using the North West Corner multi-marginal gluing method. + warmstart : bool, optional + Use the previous OT or potentials as initialization for the next inner solver iteration, by default False. + stopping_criterion : str, optional + Stopping criterion for the outer loop of the BCD solver, by default 'loss'. + Either 'loss' to use the optimize objective or 'bary' for variations of the barycenter w.r.t the Frobenius norm. + max_iter_bary : int, optional + Maximum number of iteration for the outer loop of the BCD solver, by default 1000. + tol_bary : float, optional + Tolerance for solution precision of the barycenter problem, by default 1e-5. + random_state : int, optional + Random seed for the initialization of the barycenter samples, by default 0. + Only used if `X_init` is None. + verbose : bool, optional + Print information in the solver, by default False + kwargs : optional + Additional parameters for the inner solver (see :any:`ot.solve_sample` and :any:`ot.lp.free_support_barycenter_generic_costs`) + Returns + ------- + + res : BaryResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.X : Barycenter samples + - res.b : Barycenter weights + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + - res.list_res: List of OTResult for each inner OT problem (one per source distribution) + - res.log: log of the optimization process (if log=True) + + See :any:`BaryResult` for more information. + + Notes + ----- + + The following methods are available for solving barycenter problems with respect to these inner OT problems: + + - **Classical exact OT problem [1]** (default parameters) : + + .. math:: + \forall k, \quad \min_{\mathbf{T}^{(k)}} \quad \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + + s.t. \ \mathbf{T}^{(k)} \mathbf{1} = \mathbf{a}^{(k)} + + \mathbf{T}^{(k)^T} \mathbf{1} = \mathbf{b} + + \mathbf{T}^{(k)} \geq 0, M^{(k)}_{i,j} = d(x^{(k)}_i,x_j) + + + + can be solved with the following code for various cost metrics between the source distributions and the barycenter: + + .. code-block:: python + + # for squared Euclidean cost, where closed-form solutions are used to update the barycenter + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, metric='sqeuclidean') + + # for uniform sample weights and barycentric weights, + res = ot.solve_bary_sample([x1, x2], n, [a1, a2], w, metric='sqeuclidean') + + # for other cost functions, where the barycenter is updated with gradient descent using Pytorch + # refer to the documentation and examples for more details. + + - **Entropic regularized OT [2]** (when ``reg!=None``): + + .. math:: + \min_{\mathbf{T}^{(k)}} \quad \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + \lambda R(\mathbf{T}^{(k)}) + + s.t. \ \mathbf{T}^{(k)} \mathbf{1} = \mathbf{a}^{(k)} + + \mathbf{T}^{(k)^T} \mathbf{1} = \mathbf{b} + + \mathbf{T}^{(k)} \geq 0, M^{(k)}_{i,j} = d(x^{(k)}_i,x_j) + + + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` regularization (``reg_type="KL"``) + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0) + + # or for original Sinkhorn paper formulation [2] + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, reg_type='entropy') + + + - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): + + .. math:: + \min_{\mathbf{T}^{(k)}} \quad \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + \lambda R(\mathbf{T}^{(k))}) + + s.t. \ \mathbf{T}^{(k)} \mathbf{1} = \mathbf{a}^{(k)} + + \mathbf{T}^{(k)^T} \mathbf{1} = \mathbf{b} + + \mathbf{T}^{(k)} \geq 0, M^{(k)}_{i,j} = d(x^{(k)}_i,x_j) + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, reg_type='L2') + + - **Unbalanced OT [41]** (when ``unbalanced!=None``): + + .. math:: + \min_{\mathbf{T}^{(k)}\geq 0} \quad \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + \lambda_u U(\mathbf{T}^{(k)}\mathbf{1},\mathbf{a}^{(k)}) + \lambda_u U(\mathbf{T}^{(k)^T}\mathbf{1},\mathbf{b}) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, unbalanced=1.0) + + # quadratic unbalanced OT + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, unbalanced=1.0, unbalanced_type='L2') + # TV = partial OT + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, unbalanced=1.0, unbalanced_type='TV') + + + - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): + + .. math:: + \min_{\mathbf{T}^{(k)} \geq 0} \quad \langle \mathbf{T}^{(k)}, \mathbf{M}^{(k)} \rangle_F + \lambda_r R(\mathbf{T}^{(k)}) + \lambda_u U(\mathbf{T}^{(k)}\mathbf{1},\mathbf{a}^{(k)}) + \lambda_u U(\mathbf{T}^{(k)^T}\mathbf{1},\mathbf{b}) + + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` for both + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, unbalanced=1.0) + # quadratic unbalanced OT with KL regularization + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, unbalanced=1.0, unbalanced_type='L2') + # both quadratic + res = ot.solve_bary_sample([x1, x2], n , [a1, a2], w, reg=1.0, reg_type='L2', unbalanced=1.0, unbalanced_type='L2') + + .. _references-solve_bary_sample: + References + ---------- + + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + """ + + if method is not None and method.lower() in lst_method_lazy: + raise NotImplementedError( + f"method {method} operating on lazy tensors is not implemented yet" + ) + + if stopping_criterion not in ["loss", "bary"]: + raise ValueError( + "stopping_criterion must be either 'loss' or 'bary', got {}".format( + stopping_criterion + ) + ) + + n_samples = len(X_a_list) + + if ( + not lazy + ): # default non lazy solver calls ot.solve_sample within _bary_sample_bcd + # Detect backend + nx = get_backend(*X_a_list, X_b_init, b_init, w) + + # check sample weights + if a_list is None: + a_list = [ + nx.ones((X_a_list[k].shape[0],), type_as=X_a_list[k]) + / X_a_list[k].shape[0] + for k in range(n_samples) + ] + + # check samples barycentric weights + if w is None: + w = nx.ones(n_samples, type_as=X_a_list[0]) / n_samples + + # check X_b_init + if X_b_init is None: + rng = np.random.RandomState(random_state) + mean_ = nx.concatenate( + [nx.mean(X_a_list[k], axis=0) for k in range(n_samples)], + axis=0, + ) + mean_ = nx.mean(mean_, axis=0) + std_ = nx.concatenate( + [nx.std(X_a_list[k], axis=0) for k in range(n_samples)], + axis=0, + ) + std_ = nx.mean(std_, axis=0) + X_b_init = rng.normal( + loc=mean_, + scale=std_, + size=(n, X_a_list[0].shape[1]), + ) + X_b_init = nx.from_numpy(X_b_init, type_as=X_a_list[0]) + else: + if (X_b_init.shape[0] != n) or (X_b_init.shape[1] != X_a_list[0].shape[1]): + raise ValueError("X_b_init must have shape (n, dim)") + + # check b_init + if b_init is None: + b_init = nx.ones((n,), type_as=X_a_list[0]) / n + + if callable(metric) or ( + isinstance(metric, list) and all(callable(m) for m in metric) + ): + if reg is not None or unbalanced is not None: + raise NotImplementedError( + "Custom callable metric only available for balanced OT (reg=None and unbalanced=None)" + ) + else: + if auto_bary_method == "true_fixed_point": + ground_bary = kwargs.get("ground_bary", None) + if ground_bary is None: + raise ValueError( + "ground_bary must be provided in kwargs for true_fixed_point method with callable metrics" + ) + + outputs = free_support_barycenter_generic_costs( + X_a_list, + a_list, + X_b_init, + metric, + ground_bary=None, + a=b_init, + numItermax=max_iter_bary, + method=auto_bary_method, + stopThr=tol_bary, + log=True, + **kwargs, + ) + if auto_bary_method == "L2_barycentric_proj": + X_b, log_ = outputs + b = b_init + elif auto_bary_method == "true_fixed_point": + X_b, b, log_ = ( + outputs # potentially modify the masses of the barycenter with the true fixed point method + ) + + # compute the pairwise transport plans and losses + metric_list = ( + metric if isinstance(metric, list) else [metric] * n_samples + ) + list_res = [ + solve( + M=metric_list[k]( + X_b, X_a_list[k] + ).T, # in the free support setting, the cost matrix is computed from the barycenter to the source distribution, so we transpose it here to be consistent with the inner_solver interface (X_a, X_b, + a=a_list[k], + b=b, + reg=None, + unbalanced=None, + ) + for k in range(n_samples) + ] + + value_linear = sum( + w[k] * list_res[k].value_linear for k in range(n_samples) + ) + res = BaryResult( + X=X_b, + b=b, + value=value_linear, + value_linear=value_linear, + log=log_, + list_res=list_res, + backend=nx, + ) + return res + else: # check metric + if metric not in ["sqeuclidean", "euclidean"]: + raise NotImplementedError( + 'Not implemented BCD with closed-form on the barycenter samples with metric="{}"'.format( + metric + ) + ) + if warmstart: + if reg is None: # exact OT + warmstart_plan = True + warmstart_potentials = False + else: # regularized OT + # unbalanced AND regularized OT + if ( + not isinstance(reg_type, tuple) + and reg_type.lower() in ["kl"] + and unbalanced_type.lower() == "kl" + ): + warmstart_plan = False + warmstart_potentials = True + + else: + warmstart_plan = True + warmstart_potentials = False + else: + warmstart_plan = False + warmstart_potentials = False + + def inner_solver(X_a, X_b, a, b, plan_init, potentials_init): + return solve_sample( + X_a=X_a, + X_b=X_b, + a=a, + b=b, + metric=metric, + reg=reg, + c=c, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + method=method, + plan_init=plan_init, + potentials_init=potentials_init, + verbose=False, + **kwargs, + ) + + # compute the barycenter using BCD + update_masses = unbalanced is not None + res = _bary_sample_bcd( + X_a_list, + X_b_init, + a_list, + b_init, + w, + metric, + inner_solver, + update_masses, + warmstart_plan, + warmstart_potentials, + stopping_criterion, + max_iter_bary, + tol_bary, + verbose, + True, # log set to True by default + nx, + ) + + return res + + else: + raise (NotImplementedError("Barycenter solver with lazy=True not implemented")) diff --git a/ot/solvers/_gromov.py b/ot/solvers/_gromov.py new file mode 100644 index 000000000..7eef3f2a4 --- /dev/null +++ b/ot/solvers/_gromov.py @@ -0,0 +1,779 @@ +# -*- coding: utf-8 -*- +""" +General OT solvers with unified API +""" + +# Author: Remi Flamary +# Cédric Vincent-Cuaz +# +# License: MIT License + +from ..utils import OTResult +from ..lp import emd2 +from ..backend import get_backend +from ..bregman import ( + sinkhorn_log, +) +from ..gromov import ( + gromov_wasserstein2, + fused_gromov_wasserstein2, + entropic_gromov_wasserstein2, + entropic_fused_gromov_wasserstein2, + semirelaxed_gromov_wasserstein2, + semirelaxed_fused_gromov_wasserstein2, + entropic_semirelaxed_fused_gromov_wasserstein2, + entropic_semirelaxed_gromov_wasserstein2, + partial_gromov_wasserstein2, + partial_fused_gromov_wasserstein2, + entropic_partial_gromov_wasserstein2, + entropic_partial_fused_gromov_wasserstein2, +) + +import warnings + + +def solve_gromov( + Ca, + Cb, + M=None, + a=None, + b=None, + loss="L2", + symmetric=None, + alpha=0.5, + reg=None, + reg_type="entropy", + unbalanced=None, + unbalanced_type="KL", + n_threads=1, + method=None, + max_iter=None, + plan_init=None, + tol=None, + verbose=False, +): + r"""Solve the discrete (Fused) Gromov-Wasserstein and return :any:`OTResult` object + + The function solves the following optimization problem: + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + The regularization is selected with `reg` (:math:`\lambda_r`) and + `reg_type`. By default ``reg=None`` and there is no regularization. The + unbalanced marginal penalization can be selected with `unbalanced` + (:math:`\lambda_u`) and `unbalanced_type`. By default ``unbalanced=None`` + and the function solves the exact optimal transport problem (respecting the + marginals). + + Parameters + ---------- + Ca : array-like, shape (dim_a, dim_a) + Cost matrix in the source domain + Cb : array-like, shape (dim_b, dim_b) + Cost matrix in the target domain + M : array-like, shape (dim_a, dim_b), optional + Linear cost matrix for Fused Gromov-Wasserstein (default is None). + a : array-like, shape (dim_a,), optional + Samples weights in the source domain (default is uniform) + b : array-like, shape (dim_b,), optional + Samples weights in the source domain (default is uniform) + loss : str, optional + Type of loss function, either ``"L2"`` or ``"KL"``, by default ``"L2"`` + symmetric : bool, optional + Use symmetric version of the Gromov-Wasserstein problem, by default None + tests whether the matrices are symmetric or True/False to avoid the test. + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + reg_type : str, optional + Type of regularization :math:`R`, by default "entropy" (only used when + ``reg!=None``) + alpha : float, optional + Weight the quadratic term (alpha*Gromov) and the linear term + ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. Not used for + Gromov problem (when M is not provided). By default ``alpha=None`` + corresponds to ``alpha=1`` for Gromov problem (``M==None``) and + ``alpha=0.5`` for Fused Gromov-Wasserstein problem (``M!=None``) + unbalanced : float, optional + Unbalanced penalization weight :math:`\lambda_u`, by default None + (balanced OT). Not implemented yet for "KL" unbalanced penalization + function :math:`U`. Corresponds to the total transport mass for partial OT. + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "semirelaxed", + "partial", by default "KL" but note that it is not implemented yet. + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + method : str, optional + Method for solving the problem when multiple algorithms are available, + default None for automatic selection. + max_iter : int, optional + Maximum number of iterations, by default None (default values in each + solvers) + plan_init : array-like, shape (dim_a, dim_b), optional + Initialization of the OT plan for iterative methods, by default None + tol : float, optional + Tolerance for solution precision, by default None (default values in + each solvers) + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + res : OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + - res.value_quad : Quadratic (GW) part of the OT loss with the optimal OT plan + + See :any:`OTResult` for more information. + + Notes + ----- + The following methods are available for solving the Gromov-Wasserstein + problem: + + - **Classical Gromov-Wasserstein (GW) problem [3]** (default parameters): + + .. math:: + \min_{\mathbf{T}\geq 0} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_gromov(Ca, Cb) # uniform weights + res = ot.solve_gromov(Ca, Cb, a=a, b=b) # given weights + res = ot.solve_gromov(Ca, Cb, loss='KL') # KL loss + + plan = res.plan # GW plan + value = res.value # GW value + + - **Fused Gromov-Wasserstein (FGW) problem [24]** (when ``M!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_gromov(Ca, Cb, M) # uniform weights, alpha=0.5 (default) + res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, alpha=0.1) # given weights and alpha + + plan = res.plan # FGW plan + loss_linear_term = res.value_linear # Wasserstein part of the loss + loss_quad_term = res.value_quad # Gromov part of the loss + loss = res.value # FGW value + + - **Regularized (Fused) Gromov-Wasserstein (GW) problem [12]** (when ``reg!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + \lambda_r R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_gromov(Ca, Cb, reg=1.0) # GW entropy regularization (default) + res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, reg=10, alpha=0.1) # FGW with entropy + + plan = res.plan # FGW plan + loss_linear_term = res.value_linear # Wasserstein part of the loss + loss_quad_term = res.value_quad # Gromov part of the loss + loss = res.value # FGW value (including regularization) + + - **Semi-relaxed (Fused) Gromov-Wasserstein (GW) [48]** (when ``unbalanced='semirelaxed'``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed') # semirelaxed GW + res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed', reg=1) # entropic semirelaxed GW + res = ot.solve_gromov(Ca, Cb, M, unbalanced='semirelaxed', alpha=0.1) # semirelaxed FGW + + plan = res.plan # FGW plan + right_marginal = res.marginal_b # right marginal of the plan + + - **Partial (Fused) Gromov-Wasserstein (GW) problem [29]** (when ``unbalanced='partial'``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} \leq \mathbf{a} + + \mathbf{T}^T \mathbf{1} \leq \mathbf{b} + + \mathbf{T} \geq 0 + + \mathbf{1}^T\mathbf{T}\mathbf{1} = m + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=0.8) # partial GW with m=0.8 + res = ot.solve_gromov(Ca, Cb, M, unbalanced_type='partial', unbalanced=0.8, alpha=0.5) # partial FGW with m=0.8 + + + .. _references-solve-gromov: + References + ---------- + + .. [3] Mémoli, F. (2011). Gromov–Wasserstein distances and the metric + approach to object matching. Foundations of computational mathematics, + 11(4), 417-487. + + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), + Gromov-Wasserstein averaging of kernel and distance matrices + International Conference on Machine Learning (ICML). + + .. [24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. + (2019). Optimal Transport for structured data with application on graphs + Proceedings of the 36th International Conference on Machine Learning + (ICML). + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, + Nicolas Courty (2022). Semi-relaxed Gromov-Wasserstein divergence and + applications on graphs. International Conference on Learning + Representations (ICLR), 2022. + + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). Partial Optimal Transport + with Applications on Positive-Unlabeled Learning, Advances in Neural + Information Processing Systems (NeurIPS), 2020. + + """ + + # detect backend + nx = get_backend(Ca, Cb, M, a, b) + + # create uniform weights if not given + if a is None: + a = nx.ones(Ca.shape[0], type_as=Ca) / Ca.shape[0] + if b is None: + b = nx.ones(Cb.shape[1], type_as=Cb) / Cb.shape[1] + + # default values for solutions + potentials = None + value = None + value_linear = None + value_quad = None + plan = None + status = None + log = None + + loss_dict = {"l2": "square_loss", "kl": "kl_loss"} + + if loss.lower() not in loss_dict.keys(): + raise (NotImplementedError('Not implemented GW loss="{}"'.format(loss))) + loss_fun = loss_dict[loss.lower()] + + if reg is None or reg == 0: # exact OT + if unbalanced is None and unbalanced_type.lower() not in [ + "semirelaxed", + ]: # Exact balanced OT + if unbalanced_type.lower() in ["partial"]: + warnings.warn( + "Exact balanced OT is computed as `unbalanced=None` even though " + f"unbalanced_type = {unbalanced_type}.", + stacklevel=2, + ) + + if M is None or alpha == 1: # Gromov-Wasserstein problem + # default values for solver + if max_iter is None: + max_iter = 10000 + if tol is None: + tol = 1e-9 + + value, log = gromov_wasserstein2( + Ca, + Cb, + a, + b, + loss_fun=loss_fun, + log=True, + symmetric=symmetric, + max_iter=max_iter, + G0=plan_init, + tol_rel=tol, + tol_abs=tol, + verbose=verbose, + ) + + value_quad = value + if alpha == 1: # set to 0 for FGW with alpha=1 + value_linear = 0 + plan = log["T"] + potentials = (log["u"], log["v"]) + + elif alpha == 0: # Wasserstein problem + # default values for EMD solver + if max_iter is None: + max_iter = 1000000 + + value_linear, log = emd2( + a, + b, + M, + numItermax=max_iter, + log=True, + return_matrix=True, + numThreads=n_threads, + ) + + value = value_linear + potentials = (log["u"], log["v"]) + plan = log["G"] + status = log["warning"] if log["warning"] is not None else "Converged" + value_quad = 0 + + else: # Fused Gromov-Wasserstein problem + # default values for solver + if max_iter is None: + max_iter = 10000 + if tol is None: + tol = 1e-9 + + value, log = fused_gromov_wasserstein2( + M, + Ca, + Cb, + a, + b, + loss_fun=loss_fun, + alpha=alpha, + log=True, + symmetric=symmetric, + max_iter=max_iter, + G0=plan_init, + tol_rel=tol, + tol_abs=tol, + verbose=verbose, + ) + + value_linear = log["lin_loss"] + value_quad = log["quad_loss"] + plan = log["T"] + potentials = (log["u"], log["v"]) + + elif unbalanced_type.lower() in ["semirelaxed"]: # Semi-relaxed OT + if M is None or alpha == 1: # Semi relaxed Gromov-Wasserstein problem + # default values for solver + if max_iter is None: + max_iter = 10000 + if tol is None: + tol = 1e-9 + + value, log = semirelaxed_gromov_wasserstein2( + Ca, + Cb, + a, + loss_fun=loss_fun, + log=True, + symmetric=symmetric, + max_iter=max_iter, + G0=plan_init, + tol_rel=tol, + tol_abs=tol, + verbose=verbose, + ) + + value_quad = value + if alpha == 1: # set to 0 for FGW with alpha=1 + value_linear = 0 + plan = log["T"] + # potentials = (log['u'], log['v']) TODO + + else: # Semi relaxed Fused Gromov-Wasserstein problem + # default values for solver + if max_iter is None: + max_iter = 10000 + if tol is None: + tol = 1e-9 + + value, log = semirelaxed_fused_gromov_wasserstein2( + M, + Ca, + Cb, + a, + loss_fun=loss_fun, + alpha=alpha, + log=True, + symmetric=symmetric, + max_iter=max_iter, + G0=plan_init, + tol_rel=tol, + tol_abs=tol, + verbose=verbose, + ) + + value_linear = log["lin_loss"] + value_quad = log["quad_loss"] + plan = log["T"] + # potentials = (log['u'], log['v']) TODO + + elif unbalanced_type.lower() in ["partial"]: # Partial OT + if M is None or alpha == 1.0: # Partial Gromov-Wasserstein problem + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise ( + ValueError("Partial GW mass given in `unbalanced` is too large") + ) + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-7 + + value, log = partial_gromov_wasserstein2( + Ca, + Cb, + a, + b, + m=unbalanced, + loss_fun=loss_fun, + log=True, + numItermax=max_iter, + G0=plan_init, + tol=tol, + symmetric=symmetric, + verbose=verbose, + ) + + value_quad = value + plan = log["T"] + # potentials = (log['u'], log['v']) TODO + + else: # partial FGW + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise ( + ValueError("Partial GW mass given in `unbalanced` is too large") + ) + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-7 + + value, log = partial_fused_gromov_wasserstein2( + M, + Ca, + Cb, + a, + b, + m=unbalanced, + loss_fun=loss_fun, + alpha=alpha, + log=True, + numItermax=max_iter, + G0=plan_init, + tol=tol, + symmetric=symmetric, + verbose=verbose, + ) + + value_linear = log["lin_loss"] + value_quad = log["quad_loss"] + plan = log["T"] + # potentials = (log['u'], log['v']) TODO + + elif unbalanced_type.lower() in ["kl", "l2"]: # unbalanced exact OT + raise (NotImplementedError('Unbalanced_type="{}"'.format(unbalanced_type))) + + else: + raise ( + NotImplementedError( + 'Unknown unbalanced_type="{}"'.format(unbalanced_type) + ) + ) + + else: # regularized OT + if unbalanced is None and unbalanced_type.lower() not in [ + "semirelaxed", + ]: # Balanced regularized OT + if unbalanced_type.lower() in ["partial"]: + warnings.warn( + "Exact balanced OT is computed as `unbalanced=None` even though " + f"unbalanced_type = {unbalanced_type}.", + stacklevel=2, + ) + + if reg_type.lower() in ["entropy"] and ( + M is None or alpha == 1 + ): # Entropic Gromov-Wasserstein problem + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if method is None: + method = "PGD" + + value_quad, log = entropic_gromov_wasserstein2( + Ca, + Cb, + a, + b, + epsilon=reg, + loss_fun=loss_fun, + log=True, + symmetric=symmetric, + solver=method, + max_iter=max_iter, + G0=plan_init, + tol_rel=tol, + tol_abs=tol, + verbose=verbose, + ) + + plan = log["T"] + value_linear = 0 + value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16)) + # potentials = (log['log_u'], log['log_v']) #TODO + + elif ( + reg_type.lower() in ["entropy"] and M is not None and alpha == 0 + ): # Entropic Wasserstein problem + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + plan, log = sinkhorn_log( + a, + b, + M, + reg=reg, + numItermax=max_iter, + stopThr=tol, + log=True, + verbose=verbose, + ) + + value_linear = nx.sum(M * plan) + value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) + potentials = (log["log_u"], log["log_v"]) + + elif ( + reg_type.lower() in ["entropy"] and M is not None + ): # Entropic Fused Gromov-Wasserstein problem + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if method is None: + method = "PGD" + + value_noreg, log = entropic_fused_gromov_wasserstein2( + M, + Ca, + Cb, + a, + b, + loss_fun=loss_fun, + alpha=alpha, + log=True, + symmetric=symmetric, + solver=method, + max_iter=max_iter, + G0=plan_init, + tol_rel=tol, + tol_abs=tol, + verbose=verbose, + ) + + value_linear = log["lin_loss"] + value_quad = log["quad_loss"] + plan = log["T"] + # potentials = (log['u'], log['v']) + value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) + + else: + raise ( + NotImplementedError( + 'Not implemented reg_type="{}"'.format(reg_type) + ) + ) + + elif unbalanced_type.lower() in ["semirelaxed"]: # Semi-relaxed OT + if reg_type.lower() in ["entropy"] and ( + M is None or alpha == 1 + ): # Entropic Semi-relaxed Gromov-Wasserstein problem + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + value_quad, log = entropic_semirelaxed_gromov_wasserstein2( + Ca, + Cb, + a, + epsilon=reg, + loss_fun=loss_fun, + log=True, + symmetric=symmetric, + max_iter=max_iter, + G0=plan_init, + tol=tol, + verbose=verbose, + ) + + plan = log["T"] + value_linear = 0 + value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16)) + + else: # Entropic Semi-relaxed FGW problem + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + value_noreg, log = entropic_semirelaxed_fused_gromov_wasserstein2( + M, + Ca, + Cb, + a, + loss_fun=loss_fun, + alpha=alpha, + log=True, + symmetric=symmetric, + max_iter=max_iter, + G0=plan_init, + tol=tol, + verbose=verbose, + ) + + value_linear = log["lin_loss"] + value_quad = log["quad_loss"] + plan = log["T"] + value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) + + elif unbalanced_type.lower() in ["partial"]: # Partial OT + if M is None or alpha == 1.0: # Partial Gromov-Wasserstein problem + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise ( + ValueError("Partial GW mass given in `unbalanced` is too large") + ) + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-7 + + value_noreg, log = entropic_partial_gromov_wasserstein2( + Ca, + Cb, + a, + b, + reg=reg, + loss_fun=loss_fun, + m=unbalanced, + log=True, + numItermax=max_iter, + G0=plan_init, + tol=tol, + symmetric=symmetric, + verbose=verbose, + ) + + value_quad = value_noreg + plan = log["T"] + # potentials = (log['u'], log['v']) TODO + value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) + else: # partial FGW + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise ( + ValueError("Partial GW mass given in `unbalanced` is too large") + ) + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-7 + + value_noreg, log = entropic_partial_fused_gromov_wasserstein2( + M, + Ca, + Cb, + a, + b, + reg=reg, + loss_fun=loss_fun, + alpha=alpha, + m=unbalanced, + log=True, + numItermax=max_iter, + G0=plan_init, + tol=tol, + symmetric=symmetric, + verbose=verbose, + ) + + value_linear = log["lin_loss"] + value_quad = log["quad_loss"] + plan = log["T"] + # potentials = (log['u'], log['v']) TODO + value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) + + else: # unbalanced AND regularized OT + raise ( + NotImplementedError( + 'Not implemented reg_type="{}" and unbalanced_type="{}"'.format( + reg_type, unbalanced_type + ) + ) + ) + + res = OTResult( + potentials=potentials, + value=value, + value_linear=value_linear, + value_quad=value_quad, + plan=plan, + status=status, + backend=nx, + log=log, + ) + + return res diff --git a/ot/solvers.py b/ot/solvers/_linear.py similarity index 60% rename from ot/solvers.py rename to ot/solvers/_linear.py index 88cf5c7ab..d9df35aa4 100644 --- a/ot/solvers.py +++ b/ot/solvers/_linear.py @@ -4,40 +4,25 @@ """ # Author: Remi Flamary +# Cédric Vincent-Cuaz # # License: MIT License -from .utils import OTResult, dist -from .lp import emd2, emd2_lazy, wasserstein_1d -from .backend import get_backend -from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced -from .bregman import ( +from ..utils import OTResult, dist +from ..lp import emd2, emd2_lazy, wasserstein_1d +from ..backend import get_backend +from ..unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced +from ..bregman import ( sinkhorn_log, empirical_sinkhorn2, empirical_sinkhorn2_geomloss, empirical_sinkhorn_nystroem2, ) -from .smooth import smooth_ot_dual -from .gromov import ( - gromov_wasserstein2, - fused_gromov_wasserstein2, - entropic_gromov_wasserstein2, - entropic_fused_gromov_wasserstein2, - semirelaxed_gromov_wasserstein2, - semirelaxed_fused_gromov_wasserstein2, - entropic_semirelaxed_fused_gromov_wasserstein2, - entropic_semirelaxed_gromov_wasserstein2, - partial_gromov_wasserstein2, - partial_fused_gromov_wasserstein2, - entropic_partial_gromov_wasserstein2, - entropic_partial_fused_gromov_wasserstein2, -) -from .gaussian import empirical_bures_wasserstein_distance -from .factored import factored_optimal_transport -from .lowrank import lowrank_sinkhorn -from .optim import cg - -import warnings +from ..smooth import smooth_ot_dual +from ..gaussian import empirical_bures_wasserstein_distance +from ..factored import factored_optimal_transport +from ..lowrank import lowrank_sinkhorn +from ..optim import cg lst_method_lazy = [ @@ -604,753 +589,6 @@ def solve( return res -def solve_gromov( - Ca, - Cb, - M=None, - a=None, - b=None, - loss="L2", - symmetric=None, - alpha=0.5, - reg=None, - reg_type="entropy", - unbalanced=None, - unbalanced_type="KL", - n_threads=1, - method=None, - max_iter=None, - plan_init=None, - tol=None, - verbose=False, -): - r"""Solve the discrete (Fused) Gromov-Wasserstein and return :any:`OTResult` object - - The function solves the following optimization problem: - - .. math:: - \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) - - The regularization is selected with `reg` (:math:`\lambda_r`) and - `reg_type`. By default ``reg=None`` and there is no regularization. The - unbalanced marginal penalization can be selected with `unbalanced` - (:math:`\lambda_u`) and `unbalanced_type`. By default ``unbalanced=None`` - and the function solves the exact optimal transport problem (respecting the - marginals). - - Parameters - ---------- - Ca : array-like, shape (dim_a, dim_a) - Cost matrix in the source domain - Cb : array-like, shape (dim_b, dim_b) - Cost matrix in the target domain - M : array-like, shape (dim_a, dim_b), optional - Linear cost matrix for Fused Gromov-Wasserstein (default is None). - a : array-like, shape (dim_a,), optional - Samples weights in the source domain (default is uniform) - b : array-like, shape (dim_b,), optional - Samples weights in the source domain (default is uniform) - loss : str, optional - Type of loss function, either ``"L2"`` or ``"KL"``, by default ``"L2"`` - symmetric : bool, optional - Use symmetric version of the Gromov-Wasserstein problem, by default None - tests whether the matrices are symmetric or True/False to avoid the test. - reg : float, optional - Regularization weight :math:`\lambda_r`, by default None (no reg., exact - OT) - reg_type : str, optional - Type of regularization :math:`R`, by default "entropy" (only used when - ``reg!=None``) - alpha : float, optional - Weight the quadratic term (alpha*Gromov) and the linear term - ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. Not used for - Gromov problem (when M is not provided). By default ``alpha=None`` - corresponds to ``alpha=1`` for Gromov problem (``M==None``) and - ``alpha=0.5`` for Fused Gromov-Wasserstein problem (``M!=None``) - unbalanced : float, optional - Unbalanced penalization weight :math:`\lambda_u`, by default None - (balanced OT). Not implemented yet for "KL" unbalanced penalization - function :math:`U`. Corresponds to the total transport mass for partial OT. - unbalanced_type : str, optional - Type of unbalanced penalization function :math:`U` either "KL", "semirelaxed", - "partial", by default "KL" but note that it is not implemented yet. - n_threads : int, optional - Number of OMP threads for exact OT solver, by default 1 - method : str, optional - Method for solving the problem when multiple algorithms are available, - default None for automatic selection. - max_iter : int, optional - Maximum number of iterations, by default None (default values in each - solvers) - plan_init : array-like, shape (dim_a, dim_b), optional - Initialization of the OT plan for iterative methods, by default None - tol : float, optional - Tolerance for solution precision, by default None (default values in - each solvers) - verbose : bool, optional - Print information in the solver, by default False - - Returns - ------- - res : OTResult() - Result of the optimization problem. The information can be obtained as follows: - - - res.plan : OT plan :math:`\mathbf{T}` - - res.potentials : OT dual potentials - - res.value : Optimal value of the optimization problem - - res.value_linear : Linear OT loss with the optimal OT plan - - res.value_quad : Quadratic (GW) part of the OT loss with the optimal OT plan - - See :any:`OTResult` for more information. - - Notes - ----- - The following methods are available for solving the Gromov-Wasserstein - problem: - - - **Classical Gromov-Wasserstein (GW) problem [3]** (default parameters): - - .. math:: - \min_{\mathbf{T}\geq 0} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} - - s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} - - \mathbf{T}^T \mathbf{1} = \mathbf{b} - - \mathbf{T} \geq 0 - - can be solved with the following code: - - .. code-block:: python - - res = ot.solve_gromov(Ca, Cb) # uniform weights - res = ot.solve_gromov(Ca, Cb, a=a, b=b) # given weights - res = ot.solve_gromov(Ca, Cb, loss='KL') # KL loss - - plan = res.plan # GW plan - value = res.value # GW value - - - **Fused Gromov-Wasserstein (FGW) problem [24]** (when ``M!=None``): - - .. math:: - \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} - - s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} - - \mathbf{T}^T \mathbf{1} = \mathbf{b} - - \mathbf{T} \geq 0 - - can be solved with the following code: - - .. code-block:: python - - res = ot.solve_gromov(Ca, Cb, M) # uniform weights, alpha=0.5 (default) - res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, alpha=0.1) # given weights and alpha - - plan = res.plan # FGW plan - loss_linear_term = res.value_linear # Wasserstein part of the loss - loss_quad_term = res.value_quad # Gromov part of the loss - loss = res.value # FGW value - - - **Regularized (Fused) Gromov-Wasserstein (GW) problem [12]** (when ``reg!=None``): - - .. math:: - \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + \lambda_r R(\mathbf{T}) - - s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} - - \mathbf{T}^T \mathbf{1} = \mathbf{b} - - \mathbf{T} \geq 0 - - can be solved with the following code: - - .. code-block:: python - - res = ot.solve_gromov(Ca, Cb, reg=1.0) # GW entropy regularization (default) - res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, reg=10, alpha=0.1) # FGW with entropy - - plan = res.plan # FGW plan - loss_linear_term = res.value_linear # Wasserstein part of the loss - loss_quad_term = res.value_quad # Gromov part of the loss - loss = res.value # FGW value (including regularization) - - - **Semi-relaxed (Fused) Gromov-Wasserstein (GW) [48]** (when ``unbalanced='semirelaxed'``): - - .. math:: - \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} - - s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} - - \mathbf{T} \geq 0 - - can be solved with the following code: - - .. code-block:: python - - res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed') # semirelaxed GW - res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed', reg=1) # entropic semirelaxed GW - res = ot.solve_gromov(Ca, Cb, M, unbalanced='semirelaxed', alpha=0.1) # semirelaxed FGW - - plan = res.plan # FGW plan - right_marginal = res.marginal_b # right marginal of the plan - - - **Partial (Fused) Gromov-Wasserstein (GW) problem [29]** (when ``unbalanced='partial'``): - - .. math:: - \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} - - s.t. \ \mathbf{T} \mathbf{1} \leq \mathbf{a} - - \mathbf{T}^T \mathbf{1} \leq \mathbf{b} - - \mathbf{T} \geq 0 - - \mathbf{1}^T\mathbf{T}\mathbf{1} = m - - can be solved with the following code: - - .. code-block:: python - - res = ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=0.8) # partial GW with m=0.8 - res = ot.solve_gromov(Ca, Cb, M, unbalanced_type='partial', unbalanced=0.8, alpha=0.5) # partial FGW with m=0.8 - - - .. _references-solve-gromov: - References - ---------- - - .. [3] Mémoli, F. (2011). Gromov–Wasserstein distances and the metric - approach to object matching. Foundations of computational mathematics, - 11(4), 417-487. - - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), - Gromov-Wasserstein averaging of kernel and distance matrices - International Conference on Machine Learning (ICML). - - .. [24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. - (2019). Optimal Transport for structured data with application on graphs - Proceedings of the 36th International Conference on Machine Learning - (ICML). - - .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, - Nicolas Courty (2022). Semi-relaxed Gromov-Wasserstein divergence and - applications on graphs. International Conference on Learning - Representations (ICLR), 2022. - - .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). Partial Optimal Transport - with Applications on Positive-Unlabeled Learning, Advances in Neural - Information Processing Systems (NeurIPS), 2020. - - """ - - # detect backend - nx = get_backend(Ca, Cb, M, a, b) - - # create uniform weights if not given - if a is None: - a = nx.ones(Ca.shape[0], type_as=Ca) / Ca.shape[0] - if b is None: - b = nx.ones(Cb.shape[1], type_as=Cb) / Cb.shape[1] - - # default values for solutions - potentials = None - value = None - value_linear = None - value_quad = None - plan = None - status = None - log = None - - loss_dict = {"l2": "square_loss", "kl": "kl_loss"} - - if loss.lower() not in loss_dict.keys(): - raise (NotImplementedError('Not implemented GW loss="{}"'.format(loss))) - loss_fun = loss_dict[loss.lower()] - - if reg is None or reg == 0: # exact OT - if unbalanced is None and unbalanced_type.lower() not in [ - "semirelaxed", - ]: # Exact balanced OT - if unbalanced_type.lower() in ["partial"]: - warnings.warn( - "Exact balanced OT is computed as `unbalanced=None` even though " - f"unbalanced_type = {unbalanced_type}.", - stacklevel=2, - ) - - if M is None or alpha == 1: # Gromov-Wasserstein problem - # default values for solver - if max_iter is None: - max_iter = 10000 - if tol is None: - tol = 1e-9 - - value, log = gromov_wasserstein2( - Ca, - Cb, - a, - b, - loss_fun=loss_fun, - log=True, - symmetric=symmetric, - max_iter=max_iter, - G0=plan_init, - tol_rel=tol, - tol_abs=tol, - verbose=verbose, - ) - - value_quad = value - if alpha == 1: # set to 0 for FGW with alpha=1 - value_linear = 0 - plan = log["T"] - potentials = (log["u"], log["v"]) - - elif alpha == 0: # Wasserstein problem - # default values for EMD solver - if max_iter is None: - max_iter = 1000000 - - value_linear, log = emd2( - a, - b, - M, - numItermax=max_iter, - log=True, - return_matrix=True, - numThreads=n_threads, - ) - - value = value_linear - potentials = (log["u"], log["v"]) - plan = log["G"] - status = log["warning"] if log["warning"] is not None else "Converged" - value_quad = 0 - - else: # Fused Gromov-Wasserstein problem - # default values for solver - if max_iter is None: - max_iter = 10000 - if tol is None: - tol = 1e-9 - - value, log = fused_gromov_wasserstein2( - M, - Ca, - Cb, - a, - b, - loss_fun=loss_fun, - alpha=alpha, - log=True, - symmetric=symmetric, - max_iter=max_iter, - G0=plan_init, - tol_rel=tol, - tol_abs=tol, - verbose=verbose, - ) - - value_linear = log["lin_loss"] - value_quad = log["quad_loss"] - plan = log["T"] - potentials = (log["u"], log["v"]) - - elif unbalanced_type.lower() in ["semirelaxed"]: # Semi-relaxed OT - if M is None or alpha == 1: # Semi relaxed Gromov-Wasserstein problem - # default values for solver - if max_iter is None: - max_iter = 10000 - if tol is None: - tol = 1e-9 - - value, log = semirelaxed_gromov_wasserstein2( - Ca, - Cb, - a, - loss_fun=loss_fun, - log=True, - symmetric=symmetric, - max_iter=max_iter, - G0=plan_init, - tol_rel=tol, - tol_abs=tol, - verbose=verbose, - ) - - value_quad = value - if alpha == 1: # set to 0 for FGW with alpha=1 - value_linear = 0 - plan = log["T"] - # potentials = (log['u'], log['v']) TODO - - else: # Semi relaxed Fused Gromov-Wasserstein problem - # default values for solver - if max_iter is None: - max_iter = 10000 - if tol is None: - tol = 1e-9 - - value, log = semirelaxed_fused_gromov_wasserstein2( - M, - Ca, - Cb, - a, - loss_fun=loss_fun, - alpha=alpha, - log=True, - symmetric=symmetric, - max_iter=max_iter, - G0=plan_init, - tol_rel=tol, - tol_abs=tol, - verbose=verbose, - ) - - value_linear = log["lin_loss"] - value_quad = log["quad_loss"] - plan = log["T"] - # potentials = (log['u'], log['v']) TODO - - elif unbalanced_type.lower() in ["partial"]: # Partial OT - if M is None or alpha == 1.0: # Partial Gromov-Wasserstein problem - if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise ( - ValueError("Partial GW mass given in `unbalanced` is too large") - ) - - # default values for solver - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-7 - - value, log = partial_gromov_wasserstein2( - Ca, - Cb, - a, - b, - m=unbalanced, - loss_fun=loss_fun, - log=True, - numItermax=max_iter, - G0=plan_init, - tol=tol, - symmetric=symmetric, - verbose=verbose, - ) - - value_quad = value - plan = log["T"] - # potentials = (log['u'], log['v']) TODO - - else: # partial FGW - if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise ( - ValueError("Partial GW mass given in `unbalanced` is too large") - ) - # default values for solver - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-7 - - value, log = partial_fused_gromov_wasserstein2( - M, - Ca, - Cb, - a, - b, - m=unbalanced, - loss_fun=loss_fun, - alpha=alpha, - log=True, - numItermax=max_iter, - G0=plan_init, - tol=tol, - symmetric=symmetric, - verbose=verbose, - ) - - value_linear = log["lin_loss"] - value_quad = log["quad_loss"] - plan = log["T"] - # potentials = (log['u'], log['v']) TODO - - elif unbalanced_type.lower() in ["kl", "l2"]: # unbalanced exact OT - raise (NotImplementedError('Unbalanced_type="{}"'.format(unbalanced_type))) - - else: - raise ( - NotImplementedError( - 'Unknown unbalanced_type="{}"'.format(unbalanced_type) - ) - ) - - else: # regularized OT - if unbalanced is None and unbalanced_type.lower() not in [ - "semirelaxed", - ]: # Balanced regularized OT - if unbalanced_type.lower() in ["partial"]: - warnings.warn( - "Exact balanced OT is computed as `unbalanced=None` even though " - f"unbalanced_type = {unbalanced_type}.", - stacklevel=2, - ) - - if reg_type.lower() in ["entropy"] and ( - M is None or alpha == 1 - ): # Entropic Gromov-Wasserstein problem - # default values for solver - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - if method is None: - method = "PGD" - - value_quad, log = entropic_gromov_wasserstein2( - Ca, - Cb, - a, - b, - epsilon=reg, - loss_fun=loss_fun, - log=True, - symmetric=symmetric, - solver=method, - max_iter=max_iter, - G0=plan_init, - tol_rel=tol, - tol_abs=tol, - verbose=verbose, - ) - - plan = log["T"] - value_linear = 0 - value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16)) - # potentials = (log['log_u'], log['log_v']) #TODO - - elif ( - reg_type.lower() in ["entropy"] and M is not None and alpha == 0 - ): # Entropic Wasserstein problem - # default values for solver - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - - plan, log = sinkhorn_log( - a, - b, - M, - reg=reg, - numItermax=max_iter, - stopThr=tol, - log=True, - verbose=verbose, - ) - - value_linear = nx.sum(M * plan) - value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) - potentials = (log["log_u"], log["log_v"]) - - elif ( - reg_type.lower() in ["entropy"] and M is not None - ): # Entropic Fused Gromov-Wasserstein problem - # default values for solver - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - if method is None: - method = "PGD" - - value_noreg, log = entropic_fused_gromov_wasserstein2( - M, - Ca, - Cb, - a, - b, - loss_fun=loss_fun, - alpha=alpha, - log=True, - symmetric=symmetric, - solver=method, - max_iter=max_iter, - G0=plan_init, - tol_rel=tol, - tol_abs=tol, - verbose=verbose, - ) - - value_linear = log["lin_loss"] - value_quad = log["quad_loss"] - plan = log["T"] - # potentials = (log['u'], log['v']) - value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) - - else: - raise ( - NotImplementedError( - 'Not implemented reg_type="{}"'.format(reg_type) - ) - ) - - elif unbalanced_type.lower() in ["semirelaxed"]: # Semi-relaxed OT - if reg_type.lower() in ["entropy"] and ( - M is None or alpha == 1 - ): # Entropic Semi-relaxed Gromov-Wasserstein problem - # default values for solver - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - - value_quad, log = entropic_semirelaxed_gromov_wasserstein2( - Ca, - Cb, - a, - epsilon=reg, - loss_fun=loss_fun, - log=True, - symmetric=symmetric, - max_iter=max_iter, - G0=plan_init, - tol=tol, - verbose=verbose, - ) - - plan = log["T"] - value_linear = 0 - value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16)) - - else: # Entropic Semi-relaxed FGW problem - # default values for solver - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - - value_noreg, log = entropic_semirelaxed_fused_gromov_wasserstein2( - M, - Ca, - Cb, - a, - loss_fun=loss_fun, - alpha=alpha, - log=True, - symmetric=symmetric, - max_iter=max_iter, - G0=plan_init, - tol=tol, - verbose=verbose, - ) - - value_linear = log["lin_loss"] - value_quad = log["quad_loss"] - plan = log["T"] - value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) - - elif unbalanced_type.lower() in ["partial"]: # Partial OT - if M is None or alpha == 1.0: # Partial Gromov-Wasserstein problem - if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise ( - ValueError("Partial GW mass given in `unbalanced` is too large") - ) - - # default values for solver - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-7 - - value_noreg, log = entropic_partial_gromov_wasserstein2( - Ca, - Cb, - a, - b, - reg=reg, - loss_fun=loss_fun, - m=unbalanced, - log=True, - numItermax=max_iter, - G0=plan_init, - tol=tol, - symmetric=symmetric, - verbose=verbose, - ) - - value_quad = value_noreg - plan = log["T"] - # potentials = (log['u'], log['v']) TODO - value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) - else: # partial FGW - if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise ( - ValueError("Partial GW mass given in `unbalanced` is too large") - ) - - # default values for solver - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-7 - - value_noreg, log = entropic_partial_fused_gromov_wasserstein2( - M, - Ca, - Cb, - a, - b, - reg=reg, - loss_fun=loss_fun, - alpha=alpha, - m=unbalanced, - log=True, - numItermax=max_iter, - G0=plan_init, - tol=tol, - symmetric=symmetric, - verbose=verbose, - ) - - value_linear = log["lin_loss"] - value_quad = log["quad_loss"] - plan = log["T"] - # potentials = (log['u'], log['v']) TODO - value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) - - else: # unbalanced AND regularized OT - raise ( - NotImplementedError( - 'Not implemented reg_type="{}" and unbalanced_type="{}"'.format( - reg_type, unbalanced_type - ) - ) - ) - - res = OTResult( - potentials=potentials, - value=value, - value_linear=value_linear, - value_quad=value_quad, - plan=plan, - status=status, - backend=nx, - log=log, - ) - - return res - - def solve_sample( X_a, X_b, @@ -1444,7 +682,7 @@ def solve_sample( plan_init : array-like, shape (dim_a, dim_b), optional Initialization of the OT plan for iterative methods, by default None rank : int, optional - Rank of the OT matrix for lazy solers (method='factored') or (method='nystroem'), by default 100 + Rank of the OT matrix for lazy solvers (method='factored') or (method='nystroem'), by default 100 scaling : float, optional Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95 potentials_init : (array-like(dim_a,),array-like(dim_b,)), optional diff --git a/ot/utils.py b/ot/utils.py index eb90cb22d..f79d69809 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1422,6 +1422,184 @@ def citation(self): """ +class BaryResult: + """Base class for OT barycenter results. + + Parameters + ---------- + X : array-like, shape (`n`, `d`) + Barycenter features. + C: array-like, shape (`n`, `n`) + Barycenter structure for Gromov Wasserstein solutions. + b : array-like, shape (`n`,) + Barycenter weights. + value : float, array-like + Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions. + value_linear : float, array-like + The linear part of the transport cost, i.e. the product between the + transport plan and the cost. + value_quad : float, array-like + The quadratic part of the transport cost for Gromov-Wasserstein + solutions. + log : dict + Dictionary containing potential information about the solver. + list_res: list of OTResult + List of results for the individual OT matching with input distributions considered as + sources and the learned barycenter distribution as target. + status : int or str + Status of the solver. + + Attributes + ---------- + + X : array-like, shape (`n`, `d`) + Barycenter features. + C: array-like, shape (`n`, `n`) + Barycenter structure for Gromov Wasserstein solutions. + b : array-like, shape (`n`,) + Barycenter weights. + value : float, array-like + Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions. + value_linear : float, array-like + The linear part of the transport cost, i.e. the product between the + transport plan and the cost. + value_quad : float, array-like + The quadratic part of the transport cost for Gromov-Wasserstein + solutions. + log : dict + Dictionary containing potential information about the solver. + list_res: list of OTResult + List of results for the individual OT matching. + status : int or str + Status of the solver. + backend : Backend + Backend used to compute the results. + """ + + def __init__( + self, + X=None, + C=None, + b=None, + value=None, + value_linear=None, + value_quad=None, + log=None, + list_res=None, + status=None, + backend=None, + ): + self._X = X + self._C = C + self._b = b + self._value = value + self._value_linear = value_linear + self._value_quad = value_quad + self._log = log + self._list_res = list_res + self._status = status + self._backend = backend if backend is not None else NumpyBackend() + + def __repr__(self): + s = "BaryResult(" + if self._value is not None: + s += "value={},".format(self._value) + if self._value_linear is not None: + s += "value_linear={},".format(self._value_linear) + if self._X is not None: + s += "X={}(shape={}),".format(self._X.__class__.__name__, self._X.shape) + if self._C is not None: + s += "C={}(shape={}),".format(self._C.__class__.__name__, self._C.shape) + if self._b is not None: + s += "b={}(shape={}),".format(self._b.__class__.__name__, self._b.shape) + if s[-1] != "(": + s = s[:-1] + ")" + else: + s = s + ")" + return s + + # Barycerters -------------------------------- + + @property + def X(self): + """Barycenter features.""" + return self._X + + @property + def C(self): + """Barycenter structure for Gromov Wasserstein solutions.""" + return self._C + + @property + def b(self): + """Barycenter weights.""" + return self._b + + # Loss values -------------------------------- + + @property + def value(self): + """Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions.""" + return self._value + + @property + def value_linear(self): + """The "minimal" transport cost, i.e. the product between the transport plan and the cost.""" + return self._value_linear + + @property + def value_quad(self): + """The quadratic part of the transport cost for Gromov-Wasserstein solutions.""" + return self._value_quad + + # List of OTResult objects ------------------------- + + @property + def list_res(self): + """List of results for the individual OT matching.""" + return self._list_res + + @property + def status(self): + """Optimization status of the solver.""" + return self._status + + @property + def log(self): + """Dictionary containing potential information about the solver.""" + return self._log + + # Miscellaneous -------------------------------- + + @property + def citation(self): + """Appropriate citation(s) for this result, in plain text and BibTex formats.""" + + # The string below refers to the POT library: + # successor methods may concatenate the relevant references + # to the original definitions, solvers and underlying numerical backends. + return """POT library: + + POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. + Website: https://pythonot.github.io/ + Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer; + + @article{flamary2021pot, + author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer}, + title = {{POT}: {Python} {Optimal} {Transport}}, + journal = {Journal of Machine Learning Research}, + year = {2021}, + volume = {22}, + number = {78}, + pages = {1-8}, + url = {http://jmlr.org/papers/v22/20-451.html} + } + """ + + class LazyTensor(object): """A lazy tensor is a tensor that is not stored in memory. Instead, it is defined by a function that computes its values on the fly from slices. diff --git a/test/test_solvers.py b/test/test_solvers.py index 802aca631..62a9d3154 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -1,6 +1,7 @@ """Tests for ot solvers""" # Author: Remi Flamary +# Cédric Vincent-Cuaz # # License: MIT License @@ -12,6 +13,7 @@ import ot from ot.bregman import geomloss from ot.backend import torch +from ot.solvers._linear import lst_method_lazy lst_reg = [None, 1] @@ -61,6 +63,13 @@ }, # fail lazy for unbalanced and regularized ] +lst_parameters_solve_bary_sample_NotImplemented = [ + {"method": method} for method in lst_method_lazy +] + [ + {"lazy": True}, # fail lazy + {"metric": "cosine"}, # fail on invalid metric +] + # set readable ids for each param lst_method_params_solve_sample = [ pytest.param(param, id=str(param)) for param in lst_method_params_solve_sample @@ -69,6 +78,10 @@ pytest.param(param, id=str(param)) for param in lst_parameters_solve_sample_NotImplemented ] +lst_parameters_solve_bary_sample_NotImplemented = [ + pytest.param(param, id=str(param)) + for param in lst_parameters_solve_bary_sample_NotImplemented +] def assert_allclose_sol(sol1, sol2): @@ -800,3 +813,279 @@ def test_solve_sample_NotImplemented(nx, method_params): with pytest.raises(NotImplementedError): ot.solve_sample(xb, yb, ab, bb, **method_params) + + +def assert_allclose_bary_sol(sol1, sol2): + lst_attr = ["X", "b", "value", "value_linear", "log"] + + nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() + nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() + + for attr in lst_attr: + if getattr(sol1, attr) is not None and getattr(sol2, attr) is not None: + try: + var1 = getattr(sol1, attr) + var2 = getattr(sol2, attr) + if isinstance(var1, dict): # only contains lists + for key in var1.keys(): + np.allclose( + np.array(var1[key]), + np.array(var2[key]), + equal_nan=True, + ) + else: + np.allclose( + nx1.to_numpy(getattr(sol1, attr)), + nx2.to_numpy(getattr(sol2, attr)), + equal_nan=True, + ) + except NotImplementedError: + pass + elif getattr(sol1, attr) is None and getattr(sol2, attr) is None: + return True + else: + return False + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") +@pytest.mark.parametrize( + "reg,reg_type,unbalanced,unbalanced_type,warmstart", + itertools.product( + lst_reg, + lst_reg_type, + lst_unbalanced, + lst_unbalanced_type, + [True, False], + # lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type, warmstart + ), +) +def test_solve_bary_sample(nx, reg, reg_type, unbalanced, unbalanced_type, warmstart): + # test bary_sample when is_Lazy = False + rng = np.random.RandomState() + + K = 2 # number of distributions + ns = rng.randint(10, 20, K) # number of samples within each distribution + n = 5 # number of samples in the barycenter + + X_list = [rng.randn(ns_i, 2) for ns_i in ns] + # X_init = np.reshape(1.0 * np.randn(n, 2), (n, 1)) + + a_list = [ot.utils.unif(X.shape[0]) for X in X_list] + b = ot.utils.unif(n) + + w = ot.utils.unif(K) + + stopping_criterion = "loss" if rng.choice([True, False]) else "bary" + + # try: + if reg_type == "tuple": + + def f(G): + return np.sum(G**2) + + def df(G): + return 2 * G + + reg_type = (f, df) + # print('test reg_type:', reg_type[0](None), reg_type[1](None)) + # solve default None weights + sol0 = ot.solve_bary_sample( + X_list, + n, + w=None, + metric="sqeuclidean", + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + warmstart=warmstart, + max_iter_bary=2, + tol_bary=1e-3, + stopping_criterion=stopping_criterion, + verbose=True, + ) + print("------ [done] sol0 - no backend") + + # solve provided uniform weights + + sol = ot.solve_bary_sample( + X_list, + n, + a_list=a_list, + b_init=b, + w=w, + metric="sqeuclidean", + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + warmstart=warmstart, + max_iter_bary=2, + tol_bary=1e-3, + stopping_criterion=stopping_criterion, + verbose=True, + ) + print("------ [done] sol - no backend") + + assert_allclose_bary_sol(sol0, sol) + + # solve in backend + X_listb = nx.from_numpy(*X_list) + a_listb = nx.from_numpy(*a_list) + wb, bb = nx.from_numpy(w, b) + + if isinstance(reg_type, tuple): + + def fb(G): + return nx.sum( + G**2 + ) # otherwise we keep previously defined (f, df) as required by inner solver + + reg_type = (fb, df) + + solb = ot.solve_bary_sample( + X_listb, + n, + a_list=a_listb, + b_init=bb, + w=wb, + metric="sqeuclidean", + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + warmstart=warmstart, + max_iter_bary=2, + tol_bary=1e-3, + stopping_criterion=stopping_criterion, + verbose=True, + ) + print("------ [done] sol - with backend") + + assert_allclose_bary_sol(sol, solb) + + # except NotImplementedError: + # pytest.skip("Not implemented") + + +@pytest.mark.parametrize( + "method_params", lst_parameters_solve_bary_sample_NotImplemented +) +def test_solve_bary_sample_NotImplemented(nx, method_params): + # test bary_sample when is_Lazy = False + rng = np.random.RandomState() + + K = 2 # number of distributions + ns = rng.randint(10, 20, K) # number of samples within each distribution + n = 5 # number of samples in the barycenter + + X_list = [rng.randn(ns_i, 2) for ns_i in ns] + # X_init = np.reshape(1.0 * np.randn(n, 2), (n, 1)) + + a_list = [ot.utils.unif(X.shape[0]) for X in X_list] + b = ot.utils.unif(n) + + w = ot.utils.unif(K) + + # solve in backend + X_listb = nx.from_numpy(*X_list) + a_listb = nx.from_numpy(*a_list) + wb, bb = nx.from_numpy(w, b) + + with pytest.raises(NotImplementedError): + ot.solve_bary_sample( + X_listb, n, a_list=a_listb, b_init=bb, w=wb, **method_params + ) + + +def test_solve_bary_sample_ValueError(nx): + # Test ValueError cases: stopping_criterion and X_b_init shape + rng = np.random.RandomState(42) + + K = 2 + ns = [10, 12] + n = 5 + X_list = [rng.randn(ns_i, 2) for ns_i in ns] + a_list = [ot.utils.unif(ns_i) for ns_i in ns] + w = ot.utils.unif(K) + + X_listb = nx.from_numpy(*X_list) + a_listb = nx.from_numpy(*a_list) + wb = nx.from_numpy(w) + + # Test Invalid stopping_criterion + with pytest.raises(ValueError, match="stopping_criterion must be"): + ot.solve_bary_sample( + X_listb, n, a_list=a_listb, w=wb, stopping_criterion="invalid" + ) + + # Test Invalid X_b_init shape + bad_X_b_init = rng.randn(n + 1, 2) + bad_X_b_init = nx.from_numpy(bad_X_b_init) + with pytest.raises(ValueError, match="X_b_init must have shape"): + ot.solve_bary_sample(X_listb, n, X_b_init=bad_X_b_init, a_list=a_listb, w=wb) + + +def test_solve_bary_sample_callable_metric(nx): + # Test callable metric paths + rng = np.random.RandomState(42) + + K = 2 + ns = [8, 9] + n = 4 + X_list = [rng.randn(ns_i, 2) for ns_i in ns] + a_list = [ot.utils.unif(ns_i) for ns_i in ns] + w = ot.utils.unif(K) + + X_listb = nx.from_numpy(*X_list) + a_listb = nx.from_numpy(*a_list) + wb = nx.from_numpy(w) + + # Define custom metric function + def custom_metric(X_b, X_a): + return nx.sqrt(nx.sum((X_b[:, None, :] - X_a[None, :, :]) ** 2, axis=2)) + + if str(nx) == "torch": + # Test single callable metric + sol_callable = ot.solve_bary_sample( + X_listb, n, a_list=a_listb, w=wb, metric=custom_metric, max_iter_bary=1 + ) + assert sol_callable is not None + assert sol_callable.X is not None + + # Test list of callable metrics + metrics_list = [custom_metric, custom_metric] + sol_list = ot.solve_bary_sample( + X_listb, n, a_list=a_listb, w=wb, metric=metrics_list, max_iter_bary=1 + ) + assert sol_list is not None + assert sol_list.X is not None + + # Test with auto_bary_method="true_fixed_point" + with pytest.raises( + ValueError, + match="ground_bary must be provided in kwargs for true_fixed_point method with callable metrics", + ): + sol_fixed_point = ot.solve_bary_sample( + X_listb, + n, + a_list=a_listb, + w=wb, + metric=custom_metric, + auto_bary_method="true_fixed_point", + max_iter_bary=1, + ) + else: + with pytest.raises( + AssertionError, + match=f"Backend {str(nx)} is not compatible with ground_bary=None, it must be provided if not using PyTorch backend", + ): + sol_wrong_backend = ot.solve_bary_sample( + X_listb, + n, + a_list=a_listb, + w=wb, + metric=custom_metric, + max_iter_bary=1, + ) diff --git a/test/test_utils.py b/test/test_utils.py index db4adeb68..05885a1c2 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -479,7 +479,7 @@ def test_OTResult(): # test print print(res) - # tets get citation + # test get citation print(res.citation) lst_attributes = [ @@ -509,6 +509,31 @@ def test_OTResult(): getattr(res, at) +def test_BaryResult(): + res = ot.utils.BaryResult() + + # test print + print(res) + + # test get citation + print(res.citation) + + lst_attributes = [ + "X", + "C", + "b", + "value", + "value_linear", + "value_quad", + "list_res", + "status", + "log", + ] + for at in lst_attributes: + print(at) + assert getattr(res, at) is None + + def test_get_coordinate_circle(): rng = np.random.RandomState(42) u = rng.rand(1, 100)