Skip to content

Commit

Permalink
feat: run inferences in an isolate using IsolateInterpreter (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
luiscib3r committed May 18, 2023
1 parent b77887e commit 972decf
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 29 deletions.
66 changes: 45 additions & 21 deletions example/style_transfer/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ class _HomeState extends State<Home> {
'assets/magenta/magenta_arbitrary-image-stylization-v1-256_int8_transfer_1.tflite';

late final Interpreter predictionInterpreter;
late final IsolateInterpreter predictionIsolateInterpreter;
late final Interpreter transferInterpreter;
late final IsolateInterpreter transferIsolateInterpreter;

final imagePicker = ImagePicker();
String? imagePath;
Expand All @@ -56,6 +58,13 @@ class _HomeState extends State<Home> {
loadModels();
}

@override
void dispose() {
predictionIsolateInterpreter.close();
transferIsolateInterpreter.close();
super.dispose();
}

// Clean old results when press some take picture button
void cleanResult() {
imagePath = null;
Expand Down Expand Up @@ -95,11 +104,17 @@ class _HomeState extends State<Home> {
options: predictionOptions,
);

predictionIsolateInterpreter =
IsolateInterpreter(address: predictionInterpreter.address);

transferInterpreter = await Interpreter.fromAsset(
transferModelPath,
options: transferOptions,
);

transferIsolateInterpreter =
IsolateInterpreter(address: transferInterpreter.address);

setState(() {});

log('Interpreters loaded successfully');
Expand Down Expand Up @@ -198,7 +213,7 @@ class _HomeState extends State<Home> {
];

// Run prediction inference
predictionInterpreter.run(predictionInput, predictionOutput);
await predictionIsolateInterpreter.run(predictionInput, predictionOutput);

// [1, 384, 384, 3]
final transferOutput = [
Expand All @@ -216,7 +231,7 @@ class _HomeState extends State<Home> {
];

// Run transfer inference
transferInterpreter.runForMultipleInputs(
await transferIsolateInterpreter.runForMultipleInputs(
transferInput,
{0: transferOutput},
);
Expand Down Expand Up @@ -263,25 +278,34 @@ class _HomeState extends State<Home> {
alignment: Alignment.center,
children: [
if (imagePath != null)
Stack(
children: [
Padding(
padding: const EdgeInsets.all(24),
child: imageResult != null
? Image.memory(imageResult!)
: Image.file(File(imagePath!)),
),
if (stylePath != null)
Positioned(
top: 0,
right: 0,
child: Image.asset(
stylePath!,
height: 48,
),
),
],
)
StreamBuilder(
stream: transferIsolateInterpreter.stateChanges,
builder: (context, state) {
return Stack(
fit: StackFit.expand,
children: [
Padding(
padding: const EdgeInsets.all(24),
child: imageResult != null
? Image.memory(imageResult!)
: Image.file(File(imagePath!)),
),
if (stylePath != null)
Positioned(
top: 0,
right: 0,
child: Image.asset(
stylePath!,
height: 48,
),
),
if (state.data == IsolateInterpreterState.loading)
const Center(
child: CircularProgressIndicator(),
),
],
);
})
else
Padding(
padding: const EdgeInsets.all(8.0),
Expand Down
20 changes: 12 additions & 8 deletions lib/src/interpreter.dart
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,21 @@ class Interpreter {

/// Run for multiple inputs and outputs
void runForMultipleInputs(List<Object> inputs, Map<int, Object> outputs) {
if (inputs.isEmpty) {
throw ArgumentError('Input error: Inputs should not be null or empty.');
}
if (outputs.isEmpty) {
throw ArgumentError('Input error: Outputs should not be null or empty.');
}
runInference(inputs);
var outputTensors = getOutputTensors();
for (var i = 0; i < outputTensors.length; i++) {
outputTensors[i].copyTo(outputs[i]!);
}
}

/// Just run inference
void runInference(List<Object> inputs) {
if (inputs.isEmpty) {
throw ArgumentError('Input error: Inputs should not be null or empty.');
}

var inputTensors = getInputTensors();

Expand All @@ -205,11 +214,6 @@ class Interpreter {
invoke();
_lastNativeInferenceDurationMicroSeconds =
DateTime.now().microsecondsSinceEpoch - inferenceStartNanos;

var outputTensors = getOutputTensors();
for (var i = 0; i < outputTensors.length; i++) {
outputTensors[i].copyTo(outputs[i]!);
}
}

/// Gets all input tensors associated with the model.
Expand Down
140 changes: 140 additions & 0 deletions lib/src/isolate_interpreter.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import 'dart:async';
import 'dart:isolate';

import 'package:tflite_flutter/tflite_flutter.dart';

class IsolateInterpreter {
IsolateInterpreter({
required this.address,
this.debugName = 'TfLiteInterpreterIsolate',
}) {
_init();
}

final int address;
final String debugName;

final ReceivePort _receivePort = ReceivePort();
late final SendPort _sendPort;
late final Isolate _isolate;

final StreamController<IsolateInterpreterState> _stateChanges =
StreamController.broadcast();
late final StreamSubscription _stateSubscription;
Stream<IsolateInterpreterState> get stateChanges => _stateChanges.stream;
IsolateInterpreterState __state = IsolateInterpreterState.idle;
IsolateInterpreterState get state => __state;
set _state(IsolateInterpreterState value) {
__state = value;
if (!_stateChanges.isClosed) {
_stateChanges.add(__state);
}
}

Future<void> _init() async {
_isolate = await Isolate.spawn(
_mainIsolate,
_receivePort.sendPort,
debugName: debugName,
);

_stateSubscription = _receivePort.listen((state) {
if (state is SendPort) {
_sendPort = state;
}

if (state is IsolateInterpreterState) {
_state = state;
}
});
}

static Future<void> _mainIsolate(SendPort sendPort) async {
final port = ReceivePort();

sendPort.send(port.sendPort);

await for (final _IsolateInterpreterData data in port) {
final interpreter = Interpreter.fromAddress(data.address);
sendPort.send(IsolateInterpreterState.loading);
interpreter.runInference(data.inputs);
sendPort.send(IsolateInterpreterState.idle);
}
}

/// Run for single input and output
Future<void> run(Object input, Object output) {
var map = <int, Object>{};
map[0] = output;

return runForMultipleInputs([input], map);
}

/// Run for multiple inputs and outputs
Future<void> runForMultipleInputs(
List<Object> inputs,
Map<int, Object> outputs,
) async {
if (state == IsolateInterpreterState.loading) return;
_state = IsolateInterpreterState.loading;

final data = _IsolateInterpreterData(
address: address,
inputs: inputs,
);

_sendPort.send(data);
await _wait();

final interpreter = Interpreter.fromAddress(address);
final outputTensors = interpreter.getOutputTensors();
for (var i = 0; i < outputTensors.length; i++) {
outputTensors[i].copyTo(outputs[i]!);
}
}

Future<void> _wait() async {
if (state == IsolateInterpreterState.loading) {
await for (final state in stateChanges) {
if (state == IsolateInterpreterState.idle) break;
}
}
}

Future<void> close() async {
await _stateSubscription.cancel();
await _stateChanges.close();
_isolate.kill();
}
}

enum IsolateInterpreterState {
idle,
loading,
}

class _IsolateInterpreterData {
_IsolateInterpreterData({
required this.address,
required this.inputs,
});

final int address;
final List<Object> inputs;
}
1 change: 1 addition & 0 deletions lib/tflite_flutter.dart
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export 'src/delegates/metal_delegate.dart';
export 'src/delegates/xnnpack_delegate.dart';
export 'src/interpreter.dart';
export 'src/interpreter_options.dart';
export 'src/isolate_interpreter.dart';
export 'src/quanitzation_params.dart';
export 'src/tensor.dart';
export 'src/util/byte_conversion_utils.dart';
Expand Down

0 comments on commit 972decf

Please sign in to comment.