diff --git a/sqlmesh/core/engine_adapter/athena.py b/sqlmesh/core/engine_adapter/athena.py index 338381549b..f8d115acad 100644 --- a/sqlmesh/core/engine_adapter/athena.py +++ b/sqlmesh/core/engine_adapter/athena.py @@ -78,10 +78,9 @@ def s3_warehouse_location_or_raise(self) -> str: @property def catalog_support(self) -> CatalogSupport: - # Athena has the concept of catalogs but the current catalog is set in the connection parameters with no way to query or change it after that - # It also cant create new catalogs, you have to configure them in AWS. Typically, catalogs that are not "awsdatacatalog" - # are pointers to the "awsdatacatalog" of other AWS accounts - return CatalogSupport.SINGLE_CATALOG_ONLY + # Athena supports querying and writing to multiple catalogs (e.g. awsdatacatalog and s3tablescatalog) + # without needing a SET CATALOG command. + return CatalogSupport.FULL_SUPPORT def create_state_table( self, @@ -105,6 +104,9 @@ def _get_data_objects( """ schema_name = to_schema(schema_name) schema = schema_name.db + + info_schema_tables = exp.table_("tables", db="information_schema", catalog=schema_name.catalog, alias="t") + query = ( exp.select( exp.column("table_catalog").as_("catalog"), @@ -118,7 +120,7 @@ def _get_data_objects( .else_(exp.column("table_type", table="t")) .as_("type"), ) - .from_(exp.to_table("information_schema.tables", alias="t")) + .from_(info_schema_tables) .where(exp.column("table_schema", table="t").eq(schema)) ) if object_names: @@ -141,9 +143,12 @@ def columns( ) -> t.Dict[str, exp.DataType]: table = exp.to_table(table_name) # note: the data_type column contains the full parameterized type, eg 'varchar(10)' + + info_schema_columns = exp.table_("columns", db="information_schema", catalog=table.catalog) + query = ( exp.select("column_name", "data_type") - .from_("information_schema.columns") + .from_(info_schema_columns) .where(exp.column("table_schema").eq(table.db), exp.column("table_name").eq(table.name)) .order_by("ordinal_position") ) @@ -197,6 +202,11 @@ def _build_create_table_exp( else: table = table_name_or_schema + table_format = kwargs.pop("table_format", None) + if not table_format and table_properties and "table_format" in table_properties: + tf = table_properties.get("table_format") + table_format = tf.name if isinstance(tf, exp.Literal) else str(tf) + properties = self._build_table_properties_exp( table=table, expression=expression, @@ -205,10 +215,11 @@ def _build_create_table_exp( table_properties=table_properties, table_description=table_description, table_kind=table_kind, + table_format=table_format, **kwargs, ) - is_hive = self._table_type(kwargs.get("table_format", None)) == "hive" + is_hive = self._table_type(table_format) == "hive" # Filter any PARTITIONED BY properties from the main column list since they cant be specified in both places # ref: https://docs.aws.amazon.com/athena/latest/ug/partitions.html @@ -247,17 +258,36 @@ def _build_table_properties_exp( **kwargs: t.Any, ) -> t.Optional[exp.Properties]: properties: t.List[exp.Expr] = [] - table_properties = table_properties or {} + table_properties = table_properties.copy() if table_properties else {} + + s3_table_prop = table_properties.pop("s3_table", None) + is_s3_table = False + if s3_table_prop is not None: + if isinstance(s3_table_prop, exp.Boolean): + is_s3_table = s3_table_prop.this + elif isinstance(s3_table_prop, exp.Literal): + is_s3_table = s3_table_prop.name.lower() in ("true", "1") + else: + is_s3_table = str(s3_table_prop).lower() in ("true", "1") + elif table and table.catalog and table.catalog.startswith("s3tablescatalog/"): + is_s3_table = True + + tf = table_properties.pop("table_format", None) + if not table_format and tf: + table_format = tf.name if isinstance(tf, exp.Literal) else str(tf) is_hive = self._table_type(table_format) == "hive" is_iceberg = not is_hive + if is_s3_table and is_hive: + raise SQLMeshError("Amazon S3 Tables only support the Iceberg format") + if is_hive and not expression: # Hive tables are CREATE EXTERNAL TABLE, Iceberg tables are CREATE TABLE # Unless it's a CTAS, those are always CREATE TABLE properties.append(exp.ExternalProperty()) - if table_format: + if table_format and not is_s3_table: properties.append( exp.Property(this=exp.var("table_type"), value=exp.Literal.string(table_format)) ) @@ -279,9 +309,30 @@ def _build_table_properties_exp( else: schema_expressions = partitioned_by - properties.append( - exp.PartitionedByProperty(this=exp.Schema(expressions=schema_expressions)) - ) + if is_hive: + properties.append( + exp.PartitionedByProperty(this=exp.Schema(expressions=schema_expressions)) + ) + else: + if is_s3_table: + array_exprs = [] + for e in schema_expressions: + e_copy = e.copy() + e_copy.transform( + lambda n: n.name if isinstance(n, exp.Identifier) else n, copy=False + ) + expr_sql = e_copy.sql(dialect="athena") + array_exprs.append(exp.Literal.string(expr_sql)) + + properties.append( + exp.Property( + this=exp.var("partitioning"), value=exp.Array(expressions=array_exprs) + ) + ) + else: + properties.append( + exp.PartitionedByProperty(this=exp.Schema(expressions=schema_expressions)) + ) if clustered_by: # Athena itself supports CLUSTERED BY, via the syntax CLUSTERED BY (col) INTO BUCKETS @@ -293,13 +344,16 @@ def _build_table_properties_exp( if storage_format: if is_iceberg: - # TBLPROPERTIES('format'='parquet') - table_properties["format"] = exp.Literal.string(storage_format) + if not is_s3_table or storage_format.lower() == "parquet": + # TBLPROPERTIES('format'='parquet') + table_properties["format"] = exp.Literal.string(storage_format) + elif is_s3_table and storage_format.lower() != "parquet": + raise SQLMeshError("Amazon S3 Tables only support the PARQUET storage format") else: # STORED AS PARQUET properties.append(exp.FileFormatProperty(this=storage_format)) - if table and (location := self._table_location_or_raise(table_properties, table)): + if table and not is_s3_table and (location := self._table_location_or_raise(table_properties, table)): properties.append(location) if is_iceberg and expression: @@ -308,8 +362,28 @@ def _build_table_properties_exp( # ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties properties.append(exp.Property(this=exp.var("is_external"), value="false")) - for name, value in table_properties.items(): - properties.append(exp.Property(this=exp.var(name), value=value)) + if not is_s3_table: + for name, value in table_properties.items(): + properties.append(exp.Property(this=exp.var(name), value=value)) + elif is_s3_table: + # According to AWS documentation for S3 Tables CTAS queries: + # "The `table_type` property defaults to `ICEBERG`, so you don't need to explicitly specify it" + # "If you don't specify a format, the system automatically uses `PARQUET`" + # We explicitly prevent all TBLPROPERTIES because Athena doesn't support them during CTAS + if expression: + # the only property allowed in CTAS for S3 Tables is 'format' (which we captured above) + format_val = table_properties.pop("format", exp.Literal.string("PARQUET")) + # Ensure it's uppercase PARQUET for S3 Tables just to be safe as per AWS examples + if isinstance(format_val, exp.Literal) and format_val.name.lower() == "parquet": + format_val = exp.Literal.string("PARQUET") + properties.append(exp.Property(this=exp.var("format"), value=format_val)) + + if table_properties: + logging.warning(f"Ignoring unsupported table properties for S3 Table CTAS: {list(table_properties.keys())}") + else: + # Standard CREATE TABLE for S3 Tables allows properties + for name, value in table_properties.items(): + properties.append(exp.Property(this=exp.var(name), value=value)) if properties: return exp.Properties(expressions=properties) @@ -613,7 +687,7 @@ def _boto3_client(self, name: str) -> t.Any: conn = self.connection return conn.session.client( name, - region_name=conn.region_name, + region_name=conn.region_̀€name, config=conn.config, **conn._client_kwargs, ) # type: ignore