Skip to content

Commit

Permalink
Add ruff format check; Fix existing formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
yoomlam committed Apr 4, 2024
1 parent 159b643 commit 1b30652
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 94 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci-linter-ruff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
- uses: chartboost/ruff-action@v1
with:
args: 'format --check --diff'
2 changes: 2 additions & 0 deletions 01-resource-referral/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain_core.runnables import RunnableLambda
from langchain_core.prompt_values import PromptValue


def stacktrace():
traceback.print_stack()

Expand All @@ -20,6 +21,7 @@ def debug_here(local_vars):

def debug_runnable(prefix: str):
"""Useful to see output/input between Runnables in a LangChain"""

def debug_chainlink(x):
print(f"{prefix if prefix else 'DEBUG_CHAINLINK'}")
if isinstance(x, PromptValue):
Expand Down
12 changes: 5 additions & 7 deletions 01-resource-referral/langgraph-workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dotenv import main
import operator
from typing import TypedDict, Annotated, Sequence

# import os
import graphviz # type: ignore

Expand Down Expand Up @@ -32,7 +33,6 @@ class WorkflowState(TypedDict):


class MyWorkflow:

def __init__(self, model_name: str, tools: list):
main.load_dotenv()
self.graph = self._init_graph()
Expand All @@ -42,7 +42,6 @@ def __init__(self, model_name: str, tools: list):
# tool_executor will be used to call the tool specified by the LLM in the llm_chain
self.tool_executor = ToolExecutor(tools)


def _init_graph(self):
graph = StateGraph(WorkflowState)
graph.add_node("decision_node", self.check_for_final_answer)
Expand Down Expand Up @@ -92,7 +91,7 @@ def _draw_graph(self, graph: StateGraph):

# Determines next node to call
def decide_next_node(self, state):
print("\nNEXT_EDGE") # , json.dumps(state, indent=2))
print("\nNEXT_EDGE") # , json.dumps(state, indent=2))

if state["final_answer"]:
return END
Expand Down Expand Up @@ -120,7 +119,7 @@ def check_for_final_answer(self, state):
print("\nHAS_FINAL_ANSWER node: Waiting for more responses")

def _got_responses_from_all_tools(self, state):
expected_tools = [ "spreadsheet", "211_api" ]
expected_tools = ["spreadsheet", "211_api"]
return all(key in state["tool_responses"] for key in expected_tools)

def run_llms(self, state):
Expand Down Expand Up @@ -148,7 +147,7 @@ def llm_spreadsheet_query(self, state):
llm_response = self.invoke_user_message("query_spreadsheet", user_message)
return {"messages": [llm_response]}

def invoke_user_message(self, tool, user_message):
def invoke_user_message(self, tool, user_message):
return self.llm_chain[tool].invoke(user_message)

def call_211_tool(self, state):
Expand Down Expand Up @@ -235,5 +234,4 @@ def merge_results(self, state):
final_state = runnable_graph.invoke(inputs)
# print("\nFINAL_STATE", type(final_state), final_state)
print("\nFINAL_ANSWER")
print(final_state['final_answer'])

print(final_state["final_answer"])
83 changes: 61 additions & 22 deletions 01-resource-referral/my_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@
from debugging import debug_runnable

TWO_ONE_ONE_BASE_SEARCH_ENDPOINT = "https://api.211.org/search/v1/api"


# TODO: adjust parameters and get LLM to set correct parameters
@tool
def call_211_api(city: str, service_type:str | list[str]) -> str:
def call_211_api(city: str, service_type: str | list[str]) -> str:
"""Calls National 211 API for the given city and service type, such as 'Consumer Services'"""
print(f"211 args: city={city}; service_type={service_type}")

return directly_call_211_api(city, service_type)

def directly_call_211_api(city:str, keyword:str | list[str]) -> str:

def directly_call_211_api(city: str, keyword: str | list[str]) -> str:
if isinstance(keyword, str):
return get_services_from_211(city, keyword)
if isinstance(keyword, list):
Expand All @@ -32,19 +35,19 @@ def directly_call_211_api(city:str, keyword:str | list[str]) -> str:
raise ValueError(f"Invalid keyword type: {type(keyword)}")


def get_services_from_211(city:str, keyword:str | list[str]):
def get_services_from_211(city: str, keyword: str | list[str]):
location_endpoint = f"{TWO_ONE_ONE_BASE_SEARCH_ENDPOINT}/Search/Keyword?Keyword={keyword}&Location={city}&Top=10&OrderBy=Relevance&SearchMode=Any&IncludeStateNationalRecords=true&ReturnTaxonomyTermsIfNoResults=false"

TWO_ONE_ONE_API_KEY = os.environ.get('TWO_ONE_ONE_API_KEY')
TWO_ONE_ONE_API_KEY = os.environ.get("TWO_ONE_ONE_API_KEY")

headers = {
'Accept': 'application/json',
'Api-Key': TWO_ONE_ONE_API_KEY,
"Accept": "application/json",
"Api-Key": TWO_ONE_ONE_API_KEY,
}

location_search = requests.get(location_endpoint, headers=headers)
# From Search: /api/Filters/ServiceAreas?StateProvince=MI, returns []
try:
try:
first_result = location_search.json()["results"][0]["document"]
# difficult to find param {location_id}, location_id returns dataowner
location_id = first_result["idLocation"]
Expand All @@ -60,12 +63,16 @@ def get_services_from_211(city:str, keyword:str | list[str]):
print("Failed to get services at location")
return "[]"


# Check for csv file
csv_file = "nyc_referral_csv.csv"
if not os.path.exists(csv_file):
print(f"Optionally download {csv_file} from google drive: https://drive.google.com/file/d/1YHgJvZCDF5VtTO-AQ4-I3_1jGzrcOHjY/view?usp=sharing")
print(
f"Optionally download {csv_file} from google drive: https://drive.google.com/file/d/1YHgJvZCDF5VtTO-AQ4-I3_1jGzrcOHjY/view?usp=sharing"
)
input(f"Press Enter to continue without {csv_file}...")


@tool
def query_spreadsheet(city: str, service_type: str | list[str]) -> str:
"""Search spreadsheet for support resources given the city and service type, such as 'Food Assistance'."""
Expand All @@ -76,20 +83,24 @@ def query_spreadsheet(city: str, service_type: str | list[str]) -> str:

# base implementation
df = pandas.read_csv(csv_file)
separated_locations = city.split(',')
separated_locations = city.split(",")
city_to_search = separated_locations[0]
query = service_type if isinstance(service_type, str) else '|'.join(service_type)
query = service_type if isinstance(service_type, str) else "|".join(service_type)
print(query)

results = df.query(f'needs.str.contains("{query}", case=False) & counties_served.str.contains("{city_to_search}", case=False)', engine='python')
results = df.query(
f'needs.str.contains("{query}", case=False) & counties_served.str.contains("{city_to_search}", case=False)',
engine="python",
)
if results.to_numpy().size == 0:
return "[]"
csv_with_header_to_json = results.replace(np.nan, None).to_dict('records')

csv_with_header_to_json = results.replace(np.nan, None).to_dict("records")
dict_json = json.dumps(csv_with_header_to_json, indent=2)

return dict_json


@tool
def merge_json_results(user_query, result_211_api, result_spreadsheet):
"""Merge JSON results from 211 API and spreadsheet"""
Expand All @@ -104,17 +115,32 @@ def merge_json_results(user_query, result_211_api, result_spreadsheet):

relevance_agent = create_relevance_agent()
# A long list of resources confuses the LLM, so sample only 40
sampled_resources : list[dict] = random.sample(list(deduplicated_dict.values()), min(40, len(deduplicated_dict)))
sampled_resources: list[dict] = random.sample(
list(deduplicated_dict.values()), min(40, len(deduplicated_dict))
)
formatted_resources = _format_for_prompt(sampled_resources)
prioritized_list = relevance_agent.invoke({"resources": formatted_resources, "user_query": user_query})
prioritized_list = relevance_agent.invoke(
{"resources": formatted_resources, "user_query": user_query}
)
return prioritized_list


ALTERNATIVE_KEYS = [
# from spreadsheet
"id", "alternate_name", "url", "website", "email", "tax_id",
"id",
"alternate_name",
"url",
"website",
"email",
"tax_id",
# from 211
"idService", "idOrganization", "name", "alternateName"
]
"idService",
"idOrganization",
"name",
"alternateName",
]


def _merge_and_deduplicate(dict_listA, dict_listB):
"""Merge 2 list of objects, removing duplicates based on object's 'name'.
If 'name' is not present, use one of ALTERNATIVE_KEYS to deduplicate."""
Expand All @@ -130,9 +156,12 @@ def _merge_and_deduplicate(dict_listA, dict_listB):
deduplicated_dict[obj["name"]] = obj
continue

deduplicated_dict[obj["name"]] = _merge_objects(deduplicated_dict[obj["name"]], obj)
deduplicated_dict[obj["name"]] = _merge_objects(
deduplicated_dict[obj["name"]], obj
)
return deduplicated_dict


def _merge_objects(objA: dict, objB: dict):
"""Merge 2 objects, concatenating values if same key are in both objects"""
merged_obj = {}
Expand All @@ -141,7 +170,10 @@ def _merge_objects(objA: dict, objB: dict):
merged_obj[key] = ";; ".join(set(values_list))
return merged_obj


APPROVED_RESOURCE_NAMES = ["Alpena CAO"]


def _filter_approved(deduplicated_dict):
"""Filter collection to only include approved resource names"""
for key in list(deduplicated_dict):
Expand All @@ -150,10 +182,16 @@ def _filter_approved(deduplicated_dict):
del deduplicated_dict[key]
return deduplicated_dict


def _format_for_prompt(resources: dict):
"""Format resources for prompt"""
return "\n".join([f"- {resource['name']} ({resource['phone']}): provides {resource.get('needs')} for counties {resource.get('counties_served')}. {resource.get('description', '')}"
for resource in resources])
return "\n".join(
[
f"- {resource['name']} ({resource['phone']}): provides {resource.get('needs')} for counties {resource.get('counties_served')}. {resource.get('description', '')}"
for resource in resources
]
)


def create_relevance_agent():
return (
Expand All @@ -162,6 +200,7 @@ def create_relevance_agent():
| create_llm(model_name="openhermes", settings={"temperature": 0, "top_p": 0.8})
)


def _agent_prompt_template():
template = """You are a helpful automated agent that filters and prioritizes benefits services. \
Downselect to less than 10 services total and prioritize the following list of services based on the user's query. \
Expand Down
7 changes: 4 additions & 3 deletions 02-household-queries/api_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
# Fetches data from Guru, currently not used as we're pulling data from static json files
# Eventually we would like to pull the latest updated data from the Guru API
GURU_ENDPOINT = "https://api.getguru.com/api/v1/"


def get_guru_data():
url = f"{GURU_ENDPOINT}cards/3fbff9c4-56a8-4561-a7d1-09727f1b4703"
headers = {
'Authorization': os.environ.get('GURU_TOKEN')}
headers = {"Authorization": os.environ.get("GURU_TOKEN")}
response = requests.request("GET", url, headers=headers)
return response.json()
return response.json()
Loading

0 comments on commit 1b30652

Please sign in to comment.