-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added support for complex numbers in python bindings #12872
Conversation
runtime/bindings/python/hal.cc
Outdated
case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64: | ||
dtype_code = "complex64"; | ||
break; | ||
case IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128: | ||
dtype_code = "complex128"; | ||
break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have a source for these codes? The reference linked above the switch statement (https://docs.python.org/3/c-api/arg.html#numbers) only lists
D (complex) [Py_complex *]
Convert a C Py_complex structure to a Python complex number.
I see a few other pages like https://numpy.org/doc/stable/user/basics.types.html#array-types-and-conversions-between-types but they don't quite look like what goes here (generally single character codes).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the py::dtype(dtype_code); to get the return value can either accept a code matching to a datatype, or you can just give the name of the datatype. I couldn't find a code for complex64, so I just gave the full name, and I did the same for complex128 for consistency
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, https://numpy.org/doc/stable/reference/arrays.dtypes.html
for key in numpy.sctypeDict.keys():
if "complex" in str(key):
print(key)
# complex64
# complex128
# complex256
# complex_
# singlecomplex
# longcomplex
# complex
dt = numpy.dtype("complex64")
print(dt.type)
print(dt.kind)
print(dt.char)
print(dt.num)
dt = numpy.dtype("complex128")
print(dt.type)
print(dt.kind)
print(dt.char)
print(dt.num)
# <class 'numpy.complex64'>
# c
# F
# 14
# <class 'numpy.complex128'>
# c
# D
# 15
Can you add that link and change the TODO text a bit near the top of this switch statement?
// See: https://docs.python.org/3/c-api/arg.html#numbers
// TODO: Handle dtypes that do not map to a code (i.e. fp16).
I'm also wondering what sort of test coverage we could add for this (probably some tests that use complex numbers in programs, run via the python APIs?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll make those changes and work on some test cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh, good timing with this, I'm tracking down an issue on Windows where that function was buggy (due to how platforms implement integer types). I'm putting together a PR that makes some related changes (like adding that reference link and swapping "i" -> "int32")
Fixes #11080. The int64 and uint64 test cases here were failing on Windows as the element type mapping was routing via the code `l`, which is a "C long int" - not an explicitly 64 bit type. This changes the mapping to always use the explicit "type strings" (any string in `numpy.sctypeDict.keys()`, [shown in this gist](https://gist.github.com/ScottTodd/ec1f7906e9c644eb47f74280d6c26229)). Relates to #12872
This will need to be rebased now that #12880 is merged. |
ed5093b
to
2a68742
Compare
I rebased and added some test cases, does this look good to merge? |
Yep, thanks! I triggered the CI. Looks like yapf wants you to remove trailing spaces: https://github.com/openxla/iree/actions/runs/4610947536/jobs/8150064584?pr=12872#step:6:15 (you can also configure your editor to trim tailing spaces on save) |
2a68742
to
4e88d30
Compare
Fixed, and it looks like all the tests passed |
Fixes #11080. The int64 and uint64 test cases here were failing on Windows as the element type mapping was routing via the code `l`, which is a "C long int" - not an explicitly 64 bit type. This changes the mapping to always use the explicit "type strings" (any string in `numpy.sctypeDict.keys()`, [shown in this gist](https://gist.github.com/ScottTodd/ec1f7906e9c644eb47f74280d6c26229)). Relates to #12872
added support for complex numbers in python bindings Co-authored-by: Elias Joseph <elias@nod-labs.com>
Fixes iree-org#11080. The int64 and uint64 test cases here were failing on Windows as the element type mapping was routing via the code `l`, which is a "C long int" - not an explicitly 64 bit type. This changes the mapping to always use the explicit "type strings" (any string in `numpy.sctypeDict.keys()`, [shown in this gist](https://gist.github.com/ScottTodd/ec1f7906e9c644eb47f74280d6c26229)). Relates to iree-org#12872
added support for complex numbers in python bindings Co-authored-by: Elias Joseph <elias@nod-labs.com>
added support for complex numbers in python bindings