Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix segment ops #445

Merged
merged 2 commits into from
Aug 3, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 17 additions & 113 deletions pgl/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import paddle
from paddle import _C_ops
from paddle.common_ops_import import LayerHelper
from paddle.common_ops_import import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph
from paddle.common_ops_import import _non_static_mode
from paddle.common_ops_import import check_variable_and_dtype
from pgl.utils.op import get_index_from_counts

Expand All @@ -35,30 +35,18 @@ def segment_pool(data, segment_ids, pool_type, name=None):
Segment Operator.
"""
pool_type = pool_type.upper()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

segment_pool 改成 if else,调用其他segment op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if in_dygraph_mode():
out, tmp = _C_ops.final_state_segment_pool(data, segment_ids,
'pooltype', pool_type)
return out
if _in_legacy_dygraph():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype',
pool_type)
return out

check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")

helper = LayerHelper("segment_pool", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
pool_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": pool_ids},
attrs={"pooltype": pool_type})
return out
if pool_type == "SUM":
return paddle.incubate.segment_sum(data, segment_ids, name)
elif pool_type == "MEAN":
return paddle.incubate.segment_mean(data, segment_ids, name)
elif pool_type == "MAX":
return paddle.incubate.segment_max(data, segment_ids, name)
elif pool_type == "MIN":
return paddle.incubate.segment_min(data, segment_ids, name)
else:
raise ValueError(
"We only support sum, mean, max, min pool types in segment_pool function."
)


def segment_sum(data, segment_ids, name=None):
Expand Down Expand Up @@ -91,28 +79,7 @@ def segment_sum(data, segment_ids, name=None):

"""

if paddle.__version__ >= '2.2.0' or paddle.__version__ == '0.0.0':
return paddle.incubate.segment_sum(data, segment_ids, name)

if _non_static_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "SUM")
return out

check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")

helper = LayerHelper("segment_sum", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": summed_ids},
attrs={"pooltype": "SUM"})
return out
return paddle.incubate.segment_sum(data, segment_ids, name)


def segment_mean(data, segment_ids, name=None):
Expand Down Expand Up @@ -146,28 +113,7 @@ def segment_mean(data, segment_ids, name=None):
#Outputs: [[2., 2., 2.], [4., 5., 6.]]

"""
if paddle.__version__ >= '2.2.0' or paddle.__version__ == '0.0.0':
return paddle.incubate.segment_mean(data, segment_ids, name)

if _non_static_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MEAN")
return out

check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")

helper = LayerHelper("segment_mean", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": summed_ids},
attrs={"pooltype": "MEAN"})
return out
return paddle.incubate.segment_mean(data, segment_ids, name)


def segment_min(data, segment_ids, name=None):
Expand Down Expand Up @@ -199,28 +145,7 @@ def segment_min(data, segment_ids, name=None):
#Outputs: [[1., 2., 1.], [4., 5., 6.]]

"""
if paddle.__version__ >= '2.2.0' or paddle.__version__ == '0.0.0':
return paddle.incubate.segment_min(data, segment_ids, name)

if _non_static_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MIN")
return out

check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")

helper = LayerHelper("segment_min", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": summed_ids},
attrs={"pooltype": "MIN"})
return out
return paddle.incubate.segment_min(data, segment_ids, name)


def segment_max(data, segment_ids, name=None):
Expand Down Expand Up @@ -253,28 +178,7 @@ def segment_max(data, segment_ids, name=None):
#Outputs: [[3., 2., 3.], [4., 5., 6.]]

"""
if paddle.__version__ >= '2.2.0' or paddle.__version__ == '0.0.0':
return paddle.incubate.segment_max(data, segment_ids, name)

if _non_static_mode():
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MAX")
return out

check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")

helper = LayerHelper("segment_max", **locals())
out = helper.create_variable_for_type_inference(dtype=data.dtype)
summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype)
helper.append_op(
type="segment_pool",
inputs={"X": data,
"SegmentIds": segment_ids},
outputs={"Out": out,
"SummedIds": summed_ids},
attrs={"pooltype": "MAX"})
return out
return paddle.incubate.segment_max(data, segment_ids, name)


def segment_softmax(data, segment_ids):
Expand Down