Skip to content

Commit

Permalink
Merge pull request #2 from mucahittekin/main
Browse files Browse the repository at this point in the history
feat: Added image create method
  • Loading branch information
gtokman authored Nov 3, 2023
2 parents 1a85c51 + a6e3a46 commit ecdcf6c
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import com.aallam.openai.api.chat.ChatCompletionRequest
import com.aallam.openai.api.chat.ChatMessage
import com.aallam.openai.api.chat.ChatRole
import com.aallam.openai.api.http.Timeout
import com.aallam.openai.api.image.ImageCreation
import com.aallam.openai.api.image.ImageSize
import com.aallam.openai.api.model.ModelId
import com.aallam.openai.client.OpenAI
import com.aallam.openai.client.OpenAIConfig
Expand All @@ -26,6 +28,7 @@ import kotlinx.coroutines.flow.map
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.json.JSONObject
import java.util.Date
import java.util.HashMap
import kotlin.time.Duration.Companion.seconds

Expand Down Expand Up @@ -209,6 +212,31 @@ class ReactNativeOpenaiModule(reactContext: ReactApplicationContext) :
}
}

@ReactMethod
public fun imageCreate(input: ReadableMap,promise: Promise){
val prompt = input.getString("prompt") as String;
val n = if (input.hasKey("n")) input.getInt("n") else null
val size = if (input.hasKey("n")) input.getString("size") else null

runBlocking {
job = scope.launch {
var imageResult = openAIClient?.imageURL(creation = ImageCreation(prompt,n, ImageSize(size ?: "512x512")))
val map = mapOf(
"created" to Date().time,
"data" to (imageResult?.map {
mapOf(
"url" to it.url
)
} ?: emptyList())
)
val toReadableMap = Arguments.makeNativeMap(map)
promise.resolve(toReadableMap)
}

}

}

private fun toList(array: ReadableArray?): List<String> {
val list = mutableListOf<String>()
if (array != null) {
Expand Down
66 changes: 56 additions & 10 deletions example/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ import * as React from 'react';

import {
Animated,
Button,
Image,
Keyboard,
Platform,
SafeAreaView,
ScrollView,
StyleSheet,
Text,
TextInput,
View,
useColorScheme,
} from 'react-native';
import OpenAI from 'react-native-openai';
Expand Down Expand Up @@ -50,7 +53,7 @@ export default function App() {
showSubscription.remove();
hideSubscription.remove();
};
}, []);
}, [yPosition]);

React.useEffect(() => {
openAI.chat.addListener('onChatMessageReceived', (payload) => {
Expand All @@ -68,6 +71,10 @@ export default function App() {
};
}, [openAI]);

const [mode, setMode] = React.useState<'text' | 'image'>('text');

const [images, setImages] = React.useState<string[]>([]);

return (
<SafeAreaView
style={[
Expand Down Expand Up @@ -99,16 +106,24 @@ export default function App() {
return;
}
setResult('');
setImages([]);
console.log(e.nativeEvent.text);
openAI.chat.stream({
messages: [
{
role: 'user',
content: e.nativeEvent.text,
},
],
model: 'gpt-3.5-turbo',
});
if (mode === 'text') {
openAI.chat.stream({
messages: [
{
role: 'user',
content: e.nativeEvent.text,
},
],
model: 'gpt-3.5-turbo',
});
} else {
const result = await openAI.image.create({
prompt: e.nativeEvent.text,
});
setImages(result.data.map((image) => image.url));
}
}}
style={[
styles.input,
Expand All @@ -119,6 +134,22 @@ export default function App() {
]}
/>
</Animated.View>
<View style={{ flexDirection: 'row' }}>
<Button
title="Text"
color={mode === 'text' ? 'darkblue' : undefined}
onPress={() => {
setMode('text');
}}
/>
<Button
title="Image"
color={mode === 'image' ? 'darkblue' : undefined}
onPress={() => {
setMode('image');
}}
/>
</View>
<ScrollView
style={{ width: '100%', backgroundColor: 'transparent' }}
contentContainerStyle={{
Expand All @@ -136,6 +167,21 @@ export default function App() {
>
Result: {result}
</Text>
{images.map((image) => (
<View
key={image}
style={{
width: '100%',
height: 300,
backgroundColor: 'transparent',
}}
>
<Image
source={{ uri: image }}
style={{ width: '100%', height: '100%' }}
/>
</View>
))}
</ScrollView>
</SafeAreaView>
);
Expand Down
2 changes: 1 addition & 1 deletion ios/OpenAIKit/Image/Image.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ public struct Image {
public let url: String
}

extension Image: Decodable {}
extension Image: Codable {}

extension Image {
public enum Size: String {
Expand Down
9 changes: 9 additions & 0 deletions ios/OpenAIKit/Image/ImageInput.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import Foundation

struct ImageInput: Codable {
let prompt: String
var n: Int?
var size: Image.Size?
var user: String?

}
4 changes: 3 additions & 1 deletion ios/OpenAIKit/Image/ImageResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ public struct ImageResponse {
public let data: [Image]
}

extension ImageResponse: Decodable {}
extension ImageResponse: Codable {

}
4 changes: 4 additions & 0 deletions ios/ReactNativeOpenai.mm
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ @interface RCT_EXTERN_MODULE(ReactNativeOpenai, RCTEventEmitter)
RCT_EXTERN_METHOD(create:(NSDictionary *)input withResolver:(RCTPromiseResolveBlock)resolve
withRejecter:(RCTPromiseRejectBlock)reject)

// Image
RCT_EXTERN_METHOD(imageCreate:(NSDictionary *)input withResolver:(RCTPromiseResolveBlock)resolve
withRejecter:(RCTPromiseRejectBlock)reject)


+ (BOOL)requiresMainQueueSetup
{
Expand Down
26 changes: 26 additions & 0 deletions ios/ReactNativeOpenai.swift
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,30 @@ extension ReactNativeOpenai {
}
}
}

// MARK: - Image
@objc(imageCreate:withResolver:withRejecter:)
public func imageCreate(input: NSDictionary,resolve: @escaping RCTPromiseResolveBlock, reject: @escaping RCTPromiseRejectBlock){
Task {
do {
let decoded = try DictionaryDecoder().decode(ImageInput.self, from: input)
let imageResult = try await openAIClient.images.create(
prompt: decoded.prompt,
n:decoded.n ?? 1,
size:decoded.size ?? .fiveTwelve,
user:decoded.user
)

if let payload = String(data: try JSONEncoder().encode(imageResult), encoding: .utf8) {
resolve(payload)
} else {
reject("error", "error", nil)
}

} catch {
reject("error", "error", error)
}
}
}

}
39 changes: 39 additions & 0 deletions src/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ class OpenAI {
module = NativeModules.ReactNativeOpenai;
private bridge: NativeEventEmitter;
public chat: Chat;
public image: Image;

public constructor(config: Config) {
this.bridge = new NativeEventEmitter(this.module);
this.module.initialize(config);
this.chat = new Chat(this.module, this.bridge);
this.image = new Image(this.module);
}
}

Expand Down Expand Up @@ -112,6 +114,25 @@ namespace ChatModels {
};
}

namespace ImageModels {
type ImageSize = '256x256' | '512x512' | '1024x1024';

export type ImageInput = {
prompt: string;
n?: number;
size?: ImageSize;
};

type Image = {
url: string;
};

export type ImageOutput = {
created: Date;
data: Image[];
};
}

class Chat {
private bridge: NativeEventEmitter;
private module: any;
Expand Down Expand Up @@ -150,4 +171,22 @@ class Chat {
}
}

class Image {
private module: any;

public constructor(module: any) {
this.module = module;
}

public async create(
input: ImageModels.ImageInput
): Promise<ImageModels.ImageOutput> {
const result = await this.module.imageCreate(input);
if (Platform.OS === 'ios') {
return JSON.parse(result);
}
return result;
}
}

export default OpenAI;

0 comments on commit ecdcf6c

Please sign in to comment.