Skip to content

Commit

Permalink
feat: add dataset distillation extraction (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
fynnfluegge authored Feb 15, 2025
1 parent 9a03df7 commit 6629556
Show file tree
Hide file tree
Showing 30 changed files with 4,255 additions and 1,712 deletions.
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

<div align="center">

Search your codebase semantically or chat with it from cli. Keep the vector database superfast up to date to the latest code changes.
Generate datasets from code for finetuning, search your codebase semantically or chat with your code from cli. Keep the vector database superfast up to date to the latest code changes.
100% local support without any dataleaks.
Built with [langchain](https://github.com/langchain-ai/langchain), [treesitter](https://github.com/tree-sitter/tree-sitter), [sentence-transformers](https://github.com/UKPLab/sentence-transformers), [instructor-embedding](https://github.com/xlang-ai/instructor-embedding),
[faiss](https://github.com/facebookresearch/faiss), [lama.cpp](https://github.com/ggerganov/llama.cpp), [Ollama](https://github.com/jmorganca/ollama), [Streamlit](https://github.com/streamlit/streamlit).
Expand All @@ -19,6 +19,8 @@ Built with [langchain](https://github.com/langchain-ai/langchain), [treesitter](

## ✨ Features

- 🗒️ &nbsp;Finetuning dataset generation
- export in Alpaca, conversational, instruction or completionn format
- 🔎 &nbsp;Semantic code search
- 💬 &nbsp;GPT-like chat with your codebase
- ⚙️ &nbsp;Synchronize vector store and latest code changes with ease
Expand All @@ -32,6 +34,19 @@ Built with [langchain](https://github.com/langchain-ai/langchain), [treesitter](
## 🚀 Usage

#### Export finetuning dataset from codebase in conversational format:
```
codeqai dataset
```
Export in different format like Alpaca with:
```
codeqai dataset --format alpaca
```
Export dataset with model distillation
```
codeqai dataset --distillation doc
```

#### Start semantic search:

```
Expand Down
48 changes: 39 additions & 9 deletions codeqai/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from codeqai.bootstrap import bootstrap
from codeqai.cache import create_cache_dir, get_cache_path, save_vector_cache
from codeqai.config import create_config, get_config_path, load_config
from codeqai.constants import EmbeddingsModel, LlmHost
from codeqai.constants import DistillationMode, EmbeddingsModel, LlmHost
from codeqai.dataset_extractor import DatasetExtractor
from codeqai.embeddings import Embeddings
from codeqai.vector_store import VectorStore

Expand Down Expand Up @@ -81,20 +82,33 @@ def run():
"chat",
"configure",
"sync",
"export-dataset (experimental)",
"dataset",
],
help="Action to perform. 'search' will semantically search the codebase. 'chat' will chat with the codebase.",
help="Action to perform. 'app' to start the streamlit app, 'search' to search the codebase, "
+ "'chat' to chat with the model, 'configure' to start config wizard, "
+ "'sync' to sync the vector store with the current git checkout, 'dataset' to export a dataset for model distillation.",
)
parser.add_argument(
"--distillation",
action="store_true",
help="Use model distillation for finetuning dataset extraction.",
type=DistillationMode,
default=DistillationMode.NONE,
help="Use model distillation for finetuning dataset extraction. Default is None."
+ "Supported modes are, 'full', 'doc', 'code'.\n"
+ "doc - Extracts only documentation for distillation.\n"
+ "code - Extracts will chunk code blocks with inlined comments for distillation.\n"
+ "full - Uses both doc and code mode",
)
parser.add_argument(
"--format",
type=str,
default="Conversational",
help="Format of the finetuning dataset. Supported formats are Conversational and Alpaca. Default is Conversational format.",
default="conversational",
help="Format of the finetuning dataset. Supported formats are conversational and alpaca. Default is Conversational format.",
)
parser.add_argument(
"--max-tokens",
type=int,
default=1024,
help="Token limit per code block for distillation dataset extraction. Default is 1024.",
)
args = parser.parse_args()

Expand Down Expand Up @@ -149,10 +163,26 @@ def run():
),
)

if args.action == "extract-dataset":
if args.action == "dataset":
print(args.distillation)
spinner = yaspin(
text=f"Parsing codebase for {args.format} dataset export...",
color="green",
)
spinner.start()
repo_name = repo.repo_name()
files = repo.load_files()
documents = codeparser.parse_code_files_for_finetuning(files)
documents = codeparser.parse_code_files_for_finetuning(
files, args.max_tokens, spinner
)
dateset_extractor = DatasetExtractor(
args.format,
args.distillation,
documents,
config,
args.max_tokens,
)
dateset_extractor.export()
exit()

# check if faiss.index exists
Expand Down
11 changes: 11 additions & 0 deletions codeqai/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@


def bootstrap(config, repo_name, embeddings_model=None):
"""
Initializes the necessary components for the application.
Args:
config (dict): Configuration dictionary containing settings for embeddings and LLM.
repo_name (str): The name of the repository.
embeddings_model (Embeddings, optional): Pre-initialized embeddings model. Defaults to None.
Returns:
tuple: A tuple containing the vector store, memory, and QA chain.
"""
if embeddings_model is None:
embeddings_model = Embeddings(
model=EmbeddingsModel[config["embeddings"].upper().replace("-", "_")],
Expand Down
31 changes: 31 additions & 0 deletions codeqai/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ def to_json(self):


def load_vector_cache(filename) -> Dict[str, VectorCache]:
"""
Loads a vector cache from a JSON file.
Args:
filename (str): The name of the file containing the vector cache.
Returns:
Dict[str, VectorCache]: A dictionary where the keys are strings and the values are VectorCache objects.
"""
with open(
get_cache_path() + "/" + filename, "r", encoding="utf-8"
) as vector_cache_file:
Expand All @@ -38,13 +47,29 @@ def load_vector_cache(filename) -> Dict[str, VectorCache]:


def save_vector_cache(vector_cache, filename):
"""
Saves a vector cache to a JSON file.
Args:
vector_cache (Dict[str, VectorCache]): A dictionary where the keys are strings and the values are VectorCache objects.
filename (str): The name of the file to save the vector cache to.
"""
with open(
get_cache_path() + "/" + filename, "w", encoding="utf-8"
) as vector_cache_file:
json.dump(vector_cache, default=VectorCache.to_json, fp=vector_cache_file)


def get_cache_path():
"""
Returns the cache directory path based on the operating system.
Returns:
str: The path to the cache directory.
Raises:
NotImplementedError: If the operating system is not supported.
"""
system = platform.system()

if system == "Linux" or system == "Darwin":
Expand All @@ -60,6 +85,12 @@ def get_cache_path():


def create_cache_dir():
"""
Creates the cache directory if it does not already exist.
This function checks if the cache directory exists at the path returned by get_cache_path().
If the directory does not exist, it creates the directory and any necessary parent directories.
"""
if not os.path.exists(get_cache_path()):
path = Path(get_cache_path())
path.mkdir(parents=True, exist_ok=True)
54 changes: 51 additions & 3 deletions codeqai/codeparser.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
import ast
import os

import inquirer
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from yaspin import yaspin

from codeqai import repo, utils
from codeqai.constants import Language
from codeqai.treesitter.treesitter import Treesitter, TreesitterMethodNode


def parse_code_files_for_db(code_files: list[str]) -> list[Document]:
"""
Parses a list of code files and returns a list of Document objects for database storage.
Args:
code_files (list[str]): List of paths to code files to be parsed.
Returns:
list[Document]: List of Document objects containing parsed code information.
"""
documents = []
code_splitter = None
for code_file in code_files:
Expand Down Expand Up @@ -60,7 +70,21 @@ def parse_code_files_for_db(code_files: list[str]) -> list[Document]:
return documents


def parse_code_files_for_finetuning(code_files: list[str]) -> list[dict]:
def parse_code_files_for_finetuning(
code_files: list[str], max_tokens, spinner
) -> list[dict]:
"""
Parses a list of code files for fine-tuning and returns a list of dictionaries containing method information.
Args:
code_files (list[str]): List of paths to code files to be parsed.
max_tokens (int): Maximum number of tokens allowed for output.
Returns:
list[dict]: List of dictionaries containing method information, including method name, code, description, and language.
"""
input_tokens = 0
output_tokens = 0
documents = []
for code_file in code_files:
with open(code_file, "r", encoding="utf-8") as file:
Expand All @@ -84,10 +108,34 @@ def parse_code_files_for_finetuning(code_files: list[str]) -> list[dict]:
)

document = {
"method_name": node.name,
"code": method_source_code,
"description": node.doc_comment,
"language": programming_language,
"language": programming_language.value,
}
documents.append(document)

if node.doc_comment is not None:
input_tokens += utils.count_tokens(node.doc_comment)
output_tokens += max_tokens

spinner.stop()

print(f"Estimated input tokens for distillation needed: {input_tokens}.")
print(f"Maximum output tokens for distillation nedeed: {output_tokens}.")
questions = [
inquirer.Confirm(
"confirm",
message="Proceed?",
default=True,
),
]

confirm = inquirer.prompt(questions)

if confirm and confirm["confirm"]:
pass
else:
exit()

return documents
39 changes: 39 additions & 0 deletions codeqai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@


def get_config_path():
"""
Returns the configuration file path based on the operating system.
This function determines the appropriate configuration directory based on the operating system
and constructs the full path to the configuration file.
Returns:
str: The path to the configuration file.
Raises:
NotImplementedError: If the operating system is not supported.
"""
system = platform.system()

if system == "Linux" or system == "Darwin":
Expand All @@ -25,17 +37,44 @@ def get_config_path():


def load_config():
"""
Loads the configuration from the configuration file.
This function reads the configuration file specified by get_config_path() and parses its content
using the YAML parser.
Returns:
dict: The configuration dictionary loaded from the file.
"""
with open(get_config_path(), "r", encoding="utf-8") as config_file:
config = yaml.safe_load(config_file)
return config


def save_config(config):
"""
Saves the configuration to the configuration file.
Args:
config (dict): The configuration dictionary to be saved.
This function writes the provided configuration dictionary to the configuration file specified by get_config_path()
using the YAML format.
"""
with open(get_config_path(), "w", encoding="utf-8") as config_file:
yaml.dump(config, config_file, default_flow_style=False)


def create_config():
"""
Creates a new configuration interactively by prompting the user for input.
This function prompts the user with a series of questions to configure the embeddings model and LLM host.
Based on the user's responses, it constructs a configuration dictionary and saves it to the configuration file.
Returns:
dict: The configuration dictionary created based on user input.
"""
os.makedirs(os.path.dirname(get_config_path()), exist_ok=True)

questions = [
Expand Down
13 changes: 11 additions & 2 deletions codeqai/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,14 @@ class LlmHost(Enum):


class DatasetFormat(Enum):
ALPACA = "Alpaca"
CONVERSATIONAL = "Conversational"
ALPACA = "alpaca"
CONVERSATIONAL = "conversational"
INSTRUCTION = "instruction"
COMPLETION = "completion"


class DistillationMode(Enum):
NONE = "none"
FULL = "full"
DOCUMENTATION = "doc"
CODE = "code"
Loading

0 comments on commit 6629556

Please sign in to comment.