diff --git a/python/src/lib.rs b/python/src/lib.rs index 1c139dd..55a72ea 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -59,7 +59,7 @@ fn result_to_csr(py: Python, x: CsMat) -> PyResult { #[pyclass] pub struct _HashingVectorizerWrapper { - inner: vtext::vectorize::HashingVectorizer, + inner: vtext::vectorize::HashingVectorizer, } #[pymethods] @@ -67,7 +67,9 @@ impl _HashingVectorizerWrapper { #[new] #[args(n_jobs = 1)] fn new(obj: &PyRawObject, n_jobs: usize) { - let estimator = vtext::vectorize::HashingVectorizer::new().n_jobs(n_jobs); + let tokenizer = vtext::tokenize::RegexpTokenizer::new("\\b\\w\\w+\\b".to_string()); + let estimator = vtext::vectorize::HashingVectorizer::new(tokenizer).n_jobs(n_jobs); + obj.init(_HashingVectorizerWrapper { inner: estimator }); } @@ -84,7 +86,7 @@ impl _HashingVectorizerWrapper { #[pyclass] pub struct _CountVectorizerWrapper { - inner: vtext::vectorize::CountVectorizer, + inner: vtext::vectorize::CountVectorizer, } #[pymethods] @@ -92,7 +94,8 @@ impl _CountVectorizerWrapper { #[new] #[args(n_jobs = 1)] fn new(obj: &PyRawObject, n_jobs: usize) { - let estimator = vtext::vectorize::CountVectorizer::new().n_jobs(n_jobs); + let tokenizer = vtext::tokenize::RegexpTokenizer::new("\\b\\w\\w+\\b".to_string()); + let estimator = vtext::vectorize::CountVectorizer::new(tokenizer).n_jobs(n_jobs); obj.init(_CountVectorizerWrapper { inner: estimator }); } diff --git a/src/tokenize/mod.rs b/src/tokenize/mod.rs index a074331..fa570cf 100644 --- a/src/tokenize/mod.rs +++ b/src/tokenize/mod.rs @@ -63,6 +63,7 @@ pub trait Tokenizer: fmt::Debug { /// Regular expression tokenizer /// +#[derive(Clone)] pub struct RegexpTokenizer { pub pattern: String, regexp: Regex, @@ -98,7 +99,7 @@ impl fmt::Debug for RegexpTokenizer { /// ## References /// /// * [UnicodeĀ® Standard Annex #29](http://www.unicode.org/reports/tr29/) -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct UnicodeSegmentTokenizer { pub word_bounds: bool, } @@ -135,7 +136,7 @@ impl Tokenizer for UnicodeSegmentTokenizer { /// ## References /// /// * [UnicodeĀ® Standard Annex #29](http://www.unicode.org/reports/tr29/) -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct VTextTokenizer { pub lang: String, } @@ -270,7 +271,7 @@ impl Tokenizer for VTextTokenizer { } /// Character tokenizer -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct CharacterTokenizer { pub window_size: usize, } diff --git a/src/vectorize/mod.rs b/src/vectorize/mod.rs index 1f751bd..afa57fc 100644 --- a/src/vectorize/mod.rs +++ b/src/vectorize/mod.rs @@ -12,19 +12,22 @@ This module allows computing a sparse document term matrix from a text corpus. ```rust extern crate vtext; +use vtext::tokenize::{VTextTokenizer,Tokenizer}; use vtext::vectorize::CountVectorizer; + let documents = vec![ String::from("Some text input"), String::from("Another line"), ]; -let mut vectorizer = CountVectorizer::new(); +let tokenizer = VTextTokenizer::new("en"); + +let mut vectorizer = CountVectorizer::new(tokenizer); let X = vectorizer.fit_transform(&documents); // returns a sparse CSR matrix with document-terms counts */ use crate::math::CSRArray; -use crate::tokenize; use crate::tokenize::Tokenizer; use hashbrown::{HashMap, HashSet}; use itertools::sorted; @@ -87,25 +90,25 @@ fn _sum_duplicates(tf: &mut CSRArray, indices_local: &[i32], nnz: &mut usize) { } #[derive(Debug)] -pub struct CountVectorizer { +pub struct CountVectorizer { lowercase: bool, - token_pattern: String, - _n_jobs: usize, + tokenizer: T, // vocabulary uses i32 indices, to avoid memory copies when converting // to sparse CSR arrays in Python with scipy.sparse pub vocabulary: HashMap, + _n_jobs: usize, } pub enum Vectorizer {} -impl CountVectorizer { +impl CountVectorizer { /// Initialize a CountVectorizer estimator - pub fn new() -> Self { + pub fn new(tokenizer: T) -> Self { CountVectorizer { lowercase: true, - token_pattern: String::from(TOKEN_PATTERN_DEFAULT), vocabulary: HashMap::with_capacity_and_hasher(1000, Default::default()), _n_jobs: 1, + tokenizer, } } @@ -124,14 +127,12 @@ impl CountVectorizer { /// /// This lists the vocabulary pub fn fit(&mut self, X: &[String]) -> () { - let tokenizer = tokenize::RegexpTokenizer::new(TOKEN_PATTERN_DEFAULT.to_string()); - let tokenize = |X: &[String]| -> HashSet { let mut _vocab: HashSet = HashSet::with_capacity(1000); for doc in X { let doc = doc.to_ascii_lowercase(); - let tokens = tokenizer.tokenize(&doc); + let tokens = self.tokenizer.tokenize(&doc); for token in tokens { if !_vocab.contains(token) { @@ -176,14 +177,12 @@ impl CountVectorizer { let mut nnz: usize = 0; - let tokenizer = tokenize::RegexpTokenizer::new(TOKEN_PATTERN_DEFAULT.to_string()); - let tokenize_map = |doc: &str| -> Vec { // Closure to tokenize a document and returns hash indices for each token let mut indices_local: Vec = Vec::with_capacity(10); - for token in tokenizer.tokenize(doc) { + for token in self.tokenizer.tokenize(doc) { if let Some(_id) = self.vocabulary.get(token) { indices_local.push(*_id) }; @@ -239,14 +238,12 @@ impl CountVectorizer { let mut nnz: usize = 0; let mut indices_local: Vec = Vec::new(); - let tokenizer = tokenize::RegexpTokenizer::new(TOKEN_PATTERN_DEFAULT.to_string()); - let pipe = X.iter().map(|doc| doc.to_ascii_lowercase()); let mut vocabulary_size: i32 = 0; for document in pipe { - let tokens = tokenizer.tokenize(&document); + let tokens = self.tokenizer.tokenize(&document); indices_local.clear(); @@ -277,23 +274,23 @@ impl CountVectorizer { } #[derive(Debug)] -pub struct HashingVectorizer { +pub struct HashingVectorizer { lowercase: bool, - token_pattern: String, + tokenizer: T, n_features: u64, _n_jobs: usize, thread_pool: Option, } -impl HashingVectorizer { +impl HashingVectorizer { /// Create a new HashingVectorizer estimator - pub fn new() -> Self { + pub fn new(tokenizer: T) -> Self { HashingVectorizer { lowercase: true, - token_pattern: String::from(TOKEN_PATTERN_DEFAULT), n_features: 1048576, _n_jobs: 1, thread_pool: None, + tokenizer, } } @@ -336,14 +333,12 @@ impl HashingVectorizer { let mut nnz: usize = 0; - let tokenizer = tokenize::RegexpTokenizer::new(TOKEN_PATTERN_DEFAULT.to_string()); - let tokenize_hash = |doc: &str| -> Vec { // Closure to tokenize a document and returns hash indices for each token let mut indices_local: Vec = Vec::with_capacity(10); - for token in tokenizer.tokenize(doc) { + for token in self.tokenizer.tokenize(doc) { // set the RNG seeds to get reproducible hashing let hash = seahash::hash_seeded(token.as_bytes(), 1, 1000, 200, 89); let hash = (hash % self.n_features) as i32; diff --git a/src/vectorize/tests.rs b/src/vectorize/tests.rs index 390e901..48ec33c 100644 --- a/src/vectorize/tests.rs +++ b/src/vectorize/tests.rs @@ -4,14 +4,17 @@ // . This file may not be copied, // modified, or distributed except according to those terms. +use crate::tokenize::*; use crate::vectorize::*; #[test] fn test_count_vectorizer_simple() { // Example 1 + let tokenizer = RegexpTokenizer::new("\\b\\w+\\w\\b".to_string()); let documents = vec!["cat dog cat".to_string()]; - let mut vect = CountVectorizer::new(); + let mut vect = CountVectorizer::new(tokenizer.clone()); + let X = vect.fit_transform(&documents); assert_eq!(X.to_dense(), array![[2, 1]]); @@ -21,8 +24,7 @@ fn test_count_vectorizer_simple() { "The sky sky sky is blue".to_string(), ]; let X_ref = array![[0, 1, 0, 1, 1, 2], [1, 0, 1, 0, 3, 1]]; - - let mut vect = CountVectorizer::new(); + let mut vect = CountVectorizer::new(tokenizer); let X = vect.fit_transform(&documents); assert_eq!(X.to_dense().shape(), X_ref.shape()); @@ -38,7 +40,9 @@ fn test_count_vectorizer_simple() { fn test_vectorize_empty_countvectorizer() { let documents = vec!["some tokens".to_string(), "".to_string()]; - let mut vect = CountVectorizer::new(); + let tokenizer = RegexpTokenizer::new("\\b\\w+\\w\\b".to_string()); + + let mut vect = CountVectorizer::new(tokenizer); vect.fit_transform(&documents); vect.fit(&documents); @@ -48,8 +52,9 @@ fn test_vectorize_empty_countvectorizer() { #[test] fn test_vectorize_empty_hashingvectorizer() { let documents = vec!["some tokens".to_string(), "".to_string()]; + let tokenizer = RegexpTokenizer::new("\\b\\w+\\w\\b".to_string()); - let vect = HashingVectorizer::new(); + let vect = HashingVectorizer::new(tokenizer); vect.fit_transform(&documents); vect.transform(&documents); @@ -57,12 +62,13 @@ fn test_vectorize_empty_hashingvectorizer() { #[test] fn test_count_vectorizer_fit_transform() { + let tokenizer = RegexpTokenizer::new("\\b\\w+\\w\\b".to_string()); for documents in &[vec!["cat dog cat".to_string()]] { - let mut vect = CountVectorizer::new(); + let mut vect = CountVectorizer::new(tokenizer.clone()); vect.fit(&documents); let X = vect.transform(&documents); - let mut vect2 = CountVectorizer::new(); + let mut vect2 = CountVectorizer::new(tokenizer.clone()); let X2 = vect2.fit_transform(&documents); assert_eq!(vect.vocabulary, vect2.vocabulary); println!("{:?}", vect.vocabulary); @@ -87,7 +93,9 @@ fn test_hashing_vectorizer_simple() { String::from("The sky is blue"), ]; - let vect = HashingVectorizer::new(); + let tokenizer = VTextTokenizer::new("en"); + + let vect = HashingVectorizer::new(tokenizer); let vect = vect.fit(&documents); let X = vect.transform(&documents); assert_eq!(X.indptr(), &[0, 4, 8]); @@ -116,17 +124,37 @@ fn test_hashing_vectorizer_simple() { fn test_empty_dataset() { let documents: Vec = vec![]; - let mut vectorizer = CountVectorizer::new(); + let tokenizer = VTextTokenizer::new("en"); + let mut vectorizer = CountVectorizer::new(tokenizer.clone()); let X = vectorizer.fit_transform(&documents); assert_eq!(X.data(), &[]); assert_eq!(X.indices(), &[]); assert_eq!(X.indptr(), &[0]); - let vectorizer = HashingVectorizer::new(); + let vectorizer = HashingVectorizer::new(tokenizer); let X = vectorizer.fit_transform(&documents); assert_eq!(X.data(), &[]); assert_eq!(X.indices(), &[]); assert_eq!(X.indptr(), &[0]); } + +#[test] +fn test_dispatch_tokenizer() { + let tokenizer = VTextTokenizer::new("en"); + CountVectorizer::new(tokenizer.clone()); + HashingVectorizer::new(tokenizer); + + let tokenizer = UnicodeSegmentTokenizer::new(false); + CountVectorizer::new(tokenizer.clone()); + HashingVectorizer::new(tokenizer); + + let tokenizer = RegexpTokenizer::new("\\b\\w+\\w\\b".to_string()); + CountVectorizer::new(tokenizer.clone()); + HashingVectorizer::new(tokenizer); + + let tokenizer = CharacterTokenizer::new(4); + CountVectorizer::new(tokenizer.clone()); + HashingVectorizer::new(tokenizer); +}