fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504)#5688
Conversation
…il in ModelTrain (5504)
mufaddal-rohawala
left a comment
There was a problem hiding this comment.
🤖 AI Code Review
This PR fixes a bug where Pipeline parameters (PipelineVariable types) fail in ModelTrainer validation because truthiness checks on PipelineVariable objects don't behave like regular strings. The core fix is correct, but there are several issues: type annotations are removed instead of being updated, the import placement is suboptimal, and the test changes appear incomplete.
|
|
||
|
|
||
| def _get_repo_name_from_image(image: str) -> str: | ||
| def _get_repo_name_from_image(image) -> str: |
There was a problem hiding this comment.
Same issue — don't remove the type annotation, update it:
def _get_repo_name_from_image(image: str | PipelineVariable) -> str | None:Note the return type should also be updated to str | None since you now return None for PipelineVariable inputs.
| str: The repository name, or None if image is a PipelineVariable | ||
| """ | ||
| from sagemaker.core.helper.pipeline_variable import PipelineVariable | ||
|
|
There was a problem hiding this comment.
Move the import to the top of the module (or at minimum to the top of the function). Inline imports inside functions are acceptable for avoiding circular dependencies, but please add a comment explaining why it's done here:
# Import here to avoid circular dependency
from sagemaker.core.helper.pipeline_variable import PipelineVariableAlso, is there actually a circular dependency risk? If not, this import should be at the module level with other imports.
| OutputDataConfig, | ||
| ) | ||
| from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE | ||
| from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE, TrainDefaults |
There was a problem hiding this comment.
TrainDefaults is imported but never used in the diff. Is this import used elsewhere in the file (not shown in the diff)? If not, this is an unused import that will fail linting. If it IS used in existing code not shown in the diff, please disregard this comment.
| OutputDataConfig, | ||
| ) | ||
| from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE | ||
| from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE, TrainDefaults |
There was a problem hiding this comment.
The test changes seem insufficient for the scope of the fix. The PR modifies validation logic in _validate_training_image_and_algorithm_name and adds PipelineVariable handling in _get_repo_name_from_image, but the test diff only shows an import change. Where are the new test cases that:
- Pass a
PipelineVariableastraining_imageand verify validation passes? - Pass a
PipelineVariableto_get_repo_name_from_imageand verify it returnsNone? - Verify that both
PipelineVariablefortraining_imageANDalgorithm_namestill raisesValueError? - Verify that
Nonefor both still raisesValueError?
Please add explicit unit tests for the changed behavior. Target >90% coverage per SDK standards.
mufaddal-rohawala
left a comment
There was a problem hiding this comment.
🤖 AI Code Review
This PR fixes a bug where PipelineVariable objects (ParameterInteger, ParameterString) were failing in ModelTrainer validation because truthiness checks on these objects don't behave like regular strings/None. The fix correctly switches to is not None checks and adds a PipelineVariable guard in _get_repo_name_from_image. The approach is sound, but there are a few issues worth addressing.
|
|
||
|
|
||
| def _get_repo_name_from_image(image: str) -> str: | ||
| def _get_repo_name_from_image(image: "str | PipelineVariable") -> "str | None": |
There was a problem hiding this comment.
The type annotation uses PEP 604 string-quoted syntax ("str | PipelineVariable") for the parameter but the module doesn't use from __future__ import annotations. While this works at runtime due to quoting, it's inconsistent with the return type which is also quoted. Per SDK conventions, prefer adding from __future__ import annotations at the top of the module and using unquoted str | PipelineVariable and str | None, or use Union[str, PipelineVariable] / Optional[str] consistently with the rest of the codebase (as done in model_trainer.py).
| from typing import Literal, Any | ||
|
|
||
| from sagemaker.core.helper.session_helper import Session | ||
| from sagemaker.core.helper.pipeline_variable import PipelineVariable |
There was a problem hiding this comment.
The import was changed from sagemaker.core.workflow.parameters.PipelineVariable to sagemaker.core.helper.pipeline_variable.PipelineVariable. Please verify these are the same class (or that the helper version is the correct base class that ParameterString/ParameterInteger inherit from). If sagemaker.core.workflow.parameters.PipelineVariable is a re-export of the same class, this is fine, but if they're different classes, the isinstance check on line 163 might not catch all pipeline variable types.
|
|
||
| def _validate_training_image_and_algorithm_name( | ||
| self, training_image: Optional[str], algorithm_name: Optional[str] | ||
| self, |
There was a problem hiding this comment.
Good fix switching from truthiness (not training_image) to identity (is not None) checks. The truthiness check would evaluate a PipelineVariable object as truthy but could also fail for empty strings. The is not None approach is more correct and explicit.
Minor: The type annotation uses Union[str, PipelineVariable, None] — consider whether Optional[Union[str, PipelineVariable]] or str | PipelineVariable | None (with from __future__ import annotations) would be more consistent with the SDK's preference for newer syntax in new code.
| def test_safe_serialize_parameter_string_returns_variable(self): | ||
| """safe_serialize should return ParameterString as-is.""" | ||
| param = ParameterString(name="MyParam", default_value="val") | ||
| result = safe_serialize(param) |
There was a problem hiding this comment.
The TestSafeSerialize tests are good additions, but I don't see any corresponding changes to the safe_serialize function in utils.py in this diff. If safe_serialize already handles PipelineVariable correctly, these tests are documenting existing behavior (which is fine). But if the bug also affects safe_serialize and it needs a fix, that change appears to be missing from this PR. Could you confirm that safe_serialize already has the PipelineVariable guard?
| """Tests for _validate_training_image_and_algorithm_name with PipelineVariable.""" | ||
|
|
||
| def test_pipeline_variable_training_image_passes_validation(self): | ||
| """PipelineVariable as training_image should pass validation.""" |
There was a problem hiding this comment.
Good test coverage for the validation fix. Consider also adding a test case where training_image is a ParameterString with an empty string default value (default_value="") to verify the is not None check doesn't regress for edge cases where the resolved value might be empty.
Description
No response from agent
Related Issue
Related issue: 5504
Changes Made
sagemaker-train/src/sagemaker/train/utils.pysagemaker-train/src/sagemaker/train/model_trainer.pysagemaker-train/tests/unit/train/test_model_trainer_pipeline_variable.pyAI-Generated PR
This PR was automatically generated by the PySDK Issue Agent.
Merge Checklist
prefix: descriptionformat