Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comparison with tract #50

Closed
igor-yusupov opened this issue Feb 4, 2024 · 13 comments
Closed

Comparison with tract #50

igor-yusupov opened this issue Feb 4, 2024 · 13 comments

Comments

@igor-yusupov
Copy link
Contributor

igor-yusupov commented Feb 4, 2024

In general I noticed that many models with rten run faster than tract (I think due to the fact that inference on tract run in single thread). but I noticed an interesting thing with transformers.

If the sequence length is short, for example 10 tokens, then inference with tract on my machine is 3 times faster. As the sequence length increases, rten starts to overtake tract in speed. For example, if submit 1000 tokens, then rten is 2 times faster.

But since during inference the decoder outputs each token separately, the difference becomes noticeable regardless of the sequence size.

code example:

use rten::{Input, Model, NodeId};
use rten_tensor::prelude::*;
use rten_tensor::NdTensorView;
use std::fs;
use std::time::Instant;
use tract_ndarray::{Array, ArrayView, Dim, Dimension, Ix};
use tract_onnx::prelude::*;

const TOKEN_SIZE: i32 = 1;

fn as_ndtensor_view<'a, T, const N: usize>(
    view: ArrayView<'a, T, Dim<[Ix; N]>>,
) -> Option<NdTensorView<'a, T, N>>
where
    Dim<[Ix; N]>: Dimension,
{
    view.to_slice().map(|slice| {
        let shape: [usize; N] = view.shape().try_into().unwrap();
        NdTensorView::from_data(shape, slice)
    })
}

fn rten_inference() {
    let start_time = Instant::now();
    let encoder = Model::load(&fs::read("encoder.rten").unwrap()).unwrap();
    let end_time = Instant::now();
    let elapsed_time = end_time.duration_since(start_time);
    println!(
        "Warmup time rten: {} milliseconds ⚡",
        elapsed_time.as_millis()
    );

    let tokens: Vec<i32> = (0..TOKEN_SIZE).collect();
    let src_size = tokens.len();

    let src = as_ndtensor_view(Array::from_shape_vec((1, src_size), tokens).unwrap().view())
        .unwrap()
        .to_tensor();
    let src_mask = Array::<i32, _>::from_elem((src_size, src_size), 0);
    let src_mask = as_ndtensor_view(src_mask.view()).unwrap().to_tensor();

    let src_id = encoder.node_id("src").unwrap();
    let src_mask_id = encoder.node_id("src_mask").unwrap();
    let memory_id = encoder.node_id("memory").unwrap();

    let inputs: Vec<(NodeId, Input)> = vec![
        (src_id, src.view().into()),
        (src_mask_id, src_mask.view().into()),
    ];

    let start_time = Instant::now();
    encoder.run_n(&inputs, [memory_id], None).unwrap();
    let end_time = Instant::now();
    let elapsed_time = end_time.duration_since(start_time);
    println!(
        "Inference time rten: {} milliseconds ⚡",
        elapsed_time.as_millis()
    );
}

fn inference_tract() {
    let start_time = Instant::now();
    let encoder = tract_onnx::onnx()
        .model_for_path("encoder.onnx")
        .unwrap()
        .with_output_fact(0, InferenceFact::default())
        .unwrap()
        .into_optimized()
        .unwrap()
        .into_runnable()
        .unwrap();
    let end_time = Instant::now();
    let elapsed_time = end_time.duration_since(start_time);
    println!(
        "Warmup time tract: {} milliseconds ⚡",
        elapsed_time.as_millis()
    );

    let tokens: Vec<i64> = (0..TOKEN_SIZE as i64).collect();
    let src_size = tokens.len();

    let src: Tensor = Array::from_shape_vec((1, src_size), tokens).unwrap().into();
    let src_mask: Tensor = Array::<bool, _>::from_elem((src_size, src_size), false).into();
    let inputs = tvec!(src.clone().into(), src_mask.into());

    let start_time = Instant::now();
    encoder.run(inputs).unwrap()[0]
        .to_array_view::<f32>()
        .unwrap()
        .to_owned();
    let end_time = Instant::now();
    let elapsed_time = end_time.duration_since(start_time);
    println!(
        "Inference time tract: {} milliseconds ⚡",
        elapsed_time.as_millis()
    );
}

fn main() {
    inference_tract();
    rten_inference();
}

I can attach weights if needed

@robertknight
Copy link
Owner

But since during inference the decoder outputs each token separately, the difference becomes noticeable regardless of the sequence size.

Is this using the same encoder.rten file as #46, produced by torch.nn.Transformer?

Assuming this is the case, then using the RTEN_TIMING environment variable (see profiling docs) with the by-shape=1 option, I can see that most of the time is spent in batched matrix multiplications:

Graph run of 401 ops finished in 15.917ms
MatMul           12.04ms (75.62%)

    Shape                      Count  Mean (ms)  Total (ms)  ns/input elem
    -------------------------  -----  ---------  ----------  -------------
    [10, 1, 512], [512, 1536]  4      1.664      6.655       2.102
    [10, 1, 512], [512, 512]   8      0.540      4.319       2.020
    [8, 10, 64], [8, 64, 10]   4      0.164      0.657       16.040
    [8, 10, 10], [8, 10, 64]   4      0.102      0.406       17.145

For the first two shape combinations that take up most of the time, the ordering of the first two dimensions turns the operation from batched matrix x matrix multiplication into batched vector x matrix multiplication and that's not handled very efficiently currently.

Looking at the model inputs, I see that the dim order is (sequence length, batch), which I presume corresponds to batch_first=False in PyTorch. I think most of the transformer models I'd tested with up to this point used batch-first so you'd end up with the "10" and "1"-sized dims swapped around in the timings above.

There are definitely optimizations possible in RTen here. In the interim, using batch_first=True in torch.nn.Transformer might work better, if you plan to use a batch size of 1 for inference.

robertknight added a commit that referenced this issue Feb 4, 2024
Batched matrix multiplication is handled by prepacking one or neither of
the inputs, depending on how often each is re-used, and then performing
one `gemm` call per matrix in the output shape.

This can be inefficient if the matrices in the batch end up being small in one
or both dimensions, for example if one of the matrices is a vector. In that
case it can be better to reshape the inputs so that instead of many
low-arithmetic intensity `gemm` calls, a single higher-arithmetic intensity
call is performed. The output is then reshaped to restore the batch dimensions.

See #50
@igor-yusupov
Copy link
Contributor Author

Thank you for answer! I will try to compare with batch_first=True

robertknight added a commit that referenced this issue Feb 5, 2024
Batched matrix multiplication is handled by prepacking one or neither of
the inputs, depending on how often each is re-used, and then performing
one `gemm` call per matrix in the output shape.

This can be inefficient if the matrices in the batch end up being small in one
or both dimensions, for example if one of the matrices is a vector. In that
case it can be better to reshape the inputs so that instead of many
low-arithmetic intensity `gemm` calls, a single higher-arithmetic intensity
call is performed. The output is then reshaped to restore the batch dimensions.

See #50
@igor-yusupov
Copy link
Contributor Author

I tried to set up batch_first=True and it seems that the speed is the same:

MatMul           9.91ms (74.24%)

    Shape                      Count  Mean (ms)  Total (ms)  ns/input elem  
    -------------------------  -----  ---------  ----------  -------------  
    [10, 1, 512], [512, 1536]  4      1.912      7.649       2.416          
    [1, 10, 512], [512, 512]   8      0.209      1.672       0.782          
    [8, 10, 64], [8, 64, 10]   4      0.080      0.320       7.812          
    [8, 10, 10], [8, 10, 64]   4      0.067      0.268       11.318 

I also changed order of batch_size and sequence_length)

@robertknight
Copy link
Owner

I tried to set up batch_first=True and it seems that the speed is the same:

Internally it seems the matrix multiplications are still being done in the same order and devolving to vector x matrix. The draft in #51 should improve this once it lands.

@igor-yusupov
Copy link
Contributor Author

I'll look forward to it, thank you!

robertknight added a commit that referenced this issue Feb 6, 2024
Batched matrix multiplication was handled by prepacking one or neither of
the inputs, depending on how often each is re-used, and then performing
one `gemm` call per matrix in the output shape.

This can be inefficient if one of the matrices passed to a gemm call ends up
being small in one or both dimensions. For example in [1], the LHS / "A" input
is a vector. In the case where the "A" input is a batch and the "B" input is a
single matrix, the "A" input can be reshaped so a single gemm call can be used,
with the output reshaped afterwards to restore the batch dimensions.

In addition to the strategy, add a simple benchmark for different input shapes.

[1] #50
robertknight added a commit that referenced this issue Feb 6, 2024
Batched matrix multiplication was handled by prepacking one or neither of
the inputs, depending on how often each is re-used, and then performing
one `gemm` call per matrix in the output shape.

This can be inefficient the LHS input has a small number of rows. For example in
[1], the LHS / "A" input is a row vector. In the case where the "A" input is a
batch and the "B" input is a single matrix, the "A" input can be reshaped so a
single gemm call can be used, with the output reshaped afterwards to restore the
batch dimensions.

Implement this alternate approach and add a simple benchmark for batched matmul.

[1] #50
@robertknight
Copy link
Owner

#51 has been merged, which should improve performance in this case.

@igor-yusupov
Copy link
Contributor Author

igor-yusupov commented Feb 7, 2024

Thank you so much! Yeah, now it works faster, but it seems that tract is still faster. On my laptop the difference is literally 5 - 10 milliseconds when encoding/decoding 1 token, but since each token is generated separately in a loop during decoding, the difference accumulates and depending on the length of the sequence can increase significantly.

But now encoding with rten >=10 tokens is already faster.

In general, I use the Transformer model to translate texts. And now I have a tract that works 1.5 times faster on average. But maybe it's because of my implementation, as I divide the text into sentences, and translate the sentences in parallel ... Well I'll try to look deeper into what might be causing the difference. But in the example above, you can just see the difference of 1.5 times if TOKEN_SIZE = 1 is specified (I updated the example)

@robertknight
Copy link
Owner

But maybe it's because of my implementation, as I divide the text into sentences, and translate the sentences in parallel

Are you running the model with a single sentence at a time (ie. batch size = 1) or with a batch of sentences? If you have many sentences what I would recommend doing is grouping the sentences into batches of approximately equal length, with appropriate padding and masks, and calling the model once per batch. Each call to run the model has some overhead, and passing in a batch of inputs amortizes this overhead over the batch.

@igor-yusupov
Copy link
Contributor Author

Yes, but if for example there are only 2 sentences of different length, it seems that it will waste time to decode the paddings of one of the sentences. I think in Python would be used the same way if it had the same ability to parallelize code as Rust.
And I also planned to get away from that approach because that way the context between sentences is lost.

I checked the inference of decoder and it runs also slower than with tract when there is batch_size = 1. I can attach weights of decoder and example of inference if you want.

I prepare example code with inference all model's parts:
encoder: https://www.dropbox.com/scl/fi/6kw23cavybwhqkiy4580m/encoder.rten?rlkey=8q9bq0z1tzbj0ajtz0i56nb0q&dl=0
decoder: https://www.dropbox.com/scl/fi/ls7jzvfyiek8qohbs9k16/decoder.rten?rlkey=suj51i4ij2kqknwlqpafxc0dy&dl=0
generator: https://www.dropbox.com/scl/fi/61d2ifjdtya45ad9897x8/generator.rten?rlkey=5nzscdraea27hbd6qeweu6dak&dl=0

use ndarray::{Array, ArrayView, Axis, Dim, Dimension, Ix, StrideShape};
use rten::{Input, Model, NodeId};
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView};
use std::fs;

const TOKEN_SIZE: i32 = 1;

fn as_ndtensor_view<'a, T, const N: usize>(
    view: ArrayView<'a, T, Dim<[Ix; N]>>,
) -> Option<NdTensorView<'a, T, N>>
where
    Dim<[Ix; N]>: Dimension,
{
    view.to_slice().map(|slice| {
        let shape: [usize; N] = view.shape().try_into().unwrap();
        NdTensorView::from_data(shape, slice)
    })
}

fn as_array_view<'a, T, const N: usize>(
    view: NdTensorView<'a, T, N>,
) -> Option<ArrayView<'a, T, Dim<[Ix; N]>>>
where
    Dim<[Ix; N]>: Dimension,
    [usize; N]: Into<StrideShape<Dim<[Ix; N]>>>,
{
    view.data()
        .map(|data| ArrayView::from_shape(view.shape(), data).unwrap())
}

fn rten_inference() {
    let encoder = Model::load(&fs::read("encoder.rten").unwrap()).unwrap();
    let decoder = Model::load(&fs::read("decoder.rten").unwrap()).unwrap();
    let generator = Model::load(&fs::read("generator.rten").unwrap()).unwrap();

    let tokens: Vec<i32> = (0..TOKEN_SIZE).collect();
    let src_size = tokens.len();

    let src = as_ndtensor_view(Array::from_shape_vec((1, src_size), tokens).unwrap().view())
        .unwrap()
        .to_tensor();
    let src_mask = Array::<i32, _>::from_elem((src_size, src_size), 0);
    let src_mask = as_ndtensor_view(src_mask.view()).unwrap().to_tensor();

    let src_id = encoder.node_id("src").unwrap();
    let src_mask_id = encoder.node_id("src_mask").unwrap();
    let memory_id = encoder.node_id("memory").unwrap();

    let inputs: Vec<(NodeId, Input)> = vec![
        (src_id, src.view().into()),
        (src_mask_id, src_mask.view().into()),
    ];

    let [memory] = encoder.run_n(&inputs, [memory_id], None).unwrap();
    let memory: NdTensor<f32, 3> = memory.try_into().unwrap();

    let start_symbol = 1;

    let memory_id = decoder.node_id("memory").unwrap();
    let ys_id = decoder.node_id("ys").unwrap();
    let tgt_mask_id = decoder.node_id("tgt_mask").unwrap();

    let ys = Array::from_elem((1, 1), start_symbol as i32);

    let tgt_mask_size = *ys.shape().get(0).unwrap() as usize;
    let tgt_mask =
        as_ndtensor_view(
            Array::<i32, _>::from_shape_fn((tgt_mask_size, tgt_mask_size), |(i, j)| {
                if i < j {
                    1
                } else {
                    0
                }
            })
            .view(),
        )
        .unwrap()
        .to_tensor();

    let ys_tensor = as_ndtensor_view(ys.view()).unwrap().to_tensor();
    let inputs_decoder: Vec<(NodeId, Input)> = vec![
        (ys_id, ys_tensor.view().into()),
        (memory_id, memory.view().into()),
        (tgt_mask_id, tgt_mask.view().into()),
    ];
    let outs_id = decoder.node_id("outs").unwrap();
    let [out] = decoder.run_n(&inputs_decoder, [outs_id], None).unwrap();

    let out: NdTensor<f32, 3> = out.try_into().unwrap();
    let out = as_array_view(out.view()).unwrap();
    let out = out.index_axis(Axis(0), out.shape()[0] - 1);
    let out = as_ndtensor_view(out.view()).unwrap().to_tensor();
    let inputs_generator = out;
    generator
        .run_one(inputs_generator.view().into(), None)
        .unwrap();
}

fn main() {
    rten_inference();
}

By the way rten now initializes models much faster than tract, and if you can do the same speed inference when batch_size=1 for decoder and generator, that would be awesome!

I can give more examples and more information about decoder and generator if needed

@robertknight
Copy link
Owner

Yes, but if for example there are only 2 sentences of different length, it seems that it will waste time to decode the paddings of one of the sentences. I think in Python would be used the same way if it had the same ability to parallelize code as Rust.

It is true that computation on padding is wasted, but depending on how much padding there is, the benefits of batching can still outweigh this. What I do in Ocrs for recognizing lines of text, of varying width, is to choose a threshold and form batches such that every image in the batch has no more than P padding elements. Then the batches are processed in parallel.

In any case, I agree that the batch_size=1 case is important, as you can't always use batching.

I can give more examples and more information about decoder and generator if needed

Yes, that would be helpful. It would understand at a high level what the processing pipeline / inference loop looks like, and what are typical/realistic input shapes for each of the models. I'm familiar with the standard transformer encoder-decoder where you do something like:

  1. Tokenize sentence (using eg. BPE)
  2. Feed sentence through encoder
  3. Feed encoder outputs + start token into decoder.
  4. Sample next token from decoder outputs (eg. with argmax)
  5. If next token is end-of-sentence break, otherwise run decoder again with next token as input

Is that what you are doing here or is it different? I notice that the decoder doesn't have KV-cache inputs, so presumably you're feeding it with a sequence that grows longer for each iteration of the loop?

@igor-yusupov
Copy link
Contributor Author

yes, I use a Transformer exactly as you described. You can find translate and greedy_decode functions here: https://pytorch.org/tutorials/beginner/translation_transformer.html

I notice that the decoder doesn't have KV-cache inputs

Yeah, I think I should update the model architecture to be able to store cache from the attention mechanism while decoding the sequence.

I'll also try to figure out the ocrs repository, thanks a lot!

Btw it would be awesome to use quant models, I attached weights in #42

@igor-yusupov
Copy link
Contributor Author

I looked at the code in nn.Transformer, and it seems they don't use a key/value cache, because they don't run what the encoder outputs through the Linear Layer, but use it immediately in the attention mechanism.
https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/transformer.py#L909

So I will try to check running with batch of sentences

@robertknight
Copy link
Owner

robertknight commented Apr 12, 2024

Thanks for the input on this. The latest RTen releases include various optimizations for transformer decoders, including:

  • Special case matrix multiplication for vector-matrix products
  • Model::partial_run API to pre-evaluate parts of the model graph which don't change across timesteps (mostly benefits encoder + decoder models, not encoder-only or decoder-only)
  • Faster Transpose
  • Reduced overhead for each Graph::run call

See the changes for the 0.6.0, 0.7.0 releases for details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants