Skip to content

Commit 89c0ba1

Browse files
committed
CLN: Remove entry points for models, metrics. Fixes #601
Remove since they are basically unused by anyone but us, but require maintenance, test, etc. Squash commits: - Remove vak/entry_points.py - Remove import of entry_points from vak/__init__.py - Remove model and metric entry points from pyproject.toml - Rewrite models/models.py to not use entry points - Fix use of models.find in config/validators.py - Fix use of find in config/models.py - Remove entry point test in tests/test_models/test_teenytweetynet.py - Fix entry point test in tests/test_models/test_tweetynet.py
1 parent 66436e9 commit 89c0ba1

9 files changed

+36
-95
lines changed

pyproject.toml

+1-10
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,10 @@ Documentation = "https://vak.readthedocs.io"
6868
[project.scripts]
6969
vak = 'vak.__main__:main'
7070

71-
[project.entry-points."vak.models"]
72-
TeenyTweetyNetModel = 'vak.models.teenytweetynet:TeenyTweetyNet'
73-
TweetyNetModel = 'vak.models.tweetynet:TweetyNet'
74-
75-
[project.entry-points."vak.metrics"]
76-
Accuracy = 'vak.metrics.Accuracy'
77-
Levenshtein = 'vak.metrics.Levenshtein'
78-
SegmentErrorRate = 'vak.metrics.SegmentErrorRate'
79-
8071
[tool.flit.sdist]
8172
exclude = [
8273
"tests/data_for_tests"
8374
]
8475

8576
[tool.pytest.ini_options]
86-
filterwarnings = ["ignore:::.*torch.utils.tensorboard",]
77+
filterwarnings = ["ignore:::.*torch.utils.tensorboard",]

src/vak/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
curvefit,
2121
datasets,
2222
device,
23-
entry_points,
2423
files,
2524
io,
2625
labeled_timebins,
@@ -54,7 +53,6 @@
5453
"csv",
5554
"datasets",
5655
"device",
57-
"entry_points",
5856
"files",
5957
"io",
6058
"labeled_timebins",

src/vak/config/models.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,12 @@ def map_from_config_dict(config_dict, model_names):
4949
# to avoid circular dependencies
5050
# (user would be unable to import models in other packages
5151
# if the module in the other package needed to `import vak`)
52-
MODELS = {model_name: model_builder for model_name, model_builder in models.find()}
52+
MODEL_NAMES = list(models.models.BUILTIN_MODELS.keys())
5353
for model_name in model_names:
54-
if model_name not in MODELS:
55-
# try appending 'Model' to name
56-
tmp_model_name = f"{model_name}Model"
57-
if tmp_model_name not in MODELS:
58-
raise ValueError(
59-
f"Did not find an installed model named {model_name} or {tmp_model_name}. "
60-
f"Installed models are: {list(MODELS.keys())}"
61-
)
54+
if model_name not in MODEL_NAMES:
55+
raise ValueError(
56+
f"Invalid model name: {model_name}.\nValid model names are: {MODEL_NAMES}"
57+
)
6258

6359
# now see if we can find corresponding sections in config
6460
sections = list(config_dict.keys())

src/vak/config/validators.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ def is_a_file(instance, attribute, value):
2626

2727

2828
def is_valid_model_name(instance, attribute, value):
29-
MODEL_NAMES = [model_name for model_name, model_builder in models.find()]
29+
MODEL_NAMES = list(models.models.BUILTIN_MODELS.keys())
3030
for model_name in value:
31-
if model_name not in MODEL_NAMES and f"{model_name}Model" not in MODEL_NAMES:
31+
if model_name not in MODEL_NAMES:
3232
raise ValueError(
33-
f"Model {model_name} not found when importing installed models."
33+
f"Invalid model name: {model_name}.\nValid model names are: {MODEL_NAMES}"
3434
)
3535

3636

@@ -91,7 +91,7 @@ def are_sections_valid(config_dict, toml_path=None):
9191
f"Please use just one command besides `prep` per .toml configuration file"
9292
)
9393

94-
MODEL_NAMES = [model_name for model_name, model_builder in models.find()]
94+
MODEL_NAMES = list(models.models.BUILTIN_MODELS.keys())
9595
# add model names to valid sections so users can define model config in sections
9696
valid_sections = VALID_SECTIONS + MODEL_NAMES
9797
for section in sections:

src/vak/entry_points.py

-8
This file was deleted.

src/vak/models/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
definition,
55
)
66
from .base import Model
7-
from .models import find, from_model_config_map
7+
from .models import from_model_config_map
88
from .teenytweetynet import TeenyTweetyNet
99
from .tweetynet import TweetyNet
1010
from .windowed_frame_classification_model import WindowedFrameClassificationModel
@@ -14,7 +14,6 @@
1414
"base",
1515
"decorator",
1616
"definition",
17-
"find",
1817
"from_model_config_map",
1918
"Model",
2019
"TeenyTweetyNet",

src/vak/models/models.py

+25-40
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,25 @@
1-
"""module that contains helper function to load models
1+
"""Helper function to load models"""
2+
from __future__ import annotations
23

3-
Models in separate packages should make themselves available to vak by including
4-
'vak.models' in the entry_points of their setup.py file.
4+
from .tweetynet import TweetyNet
5+
from .teenytweetynet import TeenyTweetyNet
56

6-
For example, if you had a package `grunet` containing a model
7-
that was instantiated by a function `GRUnet`,
8-
then that package would include the following in its setup.py file:
97

10-
setup(
11-
...
12-
entry_points={'vak.models': 'GRUnet = grunet:GRUnet'},
13-
...
14-
)
8+
# TODO: Replace constant with decorator that registers models, https://github.com/vocalpy/vak/issues/623
9+
BUILTIN_MODELS = {
10+
'TweetyNet': TweetyNet,
11+
'TeenyTweetyNet': TeenyTweetyNet
12+
}
1513

16-
For more detail on entry points in Python, see:
17-
https://packaging.python.org/guides/creating-and-discovering-plugins/#using-package-metadata
18-
https://setuptools.readthedocs.io/en/latest/setuptools.html#dynamic-discovery-of-services-and-plugins
19-
https://amir.rachum.com/blog/2017/07/28/python-entry-points/
20-
"""
21-
from .. import entry_points
14+
MODEL_NAMES = list(BUILTIN_MODELS.keys())
2215

23-
MODELS_ENTRY_POINT = "vak.models"
2416

25-
26-
def find():
27-
"""find installed vak.models
28-
29-
returns generator that yields model name and function for loading
30-
"""
31-
for entrypoint in entry_points._iter(MODELS_ENTRY_POINT):
32-
yield entrypoint.name, entrypoint.load()
33-
34-
35-
def from_model_config_map(model_config_map,
17+
def from_model_config_map(model_config_map: dict[str: dict],
3618
# TODO: move num_classes / input_shape into model configs
37-
num_classes,
38-
input_shape,
39-
labelmap):
40-
"""get models that are ready to train, given their names and configurations.
19+
num_classes: int,
20+
input_shape: tuple[int, int, int],
21+
labelmap: dict) -> dict:
22+
"""Get models that are ready to train, given their names and configurations.
4123
4224
Given a dictionary that maps model names to configurations,
4325
along with the number of classes they should be trained to discriminate and their input shape,
@@ -66,7 +48,7 @@ def from_model_config_map(model_config_map,
6648
models_map : dict
6749
where keys are model names and values are instances of the models, ready for training
6850
"""
69-
MODELS = {model_name: model_builder for model_name, model_builder in find()}
51+
import vak.models
7052

7153
models_map = {}
7254
for model_name, model_config in model_config_map.items():
@@ -77,12 +59,15 @@ def from_model_config_map(model_config_map,
7759
num_classes=num_classes,
7860
input_shape=input_shape,
7961
)
62+
8063
try:
81-
model = MODELS[model_name].from_config(config=model_config, labelmap=labelmap)
82-
except KeyError:
83-
model = MODELS[f"{model_name}Model"].from_config(
84-
config=model_config,
85-
labelmap=labelmap
86-
)
64+
model_class = getattr(vak.models, model_name)
65+
except AttributeError as e:
66+
raise ValueError(
67+
f"Invalid model name: '{model_name}'.\nValid model names are: {MODEL_NAMES}"
68+
) from e
69+
70+
model = model_class.from_config(config=model_config, labelmap=labelmap)
8771
models_map[model_name] = model
72+
8873
return models_map

tests/test_models/test_teenytweetynet.py

-10
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,6 @@
88
from .test_tweetynet import TEST_INIT_ARGVALS
99

1010

11-
def test_installed():
12-
# makes sure the entry point loads properly
13-
if "vak" in sys.modules:
14-
sys.modules.pop("vak")
15-
import vak
16-
17-
models = [name for name, class_ in sorted(vak.models.find())]
18-
assert "TeenyTweetyNetModel" in models
19-
20-
2111
class TestTeenyTweetyNet:
2212
def test_model_is_decorated(self):
2313
assert issubclass(vak.models.TeenyTweetyNet,

tests/test_models/test_tweetynet.py

-10
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,6 @@
3131
TEST_INIT_ARGVALS = itertools.product(LABELMAPS, INPUT_SHAPES)
3232

3333

34-
def test_installed():
35-
# makes sure the entry point loads properly
36-
if "vak" in sys.modules:
37-
sys.modules.pop("vak")
38-
import vak
39-
40-
models = [name for name, class_ in sorted(vak.models.find())]
41-
assert "TweetyNetModel" in models
42-
43-
4434
class TestTweetyNet:
4535
def test_model_is_decorated(self):
4636
assert issubclass(vak.models.TweetyNet,

0 commit comments

Comments
 (0)