Skip to content

Commit

Permalink
Better handling for connection errors in the std interface
Browse files Browse the repository at this point in the history
- check for EOF/EPIPE errors in the major functions of
  clickhouse_std.go and explicitly return driver.ErrBadConn
  when we catch them

- add debug logging of most error conditions to the std interface,
  following the existing pattern for error handling

Signed-off-by: Nathan J. Mehl <n@oden.io>
  • Loading branch information
n-oden committed Jan 25, 2023
1 parent 0c79b0f commit 1ede73c
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 17 deletions.
115 changes: 98 additions & 17 deletions clickhouse_std.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,23 @@ import (
"database/sql/driver"
"errors"
"fmt"
"github.com/ClickHouse/clickhouse-go/v2/lib/column"
ldriver "github.com/ClickHouse/clickhouse-go/v2/lib/driver"
"io"
"log"
"os"
"reflect"
"strings"
"sync/atomic"

"github.com/ClickHouse/clickhouse-go/v2/lib/column"
ldriver "github.com/ClickHouse/clickhouse-go/v2/lib/driver"
"syscall"
)

var globalConnID int64

type stdConnOpener struct {
err error
opt *Options
err error
opt *Options
debugf func(format string, v ...interface{})
}

func (o *stdConnOpener) Driver() driver.Driver {
Expand All @@ -45,6 +48,7 @@ func (o *stdConnOpener) Driver() driver.Driver {

func (o *stdConnOpener) Connect(ctx context.Context) (_ driver.Conn, err error) {
if o.err != nil {
o.debugf("[connect] opener error: %v\n", o.err)
return nil, o.err
}
var (
Expand Down Expand Up @@ -77,9 +81,16 @@ func (o *stdConnOpener) Connect(ctx context.Context) (_ driver.Conn, err error)
num = (int(connID) + i) % len(o.opt.Addr)
}
if conn, err = dialFunc(ctx, o.opt.Addr[num], connID, o.opt); err == nil {
var debugf = func(format string, v ...interface{}) {}
if o.opt.Debug {
debugf = log.New(os.Stdout, fmt.Sprintf("[clickhouse-std][conn=%d][%s] ", num, o.opt.Addr[num]), 0).Printf
}
return &stdDriver{
conn: conn,
conn: conn,
debugf: debugf,
}, nil
} else {
o.debugf("[connect] error connecting to %s on connection %d: %v\n", o.opt.Addr[num], connID, err)
}
}
return nil, err
Expand All @@ -89,18 +100,34 @@ func init() {
sql.Register("clickhouse", &stdDriver{})
}

// isFatalError returns true if the error class indicates that the
// db connection is no longer usable and should be marked bad
func isFatalError(err error) bool {
if errors.Is(err, io.EOF) || errors.Is(err, syscall.EPIPE) {
return true
}
return false
}

func Connector(opt *Options) driver.Connector {
if opt == nil {
opt = &Options{}
}

o := opt.setDefaults()

var debugf = func(format string, v ...interface{}) {}
if o.Debug {
debugf = log.New(os.Stdout, fmt.Sprintf("[clickhouse-std][opener] "), 0).Printf
}
return &stdConnOpener{
opt: o,
opt: o,
debugf: debugf,
}
}

func OpenDB(opt *Options) *sql.DB {
var debugf = func(format string, v ...interface{}) {}
if opt == nil {
opt = &Options{}
}
Expand All @@ -114,14 +141,19 @@ func OpenDB(opt *Options) *sql.DB {
if opt.ConnMaxLifetime > 0 {
settings = append(settings, "SetConnMaxLifetime")
}
if opt.Debug {
debugf = log.New(os.Stdout, fmt.Sprintf("[clickhouse-std][opener] "), 0).Printf
}
if len(settings) != 0 {
return sql.OpenDB(&stdConnOpener{
err: fmt.Errorf("cannot connect. invalid settings. use %s (see https://pkg.go.dev/database/sql)", strings.Join(settings, ",")),
err: fmt.Errorf("cannot connect. invalid settings. use %s (see https://pkg.go.dev/database/sql)", strings.Join(settings, ",")),
debugf: debugf,
})
}
o := opt.setDefaults()
return sql.OpenDB(&stdConnOpener{
opt: o,
opt: o,
debugf: debugf,
})
}

Expand All @@ -138,11 +170,13 @@ type stdConnect interface {
type stdDriver struct {
conn stdConnect
commit func() error
debugf func(format string, v ...interface{})
}

func (std *stdDriver) Open(dsn string) (_ driver.Conn, err error) {
var opt Options
if err := opt.fromDSN(dsn); err != nil {
std.debugf("Open dsn error: %v\n", err)
return nil, err
}
o := opt.setDefaults()
Expand All @@ -151,6 +185,7 @@ func (std *stdDriver) Open(dsn string) (_ driver.Conn, err error) {

func (std *stdDriver) ResetSession(ctx context.Context) error {
if std.conn.isBad() {
std.debugf("Resetting session because connection is bad")
return driver.ErrBadConn
}
return nil
Expand All @@ -167,7 +202,16 @@ func (std *stdDriver) Commit() error {
defer func() {
std.commit = nil
}()
return std.commit()

if err := std.commit(); err != nil {
if isFatalError(err) {
std.debugf("Commit got EOF error: resetting connection")
return driver.ErrBadConn
}
std.debugf("Commit error: %v\n", err)
return err
}
return nil
}

func (std *stdDriver) Rollback() error {
Expand All @@ -186,18 +230,29 @@ func (std *stdDriver) ExecContext(ctx context.Context, query string, args []driv
return driver.RowsAffected(0), std.conn.asyncInsert(ctx, query, options.async.wait)
}
if err := std.conn.exec(ctx, query, rebind(args)...); err != nil {
if isFatalError(err) {
std.debugf("ExecContext got a fatal error, resetting connection: %v\n", err)
return nil, driver.ErrBadConn
}
std.debugf("ExecContext error: %v\n", err)
return nil, err
}
return driver.RowsAffected(0), nil
}

func (std *stdDriver) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
r, err := std.conn.query(ctx, func(*connect, error) {}, query, rebind(args)...)
if isFatalError(err) {
std.debugf("QueryContext got a fatal error, resetting connection: %v\n", err)
return nil, driver.ErrBadConn
}
if err != nil {
std.debugf("QueryContext error: %v\n", err)
return nil, err
}
return &stdRows{
rows: r,
rows: r,
debugf: std.debugf,
}, nil
}

Expand All @@ -208,18 +263,34 @@ func (std *stdDriver) Prepare(query string) (driver.Stmt, error) {
func (std *stdDriver) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
batch, err := std.conn.prepareBatch(ctx, query, func(*connect, error) {})
if err != nil {
if isFatalError(err) {
std.debugf("PrepareContext got a fatal error, resetting connection: %v\n", err)
}
std.debugf("PrepareContext error: %v\n", err)
return nil, err
}
std.commit = batch.Send
return &stdBatch{
batch: batch,
batch: batch,
debugf: std.debugf,
}, nil
}

func (std *stdDriver) Close() error { return std.conn.close() }
func (std *stdDriver) Close() error {
err := std.conn.close()
if err != nil {
if isFatalError(err) {
std.debugf("Close got a fatal error, resetting connection: %v\n", err)
return driver.ErrBadConn
}
std.debugf("Close error: %v\n", err)
}
return err
}

type stdBatch struct {
batch ldriver.Batch
batch ldriver.Batch
debugf func(format string, v ...interface{})
}

func (s *stdBatch) NumInput() int { return -1 }
Expand All @@ -229,6 +300,7 @@ func (s *stdBatch) Exec(args []driver.Value) (driver.Result, error) {
values = append(values, v)
}
if err := s.batch.Append(values...); err != nil {
s.debugf("[batch][exec] append error: %v", err)
return nil, err
}
return driver.RowsAffected(0), nil
Expand All @@ -249,7 +321,8 @@ func (s *stdBatch) Query(args []driver.Value) (driver.Rows, error) {
func (s *stdBatch) Close() error { return nil }

type stdRows struct {
rows *rows
rows *rows
debugf func(format string, v ...interface{})
}

func (r *stdRows) Columns() []string {
Expand Down Expand Up @@ -284,9 +357,11 @@ func (r *stdRows) ColumnTypePrecisionScale(idx int) (precision, scale int64, ok

func (r *stdRows) Next(dest []driver.Value) error {
if len(r.rows.block.Columns) != len(dest) {
err := fmt.Errorf("expected %d destination arguments in Next, not %d", len(r.rows.block.Columns), len(dest))
r.debugf("Next length error: %v\n", err)
return &OpError{
Op: "Next",
Err: fmt.Errorf("expected %d destination arguments in Next, not %d", len(r.rows.block.Columns), len(dest)),
Err: err,
}
}
if r.rows.Next() {
Expand All @@ -296,6 +371,7 @@ func (r *stdRows) Next(dest []driver.Value) error {
case driver.Valuer:
v, err := value.Value()
if err != nil {
r.debugf("Next row error: %v\n", err)
return err
}
dest[i] = v
Expand All @@ -306,6 +382,7 @@ func (r *stdRows) Next(dest []driver.Value) error {
return nil
}
if err := r.rows.Err(); err != nil {
r.debugf("Next rows error: %v\n", err)
return err
}
return io.EOF
Expand All @@ -327,5 +404,9 @@ func (r *stdRows) NextResultSet() error {
}

func (r *stdRows) Close() error {
return r.rows.Close()
err := r.rows.Close()
if err != nil {
r.debugf("Rows Close error: %v\n", err)
}
return err
}
8 changes: 8 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"fmt"
"github.com/ClickHouse/clickhouse-go/v2/resources"
"github.com/pkg/errors"
"io"
"log"
"net"
"os"
Expand Down Expand Up @@ -230,6 +231,11 @@ func (c *connect) sendData(block *proto.Block, name string) error {
if errors.Is(err, syscall.EPIPE) {
c.debugf("[send data] pipe is broken, closing connection")
c.closed = true
} else if errors.Is(err, io.EOF) {
c.debugf("[send data] unexpected EOF, closing connection")
c.closed = true
} else {
c.debugf("[send data] unexpected error: %v", err)
}
return err
}
Expand All @@ -241,6 +247,7 @@ func (c *connect) sendData(block *proto.Block, name string) error {

func (c *connect) readData(packet byte, compressible bool) (*proto.Block, error) {
if _, err := c.reader.Str(); err != nil {
c.debugf("[read data] str error: %v", err)
return nil, err
}
if compressible && c.compression != CompressionNone {
Expand All @@ -249,6 +256,7 @@ func (c *connect) readData(packet byte, compressible bool) (*proto.Block, error)
}
block := proto.Block{Timezone: c.server.Timezone}
if err := block.Decode(c.reader, c.revision); err != nil {
c.debugf("[read data] decode error: %v", err)
return nil, err
}
block.Packet = packet
Expand Down
3 changes: 3 additions & 0 deletions conn_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func (c *connect) query(ctx context.Context, release func(*connect, error), quer
)

if err != nil {
c.debugf("[bindQuery] error: %v", err)
release(c, err)
return nil, err
}
Expand All @@ -54,6 +55,7 @@ func (c *connect) query(ctx context.Context, release func(*connect, error), quer
init, err := c.firstBlock(ctx, onProcess)

if err != nil {
c.debugf("[query] first block error: %v", err)
release(c, err)
return nil, err
}
Expand All @@ -73,6 +75,7 @@ func (c *connect) query(ctx context.Context, release func(*connect, error), quer
}
err := c.process(ctx, onProcess)
if err != nil {
c.debugf("[query] process error: %v", err)
errors <- err
}
close(stream)
Expand Down

0 comments on commit 1ede73c

Please sign in to comment.