Skip to content

Commit

Permalink
Merge pull request #24580 from dfm:fix-ffi-test-segfault
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 691062859
  • Loading branch information
Google-ML-Automation committed Oct 29, 2024
2 parents eff6cb4 + 1785479 commit c67cf51
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tests/extend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import os
import sys
import unittest
from functools import partial

Expand All @@ -35,6 +34,7 @@
from jax._src import xla_bridge
from jax._src.interpreters import mlir
from jax._src.layout import DeviceLocalLayout
from jax._src.lib import lapack
from jax._src.lib.mlir.dialects import hlo
from jax._src.lax import linalg as lax_linalg_internal

Expand Down Expand Up @@ -261,10 +261,6 @@ def testFfiCallBatching(self, shape, vmap_method):

@jtu.run_on_devices("gpu", "cpu")
def testVectorizedDeprecation(self):
if sys.version_info.major == 3 and sys.version_info.minor == 13:
# TODO(b/376025274): Remove the skip once the bug is fixed.
raise unittest.SkipTest("Crashes on Python 3.13")

x = self.rng().randn(3, 5, 4).astype(np.float32)
with self.assertWarns(DeprecationWarning):
ffi_call_geqrf(x, vectorized=True)
Expand Down Expand Up @@ -332,6 +328,9 @@ def fun(x):


def ffi_call_geqrf(x, **kwargs):
if jtu.test_device_matches(["cpu"]):
lapack._lapack.initialize()

assert x.dtype == np.float32
ndim = x.ndim
x_major_to_minor = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2)
Expand Down

0 comments on commit c67cf51

Please sign in to comment.