From 53db610622d6fefff1508fcfb55d43848d877e5e Mon Sep 17 00:00:00 2001 From: P S Solanki Date: Sun, 17 Apr 2022 20:55:51 +0530 Subject: [PATCH] allows disabling enum enforcer on all client methods. (#303) * allowed disabling enum enforcer on all client methods. tests and docs modified as per changes Adds relevant test cases * reduced all lines to under 79 * formatting changes again * removed additional assertions from test cases --- .gitignore | 1 + docs/client.rst | 9 +- tda/auth.py | 69 +++++++++-- tests/auth_test.py | 287 ++++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 334 insertions(+), 32 deletions(-) diff --git a/.gitignore b/.gitignore index 0489431..809e08b 100644 --- a/.gitignore +++ b/.gitignore @@ -104,6 +104,7 @@ venv/ ENV/ env.bak/ venv.bak/ +.idea # Spyder project settings .spyderproject diff --git a/docs/client.rst b/docs/client.rst index 73399d1..fbfe37f 100644 --- a/docs/client.rst +++ b/docs/client.rst @@ -133,7 +133,14 @@ Creating a New Client +++++++++++++++++++++ 99.9% of users should not create their own clients, and should instead follow -the instructions outlined in :ref:`auth`. For those brave enough to build their +the instructions outlined in :ref:`auth`. + +For users who want to disable the strict enum type checking on http client, +just pass ``enforce_enums=False`` in any of the client creation functions +described in :ref:`auth`. Just note that for most users, it is advised they +stick with the default behavior. + +For those brave enough to build their own, the constructor looks like this: .. automethod:: tda.client.Client.__init__ diff --git a/tda/auth.py b/tda/auth.py index df777f0..b77deea 100644 --- a/tda/auth.py +++ b/tda/auth.py @@ -70,7 +70,7 @@ def _register_token_redactions(token): register_redactions(token) -def client_from_token_file(token_path, api_key, asyncio=False): +def client_from_token_file(token_path, api_key, asyncio=False, enforce_enums=True): ''' Returns a session from an existing token file. The session will perform an auth refresh as needed. It will also update the token on disk whenever @@ -82,16 +82,25 @@ def client_from_token_file(token_path, api_key, asyncio=False): :func:`~tda.auth.easy_client` to create one. :param api_key: Your TD Ameritrade application's API key, also known as the client ID. + :param asyncio: If set to ``True``, this will enable async support allowing + the client to be used in an async environment. Defaults to + ``False`` + :param enforce_enums: Set it to ``False`` to disable the enum checks on ALL + the client methods. Only do it if you know you really + need it. For most users, it is advised to use enums + to avoid errors. ''' load = __token_loader(token_path) return client_from_access_functions( - api_key, load, __update_token(token_path), asyncio=asyncio) + api_key, load, __update_token(token_path), asyncio=asyncio, + enforce_enums=enforce_enums) def __fetch_and_register_token_from_redirect( - oauth, redirected_url, api_key, token_path, token_write_func, asyncio): + oauth, redirected_url, api_key, token_path, token_write_func, asyncio, + enforce_enums=True): token = oauth.fetch_token( TOKEN_ENDPOINT, authorization_response=redirected_url, @@ -132,7 +141,7 @@ async def oauth_client_update_token(t, *args, **kwargs): auto_refresh_url=TOKEN_ENDPOINT, auto_refresh_kwargs={'client_id': api_key}, update_token=oauth_client_update_token), - token_metadata=metadata_manager) + token_metadata=metadata_manager, enforce_enums=enforce_enums) class RedirectTimeoutError(Exception): @@ -263,7 +272,8 @@ def ensure_refresh_token_update( # TODO: Raise an exception when passing both token_path and token_write_func def client_from_login_flow(webdriver, api_key, redirect_url, token_path, redirect_wait_time_seconds=0.1, max_waits=3000, - asyncio=False, token_write_func=None): + asyncio=False, token_write_func=None, + enforce_enums=True): ''' Uses the webdriver to perform an OAuth webapp login flow and creates a client wrapped around the resulting token. The client will be configured to @@ -281,6 +291,14 @@ def client_from_login_flow(webdriver, api_key, redirect_url, token_path, :param token_path: Path to which the new token will be written. If the token file already exists, it will be overwritten with a new one. Updated tokens will be written to this path as well. + + :param asyncio: If set to ``True``, this will enable async support allowing + the client to be used in an async environment. Defaults to + ``False`` + :param enforce_enums: Set it to ``False`` to disable the enum checks on ALL + the client methods. Only do it if you know you really + need it. For most users, it is advised to use enums + to avoid errors. ''' get_logger().info('Creating new token with redirect URL \'%s\' ' + 'and token path \'%s\'', redirect_url, token_path) @@ -328,11 +346,12 @@ def client_from_login_flow(webdriver, api_key, redirect_url, token_path, return __fetch_and_register_token_from_redirect( oauth, current_url, api_key, token_path, token_write_func, - asyncio) + asyncio, enforce_enums=enforce_enums) def client_from_manual_flow(api_key, redirect_url, token_path, - asyncio=False, token_write_func=None): + asyncio=False, token_write_func=None, + enforce_enums=True): ''' Walks the user through performing an OAuth login flow by manually copy-pasting URLs, and returns a client wrapped around the resulting token. @@ -351,6 +370,13 @@ def client_from_manual_flow(api_key, redirect_url, token_path, :param token_path: Path to which the new token will be written. If the token file already exists, it will be overwritten with a new one. Updated tokens will be written to this path as well. + :param asyncio: If set to ``True``, this will enable async support allowing + the client to be used in an async environment. Defaults to + ``False`` + :param enforce_enums: Set it to ``False`` to disable the enum checks on ALL + the client methods. Only do it if you know you really + need it. For most users, it is advised to use enums + to avoid errors. ''' get_logger().info('Creating new token with redirect URL \'%s\' ' + 'and token path \'%s\'', redirect_url, token_path) @@ -397,11 +423,11 @@ def client_from_manual_flow(api_key, redirect_url, token_path, return __fetch_and_register_token_from_redirect( oauth, redirected_url, api_key, token_path, token_write_func, - asyncio) + asyncio, enforce_enums=enforce_enums) def easy_client(api_key, redirect_uri, token_path, webdriver_func=None, - asyncio=False): + asyncio=False, enforce_enums=True): '''Convenient wrapper around :func:`client_from_login_flow` and :func:`client_from_token_file`. If ``token_path`` exists, loads the token from it. Otherwise open a login flow to fetch a new token. Returns a client @@ -426,11 +452,19 @@ def easy_client(api_key, redirect_uri, token_path, webdriver_func=None, :param webdriver_func: Function that returns a webdriver for use in fetching a new token. Will only be called if the token file cannot be found. + :param asyncio: If set to ``True``, this will enable async support allowing + the client to be used in an async environment. Defaults to + ``False`` + :param enforce_enums: Set it to ``False`` to disable the enum checks on ALL + the client methods. Only do it if you know you really + need it. For most users, it is advised to use enums + to avoid errors. ''' logger = get_logger() if os.path.isfile(token_path): - c = client_from_token_file(token_path, api_key, asyncio=asyncio) + c = client_from_token_file(token_path, api_key, asyncio=asyncio, + enforce_enums=enforce_enums) logger.info( 'Returning client loaded from token file \'%s\'', token_path) return c @@ -440,7 +474,8 @@ def easy_client(api_key, redirect_uri, token_path, webdriver_func=None, if webdriver_func is not None: with webdriver_func() as driver: c = client_from_login_flow( - driver, api_key, redirect_uri, token_path, asyncio=asyncio) + driver, api_key, redirect_uri, token_path, asyncio=asyncio, + enforce_enums=enforce_enums) logger.info( 'Returning client fetched using webdriver, writing' + 'token to \'%s\'', token_path) @@ -451,7 +486,8 @@ def easy_client(api_key, redirect_uri, token_path, webdriver_func=None, def client_from_access_functions(api_key, token_read_func, - token_write_func, asyncio=False): + token_write_func, asyncio=False, + enforce_enums=True): ''' Returns a session from an existing token file, using the accessor methods to read and write the token. This is an advanced method for users who do not @@ -478,6 +514,13 @@ def client_from_access_functions(api_key, token_read_func, called whenever the token is updated, such as when it is refreshed. See the above-mentioned example for what parameters this method takes. + :param asyncio: If set to ``True``, this will enable async support allowing + the client to be used in an async environment. Defaults to + ``False`` + :param enforce_enums: Set it to ``False`` to disable the enum checks on ALL + the client methods. Only do it if you know you really + need it. For most users, it is advised to use enums + to avoid errors. ''' token = token_read_func() @@ -510,4 +553,4 @@ async def oauth_client_update_token(t, *args, **kwargs): token=token, token_endpoint=TOKEN_ENDPOINT, update_token=oauth_client_update_token), - token_metadata=metadata) + token_metadata=metadata, enforce_enums=enforce_enums) diff --git a/tests/auth_test.py b/tests/auth_test.py index 66460ae..093d67b 100644 --- a/tests/auth_test.py +++ b/tests/auth_test.py @@ -49,7 +49,8 @@ def test_pickle_loads(self, async_session, sync_session, client): self.assertEqual('returned client', auth.client_from_token_file(self.pickle_path, API_KEY)) - client.assert_called_once_with(API_KEY, _, token_metadata=_) + client.assert_called_once_with(API_KEY, _, token_metadata=_, + enforce_enums=_) sync_session.assert_called_once_with( API_KEY, token=self.token, @@ -67,7 +68,8 @@ def test_json_loads(self, async_session, sync_session, client): self.assertEqual('returned client', auth.client_from_token_file(self.json_path, API_KEY)) - client.assert_called_once_with(API_KEY, _, token_metadata=_) + client.assert_called_once_with(API_KEY, _, token_metadata=_, + enforce_enums=_) sync_session.assert_called_once_with( API_KEY, token=self.token, @@ -96,7 +98,6 @@ def test_update_token_updates_token( 'token': updated_token, }) - @no_duplicates @patch('tda.auth.Client') @patch('tda.auth.OAuth2Client', new_callable=MockOAuthClient) @@ -109,17 +110,45 @@ def test_api_key_is_normalized(self, async_session, sync_session, client): self.assertEqual('returned client', auth.client_from_token_file(self.json_path, 'API_KEY')) client.assert_called_once_with( - 'API_KEY@AMER.OAUTHAP', _, token_metadata=_) + 'API_KEY@AMER.OAUTHAP', _, token_metadata=_, enforce_enums=_) sync_session.assert_called_once_with( 'API_KEY@AMER.OAUTHAP', token=self.token, token_endpoint=_, update_token=_) + @no_duplicates + @patch('tda.auth.Client') + @patch('tda.auth.OAuth2Client', new_callable=MockOAuthClient) + @patch('tda.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) + def test_enforce_enums_being_disabled(self, async_session, sync_session, client): + self.write_token() -class ClientFromAccessFunctionsTest(unittest.TestCase): + client.return_value = 'returned client' + + self.assertEqual('returned client', + auth.client_from_token_file(self.json_path, API_KEY, + enforce_enums=False)) + client.assert_called_once_with(API_KEY, _, token_metadata=_, + enforce_enums=False) + + @no_duplicates + @patch('tda.auth.Client') + @patch('tda.auth.OAuth2Client', new_callable=MockOAuthClient) + @patch('tda.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) + def test_enforce_enums_being_enabled(self, async_session, sync_session, client): + self.write_token() + + client.return_value = 'returned client' + + self.assertEqual('returned client', + auth.client_from_token_file(self.json_path, API_KEY)) + client.assert_called_once_with(API_KEY, _, token_metadata=_, + enforce_enums=True) +class ClientFromAccessFunctionsTest(unittest.TestCase): + @no_duplicates @patch('tda.auth.Client') @patch('tda.auth.OAuth2Client', new_callable=MockOAuthClient) @@ -132,6 +161,7 @@ def test_success_with_write_func_legacy_token( token_read_func.return_value = token token_writes = [] + def token_write_func(token): token_writes.append(token) @@ -159,7 +189,6 @@ def token_write_func(token): 'token': token, }], token_writes) - @no_duplicates @patch('tda.auth.Client') @patch('tda.auth.OAuth2Client', new_callable=MockOAuthClient) @@ -175,6 +204,7 @@ def test_success_with_write_func_metadata_aware_token( token_read_func.return_value = token token_writes = [] + def token_write_func(token): token_writes.append(token) @@ -199,6 +229,57 @@ def token_write_func(token): update_token(token['token']) self.assertEqual([token], token_writes) + @no_duplicates + @patch('tda.auth.Client') + @patch('tda.auth.OAuth2Client', new_callable=MockOAuthClient) + @patch('tda.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) + def test_success_with_enforce_enums_disabled( + self, async_session, sync_session, client): + token = {'token': 'yes'} + + token_read_func = MagicMock() + token_read_func.return_value = token + + token_writes = [] + + def token_write_func(token): + token_writes.append(token) + + client.return_value = 'returned client' + self.assertEqual('returned client', + auth.client_from_access_functions( + 'API_KEY@AMER.OAUTHAP', + token_read_func, + token_write_func, enforce_enums=False)) + + client.assert_called_once_with('API_KEY@AMER.OAUTHAP', _, token_metadata=_, + enforce_enums=False) + + @no_duplicates + @patch('tda.auth.Client') + @patch('tda.auth.OAuth2Client', new_callable=MockOAuthClient) + @patch('tda.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) + def test_success_with_enforce_enums_enabled( + self, async_session, sync_session, client): + token = {'token': 'yes'} + + token_read_func = MagicMock() + token_read_func.return_value = token + + token_writes = [] + + def token_write_func(token): + token_writes.append(token) + + client.return_value = 'returned client' + self.assertEqual('returned client', + auth.client_from_access_functions( + 'API_KEY@AMER.OAUTHAP', + token_read_func, + token_write_func)) + + client.assert_called_once_with('API_KEY@AMER.OAUTHAP', _, token_metadata=_, + enforce_enums=True) REDIRECT_URL = 'https://redirect.url.com' @@ -211,7 +292,6 @@ def setUp(self): self.json_path = os.path.join(self.tmp_dir.name, 'token.json') self.token = {'token': 'yes'} - @no_duplicates @patch('tda.auth.Client') @patch('tda.auth.OAuth2Client', new_callable=MockOAuthClient) @@ -335,7 +415,6 @@ def test_normalize_api_key(self, async_session, sync_session, client): 'API_KEY@AMER.OAUTHAP', sync_session.call_args[0][0]) - @no_duplicates @patch('tda.auth.Client') @patch('tda.auth.OAuth2Client', new_callable=MockOAuthClient) @@ -359,7 +438,6 @@ def test_unexpected_redirect_url(self, async_session, sync_session, client): self.json_path, redirect_wait_time_seconds=0.0) - @no_duplicates @patch('tda.auth.Client') @patch('tda.auth.OAuth2Client', new_callable=MockOAuthClient) @@ -394,7 +472,6 @@ def test_default_token_write_func( 'token': self.token }, json.load(f)) - @no_duplicates @patch('tda.auth.Client') @patch('tda.auth.OAuth2Client', new_callable=MockOAuthClient) @@ -413,6 +490,7 @@ def test_custom_token_write_func(self, async_session, sync_session, client): client.return_value = 'returned client' token_writes = [] + def dummy_token_write_func(token): token_writes.append(token) @@ -432,6 +510,61 @@ def dummy_token_write_func(token): 'token': self.token }], token_writes) + @no_duplicates + @patch('tda.auth.Client') + @patch('tda.auth.OAuth2Client', new_callable=MockOAuthClient) + @patch('tda.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) + @patch('time.time', unittest.mock.MagicMock(return_value=MOCK_NOW)) + def test_enforce_enums_disabled( + self, async_session, sync_session, client): + AUTH_URL = 'https://auth.url.com' + + sync_session.return_value = sync_session + sync_session.create_authorization_url.return_value = AUTH_URL, None + sync_session.fetch_token.return_value = self.token + + webdriver = MagicMock() + webdriver.current_url = REDIRECT_URL + '/token_params' + + client.return_value = 'returned client' + + self.assertEqual('returned client', + auth.client_from_login_flow( + webdriver, API_KEY, REDIRECT_URL, + self.json_path, + redirect_wait_time_seconds=0.0, + enforce_enums=False)) + + client.assert_called_once_with(API_KEY, _, token_metadata=_, + enforce_enums=False) + + @no_duplicates + @patch('tda.auth.Client') + @patch('tda.auth.OAuth2Client', new_callable=MockOAuthClient) + @patch('tda.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) + @patch('time.time', unittest.mock.MagicMock(return_value=MOCK_NOW)) + def test_enforce_enums_enabled( + self, async_session, sync_session, client): + AUTH_URL = 'https://auth.url.com' + + sync_session.return_value = sync_session + sync_session.create_authorization_url.return_value = AUTH_URL, None + sync_session.fetch_token.return_value = self.token + + webdriver = MagicMock() + webdriver.current_url = REDIRECT_URL + '/token_params' + + client.return_value = 'returned client' + + self.assertEqual('returned client', + auth.client_from_login_flow( + webdriver, API_KEY, REDIRECT_URL, + self.json_path, + redirect_wait_time_seconds=0.0)) + + client.assert_called_once_with(API_KEY, _, token_metadata=_, + enforce_enums=True) + class ClientFromManualFlow(unittest.TestCase): @@ -440,7 +573,6 @@ def setUp(self): self.json_path = os.path.join(self.tmp_dir.name, 'token.json') self.token = {'token': 'yes'} - @no_duplicates @patch('tda.auth.Client') @patch('tda.auth.OAuth2Client', new_callable=MockOAuthClient) @@ -495,7 +627,6 @@ def test_normalize_api_key( 'API_KEY@AMER.OAUTHAP', sync_session.call_args[0][0]) - @no_duplicates @patch('tda.auth.Client') @patch('tda.auth.OAuth2Client', new_callable=MockOAuthClient) @@ -517,6 +648,7 @@ def test_custom_token_write_func( prompt_func.return_value = 'http://redirect.url.com/?data' token_writes = [] + def dummy_token_write_func(token): token_writes.append(token) @@ -535,7 +667,6 @@ def dummy_token_write_func(token): 'token': self.token }], token_writes) - @no_duplicates @patch('tda.auth.Client') @patch('tda.auth.OAuth2Client', new_callable=MockOAuthClient) @@ -568,6 +699,55 @@ def test_print_warning_on_http_redirect_uri( print_func.assert_any_call(AnyStringWith('will transmit data over HTTP')) + @no_duplicates + @patch('tda.auth.Client') + @patch('tda.auth.OAuth2Client', new_callable=MockOAuthClient) + @patch('tda.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) + @patch('tda.auth.prompt') + @patch('time.time', unittest.mock.MagicMock(return_value=MOCK_NOW)) + def test_enforce_enums_disabled( + self, prompt_func, async_session, sync_session, client): + AUTH_URL = 'https://auth.url.com' + + sync_session.return_value = sync_session + sync_session.create_authorization_url.return_value = AUTH_URL, None + sync_session.fetch_token.return_value = self.token + + client.return_value = 'returned client' + prompt_func.return_value = 'http://redirect.url.com/?data' + + self.assertEqual('returned client', + auth.client_from_manual_flow( + API_KEY, REDIRECT_URL, self.json_path, + enforce_enums=False)) + + client.assert_called_once_with(API_KEY, _, token_metadata=_, + enforce_enums=False) + + @no_duplicates + @patch('tda.auth.Client') + @patch('tda.auth.OAuth2Client', new_callable=MockOAuthClient) + @patch('tda.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient) + @patch('tda.auth.prompt') + @patch('time.time', unittest.mock.MagicMock(return_value=MOCK_NOW)) + def test_enforce_enums_enabled( + self, prompt_func, async_session, sync_session, client): + AUTH_URL = 'https://auth.url.com' + + sync_session.return_value = sync_session + sync_session.create_authorization_url.return_value = AUTH_URL, None + sync_session.fetch_token.return_value = self.token + + client.return_value = 'returned client' + prompt_func.return_value = 'http://redirect.url.com/?data' + + self.assertEqual('returned client', + auth.client_from_manual_flow( + API_KEY, REDIRECT_URL, self.json_path)) + + client.assert_called_once_with(API_KEY, _, token_metadata=_, + enforce_enums=True) + class EasyClientTest(unittest.TestCase): @@ -620,6 +800,79 @@ def test_no_token_file_with_wd_func( webdriver_func.assert_called_once() client_from_login_flow.assert_called_once() + @no_duplicates + @patch('tda.auth.client_from_login_flow') + @patch('tda.auth.client_from_token_file') + def test_no_token_file_with_wd_func_with_enums_disabled( + self, + client_from_token_file, + client_from_login_flow): + client_from_token_file.side_effect = SystemExit() + client_from_login_flow.return_value = 'returned client' + webdriver_func = MagicMock() + + self.assertEquals('returned client', + auth.easy_client( + API_KEY, REDIRECT_URL, self.json_path, + webdriver_func=webdriver_func, enforce_enums=False)) + + webdriver_func.assert_called_once() + client_from_login_flow.assert_called_once_with(_, API_KEY, REDIRECT_URL, + self.json_path, + asyncio=False, + enforce_enums=False) + + @no_duplicates + @patch('tda.auth.client_from_login_flow') + @patch('tda.auth.client_from_token_file') + def test_no_token_file_with_wd_func_with_enums_enabled( + self, + client_from_token_file, + client_from_login_flow): + client_from_token_file.side_effect = SystemExit() + client_from_login_flow.return_value = 'returned client' + webdriver_func = MagicMock() + + self.assertEquals('returned client', + auth.easy_client( + API_KEY, REDIRECT_URL, self.json_path, + webdriver_func=webdriver_func)) + + webdriver_func.assert_called_once() + client_from_login_flow.assert_called_once_with(_, API_KEY, REDIRECT_URL, + self.json_path, + asyncio=False, + enforce_enums=True) + + @no_duplicates + @patch('tda.auth.client_from_token_file') + def test_token_file_with_enums_disabled(self, client_from_token_file): + self.write_token() + + webdriver_func = MagicMock() + client_from_token_file.return_value = self.token + + self.assertEquals(self.token, + auth.easy_client(API_KEY, REDIRECT_URL, self.json_path, + enforce_enums=False)) + client_from_token_file.assert_called_once_with(self.json_path, API_KEY, + asyncio=False, + enforce_enums=False) + + @no_duplicates + @patch('tda.auth.client_from_token_file') + def test_token_file_with_enums_enabled(self, client_from_token_file): + self.write_token() + + webdriver_func = MagicMock() + client_from_token_file.return_value = self.token + + self.assertEquals(self.token, + auth.easy_client(API_KEY, REDIRECT_URL, self.json_path)) + client_from_token_file.assert_called_once_with(self.json_path, API_KEY, + asyncio=False, + enforce_enums=True) + class TokenMetadataTest(unittest.TestCase): @@ -666,27 +919,25 @@ def test_correct_suffix(self): import logging logging.getLogger('tda.auth').warning('dummy') - self.assertEqual('API_KEY@AMER.OAUTHAP', + self.assertEqual('API_KEY@AMER.OAUTHAP', auth._normalize_api_key('API_KEY@AMER.OAUTHAP')) self.assertEqual(['WARNING:tda.auth:dummy'], log.output) - @no_duplicates def test_no_suffix(self): with self.assertLogs('tda.auth') as log: - self.assertEqual('API_KEY@AMER.OAUTHAP', + self.assertEqual('API_KEY@AMER.OAUTHAP', auth._normalize_api_key('API_KEY')) self.assertEqual([ 'INFO:tda.auth:Appending @AMER.OAUTHAP to API key' ], log.output) - @no_duplicates def test_invalid_suffix(self): with self.assertLogs('tda.auth') as log: - self.assertEqual('API_KEY@AMER.OAUTHAP', + self.assertEqual('API_KEY@AMER.OAUTHAP', auth._normalize_api_key('API_KEY@AMER')) self.assertEqual([