|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +# |
| 18 | + |
| 19 | +import torch |
| 20 | + |
| 21 | +from iotdb.ainode.core.log import Logger |
| 22 | + |
| 23 | +logger = Logger() |
| 24 | + |
| 25 | + |
| 26 | +class TotoForPrediction(torch.nn.Module): |
| 27 | + """ |
| 28 | + Wrapper around the Toto model for AINode integration. |
| 29 | +
|
| 30 | + Toto (Time Series Optimized Transformer for Observability) is a 151M parameter |
| 31 | + foundation model for multivariate time series forecasting. This wrapper delegates |
| 32 | + model loading to the ``toto-ts`` package while providing a compatible interface |
| 33 | + for AINode's model loading mechanism. |
| 34 | +
|
| 35 | + The underlying Toto model uses ``huggingface_hub.ModelHubMixin`` for ``from_pretrained`` |
| 36 | + support, which differs from the standard ``transformers.PreTrainedModel`` pattern. |
| 37 | + This wrapper bridges that gap. |
| 38 | +
|
| 39 | + Reference: https://huggingface.co/Datadog/Toto-Open-Base-1.0 |
| 40 | + """ |
| 41 | + |
| 42 | + def __init__(self, toto_model): |
| 43 | + """ |
| 44 | + Initialize the wrapper with a loaded Toto model instance. |
| 45 | +
|
| 46 | + Args: |
| 47 | + toto_model: A ``toto.model.toto.Toto`` instance. |
| 48 | + """ |
| 49 | + super().__init__() |
| 50 | + self.toto = toto_model |
| 51 | + |
| 52 | + @classmethod |
| 53 | + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| 54 | + """ |
| 55 | + Load a Toto model from a local directory or HuggingFace Hub repository. |
| 56 | +
|
| 57 | + This delegates to ``toto.model.toto.Toto.from_pretrained()`` which uses |
| 58 | + ``ModelHubMixin`` to load the model weights and configuration. |
| 59 | +
|
| 60 | + Args: |
| 61 | + pretrained_model_name_or_path (str): Path to a local directory containing |
| 62 | + ``config.json`` and ``model.safetensors``, or a HuggingFace Hub repo ID |
| 63 | + (e.g., ``Datadog/Toto-Open-Base-1.0``). |
| 64 | + **kwargs: Additional keyword arguments passed to the underlying loader. |
| 65 | +
|
| 66 | + Returns: |
| 67 | + TotoForPrediction: A wrapper instance containing the loaded Toto model. |
| 68 | + """ |
| 69 | + from toto.model.toto import Toto |
| 70 | + |
| 71 | + toto_model = Toto.from_pretrained(pretrained_model_name_or_path, **kwargs) |
| 72 | + logger.info(f"Loaded Toto model from {pretrained_model_name_or_path}") |
| 73 | + return cls(toto_model) |
| 74 | + |
| 75 | + @classmethod |
| 76 | + def from_config(cls, config): |
| 77 | + """ |
| 78 | + Create a Toto model from a configuration (for training from scratch). |
| 79 | +
|
| 80 | + Args: |
| 81 | + config: A ``TotoConfig`` or compatible configuration object. |
| 82 | +
|
| 83 | + Returns: |
| 84 | + TotoForPrediction: A wrapper instance containing a newly initialized Toto model. |
| 85 | + """ |
| 86 | + from toto.model.toto import Toto |
| 87 | + |
| 88 | + toto_model = Toto( |
| 89 | + patch_size=getattr(config, "patch_size", 32), |
| 90 | + stride=getattr(config, "stride", 32), |
| 91 | + embed_dim=getattr(config, "embed_dim", 1024), |
| 92 | + num_layers=getattr(config, "num_layers", 18), |
| 93 | + num_heads=getattr(config, "num_heads", 16), |
| 94 | + mlp_hidden_dim=getattr(config, "mlp_hidden_dim", 2816), |
| 95 | + dropout=getattr(config, "dropout", 0.0), |
| 96 | + spacewise_every_n_layers=getattr(config, "spacewise_every_n_layers", 3), |
| 97 | + scaler_cls=getattr(config, "scaler_cls", "per_variate_causal"), |
| 98 | + output_distribution_classes=getattr( |
| 99 | + config, "output_distribution_classes", ["student_t_mixture"] |
| 100 | + ), |
| 101 | + spacewise_first=getattr(config, "spacewise_first", True), |
| 102 | + use_memory_efficient_attention=getattr( |
| 103 | + config, "use_memory_efficient_attention", True |
| 104 | + ), |
| 105 | + stabilize_with_global=getattr(config, "stabilize_with_global", True), |
| 106 | + scale_factor_exponent=getattr(config, "scale_factor_exponent", 10.0), |
| 107 | + ) |
| 108 | + return cls(toto_model) |
| 109 | + |
| 110 | + @property |
| 111 | + def backbone(self): |
| 112 | + """ |
| 113 | + Access the underlying TotoBackbone model used for inference. |
| 114 | +
|
| 115 | + Returns: |
| 116 | + The ``TotoBackbone`` instance from the Toto model. |
| 117 | + """ |
| 118 | + return self.toto.model |
| 119 | + |
| 120 | + @property |
| 121 | + def device(self): |
| 122 | + """ |
| 123 | + Get the device of the model parameters. |
| 124 | +
|
| 125 | + Returns: |
| 126 | + torch.device: The device where the model parameters reside. |
| 127 | + """ |
| 128 | + return self.toto.device |
0 commit comments