Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

delay filter init; remove *args #1369

Merged
merged 4 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions lm_eval/api/filter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List
from typing import Callable, Iterable, List, Union

from lm_eval.api.instance import Instance

Expand All @@ -14,13 +14,13 @@ class Filter(ABC):

"""

def __init__(self, *args, **kwargs) -> None:
def __init__(self, **kwargs) -> None:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""

@abstractmethod
def apply(self, resps, docs):
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
Expand All @@ -40,15 +40,16 @@ class FilterEnsemble:
"""

name: str
filters: List[Filter]
filters: List[Callable[[], Filter]]

def apply(self, instances: List[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
resps, docs = list(resps), list(docs)

for f in self.filters:
# apply filters in sequence
resps = f.apply(resps, docs)
f_init = f()
resps = f_init.apply(resps, docs)
baberabb marked this conversation as resolved.
Show resolved Hide resolved

# add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
Expand Down
12 changes: 5 additions & 7 deletions lm_eval/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List
from typing import List, Union
from functools import partial

from lm_eval.api.filter import FilterEnsemble
from . import selection
Expand All @@ -22,7 +23,7 @@
}


def get_filter(filter_name):
def get_filter(filter_name: str) -> Union[type, str]:
if filter_name in FILTER_REGISTRY:
return FILTER_REGISTRY[filter_name]
else:
Expand All @@ -37,11 +38,8 @@ def build_filter_ensemble(
"""
filters = []
for function, kwargs in components:
if kwargs is None:
f = get_filter(function)()
else:
# create a filter given its name in the registry
f = get_filter(function)(**kwargs) # TODO: pass kwargs to filters properly
# create a filter given its name in the registry
f = partial(get_filter(function), **kwargs)
# add the filter as a pipeline step
filters.append(f)

Expand Down
6 changes: 4 additions & 2 deletions lm_eval/filters/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ def apply(self, resps, docs):


class TakeKFilter(Filter):
def __init__(self, *args, **kwargs) -> None:
def __init__(self, **kwargs) -> None:
self.k = kwargs.pop("k")

super().__init__(*args, **kwargs)
super().__init__(**kwargs)

def apply(self, resps, docs):
# need resp to be subscriptable to check below
resps = list(resps)
# check we have at least k responses per doc, else we can't take the first k
assert (
len(resps[0]) >= self.k
Expand Down
Loading