Skip to content

Commit

Permalink
Bug fix - Limit val batches set to 1.0 (NVIDIA#5023)
Browse files Browse the repository at this point in the history
* Bug fix

Signed-off-by: shanmugamr1992 <shanmugamr1992@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Adressed sandeep's comments

* Fixing limit val batches support in bert

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixing limit val batches support in bert

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: shanmugamr1992 <shanmugamr1992@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sandeep Subramanian <sandeep.subramanian.1@umontreal.ca>
  • Loading branch information
3 people authored and titu1994 committed Oct 6, 2022
1 parent 9a6725e commit 1a689f7
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def build_train_valid_test_datasets(
trainer,
data_prefix["train"],
data_impl,
train_valid_test_num_samples[0],
int(train_valid_test_num_samples[0]),
seq_length,
seed,
skip_warmup,
Expand All @@ -111,7 +111,7 @@ def build_train_valid_test_datasets(
trainer,
data_prefix["validation"],
data_impl,
train_valid_test_num_samples[1],
int(train_valid_test_num_samples[1]),
seq_length,
seed,
skip_warmup,
Expand All @@ -123,7 +123,7 @@ def build_train_valid_test_datasets(
trainer,
data_prefix["test"],
data_impl,
train_valid_test_num_samples[2],
int(train_valid_test_num_samples[2]),
seq_length,
seed,
skip_warmup,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def process_batch(self, batch):

def _build_train_valid_test_datasets(self):
logging.info('Building Bert datasets.')
if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float):
raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.")
global_batch_size = self.trainer.world_size * self.cfg.micro_batch_size / self.cfg.tensor_model_parallel_size
# Compute trianing micro-batch steps: total_global_batch_steps x grad_acumms_per_global_batch
max_train_steps = self.trainer.max_steps * self.trainer.accumulate_grad_batches
Expand All @@ -238,6 +240,12 @@ def _build_train_valid_test_datasets(self):
eval_iters * global_batch_size,
test_iters * global_batch_size,
]

if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float):
train_valid_test_num_samples[
1
] = 1 # This is to make sure we only have one epoch on every validation iteration

self._train_ds, self._validation_ds, self._test_ds = build_train_valid_test_datasets(
cfg=self.cfg,
trainer=self.trainer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,8 @@ def process_global_batch(self, global_batch):

def build_train_valid_test_datasets(self):
logging.info('Building GPT datasets.')
if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float):
raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.")
global_batch_size = self.cfg.global_batch_size
max_train_steps = self.trainer.max_steps
eval_iters = (max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches
Expand All @@ -516,6 +518,12 @@ def build_train_valid_test_datasets(self):
eval_iters * global_batch_size,
test_iters * global_batch_size,
]

if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float):
train_valid_test_num_samples[
1
] = 1 # This is to make sure we only have one epoch on every validation iteration

self._train_ds, self._validation_ds, self._test_ds = build_train_valid_test_datasets(
cfg=self.cfg,
trainer=self.trainer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def add_special_tokens_to_tokenizer(cls, tokenizer, tokenizer_cfg, dataset_type=

def build_train_valid_test_datasets(self):
logging.info(f'Building {self.model_name} datasets.')
if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float):
raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.")
global_batch_size = self._cfg.global_batch_size
eval_iters = (self.trainer.max_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches
test_iters = self.trainer.limit_test_batches
Expand All @@ -155,6 +157,12 @@ def build_train_valid_test_datasets(self):
eval_iters * global_batch_size,
test_iters * global_batch_size,
]

if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float):
train_valid_test_num_samples[
1
] = 1 # This is to make sure we only have one epoch on every validation iteration

self._train_ds, self._validation_ds, self._test_ds = build_train_valid_test_datasets(
cfg=self._cfg,
trainer=self.trainer,
Expand Down

0 comments on commit 1a689f7

Please sign in to comment.