diff --git a/__tests__/io.test.ts b/__tests__/io.test.ts index 4e7ba909..29fd8b4b 100644 --- a/__tests__/io.test.ts +++ b/__tests__/io.test.ts @@ -274,6 +274,18 @@ describe("parquet", () => { const df = pl.scanParquet(parquetpath, { nRows: 4 }).collectSync(); expect(df.shape).toEqual({ height: 4, width: 4 }); }); + + test("writeParquet with decimals", async () => { + const df = pl.DataFrame([ + pl.Series("decimal", [1n, 2n, 3n], pl.Decimal()), + pl.Series("u32", [1, 2, 3], pl.UInt32), + pl.Series("str", ["a", "b", "c"]), + ]); + + const buf = df.writeParquet(); + const newDF = pl.readParquet(buf); + expect(newDF).toFrameEqual(df); + }); }); describe("ipc", () => { diff --git a/__tests__/series.test.ts b/__tests__/series.test.ts index cf739390..28f8fcea 100644 --- a/__tests__/series.test.ts +++ b/__tests__/series.test.ts @@ -1,6 +1,5 @@ /* eslint-disable newline-per-chained-call */ import pl from "@polars"; -import { InvalidOperationError } from "../polars/error"; import Chance from "chance"; describe("from lists", () => { @@ -186,6 +185,32 @@ describe("typedArrays", () => { const actual = pl.Series(float64Array).toTypedArray(); expect(JSON.stringify(actual)).toEqual(JSON.stringify(float64Array)); }); + + test("decimal", () => { + const expected = [1n, 2n, 3n]; + const expectedDtype = pl.Decimal(10, 2); + const actual = pl.Series("", expected, expectedDtype); + expect(actual.dtype).toEqual(expectedDtype); + try { + actual.toArray(); + } catch (e: any) { + expect(e.message).toContain( + "Decimal is not a supported type in javascript, please convert to string or number before collecting to js", + ); + } + }); + + test("fixed list", () => { + const expectedDtype = pl.FixedSizeList(pl.Float32, 3); + const expected = [ + [1, 2, 3], + [4, 5, 6], + ]; + const actual = pl.Series("", expected, expectedDtype); + expect(actual.dtype).toEqual(expectedDtype); + const actualValues = actual.toArray(); + expect(actualValues).toEqual(expected); + }); }); describe("series", () => { const chance = new Chance(); diff --git a/polars/datatypes/datatype.ts b/polars/datatypes/datatype.ts index 7fd2bb45..d9d882bc 100644 --- a/polars/datatypes/datatype.ts +++ b/polars/datatypes/datatype.ts @@ -81,6 +81,11 @@ export abstract class DataType { return new Categorical(); } + /** Decimal type */ + public static Decimal(precision?: number, scale?: number): DataType { + return new Decimal(precision, scale); + } + /** * Calendar date and time type * @param timeUnit any of 'ms' | 'ns' | 'us' @@ -186,6 +191,39 @@ export class String extends DataType {} export class Categorical extends DataType {} +export class Decimal extends DataType { + private precision: number | null; + private scale: number | null; + constructor(precision?: number, scale?: number) { + super(); + this.precision = precision ?? null; + this.scale = scale ?? null; + } + override get inner() { + return [this.precision, this.scale]; + } + override equals(other: DataType): boolean { + if (other.variant === this.variant) { + return ( + this.precision === (other as Decimal).precision && + this.scale === (other as Decimal).scale + ); + } + return false; + } + + override toJSON() { + return { + [this.identity]: { + [this.variant]: { + precision: this.precision, + scale: this.scale, + }, + }, + }; + } +} + /** * Datetime type */ @@ -234,10 +272,6 @@ export class FixedSizeList extends DataType { super(); } - override get variant() { - return "FixedSizeList"; - } - override get inner(): [DataType, number] { return [this.__inner, this.listSize]; } @@ -349,6 +383,7 @@ export namespace DataType { export type Object = import(".").Object_; export type Null = import(".").Null; export type Struct = import(".").Struct; + export type Decimal = import(".").Decimal; /** * deserializes a datatype from the serde output of rust polars `DataType` * @param dtype dtype object diff --git a/polars/index.ts b/polars/index.ts index 85666e96..5fc24470 100644 --- a/polars/index.ts +++ b/polars/index.ts @@ -1,7 +1,8 @@ +import { Decimal } from "./datatypes/datatype"; import * as series from "./series"; import * as df from "./dataframe"; import { DataType, Field as _field } from "./datatypes"; -export { DataType } from "./datatypes"; +export * from "./datatypes"; import * as func from "./functions"; import * as io from "./io"; import * as cfg from "./cfg"; @@ -111,6 +112,7 @@ export namespace pl { export type Object = import("./datatypes").Object_; export type Null = import("./datatypes").Null; export type Struct = import("./datatypes").Struct; + export type Decimal = import("./datatypes").Decimal; export const Categorical = DataType.Categorical; export const Int8 = DataType.Int8; @@ -137,6 +139,7 @@ export namespace pl { export const Object = DataType.Object; export const Null = DataType.Null; export const Struct = DataType.Struct; + export const Decimal = DataType.Decimal; /** * Run SQL queries against DataFrame/LazyFrame data. diff --git a/polars/internals/construction.ts b/polars/internals/construction.ts index 2babf7de..7425877a 100644 --- a/polars/internals/construction.ts +++ b/polars/internals/construction.ts @@ -158,6 +158,13 @@ export function arrayToJsSeries( return df.toStruct(name); } + + if (dtype?.variant === "Decimal") { + if (typeof firstValue !== "bigint") { + throw new Error("Decimal type can only be constructed from BigInt"); + } + return pli.JsSeries.newAnyvalue(name, values, dtype, strict); + } if (firstValue instanceof Date) { series = pli.JsSeries.newOptDate(name, values, strict); } else { diff --git a/src/conversion.rs b/src/conversion.rs index 89d5734e..5d4415b3 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -67,6 +67,7 @@ impl ToSeries for JsUnknown { Series::new("", v) } } + impl ToNapiValue for Wrap<&Series> { unsafe fn to_napi_value(napi_env: sys::napi_env, val: Self) -> napi::Result { let s = val.0; @@ -101,6 +102,7 @@ impl ToNapiValue for Wrap<&Series> { } } } + impl<'a> ToNapiValue for Wrap> { unsafe fn to_napi_value(env: sys::napi_env, val: Self) -> Result { match val.0 { @@ -152,7 +154,16 @@ impl<'a> ToNapiValue for Wrap> { AnyValue::Time(v) => i64::to_napi_value(env, v), AnyValue::List(ser) => Wrap::<&Series>::to_napi_value(env, Wrap(&ser)), ref av @ AnyValue::Struct(_, _, flds) => struct_dict(env, av._iter_struct_av(), flds), - _ => todo!(), + AnyValue::Array(ser, _) => Wrap::<&Series>::to_napi_value(env, Wrap(&ser)), + AnyValue::Enum(_, _, _) => todo!(), + AnyValue::Object(_) => todo!(), + AnyValue::ObjectOwned(_) => todo!(), + AnyValue::StructOwned(_) => todo!(), + AnyValue::Binary(_) => todo!(), + AnyValue::BinaryOwned(_) => todo!(), + AnyValue::Decimal(_, _) => { + Err(napi::Error::from_reason("Decimal is not a supported type in javascript, please convert to string or number before collecting to js")) + } } } } @@ -679,6 +690,12 @@ impl FromNapiValue for Wrap { } DataType::Struct(fldvec) } + "Decimal" => { + let inner = obj.get::<_, Array>("inner")?.unwrap(); // [precision, scale] + let precision = inner.get::>(0)?.unwrap().map(|x| x as usize); + let scale = inner.get::>(1)?.unwrap().map(|x| x as usize); + DataType::Decimal(precision, scale) + } tp => panic!("Type {} not implemented in str_to_polarstype", tp), }; Ok(Wrap(dtype)) @@ -963,6 +980,27 @@ impl ToNapiValue for Wrap { Object::to_napi_value(env, obj) } + DataType::Array(dtype, size) => { + let env_ctx = Env::from_raw(env); + let mut obj = env_ctx.create_object()?; + let wrapped = Wrap(*dtype); + let mut inner_arr = env_ctx.create_array(2)?; + inner_arr.set(0, wrapped)?; + inner_arr.set(1, size as u32)?; + obj.set("variant", "FixedSizeList")?; + obj.set("inner", inner_arr)?; + Object::to_napi_value(env, obj) + } + DataType::Decimal(precision, scale) => { + let env_ctx = Env::from_raw(env); + let mut obj = env_ctx.create_object()?; + let mut inner_arr = env_ctx.create_array(2)?; + inner_arr.set(0, precision.map(|p| p as u32))?; + inner_arr.set(1, scale.map(|s| s as u32))?; + obj.set("variant", "Decimal")?; + obj.set("inner", inner_arr)?; + Object::to_napi_value(env, obj) + } _ => { todo!() } diff --git a/src/dataframe.rs b/src/dataframe.rs index a0340197..8f44480d 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -90,8 +90,7 @@ pub struct ReadCsvOptions { fn mmap_reader_to_df<'a>( csv: impl MmapBytesReader + 'a, options: ReadCsvOptions, -) -> napi::Result -{ +) -> napi::Result { let null_values = options.null_values.map(|w| w.0); let row_count = options.row_count.map(RowIndex::from); let projection = options @@ -598,12 +597,15 @@ impl JsDataFrame { let df = self .df - .join(&other.df, left_on, right_on, + .join( + &other.df, + left_on, + right_on, JoinArgs { how: how, suffix: suffix, ..Default::default() - } + }, ) .map_err(JsPolarsErr::from)?; Ok(JsDataFrame::new(df)) diff --git a/src/lazy/dsl.rs b/src/lazy/dsl.rs index ab2dd313..3cd522ed 100644 --- a/src/lazy/dsl.rs +++ b/src/lazy/dsl.rs @@ -955,16 +955,22 @@ impl JsExpr { .into() } #[napi(catch_unwind)] - pub fn replace(&self, old: &JsExpr, new: &JsExpr, default: Option<&JsExpr>, return_dtype: Option>) -> JsExpr { + pub fn replace( + &self, + old: &JsExpr, + new: &JsExpr, + default: Option<&JsExpr>, + return_dtype: Option>, + ) -> JsExpr { self.inner - .clone() - .replace( - old.inner.clone(), - new.inner.clone(), - default.map(|e| e.inner.clone()), - return_dtype.map(|dt| dt.0), - ) - .into() + .clone() + .replace( + old.inner.clone(), + new.inner.clone(), + default.map(|e| e.inner.clone()), + return_dtype.map(|dt| dt.0), + ) + .into() } #[napi(catch_unwind)] pub fn year(&self) -> JsExpr { diff --git a/src/series.rs b/src/series.rs index 904e4cad..be4851e7 100644 --- a/src/series.rs +++ b/src/series.rs @@ -194,6 +194,22 @@ impl JsSeries { .into_series() .into()) } + + #[napi(factory, catch_unwind)] + pub fn new_anyvalue( + name: String, + values: Vec>, + dtype: Wrap, + strict: bool, + ) -> napi::Result { + let values = values.into_iter().map(|v| v.0).collect::>(); + + let s = Series::from_any_values_and_dtype(&name, &values, &dtype.0, strict) + .map_err(JsPolarsErr::from)?; + + Ok(s.into()) + } + #[napi(factory, catch_unwind)] pub fn new_list(name: String, values: Array, dtype: Wrap) -> napi::Result { use crate::list_construction::js_arr_to_list; @@ -993,8 +1009,15 @@ impl JsSeries { // Ok(ca.into_series().into()) // } #[napi(catch_unwind)] - pub fn to_dummies(&self, separator: Option<&str>, drop_first: bool) -> napi::Result { - let df = self.series.to_dummies(separator, drop_first).map_err(JsPolarsErr::from)?; + pub fn to_dummies( + &self, + separator: Option<&str>, + drop_first: bool, + ) -> napi::Result { + let df = self + .series + .to_dummies(separator, drop_first) + .map_err(JsPolarsErr::from)?; Ok(df.into()) } #[napi(catch_unwind)]