Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions src/pyrecest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,113 @@ def tile(x, reps):
backend.tile = tile


def _pytorch_pad_pairs(pad_width, ndim, numpy_module) -> tuple[tuple[int, int], ...]:
"""Normalize NumPy-style pad widths to per-axis pairs."""

try:
pad_pairs = numpy_module.broadcast_to(numpy_module.asarray(pad_width), (ndim, 2))
except ValueError as exc:
raise ValueError(
f"pad_width must be broadcastable to shape ({ndim}, 2)"
) from exc

if numpy_module.any(pad_pairs < 0):
raise ValueError("index can't contain negative values")

return tuple(tuple(int(value) for value in pair) for pair in pad_pairs.tolist())


def _pytorch_torch_pad_width(pad_pairs: tuple[tuple[int, int], ...]) -> list[int]:
"""Convert NumPy-ordered pad pairs to PyTorch's reversed flat order."""

return [value for pair in reversed(pad_pairs) for value in pair]


def _pytorch_constant_value_pairs(
constant_values,
ndim,
numpy_module,
torch_module,
) -> tuple[tuple[object, object], ...]:
"""Normalize NumPy-style constant pad values to per-axis pairs."""

if torch_module.is_tensor(constant_values):
constant_values = constant_values.detach().cpu().numpy()

try:
value_pairs = numpy_module.broadcast_to(
numpy_module.asarray(constant_values),
(ndim, 2),
)
except ValueError as exc:
raise ValueError(
f"constant_values must be broadcastable to shape ({ndim}, 2)"
) from exc

return tuple(tuple(pair) for pair in value_pairs.tolist())


def _patch_pytorch_pad_facade() -> None:
"""Make public PyTorch ``pad`` accept NumPy-style constant values."""

import pyrecest.backend as backend # pylint: disable=import-outside-toplevel

if getattr(backend, "__backend_name__", None) != "pytorch":
return

try:
import numpy as _np # pylint: disable=import-outside-toplevel
import torch as _torch # pylint: disable=import-outside-toplevel
except (
ModuleNotFoundError
): # pragma: no cover - backend import fails first in practice
return

original_pad = backend.pad

def pad(a, pad_width, mode="constant", constant_values=0.0):
values = backend.array(a)
if mode != "constant":
return original_pad(
values,
pad_width,
mode=mode,
constant_values=constant_values,
)

pad_pairs = _pytorch_pad_pairs(pad_width, values.ndim, _np)
torch_pad_width = _pytorch_torch_pad_width(pad_pairs)
result = _torch.nn.functional.pad(
values,
torch_pad_width,
mode="constant",
value=0.0,
)
value_pairs = _pytorch_constant_value_pairs(
constant_values,
values.ndim,
_np,
_torch,
)

for axis, ((before, after), (before_value, after_value)) in enumerate(
zip(pad_pairs, value_pairs)
):
if before:
index = [slice(None)] * result.ndim
index[axis] = slice(0, before)
result[tuple(index)] = before_value
if after:
index = [slice(None)] * result.ndim
index[axis] = slice(result.shape[axis] - after, result.shape[axis])
result[tuple(index)] = after_value
return result

pad.__name__ = "pad"
pad.__doc__ = getattr(original_pad, "__doc__", None)
backend.pad = pad


def _patch_jax_std_out_facade() -> None:
"""Make public JAX ``std`` accept NumPy's ``out`` argument."""

Expand Down Expand Up @@ -231,6 +338,7 @@ def std(

_patch_pytorch_comparison_facade()
_patch_pytorch_tile_facade()
_patch_pytorch_pad_facade()
_patch_jax_std_out_facade()

from pyrecest.backend_support import ( # noqa: E402,F401
Expand Down
39 changes: 39 additions & 0 deletions tests/backend_support/test_pytorch_pad_constant_values_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest

from tests.support.backend_runner import run_backend_code


def test_pytorch_pad_accepts_numpy_constant_value_pairs():
pytest.importorskip("torch")

code = """
import numpy as np
import numpy.testing as npt

import pyrecest.backend as backend

values = backend.asarray([[1, 2], [3, 4]], dtype=backend.int64)
result = backend.pad(values, ((1, 0), (0, 2)), constant_values=((5, 6), (7, 8)))
expected = np.pad(
np.array([[1, 2], [3, 4]]),
((1, 0), (0, 2)),
mode="constant",
constant_values=((5, 6), (7, 8)),
)
assert result.shape == expected.shape
assert backend.to_numpy(result).tolist() == expected.tolist()

complex_values = backend.asarray([1.0 + 1.0j], dtype=backend.complex128)
complex_result = backend.pad(complex_values, (1, 1), constant_values=2.0 + 3.0j)
complex_expected = np.pad(
np.array([1.0 + 1.0j], dtype=np.complex128),
(1, 1),
mode="constant",
constant_values=2.0 + 3.0j,
)
npt.assert_allclose(backend.to_numpy(complex_result), complex_expected)
"""

result = run_backend_code("pytorch", code)

assert result.returncode == 0, result.stderr
Loading