diff --git a/pom.xml b/pom.xml index 960343d01..64299d059 100644 --- a/pom.xml +++ b/pom.xml @@ -50,7 +50,7 @@ azure-eventhubs-spark-parent 2.11 - 2.3.3 + 2.4.7 github 1.8 1g @@ -58,6 +58,7 @@ core + schemaregistry-avro diff --git a/schemaregistry-avro/README.md b/schemaregistry-avro/README.md new file mode 100644 index 000000000..b9e567ebc --- /dev/null +++ b/schemaregistry-avro/README.md @@ -0,0 +1,88 @@ +# azure-schemaregistry-spark (WIP) + +## Overview + +Schema Registry support in Java is provided by the official Schema Registry SDK in the Azure Java SDK repository. + +Schema Registry serializer craft payloads that contain a schema ID and an encoded payload. The ID references a registry-stored schema that can be used to decode the user-specified payload. + +However, consuming Schema Registry-backed payloads in Spark is particularly difficult, since - +- Spark Kafka does not support plug-in with KafkaSerializer and KafkaDeserializer objects, and +- Object management is non-trivial given Spark's driver-executor model. + +For these reasons, Spark functions are required to simplify SR UX in Spark. This repository contains packages that will provide Spark support in Scala for serialization and deserialization of registry-backed payloads. Code is work in progress. + +Currently, only Avro encodings are supported by Azure Schema Registry clients. `from_avro` and `to_avro` found in the `functions.scala` files will be usable for converting Spark SQL columns from registry-backed payloads to columns of the correct Spark SQL datatype (e.g. `StringType`, `StructType`, etc.). + +## Usage + +Compile the JAR and build with dependencies using the following Maven commmand: +```bash +mvn clean compile assembly:single +``` + +The JAR can then be uploaded without additional required dependencies in your Databricks environment. If using `spark-submit`, use the `--jars` option to submit the path of the custom JAR. + +Spark/Databricks usage is the following: + +```scala +import com.microsoft.azure.schemaregistry.spark.avro.functions._; + +val props: HashMap[String, String] = new HashMap() +props.put("schema.registry.url", SCHEMA_REGISTRY_URL) +props.put("schema.registry.tenant.id", SCHEMA_REGISTRY_TENANT_ID) +props.put("schema.registry.client.id", SCHEMA_REGISTRY_CLIENT_ID) +props.put("schema.registry.client.secret", SCHEMA_REGISTRY_CLIENT_SECRET) + + +val df = spark.readStream + .format("kafka") + .option("subscribe", TOPIC) + .option("kafka.bootstrap.servers", BOOTSTRAP_SERVERS) + .option("kafka.sasl.mechanism", "PLAIN") + .option("kafka.security.protocol", "SASL_SSL") + .option("kafka.sasl.jaas.config", EH_SASL) + .option("kafka.request.timeout.ms", "60000") + .option("kafka.session.timeout.ms", "60000") + .option("failOnDataLoss", "false") + .option("startingOffsets", "earliest") + .option("kafka.group.id", "kafka-group") + .load() + +// from_avro() arguments: +// Spark SQL Column +// schema GUID +// properties for communicating with SR service (see props above) +df.select(from_avro($"value", "[schema guid]", props)) + .writeStream + .outputMode("append") + .format("console") + .start() + .awaitTermination() +``` + +## Schema Evolution + +In the context of stream processing, the primary use case is where the schema GUID references a schema matching in the stream. + +However, there are two edge cases that will be common in streaming scenarios in which we are concerned with schema evolution - +- Stream jobs reading old data with new schemas - only backwards compatible data will be readable, meaning that fields may be null. +- Stream jobs reading new data with old schemas - even if the Spark job schema is forwards compatible with the new schema, projecting data written with the new schema to the old one will result in data loss in the case of additional fields being added. + +To handle the more dangerous second case, Spark functions will throw if incoming data contains fields that cannot be captured by the existing schema. This behavior is based on the assumption that perceived data loss is prohibited. + +To handle the first first case, a parameter will be introduced called `requireExactSchemaMatch`: +- If true, if the schema in the payload is not an exact match to the Spark-specified schema, then the job will throw. This allows users to specify that their pipeline contain one schema only. +- If false, the job will attempt to read the data incoming in the stream. In the case of upgraded consumers reading backwards compatible schemas, the job will be able to properly read the schemas (nullable deleted fields, adding new optional fields). + +## Failure Modes + +Two modes will be supported as dictated by Spark SQL - +- `FailFastMode` - fail on catching any exception +- `PermissiveMode` - continue processing if parsing exceptions are caught (currently unsupported) + +Customers will be able to configure the stream with specific failure models, but the default failure model will be `FailFastMode` to prevent perceived data loss with `PermissiveMode`. + +See also: +- aka.ms/schemaregistry +- https://github.com/Azure/azure-schema-registry-for-kafka diff --git a/schemaregistry-avro/pom.xml b/schemaregistry-avro/pom.xml new file mode 100644 index 000000000..fda640079 --- /dev/null +++ b/schemaregistry-avro/pom.xml @@ -0,0 +1,57 @@ + + 4.0.0 + com.microsoft.azure + azure-schemaregistry-spark-avro_${scala.binary.version} + 1.0.0-beta.4 + ${project.artifactId} + + + com.microsoft.azure + azure-eventhubs-spark-parent_${scala.binary.version} + 2.3.18 + ../pom.xml + + jar + + + 1.8 + 1.8 + UTF-8 + + + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${spark.version} + test-jar + test + + + com.azure + azure-data-schemaregistry + 1.0.0-beta.4 + + + com.azure + azure-data-schemaregistry-avro + 1.0.0-beta.4 + + + com.azure + azure-identity + 1.1.3 + + + + + + + maven-assembly-plugin + + false + + + + + diff --git a/schemaregistry-avro/src/main/scala/com/microsoft/azure/schemaregistry/spark/avro/AvroDataToCatalyst.scala b/schemaregistry-avro/src/main/scala/com/microsoft/azure/schemaregistry/spark/avro/AvroDataToCatalyst.scala new file mode 100644 index 000000000..105312345 --- /dev/null +++ b/schemaregistry-avro/src/main/scala/com/microsoft/azure/schemaregistry/spark/avro/AvroDataToCatalyst.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.azure.schemaregistry.spark.avro + +import java.io.ByteArrayInputStream + +import com.azure.core.util.serializer.TypeReference +import com.azure.data.schemaregistry.SchemaRegistryClientBuilder +import com.azure.data.schemaregistry.avro.{SchemaRegistryAvroSerializerBuilder} +import com.azure.identity.ClientSecretCredentialBuilder +import org.apache.avro.Schema +import org.apache.avro.generic.GenericRecord +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, SpecificInternalRow, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.types._ + +import scala.util.control.NonFatal + +case class AvroDataToCatalyst( + child: Expression, + schemaId: String, + options: Map[java.lang.String, java.lang.String], + requireExactSchemaMatch: Boolean) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[BinaryType] = Seq(BinaryType) + + override lazy val dataType: DataType = { + val dt = SchemaConverters.toSqlType(new Schema.Parser().parse(expectedSchemaString)).dataType; + dt + } + + override def nullable: Boolean = true + + private val expectedSchemaString : String = { + new String(schemaRegistryAsyncClient.getSchema(schemaId).block().getSchema) + } + + @transient private lazy val schemaRegistryCredential = new ClientSecretCredentialBuilder() + .tenantId(options.getOrElse("schema.registry.tenant.id", null)) + .clientId(options.getOrElse("schema.registry.client.id", null)) + .clientSecret(options.getOrElse("schema.registry.client.secret", null)) + .build() + + @transient private lazy val schemaRegistryAsyncClient = new SchemaRegistryClientBuilder() + .endpoint(options.getOrElse("schema.registry.url", null)) + .credential(schemaRegistryCredential) + .buildAsyncClient() + + @transient private lazy val deserializer = new SchemaRegistryAvroSerializerBuilder() + .schemaRegistryAsyncClient(schemaRegistryAsyncClient) + .schemaGroup(options.getOrElse("schema.group", null)) + .autoRegisterSchema(options.getOrElse("specific.avro.reader", false).asInstanceOf[Boolean]) + .buildSerializer() + + @transient private lazy val avroConverter = { + new AvroDeserializer(new Schema.Parser().parse(expectedSchemaString), dataType) + } + + @transient private lazy val expectedSchema = new Schema.Parser().parse(expectedSchemaString) + + @transient private lazy val parseMode: ParseMode = { + val mode = options.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode) + if (mode != PermissiveMode && mode != FailFastMode) { + throw new IllegalArgumentException(mode + "parse mode not supported.") + } + mode + } + + @transient private lazy val nullResultRow: Any = dataType match { + case st: StructType => + val resultRow = new SpecificInternalRow(st.map(_.dataType)) + for(i <- 0 until st.length) { + resultRow.setNullAt(i) + } + resultRow + + case _ => + null + } + + override def nullSafeEval(input: Any): Any = { + try { + val binary = new ByteArrayInputStream(input.asInstanceOf[Array[Byte]]) + // compare schema version and datatype version + val genericRecord = deserializer.deserialize(binary, TypeReference.createInstance(classOf[GenericRecord])) + + if (requireExactSchemaMatch) { + if (!expectedSchema.equals(genericRecord.getSchema)) { + throw new IncompatibleSchemaException(s"Schema not exact match, payload schema did not match expected schema. Payload schema: ${genericRecord.getSchema}") + } + } + + avroConverter.deserialize(genericRecord) + } catch { + case NonFatal(e) => parseMode match { + case PermissiveMode => nullResultRow + case FailFastMode => + throw new Exception("Malformed records are detected in record parsing. " + + s"Current parse Mode: ${FailFastMode.name}. To process malformed records as null " + + "result, try setting the option 'mode' as 'PERMISSIVE'.", e) + case _ => + throw new Exception(s"Unknown parse mode: ${parseMode.name}") + } + } + } + + override def prettyName: String = "from_avro" + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val expr = ctx.addReferenceObj("this", this) + defineCodeGen(ctx, ev, input => + s"(${CodeGenerator.boxedType(dataType)})$expr.nullSafeEval($input)") + } +} \ No newline at end of file diff --git a/schemaregistry-avro/src/main/scala/com/microsoft/azure/schemaregistry/spark/avro/AvroDeserializer.scala b/schemaregistry-avro/src/main/scala/com/microsoft/azure/schemaregistry/spark/avro/AvroDeserializer.scala new file mode 100644 index 000000000..dde932f23 --- /dev/null +++ b/schemaregistry-avro/src/main/scala/com/microsoft/azure/schemaregistry/spark/avro/AvroDeserializer.scala @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.azure.schemaregistry.spark.avro + +import java.math.BigDecimal +import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder} +import org.apache.avro.Conversions.DecimalConversion +import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis} +import org.apache.avro.Schema.Type._ +import org.apache.avro.generic._ +import org.apache.avro.util.Utf8 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A deserializer to deserialize data in avro format to data in catalyst format. + */ +class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { + private lazy val decimalConversions = new DecimalConversion() + + private val converter: Any => Any = rootCatalystType match { + // A shortcut for empty schema. + case st: StructType if st.isEmpty => + (data: Any) => InternalRow.empty + + case st: StructType => + val resultRow = new SpecificInternalRow(st.map(_.dataType)) + val fieldUpdater = new RowUpdater(resultRow) + val writer = getRecordWriter(rootAvroType, st, Nil) + (data: Any) => { + val record = data.asInstanceOf[GenericRecord] + writer(fieldUpdater, record) + resultRow + } + + case _ => + val tmpRow = new SpecificInternalRow(Seq(rootCatalystType)) + val fieldUpdater = new RowUpdater(tmpRow) + val writer = newWriter(rootAvroType, rootCatalystType, Nil) + (data: Any) => { + writer(fieldUpdater, 0, data) + tmpRow.get(0, rootCatalystType) + } + } + + def deserialize(data: Any): Any = converter(data) + + /** + * Creates a writer to write avro values to Catalyst values at the given ordinal with the given + * updater. + */ + private def newWriter( + avroType: Schema, + catalystType: DataType, + path: List[String]): (CatalystDataUpdater, Int, Any) => Unit = + (avroType.getType, catalystType) match { + case (NULL, NullType) => (updater, ordinal, _) => + updater.setNullAt(ordinal) + + // TODO: we can avoid boxing if future version of avro provide primitive accessors. + case (BOOLEAN, BooleanType) => (updater, ordinal, value) => + updater.setBoolean(ordinal, value.asInstanceOf[Boolean]) + + case (INT, IntegerType) => (updater, ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[Int]) + + case (INT, DateType) => (updater, ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[Int]) + + case (LONG, LongType) => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long]) + + case (LONG, TimestampType) => avroType.getLogicalType match { + case _: TimestampMillis => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long] * 1000) + case _: TimestampMicros => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long]) + case null => (updater, ordinal, value) => + // For backward compatibility, if the Avro type is Long and it is not logical type, + // the value is processed as timestamp type with millisecond precision. + updater.setLong(ordinal, value.asInstanceOf[Long] * 1000) + case other => throw new IncompatibleSchemaException( + s"Cannot convert Avro logical type ${other} to Catalyst Timestamp type.") + } + + // Before we upgrade Avro to 1.8 for logical type support, spark-avro converts Long to Date. + // For backward compatibility, we still keep this conversion. + case (LONG, DateType) => (updater, ordinal, value) => + updater.setInt(ordinal, (value.asInstanceOf[Long] / DateTimeUtils.MILLIS_PER_DAY).toInt) + + case (FLOAT, FloatType) => (updater, ordinal, value) => + updater.setFloat(ordinal, value.asInstanceOf[Float]) + + case (DOUBLE, DoubleType) => (updater, ordinal, value) => + updater.setDouble(ordinal, value.asInstanceOf[Double]) + + case (STRING, StringType) => (updater, ordinal, value) => + val str = value match { + case s: String => UTF8String.fromString(s) + case s: Utf8 => + val bytes = new Array[Byte](s.getByteLength) + System.arraycopy(s.getBytes, 0, bytes, 0, s.getByteLength) + UTF8String.fromBytes(bytes) + } + updater.set(ordinal, str) + + case (ENUM, StringType) => (updater, ordinal, value) => + updater.set(ordinal, UTF8String.fromString(value.toString)) + + case (FIXED, BinaryType) => (updater, ordinal, value) => + updater.set(ordinal, value.asInstanceOf[GenericFixed].bytes().clone()) + + case (BYTES, BinaryType) => (updater, ordinal, value) => + val bytes = value match { + case b: ByteBuffer => + val bytes = new Array[Byte](b.remaining) + b.get(bytes) + bytes + case b: Array[Byte] => b + case other => throw new RuntimeException(s"$other is not a valid avro binary.") + } + updater.set(ordinal, bytes) + + case (FIXED, d: DecimalType) => (updater, ordinal, value) => + val bigDecimal = decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroType, + LogicalTypes.decimal(d.precision, d.scale)) + val decimal = createDecimal(bigDecimal, d.precision, d.scale) + updater.setDecimal(ordinal, decimal) + + case (BYTES, d: DecimalType) => (updater, ordinal, value) => + val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroType, + LogicalTypes.decimal(d.precision, d.scale)) + val decimal = createDecimal(bigDecimal, d.precision, d.scale) + updater.setDecimal(ordinal, decimal) + + case (RECORD, st: StructType) => + val writeRecord = getRecordWriter(avroType, st, path) + (updater, ordinal, value) => + val row = new SpecificInternalRow(st) + writeRecord(new RowUpdater(row), value.asInstanceOf[GenericRecord]) + updater.set(ordinal, row) + + case (ARRAY, ArrayType(elementType, containsNull)) => + val elementWriter = newWriter(avroType.getElementType, elementType, path) + (updater, ordinal, value) => + val array = value.asInstanceOf[GenericData.Array[Any]] + val len = array.size() + val result = createArrayData(elementType, len) + val elementUpdater = new ArrayDataUpdater(result) + + var i = 0 + while (i < len) { + val element = array.get(i) + if (element == null) { + if (!containsNull) { + throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " + + "allowed to be null") + } else { + elementUpdater.setNullAt(i) + } + } else { + elementWriter(elementUpdater, i, element) + } + i += 1 + } + + updater.set(ordinal, result) + + case (MAP, MapType(keyType, valueType, valueContainsNull)) if keyType == StringType => + val keyWriter = newWriter(SchemaBuilder.builder().stringType(), StringType, path) + val valueWriter = newWriter(avroType.getValueType, valueType, path) + (updater, ordinal, value) => + val map = value.asInstanceOf[java.util.Map[AnyRef, AnyRef]] + val keyArray = createArrayData(keyType, map.size()) + val keyUpdater = new ArrayDataUpdater(keyArray) + val valueArray = createArrayData(valueType, map.size()) + val valueUpdater = new ArrayDataUpdater(valueArray) + val iter = map.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + assert(entry.getKey != null) + keyWriter(keyUpdater, i, entry.getKey) + if (entry.getValue == null) { + if (!valueContainsNull) { + throw new RuntimeException(s"Map value at path ${path.mkString(".")} is not " + + "allowed to be null") + } else { + valueUpdater.setNullAt(i) + } + } else { + valueWriter(valueUpdater, i, entry.getValue) + } + i += 1 + } + + updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) + + case (UNION, _) => + val allTypes = avroType.getTypes.asScala + val nonNullTypes = allTypes.filter(_.getType != NULL) + val nonNullAvroType = Schema.createUnion(nonNullTypes.asJava) + if (nonNullTypes.nonEmpty) { + if (nonNullTypes.length == 1) { + newWriter(nonNullTypes.head, catalystType, path) + } else { + nonNullTypes.map(_.getType) match { + case Seq(a, b) if Set(a, b) == Set(INT, LONG) && catalystType == LongType => + (updater, ordinal, value) => value match { + case null => updater.setNullAt(ordinal) + case l: java.lang.Long => updater.setLong(ordinal, l) + case i: java.lang.Integer => updater.setLong(ordinal, i.longValue()) + } + + case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && catalystType == DoubleType => + (updater, ordinal, value) => value match { + case null => updater.setNullAt(ordinal) + case d: java.lang.Double => updater.setDouble(ordinal, d) + case f: java.lang.Float => updater.setDouble(ordinal, f.doubleValue()) + } + + case _ => + catalystType match { + case st: StructType if st.length == nonNullTypes.size => + val fieldWriters = nonNullTypes.zip(st.fields).map { + case (schema, field) => newWriter(schema, field.dataType, path :+ field.name) + }.toArray + (updater, ordinal, value) => { + val row = new SpecificInternalRow(st) + val fieldUpdater = new RowUpdater(row) + val i = GenericData.get().resolveUnion(nonNullAvroType, value) + fieldWriters(i)(fieldUpdater, i, value) + updater.set(ordinal, row) + } + + case _ => + throw new IncompatibleSchemaException( + s"Cannot convert Avro to catalyst because schema at path " + + s"${path.mkString(".")} is not compatible " + + s"(avroType = $avroType, sqlType = $catalystType).\n" + + s"Source Avro schema: $rootAvroType.\n" + + s"Target Catalyst type: $rootCatalystType") + } + } + } + } else { + (updater, ordinal, value) => updater.setNullAt(ordinal) + } + + case _ => + throw new IncompatibleSchemaException( + s"Cannot convert Avro to catalyst because schema at path ${path.mkString(".")} " + + s"is not compatible (avroType = $avroType, sqlType = $catalystType).\n" + + s"Source Avro schema: $rootAvroType.\n" + + s"Target Catalyst type: $rootCatalystType") + } + + // TODO: move the following method in Decimal object on creating Decimal from BigDecimal? + private def createDecimal(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { + if (precision <= Decimal.MAX_LONG_DIGITS) { + // Constructs a `Decimal` with an unscaled `Long` value if possible. + Decimal(decimal.unscaledValue().longValue(), precision, scale) + } else { + // Otherwise, resorts to an unscaled `BigInteger` instead. + Decimal(decimal, precision, scale) + } + } + + private def getRecordWriter( + avroType: Schema, + sqlType: StructType, + path: List[String]): (CatalystDataUpdater, GenericRecord) => Unit = { + val validFieldIndexes = ArrayBuffer.empty[Int] + val fieldWriters = ArrayBuffer.empty[(CatalystDataUpdater, Any) => Unit] + + val length = sqlType.length + var i = 0 + while (i < length) { + val sqlField = sqlType.fields(i) + val avroField = avroType.getField(sqlField.name) + if (avroField != null) { + validFieldIndexes += avroField.pos() + + val baseWriter = newWriter(avroField.schema(), sqlField.dataType, path :+ sqlField.name) + val ordinal = i + val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { + if (value == null) { + fieldUpdater.setNullAt(ordinal) + } else { + baseWriter(fieldUpdater, ordinal, value) + } + } + fieldWriters += fieldWriter + } else if (!sqlField.nullable) { + throw new IncompatibleSchemaException( + s""" + |Cannot find non-nullable field ${path.mkString(".")}.${sqlField.name} in Avro schema. + |Source Avro schema: $rootAvroType. + |Target Catalyst type: $rootCatalystType. + """.stripMargin) + } + i += 1 + } + + (fieldUpdater, record) => { + var i = 0 + while (i < validFieldIndexes.length) { + fieldWriters(i)(fieldUpdater, record.get(validFieldIndexes(i))) + i += 1 + } + } + } + + private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { + case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) + case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) + case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) + case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) + case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) + case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) + case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) + case _ => new GenericArrayData(new Array[Any](length)) + } + + /** + * A base interface for updating values inside catalyst data structure like `InternalRow` and + * `ArrayData`. + */ + sealed trait CatalystDataUpdater { + def set(ordinal: Int, value: Any): Unit + + def setNullAt(ordinal: Int): Unit = set(ordinal, null) + def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) + def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) + def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) + def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) + def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) + def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) + def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) + def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value) + } + + final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) + override def setDecimal(ordinal: Int, value: Decimal): Unit = + row.setDecimal(ordinal, value, value.precision) + } + + final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) + override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value) + } +} \ No newline at end of file diff --git a/schemaregistry-avro/src/main/scala/com/microsoft/azure/schemaregistry/spark/avro/SchemaConverters.scala b/schemaregistry-avro/src/main/scala/com/microsoft/azure/schemaregistry/spark/avro/SchemaConverters.scala new file mode 100644 index 000000000..284c86dd3 --- /dev/null +++ b/schemaregistry-avro/src/main/scala/com/microsoft/azure/schemaregistry/spark/avro/SchemaConverters.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.azure.schemaregistry.spark.avro + +import scala.collection.JavaConverters._ +import scala.util.Random + +import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder} +import org.apache.avro.LogicalTypes.{Date, Decimal, TimestampMicros, TimestampMillis} +import org.apache.avro.Schema.Type._ + +import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator +import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.Decimal.{maxPrecisionForBytes, minBytesForPrecision} + +/** + * This object contains method that are used to convert sparkSQL schemas to avro schemas and vice + * versa. + */ +object SchemaConverters { + private lazy val uuidGenerator = RandomUUIDGenerator(new Random().nextLong()) + + private lazy val nullSchema = Schema.create(Schema.Type.NULL) + + case class SchemaType(dataType: DataType, nullable: Boolean) + + /** + * This function takes an avro schema and returns a sql schema. + */ + def toSqlType(avroSchema: Schema): SchemaType = { + toSqlTypeHelper(avroSchema, Set.empty) + } + + def toSqlTypeHelper(avroSchema: Schema, existingRecordNames: Set[String]): SchemaType = { + avroSchema.getType match { + case INT => avroSchema.getLogicalType match { + case _: Date => SchemaType(DateType, nullable = false) + case _ => SchemaType(IntegerType, nullable = false) + } + case STRING => SchemaType(StringType, nullable = false) + case BOOLEAN => SchemaType(BooleanType, nullable = false) + case BYTES | FIXED => avroSchema.getLogicalType match { + // For FIXED type, if the precision requires more bytes than fixed size, the logical + // type will be null, which is handled by Avro library. + case d: Decimal => SchemaType(DecimalType(d.getPrecision, d.getScale), nullable = false) + case _ => SchemaType(BinaryType, nullable = false) + } + + case DOUBLE => SchemaType(DoubleType, nullable = false) + case FLOAT => SchemaType(FloatType, nullable = false) + case LONG => avroSchema.getLogicalType match { + case _: TimestampMillis | _: TimestampMicros => SchemaType(TimestampType, nullable = false) + case _ => SchemaType(LongType, nullable = false) + } + + case ENUM => SchemaType(StringType, nullable = false) + + case RECORD => + if (existingRecordNames.contains(avroSchema.getFullName)) { + throw new IncompatibleSchemaException(s""" + |Found recursive reference in Avro schema, which can not be processed by Spark: + |${avroSchema.toString(true)} + """.stripMargin) + } + val newRecordNames = existingRecordNames + avroSchema.getFullName + val fields = avroSchema.getFields.asScala.map { f => + val schemaType = toSqlTypeHelper(f.schema(), newRecordNames) + StructField(f.name, schemaType.dataType, schemaType.nullable) + } + + SchemaType(StructType(fields), nullable = false) + + case ARRAY => + val schemaType = toSqlTypeHelper(avroSchema.getElementType, existingRecordNames) + SchemaType( + ArrayType(schemaType.dataType, containsNull = schemaType.nullable), + nullable = false) + + case MAP => + val schemaType = toSqlTypeHelper(avroSchema.getValueType, existingRecordNames) + SchemaType( + MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), + nullable = false) + + case UNION => + if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { + // In case of a union with null, eliminate it and make a recursive call + val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) + if (remainingUnionTypes.size == 1) { + toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames).copy(nullable = true) + } else { + toSqlTypeHelper(Schema.createUnion(remainingUnionTypes.asJava), existingRecordNames) + .copy(nullable = true) + } + } else avroSchema.getTypes.asScala.map(_.getType) match { + case Seq(t1) => + toSqlTypeHelper(avroSchema.getTypes.get(0), existingRecordNames) + case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => + SchemaType(LongType, nullable = false) + case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => + SchemaType(DoubleType, nullable = false) + case _ => + // Convert complex unions to struct types where field names are member0, member1, etc. + // This is consistent with the behavior when converting between Avro and Parquet. + val fields = avroSchema.getTypes.asScala.zipWithIndex.map { + case (s, i) => + val schemaType = toSqlTypeHelper(s, existingRecordNames) + // All fields are nullable because only one of them is set at a time + StructField(s"member$i", schemaType.dataType, nullable = true) + } + + SchemaType(StructType(fields), nullable = false) + } + + case other => throw new IncompatibleSchemaException(s"Unsupported type $other") + } + } + + def toAvroType( + catalystType: DataType, + nullable: Boolean = false, + recordName: String = "topLevelRecord", + nameSpace: String = "") + : Schema = { + val builder = SchemaBuilder.builder() + + val schema = catalystType match { + case BooleanType => builder.booleanType() + case ByteType | ShortType | IntegerType => builder.intType() + case LongType => builder.longType() + case DateType => + LogicalTypes.date().addToSchema(builder.intType()) + case TimestampType => + LogicalTypes.timestampMicros().addToSchema(builder.longType()) + + case FloatType => builder.floatType() + case DoubleType => builder.doubleType() + case StringType => builder.stringType() + case d: DecimalType => + val avroType = LogicalTypes.decimal(d.precision, d.scale) + val fixedSize = minBytesForPrecision(d.precision) + // Need to avoid naming conflict for the fixed fields + val name = nameSpace match { + case "" => s"$recordName.fixed" + case _ => s"$nameSpace.$recordName.fixed" + } + avroType.addToSchema(SchemaBuilder.fixed(name).size(fixedSize)) + + case BinaryType => builder.bytesType() + case ArrayType(et, containsNull) => + builder.array() + .items(toAvroType(et, containsNull, recordName, nameSpace)) + case MapType(StringType, vt, valueContainsNull) => + builder.map() + .values(toAvroType(vt, valueContainsNull, recordName, nameSpace)) + case st: StructType => + val childNameSpace = if (nameSpace != "") s"$nameSpace.$recordName" else recordName + val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields() + st.foreach { f => + val fieldAvroType = + toAvroType(f.dataType, f.nullable, f.name, childNameSpace) + fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault() + } + fieldsAssembler.endRecord() + + // This should never happen. + case other => throw new IncompatibleSchemaException(s"Unexpected type $other.") + } + if (nullable) { + Schema.createUnion(schema, nullSchema) + } else { + schema + } + } +} + +class IncompatibleSchemaException(msg: String, ex: Throwable = null) extends Exception(msg, ex) \ No newline at end of file diff --git a/schemaregistry-avro/src/main/scala/com/microsoft/azure/schemaregistry/spark/avro/functions.scala b/schemaregistry-avro/src/main/scala/com/microsoft/azure/schemaregistry/spark/avro/functions.scala new file mode 100644 index 000000000..c38662a7a --- /dev/null +++ b/schemaregistry-avro/src/main/scala/com/microsoft/azure/schemaregistry/spark/avro/functions.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.azure.schemaregistry.spark.avro + +import com.azure.data.schemaregistry.avro.{SchemaRegistryAvroSerializer} +import scala.collection.JavaConverters._ +import org.apache.spark.sql.Column + +/*** + * Scala object containing utility methods for serialization/deserialization with Azure Schema Registry and Spark SQL + * columns. + * + * Functions are agnostic to data source or sink and can be used with any Schema Registry payloads, including: + * - Kafka Spark connector ($value) + * - Event Hubs Spark connector ($Body) + * - Event Hubs Avro Capture blobs ($Body) + */ +object functions { + var serializer: SchemaRegistryAvroSerializer = null + + /*** + * Converts Spark SQL Column containing SR payloads into a + * @param data column with SR payloads + * @param schemaId GUID of the expected schema + * @param clientOptions map of configuration properties, including Spark run mode (permissive vs. fail-fast) + * @param requireExactSchemaMatch boolean if call should throw if data contents do not exactly match expected schema + * @return + */ + def from_avro( + data: Column, + schemaId: String, + clientOptions: java.util.Map[java.lang.String, java.lang.String], + requireExactSchemaMatch: Boolean = true): Column = { + new Column(AvroDataToCatalyst(data.expr, schemaId, clientOptions.asScala.toMap, requireExactSchemaMatch)) + } +} diff --git a/schemaregistry-avro/src/test/scala/com/microsoft/azure/schemaregistry/spark/avro/AvroDeserializerSuite.scala b/schemaregistry-avro/src/test/scala/com/microsoft/azure/schemaregistry/spark/avro/AvroDeserializerSuite.scala new file mode 100644 index 000000000..bae7a6b6c --- /dev/null +++ b/schemaregistry-avro/src/test/scala/com/microsoft/azure/schemaregistry/spark/avro/AvroDeserializerSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.azure.schemaregistry.spark.avro + +import java.util + +import org.apache.spark.SparkException +import org.apache.spark.sql.{Column, QueryTest, Row} +import org.apache.spark.sql.execution.LocalTableScanExec +import org.apache.spark.sql.functions.{col, lit, struct} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class AvroFunctionsSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + test("do not handle null column") { + + try { + functions.from_avro(null, "schema_id", null) + fail() + } + catch { + case _: NullPointerException => + } + } + + test("do not handle null schema ID") { + try { + functions.from_avro(new Column("empty"), null, new util.HashMap()) + fail() + } + catch { + case _: NullPointerException => + } + } + + test("invalid client options") { + val configMap = new util.HashMap[String, String]() + configMap.put("schema.registry.url", "https://namespace.servicebus.windows.net") + try { + functions.from_avro(new Column("empty"), "schema_id", configMap) + fail() + } + catch { + case _: IllegalArgumentException => + } + } +}