diff --git a/Text2SQL/Dockerfile b/Text2SQL/Dockerfile new file mode 100644 index 000000000..3bdabe8aa --- /dev/null +++ b/Text2SQL/Dockerfile @@ -0,0 +1,17 @@ +# Use an official Python runtime as a base image +FROM python:3.11-slim + +# Set the working directory inside the container +WORKDIR /app + +# Install required dependencies +RUN pip install websockets transformers torch tensorflow-cpu + +# Copy the Python WebSocket server code into the container +COPY server.py /app/ + +# Expose the port the WebSocket server will listen on +EXPOSE 8765 + +# Command to run the WebSocket server +CMD ["python", "server.py"] diff --git a/Text2SQL/server.py b/Text2SQL/server.py new file mode 100644 index 000000000..00fc3c14f --- /dev/null +++ b/Text2SQL/server.py @@ -0,0 +1,114 @@ +import re +import time +import asyncio +import websockets +import logging +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + +# Load the model and tokenizer +model_path = 'gaussalgo/T5-LM-Large-text2sql-spider' +model = AutoModelForSeq2SeqLM.from_pretrained(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path) + +# Database schema (unchanged, used for model input) +schema = """ +"USA_OEE" + "timestamp" STRING, + "device_name" STRING, + "Quality" FLOAT, + "Performance" FLOAT, + "Availability" FLOAT, + "OEE" FLOAT, + foreign_key: + primary key: "timestamp" +""" + +# Table names and column names +table_names = ["USA_OEE"] # Table names in the schema +column_names = ["timestamp", "device_name", "Quality", "Performance", "Availability", "OEE"] # Column names in the schema + +# Function to add double quotations to table and column names in the SQL query +def add_double_quotations(sql_query, table_names, column_names): + """ + Add double quotations to table and column names in the SQL query. + :param sql_query: Input SQL query string + :param table_names: List of table names + :param column_names: List of column names + :return: Formatted SQL query + """ + # Create a mapping for tables and columns + table_map = {table.lower(): f'public."{table}"' for table in table_names} + column_map = {col.lower(): f'"{col}"' for col in column_names} + + # Define a regex pattern to identify table and column names + identifier_pattern = r'\b\w+\b' + + # Replace table and column names using the maps + def replace_identifiers(match): + identifier = match.group(0) + if identifier.lower() in table_map: + return table_map[identifier.lower()] + elif identifier.lower() in column_map: + return column_map[identifier.lower()] + return identifier # Return the original if not found + + # Apply the regex pattern to the SQL query + formatted_query = re.sub(identifier_pattern, replace_identifiers, sql_query) + return formatted_query + +# Function to generate SQL query from the question using the transformer model +def generate_sql_query(question): + # Combine question with schema + input_text = " ".join(["Question: ", question, "Schema:", schema]) + + try: + # Start the timer + start_time = time.time() + + # Tokenize the input and generate the output + model_inputs = tokenizer(input_text, return_tensors="pt") + outputs = model.generate(**model_inputs, max_length=512) + + # Stop the timer + end_time = time.time() + + # Decode and return the SQL query + output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) + generated_sql = output_text[0] + + # Add double quotations to table and column names + formatted_sql = add_double_quotations(generated_sql, table_names, column_names) + + # Print the time taken (for logging) + print(f"Time taken: {end_time - start_time:.2f} seconds\n") + return formatted_sql + except Exception as e: + return f"An error occurred: {e}" + +# Set up logging +logging.basicConfig(level=logging.INFO) + +# WebSocket handler that processes questions and returns SQL +async def echo(websocket): + logging.info(f"New connection from {websocket.remote_address}") + try: + async for message in websocket: + logging.info(f"Received message: {message}") + # Call the function to generate SQL query from the question + sql_query = generate_sql_query(message) + await websocket.send(sql_query) + except websockets.exceptions.ConnectionClosed as e: + logging.error(f"Connection closed: {e}") + +# WebSocket server function +async def main(): + # Create the WebSocket server + server = await websockets.serve(echo, "0.0.0.0", 8765) + logging.info("WebSocket Server running on ws://0.0.0.0:8765") + + # Keep the server running indefinitely + await server.wait_closed() + +if __name__ == "__main__": + # Run the WebSocket server + asyncio.run(main()) diff --git a/docker-compose.yml b/docker-compose.yml index 32355dbad..e8fac24af 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -214,6 +214,17 @@ services: healthcheck: disable: true + text2sql: + build: + context: ./Text2Sql # Directory where the Text-to-SQL Dockerfile and app are located + container_name: websocket-server + ports: + - "8765:8765" + depends_on: + - db # Connects to your PostgreSQL database + networks: + - default # Ensures it uses the same network as other containers + superset-tests-worker: build: <<: *common-build @@ -239,6 +250,8 @@ services: volumes: *superset-volumes healthcheck: test: ["CMD-SHELL", "celery inspect ping -A superset.tasks.celery_app:app -d celery@$$HOSTNAME"] + + volumes: superset_home: diff --git a/superset-frontend/src/pages/Dashboard/index.tsx b/superset-frontend/src/pages/Dashboard/index.tsx index 22326b3bc..11643cc2b 100644 --- a/superset-frontend/src/pages/Dashboard/index.tsx +++ b/superset-frontend/src/pages/Dashboard/index.tsx @@ -23,6 +23,7 @@ import { UserWithPermissionsAndRoles } from 'src/types/bootstrapTypes'; import { t } from '@superset-ui/core'; import './index.css'; import { DashboardPage } from 'src/dashboard/containers/DashboardPage'; +import ChatBOT from './ChatBOT'; import AlertList from '../AlertReportList'; import { addDangerToast, addSuccessToast } from 'src/components/MessageToasts/actions'; @@ -174,7 +175,7 @@ const DashboardRoute: FC = () => { user={currentUser} /> ) : activeButton === 'ChatBot' ? ( - + ) : activeButton === 'Analytics' ? ( ) : (