forked from nxvipin/Ocrn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
grinder.py
137 lines (111 loc) · 3.73 KB
/
grinder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from ocrn import dataset as ds
from ocrn import feature as ft
from ocrn import neuralnet as nn
from PIL import Image
from pybrain.tools.xml.networkwriter import NetworkWriter
from pybrain.tools.xml.networkreader import NetworkReader
import numpy
import pickle
import os
import socket
import md5
import datetime
hostname = socket.gethostname()
joey = hostname == "Plutonium"
OCRN_PATH = "/home/dg/Ocrn/"
if joey:
OCRN_PATH = "/Users/josefdlange/Projects/Expresso/Ocrn/"
class Grinder:
"""
Class to encapsulate the text-recognition neural network, based off swvist's implementation.
More details at http://github.com/expresso-math/Ocrn
Important member variables:
-- neural_network: The neural network. ocrn.neuralnet type.
-- data_set: The data set. ocrn.dataset type
"""
def __init__(self):
"""
Create a __new__ Grinder object.
"""
self.neural_network = nn.neuralnet(100,80,1)
self.data_set = ds.dataset(100,1)
def load_dataset(self, file_path=[OCRN_PATH+'data/inputdata']):
"""
Load the dataset from file. Defaults to the above.
"""
if self.data_set.generateDataSetFromFile(file_path):
print "Successfully loaded dataset from file."
if self.neural_network.loadTrainingData(self.data_set.getTrainingDataset()):
print "Successfully loaded Training Data from data_set."
else:
# Couldn't load training data from data set.
print "There was an error loading the training data into the neural network."
else:
# Couldn't load data set.
print "Something is really broken, since this method always returns 1."
def train_to_convergence(self):
self.neural_network.teachUntilConvergence()
def train(self, maxEpochs = 100):
self.neural_network.teach(maxEpochs)
def train_loop(self, n):
self.neural_network.teach(n)
def guess(self, image_file_path):
if image_file_path:
feature_vector = ft.feature.getImageFeatureVector(image_file_path)
result = self.neural_network.activate(feature_vector)
if result >= 0:
result = result
else:
result = 0
return str(unichr(result))
def guess_on_image(self, image):
if image:
image = Image.open(image)
pathname = OCRN_PATH+"/data/testdata/" + md5.new(str(datetime.datetime.now())).hexdigest() + ".bmp"
tempImage = image.convert("1")
tempImage.save(pathname, "BMP")
result = self.guess(pathname)
print result
return result
def generateDataSetFromRoaster(self, dataTuple):
"""
Takes a tuple of (imageData, asciiVal) and adds all images to
../data/trainingdata/ and then adds a line to imagedata
"""
imageData, asciiVal = dataTuple
# For each imageData entry in the imageData list, save as a bmp and
# write that path to imageData with `path:asciiVal`
#This is to not have to do a getTrainingCount call every time.
trainCount = self.getTrainingCount()
datafile = open(OCRN_PATH+"/data/inputdata", "a")
for image in imageData:
pathname = OCRN_PATH+"/data/trainingdata/" + str(trainCount) + ".bmp"
tempImage = image.convert("1")
tempImage.save(pathname, "BMP")
datafile.write(pathname+":"+str(asciiVal)+"\n")
trainCount = trainCount + 1
datafile.close()
def getTrainingCount(self):
"""
Gets the number of trained images from the imageData file.
"""
# Will need to change this to relative path later.
wcData = os.popen("wc -l " + OCRN_PATH + "/data/inputdata").read()
# Because wc returns number and filename
wcList = wcData.split()
if not wcList:
wcList = ['0']
print "wcList is " + str(wcList)
return int(wcList[0])
def reset(self):
self.neural_network = nn.neuralnet(100,80,1)
self.data_set = ds.dataset(100,1)
def main():
g = Grinder()
print g
print g.neural_network
print g.neural_network.nnet
print g.data_set
print g.data_set.DS
if __name__ =='__main__':
main()