Skip to content

Commit

Permalink
fix data dtype for amp training
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangting2020 committed Apr 26, 2023
1 parent 0af4680 commit afc301f
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion ppcls/static/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,14 @@ def build(config,
mode = "Train" if is_train else "Eval"
use_mix = "batch_transform_ops" in config["DataLoader"][mode][
"dataset"]
data_dtype = "float32"
if 'AMP' in config and config["AMP"]["level"] == 'O2':
data_dtype = "float16"
feeds = create_feeds(
config["Global"]["image_shape"],
use_mix,
class_num=class_num,
dtype="float32")
dtype=data_dtype)

# build model
# data_format should be assigned in arch-dict
Expand Down

0 comments on commit afc301f

Please sign in to comment.