-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathkeras-self-defined-triplet-loss.py
43 lines (30 loc) · 1.06 KB
/
keras-self-defined-triplet-loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""
from https://kexue.fm/archives/4493
recorded for quickly finding purpose
"""
from keras.layers import Input,Embedding,LSTM,Dense,Lambda
from keras.models import Model
import keras.backend as K
from keras.layers import dot
word_size=10000
nb_features=128
encode_size=64
margin=0.1
embedding=Embedding(word_size,nb_features)
lstm=LSTM(encode_size)
def encoder(input_):
return lstm(embedding(input_))
q_input=Input(shape=(None,))
a_right=Input(shape=(None,))
a_wrong=Input(shape=(None,))
q_encoded=encoder(q_input)
a_right_encode=encoder(a_right)
a_wrong_encode=encoder(a_wrong)
q_encode_dense=Dense(encode_size)(q_encoded)
right_cos=dot([q_encode_dense,a_right_encode],-1,normalize=True)
wrong_cos=dot([q_encode_dense,a_wrong_encode],-1,normalize=True)
triplet_loss=Lambda(lambda x:K.relu(margin+x[0]-x[1]))([wrong_cos,right_cos])
model_train=Model([q_input,a_right,a_wrong],outputs=triplet_loss)
model_train.compile(loss=lambda y_true,y_pred:y_pred,optimizer="adam")
###the shape of y is any matrix with shape: [len(q),1]
model_train.fit([q,a1,a2],y,epochs=10)