vetrag/chat_service.go

139 lines
3.6 KiB
Go

package main
import (
"context"
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
type ChatService struct {
LLM *LLMClient
reasonsDB ReasonDB
}
func NewChatService(llm *LLMClient, db ReasonDB) *ChatService {
return &ChatService{LLM: llm, reasonsDB: db}
}
func (cs *ChatService) HandleChat(c *gin.Context) {
ctx := context.Background()
req, err := cs.parseRequest(c)
if err != nil {
return
}
keywordsResponse, err := cs.getKeywordsResponse(ctx, req.Message)
if err != nil {
cs.logChat(req, keywordsResponse, nil, "", err)
c.JSON(http.StatusOK, ChatResponse{Match: nil})
return
}
keywordList := cs.keywordsToStrings(keywordsResponse["keyword"])
best, err := cs.findVisitReason(ctx, req, keywordList)
resp := cs.buildResponse(best)
c.JSON(http.StatusOK, resp)
}
func (cs *ChatService) parseRequest(c *gin.Context) (ChatRequest, error) {
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
logrus.WithError(err).Error("Invalid request")
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return req, err
}
return req, nil
}
func (cs *ChatService) getKeywordsResponse(ctx context.Context, message string) (map[string]interface{}, error) {
keywords, err := cs.LLM.ExtractKeywords(ctx, message)
return keywords, err
}
func (cs *ChatService) keywordsToStrings(kwIface interface{}) []string {
var kwArr []string
switch v := kwIface.(type) {
case []interface{}:
for _, item := range v {
if s, ok := item.(string); ok {
kwArr = append(kwArr, s)
}
}
case []string:
kwArr = v
}
return kwArr
}
func (cs *ChatService) findVisitReason(ctx context.Context, req ChatRequest, keywordList []string) (*Reason, error) {
logrus.WithFields(logrus.Fields{
"keywords": keywordList,
"message": req.Message,
}).Info("Finding visit reason candidates")
candidateReasonse, err := cs.reasonsDB.findCandidates(keywordList)
logrus.WithFields(logrus.Fields{
"candidates": candidateReasonse,
"error": err,
}).Info("Candidate reasons found")
if err != nil {
return nil, err
}
bestID := ""
if len(candidateReasonse) > 0 {
bestID, err = cs.LLM.DisambiguateBestMatch(ctx, req.Message, candidateReasonse)
logrus.WithFields(logrus.Fields{
"bestID": bestID,
"error": err,
}).Info("Disambiguated best match")
}
reason, err := cs.reasonsDB.findById(bestID)
if err != nil {
return nil, fmt.Errorf("findById: %w", err)
}
return &reason, nil
}
func (cs *ChatService) buildResponse(best *Reason) ChatResponse {
if best == nil {
resp := ChatResponse{Match: nil}
logrus.WithFields(logrus.Fields{"response": resp}).Info("Build response: no match")
return resp
}
totalPrice, totalDuration := sumProcedures(best.Procedures)
resp := ChatResponse{
Match: &best.ID,
Procedures: best.Procedures,
TotalPrice: totalPrice,
TotalDuration: totalDuration,
Notes: best.Notes,
}
logrus.WithFields(logrus.Fields{"response": resp}).Info("Build response: match found")
return resp
}
func (cs *ChatService) logChat(req ChatRequest, keywords map[string]interface{}, candidates []Reason, bestID string, err error) {
logRequest(req, keywords, candidates, bestID, err)
if candidates != nil && bestID != "" {
var best *Reason
for i := range candidates {
if candidates[i].ID == bestID {
best = &candidates[i]
break
}
}
if best != nil {
totalPrice, totalDuration := sumProcedures(best.Procedures)
logrus.WithFields(logrus.Fields{
"match": best.ID,
"total_price": totalPrice,
"total_duration": totalDuration,
"notes": best.Notes,
}).Info("Responding with match")
}
}
}