Skip to content

Commit

Permalink
feat: add Options
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Jan 18, 2025
1 parent 4cbb15a commit 815e11a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 40 deletions.
8 changes: 8 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ type DBStats struct {

type DBOption func(db *DB)

func WithOptions(opts ...DBOption) DBOption {
return func(db *DB) {
for _, opt := range opts {
opt(db)
}
}
}

func WithDiscardUnknownColumns() DBOption {
return func(db *DB) {
db.flags = db.flags.Set(discardUnknownColumns)
Expand Down
85 changes: 45 additions & 40 deletions driver/pgdriver/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func newDefaultConfig() *Config {
host := env("PGHOST", "localhost")
port := env("PGPORT", "5432")

cfg := &Config{
conf := &Config{
Network: "tcp",
Addr: net.JoinHostPort(host, port),
DialTimeout: 5 * time.Second,
Expand All @@ -63,52 +63,57 @@ func newDefaultConfig() *Config {
WriteTimeout: 5 * time.Second,
}

cfg.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
conf.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
netDialer := &net.Dialer{
Timeout: cfg.DialTimeout,
Timeout: conf.DialTimeout,
KeepAlive: 5 * time.Minute,
}
return netDialer.DialContext(ctx, network, addr)
}

return cfg
return conf
}

type Option func(cfg *Config)
type Option func(conf *Config)

// Deprecated. Use Option instead.
type DriverOption = Option
func WithOptions(opts ...Option) Option {
return func(conf *Config) {
for _, opt := range opts {
opt(conf)
}
}
}

func WithNetwork(network string) Option {
if network == "" {
panic("network is empty")
}
return func(cfg *Config) {
cfg.Network = network
return func(conf *Config) {
conf.Network = network
}
}

func WithAddr(addr string) Option {
if addr == "" {
panic("addr is empty")
}
return func(cfg *Config) {
cfg.Addr = addr
return func(conf *Config) {
conf.Addr = addr
}
}

func WithTLSConfig(tlsConfig *tls.Config) Option {
return func(cfg *Config) {
cfg.TLSConfig = tlsConfig
return func(conf *Config) {
conf.TLSConfig = tlsConfig
}
}

func WithInsecure(on bool) Option {
return func(cfg *Config) {
return func(conf *Config) {
if on {
cfg.TLSConfig = nil
conf.TLSConfig = nil
} else {
cfg.TLSConfig = &tls.Config{InsecureSkipVerify: true}
conf.TLSConfig = &tls.Config{InsecureSkipVerify: true}
}
}
}
Expand All @@ -117,81 +122,81 @@ func WithUser(user string) Option {
if user == "" {
panic("user is empty")
}
return func(cfg *Config) {
cfg.User = user
return func(conf *Config) {
conf.User = user
}
}

func WithPassword(password string) Option {
return func(cfg *Config) {
cfg.Password = password
return func(conf *Config) {
conf.Password = password
}
}

func WithDatabase(database string) Option {
if database == "" {
panic("database is empty")
}
return func(cfg *Config) {
cfg.Database = database
return func(conf *Config) {
conf.Database = database
}
}

func WithApplicationName(appName string) Option {
return func(cfg *Config) {
cfg.AppName = appName
return func(conf *Config) {
conf.AppName = appName
}
}

func WithConnParams(params map[string]interface{}) Option {
return func(cfg *Config) {
cfg.ConnParams = params
return func(conf *Config) {
conf.ConnParams = params
}
}

func WithTimeout(timeout time.Duration) Option {
return func(cfg *Config) {
cfg.DialTimeout = timeout
cfg.ReadTimeout = timeout
cfg.WriteTimeout = timeout
return func(conf *Config) {
conf.DialTimeout = timeout
conf.ReadTimeout = timeout
conf.WriteTimeout = timeout
}
}

func WithDialTimeout(dialTimeout time.Duration) Option {
return func(cfg *Config) {
cfg.DialTimeout = dialTimeout
return func(conf *Config) {
conf.DialTimeout = dialTimeout
}
}

func WithReadTimeout(readTimeout time.Duration) Option {
return func(cfg *Config) {
cfg.ReadTimeout = readTimeout
return func(conf *Config) {
conf.ReadTimeout = readTimeout
}
}

func WithWriteTimeout(writeTimeout time.Duration) Option {
return func(cfg *Config) {
cfg.WriteTimeout = writeTimeout
return func(conf *Config) {
conf.WriteTimeout = writeTimeout
}
}

// WithResetSessionFunc configures a function that is called prior to executing
// a query on a connection that has been used before.
// If the func returns driver.ErrBadConn, the connection is discarded.
func WithResetSessionFunc(fn func(context.Context, *Conn) error) Option {
return func(cfg *Config) {
cfg.ResetSessionFunc = fn
return func(conf *Config) {
conf.ResetSessionFunc = fn
}
}

func WithDSN(dsn string) Option {
return func(cfg *Config) {
return func(conf *Config) {
opts, err := parseDSN(dsn)
if err != nil {
panic(err)
}
for _, opt := range opts {
opt(cfg)
opt(conf)
}
}
}
Expand Down

0 comments on commit 815e11a

Please sign in to comment.