Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update TextCatBOW to use the fixed SparseLinear layer #13149

Merged
merged 4 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions spacy/cli/templates/quickstart_training.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,9 @@ grad_factor = 1.0
@layers = "reduce_mean.v1"

[components.textcat.model.linear_model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = true
length = 262144
ngram_size = 1
no_output_layer = false

Expand Down Expand Up @@ -308,8 +309,9 @@ grad_factor = 1.0
@layers = "reduce_mean.v1"

[components.textcat_multilabel.model.linear_model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = false
length = 262144
ngram_size = 1
no_output_layer = false

Expand Down Expand Up @@ -542,14 +544,15 @@ nO = null
width = ${components.tok2vec.model.encode.width}

[components.textcat.model.linear_model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = true
length = 262144
ngram_size = 1
no_output_layer = false

{% else -%}
[components.textcat.model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = true
ngram_size = 1
no_output_layer = false
Expand All @@ -570,15 +573,17 @@ nO = null
width = ${components.tok2vec.model.encode.width}

[components.textcat_multilabel.model.linear_model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = false
length = 262144
ngram_size = 1
no_output_layer = false

{% else -%}
[components.textcat_multilabel.model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = false
length = 262144
ngram_size = 1
no_output_layer = false
{%- endif %}
Expand Down
1 change: 1 addition & 0 deletions spacy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,7 @@ class Errors(metaclass=ErrorsWithCodes):
"predicted docs when training {component}.")
E1055 = ("The 'replace_listener' callback expects {num_params} parameters, "
"but only callbacks with one or three parameters are supported")
E1056 = ("The `TextCatBOW` architecture expects a length of at least 1, was {length}.")


# Deprecated model shortcuts, only used in errors and warnings
Expand Down
46 changes: 43 additions & 3 deletions spacy/ml/models/textcat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import List, Optional, cast
from typing import List, Optional, Tuple, cast

from thinc.api import (
Dropout,
Expand All @@ -12,6 +12,7 @@
Relu,
Softmax,
SparseLinear,
SparseLinear_v2,
chain,
clone,
concatenate,
Expand All @@ -25,9 +26,10 @@
)
from thinc.layers.chain import init as init_chain
from thinc.layers.resizable import resize_linear_weighted, resize_model
from thinc.types import Floats2d
from thinc.types import ArrayXd, Floats2d

from ...attrs import ORTH
from ...errors import Errors
from ...tokens import Doc
from ...util import registry
from ..extract_ngrams import extract_ngrams
Expand Down Expand Up @@ -95,10 +97,48 @@ def build_bow_text_classifier(
ngram_size: int,
no_output_layer: bool,
nO: Optional[int] = None,
) -> Model[List[Doc], Floats2d]:
return _build_bow_text_classifier(
exclusive_classes=exclusive_classes,
ngram_size=ngram_size,
no_output_layer=no_output_layer,
nO=nO,
sparse_linear=SparseLinear(nO=nO),
)


@registry.architectures("spacy.TextCatBOW.v3")
def build_bow_text_classifier_v3(
exclusive_classes: bool,
ngram_size: int,
no_output_layer: bool,
length: int = 262144,
svlandeg marked this conversation as resolved.
Show resolved Hide resolved
nO: Optional[int] = None,
) -> Model[List[Doc], Floats2d]:
if length < 1:
raise ValueError(Errors.E1056.format(length=length))

# Find k such that 2**(k-1) < length <= 2**k.
svlandeg marked this conversation as resolved.
Show resolved Hide resolved
length = 2 ** (length - 1).bit_length()

return _build_bow_text_classifier(
exclusive_classes=exclusive_classes,
ngram_size=ngram_size,
no_output_layer=no_output_layer,
nO=nO,
sparse_linear=SparseLinear_v2(nO=nO, length=length),
)


def _build_bow_text_classifier(
exclusive_classes: bool,
ngram_size: int,
no_output_layer: bool,
sparse_linear: Model[Tuple[ArrayXd, ArrayXd, ArrayXd], ArrayXd],
nO: Optional[int] = None,
) -> Model[List[Doc], Floats2d]:
fill_defaults = {"b": 0, "W": 0}
with Model.define_operators({">>": chain}):
sparse_linear = SparseLinear(nO=nO)
output_layer = None
if not no_output_layer:
fill_defaults["b"] = NEG_VALUE
Expand Down
6 changes: 4 additions & 2 deletions spacy/pipeline/textcat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,19 @@
depth = 2

[model.linear_model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = true
length = 262144
ngram_size = 1
no_output_layer = false
"""
DEFAULT_SINGLE_TEXTCAT_MODEL = Config().from_str(single_label_default_config)["model"]

single_label_bow_config = """
[model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = true
length = 262144
ngram_size = 1
no_output_layer = false
"""
Expand Down
5 changes: 3 additions & 2 deletions spacy/pipeline/textcat_multilabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,17 @@
depth = 2

[model.linear_model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = false
length = 262144
ngram_size = 1
no_output_layer = false
"""
DEFAULT_MULTI_TEXTCAT_MODEL = Config().from_str(multi_label_default_config)["model"]

multi_label_bow_config = """
[model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = false
ngram_size = 1
no_output_layer = false
Expand Down
2 changes: 1 addition & 1 deletion spacy/tests/pipeline/test_pipe_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test_pipe_class_component_model():
"@architectures": "spacy.TextCatEnsemble.v2",
"tok2vec": DEFAULT_TOK2VEC_MODEL,
"linear_model": {
"@architectures": "spacy.TextCatBOW.v2",
"@architectures": "spacy.TextCatBOW.v3",
"exclusive_classes": False,
"ngram_size": 1,
"no_output_layer": False,
Expand Down
31 changes: 18 additions & 13 deletions spacy/tests/pipeline/test_textcat.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def test_implicit_label(name, get_examples):
@pytest.mark.parametrize(
"name,textcat_config",
[
# BOW
# BOW V1
("textcat", {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "no_output_layer": False, "ngram_size": 3}),
("textcat", {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
Expand Down Expand Up @@ -451,11 +451,11 @@ def test_no_resize(name, textcat_config):
@pytest.mark.parametrize(
"name,textcat_config",
[
# BOW
("textcat", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "no_output_layer": False, "ngram_size": 3}),
("textcat", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
# BOW V3
("textcat", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "no_output_layer": False, "ngram_size": 3}),
("textcat", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
# CNN
("textcat", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
Expand All @@ -480,11 +480,11 @@ def test_resize(name, textcat_config):
@pytest.mark.parametrize(
"name,textcat_config",
[
# BOW
("textcat", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "no_output_layer": False, "ngram_size": 3}),
("textcat", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
# BOW v3
("textcat", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "no_output_layer": False, "ngram_size": 3}),
("textcat", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
danieldk marked this conversation as resolved.
Show resolved Hide resolved
# CNN
("textcat", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
Expand Down Expand Up @@ -693,9 +693,14 @@ def test_overfitting_IO_multi():
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "ngram_size": 4, "no_output_layer": False}),
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "ngram_size": 3, "no_output_layer": True}),
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "ngram_size": 2, "no_output_layer": True}),
# BOW V3
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}),
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "ngram_size": 4, "no_output_layer": False}),
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "ngram_size": 3, "no_output_layer": True}),
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "ngram_size": 2, "no_output_layer": True}),
# ENSEMBLE V2
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}}),
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "ngram_size": 5, "no_output_layer": False}}),
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}}),
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "ngram_size": 5, "no_output_layer": False}}),
# CNN V2
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
Expand Down
4 changes: 2 additions & 2 deletions spacy/tests/test_cli_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_project_push_pull(project_dir):

def test_find_function_valid():
# example of architecture in main code base
function = "spacy.TextCatBOW.v2"
function = "spacy.TextCatBOW.v3"
result = CliRunner().invoke(app, ["find-function", function, "-r", "architectures"])
assert f"Found registered function '{function}'" in result.stdout
assert "textcat.py" in result.stdout
Expand All @@ -257,7 +257,7 @@ def test_find_function_valid():

def test_find_function_invalid():
# invalid registry
function = "spacy.TextCatBOW.v2"
function = "spacy.TextCatBOW.v3"
registry = "foobar"
result = CliRunner().invoke(
app, ["find-function", function, "--registry", registry]
Expand Down
3 changes: 2 additions & 1 deletion spacy/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,9 @@ def test_util_dot_section():
factory = "textcat"

[components.textcat.model]
@architectures = "spacy.TextCatBOW.v2"
@architectures = "spacy.TextCatBOW.v3"
exclusive_classes = true
length = 262144
ngram_size = 1
no_output_layer = false
"""
Expand Down
Loading
Loading