Skip to content

Commit

Permalink
feat: use pickKey for UX
Browse files Browse the repository at this point in the history
  • Loading branch information
tak-bro committed Jul 24, 2024
1 parent 95a3bc7 commit 2440b32
Show file tree
Hide file tree
Showing 20 changed files with 56 additions and 170 deletions.
10 changes: 0 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,6 @@ aipick -m "Why is the sky blue?"
aipick --message <s> # or -m <s>
```

##### `--generate` or `-g`
- Number of responses to generate (Warning: generating multiple costs more) (default: **1**)

```sh
aipick --generate <i> # or -g <i>
```

> Warning: this uses more tokens, meaning it costs more.
##### `--systemPrompt` or `-s`
- System prompt to let users fine-tune prompt

Expand Down Expand Up @@ -242,7 +233,6 @@ aipick config set OPENAI.key=<your-api-key> GEMINI.temperature=3
| `OLLAMA_HOST` | `http://localhost:11434` | The Ollama Host |
| `OLLAMA_TIMEOUT` | `100_000` ms | Request timeout for the Ollama |
| `locale` | `en` | Locale for the generated commit messages |
| `generate` | `1` | Number of commit messages to generate |
| `type` | `conventional` | Type of commit message to generate |
| `proxy` | N/A | Set a HTTP/HTTPS proxy to use for requests(only **OpenAI**) |
| `timeout` | `10_000` ms | Network request timeout |
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
"formdata-node": "^6.0.3",
"groq-sdk": "^0.4.0",
"inquirer": "9.2.8",
"inquirer-reactive-list-prompt": "^1.0.8",
"inquirer-reactive-list-prompt": "^1.0.9",
"ollama": "^0.5.6",
"ora": "^8.0.1",
"readline": "^1.3.0",
Expand Down
8 changes: 4 additions & 4 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 1 addition & 6 deletions src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@ cli(
name: 'aipick',
version,
flags: {
generate: {
type: Number,
description: 'Number of responses to generate (Warning: generating multiple costs more) (default: 1)',
alias: 'g',
},
message: {
type: String,
description: 'Message to ask to AI',
Expand All @@ -38,7 +33,7 @@ cli(
ignoreArgv: type => type === 'unknown-flag' || type === 'argument',
},
argv => {
aipick(argv.flags.generate, argv.flags.message, argv.flags.systemPrompt, rawArgv);
aipick(argv.flags.message, argv.flags.systemPrompt, rawArgv);
},
rawArgv
);
3 changes: 1 addition & 2 deletions src/commands/aipick.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { KnownError, handleCliError } from '../utils/error.js';

const consoleManager = new ConsoleManager();

export default async (generate: number | undefined, message: string | undefined, systemPrompt: string | undefined, rawArgv: string[]) =>
export default async (message: string | undefined, systemPrompt: string | undefined, rawArgv: string[]) =>
(async () => {
consoleManager.printTitle();

Expand All @@ -21,7 +21,6 @@ export default async (generate: number | undefined, message: string | undefined,

const config = await getConfig(
{
generate: generate as number,
systemPrompt: systemPrompt?.toString() as string,
},
rawArgv
Expand Down
4 changes: 3 additions & 1 deletion src/managers/reactive-prompt.manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ export class ReactivePromptManager {
emptyMessage: `⚠ ${emptyResponses}`,
loop: false,
showDescription,
descPageSize: 10,
descPageSize: 15,
choices$: this.choices$,
loader$: this.loader$,
// @ts-ignore ignore
pickKey: 'short',
});
}

Expand Down
60 changes: 6 additions & 54 deletions src/services/ai/ai.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,13 @@ import { ReactiveListChoice } from 'inquirer-reactive-list-prompt';
import { Observable, of } from 'rxjs';

import { ModelConfig, ModelName } from '../../utils/config.js';
import { getFirstWordsFrom } from '../../utils/utils.js';

export interface AIResponse {
title: string;
value: string;
}

export interface RawAIResponse {
summary: string;
description?: string;
}

export interface AIServiceParams {
config: ModelConfig<ModelName>;
userMessage: string;
Expand Down Expand Up @@ -56,57 +52,13 @@ export abstract class AIService {
});
};

protected sanitizeResponse(generatedText: string, maxCount: number, ignoreBody: boolean): AIResponse[] {
protected sanitizeResponse(generatedText: string, ignoreBody: boolean): AIResponse[] {
try {
const rawResponses: RawAIResponse[] = JSON.parse(generatedText);
const filtedResponses = rawResponses.map((data: RawAIResponse) => {
if (ignoreBody) {
return {
title: `${data.summary}`,
value: `${data.summary}`,
};
}
return {
title: `${data.summary}`,
value: `${data.summary}${data.description ? `\n\n${data.description}` : ''}`,
};
});

if (filtedResponses.length > maxCount) {
return filtedResponses.slice(0, maxCount);
}
return filtedResponses;
const title = `${getFirstWordsFrom(generatedText)}...`;
const value = generatedText;
return [{ title, value }];
} catch (error) {
const jsonPattern = /\[[\s\S]*?\]/;
try {
const jsonMatch = generatedText.match(jsonPattern);
if (!jsonMatch) {
// No valid JSON array found in the response
return [];
}
const jsonStr = jsonMatch[0];
const rawResponses: RawAIResponse[] = JSON.parse(jsonStr);
const filtedResponses = rawResponses.map((data: RawAIResponse) => {
if (ignoreBody) {
return {
title: `${data.summary}`,
value: `${data.summary}`,
};
}
return {
title: `${data.summary}`,
value: `${data.summary}${data.description ? `\n\n${data.description}` : ''}`,
};
});

if (filtedResponses.length > maxCount) {
return filtedResponses.slice(0, maxCount);
}
return filtedResponses;
} catch (e) {
// Error parsing JSON
return [];
}
return [];
}
}
}
6 changes: 3 additions & 3 deletions src/services/ai/anthropic.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export class AnthropicService extends AIService {
concatMap(messages => from(messages)),
map(data => ({
name: `${this.serviceName} ${data.title}`,
short: data.title,
value: data.value,
description: data.value,
isError: false,
Expand All @@ -47,11 +48,10 @@ export class AnthropicService extends AIService {
private async generateResponses(): Promise<AIResponse[]> {
try {
const userMessage = this.params.userMessage;
const { generate, systemPrompt, systemPromptPath, logging, temperature } = this.params.config;
const { systemPrompt, systemPromptPath, logging, temperature } = this.params.config;

const promptOptions: PromptOptions = {
...DEFAULT_PROMPT_OPTIONS,
generate,
userMessage,
systemPrompt,
systemPromptPath,
Expand All @@ -73,7 +73,7 @@ export class AnthropicService extends AIService {
const result: Anthropic.Message = await this.anthropic.messages.create(params);
const completion = result.content.map(({ text }) => text).join('');
logging && createLogResponse('Anthropic', userMessage, generatedSystemPrompt, completion);
return this.sanitizeResponse(completion, generate, this.params.config.ignoreBody);
return this.sanitizeResponse(completion, this.params.config.ignoreBody);
} catch (error) {
const errorAsAny = error as any;
if (errorAsAny.code === 'ENOTFOUND') {
Expand Down
6 changes: 3 additions & 3 deletions src/services/ai/codestral.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export class CodestralService extends AIService {
concatMap(messages => from(messages)),
map(data => ({
name: `${this.serviceName} ${data.title}`,
short: data.title,
value: data.value,
description: data.value,
isError: false,
Expand All @@ -43,10 +44,9 @@ export class CodestralService extends AIService {
private async generateResponses(): Promise<AIResponse[]> {
try {
const userMessage = this.params.userMessage;
const { generate, systemPrompt, systemPromptPath, logging, temperature } = this.params.config;
const { systemPrompt, systemPromptPath, logging, temperature } = this.params.config;
const promptOptions: PromptOptions = {
...DEFAULT_PROMPT_OPTIONS,
generate,
userMessage,
systemPrompt,
systemPromptPath,
Expand All @@ -55,7 +55,7 @@ export class CodestralService extends AIService {
this.checkAvailableModels();
const chatResponse = await this.createChatCompletions(generatedSystemPrompt, userMessage);
logging && createLogResponse('Codestral', userMessage, generatedSystemPrompt, chatResponse);
return this.sanitizeResponse(chatResponse, generate, this.params.config.ignoreBody);
return this.sanitizeResponse(chatResponse, this.params.config.ignoreBody);
} catch (error) {
const errorAsAny = error as any;
if (errorAsAny.code === 'ENOTFOUND') {
Expand Down
6 changes: 3 additions & 3 deletions src/services/ai/cohere.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export class CohereService extends AIService {
concatMap(messages => from(messages)),
map(data => ({
name: `${this.serviceName} ${data.title}`,
short: data.title,
value: data.value,
description: data.value,
isError: false,
Expand All @@ -41,10 +42,9 @@ export class CohereService extends AIService {
private async generateResponses(): Promise<AIResponse[]> {
try {
const userMessage = this.params.userMessage;
const { generate, systemPrompt, systemPromptPath, logging, temperature } = this.params.config;
const { systemPrompt, systemPromptPath, logging, temperature } = this.params.config;
const promptOptions: PromptOptions = {
...DEFAULT_PROMPT_OPTIONS,
generate,
userMessage,
systemPrompt,
systemPromptPath,
Expand All @@ -62,7 +62,7 @@ export class CohereService extends AIService {
});

logging && createLogResponse('Cohere', userMessage, generatedSystemPrompt, prediction.text);
return this.sanitizeResponse(prediction.text, generate, this.params.config.ignoreBody);
return this.sanitizeResponse(prediction.text, this.params.config.ignoreBody);
} catch (error) {
const errorAsAny = error as any;
if (errorAsAny instanceof CohereTimeoutError) {
Expand Down
6 changes: 3 additions & 3 deletions src/services/ai/gemini.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export class GeminiService extends AIService {
concatMap(messages => from(messages)),
map(data => ({
name: `${this.serviceName} ${data.title}`,
short: data.title,
value: data.value,
description: data.value,
isError: false,
Expand All @@ -39,11 +40,10 @@ export class GeminiService extends AIService {
private async generateResponses(): Promise<AIResponse[]> {
try {
const userMessage = this.params.userMessage;
const { generate, systemPrompt, systemPromptPath, logging, temperature } = this.params.config;
const { systemPrompt, systemPromptPath, logging, temperature } = this.params.config;
const maxTokens = this.params.config['max-tokens'];
const promptOptions: PromptOptions = {
...DEFAULT_PROMPT_OPTIONS,
generate,
userMessage,
systemPrompt,
systemPromptPath,
Expand All @@ -63,7 +63,7 @@ export class GeminiService extends AIService {
const completion = response.text();

logging && createLogResponse('Gemini', userMessage, generatedSystemPrompt, completion);
return this.sanitizeResponse(completion, generate, this.params.config.ignoreBody);
return this.sanitizeResponse(completion, this.params.config.ignoreBody);
} catch (error) {
const errorAsAny = error as any;
if (errorAsAny.code === 'ENOTFOUND') {
Expand Down
6 changes: 3 additions & 3 deletions src/services/ai/groq.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export class GroqService extends AIService {
concatMap(messages => from(messages)),
map(data => ({
name: `${this.serviceName} ${data.title}`,
short: data.title,
value: data.value,
description: data.value,
isError: false,
Expand All @@ -39,11 +40,10 @@ export class GroqService extends AIService {
private async generateResponses(): Promise<AIResponse[]> {
try {
const userMessage = this.params.userMessage;
const { generate, systemPrompt, systemPromptPath, logging, temperature } = this.params.config;
const { systemPrompt, systemPromptPath, logging, temperature } = this.params.config;
const maxTokens = this.params.config['max-tokens'];
const promptOptions: PromptOptions = {
...DEFAULT_PROMPT_OPTIONS,
generate,
userMessage,
systemPrompt,
systemPromptPath,
Expand Down Expand Up @@ -73,7 +73,7 @@ export class GroqService extends AIService {

const result = chatCompletion.choices[0].message.content || '';
logging && createLogResponse('Groq', userMessage, generatedSystemPrompt, result);
return this.sanitizeResponse(result, generate, this.params.config.ignoreBody);
return this.sanitizeResponse(result, this.params.config.ignoreBody);
} catch (error) {
throw error as any;
}
Expand Down
6 changes: 3 additions & 3 deletions src/services/ai/hugging-face.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ export class HuggingFaceService extends AIService {
concatMap(messages => from(messages)),
map(data => ({
name: `${this.serviceName} ${data.title}`,
short: data.title,
value: data.value,
description: data.value,
isError: false,
Expand All @@ -94,10 +95,9 @@ export class HuggingFaceService extends AIService {
await this.intialize();

const userMessage = this.params.userMessage;
const { generate, systemPrompt, systemPromptPath, logging, temperature } = this.params.config;
const { systemPrompt, systemPromptPath, logging, temperature } = this.params.config;
const promptOptions: PromptOptions = {
...DEFAULT_PROMPT_OPTIONS,
generate,
userMessage,
systemPrompt,
systemPromptPath,
Expand All @@ -110,7 +110,7 @@ export class HuggingFaceService extends AIService {
// await this.deleteConversation(conversation.id);

logging && createLogResponse('HuggingFace', userMessage, generatedSystemPrompt, response);
return this.sanitizeResponse(response, generate, this.params.config.ignoreBody);
return this.sanitizeResponse(response, this.params.config.ignoreBody);
} catch (error) {
const errorAsAny = error as any;
if (errorAsAny.code === 'ENOTFOUND') {
Expand Down
6 changes: 3 additions & 3 deletions src/services/ai/mistral.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ export class MistralService extends AIService {
concatMap(messages => from(messages)),
map(data => ({
name: `${this.serviceName} ${data.title}`,
short: data.title,
value: data.value,
description: data.value,
isError: false,
Expand All @@ -74,10 +75,9 @@ export class MistralService extends AIService {
private async generateMessages(): Promise<AIResponse[]> {
try {
const userMessage = this.params.userMessage;
const { generate, systemPrompt, systemPromptPath, logging, temperature } = this.params.config;
const { systemPrompt, systemPromptPath, logging, temperature } = this.params.config;
const promptOptions: PromptOptions = {
...DEFAULT_PROMPT_OPTIONS,
generate,
userMessage,
systemPrompt,
systemPromptPath,
Expand All @@ -87,7 +87,7 @@ export class MistralService extends AIService {
await this.checkAvailableModels();
const chatResponse = await this.createChatCompletions(generatedSystemPrompt, userMessage);
logging && createLogResponse('MistralAI', userMessage, generatedSystemPrompt, chatResponse);
return this.sanitizeResponse(chatResponse, generate, this.params.config.ignoreBody);
return this.sanitizeResponse(chatResponse, this.params.config.ignoreBody);
} catch (error) {
const errorAsAny = error as any;
if (errorAsAny.code === 'ENOTFOUND') {
Expand Down
Loading

0 comments on commit 2440b32

Please sign in to comment.