From 9b3eb81014dd3be8c8934942900422eaea9f0de5 Mon Sep 17 00:00:00 2001 From: kmckiern Date: Fri, 2 Sep 2022 04:46:31 -0700 Subject: [PATCH] if learning rate is a tensor, get item (float) (#18861) --- src/transformers/trainer_pt_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 57103b50d5a039..7baa7a46e95932 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -837,6 +837,8 @@ def _get_learning_rate(self): if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.4") else self.lr_scheduler.get_lr()[0] ) + if torch.is_tensor(last_lr): + last_lr = last_lr.item() return last_lr