Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 92 additions & 18 deletions sqlmesh/core/engine_adapter/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,9 @@ def s3_warehouse_location_or_raise(self) -> str:

@property
def catalog_support(self) -> CatalogSupport:
# Athena has the concept of catalogs but the current catalog is set in the connection parameters with no way to query or change it after that
# It also cant create new catalogs, you have to configure them in AWS. Typically, catalogs that are not "awsdatacatalog"
# are pointers to the "awsdatacatalog" of other AWS accounts
return CatalogSupport.SINGLE_CATALOG_ONLY
# Athena supports querying and writing to multiple catalogs (e.g. awsdatacatalog and s3tablescatalog)
# without needing a SET CATALOG command.
return CatalogSupport.FULL_SUPPORT

def create_state_table(
self,
Expand All @@ -105,6 +104,9 @@ def _get_data_objects(
"""
schema_name = to_schema(schema_name)
schema = schema_name.db

info_schema_tables = exp.table_("tables", db="information_schema", catalog=schema_name.catalog, alias="t")

query = (
exp.select(
exp.column("table_catalog").as_("catalog"),
Expand All @@ -118,7 +120,7 @@ def _get_data_objects(
.else_(exp.column("table_type", table="t"))
.as_("type"),
)
.from_(exp.to_table("information_schema.tables", alias="t"))
.from_(info_schema_tables)
.where(exp.column("table_schema", table="t").eq(schema))
)
if object_names:
Expand All @@ -141,9 +143,12 @@ def columns(
) -> t.Dict[str, exp.DataType]:
table = exp.to_table(table_name)
# note: the data_type column contains the full parameterized type, eg 'varchar(10)'

info_schema_columns = exp.table_("columns", db="information_schema", catalog=table.catalog)

query = (
exp.select("column_name", "data_type")
.from_("information_schema.columns")
.from_(info_schema_columns)
.where(exp.column("table_schema").eq(table.db), exp.column("table_name").eq(table.name))
.order_by("ordinal_position")
)
Expand Down Expand Up @@ -197,6 +202,11 @@ def _build_create_table_exp(
else:
table = table_name_or_schema

table_format = kwargs.pop("table_format", None)
if not table_format and table_properties and "table_format" in table_properties:
tf = table_properties.get("table_format")
table_format = tf.name if isinstance(tf, exp.Literal) else str(tf)

properties = self._build_table_properties_exp(
table=table,
expression=expression,
Expand All @@ -205,10 +215,11 @@ def _build_create_table_exp(
table_properties=table_properties,
table_description=table_description,
table_kind=table_kind,
table_format=table_format,
**kwargs,
)

is_hive = self._table_type(kwargs.get("table_format", None)) == "hive"
is_hive = self._table_type(table_format) == "hive"

# Filter any PARTITIONED BY properties from the main column list since they cant be specified in both places
# ref: https://docs.aws.amazon.com/athena/latest/ug/partitions.html
Expand Down Expand Up @@ -247,17 +258,36 @@ def _build_table_properties_exp(
**kwargs: t.Any,
) -> t.Optional[exp.Properties]:
properties: t.List[exp.Expr] = []
table_properties = table_properties or {}
table_properties = table_properties.copy() if table_properties else {}

s3_table_prop = table_properties.pop("s3_table", None)
is_s3_table = False
if s3_table_prop is not None:
if isinstance(s3_table_prop, exp.Boolean):
is_s3_table = s3_table_prop.this
elif isinstance(s3_table_prop, exp.Literal):
is_s3_table = s3_table_prop.name.lower() in ("true", "1")
else:
is_s3_table = str(s3_table_prop).lower() in ("true", "1")
elif table and table.catalog and table.catalog.startswith("s3tablescatalog/"):
is_s3_table = True

tf = table_properties.pop("table_format", None)
if not table_format and tf:
table_format = tf.name if isinstance(tf, exp.Literal) else str(tf)

is_hive = self._table_type(table_format) == "hive"
is_iceberg = not is_hive

if is_s3_table and is_hive:
raise SQLMeshError("Amazon S3 Tables only support the Iceberg format")

if is_hive and not expression:
# Hive tables are CREATE EXTERNAL TABLE, Iceberg tables are CREATE TABLE
# Unless it's a CTAS, those are always CREATE TABLE
properties.append(exp.ExternalProperty())

if table_format:
if table_format and not is_s3_table:
properties.append(
exp.Property(this=exp.var("table_type"), value=exp.Literal.string(table_format))
)
Expand All @@ -279,9 +309,30 @@ def _build_table_properties_exp(
else:
schema_expressions = partitioned_by

properties.append(
exp.PartitionedByProperty(this=exp.Schema(expressions=schema_expressions))
)
if is_hive:
properties.append(
exp.PartitionedByProperty(this=exp.Schema(expressions=schema_expressions))
)
else:
if is_s3_table:
array_exprs = []
for e in schema_expressions:
e_copy = e.copy()
e_copy.transform(
lambda n: n.name if isinstance(n, exp.Identifier) else n, copy=False
)
expr_sql = e_copy.sql(dialect="athena")
array_exprs.append(exp.Literal.string(expr_sql))

properties.append(
exp.Property(
this=exp.var("partitioning"), value=exp.Array(expressions=array_exprs)
)
)
else:
properties.append(
exp.PartitionedByProperty(this=exp.Schema(expressions=schema_expressions))
)

if clustered_by:
# Athena itself supports CLUSTERED BY, via the syntax CLUSTERED BY (col) INTO <n> BUCKETS
Expand All @@ -293,13 +344,16 @@ def _build_table_properties_exp(

if storage_format:
if is_iceberg:
# TBLPROPERTIES('format'='parquet')
table_properties["format"] = exp.Literal.string(storage_format)
if not is_s3_table or storage_format.lower() == "parquet":
# TBLPROPERTIES('format'='parquet')
table_properties["format"] = exp.Literal.string(storage_format)
elif is_s3_table and storage_format.lower() != "parquet":
raise SQLMeshError("Amazon S3 Tables only support the PARQUET storage format")
else:
# STORED AS PARQUET
properties.append(exp.FileFormatProperty(this=storage_format))

if table and (location := self._table_location_or_raise(table_properties, table)):
if table and not is_s3_table and (location := self._table_location_or_raise(table_properties, table)):
properties.append(location)

if is_iceberg and expression:
Expand All @@ -308,8 +362,28 @@ def _build_table_properties_exp(
# ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties
properties.append(exp.Property(this=exp.var("is_external"), value="false"))

for name, value in table_properties.items():
properties.append(exp.Property(this=exp.var(name), value=value))
if not is_s3_table:
for name, value in table_properties.items():
properties.append(exp.Property(this=exp.var(name), value=value))
elif is_s3_table:
# According to AWS documentation for S3 Tables CTAS queries:
# "The `table_type` property defaults to `ICEBERG`, so you don't need to explicitly specify it"
# "If you don't specify a format, the system automatically uses `PARQUET`"
# We explicitly prevent all TBLPROPERTIES because Athena doesn't support them during CTAS
if expression:
# the only property allowed in CTAS for S3 Tables is 'format' (which we captured above)
format_val = table_properties.pop("format", exp.Literal.string("PARQUET"))
# Ensure it's uppercase PARQUET for S3 Tables just to be safe as per AWS examples
if isinstance(format_val, exp.Literal) and format_val.name.lower() == "parquet":
format_val = exp.Literal.string("PARQUET")
properties.append(exp.Property(this=exp.var("format"), value=format_val))

if table_properties:
logging.warning(f"Ignoring unsupported table properties for S3 Table CTAS: {list(table_properties.keys())}")
else:
# Standard CREATE TABLE for S3 Tables allows properties
for name, value in table_properties.items():
properties.append(exp.Property(this=exp.var(name), value=value))

if properties:
return exp.Properties(expressions=properties)
Expand Down Expand Up @@ -613,7 +687,7 @@ def _boto3_client(self, name: str) -> t.Any:
conn = self.connection
return conn.session.client(
name,
region_name=conn.region_name,
region_name=conn.region_̀name,
config=conn.config,
**conn._client_kwargs,
) # type: ignore
Expand Down