From 410cb3fcaaad0a17ab7fb5beec8b5dbc05726c33 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Wed, 3 Apr 2024 05:42:19 -0400 Subject: [PATCH 1/4] Fix layer swapping on Box --- candle-lora-macro/README.md | 20 +-- candle-lora-macro/examples/linear.rs | 8 +- candle-lora-macro/src/lib.rs | 196 ++++++++++++------------ candle-lora-transformers/src/bert.rs | 2 +- candle-lora-transformers/src/bigcode.rs | 4 +- candle-lora-transformers/src/dinov2.rs | 4 +- candle-lora-transformers/src/falcon.rs | 4 +- candle-lora/tests/conv1d.rs | 8 +- candle-lora/tests/conv1d_merged.rs | 4 +- candle-lora/tests/conv2d.rs | 8 +- candle-lora/tests/conv2d_merged.rs | 8 +- candle-lora/tests/embed.rs | 8 +- candle-lora/tests/embed_merged.rs | 8 +- candle-lora/tests/linear.rs | 8 +- candle-lora/tests/linear_merged.rs | 8 +- 15 files changed, 162 insertions(+), 136 deletions(-) diff --git a/candle-lora-macro/README.md b/candle-lora-macro/README.md index e4e753b..c6fb3f7 100644 --- a/candle-lora-macro/README.md +++ b/candle-lora-macro/README.md @@ -4,19 +4,19 @@ This library makes using [`candle-lora`](https://github.com/EricLBuehler/candle- `candle-lora-macro` exports 2 macros: `AutoLoraConvert` and `replace_layer_fields`. The `AutoLoraConvert` derive macro automatically creates a method `get_lora_model`, when called which selects and swaps all supported layers for their LoRA counterparts. This method is the equivalent of `peft`'s `get_peft_model` method, and modifies the model in place. It expects all -layers of the supported types to be a `dyn` type, that is `Box`. +layers of the supported types to be a `dyn` type: `Arc`. **Therefore the type wrapping the layer must be `Arc`.** In addition, `AutoLoraConvert` also defines a method `get_merged_lora_model` which does everything `get_lora_model` does, but also merges the weights of the LoRA layers to improve inference performance. To further automate the process of using `candle-lora`, `candle-lora-macro` also provides an attribute macro called `replace_layer_fields`. -`replace_layer_fields` swaps out the concrete types for `dyn` types. If this macro is not added to the model structs, be sure to change the member types to `Box`. +`replace_layer_fields` swaps out the concrete types for `dyn` types. If this macro is not added to the model structs, be sure to change the member types to `Arc`. `replace_layer_fields` is able to swap: -- `Linear` to `Box` -- `Conv1d` to `Box` -- `Conv2d` to `Box` -- `Embedding` to `Box` -- `Option` to `Option>` -- `Option` to `Option>` -- `Option` to `Option>` -- `Option` to `Option>` \ No newline at end of file +- `Linear` to `Arc` +- `Conv1d` to `Arc` +- `Conv2d` to `Arc` +- `Embedding` to `Arc` +- `Option` to `Option>` +- `Option` to `Option>` +- `Option` to `Option>` +- `Option` to `Option>` \ No newline at end of file diff --git a/candle-lora-macro/examples/linear.rs b/candle-lora-macro/examples/linear.rs index 490402f..03fe91c 100644 --- a/candle-lora-macro/examples/linear.rs +++ b/candle-lora-macro/examples/linear.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use candle_core::{DType, Device, Module, Result, Tensor}; use candle_lora::{LinearLayerLike, LoraConfig, LoraLinearConfig}; use candle_lora_macro::{replace_layer_fields, AutoLoraConvert}; @@ -6,7 +8,7 @@ use candle_nn::{init, Linear, VarBuilder, VarMap}; #[replace_layer_fields] #[derive(AutoLoraConvert, Debug)] struct Model { - a: Box, + a: Arc, b: i32, } @@ -32,7 +34,7 @@ fn main() { .unwrap(); let mut model = Model { - a: Box::new(Linear::new(layer_weight.clone(), None)), + a: Arc::new(Linear::new(layer_weight.clone(), None)), b: 1, }; @@ -49,6 +51,8 @@ fn main() { None, ); + println!("{:?}", model.a); + let dummy_image = Tensor::zeros((10, 10), DType::F32, &device).unwrap(); //Test the model diff --git a/candle-lora-macro/src/lib.rs b/candle-lora-macro/src/lib.rs index fb6b239..40e47a4 100644 --- a/candle-lora-macro/src/lib.rs +++ b/candle-lora-macro/src/lib.rs @@ -9,121 +9,127 @@ use syn::{ pub fn replace_layer_fields(_args: TokenStream1, input: TokenStream1) -> TokenStream1 { let mut ast = parse_macro_input!(input as DeriveInput); match &mut ast.data { - Data::Struct(ref mut struct_data) => match &mut struct_data.fields { - Fields::Named(fields) => { - for field in fields.named.iter_mut() { - let mut f = None; - let ident = field.ident.clone().unwrap(); - let ty = field.ty.clone(); - if let Type::Path(path) = ty { - if path.path.segments.len() == 1 { - match path - .path - .segments - .first() - .unwrap() - .ident - .to_string() - .as_str() - { - "Linear" => { - if let Visibility::Public(_) = field.vis { - f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Arc)).unwrap()); - } else { - f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Arc)).unwrap()); + Data::Struct(ref mut struct_data) => { + match &mut struct_data.fields { + Fields::Named(fields) => { + for field in fields.named.iter_mut() { + let mut f = None; + let ident = field.ident.clone().unwrap(); + let ty = field.ty.clone(); + if let Type::Path(path) = ty { + if path.path.segments.len() == 1 { + match path + .path + .segments + .first() + .unwrap() + .ident + .to_string() + .as_str() + { + "Linear" => { + if let Visibility::Public(_) = field.vis { + f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Arc)).unwrap()); + } else { + f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Arc)).unwrap()); + } } - } - "Conv1d" => { - if let Visibility::Public(_) = field.vis { - f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Arc)).unwrap()); - } else { - f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Arc)).unwrap()); + "Conv1d" => { + if let Visibility::Public(_) = field.vis { + f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Arc)).unwrap()); + } else { + f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Arc)).unwrap()); + } } - } - "Conv2d" => { - if let Visibility::Public(_) = field.vis { - f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Arc)).unwrap()); - } else { - f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Arc)).unwrap()); + "Conv2d" => { + if let Visibility::Public(_) = field.vis { + f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Arc)).unwrap()); + } else { + f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Arc)).unwrap()); + } } - } - "Embedding" => { - if let Visibility::Public(_) = field.vis { - f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Arc)).unwrap()); - } else { - f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Arc)).unwrap()); + "Embedding" => { + if let Visibility::Public(_) = field.vis { + f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Arc)).unwrap()); + } else { + f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Arc)).unwrap()); + } } - } - "Option" => { - if let PathArguments::AngleBracketed(bracketed) = - &path.path.segments.first().unwrap().arguments - { - if bracketed.args.len() == 1 { - if let GenericArgument::Type(Type::Path(tp)) = - bracketed.args.first().unwrap() - { - if tp.path.segments.len() == 1 { - match tp - .path - .segments - .first() - .unwrap() - .ident - .to_string() - .as_str() - { - "Linear" => { - if let Visibility::Public(_) = field.vis - { - f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Option>)).unwrap()); - } else { - f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Option>)).unwrap()); + "Option" => { + if let PathArguments::AngleBracketed(bracketed) = + &path.path.segments.first().unwrap().arguments + { + if bracketed.args.len() == 1 { + if let GenericArgument::Type(Type::Path(tp)) = + bracketed.args.first().unwrap() + { + if tp.path.segments.len() == 1 { + match tp + .path + .segments + .first() + .unwrap() + .ident + .to_string() + .as_str() + { + "Linear" => { + if let Visibility::Public(_) = + field.vis + { + f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Option>)).unwrap()); + } else { + f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Option>)).unwrap()); + } } - } - "Conv1d" => { - if let Visibility::Public(_) = field.vis - { - f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Option>)).unwrap()); - } else { - f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Option>)).unwrap()); + "Conv1d" => { + if let Visibility::Public(_) = + field.vis + { + f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Option>)).unwrap()); + } else { + f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Option>)).unwrap()); + } } - } - "Conv2d" => { - if let Visibility::Public(_) = field.vis - { - f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Option>)).unwrap()); - } else { - f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Option>)).unwrap()); + "Conv2d" => { + if let Visibility::Public(_) = + field.vis + { + f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Option>)).unwrap()); + } else { + f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Option>)).unwrap()); + } } - } - "Embedding" => { - if let Visibility::Public(_) = field.vis - { - f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Option>)).unwrap()); - } else { - f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Option>)).unwrap()); + "Embedding" => { + if let Visibility::Public(_) = + field.vis + { + f = Some(syn::Field::parse_named.parse2(quote::quote!(pub #ident: Option>)).unwrap()); + } else { + f = Some(syn::Field::parse_named.parse2(quote::quote!(#ident: Option>)).unwrap()); + } } + _ => {} } - _ => {} } } } } } + _ => {} } - _ => {} } } - } - if let Some(f) = f { - *field = f; + if let Some(f) = f { + *field = f; + } } } + _ => { + panic!("Named fields are required.") + } } - _ => { - panic!("Named fields are required.") - } - }, + } _ => { panic!("Cannot swap fields of non struct!"); } diff --git a/candle-lora-transformers/src/bert.rs b/candle-lora-transformers/src/bert.rs index ae55198..11b7d16 100644 --- a/candle-lora-transformers/src/bert.rs +++ b/candle-lora-transformers/src/bert.rs @@ -268,7 +268,7 @@ struct BertEmbedding { } impl Deref for BertEmbedding { - type Target = Arc; + type Target = Arc; fn deref(&self) -> &Self::Target { &self.inner diff --git a/candle-lora-transformers/src/bigcode.rs b/candle-lora-transformers/src/bigcode.rs index b9aa43d..a76626d 100644 --- a/candle-lora-transformers/src/bigcode.rs +++ b/candle-lora-transformers/src/bigcode.rs @@ -16,7 +16,7 @@ struct CustomLinear { } impl Deref for CustomLinear { - type Target = Arc; + type Target = Arc; fn deref(&self) -> &Self::Target { &self.inner @@ -30,7 +30,7 @@ struct CustomEmbedding { } impl Deref for CustomEmbedding { - type Target = Arc; + type Target = Arc; fn deref(&self) -> &Self::Target { &self.inner diff --git a/candle-lora-transformers/src/dinov2.rs b/candle-lora-transformers/src/dinov2.rs index c39b91a..608b889 100644 --- a/candle-lora-transformers/src/dinov2.rs +++ b/candle-lora-transformers/src/dinov2.rs @@ -28,7 +28,7 @@ struct DinoLinear { } impl Deref for DinoLinear { - type Target = Arc; + type Target = Arc; fn deref(&self) -> &Self::Target { &self.inner @@ -290,7 +290,7 @@ struct DinoConv2d { } impl Deref for DinoConv2d { - type Target = Arc; + type Target = Arc; fn deref(&self) -> &Self::Target { &self.inner diff --git a/candle-lora-transformers/src/falcon.rs b/candle-lora-transformers/src/falcon.rs index c7b3453..c3254fb 100644 --- a/candle-lora-transformers/src/falcon.rs +++ b/candle-lora-transformers/src/falcon.rs @@ -219,7 +219,7 @@ struct AttentionDense { } impl Deref for AttentionQKV { - type Target = Arc; + type Target = Arc; fn deref(&self) -> &Self::Target { &self.query_key_value @@ -227,7 +227,7 @@ impl Deref for AttentionQKV { } impl Deref for AttentionDense { - type Target = Arc; + type Target = Arc; fn deref(&self) -> &Self::Target { &self.dense diff --git a/candle-lora/tests/conv1d.rs b/candle-lora/tests/conv1d.rs index 248a65b..86c2ea1 100644 --- a/candle-lora/tests/conv1d.rs +++ b/candle-lora/tests/conv1d.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use candle_lora::{LoraConfig, SelectedLayersBuilder}; use candle_nn::VarBuilder; @@ -16,7 +18,7 @@ fn conv1d() -> candle_core::Result<()> { #[derive(Debug)] struct Model { - conv: Box, + conv: Arc, } impl Module for Model { @@ -29,7 +31,7 @@ fn conv1d() -> candle_core::Result<()> { fn insert_new(&mut self, new: NewLayers) { for (name, conv) in new.conv1d { match name { - ModelLayers::Conv => self.conv = Box::new(conv), + ModelLayers::Conv => self.conv = Arc::new(conv), } } } @@ -56,7 +58,7 @@ fn conv1d() -> candle_core::Result<()> { )?; let mut model = Model { - conv: Box::new(Conv1d::new( + conv: Arc::new(Conv1d::new( conv_weight.clone(), Some(conv_bias.clone()), Conv1dConfig::default(), diff --git a/candle-lora/tests/conv1d_merged.rs b/candle-lora/tests/conv1d_merged.rs index 096fb9d..6d5e47a 100644 --- a/candle-lora/tests/conv1d_merged.rs +++ b/candle-lora/tests/conv1d_merged.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use candle_lora::{LoraConfig, Merge, SelectedLayersBuilder}; use candle_nn::VarBuilder; @@ -16,7 +18,7 @@ fn conv1d() -> candle_core::Result<()> { #[derive(Debug)] struct Model { - conv: Box, + conv: Arc, } impl Module for Model { diff --git a/candle-lora/tests/conv2d.rs b/candle-lora/tests/conv2d.rs index 7f41069..62a1167 100644 --- a/candle-lora/tests/conv2d.rs +++ b/candle-lora/tests/conv2d.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use candle_lora::{LoraConfig, LoraConv2dConfig, SelectedLayersBuilder}; use candle_nn::VarBuilder; @@ -16,7 +18,7 @@ fn conv2d() -> candle_core::Result<()> { #[derive(Debug)] struct Model { - conv: Box, + conv: Arc, } impl Module for Model { @@ -29,7 +31,7 @@ fn conv2d() -> candle_core::Result<()> { fn insert_new(&mut self, new: NewLayers) { for (name, conv) in new.conv2d { match name { - ModelLayers::Conv => self.conv = Box::new(conv), + ModelLayers::Conv => self.conv = Arc::new(conv), } } } @@ -67,7 +69,7 @@ fn conv2d() -> candle_core::Result<()> { )?; let mut model = Model { - conv: Box::new(Conv2d::new( + conv: Arc::new(Conv2d::new( conv_weight.clone(), Some(conv_bias.clone()), cfg, diff --git a/candle-lora/tests/conv2d_merged.rs b/candle-lora/tests/conv2d_merged.rs index 93b7fdb..9542f2d 100644 --- a/candle-lora/tests/conv2d_merged.rs +++ b/candle-lora/tests/conv2d_merged.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use candle_lora::{LoraConfig, LoraConv2dConfig, Merge, SelectedLayersBuilder}; use candle_nn::VarBuilder; @@ -16,7 +18,7 @@ fn conv2d() -> candle_core::Result<()> { #[derive(Debug)] struct Model { - conv: Box, + conv: Arc, } impl Module for Model { @@ -31,7 +33,7 @@ fn conv2d() -> candle_core::Result<()> { match name { ModelLayers::Conv => { conv.merge_weights().unwrap(); - self.conv = Box::new(conv) + self.conv = Arc::new(conv) } } } @@ -70,7 +72,7 @@ fn conv2d() -> candle_core::Result<()> { )?; let mut model = Model { - conv: Box::new(Conv2d::new( + conv: Arc::new(Conv2d::new( conv_weight.clone(), Some(conv_bias.clone()), cfg, diff --git a/candle-lora/tests/embed.rs b/candle-lora/tests/embed.rs index aa5d268..5c84325 100644 --- a/candle-lora/tests/embed.rs +++ b/candle-lora/tests/embed.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use candle_lora::{LoraConfig, LoraEmbeddingConfig, SelectedLayersBuilder}; use candle_nn::VarBuilder; @@ -16,7 +18,7 @@ fn embed() -> candle_core::Result<()> { #[derive(Debug)] struct Model { - embed: Box, + embed: Arc, } impl Module for Model { @@ -29,7 +31,7 @@ fn embed() -> candle_core::Result<()> { fn insert_new(&mut self, new: NewLayers) { for (name, embed) in new.embed { match name { - ModelLayers::Embed => self.embed = Box::new(embed), + ModelLayers::Embed => self.embed = Arc::new(embed), } } } @@ -52,7 +54,7 @@ fn embed() -> candle_core::Result<()> { )?; let mut model = Model { - embed: Box::new(Embedding::new(embed_weight, hidden_size)), + embed: Arc::new(Embedding::new(embed_weight, hidden_size)), }; let dummy_image = Tensor::zeros((2, 4), DType::U32, &device)?; diff --git a/candle-lora/tests/embed_merged.rs b/candle-lora/tests/embed_merged.rs index 2efa353..0661e4b 100644 --- a/candle-lora/tests/embed_merged.rs +++ b/candle-lora/tests/embed_merged.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use candle_lora::{LoraConfig, LoraEmbeddingConfig, Merge, SelectedLayersBuilder}; use candle_nn::VarBuilder; @@ -16,7 +18,7 @@ fn embed() -> candle_core::Result<()> { #[derive(Debug)] struct Model { - embed: Box, + embed: Arc, } impl Module for Model { @@ -31,7 +33,7 @@ fn embed() -> candle_core::Result<()> { match name { ModelLayers::Embed => { embed.merge_weights().unwrap(); - self.embed = Box::new(embed) + self.embed = Arc::new(embed) } } } @@ -54,7 +56,7 @@ fn embed() -> candle_core::Result<()> { )?; let mut model = Model { - embed: Box::new(Embedding::new(embed_weight, hidden_size)), + embed: Arc::new(Embedding::new(embed_weight, hidden_size)), }; let dummy_image = Tensor::zeros((2, 4), DType::U32, &device)?; diff --git a/candle-lora/tests/linear.rs b/candle-lora/tests/linear.rs index 76a8f85..6532596 100644 --- a/candle-lora/tests/linear.rs +++ b/candle-lora/tests/linear.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use candle_lora::{LoraConfig, NewLayers, SelectedLayersBuilder}; use candle_nn::VarBuilder; @@ -16,7 +18,7 @@ fn single_linear() -> candle_core::Result<()> { #[derive(Debug)] struct Model { - layer: Box, + layer: Arc, } impl Module for Model { @@ -29,7 +31,7 @@ fn single_linear() -> candle_core::Result<()> { fn insert_new(&mut self, new: NewLayers) { for (name, linear) in new.linear { match name { - ModelLayers::Layer => self.layer = Box::new(linear), + ModelLayers::Layer => self.layer = Arc::new(linear), } } } @@ -49,7 +51,7 @@ fn single_linear() -> candle_core::Result<()> { )?; let mut model = Model { - layer: Box::new(Linear::new(layer_weight.clone(), None)), + layer: Arc::new(Linear::new(layer_weight.clone(), None)), }; let dummy_image = Tensor::zeros((10, 10), DType::F32, &device)?; diff --git a/candle-lora/tests/linear_merged.rs b/candle-lora/tests/linear_merged.rs index 0da78ef..1c4449d 100644 --- a/candle-lora/tests/linear_merged.rs +++ b/candle-lora/tests/linear_merged.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use candle_lora::{LoraConfig, Merge, NewLayers, SelectedLayersBuilder}; use candle_nn::VarBuilder; @@ -16,7 +18,7 @@ fn linear() -> candle_core::Result<()> { #[derive(Debug)] struct Model { - layer: Box, + layer: Arc, } impl Module for Model { @@ -31,7 +33,7 @@ fn linear() -> candle_core::Result<()> { match name { ModelLayers::Layer => { linear.merge_weights().unwrap(); - self.layer = Box::new(linear) + self.layer = Arc::new(linear) } } } @@ -52,7 +54,7 @@ fn linear() -> candle_core::Result<()> { )?; let mut model = Model { - layer: Box::new(Linear::new(layer_weight.clone(), None)), + layer: Arc::new(Linear::new(layer_weight.clone(), None)), }; let dummy_image = Tensor::zeros((10, 10), DType::F32, &device)?; From ec14269a72e81ec1b04c2783f4459b8f1fa828e1 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Wed, 3 Apr 2024 05:43:37 -0400 Subject: [PATCH 2/4] Fix layer swapping on Box --- candle-lora/tests/conv1d_merged.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/candle-lora/tests/conv1d_merged.rs b/candle-lora/tests/conv1d_merged.rs index 6d5e47a..0c61b06 100644 --- a/candle-lora/tests/conv1d_merged.rs +++ b/candle-lora/tests/conv1d_merged.rs @@ -33,7 +33,7 @@ fn conv1d() -> candle_core::Result<()> { match name { ModelLayers::Conv => { conv.merge_weights().unwrap(); - self.conv = Box::new(conv) + self.conv = Arc::new(conv) } } } @@ -61,7 +61,7 @@ fn conv1d() -> candle_core::Result<()> { )?; let mut model = Model { - conv: Box::new(Conv1d::new( + conv: Arc::new(Conv1d::new( conv_weight.clone(), Some(conv_bias.clone()), Conv1dConfig::default(), From 8f96feb60116e5827c3a1c5b9926b8d5344920ed Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Wed, 3 Apr 2024 06:57:45 -0400 Subject: [PATCH 3/4] Implement weight retrieval method --- README.md | 5 +++ candle-lora-macro/examples/linear.rs | 2 +- candle-lora-macro/src/lib.rs | 45 ++++++++++++++++++++++++++ candle-lora-transformers/src/llama.rs | 7 ++++ candle-lora-transformers/src/resnet.rs | 2 +- candle-lora/src/frozenconv.rs | 16 ++++++++- candle-lora/src/frozenembed.rs | 10 +++++- candle-lora/src/frozenlinear.rs | 10 +++++- candle-lora/src/lib.rs | 36 ++++++++++++++++++--- candle-lora/src/loraconv1d.rs | 20 +++++++++++- candle-lora/src/loraconv2d.rs | 20 +++++++++++- candle-lora/src/loraembed.rs | 21 ++++++++++-- candle-lora/src/loralinear.rs | 20 +++++++++++- 13 files changed, 200 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index d977b8c..762e1b7 100644 --- a/README.md +++ b/README.md @@ -54,5 +54,10 @@ transformers have been converted: To use a LoRA transformer, simply replace the model from `candle-transformers` with its counterpart in `candle-lora-transformers`! +## Saving and loading +`candle_lora` supports retrieving weights for LoRA adapters via the `get_tensors` method, defined automatically in `#[auto_layer_convert]`. This function is meant to be used with `candle_core::safetensors::save()`. To load, simply load the `VarBuilder` and pass that to `get_lora_model`. + +`candle_lora`'s weight naming is not compatible with `peft` yet. + ## Resources `candle-lora`'s LoRA conversion implementations are based on HuggingFace's [`peft`](https://github.com/huggingface/peft/tree/main) library. See the original paper [here](https://arxiv.org/pdf/2106.09685.pdf), as well as Microsoft's [implementation](https://github.com/microsoft/LoRA). \ No newline at end of file diff --git a/candle-lora-macro/examples/linear.rs b/candle-lora-macro/examples/linear.rs index 03fe91c..58a2cc3 100644 --- a/candle-lora-macro/examples/linear.rs +++ b/candle-lora-macro/examples/linear.rs @@ -51,7 +51,7 @@ fn main() { None, ); - println!("{:?}", model.a); + dbg!(model.get_tensors()); let dummy_image = Tensor::zeros((10, 10), DType::F32, &device).unwrap(); diff --git a/candle-lora-macro/src/lib.rs b/candle-lora-macro/src/lib.rs index 40e47a4..2f9b6fe 100644 --- a/candle-lora-macro/src/lib.rs +++ b/candle-lora-macro/src/lib.rs @@ -303,6 +303,15 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 { }];); } + let mut linear_get = TokenStream::new(); + if !linear_fields.is_empty() { + quote_into::quote_into!(linear_get += [#{ + for (namei,_) in linear_fields.iter() { + quote_into::quote_into!(linear_get += (self.#namei.get_tensors(&mut output)),) + } + }];); + } + let mut conv1d_stream = TokenStream::new(); if !conv1d_fields.is_empty() { quote_into::quote_into!(conv1d_stream += [#{ @@ -312,6 +321,15 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 { }];); } + let mut conv1d_get = TokenStream::new(); + if !conv1d_fields.is_empty() { + quote_into::quote_into!(conv1d_get += [#{ + for (namei,_) in conv1d_fields.iter() { + quote_into::quote_into!(conv1d_get += (self.#namei.get_tensors(&mut output)),) + } + }];); + } + let mut conv2d_stream = TokenStream::new(); if !conv2d_fields.is_empty() { quote_into::quote_into!(conv2d_stream += [#{ @@ -321,6 +339,15 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 { }];); } + let mut conv2d_get = TokenStream::new(); + if !conv2d_fields.is_empty() { + quote_into::quote_into!(conv2d_get += [#{ + for (namei,_) in conv2d_fields.iter() { + quote_into::quote_into!(conv2d_get += (self.#namei.get_tensors(&mut output)),) + } + }];); + } + let mut embed_stream = TokenStream::new(); if !embed_fields.is_empty() { quote_into::quote_into!(embed_stream += [#{ @@ -330,6 +357,15 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 { }];); } + let mut embed_get = TokenStream::new(); + if !embed_fields.is_empty() { + quote_into::quote_into!(embed_get += [#{ + for (namei,_) in embed_fields.iter() { + quote_into::quote_into!(embed_get += (self.#namei.get_tensors(&mut output)),) + } + }];); + } + let mut linear_stream_assign = TokenStream::new(); if !linear_fields.is_empty() { quote_into::quote_into!(linear_stream_assign += [#{ @@ -653,6 +689,15 @@ pub fn auto_lora_convert(tokens: TokenStream1) -> TokenStream1 { #conv2d_merge_option1_stream_assign #embed_merge_option1_stream_assign } + + pub fn get_tensors(&self) -> ::std::collections::HashMap { + let mut output = ::std::collections::HashMap::new(); + #linear_get + #conv1d_get + #conv2d_get + #embed_get + output + } } } diff --git a/candle-lora-transformers/src/llama.rs b/candle-lora-transformers/src/llama.rs index c3c218e..a8c13c5 100644 --- a/candle-lora-transformers/src/llama.rs +++ b/candle-lora-transformers/src/llama.rs @@ -3,6 +3,7 @@ use candle_core::{DType, Device, IndexOp, Result, Tensor, D}; use candle_lora::{ EmbeddingLayerLike, LinearLayerLike, LoraConfig, LoraEmbeddingConfig, LoraLinearConfig, + Saveable, }; use candle_lora_macro::{replace_layer_fields, AutoLoraConvert}; use candle_nn::{Embedding, Module, VarBuilder}; @@ -103,6 +104,12 @@ impl Module for LlamaLinear { } } +impl Saveable for LlamaLinear { + fn get_tensors(&self, _accum: &mut HashMap) { + unimplemented!() + } +} + impl LinearLayerLike for LlamaLinear { fn bias(&self) -> Option<&Tensor> { self.inner.bias() diff --git a/candle-lora-transformers/src/resnet.rs b/candle-lora-transformers/src/resnet.rs index 434e5b0..5602a5c 100644 --- a/candle-lora-transformers/src/resnet.rs +++ b/candle-lora-transformers/src/resnet.rs @@ -3,7 +3,7 @@ //! See "Deep Residual Learning for Image Recognition" He et al. 2015 //! -use candle_core::{Module, Result, D}; +use candle_core::{Module, Result, Tensor, D}; use candle_lora::{Conv2dLayerLike, LoraConfig, LoraConv2dConfig}; use candle_lora_macro::{replace_layer_fields, AutoLoraConvert}; use candle_nn::{batch_norm, VarBuilder}; diff --git a/candle-lora/src/frozenconv.rs b/candle-lora/src/frozenconv.rs index e1d4919..b4e528f 100644 --- a/candle-lora/src/frozenconv.rs +++ b/candle-lora/src/frozenconv.rs @@ -1,7 +1,9 @@ +use std::collections::HashMap; + use candle_core::{Module, Result, Tensor}; use candle_nn::{Conv1d, Conv1dConfig, Conv2d, Conv2dConfig}; -use crate::{Conv1dLayerLike, Conv2dLayerLike}; +use crate::{Conv1dLayerLike, Conv2dLayerLike, Saveable}; /// Conv1d, but with a `new` implementation that ensures the weights are detached (frozen). #[derive(Debug)] @@ -42,6 +44,12 @@ impl Module for FrozenConv1d { } } +impl Saveable for FrozenConv1d { + fn get_tensors(&self, _accum: &mut HashMap) { + unimplemented!("Saving not supported for frozen layers, only for candle_lora layers."); + } +} + impl Conv1dLayerLike for FrozenConv1d { fn config(&self) -> &Conv1dConfig { self.conv.config() @@ -93,6 +101,12 @@ impl Module for FrozenConv2d { } } +impl Saveable for FrozenConv2d { + fn get_tensors(&self, _accum: &mut HashMap) { + unimplemented!("Saving not supported for frozen layers, only for candle_lora layers."); + } +} + impl Conv2dLayerLike for FrozenConv2d { fn config(&self) -> &Conv2dConfig { self.conv.config() diff --git a/candle-lora/src/frozenembed.rs b/candle-lora/src/frozenembed.rs index 9c77412..0341c30 100644 --- a/candle-lora/src/frozenembed.rs +++ b/candle-lora/src/frozenembed.rs @@ -1,7 +1,9 @@ +use std::collections::HashMap; + use candle_core::{Result, Tensor}; use candle_nn::Embedding; -use crate::EmbeddingLayerLike; +use crate::{EmbeddingLayerLike, Saveable}; /// Embedding, but with a `new` implementation that ensures the embeddings are detached (frozen). #[derive(Debug)] @@ -27,6 +29,12 @@ impl crate::Module for FrozenEmbedding { } } +impl Saveable for FrozenEmbedding { + fn get_tensors(&self, _accum: &mut HashMap) { + unimplemented!("Saving not supported for frozen layers, only for candle_lora layers."); + } +} + impl EmbeddingLayerLike for FrozenEmbedding { fn embeddings(&self) -> &Tensor { self.embed.embeddings() diff --git a/candle-lora/src/frozenlinear.rs b/candle-lora/src/frozenlinear.rs index b241688..a6b869a 100644 --- a/candle-lora/src/frozenlinear.rs +++ b/candle-lora/src/frozenlinear.rs @@ -1,7 +1,9 @@ +use std::collections::HashMap; + use candle_core::{Module, Result, Shape, Tensor}; use candle_nn::Linear; -use crate::LinearLayerLike; +use crate::{LinearLayerLike, Saveable}; /// Linear, but with a `new` implementation that ensures the weight and/or biases are detached (frozen). #[derive(Debug)] @@ -27,6 +29,12 @@ impl Module for FrozenLinear { } } +impl Saveable for FrozenLinear { + fn get_tensors(&self, _accum: &mut HashMap) { + unimplemented!("Saving not supported for frozen layers, only for candle_lora layers."); + } +} + impl LinearLayerLike for FrozenLinear { fn bias(&self) -> Option<&Tensor> { self.linear.bias() diff --git a/candle-lora/src/lib.rs b/candle-lora/src/lib.rs index 12c013a..e3a0c41 100644 --- a/candle-lora/src/lib.rs +++ b/candle-lora/src/lib.rs @@ -211,13 +211,23 @@ pub struct NewLayers { pub embed: HashMap, } +pub trait Saveable { + fn get_tensors(&self, accum: &mut HashMap); +} + /// Any layer that is linear-like. -pub trait LinearLayerLike: Module + Debug { +pub trait LinearLayerLike: Module + Debug + Saveable { fn weight(&self) -> &Tensor; fn bias(&self) -> Option<&Tensor>; fn shape(&self) -> &Shape; } +impl Saveable for Linear { + fn get_tensors(&self, _accum: &mut HashMap) { + unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers."); + } +} + impl LinearLayerLike for Linear { fn weight(&self) -> &Tensor { self.weight() @@ -231,12 +241,18 @@ impl LinearLayerLike for Linear { } /// Any layer that is conv1d-like. -pub trait Conv1dLayerLike: Module + Debug { +pub trait Conv1dLayerLike: Module + Debug + Saveable { fn weight(&self) -> &Tensor; fn bias(&self) -> Option<&Tensor>; fn config(&self) -> &Conv1dConfig; } +impl Saveable for Conv1d { + fn get_tensors(&self, _accum: &mut HashMap) { + unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers."); + } +} + impl Conv1dLayerLike for Conv1d { fn config(&self) -> &Conv1dConfig { self.config() @@ -250,12 +266,18 @@ impl Conv1dLayerLike for Conv1d { } /// Any layer that is conv2d-like. -pub trait Conv2dLayerLike: Module + Debug { +pub trait Conv2dLayerLike: Module + Debug + Saveable { fn weight(&self) -> &Tensor; fn bias(&self) -> Option<&Tensor>; fn config(&self) -> &Conv2dConfig; } +impl Saveable for Conv2d { + fn get_tensors(&self, _accum: &mut HashMap) { + unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers."); + } +} + impl Conv2dLayerLike for Conv2d { fn config(&self) -> &Conv2dConfig { self.config() @@ -269,11 +291,17 @@ impl Conv2dLayerLike for Conv2d { } /// Any layer that is embedding-like. -pub trait EmbeddingLayerLike: Module + Debug { +pub trait EmbeddingLayerLike: Module + Debug + Saveable { fn embeddings(&self) -> &Tensor; fn hidden_size(&self) -> usize; } +impl Saveable for Embedding { + fn get_tensors(&self, _accum: &mut HashMap) { + unimplemented!("Saving not supported for candle_nn layers, only for candle_lora layers."); + } +} + impl EmbeddingLayerLike for Embedding { fn embeddings(&self) -> &Tensor { self.embeddings() diff --git a/candle-lora/src/loraconv1d.rs b/candle-lora/src/loraconv1d.rs index 9bd6332..ea7e579 100644 --- a/candle-lora/src/loraconv1d.rs +++ b/candle-lora/src/loraconv1d.rs @@ -1,4 +1,4 @@ -use std::ops::Mul; +use std::{collections::HashMap, ops::Mul}; use candle_core::{Module, Result, Tensor}; use candle_nn::{init, Conv1d, Conv1dConfig, Dropout, VarBuilder}; @@ -7,6 +7,7 @@ use trc::Trc; use crate::{ frozenconv::FrozenConv1d, Conv1dLayerLike, LoraConfig, Merge, MergeError, MergeErrorOrError, + Saveable, }; #[derive(Debug, Clone)] @@ -17,6 +18,8 @@ pub struct LoraConv1d { scale: Option, dropout: Option>, merged: bool, + prefix: String, + id: usize, } #[derive(Clone, Debug)] @@ -73,6 +76,8 @@ impl LoraConv1d { }, dropout: config.dropout.map(|x| Trc::new(Dropout::new(x))), merged: false, + prefix: vb.prefix(), + id, }) } } @@ -155,6 +160,19 @@ impl Module for LoraConv1d { } } +impl Saveable for LoraConv1d { + fn get_tensors(&self, accum: &mut HashMap) { + accum.insert( + self.prefix.clone() + &format!("a{}.weight", self.id), + self.a.clone(), + ); + accum.insert( + self.prefix.clone() + &format!("b{}.weight", self.id), + self.b.clone(), + ); + } +} + impl Conv1dLayerLike for LoraConv1d { fn config(&self) -> &Conv1dConfig { self.old.config() diff --git a/candle-lora/src/loraconv2d.rs b/candle-lora/src/loraconv2d.rs index 09e682b..0eb699e 100644 --- a/candle-lora/src/loraconv2d.rs +++ b/candle-lora/src/loraconv2d.rs @@ -1,4 +1,4 @@ -use std::ops::Mul; +use std::{collections::HashMap, ops::Mul}; use candle_core::{Module, Result, Tensor}; use candle_nn::{init, Conv2d, Conv2dConfig, Dropout, VarBuilder}; @@ -7,6 +7,7 @@ use trc::Trc; use crate::{ frozenconv::FrozenConv2d, Conv2dLayerLike, LoraConfig, Merge, MergeError, MergeErrorOrError, + Saveable, }; #[derive(Debug, Clone)] @@ -17,6 +18,8 @@ pub struct LoraConv2d { scale: Option, dropout: Option>, merged: bool, + prefix: String, + id: usize, } #[derive(Clone, Debug)] @@ -85,6 +88,8 @@ impl LoraConv2d { }, dropout: config.dropout.map(|x| Trc::new(Dropout::new(x))), merged: false, + prefix: vb.prefix(), + id, }) } } @@ -189,6 +194,19 @@ impl Module for LoraConv2d { } } +impl Saveable for LoraConv2d { + fn get_tensors(&self, accum: &mut HashMap) { + accum.insert( + self.prefix.clone() + &format!("a{}.weight", self.id), + self.a_conv.weight().clone(), + ); + accum.insert( + self.prefix.clone() + &format!("b{}.weight", self.id), + self.b_conv.weight().clone(), + ); + } +} + impl Conv2dLayerLike for LoraConv2d { fn config(&self) -> &Conv2dConfig { self.old.config() diff --git a/candle-lora/src/loraembed.rs b/candle-lora/src/loraembed.rs index 1350bad..bff4446 100644 --- a/candle-lora/src/loraembed.rs +++ b/candle-lora/src/loraembed.rs @@ -1,4 +1,4 @@ -use std::ops::Mul; +use std::{collections::HashMap, ops::Mul}; use candle_core::{Module, Result, Tensor}; use candle_nn::{init, Embedding, Init, VarBuilder}; @@ -7,7 +7,7 @@ use trc::Trc; use crate::{ frozenembed::FrozenEmbedding, EmbeddingLayerLike, LoraConfig, Merge, MergeError, - MergeErrorOrError, + MergeErrorOrError, Saveable, }; #[derive(Debug, Clone)] @@ -18,6 +18,8 @@ pub struct LoraEmbedding { b: Tensor, scale: Option, merged: bool, + prefix: String, + id: usize, } #[derive(Clone, Debug)] @@ -73,6 +75,8 @@ impl LoraEmbedding { None }, merged: false, + prefix: vb.prefix(), + id, }) } } @@ -135,6 +139,19 @@ impl Module for LoraEmbedding { } } +impl Saveable for LoraEmbedding { + fn get_tensors(&self, accum: &mut HashMap) { + accum.insert( + self.prefix.clone() + &format!("a{}.weight", self.id), + self.a.clone(), + ); + accum.insert( + self.prefix.clone() + &format!("b{}.weight", self.id), + self.b.clone(), + ); + } +} + impl EmbeddingLayerLike for LoraEmbedding { fn embeddings(&self) -> &Tensor { self.old.embeddings() diff --git a/candle-lora/src/loralinear.rs b/candle-lora/src/loralinear.rs index 2a1e05f..0605b55 100644 --- a/candle-lora/src/loralinear.rs +++ b/candle-lora/src/loralinear.rs @@ -1,4 +1,4 @@ -use std::ops::Mul; +use std::{collections::HashMap, ops::Mul}; use candle_core::{Module, Result, Shape, Tensor}; use candle_nn::{init, Dropout, Linear, VarBuilder}; @@ -7,6 +7,7 @@ use trc::Trc; use crate::{ frozenlinear::FrozenLinear, LinearLayerLike, LoraConfig, Merge, MergeError, MergeErrorOrError, + Saveable, }; #[derive(Debug, Clone)] @@ -17,6 +18,8 @@ pub struct LoraLinear { scale: Option, dropout: Option>, merged: bool, + prefix: String, + id: usize, } #[derive(Clone, Debug)] @@ -65,6 +68,8 @@ impl LoraLinear { }, dropout: config.dropout.map(|x| Trc::new(Dropout::new(x))), merged: false, + prefix: vb.prefix(), + id, }) } } @@ -137,6 +142,19 @@ impl Module for LoraLinear { } } +impl Saveable for LoraLinear { + fn get_tensors(&self, accum: &mut HashMap) { + accum.insert( + self.prefix.clone() + &format!("a{}.weight", self.id), + self.ff_a.weight().clone(), + ); + accum.insert( + self.prefix.clone() + &format!("b{}.weight", self.id), + self.ff_b.weight().clone(), + ); + } +} + impl LinearLayerLike for LoraLinear { fn bias(&self) -> Option<&Tensor> { self.old.bias() From 675d3d228ca79ca13c0555b2c4fc88feaf009acd Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Wed, 3 Apr 2024 07:21:11 -0400 Subject: [PATCH 4/4] Fix use of send, sync --- candle-lora/Cargo.toml | 1 - candle-lora/src/lib.rs | 8 ++++---- candle-lora/src/loraconv1d.rs | 15 +++++++-------- candle-lora/src/loraconv2d.rs | 15 +++++++-------- candle-lora/src/loraembed.rs | 11 +++++------ candle-lora/src/loralinear.rs | 15 +++++++-------- 6 files changed, 30 insertions(+), 35 deletions(-) diff --git a/candle-lora/Cargo.toml b/candle-lora/Cargo.toml index 8b0a601..f738af7 100644 --- a/candle-lora/Cargo.toml +++ b/candle-lora/Cargo.toml @@ -16,7 +16,6 @@ candle-core.workspace = true candle-nn.workspace = true either.workspace = true thiserror.workspace = true -trc.workspace = true [features] cuda = ["candle-core/cuda", "candle-nn/cuda"] \ No newline at end of file diff --git a/candle-lora/src/lib.rs b/candle-lora/src/lib.rs index e3a0c41..1695051 100644 --- a/candle-lora/src/lib.rs +++ b/candle-lora/src/lib.rs @@ -216,7 +216,7 @@ pub trait Saveable { } /// Any layer that is linear-like. -pub trait LinearLayerLike: Module + Debug + Saveable { +pub trait LinearLayerLike: Module + Debug + Saveable + Send + Sync { fn weight(&self) -> &Tensor; fn bias(&self) -> Option<&Tensor>; fn shape(&self) -> &Shape; @@ -241,7 +241,7 @@ impl LinearLayerLike for Linear { } /// Any layer that is conv1d-like. -pub trait Conv1dLayerLike: Module + Debug + Saveable { +pub trait Conv1dLayerLike: Module + Debug + Saveable + Send + Sync { fn weight(&self) -> &Tensor; fn bias(&self) -> Option<&Tensor>; fn config(&self) -> &Conv1dConfig; @@ -266,7 +266,7 @@ impl Conv1dLayerLike for Conv1d { } /// Any layer that is conv2d-like. -pub trait Conv2dLayerLike: Module + Debug + Saveable { +pub trait Conv2dLayerLike: Module + Debug + Saveable + Send + Sync { fn weight(&self) -> &Tensor; fn bias(&self) -> Option<&Tensor>; fn config(&self) -> &Conv2dConfig; @@ -291,7 +291,7 @@ impl Conv2dLayerLike for Conv2d { } /// Any layer that is embedding-like. -pub trait EmbeddingLayerLike: Module + Debug + Saveable { +pub trait EmbeddingLayerLike: Module + Debug + Saveable + Send + Sync { fn embeddings(&self) -> &Tensor; fn hidden_size(&self) -> usize; } diff --git a/candle-lora/src/loraconv1d.rs b/candle-lora/src/loraconv1d.rs index ea7e579..ab3eab8 100644 --- a/candle-lora/src/loraconv1d.rs +++ b/candle-lora/src/loraconv1d.rs @@ -1,9 +1,8 @@ -use std::{collections::HashMap, ops::Mul}; +use std::{collections::HashMap, ops::Mul, sync::Arc}; use candle_core::{Module, Result, Tensor}; use candle_nn::{init, Conv1d, Conv1dConfig, Dropout, VarBuilder}; use either::Either; -use trc::Trc; use crate::{ frozenconv::FrozenConv1d, Conv1dLayerLike, LoraConfig, Merge, MergeError, MergeErrorOrError, @@ -12,11 +11,11 @@ use crate::{ #[derive(Debug, Clone)] pub struct LoraConv1d { - old: Trc, + old: Arc, a: Tensor, b: Tensor, scale: Option, - dropout: Option>, + dropout: Option>, merged: bool, prefix: String, id: usize, @@ -66,7 +65,7 @@ impl LoraConv1d { )?; Ok(LoraConv1d { - old: Trc::new(FrozenConv1d::new_from_conv1d(old)?), + old: Arc::new(FrozenConv1d::new_from_conv1d(old)?), a, b, scale: if config.rank > 0 { @@ -74,7 +73,7 @@ impl LoraConv1d { } else { None }, - dropout: config.dropout.map(|x| Trc::new(Dropout::new(x))), + dropout: config.dropout.map(|x| Arc::new(Dropout::new(x))), merged: false, prefix: vb.prefix(), id, @@ -101,7 +100,7 @@ impl Merge for LoraConv1d { if self.merged { Err(Either::Left(MergeError::AlreadyMerged)) } else { - self.old = Trc::new( + self.old = Arc::new( FrozenConv1d::new( &(self.old.weight() + self.get_delta_weight()?).map_err(Either::Right)?, self.old.bias(), @@ -118,7 +117,7 @@ impl Merge for LoraConv1d { if !self.merged { Err(Either::Left(MergeError::NotMerged)) } else { - self.old = Trc::new( + self.old = Arc::new( FrozenConv1d::new( &(self.old.weight() - self.get_delta_weight()?).map_err(Either::Right)?, self.old.bias(), diff --git a/candle-lora/src/loraconv2d.rs b/candle-lora/src/loraconv2d.rs index 0eb699e..8022497 100644 --- a/candle-lora/src/loraconv2d.rs +++ b/candle-lora/src/loraconv2d.rs @@ -1,9 +1,8 @@ -use std::{collections::HashMap, ops::Mul}; +use std::{collections::HashMap, ops::Mul, sync::Arc}; use candle_core::{Module, Result, Tensor}; use candle_nn::{init, Conv2d, Conv2dConfig, Dropout, VarBuilder}; use either::Either; -use trc::Trc; use crate::{ frozenconv::FrozenConv2d, Conv2dLayerLike, LoraConfig, Merge, MergeError, MergeErrorOrError, @@ -12,11 +11,11 @@ use crate::{ #[derive(Debug, Clone)] pub struct LoraConv2d { - old: Trc, + old: Arc, a_conv: Conv2d, b_conv: Conv2d, scale: Option, - dropout: Option>, + dropout: Option>, merged: bool, prefix: String, id: usize, @@ -78,7 +77,7 @@ impl LoraConv2d { ); Ok(LoraConv2d { - old: Trc::new(FrozenConv2d::new_from_conv2d(old)?), + old: Arc::new(FrozenConv2d::new_from_conv2d(old)?), a_conv, b_conv, scale: if config.rank > 0 { @@ -86,7 +85,7 @@ impl LoraConv2d { } else { None }, - dropout: config.dropout.map(|x| Trc::new(Dropout::new(x))), + dropout: config.dropout.map(|x| Arc::new(Dropout::new(x))), merged: false, prefix: vb.prefix(), id, @@ -141,7 +140,7 @@ impl Merge for LoraConv2d { if self.merged { Err(Either::Left(MergeError::AlreadyMerged)) } else { - self.old = Trc::new( + self.old = Arc::new( FrozenConv2d::new( &(self.old.weight() + self.get_delta_weight()?).map_err(Either::Right)?, self.old.bias(), @@ -158,7 +157,7 @@ impl Merge for LoraConv2d { if !self.merged { Err(Either::Left(MergeError::NotMerged)) } else { - self.old = Trc::new( + self.old = Arc::new( FrozenConv2d::new( &(self.old.weight() - self.get_delta_weight()?).map_err(Either::Right)?, self.old.bias(), diff --git a/candle-lora/src/loraembed.rs b/candle-lora/src/loraembed.rs index bff4446..5545873 100644 --- a/candle-lora/src/loraembed.rs +++ b/candle-lora/src/loraembed.rs @@ -1,9 +1,8 @@ -use std::{collections::HashMap, ops::Mul}; +use std::{collections::HashMap, ops::Mul, sync::Arc}; use candle_core::{Module, Result, Tensor}; use candle_nn::{init, Embedding, Init, VarBuilder}; use either::Either; -use trc::Trc; use crate::{ frozenembed::FrozenEmbedding, EmbeddingLayerLike, LoraConfig, Merge, MergeError, @@ -12,7 +11,7 @@ use crate::{ #[derive(Debug, Clone)] pub struct LoraEmbedding { - old: Trc, + old: Arc, embed_a: Embedding, a: Tensor, b: Tensor, @@ -65,7 +64,7 @@ impl LoraEmbedding { let embed_a = Embedding::new(a_t.clone(), a_t.dim(1)?); Ok(LoraEmbedding { - old: Trc::new(FrozenEmbedding::new_from_embed(old)?), + old: Arc::new(FrozenEmbedding::new_from_embed(old)?), embed_a, a, b, @@ -94,7 +93,7 @@ impl Merge for LoraEmbedding { if self.merged { Err(Either::Left(MergeError::AlreadyMerged)) } else { - self.old = Trc::new( + self.old = Arc::new( FrozenEmbedding::new( &(self.embeddings() + self.get_delta_weight()?.transpose(0, 1)) .map_err(Either::Right)?, @@ -111,7 +110,7 @@ impl Merge for LoraEmbedding { if !self.merged { Err(Either::Left(MergeError::NotMerged)) } else { - self.old = Trc::new( + self.old = Arc::new( FrozenEmbedding::new( &(self.embeddings() - self.get_delta_weight()?.transpose(0, 1)) .map_err(Either::Right)?, diff --git a/candle-lora/src/loralinear.rs b/candle-lora/src/loralinear.rs index 0605b55..6ce626d 100644 --- a/candle-lora/src/loralinear.rs +++ b/candle-lora/src/loralinear.rs @@ -1,9 +1,8 @@ -use std::{collections::HashMap, ops::Mul}; +use std::{collections::HashMap, ops::Mul, sync::Arc}; use candle_core::{Module, Result, Shape, Tensor}; use candle_nn::{init, Dropout, Linear, VarBuilder}; use either::Either; -use trc::Trc; use crate::{ frozenlinear::FrozenLinear, LinearLayerLike, LoraConfig, Merge, MergeError, MergeErrorOrError, @@ -12,11 +11,11 @@ use crate::{ #[derive(Debug, Clone)] pub struct LoraLinear { - old: Trc, + old: Arc, ff_a: Linear, ff_b: Linear, scale: Option, - dropout: Option>, + dropout: Option>, merged: bool, prefix: String, id: usize, @@ -58,7 +57,7 @@ impl LoraLinear { )?; Ok(LoraLinear { - old: Trc::new(FrozenLinear::new_from_linear(old)?), + old: Arc::new(FrozenLinear::new_from_linear(old)?), ff_a: Linear::new(a, None), ff_b: Linear::new(b, None), scale: if config.rank > 0 { @@ -66,7 +65,7 @@ impl LoraLinear { } else { None }, - dropout: config.dropout.map(|x| Trc::new(Dropout::new(x))), + dropout: config.dropout.map(|x| Arc::new(Dropout::new(x))), merged: false, prefix: vb.prefix(), id, @@ -91,7 +90,7 @@ impl Merge for LoraLinear { if self.merged { Err(Either::Left(MergeError::AlreadyMerged)) } else { - self.old = Trc::new( + self.old = Arc::new( FrozenLinear::new( (self.old.weight() + self.get_delta_weight()?).map_err(Either::Right)?, self.old.bias().cloned(), @@ -107,7 +106,7 @@ impl Merge for LoraLinear { if !self.merged { Err(Either::Left(MergeError::NotMerged)) } else { - self.old = Trc::new( + self.old = Arc::new( FrozenLinear::new( (self.old.weight() - self.get_delta_weight()?).map_err(Either::Right)?, self.old.bias().cloned(),