-
Notifications
You must be signed in to change notification settings - Fork 5.6k
/
loss.py
4112 lines (3406 loc) · 164 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
# TODO: define loss functions of neural network
import paddle
from paddle import _C_ops, fluid, in_dynamic_mode
from paddle.framework import core
from paddle.static.nn.control_flow import Assert
from paddle.utils import deprecated
from ...common_ops_import import Variable
from ...fluid.data_feeder import check_variable_and_dtype
from ...fluid.framework import _current_expected_place
from ...fluid.layer_helper import LayerHelper
from ...tensor.manipulation import reshape
__all__ = []
kIgnoreIndex = -100
def dice_loss(input, label, epsilon=0.00001, name=None):
r"""
Dice loss for comparing the similarity between the input predictions and the label.
This implementation is for binary classification, where the input is sigmoid
predictions of each pixel, usually used for segmentation task. The dice loss can
be defined as the following equation:
.. math::
dice\_loss &= 1 - \frac{2 * intersection\_area}{total\_area} \\
&= \frac{(total\_area - intersection\_area) - intersection\_area}{total\_area} \\
&= \frac{(union\_area - intersection\_area)}{total\_area}
Parameters:
input (Tensor): Tensor, rank>=2, shape is :math:`[N_1, N_2, ..., N_k, D]`, where :math:`N_1` is
the batch_size, :math:`D` is the number of categories. It is usually the output
predictions of sigmoid activation. The data type can be float32 or float64.
label (Tensor): Tensor, the groud truth with the same rank as input, shape is :math:`[N_1, N_2, ..., N_k, 1]`.
where :math:`N_1` is the batch_size. The data type can be int32 or int64.
epsilon (float): The epsilon will be added to the numerator and denominator.
If both input and label are empty, it makes sure dice is 1.
Default: 0.00001
name(str, optional): The default value is None.
Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`
Returns:
0-D Tensor, which shape is [], data type is the same as `input` .
Example:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.randn((3,224,224,2))
label = paddle.randint(high=2, shape=(3,224,224,1))
predictions = F.softmax(x)
loss = F.dice_loss(input=predictions, label=label)
"""
assert input.dtype in (paddle.float32, paddle.float64)
assert label.dtype in (paddle.int32, paddle.int64)
assert (
len(input.shape) >= 2
), "The rank of input should be greater than or equal to 2."
assert len(input.shape) == len(label.shape), (
"The rank of input and label should be equal, "
"but received input: %d, label: %d."
% (len(input.shape), len(label.shape))
)
assert label.shape[-1] == 1, (
"The last dimension of label should be 1, "
"but received %d." % label.shape[-1]
)
assert (
input.shape[:-1] == label.shape[:-1]
), "All dimensions should be equal except the last one."
assert (
input.numel() > 0 and label.numel() > 0
), "Any dimension of input and label cannot be equal to 0."
label = paddle.squeeze(label, [-1])
label = paddle.nn.functional.one_hot(label, input.shape[-1])
reduce_dim = list(range(1, len(input.shape)))
inse = paddle.sum(input * label, axis=reduce_dim)
dice_denominator = paddle.sum(input, axis=reduce_dim) + paddle.sum(
label, axis=reduce_dim
)
dice_score = 1 - inse * 2 / (dice_denominator + epsilon)
return paddle.mean(dice_score)
def log_loss(input, label, epsilon=1e-4, name=None):
r"""
**Negative Log Loss Layer**
This layer accepts input predictions and target label and returns the
negative log loss.
.. math::
Out = -label * \log{(input + \epsilon)}
- (1 - label) * \log{(1 - input + \epsilon)}
Args:
input (Tensor|list): A 2-D tensor with shape [N x 1], where N is the
batch size. This input is a probability computed
by the previous operator. Data type float32.
label (Tensor|list): The ground truth which is a 2-D tensor with
shape [N x 1], where N is the batch size.
Data type float32.
epsilon (float, optional): A small number for numerical stability. Default 1e-4.
name(str|None): For detailed information, please refer to
:ref:`api_guide_Name` . Usually name is no need to set and None by default.
Returns:
Tensor, which shape is [N x 1], data type is float32.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
label = paddle.randn((10,1))
prob = paddle.randn((10,1))
cost = F.log_loss(input=prob, label=label)
"""
if in_dynamic_mode():
return _C_ops.log_loss(input, label, epsilon)
helper = LayerHelper('log_loss', **locals())
check_variable_and_dtype(input, 'input', ['float32'], 'log_loss')
check_variable_and_dtype(label, 'label', ['float32'], 'log_loss')
loss = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='log_loss',
inputs={'Predicted': [input], 'Labels': [label]},
outputs={'Loss': [loss]},
attrs={'epsilon': epsilon},
)
return loss
def fluid_softmax_with_cross_entropy(
logits,
label,
soft_label=False,
ignore_index=-100,
numeric_stable_mode=True,
return_softmax=False,
axis=-1,
):
r"""
This operator implements the cross entropy loss function with softmax. This function
combines the calculation of the softmax operation and the cross entropy loss function
to provide a more numerically stable gradient.
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
When the attribute :attr:`soft_label` is set :attr:`False`, this operators
expects mutually exclusive hard labels, each sample in a batch is in exactly
one class with a probability of 1.0. Each sample in the batch will have a
single label.
The equation is as follows:
1) Hard label (one-hot label, so every sample has exactly one class)
.. math::
\\loss_j=-\text{logits}_{label_j} +\log\left(\sum_{i=0}^{K}\exp(\text{logits}_i)\right), j = 1,..., K
2) Soft label (each sample can have a distribution over all classes)
.. math::
\\loss_j= -\sum_{i=0}^{K}\text{label}_i\left(\text{logits}_i - \log\left(\sum_{i=0}^{K}\exp(\text{logits}_i)\right)\right), j = 1,...,K
3) If :attr:`numeric_stable_mode` is :attr:`True`, softmax is calculated first by:
.. math::
\\max_j&=\max_{i=0}^{K}{\text{logits}_i} \\
log\_max\_sum_j &= \log\sum_{i=0}^{K}\exp(logits_i - max_j)\\
softmax_j &= \exp(logits_j - max_j - {log\_max\_sum}_j)
and then cross entropy loss is calculated by softmax and label.
Args:
logits (Tensor): A multi-dimension ``Tensor`` , and the data type is float32 or float64. The input tensor of unscaled log probabilities.
label (Tensor): The ground truth ``Tensor`` , data type is the same
as the ``logits`` . If :attr:`soft_label` is set to :attr:`True`,
Label is a ``Tensor`` in the same shape with :attr:`logits`.
If :attr:`soft_label` is set to :attr:`True`, Label is a ``Tensor``
in the same shape with :attr:`logits` expect shape in dimension :attr:`axis` as 1.
soft_label (bool, optional): A flag to indicate whether to interpretant the given
labels as soft labels. Default False.
ignore_index (int, optional): Specifies a target value that is ignored and does
not contribute to the input gradient. Only valid
if :attr:`soft_label` is set to :attr:`False`.
Default: kIgnoreIndex(-100).
numeric_stable_mode (bool, optional): A flag to indicate whether to use a more
numerically stable algorithm. Only valid
when :attr:`soft_label` is :attr:`False`
and GPU is used. When :attr:`soft_label`
is :attr:`True` or CPU is used, the
algorithm is always numerically stable.
Note that the speed may be slower when use
stable algorithm. Default: True.
return_softmax (bool, optional): A flag indicating whether to return the softmax
along with the cross entropy loss. Default: False.
axis (int, optional): The index of dimension to perform softmax calculations. It
should be in range :math:`[-1, rank - 1]`, while :math:`rank`
is the rank of input :attr:`logits`. Default: -1.
Returns:
``Tensor`` or Tuple of two ``Tensor`` : Return the cross entropy loss if \
`return_softmax` is False, otherwise the tuple \
(loss, softmax), softmax is in the same shape \
with input logits and cross entropy loss is in \
the same shape with input logits except shape \
in dimension :attr:`axis` as 1.
Examples:
.. code-block:: python
import paddle
logits = paddle.to_tensor([0.4, 0.6, 0.9])
label = paddle.randint(high=2, shape=[1], dtype="int64")
out = paddle.nn.functional.softmax_with_cross_entropy(logits=logits, label=label)
print(out)
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [1.15328646])
"""
input_dims = len(list(logits.shape))
if input_dims == 0:
raise ValueError('The dimention of input should be larger than zero!')
label_dims = len(list(label.shape))
if input_dims - 1 != label_dims and input_dims != label_dims:
raise ValueError(
'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
(got nput_dims{}, label_dims{})'.format(
input_dims, label_dims
)
)
if input_dims - 1 == label_dims:
label = paddle.unsqueeze(label, axis=axis)
if in_dynamic_mode():
softmax, loss = _C_ops.cross_entropy_with_softmax(
logits,
label,
soft_label,
True,
numeric_stable_mode,
ignore_index,
axis,
)
if not return_softmax:
return loss
else:
return loss, softmax
else:
attrs = {
'soft_label': soft_label,
'ignore_index': ignore_index,
'numeric_stable_mode': numeric_stable_mode,
'axis': axis,
}
helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
outputs = {'Softmax': softmax, 'Loss': loss}
helper.append_op(
type='softmax_with_cross_entropy',
inputs={'Logits': logits, 'Label': label},
outputs=outputs,
attrs=attrs,
)
if return_softmax:
return loss, softmax
return loss
def npair_loss(anchor, positive, labels, l2_reg=0.002):
"""
Npair loss requires paired data. Npair loss has two parts: the first part is L2
regularizer on the embedding vector; the second part is cross entropy loss which
takes the similarity matrix of anchor and positive as logits.
For more information, please refer to:
`Improved Deep Metric Learning with Multi class N pair Loss Objective <http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf>`_
Args:
anchor(Tensor): embedding vector for the anchor image. shape=[batch_size, embedding_dims],
the data type is float32 or float64.
positive(Tensor): embedding vector for the positive image. shape=[batch_size, embedding_dims],
the data type is float32 or float64.
labels(Tensor): 1-D tensor. shape=[batch_size], the data type is float32 or float64 or int64.
l2_reg(float32): L2 regularization term on embedding vector, default: 0.002.
Returns:
A 0-D Tensor representing the npair loss, the data type is the same as anchor, the shape is [].
Examples:
.. code-block:: python
import paddle
DATATYPE = "float32"
anchor = paddle.rand(shape=(18, 6), dtype=DATATYPE)
positive = paddle.rand(shape=(18, 6), dtype=DATATYPE)
labels = paddle.rand(shape=(18,), dtype=DATATYPE)
npair_loss = paddle.nn.functional.npair_loss(anchor, positive, labels, l2_reg = 0.002)
print(npair_loss)
"""
if anchor.size == 0:
raise ValueError("The dims of anchor should be greater than 0.")
if positive.size == 0:
raise ValueError("The dims of positive should be greater than 0.")
check_variable_and_dtype(
anchor, 'anchor', ['float32', 'float64'], 'npair_loss'
)
check_variable_and_dtype(
positive, 'positive', ['float32', 'float64'], 'positive'
)
check_variable_and_dtype(
labels, 'labels', ['float32', 'float64', 'int64'], 'labels'
)
Beta = 0.25
batch_size = labels.shape[0]
labels = paddle.reshape(labels, shape=[batch_size, 1])
labels = paddle.tile(labels, repeat_times=[1, batch_size])
labels = paddle.equal(labels, paddle.transpose(labels, perm=[1, 0])).astype(
'float32'
)
labels = labels / paddle.sum(labels, axis=1, keepdim=True)
l2loss = paddle.mean(paddle.sum(paddle.square(anchor), 1)) + paddle.mean(
paddle.sum(paddle.square(positive), 1)
)
l2loss = l2loss * Beta * l2_reg
similarity_matrix = paddle.matmul(
anchor, positive, transpose_x=False, transpose_y=True
)
softmax_ce = fluid_softmax_with_cross_entropy(
logits=similarity_matrix, label=labels, soft_label=True
)
cross_entropy = paddle.sum(labels * softmax_ce, 0)
celoss = paddle.mean(cross_entropy)
return l2loss + celoss
def square_error_cost(input, label):
r"""
This op accepts input predictions and target label and returns the
squared error cost.
For predictions label, and target label, the equation is:
.. math::
Out = (input - label)^2
Parameters:
input (Tensor): Input tensor, the data type should be float32.
label (Tensor): Label tensor, the data type should be float32.
Returns:
Tensor, The tensor storing the element-wise squared error
difference between input and label.
Examples:
.. code-block:: python
import paddle
input = paddle.to_tensor([1.1, 1.9])
label = paddle.to_tensor([1.0, 2.0])
output = paddle.nn.functional.square_error_cost(input, label)
print(output)
# [0.01, 0.01]
"""
if in_dynamic_mode():
minus_out = _C_ops.subtract(input, label)
square_out = _C_ops.square(minus_out)
return square_out
else:
check_variable_and_dtype(
input, "input", ['float32', 'float64'], 'square_error_cost'
)
check_variable_and_dtype(
label, "label", ['float32', 'float64'], 'square_error_cost'
)
helper = LayerHelper('square_error_cost', **locals())
minus_out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='elementwise_sub',
inputs={'X': [input], 'Y': [label]},
outputs={'Out': [minus_out]},
)
square_out = helper.create_variable_for_type_inference(
dtype=input.dtype
)
helper.append_op(
type='square',
inputs={'X': [minus_out]},
outputs={'Out': [square_out]},
)
return square_out
def edit_distance(
input,
label,
normalized=True,
ignored_tokens=None,
input_length=None,
label_length=None,
):
"""
This op computes the edit distances, also called Levenshtein distance, between a batch of
hypothesis strings and their references. It measures how dissimilar two strings are by counting
the minimum number of operations to transform one string into another.
The operations include insertion, deletion, and substitution.
For example, given hypothesis string A = "kitten" and reference
B = "sitting", A will be transformed into B
at least after two substitutions and one insertion:
"kitten" -> "sitten" -> "sittin" -> "sitting"
So the edit distance between A and B is 3.
The input is a Tensor, the input_length and label_length should be supported.
The `batch_size` of labels should be same as `input`.
The output include the edit distance value between every pair of input and related label, and the number of sequence.
If Attr(normalized) is true,
the edit distance value will be divided by the length of label.
Parameters:
input(Tensor): The input tensor, its rank should be equal to 2 and its data type should be int64.
label(Tensor): The label tensor, its rank should be equal to 2 and its data type should be int64.
normalized(bool, default True): Indicated whether to normalize the edit distance.
ignored_tokens(list<int>, default None): Tokens that will be removed before
calculating edit distance.
input_length(Tensor): The length for each sequence in `input` if it's of Tensor type, it should have shape `(batch_size, )` and its data type should be int64.
label_length(Tensor): The length for each sequence in `label` if it's of Tensor type, it should have shape `(batch_size, )` and its data type should be int64.
NOTE: To be avoid unexpected result, the value of every elements in input_length and label_length should be equal to the value of the second dimension of input and label. For example, The input: [[1,2,3,4],[5,6,7,8],[9,10,11,12]], the shape of input is [3,4] and the input_length should be [4,4,4]
Returns:
Tuple:
distance(Tensor): edit distance result, its data type is float32, and its shape is (batch_size, 1).
sequence_num(Tensor): sequence number, its data type is float32, and its shape is (1,).
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
input = paddle.to_tensor([[1,2,3],[4,5,6],[4,4,4],[1,1,1]], dtype='int64')
label = paddle.to_tensor([[1,3,4,1],[4,5,8,1],[7,7,7,1],[1,1,1,1]], dtype='int64')
input_len = paddle.to_tensor([3,3,3,3], dtype='int64')
label_len = paddle.to_tensor([4,4,4,4], dtype='int64')
distance, sequence_num = F.loss.edit_distance(input=input, label=label, input_length=input_len, label_length=label_len, normalized=False)
# print(distance)
# [[3.]
# [2.]
# [4.]
# [1.]]
# if set normalized to True
# [[0.75]
# [0.5 ]
# [1. ]
# [0.25]
#
# print(sequence_num)
# [4]
"""
helper = LayerHelper("edit_distance", **locals())
# remove some tokens from input and labels
if ignored_tokens is not None and len(ignored_tokens) > 0:
erased_input = helper.create_variable_for_type_inference(dtype="int64")
erased_label = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type="sequence_erase",
inputs={"X": [input]},
outputs={"Out": [erased_input]},
attrs={"tokens": ignored_tokens},
)
input = erased_input
helper.append_op(
type="sequence_erase",
inputs={"X": [label]},
outputs={"Out": [erased_label]},
attrs={"tokens": ignored_tokens},
)
label = erased_label
if in_dynamic_mode():
return _C_ops.edit_distance(
input, label, input_length, label_length, normalized
)
check_variable_and_dtype(input, 'input', ['int64'], 'edit_distance')
check_variable_and_dtype(label, 'label', ['int64'], 'edit_distance')
this_inputs = {"Hyps": [input], "Refs": [label]}
if input_length is not None and label_length is not None:
this_inputs['HypsLength'] = [input_length]
this_inputs['RefsLength'] = [label_length]
# edit distance op
edit_distance_out = helper.create_variable_for_type_inference(dtype="int64")
sequence_num = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type="edit_distance",
inputs=this_inputs,
outputs={"Out": [edit_distance_out], "SequenceNum": [sequence_num]},
attrs={"normalized": normalized},
)
return edit_distance_out, sequence_num
def binary_cross_entropy(
input, label, weight=None, reduction='mean', name=None
):
"""
Measure the binary_cross_entropy loss between input predictions ``input``
and target labels ``label`` . The binary_cross_entropy loss can be described as:
If :attr:`weight` is set, the loss is:
.. math::
Out = -1 * weight * (label * log(input) + (1 - label) * log(1 - input))
If :attr:`weight` is None, the loss is:
.. math::
Out = -1 * (label * log(input) + (1 - label) * log(1 - input))
If :attr:`reduction` set to ``'none'``, the interface will return the original loss `Out`.
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:
.. math::
Out = MEAN(Out)
If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:
.. math::
Out = SUM(Out)
Note that the input predictions ``input`` always be the output of sigmoid, and the target labels ``label``
should be numbers between 0 and 1.
Parameters:
input (Tensor): The input predications tensor. 2-D tensor with shape: [N, *],
N is batch_size, `*` means number of additional dimensions. The ``input``
should always be the output of sigmod. Available dtype is float16, float32, float64.
label (Tensor): The target labels tensor. 2-D tensor with the same shape as
``input``. The target labels which values should be numbers between 0 and 1.
Available dtype is float16, float32, float64.
weight (Tensor, optional): A manual rescaling weight given to the loss of each
batch element. If given, has to be a Tensor of size nbatch and the data type
is float32, float64. Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default is ``'mean'``.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor. If ``reduction`` is ``'none'``, the shape of output is
same as ``input`` , else the shape of output is scalar.
Examples:
.. code-block:: python
import paddle
input = paddle.to_tensor([0.5, 0.6, 0.7], 'float32')
label = paddle.to_tensor([1.0, 0.0, 1.0], 'float32')
output = paddle.nn.functional.binary_cross_entropy(input, label)
print(output) # 0.65537095
"""
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in binary_cross_entropy should be 'sum', "
"'mean' or 'none', but received %s, which is not allowed."
% reduction
)
if in_dynamic_mode():
out = _C_ops.bce_loss(input, label)
if weight is not None:
out = _C_ops.multiply(out, weight, 'axis', -1)
if reduction == 'sum':
return _C_ops.sum(out, [], None, False)
elif reduction == 'mean':
return _C_ops.mean_all(out)
else:
return out
else:
check_variable_and_dtype(
input,
'input',
['float16', 'float32', 'float64'],
'binary_cross_entropy',
)
check_variable_and_dtype(
label,
'label',
['float16', 'float32', 'float64'],
'binary_cross_entropy',
)
sub_name = name if weight is None and reduction == 'none' else None
helper = LayerHelper("binary_cross_entropy", name=sub_name)
out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='bce_loss',
inputs={
'X': [input],
'Label': [label],
},
outputs={'Out': [out]},
)
if weight is not None:
if isinstance(weight, paddle.static.Variable):
weight_name = name if reduction == 'none' else None
out = paddle.multiply(out, weight, name=weight_name)
else:
raise ValueError(
"The weight is not a Tensor, please convert to Tensor."
)
if reduction == 'sum':
return paddle.sum(out, name=name)
elif reduction == 'mean':
return paddle.mean(out, name=name)
else:
return out
def binary_cross_entropy_with_logits(
logit, label, weight=None, reduction='mean', pos_weight=None, name=None
):
r"""
Combine the sigmoid layer and the :ref:`api_nn_loss_BCELoss` layer.
This measures the element-wise probability error in classification tasks
in which each class is independent.
This can be thought of as predicting labels for a data-point, where labels
are not mutually exclusive. For example, a news article can be about
politics, technology or sports at the same time or none of these.
Firstly, calculate loss function as follows:
.. math::
Out = -Labels * \log(\sigma(Logit)) - (1 - Labels) * \log(1 - \sigma(Logit))
We know that :math:`\sigma(Logit) = \frac{1}{1 + e^{-Logit}}`. By substituting this we get:
.. math::
Out = Logit - Logit * Labels + \log(1 + e^{-Logit})
For stability and to prevent overflow of :math:`e^{-Logit}` when Logit < 0,
we reformulate the loss as follows:
.. math::
Out = \max(Logit, 0) - Logit * Labels + \log(1 + e^{-\|Logit\|})
Then, if ``weight`` or ``pos_weight`` is not None, then multiply the
weight tensor on the loss `Out`. The ``weight`` tensor will attach different
weight on every items in the batch. The ``pos_weight`` will attach different
weight on the positive label of each class.
Finally, apply reduce operation on the loss.
If :attr:`reduction` set to ``'none'``, will return the original loss `Out`.
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`.
If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`.
Note that the target labels ``label`` should be numbers between 0 and 1.
Args:
logit (Tensor): The input predications tensor. 2-D tensor with shape: [N, *],
N is batch_size, `*` means number of additional dimensions. The ``logit``
is usually the output of Linear layer. Available dtype is float32, float64.
label (Tensor): The target labels tensor. 2-D tensor with the same shape as
``logit``. The target labels which values should be numbers between 0 and 1.
Available dtype is float32, float64.
weight (Tensor, optional): A manual rescaling weight given to the loss of each
batch element. If given, it has to be a 1D Tensor whose size is `[N, ]`,
The data type is float32, float64. Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default is ``'mean'``.
pos_weight (Tensor, optional): A weight of positive examples. Must be a vector
with length equal to the number of classes. The data type is float32, float64.
Default is ``'None'``.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor. If ``reduction`` is ``'none'``, the shape of output is
same as ``logit`` , else the shape of output is scalar.
Examples:
.. code-block:: python
import paddle
logit = paddle.to_tensor([5.0, 1.0, 3.0])
label = paddle.to_tensor([1.0, 0.0, 1.0])
output = paddle.nn.functional.binary_cross_entropy_with_logits(logit, label)
print(output) # 0.45618808
"""
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in binary_cross_entropy_with_logits "
"should be 'sum', 'mean' or 'none', but received %s, which is not allowed."
% reduction
)
if in_dynamic_mode():
one = _C_ops.full(
[1],
float(1.0),
logit.dtype,
_current_expected_place(),
)
out = _C_ops.sigmoid_cross_entropy_with_logits(
logit, label, False, -100
)
if pos_weight is not None:
log_weight = _C_ops.add(
_C_ops.multiply(label, _C_ops.subtract(pos_weight, one)), one
)
out = _C_ops.multiply(out, log_weight)
if weight is not None:
out = _C_ops.multiply(out, weight)
if reduction == "sum":
return _C_ops.sum(out, [], None, False)
elif reduction == "mean":
return _C_ops.mean_all(out)
else:
return out
else:
check_variable_and_dtype(
logit,
'logit',
['float32', 'float64'],
'binary_cross_entropy_with_logits',
)
check_variable_and_dtype(
label,
'label',
['float32', 'float64'],
'binary_cross_entropy_with_logits',
)
sigmoid_name = None
if reduction == 'none' and pos_weight is None and weight is None:
sigmoid_name = name
helper = LayerHelper("sigmoid_cross_entropy_with_logits", **locals())
out = helper.create_variable_for_type_inference(dtype=logit.dtype)
helper.append_op(
type="sigmoid_cross_entropy_with_logits",
inputs={"X": logit, "Label": label},
attrs={"ignore_index": kIgnoreIndex, 'normalize': False},
outputs={"Out": out},
)
one = paddle.full(shape=[1], fill_value=1.0, dtype=logit.dtype)
if pos_weight is not None:
check_variable_and_dtype(
pos_weight,
'pos_weight',
['float32', 'float64'],
'binary_cross_entropy_with_logits',
)
log_weight = paddle.add(
paddle.multiply(label, paddle.subtract(pos_weight, one)), one
)
pos_weight_name = (
name if reduction == 'none' and weight is None else None
)
out = paddle.multiply(out, log_weight, name=pos_weight_name)
if weight is not None:
check_variable_and_dtype(
weight,
'weight',
['float32', 'float64'],
'binary_cross_entropy_with_logits',
)
weight_name = name if reduction == 'none' else None
out = paddle.multiply(out, weight, name=weight_name)
if reduction == "sum":
return paddle.sum(out, name=name)
elif reduction == "mean":
return paddle.mean(out, name=name)
return out
def hsigmoid_loss(
input,
label,
num_classes,
weight,
bias=None,
path_table=None,
path_code=None,
is_sparse=False,
name=None,
):
"""
The hierarchical sigmoid organizes the classes into a complete binary tree to reduce the computational complexity
and speed up the model training, especially the training of language model.
Each leaf node of the complete binary tree represents a class(word) and each non-leaf node acts as a binary classifier.
For each class(word), there's a unique path from root to itself, hsigmoid calculate the cost for each non-leaf node on
the path, and sum them to get a total cost.
Comparing to softmax, hsigmoid can reduce the computational complexity from :math:`O(N)` to :math:`O(logN)`, where :math:`N`
represents the number of classes or the size of word dict.
The API supports default tree and custom tree. For the default tree, you can refer to `Hierarchical Probabilistic Neural
Network Language Model <http://www.iro.umontreal.ca/~lisa/pointeurs/hierarchical-nnlm-aistats05.pdf>`_.
For the custom tree, you need to set :attr:`is_custom` to True, and do the following steps (take the language model as an example):
1. Using a custom word dict to build a binary tree, each leaf node should be an word in the word dict.
2. Creating a dict map word_id -> path that from the word to the root node, we call it path_table.
3. Creating a dict map word_id -> code of path that from the word to the root node, we call it path_code.
Code means the label of each binary classifier, 1 indicate true, 0 indicate false.
4. Now, each word should has its path and code along the path, you can pass a batch of path and code related
to the same batch of inputs.
Parameters:
input (Tensor): A tensor with the shape [N, D], where N is the size of mini-batch,
and D is the feature size. Its data type supports float32 or float64.
label (Tensor): A tensor contains the labels of training data. Its shape is [N, 1]
and data type is int64.
num_classes (int): The number of classes or the size of word dict, must be greater than 2.
If the default tree is used (path_code and path_table is None are None), `num_classes`
should not be None. If the custom tree is used (path_code and path_table is None are not None),
`num_classes` should be the number of non-leaf nodes, which indicates the num of
classes using by the binary classifier.
weight (Tensor): A tensor with shape (num_classes - 1, D), with the same data type as `input`.
bias (Tensor, optional): A tensor with shape (num_classes - 1, 1), with the same data type as `input`.
If `bias` is None, no bias will be add. Default is None.
path_table (Tensor, optional): A tensor that stores each batch of samples' path from leaf to root
node, its shape is [N, L] and data type is int64, where L is the length of path. For each sample i,
path_table[i] is a np.array like structure and each element in this array is the indexes in parent
nodes' weight matrix. If `path_table` and `path_code` are None, the default tree will be used.
Default is None.
path_code (Tensor, optional): A tensor that stores each batch of samples' code of path from leaf
to root node, its shape is [N, L] and data type is int64, which is the same as :attr:`path_table`.
Each code of path is consisted with the code of nodes from leaf to root node. If `path_table` and
`path_code` are None, the default tree will be used. Default is None.
is_sparse (bool, optional): Whether use sparse updating instead of dense updating. If `is_sparse` is True,
the gradient of `weight` and `input` will be sparse. Default is False.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A tensor with the cost of hierarchical sigmoid, its shape is [N, 1] and data type is the same as `input`.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
paddle.set_device('cpu')
input = paddle.uniform([4, 3])
# [[0.45424712 -0.77296764 0.82943869] # random
# [0.85062802 0.63303483 0.35312140] # random
# [0.57170701 0.16627562 0.21588242] # random
# [0.27610803 -0.99303514 -0.17114788]] # random
label = paddle.to_tensor([0, 1, 4, 5])
num_classes = 5
weight=paddle.uniform([num_classes-1, 3])
# [[-0.64477652 0.24821866 -0.17456549] # random
# [-0.04635394 0.07473493 -0.25081766] # random
# [ 0.05986035 -0.12185556 0.45153677] # random
# [-0.66236806 0.91271877 -0.88088769]] # random
out=F.hsigmoid_loss(input, label, num_classes, weight)
# [[1.96709502]
# [2.40019274]
# [2.11009121]
# [1.92374969]]
"""
if num_classes < 2:
raise ValueError(f'Expected num_classes >= 2 (got {num_classes})')
if in_dynamic_mode():
out, _, _ = _C_ops.hsigmoid_loss(
input,
label,
weight,
bias,
path_table,
path_code,
num_classes,
is_sparse,
is_sparse,
)
return out
else:
check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'hsigmoid_loss'
)
check_variable_and_dtype(label, 'label', ['int64'], 'hsigmoid_loss')
check_variable_and_dtype(
weight, 'weight', ['float32', 'float64'], 'hsigmoid_loss'
)
if bias is not None:
check_variable_and_dtype(
bias, 'bias', ['float32', 'float64'], 'hsigmoid_loss'
)
if path_table is not None:
check_variable_and_dtype(
path_table, 'path_table', ['int64'], 'hsigmoid_loss'
)
if path_code is not None:
check_variable_and_dtype(
path_code, 'path_code', ['int64'], 'hsigmoid_loss'
)