@@ -40,7 +40,8 @@ def train(
40
40
41
41
if attribute_doc is not None :
42
42
example .set (
43
- MESSAGE_TOKENS_NAMES [attribute ], self .tokenize (attribute_doc )
43
+ MESSAGE_TOKENS_NAMES [attribute ],
44
+ self .tokenize (attribute_doc , attribute ),
44
45
)
45
46
46
47
def get_doc (self , message : Message , attribute : Text ) -> "Doc" :
@@ -49,10 +50,12 @@ def get_doc(self, message: Message, attribute: Text) -> "Doc":
49
50
def process (self , message : Message , ** kwargs : Any ) -> None :
50
51
message .set (
51
52
MESSAGE_TOKENS_NAMES [MESSAGE_TEXT_ATTRIBUTE ],
52
- self .tokenize (self .get_doc (message , MESSAGE_TEXT_ATTRIBUTE )),
53
+ self .tokenize (
54
+ self .get_doc (message , MESSAGE_TEXT_ATTRIBUTE ), MESSAGE_TEXT_ATTRIBUTE
55
+ ),
53
56
)
54
57
55
- def tokenize (self , doc : "Doc" ) -> List [Token ]:
58
+ def tokenize (self , doc : "Doc" , attribute : Text ) -> List [Token ]:
56
59
tokens = [Token (t .text , t .idx ) for t in doc ]
57
- self .add_cls_token (tokens )
60
+ self .add_cls_token (tokens , attribute )
58
61
return tokens
0 commit comments