From a45fcb7ab75846b115457ea6a9282ef109f363bc Mon Sep 17 00:00:00 2001 From: Asghar Ghorbani Date: Sat, 7 Dec 2024 10:53:46 +0100 Subject: [PATCH] [Feat] add edit message and retry generation (#128) * feat: add edit/copy/regenerate menu to the chat bubbles * feat: regenerate a response with the same or another model * feat: add edit previous message * chore: add menu, menu item and submenu components * chore: refactor theme colors --- .../external/react-native-haptic-feedback.js | 5 + __mocks__/stores/chatSessionStore.ts | 3 + __mocks__/stores/modelStore.ts | 126 ++- ios/Podfile.lock | 25 + jest.config.js | 2 + jest/setup.ts | 2 + package.json | 1 + src/api/__tests__/hf.test.ts | 77 ++ .../__tests__/BottomSheetSearchbar.test.tsx | 74 ++ src/components/Bubble/Bubble.tsx | 24 +- src/components/Bubble/styles.ts | 2 +- src/components/ChatInput/ChatInput.tsx | 187 ++++ .../__tests__/ChatInput.test.tsx} | 20 +- src/components/ChatInput/index.ts | 1 + src/components/ChatInput/styles.ts | 63 ++ src/components/ChatView/ChatView.tsx | 892 +++++++++++------- .../ChatView/__tests__/ChatView.test.tsx | 5 +- src/components/ChatView/styles.ts | 7 +- src/components/Input/Input.tsx | 130 --- src/components/Input/index.ts | 1 - src/components/Input/styles.ts | 25 - src/components/MarkdownView/MarkdownView.tsx | 25 +- src/components/Menu/Menu.tsx | 78 ++ src/components/Menu/MenuContext.tsx | 5 + src/components/Menu/MenuItem/MenuItem.tsx | 172 ++++ .../Menu/MenuItem/__tests__/MenuItem.test.tsx | 64 ++ src/components/Menu/MenuItem/index.ts | 1 + src/components/Menu/MenuItem/styles.ts | 49 + src/components/Menu/SubMenu/SubMenu.tsx | 69 ++ .../Menu/SubMenu/__tests__/SubMenu.test.tsx | 49 + src/components/Menu/SubMenu/index.ts | 1 + src/components/Menu/SubMenu/styles.ts | 17 + src/components/Menu/__tests__/Menu.test.tsx | 35 + src/components/Menu/index.ts | 2 + src/components/Menu/styles.ts | 34 + src/components/Message/Message.tsx | 62 +- .../ModelsHeaderRight/ModelsHeaderRight.tsx | 51 +- src/components/ModelsHeaderRight/styles.ts | 23 +- src/components/TextMessage/TextMessage.tsx | 1 + src/components/index.ts | 3 +- src/hooks/__tests__/useChatSession.test.ts | 69 +- src/hooks/__tests__/useMessageActions.test.ts | 227 +++++ src/hooks/index.ts | 1 + src/hooks/useChatSession.ts | 63 +- src/hooks/useMessageActions.ts | 93 ++ src/screens/ChatScreen/ChatScreen.tsx | 26 +- .../__tests__/ModelNotLoadedMessage.test.tsx | 28 +- .../ChatScreen/__tests__/ChatScreen.test.tsx | 20 +- .../__tests__/ModelAccordion.test.tsx | 27 +- .../ModelsScreen/ModelCard/ModelCard.tsx | 2 +- .../ModelCard/__tests__/ModelCard.test.tsx | 5 +- src/store/ChatSessionStore.ts | 100 +- src/store/ModelStore.ts | 24 +- src/utils/colorUtils.ts | 79 ++ src/utils/theme.ts | 381 +++++--- src/utils/types.ts | 120 ++- yarn.lock | 5 + 57 files changed, 2730 insertions(+), 953 deletions(-) create mode 100644 __mocks__/external/react-native-haptic-feedback.js create mode 100644 src/api/__tests__/hf.test.ts create mode 100644 src/components/BottomSheetSearchbar/__tests__/BottomSheetSearchbar.test.tsx create mode 100644 src/components/ChatInput/ChatInput.tsx rename src/components/{Input/__tests__/Input.test.tsx => ChatInput/__tests__/ChatInput.test.tsx} (97%) create mode 100644 src/components/ChatInput/index.ts create mode 100644 src/components/ChatInput/styles.ts delete mode 100644 src/components/Input/Input.tsx delete mode 100644 src/components/Input/index.ts delete mode 100644 src/components/Input/styles.ts create mode 100644 src/components/Menu/Menu.tsx create mode 100644 src/components/Menu/MenuContext.tsx create mode 100644 src/components/Menu/MenuItem/MenuItem.tsx create mode 100644 src/components/Menu/MenuItem/__tests__/MenuItem.test.tsx create mode 100644 src/components/Menu/MenuItem/index.ts create mode 100644 src/components/Menu/MenuItem/styles.ts create mode 100644 src/components/Menu/SubMenu/SubMenu.tsx create mode 100644 src/components/Menu/SubMenu/__tests__/SubMenu.test.tsx create mode 100644 src/components/Menu/SubMenu/index.ts create mode 100644 src/components/Menu/SubMenu/styles.ts create mode 100644 src/components/Menu/__tests__/Menu.test.tsx create mode 100644 src/components/Menu/index.ts create mode 100644 src/components/Menu/styles.ts create mode 100644 src/hooks/__tests__/useMessageActions.test.ts create mode 100644 src/hooks/useMessageActions.ts create mode 100644 src/utils/colorUtils.ts diff --git a/__mocks__/external/react-native-haptic-feedback.js b/__mocks__/external/react-native-haptic-feedback.js new file mode 100644 index 0000000..2a52509 --- /dev/null +++ b/__mocks__/external/react-native-haptic-feedback.js @@ -0,0 +1,5 @@ +const ReactNativeHapticFeedback = { + trigger: jest.fn(), +}; + +export default ReactNativeHapticFeedback; diff --git a/__mocks__/stores/chatSessionStore.ts b/__mocks__/stores/chatSessionStore.ts index e8ef070..f7b424d 100644 --- a/__mocks__/stores/chatSessionStore.ts +++ b/__mocks__/stores/chatSessionStore.ts @@ -18,6 +18,9 @@ export const mockChatSessionStore = { createNewSession: jest.fn(), updateMessage: jest.fn(), updateMessageToken: jest.fn(), + exitEditMode: jest.fn(), + enterEditMode: jest.fn(), + removeMessagesFromId: jest.fn(), }; Object.defineProperty(mockChatSessionStore, 'currentSessionMessages', { diff --git a/__mocks__/stores/modelStore.ts b/__mocks__/stores/modelStore.ts index aea647a..8b589e7 100644 --- a/__mocks__/stores/modelStore.ts +++ b/__mocks__/stores/modelStore.ts @@ -1,41 +1,91 @@ +import {computed, makeAutoObservable, ObservableMap} from 'mobx'; + import {modelsList} from '../../jest/fixtures/models'; -export const mockModelStore = { - models: modelsList, - n_context: 1024, - MIN_CONTEXT_SIZE: 200, - useAutoRelease: true, - useMetal: false, - n_gpu_layers: 50, - activeModelId: undefined as string | undefined, - setNContext: jest.fn(), - updateUseAutoRelease: jest.fn(), - updateUseMetal: jest.fn(), - setNGPULayers: jest.fn(), - refreshDownloadStatuses: jest.fn(), - addLocalModel: jest.fn(), - resetModels: jest.fn(), - initContext: jest.fn().mockResolvedValue(Promise.resolve()), - checkSpaceAndDownload: jest.fn(), - getDownloadProgress: jest.fn(), - manualReleaseContext: jest.fn(), - setActiveModel(modelId: string) { +import {Model} from '../../src/utils/types'; + +class MockModelStore { + models = modelsList; + n_context = 1024; + MIN_CONTEXT_SIZE = 200; + useAutoRelease = true; + useMetal = false; + n_gpu_layers = 50; + activeModelId: string | undefined; + inferencing = false; + isStreaming = false; + downloadJobs = new ObservableMap(); + + refreshDownloadStatuses: jest.Mock; + addLocalModel: jest.Mock; + setNContext: jest.Mock; + updateUseAutoRelease: jest.Mock; + updateUseMetal: jest.Mock; + setNGPULayers: jest.Mock; + resetModels: jest.Mock; + initContext: jest.Mock; + lastUsedModelId: any; + checkSpaceAndDownload: jest.Mock; + getDownloadProgress: jest.Mock; + manualReleaseContext: jest.Mock; + + constructor() { + makeAutoObservable(this, { + refreshDownloadStatuses: false, + addLocalModel: false, + setNContext: false, + updateUseAutoRelease: false, + updateUseMetal: false, + setNGPULayers: false, + resetModels: false, + initContext: false, + checkSpaceAndDownload: false, + getDownloadProgress: false, + manualReleaseContext: false, + lastUsedModel: computed, + activeModel: computed, + isDownloading: computed, + }); + this.refreshDownloadStatuses = jest.fn(); + this.addLocalModel = jest.fn(); + this.setNContext = jest.fn(); + this.updateUseAutoRelease = jest.fn(); + this.updateUseMetal = jest.fn(); + this.setNGPULayers = jest.fn(); + this.resetModels = jest.fn(); + this.initContext = jest.fn().mockResolvedValue(Promise.resolve()); + this.checkSpaceAndDownload = jest.fn(); + this.getDownloadProgress = jest.fn(); + this.manualReleaseContext = jest.fn(); + } + + setActiveModel = (modelId: string) => { this.activeModelId = modelId; - }, -}; -Object.defineProperty(mockModelStore, 'lastUsedModel', { - get: jest.fn(() => undefined), - configurable: true, -}); -Object.defineProperty(mockModelStore, 'isDownloading', { - get: jest.fn(() => () => false), - configurable: true, -}); -Object.defineProperty(mockModelStore, 'activeModel', { - get: jest.fn(() => - mockModelStore.models.find( - model => model.id === mockModelStore.activeModelId, - ), - ), - configurable: true, -}); + }; + + setInferencing = (value: boolean) => { + this.inferencing = value; + }; + + setIsStreaming = (value: boolean) => { + this.isStreaming = value; + }; + + get lastUsedModel(): Model | undefined { + return this.lastUsedModelId + ? this.models.find(m => m.id === this.lastUsedModelId) + : undefined; + } + + get isDownloading() { + return (modelId: string) => { + return this.downloadJobs.has(modelId); + }; + } + + get activeModel() { + return this.models.find(model => model.id === this.activeModelId); + } +} + +export const mockModelStore = new MockModelStore(); diff --git a/ios/Podfile.lock b/ios/Podfile.lock index 756e094..65ba00e 100644 --- a/ios/Podfile.lock +++ b/ios/Podfile.lock @@ -1657,6 +1657,27 @@ PODS: - ReactCommon/turbomodule/bridging - ReactCommon/turbomodule/core - Yoga + - RNReactNativeHapticFeedback (2.3.3): + - DoubleConversion + - glog + - hermes-engine + - RCT-Folly (= 2024.01.01.00) + - RCTRequired + - RCTTypeSafety + - React-Core + - React-debug + - React-Fabric + - React-featureflags + - React-graphics + - React-ImageManager + - React-NativeModulesApple + - React-RCTFabric + - React-rendererdebug + - React-utils + - ReactCodegen + - ReactCommon/turbomodule/bridging + - ReactCommon/turbomodule/core + - Yoga - RNReanimated (3.16.3): - DoubleConversion - glog @@ -1873,6 +1894,7 @@ DEPENDENCIES: - "RNCPicker (from `../node_modules/@react-native-picker/picker`)" - RNDeviceInfo (from `../node_modules/react-native-device-info`) - RNGestureHandler (from `../node_modules/react-native-gesture-handler`) + - RNReactNativeHapticFeedback (from `../node_modules/react-native-haptic-feedback`) - RNReanimated (from `../node_modules/react-native-reanimated`) - RNScreens (from `../node_modules/react-native-screens`) - RNSVG (from `../node_modules/react-native-svg`) @@ -2041,6 +2063,8 @@ EXTERNAL SOURCES: :path: "../node_modules/react-native-device-info" RNGestureHandler: :path: "../node_modules/react-native-gesture-handler" + RNReactNativeHapticFeedback: + :path: "../node_modules/react-native-haptic-feedback" RNReanimated: :path: "../node_modules/react-native-reanimated" RNScreens: @@ -2131,6 +2155,7 @@ SPEC CHECKSUMS: RNCPicker: d8662eb6615e3401acb590c44b97b2af3beb1e53 RNDeviceInfo: ae26ae45db3f9937f038a284bcd0a1db8d70db96 RNGestureHandler: 5b24d10761754ad271b714e536c457fd89b17c54 + RNReactNativeHapticFeedback: 00ba111b82aa266bb3ee1aa576831c2ea9a9dfad RNReanimated: 929c26a706dfe1af8feee9f2cf78004394e4dd04 RNScreens: e21c8d32fe97737ecc30f1f21e7b6f69f341a1f5 RNSVG: 6a529f4faed8be4ebfb00f1a29e25cb046d95e61 diff --git a/jest.config.js b/jest.config.js index 3d4b14e..2e5e1d8 100644 --- a/jest.config.js +++ b/jest.config.js @@ -39,5 +39,7 @@ module.exports = { '/__mocks__/external/react-native-document-picker.js', '@dr.pogodin/react-native-fs': '/__mocks__/external/@dr.pogodin/react-native-fs.js', + 'react-native-haptic-feedback': + '/__mocks__/external/react-native-haptic-feedback.js', }, }; diff --git a/jest/setup.ts b/jest/setup.ts index 418b3ee..0c87733 100644 --- a/jest/setup.ts +++ b/jest/setup.ts @@ -3,6 +3,8 @@ import mockClipboard from '@react-native-clipboard/clipboard/jest/clipboard-mock import 'react-native-gesture-handler/jestSetup'; +jest.mock('react-native-haptic-feedback'); + // Mock react-native-reanimated //require('react-native-reanimated').setUpTests(); jest.mock('react-native-reanimated', () => { diff --git a/package.json b/package.json index 8adf13f..5f12ac2 100644 --- a/package.json +++ b/package.json @@ -51,6 +51,7 @@ "react-native-document-picker": "^9.1.2", "react-native-gesture-handler": "^2.20.2", "react-native-get-random-values": "^1.11.0", + "react-native-haptic-feedback": "^2.3.3", "react-native-image-viewing": "^0.2.2", "react-native-linear-gradient": "^2.8.3", "react-native-marked": "^6.0.4", diff --git a/src/api/__tests__/hf.test.ts b/src/api/__tests__/hf.test.ts new file mode 100644 index 0000000..e37d09f --- /dev/null +++ b/src/api/__tests__/hf.test.ts @@ -0,0 +1,77 @@ +import axios from 'axios'; +import {fetchGGUFSpecs, fetchModelFilesDetails, fetchModels} from '../hf'; + +jest.mock('axios'); +const mockedAxios = axios as jest.Mocked; + +describe('fetchModels', () => { + it('should fetch models with basic parameters', async () => { + const mockResponse = { + data: [{id: 'model1'}], + headers: {link: 'next-page-link'}, + }; + mockedAxios.get.mockResolvedValueOnce(mockResponse); + + const result = await fetchModels({search: 'test'}); + + expect(mockedAxios.get).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ + params: expect.objectContaining({search: 'test'}), + }), + ); + expect(result).toEqual({ + models: [{id: 'model1'}], + nextLink: 'next-page-link', + }); + }); + + it('should handle missing pagination link', async () => { + const mockResponse = { + data: [{id: 'model1'}], + headers: {}, + }; + mockedAxios.get.mockResolvedValueOnce(mockResponse); + + const result = await fetchModels({}); + expect(result.nextLink).toBeNull(); + }); +}); + +describe('API error handling', () => { + it('should handle network errors in fetchModels', async () => { + const error = new Error('Network error'); + mockedAxios.get.mockRejectedValueOnce(error); + + await expect(fetchModels({})).rejects.toThrow('Network error'); + }); + + it('should handle non-ok responses in fetchModelFilesDetails', async () => { + global.fetch = jest.fn().mockResolvedValueOnce({ + ok: false, + statusText: 'Not Found', + }); + + await expect(fetchModelFilesDetails('model1')).rejects.toThrow( + 'Error fetching model files: Not Found', + ); + }); +}); + +describe('fetchGGUFSpecs', () => { + it('should parse GGUF specs correctly', async () => { + const mockSpecs = { + gguf: { + params: 7, + type: 'f16', + }, + }; + global.fetch = jest.fn().mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockSpecs), + }); + + const result = await fetchGGUFSpecs('model1'); + expect(result).toEqual(mockSpecs); + }); +}); diff --git a/src/components/BottomSheetSearchbar/__tests__/BottomSheetSearchbar.test.tsx b/src/components/BottomSheetSearchbar/__tests__/BottomSheetSearchbar.test.tsx new file mode 100644 index 0000000..8c9e23e --- /dev/null +++ b/src/components/BottomSheetSearchbar/__tests__/BottomSheetSearchbar.test.tsx @@ -0,0 +1,74 @@ +import React from 'react'; +import {render, fireEvent} from '@testing-library/react-native'; +import {BottomSheetSearchbar} from '../BottomSheetSearchbar'; +import {useBottomSheetInternal} from '@gorhom/bottom-sheet'; + +jest.mock('@gorhom/bottom-sheet', () => ({ + useBottomSheetInternal: jest.fn(), +})); + +describe('BottomSheetSearchbar', () => { + const mockShouldHandleKeyboardEvents = {value: false}; + + beforeEach(() => { + (useBottomSheetInternal as jest.Mock).mockReturnValue({ + shouldHandleKeyboardEvents: mockShouldHandleKeyboardEvents, + }); + }); + + it('should handle focus event correctly', () => { + const onFocus = jest.fn(); + const {getByTestId} = render( + , + ); + + fireEvent(getByTestId('searchbar'), 'focus'); + + expect(mockShouldHandleKeyboardEvents.value).toBe(true); + expect(onFocus).toHaveBeenCalled(); + }); + + it('should handle blur event correctly', () => { + const onBlur = jest.fn(); + const {getByTestId} = render( + , + ); + + fireEvent(getByTestId('searchbar'), 'blur'); + + expect(mockShouldHandleKeyboardEvents.value).toBe(false); + expect(onBlur).toHaveBeenCalled(); + }); + + it('should reset keyboard events flag on unmount', () => { + const {unmount} = render(); + + unmount(); + + expect(mockShouldHandleKeyboardEvents.value).toBe(false); + }); + + it('should forward props to Searchbar component', () => { + const placeholder = 'Search...'; + const value = 'test'; + const onChangeText = jest.fn(); + + const {getByPlaceholderText} = render( + , + ); + + const searchbar = getByPlaceholderText(placeholder); + expect(searchbar.props.value).toBe(value); + + fireEvent.changeText(searchbar, 'new value'); + expect(onChangeText).toHaveBeenCalledWith('new value'); + }); +}); diff --git a/src/components/Bubble/Bubble.tsx b/src/components/Bubble/Bubble.tsx index 75725e6..742637a 100644 --- a/src/components/Bubble/Bubble.tsx +++ b/src/components/Bubble/Bubble.tsx @@ -1,10 +1,11 @@ import type {ReactNode} from 'react'; import React, {useContext} from 'react'; -import {View, TouchableOpacity} from 'react-native'; +import {View, TouchableOpacity, Animated} from 'react-native'; import {Text} from 'react-native-paper'; import Clipboard from '@react-native-clipboard/clipboard'; import Icon from 'react-native-vector-icons/MaterialCommunityIcons'; +import ReactNativeHapticFeedback from 'react-native-haptic-feedback'; import {useTheme} from '../../hooks'; @@ -13,15 +14,22 @@ import {styles} from './styles'; import {UserContext} from '../../utils'; import {MessageType} from '../../utils/types'; +const hapticOptions = { + enableVibrateFallback: true, + ignoreAndroidSystemSettings: false, +}; + export const Bubble = ({ child, message, // eslint-disable-next-line @typescript-eslint/no-unused-vars nextMessageInGroup, + scale = new Animated.Value(1), }: { child: ReactNode; message: MessageType.Any; nextMessageInGroup: boolean; + scale?: Animated.Value; }) => { const theme = useTheme(); const user = useContext(UserContext); @@ -41,15 +49,19 @@ export const Bubble = ({ const copyToClipboard = () => { if (message.type === 'text') { + ReactNativeHapticFeedback.trigger('impactLight', hapticOptions); Clipboard.setString(message.text.trim()); } }; - const Container = false //copyable - ? TouchableOpacity - : View; return ( - + {child} {timings && ( @@ -61,6 +73,6 @@ export const Bubble = ({ {timings && {timingsString}} )} - + ); }; diff --git a/src/components/Bubble/styles.ts b/src/components/Bubble/styles.ts index 2e6452e..f7ddede 100644 --- a/src/components/Bubble/styles.ts +++ b/src/components/Bubble/styles.ts @@ -18,7 +18,7 @@ export const styles = ({ backgroundColor: !currentUserIsAuthor || message.type === 'image' ? 'transparent' //theme.colors.secondary - : theme.colors.surfaceVariant, + : theme.colors.authorBubbleBackground, borderBottomLeftRadius: currentUserIsAuthor || roundBorder ? theme.borders.messageBorderRadius diff --git a/src/components/ChatInput/ChatInput.tsx b/src/components/ChatInput/ChatInput.tsx new file mode 100644 index 0000000..fd9b74a --- /dev/null +++ b/src/components/ChatInput/ChatInput.tsx @@ -0,0 +1,187 @@ +import * as React from 'react'; +import {TextInput, TextInputProps, View, Animated} from 'react-native'; + +import {observer} from 'mobx-react'; +import {IconButton, Text} from 'react-native-paper'; + +import {useTheme} from '../../hooks'; + +import {createStyles} from './styles'; + +import {chatSessionStore} from '../../store'; + +import {MessageType} from '../../utils/types'; +import {L10nContext, unwrap, UserContext} from '../../utils'; + +import { + AttachmentButton, + AttachmentButtonAdditionalProps, + CircularActivityIndicator, + CircularActivityIndicatorProps, + SendButton, + StopButton, +} from '..'; + +export interface ChatInputTopLevelProps { + /** Whether attachment is uploading. Will replace attachment button with a + * {@link CircularActivityIndicator}. Since we don't have libraries for + * managing media in dependencies we have no way of knowing if + * something is uploading so you need to set this manually. */ + isAttachmentUploading?: boolean; + /** Whether the AI is currently streaming tokens */ + isStreaming?: boolean; + /** @see {@link AttachmentButtonProps.onPress} */ + onAttachmentPress?: () => void; + /** Will be called on {@link SendButton} tap. Has {@link MessageType.PartialText} which can + * be transformed to {@link MessageType.Text} and added to the messages list. */ + onSendPress: (message: MessageType.PartialText) => void; + onStopPress?: () => void; + onCancelEdit?: () => void; + isStopVisible?: boolean; + /** Controls the visibility behavior of the {@link SendButton} based on the + * `TextInput` state. Defaults to `editing`. */ + sendButtonVisibilityMode?: 'always' | 'editing'; + textInputProps?: TextInputProps; +} + +export interface ChatInputAdditionalProps { + attachmentButtonProps?: AttachmentButtonAdditionalProps; + attachmentCircularActivityIndicatorProps?: CircularActivityIndicatorProps; +} + +export type ChatInputProps = ChatInputTopLevelProps & ChatInputAdditionalProps; + +/** Bottom bar input component with a text input, attachment and + * send buttons inside. By default hides send button when text input is empty. */ +export const ChatInput = observer( + ({ + attachmentButtonProps, + attachmentCircularActivityIndicatorProps, + isAttachmentUploading, + isStreaming = false, + onAttachmentPress, + onSendPress, + onStopPress, + onCancelEdit, + isStopVisible, + sendButtonVisibilityMode, + textInputProps, + }: ChatInputProps) => { + const l10n = React.useContext(L10nContext); + const theme = useTheme(); + const user = React.useContext(UserContext); + const inputRef = React.useRef(null); + const editBarHeight = React.useRef(new Animated.Value(0)).current; + + // Use `defaultValue` if provided + const [text, setText] = React.useState(textInputProps?.defaultValue ?? ''); + const isEditMode = chatSessionStore.isEditMode; + + const styles = createStyles({theme, isEditMode}); + + const value = textInputProps?.value ?? text; + + React.useEffect(() => { + if (isEditMode) { + // Animate edit bar height + Animated.spring(editBarHeight, { + toValue: 28, + useNativeDriver: false, + friction: 8, + }).start(); + // Focus input + inputRef.current?.focus(); + } else { + Animated.spring(editBarHeight, { + toValue: 0, + useNativeDriver: false, + friction: 8, + }).start(); + onCancelEdit?.(); + } + }, [isEditMode, editBarHeight, onCancelEdit]); + + const handleChangeText = (newText: string) => { + setText(newText); + textInputProps?.onChangeText?.(newText); + }; + + const handleSend = () => { + const trimmedValue = value.trim(); + if (trimmedValue) { + onSendPress({text: trimmedValue, type: 'text'}); + setText(''); + } + }; + + const handleCancel = () => { + setText(''); + onCancelEdit?.(); + }; + + const isSendButtonVisible = + !isStreaming && + !isStopVisible && + user && + (sendButtonVisibilityMode === 'always' || value.trim()); + + return ( + + + {isEditMode && ( + + + Editing message + + + + )} + + {user && + (isAttachmentUploading ? ( + + ) : ( + !!onAttachmentPress && ( + + ) + ))} + + {isSendButtonVisible ? : null} + {isStopVisible && } + + + + ); + }, +); diff --git a/src/components/Input/__tests__/Input.test.tsx b/src/components/ChatInput/__tests__/ChatInput.test.tsx similarity index 97% rename from src/components/Input/__tests__/Input.test.tsx rename to src/components/ChatInput/__tests__/ChatInput.test.tsx index 103b1af..49d188b 100644 --- a/src/components/Input/__tests__/Input.test.tsx +++ b/src/components/ChatInput/__tests__/ChatInput.test.tsx @@ -5,7 +5,7 @@ import {ScrollView} from 'react-native'; import {user} from '../../../../jest/fixtures'; import {l10n} from '../../../utils/l10n'; import {UserContext} from '../../../utils'; -import {Input} from '../Input'; +import {ChatInput} from '../ChatInput'; const renderScrollable = () => ; @@ -15,7 +15,7 @@ describe('input', () => { const onSendPress = jest.fn(); const {getByPlaceholderText, getByLabelText} = render( - { const onSendPress = jest.fn(); const {getByPlaceholderText, getByLabelText} = render( - { const onChangeText = jest.fn(newValue => { rerender( - { }); const {getByPlaceholderText, getByLabelText, rerender} = render( - { const onChangeText = jest.fn(); const {getByPlaceholderText, getByLabelText} = render( - { const value = 'value'; const {getByPlaceholderText, getByLabelText} = render( - { const defaultValue = 'defaultValue'; const {getByPlaceholderText, getByLabelText} = render( - { const onSendPress = jest.fn(); const {getByLabelText} = render( - { const onSendPress = jest.fn(); const {getByTestId} = render( - + StyleSheet.create({ + container: { + alignItems: 'center', + flexDirection: 'row', + }, + input: { + ...theme.fonts.inputTextStyle, + color: theme.colors.inverseOnSurface, + flex: 1, + maxHeight: 150, + paddingVertical: 0, + }, + marginRight: { + marginRight: 16, + }, + inputContainer: { + flex: 1, + flexDirection: 'row', + alignItems: 'flex-end', + borderRadius: 12, + overflow: 'hidden', + }, + editBar: { + position: 'absolute', + top: 0, + left: 0, + right: 0, + backgroundColor: theme.colors.surfaceVariant, + flexDirection: 'row', + alignItems: 'center', + justifyContent: 'space-between', + paddingHorizontal: 12, + borderTopLeftRadius: 12, + borderTopRightRadius: 12, + borderBottomWidth: 1, + borderBottomColor: theme.colors.outlineVariant, + }, + editBarText: { + color: theme.colors.onSurfaceVariant, + }, + editBarButton: { + margin: 0, + }, + inputRow: { + flex: 1, + flexDirection: 'row', + alignItems: 'flex-end', + paddingHorizontal: 24, + paddingVertical: 20, + marginTop: isEditMode ? 28 : 0, + }, + }); diff --git a/src/components/ChatView/ChatView.tsx b/src/components/ChatView/ChatView.tsx index b7aafc2..8462bea 100644 --- a/src/components/ChatView/ChatView.tsx +++ b/src/components/ChatView/ChatView.tsx @@ -12,26 +12,19 @@ import { } from 'react-native'; import dayjs from 'dayjs'; +import {observer} from 'mobx-react'; import calendar from 'dayjs/plugin/calendar'; import {oneOf} from '@flyerhq/react-native-link-preview'; import {useSafeAreaInsets} from 'react-native-safe-area-context'; import {useComponentSize} from '../KeyboardAccessoryView/hooks'; -import {LoadingBubble} from '../LoadingBubble'; -import {usePrevious, useTheme} from '../../hooks'; +import {usePrevious, useTheme, useMessageActions} from '../../hooks'; -import {styles} from './styles'; import ImageView from './ImageView'; -import { - Message, - MessageTopLevelProps, - KeyboardAccessoryView, - CircularActivityIndicator, - Input, - InputAdditionalProps, - InputTopLevelProps, -} from '..'; +import {createStyles} from './styles'; + +import {chatSessionStore, modelStore} from '../../store'; import {l10n} from '../../utils/l10n'; import {MessageType, User} from '../../utils/types'; @@ -43,6 +36,18 @@ import { UserContext, } from '../../utils'; +import { + Message, + MessageTopLevelProps, + KeyboardAccessoryView, + CircularActivityIndicator, + ChatInput, + ChatInputAdditionalProps, + ChatInputTopLevelProps, + Menu, + LoadingBubble, +} from '..'; + // Untestable /* istanbul ignore next */ const animate = () => { @@ -51,7 +56,7 @@ const animate = () => { dayjs.extend(calendar); -export type ChatTopLevelProps = InputTopLevelProps & MessageTopLevelProps; +export type ChatTopLevelProps = ChatInputTopLevelProps & MessageTopLevelProps; export interface ChatProps extends ChatTopLevelProps { /** Allows you to replace the default Input widget e.g. if you want to create a channel view. */ @@ -77,7 +82,7 @@ export interface ChatProps extends ChatTopLevelProps { /** Use this to enable `LayoutAnimation`. Experimental on Android (same as React Native). */ enableAnimation?: boolean; flatListProps?: Partial>; - inputProps?: InputAdditionalProps; + inputProps?: ChatInputAdditionalProps; /** Used for pagination (infinite scroll) together with {@link ChatProps.onEndReached}. * When true, indicates that there are no more pages to load and * pagination will not be triggered. */ @@ -108,357 +113,554 @@ export interface ChatProps extends ChatTopLevelProps { user: User; } +// Add these types at the top of the file with other imports +type MenuItem = { + label: string; + onPress?: () => void; + icon?: string; + disabled: boolean; + submenu?: SubMenuItem[]; +}; + +type SubMenuItem = { + label: string; + onPress: () => void; + disabled?: boolean; + width?: number; +}; + /** Entry component, represents the complete chat */ -export const ChatView = ({ - customBottomComponent, - customDateHeaderText, - dateFormat, - disableImageGallery, - emptyState, - enableAnimation, - flatListProps, - inputProps, - isAttachmentUploading, - isLastPage, - isStopVisible, - isStreaming = false, - isThinking = false, - l10nOverride, - locale = 'en', - messages, - onAttachmentPress, - onEndReached, - onMessageLongPress, - onMessagePress, - onPreviewDataFetched, - onSendPress, - onStopPress, - renderBubble, - renderCustomMessage, - renderFileMessage, - renderImageMessage, - renderTextMessage, - sendButtonVisibilityMode = 'editing', - showUserAvatars = false, - showUserNames = false, - textInputProps, - timeFormat, - usePreviewData = true, - user, -}: ChatProps) => { - const theme = useTheme(); - - const { - container, - emptyComponentContainer, - emptyComponentTitle, - flatList, - flatListContentContainer, - footer, - footerLoadingPage, - keyboardAccessoryView, - } = styles({theme}); - - const {onLayout, size} = useComponentSize(); - const animationRef = React.useRef(false); - const list = React.useRef>(null); - const insets = useSafeAreaInsets(); - const [isImageViewVisible, setIsImageViewVisible] = React.useState(false); - const [isNextPageLoading, setNextPageLoading] = React.useState(false); - const [imageViewIndex, setImageViewIndex] = React.useState(0); - const [stackEntry, setStackEntry] = React.useState({}); - - const l10nValue = React.useMemo( - () => ({...l10n[locale], ...unwrap(l10nOverride)}), - [l10nOverride, locale], - ); - - const {chatMessages, gallery} = calculateChatMessages(messages, user, { +export const ChatView = observer( + ({ + customBottomComponent, customDateHeaderText, dateFormat, - showUserNames, + disableImageGallery, + emptyState, + enableAnimation, + flatListProps, + inputProps, + isAttachmentUploading, + isLastPage, + isStopVisible, + isStreaming = false, + isThinking = false, + l10nOverride, + locale = 'en', + messages, + onAttachmentPress, + onEndReached, + onMessageLongPress: externalOnMessageLongPress, + onMessagePress, + onPreviewDataFetched, + onSendPress, + onStopPress, + renderBubble, + renderCustomMessage, + renderFileMessage, + renderImageMessage, + renderTextMessage, + sendButtonVisibilityMode = 'editing', + showUserAvatars = false, + showUserNames = false, + textInputProps, timeFormat, - }); - - const previousChatMessages = usePrevious(chatMessages); - - React.useEffect(() => { - if ( - chatMessages[0]?.type !== 'dateHeader' && - chatMessages[0]?.id !== previousChatMessages?.[0]?.id && - chatMessages[0]?.author?.id === user.id - ) { - list.current?.scrollToOffset({ - animated: true, - offset: 0, + usePreviewData = true, + user, + }: ChatProps) => { + const theme = useTheme(); + const [inputText, setInputText] = React.useState(''); + + const wrappedOnSendPress = React.useCallback( + async (message: MessageType.PartialText) => { + if (chatSessionStore.isEditMode) { + chatSessionStore.commitEdit(); + } + onSendPress(message); + setInputText(''); + }, + [onSendPress], + ); + + const handleCancelEdit = React.useCallback(() => { + setInputText(''); + chatSessionStore.exitEditMode(); + }, []); + + const {handleCopy, handleEdit, handleTryAgain, handleTryAgainWith} = + useMessageActions({ + user, + messages, + handleSendPress: wrappedOnSendPress, + setInputText, }); - } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [chatMessages]); - React.useEffect(() => { - initLocale(locale); - }, [locale]); + const styles = createStyles({theme}); + + const {onLayout, size} = useComponentSize(); + const animationRef = React.useRef(false); + const list = React.useRef>(null); + const insets = useSafeAreaInsets(); + const [isImageViewVisible, setIsImageViewVisible] = React.useState(false); + const [isNextPageLoading, setNextPageLoading] = React.useState(false); + const [imageViewIndex, setImageViewIndex] = React.useState(0); + const [stackEntry, setStackEntry] = React.useState({}); + + const l10nValue = React.useMemo( + () => ({...l10n[locale], ...unwrap(l10nOverride)}), + [l10nOverride, locale], + ); + + const {chatMessages, gallery} = calculateChatMessages(messages, user, { + customDateHeaderText, + dateFormat, + showUserNames, + timeFormat, + }); - // Untestable - /* istanbul ignore next */ - if (animationRef.current && enableAnimation) { - InteractionManager.runAfterInteractions(animate); - } + const previousChatMessages = usePrevious(chatMessages); + + React.useEffect(() => { + if ( + chatMessages[0]?.type !== 'dateHeader' && + chatMessages[0]?.id !== previousChatMessages?.[0]?.id && + chatMessages[0]?.author?.id === user.id + ) { + list.current?.scrollToOffset({ + animated: true, + offset: 0, + }); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [chatMessages]); + + React.useEffect(() => { + initLocale(locale); + }, [locale]); - React.useEffect(() => { // Untestable /* istanbul ignore next */ if (animationRef.current && enableAnimation) { InteractionManager.runAfterInteractions(animate); - } else { - animationRef.current = true; } - }, [enableAnimation, messages]); - const handleEndReached = React.useCallback( - // Ignoring because `scroll` event for some reason doesn't trigger even basic - // `onEndReached`, impossible to test. - // TODO: Verify again later + React.useEffect(() => { + // Untestable + /* istanbul ignore next */ + if (animationRef.current && enableAnimation) { + InteractionManager.runAfterInteractions(animate); + } else { + animationRef.current = true; + } + }, [enableAnimation, messages]); + + const handleEndReached = React.useCallback( + // Ignoring because `scroll` event for some reason doesn't trigger even basic + // `onEndReached`, impossible to test. + // TODO: Verify again later + /* istanbul ignore next */ + async ({distanceFromEnd}: {distanceFromEnd: number}) => { + if ( + !onEndReached || + isLastPage || + distanceFromEnd <= 0 || + messages.length === 0 || + isNextPageLoading + ) { + return; + } + + setNextPageLoading(true); + await onEndReached?.(); + setNextPageLoading(false); + }, + [isLastPage, isNextPageLoading, messages.length, onEndReached], + ); + + const handleImagePress = React.useCallback( + (message: MessageType.Image) => { + setImageViewIndex( + gallery.findIndex( + image => image.id === message.id && image.uri === message.uri, + ), + ); + setIsImageViewVisible(true); + setStackEntry( + StatusBar.pushStackEntry({ + barStyle: 'light-content', + animated: true, + }), + ); + }, + [gallery], + ); + + const handleMessagePress = React.useCallback( + (message: MessageType.Any) => { + if (message.type === 'image' && !disableImageGallery) { + handleImagePress(message); + } + onMessagePress?.(message); + }, + [disableImageGallery, handleImagePress, onMessagePress], + ); + + // TODO: Tapping on a close button results in the next warning: + // `An update to ImageViewing inside a test was not wrapped in act(...).` /* istanbul ignore next */ - async ({distanceFromEnd}: {distanceFromEnd: number}) => { - if ( - !onEndReached || - isLastPage || - distanceFromEnd <= 0 || - messages.length === 0 || - isNextPageLoading - ) { - return; + const handleRequestClose = () => { + setIsImageViewVisible(false); + StatusBar.popStackEntry(stackEntry); + }; + + const keyExtractor = React.useCallback( + ({id}: MessageType.DerivedAny) => id, + [], + ); + + const [menuVisible, setMenuVisible] = React.useState(false); + const [menuPosition, setMenuPosition] = React.useState({x: 0, y: 0}); + const [selectedMessage, setSelectedMessage] = + React.useState(null); + + const handleMessageLongPress = React.useCallback( + (message: MessageType.Any, event: any) => { + if (message.type !== 'text') { + externalOnMessageLongPress?.(message); + return; + } + + const {pageX, pageY} = event.nativeEvent; + setMenuPosition({x: pageX, y: pageY}); + setSelectedMessage(message); + setMenuVisible(true); + externalOnMessageLongPress?.(message); + }, + [externalOnMessageLongPress], + ); + + const handleMenuDismiss = React.useCallback(() => { + setMenuVisible(false); + setSelectedMessage(null); + }, []); + + const menuItems = React.useMemo((): MenuItem[] => { + if (!selectedMessage || selectedMessage.type !== 'text') { + return []; } - setNextPageLoading(true); - await onEndReached?.(); - setNextPageLoading(false); - }, - [isLastPage, isNextPageLoading, messages.length, onEndReached], - ); - - const handleImagePress = React.useCallback( - (message: MessageType.Image) => { - setImageViewIndex( - gallery.findIndex( - image => image.id === message.id && image.uri === message.uri, - ), - ); - setIsImageViewVisible(true); - setStackEntry( - StatusBar.pushStackEntry({ - barStyle: 'light-content', - animated: true, - }), - ); - }, - [gallery], - ); - - const handleMessagePress = React.useCallback( - (message: MessageType.Any) => { - if (message.type === 'image' && !disableImageGallery) { - handleImagePress(message); + const isAuthor = selectedMessage.author.id === user.id; + const hasActiveModel = modelStore.activeModelId !== undefined; + const models = modelStore.availableModels || []; + + const baseItems: MenuItem[] = [ + { + label: 'Copy', + onPress: () => { + handleCopy(selectedMessage); + handleMenuDismiss(); + }, + icon: 'content-copy', + disabled: false, + }, + ]; + + if (!isAuthor) { + baseItems.push({ + label: 'Regenerate', + onPress: () => { + handleTryAgain(selectedMessage); + handleMenuDismiss(); + }, + icon: 'refresh', + disabled: !hasActiveModel, + }); + + baseItems.push({ + label: 'Regenerate with', + icon: 'chevron-right', + disabled: false, + submenu: models.map(model => ({ + label: model.name, + width: Math.min(300, size.width), + onPress: () => { + handleTryAgainWith(model.id, selectedMessage); + handleMenuDismiss(); + }, + })), + }); } - onMessagePress?.(message); - }, - [disableImageGallery, handleImagePress, onMessagePress], - ); - - // TODO: Tapping on a close button results in the next warning: - // `An update to ImageViewing inside a test was not wrapped in act(...).` - /* istanbul ignore next */ - const handleRequestClose = () => { - setIsImageViewVisible(false); - StatusBar.popStackEntry(stackEntry); - }; - - const keyExtractor = React.useCallback( - ({id}: MessageType.DerivedAny) => id, - [], - ); - - const renderItem = React.useCallback( - ({item: message}: {item: MessageType.DerivedAny; index: number}) => { - const messageWidth = - showUserAvatars && - message.type !== 'dateHeader' && - message.author.id !== user.id - ? Math.floor(Math.min(size.width * 0.9, 440)) - : Math.floor(Math.min(size.width * 0.92, 440)); - - const roundBorder = - message.type !== 'dateHeader' && message.nextMessageInGroup; - const showAvatar = - message.type !== 'dateHeader' && !message.nextMessageInGroup; - const showName = message.type !== 'dateHeader' && message.showName; - const showStatus = message.type !== 'dateHeader' && message.showStatus; - - return ( - - ); - }, - [ - enableAnimation, - handleMessagePress, - onMessageLongPress, - onPreviewDataFetched, - renderBubble, - renderCustomMessage, - renderFileMessage, - renderImageMessage, - renderTextMessage, - showUserAvatars, - size.width, - usePreviewData, + + if (isAuthor) { + baseItems.push({ + label: 'Edit', + onPress: () => { + handleEdit(selectedMessage); + handleMenuDismiss(); + }, + icon: 'pencil', + disabled: !hasActiveModel, + }); + } + + return baseItems; + }, [ + selectedMessage, user.id, - ], - ); - - const renderListEmptyComponent = React.useCallback( - () => ( - - {oneOf( - emptyState, - - {l10nValue.emptyChatPlaceholder} - , - )()} - - ), - [emptyComponentContainer, emptyComponentTitle, emptyState, l10nValue], - ); - - const renderListFooterComponent = React.useCallback( - () => - // Impossible to test, see `handleEndReached` function - /* istanbul ignore next */ - isNextPageLoading ? ( - - + handleCopy, + handleTryAgain, + handleTryAgainWith, + handleEdit, + handleMenuDismiss, + size.width, + ]); + + const renderMenuItem = React.useCallback( + (item: MenuItem, index: number) => { + if (item.submenu) { + return ( + + {index > 0 && } + ( + + {subIndex > 0 && } + + + ), + )} + /> + + ); + } + + return ( + + {index > 0 && } + + + ); + }, + [styles.menu], + ); + + const renderMessage = React.useCallback( + ({item: message}: {item: MessageType.DerivedAny; index: number}) => { + const messageWidth = + showUserAvatars && + message.type !== 'dateHeader' && + message.author.id !== user.id + ? Math.floor(Math.min(size.width * 0.9, 440)) + : Math.floor(Math.min(size.width * 0.92, 440)); + + const roundBorder = + message.type !== 'dateHeader' && message.nextMessageInGroup; + const showAvatar = + message.type !== 'dateHeader' && !message.nextMessageInGroup; + const showName = message.type !== 'dateHeader' && message.showName; + const showStatus = message.type !== 'dateHeader' && message.showStatus; + + return ( + + ); + }, + [ + enableAnimation, + handleMessageLongPress, + handleMessagePress, + onPreviewDataFetched, + renderBubble, + renderCustomMessage, + renderFileMessage, + renderImageMessage, + renderTextMessage, + showUserAvatars, + size.width, + usePreviewData, + user.id, + ], + ); + + const renderListEmptyComponent = React.useCallback( + () => ( + + {oneOf( + emptyState, + + {l10nValue.emptyChatPlaceholder} + , + )()} - ) : ( - ), - [footer, footerLoadingPage, isNextPageLoading, theme.colors.primary], - ); - - const renderListHeaderComponent = React.useCallback( - () => (isThinking ? : null), - [isThinking], - ); - - const renderScrollable = React.useCallback( - (panHandlers: GestureResponderHandlers) => ( - - ), - [ - chatMessages, - flatList, - flatListContentContainer, - flatListProps, - handleEndReached, - insets.bottom, - keyExtractor, - renderItem, - renderListEmptyComponent, - renderListFooterComponent, - renderListHeaderComponent, - ], - ); - - return ( - - {/**/} - - - {customBottomComponent ? ( - <> - <>{renderScrollable({})} - <>{customBottomComponent()} - - ) : ( - - + // Impossible to test, see `handleEndReached` function + /* istanbul ignore next */ + isNextPageLoading ? ( + + + + ) : ( + + ), + [ + isNextPageLoading, + styles.footerLoadingPage, + styles.footer, + theme.colors.primary, + ], + ); + + const renderListHeaderComponent = React.useCallback( + () => (isThinking ? : null), + [isThinking], + ); + + const renderScrollable = React.useCallback( + (panHandlers: GestureResponderHandlers) => ( + + ), + [ + chatMessages, + styles.flatList, + styles.flatListContentContainer, + flatListProps, + handleEndReached, + insets.bottom, + keyExtractor, + renderMessage, + renderListEmptyComponent, + renderListFooterComponent, + renderListHeaderComponent, + ], + ); + + return ( + + + + {customBottomComponent ? ( + <> + <>{renderScrollable({})} + <>{customBottomComponent()} + + ) : ( + - - )} - - - - {/**/} - - ); -}; + style: styles.keyboardAccessoryView, + }}> + + + )} + + + {menuItems.map(renderMenuItem)} + + + + + ); + }, +); diff --git a/src/components/ChatView/__tests__/ChatView.test.tsx b/src/components/ChatView/__tests__/ChatView.test.tsx index 73079b5..c07a27d 100644 --- a/src/components/ChatView/__tests__/ChatView.test.tsx +++ b/src/components/ChatView/__tests__/ChatView.test.tsx @@ -61,7 +61,7 @@ describe('chat', () => { }, ]; const onSendPress = jest.fn(); - const {getByLabelText} = render( + const {getByLabelText, getByPlaceholderText} = render( { user={user} />, ); + const textInput = getByPlaceholderText(l10n.en.inputPlaceholder); + fireEvent.changeText(textInput, 'text'); + const button = getByLabelText(l10n.en.sendButtonAccessibilityLabel); fireEvent.press(button); expect(onSendPress).toHaveBeenCalledWith({text: 'text', type: 'text'}); diff --git a/src/components/ChatView/styles.ts b/src/components/ChatView/styles.ts index 434c4a8..f4b22de 100644 --- a/src/components/ChatView/styles.ts +++ b/src/components/ChatView/styles.ts @@ -1,7 +1,7 @@ import {Platform, StyleSheet} from 'react-native'; import {Theme} from '../../utils/types'; -export const styles = ({theme}: {theme: Theme}) => +export const createStyles = ({theme}: {theme: Theme}) => StyleSheet.create({ container: { backgroundColor: theme.colors.background, @@ -39,8 +39,11 @@ export const styles = ({theme}: {theme: Theme}) => height: 4, }, keyboardAccessoryView: { - backgroundColor: theme.colors.onBackground, + backgroundColor: theme.colors.primary, borderTopLeftRadius: theme.borders.inputBorderRadius, borderTopRightRadius: theme.borders.inputBorderRadius, }, + menu: { + width: 170, + }, }); diff --git a/src/components/Input/Input.tsx b/src/components/Input/Input.tsx deleted file mode 100644 index 90bff50..0000000 --- a/src/components/Input/Input.tsx +++ /dev/null @@ -1,130 +0,0 @@ -import * as React from 'react'; -import {TextInput, TextInputProps, View} from 'react-native'; - -import {useTheme} from '../../hooks'; - -import {styles} from './styles'; - -import {MessageType} from '../../utils/types'; -import {L10nContext, unwrap, UserContext} from '../../utils'; - -import { - AttachmentButton, - AttachmentButtonAdditionalProps, - CircularActivityIndicator, - CircularActivityIndicatorProps, - SendButton, - StopButton, -} from '..'; - -export interface InputTopLevelProps { - /** Whether attachment is uploading. Will replace attachment button with a - * {@link CircularActivityIndicator}. Since we don't have libraries for - * managing media in dependencies we have no way of knowing if - * something is uploading so you need to set this manually. */ - isAttachmentUploading?: boolean; - /** Whether the AI is currently streaming tokens */ - isStreaming?: boolean; - /** @see {@link AttachmentButtonProps.onPress} */ - onAttachmentPress?: () => void; - /** Will be called on {@link SendButton} tap. Has {@link MessageType.PartialText} which can - * be transformed to {@link MessageType.Text} and added to the messages list. */ - onSendPress: (message: MessageType.PartialText) => void; - onStopPress?: () => void; - isStopVisible?: boolean; - /** Controls the visibility behavior of the {@link SendButton} based on the - * `TextInput` state. Defaults to `editing`. */ - sendButtonVisibilityMode?: 'always' | 'editing'; - textInputProps?: TextInputProps; -} - -export interface InputAdditionalProps { - attachmentButtonProps?: AttachmentButtonAdditionalProps; - attachmentCircularActivityIndicatorProps?: CircularActivityIndicatorProps; -} - -export type InputProps = InputTopLevelProps & InputAdditionalProps; - -/** Bottom bar input component with a text input, attachment and - * send buttons inside. By default hides send button when text input is empty. */ -export const Input = ({ - attachmentButtonProps, - attachmentCircularActivityIndicatorProps, - isAttachmentUploading, - isStreaming = false, - onAttachmentPress, - onSendPress, - onStopPress, - isStopVisible, - sendButtonVisibilityMode, - textInputProps, -}: InputProps) => { - const l10n = React.useContext(L10nContext); - const theme = useTheme(); - const user = React.useContext(UserContext); - const {container, input, marginRight} = styles({theme}); - - // Use `defaultValue` if provided - const [text, setText] = React.useState(textInputProps?.defaultValue ?? ''); - - const value = textInputProps?.value ?? text; - - const handleChangeText = (newText: string) => { - // Track local state in case `onChangeText` is provided and `value` is not - setText(newText); - textInputProps?.onChangeText?.(newText); - }; - - const handleSend = () => { - const trimmedValue = value.trim(); - - // Impossible to test since button is not visible when value is empty. - // Additional check for the keyboard input. - /* istanbul ignore next */ - if (trimmedValue) { - onSendPress({text: trimmedValue, type: 'text'}); - setText(''); - } - }; - - const isSendButtonVisible = - !isStreaming && - !isStopVisible && - user && - (sendButtonVisibilityMode === 'always' || value.trim()); - - return ( - - {user && - (isAttachmentUploading ? ( - - ) : ( - !!onAttachmentPress && ( - - ) - ))} - - {isSendButtonVisible ? : null} - {isStopVisible && } - - ); -}; diff --git a/src/components/Input/index.ts b/src/components/Input/index.ts deleted file mode 100644 index ba9fe7e..0000000 --- a/src/components/Input/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from './Input'; diff --git a/src/components/Input/styles.ts b/src/components/Input/styles.ts deleted file mode 100644 index 82a19d1..0000000 --- a/src/components/Input/styles.ts +++ /dev/null @@ -1,25 +0,0 @@ -import {StyleSheet} from 'react-native'; - -import {Theme} from '../../utils/types'; - -export const styles = ({theme}: {theme: Theme}) => - StyleSheet.create({ - container: { - alignItems: 'center', - flexDirection: 'row', - paddingHorizontal: 24, - paddingVertical: 20, - }, - input: { - ...theme.fonts.inputTextStyle, - color: theme.colors.inverseOnSurface, - flex: 1, - maxHeight: 100, - // Fixes default paddings for Android - paddingBottom: 0, - paddingTop: 0, - }, - marginRight: { - marginRight: 16, - }, - }); diff --git a/src/components/MarkdownView/MarkdownView.tsx b/src/components/MarkdownView/MarkdownView.tsx index bd516d2..46ccc3f 100644 --- a/src/components/MarkdownView/MarkdownView.tsx +++ b/src/components/MarkdownView/MarkdownView.tsx @@ -18,31 +18,27 @@ interface MarkdownViewProps { markdownText: string; maxMessageWidth: number; //isComplete: boolean; // indicating if message is complete + selectable?: boolean; } export const MarkdownView: React.FC = React.memo( - ({markdownText, maxMessageWidth}) => { + ({markdownText, maxMessageWidth, selectable = false}) => { const _maxWidth = maxMessageWidth; const theme = useTheme(); const tagsStyles = useMemo(() => createTagsStyles(theme), [theme]); - const defaultTextProps = useMemo(() => ({selectable: true}), []); + const defaultTextProps = useMemo( + () => ({ + selectable, + userSelect: selectable ? 'text' : 'none', + }), + [selectable], + ); const systemFonts = useMemo(() => defaultSystemFonts, []); const contentWidth = useMemo(() => _maxWidth, [_maxWidth]); - //if (!isComplete) { - // // During streaming, use Text component - // return ( - // - // - // {markdownText} - // - // - // ); - //} - const htmlContent = useMemo(() => marked(markdownText), [markdownText]); const source = useMemo(() => ({html: htmlContent}), [htmlContent]); @@ -65,5 +61,6 @@ export const MarkdownView: React.FC = React.memo( (prevProps, nextProps) => prevProps.markdownText === nextProps.markdownText && //prevProps.isComplete === nextProps.isComplete && - prevProps.maxMessageWidth === nextProps.maxMessageWidth, + prevProps.maxMessageWidth === nextProps.maxMessageWidth && + prevProps.selectable === nextProps.selectable, ); diff --git a/src/components/Menu/Menu.tsx b/src/components/Menu/Menu.tsx new file mode 100644 index 0000000..6dd5558 --- /dev/null +++ b/src/components/Menu/Menu.tsx @@ -0,0 +1,78 @@ +import React, {useState} from 'react'; + +import { + Divider, + Menu as PaperMenu, + MenuProps as PaperMenuProps, +} from 'react-native-paper'; + +import {useTheme} from '../../hooks'; + +import {createStyles} from './styles'; +import {MenuItem, MenuItemProps} from './MenuItem'; + +const Separator = () => { + const theme = useTheme(); + const styles = createStyles(theme); + return ; +}; + +const GroupSeparator = () => { + const theme = useTheme(); + const styles = createStyles(theme); + return ( + + ); +}; + +export interface MenuProps extends Omit { + selectable?: boolean; +} + +export const Menu: React.FC & { + Item: typeof MenuItem; + GroupSeparator: typeof GroupSeparator; + Separator: typeof Separator; +} = ({children, selectable = false, ...menuProps}) => { + const theme = useTheme(); + const styles = createStyles(theme); + const [hasActiveSubmenu, setHasActiveSubmenu] = useState(false); + + const handleSubmenuOpen = () => setHasActiveSubmenu(true); + const handleSubmenuClose = () => setHasActiveSubmenu(false); + + return ( + + {React.Children.map(children, child => + React.isValidElement(child) + ? React.cloneElement(child, { + onSubmenuOpen: handleSubmenuOpen, + onSubmenuClose: handleSubmenuClose, + selectable, + }) + : child, + )} + + ); +}; + +Menu.Item = MenuItem; +Menu.GroupSeparator = GroupSeparator; +Menu.Separator = Separator; diff --git a/src/components/Menu/MenuContext.tsx b/src/components/Menu/MenuContext.tsx new file mode 100644 index 0000000..e593608 --- /dev/null +++ b/src/components/Menu/MenuContext.tsx @@ -0,0 +1,5 @@ +import React from 'react'; + +export const MenuContext = React.createContext<{selectable: boolean}>({ + selectable: false, +}); diff --git a/src/components/Menu/MenuItem/MenuItem.tsx b/src/components/Menu/MenuItem/MenuItem.tsx new file mode 100644 index 0000000..27c5363 --- /dev/null +++ b/src/components/Menu/MenuItem/MenuItem.tsx @@ -0,0 +1,172 @@ +import {View, Animated} from 'react-native'; +import React, {useRef, useState, useEffect} from 'react'; +import {StyleProp, TextStyle, ViewStyle} from 'react-native'; + +import {Menu as PaperMenu, Icon} from 'react-native-paper'; +import {MenuItemProps as PaperMenuItemProps} from 'react-native-paper'; +import {IconSource} from 'react-native-paper/lib/typescript/components/Icon'; + +import {SubMenu} from '../SubMenu/SubMenu'; + +import {useTheme} from '../../../hooks'; + +import {createStyles} from './styles'; + +export interface MenuItemProps + extends Omit { + label: string; + labelStyle?: StyleProp; + danger?: boolean; + style?: StyleProp; + isGroupLabel?: boolean; + icon?: IconSource; + selected?: boolean; + submenu?: React.ReactNode[]; + onSubmenuOpen?: () => void; + onSubmenuClose?: () => void; + selectable?: boolean; +} + +export const MenuItem: React.FC = ({ + label, + danger, + style, + labelStyle, + isGroupLabel, + icon, + selected, + leadingIcon, + trailingIcon, + submenu, + onSubmenuOpen, + onSubmenuClose, + selectable = false, + ...menuItemProps +}) => { + const [isSubmenuOpen, setIsSubmenuOpen] = useState(false); + const [submenuPosition, setSubmenuPosition] = useState({x: 0, y: 0}); + const itemRef = useRef(null); + const fadeAnim = useRef(new Animated.Value(1)).current; + + const theme = useTheme(); + + useEffect(() => { + Animated.timing(fadeAnim, { + toValue: isSubmenuOpen ? 0.6 : 1, + duration: 200, + useNativeDriver: true, + }).start(); + }, [fadeAnim, isSubmenuOpen]); + + const styles = createStyles(theme); + + const renderLeadingIcon = props => { + if (!selectable && !leadingIcon) { + return undefined; + } + + return ( + + {selected && } + {leadingIcon && + (typeof leadingIcon === 'function' ? ( + leadingIcon({...props, size: 18}) + ) : ( + + ))} + + ); + }; + + const renderTrailingIcon = props => ( + + {trailingIcon ? ( + typeof trailingIcon === 'function' ? ( + trailingIcon({...props, size: 18}) + ) : ( + + ) + ) : icon ? ( + + ) : null} + + ); + + const renderSubmenuIcon = () => ( + + + + ); + + const handlePress = (e: any) => { + if (submenu) { + itemRef.current?.measure((x, y, width, height, pageX, pageY) => { + const willOpen = !isSubmenuOpen; + setSubmenuPosition({x: pageX + width, y: pageY + height}); + setIsSubmenuOpen(willOpen); + if (willOpen) { + onSubmenuOpen?.(); + } else { + onSubmenuClose?.(); + } + }); + } else { + menuItemProps.onPress?.(e); + } + }; + + return ( + + + {submenu && ( + { + setIsSubmenuOpen(false); + onSubmenuClose?.(); + }} + anchorPosition={submenuPosition}> + {submenu} + + )} + + ); +}; diff --git a/src/components/Menu/MenuItem/__tests__/MenuItem.test.tsx b/src/components/Menu/MenuItem/__tests__/MenuItem.test.tsx new file mode 100644 index 0000000..e135bfa --- /dev/null +++ b/src/components/Menu/MenuItem/__tests__/MenuItem.test.tsx @@ -0,0 +1,64 @@ +import React from 'react'; +import {MenuItem} from '../MenuItem'; +import {useTheme} from '../../../../hooks'; +import {fireEvent, render} from '../../../../../jest/test-utils'; + +describe('MenuItem', () => { + beforeEach(() => { + (useTheme as jest.Mock).mockReturnValue({ + colors: { + menuText: '#000000', + menuDangerText: '#FF0000', + menuBackgroundActive: '#E0E0E0', + }, + fonts: { + bodySmall: {}, + }, + }); + }); + + it('renders basic menu item correctly', () => { + const onPress = jest.fn(); + const {getByText} = render( + , + ); + + expect(getByText('Test Item')).toBeTruthy(); + }); + + it('handles press events', () => { + const onPress = jest.fn(); + const {getByText} = render( + , + ); + + fireEvent.press(getByText('Test Item')); + expect(onPress).toHaveBeenCalled(); + }); + + it('renders leading icon when provided', () => { + const {UNSAFE_getByProps} = render( + {}} />, + ); + + expect(UNSAFE_getByProps({source: 'check'})).toBeTruthy(); + }); + + it('renders trailing icon when provided', () => { + const {UNSAFE_getByProps} = render( + {}} />, + ); + + expect(UNSAFE_getByProps({source: 'close'})).toBeTruthy(); + }); + + it('handles disabled state correctly', () => { + const onPress = jest.fn(); + const {getByText} = render( + , + ); + + fireEvent.press(getByText('Test Item')); + expect(onPress).not.toHaveBeenCalled(); + }); +}); diff --git a/src/components/Menu/MenuItem/index.ts b/src/components/Menu/MenuItem/index.ts new file mode 100644 index 0000000..2b5e263 --- /dev/null +++ b/src/components/Menu/MenuItem/index.ts @@ -0,0 +1 @@ +export * from './MenuItem'; diff --git a/src/components/Menu/MenuItem/styles.ts b/src/components/Menu/MenuItem/styles.ts new file mode 100644 index 0000000..0ea4aef --- /dev/null +++ b/src/components/Menu/MenuItem/styles.ts @@ -0,0 +1,49 @@ +import {StyleSheet} from 'react-native'; +import {Theme} from '../../../utils/types'; + +export const createStyles = (theme: Theme) => + StyleSheet.create({ + container: { + height: 30, + backgroundColor: 'transparent', + flexDirection: 'row', + alignItems: 'center', + paddingHorizontal: 12, + }, + containerWithLeading: { + paddingLeft: 8, + }, + leadingContainer: { + flexDirection: 'row', + alignItems: 'center', + justifyContent: 'flex-start', + }, + contentContainer: { + flex: 1, + flexDirection: 'row', + alignItems: 'center', + justifyContent: 'flex-start', + marginLeft: 0, + }, + label: { + ...theme.fonts.bodySmall, + textAlign: 'left', + paddingLeft: 0, + }, + labelDisabled: { + opacity: 0.5, + }, + itemDisabled: { + opacity: 0.5, + }, + trailingContainer: { + alignItems: 'flex-end', + }, + groupLabel: { + paddingTop: 12, + opacity: 0.5, + }, + activeParent: { + backgroundColor: theme.colors.menuBackgroundActive, + }, + }); diff --git a/src/components/Menu/SubMenu/SubMenu.tsx b/src/components/Menu/SubMenu/SubMenu.tsx new file mode 100644 index 0000000..c99b0c2 --- /dev/null +++ b/src/components/Menu/SubMenu/SubMenu.tsx @@ -0,0 +1,69 @@ +import React from 'react'; + +import {Menu as PaperMenu} from 'react-native-paper'; + +import {useTheme} from '../../../hooks'; + +import {createStyles} from './styles'; + +interface SubMenuProps { + visible: boolean; + onDismiss: () => void; + children: React.ReactNode; + anchorPosition?: {x: number; y: number}; +} + +export const SubMenu: React.FC = ({ + visible, + onDismiss, + children, + anchorPosition, +}) => { + const theme = useTheme(); + const styles = createStyles(theme); + + return ( + + {children} + + ); +}; + +/** + * SubMenu component for nested menu items. + * + * Usage example: + * ```tsx + * , + * // Nested submenu + * , + * , + * ]} + * />, + * , + * ]} + * /> + * ``` + * + * Features: + * - Supports infinite nesting of submenus + * - Parent menu dims when submenu is open + * - Maintains consistent styling with parent menu + */ diff --git a/src/components/Menu/SubMenu/__tests__/SubMenu.test.tsx b/src/components/Menu/SubMenu/__tests__/SubMenu.test.tsx new file mode 100644 index 0000000..8368d7e --- /dev/null +++ b/src/components/Menu/SubMenu/__tests__/SubMenu.test.tsx @@ -0,0 +1,49 @@ +import React from 'react'; +import {render} from '../../../../../jest/test-utils'; +import {SubMenu} from '../SubMenu'; +import {MenuItem} from '../../MenuItem'; + +describe('SubMenu', () => { + it('renders when visible', () => { + const {getByText} = render( + {}} + anchorPosition={{x: 100, y: 100}}> + {}} /> + , + ); + + expect(getByText('SubMenu Item')).toBeTruthy(); + }); + + it('does not render when not visible', () => { + const {queryByText} = render( + {}} + anchorPosition={{x: 100, y: 100}}> + {}} /> + , + ); + + expect(queryByText('SubMenu Item')).toBeNull(); + }); + + it('handles multiple menu items', () => { + const {getByText} = render( + {}} + anchorPosition={{x: 100, y: 100}}> + {}} /> + {}} /> + {}} /> + , + ); + + expect(getByText('Item 1')).toBeTruthy(); + expect(getByText('Item 2')).toBeTruthy(); + expect(getByText('Item 3')).toBeTruthy(); + }); +}); diff --git a/src/components/Menu/SubMenu/index.ts b/src/components/Menu/SubMenu/index.ts new file mode 100644 index 0000000..aa4cda5 --- /dev/null +++ b/src/components/Menu/SubMenu/index.ts @@ -0,0 +1 @@ +export * from './SubMenu'; diff --git a/src/components/Menu/SubMenu/styles.ts b/src/components/Menu/SubMenu/styles.ts new file mode 100644 index 0000000..c4a2cf7 --- /dev/null +++ b/src/components/Menu/SubMenu/styles.ts @@ -0,0 +1,17 @@ +import {StyleSheet} from 'react-native'; + +import {Theme} from '../../../utils/types'; + +export const createStyles = (theme: Theme) => + StyleSheet.create({ + menu: { + //minWidth: 220, + marginTop: 0, + marginLeft: 0, + }, + content: { + paddingVertical: 6, + backgroundColor: theme.colors.menuBackground, + borderRadius: 15, + }, + }); diff --git a/src/components/Menu/__tests__/Menu.test.tsx b/src/components/Menu/__tests__/Menu.test.tsx new file mode 100644 index 0000000..427d65a --- /dev/null +++ b/src/components/Menu/__tests__/Menu.test.tsx @@ -0,0 +1,35 @@ +import React from 'react'; +import {render} from '../../../../jest/test-utils'; +import {Menu} from '../Menu'; + +describe('Menu', () => { + it('renders menu items correctly', () => { + const {getByText} = render( + {}} anchor={undefined}> + {}} /> + {}} /> + , + ); + + expect(getByText('Item 1')).toBeTruthy(); + expect(getByText('Item 2')).toBeTruthy(); + }); + + it('renders separators correctly', () => { + const {UNSAFE_getAllByType} = render( + {}} anchor={undefined}> + {}} /> + + {}} /> + + {}} /> + , + ); + + const separators = UNSAFE_getAllByType(Menu.Separator); + const groupSeparators = UNSAFE_getAllByType(Menu.GroupSeparator); + + expect(separators).toHaveLength(1); + expect(groupSeparators).toHaveLength(1); + }); +}); diff --git a/src/components/Menu/index.ts b/src/components/Menu/index.ts new file mode 100644 index 0000000..2574f97 --- /dev/null +++ b/src/components/Menu/index.ts @@ -0,0 +1,2 @@ +export * from './Menu'; +export * from './MenuItem'; diff --git a/src/components/Menu/styles.ts b/src/components/Menu/styles.ts new file mode 100644 index 0000000..189f282 --- /dev/null +++ b/src/components/Menu/styles.ts @@ -0,0 +1,34 @@ +import {StyleSheet} from 'react-native'; + +import {Theme} from '../../utils/types'; + +export const createStyles = (theme: Theme) => + StyleSheet.create({ + menu: { + shadowColor: 'rgba(0, 0, 0, 0.05)', + shadowRadius: 70, + shadowOffset: {width: 0, height: 0}, + elevation: 5, + }, + menuWithSubmenu: { + elevation: 0, + shadowOpacity: 0, + }, + content: { + paddingVertical: 4, + backgroundColor: theme.colors.menuBackground, + borderRadius: 15, + }, + contentWithSubmenu: { + backgroundColor: theme.colors.menuBackgroundDimmed, + }, + groupSeparator: { + height: 6, + flexShrink: 0, + backgroundColor: 'transparent', + }, + separator: { + //height: 1, + backgroundColor: theme.colors.menuSeparator, + }, + }); diff --git a/src/components/Message/Message.tsx b/src/components/Message/Message.tsx index 7cc8a56..2c06b88 100644 --- a/src/components/Message/Message.tsx +++ b/src/components/Message/Message.tsx @@ -1,24 +1,34 @@ import * as React from 'react'; -import {Pressable, Text, View} from 'react-native'; +import {Pressable, Text, View, Animated} from 'react-native'; -import {useTheme} from '../../hooks'; import {oneOf} from '@flyerhq/react-native-link-preview'; +import ReactNativeHapticFeedback from 'react-native-haptic-feedback'; + +import {useTheme} from '../../hooks'; import styles from './styles'; -import {Avatar} from '../Avatar'; -import {StatusIcon} from '../StatusIcon'; -import {FileMessage} from '../FileMessage'; -import {ImageMessage} from '../ImageMessage'; -import {TextMessage, TextMessageTopLevelProps} from '../TextMessage'; +import { + Avatar, + StatusIcon, + FileMessage, + ImageMessage, + TextMessage, + TextMessageTopLevelProps, +} from '..'; import {MessageType} from '../../utils/types'; import {excludeDerivedMessageProps, UserContext} from '../../utils'; +const hapticOptions = { + enableVibrateFallback: true, + ignoreAndroidSystemSettings: false, +}; + export interface MessageTopLevelProps extends TextMessageTopLevelProps { /** Called when user makes a long press on any message */ - onMessageLongPress?: (message: MessageType.Any) => void; + onMessageLongPress?: (message: MessageType.Any, event?: any) => void; /** Called when user taps on any message */ - onMessagePress?: (message: MessageType.Any) => void; + onMessagePress?: (message: MessageType.Any, event?: any) => void; /** Customize the default bubble using this function. `child` is a content * you should render inside your bubble, `message` is a current message * (contains `author` inside) and `nextMessageInGroup` allows you to see @@ -28,6 +38,7 @@ export interface MessageTopLevelProps extends TextMessageTopLevelProps { child: React.ReactNode; message: MessageType.Any; nextMessageInGroup: boolean; + scale?: Animated.Value; }) => React.ReactNode; /** Render a custom message inside predefined bubble */ renderCustomMessage?: ( @@ -89,6 +100,7 @@ export const Message = React.memo( }: MessageProps) => { const user = React.useContext(UserContext); const theme = useTheme(); + const scaleAnim = React.useRef(new Animated.Value(1)).current; const currentUserIsAuthor = message.type !== 'dateHeader' && user?.id === message.author.id; @@ -101,6 +113,24 @@ export const Message = React.memo( theme, }); + const handlePressIn = () => { + Animated.spring(scaleAnim, { + toValue: 1.03, + friction: 8, + tension: 100, + useNativeDriver: true, + }).start(); + }; + + const handlePressOut = () => { + Animated.spring(scaleAnim, { + toValue: 1, + friction: 8, + tension: 100, + useNativeDriver: true, + }).start(); + }; + if (message.type === 'dateHeader') { return ( @@ -127,6 +157,7 @@ export const Message = React.memo( child, message: excludeDerivedMessageProps(message), nextMessageInGroup: roundBorder, + scale: scaleAnim, }); }; @@ -197,10 +228,15 @@ export const Message = React.memo( }} /> - onMessageLongPress?.(excludeDerivedMessageProps(message)) - } - onPress={() => onMessagePress?.(excludeDerivedMessageProps(message))} + onLongPress={event => { + ReactNativeHapticFeedback.trigger('impactLight', hapticOptions); + onMessageLongPress?.(excludeDerivedMessageProps(message), event); + }} + onPress={event => { + onMessagePress?.(excludeDerivedMessageProps(message), event); + }} + onPressIn={handlePressIn} + onPressOut={handlePressOut} style={pressable}> {renderBubbleContainer()} diff --git a/src/components/ModelsHeaderRight/ModelsHeaderRight.tsx b/src/components/ModelsHeaderRight/ModelsHeaderRight.tsx index 47744e0..f25a6b1 100644 --- a/src/components/ModelsHeaderRight/ModelsHeaderRight.tsx +++ b/src/components/ModelsHeaderRight/ModelsHeaderRight.tsx @@ -2,13 +2,11 @@ import {Image, View} from 'react-native'; import React, {useContext, useState} from 'react'; import {observer} from 'mobx-react'; -import {Menu, IconButton, Divider} from 'react-native-paper'; +import {IconButton} from 'react-native-paper'; import iconHF from '../../assets/icon-hf.png'; import iconHFLight from '../../assets/icon-hf-light.png'; -import {useTheme} from '../../hooks'; - import {createStyles} from './styles'; import {ModelsResetDialog} from '../ModelsResetDialog'; @@ -16,6 +14,8 @@ import {modelStore, uiStore} from '../../store'; import {L10nContext} from '../../utils'; +import {Menu} from '..'; + export const ModelsHeaderRight = observer(() => { const [menuVisible, setMenuVisible] = useState(false); const [resetDialogVisible, setResetDialogVisible] = useState(false); @@ -23,8 +23,7 @@ export const ModelsHeaderRight = observer(() => { const l10n = useContext(L10nContext); - const theme = useTheme(); - const styles = createStyles(theme); + const styles = createStyles(); const filters = uiStore.pageStates.modelsScreen.filters; const setFilters = (value: string[]) => { @@ -63,7 +62,7 @@ export const ModelsHeaderRight = observer(() => { setMenuVisible(false)} - contentStyle={styles.menuContent} + selectable anchor={ { /> }> {/* Filter section */} - + ( + icon={({size}) => ( )} onPress={() => toggleFilter('hf')} - title={l10n.menuTitleHf} - titleStyle={styles.menuItem} - trailingIcon={filters.includes('hf') ? 'check' : undefined} + label={l10n.menuTitleHf} + selected={filters.includes('hf')} + style={styles.menuItem} /> toggleFilter('downloaded')} - title={l10n.menuTitleDownloaded} - titleStyle={styles.menuItem} - trailingIcon={filters.includes('downloaded') ? 'check' : undefined} + label={l10n.menuTitleDownloaded} + selected={filters.includes('downloaded')} + style={styles.menuItem} /> {/* View section */} - + toggleFilter('grouped')} - title={l10n.menuTitleGrouped} - titleStyle={styles.menuItem} - trailingIcon={filters.includes('grouped') ? 'check' : undefined} + label={l10n.menuTitleGrouped} + selected={filters.includes('grouped')} + style={styles.menuItem} /> {/* Actions section */} - + { setMenuVisible(false); showResetDialog(); }} - title={l10n.menuTitleReset} - titleStyle={styles.menuItem} + label={l10n.menuTitleReset} + style={styles.menuItem} /> diff --git a/src/components/ModelsHeaderRight/styles.ts b/src/components/ModelsHeaderRight/styles.ts index ceda605..8463def 100644 --- a/src/components/ModelsHeaderRight/styles.ts +++ b/src/components/ModelsHeaderRight/styles.ts @@ -1,7 +1,6 @@ import {StyleSheet} from 'react-native'; -import {Theme} from '../../utils/types'; -export const createStyles = (theme: Theme) => +export const createStyles = () => StyleSheet.create({ container: { flexDirection: 'row', @@ -12,25 +11,7 @@ export const createStyles = (theme: Theme) => margin: 0, marginHorizontal: 4, }, - menuContent: { - borderRadius: 8, - elevation: 8, - marginTop: 8, - minWidth: 220, - backgroundColor: theme.colors.surface, - }, - menuSection: { - fontSize: 11, - opacity: 0.7, - fontWeight: 'bold', - paddingHorizontal: 16, - paddingVertical: 4, - }, menuItem: { - fontSize: 13, - }, - divider: { - marginVertical: 4, - opacity: 0.2, + width: 220, }, }); diff --git a/src/components/TextMessage/TextMessage.tsx b/src/components/TextMessage/TextMessage.tsx index a021913..3ad0813 100644 --- a/src/components/TextMessage/TextMessage.tsx +++ b/src/components/TextMessage/TextMessage.tsx @@ -157,6 +157,7 @@ export const TextMessage = ({ {/*Platform.OS === 'ios' ? ( diff --git a/src/components/index.ts b/src/components/index.ts index 539078b..91f3e25 100644 --- a/src/components/index.ts +++ b/src/components/index.ts @@ -3,15 +3,16 @@ export * from './AttachmentButton'; export * from './Avatar'; export * from './BottomSheetSearchbar'; export * from './Bubble'; +export * from './ChatInput'; export * from './ChatView'; export * from './CircularActivityIndicator'; export * from './FileMessage'; export * from './HeaderRight'; export * from './ImageMessage'; -export * from './Input'; export * from './KeyboardAccessoryView'; export * from './LoadingBubble'; export * from './MarkdownView'; +export * from './Menu'; export * from './Message'; export * from './ModelsHeaderRight'; export * from './ModelsResetDialog'; diff --git a/src/hooks/__tests__/useChatSession.test.ts b/src/hooks/__tests__/useChatSession.test.ts index 5fa6efd..4cf74a4 100644 --- a/src/hooks/__tests__/useChatSession.test.ts +++ b/src/hooks/__tests__/useChatSession.test.ts @@ -40,13 +40,7 @@ describe('useChatSession', () => { it('should send a message and update the chat session', async () => { const {result} = renderHook(() => - useChatSession( - modelStore.context, - {current: null}, - [], - textMessage.author, - mockAssistant, - ), + useChatSession({current: null}, textMessage.author, mockAssistant), ); await act(async () => { @@ -58,14 +52,9 @@ describe('useChatSession', () => { }); it('should handle model not loaded scenario', async () => { + modelStore.context = undefined; const {result} = renderHook(() => - useChatSession( - undefined, - {current: null}, - [], - textMessage.author, - assistant, - ), + useChatSession({current: null}, textMessage.author, assistant), ); await act(async () => { @@ -92,13 +81,7 @@ describe('useChatSession', () => { } const {result} = renderHook(() => - useChatSession( - modelStore.context, - {current: null}, - [], - textMessage.author, - mockAssistant, - ), + useChatSession({current: null}, textMessage.author, mockAssistant), ); await act(async () => { @@ -127,13 +110,7 @@ describe('useChatSession', () => { } const {result} = renderHook(() => - useChatSession( - modelStore.context, - {current: null}, - [], - textMessage.author, - mockAssistant, - ), + useChatSession({current: null}, textMessage.author, mockAssistant), ); await act(async () => { @@ -173,13 +150,7 @@ describe('useChatSession', () => { it('should reset the conversation', () => { const {result} = renderHook(() => - useChatSession( - modelStore.context, - {current: null}, - [], - textMessage.author, - mockAssistant, - ), + useChatSession({current: null}, textMessage.author, mockAssistant), ); result.current.handleResetConversation(); @@ -194,13 +165,7 @@ describe('useChatSession', () => { it('should not stop completion when inferencing is false', () => { const {result} = renderHook(() => - useChatSession( - modelStore.context, - {current: null}, - [], - textMessage.author, - mockAssistant, - ), + useChatSession({current: null}, textMessage.author, mockAssistant), ); result.current.handleStopPress(); @@ -221,13 +186,7 @@ describe('useChatSession', () => { } const {result} = renderHook(() => - useChatSession( - modelStore.context, - {current: null}, - [], - textMessage.author, - mockAssistant, - ), + useChatSession({current: null}, textMessage.author, mockAssistant), ); const sendPromise = act(async () => { @@ -237,13 +196,13 @@ describe('useChatSession', () => { await act(async () => { await new Promise(resolve => setTimeout(resolve, 0)); }); - expect(result.current.inferencing).toBe(true); + expect(modelStore.inferencing).toBe(true); await act(async () => { resolveCompletion!({timings: {total: 100}, usage: {}}); await sendPromise; }); - expect(result.current.inferencing).toBe(false); + expect(modelStore.inferencing).toBe(false); }); test.each([ @@ -273,13 +232,7 @@ describe('useChatSession', () => { modelStore.setActiveModel(testModel.id); const {result} = renderHook(() => - useChatSession( - modelStore.context, - {current: null}, - [], - textMessage.author, - mockAssistant, - ), + useChatSession({current: null}, textMessage.author, mockAssistant), ); await act(async () => { diff --git a/src/hooks/__tests__/useMessageActions.test.ts b/src/hooks/__tests__/useMessageActions.test.ts new file mode 100644 index 0000000..f6c17b9 --- /dev/null +++ b/src/hooks/__tests__/useMessageActions.test.ts @@ -0,0 +1,227 @@ +import Clipboard from '@react-native-clipboard/clipboard'; +import {renderHook, act} from '@testing-library/react-hooks'; + +import {textMessage, user} from '../../../jest/fixtures'; +import {createModel} from '../../../jest/fixtures/models'; + +import {useMessageActions} from '../useMessageActions'; + +import {chatSessionStore, modelStore} from '../../store'; + +jest.mock('@react-native-clipboard/clipboard', () => ({ + setString: jest.fn(), +})); + +describe('useMessageActions', () => { + const mockSetInputText = jest.fn(); + const mockHandleSendPress = jest.fn(); + const messages = [ + { + ...textMessage, + id: '1', + text: 'Hello', + author: user, + }, + { + ...textMessage, + id: '2', + text: 'Hi there', + author: {id: 'assistant'}, + }, + ]; + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('copies message text to clipboard', () => { + const {result} = renderHook(() => + useMessageActions({ + user, + messages, + handleSendPress: mockHandleSendPress, + setInputText: mockSetInputText, + }), + ); + + act(() => { + result.current.handleCopy({ + ...textMessage, + text: 'Copy this text', + type: 'text', + }); + }); + + expect(Clipboard.setString).toHaveBeenCalledWith('Copy this text'); + }); + + it('enters edit mode for user message', () => { + const {result} = renderHook(() => + useMessageActions({ + user, + messages, + handleSendPress: mockHandleSendPress, + setInputText: mockSetInputText, + }), + ); + + const userMessage = { + ...textMessage, + id: 'test-id', + text: 'Edit this message', + author: user, + type: 'text' as const, + }; + + act(() => { + result.current.handleEdit(userMessage); + }); + + expect(chatSessionStore.enterEditMode).toHaveBeenCalledWith('test-id'); + expect(mockSetInputText).toHaveBeenCalledWith('Edit this message'); + }); + + it('does not enter edit mode for assistant message', () => { + const {result} = renderHook(() => + useMessageActions({ + user, + messages, + handleSendPress: mockHandleSendPress, + setInputText: mockSetInputText, + }), + ); + + const assistantMessage = { + ...textMessage, + author: {id: 'assistant'}, + type: 'text' as const, + }; + + act(() => { + result.current.handleEdit(assistantMessage); + }); + + expect(chatSessionStore.enterEditMode).not.toHaveBeenCalled(); + expect(mockSetInputText).not.toHaveBeenCalled(); + }); + + describe('handleTryAgain', () => { + it('resubmits user message', async () => { + const {result} = renderHook(() => + useMessageActions({ + user, + messages, + handleSendPress: mockHandleSendPress, + setInputText: mockSetInputText, + }), + ); + + const userMessage = { + ...textMessage, + id: '1', + text: 'Try again with this', + author: user, + type: 'text' as const, + }; + + await act(async () => { + await result.current.handleTryAgain(userMessage); + }); + + expect(chatSessionStore.removeMessagesFromId).toHaveBeenCalledWith( + '1', + true, + ); + expect(mockHandleSendPress).toHaveBeenCalledWith({ + text: 'Try again with this', + type: 'text', + }); + }); + + it('resubmits last user message when retrying assistant message', async () => { + const _messages = [ + { + ...textMessage, + id: '2', + text: 'Assistant response', + author: {id: 'assistant'}, + type: 'text' as const, + }, + { + ...textMessage, + id: '1', + text: 'User message', + author: user, + type: 'text' as const, + }, + ]; + + const {result} = renderHook(() => + useMessageActions({ + user, + messages: _messages, + handleSendPress: mockHandleSendPress, + setInputText: mockSetInputText, + }), + ); + + await act(async () => { + await result.current.handleTryAgain(_messages[1]); + }); + + expect(chatSessionStore.removeMessagesFromId).toHaveBeenCalledWith( + '1', + true, + ); + expect(mockHandleSendPress).toHaveBeenCalledWith({ + text: 'User message', + type: 'text', + }); + }); + }); + + describe('handleTryAgainWith', () => { + it('uses current model if model ID matches', async () => { + const {result} = renderHook(() => + useMessageActions({ + user, + messages, + handleSendPress: mockHandleSendPress, + setInputText: mockSetInputText, + }), + ); + + modelStore.activeModelId = 'model-1'; + + await act(async () => { + await result.current.handleTryAgainWith('model-1', messages[0]); + }); + + expect(modelStore.initContext).not.toHaveBeenCalled(); + expect(chatSessionStore.removeMessagesFromId).toHaveBeenCalled(); + expect(mockHandleSendPress).toHaveBeenCalled(); + }); + + it('initializes new model if model ID differs', async () => { + const {result} = renderHook(() => + useMessageActions({ + user, + messages, + handleSendPress: mockHandleSendPress, + setInputText: mockSetInputText, + }), + ); + + modelStore.activeModelId = 'model-1'; + modelStore.models = [createModel({id: 'model-2', name: 'Model 2'})]; + + await act(async () => { + await result.current.handleTryAgainWith('model-2', messages[0]); + }); + + expect(modelStore.initContext).toHaveBeenCalled(); + expect(chatSessionStore.removeMessagesFromId).toHaveBeenCalled(); + expect(mockHandleSendPress).toHaveBeenCalled(); + }); + }); +}); diff --git a/src/hooks/index.ts b/src/hooks/index.ts index 1c0d3b7..015ff2c 100644 --- a/src/hooks/index.ts +++ b/src/hooks/index.ts @@ -2,5 +2,6 @@ export * from './usePrevious'; export * from './useTheme'; export * from './useChatSession'; export * from './useMemoryCheck'; +export * from './useMessageActions'; export * from './useMoveScroll'; export * from './useStorageCheck'; diff --git a/src/hooks/useChatSession.ts b/src/hooks/useChatSession.ts index 181bfe2..6aec239 100644 --- a/src/hooks/useChatSession.ts +++ b/src/hooks/useChatSession.ts @@ -1,7 +1,5 @@ -import React, {useRef, useCallback, useState} from 'react'; - +import React, {useRef, useCallback} from 'react'; import {toJS} from 'mobx'; -import {LlamaContext} from '@pocketpalai/llama.rn'; import throttle from 'lodash.throttle'; import {randId} from '../utils'; @@ -12,17 +10,13 @@ import {MessageType, User} from '../utils/types'; import {applyChatTemplate, convertToChatMessages} from '../utils/chat'; export const useChatSession = ( - context: LlamaContext | undefined, currentMessageInfo: React.MutableRefObject<{ createdAt: number; id: string; } | null>, - messages: MessageType.Any[], user: User, assistant: User, ) => { - const [inferencing, setInferencing] = useState(false); - const [isStreaming, setIsStreaming] = useState(false); const l10n = React.useContext(L10nContext); const conversationIdRef = useRef(randId()); @@ -31,20 +25,18 @@ export const useChatSession = ( const updateInterval = 150; // Interval for flushing token buffer (in ms) // Function to flush the token buffer and update the chat message - const flushTokenBuffer = useCallback( - (createdAt: number, id: string) => { - if (tokenBufferRef.current.length > 0 && context) { - chatSessionStore.updateMessageToken( - {token: tokenBufferRef.current}, - createdAt, - id, - context, - ); - tokenBufferRef.current = ''; // Reset the token buffer - } - }, - [context], - ); + const flushTokenBuffer = useCallback((createdAt: number, id: string) => { + const context = modelStore.context; + if (tokenBufferRef.current.length > 0 && context) { + chatSessionStore.updateMessageToken( + {token: tokenBufferRef.current}, + createdAt, + id, + context, + ); + tokenBufferRef.current = ''; // Reset the token buffer + } + }, []); // Throttled version of flushTokenBuffer to prevent excessive updates const throttledFlushTokenBuffer = throttle( @@ -71,6 +63,7 @@ export const useChatSession = ( }; const handleSendPress = async (message: MessageType.PartialText) => { + const context = modelStore.context; if (!context) { addSystemMessage(l10n.modelNotLoaded); return; @@ -89,8 +82,8 @@ export const useChatSession = ( }, }; addMessage(textMessage); - setInferencing(true); - setIsStreaming(false); + modelStore.setInferencing(true); + modelStore.setIsStreaming(false); const id = randId(); const createdAt = Date.now(); @@ -107,7 +100,9 @@ export const useChatSession = ( : []), ...convertToChatMessages([ textMessage, - ...messages.filter(msg => msg.id !== textMessage.id), + ...chatSessionStore.currentSessionMessages.filter( + msg => msg.id !== textMessage.id, + ), ]), ]; @@ -124,11 +119,10 @@ export const useChatSession = ( {...completionParams, prompt}, data => { if (data.token && currentMessageInfo.current) { - if (!isStreaming) { - setIsStreaming(true); + if (!modelStore.isStreaming) { + modelStore.setIsStreaming(true); } tokenBufferRef.current += data.token; - // Avoid variable shadowing by using properties directly throttledFlushTokenBuffer( currentMessageInfo.current.createdAt, currentMessageInfo.current.id, @@ -151,11 +145,11 @@ export const useChatSession = ( chatSessionStore.updateMessage(id, { metadata: {timings: result.timings, copyable: true}, }); - setInferencing(false); - setIsStreaming(false); + modelStore.setInferencing(false); + modelStore.setIsStreaming(false); } catch (error) { - setInferencing(false); - setIsStreaming(false); + modelStore.setInferencing(false); + modelStore.setIsStreaming(false); const errorMessage = (error as Error).message; if (errorMessage.includes('network')) { // TODO: This can be removed. We don't use network for chat. @@ -172,7 +166,8 @@ export const useChatSession = ( }; const handleStopPress = () => { - if (inferencing && context) { + const context = modelStore.context; + if (modelStore.inferencing && context) { context.stopCompletion(); } if ( @@ -184,13 +179,13 @@ export const useChatSession = ( currentMessageInfo.current.id, ); } + modelStore.setInferencing(false); + modelStore.setIsStreaming(false); }; return { handleSendPress, handleResetConversation, handleStopPress, - inferencing, - isStreaming, }; }; diff --git a/src/hooks/useMessageActions.ts b/src/hooks/useMessageActions.ts new file mode 100644 index 0000000..afad942 --- /dev/null +++ b/src/hooks/useMessageActions.ts @@ -0,0 +1,93 @@ +import {useCallback} from 'react'; + +import Clipboard from '@react-native-clipboard/clipboard'; + +import {chatSessionStore, modelStore} from '../store'; + +import {MessageType, User} from '../utils/types'; + +interface UseMessageActionsProps { + user: User; + messages: MessageType.Any[]; + handleSendPress: (message: MessageType.PartialText) => Promise; + setInputText?: (text: string) => void; +} + +export const useMessageActions = ({ + user, + messages, + handleSendPress, + setInputText, +}: UseMessageActionsProps) => { + const handleCopy = useCallback((message: MessageType.Text) => { + if (message.type === 'text') { + Clipboard.setString(message.text.trim()); + } + }, []); + + const handleEdit = useCallback( + async (message: MessageType.Text) => { + if (message.type !== 'text' || message.author.id !== user.id) { + return; + } + + // Enter edit mode and set input text + chatSessionStore.enterEditMode(message.id); + setInputText?.(message.text); + }, + [setInputText, user.id], + ); + + const handleTryAgain = useCallback( + async (message: MessageType.Text) => { + if (message.type !== 'text') { + return; + } + + // If it's the user's message, resubmit it + if (message.author.id === user.id) { + // Remove all messages from this point (inclusive) + const messageText = message.text; + chatSessionStore.removeMessagesFromId(message.id, true); + await handleSendPress({text: messageText, type: 'text'}); + } else { + // If it's the assistant's message, find and resubmit the last user message + const messageIndex = messages.findIndex(msg => msg.id === message.id); + const previousMessage = messages + .slice(messageIndex + 1) + .find(msg => msg.author.id === user.id && msg.type === 'text') as + | MessageType.Text + | undefined; + + if (previousMessage && previousMessage.text) { + const messageText = previousMessage.text; + chatSessionStore.removeMessagesFromId(previousMessage.id, true); + await handleSendPress({text: messageText, type: 'text'}); + } + } + }, + [messages, handleSendPress, user.id], + ); + + const handleTryAgainWith = useCallback( + async (modelId: string, message: MessageType.Text) => { + if (modelId === modelStore.activeModelId) { + await handleTryAgain(message); + return; + } + const model = modelStore.models.find(m => m.id === modelId); + if (model) { + await modelStore.initContext(model); + await handleTryAgain(message); + } + }, + [handleTryAgain], + ); + + return { + handleCopy, + handleEdit, + handleTryAgain, + handleTryAgainWith, + }; +}; diff --git a/src/screens/ChatScreen/ChatScreen.tsx b/src/screens/ChatScreen/ChatScreen.tsx index 7f153e3..832191a 100644 --- a/src/screens/ChatScreen/ChatScreen.tsx +++ b/src/screens/ChatScreen/ChatScreen.tsx @@ -19,52 +19,56 @@ const renderBubble = ({ child, message, nextMessageInGroup, + scale, }: { child: ReactNode; message: MessageType.Any; nextMessageInGroup: boolean; + scale?: any; }) => ( ); export const ChatScreen: React.FC = observer(() => { - const context = modelStore.context; const currentMessageInfo = useRef<{createdAt: number; id: string} | null>( null, ); const l10n = React.useContext(L10nContext); - const messages = chatSessionStore.currentSessionMessages; - const {handleSendPress, handleStopPress, inferencing, isStreaming} = - useChatSession(context, currentMessageInfo, messages, user, assistant); + const {handleSendPress, handleStopPress} = useChatSession( + currentMessageInfo, + user, + assistant, + ); // Show loading bubble only during the thinking phase (inferencing but not streaming) - const isThinking = inferencing && !isStreaming; + const isThinking = modelStore.inferencing && !modelStore.isStreaming; return ( : undefined } renderBubble={renderBubble} - messages={messages} + messages={chatSessionStore.currentSessionMessages} onSendPress={handleSendPress} onStopPress={handleStopPress} user={user} - isStopVisible={inferencing} + isStopVisible={modelStore.inferencing} isThinking={isThinking} - isStreaming={isStreaming} + isStreaming={modelStore.isStreaming} sendButtonVisibilityMode="editing" textInputProps={{ - editable: !!context, - placeholder: !context + editable: !!modelStore.context, + placeholder: !modelStore.context ? modelStore.isContextLoading ? l10n.loadingModel : l10n.modelNotLoaded diff --git a/src/screens/ChatScreen/ModelNotLoadedMessage/__tests__/ModelNotLoadedMessage.test.tsx b/src/screens/ChatScreen/ModelNotLoadedMessage/__tests__/ModelNotLoadedMessage.test.tsx index 3607d49..a7d5672 100644 --- a/src/screens/ChatScreen/ModelNotLoadedMessage/__tests__/ModelNotLoadedMessage.test.tsx +++ b/src/screens/ChatScreen/ModelNotLoadedMessage/__tests__/ModelNotLoadedMessage.test.tsx @@ -8,7 +8,7 @@ import {ModelNotLoadedMessage} from '../ModelNotLoadedMessage'; import {modelStore} from '../../../../store'; import {l10n} from '../../../../utils/l10n'; -import {basicModel} from '../../../../../jest/fixtures/models'; +import {basicModel, modelsList} from '../../../../../jest/fixtures/models'; const Drawer = createDrawerNavigator(); const mockNavigate = jest.fn(); @@ -37,13 +37,8 @@ const customRender = (ui, {...renderOptions} = {}) => { describe('ModelNotLoadedMessage', () => { beforeEach(() => { jest.clearAllMocks(); - - // When use Object.defineProperty, since it doesn't count as a mock function - // jest.clearAllMocks() won't affect it, hence the need to explicitly undefine it. - Object.defineProperty(modelStore, 'lastUsedModel', { - get: jest.fn(() => undefined), - }); - + modelStore.models = modelsList; + modelStore.lastUsedModelId = undefined; (modelStore.initContext as jest.Mock).mockReset(); }); @@ -53,10 +48,7 @@ describe('ModelNotLoadedMessage', () => { }); it('renders correctly when last used model exists', () => { - Object.defineProperty(modelStore, 'lastUsedModel', { - get: jest.fn(() => basicModel), - }); - + modelStore.lastUsedModelId = modelStore.models[0].id; const {getByText} = customRender(); expect(getByText(l10n.en.readyToChat)).toBeTruthy(); @@ -72,9 +64,7 @@ describe('ModelNotLoadedMessage', () => { }); it('loads last used model when available', async () => { - Object.defineProperty(modelStore, 'lastUsedModel', { - get: jest.fn(() => basicModel), - }); + modelStore.lastUsedModelId = basicModel.id; (modelStore.initContext as jest.Mock).mockResolvedValue(undefined); const {getByText} = customRender(); @@ -87,9 +77,7 @@ describe('ModelNotLoadedMessage', () => { }); it('handles model loading error correctly', async () => { - Object.defineProperty(modelStore, 'lastUsedModel', { - get: jest.fn(() => basicModel), - }); + modelStore.lastUsedModelId = basicModel.id; const mockError = new Error('Failed to load model'); (modelStore.initContext as jest.Mock).mockRejectedValue(mockError); @@ -110,9 +98,7 @@ describe('ModelNotLoadedMessage', () => { }); it('updates last used model state on mount', async () => { - Object.defineProperty(modelStore, 'lastUsedModel', { - get: jest.fn(() => basicModel), - }); + modelStore.lastUsedModelId = basicModel.id; const {getByText} = customRender(); diff --git a/src/screens/ChatScreen/__tests__/ChatScreen.test.tsx b/src/screens/ChatScreen/__tests__/ChatScreen.test.tsx index 29df178..4ce29cb 100644 --- a/src/screens/ChatScreen/__tests__/ChatScreen.test.tsx +++ b/src/screens/ChatScreen/__tests__/ChatScreen.test.tsx @@ -175,11 +175,25 @@ describe('ChatScreen', () => { fireEvent.changeText(input, 'Hello, AI!'); }); - const sendButton = getByTestId('send-button'); - fireEvent.press(sendButton); + await act(async () => { + const sendButton = getByTestId('send-button'); + fireEvent.press(sendButton); + modelStore.setInferencing(true); // since mock doesn't really set inferencing + }); + + await waitFor( + () => { + expect(getByTestId('stop-button')).toBeTruthy(); + }, + { + timeout: 1000, + }, + ); const stopButton = getByTestId('stop-button'); - fireEvent.press(stopButton); + await act(async () => { + fireEvent.press(stopButton); + }); expect(modelStore.context?.stopCompletion).toHaveBeenCalled(); }); diff --git a/src/screens/ModelsScreen/ModelAccordion/__tests__/ModelAccordion.test.tsx b/src/screens/ModelsScreen/ModelAccordion/__tests__/ModelAccordion.test.tsx index a37deac..660abaf 100644 --- a/src/screens/ModelsScreen/ModelAccordion/__tests__/ModelAccordion.test.tsx +++ b/src/screens/ModelsScreen/ModelAccordion/__tests__/ModelAccordion.test.tsx @@ -11,9 +11,7 @@ import {modelStore} from '../../../../store'; describe('ModelAccordion', () => { beforeEach(() => { - Object.defineProperty(modelStore, 'activeModel', { - get: jest.fn(() => undefined), - }); + modelStore.activeModelId = undefined; }); it('renders the accordion with correct title and children', () => { @@ -79,18 +77,19 @@ describe('ModelAccordion', () => { }); it('applies active group styles when activeModel matches group type', () => { - const group = {type: 'Model Group1'}; - Object.defineProperty(modelStore, 'activeModel', { - get: jest.fn(() => ({type: 'Model Group1'})), - }); + const activeModel = modelStore.models[0]; + modelStore.activeModelId = activeModel.id; const {getByTestId} = render( - + , ); - const accordion = getByTestId('model-accordion-Model Group1').parent; + const accordion = getByTestId(`model-accordion-${activeModel.type}`).parent; expect(accordion?.props.style).toEqual( // Wow, this is a mess. @@ -106,17 +105,17 @@ describe('ModelAccordion', () => { }); it('applies default theme styles when activeModel does not match group type', () => { - const group = {type: 'Model Group'}; - Object.defineProperty(modelStore, 'activeModel', { - get: jest.fn(() => ({type: 'Different Group'})), - }); + const group = {type: 'Model Group blah blah'}; + modelStore.activeModelId = modelStore.models[0].id; const {getByTestId} = render( , ); - const accordion = getByTestId('model-accordion-Model Group').parent; + const accordion = getByTestId( + 'model-accordion-Model Group blah blah', + ).parent; expect(accordion?.props.style).toEqual( expect.arrayContaining([ diff --git a/src/screens/ModelsScreen/ModelCard/ModelCard.tsx b/src/screens/ModelsScreen/ModelCard/ModelCard.tsx index 52ac51d..32fea62 100644 --- a/src/screens/ModelsScreen/ModelCard/ModelCard.tsx +++ b/src/screens/ModelsScreen/ModelCard/ModelCard.tsx @@ -297,7 +297,7 @@ export const ModelCard: React.FC = observer( {model.downloadSpeed && ( diff --git a/src/screens/ModelsScreen/ModelCard/__tests__/ModelCard.test.tsx b/src/screens/ModelsScreen/ModelCard/__tests__/ModelCard.test.tsx index 6979505..57a695d 100644 --- a/src/screens/ModelsScreen/ModelCard/__tests__/ModelCard.test.tsx +++ b/src/screens/ModelsScreen/ModelCard/__tests__/ModelCard.test.tsx @@ -136,9 +136,8 @@ describe('ModelCard', () => { expect(queryByTestId('download-progress-bar')).toBeNull(); }); - Object.defineProperty(modelStore, 'isDownloading', { - get: jest.fn(() => () => true), - }); + modelStore.downloadJobs.set(downloadingModel.id, true); + rerender(); await waitFor(() => { diff --git a/src/store/ChatSessionStore.ts b/src/store/ChatSessionStore.ts index 1f5dff9..76ef715 100644 --- a/src/store/ChatSessionStore.ts +++ b/src/store/ChatSessionStore.ts @@ -1,9 +1,10 @@ -import {makeAutoObservable, runInAction} from 'mobx'; +import {LlamaContext} from '@pocketpalai/llama.rn'; import * as RNFS from '@dr.pogodin/react-native-fs'; +import {makeAutoObservable, runInAction} from 'mobx'; import {format, isToday, isYesterday} from 'date-fns'; -import {MessageType} from '../utils/types'; -import {LlamaContext} from '@pocketpalai/llama.rn'; + import {assistant} from '../utils/chat'; +import {MessageType} from '../utils/types'; const NEW_SESSION_TITLE = 'New Session'; const TITLE_LIMIT = 40; @@ -22,6 +23,8 @@ interface SessionGroup { class ChatSessionStore { sessions: SessionMetaData[] = []; activeSessionId: string | null = null; + isEditMode: boolean = false; + editingMessageId: string | null = null; constructor() { makeAutoObservable(this); @@ -67,12 +70,14 @@ class ChatSessionStore { resetActiveSession() { runInAction(() => { + this.exitEditMode(); this.activeSessionId = null; }); } setActiveSession(sessionId: string) { runInAction(() => { + this.exitEditMode(); this.activeSessionId = sessionId; }); } @@ -110,6 +115,14 @@ class ChatSessionStore { if (this.activeSessionId) { const session = this.sessions.find(s => s.id === this.activeSessionId); if (session) { + if (this.isEditMode && this.editingMessageId) { + const messageIndex = session.messages.findIndex( + msg => msg.id === this.editingMessageId, + ); + if (messageIndex >= 0) { + return session.messages.slice(messageIndex + 1); + } + } return session.messages; } } @@ -274,6 +287,87 @@ class ChatSessionStore { return orderedGroups; } + + /** + * Enters edit mode for a specific message + */ + enterEditMode(messageId: string): void { + if (this.activeSessionId) { + const session = this.sessions.find(s => s.id === this.activeSessionId); + if (session) { + const messageIndex = session.messages.findIndex( + msg => msg.id === messageId, + ); + if (messageIndex >= 0) { + runInAction(() => { + this.isEditMode = true; + this.editingMessageId = messageId; + }); + } + } + } + } + + /** + * Exits edit mode without making changes + */ + exitEditMode(): void { + runInAction(() => { + this.isEditMode = false; + this.editingMessageId = null; + }); + } + + /** + * Commits the edit by actually removing messages after the edited message + */ + commitEdit(): void { + if (this.activeSessionId && this.editingMessageId) { + const session = this.sessions.find(s => s.id === this.activeSessionId); + if (session) { + const messageIndex = session.messages.findIndex( + msg => msg.id === this.editingMessageId, + ); + if (messageIndex >= 0) { + runInAction(() => { + session.messages = session.messages.slice(messageIndex + 1); + this.isEditMode = false; + this.editingMessageId = null; + this.saveSessionsMetadata(); + }); + } + } + } + } + + /** + * Removes messages from the current active session starting from a specific message ID. + * If includeMessage is true, the message with the given ID is also removed. + * + * @param messageId - The ID of the message to start removal from. + * @param includeMessage - Whether to include the message with the given ID in the removal. + */ + removeMessagesFromId( + messageId: string, + includeMessage: boolean = true, + ): void { + if (this.activeSessionId) { + const session = this.sessions.find(s => s.id === this.activeSessionId); + if (session) { + const messageIndex = session.messages.findIndex( + msg => msg.id === messageId, + ); + if (messageIndex >= 0) { + runInAction(() => { + session.messages = session.messages.slice( + includeMessage ? messageIndex + 1 : messageIndex, + ); + this.saveSessionsMetadata(); + }); + } + } + } + } } export const chatSessionStore = new ChatSessionStore(); diff --git a/src/store/ModelStore.ts b/src/store/ModelStore.ts index 4089977..0305648 100644 --- a/src/store/ModelStore.ts +++ b/src/store/ModelStore.ts @@ -9,6 +9,7 @@ import {computed, makeAutoObservable, ObservableMap, runInAction} from 'mobx'; import {CompletionParams, LlamaContext, initLlama} from '@pocketpalai/llama.rn'; import {uiStore} from './UIStore'; +import {chatSessionStore} from './ChatSessionStore'; import {defaultModels, MODEL_LIST_VERSION} from './defaultModels'; import {deepMerge, formatBytes, hasEnoughSpace, hfAsModel} from '../utils'; @@ -46,6 +47,9 @@ class ModelStore { MIN_CONTEXT_SIZE = 200; + inferencing: boolean = false; + isStreaming: boolean = false; + constructor() { makeAutoObservable(this, {activeModel: computed}); makePersistable(this, { @@ -593,6 +597,7 @@ class ModelStore { releaseContext = async () => { console.log('attempt to release'); + chatSessionStore.exitEditMode(); if (!this.context) { return Promise.resolve('No context to release'); } @@ -828,7 +833,6 @@ class ModelStore { const ctxtTemplate = (ctx.model as any)?.metadata?.[ 'tokenizer.chat_template' ]; - console.log('ctxtTemplate: ', ctxtTemplate); if (ctxtTemplate) { const contextStops = stops.filter(stop => ctxtTemplate.includes(stop)); stopTokens.push(...contextStops); @@ -860,6 +864,24 @@ class ModelStore { // Continue execution - stop token update is not critical } } + + get availableModels(): Model[] { + return this.models.filter( + model => + // Include models that are either local or downloaded + model.isLocal || + model.origin === ModelOrigin.LOCAL || + model.isDownloaded, + ); + } + + setInferencing(value: boolean) { + this.inferencing = value; + } + + setIsStreaming(value: boolean) { + this.isStreaming = value; + } } export const modelStore = new ModelStore(); diff --git a/src/utils/colorUtils.ts b/src/utils/colorUtils.ts new file mode 100644 index 0000000..b8eb51e --- /dev/null +++ b/src/utils/colorUtils.ts @@ -0,0 +1,79 @@ +/** + * Converts a hex color to RGBA + */ +export const hexToRGBA = (hex: string, alpha: number = 1): string => { + // Remove the hash if it exists + hex = hex.replace('#', ''); + + // Parse the hex values + const r = parseInt(hex.substring(0, 2), 16); + const g = parseInt(hex.substring(2, 4), 16); + const b = parseInt(hex.substring(4, 6), 16); + + // Return the rgba string + return `rgba(${r}, ${g}, ${b}, ${alpha})`; +}; + +/** + * Applies opacity to a color (works with both hex and rgba) + */ +export const withOpacity = (color: string, opacity: number): string => { + if (color.startsWith('rgba')) { + // If it's already rgba, just modify the opacity + return color.replace(/[\d.]+\)$/g, `${opacity})`); + } + return hexToRGBA(color, opacity); +}; + +/** + * Determines if a color is light or dark + */ +export const isLightColor = (color: string): boolean => { + let r: number, g: number, b: number; + + if (color.startsWith('#')) { + const hex = color.replace('#', ''); + r = parseInt(hex.substring(0, 2), 16); + g = parseInt(hex.substring(2, 4), 16); + b = parseInt(hex.substring(4, 6), 16); + } else if (color.startsWith('rgba')) { + const matches = color.match(/rgba?\((\d+),\s*(\d+),\s*(\d+)/); + if (!matches) { + return true; + } + [, r, g, b] = matches.map(Number); + } else { + return true; // Default to light for unknown formats + } + + // Calculate relative luminance + const luminance = (0.299 * r + 0.587 * g + 0.114 * b) / 255; + return luminance > 0.5; +}; + +/** + * Gets a contrasting color (black or white) based on background + */ +export const getContrastColor = (backgroundColor: string): string => { + return isLightColor(backgroundColor) ? '#000000' : '#FFFFFF'; +}; + +/** + * MD3 state layer opacity values + */ +export const stateLayerOpacity = { + hover: 0.08, + focus: 0.12, + pressed: 0.12, + dragged: 0.16, +} as const; + +/** + * Creates a state layer color based on the type of state + */ +export const createStateLayer = ( + baseColor: string, + state: keyof typeof stateLayerOpacity, +): string => { + return withOpacity(baseColor, stateLayerOpacity[state]); +}; diff --git a/src/utils/theme.ts b/src/utils/theme.ts index 33cf2b8..a09cea3 100644 --- a/src/utils/theme.ts +++ b/src/utils/theme.ts @@ -3,153 +3,258 @@ import { DefaultTheme as PaperLightTheme, } from 'react-native-paper'; -import {Colors, Theme} from './types'; +import {MD3BaseColors, SemanticColors, Theme} from './types'; +import {withOpacity, stateLayerOpacity} from './colorUtils'; -import {getThemeColorsAsArray} from '.'; - -const lightColors: Colors = { - ...PaperLightTheme.colors, - primary: '#6200ee', //PaperLightTheme.colors.primary, - accent: '#03dac4', - outlineVariant: '#a1a1a1', - receivedMessageDocumentIcon: PaperLightTheme.colors.primary, - sentMessageDocumentIcon: PaperLightTheme.colors.onSurface, - userAvatarImageBackground: 'transparent', - userAvatarNameColors: getThemeColorsAsArray(PaperLightTheme), - searchBarBackground: 'rgba(118, 118, 128, 0.12)', // iOS light mode searchbar +// MD3 key colors (seed colors) +const md3BaseColors: Partial = { + primary: '#111111', + secondary: '#3669F5', + tertiary: '#018786', + error: '#B3261E', }; -export const lightTheme: Theme = { - ...PaperLightTheme, - borders: { - inputBorderRadius: 20, - messageBorderRadius: 20, - }, - colors: lightColors, - fonts: { - ...PaperLightTheme.fonts, - dateDividerTextStyle: { - color: lightColors.onSurface, - fontSize: 12, - fontWeight: '800', - lineHeight: 16, - opacity: 0.4, - }, - emptyChatPlaceholderTextStyle: { - color: lightColors.onSurface, - fontSize: 16, - fontWeight: '500', - lineHeight: 24, - }, - inputTextStyle: { - fontSize: 16, - fontWeight: '500', - lineHeight: 24, - }, - receivedMessageBodyTextStyle: { - color: lightColors.onPrimary, - fontSize: 16, - fontWeight: '500', - lineHeight: 24, - }, - receivedMessageCaptionTextStyle: { - color: lightColors.onSurfaceVariant, - fontSize: 12, - fontWeight: '500', - lineHeight: 16, - }, - receivedMessageLinkDescriptionTextStyle: { - color: lightColors.onPrimary, - fontSize: 14, - fontWeight: '400', - lineHeight: 20, - }, - receivedMessageLinkTitleTextStyle: { - color: lightColors.onPrimary, - fontSize: 16, - fontWeight: '800', - lineHeight: 22, - }, - sentMessageBodyTextStyle: { - color: lightColors.onSurface, - fontSize: 16, - fontWeight: '500', - lineHeight: 24, - }, - sentMessageCaptionTextStyle: { - color: lightColors.onSurfaceVariant, - fontSize: 12, - fontWeight: '500', - lineHeight: 16, - }, - sentMessageLinkDescriptionTextStyle: { - color: lightColors.onSurface, - fontSize: 14, - fontWeight: '400', - lineHeight: 20, - }, - sentMessageLinkTitleTextStyle: { - color: lightColors.onSurface, - fontSize: 16, - fontWeight: '800', - lineHeight: 22, - }, - userAvatarTextStyle: { - color: lightColors.onSurface, - fontSize: 12, - fontWeight: '800', - lineHeight: 16, - }, - userNameTextStyle: { - fontSize: 12, - fontWeight: '800', - lineHeight: 16, - }, - }, - insets: { - messageInsetsHorizontal: 20, - messageInsetsVertical: 10, - }, +const createBaseColors = (isDark: boolean): MD3BaseColors => { + const baseTheme = isDark ? MD3DarkTheme : PaperLightTheme; + + if (isDark) { + return { + ...baseTheme.colors, + primary: '#DADDE6', + onPrimary: '#44464C', + primaryContainer: '#5B5E66', + onPrimaryContainer: '#DEE0E6', + secondary: '#95ABE6', + onSecondary: '#11214C', + secondaryContainer: '#162C66', + onSecondaryContainer: '#ADBCE6', + tertiary: '#80E6E4', + onTertiary: '#014C4C', + tertiaryContainer: '#016665', + onTertiaryContainer: '#9EE6E5', + error: '#E69490', + onError: '#4C100D', + errorContainer: '#661511', + onErrorContainer: '#E6ACA9', + background: '#333333', + onBackground: '#e5e5e6', + surface: '#333333', + onSurface: '#e5e5e6', + surfaceVariant: '#646466', + onSurfaceVariant: '#e3e4e6', + outline: '#b0b1b3', + outlineVariant: '#a1a1a1', + // Additional required MD3 colors + surfaceDisabled: withOpacity('#333333', 0.12), + onSurfaceDisabled: withOpacity('#e5e5e6', 0.38), + inverseSurface: '#e5e5e6', + inverseOnSurface: '#333333', + inversePrimary: '#5B5E66', + scrim: 'rgba(0, 0, 0, 0.25)', + }; + } + + return { + ...baseTheme.colors, + primary: md3BaseColors.primary!, + onPrimary: '#FFFFFF', + primaryContainer: '#DEE0E6', + onPrimaryContainer: '#2D2F33', + secondary: md3BaseColors.secondary!, + onSecondary: '#FFFFFF', + secondaryContainer: '#ADBCE6', + onSecondaryContainer: '#0B1633', + tertiary: md3BaseColors.tertiary!, + onTertiary: '#FFFFFF', + tertiaryContainer: '#9EE6E5', + onTertiaryContainer: '#013332', + error: md3BaseColors.error!, + onError: '#FFFFFF', + errorContainer: '#E6ACA9', + onErrorContainer: '#330B09', + background: '#fcfcfc', + onBackground: '#333333', + surface: '#fcfcfc', + onSurface: '#333333', + surfaceVariant: '#e4e4e6', + onSurfaceVariant: '#646466', + outline: '#969799', + outlineVariant: '#a1a1a1', + // Additional required MD3 colors + surfaceDisabled: withOpacity('#fcfcfc', 0.12), + onSurfaceDisabled: withOpacity('#333333', 0.38), + inverseSurface: '#333333', + inverseOnSurface: '#fcfcfc', + inversePrimary: '#DEE0E6', + scrim: 'rgba(0, 0, 0, 0.25)', + }; }; -const darkColors: Colors = { - ...MD3DarkTheme.colors, - primary: '#bb86fc', - accent: '#03dac6', - outlineVariant: '#a1a1a1', - receivedMessageDocumentIcon: MD3DarkTheme.colors.primary, - sentMessageDocumentIcon: MD3DarkTheme.colors.onSurface, +const createSemanticColors = ( + baseColors: MD3BaseColors, + isDark: boolean, +): SemanticColors => ({ + // Surface variants + surfaceContainerHighest: isDark + ? withOpacity(baseColors.surface, 0.22) + : withOpacity(baseColors.primary, 0.05), + surfaceContainerHigh: isDark + ? withOpacity(baseColors.surface, 0.16) + : withOpacity(baseColors.primary, 0.03), + surfaceContainer: isDark + ? withOpacity(baseColors.surface, 0.12) + : withOpacity(baseColors.primary, 0.02), + surfaceContainerLow: isDark + ? withOpacity(baseColors.surface, 0.08) + : withOpacity(baseColors.primary, 0.01), + surfaceContainerLowest: isDark + ? withOpacity(baseColors.surface, 0.04) + : baseColors.surface, + surfaceDim: isDark + ? withOpacity(baseColors.surface, 0.06) + : withOpacity(baseColors.primary, 0.06), + surfaceBright: isDark + ? withOpacity(baseColors.surface, 0.24) + : baseColors.surface, + + // Interactive states + stateLayerOpacity: 0.12, + hoverStateOpacity: stateLayerOpacity.hover, + pressedStateOpacity: stateLayerOpacity.pressed, + draggedStateOpacity: stateLayerOpacity.dragged, + focusStateOpacity: stateLayerOpacity.focus, + + // Menu specific + menuBackground: baseColors.surface, + menuBackgroundDimmed: withOpacity(baseColors.surface, 0.9), + menuBackgroundActive: withOpacity(baseColors.primary, 0.08), + menuSeparator: withOpacity(baseColors.primary, 0.5), + menuGroupSeparator: isDark + ? withOpacity('#FFFFFF', 0.08) + : withOpacity('#000000', 0.08), + menuText: baseColors.onSurface, + menuDangerText: baseColors.error, + + // Message specific + authorBubbleBackground: isDark + ? 'rgba(255, 255, 255, 0.03)' + : 'rgba(0, 0, 0, 0.03)', + receivedMessageDocumentIcon: baseColors.primary, + sentMessageDocumentIcon: baseColors.onSurface, userAvatarImageBackground: 'transparent', - userAvatarNameColors: getThemeColorsAsArray(MD3DarkTheme), - searchBarBackground: 'rgba(28, 28, 30, 0.92)', // iOS dark mode searchbar -}; + userAvatarNameColors: [ + baseColors.primary, + baseColors.secondary, + baseColors.tertiary, + baseColors.error, + ], + searchBarBackground: isDark + ? 'rgba(28, 28, 30, 0.92)' + : 'rgba(118, 118, 128, 0.12)', +}); -export const darkTheme: Theme = { - ...MD3DarkTheme, - borders: lightTheme.borders, - colors: darkColors, - fonts: { - ...lightTheme.fonts, - dateDividerTextStyle: { - ...lightTheme.fonts.dateDividerTextStyle, - color: MD3DarkTheme.colors.onSurface, - }, - receivedMessageBodyTextStyle: { - ...lightTheme.fonts.receivedMessageBodyTextStyle, - color: MD3DarkTheme.colors.onPrimary, +const createTheme = (isDark: boolean): Theme => { + const baseTheme = isDark ? MD3DarkTheme : PaperLightTheme; + const baseColors = createBaseColors(isDark); + const semanticColors = createSemanticColors(baseColors, isDark); + + return { + ...baseTheme, + colors: { + ...baseColors, + ...semanticColors, }, - receivedMessageCaptionTextStyle: { - ...lightTheme.fonts.receivedMessageCaptionTextStyle, - color: MD3DarkTheme.colors.onSurfaceVariant, + borders: { + inputBorderRadius: 20, + messageBorderRadius: 15, }, - receivedMessageLinkDescriptionTextStyle: { - ...lightTheme.fonts.receivedMessageLinkDescriptionTextStyle, - color: MD3DarkTheme.colors.onPrimary, + fonts: { + ...baseTheme.fonts, + dateDividerTextStyle: { + color: baseColors.onSurface, + fontSize: 12, + fontWeight: '800', + lineHeight: 16, + opacity: 0.4, + }, + emptyChatPlaceholderTextStyle: { + color: baseColors.onSurface, + fontSize: 16, + fontWeight: '500', + lineHeight: 24, + }, + inputTextStyle: { + fontSize: 16, + fontWeight: '500', + lineHeight: 24, + }, + receivedMessageBodyTextStyle: { + color: baseColors.onPrimary, + fontSize: 16, + fontWeight: '500', + lineHeight: 24, + }, + receivedMessageCaptionTextStyle: { + color: baseColors.onSurfaceVariant, + fontSize: 12, + fontWeight: '500', + lineHeight: 16, + }, + receivedMessageLinkDescriptionTextStyle: { + color: baseColors.onPrimary, + fontSize: 14, + fontWeight: '400', + lineHeight: 20, + }, + receivedMessageLinkTitleTextStyle: { + color: baseColors.onPrimary, + fontSize: 16, + fontWeight: '800', + lineHeight: 22, + }, + sentMessageBodyTextStyle: { + color: baseColors.onSurface, + fontSize: 16, + fontWeight: '500', + lineHeight: 24, + }, + sentMessageCaptionTextStyle: { + color: baseColors.onSurfaceVariant, + fontSize: 12, + fontWeight: '500', + lineHeight: 16, + }, + sentMessageLinkDescriptionTextStyle: { + color: baseColors.onSurface, + fontSize: 14, + fontWeight: '400', + lineHeight: 20, + }, + sentMessageLinkTitleTextStyle: { + color: baseColors.onSurface, + fontSize: 16, + fontWeight: '800', + lineHeight: 22, + }, + userAvatarTextStyle: { + color: baseColors.onSurface, + fontSize: 12, + fontWeight: '800', + lineHeight: 16, + }, + userNameTextStyle: { + fontSize: 12, + fontWeight: '800', + lineHeight: 16, + }, }, - receivedMessageLinkTitleTextStyle: { - ...lightTheme.fonts.receivedMessageLinkTitleTextStyle, - color: MD3DarkTheme.colors.onPrimary, + insets: { + messageInsetsHorizontal: 20, + messageInsetsVertical: 10, }, - }, - insets: lightTheme.insets, + icons: {}, + }; }; + +export const lightTheme = createTheme(false); +export const darkTheme = createTheme(true); diff --git a/src/utils/types.ts b/src/utils/types.ts index c6eed5f..abef520 100644 --- a/src/utils/types.ts +++ b/src/utils/types.ts @@ -1,5 +1,5 @@ import * as React from 'react'; -import {ColorValue, ImageURISource, TextStyle} from 'react-native'; +import {ImageURISource, TextStyle} from 'react-native'; import {MD3Theme} from 'react-native-paper'; import {TemplateConfig} from 'chat-formatter'; @@ -131,35 +131,39 @@ export interface Size { width: number; } -export interface Colors extends MD3Colors { - accent: string; +export interface MD3BaseColors extends MD3Colors { + primary: string; + onPrimary: string; + primaryContainer: string; + onPrimaryContainer: string; + secondary: string; + onSecondary: string; + secondaryContainer: string; + onSecondaryContainer: string; + tertiary: string; + onTertiary: string; + tertiaryContainer: string; + onTertiaryContainer: string; + error: string; + onError: string; + errorContainer: string; + onErrorContainer: string; + background: string; + onBackground: string; + surface: string; + onSurface: string; + surfaceVariant: string; + onSurfaceVariant: string; + outline: string; outlineVariant: string; - receivedMessageDocumentIcon: string; - sentMessageDocumentIcon: string; - userAvatarImageBackground: string; - userAvatarNameColors: ColorValue[]; - searchBarBackground: string; -} - -export interface Typescale extends MD3Typescale { - dateDividerTextStyle?: TextStyle; // Optional custom styles - emptyChatPlaceholderTextStyle?: TextStyle; - inputTextStyle?: TextStyle; - receivedMessageBodyTextStyle?: TextStyle; - receivedMessageCaptionTextStyle?: TextStyle; - receivedMessageLinkDescriptionTextStyle?: TextStyle; - receivedMessageLinkTitleTextStyle?: TextStyle; - sentMessageBodyTextStyle?: TextStyle; - sentMessageCaptionTextStyle?: TextStyle; - sentMessageLinkDescriptionTextStyle?: TextStyle; - sentMessageLinkTitleTextStyle?: TextStyle; - userAvatarTextStyle?: TextStyle; - userNameTextStyle?: TextStyle; -} -export interface ThemeBorders { - inputBorderRadius: number; - messageBorderRadius: number; + // Additional MD3 required colors + surfaceDisabled: string; + onSurfaceDisabled: string; + inverseSurface: string; + inverseOnSurface: string; + inversePrimary: string; + scrim: string; } export interface ThemeIcons { @@ -172,15 +176,71 @@ export interface ThemeIcons { sendingIcon?: () => React.ReactNode; } +export interface SemanticColors { + // Surface variants + surfaceContainerHighest: string; + surfaceContainerHigh: string; + surfaceContainer: string; + surfaceContainerLow: string; + surfaceContainerLowest: string; + surfaceDim: string; + surfaceBright: string; + + // Interactive states + stateLayerOpacity: number; + hoverStateOpacity: number; + pressedStateOpacity: number; + draggedStateOpacity: number; + focusStateOpacity: number; + + // Menu specific + menuBackground: string; + menuBackgroundDimmed: string; + menuBackgroundActive: string; + menuSeparator: string; + menuGroupSeparator: string; + menuText: string; + menuDangerText: string; + + // Message specific + authorBubbleBackground: string; + receivedMessageDocumentIcon: string; + sentMessageDocumentIcon: string; + userAvatarImageBackground: string; + userAvatarNameColors: string[]; + searchBarBackground: string; +} + +export interface ThemeBorders { + inputBorderRadius: number; + messageBorderRadius: number; +} + +export interface ThemeFonts extends MD3Typescale { + dateDividerTextStyle: TextStyle; + emptyChatPlaceholderTextStyle: TextStyle; + inputTextStyle: TextStyle; + receivedMessageBodyTextStyle: TextStyle; + receivedMessageCaptionTextStyle: TextStyle; + receivedMessageLinkDescriptionTextStyle: TextStyle; + receivedMessageLinkTitleTextStyle: TextStyle; + sentMessageBodyTextStyle: TextStyle; + sentMessageCaptionTextStyle: TextStyle; + sentMessageLinkDescriptionTextStyle: TextStyle; + sentMessageLinkTitleTextStyle: TextStyle; + userAvatarTextStyle: TextStyle; + userNameTextStyle: TextStyle; +} + export interface ThemeInsets { messageInsetsHorizontal: number; messageInsetsVertical: number; } export interface Theme extends MD3Theme { - colors: Colors; - fonts: Typescale; + colors: MD3BaseColors & SemanticColors; borders: ThemeBorders; + fonts: ThemeFonts; insets: ThemeInsets; icons?: ThemeIcons; } diff --git a/yarn.lock b/yarn.lock index abd3a35..1a1f4da 100644 --- a/yarn.lock +++ b/yarn.lock @@ -6673,6 +6673,11 @@ react-native-get-random-values@^1.11.0: dependencies: fast-base64-decode "^1.0.0" +react-native-haptic-feedback@^2.3.3: + version "2.3.3" + resolved "https://registry.yarnpkg.com/react-native-haptic-feedback/-/react-native-haptic-feedback-2.3.3.tgz#88b6876e91399a69bd1b551fe1681b2f3dc1214e" + integrity sha512-svS4D5PxfNv8o68m9ahWfwje5NqukM3qLS48+WTdhbDkNUkOhP9rDfDSRHzlhk4zq+ISjyw95EhLeh8NkKX5vQ== + react-native-image-viewing@^0.2.2: version "0.2.2" resolved "https://registry.yarnpkg.com/react-native-image-viewing/-/react-native-image-viewing-0.2.2.tgz#fb26e57d7d3d9ce4559a3af3d244387c0367242b"