diff --git a/sagemaker-core/src/sagemaker/core/workflow/utilities.py b/sagemaker-core/src/sagemaker/core/workflow/utilities.py index c07a31c51e..2c33c6ae7b 100644 --- a/sagemaker-core/src/sagemaker/core/workflow/utilities.py +++ b/sagemaker-core/src/sagemaker/core/workflow/utilities.py @@ -175,7 +175,9 @@ def get_code_hash(step: Entity) -> str: source_dir = source_code.source_dir requirements = source_code.requirements entry_point = source_code.entry_script - return get_training_code_hash(entry_point, source_dir, requirements) + return get_training_code_hash( + entry_point, source_dir, requirements + ) return None diff --git a/sagemaker-core/tests/unit/workflow/test_utilities.py b/sagemaker-core/tests/unit/workflow/test_utilities.py index 5e9ed7bbbd..418d1f11d7 100644 --- a/sagemaker-core/tests/unit/workflow/test_utilities.py +++ b/sagemaker-core/tests/unit/workflow/test_utilities.py @@ -24,6 +24,7 @@ get_processing_dependencies, get_processing_code_hash, get_training_code_hash, + get_code_hash, validate_step_args_input, override_pipeline_parameter_var, trim_request_dict, @@ -273,10 +274,14 @@ def test_get_training_code_hash_with_source_dir(self): requirements_file.write_text("numpy==1.21.0") result_no_deps = get_training_code_hash( - entry_point=str(entry_file), source_dir=temp_dir, dependencies=None + entry_point=str(entry_file), + source_dir=temp_dir, + dependencies=None, ) result_with_deps = get_training_code_hash( - entry_point=str(entry_file), source_dir=temp_dir, dependencies=str(requirements_file) + entry_point=str(entry_file), + source_dir=temp_dir, + dependencies=str(requirements_file), ) assert result_no_deps is not None @@ -285,6 +290,33 @@ def test_get_training_code_hash_with_source_dir(self): assert len(result_with_deps) == 64 assert result_no_deps != result_with_deps + def test_get_training_code_hash_source_dir_none_deps( + self, + ): + """Test get_training_code_hash with source_dir + and None dependencies does not raise TypeError. + """ + with tempfile.TemporaryDirectory() as temp_dir: + entry_file = Path(temp_dir, "train.py") + entry_file.write_text("print('training')") + + # Should NOT raise TypeError + result_none = get_training_code_hash( + entry_point=str(entry_file), + source_dir=temp_dir, + dependencies=None, + ) + # Empty list should be equivalent to None + result_empty = get_training_code_hash( + entry_point=str(entry_file), + source_dir=temp_dir, + dependencies=[], + ) + + assert result_none is not None + assert len(result_none) == 64 + assert result_none == result_empty + def test_get_training_code_hash_entry_point_only(self): """Test get_training_code_hash with entry_point only""" with tempfile.TemporaryDirectory() as temp_dir: @@ -295,11 +327,15 @@ def test_get_training_code_hash_entry_point_only(self): # Without dependencies result_no_deps = get_training_code_hash( - entry_point=str(entry_file), source_dir=None, dependencies=None + entry_point=str(entry_file), + source_dir=None, + dependencies=None, ) # With dependencies result_with_deps = get_training_code_hash( - entry_point=str(entry_file), source_dir=None, dependencies=str(requirements_file) + entry_point=str(entry_file), + source_dir=None, + dependencies=str(requirements_file), ) assert result_no_deps is not None @@ -308,6 +344,33 @@ def test_get_training_code_hash_entry_point_only(self): assert len(result_with_deps) == 64 assert result_no_deps != result_with_deps + def test_get_training_code_hash_entry_point_none_deps( + self, + ): + """Test get_training_code_hash with entry_point + and None dependencies does not raise TypeError. + """ + with tempfile.TemporaryDirectory() as temp_dir: + entry_file = Path(temp_dir, "train.py") + entry_file.write_text("print('training')") + + # Should NOT raise TypeError + result_none = get_training_code_hash( + entry_point=str(entry_file), + source_dir=None, + dependencies=None, + ) + # Empty list should be equivalent to None + result_empty = get_training_code_hash( + entry_point=str(entry_file), + source_dir=None, + dependencies=[], + ) + + assert result_none is not None + assert len(result_none) == 64 + assert result_none == result_empty + def test_get_training_code_hash_s3_uri(self): """Test get_training_code_hash with S3 URI returns None""" result = get_training_code_hash( @@ -325,6 +388,111 @@ def test_get_training_code_hash_pipeline_variable(self): assert result is None + def test_get_code_hash_training_step_no_requirements( + self, + ): + """Test get_code_hash with TrainingStep where + SourceCode has requirements=None. + """ + # Create a fake TrainingStep class to patch isinstance + FakeTrainingStep = type( + "TrainingStep", (), {} + ) + + with tempfile.TemporaryDirectory() as temp_dir: + entry_file = Path(temp_dir, "train.py") + entry_file.write_text("print('training')") + + mock_source_code = Mock() + mock_source_code.source_dir = temp_dir + mock_source_code.requirements = None + mock_source_code.entry_script = str(entry_file) + + mock_model_trainer = Mock() + mock_model_trainer.source_code = mock_source_code + + mock_step_args = Mock() + mock_step_args.func_args = [ + mock_model_trainer + ] + + mock_step = MagicMock(spec=FakeTrainingStep) + mock_step.step_args = mock_step_args + + with patch( + "sagemaker.core.workflow.utilities" + ".TrainingStep", + new=FakeTrainingStep, + ): + result = get_code_hash(mock_step) + + assert result is not None + assert len(result) == 64 + + def test_get_code_hash_training_step_with_requirements( + self, + ): + """Test get_code_hash with TrainingStep where + SourceCode has valid requirements. + """ + FakeTrainingStep = type( + "TrainingStep", (), {} + ) + + with tempfile.TemporaryDirectory() as temp_dir: + entry_file = Path(temp_dir, "train.py") + entry_file.write_text("print('training')") + req_file = Path(temp_dir, "requirements.txt") + req_file.write_text("numpy==1.21.0") + + mock_sc_no_req = Mock() + mock_sc_no_req.source_dir = temp_dir + mock_sc_no_req.requirements = None + mock_sc_no_req.entry_script = str(entry_file) + + mock_sc_with_req = Mock() + mock_sc_with_req.source_dir = temp_dir + mock_sc_with_req.requirements = str(req_file) + mock_sc_with_req.entry_script = str(entry_file) + + mock_mt_no_req = Mock() + mock_mt_no_req.source_code = mock_sc_no_req + + mock_mt_with_req = Mock() + mock_mt_with_req.source_code = mock_sc_with_req + + mock_step_no_req = MagicMock( + spec=FakeTrainingStep + ) + mock_step_no_req.step_args = Mock() + mock_step_no_req.step_args.func_args = [ + mock_mt_no_req + ] + + mock_step_with_req = MagicMock( + spec=FakeTrainingStep + ) + mock_step_with_req.step_args = Mock() + mock_step_with_req.step_args.func_args = [ + mock_mt_with_req + ] + + with patch( + "sagemaker.core.workflow.utilities" + ".TrainingStep", + new=FakeTrainingStep, + ): + result_no_req = get_code_hash( + mock_step_no_req + ) + result_with_req = get_code_hash( + mock_step_with_req + ) + + assert result_no_req is not None + assert result_with_req is not None + assert result_no_req != result_with_req + def test_validate_step_args_input_valid(self): """Test validate_step_args_input with valid input""" step_args = _StepArguments(