From e2dddc2b6158e4a428b3de372f13ba10a13b7702 Mon Sep 17 00:00:00 2001 From: graceli02 Date: Sat, 21 Mar 2026 03:21:23 -0400 Subject: [PATCH 01/11] [AINode] Integrate Toto time series foundation model as a built-in model Integrate Datadog's Toto-Open-Base-1.0 into AINode's builtin model registry as an optional lazy dependency. - Add TotoConfig (PretrainedConfig) with Toto architecture params - Add TotoForPrediction wrapper loaded via ModelHubMixin.from_pretrained - Add TotoPipeline (ForecastPipeline) with lazy toto-ts import and clear installation instructions if the package is missing - Register 'toto' in BUILTIN_HF_TRANSFORMERS_MODEL_MAP - Add 'toto' entry to AINodeTestUtils.BUILTIN_LTSM_MAP toto-ts is optional: no changes to pyproject.toml or poetry.lock --- .../iotdb/ainode/utils/AINodeTestUtils.java | 4 +- .../iotdb/ainode/core/model/model_info.py | 18 +++ .../iotdb/ainode/core/model/toto/__init__.py | 0 .../core/model/toto/configuration_toto.py | 73 ++++++++++ .../ainode/core/model/toto/modeling_toto.py | 128 ++++++++++++++++++ .../ainode/core/model/toto/pipeline_toto.py | 62 +++++++++ 6 files changed, 284 insertions(+), 1 deletion(-) create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/__init__.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java index e41d3d4e0f97e..bf758a083d463 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java @@ -58,7 +58,9 @@ public class AINodeTestUtils { new AbstractMap.SimpleEntry<>( "chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active")), new AbstractMap.SimpleEntry<>( - "moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active"))) + "moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "toto", new FakeModelInfo("toto", "toto", "builtin", "active"))) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); public static final Map BUILTIN_MODEL_MAP; diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index da752cbd78432..024f62bac76e6 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -160,4 +160,22 @@ def __repr__(self): }, transformers_registered=True, ), + "toto": ModelInfo( + model_id="toto", + category=ModelCategory.BUILTIN, + state=ModelStates.INACTIVE, + model_type="toto", + pipeline_cls="pipeline_toto.TotoPipeline", + repo_id="Datadog/Toto-Open-Base-1.0", +<<<<<<< Updated upstream + auto_map=None, + transformers_registered=False, +======= + auto_map={ + "AutoConfig": "configuration_toto.TotoConfig", + "AutoModelForCausalLM": "modeling_toto.TotoForPrediction", + }, + transformers_registered=True, +>>>>>>> Stashed changes + ), } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py new file mode 100644 index 0000000000000..82e185c377da3 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import List, Optional + +from transformers import PretrainedConfig + + +class TotoConfig(PretrainedConfig): + """ + Configuration class for the Toto time series forecasting model. + + Toto (Time Series Optimized Transformer for Observability) is a foundation model + for multivariate time series forecasting developed by Datadog. It uses a decoder-only + architecture with per-variate patch-based causal scaling, proportional time-variate + factorized attention, and a Student-T mixture prediction head. + + Reference: https://github.com/DataDog/toto + """ + + model_type = "toto" + + def __init__( + self, + patch_size: int = 32, + stride: int = 32, + embed_dim: int = 1024, + num_layers: int = 18, + num_heads: int = 16, + mlp_hidden_dim: int = 2816, + dropout: float = 0.0, + spacewise_every_n_layers: int = 3, + scaler_cls: str = "per_variate_causal", + output_distribution_classes: Optional[List[str]] = None, + spacewise_first: bool = True, + use_memory_efficient_attention: bool = True, + stabilize_with_global: bool = True, + scale_factor_exponent: float = 10.0, + **kwargs, + ): + self.patch_size = patch_size + self.stride = stride + self.embed_dim = embed_dim + self.num_layers = num_layers + self.num_heads = num_heads + self.mlp_hidden_dim = mlp_hidden_dim + self.dropout = dropout + self.spacewise_every_n_layers = spacewise_every_n_layers + self.scaler_cls = scaler_cls + self.output_distribution_classes = output_distribution_classes or [ + "student_t_mixture" + ] + self.spacewise_first = spacewise_first + self.use_memory_efficient_attention = use_memory_efficient_attention + self.stabilize_with_global = stabilize_with_global + self.scale_factor_exponent = scale_factor_exponent + + super().__init__(**kwargs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py new file mode 100644 index 0000000000000..f2fbe16390815 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import torch + +from iotdb.ainode.core.log import Logger + +logger = Logger() + + +class TotoForPrediction(torch.nn.Module): + """ + Wrapper around the Toto model for AINode integration. + + Toto (Time Series Optimized Transformer for Observability) is a 151M parameter + foundation model for multivariate time series forecasting. This wrapper delegates + model loading to the ``toto-ts`` package while providing a compatible interface + for AINode's model loading mechanism. + + The underlying Toto model uses ``huggingface_hub.ModelHubMixin`` for ``from_pretrained`` + support, which differs from the standard ``transformers.PreTrainedModel`` pattern. + This wrapper bridges that gap. + + Reference: https://huggingface.co/Datadog/Toto-Open-Base-1.0 + """ + + def __init__(self, toto_model): + """ + Initialize the wrapper with a loaded Toto model instance. + + Args: + toto_model: A ``toto.model.toto.Toto`` instance. + """ + super().__init__() + self.toto = toto_model + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + """ + Load a Toto model from a local directory or HuggingFace Hub repository. + + This delegates to ``toto.model.toto.Toto.from_pretrained()`` which uses + ``ModelHubMixin`` to load the model weights and configuration. + + Args: + pretrained_model_name_or_path (str): Path to a local directory containing + ``config.json`` and ``model.safetensors``, or a HuggingFace Hub repo ID + (e.g., ``Datadog/Toto-Open-Base-1.0``). + **kwargs: Additional keyword arguments passed to the underlying loader. + + Returns: + TotoForPrediction: A wrapper instance containing the loaded Toto model. + """ + from toto.model.toto import Toto + + toto_model = Toto.from_pretrained(pretrained_model_name_or_path, **kwargs) + logger.info(f"Loaded Toto model from {pretrained_model_name_or_path}") + return cls(toto_model) + + @classmethod + def from_config(cls, config): + """ + Create a Toto model from a configuration (for training from scratch). + + Args: + config: A ``TotoConfig`` or compatible configuration object. + + Returns: + TotoForPrediction: A wrapper instance containing a newly initialized Toto model. + """ + from toto.model.toto import Toto + + toto_model = Toto( + patch_size=getattr(config, "patch_size", 32), + stride=getattr(config, "stride", 32), + embed_dim=getattr(config, "embed_dim", 1024), + num_layers=getattr(config, "num_layers", 18), + num_heads=getattr(config, "num_heads", 16), + mlp_hidden_dim=getattr(config, "mlp_hidden_dim", 2816), + dropout=getattr(config, "dropout", 0.0), + spacewise_every_n_layers=getattr(config, "spacewise_every_n_layers", 3), + scaler_cls=getattr(config, "scaler_cls", "per_variate_causal"), + output_distribution_classes=getattr( + config, "output_distribution_classes", ["student_t_mixture"] + ), + spacewise_first=getattr(config, "spacewise_first", True), + use_memory_efficient_attention=getattr( + config, "use_memory_efficient_attention", True + ), + stabilize_with_global=getattr(config, "stabilize_with_global", True), + scale_factor_exponent=getattr(config, "scale_factor_exponent", 10.0), + ) + return cls(toto_model) + + @property + def backbone(self): + """ + Access the underlying TotoBackbone model used for inference. + + Returns: + The ``TotoBackbone`` instance from the Toto model. + """ + return self.toto.model + + @property + def device(self): + """ + Get the device of the model parameters. + + Returns: + torch.device: The device where the model parameters reside. + """ + return self.toto.device diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py new file mode 100644 index 0000000000000..4f27f6c5f4f3b --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py @@ -0,0 +1,62 @@ +import torch + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline +from iotdb.ainode.core.model.toto.data.util.dataset import MaskedTimeseries + + +class TotoPipeline(ForecastPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, **model_kwargs) + + def preprocess(self, inputs, **infer_kwargs): + super().preprocess(inputs, **infer_kwargs) + processed_inputs = [] + + for item in inputs: + targets = item["targets"] + if targets.ndim == 1: + targets = targets.unsqueeze(0) + + variate_count, series_len = targets.shape + device = targets.device + + processed_inputs.append( + MaskedTimeseries( + series=targets, + padding_mask=torch.ones( + (variate_count, series_len), dtype=torch.bool, device=device + ), + id_mask=torch.arange( + variate_count, dtype=torch.int64, device=device + ).unsqueeze(-1).expand(variate_count, series_len), + timestamp_seconds=torch.arange( + series_len, dtype=torch.int64, device=device + ).unsqueeze(0).expand(variate_count, series_len), + time_interval_seconds=torch.ones( + variate_count, dtype=torch.int64, device=device + ), + num_exogenous_variables=0, + ) + ) + + return processed_inputs + + def forecast(self, inputs, **infer_kwargs): + output_length = infer_kwargs.get("output_length", 96) + num_samples = infer_kwargs.get("num_samples", None) + + outputs = [] + for item in inputs: + forecast = self.model.forecast( + item, + prediction_length=output_length, + num_samples=num_samples, + ) + mean = forecast.mean + if mean.ndim == 3 and mean.shape[0] == 1: + mean = mean.squeeze(0) + outputs.append(mean) + return outputs + + def postprocess(self, outputs, **infer_kwargs): + return super().postprocess(outputs, **infer_kwargs) From e8c74232c86b11a7c3121fa9a6281921c2b5f451 Mon Sep 17 00:00:00 2001 From: graceli02 Date: Sat, 21 Mar 2026 03:26:39 -0400 Subject: [PATCH 02/11] [AINode] Fix merge conflict in model_info.py and reformat pipeline_toto.py --- iotdb-core/ainode/iotdb/ainode/core/model/model_info.py | 5 ----- .../ainode/iotdb/ainode/core/model/toto/pipeline_toto.py | 8 ++++++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index 024f62bac76e6..642986c42d21f 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -167,15 +167,10 @@ def __repr__(self): model_type="toto", pipeline_cls="pipeline_toto.TotoPipeline", repo_id="Datadog/Toto-Open-Base-1.0", -<<<<<<< Updated upstream - auto_map=None, - transformers_registered=False, -======= auto_map={ "AutoConfig": "configuration_toto.TotoConfig", "AutoModelForCausalLM": "modeling_toto.TotoForPrediction", }, transformers_registered=True, ->>>>>>> Stashed changes ), } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py index 4f27f6c5f4f3b..0ff6fccc203f7 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py @@ -28,10 +28,14 @@ def preprocess(self, inputs, **infer_kwargs): ), id_mask=torch.arange( variate_count, dtype=torch.int64, device=device - ).unsqueeze(-1).expand(variate_count, series_len), + ) + .unsqueeze(-1) + .expand(variate_count, series_len), timestamp_seconds=torch.arange( series_len, dtype=torch.int64, device=device - ).unsqueeze(0).expand(variate_count, series_len), + ) + .unsqueeze(0) + .expand(variate_count, series_len), time_interval_seconds=torch.ones( variate_count, dtype=torch.int64, device=device ), From e6de6a5e4ba9de72704394ef729693725ddd1856 Mon Sep 17 00:00:00 2001 From: graceli02 Date: Sat, 21 Mar 2026 16:34:55 -0400 Subject: [PATCH 03/11] [AINode] Add Apache license headers and fix pipeline_toto.py - Add Apache 2.0 license header to __init__.py and pipeline_toto.py - Fix pipeline_toto.py: replace broken local import with lazy toto-ts import via _import_toto() helper; fix merge conflict in model_info.py --- .../iotdb/ainode/core/model/toto/__init__.py | 17 +++ .../ainode/core/model/toto/pipeline_toto.py | 111 ++++++++++++++---- 2 files changed, 105 insertions(+), 23 deletions(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/__init__.py index e69de29bb2d1d..2a1e720805f29 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/__init__.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py index 0ff6fccc203f7..7cfd26b43702d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py @@ -1,15 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import logging +import warnings + import torch from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline -from iotdb.ainode.core.model.toto.data.util.dataset import MaskedTimeseries + +logger = logging.getLogger(__name__) + +_TOTO_INSTALL_MSG = ( + "toto-ts is required to use the Toto model but is not installed.\n" + "Install it with: pip install toto-ts\n" + "Note: toto-ts pins specific versions of torch, numpy, and transformers " + "that may conflict with other AINode dependencies. Install in a separate " + "environment if needed." +) + + +def _import_toto(): + try: + from toto.data.util.dataset import MaskedTimeseries + from toto.inference.forecaster import TotoForecaster + + return MaskedTimeseries, TotoForecaster + except ImportError as e: + raise ImportError(_TOTO_INSTALL_MSG) from e class TotoPipeline(ForecastPipeline): def __init__(self, model_info, **model_kwargs): super().__init__(model_info, **model_kwargs) + _, TotoForecaster = _import_toto() + self.forecaster = TotoForecaster(self.model.backbone) def preprocess(self, inputs, **infer_kwargs): super().preprocess(inputs, **infer_kwargs) + MaskedTimeseries, _ = _import_toto() processed_inputs = [] for item in inputs: @@ -17,29 +60,38 @@ def preprocess(self, inputs, **infer_kwargs): if targets.ndim == 1: targets = targets.unsqueeze(0) - variate_count, series_len = targets.shape + n_variates, series_len = targets.shape device = targets.device + if "past_covariates" in item or "future_covariates" in item: + warnings.warn( + "TotoPipeline does not support covariates; they will be ignored.", + UserWarning, + stacklevel=2, + ) + + padding_mask = ~torch.isnan(targets) + targets = targets.nan_to_num(0.0) + + id_mask = torch.zeros( + n_variates, series_len, dtype=torch.long, device=device + ) + timestamp_seconds = ( + torch.arange(series_len, dtype=torch.long, device=device) + .unsqueeze(0) + .expand(n_variates, series_len) + ) + time_interval_seconds = torch.ones( + n_variates, dtype=torch.long, device=device + ) + processed_inputs.append( MaskedTimeseries( series=targets, - padding_mask=torch.ones( - (variate_count, series_len), dtype=torch.bool, device=device - ), - id_mask=torch.arange( - variate_count, dtype=torch.int64, device=device - ) - .unsqueeze(-1) - .expand(variate_count, series_len), - timestamp_seconds=torch.arange( - series_len, dtype=torch.int64, device=device - ) - .unsqueeze(0) - .expand(variate_count, series_len), - time_interval_seconds=torch.ones( - variate_count, dtype=torch.int64, device=device - ), - num_exogenous_variables=0, + padding_mask=padding_mask, + id_mask=id_mask, + timestamp_seconds=timestamp_seconds, + time_interval_seconds=time_interval_seconds, ) ) @@ -48,15 +100,28 @@ def preprocess(self, inputs, **infer_kwargs): def forecast(self, inputs, **infer_kwargs): output_length = infer_kwargs.get("output_length", 96) num_samples = infer_kwargs.get("num_samples", None) + samples_per_batch = infer_kwargs.get("samples_per_batch", 10) outputs = [] - for item in inputs: - forecast = self.model.forecast( - item, + for masked_ts in inputs: + masked_ts = masked_ts._replace( + series=masked_ts.series.to(self.model.device), + padding_mask=masked_ts.padding_mask.to(self.model.device), + id_mask=masked_ts.id_mask.to(self.model.device), + timestamp_seconds=masked_ts.timestamp_seconds.to(self.model.device), + time_interval_seconds=masked_ts.time_interval_seconds.to( + self.model.device + ), + ) + result = self.forecaster.forecast( + masked_ts, prediction_length=output_length, num_samples=num_samples, + samples_per_batch=samples_per_batch, ) - mean = forecast.mean + mean = result.mean + if mean.ndim == 3: + mean = mean.mean(dim=0) if mean.ndim == 3 and mean.shape[0] == 1: mean = mean.squeeze(0) outputs.append(mean) From 713e4f00d278dc9404279a575e98604e5ee20a3a Mon Sep 17 00:00:00 2001 From: graceli02 Date: Sun, 22 Mar 2026 20:17:39 -0400 Subject: [PATCH 04/11] [AINode] Integrate toto source, fix build_binary.py, rewrite modeling/pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix build_binary.py: poetry lock → poetry install --no-root; remove capture_output=True so errors are visible in CI - Vendor toto source (DataDog/toto, Apache-2.0) into model/toto/: model/{attention,backbone,distribution,embedding,feed_forward, fusion,rope,scaler,transformer,toto,util}.py data/util/dataset.py inference/forecaster.py Eliminates toto-ts pip dependency and all gluonts transitive deps. gluonts replaced with pure PyTorch (TransformedDistribution/AffineTransform, torch.nn.Module Scaler base). - Rewrite modeling_toto.py: TotoForPrediction now inherits PreTrainedModel (required by model_loader); backbone stored as self.model so safetensors keys (model.*) map directly; custom from_pretrained applies _map_state_dict_keys for SwiGLU remapping before loading weights. - Rewrite pipeline_toto.py: import directly from local source; TotoForecaster created lazily inside _get_forecaster() — not at __init__ time — fixing ImportError at pipeline construction in CI. - pyproject.toml: add rotary-embedding-torch>=0.8.0 (only new dep) - .gitignore: un-ignore toto data/ package (Python source, not data files) - Add toto/NOTICE with Datadog attribution per Apache policy Co-Authored-By: Claude Sonnet 4.6 --- iotdb-core/ainode/.gitignore | 4 + iotdb-core/ainode/build_binary.py | 3 +- .../iotdb/ainode/core/model/toto/NOTICE | 36 ++ .../ainode/core/model/toto/data/__init__.py | 20 + .../core/model/toto/data/util/__init__.py | 20 + .../core/model/toto/data/util/dataset.py | 127 ++++++ .../core/model/toto/inference/__init__.py | 20 + .../core/model/toto/inference/forecaster.py | 372 ++++++++++++++++++ .../ainode/core/model/toto/model/__init__.py | 20 + .../ainode/core/model/toto/model/attention.py | 213 ++++++++++ .../ainode/core/model/toto/model/backbone.py | 236 +++++++++++ .../core/model/toto/model/distribution.py | 106 +++++ .../ainode/core/model/toto/model/embedding.py | 77 ++++ .../core/model/toto/model/feed_forward.py | 35 ++ .../ainode/core/model/toto/model/fusion.py | 51 +++ .../ainode/core/model/toto/model/rope.py | 94 +++++ .../ainode/core/model/toto/model/scaler.py | 299 ++++++++++++++ .../ainode/core/model/toto/model/toto.py | 151 +++++++ .../core/model/toto/model/transformer.py | 287 ++++++++++++++ .../ainode/core/model/toto/model/util.py | 210 ++++++++++ .../ainode/core/model/toto/modeling_toto.py | 194 +++++---- .../ainode/core/model/toto/pipeline_toto.py | 46 +-- iotdb-core/ainode/pyproject.toml | 1 + 23 files changed, 2519 insertions(+), 103 deletions(-) create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/NOTICE create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/data/__init__.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/__init__.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/dataset.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/__init__.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/model/__init__.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/model/feed_forward.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/model/rope.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py diff --git a/iotdb-core/ainode/.gitignore b/iotdb-core/ainode/.gitignore index e4947e516d37a..60ecdcf76b52e 100644 --- a/iotdb-core/ainode/.gitignore +++ b/iotdb-core/ainode/.gitignore @@ -20,3 +20,7 @@ poetry.lock # generated by pyinstaller /dist/ /build/ + +# Un-ignore toto source data/ package (Python source, not data files) +!iotdb/ainode/core/model/toto/data/ +!iotdb/ainode/core/model/toto/data/** diff --git a/iotdb-core/ainode/build_binary.py b/iotdb-core/ainode/build_binary.py index c943de4158170..bea1970626b56 100644 --- a/iotdb-core/ainode/build_binary.py +++ b/iotdb-core/ainode/build_binary.py @@ -438,11 +438,10 @@ def verify_poetry_env(): print("Running poetry install...") subprocess.run( - [str(poetry_exe), "lock"], + [str(poetry_exe), "install", "--no-root"], cwd=str(script_dir), env=venv_env, check=True, - capture_output=True, text=True, ) verify_poetry_env() # Verify before install diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/NOTICE b/iotdb-core/ainode/iotdb/ainode/core/model/toto/NOTICE new file mode 100644 index 0000000000000..23999458c3eee --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/NOTICE @@ -0,0 +1,36 @@ +Apache IoTDB – AINode: Toto model +Copyright 2025 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +============================================================================ + +This directory includes source code derived from the DataDog/toto project: + + Toto – Timeseries-Optimized Transformer for Observability + Copyright 2025 Datadog, Inc. + Licensed under the Apache License, Version 2.0 + https://github.com/DataDog/toto + +The following files are derived from that project: + + model/attention.py + model/backbone.py + model/distribution.py + model/embedding.py + model/feed_forward.py + model/fusion.py + model/rope.py + model/scaler.py + model/transformer.py + model/toto.py + model/util.py + data/util/dataset.py + inference/forecaster.py + +Each derived file carries the original copyright notice and the Apache 2.0 +license header as required by the original project's license. + +A copy of the Apache License, Version 2.0 is available at: + http://www.apache.org/licenses/LICENSE-2.0 diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/__init__.py new file mode 100644 index 0000000000000..ba26b1edd945e --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/__init__.py new file mode 100644 index 0000000000000..ba26b1edd945e --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/dataset.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/dataset.py new file mode 100644 index 0000000000000..6bccf35988c24 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/dataset.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +from functools import reduce +from typing import NamedTuple + +import numpy as np +import torch +import torch.utils.data +from einops import repeat +from jaxtyping import Bool, Float, Int, Shaped + + +def pad_array( + values: Shaped[torch.Tensor, "*batch variates series_len"], # noqa: F722 + patch_stride: int, +) -> Shaped[torch.Tensor, "*batch variates padded_length"]: # noqa: F722 + """ + Makes sure that the series length is divisible by the patch_stride + by adding left-padding. + """ + if isinstance(values, np.ndarray): + values = torch.from_numpy(values) + series_len = values.shape[-1] + padded_length = int(np.ceil(series_len / patch_stride) * patch_stride) + if values.ndim == 2: + padded_values = torch.zeros((values.shape[0], padded_length), dtype=values.dtype, device=values.device) + elif values.ndim == 3: + padded_values = torch.zeros( + (values.shape[0], values.shape[1], padded_length), + dtype=values.dtype, + device=values.device, + ) + else: + raise ValueError(f"Unsupported number of dimensions: {values.ndim}") + padded_values[..., -series_len:] = values + + return padded_values + + +def pad_id_mask( + id_mask: Int[torch.Tensor, "*batch variates series_len"], # noqa: F722 + patch_stride: int, +) -> Int[torch.Tensor, "*batch variates padded_length"]: # noqa: F722 + """ + Makes sure that the series length is divisible by the patch_stride + by adding left-padding to the id mask. + """ + series_len = id_mask.shape[-1] + padded_length = int(np.ceil(series_len / patch_stride) * patch_stride) + padding_amount = padded_length - series_len + left_edge: Int[torch.Tensor, "*batch variates"] = id_mask[..., 0] # noqa: F722 + if id_mask.ndim == 2: + padding = repeat( + left_edge, + "variates -> variates padding_amount", + padding_amount=padding_amount, + ) + id_mask = torch.cat([padding, id_mask], dim=1) + elif id_mask.ndim == 3: + padding = repeat( + left_edge, + "batch variates -> batch variates padding_amount", + padding_amount=padding_amount, + ) + id_mask = torch.cat([padding, id_mask], dim=2) + else: + raise ValueError(f"Unsupported number of dimensions: {id_mask.ndim}") + + return id_mask + + +class MaskedTimeseries(NamedTuple): + series: Float[torch.Tensor, "*batch variates series_len"] # noqa: F722 + padding_mask: Bool[torch.Tensor, "*batch variates series_len"] # noqa: F722 + id_mask: Int[torch.Tensor, "*batch variates #series_len"] # noqa: F722 + timestamp_seconds: Int[torch.Tensor, "*batch variates series_len"] # noqa: F722 + time_interval_seconds: Int[torch.Tensor, "*batch variates"] # noqa: F722 + num_exogenous_variables: int = 0 + + def to(self, device: torch.device) -> "MaskedTimeseries": + return MaskedTimeseries( + series=self.series.to(device), + padding_mask=self.padding_mask.to(device), + id_mask=self.id_mask.to(device), + timestamp_seconds=self.timestamp_seconds.to(device), + time_interval_seconds=self.time_interval_seconds.to(device), + num_exogenous_variables=self.num_exogenous_variables, + ) + + +def is_extreme_value(t: torch.Tensor) -> torch.Tensor: + if torch.is_floating_point(t): + max_value = torch.finfo(t.dtype).max + else: + max_value = torch.iinfo(t.dtype).max + + return reduce( + torch.logical_or, + ( + torch.isinf(t), + torch.isnan(t), + t.abs() >= max_value / 2, + ), + ) + + +def replace_extreme_values(t: torch.Tensor, replacement: float = 0.0) -> torch.Tensor: + return torch.where(is_extreme_value(t), torch.tensor(replacement, dtype=t.dtype, device=t.device), t) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/__init__.py new file mode 100644 index 0000000000000..ba26b1edd945e --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py new file mode 100644 index 0000000000000..2099892293be8 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py @@ -0,0 +1,372 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +from dataclasses import dataclass +from typing import cast + +import numpy as np +import torch +from einops import rearrange, repeat +from jaxtyping import Bool, Float, Int +from torch.distributions import Distribution +from torch.distributions import TransformedDistribution +from torch.distributions.transforms import AffineTransform + +from ..data.util.dataset import ( + MaskedTimeseries, + pad_array, + pad_id_mask, + replace_extreme_values, +) +from ..model.backbone import TotoBackbone + + +class AffineTransformed(TransformedDistribution): + """ + Thin wrapper around TransformedDistribution with AffineTransform, + replacing gluonts.torch.distributions.AffineTransformed. + """ + + def __init__(self, base_distribution, loc=0.0, scale=1.0): + super().__init__(base_distribution, AffineTransform(loc=loc, scale=scale)) + + @property + def mean(self): + loc = self.transforms[0].loc + scale = self.transforms[0].scale + return loc + scale * self.base_dist.mean + + # Note: Do NOT override sample() here. TransformedDistribution.sample() correctly + # calls base_dist.sample() (not rsample), which works for non-reparameterizable + # distributions like MixtureSameFamily. + + +@dataclass(frozen=True) +class Forecast: + mean: Float[torch.Tensor, "batch variate future_time_steps"] + samples: Float[torch.Tensor, "batch variate future_time_steps samples"] | None = None + + def quantile(self, q: float | torch.Tensor) -> Float[torch.Tensor, "batch variate future_time_steps"]: + assert self.samples is not None, "samples must be provided to compute quantiles" + assert isinstance(q, (float, torch.Tensor)), "q must be a float or a tensor" + if isinstance(q, float): + q = torch.tensor(q, device=self.samples.device, dtype=self.samples.dtype) + return self.samples.quantile(q, dim=-1) + + @property + def median(self) -> Float[torch.Tensor, "batch variate future_time_steps"]: + return self.quantile(0.5) + + @property + def std(self) -> Float[torch.Tensor, "batch variate future_time_steps"]: + assert self.samples is not None, "samples must be provided to compute standard deviation" + return self.samples.std(dim=-1) + + +class TotoForecaster: + """ + A forecaster class for the Toto model that handles autoregressive decoding + for time series forecasting. + """ + + model: TotoBackbone + + def __init__(self, model: TotoBackbone): + self.model = model + self.model.eval() + + def forecast( + self, + inputs: MaskedTimeseries, + prediction_length: int, + num_samples: int | None = None, + samples_per_batch: int = 10, + use_kv_cache: bool = True, + future_exogenous_variables: Float[torch.Tensor, "batch exogenous_variables future_time_steps"] | None = None, + ) -> Forecast: + if len(inputs.series.shape) == 2: + batch = cast(MaskedTimeseries, torch.utils.data.default_collate([inputs])) + else: + batch = inputs + + if future_exogenous_variables is not None and len(future_exogenous_variables.shape) == 2: + future_exogenous_variables = future_exogenous_variables.unsqueeze(0) + + series = pad_array(batch.series, self.model.patch_embed.stride) + padding_mask = pad_array(batch.padding_mask, self.model.patch_embed.stride) + id_mask = batch.id_mask + if id_mask is not None: + id_mask = pad_id_mask(batch.id_mask, self.model.patch_embed.stride) + timestamp_seconds = pad_array(batch.timestamp_seconds, self.model.patch_embed.stride) + time_interval_seconds: Int[torch.Tensor, "batch variate series_len"] = torch.as_tensor( + batch.time_interval_seconds, device=series.device, dtype=torch.int + ) + + if num_samples is not None: + samples = self.generate_samples( + inputs=series, + prediction_length=prediction_length, + num_samples=num_samples, + timestamp_seconds=timestamp_seconds, + time_interval_seconds=time_interval_seconds, + input_padding_mask=padding_mask, + id_mask=id_mask, + sampling_batch_size=samples_per_batch, + use_kv_cache=use_kv_cache, + future_exogenous_variables=future_exogenous_variables, + num_exogenous_variables=batch.num_exogenous_variables, + ) + mean = samples.mean(dim=-1) + else: + mean = self.generate_mean( + inputs=series, + prediction_length=prediction_length, + timestamp_seconds=timestamp_seconds, + time_interval_seconds=time_interval_seconds, + input_padding_mask=padding_mask, + id_mask=id_mask, + use_kv_cache=use_kv_cache, + future_exogenous_variables=future_exogenous_variables, + num_exogenous_variables=batch.num_exogenous_variables, + ) + samples = None + + return Forecast(mean=mean, samples=samples) + + def assert_ev_compatibility( + self, + inputs, + future_exogenous_variables, + prediction_length, + num_exogenous_variables, + ) -> None: + assert future_exogenous_variables.shape[-1] == prediction_length + assert future_exogenous_variables.shape[0] == inputs.shape[0] + assert num_exogenous_variables == future_exogenous_variables.shape[-2] + + def round_ft_ev(self, future_exogenous_variables, T_rounded): + B, V_ev, T_future = future_exogenous_variables.shape + dtype = future_exogenous_variables.dtype + device = future_exogenous_variables.device + padding = torch.zeros(B, V_ev, T_rounded - T_future, device=device, dtype=dtype) + return torch.cat([future_exogenous_variables, padding], dim=-1) + + @torch.no_grad() + def generate_mean( + self, + inputs: Float[torch.Tensor, "batch variate time_steps"], + prediction_length: int, + timestamp_seconds: Int[torch.Tensor, "batch variate time_steps"], + time_interval_seconds: Int[torch.Tensor, "batch variate"], + input_padding_mask: Bool[torch.Tensor, "batch variate time_steps"] | None = None, + id_mask: Float[torch.Tensor, "batch #variate time_steps"] | None = None, + use_kv_cache: bool = False, + future_exogenous_variables=None, + num_exogenous_variables: int = 0, + ) -> Float[torch.Tensor, "batch variate time_steps"]: + if input_padding_mask is None: + input_padding_mask = torch.ones_like(inputs, dtype=torch.bool, device=inputs.device) + if id_mask is None: + id_mask = torch.zeros_like(inputs, dtype=torch.int, device=inputs.device) + + if future_exogenous_variables is not None: + self.assert_ev_compatibility(inputs, future_exogenous_variables, prediction_length, num_exogenous_variables) + + patch_size = self.model.patch_embed.stride + rounded_steps = int(np.ceil(prediction_length / patch_size) * patch_size) + if rounded_steps > prediction_length and future_exogenous_variables is not None: + future_exogenous_variables = self.round_ft_ev(future_exogenous_variables, rounded_steps) + start_index = inputs.shape[-1] + end_index = start_index + prediction_length + + dummy_padding = torch.ones( + (input_padding_mask.shape[0], input_padding_mask.shape[1], patch_size), + device=inputs.device, + dtype=torch.bool, + ) + dummy_id_mask = repeat( + id_mask[:, :, -1:], + "batch variates 1 -> batch variates patch_size", + patch_size=patch_size, + ) + if use_kv_cache: + kv_cache = self.model.allocate_kv_cache( + batch_size=inputs.shape[0], + num_variates=inputs.shape[1], + max_time_steps=inputs.shape[2] + rounded_steps, + device=inputs.device, + dtype=inputs.dtype, + ) + else: + kv_cache = None + + scaling_prefix_length = inputs.shape[-1] + + for idx in range(rounded_steps // patch_size): + base_distr, loc, scale = self.model( + inputs=inputs, + input_padding_mask=input_padding_mask, + id_mask=id_mask, + kv_cache=kv_cache, + scaling_prefix_length=scaling_prefix_length, + num_exogenous_variables=num_exogenous_variables, + ) + distr = self.create_affine_transformed(base_distr, loc, scale) + + samples = replace_extreme_values(distr.mean[:, :, -patch_size:]) + + if future_exogenous_variables is not None: + start, stop = idx * patch_size, (idx + 1) * patch_size + samples[:, -num_exogenous_variables:] = future_exogenous_variables[:, :, start:stop] + + inputs = torch.cat([inputs, samples], dim=-1) + id_mask = torch.cat([id_mask, dummy_id_mask], dim=-1) + input_padding_mask = torch.cat([input_padding_mask, dummy_padding], dim=-1) + for _ in range(patch_size): + next_timestamp = timestamp_seconds[:, :, -1] + time_interval_seconds + timestamp_seconds = torch.cat([timestamp_seconds, next_timestamp.unsqueeze(-1)], dim=-1) + + return inputs.detach()[:, :, start_index:end_index] + + @torch.no_grad() + def generate_samples( + self, + inputs: Float[torch.Tensor, "batch variate time_steps"], + prediction_length: int, + num_samples: int, + timestamp_seconds: Int[torch.Tensor, "batch variate time_steps"], + time_interval_seconds: Int[torch.Tensor, "batch variate"], + input_padding_mask: Bool[torch.Tensor, "batch variate time_steps"] | None = None, + id_mask: Float[torch.Tensor, "batch #variate time_steps"] | None = None, + sampling_batch_size: int = 10, + use_kv_cache: bool = False, + future_exogenous_variables=None, + num_exogenous_variables: int = 0, + ) -> Float[torch.Tensor, "batch variate time_steps samples"]: + if input_padding_mask is None: + input_padding_mask = torch.ones_like(inputs, dtype=torch.bool, device=inputs.device) + if id_mask is None: + id_mask = torch.zeros_like(inputs, dtype=torch.int, device=inputs.device) + + if future_exogenous_variables is not None: + self.assert_ev_compatibility(inputs, future_exogenous_variables, prediction_length, num_exogenous_variables) + + assert num_samples % sampling_batch_size == 0, "num_samples must be divisible by sampling_batch_size" + num_batches = num_samples // sampling_batch_size + + patch_size = self.model.patch_embed.patch_size + rounded_steps = int(np.ceil(prediction_length / patch_size) * patch_size) + if rounded_steps > prediction_length and future_exogenous_variables is not None: + future_exogenous_variables = self.round_ft_ev(future_exogenous_variables, rounded_steps) + start_index = inputs.shape[-1] + end_index = start_index + prediction_length + + dummy_padding = torch.ones( + (input_padding_mask.shape[0] * sampling_batch_size, input_padding_mask.shape[1], patch_size), + dtype=torch.bool, + device=inputs.device, + ) + dummy_id_mask = repeat( + id_mask[:, :, -1:], + "batch variates 1 -> (sampling_batch_size batch) variates patch_size", + sampling_batch_size=sampling_batch_size, + patch_size=patch_size, + ) + inputs = repeat(inputs, "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", sampling_batch_size=sampling_batch_size) + if future_exogenous_variables is not None: + future_exogenous_variables = repeat( + future_exogenous_variables, + "batch exogenous_variables future_time_steps -> (sampling_batch_size batch) exogenous_variables future_time_steps", + sampling_batch_size=sampling_batch_size, + ) + input_padding_mask = repeat(input_padding_mask, "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", sampling_batch_size=sampling_batch_size) + id_mask = repeat(id_mask, "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", sampling_batch_size=sampling_batch_size) + timestamp_seconds = repeat(timestamp_seconds, "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", sampling_batch_size=sampling_batch_size) + time_interval_seconds = repeat(time_interval_seconds, "batch variates -> (sampling_batch_size batch) variates", sampling_batch_size=sampling_batch_size) + + all_samples = [] + if use_kv_cache: + kv_cache = self.model.allocate_kv_cache( + batch_size=inputs.shape[0], + num_variates=inputs.shape[1], + max_time_steps=inputs.shape[2] + rounded_steps, + device=inputs.device, + dtype=inputs.dtype, + ) + else: + kv_cache = None + + scaling_prefix_length = inputs.shape[-1] + + for _ in range(num_batches): + batch_inputs = torch.clone(inputs) + batch_input_padding_mask = torch.clone(input_padding_mask) + batch_id_mask = torch.clone(id_mask) + batch_timestamp_seconds = torch.clone(timestamp_seconds) + + for idx in range(rounded_steps // patch_size): + base_distr, loc, scale = self.model( + inputs=batch_inputs, + input_padding_mask=batch_input_padding_mask, + id_mask=batch_id_mask, + kv_cache=kv_cache, + scaling_prefix_length=scaling_prefix_length, + num_exogenous_variables=num_exogenous_variables, + ) + distr = self.create_affine_transformed(base_distr, loc, scale) + + sample = distr.sample() + assert sample is not None + + samples = replace_extreme_values(sample[:, :, -patch_size:]) + + if future_exogenous_variables is not None: + start, stop = idx * patch_size, (idx + 1) * patch_size + samples[:, -num_exogenous_variables:] = future_exogenous_variables[:, :, start:stop] + batch_inputs = torch.cat([batch_inputs, samples], dim=-1) + batch_id_mask = torch.cat([batch_id_mask, dummy_id_mask], dim=-1) + batch_input_padding_mask = torch.cat([batch_input_padding_mask, dummy_padding], dim=-1) + for _ in range(patch_size): + next_timestamp = batch_timestamp_seconds[:, :, -1] + time_interval_seconds + batch_timestamp_seconds = torch.cat([batch_timestamp_seconds, next_timestamp.unsqueeze(-1)], dim=-1) + all_samples.append(batch_inputs) + if kv_cache is not None: + kv_cache.reset() + + outputs = torch.cat(all_samples, dim=0) + unfolded_outputs = rearrange( + outputs, + "(samples batch) variates seq_len -> batch variates seq_len samples", + samples=num_samples, + ).detach() + + return unfolded_outputs[:, :, start_index:end_index, :] + + @staticmethod + def create_affine_transformed(base_distr: Distribution, loc: torch.Tensor, scale: torch.Tensor) -> Distribution: + base_shape = base_distr.mean.shape + base_time_dim = base_shape[-1] + loc_time_dim = loc.shape[-1] + + if loc_time_dim == 1: + return AffineTransformed(base_distr, loc=loc, scale=scale) + + return AffineTransformed(base_distr, loc=loc[:, :, -base_time_dim:], scale=scale[:, :, -base_time_dim:]) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/__init__.py new file mode 100644 index 0000000000000..ba26b1edd945e --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py new file mode 100644 index 0000000000000..387ea30204c11 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py @@ -0,0 +1,213 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +import logging +import warnings +from enum import Enum +from typing import TYPE_CHECKING, Optional, Union + +import torch +from einops import rearrange +from jaxtyping import Bool, Float, Int + +from .rope import TimeAwareRotaryEmbedding + +if TYPE_CHECKING: + from .util import KVCache + +log = logging.getLogger(__name__) + +try: + from xformers.ops import LowerTriangularMask, memory_efficient_attention + + XFORMERS_AVAILABLE = True + log.info("xFormers Memory-Efficient Attention available.") +except ImportError: + warnings.warn( + "xFormers Memory-Efficient Attention not available. " + "Falling back to native PyTorch scaled_dot_product_attention.", + ImportWarning, + ) + + XFORMERS_AVAILABLE = False + +from torch.nn.functional import scaled_dot_product_attention + + +class AttentionAxis(Enum): + TIME = 1 + SPACE = 2 + + +class BaseMultiheadAttention(torch.nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float, + rotary_emb: Optional[TimeAwareRotaryEmbedding], + use_memory_efficient_attention: bool, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads." + self.head_dim = embed_dim // num_heads + self.rotary_emb = rotary_emb + + self.wQKV = torch.nn.Linear(embed_dim, embed_dim * 3) + self.dropout = dropout + self.use_memory_efficient_attention = use_memory_efficient_attention + self.wO = torch.nn.Linear(embed_dim, embed_dim) + + assert not ( + not XFORMERS_AVAILABLE and self.use_memory_efficient_attention + ), "XFORMERS_AVAILABLE is False, so use_memory_efficient_attention must be False" + + if not hasattr(self, "attention_axis") or self.attention_axis not in (AttentionAxis.TIME, AttentionAxis.SPACE): + raise ValueError("Child class must define attention_axis as AttentionAxis.TIME or AttentionAxis.SPACE.") + + def rearrange_inputs( + self, inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"] + ) -> Float[torch.Tensor, "... embed_dim"]: + pattern = ( + "batch variate seq_len embed_dim -> (batch variate) seq_len embed_dim" + if self.attention_axis == AttentionAxis.TIME + else "batch variate seq_len embed_dim -> (batch seq_len) variate embed_dim" + ) + return rearrange(inputs, pattern) + + def get_qkv(self, inputs: torch.Tensor) -> tuple[torch.Tensor, ...]: + if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention: + pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate seq_len n_heads head_dim" + elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention: + pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate n_heads seq_len head_dim" + elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention: + pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len variate n_heads head_dim" + elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention: + pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len n_heads variate head_dim" + + qkv = self.wQKV(inputs.contiguous()) + return rearrange(qkv, pattern, qkv=3, head_dim=self.head_dim, n_heads=self.num_heads).unbind(dim=0) + + def positional_embedding(self, q, k, v, kv_cache, layer_idx): + seq_pos_offset = 0 + if self.rotary_emb is not None and self.attention_axis == AttentionAxis.TIME: + if kv_cache is not None: + seq_pos_offset = kv_cache.seq_len(layer_idx) + q, k = self.rotary_emb.rotate_queries_and_keys(q, k, seq_pos_offset=seq_pos_offset) + + if kv_cache is not None and self.attention_axis == AttentionAxis.TIME: + kv_cache.append(layer_idx, (k, v)) + k, v = kv_cache[layer_idx] + + q = q.contiguous() + k = k.contiguous().to(q.dtype) + v = v.contiguous().to(q.dtype) + + return q, k, v, seq_pos_offset + + def rearrange_output( + self, output: torch.Tensor, batch: int, variate: int, seq_len: int + ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]: + if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention: + pattern = "(batch variate) seq_len n_heads head_dim -> batch variate seq_len (n_heads head_dim)" + elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention: + pattern = "(batch variate) n_heads seq_len head_dim -> batch variate seq_len (n_heads head_dim)" + elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention: + pattern = "(batch seq_len) variate n_heads head_dim -> batch variate seq_len (n_heads head_dim)" + elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention: + pattern = "(batch seq_len) n_heads variate head_dim -> batch variate seq_len (n_heads head_dim)" + + return rearrange(output, pattern, batch=batch, variate=variate, seq_len=seq_len) + + def run_attention(self, attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate): + q_dim_start, q_dim_end = seq_pos_offset, seq_pos_offset + seq_len + kv_dim_start, kv_dim_end = 0, v.shape[1] if self.use_memory_efficient_attention else v.shape[2] + if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention: + attention_mask = ( + attention_mask[..., q_dim_start:q_dim_end, kv_dim_start:kv_dim_end] + if torch.is_tensor(attention_mask) + else LowerTriangularMask() if seq_pos_offset == 0 else None + ) + return memory_efficient_attention(q, k, v, attn_bias=attention_mask, p=dropout) + elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention: + attention_mask = ( + attention_mask[..., q_dim_start:q_dim_end, kv_dim_start:kv_dim_end] + if torch.is_tensor(attention_mask) + else None + ) + return scaled_dot_product_attention( + q, k, v, + attn_mask=attention_mask, + dropout_p=dropout, + is_causal=(attention_mask is None and seq_pos_offset == 0), + ) + elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention: + attention_mask = ( + attention_mask[..., kv_dim_start:kv_dim_end, kv_dim_start:kv_dim_end] + if torch.is_tensor(attention_mask) + else None + ) + return memory_efficient_attention(q, k, v, attn_bias=attention_mask, p=dropout) + elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention: + attention_mask = ( + attention_mask[..., kv_dim_start:kv_dim_end, kv_dim_start:kv_dim_end] + if torch.is_tensor(attention_mask) + else None + ) + return scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False) + + def forward( + self, + layer_idx: int, + inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"], + attention_mask: Optional[ + Union[ + Bool[torch.Tensor, "batch_X_variate n_heads seq_len seq_len"], + Bool[torch.Tensor, "batch_X_seq_len n_heads variate variate"], + ] + ] = None, + kv_cache: Optional["KVCache"] = None, + ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]: + batch_size, variate, seq_len, _ = inputs.shape + dropout = self.dropout if self.training else 0.0 + + rearranged_inputs = self.rearrange_inputs(inputs) + q, k, v = self.get_qkv(rearranged_inputs) + + q, k, v, seq_pos_offset = self.positional_embedding(q, k, v, kv_cache, layer_idx) + + output = self.run_attention(attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate) + + output = self.rearrange_output(output, batch_size, variate, seq_len) + return self.wO(output) + + +class TimeWiseMultiheadAttention(BaseMultiheadAttention): + attention_axis = AttentionAxis.TIME + + +class SpaceWiseMultiheadAttention(BaseMultiheadAttention): + attention_axis = AttentionAxis.SPACE + + +MultiHeadAttention = TimeWiseMultiheadAttention | SpaceWiseMultiheadAttention diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py new file mode 100644 index 0000000000000..50cf4943185e8 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py @@ -0,0 +1,236 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +from math import ceil +from typing import NamedTuple, Optional, Type, cast + +import torch +from einops import rearrange, repeat +from jaxtyping import Bool, Float, Int + +from .distribution import DISTRIBUTION_CLASSES_LOOKUP, DistributionOutput +from .embedding import PatchEmbedding +from .fusion import Fusion +from .scaler import scaler_types +from .transformer import Transformer +from .util import KVCache + + +class TotoOutput(NamedTuple): + """ + Output of the Toto model. Contains the output distribution, the location parameters, + and the scale parameters. + """ + + distribution: torch.distributions.Distribution + loc: Float[torch.Tensor, "batch variate"] + scale: Float[torch.Tensor, "batch variate"] + + +class TotoBackbone(torch.nn.Module): + """ + Toto (Timeseries-Optimized Transformer for Observability) is a transformer-based model + for multivariate time series forecasting. + """ + + def __init__( + self, + patch_size: int, + stride: int, + embed_dim: int, + num_layers: int, + num_heads: int, + mlp_hidden_dim: int, + dropout: float, + spacewise_every_n_layers: int, + scaler_cls: str, + output_distribution_classes: list[str], + spacewise_first: bool = True, + output_distribution_kwargs: dict | None = None, + use_memory_efficient_attention: bool = True, + stabilize_with_global: bool = True, + scale_factor_exponent: float = 10.0, + ): + super().__init__() + self.embed_dim = embed_dim + self.fusion: Optional[Fusion] = None + self.num_prepended_tokens: int = 0 + self.target_variate_label: Optional[torch.nn.Parameter] = None + self.exogenous_variate_label: Optional[torch.nn.Parameter] = None + + if scaler_cls == "": + self.scaler = scaler_types[scaler_cls]( + patch_size=patch_size, + stabilize_with_global=stabilize_with_global, + scale_factor_exponent=scale_factor_exponent, + ) + else: + self.scaler = scaler_types[scaler_cls]() + + self.patch_embed = PatchEmbedding(patch_size, stride, embed_dim) + self.dropout = dropout + self.num_layers = num_layers + self.use_memory_efficient_attention = use_memory_efficient_attention + self.transformer = Transformer( + embed_dim=embed_dim, + num_heads=num_heads, + num_layers=self.num_layers, + mlp_hidden_dim=mlp_hidden_dim, + dropout=dropout, + spacewise_every_n_layers=spacewise_every_n_layers, + spacewise_first=spacewise_first, + use_memory_efficient_attention=self.use_memory_efficient_attention, + fusion=self.fusion, + ) + self.unembed = torch.nn.Linear(embed_dim, embed_dim * patch_size) + + output_distribution_classes_ = [DISTRIBUTION_CLASSES_LOOKUP[c] for c in output_distribution_classes] + self.output_distribution = output_distribution_classes_[0](embed_dim, **(output_distribution_kwargs or {})) + + def allocate_kv_cache( + self, + batch_size: int, + num_variates: int, + max_time_steps: int, + device: torch.device, + dtype: torch.dtype, + ) -> KVCache: + return KVCache( + batch_size=batch_size, + num_variates=num_variates, + transformer_layers=list(self.transformer.layers), + num_layers=self.num_layers, + embed_dim=self.embed_dim, + num_heads=cast(int, self.transformer.layers[0].num_heads), + max_seq_len=ceil(max_time_steps / self.patch_embed.stride), + device=device, + dtype=dtype, + use_memory_efficient_attention=self.use_memory_efficient_attention, + ) + + def backbone( + self, + inputs: Float[torch.Tensor, "batch variate time_steps"], + input_padding_mask: Bool[torch.Tensor, "batch variate time_steps"], + id_mask: Float[torch.Tensor, "batch #variate time_steps"], + kv_cache: Optional[KVCache] = None, + scaling_prefix_length: Optional[int] = None, + num_exogenous_variables: int = 0, + ) -> tuple[ + Float[torch.Tensor, "batch variates time_steps embed_dim"], + Float[torch.Tensor, "batch variates time_steps"], + Float[torch.Tensor, "batch variates time_steps"], + ]: + scaled_inputs, loc, scale = self.scaler( + inputs, + weights=torch.ones_like(inputs, device=inputs.device), + padding_mask=input_padding_mask, + prefix_length=scaling_prefix_length, + ) + + if kv_cache is not None: + kv_cache_len_tensor = kv_cache.current_len(0) + kv_cache_len = ( + int(kv_cache_len_tensor) if isinstance(kv_cache_len_tensor, torch.Tensor) else kv_cache_len_tensor + ) + prefix_len = max(0, self.patch_embed.stride * (kv_cache_len - self.num_prepended_tokens)) + + scaled_inputs = scaled_inputs[:, :, prefix_len:] + + assert (prefix_len == 0) or ( + scaled_inputs.shape[-1] == self.patch_embed.stride + ), "Must decode one step at a time." + + input_padding_mask = input_padding_mask[:, :, prefix_len:] + id_mask = id_mask[:, :, prefix_len:] + + embeddings, reduced_id_mask = self.patch_embed(scaled_inputs, id_mask) + + variate_label_embeds = self.build_variate_label_embeds(num_exogenous_variables, embeddings) + + original_seq_len = embeddings.shape[2] + transformed = self.transformer(embeddings, reduced_id_mask, kv_cache, variate_label_embeds=variate_label_embeds) + added_tokens = transformed.shape[2] - original_seq_len + if added_tokens > 0: + transformed = transformed[:, :, added_tokens:] + + flattened: Float[torch.Tensor, "batch variates new_seq_len embed_dim"] = rearrange( + self.unembed(transformed), + "batch variates seq_len (patch_size embed_dim) -> batch variates (seq_len patch_size) embed_dim", + embed_dim=self.embed_dim, + ) + return flattened, loc, scale + + def forward( + self, + inputs: Float[torch.Tensor, "batch variate time_steps"], + input_padding_mask: Bool[torch.Tensor, "batch variate time_steps"], + id_mask: Float[torch.Tensor, "batch #variate time_steps"], + kv_cache: Optional[KVCache] = None, + scaling_prefix_length: Optional[int] = None, + num_exogenous_variables: int = 0, + ) -> TotoOutput: + flattened, loc, scale = self.backbone( + inputs, + input_padding_mask, + id_mask, + kv_cache, + scaling_prefix_length, + num_exogenous_variables, + ) + + return TotoOutput(self.output_distribution(flattened), loc, scale) + + @property + def device(self): + return next(self.parameters()).device + + def enable_variate_labels(self) -> None: + self.fusion = Fusion() + self.num_prepended_tokens = 1 + self.target_variate_label = torch.nn.Parameter(torch.randn(self.embed_dim)) + self.exogenous_variate_label = torch.nn.Parameter(torch.randn(self.embed_dim)) + if hasattr(self, "transformer") and self.transformer is not None: + self.transformer.fusion = self.fusion + + def build_variate_label_embeds( + self, + num_exogenous_variables: int, + embeddings: Float[torch.Tensor, "batch variate seq_len embed_dim"], + ) -> Optional[Float[torch.Tensor, "batch variate 1 embed_dim"]]: + if self.fusion is None: + return None + + assert self.target_variate_label is not None + assert self.exogenous_variate_label is not None + + batch_size, num_variates, _, _ = embeddings.shape + + target_variate_label = repeat(self.target_variate_label, "d -> b v 1 d", b=batch_size, v=num_variates).to( + device=embeddings.device, dtype=embeddings.dtype + ) + exogenous_variate_label = repeat(self.exogenous_variate_label, "d -> b v 1 d", b=batch_size, v=num_variates).to( + device=embeddings.device, dtype=embeddings.dtype + ) + exog_mask = torch.zeros(1, num_variates, 1, 1, dtype=torch.bool, device=embeddings.device) + if num_exogenous_variables > 0: + exog_mask[:, -num_exogenous_variables:] = True + return torch.where(exog_mask, exogenous_variate_label, target_variate_label) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py new file mode 100644 index 0000000000000..ac7023321a31c --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +from abc import ABC + +import torch +import torch.nn.functional as F +from torch.distributions import TransformedDistribution +from torch.distributions.transforms import AffineTransform + + +class AffineTransformed(TransformedDistribution): + """ + A thin wrapper around TransformedDistribution with an AffineTransform, + replacing the gluonts.torch.distributions.AffineTransformed dependency. + Provides the same interface: mean, variance, sample(), log_prob(). + """ + + def __init__(self, base_distribution, loc=0.0, scale=1.0): + super().__init__(base_distribution, AffineTransform(loc=loc, scale=scale)) + + @property + def mean(self): + # mean(aX + b) = a * mean(X) + b + loc = self.transforms[0].loc + scale = self.transforms[0].scale + return loc + scale * self.base_dist.mean + + # Note: Do NOT override sample() here. TransformedDistribution.sample() correctly + # calls base_dist.sample() (not rsample), which works for non-reparameterizable + # distributions like MixtureSameFamily. + + +class DistributionOutput(ABC, torch.nn.Module): + pass + + +class StudentTOutput(DistributionOutput): + def __init__(self, embed_dim): + super().__init__() + self.embed_dim = embed_dim + self.df = torch.nn.Linear(embed_dim, 1) + self.loc_proj = torch.nn.Linear(embed_dim, 1) + self.scale_proj = torch.nn.Linear(embed_dim, 1) + + def forward(self, inputs, loc=None, scale=None): + eps = torch.finfo(inputs.dtype).eps + df = 2.0 + F.softplus(self.df(inputs)).clamp_min(eps).squeeze(-1) + base_loc = self.loc_proj(inputs).squeeze(-1) + base_scale = F.softplus(self.scale_proj(inputs)).clamp_min(eps).squeeze(-1) + + base_dist = torch.distributions.StudentT(df, base_loc, base_scale, validate_args=False) + + if loc is not None and scale is not None: + return AffineTransformed(base_dist, loc=loc, scale=scale) + return base_dist + + +class MixtureOfStudentTsOutput(DistributionOutput): + def __init__(self, embed_dim, k_components): + super().__init__() + self.embed_dim = embed_dim + self.k_components = k_components + + self.df = torch.nn.Linear(embed_dim, k_components) + self.loc_proj = torch.nn.Linear(embed_dim, k_components) + self.scale_proj = torch.nn.Linear(embed_dim, k_components) + self.mixture_weights = torch.nn.Linear(embed_dim, k_components) + + def forward(self, inputs, loc=None, scale=None): + df = 2.0 + F.softplus(self.df(inputs)).clamp_min(torch.finfo(inputs.dtype).eps) + component_loc = self.loc_proj(inputs) + component_scale = F.softplus(self.scale_proj(inputs)).clamp_min(torch.finfo(inputs.dtype).eps) + logits = self.mixture_weights(inputs) + probs = F.softmax(logits, dim=-1) + components = torch.distributions.StudentT(df, component_loc, component_scale, validate_args=False) + mixture_distribution = torch.distributions.Categorical(probs=probs) + + return torch.distributions.MixtureSameFamily(mixture_distribution, components) + + +DISTRIBUTION_CLASSES_LOOKUP = { + "": StudentTOutput, + "": MixtureOfStudentTsOutput, + # Short-form aliases for convenience + "student_t": StudentTOutput, + "student_t_mixture": MixtureOfStudentTsOutput, +} diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py new file mode 100644 index 0000000000000..248fa1326bedb --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +from typing import Optional + +import torch +from jaxtyping import Float, Int, Num + + +def patchify_id_mask( + id_mask: Int[torch.Tensor, "batch variate time_steps"], patch_size: int +) -> Int[torch.Tensor, "batch variate seq_len patch_size"]: + patched_id_mask = id_mask.unfold(dimension=-1, size=patch_size, step=patch_size) + patched_id_mask_min = patched_id_mask.min(-1).values + patched_id_mask_max = patched_id_mask.max(-1).values + assert torch.eq(patched_id_mask_min, patched_id_mask_max).all(), "Patches cannot span multiple datasets" + return patched_id_mask_min + + +class PatchEmbedding(torch.nn.Module): + """ + Multivariate time series patch embedding. + Patchifies each variate separately. + """ + + def __init__(self, patch_size: int, stride: int, embed_dim: int): + super().__init__() + self.patch_size = patch_size + self.embed_dim = embed_dim + self.stride = stride + self.projection = torch.nn.Linear(self.patch_size, self.embed_dim) + + def _patchify( + self, x: Num[torch.Tensor, "batch variate time_steps"] + ) -> Num[torch.Tensor, "batch variate seq_len patch_size"]: + return x.unfold(dimension=-1, size=self.patch_size, step=self.stride) + + def forward( + self, + x: Float[torch.Tensor, "batch #variate time_steps"], + id_mask: Float[torch.Tensor, "batch time_steps"], + ) -> tuple[ + Float[torch.Tensor, "batch variate seq_len embed_dim"], + Int[torch.Tensor, "batch seq_len"], + ]: + assert ( + x.shape[-1] % self.patch_size == 0 + ), f"Series length ({x.shape=}) must be divisible by ({self.patch_size=})" + x_patched: Float[torch.Tensor, "batch variate seq_len patch_size"] = self._patchify(x) + id_mask_patched: Int[torch.Tensor, "batch variate seq_len patch_size"] = self._patchify(id_mask) + + assert torch.eq( + id_mask_patched.min(-1).values, id_mask_patched.max(-1).values + ).all(), "Patches cannot span multiple datasets" + + return ( + self.projection(x_patched), + id_mask_patched.min(-1).values, + ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/feed_forward.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/feed_forward.py new file mode 100644 index 0000000000000..024a8bed7277f --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/feed_forward.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +import torch +import torch.nn.functional as F + + +class SwiGLU(torch.nn.Module): + """ + https://arxiv.org/abs/2002.05202 + NOTE: x should be 2x the size you want + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Note this ordering is unusual, but is done so to match xFormers + gate, x = x.chunk(2, dim=-1) + return F.silu(gate) * x diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py new file mode 100644 index 0000000000000..4d6723b251cfa --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +from typing import Optional + +import torch +import torch.nn.functional as F +from jaxtyping import Float + + +class Fusion(torch.nn.Module): + """ + Prepends variate label embeddings to the input embeddings along the sequence dimension. + """ + + def __init__(self) -> None: + super().__init__() + + def forward( + self, + embeddings: Float[torch.Tensor, "batch variate seq_len embed_dim"], + variate_label_embeds: Optional[Float[torch.Tensor, "batch variate 1 embed_dim"]] = None, + ) -> Float[torch.Tensor, "batch variate new_seq_len embed_dim"]: + + if variate_label_embeds is None: + return embeddings + + processed_embeddings = F.normalize(variate_label_embeds, p=2, dim=-1) + + return torch.cat( + [processed_embeddings.to(dtype=embeddings.dtype, device=embeddings.device, non_blocking=True), embeddings], + dim=2, + ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/rope.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/rope.py new file mode 100644 index 0000000000000..96e6251707705 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/rope.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +from typing import Optional + +import torch +from einops import rearrange +from jaxtyping import Int +from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb +from rotary_embedding_torch.rotary_embedding_torch import default + + +def exists(val): + return val is not None + + +class TimeAwareRotaryEmbedding(RotaryEmbedding): + """ + A variant of the rotary position embedding that (optionally) uses the time index + to compute the sinusoidal and cosine embeddings. Useful for time series data. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # If the parent stored `freqs` as a Parameter, remove it and register as a buffer + if hasattr(self, "freqs") and isinstance(self.freqs, torch.nn.Parameter): + freqs_data = self.freqs.data + self._parameters.pop("freqs") + self.register_buffer("freqs", freqs_data, persistent=False) + + def rotate_queries_and_keys( + self, + q: torch.Tensor, + k: torch.Tensor, + seq_dim: Optional[int] = None, + seq_pos: Optional[Int[torch.Tensor, "... seq_len"]] = None, + seq_pos_offset: int = 0, + ): + if seq_dim is None: + seq_dim = self.default_seq_dim + + assert self.use_xpos + device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] + + seq = default(seq_pos, self.get_seq_pos(seq_len, dtype=dtype, device=device)) + seq = seq + seq_pos_offset + + freqs = self.forward(seq) + + scale = self.get_scale(seq).to(dtype) + + if seq_dim == -3: + num_heads = q.shape[-2] + freqs = freqs.unsqueeze(1).expand(-1, num_heads, -1) + scale = scale.unsqueeze(1).expand(-1, num_heads, -1) + + rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim) + rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def get_scale( + self, + t: torch.Tensor, + ): + assert self.use_xpos + + power = (t - t.max(-1).values.unsqueeze(-1) // 2) / self.scale_base + + scale = self.scale ** rearrange(power, "... n -> ... n 1") + scale = torch.cat((scale, scale), dim=-1) + + return scale diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py new file mode 100644 index 0000000000000..255211f3772c5 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py @@ -0,0 +1,299 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +import warnings +from typing import Tuple + +import torch +from einops import reduce, repeat + + +class Scaler(torch.nn.Module): + """ + Minimal base class replacing gluonts.torch.scaler.Scaler. + Provides a __call__ interface for scaling data. + """ + + pass + + +class StdMeanScaler(Scaler): + """ + Scales data to have zero mean and unit variance along a given dimension. + """ + + def __init__( + self, + dim: int = -1, + keepdim: bool = True, + minimum_scale: float = 1e-3, + ) -> None: + super().__init__() + self.dim = dim + self.keepdim = keepdim + self.minimum_scale = minimum_scale + + def __call__( + self, + data: torch.Tensor, + padding_mask: torch.Tensor, + weights: torch.Tensor, + prefix_length: int | None = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert data.shape == weights.shape, "data and weights must have same shape" + with torch.no_grad(): + if prefix_length is not None: + prefix_mask = torch.zeros_like(weights) + prefix_mask[..., :prefix_length] = 1.0 + weights = weights * prefix_mask + + weights = weights * padding_mask + + try: + high_precision_data = data.to(torch.float64) + except TypeError: + warnings.warn( + f"Float64 is not supported by device {data.device}. " + "Using float32 instead for accumulating denominator in input scaler. " + "This may lead to overflow issues if the data contains extreme values.", + RuntimeWarning, + ) + high_precision_data = data.to(torch.float32) + + denominator = weights.sum(self.dim, keepdim=self.keepdim).clamp_min(1.0).to(high_precision_data.dtype) + means = (high_precision_data * weights).sum(self.dim, keepdim=self.keepdim) / denominator + means = torch.nan_to_num(means) + + variance = (((high_precision_data - means) * weights) ** 2).sum( + self.dim, keepdim=self.keepdim + ) / denominator + scale = torch.sqrt(variance + self.minimum_scale).to(data.dtype) + loc = means.to(data.dtype) + + return (data - loc) / scale, loc, scale + + +def compute_causal_statistics( + data: torch.Tensor, + weights: torch.Tensor, + padding_mask: torch.Tensor, + dim: int, + minimum_scale: float, + use_bessel_correction: bool = True, + stabilize_with_global: bool = False, + scale_factor_exponent: float = 10.0, + prefix_length: int | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dim == -1, "compute_causal_statistics only supports dim=-1 (last dimension)" + + with torch.no_grad(): + weights = weights * padding_mask + + try: + high_precision_data = data.to(torch.float64) + high_precision_weights = weights.to(torch.float64) + except TypeError: + warnings.warn( + f"Float64 is not supported by device {data.device}. " + "Using float32 instead for causal scaler calculations.", + RuntimeWarning, + ) + high_precision_data = data.to(torch.float32) + high_precision_weights = weights.to(torch.float32) + + prev_deterministic = torch.are_deterministic_algorithms_enabled() + if prev_deterministic and data.device.type == "cuda": + torch.use_deterministic_algorithms(False) + + try: + weighted_data = high_precision_weights * high_precision_data + + cum_weights = torch.cumsum(high_precision_weights, dim=dim) + cum_values = torch.cumsum(weighted_data, dim=dim) + + denominator = cum_weights.clamp_min(1.0) + causal_means = cum_values / denominator + + shifted_means = torch.zeros_like(causal_means) + shifted_means[..., 1:] = causal_means[..., :-1] + + delta = high_precision_data - shifted_means + increment = delta * (high_precision_data - causal_means) * high_precision_weights + m_2 = torch.cumsum(increment, dim=dim) + + if use_bessel_correction: + causal_variance = m_2 / torch.clamp(denominator - 1.0, min=1.0) + else: + causal_variance = m_2 / denominator + + causal_scale = torch.sqrt(causal_variance + minimum_scale) + + if stabilize_with_global: + if prefix_length is not None: + prefix_mask = torch.zeros_like(weights) + prefix_mask[..., :prefix_length] = 1.0 + weighted_data = weighted_data * prefix_mask + weights = weights * prefix_mask + padding_mask = padding_mask * prefix_mask + + scale_factor_min = 10.0 ** (-scale_factor_exponent) + scale_factor_max = 10.0**scale_factor_exponent + + global_denominator = (weights * padding_mask).sum(dim, keepdim=True).clamp_min(1.0) + global_means = (weighted_data).sum(dim, keepdim=True) / global_denominator + global_means = torch.nan_to_num(global_means) + + global_variance = (((high_precision_data - global_means) * weights * padding_mask) ** 2).sum( + dim, keepdim=True + ) / global_denominator + global_scale = torch.sqrt(global_variance + minimum_scale) + + expanded_global_scale = global_scale.expand_as(causal_scale) + min_allowed_scale = expanded_global_scale * scale_factor_min + max_allowed_scale = expanded_global_scale * scale_factor_max + + causal_scale = torch.clamp( + causal_scale, + min=torch.max(torch.tensor(minimum_scale, device=causal_scale.device), min_allowed_scale), + max=max_allowed_scale, + ) + + causal_means = causal_means.to(data.dtype) + causal_scale = causal_scale.to(data.dtype) + + finally: + if prev_deterministic and data.device.type == "cuda": + torch.use_deterministic_algorithms(True) + + return causal_means, causal_scale + + +class CausalStdMeanScaler(Scaler): + def __init__( + self, + dim: int = -1, + minimum_scale: float = 0.1, + use_bessel_correction: bool = True, + stabilize_with_global: bool = False, + scale_factor_exponent: float = 10.0, + ) -> None: + super().__init__() + assert dim == -1, "CausalStdMeanScaler only supports dim=-1 (last dimension)" + self.dim = dim + self.minimum_scale = minimum_scale + self.use_bessel_correction = use_bessel_correction + self.stabilize_with_global = stabilize_with_global + self.scale_factor_exponent = scale_factor_exponent + + def __call__( + self, + data: torch.Tensor, + padding_mask: torch.Tensor, + weights: torch.Tensor, + prefix_length: int | None = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert data.shape == weights.shape, "data and weights must have same shape" + assert len(data.shape) == 3, "Input data must have shape [batch, variates, time_steps]" + + causal_means, causal_scale = compute_causal_statistics( + data, + weights, + padding_mask, + self.dim, + self.minimum_scale, + self.use_bessel_correction, + self.stabilize_with_global, + self.scale_factor_exponent, + prefix_length, + ) + + scaled_data = (data - causal_means) / causal_scale + + return scaled_data, causal_means, causal_scale + + +class CausalPatchStdMeanScaler(Scaler): + def __init__( + self, + dim: int = -1, + patch_size: int = 32, + minimum_scale: float = 0.1, + use_bessel_correction: bool = True, + stabilize_with_global: bool = False, + scale_factor_exponent: float = 10.0, + ) -> None: + super().__init__() + assert dim == -1, "CausalPatchStdMeanScaler only supports dim=-1 (last dimension)" + self.dim = dim + self.patch_size = patch_size + self.minimum_scale = minimum_scale + self.use_bessel_correction = use_bessel_correction + self.stabilize_with_global = stabilize_with_global + self.scale_factor_exponent = scale_factor_exponent + + def __call__( + self, + data: torch.Tensor, + padding_mask: torch.Tensor, + weights: torch.Tensor, + prefix_length: int | None = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert data.shape == weights.shape, "data and weights must have same shape" + assert len(data.shape) == 3, "Input data must have shape [batch, variates, time_steps]" + + with torch.no_grad(): + time_steps = data.shape[-1] + assert ( + time_steps % self.patch_size == 0 + ), f"Time steps ({time_steps}) must be divisible by patch size ({self.patch_size})" + + causal_means, causal_scale = compute_causal_statistics( + data, + weights, + padding_mask, + -1, + self.minimum_scale, + self.use_bessel_correction, + self.stabilize_with_global, + self.scale_factor_exponent, + prefix_length, + ) + + means_unfolded = causal_means.unfold(-1, self.patch_size, self.patch_size) + scales_unfolded = causal_scale.unfold(-1, self.patch_size, self.patch_size) + + patch_stats_means = means_unfolded[..., -1] + patch_stats_scales = scales_unfolded[..., -1] + + patch_means = repeat(patch_stats_means, "b v p -> b v (p s)", s=self.patch_size) + patch_scales = repeat(patch_stats_scales, "b v p -> b v (p s)", s=self.patch_size) + + scaled_data = (data - patch_means) / patch_scales + + return scaled_data, patch_means, patch_scales + + +# for deserialization of SafeTensors checkpoints +scaler_types = { + "": StdMeanScaler, + "": CausalStdMeanScaler, + "": CausalPatchStdMeanScaler, +} diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py new file mode 100644 index 0000000000000..1710c9eeadf05 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +import json +import os +import re +from pathlib import Path +from typing import Dict, Optional, Union + +import safetensors.torch as safetorch +import torch + +from .attention import XFORMERS_AVAILABLE +from .backbone import TotoBackbone +from .transformer import XFORMERS_SWIGLU_AVAILABLE + + +class Toto(torch.nn.Module): + """ + PyTorch module for Toto (Timeseries-Optimized Transformer for Observability). + This class is used internally for checkpoint loading logic. + """ + + def __init__( + self, + patch_size: int, + stride: int, + embed_dim: int, + num_layers: int, + num_heads: int, + mlp_hidden_dim: int, + dropout: float, + spacewise_every_n_layers: int, + scaler_cls: str, + output_distribution_classes: list[str], + spacewise_first: bool = True, + output_distribution_kwargs: dict | None = None, + use_memory_efficient_attention: bool = True, + stabilize_with_global: bool = True, + scale_factor_exponent: float = 10.0, + **model_kwargs, + ): + super().__init__() + self.model = TotoBackbone( + patch_size=patch_size, + stride=stride, + embed_dim=embed_dim, + num_layers=num_layers, + num_heads=num_heads, + mlp_hidden_dim=mlp_hidden_dim, + dropout=dropout, + spacewise_every_n_layers=spacewise_every_n_layers, + scaler_cls=scaler_cls, + output_distribution_classes=output_distribution_classes, + spacewise_first=spacewise_first, + output_distribution_kwargs=output_distribution_kwargs, + use_memory_efficient_attention=use_memory_efficient_attention, + stabilize_with_global=stabilize_with_global, + scale_factor_exponent=scale_factor_exponent, + ) + + @classmethod + def load_from_checkpoint( + cls, + checkpoint_path, + map_location: str = "cpu", + strict=True, + **model_kwargs, + ): + if os.path.isdir(checkpoint_path): + safetensors_file = os.path.join(checkpoint_path, "model.safetensors") + else: + safetensors_file = checkpoint_path + + if os.path.exists(safetensors_file): + model_state = safetorch.load_file(safetensors_file, device=map_location) + else: + raise FileNotFoundError(f"Model checkpoint not found at: {safetensors_file}") + + config_file = os.path.join(checkpoint_path, "config.json") + config = {} + if os.path.exists(config_file): + with open(config_file, "r") as f: + config = json.load(f) + + config.update(model_kwargs) + + remapped_state_dict = cls._map_state_dict_keys( + model_state, XFORMERS_SWIGLU_AVAILABLE and not config.get("pre_xformers_checkpoint", False) + ) + + if not XFORMERS_AVAILABLE and config.get("use_memory_efficient_attention", True): + config["use_memory_efficient_attention"] = False + + instance = cls(**config) + instance.to(map_location) + + filtered_remapped_state_dict = { + k: v + for k, v in remapped_state_dict.items() + if k in instance.state_dict() and not k.endswith("rotary_emb.freqs") + } + + instance.load_state_dict(filtered_remapped_state_dict, strict=strict) + return instance + + @staticmethod + def _map_state_dict_keys(state_dict, use_fused_swiglu): + if use_fused_swiglu: + remap_keys = { + "mlp.0.weight": "mlp.0.w12.weight", + "mlp.0.bias": "mlp.0.w12.bias", + "mlp.2.weight": "mlp.0.w3.weight", + "mlp.2.bias": "mlp.0.w3.bias", + } + else: + remap_keys = { + "mlp.0.w12.weight": "mlp.0.weight", + "mlp.0.w12.bias": "mlp.0.bias", + "mlp.0.w3.weight": "mlp.2.weight", + "mlp.0.w3.bias": "mlp.2.bias", + } + + def replace_key(text): + for pattern, replacement in remap_keys.items(): + text = re.sub(pattern, replacement, text) + return text + + return {replace_key(k): v for k, v in state_dict.items()} + + @property + def device(self): + return next(self.model.parameters()).device diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py new file mode 100644 index 0000000000000..9979bac960269 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py @@ -0,0 +1,287 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +import warnings +from typing import Literal, Optional, Union, cast + +import torch +import torch.nn.functional as F +from einops import rearrange +from jaxtyping import Bool, Float, Int +from rotary_embedding_torch import RotaryEmbedding + +from .attention import ( + AttentionAxis, + MultiHeadAttention, + SpaceWiseMultiheadAttention, + TimeWiseMultiheadAttention, +) +from .feed_forward import SwiGLU +from .fusion import Fusion +from .rope import TimeAwareRotaryEmbedding +from .util import KVCache, RMSNorm, make_batched_block_mask + +try: + from xformers.ops.swiglu_op import SwiGLU as SwiGLU_fused + + XFORMERS_SWIGLU_AVAILABLE = True +except ImportError: + warnings.warn( + "xFormers fused SwiGLU kernel not found. " + "Using native PyTorch implementation for feed-forward layers.", + ImportWarning, + ) + XFORMERS_SWIGLU_AVAILABLE = False + + +class TransformerLayer(torch.nn.Module): + embed_dim: int + num_heads: int + mlp_hidden_dim: int + dropout: float + attention_axis: AttentionAxis + + def __init__( + self, + embed_dim: int, + num_heads: int, + mlp_hidden_dim: int, + dropout: float, + rotary_emb: RotaryEmbedding = None, + attention_axis: AttentionAxis = AttentionAxis.TIME, + RMS_norm: bool = True, + use_memory_efficient_attention: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.mlp_hidden_dim = mlp_hidden_dim + self.dropout = dropout + self.attention_axis = attention_axis + + if RMS_norm: + self.norm1: Union[RMSNorm, torch.nn.LayerNorm] = RMSNorm(embed_dim) + self.norm2: Union[RMSNorm, torch.nn.LayerNorm] = RMSNorm(embed_dim) + else: + self.norm1 = torch.nn.LayerNorm(embed_dim) + self.norm2 = torch.nn.LayerNorm(embed_dim) + + self.attention: MultiHeadAttention + + if attention_axis == AttentionAxis.TIME: + self.attention = TimeWiseMultiheadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + dropout=dropout, + rotary_emb=rotary_emb, + use_memory_efficient_attention=use_memory_efficient_attention, + ) + elif attention_axis == AttentionAxis.SPACE: + self.attention = SpaceWiseMultiheadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + dropout=dropout, + rotary_emb=None, + use_memory_efficient_attention=use_memory_efficient_attention, + ) + else: + raise ValueError("Invalid attention axis") + + if XFORMERS_SWIGLU_AVAILABLE: + self.mlp = torch.nn.Sequential( + SwiGLU_fused(in_features=embed_dim, hidden_features=mlp_hidden_dim), + torch.nn.Dropout(dropout), + ) + else: + self.mlp = torch.nn.Sequential( + torch.nn.Linear(embed_dim, 2 * mlp_hidden_dim), + SwiGLU(), + torch.nn.Linear(mlp_hidden_dim, embed_dim), + torch.nn.Dropout(dropout), + ) + + def forward( + self, + layer_idx: int, + inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"], + attention_mask: Optional[ + Union[ + Bool[torch.Tensor, "batch seq_len variate variate"], + Bool[torch.Tensor, "batch #variate seq_len seq_len"], + ] + ] = None, + kv_cache: Optional[KVCache] = None, + ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]: + pre_norm_1 = self.norm1(inputs) + hidden_state = inputs + self.attention(layer_idx, pre_norm_1, attention_mask, kv_cache).contiguous() + + pre_norm_2 = self.norm2(hidden_state) + return hidden_state + self.mlp(pre_norm_2) + + +class Transformer(torch.nn.Module): + def __init__( + self, + num_layers: int, + embed_dim: int, + num_heads: int, + mlp_hidden_dim: int, + dropout: float, + spacewise_every_n_layers: int, + spacewise_first: bool, + use_memory_efficient_attention: bool = True, + *, + fusion: Optional[Fusion] = None, + ): + super().__init__() + + assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads." + + self.rotary_emb = TimeAwareRotaryEmbedding( + embed_dim // num_heads, + use_xpos=True, + cache_if_possible=True, + seq_before_head_dim=use_memory_efficient_attention, + ) + attention_axes = self._get_layer_types(num_layers, spacewise_every_n_layers, spacewise_first) + + self.use_memory_efficient_attention = use_memory_efficient_attention + self.fusion = fusion + + self.layers = torch.nn.ModuleList( + [ + TransformerLayer( + embed_dim=embed_dim, + num_heads=num_heads, + mlp_hidden_dim=mlp_hidden_dim, + dropout=dropout, + rotary_emb=self.rotary_emb, + attention_axis=attention_axes[i], + use_memory_efficient_attention=self.use_memory_efficient_attention, + ) + for i in range(num_layers) + ] + ) + + def _get_mask( + self, + num_heads: int, + dtype: torch.dtype, + id_mask: Optional[torch.Tensor] = None, + ) -> Union[ + Bool[torch.Tensor, "batch num_heads seq_len seq_len"], + Float[torch.Tensor, "batch num_heads seq_len seq_len"], + Bool[torch.Tensor, "batch num_heads variate variate"], + Float[torch.Tensor, "batch num_heads variate variate"], + ]: + if id_mask is None: + raise ValueError("id_mask must be provided for spacewise masks.") + + mask = make_batched_block_mask(id_mask.transpose(-1, -2)) + + if self.use_memory_efficient_attention: + mask = self._pad_to_multiple(mask) + mask = mask.float().masked_fill(~mask, float("-inf")).masked_fill(mask, 0.0).to(dtype) + + mask = rearrange(mask, "batch seq_len variate1 variate2 -> (batch seq_len) 1 variate1 variate2") + return mask.expand(-1, num_heads, -1, -1).contiguous() + + def _pad_to_multiple( + self, + tensor: torch.Tensor, + multiple: int = 8, + causal: bool = False, + ) -> torch.Tensor: + pad_amount = (multiple - tensor.shape[-1] % multiple) % multiple + if pad_amount > 0: + new_size = tensor.shape[-1] + pad_amount + if causal: + full_mask = torch.tril(torch.ones((new_size, new_size), dtype=tensor.dtype, device=tensor.device)) + full_mask[: tensor.shape[-1], : tensor.shape[-1]] = tensor + tensor = full_mask + else: + tensor = F.pad(tensor, (0, pad_amount, 0, pad_amount)) + return tensor + + def _get_layer_types( + self, + num_layers: int, + spacewise_every_n_layers: int, + spacewise_first: bool, + ) -> list[AttentionAxis]: + if spacewise_every_n_layers == -1: + return [AttentionAxis.TIME] * num_layers + assert num_layers % spacewise_every_n_layers == 0 + + block = [AttentionAxis.TIME] * (spacewise_every_n_layers - 1) + + if spacewise_first: + block = [AttentionAxis.SPACE] + block + else: + block = block + [AttentionAxis.SPACE] + + return block * (num_layers // spacewise_every_n_layers) + + def forward( + self, + inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"], + id_mask: Float[torch.Tensor, "batch #variate seq_len"], + kv_cache: Optional[KVCache] = None, + variate_label_embeds: Optional[Float[torch.Tensor, "batch variate 1 embed_dim"]] = None, + ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]: + + if self.fusion is not None and variate_label_embeds is not None: + should_apply_fusion = True + if kv_cache is not None: + kv_len_tensor = kv_cache.current_len(0) + kv_len = int(kv_len_tensor) if isinstance(kv_len_tensor, torch.Tensor) else kv_len_tensor + should_apply_fusion = kv_len == 0 + if should_apply_fusion: + inputs = self.fusion(inputs, variate_label_embeds=variate_label_embeds) + + batch, _, seq_len, _ = inputs.shape + + if id_mask is not None and id_mask.shape[-1] != seq_len: + added = int(seq_len - id_mask.shape[-1]) + if added > 0: + pad_slice = id_mask[..., :1] + id_mask = torch.cat([pad_slice.expand(-1, -1, added), id_mask], dim=-1) + + seq_len = (kv_cache.seq_len(1) if kv_cache else 0) + seq_len + + num_heads: int = cast(int, self.layers[0].num_heads) + + timewise_attention_mask = None + + spacewise_attention_mask = self._get_mask( + num_heads=num_heads, + dtype=inputs.dtype, + id_mask=id_mask, + ) + + for layer_idx, layer in enumerate(self.layers): + inputs = layer( + layer_idx, + inputs, + (timewise_attention_mask if layer.attention_axis == AttentionAxis.TIME else spacewise_attention_mask), + kv_cache, + ) + return inputs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py new file mode 100644 index 0000000000000..182c38d195653 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py @@ -0,0 +1,210 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +import warnings +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional, TypeAlias, Union + +import torch +from einops import rearrange +from jaxtyping import Float, Int + +from .attention import TimeWiseMultiheadAttention + +if TYPE_CHECKING: + from .transformer import TransformerLayer + +try: + from xformers import _is_triton_available + from xformers.ops.rmsnorm import rms_norm, rms_norm_add + + XFORMERS_RMSNORM_AVAILABLE = True +except ImportError: + warnings.warn( + "xFormers fused RMSNorm implementation not available. Will not use " + "optimized kernel for inference.", + ImportWarning, + ) + + def _is_triton_available(): + return False + + XFORMERS_RMSNORM_AVAILABLE = False + + +class RMSNorm(torch.nn.Module): + """ + Wraps xFormers' rms_norm for eval/frozen mode, and does a Python fallback for train mode. + """ + + def __init__(self, dim: int, include_weight: bool = True, eps: float = 1e-8): + super(RMSNorm, self).__init__() + self.eps = eps + if include_weight: + self.scale: Optional[torch.nn.Parameter] = torch.nn.Parameter(torch.ones(dim)) + else: + self.scale = None + + def forward(self, x: torch.Tensor): + if ( + ((not self.training) or (self.scale is not None and not self.scale.requires_grad)) + and XFORMERS_RMSNORM_AVAILABLE + and _is_triton_available() + ): + return rms_norm(x, self.scale, self.eps) + + x_normed = x / torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + return x_normed if self.scale is None else x_normed * self.scale + + def increment_and_forward_(self, x: torch.Tensor, y: torch.Tensor): + if (not self.training) or (self.scale is not None and not self.scale.requires_grad): + return rms_norm_add(x, y, self.scale, self.eps) + return self.forward(x + y) + + +def make_batched_block_mask(t: torch.Tensor) -> torch.Tensor: + unsqueezed = rearrange(t, "... d -> ... 1 d") + return unsqueezed == unsqueezed.transpose(-1, -2) + + +K: TypeAlias = Float[torch.Tensor, "batch_size_X_num_variates num_heads seq_len head_dim"] +V: TypeAlias = Float[torch.Tensor, "batch_size_X_num_variates num_heads seq_len head_dim"] +KV: TypeAlias = tuple[K, V] + + +@dataclass +class KVCache: + """ + Key/Value cache for storing intermediate attention values during multistep inference. + Only stores KV cache for timewise layers, skipping spacewise layers. + """ + + batch_size: int + num_variates: int + transformer_layers: List["TransformerLayer"] + num_layers: int + embed_dim: int + num_heads: int + max_seq_len: int + device: torch.device = torch.device("cpu") + dtype: torch.dtype = torch.float32 + use_memory_efficient_attention: bool = True + + _keys: Union[ + Float[torch.Tensor, "time_layer_count batch_size_X_num_variates max_seq_len num_heads head_dim"], + Float[torch.Tensor, "time_layer_count batch_size_X_num_variates num_heads max_seq_len head_dim"], + ] = field(init=False) + + _values: Union[ + Float[torch.Tensor, "time_layer_count batch_size_X_num_variates max_seq_len num_heads head_dim"], + Float[torch.Tensor, "time_layer_count batch_size_X_num_variates num_heads max_seq_len head_dim"], + ] = field(init=False) + + _current_idx: Int[torch.Tensor, "time_layer_count"] = field(init=False) + _layer_cache_map: Int[torch.Tensor, "num_layers"] = field(init=False) + + def __post_init__(self): + assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads" + head_dim = self.embed_dim // self.num_heads + + time_layer_indices = [ + i + for i in range(self.num_layers) + if isinstance(self.transformer_layers[i].attention, TimeWiseMultiheadAttention) + ] + + time_layer_count = max(1, len(time_layer_indices)) + if self.use_memory_efficient_attention: + shape = ( + time_layer_count, + self.batch_size * self.num_variates, + self.max_seq_len, + self.num_heads, + head_dim, + ) + else: + shape = ( + time_layer_count, + self.batch_size * self.num_variates, + self.num_heads, + self.max_seq_len, + head_dim, + ) + self._keys = torch.zeros(shape, device=self.device, dtype=self.dtype) + self._values = torch.zeros_like(self._keys) + self._current_idx = torch.zeros(time_layer_count, device=self.device, dtype=torch.int) + self._layer_cache_map = torch.zeros((self.num_layers,), dtype=torch.int, device=self.device) + for cache_idx, layer_idx in enumerate(time_layer_indices): + self._layer_cache_map[layer_idx] = int(cache_idx) + + def __getitem__(self, layer_idx: int) -> KV: + cache_idx = int(self._layer_cache_map[layer_idx].item()) + end_idx = int(self._current_idx[cache_idx].item()) + + if self.use_memory_efficient_attention: + return self._keys[cache_idx, :, :end_idx, :, :], self._values[cache_idx, :, :end_idx, :, :] + else: + return self._keys[cache_idx, :, :, :end_idx, :], self._values[cache_idx, :, :, :end_idx, :] + + def current_len(self, cache_idx: int) -> int: + return int(self._current_idx[cache_idx].item()) if self._current_idx.numel() > 0 else 0 + + def seq_len(self, layer_idx: int) -> int: + cache_idx = int(self._layer_cache_map[layer_idx].item()) + return self.current_len(cache_idx) + + def append(self, layer_idx: int, kv: KV): + cache_idx = int(self._layer_cache_map[layer_idx].item()) + keys, values = kv + + assert keys.shape == values.shape, "keys and values must have the same shape" + assert ( + keys.shape[0] == self.batch_size * self.num_variates + ), "keys and values must have batch_size * num_variates as their first dimension" + + if self.use_memory_efficient_attention: + assert keys.shape[2] == self.num_heads + else: + assert keys.shape[1] == self.num_heads + assert keys.shape[3] == self.embed_dim // self.num_heads + + start_idx = self._current_idx[cache_idx] + if self.use_memory_efficient_attention: + end_idx = start_idx + keys.shape[1] + else: + end_idx = start_idx + keys.shape[2] + assert end_idx <= self.max_seq_len, ( + f"max_seq_len exceeded {end_idx} > {self.max_seq_len}, keys.shape: {keys.shape}" + ) + + if self.use_memory_efficient_attention: + self._keys[cache_idx, :, start_idx:end_idx, :, :] = keys + self._values[cache_idx, :, start_idx:end_idx, :, :] = values + else: + self._keys[cache_idx, :, :, start_idx:end_idx, :] = keys + self._values[cache_idx, :, :, start_idx:end_idx, :] = values + + self._current_idx[cache_idx] = end_idx + + def reset(self): + self._keys.zero_() + self._values.zero_() + self._current_idx.zero_() diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py index f2fbe16390815..d5f0ca34003ee 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py @@ -16,113 +16,151 @@ # under the License. # -import torch +import json +import os + +import safetensors.torch as safetorch +from transformers import PreTrainedModel from iotdb.ainode.core.log import Logger +from .configuration_toto import TotoConfig +from .model.attention import XFORMERS_AVAILABLE +from .model.backbone import TotoBackbone +from .model.toto import Toto +from .model.transformer import XFORMERS_SWIGLU_AVAILABLE + logger = Logger() -class TotoForPrediction(torch.nn.Module): +class TotoPreTrainedModel(PreTrainedModel): + """Abstract base class for all Toto model variants.""" + + config_class = TotoConfig + base_model_prefix = "model" + supports_gradient_checkpointing = False + + def _init_weights(self, module): + # Weights are loaded from the pretrained checkpoint; no random initialisation needed. + pass + + +class TotoForPrediction(TotoPreTrainedModel): """ - Wrapper around the Toto model for AINode integration. + Toto (Timeseries-Optimized Transformer for Observability) model for time series prediction. - Toto (Time Series Optimized Transformer for Observability) is a 151M parameter - foundation model for multivariate time series forecasting. This wrapper delegates - model loading to the ``toto-ts`` package while providing a compatible interface - for AINode's model loading mechanism. + Integrates the Toto backbone with AINode's model loading mechanism using the + transformers PreTrainedModel interface. Weights are loaded directly from the + Datadog/Toto-Open-Base-1.0 safetensors checkpoint. - The underlying Toto model uses ``huggingface_hub.ModelHubMixin`` for ``from_pretrained`` - support, which differs from the standard ``transformers.PreTrainedModel`` pattern. - This wrapper bridges that gap. + The backbone is stored as ``self.model`` so that safetensors key prefixes + (``model.*``) map directly to parameters without any renaming. Reference: https://huggingface.co/Datadog/Toto-Open-Base-1.0 """ - def __init__(self, toto_model): - """ - Initialize the wrapper with a loaded Toto model instance. - - Args: - toto_model: A ``toto.model.toto.Toto`` instance. - """ - super().__init__() - self.toto = toto_model + def __init__(self, config: TotoConfig): + super().__init__(config) + # Backbone stored as self.model so safetensors keys (model.*) match directly. + self.model = TotoBackbone( + patch_size=config.patch_size, + stride=config.stride, + embed_dim=config.embed_dim, + num_layers=config.num_layers, + num_heads=config.num_heads, + mlp_hidden_dim=config.mlp_hidden_dim, + dropout=config.dropout, + spacewise_every_n_layers=config.spacewise_every_n_layers, + scaler_cls=config.scaler_cls, + output_distribution_classes=config.output_distribution_classes, + spacewise_first=config.spacewise_first, + use_memory_efficient_attention=config.use_memory_efficient_attention, + stabilize_with_global=config.stabilize_with_global, + scale_factor_exponent=config.scale_factor_exponent, + ) + self.post_init() @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): """ - Load a Toto model from a local directory or HuggingFace Hub repository. + Load TotoForPrediction from a local directory containing ``config.json`` + and ``model.safetensors``. - This delegates to ``toto.model.toto.Toto.from_pretrained()`` which uses - ``ModelHubMixin`` to load the model weights and configuration. + This override is required because: + 1. The safetensors file uses legacy SwiGLU key names that need remapping. + 2. The config uses class-path strings for ``scaler_cls`` and + ``output_distribution_classes`` that must not be filtered out. Args: - pretrained_model_name_or_path (str): Path to a local directory containing - ``config.json`` and ``model.safetensors``, or a HuggingFace Hub repo ID - (e.g., ``Datadog/Toto-Open-Base-1.0``). - **kwargs: Additional keyword arguments passed to the underlying loader. + pretrained_model_name_or_path (str): Path to a local directory. + **kwargs: Extra key/value pairs merged into the config before construction. Returns: - TotoForPrediction: A wrapper instance containing the loaded Toto model. + TotoForPrediction: Fully initialised and weight-loaded model in eval mode. """ - from toto.model.toto import Toto + if os.path.isdir(pretrained_model_name_or_path): + config_file = os.path.join(pretrained_model_name_or_path, "config.json") + safetensors_file = os.path.join( + pretrained_model_name_or_path, "model.safetensors" + ) + else: + raise ValueError( + f"pretrained_model_name_or_path must be a local directory, " + f"got: {pretrained_model_name_or_path}" + ) + + # ── Load config ────────────────────────────────────────────────────── + config_dict: dict = {} + if os.path.exists(config_file): + with open(config_file, "r") as f: + config_dict = json.load(f) + config_dict.update(kwargs) + + # Disable xFormers memory-efficient attention if the library is absent. + if not XFORMERS_AVAILABLE and config_dict.get( + "use_memory_efficient_attention", True + ): + config_dict["use_memory_efficient_attention"] = False + + config = TotoConfig(**config_dict) + + # ── Instantiate model ───────────────────────────────────────────────── + instance = cls(config) + + # ── Load safetensors weights ────────────────────────────────────────── + if not os.path.exists(safetensors_file): + raise FileNotFoundError( + f"Model checkpoint not found at: {safetensors_file}" + ) + + state_dict = safetorch.load_file(safetensors_file, device="cpu") + + # Remap SwiGLU weight names if the fused xFormers kernel is available. + use_fused_swiglu = XFORMERS_SWIGLU_AVAILABLE and not config_dict.get( + "pre_xformers_checkpoint", False + ) + state_dict = Toto._map_state_dict_keys(state_dict, use_fused_swiglu) - toto_model = Toto.from_pretrained(pretrained_model_name_or_path, **kwargs) - logger.info(f"Loaded Toto model from {pretrained_model_name_or_path}") - return cls(toto_model) + # Filter to keys that exist in the model, skipping cached rotary buffers. + model_state = instance.state_dict() + filtered_state_dict = { + k: v + for k, v in state_dict.items() + if k in model_state and not k.endswith("rotary_emb.freqs") + } - @classmethod - def from_config(cls, config): - """ - Create a Toto model from a configuration (for training from scratch). + instance.load_state_dict(filtered_state_dict, strict=False) + instance.eval() - Args: - config: A ``TotoConfig`` or compatible configuration object. - - Returns: - TotoForPrediction: A wrapper instance containing a newly initialized Toto model. - """ - from toto.model.toto import Toto - - toto_model = Toto( - patch_size=getattr(config, "patch_size", 32), - stride=getattr(config, "stride", 32), - embed_dim=getattr(config, "embed_dim", 1024), - num_layers=getattr(config, "num_layers", 18), - num_heads=getattr(config, "num_heads", 16), - mlp_hidden_dim=getattr(config, "mlp_hidden_dim", 2816), - dropout=getattr(config, "dropout", 0.0), - spacewise_every_n_layers=getattr(config, "spacewise_every_n_layers", 3), - scaler_cls=getattr(config, "scaler_cls", "per_variate_causal"), - output_distribution_classes=getattr( - config, "output_distribution_classes", ["student_t_mixture"] - ), - spacewise_first=getattr(config, "spacewise_first", True), - use_memory_efficient_attention=getattr( - config, "use_memory_efficient_attention", True - ), - stabilize_with_global=getattr(config, "stabilize_with_global", True), - scale_factor_exponent=getattr(config, "scale_factor_exponent", 10.0), - ) - return cls(toto_model) + logger.info(f"Loaded Toto model from {pretrained_model_name_or_path}") + return instance @property def backbone(self): - """ - Access the underlying TotoBackbone model used for inference. - - Returns: - The ``TotoBackbone`` instance from the Toto model. - """ - return self.toto.model + """The underlying ``TotoBackbone`` used for inference.""" + return self.model @property def device(self): - """ - Get the device of the model parameters. - - Returns: - torch.device: The device where the model parameters reside. - """ - return self.toto.device + """Device on which model parameters reside.""" + return next(self.parameters()).device diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py index 7cfd26b43702d..311e679d3b703 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py @@ -23,36 +23,35 @@ from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline -logger = logging.getLogger(__name__) - -_TOTO_INSTALL_MSG = ( - "toto-ts is required to use the Toto model but is not installed.\n" - "Install it with: pip install toto-ts\n" - "Note: toto-ts pins specific versions of torch, numpy, and transformers " - "that may conflict with other AINode dependencies. Install in a separate " - "environment if needed." -) +from .data.util.dataset import MaskedTimeseries +from .inference.forecaster import TotoForecaster +logger = logging.getLogger(__name__) -def _import_toto(): - try: - from toto.data.util.dataset import MaskedTimeseries - from toto.inference.forecaster import TotoForecaster - return MaskedTimeseries, TotoForecaster - except ImportError as e: - raise ImportError(_TOTO_INSTALL_MSG) from e +class TotoPipeline(ForecastPipeline): + """ + Inference pipeline for the Toto time series foundation model. + Converts raw input tensors into ``MaskedTimeseries`` objects and delegates + autoregressive decoding to ``TotoForecaster``. The forecaster is created + lazily on the first call to ``forecast()`` so that pipeline construction + does not require a live model (useful during import / registration time). + """ -class TotoPipeline(ForecastPipeline): def __init__(self, model_info, **model_kwargs): super().__init__(model_info, **model_kwargs) - _, TotoForecaster = _import_toto() - self.forecaster = TotoForecaster(self.model.backbone) + # Forecaster is created lazily to avoid issues at construction time. + self._forecaster: TotoForecaster | None = None + + def _get_forecaster(self) -> TotoForecaster: + """Return the cached forecaster, creating it on first call.""" + if self._forecaster is None: + self._forecaster = TotoForecaster(self.model.backbone) + return self._forecaster def preprocess(self, inputs, **infer_kwargs): super().preprocess(inputs, **infer_kwargs) - MaskedTimeseries, _ = _import_toto() processed_inputs = [] for item in inputs: @@ -102,6 +101,8 @@ def forecast(self, inputs, **infer_kwargs): num_samples = infer_kwargs.get("num_samples", None) samples_per_batch = infer_kwargs.get("samples_per_batch", 10) + forecaster = self._get_forecaster() + outputs = [] for masked_ts in inputs: masked_ts = masked_ts._replace( @@ -113,15 +114,14 @@ def forecast(self, inputs, **infer_kwargs): self.model.device ), ) - result = self.forecaster.forecast( + result = forecaster.forecast( masked_ts, prediction_length=output_length, num_samples=num_samples, samples_per_batch=samples_per_batch, ) mean = result.mean - if mean.ndim == 3: - mean = mean.mean(dim=0) + # Remove batch dimension if present (batch=1 squeeze). if mean.ndim == 3 and mean.shape[0] == 1: mean = mean.squeeze(0) outputs.append(mean) diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml index 9a142fe72596b..0dc630fb0ff60 100644 --- a/iotdb-core/ainode/pyproject.toml +++ b/iotdb-core/ainode/pyproject.toml @@ -117,6 +117,7 @@ setuptools = ">=75.3.0" joblib = ">=1.4.2" urllib3 = "2.6.3" jaxtyping = ">=0.2.24" +rotary-embedding-torch = ">=0.8.0" [tool.poetry.scripts] ainode = "iotdb.ainode.core.script:main" From 60282a3ca515124e3492dc6282771043ab12b759 Mon Sep 17 00:00:00 2001 From: graceli02 Date: Sun, 22 Mar 2026 21:43:51 -0400 Subject: [PATCH 05/11] [AINode] Fix RAT license check: move toto attribution to root NOTICE Apache RAT flagged the standalone NOTICE file inside the toto Python package because the project's RAT config does not exclude plain NOTICE files. Moved the Datadog/toto attribution to the standard location (project root NOTICE) and removed the inner NOTICE file. Co-Authored-By: Claude Sonnet 4.6 --- NOTICE | 9 +++++ .../iotdb/ainode/core/model/toto/NOTICE | 36 ------------------- 2 files changed, 9 insertions(+), 36 deletions(-) delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/toto/NOTICE diff --git a/NOTICE b/NOTICE index fa52a36987f48..429495c377b09 100644 --- a/NOTICE +++ b/NOTICE @@ -17,6 +17,15 @@ grant the users the right to the use of patent under the requirement of Apache 2 ============================================================================ +This product includes source code derived from the DataDog/toto project: + + Toto – Timeseries-Optimized Transformer for Observability + Copyright 2025 Datadog, Inc. + Licensed under the Apache License, Version 2.0 + https://github.com/DataDog/toto + +============================================================================ + Apache Commons Collections Copyright 2001-2019 The Apache Software Foundation diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/NOTICE b/iotdb-core/ainode/iotdb/ainode/core/model/toto/NOTICE deleted file mode 100644 index 23999458c3eee..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/NOTICE +++ /dev/null @@ -1,36 +0,0 @@ -Apache IoTDB – AINode: Toto model -Copyright 2025 The Apache Software Foundation - -This product includes software developed at -The Apache Software Foundation (http://www.apache.org/). - -============================================================================ - -This directory includes source code derived from the DataDog/toto project: - - Toto – Timeseries-Optimized Transformer for Observability - Copyright 2025 Datadog, Inc. - Licensed under the Apache License, Version 2.0 - https://github.com/DataDog/toto - -The following files are derived from that project: - - model/attention.py - model/backbone.py - model/distribution.py - model/embedding.py - model/feed_forward.py - model/fusion.py - model/rope.py - model/scaler.py - model/transformer.py - model/toto.py - model/util.py - data/util/dataset.py - inference/forecaster.py - -Each derived file carries the original copyright notice and the Apache 2.0 -license header as required by the original project's license. - -A copy of the Apache License, Version 2.0 is available at: - http://www.apache.org/licenses/LICENSE-2.0 From 13edc2352483069a1ae6d6b3027c507e049be38e Mon Sep 17 00:00:00 2001 From: graceli02 Date: Mon, 23 Mar 2026 00:52:37 -0400 Subject: [PATCH 06/11] [AINode] Fix toto runtime bugs: scaler aliases and output_distribution_kwargs - scaler.py: add short-name aliases ("per_variate", "per_variate_causal", "per_variate_causal_patch") to scaler_types dict so that config.json string values work without KeyError at backbone init time. - backbone.py: recognise "per_variate_causal_patch" string in the CausalPatchStdMeanScaler branch (alongside the legacy class-path string). - configuration_toto.py: add output_distribution_kwargs parameter with default {"k_components": 5} matching Datadog/Toto-Open-Base-1.0. - modeling_toto.py: pass output_distribution_kwargs from config to TotoBackbone so MixtureOfStudentTsOutput receives k_components. Fixes: KeyError 'per_variate_causal' in scaler_types lookup. Fixes: MixtureOfStudentTsOutput missing required positional arg k_components. Co-Authored-By: Claude Sonnet 4.6 --- .../iotdb/ainode/core/model/toto/configuration_toto.py | 3 +++ .../ainode/iotdb/ainode/core/model/toto/model/backbone.py | 5 ++++- .../ainode/iotdb/ainode/core/model/toto/model/scaler.py | 4 ++++ .../ainode/iotdb/ainode/core/model/toto/modeling_toto.py | 1 + 4 files changed, 12 insertions(+), 1 deletion(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py index 82e185c377da3..3da6b98329f4d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py @@ -47,6 +47,7 @@ def __init__( spacewise_every_n_layers: int = 3, scaler_cls: str = "per_variate_causal", output_distribution_classes: Optional[List[str]] = None, + output_distribution_kwargs: Optional[dict] = None, spacewise_first: bool = True, use_memory_efficient_attention: bool = True, stabilize_with_global: bool = True, @@ -65,6 +66,8 @@ def __init__( self.output_distribution_classes = output_distribution_classes or [ "student_t_mixture" ] + # k_components=5 is the default used by Datadog/Toto-Open-Base-1.0 + self.output_distribution_kwargs = output_distribution_kwargs or {"k_components": 5} self.spacewise_first = spacewise_first self.use_memory_efficient_attention = use_memory_efficient_attention self.stabilize_with_global = stabilize_with_global diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py index 50cf4943185e8..b11c07c89080a 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py @@ -76,7 +76,10 @@ def __init__( self.target_variate_label: Optional[torch.nn.Parameter] = None self.exogenous_variate_label: Optional[torch.nn.Parameter] = None - if scaler_cls == "": + if scaler_cls in ( + "", + "per_variate_causal_patch", + ): self.scaler = scaler_types[scaler_cls]( patch_size=patch_size, stabilize_with_global=stabilize_with_global, diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py index 255211f3772c5..c787b64fea849 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py @@ -296,4 +296,8 @@ def __call__( "": StdMeanScaler, "": CausalStdMeanScaler, "": CausalPatchStdMeanScaler, + # Short aliases used in config.json + "per_variate": StdMeanScaler, + "per_variate_causal": CausalStdMeanScaler, + "per_variate_causal_patch": CausalPatchStdMeanScaler, } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py index d5f0ca34003ee..08fda1c3c7231 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py @@ -73,6 +73,7 @@ def __init__(self, config: TotoConfig): spacewise_every_n_layers=config.spacewise_every_n_layers, scaler_cls=config.scaler_cls, output_distribution_classes=config.output_distribution_classes, + output_distribution_kwargs=config.output_distribution_kwargs, spacewise_first=config.spacewise_first, use_memory_efficient_attention=config.use_memory_efficient_attention, stabilize_with_global=config.stabilize_with_global, From 41a6db69f94f090e162ef263d651521bea11fb6b Mon Sep 17 00:00:00 2001 From: graceli02 Date: Mon, 23 Mar 2026 16:25:26 -0400 Subject: [PATCH 07/11] [AINode] Fix toto code style: Black formatting and isort imports Co-Authored-By: Claude Opus 4.6 --- .../core/model/toto/configuration_toto.py | 4 +- .../core/model/toto/inference/forecaster.py | 146 ++++++++++++++---- .../ainode/core/model/toto/model/attention.py | 113 +++++++++++--- .../ainode/core/model/toto/model/backbone.py | 51 ++++-- .../core/model/toto/model/distribution.py | 12 +- .../ainode/core/model/toto/model/embedding.py | 12 +- .../ainode/core/model/toto/model/fusion.py | 11 +- .../ainode/core/model/toto/model/scaler.py | 53 +++++-- .../ainode/core/model/toto/model/toto.py | 12 +- .../core/model/toto/model/transformer.py | 49 ++++-- .../ainode/core/model/toto/model/util.py | 79 +++++++--- 11 files changed, 414 insertions(+), 128 deletions(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py index 3da6b98329f4d..2a00fcc3be454 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py @@ -67,7 +67,9 @@ def __init__( "student_t_mixture" ] # k_components=5 is the default used by Datadog/Toto-Open-Base-1.0 - self.output_distribution_kwargs = output_distribution_kwargs or {"k_components": 5} + self.output_distribution_kwargs = output_distribution_kwargs or { + "k_components": 5 + } self.spacewise_first = spacewise_first self.use_memory_efficient_attention = use_memory_efficient_attention self.stabilize_with_global = stabilize_with_global diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py index 2099892293be8..2a9db2aa62950 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py @@ -26,8 +26,7 @@ import torch from einops import rearrange, repeat from jaxtyping import Bool, Float, Int -from torch.distributions import Distribution -from torch.distributions import TransformedDistribution +from torch.distributions import Distribution, TransformedDistribution from torch.distributions.transforms import AffineTransform from ..data.util.dataset import ( @@ -62,9 +61,13 @@ def mean(self): @dataclass(frozen=True) class Forecast: mean: Float[torch.Tensor, "batch variate future_time_steps"] - samples: Float[torch.Tensor, "batch variate future_time_steps samples"] | None = None + samples: Float[torch.Tensor, "batch variate future_time_steps samples"] | None = ( + None + ) - def quantile(self, q: float | torch.Tensor) -> Float[torch.Tensor, "batch variate future_time_steps"]: + def quantile( + self, q: float | torch.Tensor + ) -> Float[torch.Tensor, "batch variate future_time_steps"]: assert self.samples is not None, "samples must be provided to compute quantiles" assert isinstance(q, (float, torch.Tensor)), "q must be a float or a tensor" if isinstance(q, float): @@ -77,7 +80,9 @@ def median(self) -> Float[torch.Tensor, "batch variate future_time_steps"]: @property def std(self) -> Float[torch.Tensor, "batch variate future_time_steps"]: - assert self.samples is not None, "samples must be provided to compute standard deviation" + assert ( + self.samples is not None + ), "samples must be provided to compute standard deviation" return self.samples.std(dim=-1) @@ -100,14 +105,19 @@ def forecast( num_samples: int | None = None, samples_per_batch: int = 10, use_kv_cache: bool = True, - future_exogenous_variables: Float[torch.Tensor, "batch exogenous_variables future_time_steps"] | None = None, + future_exogenous_variables: ( + Float[torch.Tensor, "batch exogenous_variables future_time_steps"] | None + ) = None, ) -> Forecast: if len(inputs.series.shape) == 2: batch = cast(MaskedTimeseries, torch.utils.data.default_collate([inputs])) else: batch = inputs - if future_exogenous_variables is not None and len(future_exogenous_variables.shape) == 2: + if ( + future_exogenous_variables is not None + and len(future_exogenous_variables.shape) == 2 + ): future_exogenous_variables = future_exogenous_variables.unsqueeze(0) series = pad_array(batch.series, self.model.patch_embed.stride) @@ -115,9 +125,13 @@ def forecast( id_mask = batch.id_mask if id_mask is not None: id_mask = pad_id_mask(batch.id_mask, self.model.patch_embed.stride) - timestamp_seconds = pad_array(batch.timestamp_seconds, self.model.patch_embed.stride) - time_interval_seconds: Int[torch.Tensor, "batch variate series_len"] = torch.as_tensor( - batch.time_interval_seconds, device=series.device, dtype=torch.int + timestamp_seconds = pad_array( + batch.timestamp_seconds, self.model.patch_embed.stride + ) + time_interval_seconds: Int[torch.Tensor, "batch variate series_len"] = ( + torch.as_tensor( + batch.time_interval_seconds, device=series.device, dtype=torch.int + ) ) if num_samples is not None: @@ -176,24 +190,35 @@ def generate_mean( prediction_length: int, timestamp_seconds: Int[torch.Tensor, "batch variate time_steps"], time_interval_seconds: Int[torch.Tensor, "batch variate"], - input_padding_mask: Bool[torch.Tensor, "batch variate time_steps"] | None = None, + input_padding_mask: ( + Bool[torch.Tensor, "batch variate time_steps"] | None + ) = None, id_mask: Float[torch.Tensor, "batch #variate time_steps"] | None = None, use_kv_cache: bool = False, future_exogenous_variables=None, num_exogenous_variables: int = 0, ) -> Float[torch.Tensor, "batch variate time_steps"]: if input_padding_mask is None: - input_padding_mask = torch.ones_like(inputs, dtype=torch.bool, device=inputs.device) + input_padding_mask = torch.ones_like( + inputs, dtype=torch.bool, device=inputs.device + ) if id_mask is None: id_mask = torch.zeros_like(inputs, dtype=torch.int, device=inputs.device) if future_exogenous_variables is not None: - self.assert_ev_compatibility(inputs, future_exogenous_variables, prediction_length, num_exogenous_variables) + self.assert_ev_compatibility( + inputs, + future_exogenous_variables, + prediction_length, + num_exogenous_variables, + ) patch_size = self.model.patch_embed.stride rounded_steps = int(np.ceil(prediction_length / patch_size) * patch_size) if rounded_steps > prediction_length and future_exogenous_variables is not None: - future_exogenous_variables = self.round_ft_ev(future_exogenous_variables, rounded_steps) + future_exogenous_variables = self.round_ft_ev( + future_exogenous_variables, rounded_steps + ) start_index = inputs.shape[-1] end_index = start_index + prediction_length @@ -235,14 +260,18 @@ def generate_mean( if future_exogenous_variables is not None: start, stop = idx * patch_size, (idx + 1) * patch_size - samples[:, -num_exogenous_variables:] = future_exogenous_variables[:, :, start:stop] + samples[:, -num_exogenous_variables:] = future_exogenous_variables[ + :, :, start:stop + ] inputs = torch.cat([inputs, samples], dim=-1) id_mask = torch.cat([id_mask, dummy_id_mask], dim=-1) input_padding_mask = torch.cat([input_padding_mask, dummy_padding], dim=-1) for _ in range(patch_size): next_timestamp = timestamp_seconds[:, :, -1] + time_interval_seconds - timestamp_seconds = torch.cat([timestamp_seconds, next_timestamp.unsqueeze(-1)], dim=-1) + timestamp_seconds = torch.cat( + [timestamp_seconds, next_timestamp.unsqueeze(-1)], dim=-1 + ) return inputs.detach()[:, :, start_index:end_index] @@ -254,7 +283,9 @@ def generate_samples( num_samples: int, timestamp_seconds: Int[torch.Tensor, "batch variate time_steps"], time_interval_seconds: Int[torch.Tensor, "batch variate"], - input_padding_mask: Bool[torch.Tensor, "batch variate time_steps"] | None = None, + input_padding_mask: ( + Bool[torch.Tensor, "batch variate time_steps"] | None + ) = None, id_mask: Float[torch.Tensor, "batch #variate time_steps"] | None = None, sampling_batch_size: int = 10, use_kv_cache: bool = False, @@ -262,25 +293,40 @@ def generate_samples( num_exogenous_variables: int = 0, ) -> Float[torch.Tensor, "batch variate time_steps samples"]: if input_padding_mask is None: - input_padding_mask = torch.ones_like(inputs, dtype=torch.bool, device=inputs.device) + input_padding_mask = torch.ones_like( + inputs, dtype=torch.bool, device=inputs.device + ) if id_mask is None: id_mask = torch.zeros_like(inputs, dtype=torch.int, device=inputs.device) if future_exogenous_variables is not None: - self.assert_ev_compatibility(inputs, future_exogenous_variables, prediction_length, num_exogenous_variables) + self.assert_ev_compatibility( + inputs, + future_exogenous_variables, + prediction_length, + num_exogenous_variables, + ) - assert num_samples % sampling_batch_size == 0, "num_samples must be divisible by sampling_batch_size" + assert ( + num_samples % sampling_batch_size == 0 + ), "num_samples must be divisible by sampling_batch_size" num_batches = num_samples // sampling_batch_size patch_size = self.model.patch_embed.patch_size rounded_steps = int(np.ceil(prediction_length / patch_size) * patch_size) if rounded_steps > prediction_length and future_exogenous_variables is not None: - future_exogenous_variables = self.round_ft_ev(future_exogenous_variables, rounded_steps) + future_exogenous_variables = self.round_ft_ev( + future_exogenous_variables, rounded_steps + ) start_index = inputs.shape[-1] end_index = start_index + prediction_length dummy_padding = torch.ones( - (input_padding_mask.shape[0] * sampling_batch_size, input_padding_mask.shape[1], patch_size), + ( + input_padding_mask.shape[0] * sampling_batch_size, + input_padding_mask.shape[1], + patch_size, + ), dtype=torch.bool, device=inputs.device, ) @@ -290,17 +336,37 @@ def generate_samples( sampling_batch_size=sampling_batch_size, patch_size=patch_size, ) - inputs = repeat(inputs, "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", sampling_batch_size=sampling_batch_size) + inputs = repeat( + inputs, + "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", + sampling_batch_size=sampling_batch_size, + ) if future_exogenous_variables is not None: future_exogenous_variables = repeat( future_exogenous_variables, "batch exogenous_variables future_time_steps -> (sampling_batch_size batch) exogenous_variables future_time_steps", sampling_batch_size=sampling_batch_size, ) - input_padding_mask = repeat(input_padding_mask, "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", sampling_batch_size=sampling_batch_size) - id_mask = repeat(id_mask, "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", sampling_batch_size=sampling_batch_size) - timestamp_seconds = repeat(timestamp_seconds, "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", sampling_batch_size=sampling_batch_size) - time_interval_seconds = repeat(time_interval_seconds, "batch variates -> (sampling_batch_size batch) variates", sampling_batch_size=sampling_batch_size) + input_padding_mask = repeat( + input_padding_mask, + "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", + sampling_batch_size=sampling_batch_size, + ) + id_mask = repeat( + id_mask, + "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", + sampling_batch_size=sampling_batch_size, + ) + timestamp_seconds = repeat( + timestamp_seconds, + "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", + sampling_batch_size=sampling_batch_size, + ) + time_interval_seconds = repeat( + time_interval_seconds, + "batch variates -> (sampling_batch_size batch) variates", + sampling_batch_size=sampling_batch_size, + ) all_samples = [] if use_kv_cache: @@ -340,13 +406,21 @@ def generate_samples( if future_exogenous_variables is not None: start, stop = idx * patch_size, (idx + 1) * patch_size - samples[:, -num_exogenous_variables:] = future_exogenous_variables[:, :, start:stop] + samples[:, -num_exogenous_variables:] = future_exogenous_variables[ + :, :, start:stop + ] batch_inputs = torch.cat([batch_inputs, samples], dim=-1) batch_id_mask = torch.cat([batch_id_mask, dummy_id_mask], dim=-1) - batch_input_padding_mask = torch.cat([batch_input_padding_mask, dummy_padding], dim=-1) + batch_input_padding_mask = torch.cat( + [batch_input_padding_mask, dummy_padding], dim=-1 + ) for _ in range(patch_size): - next_timestamp = batch_timestamp_seconds[:, :, -1] + time_interval_seconds - batch_timestamp_seconds = torch.cat([batch_timestamp_seconds, next_timestamp.unsqueeze(-1)], dim=-1) + next_timestamp = ( + batch_timestamp_seconds[:, :, -1] + time_interval_seconds + ) + batch_timestamp_seconds = torch.cat( + [batch_timestamp_seconds, next_timestamp.unsqueeze(-1)], dim=-1 + ) all_samples.append(batch_inputs) if kv_cache is not None: kv_cache.reset() @@ -361,7 +435,9 @@ def generate_samples( return unfolded_outputs[:, :, start_index:end_index, :] @staticmethod - def create_affine_transformed(base_distr: Distribution, loc: torch.Tensor, scale: torch.Tensor) -> Distribution: + def create_affine_transformed( + base_distr: Distribution, loc: torch.Tensor, scale: torch.Tensor + ) -> Distribution: base_shape = base_distr.mean.shape base_time_dim = base_shape[-1] loc_time_dim = loc.shape[-1] @@ -369,4 +445,8 @@ def create_affine_transformed(base_distr: Distribution, loc: torch.Tensor, scale if loc_time_dim == 1: return AffineTransformed(base_distr, loc=loc, scale=scale) - return AffineTransformed(base_distr, loc=loc[:, :, -base_time_dim:], scale=scale[:, :, -base_time_dim:]) + return AffineTransformed( + base_distr, + loc=loc[:, :, -base_time_dim:], + scale=scale[:, :, -base_time_dim:], + ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py index 387ea30204c11..80f6d381ff20b 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py @@ -69,7 +69,9 @@ def __init__( super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads - assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads." + assert ( + embed_dim % num_heads == 0 + ), "Embedding dimension must be divisible by number of heads." self.head_dim = embed_dim // num_heads self.rotary_emb = rotary_emb @@ -82,8 +84,13 @@ def __init__( not XFORMERS_AVAILABLE and self.use_memory_efficient_attention ), "XFORMERS_AVAILABLE is False, so use_memory_efficient_attention must be False" - if not hasattr(self, "attention_axis") or self.attention_axis not in (AttentionAxis.TIME, AttentionAxis.SPACE): - raise ValueError("Child class must define attention_axis as AttentionAxis.TIME or AttentionAxis.SPACE.") + if not hasattr(self, "attention_axis") or self.attention_axis not in ( + AttentionAxis.TIME, + AttentionAxis.SPACE, + ): + raise ValueError( + "Child class must define attention_axis as AttentionAxis.TIME or AttentionAxis.SPACE." + ) def rearrange_inputs( self, inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"] @@ -96,24 +103,40 @@ def rearrange_inputs( return rearrange(inputs, pattern) def get_qkv(self, inputs: torch.Tensor) -> tuple[torch.Tensor, ...]: - if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention: + if ( + self.attention_axis == AttentionAxis.TIME + and self.use_memory_efficient_attention + ): pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate seq_len n_heads head_dim" - elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention: + elif ( + self.attention_axis == AttentionAxis.TIME + and not self.use_memory_efficient_attention + ): pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate n_heads seq_len head_dim" - elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention: + elif ( + self.attention_axis == AttentionAxis.SPACE + and self.use_memory_efficient_attention + ): pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len variate n_heads head_dim" - elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention: + elif ( + self.attention_axis == AttentionAxis.SPACE + and not self.use_memory_efficient_attention + ): pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len n_heads variate head_dim" qkv = self.wQKV(inputs.contiguous()) - return rearrange(qkv, pattern, qkv=3, head_dim=self.head_dim, n_heads=self.num_heads).unbind(dim=0) + return rearrange( + qkv, pattern, qkv=3, head_dim=self.head_dim, n_heads=self.num_heads + ).unbind(dim=0) def positional_embedding(self, q, k, v, kv_cache, layer_idx): seq_pos_offset = 0 if self.rotary_emb is not None and self.attention_axis == AttentionAxis.TIME: if kv_cache is not None: seq_pos_offset = kv_cache.seq_len(layer_idx) - q, k = self.rotary_emb.rotate_queries_and_keys(q, k, seq_pos_offset=seq_pos_offset) + q, k = self.rotary_emb.rotate_queries_and_keys( + q, k, seq_pos_offset=seq_pos_offset + ) if kv_cache is not None and self.attention_axis == AttentionAxis.TIME: kv_cache.append(layer_idx, (k, v)) @@ -128,53 +151,89 @@ def positional_embedding(self, q, k, v, kv_cache, layer_idx): def rearrange_output( self, output: torch.Tensor, batch: int, variate: int, seq_len: int ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]: - if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention: + if ( + self.attention_axis == AttentionAxis.TIME + and self.use_memory_efficient_attention + ): pattern = "(batch variate) seq_len n_heads head_dim -> batch variate seq_len (n_heads head_dim)" - elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention: + elif ( + self.attention_axis == AttentionAxis.TIME + and not self.use_memory_efficient_attention + ): pattern = "(batch variate) n_heads seq_len head_dim -> batch variate seq_len (n_heads head_dim)" - elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention: + elif ( + self.attention_axis == AttentionAxis.SPACE + and self.use_memory_efficient_attention + ): pattern = "(batch seq_len) variate n_heads head_dim -> batch variate seq_len (n_heads head_dim)" - elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention: + elif ( + self.attention_axis == AttentionAxis.SPACE + and not self.use_memory_efficient_attention + ): pattern = "(batch seq_len) n_heads variate head_dim -> batch variate seq_len (n_heads head_dim)" return rearrange(output, pattern, batch=batch, variate=variate, seq_len=seq_len) - def run_attention(self, attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate): + def run_attention( + self, attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate + ): q_dim_start, q_dim_end = seq_pos_offset, seq_pos_offset + seq_len - kv_dim_start, kv_dim_end = 0, v.shape[1] if self.use_memory_efficient_attention else v.shape[2] - if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention: + kv_dim_start, kv_dim_end = 0, ( + v.shape[1] if self.use_memory_efficient_attention else v.shape[2] + ) + if ( + self.attention_axis == AttentionAxis.TIME + and self.use_memory_efficient_attention + ): attention_mask = ( attention_mask[..., q_dim_start:q_dim_end, kv_dim_start:kv_dim_end] if torch.is_tensor(attention_mask) else LowerTriangularMask() if seq_pos_offset == 0 else None ) - return memory_efficient_attention(q, k, v, attn_bias=attention_mask, p=dropout) - elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention: + return memory_efficient_attention( + q, k, v, attn_bias=attention_mask, p=dropout + ) + elif ( + self.attention_axis == AttentionAxis.TIME + and not self.use_memory_efficient_attention + ): attention_mask = ( attention_mask[..., q_dim_start:q_dim_end, kv_dim_start:kv_dim_end] if torch.is_tensor(attention_mask) else None ) return scaled_dot_product_attention( - q, k, v, + q, + k, + v, attn_mask=attention_mask, dropout_p=dropout, is_causal=(attention_mask is None and seq_pos_offset == 0), ) - elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention: + elif ( + self.attention_axis == AttentionAxis.SPACE + and self.use_memory_efficient_attention + ): attention_mask = ( attention_mask[..., kv_dim_start:kv_dim_end, kv_dim_start:kv_dim_end] if torch.is_tensor(attention_mask) else None ) - return memory_efficient_attention(q, k, v, attn_bias=attention_mask, p=dropout) - elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention: + return memory_efficient_attention( + q, k, v, attn_bias=attention_mask, p=dropout + ) + elif ( + self.attention_axis == AttentionAxis.SPACE + and not self.use_memory_efficient_attention + ): attention_mask = ( attention_mask[..., kv_dim_start:kv_dim_end, kv_dim_start:kv_dim_end] if torch.is_tensor(attention_mask) else None ) - return scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False) + return scaled_dot_product_attention( + q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False + ) def forward( self, @@ -194,9 +253,13 @@ def forward( rearranged_inputs = self.rearrange_inputs(inputs) q, k, v = self.get_qkv(rearranged_inputs) - q, k, v, seq_pos_offset = self.positional_embedding(q, k, v, kv_cache, layer_idx) + q, k, v, seq_pos_offset = self.positional_embedding( + q, k, v, kv_cache, layer_idx + ) - output = self.run_attention(attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate) + output = self.run_attention( + attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate + ) output = self.rearrange_output(output, batch_size, variate, seq_len) return self.wO(output) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py index b11c07c89080a..84fa537e3fc2b 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py @@ -105,8 +105,12 @@ def __init__( ) self.unembed = torch.nn.Linear(embed_dim, embed_dim * patch_size) - output_distribution_classes_ = [DISTRIBUTION_CLASSES_LOOKUP[c] for c in output_distribution_classes] - self.output_distribution = output_distribution_classes_[0](embed_dim, **(output_distribution_kwargs or {})) + output_distribution_classes_ = [ + DISTRIBUTION_CLASSES_LOOKUP[c] for c in output_distribution_classes + ] + self.output_distribution = output_distribution_classes_[0]( + embed_dim, **(output_distribution_kwargs or {}) + ) def allocate_kv_cache( self, @@ -152,9 +156,13 @@ def backbone( if kv_cache is not None: kv_cache_len_tensor = kv_cache.current_len(0) kv_cache_len = ( - int(kv_cache_len_tensor) if isinstance(kv_cache_len_tensor, torch.Tensor) else kv_cache_len_tensor + int(kv_cache_len_tensor) + if isinstance(kv_cache_len_tensor, torch.Tensor) + else kv_cache_len_tensor + ) + prefix_len = max( + 0, self.patch_embed.stride * (kv_cache_len - self.num_prepended_tokens) ) - prefix_len = max(0, self.patch_embed.stride * (kv_cache_len - self.num_prepended_tokens)) scaled_inputs = scaled_inputs[:, :, prefix_len:] @@ -167,18 +175,27 @@ def backbone( embeddings, reduced_id_mask = self.patch_embed(scaled_inputs, id_mask) - variate_label_embeds = self.build_variate_label_embeds(num_exogenous_variables, embeddings) + variate_label_embeds = self.build_variate_label_embeds( + num_exogenous_variables, embeddings + ) original_seq_len = embeddings.shape[2] - transformed = self.transformer(embeddings, reduced_id_mask, kv_cache, variate_label_embeds=variate_label_embeds) + transformed = self.transformer( + embeddings, + reduced_id_mask, + kv_cache, + variate_label_embeds=variate_label_embeds, + ) added_tokens = transformed.shape[2] - original_seq_len if added_tokens > 0: transformed = transformed[:, :, added_tokens:] - flattened: Float[torch.Tensor, "batch variates new_seq_len embed_dim"] = rearrange( - self.unembed(transformed), - "batch variates seq_len (patch_size embed_dim) -> batch variates (seq_len patch_size) embed_dim", - embed_dim=self.embed_dim, + flattened: Float[torch.Tensor, "batch variates new_seq_len embed_dim"] = ( + rearrange( + self.unembed(transformed), + "batch variates seq_len (patch_size embed_dim) -> batch variates (seq_len patch_size) embed_dim", + embed_dim=self.embed_dim, + ) ) return flattened, loc, scale @@ -227,13 +244,15 @@ def build_variate_label_embeds( batch_size, num_variates, _, _ = embeddings.shape - target_variate_label = repeat(self.target_variate_label, "d -> b v 1 d", b=batch_size, v=num_variates).to( - device=embeddings.device, dtype=embeddings.dtype - ) - exogenous_variate_label = repeat(self.exogenous_variate_label, "d -> b v 1 d", b=batch_size, v=num_variates).to( - device=embeddings.device, dtype=embeddings.dtype + target_variate_label = repeat( + self.target_variate_label, "d -> b v 1 d", b=batch_size, v=num_variates + ).to(device=embeddings.device, dtype=embeddings.dtype) + exogenous_variate_label = repeat( + self.exogenous_variate_label, "d -> b v 1 d", b=batch_size, v=num_variates + ).to(device=embeddings.device, dtype=embeddings.dtype) + exog_mask = torch.zeros( + 1, num_variates, 1, 1, dtype=torch.bool, device=embeddings.device ) - exog_mask = torch.zeros(1, num_variates, 1, 1, dtype=torch.bool, device=embeddings.device) if num_exogenous_variables > 0: exog_mask[:, -num_exogenous_variables:] = True return torch.where(exog_mask, exogenous_variate_label, target_variate_label) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py index ac7023321a31c..f34bd4afdf0aa 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py @@ -67,7 +67,9 @@ def forward(self, inputs, loc=None, scale=None): base_loc = self.loc_proj(inputs).squeeze(-1) base_scale = F.softplus(self.scale_proj(inputs)).clamp_min(eps).squeeze(-1) - base_dist = torch.distributions.StudentT(df, base_loc, base_scale, validate_args=False) + base_dist = torch.distributions.StudentT( + df, base_loc, base_scale, validate_args=False + ) if loc is not None and scale is not None: return AffineTransformed(base_dist, loc=loc, scale=scale) @@ -88,10 +90,14 @@ def __init__(self, embed_dim, k_components): def forward(self, inputs, loc=None, scale=None): df = 2.0 + F.softplus(self.df(inputs)).clamp_min(torch.finfo(inputs.dtype).eps) component_loc = self.loc_proj(inputs) - component_scale = F.softplus(self.scale_proj(inputs)).clamp_min(torch.finfo(inputs.dtype).eps) + component_scale = F.softplus(self.scale_proj(inputs)).clamp_min( + torch.finfo(inputs.dtype).eps + ) logits = self.mixture_weights(inputs) probs = F.softmax(logits, dim=-1) - components = torch.distributions.StudentT(df, component_loc, component_scale, validate_args=False) + components = torch.distributions.StudentT( + df, component_loc, component_scale, validate_args=False + ) mixture_distribution = torch.distributions.Categorical(probs=probs) return torch.distributions.MixtureSameFamily(mixture_distribution, components) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py index 248fa1326bedb..fc7eadac9af94 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py @@ -31,7 +31,9 @@ def patchify_id_mask( patched_id_mask = id_mask.unfold(dimension=-1, size=patch_size, step=patch_size) patched_id_mask_min = patched_id_mask.min(-1).values patched_id_mask_max = patched_id_mask.max(-1).values - assert torch.eq(patched_id_mask_min, patched_id_mask_max).all(), "Patches cannot span multiple datasets" + assert torch.eq( + patched_id_mask_min, patched_id_mask_max + ).all(), "Patches cannot span multiple datasets" return patched_id_mask_min @@ -64,8 +66,12 @@ def forward( assert ( x.shape[-1] % self.patch_size == 0 ), f"Series length ({x.shape=}) must be divisible by ({self.patch_size=})" - x_patched: Float[torch.Tensor, "batch variate seq_len patch_size"] = self._patchify(x) - id_mask_patched: Int[torch.Tensor, "batch variate seq_len patch_size"] = self._patchify(id_mask) + x_patched: Float[torch.Tensor, "batch variate seq_len patch_size"] = ( + self._patchify(x) + ) + id_mask_patched: Int[torch.Tensor, "batch variate seq_len patch_size"] = ( + self._patchify(id_mask) + ) assert torch.eq( id_mask_patched.min(-1).values, id_mask_patched.max(-1).values diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py index 4d6723b251cfa..cfe364ac91eb9 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py @@ -37,7 +37,9 @@ def __init__(self) -> None: def forward( self, embeddings: Float[torch.Tensor, "batch variate seq_len embed_dim"], - variate_label_embeds: Optional[Float[torch.Tensor, "batch variate 1 embed_dim"]] = None, + variate_label_embeds: Optional[ + Float[torch.Tensor, "batch variate 1 embed_dim"] + ] = None, ) -> Float[torch.Tensor, "batch variate new_seq_len embed_dim"]: if variate_label_embeds is None: @@ -46,6 +48,11 @@ def forward( processed_embeddings = F.normalize(variate_label_embeds, p=2, dim=-1) return torch.cat( - [processed_embeddings.to(dtype=embeddings.dtype, device=embeddings.device, non_blocking=True), embeddings], + [ + processed_embeddings.to( + dtype=embeddings.dtype, device=embeddings.device, non_blocking=True + ), + embeddings, + ], dim=2, ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py index c787b64fea849..e640e3ef3a213 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py @@ -78,8 +78,14 @@ def __call__( ) high_precision_data = data.to(torch.float32) - denominator = weights.sum(self.dim, keepdim=self.keepdim).clamp_min(1.0).to(high_precision_data.dtype) - means = (high_precision_data * weights).sum(self.dim, keepdim=self.keepdim) / denominator + denominator = ( + weights.sum(self.dim, keepdim=self.keepdim) + .clamp_min(1.0) + .to(high_precision_data.dtype) + ) + means = (high_precision_data * weights).sum( + self.dim, keepdim=self.keepdim + ) / denominator means = torch.nan_to_num(means) variance = (((high_precision_data - means) * weights) ** 2).sum( @@ -136,7 +142,9 @@ def compute_causal_statistics( shifted_means[..., 1:] = causal_means[..., :-1] delta = high_precision_data - shifted_means - increment = delta * (high_precision_data - causal_means) * high_precision_weights + increment = ( + delta * (high_precision_data - causal_means) * high_precision_weights + ) m_2 = torch.cumsum(increment, dim=dim) if use_bessel_correction: @@ -157,13 +165,17 @@ def compute_causal_statistics( scale_factor_min = 10.0 ** (-scale_factor_exponent) scale_factor_max = 10.0**scale_factor_exponent - global_denominator = (weights * padding_mask).sum(dim, keepdim=True).clamp_min(1.0) - global_means = (weighted_data).sum(dim, keepdim=True) / global_denominator - global_means = torch.nan_to_num(global_means) - - global_variance = (((high_precision_data - global_means) * weights * padding_mask) ** 2).sum( + global_denominator = ( + (weights * padding_mask).sum(dim, keepdim=True).clamp_min(1.0) + ) + global_means = (weighted_data).sum( dim, keepdim=True ) / global_denominator + global_means = torch.nan_to_num(global_means) + + global_variance = ( + ((high_precision_data - global_means) * weights * padding_mask) ** 2 + ).sum(dim, keepdim=True) / global_denominator global_scale = torch.sqrt(global_variance + minimum_scale) expanded_global_scale = global_scale.expand_as(causal_scale) @@ -172,7 +184,10 @@ def compute_causal_statistics( causal_scale = torch.clamp( causal_scale, - min=torch.max(torch.tensor(minimum_scale, device=causal_scale.device), min_allowed_scale), + min=torch.max( + torch.tensor(minimum_scale, device=causal_scale.device), + min_allowed_scale, + ), max=max_allowed_scale, ) @@ -211,7 +226,9 @@ def __call__( prefix_length: int | None = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert data.shape == weights.shape, "data and weights must have same shape" - assert len(data.shape) == 3, "Input data must have shape [batch, variates, time_steps]" + assert ( + len(data.shape) == 3 + ), "Input data must have shape [batch, variates, time_steps]" causal_means, causal_scale = compute_causal_statistics( data, @@ -241,7 +258,9 @@ def __init__( scale_factor_exponent: float = 10.0, ) -> None: super().__init__() - assert dim == -1, "CausalPatchStdMeanScaler only supports dim=-1 (last dimension)" + assert ( + dim == -1 + ), "CausalPatchStdMeanScaler only supports dim=-1 (last dimension)" self.dim = dim self.patch_size = patch_size self.minimum_scale = minimum_scale @@ -257,7 +276,9 @@ def __call__( prefix_length: int | None = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert data.shape == weights.shape, "data and weights must have same shape" - assert len(data.shape) == 3, "Input data must have shape [batch, variates, time_steps]" + assert ( + len(data.shape) == 3 + ), "Input data must have shape [batch, variates, time_steps]" with torch.no_grad(): time_steps = data.shape[-1] @@ -283,8 +304,12 @@ def __call__( patch_stats_means = means_unfolded[..., -1] patch_stats_scales = scales_unfolded[..., -1] - patch_means = repeat(patch_stats_means, "b v p -> b v (p s)", s=self.patch_size) - patch_scales = repeat(patch_stats_scales, "b v p -> b v (p s)", s=self.patch_size) + patch_means = repeat( + patch_stats_means, "b v p -> b v (p s)", s=self.patch_size + ) + patch_scales = repeat( + patch_stats_scales, "b v p -> b v (p s)", s=self.patch_size + ) scaled_data = (data - patch_means) / patch_scales diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py index 1710c9eeadf05..61595334171f5 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py @@ -93,7 +93,9 @@ def load_from_checkpoint( if os.path.exists(safetensors_file): model_state = safetorch.load_file(safetensors_file, device=map_location) else: - raise FileNotFoundError(f"Model checkpoint not found at: {safetensors_file}") + raise FileNotFoundError( + f"Model checkpoint not found at: {safetensors_file}" + ) config_file = os.path.join(checkpoint_path, "config.json") config = {} @@ -104,10 +106,14 @@ def load_from_checkpoint( config.update(model_kwargs) remapped_state_dict = cls._map_state_dict_keys( - model_state, XFORMERS_SWIGLU_AVAILABLE and not config.get("pre_xformers_checkpoint", False) + model_state, + XFORMERS_SWIGLU_AVAILABLE + and not config.get("pre_xformers_checkpoint", False), ) - if not XFORMERS_AVAILABLE and config.get("use_memory_efficient_attention", True): + if not XFORMERS_AVAILABLE and config.get( + "use_memory_efficient_attention", True + ): config["use_memory_efficient_attention"] = False instance = cls(**config) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py index 9979bac960269..58220c30e6259 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py @@ -131,7 +131,12 @@ def forward( kv_cache: Optional[KVCache] = None, ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]: pre_norm_1 = self.norm1(inputs) - hidden_state = inputs + self.attention(layer_idx, pre_norm_1, attention_mask, kv_cache).contiguous() + hidden_state = ( + inputs + + self.attention( + layer_idx, pre_norm_1, attention_mask, kv_cache + ).contiguous() + ) pre_norm_2 = self.norm2(hidden_state) return hidden_state + self.mlp(pre_norm_2) @@ -153,7 +158,9 @@ def __init__( ): super().__init__() - assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads." + assert ( + embed_dim % num_heads == 0 + ), "Embedding dimension must be divisible by number of heads." self.rotary_emb = TimeAwareRotaryEmbedding( embed_dim // num_heads, @@ -161,7 +168,9 @@ def __init__( cache_if_possible=True, seq_before_head_dim=use_memory_efficient_attention, ) - attention_axes = self._get_layer_types(num_layers, spacewise_every_n_layers, spacewise_first) + attention_axes = self._get_layer_types( + num_layers, spacewise_every_n_layers, spacewise_first + ) self.use_memory_efficient_attention = use_memory_efficient_attention self.fusion = fusion @@ -199,9 +208,17 @@ def _get_mask( if self.use_memory_efficient_attention: mask = self._pad_to_multiple(mask) - mask = mask.float().masked_fill(~mask, float("-inf")).masked_fill(mask, 0.0).to(dtype) + mask = ( + mask.float() + .masked_fill(~mask, float("-inf")) + .masked_fill(mask, 0.0) + .to(dtype) + ) - mask = rearrange(mask, "batch seq_len variate1 variate2 -> (batch seq_len) 1 variate1 variate2") + mask = rearrange( + mask, + "batch seq_len variate1 variate2 -> (batch seq_len) 1 variate1 variate2", + ) return mask.expand(-1, num_heads, -1, -1).contiguous() def _pad_to_multiple( @@ -214,7 +231,11 @@ def _pad_to_multiple( if pad_amount > 0: new_size = tensor.shape[-1] + pad_amount if causal: - full_mask = torch.tril(torch.ones((new_size, new_size), dtype=tensor.dtype, device=tensor.device)) + full_mask = torch.tril( + torch.ones( + (new_size, new_size), dtype=tensor.dtype, device=tensor.device + ) + ) full_mask[: tensor.shape[-1], : tensor.shape[-1]] = tensor tensor = full_mask else: @@ -245,14 +266,20 @@ def forward( inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"], id_mask: Float[torch.Tensor, "batch #variate seq_len"], kv_cache: Optional[KVCache] = None, - variate_label_embeds: Optional[Float[torch.Tensor, "batch variate 1 embed_dim"]] = None, + variate_label_embeds: Optional[ + Float[torch.Tensor, "batch variate 1 embed_dim"] + ] = None, ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]: if self.fusion is not None and variate_label_embeds is not None: should_apply_fusion = True if kv_cache is not None: kv_len_tensor = kv_cache.current_len(0) - kv_len = int(kv_len_tensor) if isinstance(kv_len_tensor, torch.Tensor) else kv_len_tensor + kv_len = ( + int(kv_len_tensor) + if isinstance(kv_len_tensor, torch.Tensor) + else kv_len_tensor + ) should_apply_fusion = kv_len == 0 if should_apply_fusion: inputs = self.fusion(inputs, variate_label_embeds=variate_label_embeds) @@ -281,7 +308,11 @@ def forward( inputs = layer( layer_idx, inputs, - (timewise_attention_mask if layer.attention_axis == AttentionAxis.TIME else spacewise_attention_mask), + ( + timewise_attention_mask + if layer.attention_axis == AttentionAxis.TIME + else spacewise_attention_mask + ), kv_cache, ) return inputs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py index 182c38d195653..d913329e7e89d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py @@ -59,13 +59,18 @@ def __init__(self, dim: int, include_weight: bool = True, eps: float = 1e-8): super(RMSNorm, self).__init__() self.eps = eps if include_weight: - self.scale: Optional[torch.nn.Parameter] = torch.nn.Parameter(torch.ones(dim)) + self.scale: Optional[torch.nn.Parameter] = torch.nn.Parameter( + torch.ones(dim) + ) else: self.scale = None def forward(self, x: torch.Tensor): if ( - ((not self.training) or (self.scale is not None and not self.scale.requires_grad)) + ( + (not self.training) + or (self.scale is not None and not self.scale.requires_grad) + ) and XFORMERS_RMSNORM_AVAILABLE and _is_triton_available() ): @@ -75,7 +80,9 @@ def forward(self, x: torch.Tensor): return x_normed if self.scale is None else x_normed * self.scale def increment_and_forward_(self, x: torch.Tensor, y: torch.Tensor): - if (not self.training) or (self.scale is not None and not self.scale.requires_grad): + if (not self.training) or ( + self.scale is not None and not self.scale.requires_grad + ): return rms_norm_add(x, y, self.scale, self.eps) return self.forward(x + y) @@ -85,8 +92,12 @@ def make_batched_block_mask(t: torch.Tensor) -> torch.Tensor: return unsqueezed == unsqueezed.transpose(-1, -2) -K: TypeAlias = Float[torch.Tensor, "batch_size_X_num_variates num_heads seq_len head_dim"] -V: TypeAlias = Float[torch.Tensor, "batch_size_X_num_variates num_heads seq_len head_dim"] +K: TypeAlias = Float[ + torch.Tensor, "batch_size_X_num_variates num_heads seq_len head_dim" +] +V: TypeAlias = Float[ + torch.Tensor, "batch_size_X_num_variates num_heads seq_len head_dim" +] KV: TypeAlias = tuple[K, V] @@ -109,26 +120,42 @@ class KVCache: use_memory_efficient_attention: bool = True _keys: Union[ - Float[torch.Tensor, "time_layer_count batch_size_X_num_variates max_seq_len num_heads head_dim"], - Float[torch.Tensor, "time_layer_count batch_size_X_num_variates num_heads max_seq_len head_dim"], + Float[ + torch.Tensor, + "time_layer_count batch_size_X_num_variates max_seq_len num_heads head_dim", + ], + Float[ + torch.Tensor, + "time_layer_count batch_size_X_num_variates num_heads max_seq_len head_dim", + ], ] = field(init=False) _values: Union[ - Float[torch.Tensor, "time_layer_count batch_size_X_num_variates max_seq_len num_heads head_dim"], - Float[torch.Tensor, "time_layer_count batch_size_X_num_variates num_heads max_seq_len head_dim"], + Float[ + torch.Tensor, + "time_layer_count batch_size_X_num_variates max_seq_len num_heads head_dim", + ], + Float[ + torch.Tensor, + "time_layer_count batch_size_X_num_variates num_heads max_seq_len head_dim", + ], ] = field(init=False) _current_idx: Int[torch.Tensor, "time_layer_count"] = field(init=False) _layer_cache_map: Int[torch.Tensor, "num_layers"] = field(init=False) def __post_init__(self): - assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads" + assert ( + self.embed_dim % self.num_heads == 0 + ), "embed_dim must be divisible by num_heads" head_dim = self.embed_dim // self.num_heads time_layer_indices = [ i for i in range(self.num_layers) - if isinstance(self.transformer_layers[i].attention, TimeWiseMultiheadAttention) + if isinstance( + self.transformer_layers[i].attention, TimeWiseMultiheadAttention + ) ] time_layer_count = max(1, len(time_layer_indices)) @@ -150,8 +177,12 @@ def __post_init__(self): ) self._keys = torch.zeros(shape, device=self.device, dtype=self.dtype) self._values = torch.zeros_like(self._keys) - self._current_idx = torch.zeros(time_layer_count, device=self.device, dtype=torch.int) - self._layer_cache_map = torch.zeros((self.num_layers,), dtype=torch.int, device=self.device) + self._current_idx = torch.zeros( + time_layer_count, device=self.device, dtype=torch.int + ) + self._layer_cache_map = torch.zeros( + (self.num_layers,), dtype=torch.int, device=self.device + ) for cache_idx, layer_idx in enumerate(time_layer_indices): self._layer_cache_map[layer_idx] = int(cache_idx) @@ -160,12 +191,22 @@ def __getitem__(self, layer_idx: int) -> KV: end_idx = int(self._current_idx[cache_idx].item()) if self.use_memory_efficient_attention: - return self._keys[cache_idx, :, :end_idx, :, :], self._values[cache_idx, :, :end_idx, :, :] + return ( + self._keys[cache_idx, :, :end_idx, :, :], + self._values[cache_idx, :, :end_idx, :, :], + ) else: - return self._keys[cache_idx, :, :, :end_idx, :], self._values[cache_idx, :, :, :end_idx, :] + return ( + self._keys[cache_idx, :, :, :end_idx, :], + self._values[cache_idx, :, :, :end_idx, :], + ) def current_len(self, cache_idx: int) -> int: - return int(self._current_idx[cache_idx].item()) if self._current_idx.numel() > 0 else 0 + return ( + int(self._current_idx[cache_idx].item()) + if self._current_idx.numel() > 0 + else 0 + ) def seq_len(self, layer_idx: int) -> int: cache_idx = int(self._layer_cache_map[layer_idx].item()) @@ -191,9 +232,9 @@ def append(self, layer_idx: int, kv: KV): end_idx = start_idx + keys.shape[1] else: end_idx = start_idx + keys.shape[2] - assert end_idx <= self.max_seq_len, ( - f"max_seq_len exceeded {end_idx} > {self.max_seq_len}, keys.shape: {keys.shape}" - ) + assert ( + end_idx <= self.max_seq_len + ), f"max_seq_len exceeded {end_idx} > {self.max_seq_len}, keys.shape: {keys.shape}" if self.use_memory_efficient_attention: self._keys[cache_idx, :, start_idx:end_idx, :, :] = keys From 2ed5e91fa024656863f91d2c4b32e0904910c7f2 Mon Sep 17 00:00:00 2001 From: graceli02 Date: Mon, 23 Mar 2026 17:29:05 -0400 Subject: [PATCH 08/11] [AINode] Fix build_binary.py: print poetry lock errors to CI logs --- iotdb-core/ainode/build_binary.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/iotdb-core/ainode/build_binary.py b/iotdb-core/ainode/build_binary.py index bea1970626b56..f3b7fa1cedf3e 100644 --- a/iotdb-core/ainode/build_binary.py +++ b/iotdb-core/ainode/build_binary.py @@ -423,7 +423,7 @@ def verify_poetry_env(): [str(poetry_exe), "lock"], cwd=str(script_dir), env=venv_env, - check=True, + check=False, capture_output=True, text=True, ) @@ -431,6 +431,9 @@ def verify_poetry_env(): print(result.stdout) if result.stderr: print(result.stderr) + if result.returncode != 0: + print(f"ERROR: poetry lock failed with exit code {result.returncode}") + sys.exit(1) verify_poetry_env() # Verify after lock accelerator = detect_accelerator() From b93ecfec3ac294e310ad0a3105de997738b89f24 Mon Sep 17 00:00:00 2001 From: graceli02 Date: Wed, 25 Mar 2026 14:58:12 -0400 Subject: [PATCH 09/11] address review comments --- iotdb-core/ainode/.gitignore | 3 -- .../ainode/core/model/toto/pipeline_toto.py | 46 +++++++++++++------ 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/iotdb-core/ainode/.gitignore b/iotdb-core/ainode/.gitignore index 60ecdcf76b52e..e2e900a491a27 100644 --- a/iotdb-core/ainode/.gitignore +++ b/iotdb-core/ainode/.gitignore @@ -21,6 +21,3 @@ poetry.lock /dist/ /build/ -# Un-ignore toto source data/ package (Python source, not data files) -!iotdb/ainode/core/model/toto/data/ -!iotdb/ainode/core/model/toto/data/** diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py index 311e679d3b703..62c8c98b5f6ca 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py @@ -16,17 +16,14 @@ # under the License. # -import logging -import warnings - import torch from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.toto.data.util.dataset import MaskedTimeseries +from iotdb.ainode.core.model.toto.inference.forecaster import TotoForecaster -from .data.util.dataset import MaskedTimeseries -from .inference.forecaster import TotoForecaster - -logger = logging.getLogger(__name__) +logger = Logger() class TotoPipeline(ForecastPipeline): @@ -51,9 +48,30 @@ def _get_forecaster(self) -> TotoForecaster: return self._forecaster def preprocess(self, inputs, **infer_kwargs): - super().preprocess(inputs, **infer_kwargs) - processed_inputs = [] + """ + Preprocess input data for Toto. + + Delegates to the base class for input validation, then converts each + validated input dict into a ``MaskedTimeseries`` named-tuple that the + ``TotoForecaster`` expects. + Parameters + ---------- + inputs : list of dict + A list of dictionaries containing input data. Each dictionary contains: + - 'targets': A tensor (1D or 2D) of shape (input_length,) or (target_count, input_length). + + infer_kwargs: Additional keyword arguments for inference, such as: + - `output_length`(int): Prediction length. + + Returns + ------- + list of MaskedTimeseries + Processed inputs compatible with Toto's forecaster. + """ + inputs = super().preprocess(inputs, **infer_kwargs) + + processed_inputs = [] for item in inputs: targets = item["targets"] if targets.ndim == 1: @@ -63,10 +81,8 @@ def preprocess(self, inputs, **infer_kwargs): device = targets.device if "past_covariates" in item or "future_covariates" in item: - warnings.warn( - "TotoPipeline does not support covariates; they will be ignored.", - UserWarning, - stacklevel=2, + logger.warning( + "TotoPipeline does not support covariates; they will be ignored." ) padding_mask = ~torch.isnan(targets) @@ -96,7 +112,7 @@ def preprocess(self, inputs, **infer_kwargs): return processed_inputs - def forecast(self, inputs, **infer_kwargs): + def forecast(self, inputs, **infer_kwargs) -> list[torch.Tensor]: output_length = infer_kwargs.get("output_length", 96) num_samples = infer_kwargs.get("num_samples", None) samples_per_batch = infer_kwargs.get("samples_per_batch", 10) @@ -127,5 +143,5 @@ def forecast(self, inputs, **infer_kwargs): outputs.append(mean) return outputs - def postprocess(self, outputs, **infer_kwargs): + def postprocess(self, outputs, **infer_kwargs) -> list[torch.Tensor]: return super().postprocess(outputs, **infer_kwargs) From cb7f6f1f4fc42379480d45ce211227e5ecc1ece0 Mon Sep 17 00:00:00 2001 From: graceli02 Date: Wed, 25 Mar 2026 21:59:32 -0400 Subject: [PATCH 10/11] [AINode] Adapt TotoPipeline to new inherit policy Rename preprocess/postprocess to _preprocess/_postprocess and remove super() calls to match the refactored ForecastPipeline base class. Co-Authored-By: Claude Opus 4.6 --- .../iotdb/ainode/core/model/toto/pipeline_toto.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py index 62c8c98b5f6ca..c6778a5e90bb6 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py @@ -47,13 +47,12 @@ def _get_forecaster(self) -> TotoForecaster: self._forecaster = TotoForecaster(self.model.backbone) return self._forecaster - def preprocess(self, inputs, **infer_kwargs): + def _preprocess(self, inputs, **infer_kwargs): """ Preprocess input data for Toto. - Delegates to the base class for input validation, then converts each - validated input dict into a ``MaskedTimeseries`` named-tuple that the - ``TotoForecaster`` expects. + Converts each input dict into a ``MaskedTimeseries`` named-tuple that + the ``TotoForecaster`` expects. Parameters ---------- @@ -69,8 +68,6 @@ def preprocess(self, inputs, **infer_kwargs): list of MaskedTimeseries Processed inputs compatible with Toto's forecaster. """ - inputs = super().preprocess(inputs, **infer_kwargs) - processed_inputs = [] for item in inputs: targets = item["targets"] @@ -143,5 +140,5 @@ def forecast(self, inputs, **infer_kwargs) -> list[torch.Tensor]: outputs.append(mean) return outputs - def postprocess(self, outputs, **infer_kwargs) -> list[torch.Tensor]: - return super().postprocess(outputs, **infer_kwargs) + def _postprocess(self, outputs, **infer_kwargs) -> list[torch.Tensor]: + return outputs From 7ceda398c8eb1bc338ed632e94c6dad872cb4131 Mon Sep 17 00:00:00 2001 From: graceli02 Date: Wed, 25 Mar 2026 22:01:13 -0400 Subject: [PATCH 11/11] [AINode] Revert .gitignore trailing newline diff Co-Authored-By: Claude Opus 4.6 --- iotdb-core/ainode/.gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/iotdb-core/ainode/.gitignore b/iotdb-core/ainode/.gitignore index e2e900a491a27..e4947e516d37a 100644 --- a/iotdb-core/ainode/.gitignore +++ b/iotdb-core/ainode/.gitignore @@ -20,4 +20,3 @@ poetry.lock # generated by pyinstaller /dist/ /build/ -