diff --git a/CHANGELOG.md b/CHANGELOG.md index 29f59d10..1a4305fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,33 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.17.12](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.17.12) - 2026-02-23 + +### Added +- `Dataset.deduplicate()` method to deduplicate images using perceptual hashing. Accepts optional `reference_ids` to deduplicate specific items, or deduplicates the entire dataset when only `threshold` is provided. Required `threshold` parameter (0-64) controls similarity matching (lower = stricter, 0 = exact matches only). +- `Dataset.deduplicate_by_ids()` method for deduplication using internal `dataset_item_ids` directly, avoiding the reference ID to item ID mapping for improved efficiency. +- `DeduplicationResult` and `DeduplicationStats` dataclasses for structured deduplication results. + +Example usage: + +```python +dataset = client.get_dataset("ds_...") + +# Deduplicate entire dataset +result = dataset.deduplicate(threshold=10) + +# Deduplicate specific items by reference IDs +result = dataset.deduplicate(threshold=10, reference_ids=["ref_1", "ref_2", "ref_3"]) + +# Deduplicate by internal item IDs (more efficient if you have them) +result = dataset.deduplicate_by_ids(threshold=10, dataset_item_ids=["item_1", "item_2"]) + +# Access results +print(f"Threshold: {result.stats.threshold}") +print(f"Original: {result.stats.original_count}, Unique: {result.stats.deduplicated_count}") +print(result.unique_reference_ids) +``` + ## [0.17.11](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.17.11) - 2025-11-03 ### Added diff --git a/nucleus/__init__.py b/nucleus/__init__.py index 3f970c2b..df97ddec 100644 --- a/nucleus/__init__.py +++ b/nucleus/__init__.py @@ -4,6 +4,8 @@ "AsyncJob", "EmbeddingsExportJob", "BoxAnnotation", + "DeduplicationResult", + "DeduplicationStats", "BoxPrediction", "CameraParams", "CategoryAnnotation", @@ -128,6 +130,7 @@ from .data_transfer_object.job_status import JobInfoRequestPayload from .dataset import Dataset from .dataset_item import DatasetItem +from .deduplication import DeduplicationResult, DeduplicationStats from .deprecation_warning import deprecated from .errors import ( DatasetItemRetrievalError, diff --git a/nucleus/constants.py b/nucleus/constants.py index 0a2bbf46..ebad94f5 100644 --- a/nucleus/constants.py +++ b/nucleus/constants.py @@ -149,6 +149,7 @@ SLICE_TAGS_KEY = "slice_tags" TAXONOMY_NAME_KEY = "taxonomy_name" TASK_ID_KEY = "task_id" +THRESHOLD_KEY = "threshold" TRACK_REFERENCE_ID_KEY = "track_reference_id" TRACK_REFERENCE_IDS_KEY = "track_reference_ids" TRACKS_KEY = "tracks" diff --git a/nucleus/dataset.py b/nucleus/dataset.py index ea95f840..bc1f244f 100644 --- a/nucleus/dataset.py +++ b/nucleus/dataset.py @@ -67,6 +67,7 @@ REQUEST_ID_KEY, SCENE_IDS_KEY, SLICE_ID_KEY, + THRESHOLD_KEY, TRACK_REFERENCE_IDS_KEY, TRACKS_KEY, TRAINED_SLICE_ID_KEY, @@ -83,6 +84,7 @@ check_items_have_dimensions, ) from .dataset_item_uploader import DatasetItemUploader +from .deduplication import DeduplicationResult, DeduplicationStats from .deprecation_warning import deprecated from .errors import NotFoundError, NucleusAPIError from .job import CustomerJobTypes, jobs_status_overview @@ -1006,6 +1008,116 @@ def create_slice_by_ids( ) return Slice(response[SLICE_ID_KEY], self._client) + def deduplicate( + self, + threshold: int, + reference_ids: Optional[List[str]] = None, + ) -> DeduplicationResult: + """Deduplicate images or frames using user-defined reference IDs. + + This method can deduplicate an entire dataset (when reference_ids is omitted) + or a specific subset of items identified by the reference_id you assigned + when uploading (e.g., "image_001", "frame_xyz"). To deduplicate using + internal Nucleus item IDs instead, use `deduplicate_by_ids()`. + + Parameters: + threshold: Hamming distance threshold (0-64). Lower = stricter. + 0 = exact matches only. + reference_ids: Optional list of user-defined reference IDs to deduplicate. + If not provided (or None), deduplicates the entire dataset. + Cannot be an empty list - use None for entire dataset. + + Returns: + DeduplicationResult with unique_reference_ids, unique_item_ids, and stats. + + Raises: + ValueError: If reference_ids is an empty list (use None for entire dataset). + NucleusAPIError: If threshold is not an integer between 0 and 64 inclusive. + NucleusAPIError: If any reference_id is not found in the dataset. + NucleusAPIError: If any item is missing a perceptual hash (pHash). + Contact Scale support if this occurs. + + Note: + - For scene datasets, this deduplicates the underlying scene frames, + not the scenes themselves. Frame reference IDs or dataset item IDs + should be provided for scene datasets. + - For very large datasets, this operation may take significant time. + """ + # Client-side validation + if reference_ids is not None and len(reference_ids) == 0: + raise ValueError( + "reference_ids cannot be empty. Omit reference_ids parameter to deduplicate entire dataset." + ) + + payload: Dict[str, Any] = {THRESHOLD_KEY: threshold} + if reference_ids is not None: + payload[REFERENCE_IDS_KEY] = reference_ids + + response = self._client.make_request( + payload, f"dataset/{self.id}/deduplicate" + ) + return DeduplicationResult( + unique_item_ids=response["unique_item_ids"], + unique_reference_ids=response["unique_reference_ids"], + stats=DeduplicationStats( + threshold=threshold, + original_count=response["stats"]["original_count"], + deduplicated_count=response["stats"]["deduplicated_count"], + ), + ) + + def deduplicate_by_ids( + self, + threshold: int, + dataset_item_ids: List[str], + ) -> DeduplicationResult: + """Deduplicate images or frames using internal Nucleus dataset item IDs. + + This method identifies items by internal Nucleus IDs (e.g., "di_abc123...") + which are system-assigned when items are uploaded. To deduplicate using + your own user-defined reference IDs instead, or to deduplicate the entire + dataset, use `deduplicate()`. + + Parameters: + threshold: Hamming distance threshold (0-64). Lower = stricter. + 0 = exact matches only. + dataset_item_ids: List of internal Nucleus dataset item IDs to deduplicate. + These IDs are generated by Nucleus; they are not + user-defined reference IDs. Must be non-empty. + + Returns: + DeduplicationResult with unique_item_ids, unique_reference_ids, and stats. + + Raises: + ValueError: If dataset_item_ids is empty. + NucleusAPIError: If threshold is not an integer between 0 and 64 inclusive. + NucleusAPIError: If any dataset_item_id is not found in the dataset. + NucleusAPIError: If any item is missing a perceptual hash (pHash). + Contact Scale support if this occurs. + """ + # Client-side validation + if not dataset_item_ids: + raise ValueError( + "dataset_item_ids must be non-empty. Use deduplicate() for entire dataset." + ) + + payload = { + DATASET_ITEM_IDS_KEY: dataset_item_ids, + THRESHOLD_KEY: threshold, + } + response = self._client.make_request( + payload, f"dataset/{self.id}/deduplicate" + ) + return DeduplicationResult( + unique_item_ids=response["unique_item_ids"], + unique_reference_ids=response["unique_reference_ids"], + stats=DeduplicationStats( + threshold=threshold, + original_count=response["stats"]["original_count"], + deduplicated_count=response["stats"]["deduplicated_count"], + ), + ) + def build_slice( self, name: str, diff --git a/nucleus/deduplication.py b/nucleus/deduplication.py new file mode 100644 index 00000000..f427c004 --- /dev/null +++ b/nucleus/deduplication.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass +from typing import List + + +@dataclass +class DeduplicationStats: + threshold: int + original_count: int + deduplicated_count: int + + +@dataclass +class DeduplicationResult: + unique_item_ids: List[str] # Internal dataset item IDs + unique_reference_ids: List[str] # User-defined reference IDs + stats: DeduplicationStats diff --git a/pyproject.toml b/pyproject.toml index 4fe1aaa2..6622dcd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ ignore = ["E501", "E741", "E731", "F401"] # Easy ignore for getting it running [tool.poetry] name = "scale-nucleus" -version = "0.17.11" +version = "0.17.12" description = "The official Python client library for Nucleus, the Data Platform for AI" license = "MIT" authors = ["Scale AI Nucleus Team "] diff --git a/tests/helpers.py b/tests/helpers.py index bef9cc50..03076e6f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -24,6 +24,8 @@ EVAL_FUNCTION_THRESHOLD = 0.5 EVAL_FUNCTION_COMPARISON = ThresholdComparison.GREATER_THAN_EQUAL_TO +DEDUP_DEFAULT_TEST_THRESHOLD = 10 + TEST_IMG_URLS = [ "https://github.com/scaleapi/nucleus-python-client/raw/master/tests/testdata/airplane.jpeg", diff --git a/tests/test_deduplication.py b/tests/test_deduplication.py new file mode 100644 index 00000000..857cd126 --- /dev/null +++ b/tests/test_deduplication.py @@ -0,0 +1,386 @@ +import pytest + +from nucleus import Dataset, DatasetItem, NucleusClient, VideoScene +from nucleus.deduplication import DeduplicationResult +from nucleus.errors import NucleusAPIError + +from .helpers import ( + DEDUP_DEFAULT_TEST_THRESHOLD, + TEST_DATASET_ITEMS, + TEST_DATASET_NAME, + TEST_IMG_URLS, + TEST_VIDEO_DATASET_NAME, + TEST_VIDEO_SCENES, + TEST_VIDEO_URL, +) + + +def test_deduplicate_empty_reference_ids_raises_error(): + fake_dataset = Dataset("fake", NucleusClient("fake")) + with pytest.raises(ValueError, match="reference_ids cannot be empty"): + fake_dataset.deduplicate(threshold=DEDUP_DEFAULT_TEST_THRESHOLD, reference_ids=[]) + + +def test_deduplicate_by_ids_empty_list_raises_error(): + fake_dataset = Dataset("fake", NucleusClient("fake")) + with pytest.raises(ValueError, match="dataset_item_ids must be non-empty"): + fake_dataset.deduplicate_by_ids(threshold=DEDUP_DEFAULT_TEST_THRESHOLD, dataset_item_ids=[]) + + +@pytest.fixture(scope="module") +def dataset_image_sync(CLIENT): + """Image dataset uploaded synchronously.""" + ds = CLIENT.create_dataset(TEST_DATASET_NAME + " dedup sync", is_scene=False) + try: + ds.append(TEST_DATASET_ITEMS) + yield ds + finally: + CLIENT.delete_dataset(ds.id) + + +@pytest.fixture(scope="module") +def dataset_image_async(CLIENT): + """Image dataset uploaded asynchronously.""" + ds = CLIENT.create_dataset(TEST_DATASET_NAME + " dedup async", is_scene=False) + try: + job = ds.append(TEST_DATASET_ITEMS, asynchronous=True) + job.sleep_until_complete() + yield ds + finally: + CLIENT.delete_dataset(ds.id) + + +@pytest.mark.integration +def test_deduplicate_image_sync_entire_dataset(dataset_image_sync): + """Test deduplication on image dataset uploaded synchronously.""" + result = dataset_image_sync.deduplicate(threshold=DEDUP_DEFAULT_TEST_THRESHOLD) + assert isinstance(result, DeduplicationResult) + assert len(result.unique_reference_ids) > 0 + assert len(result.unique_item_ids) > 0 + assert result.stats.original_count == len(TEST_DATASET_ITEMS) + + +@pytest.mark.integration +def test_deduplicate_image_sync_with_reference_ids(dataset_image_sync): + """Test deduplication with reference IDs on image dataset uploaded synchronously.""" + reference_ids = [item.reference_id for item in TEST_DATASET_ITEMS[:2]] + result = dataset_image_sync.deduplicate(threshold=DEDUP_DEFAULT_TEST_THRESHOLD, reference_ids=reference_ids) + assert isinstance(result, DeduplicationResult) + assert result.stats.original_count == len(reference_ids) + assert len(result.unique_reference_ids) <= len(reference_ids) + assert len(result.unique_item_ids) <= len(reference_ids) + + +@pytest.mark.integration +def test_deduplicate_image_sync_by_ids(dataset_image_sync): + """Test deduplicate_by_ids on image dataset uploaded synchronously.""" + initial_result = dataset_image_sync.deduplicate(threshold=DEDUP_DEFAULT_TEST_THRESHOLD) + item_ids = initial_result.unique_item_ids + assert len(item_ids) > 0 + + result = dataset_image_sync.deduplicate_by_ids(threshold=DEDUP_DEFAULT_TEST_THRESHOLD, dataset_item_ids=item_ids) + assert isinstance(result, DeduplicationResult) + assert result.stats.original_count == len(item_ids) + assert result.unique_item_ids == initial_result.unique_item_ids + + +@pytest.mark.integration +def test_deduplicate_image_async_entire_dataset(dataset_image_async): + """Test deduplication on image dataset uploaded asynchronously.""" + result = dataset_image_async.deduplicate(threshold=DEDUP_DEFAULT_TEST_THRESHOLD) + assert isinstance(result, DeduplicationResult) + assert len(result.unique_reference_ids) > 0 + assert len(result.unique_item_ids) > 0 + assert result.stats.original_count == len(TEST_DATASET_ITEMS) + + +@pytest.mark.integration +def test_deduplicate_image_async_with_reference_ids(dataset_image_async): + """Test deduplication with reference IDs on image dataset uploaded asynchronously.""" + reference_ids = [item.reference_id for item in TEST_DATASET_ITEMS[:2]] + result = dataset_image_async.deduplicate(threshold=DEDUP_DEFAULT_TEST_THRESHOLD, reference_ids=reference_ids) + assert isinstance(result, DeduplicationResult) + assert result.stats.original_count == len(reference_ids) + assert len(result.unique_reference_ids) <= len(reference_ids) + assert len(result.unique_item_ids) <= len(reference_ids) + + +@pytest.mark.integration +def test_deduplicate_image_async_by_ids(dataset_image_async): + """Test deduplicate_by_ids on image dataset uploaded asynchronously.""" + initial_result = dataset_image_async.deduplicate(threshold=DEDUP_DEFAULT_TEST_THRESHOLD) + item_ids = initial_result.unique_item_ids + assert len(item_ids) > 0 + + result = dataset_image_async.deduplicate_by_ids(threshold=DEDUP_DEFAULT_TEST_THRESHOLD, dataset_item_ids=item_ids) + assert isinstance(result, DeduplicationResult) + assert result.stats.original_count == len(item_ids) + assert result.unique_item_ids == initial_result.unique_item_ids + + +@pytest.fixture(scope="module") +def dataset_video_scene_async(CLIENT): + """Video scene dataset (with frames) uploaded asynchronously.""" + ds = CLIENT.create_dataset(TEST_VIDEO_DATASET_NAME + " dedup async", is_scene=True) + try: + scene_1 = TEST_VIDEO_SCENES["scenes"][0] + scenes = [VideoScene.from_json(scene_1)] + job = ds.append(scenes, asynchronous=True) + job.sleep_until_complete() + yield ds + finally: + CLIENT.delete_dataset(ds.id) + + +def _get_scene_frame_ref_ids(): + """Extract frame reference IDs from TEST_VIDEO_SCENES scene_1.""" + return [frame["reference_id"] for frame in TEST_VIDEO_SCENES["scenes"][0]["frames"]] + + +@pytest.mark.integration +def test_deduplicate_video_scene_async_entire_dataset(dataset_video_scene_async): + """Test deduplication on video scene dataset uploaded asynchronously.""" + result = dataset_video_scene_async.deduplicate(threshold=DEDUP_DEFAULT_TEST_THRESHOLD) + assert isinstance(result, DeduplicationResult) + assert len(result.unique_reference_ids) > 0 + assert len(result.unique_item_ids) > 0 + assert result.stats.original_count == len(_get_scene_frame_ref_ids()) + + +@pytest.mark.integration +def test_deduplicate_video_scene_async_with_frame_reference_ids(dataset_video_scene_async): + """Test deduplication with frame reference IDs on video scene dataset uploaded asynchronously.""" + frame_ref_ids = _get_scene_frame_ref_ids() + result = dataset_video_scene_async.deduplicate(threshold=DEDUP_DEFAULT_TEST_THRESHOLD, reference_ids=frame_ref_ids) + assert isinstance(result, DeduplicationResult) + assert result.stats.original_count == len(frame_ref_ids) + assert len(result.unique_reference_ids) <= len(frame_ref_ids) + assert len(result.unique_item_ids) <= len(frame_ref_ids) + + +@pytest.mark.integration +def test_deduplicate_video_scene_async_by_ids(dataset_video_scene_async): + """Test deduplicate_by_ids on video scene dataset uploaded asynchronously.""" + initial_result = dataset_video_scene_async.deduplicate(threshold=DEDUP_DEFAULT_TEST_THRESHOLD) + item_ids = initial_result.unique_item_ids + assert len(item_ids) > 0 + + result = dataset_video_scene_async.deduplicate_by_ids( + threshold=DEDUP_DEFAULT_TEST_THRESHOLD, dataset_item_ids=item_ids + ) + assert isinstance(result, DeduplicationResult) + assert result.stats.original_count == len(item_ids) + assert result.unique_item_ids == initial_result.unique_item_ids + + +@pytest.fixture(scope="module") +def dataset_video_url_async(CLIENT): + """Video URL dataset uploaded asynchronously.""" + ds = CLIENT.create_dataset(TEST_VIDEO_DATASET_NAME + " video_url dedup async", is_scene=True) + try: + scene = VideoScene.from_json({ + "reference_id": "video_url_scene_async", + "video_url": TEST_VIDEO_URL, + "metadata": {"test": "video_url_dedup_async"}, + }) + job = ds.append([scene], asynchronous=True) + job.sleep_until_complete() + yield ds + finally: + CLIENT.delete_dataset(ds.id) + + +@pytest.mark.integration +def test_deduplicate_video_url_async_entire_dataset(dataset_video_url_async): + """Test deduplication on video URL dataset uploaded asynchronously.""" + result = dataset_video_url_async.deduplicate(threshold=DEDUP_DEFAULT_TEST_THRESHOLD) + assert isinstance(result, DeduplicationResult) + assert len(result.unique_reference_ids) > 0 + assert len(result.unique_item_ids) > 0 + assert result.stats.original_count > 0 + + +@pytest.mark.integration +def test_deduplicate_video_url_async_by_ids(dataset_video_url_async): + """Test deduplicate_by_ids on video URL dataset uploaded asynchronously.""" + initial_result = dataset_video_url_async.deduplicate(threshold=DEDUP_DEFAULT_TEST_THRESHOLD) + item_ids = initial_result.unique_item_ids + assert len(item_ids) > 0 + + result = dataset_video_url_async.deduplicate_by_ids( + threshold=DEDUP_DEFAULT_TEST_THRESHOLD, dataset_item_ids=item_ids + ) + assert isinstance(result, DeduplicationResult) + assert result.stats.original_count == len(item_ids) + assert result.unique_item_ids == initial_result.unique_item_ids + + +# Edge case tests + + +@pytest.mark.integration +def test_deduplicate_threshold_zero(dataset_image_sync): + """Verify threshold=0 (exact match only) succeeds and returns correct stats.""" + result = dataset_image_sync.deduplicate(threshold=0) + assert isinstance(result, DeduplicationResult) + assert result.stats.threshold == 0 + + +@pytest.mark.integration +def test_deduplicate_threshold_max(dataset_image_sync): + """Verify threshold=64 (maximum allowed) succeeds and returns correct stats.""" + result = dataset_image_sync.deduplicate(threshold=64) + assert isinstance(result, DeduplicationResult) + assert result.stats.threshold == 64 + + +@pytest.mark.integration +def test_deduplicate_threshold_negative(dataset_image_sync): + """Verify negative threshold raises NucleusAPIError (must be >= 0).""" + with pytest.raises(NucleusAPIError): + dataset_image_sync.deduplicate(threshold=-1) + + +@pytest.mark.integration +def test_deduplicate_threshold_too_high(dataset_image_sync): + """Verify threshold > 64 raises NucleusAPIError (must be <= 64).""" + with pytest.raises(NucleusAPIError): + dataset_image_sync.deduplicate(threshold=65) + + +@pytest.mark.integration +def test_deduplicate_threshold_non_integer(dataset_image_sync): + """Verify non-integer threshold raises NucleusAPIError.""" + with pytest.raises(NucleusAPIError): + dataset_image_sync.deduplicate(threshold=10.5) + + +@pytest.mark.integration +def test_deduplicate_nonexistent_reference_id(dataset_image_sync): + """Verify nonexistent reference_id raises NucleusAPIError.""" + with pytest.raises(NucleusAPIError): + dataset_image_sync.deduplicate( + threshold=DEDUP_DEFAULT_TEST_THRESHOLD, reference_ids=["nonexistent_ref_id"] + ) + + +@pytest.mark.integration +def test_deduplicate_by_ids_nonexistent_id(dataset_image_sync): + """Verify nonexistent dataset_item_id raises NucleusAPIError.""" + with pytest.raises(NucleusAPIError): + dataset_image_sync.deduplicate_by_ids( + threshold=DEDUP_DEFAULT_TEST_THRESHOLD, dataset_item_ids=["di_nonexistent"] + ) + + +@pytest.mark.integration +def test_deduplicate_idempotency(dataset_image_sync): + """Verify repeated deduplication calls return consistent results.""" + result1 = dataset_image_sync.deduplicate(threshold=DEDUP_DEFAULT_TEST_THRESHOLD) + result2 = dataset_image_sync.deduplicate(threshold=DEDUP_DEFAULT_TEST_THRESHOLD) + + assert result1.unique_item_ids == result2.unique_item_ids + assert result1.unique_reference_ids == result2.unique_reference_ids + assert result1.stats.original_count == result2.stats.original_count + assert result1.stats.deduplicated_count == result2.stats.deduplicated_count + + +@pytest.mark.integration +def test_deduplicate_response_invariants(dataset_image_sync): + """Verify response maintains expected invariants between fields.""" + result = dataset_image_sync.deduplicate(threshold=DEDUP_DEFAULT_TEST_THRESHOLD) + + assert len(result.unique_item_ids) == len(result.unique_reference_ids) + assert result.stats.deduplicated_count == len(result.unique_item_ids) + assert result.stats.deduplicated_count <= result.stats.original_count + assert result.stats.threshold == DEDUP_DEFAULT_TEST_THRESHOLD + + +@pytest.mark.integration +def test_deduplicate_by_ids_threshold_negative(dataset_image_sync): + """Verify deduplicate_by_ids rejects negative threshold.""" + initial_result = dataset_image_sync.deduplicate(threshold=DEDUP_DEFAULT_TEST_THRESHOLD) + item_ids = initial_result.unique_item_ids + + with pytest.raises(NucleusAPIError): + dataset_image_sync.deduplicate_by_ids(threshold=-1, dataset_item_ids=item_ids) + + +@pytest.mark.integration +def test_deduplicate_by_ids_threshold_too_high(dataset_image_sync): + """Verify deduplicate_by_ids rejects threshold > 64.""" + initial_result = dataset_image_sync.deduplicate(threshold=DEDUP_DEFAULT_TEST_THRESHOLD) + item_ids = initial_result.unique_item_ids + + with pytest.raises(NucleusAPIError): + dataset_image_sync.deduplicate_by_ids(threshold=65, dataset_item_ids=item_ids) + + +@pytest.mark.integration +def test_deduplicate_single_item(dataset_image_sync): + """Verify single item deduplication returns that item as unique.""" + reference_ids = [TEST_DATASET_ITEMS[0].reference_id] + result = dataset_image_sync.deduplicate( + threshold=DEDUP_DEFAULT_TEST_THRESHOLD, reference_ids=reference_ids + ) + + assert result.stats.original_count == 1 + assert result.stats.deduplicated_count == 1 + assert len(result.unique_reference_ids) == 1 + + +@pytest.fixture(scope="function") +def dataset_empty(CLIENT): + """Empty dataset with no items.""" + ds = CLIENT.create_dataset(TEST_DATASET_NAME + " empty", is_scene=False) + try: + yield ds + finally: + CLIENT.delete_dataset(ds.id) + + +@pytest.mark.integration +def test_deduplicate_empty_dataset(dataset_empty): + """Verify deduplication on empty dataset returns zero counts.""" + result = dataset_empty.deduplicate(threshold=DEDUP_DEFAULT_TEST_THRESHOLD) + + assert result.stats.original_count == 0 + assert result.stats.deduplicated_count == 0 + assert len(result.unique_reference_ids) == 0 + assert len(result.unique_item_ids) == 0 + + +@pytest.fixture(scope="function") +def dataset_with_duplicates(CLIENT): + """Dataset with duplicate images (same image uploaded twice).""" + ds = CLIENT.create_dataset(TEST_DATASET_NAME + " duplicates", is_scene=False) + try: + items = [ + DatasetItem(TEST_IMG_URLS[0], reference_id="img_original"), + DatasetItem(TEST_IMG_URLS[0], reference_id="img_duplicate"), + DatasetItem(TEST_IMG_URLS[1], reference_id="img_different"), + ] + ds.append(items) + yield ds + finally: + CLIENT.delete_dataset(ds.id) + + +@pytest.mark.integration +def test_deduplicate_identifies_duplicates(dataset_with_duplicates): + """Verify deduplication actually identifies duplicate images.""" + result = dataset_with_duplicates.deduplicate(threshold=0) + + assert result.stats.original_count == 3 + # With threshold=0, the two identical images should be deduplicated to one + assert result.stats.deduplicated_count == 2 + assert len(result.unique_reference_ids) == 2 + + +@pytest.mark.integration +def test_deduplicate_distinct_images_all_unique(dataset_image_sync): + """Distinct images should all remain after deduplication.""" + result = dataset_image_sync.deduplicate(threshold=0) + + # With threshold=0 (exact match only), all distinct images should be unique + assert result.stats.deduplicated_count == result.stats.original_count diff --git a/tests/test_jobs.py b/tests/test_jobs.py index fb5a631a..3b600665 100644 --- a/tests/test_jobs.py +++ b/tests/test_jobs.py @@ -1,10 +1,9 @@ -import time -from pathlib import Path - import pytest from nucleus import AsyncJob, NucleusClient +from .helpers import TEST_DATASET_ITEMS, TEST_DATASET_NAME + def test_reprs(): # Have to define here in order to have access to all relevant objects @@ -23,11 +22,32 @@ def test_repr(test_object: any): ) -def test_job_listing_and_retrieval(CLIENT): +@pytest.fixture(scope="module") +def job_from_dataset_upload(CLIENT): + """Create a job by doing an async dataset upload.""" + ds = CLIENT.create_dataset(TEST_DATASET_NAME + " job test", is_scene=False) + try: + job = ds.append(TEST_DATASET_ITEMS, asynchronous=True) + job.sleep_until_complete() + yield job + finally: + CLIENT.delete_dataset(ds.id) + + +@pytest.mark.integration +def test_job_listing(CLIENT): + """Test that list_jobs returns results.""" jobs = CLIENT.list_jobs() - assert len(jobs) > 0, "No jobs found" - fetch_id = jobs[0].job_id - fetched_job = CLIENT.get_job(fetch_id) - # job_last_known_status can change - fetched_job.job_last_known_status = jobs[0].job_last_known_status - assert fetched_job == jobs[0] + assert isinstance(jobs, list) + # Just verify the API works and returns AsyncJob objects + if len(jobs) > 0: + assert hasattr(jobs[0], "job_id") + + +@pytest.mark.integration +def test_job_retrieval(CLIENT, job_from_dataset_upload): + """Test that we can retrieve a job we created by ID.""" + known_job_id = job_from_dataset_upload.job_id + + fetched_job = CLIENT.get_job(known_job_id) + assert fetched_job.job_id == known_job_id