179 lines
4.2 KiB
Go
179 lines
4.2 KiB
Go
package workers
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"regexp"
|
|
"sync/atomic"
|
|
|
|
"git.savin.nyc/alex/go-receipt-tracker/bus"
|
|
"git.savin.nyc/alex/go-receipt-tracker/config"
|
|
"git.savin.nyc/alex/go-receipt-tracker/utils"
|
|
"github.com/google/generative-ai-go/genai"
|
|
"google.golang.org/api/option"
|
|
)
|
|
|
|
type Gemini struct {
|
|
id string
|
|
model string
|
|
opts *config.Parser
|
|
client *genai.Client
|
|
subscriptions map[string]chan bus.Event //
|
|
bus *bus.Bus
|
|
log *slog.Logger
|
|
cancel context.CancelFunc //
|
|
end uint32 // ensure the close methods are only called once
|
|
}
|
|
|
|
func NewGemini(cfg *config.Parser, b *bus.Bus) *Gemini {
|
|
ai := &Gemini{
|
|
id: cfg.Type, // "gemini",
|
|
model: cfg.Model, // [ "gemini-1.5-flash", "gemini-2.0-flash-preview-02-05" ]
|
|
opts: cfg,
|
|
subscriptions: make(map[string]chan bus.Event),
|
|
bus: b,
|
|
}
|
|
// model := client.GenerativeModel("gemini-1.5-flash")
|
|
|
|
return ai
|
|
}
|
|
|
|
func (ai *Gemini) ID() string {
|
|
return ai.id
|
|
}
|
|
|
|
func (ai *Gemini) Type() string {
|
|
return "parser"
|
|
}
|
|
|
|
func (ai *Gemini) Init(log *slog.Logger) error {
|
|
ai.log = log
|
|
|
|
ctx := context.Background()
|
|
client, err := genai.NewClient(ctx, option.WithAPIKey(ai.opts.ApiKey))
|
|
if err != nil {
|
|
log.Error("cannot", "error", err)
|
|
}
|
|
ai.client = client
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ai *Gemini) Serve() {
|
|
if atomic.LoadUint32(&ai.end) == 1 {
|
|
return
|
|
}
|
|
|
|
ai.subscribe("parser:" + ai.ID())
|
|
// ai.subscribe("parser:*")
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
ai.cancel = cancel
|
|
|
|
go ai.eventLoop(ctx)
|
|
}
|
|
|
|
func (ai *Gemini) OneTime() error {
|
|
return nil
|
|
}
|
|
|
|
func (ai *Gemini) Stop() {
|
|
ai.client.Close()
|
|
}
|
|
|
|
func (ai *Gemini) Close() {
|
|
}
|
|
|
|
func (ai *Gemini) subscribe(chn string) error {
|
|
s, err := ai.bus.Subscribe(chn, ai.Type()+":"+ai.ID())
|
|
if err != nil {
|
|
ai.log.Error("couldn't subscribe to a channel", "channel", chn, "error", err.Error())
|
|
return err
|
|
}
|
|
ai.subscriptions[chn] = s
|
|
|
|
return nil
|
|
}
|
|
|
|
// eventLoop loops forever
|
|
func (ai *Gemini) eventLoop(ctx context.Context) {
|
|
ai.log.Debug(ai.ID() + " communication event loop started")
|
|
defer ai.log.Debug(ai.ID() + " communication event loop halted")
|
|
|
|
for {
|
|
for chn, ch := range ai.subscriptions {
|
|
select {
|
|
case event := <-ch:
|
|
switch event.Payload.(type) {
|
|
case *bus.Message:
|
|
ai.log.Debug("got a new message to a channel", "channel", chn, "worker type", ai.Type(), "worker id", ai.ID())
|
|
|
|
msg := event.Payload.(*bus.Message)
|
|
res, err := ai.recognize(event.Payload.(*bus.Message).Image.Filename)
|
|
if err != nil {
|
|
ai.log.Error("got an error from parser ("+ai.ID()+")", "error", err)
|
|
}
|
|
|
|
msg.Image.Parsed[ai.ID()] = res
|
|
err = ai.bus.Publish("processor:receipt_add", msg)
|
|
if err != nil {
|
|
ai.log.Error("couldn't publish to a channel", "channel", "processor:receipt_add", "error", err.Error())
|
|
}
|
|
}
|
|
case <-ctx.Done():
|
|
ai.log.Info("stopping " + ai.ID() + " communication event loop")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (ai *Gemini) recognize(img string) (res string, err error) {
|
|
ctx := context.Background()
|
|
|
|
r := regexp.MustCompile(`^\x60{3}(?:json)\n([\s\S]*?)\n*\x60{3}\n*$`)
|
|
|
|
file, err := ai.client.UploadFileFromPath(ctx, img, nil)
|
|
if err != nil {
|
|
ai.log.Error("cannot upload a receipt file", "Messages error: %v\n", err)
|
|
return "", err
|
|
}
|
|
defer ai.client.DeleteFile(ctx, file.Name)
|
|
|
|
gotFile, err := ai.client.GetFile(ctx, file.Name)
|
|
if err != nil {
|
|
ai.log.Error("cannot get uploaded file name", "error", err)
|
|
}
|
|
ai.log.Debug("successfully uploaded file", "filename", gotFile.Name)
|
|
|
|
model := ai.client.GenerativeModel("gemini-1.5-flash")
|
|
resp, err := model.GenerateContent(ctx,
|
|
genai.FileData{URI: file.URI},
|
|
genai.Text(config.Request))
|
|
if err != nil {
|
|
ai.log.Error("cannot recognize a receipt", "Messages error: %v\n", err)
|
|
}
|
|
|
|
for _, cand := range resp.Candidates {
|
|
if cand.Content != nil {
|
|
for _, part := range cand.Content.Parts {
|
|
res += fmt.Sprint(part)
|
|
}
|
|
}
|
|
}
|
|
|
|
if r.MatchString(res) {
|
|
res = r.FindStringSubmatch(res)[1]
|
|
}
|
|
|
|
if !utils.IsJSON(res) {
|
|
ai.log.Error("Gemini returned not valid JSON", "%+v", res)
|
|
return "", err
|
|
}
|
|
|
|
ai.log.Debug("recognition result", "worker", ai.ID(), "output", res)
|
|
|
|
return res, nil
|
|
}
|