fix: memory retrieval bug

This commit is contained in:
ikechan8370 2025-11-05 15:32:29 +08:00
parent 0550e6d492
commit 9a5fe1d610
4 changed files with 170 additions and 81 deletions

View file

@ -114,6 +114,22 @@ function resetSimpleState (overrides = {}) {
} }
} }
function sanitiseRawFtsInput (input) {
if (!input) {
return ''
}
const trimmed = String(input).trim()
if (!trimmed) {
return ''
}
const replaced = trimmed
.replace(/["'`]+/g, ' ')
.replace(/\u3000/g, ' ')
.replace(/[^\p{L}\p{N}\u4E00-\u9FFF\u3040-\u30FF\uAC00-\uD7AF\u1100-\u11FF\s]+/gu, ' ')
const collapsed = replaced.replace(/\s+/g, ' ').trim()
return collapsed || trimmed
}
function isSimpleLibraryFile (filename) { function isSimpleLibraryFile (filename) {
return /(^libsimple.*\.(so|dylib|dll)$)|(^simple\.(so|dylib|dll)$)/i.test(filename) return /(^libsimple.*\.(so|dylib|dll)$)|(^simple\.(so|dylib|dll)$)/i.test(filename)
} }
@ -644,6 +660,16 @@ export function getSimpleExtensionState () {
return { ...simpleExtensionState } return { ...simpleExtensionState }
} }
export function sanitiseFtsQueryInput (query, ftsConfig) {
if (!query) {
return ''
}
if (ftsConfig?.matchQuery) {
return String(query).trim()
}
return sanitiseRawFtsInput(query)
}
export function getMemoryDatabase () { export function getMemoryDatabase () {
if (dbInstance) { if (dbInstance) {
return dbInstance return dbInstance

View file

@ -1,4 +1,4 @@
import { SendMessageOption } from 'chaite' import { SendMessageOption, Chaite } from 'chaite'
import ChatGPTConfig from '../../config/config.js' import ChatGPTConfig from '../../config/config.js'
import { getClientForModel } from '../chaite/vectorizer.js' import { getClientForModel } from '../chaite/vectorizer.js'
@ -43,9 +43,40 @@ function formatEntry (entry) {
return str.length > limit ? str.slice(0, limit) + '…' : str return str.length > limit ? str.slice(0, limit) + '…' : str
} }
function resolveGroupExtractionPrompts () { async function resolvePresetSendMessageOption (presetId, scope) {
if (!presetId) {
return null
}
try {
const chaite = Chaite.getInstance?.()
if (!chaite) {
logger.warn(`[Memory] ${scope} extraction preset ${presetId} configured but Chaite is not initialized`)
return null
}
const presetManager = chaite.getChatPresetManager?.()
if (!presetManager) {
logger.warn(`[Memory] ${scope} extraction preset ${presetId} configured but preset manager unavailable`)
return null
}
const preset = await presetManager.getInstance(presetId)
if (!preset) {
logger.warn(`[Memory] ${scope} extraction preset ${presetId} not found`)
return null
}
logger.debug(`[Memory] using ${scope} extraction preset ${presetId}`)
return {
preset,
sendMessageOption: JSON.parse(JSON.stringify(preset.sendMessageOption || {}))
}
} catch (err) {
logger.error(`[Memory] failed to load ${scope} extraction preset ${presetId}:`, err)
return null
}
}
function resolveGroupExtractionPrompts (presetSendMessageOption) {
const config = ChatGPTConfig.memory?.group || {} const config = ChatGPTConfig.memory?.group || {}
const system = config.extractionSystemPrompt || `You are a knowledge extraction assistant that specialises in summarising long-term facts from group chat transcripts. const system = config.extractionSystemPrompt || presetSendMessageOption?.systemOverride || `You are a knowledge extraction assistant that specialises in summarising long-term facts from group chat transcripts.
Read the provided conversation and identify statements that should be stored as long-term knowledge for the group. Read the provided conversation and identify statements that should be stored as long-term knowledge for the group.
Return a JSON array. Each element must contain: Return a JSON array. Each element must contain:
{ {
@ -79,9 +110,9 @@ function buildExistingMemorySection (existingMemories = []) {
return `以下是关于用户的已知长期记忆,请在提取新记忆时参考,避免重复已有事实,并在信息变更时更新描述:\n${lines.join('\n')}` return `以下是关于用户的已知长期记忆,请在提取新记忆时参考,避免重复已有事实,并在信息变更时更新描述:\n${lines.join('\n')}`
} }
function resolveUserExtractionPrompts (existingMemories = []) { function resolveUserExtractionPrompts (existingMemories = [], presetSendMessageOption) {
const config = ChatGPTConfig.memory?.user || {} const config = ChatGPTConfig.memory?.user || {}
const systemTemplate = config.extractionSystemPrompt || `You are an assistant that extracts long-term personal preferences or persona details about a user. const systemTemplate = config.extractionSystemPrompt || presetSendMessageOption?.systemOverride || `You are an assistant that extracts long-term personal preferences or persona details about a user.
Given a conversation snippet between the user and the bot, identify durable information such as preferences, nicknames, roles, speaking style, habits, or other facts that remain valid over time. Given a conversation snippet between the user and the bot, identify durable information such as preferences, nicknames, roles, speaking style, habits, or other facts that remain valid over time.
Return a JSON array of **strings**, and nothing else, without any other characters including \`\`\` or \`\`\`json. Each string must be a short sentence (in the same language as the conversation) describing one piece of long-term memory. Do not include keys, JSON objects, or additional metadata. Ignore temporary topics or uncertain information.` Return a JSON array of **strings**, and nothing else, without any other characters including \`\`\` or \`\`\`json. Each string must be a short sentence (in the same language as the conversation) describing one piece of long-term memory. Do not include keys, JSON objects, or additional metadata. Ignore temporary topics or uncertain information.`
const userTemplate = config.extractionUserPrompt || `下面是用户与机器人的对话,请根据系统提示提取可长期记忆的个人信息。 const userTemplate = config.extractionUserPrompt || `下面是用户与机器人的对话,请根据系统提示提取可长期记忆的个人信息。
@ -103,8 +134,16 @@ function buildUserPrompt (messages, template) {
return template.replace('${messages}', body) return template.replace('${messages}', body)
} }
async function callModel ({ prompt, systemPrompt, model, maxToken = 4096, temperature = 0.2 }) { async function callModel ({ prompt, systemPrompt, model, maxToken = 4096, temperature = 0.2, sendMessageOption }) {
const { client } = await getClientForModel(model) const options = sendMessageOption
? JSON.parse(JSON.stringify(sendMessageOption))
: {}
options.model = model || options.model
if (!options.model) {
throw new Error('No model available for memory extraction call')
}
const resolvedModel = options.model
const { client } = await getClientForModel(resolvedModel)
const response = await client.sendMessage({ const response = await client.sendMessage({
role: 'user', role: 'user',
content: [ content: [
@ -114,10 +153,11 @@ async function callModel ({ prompt, systemPrompt, model, maxToken = 4096, temper
} }
] ]
}, SendMessageOption.create({ }, SendMessageOption.create({
model, ...options,
// temperature, model: options.model,
maxToken, temperature: options.temperature ?? temperature,
systemOverride: systemPrompt, maxToken: options.maxToken ?? maxToken,
systemOverride: systemPrompt ?? options.systemOverride,
disableHistoryRead: true, disableHistoryRead: true,
disableHistorySave: true, disableHistorySave: true,
stream: false stream: false
@ -125,44 +165,54 @@ async function callModel ({ prompt, systemPrompt, model, maxToken = 4096, temper
return collectTextFromResponse(response) return collectTextFromResponse(response)
} }
function resolveGroupExtractionModel () { function resolveGroupExtractionModel (presetSendMessageOption) {
const config = ChatGPTConfig.memory?.group const config = ChatGPTConfig.memory?.group
if (config?.extractionModel) { if (config?.extractionModel) {
return config.extractionModel return config.extractionModel
} }
if (presetSendMessageOption?.model) {
return presetSendMessageOption.model
}
if (ChatGPTConfig.llm?.defaultModel) { if (ChatGPTConfig.llm?.defaultModel) {
return ChatGPTConfig.llm.defaultModel return ChatGPTConfig.llm.defaultModel
} }
return ChatGPTConfig.llm?.embeddingModel || '' return ''
} }
function resolveUserExtractionModel () { function resolveUserExtractionModel (presetSendMessageOption) {
const config = ChatGPTConfig.memory?.user const config = ChatGPTConfig.memory?.user
if (config?.extractionModel) { if (config?.extractionModel) {
return config.extractionModel return config.extractionModel
} }
if (presetSendMessageOption?.model) {
return presetSendMessageOption.model
}
if (ChatGPTConfig.llm?.defaultModel) { if (ChatGPTConfig.llm?.defaultModel) {
return ChatGPTConfig.llm.defaultModel return ChatGPTConfig.llm.defaultModel
} }
return ChatGPTConfig.llm?.embeddingModel || '' return ''
} }
export async function extractGroupFacts (messages) { export async function extractGroupFacts (messages) {
if (!messages || messages.length === 0) { if (!messages || messages.length === 0) {
return [] return []
} }
const model = resolveGroupExtractionModel() const groupConfig = ChatGPTConfig.memory?.group || {}
const presetInfo = await resolvePresetSendMessageOption(groupConfig.extractionPresetId, 'group')
const presetOptions = presetInfo?.sendMessageOption
const model = resolveGroupExtractionModel(presetOptions)
if (!model) { if (!model) {
logger.warn('No model configured for group memory extraction') logger.warn('No model configured for group memory extraction')
return [] return []
} }
try { try {
const prompts = resolveGroupExtractionPrompts() const prompts = resolveGroupExtractionPrompts(presetOptions)
logger.debug(`[Memory] start group fact extraction, messages=${messages.length}, model=${model}`) logger.debug(`[Memory] start group fact extraction, messages=${messages.length}, model=${model}${presetInfo?.preset ? `, preset=${presetInfo.preset.id}` : ''}`)
const text = await callModel({ const text = await callModel({
prompt: buildGroupUserPrompt(messages, prompts.userTemplate), prompt: buildGroupUserPrompt(messages, prompts.userTemplate),
systemPrompt: prompts.system, systemPrompt: prompts.system,
model model,
sendMessageOption: presetOptions
}) })
const parsed = parseJSON(text) const parsed = parseJSON(text)
if (Array.isArray(parsed)) { if (Array.isArray(parsed)) {
@ -184,18 +234,22 @@ export async function extractUserMemories (messages, existingMemories = []) {
if (!messages || messages.length === 0) { if (!messages || messages.length === 0) {
return [] return []
} }
const model = resolveUserExtractionModel() const userConfig = ChatGPTConfig.memory?.user || {}
const presetInfo = await resolvePresetSendMessageOption(userConfig.extractionPresetId, 'user')
const presetOptions = presetInfo?.sendMessageOption
const model = resolveUserExtractionModel(presetOptions)
if (!model) { if (!model) {
logger.warn('No model configured for user memory extraction') logger.warn('No model configured for user memory extraction')
return [] return []
} }
try { try {
const prompts = resolveUserExtractionPrompts(existingMemories) const prompts = resolveUserExtractionPrompts(existingMemories, presetOptions)
logger.debug(`[Memory] start user memory extraction, snippets=${messages.length}, existing=${existingMemories.length}, model=${model}`) logger.debug(`[Memory] start user memory extraction, snippets=${messages.length}, existing=${existingMemories.length}, model=${model}${presetInfo?.preset ? `, preset=${presetInfo.preset.id}` : ''}`)
const text = await callModel({ const text = await callModel({
prompt: buildUserPrompt(messages, prompts.userTemplate), prompt: buildUserPrompt(messages, prompts.userTemplate),
systemPrompt: prompts.system, systemPrompt: prompts.system,
model model,
sendMessageOption: presetOptions
}) })
const parsed = parseJSON(text) const parsed = parseJSON(text)
if (Array.isArray(parsed)) { if (Array.isArray(parsed)) {

View file

@ -1,4 +1,4 @@
import { getMemoryDatabase, getVectorDimension, getGroupMemoryFtsConfig, resetVectorTableDimension } from './database.js' import { getMemoryDatabase, getVectorDimension, getGroupMemoryFtsConfig, resetVectorTableDimension, sanitiseFtsQueryInput } from './database.js'
import ChatGPTConfig from '../../config/config.js' import ChatGPTConfig from '../../config/config.js'
import { embedTexts } from '../chaite/vectorizer.js' import { embedTexts } from '../chaite/vectorizer.js'
@ -342,37 +342,42 @@ export class GroupMemoryStore {
if (!queryText || !queryText.trim()) { if (!queryText || !queryText.trim()) {
return [] return []
} }
const trimmedQuery = queryText.trim() const originalQuery = queryText.trim()
const ftsConfig = getGroupMemoryFtsConfig() const ftsConfig = getGroupMemoryFtsConfig()
const matchExpression = ftsConfig.matchQuery ? `${ftsConfig.matchQuery}(?)` : '?' const matchQueryParam = sanitiseFtsQueryInput(originalQuery, ftsConfig)
const results = [] const results = []
const seen = new Set() const seen = new Set()
try { if (matchQueryParam) {
const rows = this.db.prepare(` const matchExpression = ftsConfig.matchQuery ? `${ftsConfig.matchQuery}(?)` : '?'
SELECT gf.*, bm25(group_facts_fts) AS bm25_score try {
FROM group_facts_fts const rows = this.db.prepare(`
JOIN group_facts gf ON gf.id = group_facts_fts.rowid SELECT gf.*, bm25(group_facts_fts) AS bm25_score
WHERE gf.group_id = ? FROM group_facts_fts
AND group_facts_fts MATCH ${matchExpression} JOIN group_facts gf ON gf.id = group_facts_fts.rowid
ORDER BY bm25_score ASC WHERE gf.group_id = ?
LIMIT ? AND group_facts_fts MATCH ${matchExpression}
`).all(groupId, trimmedQuery, limit) ORDER BY bm25_score ASC
for (const row of rows) { LIMIT ?
const bm25Threshold = this.bm25Threshold `).all(groupId, matchQueryParam, limit)
if (bm25Threshold) { for (const row of rows) {
const score = Number(row?.bm25_score) const bm25Threshold = this.bm25Threshold
if (!Number.isFinite(score) || score > bm25Threshold) { if (bm25Threshold) {
continue const score = Number(row?.bm25_score)
if (!Number.isFinite(score) || score > bm25Threshold) {
continue
}
row.bm25_score = score
}
if (row && !seen.has(row.id)) {
results.push(row)
seen.add(row.id)
} }
row.bm25_score = score
}
if (row && !seen.has(row.id)) {
results.push(row)
seen.add(row.id)
} }
} catch (err) {
logger.warn('Text search failed for group memory:', err)
} }
} catch (err) { } else {
logger.warn('Text search failed for group memory:', err) logger.debug('[Memory] group memory text search skipped MATCH due to empty query after sanitisation')
} }
if (results.length < limit) { if (results.length < limit) {
@ -384,7 +389,7 @@ export class GroupMemoryStore {
AND instr(fact, ?) > 0 AND instr(fact, ?) > 0
ORDER BY importance DESC, created_at DESC ORDER BY importance DESC, created_at DESC
LIMIT ? LIMIT ?
`).all(groupId, trimmedQuery, Math.max(limit * 2, limit)) `).all(groupId, originalQuery, Math.max(limit * 2, limit))
for (const row of likeRows) { for (const row of likeRows) {
if (row && !seen.has(row.id)) { if (row && !seen.has(row.id)) {
results.push(row) results.push(row)

View file

@ -1,4 +1,4 @@
import { getMemoryDatabase, getUserMemoryFtsConfig } from './database.js' import { getMemoryDatabase, getUserMemoryFtsConfig, sanitiseFtsQueryInput } from './database.js'
import { md5 } from '../../utils/common.js' import { md5 } from '../../utils/common.js'
function normaliseId (value) { function normaliseId (value) {
@ -213,42 +213,46 @@ export class UserMemoryStore {
const normUserId = normaliseId(userId) const normUserId = normaliseId(userId)
const normGroupId = normaliseId(groupId) const normGroupId = normaliseId(groupId)
const filteredExclude = (excludeIds || []).filter(Boolean) const filteredExclude = (excludeIds || []).filter(Boolean)
const trimmedQuery = queryText.trim() const originalQuery = queryText.trim()
const ftsConfig = getUserMemoryFtsConfig() const ftsConfig = getUserMemoryFtsConfig()
const matchExpression = ftsConfig.matchQuery ? `${ftsConfig.matchQuery}(?)` : '?' const matchQueryParam = sanitiseFtsQueryInput(originalQuery, ftsConfig)
const params = [normUserId]
let query = `
SELECT um.*, bm25(user_memory_fts) AS bm25_score
FROM user_memory_fts
JOIN user_memory um ON um.id = user_memory_fts.rowid
WHERE um.user_id = ?
AND user_memory_fts MATCH ${matchExpression}
`
params.push(trimmedQuery)
if (normGroupId) {
query += ' AND (um.group_id = ? OR um.group_id IS NULL)'
params.push(normGroupId)
}
if (filteredExclude.length) {
query += ` AND um.id NOT IN (${filteredExclude.map(() => '?').join(',')})`
params.push(...filteredExclude)
}
query += `
ORDER BY bm25_score ASC, um.updated_at DESC
LIMIT ?
`
params.push(limit)
const results = [] const results = []
const seen = new Set(filteredExclude) const seen = new Set(filteredExclude)
try { if (matchQueryParam) {
const ftsRows = this.db.prepare(query).all(...params) const matchExpression = ftsConfig.matchQuery ? `${ftsConfig.matchQuery}(?)` : '?'
appendRows(results, ftsRows, seen) const params = [normUserId, matchQueryParam]
} catch (err) { let query = `
logger?.warn?.('User memory text search failed:', err) SELECT um.*, bm25(user_memory_fts) AS bm25_score
FROM user_memory_fts
JOIN user_memory um ON um.id = user_memory_fts.rowid
WHERE um.user_id = ?
AND user_memory_fts MATCH ${matchExpression}
`
if (normGroupId) {
query += ' AND (um.group_id = ? OR um.group_id IS NULL)'
params.push(normGroupId)
}
if (filteredExclude.length) {
query += ` AND um.id NOT IN (${filteredExclude.map(() => '?').join(',')})`
params.push(...filteredExclude)
}
query += `
ORDER BY bm25_score ASC, um.updated_at DESC
LIMIT ?
`
params.push(limit)
try {
const ftsRows = this.db.prepare(query).all(...params)
appendRows(results, ftsRows, seen)
} catch (err) {
logger?.warn?.('User memory text search failed:', err)
}
} else {
logger?.debug?.('[Memory] user memory text search skipped MATCH due to empty query after sanitisation')
} }
if (results.length < limit) { if (results.length < limit) {
const likeParams = [normUserId, trimmedQuery] const likeParams = [normUserId, originalQuery]
let likeQuery = ` let likeQuery = `
SELECT um.* SELECT um.*
FROM user_memory um FROM user_memory um