diff --git a/vision_transformer/main.py b/vision_transformer/main.py index 15fd20c640..d215156127 100644 --- a/vision_transformer/main.py +++ b/vision_transformer/main.py @@ -49,9 +49,9 @@ def __init__(self, args): # Linear projection self.LinearProjection = nn.Linear(self.input_size, self.latent_size) # Class token - self.class_token = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device) + self.class_token = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size).to(self.device)) # Positional embedding - self.pos_embedding = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device) + self.pos_embedding = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size).to(self.device)) def forward(self, input_data): input_data = input_data.to(self.device)