Skip to content

Commit

Permalink
Use scalapb-runtime-grpc to avoid conversions to/from Java.
Browse files Browse the repository at this point in the history
  • Loading branch information
thesamet committed Dec 2, 2015
1 parent 61ceba4 commit 59e9fd5
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 50 deletions.
13 changes: 12 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ lazy val root =
publish := {},
publishLocal := {}
).aggregate(
runtimeJS, runtimeJVM, compilerPlugin, proptest, scalapbc)
runtimeJS, runtimeJVM, grpcRuntime, compilerPlugin, proptest, scalapbc)

lazy val runtime = crossProject.crossType(CrossType.Full).in(file("scalapb-runtime"))
.settings(
Expand Down Expand Up @@ -78,6 +78,17 @@ lazy val runtime = crossProject.crossType(CrossType.Full).in(file("scalapb-runti
lazy val runtimeJVM = runtime.jvm
lazy val runtimeJS = runtime.js

val grpcVersion = "0.9.0"

lazy val grpcRuntime = project.in(file("scalapb-runtime-grpc"))
.dependsOn(runtimeJVM)
.settings(
name := "scalapb-runtime-grpc",
libraryDependencies ++= Seq(
"io.grpc" % "grpc-all" % grpcVersion
)
)

lazy val compilerPlugin = project.in(file("compiler-plugin"))
.dependsOn(runtimeJVM)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,22 @@ final class GrpcServicePrinter(service: ServiceDescriptor, override val params:
p.add(
"override " + methodSignature(m, identity) + " = {"
).add(
s""" ${m.scalaOut}.fromJavaProto(_root_.io.grpc.stub.ClientCalls.blockingUnaryCall(channel.newCall(${methodDescriptorName(m)}, options), ${m.scalaIn}.toJavaProto(request)))""",
s""" _root_.io.grpc.stub.ClientCalls.blockingUnaryCall(channel.newCall(${methodDescriptorName(m)}, options), request)""",
"}"
)
} else {
p.add(
"override " + methodSignature(m, "scala.concurrent.Future[" + _ + "]") + " = {"
).add(
s""" $guavaFuture2ScalaFuture($futureUnaryCall(channel.newCall(${methodDescriptorName(m)}, options), ${m.scalaIn}.toJavaProto(request)))(${m.scalaOut}.fromJavaProto(_))""",
s""" $guavaFuture2ScalaFuture($futureUnaryCall(channel.newCall(${methodDescriptorName(m)}, options), request))""",
"}"
)
}
case StreamType.ServerStreaming =>
p.add(
"override " + methodSignature(m) + " = {"
).addI(
s"${clientCalls.serverStreaming}(channel.newCall(${methodDescriptorName(m)}, options), ${m.scalaIn}.toJavaProto(request), $contramapObserver(observer)(${m.scalaOut}.fromJavaProto))"
s"${clientCalls.serverStreaming}(channel.newCall(${methodDescriptorName(m)}, options), request, observer)"
).add("}")
case streamType =>
val call = if (streamType == StreamType.ClientStreaming) {
Expand All @@ -96,10 +96,8 @@ final class GrpcServicePrinter(service: ServiceDescriptor, override val params:
p.add(
"override " + methodSignature(m) + " = {"
).indent.add(
s"$contramapObserver("
).indent.add(
s"$call(channel.newCall(${methodDescriptorName(m)}, options), $contramapObserver(observer)(${m.scalaOut}.fromJavaProto)))(${m.scalaIn}.toJavaProto"
).outdent.add(")").outdent.add("}")
s"$call(channel.newCall(${methodDescriptorName(m)}, options), observer)"
).outdent.add("}")
}
}

Expand All @@ -120,26 +118,7 @@ final class GrpcServicePrinter(service: ServiceDescriptor, override val params:
)
}

private[this] val contramapObserver = "contramapObserver"
private[this] val contramapObserverImpl = s"""private[this] def $contramapObserver[A, B](observer: ${observer("A")})(f: B => A): ${observer("B")} =
new ${observer("B")} {
override def onNext(value: B): Unit = observer.onNext(f(value))
override def onError(t: Throwable): Unit = observer.onError(t)
override def onCompleted(): Unit = observer.onCompleted()
}"""

private[this] val guavaFuture2ScalaFuture = "guavaFuture2ScalaFuture"

private[this] val guavaFuture2ScalaFutureImpl = {
s"""private[this] def $guavaFuture2ScalaFuture[A, B](guavaFuture: _root_.com.google.common.util.concurrent.ListenableFuture[A])(converter: A => B): scala.concurrent.Future[B] = {
val p = scala.concurrent.Promise[B]()
_root_.com.google.common.util.concurrent.Futures.addCallback(guavaFuture, new _root_.com.google.common.util.concurrent.FutureCallback[A] {
override def onFailure(t: Throwable) = p.failure(t)
override def onSuccess(a: A) = p.success(converter(a))
})
p.future
}"""
}
private[this] val guavaFuture2ScalaFuture = "com.trueaccord.scalapb.grpc.Grpc.guavaFuture2ScalaFuture"

private[this] val asyncClientName = service.getName + "AsyncClientImpl"

Expand All @@ -164,11 +143,8 @@ final class GrpcServicePrinter(service: ServiceDescriptor, override val params:
"METHOD_" + method.getName.toUpperCase(Locale.ENGLISH)

private[this] def methodDescriptor(method: MethodDescriptor) = {
val inJava = method.getInputType.javaTypeName
val outJava = method.getOutputType.javaTypeName

def marshaller(typeName: String) =
s"_root_.io.grpc.protobuf.ProtoUtils.marshaller($typeName.getDefaultInstance)"
s"new _root_.com.trueaccord.scalapb.grpc.Marshaller($typeName)"

val methodType = method.streamType match {
case StreamType.Unary => "UNARY"
Expand All @@ -179,12 +155,12 @@ final class GrpcServicePrinter(service: ServiceDescriptor, override val params:

val grpcMethodDescriptor = "_root_.io.grpc.MethodDescriptor"

s""" private[this] val ${methodDescriptorName(method)}: $grpcMethodDescriptor[$inJava, $outJava] =
s""" private[this] val ${methodDescriptorName(method)}: $grpcMethodDescriptor[${method.scalaIn}, ${method.scalaOut}] =
$grpcMethodDescriptor.create(
$grpcMethodDescriptor.MethodType.$methodType,
$grpcMethodDescriptor.generateFullMethodName("${service.getFullName}", "${method.getName}"),
${marshaller(inJava)},
${marshaller(outJava)}
${marshaller(method.scalaIn)},
${marshaller(method.scalaOut)}
)"""
}

Expand All @@ -202,20 +178,18 @@ s""" private[this] val ${methodDescriptorName(method)}: $grpcMethodDescriptor[$
}

private[this] def createMethod(method: MethodDescriptor): String = {
val javaIn = method.getInputType.javaTypeName
val javaOut = method.getOutputType.javaTypeName
val executionContext = "executionContext"
val name = callMethodName(method)
val serviceImpl = "serviceImpl"
method.streamType match {
case StreamType.Unary =>
val serverMethod = s"_root_.io.grpc.stub.ServerCalls.UnaryMethod[$javaIn, $javaOut]"
val serverMethod = s"_root_.io.grpc.stub.ServerCalls.UnaryMethod[${method.scalaIn}, ${method.scalaOut}]"
s""" def ${name}($serviceImpl: $serviceFuture, $executionContext: scala.concurrent.ExecutionContext): $serverMethod = {
new $serverMethod {
override def invoke(request: $javaIn, observer: $streamObserver[$javaOut]): Unit =
$serviceImpl.${methodName(method)}(${method.scalaIn}.fromJavaProto(request)).onComplete {
override def invoke(request: ${method.scalaIn}, observer: $streamObserver[${method.scalaOut}]): Unit =
$serviceImpl.${methodName(method)}(request).onComplete {
case scala.util.Success(value) =>
observer.onNext(${method.scalaOut}.toJavaProto(value))
observer.onNext(value)
observer.onCompleted()
case scala.util.Failure(error) =>
observer.onError(error)
Expand All @@ -224,25 +198,25 @@ s""" def ${name}($serviceImpl: $serviceFuture, $executionContext: scala.concurr
}
}"""
case StreamType.ServerStreaming =>
val serverMethod = s"_root_.io.grpc.stub.ServerCalls.ServerStreamingMethod[$javaIn, $javaOut]"
val serverMethod = s"_root_.io.grpc.stub.ServerCalls.ServerStreamingMethod[${method.scalaIn}, ${method.scalaOut}]"

s""" def ${name}($serviceImpl: $serviceFuture): $serverMethod = {
new $serverMethod {
override def invoke(request: $javaIn, observer: $streamObserver[$javaOut]): Unit =
$serviceImpl.${methodName0(method)}(${method.scalaIn}.fromJavaProto(request), $contramapObserver(observer)(${method.scalaOut}.toJavaProto))
override def invoke(request: ${method.scalaIn}, observer: $streamObserver[${method.scalaOut}]): Unit =
$serviceImpl.${methodName0(method)}(request, observer)
}
}"""
case _ =>
val serverMethod = if(method.streamType == StreamType.ClientStreaming) {
s"_root_.io.grpc.stub.ServerCalls.ClientStreamingMethod[$javaIn, $javaOut]"
s"_root_.io.grpc.stub.ServerCalls.ClientStreamingMethod[${method.scalaIn}, ${method.scalaOut}]"
} else {
s"_root_.io.grpc.stub.ServerCalls.BidiStreamingMethod[$javaIn, $javaOut]"
s"_root_.io.grpc.stub.ServerCalls.BidiStreamingMethod[${method.scalaIn}, ${method.scalaOut}]"
}

s""" def ${name}($serviceImpl: $serviceFuture): $serverMethod = {
new $serverMethod {
override def invoke(observer: $streamObserver[$javaOut]): $streamObserver[$javaIn] =
$contramapObserver($serviceImpl.${methodName0(method)}($contramapObserver(observer)(${method.scalaOut}.toJavaProto)))(${method.scalaIn}.fromJavaProto)
override def invoke(observer: $streamObserver[${method.scalaOut}]): $streamObserver[${method.scalaIn}] =
$serviceImpl.${methodName0(method)}(observer)
}
}"""
}
Expand Down Expand Up @@ -295,8 +269,6 @@ s"""def bindService(service: $serviceFuture, $executionContext: scala.concurrent
asyncClientImpl
).newline.addI(
bindService,
guavaFuture2ScalaFutureImpl,
contramapObserverImpl,
s"def blockingClient(channel: $channel): $serviceBlocking = new $blockingClientName(channel)",
s"def futureClient(channel: $channel): $serviceFuture = new $asyncClientName(channel)"
).add(
Expand Down
2 changes: 2 additions & 0 deletions e2e/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ libraryDependencies ++= Seq(
"org.scalatest" %% "scalatest" % "2.2.1" % "test",
"io.grpc" % "grpc-all" % grpcVersion,
"org.scalacheck" %% "scalacheck" % "1.12.4" % "test",
"com.trueaccord.scalapb" %% "scalapb-runtime-grpc" % com.trueaccord.scalapb.Version.scalapbVersion,
"com.trueaccord.scalapb" %% "scalapb-runtime" % com.trueaccord.scalapb.Version.scalapbVersion % PB.protobufConfig
)

Expand All @@ -39,6 +40,7 @@ def grpcExeUrl() = {
"linux-x86_64"
}
val artifactId = "protoc-gen-grpc-java"
// s"file:///Users/nadavsr/Downloads/${artifactId}-${grpcVersion}-${os}.exe"
s"http://repo1.maven.org/maven2/io/grpc/${artifactId}/${grpcVersion}/${artifactId}-${grpcVersion}-${os}.exe"
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.trueaccord.scalapb.grpc

import com.google.common.util.concurrent.{FutureCallback, Futures, ListenableFuture}

import scala.concurrent.{Promise, Future}

object Grpc {
def guavaFuture2ScalaFuture[A](guavaFuture: ListenableFuture[A]): Future[A] = {
val p = Promise[A]()
Futures.addCallback(guavaFuture, new FutureCallback[A] {
override def onFailure(t: Throwable) = p.failure(t)
override def onSuccess(a: A) = p.success(a)
})
p.future
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.trueaccord.scalapb.grpc

import java.io.InputStream

import com.trueaccord.scalapb.grpc.ProtoInputStream.NotStarted
import com.trueaccord.scalapb.{GeneratedMessage, GeneratedMessageCompanion, Message}

class Marshaller[T <: GeneratedMessage with Message[T]](companion: GeneratedMessageCompanion[T]) extends io.grpc.MethodDescriptor.Marshaller[T] {
override def stream(t: T): InputStream = ProtoInputStream.newInstance(t)

override def parse(inputStream: InputStream): T = inputStream match {
/* Optimization for in-memory transport. */
case ProtoInputStream(NotStarted(m: T @unchecked)) if m.companion == companion =>
m
case is => companion.parseFrom(is)
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package com.trueaccord.scalapb.grpc

import java.io.{ByteArrayInputStream, IOException, InputStream, OutputStream}

import com.google.common.io.ByteStreams
import com.google.protobuf.CodedOutputStream
import com.trueaccord.scalapb.GeneratedMessage
import com.trueaccord.scalapb.grpc.ProtoInputStream._

private class ProtoInputStream(var state: State) extends InputStream {
@throws(classOf[IOException])
def drainTo(target: OutputStream): Int = {
val bytesWritten = state match {
case NotStarted(message) =>
message.writeTo(target)
message.serializedSize
case Partial(partial) =>
ByteStreams.copy(partial, target).toInt
case Done => throw new IllegalStateException()
}
state = Done
bytesWritten
}

@throws(classOf[IOException])
def read: Int = {
state match {
case NotStarted(message) =>
state = Partial(new ByteArrayInputStream(message.toByteArray))
case _ =>
}
state match {
case Partial(partial) =>
partial.read()
case _ =>
-1
}
}

@throws(classOf[IOException])
override def read(b: Array[Byte], off: Int, len: Int): Int =
state match {
case NotStarted(message) =>
message.serializedSize match {
case 0 =>
state = Done
-1
case size if len >= size =>
val stream = CodedOutputStream.newInstance(b, off, size)
message.writeTo(stream)
stream.flush()
stream.checkNoSpaceLeft()
state = Done
size
case _ =>
val partial = new ByteArrayInputStream(message.toByteArray)
state = Partial(partial)
partial.read(b, off, len)
}
case Partial(partial) =>
partial.read(b, off, len)
case _ =>
-1
}

@throws(classOf[IOException])
override def available: Int = state match {
case NotStarted(message) => message.serializedSize
case Partial(partial) => partial.available()
case _ => 0
}
}

private object ProtoInputStream {
sealed trait State
case class NotStarted[A <: GeneratedMessage](message: A) extends State
case class Partial(bytes: ByteArrayInputStream) extends State
case object Done extends State

def unapply(p: ProtoInputStream): Option[State] = Some(p.state)

def newInstance(message: GeneratedMessage): InputStream = new ProtoInputStream(NotStarted(message))
}

0 comments on commit 59e9fd5

Please sign in to comment.