From 3a860f15c82b6602a3e5255e6c17e9e26ac3ca59 Mon Sep 17 00:00:00 2001 From: iusztinpaul Date: Wed, 11 Oct 2023 10:21:15 +0300 Subject: [PATCH] feat: Deploy financial assistant as a RESTful API --- .vscode/launch.json | 2 +- modules/financial_bot/.beamignore | 5 + modules/financial_bot/Makefile | 34 +++- .../financial_bot/financial_bot/__init__.py | 49 ++++++ .../financial_bot/financial_bot/constants.py | 2 +- .../financial_bot/langchain_bot.py | 7 +- modules/financial_bot/logging.yaml | 27 +++ modules/financial_bot/poetry.lock | 164 +++++++++++++++++- modules/financial_bot/pyproject.toml | 1 + modules/financial_bot/requirements.txt | 99 +++++++++++ modules/financial_bot/tools/bot.py | 96 ++++++++++ modules/financial_bot/tools/run_chain.py | 25 --- modules/training_pipeline/.beamignore | 3 - .../training_pipeline/tools/inference_run.py | 2 +- 14 files changed, 473 insertions(+), 43 deletions(-) create mode 100644 modules/financial_bot/.beamignore create mode 100644 modules/financial_bot/logging.yaml create mode 100644 modules/financial_bot/requirements.txt create mode 100644 modules/financial_bot/tools/bot.py delete mode 100644 modules/financial_bot/tools/run_chain.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 241648f..cf50888 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -123,7 +123,7 @@ "name": "Financial Bot [Dev]", "type": "python", "request": "launch", - "module": "tools.run_chain", + "module": "tools.bot", "justMyCode": false, "cwd": "${workspaceFolder}/modules/financial_bot", }, diff --git a/modules/financial_bot/.beamignore b/modules/financial_bot/.beamignore new file mode 100644 index 0000000..46ee623 --- /dev/null +++ b/modules/financial_bot/.beamignore @@ -0,0 +1,5 @@ +model_cache/* +results/* +output/* +logs/* +.ruff_cache/ diff --git a/modules/financial_bot/Makefile b/modules/financial_bot/Makefile index 2728f90..f326b8c 100644 --- a/modules/financial_bot/Makefile +++ b/modules/financial_bot/Makefile @@ -1,4 +1,4 @@ -### Install ### +# === Install === install: @echo "Installing financial bot..." @@ -25,10 +25,38 @@ add_dev: run: @echo "Running financial_bot..." - poetry run python -m tools.run_chain + poetry run python -m tools.bot -### PEP 8 ### +# === Beam === + +export_requirements: + @echo "Exporting requirements..." + + if [ -f requirements.txt ]; then rm requirements.txt; fi + poetry export -f requirements.txt --output requirements.txt --without-hashes + +run_beam: export_requirements + @echo "Running financial_bot on Beam..." + + BEAM_IGNORE_IMPORTS_OFF=true beam run ./tools/bot.py:run -d '{"env_file_path": "env", "model_cache_dir": "./model_cache"}' + +deploy_beam: export_requirements + @echo "Deploying financial_bot on Beam..." + + BEAM_IGNORE_IMPORTS_OFF=true beam deploy ./tools/bot.py:run + +call_restful_api: + curl -X POST \ + --compressed 'https://apps.beam.cloud/${DEPLOYMENT_ID}' \ + -H 'Accept: */*' \ + -H 'Accept-Encoding: gzip, deflate' \ + -H 'Authorization: Basic ${TOKEN}' \ + -H 'Connection: keep-alive' \ + -H 'Content-Type: application/json' \ + -d '{"about_me": "I am a student and I have some money that I want to invest.", "question": "Should I consider investing in stocks from the Tech Sector?"}' + +# === Formatting & Linting === # Be sure to install the dev dependencies first # lint_check: diff --git a/modules/financial_bot/financial_bot/__init__.py b/modules/financial_bot/financial_bot/__init__.py index e69de29..fb07bd3 100644 --- a/modules/financial_bot/financial_bot/__init__.py +++ b/modules/financial_bot/financial_bot/__init__.py @@ -0,0 +1,49 @@ +import logging +import logging.config +import os +import yaml + +from dotenv import load_dotenv, find_dotenv +from pathlib import Path + + +logger = logging.getLogger(__name__) + + +def initialize(logging_config_path: str = "logging.yaml", env_file_path: str = ".env"): + logger.info("Initializing logger...") + try: + initialize_logger(config_path=logging_config_path) + except FileNotFoundError: + logger.warning( + f"No logging configuration file found at: {logging_config_path}. Setting logging level to INFO." + ) + logging.basicConfig(level=logging.INFO) + + logger.info("Initializing env vars...") + if env_file_path is None: + env_file_path = find_dotenv(raise_error_if_not_found=True, usecwd=False) + + logger.info(f"Loading environment variables from: {env_file_path}") + found_env_file = load_dotenv(env_file_path, verbose=True, override=True) + if found_env_file is False: + raise RuntimeError(f"Could not find environment file at: {env_file_path}") + + +def initialize_logger( + config_path: str = "logging.yaml", logs_dir_name: str = "logs" +) -> logging.Logger: + """Initialize logger from a YAML config file.""" + + # Create logs directory. + config_path_parent = Path(config_path).parent + logs_dir = config_path_parent / logs_dir_name + logs_dir.mkdir(parents=True, exist_ok=True) + + with open(config_path, "rt") as f: + config = yaml.safe_load(f.read()) + + # Make sure that existing logger will still work. + config["disable_existing_loggers"] = False + + logging.config.dictConfig(config) diff --git a/modules/financial_bot/financial_bot/constants.py b/modules/financial_bot/financial_bot/constants.py index 3f94ca0..50d52b7 100644 --- a/modules/financial_bot/financial_bot/constants.py +++ b/modules/financial_bot/financial_bot/constants.py @@ -19,4 +19,4 @@ SYSTEM_MESSAGE = "You are a financial expert. Based on the context I provide, respond in a helpful manner" # === Misc === -DEBUG = True +DEBUG = False diff --git a/modules/financial_bot/financial_bot/langchain_bot.py b/modules/financial_bot/financial_bot/langchain_bot.py index df5b119..1c09eb5 100644 --- a/modules/financial_bot/financial_bot/langchain_bot.py +++ b/modules/financial_bot/financial_bot/langchain_bot.py @@ -1,4 +1,5 @@ import logging +from pathlib import Path from langchain import chains from langchain.memory import ConversationBufferMemory @@ -18,12 +19,16 @@ def __init__( self, llm_model_id: str = constants.LLM_MODEL_ID, llm_lora_model_id: str = constants.LLM_QLORA_CHECKPOINT, + model_cache_dir: Path = constants.CACHE_DIR, debug: bool = constants.DEBUG, ): self._qdrant_client = build_qdrant_client() self._embd_model = EmbeddingModelSingleton() self._llm_agent = build_huggingface_pipeline( - llm_model_id=llm_model_id, llm_lora_model_id=llm_lora_model_id, debug=debug + llm_model_id=llm_model_id, + llm_lora_model_id=llm_lora_model_id, + cache_dir=model_cache_dir, + debug=debug, ) self.finbot_chain = self.build_chain() diff --git a/modules/financial_bot/logging.yaml b/modules/financial_bot/logging.yaml new file mode 100644 index 0000000..3b1b6a4 --- /dev/null +++ b/modules/financial_bot/logging.yaml @@ -0,0 +1,27 @@ +version: 1 + +formatters: + simple: + format: "%(asctime)s - %(levelname)s - %(message)s" +handlers: + console: + class: logging.StreamHandler + level: DEBUG + formatter: simple + stream: ext://sys.stdout + file_info: + class: logging.FileHandler + formatter: simple + level: INFO + filename: logs/info.log + mode: 'a' + file_error: + class: logging.FileHandler + formatter: simple + level: ERROR + filename: logs/error.log + mode: 'a' +root: + level: INFO + handlers: [console, file_info, file_error] + \ No newline at end of file diff --git a/modules/financial_bot/poetry.lock b/modules/financial_bot/poetry.lock index 69d8e7b..ccd5db2 100644 --- a/modules/financial_bot/poetry.lock +++ b/modules/financial_bot/poetry.lock @@ -200,6 +200,26 @@ docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib- tests = ["attrs[tests-no-zope]", "zope-interface"] tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +[[package]] +name = "beam-sdk" +version = "0.14.6" +description = "" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "beam_sdk-0.14.6-py3-none-any.whl", hash = "sha256:3ca3f62bc0fd15585f70a71b380ee6dfd15cd96f2ecbba08f58946ded6d174c0"}, + {file = "beam_sdk-0.14.6.tar.gz", hash = "sha256:dd93e932ea75808e6be0d20ac892c40f3fbcfb625b17e2fdfe4a1e3acee308ff"}, +] + +[package.dependencies] +croniter = ">=1.3.7,<2.0.0" +importlib-metadata = "5.2.0" +Jinja2 = ">=3.1.2,<4.0.0" +marshmallow = "3.18.0" +marshmallow-dataclass = ">=8.5.9,<9.0.0" +typeguard = ">=2.13.3,<3.0.0" +validators = ">=0.20.0,<0.21.0" + [[package]] name = "bitsandbytes" version = "0.41.1" @@ -416,6 +436,20 @@ files = [ [package.dependencies] six = "*" +[[package]] +name = "croniter" +version = "1.4.1" +description = "croniter provides iteration for datetime object with cron like format" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "croniter-1.4.1-py2.py3-none-any.whl", hash = "sha256:9595da48af37ea06ec3a9f899738f1b2c1c13da3c38cea606ef7cd03ea421128"}, + {file = "croniter-1.4.1.tar.gz", hash = "sha256:1a6df60eacec3b7a0aa52a8f2ef251ae3dd2a7c7c8b9874e73e791636d55a361"}, +] + +[package.dependencies] +python-dateutil = "*" + [[package]] name = "dataclasses-json" version = "0.5.14" @@ -431,6 +465,17 @@ files = [ marshmallow = ">=3.18.0,<4.0.0" typing-inspect = ">=0.4.0,<1" +[[package]] +name = "decorator" +version = "5.1.1" +description = "Decorators for Humans" +optional = false +python-versions = ">=3.5" +files = [ + {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, + {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, +] + [[package]] name = "dulwich" version = "0.21.6" @@ -1024,6 +1069,25 @@ files = [ {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, ] +[[package]] +name = "importlib-metadata" +version = "5.2.0" +description = "Read metadata from Python packages" +optional = false +python-versions = ">=3.7" +files = [ + {file = "importlib_metadata-5.2.0-py3-none-any.whl", hash = "sha256:0eafa39ba42bf225fc00e67f701d71f85aead9f878569caf13c3724f704b970f"}, + {file = "importlib_metadata-5.2.0.tar.gz", hash = "sha256:404d48d62bba0b7a77ff9d405efd91501bef2e67ff4ace0bed40a0cf28c3c7cd"}, +] + +[package.dependencies] +zipp = ">=0.5" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +perf = ["ipython"] +testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)"] + [[package]] name = "jinja2" version = "3.1.2" @@ -1214,24 +1278,48 @@ files = [ [[package]] name = "marshmallow" -version = "3.20.1" +version = "3.18.0" description = "A lightweight library for converting complex datatypes to and from native Python datatypes." optional = false -python-versions = ">=3.8" +python-versions = ">=3.7" files = [ - {file = "marshmallow-3.20.1-py3-none-any.whl", hash = "sha256:684939db93e80ad3561392f47be0230743131560a41c5110684c16e21ade0a5c"}, - {file = "marshmallow-3.20.1.tar.gz", hash = "sha256:5d2371bbe42000f2b3fb5eaa065224df7d8f8597bc19a1bbfa5bfe7fba8da889"}, + {file = "marshmallow-3.18.0-py3-none-any.whl", hash = "sha256:35e02a3a06899c9119b785c12a22f4cda361745d66a71ab691fd7610202ae104"}, + {file = "marshmallow-3.18.0.tar.gz", hash = "sha256:6804c16114f7fce1f5b4dadc31f4674af23317fcc7f075da21e35c1a35d781f7"}, ] [package.dependencies] packaging = ">=17.0" [package.extras] -dev = ["flake8 (==6.0.0)", "flake8-bugbear (==23.7.10)", "mypy (==1.4.1)", "pre-commit (>=2.4,<4.0)", "pytest", "pytz", "simplejson", "tox"] -docs = ["alabaster (==0.7.13)", "autodocsumm (==0.2.11)", "sphinx (==7.0.1)", "sphinx-issues (==3.0.1)", "sphinx-version-warning (==1.1.2)"] -lint = ["flake8 (==6.0.0)", "flake8-bugbear (==23.7.10)", "mypy (==1.4.1)", "pre-commit (>=2.4,<4.0)"] +dev = ["flake8 (==5.0.4)", "flake8-bugbear (==22.9.11)", "mypy (==0.971)", "pre-commit (>=2.4,<3.0)", "pytest", "pytz", "simplejson", "tox"] +docs = ["alabaster (==0.7.12)", "autodocsumm (==0.2.9)", "sphinx (==5.1.1)", "sphinx-issues (==3.0.1)", "sphinx-version-warning (==1.1.2)"] +lint = ["flake8 (==5.0.4)", "flake8-bugbear (==22.9.11)", "mypy (==0.971)", "pre-commit (>=2.4,<3.0)"] tests = ["pytest", "pytz", "simplejson"] +[[package]] +name = "marshmallow-dataclass" +version = "8.6.0" +description = "Python library to convert dataclasses into marshmallow schemas." +optional = false +python-versions = ">=3.6" +files = [ + {file = "marshmallow_dataclass-8.6.0-py3-none-any.whl", hash = "sha256:7885e9b5f5287b64573b174d69334fd20de1628001a4fa2adc8e75be5196755e"}, + {file = "marshmallow_dataclass-8.6.0.tar.gz", hash = "sha256:a21f4d050a1d24249fd43aa56c7e4aea4b6454e049dc2c5f1496f479e30bf5d7"}, +] + +[package.dependencies] +marshmallow = ">=3.13.0,<4.0" +typing-extensions = {version = ">=4.2.0", markers = "python_version < \"3.11\" and python_version >= \"3.7\""} +typing-inspect = ">=0.8.0,<1.0" + +[package.extras] +dev = ["marshmallow (>=3.18.0,<4.0)", "marshmallow-enum", "pre-commit (>=2.17,<3.0)", "pytest (>=5.4)", "pytest-mypy-plugins (>=1.2.0)", "sphinx", "typeguard (>=2.4.1,<4.0.0)"] +docs = ["sphinx"] +enum = ["marshmallow (>=3.18.0,<4.0)", "marshmallow-enum"] +lint = ["pre-commit (>=2.17,<3.0)"] +tests = ["pytest (>=5.4)", "pytest-mypy-plugins (>=1.2.0)"] +union = ["typeguard (>=2.4.1,<4.0.0)"] + [[package]] name = "mdurl" version = "0.1.2" @@ -1661,6 +1749,20 @@ toml = ["toml"] tomli = ["tomli", "tomli-w"] yaml = ["ruamel.yaml (>=0.17)"] +[[package]] +name = "python-dateutil" +version = "2.8.2" +description = "Extensions to the standard Python datetime module" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, + {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, +] + +[package.dependencies] +six = ">=1.5" + [[package]] name = "python-dotenv" version = "1.0.0" @@ -2645,6 +2747,21 @@ torchhub = ["filelock", "huggingface-hub (>=0.15.1,<1.0)", "importlib-metadata", video = ["av (==9.2.0)", "decord (==0.6.0)"] vision = ["Pillow (<10.0.0)"] +[[package]] +name = "typeguard" +version = "2.13.3" +description = "Run-time type checker for Python" +optional = false +python-versions = ">=3.5.3" +files = [ + {file = "typeguard-2.13.3-py3-none-any.whl", hash = "sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1"}, + {file = "typeguard-2.13.3.tar.gz", hash = "sha256:00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4"}, +] + +[package.extras] +doc = ["sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["mypy", "pytest", "typing-extensions"] + [[package]] name = "types-requests" version = "2.31.0.2" @@ -2712,6 +2829,22 @@ brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +[[package]] +name = "validators" +version = "0.20.0" +description = "Python Data Validation for Humans™." +optional = false +python-versions = ">=3.4" +files = [ + {file = "validators-0.20.0.tar.gz", hash = "sha256:24148ce4e64100a2d5e267233e23e7afeb55316b47d30faae7eb6e7292bc226a"}, +] + +[package.dependencies] +decorator = ">=3.4.0" + +[package.extras] +test = ["flake8 (>=2.4.0)", "isort (>=4.2.2)", "pytest (>=2.2.3)"] + [[package]] name = "websocket-client" version = "1.3.3" @@ -2910,7 +3043,22 @@ files = [ idna = ">=2.0" multidict = ">=4.0" +[[package]] +name = "zipp" +version = "3.17.0" +description = "Backport of pathlib-compatible object wrapper for zip files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"}, + {file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] + [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "f9550d7d7bfafd0b554a40891f61d90c55838b225886a2030efed5736ccffb61" +content-hash = "db28575fa8aa58ca6edd45f643a9801e4759979ecdbcdf5fb2b18901e2409569" diff --git a/modules/financial_bot/pyproject.toml b/modules/financial_bot/pyproject.toml index 121818a..5a95d49 100644 --- a/modules/financial_bot/pyproject.toml +++ b/modules/financial_bot/pyproject.toml @@ -21,6 +21,7 @@ fire = "^0.5.0" comet-llm = "^1.1.0" bitsandbytes = "^0.41.1" torch = "2.0.1" +beam-sdk = "0.14.6" [tool.poetry.group.dev.dependencies] black = "^23.7.0" diff --git a/modules/financial_bot/requirements.txt b/modules/financial_bot/requirements.txt new file mode 100644 index 0000000..4770ee9 --- /dev/null +++ b/modules/financial_bot/requirements.txt @@ -0,0 +1,99 @@ +accelerate==0.21.0 ; python_version >= "3.10" and python_version < "3.12" +aiohttp==3.8.5 ; python_version >= "3.10" and python_version < "3.12" +aiosignal==1.3.1 ; python_version >= "3.10" and python_version < "3.12" +anyio==4.0.0 ; python_version >= "3.10" and python_version < "3.12" +async-timeout==4.0.3 ; python_version >= "3.10" and python_version < "3.12" +attrs==23.1.0 ; python_version >= "3.10" and python_version < "3.12" +beam-sdk==0.14.6 ; python_version >= "3.10" and python_version < "3.12" +bitsandbytes==0.41.1 ; python_version >= "3.10" and python_version < "3.12" +certifi==2023.7.22 ; python_version >= "3.10" and python_version < "3.12" +charset-normalizer==3.2.0 ; python_version >= "3.10" and python_version < "3.12" +colorama==0.4.6 ; python_version >= "3.10" and python_version < "3.12" and platform_system == "Windows" +comet-llm==1.3.0 ; python_version >= "3.10" and python_version < "3.12" +comet-ml==3.33.10 ; python_version >= "3.10" and python_version < "3.12" +configobj==5.0.8 ; python_version >= "3.10" and python_version < "3.12" +croniter==1.4.1 ; python_version >= "3.10" and python_version < "3.12" +dataclasses-json==0.5.14 ; python_version >= "3.10" and python_version < "3.12" +decorator==5.1.1 ; python_version >= "3.10" and python_version < "3.12" +dulwich==0.21.6 ; python_version >= "3.10" and python_version < "3.12" +einops==0.6.1 ; python_version >= "3.10" and python_version < "3.12" +everett[ini]==3.1.0 ; python_version >= "3.10" and python_version < "3.12" +exceptiongroup==1.1.3 ; python_version >= "3.10" and python_version < "3.11" +filelock==3.12.3 ; python_version >= "3.10" and python_version < "3.12" +fire==0.5.0 ; python_version >= "3.10" and python_version < "3.12" +flatten-dict==0.4.2 ; python_version >= "3.10" and python_version < "3.12" +frozenlist==1.4.0 ; python_version >= "3.10" and python_version < "3.12" +fsspec==2023.9.0 ; python_version >= "3.10" and python_version < "3.12" +greenlet==2.0.2 ; python_version >= "3.10" and python_version < "3.12" and (platform_machine == "aarch64" or platform_machine == "ppc64le" or platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "AMD64" or platform_machine == "win32" or platform_machine == "WIN32") +grpcio-tools==1.58.0 ; python_version >= "3.10" and python_version < "3.12" +grpcio==1.58.0 ; python_version >= "3.10" and python_version < "3.12" +h11==0.14.0 ; python_version >= "3.10" and python_version < "3.12" +h2==4.1.0 ; python_version >= "3.10" and python_version < "3.12" +hpack==4.0.0 ; python_version >= "3.10" and python_version < "3.12" +httpcore==0.17.3 ; python_version >= "3.10" and python_version < "3.12" +httpx[http2]==0.24.1 ; python_version >= "3.10" and python_version < "3.12" +huggingface-hub==0.16.4 ; python_version >= "3.10" and python_version < "3.12" +hyperframe==6.0.1 ; python_version >= "3.10" and python_version < "3.12" +idna==3.4 ; python_version >= "3.10" and python_version < "3.12" +importlib-metadata==5.2.0 ; python_version >= "3.10" and python_version < "3.12" +jinja2==3.1.2 ; python_version >= "3.10" and python_version < "3.12" +jsonschema-specifications==2023.7.1 ; python_version >= "3.10" and python_version < "3.12" +jsonschema==4.19.0 ; python_version >= "3.10" and python_version < "3.12" +langchain==0.0.285 ; python_version >= "3.10" and python_version < "3.12" +langsmith==0.0.35 ; python_version >= "3.10" and python_version < "3.12" +markdown-it-py==3.0.0 ; python_version >= "3.10" and python_version < "3.12" +markupsafe==2.1.3 ; python_version >= "3.10" and python_version < "3.12" +marshmallow-dataclass==8.6.0 ; python_version >= "3.10" and python_version < "3.12" +marshmallow==3.18.0 ; python_version >= "3.10" and python_version < "3.12" +mdurl==0.1.2 ; python_version >= "3.10" and python_version < "3.12" +mpmath==1.3.0 ; python_version >= "3.10" and python_version < "3.12" +multidict==6.0.4 ; python_version >= "3.10" and python_version < "3.12" +mypy-extensions==1.0.0 ; python_version >= "3.10" and python_version < "3.12" +networkx==3.1 ; python_version >= "3.10" and python_version < "3.12" +numexpr==2.8.5 ; python_version >= "3.10" and python_version < "3.12" +numpy==1.25.2 ; python_version >= "3.10" and python_version < "3.12" +packaging==23.1 ; python_version >= "3.10" and python_version < "3.12" +peft==0.4.0 ; python_version >= "3.10" and python_version < "3.12" +protobuf==4.24.3 ; python_version >= "3.10" and python_version < "3.12" +psutil==5.9.5 ; python_version >= "3.10" and python_version < "3.12" +pydantic==1.10.12 ; python_version >= "3.10" and python_version < "3.12" +pygments==2.16.1 ; python_version >= "3.10" and python_version < "3.12" +python-box==6.1.0 ; python_version >= "3.10" and python_version < "3.12" +python-dateutil==2.8.2 ; python_version >= "3.10" and python_version < "3.12" +python-dotenv==1.0.0 ; python_version >= "3.10" and python_version < "3.12" +pyyaml==6.0.1 ; python_version >= "3.10" and python_version < "3.12" +qdrant-client==1.1.1 ; python_version >= "3.10" and python_version < "3.12" +referencing==0.30.2 ; python_version >= "3.10" and python_version < "3.12" +regex==2023.8.8 ; python_version >= "3.10" and python_version < "3.12" +requests-toolbelt==1.0.0 ; python_version >= "3.10" and python_version < "3.12" +requests==2.31.0 ; python_version >= "3.10" and python_version < "3.12" +rich==13.5.2 ; python_version >= "3.10" and python_version < "3.12" +rpds-py==0.10.2 ; python_version >= "3.10" and python_version < "3.12" +safetensors==0.3.3 ; python_version >= "3.10" and python_version < "3.12" +scipy==1.11.2 ; python_version >= "3.10" and python_version < "3.12" +semantic-version==2.10.0 ; python_version >= "3.10" and python_version < "3.12" +sentry-sdk==1.30.0 ; python_version >= "3.10" and python_version < "3.12" +setuptools==68.2.0 ; python_version >= "3.10" and python_version < "3.12" +simplejson==3.19.1 ; python_version >= "3.10" and python_version < "3.12" +six==1.16.0 ; python_version >= "3.10" and python_version < "3.12" +sniffio==1.3.0 ; python_version >= "3.10" and python_version < "3.12" +sqlalchemy==2.0.20 ; python_version >= "3.10" and python_version < "3.12" +sympy==1.12 ; python_version >= "3.10" and python_version < "3.12" +tenacity==8.2.3 ; python_version >= "3.10" and python_version < "3.12" +termcolor==2.3.0 ; python_version >= "3.10" and python_version < "3.12" +tokenizers==0.13.3 ; python_version >= "3.10" and python_version < "3.12" +torch==2.0.1 ; python_version >= "3.10" and python_version < "3.12" +tqdm==4.66.1 ; python_version >= "3.10" and python_version < "3.12" +transformers==4.33.1 ; python_version >= "3.10" and python_version < "3.12" +typeguard==2.13.3 ; python_version >= "3.10" and python_version < "3.12" +types-requests==2.31.0.2 ; python_version >= "3.10" and python_version < "3.12" +types-urllib3==1.26.25.14 ; python_version >= "3.10" and python_version < "3.12" +typing-extensions==4.7.1 ; python_version >= "3.10" and python_version < "3.12" +typing-inspect==0.9.0 ; python_version >= "3.10" and python_version < "3.12" +urllib3==1.26.16 ; python_version >= "3.10" and python_version < "3.12" +validators==0.20.0 ; python_version >= "3.10" and python_version < "3.12" +websocket-client==1.3.3 ; python_version >= "3.10" and python_version < "3.12" +wrapt==1.15.0 ; python_version >= "3.10" and python_version < "3.12" +wurlitzer==3.0.3 ; python_version >= "3.10" and python_version < "3.12" +yarl==1.9.2 ; python_version >= "3.10" and python_version < "3.12" +zipp==3.17.0 ; python_version >= "3.10" and python_version < "3.12" diff --git a/modules/financial_bot/tools/bot.py b/modules/financial_bot/tools/bot.py new file mode 100644 index 0000000..0c91495 --- /dev/null +++ b/modules/financial_bot/tools/bot.py @@ -0,0 +1,96 @@ +import logging +from pathlib import Path + +import fire +from beam import App, Image, Runtime, Volume, VolumeType + +financial_bot = App( + name="financial_bot", + runtime=Runtime( + cpu=4, + memory="64Gi", + gpu="T4", + image=Image(python_version="python3.10", python_packages="requirements.txt"), + ), + volumes=[ + Volume( + path="./model_cache", name="model_cache", volume_type=VolumeType.Persistent + ), + ], +) + + +logger = logging.getLogger(__name__) + + +def load_models( + env_file_path: str = ".env", + logging_config_path: str = "logging.yaml", + model_cache_dir: str = "./model_cache", +): + from financial_bot import initialize + + # Be sure to initialize the environment variables before importing any other modules. + initialize(logging_config_path=logging_config_path, env_file_path=env_file_path) + + from financial_bot import utils + from financial_bot.langchain_bot import FinancialBot + + logger.info("#" * 100) + utils.log_available_gpu_memory() + utils.log_available_ram() + logger.info("#" * 100) + + bot = FinancialBot(model_cache_dir=Path(model_cache_dir)) + + return bot + + +@financial_bot.rest_api(keep_warm_seconds=300, loader=load_models) +def run(**inputs): + from financial_bot import utils + + logger.info("#" * 100) + utils.log_available_gpu_memory() + utils.log_available_ram() + logger.info("#" * 100) + + # TODO: Check how memory is persisted between requests. + bot = inputs["context"] + input_payload = { + "about_me": inputs["about_me"], + "question": inputs["question"], + } + response = bot.answer(**input_payload) + + return response + + +@financial_bot.rest_api(keep_warm_seconds=300, loader=load_models) +def run_dev(**inputs): + from financial_bot import utils + + logger.info("#" * 100) + utils.log_available_gpu_memory() + utils.log_available_ram() + logger.info("#" * 100) + + bot = inputs["context"] + + input_payload = { + "about_me": "I'm a student and I have some money that I want to invest.", + "question": "Should I consider investing in stocks from the Tech Sector?", + } + response = bot.answer(**input_payload) + print(response) + + next_question = "What about the Energy Sector?" + input_payload["question"] = next_question + response = bot.answer(**input_payload) + print(response) + + return response + + +if __name__ == "__main__": + fire.Fire(run) diff --git a/modules/financial_bot/tools/run_chain.py b/modules/financial_bot/tools/run_chain.py deleted file mode 100644 index 475db9a..0000000 --- a/modules/financial_bot/tools/run_chain.py +++ /dev/null @@ -1,25 +0,0 @@ -import dotenv -import fire - -from financial_bot.langchain_bot import FinancialBot - -dotenv.load_dotenv() - - -def main(): - bot = FinancialBot() - input_payload = { - "about_me": "I'm a student and I have some money that I want to invest.", - "question": "Should I consider investing in stocks from the Tech Sector?", - } - response = bot.answer(**input_payload) - print(response) - - next_question = "What about the Energy Sector?" - input_payload["question"] = next_question - response = bot.answer(**input_payload) - print(response) - - -if __name__ == "__main__": - fire.Fire(main) diff --git a/modules/training_pipeline/.beamignore b/modules/training_pipeline/.beamignore index 5d52966..7cb0b9f 100644 --- a/modules/training_pipeline/.beamignore +++ b/modules/training_pipeline/.beamignore @@ -1,9 +1,6 @@ -# Running artifacts model_cache/* results/* output/* logs/* .ruff_cache/ - -# Datasets dataset/* \ No newline at end of file diff --git a/modules/training_pipeline/tools/inference_run.py b/modules/training_pipeline/tools/inference_run.py index f10dae5..8a60fba 100644 --- a/modules/training_pipeline/tools/inference_run.py +++ b/modules/training_pipeline/tools/inference_run.py @@ -10,7 +10,7 @@ runtime=Runtime( cpu=4, memory="64Gi", - gpu="A10G", + gpu="T4", image=Image(python_version="python3.10", python_packages="requirements.txt"), ), volumes=[