diff --git a/python/tflite_micro/signal/BUILD b/python/tflite_micro/signal/BUILD index e5164759934..a0f9ae146b6 100644 --- a/python/tflite_micro/signal/BUILD +++ b/python/tflite_micro/signal/BUILD @@ -15,6 +15,7 @@ cc_library( name = "ops_lib", visibility = [":signal_friends"], deps = [ + ":delay_op_cc", ":fft_ops_cc", ":window_op_cc", ], @@ -29,11 +30,35 @@ py_library( srcs_version = "PY3", visibility = ["//python/tflite_micro/signal/utils:__subpackages__"], deps = [ + ":delay_op", ":fft_ops", ":window_op", ], ) +py_tflm_signal_library( + name = "delay_op", + srcs = ["ops/delay_op.py"], + cc_op_defs = ["//signal/tensorflow_core/ops:delay_op"], + cc_op_kernels = [ + "//signal/tensorflow_core/kernels:delay_kernel", + ], +) + +py_test( + name = "delay_op_test", + size = "small", + srcs = ["ops/delay_op_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":delay_op", + "//python/tflite_micro/signal/utils:util", + requirement("numpy"), + requirement("tensorflow-cpu"), + ], +) + py_tflm_signal_library( name = "fft_ops", srcs = ["ops/fft_ops.py"], diff --git a/python/tflite_micro/signal/ops/delay_op.py b/python/tflite_micro/signal/ops/delay_op.py new file mode 100644 index 00000000000..c7508049ab3 --- /dev/null +++ b/python/tflite_micro/signal/ops/delay_op.py @@ -0,0 +1,37 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Use overlap add op in python.""" + +import tensorflow as tf +from tflite_micro.python.tflite_micro.signal.utils import util + +gen_delay_op = util.load_custom_op('delay_op.so') + + +def _delay_wrapper(delay_fn, default_name): + """Wrapper around gen_delay_op.delay*.""" + + def _delay(input_tensor, delay_length, name=default_name): + with tf.name_scope(name) as name: + input_tensor = tf.convert_to_tensor(input_tensor, dtype=tf.int16) + return delay_fn(input_tensor, delay_length=delay_length, name=name) + + return _delay + + +# TODO(b/286250473): change back name after name clash resolved +delay = _delay_wrapper(gen_delay_op.signal_delay, "signal_delay") + +tf.no_gradient("signal_delay") diff --git a/python/tflite_micro/signal/ops/delay_op_test.py b/python/tflite_micro/signal/ops/delay_op_test.py new file mode 100644 index 00000000000..66b033fc977 --- /dev/null +++ b/python/tflite_micro/signal/ops/delay_op_test.py @@ -0,0 +1,85 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for delay op.""" + +import numpy as np +import tensorflow as tf + +from tflite_micro.python.tflite_micro.signal.ops import delay_op +from tflite_micro.python.tflite_micro.signal.utils import util + + +class DelayOpTest(tf.test.TestCase): + + def TestHelper(self, input_signal, delay_length, frame_size): + inner_dim_size = input_signal.shape[-1] + input_signal_rank = len(input_signal.shape) + frame_num = int(np.ceil((inner_dim_size + delay_length) / frame_size)) + # We need to continue feeding the op with zeros until the delay line is + # flushed. Pad the input signal to a multiple of frame_size. + padded_size = frame_num * frame_size + pad_size = int(padded_size - inner_dim_size) + # Axes to pass to np.pad. All axes have no padding except the innermost one. + pad_outer_axes = np.zeros([input_signal_rank - 1, 2], dtype=int) + pad_input_signal = np.vstack([pad_outer_axes, [0, pad_size]]) + input_signal_padded = np.pad(input_signal, pad_input_signal) + delay_exp_signal = np.vstack( + [pad_outer_axes, [delay_length, pad_size - delay_length]]) + delay_exp = np.pad(input_signal, delay_exp_signal) + delay_out = np.zeros(input_signal_padded.shape) + + in_frame_shape = input_signal.shape[:-1] + (frame_size, ) + func = tf.function(delay_op.delay) + concrete_function = func.get_concrete_function(tf.TensorSpec( + in_frame_shape, dtype=tf.int16), + delay_length=delay_length) + interpreter = util.get_tflm_interpreter(concrete_function, func) + + for i in range(frame_num): + in_frame = input_signal_padded[..., i * frame_size:(i + 1) * frame_size] + # TFLM + interpreter.set_input(in_frame, 0) + interpreter.invoke() + out_frame_tflm = interpreter.get_output(0) + # TF + out_frame = self.evaluate( + delay_op.delay(in_frame, delay_length=delay_length)) + delay_out[..., i * frame_size:(i + 1) * frame_size] = out_frame + self.assertAllEqual(out_frame, out_frame_tflm) + self.assertAllEqual(delay_out, delay_exp) + + def testFrameLargerThanDelay(self): + self.TestHelper(np.arange(0, 30, dtype=np.int16), 7, 10) + + def testFrameSmallerThanDelay(self): + self.TestHelper(np.arange(0, 70, dtype=np.int16), 21, 3) + + def testZeroDelay(self): + self.TestHelper(np.arange(0, 20, dtype=np.int16), 0, 3) + + def testNegativeDelay(self): + with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)): + self.TestHelper(np.arange(1, 20, dtype=np.int16), -21, 3) + + def testMultiDimensionalDelay(self): + input_signal = np.reshape(np.arange(0, 120, dtype=np.int16), [2, 3, 20]) + self.TestHelper(input_signal, 4, 6) + input_signal = np.reshape(np.arange(0, 72, dtype=np.int16), + [2, 2, 3, 3, 2]) + self.TestHelper(input_signal, 7, 3) + + +if __name__ == '__main__': + tf.test.main() diff --git a/signal/tensorflow_core/kernels/BUILD b/signal/tensorflow_core/kernels/BUILD index 6745d9c991f..bea8eefb3b9 100644 --- a/signal/tensorflow_core/kernels/BUILD +++ b/signal/tensorflow_core/kernels/BUILD @@ -5,6 +5,15 @@ package( licenses = ["notice"], ) +tflm_signal_kernel_library( + name = "delay_kernel", + srcs = ["delay_kernel.cc"], + deps = [ + "//signal/src:circular_buffer", + "@tensorflow_cc_deps//:cc_library", + ], +) + tflm_signal_kernel_library( name = "fft_kernel", srcs = ["fft_kernels.cc"], diff --git a/signal/tensorflow_core/kernels/delay_kernel.cc b/signal/tensorflow_core/kernels/delay_kernel.cc new file mode 100644 index 00000000000..8c5c505e8e0 --- /dev/null +++ b/signal/tensorflow_core/kernels/delay_kernel.cc @@ -0,0 +1,94 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "signal/src/circular_buffer.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace signal { + +class DelayOp : public tensorflow::OpKernel { + public: + explicit DelayOp(tensorflow::OpKernelConstruction* context) + : tensorflow::OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("delay_length", &delay_length_)); + initialized_ = false; + } + + ~DelayOp() {} + + void Compute(tensorflow::OpKernelContext* context) override { + const tensorflow::Tensor& input_tensor = context->input(0); + if (!initialized_) { + frame_size_ = input_tensor.flat_inner_dims().dimensions().at(1); + outer_dims_ = input_tensor.flat_inner_dims().dimensions().at(0); + + state_tensors_.resize(outer_dims_); + circular_buffers_.resize(outer_dims_); + + // Calculate the capacity of the circular buffer. + size_t capacity = frame_size_ + delay_length_; + size_t state_size = + tflite::tflm_signal::CircularBufferGetNeededMemory(capacity); + for (int i = 0; i < outer_dims_; i++) { + OP_REQUIRES_OK( + context, + context->allocate_temp( + DT_INT8, TensorShape({static_cast(state_size)}), + &state_tensors_[i])); + int8_t* state_ = state_tensors_[i].flat().data(); + circular_buffers_[i] = tflite::tflm_signal::CircularBufferInit( + capacity, state_, state_size); + tflite::tflm_signal::CircularBufferWriteZeros(circular_buffers_[i], + delay_length_); + } + initialized_ = true; + } + + TensorShape output_shape = input_tensor.shape(); + tensorflow::Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, output_shape, &output_tensor)); + + for (int dim_index = 0, sample_index = 0; dim_index < outer_dims_; + dim_index++, sample_index += frame_size_) { + tflite::tflm_signal::CircularBufferWrite( + circular_buffers_[dim_index], + &input_tensor.flat().data()[sample_index], frame_size_); + tflite::tflm_signal::CircularBufferGet( + circular_buffers_[dim_index], frame_size_, + &(reinterpret_cast(output_tensor->data()))[sample_index]); + tflite::tflm_signal::CircularBufferDiscard(circular_buffers_[dim_index], + frame_size_); + } + } + + private: + bool initialized_; + int frame_size_; + int delay_length_; + int outer_dims_; + std::vector state_tensors_; + std::vector circular_buffers_; +}; + +// TODO(b/286250473): change back name after name clash resolved +REGISTER_KERNEL_BUILDER(Name("SignalDelay").Device(tensorflow::DEVICE_CPU), + DelayOp); + +} // namespace signal +} // namespace tensorflow diff --git a/signal/tensorflow_core/ops/BUILD b/signal/tensorflow_core/ops/BUILD index 8c24f22816f..436dde54c7d 100644 --- a/signal/tensorflow_core/ops/BUILD +++ b/signal/tensorflow_core/ops/BUILD @@ -5,6 +5,14 @@ package( licenses = ["notice"], ) +tflm_signal_kernel_library( + name = "delay_op", + srcs = ["delay_op.cc"], + deps = [ + "@tensorflow_cc_deps//:cc_library", + ], +) + tflm_signal_kernel_library( name = "fft_ops", srcs = ["fft_ops.cc"], diff --git a/signal/tensorflow_core/ops/delay_op.cc b/signal/tensorflow_core/ops/delay_op.cc new file mode 100644 index 00000000000..dfee3d2450b --- /dev/null +++ b/signal/tensorflow_core/ops/delay_op.cc @@ -0,0 +1,58 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +using tensorflow::shape_inference::InferenceContext; +using tensorflow::shape_inference::ShapeHandle; + +namespace tensorflow { +namespace signal { + +Status DelayShape(InferenceContext* c) { + ShapeHandle out; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &out)); + c->set_output(0, out); + return OkStatus(); +} + +// TODO(b/286250473): change back name after name clash resolved +REGISTER_OP("SignalDelay") + .Attr("delay_length: int >= 0") + .Input("input: int16") + .Output("output: int16") + .SetShapeFn(DelayShape) + .Doc(R"doc( +Delay the innermost dimension of input signal by delay_length samples. + +For example, assuming an input signal of 10 samples, +[1 2 3 4 5 6 7 8 9 0] +If we input the signal to a delay op configured with delay_length=3, the op +will produce the following output: +[0 0 0 1 2 3 4 5 6 7] +To retrieve the remainder of the input signal, call the delay op again with +zeros as input: +[0 0 0 0 0 0 0 0 0 0] +to get the output: +[8 9 0 0 0 0 0 0 0 0] + +input: A multidimensional input signal. +output: An output signal of the same shape as the input signal. The innermost + dimension is delayed by delay_length samples. +)doc"); + +} // namespace signal +} // namespace tensorflow