Skip to content

Commit

Permalink
feat(generics): cleaner generics
Browse files Browse the repository at this point in the history
- add better generic types to traits
- add new trait implementation for an array of expressions
- update example
  • Loading branch information
sjrusso8 committed Mar 22, 2024
1 parent 6f0fc4f commit 5fdeeaf
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 114 deletions.
2 changes: 1 addition & 1 deletion examples/delta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use spark_connect_rs::dataframe::SaveMode;
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut spark: SparkSession = SparkSessionBuilder::default().build().await?;

let paths = vec!["/opt/spark/examples/src/main/resources/people.csv".to_string()];
let paths = ["/opt/spark/examples/src/main/resources/people.csv"];

let df = spark
.clone()
Expand Down
6 changes: 3 additions & 3 deletions examples/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ use spark_connect_rs::functions as F;
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let spark: SparkSession = SparkSessionBuilder::default().build().await?;

let paths = vec!["/opt/spark/examples/src/main/resources/people.csv".to_string()];
let path = ["/opt/spark/examples/src/main/resources/people.csv"];

let mut df = spark
.read()
.format("csv")
.option("header", "True")
.option("delimiter", ";")
.load(paths);
.load(path);

df.select(vec![
df.select([
F::col("name"),
F::col("age").cast("int").alias("age_int"),
(F::lit(3.0) + F::col("age").cast("int")).alias("addition"),
Expand Down
2 changes: 1 addition & 1 deletion examples/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.read()
.format("csv")
.option("header", "true")
.load(vec![path.to_string()]);
.load([path]);

df.show(Some(10), None, None).await?;

Expand Down
31 changes: 14 additions & 17 deletions src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ impl Catalog {

/// Returns a list of catalogs in this session
#[allow(non_snake_case)]
pub async fn listCatalogs(&mut self, pattern: Option<String>) -> Vec<RecordBatch> {
pub async fn listCatalogs(&mut self, pattern: Option<&str>) -> Vec<RecordBatch> {
let pattern = pattern.map(|val| val.to_owned());

let cat_type = Some(spark::catalog::CatType::ListCatalogs(spark::ListCatalogs {
pattern,
}));
Expand All @@ -72,7 +74,9 @@ impl Catalog {

/// Returns a list of databases in this session
#[allow(non_snake_case)]
pub async fn listDatabases(&mut self, pattern: Option<String>) -> Vec<RecordBatch> {
pub async fn listDatabases(&mut self, pattern: Option<&str>) -> Vec<RecordBatch> {
let pattern = pattern.map(|val| val.to_owned());

let cat_type = Some(spark::catalog::CatType::ListDatabases(
spark::ListDatabases { pattern },
));
Expand All @@ -92,12 +96,12 @@ impl Catalog {
#[allow(non_snake_case)]
pub async fn listTables(
&mut self,
dbName: Option<String>,
pattern: Option<String>,
dbName: Option<&str>,
pattern: Option<&str>,
) -> Vec<RecordBatch> {
let cat_type = Some(spark::catalog::CatType::ListTables(spark::ListTables {
db_name: dbName,
pattern,
db_name: dbName.map(|db| db.to_owned()),
pattern: pattern.map(|val| val.to_owned()),
}));

let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
Expand All @@ -113,14 +117,10 @@ impl Catalog {

/// Returns a list of columns for the given tables/views in the specific database
#[allow(non_snake_case)]
pub async fn listColumns(
&mut self,
tableName: String,
dbName: Option<String>,
) -> Vec<RecordBatch> {
pub async fn listColumns(&mut self, tableName: &str, dbName: Option<&str>) -> Vec<RecordBatch> {
let cat_type = Some(spark::catalog::CatType::ListColumns(spark::ListColumns {
table_name: tableName,
db_name: dbName,
table_name: tableName.to_owned(),
db_name: dbName.map(|val| val.to_owned()),
}));

let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type });
Expand Down Expand Up @@ -207,10 +207,7 @@ mod tests {
.await
.unwrap();

let value = spark
.catalog()
.listDatabases(Some("*rust".to_string()))
.await;
let value = spark.catalog().listDatabases(Some("*rust")).await;

assert_eq!(4, value[0].num_columns());
assert_eq!(1, value[0].num_rows());
Expand Down
85 changes: 48 additions & 37 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
//! DataFrame with Reader/Writer repesentation
use std::collections::HashMap;

use crate::column::Column;
use crate::errors::SparkError;
use crate::expressions::{ToExpr, ToFilterExpr, ToVecExpr};
Expand Down Expand Up @@ -159,44 +157,36 @@ impl DataFrame {
Some(data?.value(0))
}

#[allow(non_snake_case, dead_code)]
#[allow(non_snake_case)]
pub async fn createTempView(&mut self, name: &str) {
self.create_view_cmd(name.to_string(), false, false)
.await
.unwrap()
self.create_view_cmd(name, false, false).await.unwrap()
}

#[allow(non_snake_case, dead_code)]
#[allow(non_snake_case)]
pub async fn createGlobalTempView(&mut self, name: &str) {
self.create_view_cmd(name.to_string(), true, false)
.await
.unwrap()
self.create_view_cmd(name, true, false).await.unwrap()
}

#[allow(non_snake_case, dead_code)]
#[allow(non_snake_case)]
pub async fn createOrReplaceGlobalTempView(&mut self, name: &str) {
self.create_view_cmd(name.to_string(), true, true)
.await
.unwrap()
self.create_view_cmd(name, true, true).await.unwrap()
}

#[allow(non_snake_case, dead_code)]
#[allow(non_snake_case)]
pub async fn createOrReplaceTempView(&mut self, name: &str) {
self.create_view_cmd(name.to_string(), false, true)
.await
.unwrap()
self.create_view_cmd(name, false, true).await.unwrap()
}

async fn create_view_cmd(
&mut self,
name: String,
name: &str,
is_global: bool,
replace: bool,
) -> Result<(), SparkError> {
let command_type =
spark::command::CommandType::CreateDataframeView(spark::CreateDataFrameViewCommand {
input: Some(self.logical_plan.clone().relation),
name,
name: name.to_string(),
is_global,
replace,
});
Expand Down Expand Up @@ -357,7 +347,10 @@ impl DataFrame {
}

#[allow(non_snake_case)]
pub fn freqItems(&mut self, cols: Vec<&str>, support: Option<f64>) -> DataFrame {
pub fn freqItems<'a, I>(&mut self, cols: I, support: Option<f64>) -> DataFrame
where
I: IntoIterator<Item = &'a str>,
{
DataFrame::new(
self.spark_session.clone(),
self.logical_plan.freqItems(cols, support),
Expand Down Expand Up @@ -455,7 +448,10 @@ impl DataFrame {
}

#[allow(non_snake_case)]
pub fn orderBy(&mut self, cols: Vec<Column>) -> DataFrame {
pub fn orderBy<I>(&mut self, cols: I) -> DataFrame
where
I: IntoIterator<Item = Column>,
{
DataFrame::new(self.spark_session.clone(), self.logical_plan.sort(cols))
}

Expand Down Expand Up @@ -583,7 +579,10 @@ impl DataFrame {
/// }
/// ```
#[allow(non_snake_case)]
pub fn selectExpr(&mut self, cols: Vec<&str>) -> DataFrame {
pub fn selectExpr<'a, I>(&mut self, cols: I) -> DataFrame
where
I: IntoIterator<Item = &'a str>,
{
DataFrame::new(
self.spark_session.clone(),
self.logical_plan.select_expr(cols),
Expand Down Expand Up @@ -692,7 +691,10 @@ impl DataFrame {
}

#[allow(non_snake_case)]
pub fn toDF(&mut self, cols: Vec<&str>) -> DataFrame {
pub fn toDF<'a, I>(&mut self, cols: I) -> DataFrame
where
I: IntoIterator<Item = &'a str>,
{
DataFrame::new(self.spark_session.clone(), self.logical_plan.to_df(cols))
}

Expand Down Expand Up @@ -752,18 +754,27 @@ impl DataFrame {
}

#[allow(non_snake_case)]
pub fn withColumns(&mut self, colMap: HashMap<&str, Column>) -> DataFrame {
pub fn withColumns<I, K>(&mut self, colMap: I) -> DataFrame
where
I: IntoIterator<Item = (K, Column)>,
K: ToString,
{
DataFrame::new(
self.spark_session.clone(),
self.logical_plan.withColumns(colMap),
)
}

/// Returns a new [DataFrame] by renaming multiple columns from a
/// `HashMap<String, String>` containing the `existing` as the key
/// and the `new` as the value.
/// an iterator of containing a key/value pair with the key as the `existing`
/// column name and the value as the `new` column name.
#[allow(non_snake_case)]
pub fn withColumnsRenamed(&mut self, cols: HashMap<String, String>) -> DataFrame {
pub fn withColumnsRenamed<I, K, V>(&mut self, cols: I) -> DataFrame
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
DataFrame::new(
self.spark_session.clone(),
self.logical_plan.withColumnsRenamed(cols),
Expand Down Expand Up @@ -903,14 +914,14 @@ mod tests {
cols
);

let paths = vec!["/opt/spark/examples/src/main/resources/people.csv".to_string()];
let path = ["/opt/spark/examples/src/main/resources/people.csv"];

let cols = spark
.read()
.format("csv")
.option("header", "True")
.option("delimiter", ";")
.load(paths)
.load(path)
.columns()
.await;

Expand Down Expand Up @@ -1000,14 +1011,14 @@ mod tests {
async fn test_describe() {
let spark = setup().await;

let paths = vec!["/opt/spark/examples/src/main/resources/people.csv".to_string()];
let path = ["/opt/spark/examples/src/main/resources/people.csv"];

let mut df = spark
.read()
.format("csv")
.option("header", "True")
.option("delimiter", ";")
.load(paths);
.load(path);

let mut df = df
.select(col("age").cast("int").alias("age_int"))
Expand Down Expand Up @@ -1040,14 +1051,14 @@ mod tests {
async fn test_drop() {
let spark = setup().await;

let paths = vec!["/opt/spark/examples/src/main/resources/people.csv".to_string()];
let path = ["/opt/spark/examples/src/main/resources/people.csv"];

let mut df = spark
.read()
.format("csv")
.option("header", "True")
.option("delimiter", ";")
.load(paths);
.load(path);

let mut df = df.drop(vec!["age", "job"]);

Expand All @@ -1060,18 +1071,18 @@ mod tests {
async fn test_join() {
let spark = setup().await;

let paths = vec!["/opt/spark/examples/src/main/resources/people.csv".to_string()];
let path = ["/opt/spark/examples/src/main/resources/people.csv"];

let mut df = spark
.clone()
.read()
.format("csv")
.option("header", "True")
.option("delimiter", ";")
.load(paths)
.load(path)
.alias("df");

let mut df1 = spark
let df1 = spark
.clone()
.range(None, 1, 1, Some(1))
.select(vec![lit("Bob").alias("name"), lit(1).alias("id")])
Expand Down
9 changes: 9 additions & 0 deletions src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ where
}
}

impl<const N: usize, T> ToVecExpr for [T; N]
where
T: ToExpr,
{
fn to_vec_expr(&self) -> Vec<spark::Expression> {
self.iter().map(|col| col.to_expr()).collect()
}
}

pub trait ToFilterExpr {
fn to_filter_expr(&self) -> Option<spark::Expression>;
}
Expand Down
8 changes: 4 additions & 4 deletions src/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,14 +507,14 @@ mod tests {
async fn test_func_col_contains() {
let spark = setup().await;

let paths = vec!["/opt/spark/examples/src/main/resources/people.csv".to_string()];
let path = ["/opt/spark/examples/src/main/resources/people.csv"];

let mut df = spark
.read()
.format("csv")
.option("header", "True")
.option("delimiter", ";")
.load(paths);
.load(path);

let row = df
.filter(col("name").contains("e"))
Expand All @@ -536,14 +536,14 @@ mod tests {
async fn test_func_col_isin() {
let spark = setup().await;

let paths = vec!["/opt/spark/examples/src/main/resources/people.csv".to_string()];
let path = ["/opt/spark/examples/src/main/resources/people.csv"];

let mut df = spark
.read()
.format("csv")
.option("header", "True")
.option("delimiter", ";")
.load(paths);
.load(path);

let row = df
.filter(col("name").isin(vec!["Jorge", "Bob"]))
Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,14 @@ mod tests {
async fn test_dataframe_read() {
let spark = setup().await;

let paths = vec!["/opt/spark/examples/src/main/resources/people.csv".to_string()];
let path = ["/opt/spark/examples/src/main/resources/people.csv"];

let mut df = spark
.read()
.format("csv")
.option("header", "True")
.option("delimiter", ";")
.load(paths);
.load(path);

let rows = df
.filter("age > 30")
Expand Down Expand Up @@ -197,7 +197,7 @@ mod tests {
.read()
.format("csv")
.option("header", "true")
.load(vec![path.to_string()]);
.load([path]);

let total: usize = df
.select(vec![col("range_id")])
Expand Down
Loading

0 comments on commit 5fdeeaf

Please sign in to comment.