Source code for torch_dxdt.kalman

"""
Kalman filter/smoother for numerical differentiation.

Uses the Kalman smoother to estimate derivatives under a Brownian motion model.
"""

import torch

from .base import Derivative


[docs] class Kalman(Derivative): """ Compute numerical derivatives using Kalman smoothing. The Kalman smoother finds the maximum likelihood estimator for a process whose derivative follows Brownian motion. This provides smooth, probabilistically-principled derivative estimates. The method minimizes: ||H*x - z||^2 + alpha * ||G*(x, dx)||_Q^2 where z is the noisy observation, x is the smoothed signal, dx is its derivative, and Q encodes the Brownian motion covariance. Args: alpha: Regularization parameter. Larger values give smoother results. If None, uses a default value of 1. Note: This implementation uses a batched linear solver and is fully differentiable, but may be slow for very long time series. Example: >>> kal = Kalman(alpha=1.0) >>> t = torch.linspace(0, 2*torch.pi, 100) >>> x = torch.sin(t) + 0.1 * torch.randn(100) >>> dx = kal.d(x, t) # Kalman-smoothed derivative """
[docs] def __init__(self, alpha: float = 1.0): if alpha is None: alpha = 1.0 self.alpha = alpha
def _solve_kalman( self, z: torch.Tensor, t: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """ Solve the Kalman smoothing problem. Following the formulation from the derivative package: - State is ordered as [dx_0, x_0, dx_1, x_1, ...] - G encodes the dynamics constraint - H extracts the position states Returns: Tuple of (x_hat, x_dot_hat) where x_hat is the smoothed signal and x_dot_hat is the derivative estimate. """ n = t.shape[0] device = t.device dtype = t.dtype # Compute time differences delta_times = t[1:] - t[:-1] # Build Q matrices for each time step # Q = [[dt, dt^2/2], [dt^2/2, dt^3/3]] Qs = [] for dt in delta_times: Q = torch.tensor( [[dt, dt**2 / 2], [dt**2 / 2, dt**3 / 3]], device=device, dtype=dtype ) Qs.append(Q) # Build block diagonal Q inverse Q_inv_blocks = [torch.linalg.inv(Q) for Q in Qs] Q_inv = torch.block_diag(*Q_inv_blocks) * self.alpha # Symmetrize to avoid numerical issues Q_inv = (Q_inv + Q_inv.T) / 2 # Build G matrix (dynamics constraint) # State ordering: [dx_0, x_0, dx_1, x_1, ...] # G_left: -[[1, 0], [dt, 1]] blocks for [dx_i, x_i] # G_right: +I blocks for [dx_{i+1}, x_{i+1}] # G is 2*(n-1) x 2*n G = torch.zeros(2 * (n - 1), 2 * n, device=device, dtype=dtype) for i, dt in enumerate(delta_times): # -[[1, 0], [dt, 1]] block for current state [dx_i, x_i] G[2 * i, 2 * i] = -1 # -dx_i coefficient G[2 * i, 2 * i + 1] = 0 # -x_i coefficient for dx equation G[2 * i + 1, 2 * i] = -dt # -dt*dx_i for x equation G[2 * i + 1, 2 * i + 1] = -1 # -x_i coefficient for x equation # +I block for next state [dx_{i+1}, x_{i+1}] G[2 * i, 2 * (i + 1)] = 1 G[2 * i + 1, 2 * (i + 1) + 1] = 1 # Build H matrix (observation model) # H extracts position x from state [dx, x] - position is at odd indices # H is n x 2n H = torch.zeros(n, 2 * n, device=device, dtype=dtype) for i in range(n): H[i, 2 * i + 1] = 1 # Extract x_i from position 2*i+1 # Solve the system: (H^T H + G^T Q^{-1} G) sol = H^T z lhs = H.T @ H + G.T @ Q_inv @ G # Handle batch dimension if z.ndim == 1: rhs = H.T @ z sol = torch.linalg.solve(lhs, rhs) # Extract x_hat and x_dot_hat # Position x is at odd indices (1, 3, 5, ...) # Velocity dx is at even indices (0, 2, 4, ...) x_hat = sol[1::2] # Odd indices - position x_dot_hat = sol[0::2] # Even indices - velocity else: # z is (B, n) rhs = H.T @ z.T # (2n, B) sol = torch.linalg.solve(lhs, rhs) # (2n, B) x_hat = sol[1::2, :].T # (B, n) x_dot_hat = sol[0::2, :].T # (B, n) return x_hat, x_dot_hat return x_hat, x_dot_hat
[docs] def d(self, x: torch.Tensor, t: torch.Tensor, dim: int = -1) -> torch.Tensor: """ Compute the derivative of x with respect to t using Kalman smoothing. Args: x: Input tensor of shape (..., T) or (T,) t: Time points tensor of shape (T,) dim: Dimension along which to differentiate. Default -1. Returns: Derivative tensor of same shape as x. """ if x.numel() == 0: return x.clone() # Move differentiation dim to last position x, original_dim = self._move_dim_to_last(x, dim) # Handle 1D input was_1d = x.ndim == 1 if was_1d: x = x.unsqueeze(0) # Flatten batch dimensions batch_shape = x.shape[:-1] T = x.shape[-1] x_flat = x.reshape(-1, T) # Solve Kalman smoothing _, x_dot_hat = self._solve_kalman(x_flat, t) # Reshape back dx = x_dot_hat.reshape(*batch_shape, T) # Restore original shape if was_1d: dx = dx.squeeze(0) return self._restore_dim(dx, original_dim)
[docs] def smooth(self, x: torch.Tensor, t: torch.Tensor, dim: int = -1) -> torch.Tensor: """ Compute the smoothed version of x using Kalman smoothing. Args: x: Input tensor of shape (..., T) or (T,) t: Time points tensor of shape (T,) dim: Dimension along which to smooth. Default -1. Returns: Smoothed tensor of same shape as x. """ if x.numel() == 0: return x.clone() # Move dim to last position x, original_dim = self._move_dim_to_last(x, dim) # Handle 1D input was_1d = x.ndim == 1 if was_1d: x = x.unsqueeze(0) # Flatten batch dimensions batch_shape = x.shape[:-1] T = x.shape[-1] x_flat = x.reshape(-1, T) # Solve Kalman smoothing x_hat, _ = self._solve_kalman(x_flat, t) # Reshape back z = x_hat.reshape(*batch_shape, T) # Restore original shape if was_1d: z = z.squeeze(0) return self._restore_dim(z, original_dim)