Skip to content

Commit

Permalink
feat(bindings/python): cursor object for PEP 249 (#548)
Browse files Browse the repository at this point in the history
  • Loading branch information
everpcpc authored Dec 27, 2024
1 parent b568749 commit 84871c0
Show file tree
Hide file tree
Showing 13 changed files with 459 additions and 15 deletions.
2 changes: 1 addition & 1 deletion bindings/nodejs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pnpm run build
const { Client } = require("databend-driver");

const client = new Client(
"databend+http://root:root@localhost:8000/?sslmode=disable",
"databend://root:root@localhost:8000/?sslmode=disable",
);
const conn = await client.getConn();

Expand Down
1 change: 1 addition & 0 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ chrono = { workspace = true }
databend-driver = { workspace = true, features = ["rustls", "flight-sql"] }
tokio-stream = { workspace = true }

csv = "1.3"
ctor = "0.2"
once_cell = "1.20"
pyo3 = { version = "0.23.3", features = ["abi3-py37", "chrono"] }
Expand Down
48 changes: 46 additions & 2 deletions bindings/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,40 @@ maturin develop

## Usage

### PEP 249 cursor object

```python
from databend_driver import BlockingDatabendClient

client = BlockingDatabendClient('databend://root:root@localhost:8000/?sslmode=disable')
cursor = client.cursor()

cursor.execute(
"""
CREATE TABLE test (
i64 Int64,
u64 UInt64,
f64 Float64,
s String,
s2 String,
d Date,
t DateTime
)
"""
)
cursor.execute("INSERT INTO test VALUES", (1, 1, 1.0, 'hello', 'world', '2021-01-01', '2021-01-01 00:00:00'))
cursor.execute("SELECT * FROM test")
rows = cursor.fetchall()
for row in rows:
print(row.values())
```

### Blocking

```python
from databend_driver import BlockingDatabendClient

client = BlockingDatabendClient('databend+http://root:root@localhost:8000/?sslmode=disable')
client = BlockingDatabendClient('databend://root:root@localhost:8000/?sslmode=disable')
conn = client.get_conn()
conn.exec(
"""
Expand All @@ -41,7 +69,7 @@ import asyncio
from databend_driver import AsyncDatabendClient

async def main():
client = AsyncDatabendClient('databend+http://root:root@localhost:8000/?sslmode=disable')
client = AsyncDatabendClient('databend://root:root@localhost:8000/?sslmode=disable')
conn = await client.get_conn()
await conn.exec(
"""
Expand Down Expand Up @@ -141,6 +169,7 @@ class AsyncDatabendConnection:
class BlockingDatabendClient:
def __init__(self, dsn: str): ...
def get_conn(self) -> BlockingDatabendConnection: ...
def cursor(self) -> BlockingDatabendCursor: ...
```

### BlockingDatabendConnection
Expand All @@ -156,11 +185,26 @@ class BlockingDatabendConnection:
def load_file(self, sql: str, file: str, format_option: dict, copy_options: dict = None) -> ServerStats: ...
```

### BlockingDatabendCursor

```python
class BlockingDatabendCursor:
def close(self) -> None: ...
def execute(self, operation: str, params: list[string] | tuple[string] = None) -> None | int: ...
def executemany(self, operation: str, params: list[list[string] | tuple[string]]) -> None | int: ...
def fetchone(self) -> Row: ...
def fetchall(self) -> list[Row]: ...
```

### Row

```python
class Row:
def values(self) -> tuple: ...
def __len__(self) -> int: ...
def __iter__(self) -> list: ...
def __dict__(self) -> dict: ...
def __getitem__(self, key: int | str) -> any: ...
```

### RowIterator
Expand Down
1 change: 1 addition & 0 deletions bindings/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ license = { text = "Apache-2.0" }
name = "databend-driver"
readme = "README.md"
requires-python = ">=3.7, < 3.14"
dynamic = ["version"]

[project.urls]
Repository = "https://github.com/databendlabs/bendsql"
Expand Down
186 changes: 186 additions & 0 deletions bindings/python/src/blocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ use std::collections::BTreeMap;
use std::path::Path;
use std::sync::Arc;

use pyo3::exceptions::{PyAttributeError, PyException};
use pyo3::prelude::*;
use pyo3::types::{PyList, PyTuple};
use tokio::sync::Mutex;
use tokio_stream::StreamExt;

use crate::types::{ConnectionInfo, DriverError, Row, RowIterator, ServerStats, VERSION};
use crate::utils::wait_for_future;
Expand All @@ -41,6 +45,14 @@ impl BlockingDatabendClient {
})?;
Ok(BlockingDatabendConnection(Arc::new(conn)))
}

pub fn cursor(&self, py: Python) -> PyResult<BlockingDatabendCursor> {
let this = self.0.clone();
let conn = wait_for_future(py, async move {
this.get_conn().await.map_err(DriverError::new)
})?;
Ok(BlockingDatabendCursor::new(conn))
}
}

#[pyclass(module = "databend_driver")]
Expand Down Expand Up @@ -142,3 +154,177 @@ impl BlockingDatabendConnection {
Ok(ServerStats::new(ret))
}
}

/// BlockingDatabendCursor is an object that follows PEP 249
/// https://peps.python.org/pep-0249/#cursor-objects
#[pyclass(module = "databend_driver")]
pub struct BlockingDatabendCursor {
conn: Arc<Box<dyn databend_driver::Connection>>,
rows: Option<Arc<Mutex<databend_driver::RowIterator>>>,
// buffer is used to store only the first row after execute
buffer: Vec<Row>,
}

impl BlockingDatabendCursor {
fn new(conn: Box<dyn databend_driver::Connection>) -> Self {
Self {
conn: Arc::new(conn),
rows: None,
buffer: Vec::new(),
}
}
}

impl BlockingDatabendCursor {
fn reset(&mut self) {
self.rows = None;
self.buffer.clear();
}
}

#[pymethods]
impl BlockingDatabendCursor {
pub fn close(&mut self, py: Python) -> PyResult<()> {
self.reset();
wait_for_future(py, async move {
self.conn.close().await.map_err(DriverError::new)
})?;
Ok(())
}

#[pyo3(signature = (operation, parameters=None))]
pub fn execute<'p>(
&'p mut self,
py: Python<'p>,
operation: String,
parameters: Option<Bound<'p, PyAny>>,
) -> PyResult<PyObject> {
if let Some(param) = parameters {
return self.executemany(py, operation, [param].to_vec());
}

self.reset();
let conn = self.conn.clone();
// fetch first row after execute
// then we could finish the query directly if there's no result
let (first, rows) = wait_for_future(py, async move {
let mut rows = conn.query_iter(&operation).await?;
let first = rows.next().await.transpose()?;
Ok::<_, databend_driver::Error>((first, rows))
})
.map_err(DriverError::new)?;
if let Some(first) = first {
self.buffer.push(Row::new(first));
}
self.rows = Some(Arc::new(Mutex::new(rows)));
Ok(py.None())
}

pub fn executemany<'p>(
&'p mut self,
py: Python<'p>,
operation: String,
parameters: Vec<Bound<'p, PyAny>>,
) -> PyResult<PyObject> {
self.reset();
let conn = self.conn.clone();
if let Some(param) = parameters.first() {
if param.downcast::<PyList>().is_ok() || param.downcast::<PyTuple>().is_ok() {
let bytes = format_csv(parameters)?;
let size = bytes.len() as u64;
let reader = Box::new(std::io::Cursor::new(bytes));
let stats = wait_for_future(py, async move {
conn.load_data(&operation, reader, size, None, None)
.await
.map_err(DriverError::new)
})?;
let result = stats.write_rows.into_pyobject(py)?;
return Ok(result.into());
} else {
return Err(PyAttributeError::new_err(
"Invalid parameter type, expected list or tuple",
));
}
}
Ok(py.None())
}

pub fn fetchone(&mut self, py: Python) -> PyResult<Option<Row>> {
if let Some(row) = self.buffer.pop() {
return Ok(Some(row));
}
match self.rows {
Some(ref rows) => {
match wait_for_future(py, async move { rows.lock().await.next().await }) {
Some(row) => Ok(Some(Row::new(row.map_err(DriverError::new)?))),
None => Ok(None),
}
}
None => Ok(None),
}
}

pub fn fetchall(&mut self, py: Python) -> PyResult<Vec<Row>> {
let mut result = self.buffer.drain(..).collect::<Vec<_>>();
match self.rows.take() {
Some(rows) => {
let fetched = wait_for_future(py, async move {
let mut rows = rows.lock().await;
let mut result = Vec::new();
while let Some(row) = rows.next().await {
result.push(row);
}
result
});
for row in fetched {
result.push(Row::new(row.map_err(DriverError::new)?));
}
Ok(result)
}
None => Ok(vec![]),
}
}
}

fn format_csv<'p>(parameters: Vec<Bound<'p, PyAny>>) -> PyResult<Vec<u8>> {
let mut wtr = csv::WriterBuilder::new().from_writer(vec![]);
for row in parameters {
let iter = row.try_iter()?;
let data = iter
.map(|v| match v {
Ok(v) => to_csv_field(v),
Err(e) => Err(e.into()),
})
.collect::<Result<Vec<_>, _>>()?;
wtr.write_record(data)
.map_err(|e| PyException::new_err(e.to_string()))
.unwrap();
}
let bytes = wtr
.into_inner()
.map_err(|e| PyException::new_err(e.to_string()))
.unwrap();
Ok(bytes)
}

fn to_csv_field(v: Bound<PyAny>) -> PyResult<String> {
match v.downcast::<PyAny>() {
Ok(v) => {
if let Ok(v) = v.extract::<String>() {
Ok(v)
} else if let Ok(v) = v.extract::<bool>() {
Ok(v.to_string())
} else if let Ok(v) = v.extract::<i64>() {
Ok(v.to_string())
} else if let Ok(v) = v.extract::<f64>() {
Ok(v.to_string())
} else {
Err(PyAttributeError::new_err(format!(
"Invalid parameter type for: {:?}, expected str, bool, int or float",
v
)))
}
}
Err(e) => Err(e.into()),
}
}
3 changes: 2 additions & 1 deletion bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ mod utils;
use pyo3::prelude::*;

use crate::asyncio::{AsyncDatabendClient, AsyncDatabendConnection};
use crate::blocking::{BlockingDatabendClient, BlockingDatabendConnection};
use crate::blocking::{BlockingDatabendClient, BlockingDatabendConnection, BlockingDatabendCursor};
use crate::types::{ConnectionInfo, Field, Row, RowIterator, Schema, ServerStats};

#[pymodule]
Expand All @@ -29,6 +29,7 @@ fn _databend_driver(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<AsyncDatabendConnection>()?;
m.add_class::<BlockingDatabendClient>()?;
m.add_class::<BlockingDatabendConnection>()?;
m.add_class::<BlockingDatabendCursor>()?;
m.add_class::<ConnectionInfo>()?;
m.add_class::<Schema>()?;
m.add_class::<Row>()?;
Expand Down
49 changes: 48 additions & 1 deletion bindings/python/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::sync::Arc;

use chrono::{NaiveDate, NaiveDateTime};
use once_cell::sync::Lazy;
use pyo3::exceptions::{PyException, PyStopAsyncIteration, PyStopIteration};
use pyo3::exceptions::{PyAttributeError, PyException, PyStopAsyncIteration, PyStopIteration};
use pyo3::sync::GILOnceCell;
use pyo3::types::{PyBytes, PyDict, PyList, PyTuple, PyType};
use pyo3::{intern, IntoPyObjectExt};
Expand Down Expand Up @@ -162,6 +162,53 @@ impl Row {
let tuple = PyTuple::new(py, vals)?;
Ok(tuple)
}

pub fn __len__(&self) -> usize {
self.0.len()
}

pub fn __iter__<'p>(&'p self, py: Python<'p>) -> PyResult<Bound<'p, PyList>> {
let vals = self.0.values().iter().map(|v| Value(v.clone()));
let list = PyList::new(py, vals)?;
Ok(list.into_bound())
}

pub fn __dict__<'p>(&'p self, py: Python<'p>) -> PyResult<Bound<'p, PyDict>> {
let dict = PyDict::new(py);
let schema = self.0.schema();
for (field, value) in schema.fields().iter().zip(self.0.values()) {
dict.set_item(&field.name, Value(value.clone()))?;
}
Ok(dict.into_bound())
}

fn get_by_index(&self, idx: usize) -> PyResult<Value> {
Ok(Value(self.0.values()[idx].clone()))
}

fn get_by_field(&self, field: &str) -> PyResult<Value> {
let schema = self.0.schema();
let idx = schema
.fields()
.iter()
.position(|f| f.name == field)
.ok_or_else(|| {
PyException::new_err(format!("field '{}' not found in schema", field))
})?;
Ok(Value(self.0.values()[idx].clone()))
}

pub fn __getitem__<'p>(&'p self, key: Bound<'p, PyAny>) -> PyResult<Value> {
if let Ok(idx) = key.extract::<usize>() {
self.get_by_index(idx)
} else if let Ok(field) = key.extract::<String>() {
self.get_by_field(&field)
} else {
Err(PyAttributeError::new_err(
"key must be an integer or a string",
))
}
}
}

#[pyclass(module = "databend_driver")]
Expand Down
Loading

0 comments on commit 84871c0

Please sign in to comment.