Skip to content

Commit

Permalink
Added docstrings, included variable replacement, other cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
cecily_carver committed Jan 14, 2025
1 parent 4d37316 commit dd8ac8e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 53 deletions.
38 changes: 11 additions & 27 deletions mysql_mimic/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
ensure_info_schema,
)
from mysql_mimic.constants import INFO_SCHEMA, KillKind
from mysql_mimic.variables_processor import SessionContext, VariablesProcessor
from mysql_mimic.variable_processor import (
SessionContext,
VariableProcessor,
get_var_assignments,
)
from mysql_mimic.utils import find_dbs
from mysql_mimic.variables import (
Variables,
Expand Down Expand Up @@ -280,29 +284,7 @@ async def _query_info_schema(self, expression: exp.Expression) -> AllowedResult:

async def _set_var_middleware(self, q: Query) -> AllowedResult:
"""Handles any SET_VAR hints, which set system variables for a single statement"""
hints = q.expression.find_all(exp.Hint)
if not hints:
return await q.next()

assignments = {}

# Iterate in reverse order so higher SET_VAR hints get priority
for hint in reversed(list(hints)):
set_var_hint = None

for e in hint.expressions:
if isinstance(e, exp.Func) and e.name == "SET_VAR":
set_var_hint = e
for eq in e.expressions:
assignments[eq.left.name] = expression_to_value(eq.right)

if set_var_hint:
set_var_hint.pop()

# Remove the hint entirely if SET_VAR was the only expression
if not hint.expressions:
hint.pop()

assignments = get_var_assignments(q.expression)
orig = {k: self.variables.get(k) for k in assignments}
try:
for k, v in assignments.items():
Expand Down Expand Up @@ -389,7 +371,7 @@ async def _begin_middleware(self, q: Query) -> AllowedResult:

async def _replace_variables_middleware(self, q: Query) -> AllowedResult:
"""Replace session variables and information functions with their corresponding values"""
VariablesProcessor(self._session_context()).replace_variables(q.expression)
VariableProcessor(self._session_context()).replace_variables(q.expression)
return await q.next()

async def _static_query_middleware(self, q: Query) -> AllowedResult:
Expand Down Expand Up @@ -523,9 +505,11 @@ def _show_errors(self, show: exp.Show) -> AllowedResult:
def _session_context(self) -> SessionContext:
return SessionContext(
connection_id=self.connection.connection_id,
current_user=str(self.username),
external_user=self.variables.get("external_user"),
current_user=self.username or "",
version=self.variables.get("version"),
variables=self.variables,
database=str(self.database),
database=self.database or "",
timestamp=self.timestamp,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
from dataclasses import dataclass
from datetime import datetime

from sqlglot import expressions as exp

from mysql_mimic.intercept import value_to_expression
from mysql_mimic.intercept import value_to_expression, expression_to_value
from mysql_mimic.variables import Variables


@dataclass
class SessionContext:
"""
Contains properties of the current session relevant to setting system variables.
Args:
connection_id: connection id for the session.
external_user: the username from the identity provider.
current_user: username of the authorized user.
version: MySQL version.
database: MySQL database name.
variables: dictionary of session variables.
timestamp: timestamp at the start of the current query.
"""

connection_id: int
external_user: str
current_user: str
Expand All @@ -15,24 +30,44 @@ class SessionContext:
variables: Variables
timestamp: datetime

def __init__(
self,
connection_id: int,
current_user: str,
variables: Variables,
database: str,
timestamp: datetime,
):
self.connection_id = connection_id
self.external_user = variables.get("external_user")
self.variables = variables
self.current_user = current_user
self.version = variables.get("version")
self.database = database
self.timestamp = timestamp


class VariablesProcessor:

variable_constants = {
"CURRENT_USER",
"CURRENT_TIME",
"CURRENT_TIMESTAMP",
"CURRENT_DATE",
}


def get_var_assignments(expression: exp.Expression) -> dict[str, str]:
"""Handles any SET_VAR hints, which set system variables for a single statement"""
hints = expression.find_all(exp.Hint)
if not hints:
return {}

assignments = {}

# Iterate in reverse order so higher SET_VAR hints get priority
for hint in reversed(list(hints)):
set_var_hint = None

for e in hint.expressions:
if isinstance(e, exp.Func) and e.name == "SET_VAR":
set_var_hint = e
for eq in e.expressions:
assignments[eq.left.name] = expression_to_value(eq.right)

if set_var_hint:
set_var_hint.pop()

# Remove the hint entirely if SET_VAR was the only expression
if not hint.expressions:
hint.pop()

return assignments


class VariableProcessor:

def __init__(self, session: SessionContext):
self._session = session
Expand Down Expand Up @@ -61,14 +96,9 @@ def __init__(self, session: SessionContext):
"CURRENT_TIME": self._functions["CURTIME"],
}
)
self._constants = {
"CURRENT_USER",
"CURRENT_TIME",
"CURRENT_TIMESTAMP",
"CURRENT_DATE",
}

def replace_variables(self, expression: exp.Expression) -> None:
"""Replaces certain system variables with information provided from the session context."""
if isinstance(expression, exp.Set):
for setitem in expression.expressions:
if isinstance(setitem.this, exp.Binary):
Expand All @@ -93,7 +123,7 @@ def _transform(self, node: exp.Expression) -> exp.Expression:
if func:
value = func()
new_node = value_to_expression(value)
elif isinstance(node, exp.Column) and node.sql() in self._constants:
elif isinstance(node, exp.Column) and node.sql() in variable_constants:
value = self._functions[node.sql()]()
new_node = value_to_expression(value)
elif isinstance(node, exp.SessionParameter):
Expand Down

0 comments on commit dd8ac8e

Please sign in to comment.