Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

为什么我的地址都改了,还会出现这样的错误:Traceback (most recent call last): File "/root/autodl-tmp/model/XrayGLM/VisualGLM-6B-main/finetune_visualglm.py", line 9, in <module> from sat.model.finetune.lora2 import LoraMixin ModuleNotFoundError: No module named 'sat.model.finetune.lora2' [2024-12-25 00:31:28,139] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 1338 #188

Open
zxdpro opened this issue Dec 24, 2024 · 0 comments

Comments

@zxdpro
Copy link

zxdpro commented Dec 24, 2024

import os
import torch
import argparse

from sat import mpu, get_args, get_tokenizer
from sat.training.deepspeed_training import training_main
from model import VisualGLMModel
from sat.model.finetune import PTuningV2Mixin
from sat.model.finetune.lora2 import LoraMixin

class FineTuneVisualGLMModel(VisualGLMModel):
def init(self, args, transformer=None, **kw_args):
super().init(args, transformer=transformer, **kw_args)
if args.use_ptuning:
self.add_mixin("ptuning", PTuningV2Mixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.pre_seq_len))
if args.use_lora:
self.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range), reinit=True)
# self.get_mixin("eva").model.glm_proj = replace_linear_with_lora(self.get_mixin("eva").model.glm_proj, LoraLinear, args.lora_rank)
elif args.use_qlora:
self.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range, qlora=True), reinit=True)
self.args = args

@classmethod
def add_model_specific_args(cls, parser):
    group = parser.add_argument_group('VisualGLM-finetune', 'VisualGLM finetune Configurations')
    group.add_argument('--pre_seq_len', type=int, default=8)
    group.add_argument('--lora_rank', type=int, default=10)
    group.add_argument('--use_ptuning', action="store_true")
    group.add_argument('--use_lora', action="store_true")
    group.add_argument('--use_qlora', action="store_true")
    group.add_argument('--layer_range', nargs='+', type=int, default=None)
    return super().add_model_specific_args(parser)

def disable_untrainable_params(self):
    enable = []
    if self.args.use_ptuning:
        enable.extend(['ptuning'])
    if self.args.use_lora or self.args.use_qlora:
        enable.extend(['matrix_A', 'matrix_B'])
    for n, p in self.named_parameters():
        flag = False
        for e in enable:
            if e.lower() in n.lower():
                flag = True
                break
        if not flag:
            p.requires_grad_(False)
        else:
            print(n)

def get_batch(data_iterator, args, timers):
# Items and their type.
keys = ['input_ids', 'labels']
datatype = torch.int64

# Broadcast data.
timers('data loader').start()
if data_iterator is not None:
    data = next(data_iterator)
else:
    data = None
timers('data loader').stop()
data_b = mpu.broadcast_data(keys, data, datatype)
data_i = mpu.broadcast_data(['image'], data, torch.float32)
# Unpack.
tokens = data_b['input_ids'].long()
labels = data_b['labels'].long()
img = data_i['image']
if args.fp16:
    img = img.half()

return tokens, labels, img, data['pre_image']

from torch.nn import CrossEntropyLoss

def forward_step(data_iterator, model, args, timers):
"""Forward step."""

# Get the batch.
timers('batch generator').start()
tokens, labels, image, pre_image = get_batch(
    data_iterator, args, timers)
timers('batch generator').stop()

logits = model(input_ids=tokens, image=image, pre_image=pre_image)[0]
dtype = logits.dtype
lm_logits = logits.to(torch.float32)

# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

lm_logits = lm_logits.to(dtype)
loss = loss.to(dtype)
return loss, {'loss': loss}

from model.blip2 import BlipImageEvalProcessor
from torch.utils.data import Dataset
import json
from PIL import Image

class FewShotDataset(Dataset):
def init(self, path, processor, tokenizer, args):
max_seq_length = args.max_source_length + args.max_target_length
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
self.images = []
self.input_ids = []
self.labels = []
for item in data:
image = processor(Image.open(item['img']).convert('RGB'))
input0 = tokenizer.encode("", add_special_tokens=False)
input1 = [tokenizer.pad_token_id] * args.image_length
input2 = tokenizer.encode("问:"+item['prompt']+"\n答:", add_special_tokens=False)
a_ids = sum([input0, input1, input2], [])
b_ids = tokenizer.encode(text=item['label'], add_special_tokens=False)
if len(a_ids) > args.max_source_length - 1:
a_ids = a_ids[: args.max_source_length - 1]
if len(b_ids) > args.max_target_length - 2:
b_ids = b_ids[: args.max_target_length - 2]
pre_image = len(input0)
input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)

        context_length = input_ids.index(tokenizer.bos_token_id)
        mask_position = context_length - 1
        labels = [-100] * context_length + input_ids[mask_position+1:]
        
        pad_len = max_seq_length - len(input_ids)
        input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
        labels = labels + [tokenizer.pad_token_id] * pad_len
        if args.ignore_pad_token_for_loss:
            labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
        self.images.append(image)
        self.input_ids.append(input_ids)
        self.labels.append(labels)
    self.pre_image = pre_image

def __len__(self):
    return len(self.images)

def __getitem__(self, idx):
    return {
        "image": self.images[idx],
        "input_ids": self.input_ids[idx],
        "labels": self.labels[idx],
        "pre_image": self.pre_image
    }

def create_dataset_function(path, args):
tokenizer = get_tokenizer(args)
image_processor = BlipImageEvalProcessor(224)

dataset = FewShotDataset(path, image_processor, tokenizer, args)
return dataset

if name == 'main':
py_parser = argparse.ArgumentParser(add_help=False)
py_parser.add_argument('--max_source_length', type=int)
py_parser.add_argument('--max_target_length', type=int)
py_parser.add_argument('--ignore_pad_token_for_loss', type=bool, default=True)
# py_parser.add_argument('--old_checkpoint', action="store_true")
py_parser.add_argument('--source_prefix', type=str, default="")
py_parser = FineTuneVisualGLMModel.add_model_specific_args(py_parser)
known, args_list = py_parser.parse_known_args()
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
args.device = 'cpu'

model_type = '/root/autodl-tmp/model/XrayGLM/VisualGLM-6B-main/visualglm-6b'
model, args = FineTuneVisualGLMModel.from_pretrained(model_type, args)
if torch.cuda.is_available():
    model = model.to('cuda')
    args.device = 'cuda'
tokenizer = get_tokenizer(args)
label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
def data_collator(examples):
    for example in examples:
        example['input_ids'] = torch.tensor(example['input_ids'], dtype=torch.long)
        example['labels'] = torch.tensor(example['labels'], dtype=torch.long)
    ret = {
        'input_ids': torch.stack([example['input_ids'] for example in examples]),
        'labels': torch.stack([example['labels'] for example in examples]),
        'image': torch.stack([example['image'] for example in examples]),
        'pre_image': example['pre_image']
    }
    return ret
training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=create_dataset_function, collate_fn=data_collator)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant