Skip to content

Commit

Permalink
Refactor some training code
Browse files Browse the repository at this point in the history
  • Loading branch information
jrrodri committed Sep 25, 2024
1 parent 893f657 commit 3ea92fd
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 32 deletions.
9 changes: 7 additions & 2 deletions abraia/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,22 @@ def move_file(self, old_path, new_path):
resp = resp.json()
return file_path(resp['file'], self.userid)

def download_file(self, path, dest=''):
def download_file(self, path, dest='', cache=False):
url = f"{API_URL}/files/{self.userid}/{path}"
if cache and dest == '':
dest = os.path.join(tempdir, path)
if not os.path.exists(dest):
os.makedirs(os.path.dirname(dest), exist_ok=True)
resp = requests.get(url, stream=True, auth=self.auth)
if resp.status_code != 200:
raise APIError(resp.text, resp.status_code)
if not dest:
return BytesIO(resp.content)
with open(dest, 'wb') as f:
f.write(resp.content)
return dest

# TODO: Merge in download_file: cache = True
# TODO: Replaced with download_file
def cache_file(self, path):
dest = os.path.join(tempdir, path)
if not os.path.exists(dest):
Expand Down
21 changes: 14 additions & 7 deletions abraia/training/classify.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import print_function, division
from ..client import Abraia, tempdir
from ..multiple import Multiple, tempdir

import onnx
import torch
Expand All @@ -15,21 +15,21 @@
from PIL import Image


abraia = Abraia()
multiple = Multiple()


torch.backends.cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def read_image(path):
dest = abraia.cache_file(path)
dest = multiple.cache_file(path)
return Image.open(dest).convert('RGB')


class Dataset(torch.utils.data.Dataset):
def __init__(self, root_dir, transform=None, target_transform=None):
paths, labels = abraia.load_dataset(root_dir)
paths, labels = multiple.load_dataset(root_dir)
self.paths = paths
self.labels = labels
self.root_dir = root_dir
Expand Down Expand Up @@ -65,7 +65,7 @@ def create_model(class_names, pretrained=True):


def load_model(path, class_names):
dest = abraia.cache_file(path)
dest = multiple.cache_file(path)
model = create_model(class_names, pretrained=False)
model.load_state_dict(torch.load(dest))
return model
Expand Down Expand Up @@ -228,5 +228,12 @@ def save_model(self, model, model_name, dataset, classes, device='cpu'):
torch.onnx.export(model, dummy_input, src, export_params=True, opset_version=10, do_constant_folding=True, input_names=['input'], output_names=['output'])
onnx_model = onnx.load(src)
onnx.checker.check_model(onnx_model)
abraia.upload_file(src, model_path)
abraia.save_json(f"{dataset}/{model_name}.json", {'inputShape': self.input_shape, 'classes': classes})
multiple.upload_file(src, model_path)
multiple.save_json(f"{dataset}/{model_name}.json", {'inputShape': self.input_shape, 'classes': classes})

def run_model(self, model, im):
input_tensor = transform(im)
input_batch = input_tensor.unsqueeze(0)
output = model(input_batch)
pred = torch.nn.functional.softmax(output)
return pred
43 changes: 21 additions & 22 deletions abraia/training/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,34 @@ def build_model_name(model_name, task):


class Model:
def __init__(self):
pass

def train_model(self, dataset, task, batch=32, epochs=100, imgsz=640):
model_name = build_model_name('yolov8n', task)
model = YOLO(f"{model_name}.pt", verbose=False)
data = f"{dataset}" if task == 'classify' else f"{dataset}/data.yaml"
results = model.train(data=data, batch=batch, epochs=epochs, imgsz=imgsz)
metrics = model.val(data=data)
return model, model_name


def save_model(self, model, model_name, dataset, task, classes, imgsz=640):
model_src = model.export(format="onnx", device="cpu")
abraia.upload_file(model_src, f"{dataset}/{model_name}.onnx")
abraia.save_json(f"{dataset}/{model_name}.json", {'task': task, 'inputShape': [1, 3, imgsz, imgsz], 'classes': classes})


def run_model(self, model, src, task='segment'):
def __init__(self, task, model_type='yolov8n'):
model_name = build_model_name(model_type, task)
self.model = YOLO(f"{model_name}.pt", verbose=False)
self.model_name = model_name
self.task = task

def train(self, dataset, batch=32, epochs=100, imgsz=640):
data = f"{dataset}" if self.task == 'classify' else f"{dataset}/data.yaml"
results = self.model.train(data=data, batch=batch, epochs=epochs, imgsz=imgsz)
metrics = self.model.val(data=data)
return metrics

def save(self, dataset, classes, imgsz=640, device="cpu"):
model_src = self.model.export(format="onnx", device=device)
abraia.upload_file(model_src, f"{dataset}/{self.model_name}.onnx")
abraia.save_json(f"{dataset}/{self.model_name}.json", {'task': self.task, 'inputShape': [1, 3, imgsz, imgsz], 'classes': classes})

def run(self, src):
objects = []
results = model.predict(src, verbose=False)[0]
results = self.model.predict(src, verbose=False)[0]
if results:
for box, mask in zip(results.boxes, results.masks):
class_id = int(box.cls)
label = results.names[class_id]
confidence = float(box.conf)
x1, y1, x2, y2 = box.xyxy.squeeze().tolist()
object = {'label': label, 'confidence': confidence, 'color': get_color(class_id), 'box': [x1, y1, x2 - x1, y2 - y1]}
if task == 'segment':
object = {'label': label, 'confidence': confidence, 'box': [x1, y1, x2 - x1, y2 - y1]}
if self.task == 'segment':
object['polygon'] = [(x, y) for x, y in mask.xy[0]]
objects.append(object)
return objects
Binary file modified images/screenshot.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
author_email='jorge@abraiasoftware.com',
license='MIT',
zip_safe=False,
packages=find_packages(),
packages=find_packages(exclude=['tests']),
tests_require=['pytest'],
setup_requires=['setuptools>=38.6.0', 'pytest-runner'],
scripts=['scripts/abraia', 'scripts/abraia.bat'],
Expand Down

0 comments on commit 3ea92fd

Please sign in to comment.