Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 76 additions & 3 deletions app/api/v1/tags.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from uuid import UUID

from fastapi import Query, Depends, Security, APIRouter
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession

from app.db.session import get_db
Expand Down Expand Up @@ -78,18 +79,23 @@ async def delete_tag_group(group_id: str, force: bool = Query(False), db: AsyncS


@router.post("/{group_id}/tags", dependencies=[Security(require_write_access)])
async def add_tag_to_group(group_id: str, data: TagCreate, db: AsyncSession = Depends(get_db)) -> dict:
async def add_tag_to_group(
group_id: str,
data: TagCreate,
skip_duplicates: bool = Query(False, description="Skip duplicate tags instead of raising error"),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Add a tag to a group."""
try:
UUID(group_id)
except ValueError:
raise APIException(404, f"Tag group {group_id} not found", 404)
service = TagService(db)
try:
group = await service.add_tag_to_group(group_id, data)
group, was_created = await service.add_tag_to_group(group_id, data, skip_duplicates=skip_duplicates)
if not group:
raise APIException(404, f"Tag group {group_id} not found", 404)
return success_response(group)
return success_response({"group": group, "wasCreated": was_created})
except ValueError as e:
raise APIException(409, str(e), 409)

Expand Down Expand Up @@ -122,3 +128,70 @@ async def remove_tag(group_id: str, tag_id: str, db: AsyncSession = Depends(get_
if not group:
raise APIException(404, f"Tag {tag_id} not found in group {group_id}", 404)
return success_response(group)


class BulkTagImportRequest(BaseModel):
"""Request schema for bulk tag import."""

groups: dict[str, list[str]] # group_name -> list of tag names


class BulkTagImportResponse(BaseModel):
"""Response schema for bulk tag import."""

groupsCreated: int
tagsCreated: int
tagsSkipped: int
warnings: list[str]


@router.post("/bulk", response_model=dict)
async def bulk_import_tags(data: BulkTagImportRequest, db: AsyncSession = Depends(get_db)):
"""Bulk import tags with duplicate handling."""
service = TagService(db)

groups_created = 0
tags_created = 0
tags_skipped = 0
warnings: list[str] = []

# Get existing tag groups
existing_summaries = await service.get_all_groups_summary()
existing_group_map = {g.name: g.id for g in existing_summaries}

for group_name, tag_names in data.groups.items():
# Get or create group
group_id = existing_group_map.get(group_name)
if not group_id:
# Create new group
try:
group = await service.create_group(TagGroupCreate(name=group_name))
group_id = group.id
existing_group_map[group_name] = group_id
groups_created += 1
except ValueError as e:
warnings.append(f"Failed to create group '{group_name}': {e}")
continue

# Add tags to group
for tag_name in tag_names:
try:
_, was_created = await service.add_tag_to_group(
group_id, TagCreate(name=tag_name), skip_duplicates=True
)
if was_created:
tags_created += 1
else:
tags_skipped += 1
warnings.append(f"Tag '{tag_name}' already exists in group '{group_name}'")
except Exception as e:
warnings.append(f"Failed to add tag '{tag_name}' to group '{group_name}': {e}")

return success_response(
BulkTagImportResponse(
groupsCreated=groups_created,
tagsCreated=tags_created,
tagsSkipped=tags_skipped,
warnings=warnings,
)
)
52 changes: 39 additions & 13 deletions app/services/tag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,31 +100,57 @@ async def delete_group(self, group_id: str, force: bool = False) -> bool:
await self.group_repo.delete(group)
return True

async def add_tag_to_group(self, group_id: str, data: TagCreate) -> TagGroupDTO | None:
"""Add a tag to a group."""
async def add_tag_to_group(
self, group_id: str, data: TagCreate, skip_duplicates: bool = False
) -> tuple[TagGroupDTO | None, bool]:
"""
Add a tag to a group.

Args:
group_id: The group ID to add the tag to
data: Tag creation data
skip_duplicates: If True, return success even if tag already exists

Returns:
Tuple of (group_dto, was_created)
- group_dto: The updated group, or None if group not found
- was_created: True if tag was created, False if it already existed
"""
group = await self.group_repo.get_with_tags(group_id)
if not group:
return None
return None, False

# Check for duplicate tag name in group
existing = await self.tag_repo.find_by_group_and_name(group_id, data.name)
if existing:
raise ValueError(f"Tag '{data.name}' already exists in this group")
if skip_duplicates:
# Return success but indicate tag already existed
group_dto = await self.get_group_by_id(group_id)
return group_dto, False
else:
# Raise error for strict validation (backward compatibility)
raise ValueError(f"Tag '{data.name}' already exists in this group")

tag = Tag(name=data.name, description=data.description, group_id=group_id)
await self.tag_repo.create(tag)

# Expire cached group and refresh to get updated tags
self.group_repo.session.expire(group)
group = await self.group_repo.get_with_tags(group_id)

return TagGroupDTO(
id=group.id,
name=group.name,
description=group.description,
tags=[TagDTO(id=t.id, name=t.name, description=t.description) for t in group.tags],
createdDate=group.created_at,
lastModifiedDate=group.updated_at,
updated_group = await self.group_repo.get_with_tags(group_id)

# Type narrowing: updated_group is guaranteed to exist since we just created a tag in it
assert updated_group is not None

return (
TagGroupDTO(
id=updated_group.id,
name=updated_group.name,
description=updated_group.description,
tags=[TagDTO(id=t.id, name=t.name, description=t.description) for t in updated_group.tags],
createdDate=updated_group.created_at,
lastModifiedDate=updated_group.updated_at,
),
True,
)

async def update_tag(self, group_id: str, tag_id: str, data: TagUpdate) -> TagGroupDTO | None:
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Pytest configuration and fixtures for squirrel-backend tests.
"""

import asyncio
import logging
from datetime import datetime
Expand Down Expand Up @@ -161,7 +162,7 @@ async def sample_tag(client: AsyncClient, sample_tag_group: dict) -> tuple[dict,
f"/v1/tags/{group_id}/tags", json={"name": "Building-A", "description": "Building A location"}
)
assert response.status_code == 200
group = response.json()["payload"]
group = response.json()["payload"]["group"]
tag = group["tags"][0]
return group, tag

Expand Down
108 changes: 104 additions & 4 deletions tests/test_api/test_tags.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Tests for Tags API endpoints.
"""

import pytest
from httpx import AsyncClient

Expand Down Expand Up @@ -164,8 +165,8 @@ async def test_add_tag_to_group(self, client: AsyncClient, sample_tag_group: dic
assert response.status_code == 200
data = response.json()
assert data["errorCode"] == 0
assert len(data["payload"]["tags"]) == 1
assert data["payload"]["tags"][0]["name"] == "Building-A"
assert len(data["payload"]["group"]["tags"]) == 1
assert data["payload"]["group"]["tags"][0]["name"] == "Building-A"

@pytest.mark.asyncio
async def test_add_multiple_tags_to_group(self, client: AsyncClient, sample_tag_group: dict):
Expand All @@ -180,11 +181,11 @@ async def test_add_multiple_tags_to_group(self, client: AsyncClient, sample_tag_

assert response.status_code == 200
data = response.json()
assert len(data["payload"]["tags"]) == 2
assert len(data["payload"]["group"]["tags"]) == 2

@pytest.mark.asyncio
async def test_add_duplicate_tag_fails(self, client: AsyncClient, sample_tag: tuple):
"""Test that duplicate tag names within a group are rejected."""
"""Test that duplicate tag names within a group are rejected by default."""
group, tag = sample_tag
group_id = group["id"]

Expand All @@ -194,6 +195,23 @@ async def test_add_duplicate_tag_fails(self, client: AsyncClient, sample_tag: tu
data = response.json()
assert "already exists" in data["errorMessage"]

@pytest.mark.asyncio
async def test_add_duplicate_tag_with_skip_duplicates_succeeds(self, client: AsyncClient, sample_tag: tuple):
"""Test that duplicate tag names succeed when skip_duplicates is true."""
group, tag = sample_tag
group_id = group["id"]

response = await client.post(
f"/v1/tags/{group_id}/tags?skip_duplicates=true",
json={"name": tag["name"]}, # Duplicate name
)

assert response.status_code == 200
data = response.json()
assert data["errorCode"] == 0
assert data["payload"]["wasCreated"] is False # Tag already existed
assert data["payload"]["group"]["id"] == group_id

@pytest.mark.asyncio
async def test_add_tag_to_nonexistent_group(self, client: AsyncClient):
"""Test adding tag to non-existent group."""
Expand Down Expand Up @@ -249,3 +267,85 @@ async def test_remove_tag_not_found(self, client: AsyncClient, sample_tag_group:
response = await client.delete(f"/v1/tags/{group_id}/tags/nonexistent-tag")

assert response.status_code == 404


class TestBulkTagImport:
"""Tests for bulk tag import endpoint."""

@pytest.mark.asyncio
async def test_bulk_import_tags_creates_new_groups_and_tags(self, client: AsyncClient):
"""Test bulk import creates new groups and tags."""
response = await client.post(
"/v1/tags/bulk",
json={
"groups": {
"Location": ["Building-A", "Building-B", "Building-C"],
"System": ["System-1", "System-2"],
}
},
)

assert response.status_code == 200
data = response.json()
assert data["errorCode"] == 0
assert data["payload"]["groupsCreated"] == 2
assert data["payload"]["tagsCreated"] == 5
assert data["payload"]["tagsSkipped"] == 0
assert len(data["payload"]["warnings"]) == 0

@pytest.mark.asyncio
async def test_bulk_import_tags_skips_duplicates(self, client: AsyncClient, sample_tag: tuple):
"""Test bulk import skips duplicate tags."""
group, tag = sample_tag
group_name = group["name"]

response = await client.post(
"/v1/tags/bulk",
json={
"groups": {
group_name: [tag["name"], "New-Tag-1", "New-Tag-2"],
}
},
)

assert response.status_code == 200
data = response.json()
assert data["errorCode"] == 0
assert data["payload"]["groupsCreated"] == 0 # Group already exists
assert data["payload"]["tagsCreated"] == 2 # Only new tags created
assert data["payload"]["tagsSkipped"] == 1 # Duplicate tag skipped
assert len(data["payload"]["warnings"]) == 1
assert "already exists" in data["payload"]["warnings"][0]

@pytest.mark.asyncio
async def test_bulk_import_tags_with_existing_group(self, client: AsyncClient, sample_tag_group: dict):
"""Test bulk import adds tags to existing group."""
group_name = sample_tag_group["name"]

response = await client.post(
"/v1/tags/bulk",
json={
"groups": {
group_name: ["New-Tag-A", "New-Tag-B"],
}
},
)

assert response.status_code == 200
data = response.json()
assert data["errorCode"] == 0
assert data["payload"]["groupsCreated"] == 0 # Group already exists
assert data["payload"]["tagsCreated"] == 2
assert data["payload"]["tagsSkipped"] == 0

@pytest.mark.asyncio
async def test_bulk_import_tags_empty_groups(self, client: AsyncClient):
"""Test bulk import with empty groups dict."""
response = await client.post("/v1/tags/bulk", json={"groups": {}})

assert response.status_code == 200
data = response.json()
assert data["errorCode"] == 0
assert data["payload"]["groupsCreated"] == 0
assert data["payload"]["tagsCreated"] == 0
assert data["payload"]["tagsSkipped"] == 0
Loading