From 5dadbcb53d894471f80ce0c0f54c1a8a338033b7 Mon Sep 17 00:00:00 2001 From: Jonas Hess Date: Thu, 31 Oct 2024 20:00:43 +0100 Subject: [PATCH 1/2] feat: improve auto throughput & logging --- README.md | 3 ++ app_http_handlers.go | 17 +++++----- app_llm.go | 11 +++--- go.mod | 3 +- go.sum | 3 ++ jobs.go | 34 +++++++++++++------ main.go | 81 +++++++++++++++++++++++++++++++++----------- paperless.go | 13 ++++--- 8 files changed, 113 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index ba2ec68..c829818 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,7 @@ services: OLLAMA_HOST: 'http://host.docker.internal:11434' # If using Ollama VISION_LLM_PROVIDER: 'ollama' # Optional, for OCR VISION_LLM_MODEL: 'minicpm-v' # Optional, for OCR + LOG_LEVEL: 'info' # Optional or 'debug', 'warn', 'error' volumes: - ./prompts:/app/prompts # Mount the prompts directory ports: @@ -122,6 +123,7 @@ If you prefer to run the application manually: -e LLM_LANGUAGE='English' \ -e VISION_LLM_PROVIDER='ollama' \ -e VISION_LLM_MODEL='minicpm-v' \ + -e LOG_LEVEL='info' \ -v $(pwd)/prompts:/app/prompts \ # Mount the prompts directory -p 8080:8080 \ paperless-gpt @@ -142,6 +144,7 @@ If you prefer to run the application manually: | `OLLAMA_HOST` | The URL of the Ollama server (e.g., `http://host.docker.internal:11434`). Useful if using Ollama. Default is `http://127.0.0.1:11434`. | No | | `VISION_LLM_PROVIDER` | The vision LLM provider to use for OCR (`openai` or `ollama`). | No | | `VISION_LLM_MODEL` | The model name to use for OCR (e.g., `minicpm-v`). | No | +| `LOG_LEVEL` | The log level for the application (`info`, `debug`, `warn`, `error`). Default is `info`. | No | **Note:** When using Ollama, ensure that the Ollama server is running and accessible from the paperless-gpt container. diff --git a/app_http_handlers.go b/app_http_handlers.go index c7ab6af..c9c243f 100644 --- a/app_http_handlers.go +++ b/app_http_handlers.go @@ -2,7 +2,6 @@ package main import ( "fmt" - "log" "net/http" "os" "strconv" @@ -59,7 +58,7 @@ func updatePromptsHandler(c *gin.Context) { titleTemplate = t err = os.WriteFile("prompts/title_prompt.tmpl", []byte(req.TitleTemplate), 0644) if err != nil { - log.Printf("Failed to write title_prompt.tmpl: %v", err) + log.Errorf("Failed to write title_prompt.tmpl: %v", err) } } @@ -73,7 +72,7 @@ func updatePromptsHandler(c *gin.Context) { tagTemplate = t err = os.WriteFile("prompts/tag_prompt.tmpl", []byte(req.TagTemplate), 0644) if err != nil { - log.Printf("Failed to write tag_prompt.tmpl: %v", err) + log.Errorf("Failed to write tag_prompt.tmpl: %v", err) } } @@ -87,7 +86,7 @@ func (app *App) getAllTagsHandler(c *gin.Context) { tags, err := app.Client.GetAllTags(ctx) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error fetching tags: %v", err)}) - log.Printf("Error fetching tags: %v", err) + log.Errorf("Error fetching tags: %v", err) return } @@ -101,7 +100,7 @@ func (app *App) documentsHandler(c *gin.Context) { documents, err := app.Client.GetDocumentsByTags(ctx, []string{manualTag}) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error fetching documents: %v", err)}) - log.Printf("Error fetching documents: %v", err) + log.Errorf("Error fetching documents: %v", err) return } @@ -115,14 +114,14 @@ func (app *App) generateSuggestionsHandler(c *gin.Context) { var suggestionRequest GenerateSuggestionsRequest if err := c.ShouldBindJSON(&suggestionRequest); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request payload: %v", err)}) - log.Printf("Invalid request payload: %v", err) + log.Errorf("Invalid request payload: %v", err) return } results, err := app.generateDocumentSuggestions(ctx, suggestionRequest) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error processing documents: %v", err)}) - log.Printf("Error processing documents: %v", err) + log.Errorf("Error processing documents: %v", err) return } @@ -135,14 +134,14 @@ func (app *App) updateDocumentsHandler(c *gin.Context) { var documents []DocumentSuggestion if err := c.ShouldBindJSON(&documents); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request payload: %v", err)}) - log.Printf("Invalid request payload: %v", err) + log.Errorf("Invalid request payload: %v", err) return } err := app.Client.UpdateDocuments(ctx, documents) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error updating documents: %v", err)}) - log.Printf("Error updating documents: %v", err) + log.Errorf("Error updating documents: %v", err) return } diff --git a/app_llm.go b/app_llm.go index 15fb79d..b7228f6 100644 --- a/app_llm.go +++ b/app_llm.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "fmt" - "log" "strings" "sync" @@ -30,7 +29,7 @@ func (app *App) getSuggestedTags(ctx context.Context, content string, suggestedT } prompt := promptBuffer.String() - log.Printf("Tag suggestion prompt: %s", prompt) + log.Debugf("Tag suggestion prompt: %s", prompt) completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{ { @@ -119,7 +118,7 @@ func (app *App) getSuggestedTitle(ctx context.Context, content string) (string, prompt := promptBuffer.String() - log.Printf("Title suggestion prompt: %s", prompt) + log.Debugf("Title suggestion prompt: %s", prompt) completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{ { @@ -183,7 +182,7 @@ func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionReque mu.Lock() errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err)) mu.Unlock() - log.Printf("Error processing document %d: %v", documentID, err) + log.Errorf("Error processing document %d: %v", documentID, err) return } } @@ -194,7 +193,7 @@ func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionReque mu.Lock() errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err)) mu.Unlock() - log.Printf("Error generating tags for document %d: %v", documentID, err) + log.Errorf("Error generating tags for document %d: %v", documentID, err) return } } @@ -206,6 +205,7 @@ func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionReque } // Titles if suggestionRequest.GenerateTitles { + log.Printf("Suggested title for document %d: %s", documentID, suggestedTitle) suggestion.SuggestedTitle = suggestedTitle } else { suggestion.SuggestedTitle = doc.Title @@ -213,6 +213,7 @@ func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionReque // Tags if suggestionRequest.GenerateTags { + log.Printf("Suggested tags for document %d: %v", documentID, suggestedTags) suggestion.SuggestedTags = suggestedTags } else { suggestion.SuggestedTags = removeTagFromList(doc.Tags, manualTag) diff --git a/go.mod b/go.mod index 6825cd4..babd910 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,8 @@ require ( github.com/Masterminds/sprig/v3 v3.2.3 github.com/gen2brain/go-fitz v1.24.14 github.com/gin-gonic/gin v1.10.0 + github.com/google/uuid v1.6.0 + github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.9.0 github.com/tmc/langchaingo v0.1.12 golang.org/x/sync v0.7.0 @@ -29,7 +31,6 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/goccy/go-json v0.10.2 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/huandu/xstrings v1.3.3 // indirect github.com/imdario/mergo v0.3.13 // indirect github.com/json-iterator/go v1.1.12 // indirect diff --git a/go.sum b/go.sum index d4dfaab..76d584c 100644 --- a/go.sum +++ b/go.sum @@ -77,6 +77,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -123,6 +125,7 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/jobs.go b/jobs.go index a2635a4..bc58b82 100644 --- a/jobs.go +++ b/jobs.go @@ -3,7 +3,6 @@ package main import ( "context" "fmt" - "log" "os" "sort" "strings" @@ -11,6 +10,7 @@ import ( "time" "github.com/google/uuid" + "github.com/sirupsen/logrus" ) // Job represents an OCR job @@ -31,13 +31,25 @@ type JobStore struct { } var ( + logger = logrus.New() + jobStore = &JobStore{ jobs: make(map[string]*Job), } jobQueue = make(chan *Job, 100) // Buffered channel with capacity of 100 jobs - logger = log.New(os.Stdout, "OCR_JOB: ", log.LstdFlags) ) +func init() { + + // Initialize logger + logger.SetOutput(os.Stdout) + logger.SetFormatter(&logrus.TextFormatter{ + FullTimestamp: true, + }) + logger.SetLevel(logrus.InfoLevel) + logger.WithField("prefix", "OCR_JOB") +} + func generateJobID() string { return uuid.New().String() } @@ -47,7 +59,7 @@ func (store *JobStore) addJob(job *Job) { defer store.Unlock() job.PagesDone = 0 // Initialize PagesDone to 0 store.jobs[job.ID] = job - logger.Printf("Job added: %v", job) + logger.Infof("Job added: %v", job) } func (store *JobStore) getJob(jobID string) (*Job, bool) { @@ -82,7 +94,7 @@ func (store *JobStore) updateJobStatus(jobID, status, result string) { job.Result = result } job.UpdatedAt = time.Now() - logger.Printf("Job status updated: %v", job) + logger.Infof("Job status updated: %v", job) } } @@ -92,16 +104,16 @@ func (store *JobStore) updatePagesDone(jobID string, pagesDone int) { if job, exists := store.jobs[jobID]; exists { job.PagesDone = pagesDone job.UpdatedAt = time.Now() - logger.Printf("Job pages done updated: %v", job) + logger.Infof("Job pages done updated: %v", job) } } func startWorkerPool(app *App, numWorkers int) { for i := 0; i < numWorkers; i++ { go func(workerID int) { - logger.Printf("Worker %d started", workerID) + logger.Infof("Worker %d started", workerID) for job := range jobQueue { - logger.Printf("Worker %d processing job: %s", workerID, job.ID) + logger.Infof("Worker %d processing job: %s", workerID, job.ID) processJob(app, job) } }(i) @@ -116,7 +128,7 @@ func processJob(app *App, job *Job) { // Download images of the document imagePaths, err := app.Client.DownloadDocumentAsImages(ctx, job.DocumentID) if err != nil { - logger.Printf("Error downloading document images for job %s: %v", job.ID, err) + logger.Infof("Error downloading document images for job %s: %v", job.ID, err) jobStore.updateJobStatus(job.ID, "failed", fmt.Sprintf("Error downloading document images: %v", err)) return } @@ -125,14 +137,14 @@ func processJob(app *App, job *Job) { for i, imagePath := range imagePaths { imageContent, err := os.ReadFile(imagePath) if err != nil { - logger.Printf("Error reading image file for job %s: %v", job.ID, err) + logger.Errorf("Error reading image file for job %s: %v", job.ID, err) jobStore.updateJobStatus(job.ID, "failed", fmt.Sprintf("Error reading image file: %v", err)) return } ocrText, err := app.doOCRViaLLM(ctx, imageContent) if err != nil { - logger.Printf("Error performing OCR for job %s: %v", job.ID, err) + logger.Errorf("Error performing OCR for job %s: %v", job.ID, err) jobStore.updateJobStatus(job.ID, "failed", fmt.Sprintf("Error performing OCR: %v", err)) return } @@ -146,5 +158,5 @@ func processJob(app *App, job *Job) { // Update job status and result jobStore.updateJobStatus(job.ID, "completed", fullOcrText) - logger.Printf("Job completed: %s", job.ID) + logger.Infof("Job completed: %s", job.ID) } diff --git a/main.go b/main.go index 234b5a5..5c8d3ef 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,6 @@ package main import ( "context" "fmt" - "log" "net/http" "os" "path/filepath" @@ -14,6 +13,7 @@ import ( "github.com/Masterminds/sprig/v3" "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/ollama" "github.com/tmc/langchaingo/llms/openai" @@ -21,6 +21,11 @@ import ( // Global Variables and Constants var ( + + // Logger + log = logrus.New() + + // Environment Variables paperlessBaseURL = os.Getenv("PAPERLESS_BASE_URL") paperlessAPIToken = os.Getenv("PAPERLESS_API_TOKEN") openaiAPIKey = os.Getenv("OPENAI_API_KEY") @@ -30,6 +35,7 @@ var ( llmModel = os.Getenv("LLM_MODEL") visionLlmProvider = os.Getenv("VISION_LLM_PROVIDER") visionLlmModel = os.Getenv("VISION_LLM_MODEL") + logLevel = strings.ToLower(os.Getenv("LOG_LEVEL")) // Templates titleTemplate *template.Template @@ -75,6 +81,9 @@ func main() { // Validate Environment Variables validateEnvVars() + // Initialize logrus logger + initLogger() + // Initialize PaperlessClient client := NewPaperlessClient(paperlessBaseURL, paperlessAPIToken) @@ -102,25 +111,28 @@ func main() { // Start background process for auto-tagging go func() { - - minBackoffDuration := time.Second + minBackoffDuration := 10 * time.Second maxBackoffDuration := time.Hour pollingInterval := 10 * time.Second backoffDuration := minBackoffDuration for { - if err := app.processAutoTagDocuments(); err != nil { - log.Printf("Error in processAutoTagDocuments: %v", err) + processedCount, err := app.processAutoTagDocuments() + if err != nil { + log.Errorf("Error in processAutoTagDocuments: %v", err) time.Sleep(backoffDuration) backoffDuration *= 2 // Exponential backoff if backoffDuration > maxBackoffDuration { - log.Printf("Repeated errors in processAutoTagDocuments detected. Setting backoff to %v", maxBackoffDuration) + log.Warnf("Repeated errors in processAutoTagDocuments detected. Setting backoff to %v", maxBackoffDuration) backoffDuration = maxBackoffDuration } } else { backoffDuration = minBackoffDuration } - time.Sleep(pollingInterval) + + if processedCount == 0 { + time.Sleep(pollingInterval) + } } }() @@ -168,12 +180,34 @@ func main() { numWorkers := 1 // Number of workers to start startWorkerPool(app, numWorkers) - log.Println("Server started on port :8080") + log.Infoln("Server started on port :8080") if err := router.Run(":8080"); err != nil { log.Fatalf("Failed to run server: %v", err) } } +func initLogger() { + switch logLevel { + case "debug": + log.SetLevel(logrus.DebugLevel) + case "info": + log.SetLevel(logrus.InfoLevel) + case "warn": + log.SetLevel(logrus.WarnLevel) + case "error": + log.SetLevel(logrus.ErrorLevel) + default: + log.SetLevel(logrus.InfoLevel) + if logLevel != "" { + log.Fatalf("Invalid log level: '%s'.", logLevel) + } + } + + log.SetFormatter(&logrus.TextFormatter{ + FullTimestamp: true, + }) +} + func isOcrEnabled() bool { return visionLlmModel != "" && visionLlmProvider != "" } @@ -192,28 +226,37 @@ func validateEnvVars() { log.Fatal("Please set the LLM_PROVIDER environment variable.") } + if visionLlmProvider != "" && visionLlmProvider != "openai" && visionLlmProvider != "ollama" { + log.Fatal("Please set the LLM_PROVIDER environment variable to 'openai' or 'ollama'.") + } + if llmModel == "" { log.Fatal("Please set the LLM_MODEL environment variable.") } - if llmProvider == "openai" && openaiAPIKey == "" { + if (llmProvider == "openai" || visionLlmProvider == "openai") && openaiAPIKey == "" { log.Fatal("Please set the OPENAI_API_KEY environment variable for OpenAI provider.") } } // processAutoTagDocuments handles the background auto-tagging of documents -func (app *App) processAutoTagDocuments() error { +func (app *App) processAutoTagDocuments() (int, error) { ctx := context.Background() documents, err := app.Client.GetDocumentsByTags(ctx, []string{autoTag}) if err != nil { - return fmt.Errorf("error fetching documents with autoTag: %w", err) + return 0, fmt.Errorf("error fetching documents with autoTag: %w", err) } if len(documents) == 0 { - return nil // No documents to process + log.Debugf("No documents with tag %s found", autoTag) + return 0, nil // No documents to process } + log.Debugf("Found at least %d remaining documents with tag %s", len(documents), autoTag) + + documents = documents[:1] // Process only one document at a time + suggestionRequest := GenerateSuggestionsRequest{ Documents: documents, GenerateTitles: true, @@ -222,15 +265,15 @@ func (app *App) processAutoTagDocuments() error { suggestions, err := app.generateDocumentSuggestions(ctx, suggestionRequest) if err != nil { - return fmt.Errorf("error generating suggestions: %w", err) + return 0, fmt.Errorf("error generating suggestions: %w", err) } err = app.Client.UpdateDocuments(ctx, suggestions) if err != nil { - return fmt.Errorf("error updating documents: %w", err) + return 0, fmt.Errorf("error updating documents: %w", err) } - return nil + return len(documents), nil } // removeTagFromList removes a specific tag from a list of tags @@ -268,7 +311,7 @@ func loadTemplates() { titleTemplatePath := filepath.Join(promptsDir, "title_prompt.tmpl") titleTemplateContent, err := os.ReadFile(titleTemplatePath) if err != nil { - log.Printf("Could not read %s, using default template: %v", titleTemplatePath, err) + log.Errorf("Could not read %s, using default template: %v", titleTemplatePath, err) titleTemplateContent = []byte(defaultTitleTemplate) if err := os.WriteFile(titleTemplatePath, titleTemplateContent, os.ModePerm); err != nil { log.Fatalf("Failed to write default title template to disk: %v", err) @@ -283,7 +326,7 @@ func loadTemplates() { tagTemplatePath := filepath.Join(promptsDir, "tag_prompt.tmpl") tagTemplateContent, err := os.ReadFile(tagTemplatePath) if err != nil { - log.Printf("Could not read %s, using default template: %v", tagTemplatePath, err) + log.Errorf("Could not read %s, using default template: %v", tagTemplatePath, err) tagTemplateContent = []byte(defaultTagTemplate) if err := os.WriteFile(tagTemplatePath, tagTemplateContent, os.ModePerm); err != nil { log.Fatalf("Failed to write default tag template to disk: %v", err) @@ -298,7 +341,7 @@ func loadTemplates() { ocrTemplatePath := filepath.Join(promptsDir, "ocr_prompt.tmpl") ocrTemplateContent, err := os.ReadFile(ocrTemplatePath) if err != nil { - log.Printf("Could not read %s, using default template: %v", ocrTemplatePath, err) + log.Errorf("Could not read %s, using default template: %v", ocrTemplatePath, err) ocrTemplateContent = []byte(defaultOcrPrompt) if err := os.WriteFile(ocrTemplatePath, ocrTemplateContent, os.ModePerm); err != nil { log.Fatalf("Failed to write default OCR template to disk: %v", err) @@ -355,7 +398,7 @@ func createVisionLLM() (llms.Model, error) { ollama.WithServerURL(host), ) default: - log.Printf("No Vision LLM provider created: %s", visionLlmProvider) + log.Infoln("Vision LLM not enabled") return nil, nil } } diff --git a/paperless.go b/paperless.go index b477324..18ccf52 100644 --- a/paperless.go +++ b/paperless.go @@ -7,7 +7,6 @@ import ( "fmt" "image/jpeg" "io" - "log" "net/http" "os" "path/filepath" @@ -223,7 +222,7 @@ func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []Docum // Fetch all available tags availableTags, err := c.GetAllTags(ctx) if err != nil { - log.Printf("Error fetching available tags: %v", err) + log.Errorf("Error fetching available tags: %v", err) return err } @@ -249,7 +248,7 @@ func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []Docum } newTags = append(newTags, tagID) } else { - log.Printf("Tag '%s' does not exist in paperless-ngx, skipping.", tagName) + log.Warnf("Tag '%s' does not exist in paperless-ngx, skipping.", tagName) } } @@ -262,7 +261,7 @@ func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []Docum if suggestedTitle != "" { updatedFields["title"] = suggestedTitle } else { - log.Printf("No valid title found for document %d, skipping.", documentID) + log.Warnf("No valid title found for document %d, skipping.", documentID) } // Suggested Content @@ -274,7 +273,7 @@ func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []Docum // Marshal updated fields to JSON jsonData, err := json.Marshal(updatedFields) if err != nil { - log.Printf("Error marshalling JSON for document %d: %v", documentID, err) + log.Errorf("Error marshalling JSON for document %d: %v", documentID, err) return err } @@ -282,14 +281,14 @@ func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []Docum path := fmt.Sprintf("api/documents/%d/", documentID) resp, err := c.Do(ctx, "PATCH", path, bytes.NewBuffer(jsonData)) if err != nil { - log.Printf("Error updating document %d: %v", documentID, err) + log.Errorf("Error updating document %d: %v", documentID, err) return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) - log.Printf("Error updating document %d: %d, %s", documentID, resp.StatusCode, string(bodyBytes)) + log.Errorf("Error updating document %d: %d, %s", documentID, resp.StatusCode, string(bodyBytes)) return fmt.Errorf("error updating document %d: %d, %s", documentID, resp.StatusCode, string(bodyBytes)) } From 16281be6d353780c84037faff02678c09c3a7fff Mon Sep 17 00:00:00 2001 From: Jonas Hess Date: Thu, 31 Oct 2024 22:37:35 +0100 Subject: [PATCH 2/2] feat: auto create missing correspondents --- README.md | 25 ++++--- app_http_handlers.go | 2 +- app_llm.go | 74 ++++++++++++++++++ main.go | 84 ++++++++++++++++----- paperless.go | 174 +++++++++++++++++++++++++++++++++++-------- paperless_test.go | 4 +- types.go | 36 +++++++-- 7 files changed, 327 insertions(+), 72 deletions(-) diff --git a/README.md b/README.md index c829818..f81241b 100644 --- a/README.md +++ b/README.md @@ -133,18 +133,19 @@ If you prefer to run the application manually: ### Environment Variables -| Variable | Description | Required | -|-----------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------|----------| -| `PAPERLESS_BASE_URL` | The base URL of your paperless-ngx instance (e.g., `http://paperless-ngx:8000`). | Yes | -| `PAPERLESS_API_TOKEN` | API token for accessing paperless-ngx. You can generate one in the paperless-ngx admin interface. | Yes | -| `LLM_PROVIDER` | The LLM provider to use (`openai` or `ollama`). | Yes | -| `LLM_MODEL` | The model name to use (e.g., `gpt-4o`, `gpt-3.5-turbo`, `llama2`). | Yes | -| `OPENAI_API_KEY` | Your OpenAI API key. Required if using OpenAI as the LLM provider. | Cond. | -| `LLM_LANGUAGE` | The likely language of your documents (e.g., `English`, `German`). Default is `English`. | No | -| `OLLAMA_HOST` | The URL of the Ollama server (e.g., `http://host.docker.internal:11434`). Useful if using Ollama. Default is `http://127.0.0.1:11434`. | No | -| `VISION_LLM_PROVIDER` | The vision LLM provider to use for OCR (`openai` or `ollama`). | No | -| `VISION_LLM_MODEL` | The model name to use for OCR (e.g., `minicpm-v`). | No | -| `LOG_LEVEL` | The log level for the application (`info`, `debug`, `warn`, `error`). Default is `info`. | No | +| Variable | Description | Required | +|----------------------------|----------------------------------------------------------------------------------------------------------------------------------------|----------| +| `PAPERLESS_BASE_URL` | The base URL of your paperless-ngx instance (e.g., `http://paperless-ngx:8000`). | Yes | +| `PAPERLESS_API_TOKEN` | API token for accessing paperless-ngx. You can generate one in the paperless-ngx admin interface. | Yes | +| `LLM_PROVIDER` | The LLM provider to use (`openai` or `ollama`). | Yes | +| `LLM_MODEL` | The model name to use (e.g., `gpt-4o`, `gpt-3.5-turbo`, `llama2`). | Yes | +| `OPENAI_API_KEY` | Your OpenAI API key. Required if using OpenAI as the LLM provider. | Cond. | +| `LLM_LANGUAGE` | The likely language of your documents (e.g., `English`, `German`). Default is `English`. | No | +| `OLLAMA_HOST` | The URL of the Ollama server (e.g., `http://host.docker.internal:11434`). Useful if using Ollama. Default is `http://127.0.0.1:11434`. | No | +| `VISION_LLM_PROVIDER` | The vision LLM provider to use for OCR (`openai` or `ollama`). | No | +| `VISION_LLM_MODEL` | The model name to use for OCR (e.g., `minicpm-v`). | No | +| `LOG_LEVEL` | The log level for the application (`info`, `debug`, `warn`, `error`). Default is `info`. | No | +| `CORRESPONDENT_BLACK_LIST` | A comma-separated list of names to exclude from the correspondents suggestions. Example: `John Doe, Jane Smith`. | No | **Note:** When using Ollama, ensure that the Ollama server is running and accessible from the paperless-gpt container. diff --git a/app_http_handlers.go b/app_http_handlers.go index c9c243f..5ca600b 100644 --- a/app_http_handlers.go +++ b/app_http_handlers.go @@ -97,7 +97,7 @@ func (app *App) getAllTagsHandler(c *gin.Context) { func (app *App) documentsHandler(c *gin.Context) { ctx := c.Request.Context() - documents, err := app.Client.GetDocumentsByTags(ctx, []string{manualTag}) + documents, err := app.Client.GetDocumentsByTags(ctx, []string{manualTag}, 25) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error fetching documents: %v", err)}) log.Errorf("Error fetching documents: %v", err) diff --git a/app_llm.go b/app_llm.go index b7228f6..ec2cb59 100644 --- a/app_llm.go +++ b/app_llm.go @@ -10,6 +10,46 @@ import ( "github.com/tmc/langchaingo/llms" ) +// getSuggestedCorrespondent generates a suggested correspondent for a document using the LLM +func (app *App) getSuggestedCorrespondent(ctx context.Context, content string, suggestedTitle string, availableCorrespondents []string, correspondentBlackList []string) (string, error) { + likelyLanguage := getLikelyLanguage() + + templateMutex.RLock() + defer templateMutex.RUnlock() + + var promptBuffer bytes.Buffer + err := correspondentTemplate.Execute(&promptBuffer, map[string]interface{}{ + "Language": likelyLanguage, + "AvailableCorrespondents": availableCorrespondents, + "BlackList": correspondentBlackList, + "Title": suggestedTitle, + "Content": content, + }) + if err != nil { + return "", fmt.Errorf("error executing correspondent template: %v", err) + } + + prompt := promptBuffer.String() + log.Debugf("Correspondent suggestion prompt: %s", prompt) + + completion, err := app.LLM.GenerateContent(ctx, []llms.MessageContent{ + { + Parts: []llms.ContentPart{ + llms.TextContent{ + Text: prompt, + }, + }, + Role: llms.ChatMessageTypeHuman, + }, + }) + if err != nil { + return "", fmt.Errorf("error getting response from LLM: %v", err) + } + + response := strings.TrimSpace(completion.Choices[0].Content) + return response, nil +} + // getSuggestedTags generates suggested tags for a document using the LLM func (app *App) getSuggestedTags(ctx context.Context, content string, suggestedTitle string, availableTags []string) ([]string, error) { likelyLanguage := getLikelyLanguage() @@ -154,6 +194,18 @@ func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionReque availableTagNames = append(availableTagNames, tagName) } + // Prepare a list of document correspodents + availableCorrespondentsMap, err := app.Client.GetAllCorrespondents(ctx) + if err != nil { + return nil, fmt.Errorf("failed to fetch available correspondents: %v", err) + } + + // Prepare a list of correspondent names + availableCorrespondentNames := make([]string, 0, len(availableCorrespondentsMap)) + for correspondentName := range availableCorrespondentsMap { + availableCorrespondentNames = append(availableCorrespondentNames, correspondentName) + } + documents := suggestionRequest.Documents documentSuggestions := []DocumentSuggestion{} @@ -175,6 +227,7 @@ func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionReque var suggestedTitle string var suggestedTags []string + var suggestedCorrespondent string if suggestionRequest.GenerateTitles { suggestedTitle, err = app.getSuggestedTitle(ctx, content) @@ -198,6 +251,18 @@ func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionReque } } + if suggestionRequest.GenerateCorrespondents { + suggestedCorrespondent, err = app.getSuggestedCorrespondent(ctx, content, suggestedTitle, availableCorrespondentNames, correspondentBlackList) + if err != nil { + mu.Lock() + errorsList = append(errorsList, fmt.Errorf("Document %d: %v", documentID, err)) + mu.Unlock() + log.Errorf("Error generating correspondents for document %d: %v", documentID, err) + return + } + + } + mu.Lock() suggestion := DocumentSuggestion{ ID: documentID, @@ -218,6 +283,15 @@ func (app *App) generateDocumentSuggestions(ctx context.Context, suggestionReque } else { suggestion.SuggestedTags = removeTagFromList(doc.Tags, manualTag) } + + // Correspondents + if suggestionRequest.GenerateCorrespondents { + log.Printf("Suggested correspondent for document %d: %s", documentID, suggestedCorrespondent) + suggestion.SuggestedCorrespondent = suggestedCorrespondent + } else { + suggestion.SuggestedCorrespondent = "" + } + documentSuggestions = append(documentSuggestions, suggestion) mu.Unlock() log.Printf("Document %d processed successfully.", documentID) diff --git a/main.go b/main.go index 5c8d3ef..506ecdf 100644 --- a/main.go +++ b/main.go @@ -26,22 +26,24 @@ var ( log = logrus.New() // Environment Variables - paperlessBaseURL = os.Getenv("PAPERLESS_BASE_URL") - paperlessAPIToken = os.Getenv("PAPERLESS_API_TOKEN") - openaiAPIKey = os.Getenv("OPENAI_API_KEY") - manualTag = "paperless-gpt" - autoTag = "paperless-gpt-auto" - llmProvider = os.Getenv("LLM_PROVIDER") - llmModel = os.Getenv("LLM_MODEL") - visionLlmProvider = os.Getenv("VISION_LLM_PROVIDER") - visionLlmModel = os.Getenv("VISION_LLM_MODEL") - logLevel = strings.ToLower(os.Getenv("LOG_LEVEL")) + paperlessBaseURL = os.Getenv("PAPERLESS_BASE_URL") + paperlessAPIToken = os.Getenv("PAPERLESS_API_TOKEN") + openaiAPIKey = os.Getenv("OPENAI_API_KEY") + manualTag = "paperless-gpt" + autoTag = "paperless-gpt-auto" + llmProvider = os.Getenv("LLM_PROVIDER") + llmModel = os.Getenv("LLM_MODEL") + visionLlmProvider = os.Getenv("VISION_LLM_PROVIDER") + visionLlmModel = os.Getenv("VISION_LLM_MODEL") + logLevel = strings.ToLower(os.Getenv("LOG_LEVEL")) + correspondentBlackList = strings.Split(os.Getenv("CORRESPONDENT_BLACK_LIST"), ",") // Templates - titleTemplate *template.Template - tagTemplate *template.Template - ocrTemplate *template.Template - templateMutex sync.RWMutex + titleTemplate *template.Template + tagTemplate *template.Template + correspondentTemplate *template.Template + ocrTemplate *template.Template + templateMutex sync.RWMutex // Default templates defaultTitleTemplate = `I will provide you with the content of a document that has been partially read by OCR (so it may contain errors). @@ -65,6 +67,34 @@ Content: Please concisely select the {{.Language}} tags from the list above that best describe the document. Be very selective and only choose the most relevant tags since too many tags will make the document less discoverable. +` + + defaultCorrespondentTemplate = `I will provide you with the content of a document. Your task is to suggest a correspondent that is most relevant to the document. + +Correspondents are the senders of documents that reach you. In the other direction, correspondents are the recipients of documents that you send. +In Paperless-ngx we can imagine correspondents as virtual drawers in which all documents of a person or company are stored. With just one click, we can find all the documents assigned to a specific correspondent. +Try to suggest a correspondent, either from the example list or come up with a new correspondent. + +Respond only with a correspondent, without any additional information! + +Be sure to choose a correspondent that is most relevant to the document. +Try to avoid any legal or financial suffixes like "GmbH" or "AG" in the correspondent name. For example use "Microsoft" instead of "Microsoft Ireland Operations Limited" or "Amazon" instead of "Amazon EU S.a.r.l.". + +If you can't find a suitable correspondent, you can respond with "Unknown". + +Example Correspondents: +{{.AvailableCorrespondents | join ", "}} + +List of Correspondents with Blacklisted Names. Please avoid these correspondents or variations of their names: +{{.BlackList | join ", "}} + +Title of the document: +{{.Title}} + +The content is likely in {{.Language}}. + +Document Content: +{{.Content}} ` defaultOcrPrompt = `Just transcribe the text in this image and preserve the formatting and layout (high quality OCR). Do that for ALL the text in the image. Be thorough and pay attention. This is very important. The image is from a text document so be sure to continue until the bottom of the page. Thanks a lot! You tend to forget about some text in the image so please focus! Use markdown format.` @@ -243,7 +273,7 @@ func validateEnvVars() { func (app *App) processAutoTagDocuments() (int, error) { ctx := context.Background() - documents, err := app.Client.GetDocumentsByTags(ctx, []string{autoTag}) + documents, err := app.Client.GetDocumentsByTags(ctx, []string{autoTag}, 1) if err != nil { return 0, fmt.Errorf("error fetching documents with autoTag: %w", err) } @@ -255,12 +285,11 @@ func (app *App) processAutoTagDocuments() (int, error) { log.Debugf("Found at least %d remaining documents with tag %s", len(documents), autoTag) - documents = documents[:1] // Process only one document at a time - suggestionRequest := GenerateSuggestionsRequest{ - Documents: documents, - GenerateTitles: true, - GenerateTags: true, + Documents: documents, + GenerateTitles: true, + GenerateTags: true, + GenerateCorrespondents: true, } suggestions, err := app.generateDocumentSuggestions(ctx, suggestionRequest) @@ -337,6 +366,21 @@ func loadTemplates() { log.Fatalf("Failed to parse tag template: %v", err) } + // Load correspondent template + correspondentTemplatePath := filepath.Join(promptsDir, "correspondent_prompt.tmpl") + correspondentTemplateContent, err := os.ReadFile(correspondentTemplatePath) + if err != nil { + log.Errorf("Could not read %s, using default template: %v", correspondentTemplatePath, err) + correspondentTemplateContent = []byte(defaultCorrespondentTemplate) + if err := os.WriteFile(correspondentTemplatePath, correspondentTemplateContent, os.ModePerm); err != nil { + log.Fatalf("Failed to write default correspondent template to disk: %v", err) + } + } + correspondentTemplate, err = template.New("correspondent").Funcs(sprig.FuncMap()).Parse(string(correspondentTemplateContent)) + if err != nil { + log.Fatalf("Failed to parse correspondent template: %v", err) + } + // Load OCR template ocrTemplatePath := filepath.Join(promptsDir, "ocr_prompt.tmpl") ocrTemplateContent, err := os.ReadFile(ocrTemplatePath) diff --git a/paperless.go b/paperless.go index 18ccf52..2f26057 100644 --- a/paperless.go +++ b/paperless.go @@ -39,29 +39,29 @@ func NewPaperlessClient(baseURL, apiToken string) *PaperlessClient { } // Do method to make requests to the Paperless-NGX API -func (c *PaperlessClient) Do(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { - url := fmt.Sprintf("%s/%s", c.BaseURL, strings.TrimLeft(path, "/")) +func (client *PaperlessClient) Do(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { + url := fmt.Sprintf("%s/%s", client.BaseURL, strings.TrimLeft(path, "/")) req, err := http.NewRequestWithContext(ctx, method, url, body) if err != nil { return nil, err } - req.Header.Set("Authorization", fmt.Sprintf("Token %s", c.APIToken)) + req.Header.Set("Authorization", fmt.Sprintf("Token %s", client.APIToken)) // Set Content-Type if body is present if body != nil { req.Header.Set("Content-Type", "application/json") } - return c.HTTPClient.Do(req) + return client.HTTPClient.Do(req) } // GetAllTags retrieves all tags from the Paperless-NGX API -func (c *PaperlessClient) GetAllTags(ctx context.Context) (map[string]int, error) { +func (client *PaperlessClient) GetAllTags(ctx context.Context) (map[string]int, error) { tagIDMapping := make(map[string]int) path := "api/tags/" for path != "" { - resp, err := c.Do(ctx, "GET", path, nil) + resp, err := client.Do(ctx, "GET", path, nil) if err != nil { return nil, err } @@ -92,8 +92,8 @@ func (c *PaperlessClient) GetAllTags(ctx context.Context) (map[string]int, error // Extract relative path from the Next URL if tagsResponse.Next != "" { nextURL := tagsResponse.Next - if strings.HasPrefix(nextURL, c.BaseURL) { - nextURL = strings.TrimPrefix(nextURL, c.BaseURL+"/") + if strings.HasPrefix(nextURL, client.BaseURL) { + nextURL = strings.TrimPrefix(nextURL, client.BaseURL+"/") } path = nextURL } else { @@ -105,15 +105,15 @@ func (c *PaperlessClient) GetAllTags(ctx context.Context) (map[string]int, error } // GetDocumentsByTags retrieves documents that match the specified tags -func (c *PaperlessClient) GetDocumentsByTags(ctx context.Context, tags []string) ([]Document, error) { +func (client *PaperlessClient) GetDocumentsByTags(ctx context.Context, tags []string, pageSize int) ([]Document, error) { tagQueries := make([]string, len(tags)) for i, tag := range tags { tagQueries[i] = fmt.Sprintf("tag:%s", tag) } searchQuery := strings.Join(tagQueries, " ") - path := fmt.Sprintf("api/documents/?query=%s", urlEncode(searchQuery)) + path := fmt.Sprintf("api/documents/?query=%s&page_size=%d", urlEncode(searchQuery), pageSize) - resp, err := c.Do(ctx, "GET", path, nil) + resp, err := client.Do(ctx, "GET", path, nil) if err != nil { return nil, err } @@ -130,7 +130,7 @@ func (c *PaperlessClient) GetDocumentsByTags(ctx context.Context, tags []string) return nil, err } - allTags, err := c.GetAllTags(ctx) + allTags, err := client.GetAllTags(ctx) if err != nil { return nil, err } @@ -159,9 +159,9 @@ func (c *PaperlessClient) GetDocumentsByTags(ctx context.Context, tags []string) } // DownloadPDF downloads the PDF file of the specified document -func (c *PaperlessClient) DownloadPDF(ctx context.Context, document Document) ([]byte, error) { +func (client *PaperlessClient) DownloadPDF(ctx context.Context, document Document) ([]byte, error) { path := fmt.Sprintf("api/documents/%d/download/", document.ID) - resp, err := c.Do(ctx, "GET", path, nil) + resp, err := client.Do(ctx, "GET", path, nil) if err != nil { return nil, err } @@ -175,9 +175,9 @@ func (c *PaperlessClient) DownloadPDF(ctx context.Context, document Document) ([ return io.ReadAll(resp.Body) } -func (c *PaperlessClient) GetDocument(ctx context.Context, documentID int) (Document, error) { +func (client *PaperlessClient) GetDocument(ctx context.Context, documentID int) (Document, error) { path := fmt.Sprintf("api/documents/%d/", documentID) - resp, err := c.Do(ctx, "GET", path, nil) + resp, err := client.Do(ctx, "GET", path, nil) if err != nil { return Document{}, err } @@ -194,7 +194,7 @@ func (c *PaperlessClient) GetDocument(ctx context.Context, documentID int) (Docu return Document{}, err } - allTags, err := c.GetAllTags(ctx) + allTags, err := client.GetAllTags(ctx) if err != nil { return Document{}, err } @@ -218,14 +218,32 @@ func (c *PaperlessClient) GetDocument(ctx context.Context, documentID int) (Docu } // UpdateDocuments updates the specified documents with suggested changes -func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []DocumentSuggestion) error { +func (client *PaperlessClient) UpdateDocuments(ctx context.Context, documents []DocumentSuggestion) error { // Fetch all available tags - availableTags, err := c.GetAllTags(ctx) + availableTags, err := client.GetAllTags(ctx) if err != nil { log.Errorf("Error fetching available tags: %v", err) return err } + documentsContainSuggestedCorrespondent := false + for _, document := range documents { + if document.SuggestedCorrespondent != "" { + documentsContainSuggestedCorrespondent = true + break + } + } + + availableCorrespondents := make(map[string]int) + if documentsContainSuggestedCorrespondent { + availableCorrespondents, err = client.GetAllCorrespondents(ctx) + if err != nil { + log.Errorf("Error fetching available correspondents: %v", + err) + return err + } + } + for _, document := range documents { documentID := document.ID @@ -248,12 +266,27 @@ func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []Docum } newTags = append(newTags, tagID) } else { - log.Warnf("Tag '%s' does not exist in paperless-ngx, skipping.", tagName) + log.Errorf("Suggested tag '%s' does not exist in paperless-ngx, skipping.", tagName) } } - updatedFields["tags"] = newTags + // Map suggested correspondent names to IDs + if document.SuggestedCorrespondent != "" { + if correspondentID, exists := availableCorrespondents[document.SuggestedCorrespondent]; exists { + updatedFields["correspondent"] = correspondentID + } else { + newCorrespondent := instantiateCorrespondent(document.SuggestedCorrespondent) + newCorrespondentID, err := client.CreateCorrespondent(context.Background(), newCorrespondent) + if err != nil { + log.Errorf("Error creating correspondent with name %s: %v\n", document.SuggestedCorrespondent, err) + return err + } + log.Infof("Created correspondent with name %s and ID %d\n", document.SuggestedCorrespondent, newCorrespondentID) + updatedFields["correspondent"] = newCorrespondentID + } + } + suggestedTitle := document.SuggestedTitle if len(suggestedTitle) > 128 { suggestedTitle = suggestedTitle[:128] @@ -279,7 +312,7 @@ func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []Docum // Send the update request using the generic Do method path := fmt.Sprintf("api/documents/%d/", documentID) - resp, err := c.Do(ctx, "PATCH", path, bytes.NewBuffer(jsonData)) + resp, err := client.Do(ctx, "PATCH", path, bytes.NewBuffer(jsonData)) if err != nil { log.Errorf("Error updating document %d: %v", documentID, err) return err @@ -299,9 +332,9 @@ func (c *PaperlessClient) UpdateDocuments(ctx context.Context, documents []Docum } // DownloadDocumentAsImages downloads the PDF file of the specified document and converts it to images -func (c *PaperlessClient) DownloadDocumentAsImages(ctx context.Context, documentId int) ([]string, error) { +func (client *PaperlessClient) DownloadDocumentAsImages(ctx context.Context, documentId int) ([]string, error) { // Create a directory named after the document ID - docDir := filepath.Join(c.GetCacheFolder(), fmt.Sprintf("/document-%d", documentId)) + docDir := filepath.Join(client.GetCacheFolder(), fmt.Sprintf("/document-%d", documentId)) if _, err := os.Stat(docDir); os.IsNotExist(err) { err = os.MkdirAll(docDir, 0755) if err != nil { @@ -326,7 +359,7 @@ func (c *PaperlessClient) DownloadDocumentAsImages(ctx context.Context, document // Proceed with downloading and converting the document to images path := fmt.Sprintf("api/documents/%d/download/", documentId) - resp, err := c.Do(ctx, "GET", path, nil) + resp, err := client.Do(ctx, "GET", path, nil) if err != nil { return nil, err } @@ -418,14 +451,97 @@ func (c *PaperlessClient) DownloadDocumentAsImages(ctx context.Context, document } // GetCacheFolder returns the cache folder for the PaperlessClient -func (c *PaperlessClient) GetCacheFolder() string { - if c.CacheFolder == "" { - c.CacheFolder = filepath.Join(os.TempDir(), "paperless-gpt") +func (client *PaperlessClient) GetCacheFolder() string { + if client.CacheFolder == "" { + client.CacheFolder = filepath.Join(os.TempDir(), "paperless-gpt") } - return c.CacheFolder + return client.CacheFolder } // urlEncode encodes a string for safe URL usage func urlEncode(s string) string { return strings.ReplaceAll(s, " ", "+") } + +// instantiateCorrespondent creates a new Correspondent object with default values +func instantiateCorrespondent(name string) Correspondent { + return Correspondent{ + Name: name, + MatchingAlgorithm: 0, + Match: "", + IsInsensitive: true, + Owner: nil, + } +} + +// CreateCorrespondent creates a new correspondent in Paperless-NGX +func (client *PaperlessClient) CreateCorrespondent(ctx context.Context, correspondent Correspondent) (int, error) { + url := "api/correspondents/" + + // Marshal the correspondent data to JSON + jsonData, err := json.Marshal(correspondent) + if err != nil { + return 0, err + } + + // Send the POST request + resp, err := client.Do(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return 0, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + bodyBytes, _ := io.ReadAll(resp.Body) + return 0, fmt.Errorf("error creating correspondent: %d, %s", resp.StatusCode, string(bodyBytes)) + } + + // Decode the response body to get the ID of the created correspondent + var createdCorrespondent struct { + ID int `json:"id"` + } + err = json.NewDecoder(resp.Body).Decode(&createdCorrespondent) + if err != nil { + return 0, err + } + + return createdCorrespondent.ID, nil +} + +// CorrespondentResponse represents the response structure for correspondents +type CorrespondentResponse struct { + Results []struct { + ID int `json:"id"` + Name string `json:"name"` + } `json:"results"` +} + +// GetAllCorrespondents retrieves all correspondents from the Paperless-NGX API +func (client *PaperlessClient) GetAllCorrespondents(ctx context.Context) (map[string]int, error) { + correspondentIDMapping := make(map[string]int) + path := "api/correspondents/?page_size=9999" + + resp, err := client.Do(ctx, "GET", path, nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("error fetching correspondents: %d, %s", resp.StatusCode, string(bodyBytes)) + } + + var correspondentsResponse CorrespondentResponse + + err = json.NewDecoder(resp.Body).Decode(&correspondentsResponse) + if err != nil { + return nil, err + } + + for _, correspondent := range correspondentsResponse.Results { + correspondentIDMapping[correspondent.Name] = correspondent.ID + } + + return correspondentIDMapping, nil +} diff --git a/paperless_test.go b/paperless_test.go index 6c70b9d..959adaa 100644 --- a/paperless_test.go +++ b/paperless_test.go @@ -203,7 +203,7 @@ func TestGetDocumentsByTags(t *testing.T) { // Set mock responses env.setMockResponse("/api/documents/", func(w http.ResponseWriter, r *http.Request) { // Verify query parameters - expectedQuery := "query=tag:tag1+tag:tag2" + expectedQuery := "query=tag:tag1+tag:tag2&page_size=25" assert.Equal(t, expectedQuery, r.URL.RawQuery) w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(documentsResponse) @@ -216,7 +216,7 @@ func TestGetDocumentsByTags(t *testing.T) { ctx := context.Background() tags := []string{"tag1", "tag2"} - documents, err := env.client.GetDocumentsByTags(ctx, tags) + documents, err := env.client.GetDocumentsByTags(ctx, tags, 25) require.NoError(t, err) expectedDocuments := []Document{ diff --git a/types.go b/types.go index 72238c0..40cbf6f 100644 --- a/types.go +++ b/types.go @@ -67,16 +67,36 @@ type Document struct { // GenerateSuggestionsRequest is the request payload for generating suggestions for /generate-suggestions endpoint type GenerateSuggestionsRequest struct { - Documents []Document `json:"documents"` - GenerateTitles bool `json:"generate_titles,omitempty"` - GenerateTags bool `json:"generate_tags,omitempty"` + Documents []Document `json:"documents"` + GenerateTitles bool `json:"generate_titles,omitempty"` + GenerateTags bool `json:"generate_tags,omitempty"` + GenerateCorrespondents bool `json:"generate_correspondents,omitempty"` } // DocumentSuggestion is the response payload for /generate-suggestions endpoint and the request payload for /update-documents endpoint (as an array) type DocumentSuggestion struct { - ID int `json:"id"` - OriginalDocument Document `json:"original_document"` - SuggestedTitle string `json:"suggested_title,omitempty"` - SuggestedTags []string `json:"suggested_tags,omitempty"` - SuggestedContent string `json:"suggested_content,omitempty"` + ID int `json:"id"` + OriginalDocument Document `json:"original_document"` + SuggestedTitle string `json:"suggested_title,omitempty"` + SuggestedTags []string `json:"suggested_tags,omitempty"` + SuggestedContent string `json:"suggested_content,omitempty"` + SuggestedCorrespondent string `json:"suggested_correspondent,omitempty"` +} + +type Correspondent struct { + Name string `json:"name"` + MatchingAlgorithm int `json:"matching_algorithm"` + Match string `json:"match"` + IsInsensitive bool `json:"is_insensitive"` + Owner *int `json:"owner"` + SetPermissions struct { + View struct { + Users []int `json:"users"` + Groups []int `json:"groups"` + } `json:"view"` + Change struct { + Users []int `json:"users"` + Groups []int `json:"groups"` + } `json:"change"` + } `json:"set_permissions"` }