Skip to content

Commit

Permalink
feat: add WebSocket server implementation with SQL query generation
Browse files Browse the repository at this point in the history
  • Loading branch information
kalyan540 committed Dec 25, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent ae1cee6 commit 0e03b07
Showing 2 changed files with 131 additions and 0 deletions.
17 changes: 17 additions & 0 deletions Text2Sql/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Use an official Python runtime as a base image
FROM python:3.11-slim

# Set the working directory inside the container
WORKDIR /app

# Install required dependencies
RUN pip install websockets transformers torch tensorflow-cpu

# Copy the Python WebSocket server code into the container
COPY server.py /app/

# Expose the port the WebSocket server will listen on
EXPOSE 8765

# Command to run the WebSocket server
CMD ["python", "server.py"]
114 changes: 114 additions & 0 deletions Text2Sql/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import re
import time
import asyncio
import websockets
import logging
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Load the model and tokenizer
model_path = 'gaussalgo/T5-LM-Large-text2sql-spider'
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Database schema (unchanged, used for model input)
schema = """
"USA_OEE"
"timestamp" STRING,
"device_name" STRING,
"Quality" FLOAT,
"Performance" FLOAT,
"Availability" FLOAT,
"OEE" FLOAT,
foreign_key:
primary key: "timestamp"
"""

# Table names and column names
table_names = ["USA_OEE"] # Table names in the schema
column_names = ["timestamp", "device_name", "Quality", "Performance", "Availability", "OEE"] # Column names in the schema

# Function to add double quotations to table and column names in the SQL query
def add_double_quotations(sql_query, table_names, column_names):
"""
Add double quotations to table and column names in the SQL query.
:param sql_query: Input SQL query string
:param table_names: List of table names
:param column_names: List of column names
:return: Formatted SQL query
"""
# Create a mapping for tables and columns
table_map = {table.lower(): f'public."{table}"' for table in table_names}
column_map = {col.lower(): f'"{col}"' for col in column_names}

# Define a regex pattern to identify table and column names
identifier_pattern = r'\b\w+\b'

# Replace table and column names using the maps
def replace_identifiers(match):
identifier = match.group(0)
if identifier.lower() in table_map:
return table_map[identifier.lower()]
elif identifier.lower() in column_map:
return column_map[identifier.lower()]
return identifier # Return the original if not found

# Apply the regex pattern to the SQL query
formatted_query = re.sub(identifier_pattern, replace_identifiers, sql_query)
return formatted_query

# Function to generate SQL query from the question using the transformer model
def generate_sql_query(question):
# Combine question with schema
input_text = " ".join(["Question: ", question, "Schema:", schema])

try:
# Start the timer
start_time = time.time()

# Tokenize the input and generate the output
model_inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**model_inputs, max_length=512)

# Stop the timer
end_time = time.time()

# Decode and return the SQL query
output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
generated_sql = output_text[0]

# Add double quotations to table and column names
formatted_sql = add_double_quotations(generated_sql, table_names, column_names)

# Print the time taken (for logging)
print(f"Time taken: {end_time - start_time:.2f} seconds\n")
return formatted_sql
except Exception as e:
return f"An error occurred: {e}"

# Set up logging
logging.basicConfig(level=logging.INFO)

# WebSocket handler that processes questions and returns SQL
async def echo(websocket):
logging.info(f"New connection from {websocket.remote_address}")
try:
async for message in websocket:
logging.info(f"Received message: {message}")
# Call the function to generate SQL query from the question
sql_query = generate_sql_query(message)
await websocket.send(sql_query)
except websockets.exceptions.ConnectionClosed as e:
logging.error(f"Connection closed: {e}")

# WebSocket server function
async def main():
# Create the WebSocket server
server = await websockets.serve(echo, "0.0.0.0", 8765)
logging.info("WebSocket Server running on ws://0.0.0.0:8765")

# Keep the server running indefinitely
await server.wait_closed()

if __name__ == "__main__":
# Run the WebSocket server
asyncio.run(main())

0 comments on commit 0e03b07

Please sign in to comment.