Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added filter to exclude attachments with unsupported file extensions #3530

Merged
merged 2 commits into from
Dec 20, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 68 additions & 28 deletions backend/onyx/connectors/confluence/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,23 @@

_SLIM_DOC_BATCH_SIZE = 5000

_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
"png",
"jpg",
"jpeg",
"gif",
"mp4",
"mov",
"mp3",
"wav",
]
_FULL_EXTENTION_FILTER_STRING = "".join(
[
f" and title!~'*.{extension}'"
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
]
)


class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
Expand All @@ -64,7 +81,7 @@ def __init__(
is_cloud: bool,
space: str = "",
page_id: str = "",
index_recursively: bool = True,
index_recursively: bool = False,
cql_query: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
Expand All @@ -83,22 +100,21 @@ def __init__(
self.wiki_base = wiki_base.rstrip("/")

# if nothing is provided, we will fetch all pages
cql_page_query = "type=page"
base_cql_page_query = "type=page"
if cql_query:
# if a cql_query is provided, we will use it to fetch the pages
cql_page_query = cql_query
base_cql_page_query = cql_query
elif page_id:
# if a cql_query is not provided, we will use the page_id to fetch the page
if index_recursively:
cql_page_query += f" and ancestor='{page_id}'"
base_cql_page_query += f" and ancestor='{page_id}'"
else:
cql_page_query += f" and id='{page_id}'"
base_cql_page_query += f" and id='{page_id}'"
elif space:
# if no cql_query or page_id is provided, we will use the space to fetch the pages
cql_page_query += f" and space='{quote(space)}'"
base_cql_page_query += f" and space='{quote(space)}'"

self.cql_page_query = cql_page_query
self.cql_time_filter = ""
self.base_cql_page_query = base_cql_page_query

self.cql_label_filter = ""
if labels_to_skip:
Expand Down Expand Up @@ -126,6 +142,33 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None
)
return None

def _construct_page_query(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> str:
page_query = self.base_cql_page_query + self.cql_label_filter

# Add time filters
if start:
formatted_start_time = datetime.fromtimestamp(
start, tz=self.timezone
).strftime("%Y-%m-%d %H:%M")
page_query += f" and lastmodified >= '{formatted_start_time}'"
if end:
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
page_query += f" and lastmodified <= '{formatted_end_time}'"

return page_query

def _construct_attachment_query(self, confluence_page_id: str) -> str:
attachment_query = f"type=attachment and container='{confluence_page_id}'"
attachment_query += self.cql_label_filter
attachment_query += _FULL_EXTENTION_FILTER_STRING
return attachment_query

def _get_comment_string_for_page_id(self, page_id: str) -> str:
comment_string = ""

Expand Down Expand Up @@ -205,11 +248,15 @@ def _convert_object_to_document(
metadata=doc_metadata,
)

def _fetch_document_batches(self) -> GenerateDocumentsOutput:
def _fetch_document_batches(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
doc_batch: list[Document] = []
confluence_page_ids: list[str] = []

page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
page_query = self._construct_page_query(start, end)
logger.debug(f"page_query: {page_query}")
# Fetch pages as Documents
for page in self.confluence_client.paginated_cql_retrieval(
Expand All @@ -228,11 +275,10 @@ def _fetch_document_batches(self) -> GenerateDocumentsOutput:

# Fetch attachments as Documents
for confluence_page_id in confluence_page_ids:
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
attachment_cql += self.cql_label_filter
attachment_query = self._construct_attachment_query(confluence_page_id)
# TODO: maybe should add time filter as well?
for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_cql,
cql=attachment_query,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
doc = self._convert_object_to_document(attachment)
Expand All @@ -248,17 +294,12 @@ def _fetch_document_batches(self) -> GenerateDocumentsOutput:
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_document_batches()

def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
# Add time filters
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
return self._fetch_document_batches()
def poll_source(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
return self._fetch_document_batches(start, end)

def retrieve_all_slim_documents(
self,
Expand All @@ -269,7 +310,7 @@ def retrieve_all_slim_documents(

restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)

page_query = self.cql_page_query + self.cql_label_filter
page_query = self.base_cql_page_query + self.cql_label_filter
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
Expand All @@ -294,10 +335,9 @@ def retrieve_all_slim_documents(
perm_sync_data=page_perm_sync_data,
)
)
attachment_cql = f"type=attachment and container='{page['id']}'"
attachment_cql += self.cql_label_filter
attachment_query = self._construct_attachment_query(page["id"])
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
cql=attachment_query,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
Expand Down
Loading