316 lines
9.9 KiB
Go
316 lines
9.9 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// ChatInteraction represents a persisted chat request/response metadata (raw JSON moved to chat_llm_raw)
|
|
type ChatInteraction struct {
|
|
CorrelationID string `json:"correlation_id"`
|
|
UserMessage string `json:"user_message"`
|
|
Translate string `json:"translate"`
|
|
Animal string `json:"animal"`
|
|
Keywords []string `json:"keywords"`
|
|
BestVisitID string `json:"best_visit_id"`
|
|
TotalPrice int `json:"total_price"`
|
|
TotalDuration int `json:"total_duration"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
// ChatRepositoryAPI defines persistence operations
|
|
//
|
|
//go:generate mockgen -destination=mock_repo.go -package=main . ChatRepositoryAPI
|
|
type ChatRepositoryAPI interface {
|
|
SaveChatInteraction(ctx context.Context, rec ChatInteraction) error
|
|
ListChatInteractions(ctx context.Context, limit, offset int) ([]ChatInteraction, error)
|
|
SaveLLMRawEvent(ctx context.Context, correlationID, phase, raw string) error
|
|
ListLLMRawEvents(ctx context.Context, correlationID string, limit, offset int) ([]RawLLMEvent, error)
|
|
SaveKnowledgeModel(ctx context.Context, text string) error
|
|
ListKnowledgeModels(ctx context.Context, limit, offset int) ([]knowledgeModelMeta, error)
|
|
GetKnowledgeModelText(ctx context.Context, id int64) (string, error)
|
|
}
|
|
|
|
// RawLLMEvent represents a stored raw LLM exchange phase
|
|
type RawLLMEvent struct {
|
|
CorrelationID string `json:"correlation_id"`
|
|
Phase string `json:"phase"`
|
|
RawJSON string `json:"raw_json"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
// knowledgeModelMeta is used for listing knowledgeModel metadata
|
|
// (exported for use in interface, but can be unexported if not needed outside package)
|
|
type knowledgeModelMeta struct {
|
|
ID int64 `json:"id"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
// PGChatRepository is a PostgreSQL implementation using pgxpool
|
|
type PGChatRepository struct {
|
|
pool *pgxpool.Pool
|
|
}
|
|
|
|
// NewPGChatRepository creates a new repository if dsn provided, returns nil if empty dsn
|
|
func NewPGChatRepository(ctx context.Context, dsn string) (*PGChatRepository, error) {
|
|
if dsn == "" {
|
|
return nil, nil
|
|
}
|
|
cfg, err := pgxpool.ParseConfig(dsn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
p, err := pgxpool.NewWithConfig(ctx, cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
r := &PGChatRepository{pool: p}
|
|
if err := r.ensureSchema(ctx); err != nil {
|
|
p.Close()
|
|
return nil, err
|
|
}
|
|
return r, nil
|
|
}
|
|
|
|
// ensureSchema creates/adjusts tables. Drops legacy raw columns.
|
|
func (r *PGChatRepository) ensureSchema(ctx context.Context) error {
|
|
ddlInteractions := `CREATE TABLE IF NOT EXISTS chat_interactions (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
|
correlation_id TEXT NOT NULL,
|
|
user_message TEXT NOT NULL,
|
|
translate TEXT,
|
|
animal TEXT,
|
|
keywords TEXT[] NOT NULL,
|
|
best_visit_id TEXT,
|
|
total_price INT,
|
|
total_duration INT
|
|
);`
|
|
if _, err := r.pool.Exec(ctx, ddlInteractions); err != nil {
|
|
return err
|
|
}
|
|
ddlRaw := `CREATE TABLE IF NOT EXISTS chat_llm_raw (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
|
correlation_id TEXT NOT NULL,
|
|
phase TEXT NOT NULL,
|
|
raw_json TEXT
|
|
);`
|
|
if _, err := r.pool.Exec(ctx, ddlRaw); err != nil {
|
|
return err
|
|
}
|
|
// Add knowledgeModel table
|
|
ddlKnowledgeModel := `CREATE TABLE IF NOT EXISTS knowledgeModel (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
|
knowledge_text TEXT NOT NULL
|
|
);`
|
|
if _, err := r.pool.Exec(ctx, ddlKnowledgeModel); err != nil {
|
|
return err
|
|
}
|
|
// Legacy column cleanup (ignore errors)
|
|
for _, drop := range []string{
|
|
"ALTER TABLE chat_interactions DROP COLUMN IF EXISTS raw_keywords_json",
|
|
"ALTER TABLE chat_interactions DROP COLUMN IF EXISTS raw_disambig_json",
|
|
} {
|
|
if _, err := r.pool.Exec(ctx, drop); err != nil {
|
|
logrus.WithError(err).Debug("drop legacy column failed (ignored)")
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SaveChatInteraction inserts a record
|
|
func (r *PGChatRepository) SaveChatInteraction(ctx context.Context, rec ChatInteraction) error {
|
|
if r == nil || r.pool == nil {
|
|
return nil
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
|
defer cancel()
|
|
_, err := r.pool.Exec(ctx, `INSERT INTO chat_interactions
|
|
(correlation_id, user_message, translate, animal, keywords, best_visit_id, total_price, total_duration)
|
|
VALUES ($1,$2,$3,$4,$5,$6,$7,$8)`,
|
|
rec.CorrelationID, rec.UserMessage, rec.Translate, rec.Animal, rec.Keywords, nullIfEmpty(rec.BestVisitID), rec.TotalPrice, rec.TotalDuration)
|
|
if err != nil {
|
|
logrus.WithError(err).Warn("failed to persist chat interaction")
|
|
}
|
|
return err
|
|
}
|
|
|
|
// ListChatInteractions retrieves records with pagination
|
|
func (r *PGChatRepository) ListChatInteractions(ctx context.Context, limit, offset int) ([]ChatInteraction, error) {
|
|
if r == nil || r.pool == nil {
|
|
return []ChatInteraction{}, nil
|
|
}
|
|
if limit <= 0 || limit > 500 {
|
|
limit = 50
|
|
}
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
|
defer cancel()
|
|
qry := `SELECT correlation_id, user_message, COALESCE(translate,'') as translate, COALESCE(animal,'') as animal, keywords, COALESCE(best_visit_id,'') as best_visit_id, total_price, total_duration, created_at
|
|
FROM chat_interactions ORDER BY created_at DESC LIMIT $1 OFFSET $2`
|
|
rows, err := r.pool.Query(ctx, qry, limit, offset)
|
|
if err != nil {
|
|
if pgErr, ok := err.(*pgconn.PgError); ok && (pgErr.Code == "42P01" || pgErr.Code == "42703") {
|
|
logrus.WithError(err).Warn("listing: attempting schema repair")
|
|
if r.ensureSchema(context.Background()) == nil {
|
|
rows, err = r.pool.Query(ctx, qry, limit, offset)
|
|
}
|
|
}
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
var out []ChatInteraction
|
|
for rows.Next() {
|
|
var rec ChatInteraction
|
|
if err := rows.Scan(&rec.CorrelationID, &rec.UserMessage, &rec.Translate, &rec.Animal, &rec.Keywords, &rec.BestVisitID, &rec.TotalPrice, &rec.TotalDuration, &rec.CreatedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, rec)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
// SaveLLMRawEvent inserts a raw event record
|
|
func (r *PGChatRepository) SaveLLMRawEvent(ctx context.Context, correlationID, phase, raw string) error {
|
|
if r == nil || r.pool == nil {
|
|
return nil
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
|
defer cancel()
|
|
_, err := r.pool.Exec(ctx, `INSERT INTO chat_llm_raw (correlation_id, phase, raw_json) VALUES ($1,$2,$3)`, correlationID, phase, raw)
|
|
if err != nil {
|
|
logrus.WithError(err).Warn("failed to persist raw llm event")
|
|
}
|
|
return err
|
|
}
|
|
|
|
// ListLLMRawEvents retrieves raw LLM events with pagination
|
|
func (r *PGChatRepository) ListLLMRawEvents(ctx context.Context, correlationID string, limit, offset int) ([]RawLLMEvent, error) {
|
|
if r == nil || r.pool == nil || correlationID == "" {
|
|
return []RawLLMEvent{}, nil
|
|
}
|
|
if limit <= 0 || limit > 500 {
|
|
limit = 50
|
|
}
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
|
defer cancel()
|
|
rows, err := r.pool.Query(ctx, `SELECT correlation_id, phase, COALESCE(raw_json,'') as raw_json, created_at FROM chat_llm_raw WHERE correlation_id=$1 ORDER BY created_at ASC LIMIT $2 OFFSET $3`, correlationID, limit, offset)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
var out []RawLLMEvent
|
|
for rows.Next() {
|
|
var ev RawLLMEvent
|
|
if err := rows.Scan(&ev.CorrelationID, &ev.Phase, &ev.RawJSON, &ev.CreatedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, ev)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
// SaveKnowledgeModel inserts a new knowledgeModel entry
|
|
func (r *PGChatRepository) SaveKnowledgeModel(ctx context.Context, text string) error {
|
|
if r == nil || r.pool == nil {
|
|
return nil
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
|
defer cancel()
|
|
_, err := r.pool.Exec(ctx, `INSERT INTO knowledgeModel (knowledge_text) VALUES ($1)`, text)
|
|
return err
|
|
}
|
|
|
|
// ListKnowledgeModels returns a list of knowledgeModel metadata (id, created_at)
|
|
func (r *PGChatRepository) ListKnowledgeModels(ctx context.Context, limit, offset int) ([]knowledgeModelMeta, error) {
|
|
if r == nil || r.pool == nil {
|
|
return nil, nil
|
|
}
|
|
if limit <= 0 || limit > 1000 {
|
|
limit = 100
|
|
}
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
|
defer cancel()
|
|
rows, err := r.pool.Query(ctx, `SELECT id, created_at FROM knowledgeModel ORDER BY created_at DESC LIMIT $1 OFFSET $2`, limit, offset)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
var out []knowledgeModelMeta
|
|
for rows.Next() {
|
|
var s knowledgeModelMeta
|
|
if err := rows.Scan(&s.ID, &s.CreatedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, s)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
// GetKnowledgeModelText returns the knowledge_text for a given id
|
|
func (r *PGChatRepository) GetKnowledgeModelText(ctx context.Context, id int64) (string, error) {
|
|
if r == nil || r.pool == nil {
|
|
return "", nil
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
|
defer cancel()
|
|
var text string
|
|
err := r.pool.QueryRow(ctx, `SELECT knowledge_text FROM knowledgeModel WHERE id=$1`, id).Scan(&text)
|
|
return text, err
|
|
}
|
|
|
|
// Close releases pool resources
|
|
func (r *PGChatRepository) Close() {
|
|
if r != nil && r.pool != nil {
|
|
r.pool.Close()
|
|
}
|
|
}
|
|
|
|
func nullIfEmpty(s string) interface{} {
|
|
if s == "" {
|
|
return nil
|
|
}
|
|
return s
|
|
}
|
|
|
|
// Helper to build DSN from env if DATABASE_URL not provided
|
|
func buildDefaultDSN() string {
|
|
if dsn := os.Getenv("DATABASE_URL"); dsn != "" {
|
|
return dsn
|
|
}
|
|
host := envOr("PGHOST", "localhost")
|
|
port := envOr("PGPORT", "5432")
|
|
user := envOr("PGUSER", "postgres")
|
|
pass := os.Getenv("PGPASSWORD")
|
|
db := envOr("PGDATABASE", "vetrag")
|
|
ssl := envOr("PGSSLMODE", "disable")
|
|
if pass != "" {
|
|
return "postgres://" + user + ":" + pass + "@" + host + ":" + port + "/" + db + "?sslmode=" + ssl
|
|
}
|
|
return "postgres://" + user + "@" + host + ":" + port + "/" + db + "?sslmode=" + ssl
|
|
}
|
|
|
|
func envOr(k, def string) string {
|
|
if v := os.Getenv(k); v != "" {
|
|
return v
|
|
}
|
|
return def
|
|
}
|