Skip to content

Commit

Permalink
Fix use of send, sync
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Apr 3, 2024
1 parent 8f96feb commit 675d3d2
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 35 deletions.
1 change: 0 additions & 1 deletion candle-lora/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ candle-core.workspace = true
candle-nn.workspace = true
either.workspace = true
thiserror.workspace = true
trc.workspace = true

[features]
cuda = ["candle-core/cuda", "candle-nn/cuda"]
8 changes: 4 additions & 4 deletions candle-lora/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ pub trait Saveable {
}

/// Any layer that is linear-like.
pub trait LinearLayerLike: Module + Debug + Saveable {
pub trait LinearLayerLike: Module + Debug + Saveable + Send + Sync {
fn weight(&self) -> &Tensor;
fn bias(&self) -> Option<&Tensor>;
fn shape(&self) -> &Shape;
Expand All @@ -241,7 +241,7 @@ impl LinearLayerLike for Linear {
}

/// Any layer that is conv1d-like.
pub trait Conv1dLayerLike: Module + Debug + Saveable {
pub trait Conv1dLayerLike: Module + Debug + Saveable + Send + Sync {
fn weight(&self) -> &Tensor;
fn bias(&self) -> Option<&Tensor>;
fn config(&self) -> &Conv1dConfig;
Expand All @@ -266,7 +266,7 @@ impl Conv1dLayerLike for Conv1d {
}

/// Any layer that is conv2d-like.
pub trait Conv2dLayerLike: Module + Debug + Saveable {
pub trait Conv2dLayerLike: Module + Debug + Saveable + Send + Sync {
fn weight(&self) -> &Tensor;
fn bias(&self) -> Option<&Tensor>;
fn config(&self) -> &Conv2dConfig;
Expand All @@ -291,7 +291,7 @@ impl Conv2dLayerLike for Conv2d {
}

/// Any layer that is embedding-like.
pub trait EmbeddingLayerLike: Module + Debug + Saveable {
pub trait EmbeddingLayerLike: Module + Debug + Saveable + Send + Sync {
fn embeddings(&self) -> &Tensor;
fn hidden_size(&self) -> usize;
}
Expand Down
15 changes: 7 additions & 8 deletions candle-lora/src/loraconv1d.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::{collections::HashMap, ops::Mul};
use std::{collections::HashMap, ops::Mul, sync::Arc};

use candle_core::{Module, Result, Tensor};
use candle_nn::{init, Conv1d, Conv1dConfig, Dropout, VarBuilder};
use either::Either;
use trc::Trc;

use crate::{
frozenconv::FrozenConv1d, Conv1dLayerLike, LoraConfig, Merge, MergeError, MergeErrorOrError,
Expand All @@ -12,11 +11,11 @@ use crate::{

#[derive(Debug, Clone)]
pub struct LoraConv1d {
old: Trc<FrozenConv1d>,
old: Arc<FrozenConv1d>,
a: Tensor,
b: Tensor,
scale: Option<f64>,
dropout: Option<Trc<Dropout>>,
dropout: Option<Arc<Dropout>>,
merged: bool,
prefix: String,
id: usize,
Expand Down Expand Up @@ -66,15 +65,15 @@ impl LoraConv1d {
)?;

Ok(LoraConv1d {
old: Trc::new(FrozenConv1d::new_from_conv1d(old)?),
old: Arc::new(FrozenConv1d::new_from_conv1d(old)?),
a,
b,
scale: if config.rank > 0 {
Some(config.alpha / config.rank as f64)
} else {
None
},
dropout: config.dropout.map(|x| Trc::new(Dropout::new(x))),
dropout: config.dropout.map(|x| Arc::new(Dropout::new(x))),
merged: false,
prefix: vb.prefix(),
id,
Expand All @@ -101,7 +100,7 @@ impl Merge for LoraConv1d {
if self.merged {
Err(Either::Left(MergeError::AlreadyMerged))
} else {
self.old = Trc::new(
self.old = Arc::new(
FrozenConv1d::new(
&(self.old.weight() + self.get_delta_weight()?).map_err(Either::Right)?,
self.old.bias(),
Expand All @@ -118,7 +117,7 @@ impl Merge for LoraConv1d {
if !self.merged {
Err(Either::Left(MergeError::NotMerged))
} else {
self.old = Trc::new(
self.old = Arc::new(
FrozenConv1d::new(
&(self.old.weight() - self.get_delta_weight()?).map_err(Either::Right)?,
self.old.bias(),
Expand Down
15 changes: 7 additions & 8 deletions candle-lora/src/loraconv2d.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::{collections::HashMap, ops::Mul};
use std::{collections::HashMap, ops::Mul, sync::Arc};

use candle_core::{Module, Result, Tensor};
use candle_nn::{init, Conv2d, Conv2dConfig, Dropout, VarBuilder};
use either::Either;
use trc::Trc;

use crate::{
frozenconv::FrozenConv2d, Conv2dLayerLike, LoraConfig, Merge, MergeError, MergeErrorOrError,
Expand All @@ -12,11 +11,11 @@ use crate::{

#[derive(Debug, Clone)]
pub struct LoraConv2d {
old: Trc<FrozenConv2d>,
old: Arc<FrozenConv2d>,
a_conv: Conv2d,
b_conv: Conv2d,
scale: Option<f64>,
dropout: Option<Trc<Dropout>>,
dropout: Option<Arc<Dropout>>,
merged: bool,
prefix: String,
id: usize,
Expand Down Expand Up @@ -78,15 +77,15 @@ impl LoraConv2d {
);

Ok(LoraConv2d {
old: Trc::new(FrozenConv2d::new_from_conv2d(old)?),
old: Arc::new(FrozenConv2d::new_from_conv2d(old)?),
a_conv,
b_conv,
scale: if config.rank > 0 {
Some(config.alpha / config.rank as f64)
} else {
None
},
dropout: config.dropout.map(|x| Trc::new(Dropout::new(x))),
dropout: config.dropout.map(|x| Arc::new(Dropout::new(x))),
merged: false,
prefix: vb.prefix(),
id,
Expand Down Expand Up @@ -141,7 +140,7 @@ impl Merge for LoraConv2d {
if self.merged {
Err(Either::Left(MergeError::AlreadyMerged))
} else {
self.old = Trc::new(
self.old = Arc::new(
FrozenConv2d::new(
&(self.old.weight() + self.get_delta_weight()?).map_err(Either::Right)?,
self.old.bias(),
Expand All @@ -158,7 +157,7 @@ impl Merge for LoraConv2d {
if !self.merged {
Err(Either::Left(MergeError::NotMerged))
} else {
self.old = Trc::new(
self.old = Arc::new(
FrozenConv2d::new(
&(self.old.weight() - self.get_delta_weight()?).map_err(Either::Right)?,
self.old.bias(),
Expand Down
11 changes: 5 additions & 6 deletions candle-lora/src/loraembed.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::{collections::HashMap, ops::Mul};
use std::{collections::HashMap, ops::Mul, sync::Arc};

use candle_core::{Module, Result, Tensor};
use candle_nn::{init, Embedding, Init, VarBuilder};
use either::Either;
use trc::Trc;

use crate::{
frozenembed::FrozenEmbedding, EmbeddingLayerLike, LoraConfig, Merge, MergeError,
Expand All @@ -12,7 +11,7 @@ use crate::{

#[derive(Debug, Clone)]
pub struct LoraEmbedding {
old: Trc<FrozenEmbedding>,
old: Arc<FrozenEmbedding>,
embed_a: Embedding,
a: Tensor,
b: Tensor,
Expand Down Expand Up @@ -65,7 +64,7 @@ impl LoraEmbedding {
let embed_a = Embedding::new(a_t.clone(), a_t.dim(1)?);

Ok(LoraEmbedding {
old: Trc::new(FrozenEmbedding::new_from_embed(old)?),
old: Arc::new(FrozenEmbedding::new_from_embed(old)?),
embed_a,
a,
b,
Expand Down Expand Up @@ -94,7 +93,7 @@ impl Merge for LoraEmbedding {
if self.merged {
Err(Either::Left(MergeError::AlreadyMerged))
} else {
self.old = Trc::new(
self.old = Arc::new(
FrozenEmbedding::new(
&(self.embeddings() + self.get_delta_weight()?.transpose(0, 1))
.map_err(Either::Right)?,
Expand All @@ -111,7 +110,7 @@ impl Merge for LoraEmbedding {
if !self.merged {
Err(Either::Left(MergeError::NotMerged))
} else {
self.old = Trc::new(
self.old = Arc::new(
FrozenEmbedding::new(
&(self.embeddings() - self.get_delta_weight()?.transpose(0, 1))
.map_err(Either::Right)?,
Expand Down
15 changes: 7 additions & 8 deletions candle-lora/src/loralinear.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::{collections::HashMap, ops::Mul};
use std::{collections::HashMap, ops::Mul, sync::Arc};

use candle_core::{Module, Result, Shape, Tensor};
use candle_nn::{init, Dropout, Linear, VarBuilder};
use either::Either;
use trc::Trc;

use crate::{
frozenlinear::FrozenLinear, LinearLayerLike, LoraConfig, Merge, MergeError, MergeErrorOrError,
Expand All @@ -12,11 +11,11 @@ use crate::{

#[derive(Debug, Clone)]
pub struct LoraLinear {
old: Trc<FrozenLinear>,
old: Arc<FrozenLinear>,
ff_a: Linear,
ff_b: Linear,
scale: Option<f64>,
dropout: Option<Trc<Dropout>>,
dropout: Option<Arc<Dropout>>,
merged: bool,
prefix: String,
id: usize,
Expand Down Expand Up @@ -58,15 +57,15 @@ impl LoraLinear {
)?;

Ok(LoraLinear {
old: Trc::new(FrozenLinear::new_from_linear(old)?),
old: Arc::new(FrozenLinear::new_from_linear(old)?),
ff_a: Linear::new(a, None),
ff_b: Linear::new(b, None),
scale: if config.rank > 0 {
Some(config.alpha / config.rank as f64)
} else {
None
},
dropout: config.dropout.map(|x| Trc::new(Dropout::new(x))),
dropout: config.dropout.map(|x| Arc::new(Dropout::new(x))),
merged: false,
prefix: vb.prefix(),
id,
Expand All @@ -91,7 +90,7 @@ impl Merge for LoraLinear {
if self.merged {
Err(Either::Left(MergeError::AlreadyMerged))
} else {
self.old = Trc::new(
self.old = Arc::new(
FrozenLinear::new(
(self.old.weight() + self.get_delta_weight()?).map_err(Either::Right)?,
self.old.bias().cloned(),
Expand All @@ -107,7 +106,7 @@ impl Merge for LoraLinear {
if !self.merged {
Err(Either::Left(MergeError::NotMerged))
} else {
self.old = Trc::new(
self.old = Arc::new(
FrozenLinear::new(
(self.old.weight() - self.get_delta_weight()?).map_err(Either::Right)?,
self.old.bias().cloned(),
Expand Down

0 comments on commit 675d3d2

Please sign in to comment.