Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 30 additions & 24 deletions src/pyrecest/backend_support/_torch_dtype_promotion_contract.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,42 @@
"""PyTorch dtype promotion compatibility helpers."""
"""PyTorch dtype compatibility helpers."""

from __future__ import annotations

from pyrecest.backend_support._torch_sort_contract import patch_pytorch_sort_numpy_contract


def patch_pytorch_dtype_promotion_contract() -> None:
"""Make PyTorch mixed-dtype helpers use Torch's promotion rules."""
"""Make PyTorch mixed-dtype helpers use Torch-compatible promotion rules."""
try:
import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel
import torch # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier
return

original_convert = raw_pytorch.convert_to_wider_dtype
if getattr(original_convert, "_pyrecest_torch_promotion_contract", False):
return

def convert_to_wider_dtype(tensor_list):
tensors = list(tensor_list)
if not tensors:
return tensors

promoted_dtype = tensors[0].dtype
for tensor in tensors[1:]:
promoted_dtype = torch.promote_types(promoted_dtype, tensor.dtype)

if all(tensor.dtype == promoted_dtype for tensor in tensors):
return tensors
return [raw_pytorch.cast(tensor, dtype=promoted_dtype) for tensor in tensors]

convert_to_wider_dtype.__name__ = getattr(
original_convert, "__name__", "convert_to_wider_dtype"
)
convert_to_wider_dtype.__doc__ = getattr(original_convert, "__doc__", None)
convert_to_wider_dtype._pyrecest_torch_promotion_contract = True
raw_pytorch.convert_to_wider_dtype = convert_to_wider_dtype
if not getattr(original_convert, "_pyrecest_torch_promotion_contract", False):

def convert_to_wider_dtype(tensor_list):
tensors = list(tensor_list)
if not tensors:
return tensors

dtype = tensors[0].dtype
for tensor in tensors[1:]:
dtype = torch.result_type(
torch.empty((), dtype=dtype),
torch.empty((), dtype=tensor.dtype),
)

if all(tensor.dtype == dtype for tensor in tensors):
return tensors
return [raw_pytorch.cast(tensor, dtype=dtype) for tensor in tensors]

convert_to_wider_dtype.__name__ = getattr(
original_convert, "__name__", "convert_to_wider_dtype"
)
convert_to_wider_dtype.__doc__ = getattr(original_convert, "__doc__", None)
convert_to_wider_dtype._pyrecest_torch_promotion_contract = True
raw_pytorch.convert_to_wider_dtype = convert_to_wider_dtype

patch_pytorch_sort_numpy_contract()
57 changes: 57 additions & 0 deletions src/pyrecest/backend_support/_torch_sort_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""PyTorch sort compatibility helpers."""

from __future__ import annotations

from operator import index as _operator_index

_SORT_KINDS = {"quicksort", "heapsort", "mergesort", "stable"}
_STABLE_SORT_KINDS = {"mergesort", "stable"}


def patch_pytorch_sort_numpy_contract() -> None:
"""Make PyTorch sort follow NumPy axis and stable-kind contracts."""
try:
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel
import torch # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable
return

original_sort = raw_pytorch.sort
if getattr(original_sort, "_pyrecest_numpy_sort_contract", False):
return

def sort(a, axis=-1, kind=None, order=None, *, stable=None, descending=False):
if order is not None:
raise TypeError("PyTorch backend sort does not support field-order sorting")
if kind is not None:
if kind not in _SORT_KINDS:
raise ValueError(
"sort kind must be one of 'quicksort', 'heapsort', 'mergesort', or 'stable'"
)
if kind in _STABLE_SORT_KINDS:
if stable is False:
raise TypeError("sort() got inconsistent 'kind' and 'stable' arguments")
stable = True
stable = bool(stable) if stable is not None else False

values = raw_pytorch.array(a)
if axis is None:
values = torch.flatten(values)
axis = -1
else:
axis = _operator_index(axis)
sorted_values, _ = torch.sort(
values,
dim=axis,
stable=stable,
descending=descending,
)
return sorted_values

sort.__name__ = getattr(original_sort, "__name__", "sort")
sort.__doc__ = getattr(original_sort, "__doc__", None)
sort._pyrecest_numpy_sort_contract = True
raw_pytorch.sort = sort
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.sort = sort
28 changes: 28 additions & 0 deletions tests/backend_support/test_pytorch_sort_raw_backend_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import importlib.util

import pytest


def _to_python(sort_backend, value):
value = sort_backend.to_numpy(value)
if hasattr(value, "tolist"):
return value.tolist()
return value


@pytest.mark.backend_portable
def test_pytorch_sort_accepts_numpy_axis_none_and_stable_kind():
if importlib.util.find_spec("torch") is None:
pytest.skip("torch is not installed")

import pyrecest # noqa: F401 # pylint: disable=import-outside-toplevel,unused-import
import pyrecest.backend as public_backend # pylint: disable=import-outside-toplevel
import pyrecest._backend.pytorch as sort_backend # pylint: disable=import-outside-toplevel

assert _to_python(sort_backend, sort_backend.sort([[3, 1], [2, 4]], axis=None)) == [1, 2, 3, 4]
assert _to_python(sort_backend, sort_backend.sort([3, 1, 2, 1], kind="stable")) == [1, 1, 2, 3]
assert _to_python(sort_backend, sort_backend.sort([3, 1, 2, 1], kind="mergesort")) == [1, 1, 2, 3]

if public_backend.__backend_name__ == "pytorch":
assert _to_python(public_backend, public_backend.sort([[3, 1], [2, 4]], axis=None)) == [1, 2, 3, 4]
assert _to_python(public_backend, public_backend.sort([3, 1, 2, 1], kind="stable")) == [1, 1, 2, 3]
Loading