From 01d0b299e737c77b14e584eddea9ffba05242dbd Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 14 Oct 2024 17:08:14 +0200 Subject: [PATCH 1/3] current updates --- bindings/python/src/processors.rs | 121 ++++++++++++++++++++++---- tokenizers/src/processors/sequence.rs | 8 ++ tokenizers/src/processors/template.rs | 56 +++++++++++- 3 files changed, 166 insertions(+), 19 deletions(-) diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 1d8e8dfac..71d7d6710 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -1,10 +1,11 @@ use std::convert::TryInto; use std::sync::Arc; +use std::sync::RwLock; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; - +use std::ops::DerefMut; use crate::encoding::PyEncoding; use crate::error::ToPyResult; use serde::{Deserialize, Serialize}; @@ -30,17 +31,17 @@ use tokenizers as tk; #[derive(Clone, Deserialize, Serialize)] #[serde(transparent)] pub struct PyPostProcessor { - pub processor: Arc, + pub processor: Arc>, } impl PyPostProcessor { - pub fn new(processor: Arc) -> Self { + pub fn new(processor: Arc>) -> Self { PyPostProcessor { processor } } pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult { let base = self.clone(); - Ok(match self.processor.as_ref() { + Ok(match self.processor.read().unwrap().clone() { PostProcessorWrapper::ByteLevel(_) => Py::new(py, (PyByteLevel {}, base))?.into_py(py), PostProcessorWrapper::Bert(_) => Py::new(py, (PyBertProcessing {}, base))?.into_py(py), PostProcessorWrapper::Roberta(_) => { @@ -56,7 +57,7 @@ impl PyPostProcessor { impl PostProcessor for PyPostProcessor { fn added_tokens(&self, is_pair: bool) -> usize { - self.processor.added_tokens(is_pair) + self.processor.read().unwrap().added_tokens(is_pair) } fn process_encodings( @@ -64,7 +65,7 @@ impl PostProcessor for PyPostProcessor { encodings: Vec, add_special_tokens: bool, ) -> tk::Result> { - self.processor + self.processor.read().unwrap() .process_encodings(encodings, add_special_tokens) } } @@ -72,7 +73,7 @@ impl PostProcessor for PyPostProcessor { #[pymethods] impl PyPostProcessor { fn __getstate__(&self, py: Python) -> PyResult { - let data = serde_json::to_string(self.processor.as_ref()).map_err(|e| { + let data = serde_json::to_string(&self.processor).map_err(|e| { exceptions::PyException::new_err(format!( "Error while attempting to pickle PostProcessor: {}", e @@ -106,7 +107,7 @@ impl PyPostProcessor { /// :obj:`int`: The number of tokens to add #[pyo3(text_signature = "(self, is_pair)")] fn num_special_tokens_to_add(&self, is_pair: bool) -> usize { - self.processor.added_tokens(is_pair) + self.processor.read().unwrap().added_tokens(is_pair) } /// Post-process the given encodings, generating the final one @@ -131,7 +132,7 @@ impl PyPostProcessor { pair: Option<&PyEncoding>, add_special_tokens: bool, ) -> PyResult { - let final_encoding = ToPyResult(self.processor.process( + let final_encoding = ToPyResult(self.processor.read().unwrap().process( encoding.encoding.clone(), pair.map(|e| e.encoding.clone()), add_special_tokens, @@ -151,6 +152,42 @@ impl PyPostProcessor { } } +macro_rules! getter { + ($self: ident, $variant: ident, $($name: tt)+) => {{ + let super_ = $self.as_ref(); + if let PostProcessorWrapper::$variant(ref post) = *super_.processor.read().unwrap() { + let output = post.$($name)+; + return format!("{:?}", output) + } else { + unreachable!() + } + }}; +} + +macro_rules! setter { + ($self: ident, $variant: ident, $name: ident, $value: expr) => {{ + let super_ = $self; + if let PostProcessorWrapper::$variant(ref mut post) = super_.processor.as_ref() { + post.$name = $value; + } + }}; + ($self: ident, $variant: ident, @$name: ident, $value: expr) => {{ + let super_ = &$self.as_ref(); + match &super_.processor.as_ref() { + PostProcessorWrapper::$variant(post_variant) => post_variant.$name($value), + _ => unreachable!(), + } + + { + if let Some(PostProcessorWrapper::$variant(post_variant)) = + Arc::get_mut(&mut super_.processor) + { + post_variant.$name($value); + } + }; + };}; +} + /// This post-processor takes care of adding the special tokens needed by /// a Bert model: /// @@ -172,7 +209,7 @@ impl PyBertProcessing { fn new(sep: (String, u32), cls: (String, u32)) -> (Self, PyPostProcessor) { ( PyBertProcessing {}, - PyPostProcessor::new(Arc::new(BertProcessing::new(sep, cls).into())), + PyPostProcessor::new(Arc::new(RwLock::new(BertProcessing::new(sep, cls).into()))), ) } @@ -222,7 +259,7 @@ impl PyRobertaProcessing { .add_prefix_space(add_prefix_space); ( PyRobertaProcessing {}, - PyPostProcessor::new(Arc::new(proc.into())), + PyPostProcessor::new(Arc::new(RwLock::new(proc.into()))), ) } @@ -257,7 +294,7 @@ impl PyByteLevel { ( PyByteLevel {}, - PyPostProcessor::new(Arc::new(byte_level.into())), + PyPostProcessor::new(Arc::new(RwLock::new(byte_level.into()))), ) } } @@ -421,9 +458,43 @@ impl PyTemplateProcessing { Ok(( PyTemplateProcessing {}, - PyPostProcessor::new(Arc::new(processor.into())), + PyPostProcessor::new(Arc::new(RwLock::new(processor.into()))), )) } + + #[getter] + fn get_single(self_: PyRef) -> String{ + getter!(self_, Template, get_single()) + } + + #[setter] + fn set_single(self_:PyRefMut, single: PyTemplate) { + let template: Template = Template::from(single); + + let super_ = &self_.into_super(); + + // Acquire a write lock on the processor + let binding = super_.processor.clone(); // Clone the Arc + let mut write_lock = match binding.write() { // Make this mutable + Ok(lock) => lock, + Err(e) => { + eprintln!("Failed to acquire write lock: {:?}", e); + return; // Handle lock acquisition failure appropriately + } + }; + + // Use deref_mut to get a mutable reference and match against the PostProcessorWrapper type + match write_lock.deref_mut() { + PostProcessorWrapper::Template(value) => { + println!("Created template single : {template:?}"); + value.set_single(template.clone()); + }, + _ => { + eprintln!("Processor is not of type PostProcessorWrapper::Template"); + } + } + + } } /// Sequence Processor @@ -441,19 +512,33 @@ impl PySequence { let mut processors: Vec = Vec::with_capacity(processors_py.len()); for n in processors_py.iter() { let processor: PyRef = n.extract().unwrap(); - let processor = processor.processor.as_ref(); + let processor = processor.processor.write().unwrap(); processors.push(processor.clone()); } let sequence_processor = Sequence::new(processors); ( PySequence {}, - PyPostProcessor::new(Arc::new(PostProcessorWrapper::Sequence(sequence_processor))), + PyPostProcessor::new(Arc::new(RwLock::new(PostProcessorWrapper::Sequence(sequence_processor)))), ) } fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> { PyTuple::new_bound(py, [PyList::empty_bound(py)]) } + + fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult> { + match &self_.as_ref().processor.read().unwrap().clone() { + PostProcessorWrapper::Sequence(inner) => match inner.get(index) { + Some(item) => PyPostProcessor::new(Arc::new(RwLock::new(item.clone()))).get_as_subtype(py), + _ => Err(PyErr::new::( + "Index not found", + )), + }, + _ => Err(PyErr::new::( + "This processor is not a Sequence, it does not support __getitem__", + )), + } + } } /// Processors Module @@ -481,9 +566,9 @@ mod test { #[test] fn get_subtype() { Python::with_gil(|py| { - let py_proc = PyPostProcessor::new(Arc::new( + let py_proc = PyPostProcessor::new(Arc::new(RwLock::new( BertProcessing::new(("SEP".into(), 0), ("CLS".into(), 1)).into(), - )); + ))); let py_bert = py_proc.get_as_subtype(py).unwrap(); assert_eq!( "BertProcessing", @@ -499,7 +584,7 @@ mod test { let rs_processing_ser = serde_json::to_string(&rs_processing).unwrap(); let rs_wrapper_ser = serde_json::to_string(&rs_wrapper).unwrap(); - let py_processing = PyPostProcessor::new(Arc::new(rs_wrapper)); + let py_processing = PyPostProcessor::new(Arc::new(RwLock::new(rs_wrapper))); let py_ser = serde_json::to_string(&py_processing).unwrap(); assert_eq!(py_ser, rs_processing_ser); assert_eq!(py_ser, rs_wrapper_ser); diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index 66c670ad8..8d273252a 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -13,6 +13,14 @@ impl Sequence { pub fn new(processors: Vec) -> Self { Self { processors } } + + pub fn get(&self, index: usize) -> Option<& PostProcessorWrapper> { + self.processors.get(index as usize) + } + + pub fn get_mut(&mut self, index: usize) -> Option<&mut PostProcessorWrapper> { + self.processors.get_mut(index) + } } impl PostProcessor for Sequence { diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index 7f1fed54d..8c9e88145 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -338,7 +338,7 @@ impl From> for Tokens { #[builder(build_fn(validate = "Self::validate"))] pub struct TemplateProcessing { #[builder(try_setter, default = "\"$0\".try_into().unwrap()")] - single: Template, + pub single: Template, #[builder(try_setter, default = "\"$A:0 $B:1\".try_into().unwrap()")] pair: Template, #[builder(setter(skip), default = "self.default_added(true)")] @@ -351,6 +351,60 @@ pub struct TemplateProcessing { special_tokens: Tokens, } + +impl TemplateProcessing { + // Getter for `single` + pub fn get_single(& self) -> Template { + self.single.clone() + } + + // Setter for `single` + pub fn set_single(&mut self, single: Template) { + println!("Setting single to: {:?}", single); // Debugging output + self.single = single; + } + + // Getter for `pair` + pub fn get_pair(&self) -> &Template { + &self.pair + } + + // Setter for `pair` + pub fn set_pair(&mut self, pair: Template) { + self.pair = pair; + } + + // Getter for `added_single` + pub fn get_added_single(&self) -> usize { + self.added_single + } + + // Setter for `added_single` + pub fn set_added_single(&mut self, added_single: usize) { + self.added_single = added_single; + } + + // Getter for `added_pair` + pub fn get_added_pair(&self) -> usize { + self.added_pair + } + + // Setter for `added_pair` + pub fn set_added_pair(&mut self, added_pair: usize) { + self.added_pair = added_pair; + } + + // Getter for `special_tokens` + pub fn get_special_tokens(&self) -> &Tokens { + &self.special_tokens + } + + // Setter for `special_tokens` + pub fn set_special_tokens(&mut self, special_tokens: Tokens) { + self.special_tokens = special_tokens; + } +} + impl From<&str> for TemplateProcessingBuilderError { fn from(e: &str) -> Self { e.to_string().into() From 1d67a76087852b81e73c6d8bbc9bfe3fd60c58d8 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 14 Oct 2024 19:41:04 +0200 Subject: [PATCH 2/3] simplify --- bindings/python/src/processors.rs | 29 +++++------------------------ 1 file changed, 5 insertions(+), 24 deletions(-) diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 71d7d6710..1a59d4a20 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -468,32 +468,13 @@ impl PyTemplateProcessing { } #[setter] - fn set_single(self_:PyRefMut, single: PyTemplate) { + fn set_single(self_: PyRef, single: PyTemplate) { let template: Template = Template::from(single); - - let super_ = &self_.into_super(); - - // Acquire a write lock on the processor - let binding = super_.processor.clone(); // Clone the Arc - let mut write_lock = match binding.write() { // Make this mutable - Ok(lock) => lock, - Err(e) => { - eprintln!("Failed to acquire write lock: {:?}", e); - return; // Handle lock acquisition failure appropriately - } + let super_ = self_.as_ref(); + let mut wrapper = super_.processor.write().unwrap(); + if let PostProcessorWrapper::Template(ref mut post) = *wrapper { + post.set_single(template.into()); }; - - // Use deref_mut to get a mutable reference and match against the PostProcessorWrapper type - match write_lock.deref_mut() { - PostProcessorWrapper::Template(value) => { - println!("Created template single : {template:?}"); - value.set_single(template.clone()); - }, - _ => { - eprintln!("Processor is not of type PostProcessorWrapper::Template"); - } - } - } } From 488a57001cd763640b8b62c1db507a017425c8a3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 14 Oct 2024 22:19:05 +0200 Subject: [PATCH 3/3] set_item works, but `tokenizer._tokenizer.post_processor[1].single = ["$0", ""]` does not ! --- bindings/python/src/processors.rs | 53 +++++++++++++++++++++++++-- tokenizers/src/processors/sequence.rs | 12 ++++++ tokenizers/src/processors/template.rs | 1 + 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 1a59d4a20..3c873a447 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -34,6 +34,17 @@ pub struct PyPostProcessor { pub processor: Arc>, } +impl From for PyPostProcessor +where + I: Into, +{ + fn from(processor: I) -> Self { + PyPostProcessor { + processor: Arc::new(RwLock::new(processor.into())), // Wrap the PostProcessorWrapper in Arc> + } + } +} + impl PyPostProcessor { pub fn new(processor: Arc>) -> Self { PyPostProcessor { processor } @@ -508,9 +519,21 @@ impl PySequence { } fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult> { - match &self_.as_ref().processor.read().unwrap().clone() { - PostProcessorWrapper::Sequence(inner) => match inner.get(index) { - Some(item) => PyPostProcessor::new(Arc::new(RwLock::new(item.clone()))).get_as_subtype(py), + + let super_ = self_.as_ref(); + let mut wrapper = super_.processor.write().unwrap(); + // if let PostProcessorWrapper::Sequence(ref mut post) = *wrapper { + // match post.get(index) { + // Some(item) => PyPostProcessor::new(Arc::clone(item)).get_as_subtype(py), + // _ => Err(PyErr::new::( + // "Index not found", + // )), + // } + // } + + match *wrapper { + PostProcessorWrapper::Sequence(ref mut inner) => match inner.get_mut(index) { + Some(item) => PyPostProcessor::new(Arc::new(RwLock::new(item.to_owned()))).get_as_subtype(py), _ => Err(PyErr::new::( "Index not found", )), @@ -520,6 +543,30 @@ impl PySequence { )), } } + + fn __setitem__(self_: PyRefMut<'_, Self>, py: Python<'_>, index: usize, value: PyRef<'_, PyPostProcessor>) -> PyResult<()> { + let super_ = self_.as_ref(); + let mut wrapper = super_.processor.write().unwrap(); + let value = value.processor.read().unwrap().clone(); + match *wrapper { + PostProcessorWrapper::Sequence(ref mut inner) => { + // Convert the Py into the appropriate Rust type + // Ensure we can set an item at the given index + if index < inner.get_processors().len() { + inner.set_mut(index, value); // Assuming you want to wrap the new item in Arc + + Ok(()) + } else { + Err(PyErr::new::( + "Index out of bounds", + )) + } + }, + _ => Err(PyErr::new::( + "This processor is not a Sequence, it does not support __setitem__", + )), + } + } } /// Processors Module diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index 8d273252a..b9fdbb4dd 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -21,6 +21,18 @@ impl Sequence { pub fn get_mut(&mut self, index: usize) -> Option<&mut PostProcessorWrapper> { self.processors.get_mut(index) } + + pub fn set_mut(&mut self, index: usize, post_proc: PostProcessorWrapper) { + self.processors[index as usize] = post_proc; + } + + pub fn get_processors(&self) -> &[PostProcessorWrapper] { + &self.processors + } + + pub fn get_processors_mut(&mut self) -> &mut [PostProcessorWrapper] { + &mut self.processors + } } impl PostProcessor for Sequence { diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index 8c9e88145..be9df6b36 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -362,6 +362,7 @@ impl TemplateProcessing { pub fn set_single(&mut self, single: Template) { println!("Setting single to: {:?}", single); // Debugging output self.single = single; + println!("Single is now {:?}", self.single); } // Getter for `pair`