From d54f8bc19dca335eb35ebbca13cbd92cb5125823 Mon Sep 17 00:00:00 2001 From: karthikreddy-02 Date: Thu, 19 Mar 2026 23:52:27 -0400 Subject: [PATCH 1/3] feat(AINode): [Issue-17301] Import PatchTST-FM-R1 architecture and register in model_info --- .../iotdb/ainode/core/model/model_info.py | 12 + .../ainode/core/model/patchtst_fm/__init__.py | 0 .../patchtst_fm/configuration_patchtst_fm.py | 54 +++ .../model/patchtst_fm/modeling_patchtst_fm.py | 426 ++++++++++++++++++ 4 files changed, 492 insertions(+) create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/__init__.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/configuration_patchtst_fm.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/modeling_patchtst_fm.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index f253fb1e56f60..778302a923573 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -158,4 +158,16 @@ def __repr__(self): }, transformers_registered=True, ), + "patchtst_fm": ModelInfo( + model_id = "patchtst_fm", + category=ModelCategory.BUILTIN, + state=ModelStates.INACTIVE, + model_type="patchtst_fm", + pipeline_cls="pipeline_patchtst_fm.PatchTSTFMPipeline", + repo_id="ibm-research/patchtst-fm-r1", + auto_map={ + "AutoConfig": "configuration_patchtst_fm.PatchTSTFMConfig", + "AutoModelForCausalLM": "modeling_patchtst_fm.PatchTSTFMForPrediction", + }, + ), } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/configuration_patchtst_fm.py b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/configuration_patchtst_fm.py new file mode 100644 index 0000000000000..e93ca10297b6b --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/configuration_patchtst_fm.py @@ -0,0 +1,54 @@ +# Copyright contributors to the TSFM project +# +"""PatchTST-FM model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +PATCHTSTFM_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class PatchTSTFMConfig(PretrainedConfig): + model_type = "patchtst_fm" + attribute_map = { + "hidden_size": "d_model", + "num_hidden_layers": "n_layer", + } + + # has_no_defaults_at_init = True + def __init__( + self, + context_length: int = 8192, + prediction_length: int = 64, + d_patch: int = 16, + d_model: int = 384, + n_head: int = 6, + n_layer: int = 6, + norm_first: bool = True, + pretrain_mask_ratio: float = 0.4, + pretrain_mask_cont: int = 8, + num_quantile: int = 99, + **kwargs, + ): + self.context_length = context_length + self.prediction_length = prediction_length + self.d_patch = d_patch + self.n_patch = int(context_length // d_patch) + self.d_model = d_model + self.n_head = n_head + self.n_layer = n_layer + self.norm_first = norm_first + self.pretrain_mask_ratio = pretrain_mask_ratio + self.pretrain_mask_cont = pretrain_mask_cont + self.num_quantile = num_quantile + + if num_quantile % 9 == 0: + quantiles = [i / (self.num_quantile + 1) for i in range(1, self.num_quantile + 1)] + else: + quantiles = [i / (self.num_quantile - 1) for i in range(1, self.num_quantile - 1)] + quantiles = [0.01] + quantiles + [0.99] + self.quantile_levels = quantiles + super().__init__(**kwargs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/modeling_patchtst_fm.py b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/modeling_patchtst_fm.py new file mode 100644 index 0000000000000..a734de5df2dd8 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/modeling_patchtst_fm.py @@ -0,0 +1,426 @@ +# Copyright contributors to the TSFM project +# +"""PatchTST-FM model implementation""" + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput, logging + +from .basic import ( + TransformerBlock, + make_attn_mask, +) +from .configuration_patchtst_fm import PatchTSTFMConfig +from .normalization import RevIN +from .tools import count_parameters + + +logger = logging.get_logger(__name__) + + +class LearnedPositionalEmbedding(nn.Module): + def __init__(self, d_model, max_len=5000, type="add"): + super().__init__() + self.embedding = nn.Embedding(max_len, d_model) + self.type = type + + def forward(self, x): + positions = torch.arange(x.size(-2), device=x.device).unsqueeze(0) + pe = self.embedding(positions) + if x.ndim == 4: + pe = pe.unsqueeze(1) + if self.type == "add": + return x + pe + elif self.type == "mul": + return x * pe + else: + raise ValueError(f"Invalid type: {self.type}") + + +class ResidualBlock(nn.Module): + def __init__(self, d_in, d_out, d_hidden): + super().__init__() + + self.layer1 = nn.Linear(d_in, d_hidden) + self.layer2 = nn.Linear(d_hidden, d_out) + self.residual = nn.Linear(d_in, d_out) + self.activation = nn.Sigmoid() + + def forward(self, x): + return self.layer2(self.activation(self.layer1(x))) + self.residual(x) + + +class PatchTSTFMPreTrainedModel(PreTrainedModel): + # Weight initialization + config_class = PatchTSTFMConfig + base_model_prefix = "model" + main_input_name = "inputs" + supports_gradient_checkpointing = False + + +@dataclass +class PatchTSTFMModelOutput(ModelOutput): + loss_mask: torch.Tensor = None + normed_target: torch.Tensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + quantile_predictions: torch.FloatTensor = None + + +@dataclass +class PatchTSTFMPretrainingOutput(ModelOutput): + loss: torch.Tensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + quanitle_predictions: torch.Tensor = None + + +@dataclass +class PatchTSTFMPredictionOutput(ModelOutput): + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + quantile_predictions: torch.Tensor = None + + +class PatchTSTFMModel(PatchTSTFMPreTrainedModel): + def __init__(self, config: PatchTSTFMConfig): + super().__init__(config) + self.config = config + self.quantile_levels = config.quantile_levels + self.pos_embed = LearnedPositionalEmbedding(d_model=config.d_model, max_len=config.n_patch, type="add") + assert config.d_model % config.n_head == 0, "[QuantileDecoder] d_model must be divisible by n_head" + + self.blocks = nn.ModuleList( + [ + TransformerBlock( + config.d_model, + config.n_head, + mlp_ratio=4.0, + norm_first=True, + dropout=0.1, + ) + for _ in range(config.n_layer) + ] + ) + self.in_layer = ResidualBlock(config.d_patch * 2, config.d_model, config.d_model) + self.out_layer = ResidualBlock(config.d_model, config.d_patch * (config.num_quantile + 1), config.d_model) + + self.norm_fn = RevIN(dim=-1, std_min=1e-5, use_sinh=True) + + def model_summary(self): + s = "" + model_name = "PatchTST-FM" + s += f"{'=' * 5:<10} {model_name} {'=' * 5:>9}\n" + s += f"{'Transformer:':<20} {count_parameters(self.blocks)[0] / 1e6:>8.2f}M\n" + s += f"{'=' * 30}\n" + p = count_parameters(self) + s += f"{'Trainable:':<20} {p[1] / 1e6:>8.2f}M\n" + s += f"{'Frozen:':<20} {p[2] / 1e6:>8.2f}M\n" + s += f"{'Total:':<20} {p[0] / 1e6:>8.2f}M\n" + s += f"{'=' * 30}\n" + return s + + def forward( + self, + inputs: torch.Tensor, + pred_mask: torch.Tensor, + miss_mask: torch.Tensor, + pad_mask: torch.Tensor, + output_hidden_states: Optional[bool] = False, + return_loss: bool = True, + return_dict: Optional[bool] = None, + # **kwargs, + ) -> PatchTSTFMPretrainingOutput: + x = inputs.to(self.device) + pad_mask = pad_mask.to(self.device).bool() + pred_mask = pred_mask.to(self.device).bool() + miss_mask = miss_mask.to(self.device).bool() + if x.ndim > 2: + x = rearrange(x, "B N T -> (B N) T") + pad_mask = rearrange(pad_mask, "B N T -> (B N) T") + pred_mask = rearrange(pred_mask, "B N T -> (B N) T") + miss_mask = rearrange(miss_mask, "B N T -> (B N) T") + + B, T = x.shape + ts_mask = pred_mask | pad_mask | miss_mask + + x_target = self.norm_fn.fit_transform(x, mask=pred_mask | pad_mask | miss_mask) + x_input = torch.where(ts_mask, torch.zeros_like(x_target), x_target) + + x_patch = x_input.reshape(B, self.config.n_patch, self.config.d_patch) + mask_patch = ts_mask.reshape(B, self.config.n_patch, self.config.d_patch) + pad_patch_mask = pad_mask.reshape(B, self.config.n_patch, self.config.d_patch).float().mean(dim=-1).gt(0.9) + + q_pred, q_raw = self.decode(x=x_patch, mask=mask_patch.float(), t_pad_mask=pad_patch_mask) + q_pred = q_pred.permute(0, 2, 3, 1) + + B, N, D, Q = q_pred.shape + q_pred = q_pred.reshape(B, N * D, Q) + + if output_hidden_states: + hidden_states = q_raw.reshape(B, N * D, Q) + else: + hidden_states = None + + # return here q_pred, loss_mask, and x_target + return PatchTSTFMModelOutput( + normed_target=x_target, + quantile_predictions=q_pred, + loss_mask=(pred_mask & ~pad_mask & ~miss_mask).float(), + hidden_states=hidden_states, + ) + + def decode(self, x, mask, t_pad_mask=None): + B, N, D = x.shape + # x = self.in_layer(torch.cat([x, t, 1 - mask], dim=-1)) + x = self.in_layer(torch.cat([x, 1 - mask], dim=-1)) + pad_attn_mask = make_attn_mask(t_pad_mask, t_pad_mask).unsqueeze(1) + + x = self.pos_embed(x) + for block in self.blocks: + x = block(x, pad_attn_mask) + x = self.out_layer(x) + q_raw = x.reshape(B, N, self.config.num_quantile + 1, self.config.d_patch).permute(0, 2, 1, 3) + q = q_raw[:, 0, :, :].unsqueeze(1) + torch.cumsum( + F.softplus(q_raw[:, 1:, :, :]) / self.config.num_quantile, dim=1 + ) + return q, q_raw + + +class PatchTSTFMForPretraining(PatchTSTFMPreTrainedModel): + def __init__(self, config: PatchTSTFMConfig): + super().__init__(config) + + self.config = config + self.backbone = PatchTSTFMModel(config) + + # move all out_layer items here + + def forward( + self, + inputs: torch.Tensor, + pred_mask: torch.Tensor, + miss_mask: torch.Tensor, + pad_mask: torch.Tensor, + output_hidden_states: Optional[bool] = False, + return_loss: bool = True, + return_dict: Optional[bool] = None, + ) -> PatchTSTFMPretrainingOutput: + # move quantile logic here + + model_outputs = self.backbone( + inputs, + pred_mask=pred_mask, + miss_mask=miss_mask, + pad_mask=pad_mask, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + q_pred = model_outputs.quantile_predictions + x_target = model_outputs.normed_target + loss_mask = model_outputs.loss_mask + + if return_loss: + x_target = x_target.unsqueeze(-1) + quantiles = torch.tensor(self.backbone.quantile_levels, device=x_target.device).view(1, 1, -1) + loss = 2 * torch.abs((x_target - q_pred) * ((x_target <= q_pred).float() - quantiles)) + loss = loss * loss_mask.unsqueeze(-1) + loss = loss.sum(dim=1) / torch.clamp(loss_mask.sum(dim=1, keepdim=True), min=1) + loss = loss.sum(dim=-1).mean() / math.sqrt(self.config.num_quantile) + else: + loss = None + + x_pred = q_pred.permute(0, 2, 1) + x_pred = self.backbone.norm_fn.inverse_transform(x_pred) + + return PatchTSTFMPretrainingOutput( + quantile_predictions=x_pred, loss=loss, hidden_states=model_outputs.hidden_states + ) + + +class PatchTSTFMForPrediction(PatchTSTFMPreTrainedModel): + def __init__(self, config: PatchTSTFMConfig): + super().__init__(config) + + self.config = config + self.backbone = PatchTSTFMModel(config) + + def model_summary(self) -> str: + return self.backbone.model_summary() + + def forward( + self, + inputs: List[torch.Tensor] | torch.Tensor, + prediction_length: Optional[int] = None, + quantile_levels: Optional[List[float]] = None, + output_hidden_states: Optional[bool] = False, + return_loss: bool = True, + return_dict: Optional[bool] = None, + ): + forecast_len = prediction_length if prediction_length else self.config.prediction_length + + cl = self.config.context_length + ul = -1 + logger.info( + f"Context Len: {cl} | Forecast Len: {forecast_len} ", + ) + cl = [cl] * len(inputs) + fl = [ + max( + forecast_len, + ul, + self.config.d_patch * max(self.config.pretrain_mask_cont, 2), + ) + ] * len(inputs) + forecast_samples, hidden_states = self.forecast_single_step( + inputs, fl, context_len=cl, output_hidden_states=output_hidden_states + ) + forecast_samples = torch.stack(forecast_samples, dim=0)[:, :, :forecast_len] + + if quantile_levels is not None: + quantile_indices = [self.backbone.quantile_levels.index(q) for q in quantile_levels] + forecast_samples = forecast_samples[:, quantile_indices, :] + return PatchTSTFMPredictionOutput(quantile_predictions=forecast_samples, hidden_states=hidden_states) + + def forecast_single_step( + self, + x: List[torch.Tensor], + forecast_len: List[int], + context_len: List[int], + output_hidden_states: Optional[bool] = False, + ): + """ + x: list of torch.Tensor of time series, can be of different lengths + """ + + inputs = [] + pad_mask = [] + pred_mask = [] + miss_mask = [] + ts_ends = [] + time_index = [] + sample_lengths = [] + + for x_i, c_i, f_i in zip(x, context_len, forecast_len): + c_i = min(x_i.shape[0] + f_i, c_i) + s_i = c_i - f_i + x_in = x_i[-s_i:] + pad_mask_i = torch.zeros_like(x_in) + miss_mask_i = torch.zeros_like(x_in) + x_in = torch.nan_to_num(x_in, nan=x_in.nanmean().item()) + pred_mask_i = torch.cat([torch.zeros_like(x_in), torch.ones(f_i)], dim=-1) + miss_mask_i = torch.cat([miss_mask_i, torch.zeros(f_i)], dim=-1) + pad_mask_i = torch.cat([pad_mask_i, torch.zeros(f_i)], dim=-1) + x_in = torch.cat([x_in, torch.ones(f_i) * x_in.nanmean().item()], dim=-1) + time_index_i = ( + torch.arange( + self.config.context_length - x_in.shape[-1] + 1, + self.config.context_length + 1, + ).float() + / self.config.context_length + ) + sample_len = x_in.shape[-1] + if sample_len == self.config.context_length: + inputs.append(x_in) + pred_mask.append(pred_mask_i) + pad_mask.append(pad_mask_i) + miss_mask.append(miss_mask_i) + time_index.append(time_index_i) + ts_ends.append(torch.tensor([0, sample_len]).float()) + sample_lengths.append(sample_len) + elif sample_len < self.config.context_length: # padding + left_pad = self.config.context_length - sample_len + inputs.append( + F.pad( + x_in, + (left_pad, 0), + mode="constant", + value=x_in.nanmean().item(), + ) + ) + pred_mask.append(F.pad(pred_mask_i, (left_pad, 0), mode="constant", value=0.0)) + pad_mask.append(F.pad(pad_mask_i, (left_pad, 0), mode="constant", value=1.0)) + miss_mask.append(F.pad(miss_mask_i, (left_pad, 0), mode="constant", value=0.0)) + time_index.append(F.pad(time_index_i, (left_pad, 0), mode="constant", value=-1)) + ts_ends.append(torch.tensor([left_pad, left_pad + sample_len]).float()) + sample_lengths.append(sample_len) + else: # subsample + inputs.append( + F.interpolate( + x_in.view(1, 1, -1), + size=self.config.context_length, + mode="nearest", + ).squeeze() + ) + pred_mask.append( + F.interpolate( + pred_mask_i.view(1, 1, -1), + size=self.config.context_length, + mode="nearest", + ).squeeze() + ) + pad_mask.append( + F.interpolate( + pad_mask_i.view(1, 1, -1), + size=self.config.context_length, + mode="nearest", + ).squeeze() + ) + miss_mask.append( + F.interpolate( + miss_mask_i.view(1, 1, -1), + size=self.config.context_length, + mode="nearest", + ).squeeze() + ) + time_index.append( + F.interpolate( + time_index_i.view(1, 1, -1), + size=self.config.context_length, + mode="nearest", + ).squeeze() + ) + ts_ends.append(torch.tensor([0, self.config.context_length]).float()) + sample_lengths.append(sample_len) + + inputs = torch.stack(inputs, dim=0) + pred_mask = torch.stack(pred_mask, dim=0) + pad_mask = torch.stack(pad_mask, dim=0) + miss_mask = torch.stack(miss_mask, dim=0) + time_index = torch.stack(time_index, dim=0) + ts_ends = torch.stack(ts_ends, dim=0) + + precision = ( + torch.bfloat16 + if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 + else torch.float16 + ) + device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu" + + with torch.autocast(device_type=device, dtype=precision, enabled=True): + model_output = self.backbone( + inputs=inputs, + pred_mask=pred_mask, + miss_mask=miss_mask, + pad_mask=pad_mask, + return_loss=False, + output_hidden_states=output_hidden_states, + ) + outputs = model_output.quantile_predictions + + outputs = outputs.permute(0, 2, 1) + outputs = self.backbone.norm_fn.inverse_transform(outputs) + + x_preds = [] + for i in range(outputs.shape[0]): + if sample_lengths[i] <= self.config.context_length: + x_pred = outputs[i][:, int(ts_ends[i][0]) : int(ts_ends[i][1])] + else: + x_pred = F.interpolate(outputs[i].unsqueeze(1), size=sample_lengths[i], mode="linear").squeeze(1) + x_preds.append(x_pred[:, -forecast_len[i] :]) + return x_preds, model_output.hidden_states From 7d497d4772d9416515e8b6184a8c144bb9103ea4 Mon Sep 17 00:00:00 2001 From: karthikreddy-02 Date: Fri, 20 Mar 2026 19:04:18 -0400 Subject: [PATCH 2/3] AINode: [Issue-17301] Implement PatchTSTFMPipeline forecast workflow --- .../model/patchtst_fm/pipeline_patchtst_fm.py | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/pipeline_patchtst_fm.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/pipeline_patchtst_fm.py b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/pipeline_patchtst_fm.py new file mode 100644 index 0000000000000..1a3400b96d3c2 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/pipeline_patchtst_fm.py @@ -0,0 +1,70 @@ +import torch + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline + +class PatchTSTFMPipeline(ForecastPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, **model_kwargs) + + def preprocess(self, inputs, **infer_kwargs): + inputs = super().preprocess(inputs, **infer_kwargs) + for idx, item in enumerate(inputs): + # Model expects float32 + target_tensor = item["targets"].to(torch.float32) + + # Expand 1D tensor [length] to [batch=1, length] + if target_tensor.ndim == 1: + target_tensor = target_tensor.unsqueeze(0) + + item["targets"] = target_tensor + return inputs + + def forecast(self, inputs, **infer_kwargs) -> list[torch.Tensor]: + """ + TODO: YOU WRITE THIS. + 1. Create an empty list called `forecasts`. + 2. Iterate through the [inputs](cci:1://file:///Users/karthik/Documents/projects/iotdb/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py:143:4-178:77) list. + 3. For each input dictionary, extract the "targets" tensor. + 4. Extract the prediction length using `infer_kwargs.get("output_length", 96)`. + 5. Move the tensor to the model's device: `tensor = tensor.to(self.device)` + 6. IBM's PatchTST expects "past_values" and "prediction_length" as arguments. + Run the forward pass inside a `with torch.no_grad():` block natively using: + output = self.model(past_values=tensor, prediction_length=pred_len) + 7. Extract the forecast using `output.prediction_outputs` and append it to your list. + 8. Return the list. + """ + forecasts = [] + for input in inputs: + targets = input['targets'] + pred_length = infer_kwargs.get("output_length", 96) + tensor = targets.to(self.device) + with torch.no_grad(): + output = self.model(past_values = tensor, prediction_length = pred_length) + + forecasts.append(output.prediction_outputs) + return forecasts + + + + + + + + + + def postprocess(self, outputs: list[torch.Tensor], **infer_kwargs) -> list[torch.Tensor]: + """ + The IBM Model returns quantiles [batch, variates, prediction_length, quantiles]. + We reduce this to [variates, prediction_length] by taking the median or mean. + """ + final_outputs = [] + for output in outputs: + # Remove batch dimension if it is just a single batch + if output.ndim == 4: + output = output.squeeze(0) + + # Average out the quantiles to get a point forecast + point_forecast = output.mean(dim=-1) + final_outputs.append(point_forecast) + + return super().postprocess(final_outputs, **infer_kwargs) From d1a348ab53f2fda64b2687ff318f4860f4f3a73f Mon Sep 17 00:00:00 2001 From: karthikreddy-02 Date: Fri, 20 Mar 2026 22:04:49 -0400 Subject: [PATCH 3/3] fix(AINode): [Issue-17301] Add missing IBM deps (basic, normalization, tools) and fix forward() argument mismatch --- .../ainode/core/model/patchtst_fm/basic.py | 307 ++++++++++++++++++ .../core/model/patchtst_fm/normalization.py | 100 ++++++ .../model/patchtst_fm/pipeline_patchtst_fm.py | 26 +- .../ainode/core/model/patchtst_fm/tools.py | 219 +++++++++++++ 4 files changed, 636 insertions(+), 16 deletions(-) create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/basic.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/normalization.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/tools.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/basic.py b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/basic.py new file mode 100644 index 0000000000000..4a1c2f1ef4a2c --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/basic.py @@ -0,0 +1,307 @@ +from typing import Optional, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def make_attn_mask(query_pad: torch.Tensor, key_pad: torch.Tensor) -> torch.Tensor: + """ + Build an additive attention mask of shape (B, Q, K) from + query/key padding masks. + + Args: + query_pad: (B, Q) bool or 0/1 tensor. 1/True = padded query position. + key_pad: (B, K) bool or 0/1 tensor. 1/True = padded key position. + + Returns: + attn_mask: (B, Q, K) float tensor, where masked positions are -inf + and valid positions are 0.0 (for use with SDPA). + """ + # Ensure boolean + q_pad = query_pad.bool() # (B, Q) + k_pad = key_pad.bool() # (B, K) + + # A position (q, k) is invalid if *either* the query or key is padded + # Shape: (B, Q, K) + pad = q_pad.unsqueeze(-1) | k_pad.unsqueeze(-2) + + # Build float mask with -inf on padded positions, 0 elsewhere + attn_mask = torch.zeros_like(pad, dtype=torch.float32) + attn_mask.masked_fill_(pad, float("-inf")) + + return attn_mask + + +class MLP(nn.Module): + def __init__( + self, + in_dim, + out_dim, + hidden_dim=256, + num_hidden_layers=1, + dropout=0, + norm=False, + activation=nn.GELU(approximate="tanh"), + output_activation=nn.Identity(), + norm_layer=nn.LayerNorm, + ): + super().__init__() + layers = [] + layers.append(nn.Linear(in_dim, hidden_dim)) + # layers.append(norm_layer(hidden_dim) if norm else nn.Identity()) + layers.append(activation) + for _ in range(num_hidden_layers - 1): + layers.append(nn.Dropout(dropout)) + layers.append(norm_layer(hidden_dim) if norm else nn.Identity()) + layers.append(nn.Linear(hidden_dim, hidden_dim)) + layers.append(activation) + layers.append(nn.Dropout(dropout)) + layers.append(norm_layer(hidden_dim) if norm else nn.Identity()) + layers.append(nn.Linear(hidden_dim, out_dim)) + layers.append(output_activation) + self.layers = nn.Sequential(*layers) + # self.init_weights() + + def forward(self, x): + return self.layers(x) + + +class SwiGLU(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim=384, dropout=0): + super().__init__() + hidden_dim = round(hidden_dim * 2 / 3) + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(in_dim, hidden_dim) + self.fc3 = nn.Linear(hidden_dim, out_dim) + self.activation = nn.SiLU() + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.activation(self.fc2(x)) + return self.dropout(self.fc3(x)) + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: Type[nn.Module] = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor: + if x.ndim == 3: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # (B, num_heads, N, head_dim) + q, k = self.q_norm(q), self.k_norm(k) + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + attn_mask=attn_mask, + ) + x = x.transpose(1, 2).reshape(B, N, C) + elif x.ndim == 4: + B, M, N, C = x.shape + qkv = self.qkv(x).reshape(B, M, N, 3, self.num_heads, self.head_dim).permute(3, 0, 4, 1, 2, 5) + q, k, v = qkv.unbind(0) # (B, num_heads, M, N, head_dim) + q, k = self.q_norm(q), self.k_norm(k) + # print('q', q.shape, 'k', k.shape, 'v', v.shape, 'attn_mask', attn_mask.shape if attn_mask is not None else "None") + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + attn_mask=attn_mask.unsqueeze(1) if attn_mask is not None else None, + ) + x = x.permute(0, 2, 3, 1, 4).reshape(B, M, N, C) + else: + raise ValueError(f"Unsupported input dimension: {x.ndim}") + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossAttention(nn.Module): + def __init__( + self, + q_dim: int, # dim of x + kv_dim: Optional[int] = None, # dim of m (defaults to q_dim) + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: Type[nn.Module] = nn.LayerNorm, + ) -> None: + super().__init__() + kv_dim = kv_dim if kv_dim is not None else q_dim + assert q_dim % num_heads == 0, "q_dim must be divisible by num_heads" + + self.num_heads = num_heads + self.head_dim = q_dim // num_heads + + self.q = nn.Linear(q_dim, q_dim, bias=qkv_bias) + self.kv = nn.Linear(kv_dim, 2 * q_dim, bias=qkv_bias) # produce k and v in the SAME head dim as q + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(q_dim, q_dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward( + self, + x: torch.Tensor, # (B, Nq, q_dim) + m: torch.Tensor, # (B, Nk, kv_dim) + attn_mask: Optional[torch.Tensor] = None, # broadcastable to (B, num_heads, Nq, Nk) or (Nq, Nk) + is_causal: bool = False, + ) -> torch.Tensor: + if x.ndim == 3: + B, Nq, Cq = x.shape + _, Nk, _ = m.shape + q = self.q(x).reshape(B, Nq, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (B, H, Nq, Hd) + kv = self.kv(m).reshape(B, Nk, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + k, v = kv.unbind(0) # (B, H, Nk, Hd) + q, k = self.q_norm(q), self.k_norm(k) + x = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0.0, + is_causal=is_causal, + ) # (B, H, Nq, Hd) + x = x.transpose(1, 2).reshape(B, Nq, Cq) # back to (B, Nq, q_dim) + elif x.ndim == 4: + B, M, Nq, Cq = x.shape + _, Nk, _ = m.shape + q = self.q(x).reshape(B, M, Nq, self.num_heads, self.head_dim).permute(0, 3, 1, 2, 4) # (B, H, M, Nq, Hd) + kv = self.kv(m).reshape(B, Nk, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + k, v = kv.unbind(0) # (B, H, Nk, Hd) + q, k = self.q_norm(q), self.k_norm(k) + x = F.scaled_dot_product_attention( + q, + k.unsqueeze(2), + v.unsqueeze(2), + attn_mask=attn_mask.unsqueeze(1) if attn_mask is not None else None, + dropout_p=self.attn_drop.p if self.training else 0.0, + is_causal=is_causal, + ) # (B, H, M, Nq, Hd) + x = x.permute(0, 2, 3, 1, 4).reshape(B, M, Nq, Cq) + else: + raise ValueError(f"Unsupported input dimension: {x.ndim}") + x = self.proj_drop(self.proj(x)) + return x + + +class TransformerBlock(nn.Module): + """ + A standard Transformer block. + """ + + def __init__( + self, + d_model, + num_heads, + mlp_ratio=4.0, + dropout=0.1, + norm_first=True, + norm_layer=nn.LayerNorm, + mlp_type="mlp", + ): + super().__init__() + self.norm_first = norm_first + self.norm1 = norm_layer(d_model, elementwise_affine=True, eps=1e-6) + self.attn = Attention(d_model, num_heads, qkv_bias=True, attn_drop=dropout, proj_drop=dropout) + self.norm2 = norm_layer(d_model, elementwise_affine=True, eps=1e-6) + if mlp_type == "swiglu": + self.mlp = SwiGLU(d_model, d_model, hidden_dim=int(mlp_ratio * d_model), dropout=dropout) + elif mlp_type == "mlp": + self.mlp = MLP( + in_dim=d_model, + out_dim=d_model, + hidden_dim=int(mlp_ratio * d_model), + dropout=dropout, + ) + else: + raise ValueError(f"Unsupported MLP type: {mlp_type}") + self.dropout = nn.Dropout(dropout) + + def forward(self, x, attn_mask=None): + if self.norm_first: + x = x + self.attn(self.norm1(x), attn_mask) + x = x + self.dropout(self.mlp(self.norm2(x))) + else: + x = self.norm1(x + self.attn(x, attn_mask)) + x = self.norm2(x + self.dropout(self.mlp(x))) + return x + + +class TransformerBlockCrossAttention(nn.Module): + def __init__( + self, + d_model, + num_heads, + d_cond=None, + mlp_ratio=4.0, + dropout=0.1, + norm_first=True, + norm_layer=nn.LayerNorm, + mlp_type="mlp", + ): + super().__init__() + d_cond = d_cond if d_cond is not None else d_model + self.norm_first = norm_first + self.norm1 = norm_layer(d_model, elementwise_affine=True, eps=1e-6) + self.attn = CrossAttention( + d_model, + d_cond, + num_heads, + qkv_bias=True, + attn_drop=dropout, + proj_drop=dropout, + ) + self.norm2 = norm_layer(d_model, elementwise_affine=True, eps=1e-6) + if mlp_type == "swiglu": + self.mlp = SwiGLU(d_model, d_model, hidden_dim=int(mlp_ratio * d_model), dropout=dropout) + elif mlp_type == "mlp": + self.mlp = MLP( + in_dim=d_model, + out_dim=d_model, + hidden_dim=int(mlp_ratio * d_model), + dropout=dropout, + ) + else: + raise ValueError(f"Unsupported MLP type: {mlp_type}") + self.dropout = nn.Dropout(dropout) + + def forward(self, x, m, attn_mask=None): + if self.norm_first: + x = x + self.attn(self.norm1(x), m, attn_mask) + x = x + self.dropout(self.mlp(self.norm2(x))) + else: + x = self.norm1(x + self.attn(x, m, attn_mask)) + x = self.norm2(x + self.dropout(self.mlp(x))) + return x diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/normalization.py b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/normalization.py new file mode 100644 index 0000000000000..02ec84d0a78db --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/normalization.py @@ -0,0 +1,100 @@ +import torch +import torch.nn as nn + + +class RevIN(nn.Module): + def __init__(self, dim=-1, std_min=1e-5, max_val=100, use_sinh=False): + super().__init__() + self.dim = dim + self.std_min = std_min + self.max_val = max_val + self.use_sinh = use_sinh + + def fit_transform(self, x, mask=None): + with torch.autocast(device_type="cuda", enabled=False): + self._get_statistics(x, mask) + return self.transform(x) + + def transform(self, x): + with torch.autocast(device_type="cuda", enabled=False): + x = (x - self.mean) / self.std + if self.use_sinh: + x = torch.asinh(x) + return x + + def inverse_transform(self, x): + with torch.autocast(device_type="cuda", enabled=False): + if self.use_sinh: + x = torch.sinh(x) + if x.ndim != self.mean.ndim: + x = x * self.std.unsqueeze(1) + self.mean.unsqueeze(1) + else: + x = x * self.std + self.mean + + return x + + def get_statistics(self): + return self.mean, self.std + + def _get_statistics(self, x, mask=None): + if mask is None: + self.mean = x.mean(dim=self.dim, keepdim=True) + std = x.std(dim=self.dim, keepdim=True) + self.std = torch.where(std > self.std_min, std, torch.ones_like(std)) + else: + mask = mask.bool() + unmask = (~mask).float() + count = unmask.sum(dim=self.dim, keepdim=True).clamp(min=1) # avoid division by zero + x_mean = (x * unmask).sum(dim=self.dim, keepdim=True) / count + x_std = (((x - x_mean) * unmask) ** 2).sum(dim=self.dim, keepdim=True) / count + x_std = x_std.sqrt() + x_std = torch.where(x_std > self.std_min, x_std, torch.ones_like(x_std)) + self.mean = x_mean + self.std = x_std + + +class CausalRevIN(nn.Module): + def __init__(self, dim=-1, std_min=1e-5, max_val=100): + """ + Causal RevIN implementation to enable parallel predictions during training of FlowState + + :param eps: a value added for numerical stability + :param with_missing (bool): whether contiguous patch masking (CPM) is used or not, interpreting nans as missing values + """ + super().__init__() + self.dim = dim + self.std_min = std_min + self.max_val = max_val + + def fit_transform(self, x, mask=None): + self._get_statistics(x, mask) + return self.transform(x) + + def transform(self, x): + return torch.clamp((x - self.mean) / self.std, min=-self.max_val, max=self.max_val) + + def inverse_transform(self, x): + if x.ndim == 2: + return x * self.std + self.mean + elif x.ndim == 3: + return x * self.std.unsqueeze(-1) + self.mean.unsqueeze(-1) + else: + raise ValueError(f"Invalid input dimension: {x.shape}") + + def get_statistics(self): + return self.mean, self.std + + def _get_statistics(self, x, mask=None): + if mask is not None: + n = torch.cumsum(1 - mask.float(), dim=1) + n = torch.where(n == 0, 1.0, n) + else: + n = torch.arange(1, x.shape[1] + 1, device=x.device) + self.mean = (torch.cumsum(x, dim=1) / n).detach() + mask = 1 - mask.float() if mask is not None else 1 + self.std = torch.sqrt(torch.cumsum(((x - self.mean) * mask) ** 2, 1) / n).detach() + self.std = torch.where(self.std > self.std_min, self.std, torch.ones_like(self.std)) + + def set_statistics(self, mean, std): + self.mean = mean + self.std = std diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/pipeline_patchtst_fm.py b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/pipeline_patchtst_fm.py index 1a3400b96d3c2..de36bcee17e5c 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/pipeline_patchtst_fm.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/pipeline_patchtst_fm.py @@ -21,27 +21,21 @@ def preprocess(self, inputs, **infer_kwargs): def forecast(self, inputs, **infer_kwargs) -> list[torch.Tensor]: """ - TODO: YOU WRITE THIS. - 1. Create an empty list called `forecasts`. - 2. Iterate through the [inputs](cci:1://file:///Users/karthik/Documents/projects/iotdb/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py:143:4-178:77) list. - 3. For each input dictionary, extract the "targets" tensor. - 4. Extract the prediction length using `infer_kwargs.get("output_length", 96)`. - 5. Move the tensor to the model's device: `tensor = tensor.to(self.device)` - 6. IBM's PatchTST expects "past_values" and "prediction_length" as arguments. - Run the forward pass inside a `with torch.no_grad():` block natively using: - output = self.model(past_values=tensor, prediction_length=pred_len) - 7. Extract the forecast using `output.prediction_outputs` and append it to your list. - 8. Return the list. + Run the PatchTST-FM-R1 forward pass for each input in the batch. + The model expects a list of 1D tensors (one per variate) and returns + a PatchTSTFMPredictionOutput with a `quantile_predictions` attribute. """ forecasts = [] - for input in inputs: - targets = input['targets'] + for item in inputs: + targets = item["targets"] pred_length = infer_kwargs.get("output_length", 96) + # Move to device and convert [n_variates, length] → list of 1D tensors + # as required by PatchTSTFMForPrediction.forward() tensor = targets.to(self.device) + tensor_list = [tensor[i] for i in range(tensor.shape[0])] with torch.no_grad(): - output = self.model(past_values = tensor, prediction_length = pred_length) - - forecasts.append(output.prediction_outputs) + output = self.model(inputs=tensor_list, prediction_length=pred_length) + forecasts.append(output.quantile_predictions) return forecasts diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/tools.py b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/tools.py new file mode 100644 index 0000000000000..db47b3acb8022 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/tools.py @@ -0,0 +1,219 @@ +import os +import random +import time +from datetime import datetime + +import numpy as np +import pandas as pd +import torch + + +def seed_everything(seed: int = 42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def count_parameters(model): + total_params = sum(p.numel() for p in model.parameters()) + grad_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + no_grad_params = total_params - grad_params + return total_params, grad_params, no_grad_params + + +def to_hms(seconds): + h = int(seconds // 3600) + m = int((seconds % 3600) // 60) + s = int(seconds % 60) + return f"{h:02d}:{m:02d}:{s:02d}" + + +def hms_to_seconds(hms): + h, m, s = map(int, hms.split(":")) + return h * 3600 + m * 60 + s + + +def compute_remaining_time(start_time, current_step, max_steps): + current_time = time.time() + elapsed_time = current_time - start_time + remaining_steps = max_steps - current_step + remaining_time = elapsed_time * remaining_steps / current_step + second_per_step = elapsed_time / current_step + return f"{to_hms(elapsed_time)}<{to_hms(remaining_time)} ({second_per_step:.2f}s/step)" + + +class Timer: + def __init__(self, start_step: int = 0, max_step: int = 0): + self.start_time = time.time() + self.last_time = time.time() + self.last_step = start_step + self.max_step = max_step + + def __call__(self, step): + current_time = time.time() + delta_time = current_time - self.last_time + delta_step = max(step - self.last_step, 1) + remaining_step = self.max_step - step + remaining_time = delta_time * remaining_step / delta_step + elapsed_time = current_time - self.start_time + second_per_step = delta_time / delta_step + self.last_time = current_time + self.last_step = step + return f"{to_hms(elapsed_time)}<{to_hms(remaining_time)} ({second_per_step:.2f}s/step)" + + +class StandardScaler: + """ + A numpy implementation of StandardScaler that mimics sklearn's StandardScaler. + Standardizes features by removing the mean and scaling to unit variance. + """ + + def __init__(self, with_mean=True, with_std=True): + self.with_mean = with_mean + self.with_std = with_std + self.mean_ = None + self.scale_ = None + self.var_ = None + self.n_samples_seen_ = 0 + + def fit(self, X): + """ + Compute the mean and std to be used for later scaling. + + Parameters: + ----------- + X : array-like, shape [n_samples, n_features] + The data used to compute the mean and standard deviation. + + Returns: + -------- + self : object + Returns self. + """ + X = np.array(X, dtype=np.float64) + + if self.with_mean: + self.mean_ = np.mean(X, axis=0) + else: + self.mean_ = np.zeros(X.shape[1], dtype=np.float64) + + if self.with_std: + self.var_ = np.var(X, axis=0) + self.scale_ = np.sqrt(self.var_) + # Handle zeros in scale + self.scale_ = np.where(self.scale_ == 0, 1.0, self.scale_) + else: + self.var_ = np.ones(X.shape[1], dtype=np.float64) + self.scale_ = np.ones(X.shape[1], dtype=np.float64) + + self.n_samples_seen_ = X.shape[0] + + return self + + def transform(self, X): + """ + Perform standardization by centering and scaling. + + Parameters: + ----------- + X : array-like, shape [n_samples, n_features] + The data to standardize. + + Returns: + -------- + X_scaled : array-like, shape [n_samples, n_features] + Standardized data. + """ + X = np.array(X, dtype=np.float64) + + if self.with_mean: + X = X - self.mean_ + + if self.with_std: + X = X / self.scale_ + + return X + + def fit_transform(self, X): + """ + Fit to data, then transform it. + + Parameters: + ----------- + X : array-like, shape [n_samples, n_features] + The data to be transformed. + + Returns: + -------- + X_scaled : array-like, shape [n_samples, n_features] + Standardized data. + """ + return self.fit(X).transform(X) + + def inverse_transform(self, X): + """ + Scale back the data to the original representation. + + Parameters: + ----------- + X : array-like, shape [n_samples, n_features] + The data to inverse transform. + + Returns: + -------- + X_orig : array-like, shape [n_samples, n_features] + Data in original scale. + """ + X = np.array(X, dtype=np.float64) + + if self.with_std: + X = X * self.scale_ + + if self.with_mean: + X = X + self.mean_ + + return X + + +class CSVLogger: + """ + Simple CSV logger that stores training metrics and figure paths. + """ + + def __init__(self, log_dir: str): + """ + Initialize CSV logger. + + Parameters: + ----------- + log_dir : str + Directory to save CSV log files and figures + """ + + self.log_dir = f"{log_dir}/run_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}" + os.makedirs(self.log_dir, exist_ok=True) + + # Initialize dataframe to store scalar metrics + self.scalar_data = [] + self.scalar_file = os.path.join(self.log_dir, "scalars.csv") + + self.fig_dir = os.path.join(self.log_dir, "figures") + os.makedirs(self.fig_dir, exist_ok=True) + os.makedirs(f"{self.fig_dir}/TRAIN") + os.makedirs(f"{self.fig_dir}/VAL") + + def log_scalar(self, tag: str, value: float, step: int): + self.scalar_data.append({"timestamp": datetime.now(), "step": step, "tag": tag, "value": value}) + + def save(self): + # Save to CSV + df = pd.DataFrame(self.scalar_data) + df.to_csv(self.scalar_file, index=False) + + def log_figure(self, tag: str, figure, step: int): + # Create figures subdirectory + fig_path = os.path.join(self.fig_dir, f"{tag}_{step}.png") + figure.savefig(fig_path)