Skip to content

Commit

Permalink
added test data from LSTM layout feature
Browse files Browse the repository at this point in the history
 + some refactoring of data generator
  • Loading branch information
Abdurrahheem committed May 15, 2023
1 parent e478135 commit f06b101
Show file tree
Hide file tree
Showing 11 changed files with 17 additions and 11 deletions.
Binary file added testdata/dnn/onnx/data/input_lstm_layout_0_0.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_lstm_layout_0_1.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_lstm_layout_0_2.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_lstm_layout_1_0.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_lstm_layout_1_1.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/input_lstm_layout_1_2.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/output_lstm_layout_0.npy
Binary file not shown.
Binary file added testdata/dnn/onnx/data/output_lstm_layout_1.npy
Binary file not shown.
28 changes: 17 additions & 11 deletions testdata/dnn/onnx/generate_onnx_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,10 +1299,13 @@ def __init__(self, **params: Any) -> None:
params[R] = weight_scale * np.ones(
(1, number_of_gates * hidden_size, hidden_size)
).astype(np.float32)

params[B] = np.ones((1, 2 * number_of_gates * hidden_size)).astype(np.float32)

params[H_0] = np.ones((1, batch_size, hidden_size)).astype(np.float32)
params[C_0] = np.ones((1, batch_size, hidden_size)).astype(np.float32)
if H_0 not in params and C_0 not in params:
params[H_0] = np.ones((1, batch_size, hidden_size)).astype(np.float32)
params[C_0] = np.ones((1, batch_size, hidden_size)).astype(np.float32)

params[P] = weight_scale * np.ones((1, number_of_peepholes * hidden_size)).astype(
np.float32)

Expand All @@ -1317,22 +1320,22 @@ def __init__(self, **params: Any) -> None:
b = (
params[B]
if B in params
else np.ones(2 * number_of_gates * hidden_size, dtype=np.float32)
else np.ones(1, 2 * number_of_gates * hidden_size, dtype=np.float32)
)
p = (
params[P]
if P in params
else np.ones(number_of_peepholes * hidden_size, dtype=np.float32)
else np.ones(1, number_of_peepholes * hidden_size, dtype=np.float32)
)
h_0 = (
params[H_0]
if H_0 in params
else np.ones((batch_size, hidden_size), dtype=np.float32)
else np.ones((1, batch_size, hidden_size), dtype=np.float32)
)
c_0 = (
params[C_0]
if C_0 in params
else np.ones((batch_size, hidden_size), dtype=np.float32)
else np.ones((1, batch_size, hidden_size), dtype=np.float32)
)

self.X = x
Expand Down Expand Up @@ -1375,7 +1378,6 @@ def step(self) -> Tuple[np.ndarray, np.ndarray]:
H_t = self.H_0
C_t = self.C_0
for x in np.split(self.X, self.X.shape[0], axis=0):
print(x.shape, self.W.shape)
gates = (
np.dot(x, np.transpose(self.W))
+ np.dot(H_t, np.transpose(self.R))
Expand All @@ -1391,7 +1393,6 @@ def step(self) -> Tuple[np.ndarray, np.ndarray]:
h_list.append(H)
H_t = H
C_t = C
print(H.shape)

concatenated = np.concatenate(h_list)
if self.num_directions == 1:
Expand Down Expand Up @@ -1483,14 +1484,12 @@ def save_model_and_data_lstm_layout(lstm, layout, x, hx, cx, basename):
onnx.checker.check_model(m)

for i, data in enumerate(inputs):
print(data.shape)
input_files = os.path.join("data", "input_" + basename + f"_{str(i)}.npy")
data = data.astype(np.float32)
np.save(input_files, np.ascontiguousarray(data.data))

output_file = os.path.join("data", "output_" + basename + ".npy")
Y_h = data.astype(np.float32)
print(Y_h.shape)
np.save(output_file, np.ascontiguousarray(Y_h.data))


Expand All @@ -1510,7 +1509,14 @@ def save_model_and_data_lstm_layout(lstm, layout, x, hx, cx, basename):
else:
x = np.ones((seq_length, batch_size, input_size)).astype(np.float32)

lstm = LayoutLSTM(X=x, layout=layout, hidden_size=hidden_size, input_size=input_size)
lstm = LayoutLSTM(
X=x,
initial_h=hx,
initial_c=cx,
layout=layout,
hidden_size=hidden_size,
input_size=input_size
)

save_model_and_data_lstm_layout(lstm, layout, x, hx, cx, f"lstm_layout_{str(layout)}")

Expand Down
Binary file added testdata/dnn/onnx/models/lstm_layout_0.onnx
Binary file not shown.
Binary file added testdata/dnn/onnx/models/lstm_layout_1.onnx
Binary file not shown.

0 comments on commit f06b101

Please sign in to comment.