Skip to content

Commit

Permalink
feat: Haskell support added, Bugs fixed in the case of Azure_OpenAI (#28
Browse files Browse the repository at this point in the history
)

1. In the case of Azure_OpenAI - Some bugs are fixed
2. Added support of Haskell language
3. Added support of Haskell language in text_splitter module of
Langchain. PR for the same -
langchain-ai/langchain#16191
  • Loading branch information
Nisarg1112 authored Jan 18, 2024
1 parent 4f928c0 commit f6f3907
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 9 deletions.
9 changes: 6 additions & 3 deletions codeqai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def create_config():
deployment_answer = inquirer.prompt(questions)
if deployment_answer and deployment_answer["deployment"]:
config["model-deployment"] = deployment_answer["deployment"]
config["chat-model"] = deployment_answer["deployment"]

elif config["llm-host"] == LlmHost.LLAMACPP.value:
questions = [
Expand Down Expand Up @@ -181,9 +182,11 @@ def create_config():
),
]

answersChatmodel = inquirer.prompt(questions)
if answersChatmodel and answersChatmodel["chat-model"]:
config["chat-model"] = answersChatmodel["chat-model"]
# Check if "chat-model" is already present in the case of Azure_OpenAI
if "chat-model" not in config:
answersChatmodel = inquirer.prompt(questions)
if answersChatmodel and answersChatmodel["chat-model"]:
config["chat-model"] = answersChatmodel["chat-model"]

save_config(config)

Expand Down
1 change: 1 addition & 0 deletions codeqai/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Language(Enum):
OBJECTIVE_C = "objective_c"
SCALA = "scala"
LUA = "lua"
HASKELL = "haskell"
UNKNOWN = "unknown"


Expand Down
7 changes: 4 additions & 3 deletions codeqai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import inquirer
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.llms import AzureOpenAI, LlamaCpp, Ollama
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.llms import LlamaCpp, Ollama

from codeqai import utils
from codeqai.constants import LlmHost
Expand All @@ -19,7 +19,8 @@ def __init__(self, llm_host: LlmHost, chat_model: str, deployment=None):
temperature=0.9, max_tokens=2048, model=chat_model
)
elif llm_host == LlmHost.AZURE_OPENAI and deployment:
self.chat_model = AzureOpenAI(
self.chat_model = AzureChatOpenAI(
openai_api_base=os.getenv("OPENAI_API_BASE"),
temperature=0.9,
max_tokens=2048,
deployment_name=deployment,
Expand Down
1 change: 1 addition & 0 deletions codeqai/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,5 @@ def get_commit_hash(file_path):
".yml",
".rst",
".md",
".hs",
]
14 changes: 12 additions & 2 deletions codeqai/treesitter/treesitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ def __init__(
self,
name: "str | bytes | None",
doc_comment: "str | None",
method_source_code: "str | None",
node: tree_sitter.Node,
):
self.name = name
self.doc_comment = doc_comment
self.method_source_code = node.text.decode()
self.method_source_code = method_source_code or node.text.decode()
self.node = node


Expand Down Expand Up @@ -46,7 +47,7 @@ def parse(self, file_bytes: bytes) -> list[TreesitterMethodNode]:
method_name = self._query_method_name(method["method"])
doc_comment = method["doc_comment"]
result.append(
TreesitterMethodNode(method_name, doc_comment, method["method"])
TreesitterMethodNode(method_name, doc_comment, None, method["method"])
)
return result

Expand All @@ -62,6 +63,15 @@ def _query_all_methods(
and node.prev_named_sibling.type == self.doc_comment_identifier
):
doc_comment_node = node.prev_named_sibling.text.decode()
else:
# added for haskell purpose.
if node.prev_named_sibling and node.prev_named_sibling.type == "signature":
prev_node = node.prev_named_sibling
if (
prev_node.prev_named_sibling
and prev_node.prev_named_sibling.type == self.doc_comment_identifier
):
doc_comment_node = prev_node.prev_named_sibling.text.decode()
methods.append({"method": node, "doc_comment": doc_comment_node})
else:
for child in node.children:
Expand Down
79 changes: 79 additions & 0 deletions codeqai/treesitter/treesitter_hs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import tree_sitter
from typing import List, Dict


from codeqai.constants import Language
from codeqai.treesitter.treesitter import (Treesitter,
TreesitterMethodNode)
from codeqai.treesitter.treesitter_registry import TreesitterRegistry

class TreesitterHaskell(Treesitter):
def __init__(self):
super().__init__(
Language.HASKELL, "function", "variable", "comment"
)

def parse(self, file_bytes: bytes) -> list[TreesitterMethodNode]:
self.tree = self.parser.parse(file_bytes)
result = []
methods = self._query_all_methods(self.tree.root_node)
for method in methods:
method_name = self._query_method_name(method["method"])
doc_comment = method["doc_comment"]
source_code = None
if method["method"].type == "signature":
sc = map(lambda x : "\n" + x.text.decode() if x.type == "function" else "", method["method"].children)
source_code = method["method"].text.decode() + "".join(sc)
result.append(
TreesitterMethodNode(method_name, doc_comment, source_code, method["method"])
)
return result

def _query_all_methods(
self,
node: tree_sitter.Node,
):
methods: List[Dict[tree_sitter.Node, tree_sitter.Node]] = []
if node.type == self.method_declaration_identifier:
doc_comment_node = None
if (
node.prev_named_sibling
and node.prev_named_sibling.type == self.doc_comment_identifier
):
doc_comment_node = node.prev_named_sibling.text.decode()
else:
if node.prev_named_sibling and node.prev_named_sibling.type == "signature":
prev_node = node.prev_named_sibling
if (
prev_node.prev_named_sibling
and prev_node.prev_named_sibling.type == self.doc_comment_identifier
):
doc_comment_node = prev_node.prev_named_sibling.text.decode()
prev_node.children.append(node)
node = prev_node
methods.append({"method": node, "doc_comment": doc_comment_node})
else:
for child in node.children:
current = self._query_all_methods(child)
if methods and current:
previous = methods[-1]
if self._query_method_name(previous["method"]) == self._query_method_name(current[0]["method"]):
previous["method"].children.extend(map(lambda x: x["method"], current))
methods = methods[:-1]
methods.append(previous)
else:
methods.extend(current)
else:
methods.extend(current)
return methods

def _query_method_name(self, node: tree_sitter.Node):
if node.type == "signature" or node.type == self.method_declaration_identifier:
for child in node.children:
if child.type == self.method_name_identifier:
return child.text.decode()
return None


# Register the TreesitterHaskell class in the registry
TreesitterRegistry.register_treesitter(Language.HASKELL, TreesitterHaskell)
2 changes: 1 addition & 1 deletion codeqai/treesitter/treesitter_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def parse(self, file_bytes: bytes) -> list[TreesitterMethodNode]:
for method in methods:
method_name = self._query_method_name(method)
doc_comment = self._query_doc_comment(method)
result.append(TreesitterMethodNode(method_name, doc_comment, method))
result.append(TreesitterMethodNode(method_name, doc_comment, None, method))
return result

def _query_method_name(self, node: tree_sitter.Node):
Expand Down
3 changes: 3 additions & 0 deletions codeqai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def get_programming_language(file_extension: str) -> Language:
".cpp": Language.CPP,
".c": Language.C,
".cs": Language.C_SHARP,
".hs": Language.HASKELL,
}
return language_mapping.get(file_extension, Language.UNKNOWN)

Expand Down Expand Up @@ -69,6 +70,8 @@ def get_langchain_language(language: Language) -> text_splitter.Language:
return text_splitter.Language.CPP
elif language == Language.C_SHARP:
return text_splitter.Language.CSHARP
elif language == Language.HASKELL:
return text_splitter.Language.HASKELL # PR for Haskell support in text_splitter module - https://github.com/langchain-ai/langchain/pull/16191
else:
return text_splitter.Language.UNKNOWN

Expand Down

0 comments on commit f6f3907

Please sign in to comment.