Skip to content

Commit

Permalink
fix escaped wildcard query on wildcard field
Browse files Browse the repository at this point in the history
Signed-off-by: gesong.samuel <gesong.samuel@bytedance.com>
  • Loading branch information
gesong.samuel committed Sep 5, 2024
1 parent 2f1e209 commit e32c7e5
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -430,22 +430,27 @@ public Query wildcardQuery(String value, MultiTermQuery.RewriteMethod method, bo
finalValue = value;
}
Predicate<String> matchPredicate;
if (value.contains("?")) {
Automaton automaton = WildcardQuery.toAutomaton(new Term(name(), finalValue));
CompiledAutomaton compiledAutomaton = new CompiledAutomaton(automaton);
Automaton automaton = WildcardQuery.toAutomaton(new Term(name(), finalValue));
CompiledAutomaton compiledAutomaton = new CompiledAutomaton(automaton);
if (compiledAutomaton.type == CompiledAutomaton.AUTOMATON_TYPE.SINGLE) {
// when type equals SINGLE, #compiledAutomaton.runAutomaton is null
matchPredicate = s -> {
if (caseInsensitive) {
s = s.toLowerCase(Locale.ROOT);
}
BytesRef valueBytes = BytesRefs.toBytesRef(s);
return compiledAutomaton.runAutomaton.run(valueBytes.bytes, valueBytes.offset, valueBytes.length);
return s.equals(finalValue);
};
} else if (compiledAutomaton.type == CompiledAutomaton.AUTOMATON_TYPE.ALL) {
return existsQuery(context);
} else if (compiledAutomaton.type == CompiledAutomaton.AUTOMATON_TYPE.NONE) {
return new MatchNoDocsQuery("Wildcard expression matches nothing");
} else {
matchPredicate = s -> {
if (caseInsensitive) {
s = s.toLowerCase(Locale.ROOT);
}
return Regex.simpleMatch(finalValue, s);
BytesRef valueBytes = BytesRefs.toBytesRef(s);
return compiledAutomaton.runAutomaton.run(valueBytes.bytes, valueBytes.offset, valueBytes.length);
};
}

Expand All @@ -468,22 +473,30 @@ public Query wildcardQuery(String value, MultiTermQuery.RewriteMethod method, bo
// Package-private for testing
static Set<String> getRequiredNGrams(String value) {
Set<String> terms = new HashSet<>();

if (value.isEmpty()) {
return terms;
}

int pos = 0;
String rawSequence = null;
String currentSequence = null;
if (!value.startsWith("?") && !value.startsWith("*")) {
// Can add prefix term
currentSequence = getNonWildcardSequence(value, 0);
rawSequence = getNonWildcardSequence(value, 0);
currentSequence = performEscape(rawSequence);
if (currentSequence.length() == 1) {
terms.add(new String(new char[] { 0, currentSequence.charAt(0) }));
terms.add(new String(new char[]{0, currentSequence.charAt(0)}));
} else {
terms.add(new String(new char[] { 0, currentSequence.charAt(0), currentSequence.charAt(1) }));
terms.add(new String(new char[]{0, currentSequence.charAt(0), currentSequence.charAt(1)}));
}
} else {
pos = findNonWildcardSequence(value, pos);
currentSequence = getNonWildcardSequence(value, pos);
rawSequence = getNonWildcardSequence(value, pos);
}
while (pos < value.length()) {
boolean isEndOfValue = pos + currentSequence.length() == value.length();
boolean isEndOfValue = pos + rawSequence.length() == value.length();
currentSequence = performEscape(rawSequence);
if (!currentSequence.isEmpty() && currentSequence.length() < 3 && !isEndOfValue && pos > 0) {
// If this is a prefix or suffix of length < 3, then we already have a longer token including the anchor.
terms.add(currentSequence);
Expand All @@ -495,23 +508,24 @@ static Set<String> getRequiredNGrams(String value) {
if (isEndOfValue) {
// This is the end of the input. We can attach a suffix anchor.
if (currentSequence.length() == 1) {
terms.add(new String(new char[] { currentSequence.charAt(0), 0 }));
terms.add(new String(new char[]{currentSequence.charAt(0), 0}));
} else {
char a = currentSequence.charAt(currentSequence.length() - 2);
char b = currentSequence.charAt(currentSequence.length() - 1);
terms.add(new String(new char[] { a, b, 0 }));
terms.add(new String(new char[]{a, b, 0}));
}
}
pos = findNonWildcardSequence(value, pos + currentSequence.length());
currentSequence = getNonWildcardSequence(value, pos);
pos = findNonWildcardSequence(value, pos + rawSequence.length());
rawSequence = getNonWildcardSequence(value, pos);
}
return terms;
}

private static String getNonWildcardSequence(String value, int startFrom) {
for (int i = startFrom; i < value.length(); i++) {
char c = value.charAt(i);
if (c == '?' || c == '*') {
if ((c == '?' || c == '*') &&
(i == 0 || value.charAt(i - 1) != '\\')) {
return value.substring(startFrom, i);
}
}
Expand All @@ -529,6 +543,22 @@ private static int findNonWildcardSequence(String value, int startFrom) {
return value.length();
}

private static String performEscape(String str) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < str.length(); i++) {
if (str.charAt(i) == '\\' && (i + 1) < str.length()) {
char c = str.charAt(i + 1);
if (c == '*' || c == '?') {
i++;
}
}
sb.append(str.charAt(i));
}
assert !sb.toString().contains("\\*");
assert !sb.toString().contains("\\?");
return sb.toString();
}

@Override
public Query regexpQuery(
String value,
Expand Down Expand Up @@ -616,10 +646,10 @@ private static Query regexpToQuery(String fieldName, RegExp regExp) {
query = builder.build();
} else if ((regExp.kind == RegExp.Kind.REGEXP_REPEAT_MIN || regExp.kind == RegExp.Kind.REGEXP_REPEAT_MINMAX)
&& regExp.min > 0) {
return regexpToQuery(fieldName, regExp.exp1);
} else {
return new MatchAllDocsQuery();
}
return regexpToQuery(fieldName, regExp.exp1);
} else {
return new MatchAllDocsQuery();
}
if (query.clauses().size() == 1) {
return query.iterator().next().getQuery();
} else if (query.clauses().size() == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,38 @@ public void testWildcardQuery() {
);
}

public void testEscapedWildcardQuery() {
MappedFieldType ft = new WildcardFieldMapper.WildcardFieldType("field");
Set<String> expectedTerms = new HashSet<>();
expectedTerms.add(prefixAnchored("*"));
expectedTerms.add(suffixAnchored("*"));

BooleanQuery.Builder builder = new BooleanQuery.Builder();
for (String term : expectedTerms) {
builder.add(new TermQuery(new Term("field", term)), BooleanClause.Occur.FILTER);
}

assertEquals(
new WildcardFieldMapper.WildcardMatchingQuery("field", builder.build(), "\\**\\*"),
ft.wildcardQuery("\\**\\*", null, null)
);

assertEquals(
new WildcardFieldMapper.WildcardMatchingQuery("field", builder.build(), "\\*"),
ft.wildcardQuery("\\*", null, null)
);

expectedTerms.remove(suffixAnchored("*"));
builder = new BooleanQuery.Builder();
for (String term : expectedTerms) {
builder.add(new TermQuery(new Term("field", term)), BooleanClause.Occur.FILTER);
}
assertEquals(
new WildcardFieldMapper.WildcardMatchingQuery("field", builder.build(), "\\**"),
ft.wildcardQuery("\\**", null, null)
);
}

public void testMultipleWildcardsInQuery() {
final String pattern = "a?cd*efg?h";
MappedFieldType ft = new WildcardFieldMapper.WildcardFieldType("field");
Expand Down

0 comments on commit e32c7e5

Please sign in to comment.