From 04d768dd719ade04e6e777dec312a4a29e8e3e2a Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Fri, 27 Mar 2026 11:53:01 -0400 Subject: [PATCH 1/2] fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504) --- sagemaker-train/src/sagemaker/train/model_trainer.py | 8 +++++--- sagemaker-train/src/sagemaker/train/utils.py | 10 +++++++--- .../unit/train/test_model_trainer_pipeline_variable.py | 2 +- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index d07edeb025..cbbbbf6a29 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -410,14 +410,16 @@ def __del__(self): self._temp_code_dir.cleanup() def _validate_training_image_and_algorithm_name( - self, training_image: Optional[str], algorithm_name: Optional[str] + self, training_image, algorithm_name ): """Validate that only one of 'training_image' or 'algorithm_name' is provided.""" - if not training_image and not algorithm_name: + has_image = training_image is not None + has_algo = algorithm_name is not None + if not has_image and not has_algo: raise ValueError( "Atleast one of 'training_image' or 'algorithm_name' must be provided.", ) - if training_image and algorithm_name: + if has_image and has_algo: raise ValueError( "Only one of 'training_image' or 'algorithm_name' must be provided.", ) diff --git a/sagemaker-train/src/sagemaker/train/utils.py b/sagemaker-train/src/sagemaker/train/utils.py index 0abd7596b5..1947876ea8 100644 --- a/sagemaker-train/src/sagemaker/train/utils.py +++ b/sagemaker-train/src/sagemaker/train/utils.py @@ -142,7 +142,7 @@ def _get_unique_name(base, max_length=63): return unique_name -def _get_repo_name_from_image(image: str) -> str: +def _get_repo_name_from_image(image) -> str: """Get the repository name from the image URI. Example: @@ -152,11 +152,15 @@ def _get_repo_name_from_image(image: str) -> str: ``` Args: - image (str): The image URI + image: The image URI (str or PipelineVariable) Returns: - str: The repository name + str: The repository name, or None if image is a PipelineVariable """ + from sagemaker.core.helper.pipeline_variable import PipelineVariable + + if isinstance(image, PipelineVariable): + return None return image.split("/")[-1].split(":")[0].split("@")[0] diff --git a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py index 3fd34fa47b..b6fa0867df 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py @@ -33,7 +33,7 @@ StoppingCondition, OutputDataConfig, ) -from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE +from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE, TrainDefaults DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest" From a2f9e4793e6aa99b83dc3bea433903ee022d4087 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Fri, 27 Mar 2026 11:59:42 -0400 Subject: [PATCH 2/2] fix: address review comments (iteration #1) --- .../src/sagemaker/train/model_trainer.py | 6 +- sagemaker-train/src/sagemaker/train/utils.py | 10 +- .../test_model_trainer_pipeline_variable.py | 118 +++++++++++++++++- 3 files changed, 124 insertions(+), 10 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index cbbbbf6a29..e88afc9b1e 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -116,7 +116,7 @@ from sagemaker.core.jumpstart.utils import get_eula_url from sagemaker.train.defaults import TrainDefaults, JumpStartTrainDefaults from sagemaker.core.workflow.pipeline_context import PipelineSession, runnable_by_pipeline -from sagemaker.core.helper.pipeline_variable import StrPipeVar +from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar from sagemaker.train.local.local_container import _LocalContainer @@ -410,7 +410,9 @@ def __del__(self): self._temp_code_dir.cleanup() def _validate_training_image_and_algorithm_name( - self, training_image, algorithm_name + self, + training_image: Union[str, PipelineVariable, None], + algorithm_name: Union[str, PipelineVariable, None], ): """Validate that only one of 'training_image' or 'algorithm_name' is provided.""" has_image = training_image is not None diff --git a/sagemaker-train/src/sagemaker/train/utils.py b/sagemaker-train/src/sagemaker/train/utils.py index 1947876ea8..8bcdb2ecf3 100644 --- a/sagemaker-train/src/sagemaker/train/utils.py +++ b/sagemaker-train/src/sagemaker/train/utils.py @@ -24,9 +24,9 @@ from typing import Literal, Any from sagemaker.core.helper.session_helper import Session +from sagemaker.core.helper.pipeline_variable import PipelineVariable from sagemaker.core.shapes import Unassigned from sagemaker.train import logger -from sagemaker.core.workflow.parameters import PipelineVariable def _default_bucket_and_prefix(session: Session) -> str: @@ -142,7 +142,7 @@ def _get_unique_name(base, max_length=63): return unique_name -def _get_repo_name_from_image(image) -> str: +def _get_repo_name_from_image(image: "str | PipelineVariable") -> "str | None": """Get the repository name from the image URI. Example: @@ -152,13 +152,11 @@ def _get_repo_name_from_image(image) -> str: ``` Args: - image: The image URI (str or PipelineVariable) + image (str | PipelineVariable): The image URI Returns: - str: The repository name, or None if image is a PipelineVariable + str | None: The repository name, or None if image is a PipelineVariable """ - from sagemaker.core.helper.pipeline_variable import PipelineVariable - if isinstance(image, PipelineVariable): return None return image.split("/")[-1].split(":")[0].split("@")[0] diff --git a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py index b6fa0867df..2a97e0606f 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.py @@ -26,14 +26,15 @@ from sagemaker.core.helper.session_helper import Session from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar -from sagemaker.core.workflow.parameters import ParameterString +from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger from sagemaker.train.model_trainer import ModelTrainer, Mode from sagemaker.train.configs import ( Compute, StoppingCondition, OutputDataConfig, ) -from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE, TrainDefaults +from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE +from sagemaker.train.utils import _get_repo_name_from_image, safe_serialize DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest" @@ -176,3 +177,116 @@ def test_training_image_rejects_invalid_type(self): stopping_condition=DEFAULT_STOPPING, output_data_config=DEFAULT_OUTPUT, ) + + +class TestValidateTrainingImageAndAlgorithmName: + """Tests for _validate_training_image_and_algorithm_name with PipelineVariable.""" + + def test_pipeline_variable_training_image_passes_validation(self): + """PipelineVariable as training_image should pass validation.""" + param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE) + trainer = ModelTrainer( + training_image=param, + base_job_name="pipeline-test-job", + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + ) + assert trainer.training_image is param + + def test_pipeline_variable_algorithm_name_passes_validation(self): + """PipelineVariable as algorithm_name should pass validation.""" + param = ParameterString(name="AlgoName", default_value="my-algo") + trainer = ModelTrainer( + algorithm_name=param, + base_job_name="pipeline-test-job", + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + ) + assert trainer.algorithm_name is param + + def test_both_pipeline_variables_raises_value_error(self): + """Both training_image and algorithm_name as PipelineVariable should raise ValueError.""" + image_param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE) + algo_param = ParameterString(name="AlgoName", default_value="my-algo") + with pytest.raises(ValueError, match="Only one of"): + ModelTrainer( + training_image=image_param, + algorithm_name=algo_param, + base_job_name="pipeline-test-job", + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + ) + + def test_neither_provided_raises_value_error(self): + """Neither training_image nor algorithm_name should raise ValueError.""" + with pytest.raises(ValueError, match="Atleast one of"): + ModelTrainer( + training_image=None, + algorithm_name=None, + base_job_name="pipeline-test-job", + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE, + stopping_condition=DEFAULT_STOPPING, + output_data_config=DEFAULT_OUTPUT, + ) + + +class TestGetRepoNameFromImage: + """Tests for _get_repo_name_from_image with PipelineVariable.""" + + def test_returns_none_for_pipeline_variable(self): + """_get_repo_name_from_image should return None for PipelineVariable.""" + param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE) + result = _get_repo_name_from_image(param) + assert result is None + + def test_returns_repo_name_for_string(self): + """_get_repo_name_from_image should return repo name for a normal string.""" + result = _get_repo_name_from_image( + "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo:latest" + ) + assert result == "my-repo" + + def test_returns_repo_name_without_tag(self): + """_get_repo_name_from_image should handle image URIs without tags.""" + result = _get_repo_name_from_image( + "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo" + ) + assert result == "my-repo" + + +class TestSafeSerialize: + """Tests for safe_serialize with PipelineVariable.""" + + def test_safe_serialize_pipeline_variable_returns_variable(self): + """safe_serialize should return the PipelineVariable object as-is.""" + param = ParameterInteger(name="MaxDepth", default_value=5) + result = safe_serialize(param) + assert result is param + + def test_safe_serialize_string_returns_string(self): + """safe_serialize should return strings as-is.""" + result = safe_serialize("hello") + assert result == "hello" + + def test_safe_serialize_int_returns_json(self): + """safe_serialize should JSON-encode integers.""" + result = safe_serialize(5) + assert result == "5" + + def test_safe_serialize_dict_returns_json(self): + """safe_serialize should JSON-encode dicts.""" + result = safe_serialize({"key": "value"}) + assert result == '{"key": "value"}' + + def test_safe_serialize_parameter_string_returns_variable(self): + """safe_serialize should return ParameterString as-is.""" + param = ParameterString(name="MyParam", default_value="val") + result = safe_serialize(param) + assert result is param