diff --git a/backend/ee/onyx/access/access.py b/backend/ee/onyx/access/access.py index 684b575ac43..558699d6170 100644 --- a/backend/ee/onyx/access/access.py +++ b/backend/ee/onyx/access/access.py @@ -3,6 +3,10 @@ from ee.onyx.db.external_perm import fetch_external_groups_for_user from ee.onyx.db.user_group import fetch_user_groups_for_documents from ee.onyx.db.user_group import fetch_user_groups_for_user +from ee.onyx.external_permissions.post_query_censoring import ( + DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION, +) +from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP from onyx.access.access import ( _get_access_for_documents as get_access_for_documents_without_groups, ) @@ -10,6 +14,7 @@ from onyx.access.models import DocumentAccess from onyx.access.utils import prefix_external_group from onyx.access.utils import prefix_user_group +from onyx.db.document import get_document_sources from onyx.db.document import get_documents_by_ids from onyx.db.models import User @@ -52,9 +57,20 @@ def _get_access_for_documents( ) doc_id_map = {doc.id: doc for doc in documents} + # Get all sources in one batch + doc_id_to_source_map = get_document_sources( + db_session=db_session, + document_ids=document_ids, + ) + access_map = {} for document_id, non_ee_access in non_ee_access_dict.items(): document = doc_id_map[document_id] + source = doc_id_to_source_map.get(document_id) + is_only_censored = ( + source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION + and source not in DOC_PERMISSIONS_FUNC_MAP + ) ext_u_emails = ( set(document.external_user_emails) @@ -70,7 +86,11 @@ def _get_access_for_documents( # If the document is determined to be "public" externally (through a SYNC connector) # then it's given the same access level as if it were marked public within Onyx - is_public_anywhere = document.is_public or non_ee_access.is_public + # If its censored, then it's public anywhere during the search and then permissions are + # applied after the search + is_public_anywhere = ( + document.is_public or non_ee_access.is_public or is_only_censored + ) # To avoid collisions of group namings between connectors, they need to be prefixed access_map[document_id] = DocumentAccess( diff --git a/backend/ee/onyx/db/external_perm.py b/backend/ee/onyx/db/external_perm.py index 97039f36d13..16de8bb4110 100644 --- a/backend/ee/onyx/db/external_perm.py +++ b/backend/ee/onyx/db/external_perm.py @@ -10,6 +10,7 @@ from onyx.configs.constants import DocumentSource from onyx.db.models import User__ExternalUserGroupId from onyx.db.users import batch_add_ext_perm_user_if_not_exists +from onyx.db.users import get_user_by_email from onyx.utils.logger import setup_logger logger = setup_logger() @@ -106,3 +107,21 @@ def fetch_external_groups_for_user( User__ExternalUserGroupId.user_id == user_id ) ).all() + + +def fetch_external_groups_for_user_email_and_group_ids( + db_session: Session, + user_email: str, + group_ids: list[str], +) -> list[User__ExternalUserGroupId]: + user = get_user_by_email(db_session=db_session, email=user_email) + if user is None: + return [] + user_id = user.id + user_ext_groups = db_session.scalars( + select(User__ExternalUserGroupId).where( + User__ExternalUserGroupId.user_id == user_id, + User__ExternalUserGroupId.external_user_group_id.in_(group_ids), + ) + ).all() + return list(user_ext_groups) diff --git a/backend/ee/onyx/external_permissions/post_query_censoring.py b/backend/ee/onyx/external_permissions/post_query_censoring.py new file mode 100644 index 00000000000..4d25643eb7e --- /dev/null +++ b/backend/ee/onyx/external_permissions/post_query_censoring.py @@ -0,0 +1,84 @@ +from collections.abc import Callable + +from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs +from ee.onyx.external_permissions.salesforce.postprocessing import ( + censor_salesforce_chunks, +) +from onyx.configs.constants import DocumentSource +from onyx.context.search.pipeline import InferenceChunk +from onyx.db.engine import get_session_context_manager +from onyx.db.models import User +from onyx.utils.logger import setup_logger + +logger = setup_logger() + +DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION: dict[ + DocumentSource, + # list of chunks to be censored and the user email. returns censored chunks + Callable[[list[InferenceChunk], str], list[InferenceChunk]], +] = { + DocumentSource.SALESFORCE: censor_salesforce_chunks, +} + + +def _get_all_censoring_enabled_sources() -> set[DocumentSource]: + """ + Returns the set of sources that have censoring enabled. + This is based on if the access_type is set to sync and the connector + source is included in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION. + + NOTE: This means if there is a source has a single cc_pair that is sync, + all chunks for that source will be censored, even if the connector that + indexed that chunk is not sync. This was done to avoid getting the cc_pair + for every single chunk. + """ + with get_session_context_manager() as db_session: + enabled_sync_connectors = get_all_auto_sync_cc_pairs(db_session) + return { + cc_pair.connector.source + for cc_pair in enabled_sync_connectors + if cc_pair.connector.source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION + } + + +# NOTE: This is only called if ee is enabled. +def _post_query_chunk_censoring( + chunks: list[InferenceChunk], + user: User | None, +) -> list[InferenceChunk]: + """ + This function checks all chunks to see if they need to be sent to a censoring + function. If they do, it sends them to the censoring function and returns the + censored chunks. If they don't, it returns the original chunks. + """ + if user is None: + # if user is None, permissions are not enforced + return chunks + + chunks_to_keep = [] + chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {} + + sources_to_censor = _get_all_censoring_enabled_sources() + for chunk in chunks: + # Separate out chunks that require permission post-processing by source + if chunk.source_type in sources_to_censor: + chunks_to_process.setdefault(chunk.source_type, []).append(chunk) + else: + chunks_to_keep.append(chunk) + + # For each source, filter out the chunks using the permission + # check function for that source + # TODO: Use a threadpool/multiprocessing to process the sources in parallel + for source, chunks_for_source in chunks_to_process.items(): + censor_chunks_for_source = DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION[source] + try: + censored_chunks = censor_chunks_for_source(chunks_for_source, user.email) + except Exception as e: + logger.exception( + f"Failed to censor chunks for source {source} so throwing out all" + f" chunks for this source and continuing: {e}" + ) + continue + chunks_to_keep.extend(censored_chunks) + + return chunks_to_keep diff --git a/backend/ee/onyx/external_permissions/salesforce/postprocessing.py b/backend/ee/onyx/external_permissions/salesforce/postprocessing.py new file mode 100644 index 00000000000..58480aa24a6 --- /dev/null +++ b/backend/ee/onyx/external_permissions/salesforce/postprocessing.py @@ -0,0 +1,226 @@ +import time + +from ee.onyx.db.external_perm import fetch_external_groups_for_user_email_and_group_ids +from ee.onyx.external_permissions.salesforce.utils import ( + get_any_salesforce_client_for_doc_id, +) +from ee.onyx.external_permissions.salesforce.utils import get_objects_access_for_user_id +from ee.onyx.external_permissions.salesforce.utils import ( + get_salesforce_user_id_from_email, +) +from onyx.configs.app_configs import BLURB_SIZE +from onyx.context.search.models import InferenceChunk +from onyx.db.engine import get_session_context_manager +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +# Types +ChunkKey = tuple[str, int] # (doc_id, chunk_id) +ContentRange = tuple[int, int | None] # (start_index, end_index) None means to the end + + +# NOTE: Used for testing timing +def _get_dummy_object_access_map( + object_ids: set[str], user_email: str, chunks: list[InferenceChunk] +) -> dict[str, bool]: + time.sleep(0.15) + # return {object_id: True for object_id in object_ids} + import random + + return {object_id: random.choice([True, False]) for object_id in object_ids} + + +def _get_objects_access_for_user_email_from_salesforce( + object_ids: set[str], + user_email: str, + chunks: list[InferenceChunk], +) -> dict[str, bool] | None: + """ + This function wraps the salesforce call as we may want to change how this + is done in the future. (E.g. replace it with the above function) + """ + # This is cached in the function so the first query takes an extra 0.1-0.3 seconds + # but subsequent queries for this source are essentially instant + first_doc_id = chunks[0].document_id + with get_session_context_manager() as db_session: + salesforce_client = get_any_salesforce_client_for_doc_id( + db_session, first_doc_id + ) + + # This is cached in the function so the first query takes an extra 0.1-0.3 seconds + # but subsequent queries by the same user are essentially instant + start_time = time.time() + user_id = get_salesforce_user_id_from_email(salesforce_client, user_email) + end_time = time.time() + logger.info( + f"Time taken to get Salesforce user ID: {end_time - start_time} seconds" + ) + if user_id is None: + return None + + # This is the only query that is not cached in the function + # so it takes 0.1-0.2 seconds total + object_id_to_access = get_objects_access_for_user_id( + salesforce_client, user_id, list(object_ids) + ) + return object_id_to_access + + +def _extract_salesforce_object_id_from_url(url: str) -> str: + return url.split("/")[-1] + + +def _get_object_ranges_for_chunk( + chunk: InferenceChunk, +) -> dict[str, list[ContentRange]]: + """ + Given a chunk, return a dictionary of salesforce object ids and the content ranges + for that object id in the current chunk + """ + if chunk.source_links is None: + return {} + + object_ranges: dict[str, list[ContentRange]] = {} + end_index = None + descending_source_links = sorted( + chunk.source_links.items(), key=lambda x: x[0], reverse=True + ) + for start_index, url in descending_source_links: + object_id = _extract_salesforce_object_id_from_url(url) + if object_id not in object_ranges: + object_ranges[object_id] = [] + object_ranges[object_id].append((start_index, end_index)) + end_index = start_index + return object_ranges + + +def _create_empty_censored_chunk(uncensored_chunk: InferenceChunk) -> InferenceChunk: + """ + Create a copy of the unfiltered chunk where potentially sensitive content is removed + to be added later if the user has access to each of the sub-objects + """ + empty_censored_chunk = InferenceChunk( + **uncensored_chunk.model_dump(), + ) + empty_censored_chunk.content = "" + empty_censored_chunk.blurb = "" + empty_censored_chunk.source_links = {} + return empty_censored_chunk + + +def _update_censored_chunk( + censored_chunk: InferenceChunk, + uncensored_chunk: InferenceChunk, + content_range: ContentRange, +) -> InferenceChunk: + """ + Update the filtered chunk with the content and source links from the unfiltered chunk using the content ranges + """ + start_index, end_index = content_range + + # Update the content of the filtered chunk + permitted_content = uncensored_chunk.content[start_index:end_index] + permitted_section_start_index = len(censored_chunk.content) + censored_chunk.content = permitted_content + censored_chunk.content + + # Update the source links of the filtered chunk + if uncensored_chunk.source_links is not None: + if censored_chunk.source_links is None: + censored_chunk.source_links = {} + link_content = uncensored_chunk.source_links[start_index] + censored_chunk.source_links[permitted_section_start_index] = link_content + + # Update the blurb of the filtered chunk + censored_chunk.blurb = censored_chunk.content[:BLURB_SIZE] + + return censored_chunk + + +# TODO: Generalize this to other sources +def censor_salesforce_chunks( + chunks: list[InferenceChunk], + user_email: str, + # This is so we can provide a mock access map for testing + access_map: dict[str, bool] | None = None, +) -> list[InferenceChunk]: + # object_id -> list[((doc_id, chunk_id), (start_index, end_index))] + object_to_content_map: dict[str, list[tuple[ChunkKey, ContentRange]]] = {} + + # (doc_id, chunk_id) -> chunk + uncensored_chunks: dict[ChunkKey, InferenceChunk] = {} + + # keep track of all object ids that we have seen to make it easier to get + # the access for these object ids + object_ids: set[str] = set() + + for chunk in chunks: + chunk_key = (chunk.document_id, chunk.chunk_id) + # create a dictionary to quickly look up the unfiltered chunk + uncensored_chunks[chunk_key] = chunk + + # for each chunk, get a dictionary of object ids and the content ranges + # for that object id in the current chunk + object_ranges_for_chunk = _get_object_ranges_for_chunk(chunk) + for object_id, ranges in object_ranges_for_chunk.items(): + object_ids.add(object_id) + for start_index, end_index in ranges: + object_to_content_map.setdefault(object_id, []).append( + (chunk_key, (start_index, end_index)) + ) + + # This is so we can provide a mock access map for testing + if access_map is None: + access_map = _get_objects_access_for_user_email_from_salesforce( + object_ids=object_ids, + user_email=user_email, + chunks=chunks, + ) + if access_map is None: + # If the user is not found in Salesforce, access_map will be None + # so we should just return an empty list because no chunks will be + # censored + return [] + + censored_chunks: dict[ChunkKey, InferenceChunk] = {} + for object_id, content_list in object_to_content_map.items(): + # if the user does not have access to the object, or the object is not in the + # access_map, do not include its content in the filtered chunks + if not access_map.get(object_id, False): + continue + + # if we got this far, the user has access to the object so we can create or update + # the filtered chunk(s) for this object + # NOTE: we only create a censored chunk if the user has access to some + # part of the chunk + for chunk_key, content_range in content_list: + if chunk_key not in censored_chunks: + censored_chunks[chunk_key] = _create_empty_censored_chunk( + uncensored_chunks[chunk_key] + ) + + uncensored_chunk = uncensored_chunks[chunk_key] + censored_chunk = _update_censored_chunk( + censored_chunk=censored_chunks[chunk_key], + uncensored_chunk=uncensored_chunk, + content_range=content_range, + ) + censored_chunks[chunk_key] = censored_chunk + + return list(censored_chunks.values()) + + +# NOTE: This is not used anywhere. +def _get_objects_access_for_user_email( + object_ids: set[str], user_email: str +) -> dict[str, bool]: + with get_session_context_manager() as db_session: + external_groups = fetch_external_groups_for_user_email_and_group_ids( + db_session=db_session, + user_email=user_email, + # Maybe make a function that adds a salesforce prefix to the group ids + group_ids=list(object_ids), + ) + external_group_ids = {group.external_user_group_id for group in external_groups} + return {group_id: group_id in external_group_ids for group_id in object_ids} diff --git a/backend/ee/onyx/external_permissions/salesforce/utils.py b/backend/ee/onyx/external_permissions/salesforce/utils.py new file mode 100644 index 00000000000..289e14e37e2 --- /dev/null +++ b/backend/ee/onyx/external_permissions/salesforce/utils.py @@ -0,0 +1,174 @@ +from simple_salesforce import Salesforce +from sqlalchemy.orm import Session + +from onyx.connectors.salesforce.sqlite_functions import get_user_id_by_email +from onyx.connectors.salesforce.sqlite_functions import init_db +from onyx.connectors.salesforce.sqlite_functions import NULL_ID_STRING +from onyx.connectors.salesforce.sqlite_functions import update_email_to_id_table +from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id +from onyx.db.document import get_cc_pairs_for_document +from onyx.utils.logger import setup_logger + +logger = setup_logger() + +_ANY_SALESFORCE_CLIENT: Salesforce | None = None + + +def get_any_salesforce_client_for_doc_id( + db_session: Session, doc_id: str +) -> Salesforce: + """ + We create a salesforce client for the first cc_pair for the first doc_id where + salesforce censoring is enabled. After that we just cache and reuse the same + client for all queries. + + We do this to reduce the number of postgres queries we make at query time. + + This may be problematic if they are using multiple cc_pairs for salesforce. + E.g. there are 2 different credential sets for 2 different salesforce cc_pairs + but only one has the permissions to access the permissions needed for the query. + """ + global _ANY_SALESFORCE_CLIENT + if _ANY_SALESFORCE_CLIENT is None: + cc_pairs = get_cc_pairs_for_document(db_session, doc_id) + first_cc_pair = cc_pairs[0] + credential_json = first_cc_pair.credential.credential_json + _ANY_SALESFORCE_CLIENT = Salesforce( + username=credential_json["sf_username"], + password=credential_json["sf_password"], + security_token=credential_json["sf_security_token"], + ) + return _ANY_SALESFORCE_CLIENT + + +def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None: + query = f"SELECT Id FROM User WHERE Email = '{user_email}'" + result = sf_client.query(query) + if len(result["records"]) == 0: + return None + return result["records"][0]["Id"] + + +# This contains only the user_ids that we have found in Salesforce. +# If we don't know their user_id, we don't store anything in the cache. +_CACHED_SF_EMAIL_TO_ID_MAP: dict[str, str] = {} + + +def get_salesforce_user_id_from_email( + sf_client: Salesforce, + user_email: str, +) -> str | None: + """ + We cache this so we don't have to query Salesforce for every query and salesforce + user IDs never change. + Memory usage is fine because we just store 2 small strings per user. + + If the email is not in the cache, we check the local salesforce database for the info. + If the user is not found in the local salesforce database, we query Salesforce. + Whatever we get back from Salesforce is added to the database. + If no user_id is found, we add a NULL_ID_STRING to the database for that email so + we don't query Salesforce again (which is slow) but we still check the local salesforce + database every query until a user id is found. This is acceptable because the query time + is quite fast. + If a user_id is created in Salesforce, it will be added to the local salesforce database + next time the connector is run. Then that value will be found in this function and cached. + + NOTE: First time this runs, it may be slow if it hasn't already been updated in the local + salesforce database. (Around 0.1-0.3 seconds) + If it's cached or stored in the local salesforce database, it's fast (<0.001 seconds). + """ + global _CACHED_SF_EMAIL_TO_ID_MAP + if user_email in _CACHED_SF_EMAIL_TO_ID_MAP: + if _CACHED_SF_EMAIL_TO_ID_MAP[user_email] is not None: + return _CACHED_SF_EMAIL_TO_ID_MAP[user_email] + + db_exists = True + try: + # Check if the user is already in the database + user_id = get_user_id_by_email(user_email) + except Exception: + init_db() + try: + user_id = get_user_id_by_email(user_email) + except Exception as e: + logger.error(f"Error checking if user is in database: {e}") + user_id = None + db_exists = False + + # If no entry is found in the database (indicated by user_id being None)... + if user_id is None: + # ...query Salesforce and store the result in the database + user_id = _query_salesforce_user_id(sf_client, user_email) + if db_exists: + update_email_to_id_table(user_email, user_id) + return user_id + elif user_id is None: + return None + elif user_id == NULL_ID_STRING: + return None + # If the found user_id is real, cache it + _CACHED_SF_EMAIL_TO_ID_MAP[user_email] = user_id + return user_id + + +_MAX_RECORD_IDS_PER_QUERY = 200 + + +def get_objects_access_for_user_id( + salesforce_client: Salesforce, + user_id: str, + record_ids: list[str], +) -> dict[str, bool]: + """ + Salesforce has a limit of 200 record ids per query. So we just truncate + the list of record ids to 200. We only ever retrieve 50 chunks at a time + so this should be fine (unlikely that we retrieve all 50 chunks contain + 4 unique objects). + If we decide this isn't acceptable we can use multiple queries but they + should be in parallel so query time doesn't get too long. + """ + truncated_record_ids = record_ids[:_MAX_RECORD_IDS_PER_QUERY] + record_ids_str = "'" + "','".join(truncated_record_ids) + "'" + access_query = f""" + SELECT RecordId, HasReadAccess + FROM UserRecordAccess + WHERE RecordId IN ({record_ids_str}) + AND UserId = '{user_id}' + """ + result = salesforce_client.query_all(access_query) + return {record["RecordId"]: record["HasReadAccess"] for record in result["records"]} + + +_CC_PAIR_ID_SALESFORCE_CLIENT_MAP: dict[int, Salesforce] = {} +_DOC_ID_TO_CC_PAIR_ID_MAP: dict[str, int] = {} + + +# NOTE: This is not used anywhere. +def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Salesforce: + """ + Uses a document id to get the cc_pair that indexed that document and uses the credentials + for that cc_pair to create a Salesforce client. + Problems: + - There may be multiple cc_pairs for a document, and we don't know which one to use. + - right now we just use the first one + - Building a new Salesforce client for each document is slow. + - Memory usage could be an issue as we build these dictionaries. + """ + if doc_id not in _DOC_ID_TO_CC_PAIR_ID_MAP: + cc_pairs = get_cc_pairs_for_document(db_session, doc_id) + first_cc_pair = cc_pairs[0] + _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id] = first_cc_pair.id + + cc_pair_id = _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id] + if cc_pair_id not in _CC_PAIR_ID_SALESFORCE_CLIENT_MAP: + cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + if cc_pair is None: + raise ValueError(f"CC pair {cc_pair_id} not found") + credential_json = cc_pair.credential.credential_json + _CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id] = Salesforce( + username=credential_json["sf_username"], + password=credential_json["sf_password"], + security_token=credential_json["sf_security_token"], + ) + + return _CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id] diff --git a/backend/ee/onyx/external_permissions/sync_params.py b/backend/ee/onyx/external_permissions/sync_params.py index 7b45720f71b..1669dee6a05 100644 --- a/backend/ee/onyx/external_permissions/sync_params.py +++ b/backend/ee/onyx/external_permissions/sync_params.py @@ -8,6 +8,9 @@ from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync +from ee.onyx.external_permissions.post_query_censoring import ( + DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION, +) from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync from onyx.access.models import DocExternalAccess from onyx.configs.constants import DocumentSource @@ -71,4 +74,7 @@ def check_if_valid_sync_source(source_type: DocumentSource) -> bool: - return source_type in DOC_PERMISSIONS_FUNC_MAP + return ( + source_type in DOC_PERMISSIONS_FUNC_MAP + or source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION + ) diff --git a/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py b/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py index faa9240f508..028a9e45df0 100644 --- a/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py +++ b/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py @@ -16,6 +16,9 @@ from ee.onyx.db.document import upsert_document_external_perms from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP +from ee.onyx.external_permissions.sync_params import ( + DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION, +) from onyx.access.models import DocExternalAccess from onyx.background.celery.apps.app_base import task_logger from onyx.configs.app_configs import JOB_TIMEOUT @@ -286,6 +289,8 @@ def connector_permission_sync_generator_task( doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type) if doc_sync_func is None: + if source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION: + return None raise ValueError( f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}" ) diff --git a/backend/onyx/connectors/salesforce/connector.py b/backend/onyx/connectors/salesforce/connector.py index 6ada66387f4..4aa9b67a884 100644 --- a/backend/onyx/connectors/salesforce/connector.py +++ b/backend/onyx/connectors/salesforce/connector.py @@ -4,34 +4,29 @@ from simple_salesforce import Salesforce from onyx.configs.app_configs import INDEX_BATCH_SIZE -from onyx.configs.constants import DocumentSource -from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnector -from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import SlimDocument -from onyx.connectors.salesforce.doc_conversion import extract_section +from onyx.connectors.salesforce.doc_conversion import convert_sf_object_to_doc +from onyx.connectors.salesforce.doc_conversion import ID_PREFIX from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel from onyx.connectors.salesforce.salesforce_calls import get_all_children_of_sf_type from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type -from onyx.connectors.salesforce.sqlite_functions import get_child_ids from onyx.connectors.salesforce.sqlite_functions import get_record from onyx.connectors.salesforce.sqlite_functions import init_db from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv -from onyx.connectors.salesforce.utils import SalesforceObject from onyx.utils.logger import setup_logger logger = setup_logger() _DEFAULT_PARENT_OBJECT_TYPES = ["Account"] -_ID_PREFIX = "SALESFORCE_" class SalesforceConnector(LoadConnector, PollConnector, SlimConnector): @@ -65,46 +60,6 @@ def sf_client(self) -> Salesforce: raise ConnectorMissingCredentialError("Salesforce") return self._sf_client - def _extract_primary_owners( - self, sf_object: SalesforceObject - ) -> list[BasicExpertInfo] | None: - object_dict = sf_object.data - if not (last_modified_by_id := object_dict.get("LastModifiedById")): - return None - if not (last_modified_by := get_record(last_modified_by_id)): - return None - if not (last_modified_by_name := last_modified_by.data.get("Name")): - return None - primary_owners = [BasicExpertInfo(display_name=last_modified_by_name)] - return primary_owners - - def _convert_object_instance_to_document( - self, sf_object: SalesforceObject - ) -> Document: - object_dict = sf_object.data - salesforce_id = object_dict["Id"] - onyx_salesforce_id = f"{_ID_PREFIX}{salesforce_id}" - base_url = f"https://{self.sf_client.sf_instance}" - extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"]) - extracted_semantic_identifier = object_dict.get("Name", "Unknown Object") - - sections = [extract_section(sf_object, base_url)] - for id in get_child_ids(sf_object.id): - if not (child_object := get_record(id)): - continue - sections.append(extract_section(child_object, base_url)) - - doc = Document( - id=onyx_salesforce_id, - sections=sections, - source=DocumentSource.SALESFORCE, - semantic_identifier=extracted_semantic_identifier, - doc_updated_at=extracted_doc_updated_at, - primary_owners=self._extract_primary_owners(sf_object), - metadata={}, - ) - return doc - def _fetch_from_salesforce( self, start: SecondsSinceUnixEpoch | None = None, @@ -126,6 +81,9 @@ def _fetch_from_salesforce( f"Found {len(child_types)} child types for {parent_object_type}" ) + # Always want to make sure user is grabbed for permissioning purposes + all_object_types.add("User") + logger.info(f"Found total of {len(all_object_types)} object types to fetch") logger.debug(f"All object types: {all_object_types}") @@ -169,9 +127,6 @@ def _fetch_from_salesforce( logger.debug( f"Added {len(new_ids)} new/updated records for {object_type}" ) - # Remove the csv file after it has been used - # to successfully update the db - os.remove(csv_path) logger.info(f"Found {len(updated_ids)} total updated records") logger.info( @@ -196,7 +151,10 @@ def _fetch_from_salesforce( continue docs_to_yield.append( - self._convert_object_instance_to_document(parent_object) + convert_sf_object_to_doc( + sf_object=parent_object, + sf_instance=self.sf_client.sf_instance, + ) ) docs_processed += 1 @@ -225,7 +183,7 @@ def retrieve_all_slim_documents( query_result = self.sf_client.query_all(query) doc_metadata_list.extend( SlimDocument( - id=f"{_ID_PREFIX}{instance_dict.get('Id', '')}", + id=f"{ID_PREFIX}{instance_dict.get('Id', '')}", perm_sync_data={}, ) for instance_dict in query_result["records"] diff --git a/backend/onyx/connectors/salesforce/doc_conversion.py b/backend/onyx/connectors/salesforce/doc_conversion.py index 908b39e80a4..e6acaf2e0cc 100644 --- a/backend/onyx/connectors/salesforce/doc_conversion.py +++ b/backend/onyx/connectors/salesforce/doc_conversion.py @@ -1,8 +1,18 @@ import re -from collections import OrderedDict +from onyx.configs.constants import DocumentSource +from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc +from onyx.connectors.models import BasicExpertInfo +from onyx.connectors.models import Document from onyx.connectors.models import Section +from onyx.connectors.salesforce.sqlite_functions import get_child_ids +from onyx.connectors.salesforce.sqlite_functions import get_record from onyx.connectors.salesforce.utils import SalesforceObject +from onyx.utils.logger import setup_logger + +logger = setup_logger() + +ID_PREFIX = "SALESFORCE_" # All of these types of keys are handled by specific fields in the doc # conversion process (E.g. URLs) or are not useful for the user (E.g. UUIDs) @@ -103,54 +113,72 @@ def _extract_dict_text(raw_dict: dict) -> str: return natural_language_for_dict -def extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section: +def _extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section: return Section( text=_extract_dict_text(salesforce_object.data), link=f"{base_url}/{salesforce_object.id}", ) -def _field_value_is_child_object(field_value: dict) -> bool: - """ - Checks if the field value is a child object. - """ - return ( - isinstance(field_value, OrderedDict) - and "records" in field_value.keys() - and isinstance(field_value["records"], list) - and len(field_value["records"]) > 0 - and "Id" in field_value["records"][0].keys() +def _extract_primary_owners( + sf_object: SalesforceObject, +) -> list[BasicExpertInfo] | None: + object_dict = sf_object.data + if not (last_modified_by_id := object_dict.get("LastModifiedById")): + logger.warning(f"No LastModifiedById found for {sf_object.id}") + return None + if not (last_modified_by := get_record(last_modified_by_id)): + logger.warning(f"No LastModifiedBy found for {last_modified_by_id}") + return None + + user_data = last_modified_by.data + expert_info = BasicExpertInfo( + first_name=user_data.get("FirstName"), + last_name=user_data.get("LastName"), + email=user_data.get("Email"), + display_name=user_data.get("Name"), ) - -def _extract_sections(salesforce_object: dict, base_url: str) -> list[Section]: - """ - This goes through the salesforce_object and extracts the top level fields as a Section. - It also goes through the child objects and extracts them as Sections. - """ - top_level_dict = {} - - child_object_sections = [] - for field_name, field_value in salesforce_object.items(): - # If the field value is not a child object, add it to the top level dict - # to turn into text for the top level section - if not _field_value_is_child_object(field_value): - top_level_dict[field_name] = field_value + # Check if all fields are None + if all( + value is None + for value in [ + expert_info.first_name, + expert_info.last_name, + expert_info.email, + expert_info.display_name, + ] + ): + logger.warning(f"No identifying information found for user {user_data}") + return None + + return [expert_info] + + +def convert_sf_object_to_doc( + sf_object: SalesforceObject, + sf_instance: str, +) -> Document: + object_dict = sf_object.data + salesforce_id = object_dict["Id"] + onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}" + base_url = f"https://{sf_instance}" + extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"]) + extracted_semantic_identifier = object_dict.get("Name", "Unknown Object") + + sections = [_extract_section(sf_object, base_url)] + for id in get_child_ids(sf_object.id): + if not (child_object := get_record(id)): continue - - # If the field value is a child object, extract the child objects and add them as sections - for record in field_value["records"]: - child_object_id = record["Id"] - child_object_sections.append( - Section( - text=f"Child Object(s): {field_name}\n{_extract_dict_text(record)}", - link=f"{base_url}/{child_object_id}", - ) - ) - - top_level_id = salesforce_object["Id"] - top_level_section = Section( - text=_extract_dict_text(top_level_dict), - link=f"{base_url}/{top_level_id}", + sections.append(_extract_section(child_object, base_url)) + + doc = Document( + id=onyx_salesforce_id, + sections=sections, + source=DocumentSource.SALESFORCE, + semantic_identifier=extracted_semantic_identifier, + doc_updated_at=extracted_doc_updated_at, + primary_owners=_extract_primary_owners(sf_object), + metadata={}, ) - return [top_level_section, *child_object_sections] + return doc diff --git a/backend/onyx/connectors/salesforce/salesforce_calls.py b/backend/onyx/connectors/salesforce/salesforce_calls.py index f569b28b873..858c240b3ec 100644 --- a/backend/onyx/connectors/salesforce/salesforce_calls.py +++ b/backend/onyx/connectors/salesforce/salesforce_calls.py @@ -77,25 +77,28 @@ def _get_all_queryable_fields_of_sf_type( object_description = _get_sf_type_object_json(sf_client, sf_type) fields: list[dict[str, Any]] = object_description["fields"] valid_fields: set[str] = set() - compound_field_names: set[str] = set() + field_names_to_remove: set[str] = set() for field in fields: if compound_field_name := field.get("compoundFieldName"): - compound_field_names.add(compound_field_name) + # We do want to get name fields even if they are compound + if not field.get("nameField"): + field_names_to_remove.add(compound_field_name) if field.get("type", "base64") == "base64": continue if field_name := field.get("name"): valid_fields.add(field_name) - return list(valid_fields - compound_field_names) + return list(valid_fields - field_names_to_remove) -def _check_if_object_type_is_empty(sf_client: Salesforce, sf_type: str) -> bool: +def _check_if_object_type_is_empty( + sf_client: Salesforce, sf_type: str, time_filter: str +) -> bool: """ - Send a small query to check if the object type is empty so we don't - perform extra bulk queries + Use the rest api to check to make sure the query will result in a non-empty response """ try: - query = f"SELECT Count() FROM {sf_type} LIMIT 1" + query = f"SELECT Count() FROM {sf_type}{time_filter} LIMIT 1" result = sf_client.query(query) if result["totalSize"] == 0: return False @@ -134,7 +137,7 @@ def _bulk_retrieve_from_salesforce( sf_type: str, time_filter: str, ) -> tuple[str, list[str] | None]: - if not _check_if_object_type_is_empty(sf_client, sf_type): + if not _check_if_object_type_is_empty(sf_client, sf_type, time_filter): return sf_type, None if existing_csvs := _check_for_existing_csvs(sf_type): diff --git a/backend/onyx/connectors/salesforce/sqlite_functions.py b/backend/onyx/connectors/salesforce/sqlite_functions.py index eb8d72ba4f3..029b4e1238d 100644 --- a/backend/onyx/connectors/salesforce/sqlite_functions.py +++ b/backend/onyx/connectors/salesforce/sqlite_functions.py @@ -40,20 +40,20 @@ def get_db_connection( def init_db() -> None: """Initialize the SQLite database with required tables if they don't exist.""" - if os.path.exists(get_sqlite_db_path()): - return - # Create database directory if it doesn't exist os.makedirs(os.path.dirname(get_sqlite_db_path()), exist_ok=True) with get_db_connection("EXCLUSIVE") as conn: cursor = conn.cursor() - # Enable WAL mode for better concurrent access and write performance - cursor.execute("PRAGMA journal_mode=WAL") - cursor.execute("PRAGMA synchronous=NORMAL") - cursor.execute("PRAGMA temp_store=MEMORY") - cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache + db_exists = os.path.exists(get_sqlite_db_path()) + + if not db_exists: + # Enable WAL mode for better concurrent access and write performance + cursor.execute("PRAGMA journal_mode=WAL") + cursor.execute("PRAGMA synchronous=NORMAL") + cursor.execute("PRAGMA temp_store=MEMORY") + cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache # Main table for storing Salesforce objects cursor.execute( @@ -90,49 +90,69 @@ def init_db() -> None: """ ) - # Always recreate indexes to ensure they exist - cursor.execute("DROP INDEX IF EXISTS idx_object_type") - cursor.execute("DROP INDEX IF EXISTS idx_parent_id") - cursor.execute("DROP INDEX IF EXISTS idx_child_parent") - cursor.execute("DROP INDEX IF EXISTS idx_object_type_id") - cursor.execute("DROP INDEX IF EXISTS idx_relationship_types_lookup") - - # Create covering indexes for common queries + # Create a table for User email to ID mapping if it doesn't exist cursor.execute( + """ + CREATE TABLE IF NOT EXISTS user_email_map ( + email TEXT PRIMARY KEY, + user_id TEXT, -- Nullable to allow for users without IDs + FOREIGN KEY (user_id) REFERENCES salesforce_objects(id) + ) WITHOUT ROWID + """ + ) + + # Create indexes if they don't exist (SQLite ignores IF NOT EXISTS for indexes) + def create_index_if_not_exists(index_name: str, create_statement: str) -> None: + cursor.execute( + f"SELECT name FROM sqlite_master WHERE type='index' AND name='{index_name}'" + ) + if not cursor.fetchone(): + cursor.execute(create_statement) + + create_index_if_not_exists( + "idx_object_type", """ CREATE INDEX idx_object_type ON salesforce_objects(object_type, id) WHERE object_type IS NOT NULL - """ + """, ) - cursor.execute( + create_index_if_not_exists( + "idx_parent_id", """ CREATE INDEX idx_parent_id ON relationships(parent_id, child_id) - """ + """, ) - cursor.execute( + create_index_if_not_exists( + "idx_child_parent", """ CREATE INDEX idx_child_parent ON relationships(child_id) WHERE child_id IS NOT NULL - """ + """, ) - # New composite index for fast parent type lookups - cursor.execute( + create_index_if_not_exists( + "idx_relationship_types_lookup", """ CREATE INDEX idx_relationship_types_lookup ON relationship_types(parent_type, child_id, parent_id) - """ + """, ) # Analyze tables to help query planner cursor.execute("ANALYZE relationships") cursor.execute("ANALYZE salesforce_objects") cursor.execute("ANALYZE relationship_types") + cursor.execute("ANALYZE user_email_map") + + # If database already existed but user_email_map needs to be populated + cursor.execute("SELECT COUNT(*) FROM user_email_map") + if cursor.fetchone()[0] == 0: + _update_user_email_map(conn) conn.commit() @@ -203,7 +223,27 @@ def _update_relationship_tables( raise -def update_sf_db_with_csv(object_type: str, csv_download_path: str) -> list[str]: +def _update_user_email_map(conn: sqlite3.Connection) -> None: + """Update the user_email_map table with current User objects. + Called internally by update_sf_db_with_csv when User objects are updated. + """ + cursor = conn.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO user_email_map (email, user_id) + SELECT json_extract(data, '$.Email'), id + FROM salesforce_objects + WHERE object_type = 'User' + AND json_extract(data, '$.Email') IS NOT NULL + """ + ) + + +def update_sf_db_with_csv( + object_type: str, + csv_download_path: str, + delete_csv_after_use: bool = True, +) -> list[str]: """Update the SF DB with a CSV file using SQLite storage.""" updated_ids = [] @@ -249,8 +289,17 @@ def update_sf_db_with_csv(object_type: str, csv_download_path: str) -> list[str] _update_relationship_tables(conn, id, parent_ids) updated_ids.append(id) + # If we're updating User objects, update the email map + if object_type == "User": + _update_user_email_map(conn) + conn.commit() + if delete_csv_after_use: + # Remove the csv file after it has been used + # to successfully update the db + os.remove(csv_download_path) + return updated_ids @@ -329,6 +378,9 @@ def get_affected_parent_ids_by_type( cursor = conn.cursor() for batch_ids in updated_ids_batches: + batch_ids = list(set(batch_ids) - updated_parent_ids) + if not batch_ids: + continue id_placeholders = ",".join(["?" for _ in batch_ids]) for parent_type in parent_types: @@ -384,3 +436,40 @@ def has_at_least_one_object_of_type(object_type: str) -> bool: ) count = cursor.fetchone()[0] return count > 0 + + +# NULL_ID_STRING is used to indicate that the user ID was queried but not found +# As opposed to None because it has yet to be queried at all +NULL_ID_STRING = "N/A" + + +def get_user_id_by_email(email: str) -> str | None: + """Get the Salesforce User ID for a given email address. + + Args: + email: The email address to look up + + Returns: + A tuple of (was_found, user_id): + - was_found: True if the email exists in the table, False if not found + - user_id: The Salesforce User ID if exists, None otherwise + """ + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT user_id FROM user_email_map WHERE email = ?", (email,)) + result = cursor.fetchone() + if result is None: + return None + return result[0] + + +def update_email_to_id_table(email: str, id: str | None) -> None: + """Update the email to ID map table with a new email and ID.""" + id_to_use = id or NULL_ID_STRING + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "INSERT OR REPLACE INTO user_email_map (email, user_id) VALUES (?, ?)", + (email, id_to_use), + ) + conn.commit() diff --git a/backend/onyx/context/search/pipeline.py b/backend/onyx/context/search/pipeline.py index be7e288799a..c6f8d8cbea8 100644 --- a/backend/onyx/context/search/pipeline.py +++ b/backend/onyx/context/search/pipeline.py @@ -37,6 +37,7 @@ from onyx.utils.threadpool_concurrency import FunctionCall from onyx.utils.threadpool_concurrency import run_functions_in_parallel from onyx.utils.timing import log_function_time +from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop logger = setup_logger() @@ -163,6 +164,17 @@ def _get_sections(self) -> list[InferenceSection]: # These chunks are ordered, deduped, and contain no large chunks retrieved_chunks = self._get_chunks() + # If ee is enabled, censor the chunk sections based on user access + # Otherwise, return the retrieved chunks + censored_chunks = fetch_ee_implementation_or_noop( + "onyx.external_permissions.post_query_censoring", + "_post_query_chunk_censoring", + retrieved_chunks, + )( + chunks=retrieved_chunks, + user=self.user, + ) + above = self.search_query.chunks_above below = self.search_query.chunks_below @@ -175,7 +187,7 @@ def _get_sections(self) -> list[InferenceSection]: seen_document_ids = set() # This preserves the ordering since the chunks are retrieved in score order - for chunk in retrieved_chunks: + for chunk in censored_chunks: if chunk.document_id not in seen_document_ids: seen_document_ids.add(chunk.document_id) chunk_requests.append( @@ -225,7 +237,7 @@ def _get_sections(self) -> list[InferenceSection]: # This maintains the original chunks ordering. Note, we cannot simply sort by score here # as reranking flow may wipe the scores for a lot of the chunks. doc_chunk_ranges_map = defaultdict(list) - for chunk in retrieved_chunks: + for chunk in censored_chunks: # The list of ranges for each document is ordered by score doc_chunk_ranges_map[chunk.document_id].append( ChunkRange( @@ -274,11 +286,11 @@ def _get_sections(self) -> list[InferenceSection]: # In case of failed parallel calls to Vespa, at least we should have the initial retrieved chunks doc_chunk_ind_to_chunk.update( - {(chunk.document_id, chunk.chunk_id): chunk for chunk in retrieved_chunks} + {(chunk.document_id, chunk.chunk_id): chunk for chunk in censored_chunks} ) # Build the surroundings for all of the initial retrieved chunks - for chunk in retrieved_chunks: + for chunk in censored_chunks: start_ind = max(0, chunk.chunk_id - above) end_ind = chunk.chunk_id + below diff --git a/backend/onyx/db/document.py b/backend/onyx/db/document.py index 4ebcc2c7d0d..7f11b64d824 100644 --- a/backend/onyx/db/document.py +++ b/backend/onyx/db/document.py @@ -20,10 +20,12 @@ from sqlalchemy.sql.expression import null from onyx.configs.constants import DEFAULT_BOOST +from onyx.configs.constants import DocumentSource from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.feedback import delete_document_feedback_for_documents__no_commit +from onyx.db.models import Connector from onyx.db.models import ConnectorCredentialPair from onyx.db.models import Credential from onyx.db.models import Document as DbDocument @@ -626,6 +628,60 @@ def get_document( return doc +def get_cc_pairs_for_document( + db_session: Session, + document_id: str, +) -> list[ConnectorCredentialPair]: + stmt = ( + select(ConnectorCredentialPair) + .join( + DocumentByConnectorCredentialPair, + and_( + DocumentByConnectorCredentialPair.connector_id + == ConnectorCredentialPair.connector_id, + DocumentByConnectorCredentialPair.credential_id + == ConnectorCredentialPair.credential_id, + ), + ) + .where(DocumentByConnectorCredentialPair.id == document_id) + ) + return list(db_session.execute(stmt).scalars().all()) + + +def get_document_sources( + db_session: Session, + document_ids: list[str], +) -> dict[str, DocumentSource]: + """Gets the sources for a list of document IDs. + Returns a dictionary mapping document ID to its source. + If a document has multiple sources (multiple CC pairs), returns the first one found. + """ + stmt = ( + select( + DocumentByConnectorCredentialPair.id, + Connector.source, + ) + .join( + ConnectorCredentialPair, + and_( + DocumentByConnectorCredentialPair.connector_id + == ConnectorCredentialPair.connector_id, + DocumentByConnectorCredentialPair.credential_id + == ConnectorCredentialPair.credential_id, + ), + ) + .join( + Connector, + ConnectorCredentialPair.connector_id == Connector.id, + ) + .where(DocumentByConnectorCredentialPair.id.in_(document_ids)) + .distinct() + ) + + results = db_session.execute(stmt).all() + return {doc_id: source for doc_id, source in results} + + def fetch_chunk_counts_for_documents( document_ids: list[str], db_session: Session, diff --git a/backend/onyx/indexing/models.py b/backend/onyx/indexing/models.py index e64cb9ae6b7..e536428282b 100644 --- a/backend/onyx/indexing/models.py +++ b/backend/onyx/indexing/models.py @@ -23,11 +23,13 @@ class ChunkEmbedding(BaseModel): class BaseChunk(BaseModel): chunk_id: int - blurb: str # The first sentence(s) of the first Section of the chunk + # The first sentence(s) of the first Section of the chunk + blurb: str content: str # Holds the link and the offsets into the raw Chunk text source_links: dict[int, str] | None - section_continuation: bool # True if this Chunk's start is not at the start of a Section + # True if this Chunk's start is not at the start of a Section + section_continuation: bool class DocAwareChunk(BaseChunk): diff --git a/backend/tests/unit/ee/onyx/external_permissions/salesforce/test_postprocessing.py b/backend/tests/unit/ee/onyx/external_permissions/salesforce/test_postprocessing.py new file mode 100644 index 00000000000..8b7a668210b --- /dev/null +++ b/backend/tests/unit/ee/onyx/external_permissions/salesforce/test_postprocessing.py @@ -0,0 +1,196 @@ +from datetime import datetime + +from ee.onyx.external_permissions.salesforce.postprocessing import ( + censor_salesforce_chunks, +) +from onyx.configs.app_configs import BLURB_SIZE +from onyx.configs.constants import DocumentSource +from onyx.context.search.models import InferenceChunk + + +def create_test_chunk( + doc_id: str, + chunk_id: int, + content: str, + source_links: dict[int, str] | None, +) -> InferenceChunk: + return InferenceChunk( + document_id=doc_id, + chunk_id=chunk_id, + blurb=content[:BLURB_SIZE], + content=content, + source_links=source_links, + section_continuation=False, + source_type=DocumentSource.SALESFORCE, + semantic_identifier="test_chunk", + title="Test Chunk", + boost=1, + recency_bias=1.0, + score=None, + hidden=False, + metadata={}, + match_highlights=[], + updated_at=datetime.now(), + ) + + +def test_validate_salesforce_access_single_object() -> None: + """Test filtering when chunk has a single Salesforce object reference""" + section = "This is a test document about a Salesforce object." + test_content = section + test_chunk = create_test_chunk( + doc_id="doc1", + chunk_id=1, + content=test_content, + source_links={0: "https://salesforce.com/object1"}, + ) + + # Test when user has access + filtered_chunks = censor_salesforce_chunks( + chunks=[test_chunk], + user_email="test@example.com", + access_map={"object1": True}, + ) + assert len(filtered_chunks) == 1 + assert filtered_chunks[0].content == test_content + + # Test when user doesn't have access + filtered_chunks = censor_salesforce_chunks( + chunks=[test_chunk], + user_email="test@example.com", + access_map={"object1": False}, + ) + assert len(filtered_chunks) == 0 + + +def test_validate_salesforce_access_multiple_objects() -> None: + """Test filtering when chunk has multiple Salesforce object references""" + section1 = "First part about object1. " + section2 = "Second part about object2. " + section3 = "Third part about object3." + + test_content = section1 + section2 + section3 + section1_end = len(section1) + section2_end = section1_end + len(section2) + + test_chunk = create_test_chunk( + doc_id="doc1", + chunk_id=1, + content=test_content, + source_links={ + 0: "https://salesforce.com/object1", + section1_end: "https://salesforce.com/object2", + section2_end: "https://salesforce.com/object3", + }, + ) + + # Test when user has access to all objects + filtered_chunks = censor_salesforce_chunks( + chunks=[test_chunk], + user_email="test@example.com", + access_map={ + "object1": True, + "object2": True, + "object3": True, + }, + ) + assert len(filtered_chunks) == 1 + assert filtered_chunks[0].content == test_content + + # Test when user has access to some objects + filtered_chunks = censor_salesforce_chunks( + chunks=[test_chunk], + user_email="test@example.com", + access_map={ + "object1": True, + "object2": False, + "object3": True, + }, + ) + assert len(filtered_chunks) == 1 + assert section1 in filtered_chunks[0].content + assert section2 not in filtered_chunks[0].content + assert section3 in filtered_chunks[0].content + + # Test when user has no access + filtered_chunks = censor_salesforce_chunks( + chunks=[test_chunk], + user_email="test@example.com", + access_map={ + "object1": False, + "object2": False, + "object3": False, + }, + ) + assert len(filtered_chunks) == 0 + + +def test_validate_salesforce_access_multiple_chunks() -> None: + """Test filtering when there are multiple chunks with different access patterns""" + section1 = "Content about object1" + section2 = "Content about object2" + + chunk1 = create_test_chunk( + doc_id="doc1", + chunk_id=1, + content=section1, + source_links={0: "https://salesforce.com/object1"}, + ) + chunk2 = create_test_chunk( + doc_id="doc1", + chunk_id=2, + content=section2, + source_links={0: "https://salesforce.com/object2"}, + ) + + # Test mixed access + filtered_chunks = censor_salesforce_chunks( + chunks=[chunk1, chunk2], + user_email="test@example.com", + access_map={ + "object1": True, + "object2": False, + }, + ) + assert len(filtered_chunks) == 1 + assert filtered_chunks[0].chunk_id == 1 + assert section1 in filtered_chunks[0].content + + +def test_validate_salesforce_access_no_source_links() -> None: + """Test handling of chunks with no source links""" + section = "Content with no source links" + test_chunk = create_test_chunk( + doc_id="doc1", + chunk_id=1, + content=section, + source_links=None, + ) + + filtered_chunks = censor_salesforce_chunks( + chunks=[test_chunk], + user_email="test@example.com", + access_map={}, + ) + assert len(filtered_chunks) == 0 + + +def test_validate_salesforce_access_blurb_update() -> None: + """Test that blurbs are properly updated based on permitted content""" + section = "First part about object1. " + long_content = section * 20 # Make it longer than BLURB_SIZE + test_chunk = create_test_chunk( + doc_id="doc1", + chunk_id=1, + content=long_content, + source_links={0: "https://salesforce.com/object1"}, + ) + + filtered_chunks = censor_salesforce_chunks( + chunks=[test_chunk], + user_email="test@example.com", + access_map={"object1": True}, + ) + assert len(filtered_chunks) == 1 + assert len(filtered_chunks[0].blurb) <= BLURB_SIZE + assert filtered_chunks[0].blurb.startswith(section) diff --git a/web/src/lib/connectors/AutoSyncOptionFields.tsx b/web/src/lib/connectors/AutoSyncOptionFields.tsx index 812288ba876..64b1725161f 100644 --- a/web/src/lib/connectors/AutoSyncOptionFields.tsx +++ b/web/src/lib/connectors/AutoSyncOptionFields.tsx @@ -15,4 +15,5 @@ export const autoSyncConfigBySource: Record< google_drive: {}, gmail: {}, slack: {}, + salesforce: {}, }; diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index cab013985de..38cd6742a34 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -343,6 +343,7 @@ export const validAutoSyncSources = [ ValidSources.GoogleDrive, ValidSources.Gmail, ValidSources.Slack, + ValidSources.Salesforce, ] as const; // Create a type from the array elements