Files
awoooi/apps/api/scripts/reembed_bge_m3.py
Your Name 8c4dc7a5a8
Some checks failed
Code Review / ai-code-review (push) Successful in 10s
CD Pipeline / tests (push) Successful in 1m5s
CD Pipeline / build-and-deploy (push) Failing after 10m6s
CD Pipeline / post-deploy-checks (push) Has been skipped
chore(rls): 新增 manual script gate 與 canary wave1
2026-05-12 20:23:27 +08:00

190 lines
6.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Re-embed Script: bge-m3:latest 1024 維重新嵌入
===============================================
遷移 embedding_bge_m3_1024.sql 後執行,重新嵌入:
1. rag_chunksembedding IS NULL 的筆數)
2. playbook_embeddingsembedding IS NULL 的筆數)
用法:
cd apps/api
python scripts/reembed_bge_m3.py [--dry-run] [--batch 50]
前置條件:
1. embedding_bge_m3_1024.sql 已執行schema 已升為 vector(1024)
2. GCP-A Ollama (34.143.170.20:11434) 可連線且有 bge-m3:latest
3. DATABASE_URL 環境變數已設定(或 .env 存在)
2026-05-04 ogt + Claude Sonnet 4.6: ADR-110 GCP-A Primary Embedding 升級
"""
from __future__ import annotations
import argparse
import asyncio
import os
import sys
from pathlib import Path
# 確保 src 在 import 路徑
sys.path.insert(0, str(Path(__file__).parent.parent))
import asyncpg
import httpx
import structlog
logging = structlog.get_logger(__name__)
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://34.143.170.20:11434")
EMBEDDING_MODEL = "bge-m3:latest"
EXPECTED_DIM = 1024
PROJECT_ID = os.getenv("AWOOOP_PROJECT_ID", "awoooi")
async def embed_text(client: httpx.AsyncClient, text: str) -> list[float]:
"""呼叫 Ollama bge-m3 嵌入單一文本"""
resp = await client.post(
f"{OLLAMA_URL}/api/embeddings",
json={"model": EMBEDDING_MODEL, "prompt": text},
timeout=60.0,
)
resp.raise_for_status()
embedding = resp.json().get("embedding", [])
if len(embedding) != EXPECTED_DIM:
raise ValueError(f"bge-m3 維度錯誤: got {len(embedding)}, expected {EXPECTED_DIM}")
return embedding
async def reembed_rag_chunks(
conn: asyncpg.Connection,
client: httpx.AsyncClient,
batch_size: int,
dry_run: bool,
) -> int:
rows = await conn.fetch(
"SELECT id, content FROM rag_chunks WHERE embedding IS NULL ORDER BY id LIMIT $1",
batch_size * 10,
)
if not rows:
logging.info("rag_chunks_all_embedded")
return 0
done = 0
for row in rows:
try:
vec = await embed_text(client, row["content"])
if not dry_run:
vec_str = "[" + ",".join(f"{v:.8f}" for v in vec) + "]"
await conn.execute(
"UPDATE rag_chunks SET embedding = $1::vector WHERE id = $2",
vec_str, row["id"],
)
done += 1
if done % 10 == 0:
logging.info("rag_chunks_progress", done=done, total=len(rows))
except Exception as e:
logging.error("rag_chunk_embed_failed", id=row["id"], error=str(e))
return done
async def reembed_playbook_embeddings(
conn: asyncpg.Connection,
client: httpx.AsyncClient,
batch_size: int,
dry_run: bool,
) -> int:
# playbook_embeddings 關聯 playbooks 表取原始內容
rows = await conn.fetch("""
SELECT pe.playbook_id, p.title, p.description, p.steps
FROM playbook_embeddings pe
JOIN playbooks p ON pe.playbook_id = p.id
WHERE pe.embedding IS NULL
ORDER BY pe.playbook_id
LIMIT $1
""", batch_size * 10)
if not rows:
logging.info("playbook_embeddings_all_embedded")
return 0
done = 0
for row in rows:
text_parts = [row["title"] or "", row["description"] or ""]
if row["steps"]:
if isinstance(row["steps"], list):
text_parts.extend(str(s) for s in row["steps"])
else:
text_parts.append(str(row["steps"]))
text = "\n".join(p for p in text_parts if p)
try:
vec = await embed_text(client, text)
if not dry_run:
vec_str = "[" + ",".join(f"{v:.8f}" for v in vec) + "]"
await conn.execute(
"UPDATE playbook_embeddings SET embedding = $1::vector WHERE playbook_id = $2",
vec_str, row["playbook_id"],
)
done += 1
if done % 10 == 0:
logging.info("playbook_embed_progress", done=done, total=len(rows))
except Exception as e:
logging.error("playbook_embed_failed", playbook_id=row["playbook_id"], error=str(e))
return done
async def main(dry_run: bool, batch_size: int) -> None:
database_url = os.getenv("DATABASE_URL")
if not database_url:
# 嘗試讀 .env
env_file = Path(__file__).parent.parent / ".env"
if env_file.exists():
for line in env_file.read_text().splitlines():
if line.startswith("DATABASE_URL="):
database_url = line.split("=", 1)[1].strip().strip('"\'')
break
if not database_url:
print("❌ DATABASE_URL 未設定,請設定環境變數或 .env 檔案", file=sys.stderr)
sys.exit(1)
if dry_run:
print("🔍 DRY RUN 模式 — 不會實際更新 DB")
async with httpx.AsyncClient() as http_client:
# 先驗證 bge-m3 可用且維度正確
print(f"🔗 驗證 GCP-A Ollama ({OLLAMA_URL}) bge-m3 連線...")
try:
test_vec = await embed_text(http_client, "連線測試")
print(f"✅ bge-m3 可用,維度 = {len(test_vec)}")
except Exception as e:
print(f"❌ bge-m3 連線失敗: {e}", file=sys.stderr)
sys.exit(1)
conn = await asyncpg.connect(database_url)
try:
await conn.execute("SELECT set_config('app.project_id', $1, FALSE)", PROJECT_ID)
# 統計待嵌入筆數
rag_null = await conn.fetchval("SELECT COUNT(*) FROM rag_chunks WHERE embedding IS NULL")
pb_null = await conn.fetchval("SELECT COUNT(*) FROM playbook_embeddings WHERE embedding IS NULL")
print(f"📊 待嵌入rag_chunks={rag_null}playbook_embeddings={pb_null}")
if rag_null == 0 and pb_null == 0:
print("✅ 所有向量已嵌入,無需重新處理")
return
rag_done = await reembed_rag_chunks(conn, http_client, batch_size, dry_run)
pb_done = await reembed_playbook_embeddings(conn, http_client, batch_size, dry_run)
print(f"{'[DRY RUN] ' if dry_run else ''}✅ 完成: rag_chunks={rag_done}, playbook_embeddings={pb_done}")
finally:
await conn.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Re-embed script for bge-m3 1024 維遷移")
parser.add_argument("--dry-run", action="store_true", help="只統計,不寫 DB")
parser.add_argument("--batch", type=int, default=50, help="每批次處理筆數")
args = parser.parse_args()
asyncio.run(main(dry_run=args.dry_run, batch_size=args.batch))