Skip to content

Commit

Permalink
Add support for embeddings
Browse files Browse the repository at this point in the history
Signed-off-by: Tomas Slusny <slusnucky@gmail.com>
  • Loading branch information
deathbeam committed Mar 3, 2024
1 parent 6d3992a commit b726025
Show file tree
Hide file tree
Showing 9 changed files with 515 additions and 90 deletions.
2 changes: 2 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[*.lua]
indent_size = 2
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ Also see [here](/lua/CopilotChat/config.lua):
system_prompt = prompts.COPILOT_INSTRUCTIONS, -- System prompt to use
model = 'gpt-4', -- GPT model to use
temperature = 0.1, -- GPT temperature
context = 'manual', -- Context to use, 'buffers', 'buffer' or 'manual'
debug = false, -- Enable debug logging
show_user_selection = true, -- Shows user selection in chat
show_system_prompt = false, -- Shows system prompt in chat
Expand Down Expand Up @@ -219,7 +220,7 @@ Also see [here](/lua/CopilotChat/config.lua):
mappings = {
close = 'q',
reset = '<C-l>',
complete_after_slash = '<Tab>',
complete = '<Tab>',
submit_prompt = '<CR>',
accept_diff = '<C-y>',
show_diff = '<C-d>',
Expand Down
7 changes: 4 additions & 3 deletions lua/CopilotChat/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ local select = require('CopilotChat.select')
---@class CopilotChat.config.mappings
---@field close string?
---@field reset string?
---@field complete_after_slash string?
---@field complete string?
---@field submit_prompt string?
---@field accept_diff string?
---@field show_diff string?
Expand All @@ -59,8 +59,9 @@ local select = require('CopilotChat.select')
---@field mappings CopilotChat.config.mappings?
return {
system_prompt = prompts.COPILOT_INSTRUCTIONS, -- System prompt to use
model = 'gpt-4', -- GPT model to use
model = 'gpt-4', -- GPT model to use, 'gpt-3.5-turbo' or 'gpt-4'
temperature = 0.1, -- GPT temperature
context = 'manual', -- Context to use, 'buffers', 'buffer' or 'manual'
debug = false, -- Enable debug logging
show_user_selection = true, -- Shows user selection in chat
show_system_prompt = false, -- Shows system prompt in chat
Expand Down Expand Up @@ -114,7 +115,7 @@ return {
mappings = {
close = 'q',
reset = '<C-l>',
complete_after_slash = '<Tab>',
complete = '<Tab>',
submit_prompt = '<CR>',
accept_diff = '<C-y>',
show_diff = '<C-d>',
Expand Down
163 changes: 163 additions & 0 deletions lua/CopilotChat/context.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
local log = require('plenary.log')

local M = {}

local outline_types = {
'local_function',
'function_item',
'arrow_function',
'function_definition',
'function_declaration',
'method_definition',
'method_declaration',
'constructor_declaration',
'class_definition',
'class_declaration',
'interface_definition',
'interface_declaration',
'type_alias_declaration',
'import_statement',
'import_from_statement',
}

local function spatial_distance_cosine(a, b)
local dot_product = 0
local magnitude_a = 0
local magnitude_b = 0
for i = 1, #a do
dot_product = dot_product + a[i] * b[i]
magnitude_a = magnitude_a + a[i] * a[i]
magnitude_b = magnitude_b + b[i] * b[i]
end
magnitude_a = math.sqrt(magnitude_a)
magnitude_b = math.sqrt(magnitude_b)
return dot_product / (magnitude_a * magnitude_b)
end

local function data_ranked_by_relatedness(query, data, top_n)
local scores = {}
for i, item in pairs(data) do
scores[i] = { index = i, score = spatial_distance_cosine(item.embedding, query.embedding) }
end
table.sort(scores, function(a, b)
return a.score > b.score
end)
local result = {}
for i = 1, math.min(top_n, #scores) do
local srt = scores[i]
table.insert(result, vim.tbl_extend('keep', data[srt.index], { score = srt.score }))
end
return result
end

function M.build_outline(bufnr)
local ft = vim.bo[bufnr].filetype
local name = vim.api.nvim_buf_get_name(bufnr)
local parser = vim.treesitter.get_parser(bufnr, ft)
if not parser then
return
end

local root = parser:parse()[1]:root()
local function_signatures = {}
local depth = 0

local function get_functions(node)
local is_func = vim.tbl_contains(outline_types, node:type())
local start_row, start_col, end_row, end_col = node:range()
if is_func then
depth = depth + 1
local start_line = vim.api.nvim_buf_get_lines(bufnr, start_row, start_row + 1, false)[1]
local signature_start =
vim.api.nvim_buf_get_text(bufnr, start_row, start_col, start_row, #start_line, {})[1]
table.insert(function_signatures, string.rep(' ', depth) .. vim.trim(signature_start))
if start_row ~= end_row then
table.insert(function_signatures, string.rep(' ', depth + 1) .. '...')
end
end
for child in node:iter_children() do
get_functions(child)
end
if is_func then
if start_row ~= end_row then
local signature_end =
vim.trim(vim.api.nvim_buf_get_text(bufnr, end_row, 0, end_row, end_col, {})[1])
if #signature_end <= 3 then
table.insert(function_signatures, string.rep(' ', depth) .. signature_end)
end
end
depth = depth - 1
end
end

get_functions(root)
local content = table.concat(function_signatures, '\n')
if content == '' then
return
end

return {
content = table.concat(function_signatures, '\n'),
filename = name,
filetype = ft,
}
end

function M.find_for_query(copilot, context, prompt, selection, bufnr, on_done)
local outline = {}

if context == 'buffers' then
outline = vim.tbl_map(
function(b)
return M.build_outline(b)
end,
vim.tbl_filter(function(b)
return vim.api.nvim_buf_is_loaded(b) and vim.fn.buflisted(b) == 1
end, vim.api.nvim_list_bufs())
)
elseif context == 'buffer' then
table.insert(outline, M.build_outline(bufnr))
end

if #outline == 0 then
on_done({})
return
end

local filetype = selection.filetype or vim.bo[bufnr].filetype
local filename = selection.filename or vim.api.nvim_buf_get_name(bufnr)

copilot:embed(outline, {
on_done = function(out)
log.debug(string.format('Got %s embeddings', #out))

if #out == 0 then
on_done({})
return
end

copilot:embed({
{
prompt = prompt,
content = selection.lines,
filename = filename,
filetype = filetype,
},
}, {
on_done = function(query_out)
local query = query_out[1]
log.debug('Prompt:', query.prompt)
log.debug('Content:', query.content)
local data = data_ranked_by_relatedness(query, out, 10)
log.debug('Ranked data:', #data)
for i, item in ipairs(data) do
log.debug(string.format('%s: %s - %s', i, item.score, item.filename))
end
on_done(data)
end,
})
end,
})
end

return M
Loading

0 comments on commit b726025

Please sign in to comment.