From 3941be22d17836e1f8486e9e111267a9397c7615 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:09:04 +0200 Subject: [PATCH 1/2] Fix PyTorch FFT array-like input handling --- src/pyrecest/_backend/pytorch/fft.py | 50 +++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/src/pyrecest/_backend/pytorch/fft.py b/src/pyrecest/_backend/pytorch/fft.py index 262d6800c..f9c507da7 100644 --- a/src/pyrecest/_backend/pytorch/fft.py +++ b/src/pyrecest/_backend/pytorch/fft.py @@ -1,10 +1,44 @@ # For ffts. Added for pyrecest. import torch as _torch -from torch.fft import ( - fftn, - fftshift, - ifftn, - ifftshift, - irfft, - rfft, -) + + +def _as_tensor(x): + return x if _torch.is_tensor(x) else _torch.as_tensor(x) + + +def _resolve_dim_alias(dim, alias, alias_name, func_name, *, default=None): + if alias is None: + return default if dim is None else dim + if dim is not None and dim != alias: + raise TypeError(f"{func_name}() got both 'dim' and '{alias_name}'") + return alias + + +def fftn(input, s=None, dim=None, norm=None, *, axes=None, out=None): + dim = _resolve_dim_alias(dim, axes, "axes", "fftn") + return _torch.fft.fftn(_as_tensor(input), s=s, dim=dim, norm=norm, out=out) + + +def ifftn(input, s=None, dim=None, norm=None, *, axes=None, out=None): + dim = _resolve_dim_alias(dim, axes, "axes", "ifftn") + return _torch.fft.ifftn(_as_tensor(input), s=s, dim=dim, norm=norm, out=out) + + +def rfft(input, n=None, dim=None, norm=None, *, axis=None, out=None): + dim = _resolve_dim_alias(dim, axis, "axis", "rfft", default=-1) + return _torch.fft.rfft(_as_tensor(input), n=n, dim=dim, norm=norm, out=out) + + +def irfft(input, n=None, dim=None, norm=None, *, axis=None, out=None): + dim = _resolve_dim_alias(dim, axis, "axis", "irfft", default=-1) + return _torch.fft.irfft(_as_tensor(input), n=n, dim=dim, norm=norm, out=out) + + +def fftshift(input, dim=None, *, axes=None): + dim = _resolve_dim_alias(dim, axes, "axes", "fftshift") + return _torch.fft.fftshift(_as_tensor(input), dim=dim) + + +def ifftshift(input, dim=None, *, axes=None): + dim = _resolve_dim_alias(dim, axes, "axes", "ifftshift") + return _torch.fft.ifftshift(_as_tensor(input), dim=dim) From b31bf9ab80cf1e53fd16f4d96b8d69322d8ad299 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:09:21 +0200 Subject: [PATCH 2/2] Add PyTorch FFT backend regression test --- .../test_pytorch_fft_numpy_signature.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 tests/backend_support/test_pytorch_fft_numpy_signature.py diff --git a/tests/backend_support/test_pytorch_fft_numpy_signature.py b/tests/backend_support/test_pytorch_fft_numpy_signature.py new file mode 100644 index 000000000..7424fafcd --- /dev/null +++ b/tests/backend_support/test_pytorch_fft_numpy_signature.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import importlib.util + +import pytest + +from tests.support.backend_runner import run_backend_code + + +def test_pytorch_fft_accepts_array_like_inputs_and_numpy_axis_aliases(): + if importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + + result = run_backend_code( + "pytorch", + """ +import numpy as np +import pyrecest.backend as backend + +values = [[1.0, 2.0], [3.0, 4.0]] +fftn_result = backend.fft.fftn(values, axes=(0, 1)) +np.testing.assert_allclose( + backend.to_numpy(fftn_result), + np.fft.fftn(np.asarray(values), axes=(0, 1)), +) + +rfft_result = backend.fft.rfft([0.0, 1.0, 0.0, -1.0], axis=0) +round_trip = backend.fft.irfft(rfft_result, n=4, axis=0) +np.testing.assert_allclose( + backend.to_numpy(round_trip), + np.asarray([0.0, 1.0, 0.0, -1.0]), +) + +np.testing.assert_array_equal( + backend.to_numpy(backend.fft.fftshift([0, 1, 2, 3], axes=0)), + np.asarray([2, 3, 0, 1]), +) +np.testing.assert_array_equal( + backend.to_numpy(backend.fft.ifftshift([2, 3, 0, 1], axes=0)), + np.asarray([0, 1, 2, 3]), +) +print("ok") +""", + ) + + assert result.returncode == 0, result.stderr + assert "ok" in result.stdout + + +def test_pytorch_fft_rejects_conflicting_axis_aliases(): + if importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + + result = run_backend_code( + "pytorch", + """ +import pyrecest.backend as backend + +try: + backend.fft.rfft([0.0, 1.0], dim=0, axis=1) +except TypeError as exc: + assert "dim" in str(exc) + assert "axis" in str(exc) +else: + raise AssertionError("conflicting dim/axis aliases were accepted") +print("ok") +""", + ) + + assert result.returncode == 0, result.stderr + assert "ok" in result.stdout