diff --git a/pyproject.toml b/pyproject.toml index ebfc112567..56d66ecff5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "requests", "rich[jupyter]", "ruamel.yaml", - "sqlglot[rs]~=28.10.1", + "sqlglot~=30.0.1", "tenacity", "time-machine", "json-stream" diff --git a/sqlmesh/core/_typing.py b/sqlmesh/core/_typing.py index 8e28312c1a..2bc69e901b 100644 --- a/sqlmesh/core/_typing.py +++ b/sqlmesh/core/_typing.py @@ -8,8 +8,8 @@ if t.TYPE_CHECKING: TableName = t.Union[str, exp.Table] SchemaName = t.Union[str, exp.Table] - SessionProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]] - CustomMaterializationProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]] + SessionProperties = t.Dict[str, t.Union[exp.Expr, str, int, float, bool]] + CustomMaterializationProperties = t.Dict[str, t.Union[exp.Expr, str, int, float, bool]] if sys.version_info >= (3, 11): diff --git a/sqlmesh/core/audit/definition.py b/sqlmesh/core/audit/definition.py index 9f470872fe..4c90151ee4 100644 --- a/sqlmesh/core/audit/definition.py +++ b/sqlmesh/core/audit/definition.py @@ -67,7 +67,7 @@ class AuditMixin(AuditCommonMetaMixin): """ query_: ParsableSql - defaults: t.Dict[str, exp.Expression] + defaults: t.Dict[str, exp.Expr] expressions_: t.Optional[t.List[ParsableSql]] jinja_macros: JinjaMacroRegistry formatting: t.Optional[bool] @@ -77,10 +77,10 @@ def query(self) -> t.Union[exp.Query, d.JinjaQuery]: return t.cast(t.Union[exp.Query, d.JinjaQuery], self.query_.parse(self.dialect)) @property - def expressions(self) -> t.List[exp.Expression]: + def expressions(self) -> t.List[exp.Expr]: if not self.expressions_: return [] - result = [] + result: t.List[exp.Expr] = [] for e in self.expressions_: parsed = e.parse(self.dialect) if not isinstance(parsed, exp.Semicolon): @@ -95,7 +95,7 @@ def macro_definitions(self) -> t.List[d.MacroDef]: @field_validator("name", "dialect", mode="before", check_fields=False) def audit_string_validator(cls: t.Type, v: t.Any) -> t.Optional[str]: - if isinstance(v, exp.Expression): + if isinstance(v, exp.Expr): return v.name.lower() return str(v).lower() if v is not None else None @@ -111,9 +111,7 @@ def audit_map_validator(cls: t.Type, v: t.Any, values: t.Any) -> t.Dict[str, t.A if isinstance(v, dict): dialect = get_dialect(values) return { - key: value - if isinstance(value, exp.Expression) - else d.parse_one(str(value), dialect=dialect) + key: value if isinstance(value, exp.Expr) else d.parse_one(str(value), dialect=dialect) for key, value in v.items() } raise_config_error("Defaults must be a tuple of exp.EQ or a dict", error_type=AuditConfigError) @@ -133,7 +131,7 @@ class ModelAudit(PydanticModel, AuditMixin, DbtInfoMixin, frozen=True): blocking: bool = True standalone: t.Literal[False] = False query_: ParsableSql = Field(alias="query") - defaults: t.Dict[str, exp.Expression] = {} + defaults: t.Dict[str, exp.Expr] = {} expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions") jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry() formatting: t.Optional[bool] = Field(default=None, exclude=True) @@ -169,7 +167,7 @@ class StandaloneAudit(_Node, AuditMixin): blocking: bool = False standalone: t.Literal[True] = True query_: ParsableSql = Field(alias="query") - defaults: t.Dict[str, exp.Expression] = {} + defaults: t.Dict[str, exp.Expr] = {} expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions") jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry() default_catalog: t.Optional[str] = None @@ -323,13 +321,13 @@ def render_definition( include_python: bool = True, include_defaults: bool = False, render_query: bool = False, - ) -> t.List[exp.Expression]: + ) -> t.List[exp.Expr]: """Returns the original list of sql expressions comprising the model definition. Args: include_python: Whether or not to include Python code in the rendered definition. """ - expressions: t.List[exp.Expression] = [] + expressions: t.List[exp.Expr] = [] comment = None for field_name in sorted(self.meta_fields): field_value = getattr(self, field_name) @@ -381,7 +379,7 @@ def meta_fields(self) -> t.Iterable[str]: return set(AuditCommonMetaMixin.__annotations__) | set(_Node.all_field_infos()) @property - def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expression]]]: + def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expr]]]: return [(self, {})] @@ -389,7 +387,7 @@ def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expression]] def load_audit( - expressions: t.List[exp.Expression], + expressions: t.List[exp.Expr], *, path: Path = Path(), module_path: Path = Path(), @@ -499,7 +497,7 @@ def load_audit( def load_multiple_audits( - expressions: t.List[exp.Expression], + expressions: t.List[exp.Expr], *, path: Path = Path(), module_path: Path = Path(), @@ -510,7 +508,7 @@ def load_multiple_audits( variables: t.Optional[t.Dict[str, t.Any]] = None, project: t.Optional[str] = None, ) -> t.Generator[Audit, None, None]: - audit_block: t.List[exp.Expression] = [] + audit_block: t.List[exp.Expr] = [] for expression in expressions: if isinstance(expression, d.Audit): if audit_block: @@ -543,7 +541,7 @@ def _raise_config_error(msg: str, path: pathlib.Path) -> None: # mypy doesn't realize raise_config_error raises an exception @t.no_type_check -def _maybe_parse_arg_pair(e: exp.Expression) -> t.Tuple[str, exp.Expression]: +def _maybe_parse_arg_pair(e: exp.Expr) -> t.Tuple[str, exp.Expr]: if isinstance(e, exp.EQ): return e.left.name, e.right diff --git a/sqlmesh/core/config/linter.py b/sqlmesh/core/config/linter.py index c2a40e09aa..11d700c540 100644 --- a/sqlmesh/core/config/linter.py +++ b/sqlmesh/core/config/linter.py @@ -34,7 +34,7 @@ def _validate_rules(cls, v: t.Any) -> t.Set[str]: v = v.unnest().name elif isinstance(v, (exp.Tuple, exp.Array)): v = [e.name for e in v.expressions] - elif isinstance(v, exp.Expression): + elif isinstance(v, exp.Expr): v = v.name return {name.lower() for name in ensure_collection(v)} diff --git a/sqlmesh/core/config/model.py b/sqlmesh/core/config/model.py index aeefdf2557..ac41d75fe3 100644 --- a/sqlmesh/core/config/model.py +++ b/sqlmesh/core/config/model.py @@ -71,9 +71,9 @@ class ModelDefaultsConfig(BaseConfig): enabled: t.Optional[t.Union[str, bool]] = None formatting: t.Optional[t.Union[str, bool]] = None batch_concurrency: t.Optional[int] = None - pre_statements: t.Optional[t.List[t.Union[str, exp.Expression]]] = None - post_statements: t.Optional[t.List[t.Union[str, exp.Expression]]] = None - on_virtual_update: t.Optional[t.List[t.Union[str, exp.Expression]]] = None + pre_statements: t.Optional[t.List[t.Union[str, exp.Expr]]] = None + post_statements: t.Optional[t.List[t.Union[str, exp.Expr]]] = None + on_virtual_update: t.Optional[t.List[t.Union[str, exp.Expr]]] = None _model_kind_validator = model_kind_validator _on_destructive_change_validator = on_destructive_change_validator diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 860194278b..dc51aad2a7 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -234,7 +234,7 @@ def resolve_table(self, model_name: str) -> str: ) def fetchdf( - self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False + self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False ) -> pd.DataFrame: """Fetches a dataframe given a sql string or sqlglot expression. @@ -248,7 +248,7 @@ def fetchdf( return self.engine_adapter.fetchdf(query, quote_identifiers=quote_identifiers) def fetch_pyspark_df( - self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False + self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False ) -> PySparkDataFrame: """Fetches a PySpark dataframe given a sql string or sqlglot expression. @@ -1105,7 +1105,7 @@ def render( execution_time: t.Optional[TimeLike] = None, expand: t.Union[bool, t.Iterable[str]] = False, **kwargs: t.Any, - ) -> exp.Expression: + ) -> exp.Expr: """Renders a model's query, expanding macros with provided kwargs, and optionally expanding referenced models. Args: @@ -1860,10 +1860,10 @@ def table_diff( self, source: str, target: str, - on: t.Optional[t.List[str] | exp.Condition] = None, + on: t.Optional[t.List[str] | exp.Expr] = None, skip_columns: t.Optional[t.List[str]] = None, select_models: t.Optional[t.Collection[str]] = None, - where: t.Optional[str | exp.Condition] = None, + where: t.Optional[str | exp.Expr] = None, limit: int = 20, show: bool = True, show_sample: bool = True, @@ -1922,7 +1922,7 @@ def table_diff( raise SQLMeshError(e) models_to_diff: t.List[ - t.Tuple[Model, EngineAdapter, str, str, t.Optional[t.List[str] | exp.Condition]] + t.Tuple[Model, EngineAdapter, str, str, t.Optional[t.List[str] | exp.Expr]] ] = [] models_without_grain: t.List[Model] = [] source_snapshots_to_name = { @@ -2041,9 +2041,9 @@ def _model_diff( target_alias: str, limit: int, decimals: int, - on: t.Optional[t.List[str] | exp.Condition] = None, + on: t.Optional[t.List[str] | exp.Expr] = None, skip_columns: t.Optional[t.List[str]] = None, - where: t.Optional[str | exp.Condition] = None, + where: t.Optional[str | exp.Expr] = None, show: bool = True, temp_schema: t.Optional[str] = None, skip_grain_check: bool = False, @@ -2083,10 +2083,10 @@ def _table_diff( limit: int, decimals: int, adapter: EngineAdapter, - on: t.Optional[t.List[str] | exp.Condition] = None, + on: t.Optional[t.List[str] | exp.Expr] = None, model: t.Optional[Model] = None, skip_columns: t.Optional[t.List[str]] = None, - where: t.Optional[str | exp.Condition] = None, + where: t.Optional[str | exp.Expr] = None, schema_diff_ignore_case: bool = False, ) -> TableDiff: if not on: @@ -2344,7 +2344,7 @@ def audit( return not errors @python_api_analytics - def rewrite(self, sql: str, dialect: str = "") -> exp.Expression: + def rewrite(self, sql: str, dialect: str = "") -> exp.Expr: """Rewrite a sql expression with semantic references into an executable query. https://sqlmesh.readthedocs.io/en/latest/concepts/metrics/overview/ diff --git a/sqlmesh/core/context_diff.py b/sqlmesh/core/context_diff.py index 07d13b1c2f..047e58609a 100644 --- a/sqlmesh/core/context_diff.py +++ b/sqlmesh/core/context_diff.py @@ -36,7 +36,7 @@ from sqlmesh.utils.metaprogramming import Executable # noqa from sqlmesh.core.environment import EnvironmentStatements -IGNORED_PACKAGES = {"sqlmesh", "sqlglot"} +IGNORED_PACKAGES = {"sqlmesh", "sqlglot", "sqlglotc"} class ContextDiff(PydanticModel): diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index c0a48326f2..122b287ac0 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -14,6 +14,7 @@ from sqlglot.dialects.dialect import DialectType from sqlglot.dialects import DuckDB, Snowflake, TSQL import sqlglot.dialects.athena as athena +from sqlglot.parsers.athena import AthenaTrinoParser from sqlglot.helper import seq_get from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import quote_identifiers @@ -52,7 +53,7 @@ class Metric(exp.Expression): arg_types = {"expressions": True} -class Jinja(exp.Func): +class Jinja(exp.Expression, exp.Func): arg_types = {"this": True} @@ -76,7 +77,7 @@ class MacroVar(exp.Var): pass -class MacroFunc(exp.Func): +class MacroFunc(exp.Expression, exp.Func): @property def name(self) -> str: return self.this.name @@ -102,7 +103,7 @@ class DColonCast(exp.Cast): pass -class MetricAgg(exp.AggFunc): +class MetricAgg(exp.Expression, exp.AggFunc): """Used for computing metrics.""" arg_types = {"this": True} @@ -118,7 +119,7 @@ class StagedFilePath(exp.Expression): arg_types = exp.Table.arg_types.copy() -def _parse_statement(self: Parser) -> t.Optional[exp.Expression]: +def _parse_statement(self: Parser) -> t.Optional[exp.Expr]: if self._curr is None: return None @@ -152,7 +153,7 @@ def _parse_statement(self: Parser) -> t.Optional[exp.Expression]: raise -def _parse_lambda(self: Parser, alias: bool = False) -> t.Optional[exp.Expression]: +def _parse_lambda(self: Parser, alias: bool = False) -> t.Optional[exp.Expr]: node = self.__parse_lambda(alias=alias) # type: ignore if isinstance(node, exp.Lambda): node.set("this", self._parse_alias(node.this)) @@ -163,7 +164,7 @@ def _parse_id_var( self: Parser, any_token: bool = True, tokens: t.Optional[t.Collection[TokenType]] = None, -) -> t.Optional[exp.Expression]: +) -> t.Optional[exp.Expr]: if self._prev and self._prev.text == SQLMESH_MACRO_PREFIX and self._match(TokenType.L_BRACE): identifier = self.__parse_id_var(any_token=any_token, tokens=tokens) # type: ignore if not self._match(TokenType.R_BRACE): @@ -207,12 +208,12 @@ def _parse_id_var( else: self.raise_error("Expecting }") - identifier = self.expression(exp.Identifier, this=this, quoted=identifier.quoted) + identifier = self.expression(exp.Identifier(this=this, quoted=identifier.quoted)) return identifier -def _parse_macro(self: Parser, keyword_macro: str = "") -> t.Optional[exp.Expression]: +def _parse_macro(self: Parser, keyword_macro: str = "") -> t.Optional[exp.Expr]: if self._prev.text != SQLMESH_MACRO_PREFIX: return self._parse_parameter() @@ -220,7 +221,7 @@ def _parse_macro(self: Parser, keyword_macro: str = "") -> t.Optional[exp.Expres index = self._index field = self._parse_primary() or self._parse_function(functions={}) or self._parse_id_var() - def _build_macro(field: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + def _build_macro(field: t.Optional[exp.Expr]) -> t.Optional[exp.Expr]: if isinstance(field, exp.Func): macro_name = field.name.upper() if macro_name != keyword_macro and macro_name in KEYWORD_MACROS: @@ -230,37 +231,39 @@ def _build_macro(field: t.Optional[exp.Expression]) -> t.Optional[exp.Expression if isinstance(field, exp.Anonymous): if macro_name == "DEF": return self.expression( - MacroDef, - this=field.expressions[0], - expression=field.expressions[1], + MacroDef( + this=field.expressions[0], + expression=field.expressions[1], + ), comments=comments, ) if macro_name == "SQL": into = field.expressions[1].this.lower() if len(field.expressions) > 1 else None return self.expression( - MacroSQL, this=field.expressions[0], into=into, comments=comments + MacroSQL(this=field.expressions[0], into=into), comments=comments ) else: field = self.expression( - exp.Anonymous, - this=field.sql_name(), - expressions=list(field.args.values()), + exp.Anonymous( + this=field.sql_name(), + expressions=list(field.args.values()), + ), comments=comments, ) - return self.expression(MacroFunc, this=field, comments=comments) + return self.expression(MacroFunc(this=field), comments=comments) if field is None: return None if field.is_string or (isinstance(field, exp.Identifier) and field.quoted): return self.expression( - MacroStrReplace, this=exp.Literal.string(field.this), comments=comments + MacroStrReplace(this=exp.Literal.string(field.this)), comments=comments ) if "@" in field.this: - return field - return self.expression(MacroVar, this=field.this, comments=comments) + return field # type: ignore[return-value] + return self.expression(MacroVar(this=field.this), comments=comments) if isinstance(field, (exp.Window, exp.IgnoreNulls, exp.RespectNulls)): field.set("this", _build_macro(field.this)) @@ -273,7 +276,7 @@ def _build_macro(field: t.Optional[exp.Expression]) -> t.Optional[exp.Expression KEYWORD_MACROS = {"WITH", "JOIN", "WHERE", "GROUP_BY", "HAVING", "ORDER_BY", "LIMIT"} -def _parse_matching_macro(self: Parser, name: str) -> t.Optional[exp.Expression]: +def _parse_matching_macro(self: Parser, name: str) -> t.Optional[exp.Expr]: if not self._match_pair(TokenType.PARAMETER, TokenType.VAR, advance=False) or ( self._next and self._next.text.upper() != name.upper() ): @@ -283,7 +286,7 @@ def _parse_matching_macro(self: Parser, name: str) -> t.Optional[exp.Expression] return _parse_macro(self, keyword_macro=name) -def _parse_body_macro(self: Parser) -> t.Tuple[str, t.Optional[exp.Expression]]: +def _parse_body_macro(self: Parser) -> t.Tuple[str, t.Optional[exp.Expr]]: name = self._next and self._next.text.upper() if name == "JOIN": @@ -301,7 +304,7 @@ def _parse_body_macro(self: Parser) -> t.Tuple[str, t.Optional[exp.Expression]]: return ("", None) -def _parse_with(self: Parser, skip_with_token: bool = False) -> t.Optional[exp.Expression]: +def _parse_with(self: Parser, skip_with_token: bool = False) -> t.Optional[exp.Expr]: macro = _parse_matching_macro(self, "WITH") if not macro: return self.__parse_with(skip_with_token=skip_with_token) # type: ignore @@ -312,7 +315,7 @@ def _parse_with(self: Parser, skip_with_token: bool = False) -> t.Optional[exp.E def _parse_join( self: Parser, skip_join_token: bool = False, parse_bracket: bool = False -) -> t.Optional[exp.Expression]: +) -> t.Optional[exp.Expr]: index = self._index method, side, kind = self._parse_join_parts() macro = _parse_matching_macro(self, "JOIN") @@ -351,7 +354,7 @@ def _parse_select( parse_set_operation: bool = True, consume_pipe: bool = True, from_: t.Optional[exp.From] = None, -) -> t.Optional[exp.Expression]: +) -> t.Optional[exp.Expr]: select = self.__parse_select( # type: ignore nested=nested, table=table, @@ -372,7 +375,7 @@ def _parse_select( return select -def _parse_where(self: Parser, skip_where_token: bool = False) -> t.Optional[exp.Expression]: +def _parse_where(self: Parser, skip_where_token: bool = False) -> t.Optional[exp.Expr]: macro = _parse_matching_macro(self, "WHERE") if not macro: return self.__parse_where(skip_where_token=skip_where_token) # type: ignore @@ -381,7 +384,7 @@ def _parse_where(self: Parser, skip_where_token: bool = False) -> t.Optional[exp return macro -def _parse_group(self: Parser, skip_group_by_token: bool = False) -> t.Optional[exp.Expression]: +def _parse_group(self: Parser, skip_group_by_token: bool = False) -> t.Optional[exp.Expr]: macro = _parse_matching_macro(self, "GROUP_BY") if not macro: return self.__parse_group(skip_group_by_token=skip_group_by_token) # type: ignore @@ -390,7 +393,7 @@ def _parse_group(self: Parser, skip_group_by_token: bool = False) -> t.Optional[ return macro -def _parse_having(self: Parser, skip_having_token: bool = False) -> t.Optional[exp.Expression]: +def _parse_having(self: Parser, skip_having_token: bool = False) -> t.Optional[exp.Expr]: macro = _parse_matching_macro(self, "HAVING") if not macro: return self.__parse_having(skip_having_token=skip_having_token) # type: ignore @@ -400,8 +403,8 @@ def _parse_having(self: Parser, skip_having_token: bool = False) -> t.Optional[e def _parse_order( - self: Parser, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False -) -> t.Optional[exp.Expression]: + self: Parser, this: t.Optional[exp.Expr] = None, skip_order_token: bool = False +) -> t.Optional[exp.Expr]: macro = _parse_matching_macro(self, "ORDER_BY") if not macro: return self.__parse_order(this, skip_order_token=skip_order_token) # type: ignore @@ -412,10 +415,10 @@ def _parse_order( def _parse_limit( self: Parser, - this: t.Optional[exp.Expression] = None, + this: t.Optional[exp.Expr] = None, top: bool = False, skip_limit_token: bool = False, -) -> t.Optional[exp.Expression]: +) -> t.Optional[exp.Expr]: macro = _parse_matching_macro(self, "TOP" if top else "LIMIT") if not macro: return self.__parse_limit(this, top=top, skip_limit_token=skip_limit_token) # type: ignore @@ -424,7 +427,7 @@ def _parse_limit( return macro -def _parse_value(self: Parser, values: bool = True) -> t.Optional[exp.Expression]: +def _parse_value(self: Parser, values: bool = True) -> t.Optional[exp.Expr]: wrapped = self._match(TokenType.L_PAREN, advance=False) # The base _parse_value method always constructs a Tuple instance. This is problematic when @@ -438,11 +441,11 @@ def _parse_value(self: Parser, values: bool = True) -> t.Optional[exp.Expression return expr -def _parse_macro_or_clause(self: Parser, parser: t.Callable) -> t.Optional[exp.Expression]: +def _parse_macro_or_clause(self: Parser, parser: t.Callable) -> t.Optional[exp.Expr]: return _parse_macro(self) if self._match(TokenType.PARAMETER) else parser() -def _parse_props(self: Parser) -> t.Optional[exp.Expression]: +def _parse_props(self: Parser) -> t.Optional[exp.Expr]: key = self._parse_id_var(any_token=True) if not key: return None @@ -460,7 +463,7 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]: elif name == "merge_filter": value = self._parse_conjunction() elif self._match(TokenType.L_PAREN): - value = self.expression(exp.Tuple, expressions=self._parse_csv(self._parse_equality)) + value = self.expression(exp.Tuple(expressions=self._parse_csv(self._parse_equality))) self._match_r_paren() else: value = self._parse_bracket(self._parse_field(any_token=True)) @@ -469,7 +472,7 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]: # Make sure if we get a windows path that it is converted to posix value = exp.Literal.string(value.this.replace("\\", "/")) # type: ignore - return self.expression(exp.Property, this=name, value=value) + return self.expression(exp.Property(this=name, value=value)) def _parse_types( @@ -477,7 +480,7 @@ def _parse_types( check_func: bool = False, schema: bool = False, allow_identifiers: bool = True, -) -> t.Optional[exp.Expression]: +) -> t.Optional[exp.Expr]: start = self._curr parsed_type = self.__parse_types( # type: ignore check_func=check_func, schema=schema, allow_identifiers=allow_identifiers @@ -534,7 +537,7 @@ def _parse_table_parts( return table -def _parse_if(self: Parser) -> t.Optional[exp.Expression]: +def _parse_if(self: Parser) -> t.Optional[exp.Expr]: # If we fail to parse an IF function with expressions as arguments, we then try # to parse a statement / command to support the macro @IF(condition, statement) index = self._index @@ -566,11 +569,11 @@ def _parse_if(self: Parser) -> t.Optional[exp.Expression]: return exp.Anonymous(this="IF", expressions=[cond, stmt]) -def _create_parser(expression_type: t.Type[exp.Expression], table_keys: t.List[str]) -> t.Callable: - def parse(self: Parser) -> t.Optional[exp.Expression]: +def _create_parser(expression_type: t.Type[exp.Expr], table_keys: t.List[str]) -> t.Callable: + def parse(self: Parser) -> t.Optional[exp.Expr]: from sqlmesh.core.model.kind import ModelKindName - expressions: t.List[exp.Expression] = [] + expressions: t.List[exp.Expr] = [] while True: prev_property = seq_get(expressions, -1) @@ -589,7 +592,7 @@ def parse(self: Parser) -> t.Optional[exp.Expression]: key = key_expression.name.lower() start = self._curr - value: t.Optional[exp.Expression | str] + value: t.Optional[exp.Expr | str] if key in table_keys: value = self._parse_table_parts() @@ -629,7 +632,7 @@ def parse(self: Parser) -> t.Optional[exp.Expression]: else: props = None - value = self.expression(ModelKind, this=kind.value, expressions=props) + value = self.expression(ModelKind(this=kind.value, expressions=props)) elif key == "expression": value = self._parse_conjunction() elif key == "partitioned_by": @@ -641,12 +644,12 @@ def parse(self: Parser) -> t.Optional[exp.Expression]: else: value = self._parse_bracket(self._parse_field(any_token=True)) - if isinstance(value, exp.Expression): + if isinstance(value, exp.Expr): value.meta["sql"] = self._find_sql(start, self._prev) - expressions.append(self.expression(exp.Property, this=key, value=value)) + expressions.append(self.expression(exp.Property(this=key, value=value))) - return self.expression(expression_type, expressions=expressions) + return self.expression(expression_type(expressions=expressions)) return parse @@ -658,7 +661,7 @@ def parse(self: Parser) -> t.Optional[exp.Expression]: } -def _props_sql(self: Generator, expressions: t.List[exp.Expression]) -> str: +def _props_sql(self: Generator, expressions: t.List[exp.Expr]) -> str: props = [] size = len(expressions) @@ -676,7 +679,7 @@ def _props_sql(self: Generator, expressions: t.List[exp.Expression]) -> str: return "\n".join(props) -def _on_virtual_update_sql(self: Generator, expressions: t.List[exp.Expression]) -> str: +def _on_virtual_update_sql(self: Generator, expressions: t.List[exp.Expr]) -> str: statements = "\n".join( self.sql(expression) if isinstance(expression, JinjaStatement) @@ -697,7 +700,7 @@ def _model_kind_sql(self: Generator, expression: ModelKind) -> str: return expression.name.upper() -def _macro_keyword_func_sql(self: Generator, expression: exp.Expression) -> str: +def _macro_keyword_func_sql(self: Generator, expression: exp.Expr) -> str: name = expression.name keyword = name.replace("_", " ") *args, clause = expression.expressions @@ -731,7 +734,7 @@ def _override(klass: t.Type[Tokenizer | Parser], func: t.Callable) -> None: def format_model_expressions( - expressions: t.List[exp.Expression], + expressions: t.List[exp.Expr], dialect: t.Optional[str] = None, rewrite_casts: bool = True, **kwargs: t.Any, @@ -752,7 +755,7 @@ def format_model_expressions( if rewrite_casts: - def cast_to_colon(node: exp.Expression) -> exp.Expression: + def cast_to_colon(node: exp.Expr) -> exp.Expr: if isinstance(node, exp.Cast) and not any( # Only convert CAST into :: if it doesn't have additional args set, otherwise this # conversion could alter the semantics (eg. changing SAFE_CAST in BigQuery to CAST) @@ -784,8 +787,8 @@ def cast_to_colon(node: exp.Expression) -> exp.Expression: def text_diff( - a: t.List[exp.Expression], - b: t.List[exp.Expression], + a: t.List[exp.Expr], + b: t.List[exp.Expr], a_dialect: t.Optional[str] = None, b_dialect: t.Optional[str] = None, ) -> str: @@ -860,7 +863,7 @@ def _is_virtual_statement_end(tokens: t.List[Token], pos: int) -> bool: return _is_command_statement(ON_VIRTUAL_UPDATE_END, tokens, pos) -def virtual_statement(statements: t.List[exp.Expression]) -> VirtualUpdateStatement: +def virtual_statement(statements: t.List[exp.Expr]) -> VirtualUpdateStatement: return VirtualUpdateStatement(expressions=statements) @@ -874,7 +877,7 @@ class ChunkType(Enum): def parse_one( sql: str, dialect: t.Optional[str] = None, into: t.Optional[exp.IntoType] = None -) -> exp.Expression: +) -> exp.Expr: expressions = parse(sql, default_dialect=dialect, match_dialect=False, into=into) if not expressions: raise SQLMeshError(f"No expressions found in '{sql}'") @@ -888,7 +891,7 @@ def parse( default_dialect: t.Optional[str] = None, match_dialect: bool = True, into: t.Optional[exp.IntoType] = None, -) -> t.List[exp.Expression]: +) -> t.List[exp.Expr]: """Parse a sql string. Supports parsing model definition. @@ -952,10 +955,10 @@ def parse( pos += 1 parser = dialect.parser() - expressions: t.List[exp.Expression] = [] + expressions: t.List[exp.Expr] = [] - def parse_sql_chunk(chunk: t.List[Token], meta_sql: bool = True) -> t.List[exp.Expression]: - parsed_expressions: t.List[t.Optional[exp.Expression]] = ( + def parse_sql_chunk(chunk: t.List[Token], meta_sql: bool = True) -> t.List[exp.Expr]: + parsed_expressions: t.List[t.Optional[exp.Expr]] = ( parser.parse(chunk, sql) if into is None else parser.parse_into(into, chunk, sql) ) expressions = [] @@ -966,7 +969,7 @@ def parse_sql_chunk(chunk: t.List[Token], meta_sql: bool = True) -> t.List[exp.E expressions.append(expression) return expressions - def parse_jinja_chunk(chunk: t.List[Token], meta_sql: bool = True) -> exp.Expression: + def parse_jinja_chunk(chunk: t.List[Token], meta_sql: bool = True) -> exp.Expr: start, *_, end = chunk segment = sql[start.end + 2 : end.start - 1] factory = jinja_query if chunk_type == ChunkType.JINJA_QUERY else jinja_statement @@ -977,9 +980,9 @@ def parse_jinja_chunk(chunk: t.List[Token], meta_sql: bool = True) -> exp.Expres def parse_virtual_statement( chunks: t.List[t.Tuple[t.List[Token], ChunkType]], pos: int - ) -> t.Tuple[t.List[exp.Expression], int]: + ) -> t.Tuple[t.List[exp.Expr], int]: # For virtual statements we need to handle both SQL and Jinja nested blocks within the chunk - virtual_update_statements = [] + virtual_update_statements: t.List[exp.Expr] = [] start = chunks[pos][0][0].start while ( @@ -1031,7 +1034,7 @@ def extend_sqlglot() -> None: # so this ensures that the extra ones it defines are also extended if dialect == athena.Athena: tokenizers.add(athena._TrinoTokenizer) - parsers.add(athena._TrinoParser) + parsers.add(AthenaTrinoParser) generators.add(athena._TrinoGenerator) generators.add(athena._HiveGenerator) @@ -1251,7 +1254,7 @@ def normalize_model_name( def find_tables( - expression: exp.Expression, default_catalog: t.Optional[str], dialect: DialectType = None + expression: exp.Expr, default_catalog: t.Optional[str], dialect: DialectType = None ) -> t.Set[str]: """Find all tables referenced in a query. @@ -1274,10 +1277,10 @@ def find_tables( return expression.meta[TABLES_META] -def add_table(node: exp.Expression, table: str) -> exp.Expression: +def add_table(node: exp.Expr, table: str) -> exp.Expr: """Add a table to all columns in an expression.""" - def _transform(node: exp.Expression) -> exp.Expression: + def _transform(node: exp.Expr) -> exp.Expr: if isinstance(node, exp.Column) and not node.table: return exp.column(node.this, table=table) if isinstance(node, exp.Identifier): @@ -1387,7 +1390,7 @@ def normalize_and_quote( quote_identifiers(query, dialect=dialect) -def interpret_expression(e: exp.Expression) -> exp.Expression | str | int | float | bool: +def interpret_expression(e: exp.Expr) -> exp.Expr | str | int | float | bool: if e.is_int: return int(e.this) if e.is_number: @@ -1399,13 +1402,13 @@ def interpret_expression(e: exp.Expression) -> exp.Expression | str | int | floa def interpret_key_value_pairs( e: exp.Tuple, -) -> t.Dict[str, exp.Expression | str | int | float | bool]: +) -> t.Dict[str, exp.Expr | str | int | float | bool]: return {i.this.name: interpret_expression(i.expression) for i in e.expressions} def extract_func_call( - v: exp.Expression, allow_tuples: bool = False -) -> t.Tuple[str, t.Dict[str, exp.Expression]]: + v: exp.Expr, allow_tuples: bool = False +) -> t.Tuple[str, t.Dict[str, exp.Expr]]: kwargs = {} if isinstance(v, exp.Anonymous): @@ -1442,7 +1445,7 @@ def extract_function_calls(func_calls: t.Any, allow_tuples: bool = False) -> t.A return [extract_func_call(i, allow_tuples=allow_tuples) for i in func_calls.expressions] if isinstance(func_calls, exp.Paren): return [extract_func_call(func_calls.this, allow_tuples=allow_tuples)] - if isinstance(func_calls, exp.Expression): + if isinstance(func_calls, exp.Expr): return [extract_func_call(func_calls, allow_tuples=allow_tuples)] if isinstance(func_calls, list): function_calls = [] @@ -1474,9 +1477,7 @@ def is_meta_expression(v: t.Any) -> bool: return isinstance(v, (Audit, Metric, Model)) -def replace_merge_table_aliases( - expression: exp.Expression, dialect: t.Optional[str] = None -) -> exp.Expression: +def replace_merge_table_aliases(expression: exp.Expr, dialect: t.Optional[str] = None) -> exp.Expr: """ Resolves references from the "source" and "target" tables (or their DBT equivalents) with the corresponding SQLMesh merge aliases (MERGE_SOURCE_ALIAS and MERGE_TARGET_ALIAS) diff --git a/sqlmesh/core/engine_adapter/athena.py b/sqlmesh/core/engine_adapter/athena.py index bd84ba5276..338381549b 100644 --- a/sqlmesh/core/engine_adapter/athena.py +++ b/sqlmesh/core/engine_adapter/athena.py @@ -158,7 +158,7 @@ def _create_schema( schema_name: SchemaName, ignore_if_exists: bool, warn_on_error: bool, - properties: t.List[exp.Expression], + properties: t.List[exp.Expr], kind: str, ) -> None: if location := self._table_location(table_properties=None, table=exp.to_table(schema_name)): @@ -177,14 +177,14 @@ def _create_schema( def _build_create_table_exp( self, table_name_or_schema: t.Union[exp.Schema, TableName], - expression: t.Optional[exp.Expression], + expression: t.Optional[exp.Expr], exists: bool = True, replace: bool = False, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, - partitioned_by: t.Optional[t.List[exp.Expression]] = None, - table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + partitioned_by: t.Optional[t.List[exp.Expr]] = None, + table_properties: t.Optional[t.Dict[str, exp.Expr]] = None, **kwargs: t.Any, ) -> exp.Create: exists = False if replace else exists @@ -235,18 +235,18 @@ def _build_table_properties_exp( catalog_name: t.Optional[str] = None, table_format: t.Optional[str] = None, storage_format: t.Optional[str] = None, - partitioned_by: t.Optional[t.List[exp.Expression]] = None, + partitioned_by: t.Optional[t.List[exp.Expr]] = None, partition_interval_unit: t.Optional[IntervalUnit] = None, - clustered_by: t.Optional[t.List[exp.Expression]] = None, - table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + clustered_by: t.Optional[t.List[exp.Expr]] = None, + table_properties: t.Optional[t.Dict[str, exp.Expr]] = None, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, table: t.Optional[exp.Table] = None, - expression: t.Optional[exp.Expression] = None, + expression: t.Optional[exp.Expr] = None, **kwargs: t.Any, ) -> t.Optional[exp.Properties]: - properties: t.List[exp.Expression] = [] + properties: t.List[exp.Expr] = [] table_properties = table_properties or {} is_hive = self._table_type(table_format) == "hive" @@ -266,7 +266,7 @@ def _build_table_properties_exp( properties.append(exp.SchemaCommentProperty(this=exp.Literal.string(table_description))) if partitioned_by: - schema_expressions: t.List[exp.Expression] = [] + schema_expressions: t.List[exp.Expr] = [] if is_hive and target_columns_to_types: # For Hive-style tables, you cannot include the partitioned by columns in the main set of columns # In the PARTITIONED BY expression, you also cant just include the column names, you need to include the data type as well @@ -381,7 +381,7 @@ def _is_hive_partitioned_table(self, table: exp.Table) -> bool: raise e def _table_location_or_raise( - self, table_properties: t.Optional[t.Dict[str, exp.Expression]], table: exp.Table + self, table_properties: t.Optional[t.Dict[str, exp.Expr]], table: exp.Table ) -> exp.LocationProperty: location = self._table_location(table_properties, table) if not location: @@ -392,7 +392,7 @@ def _table_location_or_raise( def _table_location( self, - table_properties: t.Optional[t.Dict[str, exp.Expression]], + table_properties: t.Optional[t.Dict[str, exp.Expr]], table: exp.Table, ) -> t.Optional[exp.LocationProperty]: base_uri: str @@ -402,7 +402,7 @@ def _table_location( s3_base_location_property = table_properties.pop( "s3_base_location" ) # pop because it's handled differently and we dont want it to end up in the TBLPROPERTIES clause - if isinstance(s3_base_location_property, exp.Expression): + if isinstance(s3_base_location_property, exp.Expr): base_uri = s3_base_location_property.name else: base_uri = s3_base_location_property @@ -419,7 +419,7 @@ def _table_location( return exp.LocationProperty(this=exp.Literal.string(full_uri)) def _find_matching_columns( - self, partitioned_by: t.List[exp.Expression], columns_to_types: t.Dict[str, exp.DataType] + self, partitioned_by: t.List[exp.Expr], columns_to_types: t.Dict[str, exp.DataType] ) -> t.List[t.Tuple[str, exp.DataType]]: matches = [] for col in partitioned_by: @@ -557,7 +557,7 @@ def _chunks() -> t.Iterable[t.List[t.List[str]]]: PartitionsToDelete=[{"Values": v} for v in batch], ) - def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None: + def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None: table = exp.to_table(table_name) table_type = self._query_table_type(table) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index e2dbb51459..8de7b79398 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -236,7 +236,7 @@ def _casted_columns( cls, target_columns_to_types: t.Dict[str, exp.DataType], source_columns: t.Optional[t.List[str]] = None, - ) -> t.List[exp.Alias]: + ) -> t.List[exp.Expr]: source_columns_lookup = set(source_columns or target_columns_to_types) return [ exp.alias_( @@ -591,7 +591,7 @@ def create_index( def _pop_creatable_type_from_properties( self, - properties: t.Dict[str, exp.Expression], + properties: t.Dict[str, exp.Expr], ) -> t.Optional[exp.Property]: """Pop out the creatable_type from the properties dictionary (if exists (return it/remove it) else return none). It also checks that none of the expressions are MATERIALIZE as that conflicts with the `materialize` parameter. @@ -652,9 +652,9 @@ def create_managed_table( table_name: TableName, query: Query, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, - partitioned_by: t.Optional[t.List[exp.Expression]] = None, - clustered_by: t.Optional[t.List[exp.Expression]] = None, - table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + partitioned_by: t.Optional[t.List[exp.Expr]] = None, + clustered_by: t.Optional[t.List[exp.Expr]] = None, + table_properties: t.Optional[t.Dict[str, exp.Expr]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, source_columns: t.Optional[t.List[str]] = None, @@ -964,7 +964,7 @@ def _create_table_from_source_queries( def _create_table( self, table_name_or_schema: t.Union[exp.Schema, TableName], - expression: t.Optional[exp.Expression], + expression: t.Optional[exp.Expr], exists: bool = True, replace: bool = False, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, @@ -1002,7 +1002,7 @@ def _create_table( def _build_create_table_exp( self, table_name_or_schema: t.Union[exp.Schema, TableName], - expression: t.Optional[exp.Expression], + expression: t.Optional[exp.Expr], exists: bool = True, replace: bool = False, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, @@ -1203,7 +1203,7 @@ def create_view( materialized_properties: t.Optional[t.Dict[str, t.Any]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, - view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + view_properties: t.Optional[t.Dict[str, exp.Expr]] = None, source_columns: t.Optional[t.List[str]] = None, **create_kwargs: t.Any, ) -> None: @@ -1382,7 +1382,7 @@ def create_schema( schema_name: SchemaName, ignore_if_exists: bool = True, warn_on_error: bool = True, - properties: t.Optional[t.List[exp.Expression]] = None, + properties: t.Optional[t.List[exp.Expr]] = None, ) -> None: properties = properties or [] return self._create_schema( @@ -1398,7 +1398,7 @@ def _create_schema( schema_name: SchemaName, ignore_if_exists: bool, warn_on_error: bool, - properties: t.List[exp.Expression], + properties: t.List[exp.Expr], kind: str, ) -> None: """Create a schema from a name or qualified table name.""" @@ -1423,7 +1423,7 @@ def drop_schema( schema_name: SchemaName, ignore_if_not_exists: bool = True, cascade: bool = False, - **drop_args: t.Dict[str, exp.Expression], + **drop_args: t.Dict[str, exp.Expr], ) -> None: return self._drop_object( name=schema_name, @@ -1494,7 +1494,7 @@ def table_exists(self, table_name: TableName) -> bool: except Exception: return False - def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None: + def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None: self.execute(exp.delete(table_name, where)) def insert_append( @@ -1552,7 +1552,7 @@ def insert_overwrite_by_partition( self, table_name: TableName, query_or_df: QueryOrDF, - partitioned_by: t.List[exp.Expression], + partitioned_by: t.List[exp.Expr], target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> None: @@ -1583,10 +1583,8 @@ def insert_overwrite_by_time_partition( query_or_df: QueryOrDF, start: TimeLike, end: TimeLike, - time_formatter: t.Callable[ - [TimeLike, t.Optional[t.Dict[str, exp.DataType]]], exp.Expression - ], - time_column: TimeColumn | exp.Expression | str, + time_formatter: t.Callable[[TimeLike, t.Optional[t.Dict[str, exp.DataType]]], exp.Expr], + time_column: TimeColumn | exp.Expr | str, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, @@ -1726,7 +1724,7 @@ def _merge( self, target_table: TableName, query: Query, - on: exp.Expression, + on: exp.Expr, whens: exp.Whens, ) -> None: this = exp.alias_(exp.to_table(target_table), alias=MERGE_TARGET_ALIAS, table=True) @@ -1741,7 +1739,7 @@ def scd_type_2_by_time( self, target_table: TableName, source_table: QueryOrDF, - unique_key: t.Sequence[exp.Expression], + unique_key: t.Sequence[exp.Expr], valid_from_col: exp.Column, valid_to_col: exp.Column, execution_time: t.Union[TimeLike, exp.Column], @@ -1777,11 +1775,11 @@ def scd_type_2_by_column( self, target_table: TableName, source_table: QueryOrDF, - unique_key: t.Sequence[exp.Expression], + unique_key: t.Sequence[exp.Expr], valid_from_col: exp.Column, valid_to_col: exp.Column, execution_time: t.Union[TimeLike, exp.Column], - check_columns: t.Union[exp.Star, t.Sequence[exp.Expression]], + check_columns: t.Union[exp.Star, t.Sequence[exp.Expr]], invalidate_hard_deletes: bool = True, execution_time_as_valid_from: bool = False, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, @@ -1813,13 +1811,13 @@ def _scd_type_2( self, target_table: TableName, source_table: QueryOrDF, - unique_key: t.Sequence[exp.Expression], + unique_key: t.Sequence[exp.Expr], valid_from_col: exp.Column, valid_to_col: exp.Column, execution_time: t.Union[TimeLike, exp.Column], invalidate_hard_deletes: bool = True, updated_at_col: t.Optional[exp.Column] = None, - check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expression]]] = None, + check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expr]]] = None, updated_at_as_valid_from: bool = False, execution_time_as_valid_from: bool = False, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, @@ -1908,7 +1906,7 @@ def remove_managed_columns( raise SQLMeshError( "Cannot use `updated_at_as_valid_from` without `updated_at_name` for SCD Type 2" ) - update_valid_from_start: t.Union[str, exp.Expression] = updated_at_col + update_valid_from_start: t.Union[str, exp.Expr] = updated_at_col # If using check_columns and the user doesn't always want execution_time for valid from # then we only use epoch 0 if we are truncating the table and loading rows for the first time. # All future new rows should have execution time. @@ -2207,9 +2205,9 @@ def merge( target_table: TableName, source_table: QueryOrDF, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], - unique_key: t.Sequence[exp.Expression], + unique_key: t.Sequence[exp.Expr], when_matched: t.Optional[exp.Whens] = None, - merge_filter: t.Optional[exp.Expression] = None, + merge_filter: t.Optional[exp.Expr] = None, source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: @@ -2382,7 +2380,7 @@ def get_data_objects( def fetchone( self, - query: t.Union[exp.Expression, str], + query: t.Union[exp.Expr, str], ignore_unsupported_errors: bool = False, quote_identifiers: bool = False, ) -> t.Optional[t.Tuple]: @@ -2396,7 +2394,7 @@ def fetchone( def fetchall( self, - query: t.Union[exp.Expression, str], + query: t.Union[exp.Expr, str], ignore_unsupported_errors: bool = False, quote_identifiers: bool = False, ) -> t.List[t.Tuple]: @@ -2409,7 +2407,7 @@ def fetchall( return self.cursor.fetchall() def _fetch_native_df( - self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False + self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False ) -> DF: """Fetches a DataFrame that can be either Pandas or PySpark from the cursor""" with self.transaction(): @@ -2432,7 +2430,7 @@ def _native_df_to_pandas_df( raise NotImplementedError(f"Unable to convert {type(query_or_df)} to Pandas") def fetchdf( - self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False + self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False ) -> pd.DataFrame: """Fetches a Pandas DataFrame from the cursor""" import pandas as pd @@ -2445,7 +2443,7 @@ def fetchdf( return df def fetch_pyspark_df( - self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False + self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False ) -> PySparkDataFrame: """Fetches a PySpark DataFrame from the cursor""" raise NotImplementedError(f"Engine does not support PySpark DataFrames: {type(self)}") @@ -2575,7 +2573,7 @@ def _is_session_active(self) -> bool: def execute( self, - expressions: t.Union[str, exp.Expression, t.Sequence[exp.Expression]], + expressions: t.Union[str, exp.Expr, t.Sequence[exp.Expr]], ignore_unsupported_errors: bool = False, quote_identifiers: bool = True, track_rows_processed: bool = False, @@ -2587,7 +2585,7 @@ def execute( ) with self.transaction(): for e in ensure_list(expressions): - if isinstance(e, exp.Expression): + if isinstance(e, exp.Expr): self._check_identifier_length(e) sql = self._to_sql(e, quote=quote_identifiers, **to_sql_kwargs) else: @@ -2597,7 +2595,7 @@ def execute( self._log_sql( sql, - expression=e if isinstance(e, exp.Expression) else None, + expression=e if isinstance(e, exp.Expr) else None, quote_identifiers=quote_identifiers, ) self._execute(sql, track_rows_processed, **kwargs) @@ -2610,7 +2608,7 @@ def _attach_correlation_id(self, sql: str) -> str: def _log_sql( self, sql: str, - expression: t.Optional[exp.Expression] = None, + expression: t.Optional[exp.Expr] = None, quote_identifiers: bool = True, ) -> None: if not logger.isEnabledFor(self._execute_log_level): @@ -2702,7 +2700,7 @@ def temp_table( self.drop_table(table) def _table_or_view_properties_to_expressions( - self, table_or_view_properties: t.Optional[t.Dict[str, exp.Expression]] = None + self, table_or_view_properties: t.Optional[t.Dict[str, exp.Expr]] = None ) -> t.List[exp.Property]: """Converts model properties (either physical or virtual) to a list of property expressions.""" if not table_or_view_properties: @@ -2714,7 +2712,7 @@ def _table_or_view_properties_to_expressions( def _build_partitioned_by_exp( self, - partitioned_by: t.List[exp.Expression], + partitioned_by: t.List[exp.Expr], *, partition_interval_unit: t.Optional[IntervalUnit] = None, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, @@ -2725,7 +2723,7 @@ def _build_partitioned_by_exp( def _build_clustered_by_exp( self, - clustered_by: t.List[exp.Expression], + clustered_by: t.List[exp.Expr], **kwargs: t.Any, ) -> t.Optional[exp.Cluster]: return None @@ -2735,17 +2733,17 @@ def _build_table_properties_exp( catalog_name: t.Optional[str] = None, table_format: t.Optional[str] = None, storage_format: t.Optional[str] = None, - partitioned_by: t.Optional[t.List[exp.Expression]] = None, + partitioned_by: t.Optional[t.List[exp.Expr]] = None, partition_interval_unit: t.Optional[IntervalUnit] = None, - clustered_by: t.Optional[t.List[exp.Expression]] = None, - table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + clustered_by: t.Optional[t.List[exp.Expr]] = None, + table_properties: t.Optional[t.Dict[str, exp.Expr]] = None, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, **kwargs: t.Any, ) -> t.Optional[exp.Properties]: """Creates a SQLGlot table properties expression for ddl.""" - properties: t.List[exp.Expression] = [] + properties: t.List[exp.Expr] = [] if table_description: properties.append( @@ -2764,12 +2762,12 @@ def _build_table_properties_exp( def _build_view_properties_exp( self, - view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + view_properties: t.Optional[t.Dict[str, exp.Expr]] = None, table_description: t.Optional[str] = None, **kwargs: t.Any, ) -> t.Optional[exp.Properties]: """Creates a SQLGlot table properties expression for view""" - properties: t.List[exp.Expression] = [] + properties: t.List[exp.Expr] = [] if table_description: properties.append( @@ -2791,7 +2789,7 @@ def _truncate_table_comment(self, comment: str) -> str: def _truncate_column_comment(self, comment: str) -> str: return self._truncate_comment(comment, self.MAX_COLUMN_COMMENT_LENGTH) - def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.Any) -> str: + def _to_sql(self, expression: exp.Expr, quote: bool = True, **kwargs: t.Any) -> str: """ Converts an expression to a SQL string. Has a set of default kwargs to apply, and then default kwargs defined for the given dialect, and then kwargs provided by the user when defining the engine @@ -2852,7 +2850,7 @@ def _order_projections_and_filter( self, query: Query, target_columns_to_types: t.Dict[str, exp.DataType], - where: t.Optional[exp.Expression] = None, + where: t.Optional[exp.Expr] = None, coerce_types: bool = False, ) -> Query: if not isinstance(query, exp.Query) or ( @@ -2863,7 +2861,7 @@ def _order_projections_and_filter( query = t.cast(exp.Query, query.copy()) with_ = query.args.pop("with_", None) - select_exprs: t.List[exp.Expression] = [ + select_exprs: t.List[exp.Expr] = [ exp.column(c, quoted=True) for c in target_columns_to_types ] if coerce_types and columns_to_types_all_known(target_columns_to_types): @@ -2914,7 +2912,7 @@ def _replace_by_key( target_table: TableName, source_table: QueryOrDF, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], - key: t.Sequence[exp.Expression], + key: t.Sequence[exp.Expr], is_unique_key: bool, source_columns: t.Optional[t.List[str]] = None, ) -> None: @@ -3055,7 +3053,7 @@ def _select_columns( ) ) - def _check_identifier_length(self, expression: exp.Expression) -> None: + def _check_identifier_length(self, expression: exp.Expr) -> None: if self.MAX_IDENTIFIER_LENGTH is None or not isinstance(expression, exp.DDL): return @@ -3147,7 +3145,7 @@ def _apply_grants_config_expr( table: exp.Table, grants_config: GrantsConfig, table_type: DataObjectType = DataObjectType.TABLE, - ) -> t.List[exp.Expression]: + ) -> t.List[exp.Expr]: """Returns SQLGlot Grant expressions to apply grants to a table. Args: @@ -3170,7 +3168,7 @@ def _revoke_grants_config_expr( table: exp.Table, grants_config: GrantsConfig, table_type: DataObjectType = DataObjectType.TABLE, - ) -> t.List[exp.Expression]: + ) -> t.List[exp.Expr]: """Returns SQLGlot expressions to revoke grants from a table. Args: diff --git a/sqlmesh/core/engine_adapter/base_postgres.py b/sqlmesh/core/engine_adapter/base_postgres.py index 11f56da133..e2347b1263 100644 --- a/sqlmesh/core/engine_adapter/base_postgres.py +++ b/sqlmesh/core/engine_adapter/base_postgres.py @@ -110,7 +110,7 @@ def create_view( materialized_properties: t.Optional[t.Dict[str, t.Any]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, - view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + view_properties: t.Optional[t.Dict[str, exp.Expr]] = None, source_columns: t.Optional[t.List[str]] = None, **create_kwargs: t.Any, ) -> None: diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 59a56b6ace..4741f90d27 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -67,7 +67,7 @@ class BigQueryEngineAdapter(ClusteredByMixin, RowDiffMixin, GrantsFromInfoSchema SUPPORTS_MATERIALIZED_VIEWS = True SUPPORTS_CLONING = True SUPPORTS_GRANTS = True - CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("session_user") + CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expr = exp.func("session_user") SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True USE_CATALOG_IN_GRANTS = True GRANT_INFORMATION_SCHEMA_TABLE_NAME = "OBJECT_PRIVILEGES" @@ -288,7 +288,7 @@ def create_schema( schema_name: SchemaName, ignore_if_exists: bool = True, warn_on_error: bool = True, - properties: t.List[exp.Expression] = [], + properties: t.List[exp.Expr] = [], ) -> None: """Create a schema from a name or qualified table name.""" from google.api_core.exceptions import Conflict @@ -433,7 +433,7 @@ def alter_table( def fetchone( self, - query: t.Union[exp.Expression, str], + query: t.Union[exp.Expr, str], ignore_unsupported_errors: bool = False, quote_identifiers: bool = False, ) -> t.Optional[t.Tuple]: @@ -453,7 +453,7 @@ def fetchone( def fetchall( self, - query: t.Union[exp.Expression, str], + query: t.Union[exp.Expr, str], ignore_unsupported_errors: bool = False, quote_identifiers: bool = False, ) -> t.List[t.Tuple]: @@ -689,7 +689,7 @@ def insert_overwrite_by_partition( self, table_name: TableName, query_or_df: QueryOrDF, - partitioned_by: t.List[exp.Expression], + partitioned_by: t.List[exp.Expr], target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> None: @@ -803,7 +803,7 @@ def _table_name(self, table_name: TableName) -> str: return ".".join(part.name for part in exp.to_table(table_name).parts) def _fetch_native_df( - self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False + self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False ) -> DF: self.execute(query, quote_identifiers=quote_identifiers) query_job = self._query_job @@ -863,7 +863,7 @@ def _build_description_property_exp( def _build_partitioned_by_exp( self, - partitioned_by: t.List[exp.Expression], + partitioned_by: t.List[exp.Expr], *, partition_interval_unit: t.Optional[IntervalUnit] = None, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, @@ -909,16 +909,16 @@ def _build_table_properties_exp( catalog_name: t.Optional[str] = None, table_format: t.Optional[str] = None, storage_format: t.Optional[str] = None, - partitioned_by: t.Optional[t.List[exp.Expression]] = None, + partitioned_by: t.Optional[t.List[exp.Expr]] = None, partition_interval_unit: t.Optional[IntervalUnit] = None, - clustered_by: t.Optional[t.List[exp.Expression]] = None, - table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + clustered_by: t.Optional[t.List[exp.Expr]] = None, + table_properties: t.Optional[t.Dict[str, exp.Expr]] = None, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, **kwargs: t.Any, ) -> t.Optional[exp.Properties]: - properties: t.List[exp.Expression] = [] + properties: t.List[exp.Expr] = [] if partitioned_by and ( partitioned_by_prop := self._build_partitioned_by_exp( @@ -1025,12 +1025,12 @@ def _build_col_comment_exp( def _build_view_properties_exp( self, - view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + view_properties: t.Optional[t.Dict[str, exp.Expr]] = None, table_description: t.Optional[str] = None, **kwargs: t.Any, ) -> t.Optional[exp.Properties]: """Creates a SQLGlot table properties expression for view""" - properties: t.List[exp.Expression] = [] + properties: t.List[exp.Expr] = [] if table_description: properties.append( @@ -1257,10 +1257,10 @@ def _update_clustering_key(self, operation: TableAlterClusterByOperation) -> Non ) ) - def _normalize_decimal_value(self, col: exp.Expression, precision: int) -> exp.Expression: + def _normalize_decimal_value(self, col: exp.Expr, precision: int) -> exp.Expr: return exp.func("FORMAT", exp.Literal.string(f"%.{precision}f"), col) - def _normalize_nested_value(self, col: exp.Expression) -> exp.Expression: + def _normalize_nested_value(self, col: exp.Expr) -> exp.Expr: return exp.func("TO_JSON_STRING", col, dialect=self.dialect) @t.overload @@ -1338,7 +1338,7 @@ def _get_current_schema(self) -> str: def _get_bq_dataset_location(self, project: str, dataset: str) -> str: return self._db_call(self.client.get_dataset, dataset_ref=f"{project}.{dataset}").location - def _get_grant_expression(self, table: exp.Table) -> exp.Expression: + def _get_grant_expression(self, table: exp.Table) -> exp.Expr: if not table.db: raise ValueError( f"Table {table.sql(dialect=self.dialect)} does not have a schema (dataset)" @@ -1392,8 +1392,8 @@ def _dcl_grants_config_expr( table: exp.Table, grants_config: GrantsConfig, table_type: DataObjectType = DataObjectType.TABLE, - ) -> t.List[exp.Expression]: - expressions: t.List[exp.Expression] = [] + ) -> t.List[exp.Expr]: + expressions: t.List[exp.Expr] = [] if not grants_config: return expressions diff --git a/sqlmesh/core/engine_adapter/clickhouse.py b/sqlmesh/core/engine_adapter/clickhouse.py index 45c22a6e55..71a834ecfc 100644 --- a/sqlmesh/core/engine_adapter/clickhouse.py +++ b/sqlmesh/core/engine_adapter/clickhouse.py @@ -64,7 +64,7 @@ def cluster(self) -> t.Optional[str]: # doesn't use the row index at all def fetchone( self, - query: t.Union[exp.Expression, str], + query: t.Union[exp.Expr, str], ignore_unsupported_errors: bool = False, quote_identifiers: bool = False, ) -> t.Tuple: @@ -77,13 +77,11 @@ def fetchone( return self.cursor.fetchall()[0] def _fetch_native_df( - self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False + self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False ) -> pd.DataFrame: """Fetches a Pandas DataFrame from the cursor""" return self.cursor.client.query_df( - self._to_sql(query, quote=quote_identifiers) - if isinstance(query, exp.Expression) - else query, + self._to_sql(query, quote=quote_identifiers) if isinstance(query, exp.Expr) else query, use_extended_dtypes=True, ) @@ -168,7 +166,7 @@ def create_schema( schema_name: SchemaName, ignore_if_exists: bool = True, warn_on_error: bool = True, - properties: t.List[exp.Expression] = [], + properties: t.List[exp.Expr] = [], ) -> None: """Create a Clickhouse database from a name or qualified table name. @@ -229,7 +227,7 @@ def _insert_overwrite_by_condition( # REPLACE BY KEY: extract kwargs if present dynamic_key = kwargs.get("dynamic_key") if dynamic_key: - dynamic_key_exp = t.cast(exp.Expression, kwargs.get("dynamic_key_exp")) + dynamic_key_exp = t.cast(exp.Expr, kwargs.get("dynamic_key_exp")) dynamic_key_unique = t.cast(bool, kwargs.get("dynamic_key_unique")) try: @@ -414,7 +412,7 @@ def _replace_by_key( target_table: TableName, source_table: QueryOrDF, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], - key: t.Sequence[exp.Expression], + key: t.Sequence[exp.Expr], is_unique_key: bool, source_columns: t.Optional[t.List[str]] = None, ) -> None: @@ -440,7 +438,7 @@ def insert_overwrite_by_partition( self, table_name: TableName, query_or_df: QueryOrDF, - partitioned_by: t.List[exp.Expression], + partitioned_by: t.List[exp.Expr], target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> None: @@ -487,7 +485,7 @@ def _get_partition_ids( def _create_table( self, table_name_or_schema: t.Union[exp.Schema, TableName], - expression: t.Optional[exp.Expression], + expression: t.Optional[exp.Expr], exists: bool = True, replace: bool = False, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, @@ -595,7 +593,7 @@ def _rename_table( self.execute(f"RENAME TABLE {old_table_sql} TO {new_table_sql}{self._on_cluster_sql()}") - def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None: + def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None: delete_expr = exp.delete(table_name, where) if self.engine_run_mode.is_cluster: delete_expr.set("cluster", exp.OnCluster(this=exp.to_identifier(self.cluster))) @@ -649,7 +647,7 @@ def _drop_object( def _build_partitioned_by_exp( self, - partitioned_by: t.List[exp.Expression], + partitioned_by: t.List[exp.Expr], **kwargs: t.Any, ) -> t.Optional[t.Union[exp.PartitionedByProperty, exp.Property]]: return exp.PartitionedByProperty( @@ -714,14 +712,14 @@ def use_server_nulls_for_unmatched_after_join( return query def _build_settings_property( - self, key: str, value: exp.Expression | str | int | float + self, key: str, value: exp.Expr | str | int | float ) -> exp.SettingsProperty: return exp.SettingsProperty( expressions=[ exp.EQ( this=exp.var(key.lower()), expression=value - if isinstance(value, exp.Expression) + if isinstance(value, exp.Expr) else exp.Literal(this=value, is_string=isinstance(value, str)), ) ] @@ -732,17 +730,17 @@ def _build_table_properties_exp( catalog_name: t.Optional[str] = None, table_format: t.Optional[str] = None, storage_format: t.Optional[str] = None, - partitioned_by: t.Optional[t.List[exp.Expression]] = None, + partitioned_by: t.Optional[t.List[exp.Expr]] = None, partition_interval_unit: t.Optional[IntervalUnit] = None, - clustered_by: t.Optional[t.List[exp.Expression]] = None, - table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + clustered_by: t.Optional[t.List[exp.Expr]] = None, + table_properties: t.Optional[t.Dict[str, exp.Expr]] = None, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, empty_ctas: bool = False, **kwargs: t.Any, ) -> t.Optional[exp.Properties]: - properties: t.List[exp.Expression] = [] + properties: t.List[exp.Expr] = [] table_engine = self.DEFAULT_TABLE_ENGINE if storage_format: @@ -809,9 +807,7 @@ def _build_table_properties_exp( ttl = table_properties_copy.pop("TTL", None) if ttl: properties.append( - exp.MergeTreeTTL( - expressions=[ttl if isinstance(ttl, exp.Expression) else exp.var(ttl)] - ) + exp.MergeTreeTTL(expressions=[ttl if isinstance(ttl, exp.Expr) else exp.var(ttl)]) ) if ( @@ -845,12 +841,12 @@ def _build_table_properties_exp( def _build_view_properties_exp( self, - view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + view_properties: t.Optional[t.Dict[str, exp.Expr]] = None, table_description: t.Optional[str] = None, **kwargs: t.Any, ) -> t.Optional[exp.Properties]: """Creates a SQLGlot table properties expression for view""" - properties: t.List[exp.Expression] = [] + properties: t.List[exp.Expr] = [] view_properties_copy = view_properties.copy() if view_properties else {} diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py index 870b946e7d..e3d029a17d 100644 --- a/sqlmesh/core/engine_adapter/databricks.py +++ b/sqlmesh/core/engine_adapter/databricks.py @@ -163,7 +163,7 @@ def _grant_object_kind(table_type: DataObjectType) -> str: return "MATERIALIZED VIEW" return "TABLE" - def _get_grant_expression(self, table: exp.Table) -> exp.Expression: + def _get_grant_expression(self, table: exp.Table) -> exp.Expr: # We only care about explicitly granted privileges and not inherited ones # if this is removed you would see grants inherited from the catalog get returned expression = super()._get_grant_expression(table) @@ -210,7 +210,7 @@ def query_factory() -> Query: return [SourceQuery(query_factory=query_factory)] def _fetch_native_df( - self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False + self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False ) -> DF: """Fetches a DataFrame that can be either Pandas or PySpark from the cursor""" if self.is_spark_session_connection: @@ -223,7 +223,7 @@ def _fetch_native_df( return self.cursor.fetchall_arrow().to_pandas() def fetchdf( - self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False + self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False ) -> pd.DataFrame: """ Returns a Pandas DataFrame from a query or expression. @@ -364,10 +364,10 @@ def _build_table_properties_exp( catalog_name: t.Optional[str] = None, table_format: t.Optional[str] = None, storage_format: t.Optional[str] = None, - partitioned_by: t.Optional[t.List[exp.Expression]] = None, + partitioned_by: t.Optional[t.List[exp.Expr]] = None, partition_interval_unit: t.Optional[IntervalUnit] = None, - clustered_by: t.Optional[t.List[exp.Expression]] = None, - table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + clustered_by: t.Optional[t.List[exp.Expr]] = None, + table_properties: t.Optional[t.Dict[str, exp.Expr]] = None, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, diff --git a/sqlmesh/core/engine_adapter/duckdb.py b/sqlmesh/core/engine_adapter/duckdb.py index 3b057219e0..ebfcaa7901 100644 --- a/sqlmesh/core/engine_adapter/duckdb.py +++ b/sqlmesh/core/engine_adapter/duckdb.py @@ -145,7 +145,7 @@ def _get_data_objects( for row in df.itertuples() ] - def _normalize_decimal_value(self, col: exp.Expression, precision: int) -> exp.Expression: + def _normalize_decimal_value(self, col: exp.Expr, precision: int) -> exp.Expr: """ duckdb truncates instead of rounding when casting to decimal. @@ -163,7 +163,7 @@ def _normalize_decimal_value(self, col: exp.Expression, precision: int) -> exp.E def _create_table( self, table_name_or_schema: t.Union[exp.Schema, TableName], - expression: t.Optional[exp.Expression], + expression: t.Optional[exp.Expr], exists: bool = True, replace: bool = False, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, diff --git a/sqlmesh/core/engine_adapter/mixins.py b/sqlmesh/core/engine_adapter/mixins.py index c8ef32b9da..bf4bb970a2 100644 --- a/sqlmesh/core/engine_adapter/mixins.py +++ b/sqlmesh/core/engine_adapter/mixins.py @@ -38,9 +38,9 @@ def merge( target_table: TableName, source_table: QueryOrDF, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], - unique_key: t.Sequence[exp.Expression], + unique_key: t.Sequence[exp.Expr], when_matched: t.Optional[exp.Whens] = None, - merge_filter: t.Optional[exp.Expression] = None, + merge_filter: t.Optional[exp.Expr] = None, source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: @@ -58,18 +58,14 @@ def merge( class PandasNativeFetchDFSupportMixin(EngineAdapter): def _fetch_native_df( - self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False + self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False ) -> DF: """Fetches a Pandas DataFrame from a SQL query.""" from warnings import catch_warnings, filterwarnings from pandas.io.sql import read_sql_query - sql = ( - self._to_sql(query, quote=quote_identifiers) - if isinstance(query, exp.Expression) - else query - ) + sql = self._to_sql(query, quote=quote_identifiers) if isinstance(query, exp.Expr) else query logger.debug(f"Executing SQL:\n{sql}") with catch_warnings(), self.transaction(): filterwarnings( @@ -87,7 +83,7 @@ class HiveMetastoreTablePropertiesMixin(EngineAdapter): def _build_partitioned_by_exp( self, - partitioned_by: t.List[exp.Expression], + partitioned_by: t.List[exp.Expr], *, catalog_name: t.Optional[str] = None, **kwargs: t.Any, @@ -120,16 +116,16 @@ def _build_table_properties_exp( catalog_name: t.Optional[str] = None, table_format: t.Optional[str] = None, storage_format: t.Optional[str] = None, - partitioned_by: t.Optional[t.List[exp.Expression]] = None, + partitioned_by: t.Optional[t.List[exp.Expr]] = None, partition_interval_unit: t.Optional[IntervalUnit] = None, - clustered_by: t.Optional[t.List[exp.Expression]] = None, - table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + clustered_by: t.Optional[t.List[exp.Expr]] = None, + table_properties: t.Optional[t.Dict[str, exp.Expr]] = None, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, **kwargs: t.Any, ) -> t.Optional[exp.Properties]: - properties: t.List[exp.Expression] = [] + properties: t.List[exp.Expr] = [] if table_format and self.dialect == "spark": properties.append(exp.FileFormatProperty(this=exp.Var(this=table_format))) @@ -166,12 +162,12 @@ def _build_table_properties_exp( def _build_view_properties_exp( self, - view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + view_properties: t.Optional[t.Dict[str, exp.Expr]] = None, table_description: t.Optional[str] = None, **kwargs: t.Any, ) -> t.Optional[exp.Properties]: """Creates a SQLGlot table properties expression for view""" - properties: t.List[exp.Expression] = [] + properties: t.List[exp.Expr] = [] if table_description: properties.append( @@ -194,7 +190,7 @@ def _truncate_comment(self, comment: str, length: t.Optional[int]) -> str: class GetCurrentCatalogFromFunctionMixin(EngineAdapter): - CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.func("current_catalog") + CURRENT_CATALOG_EXPRESSION: exp.Expr = exp.func("current_catalog") def get_current_catalog(self) -> t.Optional[str]: """Returns the catalog name of the current connection.""" @@ -240,7 +236,7 @@ def _default_precision_to_max( def _build_create_table_exp( self, table_name_or_schema: t.Union[exp.Schema, TableName], - expression: t.Optional[exp.Expression], + expression: t.Optional[exp.Expr], exists: bool = True, replace: bool = False, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, @@ -322,11 +318,11 @@ def is_destructive(self) -> bool: return False @property - def _alter_actions(self) -> t.List[exp.Expression]: + def _alter_actions(self) -> t.List[exp.Expr]: return [exp.Cluster(expressions=self.cluster_key_expressions)] @property - def cluster_key_expressions(self) -> t.List[exp.Expression]: + def cluster_key_expressions(self) -> t.List[exp.Expr]: # Note: Assumes `clustering_key` as a string like: # - "(col_a)" # - "(col_a, col_b)" @@ -346,14 +342,14 @@ def is_destructive(self) -> bool: return False @property - def _alter_actions(self) -> t.List[exp.Expression]: + def _alter_actions(self) -> t.List[exp.Expr]: return [exp.Command(this="DROP", expression="CLUSTERING KEY")] class ClusteredByMixin(EngineAdapter): def _build_clustered_by_exp( self, - clustered_by: t.List[exp.Expression], + clustered_by: t.List[exp.Expr], **kwargs: t.Any, ) -> t.Optional[exp.Cluster]: return exp.Cluster(expressions=[c.copy() for c in clustered_by]) @@ -410,9 +406,9 @@ def logical_merge( target_table: TableName, source_table: QueryOrDF, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], - unique_key: t.Sequence[exp.Expression], + unique_key: t.Sequence[exp.Expr], when_matched: t.Optional[exp.Whens] = None, - merge_filter: t.Optional[exp.Expression] = None, + merge_filter: t.Optional[exp.Expr] = None, source_columns: t.Optional[t.List[str]] = None, ) -> None: """ @@ -452,12 +448,12 @@ def concat_columns( decimal_precision: int = 3, timestamp_precision: int = MAX_TIMESTAMP_PRECISION, delimiter: str = ",", - ) -> exp.Expression: + ) -> exp.Expr: """ Produce an expression that generates a string version of a record, that is: - Every column converted to a string representation, joined together into a single string using the specified :delimiter """ - expressions_to_concat: t.List[exp.Expression] = [] + expressions_to_concat: t.List[exp.Expr] = [] for idx, (column, type) in enumerate(columns_to_types.items()): expressions_to_concat.append( exp.func( @@ -475,11 +471,11 @@ def concat_columns( def normalize_value( self, - expr: exp.Expression, + expr: exp.Expr, type: exp.DataType, decimal_precision: int = 3, timestamp_precision: int = MAX_TIMESTAMP_PRECISION, - ) -> exp.Expression: + ) -> exp.Expr: """ Return an expression that converts the values inside the column `col` to a normalized string @@ -490,6 +486,7 @@ def normalize_value( - `boolean` columns -> '1' or '0' - NULLS -> "" (empty string) """ + value: exp.Expr if type.is_type(exp.DataType.Type.BOOLEAN): value = self._normalize_boolean_value(expr) elif type.is_type(*exp.DataType.INTEGER_TYPES): @@ -512,12 +509,12 @@ def normalize_value( return exp.cast(value, to=exp.DataType.build("VARCHAR")) - def _normalize_nested_value(self, expr: exp.Expression) -> exp.Expression: + def _normalize_nested_value(self, expr: exp.Expr) -> exp.Expr: return expr def _normalize_timestamp_value( - self, expr: exp.Expression, type: exp.DataType, precision: int - ) -> exp.Expression: + self, expr: exp.Expr, type: exp.DataType, precision: int + ) -> exp.Expr: if precision > self.MAX_TIMESTAMP_PRECISION: raise ValueError( f"Requested timestamp precision '{precision}' exceeds maximum supported precision: {self.MAX_TIMESTAMP_PRECISION}" @@ -547,18 +544,18 @@ def _normalize_timestamp_value( return expr - def _normalize_integer_value(self, expr: exp.Expression) -> exp.Expression: + def _normalize_integer_value(self, expr: exp.Expr) -> exp.Expr: return exp.cast(expr, "BIGINT") - def _normalize_decimal_value(self, expr: exp.Expression, precision: int) -> exp.Expression: + def _normalize_decimal_value(self, expr: exp.Expr, precision: int) -> exp.Expr: return exp.cast(expr, f"DECIMAL(38,{precision})") - def _normalize_boolean_value(self, expr: exp.Expression) -> exp.Expression: + def _normalize_boolean_value(self, expr: exp.Expr) -> exp.Expr: return exp.cast(expr, "INT") class GrantsFromInfoSchemaMixin(EngineAdapter): - CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("current_user") + CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expr = exp.func("current_user") SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = False USE_CATALOG_IN_GRANTS = False GRANT_INFORMATION_SCHEMA_TABLE_NAME = "table_privileges" @@ -578,8 +575,8 @@ def _dcl_grants_config_expr( table: exp.Table, grants_config: GrantsConfig, table_type: DataObjectType = DataObjectType.TABLE, - ) -> t.List[exp.Expression]: - expressions: t.List[exp.Expression] = [] + ) -> t.List[exp.Expr]: + expressions: t.List[exp.Expr] = [] if not grants_config: return expressions @@ -617,7 +614,7 @@ def _apply_grants_config_expr( table: exp.Table, grants_config: GrantsConfig, table_type: DataObjectType = DataObjectType.TABLE, - ) -> t.List[exp.Expression]: + ) -> t.List[exp.Expr]: return self._dcl_grants_config_expr(exp.Grant, table, grants_config, table_type) def _revoke_grants_config_expr( @@ -625,10 +622,10 @@ def _revoke_grants_config_expr( table: exp.Table, grants_config: GrantsConfig, table_type: DataObjectType = DataObjectType.TABLE, - ) -> t.List[exp.Expression]: + ) -> t.List[exp.Expr]: return self._dcl_grants_config_expr(exp.Revoke, table, grants_config, table_type) - def _get_grant_expression(self, table: exp.Table) -> exp.Expression: + def _get_grant_expression(self, table: exp.Table) -> exp.Expr: schema_identifier = table.args.get("db") or normalize_identifiers( exp.to_identifier(self._get_current_schema(), quoted=True), dialect=self.dialect ) diff --git a/sqlmesh/core/engine_adapter/mssql.py b/sqlmesh/core/engine_adapter/mssql.py index 359d1f0818..e381c0a198 100644 --- a/sqlmesh/core/engine_adapter/mssql.py +++ b/sqlmesh/core/engine_adapter/mssql.py @@ -176,7 +176,7 @@ def drop_schema( schema_name: SchemaName, ignore_if_not_exists: bool = True, cascade: bool = False, - **drop_args: t.Dict[str, exp.Expression], + **drop_args: t.Dict[str, exp.Expr], ) -> None: """ MsSql doesn't support CASCADE clause and drops schemas unconditionally. @@ -205,9 +205,9 @@ def merge( target_table: TableName, source_table: QueryOrDF, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], - unique_key: t.Sequence[exp.Expression], + unique_key: t.Sequence[exp.Expr], when_matched: t.Optional[exp.Whens] = None, - merge_filter: t.Optional[exp.Expression] = None, + merge_filter: t.Optional[exp.Expr] = None, source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: @@ -401,7 +401,7 @@ def _get_data_objects( for row in dataframe.itertuples() ] - def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.Any) -> str: + def _to_sql(self, expression: exp.Expr, quote: bool = True, **kwargs: t.Any) -> str: sql = super()._to_sql(expression, quote=quote, **kwargs) return f"{sql};" @@ -448,7 +448,7 @@ def _insert_overwrite_by_condition( **kwargs, ) - def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None: + def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expr]) -> None: if where == exp.true(): # "A TRUNCATE TABLE operation can be rolled back within a transaction." # ref: https://learn.microsoft.com/en-us/sql/t-sql/statements/truncate-table-transact-sql?view=sql-server-ver15#remarks diff --git a/sqlmesh/core/engine_adapter/mysql.py b/sqlmesh/core/engine_adapter/mysql.py index 31773d6c63..66759dc440 100644 --- a/sqlmesh/core/engine_adapter/mysql.py +++ b/sqlmesh/core/engine_adapter/mysql.py @@ -73,7 +73,7 @@ def drop_schema( schema_name: SchemaName, ignore_if_not_exists: bool = True, cascade: bool = False, - **drop_args: t.Dict[str, exp.Expression], + **drop_args: t.Dict[str, exp.Expr], ) -> None: # MySQL doesn't support CASCADE clause and drops schemas unconditionally. super().drop_schema(schema_name, ignore_if_not_exists=ignore_if_not_exists, cascade=False) diff --git a/sqlmesh/core/engine_adapter/postgres.py b/sqlmesh/core/engine_adapter/postgres.py index 3dd108cf91..6794169322 100644 --- a/sqlmesh/core/engine_adapter/postgres.py +++ b/sqlmesh/core/engine_adapter/postgres.py @@ -40,7 +40,7 @@ class PostgresEngineAdapter( MAX_IDENTIFIER_LENGTH: t.Optional[int] = 63 SUPPORTS_QUERY_EXECUTION_TRACKING = True GRANT_INFORMATION_SCHEMA_TABLE_NAME = "role_table_grants" - CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.column("current_role") + CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expr = exp.column("current_role") SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True SCHEMA_DIFFER_KWARGS = { "parameterized_type_defaults": { @@ -73,7 +73,7 @@ class PostgresEngineAdapter( } def _fetch_native_df( - self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False + self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False ) -> DF: """ `read_sql_query` when using psycopg will result on a hanging transaction that must be committed @@ -113,9 +113,9 @@ def merge( target_table: TableName, source_table: QueryOrDF, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], - unique_key: t.Sequence[exp.Expression], + unique_key: t.Sequence[exp.Expr], when_matched: t.Optional[exp.Whens] = None, - merge_filter: t.Optional[exp.Expression] = None, + merge_filter: t.Optional[exp.Expr] = None, source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: diff --git a/sqlmesh/core/engine_adapter/redshift.py b/sqlmesh/core/engine_adapter/redshift.py index 03dc89053e..c2a27954cd 100644 --- a/sqlmesh/core/engine_adapter/redshift.py +++ b/sqlmesh/core/engine_adapter/redshift.py @@ -143,7 +143,7 @@ def cursor(self) -> t.Any: return cursor def _fetch_native_df( - self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False + self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False ) -> pd.DataFrame: """Fetches a Pandas DataFrame from the cursor""" import pandas as pd @@ -217,7 +217,7 @@ def create_view( materialized_properties: t.Optional[t.Dict[str, t.Any]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, - view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + view_properties: t.Optional[t.Dict[str, exp.Expr]] = None, source_columns: t.Optional[t.List[str]] = None, **create_kwargs: t.Any, ) -> None: @@ -227,7 +227,7 @@ def create_view( swap tables out from under views. Therefore, we create the view as non-binding. """ no_schema_binding = True - if isinstance(query_or_df, exp.Expression): + if isinstance(query_or_df, exp.Expr): # We can't include NO SCHEMA BINDING if the query has a recursive CTE has_recursive_cte = any( w.args.get("recursive", False) for w in query_or_df.find_all(exp.With) @@ -367,9 +367,9 @@ def merge( target_table: TableName, source_table: QueryOrDF, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], - unique_key: t.Sequence[exp.Expression], + unique_key: t.Sequence[exp.Expr], when_matched: t.Optional[exp.Whens] = None, - merge_filter: t.Optional[exp.Expression] = None, + merge_filter: t.Optional[exp.Expr] = None, source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: @@ -400,12 +400,12 @@ def _merge( self, target_table: TableName, query: Query, - on: exp.Expression, + on: exp.Expr, whens: exp.Whens, ) -> None: # Redshift does not support table aliases in the target table of a MERGE statement. # So we must use the actual table name instead of an alias, as we do with the source table. - def resolve_target_table(expression: exp.Expression) -> exp.Expression: + def resolve_target_table(expression: exp.Expr) -> exp.Expr: if ( isinstance(expression, exp.Column) and expression.table.upper() == MERGE_TARGET_ALIAS @@ -436,7 +436,7 @@ def resolve_target_table(expression: exp.Expression) -> exp.Expression: track_rows_processed=True, ) - def _normalize_decimal_value(self, expr: exp.Expression, precision: int) -> exp.Expression: + def _normalize_decimal_value(self, expr: exp.Expr, precision: int) -> exp.Expr: # Redshift is finicky. It truncates when the data is already in a table, but rounds when the data is generated as part of a SELECT. # # The following works: diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index a8eabe070d..09c530b8f3 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -83,7 +83,7 @@ class SnowflakeEngineAdapter( SNOWPARK = "snowpark" SUPPORTS_QUERY_EXECUTION_TRACKING = True SUPPORTS_GRANTS = True - CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("CURRENT_ROLE") + CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expr = exp.func("CURRENT_ROLE") USE_CATALOG_IN_GRANTS = True @contextlib.contextmanager @@ -95,7 +95,7 @@ def session(self, properties: SessionProperties) -> t.Iterator[None]: if isinstance(warehouse, str): warehouse = exp.to_identifier(warehouse) - if not isinstance(warehouse, exp.Expression): + if not isinstance(warehouse, exp.Expr): raise SQLMeshError(f"Invalid warehouse: '{warehouse}'") warehouse_exp = quote_identifiers( @@ -189,7 +189,7 @@ def _drop_catalog(self, catalog_name: exp.Identifier) -> None: def _create_table( self, table_name_or_schema: t.Union[exp.Schema, TableName], - expression: t.Optional[exp.Expression], + expression: t.Optional[exp.Expr], exists: bool = True, replace: bool = False, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, @@ -225,9 +225,9 @@ def create_managed_table( table_name: TableName, query: Query, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, - partitioned_by: t.Optional[t.List[exp.Expression]] = None, - clustered_by: t.Optional[t.List[exp.Expression]] = None, - table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + partitioned_by: t.Optional[t.List[exp.Expr]] = None, + clustered_by: t.Optional[t.List[exp.Expr]] = None, + table_properties: t.Optional[t.Dict[str, exp.Expr]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, source_columns: t.Optional[t.List[str]] = None, @@ -278,7 +278,7 @@ def create_view( materialized_properties: t.Optional[t.Dict[str, t.Any]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, - view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + view_properties: t.Optional[t.Dict[str, exp.Expr]] = None, source_columns: t.Optional[t.List[str]] = None, **create_kwargs: t.Any, ) -> None: @@ -311,16 +311,16 @@ def _build_table_properties_exp( catalog_name: t.Optional[str] = None, table_format: t.Optional[str] = None, storage_format: t.Optional[str] = None, - partitioned_by: t.Optional[t.List[exp.Expression]] = None, + partitioned_by: t.Optional[t.List[exp.Expr]] = None, partition_interval_unit: t.Optional[IntervalUnit] = None, - clustered_by: t.Optional[t.List[exp.Expression]] = None, - table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + clustered_by: t.Optional[t.List[exp.Expr]] = None, + table_properties: t.Optional[t.Dict[str, exp.Expr]] = None, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, **kwargs: t.Any, ) -> t.Optional[exp.Properties]: - properties: t.List[exp.Expression] = [] + properties: t.List[exp.Expr] = [] # TODO: there is some overlap with the base class and other engine adapters # we need a better way of filtering table properties relevent to the current engine @@ -471,7 +471,7 @@ def cleanup() -> None: return [SourceQuery(query_factory=query_factory, cleanup_func=cleanup)] def _fetch_native_df( - self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False + self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False ) -> DF: import pandas as pd from snowflake.connector.errors import NotSupportedError @@ -561,7 +561,7 @@ def _get_data_objects( for row in df.rename(columns={col: col.lower() for col in df.columns}).itertuples() ] - def _get_grant_expression(self, table: exp.Table) -> exp.Expression: + def _get_grant_expression(self, table: exp.Table) -> exp.Expr: # Upon execute the catalog in table expressions are properly normalized to handle the case where a user provides # the default catalog in their connection config. This doesn't though update catalogs in strings like when querying # the information schema. So we need to manually replace those here. @@ -586,7 +586,7 @@ def set_current_catalog(self, catalog: str) -> None: def set_current_schema(self, schema: str) -> None: self.execute(exp.Use(kind="SCHEMA", this=to_schema(schema))) - def _normalize_catalog(self, expression: exp.Expression) -> exp.Expression: + def _normalize_catalog(self, expression: exp.Expr) -> exp.Expr: # note: important to use self._default_catalog instead of the self.default_catalog property # otherwise we get RecursionError: maximum recursion depth exceeded # because it calls get_current_catalog(), which executes a query, which needs the default catalog, which calls get_current_catalog()... etc @@ -604,7 +604,7 @@ def unquote_and_lower(identifier: str) -> str: self._default_catalog, dialect=self.dialect ) - def catalog_rewriter(node: exp.Expression) -> exp.Expression: + def catalog_rewriter(node: exp.Expr) -> exp.Expr: if isinstance(node, exp.Table): if node.catalog: # only replace the catalog on the model with the target catalog if the two are functionally equivalent @@ -621,7 +621,7 @@ def catalog_rewriter(node: exp.Expression) -> exp.Expression: expression = expression.transform(catalog_rewriter) return expression - def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.Any) -> str: + def _to_sql(self, expression: exp.Expr, quote: bool = True, **kwargs: t.Any) -> str: return super()._to_sql( expression=self._normalize_catalog(expression), quote=quote, **kwargs ) diff --git a/sqlmesh/core/engine_adapter/spark.py b/sqlmesh/core/engine_adapter/spark.py index 5216b0a329..9199aa3bcd 100644 --- a/sqlmesh/core/engine_adapter/spark.py +++ b/sqlmesh/core/engine_adapter/spark.py @@ -340,12 +340,12 @@ def _get_temp_table( return table def fetchdf( - self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False + self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False ) -> pd.DataFrame: return self.fetch_pyspark_df(query, quote_identifiers=quote_identifiers).toPandas() def fetch_pyspark_df( - self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False + self, query: t.Union[exp.Expr, str], quote_identifiers: bool = False ) -> PySparkDataFrame: return self._ensure_pyspark_df( self._fetch_native_df(query, quote_identifiers=quote_identifiers) @@ -437,7 +437,7 @@ def _native_df_to_pandas_df( def _create_table( self, table_name_or_schema: t.Union[exp.Schema, TableName], - expression: t.Optional[exp.Expression], + expression: t.Optional[exp.Expr], exists: bool = True, replace: bool = False, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py index 89470728f2..00acddb26c 100644 --- a/sqlmesh/core/engine_adapter/trino.py +++ b/sqlmesh/core/engine_adapter/trino.py @@ -129,7 +129,7 @@ def session(self, properties: SessionProperties) -> t.Iterator[None]: yield return - if not isinstance(authorization, exp.Expression): + if not isinstance(authorization, exp.Expr): authorization = exp.Literal.string(authorization) if not authorization.is_string: @@ -326,13 +326,13 @@ def _scd_type_2( self, target_table: TableName, source_table: QueryOrDF, - unique_key: t.Sequence[exp.Expression], + unique_key: t.Sequence[exp.Expr], valid_from_col: exp.Column, valid_to_col: exp.Column, execution_time: t.Union[TimeLike, exp.Column], invalidate_hard_deletes: bool = True, updated_at_col: t.Optional[exp.Column] = None, - check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expression]]] = None, + check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expr]]] = None, updated_at_as_valid_from: bool = False, execution_time_as_valid_from: bool = False, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, @@ -409,7 +409,7 @@ def _create_schema( schema_name: SchemaName, ignore_if_exists: bool, warn_on_error: bool, - properties: t.List[exp.Expression], + properties: t.List[exp.Expr], kind: str, ) -> None: if mapped_location := self._schema_location(schema_name): @@ -426,7 +426,7 @@ def _create_schema( def _create_table( self, table_name_or_schema: t.Union[exp.Schema, TableName], - expression: t.Optional[exp.Expression], + expression: t.Optional[exp.Expr], exists: bool = True, replace: bool = False, target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, diff --git a/sqlmesh/core/lineage.py b/sqlmesh/core/lineage.py index 777a2a7d9a..8363979034 100644 --- a/sqlmesh/core/lineage.py +++ b/sqlmesh/core/lineage.py @@ -16,7 +16,7 @@ from sqlmesh.core.model import Model -CACHE: t.Dict[str, t.Tuple[int, exp.Expression, Scope]] = {} +CACHE: t.Dict[str, t.Tuple[int, exp.Expr, Scope]] = {} def lineage( @@ -25,8 +25,8 @@ def lineage( trim_selects: bool = True, **kwargs: t.Any, ) -> Node: - query = None - scope = None + query: t.Optional[exp.Expr] = None + scope: t.Optional[Scope] = None if model.name in CACHE: obj_id, query, scope = CACHE[model.name] diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index af7c344081..888acbb8eb 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -110,7 +110,7 @@ def _macro_sql(sql: str, into: t.Optional[str] = None) -> str: return f"self.parse_one({', '.join(args)})" -def _macro_func_sql(self: Generator, e: exp.Expression) -> str: +def _macro_func_sql(self: Generator, e: exp.Expr) -> str: func = e.this if isinstance(func, exp.Anonymous): @@ -178,7 +178,7 @@ def __init__( schema: t.Optional[MappingSchema] = None, runtime_stage: RuntimeStage = RuntimeStage.LOADING, resolve_table: t.Optional[t.Callable[[str | exp.Table], str]] = None, - resolve_tables: t.Optional[t.Callable[[exp.Expression], exp.Expression]] = None, + resolve_tables: t.Optional[t.Callable[[exp.Expr], exp.Expr]] = None, snapshots: t.Optional[t.Dict[str, Snapshot]] = None, default_catalog: t.Optional[str] = None, path: t.Optional[Path] = None, @@ -237,7 +237,7 @@ def __init__( def send( self, name: str, *args: t.Any, **kwargs: t.Any - ) -> t.Union[None, exp.Expression, t.List[exp.Expression]]: + ) -> t.Union[None, exp.Expr, t.List[exp.Expr]]: func = self.macros.get(normalize_macro_name(name)) if not callable(func): @@ -253,14 +253,12 @@ def send( + format_evaluated_code_exception(e, self.python_env) ) - def transform( - self, expression: exp.Expression - ) -> exp.Expression | t.List[exp.Expression] | None: + def transform(self, expression: exp.Expr) -> exp.Expr | t.List[exp.Expr] | None: changed = False def evaluate_macros( - node: exp.Expression, - ) -> exp.Expression | t.List[exp.Expression] | None: + node: exp.Expr, + ) -> exp.Expr | t.List[exp.Expr] | None: nonlocal changed if isinstance(node, MacroVar): @@ -281,14 +279,10 @@ def evaluate_macros( value = self.locals.get(var_name, variables.get(var_name)) if isinstance(value, list): return exp.convert( - tuple( - self.transform(v) if isinstance(v, exp.Expression) else v for v in value - ) + tuple(self.transform(v) if isinstance(v, exp.Expr) else v for v in value) ) - return exp.convert( - self.transform(value) if isinstance(value, exp.Expression) else value - ) + return exp.convert(self.transform(value) if isinstance(value, exp.Expr) else value) if isinstance(node, exp.Identifier) and "@" in node.this: text = self.template(node.this, {}) if node.this != text: @@ -311,7 +305,7 @@ def evaluate_macros( self.parse_one(node.sql(dialect=self.dialect, copy=False)) for node in transformed ] - if isinstance(transformed, exp.Expression): + if isinstance(transformed, exp.Expr): return self.parse_one(transformed.sql(dialect=self.dialect, copy=False)) return transformed @@ -339,7 +333,7 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str: } return MacroStrTemplate(str(text)).safe_substitute(CaseInsensitiveMapping(base_mapping)) - def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | None: + def evaluate(self, node: MacroFunc) -> exp.Expr | t.List[exp.Expr] | None: if isinstance(node, MacroDef): if isinstance(node.expression, exp.Lambda): _, fn = _norm_var_arg_lambda(self, node.expression) @@ -353,7 +347,7 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | return node if isinstance(node, (MacroSQL, MacroStrReplace)): - result: t.Optional[exp.Expression | t.List[exp.Expression]] = exp.convert( + result: t.Optional[exp.Expr | t.List[exp.Expr]] = exp.convert( self.eval_expression(node) ) else: @@ -421,7 +415,7 @@ def eval_expression(self, node: t.Any) -> t.Any: Returns: The return value of the evaled Python Code. """ - if not isinstance(node, exp.Expression): + if not isinstance(node, exp.Expr): return node code = node.sql() try: @@ -434,8 +428,8 @@ def eval_expression(self, node: t.Any) -> t.Any: ) def parse_one( - self, sql: str | exp.Expression, into: t.Optional[exp.IntoType] = None, **opts: t.Any - ) -> exp.Expression: + self, sql: str | exp.Expr, into: t.Optional[exp.IntoType] = None, **opts: t.Any + ) -> exp.Expr: """Parses the given SQL string and returns a syntax tree for the first parsed SQL statement. @@ -497,7 +491,7 @@ def resolve_table(self, table: str | exp.Table) -> str: ) return self._resolve_table(table) - def resolve_tables(self, query: exp.Expression) -> exp.Expression: + def resolve_tables(self, query: exp.Expr) -> exp.Expr: """Resolves queries with references to SQLMesh model names to their physical tables.""" if not self._resolve_tables: raise SQLMeshError( @@ -588,7 +582,7 @@ def variables(self) -> t.Dict[str, t.Any]: **self.locals.get(c.SQLMESH_BLUEPRINT_VARS_METADATA, {}), } - def _coerce(self, expr: exp.Expression, typ: t.Any, strict: bool = False) -> t.Any: + def _coerce(self, expr: exp.Expr, typ: t.Any, strict: bool = False) -> t.Any: """Coerces the given expression to the specified type on a best-effort basis.""" return _coerce(expr, typ, self.dialect, self._path, strict) @@ -648,8 +642,8 @@ def _norm_var_arg_lambda( """ def substitute( - node: exp.Expression, args: t.Dict[str, exp.Expression] - ) -> exp.Expression | t.List[exp.Expression] | None: + node: exp.Expr, args: t.Dict[str, exp.Expr] + ) -> exp.Expr | t.List[exp.Expr] | None: if isinstance(node, (exp.Identifier, exp.Var)): if not isinstance(node.parent, exp.Column): name = node.name.lower() @@ -798,8 +792,8 @@ def filter_(evaluator: MacroEvaluator, *args: t.Any) -> t.List[t.Any]: def _optional_expression( evaluator: MacroEvaluator, condition: exp.Condition, - expression: exp.Expression, -) -> t.Optional[exp.Expression]: + expression: exp.Expr, +) -> t.Optional[exp.Expr]: """Inserts expression when the condition is True The following examples express the usage of this function in the context of the macros which wrap it. @@ -864,7 +858,7 @@ def star( suffix: exp.Literal = exp.Literal.string(""), quote_identifiers: exp.Boolean = exp.true(), except_: t.Union[exp.Array, exp.Tuple] = exp.Tuple(expressions=[]), -) -> t.List[exp.Alias]: +) -> t.List[exp.Expr]: """Returns a list of projections for the given relation. Args: @@ -939,7 +933,7 @@ def star( @macro() def generate_surrogate_key( evaluator: MacroEvaluator, - *fields: exp.Expression, + *fields: exp.Expr, hash_function: exp.Literal = exp.Literal.string("MD5"), ) -> exp.Func: """Generates a surrogate key (string) for the given fields. @@ -956,7 +950,7 @@ def generate_surrogate_key( >>> MacroEvaluator(dialect="bigquery").transform(parse_one(sql, dialect="bigquery")).sql("bigquery") "SELECT SHA256(CONCAT(COALESCE(CAST(a AS STRING), '_sqlmesh_surrogate_key_null_'), '|', COALESCE(CAST(b AS STRING), '_sqlmesh_surrogate_key_null_'), '|', COALESCE(CAST(c AS STRING), '_sqlmesh_surrogate_key_null_'))) FROM foo" """ - string_fields: t.List[exp.Expression] = [] + string_fields: t.List[exp.Expr] = [] for i, field in enumerate(fields): if i > 0: string_fields.append(exp.Literal.string("|")) @@ -980,7 +974,7 @@ def generate_surrogate_key( @macro() -def safe_add(_: MacroEvaluator, *fields: exp.Expression) -> exp.Case: +def safe_add(_: MacroEvaluator, *fields: exp.Expr) -> exp.Case: """Adds numbers together, substitutes nulls for 0s and only returns null if all fields are null. Example: @@ -998,7 +992,7 @@ def safe_add(_: MacroEvaluator, *fields: exp.Expression) -> exp.Case: @macro() -def safe_sub(_: MacroEvaluator, *fields: exp.Expression) -> exp.Case: +def safe_sub(_: MacroEvaluator, *fields: exp.Expr) -> exp.Case: """Subtract numbers, substitutes nulls for 0s and only returns null if all fields are null. Example: @@ -1016,7 +1010,7 @@ def safe_sub(_: MacroEvaluator, *fields: exp.Expression) -> exp.Case: @macro() -def safe_div(_: MacroEvaluator, numerator: exp.Expression, denominator: exp.Expression) -> exp.Div: +def safe_div(_: MacroEvaluator, numerator: exp.Expr, denominator: exp.Expr) -> exp.Div: """Divides numbers, returns null if the denominator is 0. Example: @@ -1032,7 +1026,7 @@ def safe_div(_: MacroEvaluator, numerator: exp.Expression, denominator: exp.Expr @macro() def union( evaluator: MacroEvaluator, - *args: exp.Expression, + *args: exp.Expr, ) -> exp.Query: """Returns a UNION of the given tables. Only choosing columns that have the same name and type. @@ -1107,10 +1101,10 @@ def union( @macro() def haversine_distance( _: MacroEvaluator, - lat1: exp.Expression, - lon1: exp.Expression, - lat2: exp.Expression, - lon2: exp.Expression, + lat1: exp.Expr, + lon1: exp.Expr, + lat2: exp.Expr, + lon2: exp.Expr, unit: exp.Literal = exp.Literal.string("mi"), ) -> exp.Mul: """Returns the haversine distance between two points. @@ -1150,17 +1144,17 @@ def haversine_distance( def pivot( evaluator: MacroEvaluator, column: SQL, - values: t.List[exp.Expression], + values: t.List[exp.Expr], alias: bool = True, - agg: exp.Expression = exp.Literal.string("SUM"), - cmp: exp.Expression = exp.Literal.string("="), - prefix: exp.Expression = exp.Literal.string(""), - suffix: exp.Expression = exp.Literal.string(""), + agg: exp.Expr = exp.Literal.string("SUM"), + cmp: exp.Expr = exp.Literal.string("="), + prefix: exp.Expr = exp.Literal.string(""), + suffix: exp.Expr = exp.Literal.string(""), then_value: SQL = SQL("1"), else_value: SQL = SQL("0"), quote: bool = True, distinct: bool = False, -) -> t.List[exp.Expression]: +) -> t.List[exp.Expr]: """Returns a list of projections as a result of pivoting the given column on the given values. Example: @@ -1173,14 +1167,14 @@ def pivot( >>> MacroEvaluator(dialect="bigquery").transform(parse_one(sql)).sql("bigquery") "SELECT SUM(CASE WHEN a = 'v' THEN tv ELSE 0 END) AS v_sfx" """ - aggregates: t.List[exp.Expression] = [] + aggregates: t.List[exp.Expr] = [] for value in values: proj = f"{agg.name}(" if distinct: proj += "DISTINCT " proj += f"CASE WHEN {column} {cmp.name} {value.sql(evaluator.dialect)} THEN {then_value} ELSE {else_value} END) " - node = evaluator.parse_one(proj) + node: exp.Expr = evaluator.parse_one(proj) if alias: node = node.as_( @@ -1196,7 +1190,7 @@ def pivot( @macro("AND") -def and_(evaluator: MacroEvaluator, *expressions: t.Optional[exp.Expression]) -> exp.Condition: +def and_(evaluator: MacroEvaluator, *expressions: t.Optional[exp.Expr]) -> exp.Condition: """Returns an AND statement filtering out any NULL expressions.""" conditions = [e for e in expressions if not isinstance(e, exp.Null)] @@ -1207,7 +1201,7 @@ def and_(evaluator: MacroEvaluator, *expressions: t.Optional[exp.Expression]) -> @macro("OR") -def or_(evaluator: MacroEvaluator, *expressions: t.Optional[exp.Expression]) -> exp.Condition: +def or_(evaluator: MacroEvaluator, *expressions: t.Optional[exp.Expr]) -> exp.Condition: """Returns an OR statement filtering out any NULL expressions.""" conditions = [e for e in expressions if not isinstance(e, exp.Null)] @@ -1219,8 +1213,8 @@ def or_(evaluator: MacroEvaluator, *expressions: t.Optional[exp.Expression]) -> @macro("VAR") def var( - evaluator: MacroEvaluator, var_name: exp.Expression, default: t.Optional[exp.Expression] = None -) -> exp.Expression: + evaluator: MacroEvaluator, var_name: exp.Expr, default: t.Optional[exp.Expr] = None +) -> exp.Expr: """Returns the value of a variable or the default value if the variable is not set.""" if not var_name.is_string: raise SQLMeshError(f"Invalid variable name '{var_name.sql()}'. Expected a string literal.") @@ -1230,8 +1224,8 @@ def var( @macro("BLUEPRINT_VAR") def blueprint_var( - evaluator: MacroEvaluator, var_name: exp.Expression, default: t.Optional[exp.Expression] = None -) -> exp.Expression: + evaluator: MacroEvaluator, var_name: exp.Expr, default: t.Optional[exp.Expr] = None +) -> exp.Expr: """Returns the value of a blueprint variable or the default value if the variable is not set.""" if not var_name.is_string: raise SQLMeshError( @@ -1244,8 +1238,8 @@ def blueprint_var( @macro() def deduplicate( evaluator: MacroEvaluator, - relation: exp.Expression, - partition_by: t.List[exp.Expression], + relation: exp.Expr, + partition_by: t.List[exp.Expr], order_by: t.List[str], ) -> exp.Query: """Returns a QUERY to deduplicate rows within a table @@ -1301,9 +1295,9 @@ def deduplicate( @macro() def date_spine( evaluator: MacroEvaluator, - datepart: exp.Expression, - start_date: exp.Expression, - end_date: exp.Expression, + datepart: exp.Expr, + start_date: exp.Expr, + end_date: exp.Expr, ) -> exp.Select: """Returns a query that produces a date spine with the given datepart, and range of start_date and end_date. Useful for joining as a date lookup table. @@ -1491,7 +1485,7 @@ def _coerce( """Coerces the given expression to the specified type on a best-effort basis.""" base_err_msg = f"Failed to coerce expression '{expr}' to type '{typ}'." try: - if typ is None or typ is t.Any or not isinstance(expr, exp.Expression): + if typ is None or typ is t.Any or not isinstance(expr, exp.Expr): return expr base = t.get_origin(typ) or typ @@ -1503,7 +1497,7 @@ def _coerce( except Exception: pass raise SQLMeshError(base_err_msg) - if base is SQL and isinstance(expr, exp.Expression): + if base is SQL and isinstance(expr, exp.Expr): return expr.sql(dialect) if base is t.Literal: @@ -1528,7 +1522,7 @@ def _coerce( if isinstance(expr, base): return expr - if issubclass(base, exp.Expression): + if issubclass(base, exp.Expr): d = Dialect.get_or_raise(dialect) into = base if base in d.parser_class.EXPRESSION_PARSERS else None if into is None: @@ -1603,7 +1597,7 @@ def _convert_sql(v: t.Any, dialect: DialectType) -> t.Any: except Exception: pass - if isinstance(v, exp.Expression): + if isinstance(v, exp.Expr): if (isinstance(v, exp.Column) and not v.table) or ( isinstance(v, exp.Identifier) or v.is_string ): diff --git a/sqlmesh/core/metric/definition.py b/sqlmesh/core/metric/definition.py index dd11cfd38d..70f10b2347 100644 --- a/sqlmesh/core/metric/definition.py +++ b/sqlmesh/core/metric/definition.py @@ -16,7 +16,7 @@ def load_metric_ddl( - expression: exp.Expression, dialect: t.Optional[str], path: Path = Path(), **kwargs: t.Any + expression: exp.Expr, dialect: t.Optional[str], path: Path = Path(), **kwargs: t.Any ) -> MetricMeta: """Returns a MetricMeta from raw Metric DDL.""" if not isinstance(expression, d.Metric): @@ -70,7 +70,7 @@ class MetricMeta(PydanticModel, frozen=True): name: str dialect: str - expression: exp.Expression + expression: exp.Expr description: t.Optional[str] = None owner: t.Optional[str] = None @@ -87,11 +87,11 @@ def _string_validator(cls, v: t.Any) -> t.Optional[str]: return str_or_exp_to_str(v) @field_validator("expression", mode="before") - def _validate_expression(cls, v: t.Any, info: ValidationInfo) -> exp.Expression: + def _validate_expression(cls, v: t.Any, info: ValidationInfo) -> exp.Expr: if isinstance(v, str): dialect = info.data.get("dialect") return d.parse_one(v, dialect=dialect) - if isinstance(v, exp.Expression): + if isinstance(v, exp.Expr): return v return v @@ -139,7 +139,7 @@ def to_metric( class Metric(MetricMeta, frozen=True): - expanded: exp.Expression + expanded: exp.Expr @property def aggs(self) -> t.Dict[exp.AggFunc, MeasureAndDimTables]: @@ -150,7 +150,7 @@ def aggs(self) -> t.Dict[exp.AggFunc, MeasureAndDimTables]: return { t.cast( exp.AggFunc, - t.cast(exp.Expression, agg.parent).transform( + t.cast(exp.Expr, agg.parent).transform( lambda node: ( exp.column(node.this, table=remove_namespace(node)) if isinstance(node, exp.Column) and node.table @@ -162,7 +162,7 @@ def aggs(self) -> t.Dict[exp.AggFunc, MeasureAndDimTables]: } @property - def formula(self) -> exp.Expression: + def formula(self) -> exp.Expr: """Returns the post aggregation formula of a metric. For simple metrics it is just the metric name. For derived metrics, @@ -181,7 +181,7 @@ def _raise_metric_config_error(msg: str, path: Path) -> None: raise ConfigError(f"{msg}. '{path}'") -def _get_measure_and_dim_tables(expression: exp.Expression) -> MeasureAndDimTables: +def _get_measure_and_dim_tables(expression: exp.Expr) -> MeasureAndDimTables: """Finds all the table references in a metric definition. Additionally ensure than the first table returned is the 'measure' or numeric value being aggregated. @@ -190,7 +190,7 @@ def _get_measure_and_dim_tables(expression: exp.Expression) -> MeasureAndDimTabl tables = {} measure_table = None - def is_measure(node: exp.Expression) -> bool: + def is_measure(node: exp.Expr) -> bool: parent = node.parent if isinstance(parent, exp.AggFunc) and node.arg_key == "this": diff --git a/sqlmesh/core/metric/rewriter.py b/sqlmesh/core/metric/rewriter.py index bbdc6c6135..6c9ec429a8 100644 --- a/sqlmesh/core/metric/rewriter.py +++ b/sqlmesh/core/metric/rewriter.py @@ -34,13 +34,13 @@ def __init__( self.join_type = join_type self.semantic_name = f"{semantic_schema}.{semantic_table}" - def rewrite(self, expression: exp.Expression) -> exp.Expression: + def rewrite(self, expression: exp.Expr) -> exp.Expr: for select in list(expression.find_all(exp.Select)): self._expand(select) return expression - def _build_sources(self, projections: t.List[exp.Expression]) -> SourceAggsAndJoins: + def _build_sources(self, projections: t.List[exp.Expr]) -> SourceAggsAndJoins: sources: SourceAggsAndJoins = {} for projection in projections: @@ -78,7 +78,7 @@ def _expand(self, select: exp.Select) -> None: explicit_joins = {exp.table_name(join.this): join for join in select.args.pop("joins", [])} for i, (name, (aggs, joins)) in enumerate(sources.items()): - source: exp.Expression = exp.to_table(name) + source: exp.Expr = exp.to_table(name) table_name = remove_namespace(name) if not isinstance(source, exp.Select): @@ -110,7 +110,7 @@ def _expand(self, select: exp.Select) -> None: copy=False, ) - for node in find_all_in_scope(query, (exp.Column, exp.TableAlias)): + for node in find_all_in_scope(query, exp.Column, exp.TableAlias): # type: ignore[arg-type,var-annotated] if isinstance(node, exp.Column): if node.table in mapping: node.set("table", exp.to_identifier(mapping[node.table])) @@ -123,7 +123,7 @@ def _add_joins( source: exp.Select, name: str, joins: t.Dict[str, t.Optional[exp.Join]], - group_by: t.List[exp.Expression], + group_by: t.List[exp.Expr], mapping: t.Dict[str, str], ) -> exp.Select: grain = [e.copy() for e in group_by] @@ -177,7 +177,7 @@ def _add_joins( return source.select(*grain, copy=False).group_by(*grain, copy=False) -def _replace_table(node: exp.Expression, table: str, base_alias: str) -> exp.Expression: +def _replace_table(node: exp.Expr, table: str, base_alias: str) -> exp.Expr: for column in find_all_in_scope(node, exp.Column): if column.table == base_alias: column.args["table"] = exp.to_identifier(table) @@ -185,11 +185,11 @@ def _replace_table(node: exp.Expression, table: str, base_alias: str) -> exp.Exp def rewrite( - sql: str | exp.Expression, + sql: str | exp.Expr, graph: ReferenceGraph, metrics: t.Dict[str, Metric], dialect: t.Optional[str] = "", -) -> exp.Expression: +) -> exp.Expr: rewriter = Rewriter(graph=graph, metrics=metrics, dialect=dialect) return optimize( diff --git a/sqlmesh/core/model/cache.py b/sqlmesh/core/model/cache.py index 774bfa402b..1f038c5d79 100644 --- a/sqlmesh/core/model/cache.py +++ b/sqlmesh/core/model/cache.py @@ -81,7 +81,7 @@ def get(self, name: str, entry_id: str = "") -> t.List[Model]: @dataclass class OptimizedQueryCacheEntry: - optimized_rendered_query: t.Optional[exp.Expression] + optimized_rendered_query: t.Optional[exp.Query] renderer_violations: t.Optional[t.Dict[type[Rule], t.Any]] diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index dc51b3379c..ccde7624bd 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -33,8 +33,8 @@ def make_python_env( expressions: t.Union[ - exp.Expression, - t.List[t.Union[exp.Expression, t.Tuple[exp.Expression, bool]]], + exp.Expr, + t.List[t.Union[exp.Expr, t.Tuple[exp.Expr, bool]]], ], jinja_macro_references: t.Optional[t.Set[MacroReference]], module_path: Path, @@ -71,7 +71,7 @@ def make_python_env( visited_macro_funcs: t.Set[int] = set() def _is_metadata_var( - name: str, expression: exp.Expression, appears_in_metadata_expression: bool + name: str, expression: exp.Expr, appears_in_metadata_expression: bool ) -> t.Optional[bool]: is_metadata_so_far = used_variables.get(name, True) if is_metadata_so_far is False: @@ -202,7 +202,7 @@ def _is_metadata_macro(name: str, appears_in_metadata_expression: bool) -> bool: def _extract_macro_func_variable_references( - macro_func: exp.Expression, + macro_func: exp.Expr, is_metadata: bool, ) -> t.Tuple[t.Set[str], t.Dict[int, bool], t.Set[int]]: var_references = set() @@ -292,12 +292,12 @@ def _add_variables_to_python_env( if blueprint_variables: metadata_blueprint_variables = { - k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v + k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expr) else v for k, v in blueprint_variables.items() if k in metadata_used_variables } blueprint_variables = { - k.lower(): SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v + k.lower(): SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expr) else v for k, v in blueprint_variables.items() if k in non_metadata_used_variables } @@ -469,9 +469,9 @@ def single_value_or_tuple(values: t.Sequence) -> exp.Identifier | exp.Tuple: def parse_expression( cls: t.Type, - v: t.Union[t.List[str], t.List[exp.Expression], str, exp.Expression, t.Callable, None], + v: t.Union[t.List[str], t.List[exp.Expr], str, exp.Expr, t.Callable, None], info: t.Optional[ValidationInfo], -) -> t.List[exp.Expression] | exp.Expression | t.Callable | None: +) -> t.List[exp.Expr] | exp.Expr | t.Callable | None: """Helper method to deserialize SQLGlot expressions in Pydantic Models.""" if v is None: return None @@ -483,7 +483,7 @@ def parse_expression( if isinstance(v, list): return [ - e if isinstance(e, exp.Expression) else d.parse_one(e, dialect=dialect) + e if isinstance(e, exp.Expr) else d.parse_one(e, dialect=dialect) # type: ignore[misc] for e in v if not isinstance(e, exp.Semicolon) ] @@ -498,7 +498,7 @@ def parse_expression( def parse_bool(v: t.Any) -> bool: - if isinstance(v, exp.Expression): + if isinstance(v, exp.Expr): if not isinstance(v, exp.Boolean): from sqlglot.optimizer.simplify import simplify @@ -524,7 +524,7 @@ def parse_properties( if isinstance(v, str): v = d.parse_one(v, dialect=dialect) if isinstance(v, (exp.Array, exp.Paren, exp.Tuple)): - eq_expressions: t.List[exp.Expression] = ( + eq_expressions: t.List[exp.Expr] = ( [v.unnest()] if isinstance(v, exp.Paren) else v.expressions ) @@ -665,18 +665,18 @@ class ParsableSql(PydanticModel): sql: str transaction: t.Optional[bool] = None - _parsed: t.Optional[exp.Expression] = None + _parsed: t.Optional[exp.Expr] = None _parsed_dialect: t.Optional[str] = None - def parse(self, dialect: str) -> exp.Expression: + def parse(self, dialect: str) -> exp.Expr: if self._parsed is None or self._parsed_dialect != dialect: self._parsed = d.parse_one(self.sql, dialect=dialect) self._parsed_dialect = dialect - return self._parsed + return self._parsed # type: ignore[return-value] @classmethod def from_parsed_expression( - cls, parsed_expression: exp.Expression, dialect: str, use_meta_sql: bool = False + cls, parsed_expression: exp.Expr, dialect: str, use_meta_sql: bool = False ) -> ParsableSql: sql = ( parsed_expression.meta.get("sql") or parsed_expression.sql(dialect=dialect) @@ -697,7 +697,7 @@ def _validate_parsable_sql( return v if isinstance(v, str): return ParsableSql(sql=v) - if isinstance(v, exp.Expression): + if isinstance(v, exp.Expr): return ParsableSql.from_parsed_expression( v, get_dialect(info.data), use_meta_sql=False ) @@ -707,7 +707,7 @@ def _validate_parsable_sql( ParsableSql(sql=s) if isinstance(s, str) else ParsableSql.from_parsed_expression(s, dialect, use_meta_sql=False) - if isinstance(s, exp.Expression) + if isinstance(s, exp.Expr) else ParsableSql.parse_obj(s) for s in v ] diff --git a/sqlmesh/core/model/decorator.py b/sqlmesh/core/model/decorator.py index 73452cc165..328b763f9f 100644 --- a/sqlmesh/core/model/decorator.py +++ b/sqlmesh/core/model/decorator.py @@ -193,7 +193,7 @@ def model( ) rendered_name = rendered_fields["name"] - if isinstance(rendered_name, exp.Expression): + if isinstance(rendered_name, exp.Expr): rendered_fields["name"] = rendered_name.sql(dialect=dialect) rendered_defaults = ( diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 831b52a44e..8d4f72e918 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -215,7 +215,7 @@ def render_definition( include_python: bool = True, include_defaults: bool = False, render_query: bool = False, - ) -> t.List[exp.Expression]: + ) -> t.List[exp.Expr]: """Returns the original list of sql expressions comprising the model definition. Args: @@ -366,7 +366,7 @@ def render_pre_statements( engine_adapter: t.Optional[EngineAdapter] = None, inside_transaction: t.Optional[bool] = True, **kwargs: t.Any, - ) -> t.List[exp.Expression]: + ) -> t.List[exp.Expr]: """Renders pre-statements for a model. Pre-statements are statements that preceded the model's SELECT query. @@ -413,7 +413,7 @@ def render_post_statements( engine_adapter: t.Optional[EngineAdapter] = None, inside_transaction: t.Optional[bool] = True, **kwargs: t.Any, - ) -> t.List[exp.Expression]: + ) -> t.List[exp.Expr]: """Renders post-statements for a model. Post-statements are statements that follow after the model's SELECT query. @@ -460,7 +460,7 @@ def render_on_virtual_update( deployability_index: t.Optional[DeployabilityIndex] = None, engine_adapter: t.Optional[EngineAdapter] = None, **kwargs: t.Any, - ) -> t.List[exp.Expression]: + ) -> t.List[exp.Expr]: return self._render_statements( self.on_virtual_update, start=start, @@ -552,15 +552,15 @@ def render_audit_query( return rendered_query @property - def pre_statements(self) -> t.List[exp.Expression]: + def pre_statements(self) -> t.List[exp.Expr]: return self._get_parsed_statements("pre_statements_") @property - def post_statements(self) -> t.List[exp.Expression]: + def post_statements(self) -> t.List[exp.Expr]: return self._get_parsed_statements("post_statements_") @property - def on_virtual_update(self) -> t.List[exp.Expression]: + def on_virtual_update(self) -> t.List[exp.Expr]: return self._get_parsed_statements("on_virtual_update_") @property @@ -572,7 +572,7 @@ def macro_definitions(self) -> t.List[d.MacroDef]: if isinstance(s, d.MacroDef) ] - def _get_parsed_statements(self, attr_name: str) -> t.List[exp.Expression]: + def _get_parsed_statements(self, attr_name: str) -> t.List[exp.Expr]: value = getattr(self, attr_name) if not value: return [] @@ -587,9 +587,9 @@ def _get_parsed_statements(self, attr_name: str) -> t.List[exp.Expression]: def _render_statements( self, - statements: t.Iterable[exp.Expression], + statements: t.Iterable[exp.Expr], **kwargs: t.Any, - ) -> t.List[exp.Expression]: + ) -> t.List[exp.Expr]: rendered = ( self._statement_renderer(statement).render(**kwargs) for statement in statements @@ -597,7 +597,7 @@ def _render_statements( ) return [r for expressions in rendered if expressions for r in expressions] - def _statement_renderer(self, expression: exp.Expression) -> ExpressionRenderer: + def _statement_renderer(self, expression: exp.Expr) -> ExpressionRenderer: expression_key = id(expression) if expression_key not in self._statement_renderer_cache: self._statement_renderer_cache[expression_key] = ExpressionRenderer( @@ -631,7 +631,7 @@ def render_signals( The list of rendered expressions. """ - def _render(e: exp.Expression) -> str | int | float | bool: + def _render(e: exp.Expr) -> str | int | float | bool: rendered_exprs = ( self._create_renderer(e).render(start=start, end=end, execution_time=execution_time) or [] @@ -676,7 +676,7 @@ def render_merge_filter( start: t.Optional[TimeLike] = None, end: t.Optional[TimeLike] = None, execution_time: t.Optional[TimeLike] = None, - ) -> t.Optional[exp.Expression]: + ) -> t.Optional[exp.Expr]: if self.merge_filter is None: return None rendered_exprs = ( @@ -690,9 +690,9 @@ def render_merge_filter( return rendered_exprs[0].transform(d.replace_merge_table_aliases, dialect=self.dialect) def _render_properties( - self, properties: t.Dict[str, exp.Expression] | SessionProperties, **render_kwargs: t.Any + self, properties: t.Dict[str, exp.Expr] | SessionProperties, **render_kwargs: t.Any ) -> t.Dict[str, t.Any]: - def _render(expression: exp.Expression) -> exp.Expression | None: + def _render(expression: exp.Expr) -> exp.Expr | None: # note: we use the _statement_renderer instead of _create_renderer because it sets model_fqn which # in turn makes @this_model available in the evaluation context rendered_exprs = self._statement_renderer(expression).render(**render_kwargs) @@ -714,7 +714,7 @@ def _render(expression: exp.Expression) -> exp.Expression | None: return { k: rendered for k, v in properties.items() - if (rendered := (_render(v) if isinstance(v, exp.Expression) else v)) + if (rendered := (_render(v) if isinstance(v, exp.Expr) else v)) } def render_physical_properties(self, **render_kwargs: t.Any) -> t.Dict[str, t.Any]: @@ -726,7 +726,7 @@ def render_virtual_properties(self, **render_kwargs: t.Any) -> t.Dict[str, t.Any def render_session_properties(self, **render_kwargs: t.Any) -> t.Dict[str, t.Any]: return self._render_properties(properties=self.session_properties, **render_kwargs) - def _create_renderer(self, expression: exp.Expression) -> ExpressionRenderer: + def _create_renderer(self, expression: exp.Expr) -> ExpressionRenderer: return ExpressionRenderer( expression, self.dialect, @@ -822,7 +822,7 @@ def set_time_format(self, default_time_format: str = c.DEFAULT_TIME_COLUMN_FORMA def convert_to_time_column( self, time: TimeLike, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> exp.Expression: + ) -> exp.Expr: """Convert a TimeLike object to the same time format and type as the model's time column.""" if self.time_column: if columns_to_types is None: @@ -970,7 +970,7 @@ def validate_definition(self) -> None: col.name for expr in values for col in t.cast( - exp.Expression, exp.maybe_parse(expr, dialect=self.dialect) + exp.Expr, exp.maybe_parse(expr, dialect=self.dialect) ).find_all(exp.Column) ] @@ -1266,7 +1266,7 @@ def _additional_metadata(self) -> t.List[str]: return additional_metadata - def _is_metadata_statement(self, statement: exp.Expression) -> bool: + def _is_metadata_statement(self, statement: exp.Expr) -> bool: if isinstance(statement, d.MacroDef): return True if isinstance(statement, d.MacroFunc): @@ -1295,7 +1295,7 @@ def full_depends_on(self) -> t.Set[str]: return self._full_depends_on @property - def partitioned_by(self) -> t.List[exp.Expression]: + def partitioned_by(self) -> t.List[exp.Expr]: """Columns to partition the model by, including the time column if it is not already included.""" if self.time_column and not self._is_time_column_in_partitioned_by: # This allows the user to opt out of automatic time_column injection @@ -1323,7 +1323,7 @@ def partition_interval_unit(self) -> t.Optional[IntervalUnit]: return None @property - def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expression]]]: + def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expr]]]: from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS audits_by_name = {**BUILT_IN_AUDITS, **self.audit_definitions} @@ -1422,8 +1422,8 @@ def render_definition( include_python: bool = True, include_defaults: bool = False, render_query: bool = False, - ) -> t.List[exp.Expression]: - result = super().render_definition( + ) -> t.List[exp.Expr]: + result: t.List[exp.Expr] = super().render_definition( include_python=include_python, include_defaults=include_defaults ) @@ -1946,7 +1946,7 @@ def render_definition( include_python: bool = True, include_defaults: bool = False, render_query: bool = False, - ) -> t.List[exp.Expression]: + ) -> t.List[exp.Expr]: # Ignore the provided value for the include_python flag, since the Pyhon model's # definition without Python code is meaningless. return super().render_definition( @@ -2001,7 +2001,7 @@ class AuditResult(PydanticModel): """The model this audit is for.""" count: t.Optional[int] = None """The number of records returned by the audit query. This could be None if the audit was skipped.""" - query: t.Optional[exp.Expression] = None + query: t.Optional[exp.Expr] = None """The rendered query used by the audit. This could be None if the audit was skipped.""" skipped: bool = False """Whether or not the audit was blocking. This can be overriden by the user.""" @@ -2009,7 +2009,7 @@ class AuditResult(PydanticModel): class EvaluatableSignals(PydanticModel): - signals_to_kwargs: t.Dict[str, t.Dict[str, t.Optional[exp.Expression]]] + signals_to_kwargs: t.Dict[str, t.Dict[str, t.Optional[exp.Expr]]] """A mapping of signal names to the kwargs passed to the signal.""" python_env: t.Dict[str, Executable] """The Python environment that should be used to evaluated the rendered signal calls.""" @@ -2054,7 +2054,7 @@ def _extract_blueprint_variables(blueprint: t.Any, path: Path) -> t.Dict[str, t. def create_models_from_blueprints( - gateway: t.Optional[str | exp.Expression], + gateway: t.Optional[str | exp.Expr], blueprints: t.Any, get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]], loader: t.Callable[..., Model], @@ -2105,7 +2105,7 @@ def create_models_from_blueprints( def load_sql_based_models( - expressions: t.List[exp.Expression], + expressions: t.List[exp.Expr], get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]], path: Path = Path(), module_path: Path = Path(), @@ -2113,8 +2113,8 @@ def load_sql_based_models( default_catalog_per_gateway: t.Optional[t.Dict[str, str]] = None, **loader_kwargs: t.Any, ) -> t.List[Model]: - gateway: t.Optional[exp.Expression] = None - blueprints: t.Optional[exp.Expression] = None + gateway: t.Optional[exp.Expr] = None + blueprints: t.Optional[exp.Expr] = None model_meta = seq_get(expressions, 0) for prop in (isinstance(model_meta, d.Model) and model_meta.expressions) or []: @@ -2160,7 +2160,7 @@ def load_sql_based_models( def load_sql_based_model( - expressions: t.List[exp.Expression], + expressions: t.List[exp.Expr], *, defaults: t.Optional[t.Dict[str, t.Any]] = None, path: t.Optional[Path] = None, @@ -2306,7 +2306,7 @@ def load_sql_based_model( if kind_prop.name.lower() == "merge_filter": meta_fields["kind"].expressions[idx] = unrendered_merge_filter - if isinstance(meta_fields.get("dialect"), exp.Expression): + if isinstance(meta_fields.get("dialect"), exp.Expr): meta_fields["dialect"] = meta_fields["dialect"].name # The name of the model will be inferred from its path relative to `models/`, if it's not explicitly specified @@ -2367,7 +2367,7 @@ def load_sql_based_model( def create_sql_model( name: TableName, - query: t.Optional[exp.Expression], + query: t.Optional[exp.Expr], **kwargs: t.Any, ) -> Model: """Creates a SQL model. @@ -2492,7 +2492,7 @@ def create_python_model( ) depends_on = { dep.sql(dialect=dialect) - for dep in t.cast(t.List[exp.Expression], depends_on_rendered)[0].expressions + for dep in t.cast(t.List[exp.Expr], depends_on_rendered)[0].expressions } used_variables = {k: v for k, v in (variables or {}).items() if k in referenced_variables} @@ -2597,7 +2597,7 @@ def _create_model( if not issubclass(klass, SqlModel): defaults.pop("optimize_query", None) - statements: t.List[t.Union[exp.Expression, t.Tuple[exp.Expression, bool]]] = [] + statements: t.List[t.Union[exp.Expr, t.Tuple[exp.Expr, bool]]] = [] if "query" in kwargs: statements.append(kwargs["query"]) @@ -2636,11 +2636,11 @@ def _create_model( if isinstance(property_values, exp.Tuple): statements.extend(property_values.expressions) - if isinstance(getattr(kwargs.get("kind"), "merge_filter", None), exp.Expression): + if isinstance(getattr(kwargs.get("kind"), "merge_filter", None), exp.Expr): statements.append(kwargs["kind"].merge_filter) jinja_macro_references, referenced_variables = extract_macro_references_and_variables( - *(gen(e if isinstance(e, exp.Expression) else e[0]) for e in statements) + *(gen(e if isinstance(e, exp.Expr) else e[0]) for e in statements) ) if jinja_macros: @@ -2687,7 +2687,7 @@ def _create_model( model.audit_definitions.update(audit_definitions) # Any macro referenced in audits or signals needs to be treated as metadata-only - statements.extend((audit.query, True) for audit in audit_definitions.values()) + statements.extend((audit.query, True) for audit in audit_definitions.values()) # type: ignore[misc] # Ensure that all audits referenced in the model are defined from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS @@ -2743,14 +2743,14 @@ def _create_model( def _split_sql_model_statements( - expressions: t.List[exp.Expression], + expressions: t.List[exp.Expr], path: t.Optional[Path], dialect: t.Optional[str] = None, ) -> t.Tuple[ - t.Optional[exp.Expression], - t.List[exp.Expression], - t.List[exp.Expression], - t.List[exp.Expression], + t.Optional[exp.Expr], + t.List[exp.Expr], + t.List[exp.Expr], + t.List[exp.Expr], UniqueKeyDict[str, ModelAudit], ]: """Extracts the SELECT query from a sequence of expressions. @@ -2811,8 +2811,8 @@ def _split_sql_model_statements( def _resolve_properties( default: t.Optional[t.Dict[str, t.Any]], - provided: t.Optional[exp.Expression | t.Dict[str, t.Any]], -) -> t.Optional[exp.Expression]: + provided: t.Optional[exp.Expr | t.Dict[str, t.Any]], +) -> t.Optional[exp.Expr]: if isinstance(provided, dict): properties = {k: exp.Literal.string(k).eq(v) for k, v in provided.items()} elif provided: @@ -2834,7 +2834,7 @@ def _resolve_properties( return None -def _list_of_calls_to_exp(value: t.List[t.Tuple[str, t.Dict[str, t.Any]]]) -> exp.Expression: +def _list_of_calls_to_exp(value: t.List[t.Tuple[str, t.Dict[str, t.Any]]]) -> exp.Expr: return exp.Tuple( expressions=[ exp.Anonymous( @@ -2849,16 +2849,16 @@ def _list_of_calls_to_exp(value: t.List[t.Tuple[str, t.Dict[str, t.Any]]]) -> ex ) -def _is_projection(expr: exp.Expression) -> bool: +def _is_projection(expr: exp.Expr) -> bool: parent = expr.parent return isinstance(parent, exp.Select) and expr.arg_key == "expressions" -def _single_expr_or_tuple(values: t.Sequence[exp.Expression]) -> exp.Expression | exp.Tuple: +def _single_expr_or_tuple(values: t.Sequence[exp.Expr]) -> exp.Expr | exp.Tuple: return values[0] if len(values) == 1 else exp.Tuple(expressions=values) -def _refs_to_sql(values: t.Any) -> exp.Expression: +def _refs_to_sql(values: t.Any) -> exp.Expr: return exp.Tuple(expressions=values) @@ -2874,7 +2874,7 @@ def render_meta_fields( blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, ) -> t.Dict[str, t.Any]: def render_field_value(value: t.Any) -> t.Any: - if isinstance(value, exp.Expression) or (isinstance(value, str) and "@" in value): + if isinstance(value, exp.Expr) or (isinstance(value, str) and "@" in value): expression = exp.maybe_parse(value, dialect=dialect) rendered_expr = render_expression( expression=expression, @@ -3011,7 +3011,7 @@ def parse_defaults_properties( def render_expression( - expression: exp.Expression, + expression: exp.Expr, module_path: Path, path: t.Optional[Path], jinja_macros: t.Optional[JinjaMacroRegistry] = None, @@ -3020,7 +3020,7 @@ def render_expression( variables: t.Optional[t.Dict[str, t.Any]] = None, default_catalog: t.Optional[str] = None, blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, -) -> t.Optional[t.List[exp.Expression]]: +) -> t.Optional[t.List[exp.Expr]]: meta_python_env = make_python_env( expressions=expression, jinja_macro_references=None, @@ -3092,8 +3092,8 @@ def get_model_name(path: Path) -> str: # function applied to time column when automatically used for partitioning in INCREMENTAL_BY_TIME_RANGE models def clickhouse_partition_func( - column: exp.Expression, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] -) -> exp.Expression: + column: exp.Expr, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] +) -> exp.Expr: # `toMonday()` function accepts a Date or DateTime type column col_type = (columns_to_types and columns_to_types.get(column.name)) or exp.DataType.build( diff --git a/sqlmesh/core/model/kind.py b/sqlmesh/core/model/kind.py index 9abaa9c650..d7a7bb9579 100644 --- a/sqlmesh/core/model/kind.py +++ b/sqlmesh/core/model/kind.py @@ -279,7 +279,7 @@ def model_kind_name(self) -> t.Optional[ModelKindName]: return self.name def to_expression( - self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any + self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any ) -> d.ModelKind: kwargs["expressions"] = expressions return d.ModelKind(this=self.name.value.upper(), **kwargs) @@ -294,7 +294,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]: class TimeColumn(PydanticModel): - column: exp.Expression + column: exp.Expr format: t.Optional[str] = None @classmethod @@ -306,7 +306,7 @@ def _time_column_validator(v: t.Any, info: ValidationInfo) -> TimeColumn: @field_validator("column", mode="before") @classmethod - def _column_validator(cls, v: t.Union[str, exp.Expression]) -> exp.Expression: + def _column_validator(cls, v: t.Union[str, exp.Expr]) -> exp.Expr: if not v: raise ConfigError("Time Column cannot be empty.") if isinstance(v, str): @@ -314,14 +314,14 @@ def _column_validator(cls, v: t.Union[str, exp.Expression]) -> exp.Expression: return v @property - def expression(self) -> exp.Expression: + def expression(self) -> exp.Expr: """Convert this pydantic model into a time_column SQLGlot expression.""" if not self.format: return self.column return exp.Tuple(expressions=[self.column, exp.Literal.string(self.format)]) - def to_expression(self, dialect: str) -> exp.Expression: + def to_expression(self, dialect: str) -> exp.Expr: """Convert this pydantic model into a time_column SQLGlot expression.""" if not self.format: return self.column @@ -346,7 +346,7 @@ def create(cls, v: t.Any, dialect: str) -> Self: exp.column(column_expr) if isinstance(column_expr, exp.Identifier) else column_expr ) format = v.expressions[1].name if len(v.expressions) > 1 else None - elif isinstance(v, exp.Expression): + elif isinstance(v, exp.Expr): column = exp.column(v) if isinstance(v, exp.Identifier) else v format = None elif isinstance(v, str): @@ -400,7 +400,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]: ] def to_expression( - self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any + self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any ) -> d.ModelKind: return super().to_expression( expressions=[ @@ -444,7 +444,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]: ] def to_expression( - self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any + self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any ) -> d.ModelKind: return super().to_expression( expressions=[ @@ -473,7 +473,7 @@ class IncrementalByTimeRangeKind(_IncrementalBy): _time_column_validator = TimeColumn.validator() def to_expression( - self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any + self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any ) -> d.ModelKind: return super().to_expression( expressions=[ @@ -513,7 +513,7 @@ class IncrementalByUniqueKeyKind(_IncrementalBy): ) unique_key: SQLGlotListOfFields when_matched: t.Optional[exp.Whens] = None - merge_filter: t.Optional[exp.Expression] = None + merge_filter: t.Optional[exp.Expr] = None batch_concurrency: t.Literal[1] = 1 @field_validator("when_matched", mode="before") @@ -543,9 +543,9 @@ def _when_matched_validator( @field_validator("merge_filter", mode="before") def _merge_filter_validator( cls, - v: t.Optional[exp.Expression], + v: t.Optional[exp.Expr], info: ValidationInfo, - ) -> t.Optional[exp.Expression]: + ) -> t.Optional[exp.Expr]: if v is None: return v @@ -568,7 +568,7 @@ def data_hash_values(self) -> t.List[t.Optional[str]]: ] def to_expression( - self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any + self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any ) -> d.ModelKind: return super().to_expression( expressions=[ @@ -590,7 +590,7 @@ class IncrementalByPartitionKind(_Incremental): disable_restatement: SQLGlotBool = False @field_validator("forward_only", mode="before") - def _forward_only_validator(cls, v: t.Union[bool, exp.Expression]) -> t.Literal[True]: + def _forward_only_validator(cls, v: t.Union[bool, exp.Expr]) -> t.Literal[True]: if v is not True: raise ConfigError( "Do not specify the `forward_only` configuration key - INCREMENTAL_BY_PARTITION models are always forward_only." @@ -606,7 +606,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]: ] def to_expression( - self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any + self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any ) -> d.ModelKind: return super().to_expression( expressions=[ @@ -640,7 +640,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]: ] def to_expression( - self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any + self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any ) -> d.ModelKind: return super().to_expression( expressions=[ @@ -669,7 +669,7 @@ def supports_python_models(self) -> bool: return False def to_expression( - self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any + self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any ) -> d.ModelKind: return super().to_expression( expressions=[ @@ -690,7 +690,7 @@ class SeedKind(_ModelKind): def _parse_csv_settings(cls, v: t.Any) -> t.Optional[CsvSettings]: if v is None or isinstance(v, CsvSettings): return v - if isinstance(v, exp.Expression): + if isinstance(v, exp.Expr): tuple_exp = parse_properties(cls, v, None) if not tuple_exp: return None @@ -700,7 +700,7 @@ def _parse_csv_settings(cls, v: t.Any) -> t.Optional[CsvSettings]: return v def to_expression( - self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any + self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any ) -> d.ModelKind: """Convert the seed kind into a SQLGlot expression.""" return super().to_expression( @@ -756,13 +756,16 @@ class _SCDType2Kind(_Incremental): @field_validator("time_data_type", mode="before") @classmethod - def _time_data_type_validator( - cls, v: t.Union[str, exp.Expression], values: t.Any - ) -> exp.Expression: - if isinstance(v, exp.Expression) and not isinstance(v, exp.DataType): + def _time_data_type_validator(cls, v: t.Union[str, exp.Expr], values: t.Any) -> exp.Expr: + if isinstance(v, exp.Expr) and not isinstance(v, exp.DataType): v = v.name dialect = get_dialect(values) data_type = exp.DataType.build(v, dialect=dialect) + # Clear meta["sql"] (set by our parser extension) so the pydantic encoder + # uses dialect-aware rendering: e.sql(dialect=meta["dialect"]). Without this, + # the raw SQL text takes priority, which can be wrong for dialect-normalized + # types (e.g., default "TIMESTAMP" should render as "DATETIME" in BigQuery). + data_type.meta.pop("sql", None) data_type.meta["dialect"] = dialect return data_type @@ -795,7 +798,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]: ] def to_expression( - self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any + self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any ) -> d.ModelKind: return super().to_expression( expressions=[ @@ -835,7 +838,7 @@ def data_hash_values(self) -> t.List[t.Optional[str]]: ] def to_expression( - self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any + self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any ) -> d.ModelKind: return super().to_expression( expressions=[ @@ -871,7 +874,7 @@ def data_hash_values(self) -> t.List[t.Optional[str]]: ] def to_expression( - self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any + self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any ) -> d.ModelKind: return super().to_expression( expressions=[ @@ -922,7 +925,7 @@ def data_hash_values(self) -> t.List[t.Optional[str]]: ] def to_expression( - self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any + self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any ) -> d.ModelKind: return super().to_expression( expressions=[ @@ -1005,7 +1008,7 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]: ] def to_expression( - self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any + self, expressions: t.Optional[t.List[exp.Expr]] = None, **kwargs: t.Any ) -> d.ModelKind: return super().to_expression( expressions=[ @@ -1142,7 +1145,7 @@ def create_model_kind(v: t.Any, dialect: str, defaults: t.Dict[str, t.Any]) -> M ) return kind_type(**props) - name = (v.name if isinstance(v, exp.Expression) else str(v)).upper() + name = (v.name if isinstance(v, exp.Expr) else str(v)).upper() return model_kind_type_from_name(name)(name=name) # type: ignore diff --git a/sqlmesh/core/model/meta.py b/sqlmesh/core/model/meta.py index c48b7d1524..a73d6d871a 100644 --- a/sqlmesh/core/model/meta.py +++ b/sqlmesh/core/model/meta.py @@ -50,7 +50,7 @@ from sqlmesh.core._typing import CustomMaterializationProperties, SessionProperties from sqlmesh.core.engine_adapter._typing import GrantsConfig -FunctionCall = t.Tuple[str, t.Dict[str, exp.Expression]] +FunctionCall = t.Tuple[str, t.Dict[str, exp.Expr]] class GrantsTargetLayer(str, Enum): @@ -92,8 +92,8 @@ class ModelMeta(_Node): retention: t.Optional[int] = None # not implemented yet table_format: t.Optional[str] = None storage_format: t.Optional[str] = None - partitioned_by_: t.List[exp.Expression] = Field(default=[], alias="partitioned_by") - clustered_by: t.List[exp.Expression] = [] + partitioned_by_: t.List[exp.Expr] = Field(default=[], alias="partitioned_by") + clustered_by: t.List[exp.Expr] = [] default_catalog: t.Optional[str] = None depends_on_: t.Optional[t.Set[str]] = Field(default=None, alias="depends_on") columns_to_types_: t.Optional[t.Dict[str, exp.DataType]] = Field(default=None, alias="columns") @@ -101,8 +101,8 @@ class ModelMeta(_Node): default=None, alias="column_descriptions" ) audits: t.List[FunctionCall] = [] - grains: t.List[exp.Expression] = [] - references: t.List[exp.Expression] = [] + grains: t.List[exp.Expr] = [] + references: t.List[exp.Expr] = [] physical_schema_override: t.Optional[str] = None physical_properties_: t.Optional[exp.Tuple] = Field(default=None, alias="physical_properties") virtual_properties_: t.Optional[exp.Tuple] = Field(default=None, alias="virtual_properties") @@ -151,11 +151,11 @@ def _normalize(value: t.Any) -> t.Any: if isinstance(v, (exp.Tuple, exp.Array)): return [_normalize(e).name for e in v.expressions] - if isinstance(v, exp.Expression): + if isinstance(v, exp.Expr): return _normalize(v).name if isinstance(v, str): value = _normalize(v) - return value.name if isinstance(value, exp.Expression) else value + return value.name if isinstance(value, exp.Expr) else value if isinstance(v, (list, tuple)): return [cls._validate_value_or_tuple(elm, data, normalize=normalize) for elm in v] @@ -163,7 +163,7 @@ def _normalize(value: t.Any) -> t.Any: @field_validator("table_format", "storage_format", mode="before") def _format_validator(cls, v: t.Any, info: ValidationInfo) -> t.Optional[str]: - if isinstance(v, exp.Expression) and not (isinstance(v, (exp.Literal, exp.Identifier))): + if isinstance(v, exp.Expr) and not (isinstance(v, (exp.Literal, exp.Identifier))): return v.sql(info.data.get("dialect")) return str_or_exp_to_str(v) @@ -188,9 +188,7 @@ def _gateway_validator(cls, v: t.Any) -> t.Optional[str]: return gateway and gateway.lower() @field_validator("partitioned_by_", "clustered_by", mode="before") - def _partition_and_cluster_validator( - cls, v: t.Any, info: ValidationInfo - ) -> t.List[exp.Expression]: + def _partition_and_cluster_validator(cls, v: t.Any, info: ValidationInfo) -> t.List[exp.Expr]: if ( isinstance(v, list) and all(isinstance(i, str) for i in v) @@ -244,9 +242,33 @@ def _columns_validator( return columns_to_types if isinstance(v, dict): - udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES + dialect_obj = Dialect.get_or_raise(dialect) + udt = dialect_obj.SUPPORTS_USER_DEFINED_TYPES for k, data_type in v.items(): + is_string_type = isinstance(data_type, str) expr = exp.DataType.build(data_type, dialect=dialect, udt=udt) + # When deserializing from a string (e.g. JSON roundtrip), normalize the type + # through the dialect's type system so that aliases (e.g. INT in BigQuery, + # which is an alias for INT64/BIGINT) are resolved to their canonical form. + # This ensures stable data hash computation across serialization/deserialization + # roundtrips. We skip this for DataType objects passed directly (Python API) + # since those should be used as-is. + if ( + is_string_type + and dialect + and expr.this + not in ( + exp.DataType.Type.USERDEFINED, + exp.DataType.Type.UNKNOWN, + ) + ): + sql_repr = expr.sql(dialect=dialect) + try: + normalized = parse_one(sql_repr, read=dialect, into=exp.DataType) + if normalized is not None: + expr = normalized + except Exception: + pass expr.meta["dialect"] = dialect columns_to_types[normalize_identifiers(k, dialect=dialect).name] = expr @@ -295,7 +317,7 @@ def _column_descriptions_validator( return col_descriptions @field_validator("grains", "references", mode="before") - def _refs_validator(cls, vs: t.Any, info: ValidationInfo) -> t.List[exp.Expression]: + def _refs_validator(cls, vs: t.Any, info: ValidationInfo) -> t.List[exp.Expr]: dialect = info.data.get("dialect") if isinstance(vs, exp.Paren): @@ -349,7 +371,7 @@ def session_properties_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any: "Invalid value for `session_properties.query_label`. Must be an array or tuple." ) - label_tuples: t.List[exp.Expression] = ( + label_tuples: t.List[exp.Expr] = ( [query_label.unnest()] if isinstance(query_label, exp.Paren) else query_label.expressions @@ -449,7 +471,7 @@ def time_column(self) -> t.Optional[TimeColumn]: return getattr(self.kind, "time_column", None) @property - def unique_key(self) -> t.List[exp.Expression]: + def unique_key(self) -> t.List[exp.Expr]: if isinstance( self.kind, (SCDType2ByTimeKind, SCDType2ByColumnKind, IncrementalByUniqueKeyKind) ): @@ -485,14 +507,14 @@ def batch_concurrency(self) -> t.Optional[int]: return getattr(self.kind, "batch_concurrency", None) @cached_property - def physical_properties(self) -> t.Dict[str, exp.Expression]: + def physical_properties(self) -> t.Dict[str, exp.Expr]: """A dictionary of properties that will be applied to the physical layer. It replaces table_properties which is deprecated.""" if self.physical_properties_: return {e.this.name: e.expression for e in self.physical_properties_.expressions} return {} @cached_property - def virtual_properties(self) -> t.Dict[str, exp.Expression]: + def virtual_properties(self) -> t.Dict[str, exp.Expr]: """A dictionary of properties that will be applied to the virtual layer.""" if self.virtual_properties_: return {e.this.name: e.expression for e in self.virtual_properties_.expressions} @@ -568,7 +590,7 @@ def when_matched(self) -> t.Optional[exp.Whens]: return None @property - def merge_filter(self) -> t.Optional[exp.Expression]: + def merge_filter(self) -> t.Optional[exp.Expr]: if isinstance(self.kind, IncrementalByUniqueKeyKind): return self.kind.merge_filter return None @@ -601,7 +623,7 @@ def on_additive_change(self) -> OnAdditiveChange: def ignored_rules(self) -> t.Set[str]: return self.ignored_rules_ or set() - def _validate_config_expression(self, expr: exp.Expression) -> str: + def _validate_config_expression(self, expr: exp.Expr) -> str: if isinstance(expr, (d.MacroFunc, d.MacroVar)): raise ConfigError(f"Unresolved macro: {expr.sql(dialect=self.dialect)}") @@ -614,10 +636,10 @@ def _validate_config_expression(self, expr: exp.Expression) -> str: return expr.name return expr.sql(dialect=self.dialect).strip() - def _validate_nested_config_values(self, value_expr: exp.Expression) -> t.List[str]: + def _validate_nested_config_values(self, value_expr: exp.Expr) -> t.List[str]: result = [] - def flatten_expr(expr: exp.Expression) -> None: + def flatten_expr(expr: exp.Expr) -> None: if isinstance(expr, exp.Array): for elem in expr.expressions: flatten_expr(elem) diff --git a/sqlmesh/core/model/seed.py b/sqlmesh/core/model/seed.py index fe1aa85204..9fd57fe6d3 100644 --- a/sqlmesh/core/model/seed.py +++ b/sqlmesh/core/model/seed.py @@ -49,7 +49,7 @@ def _bool_validator(cls, v: t.Any) -> t.Optional[bool]: ) @classmethod def _str_validator(cls, v: t.Any) -> t.Optional[str]: - if v is None or not isinstance(v, exp.Expression): + if v is None or not isinstance(v, exp.Expr): return v # SQLGlot parses escape sequences like \t as \\t for dialects that don't treat \ as @@ -60,7 +60,7 @@ def _str_validator(cls, v: t.Any) -> t.Optional[str]: @field_validator("na_values", mode="before") @classmethod def _na_values_validator(cls, v: t.Any) -> t.Optional[NaValues]: - if v is None or not isinstance(v, exp.Expression): + if v is None or not isinstance(v, exp.Expr): return v try: diff --git a/sqlmesh/core/node.py b/sqlmesh/core/node.py index 4a3bf2564b..d3b63312f1 100644 --- a/sqlmesh/core/node.py +++ b/sqlmesh/core/node.py @@ -215,7 +215,7 @@ def post_init(self) -> Self: self.alias = None return self - def to_expression(self) -> exp.Expression: + def to_expression(self) -> exp.Expr: """Produce a SQLGlot expression representing this object, for use in things like the model/audit definition renderers""" return exp.tuple_( *( @@ -324,7 +324,7 @@ def copy(self, **kwargs: t.Any) -> Self: def _name_validator(cls, v: t.Any) -> t.Optional[str]: if v is None: return None - if isinstance(v, exp.Expression): + if isinstance(v, exp.Expr): return v.meta["sql"] return str(v) @@ -352,7 +352,7 @@ def _cron_tz_validator(cls, v: t.Any) -> t.Optional[zoneinfo.ZoneInfo]: @field_validator("start", "end", mode="before") @classmethod def _date_validator(cls, v: t.Any) -> t.Optional[TimeLike]: - if isinstance(v, exp.Expression): + if isinstance(v, exp.Expr): v = v.name if v and not to_datetime(v): raise ConfigError(f"'{v}' needs to be time-like: https://pypi.org/project/dateparser") @@ -555,6 +555,6 @@ def __str__(self) -> str: def str_or_exp_to_str(v: t.Any) -> t.Optional[str]: - if isinstance(v, exp.Expression): + if isinstance(v, exp.Expr): return v.name return str(v) if v is not None else None diff --git a/sqlmesh/core/reference.py b/sqlmesh/core/reference.py index 2bf2c04e98..9e93ce7b38 100644 --- a/sqlmesh/core/reference.py +++ b/sqlmesh/core/reference.py @@ -14,7 +14,7 @@ class Reference(PydanticModel, frozen=True): model_name: str - expression: exp.Expression + expression: exp.Expr unique: bool = False _name: str = "" diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index 50c1faeb63..7683956064 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -48,7 +48,7 @@ class BaseExpressionRenderer: def __init__( self, - expression: exp.Expression, + expression: exp.Expr, dialect: DialectType, macro_definitions: t.List[d.MacroDef], path: t.Optional[Path] = None, @@ -73,7 +73,7 @@ def __init__( self._normalize_identifiers = normalize_identifiers self._quote_identifiers = quote_identifiers self.update_schema({} if schema is None else schema) - self._cache: t.List[t.Optional[exp.Expression]] = [] + self._cache: t.List[t.Optional[exp.Expr]] = [] self._model_fqn = model.fqn if model else None self._optimize_query_flag = optimize_query is not False self._model = model @@ -91,7 +91,7 @@ def _render( deployability_index: t.Optional[DeployabilityIndex] = None, runtime_stage: RuntimeStage = RuntimeStage.LOADING, **kwargs: t.Any, - ) -> t.List[t.Optional[exp.Expression]]: + ) -> t.List[t.Optional[exp.Expr]]: """Renders a expression, expanding macros with provided kwargs Args: @@ -205,7 +205,7 @@ def _resolve_table(table: str | exp.Table) -> str: if variables: macro_evaluator.locals.setdefault(c.SQLMESH_VARS, {}).update(variables) - expressions = [self._expression] + expressions: t.List[exp.Expr] = [self._expression] if isinstance(self._expression, d.Jinja): try: jinja_env_kwargs = { @@ -283,7 +283,7 @@ def _resolve_table(table: str | exp.Table) -> str: f"Failed to evaluate macro '{definition}'.\n\n{ex}\n", self._path ) - resolved_expressions: t.List[t.Optional[exp.Expression]] = [] + resolved_expressions: t.List[t.Optional[exp.Expr]] = [] for expression in expressions: try: @@ -294,7 +294,7 @@ def _resolve_table(table: str | exp.Table) -> str: self._path, ) - for expression in t.cast(t.List[exp.Expression], transformed_expressions): + for expression in t.cast(t.List[exp.Expr], transformed_expressions): with self._normalize_and_quote(expression) as expression: if hasattr(expression, "selects"): for select in expression.selects: @@ -320,12 +320,12 @@ def _resolve_table(table: str | exp.Table) -> str: self._cache = resolved_expressions return resolved_expressions - def update_cache(self, expression: t.Optional[exp.Expression]) -> None: + def update_cache(self, expression: t.Optional[exp.Expr]) -> None: self._cache = [expression] def _resolve_table( self, - table_name: str | exp.Expression, + table_name: str | exp.Expr, snapshots: t.Optional[t.Dict[str, Snapshot]] = None, table_mapping: t.Optional[t.Dict[str, str]] = None, deployability_index: t.Optional[DeployabilityIndex] = None, @@ -380,7 +380,7 @@ def _resolve_tables( if snapshot.is_model } - def _expand(node: exp.Expression) -> exp.Expression: + def _expand(node: exp.Expr) -> exp.Expr: if isinstance(node, exp.Table) and snapshots: name = exp.table_name(node, identify=True) model = model_mapping.get(name) @@ -449,7 +449,7 @@ def render( deployability_index: t.Optional[DeployabilityIndex] = None, expand: t.Iterable[str] = tuple(), **kwargs: t.Any, - ) -> t.Optional[t.List[exp.Expression]]: + ) -> t.Optional[t.List[exp.Expr]]: try: expressions = super()._render( start=start, @@ -631,7 +631,7 @@ def render( def update_cache( self, - expression: t.Optional[exp.Expression], + expression: t.Optional[exp.Expr], violated_rules: t.Optional[t.Dict[type[Rule], t.Any]] = None, optimized: bool = False, ) -> None: diff --git a/sqlmesh/core/schema_diff.py b/sqlmesh/core/schema_diff.py index e1f9d72a6c..ecf38b18a8 100644 --- a/sqlmesh/core/schema_diff.py +++ b/sqlmesh/core/schema_diff.py @@ -37,7 +37,7 @@ def is_additive(self) -> bool: @property @abc.abstractmethod - def _alter_actions(self) -> t.List[exp.Expression]: + def _alter_actions(self) -> t.List[exp.Expr]: pass @property @@ -104,7 +104,7 @@ def is_destructive(self) -> bool: return self.is_part_of_destructive_change @property - def _alter_actions(self) -> t.List[exp.Expression]: + def _alter_actions(self) -> t.List[exp.Expr]: column_def = exp.ColumnDef( this=self.column, kind=self.column_type, @@ -127,7 +127,7 @@ def is_destructive(self) -> bool: return True @property - def _alter_actions(self) -> t.List[exp.Expression]: + def _alter_actions(self) -> t.List[exp.Expr]: return [exp.Drop(this=self.column, kind="COLUMN", cascade=self.cascade)] @@ -145,7 +145,7 @@ def is_destructive(self) -> bool: return self.is_part_of_destructive_change @property - def _alter_actions(self) -> t.List[exp.Expression]: + def _alter_actions(self) -> t.List[exp.Expr]: return [ exp.AlterColumn( this=self.column, @@ -363,14 +363,12 @@ class SchemaDiffer(PydanticModel): coerceable_types_: t.Dict[exp.DataType, t.Set[exp.DataType]] = Field( default_factory=dict, alias="coerceable_types" ) - precision_increase_allowed_types: t.Optional[t.Set[exp.DataType.Type]] = None + precision_increase_allowed_types: t.Optional[t.Set[exp.DType]] = None support_coercing_compatible_types: bool = False drop_cascade: bool = False - parameterized_type_defaults: t.Dict[ - exp.DataType.Type, t.List[t.Tuple[t.Union[int, float], ...]] - ] = {} - max_parameter_length: t.Dict[exp.DataType.Type, t.Union[int, float]] = {} - types_with_unlimited_length: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} + parameterized_type_defaults: t.Dict[exp.DType, t.List[t.Tuple[t.Union[int, float], ...]]] = {} + max_parameter_length: t.Dict[exp.DType, t.Union[int, float]] = {} + types_with_unlimited_length: t.Dict[exp.DType, t.Set[exp.DType]] = {} treat_alter_data_type_as_destructive: bool = False _coerceable_types: t.Dict[exp.DataType, t.Set[exp.DataType]] = {} diff --git a/sqlmesh/core/selector.py b/sqlmesh/core/selector.py index 3865327acd..9eaf4995c8 100644 --- a/sqlmesh/core/selector.py +++ b/sqlmesh/core/selector.py @@ -191,7 +191,7 @@ def expand_model_selections( models_by_tags.setdefault(tag, set()) models_by_tags[tag].add(model.fqn) - def evaluate(node: exp.Expression) -> t.Set[str]: + def evaluate(node: exp.Expr) -> t.Set[str]: if isinstance(node, exp.Var): pattern = node.this if "*" in pattern: @@ -400,7 +400,7 @@ class Direction(exp.Expression): pass -def parse(selector: str, dialect: DialectType = None) -> exp.Expression: +def parse(selector: str, dialect: DialectType = None) -> exp.Expr: tokens = SelectorDialect().tokenize(selector) i = 0 @@ -444,7 +444,7 @@ def _parse_kind(kind: str) -> bool: return True return False - def _parse_var() -> exp.Expression: + def _parse_var() -> exp.Expr: upstream = _match(TokenType.PLUS) downstream = None tag = _parse_kind("tag") @@ -457,7 +457,7 @@ def _parse_var() -> exp.Expression: name = _prev().text rstar = "*" if _match(TokenType.STAR) else "" downstream = _match(TokenType.PLUS) - this: exp.Expression = exp.Var(this=f"{lstar}{name}{rstar}") + this: exp.Expr = exp.Var(this=f"{lstar}{name}{rstar}") elif _match(TokenType.L_PAREN): this = exp.Paren(this=_parse_conjunction()) @@ -483,12 +483,12 @@ def _parse_var() -> exp.Expression: this = Direction(this=this, **directions) return this - def _parse_unary() -> exp.Expression: + def _parse_unary() -> exp.Expr: if _match(TokenType.CARET): return exp.Not(this=_parse_unary()) return _parse_var() - def _parse_conjunction() -> exp.Expression: + def _parse_conjunction() -> exp.Expr: this = _parse_unary() if _match(TokenType.AMP): diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 4f5102cbef..b1ffd4dc26 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -249,7 +249,7 @@ def evaluate_and_fetch( query_or_df = next(queries_or_dfs) if isinstance(query_or_df, pd.DataFrame): return query_or_df.head(limit) - if not isinstance(query_or_df, exp.Expression): + if not isinstance(query_or_df, exp.Expr): # We assume that if this branch is reached, `query_or_df` is a pyspark / snowpark / bigframe dataframe, # so we use `limit` instead of `head` to get back a dataframe instead of List[Row] # https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.DataFrame.head.html#pyspark.sql.DataFrame.head @@ -940,7 +940,7 @@ def _render_and_insert_snapshot( snapshots: t.Dict[str, Snapshot], render_kwargs: t.Dict[str, t.Any], create_render_kwargs: t.Dict[str, t.Any], - rendered_physical_properties: t.Dict[str, exp.Expression], + rendered_physical_properties: t.Dict[str, exp.Expr], deployability_index: DeployabilityIndex, target_table_name: str, is_first_insert: bool, @@ -1069,7 +1069,7 @@ def _clone_snapshot_in_dev( snapshots: t.Dict[str, Snapshot], deployability_index: DeployabilityIndex, render_kwargs: t.Dict[str, t.Any], - rendered_physical_properties: t.Dict[str, exp.Expression], + rendered_physical_properties: t.Dict[str, exp.Expr], allow_destructive_snapshots: t.Set[str], allow_additive_snapshots: t.Set[str], run_pre_post_statements: bool = False, @@ -1186,7 +1186,7 @@ def _migrate_target_table( snapshots: t.Dict[str, Snapshot], deployability_index: DeployabilityIndex, render_kwargs: t.Dict[str, t.Any], - rendered_physical_properties: t.Dict[str, exp.Expression], + rendered_physical_properties: t.Dict[str, exp.Expr], allow_destructive_snapshots: t.Set[str], allow_additive_snapshots: t.Set[str], run_pre_post_statements: bool = False, @@ -1472,7 +1472,7 @@ def _execute_create( is_table_deployable: bool, deployability_index: DeployabilityIndex, create_render_kwargs: t.Dict[str, t.Any], - rendered_physical_properties: t.Dict[str, exp.Expression], + rendered_physical_properties: t.Dict[str, exp.Expr], dry_run: bool, run_pre_post_statements: bool = True, skip_grants: bool = False, @@ -3106,7 +3106,7 @@ def create( query=model.render_query_or_raise(**render_kwargs), target_columns_to_types=model.columns_to_types, partitioned_by=model.partitioned_by, - clustered_by=model.clustered_by, + clustered_by=model.clustered_by, # type: ignore[arg-type] table_properties=kwargs.get("physical_properties", model.physical_properties), table_description=model.description, column_descriptions=model.column_descriptions, @@ -3151,7 +3151,7 @@ def insert( query=query_or_df, # type: ignore target_columns_to_types=model.columns_to_types, partitioned_by=model.partitioned_by, - clustered_by=model.clustered_by, + clustered_by=model.clustered_by, # type: ignore[arg-type] table_properties=kwargs.get("physical_properties", model.physical_properties), table_description=model.description, column_descriptions=model.column_descriptions, diff --git a/sqlmesh/core/state_sync/common.py b/sqlmesh/core/state_sync/common.py index 2e8c67ac29..d1208c5213 100644 --- a/sqlmesh/core/state_sync/common.py +++ b/sqlmesh/core/state_sync/common.py @@ -141,8 +141,8 @@ def _expanded_tuple_comparison( cls, columns: t.List[exp.Column], values: t.List[t.Union[exp.Literal, exp.Neg]], - operator: t.Type[exp.Expression], - ) -> exp.Expression: + operator: t.Type[exp.Expr], + ) -> exp.Condition: """Generate expanded tuple comparison that works across all SQL engines. Converts tuple comparisons like (a, b, c) OP (x, y, z) into an expanded form @@ -177,8 +177,8 @@ def _expanded_tuple_comparison( # e.g., (a, b) <= (x, y) becomes: a < x OR (a = x AND b <= y) # For < and >, we use the strict operator throughout # e.g., (a, b) > (x, y) becomes: a > x OR (a = x AND b > x) - strict_operator: t.Type[exp.Expression] - final_operator: t.Type[exp.Expression] + strict_operator: t.Type[exp.Expr] + final_operator: t.Type[exp.Expr] if operator in (exp.LTE, exp.GTE): # For inclusive operators (<=, >=), use strict form for intermediate columns @@ -190,7 +190,7 @@ def _expanded_tuple_comparison( strict_operator = operator final_operator = operator - conditions: t.List[exp.Expression] = [] + conditions: t.List[exp.Expr] = [] for i in range(len(columns)): # Build equality conditions for all columns before current equality_conditions = [exp.EQ(this=columns[j], expression=values[j]) for j in range(i)] @@ -204,10 +204,10 @@ def _expanded_tuple_comparison( else: conditions.append(comparison_condition) - return exp.or_(*conditions) if len(conditions) > 1 else conditions[0] + return exp.or_(*conditions) if len(conditions) > 1 else t.cast(exp.Condition, conditions[0]) @property - def where_filter(self) -> exp.Expression: + def where_filter(self) -> exp.Condition: # Use expanded tuple comparisons for cross-engine compatibility # Native tuple comparisons like (a, b) > (x, y) don't work reliably across all SQL engines columns = [ @@ -223,7 +223,7 @@ def where_filter(self) -> exp.Expression: start_condition = self._expanded_tuple_comparison(columns, start_values, exp.GT) - range_filter: exp.Expression + range_filter: exp.Condition if isinstance(self.end, RowBoundary): end_values = [ exp.Literal.number(self.end.updated_ts), diff --git a/sqlmesh/core/state_sync/db/environment.py b/sqlmesh/core/state_sync/db/environment.py index e3f1d1ec9e..713ce0193e 100644 --- a/sqlmesh/core/state_sync/db/environment.py +++ b/sqlmesh/core/state_sync/db/environment.py @@ -296,7 +296,7 @@ def _environment_summmary_from_row(self, row: t.Tuple[str, ...]) -> EnvironmentS def _environments_query( self, - where: t.Optional[str | exp.Expression] = None, + where: t.Optional[str | exp.Expr] = None, lock_for_update: bool = False, required_fields: t.Optional[t.List[str]] = None, ) -> exp.Select: @@ -310,7 +310,7 @@ def _environments_query( return query.lock(copy=False) return query - def _create_expiration_filter_expr(self, current_ts: int) -> exp.Expression: + def _create_expiration_filter_expr(self, current_ts: int) -> exp.Expr: """Creates a SQLGlot filter expression to find expired environments. Args: @@ -322,7 +322,7 @@ def _create_expiration_filter_expr(self, current_ts: int) -> exp.Expression: ) def _fetch_environment_summaries( - self, where: t.Optional[str | exp.Expression] = None + self, where: t.Optional[str | exp.Expr] = None ) -> t.List[EnvironmentSummary]: return [ self._environment_summmary_from_row(row) diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index d584c69d65..8ca98f2d48 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -623,7 +623,7 @@ def _get_snapshots_expressions( self, snapshot_ids: t.Iterable[SnapshotIdLike], lock_for_update: bool = False, - ) -> t.Iterator[exp.Expression]: + ) -> t.Iterator[exp.Expr]: for where in snapshot_id_filter( self.engine_adapter, snapshot_ids, diff --git a/sqlmesh/core/state_sync/db/utils.py b/sqlmesh/core/state_sync/db/utils.py index 87c259f5d6..b0f321e21f 100644 --- a/sqlmesh/core/state_sync/db/utils.py +++ b/sqlmesh/core/state_sync/db/utils.py @@ -123,11 +123,9 @@ def create_batches(l: t.List[T], batch_size: int) -> t.List[t.List[T]]: return [l[i : i + batch_size] for i in range(0, len(l), batch_size)] -def fetchone( - engine_adapter: EngineAdapter, query: t.Union[exp.Expression, str] -) -> t.Optional[t.Tuple]: +def fetchone(engine_adapter: EngineAdapter, query: t.Union[exp.Expr, str]) -> t.Optional[t.Tuple]: return engine_adapter.fetchone(query, ignore_unsupported_errors=True, quote_identifiers=True) -def fetchall(engine_adapter: EngineAdapter, query: t.Union[exp.Expression, str]) -> t.List[t.Tuple]: +def fetchall(engine_adapter: EngineAdapter, query: t.Union[exp.Expr, str]) -> t.List[t.Tuple]: return engine_adapter.fetchall(query, ignore_unsupported_errors=True, quote_identifiers=True) diff --git a/sqlmesh/core/state_sync/export_import.py b/sqlmesh/core/state_sync/export_import.py index 3a63351ddb..2461ee50fa 100644 --- a/sqlmesh/core/state_sync/export_import.py +++ b/sqlmesh/core/state_sync/export_import.py @@ -29,7 +29,7 @@ class SQLMeshJSONStreamEncoder(JSONStreamEncoder): def default(self, obj: t.Any) -> t.Any: - if isinstance(obj, exp.Expression): + if isinstance(obj, exp.Expr): return _expression_encoder(obj) return super().default(obj) diff --git a/sqlmesh/core/table_diff.py b/sqlmesh/core/table_diff.py index bd32cc170f..df99227f89 100644 --- a/sqlmesh/core/table_diff.py +++ b/sqlmesh/core/table_diff.py @@ -224,9 +224,9 @@ def __init__( adapter: EngineAdapter, source: TableName, target: TableName, - on: t.List[str] | exp.Condition, + on: t.List[str] | exp.Expr, skip_columns: t.List[str] | None = None, - where: t.Optional[str | exp.Condition] = None, + where: t.Optional[str | exp.Expr] = None, limit: int = 20, source_alias: t.Optional[str] = None, target_alias: t.Optional[str] = None, @@ -305,18 +305,18 @@ def key_columns(self) -> t.Tuple[t.List[exp.Column], t.List[exp.Column], t.List[ return s_index, t_index, index_cols @property - def source_key_expression(self) -> exp.Expression: + def source_key_expression(self) -> exp.Expr: s_index, _, _ = self.key_columns return self._key_expression(s_index, self.source_schema) @property - def target_key_expression(self) -> exp.Expression: + def target_key_expression(self) -> exp.Expr: _, t_index, _ = self.key_columns return self._key_expression(t_index, self.target_schema) def _key_expression( self, cols: t.List[exp.Column], schema: t.Dict[str, exp.DataType] - ) -> exp.Expression: + ) -> exp.Expr: # if there is a single column, dont do anything fancy to it in order to allow existing indexes to be hit if len(cols) == 1: return exp.to_column(cols[0].name) @@ -363,7 +363,7 @@ def row_diff( s_index_names = [c.name for c in s_index] t_index_names = [t.name for t in t_index] - def _column_expr(name: str, table: str) -> exp.Expression: + def _column_expr(name: str, table: str) -> exp.Expr: column_type = matched_columns[name] qualified_column = exp.column(name, table) @@ -678,9 +678,9 @@ def _column_expr(name: str, table: str) -> exp.Expression: def _fetch_sample( self, sample_table: exp.Table, - s_selects: t.Dict[str, exp.Alias], + s_selects: t.Dict[str, exp.Expr], s_index: t.List[exp.Column], - t_selects: t.Dict[str, exp.Alias], + t_selects: t.Dict[str, exp.Expr], t_index: t.List[exp.Column], limit: int, ) -> pd.DataFrame: @@ -742,5 +742,5 @@ def _fetch_sample( return self.adapter.fetchdf(query, quote_identifiers=True) -def name(e: exp.Expression) -> str: +def name(e: exp.Expr) -> str: return e.args["alias"].sql(identify=True) diff --git a/sqlmesh/core/test/definition.py b/sqlmesh/core/test/definition.py index 2a838753de..629e8f8d5b 100644 --- a/sqlmesh/core/test/definition.py +++ b/sqlmesh/core/test/definition.py @@ -674,7 +674,7 @@ def _add_missing_columns( class SqlModelTest(ModelTest): - def test_ctes(self, ctes: t.Dict[str, exp.Expression], recursive: bool = False) -> None: + def test_ctes(self, ctes: t.Dict[str, exp.Expr], recursive: bool = False) -> None: """Run CTE queries and compare output to expected output""" for cte_name, values in self.body["outputs"].get("ctes", {}).items(): with self.subTest(cte=cte_name): @@ -819,7 +819,7 @@ def _execute_model(self) -> pd.DataFrame: time_kwargs = {key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables} df = next(self.model.render(context=self.context, variables=variables, **time_kwargs)) - assert not isinstance(df, exp.Expression) + assert not isinstance(df, exp.Expr) return df if isinstance(df, pd.DataFrame) else df.toPandas() diff --git a/sqlmesh/dbt/model.py b/sqlmesh/dbt/model.py index 41cea9b9ae..55994abf85 100644 --- a/sqlmesh/dbt/model.py +++ b/sqlmesh/dbt/model.py @@ -485,7 +485,7 @@ def model_kind(self, context: DbtContext) -> ModelKind: raise ConfigError(f"{materialization.value} materialization not supported.") - def _big_query_partition_by_expr(self, context: DbtContext) -> exp.Expression: + def _big_query_partition_by_expr(self, context: DbtContext) -> exp.Expr: assert isinstance(self.partition_by, dict) data_type = self.partition_by["data_type"].lower() raw_field = self.partition_by["field"] diff --git a/sqlmesh/lsp/hints.py b/sqlmesh/lsp/hints.py index a8d56e2f31..611ce8608d 100644 --- a/sqlmesh/lsp/hints.py +++ b/sqlmesh/lsp/hints.py @@ -5,7 +5,6 @@ from lsprotocol import types from sqlglot import exp -from sqlglot.expressions import Expression from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlmesh.core.model.definition import SqlModel from sqlmesh.lsp.context import LSPContext, ModelTarget @@ -60,7 +59,7 @@ def get_hints( def _get_type_hints_for_select( - expression: exp.Expression, + expression: exp.Expr, dialect: str, columns_to_types: t.Dict[str, exp.DataType], start_line: int, @@ -113,7 +112,7 @@ def _get_type_hints_for_select( def _get_type_hints_for_model_from_query( - query: Expression, + query: exp.Expr, dialect: str, columns_to_types: t.Dict[str, exp.DataType], start_line: int, diff --git a/sqlmesh/lsp/reference.py b/sqlmesh/lsp/reference.py index 80d401f79c..73c4e5681b 100644 --- a/sqlmesh/lsp/reference.py +++ b/sqlmesh/lsp/reference.py @@ -209,7 +209,7 @@ def get_macro_reference( target: t.Union[Model, StandaloneAudit], read_file: t.List[str], config_path: t.Optional[Path], - node: exp.Expression, + node: exp.Expr, macro_name: str, ) -> t.Optional[Reference]: # Get the file path where the macro is defined diff --git a/sqlmesh/utils/date.py b/sqlmesh/utils/date.py index c9bb19c835..bdc15125d4 100644 --- a/sqlmesh/utils/date.py +++ b/sqlmesh/utils/date.py @@ -168,7 +168,7 @@ def to_datetime( dt: t.Optional[datetime] = value elif isinstance(value, date): dt = datetime(value.year, value.month, value.day) - elif isinstance(value, exp.Expression): + elif isinstance(value, exp.Expr): return to_datetime(value.name) else: try: @@ -401,7 +401,7 @@ def to_time_column( dialect: str, time_column_format: t.Optional[str] = None, nullable: bool = False, -) -> exp.Expression: +) -> exp.Expr: """Convert a TimeLike object to the same time format and type as the model's time column.""" if dialect == "clickhouse" and time_column_type.is_type( *(exp.DataType.TEMPORAL_TYPES - {exp.DataType.Type.DATE, exp.DataType.Type.DATE32}) diff --git a/sqlmesh/utils/jinja.py b/sqlmesh/utils/jinja.py index 240b183391..725842c842 100644 --- a/sqlmesh/utils/jinja.py +++ b/sqlmesh/utils/jinja.py @@ -12,7 +12,8 @@ from jinja2 import Environment, Template, nodes, UndefinedError from jinja2.runtime import Macro -from sqlglot import Dialect, Expression, Parser, TokenType +from sqlglot import Dialect, Parser, TokenType +from sqlglot.expressions import Expression from sqlmesh.core import constants as c from sqlmesh.core import dialect as d diff --git a/sqlmesh/utils/lineage.py b/sqlmesh/utils/lineage.py index f5b4506c68..f63395708d 100644 --- a/sqlmesh/utils/lineage.py +++ b/sqlmesh/utils/lineage.py @@ -70,7 +70,7 @@ class MacroReference(PydanticModel): def extract_references_from_query( - query: exp.Expression, + query: exp.Expr, context: t.Union["Context", "GenericContext[t.Any]"], document_path: Path, read_file: t.List[str], @@ -95,7 +95,11 @@ def extract_references_from_query( # Check if this table reference is a CTE in the current scope if cte_scope := scope.cte_sources.get(table_name): + if cte_scope.expression is None: + continue cte = cte_scope.expression.parent + if cte is None: + continue alias = cte.args["alias"] if isinstance(alias, exp.TableAlias): identifier = alias.this diff --git a/sqlmesh/utils/metaprogramming.py b/sqlmesh/utils/metaprogramming.py index 753db427f3..cd77c36353 100644 --- a/sqlmesh/utils/metaprogramming.py +++ b/sqlmesh/utils/metaprogramming.py @@ -444,6 +444,41 @@ def value( ) +def _resolve_import_module(obj: t.Any, name: str) -> str: + """Resolve the most appropriate module path for importing an object. + + When a callable's ``__module__`` points to a submodule of a known public + module (e.g. ``sqlglot.expressions.builders`` is a submodule of + ``sqlglot.expressions``), and the object is re-exported from that public + parent module, prefer the public parent so that generated import statements + remain stable across internal restructurings of third-party packages. + + Args: + obj: The callable to resolve. + name: The name under which the object will be imported. + + Returns: + The module path to use in the ``from import `` statement. + """ + module_name = getattr(obj, "__module__", None) or "" + parts = module_name.split(".") + + # Walk from the shallowest ancestor (excluding the top-level package) up to + # the immediate parent, returning the shallowest one that re-exports the object. + # We skip the top-level package to avoid over-normalizing (e.g. ``sqlglot`` + # re-exports everything, but callers expect ``sqlglot.expressions``). + for i in range(2, len(parts)): + parent = ".".join(parts[:i]) + try: + parent_module = sys.modules.get(parent) or importlib.import_module(parent) + if getattr(parent_module, name, None) is obj: + return parent + except Exception: + continue + + return module_name + + def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable]: """Serializes a python function into a self contained dictionary. @@ -512,7 +547,7 @@ def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable ) else: serialized[k] = Executable( - payload=f"from {v.__module__} import {name}", + payload=f"from {_resolve_import_module(v, name)} import {name}", kind=ExecutableKind.IMPORT, is_metadata=is_metadata, ) diff --git a/sqlmesh/utils/pydantic.py b/sqlmesh/utils/pydantic.py index 2c9c570e5b..8bc81e2774 100644 --- a/sqlmesh/utils/pydantic.py +++ b/sqlmesh/utils/pydantic.py @@ -56,7 +56,7 @@ def get_dialect(values: t.Any) -> str: return model._dialect if dialect is None else dialect # type: ignore -def _expression_encoder(e: exp.Expression) -> str: +def _expression_encoder(e: exp.Expr) -> str: return e.meta.get("sql") or e.sql(dialect=e.meta.get("dialect")) @@ -70,7 +70,7 @@ class PydanticModel(pydantic.BaseModel): # crippled badly. Here we need to enumerate all different ways of how sqlglot expressions # show up in pydantic models. json_encoders={ - exp.Expression: _expression_encoder, + exp.Expr: _expression_encoder, exp.DataType: _expression_encoder, exp.Tuple: _expression_encoder, AuditQueryTypes: _expression_encoder, # type: ignore @@ -190,7 +190,7 @@ def validate_list_of_strings(v: t.Any) -> t.List[str]: def validate_string(v: t.Any) -> str: - if isinstance(v, exp.Expression): + if isinstance(v, exp.Expr): return v.name return str(v) @@ -204,13 +204,13 @@ def validate_expression(expression: E, dialect: str) -> E: def bool_validator(v: t.Any) -> bool: if isinstance(v, exp.Boolean): return v.this - if isinstance(v, exp.Expression): + if isinstance(v, exp.Expr): return str_to_bool(v.name) return str_to_bool(str(v or "")) def positive_int_validator(v: t.Any) -> int: - if isinstance(v, exp.Expression) and v.is_int: + if isinstance(v, exp.Expr) and v.is_int: v = int(v.name) if not isinstance(v, int): raise ValueError(f"Invalid num {v}. Value must be an integer value") @@ -237,10 +237,10 @@ def _formatted_validation_errors(error: pydantic.ValidationError) -> t.List[str] def _get_field( v: t.Any, values: t.Any, -) -> exp.Expression: +) -> exp.Expr: dialect = get_dialect(values) - if isinstance(v, exp.Expression): + if isinstance(v, exp.Expr): expression = v else: expression = parse_one(v, dialect=dialect) @@ -257,16 +257,16 @@ def _get_field( def _get_fields( v: t.Any, values: t.Any, -) -> t.List[exp.Expression]: +) -> t.List[exp.Expr]: dialect = get_dialect(values) if isinstance(v, (exp.Tuple, exp.Array)): - expressions: t.List[exp.Expression] = v.expressions - elif isinstance(v, exp.Expression): + expressions: t.List[exp.Expr] = v.expressions + elif isinstance(v, exp.Expr): expressions = [v] else: expressions = [ - parse_one(entry, dialect=dialect) if isinstance(entry, str) else entry + parse_one(entry, dialect=dialect) if isinstance(entry, str) else entry # type: ignore[misc] for entry in ensure_list(v) ] @@ -278,7 +278,7 @@ def _get_fields( return results -def list_of_fields_validator(v: t.Any, values: t.Any) -> t.List[exp.Expression]: +def list_of_fields_validator(v: t.Any, values: t.Any) -> t.List[exp.Expr]: return _get_fields(v, values) @@ -291,15 +291,15 @@ def column_validator(v: t.Any, values: t.Any) -> exp.Column: def list_of_fields_or_star_validator( v: t.Any, values: t.Any -) -> t.Union[exp.Star, t.List[exp.Expression]]: +) -> t.Union[exp.Star, t.List[exp.Expr]]: expressions = _get_fields(v, values) if len(expressions) == 1 and isinstance(expressions[0], exp.Star): return t.cast(exp.Star, expressions[0]) - return t.cast(t.List[exp.Expression], expressions) + return t.cast(t.List[exp.Expr], expressions) def cron_validator(v: t.Any) -> str: - if isinstance(v, exp.Expression): + if isinstance(v, exp.Expr): v = v.name from croniter import CroniterBadCronError, croniter @@ -338,7 +338,7 @@ def get_concrete_types_from_typehint(typehint: type[t.Any]) -> set[type[t.Any]]: SQLGlotBool = bool SQLGlotPositiveInt = int SQLGlotColumn = exp.Column - SQLGlotListOfFields = t.List[exp.Expression] + SQLGlotListOfFields = t.List[exp.Expr] SQLGlotListOfFieldsOrStar = t.Union[SQLGlotListOfFields, exp.Star] SQLGlotCron = str else: @@ -348,10 +348,8 @@ def get_concrete_types_from_typehint(typehint: type[t.Any]) -> set[type[t.Any]]: SQLGlotString = t.Annotated[str, BeforeValidator(validate_string)] SQLGlotBool = t.Annotated[bool, BeforeValidator(bool_validator)] SQLGlotPositiveInt = t.Annotated[int, BeforeValidator(positive_int_validator)] - SQLGlotColumn = t.Annotated[exp.Expression, BeforeValidator(column_validator)] - SQLGlotListOfFields = t.Annotated[ - t.List[exp.Expression], BeforeValidator(list_of_fields_validator) - ] + SQLGlotColumn = t.Annotated[exp.Expr, BeforeValidator(column_validator)] + SQLGlotListOfFields = t.Annotated[t.List[exp.Expr], BeforeValidator(list_of_fields_validator)] SQLGlotListOfFieldsOrStar = t.Annotated[ t.Union[SQLGlotListOfFields, exp.Star], BeforeValidator(list_of_fields_or_star_validator) ] diff --git a/tests/conftest.py b/tests/conftest.py index b18271465d..46086444bd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -381,7 +381,7 @@ def _make_function( @pytest.fixture def assert_exp_eq() -> t.Callable: def _assert_exp_eq( - source: exp.Expression | str, expected: exp.Expression | str, dialect: DialectType = None + source: exp.Expr | str, expected: exp.Expr | str, dialect: DialectType = None ) -> None: source_exp = maybe_parse(source, dialect=dialect) expected_exp = maybe_parse(expected, dialect=dialect) diff --git a/tests/core/engine_adapter/__init__.py b/tests/core/engine_adapter/__init__.py index 4761c4100b..a9370b8cc3 100644 --- a/tests/core/engine_adapter/__init__.py +++ b/tests/core/engine_adapter/__init__.py @@ -11,7 +11,7 @@ def to_sql_calls(adapter: EngineAdapter, identify: bool = True) -> t.List[str]: value = call[0][0] sql = ( value.sql(dialect=adapter.dialect, identify=identify) - if isinstance(value, exp.Expression) + if isinstance(value, exp.Expr) else str(value) ) output.append(sql) diff --git a/tests/core/engine_adapter/integration/__init__.py b/tests/core/engine_adapter/integration/__init__.py index 4ad6a17944..47ccdc876a 100644 --- a/tests/core/engine_adapter/integration/__init__.py +++ b/tests/core/engine_adapter/integration/__init__.py @@ -276,7 +276,7 @@ def time_formatter(self) -> t.Callable: return lambda x, _: exp.Literal.string(to_ds(x)) @property - def partitioned_by(self) -> t.List[exp.Expression]: + def partitioned_by(self) -> t.List[exp.Expr]: return [parse_one(self.time_column)] @property @@ -388,8 +388,8 @@ def table(self, table_name: TableName, schema: str = TEST_SCHEMA) -> exp.Table: ) def physical_properties( - self, properties_for_dialect: t.Dict[str, t.Dict[str, str | exp.Expression]] - ) -> t.Dict[str, exp.Expression]: + self, properties_for_dialect: t.Dict[str, t.Dict[str, str | exp.Expr]] + ) -> t.Dict[str, exp.Expr]: if props := properties_for_dialect.get(self.dialect): return {k: exp.Literal.string(v) if isinstance(v, str) else v for k, v in props.items()} return {} diff --git a/tests/core/engine_adapter/integration/test_integration_athena.py b/tests/core/engine_adapter/integration/test_integration_athena.py index 1c0ece6d78..9d23af206e 100644 --- a/tests/core/engine_adapter/integration/test_integration_athena.py +++ b/tests/core/engine_adapter/integration/test_integration_athena.py @@ -378,7 +378,7 @@ def test_insert_overwrite_by_time_partition_date_type( ), # note: columns_to_types_from_df() would infer this as TEXT but we need a DATE type } - def time_formatter(time: TimeLike, _: t.Optional[t.Dict[str, exp.DataType]]) -> exp.Expression: + def time_formatter(time: TimeLike, _: t.Optional[t.Dict[str, exp.DataType]]) -> exp.Expr: return exp.cast(exp.Literal.string(to_ds(time)), "date") engine_adapter.create_table( @@ -440,7 +440,7 @@ def test_insert_overwrite_by_time_partition_datetime_type( ), # note: columns_to_types_from_df() would infer this as TEXT but we need a DATETIME type } - def time_formatter(time: TimeLike, _: t.Optional[t.Dict[str, exp.DataType]]) -> exp.Expression: + def time_formatter(time: TimeLike, _: t.Optional[t.Dict[str, exp.DataType]]) -> exp.Expr: return exp.cast(exp.Literal.string(to_ts(time)), "datetime") engine_adapter.create_table( diff --git a/tests/core/engine_adapter/integration/test_integration_clickhouse.py b/tests/core/engine_adapter/integration/test_integration_clickhouse.py index f09360c673..4420acec71 100644 --- a/tests/core/engine_adapter/integration/test_integration_clickhouse.py +++ b/tests/core/engine_adapter/integration/test_integration_clickhouse.py @@ -64,9 +64,7 @@ def _create_table_and_insert_existing_data( "ds": exp.DataType.build("Date", "clickhouse"), }, table_name: str = "data_existing", - partitioned_by: t.Optional[t.List[exp.Expression]] = [ - parse_one("toMonth(ds)", dialect="clickhouse") - ], + partitioned_by: t.Optional[t.List[exp.Expr]] = [parse_one("toMonth(ds)", dialect="clickhouse")], ) -> exp.Table: existing_data = existing_data existing_table_name: exp.Table = ctx.table(table_name) diff --git a/tests/core/engine_adapter/test_athena.py b/tests/core/engine_adapter/test_athena.py index 66e84ae025..19c92f66ac 100644 --- a/tests/core/engine_adapter/test_athena.py +++ b/tests/core/engine_adapter/test_athena.py @@ -81,7 +81,7 @@ def table_diff(adapter: AthenaEngineAdapter) -> TableDiff: def test_table_location( adapter: AthenaEngineAdapter, config_s3_warehouse_location: t.Optional[str], - table_properties: t.Optional[t.Dict[str, exp.Expression]], + table_properties: t.Optional[t.Dict[str, exp.Expr]], table: exp.Table, expected_location: t.Optional[str], ) -> None: diff --git a/tests/core/engine_adapter/test_bigquery.py b/tests/core/engine_adapter/test_bigquery.py index 9a6bc7d851..134f144df1 100644 --- a/tests/core/engine_adapter/test_bigquery.py +++ b/tests/core/engine_adapter/test_bigquery.py @@ -593,7 +593,7 @@ def _to_sql_calls(execute_mock: t.Any, identify: bool = True) -> t.List[str]: for value in values: sql = ( value.sql(dialect="bigquery", identify=identify) - if isinstance(value, exp.Expression) + if isinstance(value, exp.Expr) else str(value) ) output.append(sql) diff --git a/tests/core/engine_adapter/test_clickhouse.py b/tests/core/engine_adapter/test_clickhouse.py index 54fbe7c323..7ff971b742 100644 --- a/tests/core/engine_adapter/test_clickhouse.py +++ b/tests/core/engine_adapter/test_clickhouse.py @@ -1365,7 +1365,7 @@ def test_exchange_tables( # The EXCHANGE TABLES call errored, so we RENAME TABLE instead assert [ quote_identifiers(call.args[0]).sql("clickhouse") - if isinstance(call.args[0], exp.Expression) + if isinstance(call.args[0], exp.Expr) else call.args[0] for call in execute_mock.call_args_list ] == [ diff --git a/tests/core/engine_adapter/test_snowflake.py b/tests/core/engine_adapter/test_snowflake.py index 60f6d38e5f..dcb6820297 100644 --- a/tests/core/engine_adapter/test_snowflake.py +++ b/tests/core/engine_adapter/test_snowflake.py @@ -123,7 +123,7 @@ def test_get_data_objects_lowercases_columns( def test_session( mocker: MockerFixture, make_mocked_engine_adapter: t.Callable, - current_warehouse: t.Union[str, exp.Expression], + current_warehouse: t.Union[str, exp.Expr], current_warehouse_exp: str, configured_warehouse: t.Optional[str], configured_warehouse_exp: t.Optional[str], diff --git a/tests/core/integration/test_auto_restatement.py b/tests/core/integration/test_auto_restatement.py index 70ca227fd3..1bda373a8f 100644 --- a/tests/core/integration/test_auto_restatement.py +++ b/tests/core/integration/test_auto_restatement.py @@ -27,7 +27,7 @@ def test_run_auto_restatement(init_and_plan_context: t.Callable): @macro() def record_intervals( - evaluator, name: exp.Expression, start: exp.Expression, end: exp.Expression, **kwargs: t.Any + evaluator, name: exp.Expr, start: exp.Expr, end: exp.Expr, **kwargs: t.Any ) -> None: if evaluator.runtime_stage == "evaluating": evaluator.engine_adapter.insert_append( @@ -178,7 +178,7 @@ def test_run_auto_restatement_failure(init_and_plan_context: t.Callable): context, _ = init_and_plan_context("examples/sushi") @macro() - def fail_auto_restatement(evaluator, start: exp.Expression, **kwargs: t.Any) -> None: + def fail_auto_restatement(evaluator, start: exp.Expr, **kwargs: t.Any) -> None: if evaluator.runtime_stage == "evaluating" and start.name != "2023-01-01": raise Exception("Failed") diff --git a/tests/core/integration/utils.py b/tests/core/integration/utils.py index bc731e6cc8..ba233080b5 100644 --- a/tests/core/integration/utils.py +++ b/tests/core/integration/utils.py @@ -105,7 +105,10 @@ def apply_to_environment( def change_data_type( - context: Context, model_name: str, old_type: DataType.Type, new_type: DataType.Type + context: Context, + model_name: str, + old_type: exp.DType, + new_type: exp.DType, ) -> None: model = context.get_model(model_name) assert model is not None diff --git a/tests/core/test_audit.py b/tests/core/test_audit.py index 66897ed088..90ac655cc6 100644 --- a/tests/core/test_audit.py +++ b/tests/core/test_audit.py @@ -329,7 +329,7 @@ def test_load_with_dictionary_defaults(): audit = load_audit(expressions, dialect="spark") assert audit.defaults.keys() == {"field1", "field2"} for value in audit.defaults.values(): - assert isinstance(value, exp.Expression) + assert isinstance(value, exp.Expr) def test_load_with_single_defaults(): @@ -350,7 +350,7 @@ def test_load_with_single_defaults(): audit = load_audit(expressions, dialect="duckdb") assert audit.defaults.keys() == {"field1"} for value in audit.defaults.values(): - assert isinstance(value, exp.Expression) + assert isinstance(value, exp.Expr) def test_no_audit_statement(): diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 9ae239f298..8c81a90b8d 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -570,7 +570,8 @@ def test_variables(): assert config.get_gateway("local").variables == {"uppercase_var": 2} with pytest.raises( - ConfigError, match="Unsupported variable value type: " + ConfigError, + match=r"Unsupported variable value type: ", ): Config(variables={"invalid_var": exp.column("sqlglot_expr")}) diff --git a/tests/core/test_macros.py b/tests/core/test_macros.py index fb10f64b27..e37a7ec05b 100644 --- a/tests/core/test_macros.py +++ b/tests/core/test_macros.py @@ -98,7 +98,7 @@ def test_select_macro(evaluator): @macro() def test_literal_type(evaluator, a: t.Literal["test_literal_a", "test_literal_b", 1, True]): - if isinstance(a, exp.Expression): + if isinstance(a, exp.Expr): raise SQLMeshError("Coercion failed") return f"'{a}'" @@ -694,8 +694,8 @@ def test_macro_coercion(macro_evaluator: MacroEvaluator, assert_exp_eq): ) == (1, "2", (3.0,)) # Using exp.Expression will always return the input expression - assert coerce(parse_one("order", into=exp.Column), exp.Expression) == exp.column("order") - assert coerce(exp.Literal.string("OK"), exp.Expression) == exp.Literal.string("OK") + assert coerce(parse_one("order", into=exp.Column), exp.Expr) == exp.column("order") + assert coerce(exp.Literal.string("OK"), exp.Expr) == exp.Literal.string("OK") # Strict flag allows raising errors and is used when recursively coercing expressions # otherwise, in general, we want to be lenient and just warn the user when something is not possible @@ -930,12 +930,10 @@ def test_date_spine(assert_exp_eq, dialect, date_part): FLATTEN( INPUT => ARRAY_GENERATE_RANGE( 0, - ( - DATEDIFF( - {date_part.upper()}, - CAST('2022-01-01' AS DATE), - CAST('2024-12-31' AS DATE) - ) + 1 - 1 + DATEDIFF( + {date_part.upper()}, + CAST('2022-01-01' AS DATE), + CAST('2024-12-31' AS DATE) ) + 1 ) ) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index cfcb843739..81707c075f 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -6011,7 +6011,7 @@ def test_when_matched_normalization() -> None: assert isinstance(model.kind, IncrementalByUniqueKeyKind) assert isinstance(model.kind.when_matched, exp.Whens) first_expression = model.kind.when_matched.expressions[0] - assert isinstance(first_expression, exp.Expression) + assert isinstance(first_expression, exp.Expr) assert ( first_expression.sql(dialect="snowflake") == 'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."KEY_A" = "__MERGE_SOURCE__"."KEY_A", "__MERGE_TARGET__"."KEY_B" = "__MERGE_SOURCE__"."KEY_B"' @@ -6039,7 +6039,7 @@ def test_when_matched_normalization() -> None: assert isinstance(model.kind, IncrementalByUniqueKeyKind) assert isinstance(model.kind.when_matched, exp.Whens) first_expression = model.kind.when_matched.expressions[0] - assert isinstance(first_expression, exp.Expression) + assert isinstance(first_expression, exp.Expr) assert ( first_expression.sql(dialect="snowflake") == 'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."kEy_A" = "__MERGE_SOURCE__"."kEy_A", "__MERGE_TARGET__"."kEY_b" = "__MERGE_SOURCE__"."KEY_B"' @@ -6447,7 +6447,7 @@ def test_end_no_start(): def test_variables(): @macro() - def test_macro_var(evaluator) -> exp.Expression: + def test_macro_var(evaluator) -> exp.Expr: return exp.convert(evaluator.var("TEST_VAR_D") + 10) expressions = parse( @@ -6946,7 +6946,7 @@ def test_unrendered_macros_sql_model(mocker: MockerFixture) -> None: # merge_filter will stay unrendered as well assert model.unique_key[0] == exp.column("a", quoted=True) assert ( - t.cast(exp.Expression, model.merge_filter).sql() + t.cast(exp.Expr, model.merge_filter).sql() == '"__MERGE_SOURCE__"."id" > 0 AND "__MERGE_TARGET__"."updated_at" < @end_ds AND "__MERGE_SOURCE__"."updated_at" > @start_ds AND @merge_filter_var' ) @@ -7149,7 +7149,7 @@ def test_gateway_macro() -> None: assert model.render_query_or_raise().sql() == "SELECT 'in_memory' AS \"gateway\"" @macro() - def macro_uses_gateway(evaluator) -> exp.Expression: + def macro_uses_gateway(evaluator) -> exp.Expr: return exp.convert(evaluator.gateway + "_from_macro") model = load_sql_based_model( @@ -8729,7 +8729,7 @@ def test_merge_filter_macro(): def predicate( evaluator: MacroEvaluator, cluster_column: exp.Column, - ) -> exp.Expression: + ) -> exp.Expr: return parse_one(f"source.{cluster_column} > dateadd(day, -7, target.{cluster_column})") expressions = d.parse( @@ -9904,7 +9904,7 @@ def entrypoint(evaluator): {"customer": SqlValue(sql="customer1"), "customer_field": SqlValue(sql="'bar'")} ) - assert t.cast(exp.Expression, customer1_model.render_query()).sql() == ( + assert t.cast(exp.Expr, customer1_model.render_query()).sql() == ( """SELECT 'bar' AS "foo", "bar" AS "foo2", 'bar' AS "foo3" FROM "db"."customer1"."my_source" AS "my_source\"""" ) @@ -9917,7 +9917,7 @@ def entrypoint(evaluator): {"customer": SqlValue(sql="customer2"), "customer_field": SqlValue(sql="qux")} ) - assert t.cast(exp.Expression, customer2_model.render_query()).sql() == ( + assert t.cast(exp.Expr, customer2_model.render_query()).sql() == ( '''SELECT "qux" AS "foo", "qux" AS "foo2", "qux" AS "foo3" FROM "db"."customer2"."my_source" AS "my_source"''' ) @@ -10703,12 +10703,12 @@ def m4_non_metadata_references_v6(evaluator): query_with_vars = macro_evaluator.transform( parse_one("SELECT " + ", ".join(f"@v{var}, @VAR('v{var}')" for var in [1, 2, 3, 6])) ) - assert t.cast(exp.Expression, query_with_vars).sql() == "SELECT 1, 1, 2, 2, 3, 3, 6, 6" + assert t.cast(exp.Expr, query_with_vars).sql() == "SELECT 1, 1, 2, 2, 3, 3, 6, 6" query_with_blueprint_vars = macro_evaluator.transform( parse_one("SELECT " + ", ".join(f"@v{var}, @BLUEPRINT_VAR('v{var}')" for var in [4, 5])) ) - assert t.cast(exp.Expression, query_with_blueprint_vars).sql() == "SELECT 4, 4, 5, 5" + assert t.cast(exp.Expr, query_with_blueprint_vars).sql() == "SELECT 4, 4, 5, 5" def test_variable_mentioned_in_both_metadata_and_non_metadata_macro(tmp_path: Path) -> None: diff --git a/tests/core/test_plan.py b/tests/core/test_plan.py index 4b330c376f..590cda01ec 100644 --- a/tests/core/test_plan.py +++ b/tests/core/test_plan.py @@ -1795,7 +1795,7 @@ def test_forward_only_models_model_kind_changed(make_snapshot, mocker: MockerFix ) def test_forward_only_models_model_kind_changed_to_incremental_by_time_range( make_snapshot, - partitioned_by: t.List[exp.Expression], + partitioned_by: t.List[exp.Expr], expected_forward_only: bool, ): snapshot = make_snapshot( diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 1413ac81f1..f3fae15e8a 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -3683,7 +3683,7 @@ def test_custom_materialization_strategy_with_custom_properties(adapter_mock, ma custom_insert_kind = None class TestCustomKind(CustomKind): - _primary_key: t.List[exp.Expression] # type: ignore[no-untyped-def] + _primary_key: t.List[exp.Expr] # type: ignore[no-untyped-def] @model_validator(mode="after") def _validate_model(self) -> Self: @@ -3695,7 +3695,7 @@ def _validate_model(self) -> Self: return self @property - def primary_key(self) -> t.List[exp.Expression]: + def primary_key(self) -> t.List[exp.Expr]: return self._primary_key class TestCustomMaterializationStrategy(CustomMaterialization[TestCustomKind]): diff --git a/tests/utils/test_metaprogramming.py b/tests/utils/test_metaprogramming.py index 4e55ae490e..9a6f0c95cd 100644 --- a/tests/utils/test_metaprogramming.py +++ b/tests/utils/test_metaprogramming.py @@ -23,6 +23,7 @@ Executable, ExecutableKind, _dict_sort, + _resolve_import_module, build_env, func_globals, normalize_source, @@ -49,7 +50,7 @@ def test_print_exception(mocker: MockerFixture): except Exception as ex: print_exception(ex, test_env, out_mock) - expected_message = r""" File ".*?.tests.utils.test_metaprogramming\.py", line 48, in test_print_exception + expected_message = r""" File ".*?.tests.utils.test_metaprogramming\.py", line 49, in test_print_exception eval\("test_fun\(\)", env\).* File '/test/path.py' \(or imported file\), line 2, in test_fun @@ -638,3 +639,18 @@ def test_dict_sort_executable_integration(): # non-deterministic repr should not change the payload exec3 = Executable.value(variables1) assert exec3.payload == "{'env': 'dev', 'debug': True, 'timeout': 30}" + + +def test_resolve_import_module(): + """Test that _resolve_import_module finds the shallowest public re-exporting module.""" + # to_table lives in sqlglot.expressions.builders but is re-exported from sqlglot.expressions + assert _resolve_import_module(to_table, "to_table") == "sqlglot.expressions" + + # Objects whose __module__ is already the public module should be returned as-is + assert _resolve_import_module(exp.Column, "Column") == "sqlglot.expressions" + + # Objects not re-exported by any parent should return the original module + class _Local: + __module__ = "some.deep.internal.module" + + assert _resolve_import_module(_Local, "_Local") == "some.deep.internal.module" diff --git a/web/server/api/endpoints/table_diff.py b/web/server/api/endpoints/table_diff.py index d441d49e5a..b0167ed032 100644 --- a/web/server/api/endpoints/table_diff.py +++ b/web/server/api/endpoints/table_diff.py @@ -126,7 +126,7 @@ def get_table_diff( table_diffs = context.table_diff( source=source, target=target, - on=exp.condition(on) if on else None, + on=t.cast(exp.Condition, exp.condition(on)) if on else None, select_models={model_or_snapshot} if model_or_snapshot else None, where=where, limit=limit,