-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathFashion_Image_Generator_Algo.py
125 lines (107 loc) · 4.31 KB
/
Fashion_Image_Generator_Algo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#############################################################################
# Group Project: Fashion Image Generation #
# #
# This is the Algorithm program for the Project #
# #
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! #
# !! Please carefully read README.md file first before running !! #
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! #
# #
#############################################################################
import torch
import numpy as np
from torch import nn
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
# Define the Generator and Discriminator
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
self.label_embed = nn.Embedding(10,100)
self.model = nn.Sequential(
# 100 1 1
nn.ConvTranspose2d(100,512,4,1,0,bias = False),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
# 512 4 4
nn.ConvTranspose2d(512,256,4,2,1,bias = False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
# 256 8 8
nn.ConvTranspose2d(256,128,4,2,1,bias = False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
# 128 16 16
nn.ConvTranspose2d(128,64,4,2,1,bias = False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
# 64 32 32
nn.ConvTranspose2d(64,1,4,2,1,bias = False),
nn.Flatten(),
# 1 64 64
nn.Linear(1*64*64, 1*28*28),
# 1 28 28
nn.Tanh()
)
def forward(self,X,label):
label = self.label_embed(label)
x = torch.mul(X,label)
x = x.view(-1,100,1,1)
x = self.model(x)
return x.view(-1,1,28,28)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.model = nn.Sequential(
# 1 28 28
nn.Conv2d(1,64,4,2,1,bias = False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2,True),
nn.Dropout2d(0.5),
# 64 14 14
nn.Conv2d(64,128,4,2,1,bias = False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2,True),
nn.Dropout2d(0.5)
# 128 7 7
)
self.D_layer = nn.Sequential(
nn.Conv2d(128,1,7,1,0,bias = False),
# 128 7 7 -> 1
nn.Sigmoid())
self.class_layer = nn.Sequential(
nn.Conv2d(128,11,7,1,0,bias = False), # 11th label: 'fake'
# 128 7 7 -> 11
nn.LogSoftmax(dim = 1))
def forward(self,X):
x = self.model(X)
dis = self.D_layer(x).view(-1)
cla = self.class_layer(x).view(-1,11)
return dis,cla
labelMap = ['T-Shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# generate N images from a label
def generateFromLabel(inLabel, inNum):
loadPath = 'ACGAN6_G_30_1.pth'
device = torch.device('cpu')
inG = torch.load(loadPath, map_location=device)
inG.eval()
random = torch.randn(inNum, 100).to(device)
random_labels = np.array([i for _ in range(1) for i in [inLabel]*inNum])
random_labels = torch.from_numpy(random_labels).int().to(device)
output = inG.forward(random, random_labels).unsqueeze(1).data.to(device)
output = output.view(inNum, 1, 28, 28)
grid = make_grid(output, nrow=inNum, normalize=True).permute(1,2,0).cpu().numpy()
return grid
# for testing
if __name__ == '__main__':
l = 9
n = 10
img = generateFromLabel(l, n)
fig, ax = plt.subplots(figsize=(n,n))
ax.imshow(img)
plt.title('Generated ['+labelMap[l]+'] Fashion Images')
plt.axis('off')
plt.show()