Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tokenizers dispatch in vectorizers #53

Merged
merged 16 commits into from
Jun 7, 2019
Merged
11 changes: 7 additions & 4 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,17 @@ fn result_to_csr(py: Python, x: CsMat<i32>) -> PyResult<PyCsrArray> {

#[pyclass]
pub struct _HashingVectorizerWrapper {
inner: vtext::vectorize::HashingVectorizer,
inner: vtext::vectorize::HashingVectorizer<vtext::tokenize::RegexpTokenizer>,
}

#[pymethods]
impl _HashingVectorizerWrapper {
#[new]
#[args(n_jobs = 1)]
fn new(obj: &PyRawObject, n_jobs: usize) {
let estimator = vtext::vectorize::HashingVectorizer::new().n_jobs(n_jobs);
let tokenizer = vtext::tokenize::RegexpTokenizer::new("\\b\\w\\w+\\b".to_string());
let estimator = vtext::vectorize::HashingVectorizer::new(tokenizer).n_jobs(n_jobs);

obj.init(_HashingVectorizerWrapper { inner: estimator });
}

Expand All @@ -84,15 +86,16 @@ impl _HashingVectorizerWrapper {

#[pyclass]
pub struct _CountVectorizerWrapper {
inner: vtext::vectorize::CountVectorizer,
inner: vtext::vectorize::CountVectorizer<vtext::tokenize::RegexpTokenizer>,
}

#[pymethods]
impl _CountVectorizerWrapper {
#[new]
#[args(n_jobs = 1)]
fn new(obj: &PyRawObject, n_jobs: usize) {
let estimator = vtext::vectorize::CountVectorizer::new().n_jobs(n_jobs);
let tokenizer = vtext::tokenize::RegexpTokenizer::new("\\b\\w\\w+\\b".to_string());
let estimator = vtext::vectorize::CountVectorizer::new(tokenizer).n_jobs(n_jobs);
obj.init(_CountVectorizerWrapper { inner: estimator });
}

Expand Down
7 changes: 4 additions & 3 deletions src/tokenize/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ pub trait Tokenizer: fmt::Debug {

/// Regular expression tokenizer
///
#[derive(Clone)]
pub struct RegexpTokenizer {
pub pattern: String,
regexp: Regex,
Expand Down Expand Up @@ -98,7 +99,7 @@ impl fmt::Debug for RegexpTokenizer {
/// ## References
///
/// * [Unicode® Standard Annex #29](http://www.unicode.org/reports/tr29/)
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct UnicodeSegmentTokenizer {
pub word_bounds: bool,
}
Expand Down Expand Up @@ -135,7 +136,7 @@ impl Tokenizer for UnicodeSegmentTokenizer {
/// ## References
///
/// * [Unicode® Standard Annex #29](http://www.unicode.org/reports/tr29/)
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct VTextTokenizer {
pub lang: String,
}
Expand Down Expand Up @@ -270,7 +271,7 @@ impl Tokenizer for VTextTokenizer {
}

/// Character tokenizer
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct CharacterTokenizer {
pub window_size: usize,
}
Expand Down
45 changes: 20 additions & 25 deletions src/vectorize/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,22 @@ This module allows computing a sparse document term matrix from a text corpus.
```rust
extern crate vtext;

use vtext::tokenize::{VTextTokenizer,Tokenizer};
use vtext::vectorize::CountVectorizer;

let documents = vec![
String::from("Some text input"),
String::from("Another line"),
];

let mut vectorizer = CountVectorizer::new();
let tokenizer = VTextTokenizer::new("en");

let mut vectorizer = CountVectorizer::new(tokenizer);
let X = vectorizer.fit_transform(&documents);
// returns a sparse CSR matrix with document-terms counts
*/

use crate::math::CSRArray;
use crate::tokenize;
use crate::tokenize::Tokenizer;
use hashbrown::{HashMap, HashSet};
use itertools::sorted;
Expand Down Expand Up @@ -87,25 +90,25 @@ fn _sum_duplicates(tf: &mut CSRArray, indices_local: &[i32], nnz: &mut usize) {
}

#[derive(Debug)]
pub struct CountVectorizer {
pub struct CountVectorizer<T> {
lowercase: bool,
token_pattern: String,
_n_jobs: usize,
tokenizer: T,
// vocabulary uses i32 indices, to avoid memory copies when converting
// to sparse CSR arrays in Python with scipy.sparse
pub vocabulary: HashMap<String, i32>,
_n_jobs: usize,
}

pub enum Vectorizer {}

impl CountVectorizer {
impl<T: Tokenizer + Sync> CountVectorizer<T> {
/// Initialize a CountVectorizer estimator
pub fn new() -> Self {
pub fn new(tokenizer: T) -> Self {
CountVectorizer {
lowercase: true,
token_pattern: String::from(TOKEN_PATTERN_DEFAULT),
vocabulary: HashMap::with_capacity_and_hasher(1000, Default::default()),
_n_jobs: 1,
tokenizer,
}
}

Expand All @@ -124,14 +127,12 @@ impl CountVectorizer {
///
/// This lists the vocabulary
pub fn fit(&mut self, X: &[String]) -> () {
let tokenizer = tokenize::RegexpTokenizer::new(TOKEN_PATTERN_DEFAULT.to_string());

let tokenize = |X: &[String]| -> HashSet<String> {
let mut _vocab: HashSet<String> = HashSet::with_capacity(1000);

for doc in X {
let doc = doc.to_ascii_lowercase();
let tokens = tokenizer.tokenize(&doc);
let tokens = self.tokenizer.tokenize(&doc);

for token in tokens {
if !_vocab.contains(token) {
Expand Down Expand Up @@ -176,14 +177,12 @@ impl CountVectorizer {

let mut nnz: usize = 0;

let tokenizer = tokenize::RegexpTokenizer::new(TOKEN_PATTERN_DEFAULT.to_string());

let tokenize_map = |doc: &str| -> Vec<i32> {
// Closure to tokenize a document and returns hash indices for each token

let mut indices_local: Vec<i32> = Vec::with_capacity(10);

for token in tokenizer.tokenize(doc) {
for token in self.tokenizer.tokenize(doc) {
if let Some(_id) = self.vocabulary.get(token) {
indices_local.push(*_id)
};
Expand Down Expand Up @@ -239,14 +238,12 @@ impl CountVectorizer {
let mut nnz: usize = 0;
let mut indices_local: Vec<i32> = Vec::new();

let tokenizer = tokenize::RegexpTokenizer::new(TOKEN_PATTERN_DEFAULT.to_string());

let pipe = X.iter().map(|doc| doc.to_ascii_lowercase());

let mut vocabulary_size: i32 = 0;

for document in pipe {
let tokens = tokenizer.tokenize(&document);
let tokens = self.tokenizer.tokenize(&document);

indices_local.clear();

Expand Down Expand Up @@ -277,23 +274,23 @@ impl CountVectorizer {
}

#[derive(Debug)]
pub struct HashingVectorizer {
pub struct HashingVectorizer<T> {
lowercase: bool,
token_pattern: String,
tokenizer: T,
n_features: u64,
_n_jobs: usize,
thread_pool: Option<rayon::ThreadPool>,
}

impl HashingVectorizer {
impl<T: Tokenizer + Sync> HashingVectorizer<T> {
/// Create a new HashingVectorizer estimator
pub fn new() -> Self {
pub fn new(tokenizer: T) -> Self {
HashingVectorizer {
lowercase: true,
token_pattern: String::from(TOKEN_PATTERN_DEFAULT),
n_features: 1048576,
_n_jobs: 1,
thread_pool: None,
tokenizer,
}
}

Expand Down Expand Up @@ -336,14 +333,12 @@ impl HashingVectorizer {

let mut nnz: usize = 0;

let tokenizer = tokenize::RegexpTokenizer::new(TOKEN_PATTERN_DEFAULT.to_string());

let tokenize_hash = |doc: &str| -> Vec<i32> {
// Closure to tokenize a document and returns hash indices for each token

let mut indices_local: Vec<i32> = Vec::with_capacity(10);

for token in tokenizer.tokenize(doc) {
for token in self.tokenizer.tokenize(doc) {
// set the RNG seeds to get reproducible hashing
let hash = seahash::hash_seeded(token.as_bytes(), 1, 1000, 200, 89);
let hash = (hash % self.n_features) as i32;
Expand Down
48 changes: 38 additions & 10 deletions src/vectorize/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
// <http://apache.org/licenses/LICENSE-2.0>. This file may not be copied,
// modified, or distributed except according to those terms.

use crate::tokenize::*;
use crate::vectorize::*;

#[test]
fn test_count_vectorizer_simple() {
// Example 1
let tokenizer = RegexpTokenizer::new("\\b\\w+\\w\\b".to_string());

let documents = vec!["cat dog cat".to_string()];
let mut vect = CountVectorizer::new();
let mut vect = CountVectorizer::new(tokenizer.clone());

let X = vect.fit_transform(&documents);
assert_eq!(X.to_dense(), array![[2, 1]]);

Expand All @@ -21,8 +24,7 @@ fn test_count_vectorizer_simple() {
"The sky sky sky is blue".to_string(),
];
let X_ref = array![[0, 1, 0, 1, 1, 2], [1, 0, 1, 0, 3, 1]];

let mut vect = CountVectorizer::new();
let mut vect = CountVectorizer::new(tokenizer);

let X = vect.fit_transform(&documents);
assert_eq!(X.to_dense().shape(), X_ref.shape());
Expand All @@ -38,7 +40,9 @@ fn test_count_vectorizer_simple() {
fn test_vectorize_empty_countvectorizer() {
let documents = vec!["some tokens".to_string(), "".to_string()];

let mut vect = CountVectorizer::new();
let tokenizer = RegexpTokenizer::new("\\b\\w+\\w\\b".to_string());

let mut vect = CountVectorizer::new(tokenizer);
vect.fit_transform(&documents);

vect.fit(&documents);
Expand All @@ -48,21 +52,23 @@ fn test_vectorize_empty_countvectorizer() {
#[test]
fn test_vectorize_empty_hashingvectorizer() {
let documents = vec!["some tokens".to_string(), "".to_string()];
let tokenizer = RegexpTokenizer::new("\\b\\w+\\w\\b".to_string());

let vect = HashingVectorizer::new();
let vect = HashingVectorizer::new(tokenizer);
vect.fit_transform(&documents);

vect.transform(&documents);
}

#[test]
fn test_count_vectorizer_fit_transform() {
let tokenizer = RegexpTokenizer::new("\\b\\w+\\w\\b".to_string());
for documents in &[vec!["cat dog cat".to_string()]] {
let mut vect = CountVectorizer::new();
let mut vect = CountVectorizer::new(tokenizer.clone());
vect.fit(&documents);
let X = vect.transform(&documents);

let mut vect2 = CountVectorizer::new();
let mut vect2 = CountVectorizer::new(tokenizer.clone());
let X2 = vect2.fit_transform(&documents);
assert_eq!(vect.vocabulary, vect2.vocabulary);
println!("{:?}", vect.vocabulary);
Expand All @@ -87,7 +93,9 @@ fn test_hashing_vectorizer_simple() {
String::from("The sky is blue"),
];

let vect = HashingVectorizer::new();
let tokenizer = VTextTokenizer::new("en");

let vect = HashingVectorizer::new(tokenizer);
let vect = vect.fit(&documents);
let X = vect.transform(&documents);
assert_eq!(X.indptr(), &[0, 4, 8]);
Expand Down Expand Up @@ -116,17 +124,37 @@ fn test_hashing_vectorizer_simple() {
fn test_empty_dataset() {
let documents: Vec<String> = vec![];

let mut vectorizer = CountVectorizer::new();
let tokenizer = VTextTokenizer::new("en");
let mut vectorizer = CountVectorizer::new(tokenizer.clone());

let X = vectorizer.fit_transform(&documents);
assert_eq!(X.data(), &[]);
assert_eq!(X.indices(), &[]);
assert_eq!(X.indptr(), &[0]);

let vectorizer = HashingVectorizer::new();
let vectorizer = HashingVectorizer::new(tokenizer);

let X = vectorizer.fit_transform(&documents);
assert_eq!(X.data(), &[]);
assert_eq!(X.indices(), &[]);
assert_eq!(X.indptr(), &[0]);
}

#[test]
fn test_dispatch_tokenizer() {
let tokenizer = VTextTokenizer::new("en");
CountVectorizer::new(tokenizer.clone());
HashingVectorizer::new(tokenizer);

let tokenizer = UnicodeSegmentTokenizer::new(false);
CountVectorizer::new(tokenizer.clone());
HashingVectorizer::new(tokenizer);

let tokenizer = RegexpTokenizer::new("\\b\\w+\\w\\b".to_string());
CountVectorizer::new(tokenizer.clone());
HashingVectorizer::new(tokenizer);

let tokenizer = CharacterTokenizer::new(4);
CountVectorizer::new(tokenizer.clone());
HashingVectorizer::new(tokenizer);
}