Skip to content

Commit

Permalink
Use python script for generating list.eval and list.train.
Browse files Browse the repository at this point in the history
This decreases dependency on `bc` program and solves the problem with EOL on Windows.
  • Loading branch information
zdenop authored and stweil committed Mar 23, 2024
1 parent 9331445 commit b72e01c
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 16 deletions.
17 changes: 1 addition & 16 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -194,22 +194,7 @@ $(OUTPUT_DIR):

$(OUTPUT_DIR)/list.eval \
$(OUTPUT_DIR)/list.train: $(ALL_LSTMF) | $(OUTPUT_DIR)
@total=$$(wc -l < $(ALL_LSTMF)); \
train=$$(echo "$$total * $(RATIO_TRAIN) / 1" | bc); \
test "$$train" = "0" && \
echo "Error: missing ground truth for training" && exit 1; \
eval=$$(echo "$$total - $$train" | bc); \
test "$$eval" = "0" && \
echo "Error: missing ground truth for evaluation" && exit 1; \
set -x; \
head -n "$$train" $(ALL_LSTMF) > "$(OUTPUT_DIR)/list.train" && \
tail -n "$$eval" $(ALL_LSTMF) > "$(OUTPUT_DIR)/list.eval"
ifeq (Windows_NT, $(OS))
dos2unix "$(ALL_LSTMF)"
dos2unix "$(OUTPUT_DIR)/list.train"
dos2unix "$(OUTPUT_DIR)/list.eval"
endif

$(PY_CMD) generate_eval_train.py $(ALL_LSTMF) $(RATIO_TRAIN)

ifdef START_MODEL
$(DATA_DIR)/$(START_MODEL)/$(MODEL_NAME).lstm-unicharset:
Expand Down
39 changes: 39 additions & 0 deletions generate_eval_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import pathlib
import sys


def split_file(input_file, ratio):
"""
Splits a text file into list.train and list.eval with lines ratio.
"""
if not isinstance(input_file, pathlib.Path):
input_file = pathlib.Path(input_file)
if not input_file.exists():
print(f"'{input_file}' not exists!")
return False
lines = input_file.read_text().splitlines()

split_point = int(ratio * len(lines))
output_dir = input_file.resolve().parent
train_list = pathlib.Path(output_dir, 'list.train')
eval_list = pathlib.Path(output_dir, 'list.eval')

with open(train_list, 'w', newline='\n') as f1, open(
eval_list, 'w', newline='\n'
) as f2:
f1.writelines(lines[:split_point])
f2.writelines(lines[split_point:])
return True


ratio = 0.95
input_file = None
if len(sys.argv) > 1:
input_file = sys.argv[1]
if len(sys.argv) > 2:
ratio = float(sys.argv[2])

split_file(input_file, ratio)

0 comments on commit b72e01c

Please sign in to comment.