Skip to content

Commit 0d7f3cf

Browse files
committed
fix wandb parameters
1 parent 9292e2a commit 0d7f3cf

File tree

1 file changed

+27
-31
lines changed

1 file changed

+27
-31
lines changed

delft/applications/grobidTagger.py

+27-31
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def configure(model, architecture, output_path=None, max_sequence_length=-1, bat
168168
def train(model, embeddings_name=None, architecture=None, transformer=None, input_path=None,
169169
output_path=None, features_indices=None, max_sequence_length=-1, batch_size=-1, max_epoch=-1,
170170
use_ELMo=False, incremental=False, input_model_path=None, patience=-1, learning_rate=None, early_stop=None, multi_gpu=False,
171-
wandb_config=None):
171+
report_to_wandb=False):
172172

173173
print('Loading data...')
174174
if input_path == None:
@@ -216,7 +216,7 @@ def train(model, embeddings_name=None, architecture=None, transformer=None, inpu
216216
early_stop=early_stop,
217217
patience=patience,
218218
learning_rate=learning_rate,
219-
report_to_wandb=wandb_config
219+
report_to_wandb=report_to_wandb
220220
)
221221

222222
if incremental:
@@ -243,9 +243,10 @@ def train(model, embeddings_name=None, architecture=None, transformer=None, inpu
243243
# split data, train a GROBID model and evaluate it
244244
def train_eval(model, embeddings_name=None, architecture='BidLSTM_CRF', transformer=None,
245245
input_path=None, output_path=None, fold_count=1,
246-
features_indices=None, max_sequence_length=-1, batch_size=-1, max_epoch=-1,
246+
features_indices=None, max_sequence_length=-1, batch_size=-1, max_epoch=-1,
247247
use_ELMo=False, incremental=False, input_model_path=None, patience=-1,
248-
learning_rate=None, early_stop=None, multi_gpu=False, wandb_config=None):
248+
learning_rate=None, early_stop=None, multi_gpu=False,
249+
report_to_wandb=False):
249250

250251
print('Loading data...')
251252
if input_path is None:
@@ -274,11 +275,12 @@ def train_eval(model, embeddings_name=None, architecture='BidLSTM_CRF', transfor
274275
use_ELMo,
275276
patience,
276277
early_stop)
278+
277279
model = Sequence(model_name, architecture=architecture, embeddings_name=embeddings_name,
278280
max_sequence_length=max_sequence_length, recurrent_dropout=0.50, batch_size=batch_size,
279281
learning_rate=learning_rate, max_epoch=max_epoch, early_stop=early_stop, patience=patience,
280282
use_ELMo=use_ELMo, fold_number=fold_count, multiprocessing=multiprocessing,
281-
features_indices=features_indices, transformer_name=transformer, report_to_wandb=wandb_config)
283+
features_indices=features_indices, transformer_name=transformer, report_to_wandb=report_to_wandb)
282284

283285
if incremental:
284286
if input_model_path != None:
@@ -310,7 +312,7 @@ def train_eval(model, embeddings_name=None, architecture='BidLSTM_CRF', transfor
310312

311313

312314
# split data, train a GROBID model and evaluate it
313-
def eval_(model, input_path=None, architecture='BidLSTM_CRF', use_ELMo=False):
315+
def eval_(model, input_path=None, architecture='BidLSTM_CRF', use_ELMo=False, report_to_wandb=False):
314316
print('Loading data...')
315317
if input_path is None:
316318
# it should never be the case
@@ -331,7 +333,7 @@ def eval_(model, input_path=None, architecture='BidLSTM_CRF', use_ELMo=False):
331333
start_time = time.time()
332334

333335
# load the model
334-
model = Sequence(model_name)
336+
model = Sequence(model_name, report_to_wandb=report_to_wandb)
335337
model.load()
336338

337339
# evaluation
@@ -478,12 +480,6 @@ class Tasks:
478480
# default word embeddings
479481
embeddings_name = "glove-840B"
480482

481-
wandb_config = None
482-
if wandb:
483-
wandb_config = {
484-
"project": "delft-grobidTagger"
485-
}
486-
487483
if action == Tasks.TRAIN:
488484
train(
489485
model,
@@ -502,7 +498,7 @@ class Tasks:
502498
max_epoch=max_epoch,
503499
early_stop=early_stop,
504500
multi_gpu=multi_gpu,
505-
wandb_config=wandb_config
501+
report_to_wandb=wandb
506502
)
507503

508504
if action == Tasks.EVAL:
@@ -516,23 +512,23 @@ class Tasks:
516512
if action == Tasks.TRAIN_EVAL:
517513
if args.fold_count < 1:
518514
raise ValueError("fold-count should be equal or more than 1")
519-
train_eval(model,
520-
embeddings_name=embeddings_name,
521-
architecture=architecture,
522-
transformer=transformer,
523-
input_path=input_path,
524-
output_path=output,
525-
fold_count=args.fold_count,
526-
max_sequence_length=max_sequence_length,
527-
batch_size=batch_size,
528-
use_ELMo=use_ELMo,
529-
incremental=incremental,
530-
input_model_path=input_model_path,
531-
learning_rate=learning_rate,
532-
max_epoch=max_epoch,
533-
early_stop=early_stop,
534-
multi_gpu=multi_gpu,
535-
wandb_config=wandb_config)
515+
train_eval(model,
516+
embeddings_name=embeddings_name,
517+
architecture=architecture,
518+
transformer=transformer,
519+
input_path=input_path,
520+
output_path=output,
521+
fold_count=args.fold_count,
522+
max_sequence_length=max_sequence_length,
523+
batch_size=batch_size,
524+
use_ELMo=use_ELMo,
525+
incremental=incremental,
526+
input_model_path=input_model_path,
527+
learning_rate=learning_rate,
528+
max_epoch=max_epoch,
529+
early_stop=early_stop,
530+
multi_gpu=multi_gpu,
531+
report_to_wandb=wandb)
536532

537533
if action == Tasks.TAG:
538534
someTexts = []

0 commit comments

Comments
 (0)