Skip to content

Commit

Permalink
Pass context and logger to db non-Transact methods
Browse files Browse the repository at this point in the history
- Context allows cancellation
- Logger allow logging times
  • Loading branch information
arielshaqed committed Aug 25, 2020
1 parent 1d967af commit 05a497f
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 11 deletions.
6 changes: 3 additions & 3 deletions catalog/cataloger_retention.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ func (c *cataloger) QueryEntriesToExpire(ctx context.Context, repositoryName str
// An object may have been deduped onto several branches with different names
// and will have multiple entries; it can only be remove once it expires from
// all of those.
rows, err := c.db.Queryx(dedupedQuery, args...)
rows, err := c.db.WithContext(ctx).Queryx(dedupedQuery, args...)
if err != nil {
return nil, fmt.Errorf("running query: %w", err)
}
Expand Down Expand Up @@ -303,7 +303,7 @@ func (c *cataloger) MarkObjectsForDeletion(ctx context.Context, repositoryName s
// TODO(ariels): This query is difficult to chunk. One way: Perform the inner SELECT
// once into a temporary table, then in a separate transaction chunk the UPDATE by
// dedup_id (this is not yet the real deletion).
result, err := c.db.Exec(`
result, err := c.db.WithContext(ctx).Exec(`
UPDATE catalog_object_dedup SET deleting=true
WHERE repository_id IN (SELECT id FROM catalog_repositories WHERE name = $1) AND
physical_address IN (
Expand Down Expand Up @@ -341,7 +341,7 @@ func (s *StringRows) Read() (string, error) {
// TODO(ariels): Process in chunks. Can store the inner physical_address query in a table for
// the duration.
func (c *cataloger) DeleteOrUnmarkObjectsForDeletion(ctx context.Context, repositoryName string) (StringRows, error) {
rows, err := c.db.Queryx(`
rows, err := c.db.WithContext(ctx).Queryx(`
WITH ids AS (SELECT id repository_id FROM catalog_repositories WHERE name = $1),
update_result AS (
UPDATE catalog_object_dedup SET deleting=all_expired
Expand Down
95 changes: 87 additions & 8 deletions db/database.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package db

import (
"context"
"database/sql"
"io"
"strconv"
Expand All @@ -24,38 +25,116 @@ type Database interface {
Transact(fn TxFunc, opts ...TxOpt) (interface{}, error)
Metadata() (map[string]string, error)
Stats() sql.DBStats
WithContext(ctx context.Context) Database
WithLogger(logger logging.Logger) Database
}

type QueryOptions struct {
logger logging.Logger
ctx context.Context
}

type SqlxDatabase struct {
db *sqlx.DB
db *sqlx.DB
queryOptions *QueryOptions
}

func NewSqlxDatabase(db *sqlx.DB) *SqlxDatabase {
return &SqlxDatabase{db: db}
}

func (d *SqlxDatabase) getLogger() logging.Logger {
if d.queryOptions != nil {
return d.queryOptions.logger
}
return logging.Default()
}

func (d *SqlxDatabase) getContext() context.Context {
if d.queryOptions != nil {
return d.queryOptions.ctx
}
return context.Background()
}

func (d *SqlxDatabase) WithContext(ctx context.Context) Database {
return &SqlxDatabase{
db: d.db,
queryOptions: &QueryOptions{
logger: d.getLogger(),
ctx: ctx,
},
}
}

func (d *SqlxDatabase) WithLogger(logger logging.Logger) Database {
return &SqlxDatabase{
db: d.db,
queryOptions: &QueryOptions{
logger: logger,
ctx: d.getContext(),
},
}
}

func (d *SqlxDatabase) Close() error {
return d.db.Close()
}

func (d *SqlxDatabase) Get(dest interface{}, query string, args ...interface{}) error {
return d.db.Get(dest, query, args...)
// reportFinish computes the duration since starts and logs a "done" report if that duration is
// long enough.
func (d *SqlxDatabase) reportFinish(err *error, fields logging.Fields, start time.Time) {
duration := time.Since(start)
if duration > 100*time.Millisecond {
d.getLogger().WithFields(fields).WithError(*err).WithField("duration", duration).Info("database done")
}
}

func (d *SqlxDatabase) Get(dest interface{}, query string, args ...interface{}) (err error) {
start := time.Now()
defer d.reportFinish(&err, logging.Fields{
"type": "get",
"query": query,
"args": args,
}, start)
return d.db.GetContext(d.getContext(), dest, query, args...)
}

func (d *SqlxDatabase) Queryx(query string, args ...interface{}) (*Rows, error) {
return d.db.Queryx(query, args...)
func (d *SqlxDatabase) Queryx(query string, args ...interface{}) (rows *Rows, err error) {
start := time.Now()
defer d.reportFinish(&err, logging.Fields{
"type": "start query",
"query": query,
"args": args,
}, start)
return d.db.QueryxContext(d.getContext(), query, args...)
}

func (d *SqlxDatabase) Exec(query string, args ...interface{}) (int64, error) {
res, err := d.db.Exec(query, args...)
func (d *SqlxDatabase) Exec(query string, args ...interface{}) (count int64, err error) {
start := time.Now()
defer d.reportFinish(&err, logging.Fields{
"type": "exec",
"query": query,
"args": args,
}, start)
res, err := d.db.ExecContext(d.getContext(), query, args...)
if err != nil {
return 0, err
}
return res.RowsAffected()
}

func (d *SqlxDatabase) Transact(fn TxFunc, opts ...TxOpt) (interface{}, error) {
func (d *SqlxDatabase) getTxOptions() *TxOptions {
options := DefaultTxOptions()
if d.queryOptions != nil {
options.logger = d.queryOptions.logger
options.ctx = d.queryOptions.ctx
}
return options
}

func (d *SqlxDatabase) Transact(fn TxFunc, opts ...TxOpt) (interface{}, error) {
options := d.getTxOptions()
for _, opt := range opts {
opt(options)
}
Expand Down

0 comments on commit 05a497f

Please sign in to comment.