diff --git a/pgl/math.py b/pgl/math.py index aed60801..6e90795f 100644 --- a/pgl/math.py +++ b/pgl/math.py @@ -25,7 +25,7 @@ import paddle from paddle import _C_ops from paddle.common_ops_import import LayerHelper -from paddle.common_ops_import import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph +from paddle.common_ops_import import _non_static_mode from paddle.common_ops_import import check_variable_and_dtype from pgl.utils.op import get_index_from_counts @@ -35,30 +35,18 @@ def segment_pool(data, segment_ids, pool_type, name=None): Segment Operator. """ pool_type = pool_type.upper() - if in_dygraph_mode(): - out, tmp = _C_ops.final_state_segment_pool(data, segment_ids, - 'pooltype', pool_type) - return out - if _in_legacy_dygraph(): - out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', - pool_type) - return out - - check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") - check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), - "segment_pool") - - helper = LayerHelper("segment_pool", **locals()) - out = helper.create_variable_for_type_inference(dtype=data.dtype) - pool_ids = helper.create_variable_for_type_inference(dtype=data.dtype) - helper.append_op( - type="segment_pool", - inputs={"X": data, - "SegmentIds": segment_ids}, - outputs={"Out": out, - "SummedIds": pool_ids}, - attrs={"pooltype": pool_type}) - return out + if pool_type == "SUM": + return paddle.incubate.segment_sum(data, segment_ids, name) + elif pool_type == "MEAN": + return paddle.incubate.segment_mean(data, segment_ids, name) + elif pool_type == "MAX": + return paddle.incubate.segment_max(data, segment_ids, name) + elif pool_type == "MIN": + return paddle.incubate.segment_min(data, segment_ids, name) + else: + raise ValueError( + "We only support sum, mean, max, min pool types in segment_pool function." + ) def segment_sum(data, segment_ids, name=None): @@ -91,28 +79,7 @@ def segment_sum(data, segment_ids, name=None): """ - if paddle.__version__ >= '2.2.0' or paddle.__version__ == '0.0.0': - return paddle.incubate.segment_sum(data, segment_ids, name) - - if _non_static_mode(): - out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "SUM") - return out - - check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") - check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), - "segment_pool") - - helper = LayerHelper("segment_sum", **locals()) - out = helper.create_variable_for_type_inference(dtype=data.dtype) - summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) - helper.append_op( - type="segment_pool", - inputs={"X": data, - "SegmentIds": segment_ids}, - outputs={"Out": out, - "SummedIds": summed_ids}, - attrs={"pooltype": "SUM"}) - return out + return paddle.incubate.segment_sum(data, segment_ids, name) def segment_mean(data, segment_ids, name=None): @@ -146,28 +113,7 @@ def segment_mean(data, segment_ids, name=None): #Outputs: [[2., 2., 2.], [4., 5., 6.]] """ - if paddle.__version__ >= '2.2.0' or paddle.__version__ == '0.0.0': - return paddle.incubate.segment_mean(data, segment_ids, name) - - if _non_static_mode(): - out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MEAN") - return out - - check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") - check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), - "segment_pool") - - helper = LayerHelper("segment_mean", **locals()) - out = helper.create_variable_for_type_inference(dtype=data.dtype) - summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) - helper.append_op( - type="segment_pool", - inputs={"X": data, - "SegmentIds": segment_ids}, - outputs={"Out": out, - "SummedIds": summed_ids}, - attrs={"pooltype": "MEAN"}) - return out + return paddle.incubate.segment_mean(data, segment_ids, name) def segment_min(data, segment_ids, name=None): @@ -199,28 +145,7 @@ def segment_min(data, segment_ids, name=None): #Outputs: [[1., 2., 1.], [4., 5., 6.]] """ - if paddle.__version__ >= '2.2.0' or paddle.__version__ == '0.0.0': - return paddle.incubate.segment_min(data, segment_ids, name) - - if _non_static_mode(): - out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MIN") - return out - - check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") - check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), - "segment_pool") - - helper = LayerHelper("segment_min", **locals()) - out = helper.create_variable_for_type_inference(dtype=data.dtype) - summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) - helper.append_op( - type="segment_pool", - inputs={"X": data, - "SegmentIds": segment_ids}, - outputs={"Out": out, - "SummedIds": summed_ids}, - attrs={"pooltype": "MIN"}) - return out + return paddle.incubate.segment_min(data, segment_ids, name) def segment_max(data, segment_ids, name=None): @@ -253,28 +178,7 @@ def segment_max(data, segment_ids, name=None): #Outputs: [[3., 2., 3.], [4., 5., 6.]] """ - if paddle.__version__ >= '2.2.0' or paddle.__version__ == '0.0.0': - return paddle.incubate.segment_max(data, segment_ids, name) - - if _non_static_mode(): - out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MAX") - return out - - check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") - check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), - "segment_pool") - - helper = LayerHelper("segment_max", **locals()) - out = helper.create_variable_for_type_inference(dtype=data.dtype) - summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) - helper.append_op( - type="segment_pool", - inputs={"X": data, - "SegmentIds": segment_ids}, - outputs={"Out": out, - "SummedIds": summed_ids}, - attrs={"pooltype": "MAX"}) - return out + return paddle.incubate.segment_max(data, segment_ids, name) def segment_softmax(data, segment_ids):