Skip to content

Commit

Permalink
add google provider
Browse files Browse the repository at this point in the history
  • Loading branch information
xvnpw committed Dec 20, 2024
1 parent c7304db commit ab3b1c2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions ai_security_analyzer/file_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ai_security_analyzer.documents import DocumentFilter, DocumentProcessor
from ai_security_analyzer.llms import LLMProvider
from ai_security_analyzer.markdowns import MarkdownMermaidValidator
from ai_security_analyzer.utils import get_total_tokens
from ai_security_analyzer.utils import get_response_content, get_total_tokens

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,7 +99,7 @@ def _create_initial_draft(self, state: AgentState, llm: Any, use_system_message:
response = llm.invoke(messages)
document_tokens = get_total_tokens(response)
return {
"sec_repo_doc": response.content,
"sec_repo_doc": get_response_content(response),
"document_tokens": document_tokens,
}
except Exception as e:
Expand All @@ -117,7 +117,7 @@ def _refine_draft(self, state: AgentState, llm: Any, use_system_message: bool):
response = llm.invoke(messages)
document_tokens = state.get("document_tokens", 0) + get_total_tokens(response)
return {
"sec_repo_doc": response.content,
"sec_repo_doc": get_response_content(response),
"document_tokens": document_tokens,
"current_refinement_count": state.get("current_refinement_count", 0) + 1,
}
Expand Down Expand Up @@ -177,7 +177,7 @@ def _editor(self, state: AgentState, llm: BaseChatModel, use_system_message: boo
response = llm.invoke(messages)
document_tokens = state.get("document_tokens", 0) + get_total_tokens(response)
return {
"sec_repo_doc": response.content,
"sec_repo_doc": get_response_content(response),
"sec_repo_doc_validation_error": "",
"editor_turns_count": state.get("editor_turns_count", 0) + 1,
"document_tokens": document_tokens,
Expand Down
8 changes: 4 additions & 4 deletions ai_security_analyzer/full_dir_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ai_security_analyzer.llms import LLMProvider
from ai_security_analyzer.loaders import RepoDirectoryLoader
from ai_security_analyzer.markdowns import MarkdownMermaidValidator
from ai_security_analyzer.utils import get_total_tokens
from ai_security_analyzer.utils import get_response_content, get_total_tokens

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -144,7 +144,7 @@ def _create_initial_draft( # type: ignore[no-untyped-def]
response = llm.invoke(messages)
document_tokens = get_total_tokens(response)
return {
"sec_repo_doc": response.content,
"sec_repo_doc": get_response_content(response),
"processed_docs_count": len(first_batch),
"document_tokens": document_tokens,
}
Expand Down Expand Up @@ -180,7 +180,7 @@ def _update_draft_with_new_docs( # type: ignore[no-untyped-def]
response = llm.invoke(messages)
document_tokens = state.get("document_tokens", 0) + get_total_tokens(response)
return {
"sec_repo_doc": response.content,
"sec_repo_doc": get_response_content(response),
"processed_docs_count": processed_count + len(next_batch),
"document_tokens": document_tokens,
}
Expand Down Expand Up @@ -239,7 +239,7 @@ def _editor(self, state: AgentState, llm: BaseChatModel, use_system_message: boo
response = llm.invoke(messages)
document_tokens = state.get("document_tokens", 0) + get_total_tokens(response)
return {
"sec_repo_doc": response.content,
"sec_repo_doc": get_response_content(response),
"sec_repo_doc_validation_error": "",
"editor_turns_count": state.get("editor_turns_count", 0) + 1,
"document_tokens": document_tokens,
Expand Down

0 comments on commit ab3b1c2

Please sign in to comment.