Skip to content

Commit

Permalink
Allow sorting plots by custom palette
Browse files Browse the repository at this point in the history
  • Loading branch information
mariya committed Nov 21, 2024
1 parent 7fcccf1 commit 8180386
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions src/conformist/prediction_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,12 @@ def visualize_class_counts(self):
# show the plot
plt.savefig(f'{self.output_dir}/class_counts.png', bbox_inches='tight')

def _sort_class_names_by_palette(self, class_names, custom_color_palette):
if isinstance(custom_color_palette, dict):
class_order = list(custom_color_palette.keys())
class_names = sorted(class_names, key=lambda x: class_order.index(x))
return class_names

def visualize_class_counts_by_dataset(self,
primary_class_only=False,
custom_color_palette=None):
Expand All @@ -243,11 +249,8 @@ def visualize_class_counts_by_dataset(self,

# Get unique class names from ccs
class_names = ccs.index.get_level_values(1).unique().sort_values()

# If color palette specified, sort class_names in same order
if isinstance(custom_color_palette, dict):
class_order = list(custom_color_palette.keys())
class_names = sorted(class_names, key=lambda x: class_order.index(x))
class_names = self._sort_class_names_by_palette(class_names,
custom_color_palette)

# Create a dictionary to map each class to a color
class_to_color = self._class_colors(
Expand Down Expand Up @@ -419,21 +422,28 @@ def visualize_prediction_stripplot(self,
for i in range(0, num_classes, 2):
ax.axhspan(i - 0.5, i + 0.5, facecolor='#eeeeee', alpha=0.5)

class_names = new_df['True class'].unique()
class_names = self._sort_class_names_by_palette(class_names,
custom_color_palette)

sns.stripplot(data=new_df,
x='Softmax score',
y='True class',
hue='Predicted class',
jitter=0.5,
alpha=0.75,
dodge=True,
palette=self._class_colors(),
palette=self._class_colors(
custom_color_palette=custom_color_palette),
size=4,
ax=ax)
ax=ax,
order=class_names)

# Create custom legend handles
class_to_color = self._class_colors(
custom_color_palette=custom_color_palette)
legend_handles = [Patch(color=class_to_color[cls], label=cls) for cls in new_df['Predicted class'].unique()]

legend_handles = [Patch(color=class_to_color[cls], label=cls) for cls in class_names]

# Position the legend to the right of the plot with bars instead of dots
plt.legend(handles=legend_handles, title="Predicted Classes", bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
Expand Down

0 comments on commit 8180386

Please sign in to comment.