Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions sagemaker-train/src/sagemaker/train/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# language governing permissions and limitations under the License.
"""Placeholder docstring"""

from __future__ import absolute_import
from __future__ import absolute_import, annotations

import logging
from enum import Enum
Expand Down Expand Up @@ -442,6 +442,27 @@ def _prepare_auto_parameters(self, static_hyperparameters, hyperparameters_to_ke

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor design question: This static method is essentially a 4-line null-safe copy. It's only called in one place (_build_training_job_definition). Is extracting it as a separate static method warranted? It adds indirection without much reuse benefit. If the intent is to also use it in _build_training_job_definitions (the multi-trainer path), that usage is currently missing. If it's only used once, consider inlining the logic.

return new_static_hyperparameters, auto_parameters

@staticmethod
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type annotation uses a forward reference string "ModelTrainer" but ModelTrainer is never imported (even conditionally under TYPE_CHECKING). For proper type checking support, consider adding:

from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from sagemaker.train.model_trainer import ModelTrainer

This avoids a runtime import cycle while enabling static analysis tools to resolve the type.

def _get_model_trainer_environment(
model_trainer: "ModelTrainer",
) -> dict[str, str] | None:
"""Extract environment variables from a ModelTrainer instance.

Returns a copy of the environment dict if it is non-empty,
otherwise None.

Args:
model_trainer (ModelTrainer): ModelTrainer instance.

Returns:
dict[str, str] | None: A copy of the environment variables
dict, or None if empty/not set.
"""
env = model_trainer.environment
if env:
return dict(env)
return None

@classmethod
def _prepare_model_trainer_for_tuning(cls, model_trainer, inputs=None, job_name=None, **kwargs):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description states: "Both the single-trainer path (_build_training_job_definition) and the multi-trainer path (_build_training_job_definitions) need this fix." However, only the single-trainer path (_build_training_job_definition) is modified here. Is the multi-trainer path (_build_training_job_definitions) already handling environment variables correctly, or is this fix incomplete? If the multi-trainer path also needs the fix, please add the corresponding change and a test that calls _build_training_job_definitions and verifies definition.environment is set on each resulting definition.

"""Prepare ModelTrainer before tuning by building sm_drivers and code channels.
Expand Down Expand Up @@ -1513,8 +1534,8 @@ def _build_training_job_definition(self, inputs):
)

# Pass through environment variables from model_trainer
env = getattr(model_trainer, "environment", None)
if env and isinstance(env, dict):
env = self._get_model_trainer_environment(model_trainer)
if env is not None:
definition.environment = env

# Pass through VPC config from model_trainer
Expand Down
201 changes: 200 additions & 1 deletion sagemaker-train/tests/unit/train/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@
# ---------------------------------------------------------------------------


def _create_mock_model_trainer(with_internal_channels=False):
def _create_mock_model_trainer(with_internal_channels=False, environment=None):
"""Create a mock ModelTrainer with common attributes.

Args:
with_internal_channels: If True, adds internal channels (code, sm_drivers)
to input_data_config for testing channel inclusion in tuning jobs.
environment: Optional dict of environment variables to set on the trainer.
"""
trainer = MagicMock()
trainer.sagemaker_session = MagicMock()
Expand All @@ -61,6 +62,7 @@ def _create_mock_model_trainer(with_internal_channels=False):
trainer.stopping_condition = MagicMock()
trainer.stopping_condition.max_runtime_in_seconds = 3600
trainer.input_data_config = None
trainer.environment = environment if environment is not None else {}

if with_internal_channels:
trainer.input_data_config = [
Expand Down Expand Up @@ -574,3 +576,200 @@ def test_build_training_job_definition_includes_internal_channels(self):
assert "train" in channel_names, "User 'train' channel should be included"
assert "validation" in channel_names, "User 'validation' channel should be included"
assert len(channel_names) == 4, "Should have exactly 4 channels"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good test — this directly validates the fix for issue #5613 by calling _build_training_job_definition and checking definition.environment. The defensive copy assertion is a nice touch.


def test_build_training_job_definition_includes_environment_variables(self):
"""Test that _build_training_job_definition includes env vars.

This test verifies the fix for GitHub issue #5613 where tuning
jobs were missing environment variables set on the ModelTrainer.
"""
env_vars = {"RANDOM_STATE": "42", "MY_VAR": "hello"}
mock_trainer = _create_mock_model_trainer(
environment=env_vars,
)

tuner = HyperparameterTuner(
model_trainer=mock_trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

definition = tuner._build_training_job_definition(None)

# The definition should contain the environment variables
assert definition.environment == env_vars, (
f"Environment should be {env_vars}, "
f"got {definition.environment}"
)
# Verify defensive copy: the dict on the definition
# should not be the same object as the original
assert definition.environment is not env_vars, (
"Environment should be a copy, not the same object"
)

def test_build_training_job_definition_with_empty_environment(self):
"""Test that empty env is not propagated to definition."""
mock_trainer = _create_mock_model_trainer(environment={})

tuner = HyperparameterTuner(
model_trainer=mock_trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

definition = tuner._build_training_job_definition(None)
assert definition is not None
# Empty environment should not be set on the definition
env = getattr(definition, "environment", None)
assert env is None, (
"Empty environment should not be propagated, "
f"got {env}"
)

def test_build_training_job_definition_with_none_environment(self):
"""Test that None env is not propagated to definition."""
mock_trainer = _create_mock_model_trainer()
mock_trainer.environment = None

tuner = HyperparameterTuner(
model_trainer=mock_trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

definition = tuner._build_training_job_definition(None)
assert definition is not None
# None environment should not be set on the definition
env = getattr(definition, "environment", None)
assert env is None, (
"None environment should not be propagated, "
f"got {env}"
)


class TestGetModelTrainerEnvironment:
"""Test _get_model_trainer_environment helper method."""

def test_returns_environment_when_set(self):
"""Test that environment is returned when set."""
env_vars = {"KEY1": "val1", "KEY2": "val2"}
mock_trainer = _create_mock_model_trainer(
environment=env_vars,
)

result = HyperparameterTuner._get_model_trainer_environment(
mock_trainer,
)
assert result == env_vars
# Verify it's a copy, not the same object
assert result is not env_vars, (
"Should return a defensive copy"
)

def test_returns_none_when_empty(self):
"""Test that None is returned when environment is empty."""
mock_trainer = _create_mock_model_trainer(environment={})

result = HyperparameterTuner._get_model_trainer_environment(
mock_trainer,
)
assert result is None

def test_returns_none_when_none(self):
"""Test that None is returned when environment is None."""
mock_trainer = _create_mock_model_trainer()
mock_trainer.environment = None

result = HyperparameterTuner._get_model_trainer_environment(
mock_trainer,
)
assert result is None


class TestMultiTrainerEnvironmentPropagation:
"""Test environment propagation for multi-trainer tuning jobs."""

def test_create_multi_trainer_with_environment(self):
"""Test that environment is preserved on trainers in create()."""
env1 = {"VAR_A": "1"}
env2 = {"VAR_B": "2"}
trainer1 = _create_mock_model_trainer(environment=env1)
trainer2 = _create_mock_model_trainer(environment=env2)

tuner = HyperparameterTuner.create(
model_trainer_dict={
"trainer1": trainer1,
"trainer2": trainer2,
},
objective_metric_name_dict={
"trainer1": "accuracy",
"trainer2": "loss",
},
hyperparameter_ranges_dict={
"trainer1": _create_single_hp_range(),
"trainer2": _create_single_hp_range(),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These multi-trainer tests (TestMultiTrainerEnvironmentPropagation) only verify that the mock's .environment attribute is preserved on the mock objects stored in tuner.model_trainer_dict. Since these are MagicMock objects, this is essentially testing that mock attribute assignment works — it doesn't test that _build_training_job_definitions (the multi-trainer build path) actually propagates environment into the resulting training job definitions. To properly test the multi-trainer fix, you'd need a test that calls _build_training_job_definitions and asserts definition.environment is set on each output definition.

},
)

# Verify environment is preserved on each trainer
assert tuner.model_trainer_dict["trainer1"].environment == env1
assert tuner.model_trainer_dict["trainer2"].environment == env2

def test_get_environment_for_each_trainer_in_dict(self):
"""Test _get_model_trainer_environment for each trainer."""
env1 = {"VAR_A": "1"}
env2 = {"VAR_B": "2"}
trainer1 = _create_mock_model_trainer(environment=env1)
trainer2 = _create_mock_model_trainer(environment=env2)

tuner = HyperparameterTuner.create(
model_trainer_dict={
"trainer1": trainer1,
"trainer2": trainer2,
},
objective_metric_name_dict={
"trainer1": "accuracy",
"trainer2": "loss",
},
hyperparameter_ranges_dict={
"trainer1": _create_single_hp_range(),
"trainer2": _create_single_hp_range(),
},
)

for name, mt in tuner.model_trainer_dict.items():
env = HyperparameterTuner._get_model_trainer_environment(
mt,
)
if name == "trainer1":
assert env == env1
elif name == "trainer2":
assert env == env2

def test_multi_trainer_empty_environment(self):
"""Test multi-trainer with empty environment."""
trainer1 = _create_mock_model_trainer(environment={})
trainer2 = _create_mock_model_trainer(environment={})

tuner = HyperparameterTuner.create(
model_trainer_dict={
"trainer1": trainer1,
"trainer2": trainer2,
},
objective_metric_name_dict={
"trainer1": "accuracy",
"trainer2": "loss",
},
hyperparameter_ranges_dict={
"trainer1": _create_single_hp_range(),
"trainer2": _create_single_hp_range(),
},
)

for _name, mt in tuner.model_trainer_dict.items():
env = HyperparameterTuner._get_model_trainer_environment(
mt,
)
assert env is None, (
"Empty environment should return None"
)
Loading