-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathcreate_model.py
53 lines (47 loc) · 3.11 KB
/
create_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# -*- coding: utf-8 -*-
import tensorflow as tf
from capsule_masked import Capsule
def build_model(input_data, input_size, sequence_length, slot_size, intent_size, intent_dim, layer_size, embed_dim,
num_rnn=1, isTraining=True, iter_slot=2, iter_intent=2, re_routing=True):
cell_fw_list = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.BasicLSTMCell(layer_size) for _ in range(num_rnn)])
cell_bw_list = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.BasicLSTMCell(layer_size) for _ in range(num_rnn)])
if isTraining == True:
cell_fw_list = tf.contrib.rnn.DropoutWrapper(cell_fw_list, input_keep_prob=0.8,
output_keep_prob=0.8)
cell_bw_list = tf.contrib.rnn.DropoutWrapper(cell_bw_list, input_keep_prob=0.8,
output_keep_prob=0.8)
embedding = tf.get_variable('embedding', [input_size, embed_dim],
initializer=tf.contrib.layers.xavier_initializer())
inputs = tf.nn.embedding_lookup(embedding, input_data)
with tf.variable_scope('slot_capsule'):
H, _, _ = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
[cell_fw_list],
[cell_bw_list],
inputs=inputs,
sequence_length=sequence_length,
dtype=tf.float32)
sc = Capsule(slot_size, layer_size, reuse=tf.AUTO_REUSE, iter_num=iter_slot, wrr_dim=(layer_size, intent_dim))
slot_capsule, routing_weight, routing_logits = sc(H, sequence_length, re_routing=False)
with tf.variable_scope('slot_proj'):
slot_p = tf.reshape(routing_logits, [-1, slot_size])
with tf.variable_scope('intent_capsule'):
intent_capsule, intent_routing_weight, _ = Capsule(intent_size, intent_dim, reuse=tf.AUTO_REUSE,
iter_num=iter_intent)(slot_capsule, slot_size)
with tf.variable_scope('intent_proj'):
intent = intent_capsule
outputs = [slot_p, intent, routing_weight, intent_routing_weight]
if re_routing:
pred_intent_index_onehot = tf.one_hot(tf.argmax(tf.norm(intent_capsule, axis=-1), axis=-1), intent_size)
pred_intent_index_onehot = tf.tile(tf.expand_dims(pred_intent_index_onehot, 2),
[1, 1, tf.shape(intent_capsule)[2]])
intent_capsule_max = tf.reduce_sum(tf.multiply(intent_capsule, tf.cast(pred_intent_index_onehot, tf.float32)),
axis=1,
keepdims=False)
caps_ihat = tf.expand_dims(tf.expand_dims(intent_capsule_max, 1), 3)
with tf.variable_scope('slot_capsule', reuse=True):
slot_capsule_new, routing_weight_new, routing_logits_new = sc(H, sequence_length, caps_ihat=caps_ihat,
re_routing=True)
with tf.variable_scope('slot_proj', reuse=True):
slot_p_new = tf.reshape(routing_logits_new, [-1, slot_size])
outputs = [slot_p_new, intent, routing_weight_new, intent_routing_weight]
return outputs