-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504) #5688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -116,7 +116,7 @@ | |
| from sagemaker.core.jumpstart.utils import get_eula_url | ||
| from sagemaker.train.defaults import TrainDefaults, JumpStartTrainDefaults | ||
| from sagemaker.core.workflow.pipeline_context import PipelineSession, runnable_by_pipeline | ||
| from sagemaker.core.helper.pipeline_variable import StrPipeVar | ||
| from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar | ||
|
|
||
| from sagemaker.train.local.local_container import _LocalContainer | ||
|
|
||
|
|
@@ -410,14 +410,18 @@ def __del__(self): | |
| self._temp_code_dir.cleanup() | ||
|
|
||
| def _validate_training_image_and_algorithm_name( | ||
| self, training_image: Optional[str], algorithm_name: Optional[str] | ||
| self, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good fix switching from truthiness ( Minor: The type annotation uses |
||
| training_image: Union[str, PipelineVariable, None], | ||
| algorithm_name: Union[str, PipelineVariable, None], | ||
| ): | ||
| """Validate that only one of 'training_image' or 'algorithm_name' is provided.""" | ||
| if not training_image and not algorithm_name: | ||
| has_image = training_image is not None | ||
| has_algo = algorithm_name is not None | ||
| if not has_image and not has_algo: | ||
| raise ValueError( | ||
| "Atleast one of 'training_image' or 'algorithm_name' must be provided.", | ||
| ) | ||
| if training_image and algorithm_name: | ||
| if has_image and has_algo: | ||
| raise ValueError( | ||
| "Only one of 'training_image' or 'algorithm_name' must be provided.", | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,9 +24,9 @@ | |
| from typing import Literal, Any | ||
|
|
||
| from sagemaker.core.helper.session_helper import Session | ||
| from sagemaker.core.helper.pipeline_variable import PipelineVariable | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The import was changed from |
||
| from sagemaker.core.shapes import Unassigned | ||
| from sagemaker.train import logger | ||
| from sagemaker.core.workflow.parameters import PipelineVariable | ||
|
|
||
|
|
||
| def _default_bucket_and_prefix(session: Session) -> str: | ||
|
|
@@ -142,7 +142,7 @@ def _get_unique_name(base, max_length=63): | |
| return unique_name | ||
|
|
||
|
|
||
| def _get_repo_name_from_image(image: str) -> str: | ||
| def _get_repo_name_from_image(image: "str | PipelineVariable") -> "str | None": | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type annotation uses PEP 604 string-quoted syntax ( |
||
| """Get the repository name from the image URI. | ||
|
|
||
| Example: | ||
|
|
@@ -152,11 +152,13 @@ def _get_repo_name_from_image(image: str) -> str: | |
| ``` | ||
|
|
||
| Args: | ||
| image (str): The image URI | ||
| image (str | PipelineVariable): The image URI | ||
|
|
||
| Returns: | ||
| str: The repository name | ||
| str | None: The repository name, or None if image is a PipelineVariable | ||
| """ | ||
| if isinstance(image, PipelineVariable): | ||
| return None | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return image.split("/")[-1].split(":")[0].split("@")[0] | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,14 +26,15 @@ | |
|
|
||
| from sagemaker.core.helper.session_helper import Session | ||
| from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar | ||
| from sagemaker.core.workflow.parameters import ParameterString | ||
| from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger | ||
| from sagemaker.train.model_trainer import ModelTrainer, Mode | ||
| from sagemaker.train.configs import ( | ||
| Compute, | ||
| StoppingCondition, | ||
| OutputDataConfig, | ||
| ) | ||
| from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE | ||
| from sagemaker.train.utils import _get_repo_name_from_image, safe_serialize | ||
|
|
||
|
|
||
| DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest" | ||
|
|
@@ -176,3 +177,116 @@ def test_training_image_rejects_invalid_type(self): | |
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| ) | ||
|
|
||
|
|
||
| class TestValidateTrainingImageAndAlgorithmName: | ||
| """Tests for _validate_training_image_and_algorithm_name with PipelineVariable.""" | ||
|
|
||
| def test_pipeline_variable_training_image_passes_validation(self): | ||
| """PipelineVariable as training_image should pass validation.""" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good test coverage for the validation fix. Consider also adding a test case where |
||
| param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE) | ||
| trainer = ModelTrainer( | ||
| training_image=param, | ||
| base_job_name="pipeline-test-job", | ||
| role=DEFAULT_ROLE, | ||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| ) | ||
| assert trainer.training_image is param | ||
|
|
||
| def test_pipeline_variable_algorithm_name_passes_validation(self): | ||
| """PipelineVariable as algorithm_name should pass validation.""" | ||
| param = ParameterString(name="AlgoName", default_value="my-algo") | ||
| trainer = ModelTrainer( | ||
| algorithm_name=param, | ||
| base_job_name="pipeline-test-job", | ||
| role=DEFAULT_ROLE, | ||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| ) | ||
| assert trainer.algorithm_name is param | ||
|
|
||
| def test_both_pipeline_variables_raises_value_error(self): | ||
| """Both training_image and algorithm_name as PipelineVariable should raise ValueError.""" | ||
| image_param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE) | ||
| algo_param = ParameterString(name="AlgoName", default_value="my-algo") | ||
| with pytest.raises(ValueError, match="Only one of"): | ||
| ModelTrainer( | ||
| training_image=image_param, | ||
| algorithm_name=algo_param, | ||
| base_job_name="pipeline-test-job", | ||
| role=DEFAULT_ROLE, | ||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| ) | ||
|
|
||
| def test_neither_provided_raises_value_error(self): | ||
| """Neither training_image nor algorithm_name should raise ValueError.""" | ||
| with pytest.raises(ValueError, match="Atleast one of"): | ||
| ModelTrainer( | ||
| training_image=None, | ||
| algorithm_name=None, | ||
| base_job_name="pipeline-test-job", | ||
| role=DEFAULT_ROLE, | ||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| ) | ||
|
|
||
|
|
||
| class TestGetRepoNameFromImage: | ||
| """Tests for _get_repo_name_from_image with PipelineVariable.""" | ||
|
|
||
| def test_returns_none_for_pipeline_variable(self): | ||
| """_get_repo_name_from_image should return None for PipelineVariable.""" | ||
| param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE) | ||
| result = _get_repo_name_from_image(param) | ||
| assert result is None | ||
|
|
||
| def test_returns_repo_name_for_string(self): | ||
| """_get_repo_name_from_image should return repo name for a normal string.""" | ||
| result = _get_repo_name_from_image( | ||
| "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo:latest" | ||
| ) | ||
| assert result == "my-repo" | ||
|
|
||
| def test_returns_repo_name_without_tag(self): | ||
| """_get_repo_name_from_image should handle image URIs without tags.""" | ||
| result = _get_repo_name_from_image( | ||
| "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo" | ||
| ) | ||
| assert result == "my-repo" | ||
|
|
||
|
|
||
| class TestSafeSerialize: | ||
| """Tests for safe_serialize with PipelineVariable.""" | ||
|
|
||
| def test_safe_serialize_pipeline_variable_returns_variable(self): | ||
| """safe_serialize should return the PipelineVariable object as-is.""" | ||
| param = ParameterInteger(name="MaxDepth", default_value=5) | ||
| result = safe_serialize(param) | ||
| assert result is param | ||
|
|
||
| def test_safe_serialize_string_returns_string(self): | ||
| """safe_serialize should return strings as-is.""" | ||
| result = safe_serialize("hello") | ||
| assert result == "hello" | ||
|
|
||
| def test_safe_serialize_int_returns_json(self): | ||
| """safe_serialize should JSON-encode integers.""" | ||
| result = safe_serialize(5) | ||
| assert result == "5" | ||
|
|
||
| def test_safe_serialize_dict_returns_json(self): | ||
| """safe_serialize should JSON-encode dicts.""" | ||
| result = safe_serialize({"key": "value"}) | ||
| assert result == '{"key": "value"}' | ||
|
|
||
| def test_safe_serialize_parameter_string_returns_variable(self): | ||
| """safe_serialize should return ParameterString as-is.""" | ||
| param = ParameterString(name="MyParam", default_value="val") | ||
| result = safe_serialize(param) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| assert result is param | ||
Uh oh!
There was an error while loading. Please reload this page.