Source code for torch_dxdt.whittaker

"""
Whittaker-Eilers Global Smoother/Derivative using PyTorch.

Implements the Whittaker-Eilers smoother which uses penalized least squares
to smooth data. Derivatives can be computed by applying difference matrices
to the smoothed signal.
"""

from collections.abc import Sequence

import torch

from .base import Derivative


class WhittakerFunction(torch.autograd.Function):
    """
    Custom autograd function for Whittaker-Eilers smoothing.

    Solves the linear system (I + lmbda * D^T D) z = x where D is a
    second-order difference matrix. Uses Cholesky factorization for
    efficient solving and reuses the factorization in the backward pass.
    """

    @staticmethod
    def forward(ctx, x, lmbda, d_order):
        """
        Solve (I + lmbda * D^T D) z = x using Cholesky factorization.

        Args:
            ctx: Autograd context for saving tensors.
            x: Input tensor of shape (B, T) where B is batch size.
            lmbda: Smoothing parameter (larger = smoother).
            d_order: Order of the difference matrix (2 for standard smoothing).

        Returns:
            Smoothed tensor z of shape (B, T).
        """
        B, T = x.shape
        device = x.device
        dtype = x.dtype

        # 1. Construct D (difference matrix of order d_order)
        D = WhittakerFunction._build_difference_matrix(T, d_order, device, dtype)

        # 2. Construct A = I + lmbda * D^T D
        A = torch.eye(T, device=device, dtype=dtype) + lmbda * (D.T @ D)

        # 3. Cholesky Decomposition (A = L L^T)
        L, info = torch.linalg.cholesky_ex(A)
        if info.any():
            raise RuntimeError(
                "Cholesky decomposition failed. Matrix not positive definite."
            )

        # 4. Solve for z: A z^T = x^T  ->  L L^T z^T = x^T
        # x is (B, T), x.T is (T, B)
        # cholesky_solve solves A @ X = B, where A is (n, n) and B is (n, k)
        # So we solve for z^T: L @ L^T @ z^T = x^T
        # Result z^T is (T, B), then z = (z^T)^T = (B, T)
        z = torch.cholesky_solve(x.T, L).T

        # 5. Save L for the backward pass
        ctx.save_for_backward(L)

        return z

    @staticmethod
    def _build_difference_matrix(
        n: int, order: int, device: torch.device, dtype: torch.dtype
    ) -> torch.Tensor:
        """
        Build a difference matrix of specified order.

        Args:
            n: Size of the signal.
            order: Order of differences (1=first diff, 2=second diff, etc.)
            device: Torch device.
            dtype: Torch dtype.

        Returns:
            Difference matrix D of shape (n - order, n).
        """
        if order == 1:
            # First order difference: D[i, :] = [0..0, -1, 1, 0..0]
            D = torch.zeros(n - 1, n, device=device, dtype=dtype)
            for i in range(n - 1):
                D[i, i] = -1
                D[i, i + 1] = 1
        elif order == 2:
            # Second order difference: D[i, :] = [0..0, 1, -2, 1, 0..0]
            D = torch.zeros(n - 2, n, device=device, dtype=dtype)
            for i in range(n - 2):
                D[i, i] = 1
                D[i, i + 1] = -2
                D[i, i + 2] = 1
        elif order == 3:
            # Third order difference
            D = torch.zeros(n - 3, n, device=device, dtype=dtype)
            for i in range(n - 3):
                D[i, i] = -1
                D[i, i + 1] = 3
                D[i, i + 2] = -3
                D[i, i + 3] = 1
        else:
            raise ValueError(f"Difference order {order} not supported (use 1, 2, or 3)")

        return D

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass for the linear system solve.

        Since A is symmetric (A = A^T), the backward pass solves:
        A * grad_x = grad_output using the cached Cholesky factor L.
        """
        (L,) = ctx.saved_tensors

        # grad_output is (B, T), solve A * grad_x = grad_output
        # Same as forward: use cholesky_solve with transposed input
        grad_x = torch.cholesky_solve(grad_output.T, L).T

        # Gradient w.r.t lmbda and d_order are not computed
        return grad_x, None, None


[docs] class Whittaker(Derivative): """ Compute numerical derivatives using Whittaker-Eilers smoothing. The Whittaker-Eilers smoother uses penalized least squares with a difference penalty to smooth noisy data. Derivatives are computed by applying finite differences to the smoothed signal. This method is particularly effective for: - Strongly noisy data - Signals where global smoothness is desired - Cases where you want explicit control over smoothness vs. fidelity The smoothness is controlled by the parameter lmbda: - Small lmbda (~1): Less smoothing, follows data closely - Large lmbda (~1e6): Heavy smoothing, very smooth result Args: lmbda: Smoothing parameter. Larger values give smoother results. Typical values range from 1 to 1e6 depending on noise level. d_order: Order of the difference penalty (default 2). - 1: Penalizes first differences (piecewise constant) - 2: Penalizes second differences (piecewise linear, most common) - 3: Penalizes third differences (smoother curves) Example: >>> wh = Whittaker(lmbda=100.0) >>> t = torch.linspace(0, 2*torch.pi, 100) >>> x = torch.sin(t) + 0.1 * torch.randn(100) >>> dx = wh.d(x, t) # Smoothed derivative >>> x_smooth = wh.smooth(x, t) # Just smoothing """
[docs] def __init__(self, lmbda: float = 100.0, d_order: int = 2): if lmbda <= 0: raise ValueError("lmbda must be positive") if d_order not in (1, 2, 3): raise ValueError("d_order must be 1, 2, or 3") self.lmbda = lmbda self.d_order = d_order
def _smooth_internal(self, x: torch.Tensor) -> torch.Tensor: """ Apply Whittaker-Eilers smoothing to the input. Args: x: Input tensor of shape (B, T). Returns: Smoothed tensor of shape (B, T). """ return WhittakerFunction.apply(x, self.lmbda, self.d_order) def _compute_derivative( self, z: torch.Tensor, dt: float, order: int = 1 ) -> torch.Tensor: """ Compute derivative of smoothed signal using finite differences. Args: z: Smoothed signal of shape (B, T). dt: Time step. order: Derivative order (1 or 2). Returns: Derivative tensor of shape (B, T). """ B, T = z.shape if order == 1: # Central differences for interior, forward/backward at edges dz = torch.zeros_like(z) # Central differences for interior points dz[:, 1:-1] = (z[:, 2:] - z[:, :-2]) / (2 * dt) # Forward difference at start dz[:, 0] = (z[:, 1] - z[:, 0]) / dt # Backward difference at end dz[:, -1] = (z[:, -1] - z[:, -2]) / dt elif order == 2: # Second derivative using central differences dz = torch.zeros_like(z) # Central second difference for interior dz[:, 1:-1] = (z[:, 2:] - 2 * z[:, 1:-1] + z[:, :-2]) / (dt**2) # Edge handling: use one-sided differences dz[:, 0] = (z[:, 2] - 2 * z[:, 1] + z[:, 0]) / (dt**2) dz[:, -1] = (z[:, -1] - 2 * z[:, -2] + z[:, -3]) / (dt**2) else: raise ValueError(f"Derivative order {order} not supported (use 1 or 2)") return dz
[docs] def d(self, x: torch.Tensor, t: torch.Tensor, dim: int = -1) -> torch.Tensor: """ Compute the derivative of x with respect to t using Whittaker smoothing. Args: x: Input tensor of shape (..., T) or (T,) t: Time points tensor of shape (T,) dim: Dimension along which to differentiate. Default -1. Returns: Derivative tensor of same shape as x. """ if x.numel() == 0: return x.clone() # Move differentiation dim to last position x, original_dim = self._move_dim_to_last(x, dim) # Handle 1D input was_1d = x.ndim == 1 if was_1d: x = x.unsqueeze(0) # Flatten batch dimensions batch_shape = x.shape[:-1] T = x.shape[-1] x_flat = x.reshape(-1, T) # Get dt dt = (t[1] - t[0]).item() # Smooth the signal z = self._smooth_internal(x_flat) # Compute derivative dz = self._compute_derivative(z, dt, order=1) # Reshape back dx = dz.reshape(*batch_shape, T) # Restore original shape if was_1d: dx = dx.squeeze(0) return self._restore_dim(dx, original_dim)
[docs] def smooth(self, x: torch.Tensor, t: torch.Tensor, dim: int = -1) -> torch.Tensor: """ Compute the smoothed version of x using Whittaker-Eilers filtering. Args: x: Input tensor of shape (..., T) or (T,) t: Time points tensor of shape (T,) (not used but kept for API consistency) dim: Dimension along which to smooth. Default -1. Returns: Smoothed tensor of same shape as x. """ if x.numel() == 0: return x.clone() # Move dim to last position x, original_dim = self._move_dim_to_last(x, dim) # Handle 1D input was_1d = x.ndim == 1 if was_1d: x = x.unsqueeze(0) # Flatten batch dimensions batch_shape = x.shape[:-1] T = x.shape[-1] x_flat = x.reshape(-1, T) # Smooth z = self._smooth_internal(x_flat) # Reshape back z = z.reshape(*batch_shape, T) # Restore original shape if was_1d: z = z.squeeze(0) return self._restore_dim(z, original_dim)
def _compute_order( self, x: torch.Tensor, t: torch.Tensor, order: int, dim: int ) -> torch.Tensor: """Compute a specific derivative order.""" if order > 2: raise ValueError("Whittaker only supports derivative orders 1 and 2") if x.numel() == 0: return x.clone() # Move differentiation dim to last position x, original_dim = self._move_dim_to_last(x, dim) # Handle 1D input was_1d = x.ndim == 1 if was_1d: x = x.unsqueeze(0) # Flatten batch dimensions batch_shape = x.shape[:-1] T = x.shape[-1] x_flat = x.reshape(-1, T) # Get dt dt = (t[1] - t[0]).item() # Smooth the signal (done once, shared across orders in d_orders) z = self._smooth_internal(x_flat) # Compute derivative of requested order dz = self._compute_derivative(z, dt, order=order) # Reshape back dx = dz.reshape(*batch_shape, T) # Restore original shape if was_1d: dx = dx.squeeze(0) return self._restore_dim(dx, original_dim)
[docs] def d_orders( self, x: torch.Tensor, t: torch.Tensor, orders: Sequence[int] = (1, 2), dim: int = -1, ) -> dict[int, torch.Tensor]: """ Compute multiple derivative orders simultaneously and efficiently. The Whittaker smoother computes the smoothed signal once and then derives multiple derivative orders from it, avoiding redundant smoothing computation. Args: x: Input tensor of shape (..., T) or (T,) t: Time points tensor of shape (T,) orders: Sequence of derivative orders to compute. Default is (1, 2). Order 0 returns the smoothed signal. dim: Dimension along which to differentiate. Default -1. Returns: Dictionary mapping order -> derivative tensor. Each tensor has the same shape as x. Example: >>> wh = Whittaker(lmbda=100.0) >>> derivs = wh.d_orders(x, t, orders=[0, 1, 2]) >>> x_smooth = derivs[0] # Smoothed signal >>> dx = derivs[1] # First derivative >>> d2x = derivs[2] # Second derivative """ orders = list(orders) max_order = max(orders) if max_order > 2: raise ValueError("Whittaker only supports derivative orders up to 2") if x.numel() == 0: return {order: x.clone() for order in orders} # Move differentiation dim to last position x_moved, original_dim = self._move_dim_to_last(x, dim) was_1d = x_moved.ndim == 1 # Get dt dt = (t[1] - t[0]).item() # Handle 1D input if was_1d: x_moved = x_moved.unsqueeze(0) # Flatten batch dimensions batch_shape = x_moved.shape[:-1] T = x_moved.shape[-1] x_flat = x_moved.reshape(-1, T) # Smooth the signal ONCE (shared across all orders) z = self._smooth_internal(x_flat) # Compute each derivative order results = {} for order in orders: if order == 0: dx = z.clone() else: dx = self._compute_derivative(z, dt, order=order) # Reshape back dx = dx.reshape(*batch_shape, T) if was_1d: dx = dx.squeeze(0) results[order] = self._restore_dim(dx, original_dim) return results