-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathnlu_model.py
32 lines (24 loc) · 1.05 KB
/
nlu_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import logging
import pprint
from rasa.nlu.training_data import load_data
from rasa.nlu import config
from rasa.nlu.model import Trainer
from rasa.nlu.model import Interpreter
from rasa.nlu.test import run_evaluation
logfile = 'nlu_model.log'
def train_nlu(data_path, configs, model_path):
logging.basicConfig(filename=logfile, level=logging.DEBUG)
training_data = load_data(data_path)
trainer = Trainer(config.load(configs))
trainer.train(training_data)
model_directory = trainer.persist(model_path, fixed_model_name='nlu')
run_evaluation(data_path, model_directory)
def run_nlu(nlu_path):
logging.basicConfig(filename=logfile, level=logging.DEBUG)
interpreter = Interpreter.load(nlu_path)
pprint.pprint(interpreter.parse("Share some latest news around the world?"))
pprint.pprint(interpreter.parse("What is going on in technology?"))
pprint.pprint(interpreter.parse("What is going on in education?"))
if __name__ == '__main__':
# train_nlu('./data/nlu.md', 'nlu_config.yml', './models')
run_nlu('./models/nlu')