From 62ebe5a5d7178c0fa13a27b26e7ef4ac572d60d6 Mon Sep 17 00:00:00 2001 From: Abduragim Shtanchaev Date: Tue, 25 Apr 2023 16:00:27 +0300 Subject: [PATCH] added test data and model for lstm wihout hidden states initialization --- .../dnn/onnx/data/input_lstm_init_h0_c0_0.npy | Bin 0 -> 256 bytes .../dnn/onnx/data/input_lstm_init_h0_c0_1.npy | Bin 0 -> 160 bytes .../dnn/onnx/data/input_lstm_init_h0_c0_2.npy | Bin 0 -> 160 bytes .../dnn/onnx/data/output_lstm_init_h0_c0.npy | Bin 0 -> 192 bytes testdata/dnn/onnx/generate_onnx_models.py | 26 ++++++++++++++++++ testdata/dnn/onnx/models/lstm_init_h0_c0.onnx | Bin 0 -> 3919 bytes 6 files changed, 26 insertions(+) create mode 100644 testdata/dnn/onnx/data/input_lstm_init_h0_c0_0.npy create mode 100644 testdata/dnn/onnx/data/input_lstm_init_h0_c0_1.npy create mode 100644 testdata/dnn/onnx/data/input_lstm_init_h0_c0_2.npy create mode 100644 testdata/dnn/onnx/data/output_lstm_init_h0_c0.npy create mode 100644 testdata/dnn/onnx/models/lstm_init_h0_c0.onnx diff --git a/testdata/dnn/onnx/data/input_lstm_init_h0_c0_0.npy b/testdata/dnn/onnx/data/input_lstm_init_h0_c0_0.npy new file mode 100644 index 0000000000000000000000000000000000000000..9a566490840e539e69364b9ffc168cf3c6f7c5c3 GIT binary patch literal 256 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$7ItqqBWTvU3P^&-|;41bL+!t|Wg?;P6EV~Ow=Gq?GAY$il61vZIgU$X; zyJpxMw>sIUSBl!t>|M3@{tRaOrf+8Zd{SKPgN$SB-Y&dv|La`&{{M6BZ5=*q+6!lX zwiV>$*>~52{eTwpT>Gqbjr&yj*zB(>UbD{){=ENrCe!}E20r^==IQU>^ft?`=4G-y E04@<%rvLx| literal 0 HcmV?d00001 diff --git a/testdata/dnn/onnx/data/input_lstm_init_h0_c0_1.npy b/testdata/dnn/onnx/data/input_lstm_init_h0_c0_1.npy new file mode 100644 index 0000000000000000000000000000000000000000..0374ce93d16645bc6d9ea784aeb69fb517ccb7c3 GIT binary patch literal 160 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-ItoB!p{b)#t3V#$`lxHWUno_5Un}>@eOeRj_m|eB+cO@u+V_)3Y`^oL Gl>Gp=#VSJp literal 0 HcmV?d00001 diff --git a/testdata/dnn/onnx/data/input_lstm_init_h0_c0_2.npy b/testdata/dnn/onnx/data/input_lstm_init_h0_c0_2.npy new file mode 100644 index 0000000000000000000000000000000000000000..0a6f757f60193ef21ada7fabd4a4f2486e6c1c9d GIT binary patch literal 160 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-ItoB!p{b)#t3V#$vY0o)-mS*X{v5x`eo>jrdyh4V+k3~p+&8I+%fTga G+I|40ktr_# literal 0 HcmV?d00001 diff --git a/testdata/dnn/onnx/data/output_lstm_init_h0_c0.npy b/testdata/dnn/onnx/data/output_lstm_init_h0_c0.npy new file mode 100644 index 0000000000000000000000000000000000000000..6ff51f60b228de14e940b47646ecd6bed37114f4 GIT binary patch literal 192 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$7Itqq53Kp6=3bhL40j^)K9PO5W^09l7@O_`>H!Hh0lLYM4;_CJaC_T2z qYMQyv*XEjCh@7_VfiIKxeF<&0GiE-%_wr)yefa@%?K(t%?gao}m^)Me literal 0 HcmV?d00001 diff --git a/testdata/dnn/onnx/generate_onnx_models.py b/testdata/dnn/onnx/generate_onnx_models.py index 50141f4fe..7c23a7674 100644 --- a/testdata/dnn/onnx/generate_onnx_models.py +++ b/testdata/dnn/onnx/generate_onnx_models.py @@ -1145,7 +1145,33 @@ def forward(self, x): lstm = LSTM(features, hidden, batch, bidirectional=True) save_data_and_model("lstm_bidirectional", input, lstm) +class LSTM_hidden_state_inputs(nn.Module): + def __init__(self, features, hidden, batch, num_layers=1, bidirectional=False): + super(LSTM_hidden_state_inputs, self).__init__() + self.lstm = nn.LSTM(features, hidden, num_layers, bidirectional=bidirectional) + + def forward(self, x, h, c): + return self.lstm(x, (h, c))[0] + +batch = 1 +features = 16 +hidden = 8 +seq_len = 2 +num_layers = 1 +bidirectional = False + +lstm = LSTM_hidden_state_inputs( + features, + hidden, + batch, + num_layers=num_layers, + bidirectional=bidirectional +) +input = torch.randn(seq_len, batch, features) +h0 = torch.randn(num_layers + int(bidirectional), batch, hidden) +c0 = torch.randn(num_layers + int(bidirectional), batch, hidden) +save_data_and_model_multy_inputs("lstm_init_h0_c0", lstm, input, h0, c0, export_params=True) class HiddenLSTM(nn.Module): def __init__(self, input_size, hidden_size, num_layers=1, is_bidirectional=False): diff --git a/testdata/dnn/onnx/models/lstm_init_h0_c0.onnx b/testdata/dnn/onnx/models/lstm_init_h0_c0.onnx new file mode 100644 index 0000000000000000000000000000000000000000..09b44aa6296ce80955cc6dd138732c6d4f3df7a7 GIT binary patch literal 3919 zcmZ`+2{e^!_c!KrjH!rR2`MT*~yMNEx>v{IOe+&_6k-)IvfFSQJ^1?cr zOEi}lTu@|Qi5 zTWQ%pamE4uJA*y_gFXJDl9%`~pv?FSq9M-|4)xp?;-ewK;Fl7yWbpI-`Q-8oF)f+@ zwr}SD$09ADr!6o3CzIoj5Fej#A7zpMy+}{}qs|^%Il&qyr5OASC5Ds~ztN8+wv3eO zNysA2dB?zS#~WJ7@q;yYzS9ohOSCpp3l~fzg2JZ~^zjfNZ#rc$s2~n3KnFRVy|_8_v8$k< z>ngq5(@k}zPRCN=H8d6mSj%oXqtt<9$T#0X)CweEi{}X6 zM)R_d2taGKO8e%3&_cK68z>pMiMmH~Sga?xlb0~lYX;Kt6`c-lJ?u6hpfwr)$ojLWUG z#dQh@Gh%4WsY;YQG)(nPD@o7DI}&Hh2Tc~u^yxiSII=wj)7jT3OXNCxVPzKPeMrSu zMt7+2^%CMS@qs$iNJyOVo_0noz~ub}_?6jTKe1&t#^00%+o$s|kuQ)w2;sv!*Arnw zSr!_y<;ZgrJK)xGsex)ddcIPFu{upq(#U{&8%N2jy;|@jNeU+Jog+==(U9+Wleo{3 z#w$bW;MueQTh?V#md!3QxO+b&C0%Dr4itlcixPeHpbB=d_L5u7Sf2mU2$<=V1-tu7 zXyR=S*qsM-eH(;6+#BrLfhPu?)m-elnE^>T^*lkD5|Y?bi^JyI>8Mm36b&1YoX4sV zbj%QBZBJw0Z5QGaB8Qfvn?W)w9lWKi(M+ZQ!fcW;bn8O2*ZLVeevO6Q8o{7)Qx{h} z_rss{kArzsB{k|u0P!+c_*CZu9MB^fKaey3P z^K+6=ac~)~bY2dL4^6=0mvdx}aXTGSb%(CBZW^|zotD~7Q2T0SxKgM~|QS0@CD197MrGz(q^Zh%dk zt?1&P2))CRaE+smmqtI5+Vw@4WOEQ`n;|xOWP*K2BQK!dkZqvLglNuOB3x7ew=$DJ zcK1wt+nUb0f5({Qv^w!hWTbJ^IF}yVdzT)%90zxdFS1LVn#pFz33|~*0|tu)Vfh|W z(1|}s+={t)uOkjNtUrTWt0RcYOf?W+RfC&9&LKS=iX^=;1M^PQg0`aseiz(_CU=Ff zp{#|aXH^ay-9+-?t1%RgID;x@5%!ey5&iLT+Hk29cWgTYzkSw0O^X-Q*Y^ZDF``5- zW26#>zP~BS|%nECOPR)jZPHm{^AEQ!cns7V&2yhY#>!-PVpmXz6 zAWhMMgk@?GtDHoXjgiELgQ>VER~+EyXdD#^#gWCQcv{U0@cv5)Z1jsqliDn3vSDI@ zYZbQ4vceddVyvQuG!&*lL#P^BS^Gfk@eg!W)JX_Q4usNid#v#?WW9JCixzi%iLZMO z<{S@)xbso?+-VJ}c}GB!?!cuq!AMwSwGex|Z87`NA#(qPGn7AyfpPU?DEv^Bm+YPc z7f(-zbkPzlv6)F%evibybCcXA;?Q8(L=VtTdhd}32!G1Q(W8p+BKskkzc&}=b_wBD z&t*7$RU-W6=Td*AeuS+4&nSCSK^%ORW5Rb$7O42^1ADj>i%iDoPjaoKqD3CJT(81u zR=dGthA4V1Z)5izP@!SB{OR;vl^{HCIZhRMN%C3wV0ZK>@ALIiEas?#-vu^rrz4Ba z8=b6AgD<$`tI{vp$v7sPh^hnp*c%do4llfrTgPT|eVS>1^eZCRSc3@-LAdTr242fd zKu6Uu_^_#uE@NBOZ#WVL!3z>0J*pJ4&0InMaw;)$Y$8!2!JrmDLCkfI!eaw5%xNnD zuCFluCpsEi8JTc-C#7o>qOdvWDBOKJc~92}&}F6-_=?W}pZ%(UXM=oEWdBS^>8%7= zzE#-2LKWP;)WWR(#qgxb8=iINg6O$a2>!Bza6Hu7oZ2iae+H*O`&aY)>rG+lamQZNNO24Ux>l zgu_w;u_>iQ=u<3Rwh{ii%*4t06vhWkS4ur3HZVOPA?g7$SqC{bnw8YdZl z3>Nw#a=7RfK%l|E(tXcgh*N_BcSC)P3>I^p{s_di$(-270iRKBO#bseSn^c z$cCZEc6dZQ6z@qgN%ziOv?95W)~GIks-^zu{c9jCcNYY_u^DbmawQ$LV5iw`$030* zeE&&^9R9)rM$9p?T+0W<26ABXilc*v8|iFxK^MS1Y+sFS z7CVu}sD<_Sn@K`}4e2bOAepC^12pVA*^ITb2!kruvL-vt3QFGx^%>wZ9e_21@E}51|9>m#Bv1GeN zGTdlZ#4w3a7+2Ay$13mAh_uD1Dp5$kU)Dh5*YY$n%mWu~S&d;2;!vl3keu084LWD4 zG2L_n?_TsFNRcT;Nv~>5Dc^w-9}|ho)mgYme>cqC!o_n6hLCtg1UpQ2(U+W`4enQ% zkp>|}u=*5?;YF#`K(GW?DU?8_lMCAK4kQ+^8$>dTpvqi=l*N=l+5A$pweDc$t?f3* zF)D>vVS*L0(O}@PntV&Iq8`-%Ndp8U6rPcmm`qZvoJ9M68{$cCUk2Xi%whdpaq@a< z4PKsEMi0N@;vV3`P42N6Y+DQ|>SOeL%WROA5hK^<BI8mnP?vU(is3fgBh^CRSrv1Z?e<|PqRL6P{9hb$)79qQ~R9@u<&sc zEh`J4Nl#6vgDoG?$Pk!p4u=114ceBG0&bSkq<_>J&g|1gJU2o`L-%4)OAXcy+0pO2 zZu4wHWYKNx67|*h0>glH*haKbKA{BW#}p89e=E)LJc~6M*-)$U3w71c1f9mddi!

g9p>XJ$3k|(yY=X<*hb81<6$uOI9VwgK>}q1@TgiXwhIb? ze4PlWOwL>rroRyf5h-+XbAzFG1*ot(4_0P5(3loQ$WSYSq{w`FHV GN&FW{1n#u} literal 0 HcmV?d00001