Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Permission Syncing for Salesforce #3551

Merged
merged 20 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion backend/ee/onyx/access/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
from ee.onyx.db.external_perm import fetch_external_groups_for_user
from ee.onyx.db.user_group import fetch_user_groups_for_documents
from ee.onyx.db.user_group import fetch_user_groups_for_user
from ee.onyx.external_permissions.post_query_censoring import (
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
)
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
from onyx.access.access import (
_get_access_for_documents as get_access_for_documents_without_groups,
)
from onyx.access.access import _get_acl_for_user as get_acl_for_user_without_groups
from onyx.access.models import DocumentAccess
from onyx.access.utils import prefix_external_group
from onyx.access.utils import prefix_user_group
from onyx.db.document import get_document_sources
from onyx.db.document import get_documents_by_ids
from onyx.db.models import User

Expand Down Expand Up @@ -52,9 +57,20 @@ def _get_access_for_documents(
)
doc_id_map = {doc.id: doc for doc in documents}

# Get all sources in one batch
doc_id_to_source_map = get_document_sources(
db_session=db_session,
document_ids=document_ids,
)

access_map = {}
for document_id, non_ee_access in non_ee_access_dict.items():
document = doc_id_map[document_id]
source = doc_id_to_source_map.get(document_id)
is_only_censored = (
source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
and source not in DOC_PERMISSIONS_FUNC_MAP
)

ext_u_emails = (
set(document.external_user_emails)
Expand All @@ -70,7 +86,11 @@ def _get_access_for_documents(

# If the document is determined to be "public" externally (through a SYNC connector)
# then it's given the same access level as if it were marked public within Onyx
is_public_anywhere = document.is_public or non_ee_access.is_public
# If its censored, then it's public anywhere during the search and then permissions are
# applied after the search
is_public_anywhere = (
document.is_public or non_ee_access.is_public or is_only_censored
)

# To avoid collisions of group namings between connectors, they need to be prefixed
access_map[document_id] = DocumentAccess(
Expand Down
19 changes: 19 additions & 0 deletions backend/ee/onyx/db/external_perm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from onyx.configs.constants import DocumentSource
from onyx.db.models import User__ExternalUserGroupId
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger

logger = setup_logger()
Expand Down Expand Up @@ -106,3 +107,21 @@ def fetch_external_groups_for_user(
User__ExternalUserGroupId.user_id == user_id
)
).all()


def fetch_external_groups_for_user_email_and_group_ids(
db_session: Session,
user_email: str,
group_ids: list[str],
) -> list[User__ExternalUserGroupId]:
user = get_user_by_email(db_session=db_session, email=user_email)
if user is None:
return []
user_id = user.id
user_ext_groups = db_session.scalars(
select(User__ExternalUserGroupId).where(
User__ExternalUserGroupId.user_id == user_id,
User__ExternalUserGroupId.external_user_group_id.in_(group_ids),
)
).all()
return list(user_ext_groups)
84 changes: 84 additions & 0 deletions backend/ee/onyx/external_permissions/post_query_censoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from collections.abc import Callable

from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.external_permissions.salesforce.postprocessing import (
censor_salesforce_chunks,
)
from onyx.configs.constants import DocumentSource
from onyx.context.search.pipeline import InferenceChunk
from onyx.db.engine import get_session_context_manager
from onyx.db.models import User
from onyx.utils.logger import setup_logger

logger = setup_logger()

DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION: dict[
DocumentSource,
# list of chunks to be censored and the user email. returns censored chunks
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the general inputs might change if we add more sources as needed

Callable[[list[InferenceChunk], str], list[InferenceChunk]],
] = {
DocumentSource.SALESFORCE: censor_salesforce_chunks,
}


def _get_all_censoring_enabled_sources() -> set[DocumentSource]:
"""
Returns the set of sources that have censoring enabled.
This is based on if the access_type is set to sync and the connector
source is included in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION.

NOTE: This means if there is a source has a single cc_pair that is sync,
all chunks for that source will be censored, even if the connector that
indexed that chunk is not sync. This was done to avoid getting the cc_pair
for every single chunk.
"""
with get_session_context_manager() as db_session:
enabled_sync_connectors = get_all_auto_sync_cc_pairs(db_session)
return {
cc_pair.connector.source
for cc_pair in enabled_sync_connectors
if cc_pair.connector.source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
}


# NOTE: This is only called if ee is enabled.
def _post_query_chunk_censoring(
chunks: list[InferenceChunk],
user: User | None,
) -> list[InferenceChunk]:
"""
This function checks all chunks to see if they need to be sent to a censoring
function. If they do, it sends them to the censoring function and returns the
censored chunks. If they don't, it returns the original chunks.
"""
if user is None:
# if user is None, permissions are not enforced
return chunks

chunks_to_keep = []
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}

sources_to_censor = _get_all_censoring_enabled_sources()
for chunk in chunks:
# Separate out chunks that require permission post-processing by source
if chunk.source_type in sources_to_censor:
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
else:
chunks_to_keep.append(chunk)

# For each source, filter out the chunks using the permission
# check function for that source
# TODO: Use a threadpool/multiprocessing to process the sources in parallel
for source, chunks_for_source in chunks_to_process.items():
censor_chunks_for_source = DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION[source]
try:
censored_chunks = censor_chunks_for_source(chunks_for_source, user.email)
except Exception as e:
logger.exception(
f"Failed to censor chunks for source {source} so throwing out all"
f" chunks for this source and continuing: {e}"
)
continue
chunks_to_keep.extend(censored_chunks)

return chunks_to_keep
226 changes: 226 additions & 0 deletions backend/ee/onyx/external_permissions/salesforce/postprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import time

from ee.onyx.db.external_perm import fetch_external_groups_for_user_email_and_group_ids
from ee.onyx.external_permissions.salesforce.utils import (
get_any_salesforce_client_for_doc_id,
)
from ee.onyx.external_permissions.salesforce.utils import get_objects_access_for_user_id
from ee.onyx.external_permissions.salesforce.utils import (
get_salesforce_user_id_from_email,
)
from onyx.configs.app_configs import BLURB_SIZE
from onyx.context.search.models import InferenceChunk
from onyx.db.engine import get_session_context_manager
from onyx.utils.logger import setup_logger

logger = setup_logger()


# Types
ChunkKey = tuple[str, int] # (doc_id, chunk_id)
ContentRange = tuple[int, int | None] # (start_index, end_index) None means to the end


# NOTE: Used for testing timing
def _get_dummy_object_access_map(
object_ids: set[str], user_email: str, chunks: list[InferenceChunk]
) -> dict[str, bool]:
time.sleep(0.15)
# return {object_id: True for object_id in object_ids}
import random

return {object_id: random.choice([True, False]) for object_id in object_ids}


def _get_objects_access_for_user_email_from_salesforce(
object_ids: set[str],
user_email: str,
chunks: list[InferenceChunk],
) -> dict[str, bool] | None:
"""
This function wraps the salesforce call as we may want to change how this
is done in the future. (E.g. replace it with the above function)
"""
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
# but subsequent queries for this source are essentially instant
first_doc_id = chunks[0].document_id
with get_session_context_manager() as db_session:
salesforce_client = get_any_salesforce_client_for_doc_id(
db_session, first_doc_id
)

# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
# but subsequent queries by the same user are essentially instant
start_time = time.time()
user_id = get_salesforce_user_id_from_email(salesforce_client, user_email)
end_time = time.time()
logger.info(
f"Time taken to get Salesforce user ID: {end_time - start_time} seconds"
)
if user_id is None:
return None

# This is the only query that is not cached in the function
# so it takes 0.1-0.2 seconds total
object_id_to_access = get_objects_access_for_user_id(
salesforce_client, user_id, list(object_ids)
)
return object_id_to_access


def _extract_salesforce_object_id_from_url(url: str) -> str:
return url.split("/")[-1]


def _get_object_ranges_for_chunk(
chunk: InferenceChunk,
) -> dict[str, list[ContentRange]]:
"""
Given a chunk, return a dictionary of salesforce object ids and the content ranges
for that object id in the current chunk
"""
if chunk.source_links is None:
return {}

object_ranges: dict[str, list[ContentRange]] = {}
end_index = None
descending_source_links = sorted(
chunk.source_links.items(), key=lambda x: x[0], reverse=True
)
for start_index, url in descending_source_links:
object_id = _extract_salesforce_object_id_from_url(url)
if object_id not in object_ranges:
object_ranges[object_id] = []
object_ranges[object_id].append((start_index, end_index))
end_index = start_index
return object_ranges


def _create_empty_censored_chunk(uncensored_chunk: InferenceChunk) -> InferenceChunk:
"""
Create a copy of the unfiltered chunk where potentially sensitive content is removed
to be added later if the user has access to each of the sub-objects
"""
empty_censored_chunk = InferenceChunk(
**uncensored_chunk.model_dump(),
)
empty_censored_chunk.content = ""
empty_censored_chunk.blurb = ""
empty_censored_chunk.source_links = {}
return empty_censored_chunk


def _update_censored_chunk(
censored_chunk: InferenceChunk,
uncensored_chunk: InferenceChunk,
content_range: ContentRange,
) -> InferenceChunk:
"""
Update the filtered chunk with the content and source links from the unfiltered chunk using the content ranges
"""
start_index, end_index = content_range

# Update the content of the filtered chunk
permitted_content = uncensored_chunk.content[start_index:end_index]
permitted_section_start_index = len(censored_chunk.content)
censored_chunk.content = permitted_content + censored_chunk.content

# Update the source links of the filtered chunk
if uncensored_chunk.source_links is not None:
if censored_chunk.source_links is None:
censored_chunk.source_links = {}
link_content = uncensored_chunk.source_links[start_index]
censored_chunk.source_links[permitted_section_start_index] = link_content

# Update the blurb of the filtered chunk
censored_chunk.blurb = censored_chunk.content[:BLURB_SIZE]

return censored_chunk


# TODO: Generalize this to other sources
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would suggest just following along with the tests to understand this lol

def censor_salesforce_chunks(
chunks: list[InferenceChunk],
user_email: str,
# This is so we can provide a mock access map for testing
access_map: dict[str, bool] | None = None,
) -> list[InferenceChunk]:
# object_id -> list[((doc_id, chunk_id), (start_index, end_index))]
object_to_content_map: dict[str, list[tuple[ChunkKey, ContentRange]]] = {}

# (doc_id, chunk_id) -> chunk
uncensored_chunks: dict[ChunkKey, InferenceChunk] = {}

# keep track of all object ids that we have seen to make it easier to get
# the access for these object ids
object_ids: set[str] = set()

for chunk in chunks:
chunk_key = (chunk.document_id, chunk.chunk_id)
# create a dictionary to quickly look up the unfiltered chunk
uncensored_chunks[chunk_key] = chunk

# for each chunk, get a dictionary of object ids and the content ranges
# for that object id in the current chunk
object_ranges_for_chunk = _get_object_ranges_for_chunk(chunk)
for object_id, ranges in object_ranges_for_chunk.items():
object_ids.add(object_id)
for start_index, end_index in ranges:
object_to_content_map.setdefault(object_id, []).append(
(chunk_key, (start_index, end_index))
)

# This is so we can provide a mock access map for testing
if access_map is None:
access_map = _get_objects_access_for_user_email_from_salesforce(
object_ids=object_ids,
user_email=user_email,
chunks=chunks,
)
if access_map is None:
# If the user is not found in Salesforce, access_map will be None
# so we should just return an empty list because no chunks will be
# censored
return []

censored_chunks: dict[ChunkKey, InferenceChunk] = {}
for object_id, content_list in object_to_content_map.items():
# if the user does not have access to the object, or the object is not in the
# access_map, do not include its content in the filtered chunks
if not access_map.get(object_id, False):
continue

# if we got this far, the user has access to the object so we can create or update
# the filtered chunk(s) for this object
# NOTE: we only create a censored chunk if the user has access to some
# part of the chunk
for chunk_key, content_range in content_list:
if chunk_key not in censored_chunks:
censored_chunks[chunk_key] = _create_empty_censored_chunk(
uncensored_chunks[chunk_key]
)

uncensored_chunk = uncensored_chunks[chunk_key]
censored_chunk = _update_censored_chunk(
censored_chunk=censored_chunks[chunk_key],
uncensored_chunk=uncensored_chunk,
content_range=content_range,
)
censored_chunks[chunk_key] = censored_chunk

return list(censored_chunks.values())


# NOTE: This is not used anywhere.
def _get_objects_access_for_user_email(
object_ids: set[str], user_email: str
) -> dict[str, bool]:
with get_session_context_manager() as db_session:
external_groups = fetch_external_groups_for_user_email_and_group_ids(
db_session=db_session,
user_email=user_email,
# Maybe make a function that adds a salesforce prefix to the group ids
group_ids=list(object_ids),
)
external_group_ids = {group.external_user_group_id for group in external_groups}
return {group_id: group_id in external_group_ids for group_id in object_ids}
Loading
Loading