Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for OAuth connectors that require user input #3571

Merged
merged 5 commits into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion backend/onyx/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,6 @@
CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE")

# Egnyte specific configs
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")

Expand Down
45 changes: 33 additions & 12 deletions backend/onyx/connectors/egnyte/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from typing import IO
from urllib.parse import quote

from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN
from pydantic import Field

from onyx.configs.app_configs import EGNYTE_CLIENT_ID
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
from onyx.configs.app_configs import INDEX_BATCH_SIZE
Expand Down Expand Up @@ -124,6 +125,15 @@ def _process_egnyte_file(


class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
class AdditionalOauthKwargs(OAuthConnector.AdditionalOauthKwargs):
egnyte_domain: str = Field(
title="Egnyte Domain",
description=(
"The domain for the Egnyte instance "
"(e.g. 'company' for company.egnyte.com)"
),
)

def __init__(
self,
folder_path: str | None = None,
Expand All @@ -139,15 +149,20 @@ def oauth_id(cls) -> DocumentSource:
return DocumentSource.EGNYTE

@classmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
def oauth_authorization_url(
cls,
base_domain: str,
state: str,
additional_kwargs: dict[str, str],
) -> str:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")

oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)

callback_uri = get_oauth_callback_uri(base_domain, "egnyte")
return (
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
f"?client_id={EGNYTE_CLIENT_ID}"
f"&redirect_uri={callback_uri}"
f"&scope=Egnyte.filesystem"
Expand All @@ -156,17 +171,23 @@ def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
)

@classmethod
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
def oauth_code_to_token(
cls,
base_domain: str,
code: str,
additional_kwargs: dict[str, str],
) -> dict[str, Any]:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_CLIENT_SECRET:
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")

oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)

# Exchange code for token
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
url = f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
redirect_uri = get_oauth_callback_uri(base_domain, "egnyte")

data = {
"client_id": EGNYTE_CLIENT_ID,
"client_secret": EGNYTE_CLIENT_SECRET,
Expand All @@ -191,7 +212,7 @@ def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:

token_data = response.json()
return {
"domain": EGNYTE_BASE_DOMAIN,
"domain": oauth_kwargs.egnyte_domain,
"access_token": token_data["access_token"],
}

Expand All @@ -215,7 +236,7 @@ def _get_files_list(
"list_content": True,
}

url_encoded_path = quote(path or "", safe="")
url_encoded_path = quote(path or "")
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
response = request_with_retries(
method="GET", url=url, headers=headers, params=params
Expand Down Expand Up @@ -271,7 +292,7 @@ def _process_files(
headers = {
"Authorization": f"Bearer {self.access_token}",
}
url_encoded_path = quote(file["path"], safe="")
url_encoded_path = quote(file["path"])
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
response = request_with_retries(
method="GET",
Expand Down
20 changes: 18 additions & 2 deletions backend/onyx/connectors/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from collections.abc import Iterator
from typing import Any

from pydantic import BaseModel

from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
Expand Down Expand Up @@ -66,19 +68,33 @@ def retrieve_all_slim_documents(


class OAuthConnector(BaseConnector):
class AdditionalOauthKwargs(BaseModel):
# if overridden, all fields should be str type
pass

@classmethod
@abc.abstractmethod
def oauth_id(cls) -> DocumentSource:
raise NotImplementedError

@classmethod
@abc.abstractmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
def oauth_authorization_url(
cls,
base_domain: str,
state: str,
additional_kwargs: dict[str, str],
) -> str:
raise NotImplementedError

@classmethod
@abc.abstractmethod
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
def oauth_code_to_token(
cls,
base_domain: str,
code: str,
additional_kwargs: dict[str, str],
) -> dict[str, Any]:
raise NotImplementedError


Expand Down
8 changes: 6 additions & 2 deletions backend/onyx/connectors/linear/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def oauth_id(cls) -> DocumentSource:
return DocumentSource.LINEAR

@classmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
def oauth_authorization_url(
cls, base_domain: str, state: str, additional_kwargs: dict[str, str]
) -> str:
if not LINEAR_CLIENT_ID:
raise ValueError("LINEAR_CLIENT_ID environment variable must be set")

Expand All @@ -92,7 +94,9 @@ def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
)

@classmethod
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
def oauth_code_to_token(
cls, base_domain: str, code: str, additional_kwargs: dict[str, str]
) -> dict[str, Any]:
data = {
"code": code,
"redirect_uri": get_oauth_callback_uri(
Expand Down
107 changes: 98 additions & 9 deletions backend/onyx/server/documents/standard_oauth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import uuid
from typing import Annotated
from typing import cast
Expand All @@ -6,7 +7,9 @@
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import Request
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session

from onyx.auth.users import current_user
Expand All @@ -28,6 +31,8 @@

_OAUTH_STATE_KEY_FMT = "oauth_state:{state}"
_OAUTH_STATE_EXPIRATION_SECONDS = 10 * 60 # 10 minutes
_DESIRED_RETURN_URL_KEY = "desired_return_url"
_ADDITIONAL_KWARGS_KEY = "additional_kwargs"

# Cache for OAuth connectors, populated at module load time
_OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {}
Expand All @@ -51,12 +56,36 @@ def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]:
_discover_oauth_connectors()


def _get_additional_kwargs(
request: Request, connector_cls: type[OAuthConnector], args_to_ignore: list[str]
) -> dict[str, str]:
# get additional kwargs from request
# e.g. anything except for desired_return_url
additional_kwargs_dict = {
k: v for k, v in request.query_params.items() if k not in args_to_ignore
}
try:
# validate
connector_cls.AdditionalOauthKwargs(**additional_kwargs_dict)
except ValidationError:
raise HTTPException(
status_code=400,
detail=(
f"Invalid additional kwargs. Got {additional_kwargs_dict}, expected "
f"{connector_cls.AdditionalOauthKwargs.model_json_schema()}"
),
)

return additional_kwargs_dict


class AuthorizeResponse(BaseModel):
redirect_url: str


@router.get("/authorize/{source}")
def oauth_authorize(
request: Request,
source: DocumentSource,
desired_return_url: Annotated[str | None, Query()] = None,
_: User = Depends(current_user),
Expand All @@ -71,19 +100,32 @@ def oauth_authorize(
connector_cls = oauth_connectors[source]
base_url = WEB_DOMAIN

# get additional kwargs from request
# e.g. anything except for desired_return_url
additional_kwargs = _get_additional_kwargs(
request, connector_cls, ["desired_return_url"]
)

# store state in redis
if not desired_return_url:
desired_return_url = f"{base_url}/admin/connectors/{source}?step=0"
redis_client = get_redis_client(tenant_id=tenant_id)
state = str(uuid.uuid4())
redis_client.set(
_OAUTH_STATE_KEY_FMT.format(state=state),
desired_return_url,
json.dumps(
{
_DESIRED_RETURN_URL_KEY: desired_return_url,
_ADDITIONAL_KWARGS_KEY: additional_kwargs,
}
),
ex=_OAUTH_STATE_EXPIRATION_SECONDS,
)

return AuthorizeResponse(
redirect_url=connector_cls.oauth_authorization_url(base_url, state)
redirect_url=connector_cls.oauth_authorization_url(
base_url, state, additional_kwargs
)
)


Expand All @@ -110,15 +152,18 @@ def oauth_callback(

# get state from redis
redis_client = get_redis_client(tenant_id=tenant_id)
original_url_bytes = cast(
oauth_state_bytes = cast(
bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state))
)
if not original_url_bytes:
if not oauth_state_bytes:
raise HTTPException(status_code=400, detail="Invalid OAuth state")
original_url = original_url_bytes.decode("utf-8")
oauth_state = json.loads(oauth_state_bytes.decode("utf-8"))

desired_return_url = cast(str, oauth_state[_DESIRED_RETURN_URL_KEY])
additional_kwargs = cast(dict[str, str], oauth_state[_ADDITIONAL_KWARGS_KEY])

base_url = WEB_DOMAIN
token_info = connector_cls.oauth_code_to_token(base_url, code)
token_info = connector_cls.oauth_code_to_token(base_url, code, additional_kwargs)

# Create a new credential with the token info
credential_data = CredentialBase(
Expand All @@ -136,8 +181,52 @@ def oauth_callback(

return CallbackResponse(
redirect_url=(
f"{original_url}?credentialId={credential.id}"
if "?" not in original_url
else f"{original_url}&credentialId={credential.id}"
f"{desired_return_url}?credentialId={credential.id}"
if "?" not in desired_return_url
else f"{desired_return_url}&credentialId={credential.id}"
)
)


class OAuthAdditionalKwargDescription(BaseModel):
name: str
display_name: str
description: str


class OAuthDetails(BaseModel):
oauth_enabled: bool
additional_kwargs: list[OAuthAdditionalKwargDescription]


@router.get("/details/{source}")
def oauth_details(
source: DocumentSource,
_: User = Depends(current_user),
) -> OAuthDetails:
oauth_connectors = _discover_oauth_connectors()

if source not in oauth_connectors:
return OAuthDetails(
oauth_enabled=False,
additional_kwargs=[],
)

connector_cls = oauth_connectors[source]

additional_kwarg_descriptions = []
for key, value in connector_cls.AdditionalOauthKwargs.model_json_schema()[
"properties"
].items():
additional_kwarg_descriptions.append(
OAuthAdditionalKwargDescription(
name=key,
display_name=value.get("title", key),
description=value.get("description", ""),
)
)

return OAuthDetails(
oauth_enabled=True,
additional_kwargs=additional_kwarg_descriptions,
)
1 change: 0 additions & 1 deletion deployment/docker_compose/docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ services:
# Egnyte OAuth Configs
- EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-}
- EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-}
- EGNYTE_BASE_DOMAIN=${EGNYTE_BASE_DOMAIN:-}
- EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-}
# Celery Configs (defaults are set in the supervisord.conf file.
# prefer doing that to have one source of defaults)
Expand Down
Loading
Loading