Skip to content

Commit

Permalink
Allow typemappers to throw exception on default values
Browse files Browse the repository at this point in the history
  • Loading branch information
thesamet committed Dec 24, 2020
1 parent fe8749a commit d07dcde
Show file tree
Hide file tree
Showing 121 changed files with 998 additions and 773 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -725,19 +725,54 @@ class ProtobufGenerator(
printer.addWithDelimiter(",")(constructorFields(message).map(_.fullString))
}

private def usesBaseTypeInBuilder(field: FieldDescriptor) = field.isRequired || field.noBox

def generateBuilder(message: Descriptor)(printer: FunctionalPrinter): FunctionalPrinter = {
val myFullScalaName = message.scalaType.fullNameWithMaybeRoot(message)
val requiredFieldMap: Map[FieldDescriptor, Int] =
message.fields.filter(_.isRequired).zipWithIndex.toMap
case class Field(name: String, typeName: String, default: String, accessor: String)
case class Field(
name: String,
targetName: String,
typeName: String,
default: String,
accessor: String,
builder: String
)

val fields = message.fieldsWithoutOneofs.map { field =>
if (!field.isRepeated)
if (usesBaseTypeInBuilder(field)) {
// To handle custom types that have no default values, we wrap required/no-boxed messages in
// Option during parsing. We also apply the type mapper after parsing is complete.
if (field.isMessage)
Field(
s"__${field.scalaName}",
field.scalaName.asSymbol,
s"_root_.scala.Option[${field.baseSingleScalaTypeName}]",
C.None,
s"_root_.scala.Some(${toBaseTypeExpr(field)(s"_message__.${field.scalaName.asSymbol}", EnclosingType.None)})",
toCustomTypeExpr(field)(
s"__${field.scalaName}.getOrElse(${field.getMessageType.scalaType.fullName}.defaultInstance)",
EnclosingType.None
)
)
else
Field(
s"__${field.scalaName}",
field.scalaName.asSymbol,
field.baseSingleScalaTypeName,
defaultValueForGet(field, uncustomized = true),
toBaseTypeExpr(field)(s"_message__.${field.scalaName.asSymbol}", EnclosingType.None),
toCustomTypeExpr(field)(s"__${field.scalaName}", EnclosingType.None)
)
} else if (!field.isRepeated)
Field(
s"__${field.scalaName}",
field.scalaName.asSymbol,
field.scalaTypeName,
defaultValueForDefaultInstance(field),
s"_message__.${field.scalaName.asSymbol}"
s"_message__.${field.scalaName.asSymbol}",
s"__${field.scalaName}"
)
else {
val it =
Expand All @@ -746,25 +781,31 @@ class ProtobufGenerator(
else s"_message__.${field.scalaName.asSymbol}"
Field(
s"__${field.scalaName}",
field.scalaName.asSymbol,
s"collection.mutable.Builder[${field.singleScalaTypeName}, ${field.scalaTypeName}]",
field.collection.newBuilder,
s"${field.collection.newBuilder} ++= $it"
s"${field.collection.newBuilder} ++= $it",
s"__${field.scalaName}.result()"
)
}
} ++ message.getOneofs.asScala.map { oneof =>
Field(
s"__${oneof.scalaName.name}",
oneof.scalaName.name.asSymbol,
oneof.scalaType.fullName,
oneof.empty.fullName,
s"_message__.${oneof.scalaName.nameSymbol}"
s"_message__.${oneof.scalaName.nameSymbol}",
s"__${oneof.scalaName.name}"
)
} ++ (if (message.preservesUnknownFields)
Seq(
Field(
"`_unknownFields__`",
"unknownFields",
"_root_.scalapb.UnknownFieldSet.Builder",
"null",
"new _root_.scalapb.UnknownFieldSet.Builder(_message__.unknownFields)"
"new _root_.scalapb.UnknownFieldSet.Builder(_message__.unknownFields)",
"if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()"
)
)
else Seq.empty)
Expand Down Expand Up @@ -800,20 +841,7 @@ class ProtobufGenerator(
)
}.add(s"$myFullScalaName(")
.indented(
_.addWithDelimiter(",")(
(message.fieldsWithoutOneofs ++ message.getOneofs.asScala).map {
case e: FieldDescriptor if e.isRepeated =>
s" ${e.scalaName.asSymbol} = __${e.scalaName}.result()"
case e: FieldDescriptor =>
s" ${e.scalaName.asSymbol} = __${e.scalaName}"
case e: OneofDescriptor =>
s" ${e.scalaName.nameSymbol} = __${e.scalaName.name}"
} ++ (if (message.preservesUnknownFields)
Seq(
" unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()"
)
else Seq())
)
_.addWithDelimiter(",")(fields.map(e => s"${e.targetName} = ${e.builder}"))
)
.add(")")
)
Expand Down Expand Up @@ -856,28 +884,32 @@ class ProtobufGenerator(
.print(message.fields) { (printer, field) =>
val p = {
val newValBase = if (field.isMessage) {
val defInstance =
s"${field.getMessageType.scalaType.fullNameWithMaybeRoot(message)}.defaultInstance"
val baseInstance =
if (field.isRepeated) defInstance
else {
// In 0.10.x we can't simply call any of the new methods that relies on Builder,
// since the references message may have been generated using an older version of
// ScalaPB.
val baseName = field.baseSingleScalaTypeName
val read =
if (field.isRepeated) s"_root_.scalapb.LiteParser.readMessage[$baseName](_input__)"
else if (usesBaseTypeInBuilder(field)) {
s"_root_.scala.Some(__${field.scalaName}.fold(_root_.scalapb.LiteParser.readMessage[$baseName](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _)))"
} else {
val expr =
if (field.isInOneof)
s"__${fieldAccessorSymbol(field)}"
else s"__${field.scalaName}"
val mappedType =
toBaseFieldType(field).apply(expr, field.enclosingType)
val mappedType = toBaseFieldType(field).apply(expr, field.enclosingType)
if (field.isInOneof || field.supportsPresence)
(mappedType + s".getOrElse($defInstance)")
else mappedType
s"$mappedType.fold(_root_.scalapb.LiteParser.readMessage[$baseName](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))"
else s"_root_.scalapb.LiteParser.readMessage[$baseName](_input__, $mappedType)"
}
s"_root_.scalapb.LiteParser.readMessage(_input__, $baseInstance)"
read
} else if (field.isEnum)
s"${field.getEnumType.scalaType.fullNameWithMaybeRoot(message)}.fromValue(_input__.readEnum())"
else if (field.getType == Type.STRING) s"_input__.readStringRequireUtf8()"
else s"_input__.read${Types.capitalizedType(field.getType)}()"

val newVal = toCustomType(field)(newValBase)
val newVal =
if (!usesBaseTypeInBuilder(field)) toCustomType(field)(newValBase) else newValBase

val updateOp =
if (field.supportsPresence) s"__${field.scalaName} = Option($newVal)"
Expand Down Expand Up @@ -1425,6 +1457,9 @@ class ProtobufGenerator(
.add(s"""object $className extends $companionType {
| implicit def messageCompanion: $companionType = this""".stripMargin)
.indent
.add(
s"override def parseFrom(input: _root_.com.google.protobuf.CodedInputStream): ${message.scalaType.fullName} = newBuilder.merge(input).result"
)
.when(message.javaConversions)(generateToJavaProto(message))
.when(message.javaConversions)(generateFromJavaProto(message))
.call(generateMerge(message))
Expand Down
7 changes: 4 additions & 3 deletions docs/src/main/scala/com/thesamet/docs/json/MyContainer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ final case class MyContainer(

object MyContainer extends scalapb.GeneratedMessageCompanion[com.thesamet.docs.json.MyContainer] {
implicit def messageCompanion: scalapb.GeneratedMessageCompanion[com.thesamet.docs.json.MyContainer] = this
override def parseFrom(input: _root_.com.google.protobuf.CodedInputStream): com.thesamet.docs.json.MyContainer = newBuilder.merge(input).result
def merge(`_message__`: com.thesamet.docs.json.MyContainer, `_input__`: _root_.com.google.protobuf.CodedInputStream): com.thesamet.docs.json.MyContainer = newBuilder(_message__).merge(_input__).result()
implicit def messageReads: _root_.scalapb.descriptors.Reads[com.thesamet.docs.json.MyContainer] = _root_.scalapb.descriptors.Reads{
case _root_.scalapb.descriptors.PMessage(__fieldsMap) =>
Expand Down Expand Up @@ -94,7 +95,7 @@ object MyContainer extends scalapb.GeneratedMessageCompanion[com.thesamet.docs.j
_tag__ match {
case 0 => _done__ = true
case 10 =>
__myAny = Option(_root_.scalapb.LiteParser.readMessage(_input__, __myAny.getOrElse(com.google.protobuf.any.Any.defaultInstance)))
__myAny = Option(__myAny.fold(_root_.scalapb.LiteParser.readMessage[com.google.protobuf.any.Any](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _)))
case tag =>
if (_unknownFields__ == null) {
_unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder()
Expand All @@ -106,8 +107,8 @@ object MyContainer extends scalapb.GeneratedMessageCompanion[com.thesamet.docs.j
}
def result(): com.thesamet.docs.json.MyContainer = {
com.thesamet.docs.json.MyContainer(
myAny = __myAny,
unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()
myAny = __myAny,
unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()
)
}
}
Expand Down
5 changes: 3 additions & 2 deletions docs/src/main/scala/com/thesamet/docs/json/MyMessage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ final case class MyMessage(

object MyMessage extends scalapb.GeneratedMessageCompanion[com.thesamet.docs.json.MyMessage] {
implicit def messageCompanion: scalapb.GeneratedMessageCompanion[com.thesamet.docs.json.MyMessage] = this
override def parseFrom(input: _root_.com.google.protobuf.CodedInputStream): com.thesamet.docs.json.MyMessage = newBuilder.merge(input).result
def merge(`_message__`: com.thesamet.docs.json.MyMessage, `_input__`: _root_.com.google.protobuf.CodedInputStream): com.thesamet.docs.json.MyMessage = newBuilder(_message__).merge(_input__).result()
implicit def messageReads: _root_.scalapb.descriptors.Reads[com.thesamet.docs.json.MyMessage] = _root_.scalapb.descriptors.Reads{
case _root_.scalapb.descriptors.PMessage(__fieldsMap) =>
Expand Down Expand Up @@ -104,8 +105,8 @@ object MyMessage extends scalapb.GeneratedMessageCompanion[com.thesamet.docs.jso
}
def result(): com.thesamet.docs.json.MyMessage = {
com.thesamet.docs.json.MyMessage(
x = __x,
unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()
x = __x,
unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()
)
}
}
Expand Down
5 changes: 3 additions & 2 deletions docs/src/main/scala/mytypes/duration/Duration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ final case class Duration(

object Duration extends scalapb.GeneratedMessageCompanion[mytypes.duration.Duration] {
implicit def messageCompanion: scalapb.GeneratedMessageCompanion[mytypes.duration.Duration] = this
override def parseFrom(input: _root_.com.google.protobuf.CodedInputStream): mytypes.duration.Duration = newBuilder.merge(input).result
def merge(`_message__`: mytypes.duration.Duration, `_input__`: _root_.com.google.protobuf.CodedInputStream): mytypes.duration.Duration = newBuilder(_message__).merge(_input__).result()
implicit def messageReads: _root_.scalapb.descriptors.Reads[mytypes.duration.Duration] = _root_.scalapb.descriptors.Reads{
case _root_.scalapb.descriptors.PMessage(__fieldsMap) =>
Expand Down Expand Up @@ -104,8 +105,8 @@ object Duration extends scalapb.GeneratedMessageCompanion[mytypes.duration.Durat
}
def result(): mytypes.duration.Duration = {
mytypes.duration.Duration(
seconds = __seconds,
unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()
seconds = __seconds,
unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()
)
}
}
Expand Down
20 changes: 11 additions & 9 deletions docs/src/main/scala/scalapb/docs/person/Person.scala
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ final case class Person(

object Person extends scalapb.GeneratedMessageCompanion[scalapb.docs.person.Person] {
implicit def messageCompanion: scalapb.GeneratedMessageCompanion[scalapb.docs.person.Person] = this
override def parseFrom(input: _root_.com.google.protobuf.CodedInputStream): scalapb.docs.person.Person = newBuilder.merge(input).result
def merge(`_message__`: scalapb.docs.person.Person, `_input__`: _root_.com.google.protobuf.CodedInputStream): scalapb.docs.person.Person = newBuilder(_message__).merge(_input__).result()
implicit def messageReads: _root_.scalapb.descriptors.Reads[scalapb.docs.person.Person] = _root_.scalapb.descriptors.Reads{
case _root_.scalapb.descriptors.PMessage(__fieldsMap) =>
Expand Down Expand Up @@ -148,7 +149,7 @@ object Person extends scalapb.GeneratedMessageCompanion[scalapb.docs.person.Pers
case 16 =>
__age = _input__.readInt32()
case 26 =>
__addresses += _root_.scalapb.LiteParser.readMessage(_input__, scalapb.docs.person.Person.Address.defaultInstance)
__addresses += _root_.scalapb.LiteParser.readMessage[scalapb.docs.person.Person.Address](_input__)
case tag =>
if (_unknownFields__ == null) {
_unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder()
Expand All @@ -160,10 +161,10 @@ object Person extends scalapb.GeneratedMessageCompanion[scalapb.docs.person.Pers
}
def result(): scalapb.docs.person.Person = {
scalapb.docs.person.Person(
name = __name,
age = __age,
addresses = __addresses.result(),
unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()
name = __name,
age = __age,
addresses = __addresses.result(),
unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()
)
}
}
Expand Down Expand Up @@ -319,6 +320,7 @@ object Person extends scalapb.GeneratedMessageCompanion[scalapb.docs.person.Pers

object Address extends scalapb.GeneratedMessageCompanion[scalapb.docs.person.Person.Address] {
implicit def messageCompanion: scalapb.GeneratedMessageCompanion[scalapb.docs.person.Person.Address] = this
override def parseFrom(input: _root_.com.google.protobuf.CodedInputStream): scalapb.docs.person.Person.Address = newBuilder.merge(input).result
def merge(`_message__`: scalapb.docs.person.Person.Address, `_input__`: _root_.com.google.protobuf.CodedInputStream): scalapb.docs.person.Person.Address = newBuilder(_message__).merge(_input__).result()
implicit def messageReads: _root_.scalapb.descriptors.Reads[scalapb.docs.person.Person.Address] = _root_.scalapb.descriptors.Reads{
case _root_.scalapb.descriptors.PMessage(__fieldsMap) =>
Expand Down Expand Up @@ -373,10 +375,10 @@ object Person extends scalapb.GeneratedMessageCompanion[scalapb.docs.person.Pers
}
def result(): scalapb.docs.person.Person.Address = {
scalapb.docs.person.Person.Address(
addressType = __addressType,
street = __street,
city = __city,
unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()
addressType = __addressType,
street = __street,
city = __city,
unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()
)
}
}
Expand Down
5 changes: 3 additions & 2 deletions docs/src/main/scala/scalapb/perf/protos/Enum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ final case class Enum(

object Enum extends scalapb.GeneratedMessageCompanion[scalapb.perf.protos.Enum] {
implicit def messageCompanion: scalapb.GeneratedMessageCompanion[scalapb.perf.protos.Enum] = this
override def parseFrom(input: _root_.com.google.protobuf.CodedInputStream): scalapb.perf.protos.Enum = newBuilder.merge(input).result
def merge(`_message__`: scalapb.perf.protos.Enum, `_input__`: _root_.com.google.protobuf.CodedInputStream): scalapb.perf.protos.Enum = newBuilder(_message__).merge(_input__).result()
implicit def messageReads: _root_.scalapb.descriptors.Reads[scalapb.perf.protos.Enum] = _root_.scalapb.descriptors.Reads{
case _root_.scalapb.descriptors.PMessage(__fieldsMap) =>
Expand Down Expand Up @@ -108,8 +109,8 @@ object Enum extends scalapb.GeneratedMessageCompanion[scalapb.perf.protos.Enum]
}
def result(): scalapb.perf.protos.Enum = {
scalapb.perf.protos.Enum(
color = __color,
unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()
color = __color,
unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()
)
}
}
Expand Down
5 changes: 3 additions & 2 deletions docs/src/main/scala/scalapb/perf/protos/EnumVector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ final case class EnumVector(

object EnumVector extends scalapb.GeneratedMessageCompanion[scalapb.perf.protos.EnumVector] {
implicit def messageCompanion: scalapb.GeneratedMessageCompanion[scalapb.perf.protos.EnumVector] = this
override def parseFrom(input: _root_.com.google.protobuf.CodedInputStream): scalapb.perf.protos.EnumVector = newBuilder.merge(input).result
def merge(`_message__`: scalapb.perf.protos.EnumVector, `_input__`: _root_.com.google.protobuf.CodedInputStream): scalapb.perf.protos.EnumVector = newBuilder(_message__).merge(_input__).result()
implicit def messageReads: _root_.scalapb.descriptors.Reads[scalapb.perf.protos.EnumVector] = _root_.scalapb.descriptors.Reads{
case _root_.scalapb.descriptors.PMessage(__fieldsMap) =>
Expand Down Expand Up @@ -121,8 +122,8 @@ object EnumVector extends scalapb.GeneratedMessageCompanion[scalapb.perf.protos.
}
def result(): scalapb.perf.protos.EnumVector = {
scalapb.perf.protos.EnumVector(
colors = __colors.result(),
unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()
colors = __colors.result(),
unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()
)
}
}
Expand Down
5 changes: 3 additions & 2 deletions docs/src/main/scala/scalapb/perf/protos/IntVector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ final case class IntVector(

object IntVector extends scalapb.GeneratedMessageCompanion[scalapb.perf.protos.IntVector] {
implicit def messageCompanion: scalapb.GeneratedMessageCompanion[scalapb.perf.protos.IntVector] = this
override def parseFrom(input: _root_.com.google.protobuf.CodedInputStream): scalapb.perf.protos.IntVector = newBuilder.merge(input).result
def merge(`_message__`: scalapb.perf.protos.IntVector, `_input__`: _root_.com.google.protobuf.CodedInputStream): scalapb.perf.protos.IntVector = newBuilder(_message__).merge(_input__).result()
implicit def messageReads: _root_.scalapb.descriptors.Reads[scalapb.perf.protos.IntVector] = _root_.scalapb.descriptors.Reads{
case _root_.scalapb.descriptors.PMessage(__fieldsMap) =>
Expand Down Expand Up @@ -117,8 +118,8 @@ object IntVector extends scalapb.GeneratedMessageCompanion[scalapb.perf.protos.I
}
def result(): scalapb.perf.protos.IntVector = {
scalapb.perf.protos.IntVector(
ints = __ints.result(),
unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()
ints = __ints.result(),
unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()
)
}
}
Expand Down
Loading

0 comments on commit d07dcde

Please sign in to comment.