Skip to content

Latest commit

 

History

History
104 lines (100 loc) · 4.26 KB

File metadata and controls

104 lines (100 loc) · 4.26 KB

准备工作

# 下载项目
git clone https://github.com/yunjey/stargan.git
cd stargan
git checkout 30867d6f85a3bb99c38ae075de651004747c42d4
# 下载预训练模型
bash download.sh pretrained-celeba-128x128
# 下载数据集
bash download.sh celeba

第一步:转换前代码预处理

  1. 规避使用TensorBoard,在config处设置不使用tensorboard,具体添加代码如下:
...
parser.add_argument('--lr_update_step', type=int, default=1000)
config = parser.parse_args()
# 第5行为添加不使用tensorboard的相关代码
config.use_tensorboard = False
print(config)
main(config)

第二步:转换

cd ../
x2paddle --convert_torch_project --project_dir=stargan --save_dir=paddle_project --pretrain_model=stargan/stargan_celeba_128/models/

【注意】此示例中的pretrain_model是训练后的PyTorch模型,转换后则为PaddlePaddle训练后的模型,用户可修改转换后代码将其作为预训练模型,也可直接用于预测。

第三步:转换后代码后处理

需要修改的文件位于paddle_project文件夹中,其中文件命名与原始stargan文件夹中文件命名一致。

  1. DataLoader的num_workers设置为0,在config处设置强制设置num_workers,具体添加代码如下:
...
parser.add_argument('--lr_update_step', type=int, default=1000)
config = parser.parse_args()
config.use_tensorboard = False
# 第6行添加设置num_workers为0
config.num_workers = 0
print(config)
main(config)
  1. 修改自定义Dataset中的__getitem__的返回值,将Tensor修改为numpy,修改代码如下:
...
class CelebA(data.Dataset):
    ...
    def __getitem__(self, index):
        """Return one image and its corresponding attribute label."""
        dataset = (self.train_dataset if self.mode == 'train' else self.
                test_dataset)
        filename, label = dataset[index]
        image = Image.open(os.path.join(self.image_dir, filename))
        # return self.transform(image), torch2paddle.create_float32_tensor(label)
        # 将原来的return替换为如下12-17行
        out1 = self.transform(image)
        if isinstance(out1, paddle.Tensor):
            out1 = out1.numpy()
        out2 = torch2paddle.create_float32_tensor(label)
        if isinstance(out2, paddle.Tensor):
            out2 = out2.numpy()
        return out1, out2
    ...
  1. Tensor对比操作中对Tensor进行判断,判断是否为bool型,如果为bool类型需要强制转换,修改代码如下:
...
class Solver(object):
    ...
    def create_labels(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None):
        ...
        for i in range(c_dim):
            if dataset == 'CelebA':
                c_trg = c_org.clone()
                if i in hair_color_indices:  
                    c_trg[:, i] = 1
                    for j in hair_color_indices:
                        if j != i:
                            c_trg[:, j] = 0
            else:
                # 如果为非int型,需要强转为int32,
                # 在18-22行实现
                # c_trg[:, i] = (c_trg[:, i] == 0)
                c_trg = c_trg.cast("int32")
                c_trg_tmp = paddle.zeros_like(c_trg)
                paddle.assign(c_trg, c_trg_tmp)
                c_trg_tmp = c_trg_tmp.cast("bool")
                c_trg_tmp[:, i] = c_trg[:, i] == 0
                c_trg = c_trg_tmp 
            ...
        ...
    ...
...

运行训练代码

cd paddle_project
python main.py --mode train --dataset CelebA --image_size 128 --c_dim 5 --sample_dir stargan_celeba/samples --log_dir stargan_celeba/logs --model_save_dir stargan_celeba/models --result_dir stargan_celeba/results --selected_attrs Black_Hair Blond_Hair Brown_Hair Male Young --celeba_image_dir ./data/celeba/images --attr_path ./data/celeba/list_attr_celeba.txt

转换后的代码可在这里进行查看。