Skip to content

feat: implemented reduce over chunks #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 75 additions & 35 deletions laygo/transformers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import reduce
import itertools
from typing import Any
from typing import Literal
from typing import Self
from typing import Union
from typing import overload
Expand Down Expand Up @@ -343,42 +344,81 @@ def __call__(self, data: Iterable[In], context: PipelineContext | None = None) -
# The context is now passed explicitly through the transformer chain.
yield from self.transformer(chunk, run_context)

def reduce[U](self, function: PipelineReduceFunction[U, Out], initial: U):
"""Reduce elements to a single value (terminal operation).

Args:
function: The reduction function. Can be context-aware.
initial: The initial value for the reduction.

Returns:
A function that executes the reduction when called with data.
"""

if is_context_aware_reduce(function):

def _reduce_with_context(data: Iterable[In], context: PipelineContext | None = None) -> Iterator[U]:
# The context for the run is determined here.
run_context = context or self.context

data_iterator = self(data, run_context)

def function_wrapper(acc: U, value: Out) -> U:
return function(acc, value, run_context)

yield reduce(function_wrapper, data_iterator, initial)

return _reduce_with_context

# Not context-aware, so we adapt the function to ignore the context.
def _reduce(data: Iterable[In], context: PipelineContext | None = None) -> Iterator[U]:
# The context for the run is determined here.
run_context = context or self.context

data_iterator = self(data, run_context)

yield reduce(function, data_iterator, initial) # type: ignore
@overload
def reduce[U](
self,
function: PipelineReduceFunction[U, Out],
initial: U,
*,
per_chunk: Literal[True],
) -> "Transformer[In, U]":
"""Reduces each chunk to a single value (chainable operation)."""
...

return _reduce
@overload
def reduce[U](
self,
function: PipelineReduceFunction[U, Out],
initial: U,
*,
per_chunk: Literal[False] = False,
) -> Callable[[Iterable[In], PipelineContext | None], Iterator[U]]:
"""Reduces the entire dataset to a single value (terminal operation)."""
...

def reduce[U](
self,
function: PipelineReduceFunction[U, Out],
initial: U,
*,
per_chunk: bool = False,
) -> Union["Transformer[In, U]", Callable[[Iterable[In], PipelineContext | None], Iterator[U]]]: # type: ignore
"""Reduces elements to a single value, either per-chunk or for the entire dataset."""
if per_chunk:
# --- Efficient "per-chunk" logic (chainable) ---

# The context-awareness check is now hoisted and executed only ONCE.
if is_context_aware_reduce(function):
# We define a specialized operation for the context-aware case.
def reduce_chunk_operation(chunk: list[Out], ctx: PipelineContext) -> list[U]:
if not chunk:
return []
# No check happens here; we know the function needs the context.
wrapper = lambda acc, val: function(acc, val, ctx) # noqa: E731, W291
return [reduce(wrapper, chunk, initial)]
else:
# We define a specialized, simpler operation for the non-aware case.
def reduce_chunk_operation(chunk: list[Out], ctx: PipelineContext) -> list[U]:
if not chunk:
return []
# No check happens here; the function is called directly.
return [reduce(function, chunk, initial)] # type: ignore

return self._pipe(reduce_chunk_operation)

# --- "Entire dataset" logic with `match` (terminal) ---
match is_context_aware_reduce(function):
case True:

def _reduce_with_context(data: Iterable[In], context: PipelineContext | None = None) -> Iterator[U]:
run_context = context or self.context
data_iterator = self(data, run_context)

def function_wrapper(acc, val):
return function(acc, val, run_context) # type: ignore

yield reduce(function_wrapper, data_iterator, initial)

return _reduce_with_context

case False:

def _reduce(data: Iterable[In], context: PipelineContext | None = None) -> Iterator[U]:
run_context = context or self.context
data_iterator = self(data, run_context)
yield reduce(function, data_iterator, initial) # type: ignore

return _reduce

def catch[U](
self,
Expand Down
79 changes: 77 additions & 2 deletions tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def test_basic_reduce(self):
"""Test reduce with sum operation."""
transformer = createTransformer(int)
reducer = transformer.reduce(lambda acc, x: acc + x, initial=0)
result = list(reducer([1, 2, 3, 4]))
result = list(reducer([1, 2, 3, 4], None))
assert result == [10]

def test_reduce_with_context(self):
Expand All @@ -280,9 +280,84 @@ def test_reduce_after_transformation(self):
"""Test reduce after map transformation."""
transformer = createTransformer(int).map(lambda x: x * 2)
reducer = transformer.reduce(lambda acc, x: acc + x, initial=0)
result = list(reducer([1, 2, 3]))
result = list(reducer([1, 2, 3], None))
assert result == [12] # [2, 4, 6] summed = 12

def test_reduce_per_chunk_basic(self):
"""Test reduce with per_chunk=True for basic operation."""
transformer = createTransformer(int, chunk_size=2).reduce(lambda acc, x: acc + x, initial=0, per_chunk=True)
result = list(transformer([1, 2, 3, 4, 5]))
# With chunk_size=2: [1, 2] -> 3, [3, 4] -> 7, [5] -> 5
assert result == [3, 7, 5]

def test_reduce_per_chunk_with_context(self):
"""Test reduce with per_chunk=True and context-aware function."""
context = PipelineContext({"multiplier": 2})
transformer = createTransformer(int, chunk_size=2).reduce(
lambda acc, x, ctx: acc + (x * ctx["multiplier"]), initial=0, per_chunk=True
)
result = list(transformer([1, 2, 3], context))
# With chunk_size=2: [1, 2] -> (1*2) + (2*2) = 6, [3] -> (3*2) = 6
assert result == [6, 6]

def test_reduce_per_chunk_empty_chunks(self):
"""Test reduce with per_chunk=True handles empty chunks correctly."""
transformer = createTransformer(int, chunk_size=5).reduce(lambda acc, x: acc + x, initial=0, per_chunk=True)
result = list(transformer([]))
assert result == []

def test_reduce_per_chunk_single_element_chunks(self):
"""Test reduce with per_chunk=True with single element chunks."""
transformer = createTransformer(int, chunk_size=1).reduce(lambda acc, x: acc + x, initial=10, per_chunk=True)
result = list(transformer([1, 2, 3]))
# Each chunk has one element: [1] -> 10+1=11, [2] -> 10+2=12, [3] -> 10+3=13
assert result == [11, 12, 13]

def test_reduce_per_chunk_chaining(self):
"""Test reduce with per_chunk=True can be chained with other operations."""
transformer = (
createTransformer(int, chunk_size=2)
.map(lambda x: x * 2)
.reduce(lambda acc, x: acc + x, initial=0, per_chunk=True)
.map(lambda x: x * 10)
)
result = list(transformer([1, 2, 3]))
# After map: [2, 4, 6]
# With chunk_size=2: [2, 4] -> 6, [6] -> 6
# After second map: [60, 60]
assert result == [60, 60]

def test_reduce_per_chunk_different_chunk_sizes(self):
"""Test reduce with per_chunk=True works with different chunk sizes."""
data = [1, 2, 3, 4, 5, 6]

# Test with chunk_size=2
transformer_2 = createTransformer(int, chunk_size=2).reduce(lambda acc, x: acc + x, initial=0, per_chunk=True)
result_2 = list(transformer_2(data))
assert result_2 == [3, 7, 11] # [1,2]->3, [3,4]->7, [5,6]->11

# Test with chunk_size=3
transformer_3 = createTransformer(int, chunk_size=3).reduce(lambda acc, x: acc + x, initial=0, per_chunk=True)
result_3 = list(transformer_3(data))
assert result_3 == [6, 15] # [1,2,3]->6, [4,5,6]->15

def test_reduce_per_chunk_versus_terminal(self):
"""Test that per_chunk=True and per_chunk=False produce different behaviors."""
data = [1, 2, 3, 4]

# Terminal reduce (per_chunk=False) - returns a callable
transformer_terminal = createTransformer(int, chunk_size=2)
reducer_terminal = transformer_terminal.reduce(lambda acc, x: acc + x, initial=0, per_chunk=False)
result_terminal = list(reducer_terminal(data, None))
assert result_terminal == [10] # Sum of all elements

# Per-chunk reduce (per_chunk=True) - returns a transformer
transformer_per_chunk = createTransformer(int, chunk_size=2).reduce(
lambda acc, x: acc + x, initial=0, per_chunk=True
)
result_per_chunk = list(transformer_per_chunk(data))
assert result_per_chunk == [3, 7] # Sum per chunk [1,2]->3, [3,4]->7


class TestTransformerEdgeCases:
"""Test edge cases and boundary conditions."""
Expand Down
Loading