-
Notifications
You must be signed in to change notification settings - Fork 16.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Langchain_Community: SQL LanguageParser (#28430)
## Description (This PR has contributions from @khushiDesai, @ashvini8, and @ssumaiyaahmed). This PR addresses **Issue #11229** which addresses the need for SQL support in document parsing. This is integrated into the generic TreeSitter parsing library, allowing LangChain users to easily load codebases in SQL into smaller, manageable "documents." This pull request adds a new ```SQLSegmenter``` class, which provides the SQL integration. ## Issue **Issue #11229**: Add support for a variety of languages to LanguageParser ## Testing We created a file ```test_sql.py``` with several tests to ensure the ```SQLSegmenter``` is functional. Below are the tests we added: - ```def test_is_valid```: Checks SQL validity. - ```def test_extract_functions_classes```: Extracts individual SQL statements. - ```def test_simplify_code```: Simplifies SQL code with comments. --------- Co-authored-by: Syeda Sumaiya Ahmed <114104419+ssumaiyaahmed@users.noreply.github.com> Co-authored-by: ashvini hunagund <97271381+ashvini8@users.noreply.github.com> Co-authored-by: Khushi Desai <khushi.desai@advantawitty.com> Co-authored-by: Khushi Desai <59741309+khushiDesai@users.noreply.github.com> Co-authored-by: ccurme <chester.curme@gmail.com>
- Loading branch information
1 parent
a7f2148
commit 26bdf40
Showing
3 changed files
with
131 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
65 changes: 65 additions & 0 deletions
65
libs/community/langchain_community/document_loaders/parsers/language/sql.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from typing import TYPE_CHECKING | ||
|
||
from langchain_community.document_loaders.parsers.language.tree_sitter_segmenter import ( # noqa: E501 | ||
TreeSitterSegmenter, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from tree_sitter import Language | ||
|
||
CHUNK_QUERY = """ | ||
[ | ||
(create_table_statement) @create | ||
(select_statement) @select | ||
(insert_statement) @insert | ||
(update_statement) @update | ||
(delete_statement) @delete | ||
] | ||
""" | ||
|
||
|
||
class SQLSegmenter(TreeSitterSegmenter): | ||
"""Code segmenter for SQL. | ||
This class uses Tree-sitter to segment SQL code into its | ||
constituent statements (e.g., SELECT, CREATE TABLE). | ||
It also provides functionality to extract these | ||
statements and simplify the code into commented descriptions. | ||
""" | ||
|
||
def get_language(self) -> "Language": | ||
"""Return the SQL language grammar for Tree-sitter.""" | ||
from tree_sitter_languages import get_language | ||
|
||
return get_language("sql") | ||
|
||
def get_chunk_query(self) -> str: | ||
"""Return the Tree-sitter query for SQL segmentation.""" | ||
return CHUNK_QUERY | ||
|
||
def extract_functions_classes(self) -> list[str]: | ||
"""Extract SQL statements from the code. | ||
Ensures that all SQL statements end with a semicolon | ||
for consistency. | ||
""" | ||
extracted = super().extract_functions_classes() | ||
# Ensure all statements end with a semicolon | ||
return [ | ||
stmt.strip() + ";" if not stmt.strip().endswith(";") else stmt.strip() | ||
for stmt in extracted | ||
] | ||
|
||
def simplify_code(self) -> str: | ||
"""Simplify the extracted SQL code into comments. | ||
Converts SQL statements into commented descriptions | ||
for easy readability. | ||
""" | ||
return "\n".join( | ||
[ | ||
f"-- Code for: {stmt.strip()}" | ||
for stmt in self.extract_functions_classes() | ||
] | ||
) | ||
|
||
def make_line_comment(self, text: str) -> str: | ||
"""Create a line comment in SQL style.""" | ||
return f"-- {text}" |
61 changes: 61 additions & 0 deletions
61
libs/community/tests/unit_tests/document_loaders/parsers/language/test_sql.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import unittest | ||
|
||
import pytest | ||
|
||
from langchain_community.document_loaders.parsers.language.sql import SQLSegmenter | ||
|
||
|
||
@pytest.mark.requires("tree_sitter", "tree_sitter_languages") | ||
class TestSQLSegmenter(unittest.TestCase): | ||
"""Unit tests for the SQLSegmenter class.""" | ||
|
||
def setUp(self) -> None: | ||
"""Set up example code and expected results for testing.""" | ||
self.example_code = """ | ||
CREATE TABLE users (id INT, name TEXT); | ||
-- A select query | ||
SELECT id, name FROM users WHERE id = 1; | ||
INSERT INTO users (id, name) VALUES (2, 'Alice'); | ||
UPDATE users SET name = 'Bob' WHERE id = 2; | ||
DELETE FROM users WHERE id = 2; | ||
""" | ||
|
||
self.expected_simplified_code = ( | ||
"-- Code for: CREATE TABLE users (id INT, name TEXT);\n" | ||
"-- Code for: SELECT id, name FROM users WHERE id = 1;\n" | ||
"-- Code for: INSERT INTO users (id, name) VALUES (2, 'Alice');\n" | ||
"-- Code for: UPDATE users SET name = 'Bob' WHERE id = 2;\n" | ||
"-- Code for: DELETE FROM users WHERE id = 2;" | ||
) | ||
|
||
self.expected_extracted_code = [ | ||
"CREATE TABLE users (id INT, name TEXT);", | ||
"SELECT id, name FROM users WHERE id = 1;", | ||
"INSERT INTO users (id, name) VALUES (2, 'Alice');", | ||
"UPDATE users SET name = 'Bob' WHERE id = 2;", | ||
"DELETE FROM users WHERE id = 2;", | ||
] | ||
|
||
def test_is_valid(self) -> None: | ||
"""Test the validity of SQL code.""" | ||
# Valid SQL code should return True | ||
self.assertTrue(SQLSegmenter("SELECT * FROM test").is_valid()) | ||
# Invalid code (non-SQL text) should return False | ||
self.assertFalse(SQLSegmenter("random text").is_valid()) | ||
|
||
def test_extract_functions_classes(self) -> None: | ||
"""Test extracting SQL statements from code.""" | ||
segmenter = SQLSegmenter(self.example_code) | ||
extracted_code = segmenter.extract_functions_classes() | ||
# Verify the extracted code matches expected SQL statements | ||
self.assertEqual(extracted_code, self.expected_extracted_code) | ||
|
||
def test_simplify_code(self) -> None: | ||
"""Test simplifying SQL code into commented descriptions.""" | ||
segmenter = SQLSegmenter(self.example_code) | ||
simplified_code = segmenter.simplify_code() | ||
self.assertEqual(simplified_code, self.expected_simplified_code) |