diff --git a/README.MD b/README.MD index a988105a..481d4665 100644 --- a/README.MD +++ b/README.MD @@ -2,7 +2,7 @@ ## 1. Introduction -**I've update the code to support both Python2 and Python3, PyTorch 0.4. If you want the old version code please checkout branch [v0.3]()** +**I've update the code to support both Python2 and Python3, PyTorch 0.4. If you want the old version code please checkout branch [v0.3](https://github.com/chenyuntc/simple-faster-rcnn-pytorch/tree/0.3)** This project is a **Simplified** Faster R-CNN implementation based on [chainercv](https://github.com/chainer/chainercv) and other [projects](#acknowledgement) . It aims to: @@ -54,7 +54,7 @@ requires PyTorch >=0.4 - install PyTorch >=0.4 with GPU (code are GPU-only), refer to [official website](http://pytorch.org) -- install cupy, you can install via `pip install cupy-cuda80` or(cupy-cuda90,cupy-cuda91). +- install cupy, you can install via `pip install cupy-cuda80` or(cupy-cuda90,cupy-cuda91, etc). - install other dependencies: `pip install -r requirements.txt ` @@ -113,9 +113,6 @@ See [demo.ipynb](https://github.com/chenyuntc/simple-faster-rcnn-pytorch/blob/ma 4. modify `voc_data_dir` cfg item in `utils/config.py`, or pass it to program using argument like `--voc-data-dir=/path/to/VOCdevkit/VOC2007/` . -#### COCO - -TBD ### 5.2 Prepare caffe-pretrained vgg16 diff --git a/model/region_proposal_network.py b/model/region_proposal_network.py index 33474a92..d0ea993b 100644 --- a/model/region_proposal_network.py +++ b/model/region_proposal_network.py @@ -112,9 +112,8 @@ def forward(self, x, img_size, scale=1.): rpn_locs = rpn_locs.permute(0, 2, 3, 1).contiguous().view(n, -1, 4) rpn_scores = self.score(h) rpn_scores = rpn_scores.permute(0, 2, 3, 1).contiguous() - rpn_softmax_scores = F.softmax(rpn_scores, dim=3) - rpn_fg_scores = \ - rpn_softmax_scores.view(n, hh, ww, n_anchor, 2)[:, :, :, :, 1].contiguous() + rpn_softmax_scores = F.softmax(rpn_scores.view(n, hh, ww, n_anchor, 2), dim=4) + rpn_fg_scores = rpn_softmax_scores[:, :, :, :, 1].contiguous() rpn_fg_scores = rpn_fg_scores.view(n, -1) rpn_scores = rpn_scores.view(n, -1, 2)