From 5b4d41d3f5ddc77dd6a0d347b07883ff6238e3db Mon Sep 17 00:00:00 2001 From: Barak Amar Date: Sun, 8 Nov 2020 16:39:05 +0200 Subject: [PATCH] DB scanner options by value with option to update AdditionalWhere --- catalog/db_branch_scanner.go | 12 +++++++----- catalog/db_branch_scanner_test.go | 8 ++++---- catalog/db_diff_scanner.go | 12 ++++++------ catalog/db_lineage_scanner.go | 21 +++++++++++++++------ catalog/db_lineage_scanner_test.go | 7 +++++-- catalog/db_scanner.go | 1 + 6 files changed, 38 insertions(+), 23 deletions(-) diff --git a/catalog/db_branch_scanner.go b/catalog/db_branch_scanner.go index aa86010110b..d30aca7f2f7 100644 --- a/catalog/db_branch_scanner.go +++ b/catalog/db_branch_scanner.go @@ -27,16 +27,14 @@ type DBBranchScanner struct { value *DBScannerEntry } -func NewDBBranchScanner(tx db.Tx, branchID int64, commitID CommitID, opts *DBScannerOptions) *DBBranchScanner { +func NewDBBranchScanner(tx db.Tx, branchID int64, commitID CommitID, opts DBScannerOptions) *DBBranchScanner { s := &DBBranchScanner{ tx: tx, branchID: branchID, idx: 0, commitID: commitID, - } - if opts != nil { - s.opts = *opts - s.after = opts.After + opts: opts, + after: opts.After, } if s.opts.BufferSize == 0 { s.opts.BufferSize = DBScannerDefaultBufferSize @@ -48,6 +46,10 @@ func NewDBBranchScanner(tx db.Tx, branchID int64, commitID CommitID, opts *DBSca return s } +func (s *DBBranchScanner) SetAdditionalWhere(part sq.Sqlizer) { + s.opts.AdditionalWhere = part +} + func getRelevantCommitsCondition(tx db.Tx, branchID int64, commitID CommitID) (string, error) { var branchMaxCommitID CommitID var commits []string diff --git a/catalog/db_branch_scanner_test.go b/catalog/db_branch_scanner_test.go index 6f7aa15bc2e..29f104f2e38 100644 --- a/catalog/db_branch_scanner_test.go +++ b/catalog/db_branch_scanner_test.go @@ -33,7 +33,7 @@ func TestDBBranchScanner(t *testing.T) { t.Run("empty", func(t *testing.T) { _, _ = conn.Transact(func(tx db.Tx) (interface{}, error) { - scanner := NewDBBranchScanner(tx, branchID, UncommittedID, &DBScannerOptions{BufferSize: 64}) + scanner := NewDBBranchScanner(tx, branchID, UncommittedID, DBScannerOptions{BufferSize: 64}) firstNext := scanner.Next() if firstNext { t.Fatalf("first entry should be false") @@ -54,7 +54,7 @@ func TestDBBranchScanner(t *testing.T) { t.Run("additional_fields", func(t *testing.T) { _, _ = conn.Transact(func(tx db.Tx) (interface{}, error) { - scanner := NewDBBranchScanner(tx, branchID, UncommittedID, &DBScannerOptions{ + scanner := NewDBBranchScanner(tx, branchID, UncommittedID, DBScannerOptions{ AdditionalFields: []string{"checksum", "physical_address"}, }) testedSomething := false @@ -81,7 +81,7 @@ func TestDBBranchScanner(t *testing.T) { t.Run("additional_where", func(t *testing.T) { _, _ = conn.Transact(func(tx db.Tx) (interface{}, error) { p := fmt.Sprintf("Obj-%04d", numberOfObjects-5) - scanner := NewDBBranchScanner(tx, branchID, UncommittedID, &DBScannerOptions{ + scanner := NewDBBranchScanner(tx, branchID, UncommittedID, DBScannerOptions{ AdditionalWhere: sq.Expr("path=?", p), }) var ent *DBScannerEntry @@ -106,7 +106,7 @@ func TestDBBranchScanner(t *testing.T) { _, _ = conn.Transact(func(tx db.Tx) (interface{}, error) { branchID, err := getBranchID(tx, repository, branchName, LockTypeNone) testutil.MustDo(t, "get branch ID", err) - scanner := NewDBBranchScanner(tx, branchID, UncommittedID, &DBScannerOptions{BufferSize: bufSize}) + scanner := NewDBBranchScanner(tx, branchID, UncommittedID, DBScannerOptions{BufferSize: bufSize}) i := 0 for scanner.Next() { o := scanner.Value() diff --git a/catalog/db_diff_scanner.go b/catalog/db_diff_scanner.go index c956491fb73..ba093502f05 100644 --- a/catalog/db_diff_scanner.go +++ b/catalog/db_diff_scanner.go @@ -86,8 +86,8 @@ func (s *DiffScanner) diffFromParent(tx db.Tx, params doDiffParams) (*DiffScanne After: params.After, AdditionalFields: prepareDiffAdditionalFields(params.AdditionalFields), } - s.leftScanner = NewDBLineageScanner(tx, params.LeftBranchID, CommittedID, &scannerOpts) - s.rightScanner = NewDBLineageScanner(tx, params.RightBranchID, UncommittedID, &scannerOpts) + s.leftScanner = NewDBLineageScanner(tx, params.LeftBranchID, CommittedID, scannerOpts) + s.rightScanner = NewDBLineageScanner(tx, params.RightBranchID, UncommittedID, scannerOpts) s.childLineage, err = getLineage(tx, params.RightBranchID, UncommittedID) if err != nil { return nil, err @@ -107,8 +107,8 @@ func (s *DiffScanner) diffFromChild(tx db.Tx, params doDiffParams) (*DiffScanner if err != nil { return nil, err } - s.leftScanner = NewDBBranchScanner(tx, params.LeftBranchID, CommittedID, &scannerOpts) - s.rightScanner = NewDBLineageScanner(tx, params.RightBranchID, UncommittedID, &scannerOpts) + s.leftScanner = NewDBBranchScanner(tx, params.LeftBranchID, CommittedID, scannerOpts) + s.rightScanner = NewDBLineageScanner(tx, params.RightBranchID, UncommittedID, scannerOpts) return s, nil } @@ -120,8 +120,8 @@ func (s *DiffScanner) diffSameBranch(tx db.Tx, params doDiffParams) (*DiffScanne After: params.After, AdditionalFields: prepareDiffAdditionalFields(params.AdditionalFields), } - s.leftScanner = NewDBLineageScanner(tx, params.LeftBranchID, params.LeftCommitID, &scannerOpts) - s.rightScanner = NewDBLineageScanner(tx, params.RightBranchID, params.RightCommitID, &scannerOpts) + s.leftScanner = NewDBLineageScanner(tx, params.LeftBranchID, params.LeftCommitID, scannerOpts) + s.rightScanner = NewDBLineageScanner(tx, params.RightBranchID, params.RightCommitID, scannerOpts) return s, nil } diff --git a/catalog/db_lineage_scanner.go b/catalog/db_lineage_scanner.go index 0340b60877d..b8ba6fd35f8 100644 --- a/catalog/db_lineage_scanner.go +++ b/catalog/db_lineage_scanner.go @@ -3,6 +3,8 @@ package catalog import ( "fmt" + sq "github.com/Masterminds/squirrel" + "github.com/treeverse/lakefs/db" ) @@ -17,18 +19,25 @@ type DBLineageScanner struct { opts DBScannerOptions } -func NewDBLineageScanner(tx db.Tx, branchID int64, commitID CommitID, opts *DBScannerOptions) *DBLineageScanner { +func NewDBLineageScanner(tx db.Tx, branchID int64, commitID CommitID, opts DBScannerOptions) *DBLineageScanner { s := &DBLineageScanner{ tx: tx, branchID: branchID, commitID: commitID, - } - if opts != nil { - s.opts = *opts + opts: opts, } return s } +func (s *DBLineageScanner) SetAdditionalWhere(part sq.Sqlizer) { + s.opts.AdditionalWhere = part + if s.scanners != nil { + for _, scanner := range s.scanners { + scanner.SetAdditionalWhere(part) + } + } +} + func (s *DBLineageScanner) Next() bool { if s.ended { return false @@ -98,9 +107,9 @@ func (s *DBLineageScanner) ensureBranchScanners() bool { return false } s.scanners = make([]*DBBranchScanner, len(lineage)+1) - s.scanners[0] = NewDBBranchScanner(s.tx, s.branchID, s.commitID, &s.opts) + s.scanners[0] = NewDBBranchScanner(s.tx, s.branchID, s.commitID, s.opts) for i, bl := range lineage { - s.scanners[i+1] = NewDBBranchScanner(s.tx, bl.BranchID, bl.CommitID, &s.opts) + s.scanners[i+1] = NewDBBranchScanner(s.tx, bl.BranchID, bl.CommitID, s.opts) } for _, branchScanner := range s.scanners { if branchScanner.Next() { diff --git a/catalog/db_lineage_scanner_test.go b/catalog/db_lineage_scanner_test.go index 66837183745..ec1444b0cc8 100644 --- a/catalog/db_lineage_scanner_test.go +++ b/catalog/db_lineage_scanner_test.go @@ -31,7 +31,7 @@ func TestDBLineageScanner(t *testing.T) { branchName := "b" + strconv.Itoa(branchNo) branchID, err := getBranchID(tx, repository, branchName, LockTypeNone) testutil.MustDo(t, "get branch id", err) - scanner := NewDBLineageScanner(tx, branchID, UncommittedID, &DBScannerOptions{BufferSize: bufSize}) + scanner := NewDBLineageScanner(tx, branchID, UncommittedID, DBScannerOptions{BufferSize: bufSize}) for i := 0; scanner.Next(); i++ { o := scanner.Value() if o == nil { @@ -73,7 +73,10 @@ func TestDBLineageScanner(t *testing.T) { // test reading committed and uncommitted data const bufSize = 8 - scannerOpts := &DBScannerOptions{BufferSize: bufSize, After: "Obj-0003"} + scannerOpts := DBScannerOptions{ + BufferSize: bufSize, + After: "Obj-0003", + } testCatalogerCreateEntry(t, ctx, c, repository, "b1", "Obj-0004", nil, "sd1") _, _ = conn.Transact(func(tx db.Tx) (interface{}, error) { lineageScannerB1U := NewDBLineageScanner(tx, b1BranchID, UncommittedID, scannerOpts) diff --git a/catalog/db_scanner.go b/catalog/db_scanner.go index c8c9cdc5209..eff805104ba 100644 --- a/catalog/db_scanner.go +++ b/catalog/db_scanner.go @@ -13,6 +13,7 @@ type DBScanner interface { Next() bool Value() *DBScannerEntry Err() error + SetAdditionalWhere(s sq.Sqlizer) } func ScanDBEntriesUntil(s DBScanner, p string) error {