Skip to content

Commit

Permalink
add cosine_similiarity_3d
Browse files Browse the repository at this point in the history
  • Loading branch information
georgypv committed Jan 30, 2024
1 parent 7a6ff76 commit beac7a1
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "polars_coord_transforms"
version = "0.7.0"
version = "0.8.0"
edition = "2021"

[lib]
Expand Down
8 changes: 8 additions & 0 deletions polars_coord_transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ def cosine_similarity_2d(self, other: pl.Expr) -> pl.Expr:
is_elementwise=True,
args=[other,]
)

def cosine_similarity_3d(self, other: pl.Expr) -> pl.Expr:
return self._expr.register_plugin(
lib=lib,
symbol="cosine_similarity_3d",
is_elementwise=True,
args=[other,]
)


class CoordTransformExpr(pl.Expr):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
version = "0.7.0"
version = "0.8.0"
authors = [
{name="Georgy Popov"}
]
Expand Down
13 changes: 13 additions & 0 deletions src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@ pub fn cosine_similarity_2d_elementwise(x1: f64, y1: f64, x2: f64, y2: f64) -> f
let magnitude1 = (x1.powi(2) + y1.powi(2)).powf(0.5);
let magnitude2 = (x2.powi(2) + y2.powi(2)).powf(0.5);

let res = if magnitude1 == 0.0 || magnitude2 == 0.0 {
0.0
} else {
dot_product / (magnitude1*magnitude2)
};
res
}

pub fn cosine_similarity_3d_elementwise(x1: f64, y1: f64, z1: f64, x2: f64, y2: f64, z2: f64) -> f64 {
let dot_product = (x1*x2) + (y1*y2) + (z1*z2);
let magnitude1 = (x1.powi(2) + y1.powi(2) + z1.powi(2)).powf(0.5);
let magnitude2 = (x2.powi(2) + y2.powi(2) + z2.powi(2)).powf(0.5);

let res = if magnitude1 == 0.0 || magnitude2 == 0.0 {
0.0
} else {
Expand Down
46 changes: 37 additions & 9 deletions src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,34 @@ fn euclidean_2d(inputs: &[Series]) -> PolarsResult<Series> {
}


#[polars_expr(output_type=Float64)]
fn euclidean_3d(inputs: &[Series]) -> PolarsResult<Series> {
let ca1: &StructChunked = inputs[0].struct_()?;
let ca2: &StructChunked = inputs[1].struct_()?;

let (x1, y1, z1) = unpack_xyz(ca1, false);
let (x2, y2, z2) = unpack_xyz(ca2, false);

let iter = izip!(
x1.f64()?,
y1.f64()?,
z1.f64()?,
x2.f64()?,
y2.f64()?,
z2.f64()?
).into_iter().map(
|(x1_op, y1_op, z1_op, x2_op, y2_op, z2_op)| {
match (x1_op, y1_op, z1_op, x2_op, y2_op, z2_op) {
(Some(x1), Some(y1), Some(z1), Some(x2), Some(y2), Some(z2),) => euclidean_3d_elementwise(x1, y1, z1, x2, y2, z2),
_ => panic!("Unable to find euclidean distance!")
}
});

let out_ca: ChunkedArray<Float64Type> = iter.collect_ca_with_dtype("distance", DataType::Float64);
Ok(out_ca.into_series())
}


#[polars_expr(output_type=Float64)]
fn cosine_similarity_2d(inputs: &[Series]) -> PolarsResult<Series> {
let ca1: &StructChunked = inputs[0].struct_()?;
Expand Down Expand Up @@ -622,7 +650,7 @@ fn cosine_similarity_2d(inputs: &[Series]) -> PolarsResult<Series> {


#[polars_expr(output_type=Float64)]
fn euclidean_3d(inputs: &[Series]) -> PolarsResult<Series> {
fn cosine_similarity_3d(inputs: &[Series]) -> PolarsResult<Series> {
let ca1: &StructChunked = inputs[0].struct_()?;
let ca2: &StructChunked = inputs[1].struct_()?;

Expand All @@ -631,20 +659,20 @@ fn euclidean_3d(inputs: &[Series]) -> PolarsResult<Series> {

let iter = izip!(
x1.f64()?,
y1.f64()?,
z1.f64()?,
y1.f64()?,
z1.f64()?,
x2.f64()?,
y2.f64()?,
z2.f64()?
).into_iter().map(
y2.f64()?,
z2.f64()?,
).into_iter().map(
|(x1_op, y1_op, z1_op, x2_op, y2_op, z2_op)| {
match (x1_op, y1_op, z1_op, x2_op, y2_op, z2_op) {
(Some(x1), Some(y1), Some(z1), Some(x2), Some(y2), Some(z2),) => euclidean_3d_elementwise(x1, y1, z1, x2, y2, z2),
_ => panic!("Unable to find euclidean distance!")
(Some(x1), Some(y1), Some(z1), Some(x2), Some(y2), Some(z2)) => cosine_similarity_3d_elementwise(x1, y1, z1, x2, y2, z2),
_ => panic!("Unable to find cosine similarity!")
}
});

let out_ca: ChunkedArray<Float64Type> = iter.collect_ca_with_dtype("distance", DataType::Float64);
let out_ca: ChunkedArray<Float64Type> = iter.collect_ca_with_dtype("cosine_similarity", DataType::Float64);
Ok(out_ca.into_series())

}

0 comments on commit beac7a1

Please sign in to comment.