Skip to content

Commit

Permalink
Add the "U" series rules in ruff (pyupgrade) (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
gbaian10 authored Oct 7, 2024
1 parent 1dfce9b commit 48a4ca2
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 21 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ line-length = 100
target-version = "py310"

[tool.ruff.lint]
extend-select = ["I"]
extend-select = ["I", "U"]

[tool.pytest_env]
OPENAI_API_KEY = "sk-fake-openai-key"
9 changes: 4 additions & 5 deletions src/agent/llama_guard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from enum import Enum
from typing import List

from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
from langchain_core.prompts import PromptTemplate
Expand All @@ -16,7 +15,7 @@ class SafetyAssessment(Enum):

class LlamaGuardOutput(BaseModel):
safety_assessment: SafetyAssessment = Field(description="The safety assessment of the content.")
unsafe_categories: List[str] = Field(
unsafe_categories: list[str] = Field(
description="If content is unsafe, the list of unsafe categories.", default=[]
)

Expand Down Expand Up @@ -86,22 +85,22 @@ def __init__(self):
)
self.prompt = PromptTemplate.from_template(llama_guard_instructions)

def _compile_prompt(self, role: str, messages: List[AnyMessage]) -> str:
def _compile_prompt(self, role: str, messages: list[AnyMessage]) -> str:
role_mapping = {"ai": "Agent", "human": "User"}
messages_str = [
f"{role_mapping[m.type]}: {m.content}" for m in messages if m.type in ["ai", "human"]
]
conversation_history = "\n\n".join(messages_str)
return self.prompt.format(role=role, conversation_history=conversation_history)

def invoke(self, role: str, messages: List[AnyMessage]) -> LlamaGuardOutput:
def invoke(self, role: str, messages: list[AnyMessage]) -> LlamaGuardOutput:
if self.model is None:
return LlamaGuardOutput(safety_assessment=SafetyAssessment.SAFE)
compiled_prompt = self._compile_prompt(role, messages)
result = self.model.invoke([HumanMessage(content=compiled_prompt)])
return parse_llama_guard_output(result.content)

async def ainvoke(self, role: str, messages: List[AnyMessage]) -> LlamaGuardOutput:
async def ainvoke(self, role: str, messages: list[AnyMessage]) -> LlamaGuardOutput:
if self.model is None:
return LlamaGuardOutput(safety_assessment=SafetyAssessment.SAFE)
compiled_prompt = self._compile_prompt(role, messages)
Expand Down
5 changes: 3 additions & 2 deletions src/client/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
from typing import Any, AsyncGenerator, Dict, Generator
from collections.abc import AsyncGenerator, Generator
from typing import Any

import httpx

Expand Down Expand Up @@ -204,7 +205,7 @@ async def astream(
yield parsed

async def acreate_feedback(
self, run_id: str, key: str, score: float, kwargs: Dict[str, Any] = {}
self, run_id: str, key: str, score: float, kwargs: dict[str, Any] = {}
):
"""
Create a feedback record for a run.
Expand Down
14 changes: 7 additions & 7 deletions src/schema/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Literal, Union
from typing import Any, Literal

from langchain_core.messages import (
AIMessage,
Expand All @@ -12,10 +12,10 @@
from pydantic import BaseModel, Field


def convert_message_content_to_string(content: Union[str, List[Union[str, Dict]]]) -> str:
def convert_message_content_to_string(content: str | list[str | dict]) -> str:
if isinstance(content, str):
return content
text: List[str] = []
text: list[str] = []
for content_item in content:
if isinstance(content_item, str):
text.append(content_item)
Expand Down Expand Up @@ -56,7 +56,7 @@ class StreamInput(UserInput):
class AgentResponse(BaseModel):
"""Response from the agent when called via /invoke."""

message: Dict[str, Any] = Field(
message: dict[str, Any] = Field(
description="Final response from the agent, as a serialized LangChain message.",
examples=[
{
Expand All @@ -80,7 +80,7 @@ class ChatMessage(BaseModel):
description="Content of the message.",
examples=["Hello, world!"],
)
tool_calls: List[ToolCall] = Field(
tool_calls: list[ToolCall] = Field(
description="Tool calls in the message.",
default=[],
)
Expand All @@ -94,7 +94,7 @@ class ChatMessage(BaseModel):
default=None,
examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
)
original: Dict[str, Any] = Field(
original: dict[str, Any] = Field(
description="Original LangChain message in serialized form.",
default={},
)
Expand Down Expand Up @@ -164,7 +164,7 @@ class Feedback(BaseModel):
description="Feedback score.",
examples=[0.8],
)
kwargs: Dict[str, Any] = Field(
kwargs: dict[str, Any] = Field(
description="Additional feedback kwargs, passed to LangSmith.",
default={},
examples=[{"comment": "In-line human feedback"}],
Expand Down
9 changes: 5 additions & 4 deletions src/service/service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import os
import warnings
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, List, Tuple, Union
from typing import Any
from uuid import uuid4

from fastapi import FastAPI, HTTPException, Request, Response
Expand Down Expand Up @@ -43,7 +44,7 @@ async def check_auth_header(request: Request, call_next):
return await call_next(request)


def _parse_input(user_input: UserInput) -> Tuple[Dict[str, Any], str]:
def _parse_input(user_input: UserInput) -> tuple[dict[str, Any], str]:
run_id = uuid4()
thread_id = user_input.thread_id or str(uuid4())
input_message = ChatMessage(type="human", content=user_input.message)
Expand All @@ -58,8 +59,8 @@ def _parse_input(user_input: UserInput) -> Tuple[Dict[str, Any], str]:


def _remove_tool_calls(
content: Union[str, List[Union[str, Dict]]],
) -> Union[str, List[Union[str, Dict]]]:
content: str | list[str | dict],
) -> str | list[str | dict]:
"""Remove tool calls from content."""
if isinstance(content, str):
return content
Expand Down
4 changes: 2 additions & 2 deletions src/streamlit_app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import os
from typing import AsyncGenerator, List
from collections.abc import AsyncGenerator

import streamlit as st
from streamlit.runtime.scriptrunner import get_script_run_ctx
Expand Down Expand Up @@ -95,7 +95,7 @@ def architecture_dialog():
# Draw existing messages
if "messages" not in st.session_state:
st.session_state.messages = []
messages: List[ChatMessage] = st.session_state.messages
messages: list[ChatMessage] = st.session_state.messages

if len(messages) == 0:
WELCOME = "Hello! I'm an AI-powered research assistant with web search and a calculator. I may take a few seconds to boot up when you send your first message. Ask me anything!"
Expand Down

0 comments on commit 48a4ca2

Please sign in to comment.