Source code for torch_dxdt.spline_method

"""
Spline-based numerical differentiation using PyTorch.

Uses differentiable B-spline fitting for smooth derivative estimation.
"""

import torch

from .base import Derivative


[docs] class Spline(Derivative): """ Compute numerical derivatives using cubic spline interpolation. Fits a smoothing spline to the data and computes derivatives analytically. Uses torch.linalg.solve for the linear system, making it differentiable. Args: s: Smoothing parameter. Larger values give smoother results. s=0 interpolates exactly through the data points. order: Spline order. Default is 3 (cubic spline). Note: The current implementation uses a simplified Whittaker smoother approach rather than true B-splines, as it's more amenable to differentiable implementation in PyTorch. Example: >>> spl = Spline(s=0.01) >>> t = torch.linspace(0, 2*torch.pi, 100) >>> x = torch.sin(t) + 0.1 * torch.randn(100) >>> dx = spl.d(x, t) # Smoothed derivative """
[docs] def __init__(self, s: float = 0.01, order: int = 3): self.s = s self.order = order
def _build_difference_matrix( self, n: int, order: int, device: torch.device, dtype: torch.dtype ) -> torch.Tensor: """ Build the finite difference matrix of given order. For order=1: First difference D[i,i] = -1, D[i,i+1] = 1 For order=2: Second difference (used for smoothing splines) """ if order == 1: # First difference matrix (n-1) x n D = torch.zeros(n - 1, n, device=device, dtype=dtype) for i in range(n - 1): D[i, i] = -1 D[i, i + 1] = 1 return D elif order == 2: # Second difference matrix (n-2) x n D = torch.zeros(n - 2, n, device=device, dtype=dtype) for i in range(n - 2): D[i, i] = 1 D[i, i + 1] = -2 D[i, i + 2] = 1 return D else: # Build higher order by composing first differences D1 = self._build_difference_matrix(n, 1, device, dtype) for _ in range(order - 1): D1_next = self._build_difference_matrix(D1.shape[0], 1, device, dtype) D1 = D1_next @ D1 return D1 def _smooth( self, x: torch.Tensor, lmbda: float, device: torch.device, dtype: torch.dtype ) -> torch.Tensor: """ Apply Whittaker smoothing: solve (I + lmbda * D^T D) z = x """ n = x.shape[-1] # Build second difference matrix D = self._build_difference_matrix(n, 2, device, dtype) # Build system matrix A = I + lmbda * D^T @ D eye = torch.eye(n, device=device, dtype=dtype) A = eye + lmbda * (D.T @ D) # Solve for each batch element if x.ndim == 1: z = torch.linalg.solve(A, x) else: # x is (B, n), we need to solve for each batch z = torch.linalg.solve(A, x.T).T return z def _differentiate_smooth( self, z: torch.Tensor, dt: float, device: torch.device, dtype: torch.dtype ) -> torch.Tensor: """ Compute derivative of smoothed signal using central differences. """ # Use central differences for interior points dz = torch.zeros_like(z) if z.ndim == 1: dz[1:-1] = (z[2:] - z[:-2]) / (2 * dt) dz[0] = (z[1] - z[0]) / dt dz[-1] = (z[-1] - z[-2]) / dt else: dz[..., 1:-1] = (z[..., 2:] - z[..., :-2]) / (2 * dt) dz[..., 0] = (z[..., 1] - z[..., 0]) / dt dz[..., -1] = (z[..., -1] - z[..., -2]) / dt return dz
[docs] def d(self, x: torch.Tensor, t: torch.Tensor, dim: int = -1) -> torch.Tensor: """ Compute the derivative of x with respect to t using spline smoothing. Args: x: Input tensor of shape (..., T) or (T,) t: Time points tensor of shape (T,) dim: Dimension along which to differentiate. Default -1. Returns: Derivative tensor of same shape as x. """ if x.numel() == 0: return x.clone() # Check length along the specified dim time_len = x.shape[dim] if time_len <= self.order: raise TypeError( f"Input length ({time_len}) must be > order ({self.order})" ) # Move differentiation dim to last position x, original_dim = self._move_dim_to_last(x, dim) # Get dt dt = (t[1] - t[0]).item() # Handle 1D input was_1d = x.ndim == 1 if was_1d: x = x.unsqueeze(0) # Flatten batch dimensions batch_shape = x.shape[:-1] T = x.shape[-1] x_flat = x.reshape(-1, T) # Apply smoothing z = self._smooth(x_flat, self.s, x.device, x.dtype) # Differentiate dz = self._differentiate_smooth(z, dt, x.device, x.dtype) # Reshape back dx = dz.reshape(*batch_shape, T) # Restore original shape if was_1d: dx = dx.squeeze(0) return self._restore_dim(dx, original_dim)
[docs] def smooth(self, x: torch.Tensor, t: torch.Tensor, dim: int = -1) -> torch.Tensor: """ Compute the smoothed version of x without differentiation. Args: x: Input tensor of shape (..., T) or (T,) t: Time points tensor of shape (T,) dim: Dimension along which to smooth. Default -1. Returns: Smoothed tensor of same shape as x. """ if x.numel() == 0: return x.clone() # Move dim to last position x, original_dim = self._move_dim_to_last(x, dim) # Handle 1D input was_1d = x.ndim == 1 if was_1d: x = x.unsqueeze(0) # Flatten batch dimensions batch_shape = x.shape[:-1] T = x.shape[-1] x_flat = x.reshape(-1, T) # Apply smoothing z = self._smooth(x_flat, self.s, x.device, x.dtype) # Reshape back z = z.reshape(*batch_shape, T) # Restore original shape if was_1d: z = z.squeeze(0) return self._restore_dim(z, original_dim)