Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Supporting SAMPLE parsing #1566

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7629,6 +7629,72 @@ impl Display for JsonNullClause {
}
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum SamplingMethod {
Bernoulli,
Row,
System,
Block,
}

impl Display for SamplingMethod {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SamplingMethod::Bernoulli => write!(f, "BERNOULLI"),
SamplingMethod::Row => write!(f, "ROW"),
SamplingMethod::System => write!(f, "SYSTEM"),
SamplingMethod::Block => write!(f, "BLOCK"),
}
}
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum SampleSeed {
Seed(u32),
Repeatable(u32),
}

impl Display for SampleSeed {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SampleSeed::Seed(seed) => write!(f, "SEED ({})", seed),
SampleSeed::Repeatable(seed) => write!(f, "REPEATABLE ({})", seed),
}
}
}

/// Table sampling
/// Snowflake (and others) offer ways to sample rows from various tables in a query
/// <https://docs.snowflake.com/en/sql-reference/constructs/sample>
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Sampling {
/// The method with which to do the sampling
pub method: SamplingMethod,
/// Sample size. Can be integer or probability (decimal)
pub size: Value,
/// `ROWS` keyword
pub rows: bool,
/// Optional `SEED` keyword for deterministic sampling
pub seed: Option<SampleSeed>,
}

impl Display for Sampling {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, " {} ({}", self.method, self.size)?;
write!(f, "){}", if self.rows { " ROWS" } else { "" })?;
if let Some(ref seed) = self.seed {
write!(f, " {seed}")?;
}
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
4 changes: 4 additions & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,15 @@ define_keywords!(
BEGIN,
BEGIN_FRAME,
BEGIN_PARTITION,
BERNOULLI,
BETWEEN,
BIGDECIMAL,
BIGINT,
BIGNUMERIC,
BINARY,
BINDING,
BLOB,
BLOCK,
BLOOMFILTER,
BOOL,
BOOLEAN,
Expand Down Expand Up @@ -668,6 +670,7 @@ define_keywords!(
RUN,
SAFE,
SAFE_CAST,
SAMPLE,
SAVEPOINT,
SCHEMA,
SCHEMAS,
Expand All @@ -677,6 +680,7 @@ define_keywords!(
SECOND,
SECRET,
SECURITY,
SEED,
SELECT,
SEMI,
SENSITIVE,
Expand Down
79 changes: 79 additions & 0 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12756,6 +12756,85 @@ impl<'a> Parser<'a> {
Ok(None)
}
}

fn parse_sampling(&mut self) -> Result<Sampling, ParserError> {
// Ensure the first keyword is SAMPLE or TABLESAMPLE
if self
.parse_one_of_keywords(&[Keyword::SAMPLE, Keyword::TABLESAMPLE])
.is_none()
{
return self.expected("Expected SAMPLE or TABLESAMPLE keyword", self.peek_token());
}

// Parse the sampling method
let method = self
.parse_one_of_keywords(&[
Keyword::BERNOULLI,
Keyword::ROW,
Keyword::SYSTEM,
Keyword::BLOCK,
])
.ok_or_else(|| {
ParserError::ParserError(format!("Expected one of BEROULLI, SYSTEM, ROW, BLOCK"))
})?;

// Parse common structure: (size) [ROWS] [SEED | REPEATABLE]
self.expect_token(&Token::LParen)?;
let size = self.parse_value()?;

// ROWS keyword is only valid for BERNOULLI | ROW sampling
let rows = match method {
Keyword::BERNOULLI | Keyword::ROW => {
self.parse_one_of_keywords(&[Keyword::ROWS]).is_some()
}
_ => false,
};
self.expect_token(&Token::RParen)?;

// Parse optional seed
let seed_keyword = self.parse_one_of_keywords(&[Keyword::REPEATABLE, Keyword::SEED]);
let seed = if let Some(keyword) = seed_keyword {
self.expect_token(&Token::LParen)?;
let seed_value = match self.parse_value()? {
Value::Number(n, _) => n
.parse::<u32>()
.map_err(|_| ParserError::ParserError(format!("Invalid seed value {}", n)))?,
_ => unreachable!(),
};
self.expect_token(&Token::RParen)?;

Some(match keyword {
Keyword::SEED => SampleSeed::Seed(seed_value),
Keyword::REPEATABLE => SampleSeed::Repeatable(seed_value),
_ => unreachable!(),
})
} else {
None
};

Ok(Sampling {
method: self.keyword_to_sampling_method(method)?,
size,
rows,
seed,
})
}

fn keyword_to_sampling_method(
&mut self,
keyword: Keyword,
) -> Result<SamplingMethod, ParserError> {
match keyword {
Keyword::BERNOULLI => Ok(SamplingMethod::Bernoulli),
Keyword::ROW => Ok(SamplingMethod::Row),
Keyword::SYSTEM => Ok(SamplingMethod::System),
Keyword::BLOCK => Ok(SamplingMethod::Block),
_ => Err(ParserError::ParserError(format!(
"Unsuppored keyword for sampling: {:?}",
keyword
))),
}
}
}

impl Word {
Expand Down
Loading