From b22e8a641c3d4065d9dd7321bacfb0bf8192e8dc Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Fri, 27 Mar 2026 16:59:03 -0400 Subject: [PATCH 1/3] fix: ModelTrainer and HyperparameterTuner missing environment variables (5613) --- sagemaker-train/src/sagemaker/train/tuner.py | 17 ++++ .../tests/unit/train/test_tuner.py | 91 ++++++++++++++++++- 2 files changed, 107 insertions(+), 1 deletion(-) diff --git a/sagemaker-train/src/sagemaker/train/tuner.py b/sagemaker-train/src/sagemaker/train/tuner.py index cde1598481..83d5eeae6f 100644 --- a/sagemaker-train/src/sagemaker/train/tuner.py +++ b/sagemaker-train/src/sagemaker/train/tuner.py @@ -442,6 +442,23 @@ def _prepare_auto_parameters(self, static_hyperparameters, hyperparameters_to_ke return new_static_hyperparameters, auto_parameters + @staticmethod + def _get_model_trainer_environment(model_trainer): + """Extract environment variables from a ModelTrainer instance. + + Returns the environment dict if it is non-empty, otherwise None. + + Args: + model_trainer: ModelTrainer instance + + Returns: + dict or None: Environment variables dict, or None if empty/not set. + """ + env = getattr(model_trainer, "environment", None) + if env: + return dict(env) + return None + @classmethod def _prepare_model_trainer_for_tuning(cls, model_trainer, inputs=None, job_name=None, **kwargs): """Prepare ModelTrainer before tuning by building sm_drivers and code channels. diff --git a/sagemaker-train/tests/unit/train/test_tuner.py b/sagemaker-train/tests/unit/train/test_tuner.py index c0255eac47..722b8e5172 100644 --- a/sagemaker-train/tests/unit/train/test_tuner.py +++ b/sagemaker-train/tests/unit/train/test_tuner.py @@ -39,12 +39,13 @@ # --------------------------------------------------------------------------- -def _create_mock_model_trainer(with_internal_channels=False): +def _create_mock_model_trainer(with_internal_channels=False, environment=None): """Create a mock ModelTrainer with common attributes. Args: with_internal_channels: If True, adds internal channels (code, sm_drivers) to input_data_config for testing channel inclusion in tuning jobs. + environment: Optional dict of environment variables to set on the trainer. """ trainer = MagicMock() trainer.sagemaker_session = MagicMock() @@ -61,6 +62,7 @@ def _create_mock_model_trainer(with_internal_channels=False): trainer.stopping_condition = MagicMock() trainer.stopping_condition.max_runtime_in_seconds = 3600 trainer.input_data_config = None + trainer.environment = environment if environment is not None else {} if with_internal_channels: trainer.input_data_config = [ @@ -574,3 +576,90 @@ def test_build_training_job_definition_includes_internal_channels(self): assert "train" in channel_names, "User 'train' channel should be included" assert "validation" in channel_names, "User 'validation' channel should be included" assert len(channel_names) == 4, "Should have exactly 4 channels" + + def test_build_training_job_definition_includes_environment_variables(self): + """Test that _build_training_job_definition includes environment variables. + + This test verifies the fix for GitHub issue #5613 where tuning jobs were missing + environment variables that were set on the ModelTrainer. + """ + env_vars = {"RANDOM_STATE": "42", "MY_VAR": "hello"} + mock_trainer = _create_mock_model_trainer(environment=env_vars) + + tuner = HyperparameterTuner( + model_trainer=mock_trainer, + objective_metric_name="accuracy", + hyperparameter_ranges=_create_single_hp_range(), + ) + + definition = tuner._build_training_job_definition(None) + + # The definition should contain the environment variables + assert hasattr(definition, "environment") or hasattr(definition, "Environment"), \ + "Training job definition should have environment attribute" + definition_env = getattr(definition, "environment", None) or getattr(definition, "Environment", None) + assert definition_env == env_vars, \ + f"Environment should be {env_vars}, got {definition_env}" + + def test_build_training_job_definition_with_empty_environment(self): + """Test that _build_training_job_definition handles empty environment.""" + mock_trainer = _create_mock_model_trainer(environment={}) + + tuner = HyperparameterTuner( + model_trainer=mock_trainer, + objective_metric_name="accuracy", + hyperparameter_ranges=_create_single_hp_range(), + ) + + # Should not raise an error + definition = tuner._build_training_job_definition(None) + assert definition is not None + + def test_build_training_job_definition_with_none_environment(self): + """Test that _build_training_job_definition handles None environment.""" + mock_trainer = _create_mock_model_trainer() + mock_trainer.environment = None + + tuner = HyperparameterTuner( + model_trainer=mock_trainer, + objective_metric_name="accuracy", + hyperparameter_ranges=_create_single_hp_range(), + ) + + # Should not raise an error + definition = tuner._build_training_job_definition(None) + assert definition is not None + + +class TestGetModelTrainerEnvironment: + """Test _get_model_trainer_environment helper method.""" + + def test_returns_environment_when_set(self): + """Test that environment is returned when set on model trainer.""" + env_vars = {"KEY1": "val1", "KEY2": "val2"} + mock_trainer = _create_mock_model_trainer(environment=env_vars) + + result = HyperparameterTuner._get_model_trainer_environment(mock_trainer) + assert result == env_vars + + def test_returns_none_when_empty(self): + """Test that None is returned when environment is empty.""" + mock_trainer = _create_mock_model_trainer(environment={}) + + result = HyperparameterTuner._get_model_trainer_environment(mock_trainer) + assert result is None + + def test_returns_none_when_none(self): + """Test that None is returned when environment is None.""" + mock_trainer = _create_mock_model_trainer() + mock_trainer.environment = None + + result = HyperparameterTuner._get_model_trainer_environment(mock_trainer) + assert result is None + + def test_returns_none_when_attribute_missing(self): + """Test that None is returned when environment attribute doesn't exist.""" + mock_trainer = MagicMock(spec=[]) + + result = HyperparameterTuner._get_model_trainer_environment(mock_trainer) + assert result is None From 3c994881c2daaac76c01f8c5d438923069760074 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Fri, 27 Mar 2026 17:04:07 -0400 Subject: [PATCH 2/3] fix: address review comments (iteration #1) --- sagemaker-train/src/sagemaker/train/tuner.py | 15 +- .../tests/unit/train/test_tuner.py | 147 +++++++++++++++--- 2 files changed, 133 insertions(+), 29 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/tuner.py b/sagemaker-train/src/sagemaker/train/tuner.py index 83d5eeae6f..ac5e369271 100644 --- a/sagemaker-train/src/sagemaker/train/tuner.py +++ b/sagemaker-train/src/sagemaker/train/tuner.py @@ -443,18 +443,21 @@ def _prepare_auto_parameters(self, static_hyperparameters, hyperparameters_to_ke return new_static_hyperparameters, auto_parameters @staticmethod - def _get_model_trainer_environment(model_trainer): + def _get_model_trainer_environment( + model_trainer: "ModelTrainer", + ) -> Optional[Dict[str, str]]: """Extract environment variables from a ModelTrainer instance. Returns the environment dict if it is non-empty, otherwise None. Args: - model_trainer: ModelTrainer instance + model_trainer (ModelTrainer): ModelTrainer instance. Returns: - dict or None: Environment variables dict, or None if empty/not set. + Optional[Dict[str, str]]: Environment variables dict, + or None if empty/not set. """ - env = getattr(model_trainer, "environment", None) + env = model_trainer.environment if env: return dict(env) return None @@ -1530,8 +1533,8 @@ def _build_training_job_definition(self, inputs): ) # Pass through environment variables from model_trainer - env = getattr(model_trainer, "environment", None) - if env and isinstance(env, dict): + env = self._get_model_trainer_environment(model_trainer) + if env is not None: definition.environment = env # Pass through VPC config from model_trainer diff --git a/sagemaker-train/tests/unit/train/test_tuner.py b/sagemaker-train/tests/unit/train/test_tuner.py index 722b8e5172..23d429ee9b 100644 --- a/sagemaker-train/tests/unit/train/test_tuner.py +++ b/sagemaker-train/tests/unit/train/test_tuner.py @@ -578,13 +578,15 @@ def test_build_training_job_definition_includes_internal_channels(self): assert len(channel_names) == 4, "Should have exactly 4 channels" def test_build_training_job_definition_includes_environment_variables(self): - """Test that _build_training_job_definition includes environment variables. + """Test that _build_training_job_definition includes env vars. - This test verifies the fix for GitHub issue #5613 where tuning jobs were missing - environment variables that were set on the ModelTrainer. + This test verifies the fix for GitHub issue #5613 where tuning + jobs were missing environment variables set on the ModelTrainer. """ env_vars = {"RANDOM_STATE": "42", "MY_VAR": "hello"} - mock_trainer = _create_mock_model_trainer(environment=env_vars) + mock_trainer = _create_mock_model_trainer( + environment=env_vars, + ) tuner = HyperparameterTuner( model_trainer=mock_trainer, @@ -595,14 +597,13 @@ def test_build_training_job_definition_includes_environment_variables(self): definition = tuner._build_training_job_definition(None) # The definition should contain the environment variables - assert hasattr(definition, "environment") or hasattr(definition, "Environment"), \ - "Training job definition should have environment attribute" - definition_env = getattr(definition, "environment", None) or getattr(definition, "Environment", None) - assert definition_env == env_vars, \ - f"Environment should be {env_vars}, got {definition_env}" + assert definition.environment == env_vars, ( + f"Environment should be {env_vars}, " + f"got {definition.environment}" + ) def test_build_training_job_definition_with_empty_environment(self): - """Test that _build_training_job_definition handles empty environment.""" + """Test that empty env is not propagated to definition.""" mock_trainer = _create_mock_model_trainer(environment={}) tuner = HyperparameterTuner( @@ -611,12 +612,17 @@ def test_build_training_job_definition_with_empty_environment(self): hyperparameter_ranges=_create_single_hp_range(), ) - # Should not raise an error definition = tuner._build_training_job_definition(None) assert definition is not None + # Empty environment should not be set on the definition + env = getattr(definition, "environment", None) + assert env is None, ( + "Empty environment should not be propagated, " + f"got {env}" + ) def test_build_training_job_definition_with_none_environment(self): - """Test that _build_training_job_definition handles None environment.""" + """Test that None env is not propagated to definition.""" mock_trainer = _create_mock_model_trainer() mock_trainer.environment = None @@ -626,27 +632,38 @@ def test_build_training_job_definition_with_none_environment(self): hyperparameter_ranges=_create_single_hp_range(), ) - # Should not raise an error definition = tuner._build_training_job_definition(None) assert definition is not None + # None environment should not be set on the definition + env = getattr(definition, "environment", None) + assert env is None, ( + "None environment should not be propagated, " + f"got {env}" + ) class TestGetModelTrainerEnvironment: """Test _get_model_trainer_environment helper method.""" def test_returns_environment_when_set(self): - """Test that environment is returned when set on model trainer.""" + """Test that environment is returned when set.""" env_vars = {"KEY1": "val1", "KEY2": "val2"} - mock_trainer = _create_mock_model_trainer(environment=env_vars) + mock_trainer = _create_mock_model_trainer( + environment=env_vars, + ) - result = HyperparameterTuner._get_model_trainer_environment(mock_trainer) + result = HyperparameterTuner._get_model_trainer_environment( + mock_trainer, + ) assert result == env_vars def test_returns_none_when_empty(self): """Test that None is returned when environment is empty.""" mock_trainer = _create_mock_model_trainer(environment={}) - result = HyperparameterTuner._get_model_trainer_environment(mock_trainer) + result = HyperparameterTuner._get_model_trainer_environment( + mock_trainer, + ) assert result is None def test_returns_none_when_none(self): @@ -654,12 +671,96 @@ def test_returns_none_when_none(self): mock_trainer = _create_mock_model_trainer() mock_trainer.environment = None - result = HyperparameterTuner._get_model_trainer_environment(mock_trainer) + result = HyperparameterTuner._get_model_trainer_environment( + mock_trainer, + ) assert result is None - def test_returns_none_when_attribute_missing(self): - """Test that None is returned when environment attribute doesn't exist.""" - mock_trainer = MagicMock(spec=[]) - result = HyperparameterTuner._get_model_trainer_environment(mock_trainer) - assert result is None +class TestMultiTrainerEnvironmentPropagation: + """Test environment propagation for multi-trainer tuning jobs.""" + + def test_create_multi_trainer_with_environment(self): + """Test that environment is preserved on trainers in create().""" + env1 = {"VAR_A": "1"} + env2 = {"VAR_B": "2"} + trainer1 = _create_mock_model_trainer(environment=env1) + trainer2 = _create_mock_model_trainer(environment=env2) + + tuner = HyperparameterTuner.create( + model_trainer_dict={ + "trainer1": trainer1, + "trainer2": trainer2, + }, + objective_metric_name_dict={ + "trainer1": "accuracy", + "trainer2": "loss", + }, + hyperparameter_ranges_dict={ + "trainer1": _create_single_hp_range(), + "trainer2": _create_single_hp_range(), + }, + ) + + # Verify environment is preserved on each trainer + assert tuner.model_trainer_dict["trainer1"].environment == env1 + assert tuner.model_trainer_dict["trainer2"].environment == env2 + + def test_get_environment_for_each_trainer_in_dict(self): + """Test _get_model_trainer_environment for each trainer.""" + env1 = {"VAR_A": "1"} + env2 = {"VAR_B": "2"} + trainer1 = _create_mock_model_trainer(environment=env1) + trainer2 = _create_mock_model_trainer(environment=env2) + + tuner = HyperparameterTuner.create( + model_trainer_dict={ + "trainer1": trainer1, + "trainer2": trainer2, + }, + objective_metric_name_dict={ + "trainer1": "accuracy", + "trainer2": "loss", + }, + hyperparameter_ranges_dict={ + "trainer1": _create_single_hp_range(), + "trainer2": _create_single_hp_range(), + }, + ) + + for name, mt in tuner.model_trainer_dict.items(): + env = HyperparameterTuner._get_model_trainer_environment( + mt, + ) + if name == "trainer1": + assert env == env1 + elif name == "trainer2": + assert env == env2 + + def test_multi_trainer_empty_environment(self): + """Test multi-trainer with empty environment.""" + trainer1 = _create_mock_model_trainer(environment={}) + trainer2 = _create_mock_model_trainer(environment={}) + + tuner = HyperparameterTuner.create( + model_trainer_dict={ + "trainer1": trainer1, + "trainer2": trainer2, + }, + objective_metric_name_dict={ + "trainer1": "accuracy", + "trainer2": "loss", + }, + hyperparameter_ranges_dict={ + "trainer1": _create_single_hp_range(), + "trainer2": _create_single_hp_range(), + }, + ) + + for _name, mt in tuner.model_trainer_dict.items(): + env = HyperparameterTuner._get_model_trainer_environment( + mt, + ) + assert env is None, ( + "Empty environment should return None" + ) From 0be1e83b4ab084a5f183570cdeef908b923b2007 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Fri, 27 Mar 2026 17:13:30 -0400 Subject: [PATCH 3/3] fix: address review comments (iteration #2) --- sagemaker-train/src/sagemaker/train/tuner.py | 11 ++++++----- sagemaker-train/tests/unit/train/test_tuner.py | 9 +++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/tuner.py b/sagemaker-train/src/sagemaker/train/tuner.py index ac5e369271..8b05b9e9e1 100644 --- a/sagemaker-train/src/sagemaker/train/tuner.py +++ b/sagemaker-train/src/sagemaker/train/tuner.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. """Placeholder docstring""" -from __future__ import absolute_import +from __future__ import absolute_import, annotations import logging from enum import Enum @@ -445,17 +445,18 @@ def _prepare_auto_parameters(self, static_hyperparameters, hyperparameters_to_ke @staticmethod def _get_model_trainer_environment( model_trainer: "ModelTrainer", - ) -> Optional[Dict[str, str]]: + ) -> dict[str, str] | None: """Extract environment variables from a ModelTrainer instance. - Returns the environment dict if it is non-empty, otherwise None. + Returns a copy of the environment dict if it is non-empty, + otherwise None. Args: model_trainer (ModelTrainer): ModelTrainer instance. Returns: - Optional[Dict[str, str]]: Environment variables dict, - or None if empty/not set. + dict[str, str] | None: A copy of the environment variables + dict, or None if empty/not set. """ env = model_trainer.environment if env: diff --git a/sagemaker-train/tests/unit/train/test_tuner.py b/sagemaker-train/tests/unit/train/test_tuner.py index 23d429ee9b..1880759809 100644 --- a/sagemaker-train/tests/unit/train/test_tuner.py +++ b/sagemaker-train/tests/unit/train/test_tuner.py @@ -601,6 +601,11 @@ def test_build_training_job_definition_includes_environment_variables(self): f"Environment should be {env_vars}, " f"got {definition.environment}" ) + # Verify defensive copy: the dict on the definition + # should not be the same object as the original + assert definition.environment is not env_vars, ( + "Environment should be a copy, not the same object" + ) def test_build_training_job_definition_with_empty_environment(self): """Test that empty env is not propagated to definition.""" @@ -656,6 +661,10 @@ def test_returns_environment_when_set(self): mock_trainer, ) assert result == env_vars + # Verify it's a copy, not the same object + assert result is not env_vars, ( + "Should return a defensive copy" + ) def test_returns_none_when_empty(self): """Test that None is returned when environment is empty."""