Skip to content

Commit

Permalink
[SPARK-45179][PYTHON] Increase Numpy minimum version to 1.21
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Increase Numpy minimum version to 1.21

### Why are the changes needed?

- according to the [release history](https://pypi.org/project/numpy/#history), Numpy 1.15 was released about 5 years ago, while the last maintenance release in 1.21 was released 1 year ago;
- with 1.21 as the minimum version, we can discard all version checking in PySpark;
- `pandas==1.4.4` just depends on `numpy>=1.21.0`;

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
search with `ag`

```
(spark_dev_310) ➜  spark git:(master) ag --py 'numpy\.__version' python
(spark_dev_310) ➜  spark git:(master)
(spark_dev_310) ➜  spark git:(master) ag --py 'np\.__version' python
python/pyspark/ml/image.py
231:        if LooseVersion(np.__version__) >= LooseVersion("1.9"):

python/pyspark/pandas/typedef/typehints.py
152:    if sys.version_info >= (3, 8) and LooseVersion(np.__version__) >= LooseVersion("1.21"):

python/pyspark/pandas/tests/test_typedef.py
365:            if sys.version_info >= (3, 8) and LooseVersion(np.__version__) >= LooseVersion("1.21"):

python/pyspark/pandas/tests/computation/test_apply_func.py
257:        if sys.version_info >= (3, 8) and LooseVersion(np.__version__) >= LooseVersion("1.21"):
```

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #42944 from zhengruifeng/bump_min_np_ver.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng committed Sep 15, 2023
1 parent 0d1f43c commit 82ed1c9
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 21 deletions.
2 changes: 1 addition & 1 deletion python/docs/source/getting_started/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ Package Supported version Note
`py4j` >=0.10.9.7 Required
`pandas` >=1.4.4 Required for pandas API on Spark and Spark Connect; Optional for Spark SQL
`pyarrow` >=4.0.0 Required for pandas API on Spark and Spark Connect; Optional for Spark SQL
`numpy` >=1.15 Required for pandas API on Spark and MLLib DataFrame-based API; Optional for Spark SQL
`numpy` >=1.21 Required for pandas API on Spark and MLLib DataFrame-based API; Optional for Spark SQL
`grpcio` >=1.48,<1.57 Required for Spark Connect
`grpcio-status` >=1.48,<1.57 Required for Spark Connect
`googleapis-common-protos` ==1.56.4 Required for Spark Connect
Expand Down
10 changes: 1 addition & 9 deletions python/pyspark/ml/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from typing import Any, Dict, List, NoReturn, Optional, cast

import numpy as np
from distutils.version import LooseVersion

from pyspark import SparkContext
from pyspark.sql.types import Row, StructType, _create_row, _parse_datatype_json_string
Expand Down Expand Up @@ -225,14 +224,7 @@ def toImage(self, array: np.ndarray, origin: str = "") -> Row:
else:
raise ValueError("Invalid number of channels")

# Running `bytearray(numpy.array([1]))` fails in specific Python versions
# with a specific Numpy version, for example in Python 3.6.0 and NumPy 1.13.3.
# Here, it avoids it by converting it to bytes.
if LooseVersion(np.__version__) >= LooseVersion("1.9"):
data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes())
else:
# Numpy prior to 1.9 don't have `tobytes` method.
data = bytearray(array.astype(dtype=np.uint8).ravel())
data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes())

# Creating new Row with _create_row(), because Row(name = value, ... )
# orders fields by name, which conflicts with expected schema order
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/pandas/tests/computation/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
#
from datetime import datetime
from distutils.version import LooseVersion
import sys
import unittest
from typing import List
Expand Down Expand Up @@ -254,7 +253,7 @@ def identify3(x) -> ps.DataFrame[float, [int, List[int]]]:
self.assert_eq(actual, pdf)

# For NumPy typing, NumPy version should be 1.21+ and Python version should be 3.8+
if sys.version_info >= (3, 8) and LooseVersion(np.__version__) >= LooseVersion("1.21"):
if sys.version_info >= (3, 8):
import numpy.typing as ntp

psdf = ps.from_pandas(pdf)
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/pandas/tests/test_typedef.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import unittest
import datetime
import decimal
from distutils.version import LooseVersion
from typing import List

import pandas
Expand Down Expand Up @@ -362,7 +361,7 @@ def test_as_spark_type_pandas_on_spark_dtype(self):
)

# For NumPy typing, NumPy version should be 1.21+ and Python version should be 3.8+
if sys.version_info >= (3, 8) and LooseVersion(np.__version__) >= LooseVersion("1.21"):
if sys.version_info >= (3, 8):
import numpy.typing as ntp

self.assertEqual(
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/pandas/typedef/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import sys
import typing
from collections.abc import Iterable
from distutils.version import LooseVersion
from inspect import isclass
from typing import Any, Callable, Generic, List, Tuple, Union, Type, get_type_hints

Expand Down Expand Up @@ -149,7 +148,7 @@ def as_spark_type(
- Python3's typing system
"""
# For NumPy typing, NumPy version should be 1.21+ and Python version should be 3.8+
if sys.version_info >= (3, 8) and LooseVersion(np.__version__) >= LooseVersion("1.21"):
if sys.version_info >= (3, 8):
if (
hasattr(tpe, "__origin__")
and tpe.__origin__ is np.ndarray # type: ignore[union-attr]
Expand Down
11 changes: 6 additions & 5 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def _supports_symlinks():
# binary format protocol with the Java version, see ARROW_HOME/format/* for specifications.
# Also don't forget to update python/docs/source/getting_started/install.rst.
_minimum_pandas_version = "1.4.4"
_minimum_numpy_version = "1.21"
_minimum_pyarrow_version = "4.0.0"
_minimum_grpc_version = "1.56.0"
_minimum_googleapis_common_protos_version = "1.56.4"
Expand Down Expand Up @@ -307,25 +308,25 @@ def run(self):
# if you're updating the versions or dependencies.
install_requires=["py4j==0.10.9.7"],
extras_require={
"ml": ["numpy>=1.15"],
"mllib": ["numpy>=1.15"],
"ml": ["numpy>=%s" % _minimum_numpy_version],
"mllib": ["numpy>=%s" % _minimum_numpy_version],
"sql": [
"pandas>=%s" % _minimum_pandas_version,
"pyarrow>=%s" % _minimum_pyarrow_version,
"numpy>=1.15",
"numpy>=%s" % _minimum_numpy_version,
],
"pandas_on_spark": [
"pandas>=%s" % _minimum_pandas_version,
"pyarrow>=%s" % _minimum_pyarrow_version,
"numpy>=1.15",
"numpy>=%s" % _minimum_numpy_version,
],
"connect": [
"pandas>=%s" % _minimum_pandas_version,
"pyarrow>=%s" % _minimum_pyarrow_version,
"grpcio>=%s" % _minimum_grpc_version,
"grpcio-status>=%s" % _minimum_grpc_version,
"googleapis-common-protos>=%s" % _minimum_googleapis_common_protos_version,
"numpy>=1.15",
"numpy>=%s" % _minimum_numpy_version,
],
},
python_requires=">=3.8",
Expand Down

0 comments on commit 82ed1c9

Please sign in to comment.