Skip to content

Commit 17d4276

Browse files
committed
[AINode] Add Toto time series foundation model as a built-in model
Integrate Datadog's Toto-Open-Base-1.0 model into AINode's builtin model registry. - Add TotoConfig (PretrainedConfig) with Toto architecture params - Add TotoForPrediction wrapper (from_pretrained via ModelHubMixin) - Add TotoPipeline (ForecastPipeline) with lazy toto-ts import - Register 'toto' in BUILTIN_HF_TRANSFORMERS_MODEL_MAP - Add 'toto' entry to AINodeTestUtils.BUILTIN_LTSM_MAP toto-ts is an optional dependency (lazy import with clear install instructions); no changes to pyproject.toml or poetry.lock required.
1 parent 49b0c35 commit 17d4276

3 files changed

Lines changed: 209 additions & 0 deletions

File tree

iotdb-core/ainode/iotdb/ainode/core/model/model_info.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,15 @@ def __repr__(self):
165165
model_type="toto",
166166
pipeline_cls="pipeline_toto.TotoPipeline",
167167
repo_id="Datadog/Toto-Open-Base-1.0",
168+
<<<<<<< Updated upstream
168169
auto_map=None,
169170
transformers_registered=False,
171+
=======
172+
auto_map={
173+
"AutoConfig": "configuration_toto.TotoConfig",
174+
"AutoModelForCausalLM": "modeling_toto.TotoForPrediction",
175+
},
176+
transformers_registered=True,
177+
>>>>>>> Stashed changes
170178
),
171179
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
19+
from typing import List, Optional
20+
21+
from transformers import PretrainedConfig
22+
23+
24+
class TotoConfig(PretrainedConfig):
25+
"""
26+
Configuration class for the Toto time series forecasting model.
27+
28+
Toto (Time Series Optimized Transformer for Observability) is a foundation model
29+
for multivariate time series forecasting developed by Datadog. It uses a decoder-only
30+
architecture with per-variate patch-based causal scaling, proportional time-variate
31+
factorized attention, and a Student-T mixture prediction head.
32+
33+
Reference: https://github.com/DataDog/toto
34+
"""
35+
36+
model_type = "toto"
37+
38+
def __init__(
39+
self,
40+
patch_size: int = 32,
41+
stride: int = 32,
42+
embed_dim: int = 1024,
43+
num_layers: int = 18,
44+
num_heads: int = 16,
45+
mlp_hidden_dim: int = 2816,
46+
dropout: float = 0.0,
47+
spacewise_every_n_layers: int = 3,
48+
scaler_cls: str = "per_variate_causal",
49+
output_distribution_classes: Optional[List[str]] = None,
50+
spacewise_first: bool = True,
51+
use_memory_efficient_attention: bool = True,
52+
stabilize_with_global: bool = True,
53+
scale_factor_exponent: float = 10.0,
54+
**kwargs,
55+
):
56+
self.patch_size = patch_size
57+
self.stride = stride
58+
self.embed_dim = embed_dim
59+
self.num_layers = num_layers
60+
self.num_heads = num_heads
61+
self.mlp_hidden_dim = mlp_hidden_dim
62+
self.dropout = dropout
63+
self.spacewise_every_n_layers = spacewise_every_n_layers
64+
self.scaler_cls = scaler_cls
65+
self.output_distribution_classes = output_distribution_classes or [
66+
"student_t_mixture"
67+
]
68+
self.spacewise_first = spacewise_first
69+
self.use_memory_efficient_attention = use_memory_efficient_attention
70+
self.stabilize_with_global = stabilize_with_global
71+
self.scale_factor_exponent = scale_factor_exponent
72+
73+
super().__init__(**kwargs)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
19+
import torch
20+
21+
from iotdb.ainode.core.log import Logger
22+
23+
logger = Logger()
24+
25+
26+
class TotoForPrediction(torch.nn.Module):
27+
"""
28+
Wrapper around the Toto model for AINode integration.
29+
30+
Toto (Time Series Optimized Transformer for Observability) is a 151M parameter
31+
foundation model for multivariate time series forecasting. This wrapper delegates
32+
model loading to the ``toto-ts`` package while providing a compatible interface
33+
for AINode's model loading mechanism.
34+
35+
The underlying Toto model uses ``huggingface_hub.ModelHubMixin`` for ``from_pretrained``
36+
support, which differs from the standard ``transformers.PreTrainedModel`` pattern.
37+
This wrapper bridges that gap.
38+
39+
Reference: https://huggingface.co/Datadog/Toto-Open-Base-1.0
40+
"""
41+
42+
def __init__(self, toto_model):
43+
"""
44+
Initialize the wrapper with a loaded Toto model instance.
45+
46+
Args:
47+
toto_model: A ``toto.model.toto.Toto`` instance.
48+
"""
49+
super().__init__()
50+
self.toto = toto_model
51+
52+
@classmethod
53+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
54+
"""
55+
Load a Toto model from a local directory or HuggingFace Hub repository.
56+
57+
This delegates to ``toto.model.toto.Toto.from_pretrained()`` which uses
58+
``ModelHubMixin`` to load the model weights and configuration.
59+
60+
Args:
61+
pretrained_model_name_or_path (str): Path to a local directory containing
62+
``config.json`` and ``model.safetensors``, or a HuggingFace Hub repo ID
63+
(e.g., ``Datadog/Toto-Open-Base-1.0``).
64+
**kwargs: Additional keyword arguments passed to the underlying loader.
65+
66+
Returns:
67+
TotoForPrediction: A wrapper instance containing the loaded Toto model.
68+
"""
69+
from toto.model.toto import Toto
70+
71+
toto_model = Toto.from_pretrained(pretrained_model_name_or_path, **kwargs)
72+
logger.info(f"Loaded Toto model from {pretrained_model_name_or_path}")
73+
return cls(toto_model)
74+
75+
@classmethod
76+
def from_config(cls, config):
77+
"""
78+
Create a Toto model from a configuration (for training from scratch).
79+
80+
Args:
81+
config: A ``TotoConfig`` or compatible configuration object.
82+
83+
Returns:
84+
TotoForPrediction: A wrapper instance containing a newly initialized Toto model.
85+
"""
86+
from toto.model.toto import Toto
87+
88+
toto_model = Toto(
89+
patch_size=getattr(config, "patch_size", 32),
90+
stride=getattr(config, "stride", 32),
91+
embed_dim=getattr(config, "embed_dim", 1024),
92+
num_layers=getattr(config, "num_layers", 18),
93+
num_heads=getattr(config, "num_heads", 16),
94+
mlp_hidden_dim=getattr(config, "mlp_hidden_dim", 2816),
95+
dropout=getattr(config, "dropout", 0.0),
96+
spacewise_every_n_layers=getattr(config, "spacewise_every_n_layers", 3),
97+
scaler_cls=getattr(config, "scaler_cls", "per_variate_causal"),
98+
output_distribution_classes=getattr(
99+
config, "output_distribution_classes", ["student_t_mixture"]
100+
),
101+
spacewise_first=getattr(config, "spacewise_first", True),
102+
use_memory_efficient_attention=getattr(
103+
config, "use_memory_efficient_attention", True
104+
),
105+
stabilize_with_global=getattr(config, "stabilize_with_global", True),
106+
scale_factor_exponent=getattr(config, "scale_factor_exponent", 10.0),
107+
)
108+
return cls(toto_model)
109+
110+
@property
111+
def backbone(self):
112+
"""
113+
Access the underlying TotoBackbone model used for inference.
114+
115+
Returns:
116+
The ``TotoBackbone`` instance from the Toto model.
117+
"""
118+
return self.toto.model
119+
120+
@property
121+
def device(self):
122+
"""
123+
Get the device of the model parameters.
124+
125+
Returns:
126+
torch.device: The device where the model parameters reside.
127+
"""
128+
return self.toto.device

0 commit comments

Comments
 (0)