Skip to content

Commit

Permalink
Minor Multiple class refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jrrodri committed Mar 17, 2024
1 parent 82f70b0 commit bf8cef6
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 27 deletions.
4 changes: 1 addition & 3 deletions abraia/multiple.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def load_metadata(self, path):

def load_envi(self, path):
dest = self.cache_file(path)
raw = f"{dest.split('.')[0]}.raw"
if not os.path.exists(raw):
self.download_file(f"{path.split('.')[0]}.raw", raw)
raw = self.cache_file(f"{path.split('.')[0]}.raw")
return np.array(spectral.io.envi.open(dest, raw)[:, :, :])

def load_mat(self, path):
Expand Down
32 changes: 11 additions & 21 deletions abraia/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,13 @@
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# TODO: Remove with next version
def download_file(path):
dest = os.path.join(tempdir, path)
if not os.path.exists(dest):
os.makedirs(os.path.dirname(dest), exist_ok=True)
multiple.download_file(path, dest)
return dest
return multiple.cache_file(path)


def read_image(path):
dest = download_file(path)
dest = multiple.cache_file(path)
return Image.open(dest).convert('RGB')


Expand Down Expand Up @@ -81,7 +78,7 @@ def save_model(path, model, device='cpu'):


def load_model(path, class_names):
dest = download_file(path)
dest = multiple.cache_file(path)
model = create_model(class_names, pretrained=False)
model.load_state_dict(torch.load(dest))
return model
Expand All @@ -98,12 +95,14 @@ def export_onnx(path, model, device='cpu'):
multiple.upload_file(src, path)


# TODO: Remove with next version
def save_json(path, values):
multiple.save_file(path, json.dumps(values))
multiple.save_json(path, values)


# TODO: Remove with next version
def load_json(path):
return json.loads(multiple.load_file(path))
return multiple.load_json(path)


transform = transforms.Compose([
Expand All @@ -112,8 +111,7 @@ def load_json(path):
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
std=[0.229, 0.224, 0.225])
])


Expand Down Expand Up @@ -145,27 +143,22 @@ def train_model(model, dataloaders, criterion=None, optimizer=None, scheduler=No

running_loss = 0.0
running_corrects = 0

# Iterate over data.
# Iterate over data
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)

# zero the parameter gradients
optimizer.zero_grad()

# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)

# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()

# statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
Expand All @@ -174,20 +167,17 @@ def train_model(model, dataloaders, criterion=None, optimizer=None, scheduler=No

epoch_loss = running_loss / len(dataloaders[phase].dataset)
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

# deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())

print()

time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best val Acc: {best_acc:4f}')

# load best model weights
model.load_state_dict(best_model_wts)
return model
Expand Down
Binary file modified images/screenshot.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions scripts/abraia
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def input_files(src):


@click.group('abraia')
@click.version_option('0.13.2')
@click.version_option('0.13.3')
def cli():
"""Abraia CLI tool"""
pass
Expand All @@ -64,7 +64,7 @@ def configure():
@cli.command()
def info():
"""Show user account information"""
click.echo('abraia, version 0.13.2\n')
click.echo('abraia, version 0.13.3\n')
click.echo('Go to [' + click.style('https://abraia.me/console/', fg='green') + '] to see your account information\n')


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

setup(
name='abraia',
version='0.13.2',
version='0.13.3',
description='Abraia Multiple SDK',
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down

0 comments on commit bf8cef6

Please sign in to comment.