Source code for torch_dxdt.base

"""
Base class for differentiable numerical differentiation methods.
"""

import abc
from collections.abc import Sequence

import torch


[docs] class Derivative(abc.ABC): """ Abstract base class for numerical differentiation methods. All differentiation methods should inherit from this class and implement the `d` method for computing derivatives. """
[docs] @abc.abstractmethod def d(self, x: torch.Tensor, t: torch.Tensor, dim: int = -1) -> torch.Tensor: """ Compute the derivative of x with respect to t. Args: x: Tensor of shape (..., T) containing the signal values. Multiple signals can be batched along leading dimensions. t: Tensor of shape (T,) containing the time points. Must be evenly spaced for most methods. dim: The dimension along which to differentiate. Default is -1 (last dimension). Returns: Tensor of same shape as x containing the derivative dx/dt. """ pass
[docs] def d_orders( self, x: torch.Tensor, t: torch.Tensor, orders: Sequence[int] = (1, 2), dim: int = -1, ) -> dict[int, torch.Tensor]: """ Compute multiple derivative orders simultaneously. This method computes multiple derivative orders in an efficient manner, avoiding redundant computation where possible. For methods that support it (e.g., SavitzkyGolay), shared computation like polynomial fitting is reused across orders. Args: x: Tensor of shape (..., T) containing the signal values. t: Tensor of shape (T,) containing the time points. orders: Sequence of derivative orders to compute. Default is (1, 2). Order 0 returns the smoothed signal (if supported). dim: The dimension along which to differentiate. Default is -1 (last dimension). Returns: Dictionary mapping order -> derivative tensor. Each tensor has the same shape as x. Example: >>> sg = SavitzkyGolay(window_length=11, polyorder=4) >>> derivs = sg.d_orders(x, t, orders=[1, 2]) >>> dx = derivs[1] # First derivative >>> d2x = derivs[2] # Second derivative """ # Default implementation: call d() for each order # Subclasses can override this for more efficient implementations results = {} for order in orders: if order == 0: try: results[0] = self.smooth(x, t, dim=dim) except NotImplementedError: # If smoothing not supported, return original results[0] = x.clone() else: # Create a copy of self with the requested order results[order] = self._compute_order(x, t, order, dim) return results
def _compute_order( self, x: torch.Tensor, t: torch.Tensor, order: int, dim: int ) -> torch.Tensor: """ Compute a specific derivative order. Override in subclasses. Default implementation raises NotImplementedError if order != 1. """ if order == 1: return self.d(x, t, dim=dim) raise NotImplementedError( f"{self.__class__.__name__} does not support d_orders with order={order}. " "Override _compute_order or d_orders for multi-order support." )
[docs] def smooth(self, x: torch.Tensor, t: torch.Tensor, dim: int = -1) -> torch.Tensor: """ Compute the smoothed version of x (if supported by the method). Args: x: Tensor of shape (..., T) containing the signal values. t: Tensor of shape (T,) containing the time points. dim: The dimension along which to smooth. Default is -1 (last dimension). Returns: Tensor of same shape as x containing the smoothed signal. Raises: NotImplementedError: If the method does not support smoothing. """ raise NotImplementedError( f"{self.__class__.__name__} does not support smoothing. " "Only certain global methods (like Kalman, Kernel, Spline) support this." )
def _move_dim_to_last( self, x: torch.Tensor, dim: int ) -> tuple[torch.Tensor, int]: """ Move the specified dimension to the last position. Returns: Tuple of (moved tensor, original dim position) """ if dim == -1 or dim == x.ndim - 1: return x, dim # Normalize negative dim dim = dim if dim >= 0 else x.ndim + dim # Move dim to last position perm = list(range(x.ndim)) perm.remove(dim) perm.append(dim) return x.permute(*perm), dim def _restore_dim(self, x: torch.Tensor, original_dim: int) -> torch.Tensor: """ Restore the dimension to its original position. """ if original_dim == -1 or original_dim == x.ndim - 1: return x # Normalize negative dim original_dim = original_dim if original_dim >= 0 else x.ndim + original_dim # Move last dim back to original position perm = list(range(x.ndim - 1)) perm.insert(original_dim, x.ndim - 1) return x.permute(*perm)