Skip to content

Commit

Permalink
update cog predict
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed Sep 4, 2022
1 parent af75697 commit 8d2447a
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions cog_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
class Predictor(BasePredictor):

def setup(self):
os.makedirs('output', exist_ok=True)
# download weights
if not os.path.exists('gfpgan/weights/realesr-general-x4v3.pth'):
os.system(
Expand Down Expand Up @@ -69,9 +70,13 @@ def predict(
) -> Path:
print(img, version, scale)
try:
extension = os.path.splitext(os.path.basename(str(img)))[1]
img = cv2.imread(str(img), cv2.IMREAD_UNCHANGED)
if len(img.shape) == 3 and img.shape[2] == 4:
img_mode = 'RGBA'
elif len(img.shape) == 2:
img_mode = None
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
else:
img_mode = None

Expand Down Expand Up @@ -120,16 +125,15 @@ def predict(
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
except Exception as error:
print('wrong scale input.', error)

if img_mode == 'RGBA': # RGBA images should be saved in png format
extension = 'png'
else:
extension = 'jpg'
save_path = f'output/out.{extension}'
cv2.imwrite(save_path, output)
out_path = Path(tempfile.mkdtemp()) / 'output.png'
# save_path = f'output/out.{extension}'
# cv2.imwrite(save_path, output)
out_path = Path(tempfile.mkdtemp()) / f'out.{extension}'
cv2.imwrite(str(out_path), output)
except Exception as error:
print('global exception', error)
print('global exception: ', error)
finally:
clean_folder('output')
return out_path
Expand Down

1 comment on commit 8d2447a

@hassanmalik007
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DSC_1055

Please sign in to comment.