diff --git a/clickhouse_connect/driver/client.py b/clickhouse_connect/driver/client.py index 5afc4a97..572e8544 100644 --- a/clickhouse_connect/driver/client.py +++ b/clickhouse_connect/driver/client.py @@ -47,13 +47,11 @@ def _apply_settings(self, settings: Dict[str, Any] = None): def _validate_settings(self, settings: Dict[str, Any]): validated = {} for key, value in settings.items(): - setting_def = self.server_settings.get(key) - if setting_def is None: - logger.warning('Setting %s is not valid, ignoring', key) - continue - if setting_def.readonly: - logger.warning('Setting %s is read only, ignoring', key) - continue + if 'session' not in key: + setting_def = self.server_settings.get(key) + if setting_def is None or setting_def.readonly: + logger.warning('Setting %s is not valid or read only, ignoring', key) + continue validated[key] = value return validated @@ -105,7 +103,7 @@ def exec_query(self, query: str, settings: Optional[Dict] = None, use_none: bool :return: QueryResult of data and metadata returned by ClickHouse """ - def command(self, cmd: str, parameters=None, use_database:bool = True, settings: Dict[str, str] = None) \ + def command(self, cmd: str, parameters=None, use_database: bool = True, settings: Dict[str, str] = None) \ -> Union[str, int, Sequence[str]]: """ Client method that returns a single value instead of a result set @@ -123,7 +121,7 @@ def command(self, cmd: str, parameters=None, use_database:bool = True, settings: return self.exec_command(cmd, use_database, settings) @abstractmethod - def exec_command(self, cmd, use_database: bool = True, settings:Dict[str, str] = None) -> Union[ + def exec_command(self, cmd, use_database: bool = True, settings: Dict[str, str] = None) -> Union[ str, int, Sequence[str]]: """ Subclass implementation of the client query function diff --git a/clickhouse_connect/driver/httpclient.py b/clickhouse_connect/driver/httpclient.py index 7914154e..4859c94c 100644 --- a/clickhouse_connect/driver/httpclient.py +++ b/clickhouse_connect/driver/httpclient.py @@ -103,16 +103,15 @@ def __init__(self, interface: str, host: str, port: int, username: str, password self.session = session self.connect_timeout = connect_timeout self.read_timeout = send_receive_timeout - self.common_settings = {} super().__init__(database=database, query_limit=query_limit, uri=self.url, settings=kwargs) def _apply_settings(self, settings: Dict[str, Any] = None): valid_settings = self._validate_settings(settings) for key, value in valid_settings.items(): if isinstance(value, bool): - self.common_settings[key] = '1' if value else '0' + self.session.params[key] = '1' if value else '0' else: - self.common_settings[key] = str(value) + self.session.params[key] = str(value) def _format_query(self, query: str) -> str: query = query.strip() @@ -127,8 +126,7 @@ def exec_query(self, query: str, settings: Optional[Dict] = None, use_none: bool See BaseClient doc_string for this method """ headers = {'Content-Type': 'text/plain'} - params = self.common_settings.copy() - params['database'] = self.database + params = {'database': self.database } if settings: params.update(settings) if columns_only_re.search(query): @@ -160,9 +158,8 @@ def data_insert(self, table: str, column_names: Sequence[str], data: Sequence[Se See BaseClient doc_string for this method """ headers = {'Content-Type': 'application/octet-stream'} - params = self.common_settings.copy() - params['query'] = f"INSERT INTO {table} ({', '.join(column_names)}) FORMAT {self.write_format}" - params['database'] = self.database + params = {'query': f"INSERT INTO {table} ({', '.join(column_names)}) FORMAT {self.write_format}", + 'database': self.database} if settings: params.update(settings) insert_block = self.build_insert(data, column_types=column_types, column_names=column_names, column_oriented=column_oriented) @@ -174,8 +171,7 @@ def exec_command(self, cmd, use_database: bool = True, settings: Optional[Dict] See BaseClient doc_string for this method """ headers = {'Content-Type': 'text/plain'} - params = self.common_settings.copy() - params['query'] = cmd + params = {'query': cmd} if use_database: params['database'] = self.database if settings: diff --git a/setup.py b/setup.py index 2f657c7c..148eeb8e 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,7 @@ def run_setup(try_c: bool = True): 'superset.db_engine_specs': ['clickhousedb=clickhouse_connect.cc_superset.engine:ClickHouseEngineSpec'] }, classifiers=[ - 'Development Status :: 3 - Alpha', + 'Development Status :: 4 - Beta', 'Intended Audience :: Developers', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python :: 3.7', diff --git a/tests/conftest.py b/tests/conftest.py index be16ced8..68ff07cc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,3 +18,4 @@ def pytest_addoption(parser): parser.addoption('--test-db', help='Test database, will not be cleaned up') parser.addoption('--tls', default=False, action='store_true') parser.addoption('--no-tls', dest='tls', action='store_false') + parser.addoption('--local', default=False, action='store_true') diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 2100b936..75cdbd47 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -21,6 +21,7 @@ class TestConfig(NamedTuple): use_docker: bool test_database: str cleanup: bool + local: bool @property def cloud(self): @@ -43,12 +44,13 @@ def test_config_fixture(request) -> Iterator[TestConfig]: username = request.config.getoption('username') password = request.config.getoption('password') cleanup = request.config.getoption('cleanup') + local = request.config.getoption('local') test_database = request.config.getoption('test_db', None) if test_database: cleanup = False else: test_database = 'cc_test' - yield TestConfig(interface, host, port, username, password, use_docker, test_database, cleanup) + yield TestConfig(interface, host, port, username, password, use_docker, test_database, cleanup, local) @fixture(scope='session', name='test_db') @@ -75,7 +77,7 @@ def test_client_fixture(test_config: TestConfig, test_db: str) -> Iterator[Clien while True: tries += 1 try: - driver = create_client(interface=test_config.interface, + client = create_client(interface=test_config.interface, host=test_config.host, port=test_config.port, username=test_config.username, @@ -86,9 +88,9 @@ def test_client_fixture(test_config: TestConfig, test_db: str) -> Iterator[Clien raise Exception('Failed to connect to ClickHouse server after 30 seconds') from ex sleep(1) if test_db != 'default': - driver.command(f'CREATE DATABASE IF NOT EXISTS {test_db}', use_database=False) - driver.database = test_db - yield driver + client.command(f'CREATE DATABASE IF NOT EXISTS {test_db}', use_database=False) + client.database = test_db + yield client if test_config.use_docker: down_result = run_cmd(['docker-compose', '-f', compose_file, 'down', '-v']) if down_result[0]: diff --git a/tests/integration_tests/test_client.py b/tests/integration_tests/test_client.py index a6b5ada2..d7d897e0 100644 --- a/tests/integration_tests/test_client.py +++ b/tests/integration_tests/test_client.py @@ -1,8 +1,11 @@ from decimal import Decimal +from time import sleep +from clickhouse_connect import create_client from clickhouse_connect.driver.client import Client from clickhouse_connect.driver.options import HAS_NUMPY, HAS_PANDAS from clickhouse_connect.driver.query import QueryResult +from tests.integration_tests.conftest import TestConfig def test_query(test_client: Client): @@ -32,6 +35,28 @@ def test_decimal_conv(test_client: Client, test_table_engine: str): assert result == [(5, -182, 55.2), (57238478234, 77, -29.5773)] +def test_session_params(test_config: TestConfig): + client = create_client(interface=test_config.interface, + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + session_id='TEST_SESSION_ID') + result = client.exec_query('SELECT number FROM system.numbers LIMIT 5', + settings={'query_id': 'test_session_params'}).result_set + assert len(result) == 5 + if test_config.local: + sleep(10) # Allow the log entries to flush to tables + result = client.exec_query( + "SELECT session_id, user FROM system.session_log WHERE session_id = 'TEST_SESSION_ID' AND " + + 'event_time > now() - 30').result_set + assert result[0] == ('TEST_SESSION_ID', test_config.username) + result = client.exec_query( + "SELECT query_id, user FROM system.query_log WHERE query_id = 'test_session_params' AND " + + 'event_time > now() - 30').result_set + assert result[0] == ('test_session_params', test_config.username) + + def test_numpy(test_client: Client): if HAS_NUMPY: np_array = test_client.query_np('SELECT * FROM system.tables')