Skip to content

Commit

Permalink
Parquet map logical type (#876)
Browse files Browse the repository at this point in the history
* Parquet map logical type

* Generate scala Map type for parquet

* Support optional map and optional value

* Ensure Map key is required

* Fix warning

* Fix parquet parser

* Consistent isEmpty dfefinition
  • Loading branch information
RustedBones authored Jan 9, 2024
1 parent 7af1303 commit 43ce752
Show file tree
Hide file tree
Showing 11 changed files with 420 additions and 343 deletions.
99 changes: 95 additions & 4 deletions parquet/src/main/scala/magnolify/parquet/ParquetField.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ sealed trait ParquetField[T] extends Serializable {

protected val isGroup: Boolean = false
protected def isEmpty(v: T): Boolean
protected final def nonEmpty(v: T): Boolean = !isEmpty(v)

def write(c: RecordConsumer, v: T)(cm: CaseMapper): Unit
def newConverter: TypeConverter[T]

Expand Down Expand Up @@ -128,7 +130,7 @@ object ParquetField {
override def write(c: RecordConsumer, v: T)(cm: CaseMapper): Unit = {
caseClass.parameters.foreach { p =>
val x = p.dereference(v)
if (!p.typeclass.isEmpty(x)) {
if (p.typeclass.nonEmpty(x)) {
val name = cm.map(p.label)
c.startField(name, p.index)
p.typeclass.writeGroup(c, x)(cm)
Expand Down Expand Up @@ -280,7 +282,7 @@ object ParquetField {
new ParquetField[Option[T]] {
override def buildSchema(cm: CaseMapper): Type =
Schema.setRepetition(t.schema(cm), Repetition.OPTIONAL)
override protected def isEmpty(v: Option[T]): Boolean = v.isEmpty
override protected def isEmpty(v: Option[T]): Boolean = v.forall(t.isEmpty)

override def fieldDocs(cm: CaseMapper): Map[String, String] = t.fieldDocs(cm)

Expand Down Expand Up @@ -319,14 +321,14 @@ object ParquetField {
.requiredGroup()
.addField(Schema.rename(repeatedSchema, AvroArrayField))
.as(LogicalTypeAnnotation.listType())
.named(t.schema(cm).getName)
.named("iterable")
} else {
repeatedSchema
}
}

override protected val isGroup: Boolean = hasAvroArray
override protected def isEmpty(v: C[T]): Boolean = v.isEmpty
override protected def isEmpty(v: C[T]): Boolean = v.forall(t.isEmpty)

override def write(c: RecordConsumer, v: C[T])(cm: CaseMapper): Unit =
if (hasAvroArray) {
Expand Down Expand Up @@ -366,6 +368,95 @@ object ParquetField {
}
}

private val KeyField = "key"
private val ValueField = "value"
private val KeyValueGroup = "key_value"
implicit def pfMap[K, V](implicit
pfKey: ParquetField[K],
pfValue: ParquetField[V]
): ParquetField[Map[K, V]] = {
new ParquetField[Map[K, V]] {
override def buildSchema(cm: CaseMapper): Type = {
val keySchema = Schema.rename(pfKey.schema(cm), KeyField)
require(keySchema.isRepetition(Repetition.REQUIRED), "Map key must be required")
val valueSchema = Schema.rename(pfValue.schema(cm), ValueField)
val keyValue = Types
.repeatedGroup()
.addField(keySchema)
.addField(valueSchema)
.named(KeyValueGroup)
Types
.requiredGroup()
.addField(keyValue)
.as(LogicalTypeAnnotation.mapType())
.named("map")
}

override val hasAvroArray: Boolean = pfKey.hasAvroArray || pfValue.hasAvroArray

override protected def isEmpty(v: Map[K, V]): Boolean = v.isEmpty

override def fieldDocs(cm: CaseMapper): Map[String, String] = Map.empty

override val typeDoc: Option[String] = None

override def write(c: RecordConsumer, v: Map[K, V])(cm: CaseMapper): Unit = {
if (v.nonEmpty) {
c.startGroup()
c.startField(KeyValueGroup, 0)
v.foreach { case (k, v) =>
c.startGroup()
c.startField(KeyField, 0)
pfKey.writeGroup(c, k)(cm)
c.endField(KeyField, 0)
if (pfValue.nonEmpty(v)) {
c.startField(ValueField, 1)
pfValue.writeGroup(c, v)(cm)
c.endField(ValueField, 1)
}
c.endGroup()
}
c.endField(KeyValueGroup, 0)
c.endGroup()
}
}

override def newConverter: TypeConverter[Map[K, V]] = {
val kvConverter = new GroupConverter with TypeConverter.Buffered[(K, V)] {
private val keyConverter = pfKey.newConverter
private val valueConverter = pfValue.newConverter
private val fieldConverters = Array(keyConverter, valueConverter)

override def isPrimitive: Boolean = false

override def getConverter(fieldIndex: Int): Converter = fieldConverters(fieldIndex)

override def start(): Unit = ()

override def end(): Unit = {
val key = keyConverter.get
val value = valueConverter.get
addValue(key -> value)
}
}.withRepetition(Repetition.REPEATED)

val mapConverter = new TypeConverter.Delegate[(K, V), Map[K, V]](kvConverter) {
override def get: Map[K, V] = inner.get(_.toMap)
}

new GroupConverter with TypeConverter.Buffered[Map[K, V]] {
override def getConverter(fieldIndex: Int): Converter = {
require(fieldIndex == 0, "Map field index != 0")
mapConverter
}
override def start(): Unit = ()
override def end(): Unit = addValue(mapConverter.get)
override def get: Map[K, V] = get(_.headOption.getOrElse(Map.empty))
}
}
}
}

// ////////////////////////////////////////////////

def logicalType[T](lta: => LogicalTypeAnnotation): LogicalTypeWord[T] =
Expand Down
16 changes: 16 additions & 0 deletions parquet/src/test/scala/magnolify/parquet/ParquetTypeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ class ParquetTypeSuite extends MagnolifySuite {

test[ParquetTypes]

test[MapPrimitive]
test[MapNested]

test("AnyVal") {
implicit val pt: ParquetType[HasValueClass] = ParquetType[HasValueClass]
test[HasValueClass]
Expand Down Expand Up @@ -193,6 +196,19 @@ class ParquetTypeSuite extends MagnolifySuite {

case class Unsafe(c: Char)
case class ParquetTypes(b: Byte, s: Short, ba: Array[Byte])

// It is technically possible to have an optional map, but operation is not bijective
// parquet would read Some(Map.empty) as None
case class MapPrimitive(
m: Map[String, Int],
// mo: Option[Map[String, Int]],
mvo: Map[String, Option[Int]]
)
case class MapNested(
m: Map[Integers, Nested],
// mo: Option[Map[Integers, Nested]],
mvo: Map[Integers, Option[Nested]]
)
case class Decimal(bd: BigDecimal, bdo: Option[BigDecimal])
case class Logical(u: UUID, d: LocalDate)
case class Time(i: Instant, dt: LocalDateTime, ot: OffsetTime, t: LocalTime)
Expand Down
61 changes: 30 additions & 31 deletions tools/src/main/scala/magnolify/tools/AvroParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,50 +28,47 @@ object AvroParser extends SchemaParser[avro.Schema] {

private def parseRecord(schema: avro.Schema): Record = {
val fields = schema.getFields.asScala.iterator.map { f =>
val (s, r) = parseSchemaAndRepetition(f.schema())
Field(f.name(), Option(f.doc()), s, r)
val s = parseSchema(f.schema())
Record.Field(f.name(), Option(f.doc()), s)
}.toList
Record(Option(schema.getName), Option(schema.getNamespace), Option(schema.getDoc), fields)
Record(
Some(schema.getName),
Option(schema.getDoc),
fields
)
}

private def parseEnum(schema: avro.Schema): Enum =
Enum(
Option(schema.getName),
Option(schema.getNamespace),
private def parseEnum(schema: avro.Schema): Primitive.Enum =
Primitive.Enum(
Some(schema.getName),
Option(schema.getDoc),
schema.getEnumSymbols.asScala.toList
)

private def parseSchemaAndRepetition(schema: avro.Schema): (Schema, Repetition) =
schema.getType match {
case Type.UNION
if schema.getTypes.size() == 2 &&
schema.getTypes.asScala.count(_.getType == Type.NULL) == 1 =>
val s = schema.getTypes.asScala.find(_.getType != Type.NULL).get
if (s.getType == Type.ARRAY) {
// Nullable array, e.g. ["null", {"type": "array", "items": ...}]
(parseSchema(s.getElementType), Repeated)
} else {
(parseSchema(s), Optional)
}
case Type.ARRAY =>
(parseSchema(schema.getElementType), Repeated)
// FIXME: map
case _ =>
(parseSchema(schema), Required)
}

private def parseSchema(schema: avro.Schema): Schema = schema.getType match {
// Nested types
case Type.RECORD => parseRecord(schema)
case Type.ENUM => parseEnum(schema)
// Composite types
case Type.RECORD =>
parseRecord(schema)
case Type.UNION =>
val types = schema.getTypes.asScala
if (types.size != 2 || !types.exists(_.getType == Type.NULL)) {
throw new IllegalArgumentException(s"Unsupported union $schema")
} else {
val s = types.find(_.getType != Type.NULL).get
Optional(parseSchema(s))
}
case Type.ARRAY =>
Repeated(parseSchema(schema.getElementType))
case Type.MAP =>
Mapped(Primitive.String, parseSchema(schema.getValueType))

// Logical types
case Type.STRING if isLogical(schema, LogicalTypes.uuid().getName) =>
Primitive.UUID
case Type.BYTES if schema.getLogicalType.isInstanceOf[LogicalTypes.Decimal] =>
Primitive.BigDecimal
case Type.INT if schema.getLogicalType.isInstanceOf[LogicalTypes.Date] => Primitive.LocalDate
case Type.INT if schema.getLogicalType.isInstanceOf[LogicalTypes.Date] =>
Primitive.LocalDate

// Millis
case Type.LONG if schema.getLogicalType.isInstanceOf[LogicalTypes.TimestampMillis] =>
Expand All @@ -92,9 +89,11 @@ object AvroParser extends SchemaParser[avro.Schema] {
Primitive.LocalDateTime

// BigQuery sqlType: DATETIME
case Type.STRING if isLogical(schema, "datetime") => Primitive.LocalDateTime
case Type.STRING if isLogical(schema, "datetime") =>
Primitive.LocalDateTime

// Primitive types
case Type.ENUM => parseEnum(schema)
case Type.FIXED => Primitive.Bytes
case Type.STRING => Primitive.String
case Type.BYTES => Primitive.Bytes
Expand Down
17 changes: 9 additions & 8 deletions tools/src/main/scala/magnolify/tools/BigQueryParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,7 @@ object BigQueryParser extends SchemaParser[TableSchema] {

private def parseRecord(fields: List[TableFieldSchema]): Record = {
val fs = fields.map { f =>
val r = f.getMode match {
case "REQUIRED" => Required
case "NULLABLE" => Optional
case "REPEATED" => Repeated
}
val s = f.getType match {
val schema = f.getType match {
case "INT64" => Primitive.Long
case "FLOAT64" => Primitive.Double
case "NUMERIC" => Primitive.BigDecimal
Expand All @@ -44,8 +39,14 @@ object BigQueryParser extends SchemaParser[TableSchema] {
case "DATETIME" => Primitive.LocalDateTime
case "STRUCT" => parseRecord(f.getFields.asScala.toList)
}
Field(f.getName, Option(f.getDescription), s, r)

val moddedSchema = f.getMode match {
case "REQUIRED" => schema
case "NULLABLE" => Optional(schema)
case "REPEATED" => Repeated(schema)
}
Record.Field(f.getName, Option(f.getDescription), moddedSchema)
}
Record(None, None, None, fs)
Record(None, None, fs)
}
}
Loading

0 comments on commit 43ce752

Please sign in to comment.