Skip to content

Commit

Permalink
Make SweetNet resizable
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Jun 26, 2024
1 parent d7ebd4e commit ab73b62
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions glycowork/ml/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,18 @@ class SweetNet(torch.nn.Module):
| :-
| Returns batch-wise predictions
"""
def __init__(self, lib_size, num_classes = 1):
def __init__(self, lib_size, num_classes: int = 1, hidden_dim: int = 128):
super(SweetNet, self).__init__()

# Convolution operations on the graph
self.conv1 = GraphConv(128, 128)
self.conv2 = GraphConv(128, 128)
self.conv3 = GraphConv(128, 128)
self.conv1 = GraphConv(hidden_dim, hidden_dim)
self.conv2 = GraphConv(hidden_dim, hidden_dim)
self.conv3 = GraphConv(hidden_dim, hidden_dim)

# Node embedding
self.item_embedding = torch.nn.Embedding(num_embeddings = lib_size+1,
embedding_dim = 128)
self.item_embedding = torch.nn.Embedding(num_embeddings=lib_size+1, embedding_dim=hidden_dim)
# Fully connected part
self.lin1 = torch.nn.Linear(128, 1024)
self.lin1 = torch.nn.Linear(hidden_dim, 1024)
self.lin2 = torch.nn.Linear(1024, 128)
self.lin3 = torch.nn.Linear(128, num_classes)
self.bn1 = torch.nn.BatchNorm1d(1024)
Expand Down Expand Up @@ -313,24 +312,27 @@ def init_weights(model, mode = 'sparse', sparsity = 0.1):
print("This initialization option is not supported.")


def prep_model(model_type, num_classes, libr = None,
trained = False):
def prep_model(model_type, num_classes, libr=None, trained=False, hidden_dim: int = 128):
"""wrapper to instantiate model, initialize it, and put it on the GPU\n
| Arguments:
| :-
| model_type (string): string indicating the type of model
| num_classes (int): number of unique classes for classification
| libr (dict): dictionary of form glycoletter:index\n
| trained (bool): whether to use pretrained model; default:False
| hidden_dim (int): hidden dimension for the model (currently only for SweetNet); default:128\n
| Returns:
| :-
| Returns PyTorch model object
"""
if libr is None:
libr = lib
if model_type == 'SweetNet':
model = SweetNet(len(libr), num_classes = num_classes)
model = SweetNet(len(libr), num_classes=num_classes, hidden_dim=hidden_dim)
model = model.apply(lambda module: init_weights(module, mode = 'sparse'))
if trained:
if hidden_dim != 128:
raise ValueError("Hidden dimension must be 128 for pretrained model")
if not os.path.exists("SweetNet.pt"):
download_model("https://drive.google.com/file/d/1V4mMywfFW8tSmjLGbmKH_D8XbLoJnqqs/view?usp=sharing", local_path = "SweetNet.pt")
model.load_state_dict(torch.load("SweetNet.pt", map_location = device))
Expand Down

0 comments on commit ab73b62

Please sign in to comment.