Skip to content

Commit

Permalink
Update convilution
Browse files Browse the repository at this point in the history
  • Loading branch information
kzrnm committed Apr 4, 2021
1 parent e8882ed commit 2c99feb
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 58 deletions.
2 changes: 2 additions & 0 deletions Source/AtCoderLibrary/Math/Internal/InternalMath.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;

namespace AtCoder.Internal
{
Expand Down Expand Up @@ -116,6 +117,7 @@ public static (long, long) InvGCD(long a, long b)
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static long SafeMod(long x, long m)
{
x %= m;
Expand Down
115 changes: 57 additions & 58 deletions Source/AtCoderLibrary/Math/MathLib.Convolution.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Diagnostics;
using AtCoder.Internal;

namespace AtCoder
Expand All @@ -18,10 +19,7 @@ public static partial class MathLib
/// </remarks>
public static StaticModInt<TMod>[] Convolution<TMod>(StaticModInt<TMod>[] a, StaticModInt<TMod>[] b)
where TMod : struct, IStaticMod
{
var temp = Convolution((ReadOnlySpan<StaticModInt<TMod>>)a, b);
return temp.ToArray();
}
=> Convolution((ReadOnlySpan<StaticModInt<TMod>>)a, b);

/// <summary>
/// 畳み込みを mod <typeparamref name="TMod"/> で計算します。
Expand All @@ -34,52 +32,61 @@ public static StaticModInt<TMod>[] Convolution<TMod>(StaticModInt<TMod>[] a, Sta
/// <para>- 2^c | (<typeparamref name="TMod"/> - 1) かつ |<paramref name="a"/>| + |<paramref name="b"/>| - 1 ≤ 2^c なる c が存在する</para>
/// <para>計算量: O((|<paramref name="a"/>|+|<paramref name="b"/>|)log(|<paramref name="a"/>|+|<paramref name="b"/>|) + log<typeparamref name="TMod"/>)</para>
/// </remarks>
public static Span<StaticModInt<TMod>> Convolution<TMod>(ReadOnlySpan<StaticModInt<TMod>> a, ReadOnlySpan<StaticModInt<TMod>> b)
public static StaticModInt<TMod>[] Convolution<TMod>(ReadOnlySpan<StaticModInt<TMod>> a, ReadOnlySpan<StaticModInt<TMod>> b)
where TMod : struct, IStaticMod
{
var n = a.Length;
var m = b.Length;
if (n == 0 || m == 0)
{
return Array.Empty<StaticModInt<TMod>>();
}

if (Math.Min(n, m) <= 60)
{
return ConvolutionNaive(a, b);
}

int z = 1 << InternalBit.CeilPow2(n + m - 1);

var aTemp = new StaticModInt<TMod>[z];
a.CopyTo(aTemp);

var bTemp = new StaticModInt<TMod>[z];
b.CopyTo(bTemp);

return Convolution(aTemp.AsSpan(), bTemp.AsSpan(), n, m, z);
return ConvolutionFFT(a.ToArray(), b.ToArray());
}

private static Span<StaticModInt<TMod>> Convolution<TMod>(Span<StaticModInt<TMod>> a, Span<StaticModInt<TMod>> b, int n, int m, int z)
private static StaticModInt<TMod>[] ConvolutionFFT<TMod>(StaticModInt<TMod>[] a, StaticModInt<TMod>[] b)
where TMod : struct, IStaticMod
{
int n = a.Length, m = b.Length;
int z = 1 << InternalBit.CeilPow2(n + m - 1);
Array.Resize(ref a, z);
Butterfly<TMod>.Calculate(a);
Array.Resize(ref b, z);
Butterfly<TMod>.Calculate(b);

for (int i = 0; i < a.Length; i++)
{
a[i] *= b[i];
}

Butterfly<TMod>.CalculateInv(a);
var result = a.Slice(0, n + m - 1);
Array.Resize(ref a, n + m - 1);
var iz = new StaticModInt<TMod>(z).Inv();
foreach (ref var r in result)

for (int i = 0; i < a.Length; i++)
a[i] *= iz;

return a;
}
private static StaticModInt<TMod>[] ConvolutionNaive<TMod>(ReadOnlySpan<StaticModInt<TMod>> a, ReadOnlySpan<StaticModInt<TMod>> b)
where TMod : struct, IStaticMod
{
if (a.Length < b.Length)
{
r *= iz;
// ref 構造体のため型引数として使えない
var temp = a;
a = b;
b = temp;
}

return result;
var ans = new StaticModInt<TMod>[a.Length + b.Length - 1];
for (int i = 0; i < a.Length; i++)
{
for (int j = 0; j < b.Length; j++)
{
ans[i + j] += a[i] * b[j];
}
}

return ans;
}

/// <summary>
Expand All @@ -104,26 +111,30 @@ public static long[] ConvolutionLong(ReadOnlySpan<long> a, ReadOnlySpan<long> b)
return Array.Empty<long>();
}

const ulong Mod1 = 754974721;
const ulong Mod2 = 167772161;
const ulong Mod3 = 469762049;
const ulong Mod1 = 754974721; // 2^24
const ulong Mod2 = 167772161; // 2^25
const ulong Mod3 = 469762049; // 2^26
const ulong M2M3 = Mod2 * Mod3;
const ulong M1M3 = Mod1 * Mod3;
const ulong M1M2 = Mod1 * Mod2;
// (m1 * m2 * m3) % 2^64
const ulong M1M2M3 = Mod1 * Mod2 * Mod3;

ulong i1 = (ulong)InternalMath.InvGCD((long)M2M3, (long)Mod1).Item2;
ulong i2 = (ulong)InternalMath.InvGCD((long)M1M3, (long)Mod2).Item2;
ulong i3 = (ulong)InternalMath.InvGCD((long)M1M2, (long)Mod3).Item2;
const ulong i1 = 190329765;
const ulong i2 = 58587104;
const ulong i3 = 187290749;

Debug.Assert(i1 == (ulong)InternalMath.InvGCD((long)M2M3, (long)Mod1).Item2);
Debug.Assert(i2 == (ulong)InternalMath.InvGCD((long)M1M3, (long)Mod2).Item2);
Debug.Assert(i3 == (ulong)InternalMath.InvGCD((long)M1M2, (long)Mod3).Item2);

var c1 = Convolution<FFTMod1>(a, b);
var c2 = Convolution<FFTMod2>(a, b);
var c3 = Convolution<FFTMod3>(a, b);

var c = new long[n + m - 1];

Span<ulong> offset = stackalloc ulong[] { 0, 0, M1M2M3, 2 * M1M2M3, 3 * M1M2M3 };
//ReadOnlySpan<ulong> offset = stackalloc ulong[] { 0, 0, M1M2M3, 2 * M1M2M3, 3 * M1M2M3 };

for (int i = 0; i < c.Length; i++)
{
Expand All @@ -145,37 +156,25 @@ public static long[] ConvolutionLong(ReadOnlySpan<long> a, ReadOnlySpan<long> b)
// x - 3M' + (0 or 2B or 4B or 6B)
// のいずれかが成り立つ、らしい
// -> see atcoder/convolution.hpp
x -= offset[(int)(diff % offset.Length)];
switch (diff % 5)
{
case 2:
x -= M1M2M3;
break;
case 3:
x -= 2 * M1M2M3;
break;
case 4:
x -= 3 * M1M2M3;
break;
}
c[i] = (long)x;
}

return c;
}
}

private static StaticModInt<TMod>[] ConvolutionNaive<TMod>(ReadOnlySpan<StaticModInt<TMod>> a, ReadOnlySpan<StaticModInt<TMod>> b)
where TMod : struct, IStaticMod
{
if (a.Length < b.Length)
{
// ref 構造体のため型引数として使えない
var temp = a;
a = b;
b = temp;
}

var ans = new StaticModInt<TMod>[a.Length + b.Length - 1];
for (int i = 0; i < a.Length; i++)
{
for (int j = 0; j < b.Length; j++)
{
ans[i + j] += a[i] * b[j];
}
}

return ans;
}

private readonly struct FFTMod1 : IStaticMod
{
public uint Mod => 754974721;
Expand Down

0 comments on commit 2c99feb

Please sign in to comment.