vetrag/main.go

305 lines
8.7 KiB
Go

package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
"strings"
"text/template"
"time"
"github.com/gin-gonic/gin"
"gopkg.in/yaml.v3"
)
// Procedure represents a single procedure for a visit reason
type Procedure struct {
Name string `yaml:"name" json:"name"`
Price int `yaml:"price" json:"price"`
DurationMin int `yaml:"duration_minutes" json:"duration_minutes"`
}
// Reason represents a visit reason entry
type Reason struct {
ID string `yaml:"id" json:"id"`
Reason string `yaml:"reason" json:"reason"`
Keywords []string `yaml:"keywords" json:"keywords"`
Procedures []Procedure `yaml:"procedures" json:"procedures"`
Notes string `yaml:"notes" json:"notes,omitempty"`
}
var reasonsDB []Reason
func loadYAMLDB(path string) error {
data, err := ioutil.ReadFile(path)
if err != nil {
return err
}
return yaml.Unmarshal(data, &reasonsDB)
}
// ChatRequest represents the incoming chat message
type ChatRequest struct {
Message string `json:"message"`
}
// ChatResponse represents the response to the frontend
type ChatResponse struct {
Match *string `json:"match"`
Procedures []Procedure `json:"procedures,omitempty"`
TotalPrice int `json:"total_price,omitempty"`
TotalDuration int `json:"total_duration,omitempty"`
Notes string `json:"notes,omitempty"`
}
// LLMClient abstracts LLM API calls
type LLMClient struct {
APIKey string
BaseURL string
}
// Config holds all prompts and settings
type Config struct {
LLM struct {
ExtractKeywordsPrompt string `yaml:"extract_keywords_prompt"`
DisambiguatePrompt string `yaml:"disambiguate_prompt"`
} `yaml:"llm"`
}
var appConfig Config
func loadConfig(path string) error {
data, err := ioutil.ReadFile(path)
if err != nil {
return err
}
return yaml.Unmarshal(data, &appConfig)
}
// ExtractKeywords calls LLM to extract keywords from user message
func (llm *LLMClient) ExtractKeywords(ctx context.Context, message string) ([]string, error) {
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
if err != nil {
log.Printf("[CONFIG] Failed to render ExtractKeywords prompt: %v", err)
return nil, err
}
log.Printf("[LLM] ExtractKeywords prompt: %q", prompt)
resp, err := llm.openAICompletion(ctx, prompt)
log.Printf("[LLM] ExtractKeywords response: %q, err: %v", resp, err)
if err != nil {
return nil, err
}
var keywords []string
if err := json.Unmarshal([]byte(resp), &keywords); err == nil {
return keywords, nil
}
// fallback: try splitting by comma
for _, k := range bytes.Split([]byte(resp), []byte{','}) {
kw := strings.TrimSpace(string(k))
if kw != "" {
keywords = append(keywords, kw)
}
}
return keywords, nil
}
// DisambiguateBestMatch calls LLM to pick best match from candidates
func (llm *LLMClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Reason) (string, error) {
entries, _ := json.Marshal(candidates)
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
if err != nil {
log.Printf("[CONFIG] Failed to render Disambiguate prompt: %v", err)
return "", err
}
log.Printf("[LLM] DisambiguateBestMatch prompt: %q", prompt)
resp, err := llm.openAICompletion(ctx, prompt)
log.Printf("[LLM] DisambiguateBestMatch response: %q, err: %v", resp, err)
if err != nil {
return "", err
}
id := strings.TrimSpace(resp)
if id == "none" || id == "null" {
return "", nil
}
return id, nil
}
// openAICompletion is a minimal OpenAI API call (text-davinci-003 or gpt-3.5-turbo-instruct)
func (llm *LLMClient) openAICompletion(ctx context.Context, prompt string) (string, error) {
apiURL := llm.BaseURL
if apiURL == "" {
apiURL = "https://api.openai.com/v1/completions"
}
log.Printf("[LLM] openAICompletion POST %s | prompt: %q", apiURL, prompt)
body := map[string]interface{}{
"model": "text-davinci-003",
"prompt": prompt,
"max_tokens": 64,
"temperature": 0,
}
jsonBody, _ := json.Marshal(body)
req, _ := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonBody))
if llm.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
log.Printf("[LLM] openAICompletion error: %v", err)
return "", err
}
defer resp.Body.Close()
var result struct {
Choices []struct {
Text string `json:"text"`
} `json:"choices"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
log.Printf("[LLM] openAICompletion decode error: %v", err)
return "", err
}
if len(result.Choices) == 0 {
log.Printf("[LLM] openAICompletion: no choices returned")
return "", nil
}
log.Printf("[LLM] openAICompletion: got text: %q", result.Choices[0].Text)
return result.Choices[0].Text, nil
}
// naiveKeywordExtract splits message into lowercase words (placeholder for LLM)
func naiveKeywordExtract(msg string) []string {
// TODO: Replace with LLM call
words := make(map[string]struct{})
for _, w := range strings.FieldsFunc(strings.ToLower(msg), func(r rune) bool {
return r < 'a' || r > 'z' && r < 'á' || r > 'ű'
}) {
words[w] = struct{}{}
}
res := make([]string, 0, len(words))
for w := range words {
res = append(res, w)
}
return res
}
// findCandidates returns reasons with overlapping keywords
func findCandidates(keywords []string) []Reason {
kwSet := make(map[string]struct{})
for _, k := range keywords {
kwSet[k] = struct{}{}
}
var candidates []Reason
for _, r := range reasonsDB {
for _, k := range r.Keywords {
if _, ok := kwSet[strings.ToLower(k)]; ok {
candidates = append(candidates, r)
break
}
}
}
return candidates
}
// sumProcedures calculates total price and duration
func sumProcedures(procs []Procedure) (int, int) {
totalPrice := 0
totalDuration := 0
for _, p := range procs {
totalPrice += p.Price
totalDuration += p.DurationMin
}
return totalPrice, totalDuration
}
// renderPrompt renders a Go template with the given data
func renderPrompt(tmplStr string, data any) (string, error) {
tmpl, err := template.New("").Parse(tmplStr)
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}
func main() {
if err := loadConfig("config.yaml"); err != nil {
log.Fatalf("Failed to load config.yaml: %v", err)
}
log.Printf("Loaded config: %+v", appConfig)
if err := loadYAMLDB("db.yaml"); err != nil {
log.Fatalf("Failed to load db.yaml: %v", err)
}
fmt.Printf("Loaded %d reasons from db.yaml\n", len(reasonsDB))
llm := &LLMClient{
APIKey: os.Getenv("OPENAI_API_KEY"),
BaseURL: os.Getenv("OPENAI_BASE_URL"), // e.g. http://localhost:1234/v1/completions
}
r := gin.Default()
r.POST("/chat", func(c *gin.Context) {
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
log.Printf("[ERROR] Invalid request: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
ctx := c.Request.Context()
keywords, err := llm.ExtractKeywords(ctx, req.Message)
candidates := findCandidates(keywords)
bestID := ""
if len(candidates) > 0 && err == nil {
bestID, err = llm.DisambiguateBestMatch(ctx, req.Message, candidates)
}
logRequest(req, keywords, candidates, bestID, err)
if err != nil || len(keywords) == 0 || len(candidates) == 0 || bestID == "" {
c.JSON(http.StatusOK, ChatResponse{Match: nil})
return
}
var best *Reason
for i := range candidates {
if candidates[i].ID == bestID {
best = &candidates[i]
break
}
}
if best == nil {
c.JSON(http.StatusOK, ChatResponse{Match: nil})
return
}
totalPrice, totalDuration := sumProcedures(best.Procedures)
log.Printf("[TRACE] Responding with match: %q, totalPrice: %d, totalDuration: %d, notes: %q", best.ID, totalPrice, totalDuration, best.Notes)
c.JSON(http.StatusOK, ChatResponse{
Match: &best.ID,
Procedures: best.Procedures,
TotalPrice: totalPrice,
TotalDuration: totalDuration,
Notes: best.Notes,
})
})
r.Run(":8080")
}
// logRequest logs incoming chat requests and extracted info
func logRequest(req ChatRequest, keywords []string, candidates []Reason, bestID string, err error) {
log.Printf("[TRACE] %s | message: %q | keywords: %v | candidates: %v | bestID: %q | err: %v",
time.Now().Format(time.RFC3339), req.Message, keywords, getCandidateIDs(candidates), bestID, err)
}
func getCandidateIDs(candidates []Reason) []string {
ids := make([]string, len(candidates))
for i, c := range candidates {
ids[i] = c.ID
}
return ids
}