Skip to content

Commit

Permalink
ensure consistency with ewizapp code
Browse files Browse the repository at this point in the history
  • Loading branch information
spodgorny9 committed Nov 8, 2024
1 parent df5d12d commit 35b8c6d
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions elm/wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ class EnergyWizardPostgres(EnergyWizardBase):
}
"""Optional mappings for weird azure names to tiktoken/openai names."""

DEFAULT_META_COLS = ('title', 'url', 'authors', 'year', 'category', 'id')
DEFAULT_META_COLS = ['title', 'url', 'nrel_id', 'id']
"""Default columns to retrieve for metadata"""

def __init__(self, db_host, db_port, db_name,
Expand All @@ -431,7 +431,8 @@ def __init__(self, db_host, db_port, db_name,
value is sqrt(n_lists).
meta_columns : list
List of metadata columns to retrieve from database. Default
query returns title and url.
query returns title, url, nrel_id, and id. nrel_id and id are
necessary to correctly format references.
cursor : psycopg2.extensions.cursor
PostgreSQL database cursor used to execute queries.
boto_client: botocore.client.BedrockRuntime
Expand All @@ -450,7 +451,7 @@ def __init__(self, db_host, db_port, db_name,
self.psycopg2 = try_import('psycopg2')

if meta_columns is None:
self.meta_columns = ['title', 'url', 'nrel_id', 'id']
self.meta_columns = self.DEFAULT_META_COLS
else:
self.meta_columns = meta_columns

Expand Down Expand Up @@ -482,14 +483,15 @@ def __init__(self, db_host, db_port, db_name,
access_key = os.getenv('AWS_ACCESS_KEY_ID')
secret_key = os.getenv('AWS_SECRET_ACCESS_KEY')
session_token = os.getenv('AWS_SESSION_TOKEN')
assert access_key is not None, "Must set AWS_ACCESS_KEY_ID!"
assert secret_key is not None, ("Must set AWS_SECRET_ACCESS_KEY!")
assert session_token is not None, "Must set AWS_SESSION_TOKEN!"
self._aws_client = boto3.client(service_name='bedrock-runtime',
region_name='us-west-2',
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
aws_session_token=session_token)
if access_key and secret_key and session_token:
self._aws_client = boto3.client(service_name='bedrock-runtime',
region_name='us-west-2',
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
aws_session_token=session_token)
else:
self._aws_client = boto3.client(service_name='bedrock-runtime',
region_name='us-west-2')
else:
self._aws_client = boto_client

Expand Down

0 comments on commit 35b8c6d

Please sign in to comment.