Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dask_ml.model_selection.GridSearchCV errors for keras model #534

Open
MikeChenfu opened this issue Jul 30, 2019 · 9 comments
Open

dask_ml.model_selection.GridSearchCV errors for keras model #534

MikeChenfu opened this issue Jul 30, 2019 · 9 comments

Comments

@MikeChenfu
Copy link

I am trying to fill Keras model into dask_ml.model_selection.GridSearchCV. If I do not set client, it works fines. However, I got errors if I have two dask workers. It seems to be unable to deserialize something.
I appreciate if anyone has suggestions about this problem.

distributed.protocol.pickle - INFO - Failed to deserialize b'\x80\x04\x95b\x05\x00\x00\x00\x00\x00\x00\x8c\x1bkeras.wrappers.scikit_learn\x94\x8c\x0fKerasClassifier\x94\x93\x94)\x81\x94}

ValueError                                Traceback (most recent call last)
~/.local/lib/python3.7/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1112             subfeed_t = self.graph.as_graph_element(
-> 1113                 subfeed, allow_tensor=True, allow_operation=False)
   1114           except Exception as e:

~/.local/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in as_graph_element(self, obj, allow_tensor, allow_operation)
   3795     with self._lock:
-> 3796       return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
   3797 

~/.local/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)
   3874       if obj.graph is not self:
-> 3875         raise ValueError("Tensor %s is not an element of this graph." % obj)
   3876       return obj

ValueError: Tensor Tensor("Placeholder:0", shape=(160, 640), dtype=float32) is not an element of this graph.

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
<timed exec> in <module>

/conda/lib/python3.7/site-packages/dask_ml/model_selection/_search.py in fit(self, X, y, groups, **fit_params)
   1293             dsk, keys = build_refit_graph(estimator, X, y, best_params, fit_params)
   1294 
-> 1295             out = scheduler(dsk, keys, num_workers=n_jobs)
   1296             self.best_estimator_ = out[0]
   1297 

/conda/lib/python3.7/site-packages/distributed/client.py in get(self, dsk, keys, restrictions, loose_restrictions, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
   2525                     should_rejoin = False
   2526             try:
-> 2527                 results = self.gather(packed, asynchronous=asynchronous, direct=direct)
   2528             finally:
   2529                 for f in futures.values():

/conda/lib/python3.7/site-packages/distributed/client.py in gather(self, futures, errors, direct, asynchronous)
   1821                 direct=direct,
   1822                 local_worker=local_worker,
-> 1823                 asynchronous=asynchronous,
   1824             )
   1825 

/conda/lib/python3.7/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
    761         else:
    762             return sync(
--> 763                 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
    764             )
    765 

/conda/lib/python3.7/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
    330             e.wait(10)
    331     if error[0]:
--> 332         six.reraise(*error[0])
    333     else:
    334         return result[0]

/conda/lib/python3.7/site-packages/six.py in reraise(tp, value, tb)
    691             if value.__traceback__ is not tb:
    692                 raise value.with_traceback(tb)
--> 693             raise value
    694         finally:
    695             value = None

/conda/lib/python3.7/site-packages/distributed/utils.py in f()
    315             if callback_timeout is not None:
    316                 future = gen.with_timeout(timedelta(seconds=callback_timeout), future)
--> 317             result[0] = yield future
    318         except Exception as exc:
    319             error[0] = sys.exc_info()

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    740                     if exc_info is not None:
    741                         try:
--> 742                             yielded = self.gen.throw(*exc_info)  # type: ignore
    743                         finally:
    744                             # Break up a reference to itself

/conda/lib/python3.7/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
   1705                 else:
   1706                     self._gather_future = future
-> 1707                 response = yield future
   1708 
   1709             if response["status"] == "error":

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    740                     if exc_info is not None:
    741                         try:
--> 742                             yielded = self.gen.throw(*exc_info)  # type: ignore
    743                         finally:
    744                             # Break up a reference to itself

/conda/lib/python3.7/site-packages/distributed/client.py in _gather_remote(self, direct, local_worker)
   1758 
   1759             else:  # ask scheduler to gather data for us
-> 1760                 response = yield self.scheduler.gather(keys=keys)
   1761         finally:
   1762             self._gather_semaphore.release()

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    740                     if exc_info is not None:
    741                         try:
--> 742                             yielded = self.gen.throw(*exc_info)  # type: ignore
    743                         finally:
    744                             # Break up a reference to itself

/conda/lib/python3.7/site-packages/distributed/core.py in send_recv_from_rpc(**kwargs)
    739             name, comm.name = comm.name, "ConnectionPool." + key
    740             try:
--> 741                 result = yield send_recv(comm=comm, op=key, **kwargs)
    742             finally:
    743                 self.pool.reuse(self.addr, comm)

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    740                     if exc_info is not None:
    741                         try:
--> 742                             yielded = self.gen.throw(*exc_info)  # type: ignore
    743                         finally:
    744                             # Break up a reference to itself

/conda/lib/python3.7/site-packages/distributed/core.py in send_recv(comm, reply, serializers, deserializers, **kwargs)
    533         yield comm.write(msg, serializers=serializers, on_error="raise")
    534         if reply:
--> 535             response = yield comm.read(deserializers=deserializers)
    536         else:
    537             response = None

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    740                     if exc_info is not None:
    741                         try:
--> 742                             yielded = self.gen.throw(*exc_info)  # type: ignore
    743                         finally:
    744                             # Break up a reference to itself

/conda/lib/python3.7/site-packages/distributed/comm/tcp.py in read(self, deserializers)
    216             try:
    217                 msg = yield from_frames(
--> 218                     frames, deserialize=self.deserialize, deserializers=deserializers
    219                 )
    220             except EOFError:

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    740                     if exc_info is not None:
    741                         try:
--> 742                             yielded = self.gen.throw(*exc_info)  # type: ignore
    743                         finally:
    744                             # Break up a reference to itself

/conda/lib/python3.7/site-packages/distributed/comm/utils.py in from_frames(frames, deserialize, deserializers)
     81 
     82     if deserialize and size > FRAME_OFFLOAD_THRESHOLD:
---> 83         res = yield offload(_from_frames)
     84     else:
     85         res = _from_frames()

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/conda/lib/python3.7/concurrent/futures/_base.py in result(self, timeout)
    423                 raise CancelledError()
    424             elif self._state == FINISHED:
--> 425                 return self.__get_result()
    426 
    427             self._condition.wait(timeout)

/conda/lib/python3.7/concurrent/futures/_base.py in __get_result(self)
    382     def __get_result(self):
    383         if self._exception:
--> 384             raise self._exception
    385         else:
    386             return self._result

/conda/lib/python3.7/concurrent/futures/thread.py in run(self)
     55 
     56         try:
---> 57             result = self.fn(*self.args, **self.kwargs)
     58         except BaseException as exc:
     59             self.future.set_exception(exc)

/conda/lib/python3.7/site-packages/distributed/comm/utils.py in _from_frames()
     69         try:
     70             return protocol.loads(
---> 71                 frames, deserialize=deserialize, deserializers=deserializers
     72             )
     73         except EOFError:

/conda/lib/python3.7/site-packages/distributed/protocol/core.py in loads(frames, deserialize, deserializers)
    124                     fs = decompress(head, fs)
    125                 fs = merge_frames(head, fs)
--> 126                 value = _deserialize(head, fs, deserializers=deserializers)
    127             else:
    128                 value = Serialized(head, fs)

/conda/lib/python3.7/site-packages/distributed/protocol/serialize.py in deserialize(header, frames, deserializers)
    188         )
    189     dumps, loads, wants_context = families[name]
--> 190     return loads(header, frames)
    191 
    192 

/conda/lib/python3.7/site-packages/distributed/protocol/serialize.py in pickle_loads(header, frames)
     62 
     63 def pickle_loads(header, frames):
---> 64     return pickle.loads(b"".join(frames))
     65 
     66 

/conda/lib/python3.7/site-packages/distributed/protocol/pickle.py in loads(x)
     59 def loads(x):
     60     try:
---> 61         return pickle.loads(x)
     62     except Exception:
     63         logger.info("Failed to deserialize %s", x[:10000], exc_info=True)

/conda/lib/python3.7/site-packages/keras/engine/network.py in __setstate__(self, state)
   1264 
   1265     def __setstate__(self, state):
-> 1266         model = saving.unpickle_model(state)
   1267         self.__dict__.update(model.__dict__)
   1268 

/conda/lib/python3.7/site-packages/keras/engine/saving.py in unpickle_model(state)
    433 def unpickle_model(state):
    434     f = h5dict(state, mode='r')
--> 435     return _deserialize_model(f)
    436 
    437 

/conda/lib/python3.7/site-packages/keras/engine/saving.py in _deserialize_model(f, custom_objects, compile)
    285                              ' elements.')
    286         weight_value_tuples += zip(symbolic_weights, weight_values)
--> 287     K.batch_set_value(weight_value_tuples)
    288 
    289     if compile:

/conda/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py in batch_set_value(tuples)
   2468             assign_ops.append(assign_op)
   2469             feed_dict[assign_placeholder] = value
-> 2470         get_session().run(assign_ops, feed_dict=feed_dict)
   2471 
   2472 

~/.local/lib/python3.7/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    948     try:
    949       result = self._run(None, fetches, feed_dict, options_ptr,
--> 950                          run_metadata_ptr)
    951       if run_metadata:
    952         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~/.local/lib/python3.7/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1114           except Exception as e:
   1115             raise TypeError(
-> 1116                 'Cannot interpret feed_dict key as Tensor: ' + e.args[0])
   1117 
   1118           if isinstance(subfeed_val, ops.Tensor):

TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(160, 640), dtype=float32) is not an element of this graph.



@TomAugspurger
Copy link
Member

I believe I ran into a similar issue here / https://gist.github.com/TomAugspurger/33efb49efe611701ef122f577d0e0430

It seems to be difficult to serialize & deserialize a Keras estimator backed by Tensorflow when there are multiple processes / threads.

So I guess that

I appreciate if anyone has suggestions about this problem.

applies to me to!

@MikeChenfu
Copy link
Author

I check the previous post: dask/dask-searchcv#69
@TomAugspurger @mrocklin the Keras model may not be able to run in multiple threads. However, it can be run on multiple GPUs.

@TomAugspurger
Copy link
Member

@MikeChenfu does that mean it should work fine with multiple processes, but a single thread per process?

@MikeChenfu
Copy link
Author

Yes, it seems to work in training model . But problem occurs when finalize the work.

Here is my demo code.

# dask-worker  $ip --nprocs 2 --nthreads 1
model = KerasClassifier(build_fn=create_model, verbose=1)
optimizers = ['rmsprop', 'adam']
init = ['glorot_uniform', 'normal', 'uniform']
epochs = [100]
batches = [512]
param_grid = dict(epochs=epochs, batch_size=batches)
grid = GridSearchCV(estimator=model, param_grid=param_grid, cv=2)
grid_result = grid.fit(x, y)

My progress results are shown below.

Epoch 97/100
7725/7725 [==============================] - 1s 116us/step - loss: 0.1316 - acc: 0.9764 - categorical_crossentropy: 0.0638
Epoch 98/100
7725/7725 [==============================] - 1s 113us/step - loss: 0.1346 - acc: 0.9763 - categorical_crossentropy: 0.0675
Epoch 99/100
7725/7725 [==============================] - 1s 112us/step - loss: 0.1363 - acc: 0.9736 - categorical_crossentropy: 0.0696
Epoch 100/100
7725/7725 [==============================] - 1s 116us/step - loss: 0.1358 - acc: 0.9746 - categorical_crossentropy: 0.0691
7724/7724 [==============================] - 1s 185us/step - loss: 0.1268 - acc: 0.9727 - categorical_crossentropy: 0.0716

@MikeChenfu
Copy link
Author

I doubt the problem is not related to the multiple processes. Accidentally, it is woking well for several times when I have two dask-workers, and then it fails again. I also use one dask-worker to run it, but get the same problem.

@MikeChenfu
Copy link
Author

When it works, I get the warnings as follows:

WARNING: Logging before flag parsing goes to stderr.
W0814 20:00:21.673532 139838015661824 deprecation_wrapper.py:119] From /conda/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0814 20:00:21.701011 139838015661824 deprecation_wrapper.py:119] From /conda/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

W0814 20:00:21.714846 139838015661824 deprecation_wrapper.py:119] From /conda/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:131: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0814 20:00:21.715923 139838015661824 deprecation_wrapper.py:119] From /conda/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:133: The name tf.placeholder_with_default is deprecated. Please use tf.compat.v1.placeholder_with_default instead.

W0814 20:00:21.724133 139838015661824 deprecation.py:506] From /conda/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
W0814 20:00:21.778590 139838015661824 deprecation_wrapper.py:119] From /conda/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.

W0814 20:00:24.539319 139838015661824 deprecation_wrapper.py:119] From /conda/lib/python3.7/site-packages/keras/optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.

W0814 20:00:24.551965 139838015661824 deprecation.py:323] From /root/.local/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

@TomAugspurger
Copy link
Member

TomAugspurger commented Jan 16, 2020 via email

@bw4sz
Copy link

bw4sz commented Jan 16, 2020

For those interested, I have an example of a trained keras model to predict with dask.

dask/distributed#2333

@stsievert
Copy link
Member

There's support for Keras serialization now in SciKeras, which brings a Scikit-Learn API to Keras. This is mentioned explicitly in the documentation on https://ml.dask.org/keras.html.

We're trying to merge serialization support upstream in Tensorflow: tensorflow/tensorflow#39609, tensorflow/community#286

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants