Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move checkConcretization for reshapes #2363

Merged
merged 9 commits into from
Jun 8, 2024

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Jun 7, 2024

The included test is the small one provided by @jjsjann123 in #2359 and it's actually tougher than the original repro. It necessitates either removing the check that concretized squeezed extents are constant 1 or to concretize Resized to broadcast extents as constant 1 so that we can evaluate max(0, min(i0, 1)) as oneVal() without calling simplifyExpr. I went with removing the check, which means in this example we have broadcast dimension with a dynamic shape like max(0, min(i0, 1)). Since we're concretizing to Broadcast, we know that dimension is not zero; if it were then we'd concretize to Iteration and SqueezeOp::checkConcretization would fail the IterType check. Still, I don't love that the expression cannot be simplified so it appears in the kernel (i9 and i10):

__global__ void nvfuser_pointwise_f0_c1_r0_g1(Tensor<float, 3, 3> T0, Tensor<float, 2, 2> T4) {
  nvfuser_index_t i0;
  i0 = ((nvfuser_index_t)threadIdx.x) + (128LL * ((nvfuser_index_t)blockIdx.x));
  Tensor<float, 3, 3> s1;
  s1.data = T0.data;
  s1.logical_size = T0.logical_size;
  s1.alloc_stride = T0.alloc_stride;
  Array<nvfuser_index_t, 3, 1> a2;
  a2 = s1.logical_size;
  nvfuser_index_t i3;
  i3 = a2[2LL];
  nvfuser_index_t i4;
  i4 = max(0LL, (min(4LL, i3)));
  nvfuser_index_t i5;
  i5 = min(i3, 4LL);
  nvfuser_index_t i6;
  i6 = max(0LL, i5);
  Array<nvfuser_index_t, 3, 1> a7;
  a7 = s1.logical_size;
  nvfuser_index_t i8;
  i8 = a7[0LL];
  nvfuser_index_t i9;
  i9 = min(i8, 1LL);
  nvfuser_index_t i10;
  i10 = max(0LL, i9);
  Array<nvfuser_index_t, 3, 1> a11;
  a11 = s1.logical_size;
  nvfuser_index_t i12;
  i12 = a11[1LL];
  nvfuser_index_t i13;
  i13 = (max(0LL, (min(2LL, i12)))) * i4;
  nvfuser_index_t i14;
  i14 = i0 % i13;
  nvfuser_index_t i15;
  i15 = min(i12, 2LL);
  nvfuser_index_t i16;
  i16 = max(0LL, i15);
  if ((i0 < i13)) {
    float T1[1LL];
    T1[0LL] = 0LL;
    T1[0LL]
       = T0[((((i3 * i12) * (i0 / i13)) + (i3 * (i14 / i4))) + (i14 % i4))];
    float T5[1LL];
    T5[0LL]
       = T1[0LL];
    T4[i0]
       = T5[0LL];
  }
}

If you look closely though, i10 is not used so it will be DCEd anyway. Still, it might be nice to concretize broadcast extents to 1 just to clean up these expressions if they appear downstream. I tried that hastily but ran into some issues so I'll leave it for another PR.

Fixes #2359

This test is actually tougher than the big repro on #2359. It
necessitate either removing the check altogether or the path I took
which is to concretize Resized to broadcast extents as constant 1 so
that we can evaluate max(0, min(i0, 1)) as 1 without calling
simplifyExpr. A less invasive solution would be to remove the extent
check in `SqueezeOp::checkConcretization`.

We could also just remove `Expr::checkConcretization` and
`checkConcretizedUses`. They are only used for SqueezeOp currently and
are not adding much value anyway probably.
@jacobhinkle
Copy link
Collaborator Author

!build --diff

@jacobhinkle jacobhinkle marked this pull request as ready for review June 7, 2024 11:34
@jacobhinkle
Copy link
Collaborator Author

Confirmed this fixes the thunder benchmark bug. This runs:

cd /opt/pytorch/lightning-thunder
git fetch origin
git switch wjy/sharded
pytest thunder/benchmarks/targets.py -k test_nanogpt_block_grad[thunder] -s

We actually _can_ squeeze expanded dimensions, which is how we compute
reductions of expanded dims.
Comment on lines -1412 to -1413
NVF_CHECK(
!new_id->hasExpandedExtent(), "Can not squeeze expanded dimension(s).");
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is outdated as of #1679 which allowed squeezing expanded dimensions.

@jacobhinkle
Copy link
Collaborator Author

!build

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@jjsjann123
Copy link
Collaborator

Looks like CI is down?! cc'ing @xwang233

@xwang233
Copy link
Collaborator

xwang233 commented Jun 7, 2024

Looks like CI is down?! cc'ing @xwang233

Seems like a machine issue. I restarted the failing jobs.

@jacobhinkle
Copy link
Collaborator Author

!build

@wujingyue wujingyue merged commit b7e3694 into main Jun 8, 2024
35 of 37 checks passed
@wujingyue wujingyue deleted the squeeze_reshape_check_concretization branch June 8, 2024 21:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Squeezed IterDomain ?S536{1} must concretize to IterType::Broadcast but found ?S536{1}.
5 participants