-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.diff
51 lines (47 loc) · 2.83 KB
/
main.diff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
diff --git "a/C:\\dev\\jade\\chatglm_cpp2\\main.cpp" b/main.cpp
index 68a63cc..9979be5 100644
--- "a/C:\\dev\\jade\\chatglm_cpp2\\main.cpp"
+++ b/main.cpp
@@ -207,11 +207,13 @@ int main(int argc, char** argv) {
}
model->reshape(shapes);
- ov::preprocess::PrePostProcessor p3(model);
- p3.input("input_ids").tensor().set_element_type(ov::element::i32); // cast to the type of tokenyzer's output
- p3.input("attention_mask").tensor().set_element_type(ov::element::i32);
- p3.input("position_ids").tensor().set_element_type(ov::element::i32);
- model = p3.build();
+ {
+ ov::preprocess::PrePostProcessor p3(model);
+ p3.input("input_ids").tensor().set_element_type(ov::element::i32); // cast to the type of tokenyzer's output
+ p3.input("attention_mask").tensor().set_element_type(ov::element::i32);
+ p3.input("position_ids").tensor().set_element_type(ov::element::i32);
+ model = p3.build();
+ }
//ov::InferRequest ireq = core.compile_model(model, "CPU", { ov::cache_dir("llm-cache") }).create_infer_request();
//ov::InferRequest ireq = core.compile_model(model, "CPU").create_infer_request();
startTime = Time::now();
@@ -222,6 +224,9 @@ int main(int argc, char** argv) {
ov::InferRequest ireq;
int32_t out_token;
+ auto model_inputs = model->inputs();
+ model = nullptr;
+
for (std::string input_text : sentences) {
std::cout << " #### sentence: index " << input_text << std::endl;
@@ -239,7 +244,7 @@ int main(int argc, char** argv) {
duration_ms = get_duration_ms_until_now(startTime);
std::cout << "Get Tokenizer id took " << duration_ms << " ms" << std::endl;
- for (const ov::Output<ov::Node>& input : model->inputs()) {
+ for (const ov::Output<ov::Node>& input : model_inputs) {
for (const std::string& name : input.get_names()) {
if (name.rfind("past_key_values", 0) == 0) {
ireq.get_tensor(input).set_shape(input.get_partial_shape().get_min_shape());
@@ -290,7 +295,7 @@ int main(int argc, char** argv) {
constexpr int32_t SPECIAL_EOS_TOKEN = 2; // There's no way to extract the value from the tokenizer for now
while (out_token != SPECIAL_EOS_TOKEN) {
startTime = Time::now();
- for (const ov::Output<ov::Node>& input : model->inputs()) {
+ for (const ov::Output<ov::Node>& input : model_inputs) {
for (const std::string& name : input.get_names()) {
if (name.rfind("past_key_values", 0) == 0) {
ireq.set_tensor(input, ireq.get_tensor("present" + name.substr(15)));