Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ version = { source = "file", path = "tortoise/__init__.py" }
excludes = ["./**/.git", "./**/.*_cache", "examples"]
include = ["CHANGELOG.rst", "LICENSE", "README.rst"]

[tool.uv.sources]
pypika-tortoise = { git = "https://github.com/seladb/pypika-tortoise", branch = "add-functions" }

[tool.mypy]
pretty = true
exclude = ["docs"]
Expand Down
54 changes: 52 additions & 2 deletions tests/contrib/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import pytest
import pytest_asyncio

from tests.testmodels import IntFields
from tests.testmodels import IntFields, Tournament
from tortoise.contrib import test
from tortoise.contrib.mysql.functions import LPad as MySqlLPad
from tortoise.contrib.mysql.functions import Rand
from tortoise.contrib.postgres.functions import Random as PostgresRandom
from tortoise.contrib.mysql.functions import RPad as MySqlRPad
from tortoise.contrib.postgres.functions import (
LPad as PostgresLPad,
)
from tortoise.contrib.postgres.functions import (
Random as PostgresRandom,
)
from tortoise.contrib.postgres.functions import (
RPad as PostgresRPad,
)
from tortoise.contrib.sqlite.functions import Random as SqliteRandom


Expand Down Expand Up @@ -43,3 +53,43 @@ async def test_sqlite_func_rand(db, intfields):
sql = IntFields.all().annotate(randnum=SqliteRandom()).values("intnum", "randnum").sql()
expected_sql = 'SELECT "intnum" "intnum",RANDOM() "randnum" FROM "intfields"'
assert sql == expected_sql


@test.requireCapability(dialect="postgres")
@pytest.mark.asyncio
async def test_postgres_func_lpad(db):
await Tournament.create(name="hello")
await Tournament.create(name="my world")
tournaments = await Tournament.annotate(pad_name=PostgresLPad("name", 12, "x"))
result = set(tournament.pad_name for tournament in tournaments)
assert result == {"xxxxmy world", "xxxxxxxhello"}


@test.requireCapability(dialect="mysql")
@pytest.mark.asyncio
async def test_mysql_func_lpad(db):
await Tournament.create(name="hello")
await Tournament.create(name="my world")
tournaments = await Tournament.annotate(pad_name=MySqlLPad("name", 12, "x"))
result = set(tournament.pad_name for tournament in tournaments)
assert result == {"xxxxmy world", "xxxxxxxhello"}


@test.requireCapability(dialect="postgres")
@pytest.mark.asyncio
async def test_postgres_func_rpad(db):
await Tournament.create(name="hello")
await Tournament.create(name="my world")
tournaments = await Tournament.annotate(pad_name=PostgresRPad("name", 12, "x"))
result = set(tournament.pad_name for tournament in tournaments)
assert result == {"my worldxxxx", "helloxxxxxxx"}


@test.requireCapability(dialect="mysql")
@pytest.mark.asyncio
async def test_mysql_func_rpad(db):
await Tournament.create(name="hello")
await Tournament.create(name="my world")
tournaments = await Tournament.annotate(pad_name=MySqlRPad("name", 12, "x"))
result = set(tournament.pad_name for tournament in tournaments)
assert result == {"my worldxxxx", "helloxxxxxxx"}
75 changes: 74 additions & 1 deletion tests/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,18 @@
from tortoise.contrib import test
from tortoise.contrib.test.condition import In, NotEQ
from tortoise.expressions import Case, F, Q, When
from tortoise.functions import Coalesce, Count, Length, Lower, Max, Trim, Upper
from tortoise.functions import (
Coalesce,
Count,
Length,
Lower,
LTrim,
Max,
Replace,
RTrim,
Trim,
Upper,
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -379,6 +390,68 @@ async def test_filter_by_aggregation_field_trim(db):
assert {(t.name, t.trimmed_name) for t in tournaments} == {(" 1 ", "1")}


@pytest.mark.asyncio
@pytest.mark.parametrize(
["name", "trim_chars", "trimmed_name"],
[
("xxxhellox", "x", "hello"),
("ababhelloab", "ab", "hello"),
],
)
async def test_filter_by_trim_with_chars(db, name, trim_chars, trimmed_name):
await Tournament.create(name=name)
tournaments = await Tournament.annotate(trimmed_name=Trim("name", trim_chars)).filter(
trimmed_name=trimmed_name
)

assert len(tournaments) == 1
assert {(t.name, t.trimmed_name) for t in tournaments} == {(name, trimmed_name)}


@pytest.mark.asyncio
async def test_filter_by_ltrim(db):
await Tournament.create(name=" hello ")
tournaments = await Tournament.annotate(trimmed_name=LTrim("name")).filter(
trimmed_name="hello "
)

assert len(tournaments) == 1
assert {(t.name, t.trimmed_name) for t in tournaments} == {(" hello ", "hello ")}


@pytest.mark.asyncio
async def test_filter_by_rtrim(db):
await Tournament.create(name=" hello ")
tournaments = await Tournament.annotate(trimmed_name=RTrim("name")).filter(
trimmed_name=" hello"
)

assert len(tournaments) == 1
assert {(t.name, t.trimmed_name) for t in tournaments} == {(" hello ", " hello")}


@pytest.mark.asyncio
async def test_replace(db):
await Tournament.create(name="Tournament A")
await Tournament.create(name="Tournament B")
tournaments = await Tournament.annotate(replaced_name=Replace("name", "Tournament", "Contest"))
result = {t.replaced_name for t in tournaments}
assert result == {"Contest A", "Contest B"}


@pytest.mark.asyncio
async def test_filter_by_replace(db):
await Tournament.create(name="1st Tournament")
await Tournament.create(name="2nd Tournament")
await Tournament.create(name="3rd Place")

tournaments = await Tournament.annotate(
replaced_name=Replace("name", "Tournament", "Contest")
).filter(replaced_name="1st Contest")
assert len(tournaments) == 1
assert {(t.name, t.replaced_name) for t in tournaments} == {("1st Tournament", "1st Contest")}


@test.requireCapability(dialect=NotEQ("mssql"))
@pytest.mark.asyncio
async def test_filter_by_aggregation_field_length(db):
Expand Down
42 changes: 41 additions & 1 deletion tortoise/contrib/mysql/functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import annotations

from pypika_tortoise.terms import Function
from pypika_tortoise import functions
from pypika_tortoise.terms import Function, Term

from tortoise.expressions import CombinedExpression, F
from tortoise.functions import Function as TortoiseFunction


class Rand(Function):
Expand All @@ -13,3 +17,39 @@ class Rand(Function):
def __init__(self, seed: int | None = None, alias=None) -> None:
super().__init__("RAND", seed, alias=alias)
self.args = [self.wrap_constant(seed)] if seed is not None else []


class LPad(TortoiseFunction):
"""
Pads the left side of a string with a specified character to reach a certain length.

:samp:`LPad("{FIELD_NAME}", length, fill_text)`
"""

def __init__(
self,
field: str | F | CombinedExpression | TortoiseFunction | Term,
length: int,
fill_text: str = " ",
) -> None:
super().__init__(field, length, fill_text)

database_func = functions.LPad


class RPad(TortoiseFunction):
"""
Pads the right side of a string with a specified character to reach a certain length.

:samp:`RPad("{FIELD_NAME}", length, fill_text)`
"""

def __init__(
self,
field: str | F | CombinedExpression | TortoiseFunction | Term,
length: int,
fill_text: str = " ",
) -> None:
super().__init__(field, length, fill_text)

database_func = functions.RPad
40 changes: 40 additions & 0 deletions tortoise/contrib/postgres/functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from pypika_tortoise import functions
from pypika_tortoise.terms import Function, Term

from tortoise.expressions import CombinedExpression, F
from tortoise.functions import Function as TortoiseFunction


class ToTsVector(Function):
"""
Expand Down Expand Up @@ -37,3 +41,39 @@ class Random(Function):

def __init__(self, alias=None) -> None:
super().__init__("RANDOM", alias=alias)


class LPad(TortoiseFunction):
"""
Pads the left side of a string with a specified character to reach a certain length.

:samp:`LPad("{FIELD_NAME}", length, fill_text)`
"""

def __init__(
self,
field: str | F | CombinedExpression | TortoiseFunction | Term,
length: int,
fill_text: str = " ",
) -> None:
super().__init__(field, length, fill_text)

database_func = functions.LPad


class RPad(TortoiseFunction):
"""
Pads the right side of a string with a specified character to reach a certain length.

:samp:`RPad("{FIELD_NAME}", length, fill_text)`
"""

def __init__(
self,
field: str | F | CombinedExpression | TortoiseFunction | Term,
length: int,
fill_text: str = " ",
) -> None:
super().__init__(field, length, fill_text)

database_func = functions.RPad
54 changes: 53 additions & 1 deletion tortoise/functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Any

from pypika_tortoise import SqlContext, functions
from pypika_tortoise.terms import Term

from tortoise.expressions import Aggregate, Function
from tortoise.expressions import Aggregate, CombinedExpression, F, Function

##############################################################################
# Standard functions
Expand All @@ -16,6 +19,55 @@ class Trim(Function):

database_func = functions.Trim

def __init__(
self,
field: str | F | CombinedExpression | Function | Term,
trim_chars: str = " ",
*default_values: Any,
) -> None:
super().__init__(field, trim_chars, *default_values)

database_func = functions.Trim


class LTrim(Function):
"""
Trims whitespace from the left side of text.

:samp:`LTrim("{FIELD_NAME}")`
"""

database_func = functions.LTrim


class RTrim(Function):
"""
Trims whitespace from the right side of text.

:samp:`RTrim("{FIELD_NAME}")`
"""

database_func = functions.RTrim


class Replace(Function):
"""
Replaces all occurrences of a search string with a replacement string.

:samp:`Replace("{FIELD_NAME}", "search", "replacement")`
"""

def __init__(
self,
field: str | F | CombinedExpression | Function | Term,
search: str,
replacement: str,
*default_values: Any,
) -> None:
super().__init__(field, search, replacement, *default_values)

database_func = functions.Replace


class Length(Function):
"""
Expand Down
8 changes: 2 additions & 6 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading