From f82cfb4eefc4b5946f18f0afaaba00d9b44b1971 Mon Sep 17 00:00:00 2001 From: Luca Steeb Date: Fri, 22 Mar 2024 00:09:59 +0700 Subject: [PATCH] feat(repository): cache engine-relevant methods (#270) --- go.sum | 8 -- internal/auth/token/token_test.go | 66 +++++++++++- internal/cache/cache.go | 130 +++++++++++++++++++++++ internal/config/database/config.go | 6 ++ internal/config/loader/loader.go | 23 ++-- internal/repository/cache/cache.go | 61 +++++++++++ internal/repository/prisma/api_token.go | 13 ++- internal/repository/prisma/repository.go | 22 +++- internal/repository/prisma/tenant.go | 13 ++- 9 files changed, 311 insertions(+), 31 deletions(-) create mode 100644 internal/cache/cache.go create mode 100644 internal/repository/cache/cache.go diff --git a/go.sum b/go.sum index 4e2cdf5aa..fae2dcde4 100644 --- a/go.sum +++ b/go.sum @@ -105,8 +105,6 @@ github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.3.0 h1:2y3SDp0ZXuc6/cjLSZ+Q3ir+QB9T/iG5yYRXqsagWSY= -github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= @@ -389,22 +387,16 @@ go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 h1:aFJWCqJMNjENlcleuuOkGAPH82y0yULBScfXcIEdS24= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1/go.mod h1:sEGXWArGqc3tVa+ekntsN65DmVbVeW+7lTKTjZF3/Fo= -go.opentelemetry.io/otel v1.21.0 h1:hzLeKBZEL7Okw2mGzZ0cc4k/A7Fta0uoPgaJCr8fsFc= -go.opentelemetry.io/otel v1.21.0/go.mod h1:QZzNPQPm1zLX4gZK4cMi+71eaorMSGT3A4znnUvNNEo= go.opentelemetry.io/otel v1.23.1 h1:Za4UzOqJYS+MUczKI320AtqZHZb7EqxO00jAHE0jmQY= go.opentelemetry.io/otel v1.23.1/go.mod h1:Td0134eafDLcTS4y+zQ26GE8u3dEuRBiBCTUIRHaikA= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.21.0 h1:cl5P5/GIfFh4t6xyruOgJP5QiA1pw4fYYdv6nc6CBWw= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.21.0/go.mod h1:zgBdWWAu7oEEMC06MMKc5NLbA/1YDXV1sMpSqEeLQLg= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.21.0 h1:tIqheXEFWAZ7O8A7m+J0aPTmpJN3YQ7qetUAdkkkKpk= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.21.0/go.mod h1:nUeKExfxAQVbiVFn32YXpXZZHZ61Cc3s3Rn1pDBGAb0= -go.opentelemetry.io/otel/metric v1.21.0 h1:tlYWfeo+Bocx5kLEloTjbcDwBuELRrIFxwdQ36PlJu4= -go.opentelemetry.io/otel/metric v1.21.0/go.mod h1:o1p3CA8nNHW8j5yuQLdc1eeqEaPfzug24uvsyIEJRWM= go.opentelemetry.io/otel/metric v1.23.1 h1:PQJmqJ9u2QaJLBOELl1cxIdPcpbwzbkjfEyelTl2rlo= go.opentelemetry.io/otel/metric v1.23.1/go.mod h1:mpG2QPlAfnK8yNhNJAxDZruU9Y1/HubbC+KyH8FaCWI= go.opentelemetry.io/otel/sdk v1.21.0 h1:FTt8qirL1EysG6sTQRZ5TokkU8d0ugCj8htOgThZXQ8= go.opentelemetry.io/otel/sdk v1.21.0/go.mod h1:Nna6Yv7PWTdgJHVRD9hIYywQBRx7pbox6nwBnZIxl/E= -go.opentelemetry.io/otel/trace v1.21.0 h1:WD9i5gzvoUPuXIXH24ZNBudiarZDKuekPqi/E8fpfLc= -go.opentelemetry.io/otel/trace v1.21.0/go.mod h1:LGbsEB0f9LGjN+OZaQQ26sohbOmiMR+BaslueVtS/qQ= go.opentelemetry.io/otel/trace v1.23.1 h1:4LrmmEd8AU2rFvU1zegmvqW7+kWarxtNOPyeL6HmYY8= go.opentelemetry.io/otel/trace v1.23.1/go.mod h1:4IpnpJFwr1mo/6HL8XIPJaE9y0+u1KcVmuW7dwFSVrI= go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= diff --git a/internal/auth/token/token_test.go b/internal/auth/token/token_test.go index cd78a67b4..e46573c2a 100644 --- a/internal/auth/token/token_test.go +++ b/internal/auth/token/token_test.go @@ -4,6 +4,7 @@ package token_test import ( "fmt" + "os" "testing" "github.com/google/uuid" @@ -16,7 +17,7 @@ import ( "github.com/hatchet-dev/hatchet/internal/testutils" ) -func TestCreateTenantToken(t *testing.T) { +func TestCreateTenantToken(t *testing.T) { // make sure no cache is used for tests testutils.RunTestWithDatabase(t, func(conf *database.Config) error { jwtManager := getJWTManager(t, conf) @@ -56,6 +57,8 @@ func TestCreateTenantToken(t *testing.T) { } func TestRevokeTenantToken(t *testing.T) { + _ = os.Setenv("CACHE_DURATION", "0") + testutils.RunTestWithDatabase(t, func(conf *database.Config) error { jwtManager := getJWTManager(t, conf) @@ -106,12 +109,73 @@ func TestRevokeTenantToken(t *testing.T) { // validate the token again _, err = jwtManager.ValidateTenantToken(token) + // error as the token was revoked assert.Error(t, err) return nil }) } +func TestRevokeTenantTokenCache(t *testing.T) { + _ = os.Setenv("CACHE_DURATION", "60s") + + testutils.RunTestWithDatabase(t, func(conf *database.Config) error { + jwtManager := getJWTManager(t, conf) + + tenantId := uuid.New().String() + + // create the tenant + slugSuffix, err := encryption.GenerateRandomBytes(8) + + if err != nil { + t.Fatal(err.Error()) + } + + _, err = conf.Repository.Tenant().CreateTenant(&repository.CreateTenantOpts{ + ID: &tenantId, + Name: "test-tenant", + Slug: fmt.Sprintf("test-tenant-%s", slugSuffix), + }) + + if err != nil { + t.Fatal(err.Error()) + } + + token, err := jwtManager.GenerateTenantToken(tenantId, "test token") + + if err != nil { + t.Fatal(err.Error()) + } + + // validate the token + _, err = jwtManager.ValidateTenantToken(token) + + assert.NoError(t, err) + + // revoke the token + apiTokens, err := conf.Repository.APIToken().ListAPITokensByTenant(tenantId) + + if err != nil { + t.Fatal(err.Error()) + } + + assert.Len(t, apiTokens, 1) + err = conf.Repository.APIToken().RevokeAPIToken(apiTokens[0].ID) + + if err != nil { + t.Fatal(err.Error()) + } + + // validate the token again + _, err = jwtManager.ValidateTenantToken(token) + + // no error as it is cached + assert.NoError(t, err) + + return nil + }) +} + func getJWTManager(t *testing.T, conf *database.Config) token.JWTManager { t.Helper() diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 000000000..a2250b085 --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,130 @@ +package cache + +import ( + "sync" + "time" +) + +// item represents a cache item with a value and an expiration time. +type item[V any] struct { + value V + expiry time.Time +} + +// isExpired checks if the cache item has expired. +func (i item[V]) isExpired() bool { + return time.Now().After(i.expiry) +} + +// TTLCache is a generic cache implementation with support for time-to-live +// (TTL) expiration. +type TTLCache[K comparable, V any] struct { + items map[K]item[V] // The map storing cache items. + mu sync.Mutex // Mutex for controlling concurrent access to the cache. + stop chan interface{} // Channel to stop the goroutine that removes expired items. +} + +// NewTTL creates a new TTLCache instance and starts a goroutine to periodically +// remove expired items every 5 seconds. +func NewTTL[K comparable, V any]() *TTLCache[K, V] { + c := &TTLCache[K, V]{ + items: make(map[K]item[V]), + stop: make(chan interface{}), + } + + go func() { + // Create a new ticker to remove expired items every 5 seconds. + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for { + select { + case <-c.stop: + return + case <-ticker.C: + c.mu.Lock() + + // Iterate over the cache items and delete expired ones. + for key, item := range c.items { + if item.isExpired() { + delete(c.items, key) + } + } + + c.mu.Unlock() + } + } + }() + + return c +} + +func (c *TTLCache[K, V]) Stop() { + close(c.stop) +} + +// Set adds a new item to the cache with the specified key, value, and +// time-to-live (TTL). +func (c *TTLCache[K, V]) Set(key K, value V, ttl time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + + c.items[key] = item[V]{ + value: value, + expiry: time.Now().Add(ttl), + } +} + +// Get retrieves the value associated with the given key from the cache. +func (c *TTLCache[K, V]) Get(key K) (V, bool) { + c.mu.Lock() + defer c.mu.Unlock() + + item, found := c.items[key] + if !found { + // If the key is not found, return the zero value for V and false. + return item.value, false + } + + if item.isExpired() { + // If the item has expired, remove it from the cache and return the + // value and false. + delete(c.items, key) + return item.value, false + } + + // Otherwise return the value and true. + return item.value, true +} + +// Remove removes the item with the specified key from the cache. +func (c *TTLCache[K, V]) Remove(key K) { + c.mu.Lock() + defer c.mu.Unlock() + + // Delete the item with the given key from the cache. + delete(c.items, key) +} + +// Pop removes and returns the item with the specified key from the cache. +func (c *TTLCache[K, V]) Pop(key K) (V, bool) { + c.mu.Lock() + defer c.mu.Unlock() + + item, found := c.items[key] + if !found { + // If the key is not found, return the zero value for V and false. + return item.value, false + } + + // If the key is found, delete the item from the cache. + delete(c.items, key) + + if item.isExpired() { + // If the item has expired, return the value and false. + return item.value, false + } + + // Otherwise return the value and true. + return item.value, true +} diff --git a/internal/config/database/config.go b/internal/config/database/config.go index b35a25cdd..f06f5d9df 100644 --- a/internal/config/database/config.go +++ b/internal/config/database/config.go @@ -1,6 +1,8 @@ package database import ( + "time" + "github.com/spf13/viper" "github.com/hatchet-dev/hatchet/internal/config/shared" @@ -20,6 +22,8 @@ type ConfigFile struct { Logger shared.LoggerConfigFile `mapstructure:"logger" json:"logger,omitempty"` LogQueries bool `mapstructure:"logQueries" json:"logQueries,omitempty" default:"false"` + + CacheDuration time.Duration `mapstructure:"cacheDuration" json:"cacheDuration,omitempty" default:"60s"` } type SeedConfigFile struct { @@ -51,6 +55,8 @@ func BindAllEnv(v *viper.Viper) { _ = v.BindEnv("sslMode", "DATABASE_POSTGRES_SSL_MODE") _ = v.BindEnv("logQueries", "DATABASE_LOG_QUERIES") + _ = v.BindEnv("cacheDuration", "CACHE_DURATION") + _ = v.BindEnv("seed.adminEmail", "ADMIN_EMAIL") _ = v.BindEnv("seed.adminPassword", "ADMIN_PASSWORD") _ = v.BindEnv("seed.adminName", "ADMIN_NAME") diff --git a/internal/config/loader/loader.go b/internal/config/loader/loader.go index b54bb2b8a..83187b745 100644 --- a/internal/config/loader/loader.go +++ b/internal/config/loader/loader.go @@ -10,11 +10,11 @@ import ( "path/filepath" "strings" + "github.com/exaring/otelpgx" + pgxzero "github.com/jackc/pgx-zerolog" "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/tracelog" - pgxzero "github.com/jackc/pgx-zerolog" - "github.com/hatchet-dev/hatchet/internal/auth/cookie" "github.com/hatchet-dev/hatchet/internal/auth/oauth" "github.com/hatchet-dev/hatchet/internal/auth/token" @@ -27,6 +27,7 @@ import ( "github.com/hatchet-dev/hatchet/internal/integrations/vcs/github" "github.com/hatchet-dev/hatchet/internal/logger" "github.com/hatchet-dev/hatchet/internal/msgqueue/rabbitmq" + "github.com/hatchet-dev/hatchet/internal/repository/cache" "github.com/hatchet-dev/hatchet/internal/repository/prisma" "github.com/hatchet-dev/hatchet/internal/repository/prisma/db" "github.com/hatchet-dev/hatchet/internal/services/ingestor" @@ -34,8 +35,6 @@ import ( "github.com/hatchet-dev/hatchet/pkg/client" "github.com/hatchet-dev/hatchet/pkg/errors" "github.com/hatchet-dev/hatchet/pkg/errors/sentry" - - "github.com/exaring/otelpgx" ) // LoadDatabaseConfigFile loads the database config file via viper @@ -121,11 +120,10 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile) (res *database.Con cf.PostgresSSLMode, ) - os.Setenv("DATABASE_URL", databaseUrl) + // TODO db.WithDatasourceURL(databaseUrl) is not working + _ = os.Setenv("DATABASE_URL", databaseUrl) - c := db.NewClient( - // db.WithDatasourceURL(databaseUrl), - ) + c := db.NewClient() if err := c.Prisma.Connect(); err != nil { return nil, err @@ -152,9 +150,14 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile) (res *database.Con return nil, fmt.Errorf("could not connect to database: %w", err) } + ch := cache.New(cf.CacheDuration) + return &database.Config{ - Disconnect: c.Prisma.Disconnect, - Repository: prisma.NewPrismaRepository(c, pool, prisma.WithLogger(&l)), + Disconnect: func() error { + ch.Stop() + return c.Prisma.Disconnect() + }, + Repository: prisma.NewPrismaRepository(c, pool, prisma.WithLogger(&l), prisma.WithCache(ch)), Seed: cf.Seed, }, nil } diff --git a/internal/repository/cache/cache.go b/internal/repository/cache/cache.go new file mode 100644 index 000000000..114d78ff9 --- /dev/null +++ b/internal/repository/cache/cache.go @@ -0,0 +1,61 @@ +package cache + +import ( + "time" + + "github.com/hatchet-dev/hatchet/internal/cache" +) + +type Cacheable interface { + // Set sets a value in the cache with the given key + Set(key string, value interface{}) + + // Get gets a value from the cache with the given key + Get(key string) (interface{}, bool) + + // Stop stops the cache and clears any goroutines + Stop() +} + +type Cache struct { + cache *cache.TTLCache[string, interface{}] + expiration time.Duration +} + +func (c *Cache) Set(key string, value interface{}) { + c.cache.Set(key, value, c.expiration) +} + +func (c *Cache) Get(key string) (interface{}, bool) { + return c.cache.Get(key) +} + +func (c *Cache) Stop() { + c.cache.Stop() +} + +func New(duration time.Duration) *Cache { + if duration == 0 { + // consider a duration of 0 a very short expiry instead of no expiry + duration = 1 * time.Millisecond + } + return &Cache{ + expiration: duration, + cache: cache.NewTTL[string, interface{}](), + } +} + +func MakeCacheable[T any](cache Cacheable, id string, f func() (*T, error)) (*T, error) { + if v, ok := cache.Get(id); ok { + return v.(*T), nil + } + + v, err := f() + if err != nil { + return nil, err + } + + cache.Set(id, v) + + return v, nil +} diff --git a/internal/repository/prisma/api_token.go b/internal/repository/prisma/api_token.go index efade3f62..a0213e798 100644 --- a/internal/repository/prisma/api_token.go +++ b/internal/repository/prisma/api_token.go @@ -5,6 +5,7 @@ import ( "time" "github.com/hatchet-dev/hatchet/internal/repository" + "github.com/hatchet-dev/hatchet/internal/repository/cache" "github.com/hatchet-dev/hatchet/internal/repository/prisma/db" "github.com/hatchet-dev/hatchet/internal/validator" ) @@ -12,19 +13,23 @@ import ( type apiTokenRepository struct { client *db.PrismaClient v validator.Validator + cache cache.Cacheable } -func NewAPITokenRepository(client *db.PrismaClient, v validator.Validator) repository.APITokenRepository { +func NewAPITokenRepository(client *db.PrismaClient, v validator.Validator, cache cache.Cacheable) repository.APITokenRepository { return &apiTokenRepository{ client: client, v: v, + cache: cache, } } func (a *apiTokenRepository) GetAPITokenById(id string) (*db.APITokenModel, error) { - return a.client.APIToken.FindUnique( - db.APIToken.ID.Equals(id), - ).Exec(context.Background()) + return cache.MakeCacheable[db.APITokenModel](a.cache, id, func() (*db.APITokenModel, error) { + return a.client.APIToken.FindUnique( + db.APIToken.ID.Equals(id), + ).Exec(context.Background()) + }) } func (a *apiTokenRepository) CreateAPIToken(opts *repository.CreateAPITokenOpts) (*db.APITokenModel, error) { diff --git a/internal/repository/prisma/repository.go b/internal/repository/prisma/repository.go index 19c2d9b35..f4d847460 100644 --- a/internal/repository/prisma/repository.go +++ b/internal/repository/prisma/repository.go @@ -1,10 +1,13 @@ package prisma import ( + "time" + "github.com/jackc/pgx/v5/pgxpool" "github.com/rs/zerolog" "github.com/hatchet-dev/hatchet/internal/repository" + "github.com/hatchet-dev/hatchet/internal/repository/cache" "github.com/hatchet-dev/hatchet/internal/repository/prisma/db" "github.com/hatchet-dev/hatchet/internal/validator" ) @@ -34,8 +37,9 @@ type prismaRepository struct { type PrismaRepositoryOpt func(*PrismaRepositoryOpts) type PrismaRepositoryOpts struct { - v validator.Validator - l *zerolog.Logger + v validator.Validator + l *zerolog.Logger + cache cache.Cacheable } func defaultPrismaRepositoryOpts() *PrismaRepositoryOpts { @@ -56,6 +60,12 @@ func WithLogger(l *zerolog.Logger) PrismaRepositoryOpt { } } +func WithCache(cache cache.Cacheable) PrismaRepositoryOpt { + return func(opts *PrismaRepositoryOpts) { + opts.cache = cache + } +} + func NewPrismaRepository(client *db.PrismaClient, pool *pgxpool.Pool, fs ...PrismaRepositoryOpt) repository.Repository { opts := defaultPrismaRepositoryOpts() @@ -66,11 +76,15 @@ func NewPrismaRepository(client *db.PrismaClient, pool *pgxpool.Pool, fs ...Pris newLogger := opts.l.With().Str("service", "database").Logger() opts.l = &newLogger + if opts.cache == nil { + opts.cache = cache.New(1 * time.Millisecond) + } + return &prismaRepository{ - apiToken: NewAPITokenRepository(client, opts.v), + apiToken: NewAPITokenRepository(client, opts.v, opts.cache), event: NewEventRepository(client, pool, opts.v, opts.l), log: NewLogRepository(client, pool, opts.v, opts.l), - tenant: NewTenantRepository(client, opts.v), + tenant: NewTenantRepository(client, opts.v, opts.cache), tenantInvite: NewTenantInviteRepository(client, opts.v), workflow: NewWorkflowRepository(client, pool, opts.v, opts.l), workflowRun: NewWorkflowRunRepository(client, pool, opts.v, opts.l), diff --git a/internal/repository/prisma/tenant.go b/internal/repository/prisma/tenant.go index 3ffe9c143..c32240893 100644 --- a/internal/repository/prisma/tenant.go +++ b/internal/repository/prisma/tenant.go @@ -4,6 +4,7 @@ import ( "context" "github.com/hatchet-dev/hatchet/internal/repository" + "github.com/hatchet-dev/hatchet/internal/repository/cache" "github.com/hatchet-dev/hatchet/internal/repository/prisma/db" "github.com/hatchet-dev/hatchet/internal/validator" ) @@ -11,12 +12,14 @@ import ( type tenantRepository struct { client *db.PrismaClient v validator.Validator + cache cache.Cacheable } -func NewTenantRepository(client *db.PrismaClient, v validator.Validator) repository.TenantRepository { +func NewTenantRepository(client *db.PrismaClient, v validator.Validator, cache cache.Cacheable) repository.TenantRepository { return &tenantRepository{ client: client, v: v, + cache: cache, } } @@ -37,9 +40,11 @@ func (r *tenantRepository) ListTenants() ([]db.TenantModel, error) { } func (r *tenantRepository) GetTenantByID(id string) (*db.TenantModel, error) { - return r.client.Tenant.FindUnique( - db.Tenant.ID.Equals(id), - ).Exec(context.Background()) + return cache.MakeCacheable[db.TenantModel](r.cache, id, func() (*db.TenantModel, error) { + return r.client.Tenant.FindUnique( + db.Tenant.ID.Equals(id), + ).Exec(context.Background()) + }) } func (r *tenantRepository) GetTenantBySlug(slug string) (*db.TenantModel, error) {