Skip to content

Commit

Permalink
🎲 Support for setting the random seed.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexjc committed May 25, 2020
1 parent fa3398b commit d976cfc
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/texturize/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
|_| |_|\___|\__,_|_| \__,_|_| \__\___/_/\_\\__|\__,_|_| |_/___\___|
Usage:
texturize SOURCE... [--size=WxH] [--output=FILE] [--seed=SEED] [--device=DEVICE]
[--octaves=O] [--precision=P] [--iterations=I]
texturize --help
Expand All @@ -18,6 +19,7 @@
Options:
SOURCE Path to source image to use as texture.
-s WxH, --size=WxH Output resolution as WIDTHxHEIGHT. [default: 640x480]
--seed=SEED Configure the random number generation.
--device=DEVICE Hardware to use, either "cpu" or "cuda".
--octaves=O Number of octaves to process. [default: 5]
--precision=P Set the quality for the optimization. [default: 1e-4]
Expand Down Expand Up @@ -96,7 +98,7 @@ def _prepare_gram(self, features):


def get_all_layers(critics):
"""Determine the minimal list of layer features that needs to be extracted from the image.
"""Determine the minimal list of features that needs to be extracted from the image.
"""
layers = set(itertools.chain.from_iterable(c.get_layers() for c in critics))
return sorted(list(layers))
Expand Down Expand Up @@ -323,6 +325,13 @@ def main():
# Scan all the files based on the patterns specified.
files = itertools.chain.from_iterable(glob.glob(s) for s in config["SOURCE"])
for filename in files:
# If there's a random seed, use it for all images.
if config["--seed"] is not None:
seed = int(config["--seed"])
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# By default, disable autograd until the core optimization loop.
with torch.no_grad():
try:
run(config, filename)
Expand Down

0 comments on commit d976cfc

Please sign in to comment.