-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.yaml
47 lines (47 loc) · 1.15 KB
/
train.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# lightning.pytorch==2.1.0dev
seed_everything: 42
trainer:
# logger:
# class_path: CustomWandbLogger
# init_args:
# save_dir: logs/sparse-cross-encoder
# project: sparse-cross-encoder
callbacks:
class_path: ModelCheckpoint
init_args:
every_n_train_steps: 5000
max_steps: 100000
val_check_interval: 5000
model:
model_name_or_path: cross-encoder/ms-marco-MiniLM-L-6-v2
config:
class_path: SparseCrossEncoderConfig
init_args:
max_position_embeddings: 512
attention_window_size: 4
cls_query_attention: true
cls_doc_attention: true
query_cls_attention: false
query_doc_attention: false
doc_cls_attention: true
doc_query_attention: true
loss_function:
class_path: MarginMSE
init_args:
reduction: mean
compile_model: true
data:
class_path: SparseCrossEncoderDataModule
init_args:
ir_dataset_path: msmarco-passage/train/kd-docpairs
truncate: true
max_length: 512
batch_size: 32
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 7.0e-06
lr_scheduler:
class_path: ConstantSchedulerWithWarmup
init_args:
num_warmup_steps: 1000