Skip to content

Commit

Permalink
Fix transform when using cuML HDBSCAN (#1960)
Browse files Browse the repository at this point in the history
  • Loading branch information
beckernick authored May 7, 2024
1 parent 1aa73b3 commit 30d6fcf
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
6 changes: 3 additions & 3 deletions bertopic/cluster/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import hdbscan
import numpy as np


def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None):
""" Function used to select the HDBSCAN-like model for generating
predictions and probabilities.
Expand Down Expand Up @@ -51,8 +51,8 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None):

str_type_model = str(type(model)).lower()
if "cuml" in str_type_model and "hdbscan" in str_type_model:
from cuml.cluster.hdbscan.prediction import approximate_predict
probabilities = approximate_predict(model, embeddings)
from cuml.cluster import hdbscan as cuml_hdbscan
probabilities = cuml_hdbscan.membership_vector(model, embeddings)
return probabilities

return None
Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,17 @@ def online_topic_model(documents, document_embeddings, embedding_model):
topics.extend(model.topics_)
model.topics_ = topics
return model

@pytest.fixture(scope="session")
def cuml_base_topic_model(documents, document_embeddings, embedding_model):
from cuml.cluster import HDBSCAN as cuml_hdbscan
from cuml.manifold import UMAP as cuml_umap

model = BERTopic(
embedding_model=embedding_model,
calculate_probabilities=True,
umap_model=cuml_umap(n_components=5, n_neighbors=5, random_state=42),
hdbscan_model=cuml_hdbscan(min_cluster_size=3, prediction_data=True),
)
model.fit(documents, document_embeddings)
return model
16 changes: 15 additions & 1 deletion tests/test_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
import pytest
from bertopic import BERTopic

def cuml_available():
try:
import cuml
return True
except ImportError:
return False

@pytest.mark.parametrize(
'model',
Expand All @@ -14,7 +20,10 @@
('online_topic_model'),
('supervised_topic_model'),
('representation_topic_model'),
('zeroshot_topic_model')
('zeroshot_topic_model'),
pytest.param(
"cuml_base_topic_model", marks=pytest.mark.skipif(not cuml_available(), reason="cuML not available")
),
])
def test_full_model(model, documents, request):
""" Tests the entire pipeline in one go. This serves as a sanity check to see if the default
Expand All @@ -26,6 +35,11 @@ def test_full_model(model, documents, request):
if model == "base_topic_model":
topic_model.save("model_dir", serialization="pytorch", save_ctfidf=True, save_embedding_model="sentence-transformers/all-MiniLM-L6-v2")
topic_model = BERTopic.load("model_dir")

if model == "cuml_base_topic_model":
assert "cuml" in str(type(topic_model.umap_model)).lower()
assert "cuml" in str(type(topic_model.hdbscan_model)).lower()

topics = topic_model.topics_

for topic in set(topics):
Expand Down

0 comments on commit 30d6fcf

Please sign in to comment.