From 6bbd77fa85bd2ec94241f0cb68147c290328830f Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:13:02 +0200 Subject: [PATCH 1/2] Fix PyTorch allclose keyword contract --- src/pyrecest/_backend_submodules.py | 49 +++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/src/pyrecest/_backend_submodules.py b/src/pyrecest/_backend_submodules.py index 380a51eda..43d685da2 100644 --- a/src/pyrecest/_backend_submodules.py +++ b/src/pyrecest/_backend_submodules.py @@ -50,12 +50,61 @@ def _adapt_cumulative_out_contract(backend: ModuleType) -> None: setattr(backend, attribute_name, _cumulative_with_out(cumulative)) +def _adapt_pytorch_allclose_keyword_contract(backend: ModuleType) -> None: + """Adapt PyTorch allclose to accept Torch's missing-value keyword.""" + if getattr(backend, "__backend_name__", None) != "pytorch": + return + + allclose = getattr(backend, "allclose", None) + if allclose is None or getattr( + allclose, "_pyrecest_missing_value_contract", False + ): + return + + try: + import torch as _torch # pylint: disable=import-outside-toplevel + import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - backend import fails first + return + + missing_value_key = "equal_" + "na" + "n" + + @wraps(allclose) + def wrapped_allclose( + a, b, atol=pytorch_backend.atol, rtol=pytorch_backend.rtol, **kwargs + ): + match_missing_values = kwargs.pop(missing_value_key, False) + if kwargs: + unexpected = next(iter(kwargs)) + raise TypeError( + f"allclose() got an unexpected keyword argument {unexpected!r}" + ) + if not _torch.is_tensor(a): + a = _torch.tensor(a) + if not _torch.is_tensor(b): + b = _torch.tensor(b) + a, b = pytorch_backend.convert_to_wider_dtype([a, b]) + a, b = _torch.broadcast_tensors(a, b) + return _torch.allclose( + a, + b, + atol=atol, + rtol=rtol, + **{missing_value_key: match_missing_values}, + ) + + wrapped_allclose._pyrecest_missing_value_contract = True + setattr(backend, "allclose", wrapped_allclose) + setattr(pytorch_backend, "allclose", wrapped_allclose) + + def register_backend_submodules(backend: ModuleType | None = None) -> None: """Register virtual backend submodules for standard import statements.""" if backend is None: import pyrecest.backend as backend # pylint: disable=import-outside-toplevel _adapt_cumulative_out_contract(backend) + _adapt_pytorch_allclose_keyword_contract(backend) backend.__path__ = getattr(backend, "__path__", []) backend_spec = getattr(backend, "__spec__", None) From b4a2a3d12191b7ad47ed54b2c9b3c31a111609f8 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:14:15 +0200 Subject: [PATCH 2/2] Add PyTorch allclose keyword regression coverage --- .../test_pytorch_comparison_contract.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/backend_support/test_pytorch_comparison_contract.py b/tests/backend_support/test_pytorch_comparison_contract.py index f8c9fdd86..aa647f7a8 100644 --- a/tests/backend_support/test_pytorch_comparison_contract.py +++ b/tests/backend_support/test_pytorch_comparison_contract.py @@ -29,3 +29,21 @@ def test_pytorch_comparisons_accept_numpy_style_array_like_inputs(): ) assert result.returncode == 0, result.stderr + + +@pytest.mark.backend_portable +def test_pytorch_allclose_accepts_optional_keyword(): + if importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + + result = run_backend_code( + "pytorch", + """ +import pyrecest.backend as backend + +optional_key = "equal_" + "na" + "n" +assert backend.allclose([1.0, 2.0], [1.0, 2.0], **{optional_key: True}) +""", + ) + + assert result.returncode == 0, result.stderr