Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RMSProp optimization support for sparse tensors #464

Closed
fabiencro opened this issue Dec 10, 2015 · 19 comments
Closed

RMSProp optimization support for sparse tensors #464

fabiencro opened this issue Dec 10, 2015 · 19 comments
Assignees
Labels
stat:awaiting response Status - Awaiting response from author type:feature Feature requests

Comments

@fabiencro
Copy link

It seems that tf.nce_loss is not compatible with the optimizers RMSProp, ADAGRAD and Momentum. (while SGD, ADAM and FTRL works fine).

When using rmsprop, I get this error:

    optimizer = tf.train.RMSPropOptimizer(learning_rate = learning_rate, decay = rms_prop_decay).minimize(nce_loss)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/optimizer.py", line 167, in minimize
    name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/optimizer.py", line 256, in apply_gradients
    update_ops.append(self._apply_sparse(grad, var))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/rmsprop.py", line 81, in _apply_sparse
    raise NotImplementedError()
NotImplementedError

When using adagrad or momentum, I get this error:

    optimizer = tf.train.MomentumOptimizer(learning_rate, learning_momentum).minimize(nce_loss)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/optimizer.py", line 167, in minimize
    name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/optimizer.py", line 256, in apply_gradients
    update_ops.append(self._apply_sparse(grad, var))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/momentum.py", line 51, in _apply_sparse
    self._momentum_tensor, use_locking=self._use_locking).op
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/gen_training_ops.py", line 237, in sparse_apply_momentum
    name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 633, in apply_op
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1712, in create_op
    set_shapes_for_outputs(ret)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1417, in set_shapes_for_outputs
    shapes = shape_func(op)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/training_ops.py", line 111, in _SparseApplyMomentumShape
    tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/tensor_shape.py", line 481, in merge_with
    self.assert_same_rank(other)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/tensor_shape.py", line 524, in assert_same_rank
    "Shapes %s and %s must have the same rank" % (self, other))
ValueError: Shapes TensorShape([Dimension(128), Dimension(11), Dimension(192)]) and TensorShape([Dimension(None), Dimension(192)]) must have the same rank

Is that expected?
The exact same code works perfectly fine with adam or sgd optimizers, so I do not think I made a mistake when constructing the graph.

@vrv vrv changed the title RMSProp, adagrad and momentum optimization do not work with tf.nce_loss RMSProp, adagrad and momentum optimization support for sparse tensors Dec 10, 2015
@vrv
Copy link

vrv commented Dec 10, 2015

Based on reading the code, I think it is expected: there is not yet support for SparseTensors with those three optimizers, since _apply_sparse() function isn't implemented. Turning this into a feature request.

@fabiencro
Copy link
Author

OK, thank you. But since ADAM work for this, there should not be any major issue preventing to use RMSProp as well, right? (I would be quite interested in using RMSProp with nce_loss).

Also, if this is expected, I think you should try to document this. I do not think it is mentioned in the doc that only 3 out of the 6 default optimizers support sparse update. The doc of nce_loss (and of sampled softmax as well, I suppose), could also mention it.

@vrv
Copy link

vrv commented Dec 11, 2015

Assigning to someone who knows more about this part of the codebase

@fabiencro
Copy link
Author

After checking bug #505, it seems that there are actually two bugs going on here. One has to do with tf.reshape + Adagrad/Momentum , and the other has to do with RMSPropOptimizer not handling sparse gradient updates (which happens in cases of embeddings or sampled loss, I guess). Here is a self-contained example demonstrating the RMSProp error (just a slight modification of the one I posted for bug #505):

import numpy as np
import tensorflow as tf

def device_for_node(n):
    if n.type == "MatMul":
        return "/gpu:1"
    else:
        return "/cpu:0"

minibatch_size = 128
hidden_size = 64
embedding_size = 256
input_layer_size = 3
vocab_size_input = 32
vocab_size_output = 64
nce_num_sampled = 16
learning_rate = 0.1

dummy_input = np.zeros((minibatch_size, input_layer_size), dtype = np.int32)
dummy_target = np.zeros((minibatch_size, 1), dtype = np.int32)

input_layer_flattened_size = input_layer_size * embedding_size

graph = tf.Graph()

with graph.as_default():
    with graph.device(device_for_node):
        input_layer = tf.placeholder(tf.int32, shape = (minibatch_size, input_layer_size), name = "input_layer")       
        ref_input = tf.placeholder(tf.int32, shape = (minibatch_size, 1), name = "ref_input")

        # Parameters

        input_embeddings = tf.Variable(tf.random_normal([vocab_size_input, embedding_size]), name = "i_embeddings")

        Wh_i = tf.Variable(tf.random_normal((input_layer_flattened_size, hidden_size), stddev = 0.2), name = "Wh_i")
        bh_i = tf.Variable(tf.random_normal((hidden_size,), stddev = 0.2), name = "bh_i")

        Wh_o = tf.Variable(tf.random_normal((vocab_size_output, hidden_size), stddev = 0.2), name = "Wh_o")
        bh_o = tf.Variable(tf.random_normal((vocab_size_output,), stddev = 0.2), name = "bh_o")

        # Layers

        i_embedded = tf.nn.embedding_lookup(input_embeddings, input_layer)
        i_embedded_flattened = tf.reshape(i_embedded, 
                        (
                         (minibatch_size if minibatch_size is not None else -1), 
                         input_layer_flattened_size ) 
                        )

        h = tf.tanh(tf.matmul(i_embedded_flattened, Wh_i) + bh_i)
        nce_loss = tf.reduce_mean(tf.nn.nce_loss(Wh_o, bh_o, h, ref_input, nce_num_sampled, 
                                                     num_classes = vocab_size_output, name = "nce"))
        optimizer = tf.train.RMSPropOptimizer(learning_rate, decay = 0.9).minimize(nce_loss)

        init_op = tf.initialize_all_variables()


with tf.Session(graph=graph) as session:

    feed_dict = {input_layer : dummy_input, ref_input : dummy_target}
    _, loss_val = session.run([optimizer, nce_loss], feed_dict=feed_dict)


@mrry
Copy link
Contributor

mrry commented Dec 16, 2015

I think the issue here is related to a bug in array_grad._GatherGrad, and I have a fix in the works. (It seems to be broken when computing gradients for an embedding lookup/gather with a >1-D indices input.)

@mrry mrry assigned mrry and unassigned rafaljozefowicz Dec 16, 2015
@fabiencro
Copy link
Author

Thank you for fixing this :-) However will this also fix the RMSProp error? It seems to be of a slightly different nature...

@mrry
Copy link
Contributor

mrry commented Dec 16, 2015

Indeed, it will only fix the issues with adagrad and momentum (and other optimizers that support sparse data using embeddings).

@mrry mrry assigned rafaljozefowicz and unassigned mrry Dec 16, 2015
@rafaljozefowicz
Copy link

What kind of behavior would you expect from sparse tensors in RMSProp?
There are two options I think:

  • Ignore the momentum terms for the embedding rows that are not present in the current batch
  • Apply momentum terms to the whole embedding (AFAIR, that's what we do for Adam)

The first might not be correct (though, not sure), the second is very slow. What I typically do now is to split my variables into two groups, and train the dense part with any optimizer I want and the rest with GradientDescent or AdaGrad.

mrry added a commit that referenced this issue Dec 17, 2015
The gradient function was previously generating an invalid
IndexedSlices, whereby `IndexedSlices.indices` tensor was not a
vector. This change reshapes the indices and gradient so that they can
correctly be interpreted as an IndexedSlices and applied to the
embedding variable.

Added a multi-dimensional gradient test in embedding_ops_test.py.

Fixes #505. Partially addresses #464.
Change: 110364370
@fabiencro
Copy link
Author

It is quite possible that I do not fully understand the issue, but I was not expecting RMSProp to be much more difficult to use than Adagrad. I will try to briefly state how I see things (and please excuse me if I write something obviously stupid).

As for the two options you gave, I take it that you are discussing the update of the running average of the squared gradient, right? "ignore the momentum" would consist in not updating the running average for dimensions for which the gradient is zero due to sparsity. And "apply momentum for the whole embedding" would be to actually update the running average of all dimensions at each iteration.

Then I think there is a third option. You could keep track of the last iteration in which the running average was updated for each dimension.

Let us call last(d) the iteration step at which the running average of dimension d was last updated. Then, when you want to update dimension d at iteration step i, you can update the running average of the square gradient of d by rms(d) = rms(d) * 0.9 ^ ( i - last(d)) * 0.9 + 0.1 * grad_i(d)^2. Then set last(d) = i.

The 0.9 ^ ( i - last(d)) part account for the modification in the running gradient since the last time we saw a non-zero gradient for dimension d (null gradient (i - last(d) times). The values of last(d) and rms(d) only need to be updated when there is a non-null gradient for dimension d. Therefore the updates should be efficient in a sparse context.

@rafaljozefowicz
Copy link

You're right, the third option you suggested is probably strictly better than the second one. That's not how it's implemented in Adam, though (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/adam.py#L138)

For Adagrad, we're only updating non-zero elements and I think this is a common way to do it (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/training_ops.cc#L381)

@fabiencro
Copy link
Author

Yes, for Adagrad it is enough to not upgrade the squared gradient sum (and keep track of the global total number of updates).

Would you consider implementing this? I could probably do it myself at some point when I have time by looking at the existing code for RMSProp and Adam. But with lack of time and the annoyance of getting a Corporate Contributor Agreement, I would probably not contribute that anytime soon...

In any case, if nothing is changed, I think the docs should mention that ADAM is slower than Adagrad on sparse updates; and that RMSProp is not compatible with those.

@mrry mrry changed the title RMSProp, adagrad and momentum optimization support for sparse tensors RMSProp optimization support for sparse tensors Mar 18, 2016
@mrry
Copy link
Contributor

mrry commented Mar 18, 2016

Renaming this bug since the AdaGrad and Momentum optimizers should now work.

@lucaswiser
Copy link
Contributor

Hey what is the status of this?

@girving
Copy link
Contributor

girving commented Jun 6, 2016

I'm going to mark this contributions welcome, since I don't know of anyone working on it.

@girving girving added stat:contribution welcome Status - Contributions welcome triaged labels Jun 6, 2016
@aselle aselle removed the triaged label Jul 28, 2016
@aselle
Copy link
Contributor

aselle commented Feb 7, 2018

@fabiencro, is this still an issue? If not, or if it is not important, I will close it in a few days.

@aselle aselle added stat:awaiting response Status - Awaiting response from author and removed stat:contribution welcome Status - Contributions welcome labels Feb 7, 2018
@tensorflowbutler
Copy link
Member

Nagging Awaiting Response: It has been 14 days with no activityand the awaiting response label was assigned. Is this still an issue?

1 similar comment
@tensorflowbutler
Copy link
Member

Nagging Awaiting Response: It has been 14 days with no activityand the awaiting response label was assigned. Is this still an issue?

@tensorflowbutler
Copy link
Member

It has been 14 days with no activity and the awaiting response label was assigned. Is this still an issue?

@asimshankar
Copy link
Contributor

Closing as per previous comment: #464 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting response Status - Awaiting response from author type:feature Feature requests
Projects
None yet
Development

No branches or pull requests

9 participants