diff --git a/src/pyrecest/__init__.py b/src/pyrecest/__init__.py index 9df2b5bfa..b1b8915a8 100644 --- a/src/pyrecest/__init__.py +++ b/src/pyrecest/__init__.py @@ -198,6 +198,113 @@ def tile(x, reps): backend.tile = tile +def _pytorch_pad_pairs(pad_width, ndim, numpy_module) -> tuple[tuple[int, int], ...]: + """Normalize NumPy-style pad widths to per-axis pairs.""" + + try: + pad_pairs = numpy_module.broadcast_to(numpy_module.asarray(pad_width), (ndim, 2)) + except ValueError as exc: + raise ValueError( + f"pad_width must be broadcastable to shape ({ndim}, 2)" + ) from exc + + if numpy_module.any(pad_pairs < 0): + raise ValueError("index can't contain negative values") + + return tuple(tuple(int(value) for value in pair) for pair in pad_pairs.tolist()) + + +def _pytorch_torch_pad_width(pad_pairs: tuple[tuple[int, int], ...]) -> list[int]: + """Convert NumPy-ordered pad pairs to PyTorch's reversed flat order.""" + + return [value for pair in reversed(pad_pairs) for value in pair] + + +def _pytorch_constant_value_pairs( + constant_values, + ndim, + numpy_module, + torch_module, +) -> tuple[tuple[object, object], ...]: + """Normalize NumPy-style constant pad values to per-axis pairs.""" + + if torch_module.is_tensor(constant_values): + constant_values = constant_values.detach().cpu().numpy() + + try: + value_pairs = numpy_module.broadcast_to( + numpy_module.asarray(constant_values), + (ndim, 2), + ) + except ValueError as exc: + raise ValueError( + f"constant_values must be broadcastable to shape ({ndim}, 2)" + ) from exc + + return tuple(tuple(pair) for pair in value_pairs.tolist()) + + +def _patch_pytorch_pad_facade() -> None: + """Make public PyTorch ``pad`` accept NumPy-style constant values.""" + + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + + if getattr(backend, "__backend_name__", None) != "pytorch": + return + + try: + import numpy as _np # pylint: disable=import-outside-toplevel + import torch as _torch # pylint: disable=import-outside-toplevel + except ( + ModuleNotFoundError + ): # pragma: no cover - backend import fails first in practice + return + + original_pad = backend.pad + + def pad(a, pad_width, mode="constant", constant_values=0.0): + values = backend.array(a) + if mode != "constant": + return original_pad( + values, + pad_width, + mode=mode, + constant_values=constant_values, + ) + + pad_pairs = _pytorch_pad_pairs(pad_width, values.ndim, _np) + torch_pad_width = _pytorch_torch_pad_width(pad_pairs) + result = _torch.nn.functional.pad( + values, + torch_pad_width, + mode="constant", + value=0.0, + ) + value_pairs = _pytorch_constant_value_pairs( + constant_values, + values.ndim, + _np, + _torch, + ) + + for axis, ((before, after), (before_value, after_value)) in enumerate( + zip(pad_pairs, value_pairs) + ): + if before: + index = [slice(None)] * result.ndim + index[axis] = slice(0, before) + result[tuple(index)] = before_value + if after: + index = [slice(None)] * result.ndim + index[axis] = slice(result.shape[axis] - after, result.shape[axis]) + result[tuple(index)] = after_value + return result + + pad.__name__ = "pad" + pad.__doc__ = getattr(original_pad, "__doc__", None) + backend.pad = pad + + def _patch_jax_std_out_facade() -> None: """Make public JAX ``std`` accept NumPy's ``out`` argument.""" @@ -231,6 +338,7 @@ def std( _patch_pytorch_comparison_facade() _patch_pytorch_tile_facade() +_patch_pytorch_pad_facade() _patch_jax_std_out_facade() from pyrecest.backend_support import ( # noqa: E402,F401 diff --git a/tests/backend_support/test_pytorch_pad_constant_values_contract.py b/tests/backend_support/test_pytorch_pad_constant_values_contract.py new file mode 100644 index 000000000..7c5941bb9 --- /dev/null +++ b/tests/backend_support/test_pytorch_pad_constant_values_contract.py @@ -0,0 +1,39 @@ +import pytest + +from tests.support.backend_runner import run_backend_code + + +def test_pytorch_pad_accepts_numpy_constant_value_pairs(): + pytest.importorskip("torch") + + code = """ +import numpy as np +import numpy.testing as npt + +import pyrecest.backend as backend + +values = backend.asarray([[1, 2], [3, 4]], dtype=backend.int64) +result = backend.pad(values, ((1, 0), (0, 2)), constant_values=((5, 6), (7, 8))) +expected = np.pad( + np.array([[1, 2], [3, 4]]), + ((1, 0), (0, 2)), + mode="constant", + constant_values=((5, 6), (7, 8)), +) +assert result.shape == expected.shape +assert backend.to_numpy(result).tolist() == expected.tolist() + +complex_values = backend.asarray([1.0 + 1.0j], dtype=backend.complex128) +complex_result = backend.pad(complex_values, (1, 1), constant_values=2.0 + 3.0j) +complex_expected = np.pad( + np.array([1.0 + 1.0j], dtype=np.complex128), + (1, 1), + mode="constant", + constant_values=2.0 + 3.0j, +) +npt.assert_allclose(backend.to_numpy(complex_result), complex_expected) +""" + + result = run_backend_code("pytorch", code) + + assert result.returncode == 0, result.stderr