vetrag/main.go

236 lines
6.0 KiB
Go

package main
import (
"bytes"
"context"
"fmt"
"html/template"
"io/ioutil"
"net/http"
"os"
"strings"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
)
// Procedure represents a single procedure for a visit reason
type Procedure struct {
Name string `yaml:"name" json:"name"`
Price int `yaml:"price" json:"price"`
DurationMin int `yaml:"duration_minutes" json:"duration_minutes"`
}
// Reason represents a visit reason entry
type Reason struct {
ID string `yaml:"id" json:"id"`
Reason string `yaml:"reason" json:"reason"`
Keywords []string `yaml:"keywords" json:"keywords"`
Procedures []Procedure `yaml:"procedures" json:"procedures"`
Notes string `yaml:"notes" json:"notes,omitempty"`
}
var reasonsDB []Reason
func loadYAMLDB(path string) error {
data, err := ioutil.ReadFile(path)
if err != nil {
return err
}
return yaml.Unmarshal(data, &reasonsDB)
}
// ChatRequest represents the incoming chat message
type ChatRequest struct {
Message string `json:"message"`
}
// ChatResponse represents the response to the frontend
type ChatResponse struct {
Match *string `json:"match"`
Procedures []Procedure `json:"procedures,omitempty"`
TotalPrice int `json:"total_price,omitempty"`
TotalDuration int `json:"total_duration,omitempty"`
Notes string `json:"notes,omitempty"`
}
// Config holds all prompts and settings
type Config struct {
LLM struct {
ExtractKeywordsPrompt string `yaml:"extract_keywords_prompt"`
DisambiguatePrompt string `yaml:"disambiguate_prompt"`
} `yaml:"llm"`
}
var appConfig Config
func loadConfig(path string) error {
data, err := ioutil.ReadFile(path)
if err != nil {
return err
}
return yaml.Unmarshal(data, &appConfig)
}
// naiveKeywordExtract splits message into lowercase words (placeholder for LLM)
func naiveKeywordExtract(msg string) []string {
// TODO: Replace with LLM call
words := make(map[string]struct{})
for _, w := range strings.FieldsFunc(strings.ToLower(msg), func(r rune) bool {
return r < 'a' || r > 'z' && r < 'á' || r > 'ű'
}) {
words[w] = struct{}{}
}
res := make([]string, 0, len(words))
for w := range words {
res = append(res, w)
}
return res
}
// findCandidates returns reasons with overlapping keywords
func findCandidates(keywords []string) []Reason {
kwSet := make(map[string]struct{})
for _, k := range keywords {
kwSet[k] = struct{}{}
}
var candidates []Reason
for _, r := range reasonsDB {
for _, k := range r.Keywords {
if _, ok := kwSet[strings.ToLower(k)]; ok {
candidates = append(candidates, r)
break
}
}
}
return candidates
}
// sumProcedures calculates total price and duration
func sumProcedures(procs []Procedure) (int, int) {
totalPrice := 0
totalDuration := 0
for _, p := range procs {
totalPrice += p.Price
totalDuration += p.DurationMin
}
return totalPrice, totalDuration
}
// renderPrompt renders a Go template with the given data
func renderPrompt(tmplStr string, data any) (string, error) {
tmpl, err := template.New("").Parse(tmplStr)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}
var uiTemplate *template.Template
func loadUITemplate(path string) error {
tmpl, err := template.ParseFiles(path)
if err != nil {
return err
}
uiTemplate = tmpl
return nil
}
func main() {
logrus.SetFormatter(&logrus.TextFormatter{FullTimestamp: true})
logrus.SetLevel(logrus.InfoLevel)
if err := loadConfig("config.yaml"); err != nil {
logrus.Fatalf("Failed to load config.yaml: %v", err)
}
logrus.Infof("Loaded config: %+v", appConfig)
if err := loadYAMLDB("db.yaml"); err != nil {
logrus.Fatalf("Failed to load db.yaml: %v", err)
}
fmt.Printf("Loaded %d reasons from db.yaml\n", len(reasonsDB))
if err := loadUITemplate("ui.html"); err != nil {
logrus.Fatalf("Failed to load ui.html: %v", err)
}
llm := &LLMClient{
APIKey: os.Getenv("OPENAI_API_KEY"),
BaseURL: os.Getenv("OPENAI_BASE_URL"),
}
r := gin.Default()
r.GET("/", func(c *gin.Context) {
c.Status(200)
uiTemplate.Execute(c.Writer, nil)
})
r.POST("/chat", func(c *gin.Context) {
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
logrus.WithError(err).Error("Invalid request")
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
ctx := context.Background()
keywords, err := llm.ExtractKeywords(ctx, req.Message)
candidates := findCandidates(keywords)
bestID := ""
if len(candidates) > 0 && err == nil {
bestID, err = llm.DisambiguateBestMatch(ctx, req.Message, candidates)
}
logRequest(req, keywords, candidates, bestID, err)
if err != nil || len(keywords) == 0 || len(candidates) == 0 || bestID == "" {
c.JSON(http.StatusOK, ChatResponse{Match: nil})
return
}
var best *Reason
for i := range candidates {
if candidates[i].ID == bestID {
best = &candidates[i]
break
}
}
if best == nil {
c.JSON(http.StatusOK, ChatResponse{Match: nil})
return
}
totalPrice, totalDuration := sumProcedures(best.Procedures)
logrus.WithFields(logrus.Fields{
"match": best.ID,
"total_price": totalPrice,
"total_duration": totalDuration,
"notes": best.Notes,
}).Info("Responding with match")
c.JSON(http.StatusOK, ChatResponse{
Match: &best.ID,
Procedures: best.Procedures,
TotalPrice: totalPrice,
TotalDuration: totalDuration,
Notes: best.Notes,
})
})
r.Run(":8080")
}
// logRequest logs incoming chat requests and extracted info
func logRequest(req ChatRequest, keywords []string, candidates []Reason, bestID string, err error) {
logrus.WithFields(logrus.Fields{
"message": req.Message,
"keywords": keywords,
"candidates": getCandidateIDs(candidates),
"bestID": bestID,
"err": err,
}).Info("Chat request trace")
}
func getCandidateIDs(candidates []Reason) []string {
ids := make([]string, len(candidates))
for i, c := range candidates {
ids[i] = c.ID
}
return ids
}