Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 52 additions & 23 deletions src/pyrecest/backend_support/_torch_dtype_promotion_contract.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,65 @@
"""PyTorch dtype promotion compatibility helpers."""
"""PyTorch compatibility helpers for NumPy-style backend contracts."""

from __future__ import annotations


def _patch_pytorch_concatenate_axis_none_contract(raw_pytorch, torch_module, backend) -> None:
"""Make PyTorch concatenate flatten inputs when ``axis=None``."""
original_concatenate = raw_pytorch.concatenate
if getattr(original_concatenate, "_pyrecest_axis_none_contract", False):
return

def concatenate(seq, axis=0, out=None):
tensors = [raw_pytorch.array(item) for item in seq]
if tensors:
tensors = raw_pytorch.convert_to_wider_dtype(tensors)
if axis is None:
tensors = [torch_module.flatten(tensor) for tensor in tensors]
axis = 0
result = torch_module.cat(tensors, dim=axis)
if out is not None:
out.copy_(result)
return out
return result

concatenate.__name__ = getattr(original_concatenate, "__name__", "concatenate")
concatenate.__doc__ = getattr(original_concatenate, "__doc__", None)
concatenate._pyrecest_axis_none_contract = True
raw_pytorch.concatenate = concatenate
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.concatenate = concatenate


def patch_pytorch_dtype_promotion_contract() -> None:
"""Make PyTorch mixed-dtype helpers use Torch's promotion rules."""
"""Make PyTorch mixed-dtype helpers and concatenate match NumPy-style contracts."""
try:
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel
import torch # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier
return

original_convert = raw_pytorch.convert_to_wider_dtype
if getattr(original_convert, "_pyrecest_torch_promotion_contract", False):
return
if not getattr(original_convert, "_pyrecest_torch_promotion_contract", False):

def convert_to_wider_dtype(tensor_list):
tensors = list(tensor_list)
if not tensors:
return tensors

promoted_dtype = tensors[0].dtype
for tensor in tensors[1:]:
promoted_dtype = torch.promote_types(promoted_dtype, tensor.dtype)

if all(tensor.dtype == promoted_dtype for tensor in tensors):
return tensors
return [raw_pytorch.cast(tensor, dtype=promoted_dtype) for tensor in tensors]

convert_to_wider_dtype.__name__ = getattr(
original_convert, "__name__", "convert_to_wider_dtype"
)
convert_to_wider_dtype.__doc__ = getattr(original_convert, "__doc__", None)
convert_to_wider_dtype._pyrecest_torch_promotion_contract = True
raw_pytorch.convert_to_wider_dtype = convert_to_wider_dtype

def convert_to_wider_dtype(tensor_list):
tensors = list(tensor_list)
if not tensors:
return tensors

promoted_dtype = tensors[0].dtype
for tensor in tensors[1:]:
promoted_dtype = torch.promote_types(promoted_dtype, tensor.dtype)

if all(tensor.dtype == promoted_dtype for tensor in tensors):
return tensors
return [raw_pytorch.cast(tensor, dtype=promoted_dtype) for tensor in tensors]

convert_to_wider_dtype.__name__ = getattr(
original_convert, "__name__", "convert_to_wider_dtype"
)
convert_to_wider_dtype.__doc__ = getattr(original_convert, "__doc__", None)
convert_to_wider_dtype._pyrecest_torch_promotion_contract = True
raw_pytorch.convert_to_wider_dtype = convert_to_wider_dtype
_patch_pytorch_concatenate_axis_none_contract(raw_pytorch, torch, backend)
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import importlib.util

import pytest
from tests.support.backend_runner import run_backend_code


@pytest.mark.backend_portable
def test_pytorch_concatenate_axis_none_flattens_public_backend_inputs():
if importlib.util.find_spec("torch") is None:
pytest.skip("PyTorch is not installed")

result = run_backend_code(
"pytorch",
"""
import pyrecest.backend as backend

actual = backend.concatenate(([[1, 2], [3, 4]], [[5, 6]]), axis=None)

assert tuple(actual.shape) == (6,)
assert actual.tolist() == [1, 2, 3, 4, 5, 6]
print("ok")
""",
)

assert result.returncode == 0, result.stderr
assert "ok" in result.stdout


@pytest.mark.backend_portable
def test_raw_pytorch_concatenate_axis_none_is_patched_with_numpy_backend():
if importlib.util.find_spec("torch") is None:
pytest.skip("PyTorch is not installed")

result = run_backend_code(
"numpy",
"""
import pyrecest._backend.pytorch as pytorch_backend

actual = pytorch_backend.concatenate(([[1, 2], [3, 4]], [[5, 6]]), axis=None)

assert tuple(actual.shape) == (6,)
assert actual.tolist() == [1, 2, 3, 4, 5, 6]
print("ok")
""",
)

assert result.returncode == 0, result.stderr
assert "ok" in result.stdout
Loading