Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions src/pyrecest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,64 @@ def tile(x, reps):
backend.tile = tile


def _pytorch_reshape_shape(shape, torch_module) -> tuple[int, ...]:
"""Normalize NumPy-style reshape dimensions for ``torch.reshape``."""

if torch_module.is_tensor(shape):
if shape.ndim == 0:
return (_operator_index(shape.item()),)
shape = shape.detach().cpu().tolist()
elif getattr(shape, "ndim", None) == 0 and hasattr(shape, "item"):
return (_operator_index(shape.item()),)

try:
return (_operator_index(shape),)
except TypeError:
pass

if isinstance(shape, (str, bytes)):
raise TypeError("reshape shape must be an integer or a sequence of integers")

try:
return tuple(_operator_index(dimension) for dimension in shape)
except TypeError as exc:
raise TypeError(
"reshape shape must be an integer or a sequence of integers"
) from exc


def _patch_pytorch_reshape_facade() -> None:
"""Make PyTorch ``reshape`` accept array-like inputs and NumPy-style shapes."""

import pyrecest.backend as backend # pylint: disable=import-outside-toplevel

if getattr(backend, "__backend_name__", None) != "pytorch":
return

try:
import torch as _torch # pylint: disable=import-outside-toplevel
import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel
except (
ModuleNotFoundError
): # pragma: no cover - backend import fails first in practice
return

original_reshape = pytorch_backend.reshape
if getattr(original_reshape, "_pyrecest_arraylike_contract", False):
return

def reshape(x, shape):
return original_reshape(
pytorch_backend.array(x), _pytorch_reshape_shape(shape, _torch)
)

reshape.__name__ = getattr(original_reshape, "__name__", "reshape")
reshape.__doc__ = getattr(original_reshape, "__doc__", None)
reshape._pyrecest_arraylike_contract = True
pytorch_backend.reshape = reshape
backend.reshape = reshape


def _patch_jax_std_out_facade() -> None:
"""Make public JAX ``std`` accept NumPy's ``out`` argument."""

Expand Down Expand Up @@ -231,6 +289,7 @@ def std(

_patch_pytorch_comparison_facade()
_patch_pytorch_tile_facade()
_patch_pytorch_reshape_facade()
_patch_jax_std_out_facade()

from pyrecest.backend_support import ( # noqa: E402,F401
Expand Down
46 changes: 46 additions & 0 deletions tests/backend_support/test_pytorch_reshape_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import importlib.util
import os
import subprocess
import sys

import pytest


@pytest.mark.backend_portable
def test_pytorch_reshape_accepts_array_like_values_and_numpy_style_shapes():
if importlib.util.find_spec("torch") is None:
pytest.skip("torch is not installed")

env = os.environ.copy()
env["PYRECEST_BACKEND"] = "pytorch"
src_path = os.path.abspath("src")
env["PYTHONPATH"] = (
src_path
if not env.get("PYTHONPATH")
else os.pathsep.join([src_path, env["PYTHONPATH"]])
)

code = """
import pyrecest.backend as backend
import pyrecest._backend.pytorch as pytorch_backend

for backend_module in (backend, pytorch_backend):
matrix = backend_module.reshape([1, 2, 3, 4], (2, 2))
assert tuple(matrix.shape) == (2, 2)
assert matrix.detach().cpu().numpy().tolist() == [[1, 2], [3, 4]]

vector = backend_module.reshape([[1, 2], [3, 4]], 4)
assert tuple(vector.shape) == (4,)
assert vector.detach().cpu().numpy().tolist() == [1, 2, 3, 4]

inferred = backend_module.reshape([1, 2, 3, 4], backend.array([2, -1]))
assert tuple(inferred.shape) == (2, 2)

try:
backend_module.reshape([1, 2], [2.0])
except TypeError:
pass
else:
raise AssertionError("reshape accepted non-integer dimensions")
"""
subprocess.run([sys.executable, "-c", code], check=True, env=env)
Loading