diff --git a/spacy/tests/parser/test_add_label.py b/spacy/tests/parser/test_add_label.py index 31bfbe56d2d..7f19ab4551d 100644 --- a/spacy/tests/parser/test_add_label.py +++ b/spacy/tests/parser/test_add_label.py @@ -8,7 +8,8 @@ from spacy.gold import GoldParse from spacy.vocab import Vocab from spacy.tokens import Doc -from spacy.pipeline import DependencyParser +from spacy.pipeline import DependencyParser, EntityRecognizer +from spacy.util import fix_random_seed @pytest.fixture @@ -19,29 +20,29 @@ def vocab(): @pytest.fixture def parser(vocab): parser = DependencyParser(vocab) - parser.cfg["token_vector_width"] = 8 - parser.cfg["hidden_width"] = 30 - parser.cfg["hist_size"] = 0 + return parser + + +def test_init_parser(parser): + pass + + +def _train_parser(parser): + fix_random_seed(1) parser.add_label("left") parser.begin_training([], **parser.cfg) sgd = Adam(NumpyOps(), 0.001) for i in range(10): losses = {} - doc = Doc(vocab, words=["a", "b", "c", "d"]) + doc = Doc(parser.vocab, words=["a", "b", "c", "d"]) gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"]) parser.update([doc], [gold], sgd=sgd, losses=losses) return parser -def test_init_parser(parser): - pass - - -# TODO: This is flakey, because it depends on what the parser first learns. -# TODO: This now seems to be implicated in segfaults. Not sure what's up! -@pytest.mark.skip def test_add_label(parser): + parser = _train_parser(parser) doc = Doc(parser.vocab, words=["a", "b", "c", "d"]) doc = parser(doc) assert doc[0].head.i == 1 @@ -69,3 +70,16 @@ def test_add_label(parser): doc = parser(doc) assert doc[0].dep_ == "right" assert doc[2].dep_ == "left" + + +@pytest.mark.xfail +def test_add_label_deserializes_correctly(): + ner1 = EntityRecognizer(Vocab()) + ner1.add_label("C") + ner1.add_label("B") + ner1.add_label("A") + ner1.begin_training([]) + ner2 = EntityRecognizer(Vocab()).from_bytes(ner1.to_bytes()) + assert ner1.moves.n_moves == ner2.moves.n_moves + for i in range(ner1.moves.n_moves): + assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i)