Skip to content

Commit

Permalink
🐛 Source Salesforce: have clear error when stream preparation fails (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
maxi297 authored Apr 15, 2024
1 parent e902607 commit 36749fb
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ data:
connectorSubtype: api
connectorType: source
definitionId: b117307c-14b6-41aa-9422-947e34922962
dockerImageTag: 2.5.1
dockerImageTag: 2.5.2
dockerRepository: airbyte/source-salesforce
documentationUrl: https://docs.airbyte.com/integrations/sources/salesforce
githubIssueLabel: source-salesforce
Expand Down
6 changes: 3 additions & 3 deletions airbyte-integrations/connectors/source-salesforce/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = [ "poetry-core>=1.0.0",]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
version = "2.5.1"
version = "2.5.2"
name = "source-salesforce"
description = "Source implementation for Salesforce."
authors = [ "Airbyte <contact@airbyte.io>",]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import requests # type: ignore[import]
from airbyte_cdk.models import ConfiguredAirbyteCatalog
from airbyte_cdk.utils import AirbyteTracedException
from airbyte_protocol.models import FailureType
from airbyte_protocol.models import FailureType, StreamDescriptor
from requests import adapters as request_adapters
from requests.exceptions import HTTPError, RequestException # type: ignore[import]

Expand Down Expand Up @@ -374,7 +374,14 @@ def load_schema(name: str, stream_options: Mapping[str, Any]) -> Tuple[str, Opti
):
if err:
self.logger.error(f"Loading error of the {stream_name} schema: {err}")
continue
# Without schema information, the source can't determine the type of stream to instantiate and there might be issues
# related to property chunking
raise AirbyteTracedException(
message=f"Schema could not be extracted for stream {stream_name}. Please retry later.",
internal_message=str(err),
failure_type=FailureType.system_error,
stream_descriptor=StreamDescriptor(name=stream_name),
)
stream_schemas[stream_name] = schema
return stream_schemas

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,19 @@
import json
import urllib.parse
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Optional
from unittest import TestCase

import freezegun
from airbyte_cdk.sources.source import TState
from airbyte_cdk.test.catalog_builder import CatalogBuilder
from airbyte_cdk.test.entrypoint_wrapper import EntrypointOutput, read
from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse
from airbyte_cdk.test.mock_http.request import ANY_QUERY_PARAMS
from airbyte_cdk.test.state_builder import StateBuilder
from airbyte_protocol.models import ConfiguredAirbyteCatalog, SyncMode
from airbyte_protocol.models import SyncMode
from config_builder import ConfigBuilder
from integration.utils import given_stream
from integration.utils import create_base_url, given_authentication, given_stream, read
from salesforce_describe_response_builder import SalesforceDescribeResponseBuilder
from source_salesforce import SourceSalesforce
from source_salesforce.streams import LOOKBACK_SECONDS

_A_FIELD_NAME = "a_field"
_ACCESS_TOKEN = "an_access_token"
_API_VERSION = "v57.0"
_CLIENT_ID = "a_client_id"
_CLIENT_SECRET = "a_client_secret"
_CURSOR_FIELD = "SystemModstamp"
Expand All @@ -33,38 +26,7 @@
_REFRESH_TOKEN = "a_refresh_token"
_STREAM_NAME = "a_stream_name"

_BASE_URL = f"{_INSTANCE_URL}/services/data/{_API_VERSION}"


def _catalog(sync_mode: SyncMode) -> ConfiguredAirbyteCatalog:
return CatalogBuilder().with_stream(_STREAM_NAME, sync_mode).build()


def _source(catalog: ConfiguredAirbyteCatalog, config: Dict[str, Any], state: Optional[TState]) -> SourceSalesforce:
return SourceSalesforce(catalog, config, state)


def _read(
sync_mode: SyncMode,
config_builder: Optional[ConfigBuilder] = None,
state_builder: Optional[StateBuilder] = None,
expecting_exception: bool = False
) -> EntrypointOutput:
catalog = _catalog(sync_mode)
config = config_builder.build() if config_builder else ConfigBuilder().build()
state = state_builder.build() if state_builder else StateBuilder().build()
return read(_source(catalog, config, state), config, catalog, state, expecting_exception)


def _given_authentication(http_mocker: HttpMocker, client_id: str, client_secret: str, refresh_token: str) -> None:
http_mocker.post(
HttpRequest(
"https://login.salesforce.com/services/oauth2/token",
query_params=ANY_QUERY_PARAMS,
body=f"grant_type=refresh_token&client_id={client_id}&client_secret={client_secret}&refresh_token={refresh_token}"
),
HttpResponse(json.dumps({"access_token": _ACCESS_TOKEN, "instance_url": _INSTANCE_URL})),
)
_BASE_URL = create_base_url(_INSTANCE_URL)


def _create_field(name: str, _type: Optional[str] = None) -> Dict[str, Any]:
Expand Down Expand Up @@ -93,7 +55,7 @@ def setUp(self) -> None:

@HttpMocker()
def test_when_read_then_create_job_and_extract_records_from_result(self, http_mocker: HttpMocker) -> None:
_given_authentication(http_mocker, _CLIENT_ID, _CLIENT_SECRET, _REFRESH_TOKEN)
given_authentication(http_mocker, _CLIENT_ID, _CLIENT_SECRET, _REFRESH_TOKEN, _INSTANCE_URL)
given_stream(http_mocker, _BASE_URL, _STREAM_NAME, SalesforceDescribeResponseBuilder().field(_A_FIELD_NAME))
http_mocker.post(
HttpRequest(f"{_BASE_URL}/jobs/query", body=json.dumps({"operation": "queryAll", "query": "SELECT a_field FROM a_stream_name", "contentType": "CSV", "columnDelimiter": "COMMA", "lineEnding": "LF"})),
Expand All @@ -108,6 +70,6 @@ def test_when_read_then_create_job_and_extract_records_from_result(self, http_mo
HttpResponse(f"{_A_FIELD_NAME}\nfield_value"),
)

output = _read(SyncMode.full_refresh, self._config)
output = read(_STREAM_NAME, SyncMode.full_refresh, self._config)

assert len(output.records) == 1
Original file line number Diff line number Diff line change
Expand Up @@ -3,69 +3,31 @@
import json
import urllib.parse
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Optional
from unittest import TestCase

import freezegun
from airbyte_cdk.sources.source import TState
from airbyte_cdk.test.catalog_builder import CatalogBuilder
from airbyte_cdk.test.entrypoint_wrapper import EntrypointOutput, read
from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse
from airbyte_cdk.test.mock_http.request import ANY_QUERY_PARAMS
from airbyte_cdk.test.state_builder import StateBuilder
from airbyte_protocol.models import ConfiguredAirbyteCatalog, SyncMode
from airbyte_protocol.models import SyncMode
from config_builder import ConfigBuilder
from integration.utils import given_stream
from integration.utils import create_base_url, given_authentication, given_stream, read
from salesforce_describe_response_builder import SalesforceDescribeResponseBuilder
from source_salesforce import SourceSalesforce
from source_salesforce.api import UNSUPPORTED_BULK_API_SALESFORCE_OBJECTS
from source_salesforce.streams import LOOKBACK_SECONDS

_A_FIELD_NAME = "a_field"
_ACCESS_TOKEN = "an_access_token"
_API_VERSION = "v57.0"
_CLIENT_ID = "a_client_id"
_CLIENT_SECRET = "a_client_secret"
_CURSOR_FIELD = "SystemModstamp"
_INSTANCE_URL = "https://instance.salesforce.com"
_BASE_URL = f"{_INSTANCE_URL}/services/data/{_API_VERSION}"
_BASE_URL = create_base_url(_INSTANCE_URL)
_LOOKBACK_WINDOW = timedelta(seconds=LOOKBACK_SECONDS)
_NOW = datetime.now(timezone.utc)
_REFRESH_TOKEN = "a_refresh_token"
_STREAM_NAME = UNSUPPORTED_BULK_API_SALESFORCE_OBJECTS[0]


def _catalog(sync_mode: SyncMode) -> ConfiguredAirbyteCatalog:
return CatalogBuilder().with_stream(_STREAM_NAME, sync_mode).build()


def _source(catalog: ConfiguredAirbyteCatalog, config: Dict[str, Any], state: Optional[TState]) -> SourceSalesforce:
return SourceSalesforce(catalog, config, state)


def _read(
sync_mode: SyncMode,
config_builder: Optional[ConfigBuilder] = None,
state_builder: Optional[StateBuilder] = None,
expecting_exception: bool = False
) -> EntrypointOutput:
catalog = _catalog(sync_mode)
config = config_builder.build() if config_builder else ConfigBuilder().build()
state = state_builder.build() if state_builder else StateBuilder().build()
return read(_source(catalog, config, state), config, catalog, state, expecting_exception)


def _given_authentication(http_mocker: HttpMocker, client_id: str, client_secret: str, refresh_token: str) -> None:
http_mocker.post(
HttpRequest(
"https://login.salesforce.com/services/oauth2/token",
query_params=ANY_QUERY_PARAMS,
body=f"grant_type=refresh_token&client_id={client_id}&client_secret={client_secret}&refresh_token={refresh_token}"
),
HttpResponse(json.dumps({"access_token": _ACCESS_TOKEN, "instance_url": _INSTANCE_URL})),
)


def _create_field(name: str, _type: Optional[str] = None) -> Dict[str, Any]:
return {"name": name, "type": _type if _type else "string"}

Expand All @@ -92,17 +54,17 @@ def setUp(self) -> None:

@HttpMocker()
def test_given_error_on_fetch_chunk_when_read_then_retry(self, http_mocker: HttpMocker) -> None:
_given_authentication(http_mocker, _CLIENT_ID, _CLIENT_SECRET, _REFRESH_TOKEN)
given_authentication(http_mocker, _CLIENT_ID, _CLIENT_SECRET, _REFRESH_TOKEN, _INSTANCE_URL)
given_stream(http_mocker, _BASE_URL, _STREAM_NAME, SalesforceDescribeResponseBuilder().field(_A_FIELD_NAME))
http_mocker.get(
HttpRequest(f"{_INSTANCE_URL}/services/data/{_API_VERSION}/queryAll?q=SELECT+{_A_FIELD_NAME}+FROM+{_STREAM_NAME}+"),
HttpRequest(f"{_BASE_URL}/queryAll?q=SELECT+{_A_FIELD_NAME}+FROM+{_STREAM_NAME}+"),
[
HttpResponse("", status_code=406),
HttpResponse(json.dumps({"records": [{"a_field": "a_value"}]})),
]
)

output = _read(SyncMode.full_refresh, self._config)
output = read(_STREAM_NAME, SyncMode.full_refresh, self._config)

assert len(output.records) == 1

Expand All @@ -115,7 +77,7 @@ def setUp(self) -> None:
self._http_mocker = HttpMocker()
self._http_mocker.__enter__()

_given_authentication(self._http_mocker, _CLIENT_ID, _CLIENT_SECRET, _REFRESH_TOKEN)
given_authentication(self._http_mocker, _CLIENT_ID, _CLIENT_SECRET, _REFRESH_TOKEN, _INSTANCE_URL)
given_stream(self._http_mocker, _BASE_URL, _STREAM_NAME, SalesforceDescribeResponseBuilder().field(_A_FIELD_NAME).field(_CURSOR_FIELD, "datetime"))

def tearDown(self) -> None:
Expand All @@ -128,11 +90,11 @@ def test_given_no_state_when_read_then_start_sync_from_start(self) -> None:
self._config.stream_slice_step("P30D").start_date(start)

self._http_mocker.get(
HttpRequest(f"{_INSTANCE_URL}/services/data/{_API_VERSION}/queryAll?q=SELECT+{_A_FIELD_NAME},{_CURSOR_FIELD}+FROM+{_STREAM_NAME}+WHERE+SystemModstamp+%3E%3D+{start_format_url}+AND+SystemModstamp+%3C+{_to_url(_NOW)}"),
HttpRequest(f"{_BASE_URL}/queryAll?q=SELECT+{_A_FIELD_NAME},{_CURSOR_FIELD}+FROM+{_STREAM_NAME}+WHERE+SystemModstamp+%3E%3D+{start_format_url}+AND+SystemModstamp+%3C+{_to_url(_NOW)}"),
HttpResponse(json.dumps({"records": [{"a_field": "a_value"}]})),
)

_read(SyncMode.incremental, self._config, StateBuilder().with_stream_state(_STREAM_NAME, {}))
read(_STREAM_NAME, SyncMode.incremental, self._config, StateBuilder().with_stream_state(_STREAM_NAME, {}))

# then HTTP requests are performed

Expand All @@ -141,11 +103,11 @@ def test_given_sequential_state_when_read_then_migrate_to_partitioned_state(self
start = _calculate_start_time(_NOW - timedelta(days=10))
self._config.stream_slice_step("P30D").start_date(start)
self._http_mocker.get(
HttpRequest(f"{_INSTANCE_URL}/services/data/{_API_VERSION}/queryAll?q=SELECT+{_A_FIELD_NAME},{_CURSOR_FIELD}+FROM+{_STREAM_NAME}+WHERE+SystemModstamp+%3E%3D+{_to_url(cursor_value - _LOOKBACK_WINDOW)}+AND+SystemModstamp+%3C+{_to_url(_NOW)}"),
HttpRequest(f"{_BASE_URL}/queryAll?q=SELECT+{_A_FIELD_NAME},{_CURSOR_FIELD}+FROM+{_STREAM_NAME}+WHERE+SystemModstamp+%3E%3D+{_to_url(cursor_value - _LOOKBACK_WINDOW)}+AND+SystemModstamp+%3C+{_to_url(_NOW)}"),
HttpResponse(json.dumps({"records": [{"a_field": "a_value"}]})),
)

output = _read(SyncMode.incremental, self._config, StateBuilder().with_stream_state(_STREAM_NAME, {_CURSOR_FIELD: cursor_value.isoformat(timespec="milliseconds")}))
output = read(_STREAM_NAME, SyncMode.incremental, self._config, StateBuilder().with_stream_state(_STREAM_NAME, {_CURSOR_FIELD: cursor_value.isoformat(timespec="milliseconds")}))

assert output.most_recent_state.stream_state.dict() == {"state_type": "date-range", "slices": [{"start": _to_partitioned_datetime(start), "end": _to_partitioned_datetime(_NOW)}]}

Expand All @@ -166,15 +128,15 @@ def test_given_partitioned_state_when_read_then_sync_missing_partitions_and_upda
self._config.stream_slice_step("P30D").start_date(start)

self._http_mocker.get(
HttpRequest(f"{_INSTANCE_URL}/services/data/{_API_VERSION}/queryAll?q=SELECT+{_A_FIELD_NAME},{_CURSOR_FIELD}+FROM+{_STREAM_NAME}+WHERE+SystemModstamp+%3E%3D+{_to_url(missing_chunk[0])}+AND+SystemModstamp+%3C+{_to_url(missing_chunk[1])}"),
HttpRequest(f"{_BASE_URL}/queryAll?q=SELECT+{_A_FIELD_NAME},{_CURSOR_FIELD}+FROM+{_STREAM_NAME}+WHERE+SystemModstamp+%3E%3D+{_to_url(missing_chunk[0])}+AND+SystemModstamp+%3C+{_to_url(missing_chunk[1])}"),
HttpResponse(json.dumps({"records": [{"a_field": "a_value"}]})),
)
self._http_mocker.get(
HttpRequest(f"{_INSTANCE_URL}/services/data/{_API_VERSION}/queryAll?q=SELECT+{_A_FIELD_NAME},{_CURSOR_FIELD}+FROM+{_STREAM_NAME}+WHERE+SystemModstamp+%3E%3D+{_to_url(most_recent_state_value - _LOOKBACK_WINDOW)}+AND+SystemModstamp+%3C+{_to_url(_NOW)}"),
HttpRequest(f"{_BASE_URL}/queryAll?q=SELECT+{_A_FIELD_NAME},{_CURSOR_FIELD}+FROM+{_STREAM_NAME}+WHERE+SystemModstamp+%3E%3D+{_to_url(most_recent_state_value - _LOOKBACK_WINDOW)}+AND+SystemModstamp+%3C+{_to_url(_NOW)}"),
HttpResponse(json.dumps({"records": [{"a_field": "a_value"}]})),
)

output = _read(SyncMode.incremental, self._config, state)
output = read(_STREAM_NAME, SyncMode.incremental, self._config, state)

# the start is granular to the second hence why we have `000` in terms of milliseconds
assert output.most_recent_state.stream_state.dict() == {"state_type": "date-range", "slices": [{"start": _to_partitioned_datetime(start), "end": _to_partitioned_datetime(_NOW)}]}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.

import json
from unittest import TestCase

import pytest
from airbyte_cdk.test.catalog_builder import CatalogBuilder
from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse
from airbyte_cdk.test.state_builder import StateBuilder
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
from airbyte_protocol.models import FailureType, SyncMode
from config_builder import ConfigBuilder
from integration.utils import create_base_url, given_authentication, given_stream
from salesforce_describe_response_builder import SalesforceDescribeResponseBuilder
from source_salesforce import SourceSalesforce

_CLIENT_ID = "a_client_id"
_CLIENT_SECRET = "a_client_secret"
_FIELD_NAME = "a_field_name"
_INSTANCE_URL = "https://instance.salesforce.com"
_REFRESH_TOKEN = "a_refresh_token"
_STREAM_NAME = "StreamName"

_BASE_URL = create_base_url(_INSTANCE_URL)


class StreamGenerationTest(TestCase):

def setUp(self) -> None:
self._config = ConfigBuilder().client_id(_CLIENT_ID).client_secret(_CLIENT_SECRET).refresh_token(_REFRESH_TOKEN).build()
self._source = SourceSalesforce(
CatalogBuilder().with_stream(_STREAM_NAME, SyncMode.full_refresh).build(),
self._config,
StateBuilder().build()
)

self._http_mocker = HttpMocker()
self._http_mocker.__enter__()

def tearDown(self) -> None:
self._http_mocker.__exit__(None, None, None)

def test_given_transient_error_fetching_schema_when_streams_then_retry(self) -> None:
given_authentication(self._http_mocker, _CLIENT_ID, _CLIENT_SECRET, _REFRESH_TOKEN, _INSTANCE_URL)
self._http_mocker.get(
HttpRequest(f"{_BASE_URL}/sobjects"),
HttpResponse(json.dumps({"sobjects": [{"name": _STREAM_NAME, "queryable": True}]})),
)
self._http_mocker.get(
HttpRequest(f"{_BASE_URL}/sobjects/{_STREAM_NAME}/describe"),
[
HttpResponse("", status_code=406),
SalesforceDescribeResponseBuilder().field("a_field_name").build()
]
)

streams = self._source.streams(self._config)

assert len(streams) == 2 # _STREAM_NAME and Describe which is always added
assert _FIELD_NAME in next(filter(lambda stream: stream.name == _STREAM_NAME, streams)).get_json_schema()["properties"]

def test_given_errors_fetching_schema_when_streams_then_raise_exception(self) -> None:
given_authentication(self._http_mocker, _CLIENT_ID, _CLIENT_SECRET, _REFRESH_TOKEN, _INSTANCE_URL)
self._http_mocker.get(
HttpRequest(f"{_BASE_URL}/sobjects"),
HttpResponse(json.dumps({"sobjects": [{"name": _STREAM_NAME, "queryable": True}]})),
)
self._http_mocker.get(
HttpRequest(f"{_BASE_URL}/sobjects/{_STREAM_NAME}/describe"),
HttpResponse("", status_code=406),
)

with pytest.raises(AirbyteTracedException) as exception:
self._source.streams(self._config)

assert exception.value.failure_type == FailureType.system_error
Loading

0 comments on commit 36749fb

Please sign in to comment.