Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Softcluster #19

Merged
merged 11 commits into from
Feb 26, 2025
Binary file modified docs/imgs/cluster_workflow.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
61 changes: 49 additions & 12 deletions sdcat/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import seaborn as sns
import numpy as np
import hdbscan
from umap import UMAP
from hdbscan import HDBSCAN
from sklearn.metrics.pairwise import cosine_similarity
Expand Down Expand Up @@ -111,8 +112,13 @@ def _run_hdbscan_assign(
:param out_path: The output path to save the clustering artifacts to
:return: The average similarity score for each cluster, exemplar_df, cluster ids, cluster means, and coverage
"""
info(f'Clustering using HDBSCAN using alpha {alpha} cluster_selection_epsilon {cluster_selection_epsilon} '
f'min_samples {min_samples} use_tsne {use_tsne} ...')
info(f'Clustering using HDBSCAN with: \n'
f'alpha {alpha} \n'
f'cluster_selection_epsilon {cluster_selection_epsilon} \n'
f'min_samples {min_samples} \n'
f'min_cluster_size {min_cluster_size} \n'
f'cluster_selection_method {cluster_selection_method} \n'
f'use_tsne {use_tsne} ...')

# Remove any existing cluster images in the output_path
for c in out_path.parent.rglob(f'{prefix}_*cluster*.png'):
Expand Down Expand Up @@ -160,6 +166,7 @@ def _run_hdbscan_assign(
labels = scan.fit_predict(x)
else:
scan = HDBSCAN(
prediction_data=True,
metric='l2',
allow_single_cluster=True,
min_cluster_size=min_cluster_size,
Expand Down Expand Up @@ -221,14 +228,42 @@ def _run_hdbscan_assign(
clustered = labels >= 0
coverage = np.sum(clustered) / num_samples
if coverage < 1.0:
# Reassign based on the closest distance to exemplar
for i, label in enumerate(labels):
if label == -1:
similarity_scores = cosine_similarity(image_emb[i].reshape(1, -1), exemplar_emb)
closest_match_index = np.argmax(similarity_scores)
# Only reassign if the similarity score is above the threshold
if similarity_scores[0][closest_match_index] >= min_similarity:
labels[i] = closest_match_index
mixed_points = []
if cluster_selection_method == 'leaf': # Only tested with leaf; oem fails
clusterer = scan.fit(x)

# Credit to hdbscan docs https://hdbscan.readthedocs.io/en/latest/soft_clustering.html
def top_two_probs_diff(probs):
sorted_probs = np.sort(probs)
return sorted_probs[-1] - sorted_probs[-2]

# Get the soft cluster assignments
soft_clusters = hdbscan.all_points_membership_vectors(clusterer)
# Compute the differences between the top two probabilities
diffs = np.array([top_two_probs_diff(x) for x in soft_clusters])
mean_diffs = np.mean(diffs)
std_diffs = np.std(diffs)
mean_cluster_probs = np.mean(np.max(soft_clusters, axis=1))
std_cluster_probs = np.std(np.max(soft_clusters, axis=1))
info(f'Mean cluster probability: {mean_cluster_probs:.4f} std {std_cluster_probs:.4f}')
info(f'Difference between top two probabilities: {mean_diffs:.4f} std {std_diffs:.4f}')
cut_off_diff = mean_diffs + 2 * std_diffs
# Select out the indices that have a small difference, and a larger total probability
mixed_points = np.where((diffs < cut_off_diff) & (np.sum(soft_clusters, axis=1) > 0.6))[0]
else:
warn('Only leaf method is supported for soft clustering')

if len(mixed_points) > 0:
reassign_labels = mixed_points
else:
reassign_labels = np.where(labels == -1)[0]
# Reassign based on the soft clustering only if very similar to the exemplar
for i, label in enumerate(reassign_labels):
similarity_scores = cosine_similarity(image_emb[i].reshape(1, -1), exemplar_emb)
closest_match_index = np.argmax(similarity_scores)
# Only reassign if the similarity score is above the threshold
if similarity_scores[0][closest_match_index] >= min_similarity:
labels[i] = closest_match_index

clusters = [[] for _ in range(len(unique_clusters))]

Expand Down Expand Up @@ -324,7 +359,9 @@ def cluster_vits(
use_tsne: bool = False,
skip_visualization: bool = False,
remove_bad_images: bool = False,
roi: bool = False) -> pd.DataFrame:
roi: bool = False,
batch_size: int = 32
) -> pd.DataFrame:
""" Cluster the crops using the VITS embeddings.
:param prefix: A unique prefix to save artifacts from clustering
:param model: The model to use for clustering
Expand Down Expand Up @@ -392,7 +429,7 @@ def cluster_vits(
# Skip the embedding extraction if all the embeddings are cached
if num_cached != len(images):
debug(f'Extracted embeddings from {len(images)} images using model {model}...')
compute_norm_embedding(model, images, device)
compute_norm_embedding(model, images, device, batch_size)

# Fetch the cached embeddings
debug('Fetching embeddings ...')
Expand Down
10 changes: 6 additions & 4 deletions sdcat/cluster/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@
@common_args.cluster_selection_epsilon
@common_args.cluster_selection_method
@common_args.min_cluster_size
@common_args.batch_size
@click.option('--det-dir', help='Input folder(s) with raw detection results', multiple=True, required=True)
@click.option('--save-dir', help='Output directory to save clustered detection results', required=True)
@click.option('--device', help='Device to use, e.g. cpu or cuda:0', type=str, default='cpu')
@click.option('--use-vits', help='Set to using the predictions from the vits cluster model', is_flag=True)
def run_cluster_det(det_dir, save_dir, device, use_vits, config_ini, alpha, cluster_selection_epsilon, cluster_selection_method, min_cluster_size, start_image, end_image, use_tsne, skip_visualization):
def run_cluster_det(det_dir, save_dir, device, use_vits, config_ini, alpha, cluster_selection_epsilon, cluster_selection_method, min_cluster_size, batch_size, start_image, end_image, use_tsne, skip_visualization):
config = cfg.Config(config_ini)
max_area = int(config('cluster', 'max_area'))
min_area = int(config('cluster', 'min_area'))
Expand Down Expand Up @@ -258,7 +259,7 @@ def is_day(utc_dt):
df_cluster = cluster_vits(prefix, model, df, save_dir, alpha, cluster_selection_epsilon, cluster_selection_method,
min_similarity, min_cluster_size, min_samples, device, use_tsne=use_tsne,
skip_visualization=skip_visualization, roi=False, use_vits=use_vits,
remove_bad_images=remove_bad_images)
remove_bad_images=remove_bad_images, batch_size=batch_size)

# Merge the results with the original DataFrame
df.update(df_cluster)
Expand All @@ -277,11 +278,12 @@ def is_day(utc_dt):
@common_args.cluster_selection_epsilon
@common_args.cluster_selection_method
@common_args.min_cluster_size
@common_args.batch_size
@click.option('--roi-dir', help='Input folder(s) with raw ROI images', multiple=True, required=True)
@click.option('--save-dir', help='Output directory to save clustered detection results', required=True)
@click.option('--device', help='Device to use, e.g. cpu or cuda:0', type=str)
@click.option('--use-vits', help='Set to using the predictions from the vits cluster model', is_flag=True)
def run_cluster_roi(roi_dir, save_dir, device, use_vits, config_ini, alpha, cluster_selection_epsilon, cluster_selection_method, min_cluster_size, use_tsne, skip_visualization):
def run_cluster_roi(roi_dir, save_dir, device, use_vits, config_ini, alpha, cluster_selection_epsilon, cluster_selection_method, min_cluster_size, batch_size, use_tsne, skip_visualization):
config = cfg.Config(config_ini)
min_samples = int(config('cluster', 'min_samples'))
alpha = alpha if alpha else float(config('cluster', 'alpha'))
Expand Down Expand Up @@ -372,7 +374,7 @@ def run_cluster_roi(roi_dir, save_dir, device, use_vits, config_ini, alpha, clus
min_similarity, min_cluster_size, min_samples, device,
use_tsne=use_tsne, use_vits=use_vits,
skip_visualization=skip_visualization, roi=True,
remove_bad_images=remove_bad_images)
remove_bad_images=remove_bad_images, batch_size=batch_size)

# Merge the results with the original DataFrame
df.update(df_cluster)
Expand Down
13 changes: 6 additions & 7 deletions sdcat/cluster/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,13 @@ def encode_image(filename):
return keep


def compute_embedding_vits(vit:ViTWrapper, images: list):
def compute_embedding_vits(vit:ViTWrapper, images: list, batch_size:int=32):
"""
Compute the embedding for the given images using the given model
:param vitwrapper: Wrapper for the ViT model
:param images: List of image filenames
:param model_name: Name of the model (i.e. google/vit-base-patch16-224, dinov2_vits16, etc.)
:param device: Device to use for the computation (cpu or cuda:0, cuda:1, etc.)
:param batch_size: Number of images to process in a batch
"""
batch_size = 32
model_name = vit.model_name

# Batch process the images
Expand All @@ -146,13 +144,14 @@ def compute_embedding_vits(vit:ViTWrapper, images: list):
err(f'Error processing {batch}: {e}')


def compute_norm_embedding(model_name: str, images: list, device: str = "cpu"):
def compute_norm_embedding(model_name: str, images: list, device: str = "cpu", batch_size: int = 32):
"""
Compute the embedding for a list of images and save them to disk.
Args:
:param images: List of image paths
:param model_name: Name of the model to use for the embedding generation
:param device: Device to use for the computation (cpu or cuda:0, cuda:1, etc.)
:param batch_size: Number of images to process in a batch
Returns:

"""
Expand All @@ -164,14 +163,14 @@ def compute_norm_embedding(model_name: str, images: list, device: str = "cpu"):

# If using a GPU, set then skip the parallel CPU processing
if torch.cuda.is_available():
compute_embedding_vits(vit_wrapper, images)
compute_embedding_vits(vit_wrapper, images, batch_size)
else:
# Use a pool of processes to speed up the embedding generation 20 images at a time on each process
num_processes = min(multiprocessing.cpu_count(), len(images) // 20)
num_processes = max(1, num_processes)
info(f'Using {num_processes} processes to compute {len(images)} embeddings 20 at a time ...')
with multiprocessing.Pool(num_processes) as pool:
args = [(vit_wrapper, images[i:i + 20]) for i in range(0, len(images), 20)]
args = [(vit_wrapper, images[i:i + 20], batch_size) for i in range(0, len(images), 20)]
pool.starmap(compute_embedding_vits, args)


Expand Down
8 changes: 5 additions & 3 deletions sdcat/common_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,17 @@

cluster_selection_method = click.option('--cluster-selection-method',
type=str,
default='leaf',
help='Method for selecting the optimal number of clusters. '
help='Method for selecting the opdtimal number of clusters. '
'Default is leaf. Options are leaf, eom, and dill')

min_cluster_size = click.option('--min-cluster-size',
type=int,
help='The minimum number of samples in a group for that group to be considered a cluster. '
'Default is 2. Increase for less conservative clustering, e.g. 5, 15')

batch_size = click.option('--batch-size',
type=int,
default=32,
help='Batch size for processing images. Default is 32')
use_tsne = click.option('--use-tsne',
is_flag=True,
help='Use t-SNE for dimensionality reduction. Default is False')
Expand Down
5 changes: 2 additions & 3 deletions sdcat/config/config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ remove_bad_images = False
min_saliency = 30
# Alpha is a parameter that controls the linkage. Don't change it unless you know what you are doing.
# See https://hdbscan.readthedocs.io/en/latest/parameter_selection.html
alpha = 0.92
alpha = 0.7
# Epsilon is a parameter that controls the linkage. Don't change it unless you know what you are doing.
# Increasing this will make the clustering more conservative
cluster_selection_epsilon = 0.0
cluster_selection_epsilon = 0.2
# The method used to select clusters from the condensed tree. leaf is the most conservative; eom is the most aggressive
cluster_selection_method = leaf
# The minimum number of samples in a group for that group to be
Expand All @@ -46,7 +46,6 @@ min_similarity = 0.70
model = google/vit-base-patch16-224
;model = facebook/dino-vits8
;model = facebook/dino-vits16
;model = google/vit-base-patch16-224-in21k
;model = MBARI-org/mbari-uav-vit-b-16

[detect]
Expand Down