diff --git a/Makefile b/Makefile index 57dff69e..858c5b57 100644 --- a/Makefile +++ b/Makefile @@ -194,22 +194,7 @@ $(OUTPUT_DIR): $(OUTPUT_DIR)/list.eval \ $(OUTPUT_DIR)/list.train: $(ALL_LSTMF) | $(OUTPUT_DIR) - @total=$$(wc -l < $(ALL_LSTMF)); \ - train=$$(echo "$$total * $(RATIO_TRAIN) / 1" | bc); \ - test "$$train" = "0" && \ - echo "Error: missing ground truth for training" && exit 1; \ - eval=$$(echo "$$total - $$train" | bc); \ - test "$$eval" = "0" && \ - echo "Error: missing ground truth for evaluation" && exit 1; \ - set -x; \ - head -n "$$train" $(ALL_LSTMF) > "$(OUTPUT_DIR)/list.train" && \ - tail -n "$$eval" $(ALL_LSTMF) > "$(OUTPUT_DIR)/list.eval" -ifeq (Windows_NT, $(OS)) - dos2unix "$(ALL_LSTMF)" - dos2unix "$(OUTPUT_DIR)/list.train" - dos2unix "$(OUTPUT_DIR)/list.eval" -endif - + $(PY_CMD) generate_eval_train.py $(ALL_LSTMF) $(RATIO_TRAIN) ifdef START_MODEL $(DATA_DIR)/$(START_MODEL)/$(MODEL_NAME).lstm-unicharset: diff --git a/generate_eval_train.py b/generate_eval_train.py new file mode 100644 index 00000000..ffb9330d --- /dev/null +++ b/generate_eval_train.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import pathlib +import sys + + +def split_file(input_file, ratio): + """ + Splits a text file into list.train and list.eval with lines ratio. + """ + if not isinstance(input_file, pathlib.Path): + input_file = pathlib.Path(input_file) + if not input_file.exists(): + print(f"'{input_file}' not exists!") + return False + lines = input_file.read_text().splitlines() + + split_point = int(ratio * len(lines)) + output_dir = input_file.resolve().parent + train_list = pathlib.Path(output_dir, 'list.train') + eval_list = pathlib.Path(output_dir, 'list.eval') + + with open(train_list, 'w', newline='\n') as f1, open( + eval_list, 'w', newline='\n' + ) as f2: + f1.writelines(lines[:split_point]) + f2.writelines(lines[split_point:]) + return True + + +ratio = 0.95 +input_file = None +if len(sys.argv) > 1: + input_file = sys.argv[1] +if len(sys.argv) > 2: + ratio = float(sys.argv[2]) + +split_file(input_file, ratio)