-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patharguments.py
68 lines (51 loc) · 1.97 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
alpha: Optional[int] = field(
default=1, metadata={"help": "The weight of Similarity loss."}
)
beta: Optional[int] = field(
default=1, metadata={"help": "The weight of Classification loss."}
)
gamma: Optional[int] = field(
default=1, metadata={"help": "The weight of Cosine Embedding loss."}
)
momentum_rate: Optional[float] = field(
default=0.0, metadata={"help": "The rate of Momentum update."}
)
centroid: Optional[bool] = field(
default=False, metadata={"help": "If use centroid"}
)
cls_loss: Optional[str] = field(
default='CrossEntropyLoss', metadata={"help": "If use BCELoss else CrossEntropyLoss"}
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: str = field(
metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
subsampling_rate: float = field(
default=1.0, metadata={"help": "The rate of subsampling, default is 1.0 (Not subsampling)."}
)
stratified_sampling: bool = field(
default=False, metadata={"help": "If stratified sampling else random sampling."}
)
with_example: bool = field(
default=True, metadata={"help": "Semantic label with one sentence example."}
)
text_max_length: Optional[int] = field(
default=512, metadata={"help": "The maximum total input sequence length after tokenization."}
)
num_proc: Optional[int] = field(
default=8, metadata={"help": "The num proc in dataset.map()."}
)