Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dlimbs_algs_multi - attempted definition for when LIMBS is multiple of TPI (greater than TPI) #27

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .vscode/c_cpp_properties.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"configurations": [
{
"name": "Win32",
"includePath": [
"${workspaceFolder}/**"
],
"defines": [
"_DEBUG",
"UNICODE",
"_UNICODE"
]
}
],
"version": 4
}
218 changes: 218 additions & 0 deletions include/cgbn/core/dispatch_dlimbs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -469,4 +469,222 @@ class dispatch_dlimbs_t<core, dlimbs_algs_full> {
}
};

template<class core>
class dispatch_dlimbs_t<core, dlimbs_algs_multi> {
public:
static const uint32_t TPI=core::TPI;
static const uint32_t LIMBS=core::LIMBS;
static const uint32_t DLIMBS=core::DLIMBS;
static const uint32_t LIMB_OFFSET=DLIMBS*TPI-LIMBS;

// These algorithms are used when LIMBS >= TPI. Almost the same as the half/full size ones, few tweaks here and there.

__device__ __forceinline__ static void dlimbs_approximate(uint32_t approx[DLIMBS], const uint32_t denom[DLIMBS]) {
uint32_t sync=core::sync_mask(), group_thread=threadIdx.x & TPI-1;
uint32_t x, d0, d1, x0, x1, x2, est, a, h, l, rem = !(LIMBS % TPI)? 0 : TPI - (LIMBS % TPI);//This is the equivalent of TPI-LIMBS. When TPI < LIMBS it can be either 0 (when LIMBS is a multiple of TPI, like LIMBS=64 with TPI=32) or a value between 1 and TPI-1, if LIMBS is not a multiple of TPI (e.g. TPI = 32, BITS = 1056, LIMBS = 33 = 1056/32, where the last 32 is not TPI but the universal number of bits per limb - rem will be 32-1).
int32_t c, top;

// computes (beta^2 - 1) / denom - beta, where beta=1<<32*LIMBS

x=0xFFFFFFFF-denom[0];

d1=__shfl_sync(sync, denom[0], TPI-1, TPI);
d0=__shfl_sync(sync, denom[0], TPI-2, TPI);

approx[0]=0;
a=uapprox(d1);

#pragma nounroll
for(int32_t thread = LIMBS-1; thread>=0; thread--) {//Please properly indent your source code.
x0=__shfl_sync(sync, x, TPI-3, TPI);
x1=__shfl_sync(sync, x, TPI-2, TPI);
x2=__shfl_sync(sync, x, TPI-1, TPI);
est=udiv(x0, x1, x2, d0, d1, a);

l=madlo_cc(est, denom[0], 0);
h=madhic(est, denom[0], 0);

x=sub_cc(x, h);
c=subc(0, 0); // thread TPI-1 is zero

top=__shfl_sync(sync, x, TPI-1, TPI);
x=__shfl_up_sync(sync, x, 1, TPI);
c=__shfl_sync(sync, c, threadIdx.x-1, TPI);
x=(group_thread==0) ? 0xFFFFFFFF : x;

x=sub_cc(x, l);
c=subc(c, 0);

if(top+core::resolve_sub(c, x)<0) {
// means a correction is required, should be very rare
x=add_cc(x, denom[0]);
c=addc(0, 0);
core::fast_propagate_add(c, x);
est--;
}
//approx[0]=(group_thread==thread+TPI-LIMBS) ? est : approx[0];
approx[0]=(group_thread==thread+rem) ? est : approx[0];
}
}

__device__ __forceinline__ static uint32_t dlimbs_sqrt_rem_wide(uint32_t s[DLIMBS], uint32_t r[DLIMBS], const uint32_t lo[DLIMBS], const uint32_t hi[DLIMBS]) {
uint32_t sync=core::sync_mask(), group_thread=threadIdx.x & TPI-1;
uint32_t x, x0, x1, t0, t1, divisor, approx, p, q, c, low, rem = !(LIMBS % TPI)? 0 : TPI - (LIMBS % TPI);

// computes s=sqrt(x), r=x-s^2, where x=(hi<<32*LIMBS) + lo

low=lo[0];
x=hi[0];
/* if(TPI ^ LIMBS) {//Always true for is_multi when TPI < LIMBS. Also, threadIdx.x-TPI+LIMBS would then be greater than threadIdx.x . Moreover, we can say ^ for !=.
low=__shfl_sync(sync, low, threadIdx.x-TPI+LIMBS, TPI);
x=((int32_t)group_thread>=(int32_t)(TPI-LIMBS)) ? x : low; // use casts to silence warning
}*/
//Alternative approach:
t0=__shfl_sync(sync, lo[0], threadIdx.x+LIMBS, TPI);
x=hi[0] | t0;

x0=__shfl_sync(sync, x, TPI-2, TPI);
x1=__shfl_sync(sync, x, TPI-1, TPI);

divisor=usqrt(x0, x1);
approx=uapprox(divisor);

t0=madlo_cc(divisor, divisor, 0);
t1=madhic(divisor, divisor, 0);
x0=sub_cc(x0, t0);
x1=subc(x1, t1);

x=(group_thread==TPI-1) ? low : x;
x=__shfl_sync(sync, x, threadIdx.x-1, TPI);
x=(group_thread==TPI-1) ? x0 : x;
s[0]=(group_thread==TPI-1) ? divisor+divisor : 0; // silent 1 at the top of s

#pragma nounroll
for(int32_t index=TPI-2;index >= (int32_t)(0);index--) {//TPI < LIMBS here, need to adjust. TPI-LIMBS would result in a less than zero number. For example BITS=2048, LIMBS=64, TPI=32
x0=__shfl_sync(sync, x, TPI-1, TPI);
q=usqrt_div(x0, x1, divisor, approx);
s[0]=(group_thread==index) ? q : s[0];

p=madhi(q, s[0], 0);
x=sub_cc(x, p);
c=subc(0, 0);
core::fast_propagate_sub(c, x);

x1=__shfl_sync(sync, x, TPI-1, TPI)-q; // we subtract q because of the silent 1 at the top of s
t0=__shfl_sync(sync, low, index, TPI);
x=__shfl_up_sync(sync, x, 1, TPI);
x=(group_thread==0) ? t0 : x;

p=madlo(q, s[0], 0);
x=sub_cc(x, p);
c=subc(0, 0);
x1-=core::fast_propagate_sub(c, x);

while(0 > (int32_t)x1) {
x1++;
q--;

// correction step: add q and s
x=add_cc(x, (group_thread==index) ? q : 0);
c=addc(0, 0);
x=add_cc(x, s[0]);
c=addc(c, 0);

x1+=core::resolve_add(c, x);

// update s
s[0]=(group_thread==index) ? q : s[0];
}
s[0]=(group_thread==index+1) ? s[0]+(q>>31) : s[0];
s[0]=(group_thread==index) ? q+q : s[0];
}
t0=__shfl_down_sync(sync, s[0], 1, TPI);
t0=(group_thread==TPI-1) ? 1 : t0;
s[0]=uright_wrap(s[0], t0, 1);
r[0]=x;
return x1;
}

__device__ __forceinline__ static void dlimbs_div_estimate(uint32_t q[DLIMBS], const uint32_t x[DLIMBS], const uint32_t approx[DLIMBS]) {
uint32_t sync=core::sync_mask(), group_thread=threadIdx.x & TPI-1;
uint32_t t, c, rem = !(LIMBS % TPI)? 0 : TPI - (LIMBS % TPI);
uint64_t w;

// computes q=(x*approx>>32*LIMBS) + x + 3
// q=min(q, (1<<32*LIMBS)-1);
//
// Notes: leaves junk in lower words of q

w=0;
#pragma unroll
for(int32_t index=0; index<LIMBS; index++) {
//t=__shfl_sync(sync, x[0], TPI-LIMBS+index, TPI);
t=__shfl_sync(sync, x[0], rem+index, TPI);
w=mad_wide(t, approx[0], w);
t=__shfl_sync(sync, ulow(w), threadIdx.x+1, TPI);
// t=(group_thread==TPI-1) ? 0 : t;
t=((group_thread + 1) & (TPI-1) == 0) ? 0 : t;//group_thread+1 is divisible by TPI that is a power of two, therefore masking the last log(TPI) bits of group_thread.
w=(w>>32)+t;
}

// increase the estimate by 3
//t=(group_thread==TPI-LIMBS) ? 3 : 0;
t=(group_thread == rem) ? 3 : 0;
w=w + t + x[0];

q[0]=ulow(w);
c=uhigh(w);
if(core::resolve_add(c, q[0])!=0)
q[0]=0xFFFFFFFF;
}

__device__ __forceinline__ static void dlimbs_sqrt_estimate(uint32_t q[DLIMBS], uint32_t top, const uint32_t x[DLIMBS], const uint32_t approx[DLIMBS]) {
uint32_t sync=core::sync_mask(), group_thread=threadIdx.x & TPI-1;
uint32_t t, high, low, rem = !(LIMBS % TPI)? 0 : TPI - (LIMBS % TPI);
uint64_t w;

// computes:
// 1. num=((top<<32*LIMBS) + x) / 2
// 2. q=(num*approx>>32*LIMBS) + num + 4
// 3. q=min(q, (1<<32*LIMBS)-1);
//
// Note: Leaves junk in lower words of q

// shift x right by 1 bit. Fill high bit with top.
t=__shfl_down_sync(sync, x[0], 1, TPI);
t=(group_thread==TPI-1) ? top : t;
low=uright_wrap(x[0], t, 1);

// if we're exactly multiple of the size, need to clear out low limb. Not sure if this is really needed at multi, if was for half (LIMBS half of TPI) and already not for full.
// if(LIMBS % TPI == 0) {
//low=(group_thread>=LIMBS) ? low : 0;
// }

// estimate is in low
w=0;
#pragma unroll
for(int32_t index=0;index<LIMBS;index++) {
t=__shfl_sync(sync, low, rem+index, TPI);
w=mad_wide(t, approx[0], w);
t=__shfl_down_sync(sync, ulow(w), 1, TPI);
// t=(group_thread==TPI-1) ? 0 : t;
//t=((group_thread+1) % TPI == 0) ? 0 : t;
t=((group_thread + 1) & (TPI-1) == 0) ? 0 : t;//group_thread+1 is divisible by TPI that is a power of two, therefore masking the last log(TPI) bits of group_thread.
w = (w>>32)+t;
}

// increase the estimate by 4 -- because we might have cleared low bit, estimate can be off by 4
t = (group_thread == rem) ? 4 : 0;
w = w + t + low;

low=ulow(w);
high=uhigh(w);
if(core::resolve_add(high, low)!=0) {
low=0xFFFFFFFF;
}

q[0]=low;
}
};

} /* namespace cgbn */