From 71cf29343d8184ac1e3397b6178c7a5f702259d7 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:55:58 +0200 Subject: [PATCH 1/3] Fix PyTorch FFT array-like inputs --- src/pyrecest/_backend/pytorch/fft.py | 35 +++++++++++++++++++++------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/src/pyrecest/_backend/pytorch/fft.py b/src/pyrecest/_backend/pytorch/fft.py index 262d6800c..e3cc1cd51 100644 --- a/src/pyrecest/_backend/pytorch/fft.py +++ b/src/pyrecest/_backend/pytorch/fft.py @@ -1,10 +1,29 @@ # For ffts. Added for pyrecest. +from functools import wraps as _wraps + import torch as _torch -from torch.fft import ( - fftn, - fftshift, - ifftn, - ifftshift, - irfft, - rfft, -) + +from ._common import array as _array + + +def _as_fft_tensor(value): + """Convert array-like FFT inputs to torch tensors.""" + return value if _torch.is_tensor(value) else _array(value) + + +def _wrap_arraylike_fft(torch_func): + """Return a PyRecEst-compatible FFT helper accepting array-like input.""" + + @_wraps(torch_func) + def fft_func(input, *args, **kwargs): # pylint: disable=redefined-builtin + return torch_func(_as_fft_tensor(input), *args, **kwargs) + + return fft_func + + +rfft = _wrap_arraylike_fft(_torch.fft.rfft) +irfft = _wrap_arraylike_fft(_torch.fft.irfft) +fftshift = _wrap_arraylike_fft(_torch.fft.fftshift) +ifftshift = _wrap_arraylike_fft(_torch.fft.ifftshift) +fftn = _wrap_arraylike_fft(_torch.fft.fftn) +ifftn = _wrap_arraylike_fft(_torch.fft.ifftn) From c6d6b6d586b73b1da7addb363dc51448e6f0a57b Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:57:34 +0200 Subject: [PATCH 2/3] Test PyTorch FFT array-like inputs --- .../test_pytorch_fftconvolve_contract.py | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/backend_support/test_pytorch_fftconvolve_contract.py b/tests/backend_support/test_pytorch_fftconvolve_contract.py index 6f3b75429..3c2a7c2d1 100644 --- a/tests/backend_support/test_pytorch_fftconvolve_contract.py +++ b/tests/backend_support/test_pytorch_fftconvolve_contract.py @@ -89,7 +89,31 @@ def test_pytorch_fftconvolve_scalar_empty_axes_matches_scipy(): first = backend.asarray(2.0) second = backend.asarray(3.0) - actual = _as_numpy(backend.signal.fftconvolve(first, second, axes=())) + actual = _as_numpy(backend.signal.fftconvolve(first, second, axes=()) expected = scipy_fftconvolve(_as_numpy(first), _as_numpy(second), axes=()) assert np.allclose(actual, expected) + + +def test_pytorch_fft_helpers_accept_array_like_inputs(): + if backend.__backend_name__ != "pytorch": + pytest.skip("PyTorch-specific FFT backend contract") + + matrix = [[1.0, 2.0], [3.0, 4.0]] + assert np.allclose(_as_numpy(backend.fft.fftn(matrix)), np.fft.fftn(matrix)) + + vector = [1.0, 2.0, 3.0] + spectrum = np.fft.rfft(vector) + assert np.allclose(_as_numpy(backend.fft.rfft(vector)), spectrum) + assert np.allclose( + _as_numpy(backend.fft.irfft(spectrum, n=len(vector))), + np.fft.irfft(spectrum, n=len(vector)), + ) + + shift_source = [1, 2, 3, 4] + assert _as_numpy(backend.fft.fftshift(shift_source)).tolist() == np.fft.fftshift( + shift_source + ).tolist() + assert _as_numpy(backend.fft.ifftshift(shift_source)).tolist() == np.fft.ifftshift( + shift_source + ).tolist() From 1ddd4aae8f5facf89eeed300098d962ddf1c7316 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:57:57 +0200 Subject: [PATCH 3/3] Fix PyTorch FFT regression test syntax --- tests/backend_support/test_pytorch_fftconvolve_contract.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/backend_support/test_pytorch_fftconvolve_contract.py b/tests/backend_support/test_pytorch_fftconvolve_contract.py index 3c2a7c2d1..b31f326ee 100644 --- a/tests/backend_support/test_pytorch_fftconvolve_contract.py +++ b/tests/backend_support/test_pytorch_fftconvolve_contract.py @@ -89,7 +89,7 @@ def test_pytorch_fftconvolve_scalar_empty_axes_matches_scipy(): first = backend.asarray(2.0) second = backend.asarray(3.0) - actual = _as_numpy(backend.signal.fftconvolve(first, second, axes=()) + actual = _as_numpy(backend.signal.fftconvolve(first, second, axes=())) expected = scipy_fftconvolve(_as_numpy(first), _as_numpy(second), axes=()) assert np.allclose(actual, expected)