test
This commit is contained in:
parent
8b615def81
commit
c6b3639109
|
|
@ -0,0 +1,108 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LLMClient abstracts LLM API calls
|
||||||
|
type LLMClient struct {
|
||||||
|
APIKey string
|
||||||
|
BaseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractKeywords calls LLM to extract keywords from user message
|
||||||
|
func (llm *LLMClient) ExtractKeywords(ctx context.Context, message string) ([]string, error) {
|
||||||
|
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
|
||||||
|
resp, err := llm.openAICompletion(ctx, prompt)
|
||||||
|
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var keywords []string
|
||||||
|
if err := json.Unmarshal([]byte(resp), &keywords); err == nil {
|
||||||
|
return keywords, nil
|
||||||
|
}
|
||||||
|
// fallback: try splitting by comma
|
||||||
|
for _, k := range bytes.Split([]byte(resp), []byte{','}) {
|
||||||
|
kw := strings.TrimSpace(string(k))
|
||||||
|
if kw != "" {
|
||||||
|
keywords = append(keywords, kw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return keywords, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisambiguateBestMatch calls LLM to pick best match from candidates
|
||||||
|
func (llm *LLMClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Reason) (string, error) {
|
||||||
|
entries, _ := json.Marshal(candidates)
|
||||||
|
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt")
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt")
|
||||||
|
resp, err := llm.openAICompletion(ctx, prompt)
|
||||||
|
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
id := strings.TrimSpace(resp)
|
||||||
|
if id == "none" || id == "null" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// openAICompletion is a minimal OpenAI API call (text-davinci-003 or gpt-3.5-turbo-instruct)
|
||||||
|
func (llm *LLMClient) openAICompletion(ctx context.Context, prompt string) (string, error) {
|
||||||
|
apiURL := llm.BaseURL
|
||||||
|
if apiURL == "" {
|
||||||
|
apiURL = "https://api.openai.com/v1/completions"
|
||||||
|
}
|
||||||
|
logrus.WithFields(logrus.Fields{"api_url": apiURL, "prompt": prompt}).Info("[LLM] openAICompletion POST")
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"model": "text-davinci-003",
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": 64,
|
||||||
|
"temperature": 0,
|
||||||
|
}
|
||||||
|
jsonBody, _ := json.Marshal(body)
|
||||||
|
req, _ := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonBody))
|
||||||
|
if llm.APIKey != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Error("[LLM] openAICompletion error")
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
var result struct {
|
||||||
|
Choices []struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"choices"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||||
|
logrus.WithError(err).Error("[LLM] openAICompletion decode error")
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if len(result.Choices) == 0 {
|
||||||
|
logrus.Warn("[LLM] openAICompletion: no choices returned")
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
logrus.WithField("text", result.Choices[0].Text).Info("[LLM] openAICompletion: got text")
|
||||||
|
return result.Choices[0].Text, nil
|
||||||
|
}
|
||||||
102
main.go
102
main.go
|
|
@ -3,7 +3,6 @@ package main
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
|
@ -56,12 +55,6 @@ type ChatResponse struct {
|
||||||
Notes string `json:"notes,omitempty"`
|
Notes string `json:"notes,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// LLMClient abstracts LLM API calls
|
|
||||||
type LLMClient struct {
|
|
||||||
APIKey string
|
|
||||||
BaseURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Config holds all prompts and settings
|
// Config holds all prompts and settings
|
||||||
type Config struct {
|
type Config struct {
|
||||||
LLM struct {
|
LLM struct {
|
||||||
|
|
@ -80,97 +73,6 @@ func loadConfig(path string) error {
|
||||||
return yaml.Unmarshal(data, &appConfig)
|
return yaml.Unmarshal(data, &appConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExtractKeywords calls LLM to extract keywords from user message
|
|
||||||
func (llm *LLMClient) ExtractKeywords(ctx context.Context, message string) ([]string, error) {
|
|
||||||
prompt, err := renderPrompt(appConfig.LLM.ExtractKeywordsPrompt, map[string]string{"Message": message})
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).Error("[CONFIG] Failed to render ExtractKeywords prompt")
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
logrus.WithField("prompt", prompt).Info("[LLM] ExtractKeywords prompt")
|
|
||||||
resp, err := llm.openAICompletion(ctx, prompt)
|
|
||||||
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] ExtractKeywords response")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var keywords []string
|
|
||||||
if err := json.Unmarshal([]byte(resp), &keywords); err == nil {
|
|
||||||
return keywords, nil
|
|
||||||
}
|
|
||||||
// fallback: try splitting by comma
|
|
||||||
for _, k := range bytes.Split([]byte(resp), []byte{','}) {
|
|
||||||
kw := strings.TrimSpace(string(k))
|
|
||||||
if kw != "" {
|
|
||||||
keywords = append(keywords, kw)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return keywords, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DisambiguateBestMatch calls LLM to pick best match from candidates
|
|
||||||
func (llm *LLMClient) DisambiguateBestMatch(ctx context.Context, message string, candidates []Reason) (string, error) {
|
|
||||||
entries, _ := json.Marshal(candidates)
|
|
||||||
prompt, err := renderPrompt(appConfig.LLM.DisambiguatePrompt, map[string]string{"Entries": string(entries), "Message": message})
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).Error("[CONFIG] Failed to render Disambiguate prompt")
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
logrus.WithField("prompt", prompt).Info("[LLM] DisambiguateBestMatch prompt")
|
|
||||||
resp, err := llm.openAICompletion(ctx, prompt)
|
|
||||||
logrus.WithFields(logrus.Fields{"response": resp, "err": err}).Info("[LLM] DisambiguateBestMatch response")
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
id := strings.TrimSpace(resp)
|
|
||||||
if id == "none" || id == "null" {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
return id, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// openAICompletion is a minimal OpenAI API call (text-davinci-003 or gpt-3.5-turbo-instruct)
|
|
||||||
func (llm *LLMClient) openAICompletion(ctx context.Context, prompt string) (string, error) {
|
|
||||||
apiURL := llm.BaseURL
|
|
||||||
if apiURL == "" {
|
|
||||||
apiURL = "https://api.openai.com/v1/completions"
|
|
||||||
}
|
|
||||||
logrus.WithFields(logrus.Fields{"api_url": apiURL, "prompt": prompt}).Info("[LLM] openAICompletion POST")
|
|
||||||
body := map[string]interface{}{
|
|
||||||
"model": "text-davinci-003",
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_tokens": 64,
|
|
||||||
"temperature": 0,
|
|
||||||
}
|
|
||||||
jsonBody, _ := json.Marshal(body)
|
|
||||||
req, _ := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonBody))
|
|
||||||
if llm.APIKey != "" {
|
|
||||||
req.Header.Set("Authorization", "Bearer "+llm.APIKey)
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
client := &http.Client{}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).Error("[LLM] openAICompletion error")
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
var result struct {
|
|
||||||
Choices []struct {
|
|
||||||
Text string `json:"text"`
|
|
||||||
} `json:"choices"`
|
|
||||||
}
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
||||||
logrus.WithError(err).Error("[LLM] openAICompletion decode error")
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if len(result.Choices) == 0 {
|
|
||||||
logrus.Warn("[LLM] openAICompletion: no choices returned")
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
logrus.WithField("text", result.Choices[0].Text).Info("[LLM] openAICompletion: got text")
|
|
||||||
return result.Choices[0].Text, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// naiveKeywordExtract splits message into lowercase words (placeholder for LLM)
|
// naiveKeywordExtract splits message into lowercase words (placeholder for LLM)
|
||||||
func naiveKeywordExtract(msg string) []string {
|
func naiveKeywordExtract(msg string) []string {
|
||||||
// TODO: Replace with LLM call
|
// TODO: Replace with LLM call
|
||||||
|
|
@ -257,7 +159,7 @@ func main() {
|
||||||
}
|
}
|
||||||
llm := &LLMClient{
|
llm := &LLMClient{
|
||||||
APIKey: os.Getenv("OPENAI_API_KEY"),
|
APIKey: os.Getenv("OPENAI_API_KEY"),
|
||||||
BaseURL: os.Getenv("OPENAI_BASE_URL"), // e.g. http://localhost:1234/v1/completions
|
BaseURL: os.Getenv("OPENAI_BASE_URL"),
|
||||||
}
|
}
|
||||||
r := gin.Default()
|
r := gin.Default()
|
||||||
r.GET("/", func(c *gin.Context) {
|
r.GET("/", func(c *gin.Context) {
|
||||||
|
|
@ -271,7 +173,7 @@ func main() {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx := c.Request.Context()
|
ctx := context.Background()
|
||||||
keywords, err := llm.ExtractKeywords(ctx, req.Message)
|
keywords, err := llm.ExtractKeywords(ctx, req.Message)
|
||||||
candidates := findCandidates(keywords)
|
candidates := findCandidates(keywords)
|
||||||
bestID := ""
|
bestID := ""
|
||||||
|
|
|
||||||
2
run.sh
2
run.sh
|
|
@ -1,3 +1,3 @@
|
||||||
export OPENAI_BASE_URL=http://localhost:1234/v1/completions
|
export OPENAI_BASE_URL=http://localhost:1234/v1/completions
|
||||||
export OPENAI_API_KEY=sk-no-key-needed # (if LM Studio doesn't require a real key)
|
export OPENAI_API_KEY=sk-no-key-needed # (if LM Studio doesn't require a real key)
|
||||||
go run main.go
|
go run .
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue