Skip to content

fix: Enhancements Needed for Secure Tar Extraction (5560)#5690

Draft
aviruthen wants to merge 2 commits intoaws:masterfrom
aviruthen:fix/enhancements-needed-for-secure-tar-extraction-5560
Draft

fix: Enhancements Needed for Secure Tar Extraction (5560)#5690
aviruthen wants to merge 2 commits intoaws:masterfrom
aviruthen:fix/enhancements-needed-for-secure-tar-extraction-5560

Conversation

@aviruthen
Copy link
Copy Markdown
Collaborator

Description

The issue reports three security bugs in tar extraction logic. Examining the codebase reveals two affected files:

  1. sagemaker-mlops/src/sagemaker/mlops/workflow/_repack_model.py - This file has ALREADY been fixed correctly: _is_bad_path() uses os.path.commonpath(), _get_safe_members() accepts a members list and base parameter, and custom_extractall_tarfile() passes tar.getmembers() and _get_resolved_path(extract_path) as base.

  2. sagemaker-core/src/sagemaker/core/common_utils.py - This is the V3 core utilities file that also defines _get_resolved_path, _is_bad_path, _get_safe_members, and custom_extractall_tarfile (in the truncated portion past line 1420). These functions are used by _create_or_update_code_dir() and _extract_model() in the same file. The sagemaker-core copy likely still has the old buggy implementations since the mlops version was independently fixed.

Additionally, _validate_source_directory() and _validate_dependency_path() in common_utils.py both use abs_source.startswith(sensitive_path) which has the same prefix-collision vulnerability as issue #3 (e.g. /etc2 would match /etc). These should use os.path.commonpath() instead.

The fix needs to:

  1. Ensure _get_safe_members() in common_utils.py takes members list and base params (not the TarFile object directly and not _get_resolved_path(''))
  2. Ensure custom_extractall_tarfile() passes tar.getmembers() and _get_resolved_path(extract_path) as base
  3. Ensure _is_bad_path() uses os.path.commonpath() instead of startswith()
  4. Fix _validate_source_directory() and _validate_dependency_path() to use os.path.commonpath() instead of startswith()

Related Issue

Related issue: 5560

Changes Made

  • sagemaker-core/src/sagemaker/core/common_utils.py
  • sagemaker-core/tests/unit/test_common_utils.py

AI-Generated PR

This PR was automatically generated by the PySDK Issue Agent.

  • Confidence score: 85%
  • Classification: bug
  • SDK version target: V3

Merge Checklist

  • Changes are backward compatible
  • Commit message follows prefix: description format
  • Unit tests added/updated
  • Integration tests added (if applicable)
  • Documentation updated (if applicable)

Copy link
Copy Markdown
Member

@mufaddal-rohawala mufaddal-rohawala left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review

This PR fixes a legitimate security vulnerability where startswith() was used for path prefix matching instead of os.path.commonpath(), which could allow prefix-collision bypasses (e.g., /etc2 matching /etc). The fix is correct in principle, but the PR description claims additional changes to _get_safe_members, _is_bad_path, and custom_extractall_tarfile that are NOT actually present in the diff. The test for custom_extractall_tarfile mocks so heavily that it doesn't actually verify the claimed fix. Several tests also have line length violations and structural issues.

def test_validate_dependency_path_blocks_sensitive_path(self):
"""Test that actual sensitive paths are blocked."""
from sagemaker.core.common_utils import _validate_dependency_path

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line length exceeds 100 characters: Same line-length issue. Please wrap the string.

Copy link
Copy Markdown
Member

@mufaddal-rohawala mufaddal-rohawala left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review

This PR fixes security vulnerabilities in tar extraction and path validation by replacing startswith() with os.path.commonpath() to prevent prefix collision attacks. The changes are well-motivated and align with the already-fixed mlops version. However, there are a few issues: a potential ValueError from os.path.commonpath() when paths are on different drives/have no common path, and some test concerns.

@@ -647,7 +647,10 @@ def _validate_source_directory(source_directory):

# Check if the source path is under any sensitive directory
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: os.path.commonpath() raises ValueError when paths have no common path or are a mix of absolute and relative paths. For example, on Windows with different drives, or if abs_source is somehow relative. More critically, if abs_source equals a sensitive path exactly (e.g., /etc), os.path.commonpath(['/etc', '/etc']) returns '/etc' which equals sensitive_path — this is correct and desired. But consider wrapping this in a try/except for ValueError:

try:
    common = os.path.commonpath([abs_source, sensitive_path])
    if abs_source != "/" and common == sensitive_path:
        raise ValueError(...)
except ValueError:
    # No common path (e.g., different drives on Windows) — not under sensitive path
    pass

This defensive pattern should be applied to all four locations where os.path.commonpath() is used in this PR.

"""
# joinpath will ignore base if path is absolute
return not _get_resolved_path(joinpath(base, path)).startswith(base)
resolved = _get_resolved_path(joinpath(base, path))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same ValueError risk from os.path.commonpath(): If resolved and base are on different mount points or one is empty, os.path.commonpath([resolved, base]) will raise ValueError. Since this is a security function, it should treat such cases as bad paths (return True):

def _is_bad_path(path, base):
    resolved = _get_resolved_path(joinpath(base, path))
    try:
        return os.path.commonpath([resolved, base]) != base
    except ValueError:
        return True

@@ -686,10 +692,13 @@ def _create_or_update_code_dir(
"""Placeholder docstring"""
code_dir = os.path.join(model_dir, "code")
resolved_code_dir = _get_resolved_path(code_dir)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same os.path.commonpath() ValueError risk here in _create_or_update_code_dir. Please add the same defensive try/except pattern.

def test_is_bad_path_absolute_escape(self):
"""Test _is_bad_path returns True for absolute path outside base."""
base = _get_resolved_path("/tmp/safe")
assert _is_bad_path("/etc/passwd", base) is True
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test uses hardcoded /tmp/safe path which may not exist on all systems. While _is_bad_path doesn't check existence, _get_resolved_path likely calls os.path.realpath() which could resolve differently depending on the OS. Consider using tmp_path fixture for more portable tests:

def test_is_bad_path_absolute_escape(self, tmp_path):
    base = _get_resolved_path(str(tmp_path / "safe"))
    assert _is_bad_path("/etc/passwd", base) is True

"""Test _is_bad_path correctly flags prefix collision.

/tmp/safe2 starts with /tmp/safe but is NOT under /tmp/safe.
The old startswith() check would miss this; commonpath catches it.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the key regression test for the prefix collision fix — good! However, this test also uses hardcoded /tmp/safe and /tmp/safe2. Consider using tmp_path for portability, and add a comment explaining this is the core test for the startswith()commonpath() fix.

mock_hardlink = Mock()
mock_hardlink.name = "bad/hardlink"
mock_hardlink.issym = Mock(return_value=False)
mock_hardlink.islnk = Mock(return_value=True)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test test_validate_source_directory_prefix_collision assumes /etcetera doesn't resolve to a sensitive path. This is a good test for the prefix collision fix, but note that _validate_source_directory may also check if the path exists or do other validation. If the function only checks against _SENSITIVE_SYSTEM_PATHS, this is fine. Just ensure the test won't break if /etcetera happens to exist on a CI system and resolves to something unexpected.


def test_validate_source_directory_prefix_collision(self):
"""Test /etcetera is NOT blocked by /etc sensitive path.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same concern as the source directory prefix collision test. /rootkit is used to test that it's not blocked by /root. This is a good test case but consider using a more clearly safe path like /root_not_sensitive to avoid any ambiguity.

extract_path = tmp_path / "extract"
extract_path.mkdir()

with tarfile.open(tar_path, "r:gz") as tar:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fragile test pattern: Deleting tarfile.data_filter with delattr and restoring it in a finally block is brittle. If another test runs concurrently or the attribute doesn't exist on older Python versions, this could cause issues. Consider using unittest.mock.patch to mock hasattr or the attribute check instead:

with patch.object(tarfile, 'data_filter', create=False, new_callable=lambda: None):

Or more simply:

with patch('sagemaker.core.common_utils.hasattr', side_effect=lambda obj, name: False if name == 'data_filter' else hasattr(obj, name)):

Actually, the simplest approach would be to mock at the check point rather than mutating the tarfile module.



def _get_safe_members(members):
def _get_safe_members(members, base):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signature change to _get_safe_members is a breaking change for any internal callers. While this is a private function (prefixed with _), verify that no other code in the SDK calls _get_safe_members with the old signature (just members as a TarFile object). The mlops copy was independently fixed, but there could be other callers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants