Skip to content

Commit 25739d7

Browse files
pavelkomarovclaude
andcommitted
Clean up circular variable support in rtsdiff
Replaces the per-dimension circular_vars/circular_units approach with a generic innovation_fn parameter on kalman_filter and a simple circular=False boolean on rtsdiff. Adds wrap_angle (radians-only) to utility.py and a test for wrapping angle differentiation. Addresses #178. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 37b39af commit 25739d7

3 files changed

Lines changed: 45 additions & 6 deletions

File tree

pynumdiff/kalman_smooth.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
try: import cvxpy
66
except ImportError: pass
77

8-
from pynumdiff.utils.utility import huber_const
8+
from pynumdiff.utils.utility import huber_const, wrap_angle
99

1010

11-
def kalman_filter(y, xhat0, P0, A, Q, C, R, B=None, u=None, save_P=True):
11+
def kalman_filter(y, xhat0, P0, A, Q, C, R, B=None, u=None, save_P=True, innovation_fn=None):
1212
"""Run the forward pass of a Kalman filter. Expects discrete-time matrices; use :func:`scipy.linalg.expm`
1313
in the caller to convert from continuous time if needed.
1414
@@ -24,6 +24,9 @@ def kalman_filter(y, xhat0, P0, A, Q, C, R, B=None, u=None, save_P=True):
2424
:param np.array u: optional control inputs, stacked in the direction of axis 0
2525
:param bool save_P: whether to save history of error covariance and a priori state estimates, used with rts
2626
smoothing but nonstandard to compute for ordinary filtering
27+
:param callable innovation_fn: optional function :code:`(y_n, pred)` returning the innovation :code:`y_n - pred`,
28+
where :code:`pred = C @ xhat_`. When :code:`None` (default), standard subtraction is used. Use this to handle
29+
circular quantities, e.g. :code:`lambda y, pred: wrap_angle(y - pred)` for angular measurements in radians.
2730
2831
:return: - **xhat_pre** (np.array) -- a priori estimates of xhat, with axis=0 the batch dimension, so xhat[n] gets the nth step
2932
- **xhat_post** (np.array) -- a posteriori estimates of xhat
@@ -57,7 +60,8 @@ def kalman_filter(y, xhat0, P0, A, Q, C, R, B=None, u=None, save_P=True):
5760
P = P_.copy()
5861
if not np.isnan(y[n]): # handle missing data
5962
K = P_ @ C.T @ np.linalg.inv(C @ P_ @ C.T + R)
60-
xhat += K @ (y[n] - C @ xhat_)
63+
innovation = innovation_fn(y[n], C @ xhat_) if innovation_fn is not None else y[n] - C @ xhat_
64+
xhat += K @ innovation
6165
P -= K @ C @ P_
6266
# the [n]th index of pre variables holds _{n|n-1} info; the [n]th index of post variables holds _{n|n} info
6367
xhat_post[n] = xhat
@@ -94,7 +98,7 @@ def rts_smooth(A, xhat_pre, xhat_post, P_pre, P_post, compute_P_smooth=True):
9498
return xhat_smooth if not compute_P_smooth else (xhat_smooth, P_smooth)
9599

96100

97-
def rtsdiff(x, dt_or_t, order, log_qr_ratio, forwardbackward, axis=0):
101+
def rtsdiff(x, dt_or_t, order, log_qr_ratio, forwardbackward, axis=0, circular=False):
98102
"""Perform Rauch-Tung-Striebel smoothing with a naive constant derivative model. Makes use of :code:`kalman_filter`
99103
and :code:`rts_smooth`, which are made public. :code:`constant_X` methods in this module call this function.
100104
@@ -109,6 +113,9 @@ def rtsdiff(x, dt_or_t, order, log_qr_ratio, forwardbackward, axis=0):
109113
:param bool forwardbackward: indicates whether to run smoother forwards and backwards
110114
(usually achieves better estimate at end points)
111115
:param int axis: data dimension along which differentiation is performed
116+
:param bool circular: if :code:`True`, treat the measured quantity as a circular variable in radians, wrapping
117+
the innovation to :math:`[-\\pi, \\pi]`. The input :code:`x` must be in radians; convert degrees with
118+
:code:`np.deg2rad` first. Default :code:`False`.
112119
113120
:return: - **x_hat** (np.array) -- estimated (smoothed) x, same shape as input :code:`x`
114121
- **dxdt_hat** (np.array) -- estimated derivative of x, same shape as input :code:`x`
@@ -140,22 +147,24 @@ def rtsdiff(x, dt_or_t, order, log_qr_ratio, forwardbackward, axis=0):
140147
Q_d[n] = eM[:order+1, order+1:] @ A_d[n].T
141148
if forwardbackward: A_d_bwd = np.linalg.inv(A_d[::-1]) # properly broadcasts, taking inv of each stacked 2D array
142149

150+
innovation_fn = (lambda y, pred: wrap_angle(y - pred)) if circular else None # wrap innovation for circular variables
151+
143152
x_hat = np.empty_like(x); dxdt_hat = np.empty_like(x)
144153
if forwardbackward: w = np.linspace(0, 1, N) # weights used to combine forward and backward results
145154

146155
for vec_idx in np.ndindex(x.shape[:axis] + x.shape[axis+1:]): # works properly for 1D case too
147156
s = vec_idx[:axis] + (slice(None),) + vec_idx[axis:] # for indexing the vector we wish to differentiate
148157
xhat0 = np.zeros(order+1); xhat0[0] = x[s][0] if not np.isnan(x[s][0]) else 0 # The first estimate is the first seen state. See #110
149158

150-
xhat_pre, xhat_post, P_pre, P_post = kalman_filter(x[s], xhat0, P0, A_d, Q_d, C, R)
159+
xhat_pre, xhat_post, P_pre, P_post = kalman_filter(x[s], xhat0, P0, A_d, Q_d, C, R, innovation_fn=innovation_fn)
151160
xhat_smooth = rts_smooth(A_d, xhat_pre, xhat_post, P_pre, P_post, compute_P_smooth=False)
152161
x_hat[s] = xhat_smooth[:,0] # first dimension is time, so slice first and second states at all times
153162
dxdt_hat[s] = xhat_smooth[:,1]
154163

155164
if forwardbackward:
156165
xhat0[0] = x[s][-1] if not np.isnan(x[s][-1]) else 0
157166
xhat_pre, xhat_post, P_pre, P_post = kalman_filter(x[s][::-1], xhat0, P0, A_d_bwd,
158-
Q_d if Q_d.ndim == 2 else Q_d[::-1], C, R) # Use same Q matrices as before, because noise should still grow in reverse time
167+
Q_d if Q_d.ndim == 2 else Q_d[::-1], C, R, innovation_fn=innovation_fn) # Use same Q matrices as before, because noise should still grow in reverse time
159168
xhat_smooth = rts_smooth(A_d_bwd, xhat_pre, xhat_post, P_pre, P_post, compute_P_smooth=False)
160169

161170
x_hat[s] = x_hat[s] * w + xhat_smooth[:, 0][::-1] * (1-w)

pynumdiff/tests/test_diff_methods.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,24 @@ def test_multidimensionality(multidim_method_and_params, request):
405405
legend = ax3.legend(bbox_to_anchor=(0.7, 0.8)); legend.legend_handles[0].set_facecolor(pyplot.cm.viridis(0.6))
406406
fig.suptitle(f'{diff_method.__name__}', fontsize=16)
407407

408+
def test_circular_rtsdiff():
409+
"""Ensure rtsdiff with circular=True correctly differentiates a wrapping angle signal in radians"""
410+
np.random.seed(42)
411+
N = 200
412+
dt_circ = 0.05
413+
t_circ = np.arange(N) * dt_circ
414+
true_dtheta = 2.0 # constant angular velocity in rad/s
415+
theta_true = true_dtheta * t_circ # linearly increasing angle, crosses 2*pi boundaries
416+
theta_noisy = np.angle(np.exp(1j * (theta_true + 0.1 * np.random.randn(N)))) # wrap to [-pi, pi] and add noise
417+
418+
_, dxdt_hat = rtsdiff(theta_noisy, dt_circ, order=1, log_qr_ratio=1, forwardbackward=True, circular=True)
419+
420+
# The interior of the signal (away from endpoints) should recover the true angular velocity well
421+
interior = slice(10, N-10)
422+
rmse = np.sqrt(np.mean((dxdt_hat[interior] - true_dtheta)**2))
423+
assert rmse < 0.5, f"RMSE of angular velocity estimate too large: {rmse:.3f} rad/s"
424+
425+
408426
# List of methods that can handle missing values
409427
nan_methods_and_params = [
410428
(splinediff, {'degree': 5, 's': 2}),

pynumdiff/utils/utility.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,18 @@
77
from scipy.ndimage import convolve1d
88

99

10+
def wrap_angle(angle):
11+
"""Wrap an angle (in radians) to the range [-pi, pi].
12+
13+
:param float or np.array angle: angle(s) in radians to wrap
14+
:return: (float or np.array) -- wrapped angle(s) in [-pi, pi]
15+
16+
.. note::
17+
Only radians are supported. Convert degrees to radians with :code:`np.deg2rad` before using this function.
18+
"""
19+
return (angle + np.pi) % (2*np.pi) - np.pi
20+
21+
1022
def huber_const(M):
1123
"""Scale that makes :code:`sum(huber())` interpolate :math:`\\sqrt{2}\\|\\cdot\\|_1` and :math:`\\frac{1}{2}\\|\\cdot\\|_2^2`,
1224
from https://jmlr.org/papers/volume14/aravkin13a/aravkin13a.pdf, with correction for missing sqrt. Here :code:`huber`

0 commit comments

Comments
 (0)