-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from FederatedAI/develop-1.1.0
Develop 1.1.0
- Loading branch information
Showing
18 changed files
with
1,969 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,6 @@ | ||
# FATE-LLM | ||
The repo for federated fine-tuning for LLM. | ||
FATE-LLM is a framework to support federated training with large language models, it also provides multiple parameter-efficient fine-tuning strategies for industrial applications. | ||
|
||
### Quick Start | ||
- [Federated ChatGLM-6B Training](./doc/tutorial/ChatGLM-6B.ipynb) | ||
- [GPT-2 Training](./doc/tutorial/GPT2-example.ipynb) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
## Release 1.1.0 | ||
### Major Features and Improvements | ||
* Support Federated Training of ChatGLM-6B with parameter-efficient fine-tuning adapters: like Lora and P-Tuning V2 etc. | ||
* Integration of `peft`, which support many parameter-efficient adapters. |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# | ||
# Copyright 2019 The FATE Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# | ||
# Copyright 2019 The FATE Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
from federatedml.nn.dataset.base import Dataset | ||
import pandas as pd | ||
from transformers import AutoTokenizer | ||
|
||
|
||
PROMPT_TEMPLATE = "{prompt}" | ||
|
||
|
||
class GLMTokenizerDataset(Dataset): | ||
def __init__(self, truncation=True, text_max_length=256, | ||
tokenizer_name_or_path=None, | ||
padding=True, padding_side="right", pad_token=None, | ||
trust_remote_code=True, | ||
prompt_template=None, | ||
prompt_column="content", | ||
response_column="summary" | ||
): | ||
|
||
super(GLMTokenizerDataset, self).__init__() | ||
self.label = None | ||
self.tokenizer = None | ||
self.padding = padding | ||
self.truncation = truncation | ||
self.max_length = text_max_length | ||
self.tokenizer_name_or_path = tokenizer_name_or_path | ||
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name_or_path, trust_remote_code=trust_remote_code) | ||
self.tokenizer.padding_side = padding_side | ||
if pad_token is not None: | ||
self.tokenizer.add_special_tokens({'pad_token': pad_token}) | ||
|
||
self.prompt_template = prompt_template if prompt_template else PROMPT_TEMPLATE | ||
self.prompt_column = prompt_column | ||
self.response_column = response_column | ||
self._data = None | ||
|
||
def load(self, file_path): | ||
df = pd.read_json(file_path, lines=True) | ||
self._data = df.apply(self._process_data, axis=1) | ||
|
||
def _process_data(self, line): | ||
_prompt = line[self.prompt_column] | ||
_response = line[self.response_column] | ||
|
||
prompt = self.prompt_template.format_map(dict(prompt=_prompt)) | ||
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) | ||
target_ids = self.tokenizer.encode(_response, add_special_tokens=False) | ||
|
||
if len(prompt_ids) > self.max_length - 1: | ||
prompt_ids = prompt_ids[: self.max_length - 1] | ||
if len(target_ids) > self.max_length - 2: | ||
target_ids = target_ids[: self.max_length - 2] | ||
|
||
input_ids = self.tokenizer.build_inputs_with_special_tokens(prompt_ids, target_ids) | ||
|
||
seq_length = input_ids.index(self.tokenizer.bos_token_id) | ||
labels = [-100] * seq_length + input_ids[seq_length:] | ||
|
||
return { | ||
"input_ids": input_ids, | ||
"labels": labels, | ||
} | ||
|
||
def get_vocab_size(self): | ||
return self.tokenizer.vocab_size | ||
|
||
def __getitem__(self, item): | ||
return self._data[item] | ||
|
||
def __len__(self): | ||
return len(self._data) | ||
|
||
def __repr__(self): | ||
return self.tokenizer.__repr__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# | ||
# Copyright 2019 The FATE Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
from federatedml.nn.dataset.base import Dataset | ||
import pandas as pd | ||
import torch as t | ||
from transformers import AutoTokenizer | ||
import os | ||
import numpy as np | ||
|
||
# avoid tokenizer parallelism | ||
os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
|
||
|
||
class TokenizerDataset(Dataset): | ||
""" | ||
A Dataset for some basic NLP Tasks, this dataset will automatically transform raw text into word indices | ||
using AutoTokenizer from transformers library, | ||
Parameters | ||
---------- | ||
truncation bool, truncate word sequence to 'text_max_length' | ||
text_max_length int, max length of word sequences | ||
tokenizer_name_or_path str, name of bert tokenizer(see transformers official for details) or path to local | ||
transformer tokenizer folder | ||
return_label bool, return label or not, this option is for host dataset, when running hetero-NN | ||
padding bool, whether to pad the word sequence to 'text_max_length' | ||
padding_side str, 'left' or 'right', where to pad the word sequence | ||
pad_token str, pad token, use this str as pad token, if None, use tokenizer.pad_token | ||
return_input_ids bool, whether to return input_ids or not, if False, return word_idx['input_ids'] | ||
""" | ||
|
||
def __init__(self, truncation=True, text_max_length=128, | ||
tokenizer_name_or_path="bert-base-uncased", | ||
return_label=True, padding=True, padding_side="right", pad_token=None, | ||
return_input_ids=True | ||
): | ||
|
||
super(TokenizerDataset, self).__init__() | ||
self.text = None | ||
self.word_idx = None | ||
self.label = None | ||
self.tokenizer = None | ||
self.sample_ids = None | ||
self.padding = padding | ||
self.truncation = truncation | ||
self.max_length = text_max_length | ||
self.with_label = return_label | ||
self.tokenizer_name_or_path = tokenizer_name_or_path | ||
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name_or_path) | ||
self.tokenizer.padding_side = padding_side | ||
self.return_input_ids = return_input_ids | ||
if pad_token is not None: | ||
self.tokenizer.add_special_tokens({'pad_token': pad_token}) | ||
|
||
def load(self, file_path): | ||
|
||
tokenizer = self.tokenizer | ||
self.text = pd.read_csv(file_path) | ||
text_list = list(self.text.text) | ||
|
||
self.word_idx = tokenizer( | ||
text_list, | ||
padding=self.padding, | ||
return_tensors='pt', | ||
truncation=self.truncation, | ||
max_length=self.max_length) | ||
|
||
if self.return_input_ids: | ||
self.word_idx = self.word_idx['input_ids'] | ||
|
||
if self.with_label: | ||
self.label = t.Tensor(self.text.label).detach().numpy() | ||
self.label = self.label.reshape((len(self.text), -1)) | ||
|
||
if 'id' in self.text: | ||
self.sample_ids = self.text['id'].values.tolist() | ||
|
||
def get_classes(self): | ||
return np.unique(self.label).tolist() | ||
|
||
def get_vocab_size(self): | ||
return self.tokenizer.vocab_size | ||
|
||
def get_sample_ids(self): | ||
return self.sample_ids | ||
|
||
def __getitem__(self, item): | ||
|
||
if self.return_input_ids: | ||
ret = self.word_idx[item] | ||
else: | ||
ret = {k: v[item] for k, v in self.word_idx.items()} | ||
|
||
if self.with_label: | ||
return ret, self.label[item] | ||
|
||
return ret | ||
|
||
def __len__(self): | ||
return len(self.text) | ||
|
||
def __repr__(self): | ||
return self.tokenizer.__repr__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# | ||
# Copyright 2019 The FATE Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# | ||
# Copyright 2019 The FATE Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
from transformers import AlbertConfig, AutoConfig | ||
from transformers import AlbertForSequenceClassification | ||
from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM | ||
|
||
|
||
class Albert(PELLM): | ||
|
||
config_class = AlbertConfig | ||
model_loader = AlbertForSequenceClassification | ||
|
||
def __init__(self, config: dict = None, | ||
pretrained_path: str = None, | ||
peft_type: str = None, | ||
peft_config: dict = None, | ||
**kwargs | ||
) -> None: | ||
|
||
if pretrained_path is not None: | ||
self.check_config(pretain_path=pretrained_path) | ||
if config is None and pretrained_path is None: | ||
config = AlbertConfig().to_dict() # use default model setting | ||
super().__init__(config=config, pretrained_path=pretrained_path, | ||
peft_type=peft_type, peft_config=peft_config, **kwargs) | ||
|
||
def check_config(self, pretain_path): | ||
config = AutoConfig.from_pretrained(pretain_path) | ||
assert isinstance( | ||
config, AlbertConfig), 'The config of pretrained model must be AlbertConfig, but got {}'.format( | ||
type(config)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# | ||
# Copyright 2019 The FATE Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
from transformers import BartConfig, AutoConfig | ||
from transformers import BartForSequenceClassification | ||
from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM | ||
|
||
|
||
class Bart(PELLM): | ||
config_class = BartConfig | ||
model_loader = BartForSequenceClassification | ||
|
||
def __init__(self, config: dict = None, | ||
pretrained_path: str = None, | ||
peft_type: str = None, | ||
peft_config: dict = None, | ||
**kwargs) -> None: | ||
|
||
if pretrained_path is not None: | ||
self.check_config(pretrain_path=pretrained_path) | ||
if config is None and pretrained_path is None: | ||
config = BartConfig().to_dict() | ||
super().__init__(config=config, pretrained_path=pretrained_path, | ||
peft_type=peft_type, peft_config=peft_config, **kwargs) | ||
|
||
def check_config(self, pretrain_path): | ||
config = AutoConfig.from_pretrained(pretrain_path) | ||
assert isinstance( | ||
config, BartConfig), 'The config of pretrained model must be BartConfig, but got {}'.format( | ||
type(config)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# | ||
# Copyright 2019 The FATE Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
from transformers import BertConfig, AutoConfig | ||
from transformers import BertForSequenceClassification | ||
from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM | ||
|
||
|
||
class Bert(PELLM): | ||
config_class = BertConfig | ||
model_loader = BertForSequenceClassification | ||
|
||
def __init__(self, config: dict = None, | ||
pretrained_path: str = None, | ||
peft_type: str = None, | ||
peft_config: dict = None, | ||
**kwargs) -> None: | ||
|
||
if pretrained_path is not None: | ||
self.check_config(pretrain_path=pretrained_path) | ||
if config is None and pretrained_path is None: | ||
config = BertConfig().to_dict() | ||
super().__init__(config=config, pretrained_path=pretrained_path, | ||
peft_type=peft_type, peft_config=peft_config, **kwargs) | ||
|
||
def check_config(self, pretrain_path): | ||
config = AutoConfig.from_pretrained(pretrain_path) | ||
assert isinstance( | ||
config, BertConfig), 'The config of pretrained model must be BertConfig, but got {}'.format( | ||
type(config)) |
Oops, something went wrong.