vetrag/sentence_embeddings.go

155 lines
5.0 KiB
Go

package main
import (
"context"
"os"
"regexp"
"strings"
"time"
"github.com/sirupsen/logrus"
)
var sentenceSplitRegex = regexp.MustCompile(`(?m)(?:[^.!?\n]+[.!?]|[^.!?\n]+$)`)
// configurable via env (seconds); defaults chosen for model cold start friendliness
func envDuration(key string, def time.Duration) time.Duration {
if v := os.Getenv(key); v != "" {
if d, err := time.ParseDuration(v); err == nil {
return d
}
}
return def
}
// startSentenceEmbeddingBackfill launches a background goroutine that iterates all visits
// and stores (visit_id, sentence, translated, embedding) records in sentence_embeddings table
// if they do not already exist (relying on unique index ON CONFLICT DO NOTHING).
func startSentenceEmbeddingBackfill(repo *PGChatRepository, llm LLMClientAPI, vdb *VisitDB) {
if repo == nil || llm == nil || vdb == nil {
logrus.Info("Sentence embedding backfill skipped (missing repo, llm or vdb)")
return
}
if disable := strings.ToLower(os.Getenv("SENTENCE_BACKFILL_DISABLE")); disable == "1" || disable == "true" {
logrus.Info("Sentence embedding backfill disabled via SENTENCE_BACKFILL_DISABLE env var")
return
}
translateTimeout := envDuration("TRANSLATE_TIMEOUT", 45*time.Second)
embeddingTimeout := envDuration("EMBEDDING_TIMEOUT", 45*time.Second)
maxTranslateAttempts := 3
maxEmbeddingAttempts := 3
go func() {
start := time.Now()
logrus.WithFields(logrus.Fields{"translateTimeout": translateTimeout, "embeddingTimeout": embeddingTimeout}).Info("Sentence embedding backfill started")
processed := 0
inserted := 0
skippedExisting := 0
skippedDueToFailures := 0
for _, visit := range vdb.visitsDB { // visitsDB accessible within package
if strings.TrimSpace(visit.Visit) == "" {
continue
}
sentences := extractSentences(visit.Visit)
for _, s := range sentences {
processed++
trimmed := strings.TrimSpace(s)
if len(trimmed) < 3 {
continue
}
// Existence check before any LLM calls
existsCtx, existsCancel := context.WithTimeout(context.Background(), 2*time.Second)
exists, err := repo.ExistsSentenceEmbedding(existsCtx, visit.ID, trimmed)
existsCancel()
if err != nil {
logrus.WithError(err).Warnf("Exists check failed visit=%s sentence=%q", visit.ID, trimmed)
} else if exists {
skippedExisting++
continue
}
// Translation with retry/backoff
var translated string
translateErr := retry(maxTranslateAttempts, 0, func(at int) error {
ctx, cancel := context.WithTimeout(context.Background(), translateTimeout)
defer cancel()
resp, err := llm.TranslateToEnglish(ctx, trimmed)
if err != nil {
logrus.WithError(err).Warnf("Translate attempt=%d failed visit=%s sentence=%q", at+1, visit.ID, trimmed)
return err
}
translated = strings.TrimSpace(resp)
return nil
})
if translateErr != nil || translated == "" {
translated = trimmed // fallback keep original language
}
// Embedding with retry/backoff (skip if translation totally failed with deadline each time)
var emb []float64
embErr := retry(maxEmbeddingAttempts, 0, func(at int) error {
ctx, cancel := context.WithTimeout(context.Background(), embeddingTimeout)
defer cancel()
vec, err := llm.GetEmbeddings(ctx, translated)
if err != nil {
logrus.WithError(err).Warnf("Embeddings attempt=%d failed visit=%s sentence=%q", at+1, visit.ID, trimmed)
return err
}
emb = vec
return nil
})
if embErr != nil {
skippedDueToFailures++
continue
}
persistCtx, pcancel := context.WithTimeout(context.Background(), 5*time.Second)
if err := repo.InsertSentenceEmbedding(persistCtx, visit.ID, trimmed, translated, emb); err == nil {
inserted++
}
pcancel()
// Throttle (configurable?)
time.Sleep(50 * time.Millisecond)
}
}
logrus.Infof("Sentence embedding backfill complete processed=%d inserted=%d skipped_existing=%d skipped_failures=%d elapsed=%s", processed, inserted, skippedExisting, skippedDueToFailures, time.Since(start))
}()
}
// retry executes fn up to attempts times with exponential backoff starting at base (or 200ms if base==0)
func retry(attempts int, base time.Duration, fn func(attempt int) error) error {
if attempts <= 0 {
return nil
}
if base <= 0 {
base = 200 * time.Millisecond
}
var err error
for a := 0; a < attempts; a++ {
err = fn(a)
if err == nil {
return nil
}
// backoff except after last attempt
if a < attempts-1 {
backoff := base << a // exponential
time.Sleep(backoff)
}
}
return err
}
// extractSentences splits a block of text into sentence-like units.
func extractSentences(text string) []string {
// First replace newlines with space to keep regex simpler, keep periods.
normalized := strings.ReplaceAll(text, "\n", " ")
matches := sentenceSplitRegex.FindAllString(normalized, -1)
var out []string
for _, m := range matches {
m = strings.TrimSpace(m)
if m != "" {
out = append(out, m)
}
}
return out
}