Skip to content

Commit

Permalink
clone original torch_layers
Browse files Browse the repository at this point in the history
  • Loading branch information
flowerthrower committed Apr 26, 2024
1 parent 48522fe commit 911ad02
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/mqt/predictor/ml/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import functional
from torch_geometric.nn import (
AttentionalAggregation,
Sequential,
Expand Down Expand Up @@ -45,10 +45,10 @@ def __init__(self, **kwargs: object) -> None:
self.edge_embedding = nn.Embedding(self.num_edge_categories, self.edge_embedding_dim)

if self.node_embedding_dim and self.node_embedding_dim == 1: # one-hot encoding
self.node_embedding = lambda x: F.one_hot(x, num_classes=self.num_node_categories).float()
self.node_embedding = lambda x: functional.one_hot(x, num_classes=self.num_node_categories).float()
self.node_embedding_dim = self.num_node_categories
if self.edge_embedding_dim and self.edge_embedding_dim == 1 and not self.zx: # one-hot encoding
self.edge_embedding = lambda x: F.one_hot(x, num_classes=self.num_edge_categories).float()
self.edge_embedding = lambda x: functional.one_hot(x, num_classes=self.num_edge_categories).float()
self.edge_embedding_dim = self.num_edge_categories

# hidden dimension accounting for multi-head concatenation
Expand Down
6 changes: 5 additions & 1 deletion src/mqt/predictor/rl/torch_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch as th
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, NatureCNN, is_image_space
from torch.nn import functional

logger = logging.getLogger("mqt-predictor")
Expand Down Expand Up @@ -44,6 +44,7 @@ def __init__(
cnn_output_dim: int = 256,
normalized_image: bool = False,
) -> None:
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
super().__init__(observation_space, features_dim=1)

extractors: Dict[str, nn.Module] = {}
Expand All @@ -56,6 +57,9 @@ def __init__(
total_concat_size += cnn_output_dim
elif key.startswith("graph"):
graph_observation_space.append(subspace)
elif is_image_space(subspace, normalized_image=normalized_image):
extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim, normalized_image=normalized_image)
total_concat_size += cnn_output_dim
else:
# The observation key is a vector, flatten it if needed
extractors[key] = nn.Flatten()
Expand Down

0 comments on commit 911ad02

Please sign in to comment.