From 2bc5866ef00db68dbd938265a9002a3b512b997e Mon Sep 17 00:00:00 2001 From: Anton Kozlov Date: Fri, 4 Nov 2022 17:07:25 +0000 Subject: [PATCH] Allow optional HTTP headers; For HTTPS connections, pass username and password as headers --- clickhouse_options.go | 3 ++- conn_http.go | 57 ++++++++++++++++++++++++++----------------- conn_http_query.go | 9 +++++-- 3 files changed, 44 insertions(+), 25 deletions(-) diff --git a/clickhouse_options.go b/clickhouse_options.go index cf412dd6c0..50acf4d5b8 100644 --- a/clickhouse_options.go +++ b/clickhouse_options.go @@ -129,7 +129,8 @@ type Options struct { MaxIdleConns int // default 5 ConnMaxLifetime time.Duration // default 1 hour ConnOpenStrategy ConnOpenStrategy - BlockBufferSize uint8 // default 2 - can be overwritten on query + HttpHeaders map[string]string // set additional headers on HTTP requests + BlockBufferSize uint8 // default 2 - can be overwritten on query scheme string ReadTimeout time.Duration diff --git a/conn_http.go b/conn_http.go index e547c05fb8..d3732f6a26 100644 --- a/conn_http.go +++ b/conn_http.go @@ -149,12 +149,22 @@ func dialHttp(ctx context.Context, addr string, num int, opt *Options) (*httpCon Host: addr, } - if len(opt.Auth.Username) > 0 { + headers := make(map[string]string) + for k, v := range opt.HttpHeaders { + headers[k] = v + } + + if opt.TLS == nil && len(opt.Auth.Username) > 0 { if len(opt.Auth.Password) > 0 { u.User = url.UserPassword(opt.Auth.Username, opt.Auth.Password) } else { u.User = url.User(opt.Auth.Username) } + } else if opt.TLS != nil && len(opt.Auth.Username) > 0 { + headers["X-Clickhouse-User"] = opt.Auth.Username + if len(opt.Auth.Password) > 0 { + headers["X-Clickhouse-Key"] = opt.Auth.Password + } } query := u.Query() @@ -195,12 +205,13 @@ func dialHttp(ctx context.Context, addr string, num int, opt *Options) (*httpCon client: &http.Client{ Transport: t, }, - url: u, - buffer: new(chproto.Buffer), - compression: opt.Compression.Method, - blockCompressor: compress.NewWriter(), - compressionPool: compressionPool, - blockBufferSize: opt.BlockBufferSize, + url: u, + buffer: new(chproto.Buffer), + compression: opt.Compression.Method, + blockCompressor: compress.NewWriter(), + compressionPool: compressionPool, + blockBufferSize: opt.BlockBufferSize, + additionalHttpHeaders: headers, } location, err := conn.readTimeZone(ctx) if err != nil { @@ -220,25 +231,27 @@ func dialHttp(ctx context.Context, addr string, num int, opt *Options) (*httpCon client: &http.Client{ Transport: t, }, - url: u, - buffer: new(chproto.Buffer), - compression: opt.Compression.Method, - blockCompressor: compress.NewWriter(), - compressionPool: compressionPool, - location: location, - blockBufferSize: opt.BlockBufferSize, + url: u, + buffer: new(chproto.Buffer), + compression: opt.Compression.Method, + blockCompressor: compress.NewWriter(), + compressionPool: compressionPool, + location: location, + blockBufferSize: opt.BlockBufferSize, + additionalHttpHeaders: headers, }, nil } type httpConnect struct { - url *url.URL - client *http.Client - location *time.Location - buffer *chproto.Buffer - compression CompressionMethod - blockCompressor *compress.Writer - compressionPool Pool[HTTPReaderWriter] - blockBufferSize uint8 + url *url.URL + client *http.Client + location *time.Location + buffer *chproto.Buffer + compression CompressionMethod + blockCompressor *compress.Writer + compressionPool Pool[HTTPReaderWriter] + blockBufferSize uint8 + additionalHttpHeaders map[string]string } func (h *httpConnect) isBad() bool { diff --git a/conn_http_query.go b/conn_http_query.go index a5c301f6ff..fbff92ea2e 100644 --- a/conn_http_query.go +++ b/conn_http_query.go @@ -21,10 +21,11 @@ import ( "bytes" "context" "errors" - chproto "github.com/ClickHouse/ch-go/proto" - "github.com/ClickHouse/clickhouse-go/v2/lib/proto" "io" "strings" + + chproto "github.com/ClickHouse/ch-go/proto" + "github.com/ClickHouse/clickhouse-go/v2/lib/proto" ) // release is ignored, because http used by std with empty release function @@ -43,6 +44,10 @@ func (h *httpConnect) query(ctx context.Context, release func(*connect, error), headers["Accept-Encoding"] = h.compression.String() } + for k, v := range h.additionalHttpHeaders { + headers[k] = v + } + res, err := h.sendQuery(ctx, strings.NewReader(query), &options, headers) if err != nil { return nil, err