Skip to content

Commit

Permalink
Setup an interface for writing http responses based on the Content-Ty…
Browse files Browse the repository at this point in the history
…pe header

Only `application/json` is supported right now, but this opens up the
easier possibility of additional content types to be returned from the
server.
  • Loading branch information
jsternberg committed Jun 10, 2016
1 parent 48f1a6d commit d2d4cba
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 112 deletions.
4 changes: 2 additions & 2 deletions cmd/influxd/run/server_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (s *Server) HTTPGet(url string) (results string, err error) {
if err != nil {
return "", err
}
body := string(MustReadAll(resp.Body))
body := strings.TrimSpace(string(MustReadAll(resp.Body)))
switch resp.StatusCode {
case http.StatusBadRequest:
if !expectPattern(".*error parsing query*.", body) {
Expand All @@ -182,7 +182,7 @@ func (s *Server) HTTPPost(url string, content []byte) (results string, err error
if err != nil {
return "", err
}
body := string(MustReadAll(resp.Body))
body := strings.TrimSpace(string(MustReadAll(resp.Body)))
switch resp.StatusCode {
case http.StatusBadRequest:
if !expectPattern(".*error parsing query*.", body) {
Expand Down
121 changes: 55 additions & 66 deletions services/httpd/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ func (h *Handler) AddRoutes(routes ...Route) {
handler = http.HandlerFunc(hf)
}

handler = h.responseWriter(handler)
if r.Gzipped {
handler = gzipFilter(handler)
}
Expand Down Expand Up @@ -266,12 +267,16 @@ func (h *Handler) serveQuery(w http.ResponseWriter, r *http.Request, user *meta.
h.statMap.Add(statQueryRequestDuration, time.Since(start).Nanoseconds())
}(time.Now())

pretty := r.FormValue("pretty") == "true"
nodeID, _ := strconv.ParseUint(r.FormValue("node_id"), 10, 64)
// Retrieve the underlying ResponseWriter or initialize our own.
rw, ok := w.(ResponseWriter)
if !ok {
rw = NewResponseWriter(w, r)
}

nodeID, _ := strconv.ParseUint(r.FormValue("node_id"), 10, 64)
qp := strings.TrimSpace(r.FormValue("q"))
if qp == "" {
h.httpError(w, `missing required parameter "q"`, pretty, http.StatusBadRequest)
h.httpError(rw, `missing required parameter "q"`, http.StatusBadRequest)
return
}

Expand All @@ -291,7 +296,7 @@ func (h *Handler) serveQuery(w http.ResponseWriter, r *http.Request, user *meta.
decoder := json.NewDecoder(strings.NewReader(rawParams))
decoder.UseNumber()
if err := decoder.Decode(&params); err != nil {
h.httpError(w, "error parsing query parameters: "+err.Error(), pretty, http.StatusBadRequest)
h.httpError(rw, "error parsing query parameters: "+err.Error(), http.StatusBadRequest)
return
}

Expand All @@ -306,7 +311,7 @@ func (h *Handler) serveQuery(w http.ResponseWriter, r *http.Request, user *meta.
}

if err != nil {
h.httpError(w, "error parsing json value: "+err.Error(), pretty, http.StatusBadRequest)
h.httpError(rw, "error parsing json value: "+err.Error(), http.StatusBadRequest)
return
}
}
Expand All @@ -317,7 +322,7 @@ func (h *Handler) serveQuery(w http.ResponseWriter, r *http.Request, user *meta.
// Parse query from query string.
query, err := p.ParseQuery()
if err != nil {
h.httpError(w, "error parsing query: "+err.Error(), pretty, http.StatusBadRequest)
h.httpError(rw, "error parsing query: "+err.Error(), http.StatusBadRequest)
return
}

Expand All @@ -327,7 +332,7 @@ func (h *Handler) serveQuery(w http.ResponseWriter, r *http.Request, user *meta.
if err, ok := err.(meta.ErrAuthorize); ok {
h.Logger.Printf("Unauthorized request | user: %q | query: %q | database %q\n", err.User, err.Query.String(), err.Database)
}
h.httpError(w, "error authorizing query: "+err.Error(), pretty, http.StatusUnauthorized)
h.httpError(rw, "error authorizing query: "+err.Error(), http.StatusUnauthorized)
return
}
}
Expand Down Expand Up @@ -365,8 +370,7 @@ func (h *Handler) serveQuery(w http.ResponseWriter, r *http.Request, user *meta.
}

// Execute query.
w.Header().Add("Connection", "close")
w.Header().Add("Content-Type", "application/json")
rw.Header().Add("Connection", "close")
results := h.QueryExecutor.ExecuteQuery(query, influxql.ExecutionOptions{
Database: db,
ChunkSize: chunkSize,
Expand All @@ -378,7 +382,7 @@ func (h *Handler) serveQuery(w http.ResponseWriter, r *http.Request, user *meta.
resp := Response{Results: make([]*influxql.Result, 0)}

// Status header is OK once this point is reached.
h.writeHeader(w, http.StatusOK)
h.writeHeader(rw, http.StatusOK)

// pull all results from the channel
rows := 0
Expand All @@ -395,12 +399,9 @@ func (h *Handler) serveQuery(w http.ResponseWriter, r *http.Request, user *meta.

// Write out result immediately if chunked.
if chunked {
n, _ := w.Write(MarshalJSON(Response{
n, _ := rw.WriteResponse(Response{
Results: []*influxql.Result{r},
}, pretty))
if !pretty {
w.Write([]byte("\n"))
}
})
h.statMap.Add(statQueryRequestBytesTransmitted, int64(n))
w.(http.Flusher).Flush()
continue
Expand Down Expand Up @@ -455,7 +456,7 @@ func (h *Handler) serveQuery(w http.ResponseWriter, r *http.Request, user *meta.

// If it's not chunked we buffered everything in memory, so write it out
if !chunked {
n, _ := w.Write(MarshalJSON(resp, pretty))
n, _ := rw.WriteResponse(resp)
h.statMap.Add(statQueryRequestBytesTransmitted, int64(n))
}
}
Expand All @@ -471,23 +472,23 @@ func (h *Handler) serveWrite(w http.ResponseWriter, r *http.Request, user *meta.

database := r.URL.Query().Get("db")
if database == "" {
h.resultError(w, influxql.Result{Err: fmt.Errorf("database is required")}, http.StatusBadRequest)
h.httpError(w, "database is required", http.StatusBadRequest)
return
}

if di := h.MetaClient.Database(database); di == nil {
h.resultError(w, influxql.Result{Err: fmt.Errorf("database not found: %q", database)}, http.StatusNotFound)
h.httpError(w, fmt.Sprintf("database not found: %q", database), http.StatusNotFound)
return
}

if h.Config.AuthEnabled && user == nil {
h.resultError(w, influxql.Result{Err: fmt.Errorf("user is required to write to database %q", database)}, http.StatusUnauthorized)
h.httpError(w, fmt.Sprintf("user is required to write to database %q", database), http.StatusUnauthorized)
return
}

if h.Config.AuthEnabled {
if err := h.WriteAuthorizer.AuthorizeWrite(user.Name, database); err != nil {
h.resultError(w, influxql.Result{Err: fmt.Errorf("%q user is not authorized to write to database %q", user.Name, database)}, http.StatusUnauthorized)
h.httpError(w, fmt.Sprintf("%q user is not authorized to write to database %q", user.Name, database), http.StatusUnauthorized)
return
}
}
Expand All @@ -497,7 +498,7 @@ func (h *Handler) serveWrite(w http.ResponseWriter, r *http.Request, user *meta.
if r.Header.Get("Content-Encoding") == "gzip" {
b, err := gzip.NewReader(r.Body)
if err != nil {
h.resultError(w, influxql.Result{Err: err}, http.StatusBadRequest)
h.httpError(w, err.Error(), http.StatusBadRequest)
return
}
defer b.Close()
Expand All @@ -519,7 +520,7 @@ func (h *Handler) serveWrite(w http.ResponseWriter, r *http.Request, user *meta.
if h.Config.WriteTracing {
h.Logger.Print("Write handler unable to read bytes from request body")
}
h.resultError(w, influxql.Result{Err: err}, http.StatusBadRequest)
h.httpError(w, err.Error(), http.StatusBadRequest)
return
}
h.statMap.Add(statWriteRequestBytesReceived, int64(buf.Len()))
Expand All @@ -535,7 +536,7 @@ func (h *Handler) serveWrite(w http.ResponseWriter, r *http.Request, user *meta.
h.writeHeader(w, http.StatusOK)
return
}
h.resultError(w, influxql.Result{Err: parseError}, http.StatusBadRequest)
h.httpError(w, parseError.Error(), http.StatusBadRequest)
return
}

Expand All @@ -546,26 +547,26 @@ func (h *Handler) serveWrite(w http.ResponseWriter, r *http.Request, user *meta.
var err error
consistency, err = models.ParseConsistencyLevel(level)
if err != nil {
h.resultError(w, influxql.Result{Err: err}, http.StatusBadRequest)
h.httpError(w, err.Error(), http.StatusBadRequest)
return
}
}

// Write points.
if err := h.PointsWriter.WritePoints(database, r.URL.Query().Get("rp"), consistency, points); influxdb.IsClientError(err) {
h.statMap.Add(statPointsWrittenFail, int64(len(points)))
h.resultError(w, influxql.Result{Err: err}, http.StatusBadRequest)
h.httpError(w, err.Error(), http.StatusBadRequest)
return
} else if err != nil {
h.statMap.Add(statPointsWrittenFail, int64(len(points)))
h.resultError(w, influxql.Result{Err: err}, http.StatusInternalServerError)
h.httpError(w, err.Error(), http.StatusInternalServerError)
return
} else if parseError != nil {
// We wrote some of the points
h.statMap.Add(statPointsWrittenOK, int64(len(points)))
// The other points failed to parse which means the client sent invalid line protocol. We return a 400
// response code as well as the lines that failed to parse.
h.resultError(w, influxql.Result{Err: fmt.Errorf("partial write:\n%v", parseError)}, http.StatusBadRequest)
h.httpError(w, fmt.Sprintf("partial write:\n%v", parseError), http.StatusBadRequest)
return
}

Expand Down Expand Up @@ -617,22 +618,6 @@ func convertToEpoch(r *influxql.Result, epoch string) {
}
}

// MarshalJSON will marshal v to JSON. Pretty prints if pretty is true.
func MarshalJSON(v interface{}, pretty bool) []byte {
var b []byte
var err error
if pretty {
b, err = json.MarshalIndent(v, "", " ")
} else {
b, err = json.Marshal(v)
}

if err != nil {
return []byte(err.Error())
}
return b
}

// serveExpvar serves registered expvar information over HTTP.
func serveExpvar(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
Expand All @@ -649,23 +634,20 @@ func serveExpvar(w http.ResponseWriter, r *http.Request) {
}

// h.httpError writes an error to the client in a standard format.
func (h *Handler) httpError(w http.ResponseWriter, error string, pretty bool, code int) {
w.Header().Add("Content-Type", "application/json")
h.writeHeader(w, code)
func (h *Handler) httpError(w http.ResponseWriter, error string, code int) {
response := Response{Err: errors.New(error)}
var b []byte
if pretty {
b, _ = json.MarshalIndent(response, "", " ")
} else {
b, _ = json.Marshal(response)
if rw, ok := w.(ResponseWriter); ok {
h.writeHeader(w, code)
rw.WriteResponse(response)
return
}
w.Write(b)
}

func (h *Handler) resultError(w http.ResponseWriter, result influxql.Result, code int) {
// Default implementation if the response writer hasn't been replaced
// with our special response writer type.
w.Header().Add("Content-Type", "application/json")
h.writeHeader(w, code)
_ = json.NewEncoder(w).Encode(&result)
b, _ := json.Marshal(response)
w.Write(b)
}

// Filters and filter helpers
Expand Down Expand Up @@ -750,22 +732,22 @@ func authenticate(inner func(http.ResponseWriter, *http.Request, *meta.UserInfo)
creds, err := parseCredentials(r)
if err != nil {
h.statMap.Add(statAuthFail, 1)
h.httpError(w, err.Error(), false, http.StatusUnauthorized)
h.httpError(w, err.Error(), http.StatusUnauthorized)
return
}

switch creds.Method {
case UserAuthentication:
if creds.Username == "" {
h.statMap.Add(statAuthFail, 1)
h.httpError(w, "username required", false, http.StatusUnauthorized)
h.httpError(w, "username required", http.StatusUnauthorized)
return
}

user, err = h.MetaClient.Authenticate(creds.Username, creds.Password)
if err != nil {
h.statMap.Add(statAuthFail, 1)
h.httpError(w, "authorization failed", false, http.StatusUnauthorized)
h.httpError(w, "authorization failed", http.StatusUnauthorized)
return
}
case BearerAuthentication:
Expand All @@ -780,39 +762,39 @@ func authenticate(inner func(http.ResponseWriter, *http.Request, *meta.UserInfo)
// Parse and validate the token.
token, err := jwt.Parse(creds.Token, keyLookupFn)
if err != nil {
h.httpError(w, err.Error(), false, http.StatusUnauthorized)
h.httpError(w, err.Error(), http.StatusUnauthorized)
return
} else if !token.Valid {
h.httpError(w, "invalid token", false, http.StatusUnauthorized)
h.httpError(w, "invalid token", http.StatusUnauthorized)
return
}

// Make sure an expiration was set on the token.
if exp, ok := token.Claims["exp"].(float64); !ok || exp <= 0.0 {
h.httpError(w, "token expiration required", false, http.StatusUnauthorized)
h.httpError(w, "token expiration required", http.StatusUnauthorized)
return
}

// Get the username from the token.
username, ok := token.Claims["username"].(string)
if !ok {
h.httpError(w, "username in token must be a string", false, http.StatusUnauthorized)
h.httpError(w, "username in token must be a string", http.StatusUnauthorized)
return
} else if username == "" {
h.httpError(w, "token must contain a username", false, http.StatusUnauthorized)
h.httpError(w, "token must contain a username", http.StatusUnauthorized)
return
}

// Lookup user in the metastore.
if user, err = h.MetaClient.User(username); err != nil {
h.httpError(w, err.Error(), false, http.StatusUnauthorized)
h.httpError(w, err.Error(), http.StatusUnauthorized)
return
} else if user == nil {
h.httpError(w, meta.ErrUserNotFound.Error(), false, http.StatusUnauthorized)
h.httpError(w, meta.ErrUserNotFound.Error(), http.StatusUnauthorized)
return
}
default:
h.httpError(w, "unsupported authentication", false, http.StatusUnauthorized)
h.httpError(w, "unsupported authentication", http.StatusUnauthorized)
}

}
Expand Down Expand Up @@ -920,6 +902,13 @@ func (h *Handler) logging(inner http.Handler, name string) http.Handler {
})
}

func (h *Handler) responseWriter(inner http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w = NewResponseWriter(w, r)
inner.ServeHTTP(w, r)
})
}

func (h *Handler) recovery(inner http.Handler, name string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
Expand Down
Loading

0 comments on commit d2d4cba

Please sign in to comment.