Skip to content

Commit

Permalink
Add json as allowable file type to copy_s3 (#844)
Browse files Browse the repository at this point in the history
* Add json as allowable file type to copy_s3

* Add json to copy statement

* debug statement

* Trying to see what the table looks like

* debug

* More debugging

* debug

* Line delimited

* new line

* missing quotes

* debug

* Remove debug statements

* hiding creds for now

* Print redacted query

* debug statement

* debug

* debug

* redacted isn't working

* debug

* add delimiter to json

* new line

* single quotes

* remove debug

* My changes

* logs

* log

* more logs

* more logs

* more logs

* Removing logs

* Removing more logs
  • Loading branch information
KasiaHinkson authored Sep 7, 2023
1 parent 5d64a68 commit fce51b0
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.7
FROM --platform=linux/amd64 python:3.7

####################
## Selenium setup ##
Expand Down
4 changes: 4 additions & 0 deletions parsons/databases/redshift/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def copy_s3(
bucket_region=None,
strict_length=True,
template_table=None,
line_delimited=False,
):
"""
Copy a file from s3 to Redshift.
Expand Down Expand Up @@ -414,6 +415,8 @@ def copy_s3(
local_path = s3.get_file(bucket, key)
if data_type == "csv":
tbl = Table.from_csv(local_path, delimiter=csv_delimiter)
elif data_type == "json":
tbl = Table.from_json(local_path, line_delimited=line_delimited)
else:
raise TypeError("Invalid data type provided")

Expand All @@ -433,6 +436,7 @@ def copy_s3(
logger.info(f"{table_name} created.")

# Copy the table
logger.info(f"Data type is {data_type}")
copy_sql = self.copy_statement(
table_name,
bucket,
Expand Down
12 changes: 4 additions & 8 deletions parsons/databases/redshift/rs_copy_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


class RedshiftCopyTable(object):

aws_access_key_id = None
aws_secret_access_key = None
iam_role = None
Expand Down Expand Up @@ -42,8 +41,9 @@ def copy_statement(
aws_secret_access_key=None,
compression=None,
bucket_region=None,
json_option="auto",
):

logger.info(f"Data type is {data_type}")
# Source / Destination
source = f"s3://{bucket}/{key}"

Expand Down Expand Up @@ -101,6 +101,8 @@ def copy_statement(
# Data Type
if data_type == "csv":
sql += f"csv delimiter '{csv_delimiter}' \n"
elif data_type == "json":
sql += f"json '{json_option}' \n"
else:
raise TypeError("Invalid data type specified.")

Expand All @@ -112,7 +114,6 @@ def copy_statement(
return sql

def get_creds(self, aws_access_key_id, aws_secret_access_key):

if aws_access_key_id and aws_secret_access_key:
# When we have credentials, then we don't need to set them again
pass
Expand All @@ -122,19 +123,16 @@ def get_creds(self, aws_access_key_id, aws_secret_access_key):
return f"credentials 'aws_iam_role={self.iam_role}'\n"

elif self.aws_access_key_id and self.aws_secret_access_key:

aws_access_key_id = self.aws_access_key_id
aws_secret_access_key = self.aws_secret_access_key

elif (
"AWS_ACCESS_KEY_ID" in os.environ and "AWS_SECRET_ACCESS_KEY" in os.environ
):

aws_access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
aws_secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]

else:

s3 = S3(use_env_token=self.use_env_token)
creds = s3.aws.session.get_credentials()
aws_access_key_id = creds.access_key
Expand All @@ -151,7 +149,6 @@ def temp_s3_copy(
aws_secret_access_key=None,
csv_encoding="utf-8",
):

if not self.s3_temp_bucket:
raise KeyError(
(
Expand Down Expand Up @@ -184,6 +181,5 @@ def temp_s3_copy(
return key

def temp_s3_delete(self, key):

if key:
self.s3.remove_file(self.s3_temp_bucket, key)
4 changes: 0 additions & 4 deletions parsons/databases/redshift/rs_create_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def create_statement(
columntypes=None,
strict_length=True,
):

# Warn the user if they don't provide a DIST key or a SORT key
self._log_key_warning(distkey=distkey, sortkey=sortkey, method="copy")

Expand Down Expand Up @@ -144,7 +143,6 @@ def vc_max(self, mapping, columns):
# Set the varchar width of a column to the maximum

for c in columns:

try:
idx = mapping["headers"].index(c)
mapping["longest"][idx] = self.VARCHAR_MAX
Expand All @@ -156,13 +154,11 @@ def vc_max(self, mapping, columns):
return mapping["longest"]

def vc_trunc(self, mapping):

return [
self.VARCHAR_MAX if c > self.VARCHAR_MAX else c for c in mapping["longest"]
]

def vc_validate(self, mapping):

return [1 if c == 0 else c for c in mapping["longest"]]

def create_sql(self, table_name, mapping, distkey=None, sortkey=None):
Expand Down

0 comments on commit fce51b0

Please sign in to comment.