Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions sagemaker-train/src/sagemaker/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
from sagemaker.core.jumpstart.utils import get_eula_url
from sagemaker.train.defaults import TrainDefaults, JumpStartTrainDefaults
from sagemaker.core.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
from sagemaker.core.helper.pipeline_variable import StrPipeVar
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar

from sagemaker.train.local.local_container import _LocalContainer

Expand Down Expand Up @@ -410,14 +410,18 @@ def __del__(self):
self._temp_code_dir.cleanup()

def _validate_training_image_and_algorithm_name(
self, training_image: Optional[str], algorithm_name: Optional[str]
self,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good fix switching from truthiness (not training_image) to identity (is not None) checks. The truthiness check would evaluate a PipelineVariable object as truthy but could also fail for empty strings. The is not None approach is more correct and explicit.

Minor: The type annotation uses Union[str, PipelineVariable, None] — consider whether Optional[Union[str, PipelineVariable]] or str | PipelineVariable | None (with from __future__ import annotations) would be more consistent with the SDK's preference for newer syntax in new code.

training_image: Union[str, PipelineVariable, None],
algorithm_name: Union[str, PipelineVariable, None],
):
"""Validate that only one of 'training_image' or 'algorithm_name' is provided."""
if not training_image and not algorithm_name:
has_image = training_image is not None
has_algo = algorithm_name is not None
if not has_image and not has_algo:
raise ValueError(
"Atleast one of 'training_image' or 'algorithm_name' must be provided.",
)
if training_image and algorithm_name:
if has_image and has_algo:
raise ValueError(
"Only one of 'training_image' or 'algorithm_name' must be provided.",
)
Expand Down
10 changes: 6 additions & 4 deletions sagemaker-train/src/sagemaker/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from typing import Literal, Any

from sagemaker.core.helper.session_helper import Session
from sagemaker.core.helper.pipeline_variable import PipelineVariable
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import was changed from sagemaker.core.workflow.parameters.PipelineVariable to sagemaker.core.helper.pipeline_variable.PipelineVariable. Please verify these are the same class (or that the helper version is the correct base class that ParameterString/ParameterInteger inherit from). If sagemaker.core.workflow.parameters.PipelineVariable is a re-export of the same class, this is fine, but if they're different classes, the isinstance check on line 163 might not catch all pipeline variable types.

from sagemaker.core.shapes import Unassigned
from sagemaker.train import logger
from sagemaker.core.workflow.parameters import PipelineVariable


def _default_bucket_and_prefix(session: Session) -> str:
Expand Down Expand Up @@ -142,7 +142,7 @@ def _get_unique_name(base, max_length=63):
return unique_name


def _get_repo_name_from_image(image: str) -> str:
def _get_repo_name_from_image(image: "str | PipelineVariable") -> "str | None":
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type annotation uses PEP 604 string-quoted syntax ("str | PipelineVariable") for the parameter but the module doesn't use from __future__ import annotations. While this works at runtime due to quoting, it's inconsistent with the return type which is also quoted. Per SDK conventions, prefer adding from __future__ import annotations at the top of the module and using unquoted str | PipelineVariable and str | None, or use Union[str, PipelineVariable] / Optional[str] consistently with the rest of the codebase (as done in model_trainer.py).

"""Get the repository name from the image URI.

Example:
Expand All @@ -152,11 +152,13 @@ def _get_repo_name_from_image(image: str) -> str:
```

Args:
image (str): The image URI
image (str | PipelineVariable): The image URI

Returns:
str: The repository name
str | None: The repository name, or None if image is a PipelineVariable
"""
if isinstance(image, PipelineVariable):
return None
return image.split("/")[-1].split(":")[0].split("@")[0]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@

from sagemaker.core.helper.session_helper import Session
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
from sagemaker.core.workflow.parameters import ParameterString
from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger
from sagemaker.train.model_trainer import ModelTrainer, Mode
from sagemaker.train.configs import (
Compute,
StoppingCondition,
OutputDataConfig,
)
from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE
from sagemaker.train.utils import _get_repo_name_from_image, safe_serialize


DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest"
Expand Down Expand Up @@ -176,3 +177,116 @@ def test_training_image_rejects_invalid_type(self):
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
)


class TestValidateTrainingImageAndAlgorithmName:
"""Tests for _validate_training_image_and_algorithm_name with PipelineVariable."""

def test_pipeline_variable_training_image_passes_validation(self):
"""PipelineVariable as training_image should pass validation."""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good test coverage for the validation fix. Consider also adding a test case where training_image is a ParameterString with an empty string default value (default_value="") to verify the is not None check doesn't regress for edge cases where the resolved value might be empty.

param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE)
trainer = ModelTrainer(
training_image=param,
base_job_name="pipeline-test-job",
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
)
assert trainer.training_image is param

def test_pipeline_variable_algorithm_name_passes_validation(self):
"""PipelineVariable as algorithm_name should pass validation."""
param = ParameterString(name="AlgoName", default_value="my-algo")
trainer = ModelTrainer(
algorithm_name=param,
base_job_name="pipeline-test-job",
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
)
assert trainer.algorithm_name is param

def test_both_pipeline_variables_raises_value_error(self):
"""Both training_image and algorithm_name as PipelineVariable should raise ValueError."""
image_param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE)
algo_param = ParameterString(name="AlgoName", default_value="my-algo")
with pytest.raises(ValueError, match="Only one of"):
ModelTrainer(
training_image=image_param,
algorithm_name=algo_param,
base_job_name="pipeline-test-job",
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
)

def test_neither_provided_raises_value_error(self):
"""Neither training_image nor algorithm_name should raise ValueError."""
with pytest.raises(ValueError, match="Atleast one of"):
ModelTrainer(
training_image=None,
algorithm_name=None,
base_job_name="pipeline-test-job",
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
)


class TestGetRepoNameFromImage:
"""Tests for _get_repo_name_from_image with PipelineVariable."""

def test_returns_none_for_pipeline_variable(self):
"""_get_repo_name_from_image should return None for PipelineVariable."""
param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE)
result = _get_repo_name_from_image(param)
assert result is None

def test_returns_repo_name_for_string(self):
"""_get_repo_name_from_image should return repo name for a normal string."""
result = _get_repo_name_from_image(
"123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo:latest"
)
assert result == "my-repo"

def test_returns_repo_name_without_tag(self):
"""_get_repo_name_from_image should handle image URIs without tags."""
result = _get_repo_name_from_image(
"123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo"
)
assert result == "my-repo"


class TestSafeSerialize:
"""Tests for safe_serialize with PipelineVariable."""

def test_safe_serialize_pipeline_variable_returns_variable(self):
"""safe_serialize should return the PipelineVariable object as-is."""
param = ParameterInteger(name="MaxDepth", default_value=5)
result = safe_serialize(param)
assert result is param

def test_safe_serialize_string_returns_string(self):
"""safe_serialize should return strings as-is."""
result = safe_serialize("hello")
assert result == "hello"

def test_safe_serialize_int_returns_json(self):
"""safe_serialize should JSON-encode integers."""
result = safe_serialize(5)
assert result == "5"

def test_safe_serialize_dict_returns_json(self):
"""safe_serialize should JSON-encode dicts."""
result = safe_serialize({"key": "value"})
assert result == '{"key": "value"}'

def test_safe_serialize_parameter_string_returns_variable(self):
"""safe_serialize should return ParameterString as-is."""
param = ParameterString(name="MyParam", default_value="val")
result = safe_serialize(param)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TestSafeSerialize tests are good additions, but I don't see any corresponding changes to the safe_serialize function in utils.py in this diff. If safe_serialize already handles PipelineVariable correctly, these tests are documenting existing behavior (which is fine). But if the bug also affects safe_serialize and it needs a fix, that change appears to be missing from this PR. Could you confirm that safe_serialize already has the PipelineVariable guard?

assert result is param
Loading