Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mypy random fixes #3893

Merged
merged 3 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added backend/alembic/__init__.py
Empty file.
2 changes: 2 additions & 0 deletions backend/ee/onyx/server/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def prepare_authorization_request(
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
session: str

if connector == DocumentSource.SLACK:
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
Expand Down Expand Up @@ -554,6 +555,7 @@ def handle_google_drive_oauth_callback(
)

session_json = session_json_bytes.decode("utf-8")
session: GoogleDriveOAuth.OAuthSession
try:
session = GoogleDriveOAuth.parse_session(session_json)

Expand Down
6 changes: 6 additions & 0 deletions backend/onyx/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ async def create(
referral_source=referral_source,
request=request,
)
user: User

async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
verify_email_is_invited(user_create.email)
Expand Down Expand Up @@ -368,6 +370,8 @@ async def oauth_callback(
"refresh_token": refresh_token,
}

user: User

try:
# Attempt to get user by OAuth account
user = await self.get_by_oauth_account(oauth_name, account_id)
Expand Down Expand Up @@ -1043,6 +1047,8 @@ async def api_key_dep(
if AUTH_TYPE == AuthType.DISABLED:
return None

user: User | None = None

hashed_api_key = get_hashed_api_key_from_request(request)
if not hashed_api_key:
raise HTTPException(status_code=401, detail="Missing API key")
Expand Down
3 changes: 2 additions & 1 deletion backend/onyx/background/celery/tasks/indexing/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,11 +586,12 @@ def connector_indexing_proxy_task(

# if the job is done, clean up and break
if job.done():
exit_code: int | None
try:
if job.status == "error":
ignore_exitcode = False

exit_code: int | None = None
exit_code = None
if job.process:
exit_code = job.process.exitcode

Expand Down
1 change: 1 addition & 0 deletions backend/onyx/background/celery/tasks/indexing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ def try_creating_indexing_task(
if not acquired:
return None

redis_connector_index: RedisConnectorIndex
try:
redis_connector = RedisConnector(tenant_id, cc_pair.id)
redis_connector_index = redis_connector.new_index(search_settings.id)
Expand Down
1 change: 1 addition & 0 deletions backend/onyx/background/celery/tasks/monitoring/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,7 @@ def cloud_check_alembic() -> bool | None:
revision_counts: dict[str, int] = {}
out_of_date_tenants: dict[str, str | None] = {}
top_revision: str = ""
tenant_ids: list[str] | list[None] = []

try:
# map each tenant_id to its revision
Expand Down
2 changes: 2 additions & 0 deletions backend/onyx/background/celery/tasks/shared/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def document_by_cc_pair_cleanup_task(
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False
except Exception as ex:
e: Exception | None = None
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
Expand Down Expand Up @@ -247,6 +248,7 @@ def cloud_beat_task_generator(
return None

last_lock_time = time.monotonic()
tenant_ids: list[str] | list[None] = []

try:
tenant_ids = get_all_tenant_ids()
Expand Down
1 change: 1 addition & 0 deletions backend/onyx/background/celery/tasks/vespa/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,7 @@ def vespa_metadata_sync_task(
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False
except Exception as ex:
e: Exception | None = None
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
Expand Down
5 changes: 4 additions & 1 deletion backend/onyx/background/indexing/run_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def _run_indexing(
callback=callback,
)

tracer: OnyxTracer
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
tracer = OnyxTracer()
Expand All @@ -255,6 +256,8 @@ def _run_indexing(
document_count = 0
chunk_count = 0
run_end_dt = None
tracer_counter: int

for ind, (window_start, window_end) in enumerate(
get_time_windows_for_index_attempt(
last_successful_run=datetime.fromtimestamp(
Expand All @@ -265,6 +268,7 @@ def _run_indexing(
):
cc_pair_loop: ConnectorCredentialPair | None = None
index_attempt_loop: IndexAttempt | None = None
tracer_counter = 0

try:
window_start = max(
Expand All @@ -289,7 +293,6 @@ def _run_indexing(
tenant_id=tenant_id,
)

tracer_counter = 0
if INDEXING_TRACER_INTERVAL > 0:
tracer.snap()
for doc_batch in connector_runner.run():
Expand Down
4 changes: 3 additions & 1 deletion backend/onyx/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.natural_language_processing.utils import get_tokenizer
Expand Down Expand Up @@ -349,7 +350,8 @@ def stream_chat_message_objects(
new_msg_req.chunks_above = 0
new_msg_req.chunks_below = 0

llm = None
llm: LLM

try:
user_id = user.id if user is not None else None

Expand Down
3 changes: 2 additions & 1 deletion backend/onyx/connectors/airtable/airtable_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,11 +369,12 @@ def load_from_state(self) -> GenerateDocumentsOutput:
# Process records in parallel batches using ThreadPoolExecutor
PARALLEL_BATCH_SIZE = 8
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
record_documents: list[Document] = []

# Process records in batches
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
batch_records = records[i : i + PARALLEL_BATCH_SIZE]
record_documents: list[Document] = []
record_documents = []

with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit batch tasks
Expand Down
6 changes: 3 additions & 3 deletions backend/onyx/connectors/asana/asana_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,18 @@ def _get_tasks_for_project(
project = self.project_api.get_project(project_gid, opts={})
if project["archived"]:
logger.info(f"Skipping archived project: {project['name']} ({project_gid})")
return []
yield from []
if not project["team"] or not project["team"]["gid"]:
logger.info(
f"Skipping project without a team: {project['name']} ({project_gid})"
)
return []
yield from []
if project["privacy_setting"] == "private":
if self.team_gid and project["team"]["gid"] != self.team_gid:
logger.info(
f"Skipping private project not in configured team: {project['name']} ({project_gid})"
)
return []
yield from []
else:
logger.info(
f"Processing private project in configured team: {project['name']} ({project_gid})"
Expand Down
1 change: 1 addition & 0 deletions backend/onyx/connectors/google_utils/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def _get_google_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDriveService | GoogleDocsService | AdminService | GmailService:
service: Resource
if isinstance(creds, ServiceAccountCredentials):
creds = creds.with_subject(user_email)
service = build(service_name, service_version, credentials=creds)
Expand Down
1 change: 1 addition & 0 deletions backend/onyx/connectors/salesforce/doc_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _clean_salesforce_dict(data: dict | list) -> dict | list:
elif isinstance(data, list):
filtered_list = []
for item in data:
filtered_item: dict | list
if isinstance(item, (dict, list)):
filtered_item = _clean_salesforce_dict(item)
# Only add non-empty dictionaries or lists
Expand Down
2 changes: 2 additions & 0 deletions backend/onyx/db/document_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def insert_document_set(
group_ids=document_set_creation_request.groups or [],
)

new_document_set_row: DocumentSetDBModel
ds_cc_pairs: list[DocumentSet__ConnectorCredentialPair]
try:
new_document_set_row = DocumentSetDBModel(
name=document_set_creation_request.name,
Expand Down
2 changes: 1 addition & 1 deletion backend/onyx/file_processing/extract_file_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def extract_file_text(
f"Failed to process with Unstructured: {str(unstructured_error)}. Falling back to normal processing."
)
# Fall through to normal processing

final_extension: str
if file_name or extension:
if extension is not None:
final_extension = extension
Expand Down
2 changes: 2 additions & 0 deletions backend/onyx/indexing/chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ def _create_chunk(
large_chunk_id=None,
)

section_link_text: str

for section_idx, section in enumerate(document.sections):
section_text = clean_text(section.text)
section_link_text = section.link or ""
Expand Down
6 changes: 5 additions & 1 deletion backend/onyx/onyxbot/slack/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import threading
import time
from collections.abc import Callable
from contextvars import Token
from threading import Event
from types import FrameType
from typing import Any
Expand Down Expand Up @@ -250,6 +251,8 @@ def acquire_tenants(self) -> None:
"""
all_tenants = get_all_tenant_ids()

token: Token[str]

# 1) Try to acquire locks for new tenants
for tenant_id in all_tenants:
if (
Expand Down Expand Up @@ -771,6 +774,7 @@ def process_message(
client=client.web_client, channel_id=channel
)

token: Token[str] | None = None
# Set the current tenant ID at the beginning for all DB calls within this thread
if client.tenant_id:
logger.info(f"Setting tenant ID to {client.tenant_id}")
Expand Down Expand Up @@ -825,7 +829,7 @@ def process_message(
if notify_no_answer:
apologize_for_fail(details, client)
finally:
if client.tenant_id:
if token:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)


Expand Down
2 changes: 1 addition & 1 deletion backend/onyx/onyxbot/slack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def read_slack_thread(
message_type = MessageType.USER
else:
self_slack_bot_id = get_onyx_bot_slack_bot_id(client)

blocks: Any
if reply.get("user") == self_slack_bot_id:
# OnyxBot response
message_type = MessageType.ASSISTANT
Expand Down
3 changes: 2 additions & 1 deletion backend/onyx/redis/redis_connector_doc_perm_sync.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from datetime import datetime
from typing import Any
from typing import cast
from uuid import uuid4

Expand Down Expand Up @@ -95,7 +96,7 @@ def fenced(self) -> bool:
@property
def payload(self) -> RedisConnectorPermissionSyncPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
fence_bytes = cast(Any, self.redis.get(self.fence_key))
if fence_bytes is None:
return None

Expand Down
3 changes: 2 additions & 1 deletion backend/onyx/redis/redis_connector_ext_group_sync.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any
from typing import cast

import redis
Expand Down Expand Up @@ -82,7 +83,7 @@ def fenced(self) -> bool:
@property
def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
fence_bytes = cast(Any, self.redis.get(self.fence_key))
if fence_bytes is None:
return None

Expand Down
3 changes: 2 additions & 1 deletion backend/onyx/redis/redis_connector_index.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any
from typing import cast
from uuid import uuid4

Expand Down Expand Up @@ -89,7 +90,7 @@ def fenced(self) -> bool:
@property
def payload(self) -> RedisConnectorIndexPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
fence_bytes = cast(Any, self.redis.get(self.fence_key))
if fence_bytes is None:
return None

Expand Down
2 changes: 2 additions & 0 deletions backend/onyx/server/manage/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ def bulk_invite_users(

tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
new_invited_emails = []
email: str

try:
for email in emails:
email_info = validate_email(email)
Expand Down
1 change: 1 addition & 0 deletions backend/scripts/chat_feedback_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def process_all_chat_feedback(onyx_url: str, api_key: str | None) -> None:
r_sessions = get_chat_sessions(onyx_url, headers, user_id)
logger.info(f"user={user_id} num_sessions={len(r_sessions.sessions)}")
for session in r_sessions.sessions:
s: ChatSessionSnapshot
try:
s = get_session_history(onyx_url, headers, session.id)
except requests.exceptions.HTTPError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.models import Document


@pytest.fixture
Expand Down Expand Up @@ -41,6 +42,10 @@ def test_confluence_connector_basic(

assert len(doc_batch) == 3

page_within_a_page_doc: Document | None = None
page_doc: Document | None = None
txt_doc: Document | None = None

for doc in doc_batch:
if doc.semantic_identifier == "DailyConnectorTestSpace Home":
page_doc = doc
Expand All @@ -49,6 +54,7 @@ def test_confluence_connector_basic(
elif doc.semantic_identifier == "Page Within A Page":
page_within_a_page_doc = doc

assert page_within_a_page_doc is not None
assert page_within_a_page_doc.semantic_identifier == "Page Within A Page"
assert page_within_a_page_doc.primary_owners
assert page_within_a_page_doc.primary_owners[0].email == "hagen@danswer.ai"
Expand All @@ -62,6 +68,7 @@ def test_confluence_connector_basic(
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/pages/200769540/Page+Within+A+Page"
)

assert page_doc is not None
assert page_doc.semantic_identifier == "DailyConnectorTestSpace Home"
assert page_doc.metadata["labels"] == ["testlabel"]
assert page_doc.primary_owners
Expand All @@ -75,6 +82,7 @@ def test_confluence_connector_basic(
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/overview"
)

assert txt_doc is not None
assert txt_doc.semantic_identifier == "small-file.txt"
assert len(txt_doc.sections) == 1
assert txt_doc.sections[0].text == "small"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def test_docs_retrieval(

for doc in retrieved_docs:
id = doc.id
retrieved_primary_owner_emails: set[str | None] = set()
retrieved_secondary_owner_emails: set[str | None] = set()
if doc.primary_owners:
retrieved_primary_owner_emails = set(
[owner.email for owner in doc.primary_owners]
Expand Down
1 change: 1 addition & 0 deletions backend/tests/integration/common_utils/managers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def set_status(
target_status: bool,
user_performing_action: DATestUser,
) -> DATestUser:
url_substring: str
if target_status is True:
url_substring = "activate"
elif target_status is False:
Expand Down
Loading
Loading