Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Decimal type #226

Merged
merged 9 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion __tests__/series.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
/* eslint-disable newline-per-chained-call */
import pl from "@polars";
import { InvalidOperationError } from "../polars/error";
import Chance from "chance";

describe("from lists", () => {
Expand Down Expand Up @@ -186,6 +185,32 @@ describe("typedArrays", () => {
const actual = pl.Series(float64Array).toTypedArray();
expect(JSON.stringify(actual)).toEqual(JSON.stringify(float64Array));
});

test("decimal", () => {
const expected = [1n, 2n, 3n];
const expectedDtype = pl.Decimal(10, 2);
const actual = pl.Series("", expected, expectedDtype);
expect(actual.dtype).toEqual(expectedDtype);
try {
actual.toArray();
} catch (e: any) {
expect(e.message).toContain(
"Decimal is not a supported type in javascript, please convert to string or number before collecting to js",
);
}
});

test("fixed list", () => {
const expectedDtype = pl.FixedSizeList(pl.Float32, 3);
const expected = [
[1, 2, 3],
[4, 5, 6],
];
const actual = pl.Series("", expected, expectedDtype);
expect(actual.dtype).toEqual(expectedDtype);
const actualValues = actual.toArray();
expect(actualValues).toEqual(expected);
});
});
describe("series", () => {
const chance = new Chance();
Expand Down
37 changes: 37 additions & 0 deletions polars/datatypes/datatype.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ export abstract class DataType {
return new Categorical();
}

public static Decimal(precision: number, scale: number): DataType {
return new Decimal(precision, scale);
}

/**
* Calendar date and time type
* @param timeUnit any of 'ms' | 'ns' | 'us'
Expand Down Expand Up @@ -186,6 +190,38 @@ export class String extends DataType {}

export class Categorical extends DataType {}

export class Decimal extends DataType {
constructor(
private precision: number,
private scale: number,
) {
super();
}
override get inner() {
return [this.precision, this.scale];
}
override equals(other: DataType): boolean {
if (other.variant === this.variant) {
return (
this.precision === (other as Decimal).precision &&
this.scale === (other as Decimal).scale
);
}
return false;
}

override toJSON() {
return {
[this.identity]: {
[this.variant]: {
precision: this.precision,
scale: this.scale,
},
},
};
}
}

/**
* Datetime type
*/
Expand Down Expand Up @@ -349,6 +385,7 @@ export namespace DataType {
export type Object = import(".").Object_;
export type Null = import(".").Null;
export type Struct = import(".").Struct;
export type Decimal = import(".").Decimal;
/**
* deserializes a datatype from the serde output of rust polars `DataType`
* @param dtype dtype object
Expand Down
9 changes: 9 additions & 0 deletions polars/datatypes/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ const POLARS_TYPE_TO_CONSTRUCTOR: Record<string, any> = {
List(name, values, _strict, dtype) {
return pli.JsSeries.newList(name, values, dtype);
},
Decimal(name, values, strict, dtype) {
return pli.JsSeries.newDecimal(
name,
values,
strict,
dtype.inner[0], // precision
dtype.inner[1], // scale
);
},
};

/** @ignore */
Expand Down
5 changes: 4 additions & 1 deletion polars/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { Decimal } from "./datatypes/datatype";
import * as series from "./series";
import * as df from "./dataframe";
import { DataType, Field as _field } from "./datatypes";
export { DataType } from "./datatypes";
export * from "./datatypes";
import * as func from "./functions";
import * as io from "./io";
import * as cfg from "./cfg";
Expand Down Expand Up @@ -111,6 +112,7 @@ export namespace pl {
export type Object = import("./datatypes").Object_;
export type Null = import("./datatypes").Null;
export type Struct = import("./datatypes").Struct;
export type Decimal = import("./datatypes").Decimal;

export const Categorical = DataType.Categorical;
export const Int8 = DataType.Int8;
Expand All @@ -137,6 +139,7 @@ export namespace pl {
export const Object = DataType.Object;
export const Null = DataType.Null;
export const Struct = DataType.Struct;
export const Decimal = DataType.Decimal;

/**
* Run SQL queries against DataFrame/LazyFrame data.
Expand Down
7 changes: 7 additions & 0 deletions polars/internals/construction.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ export function arrayToJsSeries(

return df.toStruct(name);
}

if (dtype?.variant === "Decimal") {
if (typeof firstValue !== "bigint") {
throw new Error("Decimal type can only be constructed from BigInt");
}
return pli.JsSeries.newAnyvalue(name, values, dtype, strict);
}
if (firstValue instanceof Date) {
series = pli.JsSeries.newOptDate(name, values, strict);
} else {
Expand Down
40 changes: 39 additions & 1 deletion src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ impl ToSeries for JsUnknown {
Series::new("", v)
}
}

impl ToNapiValue for Wrap<&Series> {
unsafe fn to_napi_value(napi_env: sys::napi_env, val: Self) -> napi::Result<sys::napi_value> {
let s = val.0;
Expand Down Expand Up @@ -101,6 +102,7 @@ impl ToNapiValue for Wrap<&Series> {
}
}
}

impl<'a> ToNapiValue for Wrap<AnyValue<'a>> {
unsafe fn to_napi_value(env: sys::napi_env, val: Self) -> Result<sys::napi_value> {
match val.0 {
Expand Down Expand Up @@ -152,7 +154,16 @@ impl<'a> ToNapiValue for Wrap<AnyValue<'a>> {
AnyValue::Time(v) => i64::to_napi_value(env, v),
AnyValue::List(ser) => Wrap::<&Series>::to_napi_value(env, Wrap(&ser)),
ref av @ AnyValue::Struct(_, _, flds) => struct_dict(env, av._iter_struct_av(), flds),
_ => todo!(),
AnyValue::Array(ser, _) => Wrap::<&Series>::to_napi_value(env, Wrap(&ser)),
AnyValue::Enum(_, _, _) => todo!(),
AnyValue::Object(_) => todo!(),
AnyValue::ObjectOwned(_) => todo!(),
AnyValue::StructOwned(_) => todo!(),
AnyValue::Binary(_) => todo!(),
AnyValue::BinaryOwned(_) => todo!(),
AnyValue::Decimal(_, _) => {
Err(napi::Error::from_reason("Decimal is not a supported type in javascript, please convert to string or number before collecting to js"))
}
}
}
}
Expand Down Expand Up @@ -679,6 +690,12 @@ impl FromNapiValue for Wrap<DataType> {
}
DataType::Struct(fldvec)
}
"Decimal" => {
let inner = obj.get::<_, Array>("inner")?.unwrap(); // [precision, scale]
let precision = inner.get::<i32>(0)?.unwrap();
let scale = inner.get::<i32>(1)?.unwrap();
DataType::Decimal(Some(precision as usize), Some(scale as usize))
}
tp => panic!("Type {} not implemented in str_to_polarstype", tp),
};
Ok(Wrap(dtype))
Expand Down Expand Up @@ -963,6 +980,27 @@ impl ToNapiValue for Wrap<DataType> {

Object::to_napi_value(env, obj)
}
DataType::Array(dtype, size) => {
let env_ctx = Env::from_raw(env);
let mut obj = env_ctx.create_object()?;
let wrapped = Wrap(*dtype);
let mut inner_arr = env_ctx.create_array(2)?;
inner_arr.set(0, wrapped)?;
inner_arr.set(1, size as u32)?;
obj.set("variant", "FixedSizeList")?;
obj.set("inner", inner_arr)?;
Object::to_napi_value(env, obj)
}
DataType::Decimal(precision, scale) => {
let env_ctx = Env::from_raw(env);
let mut obj = env_ctx.create_object()?;
let mut inner_arr = env_ctx.create_array(2)?;
inner_arr.set(0, precision.map(|p| p as u32))?;
inner_arr.set(1, scale.map(|s| s as u32))?;
obj.set("variant", "Decimal")?;
obj.set("inner", inner_arr)?;
Object::to_napi_value(env, obj)
}
_ => {
todo!()
}
Expand Down
10 changes: 6 additions & 4 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ pub struct ReadCsvOptions {
fn mmap_reader_to_df<'a>(
csv: impl MmapBytesReader + 'a,
options: ReadCsvOptions,
) -> napi::Result<JsDataFrame>
{
) -> napi::Result<JsDataFrame> {
let null_values = options.null_values.map(|w| w.0);
let row_count = options.row_count.map(RowIndex::from);
let projection = options
Expand Down Expand Up @@ -598,12 +597,15 @@ impl JsDataFrame {

let df = self
.df
.join(&other.df, left_on, right_on,
.join(
&other.df,
left_on,
right_on,
JoinArgs {
how: how,
suffix: suffix,
..Default::default()
}
},
)
.map_err(JsPolarsErr::from)?;
Ok(JsDataFrame::new(df))
Expand Down
24 changes: 15 additions & 9 deletions src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,16 +955,22 @@ impl JsExpr {
.into()
}
#[napi(catch_unwind)]
pub fn replace(&self, old: &JsExpr, new: &JsExpr, default: Option<&JsExpr>, return_dtype: Option<Wrap<DataType>>) -> JsExpr {
pub fn replace(
&self,
old: &JsExpr,
new: &JsExpr,
default: Option<&JsExpr>,
return_dtype: Option<Wrap<DataType>>,
) -> JsExpr {
self.inner
.clone()
.replace(
old.inner.clone(),
new.inner.clone(),
default.map(|e| e.inner.clone()),
return_dtype.map(|dt| dt.0),
)
.into()
.clone()
.replace(
old.inner.clone(),
new.inner.clone(),
default.map(|e| e.inner.clone()),
return_dtype.map(|dt| dt.0),
)
.into()
}
#[napi(catch_unwind)]
pub fn year(&self) -> JsExpr {
Expand Down
27 changes: 25 additions & 2 deletions src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,22 @@ impl JsSeries {
.into_series()
.into())
}

#[napi(factory, catch_unwind)]
pub fn new_anyvalue(
name: String,
values: Vec<Wrap<AnyValue>>,
dtype: Wrap<DataType>,
strict: bool,
) -> napi::Result<JsSeries> {
let values = values.into_iter().map(|v| v.0).collect::<Vec<_>>();

let s = Series::from_any_values_and_dtype(&name, &values, &dtype.0, strict)
.map_err(JsPolarsErr::from)?;

Ok(s.into())
}

#[napi(factory, catch_unwind)]
pub fn new_list(name: String, values: Array, dtype: Wrap<DataType>) -> napi::Result<JsSeries> {
use crate::list_construction::js_arr_to_list;
Expand Down Expand Up @@ -993,8 +1009,15 @@ impl JsSeries {
// Ok(ca.into_series().into())
// }
#[napi(catch_unwind)]
pub fn to_dummies(&self, separator: Option<&str>, drop_first: bool) -> napi::Result<JsDataFrame> {
let df = self.series.to_dummies(separator, drop_first).map_err(JsPolarsErr::from)?;
pub fn to_dummies(
&self,
separator: Option<&str>,
drop_first: bool,
) -> napi::Result<JsDataFrame> {
let df = self
.series
.to_dummies(separator, drop_first)
.map_err(JsPolarsErr::from)?;
Ok(df.into())
}
#[napi(catch_unwind)]
Expand Down