Skip to content

Commit

Permalink
Fix #1866 (#1898)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaartenGr authored Apr 18, 2024
1 parent de7376d commit 0cf7d9c
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3241,37 +3241,37 @@ def merge_models(cls, models, min_similarity: float = .7, embedding_model=None):

# Merge Topic Representations
new_topics_dict = {}
new_topic_val = max_topic + 1
for index, new_topic in enumerate(new_topics):
new_topic_val = max_topic + index + 1
new_topics_dict[new_topic] = new_topic_val
merged_topics["topic_representations"][str(new_topic_val)] = selected_topics["topic_representations"][str(new_topic)]
merged_topics["topic_labels"][str(new_topic_val)] = selected_topics["topic_labels"][str(new_topic)]

# Add new aspects
if selected_topics["topic_aspects"]:
aspects_1 = set(merged_topics["topic_aspects"].keys())
aspects_2 = set(selected_topics["topic_aspects"].keys())
aspects_diff = aspects_2.difference(aspects_1)
if aspects_diff:
for aspect in aspects_diff:
merged_topics["topic_aspects"][aspect] = {}

# If the original model does not have topic aspects but the to be added model does
if not merged_topics.get("topic_aspects"):
merged_topics["topic_aspects"] = selected_topics["topic_aspects"]

# If they both contain topic aspects, add to the existing set of aspects
else:
for aspect, values in selected_topics["topic_aspects"].items():
merged_topics["topic_aspects"][aspect][str(new_topic_val)] = values[str(new_topic)]
for new_topic in new_topics:
if new_topic != -1:
max_topic += 1
new_topics_dict[new_topic] = max_topic
merged_topics["topic_representations"][str(max_topic)] = selected_topics["topic_representations"][str(new_topic)]
merged_topics["topic_labels"][str(max_topic)] = selected_topics["topic_labels"][str(new_topic)]

# Add new aspects
if selected_topics["topic_aspects"]:
aspects_1 = set(merged_topics["topic_aspects"].keys())
aspects_2 = set(selected_topics["topic_aspects"].keys())
aspects_diff = aspects_2.difference(aspects_1)
if aspects_diff:
for aspect in aspects_diff:
merged_topics["topic_aspects"][aspect] = {}

# If the original model does not have topic aspects but the to be added model does
if not merged_topics.get("topic_aspects"):
merged_topics["topic_aspects"] = selected_topics["topic_aspects"]

# If they both contain topic aspects, add to the existing set of aspects
else:
for aspect, values in selected_topics["topic_aspects"].items():
merged_topics["topic_aspects"][aspect][str(max_topic)] = values[str(new_topic)]

# Add new embeddings
new_tensors = tensors[new_topic + selected_topics["_outliers"]]
merged_tensors = np.vstack([merged_tensors, new_tensors])
# Add new embeddings
new_tensors = tensors[new_topic + selected_topics["_outliers"]]
merged_tensors = np.vstack([merged_tensors, new_tensors])

# Topic Mapper
merged_topics["topic_mapper"] = TopicMapper(list(range(-1, new_topic_val+1, 1))).mappings_
merged_topics["topic_mapper"] = TopicMapper(list(range(-1, max_topic+1, 1))).mappings_

# Find similar topics and re-assign those from the new models
sims_idx = np.argmax(sim_matrix, axis=1)
Expand Down

0 comments on commit 0cf7d9c

Please sign in to comment.