From 0520e06ebda91e5187c6af8a4674d87eada4cc77 Mon Sep 17 00:00:00 2001 From: kunwar-vp Date: Thu, 21 May 2026 12:02:52 -0700 Subject: [PATCH] release: graphn 0.1.6 (LoRA support + custom-model update endpoint) Two spec-sync PRs (#12 on 2026-05-14, #15 on 2026-05-16) landed regenerated _generated/ on main but neither bumped pyproject, so the new control-plane surface has been sitting in the source tree unreleased on PyPI for a week. This PR closes the gap: bumps to 0.1.6, ships matching ergonomic wrappers + typed fields on the hand-curated resource layer, and adds tests so the new surface is covered, not just compiled. New high-level surface on client.custom_models (sync + async): - update(model_id, *, name=..., min_replicas=..., max_replicas=..., cooldown_seconds=..., extra=...) issues PATCH /v1/{ws}/custom-models/{id}. In-place mutation of the live deployment - no rolling restart, no downtime. Empty PATCH is refused client-side with ValidationError(code="empty_update") one round-trip earlier than the server's 422; an `extra` mapping lets callers PATCH future fields without an SDK release. - supported_architectures() returns a typed SupportedArchitectures catalog from GET /v1/{ws}/custom-models/supported-architectures. Each ArchitectureInfo carries the capability tags (tool_calling, vision, image_input, video_input, streaming, json_mode) the architecture exposes. Intended for driving UI architecture/ capability filters before calling validate(). - create(..., base_model_id=...) wires up the LoRA-import hint. Required on weight_source=s3_* to classify the bundle as an adapter at create-time; optional on weight_source=huggingface where it overrides adapter_config.json::base_model_name_or_path from the upstream repo (useful when the recorded base id isn't a valid HF id, e.g. a local filesystem path used during training). - validate(..., model_size_gb=...) lets callers skip the HF head-bytes probe by supplying a weight-size hint, useful for very large models (405B-class) where the probe stalls validate. Typed LoRA fields on the existing Pydantic types: - CustomModel: artifact_type ("base"|"lora"|None), base_model_id, lora_adapter_name, lora_rank. artifact_type is None on responses from control planes that predate the LoRA work - treat that as "base" for compatibility. - ValidateModelResponse: artifact_type (defaults to "base" on fresh responses, None on legacy), detected_base_model_id, lora_rank. When artifact_type == "lora", the architectures / num_params / estimated_memory_gb / max_context_length fields describe the base model resolved from adapter_config.json, not the adapter itself. New public exports: ArchitectureInfo, SupportedArchitectures, ArtifactType from graphn (and graphn.custom_models). CustomModelCreate.huggingface_model_id is now required on the generated attrs dataclass (was str | Unset). The server has returned 422 for omitted huggingface_model_id on every weight source since 0.1.3 (voltagepark/takao#1997) and the hand-curated client.custom_models.create resource raises ValidationError client-side for S3 imports, so this is the generated type catching up - callers using the keyword-only ergonomic API are unaffected. Tests: 57 pass (43 existing + 14 new) covering both transports. ruff check clean. mypy is clean on every file this PR touches (pre-existing no-any-return errors in _transport.py and tts.py are on main and not regressions). The auto-tag job's CHANGELOG check matches "## [0.1.6] - 2026-05-21" so PyPI publish fires automatically on merge. --- CHANGELOG.md | 87 ++++++++ pyproject.toml | 2 +- src/graphn/__init__.py | 6 + src/graphn/custom_models/__init__.py | 13 +- src/graphn/custom_models/resource.py | 142 ++++++++++++++ src/graphn/custom_models/types.py | 70 +++++++ tests/test_custom_models.py | 283 +++++++++++++++++++++++++++ 7 files changed, 601 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a29d14..26ca4b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,93 @@ No `git tag`, no `git push --tags`, no Actions clicks. ## [Unreleased] +## [0.1.6] — 2026-05-21 + +Spec-sync release plus a matching round of ergonomic wrappers. Picks +up two new custom-model control-plane endpoints, exposes them +through `client.custom_models`, and adds typed LoRA-adapter fields +to the public `CustomModel` and `ValidateModelResponse` Pydantic +models. The low-level generated client (`graphn._generated`) and +the hand-curated resource layer (`graphn.custom_models`) are both +fully in sync with the upstream OpenAPI spec. + +### Added + +- `client.custom_models.update(model_id, *, name=..., min_replicas=..., + max_replicas=..., cooldown_seconds=..., extra=...)` (sync and async). + Issues `PATCH /v1/{workspaceId}/custom-models/{modelId}` against the + control plane and returns the refreshed :class:`CustomModel`. + Mutates a vetted set of post-create fields in place against the + live deployment — no rolling restart, no downtime. Immutable fields + (`huggingface_model_id`, `weight_source`, GPU topology, …) are not + exposed; change them by deleting and re-creating the model. The SDK + refuses an empty PATCH client-side (raises + `graphn.ValidationError` with code `empty_update`), one round-trip + earlier than the server's `422`. +- `client.custom_models.supported_architectures()` (sync and async). + Returns a typed :class:`SupportedArchitectures` catalog of model + architectures the platform's serving runtimes can deploy, each + annotated with the capability tags (`tool_calling`, `vision`, + `image_input`, `video_input`, `streaming`, `json_mode`) it exposes. + Intended for driving architecture/capability filters in client UIs + before calling :meth:`client.custom_models.validate`. The list is + updated alongside platform runtime upgrades; clients should not + cache it across build cycles. +- LoRA-adapter visibility on the existing types. `CustomModel` gains + `artifact_type` (`Literal["base", "lora"] | None`), `base_model_id`, + `lora_adapter_name`, and `lora_rank` typed fields; older control + planes that predate the LoRA work leave `artifact_type` unset and + should be treated as `"base"` for compatibility. `ValidateModelResponse` + gains `artifact_type`, `detected_base_model_id`, and `lora_rank` so + callers can detect that a HuggingFace repo contains a LoRA adapter + (via `adapter_config.json`) before deploying. When + `artifact_type == "lora"` on the validate response, the + `architectures` / `num_params` / `estimated_memory_gb` / + `max_context_length` fields describe the **base** model resolved + from `adapter_config.json`, not the adapter itself. +- `client.custom_models.create(..., base_model_id=...)` (sync and async). + Required on `weight_source=s3_*` LoRA imports — it's the only way to + classify the bundle as an adapter at create time; omitting it routes + the import through the base path, and a LoRA bundle that wasn't + declared will deploy to `failed` with an actionable error. Optional + on `weight_source=huggingface`, where it **overrides** + `adapter_config.json::base_model_name_or_path` from the upstream + adapter repo — useful when the recorded base id isn't a valid + HuggingFace id (e.g. a local filesystem path used during training). + The base id must be one of the platform's allowlisted bases (see + `client.custom_models.supported_architectures()`). +- `client.custom_models.validate(..., model_size_gb=...)` (sync and + async). Optional caller-supplied estimate (in GiB) of the on-disk + weights size. When provided, the platform sizes the model-weights + PVC from this hint instead of waiting for a HuggingFace head-bytes + probe; useful for very large models (e.g. 405B) where the probe + would otherwise stall the validate response. +- New public exports from `graphn`: `ArchitectureInfo`, + `SupportedArchitectures`, `ArtifactType`. + +### Changed + +- `CustomModelCreate.huggingface_model_id` is now a **required** + field on the generated `attrs` dataclass (previously + `str | Unset`). This mirrors the server-side behavior already + shipped in v0.1.3 (the control plane has returned `422` for omitted + `huggingface_model_id` on every weight source since + voltagepark/takao#1997) and the client-side `ValidationError` the + high-level `client.custom_models.create` resource has raised since + v0.1.3. The generated type just catches up; callers using the + hand-curated `client.custom_models.create` keyword-only API are + unaffected — the high-level resource still accepts + `huggingface_model_id` as a keyword argument and the existing + client-side guard fires before the request is built. +- `CustomModelCreate.s3_role_arn` docstring now records the + `graphn-byom-*` role-name prefix the platform enforces. No wire or + validation change in the SDK; the constraint has been server-side + since 0.1.3 and the customer-facing CloudFormation template + enforces the same prefix at stack-create time. Doc-only. +- `CustomModel.gpu_memory_utilization` docstring no longer names the + serving engine (`vLLM`). Engine-agnostic wording aligns with the + 0.1.3 scrub of customer-facing serving-engine references. + ## [0.1.5] — 2026-05-14 Patch release. Widens the upper bound on the `openai` runtime diff --git a/pyproject.toml b/pyproject.toml index ded26be..8cd6bae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "graphn" -version = "0.1.5" +version = "0.1.6" description = "Official Python SDK for the Graphn API: custom-model lifecycle, secrets, and OpenAI-compatible inference." readme = "README.md" requires-python = ">=3.10" diff --git a/src/graphn/__init__.py b/src/graphn/__init__.py index 5bbf2db..767a376 100644 --- a/src/graphn/__init__.py +++ b/src/graphn/__init__.py @@ -22,12 +22,15 @@ from graphn._pagination import AsyncPage, SyncPage from graphn._version import __version__ from graphn.custom_models.types import ( + ArchitectureInfo, + ArtifactType, Capability, CustomModel, CustomModelAccess, CustomModelStatus, GpuHoursResponse, Quantization, + SupportedArchitectures, ValidateModelResponse, WeightSource, ) @@ -36,6 +39,8 @@ __all__ = [ "APIConnectionError", "APIError", + "ArchitectureInfo", + "ArtifactType", "AsyncClient", "AsyncPage", "AuthenticationError", @@ -53,6 +58,7 @@ "RateLimitError", "Secret", "ServerError", + "SupportedArchitectures", "SyncPage", "ValidateModelResponse", "ValidationError", diff --git a/src/graphn/custom_models/__init__.py b/src/graphn/custom_models/__init__.py index 97da154..c64d93a 100644 --- a/src/graphn/custom_models/__init__.py +++ b/src/graphn/custom_models/__init__.py @@ -1,10 +1,21 @@ """Custom-model resource module.""" from graphn.custom_models.types import ( + ArchitectureInfo, + ArtifactType, Capability, CustomModel, CustomModelStatus, + SupportedArchitectures, WeightSource, ) -__all__ = ["Capability", "CustomModel", "CustomModelStatus", "WeightSource"] +__all__ = [ + "ArchitectureInfo", + "ArtifactType", + "Capability", + "CustomModel", + "CustomModelStatus", + "SupportedArchitectures", + "WeightSource", +] diff --git a/src/graphn/custom_models/resource.py b/src/graphn/custom_models/resource.py index 664f548..28110e8 100644 --- a/src/graphn/custom_models/resource.py +++ b/src/graphn/custom_models/resource.py @@ -30,6 +30,7 @@ CustomModelStatus, GpuHoursResponse, Quantization, + SupportedArchitectures, ValidateModelResponse, WeightSource, ) @@ -65,6 +66,7 @@ def _build_create_body( min_replicas: int | None, max_replicas: int | None, cooldown_seconds: int | None, + base_model_id: str | None, extra: Mapping[str, Any] | None, ) -> dict[str, Any]: if weight_source in _S3_WEIGHT_SOURCES and not (huggingface_model_id or "").strip(): @@ -107,6 +109,8 @@ def _build_create_body( body["max_replicas"] = max_replicas if cooldown_seconds is not None: body["cooldown_seconds"] = cooldown_seconds + if base_model_id is not None: + body["base_model_id"] = base_model_id if extra: body.update(extra) return body @@ -118,6 +122,7 @@ def _build_validate_body( hf_token_secret_id: str | None, quantization: Quantization | None, gpu_memory_utilization: float | None, + model_size_gb: int | None, ) -> dict[str, Any]: body: dict[str, Any] = {"huggingface_model_id": huggingface_model_id} if hf_token_secret_id is not None: @@ -126,6 +131,42 @@ def _build_validate_body( body["quantization"] = quantization if gpu_memory_utilization is not None: body["gpu_memory_utilization"] = gpu_memory_utilization + if model_size_gb is not None: + body["model_size_gb"] = model_size_gb + return body + + +def _build_update_body( + *, + name: str | None, + min_replicas: int | None, + max_replicas: int | None, + cooldown_seconds: int | None, + extra: Mapping[str, Any] | None, +) -> dict[str, Any]: + body: dict[str, Any] = {} + if name is not None: + body["name"] = name + if min_replicas is not None: + body["min_replicas"] = min_replicas + if max_replicas is not None: + body["max_replicas"] = max_replicas + if cooldown_seconds is not None: + body["cooldown_seconds"] = cooldown_seconds + if extra: + body.update(extra) + if not body: + raise ValidationError( + ( + "update() requires at least one of: name, min_replicas, " + "max_replicas, cooldown_seconds (or an `extra` mapping with " + "additional fields). The control plane rejects empty PATCH " + "bodies with 422; this check fires one round-trip earlier." + ), + status_code=422, + code="empty_update", + details={}, + ) return body @@ -162,6 +203,7 @@ def create( min_replicas: int | None = None, max_replicas: int | None = None, cooldown_seconds: int | None = None, + base_model_id: str | None = None, extra: Mapping[str, Any] | None = None, idempotency_key: str | None = None, ) -> CustomModel: @@ -181,6 +223,7 @@ def create( min_replicas=min_replicas, max_replicas=max_replicas, cooldown_seconds=cooldown_seconds, + base_model_id=base_model_id, extra=extra, ) data = self._transport.request( @@ -191,6 +234,64 @@ def create( ) return CustomModel.model_validate(data) + def update( + self, + model_id: str, + *, + name: str | None = None, + min_replicas: int | None = None, + max_replicas: int | None = None, + cooldown_seconds: int | None = None, + extra: Mapping[str, Any] | None = None, + ) -> CustomModel: + """Partially update mutable fields on a custom model. + + Applies in place to the live deployment — no rolling restart, + no downtime. Immutable fields (``huggingface_model_id``, + ``weight_source``, GPU topology, …) are not exposed; change + them by deleting and re-creating the model. At least one + keyword argument (or an ``extra`` field) must be supplied. + + Raises + ------ + graphn.ValidationError + If no fields are provided (code ``empty_update``). + graphn.NotFoundError + If ``model_id`` doesn't exist in the workspace. + """ + + body = _build_update_body( + name=name, + min_replicas=min_replicas, + max_replicas=max_replicas, + cooldown_seconds=cooldown_seconds, + extra=extra, + ) + data = self._transport.request( + "PATCH", + self._transport.cp_path("custom-models", model_id), + json=body, + ) + return CustomModel.model_validate(data) + + def supported_architectures(self) -> SupportedArchitectures: + """List HuggingFace architectures the platform can deploy. + + Returns the static catalog of model architectures supported by + the platform's serving runtimes, each annotated with the + capability tags (``tool_calling``, ``vision``, …) it exposes. + Use this to drive architecture/capability filters in client + UIs before calling :meth:`validate`. The list is updated + alongside platform runtime upgrades; clients should not cache + it across build cycles. + """ + + data = self._transport.request( + "GET", + self._transport.cp_path("custom-models", "supported-architectures"), + ) + return SupportedArchitectures.model_validate(data) + def list( self, *, @@ -250,12 +351,14 @@ def validate( hf_token_secret_id: str | None = None, quantization: Quantization | None = None, gpu_memory_utilization: float | None = None, + model_size_gb: int | None = None, ) -> ValidateModelResponse: body = _build_validate_body( huggingface_model_id=huggingface_model_id, hf_token_secret_id=hf_token_secret_id, quantization=quantization, gpu_memory_utilization=gpu_memory_utilization, + model_size_gb=model_size_gb, ) data = self._transport.request( "POST", @@ -332,6 +435,7 @@ async def create( min_replicas: int | None = None, max_replicas: int | None = None, cooldown_seconds: int | None = None, + base_model_id: str | None = None, extra: Mapping[str, Any] | None = None, idempotency_key: str | None = None, ) -> CustomModel: @@ -351,6 +455,7 @@ async def create( min_replicas=min_replicas, max_replicas=max_replicas, cooldown_seconds=cooldown_seconds, + base_model_id=base_model_id, extra=extra, ) data = await self._transport.request( @@ -361,6 +466,41 @@ async def create( ) return CustomModel.model_validate(data) + async def update( + self, + model_id: str, + *, + name: str | None = None, + min_replicas: int | None = None, + max_replicas: int | None = None, + cooldown_seconds: int | None = None, + extra: Mapping[str, Any] | None = None, + ) -> CustomModel: + """Asynchronous mirror of :meth:`CustomModels.update`.""" + + body = _build_update_body( + name=name, + min_replicas=min_replicas, + max_replicas=max_replicas, + cooldown_seconds=cooldown_seconds, + extra=extra, + ) + data = await self._transport.request( + "PATCH", + self._transport.cp_path("custom-models", model_id), + json=body, + ) + return CustomModel.model_validate(data) + + async def supported_architectures(self) -> SupportedArchitectures: + """Asynchronous mirror of :meth:`CustomModels.supported_architectures`.""" + + data = await self._transport.request( + "GET", + self._transport.cp_path("custom-models", "supported-architectures"), + ) + return SupportedArchitectures.model_validate(data) + async def list( self, *, @@ -422,12 +562,14 @@ async def validate( hf_token_secret_id: str | None = None, quantization: Quantization | None = None, gpu_memory_utilization: float | None = None, + model_size_gb: int | None = None, ) -> ValidateModelResponse: body = _build_validate_body( huggingface_model_id=huggingface_model_id, hf_token_secret_id=hf_token_secret_id, quantization=quantization, gpu_memory_utilization=gpu_memory_utilization, + model_size_gb=model_size_gb, ) data = await self._transport.request( "POST", diff --git a/src/graphn/custom_models/types.py b/src/graphn/custom_models/types.py index 12b7f71..759a61e 100644 --- a/src/graphn/custom_models/types.py +++ b/src/graphn/custom_models/types.py @@ -31,6 +31,9 @@ Quantization = Literal["awq", "gptq", "fp8", "squeezellm", "marlin", "gguf"] +ArtifactType = Literal["base", "lora"] +"""Whether a custom-model import is a full base checkpoint or a LoRA adapter.""" + class CustomModel(BaseModel): """Public custom-model record. @@ -69,6 +72,25 @@ class CustomModel(BaseModel): estimated_memory_gb: float | None = None architectures: list[str] | None = None + artifact_type: ArtifactType | None = None + """``"base"`` for full checkpoints, ``"lora"`` for adapter imports. + + Set eagerly at create-time. HuggingFace imports are classified by + probing ``adapter_config.json`` on the upstream repo; S3 imports are + classified as ``"lora"`` iff ``base_model_id`` is supplied on + :meth:`CustomModels.create`. Older control planes that predate the + LoRA work leave this field unset on existing records — treat + ``None`` as ``"base"`` for compatibility. + """ + base_model_id: str | None = None + """Base model id this adapter loads on top of (populated when + ``artifact_type == "lora"``).""" + lora_adapter_name: str | None = None + """vLLM routing name the adapter is served under. Clients address + the adapter via ``model=`` in chat completions.""" + lora_rank: int | None = None + """``r`` value from the adapter's ``adapter_config.json``.""" + class CustomModelAccess(BaseModel): """Workspace allowlist check result.""" @@ -98,3 +120,51 @@ class ValidateModelResponse(BaseModel): num_params: int | None = None estimated_memory_gb: float | None = None max_context_length: int | None = None + + artifact_type: ArtifactType | None = None + """``"lora"`` when AF detected an ``adapter_config.json`` in the + HuggingFace repo, ``"base"`` otherwise. Defaults to ``"base"`` on + fresh responses; older control planes may omit the field entirely, + in which case the bundle should be treated as a base checkpoint. + + When ``artifact_type == "lora"``, ``architectures``, ``num_params``, + ``estimated_memory_gb``, and ``max_context_length`` describe the + *base* model resolved from ``adapter_config.json`` — not the + adapter itself. + """ + detected_base_model_id: str | None = None + """Base model id read from ``adapter_config.json::base_model_name_or_path``. + Populated only when ``artifact_type == "lora"``.""" + lora_rank: int | None = None + """``r`` value from the adapter's ``adapter_config.json``. + Populated only when ``artifact_type == "lora"``.""" + + +class ArchitectureInfo(BaseModel): + """A HuggingFace architecture the platform can serve. + + Returned as elements of :class:`SupportedArchitectures.architectures`. + """ + + model_config = ConfigDict(extra="allow", frozen=True) + + name: str + """HuggingFace ``architectures[0]`` value (e.g. ``"LlamaForCausalLM"``, + ``"Qwen3VLMoeForConditionalGeneration"``).""" + capabilities: list[str] = Field(default_factory=list) + """Capability tags this architecture exposes — ``"tool_calling"``, + ``"vision"``, ``"image_input"``, ``"video_input"``, ``"streaming"``, + ``"json_mode"``.""" + + +class SupportedArchitectures(BaseModel): + """Catalog of model architectures supported for custom-model import. + + Returned by :meth:`CustomModels.supported_architectures`. The list + is updated alongside platform runtime upgrades; clients should not + cache it across build cycles. + """ + + model_config = ConfigDict(extra="allow", frozen=True) + + architectures: list[ArchitectureInfo] = Field(default_factory=list) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index c620e82..3c7efde 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -333,3 +333,286 @@ async def test_async_create_s3_requires_huggingface_model_id( ) assert exc_info.value.code == "missing_huggingface_model_id" assert respx_mock.calls.call_count == 0 + + +def test_create_s3_lora_passes_base_model_id( + client: Client, respx_mock: respx.MockRouter +) -> None: + """`base_model_id` is the only way to classify an S3 bundle as LoRA at create.""" + + route = respx_mock.post(cp_url("custom-models")).mock( + return_value=httpx.Response( + 201, + json=_model_payload( + weight_source="s3_presigned", + huggingface_model_id="org/qwen3-finetune", + artifact_type="lora", + base_model_id="Qwen/Qwen3-4B", + ), + ) + ) + + model = client.custom_models.create( + name="my-lora", + weight_source="s3_presigned", + huggingface_model_id="org/qwen3-finetune", + s3_url="https://example.com/lora.tar.gz", + base_model_id="Qwen/Qwen3-4B", + ) + + assert route.called + sent = json.loads(route.calls.last.request.content) + assert sent["base_model_id"] == "Qwen/Qwen3-4B" + assert model.artifact_type == "lora" + assert model.base_model_id == "Qwen/Qwen3-4B" + + +def test_create_huggingface_lora_override_passes_base_model_id( + client: Client, respx_mock: respx.MockRouter +) -> None: + """On HF imports `base_model_id` overrides `adapter_config.json::base_model_name_or_path`.""" + + route = respx_mock.post(cp_url("custom-models")).mock( + return_value=httpx.Response(201, json=_model_payload()) + ) + + client.custom_models.create( + name="my-lora", + huggingface_model_id="org/some-lora", + base_model_id="meta-llama/Llama-3-8B", + ) + + sent = json.loads(route.calls.last.request.content) + assert sent["base_model_id"] == "meta-llama/Llama-3-8B" + + +def test_get_returns_typed_lora_fields( + client: Client, respx_mock: respx.MockRouter +) -> None: + respx_mock.get(cp_url("custom-models/cm_lora")).mock( + return_value=httpx.Response( + 200, + json=_model_payload( + id="cm_lora", + artifact_type="lora", + base_model_id="Qwen/Qwen3-4B", + lora_adapter_name="my-lora", + lora_rank=16, + ), + ) + ) + + model = client.custom_models.get("cm_lora") + assert model.artifact_type == "lora" + assert model.base_model_id == "Qwen/Qwen3-4B" + assert model.lora_adapter_name == "my-lora" + assert model.lora_rank == 16 + + +def test_get_legacy_response_treats_artifact_type_as_none( + client: Client, respx_mock: respx.MockRouter +) -> None: + """Older control planes don't return `artifact_type`; SDK must tolerate it.""" + + respx_mock.get(cp_url("custom-models/cm_legacy")).mock( + return_value=httpx.Response(200, json=_model_payload(id="cm_legacy")) + ) + + model = client.custom_models.get("cm_legacy") + assert model.artifact_type is None + assert model.base_model_id is None + assert model.lora_adapter_name is None + assert model.lora_rank is None + + +def test_validate_returns_lora_fields( + client: Client, respx_mock: respx.MockRouter +) -> None: + respx_mock.post(cp_url("custom-models/validate")).mock( + return_value=httpx.Response( + 200, + json={ + "valid": True, + "artifact_type": "lora", + "detected_base_model_id": "Qwen/Qwen3-4B", + "lora_rank": 16, + "architectures": ["Qwen3ForCausalLM"], + "num_params": 7_500_000_000, + }, + ) + ) + + resp = client.custom_models.validate(huggingface_model_id="org/some-lora") + assert resp.valid is True + assert resp.artifact_type == "lora" + assert resp.detected_base_model_id == "Qwen/Qwen3-4B" + assert resp.lora_rank == 16 + + +def test_validate_forwards_model_size_gb( + client: Client, respx_mock: respx.MockRouter +) -> None: + route = respx_mock.post(cp_url("custom-models/validate")).mock( + return_value=httpx.Response(200, json={"valid": True}) + ) + + client.custom_models.validate( + huggingface_model_id="meta-llama/Llama-3-405B", + model_size_gb=812, + ) + sent = json.loads(route.calls.last.request.content) + assert sent["model_size_gb"] == 812 + + +def test_update_sends_patch_with_body( + client: Client, respx_mock: respx.MockRouter +) -> None: + route = respx_mock.patch(cp_url("custom-models/cm_01")).mock( + return_value=httpx.Response( + 200, + json=_model_payload( + status="ready", + min_replicas=1, + max_replicas=4, + cooldown_seconds=300, + display_name="renamed", + ), + ) + ) + + model = client.custom_models.update( + "cm_01", + name="renamed", + min_replicas=1, + max_replicas=4, + cooldown_seconds=300, + ) + + assert route.called + assert route.calls.last.request.method == "PATCH" + sent = json.loads(route.calls.last.request.content) + assert sent == { + "name": "renamed", + "min_replicas": 1, + "max_replicas": 4, + "cooldown_seconds": 300, + } + assert model.min_replicas == 1 + assert model.max_replicas == 4 + assert model.cooldown_seconds == 300 + + +def test_update_rejects_empty_body( + client: Client, respx_mock: respx.MockRouter +) -> None: + """An empty PATCH must fail client-side, never hitting the wire.""" + + with pytest.raises(ValidationError) as exc_info: + client.custom_models.update("cm_01") + assert exc_info.value.code == "empty_update" + assert respx_mock.calls.call_count == 0 + + +def test_update_404_maps_to_not_found( + client: Client, respx_mock: respx.MockRouter +) -> None: + respx_mock.patch(cp_url("custom-models/missing")).mock( + return_value=httpx.Response( + 404, json={"code": "not_found", "message": "no such model"} + ) + ) + + with pytest.raises(NotFoundError): + client.custom_models.update("missing", min_replicas=1) + + +def test_update_extra_passes_through( + client: Client, respx_mock: respx.MockRouter +) -> None: + """`extra` lets callers PATCH future fields without an SDK release.""" + + route = respx_mock.patch(cp_url("custom-models/cm_01")).mock( + return_value=httpx.Response(200, json=_model_payload()) + ) + + client.custom_models.update("cm_01", extra={"future_field": "value"}) + sent = json.loads(route.calls.last.request.content) + assert sent == {"future_field": "value"} + + +def test_supported_architectures_returns_typed( + client: Client, respx_mock: respx.MockRouter +) -> None: + respx_mock.get(cp_url("custom-models/supported-architectures")).mock( + return_value=httpx.Response( + 200, + json={ + "architectures": [ + { + "name": "LlamaForCausalLM", + "capabilities": ["tool_calling", "streaming"], + }, + { + "name": "Qwen3VLMoeForConditionalGeneration", + "capabilities": ["vision", "image_input", "streaming"], + }, + ], + }, + ) + ) + + catalog = client.custom_models.supported_architectures() + assert [a.name for a in catalog.architectures] == [ + "LlamaForCausalLM", + "Qwen3VLMoeForConditionalGeneration", + ] + assert catalog.architectures[1].capabilities == ["vision", "image_input", "streaming"] + + +@pytest.mark.asyncio +async def test_async_update_sends_patch_with_body( + async_client: AsyncClient, respx_mock: respx.MockRouter +) -> None: + route = respx_mock.patch(cp_url("custom-models/cm_01")).mock( + return_value=httpx.Response( + 200, json=_model_payload(min_replicas=2, max_replicas=6) + ) + ) + + model = await async_client.custom_models.update( + "cm_01", min_replicas=2, max_replicas=6 + ) + + assert route.called + sent = json.loads(route.calls.last.request.content) + assert sent == {"min_replicas": 2, "max_replicas": 6} + assert model.min_replicas == 2 + assert model.max_replicas == 6 + + +@pytest.mark.asyncio +async def test_async_update_rejects_empty_body( + async_client: AsyncClient, respx_mock: respx.MockRouter +) -> None: + with pytest.raises(ValidationError): + await async_client.custom_models.update("cm_01") + assert respx_mock.calls.call_count == 0 + + +@pytest.mark.asyncio +async def test_async_supported_architectures( + async_client: AsyncClient, respx_mock: respx.MockRouter +) -> None: + respx_mock.get(cp_url("custom-models/supported-architectures")).mock( + return_value=httpx.Response( + 200, + json={ + "architectures": [ + {"name": "LlamaForCausalLM", "capabilities": ["tool_calling"]}, + ], + }, + ) + ) + + catalog = await async_client.custom_models.supported_architectures() + assert catalog.architectures[0].name == "LlamaForCausalLM"