Files
awoooi/apps/api/scripts/reembed_bge_m3.py
Your Name b4055c5915
Some checks failed
Code Review / ai-code-review (push) Successful in 57s
run-migration / migrate (push) Failing after 44s
feat(embedding): ADR-110 升級 bge-m3:latest 1024 維向量
GCP-A (34.143.170.20) 無 nomic-embed-text,改用 bge-m3:latest(專用
多語言 embedding 模型),產生 1024 維向量。

變更:
- embedding_service.py: 加入 bge-m3:latest=1024 維到 MODEL_DIMENSIONS,
  預設模型改為 bge-m3:latest,更新文件說明
- playbook_embedding_repository.py + interfaces.py: 更新維度說明
- migrations/embedding_bge_m3_1024.sql: pgvector schema 遷移
  rag_chunks + playbook_embeddings vector(768) → vector(1024)
- scripts/reembed_bge_m3.py: 遷移後重新嵌入現有資料的 script

遷移步驟:
  1. 執行 embedding_bge_m3_1024.sql(清空現有 768 維向量,變更維度)
  2. 執行 python scripts/reembed_bge_m3.py 重新嵌入

2026-05-04 ogt + Claude Sonnet 4.6

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-04 11:18:20 +08:00

188 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Re-embed Script: bge-m3:latest 1024 維重新嵌入
===============================================
遷移 embedding_bge_m3_1024.sql 後執行,重新嵌入:
1. rag_chunksembedding IS NULL 的筆數)
2. playbook_embeddingsembedding IS NULL 的筆數)
用法:
cd apps/api
python scripts/reembed_bge_m3.py [--dry-run] [--batch 50]
前置條件:
1. embedding_bge_m3_1024.sql 已執行schema 已升為 vector(1024)
2. GCP-A Ollama (34.143.170.20:11434) 可連線且有 bge-m3:latest
3. DATABASE_URL 環境變數已設定(或 .env 存在)
2026-05-04 ogt + Claude Sonnet 4.6: ADR-110 GCP-A Primary Embedding 升級
"""
from __future__ import annotations
import argparse
import asyncio
import os
import sys
from pathlib import Path
# 確保 src 在 import 路徑
sys.path.insert(0, str(Path(__file__).parent.parent))
import asyncpg
import httpx
import structlog
logging = structlog.get_logger(__name__)
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://34.143.170.20:11434")
EMBEDDING_MODEL = "bge-m3:latest"
EXPECTED_DIM = 1024
async def embed_text(client: httpx.AsyncClient, text: str) -> list[float]:
"""呼叫 Ollama bge-m3 嵌入單一文本"""
resp = await client.post(
f"{OLLAMA_URL}/api/embeddings",
json={"model": EMBEDDING_MODEL, "prompt": text},
timeout=60.0,
)
resp.raise_for_status()
embedding = resp.json().get("embedding", [])
if len(embedding) != EXPECTED_DIM:
raise ValueError(f"bge-m3 維度錯誤: got {len(embedding)}, expected {EXPECTED_DIM}")
return embedding
async def reembed_rag_chunks(
conn: asyncpg.Connection,
client: httpx.AsyncClient,
batch_size: int,
dry_run: bool,
) -> int:
rows = await conn.fetch(
"SELECT id, content FROM rag_chunks WHERE embedding IS NULL ORDER BY id LIMIT $1",
batch_size * 10,
)
if not rows:
logging.info("rag_chunks_all_embedded")
return 0
done = 0
for row in rows:
try:
vec = await embed_text(client, row["content"])
if not dry_run:
vec_str = "[" + ",".join(f"{v:.8f}" for v in vec) + "]"
await conn.execute(
"UPDATE rag_chunks SET embedding = $1::vector WHERE id = $2",
vec_str, row["id"],
)
done += 1
if done % 10 == 0:
logging.info("rag_chunks_progress", done=done, total=len(rows))
except Exception as e:
logging.error("rag_chunk_embed_failed", id=row["id"], error=str(e))
return done
async def reembed_playbook_embeddings(
conn: asyncpg.Connection,
client: httpx.AsyncClient,
batch_size: int,
dry_run: bool,
) -> int:
# playbook_embeddings 關聯 playbooks 表取原始內容
rows = await conn.fetch("""
SELECT pe.playbook_id, p.title, p.description, p.steps
FROM playbook_embeddings pe
JOIN playbooks p ON pe.playbook_id = p.id
WHERE pe.embedding IS NULL
ORDER BY pe.playbook_id
LIMIT $1
""", batch_size * 10)
if not rows:
logging.info("playbook_embeddings_all_embedded")
return 0
done = 0
for row in rows:
text_parts = [row["title"] or "", row["description"] or ""]
if row["steps"]:
if isinstance(row["steps"], list):
text_parts.extend(str(s) for s in row["steps"])
else:
text_parts.append(str(row["steps"]))
text = "\n".join(p for p in text_parts if p)
try:
vec = await embed_text(client, text)
if not dry_run:
vec_str = "[" + ",".join(f"{v:.8f}" for v in vec) + "]"
await conn.execute(
"UPDATE playbook_embeddings SET embedding = $1::vector WHERE playbook_id = $2",
vec_str, row["playbook_id"],
)
done += 1
if done % 10 == 0:
logging.info("playbook_embed_progress", done=done, total=len(rows))
except Exception as e:
logging.error("playbook_embed_failed", playbook_id=row["playbook_id"], error=str(e))
return done
async def main(dry_run: bool, batch_size: int) -> None:
database_url = os.getenv("DATABASE_URL")
if not database_url:
# 嘗試讀 .env
env_file = Path(__file__).parent.parent / ".env"
if env_file.exists():
for line in env_file.read_text().splitlines():
if line.startswith("DATABASE_URL="):
database_url = line.split("=", 1)[1].strip().strip('"\'')
break
if not database_url:
print("❌ DATABASE_URL 未設定,請設定環境變數或 .env 檔案", file=sys.stderr)
sys.exit(1)
if dry_run:
print("🔍 DRY RUN 模式 — 不會實際更新 DB")
async with httpx.AsyncClient() as http_client:
# 先驗證 bge-m3 可用且維度正確
print(f"🔗 驗證 GCP-A Ollama ({OLLAMA_URL}) bge-m3 連線...")
try:
test_vec = await embed_text(http_client, "連線測試")
print(f"✅ bge-m3 可用,維度 = {len(test_vec)}")
except Exception as e:
print(f"❌ bge-m3 連線失敗: {e}", file=sys.stderr)
sys.exit(1)
conn = await asyncpg.connect(database_url)
try:
# 統計待嵌入筆數
rag_null = await conn.fetchval("SELECT COUNT(*) FROM rag_chunks WHERE embedding IS NULL")
pb_null = await conn.fetchval("SELECT COUNT(*) FROM playbook_embeddings WHERE embedding IS NULL")
print(f"📊 待嵌入rag_chunks={rag_null}playbook_embeddings={pb_null}")
if rag_null == 0 and pb_null == 0:
print("✅ 所有向量已嵌入,無需重新處理")
return
rag_done = await reembed_rag_chunks(conn, http_client, batch_size, dry_run)
pb_done = await reembed_playbook_embeddings(conn, http_client, batch_size, dry_run)
print(f"{'[DRY RUN] ' if dry_run else ''}✅ 完成: rag_chunks={rag_done}, playbook_embeddings={pb_done}")
finally:
await conn.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Re-embed script for bge-m3 1024 維遷移")
parser.add_argument("--dry-run", action="store_true", help="只統計,不寫 DB")
parser.add_argument("--batch", type=int, default=50, help="每批次處理筆數")
args = parser.parse_args()
asyncio.run(main(dry_run=args.dry_run, batch_size=args.batch))