Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.spark.format

import org.apache.paimon.table.FormatTable

import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage}
import org.apache.spark.sql.types.StructType

/**
* Spark 4.0-compatible shadow of the `paimon-spark4-common` `FormatTableBatchWrite`. Compiled
* against 4.0.2 so its class file's method table does not carry the `commit(.., WriteSummary)`
* signature added by Spark 4.1's `BatchWrite` default method, avoiding `ClassNotFoundException:
* WriteSummary` lazy-linking on 4.0 task serialization.
*/
class FormatTableBatchWrite(
table: FormatTable,
overwriteDynamic: Option[Boolean],
overwritePartitions: Option[Map[String, String]],
writeSchema: StructType)
extends FormatTableBatchWriteBase(table, overwriteDynamic, overwritePartitions, writeSchema)
with BatchWrite
with Serializable {

override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory =
createFormatTableDataWriterFactory()

override def useCommitCoordinator(): Boolean = false

override def commit(messages: Array[WriterCommitMessage]): Unit = commitMessages(messages)

override def abort(messages: Array[WriterCommitMessage]): Unit = abortMessages(messages)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,121 +18,37 @@

package org.apache.paimon.spark.write

import org.apache.paimon.io.{CompactIncrement, DataFileMeta, DataIncrement}
import org.apache.paimon.spark.catalyst.Compatibility
import org.apache.paimon.spark.commands.SparkDataFileMeta
import org.apache.paimon.spark.metric.SparkMetricRegistry
import org.apache.paimon.spark.rowops.PaimonCopyOnWriteScan
import org.apache.paimon.table.FileStoreTable
import org.apache.paimon.table.sink.{BatchWriteBuilder, CommitMessage, CommitMessageImpl}

import org.apache.spark.sql.PaimonSparkSession
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.StructType

import java.util.Collections

import scala.collection.JavaConverters._

case class PaimonBatchWrite(
/**
* Spark-4.0 shadow wrapper. Source-identical to the `paimon-spark4-common` version but compiled
* against Spark 4.0.2; the maven shade order picks `paimon-spark-4.0/target/classes` ahead of the
* shaded 4-common copy, so the class metadata loaded at runtime does not include the 4.1-only
* `BatchWrite.commit(.., WriteSummary)` signature that triggers `ClassNotFoundException` via
* `ObjectStreamClass.getPrivateMethod` during Spark task serialization.
*/
class PaimonBatchWrite(
table: FileStoreTable,
writeSchema: StructType,
dataSchema: StructType,
overwritePartitions: Option[Map[String, String]],
copyOnWriteScan: Option[PaimonCopyOnWriteScan])
extends BatchWrite
with WriteHelper {

protected val metricRegistry = SparkMetricRegistry()
extends PaimonBatchWriteBase(table, writeSchema, dataSchema, overwritePartitions, copyOnWriteScan)
with BatchWrite
with Serializable {

protected val batchWriteBuilder: BatchWriteBuilder = {
val builder = table.newBatchWriteBuilder()
overwritePartitions.foreach(partitions => builder.withOverwrite(partitions.asJava))
builder
}

override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = {
(_: Int, _: Long) =>
{
PaimonV2DataWriter(
batchWriteBuilder,
writeSchema,
dataSchema,
coreOptions,
table.catalogEnvironment().catalogContext())
}
}
override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory =
createPaimonDataWriterFactory(info)

override def useCommitCoordinator(): Boolean = false

override def commit(messages: Array[WriterCommitMessage]): Unit = {
logInfo(s"Committing to table ${table.name()}")
val batchTableCommit = batchWriteBuilder.newCommit()
batchTableCommit.withMetricRegistry(metricRegistry)
val addCommitMessage = WriteTaskResult.merge(messages)
val deletedCommitMessage = copyOnWriteScan match {
case Some(scan) => buildDeletedCommitMessage(scan.scannedFiles)
case None => Seq.empty
}
val commitMessages = addCommitMessage ++ deletedCommitMessage
try {
val start = System.currentTimeMillis()
batchTableCommit.commit(commitMessages.asJava)
logInfo(s"Committed in ${System.currentTimeMillis() - start} ms")
} finally {
batchTableCommit.close()
}
postDriverMetrics()
postCommit(commitMessages)
}

// Spark support v2 write driver metrics since 4.0, see https://github.com/apache/spark/pull/48573
// To ensure compatibility with 3.x, manually post driver metrics here instead of using Spark's API.
protected def postDriverMetrics(): Unit = {
val spark = PaimonSparkSession.active
// todo: find a more suitable way to get metrics.
val commitMetrics = metricRegistry.buildSparkCommitMetrics()
val executionId = spark.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
val executionMetrics = Compatibility.getExecutionMetrics(spark, executionId.toLong).distinct
val metricUpdates = executionMetrics.flatMap {
m =>
commitMetrics.find(x => m.metricType.toLowerCase.contains(x.name.toLowerCase)) match {
case Some(customTaskMetric) => Some((m.accumulatorId, customTaskMetric.value()))
case None => None
}
}
SQLMetrics.postDriverMetricsUpdatedByValue(spark.sparkContext, executionId, metricUpdates)
}
override def commit(messages: Array[WriterCommitMessage]): Unit = commitMessages(messages)

override def abort(messages: Array[WriterCommitMessage]): Unit = {
// TODO clean uncommitted files
}

private def buildDeletedCommitMessage(
deletedFiles: Seq[SparkDataFileMeta]): Seq[CommitMessage] = {
logInfo(s"[V2 Write] Building deleted commit message for ${deletedFiles.size} files")
deletedFiles
.groupBy(f => (f.partition, f.bucket))
.map {
case ((partition, bucket), files) =>
val deletedDataFileMetas = files.map(_.dataFileMeta).toList.asJava

new CommitMessageImpl(
partition,
bucket,
files.head.totalBuckets,
new DataIncrement(
Collections.emptyList[DataFileMeta],
deletedDataFileMetas,
Collections.emptyList[DataFileMeta]),
new CompactIncrement(
Collections.emptyList[DataFileMeta],
Collections.emptyList[DataFileMeta],
Collections.emptyList[DataFileMeta])
)
}
.toSeq
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ import org.apache.paimon.data.variant.{GenericVariant, Variant}
import org.apache.paimon.spark.catalyst.analysis.Spark4ResolutionRules
import org.apache.paimon.spark.catalyst.parser.extensions.PaimonSpark4SqlExtensionsParser
import org.apache.paimon.spark.data.{Spark4ArrayData, Spark4InternalRow, Spark4InternalRowWithBlob, SparkArrayData, SparkInternalRow}
import org.apache.paimon.spark.format.FormatTableBatchWrite
import org.apache.paimon.spark.rowops.PaimonCopyOnWriteScan
import org.apache.paimon.spark.write.PaimonBatchWrite
import org.apache.paimon.table.{FileStoreTable, FormatTable}
import org.apache.paimon.types.{DataType, RowType}

import org.apache.hadoop.conf.Configuration
Expand All @@ -38,6 +42,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, Table, TableCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.write.BatchWrite
import org.apache.spark.sql.execution.SparkFormatTable
import org.apache.spark.sql.execution.datasources.{PartitioningAwareFileIndex, PartitionSpec}
import org.apache.spark.sql.execution.streaming.{FileStreamSink, MetadataLogFileIndex}
Expand Down Expand Up @@ -101,6 +106,21 @@ class Spark4Shim extends SparkShim {
tableCatalog.createTable(ident, columns, partitions, properties)
}

override def createPaimonBatchWrite(
table: FileStoreTable,
writeSchema: StructType,
dataSchema: StructType,
overwritePartitions: Option[Map[String, String]],
copyOnWriteScan: Option[PaimonCopyOnWriteScan]): BatchWrite =
new PaimonBatchWrite(table, writeSchema, dataSchema, overwritePartitions, copyOnWriteScan)

override def createFormatTableBatchWrite(
table: FormatTable,
overwriteDynamic: Option[Boolean],
overwritePartitions: Option[Map[String, String]],
writeSchema: StructType): BatchWrite =
new FormatTableBatchWrite(table, overwriteDynamic, overwritePartitions, writeSchema)

override def createCTERelationRef(
cteId: Long,
resolved: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,89 +18,38 @@

package org.apache.paimon.spark.format

import org.apache.paimon.format.csv.CsvOptions
import org.apache.paimon.spark.{BaseTable, FormatTableScanBuilder, SparkInternalRowWrapper}
import org.apache.paimon.spark.write.{BaseV2WriteBuilder, FormatTableWriteTaskResult, V2DataWrite, WriteTaskResult}
import org.apache.paimon.spark.SparkInternalRowWrapper
import org.apache.paimon.spark.write.{FormatTableWriteTaskResult, V2DataWrite, WriteTaskResult}
import org.apache.paimon.table.FormatTable
import org.apache.paimon.table.sink.{BatchTableWrite, BatchWriteBuilder, CommitMessage}
import org.apache.paimon.types.RowType

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, TableCapability, TableCatalog}
import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, BATCH_WRITE, OVERWRITE_BY_FILTER, OVERWRITE_DYNAMIC}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.connector.write.streaming.StreamingWrite
import org.apache.spark.sql.connector.write.{DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

import java.util
import java.util.Locale

import scala.collection.JavaConverters._

case class PaimonFormatTable(table: FormatTable)
extends BaseTable
with SupportsRead
with SupportsWrite {

override def capabilities(): util.Set[TableCapability] = {
util.EnumSet.of(BATCH_READ, BATCH_WRITE, OVERWRITE_DYNAMIC, OVERWRITE_BY_FILTER)
}

override def properties: util.Map[String, String] = {
val properties = new util.HashMap[String, String](table.options())
properties.put(TableCatalog.PROP_PROVIDER, table.format.name().toLowerCase(Locale.ROOT))
if (table.comment.isPresent) {
properties.put(TableCatalog.PROP_COMMENT, table.comment.get)
}
if (FormatTable.Format.CSV == table.format) {
properties.put(
"sep",
properties.getOrDefault(
CsvOptions.FIELD_DELIMITER.key(),
CsvOptions.FIELD_DELIMITER.defaultValue()))
}
properties
}

override def newScanBuilder(caseInsensitiveStringMap: CaseInsensitiveStringMap): ScanBuilder = {
val scanBuilder = FormatTableScanBuilder(table.copy(caseInsensitiveStringMap))
scanBuilder.pruneColumns(schema)
scanBuilder
}

override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
PaimonFormatTableWriterBuilder(table, info.schema)
}
}

case class PaimonFormatTableWriterBuilder(table: FormatTable, writeSchema: StructType)
extends BaseV2WriteBuilder(table) {

override def partitionRowType(): RowType = table.partitionType

override def build: Write = new Write() {
override def toBatch: BatchWrite = {
FormatTableBatchWrite(table, overwriteDynamic, overwritePartitions, writeSchema)
}

override def toStreaming: StreamingWrite = {
throw new UnsupportedOperationException("FormatTable doesn't support streaming write")
}
}
}

private case class FormatTableBatchWrite(
/**
* Business logic for `FormatTable` batch writes, deliberately *not* extending
* `org.apache.spark.sql.connector.write.BatchWrite`. See
* [[org.apache.paimon.spark.write.PaimonBatchWriteBase]] for the full rationale: Spark 4.1 added a
* default method `BatchWrite.commit(.., WriteSummary)` whose `WriteSummary` parameter type is
* unavailable on Spark 4.0, so a class compiled against 4.1 that mixes in `BatchWrite` triggers
* `ClassNotFoundException: WriteSummary` lazy-linking on 4.0 runtimes during Spark task
* serialization. Keeping this base off `BatchWrite` lets common ship the implementation once;
* per-version `paimon-spark{3,4}-common` modules supply a thin wrapper that mixes in `BatchWrite`,
* and `paimon-spark-4.0/src/main` shadows that wrapper at the 4.0.2 compile target.
*/
abstract class FormatTableBatchWriteBase(
table: FormatTable,
overwriteDynamic: Option[Boolean],
overwritePartitions: Option[Map[String, String]],
writeSchema: StructType)
extends BatchWrite
with Logging {
extends Logging
with Serializable {

private val batchWriteBuilder = {
protected val batchWriteBuilder: BatchWriteBuilder = {
val builder = table.newBatchWriteBuilder()
// todo: add test for static overwrite the whole table
if (overwriteDynamic.contains(true)) {
Expand All @@ -111,13 +60,11 @@ private case class FormatTableBatchWrite(
builder
}

override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = {
protected def createFormatTableDataWriterFactory(): DataWriterFactory = {
(_: Int, _: Long) => new FormatTableDataWriter(batchWriteBuilder, writeSchema)
}

override def useCommitCoordinator(): Boolean = false

override def commit(messages: Array[WriterCommitMessage]): Unit = {
protected def commitMessages(messages: Array[WriterCommitMessage]): Unit = {
logInfo(s"Committing to FormatTable ${table.name()}")
val batchTableCommit = batchWriteBuilder.newCommit()
val commitMessages = WriteTaskResult.merge(messages).asJava
Expand All @@ -132,7 +79,7 @@ private case class FormatTableBatchWrite(
}
}

override def abort(messages: Array[WriterCommitMessage]): Unit = {
protected def abortMessages(messages: Array[WriterCommitMessage]): Unit = {
logInfo(s"Aborting write to FormatTable ${table.name()}")
val batchTableCommit = batchWriteBuilder.newCommit()
val commitMessages = WriteTaskResult.merge(messages).asJava
Expand Down
Loading