Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 227 additions & 43 deletions src/specify_cli/__init__.py

Large diffs are not rendered by default.

206 changes: 176 additions & 30 deletions src/specify_cli/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import json
import hashlib
import os
import sys
import tarfile
import tempfile
import zipfile
import shutil
Expand Down Expand Up @@ -106,6 +108,120 @@ def normalize_priority(value: Any, default: int = 10) -> int:
return priority if priority >= 1 else default


def detect_archive_format(url: str, content_type: str = "") -> str:
"""Detect archive format from URL path extension or Content-Type header.

Args:
url: URL or file path to inspect.
content_type: Optional ``Content-Type`` header value from the HTTP response.

Returns:
``"zip"`` for ZIP archives, ``"tar.gz"`` for gzipped tarballs, or ``""``
when the format cannot be determined.
"""
# Strip query-string / fragment before examining the path extension.
url_path = url.split("?")[0].split("#")[0].lower()
if url_path.endswith(".zip"):
return "zip"
if url_path.endswith(".tar.gz") or url_path.endswith(".tgz"):
return "tar.gz"

# Fall back to Content-Type header inspection.
ct = content_type.lower()
if "application/zip" in ct or "application/x-zip" in ct:
return "zip"
if any(
t in ct
for t in (
"application/gzip",
"application/x-gzip",
"application/x-tar+gzip",
)
):
return "tar.gz"

return ""


def safe_extract_tarball(
archive_path: Path,
dest_dir: Path,
error_class: "type[Exception]" = Exception,
) -> None:
"""Safely extract a ``.tar.gz`` or ``.tgz`` archive into *dest_dir*.

All members are validated before extraction to prevent *tar slip*
(path traversal) attacks. Symlinks, hard links, and special files
(devices, FIFOs, etc.) are rejected.

On Python 3.12 and later the ``"data"`` extraction filter is applied
for an additional layer of OS-level protection. On earlier versions
the explicit member list (containing only pre-validated regular files
and directories) is passed to ``extractall()`` — since all symlinks are
already rejected in the validation phase, no archive-introduced symlink
can be followed during extraction.

Args:
archive_path: Path to the ``.tar.gz``/``.tgz`` archive.
dest_dir: Destination directory (must already exist).
error_class: Exception class to raise on unsafe entries.

Raises:
error_class: If any member is unsafe or the archive cannot be read.
"""
dest_resolved = dest_dir.resolve()

try:
with tarfile.open(archive_path, "r:gz") as tf:
members = tf.getmembers()
safe_members = []

# Validate every member before extracting anything.
for member in members:
# Reject absolute paths and any path component that is "..".
if os.path.isabs(member.name) or any(
part == ".." for part in member.name.replace("\\", "/").split("/")
):
raise error_class(
f"Unsafe path in tar archive: {member.name} (potential path traversal)"
)

# Confirm the resolved path stays inside dest_dir.
member_path = (dest_dir / member.name).resolve()
try:
member_path.relative_to(dest_resolved)
except ValueError:
raise error_class(
f"Unsafe path in tar archive: {member.name} (potential path traversal)"
)

# Reject symlinks and hard links.
if member.issym() or member.islnk():
raise error_class(
f"Symlinks are not allowed in archive: {member.name}"
)

# Only allow regular files and directories.
if not (member.isreg() or member.isdir()):
raise error_class(
f"Non-regular file in archive: {member.name}"
)

safe_members.append(member)

# Extract — use the "data" filter on Python 3.12+ for extra hardening.
# On older versions pass only the pre-validated members so that no
# unvetted entry (added concurrently or via a race) slips through.
if sys.version_info >= (3, 12):
tf.extractall(dest_dir, filter="data") # type: ignore[call-arg]
else:
tf.extractall(dest_dir, members=safe_members) # noqa: S202 — validated above
except error_class:
raise
except (tarfile.TarError, OSError) as e:
raise error_class(f"Failed to read archive {archive_path}: {e}") from e


@dataclass
class CatalogEntry:
"""Represents a single catalog entry in the catalog stack."""
Expand Down Expand Up @@ -1202,18 +1318,19 @@ def install_from_zip(
speckit_version: str,
priority: int = 10,
) -> ExtensionManifest:
"""Install extension from ZIP file.
"""Install extension from a ZIP or ``.tar.gz``/``.tgz`` archive.

Args:
zip_path: Path to extension ZIP file
zip_path: Path to the extension archive (ZIP or gzipped tarball).
speckit_version: Current spec-kit version
priority: Resolution priority (lower = higher precedence, default 10)

Returns:
Installed extension manifest

Raises:
ValidationError: If manifest is invalid or priority is invalid
ValidationError: If manifest is invalid, the archive is unsafe, or
priority is invalid
CompatibilityError: If extension is incompatible
"""
# Validate priority early
Expand All @@ -1223,21 +1340,27 @@ def install_from_zip(
with tempfile.TemporaryDirectory() as tmpdir:
temp_path = Path(tmpdir)

# Extract ZIP safely (prevent Zip Slip attack)
with zipfile.ZipFile(zip_path, 'r') as zf:
# Validate all paths first before extracting anything
temp_path_resolved = temp_path.resolve()
for member in zf.namelist():
member_path = (temp_path / member).resolve()
# Use is_relative_to for safe path containment check
try:
member_path.relative_to(temp_path_resolved)
except ValueError:
raise ValidationError(
f"Unsafe path in ZIP archive: {member} (potential path traversal)"
)
# Only extract after all paths are validated
zf.extractall(temp_path)
archive_fmt = detect_archive_format(str(zip_path))

if archive_fmt == "tar.gz":
# Extract tarball safely (prevent tar slip attack)
safe_extract_tarball(zip_path, temp_path, ValidationError)
else:
# Extract ZIP safely (prevent Zip Slip attack)
with zipfile.ZipFile(zip_path, 'r') as zf:
# Validate all paths first before extracting anything
temp_path_resolved = temp_path.resolve()
for member in zf.namelist():
member_path = (temp_path / member).resolve()
# Use is_relative_to for safe path containment check
try:
member_path.relative_to(temp_path_resolved)
except ValueError:
raise ValidationError(
f"Unsafe path in ZIP archive: {member} (potential path traversal)"
)
# Only extract after all paths are validated
zf.extractall(temp_path)

# Find extension directory (may be nested)
extension_dir = temp_path
Expand All @@ -1251,7 +1374,7 @@ def install_from_zip(
manifest_path = extension_dir / "extension.yml"

if not manifest_path.exists():
raise ValidationError("No extension.yml found in ZIP file")
raise ValidationError("No extension.yml found in archive")

# Install from extracted directory
return self.install_from_directory(extension_dir, speckit_version, priority=priority)
Expand Down Expand Up @@ -1965,14 +2088,18 @@ def get_extension_info(self, extension_id: str) -> Optional[Dict[str, Any]]:
return None

def download_extension(self, extension_id: str, target_dir: Optional[Path] = None) -> Path:
"""Download extension ZIP from catalog.
"""Download extension archive from catalog.

Supports both ZIP (``.zip``) and gzipped tarball (``.tar.gz``/``.tgz``)
archives. The format is detected from the download URL's path extension;
when ambiguous the ``Content-Type`` header is used as a fallback.

Args:
extension_id: ID of the extension to download
target_dir: Directory to save ZIP file (defaults to temp directory)
target_dir: Directory to save the archive (defaults to cache directory)

Returns:
Path to downloaded ZIP file
Path to downloaded archive file

Raises:
ExtensionError: If extension not found or download fails
Expand Down Expand Up @@ -2011,21 +2138,40 @@ def download_extension(self, extension_id: str, target_dir: Optional[Path] = Non
target_dir.mkdir(parents=True, exist_ok=True)

version = ext_info.get("version", "unknown")
zip_filename = f"{extension_id}-{version}.zip"
zip_path = target_dir / zip_filename

# Download the ZIP file
# Detect archive format from URL; resolve via Content-Type when needed.
archive_fmt = detect_archive_format(download_url)

# Download the archive
try:
with self._open_url(download_url, timeout=60) as response:
zip_data = response.read()

zip_path.write_bytes(zip_data)
return zip_path
if not archive_fmt:
content_type = response.headers.get("Content-Type", "")
archive_fmt = detect_archive_format(download_url, content_type)
archive_data = response.read()

except urllib.error.URLError as e:
raise ExtensionError(f"Failed to download extension from {download_url}: {e}")
except IOError as e:
raise ExtensionError(f"Failed to save extension ZIP: {e}")
raise ExtensionError(f"Failed to read extension archive from {download_url}: {e}")

# Choose file extension based on detected format.
if not archive_fmt:
raise ExtensionError(
f"Could not determine archive format for {download_url}. "
"Ensure the URL points to a .zip or .tar.gz/.tgz file."
)
if archive_fmt == "tar.gz":
archive_filename = f"{extension_id}-{version}.tar.gz"
else:
archive_filename = f"{extension_id}-{version}.zip"

Comment thread
mnriem marked this conversation as resolved.
archive_path = target_dir / archive_filename
try:
archive_path.write_bytes(archive_data)
except IOError as e:
raise ExtensionError(f"Failed to save extension archive: {e}")
return archive_path

def clear_cache(self):
"""Clear the catalog cache (both legacy and URL-hash-based files)."""
Expand Down
Loading