INFO · info-20251219-003
Memory 系统技术方案 (Gemini Embedding + pgvector)
[INFO] Memory 系统技术方案 (Gemini Embedding + pgvector)
- 时间: 2024-12-19
- 类型: 方案
- 来源: 技术设计
- 置信度: 8/10
- 标签: #Memory #Embedding #pgvector #Gemini #技术方案
概述
基于 Gemini Embedding + PostgreSQL pgvector 的 Memory 系统完整技术方案。
一、数据模型设计
1.1 Memory 类型定义
| 类型 | 说明 | 生成时机 | 生命周期 |
|---|---|---|---|
raw | 原始对话 | 每轮对话结束 | 保留 N 天 / 滚动清理 |
summary | 对话摘要 | 对话结束 / 定期聚合 | 长期保留 |
facts | 关键事实 | LLM 提取 | 长期保留,可更新 |
preferences | 用户偏好 | LLM 识别 | 长期保留,可覆盖 |
1.2 各类型字段设计
raw(原始对话)
{
"session_id": "会话ID",
"role": "user/assistant",
"content": "原始内容",
"turn_index": 0,
"intent": "识别的意图",
"timestamp": "2024-12-19T10:00:00Z"
}
summary(对话摘要)
{
"session_id": "会话ID",
"summary": "用户询问了日程安排,确认了明天下午3点的会议",
"key_topics": ["日程", "会议"],
"sentiment": "neutral",
"action_items": ["提醒明天会议"]
}
facts(关键事实)
{
"fact": "用户的老板叫张总",
"category": "relationship",
"confidence": 0.9,
"source_session": "会话ID",
"first_seen": "2024-12-19",
"last_confirmed": "2024-12-19",
"times_mentioned": 3
}
preferences(用户偏好)
{
"preference": "喜欢简洁的回复风格",
"category": "communication_style",
"strength": 0.8,
"evidence": ["多次要求简短回答", "跳过详细解释"]
}
二、数据库设计
2.1 表结构
-- 启用扩展
CREATE EXTENSION IF NOT EXISTS vector;
CREATE EXTENSION IF NOT EXISTS pg_trgm; -- 用于文本模糊搜索
-- 主表:memories
CREATE TABLE memories (
id BIGSERIAL PRIMARY KEY,
user_id VARCHAR(64) NOT NULL,
type VARCHAR(16) NOT NULL CHECK (type IN ('raw', 'summary', 'facts', 'preferences')),
-- 内容
content TEXT NOT NULL,
content_hash VARCHAR(64), -- 用于去重
-- 向量
embedding vector(768),
-- 元数据
metadata JSONB NOT NULL DEFAULT '{}',
-- 关联
session_id VARCHAR(64),
source_memory_id BIGINT REFERENCES memories(id),
-- 权重和状态
importance FLOAT DEFAULT 0.5, -- 重要性评分 0-1
access_count INT DEFAULT 0, -- 访问次数
is_active BOOLEAN DEFAULT true, -- 是否有效
-- 时间
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
expires_at TIMESTAMPTZ, -- 过期时间(raw 类型用)
last_accessed_at TIMESTAMPTZ
);
-- 索引
CREATE INDEX idx_memories_user_type ON memories(user_id, type);
CREATE INDEX idx_memories_user_active ON memories(user_id, is_active) WHERE is_active = true;
CREATE INDEX idx_memories_session ON memories(session_id) WHERE session_id IS NOT NULL;
CREATE INDEX idx_memories_expires ON memories(expires_at) WHERE expires_at IS NOT NULL;
CREATE INDEX idx_memories_content_hash ON memories(user_id, content_hash);
CREATE INDEX idx_memories_metadata ON memories USING gin(metadata);
-- 向量索引(HNSW,性能更好)
CREATE INDEX idx_memories_embedding ON memories
USING hnsw (embedding vector_cosine_ops)
WITH (m = 16, ef_construction = 64);
-- 更新时间触发器
CREATE OR REPLACE FUNCTION update_updated_at()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = NOW();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER memories_updated_at
BEFORE UPDATE ON memories
FOR EACH ROW
EXECUTE FUNCTION update_updated_at();
2.2 分区策略(可选,数据量大时)
-- 按用户分区(如果用户量大)
CREATE TABLE memories_partitioned (
LIKE memories INCLUDING ALL
) PARTITION BY HASH (user_id);
-- 创建分区
CREATE TABLE memories_p0 PARTITION OF memories_partitioned FOR VALUES WITH (MODULUS 4, REMAINDER 0);
CREATE TABLE memories_p1 PARTITION OF memories_partitioned FOR VALUES WITH (MODULUS 4, REMAINDER 1);
CREATE TABLE memories_p2 PARTITION OF memories_partitioned FOR VALUES WITH (MODULUS 4, REMAINDER 2);
CREATE TABLE memories_p3 PARTITION OF memories_partitioned FOR VALUES WITH (MODULUS 4, REMAINDER 3);
三、Embedding 服务
3.1 封装类
import google.generativeai as genai
from typing import List, Optional
from tenacity import retry, stop_after_attempt, wait_exponential
import hashlib
import asyncio
from dataclasses import dataclass
from enum import Enum
class TaskType(Enum):
RETRIEVAL_DOCUMENT = "retrieval_document" # 存储时用
RETRIEVAL_QUERY = "retrieval_query" # 查询时用
SEMANTIC_SIMILARITY = "semantic_similarity"
CLASSIFICATION = "classification"
@dataclass
class EmbeddingResult:
text: str
embedding: List[float]
model: str
dimensions: int
class GeminiEmbeddingService:
def __init__(self, api_key: str, model: str = "models/text-embedding-004"):
genai.configure(api_key=api_key)
self.model = model
self.dimensions = 768
self._cache = {} # 简单内存缓存
def _get_cache_key(self, text: str, task_type: TaskType) -> str:
return hashlib.md5(f"{text}:{task_type.value}".encode()).hexdigest()
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10))
def embed(self, text: str, task_type: TaskType = TaskType.RETRIEVAL_DOCUMENT) -> EmbeddingResult:
"""单条文本 embedding"""
cache_key = self._get_cache_key(text, task_type)
if cache_key in self._cache:
return self._cache[cache_key]
# 文本预处理
text = self._preprocess(text)
result = genai.embed_content(
model=self.model,
content=text,
task_type=task_type.value
)
embedding_result = EmbeddingResult(
text=text,
embedding=result['embedding'],
model=self.model,
dimensions=len(result['embedding'])
)
self._cache[cache_key] = embedding_result
return embedding_result
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10))
def embed_batch(self, texts: List[str], task_type: TaskType = TaskType.RETRIEVAL_DOCUMENT) -> List[EmbeddingResult]:
"""批量 embedding(Gemini 支持批量)"""
texts = [self._preprocess(t) for t in texts]
result = genai.embed_content(
model=self.model,
content=texts,
task_type=task_type.value
)
return [
EmbeddingResult(
text=text,
embedding=emb,
model=self.model,
dimensions=len(emb)
)
for text, emb in zip(texts, result['embedding'])
]
def _preprocess(self, text: str) -> str:
"""文本预处理"""
# 去除多余空白
text = ' '.join(text.split())
# 截断过长文本(Gemini 限制)
max_chars = 10000
if len(text) > max_chars:
text = text[:max_chars]
return text
async def embed_async(self, text: str, task_type: TaskType = TaskType.RETRIEVAL_DOCUMENT) -> EmbeddingResult:
"""异步版本"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.embed, text, task_type)
3.2 Embedding 策略
class MemoryEmbeddingStrategy:
"""不同类型 Memory 的 embedding 策略"""
def __init__(self, embedding_service: GeminiEmbeddingService):
self.service = embedding_service
def embed_raw(self, content: str, metadata: dict) -> List[float]:
"""原始对话:包含角色和意图"""
role = metadata.get('role', 'unknown')
intent = metadata.get('intent', '')
enriched = f"[{role}] {content}"
if intent:
enriched += f" (intent: {intent})"
return self.service.embed(enriched, TaskType.RETRIEVAL_DOCUMENT).embedding
def embed_summary(self, content: str, metadata: dict) -> List[float]:
"""摘要:强调主题"""
topics = metadata.get('key_topics', [])
enriched = content
if topics:
enriched = f"Topics: {', '.join(topics)}. {content}"
return self.service.embed(enriched, TaskType.RETRIEVAL_DOCUMENT).embedding
def embed_fact(self, content: str, metadata: dict) -> List[float]:
"""事实:包含类别"""
category = metadata.get('category', '')
enriched = content
if category:
enriched = f"[{category}] {content}"
return self.service.embed(enriched, TaskType.RETRIEVAL_DOCUMENT).embedding
def embed_preference(self, content: str, metadata: dict) -> List[float]:
"""偏好:标注为偏好"""
enriched = f"User preference: {content}"
return self.service.embed(enriched, TaskType.RETRIEVAL_DOCUMENT).embedding
def embed_query(self, query: str) -> List[float]:
"""查询向量:用 retrieval_query"""
return self.service.embed(query, TaskType.RETRIEVAL_QUERY).embedding
四、Memory 存储服务
import json
import hashlib
from datetime import datetime, timedelta
from typing import List, Optional, Dict, Any
from dataclasses import dataclass
from enum import Enum
import psycopg2
from psycopg2.extras import RealDictCursor
class MemoryType(Enum):
RAW = "raw"
SUMMARY = "summary"
FACTS = "facts"
PREFERENCES = "preferences"
@dataclass
class Memory:
id: Optional[int]
user_id: str
type: MemoryType
content: str
embedding: Optional[List[float]]
metadata: Dict[str, Any]
importance: float = 0.5
session_id: Optional[str] = None
created_at: Optional[datetime] = None
@dataclass
class RecallResult:
memory: Memory
similarity: float
class MemoryStore:
def __init__(self, db_connection, embedding_service: GeminiEmbeddingService):
self.conn = db_connection
self.embedding_service = embedding_service
self.strategy = MemoryEmbeddingStrategy(embedding_service)
# 各类型的过期时间配置
self.ttl_config = {
MemoryType.RAW: timedelta(days=30),
MemoryType.SUMMARY: None, # 不过期
MemoryType.FACTS: None,
MemoryType.PREFERENCES: None,
}
def _get_content_hash(self, content: str) -> str:
return hashlib.sha256(content.encode()).hexdigest()[:16]
def _get_embedding(self, memory_type: MemoryType, content: str, metadata: dict) -> List[float]:
"""根据类型选择 embedding 策略"""
if memory_type == MemoryType.RAW:
return self.strategy.embed_raw(content, metadata)
elif memory_type == MemoryType.SUMMARY:
return self.strategy.embed_summary(content, metadata)
elif memory_type == MemoryType.FACTS:
return self.strategy.embed_fact(content, metadata)
elif memory_type == MemoryType.PREFERENCES:
return self.strategy.embed_preference(content, metadata)
def save(self, memory: Memory, dedupe: bool = True) -> int:
"""保存 Memory"""
content_hash = self._get_content_hash(memory.content)
# 去重检查
if dedupe:
existing = self._find_duplicate(memory.user_id, memory.type, content_hash)
if existing:
# 更新访问次数和时间
self._update_access(existing['id'])
return existing['id']
# 生成 embedding
embedding = self._get_embedding(memory.type, memory.content, memory.metadata)
# 计算过期时间
expires_at = None
ttl = self.ttl_config.get(memory.type)
if ttl:
expires_at = datetime.now() + ttl
with self.conn.cursor() as cur:
cur.execute("""
INSERT INTO memories
(user_id, type, content, content_hash, embedding, metadata,
importance, session_id, expires_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING id
""", (
memory.user_id,
memory.type.value,
memory.content,
content_hash,
embedding,
json.dumps(memory.metadata),
memory.importance,
memory.session_id,
expires_at
))
memory_id = cur.fetchone()[0]
self.conn.commit()
return memory_id
def save_batch(self, memories: List[Memory]) -> List[int]:
"""批量保存"""
# 批量生成 embedding
texts = [m.content for m in memories]
embeddings = self.embedding_service.embed_batch(texts)
ids = []
with self.conn.cursor() as cur:
for memory, emb_result in zip(memories, embeddings):
content_hash = self._get_content_hash(memory.content)
expires_at = None
ttl = self.ttl_config.get(memory.type)
if ttl:
expires_at = datetime.now() + ttl
cur.execute("""
INSERT INTO memories
(user_id, type, content, content_hash, embedding, metadata,
importance, session_id, expires_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING id
""", (
memory.user_id,
memory.type.value,
memory.content,
content_hash,
emb_result.embedding,
json.dumps(memory.metadata),
memory.importance,
memory.session_id,
expires_at
))
ids.append(cur.fetchone()[0])
self.conn.commit()
return ids
def _find_duplicate(self, user_id: str, memory_type: MemoryType, content_hash: str) -> Optional[dict]:
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute("""
SELECT id, content FROM memories
WHERE user_id = %s AND type = %s AND content_hash = %s AND is_active = true
LIMIT 1
""", (user_id, memory_type.value, content_hash))
return cur.fetchone()
def _update_access(self, memory_id: int):
with self.conn.cursor() as cur:
cur.execute("""
UPDATE memories
SET access_count = access_count + 1, last_accessed_at = NOW()
WHERE id = %s
""", (memory_id,))
self.conn.commit()
def recall(
self,
user_id: str,
query: str,
top_k: int = 10,
types: Optional[List[MemoryType]] = None,
min_similarity: float = 0.5,
time_decay: bool = True
) -> List[RecallResult]:
"""召回相关 Memory"""
query_embedding = self.strategy.embed_query(query)
# 构建类型过滤
type_filter = ""
if types:
type_values = [t.value for t in types]
type_filter = f"AND type = ANY(ARRAY{type_values})"
# 基础查询
sql = f"""
WITH ranked AS (
SELECT
id, user_id, type, content, metadata, importance,
session_id, created_at, access_count,
1 - (embedding <=> %s::vector) AS similarity
FROM memories
WHERE user_id = %s
AND is_active = true
AND (expires_at IS NULL OR expires_at > NOW())
{type_filter}
)
SELECT * FROM ranked
WHERE similarity >= %s
ORDER BY similarity DESC
LIMIT %s
"""
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(sql, (query_embedding, user_id, min_similarity, top_k * 2))
rows = cur.fetchall()
# 重排序(考虑时间衰减和重要性)
results = []
for row in rows:
memory = Memory(
id=row['id'],
user_id=row['user_id'],
type=MemoryType(row['type']),
content=row['content'],
embedding=None,
metadata=row['metadata'],
importance=row['importance'],
session_id=row['session_id'],
created_at=row['created_at']
)
# 计算最终分数
score = row['similarity']
if time_decay:
# 时间衰减:越新越重要
age_days = (datetime.now() - row['created_at']).days
decay = 1.0 / (1.0 + age_days * 0.01)
score *= decay
# 重要性加权
score *= (0.5 + row['importance'] * 0.5)
# 访问频率加分
score *= (1.0 + min(row['access_count'], 10) * 0.02)
results.append(RecallResult(memory=memory, similarity=score))
# 按最终分数排序
results.sort(key=lambda x: x.similarity, reverse=True)
# 更新访问记录
for r in results[:top_k]:
self._update_access(r.memory.id)
return results[:top_k]
def recall_by_type(
self,
user_id: str,
memory_type: MemoryType,
limit: int = 20
) -> List[Memory]:
"""按类型召回最近的 Memory(不用向量搜索)"""
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute("""
SELECT id, user_id, type, content, metadata, importance, session_id, created_at
FROM memories
WHERE user_id = %s AND type = %s AND is_active = true
ORDER BY created_at DESC
LIMIT %s
""", (user_id, memory_type.value, limit))
rows = cur.fetchall()
return [
Memory(
id=row['id'],
user_id=row['user_id'],
type=MemoryType(row['type']),
content=row['content'],
embedding=None,
metadata=row['metadata'],
importance=row['importance'],
session_id=row['session_id'],
created_at=row['created_at']
)
for row in rows
]
五、Memory 提取服务(LLM)
from typing import List, Tuple
import json
class MemoryExtractor:
"""用 LLM 从对话中提取 Memory"""
def __init__(self, llm_client):
self.llm = llm_client
def extract_facts(self, conversation: List[dict]) -> List[Tuple[str, dict]]:
"""从对话中提取事实"""
prompt = f"""从以下对话中提取关于用户的关键事实。
只提取明确陈述的事实,不要推测。
对话:
{json.dumps(conversation, ensure_ascii=False, indent=2)}
以 JSON 数组格式返回,每个事实包含:
- fact: 事实描述
- category: 分类 (personal_info/relationship/work/preference/habit/other)
- confidence: 置信度 0-1
只返回 JSON,不要其他内容。"""
response = self.llm.generate(prompt)
return self._parse_json_array(response)
def extract_preferences(self, conversation: List[dict]) -> List[Tuple[str, dict]]:
"""从对话中识别用户偏好"""
prompt = f"""从以下对话中识别用户的偏好和习惯。
关注用户的:沟通风格偏好、时间偏好、工具偏好等。
对话:
{json.dumps(conversation, ensure_ascii=False, indent=2)}
以 JSON 数组格式返回,每个偏好包含:
- preference: 偏好描述
- category: 分类 (communication_style/time/tool/content/other)
- strength: 强度 0-1
- evidence: 支持证据
只返回 JSON,不要其他内容。"""
response = self.llm.generate(prompt)
return self._parse_json_array(response)
def generate_summary(self, conversation: List[dict]) -> Tuple[str, dict]:
"""生成对话摘要"""
prompt = f"""为以下对话生成简洁摘要。
对话:
{json.dumps(conversation, ensure_ascii=False, indent=2)}
以 JSON 格式返回:
- summary: 摘要(2-3句话)
- key_topics: 主题列表
- sentiment: 情感 (positive/neutral/negative)
- action_items: 待办事项列表(如有)
只返回 JSON,不要其他内容。"""
response = self.llm.generate(prompt)
return self._parse_json(response)
def _parse_json_array(self, text: str) -> list:
try:
start = text.find('[')
end = text.rfind(']') + 1
if start >= 0 and end > start:
return json.loads(text[start:end])
except:
pass
return []
def _parse_json(self, text: str) -> dict:
try:
start = text.find('{')
end = text.rfind('}') + 1
if start >= 0 and end > start:
return json.loads(text[start:end])
except:
pass
return {}
六、Pipeline 集成
class MemoryPipeline:
"""Memory 处理流水线"""
def __init__(self, memory_store: MemoryStore, extractor: MemoryExtractor):
self.store = memory_store
self.extractor = extractor
def process_conversation(self, user_id: str, session_id: str, conversation: List[dict]):
"""处理完整对话"""
# 1. 保存原始对话
for i, turn in enumerate(conversation):
raw_memory = Memory(
id=None,
user_id=user_id,
type=MemoryType.RAW,
content=turn['content'],
embedding=None,
metadata={
'role': turn['role'],
'turn_index': i,
'intent': turn.get('intent', '')
},
session_id=session_id
)
self.store.save(raw_memory)
# 2. 提取并保存事实
facts = self.extractor.extract_facts(conversation)
for fact_data in facts:
fact_memory = Memory(
id=None,
user_id=user_id,
type=MemoryType.FACTS,
content=fact_data.get('fact', ''),
embedding=None,
metadata={
'category': fact_data.get('category', 'other'),
'confidence': fact_data.get('confidence', 0.5),
'source_session': session_id
},
importance=fact_data.get('confidence', 0.5),
session_id=session_id
)
self.store.save(fact_memory, dedupe=True)
# 3. 识别并保存偏好
preferences = self.extractor.extract_preferences(conversation)
for pref_data in preferences:
pref_memory = Memory(
id=None,
user_id=user_id,
type=MemoryType.PREFERENCES,
content=pref_data.get('preference', ''),
embedding=None,
metadata={
'category': pref_data.get('category', 'other'),
'strength': pref_data.get('strength', 0.5),
'evidence': pref_data.get('evidence', [])
},
importance=pref_data.get('strength', 0.5),
session_id=session_id
)
self.store.save(pref_memory, dedupe=True)
# 4. 生成并保存摘要
summary_data = self.extractor.generate_summary(conversation)
if summary_data.get('summary'):
summary_memory = Memory(
id=None,
user_id=user_id,
type=MemoryType.SUMMARY,
content=summary_data['summary'],
embedding=None,
metadata={
'key_topics': summary_data.get('key_topics', []),
'sentiment': summary_data.get('sentiment', 'neutral'),
'action_items': summary_data.get('action_items', [])
},
importance=0.7,
session_id=session_id
)
self.store.save(summary_memory)
def get_context_for_conversation(self, user_id: str, current_query: str) -> str:
"""为当前对话获取相关上下文"""
# 1. 召回相关 Memory
relevant = self.store.recall(
user_id=user_id,
query=current_query,
top_k=10,
min_similarity=0.4
)
# 2. 获取最近的偏好
preferences = self.store.recall_by_type(user_id, MemoryType.PREFERENCES, limit=5)
# 3. 组装上下文
context_parts = []
if preferences:
prefs_text = "\n".join([f"- {p.content}" for p in preferences])
context_parts.append(f"用户偏好:\n{prefs_text}")
if relevant:
# 按类型分组
facts = [r for r in relevant if r.memory.type == MemoryType.FACTS]
summaries = [r for r in relevant if r.memory.type == MemoryType.SUMMARY]
if facts:
facts_text = "\n".join([f"- {r.memory.content}" for r in facts[:5]])
context_parts.append(f"相关事实:\n{facts_text}")
if summaries:
sum_text = "\n".join([f"- {r.memory.content}" for r in summaries[:3]])
context_parts.append(f"历史摘要:\n{sum_text}")
return "\n\n".join(context_parts)
七、清理任务
class MemoryMaintenance:
"""Memory 维护任务"""
def __init__(self, conn):
self.conn = conn
def cleanup_expired(self) -> int:
"""清理过期 Memory"""
with self.conn.cursor() as cur:
cur.execute("""
UPDATE memories SET is_active = false
WHERE expires_at IS NOT NULL AND expires_at < NOW() AND is_active = true
RETURNING id
""")
count = cur.rowcount
self.conn.commit()
return count
def merge_duplicate_facts(self, user_id: str):
"""合并重复的事实(语义相似)"""
# 实现:用向量相似度找重复,保留最早的,合并 times_mentioned
pass
def update_importance_scores(self, user_id: str):
"""根据访问频率更新重要性分数"""
with self.conn.cursor() as cur:
cur.execute("""
UPDATE memories
SET importance = LEAST(1.0, importance + access_count * 0.01)
WHERE user_id = %s AND access_count > 0
""", (user_id,))
self.conn.commit()
Demo 验收标准
| 阶段 | 测试项 | 通过标准 |
|---|---|---|
| 基础 | Embedding 生成 | 768 维向量,耗时 <500ms |
| 基础 | 单条存储 | 写入成功,去重生效 |
| 基础 | 向量检索 | 相关内容 similarity >0.7 |
| 进阶 | 批量存储 | 10条 <2s |
| 进阶 | 混合召回 | 按类型过滤正确 |
| 进阶 | 重排序 | 时间衰减、重要性生效 |
| 集成 | Pipeline | 对话 → raw/summary/facts/preferences |
| 集成 | 上下文组装 | 返回格式正确 |
关联
- 触发规则: -
- 待验证: 实际开发中验证方案可行性