Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jaewon-lee-b committed Jun 1, 2022
1 parent febe966 commit 67771ec
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 5 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ LTEW-RCAN|[Google Drive](https://drive.google.com/file/d/1XPxwop6Q5EZGi9pM392VC5

### **Asymmetric-scale SR**

**Train**: `python train_lte.py --config configs/train-div2k/train_rcan-lte.yaml --gpu 0`
**Train**: `CUDA_VISIBLE_DEVICES=0 python train_lte.py --config configs/train-div2k/train_rcan-lte.yaml --gpu 0`

**Test**: `bash ./scripts/test-benchmark-asym.sh save/_train_rcan-lte/epoch-last.pth 0`

### **Homography transformation**

**Train**: `python train_ltew.py --config configs/train-div2k/train_rrdb-lte-warp.yaml --gpu 0,1`
**Train**: `CUDA_VISIBLE_DEVICES=0,1 python train_ltew.py --config configs/train-div2k/train_rrdb-lte-warp.yaml --gpu 0,1`

**Test**: `bash ./scripts/test-benchmark-warp.sh ./save/_train_rrdb-lte-warp/epoch-last.pth 0`

Expand Down
3 changes: 0 additions & 3 deletions train_ltew.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,13 @@ def train(train_loader, model, optimizer, \
)
lr_crop = utils.quantize(lr_crop)
m = transform.compensate_offset(m, ix, iy)
##################################################################################

############### backward map (coord, cell) from Y space to X space ###############
gridx, mask = utils.gridy2gridx_homography(batch['coord'][0], \
hr.shape[-2], hr.shape[-1], lr_crop.shape[-2], lr_crop.shape[-1], m.cuda(), cpu=False) # backward map coord from Y to X

cell = utils.celly2cellx_homography(batch['cell'][0], \
hr.shape[-2], hr.shape[-1], lr_crop.shape[-2], lr_crop.shape[-1], m.cuda(), cpu=False) # backward map cell from Y to X
##################################################################################

######################## sample query point to train LTEW ########################
num_samples = lr_crop.shape[-2] * lr_crop.shape[-1]
Expand All @@ -137,7 +135,6 @@ def train(train_loader, model, optimizer, \
gridx = gridx[sample_lst].unsqueeze(0).expand(hr.shape[0], -1, -1) # coord sample
cell = cell[sample_lst].unsqueeze(0).expand(hr.shape[0], -1, -1) # cell sample
gt = (batch['gt'][:, sample_lst].cuda() - gt_sub) / gt_div # gt sample
##################################################################################

pred = model(lr_crop, gridx, cell)
loss = loss_fn(pred, gt)
Expand Down

0 comments on commit 67771ec

Please sign in to comment.