Skip to content

Commit

Permalink
Add Aja.Enum.sum_by/2 and Aja.Enum.product_by/2
Browse files Browse the repository at this point in the history
  • Loading branch information
sabiwara committed Dec 12, 2024
1 parent 11410e3 commit 874bf81
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 15 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## Dev

### Enhancements

- Add `Aja.Enum.sum_by/2` and `Aja.Enum.product_by/2`

## v0.7.2 (2024-10-31)

### Bug fixes
Expand Down
12 changes: 12 additions & 0 deletions bench/enum/product_by.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
list = Enum.map(1..1000, &(1 + rem(&1, 3)))
vector = Aja.Vector.new(list)

Benchee.run(%{
"Aja.Enum.product_by/2 (vector)" => fn -> Aja.Enum.product_by(vector, & &1) end,
"Enum.product_by/2 (vector)" => fn -> Enum.product_by(vector, & &1) end,
"Aja.Enum.reduce/3 (vector)" => fn -> Aja.Enum.reduce(vector, 1, &*/2) end,
"Enum.product_by/2 (list)" => fn -> Enum.product_by(list, & &1) end,
# for comparison:
"Enum.product(list)" => fn -> Enum.product(list) end,
"Aja.Enum.product(vector)" => fn -> Aja.Enum.product(vector) end
})
41 changes: 41 additions & 0 deletions bench/enum/product_by.results.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
Operating System: macOS
CPU Information: Apple M1
Number of Available Cores: 8
Available memory: 16 GB
Elixir 1.18.0-rc.0
Erlang 27.0
JIT enabled: true

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 5 s
memory time: 0 ns
reduction time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 42 s

Benchmarking Aja.Enum.product(vector) ...
Benchmarking Aja.Enum.product_by/2 (vector) ...
Benchmarking Aja.Enum.reduce/3 (vector) ...
Benchmarking Enum.product(list) ...
Benchmarking Enum.product_by/2 (list) ...
Benchmarking Enum.product_by/2 (vector) ...
Calculating statistics...
Formatting results...

Name ips average deviation median 99th %
Aja.Enum.product(vector) 90.68 K 11.03 μs ±22.83% 10.92 μs 12.79 μs
Enum.product(list) 85.90 K 11.64 μs ±14.89% 11.25 μs 14.63 μs
Aja.Enum.product_by/2 (vector) 80.01 K 12.50 μs ±16.89% 12 μs 18.63 μs
Enum.product_by/2 (list) 74.35 K 13.45 μs ±44.35% 12.83 μs 23.63 μs
Aja.Enum.reduce/3 (vector) 73.37 K 13.63 μs ±32.29% 13.25 μs 19.17 μs
Enum.product_by/2 (vector) 51.72 K 19.34 μs ±20.57% 18.67 μs 25.42 μs

Comparison:
Aja.Enum.product(vector) 90.68 K
Enum.product(list) 85.90 K - 1.06x slower +0.61 μs
Aja.Enum.product_by/2 (vector) 80.01 K - 1.13x slower +1.47 μs
Enum.product_by/2 (list) 74.35 K - 1.22x slower +2.42 μs
Aja.Enum.reduce/3 (vector) 73.37 K - 1.24x slower +2.60 μs
Enum.product_by/2 (vector) 51.72 K - 1.75x slower +8.31 μs
12 changes: 12 additions & 0 deletions bench/enum/sum_by.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
list = Enum.map(1..1000, &(1 + rem(&1, 3)))
vector = Aja.Vector.new(list)

Benchee.run(%{
"Aja.Enum.sum_by/2 (vector)" => fn -> Aja.Enum.sum_by(vector, & &1) end,
"Aja.Enum.sum_by/2 (list)" => fn -> Aja.Enum.sum_by(list, & &1) end,
"Enum.sum_by/2 (vector)" => fn -> Enum.sum_by(vector, & &1) end,
"Aja.Enum.reduce/3 (vector)" => fn -> Aja.Enum.reduce(vector, 0, &+/2) end,
"Enum.sum_by/2 (list)" => fn -> Enum.sum_by(list, & &1) end,
# for comparison:
":lists.sum/1" => fn -> :lists.sum(list) end
})
41 changes: 41 additions & 0 deletions bench/enum/sum_by.results.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
Operating System: macOS
CPU Information: Apple M1
Number of Available Cores: 8
Available memory: 16 GB
Elixir 1.18.0-rc.0
Erlang 27.0
JIT enabled: true

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 5 s
memory time: 0 ns
reduction time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 42 s

Benchmarking :lists.sum/1 ...
Benchmarking Aja.Enum.reduce/3 (vector) ...
Benchmarking Aja.Enum.sum_by/2 (list) ...
Benchmarking Aja.Enum.sum_by/2 (vector) ...
Benchmarking Enum.sum_by/2 (list) ...
Benchmarking Enum.sum_by/2 (vector) ...
Calculating statistics...
Formatting results...

Name ips average deviation median 99th %
:lists.sum/1 519.20 K 1.93 μs ±264.77% 1.88 μs 2.08 μs
Aja.Enum.sum_by/2 (vector) 393.70 K 2.54 μs ±1997.47% 2.33 μs 2.63 μs
Aja.Enum.sum_by/2 (list) 312.38 K 3.20 μs ±250.41% 3.21 μs 3.38 μs
Enum.sum_by/2 (list) 305.42 K 3.27 μs ±278.26% 3.21 μs 3.58 μs
Aja.Enum.reduce/3 (vector) 147.91 K 6.76 μs ±109.40% 6.71 μs 7.04 μs
Enum.sum_by/2 (vector) 121.08 K 8.26 μs ±79.69% 7.92 μs 12.92 μs

Comparison:
:lists.sum/1 519.20 K
Aja.Enum.sum_by/2 (vector) 393.70 K - 1.32x slower +0.61 μs
Aja.Enum.sum_by/2 (list) 312.38 K - 1.66x slower +1.28 μs
Enum.sum_by/2 (list) 305.42 K - 1.70x slower +1.35 μs
Aja.Enum.reduce/3 (vector) 147.91 K - 3.51x slower +4.83 μs
Enum.sum_by/2 (vector) 121.08 K - 4.29x slower +6.33 μs
74 changes: 59 additions & 15 deletions lib/enum.ex
Original file line number Diff line number Diff line change
Expand Up @@ -603,21 +603,46 @@ defmodule Aja.Enum do
Returns the sum of all elements.
Mirrors `Enum.sum/1` with higher performance for Aja structures.
Raises `ArithmeticError` if `enumerable` contains a non-numeric value.
"""
@spec sum(t(num)) :: num when num: number
def sum(enumerable) do
case H.try_get_raw_vec_or_list(enumerable) do
nil -> Enum.sum(enumerable)
list when is_list(list) -> :lists.sum(list)
vector -> RawVector.sum(vector)
end
end

@doc """
Maps and sums the given `enumerable` in one pass.
Mirrors `Enum.sum_by/2` with higher performance for Aja structures.
Raises `ArithmeticError` if `mapper` returns a non-numeric value.
"""
@spec sum_by(t(val), (val -> num)) :: num when val: value, num: number
def sum_by(enumerable, fun) do
case H.try_get_raw_vec_or_list(enumerable) do
nil ->
Enum.sum(enumerable)
# TODO use Enum.sum_by/1 for Elixir 1.18+
reduce(enumerable, 0, fn el, acc -> fun.(el) + acc end)

list when is_list(list) ->
:lists.sum(list)
sum_by_list(list, fun, 0)

vector ->
RawVector.sum(vector)
RawVector.sum_by(vector, fun)
end
end

defp sum_by_list([], _fun, acc), do: acc

defp sum_by_list([head | rest], fun, acc) do
sum_by_list(rest, fun, fun.(head) + acc)
end

@doc """
Returns the product of all elements in the `enumerable`.
Expand All @@ -627,31 +652,50 @@ defmodule Aja.Enum do
## Examples
iex> 1..5 |> Aja.Enum.product()
iex> Aja.Enum.product(1..5)
120
iex> [] |> Aja.Enum.product()
iex> Aja.Enum.product([])
1
"""
@spec product(t(num)) :: num when num: number
def product(enumerable) do
case H.try_get_raw_vec_or_list(enumerable) do
nil ->
# TODO use Enum.product/1 for Elixir 1.11
reduce(enumerable, 1, &*/2)
nil -> Enum.product(enumerable)
list when is_list(list) -> Enum.product(list)
vector -> RawVector.product(vector)
end
end

list when is_list(list) ->
product_list(list, 1)
@doc """
Maps and computes the product of the given `enumerable` in one pass.
vector ->
RawVector.product(vector)
Mirrors `Enum.product_by/2`.
Raises `ArithmeticError` if `mapper` returns a non-numeric value.
## Examples
iex> Aja.Enum.product_by(1..3, & &1 + 1)
24
iex> Aja.Enum.product_by([], & &1 + 1)
1
"""
@spec product_by(t(val), (val -> num)) :: num when val: value, num: number
def product_by(enumerable, fun) do
case H.try_get_raw_vec_or_list(enumerable) do
# TODO use Enum.product_by/1 for Elixir 1.18+
nil -> Enum.reduce(enumerable, 1, fn el, acc -> fun.(el) * acc end)
list when is_list(list) -> product_by_list(list, fun, 1)
vector -> RawVector.product_by(vector, fun)
end
end

defp product_list([], acc), do: acc
defp product_by_list([], _fun, acc), do: acc

defp product_list([head | rest], acc) do
product_list(rest, head * acc)
defp product_by_list([head | rest], fun, acc) do
product_by_list(rest, fun, fun.(head) * acc)
end

@doc """
Expand Down
10 changes: 10 additions & 0 deletions lib/vector/raw.ex
Original file line number Diff line number Diff line change
Expand Up @@ -531,11 +531,21 @@ defmodule Aja.Vector.Raw do
acc + arg
end

@spec sum_by(t(val), (val -> number)) :: number when val: value
C.def_foldl sum_by(arg, acc \\ 0, fun) do
acc + fun.(arg)
end

@spec product(t(number)) :: number
C.def_foldl product(arg, acc \\ 1) do
acc * arg
end

@spec product_by(t(number), (val -> number)) :: number when val: value
C.def_foldl product_by(arg, acc \\ 1, fun) do
acc * fun.(arg)
end

@spec count(t(val), (val -> as_boolean(term))) :: non_neg_integer when val: value
C.def_foldl count(arg, acc \\ 0, fun) do
if fun.(arg) do
Expand Down
12 changes: 12 additions & 0 deletions test/enum_prop_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,19 @@ defmodule Aja.Enum.PropTest do
assert ^sum_result = Aja.Enum.sum(stream) |> capture_error()
assert capture_error(Enum.sum(map_set)) === capture_error(Aja.Enum.sum(map_set))

assert ^sum_result = Aja.Enum.sum_by(list, & &1) |> capture_error()
assert ^sum_result = Aja.Enum.sum_by(vector, & &1) |> capture_error()
assert ^sum_result = Aja.Enum.sum_by(stream, & &1) |> capture_error()

product_result = Enum.reduce(list, 1, &(&2 * &1)) |> capture_error()
assert ^product_result = Aja.Enum.product(list) |> capture_error()
assert ^product_result = Aja.Enum.product(vector) |> capture_error()
assert ^product_result = Aja.Enum.product(stream) |> capture_error()

assert ^product_result = Aja.Enum.product_by(list, & &1) |> capture_error()
assert ^product_result = Aja.Enum.product_by(vector, & &1) |> capture_error()
assert ^product_result = Aja.Enum.product_by(stream, & &1) |> capture_error()

join_result = Enum.join(list, ",") |> capture_error()
assert ^join_result = Aja.Enum.join(list, ",") |> capture_error()
assert ^join_result = Aja.Enum.join(stream, ",") |> capture_error()
Expand Down Expand Up @@ -449,6 +457,10 @@ defmodule Aja.Enum.PropTest do
assert Enum.sum(list) === Aja.Enum.sum(list)
assert Enum.sum(list) === Aja.Enum.sum(vector)
assert Enum.sum(map_set) === Aja.Enum.sum(map_set)

assert Enum.sum(list) * 2 === Aja.Enum.sum_by(list, &(&1 * 2))
assert Enum.sum(list) * 2 === Aja.Enum.sum_by(vector, &(&1 * 2))
assert Enum.sum(map_set) * 2 === Aja.Enum.sum_by(map_set, &(&1 * 2))
end
end

Expand Down

0 comments on commit 874bf81

Please sign in to comment.