feat: 初步支持function call(WIP)

This commit is contained in:
ikechan8370 2023-06-23 01:09:12 +08:00
parent 4a4dceec18
commit 97b3acbf3b
24 changed files with 13607 additions and 841 deletions

View file

@ -7,8 +7,9 @@ import * as tokenizer from './tokenizer'
import * as types from './types'
import globalFetch from 'node-fetch'
import { fetchSSE } from './fetch-sse'
import {Role} from "./types";
const CHATGPT_MODEL = 'gpt-3.5-turbo'
const CHATGPT_MODEL = 'gpt-3.5-turbo-0613'
const USER_LABEL_DEFAULT = 'User'
const ASSISTANT_LABEL_DEFAULT = 'ChatGPT'
@ -136,7 +137,8 @@ export class ChatGPTAPI {
*/
async sendMessage(
text: string,
opts: types.SendMessageOptions = {}
opts: types.SendMessageOptions = {},
role: Role = 'user',
): Promise<types.ChatMessage> {
const {
parentMessageId,
@ -157,17 +159,19 @@ export class ChatGPTAPI {
}
const message: types.ChatMessage = {
role: 'user',
role,
id: messageId,
conversationId,
parentMessageId,
text
text,
name: opts.name
}
const latestQuestion = message
const { messages, maxTokens, numTokens } = await this._buildMessages(
text,
role,
opts
)
@ -176,7 +180,8 @@ export class ChatGPTAPI {
id: uuidv4(),
conversationId,
parentMessageId: messageId,
text: ''
text: undefined,
functionCall: undefined
}
const responseP = new Promise<types.ChatMessage>(
@ -228,9 +233,20 @@ export class ChatGPTAPI {
if (response.choices?.length) {
const delta = response.choices[0].delta
result.delta = delta.content
if (delta?.content) result.text += delta.content
if (delta.function_call) {
if (delta.function_call.name) {
result.functionCall = {
name: delta.function_call.name,
arguments: delta.function_call.arguments
}
} else {
result.functionCall.arguments = result.functionCall.arguments || '' + delta.function_call.arguments
}
} else {
result.delta = delta.content
if (delta?.content) result.text += delta.content
}
if (delta.role) {
result.role = delta.role
}
@ -278,7 +294,11 @@ export class ChatGPTAPI {
if (response?.choices?.length) {
const message = response.choices[0].message
result.text = message.content
if (message.content) {
result.text = message.content
} else if (message.function_call) {
result.functionCall = message.function_call
}
if (message.role) {
result.role = message.role
}
@ -358,7 +378,7 @@ export class ChatGPTAPI {
this._apiOrg = apiOrg
}
protected async _buildMessages(text: string, opts: types.SendMessageOptions) {
protected async _buildMessages(text: string, role: Role, opts: types.SendMessageOptions) {
const { systemMessage = this._systemMessage } = opts
let { parentMessageId } = opts
@ -379,7 +399,7 @@ export class ChatGPTAPI {
let nextMessages = text
? messages.concat([
{
role: 'user',
role,
content: text,
name: opts.name
}
@ -395,11 +415,13 @@ export class ChatGPTAPI {
return prompt.concat([`Instructions:\n${message.content}`])
case 'user':
return prompt.concat([`${userLabel}:\n${message.content}`])
case 'function':
return prompt.concat([`Function:\n${message.content}`])
default:
return prompt.concat([`${assistantLabel}:\n${message.content}`])
return message.content ? prompt.concat([`${assistantLabel}:\n${message.content}`]) : prompt
}
}, [] as string[])
.join('\n\n')
.join('\n\n')
const nextNumTokensEstimate = await this._getTokenCount(prompt)
const isValidPrompt = nextNumTokensEstimate <= maxNumTokens
@ -430,7 +452,8 @@ export class ChatGPTAPI {
{
role: parentMessageRole,
content: parentMessage.text,
name: parentMessage.name
name: parentMessage.name,
function_call: parentMessage.functionCall ? parentMessage.functionCall : undefined
},
...nextMessages.slice(systemMessageOffset)
])