Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: GenerationMixin._extract_past_from_model_output() got an unexpected keyword argument 'standardize_cache_format with transformers==4.44.0 #181

Open
1 of 2 tasks
7801943 opened this issue Aug 16, 2024 · 2 comments
Assignees

Comments

@7801943
Copy link

7801943 commented Aug 16, 2024

System Info / 系統信息

transformers==4.44.0

Who can help? / 谁可以帮助到您?

No response

Information / 问题信息

  • The official example scripts / 官方的示例脚本
  • My own modified scripts / 我自己修改的脚本和任务

Reproduction / 复现过程

basic_demo/web_demo.py
因为transformers 4.44.0 的代码有所变动,所以报出TypeError: GenerationMixin._extract_past_from_model_output() got an unexpected keyword argument 'standardize_cache_format

transformers 4.44.0 的utils.py 变成了这样:

    def _extract_past_from_model_output(self, outputs: ModelOutput):
        past_key_values = None
        cache_name = "past_key_values"
        if "past_key_values" in outputs:
            past_key_values = outputs.past_key_values
        elif "mems" in outputs:
            past_key_values = outputs.mems
        elif "past_buckets_states" in outputs:
            past_key_values = outputs.past_buckets_states
        elif "cache_params" in outputs:
            past_key_values = outputs.cache_params
            cache_name = "cache_params"

        return cache_name, past_key_values

transformers 4.42.4的utils.py是这样:

    def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False):
        past_key_values = None
        cache_name = "past_key_values"
        if "past_key_values" in outputs:
            past_key_values = outputs.past_key_values
        elif "mems" in outputs:
            past_key_values = outputs.mems
        elif "past_buckets_states" in outputs:
            past_key_values = outputs.past_buckets_states
        elif "cache_params" in outputs:
            past_key_values = outputs.cache_params
            cache_name = "cache_params"

        # Bloom fix: standardizes the cache format when requested
        if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"):
            batch_size = outputs.logits.shape[0]
            past_key_values = self._convert_to_standard_cache(past_key_values, batch_size=batch_size)
        return cache_name, past_key_values

Expected behavior / 期待表现

将modeling_cogvlm.py的710行由这样

        # update past_key_values
        # tansformers==4.42.4
        model_kwargs["past_key_values"] = self._extract_past_from_model_output(
            outputs, standardize_cache_format=standardize_cache_format
        )

改成这样:

        # transformers==4.44.0 
        _, model_kwargs["past_key_values"] = self._extract_past_from_model_output(outputs)

可以正常工作了

@zRzRzRzRzRzRzR zRzRzRzRzRzRzR self-assigned this Aug 17, 2024
@zRzRzRzRzRzRzR
Copy link
Member

下降到4.40.2,我们之后会更新模型文件

@YakiChen
Copy link

4.40.2会报错不符合transformers>=4.41.2,<=4.45.0要求,可以装4.41.2,历时两天终于解决了

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants