Skip to content

Commit

Permalink
Refactor training classify
Browse files Browse the repository at this point in the history
  • Loading branch information
jrrodri committed Sep 25, 2024
1 parent 3ea92fd commit 1ad71dd
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 32 deletions.
55 changes: 26 additions & 29 deletions abraia/training/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import onnx
import torch
import torchvision
from torchvision import models, transforms
from torchvision import models, transforms, datasets

import os
import copy
Expand Down Expand Up @@ -64,13 +64,6 @@ def create_model(class_names, pretrained=True):
return model


def load_model(path, class_names):
dest = multiple.cache_file(path)
model = create_model(class_names, pretrained=False)
model.load_state_dict(torch.load(dest))
return model


transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
Expand Down Expand Up @@ -189,6 +182,7 @@ class Model:
def __init__(self):
self.imgsz = 224
self.input_shape = [1, 3, self.imgsz, self.imgsz]
self.model_name = 'resnet18'

def create_dataset(self, dataset):
# Data augmentation and normalization for training
Expand All @@ -207,33 +201,36 @@ def create_dataset(self, dataset):
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
image_datasets = {x: Dataset(os.path.join(dataset), data_transforms[x]) for x in ['train', 'val']}
# image_datasets = {x: Dataset(os.path.join(dataset, x), data_transforms[x]) for x in ['train', 'val']}
# image_datasets = {x: Dataset(os.path.join(dataset), data_transforms[x]) for x in ['train', 'val']}
image_datasets = {x: datasets.ImageFolder(os.path.join(dataset, x), transform=data_transforms[x]) for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=8, shuffle=True, num_workers=4) for x in ['train', 'val']}
classes = image_datasets['train'].classes
return dataloaders, classes

def train_model(self, dataloaders, classes):
def train(self, dataset, epochs=25):
dataloaders, classes = self.create_dataset(dataset)
model_conv = create_model(classes)
model = train_model(model_conv, dataloaders, num_epochs=25)
return model

def save_model(self, model, model_name, dataset, classes, device='cpu'):
imgsz = 224
model.to(device)
model_path = f"{dataset}/{model_name}.onnx"
src = os.path.join(tempdir, model_path)
dummy_input = torch.randn(1, 3, imgsz, imgsz)
os.makedirs(os.path.dirname(src), exist_ok=True)
torch.onnx.export(model, dummy_input, src, export_params=True, opset_version=10, do_constant_folding=True, input_names=['input'], output_names=['output'])
onnx_model = onnx.load(src)
self.classes = classes
self.model = train_model(model_conv, dataloaders, num_epochs=epochs)
return self.model

def save(self, dataset, classes, device='cpu'):
self.model.to(device)
model_src = os.path.join(tempdir, f"{dataset}/{self.model_name}.onnx")
os.makedirs(os.path.dirname(model_src), exist_ok=True)
dummy_input = torch.randn(1, 3, self.imgsz, self.imgsz)
torch.onnx.export(self.model, dummy_input, model_src, export_params=True, opset_version=10, do_constant_folding=True, input_names=['input'], output_names=['output'])
onnx_model = onnx.load(model_src)
onnx.checker.check_model(onnx_model)
multiple.upload_file(src, model_path)
multiple.save_json(f"{dataset}/{model_name}.json", {'inputShape': self.input_shape, 'classes': classes})
multiple.upload_file(model_src, f"{dataset}/{self.model_name}.onnx")
multiple.save_json(f"{dataset}/{self.model_name}.json", {'inputShape': self.input_shape, 'classes': classes})

def run_model(self, model, im):
def run(self, im):
input_tensor = transform(im)
input_batch = input_tensor.unsqueeze(0)
output = model(input_batch)
pred = torch.nn.functional.softmax(output)
return pred
output = self.model(input_batch)
pred = torch.softmax(output)
idx = int(pred.argmax())
confidence = float(pred[idx])
label = self.classes[idx]
return [{'label': label, 'confidence': confidence}]
6 changes: 3 additions & 3 deletions abraia/training/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, task, model_type='yolov8n'):
self.model_name = model_name
self.task = task

def train(self, dataset, batch=32, epochs=100, imgsz=640):
def train(self, dataset, epochs=100, batch=32, imgsz=640):
data = f"{dataset}" if self.task == 'classify' else f"{dataset}/data.yaml"
results = self.model.train(data=data, batch=batch, epochs=epochs, imgsz=imgsz)
metrics = self.model.val(data=data)
Expand All @@ -41,9 +41,9 @@ def save(self, dataset, classes, imgsz=640, device="cpu"):
abraia.upload_file(model_src, f"{dataset}/{self.model_name}.onnx")
abraia.save_json(f"{dataset}/{self.model_name}.json", {'task': self.task, 'inputShape': [1, 3, imgsz, imgsz], 'classes': classes})

def run(self, src):
def run(self, im):
objects = []
results = self.model.predict(src, verbose=False)[0]
results = self.model.predict(im, verbose=False)[0]
if results:
for box, mask in zip(results.boxes, results.masks):
class_id = int(box.cls)
Expand Down

0 comments on commit 1ad71dd

Please sign in to comment.