Skip to content

Commit

Permalink
Making progress on fingerprinting
Browse files Browse the repository at this point in the history
  • Loading branch information
AG committed Mar 5, 2024
1 parent e122c99 commit 8b804ea
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 76 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@ A demonstration of predictive text without an LLM, using permy.link

[Check it out](https://adamgrant.info/tiny-predictive-text)

## Usage
In script.js uncomment only the dictionary size you want to use. The larger the dictionary, the larger the file and will impact load times.

```javascript
import { dictionary } from './dictionary-10K.js';
// import { dictionary } from './dictionary-25K.js';
// import { dictionary } from './dictionary-100K.js';
// import { dictionary } from './dictionary-250K.js';
```

## Training

No GPUs OS requirements or nVidia libraries needed. I run this on my Macbook Pro with the included version of Python.
Expand Down
195 changes: 119 additions & 76 deletions train-fingerprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
PRUNE_FREQUENCY = 1 * 1000 * 1000 # Every this many document positions
CHUNK_SIZE = 1024 # 1KB per chunk
TARGET_DICTIONARY_COUNT = 250 * 1000
CONTEXT_WORD_LENGTH = 10

# Define a flag to indicate when an interrupt has been caught
interrupted = False
Expand All @@ -32,19 +33,12 @@ def save_trie_store(trie_store):
pickle.dump(trie_store, f, protocol=pickle.HIGHEST_PROTOCOL)
print("trie_fingerprint_store saved due to interruption.")

DEFAULT_TRIE_STORE ={}
# Design:
DEFAULT_TRIE_STORE ={ "fingerprints": {}, "scores": {}}
#
# (Scores)
# {
# "what_i_mean": {
# "phrase": "what I mean", (Everything below is about the 10* words found before this)
# "vcr": 0.72, (Vowel to consonant ratio)
# "wld": [4,2,3,1], (Word length distribution, 3-, 4-, 5- and >5- length)
# "uwr": 0.98 (Unique word ratio)
# },
# "what_i_mean": 12, ... (Number of times we found it)
# }
#
# *On average, an English sentence is 15-20 words long, so 13 (10+3) length is reasonable.

def load_trie_store():
try:
Expand All @@ -57,49 +51,109 @@ def load_trie_store():
def _slugify(text):
return slugify(text, separator="_")

# Define a function to update the trie structure with predictive words
def update_trie(trie, predictive_words):
for word in predictive_words:
# Ensure each word has a sub-trie if it does not exist
if word not in trie:
trie[word] = {}
def vowel_to_consonant_ratio(phrase, existing_vcr, instances):
vowels = "aeiouAEIOU"
consonants = "bcdfghjklmnpqrstvwxyzBCDFGHJKLMNPQRSTVWXYZ"
vowel_count = sum(1 for char in phrase if char in vowels)
consonant_count = sum(1 for char in phrase if char in consonants)
new_vcr = vowel_count / max(1, consonant_count) # Avoid division by zero

# Averaging with existing value
if existing_vcr is not None:
average_vcr = ((existing_vcr * (instances - 1)) + new_vcr) / instances
else:
average_vcr = new_vcr

return average_vcr

# Ensure the '\ranked' key exists at the current level if not already present
if '\ranked' not in trie:
trie['\ranked'] = {}
def word_length_distribution(string, existing_wld, instances):
words = string.split()
# Initialize new distribution with the specific categories
new_distribution = {'<=3': 0, '4': 0, '5': 0, '6': 0, '>=7': 0}

# Count words falling into each category
for word in words:
length = len(word)
if length <= 3:
new_distribution['<=3'] += 1
elif length == 4:
new_distribution['4'] += 1
elif length == 5:
new_distribution['5'] += 1
elif length == 6:
new_distribution['6'] += 1
else: # length >= 7
new_distribution['>=7'] += 1

# Averaging with existing value
if existing_wld is not None:
# Update each category count by averaging
for category, count in new_distribution.items():
if category in existing_wld:
existing_wld[category] = ((existing_wld[category] * (instances - 1)) + count) / instances
else:
existing_wld[category] = count / instances
average_wld = existing_wld
else:
average_wld = {category: count / instances for category, count in new_distribution.items()}

return average_wld

# Update the score in '\ranked' at the current level for the current word
trie['\ranked'][word] = trie['\ranked'].get(word, 0) + 1

# Move to the sub-trie of the current word for the next iteration
# This ensures the structure for subsequent words while keeping '\ranked' updated at the parent level
trie = trie[word]
def unique_word_ratio(string, existing_uwr, instances):
words = string.split()
unique_words = len(set(words))
total_words = len(words)
new_uwr = unique_words / max(1, total_words) # Avoid division by zero

# Averaging with existing value
if existing_uwr is not None:
average_uwr = ((existing_uwr * (instances - 1)) + new_uwr) / instances
else:
average_uwr = new_uwr

return average_uwr

# Define a function to load or initialize the trie from memory
def load_trie(trie_store, path, context_slug):
def load_trie(trie_store, predictive_slug):
# Access the trie data by first navigating to the path, then the context_slug
return trie_store['tries'].get(path, {}).get(context_slug, {})

def save_trie(trie_store, trie, path, context_slug):
# Check if the path exists in 'tries'; if not, create it
if path not in trie_store['tries']:
trie_store['tries'][path] = {}

# Now, path exists for sure; check for context_slug under this path
# This step might be redundant if you're always going to assign a new trie,
# but it's crucial if you're updating or merging with existing data.
if context_slug not in trie_store['tries'][path]:
trie_store['tries'][path][context_slug] = {}

return trie_store["fingerprints"].get(predictive_slug, {})

# Process the values for the trie
def process_trie(trie, actual_phrase):
# Look at the incoming trie. If it has existing values we should process them
# with the ones we compute on the incoming context words. (Avg over inst)
# Design:
# (Fingerprints)
# {
# "what_i_mean": {
# "phrase": "what I mean", (Everything below is about the 10* words found before this)
# "vcr": 0.72, (Vowel to consonant ratio)
# "wld": [4,2,3,1], (Word length distribution, 3-, 4-, 5- and >5- length)
# "uwr": 0.98 (Unique word ratio)
# "inst": 4 (Number of instances found)
# },
# }
#
# *On average, an English sentence is 15-20 words long, so 13 (10+3) length is reasonable.
_actual_phrase = trie.get("phrase", actual_phrase)
instances = trie.get("inst", 0)
instances += 1
trie.update({
"phrase": _actual_phrase,
"vcr": vowel_to_consonant_ratio(_actual_phrase, trie.get("vcr", None), instances),
"wld": word_length_distribution(_actual_phrase, trie.get("wld", None), instances),
"uwr": unique_word_ratio(_actual_phrase, trie.get("uwr", None), instances),
"inst": instances
})

def save_trie(trie_store, trie, predictive_slug):
# Assign the trie to the specified path and context_slug
trie_store['tries'][path][context_slug] = trie
trie_store["fingerprints"][predictive_slug] = trie

def update_scores(trie_store, path, context_slug):
if path not in trie_store['scores']:
trie_store['scores'][path] = {}
if context_slug not in trie_store['scores'][path]:
trie_store['scores'][path][context_slug] = 0
trie_store['scores'][path][context_slug] += 1
def update_scores(trie_store, predictive_slug):
if predictive_slug not in trie_store['scores']:
trie_store['scores'][predictive_slug] = 0
trie_store['scores'][predictive_slug] += 1

def save_position(progress_file, current_position, trie_store):
# Every now and then save our progress.
Expand Down Expand Up @@ -181,61 +235,50 @@ def main():
save_position(progress_file, current_position, trie_store)
prune_position_marker = current_position

# Process words three at a time with shifting window
for i in range(len(words) - 2):
context_words = words[i:i+3]
# Process words with a context window of ten words at a time
for i in range(len(words) - (CONTEXT_WORD_LENGTH - 1)): # Adjust based on CONTEXT_WORD_LENGTH
context_words = words[i:i + CONTEXT_WORD_LENGTH] # Use CONTEXT_WORD_LENGTH for context window
predictive_words = []

# Determine predictive words, up to three or until one ends with a punctuation mark
for j in range(i + 3, min(i + 6, len(words))):
# Determine predictive words, starting right after the context window, up to three words
for j in range(i + CONTEXT_WORD_LENGTH, min(i + CONTEXT_WORD_LENGTH + 3, len(words))):
word = words[j]
# Define a set of punctuation that is allowed within a word
# Define and use punctuation sets as before
internal_punctuation = {"'", "-"}
additional_punctuation = {"“", "”", "–", "—"}
# Create a set of punctuation that signals the end of a word, excluding the internal punctuation
ending_punctuation = set(string.punctuation) - internal_punctuation # + additional_punctuation
ending_punctuation = (set(string.punctuation) | additional_punctuation) - internal_punctuation

# Check for and remove ending punctuation from the word
# Process each word for ending punctuation and collect predictive words as before
cleaned_word = ''.join(char for char in word if char not in ending_punctuation)

# If after cleaning the word it ends with any ending punctuation, or if the original word contained ending punctuation
if cleaned_word != word or any(char in ending_punctuation for char in word):
predictive_words.append(cleaned_word)
break
else:
# For regular words or words with internal punctuation, add the cleaned word
predictive_words.append(cleaned_word)

if not predictive_words: # Skip if there are no predictive words
if not predictive_words:
continue

finish_filing(trie_store, context_words, predictive_words, "3_words")

## Two word alternative
context_words_2 = words[i+1:i+3]
predictive_words_2 = predictive_words[:2]
finish_filing(trie_store, context_words_2, predictive_words_2, "2_words")
finish_filing(trie_store, context_words, predictive_words)

## Three word alternative
context_words_1 = words[i+2:i+3]
finish_filing(trie_store, context_words_1, predictive_words_2, "1_word")
flatten_to_dictionary(trie_store, TARGET_DICTIONARY_COUNT)

def finish_filing(trie_store, context_words, predictive_words, dictionary_subpath):
def finish_filing(trie_store, context_words, predictive_words):
# Slugify the context words
context_slug = _slugify('_'.join(context_words))
predictive_slug = _slugify('_'.join(predictive_words))
actual_phrase = " ".join(context_words)

# Now you can safely proceed with the trie file path
trie = load_trie(trie_store, dictionary_subpath, context_slug)
# Get or create the dict entry for this predictive slug
trie = load_trie(trie_store, predictive_slug)

# Update the trie with the predictive words
update_trie(trie, predictive_words)
# With that entry, start processing the properties of the context words
process_trie(trie, actual_phrase)

# Save the updated trie back to the .pkl file
save_trie(trie_store, trie, dictionary_subpath, context_slug)
save_trie(trie_store, trie, predictive_slug)

# Update the counts in scores_3_words.pkl for the context words slug
update_scores(trie_store, dictionary_subpath, context_slug)
update_scores(trie_store, predictive_slug)

# Check if the script is being run directly and call the main function
if __name__ == "__main__":
Expand Down

0 comments on commit 8b804ea

Please sign in to comment.