Skip to content

Commit

Permalink
test(python/adbc_driver_flightsql): test incremental execution
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed Feb 29, 2024
1 parent b29e4a9 commit 92220c2
Show file tree
Hide file tree
Showing 6 changed files with 335 additions and 12 deletions.
6 changes: 6 additions & 0 deletions docs/source/driver/status.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ Update Queries
:header-rows: 1

* - Driver
- Incremental Queries
- Partitioned Data
- Parameterized Queries
- Prepared Statements
Expand All @@ -161,29 +162,34 @@ Update Queries
- Y
- Y
- Y
- Y

* - Flight SQL (Java)
- N
- Y
- Y
- Y
- Y
- Y

* - JDBC
- N/A
- N/A
- Y
- Y
- Y
- Y

* - PostgreSQL
- N/A
- N/A
- Y
- Y
- Y
- Y

* - SQLite
- N/A
- N/A
- Y
- Y
Expand Down
103 changes: 102 additions & 1 deletion go/adbc/driver/flightsql/cmd/testserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"os"
"strconv"
"strings"
"sync"

"github.com/apache/arrow/go/v16/arrow"
"github.com/apache/arrow/go/v16/arrow/array"
Expand All @@ -45,6 +46,9 @@ import (

type ExampleServer struct {
flightsql.BaseServer

mu sync.Mutex
pollingStatus map[string]int
}

func StatusWithDetail(code codes.Code, message string, details ...proto.Message) error {
Expand Down Expand Up @@ -120,6 +124,103 @@ func (srv *ExampleServer) GetFlightInfoStatement(ctx context.Context, cmd flight
}, nil
}

func (srv *ExampleServer) PollFlightInfo(ctx context.Context, desc *flight.FlightDescriptor) (*flight.PollInfo, error) {
srv.mu.Lock()
defer srv.mu.Unlock()

var val wrapperspb.StringValue
var err error
if err = proto.Unmarshal(desc.Cmd, &val); err != nil {
return nil, err
}

srv.pollingStatus[val.Value]--
progress := srv.pollingStatus[val.Value]

ticket, err := flightsql.CreateStatementQueryTicket([]byte(val.Value))
if err != nil {
return nil, err
}

endpoints := make([]*flight.FlightEndpoint, 5-progress)
for i := range endpoints {
endpoints[i] = &flight.FlightEndpoint{Ticket: &flight.Ticket{Ticket: ticket}}
}

var schema []byte
if progress < 3 {
schema = flight.SerializeSchema(arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil), srv.Alloc)
}
if progress == 0 {
desc = nil
}

if val.Value == "error_poll_later" && progress == 3 {
return nil, StatusWithDetail(codes.Unavailable, "expected error (PollFlightInfo)")
}

return &flight.PollInfo{
Info: &flight.FlightInfo{
Schema: schema,
Endpoint: endpoints,
FlightDescriptor: desc,
TotalRecords: -1,
TotalBytes: -1,
},
FlightDescriptor: desc,
Progress: proto.Float64(1.0 - (float64(progress) / 5.0)),
}, nil
}

func (srv *ExampleServer) PollFlightInfoPreparedStatement(ctx context.Context, query flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.PollInfo, error) {
srv.mu.Lock()
defer srv.mu.Unlock()

switch string(query.GetPreparedStatementHandle()) {
case "error_poll":
detail1 := wrapperspb.String("detail1")
detail2 := wrapperspb.String("detail2")
return nil, StatusWithDetail(codes.InvalidArgument, "expected error (PollFlightInfo)", detail1, detail2)
case "finish_immediately":
ticket, err := flightsql.CreateStatementQueryTicket(query.GetPreparedStatementHandle())
if err != nil {
return nil, err
}
return &flight.PollInfo{
Info: &flight.FlightInfo{
Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil), srv.Alloc),
Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: ticket}}},
FlightDescriptor: desc,
TotalRecords: -1,
TotalBytes: -1,
},
FlightDescriptor: nil,
Progress: proto.Float64(1.0),
}, nil
}

descriptor, err := proto.Marshal(&wrapperspb.StringValue{Value: string(query.GetPreparedStatementHandle())})
if err != nil {
return nil, err
}

srv.pollingStatus[string(query.GetPreparedStatementHandle())] = 5
return &flight.PollInfo{
Info: &flight.FlightInfo{
Schema: nil,
Endpoint: []*flight.FlightEndpoint{},
FlightDescriptor: desc,
TotalRecords: -1,
TotalBytes: -1,
},
FlightDescriptor: &flight.FlightDescriptor{
Type: flight.DescriptorCMD,
Cmd: descriptor,
},
Progress: proto.Float64(0.0),
}, nil
}

func (srv *ExampleServer) DoGetPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) {
log.Printf("DoGetPreparedStatement: %v", cmd.GetPreparedStatementHandle())
switch string(cmd.GetPreparedStatementHandle()) {
Expand Down Expand Up @@ -226,7 +327,7 @@ func main() {

flag.Parse()

srv := &ExampleServer{}
srv := &ExampleServer{pollingStatus: make(map[string]int)}
srv.Alloc = memory.DefaultAllocator

server := flight.NewServerWithMiddleware(nil)
Expand Down
204 changes: 204 additions & 0 deletions python/adbc_driver_flightsql/tests/test_incremental.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import re

import google.protobuf.any_pb2 as any_pb2
import google.protobuf.wrappers_pb2 as wrappers_pb2
import pyarrow
import pytest

import adbc_driver_manager
from adbc_driver_manager import StatementOptions

SCHEMA = pyarrow.schema([("ints", "int32")])


def test_incremental_error(test_dbapi) -> None:
with test_dbapi.cursor() as cur:
cur.adbc_statement.set_options(
**{
StatementOptions.INCREMENTAL.value: "true",
}
)
with pytest.raises(
test_dbapi.ProgrammingError,
match=re.escape("[FlightSQL] expected error (PollFlightInfo)"),
) as exc_info:
cur.adbc_execute_partitions("error_poll")

found = set()
for _, detail in exc_info.value.details:
anyproto = any_pb2.Any()
anyproto.ParseFromString(detail)
string = wrappers_pb2.StringValue()
anyproto.Unpack(string)
found.add(string.value)
assert found == {"detail1", "detail2"}

# After an error, we can execute a different query.
partitions, schema = cur.adbc_execute_partitions("finish_immediately")
assert len(partitions) == 1
assert schema == SCHEMA
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(1.0)


def test_incremental_error_poll(test_dbapi) -> None:
with test_dbapi.cursor() as cur:
cur.adbc_statement.set_options(
**{
StatementOptions.INCREMENTAL.value: "true",
}
)
partitions, schema = cur.adbc_execute_partitions("error_poll_later")
assert len(partitions) == 1
assert schema is None
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(0.2)

# An error can be retried.
with pytest.raises(
test_dbapi.OperationalError,
match=re.escape("[FlightSQL] expected error (PollFlightInfo)"),
) as excinfo:
partitions, schema = cur.adbc_execute_partitions("error_poll_later")
assert excinfo.value.status_code == adbc_driver_manager.AdbcStatusCode.IO

partitions, schema = cur.adbc_execute_partitions("error_poll_later")
assert len(partitions) == 2
assert schema == SCHEMA
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(0.6)

partitions, schema = cur.adbc_execute_partitions("error_poll_later")
assert len(partitions) == 1
assert schema == SCHEMA
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(0.8)

partitions, schema = cur.adbc_execute_partitions("error_poll_later")
assert len(partitions) == 1
assert schema == SCHEMA
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(1.0)

partitions, _ = cur.adbc_execute_partitions("error_poll_later")
assert partitions == []


def test_incremental_immediately(test_dbapi) -> None:
with test_dbapi.cursor() as cur:
cur.adbc_statement.set_options(
**{
StatementOptions.INCREMENTAL.value: "true",
}
)
partitions, schema = cur.adbc_execute_partitions("finish_immediately")
assert len(partitions) == 1
assert schema == SCHEMA
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(1.0)

partitions, schema = cur.adbc_execute_partitions("finish_immediately")
assert partitions == []

# reuse for a new query
partitions, schema = cur.adbc_execute_partitions("finish_immediately")
assert len(partitions) == 1
partitions, schema = cur.adbc_execute_partitions("finish_immediately")
assert partitions == []


def test_incremental_query(test_dbapi) -> None:
with test_dbapi.cursor() as cur:
cur.adbc_statement.set_options(
**{
StatementOptions.INCREMENTAL.value: "true",
}
)
partitions, schema = cur.adbc_execute_partitions("SELECT 1")
assert len(partitions) == 1
assert schema is None
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(0.2)

message = (
"[Flight SQL] Cannot disable incremental execution "
"while a query is in progress"
)
with pytest.raises(
test_dbapi.ProgrammingError,
match=re.escape(message),
) as excinfo:
cur.adbc_statement.set_options(
**{
StatementOptions.INCREMENTAL.value: "false",
}
)
assert (
excinfo.value.status_code
== adbc_driver_manager.AdbcStatusCode.INVALID_STATE
)

partitions, schema = cur.adbc_execute_partitions("SELECT 1")
assert len(partitions) == 1
assert schema is None
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(0.4)

partitions, schema = cur.adbc_execute_partitions("SELECT 1")
assert len(partitions) == 1
assert schema == SCHEMA
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(0.6)

partitions, schema = cur.adbc_execute_partitions("SELECT 1")
assert len(partitions) == 1
assert schema == SCHEMA
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(0.8)

partitions, schema = cur.adbc_execute_partitions("SELECT 1")
assert len(partitions) == 1
assert schema == SCHEMA
assert cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
) == pytest.approx(1.0)

partitions, schema = cur.adbc_execute_partitions("SELECT 1")
assert len(partitions) == 0
assert schema == SCHEMA
assert (
cur.adbc_statement.get_option_float(StatementOptions.PROGRESS.value) == 0.0
)

cur.adbc_statement.set_options(
**{
StatementOptions.INCREMENTAL.value: "false",
}
)
4 changes: 3 additions & 1 deletion python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ class AdbcStatement(_AdbcHandle):
def bind_stream(self, *args, **kwargs) -> Any: ...
def cancel(self) -> None: ...
def close(self) -> None: ...
def execute_partitions(self, *args, **kwargs) -> Any: ...
def execute_partitions(
self,
) -> Tuple[List[bytes], Optional[ArrowSchemaHandle], int]: ...
def execute_query(self, *args, **kwargs) -> Any: ...
def execute_schema(self) -> "ArrowSchemaHandle": ...
def execute_update(self, *args, **kwargs) -> Any: ...
Expand Down
Loading

0 comments on commit 92220c2

Please sign in to comment.