55try : import cvxpy
66except ImportError : pass
77
8- from pynumdiff .utils .utility import huber_const
8+ from pynumdiff .utils .utility import huber_const , wrap_angle
99
1010
11- def kalman_filter (y , xhat0 , P0 , A , Q , C , R , B = None , u = None , save_P = True ):
11+ def kalman_filter (y , xhat0 , P0 , A , Q , C , R , B = None , u = None , save_P = True , innovation_fn = None ):
1212 """Run the forward pass of a Kalman filter. Expects discrete-time matrices; use :func:`scipy.linalg.expm`
1313 in the caller to convert from continuous time if needed.
1414
@@ -24,6 +24,9 @@ def kalman_filter(y, xhat0, P0, A, Q, C, R, B=None, u=None, save_P=True):
2424 :param np.array u: optional control inputs, stacked in the direction of axis 0
2525 :param bool save_P: whether to save history of error covariance and a priori state estimates, used with rts
2626 smoothing but nonstandard to compute for ordinary filtering
27+ :param callable innovation_fn: optional function :code:`(y_n, pred)` returning the innovation :code:`y_n - pred`,
28+ where :code:`pred = C @ xhat_`. When :code:`None` (default), standard subtraction is used. Use this to handle
29+ circular quantities, e.g. :code:`lambda y, pred: wrap_angle(y - pred)` for angular measurements in radians.
2730
2831 :return: - **xhat_pre** (np.array) -- a priori estimates of xhat, with axis=0 the batch dimension, so xhat[n] gets the nth step
2932 - **xhat_post** (np.array) -- a posteriori estimates of xhat
@@ -57,7 +60,8 @@ def kalman_filter(y, xhat0, P0, A, Q, C, R, B=None, u=None, save_P=True):
5760 P = P_ .copy ()
5861 if not np .isnan (y [n ]): # handle missing data
5962 K = P_ @ C .T @ np .linalg .inv (C @ P_ @ C .T + R )
60- xhat += K @ (y [n ] - C @ xhat_ )
63+ innovation = innovation_fn (y [n ], C @ xhat_ ) if innovation_fn is not None else y [n ] - C @ xhat_
64+ xhat += K @ innovation
6165 P -= K @ C @ P_
6266 # the [n]th index of pre variables holds _{n|n-1} info; the [n]th index of post variables holds _{n|n} info
6367 xhat_post [n ] = xhat
@@ -94,7 +98,7 @@ def rts_smooth(A, xhat_pre, xhat_post, P_pre, P_post, compute_P_smooth=True):
9498 return xhat_smooth if not compute_P_smooth else (xhat_smooth , P_smooth )
9599
96100
97- def rtsdiff (x , dt_or_t , order , log_qr_ratio , forwardbackward , axis = 0 ):
101+ def rtsdiff (x , dt_or_t , order , log_qr_ratio , forwardbackward , axis = 0 , circular = False ):
98102 """Perform Rauch-Tung-Striebel smoothing with a naive constant derivative model. Makes use of :code:`kalman_filter`
99103 and :code:`rts_smooth`, which are made public. :code:`constant_X` methods in this module call this function.
100104
@@ -109,6 +113,9 @@ def rtsdiff(x, dt_or_t, order, log_qr_ratio, forwardbackward, axis=0):
109113 :param bool forwardbackward: indicates whether to run smoother forwards and backwards
110114 (usually achieves better estimate at end points)
111115 :param int axis: data dimension along which differentiation is performed
116+ :param bool circular: if :code:`True`, treat the measured quantity as a circular variable in radians, wrapping
117+ the innovation to :math:`[-\\ pi, \\ pi]`. The input :code:`x` must be in radians; convert degrees with
118+ :code:`np.deg2rad` first. Default :code:`False`.
112119
113120 :return: - **x_hat** (np.array) -- estimated (smoothed) x, same shape as input :code:`x`
114121 - **dxdt_hat** (np.array) -- estimated derivative of x, same shape as input :code:`x`
@@ -140,22 +147,24 @@ def rtsdiff(x, dt_or_t, order, log_qr_ratio, forwardbackward, axis=0):
140147 Q_d [n ] = eM [:order + 1 , order + 1 :] @ A_d [n ].T
141148 if forwardbackward : A_d_bwd = np .linalg .inv (A_d [::- 1 ]) # properly broadcasts, taking inv of each stacked 2D array
142149
150+ innovation_fn = (lambda y , pred : wrap_angle (y - pred )) if circular else None # wrap innovation for circular variables
151+
143152 x_hat = np .empty_like (x ); dxdt_hat = np .empty_like (x )
144153 if forwardbackward : w = np .linspace (0 , 1 , N ) # weights used to combine forward and backward results
145154
146155 for vec_idx in np .ndindex (x .shape [:axis ] + x .shape [axis + 1 :]): # works properly for 1D case too
147156 s = vec_idx [:axis ] + (slice (None ),) + vec_idx [axis :] # for indexing the vector we wish to differentiate
148157 xhat0 = np .zeros (order + 1 ); xhat0 [0 ] = x [s ][0 ] if not np .isnan (x [s ][0 ]) else 0 # The first estimate is the first seen state. See #110
149158
150- xhat_pre , xhat_post , P_pre , P_post = kalman_filter (x [s ], xhat0 , P0 , A_d , Q_d , C , R )
159+ xhat_pre , xhat_post , P_pre , P_post = kalman_filter (x [s ], xhat0 , P0 , A_d , Q_d , C , R , innovation_fn = innovation_fn )
151160 xhat_smooth = rts_smooth (A_d , xhat_pre , xhat_post , P_pre , P_post , compute_P_smooth = False )
152161 x_hat [s ] = xhat_smooth [:,0 ] # first dimension is time, so slice first and second states at all times
153162 dxdt_hat [s ] = xhat_smooth [:,1 ]
154163
155164 if forwardbackward :
156165 xhat0 [0 ] = x [s ][- 1 ] if not np .isnan (x [s ][- 1 ]) else 0
157166 xhat_pre , xhat_post , P_pre , P_post = kalman_filter (x [s ][::- 1 ], xhat0 , P0 , A_d_bwd ,
158- Q_d if Q_d .ndim == 2 else Q_d [::- 1 ], C , R ) # Use same Q matrices as before, because noise should still grow in reverse time
167+ Q_d if Q_d .ndim == 2 else Q_d [::- 1 ], C , R , innovation_fn = innovation_fn ) # Use same Q matrices as before, because noise should still grow in reverse time
159168 xhat_smooth = rts_smooth (A_d_bwd , xhat_pre , xhat_post , P_pre , P_post , compute_P_smooth = False )
160169
161170 x_hat [s ] = x_hat [s ] * w + xhat_smooth [:, 0 ][::- 1 ] * (1 - w )
0 commit comments