Skip to content

Commit

Permalink
pass the encoding format while generating schema
Browse files Browse the repository at this point in the history
  • Loading branch information
sgandhi1311 committed Oct 25, 2023
1 parent e9a1977 commit a4174d5
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 12 deletions.
2 changes: 1 addition & 1 deletion tap_sftp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def do_discover(config):
if encoding_format != "utf-8":
if not is_valid_encoding(encoding_format):
raise Exception("Unknown Encoding - {}. Enter the valid encoding format".format(encoding_format))
streams = discover_streams(config)
streams = discover_streams(config, encoding_format)
if not streams:
raise Exception("No streams found")
catalog = {"streams": streams}
Expand Down
8 changes: 4 additions & 4 deletions tap_sftp/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@

LOGGER= singer.get_logger()

def discover_streams(config):
def discover_streams(config, encoding_format):
streams = []

conn = client.connection(config)
prefix = format(config.get("user_dir", "./"))

tables = json.loads(config['tables'])
for table_spec in tables:
schema, stream_md = get_schema(conn, table_spec)
schema, stream_md = get_schema(conn, table_spec, encoding_format)

streams.append(
{
Expand All @@ -38,9 +38,9 @@ def discover_streams(config):
interval=10,
jitter=None)
# generate schema
def get_schema(conn, table_spec):
def get_schema(conn, table_spec, encoding_format):
LOGGER.info('Sampling records to determine table JSON schema "%s".', table_spec['table_name'])
schema = json_schema.get_schema_for_table(conn, table_spec)
schema = json_schema.get_schema_for_table(conn, table_spec, encoding_format)
stream_md = metadata.get_standard_metadata(schema,
key_properties=table_spec.get('key_properties'),
replication_method='INCREMENTAL')
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/test_encoding_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_do_discover_valid_encoding(
mock_is_valid_encoding.assert_not_called()
else:
mock_is_valid_encoding.assert_called_with("latin_1")
mock_discover_streams.assert_called_with(config)
mock_discover_streams.assert_called_with(config, encoding_format)
self.assertEqual(captured_output, sys.stdout) # Ensure sys.stdout is restored

@patch("tap_sftp.is_valid_encoding", return_value=False)
Expand Down
7 changes: 4 additions & 3 deletions tests/unittests/test_permission_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tap_sftp.sync as sync
import paramiko

DEFAULT_ENCODING_FORMAT = "utf-8"
@mock.patch("tap_sftp.client.SFTPConnection.sftp")
@mock.patch("tap_sftp.client.LOGGER.warn")
class TestPermissionError(unittest.TestCase):
Expand Down Expand Up @@ -55,7 +56,7 @@ def test_no_error_during_sync(self, mocked_get_row_iterators, mocked_stats, mock

conn = client.SFTPConnection("10.0.0.1", "username", port="22")

rows_synced = sync.sync_file(conn, {"filepath": "/root_dir/file.csv.gz", "last_modified": "2020-01-01"}, None, {"key_properties": ["id"], "delimiter": ","})
rows_synced = sync.sync_file(conn, {"filepath": "/root_dir/file.csv.gz", "last_modified": "2020-01-01"}, None, {"key_properties": ["id"], "delimiter": ","}, encoding_format=DEFAULT_ENCODING_FORMAT)
# check if "csv.get_row_iterators" is called if it is called then error has not occurred
# if it is not called then error has occured and function returned from the except block
self.assertEquals(1, mocked_get_row_iterators.call_count)
Expand All @@ -68,7 +69,7 @@ def test_permisison_error_during_sync(self, mocked_get_row_iterators, mocked_log

conn = client.SFTPConnection("10.0.0.1", "username", port="22")

rows_synced = sync.sync_file(conn, {"filepath": "/root_dir/file.csv.gz", "last_modified": "2020-01-01"}, None, {"key_properties": ["id"], "delimiter": ","})
rows_synced = sync.sync_file(conn, {"filepath": "/root_dir/file.csv.gz", "last_modified": "2020-01-01"}, None, {"key_properties": ["id"], "delimiter": ","}, encoding_format=DEFAULT_ENCODING_FORMAT)
# check if "csv.get_row_iterators" is called if it is called then error has not occurred
# if it is not called then error has occured and function returned from the except block
self.assertEquals(0, mocked_get_row_iterators.call_count)
Expand All @@ -82,7 +83,7 @@ def test_oserror_during_sync(self, mocked_get_row_iterators, mocked_logger, mock

conn = client.SFTPConnection("10.0.0.1", "username", port="22")

rows_synced = sync.sync_file(conn, {"filepath": "/root_dir/file.csv.gz", "last_modified": "2020-01-01"}, None, {"key_properties": ["id"], "delimiter": ","})
rows_synced = sync.sync_file(conn, {"filepath": "/root_dir/file.csv.gz", "last_modified": "2020-01-01"}, None, {"key_properties": ["id"], "delimiter": ","}, encoding_format=DEFAULT_ENCODING_FORMAT)
# check if "csv.get_row_iterators" is called if it is called then error has not occurred
# if it is not called then error has occured and function returned from the except block
self.assertEquals(0, mocked_get_row_iterators.call_count)
Expand Down
7 changes: 4 additions & 3 deletions tests/unittests/test_timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,14 @@ def test_timeout_backoff__get_schema(self, mocked_get_schema_for_table, mocked_g
"""
# mock 'get_schema_for_table' and raise 'socket.timeout' error
mocked_get_schema_for_table.side_effect = socket.timeout

encoding_format = "utf-8"
table_spec = {
"table_name": "test"
}
before_time = datetime.now()
with self.assertRaises(socket.timeout):
# function call
discover.get_schema("test_conn", table_spec)
discover.get_schema("test_conn", table_spec, encoding_format)
after_time = datetime.now()

# verify that the tap backoff for 60 seconds
Expand Down Expand Up @@ -187,11 +187,12 @@ def test_timeout_backoff__sync_file(self, mocked_get_row_iterators, mocked_get_f
file = {
"filepath": "/root/file.csv"
}
encoding_format = "utf-8"
# create connection
conn = client.connection(config=config)
with self.assertRaises(socket.timeout):
# function call
sync.sync_file(conn=conn, f=file, stream="test_stream", table_spec=table_spec)
sync.sync_file(conn=conn, f=file, stream="test_stream", table_spec=table_spec, encoding_format=encoding_format)

# verify that the tap backoff for 5 times
self.assertEquals(mocked_get_row_iterators.call_count, 5)
Expand Down

0 comments on commit a4174d5

Please sign in to comment.