diff --git a/__tests__/dataframe.test.ts b/__tests__/dataframe.test.ts index 20474719..8d745373 100644 --- a/__tests__/dataframe.test.ts +++ b/__tests__/dataframe.test.ts @@ -7,7 +7,6 @@ describe("dataframe", () => { pl.Series("foo", [1, 2, 9], pl.Int16), pl.Series("bar", [6, 2, 8], pl.Int16), ]); - test("dtypes", () => { const expected = [pl.Float64, pl.String]; const actual = pl.DataFrame({ a: [1, 2, 3], b: ["a", "b", "c"] }).dtypes; @@ -1318,6 +1317,101 @@ describe("dataframe", () => { ]); expect(actual).toFrameEqual(expected); }); + test("df from JSON with multiple struct", () => { + const rows = [ + { + id: 1, + name: "one", + attributes: { + b: false, + bb: true, + s: "one", + x: 1, + att2: { s: "two", y: 2, att3: { s: "three", y: 3 } }, + }, + }, + ]; + + const actual = pl.DataFrame(rows); + const expected = `shape: (1,) +Series: 'attributes' [struct[5]] +[ + {false,true,"one",1.0,{"two",2.0,{"three",3.0}}} +]`; + expect(actual.select("attributes").toSeries().toString()).toEqual(expected); + }); + test("df from JSON with struct", () => { + const rows = [ + { + id: 1, + name: "one", + attributes: { b: false, bb: true, s: "one", x: 1 }, + }, + { + id: 2, + name: "two", + attributes: { b: false, bb: true, s: "two", x: 2 }, + }, + { + id: 3, + name: "three", + attributes: { b: false, bb: true, s: "three", x: 3 }, + }, + ]; + + let actual = pl.DataFrame(rows); + expect(actual.schema).toStrictEqual({ + id: pl.Float64, + name: pl.String, + attributes: pl.Struct([ + new pl.Field("b", pl.Bool), + new pl.Field("bb", pl.Bool), + new pl.Field("s", pl.String), + new pl.Field("x", pl.Float64), + ]), + }); + + let expected = `shape: (3, 3) +┌─────┬───────┬──────────────────────────┐ +│ id ┆ name ┆ attributes │ +│ --- ┆ --- ┆ --- │ +│ f64 ┆ str ┆ struct[4] │ +╞═════╪═══════╪══════════════════════════╡ +│ 1.0 ┆ one ┆ {false,true,"one",1.0} │ +│ 2.0 ┆ two ┆ {false,true,"two",2.0} │ +│ 3.0 ┆ three ┆ {false,true,"three",3.0} │ +└─────┴───────┴──────────────────────────┘`; + expect(actual.toString()).toStrictEqual(expected); + + const schema = { + id: pl.Int32, + name: pl.String, + attributes: pl.Struct([ + new pl.Field("b", pl.Bool), + new pl.Field("bb", pl.Bool), + new pl.Field("s", pl.String), + new pl.Field("x", pl.Int16), + ]), + }; + actual = pl.DataFrame(rows, { schema: schema }); + expected = `shape: (3, 3) +┌─────┬───────┬────────────────────────┐ +│ id ┆ name ┆ attributes │ +│ --- ┆ --- ┆ --- │ +│ i32 ┆ str ┆ struct[4] │ +╞═════╪═══════╪════════════════════════╡ +│ 1 ┆ one ┆ {false,true,"one",1} │ +│ 2 ┆ two ┆ {false,true,"two",2} │ +│ 3 ┆ three ┆ {false,true,"three",3} │ +└─────┴───────┴────────────────────────┘`; + expect(actual.toString()).toStrictEqual(expected); + expect(actual.getColumn("name").toArray()).toEqual( + rows.map((e) => e["name"]), + ); + expect(actual.getColumn("attributes").toArray()).toMatchObject( + rows.map((e) => e["attributes"]), + ); + }); test("pivot", () => { { const df = pl.DataFrame({ diff --git a/src/dataframe.rs b/src/dataframe.rs index 948b3119..6bd9e4d6 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -442,27 +442,26 @@ pub fn from_rows( infer_schema(pairs, infer_schema_length) } }; - let len = rows.len(); - let it: Vec = (0..len) + let it: Vec = (0..rows.len()) .into_iter() .map(|idx| { let obj = rows .get::(idx as u32) .unwrap_or(None) .unwrap_or_else(|| env.create_object().unwrap()); + Row(schema .iter_fields() .map(|fld| { - let dtype = fld.dtype().clone(); - let key = fld.name(); - if let Ok(unknown) = obj.get(key) { - let av = match unknown { - Some(unknown) => unsafe { - coerce_js_anyvalue(unknown, dtype).unwrap_or(AnyValue::Null) - }, - None => AnyValue::Null, - }; - av + let dtype: &DataType = fld.dtype(); + let key: &PlSmallStr = fld.name(); + if let Ok(unknown) = obj.get::<&polars::prelude::PlSmallStr, JsUnknown>(key) { + match unknown { + Some(unknown) => { + coerce_js_anyvalue(unknown, dtype.clone()).unwrap_or(AnyValue::Null) + } + _ => AnyValue::Null, + } } else { AnyValue::Null } @@ -1620,61 +1619,79 @@ fn obj_to_pairs(rows: &Array, len: usize) -> impl '_ + Iterator(idx as u32).unwrap().unwrap(); - let keys = Object::keys(&obj).unwrap(); keys.iter() .map(|key| { let value = obj.get::<_, napi::JsUnknown>(&key).unwrap_or(None); - let dtype = match value { - Some(val) => { - let ty = val.get_type().unwrap(); - match ty { - ValueType::Boolean => DataType::Boolean, - ValueType::Number => DataType::Float64, - ValueType::String => DataType::String, - ValueType::Object => { - if val.is_array().unwrap() { - let arr: napi::JsObject = unsafe { val.cast() }; - let len = arr.get_array_length().unwrap(); - - if len == 0 { - DataType::List(DataType::Null.into()) - } else { - // dont compare too many items, as it could be expensive - let max_take = std::cmp::min(len as usize, 10); - let mut dtypes: Vec = - Vec::with_capacity(len as usize); - - for idx in 0..max_take { - let item: napi::JsUnknown = - arr.get_element(idx as u32).unwrap(); - let ty = item.get_type().unwrap(); - let dt: Wrap = ty.into(); - dtypes.push(dt.0) - } - let dtype = coerce_data_type(&dtypes); - - DataType::List(dtype.into()) - } - } else if val.is_date().unwrap() { - DataType::Datetime(TimeUnit::Milliseconds, None) - } else { - DataType::Struct(vec![]) - } - } - ValueType::BigInt => DataType::UInt64, - _ => DataType::Null, - } - } - None => DataType::Null, - }; - (key.to_owned(), dtype) + (key.to_owned(), obj_to_type(value)) }) .collect() }) } -unsafe fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult> { +fn obj_to_type(value: Option) -> DataType { + match value { + Some(val) => { + let ty = val.get_type().unwrap(); + match ty { + ValueType::Boolean => DataType::Boolean, + ValueType::Number => DataType::Float64, + ValueType::BigInt => DataType::UInt64, + ValueType::String => DataType::String, + ValueType::Object => { + if val.is_array().unwrap() { + let arr: napi::JsObject = unsafe { val.cast() }; + let len = arr.get_array_length().unwrap(); + if len == 0 { + DataType::List(DataType::Null.into()) + } else { + // dont compare too many items, as it could be expensive + let max_take = std::cmp::min(len as usize, 10); + let mut dtypes: Vec = Vec::with_capacity(len as usize); + + for idx in 0..max_take { + let item: napi::JsUnknown = arr.get_element(idx as u32).unwrap(); + let ty = item.get_type().unwrap(); + let dt: Wrap = ty.into(); + dtypes.push(dt.0) + } + let dtype = coerce_data_type(&dtypes); + + DataType::List(dtype.into()) + } + } else if val.is_date().unwrap() { + DataType::Datetime(TimeUnit::Milliseconds, None) + } else { + let inner_val: napi::JsObject = unsafe { val.cast() }; + let inner_keys = Object::keys(&inner_val).unwrap(); + let mut fldvec: Vec = Vec::with_capacity(inner_keys.len() as usize); + + inner_keys.iter().for_each(|key| { + let inner_val = inner_val.get::<_, napi::JsUnknown>(&key).unwrap(); + let dtype = match inner_val.as_ref().unwrap().get_type().unwrap() { + ValueType::Boolean => DataType::Boolean, + ValueType::Number => DataType::Float64, + ValueType::BigInt => DataType::UInt64, + ValueType::String => DataType::String, + // determine struct type using a recursive func + ValueType::Object => obj_to_type(inner_val), + _ => DataType::Null, + }; + + let fld = Field::new(key.into(), dtype); + fldvec.push(fld); + }); + DataType::Struct(fldvec) + } + } + _ => DataType::Null, + } + } + None => DataType::Null, + } +} + +fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult> { use DataType::*; let vtype = val.get_type().unwrap(); match (vtype, dtype) { @@ -1749,7 +1766,7 @@ unsafe fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult { if val.is_date()? { - let d: napi::JsDate = val.cast(); + let d: napi::JsDate = unsafe { val.cast() }; let d = d.value_of()?; Ok(AnyValue::Datetime(d as i64, TimeUnit::Milliseconds, None)) } else { @@ -1757,9 +1774,51 @@ unsafe fn coerce_js_anyvalue<'a>(val: JsUnknown, dtype: DataType) -> JsResult { - let s = val.to_series(); + let s = unsafe { val.to_series() }; Ok(AnyValue::List(s)) } + (ValueType::Object, DataType::Struct(fields)) => { + let number_of_fields: i8 = fields.len().try_into().map_err(|e| { + napi::Error::from_reason(format!( + "the number of `fields` cannot be larger than i8::MAX {e:?}" + )) + })?; + + let inner_val: napi::JsObject = unsafe { val.cast() }; + let mut val_vec: Vec> = + Vec::with_capacity(number_of_fields as usize); + fields.iter().for_each(|fld| { + let single_val = inner_val + .get::<_, napi::JsUnknown>(&fld.name) + .unwrap() + .unwrap(); + let vv = match &fld.dtype { + DataType::Boolean => { + AnyValue::Boolean(single_val.coerce_to_bool().unwrap().get_value().unwrap()) + } + DataType::String => AnyValue::from_js(single_val).expect("Expecting string"), + DataType::Int16 => AnyValue::Int16( + single_val.coerce_to_number().unwrap().get_int32().unwrap() as i16, + ), + DataType::Int32 => { + AnyValue::Int32(single_val.coerce_to_number().unwrap().get_int32().unwrap()) + } + DataType::Int64 => { + AnyValue::Int64(single_val.coerce_to_number().unwrap().get_int64().unwrap()) + } + DataType::Float64 => AnyValue::Float64( + single_val.coerce_to_number().unwrap().get_double().unwrap(), + ), + DataType::Struct(_) => { + coerce_js_anyvalue(single_val, fld.dtype.clone()).unwrap() + } + _ => AnyValue::Null, + }; + val_vec.push(vv); + }); + + Ok(AnyValue::StructOwned(Box::new((val_vec, fields)))) + } _ => Ok(AnyValue::Null), } }