diff --git a/utils/openai/chatgpt-api.js b/utils/openai/chatgpt-api.js index 4b1139e..49c2e17 100644 --- a/utils/openai/chatgpt-api.js +++ b/utils/openai/chatgpt-api.js @@ -163,7 +163,7 @@ var ChatGPTAPI = /** @class */ (function () { name: opts.name }; latestQuestion = message; - return [4 /*yield*/, this._buildMessages(text, role, opts)]; + return [4 /*yield*/, this._buildMessages(text, role, opts, completionParams)]; case 1: _c = _d.sent(), messages = _c.messages, maxTokens = _c.maxTokens, numTokens = _c.numTokens; result = { @@ -368,11 +368,11 @@ var ChatGPTAPI = /** @class */ (function () { enumerable: false, configurable: true }); - ChatGPTAPI.prototype._buildMessages = function (text, role, opts) { + ChatGPTAPI.prototype._buildMessages = function (text, role, opts, completionParams) { return __awaiter(this, void 0, void 0, function () { - var _a, systemMessage, parentMessageId, userLabel, assistantLabel, maxNumTokens, messages, systemMessageOffset, nextMessages, numTokens, prompt_1, nextNumTokensEstimate, isValidPrompt, parentMessage, parentMessageRole, maxTokens; - return __generator(this, function (_b) { - switch (_b.label) { + var _a, systemMessage, parentMessageId, userLabel, assistantLabel, maxNumTokens, messages, systemMessageOffset, nextMessages, functionToken, numTokens, _i, _b, func, _c, _d, _e, _f, key, _g, property, _h, _j, field, _k, _l, _m, _o, _p, enumElement, _q, prompt_1, nextNumTokensEstimate, _r, _s, m1, _t, isValidPrompt, parentMessage, parentMessageRole, maxTokens; + return __generator(this, function (_u) { + switch (_u.label) { case 0: _a = opts.systemMessage, systemMessage = _a === void 0 ? this._systemMessage : _a; parentMessageId = opts.parentMessageId; @@ -396,9 +396,85 @@ var ChatGPTAPI = /** @class */ (function () { } ]) : messages; - numTokens = 0; - _b.label = 1; + functionToken = 0; + numTokens = functionToken; + _i = 0, _b = completionParams.functions; + _u.label = 1; case 1: + if (!(_i < _b.length)) return [3 /*break*/, 19]; + func = _b[_i]; + _c = functionToken; + return [4 /*yield*/, this._getTokenCount(func.name)]; + case 2: + functionToken = _c + _u.sent(); + _d = functionToken; + return [4 /*yield*/, this._getTokenCount(func.description)]; + case 3: + functionToken = _d + _u.sent(); + if (!func.parameters.properties) return [3 /*break*/, 18]; + _e = 0, _f = Object.keys(func.parameters.properties); + _u.label = 4; + case 4: + if (!(_e < _f.length)) return [3 /*break*/, 18]; + key = _f[_e]; + _g = functionToken; + return [4 /*yield*/, this._getTokenCount(key)]; + case 5: + functionToken = _g + _u.sent(); + property = func.parameters.properties[key]; + _h = 0, _j = Object.keys(property); + _u.label = 6; + case 6: + if (!(_h < _j.length)) return [3 /*break*/, 17]; + field = _j[_h]; + _k = field; + switch (_k) { + case 'type': return [3 /*break*/, 7]; + case 'description': return [3 /*break*/, 9]; + case 'field': return [3 /*break*/, 11]; + } + return [3 /*break*/, 16]; + case 7: + functionToken += 2; + _l = functionToken; + return [4 /*yield*/, this._getTokenCount(property.type)]; + case 8: + functionToken = _l + _u.sent(); + return [3 /*break*/, 16]; + case 9: + functionToken += 2; + _m = functionToken; + return [4 /*yield*/, this._getTokenCount(property.description)]; + case 10: + functionToken = _m + _u.sent(); + return [3 /*break*/, 16]; + case 11: + functionToken -= 3; + _o = 0, _p = property.enum; + _u.label = 12; + case 12: + if (!(_o < _p.length)) return [3 /*break*/, 15]; + enumElement = _p[_o]; + functionToken += 3; + _q = functionToken; + return [4 /*yield*/, this._getTokenCount(enumElement)]; + case 13: + functionToken = _q + _u.sent(); + _u.label = 14; + case 14: + _o++; + return [3 /*break*/, 12]; + case 15: return [3 /*break*/, 16]; + case 16: + _h++; + return [3 /*break*/, 6]; + case 17: + _e++; + return [3 /*break*/, 4]; + case 18: + _i++; + return [3 /*break*/, 1]; + case 19: prompt_1 = nextMessages .reduce(function (prompt, message) { switch (message.role) { @@ -407,32 +483,48 @@ var ChatGPTAPI = /** @class */ (function () { case 'user': return prompt.concat(["".concat(userLabel, ":\n").concat(message.content)]); case 'function': - return prompt.concat(["Function:\n".concat(message.content)]); + // leave befind + return prompt; default: return message.content ? prompt.concat(["".concat(assistantLabel, ":\n").concat(message.content)]) : prompt; } }, []) .join('\n\n'); return [4 /*yield*/, this._getTokenCount(prompt_1)]; - case 2: - nextNumTokensEstimate = _b.sent(); - isValidPrompt = nextNumTokensEstimate <= maxNumTokens; + case 20: + nextNumTokensEstimate = _u.sent(); + _r = 0, _s = nextMessages + .filter(function (m) { return m.function_call; }); + _u.label = 21; + case 21: + if (!(_r < _s.length)) return [3 /*break*/, 24]; + m1 = _s[_r]; + _t = nextNumTokensEstimate; + return [4 /*yield*/, this._getTokenCount(JSON.stringify(m1.function_call) || '')]; + case 22: + nextNumTokensEstimate = _t + _u.sent(); + _u.label = 23; + case 23: + _r++; + return [3 /*break*/, 21]; + case 24: + isValidPrompt = nextNumTokensEstimate + functionToken <= maxNumTokens; if (prompt_1 && !isValidPrompt) { - return [3 /*break*/, 5]; + return [3 /*break*/, 27]; } messages = nextMessages; numTokens = nextNumTokensEstimate; if (!isValidPrompt) { - return [3 /*break*/, 5]; + return [3 /*break*/, 27]; } if (!parentMessageId) { - return [3 /*break*/, 5]; + return [3 /*break*/, 27]; } return [4 /*yield*/, this._getMessageById(parentMessageId)]; - case 3: - parentMessage = _b.sent(); + case 25: + parentMessage = _u.sent(); if (!parentMessage) { - return [3 /*break*/, 5]; + return [3 /*break*/, 27]; } parentMessageRole = parentMessage.role || 'user'; nextMessages = nextMessages.slice(0, systemMessageOffset).concat(__spreadArray([ @@ -440,16 +532,15 @@ var ChatGPTAPI = /** @class */ (function () { role: parentMessageRole, content: parentMessage.text, name: parentMessage.name, - function_call: parentMessage.functionCall ? parentMessage.functionCall : undefined, - + function_call: parentMessage.functionCall ? parentMessage.functionCall : undefined } ], nextMessages.slice(systemMessageOffset), true)); parentMessageId = parentMessage.parentMessageId; - _b.label = 4; - case 4: - if (true) return [3 /*break*/, 1]; - _b.label = 5; - case 5: + _u.label = 26; + case 26: + if (true) return [3 /*break*/, 19]; + _u.label = 27; + case 27: maxTokens = Math.max(1, Math.min(this._maxModelTokens - numTokens, this._maxResponseTokens)); return [2 /*return*/, { messages: messages, maxTokens: maxTokens, numTokens: numTokens }]; } @@ -457,6 +548,7 @@ var ChatGPTAPI = /** @class */ (function () { }); }; ChatGPTAPI.prototype._getTokenCount = function (text) { + if (!text) return 0; return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { // TODO: use a better fix in the tokenizer diff --git a/utils/openai/chatgpt-api.ts b/utils/openai/chatgpt-api.ts index 2643eb5..c7ffe76 100644 --- a/utils/openai/chatgpt-api.ts +++ b/utils/openai/chatgpt-api.ts @@ -7,7 +7,7 @@ import * as tokenizer from './tokenizer' import * as types from './types' import globalFetch from 'node-fetch' import { fetchSSE } from './fetch-sse' -import {Role} from "./types"; +import {openai, Role} from "./types"; const CHATGPT_MODEL = 'gpt-3.5-turbo-0613' @@ -172,7 +172,8 @@ export class ChatGPTAPI { const { messages, maxTokens, numTokens } = await this._buildMessages( text, role, - opts + opts, + completionParams ) const result: types.ChatMessage = { @@ -378,7 +379,9 @@ export class ChatGPTAPI { this._apiOrg = apiOrg } - protected async _buildMessages(text: string, role: Role, opts: types.SendMessageOptions) { + protected async _buildMessages(text: string, role: Role, opts: types.SendMessageOptions, completionParams: Partial< + Omit + >) { const { systemMessage = this._systemMessage } = opts let { parentMessageId } = opts @@ -405,8 +408,42 @@ export class ChatGPTAPI { } ]) : messages - let numTokens = 0 + let functionToken = 0 + + let numTokens = functionToken + for (const func of completionParams.functions) { + functionToken += await this._getTokenCount(func.name) + functionToken += await this._getTokenCount(func.description) + if (func.parameters.properties) { + for (let key of Object.keys(func.parameters.properties)) { + functionToken += await this._getTokenCount(key) + let property = func.parameters.properties[key] + for (let field of Object.keys(property)) { + switch (field) { + case 'type': { + functionToken += 2 + functionToken += await this._getTokenCount(property.type) + break + } + case 'description': { + functionToken += 2 + functionToken += await this._getTokenCount(property.description) + break + } + case 'field': { + functionToken -= 3 + for (let enumElement of property.enum) { + functionToken += 3 + functionToken += await this._getTokenCount(enumElement) + } + break + } + } + } + } + } + } do { const prompt = nextMessages .reduce((prompt, message) => { @@ -416,20 +453,26 @@ export class ChatGPTAPI { case 'user': return prompt.concat([`${userLabel}:\n${message.content}`]) case 'function': - return prompt.concat([`Function:\n${message.content}`]) + // leave befind + return prompt default: return message.content ? prompt.concat([`${assistantLabel}:\n${message.content}`]) : prompt } }, [] as string[]) .join('\n\n') - const nextNumTokensEstimate = await this._getTokenCount(prompt) - const isValidPrompt = nextNumTokensEstimate <= maxNumTokens + let nextNumTokensEstimate = await this._getTokenCount(prompt) + + for (const m1 of nextMessages + .filter(m => m.function_call)) { + nextNumTokensEstimate += await this._getTokenCount(JSON.stringify(m1.function_call) || '') + } + + const isValidPrompt = nextNumTokensEstimate + functionToken <= maxNumTokens if (prompt && !isValidPrompt) { break } - messages = nextMessages numTokens = nextNumTokensEstimate @@ -472,6 +515,9 @@ export class ChatGPTAPI { } protected async _getTokenCount(text: string) { + if (!text) { + return 0 + } // TODO: use a better fix in the tokenizer text = text.replace(/<\|endoftext\|>/g, '')