Skip to content

Commit

Permalink
Use keyword_pattern for is_symbol check
Browse files Browse the repository at this point in the history
  • Loading branch information
joshbode committed Jul 6, 2024
1 parent a110e12 commit ccd371a
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 7 deletions.
3 changes: 2 additions & 1 deletion lua/cmp/entry.lua
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ end
---@param input string
---@param matching_config cmp.MatchingConfig
---@return { score: integer, matches: table[] }
entry.match = function(self, input, matching_config)
entry.match = function(self, input, matching_config, keyword_pattern)
return self.match_cache:ensure(input .. ':' .. (self.resolved_completion_item and '1' or '0' .. ':') .. (matching_config.disallow_fuzzy_matching and '1' or '0') .. ':' .. (matching_config.disallow_partial_fuzzy_matching and '1' or '0') .. ':' .. (matching_config.disallow_partial_matching and '1' or '0') .. ':' .. (matching_config.disallow_prefix_unmatching and '1' or '0') .. ':' .. (matching_config.disallow_symbol_nonprefix_matching and '1' or '0'), function()
local option = {
disallow_fuzzy_matching = matching_config.disallow_fuzzy_matching,
Expand All @@ -377,6 +377,7 @@ entry.match = function(self, input, matching_config)
self:get_word(),
self:get_completion_item().label,
},
keyword_pattern = keyword_pattern,
}

local score, matches, filter_text, _
Expand Down
18 changes: 14 additions & 4 deletions lua/cmp/matcher.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
local char = require('cmp.utils.char')
local pattern = require('cmp.utils.pattern')

local matcher = {}

Expand Down Expand Up @@ -81,7 +82,7 @@ end
---Match entry
---@param input string
---@param word string
---@param option { synonyms: string[], disallow_fullfuzzy_matching: boolean, disallow_fuzzy_matching: boolean, disallow_partial_fuzzy_matching: boolean, disallow_partial_matching: boolean, disallow_prefix_unmatching: boolean, disallow_symbol_nonprefix_matching: boolean }
---@param option { synonyms: string[], disallow_fullfuzzy_matching: boolean, disallow_fuzzy_matching: boolean, disallow_partial_fuzzy_matching: boolean, disallow_partial_matching: boolean, disallow_prefix_unmatching: boolean, disallow_symbol_nonprefix_matching: boolean, keyword_pattern: string|nil }
---@return integer, table
matcher.match = function(input, word, option)
option = option or {}
Expand All @@ -103,6 +104,15 @@ matcher.match = function(input, word, option)
end
end

local is_symbol
if option.keyword_pattern ~= nil then
is_symbol = function(byte)
return pattern.matchstr(option.keyword_pattern, string.char(byte)) == nil
end
else
is_symbol = char.is_symbol
end

-- Gather matched regions
local matches = {}
local input_start_index = 1
Expand All @@ -111,7 +121,7 @@ matcher.match = function(input, word, option)
local word_bound_index = 1
local no_symbol_match = false
while input_end_index <= #input and word_index <= #word do
local m = matcher.find_match_region(input, input_start_index, input_end_index, word, word_index)
local m = matcher.find_match_region(input, input_start_index, input_end_index, word, word_index, is_symbol)
if m and input_end_index <= m.input_match_end then
m.index = word_bound_index
input_start_index = m.input_match_start + 1
Expand Down Expand Up @@ -277,7 +287,7 @@ matcher.fuzzy = function(input, word, matches, option)
end

--- find_match_region
matcher.find_match_region = function(input, input_start_index, input_end_index, word, word_index)
matcher.find_match_region = function(input, input_start_index, input_end_index, word, word_index, is_symbol)
-- determine input position ( woroff -> word_offset )
while input_start_index < input_end_index do
if char.match(string.byte(input, input_end_index), string.byte(word, word_index)) then
Expand Down Expand Up @@ -309,7 +319,7 @@ matcher.find_match_region = function(input, input_start_index, input_end_index,
strict_count = strict_count + (c1 == c2 and 1 or 0)
match_count = match_count + 1
word_offset = word_offset + 1
no_symbol_match = no_symbol_match or char.is_symbol(c1)
no_symbol_match = no_symbol_match or is_symbol(c1)
else
-- Match end (partial region)
if input_match_start ~= -1 then
Expand Down
15 changes: 15 additions & 0 deletions lua/cmp/matcher_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ describe('matcher', function()
assert.is.truthy(matcher.match('fmodify', 'fnamemodify', config.matching) >= 1)
assert.is.truthy(matcher.match('candlesingle', 'candle#accept#single', config.matching) >= 1)

local options = {
keyword_pattern = [[ \w\+ ]],
}
assert.is.truthy(matcher.match('ab', 'a_b_c', options) > matcher.match('ac', 'a_b_c', options))
assert.is.truthy(matcher.match('a_b', 'a_b_c', options) > matcher.match('ab', 'a_b_c', options))
assert.is.truthy(matcher.match('a_b/c', 'a_b/c', options) > matcher.match('a/c', 'a_b/c', options))

assert.is.truthy(matcher.match('bora', 'border-radius') >= 1)
assert.is.truthy(matcher.match('woroff', 'word_offset') >= 1)
assert.is.truthy(matcher.match('call', 'call') > matcher.match('call', 'condition_all'))
assert.is.truthy(matcher.match('Buffer', 'Buffer') > matcher.match('Buffer', 'buffer'))
assert.is.truthy(matcher.match('luacon', 'lua_context') > matcher.match('luacon', 'LuaContext'))
assert.is.truthy(matcher.match('fmodify', 'fnamemodify') >= 1)
assert.is.truthy(matcher.match('candlesingle', 'candle#accept#single') >= 1)

assert.is.truthy(matcher.match('vi', 'void#', config.matching) >= 1)
assert.is.truthy(matcher.match('vo', 'void#', config.matching) >= 1)
assert.is.truthy(matcher.match('var_', 'var_dump', config.matching) >= 1)
Expand Down
2 changes: 1 addition & 1 deletion lua/cmp/source.lua
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ source.get_entries = function(self, ctx)
inputs[o] = string.sub(ctx.cursor_before_line, o)
end

local match = e:match(inputs[o], matching_config)
local match = e:match(inputs[o], matching_config, self:get_keyword_pattern())
e.score = match.score
e.exact = false
if e.score >= 1 then
Expand Down
1 change: 0 additions & 1 deletion lua/cmp/view/ghost_text_view.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
local config = require('cmp.config')
local misc = require('cmp.utils.misc')
local snippet = require('cmp.utils.snippet')
local str = require('cmp.utils.str')
local api = require('cmp.utils.api')
local types = require('cmp.types')

Expand Down

0 comments on commit ccd371a

Please sign in to comment.