-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy path07.entropy.py
35 lines (27 loc) · 1.03 KB
/
07.entropy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import matplotlib.pyplot as plt
import numpy as np
def gini(p):
return p * (1 - p) + (1 - p) * (1 - (1 - p))
def entropy(p):
return -p * np.log2(p) - (1 - p) * np.log2(1 - p)
def error(p):
return 1 - np.max([p, 1 - p])
x = np.arange(0, 1, 0.01)
ent = [entropy(p) if p != 0 else None for p in x]
sc_ent = [e * 0.5 if e else None for e in ent]
err = [error(p) for p in x]
fig = plt.figure()
ax = plt.subplot(111)
data = [ent, sc_ent, gini(x), err]
titles = ['Entropy', 'Entropy (scaled)', 'Gini impurity', ' Misclassification error']
lines = ['-', '-', '--', '-.']
colors = ['black', 'lightgray', 'red', 'green', 'cyan']
for i, lab, ls, c, in zip(data, titles, lines, colors):
line = ax.plot(x, i, label=lab, linestyle=ls, lw=2, color=c)
ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=5, fancybox=True, shadow=False)
ax.axhline(y=0.5, linewidth=1, color='k', linestyle='--')
ax.axhline(y=1.0, linewidth=1, color='k', linestyle='--')
plt.ylim([0, 1.1])
plt.xlabel('p(i=1)')
plt.ylabel('impurity index')
plt.show()