Skip to content

Commit

Permalink
Ignore Seaborn plot warnings (#3576)
Browse files Browse the repository at this point in the history
* Ignore Seaborn plot warnings

* Update plots.py

* Update metrics.py
  • Loading branch information
glenn-jocher authored Jun 10, 2021
1 parent 4695ca8 commit 095197b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
9 changes: 6 additions & 3 deletions utils/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Model validation metrics

import warnings
from pathlib import Path

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -167,9 +168,11 @@ def plot(self, save_dir='', names=()):
fig = plt.figure(figsize=(12, 9), tight_layout=True)
sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
xticklabels=names + ['background FP'] if labels else "auto",
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
xticklabels=names + ['background FP'] if labels else "auto",
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
fig.axes[0].set_xlabel('True')
fig.axes[0].set_ylabel('Predicted')
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
Expand Down
8 changes: 4 additions & 4 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import seaborn as sn
import torch
import yaml
from PIL import Image, ImageDraw, ImageFont
Expand Down Expand Up @@ -291,7 +291,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])

# seaborn correlogram
sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
plt.close()

Expand All @@ -306,8 +306,8 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
ax[0].set_xticklabels(names, rotation=90, fontsize=10)
else:
ax[0].set_xlabel('classes')
sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)

# rectangles
labels[:, 1:3] = 0.5 # center
Expand Down

0 comments on commit 095197b

Please sign in to comment.