diff --git a/parsons/redash/redash.py b/parsons/redash/redash.py index 47167319dd..d68642a781 100644 --- a/parsons/redash/redash.py +++ b/parsons/redash/redash.py @@ -52,6 +52,12 @@ def __init__( if user_api_key: self.session.headers.update({"Authorization": f"Key {user_api_key}"}) + def _catch_runtime_error(self, res): + if res.status_code != 200: + raise RuntimeError( + f"Error. Status code: {res.status_code}. Reason: {res.reason}" + ) + def _poll_job(self, session, job, query_id): start_secs = time.time() while job["status"] not in (3, 4): @@ -84,6 +90,63 @@ def _poll_job(self, session, job, query_id): "Redash Query {} failed: {}".format(query_id, job["error"]) ) + def get_data_source(self, data_source_id): + """ + Get a data source. + + `Args:` + data_source_id: int or str + ID of data source. + `Returns`: + Data source json object + """ + res = self.session.get(f"{self.base_url}/api/data_sources/{data_source_id}") + self._catch_runtime_error(res) + return res.json() + + def update_data_source( + self, data_source_id, name, type, dbName, host, password, port, user + ): + """ + Update a data source. + + `Args:` + data_source_id: str or int + ID of data source. + name: str + Name of data source. + type: str + Type of data source. + dbname: str + Database name of data source. + host: str + Host of data source. + password: str + Password of data source. + port: int or str + Port of data source. + user: str + Username of data source. + `Returns:` + ``None`` + """ + self._catch_runtime_error( + self.session.post( + f"{self.base_url}/api/data_sources/{data_source_id}", + json={ + "name": name, + "type": type, + "options": { + "dbname": dbName, + "host": host, + "password": password, + "port": port, + "user": user, + }, + }, + ) + ) + def get_fresh_query_results(self, query_id=None, params=None): """ Make a fresh query result and get back the CSV http response object back diff --git a/test/test_redash.py b/test/test_redash.py index d1aed8870c..a81a23882f 100644 --- a/test/test_redash.py +++ b/test/test_redash.py @@ -13,11 +13,43 @@ class TestRedash(unittest.TestCase): mock_data = "foo,bar\n1,2\n3,4" + mock_data_source = { + "id": 1, + "name": "Data Source 1", + "type": "redshift", + "options": { + "dbname": "db_name", + "host": "host.example.com", + "password": "--------", + "port": 5439, + "user": "username", + }, + } mock_result = Table([("foo", "bar"), ("1", "2"), ("3", "4")]) def setUp(self): self.redash = Redash(BASE_URL, API_KEY) + @requests_mock.Mocker() + def test_get_data_source(self, m): + m.get(f"{BASE_URL}/api/data_sources/1", json=self.mock_data_source) + assert self.redash.get_data_source(1) == self.mock_data_source + + @requests_mock.Mocker() + def test_update_data_source(self, m): + m.post(f"{BASE_URL}/api/data_sources/1", json=self.mock_data_source) + self.redash.update_data_source( + 1, + "Data Source 1", + "redshift", + "db_name", + "host.example.com", + "password", + 5439, + "username", + ) + assert m.call_count == 1 + @requests_mock.Mocker() def test_cached_query(self, m): redash = Redash(BASE_URL) # no user_api_key