diff --git a/src/pyrecest/_backend/pytorch/fft.py b/src/pyrecest/_backend/pytorch/fft.py index 262d6800c..e3cc1cd51 100644 --- a/src/pyrecest/_backend/pytorch/fft.py +++ b/src/pyrecest/_backend/pytorch/fft.py @@ -1,10 +1,29 @@ # For ffts. Added for pyrecest. +from functools import wraps as _wraps + import torch as _torch -from torch.fft import ( - fftn, - fftshift, - ifftn, - ifftshift, - irfft, - rfft, -) + +from ._common import array as _array + + +def _as_fft_tensor(value): + """Convert array-like FFT inputs to torch tensors.""" + return value if _torch.is_tensor(value) else _array(value) + + +def _wrap_arraylike_fft(torch_func): + """Return a PyRecEst-compatible FFT helper accepting array-like input.""" + + @_wraps(torch_func) + def fft_func(input, *args, **kwargs): # pylint: disable=redefined-builtin + return torch_func(_as_fft_tensor(input), *args, **kwargs) + + return fft_func + + +rfft = _wrap_arraylike_fft(_torch.fft.rfft) +irfft = _wrap_arraylike_fft(_torch.fft.irfft) +fftshift = _wrap_arraylike_fft(_torch.fft.fftshift) +ifftshift = _wrap_arraylike_fft(_torch.fft.ifftshift) +fftn = _wrap_arraylike_fft(_torch.fft.fftn) +ifftn = _wrap_arraylike_fft(_torch.fft.ifftn) diff --git a/tests/backend_support/test_pytorch_fftconvolve_contract.py b/tests/backend_support/test_pytorch_fftconvolve_contract.py index 6f3b75429..b31f326ee 100644 --- a/tests/backend_support/test_pytorch_fftconvolve_contract.py +++ b/tests/backend_support/test_pytorch_fftconvolve_contract.py @@ -93,3 +93,27 @@ def test_pytorch_fftconvolve_scalar_empty_axes_matches_scipy(): expected = scipy_fftconvolve(_as_numpy(first), _as_numpy(second), axes=()) assert np.allclose(actual, expected) + + +def test_pytorch_fft_helpers_accept_array_like_inputs(): + if backend.__backend_name__ != "pytorch": + pytest.skip("PyTorch-specific FFT backend contract") + + matrix = [[1.0, 2.0], [3.0, 4.0]] + assert np.allclose(_as_numpy(backend.fft.fftn(matrix)), np.fft.fftn(matrix)) + + vector = [1.0, 2.0, 3.0] + spectrum = np.fft.rfft(vector) + assert np.allclose(_as_numpy(backend.fft.rfft(vector)), spectrum) + assert np.allclose( + _as_numpy(backend.fft.irfft(spectrum, n=len(vector))), + np.fft.irfft(spectrum, n=len(vector)), + ) + + shift_source = [1, 2, 3, 4] + assert _as_numpy(backend.fft.fftshift(shift_source)).tolist() == np.fft.fftshift( + shift_source + ).tolist() + assert _as_numpy(backend.fft.ifftshift(shift_source)).tolist() == np.fft.ifftshift( + shift_source + ).tolist()