Skip to content

Commit

Permalink
add grpc support
Browse files Browse the repository at this point in the history
  • Loading branch information
xuwei-k authored and thesamet committed Dec 2, 2015
1 parent de53ada commit 61ceba4
Show file tree
Hide file tree
Showing 17 changed files with 907 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ trait DescriptorPimps {
def asSymbol: String = if (SCALA_RESERVED_WORDS.contains(s)) s"`$s`" else s
}

private def snakeCaseToCamelCase(name: String, upperInitial: Boolean = false): String = {
protected final def snakeCaseToCamelCase(name: String, upperInitial: Boolean = false): String = {
val b = new StringBuilder()
@annotation.tailrec
def inner(name: String, index: Int, capNext: Boolean): Unit = if (name.nonEmpty) {
Expand All @@ -45,6 +45,22 @@ trait DescriptorPimps {
b.toString
}

implicit final class MethodDescriptorPimp(self: MethodDescriptor) {
def scalaOut: String = self.getOutputType.scalaTypeName

def scalaIn: String = self.getInputType.scalaTypeName

def streamType: StreamType = {
val p = self.toProto
(p.getClientStreaming, p.getServerStreaming) match {
case (false, false) => StreamType.Unary
case (true, false) => StreamType.ClientStreaming
case (false, true) => StreamType.ServerStreaming
case (true, true) => StreamType.Bidirectional
}
}
}

implicit class FieldDescriptorPimp(val fd: FieldDescriptor) {
def containingOneOf: Option[OneofDescriptor] = Option(fd.getContainingOneof)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,36 @@
package com.trueaccord.scalapb.compiler

import com.trueaccord.scalapb.compiler.FunctionalPrinter.PrinterEndo

trait FPrintable {
def print(printer: FunctionalPrinter): FunctionalPrinter
}

object PrinterEndo {
def apply(endo: PrinterEndo): PrinterEndo = endo
}

object FunctionalPrinter {
type PrinterEndo = FunctionalPrinter => FunctionalPrinter
val newline: PrinterEndo = _.newline
}

case class FunctionalPrinter(content: List[String] = Nil, indentLevel: Int = 0) {
val INDENT_SIZE = 2

def seq(s: Seq[String]): FunctionalPrinter = add(s: _*)

def add(s: String*): FunctionalPrinter = {
copy(content = s.map(l => " " * (indentLevel * INDENT_SIZE) + l).reverseIterator.toList ::: content)
}

/** add with indent */
def addI(s: String*): FunctionalPrinter = {
this.indent.seq(s).outdent
}

def newline: FunctionalPrinter = add("")

def addM(s: String): FunctionalPrinter =
add(s.stripMargin.split("\n", -1): _*)

Expand All @@ -31,7 +52,11 @@ case class FunctionalPrinter(content: List[String] = Nil, indentLevel: Int = 0)
def indent = copy(indentLevel = indentLevel + 1)
def outdent = copy(indentLevel = indentLevel - 1)

def call(f: FunctionalPrinter => FunctionalPrinter) = f(this)
def call(f: (FunctionalPrinter => FunctionalPrinter)*): FunctionalPrinter =
f.foldLeft(this)((p, f) => f(p))

def withIndent(f: (FunctionalPrinter => FunctionalPrinter)*): FunctionalPrinter =
f.foldLeft(this.indent)((p, f) => f(p)).outdent

def when(cond: => Boolean)(func: FunctionalPrinter => FunctionalPrinter) =
if (cond) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
package com.trueaccord.scalapb.compiler

import java.util.Locale

import com.google.protobuf.Descriptors.{MethodDescriptor, ServiceDescriptor}
import com.trueaccord.scalapb.compiler.FunctionalPrinter.PrinterEndo
import scala.collection.JavaConverters._

final class GrpcServicePrinter(service: ServiceDescriptor, override val params: GeneratorParams) extends DescriptorPimps {

private[this] def methodName0(method: MethodDescriptor): String = snakeCaseToCamelCase(method.getName)
private[this] def methodName(method: MethodDescriptor): String = methodName0(method).asSymbol
private[this] def observer(typeParam: String): String = s"$streamObserver[$typeParam]"

private[this] def methodSignature(method: MethodDescriptor, t: String => String = identity[String]): String = {
method.streamType match {
case StreamType.Unary =>
s"def ${methodName(method)}(request: ${method.scalaIn}): ${t(method.scalaOut)}"
case StreamType.ServerStreaming =>
s"def ${methodName(method)}(request: ${method.scalaIn}, observer: ${observer(method.scalaOut)}): Unit"
case StreamType.ClientStreaming | StreamType.Bidirectional =>
s"def ${methodName(method)}(observer: ${observer(method.scalaOut)}): ${observer(method.scalaIn)}"
}
}

private[this] def base: PrinterEndo = {
val F = "F[" + (_: String) + "]"

val methods: PrinterEndo = { p =>
p.seq(service.getMethods.asScala.map(methodSignature(_, F)))
}

{ p =>
p.add(s"trait ${serviceName(F("_"))} {").withIndent(methods).add("}")
}
}

private[this] val channel = "_root_.io.grpc.Channel"
private[this] val callOptions = "_root_.io.grpc.CallOptions"

private[this] def serviceName0 = service.getName.asSymbol
private[this] def serviceName(p: String) = serviceName0 + "[" + p + "]"
private[this] val serviceBlocking = serviceName("({type l[a] = a})#l")
private[this] val serviceFuture = serviceName("scala.concurrent.Future")

private[this] val futureUnaryCall = "_root_.io.grpc.stub.ClientCalls.futureUnaryCall"
private[this] val abstractStub = "_root_.io.grpc.stub.AbstractStub"
private[this] val streamObserver = "_root_.io.grpc.stub.StreamObserver"

private[this] object serverCalls {
val unary = "_root_.io.grpc.stub.ServerCalls.asyncUnaryCall"
val clientStreaming = "_root_.io.grpc.stub.ServerCalls.asyncClientStreamingCall"
val serverStreaming = "_root_.io.grpc.stub.ServerCalls.asyncServerStreamingCall"
val bidiStreaming = "_root_.io.grpc.stub.ServerCalls.asyncBidiStreamingCall"
}

private[this] object clientCalls {
val clientStreaming = "_root_.io.grpc.stub.ClientCalls.asyncClientStreamingCall"
val serverStreaming = "_root_.io.grpc.stub.ClientCalls.asyncServerStreamingCall"
val bidiStreaming = "_root_.io.grpc.stub.ClientCalls.asyncBidiStreamingCall"
}

private[this] val blockingClientName: String = service.getName + "BlockingClientImpl"

private[this] def clientMethodImpl(m: MethodDescriptor, blocking: Boolean) = PrinterEndo{ p =>
m.streamType match {
case StreamType.Unary =>
if(blocking) {
p.add(
"override " + methodSignature(m, identity) + " = {"
).add(
s""" ${m.scalaOut}.fromJavaProto(_root_.io.grpc.stub.ClientCalls.blockingUnaryCall(channel.newCall(${methodDescriptorName(m)}, options), ${m.scalaIn}.toJavaProto(request)))""",
"}"
)
} else {
p.add(
"override " + methodSignature(m, "scala.concurrent.Future[" + _ + "]") + " = {"
).add(
s""" $guavaFuture2ScalaFuture($futureUnaryCall(channel.newCall(${methodDescriptorName(m)}, options), ${m.scalaIn}.toJavaProto(request)))(${m.scalaOut}.fromJavaProto(_))""",
"}"
)
}
case StreamType.ServerStreaming =>
p.add(
"override " + methodSignature(m) + " = {"
).addI(
s"${clientCalls.serverStreaming}(channel.newCall(${methodDescriptorName(m)}, options), ${m.scalaIn}.toJavaProto(request), $contramapObserver(observer)(${m.scalaOut}.fromJavaProto))"
).add("}")
case streamType =>
val call = if (streamType == StreamType.ClientStreaming) {
clientCalls.clientStreaming
} else {
clientCalls.bidiStreaming
}

p.add(
"override " + methodSignature(m) + " = {"
).indent.add(
s"$contramapObserver("
).indent.add(
s"$call(channel.newCall(${methodDescriptorName(m)}, options), $contramapObserver(observer)(${m.scalaOut}.fromJavaProto)))(${m.scalaIn}.toJavaProto"
).outdent.add(")").outdent.add("}")
}
}

private[this] val blockingClientImpl: PrinterEndo = { p =>
val methods = service.getMethods.asScala.map(clientMethodImpl(_, true))

val build =
s" override def build(channel: $channel, options: $callOptions): $blockingClientName = new $blockingClientName(channel, options)"

p.add(
s"class $blockingClientName(channel: $channel, options: $callOptions = $callOptions.DEFAULT) extends $abstractStub[$blockingClientName](channel, options) with $serviceBlocking {"
).withIndent(
methods : _*
).add(
build
).add(
"}"
)
}

private[this] val contramapObserver = "contramapObserver"
private[this] val contramapObserverImpl = s"""private[this] def $contramapObserver[A, B](observer: ${observer("A")})(f: B => A): ${observer("B")} =
new ${observer("B")} {
override def onNext(value: B): Unit = observer.onNext(f(value))
override def onError(t: Throwable): Unit = observer.onError(t)
override def onCompleted(): Unit = observer.onCompleted()
}"""

private[this] val guavaFuture2ScalaFuture = "guavaFuture2ScalaFuture"

private[this] val guavaFuture2ScalaFutureImpl = {
s"""private[this] def $guavaFuture2ScalaFuture[A, B](guavaFuture: _root_.com.google.common.util.concurrent.ListenableFuture[A])(converter: A => B): scala.concurrent.Future[B] = {
val p = scala.concurrent.Promise[B]()
_root_.com.google.common.util.concurrent.Futures.addCallback(guavaFuture, new _root_.com.google.common.util.concurrent.FutureCallback[A] {
override def onFailure(t: Throwable) = p.failure(t)
override def onSuccess(a: A) = p.success(converter(a))
})
p.future
}"""
}

private[this] val asyncClientName = service.getName + "AsyncClientImpl"

private[this] val asyncClientImpl: PrinterEndo = { p =>
val methods = service.getMethods.asScala.map(clientMethodImpl(_, false))

val build =
s" override def build(channel: $channel, options: $callOptions): $asyncClientName = new $asyncClientName(channel, options)"

p.add(
s"class $asyncClientName(channel: $channel, options: $callOptions = $callOptions.DEFAULT) extends $abstractStub[$asyncClientName](channel, options) with $serviceFuture {"
).withIndent(
methods : _*
).add(
build
).add(
"}"
)
}

private[this] def methodDescriptorName(method: MethodDescriptor): String =
"METHOD_" + method.getName.toUpperCase(Locale.ENGLISH)

private[this] def methodDescriptor(method: MethodDescriptor) = {
val inJava = method.getInputType.javaTypeName
val outJava = method.getOutputType.javaTypeName

def marshaller(typeName: String) =
s"_root_.io.grpc.protobuf.ProtoUtils.marshaller($typeName.getDefaultInstance)"

val methodType = method.streamType match {
case StreamType.Unary => "UNARY"
case StreamType.ClientStreaming => "CLIENT_STREAMING"
case StreamType.ServerStreaming => "SERVER_STREAMING"
case StreamType.Bidirectional => "BIDI_STREAMING"
}

val grpcMethodDescriptor = "_root_.io.grpc.MethodDescriptor"

s""" private[this] val ${methodDescriptorName(method)}: $grpcMethodDescriptor[$inJava, $outJava] =
$grpcMethodDescriptor.create(
$grpcMethodDescriptor.MethodType.$methodType,
$grpcMethodDescriptor.generateFullMethodName("${service.getFullName}", "${method.getName}"),
${marshaller(inJava)},
${marshaller(outJava)}
)"""
}

private[this] val methodDescriptors: Seq[String] = service.getMethods.asScala.map(methodDescriptor)

private[this] def callMethodName(method: MethodDescriptor) =
methodName0(method) + "Method"

private[this] def callMethod(method: MethodDescriptor) =
method.streamType match {
case StreamType.Unary =>
s"${callMethodName(method)}(service, executionContext)"
case _ =>
s"${callMethodName(method)}(service)"
}

private[this] def createMethod(method: MethodDescriptor): String = {
val javaIn = method.getInputType.javaTypeName
val javaOut = method.getOutputType.javaTypeName
val executionContext = "executionContext"
val name = callMethodName(method)
val serviceImpl = "serviceImpl"
method.streamType match {
case StreamType.Unary =>
val serverMethod = s"_root_.io.grpc.stub.ServerCalls.UnaryMethod[$javaIn, $javaOut]"
s""" def ${name}($serviceImpl: $serviceFuture, $executionContext: scala.concurrent.ExecutionContext): $serverMethod = {
new $serverMethod {
override def invoke(request: $javaIn, observer: $streamObserver[$javaOut]): Unit =
$serviceImpl.${methodName(method)}(${method.scalaIn}.fromJavaProto(request)).onComplete {
case scala.util.Success(value) =>
observer.onNext(${method.scalaOut}.toJavaProto(value))
observer.onCompleted()
case scala.util.Failure(error) =>
observer.onError(error)
observer.onCompleted()
}($executionContext)
}
}"""
case StreamType.ServerStreaming =>
val serverMethod = s"_root_.io.grpc.stub.ServerCalls.ServerStreamingMethod[$javaIn, $javaOut]"

s""" def ${name}($serviceImpl: $serviceFuture): $serverMethod = {
new $serverMethod {
override def invoke(request: $javaIn, observer: $streamObserver[$javaOut]): Unit =
$serviceImpl.${methodName0(method)}(${method.scalaIn}.fromJavaProto(request), $contramapObserver(observer)(${method.scalaOut}.toJavaProto))
}
}"""
case _ =>
val serverMethod = if(method.streamType == StreamType.ClientStreaming) {
s"_root_.io.grpc.stub.ServerCalls.ClientStreamingMethod[$javaIn, $javaOut]"
} else {
s"_root_.io.grpc.stub.ServerCalls.BidiStreamingMethod[$javaIn, $javaOut]"
}

s""" def ${name}($serviceImpl: $serviceFuture): $serverMethod = {
new $serverMethod {
override def invoke(observer: $streamObserver[$javaOut]): $streamObserver[$javaIn] =
$contramapObserver($serviceImpl.${methodName0(method)}($contramapObserver(observer)(${method.scalaOut}.toJavaProto)))(${method.scalaIn}.fromJavaProto)
}
}"""
}
}

private[this] val bindService = {
val executionContext = "executionContext"
val methods = service.getMethods.asScala.map { m =>

val call = m.streamType match {
case StreamType.Unary => serverCalls.unary
case StreamType.ClientStreaming => serverCalls.clientStreaming
case StreamType.ServerStreaming => serverCalls.serverStreaming
case StreamType.Bidirectional => serverCalls.bidiStreaming
}

s""".addMethod(
${methodDescriptorName(m)},
$call(
${callMethod(m)}
)
)"""
}.mkString

val serverServiceDef = "_root_.io.grpc.ServerServiceDefinition"

s"""def bindService(service: $serviceFuture, $executionContext: scala.concurrent.ExecutionContext): $serverServiceDef =
$serverServiceDef.builder("${service.getFullName}")$methods.build()
"""
}

val objectName = service.getName + "Grpc"

def printService(printer: FunctionalPrinter): FunctionalPrinter = {
printer.add(
"package " + service.getFile.scalaPackageName,
"",
"import scala.language.higherKinds",
"",
s"object $objectName {"
).seq(
service.getMethods.asScala.map(createMethod)
).seq(
methodDescriptors
).newline.withIndent(
base,
FunctionalPrinter.newline,
blockingClientImpl,
FunctionalPrinter.newline,
asyncClientImpl
).newline.addI(
bindService,
guavaFuture2ScalaFutureImpl,
contramapObserverImpl,
s"def blockingClient(channel: $channel): $serviceBlocking = new $blockingClientName(channel)",
s"def futureClient(channel: $channel): $serviceFuture = new $asyncClientName(channel)"
).add(
""
).outdent.add("}")
}
}
Loading

0 comments on commit 61ceba4

Please sign in to comment.