Skip to content

Commit

Permalink
Enhancement: Support for extra_info in Reward Calculation (#266)
Browse files Browse the repository at this point in the history
### **Enhancement: Support for `extra_info` in Reward Calculation**  

#### **Summary**  
This update enhances the reward computation process by introducing an
additional `extra_info` parameter. This allows users to pass in more
contextual information when calculating rewards, improving flexibility
for different datasets.

#### **Changes Made**  
- **Updated `_default_compute_score`** to accept an `extra_info`
argument:
  ```python
def _default_compute_score(data_source, solution_str, ground_truth,
extra_info):
  ```
- **Modified the reward manager (`naive.py`)** to pass `extra_info` from
`data_item.non_tensor_batch` to `compute_score`:
  ```python
  extra_info = data_item.non_tensor_batch['extra_info']
  score = self.compute_score(
      data_source=data_source,
      solution_str=sequences_str,
      ground_truth=ground_truth,
      extra_info=extra_info,
  )
  ```
  
#### **Why This Change?**  
- Some datasets require additional context beyond `data_source`,
`solution_str`, and `ground_truth` for accurate reward computation.
- The new `extra_info` field allows users to pass custom metadata,
ideally in dictionary form, as specified in the [official
documentation](https://verl.readthedocs.io/en/latest/preparation/prepare_data.html).
- This change maintains compatibility with existing dataset processing
scripts, as they already include the `extra_info` field.

#### **Impact**  
- **Improved flexibility**: Users can now pass additional contextual
information, making reward computation more adaptable to different
datasets.
- **Backward compatibility**: Since all example datasets already include
`extra_info`, this update should integrate seamlessly.

Let me know if any modifications are needed!
  • Loading branch information
maksimstw authored Feb 17, 2025
1 parent 0c32cf7 commit f0e5bdf
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
2 changes: 1 addition & 1 deletion verl/utils/reward_score/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# from . import gsm8k, math, prime_math, prime_code


def _default_compute_score(data_source, solution_str, ground_truth):
def _default_compute_score(data_source, solution_str, ground_truth, extra_info=None):
if data_source == 'openai/gsm8k':
from . import gsm8k
res = gsm8k.compute_score(solution_str, ground_truth)
Expand Down
3 changes: 3 additions & 0 deletions verl/workers/reward_manager/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,13 @@ def __call__(self, data: DataProto):

data_source = data_item.non_tensor_batch['data_source']

extra_info = data_item.non_tensor_batch.get('extra_info', None)

score = self.compute_score(
data_source=data_source,
solution_str=sequences_str,
ground_truth=ground_truth,
extra_info=extra_info,
)
reward_tensor[i, valid_response_length - 1] = score

Expand Down

0 comments on commit f0e5bdf

Please sign in to comment.