fix: 修复openai client由于计算token引起的错误

This commit is contained in:
ikechan8370 2023-07-01 21:50:26 +08:00
parent 285afad993
commit 80b320ff3a
2 changed files with 123 additions and 102 deletions

View file

@ -373,12 +373,13 @@ var ChatGPTAPI = /** @class */ (function () {
configurable: true configurable: true
}); });
ChatGPTAPI.prototype._buildMessages = function (text, role, opts, completionParams) { ChatGPTAPI.prototype._buildMessages = function (text, role, opts, completionParams) {
var _a, _b;
return __awaiter(this, void 0, void 0, function () { return __awaiter(this, void 0, void 0, function () {
var _a, systemMessage, parentMessageId, userLabel, assistantLabel, maxNumTokens, messages, systemMessageOffset, nextMessages, functionToken, numTokens, _i, _b, func, _c, _d, _e, _f, key, _g, property, _h, _j, field, _k, _l, _m, _o, _p, enumElement, _q, prompt_1, nextNumTokensEstimate, _r, _s, m1, _t, isValidPrompt, parentMessage, parentMessageRole, maxTokens; var _c, systemMessage, parentMessageId, userLabel, assistantLabel, maxNumTokens, messages, systemMessageOffset, nextMessages, functionToken, numTokens, _i, _d, func, _e, _f, _g, _h, key, _j, property, _k, _l, field, _m, _o, _p, _q, _r, enumElement, _s, _t, _u, string, _v, prompt_1, nextNumTokensEstimate, _w, _x, m1, _y, isValidPrompt, parentMessage, parentMessageRole, maxTokens;
return __generator(this, function (_u) { return __generator(this, function (_z) {
switch (_u.label) { switch (_z.label) {
case 0: case 0:
_a = opts.systemMessage, systemMessage = _a === void 0 ? this._systemMessage : _a; _c = opts.systemMessage, systemMessage = _c === void 0 ? this._systemMessage : _c;
parentMessageId = opts.parentMessageId; parentMessageId = opts.parentMessageId;
userLabel = USER_LABEL_DEFAULT; userLabel = USER_LABEL_DEFAULT;
assistantLabel = ASSISTANT_LABEL_DEFAULT; assistantLabel = ASSISTANT_LABEL_DEFAULT;
@ -402,83 +403,100 @@ var ChatGPTAPI = /** @class */ (function () {
: messages; : messages;
functionToken = 0; functionToken = 0;
numTokens = functionToken; numTokens = functionToken;
_i = 0, _b = completionParams.functions; if (!completionParams.functions) return [3 /*break*/, 23];
_u.label = 1; _i = 0, _d = completionParams.functions;
_z.label = 1;
case 1: case 1:
if (!(_i < _b.length)) return [3 /*break*/, 19]; if (!(_i < _d.length)) return [3 /*break*/, 23];
func = _b[_i]; func = _d[_i];
_c = functionToken; _e = functionToken;
return [4 /*yield*/, this._getTokenCount(func.name)]; return [4 /*yield*/, this._getTokenCount(func === null || func === void 0 ? void 0 : func.name)];
case 2: case 2:
functionToken = _c + _u.sent(); functionToken = _e + _z.sent();
_d = functionToken; _f = functionToken;
return [4 /*yield*/, this._getTokenCount(func.description)]; return [4 /*yield*/, this._getTokenCount(func === null || func === void 0 ? void 0 : func.description)];
case 3: case 3:
functionToken = _d + _u.sent(); functionToken = _f + _z.sent();
if (!func.parameters.properties) return [3 /*break*/, 18]; if (!((_a = func === null || func === void 0 ? void 0 : func.parameters) === null || _a === void 0 ? void 0 : _a.properties)) return [3 /*break*/, 18];
_e = 0, _f = Object.keys(func.parameters.properties); _g = 0, _h = Object.keys(func.parameters.properties);
_u.label = 4; _z.label = 4;
case 4: case 4:
if (!(_e < _f.length)) return [3 /*break*/, 18]; if (!(_g < _h.length)) return [3 /*break*/, 18];
key = _f[_e]; key = _h[_g];
_g = functionToken; _j = functionToken;
return [4 /*yield*/, this._getTokenCount(key)]; return [4 /*yield*/, this._getTokenCount(key)];
case 5: case 5:
functionToken = _g + _u.sent(); functionToken = _j + _z.sent();
property = func.parameters.properties[key]; property = func.parameters.properties[key];
_h = 0, _j = Object.keys(property); _k = 0, _l = Object.keys(property);
_u.label = 6; _z.label = 6;
case 6: case 6:
if (!(_h < _j.length)) return [3 /*break*/, 17]; if (!(_k < _l.length)) return [3 /*break*/, 17];
field = _j[_h]; field = _l[_k];
_k = field; _m = field;
switch (_k) { switch (_m) {
case 'type': return [3 /*break*/, 7]; case 'type': return [3 /*break*/, 7];
case 'description': return [3 /*break*/, 9]; case 'description': return [3 /*break*/, 9];
case 'field': return [3 /*break*/, 11]; case 'enum': return [3 /*break*/, 11];
} }
return [3 /*break*/, 16]; return [3 /*break*/, 16];
case 7: case 7:
functionToken += 2; functionToken += 2;
_l = functionToken; _o = functionToken;
return [4 /*yield*/, this._getTokenCount(property.type)]; return [4 /*yield*/, this._getTokenCount(property === null || property === void 0 ? void 0 : property.type)];
case 8: case 8:
functionToken = _l + _u.sent(); functionToken = _o + _z.sent();
return [3 /*break*/, 16]; return [3 /*break*/, 16];
case 9: case 9:
functionToken += 2; functionToken += 2;
_m = functionToken; _p = functionToken;
return [4 /*yield*/, this._getTokenCount(property.description)]; return [4 /*yield*/, this._getTokenCount(property === null || property === void 0 ? void 0 : property.description)];
case 10: case 10:
functionToken = _m + _u.sent(); functionToken = _p + _z.sent();
return [3 /*break*/, 16]; return [3 /*break*/, 16];
case 11: case 11:
functionToken -= 3; functionToken -= 3;
_o = 0, _p = property.enum; _q = 0, _r = property === null || property === void 0 ? void 0 : property.enum;
_u.label = 12; _z.label = 12;
case 12: case 12:
if (!(_o < _p.length)) return [3 /*break*/, 15]; if (!(_q < _r.length)) return [3 /*break*/, 15];
enumElement = _p[_o]; enumElement = _r[_q];
functionToken += 3; functionToken += 3;
_q = functionToken; _s = functionToken;
return [4 /*yield*/, this._getTokenCount(enumElement)]; return [4 /*yield*/, this._getTokenCount(enumElement)];
case 13: case 13:
functionToken = _q + _u.sent(); functionToken = _s + _z.sent();
_u.label = 14; _z.label = 14;
case 14: case 14:
_o++; _q++;
return [3 /*break*/, 12]; return [3 /*break*/, 12];
case 15: return [3 /*break*/, 16]; case 15: return [3 /*break*/, 16];
case 16: case 16:
_h++; _k++;
return [3 /*break*/, 6]; return [3 /*break*/, 6];
case 17: case 17:
_e++; _g++;
return [3 /*break*/, 4]; return [3 /*break*/, 4];
case 18: case 18:
if (!((_b = func === null || func === void 0 ? void 0 : func.parameters) === null || _b === void 0 ? void 0 : _b.required)) return [3 /*break*/, 22];
_t = 0, _u = func.parameters.required;
_z.label = 19;
case 19:
if (!(_t < _u.length)) return [3 /*break*/, 22];
string = _u[_t];
functionToken += 2;
_v = functionToken;
return [4 /*yield*/, this._getTokenCount(string)];
case 20:
functionToken = _v + _z.sent();
_z.label = 21;
case 21:
_t++;
return [3 /*break*/, 19];
case 22:
_i++; _i++;
return [3 /*break*/, 1]; return [3 /*break*/, 1];
case 19: case 23:
prompt_1 = nextMessages prompt_1 = nextMessages
.reduce(function (prompt, message) { .reduce(function (prompt, message) {
switch (message.role) { switch (message.role) {
@ -495,40 +513,40 @@ var ChatGPTAPI = /** @class */ (function () {
}, []) }, [])
.join('\n\n'); .join('\n\n');
return [4 /*yield*/, this._getTokenCount(prompt_1)]; return [4 /*yield*/, this._getTokenCount(prompt_1)];
case 20:
nextNumTokensEstimate = _u.sent();
_r = 0, _s = nextMessages
.filter(function (m) { return m.function_call; });
_u.label = 21;
case 21:
if (!(_r < _s.length)) return [3 /*break*/, 24];
m1 = _s[_r];
_t = nextNumTokensEstimate;
return [4 /*yield*/, this._getTokenCount(JSON.stringify(m1.function_call) || '')];
case 22:
nextNumTokensEstimate = _t + _u.sent();
_u.label = 23;
case 23:
_r++;
return [3 /*break*/, 21];
case 24: case 24:
nextNumTokensEstimate = _z.sent();
_w = 0, _x = nextMessages
.filter(function (m) { return m.function_call; });
_z.label = 25;
case 25:
if (!(_w < _x.length)) return [3 /*break*/, 28];
m1 = _x[_w];
_y = nextNumTokensEstimate;
return [4 /*yield*/, this._getTokenCount(JSON.stringify(m1.function_call) || '')];
case 26:
nextNumTokensEstimate = _y + _z.sent();
_z.label = 27;
case 27:
_w++;
return [3 /*break*/, 25];
case 28:
isValidPrompt = nextNumTokensEstimate + functionToken <= maxNumTokens; isValidPrompt = nextNumTokensEstimate + functionToken <= maxNumTokens;
if (prompt_1 && !isValidPrompt) { if (prompt_1 && !isValidPrompt) {
return [3 /*break*/, 27]; return [3 /*break*/, 31];
} }
messages = nextMessages; messages = nextMessages;
numTokens = nextNumTokensEstimate + functionToken; numTokens = nextNumTokensEstimate + functionToken;
if (!isValidPrompt) { if (!isValidPrompt) {
return [3 /*break*/, 27]; return [3 /*break*/, 31];
} }
if (!parentMessageId) { if (!parentMessageId) {
return [3 /*break*/, 27]; return [3 /*break*/, 31];
} }
return [4 /*yield*/, this._getMessageById(parentMessageId)]; return [4 /*yield*/, this._getMessageById(parentMessageId)];
case 25: case 29:
parentMessage = _u.sent(); parentMessage = _z.sent();
if (!parentMessage) { if (!parentMessage) {
return [3 /*break*/, 27]; return [3 /*break*/, 31];
} }
parentMessageRole = parentMessage.role || 'user'; parentMessageRole = parentMessage.role || 'user';
nextMessages = nextMessages.slice(0, systemMessageOffset).concat(__spreadArray([ nextMessages = nextMessages.slice(0, systemMessageOffset).concat(__spreadArray([
@ -540,11 +558,11 @@ var ChatGPTAPI = /** @class */ (function () {
} }
], nextMessages.slice(systemMessageOffset), true)); ], nextMessages.slice(systemMessageOffset), true));
parentMessageId = parentMessage.parentMessageId; parentMessageId = parentMessage.parentMessageId;
_u.label = 26; _z.label = 30;
case 26: case 30:
if (true) return [3 /*break*/, 19]; if (true) return [3 /*break*/, 23];
_u.label = 27; _z.label = 31;
case 27: case 31:
maxTokens = Math.max(1, Math.min(this._maxModelTokens - numTokens, this._maxResponseTokens)); maxTokens = Math.max(1, Math.min(this._maxModelTokens - numTokens, this._maxResponseTokens));
return [2 /*return*/, { messages: messages, maxTokens: maxTokens, numTokens: numTokens }]; return [2 /*return*/, { messages: messages, maxTokens: maxTokens, numTokens: numTokens }];
} }

View file

@ -415,44 +415,47 @@ export class ChatGPTAPI {
let functionToken = 0 let functionToken = 0
let numTokens = functionToken let numTokens = functionToken
for (const func of completionParams.functions) { if (completionParams.functions) {
functionToken += await this._getTokenCount(func.name) for (const func of completionParams.functions) {
functionToken += await this._getTokenCount(func.description) functionToken += await this._getTokenCount(func?.name)
if (func.parameters.properties) { functionToken += await this._getTokenCount(func?.description)
for (let key of Object.keys(func.parameters.properties)) { if (func?.parameters?.properties) {
functionToken += await this._getTokenCount(key) for (let key of Object.keys(func.parameters.properties)) {
let property = func.parameters.properties[key] functionToken += await this._getTokenCount(key)
for (let field of Object.keys(property)) { let property = func.parameters.properties[key]
switch (field) { for (let field of Object.keys(property)) {
case 'type': { switch (field) {
functionToken += 2 case 'type': {
functionToken += await this._getTokenCount(property.type) functionToken += 2
break functionToken += await this._getTokenCount(property?.type)
} break
case 'description': { }
functionToken += 2 case 'description': {
functionToken += await this._getTokenCount(property.description) functionToken += 2
break functionToken += await this._getTokenCount(property?.description)
} break
case 'field': { }
functionToken -= 3 case 'enum': {
for (let enumElement of property.enum) { functionToken -= 3
functionToken += 3 for (let enumElement of property?.enum) {
functionToken += await this._getTokenCount(enumElement) functionToken += 3
functionToken += await this._getTokenCount(enumElement)
}
break
} }
break
} }
} }
} }
} }
} if (func?.parameters?.required) {
if (func.parameters.required) { for (let string of func.parameters.required) {
for (let string of func.parameters.required) { functionToken += 2
functionToken += 2 functionToken += await this._getTokenCount(string)
functionToken += await this._getTokenCount(string) }
} }
} }
} }
do { do {
const prompt = nextMessages const prompt = nextMessages
.reduce((prompt, message) => { .reduce((prompt, message) => {