Source code for torch_dxdt.spectral

"""
Spectral differentiation using PyTorch FFT.

Computes derivatives by multiplying by frequency in Fourier space.
"""

import math

import torch

from .base import Derivative


[docs] class Spectral(Derivative): """ Compute numerical derivatives using spectral (Fourier) methods. Transforms to Fourier space, multiplies by (i * omega)^order, and transforms back. This method is very accurate for smooth, periodic data. Args: order: Order of the derivative. Default is 1. filter_func: Optional function to filter frequencies before differentiation. Takes wavenumbers as input and returns weights. Example: lambda k: (torch.abs(k) < 10).float() Note: - Assumes the data is periodic over the sample interval. - Works best for smooth, band-limited signals. - For non-periodic data, consider windowing or other methods. Example: >>> spec = Spectral(order=1) >>> t = torch.linspace(0, 2*torch.pi, 100, endpoint=False) >>> x = torch.sin(t) >>> dx = spec.d(x, t) # Should approximate cos(t) """
[docs] def __init__(self, order: int = 1, filter_func=None): self.order = order self.filter_func = filter_func
[docs] def d(self, x: torch.Tensor, t: torch.Tensor, dim: int = -1) -> torch.Tensor: """ Compute the derivative of x with respect to t using spectral methods. Args: x: Input tensor of shape (..., T) or (T,) t: Time points tensor of shape (T,) dim: Dimension along which to differentiate. Default -1. Returns: Derivative tensor of same shape as x. """ if x.numel() == 0: return x.clone() # Move differentiation dim to last position x, original_dim = self._move_dim_to_last(x, dim) # Handle 1D input was_1d = x.ndim == 1 if was_1d: x = x.unsqueeze(0) T = x.shape[-1] # Compute the period t[-1] - t[0] + (t[1] - t[0]) # Full period including endpoint # Compute wavenumbers # For rfft, we only get positive frequencies up to Nyquist freqs = torch.fft.rfftfreq(T, d=(t[1] - t[0]).item(), device=x.device) omega = 2 * math.pi * freqs # FFT x_fft = torch.fft.rfft(x, dim=-1) # Multiply by (i * omega)^order # For first derivative: multiply by i * omega # For second derivative: multiply by -omega^2, etc. multiplier = (1j * omega) ** self.order # Apply filter if provided if self.filter_func is not None: # Convert frequencies to integer wavenumbers for the filter k = torch.arange(len(omega), device=x.device, dtype=x.dtype) weights = self.filter_func(k) multiplier = multiplier * weights # Apply differentiation in frequency space dx_fft = x_fft * multiplier # Inverse FFT dx = torch.fft.irfft(dx_fft, n=T, dim=-1) # Restore original shape if was_1d: dx = dx.squeeze(0) return self._restore_dim(dx, original_dim)