Skip to content

Commit

Permalink
Merge pull request #3893 from onyx-dot-app/mypy_random
Browse files Browse the repository at this point in the history
Mypy random fixes
  • Loading branch information
rkuo-danswer authored Feb 5, 2025
2 parents c0271a9 + ec0c655 commit 5854b39
Show file tree
Hide file tree
Showing 28 changed files with 61 additions and 13 deletions.
Empty file added backend/alembic/__init__.py
Empty file.
2 changes: 2 additions & 0 deletions backend/ee/onyx/server/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def prepare_authorization_request(
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
session: str

if connector == DocumentSource.SLACK:
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
Expand Down Expand Up @@ -554,6 +555,7 @@ def handle_google_drive_oauth_callback(
)

session_json = session_json_bytes.decode("utf-8")
session: GoogleDriveOAuth.OAuthSession
try:
session = GoogleDriveOAuth.parse_session(session_json)

Expand Down
6 changes: 6 additions & 0 deletions backend/onyx/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ async def create(
referral_source=referral_source,
request=request,
)
user: User

async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
verify_email_is_invited(user_create.email)
Expand Down Expand Up @@ -368,6 +370,8 @@ async def oauth_callback(
"refresh_token": refresh_token,
}

user: User

try:
# Attempt to get user by OAuth account
user = await self.get_by_oauth_account(oauth_name, account_id)
Expand Down Expand Up @@ -1043,6 +1047,8 @@ async def api_key_dep(
if AUTH_TYPE == AuthType.DISABLED:
return None

user: User | None = None

hashed_api_key = get_hashed_api_key_from_request(request)
if not hashed_api_key:
raise HTTPException(status_code=401, detail="Missing API key")
Expand Down
3 changes: 2 additions & 1 deletion backend/onyx/background/celery/tasks/indexing/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,11 +586,12 @@ def connector_indexing_proxy_task(

# if the job is done, clean up and break
if job.done():
exit_code: int | None
try:
if job.status == "error":
ignore_exitcode = False

exit_code: int | None = None
exit_code = None
if job.process:
exit_code = job.process.exitcode

Expand Down
1 change: 1 addition & 0 deletions backend/onyx/background/celery/tasks/indexing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ def try_creating_indexing_task(
if not acquired:
return None

redis_connector_index: RedisConnectorIndex
try:
redis_connector = RedisConnector(tenant_id, cc_pair.id)
redis_connector_index = redis_connector.new_index(search_settings.id)
Expand Down
1 change: 1 addition & 0 deletions backend/onyx/background/celery/tasks/monitoring/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,7 @@ def cloud_check_alembic() -> bool | None:
revision_counts: dict[str, int] = {}
out_of_date_tenants: dict[str, str | None] = {}
top_revision: str = ""
tenant_ids: list[str] | list[None] = []

try:
# map each tenant_id to its revision
Expand Down
2 changes: 2 additions & 0 deletions backend/onyx/background/celery/tasks/shared/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def document_by_cc_pair_cleanup_task(
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False
except Exception as ex:
e: Exception | None = None
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
Expand Down Expand Up @@ -247,6 +248,7 @@ def cloud_beat_task_generator(
return None

last_lock_time = time.monotonic()
tenant_ids: list[str] | list[None] = []

try:
tenant_ids = get_all_tenant_ids()
Expand Down
1 change: 1 addition & 0 deletions backend/onyx/background/celery/tasks/vespa/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,7 @@ def vespa_metadata_sync_task(
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False
except Exception as ex:
e: Exception | None = None
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
Expand Down
5 changes: 4 additions & 1 deletion backend/onyx/background/indexing/run_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def _run_indexing(
callback=callback,
)

tracer: OnyxTracer
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
tracer = OnyxTracer()
Expand All @@ -255,6 +256,8 @@ def _run_indexing(
document_count = 0
chunk_count = 0
run_end_dt = None
tracer_counter: int

for ind, (window_start, window_end) in enumerate(
get_time_windows_for_index_attempt(
last_successful_run=datetime.fromtimestamp(
Expand All @@ -265,6 +268,7 @@ def _run_indexing(
):
cc_pair_loop: ConnectorCredentialPair | None = None
index_attempt_loop: IndexAttempt | None = None
tracer_counter = 0

try:
window_start = max(
Expand All @@ -289,7 +293,6 @@ def _run_indexing(
tenant_id=tenant_id,
)

tracer_counter = 0
if INDEXING_TRACER_INTERVAL > 0:
tracer.snap()
for doc_batch in connector_runner.run():
Expand Down
4 changes: 3 additions & 1 deletion backend/onyx/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.natural_language_processing.utils import get_tokenizer
Expand Down Expand Up @@ -349,7 +350,8 @@ def stream_chat_message_objects(
new_msg_req.chunks_above = 0
new_msg_req.chunks_below = 0

llm = None
llm: LLM

try:
user_id = user.id if user is not None else None

Expand Down
3 changes: 2 additions & 1 deletion backend/onyx/connectors/airtable/airtable_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,11 +369,12 @@ def load_from_state(self) -> GenerateDocumentsOutput:
# Process records in parallel batches using ThreadPoolExecutor
PARALLEL_BATCH_SIZE = 8
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
record_documents: list[Document] = []

# Process records in batches
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
batch_records = records[i : i + PARALLEL_BATCH_SIZE]
record_documents: list[Document] = []
record_documents = []

with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit batch tasks
Expand Down
6 changes: 3 additions & 3 deletions backend/onyx/connectors/asana/asana_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,18 @@ def _get_tasks_for_project(
project = self.project_api.get_project(project_gid, opts={})
if project["archived"]:
logger.info(f"Skipping archived project: {project['name']} ({project_gid})")
return []
yield from []
if not project["team"] or not project["team"]["gid"]:
logger.info(
f"Skipping project without a team: {project['name']} ({project_gid})"
)
return []
yield from []
if project["privacy_setting"] == "private":
if self.team_gid and project["team"]["gid"] != self.team_gid:
logger.info(
f"Skipping private project not in configured team: {project['name']} ({project_gid})"
)
return []
yield from []
else:
logger.info(
f"Processing private project in configured team: {project['name']} ({project_gid})"
Expand Down
1 change: 1 addition & 0 deletions backend/onyx/connectors/google_utils/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def _get_google_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDriveService | GoogleDocsService | AdminService | GmailService:
service: Resource
if isinstance(creds, ServiceAccountCredentials):
creds = creds.with_subject(user_email)
service = build(service_name, service_version, credentials=creds)
Expand Down
1 change: 1 addition & 0 deletions backend/onyx/connectors/salesforce/doc_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _clean_salesforce_dict(data: dict | list) -> dict | list:
elif isinstance(data, list):
filtered_list = []
for item in data:
filtered_item: dict | list
if isinstance(item, (dict, list)):
filtered_item = _clean_salesforce_dict(item)
# Only add non-empty dictionaries or lists
Expand Down
2 changes: 2 additions & 0 deletions backend/onyx/db/document_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def insert_document_set(
group_ids=document_set_creation_request.groups or [],
)

new_document_set_row: DocumentSetDBModel
ds_cc_pairs: list[DocumentSet__ConnectorCredentialPair]
try:
new_document_set_row = DocumentSetDBModel(
name=document_set_creation_request.name,
Expand Down
2 changes: 1 addition & 1 deletion backend/onyx/file_processing/extract_file_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def extract_file_text(
f"Failed to process with Unstructured: {str(unstructured_error)}. Falling back to normal processing."
)
# Fall through to normal processing

final_extension: str
if file_name or extension:
if extension is not None:
final_extension = extension
Expand Down
2 changes: 2 additions & 0 deletions backend/onyx/indexing/chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ def _create_chunk(
large_chunk_id=None,
)

section_link_text: str

for section_idx, section in enumerate(document.sections):
section_text = clean_text(section.text)
section_link_text = section.link or ""
Expand Down
6 changes: 5 additions & 1 deletion backend/onyx/onyxbot/slack/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import threading
import time
from collections.abc import Callable
from contextvars import Token
from threading import Event
from types import FrameType
from typing import Any
Expand Down Expand Up @@ -250,6 +251,8 @@ def acquire_tenants(self) -> None:
"""
all_tenants = get_all_tenant_ids()

token: Token[str]

# 1) Try to acquire locks for new tenants
for tenant_id in all_tenants:
if (
Expand Down Expand Up @@ -771,6 +774,7 @@ def process_message(
client=client.web_client, channel_id=channel
)

token: Token[str] | None = None
# Set the current tenant ID at the beginning for all DB calls within this thread
if client.tenant_id:
logger.info(f"Setting tenant ID to {client.tenant_id}")
Expand Down Expand Up @@ -825,7 +829,7 @@ def process_message(
if notify_no_answer:
apologize_for_fail(details, client)
finally:
if client.tenant_id:
if token:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)


Expand Down
2 changes: 1 addition & 1 deletion backend/onyx/onyxbot/slack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def read_slack_thread(
message_type = MessageType.USER
else:
self_slack_bot_id = get_onyx_bot_slack_bot_id(client)

blocks: Any
if reply.get("user") == self_slack_bot_id:
# OnyxBot response
message_type = MessageType.ASSISTANT
Expand Down
3 changes: 2 additions & 1 deletion backend/onyx/redis/redis_connector_doc_perm_sync.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from datetime import datetime
from typing import Any
from typing import cast
from uuid import uuid4

Expand Down Expand Up @@ -96,7 +97,7 @@ def fenced(self) -> bool:
@property
def payload(self) -> RedisConnectorPermissionSyncPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
fence_bytes = cast(Any, self.redis.get(self.fence_key))
if fence_bytes is None:
return None

Expand Down
3 changes: 2 additions & 1 deletion backend/onyx/redis/redis_connector_ext_group_sync.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any
from typing import cast

import redis
Expand Down Expand Up @@ -82,7 +83,7 @@ def fenced(self) -> bool:
@property
def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
fence_bytes = cast(Any, self.redis.get(self.fence_key))
if fence_bytes is None:
return None

Expand Down
3 changes: 2 additions & 1 deletion backend/onyx/redis/redis_connector_index.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any
from typing import cast
from uuid import uuid4

Expand Down Expand Up @@ -91,7 +92,7 @@ def fenced(self) -> bool:
@property
def payload(self) -> RedisConnectorIndexPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
fence_bytes = cast(Any, self.redis.get(self.fence_key))
if fence_bytes is None:
return None

Expand Down
2 changes: 2 additions & 0 deletions backend/onyx/server/manage/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ def bulk_invite_users(

tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
new_invited_emails = []
email: str

try:
for email in emails:
email_info = validate_email(email)
Expand Down
1 change: 1 addition & 0 deletions backend/scripts/chat_feedback_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def process_all_chat_feedback(onyx_url: str, api_key: str | None) -> None:
r_sessions = get_chat_sessions(onyx_url, headers, user_id)
logger.info(f"user={user_id} num_sessions={len(r_sessions.sessions)}")
for session in r_sessions.sessions:
s: ChatSessionSnapshot
try:
s = get_session_history(onyx_url, headers, session.id)
except requests.exceptions.HTTPError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.models import Document


@pytest.fixture
Expand Down Expand Up @@ -41,6 +42,10 @@ def test_confluence_connector_basic(

assert len(doc_batch) == 3

page_within_a_page_doc: Document | None = None
page_doc: Document | None = None
txt_doc: Document | None = None

for doc in doc_batch:
if doc.semantic_identifier == "DailyConnectorTestSpace Home":
page_doc = doc
Expand All @@ -49,6 +54,7 @@ def test_confluence_connector_basic(
elif doc.semantic_identifier == "Page Within A Page":
page_within_a_page_doc = doc

assert page_within_a_page_doc is not None
assert page_within_a_page_doc.semantic_identifier == "Page Within A Page"
assert page_within_a_page_doc.primary_owners
assert page_within_a_page_doc.primary_owners[0].email == "hagen@danswer.ai"
Expand All @@ -62,6 +68,7 @@ def test_confluence_connector_basic(
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/pages/200769540/Page+Within+A+Page"
)

assert page_doc is not None
assert page_doc.semantic_identifier == "DailyConnectorTestSpace Home"
assert page_doc.metadata["labels"] == ["testlabel"]
assert page_doc.primary_owners
Expand All @@ -75,6 +82,7 @@ def test_confluence_connector_basic(
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/overview"
)

assert txt_doc is not None
assert txt_doc.semantic_identifier == "small-file.txt"
assert len(txt_doc.sections) == 1
assert txt_doc.sections[0].text == "small"
Expand Down
2 changes: 2 additions & 0 deletions backend/tests/daily/connectors/gmail/test_gmail_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def test_docs_retrieval(

for doc in retrieved_docs:
id = doc.id
retrieved_primary_owner_emails: set[str | None] = set()
retrieved_secondary_owner_emails: set[str | None] = set()
if doc.primary_owners:
retrieved_primary_owner_emails = set(
[owner.email for owner in doc.primary_owners]
Expand Down
1 change: 1 addition & 0 deletions backend/tests/integration/common_utils/managers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def set_status(
target_status: bool,
user_performing_action: DATestUser,
) -> DATestUser:
url_substring: str
if target_status is True:
url_substring = "activate"
elif target_status is False:
Expand Down
Loading

0 comments on commit 5854b39

Please sign in to comment.