Skip to content

Commit

Permalink
http: Initialize http.Request with ctx so cancelation interrupts the …
Browse files Browse the repository at this point in the history
…request (#321)

* Add test to verify canceled context.Context aborts http request

* http: Initialize http.Request with ctx so cancelation interrupts the request
  • Loading branch information
zeisss authored May 12, 2021
1 parent 9e42df5 commit ad4c48e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
6 changes: 3 additions & 3 deletions get_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error {
u.RawQuery = q.Encode()

// Get the URL
req, err := http.NewRequest("GET", u.String(), nil)
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -176,7 +176,7 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error {
// We first make a HEAD request so we can check
// if the server supports range queries. If the server/URL doesn't
// support HEAD requests, we just fall back to GET.
req, err := http.NewRequest("HEAD", src.String(), nil)
req, err := http.NewRequestWithContext(ctx, "HEAD", src.String(), nil)
if err != nil {
return err
}
Expand All @@ -203,7 +203,7 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error {
}
}

req, err = http.NewRequest("GET", src.String(), nil)
req, err = http.NewRequestWithContext(ctx, "GET", src.String(), nil)
if err != nil {
return err
}
Expand Down
49 changes: 49 additions & 0 deletions get_http_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package getter

import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
Expand Down Expand Up @@ -404,6 +405,42 @@ func TestHttpGetter_cleanhttp(t *testing.T) {
}
}

func TestHttpGetter__RespectsContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately

ln := testHttpServer(t)

var u url.URL
u.Scheme = "http"
u.Host = ln.Addr().String()
u.Path = "/file"
dst := tempDir(t)

rt := hookableHTTPRoundTripper{
before: func(req *http.Request) {
err := req.Context().Err()
if !errors.Is(err, context.Canceled) {
t.Fatalf("Expected http.Request with canceled.Context, got: %v", err)
}
},
RoundTripper: http.DefaultTransport,
}

g := new(HttpGetter)
g.client = &Client{
Ctx: ctx,
}
g.Client = &http.Client{
Transport: &rt,
}

err := g.Get(dst, &u)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled, got: %v", err)
}
}

func testHttpServer(t *testing.T) net.Listener {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
Expand Down Expand Up @@ -531,3 +568,15 @@ machine %s
login foo
password bar
`

type hookableHTTPRoundTripper struct {
before func(req *http.Request)
http.RoundTripper
}

func (m *hookableHTTPRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if m.before != nil {
m.before(req)
}
return m.RoundTripper.RoundTrip(req)
}

0 comments on commit ad4c48e

Please sign in to comment.