diff --git a/server/http/interceptors/interceptors.go b/server/http/interceptors/interceptors.go index 151ffdf9779..26acf96613f 100644 --- a/server/http/interceptors/interceptors.go +++ b/server/http/interceptors/interceptors.go @@ -89,6 +89,19 @@ func (w *wrappedCompressionResponseWriter) Write(b []byte) (int, error) { return w.Writer.Write(b) } +func (w *wrappedCompressionResponseWriter) Flush() { + if f, ok := w.Writer.(flusherWithError); ok { + f.Flush() + } + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +type flusherWithError interface { + Flush() error +} + func Gzip(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { @@ -254,6 +267,12 @@ func (w *instrumentedResponseWriter) WriteHeader(statusCode int) { w.ResponseWriter.WriteHeader(statusCode) } +func (w *instrumentedResponseWriter) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + func alertOnPanic() { buf := make([]byte, 1<<20) n := runtime.Stack(buf, true) diff --git a/server/http/protolet/BUILD b/server/http/protolet/BUILD index 15370a56148..c68d7eae034 100644 --- a/server/http/protolet/BUILD +++ b/server/http/protolet/BUILD @@ -8,6 +8,8 @@ go_library( deps = [ "//server/util/request_context", "@io_opentelemetry_go_otel_trace//:trace", + "@org_golang_google_grpc//:go_default_library", + "@org_golang_google_grpc//codes", "@org_golang_google_protobuf//encoding/protojson", "@org_golang_google_protobuf//encoding/prototext", "@org_golang_google_protobuf//proto", diff --git a/server/http/protolet/protolet.go b/server/http/protolet/protolet.go index 1a562b48499..3fe1da0fa6b 100644 --- a/server/http/protolet/protolet.go +++ b/server/http/protolet/protolet.go @@ -1,21 +1,39 @@ package protolet import ( + "bytes" "context" "fmt" "io" + "io/ioutil" "net/http" "reflect" + "strconv" + "strings" - "github.com/buildbuddy-io/buildbuddy/server/util/request_context" "go.opentelemetry.io/otel/trace" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" + + requestcontext "github.com/buildbuddy-io/buildbuddy/server/util/request_context" ) const ( contextProtoMessageKey = "protolet.requestMessage" + // GRPC over HTTP requires protobuf messages to be sent in a series of `Length-Prefixed-Message`s + // Here's what a Length-Prefixed-Message looks like: + // Length-Prefixed-Message → Compressed-Flag Message-Length Message + // Compressed-Flag → 0 / 1 # encoded as 1 byte unsigned integer + // Message-Length → {length of Message} # encoded as 4 byte unsigned integer (big endian) + // Message → *{binary octet} + // This means the actual proto we want to deserialize starts at byte 5 because there is 1 + // byte that tells us whether or not the message is compressed, and then 4 bytes that tell + // us the length of the message. + // For more info, see: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md + messageByteOffset = 5 ) func isRPCMethod(m reflect.Method) bool { @@ -42,6 +60,26 @@ func isRPCMethod(m reflect.Method) bool { return true } +func isStreamingRPCMethod(m reflect.Method) bool { + t := m.Type + if t.Kind() != reflect.Func { + return false + } + if t.NumIn() != 3 || t.NumOut() != 1 { + return false + } + if !t.In(1).Implements(reflect.TypeOf((*proto.Message)(nil)).Elem()) { + return false + } + if !t.In(2).Implements(reflect.TypeOf((*grpc.ServerStream)(nil)).Elem()) { + return false + } + if !t.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + return false + } + return true +} + func ReadRequestToProto(r *http.Request, req proto.Message) error { body, err := io.ReadAll(r.Body) if err != nil { @@ -55,6 +93,9 @@ func ReadRequestToProto(r *http.Request, req proto.Message) error { return proto.Unmarshal(body, req) case "application/protobuf-text": return prototext.Unmarshal(body, req) + case "application/grpc+proto": + r.Body = ioutil.NopCloser(bytes.NewReader(body)) + return proto.Unmarshal(body[messageByteOffset:], req) default: return fmt.Errorf("Unknown Content-Type: %s, expected application/json or application/protobuf", ct) } @@ -101,7 +142,7 @@ type HTTPHandlers struct { RequestHandler http.Handler } -func GenerateHTTPHandlers(servicePrefix string, server interface{}) (*HTTPHandlers, error) { +func GenerateHTTPHandlers(servicePrefix, serviceName string, server interface{}, grpcServer *grpc.Server) (*HTTPHandlers, error) { if reflect.ValueOf(server).Type().Kind() != reflect.Ptr { return nil, fmt.Errorf("GenerateHTTPHandlers must be called with a pointer to an RPC service implementation") } @@ -110,7 +151,7 @@ func GenerateHTTPHandlers(servicePrefix string, server interface{}) (*HTTPHandle serverType := reflect.TypeOf(server) for i := 0; i < serverType.NumMethod(); i++ { method := serverType.Method(i) - if !isRPCMethod(method) { + if !isRPCMethod(method) && !isStreamingRPCMethod(method) { continue } handlerFns[servicePrefix+method.Name] = method.Func @@ -125,7 +166,13 @@ func GenerateHTTPHandlers(servicePrefix string, server interface{}) (*HTTPHandle } methodType := method.Type() - reqVal := reflect.New(methodType.In(2).Elem()) + requestIndex := 2 + // If we're dealing with a streaming method, the request proto is the first input + if method.Type().In(1).Implements(reflect.TypeOf((*proto.Message)(nil)).Elem()) { + requestIndex = 1 + } + + reqVal := reflect.New(methodType.In(requestIndex).Elem()) req := reqVal.Interface().(proto.Message) if err := ReadRequestToProto(r, req); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) @@ -146,6 +193,24 @@ func GenerateHTTPHandlers(servicePrefix string, server interface{}) (*HTTPHandle return } + // If we're getting a grpc+proto request over http, we rewrite the path to point at + // the grpc server's http handler endpoints and make the request look like an http2 request. + // We also wrap the ResponseWriter so we can return proper errors to the web front-end. + if r.Header.Get("content-type") == "application/grpc+proto" { + r.URL.Path = fmt.Sprintf("/%s/%s", serviceName, strings.TrimPrefix(r.URL.Path, servicePrefix)) + r.ProtoMajor = 2 + r.ProtoMinor = 0 + wrapped := &wrappedResponse{w: w} + grpcServer.ServeHTTP(wrapped, r) + wrapped.sendErrorIfNeeded(r) + return + } + + if method.Type().NumOut() != 2 { + http.Error(w, "Streaming not enabled.", http.StatusNotImplemented) + return + } + // If we know this is a protolet request and we expect to handle it, // override the span name to something legible instead of the generic // handled-path name. This means instead of the span appearing with a @@ -179,3 +244,51 @@ func GenerateHTTPHandlers(servicePrefix string, server interface{}) (*HTTPHandle RequestHandler: requestHandler, }, nil } + +type wrappedResponse struct { + w http.ResponseWriter + wroteHeader bool + wroteBody bool +} + +func (w *wrappedResponse) Header() http.Header { + return w.w.Header() +} + +func (w *wrappedResponse) Write(b []byte) (int, error) { + w.wroteBody, w.wroteHeader = true, true + return w.w.Write(b) +} + +func (w *wrappedResponse) WriteHeader(code int) { + w.wroteHeader = true + w.w.WriteHeader(code) +} + +func (w *wrappedResponse) Flush() { + if !w.wroteHeader && !w.wroteBody { + return + } + if f, ok := w.w.(http.Flusher); ok { + f.Flush() + } +} + +func (w *wrappedResponse) sendErrorIfNeeded(req *http.Request) { + if w.wroteHeader || w.wroteBody { + return + } + i, err := strconv.Atoi(w.Header().Get("grpc-status")) + if err != nil { + i = int(codes.Unknown) + } + if i == 0 { + w.WriteHeader(200) + } + + // Match our current behavior where we return 500 for all errors and return the message in the response body + w.WriteHeader(500) + code := codes.Code(i).String() + w.Write([]byte(fmt.Sprintf("rpc error: code = %s desc = %s", code, w.Header().Get("grpc-message")))) + w.Flush() +} diff --git a/server/libmain/libmain.go b/server/libmain/libmain.go index 8730e1677d6..2de60119b54 100644 --- a/server/libmain/libmain.go +++ b/server/libmain/libmain.go @@ -310,13 +310,6 @@ func StartAndRunServices(env environment.Env) { log.Fatalf("%v", err) } - // Generate HTTP (protolet) handlers for the BuildBuddy API, so it - // can be called over HTTP(s). - protoletHandler, err := protolet.GenerateHTTPHandlers("/rpc/BuildBuddyService/", env.GetBuildBuddyServer()) - if err != nil { - log.Fatalf("Error initializing RPC over HTTP handlers for BuildBuddy server: %s", err) - } - monitoring.StartMonitoringHandler(env, fmt.Sprintf("%s:%d", *listen, *monitoringPort)) if err := build_event_server.Register(env); err != nil { @@ -357,6 +350,13 @@ func StartAndRunServices(env environment.Env) { log.Fatalf("%v", err) } + // Generate HTTP (protolet) handlers for the BuildBuddy API, so it + // can be called over HTTP(s). + protoletHandler, err := protolet.GenerateHTTPHandlers("/rpc/BuildBuddyService/", "buildbuddy.service.BuildBuddyService", env.GetBuildBuddyServer(), env.GetGRPCServer()) + if err != nil { + log.Fatalf("Error initializing RPC over HTTP handlers for BuildBuddy server: %s", err) + } + mux := env.GetMux() // Register all of our HTTP handlers on the default mux. mux.Handle("/", interceptors.WrapExternalHandler(env, staticFileServer)) @@ -382,7 +382,7 @@ func StartAndRunServices(env environment.Env) { // Register API as an HTTP service. if api := env.GetAPIService(); api != nil { - apiProtoHandlers, err := protolet.GenerateHTTPHandlers("/api/v1/", api) + apiProtoHandlers, err := protolet.GenerateHTTPHandlers("/api/v1/", "api.v1", api, env.GetGRPCServer()) if err != nil { log.Fatalf("Error initializing RPC over HTTP handlers for API: %s", err) }