Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sagemaker-mlops/src/sagemaker/mlops/workflow/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _find_dependencies_in_step_arguments(
else:
dependencies.add(self._get_step_name_from_str(referenced_step, step_map))

from sagemaker.core.workflow.function_step import DelayedReturn
from sagemaker.mlops.workflow.function_step import DelayedReturn

# TODO: we can remove the if-elif once move the validators to JsonGet constructor
if isinstance(pipeline_variable, JsonGet):
Expand Down
13 changes: 10 additions & 3 deletions sagemaker-mlops/tests/unit/workflow/test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def test_step_find_dependencies_in_step_arguments_with_json_get():
obj = {"key": json_get}

with patch('sagemaker.mlops.workflow.steps.TYPE_CHECKING', False):
with patch.dict('sys.modules', {'sagemaker.core.workflow.function_step': Mock()}):
with patch.dict('sys.modules', {'sagemaker.mlops.workflow.function_step': Mock()}):
dependencies = Step._find_dependencies_in_step_arguments(step2, obj, {"step1": step1})
assert "step1" in dependencies

Expand Down Expand Up @@ -445,7 +445,7 @@ def test_step_find_dependencies_in_step_arguments_with_delayed_return():
mock_module = Mock()
mock_module.DelayedReturn = delayed_return_class

with patch.dict('sys.modules', {'sagemaker.core.workflow.function_step': mock_module}):
with patch.dict('sys.modules', {'sagemaker.mlops.workflow.function_step': mock_module}):
dependencies = Step._find_dependencies_in_step_arguments(step2, obj, {"step1": step1})
assert "step1" in dependencies

Expand Down Expand Up @@ -473,11 +473,18 @@ def test_step_find_dependencies_in_step_arguments_with_string_reference():
mock_module = Mock()
mock_module.DelayedReturn = delayed_return_class

with patch.dict('sys.modules', {'sagemaker.core.workflow.function_step': mock_module}):
with patch.dict('sys.modules', {'sagemaker.mlops.workflow.function_step': mock_module}):
dependencies = Step._find_dependencies_in_step_arguments(step2, obj, step_map)
assert "step1" in dependencies


def test_delayed_return_import_from_correct_module():
"""Verify that DelayedReturn can be imported from sagemaker.mlops.workflow.function_step."""
from sagemaker.mlops.workflow.function_step import DelayedReturn
assert DelayedReturn is not None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test performs an actual import of DelayedReturn from the real module, which makes it more of a smoke/integration test than a unit test. If sagemaker.mlops.workflow.function_step has heavy dependencies that aren't available in the unit test environment, this could fail unexpectedly.

Also, the assertion hasattr(DelayedReturn, '_to_json_get') couples this test to a private implementation detail. If the private method is renamed or removed, this test will break even though the import fix is still valid. Consider either:

  1. Removing the hasattr check and just verifying the import succeeds, or
  2. Checking a more stable public attribute if one exists.
def test_delayed_return_import_from_correct_module():
    """Verify that DelayedReturn can be imported from sagemaker.mlops.workflow.function_step."""
    from sagemaker.mlops.workflow.function_step import DelayedReturn
    assert DelayedReturn is not None

assert hasattr(DelayedReturn, '_to_json_get')


def test_tuning_step_requires_step_args():
from sagemaker.mlops.workflow.steps import TuningStep

Expand Down
Loading