Skip to content

Commit

Permalink
[TPU[Mosaic] Fix missing sfences in smem DMAs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 724545200
  • Loading branch information
Google-ML-Automation committed Feb 8, 2025
1 parent 2890357 commit 5633ed6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
17 changes: 10 additions & 7 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ class LoweringRuleContext:
avals_in: Sequence[jax_core.AbstractValue]
avals_out: Sequence[jax_core.AbstractValue]
block_shapes: Sequence[tuple[int | pallas_core.Mapped, ...] | None]

replace = dataclasses.replace

@property
Expand Down Expand Up @@ -3362,22 +3361,26 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree,
device_id = _device_id_to_logical(ctx, device_id, device_id_type)
tpu.enqueue_dma(src_ref, dst_ref, sem, source_semaphore=src_sem,
device_id=device_id)

return []
lowering_rules[tpu_primitives.dma_start_p] = _dma_start_lowering_rule


def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree,
device_id_type: tpu_primitives.DeviceIdType):
del device_id_type
(_, _, ref, transforms, sem, sem_transforms, _, _, _) = tree_util.tree_unflatten(
tree, args)
(_, _, ref_aval, _, sem_aval, _, _, _, _) = tree_util.tree_unflatten(
tree, ctx.avals_in)
(src, src_transforms, dst, transforms, sem, sem_transforms, _, _, _) = (
tree_util.tree_unflatten(tree, args)
)
(src_aval, _, dst_aval, _, sem_aval, _, _, _, _) = tree_util.tree_unflatten(
tree, ctx.avals_in
)
block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes)
ref_block_shape = block_shapes[2]
ref, _ = _transform_ref(ref, ref_aval.dtype, ref_block_shape, transforms)
src, _ = _transform_ref(src, src_aval.dtype, src_aval.shape, src_transforms)
dst, _ = _transform_ref(dst, dst_aval.dtype, ref_block_shape, transforms)
sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms)
tpu.wait_dma(sem, ref)
tpu.wait_dma(sem, src, dst)
return []
lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule

Expand Down
3 changes: 2 additions & 1 deletion jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,8 @@ def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> {
def TPU_WaitDMAOp : TPU_Op<"wait_dma"> {
let arguments = (ins
MemRefOf<[TPU_DMASemaphoreType]>:$semaphore,
AnyMemRef:$ref
AnyMemRef:$src,
AnyMemRef:$dst
);
let hasVerifier = 1;
}
Expand Down

0 comments on commit 5633ed6

Please sign in to comment.