Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Voice commands using phrase embeddings #40

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions commands.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
[
{
"name": "restart",
"threshold": 0.5,
"phrases": [
{
"content": "Start over",
"embedding": [],
"similarity": 0
},
{
"content": "Restart the conversation",
"embedding": [],
"similarity": 0
},
{
"content": "From the top",
"embedding": [],
"similarity": 0
},
{
"content": "Wait, stop, go back",
"embedding": [],
"similarity": 0
},
{
"content": "No no no",
"embedding": [],
"similarity": 0
},
{
"content": "From the beginning",
"embedding": [],
"similarity": 0
},
{
"content": "Take it from the top",
"embedding": [],
"similarity": 0
}
]
}
]
31 changes: 31 additions & 0 deletions embed.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import fs from 'fs';
import util from 'util';
import { CommandEmbedding } from './src/talk';
import { llamaEmbed } from './src/depedenciesLibrary/llm';

const readFileAsync = util.promisify(fs.readFile);
const writeFileAsync = util.promisify(fs.writeFile);

const COMMAND_DEF_PATH = 'commands.json';
const EMBEDDINGS_PATH = 'embeddings.json';
const llamaServerUrl = 'http://127.0.0.1:8080';

const embedCommands = async (): Promise<void> => {
const commandData = await readFileAsync(COMMAND_DEF_PATH, 'utf8');
const commands: CommandEmbedding[] = JSON.parse(commandData);
for (const command of commands) {
for (const phrase of command.phrases) {
phrase.embedding = await llamaEmbed(llamaServerUrl, phrase.content);
}
}
const jsonStr = JSON.stringify(commands);

try {
await writeFileAsync(EMBEDDINGS_PATH, jsonStr, 'utf8');
} catch (err) {
console.error('Error writing JSON to file:', err);
}
console.log(`Embedding data written to ${EMBEDDINGS_PATH} successfully.`);
}

embedCommands();
1 change: 1 addition & 0 deletions embeddings.json

Large diffs are not rendered by default.

13 changes: 11 additions & 2 deletions index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ import { spawn } from 'child_process';
import readline from 'readline';
import config from './config.json';
const { whisperModelPath, audioListenerScript } = config;
import { talk } from './src/talk';
import embeddings from './embeddings.json';
import { talk, checkTranscriptionForCommand, CommandType, CommandEmbedding } from './src/talk';
const fs = require('fs');
const path = require('path');

Expand Down Expand Up @@ -189,12 +190,15 @@ const transcriptionEventHandler = async (event: AudioBytesEvent) => {

// TODO: Wait for 1s, because whisper bindings currently throw out if not enough audio passed in
// Therefore fix whisper
let transcription = '';
if (!transcriptionMutex && joinedBuffer.length > ONE_SECOND) {
transcriptionMutex = true;
globalWhisperPromise = whisper.whisperInferenceOnBytes(joinedBuffer);
const rawTranscription = await globalWhisperPromise;
// Remove transcription artifacts like (wind howling)
const transcription = rawTranscription.replace(/\s*\[[^\]]*\]\s*|\s*\([^)]*\)\s*/g, '');
transcription = rawTranscription.replace(/\s*\[[^\]]*\]\s*|\s*\([^)]*\)\s*/g, '');
// Trim starting whitespace
transcription = transcription.trimStart();
const transcriptionEvent: TranscriptionEvent = {
timestamp: Number(Date.now()),
eventType: 'transcription',
Expand All @@ -207,6 +211,11 @@ const transcriptionEventHandler = async (event: AudioBytesEvent) => {
newEventHandler(transcriptionEvent);
transcriptionMutex = false;
}
const command: CommandType = await checkTranscriptionForCommand(llamaServerUrl, embeddings as CommandEmbedding[], transcription);
if (command === 'restart') {
console.log('===== RESTART =====');
// TODO restart the conversation
}
}

const cutTranscriptionEventHandler = async (event: TranscriptionEvent) => {
Expand Down
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"license": "MIT",
"scripts": {
"start": "npx ts-node ./index.ts",
"embed": "npx ts-node ./embed.ts",
"test-voice": "npx ts-node ./tests/voice.test.ts"
},
"dependencies": {
Expand Down
11 changes: 11 additions & 0 deletions src/depedenciesLibrary/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,14 @@ export const llamaInvoke = (prompt: string, input: string, llamaServerUrl: strin
});
});
}

export const llamaEmbed = async (llamaServerUrl: string, content: string): Promise<number[]> => {
const response = await axios({
method: 'post',
url: `${llamaServerUrl}/embedding`,
data: {
content
}
});
return response.data.embedding;
}
62 changes: 60 additions & 2 deletions src/talk.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
import { playAudioFile, generateAudio } from './depedenciesLibrary/voice'
import { llamaInvoke } from './depedenciesLibrary/llm';
import { llamaInvoke, llamaEmbed } from './depedenciesLibrary/llm';

export type CommandType = 'continue' | 'restart';

interface Phrase {
content: string;
embedding: number[];
similarity: number;
}

export interface CommandEmbedding {
name: CommandType;
threshold: number;
phrases: Phrase[];
}

// Talk: Greedily generate audio while completing an LLM inference
export const talk = async (prompt: string, input: string, llamaServerUrl: string, personaConfig:string, sentenceCallback: (sentence: string) => void): Promise<string> => {
Expand Down Expand Up @@ -28,4 +42,48 @@ export const talk = async (prompt: string, input: string, llamaServerUrl: string
await promisesChain;
return response;

}
}

const cosineSimilarity = (A: number[], B: number[]): number => {
if ((!A.length) || (!B.length) || (A.length !== B.length)) {
throw new Error('Invalid vectors');
}
let dotProduct = 0;
let magA = 0;
let magB = 0;

for (let i=0; i<A.length; i++) {
dotProduct += A[i] * B[i];
magA += A[i] * A[i];
magB += B[i] * B[i];
}
magA = Math.sqrt(magA);
magB = Math.sqrt(magB);

return dotProduct / (magA * magB);
Comment on lines +47 to +63
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh what? cosine similarity is just dot product operation? :^)

}

// Check the transcription for a match to the command embeddings
export const checkTranscriptionForCommand = async (llamaServerUrl: string, commandEmbeddings: CommandEmbedding[], transcription: string): Promise<CommandType> => {
if (transcription.length) {
// Remove punctuation from the end
transcription = transcription.replace(/[^\w\s]*$/, "");
console.log(transcription);
const transcriptionEmbedding = await llamaEmbed(llamaServerUrl, transcription);
for (const command of commandEmbeddings) {
command.phrases.map((phrase: Phrase) => {
phrase.similarity = cosineSimilarity(phrase.embedding, transcriptionEmbedding)
});
command.phrases.sort((a, b) => b.similarity - a.similarity);
const phrase = command.phrases[0];
if (phrase.similarity > command.threshold) {
//console.log(`transcription: "${transcription}" PASSED`);
//console.log(`phrase: "${phrase.content}" ${phrase.similarity}`);
//console.log(`greater than threshold ${command.threshold}`);
return command.name;
}
//console.log(`${command.name}: ${phrase.similarity}`);
}
}
return 'continue';
}