Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] refactor: init some experimental refactoring. #362

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions LLama/Abstractions/IInferenceParams.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using System.Collections.Generic;
using LLama.Common;
using LLama.Control;
using LLama.Native;
using LLama.Sampling;
using LLama.Transform;

namespace LLama.Abstractions
{
Expand Down Expand Up @@ -114,5 +116,15 @@ public interface IInferenceParams
/// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b>
/// </summary>
ISamplingPipeline? SamplingPipeline { get; set; }

/// <summary>
/// Set a custom generation control to use. <b>If this is set antiprompt will be ignored!</b>
/// </summary>
IGenerationControl GenerationControl { get; set; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
IGenerationControl GenerationControl { get; set; }
IGenerationControl? GenerationControl { get; set; }


/// <summary>
/// Set a custom tokenizer to use.
/// </summary>
ITokenizer Tokenizer { get; set; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ITokenizer Tokenizer { get; set; }
ITokenizer? Tokenizer { get; set; }

}
}
4 changes: 2 additions & 2 deletions LLama/Abstractions/ILLamaExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ public interface ILLamaExecutor
/// </summary>
/// <param name="text">Your prompt</param>
/// <param name="inferenceParams">Any additional parameters</param>
/// <param name="token">A cancellation token.</param>
/// <param name="cancellationToken">A cancellation token.</param>
/// <returns></returns>
IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default);
IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default);
}
}
9 changes: 9 additions & 0 deletions LLama/Common/InferenceParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
using System.Collections.Generic;
using LLama.Native;
using LLama.Sampling;
using LLama.Control;
using LLama.Transform;
using System.Text;

namespace LLama.Common
{
Expand Down Expand Up @@ -80,6 +83,12 @@ public record InferenceParams

/// <inheritdoc />
public ISamplingPipeline? SamplingPipeline { get; set; }

/// <inheritdoc />
public IGenerationControl GenerationControl { get; set; } = new DefaultGenerationControl();

/// <inheritdoc />
public ITokenizer Tokenizer { get; set; } = new DefaultTokenizer(Encoding.UTF8);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using System;
using System.Collections.Generic;

namespace LLama
namespace LLama.Control
{
/// <summary>
/// AntipromptProcessor keeps track of past tokens looking for any set Anti-Prompts
Expand Down
42 changes: 42 additions & 0 deletions LLama/Control/DefaultGenerationControl.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using LLama.Abstractions;
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Control
{
/// <summary>
/// The default generation control in LLamaSharp, using antiprompts. This class should not be inherited.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// The default generation control in LLamaSharp, using antiprompts. This class should not be inherited.
/// The default generation control in LLamaSharp, using antiprompts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's sealed, so it's not possible to extend this class.

/// <b>Note that this class has state. The previous outputs feeded to it will affect its control.</b>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// <b>Note that this class has state. The previous outputs feeded to it will affect its control.</b>
/// <b>Note that this class has state. The previous outputs fed to it will affect its output.</b>

/// If you use it in a session, please don't reuse it for another session unless you intend to do so.
/// </summary>
public sealed class DefaultGenerationControl: IGenerationControl
{
private AntipromptProcessor _antipromptProcessor;

/// <summary>
/// <inheritdoc/>
/// </summary>
public DefaultGenerationControl()
{
_antipromptProcessor = new AntipromptProcessor();
}

/// <summary>
/// <inheritdoc/>
/// </summary>
public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText)
{
_antipromptProcessor.SetAntiprompts(inferenceParams.AntiPrompts);
return _antipromptProcessor.Add(lastOutputText);
}

/// <summary>
/// <inheritdoc/>
/// </summary>
public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable<int> lastOutputIds)
{
return false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be returning false?

}
}
}
31 changes: 31 additions & 0 deletions LLama/Control/IGenerationControl.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using LLama.Abstractions;
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Control
{
/// <summary>
/// Control the text generation of LLama Executors.
/// </summary>
public interface IGenerationControl
{
/// <summary>
/// Use the last output text to determine if the generation should stop.
/// </summary>
/// <param name="context"></param>
/// <param name="inferenceParams"></param>
/// <param name="lastOutputText"></param>
/// <returns></returns>
bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText);

/// <summary>
martindevans marked this conversation as resolved.
Show resolved Hide resolved
/// Use the last output ids to determine if the generation should stop.
/// </summary>
/// <param name="context"></param>
/// <param name="inferenceParams"></param>
/// <param name="lastOutputIds"></param>
/// <returns></returns>
bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable<int> lastOutputIds);
}
}
1 change: 1 addition & 0 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using LLama.Common;
using LLama.Exceptions;
using LLama.Native;
using LLama.Transform;
using Microsoft.Extensions.Logging;
using System;
using System.Collections.Generic;
Expand Down
49 changes: 27 additions & 22 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Threading.Tasks;
using LLama.Native;
using LLama.Sampling;
using LLama.Control;
using Microsoft.Extensions.Logging;

namespace LLama
Expand Down Expand Up @@ -49,69 +50,73 @@ public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger?
/// <inheritdoc />
public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
// Ensure the context from last time is disposed (it always hould be)
// Ensure the context from last time is disposed (it always should be)
if (!Context.NativeHandle.IsClosed)
Context.Dispose();

// Create an inference context which will be disposed when this method exits
using var context = _weights.CreateContext(_params, _logger);
Context = context;

await foreach(var item in InferAsync(prompt, Context, inferenceParams, cancellationToken))
{
yield return item;
}
}

public static async IAsyncEnumerable<string> InferAsync(string prompt, LLamaContext context, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{

// Sanity check inference params
inferenceParams ??= new InferenceParams();
if (inferenceParams.TokensKeep > Context.ContextSize)
throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})");

// Create decoders for the token stream
var decoder = new StreamingTokenDecoder(Context);
var antiprocessor = new AntipromptProcessor(inferenceParams.AntiPrompts);
if (inferenceParams.TokensKeep > context.ContextSize)
throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({context.ContextSize})");

// Keep track of the last N tokens emitted
var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount);
var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount < 0 ? context.ContextSize : inferenceParams.RepeatLastTokensCount);
var lastTokens = new List<llama_token>(repeat_last_n);
for (var i = 0; i < repeat_last_n; i++)
lastTokens.Add(0);

// Tokenize the prompt
var tokens = Context.Tokenize(prompt).ToList();
var tokens = inferenceParams.Tokenizer.Tokenize(context, prompt).ToList();
lastTokens.AddRange(tokens);
var n_past = 1 + tokens.Count;

// Evaluate the prompt
await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken)
await Task.Run(() => { context.Eval(tokens, 1); }, cancellationToken)
.ConfigureAwait(false);

// Begin loop, evaluating one token at a time
var mu = (float?)null;
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
for (var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
{
llama_token id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
id = inferenceParams.SamplingPipeline.Sample(context.NativeHandle, context.NativeHandle.GetLogits(), lastTokens);
}
else
{
// Penalize the generated tokens by various penalties
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

// Sample a single token
id = Context.Sample(
id = context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
}

// Decode this token into text
decoder.Add(id);
var decoded = decoder.Read();
var decoded = inferenceParams.Tokenizer.Detokenize(context, id);
yield return decoded;

// Check if any of the antiprompts have been generated
if (antiprocessor.Add(decoded))
// Check if the generation should stop
if (inferenceParams.GenerationControl.ShouldStopGeneration(context, inferenceParams, decoded))
break;

lastTokens.Add(id);
Expand All @@ -120,19 +125,19 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams

// when run out of context
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
if (n_past + tokens.Count >= Context.ContextSize)
if (n_past + tokens.Count >= context.ContextSize)
{
var n_left = n_past - inferenceParams.TokensKeep - 1;
var n_discard = n_left / 2;

NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);
NativeApi.llama_kv_cache_seq_rm(context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
NativeApi.llama_kv_cache_seq_shift(context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);

n_past -= n_discard;
}

// ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)
n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken)
n_past = await Task.Run(() => context.Eval(tokens, n_past), cancellationToken)
.ConfigureAwait(false);
}
}
Expand Down
31 changes: 31 additions & 0 deletions LLama/TextCompletion.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using LLama.Abstractions;
using LLama.Common;
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;

namespace LLama
{
/// <summary>
/// A class to execute text completion task.
/// </summary>
public class TextCompletion
{
public string Execute(string prompt, IInferenceParams? inferenceParams = null)
{
throw new NotImplementedException();
}

public ChatHistory Execute(ChatHistory prompt, IInferenceParams? inferenceParams = null)
{
throw new NotImplementedException();
}

public async IAsyncEnumerable<string> StreamingExecute(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}
}
}
53 changes: 53 additions & 0 deletions LLama/Transform/DefaultTokenizer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Transform
{
/// <summary>
/// The default tokenizer of LLamaSharp. This class should not be inherited.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// The default tokenizer of LLamaSharp. This class should not be inherited.
/// The default tokenizer of LLamaSharp.

/// <b>Note that this class has state. The previous outputs feeded to it will affect its control.</b>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// <b>Note that this class has state. The previous outputs feeded to it will affect its control.</b>
/// <b>Note that this class has state. The previous outputs fed to it will affect its output.</b>

/// If you use it in a session, please don't reuse it for another session unless you intend to do so.
/// </summary>
public sealed class DefaultTokenizer: ITokenizer
{
private Encoding _encoding;
private StreamingTokenDecoder _tokenDecoder;

/// <summary>
/// Initialize a new tokenizer with the specified encoding.
/// </summary>
/// <param name="encoding"></param>
public DefaultTokenizer(Encoding encoding)
{
_encoding = encoding;
_tokenDecoder = new StreamingTokenDecoder(encoding);
}

/// <summary>
/// <inheritdoc/>
/// </summary>
public IEnumerable<int> Tokenize(LLamaContext context, string text, bool addBos = true, bool special = false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to accept a LLamaWeights in the constructor, instead of a LLamaContext in the Tokenize/Detokenize methids. That simplifies usage and allows you to use the tokenizer without creating an entire context (which is quite memory hungry).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll change it, thank you!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there already a way to tokenize without a context? I didn't find a such method

Copy link
Member

@martindevans martindevans Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLamaWeights w = your_weights;
w.NativeHandle.Tokenize("a string");

Should do it. There should really be higher level wrappers for tokenization in LLamaWeights, so that you don't have to access the NativeHandle, but I haven't built them yet.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I mistook the one in NativeApi with context as parameter as the lowest level api. I'll add a wrapper for it with LLamaWeights. :)

{
return context.Tokenize(text, addBos, special);
}

/// <summary>
/// <inheritdoc/>
/// </summary>
public string Detokenize(LLamaContext context, int token)
{
_tokenDecoder.Add(token, context.NativeHandle.ModelHandle);
return _tokenDecoder.Read();
}

/// <summary>
/// <inheritdoc/>
/// </summary>
public string Detokenize(LLamaContext context, IEnumerable<int> tokens)
{
_tokenDecoder.AddRange(tokens, context.NativeHandle.ModelHandle);
return _tokenDecoder.Read();
}
}
}
15 changes: 15 additions & 0 deletions LLama/Transform/ITokenizer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Transform
{
public interface ITokenizer
{
IEnumerable<int> Tokenize(LLamaContext context, string text, bool addBos = true, bool special = false);

string Detokenize(LLamaContext context, int token);

string Detokenize(LLamaContext context, IEnumerable<int> tokens);
}
}
Loading