Skip to content

Commit

Permalink
fix wandb parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Nov 6, 2024
1 parent 9292e2a commit 0d7f3cf
Showing 1 changed file with 27 additions and 31 deletions.
58 changes: 27 additions & 31 deletions delft/applications/grobidTagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def configure(model, architecture, output_path=None, max_sequence_length=-1, bat
def train(model, embeddings_name=None, architecture=None, transformer=None, input_path=None,
output_path=None, features_indices=None, max_sequence_length=-1, batch_size=-1, max_epoch=-1,
use_ELMo=False, incremental=False, input_model_path=None, patience=-1, learning_rate=None, early_stop=None, multi_gpu=False,
wandb_config=None):
report_to_wandb=False):

print('Loading data...')
if input_path == None:
Expand Down Expand Up @@ -216,7 +216,7 @@ def train(model, embeddings_name=None, architecture=None, transformer=None, inpu
early_stop=early_stop,
patience=patience,
learning_rate=learning_rate,
report_to_wandb=wandb_config
report_to_wandb=report_to_wandb
)

if incremental:
Expand All @@ -243,9 +243,10 @@ def train(model, embeddings_name=None, architecture=None, transformer=None, inpu
# split data, train a GROBID model and evaluate it
def train_eval(model, embeddings_name=None, architecture='BidLSTM_CRF', transformer=None,
input_path=None, output_path=None, fold_count=1,
features_indices=None, max_sequence_length=-1, batch_size=-1, max_epoch=-1,
features_indices=None, max_sequence_length=-1, batch_size=-1, max_epoch=-1,
use_ELMo=False, incremental=False, input_model_path=None, patience=-1,
learning_rate=None, early_stop=None, multi_gpu=False, wandb_config=None):
learning_rate=None, early_stop=None, multi_gpu=False,
report_to_wandb=False):

print('Loading data...')
if input_path is None:
Expand Down Expand Up @@ -274,11 +275,12 @@ def train_eval(model, embeddings_name=None, architecture='BidLSTM_CRF', transfor
use_ELMo,
patience,
early_stop)

model = Sequence(model_name, architecture=architecture, embeddings_name=embeddings_name,
max_sequence_length=max_sequence_length, recurrent_dropout=0.50, batch_size=batch_size,
learning_rate=learning_rate, max_epoch=max_epoch, early_stop=early_stop, patience=patience,
use_ELMo=use_ELMo, fold_number=fold_count, multiprocessing=multiprocessing,
features_indices=features_indices, transformer_name=transformer, report_to_wandb=wandb_config)
features_indices=features_indices, transformer_name=transformer, report_to_wandb=report_to_wandb)

if incremental:
if input_model_path != None:
Expand Down Expand Up @@ -310,7 +312,7 @@ def train_eval(model, embeddings_name=None, architecture='BidLSTM_CRF', transfor


# split data, train a GROBID model and evaluate it
def eval_(model, input_path=None, architecture='BidLSTM_CRF', use_ELMo=False):
def eval_(model, input_path=None, architecture='BidLSTM_CRF', use_ELMo=False, report_to_wandb=False):
print('Loading data...')
if input_path is None:
# it should never be the case
Expand All @@ -331,7 +333,7 @@ def eval_(model, input_path=None, architecture='BidLSTM_CRF', use_ELMo=False):
start_time = time.time()

# load the model
model = Sequence(model_name)
model = Sequence(model_name, report_to_wandb=report_to_wandb)
model.load()

# evaluation
Expand Down Expand Up @@ -478,12 +480,6 @@ class Tasks:
# default word embeddings
embeddings_name = "glove-840B"

wandb_config = None
if wandb:
wandb_config = {
"project": "delft-grobidTagger"
}

if action == Tasks.TRAIN:
train(
model,
Expand All @@ -502,7 +498,7 @@ class Tasks:
max_epoch=max_epoch,
early_stop=early_stop,
multi_gpu=multi_gpu,
wandb_config=wandb_config
report_to_wandb=wandb
)

if action == Tasks.EVAL:
Expand All @@ -516,23 +512,23 @@ class Tasks:
if action == Tasks.TRAIN_EVAL:
if args.fold_count < 1:
raise ValueError("fold-count should be equal or more than 1")
train_eval(model,
embeddings_name=embeddings_name,
architecture=architecture,
transformer=transformer,
input_path=input_path,
output_path=output,
fold_count=args.fold_count,
max_sequence_length=max_sequence_length,
batch_size=batch_size,
use_ELMo=use_ELMo,
incremental=incremental,
input_model_path=input_model_path,
learning_rate=learning_rate,
max_epoch=max_epoch,
early_stop=early_stop,
multi_gpu=multi_gpu,
wandb_config=wandb_config)
train_eval(model,
embeddings_name=embeddings_name,
architecture=architecture,
transformer=transformer,
input_path=input_path,
output_path=output,
fold_count=args.fold_count,
max_sequence_length=max_sequence_length,
batch_size=batch_size,
use_ELMo=use_ELMo,
incremental=incremental,
input_model_path=input_model_path,
learning_rate=learning_rate,
max_epoch=max_epoch,
early_stop=early_stop,
multi_gpu=multi_gpu,
report_to_wandb=wandb)

if action == Tasks.TAG:
someTexts = []
Expand Down

0 comments on commit 0d7f3cf

Please sign in to comment.