@@ -168,7 +168,7 @@ def configure(model, architecture, output_path=None, max_sequence_length=-1, bat
168
168
def train (model , embeddings_name = None , architecture = None , transformer = None , input_path = None ,
169
169
output_path = None , features_indices = None , max_sequence_length = - 1 , batch_size = - 1 , max_epoch = - 1 ,
170
170
use_ELMo = False , incremental = False , input_model_path = None , patience = - 1 , learning_rate = None , early_stop = None , multi_gpu = False ,
171
- wandb_config = None ):
171
+ report_to_wandb = False ):
172
172
173
173
print ('Loading data...' )
174
174
if input_path == None :
@@ -216,7 +216,7 @@ def train(model, embeddings_name=None, architecture=None, transformer=None, inpu
216
216
early_stop = early_stop ,
217
217
patience = patience ,
218
218
learning_rate = learning_rate ,
219
- report_to_wandb = wandb_config
219
+ report_to_wandb = report_to_wandb
220
220
)
221
221
222
222
if incremental :
@@ -243,9 +243,10 @@ def train(model, embeddings_name=None, architecture=None, transformer=None, inpu
243
243
# split data, train a GROBID model and evaluate it
244
244
def train_eval (model , embeddings_name = None , architecture = 'BidLSTM_CRF' , transformer = None ,
245
245
input_path = None , output_path = None , fold_count = 1 ,
246
- features_indices = None , max_sequence_length = - 1 , batch_size = - 1 , max_epoch = - 1 ,
246
+ features_indices = None , max_sequence_length = - 1 , batch_size = - 1 , max_epoch = - 1 ,
247
247
use_ELMo = False , incremental = False , input_model_path = None , patience = - 1 ,
248
- learning_rate = None , early_stop = None , multi_gpu = False , wandb_config = None ):
248
+ learning_rate = None , early_stop = None , multi_gpu = False ,
249
+ report_to_wandb = False ):
249
250
250
251
print ('Loading data...' )
251
252
if input_path is None :
@@ -274,11 +275,12 @@ def train_eval(model, embeddings_name=None, architecture='BidLSTM_CRF', transfor
274
275
use_ELMo ,
275
276
patience ,
276
277
early_stop )
278
+
277
279
model = Sequence (model_name , architecture = architecture , embeddings_name = embeddings_name ,
278
280
max_sequence_length = max_sequence_length , recurrent_dropout = 0.50 , batch_size = batch_size ,
279
281
learning_rate = learning_rate , max_epoch = max_epoch , early_stop = early_stop , patience = patience ,
280
282
use_ELMo = use_ELMo , fold_number = fold_count , multiprocessing = multiprocessing ,
281
- features_indices = features_indices , transformer_name = transformer , report_to_wandb = wandb_config )
283
+ features_indices = features_indices , transformer_name = transformer , report_to_wandb = report_to_wandb )
282
284
283
285
if incremental :
284
286
if input_model_path != None :
@@ -310,7 +312,7 @@ def train_eval(model, embeddings_name=None, architecture='BidLSTM_CRF', transfor
310
312
311
313
312
314
# split data, train a GROBID model and evaluate it
313
- def eval_ (model , input_path = None , architecture = 'BidLSTM_CRF' , use_ELMo = False ):
315
+ def eval_ (model , input_path = None , architecture = 'BidLSTM_CRF' , use_ELMo = False , report_to_wandb = False ):
314
316
print ('Loading data...' )
315
317
if input_path is None :
316
318
# it should never be the case
@@ -331,7 +333,7 @@ def eval_(model, input_path=None, architecture='BidLSTM_CRF', use_ELMo=False):
331
333
start_time = time .time ()
332
334
333
335
# load the model
334
- model = Sequence (model_name )
336
+ model = Sequence (model_name , report_to_wandb = report_to_wandb )
335
337
model .load ()
336
338
337
339
# evaluation
@@ -478,12 +480,6 @@ class Tasks:
478
480
# default word embeddings
479
481
embeddings_name = "glove-840B"
480
482
481
- wandb_config = None
482
- if wandb :
483
- wandb_config = {
484
- "project" : "delft-grobidTagger"
485
- }
486
-
487
483
if action == Tasks .TRAIN :
488
484
train (
489
485
model ,
@@ -502,7 +498,7 @@ class Tasks:
502
498
max_epoch = max_epoch ,
503
499
early_stop = early_stop ,
504
500
multi_gpu = multi_gpu ,
505
- wandb_config = wandb_config
501
+ report_to_wandb = wandb
506
502
)
507
503
508
504
if action == Tasks .EVAL :
@@ -516,23 +512,23 @@ class Tasks:
516
512
if action == Tasks .TRAIN_EVAL :
517
513
if args .fold_count < 1 :
518
514
raise ValueError ("fold-count should be equal or more than 1" )
519
- train_eval (model ,
520
- embeddings_name = embeddings_name ,
521
- architecture = architecture ,
522
- transformer = transformer ,
523
- input_path = input_path ,
524
- output_path = output ,
525
- fold_count = args .fold_count ,
526
- max_sequence_length = max_sequence_length ,
527
- batch_size = batch_size ,
528
- use_ELMo = use_ELMo ,
529
- incremental = incremental ,
530
- input_model_path = input_model_path ,
531
- learning_rate = learning_rate ,
532
- max_epoch = max_epoch ,
533
- early_stop = early_stop ,
534
- multi_gpu = multi_gpu ,
535
- wandb_config = wandb_config )
515
+ train_eval (model ,
516
+ embeddings_name = embeddings_name ,
517
+ architecture = architecture ,
518
+ transformer = transformer ,
519
+ input_path = input_path ,
520
+ output_path = output ,
521
+ fold_count = args .fold_count ,
522
+ max_sequence_length = max_sequence_length ,
523
+ batch_size = batch_size ,
524
+ use_ELMo = use_ELMo ,
525
+ incremental = incremental ,
526
+ input_model_path = input_model_path ,
527
+ learning_rate = learning_rate ,
528
+ max_epoch = max_epoch ,
529
+ early_stop = early_stop ,
530
+ multi_gpu = multi_gpu ,
531
+ report_to_wandb = wandb )
536
532
537
533
if action == Tasks .TAG :
538
534
someTexts = []
0 commit comments