Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

询问hidden_state维度数 的问题 #1366

Open
EmeryBAI opened this issue Feb 26, 2025 · 0 comments
Open

询问hidden_state维度数 的问题 #1366

EmeryBAI opened this issue Feb 26, 2025 · 0 comments

Comments

@EmeryBAI
Copy link

我在使用QwenLM 7B 时候调用输出时候的hidden_state, 发现shape 是(39,29,1,5, d_model), 我发现39和outputs.sequences[0]的维度相同,29是transformer block 的数量,d_model是每个token embedding 的维度数,但不知道5 是什么意思,下面是我的代码

import os
import torch
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn as nn
from multiprocessing import Process, Queue
from latentnetwork.utils import log_visualize, now, match_star

class ModelAPI(nn.Module):
    def __init__(self, 
                layer_num: int = None, 
                model_path: str = None, 
                recruitment: dict = None, 
                agent_index: int = None):
        super().__init__()
        self.model_path = model_path
        self.agent_index = agent_index
        self.num_gpus = torch.cuda.device_count()
        self.device = torch.device(f"cuda:{agent_index % self.num_gpus}")
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            device_map={"": self.device},
            torch_dtype=torch.float16 if 'cuda' in str(self.device) else torch.float32
        ).eval()

        self._init_role_config(layer_num, recruitment)

    def _init_role_config(self, layer_num, recruitment):
        if layer_num == -1:
            # aggregate
            self.role_prompts = recruitment["role_prompts"]["Aggregator"]
        else:
            self.layer_name = recruitment["layer_names"][layer_num]
            self.layer_prompt = recruitment["layer_prompts"][self.layer_name]
            self.layer_roles = recruitment["layer_roles"][self.layer_name]
            agent_role = self.layer_roles[
                self.agent_index if self.agent_index < len(self.layer_roles) else -1
            ]
            self.role_prompts = recruitment["role_prompts"][agent_role]

    def get_input_embeddings(self, text):
        """
        获取输入文本的词嵌入。
        :param text: 输入文本
        :return: 输入词嵌入 (torch.Tensor)
        """
        inputs = self.tokenizer(text, return_tensors="pt")
        input_ids = inputs["input_ids"].to(self.device)
        embedding_layer = self.model.get_input_embeddings()
        input_embeddings = embedding_layer(input_ids)
        return input_embeddings
    
    def embeddings_to_text(self, concatenated_embeddings):
        """
        将嵌入向量转换为文字。
        :param concatenated_embeddings: 输入嵌入向量 (torch.Tensor), 形状为 [batch_size, sequence_length, embedding_dim]
        :return: 解码后的文本字符串
        """
        # 获取词嵌入矩阵
        embedding_layer = self.model.get_input_embeddings()
        embedding_weights = embedding_layer.weight  # 形状为 [vocab_size, embedding_dim]

        # 展平嵌入向量以便处理
        batch_size, sequence_length, embedding_dim = concatenated_embeddings.shape
        concatenated_embeddings = concatenated_embeddings.view(-1, embedding_dim)  # [batch_size * sequence_length, embedding_dim]

        # 计算每个嵌入向量与词表中所有嵌入向量的余弦相似度
        similarities = torch.nn.functional.cosine_similarity(
            concatenated_embeddings.unsqueeze(1),  # [batch_size * sequence_length, 1, embedding_dim]
            embedding_weights.unsqueeze(0),        # [1, vocab_size, embedding_dim]
            dim=-1                                 # 沿着 embedding_dim 维度计算相似度
        )  # 结果形状为 [batch_size * sequence_length, vocab_size]

        # 找到最相似的 token ID
        nearest_token_ids = similarities.argmax(dim=-1)  # [batch_size * sequence_length]

        # 将 token ID 转换回文本
        nearest_token_ids = nearest_token_ids.view(batch_size, sequence_length)  # 恢复形状为 [batch_size, sequence_length]
        decoded_texts = [
            self.tokenizer.decode(token_ids, skip_special_tokens=False)
            for token_ids in nearest_token_ids
        ]

        return decoded_texts


    @torch.no_grad()  # 禁用梯度计算
    def generate(self, input_data, max_length=2000):
        """
        生成文本或潜在表示
        :param input_data: 输入数据, 可以是文本字符串或要拼接的矩阵(tensor)
        :param max_length: 生成的最大长度
        :return: 生成的文本和/或隐藏状态
        """
        # 统一处理输入文本
        text = self.layer_prompt + self.role_prompts + (input_data if isinstance(input_data, str) else "")
        text_with_eos = text + self.tokenizer.eos_token
        
        # Tokenize the input text
        inputs = self.tokenizer(text_with_eos, return_tensors="pt").to(self.device)
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]

        # Get input embeddings
        input_embeddings = self.model.get_input_embeddings()(input_ids)
        
        if isinstance(input_data, torch.Tensor):
            # 如果输入是tensor,执行latent功能
            if input_data.device != input_embeddings.device:
                input_data = input_data.to(input_embeddings.device)
            
            # 拼接 input_embeddings 和 input_data
            # print(input_embeddings.shape, input_data.shape)
            concatenated_embeddings = torch.cat((input_embeddings, input_data), dim=1)
            # concatenated_text = self.embeddings_to_text(concatenated_embeddings)
            
            # 设置 inputs_embeds,同时将 input_ids 和 attention_mask 设置为 None
            outputs = self.model.generate(
                inputs_embeds=concatenated_embeddings,
                attention_mask=None,
                max_length=max_length,
                eos_token_id=self.tokenizer.eos_token_id,
                pad_token_id=self.tokenizer.pad_token_id,
                num_beams=5,
                temperature=0.7,
                early_stopping=True,
                output_hidden_states=True,  # 返回隐藏状态
                output_logits=True,
                return_dict_in_generate=True
            )
        else:
            # 如果输入是文本,直接使用 input_ids 和 attention_mask
            outputs = self.model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=max_length,
                eos_token_id=self.tokenizer.eos_token_id,
                pad_token_id=self.tokenizer.pad_token_id,
                num_beams=5,
                temperature=0.7,
                early_stopping=True,
                return_dict_in_generate=True
            )

        # Decode the generated text
        generated_text = self.tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
        
        if isinstance(input_data, torch.Tensor):
            # 如果是latent功能,返回生成的文本和隐藏状态
            # hidden_state = outputs.hidden_states[-2]  # Second-to-last hidden state
            hidden_state = outputs.hidden_states[-1]
            logits = outputs.logits
            # last_hidden_state = hidden_state[:, -input_data.shape[1]:, :]
            # last_hidden_state = outputs.last_hidden_state
            return generated_text, hidden_state
        else:
            # 如果是文本生成功能,只返回生成的文本
            return generated_text, None

    def text_aggregate(self, text_list, unit_num):
        """
        Aggregate texts using LLM, grouping by 'unit_num' and repeating until one text remains.
        
        :param text_list: List of initial texts
        :param unit_num: Number of texts per group during each aggregation step
        :return: The final aggregated text
        """
        def aggregate_group(group):
            """
            Aggregate a single group of texts using LLM with retry logic.
            
            :param group: List of texts to aggregate
            :return: Aggregated text
            """
            group_text = ''.join(f'The {i}th text is: {text}' for i, text in enumerate(group))
            log_visualize(f"Aggregating group: {group_text}")
            
            retry = 0
            while retry < 3:
                try:
                    aggregated_text = self.generate(group_text)[0]  # Use generate method to aggregate texts
                    aggregated_text = match_star(aggregated_text)
                    if aggregated_text:
                        return aggregated_text
                    else:
                        retry += 1
                except Exception as e:
                    log_visualize(f"Retry {retry + 1} due to error: {e}")
                    retry += 1
            
            raise RuntimeError("Aggregation failed after multiple retries")

        current_texts = text_list.copy()  # Copy to avoid modifying original list

        while len(current_texts) > 1:
            # Step 1: Group current texts into sublists with a maximum size of 'unit_num'
            groups = [current_texts[i:i + unit_num] for i in range(0, len(current_texts), unit_num)]

            # Step 2: Aggregate each group
            aggregated_texts = []
            for group in groups:
                if len(group) == 1:
                    # Skip aggregation for single text groups
                    aggregated_texts.extend(group)
                else:
                    # Aggregate using LLM
                    aggregated_text = aggregate_group(group)
                    aggregated_texts.append(aggregated_text)

            # Update current texts for next iteration
            current_texts = aggregated_texts

        return current_texts[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant