Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tokenizer save_pretrained can not handle non-string value in dtype #33304

Closed
2 of 4 tasks
jiaweihhuang opened this issue Sep 4, 2024 · 4 comments
Closed
2 of 4 tasks
Labels
bug Usage General questions about the library

Comments

@jiaweihhuang
Copy link

jiaweihhuang commented Sep 4, 2024

System Info

python3.10
transformers 4.36.2
torch 2.1.2
torchaudio 2.1.2
torchvision 0.16.2

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import T5Tokenizer
import torch
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", torch_dtype=torch.bfloat16)
tokenizer.save_pretrained('./')

Expected behavior

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/jiawhuang/miniconda3/envs/rlhflow/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2430, in save_pretrained
    out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
  File "/home/jiawhuang/miniconda3/envs/rlhflow/lib/python3.10/json/__init__.py", line 238, in dumps
    **kw).encode(obj)
  File "/home/jiawhuang/miniconda3/envs/rlhflow/lib/python3.10/json/encoder.py", line 201, in encode
    chunks = list(chunks)
  File "/home/jiawhuang/miniconda3/envs/rlhflow/lib/python3.10/json/encoder.py", line 431, in _iterencode
    yield from _iterencode_dict(o, _current_indent_level)
  File "/home/jiawhuang/miniconda3/envs/rlhflow/lib/python3.10/json/encoder.py", line 405, in _iterencode_dict
    yield from chunks
  File "/home/jiawhuang/miniconda3/envs/rlhflow/lib/python3.10/json/encoder.py", line 438, in _iterencode
    o = _default(o)
  File "/home/jiawhuang/miniconda3/envs/rlhflow/lib/python3.10/json/encoder.py", line 179, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type dtype is not JSON serializable

Explanation

My conjecture is that, when I load tokenizer with bfloat16, tokenizer.dtype is assigned by torch.bfloat16. When saving the tokenizer, the dtype was not handled.

@LysandreJik
Copy link
Member

Hello! Is the torch_dtype argument documented somewhere? It doesn't seem to me like a feature that's supported

@jiaweihhuang
Copy link
Author

jiaweihhuang commented Sep 7, 2024

Hello, I take a deeper look, it seems there is no argument named torch_dtype, but tokenizer.init_kwargs has an element with key named torch_dtype.

I found it is possible to avoid such TypeError by modifying the original _save function of DPOTrainer class to the following. Here the only modification I have is under the if self.tokenizer is not None: branch: I set self.tokenizer.init_kwargs['torch_dtype'] to None before saving and set it back to the original value after saving.

Maybe you have smarter ways to fix it.

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        import os
        from transformers.utils import logging, is_peft_available
        logger = logging.get_logger(__name__)
        from transformers.modeling_utils import unwrap_model
        if is_peft_available():
            from peft import PeftModel
        TRAINING_ARGS_NAME = "training_args.bin"

        # If we are executing this function, we are the process zero, so we don't check for that.
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Saving model checkpoint to {output_dir}")

        supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, supported_classes):
            if state_dict is None:
                state_dict = self.model.state_dict()

            if isinstance(unwrap_model(self.model), supported_classes):
                unwrap_model(self.model).save_pretrained(
                    output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
                )
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
                if self.args.save_safetensors:
                    safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME))
                else:
                    torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
        else:
            self.model.save_pretrained(
                output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
            )

        if self.tokenizer is not None:
            org_dtype = self.tokenizer.init_kwargs['torch_dtype']
            self.tokenizer.init_kwargs['torch_dtype'] = None
            self.tokenizer.save_pretrained(output_dir)
            self.tokenizer.init_kwargs['torch_dtype'] = org_dtype

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

Copy link

github-actions bot commented Oct 5, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ArthurZucker
Copy link
Collaborator

This is related to the DPO traininer in trl? In that case PR should go there! Closing as this is a wrong usage of the tokenizers.form_pretrained 🤗

@ArthurZucker ArthurZucker added the Usage General questions about the library label Oct 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Usage General questions about the library
Projects
None yet
Development

No branches or pull requests

3 participants