Skip to content

Commit

Permalink
Improved resumption
Browse files Browse the repository at this point in the history
  • Loading branch information
AG committed Sep 26, 2024
1 parent f6d1ebb commit 06781bc
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 89 deletions.
Binary file modified backup/processing_progress.txt
Binary file not shown.
204 changes: 115 additions & 89 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,119 +1,145 @@
# Import necessary modules
import pickle
import os
import sys
import gc
import shutil
from tqdm import tqdm
import signal
import datasets
import logging
import sys
import asyncio
from lib.process_predictive_words import main as process_predictive_words
from lib.process_context_words import main as process_context_words
from lib.finish_filing import main as finish_filing
from lib.create_dictionary import create_batch
from lib.merge_batches import main as merge_batches
import asyncio
import gc
from lib.constants import PRUNE_FREQUENCY, TARGET_DICTIONARY_COUNT, TOTAL_WORD_COUNT
import argparse # Import argparse for command-line parsing

# Define a flag to indicate when an interrupt has been caught
# Global flag for graceful exit
interrupted = False

def signal_handler(sig, frame):
global interrupted
interrupted = True
print("Graceful exit request received.")

# Register the signal handler
# Signal handler for graceful exit
signal.signal(signal.SIGINT, signal_handler)

DEFAULT_TREE_STORE ={}

async def save_position(progress_file, current_position, word_count, tree_store):
# Every now and then save our progress.
print(f"Saving the current position of %s" % current_position)

# Save the current progress (file position)
with open(progress_file, 'w') as f:
f.write(f"{str(current_position)},{str(word_count)}")

print(f"Passed %s positions. Time to optimize before continuing..." % PRUNE_FREQUENCY)
# TODO This was causing too many problems.
# await create_batch(tree_store, TARGET_DICTIONARY_COUNT)
return DEFAULT_TREE_STORE
DEFAULT_TREE_STORE = {}

async def load_progress(progress_file):
"""Try to load progress using the new method (state_dict) or fall back to the old method."""
if os.path.exists(progress_file):
try:
# Try to load the progress as a pickle object (new method)
with open(progress_file, 'rb') as f:
state_dict, word_count = pickle.load(f)
print(f"Loaded progress using state_dict with word count {word_count}")
return state_dict, word_count
except (pickle.UnpicklingError, EOFError):
# Fallback to the old method if pickle loading fails (old method)
with open(progress_file, 'r') as f:
start_position_str, word_length_str = f.read().strip().split(',')
start_position = int(start_position_str)
word_count = int(word_length_str)
print(f"Loaded progress using old method from position {start_position} with word count {word_count}")
return start_position, word_count
return None, 0 # No progress file found

async def save_position(progress_file, dataset, word_count, tree_store):
"""Always save progress using the new state_dict method."""
# Always create the state_dict, even if resuming from an old format
state_dict = dataset.state_dict()
with open(progress_file, 'wb') as f:
pickle.dump((state_dict, word_count), f)
print(f"Saved state_dict and word count {word_count}")
await create_batch(tree_store, TARGET_DICTIONARY_COUNT)

return DEFAULT_TREE_STORE

async def main(retain=False):
tree_store = DEFAULT_TREE_STORE
if not retain and os.path.exists('training'):
shutil.rmtree('training')
print("Previous training data cleared.")

training_path = 'training'
os.makedirs(training_path, exist_ok=True)

# Load dataset from Hugging Face datasets
datasets.logging.set_verbosity(datasets.logging.WARNING)
logging.getLogger('fsspec').setLevel(logging.WARNING)
logging.getLogger('urllib3').setLevel(logging.WARNING)
dataset = datasets.load_dataset('oscar-corpus/OSCAR-2201', language='en', split='train', streaming=True, trust_remote_code=True)

# Initialize start_position and word_length to 0
start_position = 0
word_count = 0

# Check if the --retain flag is used and if the progress file exists
if retain and os.path.exists('training/processing_progress.txt'):
with open('training/processing_progress.txt', 'r') as f:
# Read the line and split it by the comma to get both values
start_position_str, word_length_str = f.read().strip().split(',')
# Convert the string values to integers
start_position = int(start_position_str)
word_count = int(word_length_str)
print(f"Resuming from position {start_position} with {word_count} total words processed.")

pbar = tqdm(total=TOTAL_WORD_COUNT, unit='word', desc="Processing dataset", position=1)
pbar.update(word_count)
for i, entry in enumerate(dataset.skip(start_position)):
if i + start_position < start_position:
pbar.display(f"Skipping ahead from {i + start_position} to {start_position}", 1)
continue # Skip to the saved position
text = entry['text'] # Extract text from dataset entry
words = text.split()

pbar.update(len(words))

# Replace reserved characters as before
words = [word.replace("score", "\sscore") for word in words]
words = [word.replace("prediction", "\sprediction") for word in words]

# Process words three at a time with shifting window
for j in range(len(words) - 2):
word_count += 1
if interrupted:
print("Script will terminate when done.")
sys.exit(0)

context_words = process_context_words(words, j)
predictive_words = process_predictive_words(words, j)

if not predictive_words:
continue

tree_store = finish_filing(tree_store, context_words, predictive_words)

if (word_count + 1) % PRUNE_FREQUENCY == 0:
# Save position and prune every PRUNE_FREQUENCY entries
tree_store = await save_position('training/processing_progress.txt', i + start_position + 1, word_count, tree_store)
gc.collect()

if (word_count + 1) % (PRUNE_FREQUENCY * 100) == 0:
await merge_batches()

await create_batch(tree_store, TARGET_DICTIONARY_COUNT)
tree_store = DEFAULT_TREE_STORE
training_path = 'training'

# Clear previous training data if not retaining
if not retain and os.path.exists(training_path):
shutil.rmtree(training_path)
print("Previous training data cleared.")

os.makedirs(training_path, exist_ok=True)

# Load dataset from Hugging Face datasets
datasets.logging.set_verbosity(datasets.logging.WARNING)
logging.getLogger('fsspec').setLevel(logging.WARNING)
logging.getLogger('urllib3').setLevel(logging.WARNING)
dataset = datasets.load_dataset('oscar-corpus/OSCAR-2201', language='en', split='train', streaming=True, trust_remote_code=True)

word_count = 0
start_position = 0
state_dict = None

# Load previous progress (either old or new format)
if retain:
state_dict, word_count = await load_progress('training/processing_progress.txt')
if isinstance(state_dict, dict):
dataset.load_state_dict(state_dict)
else:
# Resume using old method; we still skip to start position but will save with state_dict
start_position = state_dict if isinstance(state_dict, int) else 0

# Initialize progress bar
pbar = tqdm(total=TOTAL_WORD_COUNT, unit='word', desc="Processing dataset", position=1)
pbar.update(word_count)

# Processing dataset
for i, entry in enumerate(dataset.skip(start_position)):
if interrupted:
print("Script will terminate when done.")
sys.exit(0)

# Extract text and process words
text = entry['text']
words = text.split()

# Update the progress bar with the number of words processed
pbar.update(len(words))

# Replace reserved characters
words = [word.replace("score", "\sscore") for word in words]
words = [word.replace("prediction", "\sprediction") for word in words]

# Process words three at a time with a shifting window
for j in range(len(words) - 2):
word_count += 1

# Get context and predictive words
context_words = process_context_words(words, j)
predictive_words = process_predictive_words(words, j)

if not predictive_words:
continue

# File the words
tree_store = finish_filing(tree_store, context_words, predictive_words)

# Save position and prune periodically
if (word_count + 1) % PRUNE_FREQUENCY == 0:
tree_store = await save_position('training/processing_progress.txt', dataset, word_count, tree_store)
gc.collect()

# Silencing for now. Creating too many problems.
# Merge batches periodically
# if (word_count + 1) % (PRUNE_FREQUENCY * 100) == 0:
# await merge_batches()

# Final batch creation after processing is complete
await create_batch(tree_store, TARGET_DICTIONARY_COUNT)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Training script with position retain functionality.')
parser.add_argument('--retain', action='store_true', help='Retain and resume from last saved position.')
args = parser.parse_args()
asyncio.run(main(retain=args.retain))

asyncio.run(main(retain=args.retain))

0 comments on commit 06781bc

Please sign in to comment.