Skip to content

Commit

Permalink
fix: correct parsing of command line bool arg values
Browse files Browse the repository at this point in the history
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
  • Loading branch information
tjohnson31415 authored and dtrifiro committed Aug 14, 2024
1 parent 977103a commit 50b020b
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions src/vllm_tgis_adapter/tgis_utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@ def _to_env_var(arg_name: str) -> str:
return arg_name.upper().replace("-", "_")


def _bool_from_string(val: str) -> bool:
return val.lower().strip() == "true" or val == "1"


def _switch_action_default(action: argparse.Action) -> None:
"""Switch to using env var fallback for all args."""
env_val = os.environ.get(_to_env_var(action.dest))
if not env_val:
return

val: bool | str | int
if action.type is bool:
val = env_val.lower() == "true" or env_val == "1"
if action.type in [bool, _bool_from_string]:
val = _bool_from_string(env_val)
elif action.type is int:
val = int(env_val)
else:
Expand Down Expand Up @@ -111,9 +115,11 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
# map to tensor_parallel_size
parser.add_argument("--num-shard", type=int)
# TODO check boolean behaviour for env vars and defaults
parser.add_argument("--output-special-tokens", type=bool, default=False)
parser.add_argument(
"--default-include-stop-seqs", type=bool, default=True
"--output-special-tokens", type=_bool_from_string, default=False
)
parser.add_argument(
"--default-include-stop-seqs", type=_bool_from_string, default=True
) # TODO TBD
parser.add_argument("--grpc-port", type=int, default=8033)

Expand All @@ -135,9 +141,13 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument("--speculator-n-candidates", type=int)
parser.add_argument("--speculator-max-batch-size", type=int)
# allow re-enabling vllm native per-request logging
parser.add_argument("--enable-vllm-log-requests", type=bool, default=False)
parser.add_argument(
"--enable-vllm-log-requests", type=_bool_from_string, default=False
)
# set to true to disable producing prompt logprobs on all requests
parser.add_argument("--disable-prompt-logprobs", type=bool, default=False)
parser.add_argument(
"--disable-prompt-logprobs", type=_bool_from_string, default=False
)

# TODO check/add other args here

Expand Down

0 comments on commit 50b020b

Please sign in to comment.