Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add SetClientBehavior method to allow users can select proxy client's do behavior #19

Merged
merged 10 commits into from
May 31, 2024
44 changes: 44 additions & 0 deletions proxy_client_behavior.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package reverseproxy

import "time"

type clientBehaviorType int

const (
do clientBehaviorType = iota
doDeadline
doRedirects
doTimeout
)

type clientBehavior struct {
clientBehaviorType clientBehaviorType
param interface{}
}

func ClientDo() clientBehavior {
return clientBehavior{
clientBehaviorType: do,
}
}

func ClientDoRedirects(param int) clientBehavior {
return clientBehavior{
clientBehaviorType: doRedirects,
param: param,
}
}

func ClientDoDeadline(param time.Time) clientBehavior {
return clientBehavior{
clientBehaviorType: doDeadline,
param: param,
}
}

func ClientDoTimeout(param time.Duration) clientBehavior {
return clientBehavior{
clientBehaviorType: doTimeout,
param: param,
}
}
43 changes: 31 additions & 12 deletions reverse_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"reflect"
"strings"
"sync"
"time"
"unsafe"

"github.com/cloudwego/hertz/pkg/app"
Expand All @@ -45,6 +46,8 @@ import (
type ReverseProxy struct {
client *client.Client

clientBehavior clientBehavior

// target is set as a reverse proxy address
Target string

Expand Down Expand Up @@ -105,7 +108,6 @@ var hopHeaders = []string{
// To rewrite Host headers, use ReverseProxy directly with a custom
// director policy.
//
// Note: if no config.ClientOption is passed it will use the default global client.Client instance.
// When passing config.ClientOption it will initialize a local client.Client instance.
// Using ReverseProxy.SetClient if there is need for shared customized client.Client instance.
func NewSingleHostReverseProxy(target string, options ...config.ClientOption) (*ReverseProxy, error) {
Expand All @@ -116,13 +118,11 @@ func NewSingleHostReverseProxy(target string, options ...config.ClientOption) (*
req.Header.SetHostBytes(req.URI().Host())
},
}
if len(options) != 0 {
c, err := client.NewClient(options...)
if err != nil {
return nil, err
}
r.client = c
c, err := client.NewClient(options...)
if err != nil {
return nil, err
}
r.client = c
return r, nil
}

Expand Down Expand Up @@ -275,11 +275,8 @@ func (r *ReverseProxy) ServeHTTP(c context.Context, ctx *app.RequestContext) {
req.Header.Add("X-Forwarded-For", ip)
}
}
fn := client.Do
if r.client != nil {
fn = r.client.Do
}
err := fn(c, req, resp)

err := r.doClientBehavior(c, req, resp)
if err != nil {
hlog.CtxErrorf(c, "HERTZ: Client request error: %#v", err.Error())
r.getErrorHandler()(ctx, err)
Expand Down Expand Up @@ -345,13 +342,35 @@ func (r *ReverseProxy) SetSaveOriginResHeader(b bool) {
r.saveOriginResHeader = b
}

func (r *ReverseProxy) SetClientBehavior(cb clientBehavior) {
r.clientBehavior = cb
}

func (r *ReverseProxy) getErrorHandler() func(c *app.RequestContext, err error) {
if r.errorHandler != nil {
return r.errorHandler
}
return r.defaultErrorHandler
}

func (r *ReverseProxy) doClientBehavior(ctx context.Context, req *protocol.Request, resp *protocol.Response) error {
var err error
switch r.clientBehavior.clientBehaviorType {
case doDeadline:
deadline := r.clientBehavior.param.(time.Time)
err = r.client.DoDeadline(ctx, req, resp, deadline)
case doRedirects:
maxRedirectsCount := r.clientBehavior.param.(int)
err = r.client.DoRedirects(ctx, req, resp, maxRedirectsCount)
case doTimeout:
timeout := r.clientBehavior.param.(time.Duration)
err = r.client.DoTimeout(ctx, req, resp, timeout)
default:
err = r.client.Do(ctx, req, resp)
}
return err
}

// b2s converts byte slice to a string without memory allocation.
// See https://groups.google.com/forum/#!msg/Golang-Nuts/ENgbUzYvCuU/90yGx7GUAgAJ .
//
Expand Down
1 change: 1 addition & 0 deletions reverse_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ func TestReverseProxySaveRespHeader(t *testing.T) {

proxy, err := NewSingleHostReverseProxy("http://127.0.0.1:9997/proxy")
proxy.SetSaveOriginResHeader(true)
proxy.SetClientBehavior(ClientDoRedirects(2))
if err != nil {
t.Errorf("proxy error: %v", err)
}
Expand Down
Loading