Skip to content

Commit a655b3c

Browse files
authored
Update gather with sparse_grad set to false (#404)
* update `gather` with `sparse_grad` set to false * Fixed Clippy warnings & updated changelog * Pin ort version
1 parent 107fb21 commit a655b3c

File tree

11 files changed

+18
-18
lines changed

11 files changed

+18
-18
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file. The format
44
## [Unreleased]
55
## Fixed
66
- (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering m-grams spanning multiple sentences).
7+
- Improved MPS device compatibility setting the `sparse_grad` flag to false for `gather` operations
78

89
## [0.21.0] - 2023-06-03
910
## Added

Cargo.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ regex = "1.6"
8787
cached-path = { version = "0.6", default-features = false, optional = true }
8888
dirs = { version = "4", optional = true }
8989
lazy_static = { version = "1", optional = true }
90-
ort = {version="1.14.8", optional = true, default-features = false, features = ["half"]}
90+
ort = {version="~1.14.8", optional = true, default-features = false, features = ["half"]}
9191
ndarray = {version="0.15", optional = true}
9292

9393
[dev-dependencies]
@@ -99,4 +99,4 @@ torch-sys = "0.13.0"
9999
tempfile = "3"
100100
itertools = "0.10"
101101
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
102-
ort = {version="1.14.8", features = ["load-dynamic"]}
102+
ort = {version="~1.14.8", features = ["load-dynamic"]}

src/models/bart/bart_model.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
357357
let output = input_ids.empty_like().to_kind(Kind::Int64);
358358
output
359359
.select(1, 0)
360-
.copy_(&input_ids.gather(1, &index_eos, true).squeeze());
360+
.copy_(&input_ids.gather(1, &index_eos, false).squeeze());
361361
output
362362
.slice(1, 1, *output.size().last().unwrap(), 1)
363363
.copy_(&input_ids.slice(1, 0, *output.size().last().unwrap() - 1, 1));

src/models/deberta/attention.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ impl DebertaDisentangledSelfAttention {
192192
let c2p_att = c2p_att.gather(
193193
-1,
194194
&self.c2p_dynamic_expand(&c2p_pos, query_layer, &relative_pos),
195-
true,
195+
false,
196196
);
197197
score = score + c2p_att;
198198
}
@@ -213,15 +213,15 @@ impl DebertaDisentangledSelfAttention {
213213
.gather(
214214
-1,
215215
&self.p2c_dynamic_expand(&p2c_pos, query_layer, key_layer),
216-
true,
216+
false,
217217
)
218218
.transpose(-1, -2);
219219
if query_layer_size[1] != key_layer_size[1] {
220220
let pos_index = relative_pos.select(3, 0).unsqueeze(-1);
221221
p2c_att = p2c_att.gather(
222222
-2,
223223
&self.pos_dynamic_expand(&pos_index, &p2c_att, key_layer),
224-
true,
224+
false,
225225
);
226226
}
227227
score = score + p2c_att;

src/models/deberta_v2/attention.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ impl DebertaV2DisentangledSelfAttention {
156156
],
157157
true,
158158
),
159-
true,
159+
false,
160160
);
161161
score = score + c2p_att / scale;
162162
Some(c2p_pos)
@@ -189,7 +189,7 @@ impl DebertaV2DisentangledSelfAttention {
189189
[query_layer.size()[0], key_layer_size[1], key_layer_size[1]],
190190
true,
191191
),
192-
true,
192+
false,
193193
)
194194
.transpose(-1, -2);
195195
score = score + p2c_att / scale;
@@ -211,7 +211,7 @@ impl DebertaV2DisentangledSelfAttention {
211211
],
212212
true,
213213
),
214-
true,
214+
false,
215215
);
216216
score = score + p2p_att;
217217
}

src/models/distilbert/distilbert_model.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,8 @@ impl DistilBertModel {
192192
P: Borrow<nn::Path<'p>>,
193193
{
194194
let p = p.borrow() / "distilbert";
195-
let embeddings = DistilBertEmbedding::new(p.borrow() / "embeddings", config);
196-
let transformer = Transformer::new(p.borrow() / "transformer", config);
195+
let embeddings = DistilBertEmbedding::new(&p / "embeddings", config);
196+
let transformer = Transformer::new(p / "transformer", config);
197197
DistilBertModel {
198198
embeddings,
199199
transformer,

src/models/mbart/mbart_model.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
162162
- 1;
163163
output
164164
.select(1, 0)
165-
.copy_(&input_ids.gather(1, &index_eos, true).squeeze());
165+
.copy_(&input_ids.gather(1, &index_eos, false).squeeze());
166166
output
167167
.slice(1, 1, *output.size().last().unwrap(), 1)
168168
.copy_(&input_ids.slice(1, 0, *output.size().last().unwrap() - 1, 1));

src/models/prophetnet/attention.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ impl ProphetNetNgramAttention {
712712
]);
713713

714714
rel_pos_embeddings
715-
.gather(1, &predict_relative_position_buckets, true)
715+
.gather(1, &predict_relative_position_buckets, false)
716716
.view([
717717
self.ngram,
718718
batch_size * self.num_attention_heads,

src/models/reformer/attention_utils.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ pub fn reverse_sort(
173173
let expanded_undo_sort_indices = undo_sorted_bucket_idx
174174
.unsqueeze(-1)
175175
.expand(out_vectors.size().as_slice(), true);
176-
let out_vectors = out_vectors.gather(2, &expanded_undo_sort_indices, true);
177-
let logits = logits.gather(2, undo_sorted_bucket_idx, true);
176+
let out_vectors = out_vectors.gather(2, &expanded_undo_sort_indices, false);
177+
let logits = logits.gather(2, undo_sorted_bucket_idx, false);
178178
(out_vectors, logits)
179179
}
180180

src/pipelines/generation_utils.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ pub(crate) mod private_generation_utils {
964964
prev_scores.push(
965965
next_token_logits
966966
.log_softmax(-1, next_token_logits.kind())
967-
.gather(1, &next_token.reshape([-1, 1]), true)
967+
.gather(1, &next_token.reshape([-1, 1]), false)
968968
.squeeze()
969969
.masked_fill(&finished_mask, 0),
970970
);

src/pipelines/zero_shot_classification.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ use crate::{
121121
bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
122122
resources::RemoteResource,
123123
};
124-
use std::ops::Deref;
125124
use tch::kind::Kind::{Bool, Float};
126125
use tch::nn::VarStore;
127126
use tch::{no_grad, Device, Kind, Tensor};
@@ -698,7 +697,7 @@ impl ZeroShotClassificationModel {
698697
.flat_map(|input| {
699698
label_sentences
700699
.iter()
701-
.map(move |label_sentence| (input.deref(), label_sentence.as_str()))
700+
.map(move |label_sentence| (*input, label_sentence.as_str()))
702701
})
703702
.collect::<Vec<(&str, &str)>>();
704703

0 commit comments

Comments
 (0)