Skip to content

Commit

Permalink
remove redundant computation in Categorical.probs (#42114)
Browse files Browse the repository at this point in the history
  • Loading branch information
Feiyu Chan authored Apr 25, 2022
1 parent 6553a9d commit 9a0bfec
Showing 1 changed file with 16 additions and 35 deletions.
51 changes: 16 additions & 35 deletions python/paddle/distribution/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def __init__(self, logits, name=None):
self.logits = self._to_tensor(logits)[0]
if self.dtype != convert_dtype(self.logits.dtype):
self.logits = tensor.cast(self.logits, dtype=self.dtype)
dist_sum = paddle.sum(self.logits, axis=-1, keepdim=True)
self._prob = self.logits / dist_sum

def sample(self, shape):
"""Generate samples of the specified shape.
Expand Down Expand Up @@ -297,42 +299,21 @@ def probs(self, value):
"""
name = self.name + '_probs'

dist_sum = paddle.sum(self.logits, axis=-1, keepdim=True)
prob = self.logits / dist_sum

shape = list(prob.shape)
value_shape = list(value.shape)
if len(shape) == 1:
num_value_in_one_dist = np.prod(value_shape)
index_value = paddle.reshape(value, [num_value_in_one_dist, 1])
index = index_value
if len(self._prob.shape) == 1: # batch_shape is empty
return paddle.gather(
self._prob, value.reshape(
[-1], name=name), name=name).reshape(
value.shape, name=name)
else:
num_dist = np.prod(shape[:-1])
num_value_in_one_dist = value_shape[-1]
prob = paddle.reshape(prob, [num_dist, shape[-1]])
if len(value_shape) == 1:
value = nn.expand(value, [num_dist])
value_shape = shape[:-1] + value_shape
index_value = paddle.reshape(value, [num_dist, -1, 1])
if shape[:-1] != value_shape[:-1]:
raise ValueError(
"shape of value {} must match shape of logits {}".format(
str(value_shape[:-1]), str(shape[:-1])))

index_prefix = paddle.unsqueeze(
arange(
num_dist, dtype=index_value.dtype), axis=-1)
index_prefix = nn.expand(index_prefix, [1, num_value_in_one_dist])
index_prefix = paddle.unsqueeze(index_prefix, axis=-1)

if index_value.dtype != index_prefix.dtype:
tensor.cast(index_prefix, dtype=index_value.dtype)
index = concat([index_prefix, index_value], axis=-1)

# value is the category index to search for the corresponding probability.
select_prob = gather_nd(prob, index)
return paddle.reshape(select_prob, value_shape, name=name)
if len(value.shape) == 1:
return paddle.take_along_axis(
self._prob,
paddle.reshape(
value, (len(self._prob.shape) - 1) * [1] + [-1],
name=name),
axis=-1)
else:
return paddle.take_along_axis(self._prob, value, axis=-1)

def log_prob(self, value):
"""Log probabilities of the given category. Refer to ``probs`` method.
Expand Down

0 comments on commit 9a0bfec

Please sign in to comment.