Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support updating template processors #1652

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 131 additions & 18 deletions bindings/python/src/processors.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::convert::TryInto;
use std::sync::Arc;
use std::sync::RwLock;

use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;

use std::ops::DerefMut;
use crate::encoding::PyEncoding;
use crate::error::ToPyResult;
use serde::{Deserialize, Serialize};
Expand All @@ -30,17 +31,28 @@ use tokenizers as tk;
#[derive(Clone, Deserialize, Serialize)]
#[serde(transparent)]
pub struct PyPostProcessor {
pub processor: Arc<PostProcessorWrapper>,
pub processor: Arc<RwLock<PostProcessorWrapper>>,
}

impl<I> From<I> for PyPostProcessor
where
I: Into<PostProcessorWrapper>,
{
fn from(processor: I) -> Self {
PyPostProcessor {
processor: Arc::new(RwLock::new(processor.into())), // Wrap the PostProcessorWrapper in Arc<RwLock<>>
}
}
}

impl PyPostProcessor {
pub fn new(processor: Arc<PostProcessorWrapper>) -> Self {
pub fn new(processor: Arc<RwLock<PostProcessorWrapper>>) -> Self {
PyPostProcessor { processor }
}

pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult<PyObject> {
let base = self.clone();
Ok(match self.processor.as_ref() {
Ok(match self.processor.read().unwrap().clone() {
PostProcessorWrapper::ByteLevel(_) => Py::new(py, (PyByteLevel {}, base))?.into_py(py),
PostProcessorWrapper::Bert(_) => Py::new(py, (PyBertProcessing {}, base))?.into_py(py),
PostProcessorWrapper::Roberta(_) => {
Expand All @@ -56,23 +68,23 @@ impl PyPostProcessor {

impl PostProcessor for PyPostProcessor {
fn added_tokens(&self, is_pair: bool) -> usize {
self.processor.added_tokens(is_pair)
self.processor.read().unwrap().added_tokens(is_pair)
}

fn process_encodings(
&self,
encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> tk::Result<Vec<Encoding>> {
self.processor
self.processor.read().unwrap()
.process_encodings(encodings, add_special_tokens)
}
}

#[pymethods]
impl PyPostProcessor {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(self.processor.as_ref()).map_err(|e| {
let data = serde_json::to_string(&self.processor).map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to pickle PostProcessor: {}",
e
Expand Down Expand Up @@ -106,7 +118,7 @@ impl PyPostProcessor {
/// :obj:`int`: The number of tokens to add
#[pyo3(text_signature = "(self, is_pair)")]
fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
self.processor.added_tokens(is_pair)
self.processor.read().unwrap().added_tokens(is_pair)
}

/// Post-process the given encodings, generating the final one
Expand All @@ -131,7 +143,7 @@ impl PyPostProcessor {
pair: Option<&PyEncoding>,
add_special_tokens: bool,
) -> PyResult<PyEncoding> {
let final_encoding = ToPyResult(self.processor.process(
let final_encoding = ToPyResult(self.processor.read().unwrap().process(
encoding.encoding.clone(),
pair.map(|e| e.encoding.clone()),
add_special_tokens,
Expand All @@ -151,6 +163,42 @@ impl PyPostProcessor {
}
}

macro_rules! getter {
($self: ident, $variant: ident, $($name: tt)+) => {{
let super_ = $self.as_ref();
if let PostProcessorWrapper::$variant(ref post) = *super_.processor.read().unwrap() {
let output = post.$($name)+;
return format!("{:?}", output)
} else {
unreachable!()
}
}};
}

macro_rules! setter {
($self: ident, $variant: ident, $name: ident, $value: expr) => {{
let super_ = $self;
if let PostProcessorWrapper::$variant(ref mut post) = super_.processor.as_ref() {
post.$name = $value;
}
}};
($self: ident, $variant: ident, @$name: ident, $value: expr) => {{
let super_ = &$self.as_ref();
match &super_.processor.as_ref() {
PostProcessorWrapper::$variant(post_variant) => post_variant.$name($value),
_ => unreachable!(),
}

{
if let Some(PostProcessorWrapper::$variant(post_variant)) =
Arc::get_mut(&mut super_.processor)
{
post_variant.$name($value);
}
};
};};
}

/// This post-processor takes care of adding the special tokens needed by
/// a Bert model:
///
Expand All @@ -172,7 +220,7 @@ impl PyBertProcessing {
fn new(sep: (String, u32), cls: (String, u32)) -> (Self, PyPostProcessor) {
(
PyBertProcessing {},
PyPostProcessor::new(Arc::new(BertProcessing::new(sep, cls).into())),
PyPostProcessor::new(Arc::new(RwLock::new(BertProcessing::new(sep, cls).into()))),
)
}

Expand Down Expand Up @@ -222,7 +270,7 @@ impl PyRobertaProcessing {
.add_prefix_space(add_prefix_space);
(
PyRobertaProcessing {},
PyPostProcessor::new(Arc::new(proc.into())),
PyPostProcessor::new(Arc::new(RwLock::new(proc.into()))),
)
}

Expand Down Expand Up @@ -257,7 +305,7 @@ impl PyByteLevel {

(
PyByteLevel {},
PyPostProcessor::new(Arc::new(byte_level.into())),
PyPostProcessor::new(Arc::new(RwLock::new(byte_level.into()))),
)
}
}
Expand Down Expand Up @@ -421,9 +469,24 @@ impl PyTemplateProcessing {

Ok((
PyTemplateProcessing {},
PyPostProcessor::new(Arc::new(processor.into())),
PyPostProcessor::new(Arc::new(RwLock::new(processor.into()))),
))
}

#[getter]
fn get_single(self_: PyRef<Self>) -> String{
getter!(self_, Template, get_single())
}

#[setter]
fn set_single(self_: PyRef<Self>, single: PyTemplate) {
let template: Template = Template::from(single);
let super_ = self_.as_ref();
let mut wrapper = super_.processor.write().unwrap();
if let PostProcessorWrapper::Template(ref mut post) = *wrapper {
post.set_single(template.into());
};
}
}

/// Sequence Processor
Expand All @@ -441,19 +504,69 @@ impl PySequence {
let mut processors: Vec<PostProcessorWrapper> = Vec::with_capacity(processors_py.len());
for n in processors_py.iter() {
let processor: PyRef<PyPostProcessor> = n.extract().unwrap();
let processor = processor.processor.as_ref();
let processor = processor.processor.write().unwrap();
processors.push(processor.clone());
}
let sequence_processor = Sequence::new(processors);
(
PySequence {},
PyPostProcessor::new(Arc::new(PostProcessorWrapper::Sequence(sequence_processor))),
PyPostProcessor::new(Arc::new(RwLock::new(PostProcessorWrapper::Sequence(sequence_processor)))),
)
}

fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> {
PyTuple::new_bound(py, [PyList::empty_bound(py)])
}

fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {

let super_ = self_.as_ref();
let mut wrapper = super_.processor.write().unwrap();
// if let PostProcessorWrapper::Sequence(ref mut post) = *wrapper {
// match post.get(index) {
// Some(item) => PyPostProcessor::new(Arc::clone(item)).get_as_subtype(py),
// _ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
// "Index not found",
// )),
// }
// }

match *wrapper {
PostProcessorWrapper::Sequence(ref mut inner) => match inner.get_mut(index) {
Some(item) => PyPostProcessor::new(Arc::new(RwLock::new(item.to_owned()))).get_as_subtype(py),
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Index not found",
)),
},
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"This processor is not a Sequence, it does not support __getitem__",
)),
}
}

fn __setitem__(self_: PyRefMut<'_, Self>, py: Python<'_>, index: usize, value: PyRef<'_, PyPostProcessor>) -> PyResult<()> {
let super_ = self_.as_ref();
let mut wrapper = super_.processor.write().unwrap();
let value = value.processor.read().unwrap().clone();
match *wrapper {
PostProcessorWrapper::Sequence(ref mut inner) => {
// Convert the Py<PyAny> into the appropriate Rust type
// Ensure we can set an item at the given index
if index < inner.get_processors().len() {
inner.set_mut(index, value); // Assuming you want to wrap the new item in Arc<RwLock>

Ok(())
} else {
Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Index out of bounds",
))
}
},
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"This processor is not a Sequence, it does not support __setitem__",
)),
}
}
}

/// Processors Module
Expand Down Expand Up @@ -481,9 +594,9 @@ mod test {
#[test]
fn get_subtype() {
Python::with_gil(|py| {
let py_proc = PyPostProcessor::new(Arc::new(
let py_proc = PyPostProcessor::new(Arc::new(RwLock::new(
BertProcessing::new(("SEP".into(), 0), ("CLS".into(), 1)).into(),
));
)));
let py_bert = py_proc.get_as_subtype(py).unwrap();
assert_eq!(
"BertProcessing",
Expand All @@ -499,7 +612,7 @@ mod test {
let rs_processing_ser = serde_json::to_string(&rs_processing).unwrap();
let rs_wrapper_ser = serde_json::to_string(&rs_wrapper).unwrap();

let py_processing = PyPostProcessor::new(Arc::new(rs_wrapper));
let py_processing = PyPostProcessor::new(Arc::new(RwLock::new(rs_wrapper)));
let py_ser = serde_json::to_string(&py_processing).unwrap();
assert_eq!(py_ser, rs_processing_ser);
assert_eq!(py_ser, rs_wrapper_ser);
Expand Down
20 changes: 20 additions & 0 deletions tokenizers/src/processors/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,26 @@ impl Sequence {
pub fn new(processors: Vec<PostProcessorWrapper>) -> Self {
Self { processors }
}

pub fn get(&self, index: usize) -> Option<& PostProcessorWrapper> {
self.processors.get(index as usize)
}

pub fn get_mut(&mut self, index: usize) -> Option<&mut PostProcessorWrapper> {
self.processors.get_mut(index)
}

pub fn set_mut(&mut self, index: usize, post_proc: PostProcessorWrapper) {
self.processors[index as usize] = post_proc;
}

pub fn get_processors(&self) -> &[PostProcessorWrapper] {
&self.processors
}

pub fn get_processors_mut(&mut self) -> &mut [PostProcessorWrapper] {
&mut self.processors
}
}

impl PostProcessor for Sequence {
Expand Down
57 changes: 56 additions & 1 deletion tokenizers/src/processors/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ impl From<HashMap<String, SpecialToken>> for Tokens {
#[builder(build_fn(validate = "Self::validate"))]
pub struct TemplateProcessing {
#[builder(try_setter, default = "\"$0\".try_into().unwrap()")]
single: Template,
pub single: Template,
#[builder(try_setter, default = "\"$A:0 $B:1\".try_into().unwrap()")]
pair: Template,
#[builder(setter(skip), default = "self.default_added(true)")]
Expand All @@ -351,6 +351,61 @@ pub struct TemplateProcessing {
special_tokens: Tokens,
}


impl TemplateProcessing {
// Getter for `single`
pub fn get_single(& self) -> Template {
self.single.clone()
}

// Setter for `single`
pub fn set_single(&mut self, single: Template) {
println!("Setting single to: {:?}", single); // Debugging output
self.single = single;
println!("Single is now {:?}", self.single);
}

// Getter for `pair`
pub fn get_pair(&self) -> &Template {
&self.pair
}

// Setter for `pair`
pub fn set_pair(&mut self, pair: Template) {
self.pair = pair;
}

// Getter for `added_single`
pub fn get_added_single(&self) -> usize {
self.added_single
}

// Setter for `added_single`
pub fn set_added_single(&mut self, added_single: usize) {
self.added_single = added_single;
}

// Getter for `added_pair`
pub fn get_added_pair(&self) -> usize {
self.added_pair
}

// Setter for `added_pair`
pub fn set_added_pair(&mut self, added_pair: usize) {
self.added_pair = added_pair;
}

// Getter for `special_tokens`
pub fn get_special_tokens(&self) -> &Tokens {
&self.special_tokens
}

// Setter for `special_tokens`
pub fn set_special_tokens(&mut self, special_tokens: Tokens) {
self.special_tokens = special_tokens;
}
}

impl From<&str> for TemplateProcessingBuilderError {
fn from(e: &str) -> Self {
e.to_string().into()
Expand Down
Loading