From 61ceba411f33368112d16cba6f9a3c1cd71a764f Mon Sep 17 00:00:00 2001 From: xuwei-k <6b656e6a69@gmail.com> Date: Thu, 19 Nov 2015 17:50:22 +0900 Subject: [PATCH] add grpc support http://www.grpc.io/ --- .../scalapb/compiler/DescriptorPimps.scala | 18 +- .../scalapb/compiler/FunctionalPrinter.scala | 27 +- .../scalapb/compiler/GrpcServicePrinter.scala | 306 ++++++++++++++++++ .../scalapb/compiler/ProtobufGenerator.scala | 16 +- .../scalapb/compiler/StreamType.scala | 9 + e2e/build.sbt | 58 +++- e2e/src/main/protobuf/service.proto | 36 +++ .../com/trueaccord/pb/Service1JavaImpl.scala | 53 +++ .../com/trueaccord/pb/Service1ScalaImpl.scala | 56 ++++ .../scala/GrpcServiceJavaServerSpec.scala | 84 +++++ .../scala/GrpcServiceScalaServerSpec.scala | 109 +++++++ e2e/src/test/scala/GrpcServiceSpecBase.scala | 57 ++++ e2e/src/test/scala/UniquePortGenerator.scala | 34 ++ proptest/build.sbt | 1 + proptest/src/test/scala/GraphGen.scala | 20 +- proptest/src/test/scala/Nodes.scala | 27 +- .../src/test/scala/SchemaGenerators.scala | 5 +- 17 files changed, 907 insertions(+), 9 deletions(-) create mode 100644 compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/GrpcServicePrinter.scala create mode 100644 compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/StreamType.scala create mode 100644 e2e/src/main/protobuf/service.proto create mode 100644 e2e/src/main/scala/com/trueaccord/pb/Service1JavaImpl.scala create mode 100644 e2e/src/main/scala/com/trueaccord/pb/Service1ScalaImpl.scala create mode 100644 e2e/src/test/scala/GrpcServiceJavaServerSpec.scala create mode 100644 e2e/src/test/scala/GrpcServiceScalaServerSpec.scala create mode 100644 e2e/src/test/scala/GrpcServiceSpecBase.scala create mode 100644 e2e/src/test/scala/UniquePortGenerator.scala diff --git a/compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/DescriptorPimps.scala b/compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/DescriptorPimps.scala index 8d0527be8..c07363763 100644 --- a/compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/DescriptorPimps.scala +++ b/compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/DescriptorPimps.scala @@ -26,7 +26,7 @@ trait DescriptorPimps { def asSymbol: String = if (SCALA_RESERVED_WORDS.contains(s)) s"`$s`" else s } - private def snakeCaseToCamelCase(name: String, upperInitial: Boolean = false): String = { + protected final def snakeCaseToCamelCase(name: String, upperInitial: Boolean = false): String = { val b = new StringBuilder() @annotation.tailrec def inner(name: String, index: Int, capNext: Boolean): Unit = if (name.nonEmpty) { @@ -45,6 +45,22 @@ trait DescriptorPimps { b.toString } + implicit final class MethodDescriptorPimp(self: MethodDescriptor) { + def scalaOut: String = self.getOutputType.scalaTypeName + + def scalaIn: String = self.getInputType.scalaTypeName + + def streamType: StreamType = { + val p = self.toProto + (p.getClientStreaming, p.getServerStreaming) match { + case (false, false) => StreamType.Unary + case (true, false) => StreamType.ClientStreaming + case (false, true) => StreamType.ServerStreaming + case (true, true) => StreamType.Bidirectional + } + } + } + implicit class FieldDescriptorPimp(val fd: FieldDescriptor) { def containingOneOf: Option[OneofDescriptor] = Option(fd.getContainingOneof) diff --git a/compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/FunctionalPrinter.scala b/compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/FunctionalPrinter.scala index b9f601000..5aa74df55 100644 --- a/compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/FunctionalPrinter.scala +++ b/compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/FunctionalPrinter.scala @@ -1,15 +1,36 @@ package com.trueaccord.scalapb.compiler +import com.trueaccord.scalapb.compiler.FunctionalPrinter.PrinterEndo + trait FPrintable { def print(printer: FunctionalPrinter): FunctionalPrinter } +object PrinterEndo { + def apply(endo: PrinterEndo): PrinterEndo = endo +} + +object FunctionalPrinter { + type PrinterEndo = FunctionalPrinter => FunctionalPrinter + val newline: PrinterEndo = _.newline +} + case class FunctionalPrinter(content: List[String] = Nil, indentLevel: Int = 0) { val INDENT_SIZE = 2 + + def seq(s: Seq[String]): FunctionalPrinter = add(s: _*) + def add(s: String*): FunctionalPrinter = { copy(content = s.map(l => " " * (indentLevel * INDENT_SIZE) + l).reverseIterator.toList ::: content) } + /** add with indent */ + def addI(s: String*): FunctionalPrinter = { + this.indent.seq(s).outdent + } + + def newline: FunctionalPrinter = add("") + def addM(s: String): FunctionalPrinter = add(s.stripMargin.split("\n", -1): _*) @@ -31,7 +52,11 @@ case class FunctionalPrinter(content: List[String] = Nil, indentLevel: Int = 0) def indent = copy(indentLevel = indentLevel + 1) def outdent = copy(indentLevel = indentLevel - 1) - def call(f: FunctionalPrinter => FunctionalPrinter) = f(this) + def call(f: (FunctionalPrinter => FunctionalPrinter)*): FunctionalPrinter = + f.foldLeft(this)((p, f) => f(p)) + + def withIndent(f: (FunctionalPrinter => FunctionalPrinter)*): FunctionalPrinter = + f.foldLeft(this.indent)((p, f) => f(p)).outdent def when(cond: => Boolean)(func: FunctionalPrinter => FunctionalPrinter) = if (cond) { diff --git a/compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/GrpcServicePrinter.scala b/compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/GrpcServicePrinter.scala new file mode 100644 index 000000000..1a05cb2dc --- /dev/null +++ b/compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/GrpcServicePrinter.scala @@ -0,0 +1,306 @@ +package com.trueaccord.scalapb.compiler + +import java.util.Locale + +import com.google.protobuf.Descriptors.{MethodDescriptor, ServiceDescriptor} +import com.trueaccord.scalapb.compiler.FunctionalPrinter.PrinterEndo +import scala.collection.JavaConverters._ + +final class GrpcServicePrinter(service: ServiceDescriptor, override val params: GeneratorParams) extends DescriptorPimps { + + private[this] def methodName0(method: MethodDescriptor): String = snakeCaseToCamelCase(method.getName) + private[this] def methodName(method: MethodDescriptor): String = methodName0(method).asSymbol + private[this] def observer(typeParam: String): String = s"$streamObserver[$typeParam]" + + private[this] def methodSignature(method: MethodDescriptor, t: String => String = identity[String]): String = { + method.streamType match { + case StreamType.Unary => + s"def ${methodName(method)}(request: ${method.scalaIn}): ${t(method.scalaOut)}" + case StreamType.ServerStreaming => + s"def ${methodName(method)}(request: ${method.scalaIn}, observer: ${observer(method.scalaOut)}): Unit" + case StreamType.ClientStreaming | StreamType.Bidirectional => + s"def ${methodName(method)}(observer: ${observer(method.scalaOut)}): ${observer(method.scalaIn)}" + } + } + + private[this] def base: PrinterEndo = { + val F = "F[" + (_: String) + "]" + + val methods: PrinterEndo = { p => + p.seq(service.getMethods.asScala.map(methodSignature(_, F))) + } + + { p => + p.add(s"trait ${serviceName(F("_"))} {").withIndent(methods).add("}") + } + } + + private[this] val channel = "_root_.io.grpc.Channel" + private[this] val callOptions = "_root_.io.grpc.CallOptions" + + private[this] def serviceName0 = service.getName.asSymbol + private[this] def serviceName(p: String) = serviceName0 + "[" + p + "]" + private[this] val serviceBlocking = serviceName("({type l[a] = a})#l") + private[this] val serviceFuture = serviceName("scala.concurrent.Future") + + private[this] val futureUnaryCall = "_root_.io.grpc.stub.ClientCalls.futureUnaryCall" + private[this] val abstractStub = "_root_.io.grpc.stub.AbstractStub" + private[this] val streamObserver = "_root_.io.grpc.stub.StreamObserver" + + private[this] object serverCalls { + val unary = "_root_.io.grpc.stub.ServerCalls.asyncUnaryCall" + val clientStreaming = "_root_.io.grpc.stub.ServerCalls.asyncClientStreamingCall" + val serverStreaming = "_root_.io.grpc.stub.ServerCalls.asyncServerStreamingCall" + val bidiStreaming = "_root_.io.grpc.stub.ServerCalls.asyncBidiStreamingCall" + } + + private[this] object clientCalls { + val clientStreaming = "_root_.io.grpc.stub.ClientCalls.asyncClientStreamingCall" + val serverStreaming = "_root_.io.grpc.stub.ClientCalls.asyncServerStreamingCall" + val bidiStreaming = "_root_.io.grpc.stub.ClientCalls.asyncBidiStreamingCall" + } + + private[this] val blockingClientName: String = service.getName + "BlockingClientImpl" + + private[this] def clientMethodImpl(m: MethodDescriptor, blocking: Boolean) = PrinterEndo{ p => + m.streamType match { + case StreamType.Unary => + if(blocking) { + p.add( + "override " + methodSignature(m, identity) + " = {" + ).add( + s""" ${m.scalaOut}.fromJavaProto(_root_.io.grpc.stub.ClientCalls.blockingUnaryCall(channel.newCall(${methodDescriptorName(m)}, options), ${m.scalaIn}.toJavaProto(request)))""", + "}" + ) + } else { + p.add( + "override " + methodSignature(m, "scala.concurrent.Future[" + _ + "]") + " = {" + ).add( + s""" $guavaFuture2ScalaFuture($futureUnaryCall(channel.newCall(${methodDescriptorName(m)}, options), ${m.scalaIn}.toJavaProto(request)))(${m.scalaOut}.fromJavaProto(_))""", + "}" + ) + } + case StreamType.ServerStreaming => + p.add( + "override " + methodSignature(m) + " = {" + ).addI( + s"${clientCalls.serverStreaming}(channel.newCall(${methodDescriptorName(m)}, options), ${m.scalaIn}.toJavaProto(request), $contramapObserver(observer)(${m.scalaOut}.fromJavaProto))" + ).add("}") + case streamType => + val call = if (streamType == StreamType.ClientStreaming) { + clientCalls.clientStreaming + } else { + clientCalls.bidiStreaming + } + + p.add( + "override " + methodSignature(m) + " = {" + ).indent.add( + s"$contramapObserver(" + ).indent.add( + s"$call(channel.newCall(${methodDescriptorName(m)}, options), $contramapObserver(observer)(${m.scalaOut}.fromJavaProto)))(${m.scalaIn}.toJavaProto" + ).outdent.add(")").outdent.add("}") + } + } + + private[this] val blockingClientImpl: PrinterEndo = { p => + val methods = service.getMethods.asScala.map(clientMethodImpl(_, true)) + + val build = + s" override def build(channel: $channel, options: $callOptions): $blockingClientName = new $blockingClientName(channel, options)" + + p.add( + s"class $blockingClientName(channel: $channel, options: $callOptions = $callOptions.DEFAULT) extends $abstractStub[$blockingClientName](channel, options) with $serviceBlocking {" + ).withIndent( + methods : _* + ).add( + build + ).add( + "}" + ) + } + + private[this] val contramapObserver = "contramapObserver" + private[this] val contramapObserverImpl = s"""private[this] def $contramapObserver[A, B](observer: ${observer("A")})(f: B => A): ${observer("B")} = + new ${observer("B")} { + override def onNext(value: B): Unit = observer.onNext(f(value)) + override def onError(t: Throwable): Unit = observer.onError(t) + override def onCompleted(): Unit = observer.onCompleted() + }""" + + private[this] val guavaFuture2ScalaFuture = "guavaFuture2ScalaFuture" + + private[this] val guavaFuture2ScalaFutureImpl = { + s"""private[this] def $guavaFuture2ScalaFuture[A, B](guavaFuture: _root_.com.google.common.util.concurrent.ListenableFuture[A])(converter: A => B): scala.concurrent.Future[B] = { + val p = scala.concurrent.Promise[B]() + _root_.com.google.common.util.concurrent.Futures.addCallback(guavaFuture, new _root_.com.google.common.util.concurrent.FutureCallback[A] { + override def onFailure(t: Throwable) = p.failure(t) + override def onSuccess(a: A) = p.success(converter(a)) + }) + p.future + }""" + } + + private[this] val asyncClientName = service.getName + "AsyncClientImpl" + + private[this] val asyncClientImpl: PrinterEndo = { p => + val methods = service.getMethods.asScala.map(clientMethodImpl(_, false)) + + val build = + s" override def build(channel: $channel, options: $callOptions): $asyncClientName = new $asyncClientName(channel, options)" + + p.add( + s"class $asyncClientName(channel: $channel, options: $callOptions = $callOptions.DEFAULT) extends $abstractStub[$asyncClientName](channel, options) with $serviceFuture {" + ).withIndent( + methods : _* + ).add( + build + ).add( + "}" + ) + } + + private[this] def methodDescriptorName(method: MethodDescriptor): String = + "METHOD_" + method.getName.toUpperCase(Locale.ENGLISH) + + private[this] def methodDescriptor(method: MethodDescriptor) = { + val inJava = method.getInputType.javaTypeName + val outJava = method.getOutputType.javaTypeName + + def marshaller(typeName: String) = + s"_root_.io.grpc.protobuf.ProtoUtils.marshaller($typeName.getDefaultInstance)" + + val methodType = method.streamType match { + case StreamType.Unary => "UNARY" + case StreamType.ClientStreaming => "CLIENT_STREAMING" + case StreamType.ServerStreaming => "SERVER_STREAMING" + case StreamType.Bidirectional => "BIDI_STREAMING" + } + + val grpcMethodDescriptor = "_root_.io.grpc.MethodDescriptor" + +s""" private[this] val ${methodDescriptorName(method)}: $grpcMethodDescriptor[$inJava, $outJava] = + $grpcMethodDescriptor.create( + $grpcMethodDescriptor.MethodType.$methodType, + $grpcMethodDescriptor.generateFullMethodName("${service.getFullName}", "${method.getName}"), + ${marshaller(inJava)}, + ${marshaller(outJava)} + )""" + } + + private[this] val methodDescriptors: Seq[String] = service.getMethods.asScala.map(methodDescriptor) + + private[this] def callMethodName(method: MethodDescriptor) = + methodName0(method) + "Method" + + private[this] def callMethod(method: MethodDescriptor) = + method.streamType match { + case StreamType.Unary => + s"${callMethodName(method)}(service, executionContext)" + case _ => + s"${callMethodName(method)}(service)" + } + + private[this] def createMethod(method: MethodDescriptor): String = { + val javaIn = method.getInputType.javaTypeName + val javaOut = method.getOutputType.javaTypeName + val executionContext = "executionContext" + val name = callMethodName(method) + val serviceImpl = "serviceImpl" + method.streamType match { + case StreamType.Unary => + val serverMethod = s"_root_.io.grpc.stub.ServerCalls.UnaryMethod[$javaIn, $javaOut]" +s""" def ${name}($serviceImpl: $serviceFuture, $executionContext: scala.concurrent.ExecutionContext): $serverMethod = { + new $serverMethod { + override def invoke(request: $javaIn, observer: $streamObserver[$javaOut]): Unit = + $serviceImpl.${methodName(method)}(${method.scalaIn}.fromJavaProto(request)).onComplete { + case scala.util.Success(value) => + observer.onNext(${method.scalaOut}.toJavaProto(value)) + observer.onCompleted() + case scala.util.Failure(error) => + observer.onError(error) + observer.onCompleted() + }($executionContext) + } + }""" + case StreamType.ServerStreaming => + val serverMethod = s"_root_.io.grpc.stub.ServerCalls.ServerStreamingMethod[$javaIn, $javaOut]" + + s""" def ${name}($serviceImpl: $serviceFuture): $serverMethod = { + new $serverMethod { + override def invoke(request: $javaIn, observer: $streamObserver[$javaOut]): Unit = + $serviceImpl.${methodName0(method)}(${method.scalaIn}.fromJavaProto(request), $contramapObserver(observer)(${method.scalaOut}.toJavaProto)) + } + }""" + case _ => + val serverMethod = if(method.streamType == StreamType.ClientStreaming) { + s"_root_.io.grpc.stub.ServerCalls.ClientStreamingMethod[$javaIn, $javaOut]" + } else { + s"_root_.io.grpc.stub.ServerCalls.BidiStreamingMethod[$javaIn, $javaOut]" + } + + s""" def ${name}($serviceImpl: $serviceFuture): $serverMethod = { + new $serverMethod { + override def invoke(observer: $streamObserver[$javaOut]): $streamObserver[$javaIn] = + $contramapObserver($serviceImpl.${methodName0(method)}($contramapObserver(observer)(${method.scalaOut}.toJavaProto)))(${method.scalaIn}.fromJavaProto) + } + }""" + } + } + + private[this] val bindService = { + val executionContext = "executionContext" + val methods = service.getMethods.asScala.map { m => + + val call = m.streamType match { + case StreamType.Unary => serverCalls.unary + case StreamType.ClientStreaming => serverCalls.clientStreaming + case StreamType.ServerStreaming => serverCalls.serverStreaming + case StreamType.Bidirectional => serverCalls.bidiStreaming + } + +s""".addMethod( + ${methodDescriptorName(m)}, + $call( + ${callMethod(m)} + ) + )""" + }.mkString + + val serverServiceDef = "_root_.io.grpc.ServerServiceDefinition" + +s"""def bindService(service: $serviceFuture, $executionContext: scala.concurrent.ExecutionContext): $serverServiceDef = + $serverServiceDef.builder("${service.getFullName}")$methods.build() + """ + } + + val objectName = service.getName + "Grpc" + + def printService(printer: FunctionalPrinter): FunctionalPrinter = { + printer.add( + "package " + service.getFile.scalaPackageName, + "", + "import scala.language.higherKinds", + "", + s"object $objectName {" + ).seq( + service.getMethods.asScala.map(createMethod) + ).seq( + methodDescriptors + ).newline.withIndent( + base, + FunctionalPrinter.newline, + blockingClientImpl, + FunctionalPrinter.newline, + asyncClientImpl + ).newline.addI( + bindService, + guavaFuture2ScalaFutureImpl, + contramapObserverImpl, + s"def blockingClient(channel: $channel): $serviceBlocking = new $blockingClientName(channel)", + s"def futureClient(channel: $channel): $serviceFuture = new $asyncClientName(channel)" + ).add( + "" + ).outdent.add("}") + } +} diff --git a/compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/ProtobufGenerator.scala b/compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/ProtobufGenerator.scala index dc6dea34c..e6f7179a7 100644 --- a/compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/ProtobufGenerator.scala +++ b/compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/ProtobufGenerator.scala @@ -6,7 +6,7 @@ import com.google.protobuf.{ByteString => GoogleByteString} import com.google.protobuf.compiler.PluginProtos.{CodeGeneratorRequest, CodeGeneratorResponse} import scala.collection.JavaConversions._ -case class GeneratorParams(javaConversions: Boolean = false, flatPackage: Boolean = false) +case class GeneratorParams(javaConversions: Boolean = false, flatPackage: Boolean = false, grpc: Boolean = false) class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps { def printEnum(e: EnumDescriptor, printer: FunctionalPrinter): FunctionalPrinter = { @@ -942,6 +942,17 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps { } def generateScalaFilesForFileDescriptor(file: FileDescriptor): Seq[CodeGeneratorResponse.File] = { + val serviceFiles = if(params.grpc) { + file.getServices.map { service => + val p = new GrpcServicePrinter(service, params) + val code = p.printService(FunctionalPrinter()).result() + val b = CodeGeneratorResponse.File.newBuilder() + b.setName(file.scalaPackageName.replace('.', '/') + "/" + p.objectName + ".scala") + b.setContent(code) + b.build + } + } else Nil + val enumFiles = for { enum <- file.getEnumTypes } yield { @@ -977,7 +988,7 @@ class ProtobufGenerator(val params: GeneratorParams) extends DescriptorPimps { b.build } - enumFiles ++ messageFiles :+ fileDescriptorObjectFile + serviceFiles ++ enumFiles ++ messageFiles :+ fileDescriptorObjectFile } } @@ -986,6 +997,7 @@ object ProtobufGenerator { params.split(",").map(_.trim).filter(_.nonEmpty).foldLeft[Either[String, GeneratorParams]](Right(GeneratorParams())) { case (Right(params), "java_conversions") => Right(params.copy(javaConversions = true)) case (Right(params), "flat_package") => Right(params.copy(flatPackage = true)) + case (Right(params), "grpc") => Right(params.copy(grpc = true, javaConversions = true)) case (Right(params), p) => Left(s"Unrecognized parameter: '$p'") case (x, _) => x } diff --git a/compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/StreamType.scala b/compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/StreamType.scala new file mode 100644 index 000000000..f6b4a859d --- /dev/null +++ b/compiler-plugin/src/main/scala/com/trueaccord/scalapb/compiler/StreamType.scala @@ -0,0 +1,9 @@ +package com.trueaccord.scalapb.compiler + +sealed abstract class StreamType extends Product with Serializable +object StreamType { + case object Unary extends StreamType + case object ClientStreaming extends StreamType + case object ServerStreaming extends StreamType + case object Bidirectional extends StreamType +} diff --git a/e2e/build.sbt b/e2e/build.sbt index d72f1203d..873352aa6 100644 --- a/e2e/build.sbt +++ b/e2e/build.sbt @@ -6,12 +6,66 @@ PB.scalapbVersion in PB.protobufConfig := com.trueaccord.scalapb.Version.scalapb PB.javaConversions in PB.protobufConfig := true -PB.runProtoc in PB.protobufConfig := (args => - com.github.os72.protocjar.Protoc.runProtoc("-v300" +: args.toArray)) +PB.runProtoc in PB.protobufConfig := {args0 => + IO.withTemporaryDirectory{ dir => + val exe = dir / "grpc.exe" + java.nio.file.Files.write(exe.toPath, grpcExe.value.get()) + exe.setExecutable(true) + val args = args0 ++ Array( + s"--plugin=protoc-gen-java_rpc=${exe.getAbsolutePath}", + s"--java_rpc_out=${((sourceManaged in Compile).value / "compiled_protobuf").getAbsolutePath}" + ) + com.github.os72.protocjar.Protoc.runProtoc("-v300" +: args.toArray) + } +} + +val grpcVersion = "0.9.0" libraryDependencies ++= Seq( "org.scalatest" %% "scalatest" % "2.2.1" % "test", + "io.grpc" % "grpc-all" % grpcVersion, "org.scalacheck" %% "scalacheck" % "1.12.4" % "test", "com.trueaccord.scalapb" %% "scalapb-runtime" % com.trueaccord.scalapb.Version.scalapbVersion % PB.protobufConfig ) +val grpcExe = SettingKey[xsbti.api.Lazy[Array[Byte]]]("grpcExeFile") + +def grpcExeUrl() = { + val os = if(scala.util.Properties.isMac){ + "osx-x86_64" + }else if(scala.util.Properties.isWin){ + "windows-x86_64" + }else{ + "linux-x86_64" + } + val artifactId = "protoc-gen-grpc-java" + s"http://repo1.maven.org/maven2/io/grpc/${artifactId}/${grpcVersion}/${artifactId}-${grpcVersion}-${os}.exe" +} + +grpcExe := xsbti.SafeLazy{ + IO.withTemporaryDirectory{ dir => + val f = dir / "temp.exe" + val u = grpcExeUrl() + println("download from " + u) + IO.download(url(u), f) + java.nio.file.Files.readAllBytes(f.toPath) + } +} + +// TODO add `grpc: SettingKey[Boolean]` to sbt-scalapb +PB.protocOptions in PB.protobufConfig := { + val conf = (PB.generatedTargets in PB.protobufConfig).value + val scalaOpts = conf.find(_._2.endsWith(".scala")) match { + case Some(targetForScala) => + Seq(s"--scala_out=grpc:${targetForScala._1.absolutePath}") + case None => + Nil + } + val javaOpts = conf.find(_._2.endsWith(".java")) match { + case Some(targetForJava) => + Seq(s"--java_out=${targetForJava._1.absolutePath}") + case None => + Nil + } + scalaOpts ++ javaOpts +} diff --git a/e2e/src/main/protobuf/service.proto b/e2e/src/main/protobuf/service.proto new file mode 100644 index 000000000..522c6ed75 --- /dev/null +++ b/e2e/src/main/protobuf/service.proto @@ -0,0 +1,36 @@ +syntax = "proto3"; + +package com.trueaccord.proto.e2e; + +import "scalapb/scalapb.proto"; + +message Req1 { + string request = 1; +} +message Res1 { + int32 length = 1; +} + +message Req2 {} +message Res2 { + uint32 count = 1; +} + +message Req3 { + uint32 num = 1; +} +message Res3 {} + +message Req4 { + int32 a = 1; +} +message Res4 { + int32 b = 1; +} + +service Service1 { + rpc method1(Req1) returns (Res1) {} + rpc method2(stream Req2) returns (Res2) {} + rpc method3(Req3) returns (stream Res3) {} + rpc method4(stream Req4) returns (stream Res4) {} +} diff --git a/e2e/src/main/scala/com/trueaccord/pb/Service1JavaImpl.scala b/e2e/src/main/scala/com/trueaccord/pb/Service1JavaImpl.scala new file mode 100644 index 000000000..9ce3e2a59 --- /dev/null +++ b/e2e/src/main/scala/com/trueaccord/pb/Service1JavaImpl.scala @@ -0,0 +1,53 @@ +package com.trueaccord.pb + +import java.util.concurrent.atomic.AtomicInteger + +import com.trueaccord.proto.e2e.Service._ +import com.trueaccord.proto.e2e.Service1Grpc._ +import io.grpc.stub.StreamObserver + +class Service1JavaImpl extends Service1{ + + override def method1(request: Req1, observer: StreamObserver[Res1]): Unit = { + val res = Res1.newBuilder.setLength(request.getRequest.length).build() + observer.onNext(res) + observer.onCompleted() + } + + override def method2(observer: StreamObserver[Res2]) = + new StreamObserver[Req2] { + private[this] val counter = new AtomicInteger() + override def onError(e: Throwable): Unit = + observer.onError(e) + + override def onCompleted(): Unit = { + val res = Res2.newBuilder().setCount(counter.getAndSet(0)).build() + observer.onNext(res) + } + + override def onNext(v: Req2): Unit = { + counter.incrementAndGet() + } + } + + private[this] var method3Counter = 0 + + override def method3(request: Req3, observer: StreamObserver[Res3]): Unit = synchronized{ + method3Counter += request.getNum + if(method3Counter > Service1ScalaImpl.method3Limit){ + observer.onNext(Res3.getDefaultInstance) + observer.onCompleted() + } + } + + override def method4(observer: StreamObserver[Res4]): StreamObserver[Req4] = + new StreamObserver[Req4] { + override def onError(e: Throwable): Unit = {} + override def onCompleted(): Unit = {} + override def onNext(request: Req4): Unit = { + observer.onNext(Res4.newBuilder.setB(request.getA * 2).build()) + observer.onCompleted() + } + } + +} diff --git a/e2e/src/main/scala/com/trueaccord/pb/Service1ScalaImpl.scala b/e2e/src/main/scala/com/trueaccord/pb/Service1ScalaImpl.scala new file mode 100644 index 000000000..2d23e32da --- /dev/null +++ b/e2e/src/main/scala/com/trueaccord/pb/Service1ScalaImpl.scala @@ -0,0 +1,56 @@ +package com.trueaccord.pb + +import java.util.concurrent.atomic.AtomicInteger + +import com.trueaccord.proto.e2e.service.Service1Grpc.Service1 +import com.trueaccord.proto.e2e.service._ +import io.grpc.stub.StreamObserver + +import scala.concurrent.Future + +object Service1ScalaImpl { + + val method3Limit = 50 + +} + +class Service1ScalaImpl extends Service1[Future]{ + + override def method1(request: Req1): Future[Res1] = + Future.successful(Res1(length = request.request.length)) + + override def method2(observer: StreamObserver[Res2]) = + new StreamObserver[Req2] { + private[this] val counter = new AtomicInteger() + override def onError(e: Throwable): Unit = + observer.onError(e) + + override def onCompleted(): Unit = { + observer.onNext(Res2(counter.getAndSet(0))) + } + + override def onNext(v: Req2): Unit = { + counter.incrementAndGet() + } + } + + private[this] var method3Counter = 0 + + override def method3(request: Req3, observer: StreamObserver[Res3]): Unit = synchronized{ + method3Counter += request.num + if(method3Counter > Service1ScalaImpl.method3Limit){ + observer.onNext(Res3()) + observer.onCompleted() + } + } + + override def method4(observer: StreamObserver[Res4]): StreamObserver[Req4] = + new StreamObserver[Req4] { + override def onError(e: Throwable): Unit = {} + override def onCompleted(): Unit = {} + override def onNext(request: Req4): Unit = { + observer.onNext(Res4(request.a * 2)) + observer.onCompleted() + } + } +} diff --git a/e2e/src/test/scala/GrpcServiceJavaServerSpec.scala b/e2e/src/test/scala/GrpcServiceJavaServerSpec.scala new file mode 100644 index 000000000..1b9a5c364 --- /dev/null +++ b/e2e/src/test/scala/GrpcServiceJavaServerSpec.scala @@ -0,0 +1,84 @@ +import java.util.concurrent.TimeoutException + +import com.trueaccord.pb.Service1ScalaImpl +import com.trueaccord.proto.e2e.service.{Service1Grpc => Service1GrpcScala, _} + +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.util.Random + +class GrpcServiceJavaServerSpec extends GrpcServiceSpecBase { + + describe("java server") { + + it("method1 blockingClient") { + withJavaServer { channel => + val client = Service1GrpcScala.blockingClient(channel) + val string = randomString() + assert(client.method1(Req1(string)).length === string.length) + } + } + + it("method1 futureClient") { + withJavaServer { channel => + val client = Service1GrpcScala.futureClient(channel) + val string = randomString() + assert(Await.result(client.method1(Req1(string)), 2.seconds).length === string.length) + } + } + + it("method2") { + withJavaServer { channel => + val client = Service1GrpcScala.futureClient(channel) + val (responseObserver, future) = getObserverAndFuture[Res2] + val requestObserver = client.method2(responseObserver) + val n = Random.nextInt(10) + for (_ <- 1 to n) { + requestObserver.onNext(Req2()) + } + + intercept[TimeoutException]{ + Await.result(future, 2.seconds) + } + + requestObserver.onCompleted() + assert(Await.result(future, 2.seconds).count === n) + } + } + + it("method3") { + withJavaServer { channel => + val client = Service1GrpcScala.futureClient(channel) + val (observer, future) = getObserverAndFuture[Res3] + val requests = Stream.continually(Req3(num = Random.nextInt(10))) + val count = requests.scanLeft(0)(_ + _.num).takeWhile(_ < Service1ScalaImpl.method3Limit).size - 1 + + requests.take(count).foreach { req => + client.method3(req, observer) + } + + intercept[TimeoutException]{ + Await.result(future, 2.seconds) + } + + client.method3(Req3(1000), observer) + Await.result(future, 2.seconds) + } + } + + it("method4") { + withJavaServer { channel => + val client = Service1GrpcScala.futureClient(channel) + val (responseObserver, future) = getObserverAndFuture[Res4] + val requestObserver = client.method4(responseObserver) + intercept[TimeoutException]{ + Await.result(future, 2.seconds) + } + val request = Req4(a = Random.nextInt()) + requestObserver.onNext(request) + assert(Await.result(future, 2.seconds).b === (request.a * 2)) + } + } + } + +} diff --git a/e2e/src/test/scala/GrpcServiceScalaServerSpec.scala b/e2e/src/test/scala/GrpcServiceScalaServerSpec.scala new file mode 100644 index 000000000..35688a510 --- /dev/null +++ b/e2e/src/test/scala/GrpcServiceScalaServerSpec.scala @@ -0,0 +1,109 @@ +import java.util.concurrent.TimeoutException + +import com.trueaccord.pb.Service1ScalaImpl + +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.util.Random + +class GrpcServiceScalaServerSpec extends GrpcServiceSpecBase { + + describe("scala server") { + + describe("java client") { + import com.trueaccord.proto.e2e.{Service1Grpc => Service1GrpcJava, _} + + it("method1 BlockingStub") { + withScalaServer { channel => + val client = Service1GrpcJava.newBlockingStub(channel) + val string = randomString() + val request = Service.Req1.newBuilder.setRequest(string).build() + assert(client.method1(request).getLength === string.length) + } + } + + it("method1 FeatureStub") { + withScalaServer { channel => + val client = Service1GrpcJava.newFutureStub(channel) + val string = randomString() + val request = Service.Req1.newBuilder.setRequest(string).build() + assert(client.method1(request).get().getLength === string.length) + } + } + } + + describe("scala client") { + import com.trueaccord.proto.e2e.service.{Service1Grpc => Service1GrpcScala, _} + + it("method1 blockingClient") { + withScalaServer { channel => + val client = Service1GrpcScala.blockingClient(channel) + val string = randomString() + assert(client.method1(Req1(string)).length === string.length) + } + } + + it("method1 futureClient") { + withScalaServer { channel => + val client = Service1GrpcScala.futureClient(channel) + val string = randomString() + assert(Await.result(client.method1(Req1(string)), 2.seconds).length === string.length) + } + } + + it("method2") { + withScalaServer { channel => + val client = Service1GrpcScala.futureClient(channel) + val (responseObserver, future) = getObserverAndFuture[Res2] + val requestObserver = client.method2(responseObserver) + val n = Random.nextInt(10) + for (_ <- 1 to n) { + requestObserver.onNext(Req2()) + } + + intercept[TimeoutException]{ + Await.result(future, 2.seconds) + } + + requestObserver.onCompleted() + assert(Await.result(future, 2.seconds).count === n) + } + } + + it("method3") { + withScalaServer { channel => + val client = Service1GrpcScala.futureClient(channel) + val (observer, future) = getObserverAndFuture[Res3] + val requests = Stream.continually(Req3(num = Random.nextInt(10))) + val count = requests.scanLeft(0)(_ + _.num).takeWhile(_ < Service1ScalaImpl.method3Limit).size - 1 + + requests.take(count).foreach { req => + client.method3(req, observer) + } + + intercept[TimeoutException]{ + Await.result(future, 2.seconds) + } + + client.method3(Req3(1000), observer) + Await.result(future, 2.seconds) + } + } + + it("method4") { + withScalaServer { channel => + val client = Service1GrpcScala.futureClient(channel) + val (responseObserver, future) = getObserverAndFuture[Res4] + val requestObserver = client.method4(responseObserver) + intercept[TimeoutException]{ + Await.result(future, 2.seconds) + } + val request = Req4(a = Random.nextInt()) + requestObserver.onNext(request) + assert(Await.result(future, 2.seconds).b === (request.a * 2)) + } + } + } + } + +} diff --git a/e2e/src/test/scala/GrpcServiceSpecBase.scala b/e2e/src/test/scala/GrpcServiceSpecBase.scala new file mode 100644 index 000000000..fc4ca67a2 --- /dev/null +++ b/e2e/src/test/scala/GrpcServiceSpecBase.scala @@ -0,0 +1,57 @@ +import java.util.concurrent.TimeUnit + +import com.trueaccord.pb.{Service1JavaImpl, Service1ScalaImpl} +import com.trueaccord.proto.e2e.service.{Service1Grpc => Service1GrpcScala} +import com.trueaccord.proto.e2e.{Service1Grpc => Service1GrpcJava} +import io.grpc.netty.{NegotiationType, NettyChannelBuilder, NettyServerBuilder} +import io.grpc.stub.StreamObserver +import io.grpc.{ManagedChannel, ServerServiceDefinition} +import org.scalatest.FunSpec + +import scala.concurrent.{ExecutionContext, Future, Promise} +import scala.util.Random + +abstract class GrpcServiceSpecBase extends FunSpec { + + protected[this] final def withScalaServer[A](f: ManagedChannel => A): A = { + withServer(Service1GrpcScala.bindService(new Service1ScalaImpl, singleThreadExecutionContext))(f) + } + + protected[this] final def withJavaServer[A](f: ManagedChannel => A): A = { + withServer(Service1GrpcJava.bindService(new Service1JavaImpl))(f) + } + + private[this] def withServer[A](services: ServerServiceDefinition*)(f: ManagedChannel => A): A = { + val port = UniquePortGenerator.get() + val server = services.foldLeft(NettyServerBuilder.forPort(port))(_.addService(_)).build() + try { + server.start() + val channel = NettyChannelBuilder.forAddress("localhost", port).negotiationType(NegotiationType.PLAINTEXT).build() + f(channel) + } finally { + server.shutdown() + server.awaitTermination(3000, TimeUnit.MILLISECONDS) + } + } + + private[this] val singleThreadExecutionContext = new ExecutionContext { + override def reportFailure(cause: Throwable): Unit = cause.printStackTrace() + + override def execute(runnable: Runnable): Unit = runnable.run() + } + + protected[this] final def getObserverAndFuture[A]: (StreamObserver[A], Future[A]) = { + val promise = Promise[A]() + val observer = new StreamObserver[A] { + override def onError(t: Throwable): Unit = {} + + override def onCompleted(): Unit = {} + + override def onNext(value: A): Unit = promise.success(value) + } + (observer, promise.future) + } + + protected[this] final def randomString(): String = Random.alphanumeric.take(Random.nextInt(10)).mkString + +} diff --git a/e2e/src/test/scala/UniquePortGenerator.scala b/e2e/src/test/scala/UniquePortGenerator.scala new file mode 100644 index 000000000..cd1862bfe --- /dev/null +++ b/e2e/src/test/scala/UniquePortGenerator.scala @@ -0,0 +1,34 @@ +import java.net.ServerSocket + +import scala.util.Random + +object UniquePortGenerator { + private[this] val usingPorts = collection.mutable.HashSet.empty[Int] + + def getOpt(): Option[Int] = synchronized { + @annotation.tailrec + def loop(loopCount: Int): Option[Int] = { + val socket = new ServerSocket(0) + val port = try { + socket.getLocalPort + } finally { + socket.close() + } + + if (usingPorts(port)) { + if (loopCount == 0) { + None + } else { + Thread.sleep(Random.nextInt(50)) + loop(loopCount - 1) + } + } else { + usingPorts += port + Option(port) + } + } + loop(30) + } + + def get(): Int = getOpt().getOrElse(sys.error("could not get port")) +} diff --git a/proptest/build.sbt b/proptest/build.sbt index 088b087cd..09742c414 100644 --- a/proptest/build.sbt +++ b/proptest/build.sbt @@ -5,6 +5,7 @@ resolvers ++= Seq( libraryDependencies ++= Seq( "com.github.os72" % "protoc-jar" % "3.0.0-b1", "com.google.protobuf" % "protobuf-java" % "3.0.0-beta-1", + "io.grpc" % "grpc-all" % "0.9.0" % "test", "com.trueaccord.lenses" %% "lenses" % "0.4.1", "org.scalacheck" %% "scalacheck" % "1.12.4" % "test", "org.scalatest" %% "scalatest" % (if (scalaVersion.value.startsWith("2.12")) "2.2.5-M2" else "2.2.5") % "test" diff --git a/proptest/src/test/scala/GraphGen.scala b/proptest/src/test/scala/GraphGen.scala index 97c68f6c6..78e64e03d 100644 --- a/proptest/src/test/scala/GraphGen.scala +++ b/proptest/src/test/scala/GraphGen.scala @@ -1,6 +1,7 @@ import GenUtils._ import GenTypes.{FieldOptions, ProtoType, FieldModifier} import com.trueaccord.scalapb.Scalapb.ScalaPbOptions +import com.trueaccord.scalapb.compiler.StreamType import org.scalacheck.{Arbitrary, Gen} object GraphGen { @@ -138,6 +139,22 @@ object GraphGen { (Some(b.build), state) } + val genStreamType: Gen[StreamType] = Gen.oneOf( + StreamType.Unary, StreamType.ClientStreaming, StreamType.ServerStreaming, StreamType.Bidirectional + ) + + def genService(messages: Seq[MessageNode])(state: State): Gen[(ServiceNode, State)] = for{ + (methods, state) <- listWithStatefulGen(state , maxSize = 3)(genMethod(messages)) + (name, state) <- state.generateName + } yield ServiceNode(name, methods) -> state + + def genMethod(messages: Seq[MessageNode])(state: State): Gen[(MethodNode, State)] = for{ + req <- Gen.oneOf(messages) + res <- Gen.oneOf(messages) + stream <- genStreamType + (name, state) <- state.generateName + } yield MethodNode(name, req, res, stream) -> state + def genFileNode(state: State): Gen[(FileNode, State)] = sized { s => for { @@ -152,7 +169,8 @@ object GraphGen { (messages, state) <- listWithStatefulGen(state, maxSize = 4)(genMessageNode(0, None, protoSyntax)) (enums, state) <- listWithStatefulGen(state, maxSize = 3)(genEnumNode(None, protoSyntax)) javaMulti <- implicitly[Arbitrary[Boolean]].arbitrary - } yield (FileNode(baseName, protoSyntax, protoPackageOption, javaPackageOption, javaMulti, scalaOptions, messages, enums, fileId), + (services, state) <- listWithStatefulGen(state, maxSize = 3)(genService(messages)) + } yield (FileNode(baseName, protoSyntax, protoPackageOption, javaPackageOption, javaMulti, scalaOptions, messages, services, enums, fileId), if (protoPackage.isEmpty) state else state.closeNamespace) } diff --git a/proptest/src/test/scala/Nodes.scala b/proptest/src/test/scala/Nodes.scala index a970e9df9..e918a52aa 100644 --- a/proptest/src/test/scala/Nodes.scala +++ b/proptest/src/test/scala/Nodes.scala @@ -1,7 +1,7 @@ import com.trueaccord.scalapb.Scalapb.ScalaPbOptions import com.trueaccord.scalapb.compiler -import com.trueaccord.scalapb.compiler.{FunctionalPrinter, FPrintable} +import com.trueaccord.scalapb.compiler.{StreamType, FunctionalPrinter, FPrintable} import scala.collection.mutable import scala.util.Try @@ -107,6 +107,29 @@ object Nodes { lazy val filesById: Map[Int, FileNode] = files.map(f => (f.fileId, f)).toMap } + final case class MethodNode(name: String, request: MessageNode, response: MessageNode, streamType: StreamType) { + def print(printer: FunctionalPrinter): FunctionalPrinter = { + val method = streamType match { + case StreamType.Unary => + s"rpc $name (${request.name}) returns (${response.name}) {};" + case StreamType.ClientStreaming => + s"rpc $name (stream ${request.name}) returns (${response.name}) {};" + case StreamType.ServerStreaming => + s"rpc $name (${request.name}) returns (stream ${response.name}) {};" + case StreamType.Bidirectional => + s"rpc $name (stream ${request.name}) returns (stream ${response.name}) {};" + } + printer.add(method) + } + } + + final case class ServiceNode(name: String, methods: Seq[MethodNode]) { + def print(printer: FunctionalPrinter): FunctionalPrinter = + printer.add(s"service $name {").indent + .print(methods)(_ print _).outdent + .add("}") + } + case class FileNode(baseFileName: String, protoSyntax: ProtoSyntax, protoPackage: Option[String], @@ -114,6 +137,7 @@ object Nodes { javaMultipleFiles: Boolean, scalaOptions: Option[ScalaPbOptions], messages: Seq[MessageNode], + services: Seq[ServiceNode], enums: Seq[EnumNode], fileId: Int) extends Node { def allMessages = messages.foldLeft(Stream.empty[MessageNode])(_ ++ _.allMessages) @@ -154,6 +178,7 @@ object Nodes { }).toSeq: _*) .printAll(enums) .print(messages)(_.print(rootNode, this, _)) + .print(services)(_ print _) /** * @return diff --git a/proptest/src/test/scala/SchemaGenerators.scala b/proptest/src/test/scala/SchemaGenerators.scala index 30b1a092e..165d9f4ff 100644 --- a/proptest/src/test/scala/SchemaGenerators.scala +++ b/proptest/src/test/scala/SchemaGenerators.scala @@ -102,7 +102,7 @@ object SchemaGenerators { val args = Seq("--proto_path", (tmpDir.toString + ":protobuf:third_party"), "--java_out", tmpDir.toString, - "--scala_out", "java_conversions:" + tmpDir.toString) ++ files + "--scala_out", "grpc,java_conversions:" + tmpDir.toString) ++ files runProtoc(args: _*) } @@ -141,6 +141,9 @@ object SchemaGenerators { jarForClass[com.trueaccord.scalapb.GeneratedMessage].getPath, jarForClass[com.trueaccord.scalapb.Scalapb].getPath, jarForClass[com.google.protobuf.Message].getPath, + jarForClass[io.grpc.Channel].getPath, + jarForClass[com.google.common.util.concurrent.ListenableFuture[_]], + jarForClass[javax.annotation.Nullable], jarForClass[com.trueaccord.lenses.Lens[_, _]].getPath, rootDir )