diff --git a/inference/utils/utils.py b/inference/utils/utils.py index 1d8c39c2c..c29c8b7d9 100644 --- a/inference/utils/utils.py +++ b/inference/utils/utils.py @@ -1,5 +1,6 @@ import argparse import copy +import json import math from typing import Any, List @@ -40,12 +41,8 @@ def get_argument_parser() -> argparse.ArgumentParser: choices=["bf16", "fp16"], help="dtype for model") group.add_argument( "--generate_kwargs", - type=dict, - default={ - "min_length": 100, - "max_new_tokens": 100, - "do_sample": False - }, + type=str, + default='{"min_length": 100, "max_new_tokens": 100, "do_sample": False}', help="generate parameters. look at https://huggingface.co/docs/transformers/v4.21.1/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate to see the supported parameters" ) @@ -55,6 +52,7 @@ def get_argument_parser() -> argparse.ArgumentParser: def get_args(parser: argparse.ArgumentParser) -> argparse.Namespace: args = parser.parse_args() args.dtype = get_torch_dtype(args.dtype) + args.generate_kwargs = json.loads(args.generate_kwargs) return args