diff --git a/simvue/api/objects/artifact/fetch.py b/simvue/api/objects/artifact/fetch.py index 80a48312..67aab36d 100644 --- a/simvue/api/objects/artifact/fetch.py +++ b/simvue/api/objects/artifact/fetch.py @@ -115,7 +115,7 @@ def from_run( @classmethod def from_name( - cls, run_id: str, name: str, **kwargs + cls, run_id: str, name: str, force_overwrite: bool = False, **kwargs ) -> typing.Union[FileArtifact | ObjectArtifact, None]: """Retrieve an artifact by name. @@ -125,11 +125,20 @@ def from_name( the identifier of the run to retrieve from. name : str the name of the artifact to retrieve. + force_overwrite : bool, optional + if duplicates are detected force download + the first match, default of False + will raise an exception Returns ------- FileArtifact | ObjectArtifact | None the artifact if found + + Raises + ------ + RuntimeError + when duplicate artifacts are found within a single run """ _temp = ArtifactBase(**kwargs) _url = URL(_temp._user_config.server.url) / f"runs/{run_id}/artifacts" @@ -144,7 +153,7 @@ def from_name( if _response.status_code == http.HTTPStatus.NOT_FOUND or not _json_response: raise ObjectNotFoundError(_temp._label, name, extra=f"for run '{run_id}'") - if (_n_res := len(_json_response)) > 1: + if (_n_res := len(_json_response)) > 1 and not force_overwrite: raise RuntimeError( f"Expected single result for artifact '{name}' for run '{run_id}'" f" but got {_n_res}" diff --git a/tests/unit/test_file_artifact.py b/tests/unit/test_file_artifact.py index 0c778d17..8127d2c5 100644 --- a/tests/unit/test_file_artifact.py +++ b/tests/unit/test_file_artifact.py @@ -175,3 +175,48 @@ def test_file_artifact_creation_offline_updated(offline_cache_setup, caplog, sna _run.delete() _folder.delete() + +@pytest.mark.api +@pytest.mark.online +@pytest.mark.parametrize( + "force_overwrite", (True, False), + ids=("allow_overwrite", "raise_exception") +) +def test_download_duplicate_artifact(force_overwrite: bool) -> None: + _uuid: str = f"{uuid.uuid4()}".split("-")[0] + _folder_name = f"/simvue_unit_testing/{_uuid}" + _folder = Folder.new(path=_folder_name) + _run = Run.new(folder=_folder_name) + _folder.commit() + _run.commit() + + with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_f: + _path = pathlib.Path(temp_f.name) + with _path.open("w") as out_f: + out_f.write(f"Hello World! {_uuid}") + for _ in range(2): + _artifact = FileArtifact.new( + name=f"test_file_artifact_{_uuid}", + file_path=_path, + storage=None, + mime_type=None, + metadata=None, + ) + _artifact.attach_to_run(_run.id, "input") + time.sleep(1) + + if force_overwrite: + assert Artifact.from_name( + run_id=_run.id, + name=f"test_file_artifact_{_uuid}", + force_overwrite=True + ) + else: + with pytest.raises(RuntimeError): + Artifact.from_name( + run_id=_run.id, + name=f"test_file_artifact_{_uuid}", + force_overwrite=False + ) + _run.delete() + _folder.delete()