-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add kmeans clustering based on ray #1080
base: main
Are you sure you want to change the base?
Conversation
Don't we want the result of In the summarization case, we could use kmeans/clustering for topic discovery and select a few elements from each cluster and use those to summarize. What I would need is membership info. |
Technically, this is K means and I am talking about clustering. |
The github comment seems not work, your comment could not be replied in dialogue. Anyway, the PR already shows how the clustering method works. Basically, you do a similar trick to assign each row a cluster id using a map function by comparing with the centroids, and then do either sort or aggregate for your purpose. In your summarization case, you might chain an aggregate function to select N entries out for each group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you fix linting and unit tests?
Any idea on how this performs? I don't need a full perf analysis, but just curious if it feels reasonable on 100 data points, 1000?
lib/sycamore/sycamore/docset.py
Outdated
@@ -903,6 +906,28 @@ def map(self, f: Callable[[Document], Document], **resource_args) -> "DocSet": | |||
mapping = Map(self.plan, f=f, **resource_args) | |||
return DocSet(self.context, mapping) | |||
|
|||
def kmeans(self, K: int, iterations: int, init_mode: str = "random", epsilon: float = 1e-4): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get that it's different, but I do think we should have a method that returns a Docset with an extra field indicating that cluster that each row is assigned to. I know that it is just an extra map, but I expect it's how people would want to use it in practice and it makes things more chainable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, sure, would add one.
|
||
Args: | ||
K: the count of centroids | ||
iterations: the max iteration runs before converge |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reasonable default for this? I at least wouldn't know what a good value to pick would be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
spark uses 20, we could follow the same, but it should really be a tuning process.
How hard would it be to add a local mode version of this? I guess it would require a local-mode aggregate. That would be convenient, though I confess I'm not sure how important it is. |
yes, would fix. |
would try to come up with one. |
This includes generally three steps: 1. materialize a document's embedding 2. initialize centroids randomly 2. iterate the kmeans process until converge, this is based on ray dataset map group and aggregate operators. The result centroids could be used for downstream work.
This includes generally three steps:
The result centroids could be used for downstream work.