-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathparser.py
159 lines (142 loc) · 4.58 KB
/
parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
from dataclasses import dataclass, field
from typing import Optional
from transformers import HfArgumentParser, TrainingArguments
from transformers.hf_argparser import DataClassType
# 1. NOTE: Fields with no default value set will be transformed
# into`required arguments within the HuggingFace argument parser
# 2. NOTE: Enum-type objects will be transformed into choices
MODELS = [
"sentence-transformers/all-mpnet-base-v2",
"bert-base-uncased",
"nlpaueb/legal-bert-base-uncased",
"mukund/privbert",
]
TASKS = [
"opp_115",
"piextract",
"policy_detection",
"policy_ie_a",
"policy_ie_b",
"policy_qa",
"privacy_qa",
"all",
]
@dataclass
class ExperimentArguments:
random_seed_iterations: int = field(
default=5, metadata={"help": "Number of random seed iterations to run"}
)
do_summarize: bool = field(
default=False, metadata={"help": "Summarize over all random seeds"}
)
@dataclass
class ModelArguments:
model_name_or_path: str = field(
metadata={
"help": "Path to pretrained model or model identifier from "
"huggingface.co/models"
}
)
config_name: Optional[str] = field(
default=None,
metadata={
"help": "Pretrained config name or path if not the same as "
"model_name_or_path"
},
)
model_revision: str = field(
default="main",
metadata={
"help": "The specific model version to use (can be a branch name, "
"tag name or commit id)."
},
)
tokenizer_name: Optional[str] = field(
default=None,
metadata={
"help": "Pretrained tokenizer name or path if not the same as "
"model_name_or_path"
},
)
cache_dir: Optional[str] = field(
default=None,
metadata={
"help": "Where do you want to store the pretrained models downloaded "
"from huggingface.co"
},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={
"help": "Whether to use one of the fast tokenizer (backed by the "
"tokenizers library) or not"
},
)
early_stopping_patience: int = field(
default=5, metadata={"help": "Early stopping patience value"}
)
do_clean: bool = field(
default=False, metadata={"help": "Clean all old checkpoints after training"}
)
def __post_init__(self):
assert self.model_name_or_path in MODELS, (
f"Model '{self.model_name_or_path}' is not supported, "
f"please select model from {MODELS}"
)
@dataclass
class DataArguments:
task: str = field(metadata={"help": "The name of the task for fine-tuning"})
data_dir: str = field(
default=os.path.relpath(
os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")
),
metadata={"help": "Path to directory containing task input data"},
)
overwrite_cache: bool = field(
default=False,
metadata={
"help": "Overwrite the cached training, evaluation and prediction sets"
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the "
"number of training examples to this value if set"
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the "
"number of evaluation examples to this value if set"
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the "
"number of prediction examples to this value if set"
},
)
def __post_init__(self):
assert os.path.isdir(self.data_dir), f"{self.data_dir} is not a valid directory"
assert (
self.task in TASKS
), f"Task '{self.task}' is not supported, please select task from {TASKS}"
def get_parser() -> HfArgumentParser:
return HfArgumentParser(
(
DataClassType(DataArguments),
DataClassType(ModelArguments),
DataClassType(TrainingArguments),
DataClassType(ExperimentArguments),
)
)