diff --git a/LLama.Examples/Examples/BatchedExecutorGuidance.cs b/LLama.Examples/Examples/BatchedExecutorGuidance.cs
index 29f6aa33c..6f3eceaba 100644
--- a/LLama.Examples/Examples/BatchedExecutorGuidance.cs
+++ b/LLama.Examples/Examples/BatchedExecutorGuidance.cs
@@ -79,7 +79,7 @@ await AnsiConsole
guidance.Prompt(g);
// Early exit if we reach the natural end of the guided sentence
- if (g == model.EndOfSentenceToken)
+ if (g == model.Tokens.EOS)
break;
// Update progress bar
diff --git a/LLama.Examples/Examples/GetEmbeddings.cs b/LLama.Examples/Examples/GetEmbeddings.cs
index 9a816b054..1e10ba22b 100644
--- a/LLama.Examples/Examples/GetEmbeddings.cs
+++ b/LLama.Examples/Examples/GetEmbeddings.cs
@@ -9,7 +9,7 @@ public static void Run()
string modelPath = UserSettings.GetModelPath();
Console.ForegroundColor = ConsoleColor.DarkGray;
- var @params = new ModelParams(modelPath) { EmbeddingMode = true };
+ var @params = new ModelParams(modelPath) { Embeddings = true };
using var weights = LLamaWeights.LoadFromFile(@params);
var embedder = new LLamaEmbedder(weights, @params);
diff --git a/LLama.Examples/Examples/SemanticKernelMemory.cs b/LLama.Examples/Examples/SemanticKernelMemory.cs
index 1c9471d86..46c9a17d9 100644
--- a/LLama.Examples/Examples/SemanticKernelMemory.cs
+++ b/LLama.Examples/Examples/SemanticKernelMemory.cs
@@ -20,7 +20,7 @@ public static async Task Run()
var parameters = new ModelParams(modelPath)
{
Seed = seed,
- EmbeddingMode = true
+ Embeddings = true
};
using var model = LLamaWeights.LoadFromFile(parameters);
diff --git a/LLama.KernelMemory/BuilderExtensions.cs b/LLama.KernelMemory/BuilderExtensions.cs
index 474cf8be4..07770244b 100644
--- a/LLama.KernelMemory/BuilderExtensions.cs
+++ b/LLama.KernelMemory/BuilderExtensions.cs
@@ -84,7 +84,7 @@ public static IKernelMemoryBuilder WithLLamaSharpDefaults(this IKernelMemoryBuil
ContextSize = config?.ContextSize ?? 2048,
Seed = config?.Seed ?? 0,
GpuLayerCount = config?.GpuLayerCount ?? 20,
- EmbeddingMode = true,
+ Embeddings = true,
MainGpu = config?.MainGpu ?? 0,
SplitMode = config?.SplitMode ?? GPUSplitMode.None,
};
diff --git a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs
index 4a089fe49..d8c366bcf 100644
--- a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs
+++ b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs
@@ -29,7 +29,7 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config)
this._config = config;
var @params = new ModelParams(_config.ModelPath)
{
- EmbeddingMode = true,
+ Embeddings = true,
MainGpu = _config.MainGpu,
SplitMode = _config.SplitMode
};
@@ -49,7 +49,7 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config, LLamaWeights we
this._config = config;
var @params = new ModelParams(_config.ModelPath)
{
- EmbeddingMode = true,
+ Embeddings = true,
MainGpu = _config.MainGpu,
SplitMode = _config.SplitMode
};
diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs
index b8350336a..7c897b781 100644
--- a/LLama.Unittest/BasicTest.cs
+++ b/LLama.Unittest/BasicTest.cs
@@ -15,7 +15,7 @@ public sealed class BasicTest
public BasicTest(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
- _params = new ModelParams(Constants.ModelPath)
+ _params = new ModelParams(Constants.GenerativeModelPath)
{
ContextSize = 2048
};
diff --git a/LLama.Unittest/BeamTests.cs b/LLama.Unittest/BeamTests.cs
index 83eb87d35..f4aa01abe 100644
--- a/LLama.Unittest/BeamTests.cs
+++ b/LLama.Unittest/BeamTests.cs
@@ -15,7 +15,7 @@ public sealed class BeamTests
public BeamTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
- _params = new ModelParams(Constants.ModelPath)
+ _params = new ModelParams(Constants.GenerativeModelPath)
{
ContextSize = 2048
};
diff --git a/LLama.Unittest/Constants.cs b/LLama.Unittest/Constants.cs
index 6e6324491..6e5e92c54 100644
--- a/LLama.Unittest/Constants.cs
+++ b/LLama.Unittest/Constants.cs
@@ -2,9 +2,11 @@
{
internal static class Constants
{
- public static string ModelPath = "Models/llama-2-7b-chat.Q3_K_S.gguf";
- public static string LLavaModelPath = "Models/llava-v1.6-mistral-7b.Q3_K_XS.gguf";
- public static string LLavaMmpPath = "Models/mmproj-model-f16.gguf";
- public static string LLavaImage = "Models/extreme-ironing-taxi-610x427.jpg";
+ public static readonly string GenerativeModelPath = "Models/llama-2-7b-chat.Q3_K_S.gguf";
+ public static readonly string EmbeddingModelPath = "Models/all-MiniLM-L12-v2.Q8_0.gguf";
+
+ public static readonly string LLavaModelPath = "Models/llava-v1.6-mistral-7b.Q3_K_XS.gguf";
+ public static readonly string LLavaMmpPath = "Models/mmproj-model-f16.gguf";
+ public static readonly string LLavaImage = "Models/extreme-ironing-taxi-610x427.jpg";
}
}
diff --git a/LLama.Unittest/GrammarTest.cs b/LLama.Unittest/GrammarTest.cs
index c5fcf0ed3..1ab9dea61 100644
--- a/LLama.Unittest/GrammarTest.cs
+++ b/LLama.Unittest/GrammarTest.cs
@@ -12,7 +12,7 @@ public sealed class GrammarTest
public GrammarTest()
{
- _params = new ModelParams(Constants.ModelPath)
+ _params = new ModelParams(Constants.GenerativeModelPath)
{
ContextSize = 2048,
Seed = 92,
diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj
index d48b308e0..e7ef46231 100644
--- a/LLama.Unittest/LLama.Unittest.csproj
+++ b/LLama.Unittest/LLama.Unittest.csproj
@@ -31,6 +31,9 @@
+
+
+
@@ -43,6 +46,9 @@
+
+ PreserveNewest
+
PreserveNewest
diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs
index ab27d9887..fe247c6ed 100644
--- a/LLama.Unittest/LLamaContextTests.cs
+++ b/LLama.Unittest/LLamaContextTests.cs
@@ -11,7 +11,7 @@ public sealed class LLamaContextTests
public LLamaContextTests()
{
- var @params = new ModelParams(Constants.ModelPath)
+ var @params = new ModelParams(Constants.GenerativeModelPath)
{
ContextSize = 768,
};
diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs
index 9935fc863..379b8dc6e 100644
--- a/LLama.Unittest/LLamaEmbedderTests.cs
+++ b/LLama.Unittest/LLamaEmbedderTests.cs
@@ -1,5 +1,7 @@
using LLama.Common;
+using LLama.Native;
using Xunit.Abstractions;
+using Xunit.Sdk;
namespace LLama.Unittest;
@@ -12,11 +14,11 @@ public sealed class LLamaEmbedderTests
public LLamaEmbedderTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
- var @params = new ModelParams(Constants.ModelPath)
+ var @params = new ModelParams(Constants.EmbeddingModelPath)
{
ContextSize = 4096,
Threads = 5,
- EmbeddingMode = true,
+ Embeddings = true,
};
using var weights = LLamaWeights.LoadFromFile(@params);
_embedder = new(weights, @params);
@@ -38,8 +40,13 @@ private static float Dot(float[] a, float[] b)
public async Task EmbedCompare()
{
var cat = await _embedder.GetEmbeddings("The cat is cute");
+ Assert.DoesNotContain(float.NaN, cat);
+
var kitten = await _embedder.GetEmbeddings("The kitten is kawaii");
+ Assert.DoesNotContain(float.NaN, kitten);
+
var spoon = await _embedder.GetEmbeddings("The spoon is not real");
+ Assert.DoesNotContain(float.NaN, spoon);
_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
@@ -47,6 +54,11 @@ public async Task EmbedCompare()
var close = 1 - Dot(cat, kitten);
var far = 1 - Dot(cat, spoon);
+
+ _testOutputHelper.WriteLine("");
+ _testOutputHelper.WriteLine($"Cat.Kitten (Close): {close:F4}");
+ _testOutputHelper.WriteLine($"Cat.Spoon (Far): {far:F4}");
+
Assert.True(close < far);
}
}
\ No newline at end of file
diff --git a/LLama.Unittest/LLavaWeightsTests.cs b/LLama.Unittest/LLavaWeightsTests.cs
index 9ca0a7246..e5df30732 100644
--- a/LLama.Unittest/LLavaWeightsTests.cs
+++ b/LLama.Unittest/LLavaWeightsTests.cs
@@ -14,7 +14,7 @@ public sealed class LLavaWeightTests
public LLavaWeightTests()
{
- var @params = new ModelParams(Constants.ModelPath)
+ var @params = new ModelParams(Constants.GenerativeModelPath)
{
// Llava models requires big context
ContextSize = 4096
diff --git a/LLama.Unittest/MemoryDisposalTests.cs b/LLama.Unittest/MemoryDisposalTests.cs
index 4ee976a85..e29ad46d5 100644
--- a/LLama.Unittest/MemoryDisposalTests.cs
+++ b/LLama.Unittest/MemoryDisposalTests.cs
@@ -7,7 +7,7 @@ public class MemoryDisposalTests
[Fact]
public void ModelDisposal()
{
- var @params = new ModelParams(Constants.ModelPath)
+ var @params = new ModelParams(Constants.GenerativeModelPath)
{
ContextSize = 2048
};
@@ -21,7 +21,7 @@ public void ModelDisposal()
[Fact]
public void ContextDisposal()
{
- var @params = new ModelParams(Constants.ModelPath)
+ var @params = new ModelParams(Constants.GenerativeModelPath)
{
ContextSize = 2048
};
diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs
index cfe499734..3ca8a76e2 100644
--- a/LLama.Unittest/StatelessExecutorTest.cs
+++ b/LLama.Unittest/StatelessExecutorTest.cs
@@ -15,7 +15,7 @@ public class StatelessExecutorTest
public StatelessExecutorTest(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
- _params = new ModelParams(Constants.ModelPath)
+ _params = new ModelParams(Constants.GenerativeModelPath)
{
ContextSize = 60,
Seed = 1754,
diff --git a/LLama.Unittest/StreamingTextDecoderTests.cs b/LLama.Unittest/StreamingTextDecoderTests.cs
index 680ca076f..7291b5d07 100644
--- a/LLama.Unittest/StreamingTextDecoderTests.cs
+++ b/LLama.Unittest/StreamingTextDecoderTests.cs
@@ -14,7 +14,7 @@ public class StreamingTextDecoderTests
public StreamingTextDecoderTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
- _params = new ModelParams(Constants.ModelPath);
+ _params = new ModelParams(Constants.GenerativeModelPath);
_model = LLamaWeights.LoadFromFile(_params);
}
diff --git a/LLama.Unittest/TokenTests.cs b/LLama.Unittest/TokenTests.cs
index e39df5f47..c11e3ae96 100644
--- a/LLama.Unittest/TokenTests.cs
+++ b/LLama.Unittest/TokenTests.cs
@@ -12,7 +12,7 @@ public sealed class TokenTests
public TokenTests()
{
- _params = new ModelParams(Constants.ModelPath)
+ _params = new ModelParams(Constants.GenerativeModelPath)
{
ContextSize = 2048
};
diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs
index f1a14a90b..fb30af400 100644
--- a/LLama.Web/Common/ModelOptions.cs
+++ b/LLama.Web/Common/ModelOptions.cs
@@ -29,9 +29,13 @@ public class ModelOptions
///
public int GpuLayerCount { get; set; } = 20;
+ public uint SeqMax { get; }
+
///
public uint Seed { get; set; } = 1686349486;
+ public bool Embeddings { get; }
+
///
public bool UseMemorymap { get; set; } = true;
@@ -57,7 +61,7 @@ public class ModelOptions
public uint BatchSize { get; set; } = 512;
///
- public bool EmbeddingMode { get; set; } = false;
+ public uint UBatchSize { get; set; } = 512;
///
public TensorSplitsCollection TensorSplits { get; set; } = new();
@@ -108,6 +112,6 @@ public class ModelOptions
public float DefragThreshold { get; set; }
///
- public bool DoPooling { get; set; }
+ public LLamaPoolingType PoolingType { get; set; }
}
}
\ No newline at end of file
diff --git a/LLama/Abstractions/IContextParams.cs b/LLama/Abstractions/IContextParams.cs
index d8592a081..f56417169 100644
--- a/LLama/Abstractions/IContextParams.cs
+++ b/LLama/Abstractions/IContextParams.cs
@@ -14,20 +14,29 @@ public interface IContextParams
uint? ContextSize { get; }
///
- /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
+ /// maximum batch size that can be submitted at once (must be >=32 to use BLAS) (n_batch)
///
uint BatchSize { get; }
+ ///
+ /// Physical batch size
+ ///
+ uint UBatchSize { get; }
+
+ ///
+ /// max number of sequences (i.e. distinct states for recurrent models)
+ ///
+ uint SeqMax { get; }
+
///
/// Seed for the random number generator (seed)
///
uint Seed { get; }
///
- /// Whether to use embedding mode. (embedding) Note that if this is set to true,
- /// The LLamaModel won't produce text response anymore.
+ /// If true, extract embeddings (together with logits).
///
- bool EmbeddingMode { get; }
+ bool Embeddings { get; }
///
/// RoPE base frequency (null to fetch from the model)
@@ -105,7 +114,7 @@ public interface IContextParams
float DefragThreshold { get; }
///
- /// Whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
+ /// How to pool (sum) embedding results by sequence id (ignored if no pooling layer)
///
- bool DoPooling { get; }
+ LLamaPoolingType PoolingType { get; }
}
\ No newline at end of file
diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs
index 413e35f89..d1b5989b3 100644
--- a/LLama/Common/ModelParams.cs
+++ b/LLama/Common/ModelParams.cs
@@ -24,6 +24,9 @@ public record ModelParams
///
public int GpuLayerCount { get; set; } = 20;
+ ///
+ public uint SeqMax { get; set; } = 1;
+
///
public uint Seed { get; set; } = 0xFFFFFFFF;
@@ -52,7 +55,10 @@ public record ModelParams
public uint BatchSize { get; set; } = 512;
///
- public bool EmbeddingMode { get; set; }
+ public uint UBatchSize { get; set; } = 512;
+
+ ///
+ public bool Embeddings { get; set; }
///
public TensorSplitsCollection TensorSplits { get; set; } = new();
@@ -97,7 +103,7 @@ public record ModelParams
public float DefragThreshold { get; set; }
///
- public bool DoPooling { get; set; }
+ public LLamaPoolingType PoolingType { get; set; } = LLamaPoolingType.Unspecified;
///
public bool VocabOnly { get; set; }
diff --git a/LLama/Extensions/IContextParamsExtensions.cs b/LLama/Extensions/IContextParamsExtensions.cs
index 53cae6e97..fa1a36dd9 100644
--- a/LLama/Extensions/IContextParamsExtensions.cs
+++ b/LLama/Extensions/IContextParamsExtensions.cs
@@ -20,11 +20,14 @@ public static class IContextParamsExtensions
///
public static void ToLlamaContextParams(this IContextParams @params, out LLamaContextParams result)
{
- result = NativeApi.llama_context_default_params();
+ result = LLamaContextParams.Default();
+
result.n_ctx = @params.ContextSize ?? 0;
result.n_batch = @params.BatchSize;
+ result.n_ubatch = @params.UBatchSize;
+ result.n_seq_max = @params.SeqMax;
result.seed = @params.Seed;
- result.embedding = @params.EmbeddingMode;
+ result.embeddings = @params.Embeddings;
result.rope_freq_base = @params.RopeFrequencyBase ?? 0;
result.rope_freq_scale = @params.RopeFrequencyScale ?? 0;
@@ -41,10 +44,13 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
result.cb_eval = IntPtr.Zero;
result.cb_eval_user_data = IntPtr.Zero;
+ result.abort_callback = IntPtr.Zero;
+ result.abort_callback_user_data = IntPtr.Zero;
+
result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16;
result.type_k = @params.TypeV ?? GGMLType.GGML_TYPE_F16;
result.offload_kqv = !@params.NoKqvOffload;
- result.do_pooling = @params.DoPooling;
+ result.llama_pooling_type = @params.PoolingType;
result.n_threads = Threads(@params.Threads);
result.n_threads_batch = Threads(@params.BatchThreads);
diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs
index 69b9e288b..c7daa1351 100644
--- a/LLama/Extensions/IModelParamsExtensions.cs
+++ b/LLama/Extensions/IModelParamsExtensions.cs
@@ -28,7 +28,8 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam
var disposer = new GroupDisposable();
- result = NativeApi.llama_model_default_params();
+ result = LLamaModelParams.Default();
+
result.main_gpu = @params.MainGpu;
result.split_mode = @params.SplitMode;
result.n_gpu_layers = @params.GpuLayerCount < 0 ? int.MaxValue : @params.GpuLayerCount;
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index 9517965e6..e398982fe 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -152,6 +152,7 @@ public string DeTokenize(IReadOnlyList tokens)
return decoder.Read();
}
+ #region state load/save
///
/// Save the state to specified path.
///
@@ -163,7 +164,7 @@ public void SaveState(string filename)
File.Delete(filename);
// Estimate size of state to write to disk, this is always equal to or greater than the actual size
- var estimatedStateSize = (long)NativeApi.llama_get_state_size(NativeHandle);
+ var estimatedStateSize = checked((long)NativeHandle.GetStateSize());
// Map the file and write the bytes directly to it. This saves copying the bytes into a C# array
long writtenBytes;
@@ -174,8 +175,53 @@ public void SaveState(string filename)
{
byte* ptr = null;
view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
- writtenBytes = (long)NativeApi.llama_copy_state_data(NativeHandle, ptr);
- view.SafeMemoryMappedViewHandle.ReleasePointer();
+ try
+ {
+ writtenBytes = (long)NativeHandle.GetState(ptr, (ulong)estimatedStateSize);
+ }
+ finally
+ {
+ view.SafeMemoryMappedViewHandle.ReleasePointer();
+ }
+ }
+ }
+
+ // Truncate the file to the actual size of data that was written
+ using (var fileStream = new FileStream(filename, FileMode.Open))
+ fileStream.SetLength(writtenBytes);
+ }
+
+ ///
+ /// Save the state of a particular sequence to specified path.
+ ///
+ ///
+ ///
+ public void SaveState(string filename, LLamaSeqId sequence)
+ {
+ // Delete that file before overwriting it
+ if (File.Exists(filename))
+ File.Delete(filename);
+
+ // Estimate size of state to write to disk, this is always equal to or greater than the actual size
+ var estimatedStateSize = checked((long)NativeHandle.GetStateSize(sequence));
+
+ // Map the file and write the bytes directly to it. This saves copying the bytes into a C# array
+ long writtenBytes;
+ using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, estimatedStateSize))
+ using (var view = file.CreateViewAccessor(0, estimatedStateSize))
+ {
+ unsafe
+ {
+ byte* ptr = null;
+ view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
+ try
+ {
+ writtenBytes = (long)NativeHandle.GetState(ptr, (ulong)estimatedStateSize, sequence);
+ }
+ finally
+ {
+ view.SafeMemoryMappedViewHandle.ReleasePointer();
+ }
}
}
@@ -187,7 +233,7 @@ public void SaveState(string filename)
///
/// Get the state data as an opaque handle, which can be loaded later using
///
- /// Use if you intend to save this state to disk.
+ /// Use if you intend to save this state to disk.
///
public State GetState()
{
@@ -198,7 +244,11 @@ public State GetState()
try
{
// Copy the state data into memory, discover the actual size required
- var actualSize = NativeHandle.GetState(memory, stateSize);
+ ulong actualSize;
+ unsafe
+ {
+ actualSize = NativeHandle.GetState((byte*)memory, stateSize);
+ }
// Shrink to size
memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize);
@@ -218,11 +268,48 @@ public State GetState()
}
}
+ ///
+ /// Get the state data as an opaque handle, which can be loaded later using
+ ///
+ /// Use if you intend to save this state to disk.
+ ///
+ public SequenceState GetState(LLamaSeqId sequence)
+ {
+ var stateSize = NativeHandle.GetStateSize(sequence);
+
+ // Allocate a chunk of memory large enough to hold the entire state
+ var memory = Marshal.AllocHGlobal((nint)stateSize);
+ try
+ {
+ // Copy the state data into memory, discover the actual size required
+ ulong actualSize;
+ unsafe
+ {
+ actualSize = NativeHandle.GetState((byte*)memory, stateSize, sequence);
+ }
+
+ // Shrink to size
+ memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize);
+
+ // Wrap memory in a "state"
+ var state = new SequenceState(memory, actualSize);
+
+ // Set memory to zero, to prevent it being freed in finally block
+ memory = IntPtr.Zero;
+
+ return state;
+ }
+ finally
+ {
+ if (memory != IntPtr.Zero)
+ Marshal.FreeHGlobal(memory);
+ }
+ }
+
///
/// Load the state from specified path.
///
///
- ///
public void LoadState(string filename)
{
// Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from
@@ -233,8 +320,41 @@ public void LoadState(string filename)
{
byte* ptr = null;
view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
- NativeApi.llama_set_state_data(NativeHandle, ptr);
- view.SafeMemoryMappedViewHandle.ReleasePointer();
+ try
+ {
+ NativeHandle.SetState(ptr);
+ }
+ finally
+ {
+ view.SafeMemoryMappedViewHandle.ReleasePointer();
+ }
+ }
+ }
+ }
+
+ ///
+ /// Load the state from specified path into a particular sequence
+ ///
+ ///
+ ///
+ public void LoadState(string filename, LLamaSeqId sequence)
+ {
+ // Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from
+ using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Open, null))
+ using (var view = file.CreateViewAccessor())
+ {
+ unsafe
+ {
+ byte* ptr = null;
+ view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
+ try
+ {
+ NativeHandle.SetState(ptr, sequence);
+ }
+ finally
+ {
+ view.SafeMemoryMappedViewHandle.ReleasePointer();
+ }
}
}
}
@@ -248,10 +368,25 @@ public void LoadState(State state)
{
unsafe
{
- NativeHandle.SetState((byte*)state.DangerousGetHandle().ToPointer());
+ NativeHandle.SetState((byte*)state.DangerousGetHandle());
}
}
+ ///
+ /// Load the state from memory into a particular sequence
+ ///
+ ///
+ ///
+ ///
+ public void LoadState(SequenceState state, LLamaSeqId sequence)
+ {
+ unsafe
+ {
+ NativeHandle.SetState((byte*)state.DangerousGetHandle(), sequence);
+ }
+ }
+ #endregion
+
///
/// Sample a single token from this context, using the given sampling pipeline
///
@@ -357,8 +492,8 @@ public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable la
}
// Save the newline logit value
- var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle);
- var nl_logit = logits[(int)nl_token];
+ var nl_token = NativeHandle.ModelHandle.Tokens.Newline;
+ var nl_logit = logits[(int?)nl_token ?? 0];
// Convert logits into token candidates
var candidates_p = LLamaTokenDataArray.Create(logits);
@@ -371,7 +506,7 @@ public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable la
candidates_p.RepetitionPenalty(NativeHandle, last_n_array, repeatPenalty, alphaFrequency, alphaPresence);
// Restore newline token logit value if necessary
- if (!penalizeNL)
+ if (!penalizeNL && nl_token.HasValue)
{
var candidatesSpan = candidates_p.data.Span;
for (var i = 0; i < candidates_p.data.Length; i++)
@@ -417,12 +552,16 @@ public void Dispose()
}
///
- /// The state of this model, which can be reloaded later
+ /// The state of this context, which can be reloaded later
///
public class State
: SafeLLamaHandleBase
{
- private ulong _size;
+ private readonly ulong _size;
+ ///
+ /// Get the size in bytes of this state object
+ ///
+ public ulong Size => _size;
internal State(IntPtr memory, ulong size)
: base(memory, true)
@@ -441,6 +580,7 @@ protected override bool ReleaseHandle()
/// Convert this state to a byte array
///
///
+ [Obsolete("It is not generally safe to convert a state into a byte array - it will fail if the state is very large")]
public byte[] ToByteArray()
{
var bytes = new byte[_size];
@@ -453,6 +593,7 @@ public byte[] ToByteArray()
///
///
///
+ [Obsolete("It is not generally safe to convert a state into a byte array - it will fail if the state is very large")]
public static State FromByteArray(byte[] bytes)
{
var memory = Marshal.AllocHGlobal(bytes.Length);
@@ -460,5 +601,49 @@ public static State FromByteArray(byte[] bytes)
return new State(memory, (ulong)bytes.Length);
}
}
+
+ ///
+ /// The state of a single sequence, which can be reloaded later
+ ///
+ public class SequenceState
+ : SafeLLamaHandleBase
+ {
+ private readonly ulong _size;
+ ///
+ /// Get the size in bytes of this state object
+ ///
+ public ulong Size => _size;
+
+ internal SequenceState(IntPtr memory, ulong size)
+ : base(memory, true)
+ {
+ _size = size;
+ }
+
+ ///
+ protected override bool ReleaseHandle()
+ {
+ Marshal.FreeHGlobal(handle);
+ return true;
+ }
+
+ ///
+ /// Copy bytes to a desintation pointer.
+ ///
+ /// Destination to write to
+ /// Length of the destination buffer
+ /// Offset from start of src to start copying from
+ /// Number of bytes written to destination
+ public unsafe ulong CopyTo(byte* dst, ulong length, ulong offset = 0)
+ {
+ var copy = Math.Min(length, _size - offset);
+
+ var src = (byte*)DangerousGetHandle();
+ src += offset;
+
+ Buffer.MemoryCopy(src, dst, length, copy);
+ return copy;
+ }
+ }
}
}
diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs
index 86ecceb51..13a3e1c27 100644
--- a/LLama/LLamaEmbedder.cs
+++ b/LLama/LLamaEmbedder.cs
@@ -32,7 +32,7 @@ public sealed class LLamaEmbedder
///
public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
- if (!@params.EmbeddingMode)
+ if (!@params.Embeddings)
throw new ArgumentException("EmbeddingMode must be true", nameof(@params));
Context = weights.CreateContext(@params, logger);
@@ -75,7 +75,7 @@ public async Task GetEmbeddings(string text, bool addBos, CancellationT
n_eval = batchSize;
batch.Clear();
- batch.AddRange(tokens.AsSpan(i, n_eval), n_past, LLamaSeqId.Zero, false);
+ batch.AddRange(tokens.AsSpan(i, n_eval), n_past, LLamaSeqId.Zero, true);
n_past += n_eval;
var returnCode = await Context.DecodeAsync(batch, cancellationToken);
@@ -97,9 +97,10 @@ public async Task GetEmbeddings(string text, bool addBos, CancellationT
private float[] GetEmbeddingsArray()
{
- var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
- if (embeddings == null)
+ var embeddings = NativeApi.llama_get_embeddings_seq(Context.NativeHandle, LLamaSeqId.Zero);
+ if (embeddings.Length == 0)
return Array.Empty();
+
return embeddings.ToArray();
}
@@ -111,6 +112,9 @@ private static void Normalize(Span embeddings)
lengthSqr += value * value;
var length = (float)Math.Sqrt(lengthSqr);
+ if (length <= float.Epsilon)
+ return;
+
// Normalize
for (var i = 0; i < embeddings.Length; i++)
embeddings[i] /= length;
diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs
index c721726e8..c00ead4f3 100644
--- a/LLama/LLamaExecutorBase.cs
+++ b/LLama/LLamaExecutorBase.cs
@@ -135,7 +135,7 @@ public StatefulExecutorBase WithSessionFile(string filename)
{
_logger?.LogInformation($"[LLamaExecutor] Attempting to load saved session from {filename}");
var session_tokens = new LLamaToken[Context.ContextSize];
- if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, out var n_token_count_out))
+ if (!NativeApi.llama_state_load_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, out var n_token_count_out))
{
_logger?.LogError($"[LLamaExecutor] Failed to load session file {filename}");
throw new RuntimeError($"Failed to load session file {_pathSession}");
@@ -183,7 +183,7 @@ public StatefulExecutorBase WithSessionFile(string filename)
public void SaveSessionFile(string filename)
{
var session_token_array = _session_tokens.ToArray();
- NativeApi.llama_save_session_file(Context.NativeHandle, filename, session_token_array, (ulong)session_token_array.Length);
+ NativeApi.llama_state_save_file(Context.NativeHandle, filename, session_token_array, (ulong)session_token_array.Length);
}
///
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index 99d45e5a5..c3a9a420e 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -163,7 +163,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
}
}
- if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle))
+ if (_embeds.Count > 0 && _embeds.Last() == Context.NativeHandle.ModelHandle.Tokens.EOS)
{
args.WaitForInput = true;
}
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index 055a5f13d..5acf4bd3e 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -21,7 +21,6 @@ namespace LLama
public class InteractiveExecutor : StatefulExecutorBase
{
private bool _is_prompt_run = true;
- private readonly LLamaToken _llama_token_newline;
// LLava
private int _EmbedImagePosition = -1;
@@ -36,13 +35,11 @@ public class InteractiveExecutor : StatefulExecutorBase
public InteractiveExecutor(LLamaContext context, ILogger? logger = null)
: base(context, logger)
{
- _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle.ModelHandle);
}
public InteractiveExecutor(LLamaContext context, LLavaWeights clipModel, ILogger? logger = null)
: base(context, clipModel, logger)
{
- _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle.ModelHandle);
}
///
@@ -210,7 +207,7 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru
return (true, Array.Empty());
}
- if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle))
+ if (_embeds.Count > 0 && _embeds.Last() == Context.NativeHandle.ModelHandle.Tokens.EOS)
{
return (true, new[] { " [end of text]\n" });
}
@@ -308,9 +305,9 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
_last_n_tokens.Enqueue(id);
- if (id == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle))
+ if (id == Context.NativeHandle.ModelHandle.Tokens.EOS)
{
- id = _llama_token_newline;
+ id = Context.NativeHandle.ModelHandle.Tokens.Newline!.Value;
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
{
var first_antiprompt = Context.Tokenize(args.Antiprompts[0], false);
diff --git a/LLama/LLamaQuantizer.cs b/LLama/LLamaQuantizer.cs
index 5410f55f9..4e541e920 100644
--- a/LLama/LLamaQuantizer.cs
+++ b/LLama/LLamaQuantizer.cs
@@ -20,7 +20,8 @@ public static class LLamaQuantizer
///
/// Whether the quantization is successful.
///
- public static bool Quantize(string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = -1, bool allowRequantize = true, bool quantizeOutputTensor = false)
+ public static bool Quantize(
+ string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = -1, bool allowRequantize = true, bool quantizeOutputTensor = false)
{
if (!ValidateFtype(ftype))
{
@@ -28,11 +29,13 @@ public static bool Quantize(string srcFileName, string dstFilename, LLamaFtype f
$"to perform quantization.");
}
- var quantizeParams = NativeApi.llama_model_quantize_default_params();
+ var quantizeParams = LLamaModelQuantizeParams.Default();
quantizeParams.ftype = ftype;
quantizeParams.nthread = nthread;
quantizeParams.allow_requantize = allowRequantize;
quantizeParams.quantize_output_tensor = quantizeOutputTensor;
+ //todo: fill in other quantize params fields.
+
unsafe
{
return NativeApi.llama_model_quantize(srcFileName, dstFilename, &quantizeParams) == 0;
@@ -59,7 +62,7 @@ public static bool Quantize(string srcFileName, string dstFilename, string ftype
private static bool ValidateFtype(LLamaFtype ftype)
{
// Validation copies from here:
- // https://github.com/ggerganov/llama.cpp/blob/3ab8b3a92ede46df88bc5a2dfca3777de4a2b2b6/llama.cpp#L10965
+ // https://github.com/ggerganov/llama.cpp/blob/f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7/llama.cpp#L13450
switch (ftype)
{
@@ -95,6 +98,7 @@ private static bool ValidateFtype(LLamaFtype ftype)
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_XXS:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ1_S:
+ case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ1_M:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ4_NL:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ4_XS:
diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs
index f9d6ca5b2..487fe2935 100644
--- a/LLama/LLamaStatelessExecutor.cs
+++ b/LLama/LLamaStatelessExecutor.cs
@@ -124,7 +124,7 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams
}
// Check if this is the EOS token
- if (id == _weights.EndOfSentenceToken)
+ if (id == _weights.Tokens.EOS)
break;
// Decode this token into text
diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs
index 0adb19875..2d8ea4d9b 100644
--- a/LLama/LLamaWeights.cs
+++ b/LLama/LLamaWeights.cs
@@ -41,24 +41,14 @@ public sealed class LLamaWeights
public ulong ParameterCount => NativeHandle.ParameterCount;
///
- /// Get the newline token for this model
- ///
- public LLamaToken NewlineToken => NativeApi.llama_token_nl(NativeHandle);
-
- ///
- /// Get the "end of sentence" token for this model
- ///
- public LLamaToken EndOfSentenceToken => NativeApi.llama_token_eos(NativeHandle);
-
- ///
- /// Get the "beginning of sentence" token for this model
+ /// Dimension of embedding vectors
///
- public LLamaToken BeginningOfSentenceToken => NativeApi.llama_token_bos(NativeHandle);
+ public int EmbeddingSize => NativeHandle.EmbeddingSize;
///
- /// Dimension of embedding vectors
+ /// Get the special tokens of this model
///
- public int EmbeddingSize => NativeHandle.EmbeddingSize;
+ public SafeLlamaModelHandle.ModelTokens Tokens => NativeHandle.Tokens;
///
/// All metadata keys in this model
diff --git a/LLama/Native/LLamaBeamView.cs b/LLama/Native/LLamaBeamView.cs
index e832eb620..dcd583ba3 100644
--- a/LLama/Native/LLamaBeamView.cs
+++ b/LLama/Native/LLamaBeamView.cs
@@ -10,7 +10,7 @@ namespace LLama.Native;
public struct LLamaBeamView
{
private unsafe LLamaToken* tokens;
- private nint n_tokens;
+ private nuint n_tokens;
///
/// Cumulative beam probability (renormalized relative to all beams)
diff --git a/LLama/Native/LLamaBeamsState.cs b/LLama/Native/LLamaBeamsState.cs
index f78c45b97..cb214aef3 100644
--- a/LLama/Native/LLamaBeamsState.cs
+++ b/LLama/Native/LLamaBeamsState.cs
@@ -19,7 +19,7 @@ public struct LLamaBeamsState
///
/// Number of elements in beam_views
///
- private nint n_beams;
+ private nuint n_beams;
///
/// Current max length of prefix tokens shared by all beams.
diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs
index c7e4a3287..8e3d7f74f 100644
--- a/LLama/Native/LLamaContextParams.cs
+++ b/LLama/Native/LLamaContextParams.cs
@@ -28,10 +28,20 @@ public struct LLamaContextParams
public uint n_ctx;
///
- /// prompt processing batch size
+ /// logical maximum batch size that can be submitted to llama_decode
///
public uint n_batch;
+ ///
+ /// physical maximum batch size
+ ///
+ public uint n_ubatch;
+
+ ///
+ /// max number of sequences (i.e. distinct states for recurrent models)
+ ///
+ public uint n_seq_max;
+
///
/// number of threads to use for generation
///
@@ -45,7 +55,12 @@ public struct LLamaContextParams
///
/// RoPE scaling type, from `enum llama_rope_scaling_type`
///
- public RopeScalingType rope_scaling_type;
+ public RopeScalingType rope_scaling_type;
+
+ ///
+ /// whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
+ ///
+ public LLamaPoolingType llama_pooling_type;
///
/// RoPE base frequency, 0 = from model
@@ -87,11 +102,13 @@ public struct LLamaContextParams
///
public float defrag_threshold;
+ //todo: implement cb_eval callback support
///
/// ggml_backend_sched_eval_callback
///
public IntPtr cb_eval;
+ //todo: implement cb_eval callback support
///
/// User data passed into cb_eval
///
@@ -113,14 +130,14 @@ public struct LLamaContextParams
private sbyte _logits_all;
///
- /// embedding mode only
+ /// if true, extract embeddings (together with logits)
///
- public bool embedding
+ public bool embeddings
{
- readonly get => Convert.ToBoolean(_embedding);
- set => _embedding = Convert.ToSByte(value);
+ readonly get => Convert.ToBoolean(_embeddings);
+ set => _embeddings = Convert.ToSByte(value);
}
- private sbyte _embedding;
+ private sbyte _embeddings;
///
/// whether to offload the KQV ops (including the KV cache) to GPU
@@ -132,15 +149,29 @@ public bool offload_kqv
}
private sbyte _offload_kqv;
+ //todo: implement abort callback support
+ ///
+ /// ggml_abort_callback
+ ///
+ public IntPtr abort_callback;
+
+ //todo: implement abort callback support
///
- /// Whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
+ /// User data passed into abort_callback
///
- public bool do_pooling
+ public IntPtr abort_callback_user_data;
+
+ ///
+ /// Get the default LLamaContextParams
+ ///
+ ///
+ public static LLamaContextParams Default()
{
- readonly get => Convert.ToBoolean(_do_pooling);
- set => _do_pooling = Convert.ToSByte(value);
+ return llama_context_default_params();
+
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ static extern LLamaContextParams llama_context_default_params();
}
- private sbyte _do_pooling;
}
}
diff --git a/LLama/Native/LLamaFtype.cs b/LLama/Native/LLamaFtype.cs
index e943c6ff9..ae2702db2 100644
--- a/LLama/Native/LLamaFtype.cs
+++ b/LLama/Native/LLamaFtype.cs
@@ -166,6 +166,11 @@ public enum LLamaFtype
///
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30,
+ ///
+ /// except 1d tensors
+ ///
+ LLAMA_FTYPE_MOSTLY_IQ1_M = 31,
+
///
/// File type was not specified
///
diff --git a/LLama/Native/LLamaKvCacheView.cs b/LLama/Native/LLamaKvCacheView.cs
index 2b4087720..86169c602 100644
--- a/LLama/Native/LLamaKvCacheView.cs
+++ b/LLama/Native/LLamaKvCacheView.cs
@@ -28,7 +28,7 @@ public unsafe struct LLamaKvCacheView
// Maximum number of sequences that can exist in a cell. It's not an error
// if there are more sequences in a cell than this value, however they will
// not be visible in the view cells_sequences.
- int n_max_seq;
+ int n_seq_max;
// Number of tokens in the cache. For example, if there are two populated
// cells, the first with 1 sequence id in it and the second with 2 sequence
@@ -48,7 +48,7 @@ public unsafe struct LLamaKvCacheView
// Information for an individual cell.
LLamaKvCacheViewCell* cells;
- // The sequences for each cell. There will be n_max_seq items per cell.
+ // The sequences for each cell. There will be n_seq_max items per cell.
LLamaSeqId* cells_sequences;
}
@@ -118,10 +118,10 @@ public static partial class NativeApi
/// Create an empty KV cache view. (use only for debugging purposes)
///
///
- ///
+ ///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern LLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_max_seq);
+ public static extern LLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_seq_max);
///
/// Free a KV cache view. (use only for debugging purposes)
diff --git a/LLama/Native/LLamaModelParams.cs b/LLama/Native/LLamaModelParams.cs
index aff2c514a..923b042c5 100644
--- a/LLama/Native/LLamaModelParams.cs
+++ b/LLama/Native/LLamaModelParams.cs
@@ -1,84 +1,96 @@
-using System;
-using System.Runtime.InteropServices;
-
-namespace LLama.Native
-{
- ///
- /// A C# representation of the llama.cpp `llama_model_params` struct
- ///
- [StructLayout(LayoutKind.Sequential)]
- public unsafe struct LLamaModelParams
- {
- ///
- /// // number of layers to store in VRAM
- ///
- public int n_gpu_layers;
-
- ///
- /// how to split the model across multiple GPUs
- ///
- public GPUSplitMode split_mode;
-
- ///
- /// the GPU that is used for scratch and small tensors
- ///
- public int main_gpu;
-
- ///
- /// how to split layers across multiple GPUs (size: )
- ///
+using System;
+using System.Runtime.InteropServices;
+
+namespace LLama.Native
+{
+ ///
+ /// A C# representation of the llama.cpp `llama_model_params` struct
+ ///
+ [StructLayout(LayoutKind.Sequential)]
+ public unsafe struct LLamaModelParams
+ {
+ ///
+ /// // number of layers to store in VRAM
+ ///
+ public int n_gpu_layers;
+
+ ///
+ /// how to split the model across multiple GPUs
+ ///
+ public GPUSplitMode split_mode;
+
+ ///
+ /// the GPU that is used for scratch and small tensors
+ ///
+ public int main_gpu;
+
+ ///
+ /// how to split layers across multiple GPUs (size: )
+ ///
public float* tensor_split;
- ///
- /// called with a progress value between 0 and 1, pass NULL to disable. If the provided progress_callback
- /// returns true, model loading continues. If it returns false, model loading is immediately aborted.
- ///
-#if NETSTANDARD2_0
+ ///
+ /// called with a progress value between 0 and 1, pass NULL to disable. If the provided progress_callback
+ /// returns true, model loading continues. If it returns false, model loading is immediately aborted.
+ ///
+#if NETSTANDARD2_0
// this code is intended to be used when running LlamaSharp on NET Framework 4.8 (NET Standard 2.0)
// as NET Framework 4.8 does not play nice with the LlamaProgressCallback type
- public IntPtr progress_callback;
-#else
- public LlamaProgressCallback progress_callback;
-#endif
-
- ///
- /// context pointer passed to the progress callback
- ///
- public void* progress_callback_user_data;
-
- ///
- /// override key-value pairs of the model meta data
- ///
- public LLamaModelMetadataOverride* kv_overrides;
-
- ///
- /// only load the vocabulary, no weights
- ///
- public bool vocab_only
- {
- readonly get => Convert.ToBoolean(_vocab_only);
- set => _vocab_only = Convert.ToSByte(value);
- }
- private sbyte _vocab_only;
-
- ///
- /// use mmap if possible
- ///
- public bool use_mmap
- {
- readonly get => Convert.ToBoolean(_use_mmap);
- set => _use_mmap = Convert.ToSByte(value);
- }
- private sbyte _use_mmap;
-
- ///
- /// force system to keep model in RAM
- ///
- public bool use_mlock
- {
- readonly get => Convert.ToBoolean(_use_mlock);
- set => _use_mlock = Convert.ToSByte(value);
- }
- private sbyte _use_mlock;
- }
-}
+ public IntPtr progress_callback;
+#else
+ public LlamaProgressCallback progress_callback;
+#endif
+
+ ///
+ /// context pointer passed to the progress callback
+ ///
+ public void* progress_callback_user_data;
+
+ ///
+ /// override key-value pairs of the model meta data
+ ///
+ public LLamaModelMetadataOverride* kv_overrides;
+
+ ///
+ /// only load the vocabulary, no weights
+ ///
+ public bool vocab_only
+ {
+ readonly get => Convert.ToBoolean(_vocab_only);
+ set => _vocab_only = Convert.ToSByte(value);
+ }
+ private sbyte _vocab_only;
+
+ ///
+ /// use mmap if possible
+ ///
+ public bool use_mmap
+ {
+ readonly get => Convert.ToBoolean(_use_mmap);
+ set => _use_mmap = Convert.ToSByte(value);
+ }
+ private sbyte _use_mmap;
+
+ ///
+ /// force system to keep model in RAM
+ ///
+ public bool use_mlock
+ {
+ readonly get => Convert.ToBoolean(_use_mlock);
+ set => _use_mlock = Convert.ToSByte(value);
+ }
+ private sbyte _use_mlock;
+
+ ///
+ /// Create a LLamaModelParams with default values
+ ///
+ ///
+ public static LLamaModelParams Default()
+ {
+ return llama_model_default_params();
+
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ static extern LLamaModelParams llama_model_default_params();
+ }
+ }
+}
diff --git a/LLama/Native/LLamaModelQuantizeParams.cs b/LLama/Native/LLamaModelQuantizeParams.cs
index 34c1a9743..b2d37eb05 100644
--- a/LLama/Native/LLamaModelQuantizeParams.cs
+++ b/LLama/Native/LLamaModelQuantizeParams.cs
@@ -20,6 +20,16 @@ public struct LLamaModelQuantizeParams
///
public LLamaFtype ftype;
+ ///
+ /// output tensor type
+ ///
+ public GGMLType output_tensor_type;
+
+ ///
+ /// itoken embeddings tensor type
+ ///
+ public GGMLType token_embedding_type;
+
///
/// allow quantizing non-f32/f16 tensors
///
@@ -51,7 +61,7 @@ public bool only_copy
private sbyte _only_copy;
///
- /// disable k-quant mixtures and quantize all tensors to the same type
+ /// quantize all tensors to the default type
///
public bool pure
{
@@ -64,5 +74,22 @@ public bool pure
/// pointer to importance matrix data
///
public IntPtr imatrix;
+
+ ///
+ /// pointer to vector containing overrides
+ ///
+ public IntPtr kv_overrides;
+
+ ///
+ /// Create a LLamaModelQuantizeParams with default values
+ ///
+ ///
+ public static LLamaModelQuantizeParams Default()
+ {
+ return llama_model_quantize_default_params();
+
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ static extern LLamaModelQuantizeParams llama_model_quantize_default_params();
+ }
}
}
diff --git a/LLama/Native/LLamaPoolingType.cs b/LLama/Native/LLamaPoolingType.cs
index 13bdfe6c2..cd928849f 100644
--- a/LLama/Native/LLamaPoolingType.cs
+++ b/LLama/Native/LLamaPoolingType.cs
@@ -6,6 +6,7 @@
/// llama_pooling_type
public enum LLamaPoolingType
{
+ Unspecified = -1,
None = 0,
Mean = 1,
CLS = 2,
diff --git a/LLama/Native/LLamaVocabType.cs b/LLama/Native/LLamaVocabType.cs
index 6620f5101..d3cdf2dc3 100644
--- a/LLama/Native/LLamaVocabType.cs
+++ b/LLama/Native/LLamaVocabType.cs
@@ -6,7 +6,23 @@
/// llama_vocab_type
public enum LLamaVocabType
{
- SentencePiece = 0,
- BytePairEncoding = 1,
- WordPiece = 2,
+ ///
+ /// For models without vocab
+ ///
+ None = 0,
+
+ ///
+ /// LLaMA tokenizer based on byte-level BPE with byte fallback
+ ///
+ SentencePiece = 1,
+
+ ///
+ /// GPT-2 tokenizer based on byte-level BPE
+ ///
+ BytePairEncoding = 2,
+
+ ///
+ /// BERT tokenizer based on WordPiece
+ ///
+ WordPiece = 3,
}
\ No newline at end of file
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index 41c1809e0..6f8d142e3 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -29,27 +29,6 @@ public static void llama_empty_call()
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern long llama_max_devices();
- ///
- /// Create a LLamaModelParams with default values
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern LLamaModelParams llama_model_default_params();
-
- ///
- /// Create a LLamaContextParams with default values
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern LLamaContextParams llama_context_default_params();
-
- ///
- /// Create a LLamaModelQuantizeParams with default values
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern LLamaModelQuantizeParams llama_model_quantize_default_params();
-
///
/// Check if memory mapping is supported
///
@@ -72,8 +51,9 @@ public static void llama_empty_call()
public static extern bool llama_supports_gpu_offload();
///
- /// Initialize the llama + ggml backend
- /// Call once at the start of the program
+ /// Initialize the llama + ggml backend. Call once at the start of the program.
+ ///
+ /// This is private because LLamaSharp automatically calls it, and it's only valid to call it once!
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_backend_init();
@@ -87,42 +67,6 @@ public static void llama_empty_call()
//[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
//public static extern void llama_numa_init(ggml_numa_strategy numa);
- ///
- /// Sets the current rng seed.
- ///
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern void llama_set_rng_seed(SafeLLamaContextHandle ctx, uint seed);
-
- ///
- /// Returns the maximum size in bytes of the state (rng, logits, embedding
- /// and kv_cache) - will often be smaller after compacting tokens
- ///
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern ulong llama_get_state_size(SafeLLamaContextHandle ctx);
-
- ///
- /// Copies the state to the specified destination address.
- /// Destination needs to have allocated enough memory.
- ///
- ///
- ///
- /// the number of bytes copied
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern unsafe ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte* dest);
-
- ///
- /// Set the state reading from the specified address
- ///
- ///
- ///
- /// the number of bytes read
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern unsafe ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte* src);
-
///
/// Load session file
///
@@ -133,7 +77,7 @@ public static void llama_empty_call()
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out);
+ public static extern bool llama_state_load_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out);
///
/// Save session file
@@ -144,63 +88,97 @@ public static void llama_empty_call()
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern bool llama_save_session_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens, ulong n_token_count);
+ public static extern bool llama_state_save_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens, ulong n_token_count);
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, LLamaToken token);
+ public static extern unsafe nuint llama_state_seq_save_file(SafeLLamaContextHandle ctx, string filepath, LLamaSeqId seq_id, LLamaToken* tokens, nuint n_token_count);
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern float llama_token_get_score(SafeLlamaModelHandle model, LLamaToken token);
+ public static extern unsafe nuint llama_state_seq_load_file(SafeLLamaContextHandle ctx, string filepath, LLamaSeqId dest_seq_id, LLamaToken* tokens_out, nuint n_token_capacity, out nuint n_token_count_out);
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern LLamaTokenType llama_token_get_type(SafeLlamaModelHandle model, LLamaToken token);
+ public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, LLamaToken token);
///
- /// Get the size of the context window for the model for this context
+ /// Set whether to use causal attention or not. If set to true, the model will only attend to the past tokens
///
- ///
- ///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern uint llama_n_ctx(SafeLLamaContextHandle ctx);
+ public static extern void llama_set_causal_attn(SafeLlamaModelHandle ctx, bool causal_attn);
///
- /// Get the batch size for this context
+ /// Set abort callback
///
- ///
- ///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern uint llama_n_batch(SafeLLamaContextHandle ctx);
+ public static extern void llama_set_abort_callback(SafeLlamaModelHandle ctx, IntPtr /* ggml_abort_callback */ abort_callback, IntPtr abort_callback_data);
///
- /// Token logits obtained from the last call to llama_decode
- /// The logits for the last token are stored in the last row
- /// Can be mutated in order to change the probabilities of the next token.
- /// Rows: n_tokens
- /// Cols: n_vocab
+ /// Wait until all computations are finished. This is automatically done when using any of the functions to obtain computation results
+ /// and is not necessary to call it explicitly in most cases.
///
///
- ///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern unsafe float* llama_get_logits(SafeLLamaContextHandle ctx);
+ public static extern void llama_synchronize(SafeLlamaModelHandle ctx);
+
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern float llama_token_get_score(SafeLlamaModelHandle model, LLamaToken token);
+
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern LLamaTokenType llama_token_get_type(SafeLlamaModelHandle model, LLamaToken token);
///
- /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab
+ /// Get the n_seq_max for this context
///
///
- ///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern unsafe float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i);
+ public static extern uint llama_n_seq_max(SafeLLamaContextHandle ctx);
///
- /// Get the embeddings for the ith sequence. Equivalent to: llama_get_embeddings(ctx) + i*n_embd
+ /// Get the embeddings for the a specific sequence.
+ /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
///
///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern unsafe float* llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i);
+ public static Span llama_get_embeddings_seq(SafeLLamaContextHandle ctx, LLamaSeqId id)
+ {
+ unsafe
+ {
+ var ptr = llama_get_embeddings_seq_native(ctx, id);
+ if (ptr == null)
+ return Array.Empty();
+
+ return new Span(ptr, ctx.EmbeddingSize);
+ }
+
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings_seq")]
+ static extern unsafe float* llama_get_embeddings_seq_native(SafeLLamaContextHandle ctx, LLamaSeqId id);
+ }
+
+ ///
+ /// Get the embeddings for the ith sequence.
+ /// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
+ ///
+ ///
+ public static Span llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i)
+ {
+ unsafe
+ {
+ var ptr = llama_get_embeddings_ith_native(ctx, i);
+ if (ptr == null)
+ return Array.Empty();
+
+ return new Span(ptr, ctx.EmbeddingSize);
+ }
+
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings_ith")]
+ static extern unsafe float* llama_get_embeddings_ith_native(SafeLLamaContextHandle ctx, int i);
+ }
///
- /// Get the embeddings for the input
+ /// Get all output token embeddings.
+ /// When pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, the embeddings for which
+ /// llama_batch.logits[i] != 0 are stored contiguously in the order they have appeared in the batch.
+ /// shape: [n_outputs*n_embd]
+ /// Otherwise, returns an empty span.
///
///
///
@@ -209,6 +187,9 @@ public static Span llama_get_embeddings(SafeLLamaContextHandle ctx)
unsafe
{
var ptr = llama_get_embeddings_native(ctx);
+ if (ptr == null)
+ return Array.Empty();
+
return new Span(ptr, ctx.EmbeddingSize);
}
@@ -230,28 +211,7 @@ public static Span llama_get_embeddings(SafeLLamaContextHandle ctx)
/// The size of the allocated buffer
/// The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings")]
- public static extern unsafe int llama_chat_apply_template(SafeLlamaModelHandle model, char* tmpl, LLamaChatMessage* chat, nint n_msg, bool add_ass, char* buf, int length);
-
- ///
- /// Get the "Beginning of sentence" token
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern LLamaToken llama_token_bos(SafeLlamaModelHandle model);
-
- ///
- /// Get the "End of sentence" token
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern LLamaToken llama_token_eos(SafeLlamaModelHandle model);
-
- ///
- /// Get the "new line" token
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern LLamaToken llama_token_nl(SafeLlamaModelHandle model);
+ public static extern unsafe int llama_chat_apply_template(SafeLlamaModelHandle model, char* tmpl, LLamaChatMessage* chat, nuint n_msg, bool add_ass, char* buf, int length);
///
/// Returns -1 if unknown, 1 for true or 0 for false.
@@ -267,34 +227,6 @@ public static Span llama_get_embeddings(SafeLLamaContextHandle ctx)
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_add_eos_token(SafeLlamaModelHandle model);
- ///
- /// codellama infill tokens, Beginning of infill prefix
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_token_prefix(SafeLlamaModelHandle model);
-
- ///
- /// codellama infill tokens, Beginning of infill middle
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_token_middle(SafeLlamaModelHandle model);
-
- ///
- /// codellama infill tokens, Beginning of infill suffix
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_token_suffix(SafeLlamaModelHandle model);
-
- ///
- /// codellama infill tokens, End of infill middle
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_token_eot(SafeLlamaModelHandle model);
-
///
/// Print out timing information for this context
///
@@ -345,13 +277,13 @@ public static int llama_token_to_piece(SafeLlamaModelHandle model, LLamaToken ll
///
///
///
- ///
- /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
+ ///
+ /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
/// Returns the number of tokens on success, no more than n_max_tokens.
/// Returns a negative number on failure - the number of tokens that would have been returned
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern unsafe int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, LLamaToken* tokens, int n_max_tokens, bool add_bos, bool special);
+ public static extern unsafe int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, LLamaToken* tokens, int n_max_tokens, bool add_special, bool parse_special);
///
/// Register a callback to receive llama log messages
@@ -377,8 +309,9 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
///
///
///
+ /// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern void llama_kv_cache_seq_rm(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1);
+ public static extern bool llama_kv_cache_seq_rm(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1);
///
/// Copy all tokens that belong to the specified sequence to another sequence
@@ -441,23 +374,6 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaPos llama_kv_cache_seq_pos_max(SafeLLamaContextHandle ctx, LLamaSeqId seq);
- ///
- /// Defragment the KV cache. This will be applied:
- /// - lazily on next llama_decode()
- /// - explicitly with llama_kv_cache_update()
- ///
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern LLamaPos llama_kv_cache_defrag(SafeLLamaContextHandle ctx);
-
- ///
- /// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern void llama_kv_cache_update(SafeLLamaContextHandle ctx);
-
///
/// Allocates a batch of tokens on the heap
/// Each token can be assigned up to n_seq_max sequence ids
@@ -481,31 +397,47 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
public static extern void llama_batch_free(LLamaNativeBatch batch);
///
+ /// Apply a loaded control vector to a llama_context, or if data is NULL, clear
+ /// the currently loaded vector.
+ /// n_embd should be the size of a single layer's control, and data should point
+ /// to an n_embd x n_layers buffer starting from layer 1.
+ /// il_start and il_end are the layer range the vector should apply to (both inclusive)
+ /// See llama_control_vector_load in common to load a control vector.
///
///
- ///
- /// Positive return values does not mean a fatal error, but rather a warning:
- /// - 0: success
- /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
- /// - < 0: error
- ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_decode(SafeLLamaContextHandle ctx, LLamaNativeBatch batch);
+ public static extern unsafe int llama_control_vector_apply(SafeLLamaContextHandle ctx, float* data, nuint len, int n_embd, int il_start, int il_end);
///
- /// Set the number of threads used for decoding
+ /// Build a split GGUF final path for this chunk.
+ /// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf"
///
- ///
- /// n_threads is the number of threads used for generation (single token)
- /// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
- ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// Returns the split_path length.
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern void llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch);
-
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern LLamaVocabType llama_vocab_type(SafeLlamaModelHandle model);
+ public static extern int llama_split_path(string split_path, nuint maxlen, string path_prefix, int split_no, int split_count);
+ ///
+ /// Extract the path prefix from the split_path if and only if the split_no and split_count match.
+ /// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0"
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// Returns the split_prefix length.
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern LLamaRopeType llama_rope_type(SafeLlamaModelHandle model);
+ public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no, int split_count);
}
}
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index 2f881fa5d..c9e959a09 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -22,7 +22,7 @@ public sealed class SafeLLamaContextHandle
///
/// Total number of tokens in the context
///
- public uint ContextSize => NativeApi.llama_n_ctx(this);
+ public uint ContextSize => llama_n_ctx(this);
///
/// Dimension of embedding vectors
@@ -32,7 +32,12 @@ public sealed class SafeLLamaContextHandle
///
/// Get the maximum batch size for this context
///
- public uint BatchSize => NativeApi.llama_n_batch(this);
+ public uint BatchSize => llama_n_batch(this);
+
+ ///
+ /// Get the physical maximum batch size for this context
+ ///
+ public uint UBatchSize => llama_n_ubatch(this);
///
/// Get the model which this context is using
@@ -127,8 +132,157 @@ static SafeLLamaContextHandle()
///
///
private unsafe delegate bool GgmlAbortCallback(void* data);
+
+ ///
+ ///
+ ///
+ ///
+ /// Positive return values does not mean a fatal error, but rather a warning:
+ /// - 0: success
+ /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
+ /// - < 0: error
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern int llama_decode(SafeLLamaContextHandle ctx, LLamaNativeBatch batch);
+
+ ///
+ /// Set the number of threads used for decoding
+ ///
+ ///
+ /// n_threads is the number of threads used for generation (single token)
+ /// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern void llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch);
+
+ ///
+ /// Token logits obtained from the last call to llama_decode
+ /// The logits for the last token are stored in the last row
+ /// Can be mutated in order to change the probabilities of the next token.
+ /// Rows: n_tokens
+ /// Cols: n_vocab
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern unsafe float* llama_get_logits(SafeLLamaContextHandle ctx);
+
+ ///
+ /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab
+ ///
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern unsafe float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i);
+
+ ///
+ /// Get the size of the context window for the model for this context
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern uint llama_n_ctx(SafeLLamaContextHandle ctx);
+
+ ///
+ /// Get the batch size for this context
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern uint llama_n_batch(SafeLLamaContextHandle ctx);
+
+ ///
+ /// Get the ubatch size for this context
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern uint llama_n_ubatch(SafeLLamaContextHandle ctx);
+
+ ///
+ /// Sets the current rng seed.
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern void llama_set_rng_seed(SafeLLamaContextHandle ctx, uint seed);
+
+ ///
+ /// Returns the maximum size in bytes of the state (rng, logits, embedding
+ /// and kv_cache) - will often be smaller after compacting tokens
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern ulong llama_state_get_size(SafeLLamaContextHandle ctx);
+
+ ///
+ /// Copies the state to the specified destination address.
+ /// Destination needs to have allocated enough memory.
+ ///
+ ///
+ ///
+ /// the number of bytes copied
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern unsafe ulong llama_state_get_data(SafeLLamaContextHandle ctx, byte* dest);
+
+ ///
+ /// Set the state reading from the specified address
+ ///
+ ///
+ ///
+ /// the number of bytes read
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern unsafe ulong llama_state_set_data(SafeLLamaContextHandle ctx, byte* src);
+
+ ///
+ /// Get the exact size needed to copy the KV cache of a single sequence
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern nuint llama_state_seq_get_size(SafeLLamaContextHandle ctx, LLamaSeqId seq_id);
+
+ ///
+ /// Copy the KV cache of a single sequence into the specified buffer
+ ///
+ ///
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern unsafe nuint llama_state_seq_get_data(SafeLLamaContextHandle ctx, byte* dst, LLamaSeqId seq_id);
+
+ ///
+ /// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// - Positive: Ok
+ /// - Zero: Failed to load
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern unsafe nuint llama_state_seq_set_data(SafeLLamaContextHandle ctx, byte* src, LLamaSeqId dest_seq_id);
+
+ ///
+ /// Defragment the KV cache. This will be applied:
+ /// - lazily on next llama_decode()
+ /// - explicitly with llama_kv_cache_update()
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern void llama_kv_cache_defrag(SafeLLamaContextHandle ctx);
+
+ ///
+ /// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern void llama_kv_cache_update(SafeLLamaContextHandle ctx);
#endregion
-
+
///
/// Token logits obtained from the last call to llama_decode
/// The logits for the last token are stored in the last row
@@ -143,7 +297,7 @@ public Span GetLogits()
unsafe
{
- var logits = NativeApi.llama_get_logits(this);
+ var logits = llama_get_logits(this);
return new Span(logits, model.VocabCount);
}
}
@@ -159,7 +313,7 @@ public Span GetLogitsIth(int i)
unsafe
{
- var logits = NativeApi.llama_get_logits_ith(this, i);
+ var logits = llama_get_logits_ith(this, i);
return new Span(logits, model.VocabCount);
}
}
@@ -216,7 +370,7 @@ public DecodeResult Decode(LLamaBatch batch)
{
lock (GlobalInferenceLock)
using (batch.ToNativeBatch(out var nb))
- return (DecodeResult)NativeApi.llama_decode(this, nb);
+ return (DecodeResult)llama_decode(this, nb);
}
///
@@ -260,19 +414,17 @@ public DecodeResult Decode(LLamaBatch batch)
///
public ulong GetStateSize()
{
- return NativeApi.llama_get_state_size(this);
+ return llama_state_get_size(this);
}
///
- /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer.
+ /// Get the size of the KV cache for a single sequence ID, when saved as bytes
///
- /// Destination to write to
- /// Number of bytes available to write to in dest (check required size with `GetStateSize()`)
- /// The number of bytes written to dest
- /// Thrown if dest is too small
- public unsafe ulong GetState(byte* dest, ulong size)
+ ///
+ ///
+ public ulong GetStateSize(LLamaSeqId sequence)
{
- return GetState(new IntPtr(dest), size);
+ return llama_state_seq_get_size(this, sequence);
}
///
@@ -282,7 +434,7 @@ public unsafe ulong GetState(byte* dest, ulong size)
/// Number of bytes available to write to in dest (check required size with `GetStateSize()`)
/// The number of bytes written to dest
/// Thrown if dest is too small
- public ulong GetState(IntPtr dest, ulong size)
+ public unsafe ulong GetState(byte* dest, ulong size)
{
var required = GetStateSize();
if (size < required)
@@ -290,10 +442,26 @@ public ulong GetState(IntPtr dest, ulong size)
unsafe
{
- return NativeApi.llama_copy_state_data(this, (byte*)dest.ToPointer());
+ return llama_state_get_data(this, dest);
}
}
+ ///
+ /// Get the raw state of a single sequence from this context, encoded as bytes. Data is written into the `dest` pointer.
+ ///
+ /// Destination to write to
+ /// Number of bytes available to write to in dest (check required size with `GetStateSize()`)
+ /// The sequence to get state data for
+ /// The number of bytes written to dest
+ public unsafe ulong GetState(byte* dest, ulong size, LLamaSeqId sequence)
+ {
+ var required = GetStateSize(sequence);
+ if (size < required)
+ throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}");
+
+ return llama_state_seq_get_data(this, dest, sequence);
+ }
+
///
/// Set the raw state of this context
///
@@ -301,20 +469,18 @@ public ulong GetState(IntPtr dest, ulong size)
/// Number of bytes read from the src pointer
public unsafe ulong SetState(byte* src)
{
- return SetState(new IntPtr(src));
+ return llama_state_set_data(this, src);
}
///
- /// Set the raw state of this context
+ /// Set the raw state of a single sequence
///
/// The pointer to read the state from
+ /// Sequence ID to set
/// Number of bytes read from the src pointer
- public ulong SetState(IntPtr src)
+ public unsafe ulong SetState(byte* src, LLamaSeqId sequence)
{
- unsafe
- {
- return NativeApi.llama_set_state_data(this, (byte*)src.ToPointer());
- }
+ return llama_state_seq_set_data(this, src, sequence);
}
#endregion
@@ -324,7 +490,7 @@ public ulong SetState(IntPtr src)
///
public void SetSeed(uint seed)
{
- NativeApi.llama_set_rng_seed(this, seed);
+ llama_set_rng_seed(this, seed);
}
///
@@ -334,10 +500,29 @@ public void SetSeed(uint seed)
/// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
public void SetThreads(uint threads, uint threadsBatch)
{
- NativeApi.llama_set_n_threads(this, threads, threadsBatch);
+ llama_set_n_threads(this, threads, threadsBatch);
}
#region KV Cache Management
+ ///
+ /// Apply KV cache updates (such as K-shifts, defragmentation, etc.)
+ ///
+ public void KvCacheUpdate()
+ {
+ llama_kv_cache_update(this);
+ }
+
+ ///
+ /// Defragment the KV cache. This will be applied:
+ /// - lazily on next llama_decode()
+ /// - explicitly with llama_kv_cache_update()
+ ///
+ ///
+ public void KvCacheDefrag()
+ {
+ llama_kv_cache_defrag(this);
+ }
+
///
/// Get a new KV cache view that can be used to debug the KV cache
///
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index acaee849a..31b73dfd4 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -22,6 +22,16 @@ public sealed class SafeLlamaModelHandle
///
public int VocabCount => llama_n_vocab(this);
+ ///
+ /// Get the vocabulary type for this model
+ ///
+ public LLamaVocabType VocabType => llama_vocab_type(this);
+
+ ///
+ /// Get the rope (positional embedding) type for this model
+ ///
+ public LLamaRopeType RopeType => llama_rope_type(this);
+
///
/// Total number of tokens in the context
///
@@ -47,6 +57,11 @@ public sealed class SafeLlamaModelHandle
///
public ulong ParameterCount => llama_model_n_params(this);
+ ///
+ /// Get the number of layers in this model
+ ///
+ public int LayerCount => llama_n_embd(this);
+
///
/// Get a description of this model
///
@@ -74,6 +89,13 @@ public string Description
///
public int MetadataCount => llama_model_meta_count(this);
+ private ModelTokens? _tokens;
+
+ ///
+ /// Get the special tokens of this model
+ ///
+ public ModelTokens Tokens => _tokens ??= new ModelTokens(this);
+
///
protected override bool ReleaseHandle()
{
@@ -105,7 +127,7 @@ public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelPara
#region native API
static SafeLlamaModelHandle()
{
- // This ensures that `NativeApi` has been loaded before calling the two native methods below
+ // Ensure that `NativeApi` has been loaded
NativeApi.llama_empty_call();
}
@@ -132,7 +154,7 @@ static SafeLlamaModelHandle()
///
/// Returns 0 on success
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, float scale, string? path_base_model, int n_threads);
+ private static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, float scale, string? path_base_model, int n_threads);
///
/// Frees all allocated memory associated with a model
@@ -210,6 +232,12 @@ private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model,
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_n_vocab(SafeLlamaModelHandle model);
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern LLamaVocabType llama_vocab_type(SafeLlamaModelHandle model);
+
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern LLamaRopeType llama_rope_type(SafeLlamaModelHandle model);
+
///
/// Get the size of the context window for the model
///
@@ -226,6 +254,14 @@ private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model,
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_n_embd(SafeLlamaModelHandle model);
+ ///
+ /// Get the number of layers in this model
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern int llama_n_layers(SafeLlamaModelHandle model);
+
///
/// Get a string describing the model type
///
@@ -259,6 +295,69 @@ private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model,
///
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern float llama_rope_freq_scale_train(SafeLlamaModelHandle model);
+
+ ///
+ /// Get the "Beginning of sentence" token
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern LLamaToken llama_token_bos(SafeLlamaModelHandle model);
+
+ ///
+ /// Get the "End of sentence" token
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern LLamaToken llama_token_eos(SafeLlamaModelHandle model);
+
+ ///
+ /// Get the "classification" token
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern LLamaToken llama_token_cls(SafeLlamaModelHandle model);
+
+ ///
+ /// Get the "sentence separator" token
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern LLamaToken llama_token_sep(SafeLlamaModelHandle model);
+
+ ///
+ /// Get the "new line" token
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern LLamaToken llama_token_nl(SafeLlamaModelHandle model);
+
+ ///
+ /// codellama infill tokens, Beginning of infill prefix
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern int llama_token_prefix(SafeLlamaModelHandle model);
+
+ ///
+ /// codellama infill tokens, Beginning of infill middle
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern int llama_token_middle(SafeLlamaModelHandle model);
+
+ ///
+ /// codellama infill tokens, Beginning of infill suffix
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern int llama_token_suffix(SafeLlamaModelHandle model);
+
+ ///
+ /// codellama infill tokens, End of infill middle
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern int llama_token_eot(SafeLlamaModelHandle model);
#endregion
#region LoRA
@@ -273,6 +372,14 @@ private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model,
///
public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, int? threads = null)
{
+ // Try to open the model file, this will check:
+ // - File exists (automatically throws FileNotFoundException)
+ // - File is readable (explicit check)
+ // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases.
+ using (var fs = new FileStream(lora, FileMode.Open))
+ if (!fs.CanRead)
+ throw new InvalidOperationException($"LoRA file '{lora}' is not readable");
+
var err = llama_model_apply_lora_from_file(
this,
lora,
@@ -282,7 +389,7 @@ public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null
);
if (err != 0)
- throw new RuntimeError("Failed to apply lora adapter.");
+ throw new RuntimeError($"Failed to apply lora adapter (err={err}).");
}
#endregion
@@ -451,5 +558,68 @@ internal IReadOnlyDictionary ReadMetadata()
return result;
}
#endregion
+
+ ///
+ /// Get tokens for a model
+ ///
+ public class ModelTokens
+ {
+ private readonly SafeLlamaModelHandle _model;
+
+ internal ModelTokens(SafeLlamaModelHandle model)
+ {
+ _model = model;
+ }
+
+ private static LLamaToken? Normalize(LLamaToken token)
+ {
+ return token == -1 ? null : token;
+ }
+
+ ///
+ /// Get the Beginning of Sentence token for this model
+ ///
+ public LLamaToken? BOS => Normalize(llama_token_bos(_model));
+
+ ///
+ /// Get the End of Sentence token for this model
+ ///
+ public LLamaToken? EOS => Normalize(llama_token_eos(_model));
+
+ ///
+ /// Get the newline token for this model
+ ///
+ public LLamaToken? Newline => Normalize(llama_token_nl(_model));
+
+ ///
+ /// Get the classification token for this model
+ ///
+ public LLamaToken? CLS => Normalize(llama_token_cls(_model));
+
+ ///
+ /// Get the sentence separator token for this model
+ ///
+ public LLamaToken? SEP => Normalize(llama_token_sep(_model));
+
+ ///
+ /// Codellama beginning of infill prefix
+ ///
+ public LLamaToken? InfillPrefix => Normalize(llama_token_prefix(_model));
+
+ ///
+ /// Codellama beginning of infill middle
+ ///
+ public LLamaToken? InfillMiddle => Normalize(llama_token_middle(_model));
+
+ ///
+ /// Codellama beginning of infill suffix
+ ///
+ public LLamaToken? InfillSuffix => Normalize(llama_token_suffix(_model));
+
+ ///
+ /// Codellama end of infill middle
+ ///
+ public LLamaToken? EOT => Normalize(llama_token_eot(_model));
+ }
}
}
diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs
index 5a9ef16cf..33806f5f9 100644
--- a/LLama/Sampling/DefaultSamplingPipeline.cs
+++ b/LLama/Sampling/DefaultSamplingPipeline.cs
@@ -133,18 +133,21 @@ protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx,
private static (int, float) GetNewlineLogit(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
- var nlToken = NativeApi.llama_token_nl(ctx.ModelHandle);
+ var nlToken = ctx.ModelHandle.Tokens.Newline;
- // Try using the ID as an index
- if (candidates.data.Span[(int)nlToken].id == nlToken)
- return ((int)nlToken, candidates.data.Span[(int)nlToken].logit);
-
- // Exhaustive search
- var span = candidates.data.Span;
- for (var i = 0; i < span.Length; i++)
+ if (nlToken.HasValue)
{
- if (span[i].id == nlToken)
- return (i, span[i].logit);
+ // Try using the ID as an index
+ if (candidates.data.Span[(int)nlToken].id == nlToken)
+ return ((int)nlToken, candidates.data.Span[(int)nlToken].logit);
+
+ // Exhaustive search
+ var span = candidates.data.Span;
+ for (var i = 0; i < span.Length; i++)
+ {
+ if (span[i].id == nlToken)
+ return (i, span[i].logit);
+ }
}
return (-1, 0);
@@ -152,7 +155,9 @@ private static (int, float) GetNewlineLogit(SafeLLamaContextHandle ctx, LLamaTok
private static void SetNewlineLogit(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int indexHint, float logit)
{
- var nlToken = NativeApi.llama_token_nl(ctx.ModelHandle);
+ var nlToken = ctx.ModelHandle.Tokens.Newline;
+ if (!nlToken.HasValue)
+ return;
// Try checking the index where we found it last time. It might not be there if `RepetitionPenalty` changed order
if (indexHint >= 0 && candidates.data.Span[indexHint].id == nlToken)
diff --git a/LLama/runtimes/deps/avx/libllama.dll b/LLama/runtimes/deps/avx/libllama.dll
index f09813ae6..3b3b3fae3 100644
Binary files a/LLama/runtimes/deps/avx/libllama.dll and b/LLama/runtimes/deps/avx/libllama.dll differ
diff --git a/LLama/runtimes/deps/avx/libllama.so b/LLama/runtimes/deps/avx/libllama.so
index 6929f1a4a..08dcee0f7 100644
Binary files a/LLama/runtimes/deps/avx/libllama.so and b/LLama/runtimes/deps/avx/libllama.so differ
diff --git a/LLama/runtimes/deps/avx/libllava_shared.so b/LLama/runtimes/deps/avx/libllava_shared.so
index 5868e2ed5..1c8adfcb7 100644
Binary files a/LLama/runtimes/deps/avx/libllava_shared.so and b/LLama/runtimes/deps/avx/libllava_shared.so differ
diff --git a/LLama/runtimes/deps/avx/llama.dll b/LLama/runtimes/deps/avx/llama.dll
index f09813ae6..3b3b3fae3 100644
Binary files a/LLama/runtimes/deps/avx/llama.dll and b/LLama/runtimes/deps/avx/llama.dll differ
diff --git a/LLama/runtimes/deps/avx/llava_shared.dll b/LLama/runtimes/deps/avx/llava_shared.dll
index 546da7588..e08a474b6 100644
Binary files a/LLama/runtimes/deps/avx/llava_shared.dll and b/LLama/runtimes/deps/avx/llava_shared.dll differ
diff --git a/LLama/runtimes/deps/avx2/libllama.dll b/LLama/runtimes/deps/avx2/libllama.dll
index 481be2352..bb8e5c48b 100644
Binary files a/LLama/runtimes/deps/avx2/libllama.dll and b/LLama/runtimes/deps/avx2/libllama.dll differ
diff --git a/LLama/runtimes/deps/avx2/libllama.so b/LLama/runtimes/deps/avx2/libllama.so
index fd2548fd7..e7c27f79f 100644
Binary files a/LLama/runtimes/deps/avx2/libllama.so and b/LLama/runtimes/deps/avx2/libllama.so differ
diff --git a/LLama/runtimes/deps/avx2/libllava_shared.so b/LLama/runtimes/deps/avx2/libllava_shared.so
index 18ca3617d..f9bbdf272 100644
Binary files a/LLama/runtimes/deps/avx2/libllava_shared.so and b/LLama/runtimes/deps/avx2/libllava_shared.so differ
diff --git a/LLama/runtimes/deps/avx2/llama.dll b/LLama/runtimes/deps/avx2/llama.dll
index 481be2352..bb8e5c48b 100644
Binary files a/LLama/runtimes/deps/avx2/llama.dll and b/LLama/runtimes/deps/avx2/llama.dll differ
diff --git a/LLama/runtimes/deps/avx2/llava_shared.dll b/LLama/runtimes/deps/avx2/llava_shared.dll
index f877c590f..6b4ad9c13 100644
Binary files a/LLama/runtimes/deps/avx2/llava_shared.dll and b/LLama/runtimes/deps/avx2/llava_shared.dll differ
diff --git a/LLama/runtimes/deps/avx512/libllama.dll b/LLama/runtimes/deps/avx512/libllama.dll
index 9f3030289..fcbc052eb 100644
Binary files a/LLama/runtimes/deps/avx512/libllama.dll and b/LLama/runtimes/deps/avx512/libllama.dll differ
diff --git a/LLama/runtimes/deps/avx512/libllama.so b/LLama/runtimes/deps/avx512/libllama.so
index 6700ef7cd..a0044ad66 100644
Binary files a/LLama/runtimes/deps/avx512/libllama.so and b/LLama/runtimes/deps/avx512/libllama.so differ
diff --git a/LLama/runtimes/deps/avx512/libllava_shared.so b/LLama/runtimes/deps/avx512/libllava_shared.so
index 8f2a9677f..d0c76ef13 100644
Binary files a/LLama/runtimes/deps/avx512/libllava_shared.so and b/LLama/runtimes/deps/avx512/libllava_shared.so differ
diff --git a/LLama/runtimes/deps/avx512/llama.dll b/LLama/runtimes/deps/avx512/llama.dll
index 9f3030289..fcbc052eb 100644
Binary files a/LLama/runtimes/deps/avx512/llama.dll and b/LLama/runtimes/deps/avx512/llama.dll differ
diff --git a/LLama/runtimes/deps/avx512/llava_shared.dll b/LLama/runtimes/deps/avx512/llava_shared.dll
index e0cfbe44a..8d643cb1d 100644
Binary files a/LLama/runtimes/deps/avx512/llava_shared.dll and b/LLama/runtimes/deps/avx512/llava_shared.dll differ
diff --git a/LLama/runtimes/deps/clblast/libllama.so b/LLama/runtimes/deps/clblast/libllama.so
index 9b5f87900..b6bff999a 100644
Binary files a/LLama/runtimes/deps/clblast/libllama.so and b/LLama/runtimes/deps/clblast/libllama.so differ
diff --git a/LLama/runtimes/deps/clblast/libllava_shared.so b/LLama/runtimes/deps/clblast/libllava_shared.so
index 764e7266d..6f63d183a 100644
Binary files a/LLama/runtimes/deps/clblast/libllava_shared.so and b/LLama/runtimes/deps/clblast/libllava_shared.so differ
diff --git a/LLama/runtimes/deps/clblast/llama.dll b/LLama/runtimes/deps/clblast/llama.dll
index a08951358..055b24a84 100644
Binary files a/LLama/runtimes/deps/clblast/llama.dll and b/LLama/runtimes/deps/clblast/llama.dll differ
diff --git a/LLama/runtimes/deps/clblast/llava_shared.dll b/LLama/runtimes/deps/clblast/llava_shared.dll
index e4a51d0ba..349ec89e5 100644
Binary files a/LLama/runtimes/deps/clblast/llava_shared.dll and b/LLama/runtimes/deps/clblast/llava_shared.dll differ
diff --git a/LLama/runtimes/deps/cu11.7.1/libllama.so b/LLama/runtimes/deps/cu11.7.1/libllama.so
index ef9baa519..1f1a79a0f 100644
Binary files a/LLama/runtimes/deps/cu11.7.1/libllama.so and b/LLama/runtimes/deps/cu11.7.1/libllama.so differ
diff --git a/LLama/runtimes/deps/cu11.7.1/libllava_shared.so b/LLama/runtimes/deps/cu11.7.1/libllava_shared.so
index 7ad6a066e..47cba9b13 100644
Binary files a/LLama/runtimes/deps/cu11.7.1/libllava_shared.so and b/LLama/runtimes/deps/cu11.7.1/libllava_shared.so differ
diff --git a/LLama/runtimes/deps/cu11.7.1/llama.dll b/LLama/runtimes/deps/cu11.7.1/llama.dll
index 22cd79574..a1cd82b25 100644
Binary files a/LLama/runtimes/deps/cu11.7.1/llama.dll and b/LLama/runtimes/deps/cu11.7.1/llama.dll differ
diff --git a/LLama/runtimes/deps/cu11.7.1/llava_shared.dll b/LLama/runtimes/deps/cu11.7.1/llava_shared.dll
index a5d1c514a..00e9794e6 100644
Binary files a/LLama/runtimes/deps/cu11.7.1/llava_shared.dll and b/LLama/runtimes/deps/cu11.7.1/llava_shared.dll differ
diff --git a/LLama/runtimes/deps/cu12.1.0/libllama.so b/LLama/runtimes/deps/cu12.1.0/libllama.so
index ac66c69f8..39b09e6b3 100644
Binary files a/LLama/runtimes/deps/cu12.1.0/libllama.so and b/LLama/runtimes/deps/cu12.1.0/libllama.so differ
diff --git a/LLama/runtimes/deps/cu12.1.0/libllava_shared.so b/LLama/runtimes/deps/cu12.1.0/libllava_shared.so
index 166633a80..ce830a7da 100644
Binary files a/LLama/runtimes/deps/cu12.1.0/libllava_shared.so and b/LLama/runtimes/deps/cu12.1.0/libllava_shared.so differ
diff --git a/LLama/runtimes/deps/cu12.1.0/llama.dll b/LLama/runtimes/deps/cu12.1.0/llama.dll
index b12c7776d..09a87cb20 100644
Binary files a/LLama/runtimes/deps/cu12.1.0/llama.dll and b/LLama/runtimes/deps/cu12.1.0/llama.dll differ
diff --git a/LLama/runtimes/deps/cu12.1.0/llava_shared.dll b/LLama/runtimes/deps/cu12.1.0/llava_shared.dll
index fdef226c3..597733ed3 100644
Binary files a/LLama/runtimes/deps/cu12.1.0/llava_shared.dll and b/LLama/runtimes/deps/cu12.1.0/llava_shared.dll differ
diff --git a/LLama/runtimes/deps/libllama.dll b/LLama/runtimes/deps/libllama.dll
index bd256c0a8..deb86e0df 100644
Binary files a/LLama/runtimes/deps/libllama.dll and b/LLama/runtimes/deps/libllama.dll differ
diff --git a/LLama/runtimes/deps/libllama.so b/LLama/runtimes/deps/libllama.so
index 66176da16..85dce3430 100644
Binary files a/LLama/runtimes/deps/libllama.so and b/LLama/runtimes/deps/libllama.so differ
diff --git a/LLama/runtimes/deps/libllava_shared.so b/LLama/runtimes/deps/libllava_shared.so
index 1c2c17144..f41a9c670 100644
Binary files a/LLama/runtimes/deps/libllava_shared.so and b/LLama/runtimes/deps/libllava_shared.so differ
diff --git a/LLama/runtimes/deps/llama.dll b/LLama/runtimes/deps/llama.dll
index bd256c0a8..deb86e0df 100644
Binary files a/LLama/runtimes/deps/llama.dll and b/LLama/runtimes/deps/llama.dll differ
diff --git a/LLama/runtimes/deps/llava_shared.dll b/LLama/runtimes/deps/llava_shared.dll
index d1aafcad9..10b057945 100644
Binary files a/LLama/runtimes/deps/llava_shared.dll and b/LLama/runtimes/deps/llava_shared.dll differ
diff --git a/LLama/runtimes/deps/osx-arm64/ggml-metal.metal b/LLama/runtimes/deps/osx-arm64/ggml-metal.metal
index 74a5e0b03..9a29f57a3 100644
--- a/LLama/runtimes/deps/osx-arm64/ggml-metal.metal
+++ b/LLama/runtimes/deps/osx-arm64/ggml-metal.metal
@@ -1,3 +1,7 @@
+#define GGML_COMMON_DECL_METAL
+#define GGML_COMMON_IMPL_METAL
+#include "ggml-common.h"
+
#include
using namespace metal;
@@ -6,46 +10,11 @@ using namespace metal;
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
-#define QK4_0 32
-#define QR4_0 2
-typedef struct {
- half d; // delta
- uint8_t qs[QK4_0 / 2]; // nibbles / quants
-} block_q4_0;
-
-#define QK4_1 32
-typedef struct {
- half d; // delta
- half m; // min
- uint8_t qs[QK4_1 / 2]; // nibbles / quants
-} block_q4_1;
-
-#define QK5_0 32
-typedef struct {
- half d; // delta
- uint8_t qh[4]; // 5-th bit of quants
- uint8_t qs[QK5_0 / 2]; // nibbles / quants
-} block_q5_0;
-
-#define QK5_1 32
-typedef struct {
- half d; // delta
- half m; // min
- uint8_t qh[4]; // 5-th bit of quants
- uint8_t qs[QK5_1 / 2]; // nibbles / quants
-} block_q5_1;
-
-#define QK8_0 32
-typedef struct {
- half d; // delta
- int8_t qs[QK8_0]; // quants
-} block_q8_0;
-
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
enum ggml_sort_order {
- GGML_SORT_ASC,
- GGML_SORT_DESC,
+ GGML_SORT_ORDER_ASC,
+ GGML_SORT_ORDER_DESC,
};
// general-purpose kernel for addition, multiplication and division of two tensors
@@ -1959,11 +1928,56 @@ kernel void kernel_pad_f32(
}
}
+kernel void kernel_arange_f32(
+ device char * dst,
+ constant int64_t & ne0,
+ constant float & start,
+ constant float & step,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ device float * dst_ptr = (device float *) dst;
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ dst_ptr[i0] = start + step * i0;
+ }
+}
+
+kernel void kernel_timestep_embedding_f32(
+ device const char * src0,
+ device char * dst,
+ constant uint64_t & nb1,
+ constant int & dim,
+ constant int & max_period,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ int i = tgpig.x;
+ device float * embed_data = (device float *)(dst + i*nb1);
+
+ int half_ = dim / 2;
+ for (int j = tpitg.x; j < half_; j += ntg.x) {
+ float timestep = ((device float *)src0)[i];
+ float freq = (float)exp(-log((float)max_period) * j / half_);
+ float arg = timestep * freq;
+ embed_data[j ] = cos(arg);
+ embed_data[j + half_] = sin(arg);
+ }
+
+ if (dim % 2 != 0 && tpitg.x == 0) {
+ embed_data[dim] = 0.f;
+ }
+}
+
// bitonic sort implementation following the CUDA kernels as reference
typedef void (argsort_t)(
- device const float * x,
- device int32_t * dst,
- constant int64_t & ncols,
+ device const float * x,
+ device int32_t * dst,
+ constant int64_t & ncols,
+ constant int64_t & ncols_pad,
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]);
@@ -1972,33 +1986,42 @@ kernel void kernel_argsort_f32_i32(
device const float * x,
device int32_t * dst,
constant int64_t & ncols,
+ constant int64_t & ncols_pad,
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]) {
// bitonic sort
int col = tpitg[0];
int row = tgpig[1];
- if (col >= ncols) return;
+ if (col >= ncols_pad) return;
- device const float * x_row = x + row * ncols;
- device int32_t * dst_row = dst + row * ncols;
+ device const float * x_row = x + row * ncols;
+ threadgroup int32_t * dst_row = shared_values;
// initialize indices
- if (col < ncols) {
- dst_row[col] = col;
- }
+ dst_row[col] = col;
+
threadgroup_barrier(mem_flags::mem_threadgroup);
- for (int k = 2; k <= ncols; k *= 2) {
+ for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
- if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
+ if (dst_row[col] >= ncols ||
+ (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+ ) {
SWAP(dst_row[col], dst_row[ixj]);
}
} else {
- if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
+ if (dst_row[ixj] >= ncols ||
+ (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+ ) {
SWAP(dst_row[col], dst_row[ixj]);
}
}
@@ -2006,10 +2029,15 @@ kernel void kernel_argsort_f32_i32(
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
+
+ // copy the result to dst without the padding
+ if (col < ncols) {
+ dst[row * ncols + col] = dst_row[col];
+ }
}
-template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32;
-template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32;
+template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32;
+template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32;
kernel void kernel_leaky_relu_f32(
device const float * src0,
@@ -2376,6 +2404,242 @@ kernel void kernel_cpy_f32_q4_1(
}
}
+kernel void kernel_cpy_f32_q5_0(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
+
+ device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < QK5_0; j++) {
+ const float v = src[j];
+ if (amax < fabs(v)) {
+ amax = fabs(v);
+ max = v;
+ }
+ }
+
+ const float d = max / -16;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK5_0].d = d;
+
+ uint32_t qh = 0;
+ for (int j = 0; j < QK5_0/2; ++j) {
+ const float x0 = src[0 + j]*id;
+ const float x1 = src[QK5_0/2 + j]*id;
+
+ const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
+ const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
+
+ dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+ }
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+ for (int j = 0; j < 4; ++j) {
+ dst_data[i00/QK5_0].qh[j] = qh8[j];
+ }
+ }
+}
+
+kernel void kernel_cpy_f32_q5_1(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
+
+ device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ float max = src[0];
+ float min = src[0];
+
+ for (int j = 1; j < QK5_1; j++) {
+ const float v = src[j];
+ min = v < min ? v : min;
+ max = v > max ? v : max;
+ }
+
+ const float d = (max - min) / 31;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK5_1].d = d;
+ dst_data[i00/QK5_1].m = min;
+
+ uint32_t qh = 0;
+ for (int j = 0; j < QK5_1/2; ++j) {
+ const float x0 = (src[0 + j] - min)*id;
+ const float x1 = (src[QK5_1/2 + j] - min)*id;
+
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+ dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
+ }
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+ for (int j = 0; j < 4; ++j) {
+ dst_data[i00/QK5_1].qh[j] = qh8[j];
+ }
+ }
+}
+
+static inline int best_index_int8(int n, constant float * val, float x) {
+ if (x <= val[0]) return 0;
+ if (x >= val[n-1]) return n-1;
+ int ml = 0, mu = n-1;
+ while (mu-ml > 1) {
+ int mav = (ml+mu)/2;
+ if (x < val[mav]) mu = mav; else ml = mav;
+ }
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
+constexpr constant static float kvalues_iq4nl_f[16] = {
+ -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
+};
+
+kernel void kernel_cpy_f32_iq4_nl(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
+
+ device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < QK4_0; j++) {
+ const float v = src[j];
+ if (amax < fabs(v)) {
+ amax = fabs(v);
+ max = v;
+ }
+ }
+
+ const float d = max / kvalues_iq4nl_f[0];
+ const float id = d ? 1.0f/d : 0.0f;
+
+ float sumqx = 0, sumq2 = 0;
+ for (int j = 0; j < QK4_NL/2; ++j) {
+ const float x0 = src[0 + j]*id;
+ const float x1 = src[QK4_NL/2 + j]*id;
+
+ const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
+ const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
+
+ dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
+
+ const float v0 = kvalues_iq4nl_f[xi0];
+ const float v1 = kvalues_iq4nl_f[xi1];
+ const float w0 = src[0 + j]*src[0 + j];
+ const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
+ sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
+ sumq2 += w0*v0*v0 + w1*v1*v1;
+
+ }
+
+ dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
+
+ }
+}
+
kernel void kernel_concat(
device const char * src0,
device const char * src1,
@@ -2432,147 +2696,6 @@ kernel void kernel_concat(
}
}
-//============================================ k-quants ======================================================
-
-#ifndef QK_K
-#define QK_K 256
-#else
-static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
-#endif
-
-#if QK_K == 256
-#define K_SCALE_SIZE 12
-#else
-#define K_SCALE_SIZE 4
-#endif
-
-typedef struct {
- uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
- uint8_t qs[QK_K/4]; // quants
- half d; // super-block scale for quantized scales
- half dmin; // super-block scale for quantized mins
-} block_q2_K;
-// 84 bytes / block
-
-typedef struct {
- uint8_t hmask[QK_K/8]; // quants - high bit
- uint8_t qs[QK_K/4]; // quants - low 2 bits
-#if QK_K == 64
- uint8_t scales[2];
-#else
- uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
-#endif
- half d; // super-block scale
-} block_q3_K;
-
-#if QK_K == 64
-typedef struct {
- half d[2]; // super-block scales/mins
- uint8_t scales[2];
- uint8_t qs[QK_K/2]; // 4-bit quants
-} block_q4_K;
-#else
-typedef struct {
- half d; // super-block scale for quantized scales
- half dmin; // super-block scale for quantized mins
- uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
- uint8_t qs[QK_K/2]; // 4--bit quants
-} block_q4_K;
-#endif
-
-#if QK_K == 64
-typedef struct {
- half d; // super-block scales/mins
- int8_t scales[QK_K/16]; // 8-bit block scales
- uint8_t qh[QK_K/8]; // quants, high bit
- uint8_t qs[QK_K/2]; // quants, low 4 bits
-} block_q5_K;
-#else
-typedef struct {
- half d; // super-block scale for quantized scales
- half dmin; // super-block scale for quantized mins
- uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
- uint8_t qh[QK_K/8]; // quants, high bit
- uint8_t qs[QK_K/2]; // quants, low 4 bits
-} block_q5_K;
-// 176 bytes / block
-#endif
-
-typedef struct {
- uint8_t ql[QK_K/2]; // quants, lower 4 bits
- uint8_t qh[QK_K/4]; // quants, upper 2 bits
- int8_t scales[QK_K/16]; // scales, quantized with 8 bits
- half d; // super-block scale
-} block_q6_K;
-// 210 bytes / block
-
-typedef struct {
- half d;
- uint16_t qs[QK_K/8];
-} block_iq2_xxs;
-// 66 bytes / block for QK_K = 256, so 2.0625 bpw
-
-typedef struct {
- half d;
- uint16_t qs[QK_K/8];
- uint8_t scales[QK_K/32];
-} block_iq2_xs;
-// 74 bytes / block for QK_K = 256, so 2.3125 bpw
-
-// 2.5625 bpw quants
-typedef struct {
- half d;
- uint8_t qs[QK_K/4];
- uint8_t qh[QK_K/32];
- uint8_t scales[QK_K/32];
-} block_iq2_s;
-
-typedef struct {
- half d;
- uint8_t qs[3*QK_K/8];
-} block_iq3_xxs;
-// 98 bytes / block for QK_K = 256, so 3.0625 bpw
-
-// 3.4375 bpw
-#if QK_K == 64
-#define IQ3S_N_SCALE 2
-#else
-#define IQ3S_N_SCALE QK_K/64
-#endif
-typedef struct {
- half d;
- uint8_t qs[QK_K/4];
- uint8_t qh[QK_K/32];
- uint8_t signs[QK_K/8];
- uint8_t scales[IQ3S_N_SCALE];
-} block_iq3_s;
-
-typedef struct {
- half d;
- uint8_t qs[QK_K/8];
- uint8_t scales[QK_K/16];
-} block_iq1_s;
-
-// Non-linear quants
-#define QK4_NL 32
-typedef struct {
- half d;
- uint8_t qs[QK4_NL/2];
-} block_iq4_nl;
-
-#if QK_K == 64
-#define block_iq4_xs block_iq4_nl
-#else
-typedef struct {
- half d;
- uint16_t scales_h;
- uint8_t scales_l[QK_K/64];
- uint8_t qs[QK_K/2];
-} block_iq4_xs;
-#endif
-
-//====================================== dot products =========================
-
void kernel_mul_mv_q2_K_f32_impl(
device const void * src0,
device const float * src1,
@@ -3595,710 +3718,6 @@ kernel void kernel_mul_mv_q6_K_f32(
// ======================= "True" 2-bit
-constexpr constant static uint64_t iq2xxs_grid[256] = {
- 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
- 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
- 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
- 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
- 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
- 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
- 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
- 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
- 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
- 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
- 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
- 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
- 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
- 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
- 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
- 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
- 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
- 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
- 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
- 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
- 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
- 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
- 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
- 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
- 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
- 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
- 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
- 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
- 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
- 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
- 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
- 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
- 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
- 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
- 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
- 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
- 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
- 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
- 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
- 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
- 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
- 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
- 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
- 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
- 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
- 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
- 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
- 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
- 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
- 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
- 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
- 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
- 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
- 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
- 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
- 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
- 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
- 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
- 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
- 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
- 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
- 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
- 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
- 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
-};
-
-constexpr constant static uint64_t iq2xs_grid[512] = {
- 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
- 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
- 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
- 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
- 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
- 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
- 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
- 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
- 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
- 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
- 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
- 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
- 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
- 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
- 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
- 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
- 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
- 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
- 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
- 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
- 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
- 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
- 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
- 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
- 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
- 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
- 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
- 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
- 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
- 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
- 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
- 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
- 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
- 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
- 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
- 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
- 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
- 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
- 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
- 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
- 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
- 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
- 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
- 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
- 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
- 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
- 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
- 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
- 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
- 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
- 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
- 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
- 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
- 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
- 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
- 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
- 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
- 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
- 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
- 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
- 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
- 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
- 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
- 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
- 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
- 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
- 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
- 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
- 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
- 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
- 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
- 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
- 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
- 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
- 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
- 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
- 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
- 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
- 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
- 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
- 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
- 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
- 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
- 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
- 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
- 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
- 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
- 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
- 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
- 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
- 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
- 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
- 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
- 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
- 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
- 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
- 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
- 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
- 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
- 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
- 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
- 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
- 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
- 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
- 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
- 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
- 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
- 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
- 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
- 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
- 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
- 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
- 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
- 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
- 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
- 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
- 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
- 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
- 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
- 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
- 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
- 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
- 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
- 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
- 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
- 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
- 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
- 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
-};
-
-constexpr constant static uint64_t iq2s_grid[1024] = {
- 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
- 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
- 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
- 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
- 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
- 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b,
- 0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919,
- 0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808,
- 0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908,
- 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b,
- 0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908,
- 0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08,
- 0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19,
- 0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819,
- 0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919,
- 0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b,
- 0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908,
- 0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908,
- 0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919,
- 0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908,
- 0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b,
- 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919,
- 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b,
- 0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919,
- 0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908,
- 0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b,
- 0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b,
- 0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08,
- 0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b,
- 0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819,
- 0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808,
- 0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908,
- 0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b,
- 0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908,
- 0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08,
- 0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819,
- 0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808,
- 0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08,
- 0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819,
- 0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b,
- 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908,
- 0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919,
- 0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b,
- 0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919,
- 0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808,
- 0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819,
- 0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919,
- 0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919,
- 0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808,
- 0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819,
- 0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b,
- 0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908,
- 0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b,
- 0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908,
- 0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919,
- 0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b,
- 0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919,
- 0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b,
- 0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819,
- 0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919,
- 0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908,
- 0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b,
- 0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908,
- 0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b,
- 0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908,
- 0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08,
- 0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908,
- 0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819,
- 0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819,
- 0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808,
- 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08,
- 0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19,
- 0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819,
- 0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808,
- 0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819,
- 0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919,
- 0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808,
- 0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19,
- 0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08,
- 0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b,
- 0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908,
- 0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808,
- 0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819,
- 0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908,
- 0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819,
- 0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808,
- 0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808,
- 0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819,
- 0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908,
- 0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08,
- 0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819,
- 0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b,
- 0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b,
- 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08,
- 0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19,
- 0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819,
- 0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919,
- 0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908,
- 0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808,
- 0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808,
- 0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908,
- 0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808,
- 0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08,
- 0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08,
- 0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908,
- 0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919,
- 0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808,
- 0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819,
- 0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908,
- 0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08,
- 0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819,
- 0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808,
- 0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808,
- 0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819,
- 0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808,
- 0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908,
- 0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b,
- 0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819,
- 0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808,
- 0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b,
- 0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808,
- 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b,
- 0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19,
- 0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819,
- 0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08,
- 0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b,
- 0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908,
- 0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b,
- 0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b,
- 0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919,
- 0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808,
- 0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819,
- 0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908,
- 0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08,
- 0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08,
- 0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819,
- 0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919,
- 0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908,
- 0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b,
- 0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908,
- 0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b,
- 0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908,
- 0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08,
- 0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819,
- 0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808,
- 0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819,
- 0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919,
- 0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808,
- 0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808,
- 0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08,
- 0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819,
- 0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919,
- 0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808,
- 0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819,
- 0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919,
- 0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808,
- 0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b,
- 0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908,
- 0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808,
- 0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908,
- 0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b,
- 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908,
- 0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b,
- 0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908,
- 0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b,
- 0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908,
- 0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08,
- 0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908,
- 0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b,
- 0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908,
- 0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08,
- 0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819,
- 0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919,
- 0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808,
- 0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19,
- 0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b,
- 0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919,
- 0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808,
- 0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819,
- 0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908,
- 0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919,
- 0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808,
- 0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808,
- 0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b,
- 0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919,
- 0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808,
- 0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b,
- 0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808,
- 0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919,
- 0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b,
- 0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08,
- 0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919,
- 0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808,
- 0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b,
- 0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908,
- 0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808,
- 0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808,
- 0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808,
- 0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908,
- 0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808,
- 0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808,
- 0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b,
- 0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908,
- 0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808,
- 0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808,
- 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819,
- 0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919,
- 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b,
- 0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808,
- 0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819,
- 0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b,
- 0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908,
- 0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08,
- 0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908,
- 0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919,
- 0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819,
- 0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908,
- 0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b,
- 0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808,
- 0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819,
- 0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908,
- 0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919,
- 0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808,
- 0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808,
- 0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808,
- 0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919,
- 0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908,
- 0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908,
- 0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08,
- 0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819,
- 0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b,
- 0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808,
- 0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819,
- 0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908,
- 0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819,
- 0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808,
- 0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808,
- 0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b,
- 0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908,
- 0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808,
- 0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908,
- 0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819,
- 0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819,
- 0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808,
- 0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b,
- 0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b,
- 0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819,
- 0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b,
- 0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b,
- 0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b,
- 0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819,
- 0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19,
- 0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819,
- 0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908,
- 0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808,
- 0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b,
-};
-
-constexpr constant static uint32_t iq3xxs_grid[256] = {
- 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
- 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
- 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
- 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
- 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
- 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
- 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
- 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
- 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
- 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
- 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
- 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
- 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
- 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
- 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
- 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
- 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
- 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
- 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
- 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
- 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
- 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
- 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
- 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
- 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
- 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
- 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
- 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
- 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
- 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
- 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
- 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
-};
-
-constexpr constant static uint32_t iq3xs_grid[512] = {
- 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
- 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
- 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
- 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
- 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
- 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
- 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
- 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
- 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
- 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
- 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
- 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
- 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
- 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
- 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
- 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
- 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
- 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
- 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
- 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
- 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
- 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
- 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
- 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
- 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
- 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
- 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
- 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
- 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
- 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
- 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
- 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
- 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
- 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
- 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
- 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
- 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
- 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
- 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
- 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
- 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
- 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
- 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
- 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
- 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
- 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
- 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
- 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
- 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
- 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
- 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
- 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
- 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
- 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
- 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
- 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
- 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
- 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
- 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
- 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
- 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
- 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
- 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
- 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
-};
-
-#define NGRID_IQ1S 512
-constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
- 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
- 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
- 0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
- 0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
- 0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
- 0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
- 0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
- 0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
- 0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
- 0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
- 0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
- 0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
- 0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
- 0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
- 0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
- 0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
- 0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
- 0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
- 0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
- 0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
- 0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
- 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
- 0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
- 0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
- 0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
- 0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
- 0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
- 0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
- 0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
- 0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
- 0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
- 0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
- 0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
- 0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
- 0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
- 0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
- 0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
- 0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
- 0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
- 0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
- 0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
- 0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
- 0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
- 0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
- 0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
- 0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
- 0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
- 0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
- 0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
- 0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
- 0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
- 0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
- 0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
- 0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
- 0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
- 0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
- 0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
- 0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
- 0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
- 0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
- 0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
- 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
- 0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
- 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
- 0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
- 0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
- 0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
- 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
- 0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
- 0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
- 0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
- 0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
- 0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
- 0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
- 0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
- 0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
- 0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
- 0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
- 0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
- 0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
- 0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
- 0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
- 0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
- 0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
- 0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
- 0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
- 0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
- 0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
- 0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
- 0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
- 0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
- 0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
- 0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
- 0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
- 0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
- 0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
- 0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
- 0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
- 0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
- 0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
- 0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
- 0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
- 0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
- 0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
- 0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
- 0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
- 0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
- 0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
- 0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
- 0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
- 0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
- 0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
- 0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
- 0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
- 0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
- 0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
- 0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
- 0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
- 0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
- 0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
- 0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
- 0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
- 0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
- 0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
- 0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
- 0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
- 0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
- 0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
-};
-
-constexpr constant static uint8_t ksigns_iq2xs[128] = {
- 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
- 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
- 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
- 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
- 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
- 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
- 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
- 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
-};
-
-constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
-
void kernel_mul_mv_iq2_xxs_f32_impl(
device const void * src0,
device const float * src1,
@@ -4742,7 +4161,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
{
int nval = 8;
int pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) values[pos + i] = iq3xs_grid[pos + i];
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i];
threadgroup_barrier(mem_flags::mem_threadgroup);
}
@@ -4769,12 +4188,14 @@ void kernel_mul_mv_iq3_s_f32_impl(
for (int row = 0; row < N_DST; row++) {
const float db = dh[0];
- const float d = db * (0.5f + ((sc[0] >> 4*(ib%2)) & 0xf));
+ const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
float2 sum = {0};
for (int l = 0; l < 4; ++l) {
- const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
- const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
+ const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values;
+ const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values;
+ const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
+ const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
for (int j = 0; j < 4; ++j) {
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
@@ -4795,7 +4216,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
}
}
}
@@ -4994,48 +4415,53 @@ void kernel_mul_mv_iq1_s_f32_impl(
device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
- float yl[16];
+ float yl[32];
float sumf[N_DST]={0.f}, all_sum;
const int nb32 = nb * (QK_K / 32);
- const int ix = tiisg/2;
- const int il = tiisg%2;
+ const int ix = tiisg;
- device const float * y4 = y + 32 * ix + 16 * il;
+ device const float * y4 = y + 32 * ix;
- for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
- for (int i = 0; i < 16; ++i) {
+ float sumy = 0;
+ for (int i = 0; i < 32; ++i) {
yl[i] = y4[i];
+ sumy += yl[i];
}
const int ibl = ib32 / (QK_K / 32);
const int ib = ib32 % (QK_K / 32);
device const block_iq1_s * xr = x + ibl;
- device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
- device const uint8_t * sc = xr->scales + 2 * ib + il;
- device const half * dh = &xr->d;
+ device const uint8_t * qs = xr->qs + 4 * ib;
+ device const uint16_t * qh = xr->qh + ib;
+ device const half * dh = &xr->d;
for (int row = 0; row < N_DST; row++) {
- constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
- constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
+ constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
+ constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
- float2 sum = {0};
- for (int j = 0; j < 8; ++j) {
- sum[0] += yl[j+ 0] * grid1[j];
- sum[1] += yl[j+ 8] * grid2[j];
+ float sum = 0;
+ for (int j = 0; j < 4; ++j) {
+ sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+ + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
+ + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
+ + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
}
- sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
+ sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
dh += nb*sizeof(block_iq1_s)/2;
qs += nb*sizeof(block_iq1_s);
- sc += nb*sizeof(block_iq1_s);
+ qh += nb*sizeof(block_iq1_s)/2;
}
- y4 += 16 * 32;
+ y4 += 32 * 32;
}
for (int row = 0; row < N_DST; ++row) {
@@ -5046,9 +4472,113 @@ void kernel_mul_mv_iq1_s_f32_impl(
}
}
-constexpr constant static float kvalues_iq4nl_f[16] = {
- -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
-};
+void kernel_mul_mv_iq1_m_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+#if QK_K != 64
+ iq1m_scale_t scale;
+#endif
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ float4 sumy = {0.f};
+ for (int i = 0; i < 8; ++i) {
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+ yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
+ yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
+ yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq1_m * xr = x + ibl;
+ device const uint8_t * qs = xr->qs + 4 * ib;
+ device const uint8_t * qh = xr->qh + 2 * ib;
+ device const uint16_t * sc = (device const uint16_t *)xr->scales;
+
+ for (int row = 0; row < N_DST; row++) {
+
+#if QK_K != 64
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+#endif
+
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
+ constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
+ constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
+
+ float2 sum = {0.f};
+ for (int j = 0; j < 4; ++j) {
+ sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+ + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
+ sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
+ + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
+ }
+ const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+ const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+#if QK_K == 64
+ const float d = (float) *((device const half *)(sc - 1));
+ sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
+ (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
+#else
+ sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
+ (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
+#endif
+
+ sc += nb*sizeof(block_iq1_m)/2;
+ qs += nb*sizeof(block_iq1_m);
+ qh += nb*sizeof(block_iq1_m);
+ }
+
+ y4 += 32 * 32;
+ }
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ }
+ }
+}
void kernel_mul_mv_iq4_nl_f32_impl(
device const void * src0,
@@ -5267,6 +4797,34 @@ kernel void kernel_mul_mv_iq1_s_f32(
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
}
+[[host_name("kernel_mul_mv_iq1_m_f32")]]
+kernel void kernel_mul_mv_iq1_m_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
kernel void kernel_mul_mv_iq4_nl_f32(
device const void * src0,
@@ -5685,15 +5243,15 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 &
device const uint8_t * qs = xb->qs + 8*ib32;
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
const uint8_t qh = xb->qh[ib32] >> 4*il;
- const float dl = d * (0.5f + ((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * 0.5f;
- constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256)));
+ const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
}
- grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256)));
- grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256)));
+ grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
+ grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
for (int i = 0; i < 4; ++i) {
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
@@ -5722,16 +5280,53 @@ void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 &
template
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const int ib32 = il/2;
+ il = il%2;
const float d = xb->d;
- device const uint8_t * qs = xb->qs + 2*il;
- device const uint8_t * sc = xb->scales + il;
- const float dl1 = d * (2*(sc[0] & 7) + 1);
- const float dl2 = d * (2*((sc[0] >> 4) & 7) + 1);
- constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
- constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
- for (int i = 0; i < 8; ++i) {
- reg[i/4+0][i%4] = dl1 * grid1[i];
- reg[i/4+2][i%4] = dl2 * grid2[i];
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+ device const uint16_t * qh = xb->qh;
+ const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
+ const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
+ const uint16_t h = qh[ib32] >> 6*il;
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
+ for (int i = 0; i < 4; ++i) {
+ reg[0][i] = dl * (grid1[i] & 0xf) + ml;
+ reg[1][i] = dl * (grid1[i] >> 4) + ml;
+ reg[2][i] = dl * (grid2[i] & 0xf) + ml;
+ reg[3][i] = dl * (grid2[i] >> 4) + ml;
+ }
+}
+
+template
+void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const int ib32 = il/2;
+ il = il%2;
+ device const uint16_t * sc = (device const uint16_t *)xb->scales;
+#if QK_K == 64
+ const float d = xb->d;
+#else
+ iq1m_scale_t scale;
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+ const float d = scale.f16;
+#endif
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+ device const uint8_t * qh = xb->qh + 2*ib32 + il;
+#if QK_K == 64
+ const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
+#else
+ const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
+#endif
+ const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+ const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
+ for (int i = 0; i < 4; ++i) {
+ reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
+ reg[1][i] = dl * (grid1[i] >> 4) + ml1;
+ reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
+ reg[3][i] = dl * (grid2[i] >> 4) + ml2;
}
}
@@ -6042,7 +5637,7 @@ template
kernel void kernel_mul_mm_id(
- device const uchar * ids,
+ device const uchar * src0s,
device const uchar * src1,
device float * dst,
+ device const uchar * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne02,
@@ -6225,29 +5821,21 @@ kernel void kernel_mul_mm_id(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const uchar * src00,
- device const uchar * src01,
- device const uchar * src02,
- device const uchar * src03,
- device const uchar * src04,
- device const uchar * src05,
- device const uchar * src06,
- device const uchar * src07,
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
// expert id
const int32_t id = tgpig.z/(ne12*ne13);
+ device const uchar * src0 = src0s + id*nb02;
tgpig.z = tgpig.z%(ne12*ne13);
// row indices of src1 for expert id
- int64_t _ne1 = 0;
- short src1ids[512];
+ threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192);
+ int64_t _ne1 = 0;
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
src1ids[_ne1++] = i1;
@@ -6255,7 +5843,7 @@ kernel void kernel_mul_mm_id(
}
kernel_mul_mm_id_impl(
- src0s[id],
+ src0,
src1,
src1ids,
dst,
@@ -6319,6 +5907,7 @@ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_r
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows;
+template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows;
#if QK_K == 64
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows;
@@ -6349,24 +5938,25 @@ typedef void (mat_mm_t)(
threadgroup uchar *,
uint3, uint, uint);
-template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm;
-template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm;
-template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm;
-template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm;
-template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm;
-template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm;
-template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm;
-template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm;
-template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm;
-template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm;
-template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm;
-template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm;
#if QK_K == 64
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm;
@@ -6379,9 +5969,10 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
//
typedef void (mat_mm_id_t)(
- device const uchar * ids,
+ device const uchar * src0s,
device const uchar * src1,
device float * dst,
+ device const uchar * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne02,
@@ -6398,35 +5989,28 @@ typedef void (mat_mm_id_t)(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const uchar * src00,
- device const uchar * src01,
- device const uchar * src02,
- device const uchar * src03,
- device const uchar * src04,
- device const uchar * src05,
- device const uchar * src06,
- device const uchar * src07,
threadgroup uchar *,
uint3, uint, uint);
-template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
-template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
-template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
-template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
-template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
-template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
-template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
-template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
-template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
-template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
-template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
-template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
#if QK_K == 64
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
@@ -6440,9 +6024,10 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
[[host_name("kernel_mul_mv_id_f32_f32")]]
kernel void kernel_mul_mv_id_f32_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -6463,28 +6048,19 @@ kernel void kernel_mul_mv_id_f32_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_f32_f32_impl(
- src0[id],
+ src0,
src1 + bid*nb11,
dst + bid*ne0,
ne00,
@@ -6509,9 +6085,10 @@ kernel void kernel_mul_mv_id_f32_f32(
[[host_name("kernel_mul_mv_id_f16_f32")]]
kernel void kernel_mul_mv_id_f16_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -6532,28 +6109,19 @@ kernel void kernel_mul_mv_id_f16_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_f16_f32_impl(
- src0[id],
+ src0,
src1 + bid*nb11,
dst + bid*ne0,
ne00,
@@ -6578,9 +6146,10 @@ kernel void kernel_mul_mv_id_f16_f32(
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
kernel void kernel_mul_mv_id_q8_0_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -6601,28 +6170,19 @@ kernel void kernel_mul_mv_id_q8_0_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_q8_0_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
@@ -6641,9 +6201,10 @@ kernel void kernel_mul_mv_id_q8_0_f32(
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
kernel void kernel_mul_mv_id_q4_0_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -6664,28 +6225,19 @@ kernel void kernel_mul_mv_id_q4_0_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
mul_vec_q_n_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
@@ -6704,9 +6256,10 @@ kernel void kernel_mul_mv_id_q4_0_f32(
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
kernel void kernel_mul_mv_id_q4_1_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -6727,28 +6280,19 @@ kernel void kernel_mul_mv_id_q4_1_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
mul_vec_q_n_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
@@ -6767,9 +6311,10 @@ kernel void kernel_mul_mv_id_q4_1_f32(
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
kernel void kernel_mul_mv_id_q5_0_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -6790,28 +6335,19 @@ kernel void kernel_mul_mv_id_q5_0_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
mul_vec_q_n_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
@@ -6830,9 +6366,10 @@ kernel void kernel_mul_mv_id_q5_0_f32(
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
kernel void kernel_mul_mv_id_q5_1_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -6853,28 +6390,19 @@ kernel void kernel_mul_mv_id_q5_1_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
mul_vec_q_n_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
@@ -6893,9 +6421,10 @@ kernel void kernel_mul_mv_id_q5_1_f32(
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
kernel void kernel_mul_mv_id_q2_K_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -6916,28 +6445,19 @@ kernel void kernel_mul_mv_id_q2_K_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_q2_K_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
@@ -6956,9 +6476,10 @@ kernel void kernel_mul_mv_id_q2_K_f32(
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
kernel void kernel_mul_mv_id_q3_K_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -6979,28 +6500,19 @@ kernel void kernel_mul_mv_id_q3_K_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_q3_K_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
@@ -7019,9 +6531,10 @@ kernel void kernel_mul_mv_id_q3_K_f32(
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
kernel void kernel_mul_mv_id_q4_K_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -7042,28 +6555,19 @@ kernel void kernel_mul_mv_id_q4_K_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_q4_K_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
@@ -7082,9 +6586,10 @@ kernel void kernel_mul_mv_id_q4_K_f32(
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
kernel void kernel_mul_mv_id_q5_K_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -7105,28 +6610,19 @@ kernel void kernel_mul_mv_id_q5_K_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_q5_K_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
@@ -7145,9 +6641,10 @@ kernel void kernel_mul_mv_id_q5_K_f32(
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
kernel void kernel_mul_mv_id_q6_K_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -7168,28 +6665,19 @@ kernel void kernel_mul_mv_id_q6_K_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_q6_K_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
@@ -7208,9 +6696,10 @@ kernel void kernel_mul_mv_id_q6_K_f32(
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
kernel void kernel_mul_mv_id_iq2_xxs_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -7231,29 +6720,20 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_iq2_xxs_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
@@ -7273,9 +6753,10 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
kernel void kernel_mul_mv_id_iq2_xs_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -7296,29 +6777,20 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_iq2_xs_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
@@ -7338,9 +6810,10 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
kernel void kernel_mul_mv_id_iq3_xxs_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -7361,29 +6834,20 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_iq3_xxs_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
@@ -7403,9 +6867,10 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
kernel void kernel_mul_mv_id_iq3_s_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -7426,29 +6891,20 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_iq3_s_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
@@ -7468,9 +6924,10 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
kernel void kernel_mul_mv_id_iq2_s_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -7491,29 +6948,20 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_iq2_s_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
@@ -7533,9 +6981,10 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
kernel void kernel_mul_mv_id_iq1_s_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -7556,28 +7005,74 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_iq1_s_f32_impl(
- src0[id],
+ src0,
+ (device const float *) (src1 + bid*nb11),
+ dst + bid*ne0,
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
+kernel void kernel_mul_mv_id_iq1_m_f32(
+ device const char * src0s,
+ device const char * src1,
+ device float * dst,
+ device const char * ids,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
+
+ kernel_mul_mv_iq1_m_f32_impl(
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
@@ -7596,9 +7091,10 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
kernel void kernel_mul_mv_id_iq4_nl_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -7619,29 +7115,20 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
threadgroup float * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_iq4_nl_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
@@ -7661,9 +7148,10 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
kernel void kernel_mul_mv_id_iq4_xs_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@@ -7684,33 +7172,24 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
threadgroup float * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
#if QK_K == 64
kernel_mul_mv_iq4_nl_f32_impl(
#else
kernel_mul_mv_iq4_xs_f32_impl(
#endif
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
diff --git a/LLama/runtimes/deps/osx-arm64/libllama.dylib b/LLama/runtimes/deps/osx-arm64/libllama.dylib
index 87295f843..b474697df 100644
Binary files a/LLama/runtimes/deps/osx-arm64/libllama.dylib and b/LLama/runtimes/deps/osx-arm64/libllama.dylib differ
diff --git a/LLama/runtimes/deps/osx-arm64/libllava_shared.dylib b/LLama/runtimes/deps/osx-arm64/libllava_shared.dylib
index 84ff71671..c47c71ed0 100644
Binary files a/LLama/runtimes/deps/osx-arm64/libllava_shared.dylib and b/LLama/runtimes/deps/osx-arm64/libllava_shared.dylib differ
diff --git a/LLama/runtimes/deps/osx-x64/libllama.dylib b/LLama/runtimes/deps/osx-x64/libllama.dylib
index df9c7280c..bbd84c9c2 100644
Binary files a/LLama/runtimes/deps/osx-x64/libllama.dylib and b/LLama/runtimes/deps/osx-x64/libllama.dylib differ
diff --git a/LLama/runtimes/deps/osx-x64/libllava_shared.dylib b/LLama/runtimes/deps/osx-x64/libllava_shared.dylib
index ac9bd1eca..c6b265ff4 100644
Binary files a/LLama/runtimes/deps/osx-x64/libllava_shared.dylib and b/LLama/runtimes/deps/osx-x64/libllava_shared.dylib differ
diff --git a/llama.cpp b/llama.cpp
index 3ab8b3a92..f7001ccc5 160000
--- a/llama.cpp
+++ b/llama.cpp
@@ -1 +1 @@
-Subproject commit 3ab8b3a92ede46df88bc5a2dfca3777de4a2b2b6
+Subproject commit f7001ccc5aa359fcf41bba19d1c99c3d25c9bcc7