Skip to content

Commit

Permalink
Merge pull request #445 from DesmonDay/delete_fluid
Browse files Browse the repository at this point in the history
fix segment ops
  • Loading branch information
Yelrose authored Aug 3, 2022
2 parents 3f31bdf + c009f77 commit fa901df
Showing 1 changed file with 17 additions and 113 deletions.
130 changes: 17 additions & 113 deletions pgl/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import paddle
from paddle import _C_ops
from paddle.common_ops_import import LayerHelper
from paddle.common_ops_import import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph
from paddle.common_ops_import import _non_static_mode
from paddle.common_ops_import import check_variable_and_dtype
from pgl.utils.op import get_index_from_counts

Expand All @@ -35,30 +35,18 @@ def segment_pool(data, segment_ids, pool_type, name=None):
Segment Operator.
"""
pool_type = pool_type.upper()
if in_dygraph_mode():
out, tmp = _C_ops.final_state_segment_pool(data, segment_ids,
'pooltype', pool_type)
return out
if _in_legacy_dygraph():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype',
pool_type)
return out

check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")

helper = LayerHelper("segment_pool", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
pool_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": pool_ids},
attrs={"pooltype": pool_type})
return out
if pool_type == "SUM":
return paddle.incubate.segment_sum(data, segment_ids, name)
elif pool_type == "MEAN":
return paddle.incubate.segment_mean(data, segment_ids, name)
elif pool_type == "MAX":
return paddle.incubate.segment_max(data, segment_ids, name)
elif pool_type == "MIN":
return paddle.incubate.segment_min(data, segment_ids, name)
else:
raise ValueError(
"We only support sum, mean, max, min pool types in segment_pool function."
)


def segment_sum(data, segment_ids, name=None):
Expand Down Expand Up @@ -91,28 +79,7 @@ def segment_sum(data, segment_ids, name=None):
"""

if paddle.__version__ >= '2.2.0' or paddle.__version__ == '0.0.0':
return paddle.incubate.segment_sum(data, segment_ids, name)

if _non_static_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "SUM")
return out

check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")

helper = LayerHelper("segment_sum", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": summed_ids},
attrs={"pooltype": "SUM"})
return out
return paddle.incubate.segment_sum(data, segment_ids, name)


def segment_mean(data, segment_ids, name=None):
Expand Down Expand Up @@ -146,28 +113,7 @@ def segment_mean(data, segment_ids, name=None):
#Outputs: [[2., 2., 2.], [4., 5., 6.]]
"""
if paddle.__version__ >= '2.2.0' or paddle.__version__ == '0.0.0':
return paddle.incubate.segment_mean(data, segment_ids, name)

if _non_static_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MEAN")
return out

check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")

helper = LayerHelper("segment_mean", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": summed_ids},
attrs={"pooltype": "MEAN"})
return out
return paddle.incubate.segment_mean(data, segment_ids, name)


def segment_min(data, segment_ids, name=None):
Expand Down Expand Up @@ -199,28 +145,7 @@ def segment_min(data, segment_ids, name=None):
#Outputs: [[1., 2., 1.], [4., 5., 6.]]
"""
if paddle.__version__ >= '2.2.0' or paddle.__version__ == '0.0.0':
return paddle.incubate.segment_min(data, segment_ids, name)

if _non_static_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MIN")
return out

check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")

helper = LayerHelper("segment_min", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": summed_ids},
attrs={"pooltype": "MIN"})
return out
return paddle.incubate.segment_min(data, segment_ids, name)


def segment_max(data, segment_ids, name=None):
Expand Down Expand Up @@ -253,28 +178,7 @@ def segment_max(data, segment_ids, name=None):
#Outputs: [[3., 2., 3.], [4., 5., 6.]]
"""
if paddle.__version__ >= '2.2.0' or paddle.__version__ == '0.0.0':
return paddle.incubate.segment_max(data, segment_ids, name)

if _non_static_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MAX")
return out

check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")

helper = LayerHelper("segment_max", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": summed_ids},
attrs={"pooltype": "MAX"})
return out
return paddle.incubate.segment_max(data, segment_ids, name)


def segment_softmax(data, segment_ids):
Expand Down

0 comments on commit fa901df

Please sign in to comment.