Skip to content

Commit

Permalink
change segment_pool
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Aug 2, 2022
1 parent 560ec76 commit c009f77
Showing 1 changed file with 12 additions and 20 deletions.
32 changes: 12 additions & 20 deletions pgl/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,18 @@ def segment_pool(data, segment_ids, pool_type, name=None):
Segment Operator.
"""
pool_type = pool_type.upper()
if _non_static_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype',
pool_type)
return out

check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")

helper = LayerHelper("segment_pool", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
pool_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": pool_ids},
attrs={"pooltype": pool_type})
return out
if pool_type == "SUM":
return paddle.incubate.segment_sum(data, segment_ids, name)
elif pool_type == "MEAN":
return paddle.incubate.segment_mean(data, segment_ids, name)
elif pool_type == "MAX":
return paddle.incubate.segment_max(data, segment_ids, name)
elif pool_type == "MIN":
return paddle.incubate.segment_min(data, segment_ids, name)
else:
raise ValueError(
"We only support sum, mean, max, min pool types in segment_pool function."
)


def segment_sum(data, segment_ids, name=None):
Expand Down

0 comments on commit c009f77

Please sign in to comment.