fix: 修复一些智能模式的问题

This commit is contained in:
ikechan8370 2023-07-01 15:48:01 +08:00
parent 36e592e773
commit 9838459446
2 changed files with 170 additions and 32 deletions

View file

@ -163,7 +163,7 @@ var ChatGPTAPI = /** @class */ (function () {
name: opts.name
};
latestQuestion = message;
return [4 /*yield*/, this._buildMessages(text, role, opts)];
return [4 /*yield*/, this._buildMessages(text, role, opts, completionParams)];
case 1:
_c = _d.sent(), messages = _c.messages, maxTokens = _c.maxTokens, numTokens = _c.numTokens;
result = {
@ -368,11 +368,11 @@ var ChatGPTAPI = /** @class */ (function () {
enumerable: false,
configurable: true
});
ChatGPTAPI.prototype._buildMessages = function (text, role, opts) {
ChatGPTAPI.prototype._buildMessages = function (text, role, opts, completionParams) {
return __awaiter(this, void 0, void 0, function () {
var _a, systemMessage, parentMessageId, userLabel, assistantLabel, maxNumTokens, messages, systemMessageOffset, nextMessages, numTokens, prompt_1, nextNumTokensEstimate, isValidPrompt, parentMessage, parentMessageRole, maxTokens;
return __generator(this, function (_b) {
switch (_b.label) {
var _a, systemMessage, parentMessageId, userLabel, assistantLabel, maxNumTokens, messages, systemMessageOffset, nextMessages, functionToken, numTokens, _i, _b, func, _c, _d, _e, _f, key, _g, property, _h, _j, field, _k, _l, _m, _o, _p, enumElement, _q, prompt_1, nextNumTokensEstimate, _r, _s, m1, _t, isValidPrompt, parentMessage, parentMessageRole, maxTokens;
return __generator(this, function (_u) {
switch (_u.label) {
case 0:
_a = opts.systemMessage, systemMessage = _a === void 0 ? this._systemMessage : _a;
parentMessageId = opts.parentMessageId;
@ -396,9 +396,85 @@ var ChatGPTAPI = /** @class */ (function () {
}
])
: messages;
numTokens = 0;
_b.label = 1;
functionToken = 0;
numTokens = functionToken;
_i = 0, _b = completionParams.functions;
_u.label = 1;
case 1:
if (!(_i < _b.length)) return [3 /*break*/, 19];
func = _b[_i];
_c = functionToken;
return [4 /*yield*/, this._getTokenCount(func.name)];
case 2:
functionToken = _c + _u.sent();
_d = functionToken;
return [4 /*yield*/, this._getTokenCount(func.description)];
case 3:
functionToken = _d + _u.sent();
if (!func.parameters.properties) return [3 /*break*/, 18];
_e = 0, _f = Object.keys(func.parameters.properties);
_u.label = 4;
case 4:
if (!(_e < _f.length)) return [3 /*break*/, 18];
key = _f[_e];
_g = functionToken;
return [4 /*yield*/, this._getTokenCount(key)];
case 5:
functionToken = _g + _u.sent();
property = func.parameters.properties[key];
_h = 0, _j = Object.keys(property);
_u.label = 6;
case 6:
if (!(_h < _j.length)) return [3 /*break*/, 17];
field = _j[_h];
_k = field;
switch (_k) {
case 'type': return [3 /*break*/, 7];
case 'description': return [3 /*break*/, 9];
case 'field': return [3 /*break*/, 11];
}
return [3 /*break*/, 16];
case 7:
functionToken += 2;
_l = functionToken;
return [4 /*yield*/, this._getTokenCount(property.type)];
case 8:
functionToken = _l + _u.sent();
return [3 /*break*/, 16];
case 9:
functionToken += 2;
_m = functionToken;
return [4 /*yield*/, this._getTokenCount(property.description)];
case 10:
functionToken = _m + _u.sent();
return [3 /*break*/, 16];
case 11:
functionToken -= 3;
_o = 0, _p = property.enum;
_u.label = 12;
case 12:
if (!(_o < _p.length)) return [3 /*break*/, 15];
enumElement = _p[_o];
functionToken += 3;
_q = functionToken;
return [4 /*yield*/, this._getTokenCount(enumElement)];
case 13:
functionToken = _q + _u.sent();
_u.label = 14;
case 14:
_o++;
return [3 /*break*/, 12];
case 15: return [3 /*break*/, 16];
case 16:
_h++;
return [3 /*break*/, 6];
case 17:
_e++;
return [3 /*break*/, 4];
case 18:
_i++;
return [3 /*break*/, 1];
case 19:
prompt_1 = nextMessages
.reduce(function (prompt, message) {
switch (message.role) {
@ -407,32 +483,48 @@ var ChatGPTAPI = /** @class */ (function () {
case 'user':
return prompt.concat(["".concat(userLabel, ":\n").concat(message.content)]);
case 'function':
return prompt.concat(["Function:\n".concat(message.content)]);
// leave befind
return prompt;
default:
return message.content ? prompt.concat(["".concat(assistantLabel, ":\n").concat(message.content)]) : prompt;
}
}, [])
.join('\n\n');
return [4 /*yield*/, this._getTokenCount(prompt_1)];
case 2:
nextNumTokensEstimate = _b.sent();
isValidPrompt = nextNumTokensEstimate <= maxNumTokens;
case 20:
nextNumTokensEstimate = _u.sent();
_r = 0, _s = nextMessages
.filter(function (m) { return m.function_call; });
_u.label = 21;
case 21:
if (!(_r < _s.length)) return [3 /*break*/, 24];
m1 = _s[_r];
_t = nextNumTokensEstimate;
return [4 /*yield*/, this._getTokenCount(JSON.stringify(m1.function_call) || '')];
case 22:
nextNumTokensEstimate = _t + _u.sent();
_u.label = 23;
case 23:
_r++;
return [3 /*break*/, 21];
case 24:
isValidPrompt = nextNumTokensEstimate + functionToken <= maxNumTokens;
if (prompt_1 && !isValidPrompt) {
return [3 /*break*/, 5];
return [3 /*break*/, 27];
}
messages = nextMessages;
numTokens = nextNumTokensEstimate;
if (!isValidPrompt) {
return [3 /*break*/, 5];
return [3 /*break*/, 27];
}
if (!parentMessageId) {
return [3 /*break*/, 5];
return [3 /*break*/, 27];
}
return [4 /*yield*/, this._getMessageById(parentMessageId)];
case 3:
parentMessage = _b.sent();
case 25:
parentMessage = _u.sent();
if (!parentMessage) {
return [3 /*break*/, 5];
return [3 /*break*/, 27];
}
parentMessageRole = parentMessage.role || 'user';
nextMessages = nextMessages.slice(0, systemMessageOffset).concat(__spreadArray([
@ -440,16 +532,15 @@ var ChatGPTAPI = /** @class */ (function () {
role: parentMessageRole,
content: parentMessage.text,
name: parentMessage.name,
function_call: parentMessage.functionCall ? parentMessage.functionCall : undefined,
function_call: parentMessage.functionCall ? parentMessage.functionCall : undefined
}
], nextMessages.slice(systemMessageOffset), true));
parentMessageId = parentMessage.parentMessageId;
_b.label = 4;
case 4:
if (true) return [3 /*break*/, 1];
_b.label = 5;
case 5:
_u.label = 26;
case 26:
if (true) return [3 /*break*/, 19];
_u.label = 27;
case 27:
maxTokens = Math.max(1, Math.min(this._maxModelTokens - numTokens, this._maxResponseTokens));
return [2 /*return*/, { messages: messages, maxTokens: maxTokens, numTokens: numTokens }];
}
@ -457,6 +548,7 @@ var ChatGPTAPI = /** @class */ (function () {
});
};
ChatGPTAPI.prototype._getTokenCount = function (text) {
if (!text) return 0;
return __awaiter(this, void 0, void 0, function () {
return __generator(this, function (_a) {
// TODO: use a better fix in the tokenizer

View file

@ -7,7 +7,7 @@ import * as tokenizer from './tokenizer'
import * as types from './types'
import globalFetch from 'node-fetch'
import { fetchSSE } from './fetch-sse'
import {Role} from "./types";
import {openai, Role} from "./types";
const CHATGPT_MODEL = 'gpt-3.5-turbo-0613'
@ -172,7 +172,8 @@ export class ChatGPTAPI {
const { messages, maxTokens, numTokens } = await this._buildMessages(
text,
role,
opts
opts,
completionParams
)
const result: types.ChatMessage = {
@ -378,7 +379,9 @@ export class ChatGPTAPI {
this._apiOrg = apiOrg
}
protected async _buildMessages(text: string, role: Role, opts: types.SendMessageOptions) {
protected async _buildMessages(text: string, role: Role, opts: types.SendMessageOptions, completionParams: Partial<
Omit<openai.CreateChatCompletionRequest, 'messages' | 'n' | 'stream'>
>) {
const { systemMessage = this._systemMessage } = opts
let { parentMessageId } = opts
@ -405,8 +408,42 @@ export class ChatGPTAPI {
}
])
: messages
let numTokens = 0
let functionToken = 0
let numTokens = functionToken
for (const func of completionParams.functions) {
functionToken += await this._getTokenCount(func.name)
functionToken += await this._getTokenCount(func.description)
if (func.parameters.properties) {
for (let key of Object.keys(func.parameters.properties)) {
functionToken += await this._getTokenCount(key)
let property = func.parameters.properties[key]
for (let field of Object.keys(property)) {
switch (field) {
case 'type': {
functionToken += 2
functionToken += await this._getTokenCount(property.type)
break
}
case 'description': {
functionToken += 2
functionToken += await this._getTokenCount(property.description)
break
}
case 'field': {
functionToken -= 3
for (let enumElement of property.enum) {
functionToken += 3
functionToken += await this._getTokenCount(enumElement)
}
break
}
}
}
}
}
}
do {
const prompt = nextMessages
.reduce((prompt, message) => {
@ -416,20 +453,26 @@ export class ChatGPTAPI {
case 'user':
return prompt.concat([`${userLabel}:\n${message.content}`])
case 'function':
return prompt.concat([`Function:\n${message.content}`])
// leave befind
return prompt
default:
return message.content ? prompt.concat([`${assistantLabel}:\n${message.content}`]) : prompt
}
}, [] as string[])
.join('\n\n')
const nextNumTokensEstimate = await this._getTokenCount(prompt)
const isValidPrompt = nextNumTokensEstimate <= maxNumTokens
let nextNumTokensEstimate = await this._getTokenCount(prompt)
for (const m1 of nextMessages
.filter(m => m.function_call)) {
nextNumTokensEstimate += await this._getTokenCount(JSON.stringify(m1.function_call) || '')
}
const isValidPrompt = nextNumTokensEstimate + functionToken <= maxNumTokens
if (prompt && !isValidPrompt) {
break
}
messages = nextMessages
numTokens = nextNumTokensEstimate
@ -472,6 +515,9 @@ export class ChatGPTAPI {
}
protected async _getTokenCount(text: string) {
if (!text) {
return 0
}
// TODO: use a better fix in the tokenizer
text = text.replace(/<\|endoftext\|>/g, '')