Skip to content

Commit

Permalink
Consolidate CLI into cnlpt command (#218)
Browse files Browse the repository at this point in the history
* consolidate cli

* change --model to --model_type, lazy load fastapi apps
  • Loading branch information
ianbulovic authored Nov 22, 2024
1 parent e04edbd commit 912713d
Show file tree
Hide file tree
Showing 21 changed files with 153 additions and 273 deletions.
14 changes: 7 additions & 7 deletions docker/Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,30 @@ ENTRYPOINT ["/bin/bash"]

FROM base as current
run python -c "import sys;sys.path.append('/home/docker');import model_download; model_download.current()"
ENTRYPOINT ["cnlpt_current_rest", "-p", "8000"]
ENTRYPOINT ["cnlpt","rest", "--model-type", "current", "-p", "8000"]

FROM base as dtr
run python -c "import sys;sys.path.append('/home/docker');import model_download; model_download.dtr()"
ENTRYPOINT ["cnlpt_dtr_rest", "-p", "8000"]
ENTRYPOINT ["cnlpt", "rest", "--model-type", "dtr", "-p", "8000"]

FROM base as event
run python -c "import sys;sys.path.append('/home/docker');import model_download; model_download.event()"
ENTRYPOINT ["cnlpt_event_rest", "-p", "8000"]
ENTRYPOINT ["cnlpt", "rest", "--model-type", "event", "-p", "8000"]

FROM base as negation
run python -c "import sys;sys.path.append('/home/docker');import model_download; model_download.negation()"
ENTRYPOINT ["cnlpt_negation_rest", "-p", "8000"]
ENTRYPOINT ["cnlpt", "rest", "--model-type", "negation", "-p", "8000"]

FROM base as termexists
run python -c "import sys;sys.path.append('/home/docker');import model_download; model_download.termexists()"
# Temporary fix, remove once the released pip package has the new model
run sed -i 's/sharpseed-termexists/termexists_pubmedbert_ssm/g' /usr/local/lib/python3.9/site-packages/cnlpt/api/termexists_rest.py
ENTRYPOINT ["cnlpt_termexists_rest", "-p", "8000"]
ENTRYPOINT ["cnlpt", "rest", "--model-type", "termexists", "-p", "8000"]

FROM base as temporal
run python -c "import sys;sys.path.append('/home/docker');import model_download; model_download.temporal()"
ENTRYPOINT ["cnlpt_temporal_rest", "-p", "8000"]
ENTRYPOINT ["cnlpt", "rest", "--model-type", "temporal", "-p", "8000"]

FROM base as timex
run python -c "import sys;sys.path.append('/home/docker');import model_download; model_download.timex()"
ENTRYPOINT ["cnlpt_timex_rest", "-p", "8000"]
ENTRYPOINT ["cnlpt", "rest", "--model-type", "timex", "-p", "8000"]
20 changes: 10 additions & 10 deletions docker/Dockerfile.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -16,42 +16,42 @@ ENTRYPOINT ["/bin/bash"]

FROM base as current
run /usr/bin/python3.9 -c "import sys;sys.path.append('/home/docker');import model_download; model_download.current()"
ENTRYPOINT ["cnlpt_current_rest", "-p", "8000"]
ENTRYPOINT ["cnlpt", "rest", "--model-type", "current", "-p", "8000"]

FROM base as dtr
run /usr/bin/python3.9 -c "import sys;sys.path.append('/home/docker');import model_download; model_download.dtr()"
ENTRYPOINT ["cnlpt_dtr_rest", "-p", "8000"]
ENTRYPOINT ["cnlpt", "rest", "--model-type", "dtr", "-p", "8000"]

FROM base as event
run /usr/bin/python3.9 -c "import sys;sys.path.append('/home/docker');import model_download; model_download.event()"
ENTRYPOINT ["cnlpt_event_rest", "-p", "8000"]
ENTRYPOINT ["cnlpt", "rest", "--model-type", "event", "-p", "8000"]

FROM base as negation
run /usr/bin/python3.9 -c "import sys;sys.path.append('/home/docker');import model_download; model_download.negation()"
ENTRYPOINT ["cnlpt_negation_rest", "-p", "8000"]
ENTRYPOINT ["cnlpt", "rest", "--model-type", "negation", "-p", "8000"]

FROM base as termexists
run /usr/bin/python3.9 -c "import sys;sys.path.append('/home/docker');import model_download; model_download.termexists()"
# Temporary fix, remove once the released pip package has the new model
run sed -i 's/sharpseed-termexists/termexists_pubmedbert_ssm/g' /usr/local/lib/python3.9/site-packages/cnlpt/api/termexists_rest.py
ENTRYPOINT ["cnlpt_termexists_rest", "-p", "8000"]
ENTRYPOINT ["cnlpt", "rest", "--model-type", "termexists", "-p", "8000"]

FROM base as temporal
run /usr/bin/python3.9 -c "import sys;sys.path.append('/home/docker');import model_download; model_download.temporal()"
ENTRYPOINT ["cnlpt_temporal_rest", "-p", "8000"]
ENTRYPOINT ["cnlpt", "rest", "--model-type", "temporal", "-p", "8000"]

FROM base as timex
run /usr/bin/python3.9 -c "import sys;sys.path.append('/home/docker');import model_download; model_download.timex()"
ENTRYPOINT ["cnlpt_timex_rest", "-p", "8000"]
ENTRYPOINT ["cnlpt", "rest", "--model-type", "timex", "-p", "8000"]

FROM base as hier_local
ENV MODEL_PATH /opt/cnlp/model
ENTRYPOINT ["python3.9", "-m", "cnlpt.api.hier_rest", "-p", "8000"]
ENTRYPOINT ["cnlpt", "rest", "--model-type", "hier", "-p", "8000"]

FROM base as hier
run /usr/bin/python3.9 -c "import sys;sys.path.append('/home/docker'); import model_download; print('$model_loc'); model_download.hier('$model_loc')"
ENTRYPOINT ["python3.9", "-m", "cnlpt.api.hier_rest", "-p", "8000"]
ENTRYPOINT ["cnlpt", "rest", "--model-type", "hier", "-p", "8000"]

FROM base as cnn
ENV MODEL_PATH /opt/cnlp/model
ENTRYPOINT ["python3.9", "-m", "cnlpt.api.cnn_rest", "-p", "8000"]
ENTRYPOINT ["cnlpt", "rest", "--model-type", "cnn", "-p", "8000"]
10 changes: 1 addition & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ readme = "README.md"
requires-python = ">=3.9, <3.12"
dependencies = [
"anaforatools~=1.1.0",
"click>=8.1.7",
"datasets~=2.21.0",
"fastapi~=0.115.2",
"httpx>=0.27.2",
Expand Down Expand Up @@ -65,15 +66,6 @@ dev = [

[project.scripts]
"cnlpt" = "cnlpt.__main__:main"
"cnlpt_current_rest" = "cnlpt.api.current_rest:rest"
"cnlpt_dtr_rest" = "cnlpt.api.dtr_rest:rest"
"cnlpt_event_rest" = "cnlpt.api.event_rest:rest"
"cnlpt_negation_rest" = "cnlpt.api.negation_rest:rest"
"cnlpt_temporal_rest" = "cnlpt.api.temporal_rest:rest"
"cnlpt_termexists_rest" = "cnlpt.api.termexists_rest:rest"
"cnlpt_timex_rest" = "cnlpt.api.timex_rest:rest"
"cnlpt_hier_rest" = "cnlpt.api.hier_rest:rest"
"cnlpt_cnn_rest" = "cnlpt.api.cnn_rest:rest"

[tool.ruff]
target-version = "py39"
Expand Down
4 changes: 2 additions & 2 deletions src/cnlpt/__main__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from . import __version__
from ._cli.main import cli


def main():
print(__version__)
cli()


if __name__ == "__main__":
Expand Down
28 changes: 28 additions & 0 deletions src/cnlpt/_cli/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import click

from .. import __version__
from .rest import rest_command


@click.group(invoke_without_command=True)
@click.option(
"--version",
type=bool,
is_flag=True,
default=False,
help="Print the cnlp_transformers version.",
)
@click.pass_context
def cli(ctx: click.Context, version: bool):
if ctx.invoked_subcommand is not None:
return

if version:
print(__version__)
ctx.exit()
else:
click.echo(ctx.get_help())
ctx.exit()


cli.add_command(rest_command)
33 changes: 33 additions & 0 deletions src/cnlpt/_cli/rest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import click

from ..api import MODEL_TYPES, get_rest_app


@click.command("rest", context_settings={"show_default": True})
@click.option(
"--model-type",
type=click.Choice(MODEL_TYPES),
required=True,
)
@click.option(
"-h",
"--host",
type=str,
default="0.0.0.0",
help="Host address to serve the REST app.",
)
@click.option(
"-p", "--port", type=int, default=8000, help="Port to serve the REST app."
)
@click.option(
"--reload",
type=bool,
is_flag=True,
default=False,
help="Auto-reload the REST app.",
)
def rest_command(model_type: str, host: str, port: int, reload: bool):
"""Start a REST application from a model."""
import uvicorn

uvicorn.run(get_rest_app(model_type), host=host, port=port, reload=reload)
52 changes: 52 additions & 0 deletions src/cnlpt/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
MODEL_TYPES = (
"cnn",
"current",
"dtr",
"event",
"hier",
"negation",
"temporal",
"termexists",
"timex",
)


def get_rest_app(model_type: str):
if model_type == "cnn":
from .cnn_rest import app

return app
elif model_type == "current":
from .current_rest import app

return app
elif model_type == "dtr":
from .dtr_rest import app

return app
elif model_type == "event":
from .event_rest import app

return app
elif model_type == "hier":
from .hier_rest import app

return app
elif model_type == "negation":
from .negation_rest import app

return app
elif model_type == "temporal":
from .temporal_rest import app

return app
elif model_type == "termexists":
from .termexists_rest import app

return app
elif model_type == "timex":
from .timex_rest import app

return app
else:
raise ValueError(f"unknown model type: {model_type}")
34 changes: 6 additions & 28 deletions src/cnlpt/api/cnn_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,9 @@
from transformers import AutoTokenizer, PreTrainedTokenizer

from ..BaselineModels import CnnSentenceClassifier
from .cnlp_rest import UnannotatedDocument, create_dataset, resolve_device
from .utils import UnannotatedDocument, create_dataset, resolve_device

MODEL_NAME = os.getenv("MODEL_PATH")
if MODEL_NAME is None:
sys.stderr.write("This REST container requires a MODEL_PATH environment variable\n")
sys.exit(-1)
device = os.getenv("MODEL_DEVICE", "auto")
device = resolve_device(device)

Expand All @@ -52,6 +49,11 @@
@asynccontextmanager
async def lifespan():
global model, tokenizer, conf_dict
if MODEL_NAME is None:
sys.stderr.write(
"This REST container requires a MODEL_PATH environment variable\n"
)
sys.exit(-1)
conf_file = join(MODEL_NAME, "config.json")
with open(conf_file) as fp:
conf_dict = json.load(fp)
Expand Down Expand Up @@ -95,27 +97,3 @@ async def process(doc: UnannotatedDocument):
# but i'm outputting them all, for transparency
out_probabilities = [str(prob) for prob in probabilities]
return {"result": result, "probabilities": out_probabilities}


def rest():
import argparse

parser = argparse.ArgumentParser(
description="Run the http server for serving CNN model outputs."
)
parser.add_argument(
"-p",
"--port",
type=int,
help="The port number to run the server on",
default=8000,
)
args = parser.parse_args()

import uvicorn

uvicorn.run("cnlpt.api.cnn_rest:app", host="0.0.0.0", port=args.port, reload=False)


if __name__ == "__main__":
rest()
27 changes: 1 addition & 26 deletions src/cnlpt/api/current_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from transformers import Trainer
from transformers.tokenization_utils import PreTrainedTokenizer

from .cnlp_rest import (
from .utils import (
EntityDocument,
create_dataset,
create_instance_string,
Expand Down Expand Up @@ -104,28 +104,3 @@ async def process(doc: EntityDocument):
)

return output


def rest():
import argparse

parser = argparse.ArgumentParser(description="Run the http server for current")
parser.add_argument(
"-p",
"--port",
type=int,
help="The port number to run the server on",
default=8000,
)

args = parser.parse_args()

import uvicorn

uvicorn.run(
"cnlpt.api.current_rest:app", host="0.0.0.0", port=args.port, reload=True
)


if __name__ == "__main__":
rest()
27 changes: 2 additions & 25 deletions src/cnlpt/api/dtr_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
from transformers import Trainer
from transformers.tokenization_utils import PreTrainedTokenizer

from .cnlp_rest import (
from .temporal_rest import OLD_DTR_LABEL_LIST
from .utils import (
EntityDocument,
create_dataset,
create_instance_string,
initialize_cnlpt_model,
)
from .temporal_rest import OLD_DTR_LABEL_LIST

MODEL_NAME = "tmills/tiny-dtr"
logger = logging.getLogger("DocTimeRel Processor with xtremedistil encoder")
Expand Down Expand Up @@ -104,26 +104,3 @@ async def process(doc: EntityDocument):
)

return output


def rest():
import argparse

parser = argparse.ArgumentParser(description="Run the http server for negation")
parser.add_argument(
"-p",
"--port",
type=int,
help="The port number to run the server on",
default=8000,
)

args = parser.parse_args()

import uvicorn

uvicorn.run("cnlpt.api.dtr_rest:app", host="0.0.0.0", port=args.port, reload=True)


if __name__ == "__main__":
rest()
Loading

0 comments on commit 912713d

Please sign in to comment.