-
Notifications
You must be signed in to change notification settings - Fork 107
Sinkhorn loss implementation #809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+208
−1
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| Sinkhorn Loss | ||
| =============== | ||
|
|
||
| .. currentmodule:: pina.loss.sinkhorn_loss | ||
|
|
||
| .. automodule:: pina._src.loss.sinkhorn_loss | ||
| :no-members: | ||
|
|
||
| .. autoclass:: pina._src.loss.sinkhorn_loss.SinkhornLoss | ||
| :members: | ||
| :show-inheritance: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| """Module for the SinkhornLoss class.""" | ||
|
|
||
| import torch | ||
| from pina._src.loss.base_dual_loss import BaseDualLoss | ||
| from pina._src.core.utils import check_consistency, check_positive_integer | ||
|
|
||
|
|
||
| class SinkhornLoss(BaseDualLoss): | ||
| r""" | ||
| Implementation of the Sinkhorn loss measuring the entropy-regularized | ||
| optimal transport distance between two empirical distributions. | ||
|
|
||
| Given an input tensor :math:`x` with :math:`N` samples and a target tensor | ||
| :math:`y` with :math:`M` samples, both in :math:`\mathbb{R}^D`, the loss is | ||
| defined through the entropy-regularized optimal transport problem: | ||
|
|
||
| .. math:: | ||
|
|
||
| W_\varepsilon(\mu, \nu) = \min_{\pi \in \Pi(\mu, \nu)} | ||
| \langle C, \pi \rangle - \varepsilon H(\pi) | ||
|
|
||
| where :math:`\mu` and :math:`\nu` are the empirical distributions associated | ||
| with :math:`x` and :math:`y`, :math:`\pi` is a transport plan, and | ||
| :math:`\Pi(\mu, \nu)` is the set of admissible transport plans with | ||
| marginals :math:`\mu` and :math:`\nu`. | ||
|
|
||
| The cost matrix is defined as: | ||
|
|
||
| .. math:: | ||
|
|
||
| C_{ij} = \left\| x_i - y_j \right\|_2^p | ||
|
|
||
| and the entropy term is: | ||
|
|
||
| .. math:: | ||
|
|
||
| H(\pi) = - \sum_{i,j} \pi_{ij} \log \pi_{ij} | ||
|
|
||
| where :math:`\varepsilon > 0` controls the strength of the entropic | ||
| regularization. | ||
|
|
||
| The Sinkhorn iterations compute the optimal dual potentials :math:`f^\ast` | ||
| and :math:`g^\ast` in log space. The regularized optimal transport cost is | ||
| then recovered from the dual formulation as: | ||
|
|
||
| .. math:: | ||
|
|
||
| W_\varepsilon = \langle a, f^\ast \rangle + \langle b, g^\ast \rangle | ||
|
|
||
| where :math:`a` and :math:`b` are uniform probability weights over the | ||
| :math:`N` input samples and :math:`M` target samples, respectively. | ||
|
|
||
| Unlike pointwise losses, the Sinkhorn loss compares whole empirical | ||
| distributions. Therefore, the output is always a scalar value. | ||
|
|
||
| Smaller values of ``eps`` provide a closer approximation to the true | ||
| Wasserstein distance, but may require more Sinkhorn iterations to converge. | ||
|
|
||
| .. seealso:: | ||
|
|
||
| **Original reference:** Patrini, G., Carioni, M., Forr'e, P., Bhargav, | ||
| S., Welling, M., Van den Berg, R., Genewein, T., and Nielsen, F. (2019). | ||
| *Sinkhorn AutoEncoders*. | ||
| In Proceedings of the 35th Conference on Uncertainty in Artificial | ||
| Intelligence. | ||
| URL: `<https://openreview.net/forum?id=BygNqoR9tm>`_. | ||
| """ | ||
|
|
||
| def __init__(self, p=2, eps=0.1, iterations=100): | ||
| """ | ||
| Initialization of the :class:`SinkhornLoss` class. | ||
|
|
||
| :param int p: The exponent of the cost function. Default is ``2``. | ||
| :param eps: The entropy regularization strength. Smaller values provide | ||
| a closer approximation to the unregularized Wasserstein distance, | ||
| but may require more iterations for convergence. Default is ``0.1``. | ||
| :type eps: int | float | ||
| :param int iterations: The number of Sinkhorn iterations. | ||
| Default is ``100``. | ||
| :raises AssertionError: If ``iterations`` is not a positive integer. | ||
| :raises AssertionError: If ``p`` is not a positive integer. | ||
| :raises ValueError: If ``eps`` is not a positive numeric value. | ||
| """ | ||
| # Initialize the base class with mean reduction | ||
| super().__init__(reduction="mean") | ||
|
|
||
| # Check consistency | ||
| check_positive_integer(iterations, strict=True) | ||
| check_positive_integer(p, strict=True) | ||
| check_consistency(eps, (int, float)) | ||
| if eps <= 0: | ||
| raise ValueError( | ||
| f"Expected 'eps' to be strictly positive, but got {eps}." | ||
| ) | ||
|
|
||
| # Initialize parameters | ||
| self.iterations = iterations | ||
| self.eps = eps | ||
| self.p = p | ||
|
|
||
| def forward(self, input, target): | ||
| """ | ||
| Forward method of the loss function. | ||
|
|
||
| :param torch.Tensor input: The input tensor. | ||
| :param torch.Tensor target: The target tensor. | ||
| :return: The computed Sinkhorn loss value. | ||
| :rtype: torch.Tensor | ||
| """ | ||
| # Extract the number of samples in input and target | ||
| n, m = input.shape[0], target.shape[0] | ||
|
|
||
| # Initialize log-uniform weights for the empirical distributions | ||
| log_a = -input.new_tensor(n).log().expand(n) | ||
| log_b = -target.new_tensor(m).log().expand(m) | ||
|
|
||
| # Initialize dual potentials f and g | ||
| f = torch.zeros(n, dtype=input.dtype, device=input.device) | ||
| g = torch.zeros(m, dtype=target.dtype, device=target.device) | ||
|
|
||
| # Define the cost matrix, shape (n, m) | ||
| C = torch.cdist(input, target, p=self.p) ** self.p | ||
|
|
||
| # Perform Sinkhorn iterations in log space for numerical stability | ||
| for _ in range(self.iterations): | ||
|
|
||
| # Update dual potential f with the softmin operation in log space | ||
| softmin_f = torch.logsumexp((g.unsqueeze(0) - C) / self.eps, dim=1) | ||
| f = self.eps * (log_a - softmin_f) | ||
|
|
||
| # Update dual potential g with the softmin operation in log space | ||
| softmin_g = torch.logsumexp((f.unsqueeze(1) - C) / self.eps, dim=0) | ||
| g = self.eps * (log_b - softmin_g) | ||
|
|
||
| # Compute the Sinkhorn loss as the sum of the means of f and g | ||
| loss = f.mean() + g.mean() | ||
|
|
||
| return self._reduction(loss.unsqueeze(0)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| import torch | ||
| import pytest | ||
| from pina.loss import SinkhornLoss | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("p", [1, 2]) | ||
| @pytest.mark.parametrize("eps", [0.01, 1]) | ||
| @pytest.mark.parametrize("iterations", [2, 5]) | ||
| def test_constructor(p, eps, iterations): | ||
|
|
||
| # Define the loss | ||
| SinkhornLoss(p=p, eps=eps, iterations=iterations) | ||
|
|
||
| # Should fail if iterations is not a positive integer | ||
| with pytest.raises(AssertionError): | ||
| SinkhornLoss(p=p, eps=eps, iterations=0) | ||
|
|
||
| # Should fail if p is not a positive integer | ||
| with pytest.raises(AssertionError): | ||
| SinkhornLoss(p=0, eps=eps, iterations=iterations) | ||
|
|
||
| # Should fail if eps is not numeric | ||
| with pytest.raises(ValueError): | ||
| SinkhornLoss(p=p, eps="invalid", iterations=iterations) | ||
|
|
||
| # Should fail if eps is not positive | ||
| with pytest.raises(ValueError): | ||
| SinkhornLoss(p=p, eps=-0.1, iterations=iterations) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("p", [2, 3]) | ||
| @pytest.mark.parametrize("eps", [0.1, 1]) | ||
| @pytest.mark.parametrize("iterations", [2, 5]) | ||
| @pytest.mark.parametrize( | ||
| "input, target", | ||
| [ | ||
| (torch.rand(10, 2), torch.rand(8, 2)), | ||
| (torch.rand(5, 3), torch.rand(5, 3)), | ||
| (torch.rand(1, 4), torch.rand(7, 4)), | ||
| (torch.rand(6, 4), torch.rand(1, 4)), | ||
| (torch.rand(3, 1), torch.rand(4, 1)), | ||
| ], | ||
| ) | ||
| def test_forward(p, eps, iterations, input, target): | ||
|
|
||
| # Define the loss | ||
| loss = SinkhornLoss(p=p, eps=eps, iterations=iterations) | ||
|
|
||
| # Forward pass | ||
| value = loss(input, target) | ||
|
|
||
| # Check shape | ||
| assert value.shape == torch.Size([1]) | ||
| assert torch.isfinite(value).all() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.