Skip to content

Commit

Permalink
[Feat] add edit message and retry generation (#128)
Browse files Browse the repository at this point in the history
* feat: add edit/copy/regenerate menu to the chat bubbles

* feat: regenerate a response with the same or another model 

* feat: add edit previous message

* chore: add menu, menu item and submenu components

* chore: refactor theme colors
  • Loading branch information
a-ghorbani authored Dec 7, 2024
1 parent 01941c0 commit a45fcb7
Show file tree
Hide file tree
Showing 57 changed files with 2,730 additions and 953 deletions.
5 changes: 5 additions & 0 deletions __mocks__/external/react-native-haptic-feedback.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
const ReactNativeHapticFeedback = {
trigger: jest.fn(),
};

export default ReactNativeHapticFeedback;
3 changes: 3 additions & 0 deletions __mocks__/stores/chatSessionStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ export const mockChatSessionStore = {
createNewSession: jest.fn(),
updateMessage: jest.fn(),
updateMessageToken: jest.fn(),
exitEditMode: jest.fn(),
enterEditMode: jest.fn(),
removeMessagesFromId: jest.fn(),
};

Object.defineProperty(mockChatSessionStore, 'currentSessionMessages', {
Expand Down
126 changes: 88 additions & 38 deletions __mocks__/stores/modelStore.ts
Original file line number Diff line number Diff line change
@@ -1,41 +1,91 @@
import {computed, makeAutoObservable, ObservableMap} from 'mobx';

import {modelsList} from '../../jest/fixtures/models';

export const mockModelStore = {
models: modelsList,
n_context: 1024,
MIN_CONTEXT_SIZE: 200,
useAutoRelease: true,
useMetal: false,
n_gpu_layers: 50,
activeModelId: undefined as string | undefined,
setNContext: jest.fn(),
updateUseAutoRelease: jest.fn(),
updateUseMetal: jest.fn(),
setNGPULayers: jest.fn(),
refreshDownloadStatuses: jest.fn(),
addLocalModel: jest.fn(),
resetModels: jest.fn(),
initContext: jest.fn().mockResolvedValue(Promise.resolve()),
checkSpaceAndDownload: jest.fn(),
getDownloadProgress: jest.fn(),
manualReleaseContext: jest.fn(),
setActiveModel(modelId: string) {
import {Model} from '../../src/utils/types';

class MockModelStore {
models = modelsList;
n_context = 1024;
MIN_CONTEXT_SIZE = 200;
useAutoRelease = true;
useMetal = false;
n_gpu_layers = 50;
activeModelId: string | undefined;
inferencing = false;
isStreaming = false;
downloadJobs = new ObservableMap();

refreshDownloadStatuses: jest.Mock;
addLocalModel: jest.Mock;
setNContext: jest.Mock;
updateUseAutoRelease: jest.Mock;
updateUseMetal: jest.Mock;
setNGPULayers: jest.Mock;
resetModels: jest.Mock;
initContext: jest.Mock;
lastUsedModelId: any;
checkSpaceAndDownload: jest.Mock;
getDownloadProgress: jest.Mock;
manualReleaseContext: jest.Mock;

constructor() {
makeAutoObservable(this, {
refreshDownloadStatuses: false,
addLocalModel: false,
setNContext: false,
updateUseAutoRelease: false,
updateUseMetal: false,
setNGPULayers: false,
resetModels: false,
initContext: false,
checkSpaceAndDownload: false,
getDownloadProgress: false,
manualReleaseContext: false,
lastUsedModel: computed,
activeModel: computed,
isDownloading: computed,
});
this.refreshDownloadStatuses = jest.fn();
this.addLocalModel = jest.fn();
this.setNContext = jest.fn();
this.updateUseAutoRelease = jest.fn();
this.updateUseMetal = jest.fn();
this.setNGPULayers = jest.fn();
this.resetModels = jest.fn();
this.initContext = jest.fn().mockResolvedValue(Promise.resolve());
this.checkSpaceAndDownload = jest.fn();
this.getDownloadProgress = jest.fn();
this.manualReleaseContext = jest.fn();
}

setActiveModel = (modelId: string) => {
this.activeModelId = modelId;
},
};
Object.defineProperty(mockModelStore, 'lastUsedModel', {
get: jest.fn(() => undefined),
configurable: true,
});
Object.defineProperty(mockModelStore, 'isDownloading', {
get: jest.fn(() => () => false),
configurable: true,
});
Object.defineProperty(mockModelStore, 'activeModel', {
get: jest.fn(() =>
mockModelStore.models.find(
model => model.id === mockModelStore.activeModelId,
),
),
configurable: true,
});
};

setInferencing = (value: boolean) => {
this.inferencing = value;
};

setIsStreaming = (value: boolean) => {
this.isStreaming = value;
};

get lastUsedModel(): Model | undefined {
return this.lastUsedModelId
? this.models.find(m => m.id === this.lastUsedModelId)
: undefined;
}

get isDownloading() {
return (modelId: string) => {
return this.downloadJobs.has(modelId);
};
}

get activeModel() {
return this.models.find(model => model.id === this.activeModelId);
}
}

export const mockModelStore = new MockModelStore();
25 changes: 25 additions & 0 deletions ios/Podfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1657,6 +1657,27 @@ PODS:
- ReactCommon/turbomodule/bridging
- ReactCommon/turbomodule/core
- Yoga
- RNReactNativeHapticFeedback (2.3.3):
- DoubleConversion
- glog
- hermes-engine
- RCT-Folly (= 2024.01.01.00)
- RCTRequired
- RCTTypeSafety
- React-Core
- React-debug
- React-Fabric
- React-featureflags
- React-graphics
- React-ImageManager
- React-NativeModulesApple
- React-RCTFabric
- React-rendererdebug
- React-utils
- ReactCodegen
- ReactCommon/turbomodule/bridging
- ReactCommon/turbomodule/core
- Yoga
- RNReanimated (3.16.3):
- DoubleConversion
- glog
Expand Down Expand Up @@ -1873,6 +1894,7 @@ DEPENDENCIES:
- "RNCPicker (from `../node_modules/@react-native-picker/picker`)"
- RNDeviceInfo (from `../node_modules/react-native-device-info`)
- RNGestureHandler (from `../node_modules/react-native-gesture-handler`)
- RNReactNativeHapticFeedback (from `../node_modules/react-native-haptic-feedback`)
- RNReanimated (from `../node_modules/react-native-reanimated`)
- RNScreens (from `../node_modules/react-native-screens`)
- RNSVG (from `../node_modules/react-native-svg`)
Expand Down Expand Up @@ -2041,6 +2063,8 @@ EXTERNAL SOURCES:
:path: "../node_modules/react-native-device-info"
RNGestureHandler:
:path: "../node_modules/react-native-gesture-handler"
RNReactNativeHapticFeedback:
:path: "../node_modules/react-native-haptic-feedback"
RNReanimated:
:path: "../node_modules/react-native-reanimated"
RNScreens:
Expand Down Expand Up @@ -2131,6 +2155,7 @@ SPEC CHECKSUMS:
RNCPicker: d8662eb6615e3401acb590c44b97b2af3beb1e53
RNDeviceInfo: ae26ae45db3f9937f038a284bcd0a1db8d70db96
RNGestureHandler: 5b24d10761754ad271b714e536c457fd89b17c54
RNReactNativeHapticFeedback: 00ba111b82aa266bb3ee1aa576831c2ea9a9dfad
RNReanimated: 929c26a706dfe1af8feee9f2cf78004394e4dd04
RNScreens: e21c8d32fe97737ecc30f1f21e7b6f69f341a1f5
RNSVG: 6a529f4faed8be4ebfb00f1a29e25cb046d95e61
Expand Down
2 changes: 2 additions & 0 deletions jest.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,7 @@ module.exports = {
'<rootDir>/__mocks__/external/react-native-document-picker.js',
'@dr.pogodin/react-native-fs':
'<rootDir>/__mocks__/external/@dr.pogodin/react-native-fs.js',
'react-native-haptic-feedback':
'<rootDir>/__mocks__/external/react-native-haptic-feedback.js',
},
};
2 changes: 2 additions & 0 deletions jest/setup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import mockClipboard from '@react-native-clipboard/clipboard/jest/clipboard-mock

import 'react-native-gesture-handler/jestSetup';

jest.mock('react-native-haptic-feedback');

// Mock react-native-reanimated
//require('react-native-reanimated').setUpTests();
jest.mock('react-native-reanimated', () => {
Expand Down
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"react-native-document-picker": "^9.1.2",
"react-native-gesture-handler": "^2.20.2",
"react-native-get-random-values": "^1.11.0",
"react-native-haptic-feedback": "^2.3.3",
"react-native-image-viewing": "^0.2.2",
"react-native-linear-gradient": "^2.8.3",
"react-native-marked": "^6.0.4",
Expand Down
77 changes: 77 additions & 0 deletions src/api/__tests__/hf.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import axios from 'axios';
import {fetchGGUFSpecs, fetchModelFilesDetails, fetchModels} from '../hf';

jest.mock('axios');
const mockedAxios = axios as jest.Mocked<typeof axios>;

describe('fetchModels', () => {
it('should fetch models with basic parameters', async () => {
const mockResponse = {
data: [{id: 'model1'}],
headers: {link: 'next-page-link'},
};
mockedAxios.get.mockResolvedValueOnce(mockResponse);

const result = await fetchModels({search: 'test'});

expect(mockedAxios.get).toHaveBeenCalledWith(
expect.any(String),
expect.objectContaining({
params: expect.objectContaining({search: 'test'}),
}),
);
expect(result).toEqual({
models: [{id: 'model1'}],
nextLink: 'next-page-link',
});
});

it('should handle missing pagination link', async () => {
const mockResponse = {
data: [{id: 'model1'}],
headers: {},
};
mockedAxios.get.mockResolvedValueOnce(mockResponse);

const result = await fetchModels({});
expect(result.nextLink).toBeNull();
});
});

describe('API error handling', () => {
it('should handle network errors in fetchModels', async () => {
const error = new Error('Network error');
mockedAxios.get.mockRejectedValueOnce(error);

await expect(fetchModels({})).rejects.toThrow('Network error');
});

it('should handle non-ok responses in fetchModelFilesDetails', async () => {
global.fetch = jest.fn().mockResolvedValueOnce({
ok: false,
statusText: 'Not Found',
});

await expect(fetchModelFilesDetails('model1')).rejects.toThrow(
'Error fetching model files: Not Found',
);
});
});

describe('fetchGGUFSpecs', () => {
it('should parse GGUF specs correctly', async () => {
const mockSpecs = {
gguf: {
params: 7,
type: 'f16',
},
};
global.fetch = jest.fn().mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(mockSpecs),
});

const result = await fetchGGUFSpecs('model1');
expect(result).toEqual(mockSpecs);
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import React from 'react';
import {render, fireEvent} from '@testing-library/react-native';
import {BottomSheetSearchbar} from '../BottomSheetSearchbar';
import {useBottomSheetInternal} from '@gorhom/bottom-sheet';

jest.mock('@gorhom/bottom-sheet', () => ({
useBottomSheetInternal: jest.fn(),
}));

describe('BottomSheetSearchbar', () => {
const mockShouldHandleKeyboardEvents = {value: false};

beforeEach(() => {
(useBottomSheetInternal as jest.Mock).mockReturnValue({
shouldHandleKeyboardEvents: mockShouldHandleKeyboardEvents,
});
});

it('should handle focus event correctly', () => {
const onFocus = jest.fn();
const {getByTestId} = render(
<BottomSheetSearchbar
testID="searchbar"
onFocus={onFocus}
value="test"
/>,
);

fireEvent(getByTestId('searchbar'), 'focus');

expect(mockShouldHandleKeyboardEvents.value).toBe(true);
expect(onFocus).toHaveBeenCalled();
});

it('should handle blur event correctly', () => {
const onBlur = jest.fn();
const {getByTestId} = render(
<BottomSheetSearchbar testID="searchbar" onBlur={onBlur} value="test" />,
);

fireEvent(getByTestId('searchbar'), 'blur');

expect(mockShouldHandleKeyboardEvents.value).toBe(false);
expect(onBlur).toHaveBeenCalled();
});

it('should reset keyboard events flag on unmount', () => {
const {unmount} = render(<BottomSheetSearchbar value="test" />);

unmount();

expect(mockShouldHandleKeyboardEvents.value).toBe(false);
});

it('should forward props to Searchbar component', () => {
const placeholder = 'Search...';
const value = 'test';
const onChangeText = jest.fn();

const {getByPlaceholderText} = render(
<BottomSheetSearchbar
placeholder={placeholder}
value={value}
onChangeText={onChangeText}
/>,
);

const searchbar = getByPlaceholderText(placeholder);
expect(searchbar.props.value).toBe(value);

fireEvent.changeText(searchbar, 'new value');
expect(onChangeText).toHaveBeenCalledWith('new value');
});
});
Loading

0 comments on commit a45fcb7

Please sign in to comment.