| |
| |
| |
| |
|
|
| import os |
| import re |
| import sqlite3 |
| from pathlib import Path |
| from uuid import uuid4 |
|
|
| import gradio as gr |
| from langchain.agents import create_agent |
| from langchain.tools import tool |
| from langgraph.checkpoint.memory import InMemorySaver |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
| MODEL_NAME = os.getenv("MODEL_NAME", "openai:gpt-5.4") |
| DATABASE_PATH = Path(os.getenv("DATABASE_PATH", "data/Chinook_Sqlite.sqlite")) |
|
|
|
|
| |
| |
| |
|
|
| def resolve_database_path() -> Path: |
| """ |
| Resolve the SQLite database path. |
| |
| Default: |
| - data/Chinook_Sqlite.sqlite |
| |
| You can override it in Hugging Face Spaces with: |
| DATABASE_PATH=/path/to/your/database.sqlite |
| """ |
|
|
| if DATABASE_PATH.exists(): |
| return DATABASE_PATH |
|
|
| common_paths = [ |
| Path("Chinook_Sqlite.sqlite"), |
| Path("chinook.db"), |
| Path("Chinook.db"), |
| Path("data/chinook.db"), |
| Path("data/Chinook.db"), |
| ] |
|
|
| for path in common_paths: |
| if path.exists(): |
| return path |
|
|
| raise FileNotFoundError( |
| "SQLite database file was not found. " |
| "Upload your database file or set DATABASE_PATH in Hugging Face Variables." |
| ) |
|
|
|
|
| DB_PATH = resolve_database_path() |
|
|
|
|
| def get_database_schema(db_path: Path) -> str: |
| """ |
| Extract table and column information from the SQLite database. |
| This schema is injected into the system prompt so the agent knows the DB structure. |
| """ |
|
|
| conn = sqlite3.connect(db_path) |
| cursor = conn.cursor() |
|
|
| cursor.execute( |
| """ |
| SELECT name |
| FROM sqlite_master |
| WHERE type = 'table' |
| AND name NOT LIKE 'sqlite_%' |
| ORDER BY name; |
| """ |
| ) |
|
|
| tables = [row[0] for row in cursor.fetchall()] |
| schema_lines = [] |
|
|
| for table in tables: |
| schema_lines.append(f"\nTable: {table}") |
|
|
| cursor.execute(f"PRAGMA table_info({table});") |
| columns = cursor.fetchall() |
|
|
| for column in columns: |
| |
| |
| _, name, col_type, notnull, _, pk = column |
|
|
| flags = [] |
| if pk: |
| flags.append("PRIMARY KEY") |
| if notnull: |
| flags.append("NOT NULL") |
|
|
| flag_text = f" ({', '.join(flags)})" if flags else "" |
| schema_lines.append(f"- {name}: {col_type}{flag_text}") |
|
|
| conn.close() |
|
|
| return "\n".join(schema_lines) |
|
|
|
|
| DATABASE_SCHEMA = get_database_schema(DB_PATH) |
|
|
|
|
| def strip_sql_code_fences(query: str) -> str: |
| """ |
| Removes markdown code fences if the model returns SQL inside ```sql ... ```. |
| """ |
|
|
| query = query.strip() |
|
|
| if query.startswith("```"): |
| query = re.sub(r"^```(?:sql)?", "", query, flags=re.IGNORECASE).strip() |
| query = re.sub(r"```$", "", query).strip() |
|
|
| return query |
|
|
|
|
| def is_read_only_sql(query: str) -> bool: |
| """ |
| Basic read-only protection. |
| Allows SELECT, WITH, PRAGMA, and EXPLAIN. |
| Blocks INSERT, UPDATE, DELETE, DROP, ALTER, CREATE, etc. |
| """ |
|
|
| cleaned = strip_sql_code_fences(query) |
| cleaned = re.sub(r"/\*.*?\*/", "", cleaned, flags=re.DOTALL) |
| cleaned = re.sub(r"--.*?$", "", cleaned, flags=re.MULTILINE) |
| cleaned = cleaned.strip().lower() |
|
|
| allowed_starts = ("select", "with", "pragma", "explain") |
|
|
| if not cleaned.startswith(allowed_starts): |
| return False |
|
|
| blocked_keywords = [ |
| "insert ", |
| "update ", |
| "delete ", |
| "drop ", |
| "alter ", |
| "create ", |
| "replace ", |
| "truncate ", |
| "attach ", |
| "detach ", |
| "vacuum", |
| "reindex", |
| ] |
|
|
| return not any(keyword in cleaned for keyword in blocked_keywords) |
|
|
|
|
| def rows_to_markdown(columns, rows, max_rows: int = 50) -> str: |
| """ |
| Convert SQL rows to a Markdown table for readable chatbot output. |
| """ |
|
|
| if not rows: |
| return "Query executed successfully, but returned no rows." |
|
|
| rows = rows[:max_rows] |
|
|
| def clean_cell(value): |
| if value is None: |
| return "" |
| text = str(value) |
| text = text.replace("\n", " ").replace("|", "\\|") |
| return text |
|
|
| header = "| " + " | ".join(columns) + " |" |
| separator = "| " + " | ".join(["---"] * len(columns)) + " |" |
|
|
| body_lines = [] |
| for row in rows: |
| body_lines.append("| " + " | ".join(clean_cell(value) for value in row) + " |") |
|
|
| return "\n".join([header, separator] + body_lines) |
|
|
|
|
| |
| |
| |
|
|
| @tool |
| def execute_sql(query: str) -> str: |
| """ |
| Execute a read-only SQLite SQL query against the Chinook database. |
| |
| Use this tool when the user asks analytical questions that require database access. |
| Only SELECT, WITH, PRAGMA, and EXPLAIN queries are allowed. |
| """ |
|
|
| query = strip_sql_code_fences(query) |
|
|
| if not is_read_only_sql(query): |
| return ( |
| "Blocked for safety. Only read-only SQL is allowed. " |
| "Please use SELECT, WITH, PRAGMA, or EXPLAIN queries." |
| ) |
|
|
| try: |
| conn = sqlite3.connect(DB_PATH) |
| cursor = conn.cursor() |
| cursor.execute(query) |
|
|
| rows = cursor.fetchall() |
| columns = [description[0] for description in cursor.description] if cursor.description else [] |
|
|
| conn.close() |
|
|
| if not columns: |
| return "Query executed successfully." |
|
|
| result_table = rows_to_markdown(columns, rows) |
|
|
| if len(rows) > 50: |
| result_table += f"\n\nShowing first 50 rows out of {len(rows)} rows." |
|
|
| return result_table |
|
|
| except Exception as e: |
| return f"SQL execution error: {str(e)}" |
|
|
|
|
| |
| |
| |
|
|
| SYSTEM_PROMPT = f""" |
| You are a helpful SQL data analyst for the Chinook SQLite database. |
| |
| Your job: |
| - Understand the user's business/data question. |
| - Write correct SQLite queries. |
| - Use the execute_sql tool to query the database. |
| - Explain the result clearly and concisely. |
| - For follow-up questions, use the conversation memory. |
| |
| Important rules: |
| - Use only read-only SQL. |
| - Never modify the database. |
| - Prefer clear SQL with explicit table joins. |
| - When useful, explain the SQL logic briefly. |
| - If the user asks a vague question, make a reasonable interpretation and proceed. |
| - If the database does not contain enough information, say that clearly. |
| |
| Available database schema: |
| {DATABASE_SCHEMA} |
| """ |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| checkpointer = InMemorySaver() |
|
|
| sql_agent_with_memory = create_agent( |
| model=MODEL_NAME, |
| tools=[execute_sql], |
| system_prompt=SYSTEM_PROMPT, |
| checkpointer=checkpointer, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def content_to_text(content): |
| """ |
| Convert LangChain message content into displayable text. |
| """ |
|
|
| if isinstance(content, str): |
| return content |
|
|
| if isinstance(content, list): |
| text_parts = [] |
|
|
| for item in content: |
| if isinstance(item, dict): |
| if "text" in item: |
| text_parts.append(item["text"]) |
| elif "content" in item: |
| text_parts.append(str(item["content"])) |
| else: |
| text_parts.append(str(item)) |
| else: |
| text_parts.append(str(item)) |
|
|
| return "\n".join(text_parts) |
|
|
| return str(content) |
|
|
|
|
| def create_thread_id(): |
| """ |
| Same thread_id = same LangGraph memory. |
| New thread_id = fresh conversation. |
| """ |
|
|
| return f"dds-sql-agent-{uuid4()}" |
|
|
|
|
| def normalize_history_to_messages(history): |
| """ |
| Gradio expects messages format: |
| [ |
| {"role": "user", "content": "..."}, |
| {"role": "assistant", "content": "..."} |
| ] |
| """ |
|
|
| if history is None: |
| return [] |
|
|
| normalized = [] |
|
|
| for item in history: |
| if isinstance(item, dict) and "role" in item and "content" in item: |
| role = item.get("role") |
| if role in ["user", "assistant"]: |
| normalized.append( |
| { |
| "role": role, |
| "content": content_to_text(item.get("content", "")), |
| } |
| ) |
|
|
| return normalized |
|
|
|
|
| |
| |
| |
|
|
| def chat_with_sql_agent(message, history, thread_id): |
| """ |
| Handles one user message from Gradio. |
| |
| This returns messages format without passing type="messages" |
| to gr.Chatbot, because some Gradio 6 runtimes expect messages |
| but do not accept the type argument. |
| """ |
|
|
| history = normalize_history_to_messages(history) |
|
|
| if not OPENAI_API_KEY: |
| assistant_message = ( |
| "OPENAI_API_KEY is missing. In Hugging Face Spaces, go to " |
| "Settings → Variables and Secrets → New Secret, then add:\n\n" |
| "`OPENAI_API_KEY = your_openai_api_key`" |
| ) |
|
|
| return history + [ |
| {"role": "user", "content": message or ""}, |
| {"role": "assistant", "content": assistant_message}, |
| ], "", thread_id or create_thread_id() |
|
|
| if not thread_id: |
| thread_id = create_thread_id() |
|
|
| if not message or not message.strip(): |
| return history, "", thread_id |
|
|
| user_message = message.strip() |
|
|
| try: |
| result = sql_agent_with_memory.invoke( |
| { |
| "messages": [ |
| { |
| "role": "user", |
| "content": user_message, |
| } |
| ] |
| }, |
| config={ |
| "configurable": { |
| "thread_id": thread_id |
| } |
| }, |
| ) |
|
|
| assistant_message = content_to_text(result["messages"][-1].content) |
|
|
| except Exception as e: |
| assistant_message = f""" |
| Something went wrong while running the SQL agent. |
| |
| Error: |
| |
| ```text |
| {str(e)} |
| ``` |
| |
| Check: |
| 1. OPENAI_API_KEY is set in Hugging Face Secrets. |
| 2. MODEL_NAME is available in your OpenAI account. |
| 3. The SQLite database file exists at: `{DB_PATH}` |
| """ |
|
|
| updated_history = history + [ |
| { |
| "role": "user", |
| "content": user_message, |
| }, |
| { |
| "role": "assistant", |
| "content": assistant_message, |
| }, |
| ] |
|
|
| return updated_history, "", thread_id |
|
|
|
|
| def reset_chat(): |
| """ |
| Clears UI history and starts a fresh memory thread. |
| """ |
|
|
| return [], create_thread_id() |
|
|
|
|
| def example_question(question): |
| """ |
| Puts an example question into the textbox. |
| """ |
|
|
| return question |
|
|
|
|
| |
| |
| |
|
|
| custom_css = """ |
| #main-container { |
| max-width: 1100px; |
| margin: 0 auto; |
| } |
| |
| .dds-note { |
| font-size: 0.95rem; |
| opacity: 0.85; |
| } |
| """ |
|
|
| with gr.Blocks(title="DDS SQL Agent", css=custom_css) as demo: |
|
|
| thread_id_state = gr.State(value=create_thread_id()) |
|
|
| with gr.Column(elem_id="main-container"): |
| gr.Markdown( |
| f""" |
| # DDS SQL Agent with Memory |
| |
| Ask questions about the Chinook SQLite database. |
| The agent can generate SQL, execute read-only queries, and remember follow-up questions in the same session. |
| |
| **Model:** `{MODEL_NAME}` |
| **Database:** `{DB_PATH}` |
| """ |
| ) |
|
|
| if not OPENAI_API_KEY: |
| gr.Markdown( |
| """ |
| > **Setup needed:** `OPENAI_API_KEY` is not set. |
| > Add it in Hugging Face Spaces under **Settings → Variables and Secrets → New Secret**. |
| """ |
| ) |
|
|
| chatbot = gr.Chatbot( |
| value=[], |
| height=560, |
| label="SQL Agent Chat", |
| placeholder="Ask a question about the database...", |
| ) |
|
|
| with gr.Row(): |
| user_input = gr.Textbox( |
| placeholder="Example: Which customer spent the most money?", |
| label="Your question", |
| scale=8, |
| ) |
|
|
| submit_btn = gr.Button( |
| "Ask", |
| scale=1, |
| variant="primary", |
| ) |
|
|
| with gr.Row(): |
| clear_btn = gr.Button("New Chat / Reset Memory") |
|
|
| gr.Markdown("### Example questions") |
|
|
| with gr.Row(): |
| ex1 = gr.Button("Which customer spent the most money?") |
| ex2 = gr.Button("Show total sales by country.") |
| ex3 = gr.Button("Which genre has the most tracks?") |
| ex4 = gr.Button("What are the top-selling tracks?") |
|
|
| ex1.click(example_question, inputs=[gr.State("Which customer spent the most money?")], outputs=[user_input]) |
| ex2.click(example_question, inputs=[gr.State("Show total sales by country.")], outputs=[user_input]) |
| ex3.click(example_question, inputs=[gr.State("Which genre has the most tracks?")], outputs=[user_input]) |
| ex4.click(example_question, inputs=[gr.State("What are the top-selling tracks?")], outputs=[user_input]) |
|
|
| submit_btn.click( |
| fn=chat_with_sql_agent, |
| inputs=[user_input, chatbot, thread_id_state], |
| outputs=[chatbot, user_input, thread_id_state], |
| ) |
|
|
| user_input.submit( |
| fn=chat_with_sql_agent, |
| inputs=[user_input, chatbot, thread_id_state], |
| outputs=[chatbot, user_input, thread_id_state], |
| ) |
|
|
| clear_btn.click( |
| fn=reset_chat, |
| inputs=[], |
| outputs=[chatbot, thread_id_state], |
| ) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| demo.queue().launch() |
|
|