Skip to content

Commit

Permalink
maxwell ft
Browse files Browse the repository at this point in the history
  • Loading branch information
teowu committed Jan 7, 2024
1 parent ca43543 commit af4904e
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 66 deletions.
1 change: 1 addition & 0 deletions playground/data/ft/maxwell/test_split_1.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions playground/data/ft/maxwell/train_split_1.json

Large diffs are not rendered by default.

137 changes: 71 additions & 66 deletions q_align/train/train_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,74 +574,79 @@ def next_rand(self):

def __getitem__(self, i) -> Dict[str, torch.Tensor]:
while True:
sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
if 'image' in sources[0]:
image_file = self.list_data_dict[i]['image']

image_folder = self.data_args.image_folder
processor = self.data_args.image_processor
from pathlib import Path
#if not Path(os.path.join(image_folder, image_file)).exists():
# i = self.next_rand()
# continue
if isinstance(image_file, list):
# Multiple Images as Input
try:
image = [Image.open(os.path.join(image_folder, imfile)).convert('RGB') for imfile in image_file]
except Exception as ex:
print(ex)
i = self.next_rand()
continue
if self.data_args.image_aspect_ratio == 'pad':
image = [expand2square(img, tuple(int(x*255) for x in processor.image_mean)) for img in image]
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
try:
sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
if 'image' in sources[0]:
image_file = self.list_data_dict[i]['image']

image_folder = self.data_args.image_folder
processor = self.data_args.image_processor
from pathlib import Path
#if not Path(os.path.join(image_folder, image_file)).exists():
# i = self.next_rand()
# continue
if isinstance(image_file, list):
# Multiple Images as Input
try:
image = [Image.open(os.path.join(image_folder, imfile)).convert('RGB') for imfile in image_file]
except Exception as ex:
print(ex)
i = self.next_rand()
continue
if self.data_args.image_aspect_ratio == 'pad':
image = [expand2square(img, tuple(int(x*255) for x in processor.image_mean)) for img in image]
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
elif os.path.join(image_folder, image_file).endswith("mp4"):
# Video as Input
image = load_video(os.path.join(image_folder, image_file))
if self.data_args.image_aspect_ratio == 'pad':
image = [expand2square(img, tuple(int(x*255) for x in processor.image_mean)) for img in image]
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
elif os.path.join(image_folder, image_file).endswith("mp4"):
# Video as Input
image = load_video(os.path.join(image_folder, image_file))
if self.data_args.image_aspect_ratio == 'pad':
image = [expand2square(img, tuple(int(x*255) for x in processor.image_mean)) for img in image]
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
try:
image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
except Exception as ex:
print(ex)
i = self.next_rand()
continue
if self.data_args.image_aspect_ratio == 'pad':
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args)
else:
try:
image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
except Exception as ex:
print(ex)
i = self.next_rand()
continue
if self.data_args.image_aspect_ratio == 'pad':
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args)
else:

sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess(
sources,
self.tokenizer,
has_image=('image' in self.list_data_dict[i]))
if isinstance(i, int):
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])

# image exist in the data
if 'image' in self.list_data_dict[i]:
data_dict['image'] = image
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
crop_size = self.data_args.image_processor.crop_size
data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
return data_dict

sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess(
sources,
self.tokenizer,
has_image=('image' in self.list_data_dict[i]))
if isinstance(i, int):
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])

# image exist in the data
if 'image' in self.list_data_dict[i]:
data_dict['image'] = image
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
crop_size = self.data_args.image_processor.crop_size
data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
return data_dict
except Exception as ex:
print(ex)
i = self.next_rand()
continue


@dataclass
Expand Down
40 changes: 40 additions & 0 deletions scripts/maxwell-officialsplit-lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash
LOAD='q-future/one-align'

for i in $(seq 1 1)
do
echo "Split $i"
DATA_FILE=playground/data/ft/maxwell/train_split_$i.json
deepspeed --master_port 25801 q_align/train/train_mem.py \
--deepspeed ./scripts/zero3.json \
--lora_enable True --visual_abstractor_lr 2e-5\
--model_name_or_path $LOAD \
--version v1 \
--data_path $DATA_FILE \
--image_folder ../datasets/MaxWell \
--image_aspect_ratio pad \
--group_by_modality_length True \
--bf16 True \
--output_dir ./q-align-maxwell-lora-$i \
--num_train_epochs 5 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 800 \
--save_total_limit 3 \
--learning_rate 2e-4 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--tune_visual_abstractor True \
--freeze_vision_model False \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb
done

0 comments on commit af4904e

Please sign in to comment.