Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model File Manager #789

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions LLama.Unittest/Model/ModelCacheTests.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using LLama.Common;
using LLama.Model;

namespace LLama.Unittest;
namespace LLama.Unittest.Model;

public class ModelManagerTests
{
Expand Down Expand Up @@ -108,14 +108,16 @@ public async void LoadModel_LoadsAndCaches()
var modelToLoad = TestableModelManager.ModelFileList
.First(f => f.FileName.Contains("llama-2-7b"));

var model = await TestableModelManager.LoadModel(modelToLoad.FilePath, null!);

Assert.Single(TestableModelManager.GetLoadedModels());

var model = await TestableModelManager.LoadModel(modelToLoad.FilePath);
var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel);
Assert.True(isLoaded);

// unload
// unload the newly acquired model even though it was cached
Assert.True(TestableModelManager.UnloadModel(model.ModelName));
//cachedModel.Dispose(); // this does effectively nothing

// unload "original"
//model.Dispose();
Assert.True(TestableModelManager.UnloadModel(model.ModelName));

Assert.Throws<ObjectDisposedException>(() =>
Expand All @@ -135,7 +137,6 @@ public async void LoadModel_AlreadyLoaded_ReturnsFromCache()
var model = await TestableModelManager.LoadModel(modelToLoad.FilePath);
Assert.NotNull(model);
Assert.Equal("LLaMA v2", model.ModelName);
Assert.Single(TestableModelManager.GetLoadedModels());
var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel);
Assert.True(isLoaded);
Assert.NotNull(cachedModel);
Expand All @@ -153,16 +154,20 @@ public async void TryGetLoadedModel_AlreadyDisposed_ReturnsFalse()
{
Assert.NotNull(model);
Assert.Equal("LLaMA v2", model.ModelName);
Assert.Single(TestableModelManager.GetLoadedModels());
var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel);
Assert.True(isLoaded);
Assert.NotNull(cachedModel);
Assert.Equal("LLaMA v2", cachedModel.ModelName);
} // end scope, dispose model

// Model is now disposed
var isDispoedLoaded = TestableModelManager.TryGetLoadedModel("LLaMA v2", out var disposedModel);
Assert.False(isDispoedLoaded);
// unload from the last check
Assert.True(TestableModelManager.UnloadModel("LLaMA v2"));

} // end scope, dispose is called on the model but since we have the model cache it should stick around until unloaded
Assert.True(TestableModelManager.UnloadModel("LLaMA v2"));

// Model is still loaded due to cache
var isDisposedLoaded = TestableModelManager.TryGetLoadedModel("LLaMA v2", out var disposedModel);
Assert.False(isDisposedLoaded);
Assert.Null(disposedModel);
}
}
35 changes: 24 additions & 11 deletions LLama/LLamaWeights.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,23 @@
Metadata = weights.ReadMetadata();
}

/// <summary>
/// Create from a "shared" handle
/// </summary>
/// <param name="handle"></param>
/// <returns></returns>
public static LLamaWeights FromSafeModelHandle(SafeLlamaModelHandle handle)
{
var model = new LLamaWeights(handle);

// Increment the model reference count while this weight exists.
// DangerousAddRef throws if it fails, so there is no need to check "success"
var success = false;
handle.DangerousAddRef(ref success);

return model;
}

/// <summary>
/// Load weights into memory
/// </summary>
Expand All @@ -79,19 +96,19 @@
public static LLamaWeights LoadFromFile(IModelParams @params)
{
using var pin = @params.ToLlamaModelParams(out var lparams);
var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);

foreach (var adapter in @params.LoraAdapters)
{
if (string.IsNullOrEmpty(adapter.Path))
continue;
if (adapter.Scale <= 0)
if (string.IsNullOrEmpty(adapter.Path) || adapter.Scale <= 0)
{
continue;
}

weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase);
model.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase);
}

return new LLamaWeights(weights);
return new LLamaWeights(model);
}

/// <summary>
Expand All @@ -113,7 +130,7 @@

// Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a
// slightly smaller range to allow some space for reporting LoRA loading too.
var modelLoadProgressRange = 1f;

Check warning on line 133 in LLama/LLamaWeights.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

The variable 'modelLoadProgressRange' is assigned but its value is never used

Check warning on line 133 in LLama/LLamaWeights.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

The variable 'modelLoadProgressRange' is assigned but its value is never used
if (loraAdapters.Length > 0)
modelLoadProgressRange = 0.9f;

Expand All @@ -133,11 +150,7 @@
if (internalCallback != null && !internalCallback(progress, ctx))
return false;

// Check the cancellation token
if (token.IsCancellationRequested)
return false;

return true;
return token.IsCancellationRequested;
};
}
#endif
Expand Down
10 changes: 3 additions & 7 deletions LLama/Model/IModelCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ public interface IModelCache : IDisposable

// Model Load and Unload
/// <summary>
/// Load a model file to be used for infernce
/// Load a model file to be used for inference
/// The caller assumes responsible for disposing this model
/// </summary>
/// <param name="modelPath"></param>
/// <param name="modelConfigurator"></param>
Expand All @@ -86,15 +87,10 @@ public Task<LLamaWeights> LoadModel(string modelPath,

/// <summary>
/// Attempt to get a model that's expected to be loaded
/// The callers assumes responsiblilty for the lifetime of the model at this point if it exists in the cache
/// </summary>
/// <param name="modeId"></param>
/// <param name="model"></param>
/// <returns></returns>
public bool TryGetLoadedModel(string modeId, out LLamaWeights model);

/// <summary>
/// Currently loaded models
/// </summary>
/// <returns></returns>
public IEnumerable<LLamaWeights> GetLoadedModels();
}
77 changes: 43 additions & 34 deletions LLama/Model/ModelCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Threading;
using System.Threading.Tasks;
using LLama.Common;
using LLama.Native;

namespace LLama.Model;

Expand All @@ -25,7 +26,7 @@ public class ModelCache : IModelCache
private readonly Dictionary<string, IEnumerable<ModelFileMetadata>> _availableModels = [];

// model id/alias, to loaded model
private readonly Dictionary<string, LLamaWeights> _loadedModelCache = [];
private readonly Dictionary<string, SafeLlamaModelHandle> _loadedModelCache = [];

/// <summary>
/// Create a new model manager that seeds available models from the given directory list
Expand All @@ -36,6 +37,15 @@ public ModelCache(string[] directories)
GetModelsFromDirectories(directories);
}

/// <inheritdoc />
public IEnumerable<ModelFileMetadata> ModelFileList
=> _availableModels.SelectMany(x => x.Value);

/// <inheritdoc />
public IEnumerable<string> ModelDirectories
=> _availableModels.Keys;

#region Directories
private void GetModelsFromDirectories(params string[] dirs)
{
foreach (var dir in dirs)
Expand Down Expand Up @@ -78,13 +88,6 @@ private void GetModelsFromDirectories(params string[] dirs)
}
}

/// <inheritdoc />
public IEnumerable<ModelFileMetadata> ModelFileList
=> _availableModels.SelectMany(x => x.Value);
/// <inheritdoc />
public IEnumerable<string> ModelDirectories
=> _availableModels.Keys;

/// <inheritdoc />
public void AddDirectory(string directory)
{
Expand All @@ -111,6 +114,7 @@ public IEnumerable<ModelFileMetadata> GetAvailableModelsFromDirectory(string dir
? dirModels
: [];
}
#endregion Directories

/// <inheritdoc />
public bool TryGetModelFileMetadata(string fileName, out ModelFileMetadata modelMeta)
Expand All @@ -120,25 +124,13 @@ public bool TryGetModelFileMetadata(string fileName, out ModelFileMetadata model
return modelMeta != null;
}

/// <inheritdoc />
public IEnumerable<LLamaWeights> GetLoadedModels()
{
return _loadedModelCache.Values;
}

/// <inheritdoc />
public bool TryGetLoadedModel(string modelId, out LLamaWeights model)
{
var isCached = _loadedModelCache.TryGetValue(modelId, out model!);

// Externall disposed, act like it's not in here
if (isCached && model.NativeHandle.IsClosed)
{
_ = _loadedModelCache.Remove(modelId);
isCached = false;
model = null!;
}

var isCached = _loadedModelCache.TryGetValue(modelId, out var handle);
model = isCached
? LLamaWeights.FromSafeModelHandle(handle)
: null!;
return isCached;
}

Expand All @@ -152,7 +144,6 @@ public async Task<LLamaWeights> LoadModel(string modelPath,
if (!string.IsNullOrEmpty(modelId)
&& TryGetLoadedModel(modelId, out var loadedModel))
{
Trace.TraceWarning($"Model {modelId} already loaded");
return loadedModel;
}

Expand All @@ -162,43 +153,61 @@ public async Task<LLamaWeights> LoadModel(string modelPath,

// load and cache
var model = await LLamaWeights.LoadFromFileAsync(modelParams, cancellationToken);

// Check if it's already cached, if so use that and dispose of this
// TODO: Consider the case where the alias is different but the underlying model file is the same
if (string.IsNullOrWhiteSpace(modelId))
{
modelId = model.ModelName;

// Check if cached again with alias
// TODO: Consider the case where the alias is different but the underlying model file is the same
if (TryGetLoadedModel(modelId, out loadedModel))
{
model.Dispose();
return loadedModel;
}
}
_loadedModelCache.Add(modelId, model);

// Increment the model reference count while this model exists (newly created)
// DangerousAddRef throws if it fails, so there is no need to check "success"
// Do this here since we're passing this to the caller to own and it's not done as part of the normal weight creation
var refSuccess = false;
model.NativeHandle.DangerousAddRef(ref refSuccess);

_loadedModelCache.Add(modelId, model.NativeHandle);
return model;
}

#region Unload
/// <inheritdoc />
public bool UnloadModel(string modelId)
{
if (TryGetLoadedModel(modelId, out var model))
if (_loadedModelCache.TryGetValue(modelId, out var handle))
{
model.Dispose();
return _loadedModelCache.Remove(modelId);
// Decrement refcount on model
handle.DangerousRelease();
handle.Dispose();
if (handle.IsClosed || handle.IsInvalid)
{
return _loadedModelCache.Remove(modelId);
}
return true;
}
return false;
}

/// <inheritdoc />
public void UnloadAllModels()
{
foreach (var model in _loadedModelCache.Values)
foreach (var handle in _loadedModelCache.Values)
{
model.Dispose();
handle.DangerousRelease();
handle.Dispose();
}
_loadedModelCache.Clear();
}

#endregion

#region Dispose
/// <inheritdoc />
public void Dispose()
Expand All @@ -208,7 +217,7 @@ public void Dispose()
}

/// <summary>
/// Unload all models when called explicity via dispose
/// Unload all models when called explicitly via dispose
/// </summary>
/// <param name="disposing">Whether or not this call is made explicitly(true) or via GC</param>
protected virtual void Dispose(bool disposing)
Expand Down
Loading
Loading