From 4e2ba51e9c31186415121976877c3f3b3373bb5a Mon Sep 17 00:00:00 2001 From: noxymon Date: Thu, 5 Mar 2026 23:16:44 +0700 Subject: [PATCH 1/3] feat: infer parameter types from function call return types in comparison expressions. --- internal/compiler/resolve.go | 30 ++++++++++++ .../testdata/mysql_datediff_type/mysql/db.go | 31 ++++++++++++ .../mysql_datediff_type/mysql/models.go | 16 +++++++ .../mysql_datediff_type/mysql/query.sql.go | 48 +++++++++++++++++++ .../testdata/mysql_datediff_type/query.sql | 6 +++ .../testdata/mysql_datediff_type/schema.sql | 5 ++ .../testdata/mysql_datediff_type/sqlc.json | 16 +++++++ 7 files changed, 152 insertions(+) create mode 100644 internal/endtoend/testdata/mysql_datediff_type/mysql/db.go create mode 100644 internal/endtoend/testdata/mysql_datediff_type/mysql/models.go create mode 100644 internal/endtoend/testdata/mysql_datediff_type/mysql/query.sql.go create mode 100644 internal/endtoend/testdata/mysql_datediff_type/query.sql create mode 100644 internal/endtoend/testdata/mysql_datediff_type/schema.sql create mode 100644 internal/endtoend/testdata/mysql_datediff_type/sqlc.json diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index b1fbb1990e..efd60e2e13 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -142,6 +142,36 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, }) case *ast.A_Expr: + // If one side of the comparison is a direct FuncCall, use the + // function's return type for the parameter. This prevents the + // ColumnRef search below from descending into the function's + // arguments and incorrectly using a nested column's type + // (e.g. DATEDIFF(date_from, NOW()) >= ? should yield int, not date). + var funcCallSide *ast.FuncCall + if fc, ok := n.Lexpr.(*ast.FuncCall); ok { + funcCallSide = fc + } else if fc, ok := n.Rexpr.(*ast.FuncCall); ok { + funcCallSide = fc + } + if funcCallSide != nil { + fun, ferr := c.ResolveFuncCall(funcCallSide) + if ferr == nil && fun.ReturnType != nil && fun.ReturnType.Name != "any" { + defaultP := named.NewInferredParam(ref.name, true) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) + a = append(a, Parameter{ + Number: ref.ref.Number, + Column: &Column{ + Name: p.Name(), + DataType: dataType(fun.ReturnType), + NotNull: p.NotNull(), + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), + }, + }) + continue + } + } + // TODO: While this works for a wide range of simple expressions, // more complicated expressions will cause this logic to fail. list := astutils.Search(n.Lexpr, func(node ast.Node) bool { diff --git a/internal/endtoend/testdata/mysql_datediff_type/mysql/db.go b/internal/endtoend/testdata/mysql_datediff_type/mysql/db.go new file mode 100644 index 0000000000..f490dfa564 --- /dev/null +++ b/internal/endtoend/testdata/mysql_datediff_type/mysql/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package mysql_datediff_type + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/mysql_datediff_type/mysql/models.go b/internal/endtoend/testdata/mysql_datediff_type/mysql/models.go new file mode 100644 index 0000000000..95c9216e49 --- /dev/null +++ b/internal/endtoend/testdata/mysql_datediff_type/mysql/models.go @@ -0,0 +1,16 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package mysql_datediff_type + +import ( + "database/sql" + "time" +) + +type WishlistItem struct { + ID uint32 + DateFrom time.Time + UpdatedAt sql.NullTime +} diff --git a/internal/endtoend/testdata/mysql_datediff_type/mysql/query.sql.go b/internal/endtoend/testdata/mysql_datediff_type/mysql/query.sql.go new file mode 100644 index 0000000000..986a2f45af --- /dev/null +++ b/internal/endtoend/testdata/mysql_datediff_type/mysql/query.sql.go @@ -0,0 +1,48 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: query.sql + +package mysql_datediff_type + +import ( + "context" + "database/sql" +) + +const getUpdateableWishlistItemIDs = `-- name: GetUpdateableWishlistItemIDs :many +SELECT id +FROM wishlist_item +WHERE DATEDIFF(date_from, NOW()) >= ? + AND DATEDIFF(date_from, NOW()) <= ? + AND updated_at < ? +` + +type GetUpdateableWishlistItemIDsParams struct { + MinDaysToDateFrom int32 + MaxDaysToDateFrom int32 + UpdatedBy sql.NullTime +} + +func (q *Queries) GetUpdateableWishlistItemIDs(ctx context.Context, arg GetUpdateableWishlistItemIDsParams) ([]uint32, error) { + rows, err := q.db.QueryContext(ctx, getUpdateableWishlistItemIDs, arg.MinDaysToDateFrom, arg.MaxDaysToDateFrom, arg.UpdatedBy) + if err != nil { + return nil, err + } + defer rows.Close() + var items []uint32 + for rows.Next() { + var id uint32 + if err := rows.Scan(&id); err != nil { + return nil, err + } + items = append(items, id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/mysql_datediff_type/query.sql b/internal/endtoend/testdata/mysql_datediff_type/query.sql new file mode 100644 index 0000000000..bed8df1ec6 --- /dev/null +++ b/internal/endtoend/testdata/mysql_datediff_type/query.sql @@ -0,0 +1,6 @@ +-- name: GetUpdateableWishlistItemIDs :many +SELECT id +FROM wishlist_item +WHERE DATEDIFF(date_from, NOW()) >= sqlc.arg('min_days_to_date_from') + AND DATEDIFF(date_from, NOW()) <= sqlc.arg('max_days_to_date_from') + AND updated_at < sqlc.arg('updated_by'); diff --git a/internal/endtoend/testdata/mysql_datediff_type/schema.sql b/internal/endtoend/testdata/mysql_datediff_type/schema.sql new file mode 100644 index 0000000000..4689e0992b --- /dev/null +++ b/internal/endtoend/testdata/mysql_datediff_type/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE wishlist_item ( + id INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, + date_from DATE NOT NULL, + updated_at TIMESTAMP NULL DEFAULT NULL +); diff --git a/internal/endtoend/testdata/mysql_datediff_type/sqlc.json b/internal/endtoend/testdata/mysql_datediff_type/sqlc.json new file mode 100644 index 0000000000..89aec053a3 --- /dev/null +++ b/internal/endtoend/testdata/mysql_datediff_type/sqlc.json @@ -0,0 +1,16 @@ +{ + "version": "2", + "sql": [ + { + "engine": "mysql", + "queries": "query.sql", + "schema": "schema.sql", + "gen": { + "go": { + "package": "mysql_datediff_type", + "out": "mysql" + } + } + } + ] +} From b634f95712833367a4c59d9402ce8ee47ac6c4a7 Mon Sep 17 00:00:00 2001 From: noxymon Date: Fri, 6 Mar 2026 08:05:26 +0700 Subject: [PATCH 2/3] fix: improve function return type resolution to correctly handle 'any' prefixes and 'record' types. --- internal/compiler/resolve.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index efd60e2e13..0ab3403786 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -4,6 +4,7 @@ import ( "fmt" "log/slog" "strconv" + "strings" "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/astutils" @@ -155,7 +156,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } if funcCallSide != nil { fun, ferr := c.ResolveFuncCall(funcCallSide) - if ferr == nil && fun.ReturnType != nil && fun.ReturnType.Name != "any" { + if ferr == nil && fun.ReturnType != nil && !strings.HasPrefix(fun.ReturnType.Name, "any") && fun.ReturnType.Name != "record" { defaultP := named.NewInferredParam(ref.name, true) p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ From c5d3dfdecd4e5f6efe6bee2cf51c42ce35fbcc62 Mon Sep 17 00:00:00 2001 From: noxymon Date: Fri, 6 Mar 2026 09:10:30 +0700 Subject: [PATCH 3/3] ad guardrails when infer parameter names from column references when `ref.name` is empty. --- internal/compiler/resolve.go | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index 0ab3403786..43f51f22ff 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -157,7 +157,22 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if funcCallSide != nil { fun, ferr := c.ResolveFuncCall(funcCallSide) if ferr == nil && fun.ReturnType != nil && !strings.HasPrefix(fun.ReturnType.Name, "any") && fun.ReturnType.Name != "record" { - defaultP := named.NewInferredParam(ref.name, true) + paramName := ref.name + if paramName == "" { + fcList := astutils.Search(funcCallSide, func(node ast.Node) bool { + _, ok := node.(*ast.ColumnRef) + return ok + }) + if len(fcList.Items) > 0 { + if cr, ok := fcList.Items[0].(*ast.ColumnRef); ok { + items := stringSlice(cr.Fields) + if len(items) > 0 { + paramName = items[len(items)-1] + } + } + } + } + defaultP := named.NewInferredParam(paramName, true) p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number,