Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add model "INF Outcome Reward Model" #208

Merged
merged 3 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion rewardbench/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
MixtralForCausalLM,
T5ForConditionalGeneration,
)

from .inform import INFORMForSequenceClassification
from .armorm import ArmoRMPipeline
from .beaver import BeaverCostPipeline, BeaverPipeline, LlamaForScore
from .betterpairrm import BetterPairRMPipeline
Expand Down Expand Up @@ -222,6 +222,14 @@
"custom_dialogue": False,
"model_type": "Seq. Classifier",
},
"infly/INF-ORM-Llama3.1-70B": {
"model_builder": INFORMForSequenceClassification.from_pretrained,
"pipeline_builder": RewardBenchPipeline,
"quantized": False,
"custom_dialogue": False,
"model_type": "Seq. Classifier",
"torch_dtype": torch.bfloat16,
},
}

DPO_MODEL_CONFIG = {
Expand Down
87 changes: 87 additions & 0 deletions rewardbench/models/inform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2024 AllenAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional
import torch
import torch.nn as nn
from transformers import LlamaPreTrainedModel, LlamaModel
from transformers.modeling_outputs import SequenceClassifierOutputWithPast


class INFORMForSequenceClassification(LlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = LlamaModel(config)
self.score = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size),
nn.ReLU(),
nn.Linear(config.hidden_size, 1)
)

# Initialize weights and apply final processing
self.post_init()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):

transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)

if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]

if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1

pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

loss = None
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
Loading