diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java index f8aa27ce6883a..8356a0553112f 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java @@ -20,12 +20,16 @@ package org.apache.iotdb.ainode.it; import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.AIClusterIT; import org.apache.iotdb.itbase.env.BaseEnv; import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; import java.sql.Connection; import java.sql.ResultSet; @@ -41,9 +45,13 @@ import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; +@RunWith(IoTDBTestRunner.class) +@Category({AIClusterIT.class}) public class AINodeInstanceManagementIT { - private static final Set TARGET_DEVICES = new HashSet<>(Arrays.asList("cpu", "0", "1")); + private static final String TARGET_DEVICES_STR = "0,1"; + private static final Set TARGET_DEVICES = + new HashSet<>(Arrays.asList(TARGET_DEVICES_STR.split(","))); @BeforeClass public static void setUp() throws Exception { @@ -76,53 +84,57 @@ private void basicManagementTest(Statement statement) throws SQLException, Inter // Ensure resources try (ResultSet resultSet = statement.executeQuery("SHOW AI_DEVICES")) { ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID"); + checkHeader(resultSetMetaData, "DeviceId,DeviceType"); final Set resultDevices = new HashSet<>(); while (resultSet.next()) { - resultDevices.add(resultSet.getString("DeviceID")); + resultDevices.add(resultSet.getString("DeviceId")); } - Assert.assertEquals(TARGET_DEVICES, resultDevices); + Set expected = new HashSet<>(TARGET_DEVICES); + expected.add("cpu"); + Assert.assertEquals(expected, resultDevices); } // Load sundial to each device - statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES)); - checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); + statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES_STR)); + checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR); // Unload sundial from each device - statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES)); - checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); + statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES_STR)); + checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR); // Load timer_xl to each device - statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'", TARGET_DEVICES)); - checkModelOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString()); + statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'", TARGET_DEVICES_STR)); + checkModelOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES_STR); // Unload timer_xl from each device - statement.execute(String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'", TARGET_DEVICES)); - checkModelNotOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString()); + statement.execute(String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'", TARGET_DEVICES_STR)); + checkModelNotOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES_STR); } private static final int LOOP_CNT = 10; - @Test + // @Test public void repeatLoadAndUnloadTest() throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { for (int i = 0; i < LOOP_CNT; i++) { - statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\""); - checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); - statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\""); - checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); + statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES_STR)); + checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR); + statement.execute( + String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES_STR)); + checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR); } } } - @Test + // @Test public void concurrentLoadAndUnloadTest() throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); Statement statement = connection.createStatement()) { for (int i = 0; i < LOOP_CNT; i++) { - statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\""); - statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\""); + statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES_STR)); + statement.execute( + String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES_STR)); } - checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); + checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR); } } @@ -145,23 +157,23 @@ public void failTestInTableModel() throws SQLException { private void failTest(Statement statement) { errorTest( statement, - "LOAD MODEL unknown TO DEVICES \"cpu,0,1\"", - "1505: Cannot load model [unknown], because it is neither a built-in nor a fine-tuned model. You can use 'SHOW MODELS' to retrieve the available models."); + "LOAD MODEL unknown TO DEVICES 'cpu,0,1'", + "1504: Model [unknown] is not registered yet. You can use 'SHOW MODELS' to retrieve the available models."); errorTest( statement, - "LOAD MODEL sundial TO DEVICES \"unknown\"", - "1507: Device ID [unknown] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices."); + "LOAD MODEL sundial TO DEVICES '999'", + "1508: AIDevice ID [999] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices."); errorTest( statement, - "UNLOAD MODEL sundial FROM DEVICES \"unknown\"", - "1507: Device ID [unknown] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices."); + "UNLOAD MODEL sundial FROM DEVICES '999'", + "1508: AIDevice ID [999] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices."); errorTest( statement, - "LOAD MODEL sundial TO DEVICES \"0,0\"", + "LOAD MODEL sundial TO DEVICES '0,0'", "1509: Device ID list contains duplicate entries."); errorTest( statement, - "UNLOAD MODEL sundial FROM DEVICES \"0,0\"", + "UNLOAD MODEL sundial FROM DEVICES '0,0'", "1510: Device ID list contains duplicate entries."); } } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java index 8ece0ba7523e0..9e78c8b025c60 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java @@ -71,15 +71,22 @@ public static void tearDown() throws Exception { public void userDefinedModelManagementTestInTree() throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { - registerUserDefinedModel(statement); - callInferenceTest( - statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active")); - dropUserDefinedModel(statement); + // Test transformers model (chronos2) in tree. + AINodeTestUtils.FakeModelInfo modelInfo = + new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active"); + registerUserDefinedModel(statement, modelInfo, "file:///data/chronos2"); + callInferenceTest(statement, modelInfo); + dropUserDefinedModel(statement, modelInfo.getModelId()); errorTest( statement, "create model origin_chronos using uri \"file:///data/chronos2_origin\"", "1505: 't5' is already used by a Transformers config, pick another name."); statement.execute("drop model origin_chronos"); + + // Test PytorchModelHubMixin model (mantis) in tree. + modelInfo = new FakeModelInfo("user_mantis", "custom_mantis", "user_defined", "active"); + registerUserDefinedModel(statement, modelInfo, "file:///data/mantis"); + dropUserDefinedModel(statement, modelInfo.getModelId()); } } @@ -87,23 +94,35 @@ public void userDefinedModelManagementTestInTree() throws SQLException, Interrup public void userDefinedModelManagementTestInTable() throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); Statement statement = connection.createStatement()) { - registerUserDefinedModel(statement); - forecastTableFunctionTest( - statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active")); - dropUserDefinedModel(statement); + // Test transformers model (chronos2) in table. + AINodeTestUtils.FakeModelInfo modelInfo = + new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active"); + registerUserDefinedModel(statement, modelInfo, "file:///data/chronos2"); + forecastTableFunctionTest(statement, modelInfo); + dropUserDefinedModel(statement, modelInfo.getModelId()); errorTest( statement, "create model origin_chronos using uri \"file:///data/chronos2_origin\"", "1505: 't5' is already used by a Transformers config, pick another name."); statement.execute("drop model origin_chronos"); + + // Test PytorchModelHubMixin model (mantis) in table. + modelInfo = new FakeModelInfo("user_mantis", "custom_mantis", "user_defined", "active"); + registerUserDefinedModel(statement, modelInfo, "file:///data/mantis"); + dropUserDefinedModel(statement, modelInfo.getModelId()); } } - private void registerUserDefinedModel(Statement statement) + public static void registerUserDefinedModel( + Statement statement, AINodeTestUtils.FakeModelInfo modelInfo, String uri) throws SQLException, InterruptedException { + String modelId = modelInfo.getModelId(); + String modelType = modelInfo.getModelType(); + String category = modelInfo.getCategory(); + final String CREATE_MODEL_TEMPLATE = "create model %s using uri \"%s\""; final String alterConfigSQL = "set configuration \"trusted_uri_pattern\"='.*'"; - final String registerSql = "create model user_chronos using uri \"file:///data/chronos2\""; - final String showSql = "SHOW MODELS user_chronos"; + final String registerSql = String.format(CREATE_MODEL_TEMPLATE, modelId, uri); + final String showSql = String.format("SHOW MODELS %s", modelId); statement.execute(alterConfigSQL); statement.execute(registerSql); boolean loading = true; @@ -112,13 +131,13 @@ private void registerUserDefinedModel(Statement statement) ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State"); while (resultSet.next()) { - String modelId = resultSet.getString(1); - String modelType = resultSet.getString(2); - String category = resultSet.getString(3); + String resultModelId = resultSet.getString(1); + String resultModelType = resultSet.getString(2); + String resultCategory = resultSet.getString(3); String state = resultSet.getString(4); - assertEquals("user_chronos", modelId); - assertEquals("custom_t5", modelType); - assertEquals("user_defined", category); + assertEquals(modelId, resultModelId); + assertEquals(modelType, resultModelType); + assertEquals(category, resultCategory); if (state.equals("active")) { loading = false; } else if (state.equals("loading")) { @@ -136,9 +155,9 @@ private void registerUserDefinedModel(Statement statement) assertFalse(loading); } - private void dropUserDefinedModel(Statement statement) throws SQLException { - final String showSql = "SHOW MODELS user_chronos"; - final String dropSql = "DROP MODEL user_chronos"; + public static void dropUserDefinedModel(Statement statement, String modelId) throws SQLException { + final String showSql = String.format("SHOW MODELS %s", modelId); + final String dropSql = String.format("DROP MODEL %s", modelId); statement.execute(dropSql); try (ResultSet resultSet = statement.executeQuery(showSql)) { ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java index 5a4dce53666d3..e41d3d4e0f97e 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java @@ -51,10 +51,10 @@ public class AINodeTestUtils { public static final Map BUILTIN_LTSM_MAP = Stream.of( - new AbstractMap.SimpleEntry<>( - "sundial", new FakeModelInfo("sundial", "sundial", "builtin", "active")), new AbstractMap.SimpleEntry<>( "timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "sundial", new FakeModelInfo("sundial", "sundial", "builtin", "active")), new AbstractMap.SimpleEntry<>( "chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active")), new AbstractMap.SimpleEntry<>( @@ -171,7 +171,7 @@ public static void checkModelOnSpecifiedDevice(Statement statement, String model LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count); if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) { foundDevices.add(deviceId); - LOGGER.info("Model {} is loaded to device {}", modelId, device); + LOGGER.info("Model {} is loaded to device {}", modelId, deviceId); } } if (foundDevices.containsAll(targetDevices)) { @@ -252,6 +252,32 @@ public static void prepareDataInTable() throws SQLException { } } + /** Prepare db.AI2(s0 FLOAT,...) with 2880 rows of data in table. */ + public static void prepareDataInTable2() throws SQLException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + statement.execute("CREATE DATABASE db"); + statement.execute( + "CREATE TABLE db.AI2 (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD, s4 FLOAT FIELD, s5 DOUBLE FIELD, s6 INT32 FIELD, s7 INT64 FIELD, s8 FLOAT FIELD, s9 DOUBLE FIELD)"); + for (int i = 0; i < 2880; i++) { + statement.execute( + String.format( + "INSERT INTO db.AI2(time,s0,s1,s2,s3,s4,s5,s6,s7,s8,s9) VALUES(%d,%f,%f,%d,%d,%f,%f,%d,%d,%f,%f)", + i, + (float) i, + (double) i, + i, + i, + (float) (i * 2), + (double) (i * 2), + i * 2, + i * 2, + (float) (i * 3), + (double) (i * 3))); + } + } + } + public static class FakeModelInfo { private final String modelId; diff --git a/iotdb-core/ainode/iotdb/ainode/core/exception.py b/iotdb-core/ainode/iotdb/ainode/core/exception.py index b007ee58c4840..b76baa3e2e25c 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/exception.py +++ b/iotdb-core/ainode/iotdb/ainode/core/exception.py @@ -15,12 +15,6 @@ # specific language governing permissions and limitations # under the License. # -import re - -from iotdb.ainode.core.model.model_constants import ( - MODEL_CONFIG_FILE_IN_YAML, - MODEL_WEIGHTS_FILE_IN_PT, -) class _BaseException(Exception): diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py index 5d0026522a136..f4bf914a84668 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py @@ -56,17 +56,23 @@ class ForecastPipeline(BasicPipeline): def __init__(self, model_info: ModelInfo, **model_kwargs): super().__init__(model_info, **model_kwargs) - def preprocess( + # ========================= Preprocess ========================= + def preprocess(self, inputs, **infer_kwargs): + inputs = self._base_preprocess(inputs, **infer_kwargs) + return self._preprocess(inputs, **infer_kwargs) + + def _base_preprocess( self, - inputs: list[dict[str, dict[str, torch.Tensor] | torch.Tensor]], + inputs, **infer_kwargs, - ): + ) -> list[dict[str, dict[str, torch.Tensor] | torch.Tensor]]: """ - Preprocess the input data before passing it to the model for inference, validating the shape and type of the input data. + The common preprocess logic for all forecast pipelines, + validating the shape and type of the input data. Args: inputs (list[dict]): - The input data, a list of dictionaries, where each dictionary contains: + The input data, expected a list of dictionaries, where each dictionary contains: - 'targets': A tensor (1D or 2D) of shape (input_length,) or (target_count, input_length). - 'past_covariates': A dictionary of tensors (optional), where each tensor has shape (input_length,). - 'future_covariates': A dictionary of tensors (optional), where each tensor has shape (input_length,). @@ -79,7 +85,11 @@ def preprocess( ValueError: If the input format is incorrect (e.g., missing keys, invalid tensor shapes). Returns: - The preprocessed inputs, validated and ready for model inference. + list[dict]: + The validated input data, a list of dictionaries, where each dictionary contains: + - 'targets': A tensor (1D or 2D) of shape (input_length,) or (target_count, input_length). + - 'past_covariates': A dictionary of tensors (optional), where each tensor has shape (input_length,). + - 'future_covariates': A dictionary of tensors (optional), where each tensor has shape (input_length,). """ if isinstance(inputs, list): @@ -211,10 +221,34 @@ def preprocess( ) return inputs + def _preprocess( + self, + inputs: list[dict[str, dict[str, torch.Tensor] | torch.Tensor]], + **infer_kwargs, + ): + """ + Optional hook for subclasses to implement custom preprocessing logic. + This method is called after the base validation in `_base_preprocess`, so the inputs + are unified when this method is invoked. + + Args: + inputs (list[dict]): The validated input data, a list of dictionaries, where each dictionary contains: + - 'targets': A tensor of shape (input_length,) or (target_count, input_length). + - 'past_covariates' (optional): A dictionary of 1-D tensors, each of shape (input_length,). + - 'future_covariates' (optional): A dictionary of 1-D tensors, each of shape (output_length,), + whose keys are guaranteed to be a subset of 'past_covariates'. + **infer_kwargs: Additional keyword arguments passed through from the pipeline. + + Returns: + inputs: The modified inputs ready for model inference. + """ + return inputs + + # ========================== Forecast ========================== @abstractmethod def forecast(self, inputs, **infer_kwargs): """ - Perform forecasting on the given inputs. + Perform forecasting on the given inputs, which must be implemented by the subclasses. Parameters: inputs: The input data used for making predictions. The type and structure @@ -225,13 +259,35 @@ def forecast(self, inputs, **infer_kwargs): Returns: The forecasted output, which will depend on the specific model's implementation. """ - pass + raise NotImplementedError("forecast not implemented") + + # ========================= Postprocess ======================== + def postprocess(self, outputs, **infer_kwargs): + outputs = self._postprocess(outputs, **infer_kwargs) + return self._base_postprocess(outputs, **infer_kwargs) + + def _postprocess(self, outputs, **infer_kwargs) -> list[torch.Tensor]: + """ + Optional hook for subclasses to implement custom postprocessing logic. + This method is called before the base validation in `_base_postprocess`, so the outputs + must conform to the expected format when this method returns. + + Args: + outputs: The raw model outputs. + **infer_kwargs: Additional keyword arguments passed through from the pipeline. - def postprocess( + Returns: + list[torch.Tensor]: The modified outputs, which must be a list of 2-D tensors + with shape (target_count, output_length), as this will be validated by `_base_postprocess`. + """ + return outputs + + def _base_postprocess( self, outputs: list[torch.Tensor], **infer_kwargs ) -> list[torch.Tensor]: """ - Postprocess the model outputs after inference, validating the shape of the output data and ensures it matches the expected dimensions. + The common postprocess logic for all forecast pipelines. + validating the shape of the output data and ensures it matches the expected dimensions. Args: outputs: @@ -262,14 +318,114 @@ class ClassificationPipeline(BasicPipeline): def __init__(self, model_info: ModelInfo, **model_kwargs): super().__init__(model_info, **model_kwargs) - def preprocess(self, inputs, **kwargs): + # ========================= Preprocess ========================= + def preprocess(self, inputs, **infer_kwargs): + inputs = self._base_preprocess(inputs, **infer_kwargs) + return self._preprocess(inputs, **infer_kwargs) + + def _base_preprocess(self, inputs, **infer_kwargs) -> torch.Tensor: + """ + The common preprocess logic for all classification pipelines, + validating and preprocess the inputs. + + Args: + inputs: The input data, expected to be a 3D-tensor. + **infer_kwargs: Additional inference parameters. + + Returns: + torch.Tensor: + The preprocessed inputs, which will be a 3D-tensor with shape (batch_size, variable_count, sequence_length). + + Raises: + ValueError: If the input format is incorrect. + """ + if isinstance(inputs, torch.Tensor) and inputs.ndim == 3: + return inputs + else: + raise ValueError( + f"The inputs should be a 3D-tensor, but got {type(inputs)} with shape {tuple(inputs.shape)}." + ) + + def _preprocess(self, inputs: torch.Tensor, **infer_kwargs): + """ + Optional hook for subclasses to implement custom preprocessing logic. + This method is called after the base validation in `_base_preprocess`, so the inputs + are unified when this method is invoked. + + Args: + inputs (torch.Tensor): The validated input data, a 3D tensor. + **infer_kwargs: Additional keyword arguments passed through from the pipeline. + + Returns: + torch.Tensor: The modified inputs ready for model inference. + """ return inputs + # ========================== Classify ========================== @abstractmethod - def classify(self, inputs, **kwargs): - pass + def classify(self, inputs, **infer_kwargs): + """ + Perform classification on the given inputs, which must be implemented by the subclasses. + + Parameters: + inputs: The input data used for making classification. The type and structure + depend on the specific implementation of the model. + **infer_kwargs: Additional inference parameters. + + Returns: + The classified result, which will depend on the specific model's implementation. + """ + raise NotImplementedError("classify not implemented") + + # ========================= Postprocess ======================== + def postprocess(self, outputs, **infer_kwargs): + outputs = self._postprocess(outputs, **infer_kwargs) + return self._base_postprocess(outputs, **infer_kwargs) + + def _postprocess(self, outputs, **infer_kwargs) -> list[torch.Tensor]: + """ + Optional hook for subclasses to implement custom postprocessing logic. + This method is called before the base validation in `_base_postprocess`, so the outputs + must conform to the expected format when this method returns. + + Args: + outputs: The raw model outputs. + **infer_kwargs: Additional keyword arguments passed through from the pipeline. + + Returns: + list[torch.Tensor]: The modified outputs, which must be a list of tensors, + as this will be validated by `_base_postprocess`. + + Raises: + ValueError: If the output format is incorrect. + """ + return outputs - def postprocess(self, outputs, **kwargs): + def _base_postprocess(self, outputs, **infer_kwargs) -> list[torch.Tensor]: + """ + The common postprocess logic for all classification pipelines, + validating the shape of the output data. + + Args: + outputs (list[torch.Tensor]): + The output from the model. + **infer_kwargs: + Additional keyword arguments. + + Returns: + list[torch.Tensor]: + The postprocessed outputs. + + Raises: + ValueError: + If the output format is incorrect. + """ + if not isinstance(outputs, list) or any( + not isinstance(output, torch.Tensor) for output in outputs + ): + raise ValueError( + f"The outputs should be a list of tensors, but got {type(outputs)}." + ) return outputs @@ -277,12 +433,29 @@ class ChatPipeline(BasicPipeline): def __init__(self, model_info: ModelInfo, **model_kwargs): super().__init__(model_info, **model_kwargs) - def preprocess(self, inputs, **kwargs): + # ========================= Preprocess ========================= + def preprocess(self, inputs, **infer_kwargs): + inputs = self._base_preprocess(inputs, **infer_kwargs) + return self._preprocess(inputs, **infer_kwargs) + + def _base_preprocess(self, inputs, **infer_kwargs): return inputs + def _preprocess(self, inputs, **infer_kwargs): + return inputs + + # ========================== Chat ========================== @abstractmethod - def chat(self, inputs, **kwargs): - pass + def chat(self, inputs, **infer_kwargs): + raise NotImplementedError("chat not implemented") + + # ========================= Postprocess ======================== + def postprocess(self, outputs, **infer_kwargs): + outputs = self._postprocess(outputs, **infer_kwargs) + return self._base_postprocess(outputs, **infer_kwargs) + + def _postprocess(self, outputs, **infer_kwargs): + return outputs - def postprocess(self, outputs, **kwargs): + def _base_postprocess(self, outputs, **infer_kwargs): return outputs diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py index 07ca8a63bce02..8dcf03627dd49 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py @@ -178,30 +178,17 @@ def _process_request(self, req): def _do_inference_and_construct_resp( self, model_id: str, - model_inputs_list: list[dict[str, torch.Tensor | dict[str, torch.Tensor]]], - output_length: int, + model_inputs, inference_attrs: dict, - **kwargs, ) -> list[bytes]: - auto_adapt = kwargs.get("auto_adapt", True) - if ( - output_length - > AINodeDescriptor().get_config().get_ain_inference_max_output_length() - ): - raise NumericalRangeException( - "output_length", - output_length, - 1, - AINodeDescriptor().get_config().get_ain_inference_max_output_length(), - ) if self._pool_controller.has_running_pools(model_id): + # Only forecast task can use pool + output_length = int(inference_attrs.get("output_length", 96)) infer_req = InferenceRequest( req_id=generate_req_id(), model_id=model_id, - inputs=torch.stack( - [data["targets"] for data in model_inputs_list], dim=0 - ), + inputs=torch.stack([data["targets"] for data in model_inputs], dim=0), output_length=output_length, ) outputs = self._process_request(infer_req) @@ -210,23 +197,17 @@ def _do_inference_and_construct_resp( inference_pipeline = load_pipeline( model_info, device=self._backend.torch_device("cpu") ) - inputs = inference_pipeline.preprocess( - model_inputs_list, - output_length=output_length, - auto_adapt=auto_adapt, - ) + inputs = inference_pipeline.preprocess(model_inputs, **inference_attrs) if isinstance(inference_pipeline, ForecastPipeline): - outputs = inference_pipeline.forecast( - inputs, output_length=output_length, **inference_attrs - ) + outputs = inference_pipeline.forecast(inputs, **inference_attrs) elif isinstance(inference_pipeline, ClassificationPipeline): - outputs = inference_pipeline.classify(inputs) + outputs = inference_pipeline.classify(inputs, **inference_attrs) elif isinstance(inference_pipeline, ChatPipeline): - outputs = inference_pipeline.chat(inputs) + outputs = inference_pipeline.chat(inputs, **inference_attrs) else: outputs = None logger.error("[Inference] Unsupported pipeline type.") - outputs = inference_pipeline.postprocess(outputs) + outputs = inference_pipeline.postprocess(outputs, **inference_attrs) # convert tensor into tsblock for the output in each batch resp_list = [] @@ -235,7 +216,7 @@ def _do_inference_and_construct_resp( resp_list.append(resp) return resp_list - def _run( + def _run_forecast( self, req, data_getter, @@ -249,14 +230,26 @@ def _run( inputs = convert_tsblock_to_tensor(raw) inference_attrs = extract_attrs(req) - output_length = int(inference_attrs.pop("output_length", 96)) + output_length = int(inference_attrs.get("output_length", 96)) + if ( + output_length + > AINodeDescriptor().get_config().get_ain_inference_max_output_length() + ): + raise NumericalRangeException( + "output_length", + output_length, + 1, + AINodeDescriptor() + .get_config() + .get_ain_inference_max_output_length(), + ) - model_inputs_list: list[ - dict[str, torch.Tensor | dict[str, torch.Tensor]] - ] = [{"targets": inputs[0]}] + model_inputs: list[dict[str, torch.Tensor | dict[str, torch.Tensor]]] = [ + {"targets": inputs[0]} + ] resp_list = self._do_inference_and_construct_resp( - model_id, model_inputs_list, output_length, inference_attrs + model_id, model_inputs, inference_attrs ) return resp_cls( @@ -271,7 +264,7 @@ def _run( return resp_cls(status, empty) def forecast(self, req: TForecastReq): - return self._run( + return self._run_forecast( req, data_getter=lambda r: r.inputData, extract_attrs=lambda r: { @@ -283,7 +276,7 @@ def forecast(self, req: TForecastReq): ) def inference(self, req: TInferenceReq): - return self._run( + return self._run_forecast( req, data_getter=lambda r: r.dataset, extract_attrs=lambda r: { diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py index ff4226e734f18..e1d67873b5334 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py @@ -62,7 +62,6 @@ def register_model( get_status(TSStatusCode.CREATE_MODEL_ERROR, str(e)) ) except Exception as e: - # Catch-all for other exceptions (mainly from transformers implementation) return TRegisterModelResp( get_status(TSStatusCode.CREATE_MODEL_ERROR, str(e)) ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py index b28f8f35a6644..01ff78ba48d5b 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py @@ -36,7 +36,7 @@ class Chronos2Pipeline(ForecastPipeline): def __init__(self, model_info, **model_kwargs): super().__init__(model_info, **model_kwargs) - def preprocess(self, inputs, **infer_kwargs): + def _preprocess(self, inputs, **infer_kwargs): """ Preprocess input data of chronos2. @@ -62,7 +62,6 @@ def preprocess(self, inputs, **infer_kwargs): - 'future_covariates' (optional): dict of str to torch.Tensor Unchanged future covariates. """ - super().preprocess(inputs, **infer_kwargs) for item in inputs: item["target"] = item.pop("targets") return inputs @@ -449,7 +448,7 @@ def _predict_step( return prediction - def postprocess( + def _postprocess( self, outputs: list[torch.Tensor], **infer_kwargs ) -> list[torch.Tensor]: """ @@ -472,5 +471,4 @@ def postprocess( # If 0.5 quantile is not provided, # get the mean of all quantiles outputs_list.append(output.mean(dim=1)) - super().postprocess(outputs_list, **infer_kwargs) return outputs_list diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py index 9f1801b5073a7..a9495f1671482 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py @@ -17,11 +17,32 @@ # from enum import Enum -# Model file constants -MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors" -MODEL_CONFIG_FILE_IN_JSON = "config.json" -MODEL_WEIGHTS_FILE_IN_PT = "model.pt" -MODEL_CONFIG_FILE_IN_YAML = "config.yaml" +# ==================== File Name Constants ==================== +# +# All file names used for model persistence are defined here. +# Never hard-code these strings elsewhere – always import from +# this module. + +# -- Config files -- +CONFIG_JSON = "config.json" +CONFIG_YAML = "config.yaml" + +# -- Full model weights -- +MODEL_SAFETENSORS = "model.safetensors" +MODEL_PT = "model.pt" +MODEL_BIN = "pytorch_model.bin" # legacy HuggingFace format + +# -- Ordered tuples for detection / searching -- +MODEL_WEIGHT_FILES = (MODEL_SAFETENSORS, MODEL_PT, MODEL_BIN) + +# -- Backward-compatible aliases (deprecated, will be removed) -- +MODEL_WEIGHTS_FILE_IN_SAFETENSORS = MODEL_SAFETENSORS +MODEL_CONFIG_FILE_IN_JSON = CONFIG_JSON +MODEL_WEIGHTS_FILE_IN_PT = MODEL_PT +MODEL_CONFIG_FILE_IN_YAML = CONFIG_YAML + + +# ==================== Enumerations ==================== class ModelCategory(Enum): diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index f253fb1e56f60..da752cbd78432 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -31,6 +31,7 @@ def __init__( pipeline_cls: str = "", repo_id: str = "", auto_map: Optional[Dict] = None, + hub_mixin_cls: Optional[str] = None, transformers_registered: bool = False, ): self.model_id = model_id @@ -39,16 +40,16 @@ def __init__( self.state = state self.pipeline_cls = pipeline_cls self.repo_id = repo_id - self.auto_map = auto_map # If exists, indicates it's a Transformers model - self.transformers_registered = ( - transformers_registered # Internal flag: whether registered to Transformers - ) + self.auto_map = auto_map + self.hub_mixin_cls = hub_mixin_cls + self.transformers_registered = transformers_registered def __repr__(self): return ( f"ModelInfo(model_id={self.model_id}, model_type={self.model_type}, " f"category={self.category.value}, state={self.state.value}, " - f"has_auto_map={self.auto_map is not None})" + f"has_auto_map={self.auto_map is not None}), " + f"has_hub_mix_in_cls={self.hub_mixin_cls is not None})" ) @@ -144,6 +145,7 @@ def __repr__(self): "AutoConfig": "config.Chronos2CoreConfig", "AutoModelForCausalLM": "model.Chronos2Model", }, + transformers_registered=True, ), "moirai2": ModelInfo( model_id="moirai2", diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py index 1da07cb9fef9e..2476a05856d3e 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py @@ -17,28 +17,21 @@ # import os -from pathlib import Path from typing import Any import torch -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoModelForNextSentencePrediction, - AutoModelForSeq2SeqLM, - AutoModelForSequenceClassification, - AutoModelForTimeSeriesPrediction, - AutoModelForTokenClassification, -) -from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.exception import ModelNotExistException from iotdb.ainode.core.log import Logger from iotdb.ainode.core.manager.device_manager import DeviceManager -from iotdb.ainode.core.model.model_constants import ModelCategory +from iotdb.ainode.core.model.model_constants import MODEL_PT from iotdb.ainode.core.model.model_info import ModelInfo from iotdb.ainode.core.model.sktime.modeling_sktime import create_sktime_model -from iotdb.ainode.core.model.utils import import_class_from_path, temporary_sys_path +from iotdb.ainode.core.model.utils import ( + get_model_and_config_by_auto_class, + get_model_and_config_by_native_code, + get_model_path, +) logger = Logger() BACKEND = DeviceManager() @@ -46,83 +39,64 @@ def load_model(model_info: ModelInfo, **model_kwargs) -> Any: if model_info.auto_map is not None: - model = load_model_from_transformers(model_info, **model_kwargs) + model = load_transformers_model(model_info, **model_kwargs) + elif model_info.hub_mixin_cls is not None: + model = _load_hub_mixin_model(model_info, **model_kwargs) else: if model_info.model_type == "sktime": model = create_sktime_model(model_info.model_id) else: - model = load_model_from_pt(model_info, **model_kwargs) + model = _load_torchscript_model(model_info, **model_kwargs) logger.info( - f"Model {model_info.model_id} loaded to device {model.device if model_info.model_type != 'sktime' else 'cpu'} successfully." + f"Model {model_info.model_id} loaded to device {next(model.parameters()).device if model_info.model_type != 'sktime' else 'cpu'} successfully." ) return model -def load_model_from_transformers(model_info: ModelInfo, **model_kwargs): +def load_transformers_model(model_info: ModelInfo, **model_kwargs): device_map = model_kwargs.get("device_map", "cpu") + trust_remote_code = model_kwargs.get("trust_remote_code", True) train_from_scratch = model_kwargs.get("train_from_scratch", False) - model_path = os.path.join( - os.getcwd(), - AINodeDescriptor().get_config().get_ain_models_dir(), - model_info.category.value, - model_info.model_id, - ) + model_path = get_model_path(model_info) - config_str = model_info.auto_map.get("AutoConfig", "") - model_str = model_info.auto_map.get("AutoModelForCausalLM", "") + model_class, config_instance = get_model_and_config_by_native_code(model_info) + if model_class is None: + model_class, config_instance = get_model_and_config_by_auto_class(model_path) - if model_info.category == ModelCategory.BUILTIN: - module_name = ( - AINodeDescriptor().get_config().get_ain_models_builtin_dir() - + "." - + model_info.model_id + # ---- Load base model ---- + if train_from_scratch: + model = model_class.from_config( + config_instance, trust_remote_code=trust_remote_code ) - config_cls = import_class_from_path(module_name, config_str) - model_cls = import_class_from_path(module_name, model_str) - elif model_str and config_str: - module_parent = str(Path(model_path).parent.absolute()) - with temporary_sys_path(module_parent): - config_cls = import_class_from_path(model_info.model_id, config_str) - model_cls = import_class_from_path(model_info.model_id, model_str) else: - config_cls = AutoConfig.from_pretrained(model_path) - if type(config_cls) in AutoModelForTimeSeriesPrediction._model_mapping.keys(): - model_cls = AutoModelForTimeSeriesPrediction - elif ( - type(config_cls) in AutoModelForNextSentencePrediction._model_mapping.keys() - ): - model_cls = AutoModelForNextSentencePrediction - elif type(config_cls) in AutoModelForSeq2SeqLM._model_mapping.keys(): - model_cls = AutoModelForSeq2SeqLM - elif ( - type(config_cls) in AutoModelForSequenceClassification._model_mapping.keys() - ): - model_cls = AutoModelForSequenceClassification - elif type(config_cls) in AutoModelForTokenClassification._model_mapping.keys(): - model_cls = AutoModelForTokenClassification - else: - model_cls = AutoModelForCausalLM + model = model_class.from_pretrained( + model_path, + config=config_instance, + trust_remote_code=trust_remote_code, + ) - if train_from_scratch: - model = model_cls.from_config(config_cls) - else: - model = model_cls.from_pretrained(model_path) + return BACKEND.move_model(model, device_map) + +def _load_hub_mixin_model(model_info: ModelInfo, **model_kwargs): + device_map = model_kwargs.get("device_map", "cpu") + model_path = get_model_path(model_info) + model_class, _ = get_model_and_config_by_native_code(model_info) + if model_class is None: + logger.error(f"Model class not found for '{model_info.model_id}'") + raise ModelNotExistException(model_info.model_id) + # Load model + model = model_class.from_pretrained(model_path) return BACKEND.move_model(model, device_map) -def load_model_from_pt(model_info: ModelInfo, **kwargs): +def _load_torchscript_model(model_info: ModelInfo, **kwargs): device_map = kwargs.get("device_map", "cpu") acceleration = kwargs.get("acceleration", False) - model_path = os.path.join( - os.getcwd(), - AINodeDescriptor().get_config().get_ain_models_dir(), - model_info.category.value, - model_info.model_id, - ) - model_file = os.path.join(model_path, "model.pt") + model_path = get_model_path(model_info) + model_file = os.path.join(model_path, MODEL_PT) if not os.path.exists(model_file): logger.error(f"Model file not found at {model_file}.") raise ModelNotExistException(model_file) @@ -134,17 +108,3 @@ def load_model_from_pt(model_info: ModelInfo, **kwargs): except Exception as e: logger.warning(f"acceleration failed, fallback to normal mode: {str(e)}") return BACKEND.move_model(model, device_map) - - -def load_model_for_efficient_inference(): - # TODO: An efficient model loading method for inference based on model_arguments - pass - - -def load_model_for_powerful_finetune(): - # TODO: An powerful model loading method for finetune based on model_arguments - pass - - -def unload_model(): - pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py index 2cfb07fb56a70..b7799df67a583 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -21,10 +21,15 @@ import os import shutil from pathlib import Path -from typing import Dict, List, Optional - -from huggingface_hub import hf_hub_download -from transformers import AutoConfig, AutoModelForCausalLM +from typing import Dict, Optional + +from huggingface_hub import PyTorchModelHubMixin, hf_hub_download +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + PretrainedConfig, + PreTrainedModel, +) from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.constant import TSStatusCode @@ -35,8 +40,8 @@ ) from iotdb.ainode.core.log import Logger from iotdb.ainode.core.model.model_constants import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_WEIGHTS_FILE_IN_SAFETENSORS, + CONFIG_JSON, + MODEL_SAFETENSORS, ModelCategory, ModelStates, UriType, @@ -147,13 +152,13 @@ def _process_builtin_model_directory(self, model_dir: str, model_id: str): def _download_model_if_necessary() -> bool: """Returns: True if the model is existed or downloaded successfully, False otherwise.""" repo_id = BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id].repo_id - weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) - config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON) + weights_path = os.path.join(model_dir, MODEL_SAFETENSORS) + config_path = os.path.join(model_dir, CONFIG_JSON) if not os.path.exists(weights_path): try: hf_hub_download( repo_id=repo_id, - filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS, + filename=MODEL_SAFETENSORS, local_dir=model_dir, ) except Exception as e: @@ -165,7 +170,7 @@ def _download_model_if_necessary() -> bool: try: hf_hub_download( repo_id=repo_id, - filename=MODEL_CONFIG_FILE_IN_JSON, + filename=CONFIG_JSON, local_dir=model_dir, ) except Exception as e: @@ -191,7 +196,7 @@ def _callback_model_download_result(self, future, model_id: str): self._models_dir, ModelCategory.BUILTIN.value, model_id, - MODEL_CONFIG_FILE_IN_JSON, + CONFIG_JSON, ) if os.path.exists(config_path): with open(config_path, "r", encoding="utf-8") as f: @@ -218,15 +223,17 @@ def _callback_model_download_result(self, future, model_id: str): def _process_user_defined_model_directory(self, model_dir: str, model_id: str): """Handling the discovery logic for a user-defined model directory.""" - config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON) + config_path = os.path.join(model_dir, CONFIG_JSON) model_type = "" auto_map = {} pipeline_cls = "" + hub_mixin_cls = "" if os.path.exists(config_path): config = load_model_config_in_json(config_path) model_type = config.get("model_type", "") auto_map = config.get("auto_map", None) pipeline_cls = config.get("pipeline_cls", "") + hub_mixin_cls = config.get("hub_mixin_cls", "") model_info = ModelInfo( model_id=model_id, model_type=model_type, @@ -234,6 +241,7 @@ def _process_user_defined_model_directory(self, model_dir: str, model_id: str): state=ModelStates.ACTIVE, pipeline_cls=pipeline_cls, auto_map=auto_map, + hub_mixin_cls=hub_mixin_cls, transformers_registered=False, # Lazy registration ) with self._lock_pool.get_lock(model_id).write_lock(): @@ -284,6 +292,7 @@ def register_model(self, model_id: str, uri: str): model_type = config.get("model_type", "") auto_map = config.get("auto_map") pipeline_cls = config.get("pipeline_cls", "") + hub_mixin_cls = config.get("hub_mixin_cls", "") with self._lock_pool.get_lock(model_id).write_lock(): model_info = ModelInfo( @@ -293,6 +302,7 @@ def register_model(self, model_id: str, uri: str): state=ModelStates.ACTIVE, pipeline_cls=pipeline_cls, auto_map=auto_map, + hub_mixin_cls=hub_mixin_cls, transformers_registered=False, # Register later ) self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info @@ -308,6 +318,17 @@ def register_model(self, model_id: str, uri: str): f"Failed to register Transformers model {model_id}, because {e}" ) raise e + elif hub_mixin_cls: + # PyTorchModelHubMixin model: immediately register + try: + if self._register_hub_mixin_model(model_info): + model_info.transformers_registered = True + except Exception as e: + model_info.state = ModelStates.INACTIVE + logger.error( + f"Failed to register HubMixin model {model_id}, because {e}" + ) + raise e else: # Other type models: only log self._register_other_model(model_info) @@ -321,6 +342,7 @@ def _register_transformers_model(self, model_info: ModelInfo) -> bool: True if registration is successful Raises: Exception: Transformers internal exception if registration fails + ValueError: If class is invalid """ auto_map = model_info.auto_map if not auto_map: @@ -338,6 +360,14 @@ def _register_transformers_model(self, model_info: ModelInfo) -> bool: config_class = import_class_from_path( model_info.model_id, auto_config_path ) + # Validate config_class is a subclass of PretrainedConfig + if not ( + isinstance(config_class, type) + and issubclass(config_class, PretrainedConfig) + ): + raise ValueError( + f"AutoConfig class '{auto_config_path}' must be a subclass of PretrainedConfig" + ) AutoConfig.register(model_info.model_type, config_class) logger.info( f"Registered AutoConfig: {model_info.model_type} -> {auto_config_path}" @@ -346,6 +376,14 @@ def _register_transformers_model(self, model_info: ModelInfo) -> bool: model_class = import_class_from_path( model_info.model_id, auto_model_path ) + # Validate model_class is a subclass of PreTrainedModel + if not ( + isinstance(model_class, type) + and issubclass(model_class, PreTrainedModel) + ): + raise ValueError( + f"AutoModelForCausalLM class '{auto_model_path}' must be a subclass of PreTrainedModel" + ) AutoModelForCausalLM.register(config_class, model_class) logger.info( f"Registered AutoModelForCausalLM: {config_class.__name__} -> {auto_model_path}" @@ -357,6 +395,48 @@ def _register_transformers_model(self, model_info: ModelInfo) -> bool: ) raise e + def _register_hub_mixin_model(self, model_info: ModelInfo) -> bool: + """ + Register PyTorchModelHubMixin model (internal method). + For now, just validate the class. + + Returns: + True if registration is successful + Raises: + ValueError: If class is invalid + Exception: For other errors + """ + hub_mixin_cls = model_info.hub_mixin_cls + if not hub_mixin_cls: + return False + + try: + model_path = os.path.join( + self._models_dir, model_info.category.value, model_info.model_id + ) + module_parent = str(Path(model_path).parent.absolute()) + with temporary_sys_path(module_parent): + model_class = import_class_from_path(model_info.model_id, hub_mixin_cls) + + # Validate that the class inherits from PyTorchModelHubMixin + if not issubclass(model_class, PyTorchModelHubMixin): + raise ValueError( + f"Class '{model_class}' does not inherit from " + "PyTorchModelHubMixin." + ) + + logger.info( + f"Registered PyTorchModelHubMixin model: " + f"{model_info.model_id} -> {hub_mixin_cls}" + ) + return True + + except Exception as e: + logger.warning( + f"Failed to register PyTorchModelHubMixin model {model_info.model_id}: {e}." + ) + raise e + def _register_other_model(self, model_info: ModelInfo): """Register other type models (non-Transformers models)""" logger.info( @@ -526,7 +606,7 @@ def get_model_info( return self._models[category.value].get(model_id) else: # Category not specified, need to traverse all dictionaries, use global lock - with self._lock_pool.get_lock("").read_lock(): + with self._lock_pool.get_lock(model_id).read_lock(): for category_dict in self._models.values(): if model_id in category_dict: return category_dict[model_id] diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/pipeline_moirai2.py b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/pipeline_moirai2.py index fe2fb63236226..666c3063df35b 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/pipeline_moirai2.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/pipeline_moirai2.py @@ -30,7 +30,7 @@ class Moirai2Pipeline(ForecastPipeline): def __init__(self, model_info, **model_kwargs): super().__init__(model_info, **model_kwargs) - def preprocess(self, inputs, **infer_kwargs): + def _preprocess(self, inputs, **infer_kwargs): """ Preprocess input data for moirai2. @@ -48,7 +48,6 @@ def preprocess(self, inputs, **infer_kwargs): list of dict Processed inputs compatible with moirai2 format (time, features). """ - super().preprocess(inputs, **infer_kwargs) # Moirai2.predict() expects past_target in (time, features) format processed_inputs = [] for item in inputs: @@ -141,7 +140,7 @@ def forecast(self, inputs, **infer_kwargs) -> list[torch.Tensor]: f"Model must be an instance of Moirai2ForPrediction, got {type(self.model)}" ) - def postprocess( + def _postprocess( self, outputs: list[torch.Tensor], **infer_kwargs ) -> list[torch.Tensor]: """ @@ -165,5 +164,4 @@ def postprocess( else: # If no quantiles, get the mean outputs_list.append(output.mean(dim=1)) - super().postprocess(outputs_list, **infer_kwargs) return outputs_list diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py index 12b2668543ef5..a528ce0ffc19d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py @@ -33,7 +33,7 @@ def __init__(self, model_info: ModelInfo, **model_kwargs): model_kwargs.pop("device", None) # sktime models run on CPU super().__init__(model_info, **model_kwargs) - def preprocess( + def _preprocess( self, inputs: list[dict[str, dict[str, torch.Tensor] | torch.Tensor]], **infer_kwargs, @@ -49,8 +49,6 @@ def preprocess( """ model_id = self.model_info.model_id - inputs = super().preprocess(inputs, **infer_kwargs) - # Here, we assume element in list has same history_length, # otherwise, the model cannot proceed if inputs[0].get("past_covariates", None) or inputs[0].get( @@ -96,7 +94,7 @@ def forecast(self, inputs: list[pd.Series], **infer_kwargs) -> np.ndarray: return outputs - def postprocess(self, outputs: np.ndarray, **infer_kwargs) -> list[torch.Tensor]: + def _postprocess(self, outputs: np.ndarray, **infer_kwargs) -> list[torch.Tensor]: """ Postprocess the model's outputs. @@ -111,5 +109,4 @@ def postprocess(self, outputs: np.ndarray, **infer_kwargs) -> list[torch.Tensor] # Transform outputs into a 2D-tensor: [batch_size, output_length] outputs = torch.from_numpy(outputs).float() outputs = [outputs[i].unsqueeze(0) for i in range(outputs.size(0))] - outputs = super().postprocess(outputs, **infer_kwargs) return outputs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py index 8aa9b175169c1..8e4ffefe3164c 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py @@ -30,7 +30,7 @@ class SundialPipeline(ForecastPipeline): def __init__(self, model_info: ModelInfo, **model_kwargs): super().__init__(model_info, **model_kwargs) - def preprocess(self, inputs, **infer_kwargs) -> torch.Tensor: + def _preprocess(self, inputs, **infer_kwargs) -> torch.Tensor: """ Preprocess the input data by converting it to a 2D tensor (Sundial only supports 2D inputs). @@ -48,7 +48,6 @@ def preprocess(self, inputs, **infer_kwargs) -> torch.Tensor: (i.e., when inputs.shape[1] != 1). """ model_id = self.model_info.model_id - inputs = super().preprocess(inputs, **infer_kwargs) # Here, we assume element in list has same history_length, # otherwise, the model cannot proceed if inputs[0].get("past_covariates", None) or inputs[0].get( @@ -93,7 +92,7 @@ def forecast(self, inputs: torch.Tensor, **infer_kwargs) -> torch.Tensor: ) return outputs - def postprocess(self, outputs: torch.Tensor, **infer_kwargs) -> list[torch.Tensor]: + def _postprocess(self, outputs: torch.Tensor, **infer_kwargs) -> list[torch.Tensor]: """ Postprocess the model's output by averaging across the num_samples dimension and expanding the dimensions to match the expected shape. @@ -107,5 +106,4 @@ def postprocess(self, outputs: torch.Tensor, **infer_kwargs) -> list[torch.Tenso """ outputs = outputs.mean(dim=1).unsqueeze(1) outputs = [outputs[i] for i in range(outputs.size(0))] - outputs = super().postprocess(outputs, **infer_kwargs) return outputs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py index 213e6102c8b64..3b7957259d651 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py @@ -30,7 +30,7 @@ class TimerPipeline(ForecastPipeline): def __init__(self, model_info: ModelInfo, **model_kwargs): super().__init__(model_info, **model_kwargs) - def preprocess(self, inputs, **infer_kwargs) -> torch.Tensor: + def _preprocess(self, inputs, **infer_kwargs) -> torch.Tensor: """ Preprocess the input data by converting it to a 2D tensor (Timer-XL only supports 2D inputs). @@ -48,7 +48,6 @@ def preprocess(self, inputs, **infer_kwargs) -> torch.Tensor: (i.e., when inputs.shape[1] != 1). """ model_id = self.model_info.model_id - inputs = super().preprocess(inputs, **infer_kwargs) # Here, we assume element in list has same history_length, # otherwise, the model cannot proceed if inputs[0].get("past_covariates", None) or inputs[0].get( @@ -86,7 +85,7 @@ def forecast(self, inputs: torch.Tensor, **infer_kwargs) -> torch.Tensor: outputs = self.model.generate(inputs, max_new_tokens=output_length, revin=revin) return outputs - def postprocess(self, outputs: torch.Tensor, **infer_kwargs) -> list[torch.Tensor]: + def _postprocess(self, outputs: torch.Tensor, **infer_kwargs) -> list[torch.Tensor]: """ Postprocess the model's output by expanding its dimensions to match the expected shape. @@ -98,5 +97,4 @@ def postprocess(self, outputs: torch.Tensor, **infer_kwargs) -> list[torch.Tenso list of torch.Tensor: A list of 2D tensors with shape [target_count(1), output_length]. """ outputs = [outputs[i].unsqueeze(0) for i in range(outputs.size(0))] - outputs = super().postprocess(outputs, **infer_kwargs) return outputs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py index 815232c52b0d6..815f1076101a4 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py @@ -23,17 +23,31 @@ import sys from contextlib import contextmanager from pathlib import Path -from typing import Dict, Tuple - -from huggingface_hub import snapshot_download +from typing import Any, Dict, Optional, Tuple, Type + +from huggingface_hub import PyTorchModelHubMixin, snapshot_download +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForNextSentencePrediction, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoModelForTimeSeriesPrediction, + AutoModelForTokenClassification, + PretrainedConfig, + PreTrainedModel, +) +from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.exception import InvalidModelUriException from iotdb.ainode.core.log import Logger from iotdb.ainode.core.model.model_constants import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_WEIGHTS_FILE_IN_SAFETENSORS, + CONFIG_JSON, + MODEL_SAFETENSORS, + ModelCategory, UriType, ) +from iotdb.ainode.core.model.model_info import ModelInfo logger = Logger() @@ -69,11 +83,95 @@ def load_model_config_in_json(config_path: str) -> Dict: return json.load(f) +def get_model_path(model_info: ModelInfo) -> str: + return os.path.join( + os.getcwd(), + AINodeDescriptor().get_config().get_ain_models_dir(), + model_info.category.value, + model_info.model_id, + ) + + +def get_model_and_config_by_native_code( + model_info: ModelInfo, +) -> Tuple[ + Optional[Type[PreTrainedModel | PyTorchModelHubMixin]], Optional[PretrainedConfig] +]: + """ + Return model_class and config_instance (optionally) from the model's native code. + """ + + # Try to get model str and config str. + config_str = None + if model_info.auto_map: + config_str = model_info.auto_map.get("AutoConfig", "") + model_str = model_info.auto_map.get("AutoModelForCausalLM", "") + if not config_str or not model_str: + return None, None + elif model_info.hub_mixin_cls: + model_str = model_info.hub_mixin_cls + else: + return None, None + + model_path = get_model_path(model_info) + + # Try to import model and config class. + config_class, config_instance = None, None + model_class = None + if model_info.category == ModelCategory.BUILTIN: + module_name = ( + AINodeDescriptor().get_config().get_ain_models_builtin_dir() + + "." + + model_info.model_id + ) + if config_str: + # For Transformer models + config_class = import_class_from_path(module_name, config_str) + config_instance = config_class.from_pretrained(model_path) + model_class = import_class_from_path(module_name, model_str) + else: + module_parent = str(Path(model_path).parent.absolute()) + with temporary_sys_path(module_parent): + if config_str: + # For Transformer models + config_class = import_class_from_path(model_info.model_id, config_str) + config_instance = config_class.from_pretrained(model_path) + model_class = import_class_from_path(model_info.model_id, model_str) + + return model_class, config_instance + + +def get_model_and_config_by_auto_class(model_path: str) -> Tuple[type, Any]: + """Return model_class and config_instance from Huggingface Transformers's AutoClass.""" + config_instance = AutoConfig.from_pretrained(model_path) + + if type(config_instance) in AutoModelForTimeSeriesPrediction._model_mapping.keys(): + model_class = AutoModelForTimeSeriesPrediction + elif ( + type(config_instance) + in AutoModelForNextSentencePrediction._model_mapping.keys() + ): + model_class = AutoModelForNextSentencePrediction + elif type(config_instance) in AutoModelForSeq2SeqLM._model_mapping.keys(): + model_class = AutoModelForSeq2SeqLM + elif ( + type(config_instance) + in AutoModelForSequenceClassification._model_mapping.keys() + ): + model_class = AutoModelForSequenceClassification + elif type(config_instance) in AutoModelForTokenClassification._model_mapping.keys(): + model_class = AutoModelForTokenClassification + else: + model_class = AutoModelForCausalLM + + return model_class, config_instance + + def validate_model_files(model_dir: str) -> Tuple[str, str]: """Validate model files exist, return config and weights file paths""" - config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON) - weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) + config_path = os.path.join(model_dir, CONFIG_JSON) + weights_path = os.path.join(model_dir, MODEL_SAFETENSORS) if not os.path.exists(config_path): raise InvalidModelUriException( @@ -116,9 +214,9 @@ def _fetch_model_from_local(source_path: str, storage_path: str): if not source_dir.is_dir(): raise InvalidModelUriException(f"Source path is not a directory: {source_path}") storage_dir = Path(storage_path) - for file in source_dir.iterdir(): - if file.is_file(): - shutil.copy2(file, storage_dir / file.name) + if storage_dir.exists(): + shutil.rmtree(storage_dir) + shutil.copytree(source_dir, storage_dir) def _fetch_model_from_hf_repo(repo_id: str, storage_path: str): diff --git a/iotdb-core/ainode/iotdb/ainode/core/util/decorator.py b/iotdb-core/ainode/iotdb/ainode/core/util/decorator.py index 5a84c3d6bb203..2abe08ca9dc6a 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/util/decorator.py +++ b/iotdb-core/ainode/iotdb/ainode/core/util/decorator.py @@ -15,15 +15,21 @@ # specific language governing permissions and limitations # under the License. # +import threading from functools import wraps def singleton(cls): + """Thread-safe singleton decorator.""" instances = {} + lock = threading.Lock() + @wraps(cls) def get_instance(*args, **kwargs): if cls not in instances: - instances[cls] = cls(*args, **kwargs) + with lock: + if cls not in instances: + instances[cls] = cls(*args, **kwargs) return instances[cls] return get_instance diff --git a/iotdb-core/ainode/iotdb/ainode/core/util/serde.py b/iotdb-core/ainode/iotdb/ainode/core/util/serde.py index a61032ba26ff1..f03d323f1b8ad 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/util/serde.py +++ b/iotdb-core/ainode/iotdb/ainode/core/util/serde.py @@ -75,10 +75,17 @@ def convert_tsblock_to_tensor(tsblock_data: bytes): # Convert DataFrame to TsBlock in binary, input shouldn't contain time column. # Maybe contain multiple value columns. def convert_tensor_to_tsblock(data_tensor: torch.Tensor): - data_frame = pd.DataFrame(data_tensor).T - data_shape = data_frame.shape - value_column_size = data_shape[1] - position_count = data_shape[0] + # Ensure the tensor is 2D with size [target_count, sequence_length] + if data_tensor.dim() == 0: + data_tensor = data_tensor.unsqueeze(0).unsqueeze(0) + elif data_tensor.dim() == 1: + data_tensor = data_tensor.unsqueeze(0) + + # Transpose the tensor to [sequence_length, target_count] + data_frame = pd.DataFrame(data_tensor.cpu()).T + # sequence_length, target_count + position_count, value_column_size = data_frame.shape[0], data_frame.shape[1] + keys = data_frame.keys() binary = value_column_size.to_bytes(4, byteorder="big") diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java index 1a0fdd4bb9634..f406c16c08545 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java @@ -62,8 +62,6 @@ import java.util.Arrays; import java.util.Collections; import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Locale; @@ -74,6 +72,8 @@ import java.util.stream.Collectors; import static org.apache.iotdb.commons.udf.builtin.relational.tvf.WindowTVFUtils.findColumnIndex; +import static org.apache.iotdb.db.queryengine.plan.relational.function.tvf.TableFunctionUtils.checkType; +import static org.apache.iotdb.db.queryengine.plan.relational.function.tvf.TableFunctionUtils.parseOptions; import static org.apache.iotdb.db.queryengine.plan.relational.utils.ResultColumnAppender.createResultColumnAppender; public class ForecastTableFunction implements TableFunction { @@ -201,17 +201,6 @@ public int hashCode() { protected static final String DEFAULT_OPTIONS = ""; protected static final int MAX_INPUT_LENGTH = 2880; - private static final String INVALID_OPTIONS_FORMAT = "Invalid options: %s"; - - private static final Set ALLOWED_INPUT_TYPES = new HashSet<>(); - - static { - ALLOWED_INPUT_TYPES.add(Type.INT32); - ALLOWED_INPUT_TYPES.add(Type.INT64); - ALLOWED_INPUT_TYPES.add(Type.FLOAT); - ALLOWED_INPUT_TYPES.add(Type.DOUBLE); - } - @Override public List getArgumentsSpecifications() { return Arrays.asList( @@ -367,38 +356,6 @@ public TableFunctionDataProcessor getDataProcessor() { }; } - // only allow for INT32, INT64, FLOAT, DOUBLE - public void checkType(Type type, String columnName) { - if (!ALLOWED_INPUT_TYPES.contains(type)) { - throw new SemanticException( - String.format( - "The type of the column [%s] is [%s], only INT32, INT64, FLOAT, DOUBLE is allowed", - columnName, type)); - } - } - - public static Map parseOptions(String options) { - if (options.isEmpty()) { - return Collections.emptyMap(); - } - String[] optionArray = options.split(","); - if (optionArray.length == 0) { - throw new SemanticException(String.format(INVALID_OPTIONS_FORMAT, options)); - } - - Map optionsMap = new HashMap<>(optionArray.length); - for (String option : optionArray) { - int index = option.indexOf('='); - if (index == -1 || index == option.length() - 1) { - throw new SemanticException(String.format(INVALID_OPTIONS_FORMAT, option)); - } - String key = option.substring(0, index).trim(); - String value = option.substring(index + 1).trim(); - optionsMap.put(key, value); - } - return optionsMap; - } - protected static class ForecastDataProcessor implements TableFunctionDataProcessor { protected static final TsBlockSerde SERDE = new TsBlockSerde(); @@ -474,7 +431,7 @@ public void finish( int columnSize = properColumnBuilders.size(); // sort inputRecords in ascending order by timestamp - inputRecords.sort(Comparator.comparingLong(record -> record.getLong(0))); + inputRecords.sort(Comparator.comparingLong(r -> r.getLong(0))); // time column long inputStartTime = inputRecords.getFirst().getLong(0); @@ -514,6 +471,7 @@ public void finish( TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode()); } + // construct result column for (int columnIndex = 1, size = predicatedResult.getValueColumnCount(); columnIndex <= size; columnIndex++) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/TableFunctionUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/TableFunctionUtils.java new file mode 100644 index 0000000000000..499f133c437e3 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/TableFunctionUtils.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.function.tvf; + +import org.apache.iotdb.db.exception.sql.SemanticException; +import org.apache.iotdb.udf.api.type.Type; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +public class TableFunctionUtils { + private static final String INVALID_OPTIONS_FORMAT = "Invalid options: %s"; + + public static Map parseOptions(String options) { + if (options.isEmpty()) { + return Collections.emptyMap(); + } + String[] optionArray = options.split(","); + if (optionArray.length == 0) { + throw new SemanticException(String.format(INVALID_OPTIONS_FORMAT, options)); + } + + Map optionsMap = new HashMap<>(optionArray.length); + for (String option : optionArray) { + int index = option.indexOf('='); + if (index == -1 || index == option.length() - 1) { + throw new SemanticException(String.format(INVALID_OPTIONS_FORMAT, option)); + } + String key = option.substring(0, index).trim(); + String value = option.substring(index + 1).trim(); + optionsMap.put(key, value); + } + return optionsMap; + } + + private static final Set ALLOWED_INPUT_TYPES = new HashSet<>(); + + static { + ALLOWED_INPUT_TYPES.add(Type.INT32); + ALLOWED_INPUT_TYPES.add(Type.INT64); + ALLOWED_INPUT_TYPES.add(Type.FLOAT); + ALLOWED_INPUT_TYPES.add(Type.DOUBLE); + } + + // only allow for INT32, INT64, FLOAT, DOUBLE + public static void checkType(Type type, String columnName) { + if (!ALLOWED_INPUT_TYPES.contains(type)) { + throw new SemanticException( + String.format( + "The type of the column [%s] is [%s], only INT32, INT64, FLOAT, DOUBLE is allowed", + columnName, type)); + } + } +}