Skip to content

Commit

Permalink
Add early stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
kowaalczyk committed Apr 3, 2019
1 parent 24cecdb commit b7c4a46
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions spacy/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
pipeline=("Comma-separated names of pipeline components", "option", "p", str),
vectors=("Model to load vectors from", "option", "v", str),
n_iter=("Number of iterations", "option", "n", int),
early_stopping_iter=("Maximum number of training epochs without dev accuracy improvement", "option", "e", int),
n_examples=("Number of examples", "option", "ns", int),
use_gpu=("Use GPU", "option", "g", int),
version=("Model version", "option", "V", str),
Expand Down Expand Up @@ -74,6 +75,7 @@ def train(
pipeline="tagger,parser,ner",
vectors=None,
n_iter=30,
early_stopping_iter=None,
n_examples=0,
use_gpu=-1,
version="0.0.0",
Expand Down Expand Up @@ -222,6 +224,8 @@ def train(
msg.row(row_head, **row_settings)
msg.row(["-" * width for width in row_settings["widths"]], **row_settings)
try:
iter_since_best = 0
best_score = 0.
for i in range(n_iter):
train_docs = corpus.train_docs(
nlp, noise_level=noise_level, gold_preproc=gold_preproc, max_length=0
Expand Down Expand Up @@ -328,6 +332,18 @@ def train(
gpu_wps=gpu_wps,
)
msg.row(progress, **row_settings)
# early stopping
if early_stopping_iter is not None:
current_score = _score_for_model(meta)
if current_score < best_score:
iter_since_best += 1
else:
iter_since_best = 0
best_score = current_score
if iter_since_best >= early_stopping_iter:
msg.text(f"Early stopping, best iteration is: {i-iter_since_best}")
msg.text(f"Best score = {best_score}; Final iteration score = {current_score}")
break
finally:
with nlp.use_params(optimizer.averages):
final_model_path = output_path / "model-final"
Expand All @@ -337,6 +353,18 @@ def train(
best_model_path = _collate_best_model(meta, output_path, nlp.pipe_names)
msg.good("Created best model", best_model_path)

def _score_for_model(meta):
""" Returns mean score between tasks in pipeline that can be used for early stopping. """
mean_acc = list()
pipes = meta['pipeline']
acc = meta['accuracy']
if 'tagger' in pipes:
mean_acc.append(acc['tags_acc'])
if 'parser' in pipes:
mean_acc.append((acc['uas']+acc['las']) / 2)
if 'ner' in pipes:
mean_acc.append((acc['ents_p']+acc['ents_r']+acc['ents_f']) / 3)
return sum(mean_acc) / len(mean_acc)

@contextlib.contextmanager
def _create_progress_bar(total):
Expand Down

0 comments on commit b7c4a46

Please sign in to comment.