Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot load pretrained sentence retrieval model #36

Open
Martin36 opened this issue Dec 15, 2021 · 1 comment
Open

Cannot load pretrained sentence retrieval model #36

Martin36 opened this issue Dec 15, 2021 · 1 comment

Comments

@Martin36
Copy link

When trying to load the pretrained ESIM model for sentence retrieval I get the following error:

Exception has occurred: NotFoundError
Key encode_rnn/birnn/bidirectional_rnn/fw/basic_lstm_cell/bias not found in checkpoint
	 [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

Caused by op 'save/RestoreV2', defined at:
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/martin/.vscode/extensions/ms-python.python-2021.12.1559732655/pythonFiles/lib/python/debugpy/__main__.py", line 45, in <module>
    cli.main()
  File "/home/martin/.vscode/extensions/ms-python.python-2021.12.1559732655/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 444, in main
    run()
  File "/home/martin/.vscode/extensions/ms-python.python-2021.12.1559732655/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 285, in run_file
    runpy.run_path(target_as_str, run_name=compat.force_str("__main__"))
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/runpy.py", line 263, in run_path
    pkg_name=pkg_name, script_name=fname)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/runpy.py", line 96, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/martin/fever-2018-team-athene/src/athene/retrieval/sentences/sentence_retrieval.py", line 246, in <module>
    main(model="esim")
  File "/home/martin/fever-2018-team-athene/src/athene/retrieval/sentences/sentence_retrieval.py", line 193, in main
    clf.restore_model(os.path.join(model_store_dir, "best_model.ckpt"))
  File "/home/martin/fever-2018-team-athene/src/athene/retrieval/sentences/deep_models/ESIM.py", line 438, in restore_model
    self._construct_graph()
  File "/home/martin/fever-2018-team-athene/src/athene/retrieval/sentences/deep_models/ESIM.py", line 211, in _construct_graph
    saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1338, in __init__
    self.build()
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1347, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1384, in _build
    build_save=build_save, build_restore=build_restore)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 835, in _build_internal
    restore_sequentially, reshape)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 472, in _AddRestoreOps
    restore_sequentially)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 886, in bulk_restore
    return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1463, in restore_v2
    shape_and_slices=shape_and_slices, dtypes=dtypes, name=name)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3392, in create_op
    op_def=op_def)
  File "/home/martin/anaconda3/envs/team-athene/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1718, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

NotFoundError (see above for traceback): Key encode_rnn/birnn/bidirectional_rnn/fw/basic_lstm_cell/bias not found in checkpoint
	 [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

During handling of the above exception, another exception occurred:

  File "/home/martin/fever-2018-team-athene/src/athene/retrieval/sentences/deep_models/ESIM.py", line 447, in restore_model
    self._saver.restore(self._session, path)
  File "/home/martin/fever-2018-team-athene/src/athene/retrieval/sentences/sentence_retrieval.py", line 193, in main
    clf.restore_model(os.path.join(model_store_dir, "best_model.ckpt"))
  File "/home/martin/fever-2018-team-athene/src/athene/retrieval/sentences/sentence_retrieval.py", line 246, in <module>
    main(model="esim")

My belief is that it may be due to a mismatch between the variables found in the tf.GraphKeys.TRAINABLE_VARIABLES and the ones found in the .cpkt file.

The following variables are found in the trainable variables:

Trainable variables:  [<tf.Variable 'encode_rnn/birnn/bidirectional_rnn/fw/basic_lstm_cell/kernel:0' shape=(428, 512) dtype=float32_ref>, <tf.Variable 'encode_rnn/birnn/bidirectional_rnn/fw/basic_lstm_cell/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'infer_rnn/birnn/bidirectional_rnn/fw/basic_lstm_cell/kernel:0' shape=(1152, 512) dtype=float32_ref>, <tf.Variable 'infer_rnn/birnn/bidirectional_rnn/fw/basic_lstm_cell/bias:0' shape=(512,) dtype=float32_ref>, <tf.Variable 'dense/kernel:0' shape=(1024, 256) dtype=float32_ref>, <tf.Variable 'dense/bias:0' shape=(256,) dtype=float32_ref>, <tf.Variable 'dense_1/kernel:0' shape=(256, 1) dtype=float32_ref>, <tf.Variable 'dense_1/bias:0' shape=(1,) dtype=float32_ref>]

And these are the variables from the .cpkt file:

Variables found in checkpoint file:  [('dense/bias', [256]), ('dense/kernel', [1024, 256]), ('dense_1/bias', [1]), ('dense_1/kernel', [256, 1]), ('h_encode_rnn/bidirectional_rnn/fw/basic_lstm_cell/bias', [512]), ('h_encode_rnn/bidirectional_rnn/fw/basic_lstm_cell/kernel', [428, 512]), ('h_infer_rnn/bidirectional_rnn/fw/basic_lstm_cell/bias', [512]), ('h_infer_rnn/bidirectional_rnn/fw/basic_lstm_cell/kernel', [1152, 512]), ('s_endode_rnn/bidirectional_rnn/fw/basic_lstm_cell/bias', [512]), ('s_endode_rnn/bidirectional_rnn/fw/basic_lstm_cell/kernel', [428, 512]), ('s_infer_rnn/bidirectional_rnn/fw/basic_lstm_cell/bias', [512]), ('s_infer_rnn/bidirectional_rnn/fw/basic_lstm_cell/kernel', [1152, 512])]

I am using the model located in model/esim_0/sentence_retrieval_ensemble/model1

Does anyone have any idea of how to fix this problem?

@Martin36
Copy link
Author

When I train the model from scratch and then loads it, it seems to work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant