Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions sagemaker-core/src/sagemaker/core/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,10 @@ def _validate_source_directory(source_directory):

# Check if the source path is under any sensitive directory
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: os.path.commonpath() raises ValueError when paths have no common path or are a mix of absolute and relative paths. For example, on Windows with different drives, or if abs_source is somehow relative. More critically, if abs_source equals a sensitive path exactly (e.g., /etc), os.path.commonpath(['/etc', '/etc']) returns '/etc' which equals sensitive_path — this is correct and desired. But consider wrapping this in a try/except for ValueError:

try:
    common = os.path.commonpath([abs_source, sensitive_path])
    if abs_source != "/" and common == sensitive_path:
        raise ValueError(...)
except ValueError:
    # No common path (e.g., different drives on Windows) — not under sensitive path
    pass

This defensive pattern should be applied to all four locations where os.path.commonpath() is used in this PR.

if abs_source != "/" and abs_source.startswith(sensitive_path):
if abs_source != "/" and (
os.path.commonpath([abs_source, sensitive_path])
== sensitive_path
):
raise ValueError(
f"source_directory cannot access sensitive system paths. "
f"Got: {source_directory} (resolved to {abs_source})"
Expand All @@ -673,7 +676,10 @@ def _validate_dependency_path(dependency):

# Check if the dependency path is under any sensitive directory
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
if abs_dependency != "/" and abs_dependency.startswith(sensitive_path):
if abs_dependency != "/" and (
os.path.commonpath([abs_dependency, sensitive_path])
== sensitive_path
):
raise ValueError(
f"dependency path cannot access sensitive system paths. "
f"Got: {dependency} (resolved to {abs_dependency})"
Expand All @@ -686,10 +692,13 @@ def _create_or_update_code_dir(
"""Placeholder docstring"""
code_dir = os.path.join(model_dir, "code")
resolved_code_dir = _get_resolved_path(code_dir)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same os.path.commonpath() ValueError risk here in _create_or_update_code_dir. Please add the same defensive try/except pattern.


# Validate that code_dir does not resolve to a sensitive system path
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
if resolved_code_dir != "/" and resolved_code_dir.startswith(sensitive_path):
if resolved_code_dir != "/" and (
os.path.commonpath([resolved_code_dir, sensitive_path])
== sensitive_path
):
raise ValueError(
f"Invalid code_dir path: {code_dir} resolves to sensitive system path {resolved_code_dir}"
)
Expand Down Expand Up @@ -1688,7 +1697,8 @@ def _is_bad_path(path, base):
bool: True if the path is not rooted under the base directory, False otherwise.
"""
# joinpath will ignore base if path is absolute
return not _get_resolved_path(joinpath(base, path)).startswith(base)
resolved = _get_resolved_path(joinpath(base, path))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same ValueError risk from os.path.commonpath(): If resolved and base are on different mount points or one is empty, os.path.commonpath([resolved, base]) will raise ValueError. Since this is a security function, it should treat such cases as bad paths (return True):

def _is_bad_path(path, base):
    resolved = _get_resolved_path(joinpath(base, path))
    try:
        return os.path.commonpath([resolved, base]) != base
    except ValueError:
        return True

return os.path.commonpath([resolved, base]) != base


def _is_bad_link(info, base):
Expand All @@ -1708,19 +1718,18 @@ def _is_bad_link(info, base):
return _is_bad_path(info.linkname, base=tip)


def _get_safe_members(members):
def _get_safe_members(members, base):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signature change to _get_safe_members is a breaking change for any internal callers. While this is a private function (prefixed with _), verify that no other code in the SDK calls _get_safe_members with the old signature (just members as a TarFile object). The mlops copy was independently fixed, but there could be other callers.

"""A generator that yields members that are safe to extract.

It filters out bad paths and bad links.

Args:
members (list): A list of members to check.
members (list): A list of TarInfo members to check.
base (str): The resolved base directory for extraction.

Yields:
tarfile.TarInfo: The tar file info.
"""
base = _get_resolved_path("")

for file_info in members:
if _is_bad_path(file_info.name, base):
logger.error("%s is blocked (illegal path)", file_info.name)
Expand Down Expand Up @@ -1783,7 +1792,11 @@ def custom_extractall_tarfile(tar, extract_path):
if hasattr(tarfile, "data_filter"):
tar.extractall(path=extract_path, filter="data")
else:
tar.extractall(path=extract_path, members=_get_safe_members(tar))
base = _get_resolved_path(extract_path)
tar.extractall(
path=extract_path,
members=_get_safe_members(tar.getmembers(), base),
)
# Re-validate extracted paths to catch symlink race conditions
_validate_extracted_paths(extract_path)

Expand Down
Loading
Loading