305 lines
8.7 KiB
Go
305 lines
8.7 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"text/template"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"gopkg.in/yaml.v3"
|
|
)
|
|
|
|
// Procedure represents a single procedure for a visit reason
|
|
type Procedure struct {
|
|
Name string `yaml:"name" json:"name"`
|
|
Price int `yaml:"price" json:"price"`
|
|
DurationMin int `yaml:"duration_minutes" json:"duration_minutes"`
|
|
}
|
|
|
|
// Reason represents a visit reason entry
|
|
type Reason struct {
|
|
ID string `yaml:"id" json:"id"`
|
|
Reason string `yaml:"reason" json:"reason"`
|
|
Keywords []string `yaml:"keywords" json:"keywords"`
|
|
Procedures []Procedure `yaml:"procedures" json:"procedures"`
|
|
Notes string `yaml:"notes" json:"notes,omitempty"`
|
|
}
|
|
|
|
var reasonsDB []Reason
|
|
|
|
func loadYAMLDB(path string) error {
|
|
data, err := ioutil.ReadFile(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return yaml.Unmarshal(data, &reasonsDB)
|
|
}
|
|
|
|
// ChatRequest represents the incoming chat message
|
|
type ChatRequest struct {
|
|
Message string `json:"message"`
|
|
}
|
|
|
|
// ChatResponse represents the response to the frontend
|
|
type ChatResponse struct {
|
|
Match *string `json:"match"`
|
|
Procedures []Procedure `json:"procedures,omitempty"`
|
|
TotalPrice int `json:"total_price,omitempty"`
|
|
TotalDuration int `json:"total_duration,omitempty"`
|
|
Notes string `json:"notes,omitempty"`
|
|
}
|
|
|
|
// LLMClient abstracts LLM API calls
|
|
type LLMClient struct {
|
|
APIKey string
|
|
BaseURL string
|
|
}
|
|
|
|
// Config holds all prompts and settings
|
|
type Config struct {
|
|
LLM struct {
|
|
ExtractKeywordsPrompt string `yaml:"extract_keywords_prompt"`
|
|
DisambiguatePrompt string `yaml:"disambiguate_prompt"`
|
|
} `yaml:"llm"`
|
|
}
|
|
|
|
var appConfig Config
|
|
|
|
func loadConfig(path string) error {
|
|
data, err := ioutil.ReadFile(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return yaml.Unmarshal(data, &appConfig)
|
|
}
|
|
|
|
// ExtractKeywords calls LLM to extract keywords from user message
|
|
func (llm *LLMClient) ExtractKeywords(ctx context.Context, message string) ([]string, error) {
|
|
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
|
|
if err != nil {
|
|
log.Printf("[CONFIG] Failed to render ExtractKeywords prompt: %v", err)
|
|
return nil, err
|
|
}
|
|
log.Printf("[LLM] ExtractKeywords prompt: %q", prompt)
|
|
resp, err := llm.openAICompletion(ctx, prompt)
|
|
log.Printf("[LLM] ExtractKeywords response: %q, err: %v", resp, err)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var keywords []string
|
|
if err := json.Unmarshal([]byte(resp), &keywords); err == nil {
|
|
return keywords, nil
|
|
}
|
|
// fallback: try splitting by comma
|
|
for _, k := range bytes.Split([]byte(resp), []byte{','}) {
|
|
kw := strings.TrimSpace(string(k))
|
|
if kw != "" {
|
|
keywords = append(keywords, kw)
|
|
}
|
|
}
|
|
return keywords, nil
|
|
}
|
|
|
|
// DisambiguateBestMatch calls LLM to pick best match from candidates
|
|
func (llm *LLMClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Reason) (string, error) {
|
|
entries, _ := json.Marshal(candidates)
|
|
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
|
|
if err != nil {
|
|
log.Printf("[CONFIG] Failed to render Disambiguate prompt: %v", err)
|
|
return "", err
|
|
}
|
|
log.Printf("[LLM] DisambiguateBestMatch prompt: %q", prompt)
|
|
resp, err := llm.openAICompletion(ctx, prompt)
|
|
log.Printf("[LLM] DisambiguateBestMatch response: %q, err: %v", resp, err)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
id := strings.TrimSpace(resp)
|
|
if id == "none" || id == "null" {
|
|
return "", nil
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
// openAICompletion is a minimal OpenAI API call (text-davinci-003 or gpt-3.5-turbo-instruct)
|
|
func (llm *LLMClient) openAICompletion(ctx context.Context, prompt string) (string, error) {
|
|
apiURL := llm.BaseURL
|
|
if apiURL == "" {
|
|
apiURL = "https://api.openai.com/v1/completions"
|
|
}
|
|
log.Printf("[LLM] openAICompletion POST %s | prompt: %q", apiURL, prompt)
|
|
body := map[string]interface{}{
|
|
"model": "text-davinci-003",
|
|
"prompt": prompt,
|
|
"max_tokens": 64,
|
|
"temperature": 0,
|
|
}
|
|
jsonBody, _ := json.Marshal(body)
|
|
req, _ := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonBody))
|
|
if llm.APIKey != "" {
|
|
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
log.Printf("[LLM] openAICompletion error: %v", err)
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
var result struct {
|
|
Choices []struct {
|
|
Text string `json:"text"`
|
|
} `json:"choices"`
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
log.Printf("[LLM] openAICompletion decode error: %v", err)
|
|
return "", err
|
|
}
|
|
if len(result.Choices) == 0 {
|
|
log.Printf("[LLM] openAICompletion: no choices returned")
|
|
return "", nil
|
|
}
|
|
log.Printf("[LLM] openAICompletion: got text: %q", result.Choices[0].Text)
|
|
return result.Choices[0].Text, nil
|
|
}
|
|
|
|
// naiveKeywordExtract splits message into lowercase words (placeholder for LLM)
|
|
func naiveKeywordExtract(msg string) []string {
|
|
// TODO: Replace with LLM call
|
|
words := make(map[string]struct{})
|
|
for _, w := range strings.FieldsFunc(strings.ToLower(msg), func(r rune) bool {
|
|
return r < 'a' || r > 'z' && r < 'á' || r > 'ű'
|
|
}) {
|
|
words[w] = struct{}{}
|
|
}
|
|
res := make([]string, 0, len(words))
|
|
for w := range words {
|
|
res = append(res, w)
|
|
}
|
|
return res
|
|
}
|
|
|
|
// findCandidates returns reasons with overlapping keywords
|
|
func findCandidates(keywords []string) []Reason {
|
|
kwSet := make(map[string]struct{})
|
|
for _, k := range keywords {
|
|
kwSet[k] = struct{}{}
|
|
}
|
|
var candidates []Reason
|
|
for _, r := range reasonsDB {
|
|
for _, k := range r.Keywords {
|
|
if _, ok := kwSet[strings.ToLower(k)]; ok {
|
|
candidates = append(candidates, r)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return candidates
|
|
}
|
|
|
|
// sumProcedures calculates total price and duration
|
|
func sumProcedures(procs []Procedure) (int, int) {
|
|
totalPrice := 0
|
|
totalDuration := 0
|
|
for _, p := range procs {
|
|
totalPrice += p.Price
|
|
totalDuration += p.DurationMin
|
|
}
|
|
return totalPrice, totalDuration
|
|
}
|
|
|
|
// renderPrompt renders a Go template with the given data
|
|
func renderPrompt(tmplStr string, data any) (string, error) {
|
|
tmpl, err := template.New("").Parse(tmplStr)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
var buf bytes.Buffer
|
|
if err := tmpl.Execute(&buf, data); err != nil {
|
|
return "", err
|
|
}
|
|
return buf.String(), nil
|
|
}
|
|
|
|
func main() {
|
|
if err := loadConfig("config.yaml"); err != nil {
|
|
log.Fatalf("Failed to load config.yaml: %v", err)
|
|
}
|
|
log.Printf("Loaded config: %+v", appConfig)
|
|
if err := loadYAMLDB("db.yaml"); err != nil {
|
|
log.Fatalf("Failed to load db.yaml: %v", err)
|
|
}
|
|
fmt.Printf("Loaded %d reasons from db.yaml\n", len(reasonsDB))
|
|
|
|
llm := &LLMClient{
|
|
APIKey: os.Getenv("OPENAI_API_KEY"),
|
|
BaseURL: os.Getenv("OPENAI_BASE_URL"), // e.g. http://localhost:1234/v1/completions
|
|
}
|
|
r := gin.Default()
|
|
r.POST("/chat", func(c *gin.Context) {
|
|
var req ChatRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
log.Printf("[ERROR] Invalid request: %v", err)
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
|
|
return
|
|
}
|
|
ctx := c.Request.Context()
|
|
keywords, err := llm.ExtractKeywords(ctx, req.Message)
|
|
candidates := findCandidates(keywords)
|
|
bestID := ""
|
|
if len(candidates) > 0 && err == nil {
|
|
bestID, err = llm.DisambiguateBestMatch(ctx, req.Message, candidates)
|
|
}
|
|
logRequest(req, keywords, candidates, bestID, err)
|
|
if err != nil || len(keywords) == 0 || len(candidates) == 0 || bestID == "" {
|
|
c.JSON(http.StatusOK, ChatResponse{Match: nil})
|
|
return
|
|
}
|
|
var best *Reason
|
|
for i := range candidates {
|
|
if candidates[i].ID == bestID {
|
|
best = &candidates[i]
|
|
break
|
|
}
|
|
}
|
|
if best == nil {
|
|
c.JSON(http.StatusOK, ChatResponse{Match: nil})
|
|
return
|
|
}
|
|
totalPrice, totalDuration := sumProcedures(best.Procedures)
|
|
log.Printf("[TRACE] Responding with match: %q, totalPrice: %d, totalDuration: %d, notes: %q", best.ID, totalPrice, totalDuration, best.Notes)
|
|
c.JSON(http.StatusOK, ChatResponse{
|
|
Match: &best.ID,
|
|
Procedures: best.Procedures,
|
|
TotalPrice: totalPrice,
|
|
TotalDuration: totalDuration,
|
|
Notes: best.Notes,
|
|
})
|
|
})
|
|
|
|
r.Run(":8080")
|
|
}
|
|
|
|
// logRequest logs incoming chat requests and extracted info
|
|
func logRequest(req ChatRequest, keywords []string, candidates []Reason, bestID string, err error) {
|
|
log.Printf("[TRACE] %s | message: %q | keywords: %v | candidates: %v | bestID: %q | err: %v",
|
|
time.Now().Format(time.RFC3339), req.Message, keywords, getCandidateIDs(candidates), bestID, err)
|
|
}
|
|
|
|
func getCandidateIDs(candidates []Reason) []string {
|
|
ids := make([]string, len(candidates))
|
|
for i, c := range candidates {
|
|
ids[i] = c.ID
|
|
}
|
|
return ids
|
|
}
|