forked from MorinWang/SONATA
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_SONATA.py
More file actions
2097 lines (1722 loc) · 91.5 KB
/
model_SONATA.py
File metadata and controls
2097 lines (1722 loc) · 91.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
SONATA Model
Integrates:
- Martingale representation theorem for coreset selection
- Optimal stopping problem for data point selection
- Enhanced multi-scale time weighting with Itô formula
- Improved state estimation and prediction methods
"""
import numpy as np
import torch
import logging
import bisect
import tensorly as tl
from collections import deque
import math
import sys
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Set tensorly backend
tl.set_backend("pytorch")
# Constants
JITTER = 1e-4
class DynamicCoreSetTensorFactorization:
"""
Base class for Dynamic CoreSet Tensor Factorization
Implements core methods for data point coreset selection and streaming tensor factorization
"""
def __init__(self, hyper_dict, data_dict):
"""Initialize the model"""
# Basic parameters
self.device = hyper_dict["device"]
self.R_U = hyper_dict["R_U"]
# Prior parameters
self.v = hyper_dict["v"] # Prior variance
self.a0 = hyper_dict["a0"]
self.b0 = hyper_dict["b0"]
self.DAMPING = hyper_dict["DAMPING"]
self.DAMPING_tau = hyper_dict["DAMPING_tau"]
# Data related parameters
self.ndims = data_dict["ndims"]
self.nmods = len(self.ndims)
self.tr_ind = data_dict["tr_ind"]
self.tr_y = torch.tensor(data_dict["tr_y"]).to(self.device) # N*1
self.te_ind = data_dict["te_ind"]
self.te_y = torch.tensor(data_dict["te_y"]).to(self.device) # N*1
self.train_time_ind = data_dict["tr_T_disct"] # N_train*1
self.test_time_ind = data_dict["te_T_disct"] # N_test*1
self.unique_train_time = list(np.unique(self.train_time_ind))
self.time_uni = data_dict["time_uni"] # N_time*1
self.N_time = len(self.time_uni)
# Create LDS parameters
self.lds_params = self._create_lds_params(hyper_dict, data_dict)
# Store LDS trajectories for each mode and factor
self.factor_dynamics = {}
# Initialize coreset manager - will be replaced in subclasses
if "coreset_manager" in hyper_dict:
self.coreset_manager = hyper_dict["coreset_manager"]
else:
# Imported here to avoid circular imports
from utils_martingale_coreset import DataPointCoreSetManager
self.coreset_manager = DataPointCoreSetManager(
max_size=hyper_dict.get("coreset_max_size", 100),
initial_threshold=hyper_dict.get("coreset_threshold", 0.5),
adaptive_threshold=hyper_dict.get("adaptive_threshold", True),
importance_weights=hyper_dict.get("importance_weights", (0.4, 0.3, 0.3)),
device=self.device,
exploration_rate=hyper_dict.get("initial_exploration_rate", 0.9),
decay_rate=hyper_dict.get("exploration_decay_rate", 0.1)
)
# Initialize multi-scale weighting mechanism
if "multi_scale_weighting" in hyper_dict:
self.multi_scale_weighting = hyper_dict["multi_scale_weighting"]
else:
# Imported here to avoid circular imports
from utils_martingale_coreset import MultiScaleWeighting
self.multi_scale_weighting = MultiScaleWeighting(
num_scales=hyper_dict.get("num_time_scales", 3),
hidden_dim=hyper_dict.get("scale_hidden_dim", 32),
device=self.device,
temperature=hyper_dict.get("attention_temperature", 1.0)
)
# Factor posterior distributions
self.post_U_m = [
torch.rand(dim, self.R_U, 1, self.N_time).double().to(self.device)
for dim in self.ndims
] # (dim, R_U, 1, T) * nmod
self.post_U_v = [
torch.eye(self.R_U).reshape(
(1, self.R_U, self.R_U,
1)).repeat(dim, 1, 1, self.N_time).double().to(self.device)
for dim in self.ndims
] # (dim, R_U, R_U, T) * nmod
# Noise posterior distribution
self.post_a = self.a0
self.post_b = self.b0
self.E_tau = 1
# Build time-data tables: given timestamp ID, returns entry indices
if "utils_streaming" in sys.modules:
import utils_streaming
self.time_data_table_tr = utils_streaming.build_time_data_table(
self.train_time_ind)
self.time_data_table_te = utils_streaming.build_time_data_table(
self.test_time_ind)
else:
# Basic implementation if utils_streaming not available
self.time_data_table_tr = self._build_time_data_table(self.train_time_ind)
self.time_data_table_te = self._build_time_data_table(self.test_time_ind)
# Placeholders
self.ind_T = None
self.y_T = None
self.uid_table = None
self.data_table = None
# Store messages (by uid order)
self.msg_U_m = None
self.msg_U_V = None
# Store messages (by data-llk order)
self.msg_U_lam_llk = None
self.msg_U_eta_llk = None
self.msg_a_llk = None
self.msg_b_llk = None
# Set product method
self.product_method = None # Set in subclasses
logger.info(f"Initialized DynamicCoreSetTensorFactorization: R_U={self.R_U}, nmods={self.nmods}")
def _build_time_data_table(self, time_indices):
"""Build time-data table if utils_streaming not available"""
time_data_table = {}
for i, t in enumerate(time_indices):
if t not in time_data_table:
time_data_table[t] = []
time_data_table[t].append(i)
return time_data_table
def _create_lds_params(self, hyper_dict, data_dict):
"""Create LDS parameters"""
LDS_init = {}
LDS_init["device"] = hyper_dict["device"]
# Build F,H,R matrices
D = hyper_dict["R_U"]
# Get parameters from config
LDS_init["R"] = torch.tensor(hyper_dict.get("noise", 1.0))
kernel = hyper_dict.get("kernel", "Matern_23")
lengthscale = hyper_dict.get("lengthscale", 0.3)
variance = hyper_dict.get("variance", 1.0)
if kernel == "Matern_21":
LDS_init["F"] = -1 / lengthscale * torch.eye(D)
LDS_init["H"] = torch.eye(D)
LDS_init["P_inf"] = torch.eye(D) * variance
LDS_init["P_0"] = LDS_init["P_inf"]
LDS_init["m_0"] = torch.randn(D, 1) * 0.3
elif kernel == "Matern_23":
lamb = np.sqrt(3) / lengthscale
F = torch.zeros((2 * D, 2 * D))
F[:D, :D] = 0
F[:D, D:] = torch.eye(D)
F[D:, :D] = -lamb * lamb * torch.eye(D)
F[D:, D:] = -2 * lamb * torch.eye(D)
P_inf = torch.diag(
torch.cat((
variance * torch.ones(D),
lamb * lamb * variance * torch.ones(D),
))
)
LDS_init["F"] = F
LDS_init["P_inf"] = P_inf
LDS_init["H"] = torch.cat((torch.eye(D), torch.zeros(D, D)), dim=1)
LDS_init["P_0"] = P_inf
LDS_init["m_0"] = 0.1 * torch.ones(2 * D, 1)
logger.info(f"LDS parameters created: kernel={kernel}, lengthscale={lengthscale}, variance={variance}")
return LDS_init
def ensure_factor_dynamics(self, mode, uid):
"""Ensure factor dynamics process exists for the specified factor"""
if (mode, uid) not in self.factor_dynamics:
# Create a new LDS_GP_streaming instance
try:
from model_LDS import LDS_GP_streaming
self.factor_dynamics[(mode, uid)] = LDS_GP_streaming(self.lds_params)
except ImportError:
# If model_LDS not available, use a simplified version
self.factor_dynamics[(mode, uid)] = SimplifiedLDS(self.lds_params)
return self.factor_dynamics[(mode, uid)]
def track_envloved_objects(self, T):
"""Get indices/values/object IDs of entries observed at time T"""
eind_T = self.time_data_table_tr[T] # List of entry IDs observed at this timestamp
self.ind_T = self.tr_ind[eind_T]
self.y_T = self.tr_y[eind_T].reshape(-1, 1, 1)
self.N_T = len(self.y_T)
# Build object ID and data tables
try:
import utils_streaming
self.uid_table, self.data_table = utils_streaming.build_id_key_table(
nmod=self.nmods, ind=self.ind_T
)
except ImportError:
# Simple implementation if utils_streaming not available
self.uid_table, self.data_table = self._build_id_key_table(self.ind_T)
def _build_id_key_table(self, ind):
"""Build ID key table if utils_streaming not available"""
uid_table = [[] for _ in range(self.nmods)]
data_table = [[] for _ in range(self.nmods)]
for i, indices in enumerate(ind):
for mode, idx in enumerate(indices):
if idx not in uid_table[mode]:
uid_table[mode].append(idx)
data_table[mode].append([i])
else:
pos = uid_table[mode].index(idx)
data_table[mode][pos].append(i)
return uid_table, data_table
def filter_predict(self, T):
"""KF prediction step for trajectories of involved objects + update posterior"""
current_time_stamp = self.time_uni[T]
# Predict for each involved object
for mode in range(self.nmods):
for uid in self.uid_table[mode]:
# Ensure factor dynamics process exists
factor_process = self.ensure_factor_dynamics(mode, uid)
# Perform prediction
factor_process.filter_predict(current_time_stamp)
# Update posterior
H = factor_process.H
m = factor_process.m_pred_list[-1]
P = factor_process.P_pred_list[-1]
self.post_U_m[mode][uid, :, :, T] = torch.mm(H, m)
self.post_U_v[mode][uid, :, :, T] = torch.mm(torch.mm(H, P), H.T)
def update_coreset(self, T):
"""
Update data point coreset
Submit current time step's observed data points to coreset manager for evaluation
"""
# Prepare current time step's data batch
data_batch = []
for i in range(self.N_T):
indices = self.ind_T[i]
y = self.y_T[i].item()
time_ind = T
data_batch.append((indices, y, time_ind))
# Use coreset manager to update coreset
added, removed = self.coreset_manager.update_coreset(data_batch, self)
# Log coreset size and changes
if T % 10 == 0: # Log every 10 time steps
logger.info(f"T={T}, coreset size: {self.coreset_manager.get_coreset_size()}, "
f"added: {len(added)}, removed: {len(removed)}")
# Update model confidence (if available)
if T > 0:
pred_error = self.compute_prediction_error(T)
if pred_error > 0:
confidence = 1.0 / (1.0 + pred_error)
self.coreset_manager.update_confidence(confidence)
return added, removed
def compute_prediction_error(self, T):
"""Calculate current prediction error for confidence update"""
if hasattr(self, 'te_ind') and hasattr(self, 'te_y'):
# Use a small subset of test data to calculate error
sample_size = min(100, len(self.te_ind))
if sample_size > 0:
indices = np.random.choice(len(self.te_ind), sample_size, replace=False)
sample_ind = self.te_ind[indices]
sample_y = self.te_y[indices]
sample_time = np.ones_like(indices) * T
pred, _ = self.model_test(sample_ind, sample_y, sample_time)
# Calculate MSE
mse = torch.mean((pred.squeeze() - sample_y.squeeze()) ** 2)
return mse.item()
return 0.0
def get_scale_weights(self, T, mode=None):
"""Get multi-scale weights"""
try:
# Collect hidden states from involved object trajectories
h_scales = []
# Collect states from different trajectories
for m in range(self.nmods):
# Only consider current mode (if specified)
if mode is not None and m != mode:
continue
# Extract first few objects from this mode (max 3)
sample_uids = self.uid_table[m][:3] if m in self.uid_table else []
for uid in sample_uids:
if (m, uid) in self.factor_dynamics:
factor_process = self.factor_dynamics[(m, uid)]
if hasattr(factor_process, 'm') and factor_process.m is not None:
# Convert to float type and ensure consistent dimensions
state = factor_process.m.clone().detach().to(torch.float32)
# Ensure flattened to 1D vector
state = state.reshape(-1)
h_scales.append(state)
# Maximum 3 hidden states
if len(h_scales) >= 3:
break
# Maximum 3 hidden states per mode
if len(h_scales) >= 3:
break
# If not enough scales, return uniform weights
if len(h_scales) < 1:
return torch.ones(1, 1, device=self.device)
# Calculate weights
weights = self.multi_scale_weighting.compute_weights(h_scales)
return weights
except Exception as e:
logger.error(f"Error calculating multi-scale weights: {e}")
# Return uniform weights on error
return torch.ones(1, 1, device=self.device)
def msg_llk_init(self):
"""Initialize llk-msg for CEP inner loop"""
N_T = len(self.y_T) # Current time step entry count
# Initialize msg_U_llk using natural parameters: lam = S_inv, eta = S_inv x m
self.msg_U_lam_llk = [
1e-4 * torch.eye(self.R_U).reshape((1, self.R_U, self.R_U)).repeat(
N_T, 1, 1).double().to(self.device) for i in range(self.nmods)
] # (N*R_U*R_U)*nmod
self.msg_U_eta_llk = [
1e-3 * torch.rand(N_T, self.R_U, 1).double().to(self.device)
for i in range(self.nmods)
] # (N*R_U*1)*nmod
# tau messages
self.msg_a = torch.ones(N_T, 1).double().to(self.device) # N*1
self.msg_b = torch.ones(N_T, 1).double().to(self.device) # N*1
def filter_update(self, T, mode, add_to_list=True):
"""KF update step for trajectories of involved objects"""
# Update all objects involved in this mode
for msg_id, uid in enumerate(self.uid_table[mode]):
# Check if corresponding message exists
if msg_id >= len(self.msg_U_m[mode]):
continue
# Get approximate msg as KF observation
y = self.msg_U_m[mode][msg_id]
R = self.msg_U_V[mode][msg_id]
# Get factor process
factor_process = self.ensure_factor_dynamics(mode, uid)
# KF update step
factor_process.filter_update(y=y, R=R, add_to_list=add_to_list)
# Update posterior objects
H = factor_process.H
m = factor_process.m
P = factor_process.P
# Update posterior
self.post_U_m[mode][uid, :, :, T] = torch.mm(H, m)
self.post_U_v[mode][uid, :, :, T] = torch.mm(torch.mm(H, P), H.T)
def smooth(self):
"""Smooth all object trajectories"""
for (mode, uid), factor_process in self.factor_dynamics.items():
factor_process.smooth()
def inner_smooth(self):
"""Online smoothing for evaluation during training, clean up and update post_U after smoothing"""
self.smooth()
self.get_post_U()
# Reset all factor processes' smooth lists
for (_, _), factor_dynamics in self.factor_dynamics.items():
factor_dynamics.reset_smooth_list()
def get_post_U(self):
"""Get final post_U using smoothing results"""
for T, time_stamp in enumerate(self.time_uni):
for mode in range(self.nmods):
for uid in range(self.ndims[mode]):
if (mode, uid) in self.factor_dynamics:
factor_process = self.factor_dynamics[(mode, uid)]
if len(factor_process.time_stamp_list) > 0:
# At least one observation
if time_stamp in factor_process.time_stamp_list:
# Timestamp appeared before
T_id = factor_process.time_2_ind_table[time_stamp]
# Update posterior based on smoothed state
if T_id < len(factor_process.m_smooth_list):
H = factor_process.H
m = factor_process.m_smooth_list[T_id]
P = factor_process.P_smooth_list[T_id]
self.post_U_m[mode][uid, :, :, T] = torch.mm(H, m)
self.post_U_v[mode][uid, :, :, T] = torch.mm(
torch.mm(H, P), H.T)
else:
# Timestamp never appeared before
# Locate position of unseen timestamp
loc = bisect.bisect(factor_process.time_stamp_list,
time_stamp)
if loc == 0 and len(factor_process.m_smooth_list) > 0:
# First backwards Gaussian jump extrapolation
prev_time_stamp = factor_process.time_stamp_list[loc]
prev_m = factor_process.m_smooth_list[loc]
prev_P = factor_process.P_smooth_list[loc]
prev_time_int = prev_time_stamp - time_stamp
prev_A = torch.inverse(
torch.matrix_exp(factor_process.F *
prev_time_int).double())
prev_Q = factor_process.P_inf - torch.mm(
torch.mm(prev_A, factor_process.P_inf), prev_A.T)
jump_m = torch.mm(prev_A, prev_m)
jump_P = (torch.mm(torch.mm(prev_A, prev_P),
prev_A.T) + prev_Q)
H = factor_process.H
self.post_U_m[mode][uid, :, :, T] = torch.mm(H, jump_m)
self.post_U_v[mode][uid, :, :, T] = torch.mm(
torch.mm(H, jump_P), H.T)
elif loc < len(factor_process.time_stamp_list) and len(factor_process.m_smooth_list) > loc:
# Interpolation, merge (according to time sequence interpolation formula)
prev_time_stamp = factor_process.time_stamp_list[loc - 1]
next_time_stamp = factor_process.time_stamp_list[loc]
if loc - 1 < len(factor_process.m_smooth_list) and loc < len(factor_process.m_smooth_list):
prev_m = factor_process.m_smooth_list[loc - 1]
prev_P = factor_process.P_smooth_list[loc - 1]
next_m = factor_process.m_smooth_list[loc]
next_P = factor_process.P_smooth_list[loc]
prev_time_int = time_stamp - prev_time_stamp
next_time_int = next_time_stamp - time_stamp
prev_A = torch.matrix_exp(
factor_process.F * prev_time_int).double()
prev_Q = factor_process.P_inf - torch.mm(
torch.mm(prev_A, factor_process.P_inf), prev_A.T)
Q1_inv = torch.inverse(
torch.mm(torch.mm(prev_A, prev_P),
prev_A.T) + prev_Q)
next_A = torch.matrix_exp(
factor_process.F * next_time_int).double()
next_Q = factor_process.P_inf - torch.mm(
torch.mm(next_A, factor_process.P_inf), next_A.T)
Q2_inv = torch.inverse(
torch.mm(torch.mm(next_A, next_P),
next_A.T) + next_Q)
merge_P = torch.inverse(Q1_inv + torch.mm(
next_A.T, torch.mm(Q2_inv, next_A)))
temp_term = torch.mm(
Q1_inv, torch.mm(
prev_A, prev_m)) + torch.mm(
Q2_inv, torch.mm(next_A, next_m))
merge_m = torch.mm(merge_P, temp_term)
H = factor_process.H
self.post_U_m[mode][uid, :, :, T] = torch.mm(H, merge_m)
self.post_U_v[mode][uid, :, :, T] = torch.mm(
torch.mm(H, merge_P), H.T)
elif loc > 0 and loc - 1 < len(factor_process.m_smooth_list):
# Extrapolate at end, forward Gaussian jump
prev_time_stamp = factor_process.time_stamp_list[loc - 1]
prev_m = factor_process.m_smooth_list[loc - 1]
prev_P = factor_process.P_smooth_list[loc - 1]
prev_time_int = time_stamp - prev_time_stamp
prev_A = torch.matrix_exp(
factor_process.F * prev_time_int).double()
prev_Q = factor_process.P_inf - torch.mm(
torch.mm(prev_A, factor_process.P_inf), prev_A.T)
jump_m = torch.mm(prev_A, prev_m)
jump_P = (torch.mm(torch.mm(prev_A, prev_P),
prev_A.T) + prev_Q)
H = factor_process.H
self.post_U_m[mode][uid, :, :, T] = torch.mm(H, jump_m)
self.post_U_v[mode][uid, :, :, T] = torch.mm(
torch.mm(H, jump_P), H.T)
def model_test(self, test_ind, test_y, test_time):
"""
Model testing and evaluation - implemented for data point coreset
Handles coreset and non-coreset data points differently
"""
MSE_loss = torch.nn.MSELoss()
MAE_loss = torch.nn.L1Loss()
loss_test = {}
# Check which test data points are in coreset
is_coreset = []
for i, indices in enumerate(test_ind):
# Check if data point is in coreset
is_core = self.coreset_manager.is_in_coreset(indices)
is_coreset.append(is_core)
# Handle coreset and non-coreset data points separately
pred = torch.zeros(len(test_ind), device=self.device)
# Handle coreset data points (using full model)
core_indices = [i for i, is_core in enumerate(is_coreset) if is_core]
if core_indices:
try:
core_pred = self.model_test_coreset(
test_ind[core_indices], test_time[core_indices])
pred[core_indices] = core_pred.squeeze().to(pred.dtype)
except Exception as e:
logger.error(f"Coreset prediction error: {e}")
# Fill with zeros on error
pred[core_indices] = 0.0
# Handle non-coreset data points (using approximate model)
noncore_indices = [i for i, is_core in enumerate(is_coreset) if not is_core]
if noncore_indices:
try:
noncore_pred = self.model_test_noncore(
test_ind[noncore_indices], test_time[noncore_indices])
pred[noncore_indices] = noncore_pred.squeeze().to(pred.dtype)
except Exception as e:
logger.error(f"Non-coreset prediction error: {e}")
# Fill with zeros on error
pred[noncore_indices] = 0.0
try:
# Calculate error metrics
loss_test["rmse"] = torch.sqrt(
MSE_loss(pred.squeeze(),
test_y.squeeze().to(self.device)))
loss_test["MAE"] = MAE_loss(pred.squeeze(),
test_y.squeeze().to(self.device))
except Exception as e:
logger.error(f"Error calculating error metrics: {e}")
# Use large values on error
loss_test["rmse"] = torch.tensor(9999.0, device=self.device)
loss_test["MAE"] = torch.tensor(999.0, device=self.device)
return pred, loss_test
def model_test_coreset(self, test_ind, test_time):
"""Handle coreset data points testing method"""
# Implemented in subclasses
raise NotImplementedError("Should be implemented in subclass")
def model_test_noncore(self, test_ind, test_time):
"""Handle non-coreset data points testing method"""
# Implemented in subclasses
raise NotImplementedError("Should be implemented in subclass")
def reset(self):
"""Reset model state"""
for (_, _), factor_dynamics in self.factor_dynamics.items():
factor_dynamics.reset_list()
self.factor_dynamics = {}
class DCTF_CP(DynamicCoreSetTensorFactorization):
"""
Dynamic CoreSet Tensor Factorization CP form
CP model with data point coreset support
"""
def __init__(self, hyper_dict, data_dict):
"""Initialize CP model"""
super().__init__(hyper_dict, data_dict)
self.product_method = "hadamard" # CP
# For CP, gamma is a constant all-one vector
self.post_gamma_m = torch.ones(self.R_U, 1).double().to(self.device) # (R)*1
logger.info("Initialized DCTF_CP model")
def product_with_gamma(self, E_z, E_z_2, mode):
"""Multiply with gamma: for CP, gamma is constant all-1 vector, so we actually do nothing here"""
return E_z, E_z_2
def msg_approx_U(self, T, mode):
"""Approximate msg from data-llk groups at T"""
# Reset msg_U_m, msg_U_V
msg_U_m_mode = []
msg_U_V_mode = []
condi_modes = [i for i in range(self.nmods)]
condi_modes.remove(mode) # [1,2], [0,2]
# Import utils_streaming for moment product
try:
import utils_streaming
E_z, E_z_2 = utils_streaming.moment_product(
modes=condi_modes,
ind=self.ind_T,
U_m=[ele[:, :, :, T] for ele in self.post_U_m],
U_v=[ele[:, :, :, T] for ele in self.post_U_v],
order="second",
sum_2_scaler=False,
device=self.device,
product_method=self.product_method,
)
except ImportError:
# Simple implementation if utils_streaming not available
E_z, E_z_2 = self._moment_product(
modes=condi_modes,
ind=self.ind_T,
U_m=[ele[:, :, :, T] for ele in self.post_U_m],
U_v=[ele[:, :, :, T] for ele in self.post_U_v],
T=T
)
E_z, E_z_2 = self.product_with_gamma(E_z, E_z_2, mode)
# Apply multi-scale weights
weights = self.get_scale_weights(T, mode)
# First use natural parameters for easier msg merging
msg_U_lam_new = self.E_tau * E_z_2 # (N,R,R)
msg_U_eta_new = self.y_T * E_z * self.E_tau # (N,R,1)
# Distinguish between coreset and non-coreset data points, apply different weights
for i in range(len(self.ind_T)):
# Check if data point is in coreset
indices = self.ind_T[i]
is_in_coreset = self.coreset_manager.is_in_coreset(indices)
# Apply different weights to coreset and non-coreset data points
if is_in_coreset:
# Standard weights for coreset data points
pass
else:
# Use smaller weights for non-coreset data points, reducing influence
msg_U_lam_new[i] = msg_U_lam_new[i] * 0.9
msg_U_eta_new[i] = msg_U_eta_new[i] * 0.9
# Apply multi-scale weights
if weights is not None and weights.numel() > 0:
try:
# Ensure weights have correct shape
if weights.dim() == 1:
weights = weights.unsqueeze(1) # Convert to column vector
# Limit scale count
max_scales = min(weights.shape[0], 3) # Consider max 3 scales
# Use only global scale weight to simplify implementation
if max_scales > 0:
# Get weight value and ensure it's valid
scale_weight = weights[0].item()
if scale_weight > 0 and not math.isnan(scale_weight) and not math.isinf(scale_weight):
msg_U_lam_new = scale_weight * msg_U_lam_new
msg_U_eta_new = scale_weight * msg_U_eta_new
except Exception as e:
logger.error(f"Error applying weights: {e}")
# DAMPING step:
self.msg_U_lam_llk[mode] = (self.DAMPING * self.msg_U_lam_llk[mode] +
(1 - self.DAMPING) * msg_U_lam_new)
self.msg_U_eta_llk[mode] = (self.DAMPING * self.msg_U_eta_llk[mode] +
(1 - self.DAMPING) * msg_U_eta_new)
# Fill msg_U_M, msg_U_V
for i in range(len(self.uid_table[mode])):
uid = self.uid_table[mode][i] # Embedding id
eid = self.data_table[mode][i] # Associated entry id
S_inv_cur = self.msg_U_lam_llk[mode][eid].sum(dim=0) # (R,R)
S_inv_Beta_cur = self.msg_U_eta_llk[mode][eid].sum(dim=0) # (R,1)
try:
# Calculate covariance and mean
U_V = torch.linalg.inv(S_inv_cur)
U_M = torch.mm(U_V, S_inv_Beta_cur) # (R,1)
except Exception as e:
# Handle matrix inversion failure
logger.warning(f"Matrix inversion failed: {e}")
jitter = 1e-3 * torch.eye(S_inv_cur.size(0)).to(S_inv_cur.device)
U_V = torch.linalg.inv(S_inv_cur + jitter)
U_M = torch.mm(U_V, S_inv_Beta_cur)
msg_U_m_mode.append(U_M)
msg_U_V_mode.append(U_V)
self.msg_U_m.append(msg_U_m_mode)
self.msg_U_V.append(msg_U_V_mode)
def _moment_product(self, modes, ind, U_m, U_v, T):
"""Simple implementation of moment product if utils_streaming not available"""
N = len(ind)
E_z = torch.zeros(N, self.R_U, 1, device=self.device)
E_z_2 = torch.zeros(N, self.R_U, self.R_U, device=self.device)
# For each data point
for i in range(N):
# For CP, we do element-wise product of factors
# Initialize with ones
mean_prod = torch.ones(self.R_U, 1, device=self.device)
cov_prod = torch.eye(self.R_U, device=self.device)
# For each mode in conditional modes
for mode in modes:
idx = ind[i][mode]
m = U_m[mode][idx]
v = U_v[mode][idx]
# Hadamard product for CP
mean_prod = mean_prod * m
cov_prod = cov_prod * (v + torch.mm(m, m.T))
E_z[i] = mean_prod
E_z_2[i] = cov_prod
return E_z, E_z_2
def msg_approx_tau(self, T):
"""Approximate msg for tau"""
all_modes = [i for i in range(self.nmods)]
try:
import utils_streaming
E_z, E_z_2 = utils_streaming.moment_product(
modes=all_modes,
ind=self.ind_T,
U_m=[ele[:, :, :, T] for ele in self.post_U_m],
U_v=[ele[:, :, :, T] for ele in self.post_U_v],
order="second",
sum_2_scaler=False,
device=self.device,
product_method=self.product_method,
)
except ImportError:
# Simple implementation if utils_streaming not available
E_z, E_z_2 = self._moment_product(
modes=all_modes,
ind=self.ind_T,
U_m=[ele[:, :, :, T] for ele in self.post_U_m],
U_v=[ele[:, :, :, T] for ele in self.post_U_v],
T=T
)
self.msg_a = 1.5 * torch.ones(self.N_T, 1).to(self.device)
term1 = 0.5 * torch.square(self.y_T) # N_T*1
term2 = self.y_T.reshape(-1, 1) * torch.matmul(
E_z.transpose(dim0=1, dim1=2),
torch.ones(self.R_U, 1).double().to(self.device)).reshape(-1, 1) # N_T*1
temp = torch.matmul(E_z_2, torch.ones(self.R_U, 1).double().to(self.device)) # N*R*1
term3 = 0.5 * torch.matmul(temp.transpose(dim0=1, dim1=2),
torch.ones(self.R_U, 1).double().to(self.device)).reshape(-1, 1) # N*1
self.msg_b = self.DAMPING_tau * self.msg_b + (1 - self.DAMPING_tau) * (
term1.reshape(-1, 1) - term2.reshape(-1, 1) + term3.reshape(-1, 1)
) # N*1
def post_update_tau(self, T=None):
"""Update posterior factor tau based on current msg factors"""
self.post_a = self.post_a + self.msg_a.sum() - self.N_T
self.post_b = self.post_b + self.msg_b.sum()
self.E_tau = self.post_a / self.post_b
def model_test_coreset(self, test_ind, test_time):
"""
Test prediction function for coreset data points
Uses full CP decomposition model for coreset data points
"""
all_modes = [i for i in range(self.nmods)]
# Use CP decomposition for prediction
try:
import utils_streaming
pred = utils_streaming.moment_product_T(
modes=all_modes,
ind=test_ind,
ind_T=test_time,
U_m_T=self.post_U_m,
U_v_T=self.post_U_v,
order="first",
sum_2_scaler=True,
device=self.device,
product_method=self.product_method,
)
except ImportError:
# Simple implementation if utils_streaming not available
pred = self._moment_product_T(
test_ind=test_ind,
test_time=test_time
)
return pred
def _moment_product_T(self, test_ind, test_time):
"""Simple implementation of moment product T if utils_streaming not available"""
N = len(test_ind)
pred = torch.zeros(N, device=self.device)
# For each test point
for i in range(N):
indices = test_ind[i]
time = test_time[i]
# Get factor means for each mode
factors = []
for mode in range(self.nmods):
idx = indices[mode]
factors.append(self.post_U_m[mode][idx, :, :, time])
# CP prediction: product of factors summed over rank
prod = torch.ones(self.R_U, 1, device=self.device)
for f in factors:
prod = prod * f
pred[i] = torch.sum(prod)
return pred
def model_test_noncore(self, test_ind, test_time):
"""
Test prediction function for non-coreset data points
Uses approximate model based on coreset data points
"""
# Get data points in coreset
coreset_data = self.coreset_manager.get_coreset_data()
if not coreset_data:
# If coreset is empty, use standard test method
return self.model_test_coreset(test_ind, test_time)
# For each test data point, find most similar coreset data point
pred = torch.zeros(len(test_ind), device=self.device)
for i, (indices, time) in enumerate(zip(test_ind, test_time)):
# Find nearest coreset data point
nearest_core_idx = -1
min_distance = float('inf')
for core_idx, (core_indices, _, core_time) in enumerate(coreset_data):
# Calculate index similarity: number of matching indices
common_indices = sum(1 for a, b in zip(indices, core_indices) if a == b)
index_sim = common_indices / len(indices)
# Calculate time similarity: closeness of time steps
time_diff = abs(time - core_time)
time_sim = math.exp(-0.1 * time_diff) # Smaller time diff means higher similarity
# Combined similarity
similarity = 0.7 * index_sim + 0.3 * time_sim
distance = 1.0 - similarity
if distance < min_distance:
min_distance = distance
nearest_core_idx = core_idx
# If found most similar coreset data point, use its prediction result
if nearest_core_idx >= 0:
core_indices, _, core_time = coreset_data[nearest_core_idx]
# Generate prediction for this coreset data point
try:
import utils_streaming
core_pred = utils_streaming.moment_product(
modes=list(range(self.nmods)),
ind=np.array([core_indices]),
U_m=[ele[:, :, :, core_time] for ele in self.post_U_m],
U_v=[ele[:, :, :, core_time] for ele in self.post_U_v],
order="first",
sum_2_scaler=True,
device=self.device,
product_method=self.product_method,
)
except ImportError:
# Simple implementation
core_pred = self._single_prediction(core_indices, core_time)
# Generate prediction for current test data point
try:
import utils_streaming
test_pred = utils_streaming.moment_product(
modes=list(range(self.nmods)),
ind=np.array([indices]),
U_m=[ele[:, :, :, time] for ele in self.post_U_m],
U_v=[ele[:, :, :, time] for ele in self.post_U_v],
order="first",
sum_2_scaler=True,
device=self.device,
product_method=self.product_method,
)
except ImportError:
# Simple implementation
test_pred = self._single_prediction(indices, time)
# Mix prediction results based on similarity
similarity = 1.0 - min_distance
pred[i] = similarity * core_pred.item() + (1 - similarity) * test_pred.item()
else:
# If no similar coreset data point found, use standard model
try:
import utils_streaming
test_pred = utils_streaming.moment_product(
modes=list(range(self.nmods)),
ind=np.array([indices]),
U_m=[ele[:, :, :, time] for ele in self.post_U_m],
U_v=[ele[:, :, :, time] for ele in self.post_U_v],
order="first",
sum_2_scaler=True,
device=self.device,
product_method=self.product_method,
)
except ImportError:
# Simple implementation
test_pred = self._single_prediction(indices, time)
pred[i] = test_pred.item()
return pred
def _single_prediction(self, indices, time):
"""Make prediction for a single data point"""
# Get factor means for each mode
factors = []
for mode in range(self.nmods):
idx = indices[mode]
factors.append(self.post_U_m[mode][idx, :, :, time])
# CP prediction: product of factors summed over rank
prod = torch.ones(self.R_U, 1, device=self.device)
for f in factors:
prod = prod * f