-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMP OP&Test] add fp16 test for linspace #52161
Merged
Merged
Changes from 9 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
6c5933a
add fp16 test for linspace
sljlp 6f257cc
fix
sljlp 4ee1117
update
sljlp 6fa57a5
update
sljlp 03ba1a7
updatew
sljlp 2fcd1eb
Merge branch 'add_fp16_for_linspace' of https://github.com/sljlp/Padd…
sljlp b928fa5
uodate for bf16
sljlp fc09b74
update
sljlp 952a527
update
sljlp d7b63cb
update
sljlp File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ | |
import unittest | ||
|
||
import numpy as np | ||
from eager_op_test import OpTest, paddle_static_guard | ||
from eager_op_test import OpTest, convert_float_to_uint16, paddle_static_guard | ||
|
||
import paddle | ||
from paddle import fluid | ||
|
@@ -26,56 +26,117 @@ class TestLinspaceOpCommonCase(OpTest): | |
def setUp(self): | ||
self.op_type = "linspace" | ||
self.python_api = paddle.linspace | ||
dtype = 'float32' | ||
self._set_dtype() | ||
self._set_data() | ||
self.attrs = {'dtype': self.attr_dtype} | ||
|
||
def _set_dtype(self): | ||
self.dtype = "float32" | ||
self.attr_dtype = int(core.VarDesc.VarType.FP32) | ||
|
||
def _set_data(self): | ||
self.outputs = {'Out': np.arange(0, 11).astype(self.dtype)} | ||
self.inputs = { | ||
'Start': np.array([0]).astype(dtype), | ||
'Stop': np.array([10]).astype(dtype), | ||
'Start': np.array([0]).astype(self.dtype), | ||
'Stop': np.array([10]).astype(self.dtype), | ||
'Num': np.array([11]).astype('int32'), | ||
} | ||
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} | ||
|
||
self.outputs = {'Out': np.arange(0, 11).astype(dtype)} | ||
|
||
def test_check_output(self): | ||
self.check_output() | ||
|
||
|
||
class TestLinspaceOpReverseCase(OpTest): | ||
def setUp(self): | ||
self.op_type = "linspace" | ||
self.python_api = paddle.linspace | ||
dtype = 'float32' | ||
class TestLinspaceOpReverseCase(TestLinspaceOpCommonCase): | ||
def _set_data(self): | ||
self.inputs = { | ||
'Start': np.array([10]).astype(dtype), | ||
'Stop': np.array([0]).astype(dtype), | ||
'Start': np.array([10]).astype(self.dtype), | ||
'Stop': np.array([0]).astype(self.dtype), | ||
'Num': np.array([11]).astype('int32'), | ||
} | ||
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} | ||
|
||
self.outputs = {'Out': np.arange(10, -1, -1).astype(dtype)} | ||
self.outputs = {'Out': np.arange(10, -1, -1).astype(self.dtype)} | ||
|
||
def test_check_output(self): | ||
self.check_output() | ||
|
||
|
||
class TestLinspaceOpNumOneCase(OpTest): | ||
def setUp(self): | ||
self.op_type = "linspace" | ||
self.python_api = paddle.linspace | ||
dtype = 'float32' | ||
class TestLinspaceOpNumOneCase(TestLinspaceOpCommonCase): | ||
def _set_data(self): | ||
self.inputs = { | ||
'Start': np.array([10]).astype(dtype), | ||
'Stop': np.array([0]).astype(dtype), | ||
'Start': np.array([10]).astype(self.dtype), | ||
'Stop': np.array([0]).astype(self.dtype), | ||
'Num': np.array([1]).astype('int32'), | ||
} | ||
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} | ||
|
||
self.outputs = {'Out': np.array(10, dtype=dtype)} | ||
self.outputs = {'Out': np.array(10, dtype=self.dtype)} | ||
|
||
def test_check_output(self): | ||
self.check_output() | ||
|
||
|
||
class TestLinspaceOpCommonCaseFP16(TestLinspaceOpCommonCase): | ||
def _set_dtype(self): | ||
self.dtype = np.float16 | ||
self.attr_dtype = int(core.VarDesc.VarType.FP16) | ||
|
||
|
||
class TestLinspaceOpReverseCaseFP16(TestLinspaceOpReverseCase): | ||
def _set_dtype(self): | ||
self.dtype = np.float16 | ||
self.attr_dtype = int(core.VarDesc.VarType.FP16) | ||
|
||
|
||
class TestLinspaceOpNumOneCaseFP16(TestLinspaceOpNumOneCase): | ||
def _set_dtype(self): | ||
self.dtype = np.float16 | ||
self.attr_dtype = int(core.VarDesc.VarType.FP16) | ||
|
||
|
||
@unittest.skipIf( | ||
not core.is_compiled_with_cuda() | ||
or not core.is_bfloat16_supported(core.CUDAPlace(0)), | ||
'not supported bf16', | ||
) | ||
class TestLinspaceOpCommonCaseBF16(TestLinspaceOpCommonCaseFP16): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 直接继承的TestLinspaceOpCommonCase,里面调用了check_output,当前kernel没有注册CPU的bf16支持,所以会跑到CPU的place上,然后报错 |
||
def _set_dtype(self): | ||
self.dtype = np.uint16 | ||
self.attr_dtype = int(core.VarDesc.VarType.BF16) | ||
|
||
def _set_data(self): | ||
self.outputs = { | ||
'Out': convert_float_to_uint16(np.arange(0, 11).astype("float32")) | ||
} | ||
self.inputs = { | ||
'Start': convert_float_to_uint16(np.array([0]).astype("float32")), | ||
'Stop': convert_float_to_uint16(np.array([10]).astype("float32")), | ||
'Num': np.array([11]).astype('int32'), | ||
} | ||
|
||
|
||
class TestLinspaceOpReverseCaseBF16(TestLinspaceOpCommonCaseBF16): | ||
def _set_data(self): | ||
self.inputs = { | ||
'Start': convert_float_to_uint16(np.array([10]).astype("float32")), | ||
'Stop': convert_float_to_uint16(np.array([0]).astype("float32")), | ||
'Num': np.array([11]).astype('int32'), | ||
} | ||
self.outputs = { | ||
'Out': convert_float_to_uint16( | ||
np.arange(10, -1, -1).astype("float32") | ||
) | ||
} | ||
|
||
|
||
class TestLinspaceOpNumOneCaseBF16(TestLinspaceOpCommonCaseBF16): | ||
def _set_data(self): | ||
self.inputs = { | ||
'Start': convert_float_to_uint16(np.array([10]).astype("float32")), | ||
'Stop': convert_float_to_uint16(np.array([0]).astype("float32")), | ||
'Num': np.array([1]).astype('int32'), | ||
} | ||
self.outputs = { | ||
'Out': convert_float_to_uint16(np.array(10, dtype="float32")) | ||
} | ||
|
||
|
||
class TestLinspaceAPI(unittest.TestCase): | ||
def test_variable_input1(self): | ||
with paddle_static_guard(): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是多删了self.attrs,下同
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.attrs定义写在 _set_dtype 里了