diff --git a/executor/builder.go b/executor/builder.go index 6005d4ed9e29a..5e2804ea1bf5d 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -4564,20 +4564,20 @@ func getPhysicalTableID(t table.Table) int64 { return t.Meta().ID } -func getPhysicalTableEngine(t table.Table) (int64, string) { +func getPhysicalTableEngine(t table.Table) string { if p, ok := t.(table.PhysicalTable); ok { pid := p.GetPhysicalID() pi := t.Meta().GetPartitionInfo() if pi == nil { - return 0, "" + return "" } for _, pd := range pi.Definitions { if pd.ID == pid { - return pd.ID, pd.Engine + return pd.Engine } } } - return 0, "" + return "" } func getFeedbackStatsTableID(ctx sessionctx.Context, t table.Table) int64 { diff --git a/executor/table_reader.go b/executor/table_reader.go index 6a08189580823..4fa2d6079f85b 100644 --- a/executor/table_reader.go +++ b/executor/table_reader.go @@ -19,7 +19,6 @@ import ( "fmt" "sort" - awsathena "github.com/aws/aws-sdk-go/service/athena" "github.com/opentracing/opentracing-go" "github.com/pingcap/tidb/distsql" "github.com/pingcap/tidb/expression" @@ -116,7 +115,7 @@ type TableReaderExecutor struct { extraPIDColumnIndex offsetOptional AWSQueryInfo *plannercore.RestoreData - awsQueryResult *awsathena.ResultSet + awsQueryResult *athena.QueryResult } // offsetOptional may be a positive integer, or invalid. @@ -145,10 +144,10 @@ func (e *TableReaderExecutor) Open(ctx context.Context) error { e.memTracker = memory.NewTracker(e.id, -1) e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) - pid, storeType := getPhysicalTableEngine(e.table) + storeType := getPhysicalTableEngine(e.table) if storeType == kv.AWSS3Engine { e.storeType = kv.AwsS3 - return e.fetchResultFromAws(pid) + return e.fetchResultFromAws() } var err error @@ -237,11 +236,14 @@ func (e *TableReaderExecutor) Next(ctx context.Context, req *chunk.Chunk) error } req.Reset() if e.storeType == kv.AwsS3 { - result := e.awsQueryResult - e.awsQueryResult = nil - if result == nil { + if e.awsQueryResult == nil { return nil } + result, err := e.awsQueryResult.GetResultSet() + e.awsQueryResult = nil + if err != nil || result == nil { + return err + } stmtCtx := e.ctx.GetSessionVars().StmtCtx for rowIdx, row := range result.Rows { if rowIdx == 0 { @@ -385,19 +387,15 @@ func (e *TableReaderExecutor) buildKVReqSeparately(ctx context.Context, ranges [ return kvReqs, nil } -func (e *TableReaderExecutor) fetchResultFromAws(pid int64) error { +func (e *TableReaderExecutor) fetchResultFromAws() error { query := e.AWSQueryInfo.String() logutil.BgLogger().Info(fmt.Sprintf("[aws query] %v", query)) cli, err := athena.CreateCli("us-west-2") - if err != nil { - return nil - } - result, err := athena.QueryTableData(cli, "test", query) if err != nil { return err } - e.awsQueryResult = result - return nil + e.awsQueryResult, err = athena.StartExecSQL(cli, "test", query) + return err } func (e *TableReaderExecutor) buildKVReq(ctx context.Context, ranges []*ranger.Range) (*kv.Request, error) { diff --git a/interval/athena/athena.go b/interval/athena/athena.go index b5a4408624b43..924fd29806e5a 100644 --- a/interval/athena/athena.go +++ b/interval/athena/athena.go @@ -147,12 +147,12 @@ func DropDatabaseAndAllTables(cli *athena.Athena, db string) error { return DropDatabase(cli, db) } -func QueryTableData(cli *athena.Athena, db, query string) (*athena.ResultSet, error) { - result, err := ExecSQL(cli, db, query) +func ExecSQL(cli *athena.Athena, db, query string) (*athena.ResultSet, error) { + rs, err := StartExecSQL(cli, db, query) if err != nil { return nil, err } - return result, nil + return rs.GetResultSet() } const ( @@ -163,33 +163,33 @@ const ( var defaultDB = "default" -func ExecSQL(cli *athena.Athena, db, query string) (*athena.ResultSet, error) { - if db == "" { - db = defaultDB - } - var s athena.StartQueryExecutionInput - s.SetQueryString(query) - - var q athena.QueryExecutionContext - q.SetDatabase(db) - s.SetQueryExecutionContext(&q) - - var r athena.ResultConfiguration - r.SetOutputLocation("s3://athena-query-result-chenshuang-dev3") - s.SetResultConfiguration(&r) - - result, err := cli.StartQueryExecution(&s) +func StartExecSQL(cli *athena.Athena, db, query string) (*QueryResult, error) { + result, err := startSQLExecution(cli, db, query) if err != nil { return nil, err } + return &QueryResult{ + cli: cli, + query: query, + execOutput: result, + }, nil +} +type QueryResult struct { + cli *athena.Athena + query string + execOutput *athena.StartQueryExecutionOutput +} + +func (rs *QueryResult) GetResultSet() (*athena.ResultSet, error) { var qri athena.GetQueryExecutionInput - qri.SetQueryExecutionId(*result.QueryExecutionId) + qri.SetQueryExecutionId(*rs.execOutput.QueryExecutionId) + var err error var qrop *athena.GetQueryExecutionOutput var state, reason string for { - qrop, err = cli.GetQueryExecutionWithContext(context.Background(), &qri) + qrop, err = rs.cli.GetQueryExecutionWithContext(context.Background(), &qri) if err != nil { return nil, err } @@ -204,18 +204,35 @@ func ExecSQL(cli *athena.Athena, db, query string) (*athena.ResultSet, error) { } if state != QuerySucceeded { - return nil, fmt.Errorf("execute %v, detail: %v, sql: %v", state, reason, query) + return nil, fmt.Errorf("execute %v, detail: %v, sql: %v", state, reason, rs.query) } var ip athena.GetQueryResultsInput - ip.SetQueryExecutionId(*result.QueryExecutionId) - - op, err := cli.GetQueryResults(&ip) + ip.SetQueryExecutionId(*rs.execOutput.QueryExecutionId) + op, err := rs.cli.GetQueryResults(&ip) if err != nil { return nil, err } return op.ResultSet, nil } +func startSQLExecution(cli *athena.Athena, db, query string) (*athena.StartQueryExecutionOutput, error) { + if db == "" { + db = defaultDB + } + var s athena.StartQueryExecutionInput + s.SetQueryString(query) + + var q athena.QueryExecutionContext + q.SetDatabase(db) + s.SetQueryExecutionContext(&q) + + var r athena.ResultConfiguration + r.SetOutputLocation("s3://athena-query-result-chenshuang-dev3") + s.SetResultConfiguration(&r) + + return cli.StartQueryExecution(&s) +} + type DDLEngine struct{} func buildCreateTableSQL(table string, s3BucketName string, tbInfo *model.TableInfo) string {