Skip to content

Commit

Permalink
load deepspeech models one time
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyandrews committed Jan 10, 2020
1 parent 0880bd8 commit e5247c9
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 73 deletions.
18 changes: 12 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ struct Configuration {
store: bool,
}

fn audio_to_text(config_data: web::Data<Mutex<Configuration>>, base64_audio: String) -> impl Responder {
fn audio_to_text(config_data: web::Data<Mutex<Configuration>>, deepspeech_data: web::Data<Mutex<speech::KakaiaDeepSpeech>>, base64_audio: String) -> impl Responder {
let config = config_data.lock().unwrap();
let mut kakaia_deepspeech = deepspeech_data.lock().unwrap();

// Load audio.bytes from String
let audio_bytes = match base64::decode(&base64_audio) {
Expand Down Expand Up @@ -73,7 +74,7 @@ fn audio_to_text(config_data: web::Data<Mutex<Configuration>>, base64_audio: Str
}

// Convert audio file to text.
let (message, extension) = speech::convert_audio_to_text(audio_file);
let converted: speech::AudioAsText = kakaia_deepspeech.convert_audio_to_text(audio_file);

// Optionally store a copy of the audio and text
if config.store {
Expand All @@ -84,7 +85,7 @@ fn audio_to_text(config_data: web::Data<Mutex<Configuration>>, base64_audio: Str
let hour = now.format("%H");
let minute = now.format("%M");
let second = now.format("%S");
let mut buffer = match std::fs::File::create(format!("{}/audio-{}-{}-{}.{}", archive_directory, &hour, &minute, &second, extension)) {
let mut buffer = match std::fs::File::create(format!("{}/audio-{}-{}-{}.{}", archive_directory, &hour, &minute, &second, converted.extension)) {
Ok(b) => b,
Err(e) => {
// @TODO: deal with this gracefully
Expand All @@ -111,7 +112,7 @@ fn audio_to_text(config_data: web::Data<Mutex<Configuration>>, base64_audio: Str
return "error archiving text conversion of audio file".to_string();
}
};
match writeln!(buffer, "{}", &message) {
match writeln!(buffer, "{}", &converted.text) {
Ok(_) => (),
Err(e) => eprintln!("failed to archive text conversion of audio file: {}", e),
}
Expand All @@ -123,20 +124,25 @@ fn audio_to_text(config_data: web::Data<Mutex<Configuration>>, base64_audio: Str
}

// Debug output for now:
println!("{}", &message);
println!("{}", &converted.text);
// Return text
format!("{}\n", message)
format!("{}\n", converted.text)
}

fn main() {
// @TODO: do we really need three copies of this?
// Configuration structure for server configuration
let config_server = Configuration::from_args();
// Configuration structure for client configuration
let config_web = config_server.clone();
// Configuration structure for client process
let config_data = web::Data::new(Mutex::new(config_server.clone()));
let deepspeech_data = web::Data::new(Mutex::new(speech::KakaiaDeepSpeech::new()));

let server = HttpServer::new(move || {
App::new()
.register_data(config_data.clone())
.register_data(deepspeech_data.clone())
.service(
web::resource("/convert/audio/text").data(
String::configure(|cfg| {
Expand Down
164 changes: 97 additions & 67 deletions src/speech.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::path::Path;
use std::env;
use std::path::Path;

use audrey::read::Reader;
use audrey::sample::interpolate::{Converter, Linear};
Expand All @@ -14,79 +14,109 @@ const VALID_WORD_COUNT_WEIGHT :f32 = 1.85;
// The model has been trained on this specific sample rate.
const SAMPLE_RATE :u32 = 16_000;

pub fn convert_audio_to_text(audio_file: std::fs::File) -> (String, String) {
// Read audio from temporary file.
let mut reader = match Reader::new(&audio_file) {
Ok(r) => r,
Err(e) => {
return (format!("failed to load audio file ({:?}): {}", audio_file, e), "unknown".to_string());
}
};
pub struct AudioAsText {
pub text: String,
pub extension: String,
}

let desc = reader.description();
// Validate the audio file.
let mut errors: Vec<String> = vec![];
if desc.channel_count() != 1 {
let error = format!("audio file must have exactly 1 track, not {}", desc.channel_count());
eprintln!("{}", &error);
errors.push(error);
}
if desc.sample_rate() != SAMPLE_RATE {
let error = format!("audio sample rate must be {}, not {}", SAMPLE_RATE, desc.sample_rate());
eprintln!("{}", &error);
errors.push(error);
}
if errors.len() > 0 {
return (format!("{:?}\n", errors), "unknown".to_string());
}
pub struct KakaiaDeepSpeech {
model: deepspeech::Model,
}
unsafe impl Send for KakaiaDeepSpeech {}

impl KakaiaDeepSpeech {
pub fn new() -> Self {
const DEEPSPEECH_MODELS_ENV: &str = "DEEPSPEECH_MODELS";
let model_dir = match env::var(DEEPSPEECH_MODELS_ENV) {
Ok(d) => d,
Err(_) => {
let default_dir = env::current_dir().unwrap().join("models/");
eprintln!("DeepSpeechModel: {} isn't set, defaulting to {:?}", DEEPSPEECH_MODELS_ENV, default_dir);
default_dir.to_str().unwrap().to_string()
}
};

let deepspeech_models_env = "DEEPSPEECH_MODELS";
let deepspeech_models_dir = match env::var(deepspeech_models_env) {
Ok(v) => v,
Err(_) => {
let default_dir = env::current_dir().unwrap().join("models/");
eprintln!("{} isn't set, defaulting to {:?}", deepspeech_models_env, default_dir);
default_dir.to_str().unwrap().to_string()
// Run the speech to text algorithm
let dir_path = Path::new(&model_dir);
let mut deepspeech_model = match Model::load_from_files(&dir_path.join("output_graph.pb"), BEAM_WIDTH) {
Ok(m) => m,
Err(_) => {
eprintln!("FATAL ERROR, {:?} is an invalid models path", dir_path);
std::process::exit(1);
}
};
deepspeech_model.enable_decoder_with_lm(&dir_path.join("lm.binary"), &dir_path.join("trie"), LM_WEIGHT, VALID_WORD_COUNT_WEIGHT);

KakaiaDeepSpeech {
model: deepspeech_model,
}
};
}

pub fn convert_audio_to_text(&mut self, audio_file: std::fs::File) -> AudioAsText {
// Read audio from temporary file.
let mut reader = match Reader::new(&audio_file) {
Ok(r) => r,
Err(e) => {
return AudioAsText {
text: format!("failed to load audio file ({:?}): {}", audio_file, e),
extension: "unknown".to_string(),
}
}
};

let dir_path = Path::new(&deepspeech_models_dir);
let mut m = match Model::load_from_files(&dir_path.join("output_graph.pb"), BEAM_WIDTH) {
Ok(m) => m,
Err(_) => {
eprintln!("FATAL ERROR, {:?} is an invalid models path", dir_path);
std::process::exit(1);
let desc = reader.description();
// Validate the audio file.
let mut errors: Vec<String> = vec![];
if desc.channel_count() != 1 {
let error = format!("audio file must have exactly 1 track, not {}", desc.channel_count());
eprintln!("{}", &error);
errors.push(error);
}
};
m.enable_decoder_with_lm(&dir_path.join("lm.binary"), &dir_path.join("trie"), LM_WEIGHT, VALID_WORD_COUNT_WEIGHT);
if desc.sample_rate() != SAMPLE_RATE {
let error = format!("audio sample rate must be {}, not {}", SAMPLE_RATE, desc.sample_rate());
eprintln!("{}", &error);
errors.push(error);
}
if errors.len() > 0 {
return AudioAsText {
text: format!("{:?}", errors),
extension: "unknown".to_string(),
}
}

// Obtain the buffer of samples
let audio_buffer :Vec<_> = if desc.sample_rate() == SAMPLE_RATE {
reader.samples().map(|s| s.unwrap()).collect()
} else {
// We need to interpolate to the target sample rate
let interpolator = Linear::new([0i16], [0]);
let conv = Converter::from_hz_to_hz(
from_iter(reader.samples::<i16>().map(|s| [s.unwrap()])),
interpolator,
desc.sample_rate() as f64,
SAMPLE_RATE as f64);
conv.until_exhausted().map(|v| v[0]).collect()
};

// Obtain the buffer of samples
let audio_buf :Vec<_> = if desc.sample_rate() == SAMPLE_RATE {
reader.samples().map(|s| s.unwrap()).collect()
} else {
// We need to interpolate to the target sample rate
let interpolator = Linear::new([0i16], [0]);
let conv = Converter::from_hz_to_hz(
from_iter(reader.samples::<i16>().map(|s| [s.unwrap()])),
interpolator,
desc.sample_rate() as f64,
SAMPLE_RATE as f64);
conv.until_exhausted().map(|v| v[0]).collect()
};
let extension = match desc.format() {
audrey::Format::Flac => "flac".to_string(),
audrey::Format::OggVorbis => "ogg".to_string(),
audrey::Format::Wav => "wav".to_string(),
audrey::Format::CafAlac => "caf".to_string(),
};

let extension = match desc.format() {
audrey::Format::Flac => "flac".to_string(),
audrey::Format::OggVorbis => "ogg".to_string(),
audrey::Format::Wav => "wav".to_string(),
audrey::Format::CafAlac => "caf".to_string(),
};
let text = match self.model.speech_to_text(audio_buffer.as_slice()) {
Ok(t) => t,
Err(e) => {
// @TODO: handle this gracefully
eprintln!("Unexpected error converting audio to text: {}", e);
"Unexpected error: failed to convert audio to text".to_string()
}
};

// Run the speech to text algorithm
match m.speech_to_text(&audio_buf) {
Ok(text) => (text, extension),
Err(e) => {
eprintln!("error converting speech to text: {}", e);
("ERROR".to_string(), extension)
AudioAsText {
text: text,
extension: extension,
}
}
}
}

0 comments on commit e5247c9

Please sign in to comment.