diff --git a/src/ast/dml.rs b/src/ast/dml.rs index c0e58e21a..439213282 100644 --- a/src/ast/dml.rs +++ b/src/ast/dml.rs @@ -487,6 +487,90 @@ pub struct Insert { pub insert_alias: Option, } +impl Display for Insert { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // Start building the insert statement. + if self.replace_into { + write!(f, "REPLACE INTO ")?; + } else { + if self.ignore { + write!(f, "INSERT IGNORE ")?; + } else if let Some(priority) = &self.priority { + write!(f, "INSERT {} ", priority)?; + } else { + write!(f, "INSERT ")?; + } + + if self.into { + write!(f, "INTO ")?; + } + if self.table { + write!(f, "TABLE ")?; + } + } + + // Write table name and alias + write!(f, "{}", self.table_name)?; + if let Some(alias) = &self.table_alias { + write!(f, " AS {}", alias)?; + } + + // Write columns if there are any + if !self.columns.is_empty() { + let cols = self + .columns + .iter() + .map(|col| col.to_string()) + .collect::>() + .join(", "); + write!(f, " ({})", cols)?; + } + + // Write partitioned insert (Hive) + if let Some(partitions) = &self.partitioned { + let parts = partitions + .iter() + .map(|p| p.to_string()) + .collect::>() + .join(", "); + write!(f, " PARTITION ({})", parts)?; + } + + // Write after columns (Hive) + if !self.after_columns.is_empty() { + let after_cols = self + .after_columns + .iter() + .map(|col| col.to_string()) + .collect::>() + .join(", "); + write!(f, " ({})", after_cols)?; + } + + // Write the source query if it exists + if let Some(source) = &self.source { + write!(f, " {}", source)?; + } + + // Write ON conflict handling for Sqlite, MySQL, etc. + if let Some(on_conflict) = &self.on { + write!(f, " {}", on_conflict)?; + } + + // Write RETURNING clause if present + if let Some(returning) = &self.returning { + let returns = returning + .iter() + .map(|r| r.to_string()) + .collect::>() + .join(", "); + write!(f, " RETURNING {}", returns)?; + } + + Ok(()) + } +} + /// DELETE statement. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -507,3 +591,65 @@ pub struct Delete { /// LIMIT (MySQL) pub limit: Option, } + +impl Display for Delete { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "DELETE ")?; + + // Handle multi-table DELETE if present + if !self.tables.is_empty() { + let tables = self.tables + .iter() + .map(|t| t.to_string()) + .collect::>() + .join(", "); + write!(f, "{} ", tables)?; + } + + // The FromTable includes the `FROM` keyword. + write!(f, "{} ", self.from)?; + + // USING clause (if present) + if let Some(using) = &self.using { + let uses = using + .iter() + .map(|tab| tab.to_string()) + .collect::>() + .join(", "); + write!(f, "USING {} ", uses)?; + } + + // WHERE clause (if present) + if let Some(sel) = &self.selection { + write!(f, "WHERE {} ", sel)?; + } + + // RETURNING clause (if present) + if let Some(ret) = &self.returning { + let rets = ret + .iter() + .map(|col| col.to_string()) + .collect::>() + .join(", "); + write!(f, "RETURNING {} ", rets)?; + } + + // ORDER BY clause (if present) + if !self.order_by.is_empty() { + let order_by = self + .order_by + .iter() + .map(|ob| ob.to_string()) + .collect::>() + .join(", "); + write!(f, "ORDER BY {} ", order_by)?; + } + + // LIMIT clause (if present) + if let Some(limit) = &self.limit { + write!(f, "LIMIT {}", limit)?; + } + + Ok(()) + } +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 480442b1e..8459f7d91 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2121,6 +2121,35 @@ pub enum FromTable { WithoutKeyword(Vec), } +impl Display for FromTable { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FromTable::WithFromKeyword(tables) => { + write!( + f, + "FROM {}", + tables + .iter() + .map(|t| t.to_string()) + .collect::>() + .join(", ") + ) + } + FromTable::WithoutKeyword(tables) => { + write!( + f, + "{}", + tables + .iter() + .map(|t| t.to_string()) + .collect::>() + .join(", ") + ) + } + } + } +} + /// A top-level statement (SELECT, INSERT, CREATE, etc.) #[allow(clippy::large_enum_variant)] #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]