From c847433dfac20aad7d78e0c2c7306960b7c22bc9 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:47:46 +0200 Subject: [PATCH 1/3] Fix PyTorch pad constant values --- src/pyrecest/__init__.py | 140 ++++++++++++++++++++++++++++++--------- 1 file changed, 109 insertions(+), 31 deletions(-) diff --git a/src/pyrecest/__init__.py b/src/pyrecest/__init__.py index 9df2b5bfa..aaf971b37 100644 --- a/src/pyrecest/__init__.py +++ b/src/pyrecest/__init__.py @@ -198,6 +198,113 @@ def tile(x, reps): backend.tile = tile +def _pytorch_pad_pairs(pad_width, ndim, numpy_module) -> tuple[tuple[int, int], ...]: + """Normalize NumPy-style pad widths to per-axis pairs.""" + + try: + pad_pairs = numpy_module.broadcast_to(numpy_module.asarray(pad_width), (ndim, 2)) + except ValueError as exc: + raise ValueError( + f"pad_width must be broadcastable to shape ({ndim}, 2)" + ) from exc + + if numpy_module.any(pad_pairs < 0): + raise ValueError("index can't contain negative values") + + return tuple(tuple(int(value) for value in pair) for pair in pad_pairs.tolist()) + + +def _pytorch_torch_pad_width(pad_pairs: tuple[tuple[int, int], ...]) -> list[int]: + """Convert NumPy-ordered pad pairs to PyTorch's reversed flat order.""" + + return [value for pair in reversed(pad_pairs) for value in pair] + + +def _pytorch_constant_value_pairs( + constant_values, + ndim, + numpy_module, + torch_module, +) -> tuple[tuple[object, object], ...]: + """Normalize NumPy-style constant pad values to per-axis pairs.""" + + if torch_module.is_tensor(constant_values): + constant_values = constant_values.detach().cpu().numpy() + + try: + value_pairs = numpy_module.broadcast_to( + numpy_module.asarray(constant_values), + (ndim, 2), + ) + except ValueError as exc: + raise ValueError( + f"constant_values must be broadcastable to shape ({ndim}, 2)" + ) from exc + + return tuple(tuple(pair) for pair in value_pairs.tolist()) + + +def _patch_pytorch_pad_facade() -> None: + """Make public PyTorch ``pad`` accept NumPy-style constant values.""" + + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + + if getattr(backend, "__backend_name__", None) != "pytorch": + return + + try: + import numpy as _np # pylint: disable=import-outside-toplevel + import torch as _torch # pylint: disable=import-outside-toplevel + except ( + ModuleNotFoundError + ): # pragma: no cover - backend import fails first in practice + return + + original_pad = backend.pad + + def pad(a, pad_width, mode="constant", constant_values=0.0): + values = backend.array(a) + if mode != "constant": + return original_pad( + values, + pad_width, + mode=mode, + constant_values=constant_values, + ) + + pad_pairs = _pytorch_pad_pairs(pad_width, values.ndim, _np) + torch_pad_width = _pytorch_torch_pad_width(pad_pairs) + result = _torch.nn.functional.pad( + values, + torch_pad_width, + mode="constant", + value=0.0, + ) + value_pairs = _pytorch_constant_value_pairs( + constant_values, + values.ndim, + _np, + _torch, + ) + + for axis, ((before, after), (before_value, after_value)) in enumerate( + zip(pad_pairs, value_pairs) + ): + if before: + index = [slice(None)] * result.ndim + index[axis] = slice(0, before) + result[tuple(index)] = before_value + if after: + index = [slice(None)] * result.ndim + index[axis] = slice(result.shape[axis] - after, result.shape[axis]) + result[tuple(index)] = after_value + return result + + pad.__name__ = "pad" + pad.__doc__ = getattr(original_pad, "__doc__", None) + backend.pad = pad + + def _patch_jax_std_out_facade() -> None: """Make public JAX ``std`` accept NumPy's ``out`` argument.""" @@ -231,6 +338,7 @@ def std( _patch_pytorch_comparison_facade() _patch_pytorch_tile_facade() +_patch_pytorch_pad_facade() _patch_jax_std_out_facade() from pyrecest.backend_support import ( # noqa: E402,F401 @@ -258,38 +366,8 @@ def std( ShapeError, ValidationError, ) -from pyrecest.stability import ( # noqa: E402,F401 - get_public_api_status, - iter_public_api_status, - stability, -) try: __version__ = version("pyrecest") -except PackageNotFoundError: # pragma: no cover - source tree without install metadata +except PackageNotFoundError: # pragma: no cover - editable/source tree without install metadata __version__ = "0+unknown" - -__all__ = [ - "BackendNotSupportedError", - "BackendSupportError", - "DimensionMismatchError", - "EvidenceComputationMode", - "NumericalStabilityError", - "OptionalDependencyError", - "PyRecEstError", - "ShapeError", - "ValidationError", - "__version__", - "assert_backend", - "backend_support", - "copy", - "format_backend_support_markdown", - "get_backend_name", - "get_backend_support", - "get_public_api_status", - "is_backend", - "iter_public_api_status", - "stability", - "warn_if_backend_env_changed", - "resolve_evidence_computation_mode", -] From ab20dca0b1bb87e3870801118363666f632464a5 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:47:58 +0200 Subject: [PATCH 2/3] Add PyTorch pad constant value regression test --- ...st_pytorch_pad_constant_values_contract.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tests/backend_support/test_pytorch_pad_constant_values_contract.py diff --git a/tests/backend_support/test_pytorch_pad_constant_values_contract.py b/tests/backend_support/test_pytorch_pad_constant_values_contract.py new file mode 100644 index 000000000..7c5941bb9 --- /dev/null +++ b/tests/backend_support/test_pytorch_pad_constant_values_contract.py @@ -0,0 +1,39 @@ +import pytest + +from tests.support.backend_runner import run_backend_code + + +def test_pytorch_pad_accepts_numpy_constant_value_pairs(): + pytest.importorskip("torch") + + code = """ +import numpy as np +import numpy.testing as npt + +import pyrecest.backend as backend + +values = backend.asarray([[1, 2], [3, 4]], dtype=backend.int64) +result = backend.pad(values, ((1, 0), (0, 2)), constant_values=((5, 6), (7, 8))) +expected = np.pad( + np.array([[1, 2], [3, 4]]), + ((1, 0), (0, 2)), + mode="constant", + constant_values=((5, 6), (7, 8)), +) +assert result.shape == expected.shape +assert backend.to_numpy(result).tolist() == expected.tolist() + +complex_values = backend.asarray([1.0 + 1.0j], dtype=backend.complex128) +complex_result = backend.pad(complex_values, (1, 1), constant_values=2.0 + 3.0j) +complex_expected = np.pad( + np.array([1.0 + 1.0j], dtype=np.complex128), + (1, 1), + mode="constant", + constant_values=2.0 + 3.0j, +) +npt.assert_allclose(backend.to_numpy(complex_result), complex_expected) +""" + + result = run_backend_code("pytorch", code) + + assert result.returncode == 0, result.stderr From 9607af177fcdfd0d02f49f525afb2e3e67bc4c98 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:49:07 +0200 Subject: [PATCH 3/3] Restore PyRecEst package exports --- src/pyrecest/__init__.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/src/pyrecest/__init__.py b/src/pyrecest/__init__.py index aaf971b37..b1b8915a8 100644 --- a/src/pyrecest/__init__.py +++ b/src/pyrecest/__init__.py @@ -366,8 +366,38 @@ def std( ShapeError, ValidationError, ) +from pyrecest.stability import ( # noqa: E402,F401 + get_public_api_status, + iter_public_api_status, + stability, +) try: __version__ = version("pyrecest") -except PackageNotFoundError: # pragma: no cover - editable/source tree without install metadata +except PackageNotFoundError: # pragma: no cover - source tree without install metadata __version__ = "0+unknown" + +__all__ = [ + "BackendNotSupportedError", + "BackendSupportError", + "DimensionMismatchError", + "EvidenceComputationMode", + "NumericalStabilityError", + "OptionalDependencyError", + "PyRecEstError", + "ShapeError", + "ValidationError", + "__version__", + "assert_backend", + "backend_support", + "copy", + "format_backend_support_markdown", + "get_backend_name", + "get_backend_support", + "get_public_api_status", + "is_backend", + "iter_public_api_status", + "stability", + "warn_if_backend_env_changed", + "resolve_evidence_computation_mode", +]