diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 0c289183e..ecd50ec7d 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -330,6 +330,8 @@ Losses BaseDualLoss LpLoss PowerLoss + SinkhornLoss + Weighting Schemas -------------------- @@ -343,4 +345,4 @@ Weighting Schemas Neural-Tangent-Kernel Weighting No Weighting Scalar Weighting - Self-Adaptive Weighting \ No newline at end of file + Self-Adaptive Weighting diff --git a/docs/source/_rst/loss/sinkhorn_loss.rst b/docs/source/_rst/loss/sinkhorn_loss.rst new file mode 100644 index 000000000..17aa370ad --- /dev/null +++ b/docs/source/_rst/loss/sinkhorn_loss.rst @@ -0,0 +1,11 @@ +Sinkhorn Loss +=============== + +.. currentmodule:: pina.loss.sinkhorn_loss + +.. automodule:: pina._src.loss.sinkhorn_loss + :no-members: + +.. autoclass:: pina._src.loss.sinkhorn_loss.SinkhornLoss + :members: + :show-inheritance: diff --git a/pina/_src/loss/sinkhorn_loss.py b/pina/_src/loss/sinkhorn_loss.py new file mode 100644 index 000000000..9feddc458 --- /dev/null +++ b/pina/_src/loss/sinkhorn_loss.py @@ -0,0 +1,138 @@ +"""Module for the SinkhornLoss class.""" + +import torch +from pina._src.loss.base_dual_loss import BaseDualLoss +from pina._src.core.utils import check_consistency, check_positive_integer + + +class SinkhornLoss(BaseDualLoss): + r""" + Implementation of the Sinkhorn loss measuring the entropy-regularized + optimal transport distance between two empirical distributions. + + Given an input tensor :math:`x` with :math:`N` samples and a target tensor + :math:`y` with :math:`M` samples, both in :math:`\mathbb{R}^D`, the loss is + defined through the entropy-regularized optimal transport problem: + + .. math:: + + W_\varepsilon(\mu, \nu) = \min_{\pi \in \Pi(\mu, \nu)} + \langle C, \pi \rangle - \varepsilon H(\pi) + + where :math:`\mu` and :math:`\nu` are the empirical distributions associated + with :math:`x` and :math:`y`, :math:`\pi` is a transport plan, and + :math:`\Pi(\mu, \nu)` is the set of admissible transport plans with + marginals :math:`\mu` and :math:`\nu`. + + The cost matrix is defined as: + + .. math:: + + C_{ij} = \left\| x_i - y_j \right\|_2^p + + and the entropy term is: + + .. math:: + + H(\pi) = - \sum_{i,j} \pi_{ij} \log \pi_{ij} + + where :math:`\varepsilon > 0` controls the strength of the entropic + regularization. + + The Sinkhorn iterations compute the optimal dual potentials :math:`f^\ast` + and :math:`g^\ast` in log space. The regularized optimal transport cost is + then recovered from the dual formulation as: + + .. math:: + + W_\varepsilon = \langle a, f^\ast \rangle + \langle b, g^\ast \rangle + + where :math:`a` and :math:`b` are uniform probability weights over the + :math:`N` input samples and :math:`M` target samples, respectively. + + Unlike pointwise losses, the Sinkhorn loss compares whole empirical + distributions. Therefore, the output is always a scalar value. + + Smaller values of ``eps`` provide a closer approximation to the true + Wasserstein distance, but may require more Sinkhorn iterations to converge. + + .. seealso:: + + **Original reference:** Patrini, G., Carioni, M., Forr'e, P., Bhargav, + S., Welling, M., Van den Berg, R., Genewein, T., and Nielsen, F. (2019). + *Sinkhorn AutoEncoders*. + In Proceedings of the 35th Conference on Uncertainty in Artificial + Intelligence. + URL: ``_. + """ + + def __init__(self, p=2, eps=0.1, iterations=100): + """ + Initialization of the :class:`SinkhornLoss` class. + + :param int p: The exponent of the cost function. Default is ``2``. + :param eps: The entropy regularization strength. Smaller values provide + a closer approximation to the unregularized Wasserstein distance, + but may require more iterations for convergence. Default is ``0.1``. + :type eps: int | float + :param int iterations: The number of Sinkhorn iterations. + Default is ``100``. + :raises AssertionError: If ``iterations`` is not a positive integer. + :raises AssertionError: If ``p`` is not a positive integer. + :raises ValueError: If ``eps`` is not a positive numeric value. + """ + # Initialize the base class with mean reduction + super().__init__(reduction="mean") + + # Check consistency + check_positive_integer(iterations, strict=True) + check_positive_integer(p, strict=True) + check_consistency(eps, (int, float)) + if eps <= 0: + raise ValueError( + f"Expected 'eps' to be strictly positive, but got {eps}." + ) + + # Initialize parameters + self.iterations = iterations + self.eps = eps + self.p = p + + def forward(self, input, target): + """ + Forward method of the loss function. + + :param torch.Tensor input: The input tensor. + :param torch.Tensor target: The target tensor. + :return: The computed Sinkhorn loss value. + :rtype: torch.Tensor + """ + # Extract the number of samples in input and target + n, m = input.shape[0], target.shape[0] + + # Initialize log-uniform weights for the empirical distributions + log_a = -input.new_tensor(n).log().expand(n) + log_b = -target.new_tensor(m).log().expand(m) + + # Initialize dual potentials f and g + f = torch.zeros(n, dtype=input.dtype, device=input.device) + g = torch.zeros(m, dtype=target.dtype, device=target.device) + + # Define the cost matrix, shape (n, m) + C = torch.cdist(input, target, p=self.p) ** self.p + + # Perform Sinkhorn iterations in log space for numerical stability + for _ in range(self.iterations): + + # Update dual potential f with the softmin operation in log space + softmin_f = torch.logsumexp((g.unsqueeze(0) - C) / self.eps, dim=1) + f = self.eps * (log_a - softmin_f) + + # Update dual potential g with the softmin operation in log space + softmin_g = torch.logsumexp((f.unsqueeze(1) - C) / self.eps, dim=0) + g = self.eps * (log_b - softmin_g) + + # Compute the Sinkhorn loss as the sum of the means of f and g + loss = f.mean() + g.mean() + + return self._reduction(loss.unsqueeze(0)) diff --git a/pina/loss/__init__.py b/pina/loss/__init__.py index 52ed278c7..280cbf76a 100644 --- a/pina/loss/__init__.py +++ b/pina/loss/__init__.py @@ -5,12 +5,14 @@ "BaseDualLoss", "LpLoss", "PowerLoss", + "SinkhornLoss", ] from pina._src.loss.dual_loss_interface import DualLossInterface from pina._src.loss.base_dual_loss import BaseDualLoss from pina._src.loss.power_loss import PowerLoss from pina._src.loss.lp_loss import LpLoss +from pina._src.loss.sinkhorn_loss import SinkhornLoss # Back-compatibility with version 0.2, to be removed soon import warnings diff --git a/tests/test_loss/test_sinkhorn_loss.py b/tests/test_loss/test_sinkhorn_loss.py new file mode 100644 index 000000000..86eb4de62 --- /dev/null +++ b/tests/test_loss/test_sinkhorn_loss.py @@ -0,0 +1,54 @@ +import torch +import pytest +from pina.loss import SinkhornLoss + + +@pytest.mark.parametrize("p", [1, 2]) +@pytest.mark.parametrize("eps", [0.01, 1]) +@pytest.mark.parametrize("iterations", [2, 5]) +def test_constructor(p, eps, iterations): + + # Define the loss + SinkhornLoss(p=p, eps=eps, iterations=iterations) + + # Should fail if iterations is not a positive integer + with pytest.raises(AssertionError): + SinkhornLoss(p=p, eps=eps, iterations=0) + + # Should fail if p is not a positive integer + with pytest.raises(AssertionError): + SinkhornLoss(p=0, eps=eps, iterations=iterations) + + # Should fail if eps is not numeric + with pytest.raises(ValueError): + SinkhornLoss(p=p, eps="invalid", iterations=iterations) + + # Should fail if eps is not positive + with pytest.raises(ValueError): + SinkhornLoss(p=p, eps=-0.1, iterations=iterations) + + +@pytest.mark.parametrize("p", [2, 3]) +@pytest.mark.parametrize("eps", [0.1, 1]) +@pytest.mark.parametrize("iterations", [2, 5]) +@pytest.mark.parametrize( + "input, target", + [ + (torch.rand(10, 2), torch.rand(8, 2)), + (torch.rand(5, 3), torch.rand(5, 3)), + (torch.rand(1, 4), torch.rand(7, 4)), + (torch.rand(6, 4), torch.rand(1, 4)), + (torch.rand(3, 1), torch.rand(4, 1)), + ], +) +def test_forward(p, eps, iterations, input, target): + + # Define the loss + loss = SinkhornLoss(p=p, eps=eps, iterations=iterations) + + # Forward pass + value = loss(input, target) + + # Check shape + assert value.shape == torch.Size([1]) + assert torch.isfinite(value).all()