From 2644d270bf8dfebb22b91c0713a8be9d1c208d10 Mon Sep 17 00:00:00 2001 From: Philip Sampaio Date: Thu, 18 Jul 2024 18:27:15 -0300 Subject: [PATCH] Attempt to fix "pow/2" after Polars changes --- lib/explorer/backend/lazy_series.ex | 11 ++++++++++- lib/explorer/data_frame.ex | 4 ++-- lib/explorer/series.ex | 12 +++++++++++- test/explorer/data_frame_test.exs | 26 ++++++++++++-------------- test/explorer/series_test.exs | 16 ++++++++-------- 5 files changed, 43 insertions(+), 26 deletions(-) diff --git a/lib/explorer/backend/lazy_series.ex b/lib/explorer/backend/lazy_series.ex index 5bb147bad..9738f53bb 100644 --- a/lib/explorer/backend/lazy_series.ex +++ b/lib/explorer/backend/lazy_series.ex @@ -166,7 +166,7 @@ defmodule Explorer.Backend.LazySeries do @comparison_operations [:equal, :not_equal, :greater, :greater_equal, :less, :less_equal] - @basic_arithmetic_operations [:add, :subtract, :multiply, :divide, :pow] + @basic_arithmetic_operations [:add, :subtract, :multiply, :divide] @other_arithmetic_operations [:quotient, :remainder] @aggregation_operations [ @@ -453,6 +453,15 @@ defmodule Explorer.Backend.LazySeries do end end + @impl true + def pow(dtype, %Series{} = left, %Series{} = right) do + # Cast from the main module is needed because we may be seeing a series from another backend. + args = [data!(Explorer.Series.cast(left, dtype)), data!(right)] + data = new(:pow, args, dtype, aggregations?(args)) + + Backend.Series.new(data, dtype) + end + for op <- @other_arithmetic_operations do @impl true def unquote(op)(left, right) do diff --git a/lib/explorer/data_frame.ex b/lib/explorer/data_frame.ex index c1fcbc024..88d1f1001 100644 --- a/lib/explorer/data_frame.ex +++ b/lib/explorer/data_frame.ex @@ -2774,11 +2774,11 @@ defmodule Explorer.DataFrame do You can overwrite existing columns as well: iex> df = Explorer.DataFrame.new(a: ["a", "b", "c"], b: [1, 2, 3]) - iex> Explorer.DataFrame.mutate_with(df, &[b: Explorer.Series.pow(&1["b"], 2)]) + iex> Explorer.DataFrame.mutate_with(df, &[b: Explorer.Series.add(&1["b"], 2)]) #Explorer.DataFrame< Polars[3 x 2] a string ["a", "b", "c"] - b s64 [1, 4, 9] + b s64 [3, 4, 5] > It's possible to "reuse" a variable for different computations: diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index 8145b85cb..1bcb3d259 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -3598,6 +3598,9 @@ defmodule Explorer.Series do sizes are series, the series must have the same size or at last one of them must have size of 1. + In case the expoent is a signed integer number or series, + the resultant series will be of `{:f, 64}` dtype. + ## Supported dtypes * floats: #{Shared.inspect_dtypes(@float_dtypes, backsticks: true)} @@ -3614,6 +3617,13 @@ defmodule Explorer.Series do iex> s = [2, 4, 6] |> Explorer.Series.from_list() iex> Explorer.Series.pow(s, 3) + #Explorer.Series< + Polars[3] + f64 [8.0, 64.0, 216.0] + > + + iex> s = [2, 4, 6] |> Explorer.Series.from_list() + iex> Explorer.Series.pow(s, Explorer.Series.from_list([3], dtype: :u32)) #Explorer.Series< Polars[3] s64 [8, 64, 216] @@ -3657,7 +3667,7 @@ defmodule Explorer.Series do defp cast_to_pow({:f, l}, {:f, r}), do: {:f, max(l, r)} defp cast_to_pow({:f, l}, {n, _}) when K.in(n, [:u, :s]), do: {:f, l} defp cast_to_pow({n, _}, {:f, r}) when K.in(n, [:u, :s]), do: {:f, r} - defp cast_to_pow({n, _}, {:s, _}) when K.in(n, [:u, :s]), do: {:s, 64} + defp cast_to_pow({n, _}, {:s, _}) when K.in(n, [:u, :s]), do: {:f, 64} defp cast_to_pow(_, _), do: nil @doc """ diff --git a/test/explorer/data_frame_test.exs b/test/explorer/data_frame_test.exs index 777174c57..ace2554a6 100644 --- a/test/explorer/data_frame_test.exs +++ b/test/explorer/data_frame_test.exs @@ -299,7 +299,7 @@ defmodule Explorer.DataFrameTest do df = DF.new(a: [1, 2, 3, 4, 5, 6, 5], b: [9, 8, 7, 6, 5, 4, 3]) message = - "expecting the function to return a boolean LazySeries, but instead it returned a LazySeries of type {:s, 64}" + "expecting the function to return a boolean LazySeries, but instead it returned a LazySeries of type {:f, 64}" assert_raise ArgumentError, message, fn -> DF.filter_with(df, fn ldf -> @@ -948,7 +948,7 @@ defmodule Explorer.DataFrameTest do calc2: [-1, 0, 2], calc3: [2, 4, 8], calc4: [0.5, 1.0, 2.0], - calc5: [1, 4, 16], + calc5: [1.0, 4.0, 16.0], calc6: [0, 1, 2], calc7: [1, 0, 0], calc8: [:nan, :nan, :nan], @@ -964,7 +964,7 @@ defmodule Explorer.DataFrameTest do "calc2" => {:s, 64}, "calc3" => {:s, 64}, "calc4" => {:f, 64}, - "calc5" => {:s, 64}, + "calc5" => {:f, 64}, "calc6" => {:s, 64}, "calc7" => {:s, 64}, "calc8" => {:f, 64}, @@ -985,7 +985,6 @@ defmodule Explorer.DataFrameTest do calc3: multiply(2, a), calc4: divide(2, a), calc5: pow(2, a), - calc5_1: pow(2.0, a), calc6: quotient(2, a), calc7: remainder(2, a) ) @@ -996,8 +995,7 @@ defmodule Explorer.DataFrameTest do calc2: [1, 0, -2], calc3: [2, 4, 8], calc4: [2.0, 1.0, 0.5], - calc5: [2, 4, 16], - calc5_1: [2.0, 4.0, 16.0], + calc5: [2.0, 4.0, 16.0], calc6: [2, 1, 0], calc7: [0, 0, 2] } @@ -1008,8 +1006,7 @@ defmodule Explorer.DataFrameTest do "calc2" => {:s, 64}, "calc3" => {:s, 64}, "calc4" => {:f, 64}, - "calc5" => {:s, 64}, - "calc5_1" => {:f, 64}, + "calc5" => {:f, 64}, "calc6" => {:s, 64}, "calc7" => {:s, 64} } @@ -1017,6 +1014,7 @@ defmodule Explorer.DataFrameTest do test "adds some columns with arithmetic operations on (lazy series, series)" do df = DF.new(a: [1, 2, 4]) + # TODO: check remainder and quotient in case they have a u32 on the right side. series = Explorer.Series.from_list([2, 1, 2]) df1 = @@ -1036,7 +1034,7 @@ defmodule Explorer.DataFrameTest do calc2: [-1, 1, 2], calc3: [2, 2, 8], calc4: [0.5, 2.0, 2.0], - calc5: [1, 2, 16], + calc5: [1.0, 2.0, 16.0], calc6: [0, 2, 2], calc7: [1, 0, 0] } @@ -1047,7 +1045,7 @@ defmodule Explorer.DataFrameTest do "calc2" => {:s, 64}, "calc3" => {:s, 64}, "calc4" => {:f, 64}, - "calc5" => {:s, 64}, + "calc5" => {:f, 64}, "calc6" => {:s, 64}, "calc7" => {:s, 64} } @@ -1074,7 +1072,7 @@ defmodule Explorer.DataFrameTest do calc2: [-1, 1, 2], calc3: [2, 2, 8], calc4: [0.5, 2.0, 2.0], - calc5: [1, 2, 16], + calc5: [1.0, 2.0, 16.0], calc6: [0, 2, 2], calc7: [1, 0, 0] } @@ -1085,7 +1083,7 @@ defmodule Explorer.DataFrameTest do "calc2" => {:s, 64}, "calc3" => {:s, 64}, "calc4" => {:f, 64}, - "calc5" => {:s, 64}, + "calc5" => {:f, 64}, "calc6" => {:s, 64}, "calc7" => {:s, 64} } @@ -1114,7 +1112,7 @@ defmodule Explorer.DataFrameTest do calc2: [19, 38, 57], calc3: [3, 4, 3], calc4: [2.0, :infinity, 7.5], - calc5: [1, 4, 3], + calc5: [1.0, 4.0, 3.0], calc6: [2, nil, 7], calc7: [0, nil, 4] } @@ -1128,7 +1126,7 @@ defmodule Explorer.DataFrameTest do "calc2" => {:s, 64}, "calc3" => {:s, 64}, "calc4" => {:f, 64}, - "calc5" => {:s, 64}, + "calc5" => {:f, 64}, "calc6" => {:s, 64}, "calc7" => {:s, 64} } diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index 9e2143fe8..8b2d6fa42 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -2290,8 +2290,8 @@ defmodule Explorer.SeriesTest do result = Series.pow(base, power) - assert result.dtype == {:s, 64} - assert Series.to_list(result) == [1, 4, 3] + assert result.dtype == {:f, 64} + assert Series.to_list(result) == [1.0, 4.0, 3.0] end end @@ -2315,8 +2315,8 @@ defmodule Explorer.SeriesTest do result = Series.pow(base, power) - assert result.dtype == {:s, 64} - assert Series.to_list(result) === [1, 4, 3] + assert result.dtype == {:f, 64} + assert Series.to_list(result) === [1.0, 4.0, 3.0] end end @@ -2392,13 +2392,13 @@ defmodule Explorer.SeriesTest do result = Series.pow(s1, s2) - assert result.dtype == {:s, 64} + assert result.dtype == {:f, 64} assert Series.to_list(result) == [1, nil, 3] end test "pow of an integer series that contains nil with an integer series" do s1 = Series.from_list([1, nil, 3]) - s2 = Series.from_list([3, 2, 1]) + s2 = Series.from_list([3, 2, 1], dtype: :u32) result = Series.pow(s1, s2) @@ -2408,7 +2408,7 @@ defmodule Explorer.SeriesTest do test "pow of an integer series that contains nil with an integer series also with nil" do s1 = Series.from_list([1, nil, 3]) - s2 = Series.from_list([3, nil, 1]) + s2 = Series.from_list([3, nil, 1], dtype: :u32) result = Series.pow(s1, s2) @@ -2421,7 +2421,7 @@ defmodule Explorer.SeriesTest do result = Series.pow(s1, 2) - assert result.dtype == {:s, 64} + assert result.dtype == {:f, 64} assert Series.to_list(result) == [1, 4, 9] end