Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,21 @@ jobs:
run: make test_sqlite_regexp
env:
PYTHONDEVMODE: 1
- name: Test FastAPI/Blacksheep/Sanic Examples
- name: Test FastAPI/Blacksheep/Sanic/Starlette Examples
run: |
PYTHONPATH=$DEST_FASTAPI uv run --frozen tortoise -c config.TORTOISE_ORM migrate
PYTHONPATH=$DEST_FASTAPI uv run --frozen pytest $PYTEST_ARGS_SEQ $DEST_FASTAPI/_tests.py
rm -f $DEST_FASTAPI/db.sqlite3
PYTHONPATH=$DEST_BLACKSHEEP uv run --frozen pytest $PYTEST_ARGS $DEST_BLACKSHEEP/_tests.py
PYTHONPATH=$DEST_SANIC uv run --frozen pytest $PYTEST_ARGS $DEST_SANIC/_tests.py
PYTHONPATH=$DEST_STARLETTE uv run --frozen pytest $PYTEST_ARGS $DEST_STARLETTE/_tests.py
uv pip install "starlette<1.0"
PYTHONPATH=$DEST_STARLETTE uv run --no-sync pytest $PYTEST_ARGS $DEST_STARLETTE/_tests.py
env:
DEST_FASTAPI: examples/fastapi
DEST_BLACKSHEEP: examples/blacksheep
DEST_SANIC: examples/sanic
DEST_STARLETTE: examples/starlette
PYTHONDEVMODE: 1
PYTEST_ARGS: "-n auto --cov=tortoise --cov-append --cov-branch --tb=native -q"
PYTEST_ARGS_SEQ: "--cov=tortoise --cov-append --cov-branch --tb=native -q"
Expand Down
38 changes: 38 additions & 0 deletions examples/starlette/_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from pathlib import Path

import pytest
from asgi_lifespan import LifespanManager
from httpx import ASGITransport, AsyncClient

try:
from main import app
from models import Users
except ImportError:
if (cwd := Path.cwd()) == (parent := Path(__file__).parent):
dirpath = "."
else:
dirpath = str(parent.relative_to(cwd))
print(f"You may need to explicitly declare python path:\n\nexport PYTHONPATH={dirpath}\n")
raise


@pytest.mark.anyio
async def test_app() -> None:
async with LifespanManager(app):
transport = ASGITransport(app=app)
# note: you _must_ set `base_url` for relative urls like "/" to work
async with AsyncClient(transport=transport, base_url="http://testserver") as client:
r = await client.get("/")
assert r.status_code == 200
assert r.json() == {"users": []}
(await Users.all()) == []

r = await client.post("/user/", json={"username": "Iron"})
assert r.status_code == 201
assert r.json() == {"user": "Users(id=1, username='Iron')"}
await Users.get(id=1) == await Users.last()

r = await client.get("/")
assert r.status_code == 200
assert r.json() == {"users": ["User 1: Iron"]}
(await Users.all()) == [await Users.first()]
24 changes: 17 additions & 7 deletions examples/starlette/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Mount, Route
from starlette.status import HTTP_201_CREATED, HTTP_400_BAD_REQUEST
from uvicorn.main import run

from tortoise.contrib.starlette import register_tortoise

Expand All @@ -17,29 +17,39 @@
app = Starlette()


@app.route("/", methods=["GET"])
async def list_all(_: Request) -> JSONResponse:
users = await Users.all()
return JSONResponse({"users": [str(user) for user in users]})


@app.route("/user", methods=["POST"])
async def add_user(request: Request) -> JSONResponse:
try:
payload = await request.json()
username = payload["username"]
except JSONDecodeError:
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="cannot parse request body")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST, detail="cannot parse request body"
) from None
except KeyError:
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="username is required")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST, detail="username is required"
) from None

user = await Users.create(username=username)
return JSONResponse({"user": str(user)}, status_code=HTTP_201_CREATED)
return JSONResponse({"user": repr(user)}, status_code=HTTP_201_CREATED)


app = Starlette(
routes=[
Route("/", list_all),
Mount("/user", routes=[Route("/", add_user, methods=["POST"])]),
]
)
register_tortoise(
app, db_url="sqlite://:memory:", modules={"models": ["models"]}, generate_schemas=True
)

if __name__ == "__main__":
run(app)
import uvicorn

uvicorn.run("__main__:app", reload=True)
5 changes: 5 additions & 0 deletions examples/starlette/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@ class Users(models.Model):

def __str__(self) -> str:
return f"User {self.id}: {self.username}"

def __repr__(self) -> str:
fields = sorted(self._meta.db_fields) # ['id', 'username']
values = ", ".join(f"{f}={getattr(self, f)!r}" for f in fields)
return f"{self.__class__.__name__}({values})"
62 changes: 58 additions & 4 deletions tortoise/contrib/starlette/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from __future__ import annotations

from collections.abc import Iterable
import functools
from collections.abc import Awaitable, Callable, Iterable
from contextlib import asynccontextmanager
from types import ModuleType
from typing import TypeVar

import starlette
from starlette.applications import Starlette # pylint: disable=E0401
from starlette.routing import Host, Mount, Route, is_async_callable
from starlette.routing import _DefaultLifespan as StarletteDefaultLifespan

from tortoise import Tortoise
from tortoise.config import TortoiseConfig
from tortoise.connection import get_connections
from tortoise.context import TortoiseContext
from tortoise.log import logger


Expand Down Expand Up @@ -79,16 +87,62 @@ def register_tortoise(
ConfigurationError
For any configuration error
"""
typed_config = TortoiseConfig.resolve_args(config, config_file, db_url, modules)

@app.on_event("startup")
async def init_orm() -> None: # pylint: disable=W0612
await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules)
await Tortoise.init(config=typed_config, _enable_global_fallback=True)
logger.info("Tortoise-ORM started, %s, %s", get_connections()._get_storage(), Tortoise.apps)
if generate_schemas:
logger.info("Tortoise-ORM generating schema")
await Tortoise.generate_schemas()

@app.on_event("shutdown")
async def close_orm() -> None: # pylint: disable=W0612
await Tortoise.close_connections()
logger.info("Tortoise-ORM shutdown")

if starlette.__version__ < "1":
if (on_event := getattr(app, "on_event", None)) is not None:
on_event("startup")(init_orm)
on_event("shutdown")(close_orm)
else:
original_lifespan = app.router.lifespan_context

if generate_schemas or not isinstance(original_lifespan, StarletteDefaultLifespan):

@asynccontextmanager
async def orm_inited_lifespan(app_):
await init_orm()
try:
async with original_lifespan(app_) as maybe_state:
yield maybe_state
finally:
await close_orm()

app.router.lifespan_context = orm_inited_lifespan

if app.router.routes:
T = TypeVar("T")

def db_session(func: Callable[..., Awaitable[T]]):
@functools.wraps(func)
async def runner(*args, **kw) -> T:
async with TortoiseContext() as ctx:
await ctx.init(typed_config)
return await func(*args, **kw)

return runner

key = "_patch_tortoise"

def patch_endpoints(routes: list) -> None:
for r in routes:
if isinstance(r, Route):
if getattr(r, key, False):
continue
setattr(r, key, True)
if is_async_callable(endpoint := r.endpoint):
r.endpoint = db_session(endpoint)
elif isinstance(r, Mount | Host) or getattr(r, "routes", []):
patch_endpoints(r.routes)

patch_endpoints(app.router.routes)
2 changes: 1 addition & 1 deletion tortoise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def __iter__(self) -> Iterable[tuple]:
yield field, getattr(self, field)

def __eq__(self, other: object) -> bool:
return type(other) is type(self) and self.pk == other.pk # type: ignore
return type(other) is type(self) and self.pk == other.pk

def _get_pk_val(self) -> Any:
return getattr(self, self._meta.pk_attr, None)
Expand Down
Loading
Loading