diff --git a/README.md b/README.md index 7a2ecba..502819c 100644 --- a/README.md +++ b/README.md @@ -80,10 +80,11 @@ public int TiktokenSharp() + | Method | Job | Runtime | Mean | Error | StdDev | Gen0 | Allocated | |-------------- |--------- |--------- |----------:|---------:|---------:|----------:|-----------:| -| SharpToken | .NET 8.0 | .NET 8.0 | 112.86 ms | 0.712 ms | 0.595 ms | 2600.0000 | 23202285 B | -| TiktokenSharp | .NET 8.0 | .NET 8.0 | 99.40 ms | 0.179 ms | 0.149 ms | 9800.0000 | 82321296 B | +| SharpToken | .NET 8.0 | .NET 8.0 | 116.38 ms | 1.026 ms | 0.909 ms | 2000.0000 | 23201696 B | +| TiktokenSharp | .NET 8.0 | .NET 8.0 | 98.34 ms | 0.198 ms | 0.176 ms | 9833.3333 | 82321080 B | ## Update diff --git a/TiktokenSharp.Test/Program.cs b/TiktokenSharp.Test/Program.cs index 28a977d..b5e36eb 100644 --- a/TiktokenSharp.Test/Program.cs +++ b/TiktokenSharp.Test/Program.cs @@ -34,7 +34,7 @@ static void GPT4() Debug.Assert(i.IsEqualTo(new List() { 15339, 1917 })); Debug.Assert(tikToken.Decode(new List() { 15339, 1917 }) == "hello world"); - var c = tikToken.Encode("hello <|endoftext|>"); + var c = tikToken.Encode("hello <|endoftext|>", new HashSet() { "<|endoftext|>" }); Debug.Assert(c.IsEqualTo(new List() { 15339, 220, 100257 })); var t1 = tikToken.Encode("我很抱歉,我不能提供任何非法或不道德的建议。快速赚钱是不容易的,需要耐心、刻苦努力和经验。如果您想增加收入,请考虑增加工作时间、寻找其他业务机会、学习新技能或提高自己的价值等方法。请记住,通过合法而道德的方式来获得收入,才是长期稳定的解决方案。"); diff --git a/TiktokenSharp/CoreBPE.cs b/TiktokenSharp/CoreBPE.cs index def5056..28e46ef 100644 --- a/TiktokenSharp/CoreBPE.cs +++ b/TiktokenSharp/CoreBPE.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Linq; @@ -12,9 +13,12 @@ public class CoreBPE { private Dictionary _specialTokensEncoder { get; set; } - // TODO private max_token_value ?? private Dictionary, int> _encoder { get; set; } + // TODO Cache? + //private ConcurrentDictionary> _cache { get; set; } + //private MemoryCache _cache = MemoryCache.Default; + private Regex _specialRegex { get; set; } private Regex _regex { get; set; } @@ -37,6 +41,7 @@ public class CoreBPE public CoreBPE(Dictionary, int> encoder, Dictionary specialTokensEncoder, string pattern) { _encoder = encoder; + //_cache = new ConcurrentDictionary>(new ReadOnlyMemoryComparer()); _regex = new Regex(pattern, RegexOptions.Compiled); _specialRegex = new Regex(string.Join("|", specialTokensEncoder.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled); _specialTokensEncoder = specialTokensEncoder; @@ -61,16 +66,13 @@ public CoreBPE(Dictionary, int> encoder, Dictionary, int) EncodeNative(string text, HashSet allowedSpecial, HashSet disallowedSpecial) { - Regex specialRegex = _specialRegex; - Regex regex = _regex; var ret = new List(); - ReadOnlySpan textSpan = text.AsSpan(); + var textSpan = text.AsMemory(); int lastPieceTokenLen = 0; int currentIndex = 0; - var enumerator = specialRegex.EnumerateMatches(textSpan); - + var enumerator = _specialRegex.EnumerateMatches(textSpan.Span); while (currentIndex < text.Length) { @@ -78,13 +80,13 @@ public CoreBPE(Dictionary, int> encoder, Dictionary, int> encoder, Dictionary currentSpan = textSpan.Slice(currentIndex, nextMatchStart - currentIndex); - foreach (var match in regex.EnumerateMatches(currentSpan)) + //read only + + ReadOnlyMemory currentSpan = textSpan.Slice(currentIndex, nextMatchStart - currentIndex); + foreach (var match in _regex.EnumerateMatches(currentSpan.Span)) { - var piece = Encoding.UTF8.GetBytes(currentSpan.Slice(match.Index, match.Length).ToString()); + var charSpan = currentSpan.Slice(match.Index, match.Length); + //var byteSpan = ByteHelper.ConvertReadOnlyMemoryCharToByte(charSpan); + + var piece = Encoding.UTF8.GetBytes(charSpan.ToString()); //TODO remove ToString if (_encoder.TryGetValue(piece, out int token)) { lastPieceTokenLen = 1; @@ -104,9 +111,19 @@ public CoreBPE(Dictionary, int> encoder, Dictionary cacheToken)) + //{ + // ret.AddRange(cacheToken); + // continue; + //} + var tokens = BytePairEncoding.BytePairEncode(piece, _encoder); lastPieceTokenLen = tokens.Count; ret.AddRange(tokens); + + //_cache[piece] = tokens; } } @@ -116,7 +133,7 @@ public CoreBPE(Dictionary, int> encoder, Dictionary Encode(string text, HashSet allowedSpecial = null, HashSet disallowedSpecial = null) { +//#if NET7_0_OR_GREATER +// HashSet>? allowedSpecialMemory = null; +// HashSet>? disallowedSpecialMemory = null; + +//#else +// List>? allowedSpecialMemory = null; +// List>? disallowedSpecialMemory = null; +//#endif +// if (allowedSpecial != null) +// { +// allowedSpecialMemory = allowedSpecial +// .Select(str => (ReadOnlyMemory)str.AsMemory()) +//#if NET7_0_OR_GREATER +// .ToHashSet(); +//#else +// .ToList(); +//#endif +// } + +// if (disallowedSpecial != null ) +// { +// var disallowedSpecialMemcpy = disallowedSpecial +// .Select(str => (ReadOnlyMemory)str.AsMemory()) +//#if NET7_0_OR_GREATER +// .ToHashSet(); +//#else +// .ToList(); +//#endif +// } + + + + return _corePBE.EncodeNative(text, allowedSpecial, disallowedSpecial).Item1; } @@ -98,8 +139,5 @@ public string Decode(List tokens) } - - - } } diff --git a/TiktokenSharp/Utils/ByteHelper.cs b/TiktokenSharp/Utils/ByteHelper.cs index a41f3ca..ead3697 100644 --- a/TiktokenSharp/Utils/ByteHelper.cs +++ b/TiktokenSharp/Utils/ByteHelper.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.InteropServices; using System.Text; using System.Threading.Tasks; @@ -20,5 +21,13 @@ public static string ConvertByteListToString(List> byteList } return Encoding.UTF8.GetString(allBytes); } + + public static ReadOnlySpan ConvertReadOnlyMemoryCharToByte(ReadOnlyMemory charMemory) + { + var charSpan = charMemory.Span; + var bytes = MemoryMarshal.AsBytes(charSpan); + return bytes; + } + } }