Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/Return SourceDocuments #274

Merged
merged 1 commit into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ import { BaseLanguageModel } from 'langchain/base_language'
import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface'
import { CustomChainHandler, getBaseClasses } from '../../../src/utils'
import { ConversationalRetrievalQAChain } from 'langchain/chains'
import { BaseRetriever } from 'langchain/schema'
import { AIChatMessage, BaseRetriever, HumanChatMessage } from 'langchain/schema'
import { BaseChatMemory, BufferMemory, ChatMessageHistory } from 'langchain/memory'
import { PromptTemplate } from 'langchain/prompts'

const default_qa_template = `Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.

Expand Down Expand Up @@ -47,6 +49,12 @@ class ConversationalRetrievalQAChain_Chains implements INode {
name: 'vectorStoreRetriever',
type: 'BaseRetriever'
},
{
label: 'Return Source Documents',
name: 'returnSourceDocuments',
type: 'boolean',
optional: true
},
{
label: 'System Message',
name: 'systemMessagePrompt',
Expand All @@ -56,6 +64,31 @@ class ConversationalRetrievalQAChain_Chains implements INode {
optional: true,
placeholder:
'I want you to act as a document that I am having a conversation with. Your name is "AI Assistant". You will provide me with answers from the given info. If the answer is not included, say exactly "Hmm, I am not sure." and stop after that. Refuse to answer any question not about the info. Never break character.'
},
{
label: 'Chain Option',
name: 'chainOption',
type: 'options',
options: [
{
label: 'MapReduceDocumentsChain',
name: 'map_reduce',
description:
'Suitable for QA tasks over larger documents and can run the preprocessing step in parallel, reducing the running time'
},
{
label: 'RefineDocumentsChain',
name: 'refine',
description: 'Suitable for QA tasks over a large number of documents.'
},
{
label: 'StuffDocumentsChain',
name: 'stuff',
description: 'Suitable for QA tasks over a small number of documents.'
}
],
additionalParams: true,
optional: true
}
]
}
Expand All @@ -64,44 +97,64 @@ class ConversationalRetrievalQAChain_Chains implements INode {
const model = nodeData.inputs?.model as BaseLanguageModel
const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever as BaseRetriever
const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
const chainOption = nodeData.inputs?.chainOption as string

const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, {
const obj: any = {
verbose: process.env.DEBUG === 'true' ? true : false,
qaTemplate: systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template
})
qaChainOptions: {
type: 'stuff',
prompt: PromptTemplate.fromTemplate(systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template)
},
memory: new BufferMemory({
memoryKey: 'chat_history',
inputKey: 'question',
outputKey: 'text',
returnMessages: true
})
}
if (returnSourceDocuments) obj.returnSourceDocuments = returnSourceDocuments
if (chainOption) obj.qaChainOptions = { ...obj.qaChainOptions, type: chainOption }

const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, obj)
return chain
}

async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
const chain = nodeData.instance as ConversationalRetrievalQAChain
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
let model = nodeData.inputs?.model

// Temporary fix: https://github.com/hwchase17/langchainjs/issues/754
model.streaming = false
chain.questionGeneratorChain.llm = model

let chatHistory = ''
const obj = { question: input }

if (options && options.chatHistory) {
if (chain.memory && options && options.chatHistory) {
const chatHistory = []
const histories: IMessage[] = options.chatHistory
chatHistory = histories
.map((item) => {
return item.message
})
.join('')
}

const obj = {
question: input,
chat_history: chatHistory ? chatHistory : []
const memory = chain.memory as BaseChatMemory

for (const message of histories) {
if (message.type === 'apiMessage') {
chatHistory.push(new AIChatMessage(message.message))
} else if (message.type === 'userMessage') {
chatHistory.push(new HumanChatMessage(message.message))
}
}
memory.chatHistory = new ChatMessageHistory(chatHistory)
chain.memory = memory
}

if (options.socketIO && options.socketIOClientId) {
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId, undefined, returnSourceDocuments)
const res = await chain.call(obj, [handler])
if (res.text && res.sourceDocuments) return res
return res?.text
} else {
const res = await chain.call(obj)
if (res.text && res.sourceDocuments) return res
return res?.text
}
}
Expand Down
2 changes: 1 addition & 1 deletion packages/components/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"faiss-node": "^0.2.1",
"form-data": "^4.0.0",
"graphql": "^16.6.0",
"langchain": "^0.0.84",
"langchain": "^0.0.91",
"linkifyjs": "^4.1.1",
"mammoth": "^1.5.1",
"moment": "^2.29.3",
Expand Down
2 changes: 1 addition & 1 deletion packages/components/src/Interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ export interface INode extends INodeProperties {
inputs?: INodeParams[]
output?: INodeOutputsValue[]
init?(nodeData: INodeData, input: string, options?: ICommonObject): Promise<any>
run?(nodeData: INodeData, input: string, options?: ICommonObject): Promise<string>
run?(nodeData: INodeData, input: string, options?: ICommonObject): Promise<string | ICommonObject>
}

export interface INodeData extends INodeProperties {
Expand Down
11 changes: 10 additions & 1 deletion packages/components/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import * as fs from 'fs'
import * as path from 'path'
import { BaseCallbackHandler } from 'langchain/callbacks'
import { Server } from 'socket.io'
import { ChainValues } from 'langchain/dist/schema'

export const numberOrExpressionRegex = '^(\\d+\\.?\\d*|{{.*}})$' //return true if string consists only numbers OR expression {{}}
export const notEmptyRegex = '(.|\\s)*\\S(.|\\s)*' //return true if string is not empty or blank
Expand Down Expand Up @@ -208,12 +209,14 @@ export class CustomChainHandler extends BaseCallbackHandler {
socketIO: Server
socketIOClientId = ''
skipK = 0 // Skip streaming for first K numbers of handleLLMStart
returnSourceDocuments = false

constructor(socketIO: Server, socketIOClientId: string, skipK?: number) {
constructor(socketIO: Server, socketIOClientId: string, skipK?: number, returnSourceDocuments?: boolean) {
super()
this.socketIO = socketIO
this.socketIOClientId = socketIOClientId
this.skipK = skipK ?? this.skipK
this.returnSourceDocuments = returnSourceDocuments ?? this.returnSourceDocuments
}

handleLLMStart() {
Expand All @@ -233,4 +236,10 @@ export class CustomChainHandler extends BaseCallbackHandler {
handleLLMEnd() {
this.socketIO.to(this.socketIOClientId).emit('end')
}

handleChainEnd(outputs: ChainValues): void | Promise<void> {
if (this.returnSourceDocuments) {
this.socketIO.to(this.socketIOClientId).emit('sourceDocuments', outputs?.sourceDocuments)
}
}
}
1 change: 1 addition & 0 deletions packages/server/src/Interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export interface IChatMessage {
content: string
chatflowid: string
createdDate: Date
sourceDocuments: string
}

export interface IComponentNodes {
Expand Down
3 changes: 3 additions & 0 deletions packages/server/src/entity/ChatMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ export class ChatMessage implements IChatMessage {
@Column()
content: string

@Column({ nullable: true })
sourceDocuments: string

@CreateDateColumn()
createdDate: Date
}
2 changes: 1 addition & 1 deletion packages/server/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ export class App {
const basicAuthMiddleware = basicAuth({
users: { [username]: password }
})
const whitelistURLs = ['/api/v1/prediction/', '/api/v1/node-icon/']
const whitelistURLs = ['/api/v1/prediction/', '/api/v1/node-icon/', '/api/v1/chatflows-streaming']
this.app.use((req, res, next) => {
if (req.url.includes('/api/v1/')) {
whitelistURLs.some((url) => req.url.includes(url)) ? next() : basicAuthMiddleware(req, res, next)
Expand Down
17 changes: 16 additions & 1 deletion packages/ui/src/layout/MainLayout/Header/ProfileSection/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ import PerfectScrollbar from 'react-perfect-scrollbar'
import MainCard from 'ui-component/cards/MainCard'
import Transitions from 'ui-component/extended/Transitions'
import { BackdropLoader } from 'ui-component/loading/BackdropLoader'
import AboutDialog from 'ui-component/dialog/AboutDialog'

// assets
import { IconLogout, IconSettings, IconFileExport, IconFileDownload } from '@tabler/icons'
import { IconLogout, IconSettings, IconFileExport, IconFileDownload, IconInfoCircle } from '@tabler/icons'

// API
import databaseApi from 'api/database'
Expand All @@ -49,6 +50,7 @@ const ProfileSection = ({ username, handleLogout }) => {

const [open, setOpen] = useState(false)
const [loading, setLoading] = useState(false)
const [aboutDialogOpen, setAboutDialogOpen] = useState(false)

const anchorRef = useRef(null)
const uploadRef = useRef(null)
Expand Down Expand Up @@ -215,6 +217,18 @@ const ProfileSection = ({ username, handleLogout }) => {
</ListItemIcon>
<ListItemText primary={<Typography variant='body2'>Export Database</Typography>} />
</ListItemButton>
<ListItemButton
sx={{ borderRadius: `${customization.borderRadius}px` }}
onClick={() => {
setOpen(false)
setAboutDialogOpen(true)
}}
>
<ListItemIcon>
<IconInfoCircle stroke={1.5} size='1.3rem' />
</ListItemIcon>
<ListItemText primary={<Typography variant='body2'>About Flowise</Typography>} />
</ListItemButton>
{localStorage.getItem('username') && localStorage.getItem('password') && (
<ListItemButton
sx={{ borderRadius: `${customization.borderRadius}px` }}
Expand All @@ -237,6 +251,7 @@ const ProfileSection = ({ username, handleLogout }) => {
</Popper>
<input ref={uploadRef} type='file' hidden accept='.json' onChange={(e) => handleFileUpload(e)} />
<BackdropLoader open={loading} />
<AboutDialog show={aboutDialogOpen} onCancel={() => setAboutDialogOpen(false)} />
</>
)
}
Expand Down
85 changes: 85 additions & 0 deletions packages/ui/src/ui-component/dialog/AboutDialog.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import { createPortal } from 'react-dom'
import { useState, useEffect } from 'react'
import PropTypes from 'prop-types'
import { Dialog, DialogContent, DialogTitle, TableContainer, Table, TableHead, TableRow, TableCell, TableBody, Paper } from '@mui/material'
import moment from 'moment'
import axios from 'axios'

const fetchLatestVer = async ({ api }) => {
let apiReturn = await axios
.get(api)
.then(async function (response) {
return response.data
})
.catch(function (error) {
console.error(error)
})
return apiReturn
}

const AboutDialog = ({ show, onCancel }) => {
const portalElement = document.getElementById('portal')

const [data, setData] = useState({})

useEffect(() => {
if (show) {
const fetchData = async (api) => {
let response = await fetchLatestVer({ api })
setData(response)
}

fetchData('https://api.github.com/repos/FlowiseAI/Flowise/releases/latest')
}

// eslint-disable-next-line react-hooks/exhaustive-deps
}, [show])

const component = show ? (
<Dialog
onClose={onCancel}
open={show}
fullWidth
maxWidth='sm'
aria-labelledby='alert-dialog-title'
aria-describedby='alert-dialog-description'
>
<DialogTitle sx={{ fontSize: '1rem' }} id='alert-dialog-title'>
Flowise Version
</DialogTitle>
<DialogContent>
{data && (
<TableContainer component={Paper}>
<Table aria-label='simple table'>
<TableHead>
<TableRow>
<TableCell>Latest Version</TableCell>
<TableCell>Published At</TableCell>
</TableRow>
</TableHead>
<TableBody>
<TableRow sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
<TableCell component='th' scope='row'>
<a target='_blank' rel='noreferrer' href={data.html_url}>
{data.name}
</a>
</TableCell>
<TableCell>{moment(data.published_at).fromNow()}</TableCell>
</TableRow>
</TableBody>
</Table>
</TableContainer>
)}
</DialogContent>
</Dialog>
) : null

return createPortal(component, portalElement)
}

AboutDialog.propTypes = {
show: PropTypes.bool,
onCancel: PropTypes.func
}

export default AboutDialog
Loading