forked from facebookresearch/LayerSkip
-
Notifications
You must be signed in to change notification settings - Fork 0
/
arguments.py
55 lines (46 loc) · 1.47 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Union
import transformers
from self_speculation.generator_base import (
GenerationConfig,
)
raw_types = Union[str, float, int, Dict, List, Tuple]
@dataclass
class Arguments:
model: str
model_args: Optional[str] = None
seed: Optional[int] = 42
output_dir: str = "./logs"
# Source: https://github.com/EleutherAI/lm-evaluation-harness/blob/a9eaaf46f1e246e5ce090e37f2f99fe1cfe5a919/lm_eval/utils.py
def simple_parse_args_string(args_string):
"""
Parses something like
args1=val1,arg2=val2
Into a dictionary
"""
args_string = args_string.strip()
if not args_string:
return {}
arg_list = [arg for arg in args_string.split(",") if arg]
args_dict = {
k: handle_arg_string(v) for k, v in [arg.split("=") for arg in arg_list]
}
return args_dict
# Source: https://github.com/EleutherAI/lm-evaluation-harness/blob/a9eaaf46f1e246e5ce090e37f2f99fe1cfe5a919/lm_eval/utils.py
def handle_arg_string(arg):
if arg.lower() == "true":
return True
elif arg.lower() == "false":
return False
elif arg.isnumeric():
return int(arg)
try:
return float(arg)
except ValueError:
return arg