From 1ce7287eff653639a439fa1b57bb40428ded8cf7 Mon Sep 17 00:00:00 2001 From: Martin Hutchinson Date: Mon, 15 Apr 2024 09:58:43 +0100 Subject: [PATCH] Fix CT client upload to be safe against no-op POSTs (#1424) Also updated all literal strings for the defined constants. This would be a great lint check. --- internal/witness/client/http/witness_client.go | 4 ++-- jsonclient/client.go | 4 ++++ trillian/ctfe/handlers_test.go | 18 +++++++++--------- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/internal/witness/client/http/witness_client.go b/internal/witness/client/http/witness_client.go index a04f55ff23..1a836d26d4 100644 --- a/internal/witness/client/http/witness_client.go +++ b/internal/witness/client/http/witness_client.go @@ -43,7 +43,7 @@ func (w Witness) GetLatestSTH(ctx context.Context, logID string) ([]byte, error) if err != nil { return nil, fmt.Errorf("failed to parse URL: %v", err) } - req, err := http.NewRequest("GET", u.String(), nil) + req, err := http.NewRequest(http.MethodGet, u.String(), nil) if err != nil { return nil, fmt.Errorf("failed to create request: %v", err) } @@ -75,7 +75,7 @@ func (w Witness) Update(ctx context.Context, logID string, sth []byte, proof [][ if err != nil { return nil, fmt.Errorf("failed to parse URL: %v", err) } - req, err := http.NewRequest("PUT", u.String(), bytes.NewReader(reqBody)) + req, err := http.NewRequest(http.MethodPut, u.String(), bytes.NewReader(reqBody)) if err != nil { return nil, fmt.Errorf("failed to create request: %v", err) } diff --git a/jsonclient/client.go b/jsonclient/client.go index a95f19ccdc..7e7cb7279e 100644 --- a/jsonclient/client.go +++ b/jsonclient/client.go @@ -248,6 +248,10 @@ func (c *JSONClient) PostAndParse(ctx context.Context, path string, req, rsp int } return nil, nil, err } + if httpRsp.Request.Method != http.MethodPost { + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Redirections#permanent_redirections + return nil, nil, fmt.Errorf("POST request to %q was converted to %s request to %q", fullURI, httpRsp.Request.Method, httpRsp.Request.URL) + } if httpRsp.StatusCode == http.StatusOK { if err = json.Unmarshal(body, &rsp); err != nil { diff --git a/trillian/ctfe/handlers_test.go b/trillian/ctfe/handlers_test.go index abf485ca14..bc786409a0 100644 --- a/trillian/ctfe/handlers_test.go +++ b/trillian/ctfe/handlers_test.go @@ -311,7 +311,7 @@ func TestGetRoots(t *testing.T) { defer info.mockCtrl.Finish() handler := AppHandler{Info: info.li, Handler: getRoots, Name: "GetRoots", Method: http.MethodGet} - req, err := http.NewRequest("GET", "http://example.com/ct/v1/get-roots", nil) + req, err := http.NewRequest(http.MethodGet, "http://example.com/ct/v1/get-roots", nil) if err != nil { t.Fatalf("Failed to create request: %v", err) } @@ -420,7 +420,7 @@ func TestAddChainWhitespace(t *testing.T) { recorder := httptest.NewRecorder() handler := AppHandler{Info: info.li, Handler: addChain, Name: "AddChain", Method: http.MethodPost} - req, err := http.NewRequest("POST", "http://example.com/ct/v1/add-chain", strings.NewReader(test.body)) + req, err := http.NewRequest(http.MethodPost, "http://example.com/ct/v1/add-chain", strings.NewReader(test.body)) if err != nil { t.Fatalf("Failed to create POST request: %v", err) } @@ -775,7 +775,7 @@ func TestGetSTH(t *testing.T) { srReq.ChargeTo = &trillian.ChargeTo{User: []string{test.wantQuotaUser}} } info.client.EXPECT().GetLatestSignedLogRoot(deadlineMatcher(), cmpMatcher{srReq}).Return(test.rpcRsp, test.rpcErr) - req, err := http.NewRequest("GET", "http://example.com/ct/v1/get-sth", nil) + req, err := http.NewRequest(http.MethodGet, "http://example.com/ct/v1/get-sth", nil) if err != nil { t.Errorf("Failed to create request: %v", err) return @@ -1010,7 +1010,7 @@ func TestGetEntries(t *testing.T) { info.setRemoteQuotaUser(test.wantQuotaUser) handler := AppHandler{Info: info.li, Handler: getEntries, Name: "GetEntries", Method: http.MethodGet} path := fmt.Sprintf("/ct/v1/get-entries?%s", test.req) - req, err := http.NewRequest("GET", path, nil) + req, err := http.NewRequest(http.MethodGet, path, nil) if err != nil { t.Errorf("Failed to create request: %v", err) continue @@ -1187,7 +1187,7 @@ func TestGetEntriesRanges(t *testing.T) { } path := fmt.Sprintf("/ct/v1/get-entries?start=%d&end=%d", test.start, test.end) - req, err := http.NewRequest("GET", path, nil) + req, err := http.NewRequest(http.MethodGet, path, nil) if err != nil { t.Fatalf("Failed to create request: %v", err) } @@ -1444,7 +1444,7 @@ func TestGetProofByHash(t *testing.T) { for _, test := range tests { info.setRemoteQuotaUser(test.wantQuotaUser) - req, err := http.NewRequest("GET", fmt.Sprintf("/ct/v1/proof-by-hash?%s", test.req), nil) + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/ct/v1/proof-by-hash?%s", test.req), nil) if err != nil { t.Errorf("Failed to create request: %v", err) continue @@ -1792,7 +1792,7 @@ func TestGetSTHConsistency(t *testing.T) { for _, test := range tests { info.setRemoteQuotaUser(test.wantQuotaUser) - req, err := http.NewRequest("GET", fmt.Sprintf("/ct/v1/get-sth-consistency?%s", test.req), nil) + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/ct/v1/get-sth-consistency?%s", test.req), nil) if err != nil { t.Errorf("Failed to create request: %v", err) continue @@ -2126,7 +2126,7 @@ func TestGetEntryAndProof(t *testing.T) { for _, test := range tests { info.setRemoteQuotaUser(test.wantQuotaUser) - req, err := http.NewRequest("GET", fmt.Sprintf("/ct/v1/get-entry-and-proof?%s", test.req), nil) + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/ct/v1/get-entry-and-proof?%s", test.req), nil) if err != nil { t.Errorf("Failed to create request: %v", err) continue @@ -2250,7 +2250,7 @@ func makeAddChainRequest(t *testing.T, li *logInfo, body io.Reader) *httptest.Re func makeAddChainRequestInternal(t *testing.T, handler AppHandler, path string, body io.Reader) *httptest.ResponseRecorder { t.Helper() - req, err := http.NewRequest("POST", fmt.Sprintf("http://example.com/ct/v1/%s", path), body) + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://example.com/ct/v1/%s", path), body) if err != nil { t.Fatalf("Failed to create POST request: %v", err) }