Skip to content

Commit

Permalink
💫 Fix class mismap on parser deserializing (closes #3433) (#3470)
Browse files Browse the repository at this point in the history
v2.1 introduced a regression when deserializing the parser after
parser.add_label() had been called. The code around the class mapping is
pretty confusing currently, as it was written to accommodate backwards
model compatibility. It needs to be revised when the models are next
retrained.

Closes #3433
  • Loading branch information
honnibal authored Mar 23, 2019
1 parent 444a3ab commit d9a07a7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 20 deletions.
9 changes: 5 additions & 4 deletions spacy/syntax/nn_parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -574,11 +574,12 @@ cdef class Parser:
cfg.setdefault('min_action_freq', 30)
actions = self.moves.get_actions(gold_parses=get_gold_tuples(),
min_freq=cfg.get('min_action_freq', 30))
previous_labels = dict(self.moves.labels)
for action, labels in self.moves.labels.items():
actions.setdefault(action, {})
for label, freq in labels.items():
if label not in actions[action]:
actions[action][label] = freq
self.moves.initialize_actions(actions)
for action, label_freqs in previous_labels.items():
for label in label_freqs:
self.moves.add_action(action, label)
cfg.setdefault('token_vector_width', 96)
if self.model is True:
self.model, cfg = self.Model(self.moves.n_moves, **cfg)
Expand Down
17 changes: 1 addition & 16 deletions spacy/tests/parser/test_add_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _train_parser(parser):
parser.begin_training([], **parser.cfg)
sgd = Adam(NumpyOps(), 0.001)

for i in range(10):
for i in range(5):
losses = {}
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
Expand All @@ -43,21 +43,7 @@ def _train_parser(parser):

def test_add_label(parser):
parser = _train_parser(parser)
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
doc = parser(doc)
assert doc[0].head.i == 1
assert doc[0].dep_ == "left"
assert doc[1].head.i == 1
assert doc[2].head.i == 3
assert doc[2].head.i == 3
parser.add_label("right")
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
doc = parser(doc)
assert doc[0].head.i == 1
assert doc[0].dep_ == "left"
assert doc[1].head.i == 1
assert doc[2].head.i == 3
assert doc[2].head.i == 3
sgd = Adam(NumpyOps(), 0.001)
for i in range(10):
losses = {}
Expand All @@ -72,7 +58,6 @@ def test_add_label(parser):
assert doc[2].dep_ == "left"


@pytest.mark.xfail
def test_add_label_deserializes_correctly():
ner1 = EntityRecognizer(Vocab())
ner1.add_label("C")
Expand Down

0 comments on commit d9a07a7

Please sign in to comment.