-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #46 from MichiganCOG/dev
Dev
- Loading branch information
Showing
34 changed files
with
2,636 additions
and
159 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ runs/* | |
models/HGC3D | ||
*.json | ||
pbs/* | ||
*.pt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import torch | ||
try: | ||
from .abstract_datasets import DetectionDataset | ||
except: | ||
from abstract_datasets import DetectionDataset | ||
import cv2 | ||
import os | ||
import numpy as np | ||
import json | ||
try: | ||
import datasets.preprocessing_transforms as pt | ||
except: | ||
import preprocessing_transforms as pt | ||
|
||
class DHF1K(DetectionDataset): | ||
def __init__(self, *args, **kwargs): | ||
super(DHF1K, self).__init__(*args, **kwargs) | ||
|
||
# Get model object in case preprocessing other than default is used | ||
self.model_object = kwargs['model_obj'] | ||
self.load_type = kwargs['load_type'] | ||
|
||
print(self.load_type) | ||
if self.load_type=='train': | ||
self.transforms = kwargs['model_obj'].train_transforms | ||
|
||
else: | ||
self.transforms = kwargs['model_obj'].test_transforms | ||
|
||
|
||
|
||
|
||
def __getitem__(self, idx): | ||
vid_info = self.samples[idx] | ||
|
||
|
||
base_path = vid_info['base_path'] | ||
vid_size = vid_info['frame_size'] | ||
|
||
input_data = [] | ||
map_data = [] | ||
bin_data = [] | ||
|
||
for frame_ind in range(len(vid_info['frames'])): | ||
frame = vid_info['frames'][frame_ind] | ||
frame_path = frame['img_path'] | ||
map_path = frame['map_path'] | ||
bin_path = frame['bin_path'] | ||
|
||
# Load frame, convert to RGB from BGR and normalize from 0 to 1 | ||
input_data.append(cv2.imread(os.path.join(base_path, frame_path))[...,::-1]/255.) | ||
|
||
# Load frame, Normalize from 0 to 1 | ||
# All frame channels have repeated values | ||
map_data.append(cv2.imread(map_path)/255.) | ||
bin_data.append(cv2.imread(bin_path)/255.) | ||
|
||
|
||
|
||
vid_data = self.transforms(input_data) | ||
|
||
# Annotations must be resized in the loss/metric | ||
map_data = torch.Tensor(map_data) | ||
bin_data = torch.Tensor(bin_data) | ||
|
||
# Permute the PIL dimensions (Frame, Height, Width, Chan) to pytorch (Chan, frame, height, width) | ||
vid_data = vid_data.permute(3, 0, 1, 2) | ||
map_data = map_data.permute(3, 0, 1, 2) | ||
bin_data = bin_data.permute(3, 0, 1, 2) | ||
# All channels are repeated so remove the unnecessary channels | ||
map_data = map_data[0].unsqueeze(0) | ||
bin_data = bin_data[0].unsqueeze(0) | ||
|
||
|
||
ret_dict = dict() | ||
ret_dict['data'] = vid_data | ||
|
||
annot_dict = dict() | ||
annot_dict['map'] = map_data | ||
annot_dict['bin'] = bin_data | ||
annot_dict['input_shape'] = vid_data.size() | ||
annot_dict['name'] = base_path | ||
ret_dict['annots'] = annot_dict | ||
|
||
return ret_dict | ||
|
||
|
||
if __name__=='__main__': | ||
|
||
class tts(): | ||
def __call__(self, x): | ||
return pt.ToTensorClip()(x) | ||
class debug_model(): | ||
def __init__(self): | ||
self.train_transforms = tts() | ||
|
||
|
||
json_path = '/path/to/DHF1K' #### Change this when testing #### | ||
|
||
|
||
dataset = DHF1K(model_obj=debug_model(), json_path=json_path, load_type='train', clip_length=16, clip_offset=0, clip_stride=1, num_clips=0, random_offset=0, resize_shape=0, crop_shape=0, crop_type='Center', final_shape=0, batch_size=1) | ||
train_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False) | ||
|
||
|
||
import matplotlib.pyplot as plt | ||
for x in enumerate(train_loader): | ||
dat = x[1]['data'][0,:,0].permute(1,2,0).numpy() | ||
bin = x[1]['annots']['bin'][0,:,0].permute(1,2,0).numpy().repeat(3,axis=2) | ||
map = x[1]['annots']['map'][0,:,0].permute(1,2,0).numpy().repeat(3, axis=2) | ||
img = np.concatenate([dat,bin,map], axis=0) | ||
plt.imshow(img) | ||
plt.show() | ||
import pdb; pdb.set_trace() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import torch | ||
from .abstract_datasets import RecognitionDataset | ||
from PIL import Image | ||
import cv2 | ||
import os | ||
import numpy as np | ||
from torchvision import transforms | ||
|
||
class KTH(RecognitionDataset): | ||
def __init__(self, *args, **kwargs): | ||
""" | ||
Initialize KTH class | ||
Args: | ||
load_type (String): Select training or testing set | ||
resize_shape (Int): [Int, Int] Array indicating desired height and width to resize input | ||
crop_shape (Int): [Int, Int] Array indicating desired height and width to crop input | ||
final_shape (Int): [Int, Int] Array indicating desired height and width of input to deep network | ||
preprocess (String): Keyword to select different preprocessing types | ||
Return: | ||
None | ||
""" | ||
super(KTH, self).__init__(*args, **kwargs) | ||
|
||
self.load_type = kwargs['load_type'] | ||
self.resize_shape = kwargs['resize_shape'] | ||
self.crop_shape = kwargs['crop_shape'] | ||
self.final_shape = kwargs['final_shape'] | ||
self.preprocess = kwargs['preprocess'] | ||
|
||
if self.load_type=='train': | ||
self.transforms = kwargs['model_obj'].train_transforms | ||
|
||
else: | ||
self.transforms = kwargs['model_obj'].test_transforms | ||
|
||
|
||
def __getitem__(self, idx): | ||
vid_info = self.samples[idx] | ||
base_path = vid_info['base_path'] | ||
|
||
input_data = [] | ||
|
||
vid_length = len(vid_info['frames']) | ||
vid_data = np.zeros((vid_length, self.final_shape[0], self.final_shape[1], 3))-1 | ||
labels = np.zeros((vid_length))-1 | ||
input_data = [] | ||
|
||
for frame_ind in range(len(vid_info['frames'])): | ||
frame_path = os.path.join(base_path, vid_info['frames'][frame_ind]['img_path']) | ||
|
||
for frame_labels in vid_info['frames'][frame_ind]['actions']: | ||
labels[frame_ind] = frame_labels['action_class'] | ||
|
||
# Load frame image data and preprocess image accordingly | ||
input_data.append(cv2.imread(frame_path)[...,::-1]/1.) | ||
|
||
|
||
# Preprocess data | ||
vid_data = self.transforms(input_data) | ||
labels = torch.from_numpy(labels).float() | ||
|
||
# Permute the PIL dimensions (Frame, Height, Width, Chan) to pytorch (Chan, frame, height, width) | ||
vid_data = vid_data.permute(3, 0, 1, 2) | ||
|
||
ret_dict = dict() | ||
ret_dict['data'] = vid_data | ||
|
||
annot_dict = dict() | ||
annot_dict['labels'] = labels | ||
|
||
ret_dict['annots'] = annot_dict | ||
|
||
return ret_dict | ||
|
||
|
||
#dataset = HMDB51(json_path='/z/dat/HMDB51', dataset_type='train', clip_length=100, num_clips=0) | ||
#dat = dataset.__getitem__(0) | ||
#import pdb; pdb.set_trace() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.