diff --git a/.gitignore b/.gitignore index 80a3da6..98703f4 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,6 @@ data/ utils/processors utils/tools utils/triggers +memory.md +resources/simple +memory.db diff --git a/apps/bym.js b/apps/bym.js index 07c3f37..b68854c 100644 --- a/apps/bym.js +++ b/apps/bym.js @@ -3,7 +3,9 @@ import { Chaite } from 'chaite' import { intoUserMessage, toYunzai } from '../utils/message.js' import common from '../../../lib/common/common.js' import { getGroupContextPrompt } from '../utils/group.js' -import {formatTimeToBeiJing} from '../utils/common.js' +import { formatTimeToBeiJing } from '../utils/common.js' +import { extractTextFromUserMessage, processUserMemory } from '../models/memory/userMemoryManager.js' +import { buildMemoryPrompt } from '../models/memory/prompt.js' export class bym extends plugin { constructor () { @@ -83,6 +85,7 @@ export class bym extends plugin { toggleMode: ChatGPTConfig.basic.toggleMode, togglePrefix: ChatGPTConfig.basic.togglePrefix }) + const userText = extractTextFromUserMessage(userMessage) || e.msg || '' // 伪人不记录历史 // sendMessageOption.disableHistoryRead = true // sendMessageOption.disableHistorySave = true @@ -98,9 +101,29 @@ export class bym extends plugin { this.reply(forwardElement) } } + const systemSegments = [] + if (sendMessageOption.systemOverride) { + systemSegments.push(sendMessageOption.systemOverride) + } + if (userText) { + const memoryPrompt = await buildMemoryPrompt({ + userId: e.sender.user_id + '', + groupId: e.isGroup ? e.group_id + '' : null, + queryText: userText + }) + if (memoryPrompt) { + systemSegments.push(memoryPrompt) + logger.debug(`[Memory] bym memory prompt: ${memoryPrompt}`) + } + } if (ChatGPTConfig.llm.enableGroupContext && e.isGroup) { const contextPrompt = await getGroupContextPrompt(e, ChatGPTConfig.llm.groupContextLength) - sendMessageOption.systemOverride = sendMessageOption.systemOverride ? sendMessageOption.systemOverride + '\n' + contextPrompt : contextPrompt + if (contextPrompt) { + systemSegments.push(contextPrompt) + } + } + if (systemSegments.length > 0) { + sendMessageOption.systemOverride = systemSegments.join('\n\n') } // 发送 const response = await Chaite.getInstance().sendMessage(userMessage, e, { @@ -120,5 +143,13 @@ export class bym extends plugin { await e.reply(forwardElement, false, { recallMsg: recall ? 10 : 0 }) } } + await processUserMemory({ + event: e, + userMessage, + userText, + conversationId: sendMessageOption.conversationId, + assistantContents: response.contents, + assistantMessageId: response.id + }) } } diff --git a/apps/chat.js b/apps/chat.js index 48e353d..b6eca32 100644 --- a/apps/chat.js +++ b/apps/chat.js @@ -2,7 +2,9 @@ import Config from '../config/config.js' import { Chaite, SendMessageOption } from 'chaite' import { getPreset, intoUserMessage, toYunzai } from '../utils/message.js' import { YunzaiUserState } from '../models/chaite/storage/lowdb/user_state_storage.js' -import { getGroupContextPrompt, getGroupHistory } from '../utils/group.js' +import { getGroupContextPrompt } from '../utils/group.js' +import { buildMemoryPrompt } from '../models/memory/prompt.js' +import { extractTextFromUserMessage, processUserMemory } from '../models/memory/userMemoryManager.js' import * as crypto from 'node:crypto' export class Chat extends plugin { @@ -11,7 +13,8 @@ export class Chat extends plugin { name: 'ChatGPT-Plugin对话', dsc: 'ChatGPT-Plugin对话', event: 'message', - priority: 500, + // 应🥑要求降低优先级 + priority: 555500, rule: [ { reg: '^[^#][sS]*', @@ -63,12 +66,34 @@ export class Chat extends plugin { toggleMode: Config.basic.toggleMode, togglePrefix: Config.basic.togglePrefix }) + const userText = extractTextFromUserMessage(userMessage) || e.msg || '' sendMessageOptions.conversationId = state?.current?.conversationId sendMessageOptions.parentMessageId = state?.current?.messageId || state?.conversations.find(c => c.id === sendMessageOptions.conversationId)?.lastMessageId + const systemSegments = [] + const baseSystem = sendMessageOptions.systemOverride || preset.sendMessageOption?.systemOverride || '' + if (baseSystem) { + systemSegments.push(baseSystem) + } + if (userText) { + const memoryPrompt = await buildMemoryPrompt({ + userId: e.sender.user_id + '', + groupId: e.isGroup ? e.group_id + '' : null, + queryText: userText + }) + if (memoryPrompt) { + systemSegments.push(memoryPrompt) + logger.debug(`[Memory] memory prompt: ${memoryPrompt}`) + } + } const enableGroupContext = (preset.groupContext === 'use_system' || !preset.groupContext) ? Config.llm.enableGroupContext : (preset.groupContext === 'enabled') if (enableGroupContext && e.isGroup) { const contextPrompt = await getGroupContextPrompt(e, Config.llm.groupContextLength) - sendMessageOptions.systemOverride = sendMessageOptions.systemOverride ? sendMessageOptions.systemOverride + '\n' + contextPrompt : (preset.sendMessageOption.systemOverride + contextPrompt) + if (contextPrompt) { + systemSegments.push(contextPrompt) + } + } + if (systemSegments.length > 0) { + sendMessageOptions.systemOverride = systemSegments.join('\n\n') } const response = await Chaite.getInstance().sendMessage(userMessage, e, { ...sendMessageOptions, @@ -95,5 +120,13 @@ export class Chat extends plugin { for (let forwardElement of forward) { this.reply(forwardElement) } + await processUserMemory({ + event: e, + userMessage, + userText, + conversationId: sendMessageOptions.conversationId, + assistantContents: response.contents, + assistantMessageId: response.id + }) } } diff --git a/apps/memory.js b/apps/memory.js new file mode 100644 index 0000000..5fbc3cc --- /dev/null +++ b/apps/memory.js @@ -0,0 +1,224 @@ +import Config from '../config/config.js' +import { GroupMessageCollector } from '../models/memory/collector.js' +import { memoryService } from '../models/memory/service.js' +import common from '../../../lib/common/common.js' + +const collector = new GroupMessageCollector() + +function isGroupManager (e) { + if (e.isMaster) { + return true + } + if (!e.member) { + return false + } + if (typeof e.member.is_admin !== 'undefined') { + return e.member.is_admin || e.member.is_owner + } + if (typeof e.member.role !== 'undefined') { + return ['admin', 'owner'].includes(e.member.role) + } + return false +} + +export class MemoryManager extends plugin { + constructor () { + const cmdPrefix = Config.basic.commandPrefix || '#chatgpt' + super({ + name: 'ChatGPT-Plugin记忆系统', + dsc: '处理记忆系统相关的采集与管理', + event: 'message', + priority: 550, + rule: [ + // { + // reg: '[\\s\\S]+', + // fnc: 'collect', + // log: false + // }, + { + reg: '^#?(我的)?记忆$', + fnc: 'showUserMemory' + }, + { + reg: '^#?他的记忆$', + fnc: 'showTargetUserMemory' + }, + { + reg: '^#?(删除|清除)(我的)?记忆\\s*(\\d+)$', + fnc: 'deleteUserMemory' + }, + { + reg: '^#?(本群|群)记忆$', + fnc: 'showGroupMemory' + }, + { + reg: '^#?(删除|移除)群记忆\\s*(\\d+)$', + fnc: 'deleteGroupMemory' + }, + { + reg: `^${cmdPrefix}记忆列表$`, + fnc: 'adminMemoryOverview', + permission: 'master' + } + ] + }) + + // 兼容miao和trss,气死了 + let task = { + name: 'ChatGPT-群记忆轮询', + cron: '*/1 * * * *', + fnc: this.pollHistoryTask.bind(this), + log: false + } + this.task = [task] + + } + + async collect (e) { + collector.push(e) + return false + } + + async showUserMemory (e) { + if (!memoryService.isUserMemoryEnabled(e.sender.user_id)) { + await e.reply('私人记忆未开启或您未被授权。') + return false + } + const memories = memoryService.listUserMemories(e.sender.user_id, e.isGroup ? e.group_id : null) + + if (!memories.length) { + await e.reply('🧠 您的记忆:\n暂无记录~') + return true + } + + const msgs = memories.map(item => + `${item.id}. ${item.value}(更新时间:${item.updated_at})` + ) + + const forwardMsg = await common.makeForwardMsg(e, ['🧠 您的记忆:', ...msgs], '私人记忆列表') + await e.reply(forwardMsg) + return true + } + + async showTargetUserMemory (e) { + if (!e.isGroup) { + await e.reply('该指令仅可在群聊中使用。') + return false + } + + const at = e.at || (e.message?.find(m => m.type === 'at')?.qq) + if (!at) { + await e.reply('请@要查询的用户。') + return false + } + + if (!memoryService.isUserMemoryEnabled(at)) { + await e.reply('该用户未开启私人记忆或未被授权。') + return false + } + + const memories = memoryService.listUserMemories(at, e.group_id) + + if (!memories.length) { + await e.reply('🧠 TA的记忆:\n暂无记录~') + return true + } + + const msgs = memories.map(item => + `${item.id}. ${item.value}(更新时间:${item.updated_at})` + ) + + const forwardMsg = await common.makeForwardMsg(e, ['🧠 TA的记忆:', ...msgs], 'TA的记忆列表') + await e.reply(forwardMsg) + return true + } + + async deleteUserMemory (e) { + const match = e.msg.match(/(\d+)$/) + if (!match) { + return false + } + const memoryId = Number(match[1]) + if (!memoryId) { + return false + } + if (!memoryService.isUserMemoryEnabled(e.sender.user_id)) { + await e.reply('私人记忆未开启或您未被授权。') + return false + } + const success = memoryService.deleteUserMemory(memoryId, e.sender.user_id) + await e.reply(success ? '已删除指定记忆。' : '未找到对应的记忆条目。') + return success + } + + async showGroupMemory (e) { + if (!e.isGroup) { + await e.reply('该指令仅可在群聊中使用。') + return false + } + if (!memoryService.isGroupMemoryEnabled(e.group_id)) { + await e.reply('本群尚未开启记忆功能。') + return false + } + await collector.flush(e.group_id) + const facts = memoryService.listGroupFacts(e.group_id) + + if (!facts.length) { + await e.reply('📚 本群记忆:\n暂无群记忆。') + return true + } + + const msgs = facts.map(item => { + const topic = item.topic ? `【${item.topic}】` : '' + return `${item.id}. ${topic}${item.fact}` + }) + + const forwardMsg = await common.makeForwardMsg(e, ['📚 本群记忆:', ...msgs], '群记忆列表') + await e.reply(forwardMsg) + return true + } + + async deleteGroupMemory (e) { + if (!e.isGroup) { + await e.reply('该指令仅可在群聊中使用。') + return false + } + if (!memoryService.isGroupMemoryEnabled(e.group_id)) { + await e.reply('本群尚未开启记忆功能。') + return false + } + if (!isGroupManager(e)) { + await e.reply('仅限主人或群管理员管理群记忆。') + return false + } + await collector.flush(e.group_id) + const match = e.msg.match(/(\d+)$/) + if (!match) { + return false + } + const factId = Number(match[1]) + if (!factId) { + return false + } + const success = memoryService.deleteGroupFact(e.group_id, factId) + await e.reply(success ? '已删除群记忆。' : '未找到对应的群记忆。') + return success + } + + async adminMemoryOverview (e) { + const enabledGroups = (Config.memory?.group?.enabledGroups || []).map(String) + const groupLines = enabledGroups.length ? enabledGroups.join(', ') : '暂无' + const userStatus = Config.memory?.user?.enable ? '已启用' : '未启用' + await e.reply(`记忆系统概览:\n- 群记忆开关:${Config.memory?.group?.enable ? '已启用' : '未启用'}\n- 已启用群:${groupLines}\n- 私人记忆:${userStatus}`) + return true + } + + async pollHistoryTask () { + try { + await collector.tickHistoryPolling() + } catch (err) { + logger.error('[Memory] scheduled history poll failed:', err) + } + return false + } +} diff --git a/config/config.js b/config/config.js index acf8f6d..e93bec1 100644 --- a/config/config.js +++ b/config/config.js @@ -184,6 +184,123 @@ class ChatGPTConfig { storage: 'sqlite' } + /** + * 记忆系统配置 + * @type {{ + * database: string, + * vectorDimensions: number, + * group: { + * enable: boolean, + * enabledGroups: string[], + * extractionModel: string, + * extractionPresetId: string, + * minMessageCount: number, + * maxMessageWindow: number, + * retrievalMode: 'vector' | 'keyword' | 'hybrid', + * hybridPrefer: 'vector-first' | 'keyword-first', + * historyPollInterval: number, + * historyBatchSize: number, + * promptHeader: string, + * promptItemTemplate: string, + * promptFooter: string, + * extractionSystemPrompt: string, + * extractionUserPrompt: string, + * vectorMaxDistance: number, + * textMaxBm25Score: number, + * maxFactsPerInjection: number, + * minImportanceForInjection: number + * }, + * user: { + * enable: boolean, + * whitelist: string[], + * blacklist: string[], + * extractionModel: string, + * extractionPresetId: string, + * maxItemsPerInjection: number, + * maxRelevantItemsPerQuery: number, + * minImportanceForInjection: number, + * promptHeader: string, + * promptItemTemplate: string, + * promptFooter: string, + * extractionSystemPrompt: string, + * extractionUserPrompt: string + * }, + * extensions: { + * simple: { + * enable: boolean, + * libraryPath: string, + * dictPath: string, + * useJieba: boolean + * } + * } + * }} + */ + memory = { + database: 'data/memory.db', + vectorDimensions: 1536, + group: { + enable: false, + enabledGroups: [], + extractionModel: '', + extractionPresetId: '', + minMessageCount: 80, + maxMessageWindow: 300, + retrievalMode: 'hybrid', + hybridPrefer: 'vector-first', + historyPollInterval: 300, + historyBatchSize: 120, + promptHeader: '# 以下是一些该群聊中可能相关的事实,你可以参考,但不要主动透露这些事实。', + promptItemTemplate: '- ${fact}${topicSuffix}${timeSuffix}', + promptFooter: '', + extractionSystemPrompt: `You are a knowledge extraction assistant that specialises in summarising long-term facts from group chat transcripts. +Read the provided conversation and identify statements that should be stored as long-term knowledge for the group. +Return a JSON array. Each element must contain: +{ + "fact": 事实内容,必须完整包含事件的各个要素而不能是简单的短语(比如谁参与了事件、做了什么事情、背景时间是什么)(同一件事情尽可能整合为同一条而非拆分,以便利于检索), + "topic": 主题关键词,字符串,如 "活动"、"成员信息", + "importance": 一个介于0和1之间的小数,数值越大表示越重要, + "source_message_ids": 原始消息ID数组, + "source_messages": 对应原始消息的简要摘录或合并文本, + "involved_users": 出现或相关的用户ID数组 +} +Only include meaningful, verifiable group-specific information that is useful for future conversations. Do not record incomplete information. Do not include general knowledge or unrelated facts. Do not wrap the JSON array in code fences.`, + extractionUserPrompt: `以下是群聊中的一些消息,请根据系统说明提取值得长期记忆的事实,以JSON数组形式返回,不要输出额外说明。 + +\${messages}`, + vectorMaxDistance: 0, + textMaxBm25Score: 0, + maxFactsPerInjection: 5, + minImportanceForInjection: 0.3 + }, + user: { + enable: false, + whitelist: [], + blacklist: [], + extractionModel: '', + extractionPresetId: '', + maxItemsPerInjection: 5, + maxRelevantItemsPerQuery: 3, + minImportanceForInjection: 0, + promptHeader: '# 用户画像', + promptItemTemplate: '- ${value}${timeSuffix}', + promptFooter: '', + extractionSystemPrompt: `You are an assistant that extracts long-term personal preferences or persona details about a user. +Given a conversation snippet between the user and the bot, identify durable information such as preferences, nicknames, roles, speaking style, habits, or other facts that remain valid over time. +Return a JSON array of **strings**, and nothing else, without any other characters including \`\`\` or \`\`\`json. Each string must be a short sentence (in the same language as the conversation) describing one piece of long-term memory. Do not include keys, JSON objects, or additional metadata. Ignore temporary topics or uncertain information.`, + extractionUserPrompt: `下面是用户与机器人的对话,请根据系统提示提取可长期记忆的个人信息。 + +\${messages}` + }, + extensions: { + simple: { + enable: false, + libraryPath: '', + dictPath: '', + useJieba: false + } + } + } + constructor () { this.version = '3.0.0' this.watcher = null @@ -336,21 +453,13 @@ class ChatGPTConfig { ? JSON.parse(content) : yaml.load(content) - // 只更新存在的配置项 + // 处理加载的配置并和默认值合并 if (loadedConfig) { - Object.keys(loadedConfig).forEach(key => { - if (key === 'version' || key === 'basic' || key === 'bym' || key === 'llm' || - key === 'management' || key === 'chaite') { - if (typeof loadedConfig[key] === 'object' && loadedConfig[key] !== null) { - // 对象的合并 - if (!this[key]) this[key] = {} - Object.assign(this[key], loadedConfig[key]) - } else { - // 基本类型直接赋值 - this[key] = loadedConfig[key] - } - } - }) + const mergeResult = this._mergeConfig(loadedConfig) + if (mergeResult.changed) { + logger?.debug?.('[Config] merged new defaults into persisted config; scheduling save') + this._triggerSave('code') + } } logger.debug('Config loaded successfully') @@ -359,6 +468,68 @@ class ChatGPTConfig { } } + _mergeConfig (loadedConfig) { + let changed = false + + const mergeInto = (target, source) => { + if (!source || typeof source !== 'object') { + return target + } + if (!target || typeof target !== 'object') { + target = Array.isArray(source) ? [] : {} + } + const result = Array.isArray(source) ? [] : { ...target } + + if (Array.isArray(source)) { + return source.slice() + } + + const targetKeys = target && typeof target === 'object' + ? Object.keys(target) + : [] + for (const key of targetKeys) { + if (!Object.prototype.hasOwnProperty.call(source, key)) { + changed = true + } + } + + for (const key of Object.keys(source)) { + const sourceValue = source[key] + const targetValue = target[key] + if (sourceValue && typeof sourceValue === 'object' && !Array.isArray(sourceValue)) { + result[key] = mergeInto(targetValue, sourceValue) + } else { + if (targetValue === undefined || targetValue !== sourceValue) { + changed = true + } + result[key] = sourceValue + } + } + return result + } + + const sections = ['version', 'basic', 'bym', 'llm', 'management', 'chaite', 'memory'] + for (const key of sections) { + const loadedValue = loadedConfig[key] + if (loadedValue === undefined) { + continue + } + if (typeof loadedValue === 'object' && loadedValue !== null) { + const merged = mergeInto(this[key], loadedValue) + if (merged !== this[key]) { + this[key] = merged + } + } else { + if (this[key] !== loadedValue) { + changed = true + } + this[key] = loadedValue + } + } + + return { changed } + } + // 合并触发保存,防抖处理 _triggerSave (origin) { // 清除之前的定时器 @@ -366,20 +537,18 @@ class ChatGPTConfig { clearTimeout(this._saveTimer) } - // 记录保存来源 - this._saveOrigin = origin || 'code' - - // 设置定时器延迟保存 + const originLabel = origin || 'code' + this._saveOrigin = originLabel this._saveTimer = setTimeout(() => { - this.saveToFile() - // 保存完成后延迟一下再清除来源标记 - setTimeout(() => { - this._saveOrigin = null - }, 100) + this.saveToFile(originLabel) + this._saveOrigin = null }, 200) } - saveToFile () { + saveToFile (origin = 'code') { + if (origin !== 'code') { + this._saveOrigin = 'external' + } logger.debug('Saving config to file...') try { const config = { @@ -388,7 +557,8 @@ class ChatGPTConfig { bym: this.bym, llm: this.llm, management: this.management, - chaite: this.chaite + chaite: this.chaite, + memory: this.memory } const content = this.configPath.endsWith('.json') @@ -408,7 +578,8 @@ class ChatGPTConfig { bym: this.bym, llm: this.llm, management: this.management, - chaite: this.chaite + chaite: this.chaite, + memory: this.memory } } } diff --git a/models/chaite/cloud.js b/models/chaite/cloud.js index 5aaf1e3..ac07ca1 100644 --- a/models/chaite/cloud.js +++ b/models/chaite/cloud.js @@ -3,8 +3,6 @@ import { ChannelsManager, ChatPresetManager, DefaultChannelLoadBalancer, - GeminiClient, - OpenAIClient, ProcessorsManager, RAGManager, ToolManager, @@ -34,6 +32,8 @@ import { checkMigrate } from './storage/sqlite/migrate.js' import { SQLiteHistoryManager } from './storage/sqlite/history_manager.js' import SQLiteTriggerStorage from './storage/sqlite/trigger_storage.js' import LowDBTriggerStorage from './storage/lowdb/trigger_storage,.js' +import { createChaiteVectorizer } from './vectorizer.js' +import { MemoryRouter, authenticateMemoryRequest } from '../memory/router.js' /** * 认证,以便共享上传 @@ -49,77 +49,13 @@ export async function authCloud (apiKey = ChatGPTConfig.chaite.cloudApiKey) { } } -/** - * - * @param {import('chaite').Channel} channel - * @returns {Promise} - */ -async function getIClientByChannel (channel) { - await channel.ready() - switch (channel.adapterType) { - case 'openai': { - return new OpenAIClient(channel.options) - } - case 'gemini': { - return new GeminiClient(channel.options) - } - case 'claude': { - throw new Error('claude doesn\'t support embedding') - } - } -} - /** * 初始化RAG管理器 * @param {string} model * @param {number} dimensions */ export async function initRagManager (model, dimensions) { - const vectorizer = new class { - async textToVector (text) { - const channels = await Chaite.getInstance().getChannelsManager().getChannelByModel(model) - if (channels.length === 0) { - throw new Error('No channel found for model: ' + model) - } - const channel = channels[0] - const client = await getIClientByChannel(channel) - const result = await client.getEmbedding(text, { - model, - dimensions - }) - return result.embeddings[0] - } - - /** - * - * @param {string[]} texts - * @returns {Promise[]>} - */ - async batchTextToVector (texts) { - const availableChannels = (await Chaite.getInstance().getChannelsManager().getAllChannels()).filter(c => c.models.includes(model)) - if (availableChannels.length === 0) { - throw new Error('No channel found for model: ' + model) - } - const channels = await Chaite.getInstance().getChannelsManager().getChannelsByModel(model, texts.length) - /** - * @type {import('chaite').IClient[]} - */ - const clients = await Promise.all(channels.map(({ channel }) => getIClientByChannel(channel))) - const results = [] - let startIndex = 0 - for (let i = 0; i < channels.length; i++) { - const { quantity } = channels[i] - const textsSlice = texts.slice(startIndex, startIndex + quantity) - const embeddings = await clients[i].getEmbedding(textsSlice, { - model, - dimensions - }) - results.push(...embeddings.embeddings) - startIndex += quantity - } - return results - } - }() + const vectorizer = createChaiteVectorizer(model, dimensions) const vectorDBPath = path.resolve('./plugins/chatgpt-plugin', ChatGPTConfig.chaite.dataDir, 'vector_index') if (!fs.existsSync(vectorDBPath)) { fs.mkdirSync(vectorDBPath, { recursive: true }) @@ -246,7 +182,9 @@ export async function initChaite () { chaite.getGlobalConfig().setPort(ChatGPTConfig.chaite.port) chaite.getGlobalConfig().setDebug(ChatGPTConfig.basic.debug) logger.info('Chaite.RAGManager 初始化完成') - chaite.runApiServer() + chaite.runApiServer(app => { + app.use('/api/memory', authenticateMemoryRequest, MemoryRouter) + }) } function deepMerge (target, source) { diff --git a/models/chaite/vectorizer.js b/models/chaite/vectorizer.js new file mode 100644 index 0000000..60b8f60 --- /dev/null +++ b/models/chaite/vectorizer.js @@ -0,0 +1,89 @@ +import { Chaite, ChaiteContext, GeminiClient, OpenAIClient } from 'chaite' + +async function getIClientByChannel (channel) { + await channel.ready() + const baseLogger = global.logger || console + if (channel.options?.setLogger) { + channel.options.setLogger(baseLogger) + } + const context = new ChaiteContext(baseLogger) + context.setChaite(Chaite.getInstance()) + switch (channel.adapterType) { + case 'openai': + return new OpenAIClient(channel.options, context) + case 'gemini': + return new GeminiClient(channel.options, context) + case 'claude': + throw new Error('claude does not support embedding') + default: + throw new Error(`Unsupported adapter ${channel.adapterType}`) + } +} + +async function resolveChannelForModel (model) { + const manager = Chaite.getInstance().getChannelsManager() + const channels = await manager.getChannelByModel(model) + if (channels.length === 0) { + throw new Error('No channel found for model: ' + model) + } + return channels[0] +} + +export async function getClientForModel (model) { + const channel = await resolveChannelForModel(model) + const client = await getIClientByChannel(channel) + return { client, channel } +} + +/** + * 创建一个基于Chaite渠道的向量器 + * @param {string} model + * @param {number} dimensions + * @returns {{ textToVector: (text: string) => Promise, batchTextToVector: (texts: string[]) => Promise }} + */ +export function createChaiteVectorizer (model, dimensions) { + return { + async textToVector (text) { + const { client } = await getClientForModel(model) + const options = { model } + if (Number.isFinite(dimensions) && dimensions > 0) { + options.dimensions = dimensions + } + const result = await client.getEmbedding(text, options) + return result.embeddings[0] + }, + async batchTextToVector (texts) { + const manager = Chaite.getInstance().getChannelsManager() + const channels = await manager.getChannelsByModel(model, texts.length) + if (channels.length === 0) { + throw new Error('No channel found for model: ' + model) + } + const clients = await Promise.all(channels.map(({ channel }) => getIClientByChannel(channel))) + const results = [] + let startIndex = 0 + for (let i = 0; i < channels.length; i++) { + const { quantity } = channels[i] + const slice = texts.slice(startIndex, startIndex + quantity) + const options = { model } + if (Number.isFinite(dimensions) && dimensions > 0) { + options.dimensions = dimensions + } + const embeddings = await clients[i].getEmbedding(slice, options) + results.push(...embeddings.embeddings) + startIndex += quantity + } + return results + } + } +} + +export async function embedTexts (texts, model, dimensions) { + if (!texts || texts.length === 0) { + return [] + } + const vectorizer = createChaiteVectorizer(model, dimensions) + if (texts.length === 1) { + return [await vectorizer.textToVector(texts[0])] + } + return await vectorizer.batchTextToVector(texts) +} diff --git a/models/memory/collector.js b/models/memory/collector.js new file mode 100644 index 0000000..d81acc1 --- /dev/null +++ b/models/memory/collector.js @@ -0,0 +1,633 @@ +import * as crypto from 'node:crypto' +import ChatGPTConfig from '../../config/config.js' +import { extractGroupFacts } from './extractor.js' +import { memoryService } from './service.js' +import { getBotFramework } from '../../utils/bot.js' +import { ICQQGroupContextCollector, TRSSGroupContextCollector } from '../../utils/group.js' +import { groupHistoryCursorStore } from './groupHistoryCursorStore.js' + +const DEFAULT_MAX_WINDOW = 300 // seconds +const DEFAULT_HISTORY_BATCH = 120 +const MAX_RECENT_IDS = 200 + +function nowSeconds () { + return Math.floor(Date.now() / 1000) +} + +function normaliseGroupId (groupId) { + return groupId === null || groupId === undefined ? null : String(groupId) +} + +function shouldIgnoreMessage (e) { + if (!e || !e.message) { + return true + } + if (e.sender?.user_id && e.sender.user_id === e.bot?.uin) { + return true + } + if (e.isPrivate) { + return true + } + const text = e.msg?.trim() + if (!text) { + return true + } + if (text.startsWith('#')) { + return true + } + const prefix = ChatGPTConfig.basic?.togglePrefix + if (prefix && text.startsWith(prefix)) { + return true + } + return false +} + +function extractPlainText (e) { + if (e.msg) { + return e.msg.trim() + } + if (Array.isArray(e.message)) { + return e.message + .filter(item => item.type === 'text') + .map(item => item.text || '') + .join('') + .trim() + } + return '' +} + +function extractHistoryText (chat) { + if (!chat) { + return '' + } + if (typeof chat.raw_message === 'string') { + const trimmed = chat.raw_message.trim() + if (trimmed) { + return trimmed + } + } + if (typeof chat.msg === 'string') { + const trimmed = chat.msg.trim() + if (trimmed) { + return trimmed + } + } + if (Array.isArray(chat.message)) { + const merged = chat.message + .filter(item => item && item.type === 'text') + .map(item => item.text || '') + .join('') + .trim() + if (merged) { + return merged + } + } + if (typeof chat.text === 'string') { + const trimmed = chat.text.trim() + if (trimmed) { + return trimmed + } + } + return '' +} + +function toPositiveInt (value, fallback = 0) { + const num = Number(value) + if (Number.isFinite(num) && num > 0) { + return Math.floor(num) + } + return fallback +} + +function normalizeTimestamp (value) { + if (value === null || value === undefined) { + return 0 + } + const num = Number(value) + if (!Number.isFinite(num) || num <= 0) { + return 0 + } + if (num > 1e12) { + return Math.floor(num) + } + return Math.floor(num * 1000) +} + +function resolveMessageIdCandidate (source) { + if (!source) { + return '' + } + const candidates = [ + source.message_id, + source.messageId, + source.msg_id, + source.seq, + source.messageSeq, + source.id + ] + for (const candidate of candidates) { + if (candidate || candidate === 0) { + const str = String(candidate).trim() + if (str) { + return str + } + } + } + return '' +} + +function resolveUserId (source) { + if (!source) { + return '' + } + const candidates = [ + source.user_id, + source.uid, + source.userId, + source.uin, + source.id, + source.qq + ] + for (const candidate of candidates) { + if (candidate || candidate === 0) { + const str = String(candidate).trim() + if (str) { + return str + } + } + } + return '' +} + +function resolveNickname (source) { + if (!source) { + return '' + } + const candidates = [ + source.card, + source.nickname, + source.name, + source.remark + ] + for (const candidate of candidates) { + if (typeof candidate === 'string') { + const trimmed = candidate.trim() + if (trimmed) { + return trimmed + } + } + } + return '' +} + +export class GroupMessageCollector { + constructor () { + this.buffers = new Map() + this.processing = new Set() + this.groupStates = new Map() + this.lastPollAt = 0 + this.polling = false + this.selfIds = null + } + + get groupConfig () { + return ChatGPTConfig.memory?.group || {} + } + + get historyBatchSize () { + const config = this.groupConfig + const configured = toPositiveInt(config.historyBatchSize, 0) + if (configured > 0) { + return configured + } + const minCount = toPositiveInt(config.minMessageCount, 80) + return Math.max(minCount, DEFAULT_HISTORY_BATCH) + } + + get historyPollIntervalMs () { + const config = this.groupConfig + const configured = Number(config.historyPollInterval) + if (Number.isFinite(configured) && configured > 0) { + return Math.floor(configured) * 1000 + } + if (configured === 0) { + return 0 + } + const fallbackSeconds = Math.max(toPositiveInt(config.maxMessageWindow, DEFAULT_MAX_WINDOW), DEFAULT_MAX_WINDOW) + return fallbackSeconds * 1000 + } + + async tickHistoryPolling (force = false) { + const intervalMs = this.historyPollIntervalMs + if (intervalMs <= 0) { + return + } + if (!force) { + const now = Date.now() + if (this.lastPollAt && (now - this.lastPollAt) < intervalMs) { + return + } + } else { + this.refreshSelfIds() + } + await this.runHistoryPoll() + } + + async runHistoryPoll () { + if (this.polling) { + return + } + this.polling = true + try { + logger.info('[Memory] start group history poll') + await this.pollGroupHistories() + } catch (err) { + logger.error('[Memory] group history poll execution failed:', err) + } finally { + this.lastPollAt = Date.now() + this.polling = false + } + } + + async pollGroupHistories () { + const config = this.groupConfig + if (!config.enable) { + return + } + const groupIds = (config.enabledGroups || []) + .map(normaliseGroupId) + .filter(Boolean) + if (groupIds.length === 0) { + return + } + this.refreshSelfIds() + const framework = getBotFramework() + for (const groupId of groupIds) { + if (!memoryService.isGroupMemoryEnabled(groupId)) { + continue + } + const collector = framework === 'trss' + ? new TRSSGroupContextCollector() + : new ICQQGroupContextCollector() + try { + const added = await this.collectHistoryForGroup(collector, groupId) + if (added > 0) { + logger.debug(`[Memory] history poll buffered ${added} messages, group=${groupId}`) + } + } catch (err) { + logger.warn(`[Memory] failed to poll history for group=${groupId}:`, err) + } + } + } + + async collectHistoryForGroup (collector, groupId) { + const limit = this.historyBatchSize + if (!limit) { + return 0 + } + let chats = [] + try { + chats = await collector.collect(undefined, groupId, 0, limit) + } catch (err) { + logger.warn(`[Memory] failed to collect history for group=${groupId}:`, err) + return 0 + } + if (!Array.isArray(chats) || chats.length === 0) { + return 0 + } + const messages = [] + for (const chat of chats) { + const payload = this.transformHistoryMessage(groupId, chat) + if (payload) { + messages.push(payload) + } + } + if (!messages.length) { + return 0 + } + messages.sort((a, b) => normalizeTimestamp(a.timestamp) - normalizeTimestamp(b.timestamp)) + let queued = 0 + for (const payload of messages) { + if (this.queueMessage(groupId, payload)) { + queued++ + } + } + return queued + } + + transformHistoryMessage (groupId, chat) { + const text = extractHistoryText(chat) + if (!text) { + return null + } + if (text.startsWith('#')) { + return null + } + const prefix = ChatGPTConfig.basic?.togglePrefix + if (prefix && text.startsWith(prefix)) { + return null + } + const sender = chat?.sender || {} + const userId = resolveUserId(sender) || resolveUserId(chat) + if (this.isBotSelfId(userId)) { + return null + } + return { + message_id: resolveMessageIdCandidate(chat), + user_id: userId, + nickname: resolveNickname(sender) || resolveNickname(chat), + text, + timestamp: chat?.time ?? chat?.timestamp ?? chat?.message_time ?? Date.now() + } + } + + queueMessage (groupId, rawPayload) { + if (!rawPayload || !rawPayload.text) { + return false + } + const state = this.getGroupState(groupId) + const messageId = this.ensureMessageId(rawPayload) + const timestampMs = normalizeTimestamp(rawPayload.timestamp) + const buffer = this.getBuffer(groupId) + const payload = { + message_id: messageId, + user_id: rawPayload.user_id ? String(rawPayload.user_id) : '', + nickname: rawPayload.nickname ? String(rawPayload.nickname) : '', + text: rawPayload.text, + timestamp: timestampMs || Date.now() + } + const messageKey = this.resolveMessageKey(payload, messageId, timestampMs) + if (this.shouldSkipMessage(state, timestampMs, messageKey, payload.message_id)) { + return false + } + this.updateGroupState(groupId, state, timestampMs, messageKey, payload.message_id) + buffer.messages.push(payload) + logger.debug(`[Memory] buffered group message, group=${groupId}, buffer=${buffer.messages.length}`) + this.tryTriggerFlush(groupId, buffer) + return true + } + + ensureMessageId (payload) { + const direct = payload?.message_id ? String(payload.message_id).trim() : '' + if (direct) { + return direct + } + const fallback = resolveMessageIdCandidate(payload) + if (fallback) { + return fallback + } + return crypto.randomUUID() + } + + resolveMessageKey (payload, messageId, timestampMs) { + if (messageId) { + return messageId + } + const parts = [ + timestampMs || '', + payload?.user_id || '', + (payload?.text || '').slice(0, 32) + ] + return parts.filter(Boolean).join(':') + } + + getGroupState (groupId) { + let state = this.groupStates.get(groupId) + if (!state) { + const cursor = groupHistoryCursorStore.getCursor(groupId) + const lastTimestamp = Number(cursor?.last_timestamp) || 0 + const lastMessageId = cursor?.last_message_id || null + state = { + lastTimestamp, + lastMessageId, + recentIds: new Set() + } + if (lastMessageId) { + state.recentIds.add(lastMessageId) + } + this.groupStates.set(groupId, state) + } + return state + } + + shouldSkipMessage (state, timestampMs, messageKey, messageId) { + if (!state) { + return false + } + if (messageId && state.lastMessageId && messageId === state.lastMessageId) { + return true + } + if (timestampMs && timestampMs < state.lastTimestamp) { + return true + } + if (timestampMs && timestampMs === state.lastTimestamp && messageKey && state.recentIds.has(messageKey)) { + return true + } + if (!timestampMs && messageKey && state.recentIds.has(messageKey)) { + return true + } + return false + } + + updateGroupState (groupId, state, timestampMs, messageKey, messageId) { + const hasTimestamp = Number.isFinite(timestampMs) && timestampMs > 0 + if (!hasTimestamp) { + if (messageKey) { + state.recentIds.add(messageKey) + if (state.recentIds.size > MAX_RECENT_IDS) { + const ids = Array.from(state.recentIds).slice(-MAX_RECENT_IDS) + state.recentIds = new Set(ids) + } + } + if (messageId) { + state.lastMessageId = String(messageId) + groupHistoryCursorStore.updateCursor(groupId, { + lastMessageId: state.lastMessageId, + lastTimestamp: state.lastTimestamp || null + }) + } + return + } + + if (timestampMs > state.lastTimestamp) { + state.lastTimestamp = timestampMs + state.recentIds = messageKey ? new Set([messageKey]) : new Set() + } else if (timestampMs === state.lastTimestamp && messageKey) { + state.recentIds.add(messageKey) + if (state.recentIds.size > MAX_RECENT_IDS) { + const ids = Array.from(state.recentIds).slice(-MAX_RECENT_IDS) + state.recentIds = new Set(ids) + } + } + + if (messageId) { + state.lastMessageId = String(messageId) + } + + groupHistoryCursorStore.updateCursor(groupId, { + lastMessageId: state.lastMessageId || null, + lastTimestamp: state.lastTimestamp || timestampMs + }) + } + + getBuffer (groupId) { + let buffer = this.buffers.get(groupId) + if (!buffer) { + buffer = { + messages: [], + lastFlushAt: nowSeconds() + } + this.buffers.set(groupId, buffer) + } + return buffer + } + + tryTriggerFlush (groupId, buffer) { + const config = this.groupConfig + const minCount = config.minMessageCount || 50 + const maxWindow = config.maxMessageWindow || DEFAULT_MAX_WINDOW + const shouldFlushByCount = buffer.messages.length >= minCount + const shouldFlushByTime = buffer.messages.length > 0 && (nowSeconds() - buffer.lastFlushAt) >= maxWindow + logger.debug(`[Memory] try trigger flush, group=${groupId}, count=${buffer.messages.length}, lastFlushAt=${buffer.lastFlushAt}, shouldFlushByCount=${shouldFlushByCount}, shouldFlushByTime=${shouldFlushByTime}`) + if (shouldFlushByCount || shouldFlushByTime) { + logger.info(`[Memory] trigger group fact extraction, group=${groupId}, count=${buffer.messages.length}, reason=${shouldFlushByCount ? 'count' : 'timeout'}`) + this.flush(groupId).catch(err => logger.error('Failed to flush group memory:', err)) + } + } + + push (e) { + const groupId = normaliseGroupId(e.group_id || e.group?.group_id) + if (!memoryService.isGroupMemoryEnabled(groupId)) { + return + } + if (shouldIgnoreMessage(e)) { + return + } + const text = extractPlainText(e) + if (!text) { + return + } + this.addSelfId(e.bot?.uin) + const messageId = e.message_id || e.seq || crypto.randomUUID() + logger.debug(`[Memory] collect group message, group=${groupId}, user=${e.sender?.user_id}, buffer=${(this.buffers.get(groupId)?.messages.length || 0) + 1}`) + this.queueMessage(groupId, { + message_id: messageId, + user_id: String(e.sender?.user_id || ''), + nickname: e.sender?.card || e.sender?.nickname || '', + text, + timestamp: e.time || Date.now() + }) + } + + async flush (groupId) { + if (this.processing.has(groupId)) { + return + } + const buffer = this.buffers.get(groupId) + if (!buffer || buffer.messages.length === 0) { + return + } + this.processing.add(groupId) + try { + const messages = buffer.messages + this.buffers.set(groupId, { + messages: [], + lastFlushAt: nowSeconds() + }) + logger.debug(`[Memory] flushing group buffer, group=${groupId}, messages=${messages.length}`) + const simplified = messages.map(msg => ({ + message_id: msg.message_id, + user_id: msg.user_id, + nickname: msg.nickname, + text: msg.text + })) + const factCandidates = await extractGroupFacts(simplified) + if (factCandidates.length === 0) { + logger.debug(`[Memory] group fact extraction returned empty, group=${groupId}`) + return + } + const messageMap = new Map(messages.map(msg => [msg.message_id, msg.text])) + const enrichedFacts = factCandidates.map(fact => { + if (!fact.source_message_ids && fact.sourceMessages) { + fact.source_message_ids = fact.sourceMessages + } + let ids = [] + if (Array.isArray(fact.source_message_ids)) { + ids = fact.source_message_ids.map(id => String(id)) + } else if (typeof fact.source_message_ids === 'string') { + ids = fact.source_message_ids.split(',').map(id => id.trim()).filter(Boolean) + } + if (!fact.source_messages && ids.length > 0) { + const summary = ids + .map(id => messageMap.get(id) || '') + .filter(Boolean) + .join('\n') + fact.source_messages = summary + } + fact.source_message_ids = ids + if (!fact.involved_users || !Array.isArray(fact.involved_users)) { + fact.involved_users = [] + } else { + fact.involved_users = fact.involved_users.map(id => String(id)) + } + return fact + }) + const saved = await memoryService.saveGroupFacts(groupId, enrichedFacts) + logger.info(`[Memory] saved ${saved.length} group facts for group=${groupId}`) + } finally { + this.processing.delete(groupId) + } + } + + addSelfId (uin) { + if (uin === null || uin === undefined) { + return + } + const str = String(uin) + if (!str) { + return + } + if (!this.selfIds) { + this.selfIds = new Set() + } + this.selfIds.add(str) + } + + refreshSelfIds () { + this.selfIds = this.collectSelfIds() + } + + collectSelfIds () { + const ids = new Set() + try { + const botGlobal = global.Bot + if (botGlobal?.bots && typeof botGlobal.bots === 'object') { + for (const bot of Object.values(botGlobal.bots)) { + if (bot?.uin) { + ids.add(String(bot.uin)) + } + } + } + if (botGlobal?.uin) { + ids.add(String(botGlobal.uin)) + } + } catch (err) { + logger?.debug?.('[Memory] failed to collect bot self ids: %o', err) + } + return ids + } + + isBotSelfId (userId) { + if (userId === null || userId === undefined) { + return false + } + const str = String(userId) + if (!str) { + return false + } + if (!this.selfIds || this.selfIds.size === 0) { + this.refreshSelfIds() + } + return this.selfIds?.has(str) || false + } +} diff --git a/models/memory/database.js b/models/memory/database.js new file mode 100644 index 0000000..63fbe3d --- /dev/null +++ b/models/memory/database.js @@ -0,0 +1,755 @@ +import Database from 'better-sqlite3' +import * as sqliteVec from 'sqlite-vec' +import fs from 'fs' +import path from 'path' +import ChatGPTConfig from '../../config/config.js' + +const META_VECTOR_DIM_KEY = 'group_vec_dimension' +const META_VECTOR_MODEL_KEY = 'group_vec_model' +const META_GROUP_TOKENIZER_KEY = 'group_memory_tokenizer' +const META_USER_TOKENIZER_KEY = 'user_memory_tokenizer' +const TOKENIZER_DEFAULT = 'unicode61' +const SIMPLE_MATCH_SIMPLE = 'simple_query' +const SIMPLE_MATCH_JIEBA = 'jieba_query' +const PLUGIN_ROOT = path.resolve('./plugins/chatgpt-plugin') + +let dbInstance = null +let cachedVectorDimension = null +let cachedVectorModel = null +let userMemoryFtsConfig = { + tokenizer: TOKENIZER_DEFAULT, + matchQuery: null +} +let groupMemoryFtsConfig = { + tokenizer: TOKENIZER_DEFAULT, + matchQuery: null +} +const simpleExtensionState = { + requested: false, + enabled: false, + loaded: false, + error: null, + libraryPath: '', + dictPath: '', + tokenizer: TOKENIZER_DEFAULT, + matchQuery: null +} + +function resolveDbPath () { + const relativePath = ChatGPTConfig.memory?.database || 'data/memory.db' + return path.resolve('./plugins/chatgpt-plugin', relativePath) +} + +export function resolvePluginPath (targetPath) { + if (!targetPath) { + return '' + } + if (path.isAbsolute(targetPath)) { + return targetPath + } + return path.resolve(PLUGIN_ROOT, targetPath) +} + +export function toPluginRelativePath (absolutePath) { + if (!absolutePath) { + return '' + } + return path.relative(PLUGIN_ROOT, absolutePath) +} + +function resolvePreferredDimension () { + const { memory, llm } = ChatGPTConfig + if (memory?.vectorDimensions && memory.vectorDimensions > 0) { + return memory.vectorDimensions + } + if (llm?.dimensions && llm.dimensions > 0) { + return llm.dimensions + } + return 1536 +} + +function ensureDirectory (filePath) { + const dir = path.dirname(filePath) + if (!fs.existsSync(dir)) { + fs.mkdirSync(dir, { recursive: true }) + } +} + +function ensureMetaTable (db) { + db.exec(` + CREATE TABLE IF NOT EXISTS memory_meta ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ) + `) +} + +function getMetaValue (db, key) { + const stmt = db.prepare('SELECT value FROM memory_meta WHERE key = ?') + const row = stmt.get(key) + return row ? row.value : null +} + +function setMetaValue (db, key, value) { + db.prepare(` + INSERT INTO memory_meta (key, value) + VALUES (?, ?) + ON CONFLICT(key) DO UPDATE SET value = excluded.value + `).run(key, value) +} + +function resetSimpleState (overrides = {}) { + simpleExtensionState.loaded = false + simpleExtensionState.error = null + simpleExtensionState.tokenizer = TOKENIZER_DEFAULT + simpleExtensionState.matchQuery = null + Object.assign(simpleExtensionState, overrides) + userMemoryFtsConfig = { + tokenizer: TOKENIZER_DEFAULT, + matchQuery: null + } + groupMemoryFtsConfig = { + tokenizer: TOKENIZER_DEFAULT, + matchQuery: null + } +} + +function sanitiseRawFtsInput (input) { + if (!input) { + return '' + } + const trimmed = String(input).trim() + if (!trimmed) { + return '' + } + const replaced = trimmed + .replace(/["'`]+/g, ' ') + .replace(/\u3000/g, ' ') + .replace(/[^\p{L}\p{N}\u4E00-\u9FFF\u3040-\u30FF\uAC00-\uD7AF\u1100-\u11FF\s]+/gu, ' ') + const collapsed = replaced.replace(/\s+/g, ' ').trim() + return collapsed || trimmed +} + +function isSimpleLibraryFile (filename) { + return /(^libsimple.*\.(so|dylib|dll)$)|(^simple\.(so|dylib|dll)$)/i.test(filename) +} + +function findSimpleLibrary (startDir) { + const stack = [startDir] + while (stack.length > 0) { + const dir = stack.pop() + if (!dir || !fs.existsSync(dir)) { + continue + } + const entries = fs.readdirSync(dir, { withFileTypes: true }) + for (const entry of entries) { + const fullPath = path.join(dir, entry.name) + if (entry.isDirectory()) { + stack.push(fullPath) + } else if (entry.isFile() && isSimpleLibraryFile(entry.name)) { + return fullPath + } + } + } + return '' +} + +function locateDictPathNear (filePath) { + if (!filePath) { + return '' + } + let currentDir = path.dirname(filePath) + for (let depth = 0; depth < 5 && currentDir && currentDir !== path.dirname(currentDir); depth++) { + const dictCandidate = path.join(currentDir, 'dict') + if (fs.existsSync(dictCandidate) && fs.statSync(dictCandidate).isDirectory()) { + return dictCandidate + } + currentDir = path.dirname(currentDir) + } + return '' +} + +function discoverSimplePaths () { + const searchRoots = [ + path.join(PLUGIN_ROOT, 'resources/simple'), + path.join(PLUGIN_ROOT, 'resources'), + path.join(PLUGIN_ROOT, 'lib/simple'), + PLUGIN_ROOT + ] + for (const root of searchRoots) { + if (!root || !fs.existsSync(root)) { + continue + } + const lib = findSimpleLibrary(root) + if (lib) { + const dictCandidate = locateDictPathNear(lib) + return { + libraryPath: toPluginRelativePath(lib) || lib, + dictPath: dictCandidate ? (toPluginRelativePath(dictCandidate) || dictCandidate) : '' + } + } + } + return { libraryPath: '', dictPath: '' } +} + +function applySimpleExtension (db) { + const config = ChatGPTConfig.memory?.extensions?.simple || {} + simpleExtensionState.requested = Boolean(config.enable) + simpleExtensionState.enabled = Boolean(config.enable) + simpleExtensionState.libraryPath = config.libraryPath || '' + simpleExtensionState.dictPath = config.dictPath || '' + if (!config.enable) { + logger?.debug?.('[Memory] simple tokenizer disabled via config') + resetSimpleState({ requested: false, enabled: false }) + return + } + if (!simpleExtensionState.libraryPath) { + const detected = discoverSimplePaths() + if (detected.libraryPath) { + simpleExtensionState.libraryPath = detected.libraryPath + simpleExtensionState.dictPath = detected.dictPath + config.libraryPath = detected.libraryPath + if (detected.dictPath) { + config.dictPath = detected.dictPath + } + } + } + const resolvedLibraryPath = resolvePluginPath(config.libraryPath) + if (!resolvedLibraryPath || !fs.existsSync(resolvedLibraryPath)) { + logger?.warn?.('[Memory] simple tokenizer library missing:', resolvedLibraryPath || '(empty path)') + resetSimpleState({ + requested: true, + enabled: true, + error: `Simple extension library not found at ${resolvedLibraryPath || '(empty path)'}` + }) + return + } + try { + logger?.info?.('[Memory] loading simple tokenizer extension from', resolvedLibraryPath) + db.loadExtension(resolvedLibraryPath) + if (config.useJieba) { + const resolvedDict = resolvePluginPath(config.dictPath) + if (resolvedDict && fs.existsSync(resolvedDict)) { + try { + logger?.debug?.('[Memory] configuring simple tokenizer jieba dict:', resolvedDict) + db.prepare('select jieba_dict(?)').get(resolvedDict) + } catch (err) { + logger?.warn?.('Failed to register jieba dict for simple extension:', err) + } + } else { + logger?.warn?.('Simple extension jieba dict path missing:', resolvedDict) + } + } + const tokenizer = config.useJieba ? 'simple_jieba' : 'simple' + const matchQuery = config.useJieba ? SIMPLE_MATCH_JIEBA : SIMPLE_MATCH_SIMPLE + simpleExtensionState.loaded = true + simpleExtensionState.error = null + simpleExtensionState.tokenizer = tokenizer + simpleExtensionState.matchQuery = matchQuery + logger?.info?.('[Memory] simple tokenizer initialised, tokenizer=%s, matchQuery=%s', tokenizer, matchQuery) + userMemoryFtsConfig = { + tokenizer, + matchQuery + } + groupMemoryFtsConfig = { + tokenizer, + matchQuery + } + return + } catch (error) { + logger?.error?.('Failed to load simple extension:', error) + resetSimpleState({ + requested: true, + enabled: true, + error: `Failed to load simple extension: ${error?.message || error}` + }) + } +} + +function loadSimpleExtensionForCleanup (db) { + if (!ChatGPTConfig.memory.extensions) { + ChatGPTConfig.memory.extensions = {} + } + if (!ChatGPTConfig.memory.extensions.simple) { + ChatGPTConfig.memory.extensions.simple = { + enable: false, + libraryPath: '', + dictPath: '', + useJieba: false + } + } + const config = ChatGPTConfig.memory.extensions.simple + let libraryPath = config.libraryPath || '' + let dictPath = config.dictPath || '' + if (!libraryPath) { + const detected = discoverSimplePaths() + libraryPath = detected.libraryPath + if (detected.dictPath && !dictPath) { + dictPath = detected.dictPath + } + if (libraryPath) { + ChatGPTConfig.memory.extensions.simple = ChatGPTConfig.memory.extensions.simple || {} + ChatGPTConfig.memory.extensions.simple.libraryPath = libraryPath + if (dictPath) { + ChatGPTConfig.memory.extensions.simple.dictPath = dictPath + } + } + } + const resolvedLibraryPath = resolvePluginPath(libraryPath) + if (!resolvedLibraryPath || !fs.existsSync(resolvedLibraryPath)) { + logger?.warn?.('[Memory] cleanup requires simple extension but library missing:', resolvedLibraryPath || '(empty path)') + return false + } + try { + logger?.info?.('[Memory] temporarily loading simple extension for cleanup tasks') + db.loadExtension(resolvedLibraryPath) + const useJieba = Boolean(config.useJieba) + if (useJieba) { + const resolvedDict = resolvePluginPath(dictPath) + if (resolvedDict && fs.existsSync(resolvedDict)) { + try { + db.prepare('select jieba_dict(?)').get(resolvedDict) + } catch (err) { + logger?.warn?.('Failed to set jieba dict during cleanup:', err) + } + } + } + return true + } catch (error) { + logger?.error?.('Failed to load simple extension for cleanup:', error) + return false + } +} + +function ensureGroupFactsTable (db) { + ensureMetaTable(db) + db.exec(` + CREATE TABLE IF NOT EXISTS group_facts ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + group_id TEXT NOT NULL, + fact TEXT NOT NULL, + topic TEXT, + importance REAL DEFAULT 0.5, + source_message_ids TEXT, + source_messages TEXT, + involved_users TEXT, + created_at TEXT DEFAULT (datetime('now')) + ) + `) + db.exec(` + CREATE UNIQUE INDEX IF NOT EXISTS idx_group_facts_unique + ON group_facts(group_id, fact) + `) + db.exec(` + CREATE INDEX IF NOT EXISTS idx_group_facts_group + ON group_facts(group_id, importance DESC, created_at DESC) + `) + ensureGroupFactsFtsTable(db) +} + +function ensureGroupHistoryCursorTable (db) { + ensureMetaTable(db) + db.exec(` + CREATE TABLE IF NOT EXISTS group_history_cursor ( + group_id TEXT PRIMARY KEY, + last_message_id TEXT, + last_timestamp INTEGER + ) + `) +} + +function ensureUserMemoryTable (db) { + ensureMetaTable(db) + db.exec(` + CREATE TABLE IF NOT EXISTS user_memory ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT NOT NULL, + group_id TEXT, + key TEXT NOT NULL, + value TEXT NOT NULL, + importance REAL DEFAULT 0.5, + source_message_id TEXT, + created_at TEXT DEFAULT (datetime('now')), + updated_at TEXT DEFAULT (datetime('now')) + ) + `) + db.exec(` + CREATE UNIQUE INDEX IF NOT EXISTS idx_user_memory_key + ON user_memory(user_id, coalesce(group_id, ''), key) + `) + db.exec(` + CREATE INDEX IF NOT EXISTS idx_user_memory_group + ON user_memory(group_id) + `) + db.exec(` + CREATE INDEX IF NOT EXISTS idx_user_memory_user + ON user_memory(user_id) + `) + ensureUserMemoryFtsTable(db) +} + +function dropGroupFactsFtsArtifacts (db) { + try { + db.exec(` + DROP TRIGGER IF EXISTS group_facts_ai; + DROP TRIGGER IF EXISTS group_facts_ad; + DROP TRIGGER IF EXISTS group_facts_au; + DROP TABLE IF EXISTS group_facts_fts; + `) + } catch (err) { + if (String(err?.message || '').includes('no such tokenizer')) { + const loaded = loadSimpleExtensionForCleanup(db) + if (loaded) { + db.exec(` + DROP TRIGGER IF EXISTS group_facts_ai; + DROP TRIGGER IF EXISTS group_facts_ad; + DROP TRIGGER IF EXISTS group_facts_au; + DROP TABLE IF EXISTS group_facts_fts; + `) + } else { + logger?.warn?.('[Memory] Falling back to raw schema cleanup for group_facts_fts') + try { + db.exec('PRAGMA writable_schema = ON;') + db.exec(`DELETE FROM sqlite_master WHERE name IN ('group_facts_ai','group_facts_ad','group_facts_au','group_facts_fts');`) + } finally { + db.exec('PRAGMA writable_schema = OFF;') + } + } + } else { + throw err + } + } +} + +function createGroupFactsFts (db, tokenizer) { + logger?.info?.('[Memory] creating group_facts_fts with tokenizer=%s', tokenizer) + db.exec(` + CREATE VIRTUAL TABLE group_facts_fts + USING fts5( + fact, + topic, + content = 'group_facts', + content_rowid = 'id', + tokenize = '${tokenizer}' + ) + `) + db.exec(` + CREATE TRIGGER group_facts_ai AFTER INSERT ON group_facts BEGIN + INSERT INTO group_facts_fts(rowid, fact, topic) + VALUES (new.id, new.fact, coalesce(new.topic, '')); + END; + `) + db.exec(` + CREATE TRIGGER group_facts_ad AFTER DELETE ON group_facts BEGIN + INSERT INTO group_facts_fts(group_facts_fts, rowid, fact, topic) + VALUES ('delete', old.id, old.fact, coalesce(old.topic, '')); + END; + `) + db.exec(` + CREATE TRIGGER group_facts_au AFTER UPDATE ON group_facts BEGIN + INSERT INTO group_facts_fts(group_facts_fts, rowid, fact, topic) + VALUES ('delete', old.id, old.fact, coalesce(old.topic, '')); + INSERT INTO group_facts_fts(rowid, fact, topic) + VALUES (new.id, new.fact, coalesce(new.topic, '')); + END; + `) + try { + db.exec(`INSERT INTO group_facts_fts(group_facts_fts) VALUES ('rebuild')`) + } catch (err) { + logger?.debug?.('Group facts FTS rebuild skipped:', err?.message || err) + } +} + +function ensureGroupFactsFtsTable (db) { + const desiredTokenizer = groupMemoryFtsConfig.tokenizer || TOKENIZER_DEFAULT + const storedTokenizer = getMetaValue(db, META_GROUP_TOKENIZER_KEY) + const tableExists = db.prepare(` + SELECT name FROM sqlite_master + WHERE type = 'table' AND name = 'group_facts_fts' + `).get() + if (storedTokenizer && storedTokenizer !== desiredTokenizer) { + dropGroupFactsFtsArtifacts(db) + } else if (!storedTokenizer && tableExists) { + // Unknown tokenizer, drop to ensure consistency. + dropGroupFactsFtsArtifacts(db) + } + const existsAfterDrop = db.prepare(` + SELECT name FROM sqlite_master + WHERE type = 'table' AND name = 'group_facts_fts' + `).get() + if (!existsAfterDrop) { + createGroupFactsFts(db, desiredTokenizer) + setMetaValue(db, META_GROUP_TOKENIZER_KEY, desiredTokenizer) + logger?.info?.('[Memory] group facts FTS initialised with tokenizer=%s', desiredTokenizer) + } +} + +function dropUserMemoryFtsArtifacts (db) { + try { + db.exec(` + DROP TRIGGER IF EXISTS user_memory_ai; + DROP TRIGGER IF EXISTS user_memory_ad; + DROP TRIGGER IF EXISTS user_memory_au; + DROP TABLE IF EXISTS user_memory_fts; + `) + } catch (err) { + if (String(err?.message || '').includes('no such tokenizer')) { + const loaded = loadSimpleExtensionForCleanup(db) + if (loaded) { + db.exec(` + DROP TRIGGER IF EXISTS user_memory_ai; + DROP TRIGGER IF EXISTS user_memory_ad; + DROP TRIGGER IF EXISTS user_memory_au; + DROP TABLE IF EXISTS user_memory_fts; + `) + } else { + logger?.warn?.('[Memory] Falling back to raw schema cleanup for user_memory_fts') + try { + db.exec('PRAGMA writable_schema = ON;') + db.exec(`DELETE FROM sqlite_master WHERE name IN ('user_memory_ai','user_memory_ad','user_memory_au','user_memory_fts');`) + } finally { + db.exec('PRAGMA writable_schema = OFF;') + } + } + } else { + throw err + } + } +} + +function createUserMemoryFts (db, tokenizer) { + logger?.info?.('[Memory] creating user_memory_fts with tokenizer=%s', tokenizer) + db.exec(` + CREATE VIRTUAL TABLE user_memory_fts + USING fts5( + value, + content = 'user_memory', + content_rowid = 'id', + tokenize = '${tokenizer}' + ) + `) + db.exec(` + CREATE TRIGGER user_memory_ai AFTER INSERT ON user_memory BEGIN + INSERT INTO user_memory_fts(rowid, value) + VALUES (new.id, new.value); + END; + `) + db.exec(` + CREATE TRIGGER user_memory_ad AFTER DELETE ON user_memory BEGIN + INSERT INTO user_memory_fts(user_memory_fts, rowid, value) + VALUES ('delete', old.id, old.value); + END; + `) + db.exec(` + CREATE TRIGGER user_memory_au AFTER UPDATE ON user_memory BEGIN + INSERT INTO user_memory_fts(user_memory_fts, rowid, value) + VALUES ('delete', old.id, old.value); + INSERT INTO user_memory_fts(rowid, value) + VALUES (new.id, new.value); + END; + `) + try { + db.exec(`INSERT INTO user_memory_fts(user_memory_fts) VALUES ('rebuild')`) + } catch (err) { + logger?.debug?.('User memory FTS rebuild skipped:', err?.message || err) + } +} + +function ensureUserMemoryFtsTable (db) { + const desiredTokenizer = userMemoryFtsConfig.tokenizer || TOKENIZER_DEFAULT + const storedTokenizer = getMetaValue(db, META_USER_TOKENIZER_KEY) + const tableExists = db.prepare(` + SELECT name FROM sqlite_master + WHERE type = 'table' AND name = 'user_memory_fts' + `).get() + if (storedTokenizer && storedTokenizer !== desiredTokenizer) { + dropUserMemoryFtsArtifacts(db) + } else if (!storedTokenizer && tableExists) { + dropUserMemoryFtsArtifacts(db) + } + const existsAfterDrop = db.prepare(` + SELECT name FROM sqlite_master + WHERE type = 'table' AND name = 'user_memory_fts' + `).get() + if (!existsAfterDrop) { + createUserMemoryFts(db, desiredTokenizer) + setMetaValue(db, META_USER_TOKENIZER_KEY, desiredTokenizer) + logger?.info?.('[Memory] user memory FTS initialised with tokenizer=%s', desiredTokenizer) + } +} + +function createVectorTable (db, dimension) { + if (!dimension || dimension <= 0) { + throw new Error(`Invalid vector dimension for table creation: ${dimension}`) + } + db.exec(`CREATE VIRTUAL TABLE vec_group_facts USING vec0(embedding float[${dimension}])`) +} + +function ensureVectorTable (db) { + ensureMetaTable(db) + if (cachedVectorDimension !== null) { + return cachedVectorDimension + } + const preferredDimension = resolvePreferredDimension() + const stored = getMetaValue(db, META_VECTOR_DIM_KEY) + const storedModel = getMetaValue(db, META_VECTOR_MODEL_KEY) + const currentModel = ChatGPTConfig.llm?.embeddingModel || '' + const tableExists = Boolean(db.prepare(` + SELECT name FROM sqlite_master + WHERE type = 'table' AND name = 'vec_group_facts' + `).get()) + + const parseDimension = value => { + if (!value && value !== 0) return 0 + const parsed = parseInt(String(value), 10) + return Number.isFinite(parsed) && parsed > 0 ? parsed : 0 + } + + const storedDimension = parseDimension(stored) + let dimension = storedDimension + let tablePresent = tableExists + + let needsTableReset = false + if (tableExists && storedDimension <= 0) { + needsTableReset = true + } + + if (needsTableReset && tableExists) { + try { + db.exec('DROP TABLE IF EXISTS vec_group_facts') + tablePresent = false + dimension = 0 + } catch (err) { + logger?.warn?.('[Memory] failed to drop vec_group_facts during dimension change:', err) + } + } + +if (!tablePresent) { + if (dimension <= 0) { + dimension = parseDimension(preferredDimension) + } + if (dimension > 0) { + try { + createVectorTable(db, dimension) + tablePresent = true + setMetaValue(db, META_VECTOR_MODEL_KEY, currentModel) + setMetaValue(db, META_VECTOR_DIM_KEY, String(dimension)) + cachedVectorDimension = dimension + cachedVectorModel = currentModel + return cachedVectorDimension + } catch (err) { + logger?.error?.('[Memory] failed to (re)create vec_group_facts table:', err) + dimension = 0 + } + } + } + + if (tablePresent && storedDimension > 0) { + cachedVectorDimension = storedDimension + cachedVectorModel = storedModel || currentModel + return cachedVectorDimension + } + + // At this point we failed to determine a valid dimension, set metadata to 0 to avoid loops. + setMetaValue(db, META_VECTOR_MODEL_KEY, currentModel) + setMetaValue(db, META_VECTOR_DIM_KEY, '0') + cachedVectorDimension = 0 + cachedVectorModel = currentModel + return cachedVectorDimension +} +export function resetVectorTableDimension (dimension) { + if (!Number.isFinite(dimension) || dimension <= 0) { + throw new Error(`Invalid vector dimension: ${dimension}`) + } + const db = getMemoryDatabase() + try { + db.exec('DROP TABLE IF EXISTS vec_group_facts') + } catch (err) { + logger?.warn?.('[Memory] failed to drop vec_group_facts:', err) + } + createVectorTable(db, dimension) + setMetaValue(db, META_VECTOR_DIM_KEY, dimension.toString()) + const model = ChatGPTConfig.llm?.embeddingModel || '' + setMetaValue(db, META_VECTOR_MODEL_KEY, model) + cachedVectorDimension = dimension + cachedVectorModel = model +} + +function migrate (db) { + ensureGroupFactsTable(db) + ensureGroupHistoryCursorTable(db) + ensureUserMemoryTable(db) + ensureVectorTable(db) +} + +export function getUserMemoryFtsConfig () { + return { ...userMemoryFtsConfig } +} + +export function getGroupMemoryFtsConfig () { + return { ...groupMemoryFtsConfig } +} + +export function getSimpleExtensionState () { + return { ...simpleExtensionState } +} + +export function sanitiseFtsQueryInput (query, ftsConfig) { + if (!query) { + return '' + } + if (ftsConfig?.matchQuery) { + return String(query).trim() + } + return sanitiseRawFtsInput(query) +} + +export function getMemoryDatabase () { + if (dbInstance) { + return dbInstance + } + const dbPath = resolveDbPath() + ensureDirectory(dbPath) + logger?.info?.('[Memory] opening memory database at %s', dbPath) + dbInstance = new Database(dbPath) + sqliteVec.load(dbInstance) + resetSimpleState({ + requested: false, + enabled: false + }) + applySimpleExtension(dbInstance) + migrate(dbInstance) + logger?.info?.('[Memory] memory database init completed (simple loaded=%s)', simpleExtensionState.loaded) + return dbInstance +} + +export function getVectorDimension () { + const currentModel = ChatGPTConfig.llm?.embeddingModel || '' + if (cachedVectorModel && cachedVectorModel !== currentModel) { + cachedVectorDimension = null + cachedVectorModel = null + } + if (cachedVectorDimension !== null) { + return cachedVectorDimension + } + const db = getMemoryDatabase() + return ensureVectorTable(db) +} + +export function resetCachedDimension () { + cachedVectorDimension = null + cachedVectorModel = null +} + +export function resetMemoryDatabaseInstance () { + if (dbInstance) { + try { + dbInstance.close() + } catch (error) { + console.warn('Failed to close memory database:', error) + } + } + dbInstance = null + cachedVectorDimension = null + cachedVectorModel = null +} diff --git a/models/memory/extractor.js b/models/memory/extractor.js new file mode 100644 index 0000000..fb6b1bb --- /dev/null +++ b/models/memory/extractor.js @@ -0,0 +1,306 @@ +import { SendMessageOption, Chaite } from 'chaite' +import ChatGPTConfig from '../../config/config.js' +import { getClientForModel } from '../chaite/vectorizer.js' + +function collectTextFromResponse (response) { + if (!response?.contents) { + return '' + } + return response.contents + .filter(content => content.type === 'text') + .map(content => content.text || '') + .join('\n') + .trim() +} + +function parseJSON (text) { + if (!text) { + return null + } + const trimmed = text.trim() + const codeBlockMatch = trimmed.match(/^```(?:json)?\s*([\s\S]*?)\s*```$/i) + const payload = codeBlockMatch ? codeBlockMatch[1] : trimmed + try { + return JSON.parse(payload) + } catch (err) { + logger.warn('Failed to parse JSON from memory extractor response:', text) + return null + } +} + +function formatEntry (entry) { + let str = '' + try { + if (typeof entry === 'string') { + str = entry + } else { + str = JSON.stringify(entry) + } + } catch (err) { + str = String(entry) + } + const limit = 200 + return str.length > limit ? str.slice(0, limit) + '…' : str +} + +function injectMessagesIntoTemplate (template, body) { + if (!template || typeof template !== 'string') { + return body + } + const placeholders = ['${messages}', '{messages}', '{{messages}}'] + let result = template + let replaced = false + for (const placeholder of placeholders) { + if (result.includes(placeholder)) { + result = result.split(placeholder).join(body) + replaced = true + } + } + if (!replaced) { + const trimmed = result.trim() + if (!trimmed) { + return body + } + if (/\n\s*$/.test(result)) { + return `${result}${body}` + } + return `${result}\n${body}` + } + return result +} + +async function resolvePresetSendMessageOption (presetId, scope) { + if (!presetId) { + return null + } + try { + const chaite = Chaite.getInstance?.() + if (!chaite) { + logger.warn(`[Memory] ${scope} extraction preset ${presetId} configured but Chaite is not initialized`) + return null + } + const presetManager = chaite.getChatPresetManager?.() + if (!presetManager) { + logger.warn(`[Memory] ${scope} extraction preset ${presetId} configured but preset manager unavailable`) + return null + } + const preset = await presetManager.getInstance(presetId) + if (!preset) { + logger.warn(`[Memory] ${scope} extraction preset ${presetId} not found`) + return null + } + logger.debug(`[Memory] using ${scope} extraction preset ${presetId}`) + return { + preset, + sendMessageOption: JSON.parse(JSON.stringify(preset.sendMessageOption || {})) + } + } catch (err) { + logger.error(`[Memory] failed to load ${scope} extraction preset ${presetId}:`, err) + return null + } +} + +function resolveGroupExtractionPrompts (presetSendMessageOption) { + const config = ChatGPTConfig.memory?.group || {} + const system = config.extractionSystemPrompt || presetSendMessageOption?.systemOverride || `You are a knowledge extraction assistant that specialises in summarising long-term facts from group chat transcripts. +Read the provided conversation and identify statements that should be stored as long-term knowledge for the group. +Return a JSON array. Each element must contain: +{ + "fact": 事实内容,必须完整包含事件的各个要素而不能是简单的短语(比如谁参与了事件、做了什么事情、背景时间是什么)(同一件事情尽可能整合为同一条而非拆分,以便利于检索), + "topic": 主题关键词,字符串,如 "活动"、"成员信息", + "importance": 一个介于0和1之间的小数,数值越大表示越重要, + "source_message_ids": 原始消息ID数组, + "source_messages": 对应原始消息的简要摘录或合并文本, + "involved_users": 出现或相关的用户ID数组 +} +Only include meaningful, verifiable group-specific information that is useful for future conversations. Do not record incomplete information. Do not include general knowledge or unrelated facts. Do not wrap the JSON array in code fences.` + const userTemplate = config.extractionUserPrompt || `以下是群聊中的一些消息,请根据系统说明提取值得长期记忆的事实,以JSON数组形式返回,不要输出额外说明。 + +\${messages}` + return { system, userTemplate } +} + +function buildGroupUserPrompt (messages, template) { + const joined = messages.map(msg => { + const sender = msg.nickname || msg.user_id || '未知用户' + return `${sender}: ${msg.text}` + }).join('\n') + return injectMessagesIntoTemplate(template, joined) +} + +function buildExistingMemorySection (existingMemories = []) { + if (!existingMemories || existingMemories.length === 0) { + return '当前没有任何已知的长期记忆。' + } + const lines = existingMemories.map((item, idx) => `${idx + 1}. ${item}`) + return `以下是关于用户的已知长期记忆,请在提取新记忆时参考,避免重复已有事实,并在信息变更时更新描述:\n${lines.join('\n')}` +} + +function resolveUserExtractionPrompts (existingMemories = [], presetSendMessageOption) { + const config = ChatGPTConfig.memory?.user || {} + const systemTemplate = config.extractionSystemPrompt || presetSendMessageOption?.systemOverride || `You are an assistant that extracts long-term personal preferences or persona details about a user. +Given a conversation snippet between the user and the bot, identify durable information such as preferences, nicknames, roles, speaking style, habits, or other facts that remain valid over time. +Return a JSON array of **strings**, and nothing else, without any other characters including \`\`\` or \`\`\`json. Each string must be a short sentence (in the same language as the conversation) describing one piece of long-term memory. Do not include keys, JSON objects, or additional metadata. Ignore temporary topics or uncertain information.` + const userTemplate = config.extractionUserPrompt || `下面是用户与机器人的对话,请根据系统提示提取可长期记忆的个人信息。 + +\${messages}` + return { + system: `${systemTemplate} + +${buildExistingMemorySection(existingMemories)}`, + userTemplate + } +} + +function buildUserPrompt (messages, template) { + const body = messages.map(msg => { + const prefix = msg.role === 'assistant' ? '机器人' : (msg.nickname || msg.user_id || '用户') + return `${prefix}: ${msg.text}` + }).join('\n') + return injectMessagesIntoTemplate(template, body) +} + +async function callModel ({ prompt, systemPrompt, model, maxToken = 4096, temperature = 0.2, sendMessageOption }) { + const options = sendMessageOption + ? JSON.parse(JSON.stringify(sendMessageOption)) + : {} + options.model = model || options.model + if (!options.model) { + throw new Error('No model available for memory extraction call') + } + const resolvedModel = options.model + const { client } = await getClientForModel(resolvedModel) + const response = await client.sendMessage({ + role: 'user', + content: [ + { + type: 'text', + text: prompt + } + ] + }, SendMessageOption.create({ + ...options, + model: options.model, + temperature: options.temperature ?? temperature, + maxToken: options.maxToken ?? maxToken, + systemOverride: systemPrompt ?? options.systemOverride, + disableHistoryRead: true, + disableHistorySave: true, + stream: false + })) + return collectTextFromResponse(response) +} + +function resolveGroupExtractionModel (presetSendMessageOption) { + const config = ChatGPTConfig.memory?.group + if (config?.extractionModel) { + return config.extractionModel + } + if (presetSendMessageOption?.model) { + return presetSendMessageOption.model + } + if (ChatGPTConfig.llm?.defaultModel) { + return ChatGPTConfig.llm.defaultModel + } + return '' +} + +function resolveUserExtractionModel (presetSendMessageOption) { + const config = ChatGPTConfig.memory?.user + if (config?.extractionModel) { + return config.extractionModel + } + if (presetSendMessageOption?.model) { + return presetSendMessageOption.model + } + if (ChatGPTConfig.llm?.defaultModel) { + return ChatGPTConfig.llm.defaultModel + } + return '' +} + +export async function extractGroupFacts (messages) { + if (!messages || messages.length === 0) { + return [] + } + const groupConfig = ChatGPTConfig.memory?.group || {} + const presetInfo = await resolvePresetSendMessageOption(groupConfig.extractionPresetId, 'group') + const presetOptions = presetInfo?.sendMessageOption + const model = resolveGroupExtractionModel(presetOptions) + if (!model) { + logger.warn('No model configured for group memory extraction') + return [] + } + try { + const prompts = resolveGroupExtractionPrompts(presetOptions) + logger.debug(`[Memory] start group fact extraction, messages=${messages.length}, model=${model}${presetInfo?.preset ? `, preset=${presetInfo.preset.id}` : ''}`) + const text = await callModel({ + prompt: buildGroupUserPrompt(messages, prompts.userTemplate), + systemPrompt: prompts.system, + model, + sendMessageOption: presetOptions + }) + const parsed = parseJSON(text) + if (Array.isArray(parsed)) { + logger.info(`[Memory] extracted ${parsed.length} group facts`) + parsed.slice(0, 10).forEach((item, idx) => { + logger.debug(`[Memory] group fact[${idx}] ${formatEntry(item)}`) + }) + return parsed + } + logger.debug('[Memory] group fact extraction returned non-array content') + return [] + } catch (err) { + logger.error('Failed to extract group facts:', err) + return [] + } +} + +export async function extractUserMemories (messages, existingMemories = []) { + if (!messages || messages.length === 0) { + return [] + } + const userConfig = ChatGPTConfig.memory?.user || {} + const presetInfo = await resolvePresetSendMessageOption(userConfig.extractionPresetId, 'user') + const presetOptions = presetInfo?.sendMessageOption + const model = resolveUserExtractionModel(presetOptions) + if (!model) { + logger.warn('No model configured for user memory extraction') + return [] + } + try { + const prompts = resolveUserExtractionPrompts(existingMemories, presetOptions) + logger.debug(`[Memory] start user memory extraction, snippets=${messages.length}, existing=${existingMemories.length}, model=${model}${presetInfo?.preset ? `, preset=${presetInfo.preset.id}` : ''}`) + const text = await callModel({ + prompt: buildUserPrompt(messages, prompts.userTemplate), + systemPrompt: prompts.system, + model, + sendMessageOption: presetOptions + }) + const parsed = parseJSON(text) + if (Array.isArray(parsed)) { + const sentences = parsed.map(item => { + if (typeof item === 'string') { + return item.trim() + } + if (item && typeof item === 'object') { + const possible = item.sentence || item.text || item.value || item.fact + if (possible) { + return String(possible).trim() + } + } + return '' + }).filter(Boolean) + logger.info(`[Memory] extracted ${sentences.length} user memories`) + sentences.slice(0, 10).forEach((item, idx) => { + logger.debug(`[Memory] user memory[${idx}] ${formatEntry(item)}`) + }) + return sentences + } + logger.debug('[Memory] user memory extraction returned non-array content') + return [] + } catch (err) { + logger.error('Failed to extract user memories:', err) + return [] + } +} diff --git a/models/memory/groupHistoryCursorStore.js b/models/memory/groupHistoryCursorStore.js new file mode 100644 index 0000000..b570e82 --- /dev/null +++ b/models/memory/groupHistoryCursorStore.js @@ -0,0 +1,61 @@ +import { getMemoryDatabase } from './database.js' + +function normaliseGroupId (groupId) { + if (groupId === null || groupId === undefined) { + return null + } + const str = String(groupId).trim() + return str || null +} + +export class GroupHistoryCursorStore { + constructor (db = getMemoryDatabase()) { + this.resetDatabase(db) + } + + resetDatabase (db = getMemoryDatabase()) { + this.db = db + this.selectStmt = this.db.prepare(` + SELECT last_message_id, last_timestamp + FROM group_history_cursor + WHERE group_id = ? + `) + this.upsertStmt = this.db.prepare(` + INSERT INTO group_history_cursor (group_id, last_message_id, last_timestamp) + VALUES (@group_id, @last_message_id, @last_timestamp) + ON CONFLICT(group_id) DO UPDATE SET + last_message_id = excluded.last_message_id, + last_timestamp = excluded.last_timestamp + `) + } + + ensureDb () { + if (!this.db || this.db.open === false) { + logger?.debug?.('[Memory] refreshing group history cursor database connection') + this.resetDatabase() + } + return this.db + } + + getCursor (groupId) { + const gid = normaliseGroupId(groupId) + if (!gid) return null + this.ensureDb() + return this.selectStmt.get(gid) || null + } + + updateCursor (groupId, { lastMessageId = null, lastTimestamp = null } = {}) { + const gid = normaliseGroupId(groupId) + if (!gid) return false + this.ensureDb() + const payload = { + group_id: gid, + last_message_id: lastMessageId ? String(lastMessageId) : null, + last_timestamp: (typeof lastTimestamp === 'number' && Number.isFinite(lastTimestamp)) ? Math.floor(lastTimestamp) : null + } + this.upsertStmt.run(payload) + return true + } +} + +export const groupHistoryCursorStore = new GroupHistoryCursorStore() diff --git a/models/memory/groupMemoryStore.js b/models/memory/groupMemoryStore.js new file mode 100644 index 0000000..459c3c0 --- /dev/null +++ b/models/memory/groupMemoryStore.js @@ -0,0 +1,515 @@ +import { getMemoryDatabase, getVectorDimension, getGroupMemoryFtsConfig, resetVectorTableDimension, sanitiseFtsQueryInput } from './database.js' +import ChatGPTConfig from '../../config/config.js' +import { embedTexts } from '../chaite/vectorizer.js' + +function toJSONString (value) { + if (!value) { + return '[]' + } + if (Array.isArray(value)) { + return JSON.stringify(value) + } + return typeof value === 'string' ? value : JSON.stringify(value) +} + +function toVectorBuffer (vector) { + if (!vector) { + return null + } + if (vector instanceof Float32Array) { + return Buffer.from(vector.buffer) + } + if (ArrayBuffer.isView(vector)) { + return Buffer.from(new Float32Array(vector).buffer) + } + return Buffer.from(new Float32Array(vector).buffer) +} + +function normaliseEmbeddingVector (vector) { + if (!vector) { + return null + } + if (Array.isArray(vector)) { + return vector + } + if (ArrayBuffer.isView(vector)) { + return Array.from(vector) + } + if (typeof vector === 'object') { + if (Array.isArray(vector.embedding)) { + return vector.embedding + } + if (ArrayBuffer.isView(vector.embedding)) { + return Array.from(vector.embedding) + } + if (Array.isArray(vector.vector)) { + return vector.vector + } + if (ArrayBuffer.isView(vector.vector)) { + return Array.from(vector.vector) + } + } + return null +} + +function normaliseGroupId (groupId) { + return groupId === null || groupId === undefined ? null : String(groupId) +} + +export class GroupMemoryStore { + constructor (db = getMemoryDatabase()) { + this.resetDatabase(db) + } + + resetDatabase (db = getMemoryDatabase()) { + this.db = db + this.insertFactStmt = this.db.prepare(` + INSERT INTO group_facts (group_id, fact, topic, importance, source_message_ids, source_messages, involved_users) + VALUES (@group_id, @fact, @topic, @importance, @source_message_ids, @source_messages, @involved_users) + ON CONFLICT(group_id, fact) DO UPDATE SET + topic = excluded.topic, + importance = excluded.importance, + source_message_ids = excluded.source_message_ids, + source_messages = excluded.source_messages, + involved_users = excluded.involved_users, + created_at = CASE + WHEN excluded.importance > group_facts.importance THEN datetime('now') + ELSE group_facts.created_at + END + `) + this.prepareVectorStatements() + this.loadFactByIdStmt = this.db.prepare('SELECT * FROM group_facts WHERE id = ?') + } + + prepareVectorStatements () { + try { + this.deleteVecStmt = this.db.prepare('DELETE FROM vec_group_facts WHERE rowid = ?') + this.insertVecStmt = this.db.prepare('INSERT INTO vec_group_facts(rowid, embedding) VALUES (?, ?)') + } catch (err) { + this.deleteVecStmt = null + this.insertVecStmt = null + logger?.debug?.('[Memory] vector table not ready, postpone statement preparation') + } + } + + ensureDb () { + if (!this.db || this.db.open === false) { + logger?.debug?.('[Memory] refreshing group memory database connection') + this.resetDatabase() + } + return this.db + } + + get embeddingModel () { + return ChatGPTConfig.llm?.embeddingModel || '' + } + + get retrievalMode () { + const mode = ChatGPTConfig.memory?.group?.retrievalMode || 'hybrid' + const lowered = String(mode).toLowerCase() + if (['vector', 'keyword', 'hybrid'].includes(lowered)) { + return lowered + } + return 'hybrid' + } + + get hybridPrefer () { + const prefer = ChatGPTConfig.memory?.group?.hybridPrefer || 'vector-first' + return prefer === 'keyword-first' ? 'keyword-first' : 'vector-first' + } + + isVectorEnabled () { + return Boolean(this.embeddingModel) + } + + get vectorDistanceThreshold () { + const value = Number(ChatGPTConfig.memory?.group?.vectorMaxDistance) + if (Number.isFinite(value) && value > 0) { + return value + } + return null + } + + get bm25Threshold () { + const value = Number(ChatGPTConfig.memory?.group?.textMaxBm25Score) + if (Number.isFinite(value) && value > 0) { + return value + } + return null + } + + async saveFacts (groupId, facts) { + if (!facts || facts.length === 0) { + return [] + } + this.ensureDb() + const normGroupId = normaliseGroupId(groupId) + const filteredFacts = facts + .map(f => { + const rawFact = typeof f.fact === 'string' ? f.fact : (Array.isArray(f.fact) ? f.fact.join(' ') : String(f.fact || '')) + const rawTopic = typeof f.topic === 'string' ? f.topic : (f.topic === undefined || f.topic === null ? '' : String(f.topic)) + const rawSourceMessages = f.source_messages ?? f.sourceMessages ?? '' + const sourceMessages = Array.isArray(rawSourceMessages) + ? rawSourceMessages.map(item => (item === null || item === undefined) ? '' : String(item)).filter(Boolean).join('\n') + : (typeof rawSourceMessages === 'string' ? rawSourceMessages : String(rawSourceMessages || '')) + return { + fact: rawFact.trim(), + topic: rawTopic.trim(), + importance: typeof f.importance === 'number' ? f.importance : Number(f.importance) || 0.5, + source_message_ids: toJSONString(f.source_message_ids || f.sourceMessages), + source_messages: sourceMessages, + involved_users: toJSONString(f.involved_users || f.involvedUsers || []) + } + }) + .filter(item => item.fact) + + if (filteredFacts.length === 0) { + return [] + } + + let vectors = [] + let tableDimension = getVectorDimension() || 0 + const configuredDimension = Number(ChatGPTConfig.llm?.dimensions || 0) + if (this.isVectorEnabled()) { + try { + const preferredDimension = configuredDimension > 0 + ? configuredDimension + : (tableDimension > 0 ? tableDimension : undefined) + vectors = await embedTexts(filteredFacts.map(f => f.fact), this.embeddingModel, preferredDimension) + vectors = vectors.map(normaliseEmbeddingVector) + const mismatchVector = vectors.find(vec => { + if (!vec) return false + if (Array.isArray(vec)) return vec.length > 0 + if (ArrayBuffer.isView(vec) && typeof vec.length === 'number') { + return vec.length > 0 + } + return false + }) + const actualDimension = mismatchVector ? mismatchVector.length : 0 + if (actualDimension && actualDimension !== tableDimension) { + const expectedDimension = tableDimension || preferredDimension || configuredDimension || 'unknown' + logger.warn(`[Memory] embedding dimension mismatch, expected=${expectedDimension}, actual=${actualDimension}. Recreating vector table.`) + try { + resetVectorTableDimension(actualDimension) + this.prepareVectorStatements() + tableDimension = actualDimension + } catch (resetErr) { + logger.error('Failed to reset vector table dimension:', resetErr) + vectors = [] + } + } else if (actualDimension && tableDimension <= 0) { + try { + resetVectorTableDimension(actualDimension) + this.prepareVectorStatements() + tableDimension = actualDimension + } catch (resetErr) { + logger.error('Failed to initialise vector table dimension:', resetErr) + vectors = [] + } + } + } catch (err) { + logger.error('Failed to embed group facts:', err) + vectors = [] + } + } + + const transaction = this.db.transaction((items, vectorList) => { + const saved = [] + for (let i = 0; i < items.length; i++) { + const payload = { + group_id: normGroupId, + ...items[i] + } + const info = this.insertFactStmt.run(payload) + let factId = Number(info.lastInsertRowid) + if (!factId) { + const existing = this.db.prepare('SELECT id FROM group_facts WHERE group_id = ? AND fact = ?').get(normGroupId, payload.fact) + factId = existing?.id + } + factId = Number.parseInt(String(factId ?? ''), 10) + if (!Number.isSafeInteger(factId)) { + logger.warn('[Memory] skip fact vector upsert due to invalid fact id', factId) + continue + } + if (!factId) { + continue + } + if (Array.isArray(vectorList) && vectorList[i]) { + if (!this.deleteVecStmt || !this.insertVecStmt) { + this.prepareVectorStatements() + } + if (!this.deleteVecStmt || !this.insertVecStmt) { + logger.warn('[Memory] vector table unavailable, skip vector upsert') + continue + } + try { + const vector = normaliseEmbeddingVector(vectorList[i]) + if (!vector) { + continue + } + let embeddingArray + if (ArrayBuffer.isView(vector)) { + if (vector instanceof Float32Array) { + embeddingArray = vector + } else { + embeddingArray = new Float32Array(vector.length) + for (let idx = 0; idx < vector.length; idx++) { + embeddingArray[idx] = Number(vector[idx]) + } + } + } else { + embeddingArray = Float32Array.from(vector) + } + const rowId = BigInt(factId) + logger.debug(`[Memory] upserting vector for fact ${factId}, rowIdType=${typeof rowId}`) + this.deleteVecStmt.run(rowId) + this.insertVecStmt.run(rowId, embeddingArray) + } catch (error) { + logger.error(`Failed to upsert vector for fact ${factId}:`, error) + } + } + saved.push(this.loadFactByIdStmt.get(factId)) + } + return saved + }) + + return transaction(filteredFacts, vectors) + } + + listFacts (groupId, limit = 50, offset = 0) { + return this.db.prepare(` + SELECT * FROM group_facts + WHERE group_id = ? + ORDER BY importance DESC, created_at DESC + LIMIT ? OFFSET ? + `).all(normaliseGroupId(groupId), limit, offset) + } + + deleteFact (groupId, factId) { + this.ensureDb() + const normGroupId = normaliseGroupId(groupId) + const fact = this.db.prepare('SELECT id FROM group_facts WHERE id = ? AND group_id = ?').get(factId, normGroupId) + if (!fact) { + return false + } + this.db.prepare('DELETE FROM group_facts WHERE id = ?').run(factId) + try { + this.deleteVecStmt.run(BigInt(factId)) + } catch (err) { + logger?.warn?.(`[Memory] failed to delete vector for fact ${factId}:`, err) + } + return true + } + + async vectorSearch (groupId, queryText, limit) { + this.ensureDb() + if (!this.isVectorEnabled()) { + return [] + } + try { + let tableDimension = getVectorDimension() || 0 + if (!tableDimension || tableDimension <= 0) { + logger.debug('[Memory] vector table dimension unavailable, attempting to infer from embedding') + } + const requestedDimension = tableDimension > 0 ? tableDimension : undefined + const [embedding] = await embedTexts([queryText], this.embeddingModel, requestedDimension) + if (!embedding) { + return [] + } + const embeddingVector = ArrayBuffer.isView(embedding) ? embedding : Float32Array.from(embedding) + const actualDimension = embeddingVector.length + if (!actualDimension) { + logger.debug('[Memory] vector search skipped: empty embedding returned') + return [] + } + if (tableDimension > 0 && actualDimension !== tableDimension) { + logger.warn(`[Memory] vector dimension mismatch detected during search, table=${tableDimension}, embedding=${actualDimension}. Rebuilding vector table.`) + try { + resetVectorTableDimension(actualDimension) + this.prepareVectorStatements() + tableDimension = actualDimension + } catch (resetErr) { + logger.error('Failed to reset vector table dimension during search:', resetErr) + return [] + } + logger.info('[Memory] vector table rebuilt; old vectors must be regenerated before vector search can return results') + return [] + } else if (tableDimension <= 0 && actualDimension > 0) { + try { + resetVectorTableDimension(actualDimension) + this.prepareVectorStatements() + tableDimension = actualDimension + } catch (resetErr) { + logger.error('Failed to initialise vector table dimension during search:', resetErr) + return [] + } + } + const rows = this.db.prepare(` + SELECT gf.*, vec_group_facts.distance AS distance + FROM vec_group_facts + JOIN group_facts gf ON gf.id = vec_group_facts.rowid + WHERE gf.group_id = ? + AND vec_group_facts.embedding MATCH ? + AND vec_group_facts.k = ${limit} + ORDER BY distance ASC + `).all(groupId, embeddingVector) + const threshold = this.vectorDistanceThreshold + if (!threshold) { + return rows + } + return rows.filter(row => typeof row?.distance === 'number' && row.distance <= threshold) + } catch (err) { + logger.warn('Vector search failed for group memory:', err) + return [] + } + } + + textSearch (groupId, queryText, limit) { + this.ensureDb() + if (!queryText || !queryText.trim()) { + return [] + } + const originalQuery = queryText.trim() + const ftsConfig = getGroupMemoryFtsConfig() + const matchQueryParam = sanitiseFtsQueryInput(originalQuery, ftsConfig) + const results = [] + const seen = new Set() + if (matchQueryParam) { + const matchExpression = ftsConfig.matchQuery ? `${ftsConfig.matchQuery}(?)` : '?' + try { + const rows = this.db.prepare(` + SELECT gf.*, bm25(group_facts_fts) AS bm25_score + FROM group_facts_fts + JOIN group_facts gf ON gf.id = group_facts_fts.rowid + WHERE gf.group_id = ? + AND group_facts_fts MATCH ${matchExpression} + ORDER BY bm25_score ASC + LIMIT ? + `).all(groupId, matchQueryParam, limit) + for (const row of rows) { + const bm25Threshold = this.bm25Threshold + if (bm25Threshold) { + const score = Number(row?.bm25_score) + if (!Number.isFinite(score) || score > bm25Threshold) { + continue + } + row.bm25_score = score + } + if (row && !seen.has(row.id)) { + results.push(row) + seen.add(row.id) + } + } + } catch (err) { + logger.warn('Text search failed for group memory:', err) + } + } else { + logger.debug('[Memory] group memory text search skipped MATCH due to empty query after sanitisation') + } + + if (results.length < limit) { + try { + const likeRows = this.db.prepare(` + SELECT * + FROM group_facts + WHERE group_id = ? + AND instr(fact, ?) > 0 + ORDER BY importance DESC, created_at DESC + LIMIT ? + `).all(groupId, originalQuery, Math.max(limit * 2, limit)) + for (const row of likeRows) { + if (row && !seen.has(row.id)) { + results.push(row) + seen.add(row.id) + if (results.length >= limit) { + break + } + } + } + } catch (err) { + logger.warn('LIKE fallback failed for group memory:', err) + } + } + + return results.slice(0, limit) + } + + importanceFallback (groupId, limit, minImportance, excludeIds = []) { + this.ensureDb() + const ids = excludeIds.filter(Boolean) + const notInClause = ids.length ? `AND id NOT IN (${ids.map(() => '?').join(',')})` : '' + const stmt = this.db.prepare(` + SELECT * FROM group_facts + WHERE group_id = ? + AND importance >= ? + ${notInClause} + ORDER BY importance DESC, created_at DESC + LIMIT ? + `) + const params = [groupId, minImportance] + if (ids.length) { + params.push(...ids) + } + params.push(limit) + return stmt.all(...params) + } + + /** + * 获取相关群记忆,支持向量/文本/混合检索 + * @param {string} groupId + * @param {string} queryText + * @param {{ limit?: number, minImportance?: number }} options + * @returns {Promise>} + */ + async queryRelevantFacts (groupId, queryText, options = {}) { + const { limit = 5, minImportance = 0 } = options + const normGroupId = normaliseGroupId(groupId) + if (!queryText) { + return this.listFacts(normGroupId, limit) + } + + const mode = this.retrievalMode + const combined = [] + const seen = new Set() + const append = rows => { + for (const row of rows) { + if (!row || seen.has(row.id)) { + continue + } + combined.push(row) + seen.add(row.id) + if (combined.length >= limit) { + break + } + } + } + + const preferVector = this.hybridPrefer !== 'keyword-first' + + if (mode === 'vector' || mode === 'hybrid') { + const vectorRows = await this.vectorSearch(normGroupId, queryText, limit) + if (mode === 'vector') { + append(vectorRows) + } else if (preferVector) { + append(vectorRows) + if (combined.length < limit) { + append(this.textSearch(normGroupId, queryText, limit)) + } + } else { + append(this.textSearch(normGroupId, queryText, limit)) + if (combined.length < limit) { + append(vectorRows) + } + } + } else if (mode === 'keyword') { + append(this.textSearch(normGroupId, queryText, limit)) + } + + if (combined.length < limit) { + const fallback = this.importanceFallback(normGroupId, limit - combined.length, minImportance, Array.from(seen)) + append(fallback) + } + + return combined.slice(0, limit) + } +} diff --git a/models/memory/prompt.js b/models/memory/prompt.js new file mode 100644 index 0000000..7093e26 --- /dev/null +++ b/models/memory/prompt.js @@ -0,0 +1,128 @@ +import ChatGPTConfig from '../../config/config.js' +import { memoryService } from './service.js' + +function renderTemplate (template, context = {}) { + if (!template) { + return '' + } + return template.replace(/\$\{(\w+)\}/g, (_, key) => { + const value = context[key] + return value === undefined || value === null ? '' : String(value) + }) +} + +function formatUserMemories (memories, config) { + if (!memories.length) { + return '' + } + const headerTemplate = config.promptHeader ?? '# 用户画像' + const itemTemplate = config.promptItemTemplate ?? '- ${value}' + const footerTemplate = config.promptFooter ?? '' + const segments = [] + const header = renderTemplate(headerTemplate, { count: memories.length }) + if (header) { + segments.push(header) + } + memories.forEach((item, index) => { + const timestamp = item.updated_at || item.created_at || '' + const timeSuffix = timestamp ? `(记录时间:${timestamp})` : '' + const context = { + index, + order: index + 1, + value: item.value || '', + importance: item.importance ?? '', + sourceMessageId: item.source_message_id || '', + sourceId: item.source_message_id || '', + groupId: item.group_id || '', + createdAt: item.created_at || '', + updatedAt: item.updated_at || '', + timestamp, + time: timestamp, + timeSuffix + } + const line = renderTemplate(itemTemplate, context) + if (line) { + segments.push(line) + } + }) + const footer = renderTemplate(footerTemplate, { count: memories.length }) + if (footer) { + segments.push(footer) + } + return segments.join('\n') +} + +function formatGroupFacts (facts, config) { + if (!facts.length) { + return '' + } + const headerTemplate = config.promptHeader ?? '# 群聊长期记忆' + const itemTemplate = config.promptItemTemplate ?? '- ${fact}${topicSuffix}' + const footerTemplate = config.promptFooter ?? '' + const segments = [] + const header = renderTemplate(headerTemplate, { count: facts.length }) + if (header) { + segments.push(header) + } + facts.forEach((item, index) => { + const topicSuffix = item.topic ? `(${item.topic})` : '' + const timestamp = item.updated_at || item.created_at || '' + const timeSuffix = timestamp ? `(记录时间:${timestamp})` : '' + const context = { + index, + order: index + 1, + fact: item.fact || '', + topic: item.topic || '', + topicSuffix, + importance: item.importance ?? '', + createdAt: item.created_at || '', + updatedAt: item.updated_at || '', + timestamp, + time: timestamp, + timeSuffix, + distance: item.distance ?? '', + bm25: item.bm25_score ?? '', + sourceMessages: item.source_messages || '', + sourceMessageIds: item.source_message_ids || '' + } + const line = renderTemplate(itemTemplate, context) + if (line) { + segments.push(line) + } + }) + const footer = renderTemplate(footerTemplate, { count: facts.length }) + if (footer) { + segments.push(footer) + } + return segments.join('\n') +} + +export async function buildMemoryPrompt ({ userId, groupId, queryText }) { + const segments = [] + const userConfig = ChatGPTConfig.memory?.user || {} + const groupConfig = ChatGPTConfig.memory?.group || {} + if (memoryService.isUserMemoryEnabled(userId)) { + const totalLimit = userConfig.maxItemsPerInjection || 5 + const searchLimit = Math.min(userConfig.maxRelevantItemsPerQuery || totalLimit, totalLimit) + const userMemories = memoryService.queryUserMemories(userId, groupId, queryText, { + totalLimit, + searchLimit, + minImportance: userConfig.minImportanceForInjection ?? 0 + }) + const userSegment = formatUserMemories(userMemories, userConfig) + if (userSegment) { + segments.push(userSegment) + } + } + if (groupId && memoryService.isGroupMemoryEnabled(groupId)) { + const facts = await memoryService.queryGroupFacts(groupId, queryText, { + limit: groupConfig.maxFactsPerInjection || 5, + minImportance: groupConfig.minImportanceForInjection || 0 + }) + const groupSegment = formatGroupFacts(facts, groupConfig) + if (groupSegment) { + segments.push(groupSegment) + } + } + return segments.join('\n\n').trim() +} diff --git a/models/memory/router.js b/models/memory/router.js new file mode 100644 index 0000000..0395013 --- /dev/null +++ b/models/memory/router.js @@ -0,0 +1,726 @@ +import express from 'express' +import fs from 'fs' +import os from 'os' +import path from 'path' +import https from 'https' +import { pipeline } from 'stream' +import { promisify } from 'util' +let AdmZip +try { + AdmZip = (await import('adm-zip')).default +} catch (e) { + logger.warn('Failed to load AdmZip, maybe you need to install it manually:', e) +} +import { execSync } from "child_process" +import { + Chaite, + ChaiteResponse, + FrontEndAuthHandler +} from 'chaite' +import ChatGPTConfig from '../../config/config.js' +import { memoryService } from './service.js' +import { + resetCachedDimension, + resetMemoryDatabaseInstance, + getSimpleExtensionState, + resolvePluginPath, + toPluginRelativePath, + resetVectorTableDimension +} from './database.js' + +const streamPipeline = promisify(pipeline) + +const SIMPLE_DOWNLOAD_BASE_URL = 'https://github.com/wangfenjin/simple/releases/latest/download' +const SIMPLE_ASSET_MAP = { + 'linux-x64': 'libsimple-linux-ubuntu-latest.zip', + 'linux-arm64': 'libsimple-linux-ubuntu-24.04-arm.zip', + 'linux-arm': 'libsimple-linux-ubuntu-24.04-arm.zip', + 'darwin-x64': 'libsimple-osx-x64.zip', + 'darwin-arm64': 'libsimple-osx-x64.zip', + 'win32-x64': 'libsimple-windows-x64.zip', + 'win32-ia32': 'libsimple-windows-x86.zip', + 'win32-arm64': 'libsimple-windows-arm64.zip' +} +const DEFAULT_SIMPLE_INSTALL_DIR = 'resources/simple' + +export function authenticateMemoryRequest (req, res, next) { + const bearer = req.header('Authorization') || '' + const token = bearer.replace(/^Bearer\s+/i, '').trim() + if (!token) { + res.status(401).json({ message: 'Access denied, token missing' }) + return + } + try { + const authKey = Chaite.getInstance()?.getGlobalConfig()?.getAuthKey() + if (authKey && FrontEndAuthHandler.validateJWT(authKey, token)) { + next() + return + } + res.status(401).json({ message: 'Invalid token' }) + } catch (error) { + res.status(401).json({ message: 'Invalid token format' }) + } +} + +function parsePositiveInt (value, fallback) { + const num = Number(value) + return Number.isInteger(num) && num >= 0 ? num : fallback +} + +function parseNumber (value, fallback) { + const num = Number(value) + return Number.isFinite(num) ? num : fallback +} + +function toStringArray (value) { + if (!Array.isArray(value)) { + return [] + } + return value + .map(item => { + if (item === undefined || item === null) { + return null + } + return String(item).trim() + }) + .filter(item => item) +} + +function parseOptionalStringParam (value) { + if (Array.isArray(value)) { + value = value[0] + } + if (value === undefined || value === null) { + return null + } + const trimmed = String(value).trim() + if (!trimmed || trimmed.toLowerCase() === 'null' || trimmed.toLowerCase() === 'undefined') { + return null + } + return trimmed +} + +function detectAssetKey (platform, arch) { + const normalizedArch = arch === 'arm64' ? 'arm64' : (arch === 'arm' ? 'arm' : (arch === 'ia32' ? 'ia32' : 'x64')) + const key = `${platform}-${normalizedArch}` + if (SIMPLE_ASSET_MAP[key]) { + return key + } + if (platform === 'darwin' && SIMPLE_ASSET_MAP['darwin-x64']) { + return 'darwin-x64' + } + if (platform === 'linux' && SIMPLE_ASSET_MAP['linux-x64']) { + return 'linux-x64' + } + if (platform === 'win32' && SIMPLE_ASSET_MAP['win32-x64']) { + return 'win32-x64' + } + return null +} + +function resolveSimpleAsset (requestedKey, requestedAsset) { + if (requestedAsset) { + return { + key: requestedKey || 'custom', + asset: requestedAsset + } + } + if (requestedKey && SIMPLE_ASSET_MAP[requestedKey]) { + return { + key: requestedKey, + asset: SIMPLE_ASSET_MAP[requestedKey] + } + } + const autoKey = detectAssetKey(process.platform, process.arch) + if (autoKey && SIMPLE_ASSET_MAP[autoKey]) { + return { key: autoKey, asset: SIMPLE_ASSET_MAP[autoKey] } + } + return { key: null, asset: null } +} + +function ensureDirectoryExists (dirPath) { + if (!fs.existsSync(dirPath)) { + fs.mkdirSync(dirPath, { recursive: true }) + } +} + +async function downloadToFile (url, destination, redirectCount = 0) { + if (redirectCount > 5) { + throw new Error('Too many redirects while downloading extension') + } + await new Promise((resolve, reject) => { + const request = https.get(url, { + headers: { + 'User-Agent': 'chatgpt-plugin-memory-extension-downloader' + } + }, async res => { + if (res.statusCode && res.statusCode >= 300 && res.statusCode < 400 && res.headers.location) { + res.resume() + try { + await downloadToFile(res.headers.location, destination, redirectCount + 1) + resolve() + } catch (err) { + reject(err) + } + return + } + if (res.statusCode !== 200) { + reject(new Error(`Failed to download extension: HTTP ${res.statusCode}`)) + res.resume() + return + } + const fileStream = fs.createWriteStream(destination) + streamPipeline(res, fileStream).then(resolve).catch(reject) + }) + request.on('error', error => reject(error)) + }) +} + +function removeDirectoryIfExists (dirPath) { + if (fs.existsSync(dirPath)) { + fs.rmSync(dirPath, { recursive: true, force: true }) + } +} + +function findLibraryFile (rootDir) { + const entries = fs.readdirSync(rootDir, { withFileTypes: true }) + for (const entry of entries) { + const fullPath = path.join(rootDir, entry.name) + if (entry.isDirectory()) { + const found = findLibraryFile(fullPath) + if (found) { + return found + } + } else if (/simple\.(so|dylib|dll)$/i.test(entry.name) || /^libsimple/i.test(entry.name)) { + return fullPath + } + } + return null +} + +function findDictDirectory (rootDir) { + const directDictPath = path.join(rootDir, 'dict') + if (fs.existsSync(directDictPath) && fs.statSync(directDictPath).isDirectory()) { + return directDictPath + } + const entries = fs.readdirSync(rootDir, { withFileTypes: true }) + for (const entry of entries) { + if (entry.isDirectory()) { + const match = findDictDirectory(path.join(rootDir, entry.name)) + if (match) { + return match + } + } + } + return null +} + +async function downloadSimpleExtensionArchive ({ assetKey, assetName, targetDir }) { + if (!assetName) { + throw new Error('Simple extension asset name is required.') + } + const downloadUrl = `${SIMPLE_DOWNLOAD_BASE_URL}/${assetName}` + const tempFile = path.join(os.tmpdir(), `libsimple-${Date.now()}-${Math.random().toString(16).slice(2)}.zip`) + ensureDirectoryExists(path.dirname(tempFile)) + await downloadToFile(downloadUrl, tempFile) + removeDirectoryIfExists(targetDir) + ensureDirectoryExists(targetDir) + if (AdmZip) { + try { + const zip = new AdmZip(tempFile) + zip.extractAllTo(targetDir, true) + } finally { + if (fs.existsSync(tempFile)) { + fs.unlinkSync(tempFile) + } + } + } else { + // 尝试使用 unzip 命令解压 + try { + execSync(`unzip "${tempFile}" -d "${targetDir}"`, { stdio: 'inherit' }) + } catch (error) { + throw new Error(`Failed to extract zip file: ${error.message}. Please install adm-zip manually: pnpm i`) + } finally { + if (fs.existsSync(tempFile)) { + fs.unlinkSync(tempFile) + } + } + } + + const libraryFile = findLibraryFile(targetDir) + if (!libraryFile) { + throw new Error('Downloaded extension package does not contain libsimple library.') + } + const dictDir = findDictDirectory(targetDir) + if (!ChatGPTConfig.memory.extensions) { + ChatGPTConfig.memory.extensions = {} + } + if (!ChatGPTConfig.memory.extensions.simple) { + ChatGPTConfig.memory.extensions.simple = { + enable: false, + libraryPath: '', + dictPath: '', + useJieba: false + } + } + const relativeLibraryPath = toPluginRelativePath(libraryFile) + const relativeDictPath = dictDir ? toPluginRelativePath(dictDir) : '' + ChatGPTConfig.memory.extensions.simple.libraryPath = relativeLibraryPath + ChatGPTConfig.memory.extensions.simple.dictPath = relativeDictPath + return { + assetKey, + assetName, + installDir: toPluginRelativePath(targetDir), + libraryPath: relativeLibraryPath, + dictPath: ChatGPTConfig.memory.extensions.simple.dictPath + } +} + +function updateMemoryConfig (payload = {}) { + const current = ChatGPTConfig.memory || {} + const previousDatabase = current.database + const previousDimension = current.vectorDimensions + + const nextConfig = { + ...current, + group: { + ...(current.group || {}) + }, + user: { + ...(current.user || {}) + }, + extensions: { + ...(current.extensions || {}), + simple: { + ...(current.extensions?.simple || {}) + } + } + } + const previousSimpleConfig = JSON.stringify(current.extensions?.simple || {}) + + if (Object.prototype.hasOwnProperty.call(payload, 'database') && typeof payload.database === 'string') { + nextConfig.database = payload.database.trim() + } + if (Object.prototype.hasOwnProperty.call(payload, 'vectorDimensions')) { + const dimension = parsePositiveInt(payload.vectorDimensions, current.vectorDimensions || 1536) + if (dimension > 0) { + nextConfig.vectorDimensions = dimension + } + } + + if (payload.group && typeof payload.group === 'object') { + const incomingGroup = payload.group + if (Object.prototype.hasOwnProperty.call(incomingGroup, 'enable')) { + nextConfig.group.enable = Boolean(incomingGroup.enable) + } + if (Object.prototype.hasOwnProperty.call(incomingGroup, 'enabledGroups')) { + nextConfig.group.enabledGroups = toStringArray(incomingGroup.enabledGroups) + } + if (Object.prototype.hasOwnProperty.call(incomingGroup, 'extractionModel') && typeof incomingGroup.extractionModel === 'string') { + nextConfig.group.extractionModel = incomingGroup.extractionModel.trim() + } + if (Object.prototype.hasOwnProperty.call(incomingGroup, 'extractionPresetId') && typeof incomingGroup.extractionPresetId === 'string') { + nextConfig.group.extractionPresetId = incomingGroup.extractionPresetId.trim() + } + if (Object.prototype.hasOwnProperty.call(incomingGroup, 'minMessageCount')) { + nextConfig.group.minMessageCount = parsePositiveInt(incomingGroup.minMessageCount, nextConfig.group.minMessageCount || 0) + } + if (Object.prototype.hasOwnProperty.call(incomingGroup, 'maxMessageWindow')) { + nextConfig.group.maxMessageWindow = parsePositiveInt(incomingGroup.maxMessageWindow, nextConfig.group.maxMessageWindow || 0) + } + if (Object.prototype.hasOwnProperty.call(incomingGroup, 'retrievalMode')) { + const mode = String(incomingGroup.retrievalMode || '').toLowerCase() + if (['vector', 'keyword', 'hybrid'].includes(mode)) { + nextConfig.group.retrievalMode = mode + } + } + if (Object.prototype.hasOwnProperty.call(incomingGroup, 'hybridPrefer')) { + const prefer = String(incomingGroup.hybridPrefer || '').toLowerCase() + if (prefer === 'keyword-first') { + nextConfig.group.hybridPrefer = 'keyword-first' + } else if (prefer === 'vector-first') { + nextConfig.group.hybridPrefer = 'vector-first' + } + } + if (Object.prototype.hasOwnProperty.call(incomingGroup, 'historyPollInterval')) { + nextConfig.group.historyPollInterval = parsePositiveInt(incomingGroup.historyPollInterval, + nextConfig.group.historyPollInterval || 0) + } + if (Object.prototype.hasOwnProperty.call(incomingGroup, 'historyBatchSize')) { + nextConfig.group.historyBatchSize = parsePositiveInt(incomingGroup.historyBatchSize, + nextConfig.group.historyBatchSize || 0) + } + if (Object.prototype.hasOwnProperty.call(incomingGroup, 'promptHeader') && typeof incomingGroup.promptHeader === 'string') { + nextConfig.group.promptHeader = incomingGroup.promptHeader + } + if (Object.prototype.hasOwnProperty.call(incomingGroup, 'promptItemTemplate') && typeof incomingGroup.promptItemTemplate === 'string') { + nextConfig.group.promptItemTemplate = incomingGroup.promptItemTemplate + } + if (Object.prototype.hasOwnProperty.call(incomingGroup, 'promptFooter') && typeof incomingGroup.promptFooter === 'string') { + nextConfig.group.promptFooter = incomingGroup.promptFooter + } + if (Object.prototype.hasOwnProperty.call(incomingGroup, 'vectorMaxDistance')) { + const distance = parseNumber(incomingGroup.vectorMaxDistance, + nextConfig.group.vectorMaxDistance ?? 0) + nextConfig.group.vectorMaxDistance = distance + } + if (Object.prototype.hasOwnProperty.call(incomingGroup, 'textMaxBm25Score')) { + const bm25 = parseNumber(incomingGroup.textMaxBm25Score, + nextConfig.group.textMaxBm25Score ?? 0) + nextConfig.group.textMaxBm25Score = bm25 + } + if (Object.prototype.hasOwnProperty.call(incomingGroup, 'maxFactsPerInjection')) { + nextConfig.group.maxFactsPerInjection = parsePositiveInt(incomingGroup.maxFactsPerInjection, + nextConfig.group.maxFactsPerInjection || 0) + } + if (Object.prototype.hasOwnProperty.call(incomingGroup, 'minImportanceForInjection')) { + const importance = parseNumber(incomingGroup.minImportanceForInjection, + nextConfig.group.minImportanceForInjection ?? 0) + nextConfig.group.minImportanceForInjection = importance + } + } + + if (payload.user && typeof payload.user === 'object') { + const incomingUser = payload.user + if (Object.prototype.hasOwnProperty.call(incomingUser, 'enable')) { + nextConfig.user.enable = Boolean(incomingUser.enable) + } + if (Object.prototype.hasOwnProperty.call(incomingUser, 'whitelist')) { + nextConfig.user.whitelist = toStringArray(incomingUser.whitelist) + } + if (Object.prototype.hasOwnProperty.call(incomingUser, 'blacklist')) { + nextConfig.user.blacklist = toStringArray(incomingUser.blacklist) + } + if (Object.prototype.hasOwnProperty.call(incomingUser, 'extractionModel') && typeof incomingUser.extractionModel === 'string') { + nextConfig.user.extractionModel = incomingUser.extractionModel.trim() + } + if (Object.prototype.hasOwnProperty.call(incomingUser, 'extractionPresetId') && typeof incomingUser.extractionPresetId === 'string') { + nextConfig.user.extractionPresetId = incomingUser.extractionPresetId.trim() + } + if (Object.prototype.hasOwnProperty.call(incomingUser, 'maxItemsPerInjection')) { + nextConfig.user.maxItemsPerInjection = parsePositiveInt(incomingUser.maxItemsPerInjection, + nextConfig.user.maxItemsPerInjection || 0) + } + if (Object.prototype.hasOwnProperty.call(incomingUser, 'maxRelevantItemsPerQuery')) { + nextConfig.user.maxRelevantItemsPerQuery = parsePositiveInt(incomingUser.maxRelevantItemsPerQuery, + nextConfig.user.maxRelevantItemsPerQuery || 0) + } + if (Object.prototype.hasOwnProperty.call(incomingUser, 'minImportanceForInjection')) { + const importance = parseNumber(incomingUser.minImportanceForInjection, + nextConfig.user.minImportanceForInjection ?? 0) + nextConfig.user.minImportanceForInjection = importance + } + if (Object.prototype.hasOwnProperty.call(incomingUser, 'promptHeader') && typeof incomingUser.promptHeader === 'string') { + nextConfig.user.promptHeader = incomingUser.promptHeader + } + if (Object.prototype.hasOwnProperty.call(incomingUser, 'promptItemTemplate') && typeof incomingUser.promptItemTemplate === 'string') { + nextConfig.user.promptItemTemplate = incomingUser.promptItemTemplate + } + if (Object.prototype.hasOwnProperty.call(incomingUser, 'promptFooter') && typeof incomingUser.promptFooter === 'string') { + nextConfig.user.promptFooter = incomingUser.promptFooter + } + } + + if (payload.extensions && typeof payload.extensions === 'object' && !Array.isArray(payload.extensions)) { + const incomingExtensions = payload.extensions + if (incomingExtensions.simple && typeof incomingExtensions.simple === 'object' && !Array.isArray(incomingExtensions.simple)) { + const incomingSimple = incomingExtensions.simple + if (Object.prototype.hasOwnProperty.call(incomingSimple, 'enable')) { + nextConfig.extensions.simple.enable = Boolean(incomingSimple.enable) + } + if (Object.prototype.hasOwnProperty.call(incomingSimple, 'libraryPath') && typeof incomingSimple.libraryPath === 'string') { + nextConfig.extensions.simple.libraryPath = incomingSimple.libraryPath.trim() + } + if (Object.prototype.hasOwnProperty.call(incomingSimple, 'dictPath') && typeof incomingSimple.dictPath === 'string') { + nextConfig.extensions.simple.dictPath = incomingSimple.dictPath.trim() + } + if (Object.prototype.hasOwnProperty.call(incomingSimple, 'useJieba')) { + nextConfig.extensions.simple.useJieba = Boolean(incomingSimple.useJieba) + } + } else if (Object.prototype.hasOwnProperty.call(incomingExtensions, 'simple')) { + logger.warn('[Memory] Unexpected value for extensions.simple, ignoring:', incomingExtensions.simple) + } + } + + ChatGPTConfig.memory.database = nextConfig.database + ChatGPTConfig.memory.vectorDimensions = nextConfig.vectorDimensions + if (!ChatGPTConfig.memory.group) ChatGPTConfig.memory.group = {} + if (!ChatGPTConfig.memory.user) ChatGPTConfig.memory.user = {} + if (!ChatGPTConfig.memory.extensions) ChatGPTConfig.memory.extensions = {} + if (!ChatGPTConfig.memory.extensions.simple) { + ChatGPTConfig.memory.extensions.simple = { + enable: false, + libraryPath: '', + dictPath: '', + useJieba: false + } + } + Object.assign(ChatGPTConfig.memory.group, nextConfig.group) + Object.assign(ChatGPTConfig.memory.user, nextConfig.user) + Object.assign(ChatGPTConfig.memory.extensions.simple, nextConfig.extensions.simple) + + if (nextConfig.vectorDimensions !== previousDimension) { + resetCachedDimension() + const targetDimension = Number(nextConfig.vectorDimensions) + if (Number.isFinite(targetDimension) && targetDimension > 0) { + try { + resetVectorTableDimension(targetDimension) + } catch (err) { + logger?.error?.('[Memory] failed to apply vector dimension change:', err) + } + } + } + const currentSimpleConfig = JSON.stringify(ChatGPTConfig.memory.extensions?.simple || {}) + + if (nextConfig.database !== previousDatabase) { + resetMemoryDatabaseInstance() + } else if (currentSimpleConfig !== previousSimpleConfig) { + resetMemoryDatabaseInstance() + } + + if (typeof ChatGPTConfig._triggerSave === 'function') { + ChatGPTConfig._triggerSave('memory') + } + + return ChatGPTConfig.memory +} + +export const MemoryRouter = (() => { + const router = express.Router() + + router.get('/config', (_req, res) => { + res.status(200).json(ChaiteResponse.ok(ChatGPTConfig.memory)) + }) + + router.post('/config', (req, res) => { + try { + const updated = updateMemoryConfig(req.body || {}) + res.status(200).json(ChaiteResponse.ok(updated)) + } catch (error) { + logger.error('Failed to update memory config:', error) + res.status(500).json(ChaiteResponse.fail(null, 'Failed to update memory config')) + } + }) + + router.get('/group/:groupId/facts', (req, res) => { + const { groupId } = req.params + const limit = parsePositiveInt(req.query.limit, 50) + const offset = parsePositiveInt(req.query.offset, 0) + try { + const facts = memoryService.listGroupFacts(groupId, limit, offset) + res.status(200).json(ChaiteResponse.ok(facts)) + } catch (error) { + logger.error('Failed to fetch group facts:', error) + res.status(500).json(ChaiteResponse.fail(null, 'Failed to fetch group facts')) + } + }) + + router.get('/extensions/simple/status', (_req, res) => { + try { + logger?.debug?.('[Memory] simple extension status requested') + const state = getSimpleExtensionState() + const simpleConfig = ChatGPTConfig.memory?.extensions?.simple || {} + const libraryPath = simpleConfig.libraryPath || state.libraryPath || '' + const dictPath = simpleConfig.dictPath || state.dictPath || '' + const resolvedLibraryPath = libraryPath ? resolvePluginPath(libraryPath) : '' + const resolvedDictPath = dictPath ? resolvePluginPath(dictPath) : '' + res.status(200).json(ChaiteResponse.ok({ + ...state, + enabled: Boolean(simpleConfig.enable), + libraryPath, + dictPath, + platform: process.platform, + arch: process.arch, + resolvedLibraryPath, + libraryExists: resolvedLibraryPath ? fs.existsSync(resolvedLibraryPath) : false, + resolvedDictPath, + dictExists: resolvedDictPath ? fs.existsSync(resolvedDictPath) : false + })) + } catch (error) { + logger.error('Failed to read simple extension status:', error) + res.status(500).json(ChaiteResponse.fail(null, 'Failed to read simple extension status')) + } + }) + + router.post('/extensions/simple/download', async (req, res) => { + const { assetKey, assetName, installDir } = req.body || {} + try { + const resolvedAsset = resolveSimpleAsset(assetKey, assetName) + if (!resolvedAsset.asset) { + res.status(400).json(ChaiteResponse.fail(null, '无法确定当前平台的扩展文件,请手动指定 assetName。')) + return + } + logger?.info?.('[Memory] downloading simple extension asset=%s (key=%s)', resolvedAsset.asset, resolvedAsset.key) + const targetRelativeDir = installDir || path.join(DEFAULT_SIMPLE_INSTALL_DIR, resolvedAsset.key || 'downloaded') + const targetDir = resolvePluginPath(targetRelativeDir) + const result = await downloadSimpleExtensionArchive({ + assetKey: resolvedAsset.key || assetKey || 'custom', + assetName: resolvedAsset.asset, + targetDir + }) + resetMemoryDatabaseInstance() + logger?.info?.('[Memory] simple extension downloaded and memory DB scheduled for reload') + res.status(200).json(ChaiteResponse.ok({ + ...result, + assetName: resolvedAsset.asset, + assetKey: resolvedAsset.key || assetKey || 'custom' + })) + } catch (error) { + logger.error('Failed to download simple extension:', error) + res.status(500).json(ChaiteResponse.fail(null, error?.message || 'Failed to download simple extension')) + } + }) + + router.post('/group/:groupId/facts', async (req, res) => { + const { groupId } = req.params + const facts = Array.isArray(req.body?.facts) ? req.body.facts : [] + if (facts.length === 0) { + res.status(400).json(ChaiteResponse.fail(null, 'facts is required')) + return + } + try { + const saved = await memoryService.saveGroupFacts(groupId, facts) + res.status(200).json(ChaiteResponse.ok(saved)) + } catch (error) { + logger.error('Failed to save group facts:', error) + res.status(500).json(ChaiteResponse.fail(null, 'Failed to save group facts')) + } + }) + + router.post('/group/:groupId/query', async (req, res) => { + const { groupId } = req.params + const { query, limit, minImportance } = req.body || {} + if (!query || typeof query !== 'string') { + res.status(400).json(ChaiteResponse.fail(null, 'query is required')) + return + } + try { + const facts = await memoryService.queryGroupFacts(groupId, query, { + limit: parsePositiveInt(limit, undefined), + minImportance: minImportance !== undefined ? parseNumber(minImportance, undefined) : undefined + }) + res.status(200).json(ChaiteResponse.ok(facts)) + } catch (error) { + logger.error('Failed to query group memory:', error) + res.status(500).json(ChaiteResponse.fail(null, 'Failed to query group memory')) + } + }) + + router.delete('/group/:groupId/facts/:factId', (req, res) => { + const { groupId, factId } = req.params + try { + const removed = memoryService.deleteGroupFact(groupId, factId) + if (!removed) { + res.status(404).json(ChaiteResponse.fail(null, 'Fact not found')) + return + } + res.status(200).json(ChaiteResponse.ok({ removed })) + } catch (error) { + logger.error('Failed to delete group fact:', error) + res.status(500).json(ChaiteResponse.fail(null, 'Failed to delete group fact')) + } + }) + + router.get('/user/memories', (req, res) => { + const userId = parseOptionalStringParam(req.query.userId) + const groupId = parseOptionalStringParam(req.query.groupId) + const limit = parsePositiveInt(req.query.limit, 50) + const offset = parsePositiveInt(req.query.offset, 0) + try { + const memories = memoryService.listUserMemories(userId, groupId, limit, offset) + res.status(200).json(ChaiteResponse.ok(memories)) + } catch (error) { + logger.error('Failed to fetch user memories:', error) + res.status(500).json(ChaiteResponse.fail(null, 'Failed to fetch user memories')) + } + }) + + router.get('/user/:userId/memories', (req, res) => { + const { userId } = req.params + const groupId = req.query.groupId ?? null + const limit = parsePositiveInt(req.query.limit, 50) + const offset = parsePositiveInt(req.query.offset, 0) + try { + const memories = memoryService.listUserMemories(userId, groupId, limit, offset) + res.status(200).json(ChaiteResponse.ok(memories)) + } catch (error) { + logger.error('Failed to fetch user memories:', error) + res.status(500).json(ChaiteResponse.fail(null, 'Failed to fetch user memories')) + } + }) + + router.post('/user/:userId/query', (req, res) => { + const { userId } = req.params + const groupId = req.body?.groupId ?? req.query.groupId ?? null + const query = req.body?.query + const totalLimit = parsePositiveInt(req.body?.totalLimit, undefined) + const searchLimit = parsePositiveInt(req.body?.searchLimit, undefined) + const minImportance = req.body?.minImportance !== undefined + ? parseNumber(req.body.minImportance, undefined) + : undefined + if (!query || typeof query !== 'string') { + res.status(400).json(ChaiteResponse.fail(null, 'query is required')) + return + } + try { + const memories = memoryService.queryUserMemories(userId, groupId, query, { + totalLimit, + searchLimit, + minImportance + }) + res.status(200).json(ChaiteResponse.ok(memories)) + } catch (error) { + logger.error('Failed to query user memory:', error) + res.status(500).json(ChaiteResponse.fail(null, 'Failed to query user memory')) + } + }) + + router.post('/user/:userId/memories', (req, res) => { + const { userId } = req.params + const groupId = req.body?.groupId ?? null + const memories = Array.isArray(req.body?.memories) ? req.body.memories : [] + if (memories.length === 0) { + res.status(400).json(ChaiteResponse.fail(null, 'memories is required')) + return + } + try { + const updated = memoryService.upsertUserMemories(userId, groupId, memories) + res.status(200).json(ChaiteResponse.ok({ updated })) + } catch (error) { + logger.error('Failed to upsert user memories:', error) + res.status(500).json(ChaiteResponse.fail(null, 'Failed to upsert user memories')) + } + }) + + router.delete('/user/:userId/memories/:memoryId', (req, res) => { + const { userId, memoryId } = req.params + try { + const removed = memoryService.deleteUserMemory(memoryId, userId) + if (!removed) { + res.status(404).json(ChaiteResponse.fail(null, 'Memory not found')) + return + } + res.status(200).json(ChaiteResponse.ok({ removed })) + } catch (error) { + logger.error('Failed to delete user memory:', error) + res.status(500).json(ChaiteResponse.fail(null, 'Failed to delete user memory')) + } + }) + + router.delete('/memories/:memoryId', (req, res) => { + const { memoryId } = req.params + try { + const removed = memoryService.deleteUserMemory(memoryId) + if (!removed) { + res.status(404).json(ChaiteResponse.fail(null, 'Memory not found')) + return + } + res.status(200).json(ChaiteResponse.ok({ removed })) + } catch (error) { + logger.error('Failed to delete memory:', error) + res.status(500).json(ChaiteResponse.fail(null, 'Failed to delete memory')) + } + }) + + return router +})() diff --git a/models/memory/service.js b/models/memory/service.js new file mode 100644 index 0000000..b9eb315 --- /dev/null +++ b/models/memory/service.js @@ -0,0 +1,194 @@ +import ChatGPTConfig from '../../config/config.js' +import { getMemoryDatabase } from './database.js' +import { GroupMemoryStore } from './groupMemoryStore.js' +import { UserMemoryStore } from './userMemoryStore.js' + +function normaliseId (id) { + if (id === null || id === undefined) { + return '' + } + return String(id) +} + +function formatEntry (entry) { + let str = '' + try { + str = JSON.stringify(entry) + } catch (err) { + str = String(entry) + } + const limit = 200 + return str.length > limit ? str.slice(0, limit) + '…' : str +} + +function normalisePersonalMemory (entry) { + if (!entry) return null + let text = '' + let importance = typeof entry?.importance === 'number' ? entry.importance : 0.6 + let sourceId = entry?.source_message_id ? String(entry.source_message_id) : null + if (typeof entry === 'string') { + text = entry.trim() + } else if (typeof entry === 'object') { + const value = entry.value || entry.text || entry.fact || entry.sentence + if (Array.isArray(value)) { + text = value.join(', ').trim() + } else if (value) { + text = String(value).trim() + } + if (entry.importance !== undefined) { + importance = Number(entry.importance) + } + if (entry.source_message_id) { + sourceId = String(entry.source_message_id) + } + } + if (!text) { + return null + } + if (Number.isNaN(importance) || importance <= 0) { + importance = 0.6 + } + return { text, importance, sourceId } +} + +class MemoryService { + constructor () { + const db = getMemoryDatabase() + this.groupStore = new GroupMemoryStore(db) + this.userStore = new UserMemoryStore(db) + } + + isGroupMemoryEnabled (groupId) { + const config = ChatGPTConfig.memory?.group + if (!config?.enable) { + return false + } + const enabledGroups = (config.enabledGroups || []).map(normaliseId) + if (enabledGroups.length === 0) { + return false + } + return enabledGroups.includes(normaliseId(groupId)) + } + + isUserMemoryEnabled (userId) { + const config = ChatGPTConfig.memory?.user + if (!config?.enable) { + return false + } + const uid = normaliseId(userId) + const whitelist = (config.whitelist || []).map(normaliseId).filter(Boolean) + const blacklist = (config.blacklist || []).map(normaliseId).filter(Boolean) + if (whitelist.length > 0) { + return whitelist.includes(uid) + } + if (blacklist.length > 0) { + return !blacklist.includes(uid) + } + return true + } + + async saveGroupFacts (groupId, facts) { + if (!this.isGroupMemoryEnabled(groupId)) { + return [] + } + try { + const saved = await this.groupStore.saveFacts(groupId, facts) + if (saved.length > 0) { + logger.info(`[Memory] group=${groupId} stored ${saved.length} facts`) + saved.slice(0, 10).forEach((item, idx) => { + logger.debug(`[Memory] group stored fact[${idx}] ${formatEntry(item)}`) + }) + } + return saved + } catch (err) { + logger.error('Failed to save group facts:', err) + return [] + } + } + + async queryGroupFacts (groupId, queryText, options = {}) { + if (!this.isGroupMemoryEnabled(groupId)) { + return [] + } + const { maxFactsPerInjection = 5, minImportanceForInjection = 0 } = ChatGPTConfig.memory?.group || {} + const limit = options.limit || maxFactsPerInjection + const minImportance = options.minImportance ?? minImportanceForInjection + try { + return await this.groupStore.queryRelevantFacts(groupId, queryText, { limit, minImportance }) + } catch (err) { + logger.error('Failed to query group memory:', err) + return [] + } + } + + listGroupFacts (groupId, limit = 50, offset = 0) { + return this.groupStore.listFacts(groupId, limit, offset) + } + + deleteGroupFact (groupId, factId) { + return this.groupStore.deleteFact(groupId, factId) + } + + upsertUserMemories (userId, groupId, memories) { + if (!this.isUserMemoryEnabled(userId)) { + return 0 + } + try { + const prepared = (memories || []) + .map(normalisePersonalMemory) + .filter(item => item && item.text) + .map(item => ({ + value: item.text, + importance: item.importance, + source_message_id: item.sourceId + })) + if (prepared.length === 0) { + return 0 + } + const changed = this.userStore.upsertMemories(userId, groupId, prepared) + if (changed > 0) { + logger.info(`[Memory] user=${userId} updated ${changed} personal memories${groupId ? ` in group=${groupId}` : ''}`) + prepared.slice(0, 10).forEach((item, idx) => { + logger.debug(`[Memory] user memory upsert[${idx}] ${formatEntry(item)}`) + }) + } + return changed + } catch (err) { + logger.error('Failed to upsert user memories:', err) + return 0 + } + } + + queryUserMemories (userId, groupId = null, queryText = '', options = {}) { + if (!this.isUserMemoryEnabled(userId)) { + return [] + } + const userConfig = ChatGPTConfig.memory?.user || {} + const totalLimit = options.totalLimit ?? userConfig.maxItemsPerInjection ?? 5 + const searchLimit = options.searchLimit ?? userConfig.maxRelevantItemsPerQuery ?? totalLimit + const minImportance = options.minImportance ?? userConfig.minImportanceForInjection ?? 0 + if (!totalLimit || totalLimit <= 0) { + return [] + } + try { + return this.userStore.queryMemories(userId, groupId, queryText, { + limit: searchLimit, + fallbackLimit: totalLimit, + minImportance + }) + } catch (err) { + logger.error('Failed to query user memories:', err) + return [] + } + } + + listUserMemories (userId, groupId = null, limit = 50, offset = 0) { + return this.userStore.listUserMemories(userId, groupId, limit, offset) + } + + deleteUserMemory (memoryId, userId = null) { + return this.userStore.deleteMemoryById(memoryId, userId) + } +} + +export const memoryService = new MemoryService() diff --git a/models/memory/userMemoryManager.js b/models/memory/userMemoryManager.js new file mode 100644 index 0000000..18dd1a3 --- /dev/null +++ b/models/memory/userMemoryManager.js @@ -0,0 +1,129 @@ +import { Chaite } from 'chaite' +import * as crypto from 'node:crypto' +import { extractUserMemories } from './extractor.js' +import { memoryService } from './service.js' + +const USER_MEMORY_CONTEXT_LIMIT = 6 + +export function extractTextFromContents (contents) { + if (!Array.isArray(contents)) { + return '' + } + return contents + .filter(item => item && item.type === 'text') + .map(item => item.text || '') + .join('\n') + .trim() +} + +export function extractTextFromUserMessage (userMessage) { + if (!userMessage?.content) { + return '' + } + return userMessage.content + .filter(item => item.type === 'text') + .map(item => item.text || '') + .join('\n') + .trim() +} + +function normaliseMemoriesInput (memories, sourceId) { + return (memories || []).map(mem => { + if (typeof mem === 'string') { + return { + value: mem, + source_message_id: sourceId + } + } + if (mem && typeof mem === 'object') { + const cloned = { ...mem } + if (!cloned.source_message_id && sourceId) { + cloned.source_message_id = sourceId + } + if (!cloned.value && cloned.fact) { + cloned.value = cloned.fact + } + if (!cloned.value && cloned.text) { + cloned.value = cloned.text + } + return cloned + } + return { + value: String(mem), + source_message_id: sourceId + } + }) +} + +export async function processUserMemory ({ event, userMessage, userText, conversationId, assistantContents, assistantMessageId }) { + const e = event + if (!memoryService.isUserMemoryEnabled(e.sender.user_id)) { + return + } + const snippets = [] + const userMessageId = e.message_id || e.seq || userMessage?.id || crypto.randomUUID() + const senderName = e.sender?.card || e.sender?.nickname || String(e.sender?.user_id || '') + + try { + const historyManager = Chaite.getInstance()?.getHistoryManager?.() + if (historyManager && conversationId) { + const history = await historyManager.getHistory(null, conversationId) + const filtered = (history || []) + .filter(msg => ['user', 'assistant'].includes(msg.role)) + .map(msg => ({ + role: msg.role, + text: extractTextFromContents(msg.content), + nickname: msg.role === 'user' ? senderName : '机器人', + message_id: msg.id + })) + .filter(item => item.text) + if (filtered.length > 0) { + const limited = filtered.slice(-USER_MEMORY_CONTEXT_LIMIT * 2) + snippets.push(...limited) + } + } + } catch (err) { + logger.warn('Failed to collect user memory context:', err) + } + + if (assistantContents) { + const assistantText = extractTextFromContents(assistantContents) + if (assistantText) { + snippets.push({ + role: 'assistant', + text: assistantText, + nickname: '机器人', + message_id: assistantMessageId || crypto.randomUUID() + }) + } + } + + if (userText && !snippets.some(item => item.message_id === userMessageId)) { + snippets.push({ + role: 'user', + text: userText, + nickname: senderName, + message_id: userMessageId + }) + } + + if (snippets.length === 0) { + return + } + + const existingRecords = memoryService.listUserMemories(e.sender.user_id, e.isGroup ? e.group_id : null, 50) + const existingTexts = existingRecords.map(record => record.value).filter(Boolean) + const memories = await extractUserMemories(snippets, existingTexts) + if (!memories || memories.length === 0) { + return + } + + const enriched = normaliseMemoriesInput(memories, userMessageId) + memoryService.upsertUserMemories( + e.sender.user_id, + e.isGroup ? e.group_id : null, + enriched + ) +} + +export { USER_MEMORY_CONTEXT_LIMIT } diff --git a/models/memory/userMemoryStore.js b/models/memory/userMemoryStore.js new file mode 100644 index 0000000..9511bae --- /dev/null +++ b/models/memory/userMemoryStore.js @@ -0,0 +1,335 @@ +import { getMemoryDatabase, getUserMemoryFtsConfig, sanitiseFtsQueryInput } from './database.js' +import { md5 } from '../../utils/common.js' + +function normaliseId (value) { + if (value === null || value === undefined) { + return null + } + const str = String(value).trim() + if (!str || str.toLowerCase() === 'null' || str.toLowerCase() === 'undefined') { + return null + } + return str +} + +function toMemoryPayload (entry) { + if (entry === null || entry === undefined) { + return null + } + if (typeof entry === 'string') { + const text = entry.trim() + return text ? { value: text, importance: 0.5 } : null + } + if (typeof entry === 'object') { + const rawValue = entry.value ?? entry.text ?? entry.fact ?? '' + const value = typeof rawValue === 'string' ? rawValue.trim() : String(rawValue || '').trim() + if (!value) { + return null + } + const importance = typeof entry.importance === 'number' ? entry.importance : 0.5 + const sourceId = entry.source_message_id ? String(entry.source_message_id) : null + const providedKey = entry.key ? String(entry.key).trim() : '' + return { + value, + importance, + source_message_id: sourceId, + providedKey + } + } + const value = String(entry).trim() + return value ? { value, importance: 0.5 } : null +} + +function deriveKey (value, providedKey = '') { + const trimmedProvided = providedKey?.trim?.() || '' + if (trimmedProvided) { + return trimmedProvided + } + if (!value) { + return null + } + return `fact:${md5(String(value))}` +} + +function stripKey (row) { + if (!row || typeof row !== 'object') { + return row + } + const { key, ...rest } = row + return rest +} + +function appendRows (target, rows, seen) { + if (!Array.isArray(rows)) { + return + } + for (const row of rows) { + if (!row || seen.has(row.id)) { + continue + } + target.push(stripKey(row)) + seen.add(row.id) + } +} + +export class UserMemoryStore { + constructor (db = getMemoryDatabase()) { + this.resetDatabase(db) + } + + resetDatabase (db = getMemoryDatabase()) { + this.db = db + this.upsertStmt = this.db.prepare(` + INSERT INTO user_memory (user_id, group_id, key, value, importance, source_message_id, created_at, updated_at) + VALUES (@user_id, @group_id, @key, @value, @importance, @source_message_id, datetime('now'), datetime('now')) + ON CONFLICT(user_id, coalesce(group_id, ''), key) DO UPDATE SET + value = excluded.value, + importance = excluded.importance, + source_message_id = excluded.source_message_id, + updated_at = datetime('now') + `) + } + + ensureDb () { + if (!this.db || this.db.open === false) { + logger?.debug?.('[Memory] refreshing user memory database connection') + this.resetDatabase() + } + return this.db + } + + upsertMemories (userId, groupId, memories) { + if (!memories || memories.length === 0) { + return 0 + } + this.ensureDb() + const normUserId = normaliseId(userId) + const normGroupId = normaliseId(groupId) + const prepared = (memories || []) + .map(toMemoryPayload) + .filter(item => item && item.value) + .map(item => { + const key = deriveKey(item.value, item.providedKey) + if (!key) { + return null + } + return { + user_id: normUserId, + group_id: normGroupId, + key, + value: String(item.value), + importance: typeof item.importance === 'number' ? item.importance : 0.5, + source_message_id: item.source_message_id ? String(item.source_message_id) : null + } + }) + .filter(Boolean) + if (!prepared.length) { + return 0 + } + const transaction = this.db.transaction(items => { + let changes = 0 + for (const item of items) { + const info = this.upsertStmt.run(item) + changes += info.changes + } + return changes + }) + return transaction(prepared) + } + + listUserMemories (userId = null, groupId = null, limit = 50, offset = 0) { + this.ensureDb() + const normUserId = normaliseId(userId) + const normGroupId = normaliseId(groupId) + const params = [] + let query = ` + SELECT * FROM user_memory + WHERE 1 = 1 + ` + if (normUserId) { + query += ' AND user_id = ?' + params.push(normUserId) + } + if (normGroupId) { + if (normUserId) { + query += ' AND (group_id = ? OR group_id IS NULL)' + } else { + query += ' AND group_id = ?' + } + params.push(normGroupId) + } + query += ` + ORDER BY importance DESC, updated_at DESC + LIMIT ? OFFSET ? + ` + params.push(limit, offset) + const rows = this.db.prepare(query).all(...params) + return rows.map(stripKey) + } + + deleteMemoryById (memoryId, userId = null) { + this.ensureDb() + if (userId) { + const result = this.db.prepare('DELETE FROM user_memory WHERE id = ? AND user_id = ?').run(memoryId, normaliseId(userId)) + return result.changes > 0 + } + const result = this.db.prepare('DELETE FROM user_memory WHERE id = ?').run(memoryId) + return result.changes > 0 + } + + listRecentMemories (userId, groupId = null, limit = 50, excludeIds = [], minImportance = 0) { + this.ensureDb() + const normUserId = normaliseId(userId) + const normGroupId = normaliseId(groupId) + const filteredExclude = (excludeIds || []).filter(Boolean) + const params = [normUserId] + let query = ` + SELECT * FROM user_memory + WHERE user_id = ? + AND importance >= ? + ` + params.push(minImportance) + if (normGroupId) { + query += ' AND (group_id = ? OR group_id IS NULL)' + params.push(normGroupId) + } + if (filteredExclude.length) { + query += ` AND id NOT IN (${filteredExclude.map(() => '?').join(',')})` + params.push(...filteredExclude) + } + query += ` + ORDER BY updated_at DESC + LIMIT ? + ` + params.push(limit) + return this.db.prepare(query).all(...params).map(stripKey) + } + + textSearch (userId, groupId = null, queryText, limit = 5, excludeIds = []) { + if (!queryText || !queryText.trim()) { + return [] + } + this.ensureDb() + const normUserId = normaliseId(userId) + const normGroupId = normaliseId(groupId) + const filteredExclude = (excludeIds || []).filter(Boolean) + const originalQuery = queryText.trim() + const ftsConfig = getUserMemoryFtsConfig() + const matchQueryParam = sanitiseFtsQueryInput(originalQuery, ftsConfig) + const results = [] + const seen = new Set(filteredExclude) + if (matchQueryParam) { + const matchExpression = ftsConfig.matchQuery ? `${ftsConfig.matchQuery}(?)` : '?' + const params = [normUserId, matchQueryParam] + let query = ` + SELECT um.*, bm25(user_memory_fts) AS bm25_score + FROM user_memory_fts + JOIN user_memory um ON um.id = user_memory_fts.rowid + WHERE um.user_id = ? + AND user_memory_fts MATCH ${matchExpression} + ` + if (normGroupId) { + query += ' AND (um.group_id = ? OR um.group_id IS NULL)' + params.push(normGroupId) + } + if (filteredExclude.length) { + query += ` AND um.id NOT IN (${filteredExclude.map(() => '?').join(',')})` + params.push(...filteredExclude) + } + query += ` + ORDER BY bm25_score ASC, um.updated_at DESC + LIMIT ? + ` + params.push(limit) + try { + const ftsRows = this.db.prepare(query).all(...params) + appendRows(results, ftsRows, seen) + } catch (err) { + logger?.warn?.('User memory text search failed:', err) + } + } else { + logger?.debug?.('[Memory] user memory text search skipped MATCH due to empty query after sanitisation') + } + + if (results.length < limit) { + const likeParams = [normUserId, originalQuery] + let likeQuery = ` + SELECT um.* + FROM user_memory um + WHERE um.user_id = ? + AND instr(um.value, ?) > 0 + ` + if (normGroupId) { + likeQuery += ' AND (um.group_id = ? OR um.group_id IS NULL)' + likeParams.push(normGroupId) + } + if (filteredExclude.length) { + likeQuery += ` AND um.id NOT IN (${filteredExclude.map(() => '?').join(',')})` + likeParams.push(...filteredExclude) + } + likeQuery += ` + ORDER BY um.importance DESC, um.updated_at DESC + LIMIT ? + ` + likeParams.push(Math.max(limit * 2, limit)) + try { + const likeRows = this.db.prepare(likeQuery).all(...likeParams) + appendRows(results, likeRows, seen) + } catch (err) { + logger?.warn?.('User memory LIKE search failed:', err) + } + } + + return results.slice(0, limit) + } + + queryMemories (userId, groupId = null, queryText = '', options = {}) { + const normUserId = normaliseId(userId) + if (!normUserId) { + return [] + } + this.ensureDb() + const { + limit = 3, + fallbackLimit, + minImportance = 0 + } = options + const totalLimit = Math.max(0, fallbackLimit ?? limit ?? 0) + if (totalLimit === 0) { + return [] + } + const searchLimit = limit > 0 ? Math.min(limit, totalLimit) : totalLimit + const results = [] + const seen = new Set() + const append = rows => { + for (const row of rows || []) { + if (!row || seen.has(row.id)) { + continue + } + results.push(row) + seen.add(row.id) + if (results.length >= totalLimit) { + break + } + } + } + + if (queryText && searchLimit > 0) { + const searched = this.textSearch(userId, groupId, queryText, searchLimit) + append(searched) + } + + if (results.length < totalLimit) { + const recent = this.listRecentMemories( + userId, + groupId, + Math.max(totalLimit * 2, totalLimit), + Array.from(seen), + minImportance + ) + append(recent) + } + + return results.slice(0, totalLimit) + } +} diff --git a/package.json b/package.json index 62933ed..7a4e110 100644 --- a/package.json +++ b/package.json @@ -1,16 +1,21 @@ { "name": "chatgpt-plugin", - "version": "3.0.0-beta.1", + "version": "3.0.0", "type": "module", "author": "ikechan8370", "dependencies": { - "chaite": "^1.4.0", + "better-sqlite3": "^9.4.3", + "adm-zip": "^0.5.10", + "chaite": "^1.8.2", "js-yaml": "^4.1.0", "keyv": "^5.3.1", "keyv-file": "^5.1.2", "lowdb": "^7.0.1", - "vectra": "^0.9.0", - "sqlite3": "^5.1.6" + "sqlite-vec": "^0.1.7-alpha.2", + "vectra": "^0.9.0" + }, + "peerDependencies": { + "sqlite3": ">=5.1.6" }, "pnpm": {} } diff --git a/utils/group.js b/utils/group.js index 6815914..876e76c 100644 --- a/utils/group.js +++ b/utils/group.js @@ -75,6 +75,9 @@ export class TRSSGroupContextCollector extends GroupContextCollector { * @returns {Promise>} */ async collect (bot = Bot, groupId, start = 0, length = 20) { + if (!bot) { + return [] + } const group = bot.pickGroup(groupId) let chats = await group.getChatHistory(start, length) try {