From d2d4cba9c5b63eaf6f647dbb1f45f7e7d709caba Mon Sep 17 00:00:00 2001 From: "Jonathan A. Sternberg" Date: Thu, 9 Jun 2016 10:52:43 -0500 Subject: [PATCH] Setup an interface for writing http responses based on the Content-Type header Only `application/json` is supported right now, but this opens up the easier possibility of additional content types to be returned from the server. --- cmd/influxd/run/server_helpers_test.go | 4 +- services/httpd/handler.go | 121 +++++++++++-------------- services/httpd/handler_test.go | 66 +++++--------- services/httpd/response_writer.go | 66 ++++++++++++++ 4 files changed, 145 insertions(+), 112 deletions(-) create mode 100644 services/httpd/response_writer.go diff --git a/cmd/influxd/run/server_helpers_test.go b/cmd/influxd/run/server_helpers_test.go index 31f795bc82f..64a655f4e5a 100644 --- a/cmd/influxd/run/server_helpers_test.go +++ b/cmd/influxd/run/server_helpers_test.go @@ -161,7 +161,7 @@ func (s *Server) HTTPGet(url string) (results string, err error) { if err != nil { return "", err } - body := string(MustReadAll(resp.Body)) + body := strings.TrimSpace(string(MustReadAll(resp.Body))) switch resp.StatusCode { case http.StatusBadRequest: if !expectPattern(".*error parsing query*.", body) { @@ -182,7 +182,7 @@ func (s *Server) HTTPPost(url string, content []byte) (results string, err error if err != nil { return "", err } - body := string(MustReadAll(resp.Body)) + body := strings.TrimSpace(string(MustReadAll(resp.Body))) switch resp.StatusCode { case http.StatusBadRequest: if !expectPattern(".*error parsing query*.", body) { diff --git a/services/httpd/handler.go b/services/httpd/handler.go index 7756d6103aa..d49fb02c916 100644 --- a/services/httpd/handler.go +++ b/services/httpd/handler.go @@ -160,6 +160,7 @@ func (h *Handler) AddRoutes(routes ...Route) { handler = http.HandlerFunc(hf) } + handler = h.responseWriter(handler) if r.Gzipped { handler = gzipFilter(handler) } @@ -266,12 +267,16 @@ func (h *Handler) serveQuery(w http.ResponseWriter, r *http.Request, user *meta. h.statMap.Add(statQueryRequestDuration, time.Since(start).Nanoseconds()) }(time.Now()) - pretty := r.FormValue("pretty") == "true" - nodeID, _ := strconv.ParseUint(r.FormValue("node_id"), 10, 64) + // Retrieve the underlying ResponseWriter or initialize our own. + rw, ok := w.(ResponseWriter) + if !ok { + rw = NewResponseWriter(w, r) + } + nodeID, _ := strconv.ParseUint(r.FormValue("node_id"), 10, 64) qp := strings.TrimSpace(r.FormValue("q")) if qp == "" { - h.httpError(w, `missing required parameter "q"`, pretty, http.StatusBadRequest) + h.httpError(rw, `missing required parameter "q"`, http.StatusBadRequest) return } @@ -291,7 +296,7 @@ func (h *Handler) serveQuery(w http.ResponseWriter, r *http.Request, user *meta. decoder := json.NewDecoder(strings.NewReader(rawParams)) decoder.UseNumber() if err := decoder.Decode(¶ms); err != nil { - h.httpError(w, "error parsing query parameters: "+err.Error(), pretty, http.StatusBadRequest) + h.httpError(rw, "error parsing query parameters: "+err.Error(), http.StatusBadRequest) return } @@ -306,7 +311,7 @@ func (h *Handler) serveQuery(w http.ResponseWriter, r *http.Request, user *meta. } if err != nil { - h.httpError(w, "error parsing json value: "+err.Error(), pretty, http.StatusBadRequest) + h.httpError(rw, "error parsing json value: "+err.Error(), http.StatusBadRequest) return } } @@ -317,7 +322,7 @@ func (h *Handler) serveQuery(w http.ResponseWriter, r *http.Request, user *meta. // Parse query from query string. query, err := p.ParseQuery() if err != nil { - h.httpError(w, "error parsing query: "+err.Error(), pretty, http.StatusBadRequest) + h.httpError(rw, "error parsing query: "+err.Error(), http.StatusBadRequest) return } @@ -327,7 +332,7 @@ func (h *Handler) serveQuery(w http.ResponseWriter, r *http.Request, user *meta. if err, ok := err.(meta.ErrAuthorize); ok { h.Logger.Printf("Unauthorized request | user: %q | query: %q | database %q\n", err.User, err.Query.String(), err.Database) } - h.httpError(w, "error authorizing query: "+err.Error(), pretty, http.StatusUnauthorized) + h.httpError(rw, "error authorizing query: "+err.Error(), http.StatusUnauthorized) return } } @@ -365,8 +370,7 @@ func (h *Handler) serveQuery(w http.ResponseWriter, r *http.Request, user *meta. } // Execute query. - w.Header().Add("Connection", "close") - w.Header().Add("Content-Type", "application/json") + rw.Header().Add("Connection", "close") results := h.QueryExecutor.ExecuteQuery(query, influxql.ExecutionOptions{ Database: db, ChunkSize: chunkSize, @@ -378,7 +382,7 @@ func (h *Handler) serveQuery(w http.ResponseWriter, r *http.Request, user *meta. resp := Response{Results: make([]*influxql.Result, 0)} // Status header is OK once this point is reached. - h.writeHeader(w, http.StatusOK) + h.writeHeader(rw, http.StatusOK) // pull all results from the channel rows := 0 @@ -395,12 +399,9 @@ func (h *Handler) serveQuery(w http.ResponseWriter, r *http.Request, user *meta. // Write out result immediately if chunked. if chunked { - n, _ := w.Write(MarshalJSON(Response{ + n, _ := rw.WriteResponse(Response{ Results: []*influxql.Result{r}, - }, pretty)) - if !pretty { - w.Write([]byte("\n")) - } + }) h.statMap.Add(statQueryRequestBytesTransmitted, int64(n)) w.(http.Flusher).Flush() continue @@ -455,7 +456,7 @@ func (h *Handler) serveQuery(w http.ResponseWriter, r *http.Request, user *meta. // If it's not chunked we buffered everything in memory, so write it out if !chunked { - n, _ := w.Write(MarshalJSON(resp, pretty)) + n, _ := rw.WriteResponse(resp) h.statMap.Add(statQueryRequestBytesTransmitted, int64(n)) } } @@ -471,23 +472,23 @@ func (h *Handler) serveWrite(w http.ResponseWriter, r *http.Request, user *meta. database := r.URL.Query().Get("db") if database == "" { - h.resultError(w, influxql.Result{Err: fmt.Errorf("database is required")}, http.StatusBadRequest) + h.httpError(w, "database is required", http.StatusBadRequest) return } if di := h.MetaClient.Database(database); di == nil { - h.resultError(w, influxql.Result{Err: fmt.Errorf("database not found: %q", database)}, http.StatusNotFound) + h.httpError(w, fmt.Sprintf("database not found: %q", database), http.StatusNotFound) return } if h.Config.AuthEnabled && user == nil { - h.resultError(w, influxql.Result{Err: fmt.Errorf("user is required to write to database %q", database)}, http.StatusUnauthorized) + h.httpError(w, fmt.Sprintf("user is required to write to database %q", database), http.StatusUnauthorized) return } if h.Config.AuthEnabled { if err := h.WriteAuthorizer.AuthorizeWrite(user.Name, database); err != nil { - h.resultError(w, influxql.Result{Err: fmt.Errorf("%q user is not authorized to write to database %q", user.Name, database)}, http.StatusUnauthorized) + h.httpError(w, fmt.Sprintf("%q user is not authorized to write to database %q", user.Name, database), http.StatusUnauthorized) return } } @@ -497,7 +498,7 @@ func (h *Handler) serveWrite(w http.ResponseWriter, r *http.Request, user *meta. if r.Header.Get("Content-Encoding") == "gzip" { b, err := gzip.NewReader(r.Body) if err != nil { - h.resultError(w, influxql.Result{Err: err}, http.StatusBadRequest) + h.httpError(w, err.Error(), http.StatusBadRequest) return } defer b.Close() @@ -519,7 +520,7 @@ func (h *Handler) serveWrite(w http.ResponseWriter, r *http.Request, user *meta. if h.Config.WriteTracing { h.Logger.Print("Write handler unable to read bytes from request body") } - h.resultError(w, influxql.Result{Err: err}, http.StatusBadRequest) + h.httpError(w, err.Error(), http.StatusBadRequest) return } h.statMap.Add(statWriteRequestBytesReceived, int64(buf.Len())) @@ -535,7 +536,7 @@ func (h *Handler) serveWrite(w http.ResponseWriter, r *http.Request, user *meta. h.writeHeader(w, http.StatusOK) return } - h.resultError(w, influxql.Result{Err: parseError}, http.StatusBadRequest) + h.httpError(w, parseError.Error(), http.StatusBadRequest) return } @@ -546,7 +547,7 @@ func (h *Handler) serveWrite(w http.ResponseWriter, r *http.Request, user *meta. var err error consistency, err = models.ParseConsistencyLevel(level) if err != nil { - h.resultError(w, influxql.Result{Err: err}, http.StatusBadRequest) + h.httpError(w, err.Error(), http.StatusBadRequest) return } } @@ -554,18 +555,18 @@ func (h *Handler) serveWrite(w http.ResponseWriter, r *http.Request, user *meta. // Write points. if err := h.PointsWriter.WritePoints(database, r.URL.Query().Get("rp"), consistency, points); influxdb.IsClientError(err) { h.statMap.Add(statPointsWrittenFail, int64(len(points))) - h.resultError(w, influxql.Result{Err: err}, http.StatusBadRequest) + h.httpError(w, err.Error(), http.StatusBadRequest) return } else if err != nil { h.statMap.Add(statPointsWrittenFail, int64(len(points))) - h.resultError(w, influxql.Result{Err: err}, http.StatusInternalServerError) + h.httpError(w, err.Error(), http.StatusInternalServerError) return } else if parseError != nil { // We wrote some of the points h.statMap.Add(statPointsWrittenOK, int64(len(points))) // The other points failed to parse which means the client sent invalid line protocol. We return a 400 // response code as well as the lines that failed to parse. - h.resultError(w, influxql.Result{Err: fmt.Errorf("partial write:\n%v", parseError)}, http.StatusBadRequest) + h.httpError(w, fmt.Sprintf("partial write:\n%v", parseError), http.StatusBadRequest) return } @@ -617,22 +618,6 @@ func convertToEpoch(r *influxql.Result, epoch string) { } } -// MarshalJSON will marshal v to JSON. Pretty prints if pretty is true. -func MarshalJSON(v interface{}, pretty bool) []byte { - var b []byte - var err error - if pretty { - b, err = json.MarshalIndent(v, "", " ") - } else { - b, err = json.Marshal(v) - } - - if err != nil { - return []byte(err.Error()) - } - return b -} - // serveExpvar serves registered expvar information over HTTP. func serveExpvar(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json; charset=utf-8") @@ -649,23 +634,20 @@ func serveExpvar(w http.ResponseWriter, r *http.Request) { } // h.httpError writes an error to the client in a standard format. -func (h *Handler) httpError(w http.ResponseWriter, error string, pretty bool, code int) { - w.Header().Add("Content-Type", "application/json") - h.writeHeader(w, code) +func (h *Handler) httpError(w http.ResponseWriter, error string, code int) { response := Response{Err: errors.New(error)} - var b []byte - if pretty { - b, _ = json.MarshalIndent(response, "", " ") - } else { - b, _ = json.Marshal(response) + if rw, ok := w.(ResponseWriter); ok { + h.writeHeader(w, code) + rw.WriteResponse(response) + return } - w.Write(b) -} -func (h *Handler) resultError(w http.ResponseWriter, result influxql.Result, code int) { + // Default implementation if the response writer hasn't been replaced + // with our special response writer type. w.Header().Add("Content-Type", "application/json") h.writeHeader(w, code) - _ = json.NewEncoder(w).Encode(&result) + b, _ := json.Marshal(response) + w.Write(b) } // Filters and filter helpers @@ -750,7 +732,7 @@ func authenticate(inner func(http.ResponseWriter, *http.Request, *meta.UserInfo) creds, err := parseCredentials(r) if err != nil { h.statMap.Add(statAuthFail, 1) - h.httpError(w, err.Error(), false, http.StatusUnauthorized) + h.httpError(w, err.Error(), http.StatusUnauthorized) return } @@ -758,14 +740,14 @@ func authenticate(inner func(http.ResponseWriter, *http.Request, *meta.UserInfo) case UserAuthentication: if creds.Username == "" { h.statMap.Add(statAuthFail, 1) - h.httpError(w, "username required", false, http.StatusUnauthorized) + h.httpError(w, "username required", http.StatusUnauthorized) return } user, err = h.MetaClient.Authenticate(creds.Username, creds.Password) if err != nil { h.statMap.Add(statAuthFail, 1) - h.httpError(w, "authorization failed", false, http.StatusUnauthorized) + h.httpError(w, "authorization failed", http.StatusUnauthorized) return } case BearerAuthentication: @@ -780,39 +762,39 @@ func authenticate(inner func(http.ResponseWriter, *http.Request, *meta.UserInfo) // Parse and validate the token. token, err := jwt.Parse(creds.Token, keyLookupFn) if err != nil { - h.httpError(w, err.Error(), false, http.StatusUnauthorized) + h.httpError(w, err.Error(), http.StatusUnauthorized) return } else if !token.Valid { - h.httpError(w, "invalid token", false, http.StatusUnauthorized) + h.httpError(w, "invalid token", http.StatusUnauthorized) return } // Make sure an expiration was set on the token. if exp, ok := token.Claims["exp"].(float64); !ok || exp <= 0.0 { - h.httpError(w, "token expiration required", false, http.StatusUnauthorized) + h.httpError(w, "token expiration required", http.StatusUnauthorized) return } // Get the username from the token. username, ok := token.Claims["username"].(string) if !ok { - h.httpError(w, "username in token must be a string", false, http.StatusUnauthorized) + h.httpError(w, "username in token must be a string", http.StatusUnauthorized) return } else if username == "" { - h.httpError(w, "token must contain a username", false, http.StatusUnauthorized) + h.httpError(w, "token must contain a username", http.StatusUnauthorized) return } // Lookup user in the metastore. if user, err = h.MetaClient.User(username); err != nil { - h.httpError(w, err.Error(), false, http.StatusUnauthorized) + h.httpError(w, err.Error(), http.StatusUnauthorized) return } else if user == nil { - h.httpError(w, meta.ErrUserNotFound.Error(), false, http.StatusUnauthorized) + h.httpError(w, meta.ErrUserNotFound.Error(), http.StatusUnauthorized) return } default: - h.httpError(w, "unsupported authentication", false, http.StatusUnauthorized) + h.httpError(w, "unsupported authentication", http.StatusUnauthorized) } } @@ -920,6 +902,13 @@ func (h *Handler) logging(inner http.Handler, name string) http.Handler { }) } +func (h *Handler) responseWriter(inner http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w = NewResponseWriter(w, r) + inner.ServeHTTP(w, r) + }) +} + func (h *Handler) recovery(inner http.Handler, name string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() diff --git a/services/httpd/handler_test.go b/services/httpd/handler_test.go index e402fe5357b..41e97bef5ff 100644 --- a/services/httpd/handler_test.go +++ b/services/httpd/handler_test.go @@ -38,8 +38,8 @@ func TestHandler_Query(t *testing.T) { h.ServeHTTP(w, MustNewJSONRequest("GET", "/query?db=foo&q=SELECT+*+FROM+bar", nil)) if w.Code != http.StatusOK { t.Fatalf("unexpected status: %d", w.Code) - } else if w.Body.String() != `{"results":[{"series":[{"name":"series0"}]},{"series":[{"name":"series1"}]}]}` { - t.Fatalf("unexpected body: %s", w.Body.String()) + } else if body := strings.TrimSpace(w.Body.String()); body != `{"results":[{"series":[{"name":"series0"}]},{"series":[{"name":"series1"}]}]}` { + t.Fatalf("unexpected body: %s", body) } } @@ -102,8 +102,8 @@ func TestHandler_Query_Auth(t *testing.T) { h.ServeHTTP(w, MustNewJSONRequest("GET", "/query?u=user1&p=abcd&db=foo&q=SELECT+*+FROM+bar", nil)) if w.Code != http.StatusOK { t.Fatalf("unexpected status: %d: %s", w.Code, w.Body.String()) - } else if w.Body.String() != `{"results":[{"series":[{"name":"series0"}]},{"series":[{"name":"series1"}]}]}` { - t.Fatalf("unexpected body: %s", w.Body.String()) + } else if body := strings.TrimSpace(w.Body.String()); body != `{"results":[{"series":[{"name":"series0"}]},{"series":[{"name":"series1"}]}]}` { + t.Fatalf("unexpected body: %s", body) } // Test the handler with valid JWT bearer token. @@ -116,8 +116,8 @@ func TestHandler_Query_Auth(t *testing.T) { h.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("unexpected status: %d: %s", w.Code, w.Body.String()) - } else if w.Body.String() != `{"results":[{"series":[{"name":"series0"}]},{"series":[{"name":"series1"}]}]}` { - t.Fatalf("unexpected body: %s", w.Body.String()) + } else if body := strings.TrimSpace(w.Body.String()); body != `{"results":[{"series":[{"name":"series0"}]},{"series":[{"name":"series1"}]}]}` { + t.Fatalf("unexpected body: %s", body) } // Test the handler with JWT token signed with invalid key. @@ -130,8 +130,8 @@ func TestHandler_Query_Auth(t *testing.T) { h.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Fatalf("unexpected status: %d: %s", w.Code, w.Body.String()) - } else if w.Body.String() != `{"error":"signature is invalid"}` { - t.Fatalf("unexpected body: %s", w.Body.String()) + } else if body := strings.TrimSpace(w.Body.String()); body != `{"error":"signature is invalid"}` { + t.Fatalf("unexpected body: %s", body) } // Test handler with valid JWT token carrying non-existant user. @@ -142,8 +142,8 @@ func TestHandler_Query_Auth(t *testing.T) { h.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Fatalf("unexpected status: %d: %s", w.Code, w.Body.String()) - } else if w.Body.String() != `{"error":"user not found"}` { - t.Fatalf("unexpected body: %s", w.Body.String()) + } else if body := strings.TrimSpace(w.Body.String()); body != `{"error":"user not found"}` { + t.Fatalf("unexpected body: %s", body) } // Test handler with expired JWT token. @@ -170,8 +170,8 @@ func TestHandler_Query_Auth(t *testing.T) { h.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Fatalf("unexpected status: %d: %s", w.Code, w.Body.String()) - } else if w.Body.String() != `{"error":"token expiration required"}` { - t.Fatalf("unexpected body: %s", w.Body.String()) + } else if body := strings.TrimSpace(w.Body.String()); body != `{"error":"token expiration required"}` { + t.Fatalf("unexpected body: %s", body) } } @@ -205,8 +205,8 @@ func TestHandler_Query_MergeResults(t *testing.T) { h.ServeHTTP(w, MustNewJSONRequest("GET", "/query?db=foo&q=SELECT+*+FROM+bar", nil)) if w.Code != http.StatusOK { t.Fatalf("unexpected status: %d", w.Code) - } else if w.Body.String() != `{"results":[{"series":[{"name":"series0"},{"name":"series1"}]}]}` { - t.Fatalf("unexpected body: %s", w.Body.String()) + } else if body := strings.TrimSpace(w.Body.String()); body != `{"results":[{"series":[{"name":"series0"},{"name":"series1"}]}]}` { + t.Fatalf("unexpected body: %s", body) } } @@ -223,8 +223,8 @@ func TestHandler_Query_MergeEmptyResults(t *testing.T) { h.ServeHTTP(w, MustNewJSONRequest("GET", "/query?db=foo&q=SELECT+*+FROM+bar", nil)) if w.Code != http.StatusOK { t.Fatalf("unexpected status: %d", w.Code) - } else if w.Body.String() != `{"results":[{"series":[{"name":"series1"}]}]}` { - t.Fatalf("unexpected body: %s", w.Body.String()) + } else if body := strings.TrimSpace(w.Body.String()); body != `{"results":[{"series":[{"name":"series1"}]}]}` { + t.Fatalf("unexpected body: %s", body) } } @@ -258,8 +258,8 @@ func TestHandler_Query_ErrQueryRequired(t *testing.T) { h.ServeHTTP(w, MustNewJSONRequest("GET", "/query", nil)) if w.Code != http.StatusBadRequest { t.Fatalf("unexpected status: %d", w.Code) - } else if w.Body.String() != `{"error":"missing required parameter \"q\""}` { - t.Fatalf("unexpected body: %s", w.Body.String()) + } else if body := strings.TrimSpace(w.Body.String()); body != `{"error":"missing required parameter \"q\""}` { + t.Fatalf("unexpected body: %s", body) } } @@ -270,8 +270,8 @@ func TestHandler_Query_ErrInvalidQuery(t *testing.T) { h.ServeHTTP(w, MustNewJSONRequest("GET", "/query?q=SELECT", nil)) if w.Code != http.StatusBadRequest { t.Fatalf("unexpected status: %d", w.Code) - } else if w.Body.String() != `{"error":"error parsing query: found EOF, expected identifier, string, number, bool at line 1, char 8"}` { - t.Fatalf("unexpected body: %s", w.Body.String()) + } else if body := strings.TrimSpace(w.Body.String()); body != `{"error":"error parsing query: found EOF, expected identifier, string, number, bool at line 1, char 8"}` { + t.Fatalf("unexpected body: %s", body) } } @@ -300,8 +300,8 @@ func TestHandler_Query_ErrResult(t *testing.T) { h.ServeHTTP(w, MustNewJSONRequest("GET", "/query?db=foo&q=SHOW+SERIES+from+bin", nil)) if w.Code != http.StatusOK { t.Fatalf("unexpected status: %d", w.Code) - } else if w.Body.String() != `{"results":[{"error":"measurement not found"}]}` { - t.Fatalf("unexpected body: %s", w.Body.String()) + } else if body := strings.TrimSpace(w.Body.String()); body != `{"results":[{"error":"measurement not found"}]}` { + t.Fatalf("unexpected body: %s", body) } } @@ -391,28 +391,6 @@ func TestHandler_HandleBadRequestBody(t *testing.T) { } } -func TestMarshalJSON_NoPretty(t *testing.T) { - if b := httpd.MarshalJSON(struct { - Name string `json:"name"` - }{Name: "foo"}, false); string(b) != `{"name":"foo"}` { - t.Fatalf("unexpected bytes: %s", b) - } -} - -func TestMarshalJSON_Pretty(t *testing.T) { - if b := httpd.MarshalJSON(struct { - Name string `json:"name"` - }{Name: "foo"}, true); string(b) != "{\n \"name\": \"foo\"\n}" { - t.Fatalf("unexpected bytes: %q", string(b)) - } -} - -func TestMarshalJSON_Error(t *testing.T) { - if b := httpd.MarshalJSON(&invalidJSON{}, true); string(b) != "json: error calling MarshalJSON for type *httpd_test.invalidJSON: marker" { - t.Fatalf("unexpected bytes: %q", string(b)) - } -} - type invalidJSON struct{} func (*invalidJSON) MarshalJSON() ([]byte, error) { return nil, errors.New("marker") } diff --git a/services/httpd/response_writer.go b/services/httpd/response_writer.go new file mode 100644 index 00000000000..bac2bd7ca1a --- /dev/null +++ b/services/httpd/response_writer.go @@ -0,0 +1,66 @@ +package httpd + +import ( + "encoding/json" + "io" + "net/http" +) + +// ResponseWriter is an interface for writing a response. +type ResponseWriter interface { + // WriteResponse writes a response. + WriteResponse(resp Response) (int, error) + + http.ResponseWriter +} + +// NewResponseWriter creates a new ResponseWriter based on the Content-Type of the request +// that wraps the ResponseWriter. +func NewResponseWriter(w http.ResponseWriter, r *http.Request) ResponseWriter { + pretty := r.URL.Query().Get("pretty") == "true" + switch r.Header.Get("Content-Type") { + case "application/json": + fallthrough + default: + w.Header().Add("Content-Type", "application/json") + return &jsonResponseWriter{Pretty: pretty, ResponseWriter: w} + } +} + +// WriteError is a convenience function for writing an error response to the ResponseWriter. +func WriteError(w ResponseWriter, err error) (int, error) { + return w.WriteResponse(Response{Err: err}) +} + +type jsonResponseWriter struct { + Pretty bool + http.ResponseWriter +} + +func (w *jsonResponseWriter) WriteResponse(resp Response) (n int, err error) { + var b []byte + if w.Pretty { + b, err = json.MarshalIndent(resp, "", " ") + } else { + b, err = json.Marshal(resp) + } + + if err != nil { + n, err = io.WriteString(w, err.Error()) + } else { + n, err = w.Write(b) + } + + if !w.Pretty { + w.Write([]byte("\n")) + n++ + } + return n, err +} + +// Flush flushes the ResponseWriter if it has a Flush() method. +func (w *jsonResponseWriter) Flush() { + if w, ok := w.ResponseWriter.(http.Flusher); ok { + w.Flush() + } +}