296 lines
9.6 KiB
Go
296 lines
9.6 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// ChatInteraction represents a persisted chat request/response metadata (raw JSON moved to chat_llm_raw)
|
|
type ChatInteraction struct {
|
|
CorrelationID string `json:"correlation_id"`
|
|
UserMessage string `json:"user_message"`
|
|
Translate string `json:"translate"`
|
|
Animal string `json:"animal"`
|
|
Keywords []string `json:"keywords"`
|
|
BestVisitID string `json:"best_visit_id"`
|
|
TotalPrice int `json:"total_price"`
|
|
TotalDuration int `json:"total_duration"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
// ChatRepositoryAPI defines persistence operations
|
|
//
|
|
//go:generate mockgen -destination=mock_repo.go -package=main . ChatRepositoryAPI
|
|
type ChatRepositoryAPI interface {
|
|
SaveChatInteraction(ctx context.Context, rec ChatInteraction) error
|
|
ListChatInteractions(ctx context.Context, limit, offset int) ([]ChatInteraction, error)
|
|
SaveLLMRawEvent(ctx context.Context, correlationID, phase, raw string) error
|
|
ListLLMRawEvents(ctx context.Context, correlationID string, limit, offset int) ([]RawLLMEvent, error)
|
|
SaveKnowledgeModel(ctx context.Context, text string) error
|
|
ListKnowledgeModels(ctx context.Context, limit, offset int) ([]knowledgeModelMeta, error)
|
|
GetKnowledgeModelText(ctx context.Context, id int64) (string, error)
|
|
GetUserByUsername(ctx context.Context, username string) (*User, error)
|
|
}
|
|
|
|
// RawLLMEvent represents a stored raw LLM exchange phase
|
|
type RawLLMEvent struct {
|
|
CorrelationID string `json:"correlation_id"`
|
|
Phase string `json:"phase"`
|
|
RawJSON string `json:"raw_json"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
// knowledgeModelMeta is used for listing knowledgeModel metadata
|
|
// (exported for use in interface, but can be unexported if not needed outside package)
|
|
type knowledgeModelMeta struct {
|
|
ID int64 `json:"id"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
// PGChatRepository is a PostgreSQL implementation using pgxpool
|
|
type PGChatRepository struct {
|
|
pool *pgxpool.Pool
|
|
}
|
|
|
|
// NewPGChatRepository creates a new repository if dsn provided, returns nil if empty dsn
|
|
func NewPGChatRepository(ctx context.Context, dsn string) (*PGChatRepository, error) {
|
|
if dsn == "" {
|
|
return nil, nil
|
|
}
|
|
cfg, err := pgxpool.ParseConfig(dsn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
p, err := pgxpool.NewWithConfig(ctx, cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
r := &PGChatRepository{pool: p}
|
|
// Schema migration will be handled by a migration tool
|
|
return r, nil
|
|
}
|
|
|
|
// SaveChatInteraction inserts a record
|
|
func (r *PGChatRepository) SaveChatInteraction(ctx context.Context, rec ChatInteraction) error {
|
|
if r == nil || r.pool == nil {
|
|
return nil
|
|
}
|
|
// Ensure keywords is not nil to satisfy NOT NULL constraint
|
|
if rec.Keywords == nil {
|
|
rec.Keywords = []string{}
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
|
defer cancel()
|
|
_, err := r.pool.Exec(ctx, `INSERT INTO chat_interactions
|
|
(correlation_id, user_message, translate, animal, keywords, best_visit_id, total_price, total_duration)
|
|
VALUES ($1,$2,$3,$4,$5,$6,$7,$8)`,
|
|
rec.CorrelationID, rec.UserMessage, rec.Translate, rec.Animal, rec.Keywords, nullIfEmpty(rec.BestVisitID), rec.TotalPrice, rec.TotalDuration)
|
|
if err != nil {
|
|
logrus.WithError(err).Warn("failed to persist chat interaction")
|
|
}
|
|
return err
|
|
}
|
|
|
|
// ListChatInteractions retrieves records with pagination
|
|
func (r *PGChatRepository) ListChatInteractions(ctx context.Context, limit, offset int) ([]ChatInteraction, error) {
|
|
if r == nil || r.pool == nil {
|
|
return []ChatInteraction{}, nil
|
|
}
|
|
if limit <= 0 || limit > 500 {
|
|
limit = 50
|
|
}
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
|
defer cancel()
|
|
qry := `SELECT correlation_id, user_message, COALESCE(translate,'') as translate, COALESCE(animal,'') as animal, keywords, COALESCE(best_visit_id,'') as best_visit_id, total_price, total_duration, created_at
|
|
FROM chat_interactions ORDER BY created_at DESC LIMIT $1 OFFSET $2`
|
|
rows, err := r.pool.Query(ctx, qry, limit, offset)
|
|
if err != nil {
|
|
if pgErr, ok := err.(*pgconn.PgError); ok && (pgErr.Code == "42P01" || pgErr.Code == "42703") {
|
|
logrus.WithError(err).Warn("listing: schema missing or column not found; please run migrations")
|
|
return nil, err
|
|
}
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
var out []ChatInteraction
|
|
for rows.Next() {
|
|
var rec ChatInteraction
|
|
if err := rows.Scan(&rec.CorrelationID, &rec.UserMessage, &rec.Translate, &rec.Animal, &rec.Keywords, &rec.BestVisitID, &rec.TotalPrice, &rec.TotalDuration, &rec.CreatedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, rec)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
// SaveLLMRawEvent inserts a raw event record
|
|
func (r *PGChatRepository) SaveLLMRawEvent(ctx context.Context, correlationID, phase, raw string) error {
|
|
if r == nil || r.pool == nil {
|
|
return nil
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
|
defer cancel()
|
|
_, err := r.pool.Exec(ctx, `INSERT INTO chat_llm_raw (correlation_id, phase, raw_json) VALUES ($1,$2,$3)`, correlationID, phase, raw)
|
|
if err != nil {
|
|
logrus.WithError(err).Warn("failed to persist raw llm event")
|
|
}
|
|
return err
|
|
}
|
|
|
|
// ListLLMRawEvents retrieves raw LLM events with pagination
|
|
func (r *PGChatRepository) ListLLMRawEvents(ctx context.Context, correlationID string, limit, offset int) ([]RawLLMEvent, error) {
|
|
if r == nil || r.pool == nil || correlationID == "" {
|
|
return []RawLLMEvent{}, nil
|
|
}
|
|
if limit <= 0 || limit > 500 {
|
|
limit = 50
|
|
}
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
|
defer cancel()
|
|
rows, err := r.pool.Query(ctx, `SELECT correlation_id, phase, COALESCE(raw_json,'') as raw_json, created_at FROM chat_llm_raw WHERE correlation_id=$1 ORDER BY created_at ASC LIMIT $2 OFFSET $3`, correlationID, limit, offset)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
var out []RawLLMEvent
|
|
for rows.Next() {
|
|
var ev RawLLMEvent
|
|
if err := rows.Scan(&ev.CorrelationID, &ev.Phase, &ev.RawJSON, &ev.CreatedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, ev)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
// SaveKnowledgeModel saves a knowledge model snapshot
|
|
func (p *PGChatRepository) SaveKnowledgeModel(ctx context.Context, text string) error {
|
|
if p == nil || p.pool == nil {
|
|
return nil
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
|
defer cancel()
|
|
_, err := p.pool.Exec(ctx, `INSERT INTO knowledgeModel (knowledge_text) VALUES ($1)`, text)
|
|
if err != nil {
|
|
logrus.WithError(err).Warn("failed to persist knowledge model")
|
|
}
|
|
return err
|
|
}
|
|
|
|
// ListKnowledgeModels lists knowledge model metadata
|
|
func (p *PGChatRepository) ListKnowledgeModels(ctx context.Context, limit, offset int) ([]knowledgeModelMeta, error) {
|
|
if p == nil || p.pool == nil {
|
|
return []knowledgeModelMeta{}, nil
|
|
}
|
|
if limit <= 0 || limit > 500 {
|
|
limit = 100
|
|
}
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
|
defer cancel()
|
|
rows, err := p.pool.Query(ctx, `SELECT id, created_at FROM knowledgeModel ORDER BY created_at DESC LIMIT $1 OFFSET $2`, limit, offset)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
var out []knowledgeModelMeta
|
|
for rows.Next() {
|
|
var meta knowledgeModelMeta
|
|
if err := rows.Scan(&meta.ID, &meta.CreatedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, meta)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
// GetKnowledgeModelText retrieves the text of a knowledge model by ID
|
|
func (p *PGChatRepository) GetKnowledgeModelText(ctx context.Context, id int64) (string, error) {
|
|
if p == nil || p.pool == nil {
|
|
return "", nil
|
|
}
|
|
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
|
defer cancel()
|
|
var text string
|
|
err := p.pool.QueryRow(ctx, `SELECT knowledge_text FROM knowledgeModel WHERE id = $1`, id).Scan(&text)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return text, nil
|
|
}
|
|
|
|
// GetUserByUsername fetches a user by username
|
|
func (r *PGChatRepository) GetUserByUsername(ctx context.Context, username string) (*User, error) {
|
|
row := r.pool.QueryRow(ctx, "SELECT id, username, password_hash, created_at FROM users WHERE username=$1", username)
|
|
var u User
|
|
err := row.Scan(&u.ID, &u.Username, &u.PasswordHash, &u.CreatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &u, nil
|
|
}
|
|
|
|
// CountUsers returns the number of users in the users table
|
|
func (r *PGChatRepository) CountUsers(count *int) error {
|
|
return r.pool.QueryRow(context.Background(), "SELECT COUNT(*) FROM users").Scan(count)
|
|
}
|
|
|
|
// CreateUser inserts a new user with username and password hash
|
|
func (r *PGChatRepository) CreateUser(username, passwordHash string) error {
|
|
_, err := r.pool.Exec(context.Background(), "INSERT INTO users (username, password_hash) VALUES ($1, $2)", username, passwordHash)
|
|
return err
|
|
}
|
|
|
|
// Close releases pool resources
|
|
func (r *PGChatRepository) Close() {
|
|
if r != nil && r.pool != nil {
|
|
r.pool.Close()
|
|
}
|
|
}
|
|
|
|
func nullIfEmpty(s string) interface{} {
|
|
if s == "" {
|
|
return nil
|
|
}
|
|
return s
|
|
}
|
|
|
|
// Helper to build DSN from env if DATABASE_URL not provided
|
|
func buildDefaultDSN() string {
|
|
if dsn := os.Getenv("DATABASE_URL"); dsn != "" {
|
|
return dsn
|
|
}
|
|
host := envOr("PGHOST", "localhost")
|
|
port := envOr("PGPORT", "5432")
|
|
user := envOr("PGUSER", "postgres")
|
|
pass := os.Getenv("PGPASSWORD")
|
|
db := envOr("PGDATABASE", "vetrag")
|
|
ssl := envOr("PGSSLMODE", "disable")
|
|
if pass != "" {
|
|
return "postgres://" + user + ":" + pass + "@" + host + ":" + port + "/" + db + "?sslmode=" + ssl
|
|
}
|
|
return "postgres://" + user + "@" + host + ":" + port + "/" + db + "?sslmode=" + ssl
|
|
}
|
|
|
|
func envOr(k, def string) string {
|
|
if v := os.Getenv(k); v != "" {
|
|
return v
|
|
}
|
|
return def
|
|
}
|