diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 177340703c1..a6a4769013f 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -574,11 +574,12 @@ cdef class Parser: cfg.setdefault('min_action_freq', 30) actions = self.moves.get_actions(gold_parses=get_gold_tuples(), min_freq=cfg.get('min_action_freq', 30)) - previous_labels = dict(self.moves.labels) + for action, labels in self.moves.labels.items(): + actions.setdefault(action, {}) + for label, freq in labels.items(): + if label not in actions[action]: + actions[action][label] = freq self.moves.initialize_actions(actions) - for action, label_freqs in previous_labels.items(): - for label in label_freqs: - self.moves.add_action(action, label) cfg.setdefault('token_vector_width', 96) if self.model is True: self.model, cfg = self.Model(self.moves.n_moves, **cfg) diff --git a/spacy/tests/parser/test_add_label.py b/spacy/tests/parser/test_add_label.py index 7f19ab4551d..45a51ac8e2f 100644 --- a/spacy/tests/parser/test_add_label.py +++ b/spacy/tests/parser/test_add_label.py @@ -33,7 +33,7 @@ def _train_parser(parser): parser.begin_training([], **parser.cfg) sgd = Adam(NumpyOps(), 0.001) - for i in range(10): + for i in range(5): losses = {} doc = Doc(parser.vocab, words=["a", "b", "c", "d"]) gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"]) @@ -43,21 +43,7 @@ def _train_parser(parser): def test_add_label(parser): parser = _train_parser(parser) - doc = Doc(parser.vocab, words=["a", "b", "c", "d"]) - doc = parser(doc) - assert doc[0].head.i == 1 - assert doc[0].dep_ == "left" - assert doc[1].head.i == 1 - assert doc[2].head.i == 3 - assert doc[2].head.i == 3 parser.add_label("right") - doc = Doc(parser.vocab, words=["a", "b", "c", "d"]) - doc = parser(doc) - assert doc[0].head.i == 1 - assert doc[0].dep_ == "left" - assert doc[1].head.i == 1 - assert doc[2].head.i == 3 - assert doc[2].head.i == 3 sgd = Adam(NumpyOps(), 0.001) for i in range(10): losses = {} @@ -72,7 +58,6 @@ def test_add_label(parser): assert doc[2].dep_ == "left" -@pytest.mark.xfail def test_add_label_deserializes_correctly(): ner1 = EntityRecognizer(Vocab()) ner1.add_label("C")