Skip to content

Commit

Permalink
Add support for embeddings
Browse files Browse the repository at this point in the history
Signed-off-by: Tomas Slusny <slusnucky@gmail.com>
  • Loading branch information
deathbeam committed Mar 2, 2024
1 parent 5422646 commit 07b089d
Show file tree
Hide file tree
Showing 5 changed files with 418 additions and 68 deletions.
2 changes: 2 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[*.lua]
indent_size = 2
169 changes: 169 additions & 0 deletions lua/CopilotChat/context.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
local log = require('plenary.log')

local M = {}

local function_types = {
'function_definition',
'function_declaration',
'local_function',
'method_definition',
'method_declaration',
'constructor_declaration',
'function_item',
'arrow_function',
}

local function spatial_distance_cosine(a, b)
local dot_product = 0
local magnitude_a = 0
local magnitude_b = 0
for i = 1, #a do
dot_product = dot_product + a[i] * b[i]
magnitude_a = magnitude_a + a[i] * a[i]
magnitude_b = magnitude_b + b[i] * b[i]
end
magnitude_a = math.sqrt(magnitude_a)
magnitude_b = math.sqrt(magnitude_b)
return dot_product / (magnitude_a * magnitude_b)
end

local function data_ranked_by_relatedness(query, data, top_n)
local scores = {}
for i, item in pairs(data) do
scores[i] = { index = i, score = spatial_distance_cosine(item.embedding, query.embedding) }
end
table.sort(scores, function(a, b)
return a.score > b.score
end)
local result = {}
for i = 1, math.min(top_n, #scores) do
local srt = scores[i]
table.insert(result, vim.tbl_extend('keep', data[srt.index], { score = srt.score }))
end
return result
end

local function collect_inputs(bufnr)
local ft = vim.bo[bufnr].filetype
local name = vim.api.nvim_buf_get_name(bufnr)
local parser = vim.treesitter.get_parser(bufnr, ft)
if not parser then
return
end

local root = parser:parse()[1]:root()
local function get_functions(node, result)
if vim.tbl_contains(function_types, node:type()) then
table.insert(result, node)
return
end
for child in node:iter_children() do
get_functions(child, result)
end
end

local functions = {}
get_functions(root, functions)

local inputs = {}
for _, node in ipairs(functions) do
local start_row, start_col, end_row, end_col = node:range()
local lines = vim.api.nvim_buf_get_lines(bufnr, start_row, end_row + 1, false)
local body = table.concat(lines, '\n')

if vim.trim(body) ~= '' then
table.insert(inputs, {
content = body,
filename = name,
filetype = ft,
})
end
end

return inputs
end

local function build_outline(bufnr)
local ft = vim.bo[bufnr].filetype
local name = vim.api.nvim_buf_get_name(bufnr)
local parser = vim.treesitter.get_parser(bufnr, ft)
if not parser then
return
end

local root = parser:parse()[1]:root()
local function_signatures = {}
local depth = 0

local function get_functions(node)
local is_func = vim.tbl_contains(function_types, node:type())
local start_row, start_col, end_row, end_col = node:range()
if is_func then
depth = depth + 1
local start_line = vim.api.nvim_buf_get_lines(bufnr, start_row, start_row + 1, false)[1]
local signature_start =
vim.api.nvim_buf_get_text(bufnr, start_row, start_col, start_row, #start_line, {})[1]
table.insert(function_signatures, string.rep(' ', depth) .. vim.trim(signature_start))
if start_row ~= end_row then
table.insert(function_signatures, string.rep(' ', depth + 1) .. '...')
end
end
for child in node:iter_children() do
get_functions(child)
end
if is_func then
if start_row ~= end_row then
local signature_end = vim.api.nvim_buf_get_text(bufnr, end_row, 0, end_row, end_col, {})[1]
table.insert(function_signatures, string.rep(' ', depth) .. vim.trim(signature_end))
end
depth = depth - 1
end
end

get_functions(root)
return {
content = table.concat(function_signatures, '\n'),
filename = name,
filetype = ft,
}
end

function M.find_for_query(copilot, selection, bufnr, on_done)
local filetype = selection.filetype or vim.bo[bufnr].filetype
local filename = selection.filename or vim.api.nvim_buf_get_name(bufnr)

local outline = build_outline(bufnr)
vim.print(outline.content)

local inputs = collect_inputs(bufnr)
if #inputs > 0 then
copilot:embed(inputs, {
on_done = function(out)
log.debug(string.format('Got %s embeddings for file %s', #out, filename))

copilot:embed({
{
content = selection.lines,
filename = filename,
filetype = filetype,
},
}, {
on_done = function(query_out)
local query = query_out[1]
log.debug('Query:', query.content)
local data = data_ranked_by_relatedness(query, out, 10)
log.debug('Ranked data:', #data)
for i, item in ipairs(data) do
log.debug(
string.format('%s: %s - %s', i, item.score, vim.split(item.content, '\n')[1])
)
end
on_done(data)
end,
})
end,
})
end
end

return M
Loading

0 comments on commit 07b089d

Please sign in to comment.