Skip to content

Commit

Permalink
Added: AI model for Chatbot
Browse files Browse the repository at this point in the history
  • Loading branch information
kalyan540 committed Dec 25, 2024
1 parent b243937 commit 94eccf1
Showing 4 changed files with 146 additions and 1 deletion.
17 changes: 17 additions & 0 deletions Text2SQL/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Use an official Python runtime as a base image
FROM python:3.11-slim

# Set the working directory inside the container
WORKDIR /app

# Install required dependencies
RUN pip install websockets transformers torch tensorflow-cpu

# Copy the Python WebSocket server code into the container
COPY server.py /app/

# Expose the port the WebSocket server will listen on
EXPOSE 8765

# Command to run the WebSocket server
CMD ["python", "server.py"]
114 changes: 114 additions & 0 deletions Text2SQL/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import re
import time
import asyncio
import websockets
import logging
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Load the model and tokenizer
model_path = 'gaussalgo/T5-LM-Large-text2sql-spider'
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Database schema (unchanged, used for model input)
schema = """
"USA_OEE"
"timestamp" STRING,
"device_name" STRING,
"Quality" FLOAT,
"Performance" FLOAT,
"Availability" FLOAT,
"OEE" FLOAT,
foreign_key:
primary key: "timestamp"
"""

# Table names and column names
table_names = ["USA_OEE"] # Table names in the schema
column_names = ["timestamp", "device_name", "Quality", "Performance", "Availability", "OEE"] # Column names in the schema

# Function to add double quotations to table and column names in the SQL query
def add_double_quotations(sql_query, table_names, column_names):
"""
Add double quotations to table and column names in the SQL query.
:param sql_query: Input SQL query string
:param table_names: List of table names
:param column_names: List of column names
:return: Formatted SQL query
"""
# Create a mapping for tables and columns
table_map = {table.lower(): f'public."{table}"' for table in table_names}
column_map = {col.lower(): f'"{col}"' for col in column_names}

# Define a regex pattern to identify table and column names
identifier_pattern = r'\b\w+\b'

# Replace table and column names using the maps
def replace_identifiers(match):
identifier = match.group(0)
if identifier.lower() in table_map:
return table_map[identifier.lower()]
elif identifier.lower() in column_map:
return column_map[identifier.lower()]
return identifier # Return the original if not found

# Apply the regex pattern to the SQL query
formatted_query = re.sub(identifier_pattern, replace_identifiers, sql_query)
return formatted_query

# Function to generate SQL query from the question using the transformer model
def generate_sql_query(question):
# Combine question with schema
input_text = " ".join(["Question: ", question, "Schema:", schema])

try:
# Start the timer
start_time = time.time()

# Tokenize the input and generate the output
model_inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**model_inputs, max_length=512)

# Stop the timer
end_time = time.time()

# Decode and return the SQL query
output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
generated_sql = output_text[0]

# Add double quotations to table and column names
formatted_sql = add_double_quotations(generated_sql, table_names, column_names)

# Print the time taken (for logging)
print(f"Time taken: {end_time - start_time:.2f} seconds\n")
return formatted_sql
except Exception as e:
return f"An error occurred: {e}"

# Set up logging
logging.basicConfig(level=logging.INFO)

# WebSocket handler that processes questions and returns SQL
async def echo(websocket):
logging.info(f"New connection from {websocket.remote_address}")
try:
async for message in websocket:
logging.info(f"Received message: {message}")
# Call the function to generate SQL query from the question
sql_query = generate_sql_query(message)
await websocket.send(sql_query)
except websockets.exceptions.ConnectionClosed as e:
logging.error(f"Connection closed: {e}")

# WebSocket server function
async def main():
# Create the WebSocket server
server = await websockets.serve(echo, "0.0.0.0", 8765)
logging.info("WebSocket Server running on ws://0.0.0.0:8765")

# Keep the server running indefinitely
await server.wait_closed()

if __name__ == "__main__":
# Run the WebSocket server
asyncio.run(main())
13 changes: 13 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -214,6 +214,17 @@ services:
healthcheck:
disable: true

text2sql:
build:
context: ./Text2Sql # Directory where the Text-to-SQL Dockerfile and app are located
container_name: websocket-server
ports:
- "8765:8765"
depends_on:
- db # Connects to your PostgreSQL database
networks:
- default # Ensures it uses the same network as other containers

superset-tests-worker:
build:
<<: *common-build
@@ -239,6 +250,8 @@ services:
volumes: *superset-volumes
healthcheck:
test: ["CMD-SHELL", "celery inspect ping -A superset.tasks.celery_app:app -d celery@$$HOSTNAME"]



volumes:
superset_home:
3 changes: 2 additions & 1 deletion superset-frontend/src/pages/Dashboard/index.tsx
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ import { UserWithPermissionsAndRoles } from 'src/types/bootstrapTypes';
import { t } from '@superset-ui/core';
import './index.css';
import { DashboardPage } from 'src/dashboard/containers/DashboardPage';
import ChatBOT from './ChatBOT';
import AlertList from '../AlertReportList';
import { addDangerToast, addSuccessToast } from 'src/components/MessageToasts/actions';

@@ -174,7 +175,7 @@ const DashboardRoute: FC = () => {
user={currentUser}
/>
) : activeButton === 'ChatBot' ? (
<DashboardPage idOrSlug={'15'} />
<ChatBOT />
) : activeButton === 'Analytics' ? (
<DashboardPage idOrSlug={'15'} />
) : (

0 comments on commit 94eccf1

Please sign in to comment.