-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
rixwew
committed
Jan 24, 2019
1 parent
6e21874
commit 7d96f84
Showing
13 changed files
with
540 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[submodule "examples/question-answering/dataset"] | ||
path = examples/question-answering/dataset | ||
url = https://github.com/shuzi/insuranceQA.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Question answering Example | ||
|
||
Question answering implementation is based on paper LSTM-based Deep Learning Models | ||
for Non-factoid Answer Selection - Tan, dos Santos, Xiang and Zhou. | ||
|
||
## Requirement | ||
|
||
* pytorch 1.0 | ||
* numpy | ||
* gensim | ||
* elasticsearch | ||
|
||
## Download insurance qa data and train model | ||
|
||
```bash | ||
bash prepare.sh | ||
python train.py | ||
``` | ||
|
||
InsuranceQA Version1 top1 precision result | ||
|
||
| Model | Validation | Test1 | Test2 | | ||
|:---------------------------------|-----------:|------:|------:| | ||
| QA-LSTM basic-model, max pooling(100 epoch) | 62.2 | 63.8 | 58.8 | | ||
| QA-LSTM basic-model, max pooling(paper) | 64.3 | 63.1 | 58.0 | | ||
|
||
|
||
## Search answers using elasticsearch plugin | ||
|
||
```bash | ||
export PYTHONPATH=$PATH_TO_SCRIPT_DIR/lib:$PYTHONPATH | ||
python search_example.py --question "Can a Non us citizen get Life Insurance" | ||
--result_size 5 | ||
``` | ||
|
||
```json | ||
{ | ||
"took": 36, | ||
"timed_out": false, | ||
"_shards": { | ||
"total": 5, | ||
"successful": 5, | ||
"skipped": 0, | ||
"failed": 0 | ||
}, | ||
"hits": { | ||
"total": 870, | ||
"max_score": null, | ||
"hits": [ | ||
{ | ||
"_index": "answers", | ||
"_type": "answer", | ||
"_id": "o1i1f2gBaJEWlukYG7sK", | ||
"_score": 0.5443098, | ||
"_source": { | ||
"description": "a non citizen can get life insurance with most company if they have a green card or an H-1b work visa some company do require the applicant be a US citizen before allow them get a life insurance policy and some will only allow green card but not work visa contact an agent find out which company will work for your situation" | ||
}, | ||
"sort": [ | ||
0.5443098 | ||
] | ||
}, | ||
{ | ||
"_index": "answers", | ||
"_type": "answer", | ||
"_id": "81i2f2gBaJEWlukYacAb", | ||
"_score": 0.7198508, | ||
"_source": { | ||
"description": "yes there be absolutely no requirement a person be a citizen buy life insurance each company make its own decision on requirement but citizenship be not 1 them so long as you be in the country legally you can buy life insurance different ID be require different carrier but rest assure if your age and health warrant it you can buy life insurance on yourself here in the USA love help thank you Gary Lane" | ||
}, | ||
"sort": [ | ||
0.7198508 | ||
] | ||
}, | ||
{ | ||
"_index": "answers", | ||
"_type": "answer", | ||
"_id": "0Fiyf2gBaJEWlukYdLDC", | ||
"_score": 0.75013983, | ||
"_source": { | ||
"description": "you do not have be a citizen obtain life insurance US life insurer require the propose insured must be a permanent resident of the US that mean a US citizen or a non US citizen who be a lawful permanent US resident ( green card or on certain visa type the applicant will also need have the means pay premium and have a demonstrable life insurance need i.e. generate earn income or asset protect here some insurer have develop foreign national program that can also work in situation where established US interest and tie exist plus meet some additional criterion citizen of some country may not be eligible it can be a complex area of field underwriting so much so that our firm have develop a special questionnaire help shop for coverage be sure work with a life insurance professional with experience in this area" | ||
}, | ||
"sort": [ | ||
0.75013983 | ||
] | ||
}, | ||
{ | ||
"_index": "answers", | ||
"_type": "answer", | ||
"_id": "9Fi2f2gBaJEWlukYacBr", | ||
"_score": 0.75358534, | ||
"_source": { | ||
"description": "yes a non US citizen can get life insurance with many American company it be up to the discretion of each company as to what type of citizenship or residency they will accept a green card be usually ok and many company will accept a work visa as qualification for apply for life insurance in the US get life insurance in the US as a non US citizen however almost always require have a residence in the United States" | ||
}, | ||
"sort": [ | ||
0.75358534 | ||
] | ||
}, | ||
{ | ||
"_index": "answers", | ||
"_type": "answer", | ||
"_id": "Nliwf2gBaJEWlukYIKcx", | ||
"_score": 0.7816857, | ||
"_source": { | ||
"description": "almost anyone can get life insurance the only people who can not get life insurance those who have serious health problem who fall outside the age guideline guarantee issue those who do not have any income at all even they may able to get a policy with a cap on the face amount in the us those who do not have citizenship a green card work visa" | ||
}, | ||
"sort": [ | ||
0.7816857 | ||
] | ||
} | ||
] | ||
} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import collections | ||
|
||
import numpy | ||
import torch.utils.data | ||
|
||
|
||
class Vocab(object): | ||
|
||
def __init__(self, vocab_path, lexicon, unk_surf='<UNK>', thresh=5): | ||
self.vid2surf = dict() | ||
lexicon = {vocab_id for vocab_id, count in lexicon.items() if count >= thresh} | ||
with open(vocab_path, encoding='utf-8') as f: | ||
for _line in f: | ||
vocab_id, surf = _line.rstrip().split('\t') | ||
if vocab_id in lexicon: | ||
self.vid2surf[vocab_id] = surf | ||
self.vid2wid = {vocab_id: i + 1 for i, vocab_id in enumerate(self.vid2surf)} | ||
self.wid2surf = {self.vid2wid.get(vid): surf for vid, surf in self.vid2surf.items()} | ||
self.unk_surf = unk_surf | ||
self.unk_word_id = len(self.vid2wid) + 1 | ||
|
||
def surfaces(self, vocab_ids): | ||
return [self.vid2surf.get(vocab_id, self.unk_surf) for vocab_id in vocab_ids] | ||
|
||
def word_ids(self, vocab_ids): | ||
return [self.vid2wid.get(vocab_id, self.unk_word_id) for vocab_id in vocab_ids] | ||
|
||
def __len__(self): | ||
return len(self.vid2surf) + 1 | ||
|
||
|
||
class AnswerData(object): | ||
|
||
def __init__(self, data_path): | ||
self.answers = dict() | ||
self.lexicon = list() | ||
with open(data_path, encoding='utf-8') as f: | ||
for _line in f: | ||
answer_id, answer = _line.rstrip().split('\t') | ||
vocab_ids = answer.split(' ') | ||
self.answers[int(answer_id)] = vocab_ids | ||
self.lexicon.extend(vocab_ids) | ||
self.lexicon = collections.Counter(self.lexicon) | ||
|
||
|
||
class QaData(object): | ||
|
||
def __init__(self, data_path): | ||
self.questions = list() | ||
self.positive = list() | ||
self.negative = list() | ||
self.lexicon = list() | ||
with open(data_path, encoding='utf-8') as f: | ||
for _line in f: | ||
values = _line.rstrip().split('\t') | ||
if len(values) == 2: | ||
question, answer_ids = values | ||
positive_ids = list(map(int, answer_ids.split(' '))) | ||
negative_ids = list() | ||
elif len(values) == 3: | ||
answer_ids, question, pool = values | ||
positive_ids = list(map(int, answer_ids.split(' '))) | ||
negative_ids = list(filter(lambda x: x not in set(positive_ids), | ||
map(int, pool.split(' ')))) | ||
else: | ||
continue | ||
vocab_ids = question.split(' ') | ||
self.questions.append(vocab_ids) | ||
self.lexicon.extend(vocab_ids) | ||
self.positive.append(positive_ids) | ||
self.negative.append(negative_ids) | ||
self.lexicon = collections.Counter(self.lexicon) | ||
|
||
|
||
class InsuranceQaDataset(torch.utils.data.Dataset): | ||
|
||
def __init__(self, question_data, answer_data, vocab, max_length=200): | ||
self.vocab = vocab | ||
self.positive = question_data.positive | ||
self.negative = question_data.negative | ||
self.questions = list(map(self.vocab.word_ids, question_data.questions)) | ||
self.answer_map = dict() | ||
for answer_id, vids in answer_data.answers.items(): | ||
self.answer_map[answer_id] = self.vocab.word_ids(vids[:max_length]) | ||
self.answers = list(self.answer_map.values()) | ||
|
||
def __len__(self): | ||
return len(self.questions) | ||
|
||
def __getitem__(self, index): | ||
question, positive_ids, negative_ids = \ | ||
self.questions[index], self.positive[index], self.negative[index] | ||
positive = self.answer_map[positive_ids[numpy.random.randint(len(positive_ids))]] | ||
if len(negative_ids) > 0: | ||
negative = self.answer_map[negative_ids[numpy.random.randint(len(negative_ids))]] | ||
else: | ||
negative = self.answers[numpy.random.randint(len(self.answers))] | ||
return torch.LongTensor(question), \ | ||
torch.LongTensor(positive), \ | ||
torch.LongTensor(negative) | ||
|
||
def get_qa_entry(self, index): | ||
question, positive_ids, negative_ids = \ | ||
self.questions[index], self.positive[index], self.negative[index] | ||
positives = [self.answer_map[positive_id] for positive_id in positive_ids] | ||
negatives = [self.answer_map[negative_id] for negative_id in negative_ids] | ||
return question, positives, negatives | ||
|
||
@classmethod | ||
def collate(cls, batch): | ||
qs, ps, ns = zip(*batch) | ||
return torch.nn.utils.rnn.pad_sequence(qs, batch_first=True), \ | ||
torch.nn.utils.rnn.pad_sequence(ps, batch_first=True), \ | ||
torch.nn.utils.rnn.pad_sequence(ns, batch_first=True), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import torch | ||
|
||
|
||
class QaLoss(torch.nn.Module): | ||
|
||
def __init__(self, margin): | ||
super().__init__() | ||
self.margin = margin | ||
|
||
def forward(self, question, positive, negative): | ||
""" | ||
max {0, margin - cosine(q, a+) + cosine(q, a-)} | ||
""" | ||
positive_sim = (question * positive).sum(1, keepdim=True) | ||
negative_sim = (question * negative).sum(1, keepdim=True) | ||
zeros = positive_sim.data.new_zeros(*positive_sim.shape) | ||
loss = torch.cat((zeros, negative_sim - positive_sim + self.margin), dim=1) | ||
loss, _ = torch.max(loss, dim=1) | ||
return torch.mean(loss) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import torch | ||
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | ||
|
||
|
||
class SentenceEncoder(torch.nn.Module): | ||
|
||
def __init__(self, embedding_weights, hidden_size): | ||
super().__init__() | ||
embedding_weights = torch.FloatTensor(embedding_weights) | ||
self.embedding = torch.nn.Embedding.from_pretrained(embedding_weights) | ||
self.rnn = torch.nn.LSTM(embedding_weights.shape[-1], hidden_size, | ||
batch_first=True, bidirectional=True) | ||
|
||
def forward(self, x): | ||
lengths = (-x.data.eq(0).long() + 1).sum(1) | ||
_, idx_sort = torch.sort(lengths, dim=0, descending=True) | ||
_, idx_unsort = torch.sort(idx_sort, dim=0) | ||
x = x.index_select(0, idx_sort) | ||
lengths = lengths.index_select(0, idx_sort) | ||
x = self.embedding(x) | ||
x = pack_padded_sequence(x, lengths, batch_first=True) | ||
x, *_ = self.rnn(x) | ||
x, _ = pad_packed_sequence(x, batch_first=True, padding_value=float('-inf')) | ||
x, _ = torch.max(x, dim=1) | ||
norm = x.norm(p=2, dim=1, keepdim=True) | ||
x = x.div(norm) | ||
x = x.index_select(0, idx_unsort) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
#!/bin/bash | ||
|
||
# download and unzip insurance qa dataset | ||
git submodule update --recursive | ||
|
||
# download pretrained word2vec model | ||
curl -O https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz | ||
gzip -d GoogleNews-vectors-negative300.bin.gz |
Oops, something went wrong.