From 3a860f15c82b6602a3e5255e6c17e9e26ac3ca59 Mon Sep 17 00:00:00 2001
From: iusztinpaul
Date: Wed, 11 Oct 2023 10:21:15 +0300
Subject: [PATCH] feat: Deploy financial assistant as a RESTful API
---
.vscode/launch.json | 2 +-
modules/financial_bot/.beamignore | 5 +
modules/financial_bot/Makefile | 34 +++-
.../financial_bot/financial_bot/__init__.py | 49 ++++++
.../financial_bot/financial_bot/constants.py | 2 +-
.../financial_bot/langchain_bot.py | 7 +-
modules/financial_bot/logging.yaml | 27 +++
modules/financial_bot/poetry.lock | 164 +++++++++++++++++-
modules/financial_bot/pyproject.toml | 1 +
modules/financial_bot/requirements.txt | 99 +++++++++++
modules/financial_bot/tools/bot.py | 96 ++++++++++
modules/financial_bot/tools/run_chain.py | 25 ---
modules/training_pipeline/.beamignore | 3 -
.../training_pipeline/tools/inference_run.py | 2 +-
14 files changed, 473 insertions(+), 43 deletions(-)
create mode 100644 modules/financial_bot/.beamignore
create mode 100644 modules/financial_bot/logging.yaml
create mode 100644 modules/financial_bot/requirements.txt
create mode 100644 modules/financial_bot/tools/bot.py
delete mode 100644 modules/financial_bot/tools/run_chain.py
diff --git a/.vscode/launch.json b/.vscode/launch.json
index 241648f..cf50888 100644
--- a/.vscode/launch.json
+++ b/.vscode/launch.json
@@ -123,7 +123,7 @@
"name": "Financial Bot [Dev]",
"type": "python",
"request": "launch",
- "module": "tools.run_chain",
+ "module": "tools.bot",
"justMyCode": false,
"cwd": "${workspaceFolder}/modules/financial_bot",
},
diff --git a/modules/financial_bot/.beamignore b/modules/financial_bot/.beamignore
new file mode 100644
index 0000000..46ee623
--- /dev/null
+++ b/modules/financial_bot/.beamignore
@@ -0,0 +1,5 @@
+model_cache/*
+results/*
+output/*
+logs/*
+.ruff_cache/
diff --git a/modules/financial_bot/Makefile b/modules/financial_bot/Makefile
index 2728f90..f326b8c 100644
--- a/modules/financial_bot/Makefile
+++ b/modules/financial_bot/Makefile
@@ -1,4 +1,4 @@
-### Install ###
+# === Install ===
install:
@echo "Installing financial bot..."
@@ -25,10 +25,38 @@ add_dev:
run:
@echo "Running financial_bot..."
- poetry run python -m tools.run_chain
+ poetry run python -m tools.bot
-### PEP 8 ###
+# === Beam ===
+
+export_requirements:
+ @echo "Exporting requirements..."
+
+ if [ -f requirements.txt ]; then rm requirements.txt; fi
+ poetry export -f requirements.txt --output requirements.txt --without-hashes
+
+run_beam: export_requirements
+ @echo "Running financial_bot on Beam..."
+
+ BEAM_IGNORE_IMPORTS_OFF=true beam run ./tools/bot.py:run -d '{"env_file_path": "env", "model_cache_dir": "./model_cache"}'
+
+deploy_beam: export_requirements
+ @echo "Deploying financial_bot on Beam..."
+
+ BEAM_IGNORE_IMPORTS_OFF=true beam deploy ./tools/bot.py:run
+
+call_restful_api:
+ curl -X POST \
+ --compressed 'https://apps.beam.cloud/${DEPLOYMENT_ID}' \
+ -H 'Accept: */*' \
+ -H 'Accept-Encoding: gzip, deflate' \
+ -H 'Authorization: Basic ${TOKEN}' \
+ -H 'Connection: keep-alive' \
+ -H 'Content-Type: application/json' \
+ -d '{"about_me": "I am a student and I have some money that I want to invest.", "question": "Should I consider investing in stocks from the Tech Sector?"}'
+
+# === Formatting & Linting ===
# Be sure to install the dev dependencies first #
lint_check:
diff --git a/modules/financial_bot/financial_bot/__init__.py b/modules/financial_bot/financial_bot/__init__.py
index e69de29..fb07bd3 100644
--- a/modules/financial_bot/financial_bot/__init__.py
+++ b/modules/financial_bot/financial_bot/__init__.py
@@ -0,0 +1,49 @@
+import logging
+import logging.config
+import os
+import yaml
+
+from dotenv import load_dotenv, find_dotenv
+from pathlib import Path
+
+
+logger = logging.getLogger(__name__)
+
+
+def initialize(logging_config_path: str = "logging.yaml", env_file_path: str = ".env"):
+ logger.info("Initializing logger...")
+ try:
+ initialize_logger(config_path=logging_config_path)
+ except FileNotFoundError:
+ logger.warning(
+ f"No logging configuration file found at: {logging_config_path}. Setting logging level to INFO."
+ )
+ logging.basicConfig(level=logging.INFO)
+
+ logger.info("Initializing env vars...")
+ if env_file_path is None:
+ env_file_path = find_dotenv(raise_error_if_not_found=True, usecwd=False)
+
+ logger.info(f"Loading environment variables from: {env_file_path}")
+ found_env_file = load_dotenv(env_file_path, verbose=True, override=True)
+ if found_env_file is False:
+ raise RuntimeError(f"Could not find environment file at: {env_file_path}")
+
+
+def initialize_logger(
+ config_path: str = "logging.yaml", logs_dir_name: str = "logs"
+) -> logging.Logger:
+ """Initialize logger from a YAML config file."""
+
+ # Create logs directory.
+ config_path_parent = Path(config_path).parent
+ logs_dir = config_path_parent / logs_dir_name
+ logs_dir.mkdir(parents=True, exist_ok=True)
+
+ with open(config_path, "rt") as f:
+ config = yaml.safe_load(f.read())
+
+ # Make sure that existing logger will still work.
+ config["disable_existing_loggers"] = False
+
+ logging.config.dictConfig(config)
diff --git a/modules/financial_bot/financial_bot/constants.py b/modules/financial_bot/financial_bot/constants.py
index 3f94ca0..50d52b7 100644
--- a/modules/financial_bot/financial_bot/constants.py
+++ b/modules/financial_bot/financial_bot/constants.py
@@ -19,4 +19,4 @@
SYSTEM_MESSAGE = "You are a financial expert. Based on the context I provide, respond in a helpful manner"
# === Misc ===
-DEBUG = True
+DEBUG = False
diff --git a/modules/financial_bot/financial_bot/langchain_bot.py b/modules/financial_bot/financial_bot/langchain_bot.py
index df5b119..1c09eb5 100644
--- a/modules/financial_bot/financial_bot/langchain_bot.py
+++ b/modules/financial_bot/financial_bot/langchain_bot.py
@@ -1,4 +1,5 @@
import logging
+from pathlib import Path
from langchain import chains
from langchain.memory import ConversationBufferMemory
@@ -18,12 +19,16 @@ def __init__(
self,
llm_model_id: str = constants.LLM_MODEL_ID,
llm_lora_model_id: str = constants.LLM_QLORA_CHECKPOINT,
+ model_cache_dir: Path = constants.CACHE_DIR,
debug: bool = constants.DEBUG,
):
self._qdrant_client = build_qdrant_client()
self._embd_model = EmbeddingModelSingleton()
self._llm_agent = build_huggingface_pipeline(
- llm_model_id=llm_model_id, llm_lora_model_id=llm_lora_model_id, debug=debug
+ llm_model_id=llm_model_id,
+ llm_lora_model_id=llm_lora_model_id,
+ cache_dir=model_cache_dir,
+ debug=debug,
)
self.finbot_chain = self.build_chain()
diff --git a/modules/financial_bot/logging.yaml b/modules/financial_bot/logging.yaml
new file mode 100644
index 0000000..3b1b6a4
--- /dev/null
+++ b/modules/financial_bot/logging.yaml
@@ -0,0 +1,27 @@
+version: 1
+
+formatters:
+ simple:
+ format: "%(asctime)s - %(levelname)s - %(message)s"
+handlers:
+ console:
+ class: logging.StreamHandler
+ level: DEBUG
+ formatter: simple
+ stream: ext://sys.stdout
+ file_info:
+ class: logging.FileHandler
+ formatter: simple
+ level: INFO
+ filename: logs/info.log
+ mode: 'a'
+ file_error:
+ class: logging.FileHandler
+ formatter: simple
+ level: ERROR
+ filename: logs/error.log
+ mode: 'a'
+root:
+ level: INFO
+ handlers: [console, file_info, file_error]
+
\ No newline at end of file
diff --git a/modules/financial_bot/poetry.lock b/modules/financial_bot/poetry.lock
index 69d8e7b..ccd5db2 100644
--- a/modules/financial_bot/poetry.lock
+++ b/modules/financial_bot/poetry.lock
@@ -200,6 +200,26 @@ docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-
tests = ["attrs[tests-no-zope]", "zope-interface"]
tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
+[[package]]
+name = "beam-sdk"
+version = "0.14.6"
+description = ""
+optional = false
+python-versions = ">=3.7,<4.0"
+files = [
+ {file = "beam_sdk-0.14.6-py3-none-any.whl", hash = "sha256:3ca3f62bc0fd15585f70a71b380ee6dfd15cd96f2ecbba08f58946ded6d174c0"},
+ {file = "beam_sdk-0.14.6.tar.gz", hash = "sha256:dd93e932ea75808e6be0d20ac892c40f3fbcfb625b17e2fdfe4a1e3acee308ff"},
+]
+
+[package.dependencies]
+croniter = ">=1.3.7,<2.0.0"
+importlib-metadata = "5.2.0"
+Jinja2 = ">=3.1.2,<4.0.0"
+marshmallow = "3.18.0"
+marshmallow-dataclass = ">=8.5.9,<9.0.0"
+typeguard = ">=2.13.3,<3.0.0"
+validators = ">=0.20.0,<0.21.0"
+
[[package]]
name = "bitsandbytes"
version = "0.41.1"
@@ -416,6 +436,20 @@ files = [
[package.dependencies]
six = "*"
+[[package]]
+name = "croniter"
+version = "1.4.1"
+description = "croniter provides iteration for datetime object with cron like format"
+optional = false
+python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+files = [
+ {file = "croniter-1.4.1-py2.py3-none-any.whl", hash = "sha256:9595da48af37ea06ec3a9f899738f1b2c1c13da3c38cea606ef7cd03ea421128"},
+ {file = "croniter-1.4.1.tar.gz", hash = "sha256:1a6df60eacec3b7a0aa52a8f2ef251ae3dd2a7c7c8b9874e73e791636d55a361"},
+]
+
+[package.dependencies]
+python-dateutil = "*"
+
[[package]]
name = "dataclasses-json"
version = "0.5.14"
@@ -431,6 +465,17 @@ files = [
marshmallow = ">=3.18.0,<4.0.0"
typing-inspect = ">=0.4.0,<1"
+[[package]]
+name = "decorator"
+version = "5.1.1"
+description = "Decorators for Humans"
+optional = false
+python-versions = ">=3.5"
+files = [
+ {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"},
+ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
+]
+
[[package]]
name = "dulwich"
version = "0.21.6"
@@ -1024,6 +1069,25 @@ files = [
{file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"},
]
+[[package]]
+name = "importlib-metadata"
+version = "5.2.0"
+description = "Read metadata from Python packages"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "importlib_metadata-5.2.0-py3-none-any.whl", hash = "sha256:0eafa39ba42bf225fc00e67f701d71f85aead9f878569caf13c3724f704b970f"},
+ {file = "importlib_metadata-5.2.0.tar.gz", hash = "sha256:404d48d62bba0b7a77ff9d405efd91501bef2e67ff4ace0bed40a0cf28c3c7cd"},
+]
+
+[package.dependencies]
+zipp = ">=0.5"
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+perf = ["ipython"]
+testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)"]
+
[[package]]
name = "jinja2"
version = "3.1.2"
@@ -1214,24 +1278,48 @@ files = [
[[package]]
name = "marshmallow"
-version = "3.20.1"
+version = "3.18.0"
description = "A lightweight library for converting complex datatypes to and from native Python datatypes."
optional = false
-python-versions = ">=3.8"
+python-versions = ">=3.7"
files = [
- {file = "marshmallow-3.20.1-py3-none-any.whl", hash = "sha256:684939db93e80ad3561392f47be0230743131560a41c5110684c16e21ade0a5c"},
- {file = "marshmallow-3.20.1.tar.gz", hash = "sha256:5d2371bbe42000f2b3fb5eaa065224df7d8f8597bc19a1bbfa5bfe7fba8da889"},
+ {file = "marshmallow-3.18.0-py3-none-any.whl", hash = "sha256:35e02a3a06899c9119b785c12a22f4cda361745d66a71ab691fd7610202ae104"},
+ {file = "marshmallow-3.18.0.tar.gz", hash = "sha256:6804c16114f7fce1f5b4dadc31f4674af23317fcc7f075da21e35c1a35d781f7"},
]
[package.dependencies]
packaging = ">=17.0"
[package.extras]
-dev = ["flake8 (==6.0.0)", "flake8-bugbear (==23.7.10)", "mypy (==1.4.1)", "pre-commit (>=2.4,<4.0)", "pytest", "pytz", "simplejson", "tox"]
-docs = ["alabaster (==0.7.13)", "autodocsumm (==0.2.11)", "sphinx (==7.0.1)", "sphinx-issues (==3.0.1)", "sphinx-version-warning (==1.1.2)"]
-lint = ["flake8 (==6.0.0)", "flake8-bugbear (==23.7.10)", "mypy (==1.4.1)", "pre-commit (>=2.4,<4.0)"]
+dev = ["flake8 (==5.0.4)", "flake8-bugbear (==22.9.11)", "mypy (==0.971)", "pre-commit (>=2.4,<3.0)", "pytest", "pytz", "simplejson", "tox"]
+docs = ["alabaster (==0.7.12)", "autodocsumm (==0.2.9)", "sphinx (==5.1.1)", "sphinx-issues (==3.0.1)", "sphinx-version-warning (==1.1.2)"]
+lint = ["flake8 (==5.0.4)", "flake8-bugbear (==22.9.11)", "mypy (==0.971)", "pre-commit (>=2.4,<3.0)"]
tests = ["pytest", "pytz", "simplejson"]
+[[package]]
+name = "marshmallow-dataclass"
+version = "8.6.0"
+description = "Python library to convert dataclasses into marshmallow schemas."
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "marshmallow_dataclass-8.6.0-py3-none-any.whl", hash = "sha256:7885e9b5f5287b64573b174d69334fd20de1628001a4fa2adc8e75be5196755e"},
+ {file = "marshmallow_dataclass-8.6.0.tar.gz", hash = "sha256:a21f4d050a1d24249fd43aa56c7e4aea4b6454e049dc2c5f1496f479e30bf5d7"},
+]
+
+[package.dependencies]
+marshmallow = ">=3.13.0,<4.0"
+typing-extensions = {version = ">=4.2.0", markers = "python_version < \"3.11\" and python_version >= \"3.7\""}
+typing-inspect = ">=0.8.0,<1.0"
+
+[package.extras]
+dev = ["marshmallow (>=3.18.0,<4.0)", "marshmallow-enum", "pre-commit (>=2.17,<3.0)", "pytest (>=5.4)", "pytest-mypy-plugins (>=1.2.0)", "sphinx", "typeguard (>=2.4.1,<4.0.0)"]
+docs = ["sphinx"]
+enum = ["marshmallow (>=3.18.0,<4.0)", "marshmallow-enum"]
+lint = ["pre-commit (>=2.17,<3.0)"]
+tests = ["pytest (>=5.4)", "pytest-mypy-plugins (>=1.2.0)"]
+union = ["typeguard (>=2.4.1,<4.0.0)"]
+
[[package]]
name = "mdurl"
version = "0.1.2"
@@ -1661,6 +1749,20 @@ toml = ["toml"]
tomli = ["tomli", "tomli-w"]
yaml = ["ruamel.yaml (>=0.17)"]
+[[package]]
+name = "python-dateutil"
+version = "2.8.2"
+description = "Extensions to the standard Python datetime module"
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
+files = [
+ {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
+ {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
+]
+
+[package.dependencies]
+six = ">=1.5"
+
[[package]]
name = "python-dotenv"
version = "1.0.0"
@@ -2645,6 +2747,21 @@ torchhub = ["filelock", "huggingface-hub (>=0.15.1,<1.0)", "importlib-metadata",
video = ["av (==9.2.0)", "decord (==0.6.0)"]
vision = ["Pillow (<10.0.0)"]
+[[package]]
+name = "typeguard"
+version = "2.13.3"
+description = "Run-time type checker for Python"
+optional = false
+python-versions = ">=3.5.3"
+files = [
+ {file = "typeguard-2.13.3-py3-none-any.whl", hash = "sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1"},
+ {file = "typeguard-2.13.3.tar.gz", hash = "sha256:00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4"},
+]
+
+[package.extras]
+doc = ["sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"]
+test = ["mypy", "pytest", "typing-extensions"]
+
[[package]]
name = "types-requests"
version = "2.31.0.2"
@@ -2712,6 +2829,22 @@ brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"]
secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"]
socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
+[[package]]
+name = "validators"
+version = "0.20.0"
+description = "Python Data Validation for Humans™."
+optional = false
+python-versions = ">=3.4"
+files = [
+ {file = "validators-0.20.0.tar.gz", hash = "sha256:24148ce4e64100a2d5e267233e23e7afeb55316b47d30faae7eb6e7292bc226a"},
+]
+
+[package.dependencies]
+decorator = ">=3.4.0"
+
+[package.extras]
+test = ["flake8 (>=2.4.0)", "isort (>=4.2.2)", "pytest (>=2.2.3)"]
+
[[package]]
name = "websocket-client"
version = "1.3.3"
@@ -2910,7 +3043,22 @@ files = [
idna = ">=2.0"
multidict = ">=4.0"
+[[package]]
+name = "zipp"
+version = "3.17.0"
+description = "Backport of pathlib-compatible object wrapper for zip files"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"},
+ {file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"},
+]
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
+testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
+
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.12"
-content-hash = "f9550d7d7bfafd0b554a40891f61d90c55838b225886a2030efed5736ccffb61"
+content-hash = "db28575fa8aa58ca6edd45f643a9801e4759979ecdbcdf5fb2b18901e2409569"
diff --git a/modules/financial_bot/pyproject.toml b/modules/financial_bot/pyproject.toml
index 121818a..5a95d49 100644
--- a/modules/financial_bot/pyproject.toml
+++ b/modules/financial_bot/pyproject.toml
@@ -21,6 +21,7 @@ fire = "^0.5.0"
comet-llm = "^1.1.0"
bitsandbytes = "^0.41.1"
torch = "2.0.1"
+beam-sdk = "0.14.6"
[tool.poetry.group.dev.dependencies]
black = "^23.7.0"
diff --git a/modules/financial_bot/requirements.txt b/modules/financial_bot/requirements.txt
new file mode 100644
index 0000000..4770ee9
--- /dev/null
+++ b/modules/financial_bot/requirements.txt
@@ -0,0 +1,99 @@
+accelerate==0.21.0 ; python_version >= "3.10" and python_version < "3.12"
+aiohttp==3.8.5 ; python_version >= "3.10" and python_version < "3.12"
+aiosignal==1.3.1 ; python_version >= "3.10" and python_version < "3.12"
+anyio==4.0.0 ; python_version >= "3.10" and python_version < "3.12"
+async-timeout==4.0.3 ; python_version >= "3.10" and python_version < "3.12"
+attrs==23.1.0 ; python_version >= "3.10" and python_version < "3.12"
+beam-sdk==0.14.6 ; python_version >= "3.10" and python_version < "3.12"
+bitsandbytes==0.41.1 ; python_version >= "3.10" and python_version < "3.12"
+certifi==2023.7.22 ; python_version >= "3.10" and python_version < "3.12"
+charset-normalizer==3.2.0 ; python_version >= "3.10" and python_version < "3.12"
+colorama==0.4.6 ; python_version >= "3.10" and python_version < "3.12" and platform_system == "Windows"
+comet-llm==1.3.0 ; python_version >= "3.10" and python_version < "3.12"
+comet-ml==3.33.10 ; python_version >= "3.10" and python_version < "3.12"
+configobj==5.0.8 ; python_version >= "3.10" and python_version < "3.12"
+croniter==1.4.1 ; python_version >= "3.10" and python_version < "3.12"
+dataclasses-json==0.5.14 ; python_version >= "3.10" and python_version < "3.12"
+decorator==5.1.1 ; python_version >= "3.10" and python_version < "3.12"
+dulwich==0.21.6 ; python_version >= "3.10" and python_version < "3.12"
+einops==0.6.1 ; python_version >= "3.10" and python_version < "3.12"
+everett[ini]==3.1.0 ; python_version >= "3.10" and python_version < "3.12"
+exceptiongroup==1.1.3 ; python_version >= "3.10" and python_version < "3.11"
+filelock==3.12.3 ; python_version >= "3.10" and python_version < "3.12"
+fire==0.5.0 ; python_version >= "3.10" and python_version < "3.12"
+flatten-dict==0.4.2 ; python_version >= "3.10" and python_version < "3.12"
+frozenlist==1.4.0 ; python_version >= "3.10" and python_version < "3.12"
+fsspec==2023.9.0 ; python_version >= "3.10" and python_version < "3.12"
+greenlet==2.0.2 ; python_version >= "3.10" and python_version < "3.12" and (platform_machine == "aarch64" or platform_machine == "ppc64le" or platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "AMD64" or platform_machine == "win32" or platform_machine == "WIN32")
+grpcio-tools==1.58.0 ; python_version >= "3.10" and python_version < "3.12"
+grpcio==1.58.0 ; python_version >= "3.10" and python_version < "3.12"
+h11==0.14.0 ; python_version >= "3.10" and python_version < "3.12"
+h2==4.1.0 ; python_version >= "3.10" and python_version < "3.12"
+hpack==4.0.0 ; python_version >= "3.10" and python_version < "3.12"
+httpcore==0.17.3 ; python_version >= "3.10" and python_version < "3.12"
+httpx[http2]==0.24.1 ; python_version >= "3.10" and python_version < "3.12"
+huggingface-hub==0.16.4 ; python_version >= "3.10" and python_version < "3.12"
+hyperframe==6.0.1 ; python_version >= "3.10" and python_version < "3.12"
+idna==3.4 ; python_version >= "3.10" and python_version < "3.12"
+importlib-metadata==5.2.0 ; python_version >= "3.10" and python_version < "3.12"
+jinja2==3.1.2 ; python_version >= "3.10" and python_version < "3.12"
+jsonschema-specifications==2023.7.1 ; python_version >= "3.10" and python_version < "3.12"
+jsonschema==4.19.0 ; python_version >= "3.10" and python_version < "3.12"
+langchain==0.0.285 ; python_version >= "3.10" and python_version < "3.12"
+langsmith==0.0.35 ; python_version >= "3.10" and python_version < "3.12"
+markdown-it-py==3.0.0 ; python_version >= "3.10" and python_version < "3.12"
+markupsafe==2.1.3 ; python_version >= "3.10" and python_version < "3.12"
+marshmallow-dataclass==8.6.0 ; python_version >= "3.10" and python_version < "3.12"
+marshmallow==3.18.0 ; python_version >= "3.10" and python_version < "3.12"
+mdurl==0.1.2 ; python_version >= "3.10" and python_version < "3.12"
+mpmath==1.3.0 ; python_version >= "3.10" and python_version < "3.12"
+multidict==6.0.4 ; python_version >= "3.10" and python_version < "3.12"
+mypy-extensions==1.0.0 ; python_version >= "3.10" and python_version < "3.12"
+networkx==3.1 ; python_version >= "3.10" and python_version < "3.12"
+numexpr==2.8.5 ; python_version >= "3.10" and python_version < "3.12"
+numpy==1.25.2 ; python_version >= "3.10" and python_version < "3.12"
+packaging==23.1 ; python_version >= "3.10" and python_version < "3.12"
+peft==0.4.0 ; python_version >= "3.10" and python_version < "3.12"
+protobuf==4.24.3 ; python_version >= "3.10" and python_version < "3.12"
+psutil==5.9.5 ; python_version >= "3.10" and python_version < "3.12"
+pydantic==1.10.12 ; python_version >= "3.10" and python_version < "3.12"
+pygments==2.16.1 ; python_version >= "3.10" and python_version < "3.12"
+python-box==6.1.0 ; python_version >= "3.10" and python_version < "3.12"
+python-dateutil==2.8.2 ; python_version >= "3.10" and python_version < "3.12"
+python-dotenv==1.0.0 ; python_version >= "3.10" and python_version < "3.12"
+pyyaml==6.0.1 ; python_version >= "3.10" and python_version < "3.12"
+qdrant-client==1.1.1 ; python_version >= "3.10" and python_version < "3.12"
+referencing==0.30.2 ; python_version >= "3.10" and python_version < "3.12"
+regex==2023.8.8 ; python_version >= "3.10" and python_version < "3.12"
+requests-toolbelt==1.0.0 ; python_version >= "3.10" and python_version < "3.12"
+requests==2.31.0 ; python_version >= "3.10" and python_version < "3.12"
+rich==13.5.2 ; python_version >= "3.10" and python_version < "3.12"
+rpds-py==0.10.2 ; python_version >= "3.10" and python_version < "3.12"
+safetensors==0.3.3 ; python_version >= "3.10" and python_version < "3.12"
+scipy==1.11.2 ; python_version >= "3.10" and python_version < "3.12"
+semantic-version==2.10.0 ; python_version >= "3.10" and python_version < "3.12"
+sentry-sdk==1.30.0 ; python_version >= "3.10" and python_version < "3.12"
+setuptools==68.2.0 ; python_version >= "3.10" and python_version < "3.12"
+simplejson==3.19.1 ; python_version >= "3.10" and python_version < "3.12"
+six==1.16.0 ; python_version >= "3.10" and python_version < "3.12"
+sniffio==1.3.0 ; python_version >= "3.10" and python_version < "3.12"
+sqlalchemy==2.0.20 ; python_version >= "3.10" and python_version < "3.12"
+sympy==1.12 ; python_version >= "3.10" and python_version < "3.12"
+tenacity==8.2.3 ; python_version >= "3.10" and python_version < "3.12"
+termcolor==2.3.0 ; python_version >= "3.10" and python_version < "3.12"
+tokenizers==0.13.3 ; python_version >= "3.10" and python_version < "3.12"
+torch==2.0.1 ; python_version >= "3.10" and python_version < "3.12"
+tqdm==4.66.1 ; python_version >= "3.10" and python_version < "3.12"
+transformers==4.33.1 ; python_version >= "3.10" and python_version < "3.12"
+typeguard==2.13.3 ; python_version >= "3.10" and python_version < "3.12"
+types-requests==2.31.0.2 ; python_version >= "3.10" and python_version < "3.12"
+types-urllib3==1.26.25.14 ; python_version >= "3.10" and python_version < "3.12"
+typing-extensions==4.7.1 ; python_version >= "3.10" and python_version < "3.12"
+typing-inspect==0.9.0 ; python_version >= "3.10" and python_version < "3.12"
+urllib3==1.26.16 ; python_version >= "3.10" and python_version < "3.12"
+validators==0.20.0 ; python_version >= "3.10" and python_version < "3.12"
+websocket-client==1.3.3 ; python_version >= "3.10" and python_version < "3.12"
+wrapt==1.15.0 ; python_version >= "3.10" and python_version < "3.12"
+wurlitzer==3.0.3 ; python_version >= "3.10" and python_version < "3.12"
+yarl==1.9.2 ; python_version >= "3.10" and python_version < "3.12"
+zipp==3.17.0 ; python_version >= "3.10" and python_version < "3.12"
diff --git a/modules/financial_bot/tools/bot.py b/modules/financial_bot/tools/bot.py
new file mode 100644
index 0000000..0c91495
--- /dev/null
+++ b/modules/financial_bot/tools/bot.py
@@ -0,0 +1,96 @@
+import logging
+from pathlib import Path
+
+import fire
+from beam import App, Image, Runtime, Volume, VolumeType
+
+financial_bot = App(
+ name="financial_bot",
+ runtime=Runtime(
+ cpu=4,
+ memory="64Gi",
+ gpu="T4",
+ image=Image(python_version="python3.10", python_packages="requirements.txt"),
+ ),
+ volumes=[
+ Volume(
+ path="./model_cache", name="model_cache", volume_type=VolumeType.Persistent
+ ),
+ ],
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+def load_models(
+ env_file_path: str = ".env",
+ logging_config_path: str = "logging.yaml",
+ model_cache_dir: str = "./model_cache",
+):
+ from financial_bot import initialize
+
+ # Be sure to initialize the environment variables before importing any other modules.
+ initialize(logging_config_path=logging_config_path, env_file_path=env_file_path)
+
+ from financial_bot import utils
+ from financial_bot.langchain_bot import FinancialBot
+
+ logger.info("#" * 100)
+ utils.log_available_gpu_memory()
+ utils.log_available_ram()
+ logger.info("#" * 100)
+
+ bot = FinancialBot(model_cache_dir=Path(model_cache_dir))
+
+ return bot
+
+
+@financial_bot.rest_api(keep_warm_seconds=300, loader=load_models)
+def run(**inputs):
+ from financial_bot import utils
+
+ logger.info("#" * 100)
+ utils.log_available_gpu_memory()
+ utils.log_available_ram()
+ logger.info("#" * 100)
+
+ # TODO: Check how memory is persisted between requests.
+ bot = inputs["context"]
+ input_payload = {
+ "about_me": inputs["about_me"],
+ "question": inputs["question"],
+ }
+ response = bot.answer(**input_payload)
+
+ return response
+
+
+@financial_bot.rest_api(keep_warm_seconds=300, loader=load_models)
+def run_dev(**inputs):
+ from financial_bot import utils
+
+ logger.info("#" * 100)
+ utils.log_available_gpu_memory()
+ utils.log_available_ram()
+ logger.info("#" * 100)
+
+ bot = inputs["context"]
+
+ input_payload = {
+ "about_me": "I'm a student and I have some money that I want to invest.",
+ "question": "Should I consider investing in stocks from the Tech Sector?",
+ }
+ response = bot.answer(**input_payload)
+ print(response)
+
+ next_question = "What about the Energy Sector?"
+ input_payload["question"] = next_question
+ response = bot.answer(**input_payload)
+ print(response)
+
+ return response
+
+
+if __name__ == "__main__":
+ fire.Fire(run)
diff --git a/modules/financial_bot/tools/run_chain.py b/modules/financial_bot/tools/run_chain.py
deleted file mode 100644
index 475db9a..0000000
--- a/modules/financial_bot/tools/run_chain.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import dotenv
-import fire
-
-from financial_bot.langchain_bot import FinancialBot
-
-dotenv.load_dotenv()
-
-
-def main():
- bot = FinancialBot()
- input_payload = {
- "about_me": "I'm a student and I have some money that I want to invest.",
- "question": "Should I consider investing in stocks from the Tech Sector?",
- }
- response = bot.answer(**input_payload)
- print(response)
-
- next_question = "What about the Energy Sector?"
- input_payload["question"] = next_question
- response = bot.answer(**input_payload)
- print(response)
-
-
-if __name__ == "__main__":
- fire.Fire(main)
diff --git a/modules/training_pipeline/.beamignore b/modules/training_pipeline/.beamignore
index 5d52966..7cb0b9f 100644
--- a/modules/training_pipeline/.beamignore
+++ b/modules/training_pipeline/.beamignore
@@ -1,9 +1,6 @@
-# Running artifacts
model_cache/*
results/*
output/*
logs/*
.ruff_cache/
-
-# Datasets
dataset/*
\ No newline at end of file
diff --git a/modules/training_pipeline/tools/inference_run.py b/modules/training_pipeline/tools/inference_run.py
index f10dae5..8a60fba 100644
--- a/modules/training_pipeline/tools/inference_run.py
+++ b/modules/training_pipeline/tools/inference_run.py
@@ -10,7 +10,7 @@
runtime=Runtime(
cpu=4,
memory="64Gi",
- gpu="A10G",
+ gpu="T4",
image=Image(python_version="python3.10", python_packages="requirements.txt"),
),
volumes=[