From 0746c0981f82b7946eb0205a50ebdf03bb36db27 Mon Sep 17 00:00:00 2001 From: Justin Date: Mon, 30 Sep 2024 13:57:59 -0400 Subject: [PATCH] change to WithReadSessionProject option in anticipation of https://github.com/googleapis/google-cloud-go/pull/10932 --- bigquery/storage_client.go | 20 ++++++---------- bigquery/storage_integration_test.go | 36 ++++++++++++++++++++++++++++ bigquery/storage_iterator.go | 6 ++--- bigquery/table.go | 17 +++++++------ 4 files changed, 56 insertions(+), 23 deletions(-) diff --git a/bigquery/storage_client.go b/bigquery/storage_client.go index afb60bd74371..8c48ed308ca1 100644 --- a/bigquery/storage_client.go +++ b/bigquery/storage_client.go @@ -95,7 +95,7 @@ func (c *readClient) close() error { } // sessionForTable establishes a new session to fetch from a table using the Storage API -func (c *readClient) sessionForTable(ctx context.Context, table *Table, ordered bool, useClientProject bool) (*readSession, error) { +func (c *readClient) sessionForTable(ctx context.Context, table *Table, rsProjectID string, ordered bool) (*readSession, error) { tableID, err := table.Identifier(StorageAPIResourceID) if err != nil { return nil, err @@ -107,17 +107,11 @@ func (c *readClient) sessionForTable(ctx context.Context, table *Table, ordered settings.maxStreamCount = 1 } - // configure where the read session is created - readSessionProjectID := table.ProjectID - if useClientProject { - readSessionProjectID = c.projectID - } - rs := &readSession{ ctx: ctx, - readSessionProjectID: readSessionProjectID, table: table, tableID: tableID, + projectID: rsProjectID, settings: settings, readRowsFunc: c.rawClient.ReadRows, createReadSessionFunc: c.rawClient.CreateReadSession, @@ -129,10 +123,10 @@ func (c *readClient) sessionForTable(ctx context.Context, table *Table, ordered type readSession struct { settings readClientSettings - ctx context.Context - readSessionProjectID string - table *Table - tableID string + ctx context.Context + table *Table + tableID string + projectID string bqSession *storagepb.ReadSession @@ -150,7 +144,7 @@ func (rs *readSession) start() error { } createReadSessionRequest := &storagepb.CreateReadSessionRequest{ - Parent: fmt.Sprintf("projects/%s", rs.readSessionProjectID), + Parent: fmt.Sprintf("projects/%s", rs.projectID), ReadSession: &storagepb.ReadSession{ Table: rs.tableID, DataFormat: storagepb.DataFormat_ARROW, diff --git a/bigquery/storage_integration_test.go b/bigquery/storage_integration_test.go index cbc9b5afd51b..d1fac66add5b 100644 --- a/bigquery/storage_integration_test.go +++ b/bigquery/storage_integration_test.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "strings" "testing" "time" @@ -87,6 +88,41 @@ func TestIntegration_StorageReadEmptyResultSet(t *testing.T) { } } +func TestIntegration_StorageReadClientProject(t *testing.T) { + if client == nil { + t.Skip("Integration tests skipped") + } + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + table := storageOptimizedClient.Dataset("usa_names").Table("usa_1910_current") + table.ProjectID = "bigquery-public-data" + + it := table.Read(ctx) + _, err := countIteratorRows(it) + if err != nil { + t.Fatal(err) + } + if !it.IsAccelerated() { + t.Fatal("expected storage api to be used") + } + + session := it.arrowIterator.(*storageArrowIterator).rs + expectedPrefix := fmt.Sprintf("projects/%s", storageOptimizedClient.projectID) + if !strings.HasPrefix(session.bqSession.Name, expectedPrefix) { + t.Fatalf("expected read session to have prefix %q: but found %s:", expectedPrefix, session.bqSession.Name) + } + + it = table.Read(ctx, WithReadSessionProject("bigquery-public-data")) + _, err = countIteratorRows(it) + if err != nil { + t.Fatal(err) + } + if it.IsAccelerated() { + t.Fatal("expected storage api to not be used") + } +} + func TestIntegration_StorageReadFromSources(t *testing.T) { if client == nil { t.Skip("Integration tests skipped") diff --git a/bigquery/storage_iterator.go b/bigquery/storage_iterator.go index f5da5def3ec5..3704fc31d74b 100644 --- a/bigquery/storage_iterator.go +++ b/bigquery/storage_iterator.go @@ -47,12 +47,12 @@ type storageArrowIterator struct { var _ ArrowIterator = &storageArrowIterator{} -func newStorageRowIteratorFromTable(ctx context.Context, table *Table, ordered, useClientProject bool) (*RowIterator, error) { +func newStorageRowIteratorFromTable(ctx context.Context, table *Table, rsProject string, ordered bool) (*RowIterator, error) { md, err := table.Metadata(ctx) if err != nil { return nil, err } - rs, err := table.c.rc.sessionForTable(ctx, table, ordered, useClientProject) + rs, err := table.c.rc.sessionForTable(ctx, table, rsProject, ordered) if err != nil { return nil, err } @@ -95,7 +95,7 @@ func newStorageRowIteratorFromJob(ctx context.Context, j *Job) (*RowIterator, er return newStorageRowIteratorFromJob(ctx, lastJob) } ordered := query.HasOrderedResults(qcfg.Q) - return newStorageRowIteratorFromTable(ctx, qcfg.Dst, ordered, false) + return newStorageRowIteratorFromTable(ctx, qcfg.Dst, job.projectID, ordered) } func resolveLastChildSelectJob(ctx context.Context, job *Job) (*Job, error) { diff --git a/bigquery/table.go b/bigquery/table.go index 012d58aba790..56058f46daf5 100644 --- a/bigquery/table.go +++ b/bigquery/table.go @@ -968,17 +968,16 @@ func (t *Table) Delete(ctx context.Context) (err error) { } type tableReadOption struct { - useClientProject bool + readSessionProject string } // TableReadOption allows requests to alter the behavior of reading from a table. type TableReadOption func(*tableReadOption) -// WithClientProject allows the read session to be created from the client project -// when reading from the table, instead of the table's project. -func WithClientProject() TableReadOption { +// WithReadSessionProject allows to create the read session with the specified project that has the necessary permissions to do so. +func WithReadSessionProject(project string) TableReadOption { return func(tro *tableReadOption) { - tro.useClientProject = true + tro.readSessionProject = project } } @@ -988,13 +987,17 @@ func (t *Table) Read(ctx context.Context, opts ...TableReadOption) *RowIterator } func (t *Table) read(ctx context.Context, pf pageFetcher, opts ...TableReadOption) *RowIterator { - tro := &tableReadOption{useClientProject: false} + tro := &tableReadOption{} for _, o := range opts { o(tro) } + if tro.readSessionProject == "" { + tro.readSessionProject = t.c.projectID + } + if t.c.isStorageReadAvailable() { - it, err := newStorageRowIteratorFromTable(ctx, t, false, tro.useClientProject) + it, err := newStorageRowIteratorFromTable(ctx, t, tro.readSessionProject, false) if err == nil { return it }