Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not able to load DCGAN pre-trained weights #16

Open
harshitaseth opened this issue May 1, 2019 · 9 comments
Open

Not able to load DCGAN pre-trained weights #16

harshitaseth opened this issue May 1, 2019 · 9 comments

Comments

@harshitaseth
Copy link

Hi @alex-sage,
I am trying to generate Logo dataset using the pre-trained weights of DCGAN and WGAN. When I am running the main.py(dcgan), it gives tensor mismatch error
Assign requires shapes of both tensors to match. lhs shape= [5,5,3,64] rhs shape= [5,5,3,456]
[[Node: save/Assign_47 = Assign[T=DT_FLOAT, _class=["loc:@generator/g_h4/w"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](generator/g_h4/w, save/RestoreV2:47)]]
I have also tried to change the few parameters to match the dimension of model parameters and checkpoint parameters, still not working.
Can you please provide me exact configuration on which you have trained your model.

@harshitaseth harshitaseth changed the title Not able to load DCGAN weights Not able to load DCGAN pre-trained weights May 1, 2019
@alex-sage
Copy link
Owner

alex-sage commented May 2, 2019

Hi @harshitaseth,
Very probably this error is caused because you use an incompatible (too recent) TensorFlow version.
This is the configuration I used for WGAN which should work for the DCGAN network too. The most important part is TF version 1.3.0.

backports.functools-lru-cache==1.5
backports.shutil-get-terminal-size==1.0.0
backports.weakref==1.0rc1
bleach==1.5.0
cycler==0.10.0
decorator==4.3.2
enum34==1.1.6
funcsigs==1.0.2
h5py==2.9.0
html5lib==0.9999999
ipython==5.8.0
ipython-genutils==0.2.0
kiwisolver==1.0.1
Markdown==3.0.1
matplotlib==2.2.3
mock==2.0.0
numpy==1.16.1
pathlib2==2.3.3
pbr==5.1.2
pexpect==4.6.0
pickleshare==0.7.5
Pillow==5.4.1
pkg-resources==0.0.0
prompt-toolkit==1.0.15
protobuf==3.6.1
ptyprocess==0.6.0
Pygments==2.3.1
pyparsing==2.3.1
python-dateutil==2.8.0
pytz==2018.9
scandir==1.9.0
scipy==1.2.1
simplegeneric==0.8.1
six==1.12.0
subprocess32==3.5.3
tensorflow==1.3.0
tensorflow-gpu==1.3.0
tensorflow-tensorboard==0.1.8
tqdm==4.31.1
traitlets==4.3.2
wcwidth==0.1.7
Werkzeug==0.14.1

Alternatively you can adhere to the requirements of the original DCGAN Repo which will definitely work for sure:

  • Python 2.7 or Python 3.3+
  • Tensorflow 0.12.1
  • SciPy
  • pillow
  • (Optional) moviepy (for visualization)
  • (Optional) Align&Cropped Images.zip : Large-scale CelebFaces Dataset

@rookiexiao123
Copy link

I installed the same packages you used.but it did't work.My work is to input text to generate infinite logos .I tried to run Dcgan.when I python main.py, it reported errors.

`[*] Reading checkpoints...
go oning!!!
WARNING:tensorflow:The saved meta_graph is possibly from an older release:
'model_variables' collection should be of type 'byte_list', but instead is of type 'node_list'.
W1216 11:08:57.212315 140132085733120 meta_graph.py:935] The saved meta_graph is possibly from an older release:
'model_variables' collection should be of type 'byte_list', but instead is of type 'node_list'.
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
W1216 11:08:57.214524 140132085733120 deprecation.py:323] From /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from checkpoint/LLD_64_64_64/DCGAN.model-132002
I1216 11:08:57.215079 140132085733120 saver.py:1280] Restoring parameters from checkpoint/LLD_64_64_64/DCGAN.model-132002
Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1356, in _do_call
return fn(*args)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1341, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1429, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [8192,1] rhs shape= [18084,1]
[[{{node save/Assign_20}}]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1286, in restore
{self.saver_def.filename_tensor_name: save_path})
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 950, in run
run_metadata_ptr)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1173, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1350, in _do_run
run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1370, in _do_call
raise type(e)(node_def, op, message)If I can't, I feel like training my own model
tensorflow.python.framework.errors_impl.InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [8192,1] rhs shape= [18084,1]
[[node save/Assign_20 (defined at /home/xhz/work/infor/logo-gen-master/dcgan/model.py:158) ]]

Errors may have originated from an input operation.
Input Source operations connected to node save/Assign_20:
discriminator/d_h3_lin/Matrix (defined at /home/xhz/work/infor/logo-gen-master/dcgan/ops.py:68)

Original stack trace for 'save/Assign_20':
File "main.py", line 130, in
tf.app.run()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/platform/app.py", line 40, in run
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
File "/usr/local/lib/python3.5/dist-packages/absl/app.py", line 299, in run
_run_main(main, args)
File "/usr/local/lib/python3.5/dist-packages/absl/app.py", line 250, in _run_main
sys.exit(main(argv))
File "main.py", line 109, in main
y_dim=FLAGS.y_dim)
File "/home/xhz/work/infor/logo-gen-master/dcgan/model.py", line 158, in init
self.saver = tf.train.Saver()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 825, in init
self.build()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 837, in build
self._build(self._filename, build_save=True, build_restore=True)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 875, in _build
build_restore=build_restore)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 508, in _build_internal
restore_sequentially, reshape)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 350, in _AddRestoreOps
assign_ops.append(saveable.restore(saveable_tensors, shapes))
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saving/saveable_object_util.py", line 72, in restore
self.op.get_shape().is_fully_defined())
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/state_ops.py", line 227, in assign
validate_shape=validate_shape)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_state_ops.py", line 66, in assign
use_locking=use_locking, name=name)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 3616, in create_op
op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2005, in init
self._traceback = tf_stack.extract_stack()

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "main.py", line 130, in
tf.app.run()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/platform/app.py", line 40, in run
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
File "/usr/local/lib/python3.5/dist-packages/absl/app.py", line 299, in run
_run_main(main, args)
File "/usr/local/lib/python3.5/dist-packages/absl/app.py", line 250, in _run_main
sys.exit(main(argv))
File "main.py", line 114, in main
if not dcgan.load(FLAGS.checkpoint_dir):
File "/home/xhz/work/infor/logo-gen-master/dcgan/model.py", line 437, in load
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1322, in restore
err, "a mismatch between the current graph and the graph")
tensorflow.python.framework.errors_impl.InvalidArgumentError: Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Assign requires shapes of both tensors to match. lhs shape= [8192,1] rhs shape= [18084,1]
[[node save/Assign_20 (defined at /home/xhz/work/infor/logo-gen-master/dcgan/model.py:158) ]]

Errors may have originated from an input operation.
Input Source operations connected to node save/Assign_20:
discriminator/d_h3_lin/Matrix (defined at /home/xhz/work/infor/logo-gen-master/dcgan/ops.py:68)

Original stack trace for 'save/Assign_20':
File "main.py", line 130, in
tf.app.run()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/platform/app.py", line 40, in run
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
File "/usr/local/lib/python3.5/dist-packages/absl/app.py", line 299, in run
_run_main(main, args)
File "/usr/local/lib/python3.5/dist-packages/absl/app.py", line 250, in _run_main
sys.exit(main(argv))
File "main.py", line 109, in main
y_dim=FLAGS.y_dim)
File "/home/xhz/work/infor/logo-gen-master/dcgan/model.py", line 158, in init
self.saver = tf.train.Saver()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 825, in init
self.build()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 837, in build
self._build(self._filename, build_save=True, build_restore=True)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 875, in _build
build_restore=build_restore)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 508, in _build_internal
restore_sequentially, reshape)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 350, in _AddRestoreOps
assign_ops.append(saveable.restore(saveable_tensors, shapes))
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saving/saveable_object_util.py", line 72, in restore
self.op.get_shape().is_fully_defined())
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/state_ops.py", line 227, in assign
validate_shape=validate_shape)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_state_ops.py", line 66, in assign
use_locking=use_locking, name=name)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 3616, in create_op
op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2005, in init
self._traceback = tf_stack.extract_stack()
`
I hope I can get your help.

@alex-sage
Copy link
Owner

This very much looks like you're still using a wrong Tensorflow version. Are you sure you're actually using 1.3.0 and not e.g. 1.13.0? Even the deprecation warning right at the beginning of your output suggests that the version you're using is newer than the one required as a quick google search suggests that this deprecation was introduced around version 1.13.

@rookiexiao123
Copy link

thanks,I'm sure that tf is 1.3.0. I'm confused, too.then I tried to retrain.

@alex-sage
Copy link
Owner

Hmm... For the DCGAN part you can also try to use tf 0.12.1 since that's the version used by the original author of the tensorflow DCGAN code I built upon. I thought it should work with v. 1.3 but maybe I'm mistaken (and before I didn't realize fully that you're working with DCGAN not WGAN).

@rookiexiao123
Copy link

ok,maybe I can try wgan. I read the paper that it say the score of wgan is more higher.

@alex-sage
Copy link
Owner

Yes I would definitely recommend to use WGAN. The generated logos look much better, it's more stable and easier to train and also my code base is a bit more sophisticated for the WGAN part since I used it more in the end.

@rookiexiao123
Copy link

rookiexiao123 commented Dec 16, 2019 via email

@alex-sage
Copy link
Owner

You could probably try to condition the generator on those keywords (in the same way I conditioned it on the data clusters). This was also an idea I mentioned in the paper for further work.

Wether or not this works of course strongly depends on how good your keywords fit the data. You can have a look at the provided metadata (especially for the hig-res images in lld-logo which are collected from twitter) to help you find some appropriate keywords.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants