diff --git a/README.md b/README.md index ab1de60..90f98cf 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ cargo add --git https://github.com/EricLBuehler/candle-lora.git candle-lora cand `candle-lora-macro` makes using `candle-lora` as simple as adding 2 macros to your model structs and calling a method! It is inspired by the simplicity of the Python `peft` library's `get_peft_model` method. -Together, these macros mean that `candle-lora` can be added to any `candle` model with minimal code changes! To see an example of the benefits, compare the example below (or [here](candle-lora-examples/examples/linear_macro.rs)) to [this](candle-lora-examples/examples/linear.rs), equivalent example. See a precise diff [here](candle-lora-examples/examples/macro_diff.txt). +Together, these macros mean that `candle-lora` can be added to any `candle` model with minimal code changes! ## LoRA transformers See transformers from Candle which have LoRA integrated [here](candle-lora-transformers/examples/). Currently, the following diff --git a/candle-lora-examples/examples/macro_diff.txt b/candle-lora-examples/examples/macro_diff.txt deleted file mode 100644 index f6efb6e..0000000 --- a/candle-lora-examples/examples/macro_diff.txt +++ /dev/null @@ -1,80 +0,0 @@ ---- examples/linear.rs 2023-09-17 17:20:10.914252466 -0400 -+++ examples/linear_macro.rs 2023-09-17 17:20:10.914252466 -0400 -@@ -1,19 +1,12 @@ --use std::{collections::HashMap, hash::Hash}; -+use candle_core::{DType, Device, Module, Result, Tensor}; -+use candle_lora::{LinearLayerLike, LoraConfig, LoraLinearConfig}; -+use candle_lora_macro::{replace_layer_fields, AutoLoraConvert}; -+use candle_nn::{init, Linear, VarBuilder, VarMap}; - --use candle_core::{DType, Device, Result, Tensor}; --use candle_lora::{ -- LinearLayerLike, Lora, LoraConfig, LoraLinearConfig, NewLayers, SelectedLayersBuilder, --}; --use candle_nn::{init, Linear, Module, VarBuilder, VarMap}; -- --#[derive(PartialEq, Eq, Hash)] --enum ModelLayers { -- Layer, --} -- --#[derive(Debug)] -+#[replace_layer_fields] -+#[derive(AutoLoraConvert, Debug)] - struct Model { -- layer: Box, -+ layer: Linear, - } - - impl Module for Model { -@@ -22,16 +15,6 @@ - } - } - --impl Model { -- fn insert_new(&mut self, new: NewLayers) { -- for (name, linear) in new.linear { -- match name { -- ModelLayers::Layer => self.layer = Box::new(linear), -- } -- } -- } --} -- - fn main() { - let device = Device::Cpu; - let dtype = DType::F32; -@@ -51,23 +34,21 @@ - layer: Box::new(Linear::new(layer_weight.clone(), None)), - }; - -- let mut linear_layers = HashMap::new(); -- linear_layers.insert(ModelLayers::Layer, &*model.layer); -- let selected = SelectedLayersBuilder::new() -- .add_linear_layers(linear_layers, LoraLinearConfig::new(10, 10)) -- .build(); -- - let varmap = VarMap::new(); - let vb = VarBuilder::from_varmap(&varmap, dtype, &device); - - let loraconfig = LoraConfig::new(1, 1., None); -- -- let new_layers = Lora::convert_model(selected, loraconfig, &vb); -- -- model.insert_new(new_layers); -+ model.get_lora_model( -+ loraconfig, -+ &vb, -+ Some(LoraLinearConfig::new(10, 10)), -+ None, -+ None, -+ None, -+ ); - - let dummy_image = Tensor::zeros((10, 10), DType::F32, &device).unwrap(); - -- let lora_output = model.forward(&dummy_image).unwrap(); -- println!("Output: {lora_output:?}"); -+ let digit = model.forward(&dummy_image).unwrap(); -+ println!("Output: {digit:?}"); - }