Skip to content

Commit

Permalink
Update code base to support per-computed math.log
Browse files Browse the repository at this point in the history
  • Loading branch information
howl-anderson committed Sep 2, 2018
1 parent 618b5bb commit 22ff836
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 30 deletions.
12 changes: 10 additions & 2 deletions MicroHMM/hmm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import math
import pickle
import pathlib
from typing import List, Union, Tuple

from MicroHMM.viterbi import Viterbi

Expand All @@ -23,6 +25,8 @@ def __init__(self, A=None, B=None, vocabulary=None):
self.state_observation_pair = {} # count of pair state and emission observation

def train_one_line(self, list_of_word_tag_pair):
# type(List[Union[List[str, str], Tuple[str, str]]) -> None

"""
train model from one line data
:param list_of_word_tag_pair: list of tuple (word, tag)
Expand Down Expand Up @@ -89,18 +93,22 @@ def do_train(self):
if previous_state not in self.A:
self.A[previous_state] = {}

self.A[previous_state][state] = bigram_probability
self.A[previous_state][state] = math.log(bigram_probability)

# compute emission probability
# NOTE: using dict.get() to prevent start state have on emission will cause exeception
emission_local_storage = self.state_observation_pair.get(previous_state, {})
for word, word_count in emission_local_storage.items():
emmit_probability = word_count / previous_state_count

if previous_state not in self.B:
self.B[previous_state] = {}

self.B[previous_state][word] = word_count / previous_state_count
self.B[previous_state][word] = math.log(emmit_probability)

def predict(self, word_list, output_graphml_file=None):
# type: (List[str], Union[str, None]) -> List[Tuple[str, str]]

if not self.A: # using self.A as an training-flag indicate if already trained.
self.do_train()

Expand Down
23 changes: 6 additions & 17 deletions MicroHMM/viterbi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# -*- coding: utf-8 -*-

import math

import networkx as nx
Expand All @@ -18,7 +17,7 @@ def __init__(self, A, B, vocabulary, start_state='<start>', end_state='<end>', v

# Trk part for robust model
# TODO: why using a very small number than average_emission_probability can archive better performance?
self.very_small_probability = very_small_probability
self.very_small_probability = math.log(very_small_probability)

# create networkx graph
self.G = nx.Graph()
Expand All @@ -44,13 +43,9 @@ def _do_predict(self, word_list):
# path_probability = transition_probability * state_observation_likelihood

# using log(probability) as probability to prevent number to be too small
transition_probability = math.log(
self.A[self.start_state].get(state, self.very_small_probability)
)
transition_probability = self.A[self.start_state].get(state, self.very_small_probability)

state_observation_likelihood = math.log(
self.B[state].get(word, self.very_small_probability)
)
state_observation_likelihood = self.B[state].get(word, self.very_small_probability)

path_probability = transition_probability + state_observation_likelihood

Expand Down Expand Up @@ -88,13 +83,9 @@ def _do_predict(self, word_list):
# path_probability = previous_path_probability * transition_probability * state_observation_likelihood

# using log(probability) as probability to prevent number to be too small
transition_probability = math.log(
self.A[i].get(state, self.very_small_probability)
)
transition_probability = self.A[i].get(state, self.very_small_probability)

state_observation_likelihood = math.log(
self.B[state].get(word, self.very_small_probability)
)
state_observation_likelihood = self.B[state].get(word, self.very_small_probability)

path_probability = previous_path_probability + transition_probability + state_observation_likelihood

Expand Down Expand Up @@ -138,9 +129,7 @@ def _do_predict(self, word_list):
# path_probability = previous_path_probability * transition_probability

# using log(probability) as probability to prevent number to be too small
transition_probability = math.log(
self.A[i].get(self.end_state, self.very_small_probability)
)
transition_probability = self.A[i].get(self.end_state, self.very_small_probability)

path_probability = previous_path_probability + transition_probability

Expand Down
68 changes: 58 additions & 10 deletions makefile
Original file line number Diff line number Diff line change
@@ -1,14 +1,33 @@
.PHONY: build
build:
python ./setup.py sdist bdist_wheel
.PHONY: clean clean-test clean-pyc clean-build docs help
.DEFAULT_GOAL := help

.PHONY: upload_test
upload_test:
twine upload --repository-url https://test.pypi.org/legacy/ dist/*
define BROWSER_PYSCRIPT
import os, webbrowser, sys

.PHONY: upload
upload:
twine upload dist/*
try:
from urllib import pathname2url
except:
from urllib.request import pathname2url

webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1])))
endef
export BROWSER_PYSCRIPT

define PRINT_HELP_PYSCRIPT
import re, sys

for line in sys.stdin:
match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line)
if match:
target, help = match.groups()
print("%-20s %s" % (target, help))
endef
export PRINT_HELP_PYSCRIPT

BROWSER := python -c "$$BROWSER_PYSCRIPT"

help:
@python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST)

clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts

Expand All @@ -31,6 +50,32 @@ clean-test: ## remove test and coverage artifacts
rm -fr htmlcov/
rm -fr .pytest_cache

lint: ## check style with flake8
flake8 MicroTokenizer tests

test: ## run tests quickly with the default Python
py.test

test-all: ## run tests on every Python version with tox
tox

coverage: ## check code coverage quickly with the default Python
coverage run --source MicroTokenizer -m pytest
coverage report -m
coverage html
$(BROWSER) htmlcov/index.html

docs: ## generate Sphinx HTML documentation, including API docs
rm -f docs/MicroTokenizer.rst
rm -f docs/modules.rst
sphinx-apidoc -o docs/ MicroTokenizer
$(MAKE) -C docs clean
$(MAKE) -C docs html
$(BROWSER) docs/_build/html/index.html

servedocs: docs ## compile the docs watching for changes
watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D .

release: dist ## package and upload a release
twine upload dist/*

Expand All @@ -39,6 +84,9 @@ dist: clean ## builds source and wheel package
python setup.py bdist_wheel
ls -l dist

install: clean ## install the package to the active Python's site-packages
python setup.py install

.PHONY: update_minor_version
update_minor_version:
punch --part minor
Expand All @@ -49,4 +97,4 @@ update_patch_version:

.PHONY: update_major_version
update_major_version:
punch --part major
punch --part major
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
networkx==2.1
pathlib;python_version<"3.4"
pytest
5 changes: 4 additions & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
pytest==3.5.1
pytest==3.5.1

# for pytest helper
pytest-helpers-namespace
19 changes: 19 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# -*- coding: utf-8 -*-

pytest_plugins = ['helpers_namespace']

import pytest


@pytest.helpers.register
def train_test_cases():
basic_test_case = [('A', 'a'), ('B', 'b')]

test_cases = [
]

for i in range(1, 4):
test_cases.append([basic_test_case] * i)

# return as list of list
return test_cases
21 changes: 21 additions & 0 deletions tests/test_hmm_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/bin/env python
# -*- coding: utf-8 -*-
import pytest

from MicroHMM.hmm import HMMModel


@pytest.mark.parametrize("train_data", pytest.helpers.train_test_cases())
def test_hmm_train(train_data):
first_train_data = train_data[0]

hmm_model = HMMModel()

for single_train_data in train_data:
hmm_model.train_one_line(single_train_data)

hmm_model.do_train()

result = hmm_model.predict([i[0] for i in first_train_data])

assert first_train_data == result

0 comments on commit 22ff836

Please sign in to comment.