-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathparser.py
207 lines (190 loc) · 6.38 KB
/
parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
from dataclasses import dataclass, field
from typing import Optional
from transformers import HfArgumentParser, TrainingArguments
from transformers.hf_argparser import DataClassType
# 1. NOTE: Fields with no default value set will be transformed
# into`required arguments within the HuggingFace argument parser
# 2. NOTE: Enum-type objects will be transformed into choices
MODELS = {
"bert-base-uncased": "5546055f03398095e385d7dc625e636cc8910bf2",
"roberta-base": "ff46155979338ff8063cdad90908b498ab91b181",
"nlpaueb/legal-bert-base-uncased": "15b570cbf88259610b082a167dacc190124f60f6",
"saibo/legal-roberta-base": "e0d78f4e064ff27621d61fa2320c79addb528d81",
"mukund/privbert": "48228b4661fa8252bdb39ca44a4d9758f6b37f88",
}
TASKS = [
"opp_115",
"piextract",
"policy_detection",
"policy_ie_a",
"policy_ie_b",
"policy_qa",
"privacy_qa",
"all",
]
@dataclass
class ExperimentArguments:
random_seed_iterations: int = field(
default=10, metadata={"help": "Number of random seed iterations to run"}
)
do_summarize: bool = field(
default=False, metadata={"help": "Summarize over all random seeds"}
)
@dataclass
class ModelArguments:
model_name_or_path: str = field(
metadata={
"help": "Path to pretrained model or model identifier from "
"huggingface.co/models"
}
)
config_name: Optional[str] = field(
default=None,
metadata={
"help": "Pretrained config name or path if not the same as "
"model_name_or_path"
},
)
model_revision: Optional[str] = field(
default=None,
metadata={
"help": "The specific model version to use (can be a branch name, "
"tag name or commit id)."
},
)
tokenizer_name: Optional[str] = field(
default=None,
metadata={
"help": "Pretrained tokenizer name or path if not the same as "
"model_name_or_path"
},
)
cache_dir: Optional[str] = field(
default=None,
metadata={
"help": "Where do you want to store the pretrained models downloaded "
"from huggingface.co"
},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={
"help": "Whether to use one of the fast tokenizer (backed by the "
"tokenizers library) or not"
},
)
early_stopping_patience: Optional[int] = field(
default=None, metadata={"help": "Early stopping patience value"}
)
do_clean: bool = field(
default=False, metadata={"help": "Clean all old checkpoints after training"}
)
def __post_init__(self):
assert self.model_name_or_path in MODELS, (
f"Model '{self.model_name_or_path}' is not supported, "
f"please select model from {MODELS}"
)
if self.model_revision is None:
self.model_revision = MODELS[self.model_name_or_path]
@dataclass
class DataArguments:
task: str = field(metadata={"help": "The name of the task for fine-tuning"})
data_dir: str = field(
default=os.path.relpath(
os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")
),
metadata={"help": "Path to directory containing task input data"},
)
overwrite_cache: bool = field(
default=False,
metadata={
"help": "Overwrite the cached training, evaluation and prediction sets"
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_seq_length: int = field(
default=512,
metadata={
"help": (
"The maximum total input sequence length after tokenization. "
"Sequences longer than this will be truncated, sequences "
"shorter will be padded."
)
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": (
"Whether to pad all samples to `max_seq_length`. "
"If False, will pad the samples dynamically when "
"batching to the maximum length in the batch "
"(which can be faster on GPU but will be slower on TPU)."
)
},
)
doc_stride: int = field(
default=128,
metadata={
"help": "When splitting up a long document into chunks, "
"how much stride to take between chunks."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the "
"number of training examples to this value if set"
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the "
"number of evaluation examples to this value if set"
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the "
"number of prediction examples to this value if set"
},
)
n_best_size: int = field(
default=20,
metadata={
"help": "PolicyQA: The total number of n-best predictions to generate when "
"looking for an answer."
},
)
max_answer_length: int = field(
default=30,
metadata={
"help": (
"PolicyQA: The maximum length of an answer that can be generated. "
"This is needed because the start and end predictions "
"are not conditioned on one another."
)
},
)
def __post_init__(self):
assert os.path.isdir(self.data_dir), f"{self.data_dir} is not a valid directory"
assert (
self.task in TASKS
), f"Task '{self.task}' is not supported, please select task from {TASKS}"
def get_parser() -> HfArgumentParser:
return HfArgumentParser(
(
DataClassType(DataArguments),
DataClassType(ModelArguments),
DataClassType(TrainingArguments),
DataClassType(ExperimentArguments),
)
)