diff --git a/src/commands/chitchatCommand.service.ts b/src/commands/chitchatCommand.service.ts index 8420a55..98607ed 100644 --- a/src/commands/chitchatCommand.service.ts +++ b/src/commands/chitchatCommand.service.ts @@ -3,7 +3,6 @@ import { ConfigService } from '@nestjs/config'; import { OnEvent } from '@nestjs/event-emitter'; import { stripIndent } from 'common-tags'; import { Message } from 'discord.js'; -import { concat } from 'rxjs'; import { EnvironmentVariables } from '../config/validate'; import { DISCORD_EVENTS } from '../discord/constants'; @@ -13,6 +12,19 @@ import { OpenAIModerationService } from '../openai/openai-moderation.service'; import { CommandsService } from './commands.service'; import { CommandService } from './types/CommandService'; +const pick = , K extends keyof T>( + obj: T, + keys: K[], +) => { + const ret = {} as Pick; + + keys.forEach((key) => { + ret[key] = obj[key]; + }); + + return ret; +}; + @Injectable() export class ChitchatCommandService implements CommandService { constructor( @@ -60,7 +72,7 @@ export class ChitchatCommandService implements CommandService { } } - private buildReplyChainMessageHistory(replyChain: Message[]) { + public buildReplyChainMessageHistory(replyChain: Message[]) { return replyChain.reverse().map((m) => { const memberId = m.member?.id; @@ -86,12 +98,6 @@ export class ChitchatCommandService implements CommandService { return message.split('\n===\n'); } - private async getPromptMessageContext(message: Message) { - const replyChain = await this.discordService.fetchReplyChain(message); - - return this.buildReplyChainMessageHistory(replyChain); - } - private async handleChitchatMessage(message: Message) { if (this.isGptEnabled) { return this.handleGptChitchat(message); @@ -126,8 +132,10 @@ export class ChitchatCommandService implements CommandService { const chatRequestMessage = this.createUserChatMessageFromDiscordMessage(message); - const replyChainMessageHistory = await this.getPromptMessageContext( - message, + const replyChain = await this.discordService.fetchReplyChain(message); + + const replyChainMessageHistory = await this.buildReplyChainMessageHistory( + replyChain, ); const userCompletionContent = replyChainMessageHistory @@ -146,13 +154,39 @@ export class ChitchatCommandService implements CommandService { const channel = await message.channel.fetch(); + const participantDetails = [ + ...new Set([ + ...replyChain.map((m) => m.member || m.author), + message.member || message.author, + ]), + ].map((author) => + pick(author as unknown as Record, [ + 'id', + 'username', + 'displayName', + ]), + ); + const messageChain = [ this.aiCompletionService.createSystemMessage(this.preamble), - stripIndent` - NAME: ${this.discordService.username} - ID: ${this.discordService.userId} - CHANNEL INFO: ${channel.toJSON()} - `, + this.aiCompletionService.createSystemMessage( + stripIndent` + ASSISTANT USER DETAILS: ${JSON.stringify( + pick( + this.discordService.user as unknown as Record, + ['id', 'name'], + ), + )} + CHANNEL DETAILS: ${JSON.stringify( + pick(channel as unknown as Record, [ + 'id', + 'name', + 'type', + ]), + )} + PARTICIPANT DETAILS: ${JSON.stringify(participantDetails)} + `, + ), ...replyChainMessageHistory, chatRequestMessage, this.aiCompletionService.createSystemMessage( @@ -166,13 +200,13 @@ export class ChitchatCommandService implements CommandService { const responseMessages = this.getCompletionResponseMessages(response); - return this.handleResponseMessages(message, responseMessages); + await this.handleResponseMessages(message, responseMessages); } catch (e) { if (e.response?.status === HttpStatus.TOO_MANY_REQUESTS) { - return message.reply('Out of credits... Please insert token.'); + await message.reply('Out of credits... Please insert token.'); + } else { + await message.reply('That one hurt my brain..'); } - - return message.reply('That one hurt my brain..'); } } } diff --git a/src/discord/discord.service.ts b/src/discord/discord.service.ts index f2105ea..61d7bd7 100644 --- a/src/discord/discord.service.ts +++ b/src/discord/discord.service.ts @@ -40,6 +40,10 @@ export class DiscordService { return this.discordClient.user?.username ?? 'ph8'; } + public get user() { + return this.discordClient.user; + } + public async fetchReplyChain(message: Message): Promise { const replyChain: Message[] = []; let m = message;