Skip to content

Commit

Permalink
CI: fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Aug 7, 2024
1 parent de02988 commit edbcd00
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,10 +919,10 @@ def _ensure_4d(t):
if dims_to_add > 0:
return jnp.expand_dims(t, axis=tuple(range(dims_to_add)))
return t
query = _ensure_4d(query)
key = _ensure_4d(key)
value = _ensure_4d(value)

query_arr = _ensure_4d(query)
key_arr = _ensure_4d(key)
value_arr = _ensure_4d(value)
bias = _ensure_4d(bias) if bias is not None else None
mask = _ensure_4d(mask) if mask is not None else None

Expand All @@ -933,15 +933,15 @@ def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None:
if shape[i] != -1 and t.shape[i] != shape[i]:
raise ValueError(f"{name} shape should be {shape}: but got {t.shape}")

B, S, K, H = key.shape
_check_has_shape(value, [B, S, K, H], 'value')
_check_has_shape(query, [B, -1, -1, H], 'query')
if query.shape[-2] % K != 0:
B, S, K, H = key_arr.shape
_check_has_shape(query_arr, [B, S, K, H], 'value')
_check_has_shape(query_arr, [B, -1, -1, H], 'query')
if query_arr.shape[-2] % K != 0:
raise ValueError(f"The number of query heads must to a multiple of "
f"key/value heads, but got {query.shape[-2]} vs {K}")
if not (query.dtype == key.dtype == value.dtype):
f"key/value heads, but got {query_arr.shape[-2]} vs {K}")
if not (query_arr.dtype == key_arr.dtype == query_arr.dtype):
raise ValueError(f"query/key/value should have the same shape, but got "
f"{query.shape} vs {key.shape} vs {value.shape}.")
f"{query_arr.shape} vs {key_arr.shape} vs {query_arr.shape}.")
if mask is not None and mask.dtype != jnp.bool_ and mask.ndim != 4:
raise ValueError(f"Mask must be a 4D boolean tensor, but got "
f"rank={mask.ndim}, dtype={mask.dtype}.")
Expand All @@ -953,18 +953,18 @@ def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None:
match implementation:
case 'xla':
out = _dot_product_attention_xla(
query, key, value, bias, mask, is_causal=is_causal, scale=scale_val,
query_arr, key_arr, query_arr, bias, mask, is_causal=is_causal, scale=scale_val,
)
case 'cudnn':
mask_type = MaskType.CAUSAL if is_causal else MaskType.NO_MASK
out = cudnn_dot_product_attention(
query, key, value, bias, mask, scale=scale_val, mask_type=mask_type
query_arr, key_arr, query_arr, bias, mask, scale=scale_val, mask_type=mask_type
)
case None:
# TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select
# best backend.
out = _dot_product_attention_xla(
query, key, value, bias, mask, is_causal=is_causal, scale=scale_val,
query_arr, key_arr, query_arr, bias, mask, is_causal=is_causal, scale=scale_val,
)
case _:
raise ValueError(f"Unsupported implementation option: {implementation}")
Expand Down

0 comments on commit edbcd00

Please sign in to comment.