Skip to content

Commit

Permalink
feat(repository): cache engine-relevant methods (#270)
Browse files Browse the repository at this point in the history
  • Loading branch information
steebchen authored Mar 21, 2024
1 parent 617a306 commit f82cfb4
Show file tree
Hide file tree
Showing 9 changed files with 311 additions and 31 deletions.
8 changes: 0 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.3.0 h1:2y3SDp0ZXuc6/cjLSZ+Q3ir+QB9T/iG5yYRXqsagWSY=
github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
Expand Down Expand Up @@ -389,22 +387,16 @@ go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 h1:aFJWCqJMNjENlcleuuOkGAPH82y0yULBScfXcIEdS24=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1/go.mod h1:sEGXWArGqc3tVa+ekntsN65DmVbVeW+7lTKTjZF3/Fo=
go.opentelemetry.io/otel v1.21.0 h1:hzLeKBZEL7Okw2mGzZ0cc4k/A7Fta0uoPgaJCr8fsFc=
go.opentelemetry.io/otel v1.21.0/go.mod h1:QZzNPQPm1zLX4gZK4cMi+71eaorMSGT3A4znnUvNNEo=
go.opentelemetry.io/otel v1.23.1 h1:Za4UzOqJYS+MUczKI320AtqZHZb7EqxO00jAHE0jmQY=
go.opentelemetry.io/otel v1.23.1/go.mod h1:Td0134eafDLcTS4y+zQ26GE8u3dEuRBiBCTUIRHaikA=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.21.0 h1:cl5P5/GIfFh4t6xyruOgJP5QiA1pw4fYYdv6nc6CBWw=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.21.0/go.mod h1:zgBdWWAu7oEEMC06MMKc5NLbA/1YDXV1sMpSqEeLQLg=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.21.0 h1:tIqheXEFWAZ7O8A7m+J0aPTmpJN3YQ7qetUAdkkkKpk=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.21.0/go.mod h1:nUeKExfxAQVbiVFn32YXpXZZHZ61Cc3s3Rn1pDBGAb0=
go.opentelemetry.io/otel/metric v1.21.0 h1:tlYWfeo+Bocx5kLEloTjbcDwBuELRrIFxwdQ36PlJu4=
go.opentelemetry.io/otel/metric v1.21.0/go.mod h1:o1p3CA8nNHW8j5yuQLdc1eeqEaPfzug24uvsyIEJRWM=
go.opentelemetry.io/otel/metric v1.23.1 h1:PQJmqJ9u2QaJLBOELl1cxIdPcpbwzbkjfEyelTl2rlo=
go.opentelemetry.io/otel/metric v1.23.1/go.mod h1:mpG2QPlAfnK8yNhNJAxDZruU9Y1/HubbC+KyH8FaCWI=
go.opentelemetry.io/otel/sdk v1.21.0 h1:FTt8qirL1EysG6sTQRZ5TokkU8d0ugCj8htOgThZXQ8=
go.opentelemetry.io/otel/sdk v1.21.0/go.mod h1:Nna6Yv7PWTdgJHVRD9hIYywQBRx7pbox6nwBnZIxl/E=
go.opentelemetry.io/otel/trace v1.21.0 h1:WD9i5gzvoUPuXIXH24ZNBudiarZDKuekPqi/E8fpfLc=
go.opentelemetry.io/otel/trace v1.21.0/go.mod h1:LGbsEB0f9LGjN+OZaQQ26sohbOmiMR+BaslueVtS/qQ=
go.opentelemetry.io/otel/trace v1.23.1 h1:4LrmmEd8AU2rFvU1zegmvqW7+kWarxtNOPyeL6HmYY8=
go.opentelemetry.io/otel/trace v1.23.1/go.mod h1:4IpnpJFwr1mo/6HL8XIPJaE9y0+u1KcVmuW7dwFSVrI=
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
Expand Down
66 changes: 65 additions & 1 deletion internal/auth/token/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package token_test

import (
"fmt"
"os"
"testing"

"github.com/google/uuid"
Expand All @@ -16,7 +17,7 @@ import (
"github.com/hatchet-dev/hatchet/internal/testutils"
)

func TestCreateTenantToken(t *testing.T) {
func TestCreateTenantToken(t *testing.T) { // make sure no cache is used for tests
testutils.RunTestWithDatabase(t, func(conf *database.Config) error {
jwtManager := getJWTManager(t, conf)

Expand Down Expand Up @@ -56,6 +57,8 @@ func TestCreateTenantToken(t *testing.T) {
}

func TestRevokeTenantToken(t *testing.T) {
_ = os.Setenv("CACHE_DURATION", "0")

testutils.RunTestWithDatabase(t, func(conf *database.Config) error {
jwtManager := getJWTManager(t, conf)

Expand Down Expand Up @@ -106,12 +109,73 @@ func TestRevokeTenantToken(t *testing.T) {
// validate the token again
_, err = jwtManager.ValidateTenantToken(token)

// error as the token was revoked
assert.Error(t, err)

return nil
})
}

func TestRevokeTenantTokenCache(t *testing.T) {
_ = os.Setenv("CACHE_DURATION", "60s")

testutils.RunTestWithDatabase(t, func(conf *database.Config) error {
jwtManager := getJWTManager(t, conf)

tenantId := uuid.New().String()

// create the tenant
slugSuffix, err := encryption.GenerateRandomBytes(8)

if err != nil {
t.Fatal(err.Error())
}

_, err = conf.Repository.Tenant().CreateTenant(&repository.CreateTenantOpts{
ID: &tenantId,
Name: "test-tenant",
Slug: fmt.Sprintf("test-tenant-%s", slugSuffix),
})

if err != nil {
t.Fatal(err.Error())
}

token, err := jwtManager.GenerateTenantToken(tenantId, "test token")

if err != nil {
t.Fatal(err.Error())
}

// validate the token
_, err = jwtManager.ValidateTenantToken(token)

assert.NoError(t, err)

// revoke the token
apiTokens, err := conf.Repository.APIToken().ListAPITokensByTenant(tenantId)

if err != nil {
t.Fatal(err.Error())
}

assert.Len(t, apiTokens, 1)
err = conf.Repository.APIToken().RevokeAPIToken(apiTokens[0].ID)

if err != nil {
t.Fatal(err.Error())
}

// validate the token again
_, err = jwtManager.ValidateTenantToken(token)

// no error as it is cached
assert.NoError(t, err)

return nil
})
}

func getJWTManager(t *testing.T, conf *database.Config) token.JWTManager {
t.Helper()

Expand Down
130 changes: 130 additions & 0 deletions internal/cache/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package cache

import (
"sync"
"time"
)

// item represents a cache item with a value and an expiration time.
type item[V any] struct {
value V
expiry time.Time
}

// isExpired checks if the cache item has expired.
func (i item[V]) isExpired() bool {
return time.Now().After(i.expiry)
}

// TTLCache is a generic cache implementation with support for time-to-live
// (TTL) expiration.
type TTLCache[K comparable, V any] struct {
items map[K]item[V] // The map storing cache items.
mu sync.Mutex // Mutex for controlling concurrent access to the cache.
stop chan interface{} // Channel to stop the goroutine that removes expired items.
}

// NewTTL creates a new TTLCache instance and starts a goroutine to periodically
// remove expired items every 5 seconds.
func NewTTL[K comparable, V any]() *TTLCache[K, V] {
c := &TTLCache[K, V]{
items: make(map[K]item[V]),
stop: make(chan interface{}),
}

go func() {
// Create a new ticker to remove expired items every 5 seconds.
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()

for {
select {
case <-c.stop:
return
case <-ticker.C:
c.mu.Lock()

// Iterate over the cache items and delete expired ones.
for key, item := range c.items {
if item.isExpired() {
delete(c.items, key)
}
}

c.mu.Unlock()
}
}
}()

return c
}

func (c *TTLCache[K, V]) Stop() {
close(c.stop)
}

// Set adds a new item to the cache with the specified key, value, and
// time-to-live (TTL).
func (c *TTLCache[K, V]) Set(key K, value V, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()

c.items[key] = item[V]{
value: value,
expiry: time.Now().Add(ttl),
}
}

// Get retrieves the value associated with the given key from the cache.
func (c *TTLCache[K, V]) Get(key K) (V, bool) {
c.mu.Lock()
defer c.mu.Unlock()

item, found := c.items[key]
if !found {
// If the key is not found, return the zero value for V and false.
return item.value, false
}

if item.isExpired() {
// If the item has expired, remove it from the cache and return the
// value and false.
delete(c.items, key)
return item.value, false
}

// Otherwise return the value and true.
return item.value, true
}

// Remove removes the item with the specified key from the cache.
func (c *TTLCache[K, V]) Remove(key K) {
c.mu.Lock()
defer c.mu.Unlock()

// Delete the item with the given key from the cache.
delete(c.items, key)
}

// Pop removes and returns the item with the specified key from the cache.
func (c *TTLCache[K, V]) Pop(key K) (V, bool) {
c.mu.Lock()
defer c.mu.Unlock()

item, found := c.items[key]
if !found {
// If the key is not found, return the zero value for V and false.
return item.value, false
}

// If the key is found, delete the item from the cache.
delete(c.items, key)

if item.isExpired() {
// If the item has expired, return the value and false.
return item.value, false
}

// Otherwise return the value and true.
return item.value, true
}
6 changes: 6 additions & 0 deletions internal/config/database/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package database

import (
"time"

"github.com/spf13/viper"

"github.com/hatchet-dev/hatchet/internal/config/shared"
Expand All @@ -20,6 +22,8 @@ type ConfigFile struct {
Logger shared.LoggerConfigFile `mapstructure:"logger" json:"logger,omitempty"`

LogQueries bool `mapstructure:"logQueries" json:"logQueries,omitempty" default:"false"`

CacheDuration time.Duration `mapstructure:"cacheDuration" json:"cacheDuration,omitempty" default:"60s"`
}

type SeedConfigFile struct {
Expand Down Expand Up @@ -51,6 +55,8 @@ func BindAllEnv(v *viper.Viper) {
_ = v.BindEnv("sslMode", "DATABASE_POSTGRES_SSL_MODE")
_ = v.BindEnv("logQueries", "DATABASE_LOG_QUERIES")

_ = v.BindEnv("cacheDuration", "CACHE_DURATION")

_ = v.BindEnv("seed.adminEmail", "ADMIN_EMAIL")
_ = v.BindEnv("seed.adminPassword", "ADMIN_PASSWORD")
_ = v.BindEnv("seed.adminName", "ADMIN_NAME")
Expand Down
23 changes: 13 additions & 10 deletions internal/config/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ import (
"path/filepath"
"strings"

"github.com/exaring/otelpgx"
pgxzero "github.com/jackc/pgx-zerolog"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/tracelog"

pgxzero "github.com/jackc/pgx-zerolog"

"github.com/hatchet-dev/hatchet/internal/auth/cookie"
"github.com/hatchet-dev/hatchet/internal/auth/oauth"
"github.com/hatchet-dev/hatchet/internal/auth/token"
Expand All @@ -27,15 +27,14 @@ import (
"github.com/hatchet-dev/hatchet/internal/integrations/vcs/github"
"github.com/hatchet-dev/hatchet/internal/logger"
"github.com/hatchet-dev/hatchet/internal/msgqueue/rabbitmq"
"github.com/hatchet-dev/hatchet/internal/repository/cache"
"github.com/hatchet-dev/hatchet/internal/repository/prisma"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
"github.com/hatchet-dev/hatchet/internal/services/ingestor"
"github.com/hatchet-dev/hatchet/internal/validator"
"github.com/hatchet-dev/hatchet/pkg/client"
"github.com/hatchet-dev/hatchet/pkg/errors"
"github.com/hatchet-dev/hatchet/pkg/errors/sentry"

"github.com/exaring/otelpgx"
)

// LoadDatabaseConfigFile loads the database config file via viper
Expand Down Expand Up @@ -121,11 +120,10 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile) (res *database.Con
cf.PostgresSSLMode,
)

os.Setenv("DATABASE_URL", databaseUrl)
// TODO db.WithDatasourceURL(databaseUrl) is not working
_ = os.Setenv("DATABASE_URL", databaseUrl)

c := db.NewClient(
// db.WithDatasourceURL(databaseUrl),
)
c := db.NewClient()

if err := c.Prisma.Connect(); err != nil {
return nil, err
Expand All @@ -152,9 +150,14 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile) (res *database.Con
return nil, fmt.Errorf("could not connect to database: %w", err)
}

ch := cache.New(cf.CacheDuration)

return &database.Config{
Disconnect: c.Prisma.Disconnect,
Repository: prisma.NewPrismaRepository(c, pool, prisma.WithLogger(&l)),
Disconnect: func() error {
ch.Stop()
return c.Prisma.Disconnect()
},
Repository: prisma.NewPrismaRepository(c, pool, prisma.WithLogger(&l), prisma.WithCache(ch)),
Seed: cf.Seed,
}, nil
}
Expand Down
Loading

0 comments on commit f82cfb4

Please sign in to comment.