diff --git a/Text2SQL/Dockerfile b/Text2SQL/Dockerfile
new file mode 100644
index 000000000..3bdabe8aa
--- /dev/null
+++ b/Text2SQL/Dockerfile
@@ -0,0 +1,17 @@
+# Use an official Python runtime as a base image
+FROM python:3.11-slim
+
+# Set the working directory inside the container
+WORKDIR /app
+
+# Install required dependencies
+RUN pip install websockets transformers torch tensorflow-cpu
+
+# Copy the Python WebSocket server code into the container
+COPY server.py /app/
+
+# Expose the port the WebSocket server will listen on
+EXPOSE 8765
+
+# Command to run the WebSocket server
+CMD ["python", "server.py"]
diff --git a/Text2SQL/server.py b/Text2SQL/server.py
new file mode 100644
index 000000000..00fc3c14f
--- /dev/null
+++ b/Text2SQL/server.py
@@ -0,0 +1,114 @@
+import re
+import time
+import asyncio
+import websockets
+import logging
+from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
+
+# Load the model and tokenizer
+model_path = 'gaussalgo/T5-LM-Large-text2sql-spider'
+model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
+tokenizer = AutoTokenizer.from_pretrained(model_path)
+
+# Database schema (unchanged, used for model input)
+schema = """
+"USA_OEE"
+ "timestamp" STRING,
+ "device_name" STRING,
+ "Quality" FLOAT,
+ "Performance" FLOAT,
+ "Availability" FLOAT,
+ "OEE" FLOAT,
+ foreign_key:
+ primary key: "timestamp"
+"""
+
+# Table names and column names
+table_names = ["USA_OEE"] # Table names in the schema
+column_names = ["timestamp", "device_name", "Quality", "Performance", "Availability", "OEE"] # Column names in the schema
+
+# Function to add double quotations to table and column names in the SQL query
+def add_double_quotations(sql_query, table_names, column_names):
+ """
+ Add double quotations to table and column names in the SQL query.
+ :param sql_query: Input SQL query string
+ :param table_names: List of table names
+ :param column_names: List of column names
+ :return: Formatted SQL query
+ """
+ # Create a mapping for tables and columns
+ table_map = {table.lower(): f'public."{table}"' for table in table_names}
+ column_map = {col.lower(): f'"{col}"' for col in column_names}
+
+ # Define a regex pattern to identify table and column names
+ identifier_pattern = r'\b\w+\b'
+
+ # Replace table and column names using the maps
+ def replace_identifiers(match):
+ identifier = match.group(0)
+ if identifier.lower() in table_map:
+ return table_map[identifier.lower()]
+ elif identifier.lower() in column_map:
+ return column_map[identifier.lower()]
+ return identifier # Return the original if not found
+
+ # Apply the regex pattern to the SQL query
+ formatted_query = re.sub(identifier_pattern, replace_identifiers, sql_query)
+ return formatted_query
+
+# Function to generate SQL query from the question using the transformer model
+def generate_sql_query(question):
+ # Combine question with schema
+ input_text = " ".join(["Question: ", question, "Schema:", schema])
+
+ try:
+ # Start the timer
+ start_time = time.time()
+
+ # Tokenize the input and generate the output
+ model_inputs = tokenizer(input_text, return_tensors="pt")
+ outputs = model.generate(**model_inputs, max_length=512)
+
+ # Stop the timer
+ end_time = time.time()
+
+ # Decode and return the SQL query
+ output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ generated_sql = output_text[0]
+
+ # Add double quotations to table and column names
+ formatted_sql = add_double_quotations(generated_sql, table_names, column_names)
+
+ # Print the time taken (for logging)
+ print(f"Time taken: {end_time - start_time:.2f} seconds\n")
+ return formatted_sql
+ except Exception as e:
+ return f"An error occurred: {e}"
+
+# Set up logging
+logging.basicConfig(level=logging.INFO)
+
+# WebSocket handler that processes questions and returns SQL
+async def echo(websocket):
+ logging.info(f"New connection from {websocket.remote_address}")
+ try:
+ async for message in websocket:
+ logging.info(f"Received message: {message}")
+ # Call the function to generate SQL query from the question
+ sql_query = generate_sql_query(message)
+ await websocket.send(sql_query)
+ except websockets.exceptions.ConnectionClosed as e:
+ logging.error(f"Connection closed: {e}")
+
+# WebSocket server function
+async def main():
+ # Create the WebSocket server
+ server = await websockets.serve(echo, "0.0.0.0", 8765)
+ logging.info("WebSocket Server running on ws://0.0.0.0:8765")
+
+ # Keep the server running indefinitely
+ await server.wait_closed()
+
+if __name__ == "__main__":
+ # Run the WebSocket server
+ asyncio.run(main())
diff --git a/docker-compose.yml b/docker-compose.yml
index 32355dbad..e8fac24af 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -214,6 +214,17 @@ services:
healthcheck:
disable: true
+ text2sql:
+ build:
+ context: ./Text2Sql # Directory where the Text-to-SQL Dockerfile and app are located
+ container_name: websocket-server
+ ports:
+ - "8765:8765"
+ depends_on:
+ - db # Connects to your PostgreSQL database
+ networks:
+ - default # Ensures it uses the same network as other containers
+
superset-tests-worker:
build:
<<: *common-build
@@ -239,6 +250,8 @@ services:
volumes: *superset-volumes
healthcheck:
test: ["CMD-SHELL", "celery inspect ping -A superset.tasks.celery_app:app -d celery@$$HOSTNAME"]
+
+
volumes:
superset_home:
diff --git a/superset-frontend/src/pages/Dashboard/index.tsx b/superset-frontend/src/pages/Dashboard/index.tsx
index 22326b3bc..11643cc2b 100644
--- a/superset-frontend/src/pages/Dashboard/index.tsx
+++ b/superset-frontend/src/pages/Dashboard/index.tsx
@@ -23,6 +23,7 @@ import { UserWithPermissionsAndRoles } from 'src/types/bootstrapTypes';
import { t } from '@superset-ui/core';
import './index.css';
import { DashboardPage } from 'src/dashboard/containers/DashboardPage';
+import ChatBOT from './ChatBOT';
import AlertList from '../AlertReportList';
import { addDangerToast, addSuccessToast } from 'src/components/MessageToasts/actions';
@@ -174,7 +175,7 @@ const DashboardRoute: FC = () => {
user={currentUser}
/>
) : activeButton === 'ChatBot' ? (
-
+
) : activeButton === 'Analytics' ? (
) : (