diff --git a/NOTICE b/NOTICE index fa52a36987f4..429495c377b0 100644 --- a/NOTICE +++ b/NOTICE @@ -17,6 +17,15 @@ grant the users the right to the use of patent under the requirement of Apache 2 ============================================================================ +This product includes source code derived from the DataDog/toto project: + + Toto – Timeseries-Optimized Transformer for Observability + Copyright 2025 Datadog, Inc. + Licensed under the Apache License, Version 2.0 + https://github.com/DataDog/toto + +============================================================================ + Apache Commons Collections Copyright 2001-2019 The Apache Software Foundation diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java index e41d3d4e0f97..bf758a083d46 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java @@ -58,7 +58,9 @@ public class AINodeTestUtils { new AbstractMap.SimpleEntry<>( "chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active")), new AbstractMap.SimpleEntry<>( - "moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active"))) + "moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "toto", new FakeModelInfo("toto", "toto", "builtin", "active"))) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); public static final Map BUILTIN_MODEL_MAP; diff --git a/iotdb-core/ainode/build_binary.py b/iotdb-core/ainode/build_binary.py index c943de415817..f3b7fa1cedf3 100644 --- a/iotdb-core/ainode/build_binary.py +++ b/iotdb-core/ainode/build_binary.py @@ -423,7 +423,7 @@ def verify_poetry_env(): [str(poetry_exe), "lock"], cwd=str(script_dir), env=venv_env, - check=True, + check=False, capture_output=True, text=True, ) @@ -431,6 +431,9 @@ def verify_poetry_env(): print(result.stdout) if result.stderr: print(result.stderr) + if result.returncode != 0: + print(f"ERROR: poetry lock failed with exit code {result.returncode}") + sys.exit(1) verify_poetry_env() # Verify after lock accelerator = detect_accelerator() @@ -438,11 +441,10 @@ def verify_poetry_env(): print("Running poetry install...") subprocess.run( - [str(poetry_exe), "lock"], + [str(poetry_exe), "install", "--no-root"], cwd=str(script_dir), env=venv_env, check=True, - capture_output=True, text=True, ) verify_poetry_env() # Verify before install diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index da752cbd7843..642986c42d21 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -160,4 +160,17 @@ def __repr__(self): }, transformers_registered=True, ), + "toto": ModelInfo( + model_id="toto", + category=ModelCategory.BUILTIN, + state=ModelStates.INACTIVE, + model_type="toto", + pipeline_cls="pipeline_toto.TotoPipeline", + repo_id="Datadog/Toto-Open-Base-1.0", + auto_map={ + "AutoConfig": "configuration_toto.TotoConfig", + "AutoModelForCausalLM": "modeling_toto.TotoForPrediction", + }, + transformers_registered=True, + ), } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/__init__.py new file mode 100644 index 000000000000..2a1e720805f2 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py new file mode 100644 index 000000000000..2a00fcc3be45 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import List, Optional + +from transformers import PretrainedConfig + + +class TotoConfig(PretrainedConfig): + """ + Configuration class for the Toto time series forecasting model. + + Toto (Time Series Optimized Transformer for Observability) is a foundation model + for multivariate time series forecasting developed by Datadog. It uses a decoder-only + architecture with per-variate patch-based causal scaling, proportional time-variate + factorized attention, and a Student-T mixture prediction head. + + Reference: https://github.com/DataDog/toto + """ + + model_type = "toto" + + def __init__( + self, + patch_size: int = 32, + stride: int = 32, + embed_dim: int = 1024, + num_layers: int = 18, + num_heads: int = 16, + mlp_hidden_dim: int = 2816, + dropout: float = 0.0, + spacewise_every_n_layers: int = 3, + scaler_cls: str = "per_variate_causal", + output_distribution_classes: Optional[List[str]] = None, + output_distribution_kwargs: Optional[dict] = None, + spacewise_first: bool = True, + use_memory_efficient_attention: bool = True, + stabilize_with_global: bool = True, + scale_factor_exponent: float = 10.0, + **kwargs, + ): + self.patch_size = patch_size + self.stride = stride + self.embed_dim = embed_dim + self.num_layers = num_layers + self.num_heads = num_heads + self.mlp_hidden_dim = mlp_hidden_dim + self.dropout = dropout + self.spacewise_every_n_layers = spacewise_every_n_layers + self.scaler_cls = scaler_cls + self.output_distribution_classes = output_distribution_classes or [ + "student_t_mixture" + ] + # k_components=5 is the default used by Datadog/Toto-Open-Base-1.0 + self.output_distribution_kwargs = output_distribution_kwargs or { + "k_components": 5 + } + self.spacewise_first = spacewise_first + self.use_memory_efficient_attention = use_memory_efficient_attention + self.stabilize_with_global = stabilize_with_global + self.scale_factor_exponent = scale_factor_exponent + + super().__init__(**kwargs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/__init__.py new file mode 100644 index 000000000000..ba26b1edd945 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/__init__.py new file mode 100644 index 000000000000..ba26b1edd945 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/dataset.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/dataset.py new file mode 100644 index 000000000000..6bccf35988c2 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/dataset.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +from functools import reduce +from typing import NamedTuple + +import numpy as np +import torch +import torch.utils.data +from einops import repeat +from jaxtyping import Bool, Float, Int, Shaped + + +def pad_array( + values: Shaped[torch.Tensor, "*batch variates series_len"], # noqa: F722 + patch_stride: int, +) -> Shaped[torch.Tensor, "*batch variates padded_length"]: # noqa: F722 + """ + Makes sure that the series length is divisible by the patch_stride + by adding left-padding. + """ + if isinstance(values, np.ndarray): + values = torch.from_numpy(values) + series_len = values.shape[-1] + padded_length = int(np.ceil(series_len / patch_stride) * patch_stride) + if values.ndim == 2: + padded_values = torch.zeros((values.shape[0], padded_length), dtype=values.dtype, device=values.device) + elif values.ndim == 3: + padded_values = torch.zeros( + (values.shape[0], values.shape[1], padded_length), + dtype=values.dtype, + device=values.device, + ) + else: + raise ValueError(f"Unsupported number of dimensions: {values.ndim}") + padded_values[..., -series_len:] = values + + return padded_values + + +def pad_id_mask( + id_mask: Int[torch.Tensor, "*batch variates series_len"], # noqa: F722 + patch_stride: int, +) -> Int[torch.Tensor, "*batch variates padded_length"]: # noqa: F722 + """ + Makes sure that the series length is divisible by the patch_stride + by adding left-padding to the id mask. + """ + series_len = id_mask.shape[-1] + padded_length = int(np.ceil(series_len / patch_stride) * patch_stride) + padding_amount = padded_length - series_len + left_edge: Int[torch.Tensor, "*batch variates"] = id_mask[..., 0] # noqa: F722 + if id_mask.ndim == 2: + padding = repeat( + left_edge, + "variates -> variates padding_amount", + padding_amount=padding_amount, + ) + id_mask = torch.cat([padding, id_mask], dim=1) + elif id_mask.ndim == 3: + padding = repeat( + left_edge, + "batch variates -> batch variates padding_amount", + padding_amount=padding_amount, + ) + id_mask = torch.cat([padding, id_mask], dim=2) + else: + raise ValueError(f"Unsupported number of dimensions: {id_mask.ndim}") + + return id_mask + + +class MaskedTimeseries(NamedTuple): + series: Float[torch.Tensor, "*batch variates series_len"] # noqa: F722 + padding_mask: Bool[torch.Tensor, "*batch variates series_len"] # noqa: F722 + id_mask: Int[torch.Tensor, "*batch variates #series_len"] # noqa: F722 + timestamp_seconds: Int[torch.Tensor, "*batch variates series_len"] # noqa: F722 + time_interval_seconds: Int[torch.Tensor, "*batch variates"] # noqa: F722 + num_exogenous_variables: int = 0 + + def to(self, device: torch.device) -> "MaskedTimeseries": + return MaskedTimeseries( + series=self.series.to(device), + padding_mask=self.padding_mask.to(device), + id_mask=self.id_mask.to(device), + timestamp_seconds=self.timestamp_seconds.to(device), + time_interval_seconds=self.time_interval_seconds.to(device), + num_exogenous_variables=self.num_exogenous_variables, + ) + + +def is_extreme_value(t: torch.Tensor) -> torch.Tensor: + if torch.is_floating_point(t): + max_value = torch.finfo(t.dtype).max + else: + max_value = torch.iinfo(t.dtype).max + + return reduce( + torch.logical_or, + ( + torch.isinf(t), + torch.isnan(t), + t.abs() >= max_value / 2, + ), + ) + + +def replace_extreme_values(t: torch.Tensor, replacement: float = 0.0) -> torch.Tensor: + return torch.where(is_extreme_value(t), torch.tensor(replacement, dtype=t.dtype, device=t.device), t) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/__init__.py new file mode 100644 index 000000000000..ba26b1edd945 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py new file mode 100644 index 000000000000..2a9db2aa6295 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py @@ -0,0 +1,452 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +from dataclasses import dataclass +from typing import cast + +import numpy as np +import torch +from einops import rearrange, repeat +from jaxtyping import Bool, Float, Int +from torch.distributions import Distribution, TransformedDistribution +from torch.distributions.transforms import AffineTransform + +from ..data.util.dataset import ( + MaskedTimeseries, + pad_array, + pad_id_mask, + replace_extreme_values, +) +from ..model.backbone import TotoBackbone + + +class AffineTransformed(TransformedDistribution): + """ + Thin wrapper around TransformedDistribution with AffineTransform, + replacing gluonts.torch.distributions.AffineTransformed. + """ + + def __init__(self, base_distribution, loc=0.0, scale=1.0): + super().__init__(base_distribution, AffineTransform(loc=loc, scale=scale)) + + @property + def mean(self): + loc = self.transforms[0].loc + scale = self.transforms[0].scale + return loc + scale * self.base_dist.mean + + # Note: Do NOT override sample() here. TransformedDistribution.sample() correctly + # calls base_dist.sample() (not rsample), which works for non-reparameterizable + # distributions like MixtureSameFamily. + + +@dataclass(frozen=True) +class Forecast: + mean: Float[torch.Tensor, "batch variate future_time_steps"] + samples: Float[torch.Tensor, "batch variate future_time_steps samples"] | None = ( + None + ) + + def quantile( + self, q: float | torch.Tensor + ) -> Float[torch.Tensor, "batch variate future_time_steps"]: + assert self.samples is not None, "samples must be provided to compute quantiles" + assert isinstance(q, (float, torch.Tensor)), "q must be a float or a tensor" + if isinstance(q, float): + q = torch.tensor(q, device=self.samples.device, dtype=self.samples.dtype) + return self.samples.quantile(q, dim=-1) + + @property + def median(self) -> Float[torch.Tensor, "batch variate future_time_steps"]: + return self.quantile(0.5) + + @property + def std(self) -> Float[torch.Tensor, "batch variate future_time_steps"]: + assert ( + self.samples is not None + ), "samples must be provided to compute standard deviation" + return self.samples.std(dim=-1) + + +class TotoForecaster: + """ + A forecaster class for the Toto model that handles autoregressive decoding + for time series forecasting. + """ + + model: TotoBackbone + + def __init__(self, model: TotoBackbone): + self.model = model + self.model.eval() + + def forecast( + self, + inputs: MaskedTimeseries, + prediction_length: int, + num_samples: int | None = None, + samples_per_batch: int = 10, + use_kv_cache: bool = True, + future_exogenous_variables: ( + Float[torch.Tensor, "batch exogenous_variables future_time_steps"] | None + ) = None, + ) -> Forecast: + if len(inputs.series.shape) == 2: + batch = cast(MaskedTimeseries, torch.utils.data.default_collate([inputs])) + else: + batch = inputs + + if ( + future_exogenous_variables is not None + and len(future_exogenous_variables.shape) == 2 + ): + future_exogenous_variables = future_exogenous_variables.unsqueeze(0) + + series = pad_array(batch.series, self.model.patch_embed.stride) + padding_mask = pad_array(batch.padding_mask, self.model.patch_embed.stride) + id_mask = batch.id_mask + if id_mask is not None: + id_mask = pad_id_mask(batch.id_mask, self.model.patch_embed.stride) + timestamp_seconds = pad_array( + batch.timestamp_seconds, self.model.patch_embed.stride + ) + time_interval_seconds: Int[torch.Tensor, "batch variate series_len"] = ( + torch.as_tensor( + batch.time_interval_seconds, device=series.device, dtype=torch.int + ) + ) + + if num_samples is not None: + samples = self.generate_samples( + inputs=series, + prediction_length=prediction_length, + num_samples=num_samples, + timestamp_seconds=timestamp_seconds, + time_interval_seconds=time_interval_seconds, + input_padding_mask=padding_mask, + id_mask=id_mask, + sampling_batch_size=samples_per_batch, + use_kv_cache=use_kv_cache, + future_exogenous_variables=future_exogenous_variables, + num_exogenous_variables=batch.num_exogenous_variables, + ) + mean = samples.mean(dim=-1) + else: + mean = self.generate_mean( + inputs=series, + prediction_length=prediction_length, + timestamp_seconds=timestamp_seconds, + time_interval_seconds=time_interval_seconds, + input_padding_mask=padding_mask, + id_mask=id_mask, + use_kv_cache=use_kv_cache, + future_exogenous_variables=future_exogenous_variables, + num_exogenous_variables=batch.num_exogenous_variables, + ) + samples = None + + return Forecast(mean=mean, samples=samples) + + def assert_ev_compatibility( + self, + inputs, + future_exogenous_variables, + prediction_length, + num_exogenous_variables, + ) -> None: + assert future_exogenous_variables.shape[-1] == prediction_length + assert future_exogenous_variables.shape[0] == inputs.shape[0] + assert num_exogenous_variables == future_exogenous_variables.shape[-2] + + def round_ft_ev(self, future_exogenous_variables, T_rounded): + B, V_ev, T_future = future_exogenous_variables.shape + dtype = future_exogenous_variables.dtype + device = future_exogenous_variables.device + padding = torch.zeros(B, V_ev, T_rounded - T_future, device=device, dtype=dtype) + return torch.cat([future_exogenous_variables, padding], dim=-1) + + @torch.no_grad() + def generate_mean( + self, + inputs: Float[torch.Tensor, "batch variate time_steps"], + prediction_length: int, + timestamp_seconds: Int[torch.Tensor, "batch variate time_steps"], + time_interval_seconds: Int[torch.Tensor, "batch variate"], + input_padding_mask: ( + Bool[torch.Tensor, "batch variate time_steps"] | None + ) = None, + id_mask: Float[torch.Tensor, "batch #variate time_steps"] | None = None, + use_kv_cache: bool = False, + future_exogenous_variables=None, + num_exogenous_variables: int = 0, + ) -> Float[torch.Tensor, "batch variate time_steps"]: + if input_padding_mask is None: + input_padding_mask = torch.ones_like( + inputs, dtype=torch.bool, device=inputs.device + ) + if id_mask is None: + id_mask = torch.zeros_like(inputs, dtype=torch.int, device=inputs.device) + + if future_exogenous_variables is not None: + self.assert_ev_compatibility( + inputs, + future_exogenous_variables, + prediction_length, + num_exogenous_variables, + ) + + patch_size = self.model.patch_embed.stride + rounded_steps = int(np.ceil(prediction_length / patch_size) * patch_size) + if rounded_steps > prediction_length and future_exogenous_variables is not None: + future_exogenous_variables = self.round_ft_ev( + future_exogenous_variables, rounded_steps + ) + start_index = inputs.shape[-1] + end_index = start_index + prediction_length + + dummy_padding = torch.ones( + (input_padding_mask.shape[0], input_padding_mask.shape[1], patch_size), + device=inputs.device, + dtype=torch.bool, + ) + dummy_id_mask = repeat( + id_mask[:, :, -1:], + "batch variates 1 -> batch variates patch_size", + patch_size=patch_size, + ) + if use_kv_cache: + kv_cache = self.model.allocate_kv_cache( + batch_size=inputs.shape[0], + num_variates=inputs.shape[1], + max_time_steps=inputs.shape[2] + rounded_steps, + device=inputs.device, + dtype=inputs.dtype, + ) + else: + kv_cache = None + + scaling_prefix_length = inputs.shape[-1] + + for idx in range(rounded_steps // patch_size): + base_distr, loc, scale = self.model( + inputs=inputs, + input_padding_mask=input_padding_mask, + id_mask=id_mask, + kv_cache=kv_cache, + scaling_prefix_length=scaling_prefix_length, + num_exogenous_variables=num_exogenous_variables, + ) + distr = self.create_affine_transformed(base_distr, loc, scale) + + samples = replace_extreme_values(distr.mean[:, :, -patch_size:]) + + if future_exogenous_variables is not None: + start, stop = idx * patch_size, (idx + 1) * patch_size + samples[:, -num_exogenous_variables:] = future_exogenous_variables[ + :, :, start:stop + ] + + inputs = torch.cat([inputs, samples], dim=-1) + id_mask = torch.cat([id_mask, dummy_id_mask], dim=-1) + input_padding_mask = torch.cat([input_padding_mask, dummy_padding], dim=-1) + for _ in range(patch_size): + next_timestamp = timestamp_seconds[:, :, -1] + time_interval_seconds + timestamp_seconds = torch.cat( + [timestamp_seconds, next_timestamp.unsqueeze(-1)], dim=-1 + ) + + return inputs.detach()[:, :, start_index:end_index] + + @torch.no_grad() + def generate_samples( + self, + inputs: Float[torch.Tensor, "batch variate time_steps"], + prediction_length: int, + num_samples: int, + timestamp_seconds: Int[torch.Tensor, "batch variate time_steps"], + time_interval_seconds: Int[torch.Tensor, "batch variate"], + input_padding_mask: ( + Bool[torch.Tensor, "batch variate time_steps"] | None + ) = None, + id_mask: Float[torch.Tensor, "batch #variate time_steps"] | None = None, + sampling_batch_size: int = 10, + use_kv_cache: bool = False, + future_exogenous_variables=None, + num_exogenous_variables: int = 0, + ) -> Float[torch.Tensor, "batch variate time_steps samples"]: + if input_padding_mask is None: + input_padding_mask = torch.ones_like( + inputs, dtype=torch.bool, device=inputs.device + ) + if id_mask is None: + id_mask = torch.zeros_like(inputs, dtype=torch.int, device=inputs.device) + + if future_exogenous_variables is not None: + self.assert_ev_compatibility( + inputs, + future_exogenous_variables, + prediction_length, + num_exogenous_variables, + ) + + assert ( + num_samples % sampling_batch_size == 0 + ), "num_samples must be divisible by sampling_batch_size" + num_batches = num_samples // sampling_batch_size + + patch_size = self.model.patch_embed.patch_size + rounded_steps = int(np.ceil(prediction_length / patch_size) * patch_size) + if rounded_steps > prediction_length and future_exogenous_variables is not None: + future_exogenous_variables = self.round_ft_ev( + future_exogenous_variables, rounded_steps + ) + start_index = inputs.shape[-1] + end_index = start_index + prediction_length + + dummy_padding = torch.ones( + ( + input_padding_mask.shape[0] * sampling_batch_size, + input_padding_mask.shape[1], + patch_size, + ), + dtype=torch.bool, + device=inputs.device, + ) + dummy_id_mask = repeat( + id_mask[:, :, -1:], + "batch variates 1 -> (sampling_batch_size batch) variates patch_size", + sampling_batch_size=sampling_batch_size, + patch_size=patch_size, + ) + inputs = repeat( + inputs, + "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", + sampling_batch_size=sampling_batch_size, + ) + if future_exogenous_variables is not None: + future_exogenous_variables = repeat( + future_exogenous_variables, + "batch exogenous_variables future_time_steps -> (sampling_batch_size batch) exogenous_variables future_time_steps", + sampling_batch_size=sampling_batch_size, + ) + input_padding_mask = repeat( + input_padding_mask, + "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", + sampling_batch_size=sampling_batch_size, + ) + id_mask = repeat( + id_mask, + "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", + sampling_batch_size=sampling_batch_size, + ) + timestamp_seconds = repeat( + timestamp_seconds, + "batch variates seq_len -> (sampling_batch_size batch) variates seq_len", + sampling_batch_size=sampling_batch_size, + ) + time_interval_seconds = repeat( + time_interval_seconds, + "batch variates -> (sampling_batch_size batch) variates", + sampling_batch_size=sampling_batch_size, + ) + + all_samples = [] + if use_kv_cache: + kv_cache = self.model.allocate_kv_cache( + batch_size=inputs.shape[0], + num_variates=inputs.shape[1], + max_time_steps=inputs.shape[2] + rounded_steps, + device=inputs.device, + dtype=inputs.dtype, + ) + else: + kv_cache = None + + scaling_prefix_length = inputs.shape[-1] + + for _ in range(num_batches): + batch_inputs = torch.clone(inputs) + batch_input_padding_mask = torch.clone(input_padding_mask) + batch_id_mask = torch.clone(id_mask) + batch_timestamp_seconds = torch.clone(timestamp_seconds) + + for idx in range(rounded_steps // patch_size): + base_distr, loc, scale = self.model( + inputs=batch_inputs, + input_padding_mask=batch_input_padding_mask, + id_mask=batch_id_mask, + kv_cache=kv_cache, + scaling_prefix_length=scaling_prefix_length, + num_exogenous_variables=num_exogenous_variables, + ) + distr = self.create_affine_transformed(base_distr, loc, scale) + + sample = distr.sample() + assert sample is not None + + samples = replace_extreme_values(sample[:, :, -patch_size:]) + + if future_exogenous_variables is not None: + start, stop = idx * patch_size, (idx + 1) * patch_size + samples[:, -num_exogenous_variables:] = future_exogenous_variables[ + :, :, start:stop + ] + batch_inputs = torch.cat([batch_inputs, samples], dim=-1) + batch_id_mask = torch.cat([batch_id_mask, dummy_id_mask], dim=-1) + batch_input_padding_mask = torch.cat( + [batch_input_padding_mask, dummy_padding], dim=-1 + ) + for _ in range(patch_size): + next_timestamp = ( + batch_timestamp_seconds[:, :, -1] + time_interval_seconds + ) + batch_timestamp_seconds = torch.cat( + [batch_timestamp_seconds, next_timestamp.unsqueeze(-1)], dim=-1 + ) + all_samples.append(batch_inputs) + if kv_cache is not None: + kv_cache.reset() + + outputs = torch.cat(all_samples, dim=0) + unfolded_outputs = rearrange( + outputs, + "(samples batch) variates seq_len -> batch variates seq_len samples", + samples=num_samples, + ).detach() + + return unfolded_outputs[:, :, start_index:end_index, :] + + @staticmethod + def create_affine_transformed( + base_distr: Distribution, loc: torch.Tensor, scale: torch.Tensor + ) -> Distribution: + base_shape = base_distr.mean.shape + base_time_dim = base_shape[-1] + loc_time_dim = loc.shape[-1] + + if loc_time_dim == 1: + return AffineTransformed(base_distr, loc=loc, scale=scale) + + return AffineTransformed( + base_distr, + loc=loc[:, :, -base_time_dim:], + scale=scale[:, :, -base_time_dim:], + ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/__init__.py new file mode 100644 index 000000000000..ba26b1edd945 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py new file mode 100644 index 000000000000..80f6d381ff20 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py @@ -0,0 +1,276 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +import logging +import warnings +from enum import Enum +from typing import TYPE_CHECKING, Optional, Union + +import torch +from einops import rearrange +from jaxtyping import Bool, Float, Int + +from .rope import TimeAwareRotaryEmbedding + +if TYPE_CHECKING: + from .util import KVCache + +log = logging.getLogger(__name__) + +try: + from xformers.ops import LowerTriangularMask, memory_efficient_attention + + XFORMERS_AVAILABLE = True + log.info("xFormers Memory-Efficient Attention available.") +except ImportError: + warnings.warn( + "xFormers Memory-Efficient Attention not available. " + "Falling back to native PyTorch scaled_dot_product_attention.", + ImportWarning, + ) + + XFORMERS_AVAILABLE = False + +from torch.nn.functional import scaled_dot_product_attention + + +class AttentionAxis(Enum): + TIME = 1 + SPACE = 2 + + +class BaseMultiheadAttention(torch.nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float, + rotary_emb: Optional[TimeAwareRotaryEmbedding], + use_memory_efficient_attention: bool, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + assert ( + embed_dim % num_heads == 0 + ), "Embedding dimension must be divisible by number of heads." + self.head_dim = embed_dim // num_heads + self.rotary_emb = rotary_emb + + self.wQKV = torch.nn.Linear(embed_dim, embed_dim * 3) + self.dropout = dropout + self.use_memory_efficient_attention = use_memory_efficient_attention + self.wO = torch.nn.Linear(embed_dim, embed_dim) + + assert not ( + not XFORMERS_AVAILABLE and self.use_memory_efficient_attention + ), "XFORMERS_AVAILABLE is False, so use_memory_efficient_attention must be False" + + if not hasattr(self, "attention_axis") or self.attention_axis not in ( + AttentionAxis.TIME, + AttentionAxis.SPACE, + ): + raise ValueError( + "Child class must define attention_axis as AttentionAxis.TIME or AttentionAxis.SPACE." + ) + + def rearrange_inputs( + self, inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"] + ) -> Float[torch.Tensor, "... embed_dim"]: + pattern = ( + "batch variate seq_len embed_dim -> (batch variate) seq_len embed_dim" + if self.attention_axis == AttentionAxis.TIME + else "batch variate seq_len embed_dim -> (batch seq_len) variate embed_dim" + ) + return rearrange(inputs, pattern) + + def get_qkv(self, inputs: torch.Tensor) -> tuple[torch.Tensor, ...]: + if ( + self.attention_axis == AttentionAxis.TIME + and self.use_memory_efficient_attention + ): + pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate seq_len n_heads head_dim" + elif ( + self.attention_axis == AttentionAxis.TIME + and not self.use_memory_efficient_attention + ): + pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate n_heads seq_len head_dim" + elif ( + self.attention_axis == AttentionAxis.SPACE + and self.use_memory_efficient_attention + ): + pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len variate n_heads head_dim" + elif ( + self.attention_axis == AttentionAxis.SPACE + and not self.use_memory_efficient_attention + ): + pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len n_heads variate head_dim" + + qkv = self.wQKV(inputs.contiguous()) + return rearrange( + qkv, pattern, qkv=3, head_dim=self.head_dim, n_heads=self.num_heads + ).unbind(dim=0) + + def positional_embedding(self, q, k, v, kv_cache, layer_idx): + seq_pos_offset = 0 + if self.rotary_emb is not None and self.attention_axis == AttentionAxis.TIME: + if kv_cache is not None: + seq_pos_offset = kv_cache.seq_len(layer_idx) + q, k = self.rotary_emb.rotate_queries_and_keys( + q, k, seq_pos_offset=seq_pos_offset + ) + + if kv_cache is not None and self.attention_axis == AttentionAxis.TIME: + kv_cache.append(layer_idx, (k, v)) + k, v = kv_cache[layer_idx] + + q = q.contiguous() + k = k.contiguous().to(q.dtype) + v = v.contiguous().to(q.dtype) + + return q, k, v, seq_pos_offset + + def rearrange_output( + self, output: torch.Tensor, batch: int, variate: int, seq_len: int + ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]: + if ( + self.attention_axis == AttentionAxis.TIME + and self.use_memory_efficient_attention + ): + pattern = "(batch variate) seq_len n_heads head_dim -> batch variate seq_len (n_heads head_dim)" + elif ( + self.attention_axis == AttentionAxis.TIME + and not self.use_memory_efficient_attention + ): + pattern = "(batch variate) n_heads seq_len head_dim -> batch variate seq_len (n_heads head_dim)" + elif ( + self.attention_axis == AttentionAxis.SPACE + and self.use_memory_efficient_attention + ): + pattern = "(batch seq_len) variate n_heads head_dim -> batch variate seq_len (n_heads head_dim)" + elif ( + self.attention_axis == AttentionAxis.SPACE + and not self.use_memory_efficient_attention + ): + pattern = "(batch seq_len) n_heads variate head_dim -> batch variate seq_len (n_heads head_dim)" + + return rearrange(output, pattern, batch=batch, variate=variate, seq_len=seq_len) + + def run_attention( + self, attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate + ): + q_dim_start, q_dim_end = seq_pos_offset, seq_pos_offset + seq_len + kv_dim_start, kv_dim_end = 0, ( + v.shape[1] if self.use_memory_efficient_attention else v.shape[2] + ) + if ( + self.attention_axis == AttentionAxis.TIME + and self.use_memory_efficient_attention + ): + attention_mask = ( + attention_mask[..., q_dim_start:q_dim_end, kv_dim_start:kv_dim_end] + if torch.is_tensor(attention_mask) + else LowerTriangularMask() if seq_pos_offset == 0 else None + ) + return memory_efficient_attention( + q, k, v, attn_bias=attention_mask, p=dropout + ) + elif ( + self.attention_axis == AttentionAxis.TIME + and not self.use_memory_efficient_attention + ): + attention_mask = ( + attention_mask[..., q_dim_start:q_dim_end, kv_dim_start:kv_dim_end] + if torch.is_tensor(attention_mask) + else None + ) + return scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_mask, + dropout_p=dropout, + is_causal=(attention_mask is None and seq_pos_offset == 0), + ) + elif ( + self.attention_axis == AttentionAxis.SPACE + and self.use_memory_efficient_attention + ): + attention_mask = ( + attention_mask[..., kv_dim_start:kv_dim_end, kv_dim_start:kv_dim_end] + if torch.is_tensor(attention_mask) + else None + ) + return memory_efficient_attention( + q, k, v, attn_bias=attention_mask, p=dropout + ) + elif ( + self.attention_axis == AttentionAxis.SPACE + and not self.use_memory_efficient_attention + ): + attention_mask = ( + attention_mask[..., kv_dim_start:kv_dim_end, kv_dim_start:kv_dim_end] + if torch.is_tensor(attention_mask) + else None + ) + return scaled_dot_product_attention( + q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False + ) + + def forward( + self, + layer_idx: int, + inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"], + attention_mask: Optional[ + Union[ + Bool[torch.Tensor, "batch_X_variate n_heads seq_len seq_len"], + Bool[torch.Tensor, "batch_X_seq_len n_heads variate variate"], + ] + ] = None, + kv_cache: Optional["KVCache"] = None, + ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]: + batch_size, variate, seq_len, _ = inputs.shape + dropout = self.dropout if self.training else 0.0 + + rearranged_inputs = self.rearrange_inputs(inputs) + q, k, v = self.get_qkv(rearranged_inputs) + + q, k, v, seq_pos_offset = self.positional_embedding( + q, k, v, kv_cache, layer_idx + ) + + output = self.run_attention( + attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate + ) + + output = self.rearrange_output(output, batch_size, variate, seq_len) + return self.wO(output) + + +class TimeWiseMultiheadAttention(BaseMultiheadAttention): + attention_axis = AttentionAxis.TIME + + +class SpaceWiseMultiheadAttention(BaseMultiheadAttention): + attention_axis = AttentionAxis.SPACE + + +MultiHeadAttention = TimeWiseMultiheadAttention | SpaceWiseMultiheadAttention diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py new file mode 100644 index 000000000000..84fa537e3fc2 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py @@ -0,0 +1,258 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +from math import ceil +from typing import NamedTuple, Optional, Type, cast + +import torch +from einops import rearrange, repeat +from jaxtyping import Bool, Float, Int + +from .distribution import DISTRIBUTION_CLASSES_LOOKUP, DistributionOutput +from .embedding import PatchEmbedding +from .fusion import Fusion +from .scaler import scaler_types +from .transformer import Transformer +from .util import KVCache + + +class TotoOutput(NamedTuple): + """ + Output of the Toto model. Contains the output distribution, the location parameters, + and the scale parameters. + """ + + distribution: torch.distributions.Distribution + loc: Float[torch.Tensor, "batch variate"] + scale: Float[torch.Tensor, "batch variate"] + + +class TotoBackbone(torch.nn.Module): + """ + Toto (Timeseries-Optimized Transformer for Observability) is a transformer-based model + for multivariate time series forecasting. + """ + + def __init__( + self, + patch_size: int, + stride: int, + embed_dim: int, + num_layers: int, + num_heads: int, + mlp_hidden_dim: int, + dropout: float, + spacewise_every_n_layers: int, + scaler_cls: str, + output_distribution_classes: list[str], + spacewise_first: bool = True, + output_distribution_kwargs: dict | None = None, + use_memory_efficient_attention: bool = True, + stabilize_with_global: bool = True, + scale_factor_exponent: float = 10.0, + ): + super().__init__() + self.embed_dim = embed_dim + self.fusion: Optional[Fusion] = None + self.num_prepended_tokens: int = 0 + self.target_variate_label: Optional[torch.nn.Parameter] = None + self.exogenous_variate_label: Optional[torch.nn.Parameter] = None + + if scaler_cls in ( + "", + "per_variate_causal_patch", + ): + self.scaler = scaler_types[scaler_cls]( + patch_size=patch_size, + stabilize_with_global=stabilize_with_global, + scale_factor_exponent=scale_factor_exponent, + ) + else: + self.scaler = scaler_types[scaler_cls]() + + self.patch_embed = PatchEmbedding(patch_size, stride, embed_dim) + self.dropout = dropout + self.num_layers = num_layers + self.use_memory_efficient_attention = use_memory_efficient_attention + self.transformer = Transformer( + embed_dim=embed_dim, + num_heads=num_heads, + num_layers=self.num_layers, + mlp_hidden_dim=mlp_hidden_dim, + dropout=dropout, + spacewise_every_n_layers=spacewise_every_n_layers, + spacewise_first=spacewise_first, + use_memory_efficient_attention=self.use_memory_efficient_attention, + fusion=self.fusion, + ) + self.unembed = torch.nn.Linear(embed_dim, embed_dim * patch_size) + + output_distribution_classes_ = [ + DISTRIBUTION_CLASSES_LOOKUP[c] for c in output_distribution_classes + ] + self.output_distribution = output_distribution_classes_[0]( + embed_dim, **(output_distribution_kwargs or {}) + ) + + def allocate_kv_cache( + self, + batch_size: int, + num_variates: int, + max_time_steps: int, + device: torch.device, + dtype: torch.dtype, + ) -> KVCache: + return KVCache( + batch_size=batch_size, + num_variates=num_variates, + transformer_layers=list(self.transformer.layers), + num_layers=self.num_layers, + embed_dim=self.embed_dim, + num_heads=cast(int, self.transformer.layers[0].num_heads), + max_seq_len=ceil(max_time_steps / self.patch_embed.stride), + device=device, + dtype=dtype, + use_memory_efficient_attention=self.use_memory_efficient_attention, + ) + + def backbone( + self, + inputs: Float[torch.Tensor, "batch variate time_steps"], + input_padding_mask: Bool[torch.Tensor, "batch variate time_steps"], + id_mask: Float[torch.Tensor, "batch #variate time_steps"], + kv_cache: Optional[KVCache] = None, + scaling_prefix_length: Optional[int] = None, + num_exogenous_variables: int = 0, + ) -> tuple[ + Float[torch.Tensor, "batch variates time_steps embed_dim"], + Float[torch.Tensor, "batch variates time_steps"], + Float[torch.Tensor, "batch variates time_steps"], + ]: + scaled_inputs, loc, scale = self.scaler( + inputs, + weights=torch.ones_like(inputs, device=inputs.device), + padding_mask=input_padding_mask, + prefix_length=scaling_prefix_length, + ) + + if kv_cache is not None: + kv_cache_len_tensor = kv_cache.current_len(0) + kv_cache_len = ( + int(kv_cache_len_tensor) + if isinstance(kv_cache_len_tensor, torch.Tensor) + else kv_cache_len_tensor + ) + prefix_len = max( + 0, self.patch_embed.stride * (kv_cache_len - self.num_prepended_tokens) + ) + + scaled_inputs = scaled_inputs[:, :, prefix_len:] + + assert (prefix_len == 0) or ( + scaled_inputs.shape[-1] == self.patch_embed.stride + ), "Must decode one step at a time." + + input_padding_mask = input_padding_mask[:, :, prefix_len:] + id_mask = id_mask[:, :, prefix_len:] + + embeddings, reduced_id_mask = self.patch_embed(scaled_inputs, id_mask) + + variate_label_embeds = self.build_variate_label_embeds( + num_exogenous_variables, embeddings + ) + + original_seq_len = embeddings.shape[2] + transformed = self.transformer( + embeddings, + reduced_id_mask, + kv_cache, + variate_label_embeds=variate_label_embeds, + ) + added_tokens = transformed.shape[2] - original_seq_len + if added_tokens > 0: + transformed = transformed[:, :, added_tokens:] + + flattened: Float[torch.Tensor, "batch variates new_seq_len embed_dim"] = ( + rearrange( + self.unembed(transformed), + "batch variates seq_len (patch_size embed_dim) -> batch variates (seq_len patch_size) embed_dim", + embed_dim=self.embed_dim, + ) + ) + return flattened, loc, scale + + def forward( + self, + inputs: Float[torch.Tensor, "batch variate time_steps"], + input_padding_mask: Bool[torch.Tensor, "batch variate time_steps"], + id_mask: Float[torch.Tensor, "batch #variate time_steps"], + kv_cache: Optional[KVCache] = None, + scaling_prefix_length: Optional[int] = None, + num_exogenous_variables: int = 0, + ) -> TotoOutput: + flattened, loc, scale = self.backbone( + inputs, + input_padding_mask, + id_mask, + kv_cache, + scaling_prefix_length, + num_exogenous_variables, + ) + + return TotoOutput(self.output_distribution(flattened), loc, scale) + + @property + def device(self): + return next(self.parameters()).device + + def enable_variate_labels(self) -> None: + self.fusion = Fusion() + self.num_prepended_tokens = 1 + self.target_variate_label = torch.nn.Parameter(torch.randn(self.embed_dim)) + self.exogenous_variate_label = torch.nn.Parameter(torch.randn(self.embed_dim)) + if hasattr(self, "transformer") and self.transformer is not None: + self.transformer.fusion = self.fusion + + def build_variate_label_embeds( + self, + num_exogenous_variables: int, + embeddings: Float[torch.Tensor, "batch variate seq_len embed_dim"], + ) -> Optional[Float[torch.Tensor, "batch variate 1 embed_dim"]]: + if self.fusion is None: + return None + + assert self.target_variate_label is not None + assert self.exogenous_variate_label is not None + + batch_size, num_variates, _, _ = embeddings.shape + + target_variate_label = repeat( + self.target_variate_label, "d -> b v 1 d", b=batch_size, v=num_variates + ).to(device=embeddings.device, dtype=embeddings.dtype) + exogenous_variate_label = repeat( + self.exogenous_variate_label, "d -> b v 1 d", b=batch_size, v=num_variates + ).to(device=embeddings.device, dtype=embeddings.dtype) + exog_mask = torch.zeros( + 1, num_variates, 1, 1, dtype=torch.bool, device=embeddings.device + ) + if num_exogenous_variables > 0: + exog_mask[:, -num_exogenous_variables:] = True + return torch.where(exog_mask, exogenous_variate_label, target_variate_label) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py new file mode 100644 index 000000000000..f34bd4afdf0a --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +from abc import ABC + +import torch +import torch.nn.functional as F +from torch.distributions import TransformedDistribution +from torch.distributions.transforms import AffineTransform + + +class AffineTransformed(TransformedDistribution): + """ + A thin wrapper around TransformedDistribution with an AffineTransform, + replacing the gluonts.torch.distributions.AffineTransformed dependency. + Provides the same interface: mean, variance, sample(), log_prob(). + """ + + def __init__(self, base_distribution, loc=0.0, scale=1.0): + super().__init__(base_distribution, AffineTransform(loc=loc, scale=scale)) + + @property + def mean(self): + # mean(aX + b) = a * mean(X) + b + loc = self.transforms[0].loc + scale = self.transforms[0].scale + return loc + scale * self.base_dist.mean + + # Note: Do NOT override sample() here. TransformedDistribution.sample() correctly + # calls base_dist.sample() (not rsample), which works for non-reparameterizable + # distributions like MixtureSameFamily. + + +class DistributionOutput(ABC, torch.nn.Module): + pass + + +class StudentTOutput(DistributionOutput): + def __init__(self, embed_dim): + super().__init__() + self.embed_dim = embed_dim + self.df = torch.nn.Linear(embed_dim, 1) + self.loc_proj = torch.nn.Linear(embed_dim, 1) + self.scale_proj = torch.nn.Linear(embed_dim, 1) + + def forward(self, inputs, loc=None, scale=None): + eps = torch.finfo(inputs.dtype).eps + df = 2.0 + F.softplus(self.df(inputs)).clamp_min(eps).squeeze(-1) + base_loc = self.loc_proj(inputs).squeeze(-1) + base_scale = F.softplus(self.scale_proj(inputs)).clamp_min(eps).squeeze(-1) + + base_dist = torch.distributions.StudentT( + df, base_loc, base_scale, validate_args=False + ) + + if loc is not None and scale is not None: + return AffineTransformed(base_dist, loc=loc, scale=scale) + return base_dist + + +class MixtureOfStudentTsOutput(DistributionOutput): + def __init__(self, embed_dim, k_components): + super().__init__() + self.embed_dim = embed_dim + self.k_components = k_components + + self.df = torch.nn.Linear(embed_dim, k_components) + self.loc_proj = torch.nn.Linear(embed_dim, k_components) + self.scale_proj = torch.nn.Linear(embed_dim, k_components) + self.mixture_weights = torch.nn.Linear(embed_dim, k_components) + + def forward(self, inputs, loc=None, scale=None): + df = 2.0 + F.softplus(self.df(inputs)).clamp_min(torch.finfo(inputs.dtype).eps) + component_loc = self.loc_proj(inputs) + component_scale = F.softplus(self.scale_proj(inputs)).clamp_min( + torch.finfo(inputs.dtype).eps + ) + logits = self.mixture_weights(inputs) + probs = F.softmax(logits, dim=-1) + components = torch.distributions.StudentT( + df, component_loc, component_scale, validate_args=False + ) + mixture_distribution = torch.distributions.Categorical(probs=probs) + + return torch.distributions.MixtureSameFamily(mixture_distribution, components) + + +DISTRIBUTION_CLASSES_LOOKUP = { + "": StudentTOutput, + "": MixtureOfStudentTsOutput, + # Short-form aliases for convenience + "student_t": StudentTOutput, + "student_t_mixture": MixtureOfStudentTsOutput, +} diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py new file mode 100644 index 000000000000..fc7eadac9af9 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +from typing import Optional + +import torch +from jaxtyping import Float, Int, Num + + +def patchify_id_mask( + id_mask: Int[torch.Tensor, "batch variate time_steps"], patch_size: int +) -> Int[torch.Tensor, "batch variate seq_len patch_size"]: + patched_id_mask = id_mask.unfold(dimension=-1, size=patch_size, step=patch_size) + patched_id_mask_min = patched_id_mask.min(-1).values + patched_id_mask_max = patched_id_mask.max(-1).values + assert torch.eq( + patched_id_mask_min, patched_id_mask_max + ).all(), "Patches cannot span multiple datasets" + return patched_id_mask_min + + +class PatchEmbedding(torch.nn.Module): + """ + Multivariate time series patch embedding. + Patchifies each variate separately. + """ + + def __init__(self, patch_size: int, stride: int, embed_dim: int): + super().__init__() + self.patch_size = patch_size + self.embed_dim = embed_dim + self.stride = stride + self.projection = torch.nn.Linear(self.patch_size, self.embed_dim) + + def _patchify( + self, x: Num[torch.Tensor, "batch variate time_steps"] + ) -> Num[torch.Tensor, "batch variate seq_len patch_size"]: + return x.unfold(dimension=-1, size=self.patch_size, step=self.stride) + + def forward( + self, + x: Float[torch.Tensor, "batch #variate time_steps"], + id_mask: Float[torch.Tensor, "batch time_steps"], + ) -> tuple[ + Float[torch.Tensor, "batch variate seq_len embed_dim"], + Int[torch.Tensor, "batch seq_len"], + ]: + assert ( + x.shape[-1] % self.patch_size == 0 + ), f"Series length ({x.shape=}) must be divisible by ({self.patch_size=})" + x_patched: Float[torch.Tensor, "batch variate seq_len patch_size"] = ( + self._patchify(x) + ) + id_mask_patched: Int[torch.Tensor, "batch variate seq_len patch_size"] = ( + self._patchify(id_mask) + ) + + assert torch.eq( + id_mask_patched.min(-1).values, id_mask_patched.max(-1).values + ).all(), "Patches cannot span multiple datasets" + + return ( + self.projection(x_patched), + id_mask_patched.min(-1).values, + ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/feed_forward.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/feed_forward.py new file mode 100644 index 000000000000..024a8bed7277 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/feed_forward.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +import torch +import torch.nn.functional as F + + +class SwiGLU(torch.nn.Module): + """ + https://arxiv.org/abs/2002.05202 + NOTE: x should be 2x the size you want + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Note this ordering is unusual, but is done so to match xFormers + gate, x = x.chunk(2, dim=-1) + return F.silu(gate) * x diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py new file mode 100644 index 000000000000..cfe364ac91eb --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +from typing import Optional + +import torch +import torch.nn.functional as F +from jaxtyping import Float + + +class Fusion(torch.nn.Module): + """ + Prepends variate label embeddings to the input embeddings along the sequence dimension. + """ + + def __init__(self) -> None: + super().__init__() + + def forward( + self, + embeddings: Float[torch.Tensor, "batch variate seq_len embed_dim"], + variate_label_embeds: Optional[ + Float[torch.Tensor, "batch variate 1 embed_dim"] + ] = None, + ) -> Float[torch.Tensor, "batch variate new_seq_len embed_dim"]: + + if variate_label_embeds is None: + return embeddings + + processed_embeddings = F.normalize(variate_label_embeds, p=2, dim=-1) + + return torch.cat( + [ + processed_embeddings.to( + dtype=embeddings.dtype, device=embeddings.device, non_blocking=True + ), + embeddings, + ], + dim=2, + ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/rope.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/rope.py new file mode 100644 index 000000000000..96e625170770 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/rope.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +from typing import Optional + +import torch +from einops import rearrange +from jaxtyping import Int +from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb +from rotary_embedding_torch.rotary_embedding_torch import default + + +def exists(val): + return val is not None + + +class TimeAwareRotaryEmbedding(RotaryEmbedding): + """ + A variant of the rotary position embedding that (optionally) uses the time index + to compute the sinusoidal and cosine embeddings. Useful for time series data. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # If the parent stored `freqs` as a Parameter, remove it and register as a buffer + if hasattr(self, "freqs") and isinstance(self.freqs, torch.nn.Parameter): + freqs_data = self.freqs.data + self._parameters.pop("freqs") + self.register_buffer("freqs", freqs_data, persistent=False) + + def rotate_queries_and_keys( + self, + q: torch.Tensor, + k: torch.Tensor, + seq_dim: Optional[int] = None, + seq_pos: Optional[Int[torch.Tensor, "... seq_len"]] = None, + seq_pos_offset: int = 0, + ): + if seq_dim is None: + seq_dim = self.default_seq_dim + + assert self.use_xpos + device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] + + seq = default(seq_pos, self.get_seq_pos(seq_len, dtype=dtype, device=device)) + seq = seq + seq_pos_offset + + freqs = self.forward(seq) + + scale = self.get_scale(seq).to(dtype) + + if seq_dim == -3: + num_heads = q.shape[-2] + freqs = freqs.unsqueeze(1).expand(-1, num_heads, -1) + scale = scale.unsqueeze(1).expand(-1, num_heads, -1) + + rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim) + rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def get_scale( + self, + t: torch.Tensor, + ): + assert self.use_xpos + + power = (t - t.max(-1).values.unsqueeze(-1) // 2) / self.scale_base + + scale = self.scale ** rearrange(power, "... n -> ... n 1") + scale = torch.cat((scale, scale), dim=-1) + + return scale diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py new file mode 100644 index 000000000000..e640e3ef3a21 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py @@ -0,0 +1,328 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +import warnings +from typing import Tuple + +import torch +from einops import reduce, repeat + + +class Scaler(torch.nn.Module): + """ + Minimal base class replacing gluonts.torch.scaler.Scaler. + Provides a __call__ interface for scaling data. + """ + + pass + + +class StdMeanScaler(Scaler): + """ + Scales data to have zero mean and unit variance along a given dimension. + """ + + def __init__( + self, + dim: int = -1, + keepdim: bool = True, + minimum_scale: float = 1e-3, + ) -> None: + super().__init__() + self.dim = dim + self.keepdim = keepdim + self.minimum_scale = minimum_scale + + def __call__( + self, + data: torch.Tensor, + padding_mask: torch.Tensor, + weights: torch.Tensor, + prefix_length: int | None = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert data.shape == weights.shape, "data and weights must have same shape" + with torch.no_grad(): + if prefix_length is not None: + prefix_mask = torch.zeros_like(weights) + prefix_mask[..., :prefix_length] = 1.0 + weights = weights * prefix_mask + + weights = weights * padding_mask + + try: + high_precision_data = data.to(torch.float64) + except TypeError: + warnings.warn( + f"Float64 is not supported by device {data.device}. " + "Using float32 instead for accumulating denominator in input scaler. " + "This may lead to overflow issues if the data contains extreme values.", + RuntimeWarning, + ) + high_precision_data = data.to(torch.float32) + + denominator = ( + weights.sum(self.dim, keepdim=self.keepdim) + .clamp_min(1.0) + .to(high_precision_data.dtype) + ) + means = (high_precision_data * weights).sum( + self.dim, keepdim=self.keepdim + ) / denominator + means = torch.nan_to_num(means) + + variance = (((high_precision_data - means) * weights) ** 2).sum( + self.dim, keepdim=self.keepdim + ) / denominator + scale = torch.sqrt(variance + self.minimum_scale).to(data.dtype) + loc = means.to(data.dtype) + + return (data - loc) / scale, loc, scale + + +def compute_causal_statistics( + data: torch.Tensor, + weights: torch.Tensor, + padding_mask: torch.Tensor, + dim: int, + minimum_scale: float, + use_bessel_correction: bool = True, + stabilize_with_global: bool = False, + scale_factor_exponent: float = 10.0, + prefix_length: int | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dim == -1, "compute_causal_statistics only supports dim=-1 (last dimension)" + + with torch.no_grad(): + weights = weights * padding_mask + + try: + high_precision_data = data.to(torch.float64) + high_precision_weights = weights.to(torch.float64) + except TypeError: + warnings.warn( + f"Float64 is not supported by device {data.device}. " + "Using float32 instead for causal scaler calculations.", + RuntimeWarning, + ) + high_precision_data = data.to(torch.float32) + high_precision_weights = weights.to(torch.float32) + + prev_deterministic = torch.are_deterministic_algorithms_enabled() + if prev_deterministic and data.device.type == "cuda": + torch.use_deterministic_algorithms(False) + + try: + weighted_data = high_precision_weights * high_precision_data + + cum_weights = torch.cumsum(high_precision_weights, dim=dim) + cum_values = torch.cumsum(weighted_data, dim=dim) + + denominator = cum_weights.clamp_min(1.0) + causal_means = cum_values / denominator + + shifted_means = torch.zeros_like(causal_means) + shifted_means[..., 1:] = causal_means[..., :-1] + + delta = high_precision_data - shifted_means + increment = ( + delta * (high_precision_data - causal_means) * high_precision_weights + ) + m_2 = torch.cumsum(increment, dim=dim) + + if use_bessel_correction: + causal_variance = m_2 / torch.clamp(denominator - 1.0, min=1.0) + else: + causal_variance = m_2 / denominator + + causal_scale = torch.sqrt(causal_variance + minimum_scale) + + if stabilize_with_global: + if prefix_length is not None: + prefix_mask = torch.zeros_like(weights) + prefix_mask[..., :prefix_length] = 1.0 + weighted_data = weighted_data * prefix_mask + weights = weights * prefix_mask + padding_mask = padding_mask * prefix_mask + + scale_factor_min = 10.0 ** (-scale_factor_exponent) + scale_factor_max = 10.0**scale_factor_exponent + + global_denominator = ( + (weights * padding_mask).sum(dim, keepdim=True).clamp_min(1.0) + ) + global_means = (weighted_data).sum( + dim, keepdim=True + ) / global_denominator + global_means = torch.nan_to_num(global_means) + + global_variance = ( + ((high_precision_data - global_means) * weights * padding_mask) ** 2 + ).sum(dim, keepdim=True) / global_denominator + global_scale = torch.sqrt(global_variance + minimum_scale) + + expanded_global_scale = global_scale.expand_as(causal_scale) + min_allowed_scale = expanded_global_scale * scale_factor_min + max_allowed_scale = expanded_global_scale * scale_factor_max + + causal_scale = torch.clamp( + causal_scale, + min=torch.max( + torch.tensor(minimum_scale, device=causal_scale.device), + min_allowed_scale, + ), + max=max_allowed_scale, + ) + + causal_means = causal_means.to(data.dtype) + causal_scale = causal_scale.to(data.dtype) + + finally: + if prev_deterministic and data.device.type == "cuda": + torch.use_deterministic_algorithms(True) + + return causal_means, causal_scale + + +class CausalStdMeanScaler(Scaler): + def __init__( + self, + dim: int = -1, + minimum_scale: float = 0.1, + use_bessel_correction: bool = True, + stabilize_with_global: bool = False, + scale_factor_exponent: float = 10.0, + ) -> None: + super().__init__() + assert dim == -1, "CausalStdMeanScaler only supports dim=-1 (last dimension)" + self.dim = dim + self.minimum_scale = minimum_scale + self.use_bessel_correction = use_bessel_correction + self.stabilize_with_global = stabilize_with_global + self.scale_factor_exponent = scale_factor_exponent + + def __call__( + self, + data: torch.Tensor, + padding_mask: torch.Tensor, + weights: torch.Tensor, + prefix_length: int | None = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert data.shape == weights.shape, "data and weights must have same shape" + assert ( + len(data.shape) == 3 + ), "Input data must have shape [batch, variates, time_steps]" + + causal_means, causal_scale = compute_causal_statistics( + data, + weights, + padding_mask, + self.dim, + self.minimum_scale, + self.use_bessel_correction, + self.stabilize_with_global, + self.scale_factor_exponent, + prefix_length, + ) + + scaled_data = (data - causal_means) / causal_scale + + return scaled_data, causal_means, causal_scale + + +class CausalPatchStdMeanScaler(Scaler): + def __init__( + self, + dim: int = -1, + patch_size: int = 32, + minimum_scale: float = 0.1, + use_bessel_correction: bool = True, + stabilize_with_global: bool = False, + scale_factor_exponent: float = 10.0, + ) -> None: + super().__init__() + assert ( + dim == -1 + ), "CausalPatchStdMeanScaler only supports dim=-1 (last dimension)" + self.dim = dim + self.patch_size = patch_size + self.minimum_scale = minimum_scale + self.use_bessel_correction = use_bessel_correction + self.stabilize_with_global = stabilize_with_global + self.scale_factor_exponent = scale_factor_exponent + + def __call__( + self, + data: torch.Tensor, + padding_mask: torch.Tensor, + weights: torch.Tensor, + prefix_length: int | None = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert data.shape == weights.shape, "data and weights must have same shape" + assert ( + len(data.shape) == 3 + ), "Input data must have shape [batch, variates, time_steps]" + + with torch.no_grad(): + time_steps = data.shape[-1] + assert ( + time_steps % self.patch_size == 0 + ), f"Time steps ({time_steps}) must be divisible by patch size ({self.patch_size})" + + causal_means, causal_scale = compute_causal_statistics( + data, + weights, + padding_mask, + -1, + self.minimum_scale, + self.use_bessel_correction, + self.stabilize_with_global, + self.scale_factor_exponent, + prefix_length, + ) + + means_unfolded = causal_means.unfold(-1, self.patch_size, self.patch_size) + scales_unfolded = causal_scale.unfold(-1, self.patch_size, self.patch_size) + + patch_stats_means = means_unfolded[..., -1] + patch_stats_scales = scales_unfolded[..., -1] + + patch_means = repeat( + patch_stats_means, "b v p -> b v (p s)", s=self.patch_size + ) + patch_scales = repeat( + patch_stats_scales, "b v p -> b v (p s)", s=self.patch_size + ) + + scaled_data = (data - patch_means) / patch_scales + + return scaled_data, patch_means, patch_scales + + +# for deserialization of SafeTensors checkpoints +scaler_types = { + "": StdMeanScaler, + "": CausalStdMeanScaler, + "": CausalPatchStdMeanScaler, + # Short aliases used in config.json + "per_variate": StdMeanScaler, + "per_variate_causal": CausalStdMeanScaler, + "per_variate_causal_patch": CausalPatchStdMeanScaler, +} diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py new file mode 100644 index 000000000000..61595334171f --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py @@ -0,0 +1,157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +import json +import os +import re +from pathlib import Path +from typing import Dict, Optional, Union + +import safetensors.torch as safetorch +import torch + +from .attention import XFORMERS_AVAILABLE +from .backbone import TotoBackbone +from .transformer import XFORMERS_SWIGLU_AVAILABLE + + +class Toto(torch.nn.Module): + """ + PyTorch module for Toto (Timeseries-Optimized Transformer for Observability). + This class is used internally for checkpoint loading logic. + """ + + def __init__( + self, + patch_size: int, + stride: int, + embed_dim: int, + num_layers: int, + num_heads: int, + mlp_hidden_dim: int, + dropout: float, + spacewise_every_n_layers: int, + scaler_cls: str, + output_distribution_classes: list[str], + spacewise_first: bool = True, + output_distribution_kwargs: dict | None = None, + use_memory_efficient_attention: bool = True, + stabilize_with_global: bool = True, + scale_factor_exponent: float = 10.0, + **model_kwargs, + ): + super().__init__() + self.model = TotoBackbone( + patch_size=patch_size, + stride=stride, + embed_dim=embed_dim, + num_layers=num_layers, + num_heads=num_heads, + mlp_hidden_dim=mlp_hidden_dim, + dropout=dropout, + spacewise_every_n_layers=spacewise_every_n_layers, + scaler_cls=scaler_cls, + output_distribution_classes=output_distribution_classes, + spacewise_first=spacewise_first, + output_distribution_kwargs=output_distribution_kwargs, + use_memory_efficient_attention=use_memory_efficient_attention, + stabilize_with_global=stabilize_with_global, + scale_factor_exponent=scale_factor_exponent, + ) + + @classmethod + def load_from_checkpoint( + cls, + checkpoint_path, + map_location: str = "cpu", + strict=True, + **model_kwargs, + ): + if os.path.isdir(checkpoint_path): + safetensors_file = os.path.join(checkpoint_path, "model.safetensors") + else: + safetensors_file = checkpoint_path + + if os.path.exists(safetensors_file): + model_state = safetorch.load_file(safetensors_file, device=map_location) + else: + raise FileNotFoundError( + f"Model checkpoint not found at: {safetensors_file}" + ) + + config_file = os.path.join(checkpoint_path, "config.json") + config = {} + if os.path.exists(config_file): + with open(config_file, "r") as f: + config = json.load(f) + + config.update(model_kwargs) + + remapped_state_dict = cls._map_state_dict_keys( + model_state, + XFORMERS_SWIGLU_AVAILABLE + and not config.get("pre_xformers_checkpoint", False), + ) + + if not XFORMERS_AVAILABLE and config.get( + "use_memory_efficient_attention", True + ): + config["use_memory_efficient_attention"] = False + + instance = cls(**config) + instance.to(map_location) + + filtered_remapped_state_dict = { + k: v + for k, v in remapped_state_dict.items() + if k in instance.state_dict() and not k.endswith("rotary_emb.freqs") + } + + instance.load_state_dict(filtered_remapped_state_dict, strict=strict) + return instance + + @staticmethod + def _map_state_dict_keys(state_dict, use_fused_swiglu): + if use_fused_swiglu: + remap_keys = { + "mlp.0.weight": "mlp.0.w12.weight", + "mlp.0.bias": "mlp.0.w12.bias", + "mlp.2.weight": "mlp.0.w3.weight", + "mlp.2.bias": "mlp.0.w3.bias", + } + else: + remap_keys = { + "mlp.0.w12.weight": "mlp.0.weight", + "mlp.0.w12.bias": "mlp.0.bias", + "mlp.0.w3.weight": "mlp.2.weight", + "mlp.0.w3.bias": "mlp.2.bias", + } + + def replace_key(text): + for pattern, replacement in remap_keys.items(): + text = re.sub(pattern, replacement, text) + return text + + return {replace_key(k): v for k, v in state_dict.items()} + + @property + def device(self): + return next(self.model.parameters()).device diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py new file mode 100644 index 000000000000..58220c30e625 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py @@ -0,0 +1,318 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +import warnings +from typing import Literal, Optional, Union, cast + +import torch +import torch.nn.functional as F +from einops import rearrange +from jaxtyping import Bool, Float, Int +from rotary_embedding_torch import RotaryEmbedding + +from .attention import ( + AttentionAxis, + MultiHeadAttention, + SpaceWiseMultiheadAttention, + TimeWiseMultiheadAttention, +) +from .feed_forward import SwiGLU +from .fusion import Fusion +from .rope import TimeAwareRotaryEmbedding +from .util import KVCache, RMSNorm, make_batched_block_mask + +try: + from xformers.ops.swiglu_op import SwiGLU as SwiGLU_fused + + XFORMERS_SWIGLU_AVAILABLE = True +except ImportError: + warnings.warn( + "xFormers fused SwiGLU kernel not found. " + "Using native PyTorch implementation for feed-forward layers.", + ImportWarning, + ) + XFORMERS_SWIGLU_AVAILABLE = False + + +class TransformerLayer(torch.nn.Module): + embed_dim: int + num_heads: int + mlp_hidden_dim: int + dropout: float + attention_axis: AttentionAxis + + def __init__( + self, + embed_dim: int, + num_heads: int, + mlp_hidden_dim: int, + dropout: float, + rotary_emb: RotaryEmbedding = None, + attention_axis: AttentionAxis = AttentionAxis.TIME, + RMS_norm: bool = True, + use_memory_efficient_attention: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.mlp_hidden_dim = mlp_hidden_dim + self.dropout = dropout + self.attention_axis = attention_axis + + if RMS_norm: + self.norm1: Union[RMSNorm, torch.nn.LayerNorm] = RMSNorm(embed_dim) + self.norm2: Union[RMSNorm, torch.nn.LayerNorm] = RMSNorm(embed_dim) + else: + self.norm1 = torch.nn.LayerNorm(embed_dim) + self.norm2 = torch.nn.LayerNorm(embed_dim) + + self.attention: MultiHeadAttention + + if attention_axis == AttentionAxis.TIME: + self.attention = TimeWiseMultiheadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + dropout=dropout, + rotary_emb=rotary_emb, + use_memory_efficient_attention=use_memory_efficient_attention, + ) + elif attention_axis == AttentionAxis.SPACE: + self.attention = SpaceWiseMultiheadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + dropout=dropout, + rotary_emb=None, + use_memory_efficient_attention=use_memory_efficient_attention, + ) + else: + raise ValueError("Invalid attention axis") + + if XFORMERS_SWIGLU_AVAILABLE: + self.mlp = torch.nn.Sequential( + SwiGLU_fused(in_features=embed_dim, hidden_features=mlp_hidden_dim), + torch.nn.Dropout(dropout), + ) + else: + self.mlp = torch.nn.Sequential( + torch.nn.Linear(embed_dim, 2 * mlp_hidden_dim), + SwiGLU(), + torch.nn.Linear(mlp_hidden_dim, embed_dim), + torch.nn.Dropout(dropout), + ) + + def forward( + self, + layer_idx: int, + inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"], + attention_mask: Optional[ + Union[ + Bool[torch.Tensor, "batch seq_len variate variate"], + Bool[torch.Tensor, "batch #variate seq_len seq_len"], + ] + ] = None, + kv_cache: Optional[KVCache] = None, + ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]: + pre_norm_1 = self.norm1(inputs) + hidden_state = ( + inputs + + self.attention( + layer_idx, pre_norm_1, attention_mask, kv_cache + ).contiguous() + ) + + pre_norm_2 = self.norm2(hidden_state) + return hidden_state + self.mlp(pre_norm_2) + + +class Transformer(torch.nn.Module): + def __init__( + self, + num_layers: int, + embed_dim: int, + num_heads: int, + mlp_hidden_dim: int, + dropout: float, + spacewise_every_n_layers: int, + spacewise_first: bool, + use_memory_efficient_attention: bool = True, + *, + fusion: Optional[Fusion] = None, + ): + super().__init__() + + assert ( + embed_dim % num_heads == 0 + ), "Embedding dimension must be divisible by number of heads." + + self.rotary_emb = TimeAwareRotaryEmbedding( + embed_dim // num_heads, + use_xpos=True, + cache_if_possible=True, + seq_before_head_dim=use_memory_efficient_attention, + ) + attention_axes = self._get_layer_types( + num_layers, spacewise_every_n_layers, spacewise_first + ) + + self.use_memory_efficient_attention = use_memory_efficient_attention + self.fusion = fusion + + self.layers = torch.nn.ModuleList( + [ + TransformerLayer( + embed_dim=embed_dim, + num_heads=num_heads, + mlp_hidden_dim=mlp_hidden_dim, + dropout=dropout, + rotary_emb=self.rotary_emb, + attention_axis=attention_axes[i], + use_memory_efficient_attention=self.use_memory_efficient_attention, + ) + for i in range(num_layers) + ] + ) + + def _get_mask( + self, + num_heads: int, + dtype: torch.dtype, + id_mask: Optional[torch.Tensor] = None, + ) -> Union[ + Bool[torch.Tensor, "batch num_heads seq_len seq_len"], + Float[torch.Tensor, "batch num_heads seq_len seq_len"], + Bool[torch.Tensor, "batch num_heads variate variate"], + Float[torch.Tensor, "batch num_heads variate variate"], + ]: + if id_mask is None: + raise ValueError("id_mask must be provided for spacewise masks.") + + mask = make_batched_block_mask(id_mask.transpose(-1, -2)) + + if self.use_memory_efficient_attention: + mask = self._pad_to_multiple(mask) + mask = ( + mask.float() + .masked_fill(~mask, float("-inf")) + .masked_fill(mask, 0.0) + .to(dtype) + ) + + mask = rearrange( + mask, + "batch seq_len variate1 variate2 -> (batch seq_len) 1 variate1 variate2", + ) + return mask.expand(-1, num_heads, -1, -1).contiguous() + + def _pad_to_multiple( + self, + tensor: torch.Tensor, + multiple: int = 8, + causal: bool = False, + ) -> torch.Tensor: + pad_amount = (multiple - tensor.shape[-1] % multiple) % multiple + if pad_amount > 0: + new_size = tensor.shape[-1] + pad_amount + if causal: + full_mask = torch.tril( + torch.ones( + (new_size, new_size), dtype=tensor.dtype, device=tensor.device + ) + ) + full_mask[: tensor.shape[-1], : tensor.shape[-1]] = tensor + tensor = full_mask + else: + tensor = F.pad(tensor, (0, pad_amount, 0, pad_amount)) + return tensor + + def _get_layer_types( + self, + num_layers: int, + spacewise_every_n_layers: int, + spacewise_first: bool, + ) -> list[AttentionAxis]: + if spacewise_every_n_layers == -1: + return [AttentionAxis.TIME] * num_layers + assert num_layers % spacewise_every_n_layers == 0 + + block = [AttentionAxis.TIME] * (spacewise_every_n_layers - 1) + + if spacewise_first: + block = [AttentionAxis.SPACE] + block + else: + block = block + [AttentionAxis.SPACE] + + return block * (num_layers // spacewise_every_n_layers) + + def forward( + self, + inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"], + id_mask: Float[torch.Tensor, "batch #variate seq_len"], + kv_cache: Optional[KVCache] = None, + variate_label_embeds: Optional[ + Float[torch.Tensor, "batch variate 1 embed_dim"] + ] = None, + ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]: + + if self.fusion is not None and variate_label_embeds is not None: + should_apply_fusion = True + if kv_cache is not None: + kv_len_tensor = kv_cache.current_len(0) + kv_len = ( + int(kv_len_tensor) + if isinstance(kv_len_tensor, torch.Tensor) + else kv_len_tensor + ) + should_apply_fusion = kv_len == 0 + if should_apply_fusion: + inputs = self.fusion(inputs, variate_label_embeds=variate_label_embeds) + + batch, _, seq_len, _ = inputs.shape + + if id_mask is not None and id_mask.shape[-1] != seq_len: + added = int(seq_len - id_mask.shape[-1]) + if added > 0: + pad_slice = id_mask[..., :1] + id_mask = torch.cat([pad_slice.expand(-1, -1, added), id_mask], dim=-1) + + seq_len = (kv_cache.seq_len(1) if kv_cache else 0) + seq_len + + num_heads: int = cast(int, self.layers[0].num_heads) + + timewise_attention_mask = None + + spacewise_attention_mask = self._get_mask( + num_heads=num_heads, + dtype=inputs.dtype, + id_mask=id_mask, + ) + + for layer_idx, layer in enumerate(self.layers): + inputs = layer( + layer_idx, + inputs, + ( + timewise_attention_mask + if layer.attention_axis == AttentionAxis.TIME + else spacewise_attention_mask + ), + kv_cache, + ) + return inputs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py new file mode 100644 index 000000000000..d913329e7e89 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py @@ -0,0 +1,251 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This file includes code derived from DataDog/toto +# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License. +# Copyright 2025 Datadog, Inc. + +import warnings +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional, TypeAlias, Union + +import torch +from einops import rearrange +from jaxtyping import Float, Int + +from .attention import TimeWiseMultiheadAttention + +if TYPE_CHECKING: + from .transformer import TransformerLayer + +try: + from xformers import _is_triton_available + from xformers.ops.rmsnorm import rms_norm, rms_norm_add + + XFORMERS_RMSNORM_AVAILABLE = True +except ImportError: + warnings.warn( + "xFormers fused RMSNorm implementation not available. Will not use " + "optimized kernel for inference.", + ImportWarning, + ) + + def _is_triton_available(): + return False + + XFORMERS_RMSNORM_AVAILABLE = False + + +class RMSNorm(torch.nn.Module): + """ + Wraps xFormers' rms_norm for eval/frozen mode, and does a Python fallback for train mode. + """ + + def __init__(self, dim: int, include_weight: bool = True, eps: float = 1e-8): + super(RMSNorm, self).__init__() + self.eps = eps + if include_weight: + self.scale: Optional[torch.nn.Parameter] = torch.nn.Parameter( + torch.ones(dim) + ) + else: + self.scale = None + + def forward(self, x: torch.Tensor): + if ( + ( + (not self.training) + or (self.scale is not None and not self.scale.requires_grad) + ) + and XFORMERS_RMSNORM_AVAILABLE + and _is_triton_available() + ): + return rms_norm(x, self.scale, self.eps) + + x_normed = x / torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + return x_normed if self.scale is None else x_normed * self.scale + + def increment_and_forward_(self, x: torch.Tensor, y: torch.Tensor): + if (not self.training) or ( + self.scale is not None and not self.scale.requires_grad + ): + return rms_norm_add(x, y, self.scale, self.eps) + return self.forward(x + y) + + +def make_batched_block_mask(t: torch.Tensor) -> torch.Tensor: + unsqueezed = rearrange(t, "... d -> ... 1 d") + return unsqueezed == unsqueezed.transpose(-1, -2) + + +K: TypeAlias = Float[ + torch.Tensor, "batch_size_X_num_variates num_heads seq_len head_dim" +] +V: TypeAlias = Float[ + torch.Tensor, "batch_size_X_num_variates num_heads seq_len head_dim" +] +KV: TypeAlias = tuple[K, V] + + +@dataclass +class KVCache: + """ + Key/Value cache for storing intermediate attention values during multistep inference. + Only stores KV cache for timewise layers, skipping spacewise layers. + """ + + batch_size: int + num_variates: int + transformer_layers: List["TransformerLayer"] + num_layers: int + embed_dim: int + num_heads: int + max_seq_len: int + device: torch.device = torch.device("cpu") + dtype: torch.dtype = torch.float32 + use_memory_efficient_attention: bool = True + + _keys: Union[ + Float[ + torch.Tensor, + "time_layer_count batch_size_X_num_variates max_seq_len num_heads head_dim", + ], + Float[ + torch.Tensor, + "time_layer_count batch_size_X_num_variates num_heads max_seq_len head_dim", + ], + ] = field(init=False) + + _values: Union[ + Float[ + torch.Tensor, + "time_layer_count batch_size_X_num_variates max_seq_len num_heads head_dim", + ], + Float[ + torch.Tensor, + "time_layer_count batch_size_X_num_variates num_heads max_seq_len head_dim", + ], + ] = field(init=False) + + _current_idx: Int[torch.Tensor, "time_layer_count"] = field(init=False) + _layer_cache_map: Int[torch.Tensor, "num_layers"] = field(init=False) + + def __post_init__(self): + assert ( + self.embed_dim % self.num_heads == 0 + ), "embed_dim must be divisible by num_heads" + head_dim = self.embed_dim // self.num_heads + + time_layer_indices = [ + i + for i in range(self.num_layers) + if isinstance( + self.transformer_layers[i].attention, TimeWiseMultiheadAttention + ) + ] + + time_layer_count = max(1, len(time_layer_indices)) + if self.use_memory_efficient_attention: + shape = ( + time_layer_count, + self.batch_size * self.num_variates, + self.max_seq_len, + self.num_heads, + head_dim, + ) + else: + shape = ( + time_layer_count, + self.batch_size * self.num_variates, + self.num_heads, + self.max_seq_len, + head_dim, + ) + self._keys = torch.zeros(shape, device=self.device, dtype=self.dtype) + self._values = torch.zeros_like(self._keys) + self._current_idx = torch.zeros( + time_layer_count, device=self.device, dtype=torch.int + ) + self._layer_cache_map = torch.zeros( + (self.num_layers,), dtype=torch.int, device=self.device + ) + for cache_idx, layer_idx in enumerate(time_layer_indices): + self._layer_cache_map[layer_idx] = int(cache_idx) + + def __getitem__(self, layer_idx: int) -> KV: + cache_idx = int(self._layer_cache_map[layer_idx].item()) + end_idx = int(self._current_idx[cache_idx].item()) + + if self.use_memory_efficient_attention: + return ( + self._keys[cache_idx, :, :end_idx, :, :], + self._values[cache_idx, :, :end_idx, :, :], + ) + else: + return ( + self._keys[cache_idx, :, :, :end_idx, :], + self._values[cache_idx, :, :, :end_idx, :], + ) + + def current_len(self, cache_idx: int) -> int: + return ( + int(self._current_idx[cache_idx].item()) + if self._current_idx.numel() > 0 + else 0 + ) + + def seq_len(self, layer_idx: int) -> int: + cache_idx = int(self._layer_cache_map[layer_idx].item()) + return self.current_len(cache_idx) + + def append(self, layer_idx: int, kv: KV): + cache_idx = int(self._layer_cache_map[layer_idx].item()) + keys, values = kv + + assert keys.shape == values.shape, "keys and values must have the same shape" + assert ( + keys.shape[0] == self.batch_size * self.num_variates + ), "keys and values must have batch_size * num_variates as their first dimension" + + if self.use_memory_efficient_attention: + assert keys.shape[2] == self.num_heads + else: + assert keys.shape[1] == self.num_heads + assert keys.shape[3] == self.embed_dim // self.num_heads + + start_idx = self._current_idx[cache_idx] + if self.use_memory_efficient_attention: + end_idx = start_idx + keys.shape[1] + else: + end_idx = start_idx + keys.shape[2] + assert ( + end_idx <= self.max_seq_len + ), f"max_seq_len exceeded {end_idx} > {self.max_seq_len}, keys.shape: {keys.shape}" + + if self.use_memory_efficient_attention: + self._keys[cache_idx, :, start_idx:end_idx, :, :] = keys + self._values[cache_idx, :, start_idx:end_idx, :, :] = values + else: + self._keys[cache_idx, :, :, start_idx:end_idx, :] = keys + self._values[cache_idx, :, :, start_idx:end_idx, :] = values + + self._current_idx[cache_idx] = end_idx + + def reset(self): + self._keys.zero_() + self._values.zero_() + self._current_idx.zero_() diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py new file mode 100644 index 000000000000..08fda1c3c723 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py @@ -0,0 +1,167 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import json +import os + +import safetensors.torch as safetorch +from transformers import PreTrainedModel + +from iotdb.ainode.core.log import Logger + +from .configuration_toto import TotoConfig +from .model.attention import XFORMERS_AVAILABLE +from .model.backbone import TotoBackbone +from .model.toto import Toto +from .model.transformer import XFORMERS_SWIGLU_AVAILABLE + +logger = Logger() + + +class TotoPreTrainedModel(PreTrainedModel): + """Abstract base class for all Toto model variants.""" + + config_class = TotoConfig + base_model_prefix = "model" + supports_gradient_checkpointing = False + + def _init_weights(self, module): + # Weights are loaded from the pretrained checkpoint; no random initialisation needed. + pass + + +class TotoForPrediction(TotoPreTrainedModel): + """ + Toto (Timeseries-Optimized Transformer for Observability) model for time series prediction. + + Integrates the Toto backbone with AINode's model loading mechanism using the + transformers PreTrainedModel interface. Weights are loaded directly from the + Datadog/Toto-Open-Base-1.0 safetensors checkpoint. + + The backbone is stored as ``self.model`` so that safetensors key prefixes + (``model.*``) map directly to parameters without any renaming. + + Reference: https://huggingface.co/Datadog/Toto-Open-Base-1.0 + """ + + def __init__(self, config: TotoConfig): + super().__init__(config) + # Backbone stored as self.model so safetensors keys (model.*) match directly. + self.model = TotoBackbone( + patch_size=config.patch_size, + stride=config.stride, + embed_dim=config.embed_dim, + num_layers=config.num_layers, + num_heads=config.num_heads, + mlp_hidden_dim=config.mlp_hidden_dim, + dropout=config.dropout, + spacewise_every_n_layers=config.spacewise_every_n_layers, + scaler_cls=config.scaler_cls, + output_distribution_classes=config.output_distribution_classes, + output_distribution_kwargs=config.output_distribution_kwargs, + spacewise_first=config.spacewise_first, + use_memory_efficient_attention=config.use_memory_efficient_attention, + stabilize_with_global=config.stabilize_with_global, + scale_factor_exponent=config.scale_factor_exponent, + ) + self.post_init() + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + """ + Load TotoForPrediction from a local directory containing ``config.json`` + and ``model.safetensors``. + + This override is required because: + 1. The safetensors file uses legacy SwiGLU key names that need remapping. + 2. The config uses class-path strings for ``scaler_cls`` and + ``output_distribution_classes`` that must not be filtered out. + + Args: + pretrained_model_name_or_path (str): Path to a local directory. + **kwargs: Extra key/value pairs merged into the config before construction. + + Returns: + TotoForPrediction: Fully initialised and weight-loaded model in eval mode. + """ + if os.path.isdir(pretrained_model_name_or_path): + config_file = os.path.join(pretrained_model_name_or_path, "config.json") + safetensors_file = os.path.join( + pretrained_model_name_or_path, "model.safetensors" + ) + else: + raise ValueError( + f"pretrained_model_name_or_path must be a local directory, " + f"got: {pretrained_model_name_or_path}" + ) + + # ── Load config ────────────────────────────────────────────────────── + config_dict: dict = {} + if os.path.exists(config_file): + with open(config_file, "r") as f: + config_dict = json.load(f) + config_dict.update(kwargs) + + # Disable xFormers memory-efficient attention if the library is absent. + if not XFORMERS_AVAILABLE and config_dict.get( + "use_memory_efficient_attention", True + ): + config_dict["use_memory_efficient_attention"] = False + + config = TotoConfig(**config_dict) + + # ── Instantiate model ───────────────────────────────────────────────── + instance = cls(config) + + # ── Load safetensors weights ────────────────────────────────────────── + if not os.path.exists(safetensors_file): + raise FileNotFoundError( + f"Model checkpoint not found at: {safetensors_file}" + ) + + state_dict = safetorch.load_file(safetensors_file, device="cpu") + + # Remap SwiGLU weight names if the fused xFormers kernel is available. + use_fused_swiglu = XFORMERS_SWIGLU_AVAILABLE and not config_dict.get( + "pre_xformers_checkpoint", False + ) + state_dict = Toto._map_state_dict_keys(state_dict, use_fused_swiglu) + + # Filter to keys that exist in the model, skipping cached rotary buffers. + model_state = instance.state_dict() + filtered_state_dict = { + k: v + for k, v in state_dict.items() + if k in model_state and not k.endswith("rotary_emb.freqs") + } + + instance.load_state_dict(filtered_state_dict, strict=False) + instance.eval() + + logger.info(f"Loaded Toto model from {pretrained_model_name_or_path}") + return instance + + @property + def backbone(self): + """The underlying ``TotoBackbone`` used for inference.""" + return self.model + + @property + def device(self): + """Device on which model parameters reside.""" + return next(self.parameters()).device diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py new file mode 100644 index 000000000000..c6778a5e90bb --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import torch + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.toto.data.util.dataset import MaskedTimeseries +from iotdb.ainode.core.model.toto.inference.forecaster import TotoForecaster + +logger = Logger() + + +class TotoPipeline(ForecastPipeline): + """ + Inference pipeline for the Toto time series foundation model. + + Converts raw input tensors into ``MaskedTimeseries`` objects and delegates + autoregressive decoding to ``TotoForecaster``. The forecaster is created + lazily on the first call to ``forecast()`` so that pipeline construction + does not require a live model (useful during import / registration time). + """ + + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, **model_kwargs) + # Forecaster is created lazily to avoid issues at construction time. + self._forecaster: TotoForecaster | None = None + + def _get_forecaster(self) -> TotoForecaster: + """Return the cached forecaster, creating it on first call.""" + if self._forecaster is None: + self._forecaster = TotoForecaster(self.model.backbone) + return self._forecaster + + def _preprocess(self, inputs, **infer_kwargs): + """ + Preprocess input data for Toto. + + Converts each input dict into a ``MaskedTimeseries`` named-tuple that + the ``TotoForecaster`` expects. + + Parameters + ---------- + inputs : list of dict + A list of dictionaries containing input data. Each dictionary contains: + - 'targets': A tensor (1D or 2D) of shape (input_length,) or (target_count, input_length). + + infer_kwargs: Additional keyword arguments for inference, such as: + - `output_length`(int): Prediction length. + + Returns + ------- + list of MaskedTimeseries + Processed inputs compatible with Toto's forecaster. + """ + processed_inputs = [] + for item in inputs: + targets = item["targets"] + if targets.ndim == 1: + targets = targets.unsqueeze(0) + + n_variates, series_len = targets.shape + device = targets.device + + if "past_covariates" in item or "future_covariates" in item: + logger.warning( + "TotoPipeline does not support covariates; they will be ignored." + ) + + padding_mask = ~torch.isnan(targets) + targets = targets.nan_to_num(0.0) + + id_mask = torch.zeros( + n_variates, series_len, dtype=torch.long, device=device + ) + timestamp_seconds = ( + torch.arange(series_len, dtype=torch.long, device=device) + .unsqueeze(0) + .expand(n_variates, series_len) + ) + time_interval_seconds = torch.ones( + n_variates, dtype=torch.long, device=device + ) + + processed_inputs.append( + MaskedTimeseries( + series=targets, + padding_mask=padding_mask, + id_mask=id_mask, + timestamp_seconds=timestamp_seconds, + time_interval_seconds=time_interval_seconds, + ) + ) + + return processed_inputs + + def forecast(self, inputs, **infer_kwargs) -> list[torch.Tensor]: + output_length = infer_kwargs.get("output_length", 96) + num_samples = infer_kwargs.get("num_samples", None) + samples_per_batch = infer_kwargs.get("samples_per_batch", 10) + + forecaster = self._get_forecaster() + + outputs = [] + for masked_ts in inputs: + masked_ts = masked_ts._replace( + series=masked_ts.series.to(self.model.device), + padding_mask=masked_ts.padding_mask.to(self.model.device), + id_mask=masked_ts.id_mask.to(self.model.device), + timestamp_seconds=masked_ts.timestamp_seconds.to(self.model.device), + time_interval_seconds=masked_ts.time_interval_seconds.to( + self.model.device + ), + ) + result = forecaster.forecast( + masked_ts, + prediction_length=output_length, + num_samples=num_samples, + samples_per_batch=samples_per_batch, + ) + mean = result.mean + # Remove batch dimension if present (batch=1 squeeze). + if mean.ndim == 3 and mean.shape[0] == 1: + mean = mean.squeeze(0) + outputs.append(mean) + return outputs + + def _postprocess(self, outputs, **infer_kwargs) -> list[torch.Tensor]: + return outputs diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml index 9a142fe72596..0dc630fb0ff6 100644 --- a/iotdb-core/ainode/pyproject.toml +++ b/iotdb-core/ainode/pyproject.toml @@ -117,6 +117,7 @@ setuptools = ">=75.3.0" joblib = ">=1.4.2" urllib3 = "2.6.3" jaxtyping = ">=0.2.24" +rotary-embedding-torch = ">=0.8.0" [tool.poetry.scripts] ainode = "iotdb.ainode.core.script:main"