Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add kmeans clustering based on ray #1080

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Add kmeans clustering based on ray #1080

wants to merge 1 commit into from

Conversation

bohou-aryn
Copy link
Collaborator

This includes generally three steps:

  1. materialize a document's embedding
  2. initialize centroids randomly
  3. iterate the kmeans process until converge, this is based on ray dataset map group and aggregate operators.

The result centroids could be used for downstream work.

@austintlee
Copy link
Contributor

Don't we want the result of .kmeans() to be a DocSet? I'm not sure what to do with the resulting array of vectors.

In the summarization case, we could use kmeans/clustering for topic discovery and select a few elements from each cluster and use those to summarize. What I would need is membership info.

@austintlee
Copy link
Contributor

Technically, this is K means and I am talking about clustering.

@bohou-aryn
Copy link
Collaborator Author

Technically, this is K means and I am talking about clustering.

The github comment seems not work, your comment could not be replied in dialogue. Anyway, the PR already shows how the clustering method works. Basically, you do a similar trick to assign each row a cluster id using a map function by comparing with the centroids, and then do either sort or aggregate for your purpose. In your summarization case, you might chain an aggregate function to select N entries out for each group.

Copy link
Contributor

@bsowell bsowell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you fix linting and unit tests?

Any idea on how this performs? I don't need a full perf analysis, but just curious if it feels reasonable on 100 data points, 1000?

@@ -903,6 +906,28 @@ def map(self, f: Callable[[Document], Document], **resource_args) -> "DocSet":
mapping = Map(self.plan, f=f, **resource_args)
return DocSet(self.context, mapping)

def kmeans(self, K: int, iterations: int, init_mode: str = "random", epsilon: float = 1e-4):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get that it's different, but I do think we should have a method that returns a Docset with an extra field indicating that cluster that each row is assigned to. I know that it is just an extra map, but I expect it's how people would want to use it in practice and it makes things more chainable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, sure, would add one.


Args:
K: the count of centroids
iterations: the max iteration runs before converge
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reasonable default for this? I at least wouldn't know what a good value to pick would be.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spark uses 20, we could follow the same, but it should really be a tuning process.

@bsowell
Copy link
Contributor

bsowell commented Dec 20, 2024

How hard would it be to add a local mode version of this? I guess it would require a local-mode aggregate. That would be convenient, though I confess I'm not sure how important it is.

@bohou-aryn
Copy link
Collaborator Author

Can you fix linting and unit tests?

Any idea on how this performs? I don't need a full perf analysis, but just curious if it feels reasonable on 100 data points, 1000?

yes, would fix.

@bohou-aryn
Copy link
Collaborator Author

How hard would it be to add a local mode version of this? I guess it would require a local-mode aggregate. That would be convenient, though I confess I'm not sure how important it is.

would try to come up with one.

This includes generally three steps:
1. materialize a document's embedding
2. initialize centroids randomly
2. iterate the kmeans process until converge, this is based on ray
   dataset map group and aggregate operators.

The result centroids could be used for downstream work.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants