diff --git a/LLama.Examples/Examples/BatchedExecutorGuidance.cs b/LLama.Examples/Examples/BatchedExecutorGuidance.cs index 29f6aa33c..6f3eceaba 100644 --- a/LLama.Examples/Examples/BatchedExecutorGuidance.cs +++ b/LLama.Examples/Examples/BatchedExecutorGuidance.cs @@ -79,7 +79,7 @@ await AnsiConsole guidance.Prompt(g); // Early exit if we reach the natural end of the guided sentence - if (g == model.EndOfSentenceToken) + if (g == model.Tokens.EOS) break; // Update progress bar diff --git a/LLama.Examples/Examples/GetEmbeddings.cs b/LLama.Examples/Examples/GetEmbeddings.cs index 9a816b054..1e10ba22b 100644 --- a/LLama.Examples/Examples/GetEmbeddings.cs +++ b/LLama.Examples/Examples/GetEmbeddings.cs @@ -9,7 +9,7 @@ public static void Run() string modelPath = UserSettings.GetModelPath(); Console.ForegroundColor = ConsoleColor.DarkGray; - var @params = new ModelParams(modelPath) { EmbeddingMode = true }; + var @params = new ModelParams(modelPath) { Embeddings = true }; using var weights = LLamaWeights.LoadFromFile(@params); var embedder = new LLamaEmbedder(weights, @params); diff --git a/LLama.Examples/Examples/SemanticKernelMemory.cs b/LLama.Examples/Examples/SemanticKernelMemory.cs index 1c9471d86..46c9a17d9 100644 --- a/LLama.Examples/Examples/SemanticKernelMemory.cs +++ b/LLama.Examples/Examples/SemanticKernelMemory.cs @@ -20,7 +20,7 @@ public static async Task Run() var parameters = new ModelParams(modelPath) { Seed = seed, - EmbeddingMode = true + Embeddings = true }; using var model = LLamaWeights.LoadFromFile(parameters); diff --git a/LLama.KernelMemory/BuilderExtensions.cs b/LLama.KernelMemory/BuilderExtensions.cs index 474cf8be4..07770244b 100644 --- a/LLama.KernelMemory/BuilderExtensions.cs +++ b/LLama.KernelMemory/BuilderExtensions.cs @@ -84,7 +84,7 @@ public static IKernelMemoryBuilder WithLLamaSharpDefaults(this IKernelMemoryBuil ContextSize = config?.ContextSize ?? 2048, Seed = config?.Seed ?? 0, GpuLayerCount = config?.GpuLayerCount ?? 20, - EmbeddingMode = true, + Embeddings = true, MainGpu = config?.MainGpu ?? 0, SplitMode = config?.SplitMode ?? GPUSplitMode.None, }; diff --git a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs index 4a089fe49..d8c366bcf 100644 --- a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs +++ b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs @@ -29,7 +29,7 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config) this._config = config; var @params = new ModelParams(_config.ModelPath) { - EmbeddingMode = true, + Embeddings = true, MainGpu = _config.MainGpu, SplitMode = _config.SplitMode }; @@ -49,7 +49,7 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config, LLamaWeights we this._config = config; var @params = new ModelParams(_config.ModelPath) { - EmbeddingMode = true, + Embeddings = true, MainGpu = _config.MainGpu, SplitMode = _config.SplitMode }; diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs index b8350336a..7c897b781 100644 --- a/LLama.Unittest/BasicTest.cs +++ b/LLama.Unittest/BasicTest.cs @@ -15,7 +15,7 @@ public sealed class BasicTest public BasicTest(ITestOutputHelper testOutputHelper) { _testOutputHelper = testOutputHelper; - _params = new ModelParams(Constants.ModelPath) + _params = new ModelParams(Constants.GenerativeModelPath) { ContextSize = 2048 }; diff --git a/LLama.Unittest/BeamTests.cs b/LLama.Unittest/BeamTests.cs index 83eb87d35..f4aa01abe 100644 --- a/LLama.Unittest/BeamTests.cs +++ b/LLama.Unittest/BeamTests.cs @@ -15,7 +15,7 @@ public sealed class BeamTests public BeamTests(ITestOutputHelper testOutputHelper) { _testOutputHelper = testOutputHelper; - _params = new ModelParams(Constants.ModelPath) + _params = new ModelParams(Constants.GenerativeModelPath) { ContextSize = 2048 }; diff --git a/LLama.Unittest/Constants.cs b/LLama.Unittest/Constants.cs index 6e6324491..6e5e92c54 100644 --- a/LLama.Unittest/Constants.cs +++ b/LLama.Unittest/Constants.cs @@ -2,9 +2,11 @@ { internal static class Constants { - public static string ModelPath = "Models/llama-2-7b-chat.Q3_K_S.gguf"; - public static string LLavaModelPath = "Models/llava-v1.6-mistral-7b.Q3_K_XS.gguf"; - public static string LLavaMmpPath = "Models/mmproj-model-f16.gguf"; - public static string LLavaImage = "Models/extreme-ironing-taxi-610x427.jpg"; + public static readonly string GenerativeModelPath = "Models/llama-2-7b-chat.Q3_K_S.gguf"; + public static readonly string EmbeddingModelPath = "Models/all-MiniLM-L12-v2.Q8_0.gguf"; + + public static readonly string LLavaModelPath = "Models/llava-v1.6-mistral-7b.Q3_K_XS.gguf"; + public static readonly string LLavaMmpPath = "Models/mmproj-model-f16.gguf"; + public static readonly string LLavaImage = "Models/extreme-ironing-taxi-610x427.jpg"; } } diff --git a/LLama.Unittest/GrammarTest.cs b/LLama.Unittest/GrammarTest.cs index c5fcf0ed3..1ab9dea61 100644 --- a/LLama.Unittest/GrammarTest.cs +++ b/LLama.Unittest/GrammarTest.cs @@ -12,7 +12,7 @@ public sealed class GrammarTest public GrammarTest() { - _params = new ModelParams(Constants.ModelPath) + _params = new ModelParams(Constants.GenerativeModelPath) { ContextSize = 2048, Seed = 92, diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj index d48b308e0..e7ef46231 100644 --- a/LLama.Unittest/LLama.Unittest.csproj +++ b/LLama.Unittest/LLama.Unittest.csproj @@ -31,6 +31,9 @@ + + + @@ -43,6 +46,9 @@ + + PreserveNewest + PreserveNewest diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs index ab27d9887..fe247c6ed 100644 --- a/LLama.Unittest/LLamaContextTests.cs +++ b/LLama.Unittest/LLamaContextTests.cs @@ -11,7 +11,7 @@ public sealed class LLamaContextTests public LLamaContextTests() { - var @params = new ModelParams(Constants.ModelPath) + var @params = new ModelParams(Constants.GenerativeModelPath) { ContextSize = 768, }; diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs index 9935fc863..379b8dc6e 100644 --- a/LLama.Unittest/LLamaEmbedderTests.cs +++ b/LLama.Unittest/LLamaEmbedderTests.cs @@ -1,5 +1,7 @@ using LLama.Common; +using LLama.Native; using Xunit.Abstractions; +using Xunit.Sdk; namespace LLama.Unittest; @@ -12,11 +14,11 @@ public sealed class LLamaEmbedderTests public LLamaEmbedderTests(ITestOutputHelper testOutputHelper) { _testOutputHelper = testOutputHelper; - var @params = new ModelParams(Constants.ModelPath) + var @params = new ModelParams(Constants.EmbeddingModelPath) { ContextSize = 4096, Threads = 5, - EmbeddingMode = true, + Embeddings = true, }; using var weights = LLamaWeights.LoadFromFile(@params); _embedder = new(weights, @params); @@ -38,8 +40,13 @@ private static float Dot(float[] a, float[] b) public async Task EmbedCompare() { var cat = await _embedder.GetEmbeddings("The cat is cute"); + Assert.DoesNotContain(float.NaN, cat); + var kitten = await _embedder.GetEmbeddings("The kitten is kawaii"); + Assert.DoesNotContain(float.NaN, kitten); + var spoon = await _embedder.GetEmbeddings("The spoon is not real"); + Assert.DoesNotContain(float.NaN, spoon); _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]"); _testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]"); @@ -47,6 +54,11 @@ public async Task EmbedCompare() var close = 1 - Dot(cat, kitten); var far = 1 - Dot(cat, spoon); + + _testOutputHelper.WriteLine(""); + _testOutputHelper.WriteLine($"Cat.Kitten (Close): {close:F4}"); + _testOutputHelper.WriteLine($"Cat.Spoon (Far): {far:F4}"); + Assert.True(close < far); } } \ No newline at end of file diff --git a/LLama.Unittest/LLavaWeightsTests.cs b/LLama.Unittest/LLavaWeightsTests.cs index 9ca0a7246..e5df30732 100644 --- a/LLama.Unittest/LLavaWeightsTests.cs +++ b/LLama.Unittest/LLavaWeightsTests.cs @@ -14,7 +14,7 @@ public sealed class LLavaWeightTests public LLavaWeightTests() { - var @params = new ModelParams(Constants.ModelPath) + var @params = new ModelParams(Constants.GenerativeModelPath) { // Llava models requires big context ContextSize = 4096 diff --git a/LLama.Unittest/MemoryDisposalTests.cs b/LLama.Unittest/MemoryDisposalTests.cs index 4ee976a85..e29ad46d5 100644 --- a/LLama.Unittest/MemoryDisposalTests.cs +++ b/LLama.Unittest/MemoryDisposalTests.cs @@ -7,7 +7,7 @@ public class MemoryDisposalTests [Fact] public void ModelDisposal() { - var @params = new ModelParams(Constants.ModelPath) + var @params = new ModelParams(Constants.GenerativeModelPath) { ContextSize = 2048 }; @@ -21,7 +21,7 @@ public void ModelDisposal() [Fact] public void ContextDisposal() { - var @params = new ModelParams(Constants.ModelPath) + var @params = new ModelParams(Constants.GenerativeModelPath) { ContextSize = 2048 }; diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index cfe499734..3ca8a76e2 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -15,7 +15,7 @@ public class StatelessExecutorTest public StatelessExecutorTest(ITestOutputHelper testOutputHelper) { _testOutputHelper = testOutputHelper; - _params = new ModelParams(Constants.ModelPath) + _params = new ModelParams(Constants.GenerativeModelPath) { ContextSize = 60, Seed = 1754, diff --git a/LLama.Unittest/StreamingTextDecoderTests.cs b/LLama.Unittest/StreamingTextDecoderTests.cs index 680ca076f..7291b5d07 100644 --- a/LLama.Unittest/StreamingTextDecoderTests.cs +++ b/LLama.Unittest/StreamingTextDecoderTests.cs @@ -14,7 +14,7 @@ public class StreamingTextDecoderTests public StreamingTextDecoderTests(ITestOutputHelper testOutputHelper) { _testOutputHelper = testOutputHelper; - _params = new ModelParams(Constants.ModelPath); + _params = new ModelParams(Constants.GenerativeModelPath); _model = LLamaWeights.LoadFromFile(_params); } diff --git a/LLama.Unittest/TokenTests.cs b/LLama.Unittest/TokenTests.cs index e39df5f47..c11e3ae96 100644 --- a/LLama.Unittest/TokenTests.cs +++ b/LLama.Unittest/TokenTests.cs @@ -12,7 +12,7 @@ public sealed class TokenTests public TokenTests() { - _params = new ModelParams(Constants.ModelPath) + _params = new ModelParams(Constants.GenerativeModelPath) { ContextSize = 2048 }; diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs index f1a14a90b..fb30af400 100644 --- a/LLama.Web/Common/ModelOptions.cs +++ b/LLama.Web/Common/ModelOptions.cs @@ -29,9 +29,13 @@ public class ModelOptions /// public int GpuLayerCount { get; set; } = 20; + public uint SeqMax { get; } + /// public uint Seed { get; set; } = 1686349486; + public bool Embeddings { get; } + /// public bool UseMemorymap { get; set; } = true; @@ -57,7 +61,7 @@ public class ModelOptions public uint BatchSize { get; set; } = 512; /// - public bool EmbeddingMode { get; set; } = false; + public uint UBatchSize { get; set; } = 512; /// public TensorSplitsCollection TensorSplits { get; set; } = new(); @@ -108,6 +112,6 @@ public class ModelOptions public float DefragThreshold { get; set; } /// - public bool DoPooling { get; set; } + public LLamaPoolingType PoolingType { get; set; } } } \ No newline at end of file diff --git a/LLama/Abstractions/IContextParams.cs b/LLama/Abstractions/IContextParams.cs index d8592a081..f56417169 100644 --- a/LLama/Abstractions/IContextParams.cs +++ b/LLama/Abstractions/IContextParams.cs @@ -14,20 +14,29 @@ public interface IContextParams uint? ContextSize { get; } /// - /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) + /// maximum batch size that can be submitted at once (must be >=32 to use BLAS) (n_batch) /// uint BatchSize { get; } + /// + /// Physical batch size + /// + uint UBatchSize { get; } + + /// + /// max number of sequences (i.e. distinct states for recurrent models) + /// + uint SeqMax { get; } + /// /// Seed for the random number generator (seed) /// uint Seed { get; } /// - /// Whether to use embedding mode. (embedding) Note that if this is set to true, - /// The LLamaModel won't produce text response anymore. + /// If true, extract embeddings (together with logits). /// - bool EmbeddingMode { get; } + bool Embeddings { get; } /// /// RoPE base frequency (null to fetch from the model) @@ -105,7 +114,7 @@ public interface IContextParams float DefragThreshold { get; } /// - /// Whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) + /// How to pool (sum) embedding results by sequence id (ignored if no pooling layer) /// - bool DoPooling { get; } + LLamaPoolingType PoolingType { get; } } \ No newline at end of file diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index 413e35f89..d1b5989b3 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -24,6 +24,9 @@ public record ModelParams /// public int GpuLayerCount { get; set; } = 20; + /// + public uint SeqMax { get; set; } = 1; + /// public uint Seed { get; set; } = 0xFFFFFFFF; @@ -52,7 +55,10 @@ public record ModelParams public uint BatchSize { get; set; } = 512; /// - public bool EmbeddingMode { get; set; } + public uint UBatchSize { get; set; } = 512; + + /// + public bool Embeddings { get; set; } /// public TensorSplitsCollection TensorSplits { get; set; } = new(); @@ -97,7 +103,7 @@ public record ModelParams public float DefragThreshold { get; set; } /// - public bool DoPooling { get; set; } + public LLamaPoolingType PoolingType { get; set; } = LLamaPoolingType.Unspecified; /// public bool VocabOnly { get; set; } diff --git a/LLama/Extensions/IContextParamsExtensions.cs b/LLama/Extensions/IContextParamsExtensions.cs index 53cae6e97..fa1a36dd9 100644 --- a/LLama/Extensions/IContextParamsExtensions.cs +++ b/LLama/Extensions/IContextParamsExtensions.cs @@ -20,11 +20,14 @@ public static class IContextParamsExtensions /// public static void ToLlamaContextParams(this IContextParams @params, out LLamaContextParams result) { - result = NativeApi.llama_context_default_params(); + result = LLamaContextParams.Default(); + result.n_ctx = @params.ContextSize ?? 0; result.n_batch = @params.BatchSize; + result.n_ubatch = @params.UBatchSize; + result.n_seq_max = @params.SeqMax; result.seed = @params.Seed; - result.embedding = @params.EmbeddingMode; + result.embeddings = @params.Embeddings; result.rope_freq_base = @params.RopeFrequencyBase ?? 0; result.rope_freq_scale = @params.RopeFrequencyScale ?? 0; @@ -41,10 +44,13 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo result.cb_eval = IntPtr.Zero; result.cb_eval_user_data = IntPtr.Zero; + result.abort_callback = IntPtr.Zero; + result.abort_callback_user_data = IntPtr.Zero; + result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16; result.type_k = @params.TypeV ?? GGMLType.GGML_TYPE_F16; result.offload_kqv = !@params.NoKqvOffload; - result.do_pooling = @params.DoPooling; + result.llama_pooling_type = @params.PoolingType; result.n_threads = Threads(@params.Threads); result.n_threads_batch = Threads(@params.BatchThreads); diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs index 69b9e288b..c7daa1351 100644 --- a/LLama/Extensions/IModelParamsExtensions.cs +++ b/LLama/Extensions/IModelParamsExtensions.cs @@ -28,7 +28,8 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam var disposer = new GroupDisposable(); - result = NativeApi.llama_model_default_params(); + result = LLamaModelParams.Default(); + result.main_gpu = @params.MainGpu; result.split_mode = @params.SplitMode; result.n_gpu_layers = @params.GpuLayerCount < 0 ? int.MaxValue : @params.GpuLayerCount; diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 9517965e6..e398982fe 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -152,6 +152,7 @@ public string DeTokenize(IReadOnlyList tokens) return decoder.Read(); } + #region state load/save /// /// Save the state to specified path. /// @@ -163,7 +164,7 @@ public void SaveState(string filename) File.Delete(filename); // Estimate size of state to write to disk, this is always equal to or greater than the actual size - var estimatedStateSize = (long)NativeApi.llama_get_state_size(NativeHandle); + var estimatedStateSize = checked((long)NativeHandle.GetStateSize()); // Map the file and write the bytes directly to it. This saves copying the bytes into a C# array long writtenBytes; @@ -174,8 +175,53 @@ public void SaveState(string filename) { byte* ptr = null; view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); - writtenBytes = (long)NativeApi.llama_copy_state_data(NativeHandle, ptr); - view.SafeMemoryMappedViewHandle.ReleasePointer(); + try + { + writtenBytes = (long)NativeHandle.GetState(ptr, (ulong)estimatedStateSize); + } + finally + { + view.SafeMemoryMappedViewHandle.ReleasePointer(); + } + } + } + + // Truncate the file to the actual size of data that was written + using (var fileStream = new FileStream(filename, FileMode.Open)) + fileStream.SetLength(writtenBytes); + } + + /// + /// Save the state of a particular sequence to specified path. + /// + /// + /// + public void SaveState(string filename, LLamaSeqId sequence) + { + // Delete that file before overwriting it + if (File.Exists(filename)) + File.Delete(filename); + + // Estimate size of state to write to disk, this is always equal to or greater than the actual size + var estimatedStateSize = checked((long)NativeHandle.GetStateSize(sequence)); + + // Map the file and write the bytes directly to it. This saves copying the bytes into a C# array + long writtenBytes; + using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, estimatedStateSize)) + using (var view = file.CreateViewAccessor(0, estimatedStateSize)) + { + unsafe + { + byte* ptr = null; + view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); + try + { + writtenBytes = (long)NativeHandle.GetState(ptr, (ulong)estimatedStateSize, sequence); + } + finally + { + view.SafeMemoryMappedViewHandle.ReleasePointer(); + } } } @@ -187,7 +233,7 @@ public void SaveState(string filename) /// /// Get the state data as an opaque handle, which can be loaded later using /// - /// Use if you intend to save this state to disk. + /// Use if you intend to save this state to disk. /// public State GetState() { @@ -198,7 +244,11 @@ public State GetState() try { // Copy the state data into memory, discover the actual size required - var actualSize = NativeHandle.GetState(memory, stateSize); + ulong actualSize; + unsafe + { + actualSize = NativeHandle.GetState((byte*)memory, stateSize); + } // Shrink to size memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize); @@ -218,11 +268,48 @@ public State GetState() } } + /// + /// Get the state data as an opaque handle, which can be loaded later using + /// + /// Use if you intend to save this state to disk. + /// + public SequenceState GetState(LLamaSeqId sequence) + { + var stateSize = NativeHandle.GetStateSize(sequence); + + // Allocate a chunk of memory large enough to hold the entire state + var memory = Marshal.AllocHGlobal((nint)stateSize); + try + { + // Copy the state data into memory, discover the actual size required + ulong actualSize; + unsafe + { + actualSize = NativeHandle.GetState((byte*)memory, stateSize, sequence); + } + + // Shrink to size + memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize); + + // Wrap memory in a "state" + var state = new SequenceState(memory, actualSize); + + // Set memory to zero, to prevent it being freed in finally block + memory = IntPtr.Zero; + + return state; + } + finally + { + if (memory != IntPtr.Zero) + Marshal.FreeHGlobal(memory); + } + } + /// /// Load the state from specified path. /// /// - /// public void LoadState(string filename) { // Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from @@ -233,8 +320,41 @@ public void LoadState(string filename) { byte* ptr = null; view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); - NativeApi.llama_set_state_data(NativeHandle, ptr); - view.SafeMemoryMappedViewHandle.ReleasePointer(); + try + { + NativeHandle.SetState(ptr); + } + finally + { + view.SafeMemoryMappedViewHandle.ReleasePointer(); + } + } + } + } + + /// + /// Load the state from specified path into a particular sequence + /// + /// + /// + public void LoadState(string filename, LLamaSeqId sequence) + { + // Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from + using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Open, null)) + using (var view = file.CreateViewAccessor()) + { + unsafe + { + byte* ptr = null; + view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr); + try + { + NativeHandle.SetState(ptr, sequence); + } + finally + { + view.SafeMemoryMappedViewHandle.ReleasePointer(); + } } } } @@ -248,10 +368,25 @@ public void LoadState(State state) { unsafe { - NativeHandle.SetState((byte*)state.DangerousGetHandle().ToPointer()); + NativeHandle.SetState((byte*)state.DangerousGetHandle()); } } + /// + /// Load the state from memory into a particular sequence + /// + /// + /// + /// + public void LoadState(SequenceState state, LLamaSeqId sequence) + { + unsafe + { + NativeHandle.SetState((byte*)state.DangerousGetHandle(), sequence); + } + } + #endregion + /// /// Sample a single token from this context, using the given sampling pipeline /// @@ -357,8 +492,8 @@ public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable la } // Save the newline logit value - var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle); - var nl_logit = logits[(int)nl_token]; + var nl_token = NativeHandle.ModelHandle.Tokens.Newline; + var nl_logit = logits[(int?)nl_token ?? 0]; // Convert logits into token candidates var candidates_p = LLamaTokenDataArray.Create(logits); @@ -371,7 +506,7 @@ public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable la candidates_p.RepetitionPenalty(NativeHandle, last_n_array, repeatPenalty, alphaFrequency, alphaPresence); // Restore newline token logit value if necessary - if (!penalizeNL) + if (!penalizeNL && nl_token.HasValue) { var candidatesSpan = candidates_p.data.Span; for (var i = 0; i < candidates_p.data.Length; i++) @@ -417,12 +552,16 @@ public void Dispose() } /// - /// The state of this model, which can be reloaded later + /// The state of this context, which can be reloaded later /// public class State : SafeLLamaHandleBase { - private ulong _size; + private readonly ulong _size; + /// + /// Get the size in bytes of this state object + /// + public ulong Size => _size; internal State(IntPtr memory, ulong size) : base(memory, true) @@ -441,6 +580,7 @@ protected override bool ReleaseHandle() /// Convert this state to a byte array /// /// + [Obsolete("It is not generally safe to convert a state into a byte array - it will fail if the state is very large")] public byte[] ToByteArray() { var bytes = new byte[_size]; @@ -453,6 +593,7 @@ public byte[] ToByteArray() /// /// /// + [Obsolete("It is not generally safe to convert a state into a byte array - it will fail if the state is very large")] public static State FromByteArray(byte[] bytes) { var memory = Marshal.AllocHGlobal(bytes.Length); @@ -460,5 +601,49 @@ public static State FromByteArray(byte[] bytes) return new State(memory, (ulong)bytes.Length); } } + + /// + /// The state of a single sequence, which can be reloaded later + /// + public class SequenceState + : SafeLLamaHandleBase + { + private readonly ulong _size; + /// + /// Get the size in bytes of this state object + /// + public ulong Size => _size; + + internal SequenceState(IntPtr memory, ulong size) + : base(memory, true) + { + _size = size; + } + + /// + protected override bool ReleaseHandle() + { + Marshal.FreeHGlobal(handle); + return true; + } + + /// + /// Copy bytes to a desintation pointer. + /// + /// Destination to write to + /// Length of the destination buffer + /// Offset from start of src to start copying from + /// Number of bytes written to destination + public unsafe ulong CopyTo(byte* dst, ulong length, ulong offset = 0) + { + var copy = Math.Min(length, _size - offset); + + var src = (byte*)DangerousGetHandle(); + src += offset; + + Buffer.MemoryCopy(src, dst, length, copy); + return copy; + } + } } } diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index 86ecceb51..13a3e1c27 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -32,7 +32,7 @@ public sealed class LLamaEmbedder /// public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null) { - if (!@params.EmbeddingMode) + if (!@params.Embeddings) throw new ArgumentException("EmbeddingMode must be true", nameof(@params)); Context = weights.CreateContext(@params, logger); @@ -75,7 +75,7 @@ public async Task GetEmbeddings(string text, bool addBos, CancellationT n_eval = batchSize; batch.Clear(); - batch.AddRange(tokens.AsSpan(i, n_eval), n_past, LLamaSeqId.Zero, false); + batch.AddRange(tokens.AsSpan(i, n_eval), n_past, LLamaSeqId.Zero, true); n_past += n_eval; var returnCode = await Context.DecodeAsync(batch, cancellationToken); @@ -97,9 +97,10 @@ public async Task GetEmbeddings(string text, bool addBos, CancellationT private float[] GetEmbeddingsArray() { - var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); - if (embeddings == null) + var embeddings = NativeApi.llama_get_embeddings_seq(Context.NativeHandle, LLamaSeqId.Zero); + if (embeddings.Length == 0) return Array.Empty(); + return embeddings.ToArray(); } @@ -111,6 +112,9 @@ private static void Normalize(Span embeddings) lengthSqr += value * value; var length = (float)Math.Sqrt(lengthSqr); + if (length <= float.Epsilon) + return; + // Normalize for (var i = 0; i < embeddings.Length; i++) embeddings[i] /= length; diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index c721726e8..c00ead4f3 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -135,7 +135,7 @@ public StatefulExecutorBase WithSessionFile(string filename) { _logger?.LogInformation($"[LLamaExecutor] Attempting to load saved session from {filename}"); var session_tokens = new LLamaToken[Context.ContextSize]; - if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, out var n_token_count_out)) + if (!NativeApi.llama_state_load_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, out var n_token_count_out)) { _logger?.LogError($"[LLamaExecutor] Failed to load session file {filename}"); throw new RuntimeError($"Failed to load session file {_pathSession}"); @@ -183,7 +183,7 @@ public StatefulExecutorBase WithSessionFile(string filename) public void SaveSessionFile(string filename) { var session_token_array = _session_tokens.ToArray(); - NativeApi.llama_save_session_file(Context.NativeHandle, filename, session_token_array, (ulong)session_token_array.Length); + NativeApi.llama_state_save_file(Context.NativeHandle, filename, session_token_array, (ulong)session_token_array.Length); } /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 99d45e5a5..c3a9a420e 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -163,7 +163,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args) } } - if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle)) + if (_embeds.Count > 0 && _embeds.Last() == Context.NativeHandle.ModelHandle.Tokens.EOS) { args.WaitForInput = true; } diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 055a5f13d..5acf4bd3e 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -21,7 +21,6 @@ namespace LLama public class InteractiveExecutor : StatefulExecutorBase { private bool _is_prompt_run = true; - private readonly LLamaToken _llama_token_newline; // LLava private int _EmbedImagePosition = -1; @@ -36,13 +35,11 @@ public class InteractiveExecutor : StatefulExecutorBase public InteractiveExecutor(LLamaContext context, ILogger? logger = null) : base(context, logger) { - _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle.ModelHandle); } public InteractiveExecutor(LLamaContext context, LLavaWeights clipModel, ILogger? logger = null) : base(context, clipModel, logger) { - _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle.ModelHandle); } /// @@ -210,7 +207,7 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru return (true, Array.Empty()); } - if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle)) + if (_embeds.Count > 0 && _embeds.Last() == Context.NativeHandle.ModelHandle.Tokens.EOS) { return (true, new[] { " [end of text]\n" }); } @@ -308,9 +305,9 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta _last_n_tokens.Enqueue(id); - if (id == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle)) + if (id == Context.NativeHandle.ModelHandle.Tokens.EOS) { - id = _llama_token_newline; + id = Context.NativeHandle.ModelHandle.Tokens.Newline!.Value; if (args.Antiprompts is not null && args.Antiprompts.Count > 0) { var first_antiprompt = Context.Tokenize(args.Antiprompts[0], false); diff --git a/LLama/LLamaQuantizer.cs b/LLama/LLamaQuantizer.cs index 5410f55f9..4e541e920 100644 --- a/LLama/LLamaQuantizer.cs +++ b/LLama/LLamaQuantizer.cs @@ -20,7 +20,8 @@ public static class LLamaQuantizer /// /// Whether the quantization is successful. /// - public static bool Quantize(string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = -1, bool allowRequantize = true, bool quantizeOutputTensor = false) + public static bool Quantize( + string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = -1, bool allowRequantize = true, bool quantizeOutputTensor = false) { if (!ValidateFtype(ftype)) { @@ -28,11 +29,13 @@ public static bool Quantize(string srcFileName, string dstFilename, LLamaFtype f $"to perform quantization."); } - var quantizeParams = NativeApi.llama_model_quantize_default_params(); + var quantizeParams = LLamaModelQuantizeParams.Default(); quantizeParams.ftype = ftype; quantizeParams.nthread = nthread; quantizeParams.allow_requantize = allowRequantize; quantizeParams.quantize_output_tensor = quantizeOutputTensor; + //todo: fill in other quantize params fields. + unsafe { return NativeApi.llama_model_quantize(srcFileName, dstFilename, &quantizeParams) == 0; @@ -59,7 +62,7 @@ public static bool Quantize(string srcFileName, string dstFilename, string ftype private static bool ValidateFtype(LLamaFtype ftype) { // Validation copies from here: - // https://github.com/ggerganov/llama.cpp/blob/3ab8b3a92ede46df88bc5a2dfca3777de4a2b2b6/llama.cpp#L10965 + // https://github.com/ggerganov/llama.cpp/blob/f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7/llama.cpp#L13450 switch (ftype) { @@ -95,6 +98,7 @@ private static bool ValidateFtype(LLamaFtype ftype) case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_XXS: case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ1_S: + case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ1_M: case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ4_NL: case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ4_XS: diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index f9d6ca5b2..487fe2935 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -124,7 +124,7 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams } // Check if this is the EOS token - if (id == _weights.EndOfSentenceToken) + if (id == _weights.Tokens.EOS) break; // Decode this token into text diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 0adb19875..2d8ea4d9b 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -41,24 +41,14 @@ public sealed class LLamaWeights public ulong ParameterCount => NativeHandle.ParameterCount; /// - /// Get the newline token for this model - /// - public LLamaToken NewlineToken => NativeApi.llama_token_nl(NativeHandle); - - /// - /// Get the "end of sentence" token for this model - /// - public LLamaToken EndOfSentenceToken => NativeApi.llama_token_eos(NativeHandle); - - /// - /// Get the "beginning of sentence" token for this model + /// Dimension of embedding vectors /// - public LLamaToken BeginningOfSentenceToken => NativeApi.llama_token_bos(NativeHandle); + public int EmbeddingSize => NativeHandle.EmbeddingSize; /// - /// Dimension of embedding vectors + /// Get the special tokens of this model /// - public int EmbeddingSize => NativeHandle.EmbeddingSize; + public SafeLlamaModelHandle.ModelTokens Tokens => NativeHandle.Tokens; /// /// All metadata keys in this model diff --git a/LLama/Native/LLamaBeamView.cs b/LLama/Native/LLamaBeamView.cs index e832eb620..dcd583ba3 100644 --- a/LLama/Native/LLamaBeamView.cs +++ b/LLama/Native/LLamaBeamView.cs @@ -10,7 +10,7 @@ namespace LLama.Native; public struct LLamaBeamView { private unsafe LLamaToken* tokens; - private nint n_tokens; + private nuint n_tokens; /// /// Cumulative beam probability (renormalized relative to all beams) diff --git a/LLama/Native/LLamaBeamsState.cs b/LLama/Native/LLamaBeamsState.cs index f78c45b97..cb214aef3 100644 --- a/LLama/Native/LLamaBeamsState.cs +++ b/LLama/Native/LLamaBeamsState.cs @@ -19,7 +19,7 @@ public struct LLamaBeamsState /// /// Number of elements in beam_views /// - private nint n_beams; + private nuint n_beams; /// /// Current max length of prefix tokens shared by all beams. diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs index c7e4a3287..8e3d7f74f 100644 --- a/LLama/Native/LLamaContextParams.cs +++ b/LLama/Native/LLamaContextParams.cs @@ -28,10 +28,20 @@ public struct LLamaContextParams public uint n_ctx; /// - /// prompt processing batch size + /// logical maximum batch size that can be submitted to llama_decode /// public uint n_batch; + /// + /// physical maximum batch size + /// + public uint n_ubatch; + + /// + /// max number of sequences (i.e. distinct states for recurrent models) + /// + public uint n_seq_max; + /// /// number of threads to use for generation /// @@ -45,7 +55,12 @@ public struct LLamaContextParams /// /// RoPE scaling type, from `enum llama_rope_scaling_type` /// - public RopeScalingType rope_scaling_type; + public RopeScalingType rope_scaling_type; + + /// + /// whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) + /// + public LLamaPoolingType llama_pooling_type; /// /// RoPE base frequency, 0 = from model @@ -87,11 +102,13 @@ public struct LLamaContextParams /// public float defrag_threshold; + //todo: implement cb_eval callback support /// /// ggml_backend_sched_eval_callback /// public IntPtr cb_eval; + //todo: implement cb_eval callback support /// /// User data passed into cb_eval /// @@ -113,14 +130,14 @@ public struct LLamaContextParams private sbyte _logits_all; /// - /// embedding mode only + /// if true, extract embeddings (together with logits) /// - public bool embedding + public bool embeddings { - readonly get => Convert.ToBoolean(_embedding); - set => _embedding = Convert.ToSByte(value); + readonly get => Convert.ToBoolean(_embeddings); + set => _embeddings = Convert.ToSByte(value); } - private sbyte _embedding; + private sbyte _embeddings; /// /// whether to offload the KQV ops (including the KV cache) to GPU @@ -132,15 +149,29 @@ public bool offload_kqv } private sbyte _offload_kqv; + //todo: implement abort callback support + /// + /// ggml_abort_callback + /// + public IntPtr abort_callback; + + //todo: implement abort callback support /// - /// Whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) + /// User data passed into abort_callback /// - public bool do_pooling + public IntPtr abort_callback_user_data; + + /// + /// Get the default LLamaContextParams + /// + /// + public static LLamaContextParams Default() { - readonly get => Convert.ToBoolean(_do_pooling); - set => _do_pooling = Convert.ToSByte(value); + return llama_context_default_params(); + + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + static extern LLamaContextParams llama_context_default_params(); } - private sbyte _do_pooling; } } diff --git a/LLama/Native/LLamaFtype.cs b/LLama/Native/LLamaFtype.cs index e943c6ff9..ae2702db2 100644 --- a/LLama/Native/LLamaFtype.cs +++ b/LLama/Native/LLamaFtype.cs @@ -166,6 +166,11 @@ public enum LLamaFtype /// LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, + /// + /// except 1d tensors + /// + LLAMA_FTYPE_MOSTLY_IQ1_M = 31, + /// /// File type was not specified /// diff --git a/LLama/Native/LLamaKvCacheView.cs b/LLama/Native/LLamaKvCacheView.cs index 2b4087720..86169c602 100644 --- a/LLama/Native/LLamaKvCacheView.cs +++ b/LLama/Native/LLamaKvCacheView.cs @@ -28,7 +28,7 @@ public unsafe struct LLamaKvCacheView // Maximum number of sequences that can exist in a cell. It's not an error // if there are more sequences in a cell than this value, however they will // not be visible in the view cells_sequences. - int n_max_seq; + int n_seq_max; // Number of tokens in the cache. For example, if there are two populated // cells, the first with 1 sequence id in it and the second with 2 sequence @@ -48,7 +48,7 @@ public unsafe struct LLamaKvCacheView // Information for an individual cell. LLamaKvCacheViewCell* cells; - // The sequences for each cell. There will be n_max_seq items per cell. + // The sequences for each cell. There will be n_seq_max items per cell. LLamaSeqId* cells_sequences; } @@ -118,10 +118,10 @@ public static partial class NativeApi /// Create an empty KV cache view. (use only for debugging purposes) /// /// - /// + /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern LLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_max_seq); + public static extern LLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_seq_max); /// /// Free a KV cache view. (use only for debugging purposes) diff --git a/LLama/Native/LLamaModelParams.cs b/LLama/Native/LLamaModelParams.cs index aff2c514a..923b042c5 100644 --- a/LLama/Native/LLamaModelParams.cs +++ b/LLama/Native/LLamaModelParams.cs @@ -1,84 +1,96 @@ -using System; -using System.Runtime.InteropServices; - -namespace LLama.Native -{ - /// - /// A C# representation of the llama.cpp `llama_model_params` struct - /// - [StructLayout(LayoutKind.Sequential)] - public unsafe struct LLamaModelParams - { - /// - /// // number of layers to store in VRAM - /// - public int n_gpu_layers; - - /// - /// how to split the model across multiple GPUs - /// - public GPUSplitMode split_mode; - - /// - /// the GPU that is used for scratch and small tensors - /// - public int main_gpu; - - /// - /// how to split layers across multiple GPUs (size: ) - /// +using System; +using System.Runtime.InteropServices; + +namespace LLama.Native +{ + /// + /// A C# representation of the llama.cpp `llama_model_params` struct + /// + [StructLayout(LayoutKind.Sequential)] + public unsafe struct LLamaModelParams + { + /// + /// // number of layers to store in VRAM + /// + public int n_gpu_layers; + + /// + /// how to split the model across multiple GPUs + /// + public GPUSplitMode split_mode; + + /// + /// the GPU that is used for scratch and small tensors + /// + public int main_gpu; + + /// + /// how to split layers across multiple GPUs (size: ) + /// public float* tensor_split; - /// - /// called with a progress value between 0 and 1, pass NULL to disable. If the provided progress_callback - /// returns true, model loading continues. If it returns false, model loading is immediately aborted. - /// -#if NETSTANDARD2_0 + /// + /// called with a progress value between 0 and 1, pass NULL to disable. If the provided progress_callback + /// returns true, model loading continues. If it returns false, model loading is immediately aborted. + /// +#if NETSTANDARD2_0 // this code is intended to be used when running LlamaSharp on NET Framework 4.8 (NET Standard 2.0) // as NET Framework 4.8 does not play nice with the LlamaProgressCallback type - public IntPtr progress_callback; -#else - public LlamaProgressCallback progress_callback; -#endif - - /// - /// context pointer passed to the progress callback - /// - public void* progress_callback_user_data; - - /// - /// override key-value pairs of the model meta data - /// - public LLamaModelMetadataOverride* kv_overrides; - - /// - /// only load the vocabulary, no weights - /// - public bool vocab_only - { - readonly get => Convert.ToBoolean(_vocab_only); - set => _vocab_only = Convert.ToSByte(value); - } - private sbyte _vocab_only; - - /// - /// use mmap if possible - /// - public bool use_mmap - { - readonly get => Convert.ToBoolean(_use_mmap); - set => _use_mmap = Convert.ToSByte(value); - } - private sbyte _use_mmap; - - /// - /// force system to keep model in RAM - /// - public bool use_mlock - { - readonly get => Convert.ToBoolean(_use_mlock); - set => _use_mlock = Convert.ToSByte(value); - } - private sbyte _use_mlock; - } -} + public IntPtr progress_callback; +#else + public LlamaProgressCallback progress_callback; +#endif + + /// + /// context pointer passed to the progress callback + /// + public void* progress_callback_user_data; + + /// + /// override key-value pairs of the model meta data + /// + public LLamaModelMetadataOverride* kv_overrides; + + /// + /// only load the vocabulary, no weights + /// + public bool vocab_only + { + readonly get => Convert.ToBoolean(_vocab_only); + set => _vocab_only = Convert.ToSByte(value); + } + private sbyte _vocab_only; + + /// + /// use mmap if possible + /// + public bool use_mmap + { + readonly get => Convert.ToBoolean(_use_mmap); + set => _use_mmap = Convert.ToSByte(value); + } + private sbyte _use_mmap; + + /// + /// force system to keep model in RAM + /// + public bool use_mlock + { + readonly get => Convert.ToBoolean(_use_mlock); + set => _use_mlock = Convert.ToSByte(value); + } + private sbyte _use_mlock; + + /// + /// Create a LLamaModelParams with default values + /// + /// + public static LLamaModelParams Default() + { + return llama_model_default_params(); + + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + static extern LLamaModelParams llama_model_default_params(); + } + } +} diff --git a/LLama/Native/LLamaModelQuantizeParams.cs b/LLama/Native/LLamaModelQuantizeParams.cs index 34c1a9743..b2d37eb05 100644 --- a/LLama/Native/LLamaModelQuantizeParams.cs +++ b/LLama/Native/LLamaModelQuantizeParams.cs @@ -20,6 +20,16 @@ public struct LLamaModelQuantizeParams /// public LLamaFtype ftype; + /// + /// output tensor type + /// + public GGMLType output_tensor_type; + + /// + /// itoken embeddings tensor type + /// + public GGMLType token_embedding_type; + /// /// allow quantizing non-f32/f16 tensors /// @@ -51,7 +61,7 @@ public bool only_copy private sbyte _only_copy; /// - /// disable k-quant mixtures and quantize all tensors to the same type + /// quantize all tensors to the default type /// public bool pure { @@ -64,5 +74,22 @@ public bool pure /// pointer to importance matrix data /// public IntPtr imatrix; + + /// + /// pointer to vector containing overrides + /// + public IntPtr kv_overrides; + + /// + /// Create a LLamaModelQuantizeParams with default values + /// + /// + public static LLamaModelQuantizeParams Default() + { + return llama_model_quantize_default_params(); + + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + static extern LLamaModelQuantizeParams llama_model_quantize_default_params(); + } } } diff --git a/LLama/Native/LLamaPoolingType.cs b/LLama/Native/LLamaPoolingType.cs index 13bdfe6c2..cd928849f 100644 --- a/LLama/Native/LLamaPoolingType.cs +++ b/LLama/Native/LLamaPoolingType.cs @@ -6,6 +6,7 @@ /// llama_pooling_type public enum LLamaPoolingType { + Unspecified = -1, None = 0, Mean = 1, CLS = 2, diff --git a/LLama/Native/LLamaVocabType.cs b/LLama/Native/LLamaVocabType.cs index 6620f5101..d3cdf2dc3 100644 --- a/LLama/Native/LLamaVocabType.cs +++ b/LLama/Native/LLamaVocabType.cs @@ -6,7 +6,23 @@ /// llama_vocab_type public enum LLamaVocabType { - SentencePiece = 0, - BytePairEncoding = 1, - WordPiece = 2, + /// + /// For models without vocab + /// + None = 0, + + /// + /// LLaMA tokenizer based on byte-level BPE with byte fallback + /// + SentencePiece = 1, + + /// + /// GPT-2 tokenizer based on byte-level BPE + /// + BytePairEncoding = 2, + + /// + /// BERT tokenizer based on WordPiece + /// + WordPiece = 3, } \ No newline at end of file diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 41c1809e0..6f8d142e3 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -29,27 +29,6 @@ public static void llama_empty_call() [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern long llama_max_devices(); - /// - /// Create a LLamaModelParams with default values - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern LLamaModelParams llama_model_default_params(); - - /// - /// Create a LLamaContextParams with default values - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern LLamaContextParams llama_context_default_params(); - - /// - /// Create a LLamaModelQuantizeParams with default values - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern LLamaModelQuantizeParams llama_model_quantize_default_params(); - /// /// Check if memory mapping is supported /// @@ -72,8 +51,9 @@ public static void llama_empty_call() public static extern bool llama_supports_gpu_offload(); /// - /// Initialize the llama + ggml backend - /// Call once at the start of the program + /// Initialize the llama + ggml backend. Call once at the start of the program. + /// + /// This is private because LLamaSharp automatically calls it, and it's only valid to call it once! /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern void llama_backend_init(); @@ -87,42 +67,6 @@ public static void llama_empty_call() //[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] //public static extern void llama_numa_init(ggml_numa_strategy numa); - /// - /// Sets the current rng seed. - /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_set_rng_seed(SafeLLamaContextHandle ctx, uint seed); - - /// - /// Returns the maximum size in bytes of the state (rng, logits, embedding - /// and kv_cache) - will often be smaller after compacting tokens - /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern ulong llama_get_state_size(SafeLLamaContextHandle ctx); - - /// - /// Copies the state to the specified destination address. - /// Destination needs to have allocated enough memory. - /// - /// - /// - /// the number of bytes copied - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte* dest); - - /// - /// Set the state reading from the specified address - /// - /// - /// - /// the number of bytes read - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte* src); - /// /// Load session file /// @@ -133,7 +77,7 @@ public static void llama_empty_call() /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out); + public static extern bool llama_state_load_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out); /// /// Save session file @@ -144,63 +88,97 @@ public static void llama_empty_call() /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern bool llama_save_session_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens, ulong n_token_count); + public static extern bool llama_state_save_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens, ulong n_token_count); [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, LLamaToken token); + public static extern unsafe nuint llama_state_seq_save_file(SafeLLamaContextHandle ctx, string filepath, LLamaSeqId seq_id, LLamaToken* tokens, nuint n_token_count); [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern float llama_token_get_score(SafeLlamaModelHandle model, LLamaToken token); + public static extern unsafe nuint llama_state_seq_load_file(SafeLLamaContextHandle ctx, string filepath, LLamaSeqId dest_seq_id, LLamaToken* tokens_out, nuint n_token_capacity, out nuint n_token_count_out); [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern LLamaTokenType llama_token_get_type(SafeLlamaModelHandle model, LLamaToken token); + public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, LLamaToken token); /// - /// Get the size of the context window for the model for this context + /// Set whether to use causal attention or not. If set to true, the model will only attend to the past tokens /// - /// - /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern uint llama_n_ctx(SafeLLamaContextHandle ctx); + public static extern void llama_set_causal_attn(SafeLlamaModelHandle ctx, bool causal_attn); /// - /// Get the batch size for this context + /// Set abort callback /// - /// - /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern uint llama_n_batch(SafeLLamaContextHandle ctx); + public static extern void llama_set_abort_callback(SafeLlamaModelHandle ctx, IntPtr /* ggml_abort_callback */ abort_callback, IntPtr abort_callback_data); /// - /// Token logits obtained from the last call to llama_decode - /// The logits for the last token are stored in the last row - /// Can be mutated in order to change the probabilities of the next token.
- /// Rows: n_tokens
- /// Cols: n_vocab + /// Wait until all computations are finished. This is automatically done when using any of the functions to obtain computation results + /// and is not necessary to call it explicitly in most cases. ///
/// - /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe float* llama_get_logits(SafeLLamaContextHandle ctx); + public static extern void llama_synchronize(SafeLlamaModelHandle ctx); + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern float llama_token_get_score(SafeLlamaModelHandle model, LLamaToken token); + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern LLamaTokenType llama_token_get_type(SafeLlamaModelHandle model, LLamaToken token); /// - /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab + /// Get the n_seq_max for this context /// /// - /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i); + public static extern uint llama_n_seq_max(SafeLLamaContextHandle ctx); /// - /// Get the embeddings for the ith sequence. Equivalent to: llama_get_embeddings(ctx) + i*n_embd + /// Get the embeddings for the a specific sequence. + /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd /// /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe float* llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i); + public static Span llama_get_embeddings_seq(SafeLLamaContextHandle ctx, LLamaSeqId id) + { + unsafe + { + var ptr = llama_get_embeddings_seq_native(ctx, id); + if (ptr == null) + return Array.Empty(); + + return new Span(ptr, ctx.EmbeddingSize); + } + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings_seq")] + static extern unsafe float* llama_get_embeddings_seq_native(SafeLLamaContextHandle ctx, LLamaSeqId id); + } + + /// + /// Get the embeddings for the ith sequence. + /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd + /// + /// + public static Span llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i) + { + unsafe + { + var ptr = llama_get_embeddings_ith_native(ctx, i); + if (ptr == null) + return Array.Empty(); + + return new Span(ptr, ctx.EmbeddingSize); + } + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings_ith")] + static extern unsafe float* llama_get_embeddings_ith_native(SafeLLamaContextHandle ctx, int i); + } /// - /// Get the embeddings for the input + /// Get all output token embeddings. + /// When pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, the embeddings for which + /// llama_batch.logits[i] != 0 are stored contiguously in the order they have appeared in the batch. + /// shape: [n_outputs*n_embd] + /// Otherwise, returns an empty span. /// /// /// @@ -209,6 +187,9 @@ public static Span llama_get_embeddings(SafeLLamaContextHandle ctx) unsafe { var ptr = llama_get_embeddings_native(ctx); + if (ptr == null) + return Array.Empty(); + return new Span(ptr, ctx.EmbeddingSize); } @@ -230,28 +211,7 @@ public static Span llama_get_embeddings(SafeLLamaContextHandle ctx) /// The size of the allocated buffer /// The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings")] - public static extern unsafe int llama_chat_apply_template(SafeLlamaModelHandle model, char* tmpl, LLamaChatMessage* chat, nint n_msg, bool add_ass, char* buf, int length); - - /// - /// Get the "Beginning of sentence" token - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern LLamaToken llama_token_bos(SafeLlamaModelHandle model); - - /// - /// Get the "End of sentence" token - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern LLamaToken llama_token_eos(SafeLlamaModelHandle model); - - /// - /// Get the "new line" token - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern LLamaToken llama_token_nl(SafeLlamaModelHandle model); + public static extern unsafe int llama_chat_apply_template(SafeLlamaModelHandle model, char* tmpl, LLamaChatMessage* chat, nuint n_msg, bool add_ass, char* buf, int length); /// /// Returns -1 if unknown, 1 for true or 0 for false. @@ -267,34 +227,6 @@ public static Span llama_get_embeddings(SafeLLamaContextHandle ctx) [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_add_eos_token(SafeLlamaModelHandle model); - /// - /// codellama infill tokens, Beginning of infill prefix - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_token_prefix(SafeLlamaModelHandle model); - - /// - /// codellama infill tokens, Beginning of infill middle - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_token_middle(SafeLlamaModelHandle model); - - /// - /// codellama infill tokens, Beginning of infill suffix - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_token_suffix(SafeLlamaModelHandle model); - - /// - /// codellama infill tokens, End of infill middle - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_token_eot(SafeLlamaModelHandle model); - /// /// Print out timing information for this context /// @@ -345,13 +277,13 @@ public static int llama_token_to_piece(SafeLlamaModelHandle model, LLamaToken ll /// /// /// - /// - /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space. + /// + /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space. /// Returns the number of tokens on success, no more than n_max_tokens. /// Returns a negative number on failure - the number of tokens that would have been returned /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, LLamaToken* tokens, int n_max_tokens, bool add_bos, bool special); + public static extern unsafe int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, LLamaToken* tokens, int n_max_tokens, bool add_special, bool parse_special); /// /// Register a callback to receive llama log messages @@ -377,8 +309,9 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) /// /// /// + /// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_kv_cache_seq_rm(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1); + public static extern bool llama_kv_cache_seq_rm(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1); /// /// Copy all tokens that belong to the specified sequence to another sequence @@ -441,23 +374,6 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern LLamaPos llama_kv_cache_seq_pos_max(SafeLLamaContextHandle ctx, LLamaSeqId seq); - /// - /// Defragment the KV cache. This will be applied: - /// - lazily on next llama_decode() - /// - explicitly with llama_kv_cache_update() - /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern LLamaPos llama_kv_cache_defrag(SafeLLamaContextHandle ctx); - - /// - /// Apply the KV cache updates (such as K-shifts, defragmentation, etc.) - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_kv_cache_update(SafeLLamaContextHandle ctx); - /// /// Allocates a batch of tokens on the heap /// Each token can be assigned up to n_seq_max sequence ids @@ -481,31 +397,47 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) public static extern void llama_batch_free(LLamaNativeBatch batch); /// + /// Apply a loaded control vector to a llama_context, or if data is NULL, clear + /// the currently loaded vector. + /// n_embd should be the size of a single layer's control, and data should point + /// to an n_embd x n_layers buffer starting from layer 1. + /// il_start and il_end are the layer range the vector should apply to (both inclusive) + /// See llama_control_vector_load in common to load a control vector. /// /// - /// - /// Positive return values does not mean a fatal error, but rather a warning:
- /// - 0: success
- /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
- /// - < 0: error
- ///
+ /// + /// + /// + /// + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_decode(SafeLLamaContextHandle ctx, LLamaNativeBatch batch); + public static extern unsafe int llama_control_vector_apply(SafeLLamaContextHandle ctx, float* data, nuint len, int n_embd, int il_start, int il_end); /// - /// Set the number of threads used for decoding + /// Build a split GGUF final path for this chunk. + /// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf" /// - /// - /// n_threads is the number of threads used for generation (single token) - /// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) - /// + /// + /// + /// + /// + /// + /// Returns the split_path length. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch); - - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern LLamaVocabType llama_vocab_type(SafeLlamaModelHandle model); + public static extern int llama_split_path(string split_path, nuint maxlen, string path_prefix, int split_no, int split_count); + /// + /// Extract the path prefix from the split_path if and only if the split_no and split_count match. + /// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0" + /// + /// + /// + /// + /// + /// + /// Returns the split_prefix length. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern LLamaRopeType llama_rope_type(SafeLlamaModelHandle model); + public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no, int split_count); } } diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 2f881fa5d..c9e959a09 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -22,7 +22,7 @@ public sealed class SafeLLamaContextHandle /// /// Total number of tokens in the context /// - public uint ContextSize => NativeApi.llama_n_ctx(this); + public uint ContextSize => llama_n_ctx(this); /// /// Dimension of embedding vectors @@ -32,7 +32,12 @@ public sealed class SafeLLamaContextHandle /// /// Get the maximum batch size for this context /// - public uint BatchSize => NativeApi.llama_n_batch(this); + public uint BatchSize => llama_n_batch(this); + + /// + /// Get the physical maximum batch size for this context + /// + public uint UBatchSize => llama_n_ubatch(this); /// /// Get the model which this context is using @@ -127,8 +132,157 @@ static SafeLLamaContextHandle() /// /// private unsafe delegate bool GgmlAbortCallback(void* data); + + /// + /// + /// + /// + /// Positive return values does not mean a fatal error, but rather a warning:
+ /// - 0: success
+ /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
+ /// - < 0: error
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_decode(SafeLLamaContextHandle ctx, LLamaNativeBatch batch); + + /// + /// Set the number of threads used for decoding + /// + /// + /// n_threads is the number of threads used for generation (single token) + /// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern void llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch); + + /// + /// Token logits obtained from the last call to llama_decode + /// The logits for the last token are stored in the last row + /// Can be mutated in order to change the probabilities of the next token.
+ /// Rows: n_tokens
+ /// Cols: n_vocab + ///
+ /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe float* llama_get_logits(SafeLLamaContextHandle ctx); + + /// + /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab + /// + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i); + + /// + /// Get the size of the context window for the model for this context + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern uint llama_n_ctx(SafeLLamaContextHandle ctx); + + /// + /// Get the batch size for this context + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern uint llama_n_batch(SafeLLamaContextHandle ctx); + + /// + /// Get the ubatch size for this context + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern uint llama_n_ubatch(SafeLLamaContextHandle ctx); + + /// + /// Sets the current rng seed. + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern void llama_set_rng_seed(SafeLLamaContextHandle ctx, uint seed); + + /// + /// Returns the maximum size in bytes of the state (rng, logits, embedding + /// and kv_cache) - will often be smaller after compacting tokens + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern ulong llama_state_get_size(SafeLLamaContextHandle ctx); + + /// + /// Copies the state to the specified destination address. + /// Destination needs to have allocated enough memory. + /// + /// + /// + /// the number of bytes copied + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe ulong llama_state_get_data(SafeLLamaContextHandle ctx, byte* dest); + + /// + /// Set the state reading from the specified address + /// + /// + /// + /// the number of bytes read + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe ulong llama_state_set_data(SafeLLamaContextHandle ctx, byte* src); + + /// + /// Get the exact size needed to copy the KV cache of a single sequence + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern nuint llama_state_seq_get_size(SafeLLamaContextHandle ctx, LLamaSeqId seq_id); + + /// + /// Copy the KV cache of a single sequence into the specified buffer + /// + /// + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe nuint llama_state_seq_get_data(SafeLLamaContextHandle ctx, byte* dst, LLamaSeqId seq_id); + + /// + /// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence + /// + /// + /// + /// + /// + /// - Positive: Ok + /// - Zero: Failed to load + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern unsafe nuint llama_state_seq_set_data(SafeLLamaContextHandle ctx, byte* src, LLamaSeqId dest_seq_id); + + /// + /// Defragment the KV cache. This will be applied: + /// - lazily on next llama_decode() + /// - explicitly with llama_kv_cache_update() + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern void llama_kv_cache_defrag(SafeLLamaContextHandle ctx); + + /// + /// Apply the KV cache updates (such as K-shifts, defragmentation, etc.) + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_kv_cache_update(SafeLLamaContextHandle ctx); #endregion - + /// /// Token logits obtained from the last call to llama_decode /// The logits for the last token are stored in the last row @@ -143,7 +297,7 @@ public Span GetLogits() unsafe { - var logits = NativeApi.llama_get_logits(this); + var logits = llama_get_logits(this); return new Span(logits, model.VocabCount); } } @@ -159,7 +313,7 @@ public Span GetLogitsIth(int i) unsafe { - var logits = NativeApi.llama_get_logits_ith(this, i); + var logits = llama_get_logits_ith(this, i); return new Span(logits, model.VocabCount); } } @@ -216,7 +370,7 @@ public DecodeResult Decode(LLamaBatch batch) { lock (GlobalInferenceLock) using (batch.ToNativeBatch(out var nb)) - return (DecodeResult)NativeApi.llama_decode(this, nb); + return (DecodeResult)llama_decode(this, nb); } /// @@ -260,19 +414,17 @@ public DecodeResult Decode(LLamaBatch batch) /// public ulong GetStateSize() { - return NativeApi.llama_get_state_size(this); + return llama_state_get_size(this); } /// - /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer. + /// Get the size of the KV cache for a single sequence ID, when saved as bytes /// - /// Destination to write to - /// Number of bytes available to write to in dest (check required size with `GetStateSize()`) - /// The number of bytes written to dest - /// Thrown if dest is too small - public unsafe ulong GetState(byte* dest, ulong size) + /// + /// + public ulong GetStateSize(LLamaSeqId sequence) { - return GetState(new IntPtr(dest), size); + return llama_state_seq_get_size(this, sequence); } /// @@ -282,7 +434,7 @@ public unsafe ulong GetState(byte* dest, ulong size) /// Number of bytes available to write to in dest (check required size with `GetStateSize()`) /// The number of bytes written to dest /// Thrown if dest is too small - public ulong GetState(IntPtr dest, ulong size) + public unsafe ulong GetState(byte* dest, ulong size) { var required = GetStateSize(); if (size < required) @@ -290,10 +442,26 @@ public ulong GetState(IntPtr dest, ulong size) unsafe { - return NativeApi.llama_copy_state_data(this, (byte*)dest.ToPointer()); + return llama_state_get_data(this, dest); } } + /// + /// Get the raw state of a single sequence from this context, encoded as bytes. Data is written into the `dest` pointer. + /// + /// Destination to write to + /// Number of bytes available to write to in dest (check required size with `GetStateSize()`) + /// The sequence to get state data for + /// The number of bytes written to dest + public unsafe ulong GetState(byte* dest, ulong size, LLamaSeqId sequence) + { + var required = GetStateSize(sequence); + if (size < required) + throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}"); + + return llama_state_seq_get_data(this, dest, sequence); + } + /// /// Set the raw state of this context /// @@ -301,20 +469,18 @@ public ulong GetState(IntPtr dest, ulong size) /// Number of bytes read from the src pointer public unsafe ulong SetState(byte* src) { - return SetState(new IntPtr(src)); + return llama_state_set_data(this, src); } /// - /// Set the raw state of this context + /// Set the raw state of a single sequence /// /// The pointer to read the state from + /// Sequence ID to set /// Number of bytes read from the src pointer - public ulong SetState(IntPtr src) + public unsafe ulong SetState(byte* src, LLamaSeqId sequence) { - unsafe - { - return NativeApi.llama_set_state_data(this, (byte*)src.ToPointer()); - } + return llama_state_seq_set_data(this, src, sequence); } #endregion @@ -324,7 +490,7 @@ public ulong SetState(IntPtr src) /// public void SetSeed(uint seed) { - NativeApi.llama_set_rng_seed(this, seed); + llama_set_rng_seed(this, seed); } /// @@ -334,10 +500,29 @@ public void SetSeed(uint seed) /// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) public void SetThreads(uint threads, uint threadsBatch) { - NativeApi.llama_set_n_threads(this, threads, threadsBatch); + llama_set_n_threads(this, threads, threadsBatch); } #region KV Cache Management + /// + /// Apply KV cache updates (such as K-shifts, defragmentation, etc.) + /// + public void KvCacheUpdate() + { + llama_kv_cache_update(this); + } + + /// + /// Defragment the KV cache. This will be applied: + /// - lazily on next llama_decode() + /// - explicitly with llama_kv_cache_update() + /// + /// + public void KvCacheDefrag() + { + llama_kv_cache_defrag(this); + } + /// /// Get a new KV cache view that can be used to debug the KV cache /// diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index acaee849a..31b73dfd4 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -22,6 +22,16 @@ public sealed class SafeLlamaModelHandle /// public int VocabCount => llama_n_vocab(this); + /// + /// Get the vocabulary type for this model + /// + public LLamaVocabType VocabType => llama_vocab_type(this); + + /// + /// Get the rope (positional embedding) type for this model + /// + public LLamaRopeType RopeType => llama_rope_type(this); + /// /// Total number of tokens in the context /// @@ -47,6 +57,11 @@ public sealed class SafeLlamaModelHandle /// public ulong ParameterCount => llama_model_n_params(this); + /// + /// Get the number of layers in this model + /// + public int LayerCount => llama_n_embd(this); + /// /// Get a description of this model /// @@ -74,6 +89,13 @@ public string Description /// public int MetadataCount => llama_model_meta_count(this); + private ModelTokens? _tokens; + + /// + /// Get the special tokens of this model + /// + public ModelTokens Tokens => _tokens ??= new ModelTokens(this); + /// protected override bool ReleaseHandle() { @@ -105,7 +127,7 @@ public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelPara #region native API static SafeLlamaModelHandle() { - // This ensures that `NativeApi` has been loaded before calling the two native methods below + // Ensure that `NativeApi` has been loaded NativeApi.llama_empty_call(); } @@ -132,7 +154,7 @@ static SafeLlamaModelHandle() /// /// Returns 0 on success [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, float scale, string? path_base_model, int n_threads); + private static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, float scale, string? path_base_model, int n_threads); /// /// Frees all allocated memory associated with a model @@ -210,6 +232,12 @@ private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern int llama_n_vocab(SafeLlamaModelHandle model); + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern LLamaVocabType llama_vocab_type(SafeLlamaModelHandle model); + + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern LLamaRopeType llama_rope_type(SafeLlamaModelHandle model); + /// /// Get the size of the context window for the model /// @@ -226,6 +254,14 @@ private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern int llama_n_embd(SafeLlamaModelHandle model); + /// + /// Get the number of layers in this model + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_n_layers(SafeLlamaModelHandle model); + /// /// Get a string describing the model type /// @@ -259,6 +295,69 @@ private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern float llama_rope_freq_scale_train(SafeLlamaModelHandle model); + + /// + /// Get the "Beginning of sentence" token + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern LLamaToken llama_token_bos(SafeLlamaModelHandle model); + + /// + /// Get the "End of sentence" token + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern LLamaToken llama_token_eos(SafeLlamaModelHandle model); + + /// + /// Get the "classification" token + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern LLamaToken llama_token_cls(SafeLlamaModelHandle model); + + /// + /// Get the "sentence separator" token + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern LLamaToken llama_token_sep(SafeLlamaModelHandle model); + + /// + /// Get the "new line" token + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern LLamaToken llama_token_nl(SafeLlamaModelHandle model); + + /// + /// codellama infill tokens, Beginning of infill prefix + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_token_prefix(SafeLlamaModelHandle model); + + /// + /// codellama infill tokens, Beginning of infill middle + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_token_middle(SafeLlamaModelHandle model); + + /// + /// codellama infill tokens, Beginning of infill suffix + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_token_suffix(SafeLlamaModelHandle model); + + /// + /// codellama infill tokens, End of infill middle + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_token_eot(SafeLlamaModelHandle model); #endregion #region LoRA @@ -273,6 +372,14 @@ private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, /// public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, int? threads = null) { + // Try to open the model file, this will check: + // - File exists (automatically throws FileNotFoundException) + // - File is readable (explicit check) + // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases. + using (var fs = new FileStream(lora, FileMode.Open)) + if (!fs.CanRead) + throw new InvalidOperationException($"LoRA file '{lora}' is not readable"); + var err = llama_model_apply_lora_from_file( this, lora, @@ -282,7 +389,7 @@ public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null ); if (err != 0) - throw new RuntimeError("Failed to apply lora adapter."); + throw new RuntimeError($"Failed to apply lora adapter (err={err})."); } #endregion @@ -451,5 +558,68 @@ internal IReadOnlyDictionary ReadMetadata() return result; } #endregion + + /// + /// Get tokens for a model + /// + public class ModelTokens + { + private readonly SafeLlamaModelHandle _model; + + internal ModelTokens(SafeLlamaModelHandle model) + { + _model = model; + } + + private static LLamaToken? Normalize(LLamaToken token) + { + return token == -1 ? null : token; + } + + /// + /// Get the Beginning of Sentence token for this model + /// + public LLamaToken? BOS => Normalize(llama_token_bos(_model)); + + /// + /// Get the End of Sentence token for this model + /// + public LLamaToken? EOS => Normalize(llama_token_eos(_model)); + + /// + /// Get the newline token for this model + /// + public LLamaToken? Newline => Normalize(llama_token_nl(_model)); + + /// + /// Get the classification token for this model + /// + public LLamaToken? CLS => Normalize(llama_token_cls(_model)); + + /// + /// Get the sentence separator token for this model + /// + public LLamaToken? SEP => Normalize(llama_token_sep(_model)); + + /// + /// Codellama beginning of infill prefix + /// + public LLamaToken? InfillPrefix => Normalize(llama_token_prefix(_model)); + + /// + /// Codellama beginning of infill middle + /// + public LLamaToken? InfillMiddle => Normalize(llama_token_middle(_model)); + + /// + /// Codellama beginning of infill suffix + /// + public LLamaToken? InfillSuffix => Normalize(llama_token_suffix(_model)); + + /// + /// Codellama end of infill middle + /// + public LLamaToken? EOT => Normalize(llama_token_eot(_model)); + } } } diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index 5a9ef16cf..33806f5f9 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -133,18 +133,21 @@ protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, private static (int, float) GetNewlineLogit(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) { - var nlToken = NativeApi.llama_token_nl(ctx.ModelHandle); + var nlToken = ctx.ModelHandle.Tokens.Newline; - // Try using the ID as an index - if (candidates.data.Span[(int)nlToken].id == nlToken) - return ((int)nlToken, candidates.data.Span[(int)nlToken].logit); - - // Exhaustive search - var span = candidates.data.Span; - for (var i = 0; i < span.Length; i++) + if (nlToken.HasValue) { - if (span[i].id == nlToken) - return (i, span[i].logit); + // Try using the ID as an index + if (candidates.data.Span[(int)nlToken].id == nlToken) + return ((int)nlToken, candidates.data.Span[(int)nlToken].logit); + + // Exhaustive search + var span = candidates.data.Span; + for (var i = 0; i < span.Length; i++) + { + if (span[i].id == nlToken) + return (i, span[i].logit); + } } return (-1, 0); @@ -152,7 +155,9 @@ private static (int, float) GetNewlineLogit(SafeLLamaContextHandle ctx, LLamaTok private static void SetNewlineLogit(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int indexHint, float logit) { - var nlToken = NativeApi.llama_token_nl(ctx.ModelHandle); + var nlToken = ctx.ModelHandle.Tokens.Newline; + if (!nlToken.HasValue) + return; // Try checking the index where we found it last time. It might not be there if `RepetitionPenalty` changed order if (indexHint >= 0 && candidates.data.Span[indexHint].id == nlToken) diff --git a/LLama/runtimes/deps/avx/libllama.dll b/LLama/runtimes/deps/avx/libllama.dll index f09813ae6..3b3b3fae3 100644 Binary files a/LLama/runtimes/deps/avx/libllama.dll and b/LLama/runtimes/deps/avx/libllama.dll differ diff --git a/LLama/runtimes/deps/avx/libllama.so b/LLama/runtimes/deps/avx/libllama.so index 6929f1a4a..08dcee0f7 100644 Binary files a/LLama/runtimes/deps/avx/libllama.so and b/LLama/runtimes/deps/avx/libllama.so differ diff --git a/LLama/runtimes/deps/avx/libllava_shared.so b/LLama/runtimes/deps/avx/libllava_shared.so index 5868e2ed5..1c8adfcb7 100644 Binary files a/LLama/runtimes/deps/avx/libllava_shared.so and b/LLama/runtimes/deps/avx/libllava_shared.so differ diff --git a/LLama/runtimes/deps/avx/llama.dll b/LLama/runtimes/deps/avx/llama.dll index f09813ae6..3b3b3fae3 100644 Binary files a/LLama/runtimes/deps/avx/llama.dll and b/LLama/runtimes/deps/avx/llama.dll differ diff --git a/LLama/runtimes/deps/avx/llava_shared.dll b/LLama/runtimes/deps/avx/llava_shared.dll index 546da7588..e08a474b6 100644 Binary files a/LLama/runtimes/deps/avx/llava_shared.dll and b/LLama/runtimes/deps/avx/llava_shared.dll differ diff --git a/LLama/runtimes/deps/avx2/libllama.dll b/LLama/runtimes/deps/avx2/libllama.dll index 481be2352..bb8e5c48b 100644 Binary files a/LLama/runtimes/deps/avx2/libllama.dll and b/LLama/runtimes/deps/avx2/libllama.dll differ diff --git a/LLama/runtimes/deps/avx2/libllama.so b/LLama/runtimes/deps/avx2/libllama.so index fd2548fd7..e7c27f79f 100644 Binary files a/LLama/runtimes/deps/avx2/libllama.so and b/LLama/runtimes/deps/avx2/libllama.so differ diff --git a/LLama/runtimes/deps/avx2/libllava_shared.so b/LLama/runtimes/deps/avx2/libllava_shared.so index 18ca3617d..f9bbdf272 100644 Binary files a/LLama/runtimes/deps/avx2/libllava_shared.so and b/LLama/runtimes/deps/avx2/libllava_shared.so differ diff --git a/LLama/runtimes/deps/avx2/llama.dll b/LLama/runtimes/deps/avx2/llama.dll index 481be2352..bb8e5c48b 100644 Binary files a/LLama/runtimes/deps/avx2/llama.dll and b/LLama/runtimes/deps/avx2/llama.dll differ diff --git a/LLama/runtimes/deps/avx2/llava_shared.dll b/LLama/runtimes/deps/avx2/llava_shared.dll index f877c590f..6b4ad9c13 100644 Binary files a/LLama/runtimes/deps/avx2/llava_shared.dll and b/LLama/runtimes/deps/avx2/llava_shared.dll differ diff --git a/LLama/runtimes/deps/avx512/libllama.dll b/LLama/runtimes/deps/avx512/libllama.dll index 9f3030289..fcbc052eb 100644 Binary files a/LLama/runtimes/deps/avx512/libllama.dll and b/LLama/runtimes/deps/avx512/libllama.dll differ diff --git a/LLama/runtimes/deps/avx512/libllama.so b/LLama/runtimes/deps/avx512/libllama.so index 6700ef7cd..a0044ad66 100644 Binary files a/LLama/runtimes/deps/avx512/libllama.so and b/LLama/runtimes/deps/avx512/libllama.so differ diff --git a/LLama/runtimes/deps/avx512/libllava_shared.so b/LLama/runtimes/deps/avx512/libllava_shared.so index 8f2a9677f..d0c76ef13 100644 Binary files a/LLama/runtimes/deps/avx512/libllava_shared.so and b/LLama/runtimes/deps/avx512/libllava_shared.so differ diff --git a/LLama/runtimes/deps/avx512/llama.dll b/LLama/runtimes/deps/avx512/llama.dll index 9f3030289..fcbc052eb 100644 Binary files a/LLama/runtimes/deps/avx512/llama.dll and b/LLama/runtimes/deps/avx512/llama.dll differ diff --git a/LLama/runtimes/deps/avx512/llava_shared.dll b/LLama/runtimes/deps/avx512/llava_shared.dll index e0cfbe44a..8d643cb1d 100644 Binary files a/LLama/runtimes/deps/avx512/llava_shared.dll and b/LLama/runtimes/deps/avx512/llava_shared.dll differ diff --git a/LLama/runtimes/deps/clblast/libllama.so b/LLama/runtimes/deps/clblast/libllama.so index 9b5f87900..b6bff999a 100644 Binary files a/LLama/runtimes/deps/clblast/libllama.so and b/LLama/runtimes/deps/clblast/libllama.so differ diff --git a/LLama/runtimes/deps/clblast/libllava_shared.so b/LLama/runtimes/deps/clblast/libllava_shared.so index 764e7266d..6f63d183a 100644 Binary files a/LLama/runtimes/deps/clblast/libllava_shared.so and b/LLama/runtimes/deps/clblast/libllava_shared.so differ diff --git a/LLama/runtimes/deps/clblast/llama.dll b/LLama/runtimes/deps/clblast/llama.dll index a08951358..055b24a84 100644 Binary files a/LLama/runtimes/deps/clblast/llama.dll and b/LLama/runtimes/deps/clblast/llama.dll differ diff --git a/LLama/runtimes/deps/clblast/llava_shared.dll b/LLama/runtimes/deps/clblast/llava_shared.dll index e4a51d0ba..349ec89e5 100644 Binary files a/LLama/runtimes/deps/clblast/llava_shared.dll and b/LLama/runtimes/deps/clblast/llava_shared.dll differ diff --git a/LLama/runtimes/deps/cu11.7.1/libllama.so b/LLama/runtimes/deps/cu11.7.1/libllama.so index ef9baa519..1f1a79a0f 100644 Binary files a/LLama/runtimes/deps/cu11.7.1/libllama.so and b/LLama/runtimes/deps/cu11.7.1/libllama.so differ diff --git a/LLama/runtimes/deps/cu11.7.1/libllava_shared.so b/LLama/runtimes/deps/cu11.7.1/libllava_shared.so index 7ad6a066e..47cba9b13 100644 Binary files a/LLama/runtimes/deps/cu11.7.1/libllava_shared.so and b/LLama/runtimes/deps/cu11.7.1/libllava_shared.so differ diff --git a/LLama/runtimes/deps/cu11.7.1/llama.dll b/LLama/runtimes/deps/cu11.7.1/llama.dll index 22cd79574..a1cd82b25 100644 Binary files a/LLama/runtimes/deps/cu11.7.1/llama.dll and b/LLama/runtimes/deps/cu11.7.1/llama.dll differ diff --git a/LLama/runtimes/deps/cu11.7.1/llava_shared.dll b/LLama/runtimes/deps/cu11.7.1/llava_shared.dll index a5d1c514a..00e9794e6 100644 Binary files a/LLama/runtimes/deps/cu11.7.1/llava_shared.dll and b/LLama/runtimes/deps/cu11.7.1/llava_shared.dll differ diff --git a/LLama/runtimes/deps/cu12.1.0/libllama.so b/LLama/runtimes/deps/cu12.1.0/libllama.so index ac66c69f8..39b09e6b3 100644 Binary files a/LLama/runtimes/deps/cu12.1.0/libllama.so and b/LLama/runtimes/deps/cu12.1.0/libllama.so differ diff --git a/LLama/runtimes/deps/cu12.1.0/libllava_shared.so b/LLama/runtimes/deps/cu12.1.0/libllava_shared.so index 166633a80..ce830a7da 100644 Binary files a/LLama/runtimes/deps/cu12.1.0/libllava_shared.so and b/LLama/runtimes/deps/cu12.1.0/libllava_shared.so differ diff --git a/LLama/runtimes/deps/cu12.1.0/llama.dll b/LLama/runtimes/deps/cu12.1.0/llama.dll index b12c7776d..09a87cb20 100644 Binary files a/LLama/runtimes/deps/cu12.1.0/llama.dll and b/LLama/runtimes/deps/cu12.1.0/llama.dll differ diff --git a/LLama/runtimes/deps/cu12.1.0/llava_shared.dll b/LLama/runtimes/deps/cu12.1.0/llava_shared.dll index fdef226c3..597733ed3 100644 Binary files a/LLama/runtimes/deps/cu12.1.0/llava_shared.dll and b/LLama/runtimes/deps/cu12.1.0/llava_shared.dll differ diff --git a/LLama/runtimes/deps/libllama.dll b/LLama/runtimes/deps/libllama.dll index bd256c0a8..deb86e0df 100644 Binary files a/LLama/runtimes/deps/libllama.dll and b/LLama/runtimes/deps/libllama.dll differ diff --git a/LLama/runtimes/deps/libllama.so b/LLama/runtimes/deps/libllama.so index 66176da16..85dce3430 100644 Binary files a/LLama/runtimes/deps/libllama.so and b/LLama/runtimes/deps/libllama.so differ diff --git a/LLama/runtimes/deps/libllava_shared.so b/LLama/runtimes/deps/libllava_shared.so index 1c2c17144..f41a9c670 100644 Binary files a/LLama/runtimes/deps/libllava_shared.so and b/LLama/runtimes/deps/libllava_shared.so differ diff --git a/LLama/runtimes/deps/llama.dll b/LLama/runtimes/deps/llama.dll index bd256c0a8..deb86e0df 100644 Binary files a/LLama/runtimes/deps/llama.dll and b/LLama/runtimes/deps/llama.dll differ diff --git a/LLama/runtimes/deps/llava_shared.dll b/LLama/runtimes/deps/llava_shared.dll index d1aafcad9..10b057945 100644 Binary files a/LLama/runtimes/deps/llava_shared.dll and b/LLama/runtimes/deps/llava_shared.dll differ diff --git a/LLama/runtimes/deps/osx-arm64/ggml-metal.metal b/LLama/runtimes/deps/osx-arm64/ggml-metal.metal index 74a5e0b03..9a29f57a3 100644 --- a/LLama/runtimes/deps/osx-arm64/ggml-metal.metal +++ b/LLama/runtimes/deps/osx-arm64/ggml-metal.metal @@ -1,3 +1,7 @@ +#define GGML_COMMON_DECL_METAL +#define GGML_COMMON_IMPL_METAL +#include "ggml-common.h" + #include using namespace metal; @@ -6,46 +10,11 @@ using namespace metal; #define MIN(x, y) ((x) < (y) ? (x) : (y)) #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } -#define QK4_0 32 -#define QR4_0 2 -typedef struct { - half d; // delta - uint8_t qs[QK4_0 / 2]; // nibbles / quants -} block_q4_0; - -#define QK4_1 32 -typedef struct { - half d; // delta - half m; // min - uint8_t qs[QK4_1 / 2]; // nibbles / quants -} block_q4_1; - -#define QK5_0 32 -typedef struct { - half d; // delta - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_0 / 2]; // nibbles / quants -} block_q5_0; - -#define QK5_1 32 -typedef struct { - half d; // delta - half m; // min - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_1 / 2]; // nibbles / quants -} block_q5_1; - -#define QK8_0 32 -typedef struct { - half d; // delta - int8_t qs[QK8_0]; // quants -} block_q8_0; - #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 enum ggml_sort_order { - GGML_SORT_ASC, - GGML_SORT_DESC, + GGML_SORT_ORDER_ASC, + GGML_SORT_ORDER_DESC, }; // general-purpose kernel for addition, multiplication and division of two tensors @@ -1959,11 +1928,56 @@ kernel void kernel_pad_f32( } } +kernel void kernel_arange_f32( + device char * dst, + constant int64_t & ne0, + constant float & start, + constant float & step, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + device float * dst_ptr = (device float *) dst; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = start + step * i0; + } +} + +kernel void kernel_timestep_embedding_f32( + device const char * src0, + device char * dst, + constant uint64_t & nb1, + constant int & dim, + constant int & max_period, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + int i = tgpig.x; + device float * embed_data = (device float *)(dst + i*nb1); + + int half_ = dim / 2; + for (int j = tpitg.x; j < half_; j += ntg.x) { + float timestep = ((device float *)src0)[i]; + float freq = (float)exp(-log((float)max_period) * j / half_); + float arg = timestep * freq; + embed_data[j ] = cos(arg); + embed_data[j + half_] = sin(arg); + } + + if (dim % 2 != 0 && tpitg.x == 0) { + embed_data[dim] = 0.f; + } +} + // bitonic sort implementation following the CUDA kernels as reference typedef void (argsort_t)( - device const float * x, - device int32_t * dst, - constant int64_t & ncols, + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]); @@ -1972,33 +1986,42 @@ kernel void kernel_argsort_f32_i32( device const float * x, device int32_t * dst, constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]) { // bitonic sort int col = tpitg[0]; int row = tgpig[1]; - if (col >= ncols) return; + if (col >= ncols_pad) return; - device const float * x_row = x + row * ncols; - device int32_t * dst_row = dst + row * ncols; + device const float * x_row = x + row * ncols; + threadgroup int32_t * dst_row = shared_values; // initialize indices - if (col < ncols) { - dst_row[col] = col; - } + dst_row[col] = col; + threadgroup_barrier(mem_flags::mem_threadgroup); - for (int k = 2; k <= ncols; k *= 2) { + for (int k = 2; k <= ncols_pad; k *= 2) { for (int j = k / 2; j > 0; j /= 2) { int ixj = col ^ j; if (ixj > col) { if ((col & k) == 0) { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { SWAP(dst_row[col], dst_row[ixj]); } } else { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { SWAP(dst_row[col], dst_row[ixj]); } } @@ -2006,10 +2029,15 @@ kernel void kernel_argsort_f32_i32( threadgroup_barrier(mem_flags::mem_threadgroup); } } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } } -template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; -template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; kernel void kernel_leaky_relu_f32( device const float * src0, @@ -2376,6 +2404,242 @@ kernel void kernel_cpy_f32_q4_1( } } +kernel void kernel_cpy_f32_q5_0( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0; + + device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK5_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK5_0].d = d; + + uint32_t qh = 0; + for (int j = 0; j < QK5_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK5_0/2 + j]*id; + + const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); + + dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + for (int j = 0; j < 4; ++j) { + dst_data[i00/QK5_0].qh[j] = qh8[j]; + } + } +} + +kernel void kernel_cpy_f32_q5_1( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1; + + device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float max = src[0]; + float min = src[0]; + + for (int j = 1; j < QK5_1; j++) { + const float v = src[j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK5_1].d = d; + dst_data[i00/QK5_1].m = min; + + uint32_t qh = 0; + for (int j = 0; j < QK5_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK5_1/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); + } + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + for (int j = 0; j < 4; ++j) { + dst_data[i00/QK5_1].qh[j] = qh8[j]; + } + } +} + +static inline int best_index_int8(int n, constant float * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} + +constexpr constant static float kvalues_iq4nl_f[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +kernel void kernel_cpy_f32_iq4_nl( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL; + + device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / kvalues_iq4nl_f[0]; + const float id = d ? 1.0f/d : 0.0f; + + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_NL/2 + j]*id; + + const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0); + const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1); + + dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4); + + const float v0 = kvalues_iq4nl_f[xi0]; + const float v1 = kvalues_iq4nl_f[xi1]; + const float w0 = src[0 + j]*src[0 + j]; + const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j]; + sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + + } + + dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d; + + } +} + kernel void kernel_concat( device const char * src0, device const char * src1, @@ -2432,147 +2696,6 @@ kernel void kernel_concat( } } -//============================================ k-quants ====================================================== - -#ifndef QK_K -#define QK_K 256 -#else -static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64"); -#endif - -#if QK_K == 256 -#define K_SCALE_SIZE 12 -#else -#define K_SCALE_SIZE 4 -#endif - -typedef struct { - uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits - uint8_t qs[QK_K/4]; // quants - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins -} block_q2_K; -// 84 bytes / block - -typedef struct { - uint8_t hmask[QK_K/8]; // quants - high bit - uint8_t qs[QK_K/4]; // quants - low 2 bits -#if QK_K == 64 - uint8_t scales[2]; -#else - uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits -#endif - half d; // super-block scale -} block_q3_K; - -#if QK_K == 64 -typedef struct { - half d[2]; // super-block scales/mins - uint8_t scales[2]; - uint8_t qs[QK_K/2]; // 4-bit quants -} block_q4_K; -#else -typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; -#endif - -#if QK_K == 64 -typedef struct { - half d; // super-block scales/mins - int8_t scales[QK_K/16]; // 8-bit block scales - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -#else -typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -// 176 bytes / block -#endif - -typedef struct { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales, quantized with 8 bits - half d; // super-block scale -} block_q6_K; -// 210 bytes / block - -typedef struct { - half d; - uint16_t qs[QK_K/8]; -} block_iq2_xxs; -// 66 bytes / block for QK_K = 256, so 2.0625 bpw - -typedef struct { - half d; - uint16_t qs[QK_K/8]; - uint8_t scales[QK_K/32]; -} block_iq2_xs; -// 74 bytes / block for QK_K = 256, so 2.3125 bpw - -// 2.5625 bpw quants -typedef struct { - half d; - uint8_t qs[QK_K/4]; - uint8_t qh[QK_K/32]; - uint8_t scales[QK_K/32]; -} block_iq2_s; - -typedef struct { - half d; - uint8_t qs[3*QK_K/8]; -} block_iq3_xxs; -// 98 bytes / block for QK_K = 256, so 3.0625 bpw - -// 3.4375 bpw -#if QK_K == 64 -#define IQ3S_N_SCALE 2 -#else -#define IQ3S_N_SCALE QK_K/64 -#endif -typedef struct { - half d; - uint8_t qs[QK_K/4]; - uint8_t qh[QK_K/32]; - uint8_t signs[QK_K/8]; - uint8_t scales[IQ3S_N_SCALE]; -} block_iq3_s; - -typedef struct { - half d; - uint8_t qs[QK_K/8]; - uint8_t scales[QK_K/16]; -} block_iq1_s; - -// Non-linear quants -#define QK4_NL 32 -typedef struct { - half d; - uint8_t qs[QK4_NL/2]; -} block_iq4_nl; - -#if QK_K == 64 -#define block_iq4_xs block_iq4_nl -#else -typedef struct { - half d; - uint16_t scales_h; - uint8_t scales_l[QK_K/64]; - uint8_t qs[QK_K/2]; -} block_iq4_xs; -#endif - -//====================================== dot products ========================= - void kernel_mul_mv_q2_K_f32_impl( device const void * src0, device const float * src1, @@ -3595,710 +3718,6 @@ kernel void kernel_mul_mv_q6_K_f32( // ======================= "True" 2-bit -constexpr constant static uint64_t iq2xxs_grid[256] = { - 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, - 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808, - 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819, - 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819, - 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b, - 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808, - 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08, - 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b, - 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819, - 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08, - 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, - 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08, - 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808, - 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808, - 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919, - 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819, - 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08, - 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908, - 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819, - 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808, - 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808, - 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908, - 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808, - 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08, - 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819, - 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819, - 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819, - 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908, - 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19, - 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819, - 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b, - 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808, - 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908, - 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08, - 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08, - 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908, - 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819, - 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808, - 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808, - 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19, - 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819, - 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, - 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b, - 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08, - 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808, - 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908, - 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b, - 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819, - 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08, - 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08, - 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808, - 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b, - 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b, - 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908, - 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819, - 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808, - 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908, - 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b, - 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808, - 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b, - 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b, - 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808, - 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19, - 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908, -}; - -constexpr constant static uint64_t iq2xs_grid[512] = { - 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, - 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, - 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, - 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, - 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, - 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808, - 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819, - 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819, - 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, - 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b, - 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b, - 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908, - 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908, - 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919, - 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808, - 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919, - 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908, - 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, - 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, - 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08, - 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808, - 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808, - 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819, - 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908, - 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819, - 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808, - 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b, - 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819, - 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819, - 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808, - 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908, - 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19, - 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b, - 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b, - 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919, - 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808, - 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819, - 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819, - 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b, - 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908, - 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808, - 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819, - 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808, - 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, - 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808, - 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808, - 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908, - 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908, - 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808, - 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b, - 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819, - 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, - 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908, - 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808, - 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908, - 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919, - 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08, - 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19, - 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b, - 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b, - 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808, - 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08, - 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b, - 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908, - 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b, - 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908, - 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, - 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808, - 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808, - 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08, - 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819, - 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919, - 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808, - 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808, - 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819, - 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819, - 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908, - 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908, - 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b, - 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908, - 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908, - 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908, - 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808, - 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, - 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819, - 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819, - 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808, - 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b, - 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819, - 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819, - 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08, - 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808, - 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19, - 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919, - 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, - 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19, - 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b, - 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808, - 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b, - 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b, - 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, - 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b, - 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808, - 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819, - 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808, - 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808, - 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08, - 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b, - 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19, - 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08, - 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919, - 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08, - 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08, - 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908, - 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908, - 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b, - 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908, - 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808, - 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b, - 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808, - 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808, - 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19, - 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08, - 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808, - 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b, - 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808, - 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b, - 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b, -}; - -constexpr constant static uint64_t iq2s_grid[1024] = { - 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, - 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, - 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, - 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, - 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, - 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b, - 0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919, - 0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808, - 0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908, - 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b, - 0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908, - 0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08, - 0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19, - 0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819, - 0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919, - 0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b, - 0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, - 0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908, - 0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919, - 0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908, - 0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b, - 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919, - 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b, - 0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, - 0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908, - 0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b, - 0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b, - 0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08, - 0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, - 0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819, - 0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808, - 0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908, - 0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b, - 0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908, - 0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08, - 0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819, - 0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808, - 0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08, - 0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819, - 0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b, - 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908, - 0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919, - 0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b, - 0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919, - 0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808, - 0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819, - 0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919, - 0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919, - 0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808, - 0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819, - 0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b, - 0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908, - 0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, - 0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, - 0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919, - 0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b, - 0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919, - 0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b, - 0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819, - 0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919, - 0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908, - 0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b, - 0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908, - 0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b, - 0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908, - 0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08, - 0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908, - 0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819, - 0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819, - 0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808, - 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08, - 0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19, - 0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819, - 0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808, - 0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819, - 0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919, - 0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808, - 0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19, - 0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08, - 0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b, - 0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908, - 0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808, - 0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819, - 0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908, - 0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819, - 0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808, - 0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808, - 0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819, - 0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908, - 0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08, - 0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819, - 0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b, - 0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b, - 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08, - 0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19, - 0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819, - 0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919, - 0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908, - 0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808, - 0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808, - 0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908, - 0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808, - 0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08, - 0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08, - 0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908, - 0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919, - 0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808, - 0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819, - 0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908, - 0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08, - 0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819, - 0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808, - 0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808, - 0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819, - 0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808, - 0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908, - 0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b, - 0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, - 0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, - 0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b, - 0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808, - 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b, - 0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19, - 0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819, - 0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08, - 0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b, - 0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908, - 0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b, - 0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b, - 0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919, - 0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808, - 0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819, - 0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908, - 0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08, - 0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08, - 0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819, - 0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919, - 0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908, - 0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b, - 0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908, - 0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b, - 0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908, - 0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08, - 0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819, - 0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808, - 0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819, - 0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919, - 0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808, - 0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808, - 0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08, - 0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819, - 0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919, - 0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808, - 0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819, - 0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919, - 0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808, - 0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b, - 0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908, - 0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808, - 0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908, - 0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b, - 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908, - 0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b, - 0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908, - 0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b, - 0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908, - 0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08, - 0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908, - 0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b, - 0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908, - 0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08, - 0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819, - 0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919, - 0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808, - 0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19, - 0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b, - 0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919, - 0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808, - 0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819, - 0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908, - 0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919, - 0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808, - 0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808, - 0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b, - 0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919, - 0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808, - 0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b, - 0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808, - 0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919, - 0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b, - 0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08, - 0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919, - 0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808, - 0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b, - 0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908, - 0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808, - 0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808, - 0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808, - 0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908, - 0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808, - 0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808, - 0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b, - 0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908, - 0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808, - 0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808, - 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819, - 0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919, - 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b, - 0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808, - 0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819, - 0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b, - 0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908, - 0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08, - 0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908, - 0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919, - 0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819, - 0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908, - 0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b, - 0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808, - 0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819, - 0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908, - 0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919, - 0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808, - 0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808, - 0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808, - 0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919, - 0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908, - 0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908, - 0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08, - 0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819, - 0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b, - 0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808, - 0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819, - 0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908, - 0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819, - 0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808, - 0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808, - 0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b, - 0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908, - 0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808, - 0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908, - 0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819, - 0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819, - 0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808, - 0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b, - 0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b, - 0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819, - 0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b, - 0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b, - 0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b, - 0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819, - 0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19, - 0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819, - 0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908, - 0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808, - 0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b, -}; - -constexpr constant static uint32_t iq3xxs_grid[256] = { - 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, - 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, - 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404, - 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e, - 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c, - 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c, - 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34, - 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c, - 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c, - 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04, - 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c, - 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414, - 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434, - 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c, - 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e, - 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24, - 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24, - 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c, - 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c, - 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14, - 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414, - 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e, - 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404, - 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c, - 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c, - 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14, - 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c, - 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c, - 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14, - 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14, - 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, - 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, -}; - -constexpr constant static uint32_t iq3xs_grid[512] = { - 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14, - 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414, - 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24, - 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c, - 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c, - 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34, - 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c, - 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414, - 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c, - 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404, - 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434, - 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c, - 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404, - 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414, - 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414, - 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404, - 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c, - 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c, - 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404, - 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e, - 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14, - 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c, - 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424, - 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c, - 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c, - 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e, - 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e, - 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e, - 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424, - 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e, - 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424, - 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404, - 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c, - 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e, - 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c, - 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c, - 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c, - 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404, - 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04, - 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c, - 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414, - 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c, - 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c, - 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424, - 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c, - 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c, - 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414, - 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c, - 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e, - 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04, - 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424, - 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14, - 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34, - 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c, - 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434, - 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c, - 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424, - 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24, - 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24, - 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e, - 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c, - 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c, - 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c, - 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404, -}; - -#define NGRID_IQ1S 512 -constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = { - 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000, - 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01, - 0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100, - 0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00, - 0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101, - 0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100, - 0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00, - 0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff, - 0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000, - 0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000, - 0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001, - 0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff, - 0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01, - 0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001, - 0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00, - 0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001, - 0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100, - 0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000, - 0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000, - 0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000, - 0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff, - 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff, - 0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01, - 0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100, - 0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff, - 0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000, - 0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101, - 0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff, - 0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff, - 0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001, - 0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01, - 0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101, - 0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100, - 0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00, - 0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001, - 0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff, - 0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000, - 0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000, - 0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100, - 0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100, - 0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01, - 0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff, - 0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101, - 0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000, - 0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff, - 0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000, - 0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff, - 0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00, - 0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101, - 0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000, - 0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000, - 0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000, - 0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100, - 0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000, - 0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001, - 0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff, - 0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000, - 0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000, - 0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000, - 0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000, - 0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff, - 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000, - 0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001, - 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01, - 0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100, - 0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000, - 0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00, - 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100, - 0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000, - 0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001, - 0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00, - 0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff, - 0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100, - 0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff, - 0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000, - 0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff, - 0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff, - 0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00, - 0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001, - 0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001, - 0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01, - 0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000, - 0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101, - 0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00, - 0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100, - 0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101, - 0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101, - 0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000, - 0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff, - 0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff, - 0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101, - 0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff, - 0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101, - 0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001, - 0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff, - 0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff, - 0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01, - 0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff, - 0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100, - 0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001, - 0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00, - 0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff, - 0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff, - 0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000, - 0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000, - 0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101, - 0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001, - 0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000, - 0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101, - 0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000, - 0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001, - 0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000, - 0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100, - 0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000, - 0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000, - 0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100, - 0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff, - 0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff, - 0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00, - 0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101, - 0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000, - 0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00, - 0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000, - 0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff, - 0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101, - 0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff, - 0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00, - 0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff, -}; - -constexpr constant static uint8_t ksigns_iq2xs[128] = { - 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, - 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, - 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175, - 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63, - 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207, - 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95, - 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111, - 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, -}; - -constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128}; - void kernel_mul_mv_iq2_xxs_f32_impl( device const void * src0, device const float * src1, @@ -4742,7 +4161,7 @@ void kernel_mul_mv_iq3_s_f32_impl( { int nval = 8; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3xs_grid[pos + i]; + for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -4769,12 +4188,14 @@ void kernel_mul_mv_iq3_s_f32_impl( for (int row = 0; row < N_DST; row++) { const float db = dh[0]; - const float d = db * (0.5f + ((sc[0] >> 4*(ib%2)) & 0xf)); + const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); float2 sum = {0}; for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); - const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); + const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values; + const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values; + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); for (int j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); @@ -4795,7 +4216,7 @@ void kernel_mul_mv_iq3_s_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f; + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; } } } @@ -4994,48 +4415,53 @@ void kernel_mul_mv_iq1_s_f32_impl( device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float yl[16]; + float yl[32]; float sumf[N_DST]={0.f}, all_sum; const int nb32 = nb * (QK_K / 32); - const int ix = tiisg/2; - const int il = tiisg%2; + const int ix = tiisg; - device const float * y4 = y + 32 * ix + 16 * il; + device const float * y4 = y + 32 * ix; - for (int ib32 = ix; ib32 < nb32; ib32 += 16) { + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - for (int i = 0; i < 16; ++i) { + float sumy = 0; + for (int i = 0; i < 32; ++i) { yl[i] = y4[i]; + sumy += yl[i]; } const int ibl = ib32 / (QK_K / 32); const int ib = ib32 % (QK_K / 32); device const block_iq1_s * xr = x + ibl; - device const uint8_t * qs = xr->qs + 4 * ib + 2 * il; - device const uint8_t * sc = xr->scales + 2 * ib + il; - device const half * dh = &xr->d; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint16_t * qh = xr->qh + ib; + device const half * dh = &xr->d; for (int row = 0; row < N_DST; row++) { - constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5))); - constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1))); + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); + constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); + constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); - float2 sum = {0}; - for (int j = 0; j < 8; ++j) { - sum[0] += yl[j+ 0] * grid1[j]; - sum[1] += yl[j+ 8] * grid2[j]; + float sum = 0; + for (int j = 0; j < 4; ++j) { + sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) + + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) + + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); } - sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1)); + sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); dh += nb*sizeof(block_iq1_s)/2; qs += nb*sizeof(block_iq1_s); - sc += nb*sizeof(block_iq1_s); + qh += nb*sizeof(block_iq1_s)/2; } - y4 += 16 * 32; + y4 += 32 * 32; } for (int row = 0; row < N_DST; ++row) { @@ -5046,9 +4472,113 @@ void kernel_mul_mv_iq1_s_f32_impl( } } -constexpr constant static float kvalues_iq4nl_f[16] = { - -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f -}; +void kernel_mul_mv_iq1_m_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + +#if QK_K != 64 + iq1m_scale_t scale; +#endif + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + float4 sumy = {0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+24]; sumy[3] += yl[i+24]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq1_m * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint8_t * qh = xr->qh + 2 * ib; + device const uint16_t * sc = (device const uint16_t *)xr->scales; + + for (int row = 0; row < N_DST; row++) { + +#if QK_K != 64 + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); +#endif + + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); + constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700))); + constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700))); + + float2 sum = {0.f}; + for (int j = 0; j < 4; ++j) { + sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4); + sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) + + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); + } + const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); +#if QK_K == 64 + const float d = (float) *((device const half *)(sc - 1)); + sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) + + (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1)); +#else + sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + + (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); +#endif + + sc += nb*sizeof(block_iq1_m)/2; + qs += nb*sizeof(block_iq1_m); + qh += nb*sizeof(block_iq1_m); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} void kernel_mul_mv_iq4_nl_f32_impl( device const void * src0, @@ -5267,6 +4797,34 @@ kernel void kernel_mul_mv_iq1_s_f32( kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); } +[[host_name("kernel_mul_mv_iq1_m_f32")]] +kernel void kernel_mul_mv_iq1_m_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + [[host_name("kernel_mul_mv_iq4_nl_f32")]] kernel void kernel_mul_mv_iq4_nl_f32( device const void * src0, @@ -5685,15 +5243,15 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & device const uint8_t * qs = xb->qs + 8*ib32; device const uint8_t * signs = xb->signs + 4*ib32 + 2*il; const uint8_t qh = xb->qh[ib32] >> 4*il; - const float dl = d * (0.5f + ((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * 0.5f; - constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256))); - constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256))); + const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); + constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256))); + constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256))); for (int i = 0; i < 4; ++i) { reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); } - grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256))); - grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256))); + grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256))); + grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256))); for (int i = 0; i < 4; ++i) { reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); @@ -5722,16 +5280,53 @@ void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & template void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) { // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; const float d = xb->d; - device const uint8_t * qs = xb->qs + 2*il; - device const uint8_t * sc = xb->scales + il; - const float dl1 = d * (2*(sc[0] & 7) + 1); - const float dl2 = d * (2*((sc[0] >> 4) & 7) + 1); - constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5))); - constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1))); - for (int i = 0; i < 8; ++i) { - reg[i/4+0][i%4] = dl1 * grid1[i]; - reg[i/4+2][i%4] = dl2 * grid2[i]; + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint16_t * qh = xb->qh; + const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1); + const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA); + const uint16_t h = qh[ib32] >> 6*il; + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) + ml; + reg[1][i] = dl * (grid1[i] >> 4) + ml; + reg[2][i] = dl * (grid2[i] & 0xf) + ml; + reg[3][i] = dl * (grid2[i] >> 4) + ml; + } +} + +template +void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + device const uint16_t * sc = (device const uint16_t *)xb->scales; +#if QK_K == 64 + const float d = xb->d; +#else + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const float d = scale.f16; +#endif + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint8_t * qh = xb->qh + 2*ib32 + il; +#if QK_K == 64 + const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1); +#else + const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1); +#endif + const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) + ml1; + reg[1][i] = dl * (grid1[i] >> 4) + ml1; + reg[2][i] = dl * (grid2[i] & 0xf) + ml2; + reg[3][i] = dl * (grid2[i] >> 4) + ml2; } } @@ -6042,7 +5637,7 @@ template kernel void kernel_mul_mm_id( - device const uchar * ids, + device const uchar * src0s, device const uchar * src1, device float * dst, + device const uchar * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne02, @@ -6225,29 +5821,21 @@ kernel void kernel_mul_mm_id( constant uint & r2, constant uint & r3, constant int & idx, - device const uchar * src00, - device const uchar * src01, - device const uchar * src02, - device const uchar * src03, - device const uchar * src04, - device const uchar * src05, - device const uchar * src06, - device const uchar * src07, threadgroup uchar * shared_memory [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; // expert id const int32_t id = tgpig.z/(ne12*ne13); + device const uchar * src0 = src0s + id*nb02; tgpig.z = tgpig.z%(ne12*ne13); // row indices of src1 for expert id - int64_t _ne1 = 0; - short src1ids[512]; + threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192); + int64_t _ne1 = 0; for (int64_t i1 = 0; i1 < ne1; i1++) { if (((device int32_t *) (ids + i1*nbi1))[idx] == id) { src1ids[_ne1++] = i1; @@ -6255,7 +5843,7 @@ kernel void kernel_mul_mm_id( } kernel_mul_mm_id_impl( - src0s[id], + src0, src1, src1ids, dst, @@ -6319,6 +5907,7 @@ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_r template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows; #if QK_K == 64 template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows; @@ -6349,24 +5938,25 @@ typedef void (mat_mm_t)( threadgroup uchar *, uint3, uint, uint); -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; #if QK_K == 64 template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; @@ -6379,9 +5969,10 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m // typedef void (mat_mm_id_t)( - device const uchar * ids, + device const uchar * src0s, device const uchar * src1, device float * dst, + device const uchar * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne02, @@ -6398,35 +5989,28 @@ typedef void (mat_mm_id_t)( constant uint & r2, constant uint & r3, constant int & idx, - device const uchar * src00, - device const uchar * src01, - device const uchar * src02, - device const uchar * src03, - device const uchar * src04, - device const uchar * src05, - device const uchar * src06, - device const uchar * src07, threadgroup uchar *, uint3, uint, uint); -template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; #if QK_K == 64 template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; @@ -6440,9 +6024,10 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel [[host_name("kernel_mul_mv_id_f32_f32")]] kernel void kernel_mul_mv_id_f32_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6463,28 +6048,19 @@ kernel void kernel_mul_mv_id_f32_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_f32_f32_impl( - src0[id], + src0, src1 + bid*nb11, dst + bid*ne0, ne00, @@ -6509,9 +6085,10 @@ kernel void kernel_mul_mv_id_f32_f32( [[host_name("kernel_mul_mv_id_f16_f32")]] kernel void kernel_mul_mv_id_f16_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6532,28 +6109,19 @@ kernel void kernel_mul_mv_id_f16_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_f16_f32_impl( - src0[id], + src0, src1 + bid*nb11, dst + bid*ne0, ne00, @@ -6578,9 +6146,10 @@ kernel void kernel_mul_mv_id_f16_f32( [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel void kernel_mul_mv_id_q8_0_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6601,28 +6170,19 @@ kernel void kernel_mul_mv_id_q8_0_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_q8_0_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6641,9 +6201,10 @@ kernel void kernel_mul_mv_id_q8_0_f32( [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel void kernel_mul_mv_id_q4_0_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6664,28 +6225,19 @@ kernel void kernel_mul_mv_id_q4_0_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; mul_vec_q_n_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6704,9 +6256,10 @@ kernel void kernel_mul_mv_id_q4_0_f32( [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel void kernel_mul_mv_id_q4_1_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6727,28 +6280,19 @@ kernel void kernel_mul_mv_id_q4_1_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; mul_vec_q_n_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6767,9 +6311,10 @@ kernel void kernel_mul_mv_id_q4_1_f32( [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel void kernel_mul_mv_id_q5_0_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6790,28 +6335,19 @@ kernel void kernel_mul_mv_id_q5_0_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; mul_vec_q_n_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6830,9 +6366,10 @@ kernel void kernel_mul_mv_id_q5_0_f32( [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel void kernel_mul_mv_id_q5_1_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6853,28 +6390,19 @@ kernel void kernel_mul_mv_id_q5_1_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; mul_vec_q_n_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6893,9 +6421,10 @@ kernel void kernel_mul_mv_id_q5_1_f32( [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel void kernel_mul_mv_id_q2_K_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6916,28 +6445,19 @@ kernel void kernel_mul_mv_id_q2_K_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_q2_K_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -6956,9 +6476,10 @@ kernel void kernel_mul_mv_id_q2_K_f32( [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel void kernel_mul_mv_id_q3_K_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -6979,28 +6500,19 @@ kernel void kernel_mul_mv_id_q3_K_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_q3_K_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7019,9 +6531,10 @@ kernel void kernel_mul_mv_id_q3_K_f32( [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel void kernel_mul_mv_id_q4_K_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7042,28 +6555,19 @@ kernel void kernel_mul_mv_id_q4_K_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_q4_K_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7082,9 +6586,10 @@ kernel void kernel_mul_mv_id_q4_K_f32( [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel void kernel_mul_mv_id_q5_K_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7105,28 +6610,19 @@ kernel void kernel_mul_mv_id_q5_K_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_q5_K_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7145,9 +6641,10 @@ kernel void kernel_mul_mv_id_q5_K_f32( [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel void kernel_mul_mv_id_q6_K_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7168,28 +6665,19 @@ kernel void kernel_mul_mv_id_q6_K_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_q6_K_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7208,9 +6696,10 @@ kernel void kernel_mul_mv_id_q6_K_f32( [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel void kernel_mul_mv_id_iq2_xxs_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7231,29 +6720,20 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq2_xxs_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7273,9 +6753,10 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32( [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel void kernel_mul_mv_id_iq2_xs_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7296,29 +6777,20 @@ kernel void kernel_mul_mv_id_iq2_xs_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq2_xs_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7338,9 +6810,10 @@ kernel void kernel_mul_mv_id_iq2_xs_f32( [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel void kernel_mul_mv_id_iq3_xxs_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7361,29 +6834,20 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq3_xxs_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7403,9 +6867,10 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32( [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel void kernel_mul_mv_id_iq3_s_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7426,29 +6891,20 @@ kernel void kernel_mul_mv_id_iq3_s_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq3_s_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7468,9 +6924,10 @@ kernel void kernel_mul_mv_id_iq3_s_f32( [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel void kernel_mul_mv_id_iq2_s_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7491,29 +6948,20 @@ kernel void kernel_mul_mv_id_iq2_s_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq2_s_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7533,9 +6981,10 @@ kernel void kernel_mul_mv_id_iq2_s_f32( [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel void kernel_mul_mv_id_iq1_s_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7556,28 +7005,74 @@ kernel void kernel_mul_mv_id_iq1_s_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq1_s_f32_impl( - src0[id], + src0, + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_iq1_m_f32")]] +kernel void kernel_mul_mv_id_iq1_m_f32( + device const char * src0s, + device const char * src1, + device float * dst, + device const char * ids, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; + + kernel_mul_mv_iq1_m_f32_impl( + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7596,9 +7091,10 @@ kernel void kernel_mul_mv_id_iq1_s_f32( [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel void kernel_mul_mv_id_iq4_nl_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7619,29 +7115,20 @@ kernel void kernel_mul_mv_id_iq4_nl_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup float * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; kernel_mul_mv_iq4_nl_f32_impl( - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, @@ -7661,9 +7148,10 @@ kernel void kernel_mul_mv_id_iq4_nl_f32( [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel void kernel_mul_mv_id_iq4_xs_f32( - device const char * ids, + device const char * src0s, device const char * src1, device float * dst, + device const char * ids, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -7684,33 +7172,24 @@ kernel void kernel_mul_mv_id_iq4_xs_f32( constant uint & r2, constant uint & r3, constant int & idx, - device const char * src00, - device const char * src01, - device const char * src02, - device const char * src03, - device const char * src04, - device const char * src05, - device const char * src06, - device const char * src07, threadgroup float * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; - const int64_t bid = tgpig.z/(ne12*ne13); tgpig.z = tgpig.z%(ne12*ne13); const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + device const char * src0 = src0s + id*nb02; #if QK_K == 64 kernel_mul_mv_iq4_nl_f32_impl( #else kernel_mul_mv_iq4_xs_f32_impl( #endif - src0[id], + src0, (device const float *) (src1 + bid*nb11), dst + bid*ne0, ne00, diff --git a/LLama/runtimes/deps/osx-arm64/libllama.dylib b/LLama/runtimes/deps/osx-arm64/libllama.dylib index 87295f843..b474697df 100644 Binary files a/LLama/runtimes/deps/osx-arm64/libllama.dylib and b/LLama/runtimes/deps/osx-arm64/libllama.dylib differ diff --git a/LLama/runtimes/deps/osx-arm64/libllava_shared.dylib b/LLama/runtimes/deps/osx-arm64/libllava_shared.dylib index 84ff71671..c47c71ed0 100644 Binary files a/LLama/runtimes/deps/osx-arm64/libllava_shared.dylib and b/LLama/runtimes/deps/osx-arm64/libllava_shared.dylib differ diff --git a/LLama/runtimes/deps/osx-x64/libllama.dylib b/LLama/runtimes/deps/osx-x64/libllama.dylib index df9c7280c..bbd84c9c2 100644 Binary files a/LLama/runtimes/deps/osx-x64/libllama.dylib and b/LLama/runtimes/deps/osx-x64/libllama.dylib differ diff --git a/LLama/runtimes/deps/osx-x64/libllava_shared.dylib b/LLama/runtimes/deps/osx-x64/libllava_shared.dylib index ac9bd1eca..c6b265ff4 100644 Binary files a/LLama/runtimes/deps/osx-x64/libllava_shared.dylib and b/LLama/runtimes/deps/osx-x64/libllava_shared.dylib differ diff --git a/llama.cpp b/llama.cpp index 3ab8b3a92..f7001ccc5 160000 --- a/llama.cpp +++ b/llama.cpp @@ -1 +1 @@ -Subproject commit 3ab8b3a92ede46df88bc5a2dfca3777de4a2b2b6 +Subproject commit f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7