diff --git a/main.go b/main.go index 7098f81..ae5b871 100644 --- a/main.go +++ b/main.go @@ -28,7 +28,7 @@ type Options struct { StopAfter string `long:"stop-after" description:"Stop after N requests per endpoint, N can be a number or duration."` Concurrency int `long:"concurrency" description:"Concurrent requests per endpoint" default:"1"` - //Source string `long:"source" description:"Where to get requests from (options: stdin-jsons, ethspam)" default:"stdin-jsons"` // Someday: stdin-tcpdump, file://foo.json, ws://remote-endpoint + //Source string `long:"source" description:"Where requests come from (options: stdin-post, stdin-get)" default:"stdin-jsons"` // Someday: stdin-tcpdump, file://foo.json, ws://remote-endpoint // TODO: Specify additional headers/configs per-endpoint (e.g. auth headers) // TODO: Periodic reporting for long-running tests? @@ -86,7 +86,7 @@ func main() { }(abort) if err := run(ctx, options); err != nil { - exit(2, "error during run: %s", err) + exit(2, "error during run: %s\n", err) } } @@ -136,7 +136,7 @@ func run(ctx context.Context, options Options) error { g, ctx := errgroup.WithContext(ctx) // responses is closed when clients are shut down - responses := make(chan Response, options.Concurrency*2) + responses := make(chan Response, options.Concurrency*4) // Launch clients clients, err := NewClients(options.Args.Endpoints, options.Concurrency, timeout) @@ -181,6 +181,10 @@ func pump(ctx context.Context, r io.Reader, clients Clients, stopAfter int) erro defer clients.Finalize() scanner := bufio.NewScanner(r) + // Some lines are really long, let's allocate a big fat megabyte for lines. + buf := make([]byte, 1024*1024) + scanner.Buffer(buf, cap(buf)) + n := 0 for scanner.Scan() { select { diff --git a/transport.go b/transport.go index 2adfeb7..37f5fe6 100644 --- a/transport.go +++ b/transport.go @@ -7,29 +7,54 @@ import ( "io/ioutil" "net/http" "net/url" + "path" + "strings" "time" ) +// NewTransport creates a transport that supports the given endpoint. The +// endpoint is a URI with a scheme and an optional mode, for example +// "https+get://infura.io/". func NewTransport(endpoint string, timeout time.Duration) (Transport, error) { url, err := url.Parse(endpoint) if err != nil { return nil, err } - switch url.Scheme { + scheme, mode := url.Scheme, "" + if parts := strings.Split(scheme, "+"); len(parts) > 1 { + scheme, mode = parts[0], parts[1] + } + var t Transport + switch scheme { case "http", "https": - return &httpTransport{ + url.Scheme = scheme + t = &httpTransport{ Client: http.Client{Timeout: timeout}, - endpoint: endpoint, - }, nil + endpoint: url.String(), + } case "ws", "wss": // TODO: Implement - return &websocketTransport{}, nil + t = &websocketTransport{} case "noop": - return &noopTransport{}, nil + t = &noopTransport{} + default: + return nil, fmt.Errorf("unsupported transport: %s", scheme) + } + + if mode == "" { + return t, nil } - return nil, fmt.Errorf("unsupported transport: %s", url.Scheme) + if modalTransport, ok := t.(Modal); ok { + return t, modalTransport.Mode(mode) + } + return nil, fmt.Errorf("transport is not modal: %s", scheme) } +// Modal is a type of Transport that has multiple modes for interpreting the +// payloads sent to it. Not all transports support modes. +type Modal interface { + Mode(string) error +} type Transport interface { // TODO: Add context? // TODO: Should this be: Do(Request) (Response, error)? @@ -39,11 +64,43 @@ type Transport interface { type httpTransport struct { http.Client - endpoint string + contentType string + endpoint string + + getHost string + getPath string +} + +func (t *httpTransport) Mode(m string) error { + switch strings.ToLower(m) { + case "post": + t.getHost = "" + case "get": + url, err := url.Parse(t.endpoint) + if err != nil { + return err + } + t.getPath = url.Path + if t.getPath == "" { + t.getPath = "/" + } + url.Path = "" + t.getHost = url.String() + default: + return fmt.Errorf("invalid mode for http transport: %s", m) + } + return nil } func (t *httpTransport) Send(body []byte) ([]byte, error) { - resp, err := t.Client.Post(t.endpoint, "", bytes.NewReader(body)) + var resp *http.Response + var err error + if t.getHost != "" { + url := t.getHost + path.Join(t.getPath, string(body)) + resp, err = t.Client.Get(url) + } else { + resp, err = t.Client.Post(t.endpoint, t.contentType, bytes.NewReader(body)) + } if err != nil { return nil, err }