diff --git a/src/pyrecest/backend_support/_torch_dtype_promotion_contract.py b/src/pyrecest/backend_support/_torch_dtype_promotion_contract.py index 3cb184d28..5af64b072 100644 --- a/src/pyrecest/backend_support/_torch_dtype_promotion_contract.py +++ b/src/pyrecest/backend_support/_torch_dtype_promotion_contract.py @@ -1,10 +1,12 @@ -"""PyTorch dtype promotion compatibility helpers.""" +"""PyTorch dtype compatibility helpers.""" from __future__ import annotations +from pyrecest.backend_support._torch_sort_contract import patch_pytorch_sort_numpy_contract + def patch_pytorch_dtype_promotion_contract() -> None: - """Make PyTorch mixed-dtype helpers use Torch's promotion rules.""" + """Make PyTorch mixed-dtype helpers use Torch-compatible promotion rules.""" try: import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel import torch # pylint: disable=import-outside-toplevel @@ -12,25 +14,29 @@ def patch_pytorch_dtype_promotion_contract() -> None: return original_convert = raw_pytorch.convert_to_wider_dtype - if getattr(original_convert, "_pyrecest_torch_promotion_contract", False): - return - - def convert_to_wider_dtype(tensor_list): - tensors = list(tensor_list) - if not tensors: - return tensors - - promoted_dtype = tensors[0].dtype - for tensor in tensors[1:]: - promoted_dtype = torch.promote_types(promoted_dtype, tensor.dtype) - - if all(tensor.dtype == promoted_dtype for tensor in tensors): - return tensors - return [raw_pytorch.cast(tensor, dtype=promoted_dtype) for tensor in tensors] - - convert_to_wider_dtype.__name__ = getattr( - original_convert, "__name__", "convert_to_wider_dtype" - ) - convert_to_wider_dtype.__doc__ = getattr(original_convert, "__doc__", None) - convert_to_wider_dtype._pyrecest_torch_promotion_contract = True - raw_pytorch.convert_to_wider_dtype = convert_to_wider_dtype + if not getattr(original_convert, "_pyrecest_torch_promotion_contract", False): + + def convert_to_wider_dtype(tensor_list): + tensors = list(tensor_list) + if not tensors: + return tensors + + dtype = tensors[0].dtype + for tensor in tensors[1:]: + dtype = torch.result_type( + torch.empty((), dtype=dtype), + torch.empty((), dtype=tensor.dtype), + ) + + if all(tensor.dtype == dtype for tensor in tensors): + return tensors + return [raw_pytorch.cast(tensor, dtype=dtype) for tensor in tensors] + + convert_to_wider_dtype.__name__ = getattr( + original_convert, "__name__", "convert_to_wider_dtype" + ) + convert_to_wider_dtype.__doc__ = getattr(original_convert, "__doc__", None) + convert_to_wider_dtype._pyrecest_torch_promotion_contract = True + raw_pytorch.convert_to_wider_dtype = convert_to_wider_dtype + + patch_pytorch_sort_numpy_contract() diff --git a/src/pyrecest/backend_support/_torch_sort_contract.py b/src/pyrecest/backend_support/_torch_sort_contract.py new file mode 100644 index 000000000..12ce3ad31 --- /dev/null +++ b/src/pyrecest/backend_support/_torch_sort_contract.py @@ -0,0 +1,57 @@ +"""PyTorch sort compatibility helpers.""" + +from __future__ import annotations + +from operator import index as _operator_index + +_SORT_KINDS = {"quicksort", "heapsort", "mergesort", "stable"} +_STABLE_SORT_KINDS = {"mergesort", "stable"} + + +def patch_pytorch_sort_numpy_contract() -> None: + """Make PyTorch sort follow NumPy axis and stable-kind contracts.""" + try: + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel + import torch # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable + return + + original_sort = raw_pytorch.sort + if getattr(original_sort, "_pyrecest_numpy_sort_contract", False): + return + + def sort(a, axis=-1, kind=None, order=None, *, stable=None, descending=False): + if order is not None: + raise TypeError("PyTorch backend sort does not support field-order sorting") + if kind is not None: + if kind not in _SORT_KINDS: + raise ValueError( + "sort kind must be one of 'quicksort', 'heapsort', 'mergesort', or 'stable'" + ) + if kind in _STABLE_SORT_KINDS: + if stable is False: + raise TypeError("sort() got inconsistent 'kind' and 'stable' arguments") + stable = True + stable = bool(stable) if stable is not None else False + + values = raw_pytorch.array(a) + if axis is None: + values = torch.flatten(values) + axis = -1 + else: + axis = _operator_index(axis) + sorted_values, _ = torch.sort( + values, + dim=axis, + stable=stable, + descending=descending, + ) + return sorted_values + + sort.__name__ = getattr(original_sort, "__name__", "sort") + sort.__doc__ = getattr(original_sort, "__doc__", None) + sort._pyrecest_numpy_sort_contract = True + raw_pytorch.sort = sort + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.sort = sort diff --git a/tests/backend_support/test_pytorch_sort_raw_backend_contract.py b/tests/backend_support/test_pytorch_sort_raw_backend_contract.py new file mode 100644 index 000000000..a63fb2876 --- /dev/null +++ b/tests/backend_support/test_pytorch_sort_raw_backend_contract.py @@ -0,0 +1,28 @@ +import importlib.util + +import pytest + + +def _to_python(sort_backend, value): + value = sort_backend.to_numpy(value) + if hasattr(value, "tolist"): + return value.tolist() + return value + + +@pytest.mark.backend_portable +def test_pytorch_sort_accepts_numpy_axis_none_and_stable_kind(): + if importlib.util.find_spec("torch") is None: + pytest.skip("torch is not installed") + + import pyrecest # noqa: F401 # pylint: disable=import-outside-toplevel,unused-import + import pyrecest.backend as public_backend # pylint: disable=import-outside-toplevel + import pyrecest._backend.pytorch as sort_backend # pylint: disable=import-outside-toplevel + + assert _to_python(sort_backend, sort_backend.sort([[3, 1], [2, 4]], axis=None)) == [1, 2, 3, 4] + assert _to_python(sort_backend, sort_backend.sort([3, 1, 2, 1], kind="stable")) == [1, 1, 2, 3] + assert _to_python(sort_backend, sort_backend.sort([3, 1, 2, 1], kind="mergesort")) == [1, 1, 2, 3] + + if public_backend.__backend_name__ == "pytorch": + assert _to_python(public_backend, public_backend.sort([[3, 1], [2, 4]], axis=None)) == [1, 2, 3, 4] + assert _to_python(public_backend, public_backend.sort([3, 1, 2, 1], kind="stable")) == [1, 1, 2, 3]