-
Notifications
You must be signed in to change notification settings - Fork 10
/
cross_entropy.py
30 lines (20 loc) · 1001 Bytes
/
cross_entropy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import tensorflow as tf
def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss'):
"""
The class-balanced cross entropy loss,
as in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_.
This is more numerically stable than class_balanced_cross_entropy
:param logits: size: the logits.
:param label: size: the ground truth in {0,1}, of the same shape as logits.
:returns: a scalar. class-balanced cross entropy loss
"""
y = tf.cast(label, tf.float32)
count_neg = tf.reduce_sum(1. - y) # the number of 0 in y
count_pos = tf.reduce_sum(y) # the number of 1 in y (less than count_neg)
beta = count_neg / (count_neg + count_pos)
pos_weight = beta / (1 - beta)
cost = tf.nn.weighted_cross_entropy_with_logits(logits, y, pos_weight)
# cost = -tf.reduce_mean(cost * (1 - beta), name=name)
cost = tf.reduce_mean(cost * (1 - beta), name=name)
return cost