diff --git a/mysql_mimic/session.py b/mysql_mimic/session.py index 63df2d1..9741dcc 100644 --- a/mysql_mimic/session.py +++ b/mysql_mimic/session.py @@ -32,7 +32,11 @@ ensure_info_schema, ) from mysql_mimic.constants import INFO_SCHEMA, KillKind -from mysql_mimic.variables_processor import SessionContext, VariablesProcessor +from mysql_mimic.variable_processor import ( + SessionContext, + VariableProcessor, + get_var_assignments, +) from mysql_mimic.utils import find_dbs from mysql_mimic.variables import ( Variables, @@ -280,29 +284,7 @@ async def _query_info_schema(self, expression: exp.Expression) -> AllowedResult: async def _set_var_middleware(self, q: Query) -> AllowedResult: """Handles any SET_VAR hints, which set system variables for a single statement""" - hints = q.expression.find_all(exp.Hint) - if not hints: - return await q.next() - - assignments = {} - - # Iterate in reverse order so higher SET_VAR hints get priority - for hint in reversed(list(hints)): - set_var_hint = None - - for e in hint.expressions: - if isinstance(e, exp.Func) and e.name == "SET_VAR": - set_var_hint = e - for eq in e.expressions: - assignments[eq.left.name] = expression_to_value(eq.right) - - if set_var_hint: - set_var_hint.pop() - - # Remove the hint entirely if SET_VAR was the only expression - if not hint.expressions: - hint.pop() - + assignments = get_var_assignments(q.expression) orig = {k: self.variables.get(k) for k in assignments} try: for k, v in assignments.items(): @@ -389,7 +371,7 @@ async def _begin_middleware(self, q: Query) -> AllowedResult: async def _replace_variables_middleware(self, q: Query) -> AllowedResult: """Replace session variables and information functions with their corresponding values""" - VariablesProcessor(self._session_context()).replace_variables(q.expression) + VariableProcessor(self._session_context()).replace_variables(q.expression) return await q.next() async def _static_query_middleware(self, q: Query) -> AllowedResult: @@ -523,9 +505,11 @@ def _show_errors(self, show: exp.Show) -> AllowedResult: def _session_context(self) -> SessionContext: return SessionContext( connection_id=self.connection.connection_id, - current_user=str(self.username), + external_user=self.variables.get("external_user"), + current_user=self.username or "", + version=self.variables.get("version"), variables=self.variables, - database=str(self.database), + database=self.database or "", timestamp=self.timestamp, ) diff --git a/mysql_mimic/variables_processor.py b/mysql_mimic/variable_processor.py similarity index 64% rename from mysql_mimic/variables_processor.py rename to mysql_mimic/variable_processor.py index 318001a..2ab4f6c 100644 --- a/mysql_mimic/variables_processor.py +++ b/mysql_mimic/variable_processor.py @@ -1,12 +1,27 @@ +from dataclasses import dataclass from datetime import datetime from sqlglot import expressions as exp -from mysql_mimic.intercept import value_to_expression +from mysql_mimic.intercept import value_to_expression, expression_to_value from mysql_mimic.variables import Variables +@dataclass class SessionContext: + """ + Contains properties of the current session relevant to setting system variables. + + Args: + connection_id: connection id for the session. + external_user: the username from the identity provider. + current_user: username of the authorized user. + version: MySQL version. + database: MySQL database name. + variables: dictionary of session variables. + timestamp: timestamp at the start of the current query. + """ + connection_id: int external_user: str current_user: str @@ -15,24 +30,44 @@ class SessionContext: variables: Variables timestamp: datetime - def __init__( - self, - connection_id: int, - current_user: str, - variables: Variables, - database: str, - timestamp: datetime, - ): - self.connection_id = connection_id - self.external_user = variables.get("external_user") - self.variables = variables - self.current_user = current_user - self.version = variables.get("version") - self.database = database - self.timestamp = timestamp - - -class VariablesProcessor: + +variable_constants = { + "CURRENT_USER", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_DATE", +} + + +def get_var_assignments(expression: exp.Expression) -> dict[str, str]: + """Handles any SET_VAR hints, which set system variables for a single statement""" + hints = expression.find_all(exp.Hint) + if not hints: + return {} + + assignments = {} + + # Iterate in reverse order so higher SET_VAR hints get priority + for hint in reversed(list(hints)): + set_var_hint = None + + for e in hint.expressions: + if isinstance(e, exp.Func) and e.name == "SET_VAR": + set_var_hint = e + for eq in e.expressions: + assignments[eq.left.name] = expression_to_value(eq.right) + + if set_var_hint: + set_var_hint.pop() + + # Remove the hint entirely if SET_VAR was the only expression + if not hint.expressions: + hint.pop() + + return assignments + + +class VariableProcessor: def __init__(self, session: SessionContext): self._session = session @@ -61,14 +96,9 @@ def __init__(self, session: SessionContext): "CURRENT_TIME": self._functions["CURTIME"], } ) - self._constants = { - "CURRENT_USER", - "CURRENT_TIME", - "CURRENT_TIMESTAMP", - "CURRENT_DATE", - } def replace_variables(self, expression: exp.Expression) -> None: + """Replaces certain system variables with information provided from the session context.""" if isinstance(expression, exp.Set): for setitem in expression.expressions: if isinstance(setitem.this, exp.Binary): @@ -93,7 +123,7 @@ def _transform(self, node: exp.Expression) -> exp.Expression: if func: value = func() new_node = value_to_expression(value) - elif isinstance(node, exp.Column) and node.sql() in self._constants: + elif isinstance(node, exp.Column) and node.sql() in variable_constants: value = self._functions[node.sql()]() new_node = value_to_expression(value) elif isinstance(node, exp.SessionParameter):