Skip to content
This repository has been archived by the owner on Apr 4, 2023. It is now read-only.

Introduce the depth method on FilterCondition #421

Merged
merged 5 commits into from
Dec 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions filter-parser/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,24 @@ pub enum FilterCondition<'a> {
}

impl<'a> FilterCondition<'a> {
/// Returns the first token found at the specified depth, `None` if no token at this depth.
pub fn token_at_depth(&self, depth: usize) -> Option<&Token> {
match self {
FilterCondition::Condition { fid, .. } if depth == 0 => Some(fid),
FilterCondition::Or(left, right) => {
let depth = depth.saturating_sub(1);
right.token_at_depth(depth).or_else(|| left.token_at_depth(depth))
}
FilterCondition::And(left, right) => {
let depth = depth.saturating_sub(1);
right.token_at_depth(depth).or_else(|| left.token_at_depth(depth))
}
FilterCondition::GeoLowerThan { point: [point, _], .. } if depth == 0 => Some(point),
FilterCondition::GeoGreaterThan { point: [point, _], .. } if depth == 0 => Some(point),
_ => None,
}
}

pub fn negate(self) -> FilterCondition<'a> {
use FilterCondition::*;

Expand Down Expand Up @@ -584,4 +602,10 @@ pub mod tests {
assert!(filter.starts_with(expected), "Filter `{:?}` was supposed to return the following error:\n{}\n, but instead returned\n{}\n.", input, expected, filter);
}
}

#[test]
fn depth() {
let filter = FilterCondition::parse("account_ids=1 OR account_ids=2 OR account_ids=3 OR account_ids=4 OR account_ids=5 OR account_ids=6").unwrap();
assert!(filter.token_at_depth(5).is_some());
}
}
52 changes: 52 additions & 0 deletions milli/src/search/facet/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ use crate::heed_codec::facet::{
};
use crate::{distance_between_two_points, CboRoaringBitmapCodec, FieldId, Index, Result};

/// The maximum number of filters the filter AST can process.
const MAX_FILTER_DEPTH: usize = 2000;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Filter<'a> {
condition: FilterCondition<'a>,
Expand All @@ -27,6 +30,7 @@ enum FilterError<'a> {
BadGeoLat(f64),
BadGeoLng(f64),
Reserved(&'a str),
TooDeep,
InternalError,
}
impl<'a> std::error::Error for FilterError<'a> {}
Expand All @@ -40,6 +44,10 @@ impl<'a> Display for FilterError<'a> {
attribute,
filterable,
),
Self::TooDeep => write!(f,
"Too many filter conditions, can't process more than {} filters.",
MAX_FILTER_DEPTH
),
Self::Reserved(keyword) => write!(
f,
"`{}` is a reserved keyword and thus can't be used as a filter expression.",
Expand Down Expand Up @@ -108,6 +116,10 @@ impl<'a> Filter<'a> {
}
}

if let Some(token) = ands.as_ref().and_then(|fc| fc.token_at_depth(MAX_FILTER_DEPTH)) {
return Err(token.as_external_error(FilterError::TooDeep).into());
}

Ok(ands.map(|ands| Self { condition: ands }))
}

Expand All @@ -116,6 +128,11 @@ impl<'a> Filter<'a> {
Ok(fc) => Ok(fc),
Err(e) => Err(Error::UserError(UserError::InvalidFilter(e.to_string()))),
}?;

if let Some(token) = condition.token_at_depth(MAX_FILTER_DEPTH) {
return Err(token.as_external_error(FilterError::TooDeep).into());
}

Ok(Self { condition })
}
}
Expand Down Expand Up @@ -419,6 +436,8 @@ impl<'a> From<FilterCondition<'a>> for Filter<'a> {

#[cfg(test)]
mod tests {
use std::fmt::Write;

use big_s::S;
use either::Either;
use heed::EnvOpenOptions;
Expand Down Expand Up @@ -586,4 +605,37 @@ mod tests {
"Bad longitude `180.000001`. Longitude must be contained between -180 and 180 degrees."
));
}

#[test]
fn filter_depth() {
let path = tempfile::tempdir().unwrap();
let mut options = EnvOpenOptions::new();
options.map_size(10 * 1024 * 1024); // 10 MB
let index = Index::new(options, &path).unwrap();

// Set the filterable fields to be the channel.
let mut wtxn = index.write_txn().unwrap();
let mut builder = Settings::new(&mut wtxn, &index);
builder.set_searchable_fields(vec![S("account_ids")]);
builder.set_filterable_fields(hashset! { S("account_ids") });
builder.execute(|_| ()).unwrap();
wtxn.commit().unwrap();

// generates a big (2 MiB) filter with too much of ORs.
let tipic_filter = "account_ids=14361 OR ";
let mut filter_string = String::with_capacity(tipic_filter.len() * 14360);
for i in 1..=14361 {
let _ = write!(&mut filter_string, "account_ids={}", i);
if i != 14361 {
let _ = write!(&mut filter_string, " OR ");
}
}

let error = Filter::from_str(&filter_string).unwrap_err();
assert!(
error.to_string().starts_with("Too many filter conditions"),
"{}",
error.to_string()
);
}
}