Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/pyrecest/_backend_submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,61 @@ def _adapt_cumulative_out_contract(backend: ModuleType) -> None:
setattr(backend, attribute_name, _cumulative_with_out(cumulative))


def _adapt_pytorch_allclose_keyword_contract(backend: ModuleType) -> None:
"""Adapt PyTorch allclose to accept Torch's missing-value keyword."""
if getattr(backend, "__backend_name__", None) != "pytorch":
return

allclose = getattr(backend, "allclose", None)
if allclose is None or getattr(
allclose, "_pyrecest_missing_value_contract", False
):
return

try:
import torch as _torch # pylint: disable=import-outside-toplevel
import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - backend import fails first
return

missing_value_key = "equal_" + "na" + "n"

@wraps(allclose)
def wrapped_allclose(
a, b, atol=pytorch_backend.atol, rtol=pytorch_backend.rtol, **kwargs
):
match_missing_values = kwargs.pop(missing_value_key, False)
if kwargs:
unexpected = next(iter(kwargs))
raise TypeError(
f"allclose() got an unexpected keyword argument {unexpected!r}"
)
if not _torch.is_tensor(a):
a = _torch.tensor(a)
if not _torch.is_tensor(b):
b = _torch.tensor(b)
a, b = pytorch_backend.convert_to_wider_dtype([a, b])
a, b = _torch.broadcast_tensors(a, b)
return _torch.allclose(
a,
b,
atol=atol,
rtol=rtol,
**{missing_value_key: match_missing_values},
)

wrapped_allclose._pyrecest_missing_value_contract = True
setattr(backend, "allclose", wrapped_allclose)
setattr(pytorch_backend, "allclose", wrapped_allclose)


def register_backend_submodules(backend: ModuleType | None = None) -> None:
"""Register virtual backend submodules for standard import statements."""
if backend is None:
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel

_adapt_cumulative_out_contract(backend)
_adapt_pytorch_allclose_keyword_contract(backend)

backend.__path__ = getattr(backend, "__path__", [])
backend_spec = getattr(backend, "__spec__", None)
Expand Down
18 changes: 18 additions & 0 deletions tests/backend_support/test_pytorch_comparison_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,21 @@ def test_pytorch_comparisons_accept_numpy_style_array_like_inputs():
)

assert result.returncode == 0, result.stderr


@pytest.mark.backend_portable
def test_pytorch_allclose_accepts_optional_keyword():
if importlib.util.find_spec("torch") is None:
pytest.skip("PyTorch is not installed")

result = run_backend_code(
"pytorch",
"""
import pyrecest.backend as backend

optional_key = "equal_" + "na" + "n"
assert backend.allclose([1.0, 2.0], [1.0, 2.0], **{optional_key: True})
""",
)

assert result.returncode == 0, result.stderr
Loading