From da2163ed9cbff6d1ff5e4d0dec199f890f2a47e7 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:34:53 +0200 Subject: [PATCH 1/2] Fix PyTorch reshape array-like contract --- src/pyrecest/__init__.py | 59 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/src/pyrecest/__init__.py b/src/pyrecest/__init__.py index 9df2b5bfa..b1893c728 100644 --- a/src/pyrecest/__init__.py +++ b/src/pyrecest/__init__.py @@ -198,6 +198,64 @@ def tile(x, reps): backend.tile = tile +def _pytorch_reshape_shape(shape, torch_module) -> tuple[int, ...]: + """Normalize NumPy-style reshape dimensions for ``torch.reshape``.""" + + if torch_module.is_tensor(shape): + if shape.ndim == 0: + return (_operator_index(shape.item()),) + shape = shape.detach().cpu().tolist() + elif getattr(shape, "ndim", None) == 0 and hasattr(shape, "item"): + return (_operator_index(shape.item()),) + + try: + return (_operator_index(shape),) + except TypeError: + pass + + if isinstance(shape, (str, bytes)): + raise TypeError("reshape shape must be an integer or a sequence of integers") + + try: + return tuple(_operator_index(dimension) for dimension in shape) + except TypeError as exc: + raise TypeError( + "reshape shape must be an integer or a sequence of integers" + ) from exc + + +def _patch_pytorch_reshape_facade() -> None: + """Make PyTorch ``reshape`` accept array-like inputs and NumPy-style shapes.""" + + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + + if getattr(backend, "__backend_name__", None) != "pytorch": + return + + try: + import torch as _torch # pylint: disable=import-outside-toplevel + import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel + except ( + ModuleNotFoundError + ): # pragma: no cover - backend import fails first in practice + return + + original_reshape = pytorch_backend.reshape + if getattr(original_reshape, "_pyrecest_arraylike_contract", False): + return + + def reshape(x, shape): + return original_reshape( + pytorch_backend.array(x), _pytorch_reshape_shape(shape, _torch) + ) + + reshape.__name__ = getattr(original_reshape, "__name__", "reshape") + reshape.__doc__ = getattr(original_reshape, "__doc__", None) + reshape._pyrecest_arraylike_contract = True + pytorch_backend.reshape = reshape + backend.reshape = reshape + + def _patch_jax_std_out_facade() -> None: """Make public JAX ``std`` accept NumPy's ``out`` argument.""" @@ -231,6 +289,7 @@ def std( _patch_pytorch_comparison_facade() _patch_pytorch_tile_facade() +_patch_pytorch_reshape_facade() _patch_jax_std_out_facade() from pyrecest.backend_support import ( # noqa: E402,F401 From 4d65569e2499913927048145a2688d85ffed6f57 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:35:05 +0200 Subject: [PATCH 2/2] Add PyTorch reshape contract regression test --- .../test_pytorch_reshape_contract.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 tests/backend_support/test_pytorch_reshape_contract.py diff --git a/tests/backend_support/test_pytorch_reshape_contract.py b/tests/backend_support/test_pytorch_reshape_contract.py new file mode 100644 index 000000000..9928b8035 --- /dev/null +++ b/tests/backend_support/test_pytorch_reshape_contract.py @@ -0,0 +1,46 @@ +import importlib.util +import os +import subprocess +import sys + +import pytest + + +@pytest.mark.backend_portable +def test_pytorch_reshape_accepts_array_like_values_and_numpy_style_shapes(): + if importlib.util.find_spec("torch") is None: + pytest.skip("torch is not installed") + + env = os.environ.copy() + env["PYRECEST_BACKEND"] = "pytorch" + src_path = os.path.abspath("src") + env["PYTHONPATH"] = ( + src_path + if not env.get("PYTHONPATH") + else os.pathsep.join([src_path, env["PYTHONPATH"]]) + ) + + code = """ +import pyrecest.backend as backend +import pyrecest._backend.pytorch as pytorch_backend + +for backend_module in (backend, pytorch_backend): + matrix = backend_module.reshape([1, 2, 3, 4], (2, 2)) + assert tuple(matrix.shape) == (2, 2) + assert matrix.detach().cpu().numpy().tolist() == [[1, 2], [3, 4]] + + vector = backend_module.reshape([[1, 2], [3, 4]], 4) + assert tuple(vector.shape) == (4,) + assert vector.detach().cpu().numpy().tolist() == [1, 2, 3, 4] + + inferred = backend_module.reshape([1, 2, 3, 4], backend.array([2, -1])) + assert tuple(inferred.shape) == (2, 2) + + try: + backend_module.reshape([1, 2], [2.0]) + except TypeError: + pass + else: + raise AssertionError("reshape accepted non-integer dimensions") +""" + subprocess.run([sys.executable, "-c", code], check=True, env=env)