vetrag/repository.go

245 lines
7.6 KiB
Go

package main
import (
"context"
"os"
"time"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/sirupsen/logrus"
)
// ChatInteraction represents a persisted chat request/response metadata (raw JSON moved to chat_llm_raw)
type ChatInteraction struct {
CorrelationID string `json:"correlation_id"`
UserMessage string `json:"user_message"`
Translate string `json:"translate"`
Animal string `json:"animal"`
Keywords []string `json:"keywords"`
BestVisitID string `json:"best_visit_id"`
TotalPrice int `json:"total_price"`
TotalDuration int `json:"total_duration"`
CreatedAt time.Time `json:"created_at"`
}
// ChatRepositoryAPI defines persistence operations
//
//go:generate mockgen -destination=mock_repo.go -package=main . ChatRepositoryAPI
type ChatRepositoryAPI interface {
SaveChatInteraction(ctx context.Context, rec ChatInteraction) error
ListChatInteractions(ctx context.Context, limit, offset int) ([]ChatInteraction, error)
SaveLLMRawEvent(ctx context.Context, correlationID, phase, raw string) error
ListLLMRawEvents(ctx context.Context, correlationID string, limit, offset int) ([]RawLLMEvent, error)
}
// RawLLMEvent represents a stored raw LLM exchange phase
type RawLLMEvent struct {
CorrelationID string `json:"correlation_id"`
Phase string `json:"phase"`
RawJSON string `json:"raw_json"`
CreatedAt time.Time `json:"created_at"`
}
// PGChatRepository is a PostgreSQL implementation using pgxpool
type PGChatRepository struct {
pool *pgxpool.Pool
}
// NewPGChatRepository creates a new repository if dsn provided, returns nil if empty dsn
func NewPGChatRepository(ctx context.Context, dsn string) (*PGChatRepository, error) {
if dsn == "" {
return nil, nil
}
cfg, err := pgxpool.ParseConfig(dsn)
if err != nil {
return nil, err
}
p, err := pgxpool.NewWithConfig(ctx, cfg)
if err != nil {
return nil, err
}
r := &PGChatRepository{pool: p}
if err := r.ensureSchema(ctx); err != nil {
p.Close()
return nil, err
}
return r, nil
}
// ensureSchema creates/adjusts tables. Drops legacy raw columns.
func (r *PGChatRepository) ensureSchema(ctx context.Context) error {
ddlInteractions := `CREATE TABLE IF NOT EXISTS chat_interactions (
id BIGSERIAL PRIMARY KEY,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
correlation_id TEXT NOT NULL,
user_message TEXT NOT NULL,
translate TEXT,
animal TEXT,
keywords TEXT[] NOT NULL,
best_visit_id TEXT,
total_price INT,
total_duration INT
);`
if _, err := r.pool.Exec(ctx, ddlInteractions); err != nil {
return err
}
ddlRaw := `CREATE TABLE IF NOT EXISTS chat_llm_raw (
id BIGSERIAL PRIMARY KEY,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
correlation_id TEXT NOT NULL,
phase TEXT NOT NULL,
raw_json TEXT
);`
if _, err := r.pool.Exec(ctx, ddlRaw); err != nil {
return err
}
// Legacy column cleanup (ignore errors)
for _, drop := range []string{
"ALTER TABLE chat_interactions DROP COLUMN IF EXISTS raw_keywords_json",
"ALTER TABLE chat_interactions DROP COLUMN IF EXISTS raw_disambig_json",
} {
if _, err := r.pool.Exec(ctx, drop); err != nil {
logrus.WithError(err).Debug("drop legacy column failed (ignored)")
}
}
return nil
}
// SaveChatInteraction inserts a record
func (r *PGChatRepository) SaveChatInteraction(ctx context.Context, rec ChatInteraction) error {
if r == nil || r.pool == nil {
return nil
}
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
_, err := r.pool.Exec(ctx, `INSERT INTO chat_interactions
(correlation_id, user_message, translate, animal, keywords, best_visit_id, total_price, total_duration)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8)`,
rec.CorrelationID, rec.UserMessage, rec.Translate, rec.Animal, rec.Keywords, nullIfEmpty(rec.BestVisitID), rec.TotalPrice, rec.TotalDuration)
if err != nil {
logrus.WithError(err).Warn("failed to persist chat interaction")
}
return err
}
// ListChatInteractions retrieves records with pagination
func (r *PGChatRepository) ListChatInteractions(ctx context.Context, limit, offset int) ([]ChatInteraction, error) {
if r == nil || r.pool == nil {
return []ChatInteraction{}, nil
}
if limit <= 0 || limit > 500 {
limit = 50
}
if offset < 0 {
offset = 0
}
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
qry := `SELECT correlation_id, user_message, COALESCE(translate,'') as translate, COALESCE(animal,'') as animal, keywords, COALESCE(best_visit_id,'') as best_visit_id, total_price, total_duration, created_at
FROM chat_interactions ORDER BY created_at DESC LIMIT $1 OFFSET $2`
rows, err := r.pool.Query(ctx, qry, limit, offset)
if err != nil {
if pgErr, ok := err.(*pgconn.PgError); ok && (pgErr.Code == "42P01" || pgErr.Code == "42703") {
logrus.WithError(err).Warn("listing: attempting schema repair")
if r.ensureSchema(context.Background()) == nil {
rows, err = r.pool.Query(ctx, qry, limit, offset)
}
}
}
if err != nil {
return nil, err
}
defer rows.Close()
var out []ChatInteraction
for rows.Next() {
var rec ChatInteraction
if err := rows.Scan(&rec.CorrelationID, &rec.UserMessage, &rec.Translate, &rec.Animal, &rec.Keywords, &rec.BestVisitID, &rec.TotalPrice, &rec.TotalDuration, &rec.CreatedAt); err != nil {
return nil, err
}
out = append(out, rec)
}
return out, rows.Err()
}
// SaveLLMRawEvent inserts a raw event record
func (r *PGChatRepository) SaveLLMRawEvent(ctx context.Context, correlationID, phase, raw string) error {
if r == nil || r.pool == nil {
return nil
}
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
_, err := r.pool.Exec(ctx, `INSERT INTO chat_llm_raw (correlation_id, phase, raw_json) VALUES ($1,$2,$3)`, correlationID, phase, raw)
if err != nil {
logrus.WithError(err).Warn("failed to persist raw llm event")
}
return err
}
// ListLLMRawEvents retrieves raw LLM events with pagination
func (r *PGChatRepository) ListLLMRawEvents(ctx context.Context, correlationID string, limit, offset int) ([]RawLLMEvent, error) {
if r == nil || r.pool == nil || correlationID == "" {
return []RawLLMEvent{}, nil
}
if limit <= 0 || limit > 500 {
limit = 50
}
if offset < 0 {
offset = 0
}
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
rows, err := r.pool.Query(ctx, `SELECT correlation_id, phase, COALESCE(raw_json,'') as raw_json, created_at FROM chat_llm_raw WHERE correlation_id=$1 ORDER BY created_at ASC LIMIT $2 OFFSET $3`, correlationID, limit, offset)
if err != nil {
return nil, err
}
defer rows.Close()
var out []RawLLMEvent
for rows.Next() {
var ev RawLLMEvent
if err := rows.Scan(&ev.CorrelationID, &ev.Phase, &ev.RawJSON, &ev.CreatedAt); err != nil {
return nil, err
}
out = append(out, ev)
}
return out, rows.Err()
}
// Close releases pool resources
func (r *PGChatRepository) Close() {
if r != nil && r.pool != nil {
r.pool.Close()
}
}
func nullIfEmpty(s string) interface{} {
if s == "" {
return nil
}
return s
}
// Helper to build DSN from env if DATABASE_URL not provided
func buildDefaultDSN() string {
if dsn := os.Getenv("DATABASE_URL"); dsn != "" {
return dsn
}
host := envOr("PGHOST", "localhost")
port := envOr("PGPORT", "5432")
user := envOr("PGUSER", "postgres")
pass := os.Getenv("PGPASSWORD")
db := envOr("PGDATABASE", "vetrag")
ssl := envOr("PGSSLMODE", "disable")
if pass != "" {
return "postgres://" + user + ":" + pass + "@" + host + ":" + port + "/" + db + "?sslmode=" + ssl
}
return "postgres://" + user + "@" + host + ":" + port + "/" + db + "?sslmode=" + ssl
}
func envOr(k, def string) string {
if v := os.Getenv(k); v != "" {
return v
}
return def
}