From 2f0f2875e464616f5406d3ccdd7510e4ddc2ef21 Mon Sep 17 00:00:00 2001 From: Dan Kortschak Date: Wed, 7 Dec 2022 11:09:00 +1030 Subject: [PATCH] lib: add HTTPWithContext cel.EnvOption This passes the provided context.Context to all network requests. This is an addition to the HEAD, GET and POST methods and a completion of the intended behaviour for do_request. --- lib/http.go | 43 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/lib/http.go b/lib/http.go index 0711a6a..6865089 100644 --- a/lib/http.go +++ b/lib/http.go @@ -251,6 +251,12 @@ import ( // line=25&page=2" // func HTTP(client *http.Client, limit *rate.Limiter) cel.EnvOption { + return HTTPWithContext(context.Background(), client, limit) +} + +// HTTP returns a cel.EnvOption to configure extended functions for HTTP +// requests that include a context.Context in network requests. +func HTTPWithContext(ctx context.Context, client *http.Client, limit *rate.Limiter) cel.EnvOption { if client == nil { client = http.DefaultClient } @@ -260,12 +266,14 @@ func HTTP(client *http.Client, limit *rate.Limiter) cel.EnvOption { return cel.Lib(httpLib{ client: client, limit: limit, + ctx: ctx, }) } type httpLib struct { client *http.Client limit *rate.Limiter + ctx context.Context } func (httpLib) CompileOptions() []cel.EnvOption { @@ -468,7 +476,7 @@ func (l httpLib) doHead(arg ref.Val) ref.Val { if err != nil { return types.NewErr("%s", err) } - resp, err := l.client.Head(string(url)) + resp, err := l.head(url) if err != nil { return types.NewErr("%s", err) } @@ -479,6 +487,14 @@ func (l httpLib) doHead(arg ref.Val) ref.Val { return types.DefaultTypeAdapter.NativeToValue(rm) } +func (l httpLib) head(url types.String) (*http.Response, error) { + req, err := http.NewRequestWithContext(l.ctx, http.MethodHead, string(url), nil) + if err != nil { + return nil, err + } + return l.client.Do(req) +} + func (l httpLib) doGet(arg ref.Val) ref.Val { url, ok := arg.(types.String) if !ok { @@ -488,7 +504,7 @@ func (l httpLib) doGet(arg ref.Val) ref.Val { if err != nil { return types.NewErr("%s", err) } - resp, err := l.client.Get(string(url)) + resp, err := l.get(url) if err != nil { return types.NewErr("%s", err) } @@ -499,6 +515,14 @@ func (l httpLib) doGet(arg ref.Val) ref.Val { return types.DefaultTypeAdapter.NativeToValue(rm) } +func (l httpLib) get(url types.String) (*http.Response, error) { + req, err := http.NewRequestWithContext(l.ctx, http.MethodGet, string(url), nil) + if err != nil { + return nil, err + } + return l.client.Do(req) +} + func newGetRequest(url ref.Val) ref.Val { return newRequestBody(types.String("GET"), url) } @@ -532,7 +556,7 @@ func (l httpLib) doPost(args ...ref.Val) ref.Val { if err != nil { return types.NewErr("%s", err) } - resp, err := l.client.Post(string(url), string(content), body) + resp, err := l.post(url, content, body) if err != nil { return types.NewErr("%s", err) } @@ -543,6 +567,15 @@ func (l httpLib) doPost(args ...ref.Val) ref.Val { return types.DefaultTypeAdapter.NativeToValue(rm) } +func (l httpLib) post(url, content types.String, body io.Reader) (*http.Response, error) { + req, err := http.NewRequestWithContext(l.ctx, http.MethodPost, string(url), body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", string(content)) + return l.client.Do(req) +} + func newPostRequest(args ...ref.Val) ref.Val { if len(args) != 3 { return types.NewErr("no such overload for post request") @@ -703,8 +736,8 @@ func (l httpLib) doRequest(arg ref.Val) ref.Val { return types.NewErr("%s", err) } // Recover the context lost during serialisation to JSON. - req = req.WithContext(context.Background()) - err = l.limit.Wait(context.TODO()) + req = req.WithContext(l.ctx) + err = l.limit.Wait(l.ctx) if err != nil { return types.NewErr("%s", err) }