Skip to content

Commit

Permalink
Adding pl.Struct support for pl.Dataframe (#306)
Browse files Browse the repository at this point in the history
Adding missing pl.Struct support for pl.Dataframe from rows to close
#298
  • Loading branch information
Bidek56 authored Feb 3, 2025
1 parent 1b7a42f commit 3eda980
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 61 deletions.
96 changes: 95 additions & 1 deletion __tests__/dataframe.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ describe("dataframe", () => {
pl.Series("foo", [1, 2, 9], pl.Int16),
pl.Series("bar", [6, 2, 8], pl.Int16),
]);

test("dtypes", () => {
const expected = [pl.Float64, pl.String];
const actual = pl.DataFrame({ a: [1, 2, 3], b: ["a", "b", "c"] }).dtypes;
Expand Down Expand Up @@ -1318,6 +1317,101 @@ describe("dataframe", () => {
]);
expect(actual).toFrameEqual(expected);
});
test("df from JSON with multiple struct", () => {
const rows = [
{
id: 1,
name: "one",
attributes: {
b: false,
bb: true,
s: "one",
x: 1,
att2: { s: "two", y: 2, att3: { s: "three", y: 3 } },
},
},
];

const actual = pl.DataFrame(rows);
const expected = `shape: (1,)
Series: 'attributes' [struct[5]]
[
{false,true,"one",1.0,{"two",2.0,{"three",3.0}}}
]`;
expect(actual.select("attributes").toSeries().toString()).toEqual(expected);
});
test("df from JSON with struct", () => {
const rows = [
{
id: 1,
name: "one",
attributes: { b: false, bb: true, s: "one", x: 1 },
},
{
id: 2,
name: "two",
attributes: { b: false, bb: true, s: "two", x: 2 },
},
{
id: 3,
name: "three",
attributes: { b: false, bb: true, s: "three", x: 3 },
},
];

let actual = pl.DataFrame(rows);
expect(actual.schema).toStrictEqual({
id: pl.Float64,
name: pl.String,
attributes: pl.Struct([
new pl.Field("b", pl.Bool),
new pl.Field("bb", pl.Bool),
new pl.Field("s", pl.String),
new pl.Field("x", pl.Float64),
]),
});

let expected = `shape: (3, 3)
┌─────┬───────┬──────────────────────────┐
│ id ┆ name ┆ attributes │
│ --- ┆ --- ┆ --- │
│ f64 ┆ str ┆ struct[4] │
╞═════╪═══════╪══════════════════════════╡
│ 1.0 ┆ one ┆ {false,true,"one",1.0} │
│ 2.0 ┆ two ┆ {false,true,"two",2.0} │
│ 3.0 ┆ three ┆ {false,true,"three",3.0} │
└─────┴───────┴──────────────────────────┘`;
expect(actual.toString()).toStrictEqual(expected);

const schema = {
id: pl.Int32,
name: pl.String,
attributes: pl.Struct([
new pl.Field("b", pl.Bool),
new pl.Field("bb", pl.Bool),
new pl.Field("s", pl.String),
new pl.Field("x", pl.Int16),
]),
};
actual = pl.DataFrame(rows, { schema: schema });
expected = `shape: (3, 3)
┌─────┬───────┬────────────────────────┐
│ id ┆ name ┆ attributes │
│ --- ┆ --- ┆ --- │
│ i32 ┆ str ┆ struct[4] │
╞═════╪═══════╪════════════════════════╡
│ 1 ┆ one ┆ {false,true,"one",1} │
│ 2 ┆ two ┆ {false,true,"two",2} │
│ 3 ┆ three ┆ {false,true,"three",3} │
└─────┴───────┴────────────────────────┘`;
expect(actual.toString()).toStrictEqual(expected);
expect(actual.getColumn("name").toArray()).toEqual(
rows.map((e) => e["name"]),
);
expect(actual.getColumn("attributes").toArray()).toMatchObject(
rows.map((e) => e["attributes"]),
);
});
test("pivot", () => {
{
const df = pl.DataFrame({
Expand Down
179 changes: 119 additions & 60 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,27 +442,26 @@ pub fn from_rows(
infer_schema(pairs, infer_schema_length)
}
};
let len = rows.len();
let it: Vec<Row> = (0..len)
let it: Vec<Row> = (0..rows.len())
.into_iter()
.map(|idx| {
let obj = rows
.get::<Object>(idx as u32)
.unwrap_or(None)
.unwrap_or_else(|| env.create_object().unwrap());

Row(schema
.iter_fields()
.map(|fld| {
let dtype = fld.dtype().clone();
let key = fld.name();
if let Ok(unknown) = obj.get(key) {
let av = match unknown {
Some(unknown) => unsafe {
coerce_js_anyvalue(unknown, dtype).unwrap_or(AnyValue::Null)
},
None => AnyValue::Null,
};
av
let dtype: &DataType = fld.dtype();
let key: &PlSmallStr = fld.name();
if let Ok(unknown) = obj.get::<&polars::prelude::PlSmallStr, JsUnknown>(key) {
match unknown {
Some(unknown) => {
coerce_js_anyvalue(unknown, dtype.clone()).unwrap_or(AnyValue::Null)
}
_ => AnyValue::Null,
}
} else {
AnyValue::Null
}
Expand Down Expand Up @@ -1620,61 +1619,79 @@ fn obj_to_pairs(rows: &Array, len: usize) -> impl '_ + Iterator<Item = Vec<(Stri
let len = std::cmp::min(len, rows.len() as usize);
(0..len).map(move |idx| {
let obj = rows.get::<Object>(idx as u32).unwrap().unwrap();

let keys = Object::keys(&obj).unwrap();
keys.iter()
.map(|key| {
let value = obj.get::<_, napi::JsUnknown>(&key).unwrap_or(None);
let dtype = match value {
Some(val) => {
let ty = val.get_type().unwrap();
match ty {
ValueType::Boolean => DataType::Boolean,
ValueType::Number => DataType::Float64,
ValueType::String => DataType::String,
ValueType::Object => {
if val.is_array().unwrap() {
let arr: napi::JsObject = unsafe { val.cast() };
let len = arr.get_array_length().unwrap();

if len == 0 {
DataType::List(DataType::Null.into())
} else {
// dont compare too many items, as it could be expensive
let max_take = std::cmp::min(len as usize, 10);
let mut dtypes: Vec<DataType> =
Vec::with_capacity(len as usize);

for idx in 0..max_take {
let item: napi::JsUnknown =
arr.get_element(idx as u32).unwrap();
let ty = item.get_type().unwrap();
let dt: Wrap<DataType> = ty.into();
dtypes.push(dt.0)
}
let dtype = coerce_data_type(&dtypes);

DataType::List(dtype.into())
}
} else if val.is_date().unwrap() {
DataType::Datetime(TimeUnit::Milliseconds, None)
} else {
DataType::Struct(vec![])
}
}
ValueType::BigInt => DataType::UInt64,
_ => DataType::Null,
}
}
None => DataType::Null,
};
(key.to_owned(), dtype)
(key.to_owned(), obj_to_type(value))
})
.collect()
})
}

unsafe fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult<AnyValue<'a>> {
fn obj_to_type(value: Option<JsUnknown>) -> DataType {
match value {
Some(val) => {
let ty = val.get_type().unwrap();
match ty {
ValueType::Boolean => DataType::Boolean,
ValueType::Number => DataType::Float64,
ValueType::BigInt => DataType::UInt64,
ValueType::String => DataType::String,
ValueType::Object => {
if val.is_array().unwrap() {
let arr: napi::JsObject = unsafe { val.cast() };
let len = arr.get_array_length().unwrap();
if len == 0 {
DataType::List(DataType::Null.into())
} else {
// dont compare too many items, as it could be expensive
let max_take = std::cmp::min(len as usize, 10);
let mut dtypes: Vec<DataType> = Vec::with_capacity(len as usize);

for idx in 0..max_take {
let item: napi::JsUnknown = arr.get_element(idx as u32).unwrap();
let ty = item.get_type().unwrap();
let dt: Wrap<DataType> = ty.into();
dtypes.push(dt.0)
}
let dtype = coerce_data_type(&dtypes);

DataType::List(dtype.into())
}
} else if val.is_date().unwrap() {
DataType::Datetime(TimeUnit::Milliseconds, None)
} else {
let inner_val: napi::JsObject = unsafe { val.cast() };
let inner_keys = Object::keys(&inner_val).unwrap();
let mut fldvec: Vec<Field> = Vec::with_capacity(inner_keys.len() as usize);

inner_keys.iter().for_each(|key| {
let inner_val = inner_val.get::<_, napi::JsUnknown>(&key).unwrap();
let dtype = match inner_val.as_ref().unwrap().get_type().unwrap() {
ValueType::Boolean => DataType::Boolean,
ValueType::Number => DataType::Float64,
ValueType::BigInt => DataType::UInt64,
ValueType::String => DataType::String,
// determine struct type using a recursive func
ValueType::Object => obj_to_type(inner_val),
_ => DataType::Null,
};

let fld = Field::new(key.into(), dtype);
fldvec.push(fld);
});
DataType::Struct(fldvec)
}
}
_ => DataType::Null,
}
}
None => DataType::Null,
}
}

fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult<AnyValue<'a>> {
use DataType::*;
let vtype = val.get_type().unwrap();
match (vtype, dtype) {
Expand Down Expand Up @@ -1749,17 +1766,59 @@ unsafe fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult<An
}
(ValueType::Object, DataType::Datetime(_, _)) => {
if val.is_date()? {
let d: napi::JsDate = val.cast();
let d: napi::JsDate = unsafe { val.cast() };
let d = d.value_of()?;
Ok(AnyValue::Datetime(d as i64, TimeUnit::Milliseconds, None))
} else {
Ok(AnyValue::Null)
}
}
(ValueType::Object, DataType::List(_)) => {
let s = val.to_series();
let s = unsafe { val.to_series() };
Ok(AnyValue::List(s))
}
(ValueType::Object, DataType::Struct(fields)) => {
let number_of_fields: i8 = fields.len().try_into().map_err(|e| {
napi::Error::from_reason(format!(
"the number of `fields` cannot be larger than i8::MAX {e:?}"
))
})?;

let inner_val: napi::JsObject = unsafe { val.cast() };
let mut val_vec: Vec<polars::prelude::AnyValue<'_>> =
Vec::with_capacity(number_of_fields as usize);
fields.iter().for_each(|fld| {
let single_val = inner_val
.get::<_, napi::JsUnknown>(&fld.name)
.unwrap()
.unwrap();
let vv = match &fld.dtype {
DataType::Boolean => {
AnyValue::Boolean(single_val.coerce_to_bool().unwrap().get_value().unwrap())
}
DataType::String => AnyValue::from_js(single_val).expect("Expecting string"),
DataType::Int16 => AnyValue::Int16(
single_val.coerce_to_number().unwrap().get_int32().unwrap() as i16,
),
DataType::Int32 => {
AnyValue::Int32(single_val.coerce_to_number().unwrap().get_int32().unwrap())
}
DataType::Int64 => {
AnyValue::Int64(single_val.coerce_to_number().unwrap().get_int64().unwrap())
}
DataType::Float64 => AnyValue::Float64(
single_val.coerce_to_number().unwrap().get_double().unwrap(),
),
DataType::Struct(_) => {
coerce_js_anyvalue(single_val, fld.dtype.clone()).unwrap()
}
_ => AnyValue::Null,
};
val_vec.push(vv);
});

Ok(AnyValue::StructOwned(Box::new((val_vec, fields))))
}
_ => Ok(AnyValue::Null),
}
}

0 comments on commit 3eda980

Please sign in to comment.