Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix for num worker 0 causing issues in losses after 1 epoch #5384

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def collate_fn(self, batch, tp_workers=0):
def pad_batch_and_build_loss_mask(self, input_ids, batch_max, answer_starts):
""" Pad input_ids in batch to max batch length while building loss mask """
batch_loss_masks = []
padded_input_ids = []
for ids, answer_start_idx in zip(input_ids, answer_starts):
if answer_start_idx is not None:
# Loss mask where answer tokens are 1.0 and all other tokens are 0.0
Expand All @@ -375,17 +376,19 @@ def pad_batch_and_build_loss_mask(self, input_ids, batch_max, answer_starts):
# Pad to max length
input_length = len(ids)
padding_length = batch_max - input_length
ids.extend([self.pad_token_id] * padding_length)
pad_extend = [self.pad_token_id] * padding_length
ids = ids + pad_extend
padded_input_ids.append(ids)

# Account for padding in loss mask
loss_mask.extend([0.0] * padding_length)
batch_loss_masks.append(torch.tensor(loss_mask, dtype=torch.float))

# Make into torch tensors
input_ids = torch.tensor(input_ids, dtype=torch.long)
padded_input_ids = torch.tensor(padded_input_ids, dtype=torch.long)
batch_loss_masks = torch.stack(batch_loss_masks)

return input_ids, batch_loss_masks
return padded_input_ids, batch_loss_masks

def inference_collate_fn(self, batch):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import json

import torch
Expand Down