Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"TRAINING_JOB_PREFIX = \"e2e-v3-pytorch\"\n",
"\n",
"# AWS Configuration\n",
"AWS_REGION = Session.boto_region_name\n",
"AWS_REGION = Session().boto_region_name\n",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as the other notebook — consider storing the Session() instance in a variable for reuse rather than creating a throwaway object:

sagemaker_session = Session()
AWS_REGION = sagemaker_session.boto_region_name

"PYTORCH_TRAINING_IMAGE = f\"763104351884.dkr.ecr.{AWS_REGION}.amazonaws.com/pytorch-training:1.13.1-cpu-py39\"\n",
"\n",
"# Generate unique identifiers\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
"MLFLOW_TRACKING_ARN = \"XXXXX\"\n",
"\n",
"# AWS Configuration\n",
"AWS_REGION = Session.boto_region_name\n",
"AWS_REGION = Session().boto_region_name\n",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While Session().boto_region_name is correct (it needs instantiation), this creates a throwaway Session object just to get the region. Consider using boto3.session.Session().region_name for a lighter-weight approach, or storing the Session() instance for reuse later in the notebook:

sagemaker_session = Session()
AWS_REGION = sagemaker_session.boto_region_name

This avoids creating multiple Session objects if it's used elsewhere in the notebook.

"\n",
"# Get PyTorch training image dynamically\n",
"PYTORCH_TRAINING_IMAGE = image_uris.retrieve(\n",
Expand Down Expand Up @@ -330,25 +330,33 @@
"outputs": [],
"source": [
"# Get the latest version of the registered model\n",
"# NOTE: MLflow 3.x removed `registered_model.latest_versions`. Use\n",
"# `client.search_model_versions()` instead.\n",
"from mlflow import MlflowClient\n",
"\n",
"client = MlflowClient()\n",
"registered_model = client.get_registered_model(name=MLFLOW_REGISTERED_MODEL_NAME)\n",
"\n",
"latest_version = registered_model.latest_versions[0]\n",
"# Search for the latest version of the registered model (MLflow 3.x compatible)\n",
"versions = client.search_model_versions(\n",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: The order_by parameter value 'version_number DESC' — please verify this is the correct field name for MLflow's search_model_versions. The MLflow documentation uses 'version_number DESC' in some examples but older versions may use different field names. Since this notebook targets mlflow==3.4.0, it would be good to confirm this works with that specific version.

" filter_string=f\"name='{MLFLOW_REGISTERED_MODEL_NAME}'\",\n",
" order_by=['version_number DESC'],\n",
" max_results=1\n",
")\n",
"\n",
"if not versions:\n",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good defensive check raising ValueError with a descriptive message including the model name. This follows the SDK convention of fail-fast validation with specific exceptions.

" raise ValueError(f\"No versions found for model '{MLFLOW_REGISTERED_MODEL_NAME}'\")\n",
"\n",
"latest_version = versions[0]\n",
"model_version = latest_version.version\n",
"model_source = latest_version.source\n",
"\n",
"# Get S3 URL of model files (for info only)\n",
"artifact_uri = client.get_model_version_download_uri(MLFLOW_REGISTERED_MODEL_NAME, model_version)\n",
"\n",
"# MLflow model registry path to use with ModelBuilder\n",
"mlflow_model_path = f\"models:/{MLFLOW_REGISTERED_MODEL_NAME}/{model_version}\"\n",
"\n",
"print(f\"Registered Model: {MLFLOW_REGISTERED_MODEL_NAME}\")\n",
"print(f\"Latest Version: {model_version}\")\n",
"print(f\"Source: {model_source}\")\n",
"print(f\"Model artifacts location: {artifact_uri}\")"
"print(f\"Source (artifact location): {model_source}\")\n",
"print(f\"MLflow model path for deployment: {mlflow_model_path}\")"
]
},
{
Expand Down Expand Up @@ -427,6 +435,8 @@
"from sagemaker.serve.mode.function_pointers import Mode\n",
"\n",
"# Cloud deployment to SageMaker endpoint\n",
"# Note: 'dependencies' parameter is deprecated. You may see a deprecation warning.\n",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says 'dependencies' parameter is deprecated and suggests configure_for_torchserve() — is this accurate? If so, should the example code itself be updated to use the non-deprecated approach rather than just adding a comment about it? Leaving deprecated code in an example notebook may confuse users.

"# Use configure_for_torchserve() for new projects.\n",
"model_builder = ModelBuilder(\n",
" mode=Mode.SAGEMAKER_ENDPOINT,\n",
" schema_builder=schema_builder,\n",
Expand Down Expand Up @@ -481,23 +491,43 @@
"metadata": {},
"outputs": [],
"source": [
"import boto3\n",
"\n",
"# Test with JSON input\n",
"# Test with JSON input using V3-native endpoint invocation\n",
"test_data = [[0.1, 0.2, 0.3, 0.4]]\n",
"\n",
"runtime_client = boto3.client('sagemaker-runtime')\n",
"response = runtime_client.invoke_endpoint(\n",
" EndpointName=core_endpoint.endpoint_name,\n",
" Body=json.dumps(test_data),\n",
" ContentType='application/json'\n",
"result = core_endpoint.invoke(\n",
" body=json.dumps(test_data),\n",
" content_type='application/json'\n",
")\n",
"\n",
"prediction = json.loads(response['Body'].read().decode('utf-8'))\n",
"prediction = json.loads(result.body.read().decode('utf-8'))\n",
"print(f\"Input: {test_data}\")\n",
"print(f\"Prediction: {prediction}\")"
]
},
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good change replacing raw boto3 sagemaker-runtime client with core_endpoint.invoke() — this aligns with the V3 tenet that all subpackages should use sagemaker-core rather than calling boto3 directly.

However, please verify the exact parameter names for core_endpoint.invoke(). The sagemaker-core Endpoint.invoke() method may use Body and ContentType (PascalCase matching the API model) rather than body and content_type (snake_case). Similarly, the response attribute might be Body rather than body. If the snake_case versions work, that's fine — just want to make sure this has been tested.

{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test with multiple inputs\n",
"test_inputs = [\n",
" [[0.5, 0.3, 0.2, 0.1]],\n",
" [[0.9, 0.1, 0.8, 0.2]],\n",
" [[0.2, 0.7, 0.4, 0.6]]\n",
"]\n",
"\n",
"for i, test_input in enumerate(test_inputs, 1):\n",
" result = core_endpoint.invoke(\n",
" body=json.dumps(test_input),\n",
" content_type='application/json'\n",
" )\n",
" \n",
" prediction = json.loads(result.body.read().decode('utf-8'))\n",
" print(f\"Test {i} - Input {test_input}: {prediction}\")\n",
" print('-' * 50)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -551,7 +581,12 @@
"- `ModelBuilder` with `MLFLOW_MODEL_PATH` - deploy from registry\n",
"\n",
"Key patterns:\n",
"- Custom `PayloadTranslator` classes for PyTorch tensor serialization\n"
"- Custom `PayloadTranslator` classes for PyTorch tensor serialization\n",
"- V3-native `core_endpoint.invoke()` for inference\n",
"\n",
"**MLflow 3.x API Note:**\n",
"- Use `client.search_model_versions()` instead of the removed `registered_model.latest_versions` attribute\n",
"- Use `latest_version.source` for artifact location instead of `client.get_model_version_download_uri()`\n"
]
}
],
Expand Down
Loading