-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MSM - supporting all window sizes #534
Conversation
Fixed precomputation tests in rust and go. Precomputaion function interface has changed as a result. |
wrappers/golang/core/msm.go
Outdated
AreScalarsOnDevice bool | ||
|
||
/// True if scalars are in Montgomery form and false otherwise. Default value: true. | ||
AreScalarsMontgomeryForm bool | ||
|
||
arePointsOnDevice bool | ||
ArePointsOnDevice bool | ||
|
||
/// True if coordinates of points are in Montgomery form and false otherwise. Default value: true. | ||
ArePointsMontgomeryForm bool | ||
|
||
areResultsOnDevice bool | ||
AreResultsOnDevice bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should remain private as the XXXCheck
functions should manipulate them internally based on the slice arguments. Its less error prone for the user
e = msm.PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut) | ||
cfg.PrecomputeFactor = precomputeFactor | ||
cfg.PointsSize = int32(points.Len()) | ||
cfg.ArePointsOnDevice = points.IsOnDevice() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the previous comment, this should be checked and updated in core.PrecomputeBasesCheck
which is called in msm.PrecomputeBases
Same for cfg.PointsSize
, it should be private and updated in core.PrecomputeBasesCheck
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, will fix, what about the rust test? is it important there too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there as well. PrecomputeFactor
is fine to update manually as it is public but points_size
is not public so it shouldn't be updated manually
{ | ||
const int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (tid >= num_of_threads) { return; } | ||
if (tid >= nof_threads) { return; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be consistent on formatting; see line 99
if (tid >= nof_threads) { return; } | |
if (tid >= nof_threads) return; |
0, // bitsize | ||
10, // large_bucket_factor | ||
batch_size, // batch_size | ||
false, // are_scalars_on_device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why don't you use scalar_d if you already copied them to device? intentional? if yes, no need to allocate scalars_d and copy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's just a test file, I copy to have all options available, I don't use scalar_d because in zprize you need to include scalar transfer time
@@ -726,62 +745,90 @@ namespace msm { | |||
NUM_BLOCKS = (nof_bms_in_batch + NUM_THREADS - 1) / NUM_THREADS; | |||
big_triangle_sum_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(buckets, final_results, nof_bms_in_batch, c); | |||
} else { | |||
// the recursive reduction algorithm works with 2 types of reduction that can run on parallel streams |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any chance you move all this reduction logic to a 'reduction_phase()' function? thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in that case we should also have scalar_splitting_phase() sorting_phase() accumulation_phase() and so on, if we want this refactoring let's do it in a different PR
icicle/src/msm/msm.cu
Outdated
bool are_bases_on_device, | ||
device_context::DeviceContext& ctx, | ||
A* output_bases) | ||
cudaError_t precompute_msm_bases(A* bases, int msm_size, MSMConfig& config, A* output_bases) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After discussing with the team, in order to prevent a breaking change right now, we think its best to:
- create a second function that will call the old function after computing
c
- add a deprecation comment on the old function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, why do we need to call the old function? I will just give my new precompute function a different name and also keep the old one as deprecated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just so you don't implement twice.
This PR enables using MSM with any value of c.
Note: default c isn't necessarily optimal, the user is expected to choose c and the precomputation factor that give the best results for the relevant case.