Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix/update reward event #71

Merged
merged 2 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion prompting/validators/reward/blacklist.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@

@dataclass
class BlacklistRewardEvent(BaseRewardEvent):
is_filter_model: bool = True
matched_ngram: str = None
significance_score: float = None
isabella618033 marked this conversation as resolved.
Show resolved Hide resolved


class Blacklist(BaseRewardModel):
Expand Down
3 changes: 1 addition & 2 deletions prompting/validators/reward/diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def mean_pooling(model_output, attention_mask):
@dataclass
class DiversityRewardEvent(BaseRewardEvent):
historic: float = None
batch: float = None
is_filter_model: bool = True
batch: float = None


class DiversityRewardModel(BaseRewardModel):
Expand Down
3 changes: 1 addition & 2 deletions prompting/validators/reward/nsfw.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@

@dataclass
class NSFWRewardEvent(BaseRewardEvent):
score: float = None
is_filter_model: bool = True
score: float = None


class NSFWRewardModel(BaseRewardModel):
Expand Down
3 changes: 1 addition & 2 deletions prompting/validators/reward/relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def mean_pooling(model_output, attention_mask):
@dataclass
class RelevanceRewardEvent(BaseRewardEvent):
bert_score: float = None
mpnet_score: float = None
is_filter_model: bool = True
mpnet_score: float = None


class RelevanceRewardModel(BaseRewardModel):
Expand Down
28 changes: 9 additions & 19 deletions prompting/validators/reward/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,17 @@

@dataclass
class BaseRewardEvent:
reward: float = None
normalized_reward: float = None
is_filter_model: bool = False
reward: float = 1.0
isabella618033 marked this conversation as resolved.
Show resolved Hide resolved
normalized_reward: float = None

@staticmethod
def parse_reward_events(reward_events) -> List[dict]:
"""Parse each reward event and ensure that values are not None."""

parsed_events = {f.name: [] for f in fields(reward_events[0])}
for i, event in enumerate(reward_events):
for field in fields(event):
value = getattr(event, field.name)

# Ensure that the reward is not None.
if field.name == 'reward' and value in (None, torch.nan, torch.inf):
bt.logging.warning(f"Reward for {event.__class__.__name__} index {i} is {value}, setting to {event.is_filter_model}")
value = 1 if event.is_filter_model else 0

parsed_events[field.name].append(value)

return parsed_events
field_names = [field.name for field in fields(reward_events[0])]
reward_events = [
asdict(reward_event).values() for reward_event in reward_events
]
reward_event = dict(zip(field_names, list(zip(*reward_events))))
return reward_event


class BaseRewardModel:
Expand Down Expand Up @@ -181,7 +171,7 @@ def apply(
# Warns unexpected behavior for rewards
if torch.isnan(filled_rewards_normalized).any():
bt.logging.warning(f"The tensor from {self.name} contains NaN values: {filled_rewards_normalized}")

filled_rewards_normalized.nan_to_num_(nan=0.0)

# Return the filled rewards.
return filled_rewards_normalized, reward_events