From 30d6fcf029368263331b838a726a3d531627ac5d Mon Sep 17 00:00:00 2001 From: Nick Becker Date: Tue, 7 May 2024 10:14:18 -0400 Subject: [PATCH] Fix `transform` when using cuML HDBSCAN (#1960) --- bertopic/cluster/_utils.py | 6 +++--- tests/conftest.py | 14 ++++++++++++++ tests/test_bertopic.py | 16 +++++++++++++++- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/bertopic/cluster/_utils.py b/bertopic/cluster/_utils.py index 355a53f6..4e1805cc 100644 --- a/bertopic/cluster/_utils.py +++ b/bertopic/cluster/_utils.py @@ -1,7 +1,7 @@ import hdbscan import numpy as np - + def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): """ Function used to select the HDBSCAN-like model for generating predictions and probabilities. @@ -51,8 +51,8 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): str_type_model = str(type(model)).lower() if "cuml" in str_type_model and "hdbscan" in str_type_model: - from cuml.cluster.hdbscan.prediction import approximate_predict - probabilities = approximate_predict(model, embeddings) + from cuml.cluster import hdbscan as cuml_hdbscan + probabilities = cuml_hdbscan.membership_vector(model, embeddings) return probabilities return None diff --git a/tests/conftest.py b/tests/conftest.py index 23c68a0d..24e44f95 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -136,3 +136,17 @@ def online_topic_model(documents, document_embeddings, embedding_model): topics.extend(model.topics_) model.topics_ = topics return model + +@pytest.fixture(scope="session") +def cuml_base_topic_model(documents, document_embeddings, embedding_model): + from cuml.cluster import HDBSCAN as cuml_hdbscan + from cuml.manifold import UMAP as cuml_umap + + model = BERTopic( + embedding_model=embedding_model, + calculate_probabilities=True, + umap_model=cuml_umap(n_components=5, n_neighbors=5, random_state=42), + hdbscan_model=cuml_hdbscan(min_cluster_size=3, prediction_data=True), + ) + model.fit(documents, document_embeddings) + return model diff --git a/tests/test_bertopic.py b/tests/test_bertopic.py index 88be7457..5d4bfac8 100644 --- a/tests/test_bertopic.py +++ b/tests/test_bertopic.py @@ -2,6 +2,12 @@ import pytest from bertopic import BERTopic +def cuml_available(): + try: + import cuml + return True + except ImportError: + return False @pytest.mark.parametrize( 'model', @@ -14,7 +20,10 @@ ('online_topic_model'), ('supervised_topic_model'), ('representation_topic_model'), - ('zeroshot_topic_model') + ('zeroshot_topic_model'), + pytest.param( + "cuml_base_topic_model", marks=pytest.mark.skipif(not cuml_available(), reason="cuML not available") + ), ]) def test_full_model(model, documents, request): """ Tests the entire pipeline in one go. This serves as a sanity check to see if the default @@ -26,6 +35,11 @@ def test_full_model(model, documents, request): if model == "base_topic_model": topic_model.save("model_dir", serialization="pytorch", save_ctfidf=True, save_embedding_model="sentence-transformers/all-MiniLM-L6-v2") topic_model = BERTopic.load("model_dir") + + if model == "cuml_base_topic_model": + assert "cuml" in str(type(topic_model.umap_model)).lower() + assert "cuml" in str(type(topic_model.hdbscan_model)).lower() + topics = topic_model.topics_ for topic in set(topics):