Skip to content

Commit

Permalink
Allow typemappers to throw exception on default values
Browse files Browse the repository at this point in the history
  • Loading branch information
thesamet committed Jan 20, 2021
1 parent a647fea commit 23d2d96
Show file tree
Hide file tree
Showing 75 changed files with 247 additions and 162 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -720,19 +720,54 @@ class ProtobufGenerator(
printer.addWithDelimiter(",")(constructorFields(message).map(_.fullString))
}

private def usesBaseTypeInBuilder(field: FieldDescriptor) = field.isRequired || field.noBox

def generateBuilder(message: Descriptor)(printer: FunctionalPrinter): FunctionalPrinter = {
val myFullScalaName = message.scalaType.fullNameWithMaybeRoot(message)
val requiredFieldMap: Map[FieldDescriptor, Int] =
message.fields.filter(_.isRequired).zipWithIndex.toMap
case class Field(name: String, typeName: String, default: String, accessor: String)
case class Field(
name: String,
targetName: String,
typeName: String,
default: String,
accessor: String,
builder: String
)

val fields = message.fieldsWithoutOneofs.map { field =>
if (!field.isRepeated)
if (usesBaseTypeInBuilder(field)) {
// To handle custom types that have no default values, we wrap required/no-boxed messages in
// Option during parsing. We also apply the type mapper after parsing is complete.
if (field.isMessage)
Field(
s"__${field.scalaName}",
field.scalaName.asSymbol,
s"_root_.scala.Option[${field.baseSingleScalaTypeName}]",
C.None,
s"_root_.scala.Some(${toBaseTypeExpr(field)(s"_message__.${field.scalaName.asSymbol}", EnclosingType.None)})",
toCustomTypeExpr(field)(
s"__${field.scalaName}.getOrElse(${field.getMessageType.scalaType.fullName}.defaultInstance)",
EnclosingType.None
)
)
else
Field(
s"__${field.scalaName}",
field.scalaName.asSymbol,
field.baseSingleScalaTypeName,
defaultValueForGet(field, uncustomized = true),
toBaseTypeExpr(field)(s"_message__.${field.scalaName.asSymbol}", EnclosingType.None),
toCustomTypeExpr(field)(s"__${field.scalaName}", EnclosingType.None)
)
} else if (!field.isRepeated)
Field(
s"__${field.scalaName}",
field.scalaName.asSymbol,
field.scalaTypeName,
defaultValueForDefaultInstance(field),
s"_message__.${field.scalaName.asSymbol}"
s"_message__.${field.scalaName.asSymbol}",
s"__${field.scalaName}"
)
else {
val it =
Expand All @@ -741,25 +776,31 @@ class ProtobufGenerator(
else s"_message__.${field.scalaName.asSymbol}"
Field(
s"__${field.scalaName}",
field.scalaName.asSymbol,
s"collection.mutable.Builder[${field.singleScalaTypeName}, ${field.scalaTypeName}]",
field.collection.newBuilder,
s"${field.collection.newBuilder} ++= $it"
s"${field.collection.newBuilder} ++= $it",
s"__${field.scalaName}.result()"
)
}
} ++ message.getRealOneofs.asScala.map { oneof =>
Field(
s"__${oneof.scalaName.name}",
oneof.scalaName.name.asSymbol,
oneof.scalaType.fullName,
oneof.empty.fullName,
s"_message__.${oneof.scalaName.nameSymbol}"
s"_message__.${oneof.scalaName.nameSymbol}",
s"__${oneof.scalaName.name}"
)
} ++ (if (message.preservesUnknownFields)
Seq(
Field(
"`_unknownFields__`",
"unknownFields",
"_root_.scalapb.UnknownFieldSet.Builder",
"null",
"new _root_.scalapb.UnknownFieldSet.Builder(_message__.unknownFields)"
"new _root_.scalapb.UnknownFieldSet.Builder(_message__.unknownFields)",
"if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()"
)
)
else Seq.empty)
Expand Down Expand Up @@ -795,20 +836,7 @@ class ProtobufGenerator(
)
}.add(s"$myFullScalaName(")
.indented(
_.addWithDelimiter(",")(
(message.fieldsWithoutOneofs ++ message.getOneofs.asScala).map {
case e: FieldDescriptor if e.isRepeated =>
s" ${e.scalaName.asSymbol} = __${e.scalaName}.result()"
case e: FieldDescriptor =>
s" ${e.scalaName.asSymbol} = __${e.scalaName}"
case e: OneofDescriptor =>
s" ${e.scalaName.nameSymbol} = __${e.scalaName.name}"
} ++ (if (message.preservesUnknownFields)
Seq(
" unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()"
)
else Seq())
)
_.addWithDelimiter(",")(fields.map(e => s" ${e.targetName} = ${e.builder}"))
)
.add(")")
)
Expand Down Expand Up @@ -851,28 +879,32 @@ class ProtobufGenerator(
.print(message.fields) { (printer, field) =>
val p = {
val newValBase = if (field.isMessage) {
val defInstance =
s"${field.getMessageType.scalaType.fullNameWithMaybeRoot(message)}.defaultInstance"
val baseInstance =
if (field.isRepeated) defInstance
else {
// In 0.10.x we can't simply call any of the new methods that relies on Builder,
// since the references message may have been generated using an older version of
// ScalaPB.
val baseName = field.baseSingleScalaTypeName
val read =
if (field.isRepeated) s"_root_.scalapb.LiteParser.readMessage[$baseName](_input__)"
else if (usesBaseTypeInBuilder(field)) {
s"_root_.scala.Some(__${field.scalaName}.fold(_root_.scalapb.LiteParser.readMessage[$baseName](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _)))"
} else {
val expr =
if (field.isInOneof)
s"__${fieldAccessorSymbol(field)}"
else s"__${field.scalaName}"
val mappedType =
toBaseFieldType(field).apply(expr, field.enclosingType)
val mappedType = toBaseFieldType(field).apply(expr, field.enclosingType)
if (field.isInOneof || field.supportsPresence)
(mappedType + s".getOrElse($defInstance)")
else mappedType
s"$mappedType.fold(_root_.scalapb.LiteParser.readMessage[$baseName](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))"
else s"_root_.scalapb.LiteParser.readMessage[$baseName](_input__, $mappedType)"
}
s"_root_.scalapb.LiteParser.readMessage(_input__, $baseInstance)"
read
} else if (field.isEnum)
s"${field.getEnumType.scalaType.fullNameWithMaybeRoot(message)}.fromValue(_input__.readEnum())"
else if (field.getType == Type.STRING) s"_input__.readStringRequireUtf8()"
else s"_input__.read${Types.capitalizedType(field.getType)}()"

val newVal = toCustomType(field)(newValBase)
val newVal =
if (!usesBaseTypeInBuilder(field)) toCustomType(field)(newValBase) else newValBase

val updateOp =
if (field.supportsPresence) s"__${field.scalaName} = Option($newVal)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ object MyContainer extends scalapb.GeneratedMessageCompanion[com.thesamet.docs.j
_tag__ match {
case 0 => _done__ = true
case 10 =>
__myAny = Option(_root_.scalapb.LiteParser.readMessage(_input__, __myAny.getOrElse(com.google.protobuf.any.Any.defaultInstance)))
__myAny = Option(__myAny.fold(_root_.scalapb.LiteParser.readMessage[com.google.protobuf.any.Any](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _)))
case tag =>
if (_unknownFields__ == null) {
_unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder()
Expand Down
2 changes: 1 addition & 1 deletion docs/src/main/scala/scalapb/docs/person/Person.scala
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ object Person extends scalapb.GeneratedMessageCompanion[scalapb.docs.person.Pers
case 16 =>
__age = _input__.readInt32()
case 26 =>
__addresses += _root_.scalapb.LiteParser.readMessage(_input__, scalapb.docs.person.Person.Address.defaultInstance)
__addresses += _root_.scalapb.LiteParser.readMessage[scalapb.docs.person.Person.Address](_input__)
case tag =>
if (_unknownFields__ == null) {
_unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ object MessageContainer extends scalapb.GeneratedMessageCompanion[scalapb.perf.p
_tag__ match {
case 0 => _done__ = true
case 10 =>
__opt = Option(_root_.scalapb.LiteParser.readMessage(_input__, __opt.getOrElse(scalapb.perf.protos.SimpleMessage.defaultInstance)))
__opt = Option(__opt.fold(_root_.scalapb.LiteParser.readMessage[scalapb.perf.protos.SimpleMessage](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _)))
case 18 =>
__rep += _root_.scalapb.LiteParser.readMessage(_input__, scalapb.perf.protos.SimpleMessage.defaultInstance)
__rep += _root_.scalapb.LiteParser.readMessage[scalapb.perf.protos.SimpleMessage](_input__)
case tag =>
if (_unknownFields__ == null) {
_unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder()
Expand Down
16 changes: 16 additions & 0 deletions e2e-withjava/src/main/protobuf/custom_types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ message CustomMessage {
repeated Name names = 15 [(scalapb.field).type = "com.thesamet.pb.FullName"];
}

message HasEmail {
optional string optional_email = 1 [(scalapb.field).type = "com.thesamet.pb.Email"];
required string required_email = 2 [(scalapb.field).type = "com.thesamet.pb.Email"];
repeated string repeated_email = 3 [(scalapb.field).type = "com.thesamet.pb.Email"];
}

message NoBoxEmail {
optional string no_box_email = 1 [(scalapb.field).type = "com.thesamet.pb.Email", (scalapb.field).no_box=true];
}

message ContainsHasEmail {
optional HasEmail optional_has_email = 1;
repeated HasEmail repeated_has_email = 2;
required HasEmail required_has_email = 3;
}

message OneofMessage {
oneof one_of {
string person_id = 1 [(scalapb.field).type = "com.thesamet.pb.PersonId"];
Expand Down
16 changes: 16 additions & 0 deletions e2e-withjava/src/main/scala/com/thesamet/pb/Email.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.thesamet.pb

import scalapb.TypeMapper

case class Email(user: String, domain: String) {
override def toString = s"$user@$domain"
}

object Email {
def fromString(s: String) = s.split("@", 2).toSeq match {
case Seq(user, domain) => Email(user, domain)
case _ => throw new IllegalArgumentException(s"Expected @ in email. Got: $s")
}

implicit val emailTypeMapper = TypeMapper[String, Email](fromString)(_.toString)
}
21 changes: 21 additions & 0 deletions e2e/src/test/scala/CustomTypesSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,27 @@ class CustomTypesSpec extends AnyFlatSpec with Matchers {
CustomerEvent mustBe a[DomainEventCompanion]
CustomerEvent.thisIs must be("The companion object")
}

"HasEmail" should "serialize and parse valid instances" in {
val dm = HasEmail(
requiredEmail = Email("foo", "bar")
)
HasEmail.parseFrom(dm.toByteArray) must be(dm)
}

"NoBoxEmail" should "serialize and parse valid instances" in {
val dm = NoBoxEmail(
noBoxEmail = Email("foo", "bar")
)
NoBoxEmail.parseFrom(dm.toByteArray) must be(dm)
}

"ContainsHasEmail" should "serialize and parse valid instances" in {
val cem = ContainsHasEmail(
requiredHasEmail = HasEmail(requiredEmail = Email("foo", "bar"))
)
ContainsHasEmail.parseFrom(cem.toByteArray) must be(cem)
}
}

object CustomTypesSpec {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ object FieldOptions extends scalapb.GeneratedMessageCompanion[scalapb.options.Fi
case 26 =>
__collectionType = Option(_input__.readStringRequireUtf8())
case 66 =>
__collection = Option(_root_.scalapb.LiteParser.readMessage(_input__, __collection.getOrElse(scalapb.options.Collection.defaultInstance)))
__collection = Option(__collection.fold(_root_.scalapb.LiteParser.readMessage[scalapb.options.Collection](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _)))
case 34 =>
__keyType = Option(_input__.readStringRequireUtf8())
case 42 =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -619,13 +619,13 @@ object ScalaPbOptions extends scalapb.GeneratedMessageCompanion[scalapb.options.
case 184 =>
__javaConversions = Option(_input__.readBool())
case 146 =>
__auxMessageOptions += _root_.scalapb.LiteParser.readMessage(_input__, scalapb.options.ScalaPbOptions.AuxMessageOptions.defaultInstance)
__auxMessageOptions += _root_.scalapb.LiteParser.readMessage[scalapb.options.ScalaPbOptions.AuxMessageOptions](_input__)
case 154 =>
__auxFieldOptions += _root_.scalapb.LiteParser.readMessage(_input__, scalapb.options.ScalaPbOptions.AuxFieldOptions.defaultInstance)
__auxFieldOptions += _root_.scalapb.LiteParser.readMessage[scalapb.options.ScalaPbOptions.AuxFieldOptions](_input__)
case 162 =>
__auxEnumOptions += _root_.scalapb.LiteParser.readMessage(_input__, scalapb.options.ScalaPbOptions.AuxEnumOptions.defaultInstance)
__auxEnumOptions += _root_.scalapb.LiteParser.readMessage[scalapb.options.ScalaPbOptions.AuxEnumOptions](_input__)
case 178 =>
__auxEnumValueOptions += _root_.scalapb.LiteParser.readMessage(_input__, scalapb.options.ScalaPbOptions.AuxEnumValueOptions.defaultInstance)
__auxEnumValueOptions += _root_.scalapb.LiteParser.readMessage[scalapb.options.ScalaPbOptions.AuxEnumValueOptions](_input__)
case 7992 =>
__testOnlyNoJavaConversions = Option(_input__.readBool())
case tag =>
Expand Down Expand Up @@ -930,7 +930,7 @@ object ScalaPbOptions extends scalapb.GeneratedMessageCompanion[scalapb.options.
case 10 =>
__target = Option(_input__.readStringRequireUtf8())
case 18 =>
__options = Option(_root_.scalapb.LiteParser.readMessage(_input__, __options.getOrElse(scalapb.options.MessageOptions.defaultInstance)))
__options = Option(__options.fold(_root_.scalapb.LiteParser.readMessage[scalapb.options.MessageOptions](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _)))
case tag =>
if (_unknownFields__ == null) {
_unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder()
Expand Down Expand Up @@ -1099,7 +1099,7 @@ object ScalaPbOptions extends scalapb.GeneratedMessageCompanion[scalapb.options.
case 10 =>
__target = Option(_input__.readStringRequireUtf8())
case 18 =>
__options = Option(_root_.scalapb.LiteParser.readMessage(_input__, __options.getOrElse(scalapb.options.FieldOptions.defaultInstance)))
__options = Option(__options.fold(_root_.scalapb.LiteParser.readMessage[scalapb.options.FieldOptions](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _)))
case tag =>
if (_unknownFields__ == null) {
_unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder()
Expand Down Expand Up @@ -1268,7 +1268,7 @@ object ScalaPbOptions extends scalapb.GeneratedMessageCompanion[scalapb.options.
case 10 =>
__target = Option(_input__.readStringRequireUtf8())
case 18 =>
__options = Option(_root_.scalapb.LiteParser.readMessage(_input__, __options.getOrElse(scalapb.options.EnumOptions.defaultInstance)))
__options = Option(__options.fold(_root_.scalapb.LiteParser.readMessage[scalapb.options.EnumOptions](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _)))
case tag =>
if (_unknownFields__ == null) {
_unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder()
Expand Down Expand Up @@ -1437,7 +1437,7 @@ object ScalaPbOptions extends scalapb.GeneratedMessageCompanion[scalapb.options.
case 10 =>
__target = Option(_input__.readStringRequireUtf8())
case 18 =>
__options = Option(_root_.scalapb.LiteParser.readMessage(_input__, __options.getOrElse(scalapb.options.EnumValueOptions.defaultInstance)))
__options = Option(__options.fold(_root_.scalapb.LiteParser.readMessage[scalapb.options.EnumValueOptions](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _)))
case tag =>
if (_unknownFields__ == null) {
_unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,15 @@ object Api extends scalapb.GeneratedMessageCompanion[com.google.protobuf.api.Api
case 10 =>
__name = _input__.readStringRequireUtf8()
case 18 =>
__methods += _root_.scalapb.LiteParser.readMessage(_input__, com.google.protobuf.api.Method.defaultInstance)
__methods += _root_.scalapb.LiteParser.readMessage[com.google.protobuf.api.Method](_input__)
case 26 =>
__options += _root_.scalapb.LiteParser.readMessage(_input__, com.google.protobuf.`type`.OptionProto.defaultInstance)
__options += _root_.scalapb.LiteParser.readMessage[com.google.protobuf.`type`.OptionProto](_input__)
case 34 =>
__version = _input__.readStringRequireUtf8()
case 42 =>
__sourceContext = Option(_root_.scalapb.LiteParser.readMessage(_input__, __sourceContext.getOrElse(com.google.protobuf.source_context.SourceContext.defaultInstance)))
__sourceContext = Option(__sourceContext.fold(_root_.scalapb.LiteParser.readMessage[com.google.protobuf.source_context.SourceContext](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _)))
case 50 =>
__mixins += _root_.scalapb.LiteParser.readMessage(_input__, com.google.protobuf.api.Mixin.defaultInstance)
__mixins += _root_.scalapb.LiteParser.readMessage[com.google.protobuf.api.Mixin](_input__)
case 56 =>
__syntax = com.google.protobuf.`type`.Syntax.fromValue(_input__.readEnum())
case tag =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ object Method extends scalapb.GeneratedMessageCompanion[com.google.protobuf.api.
case 40 =>
__responseStreaming = _input__.readBool()
case 50 =>
__options += _root_.scalapb.LiteParser.readMessage(_input__, com.google.protobuf.`type`.OptionProto.defaultInstance)
__options += _root_.scalapb.LiteParser.readMessage[com.google.protobuf.`type`.OptionProto](_input__)
case 56 =>
__syntax = com.google.protobuf.`type`.Syntax.fromValue(_input__.readEnum())
case tag =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ object CodeGeneratorRequest extends scalapb.GeneratedMessageCompanion[com.google
case 18 =>
__parameter = Option(_input__.readStringRequireUtf8())
case 122 =>
__protoFile += _root_.scalapb.LiteParser.readMessage(_input__, com.google.protobuf.descriptor.FileDescriptorProto.defaultInstance)
__protoFile += _root_.scalapb.LiteParser.readMessage[com.google.protobuf.descriptor.FileDescriptorProto](_input__)
case 26 =>
__compilerVersion = Option(_root_.scalapb.LiteParser.readMessage(_input__, __compilerVersion.getOrElse(com.google.protobuf.compiler.plugin.Version.defaultInstance)))
__compilerVersion = Option(__compilerVersion.fold(_root_.scalapb.LiteParser.readMessage[com.google.protobuf.compiler.plugin.Version](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _)))
case tag =>
if (_unknownFields__ == null) {
_unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder()
Expand Down
Loading

0 comments on commit 23d2d96

Please sign in to comment.