Skip to content

Commit

Permalink
str kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
Mayank Mishra authored and Mayank Mishra committed Aug 17, 2022
1 parent f3385f2 commit 8f25200
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions inference/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import copy
import json
import math
from typing import Any, List

Expand Down Expand Up @@ -40,12 +41,8 @@ def get_argument_parser() -> argparse.ArgumentParser:
choices=["bf16", "fp16"], help="dtype for model")
group.add_argument(
"--generate_kwargs",
type=dict,
default={
"min_length": 100,
"max_new_tokens": 100,
"do_sample": False
},
type=str,
default='{"min_length": 100, "max_new_tokens": 100, "do_sample": False}',
help="generate parameters. look at https://huggingface.co/docs/transformers/v4.21.1/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate to see the supported parameters"
)

Expand All @@ -55,6 +52,7 @@ def get_argument_parser() -> argparse.ArgumentParser:
def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
args = parser.parse_args()
args.dtype = get_torch_dtype(args.dtype)
args.generate_kwargs = json.loads(args.generate_kwargs)
return args


Expand Down

0 comments on commit 8f25200

Please sign in to comment.