forked from NVlabs/ocropus3-ocroline
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathocroline-pred
executable file
·113 lines (87 loc) · 2.86 KB
/
ocroline-pred
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/python3
import os
import os.path
import argparse
from pylab import *
from torch import nn
from dlinputs import gopen, paths, filters
import ocroline
rc("image", cmap="gray")
ion()
model_path = os.environ.get(
"MODELS", ".:/usr/local/share/ocroline:/usr/share/ocroline")
default_model = "line2-000003330-004377.pt"
parser = argparse.ArgumentParser("train a page segmenter")
parser.add_argument("-m", "--model", default=default_model, help="load model")
parser.add_argument("-b", "--batchsize", type=int, default=1)
parser.add_argument("-D", "--makesource", default=None)
parser.add_argument("-P", "--makepipeline", default=None)
parser.add_argument("-i", "--invert", action="store_true")
parser.add_argument("--display", type=int, default=0)
parser.add_argument("input")
parser.add_argument("output", nargs="?")
args = parser.parse_args()
ARGS = {k: v for k, v in list(args.__dict__.items())}
def make_source():
return gopen.open_source(args.input)
def make_pipeline():
def fixdepth(image):
assert image.ndim in [2, 3]
if image.ndim == 3:
image = np.mean(image, 2)
image = np.expand_dims(image, 2)
image -= amin(image)
image /= amax(image)
if args.invert:
image = 1-image
return image
return filters.compose(
filters.rename(input="line.png png jpeg jpg"),
filters.map(input=fixdepth),
filters.batched(args.batchsize, combine_tensors=False))
if args.makesource:
exec(compile(open(args.makesource, "rb").read(), args.makesource, 'exec'))
if args.makepipeline:
exec(compile(open(args.makepipeline, "rb").read(), args.makepipeline, 'exec'))
def pixels_to_batch(x):
b, d, h, w = x.size()
return x.permute(0, 2, 3, 1).contiguous().view(b*h*w, d)
class PixelsToBatch(nn.Module):
def forward(self, x):
return pixels_to_batch(x)
source = make_source()
pipeline = make_pipeline()
source = pipeline(source)
if args.output:
sink = gopen.open_sink(args.output)
mname = paths.find_file(model_path, args.model)
assert mname is not None, "model not found"
print("loading", mname)
rec = ocroline.LineRecognizer(mname)
print(rec.model)
def display_batch(image, output):
clf()
if image is not None:
subplot(121)
imshow(image[0, :, :, 0], vmin=0, vmax=1)
if output is not None:
subplot(122)
imshow(output[0, :, :, 0], vmin=0, vmax=1)
draw()
ginput(1, 1e-3)
for i, sample in enumerate(source):
fname = sample["__key__"]
image = sample["input"]
output = rec.recognize_batch(image)
# if nbatches % 10 == 0:
if args.display > 0:
if i % args.display == 0:
clf()
imshow(image[0, :, :, 0], vmin=0, vmax=1)
draw()
ginput(1, 1e-3)
waitforbuttonpress(0.0001)
for line in output:
print(line)
if args.output:
sink.close()