Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(repository): cache engine-relevant methods #270

Merged
merged 9 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.3.0 h1:2y3SDp0ZXuc6/cjLSZ+Q3ir+QB9T/iG5yYRXqsagWSY=
github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
Expand Down Expand Up @@ -389,22 +387,16 @@ go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 h1:aFJWCqJMNjENlcleuuOkGAPH82y0yULBScfXcIEdS24=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1/go.mod h1:sEGXWArGqc3tVa+ekntsN65DmVbVeW+7lTKTjZF3/Fo=
go.opentelemetry.io/otel v1.21.0 h1:hzLeKBZEL7Okw2mGzZ0cc4k/A7Fta0uoPgaJCr8fsFc=
go.opentelemetry.io/otel v1.21.0/go.mod h1:QZzNPQPm1zLX4gZK4cMi+71eaorMSGT3A4znnUvNNEo=
go.opentelemetry.io/otel v1.23.1 h1:Za4UzOqJYS+MUczKI320AtqZHZb7EqxO00jAHE0jmQY=
go.opentelemetry.io/otel v1.23.1/go.mod h1:Td0134eafDLcTS4y+zQ26GE8u3dEuRBiBCTUIRHaikA=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.21.0 h1:cl5P5/GIfFh4t6xyruOgJP5QiA1pw4fYYdv6nc6CBWw=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.21.0/go.mod h1:zgBdWWAu7oEEMC06MMKc5NLbA/1YDXV1sMpSqEeLQLg=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.21.0 h1:tIqheXEFWAZ7O8A7m+J0aPTmpJN3YQ7qetUAdkkkKpk=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.21.0/go.mod h1:nUeKExfxAQVbiVFn32YXpXZZHZ61Cc3s3Rn1pDBGAb0=
go.opentelemetry.io/otel/metric v1.21.0 h1:tlYWfeo+Bocx5kLEloTjbcDwBuELRrIFxwdQ36PlJu4=
go.opentelemetry.io/otel/metric v1.21.0/go.mod h1:o1p3CA8nNHW8j5yuQLdc1eeqEaPfzug24uvsyIEJRWM=
go.opentelemetry.io/otel/metric v1.23.1 h1:PQJmqJ9u2QaJLBOELl1cxIdPcpbwzbkjfEyelTl2rlo=
go.opentelemetry.io/otel/metric v1.23.1/go.mod h1:mpG2QPlAfnK8yNhNJAxDZruU9Y1/HubbC+KyH8FaCWI=
go.opentelemetry.io/otel/sdk v1.21.0 h1:FTt8qirL1EysG6sTQRZ5TokkU8d0ugCj8htOgThZXQ8=
go.opentelemetry.io/otel/sdk v1.21.0/go.mod h1:Nna6Yv7PWTdgJHVRD9hIYywQBRx7pbox6nwBnZIxl/E=
go.opentelemetry.io/otel/trace v1.21.0 h1:WD9i5gzvoUPuXIXH24ZNBudiarZDKuekPqi/E8fpfLc=
go.opentelemetry.io/otel/trace v1.21.0/go.mod h1:LGbsEB0f9LGjN+OZaQQ26sohbOmiMR+BaslueVtS/qQ=
go.opentelemetry.io/otel/trace v1.23.1 h1:4LrmmEd8AU2rFvU1zegmvqW7+kWarxtNOPyeL6HmYY8=
go.opentelemetry.io/otel/trace v1.23.1/go.mod h1:4IpnpJFwr1mo/6HL8XIPJaE9y0+u1KcVmuW7dwFSVrI=
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
Expand Down
66 changes: 65 additions & 1 deletion internal/auth/token/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package token_test

import (
"fmt"
"os"
"testing"

"github.com/google/uuid"
Expand All @@ -16,7 +17,7 @@ import (
"github.com/hatchet-dev/hatchet/internal/testutils"
)

func TestCreateTenantToken(t *testing.T) {
func TestCreateTenantToken(t *testing.T) { // make sure no cache is used for tests
testutils.RunTestWithDatabase(t, func(conf *database.Config) error {
jwtManager := getJWTManager(t, conf)

Expand Down Expand Up @@ -56,6 +57,8 @@ func TestCreateTenantToken(t *testing.T) {
}

func TestRevokeTenantToken(t *testing.T) {
_ = os.Setenv("CACHE_DURATION", "0")

testutils.RunTestWithDatabase(t, func(conf *database.Config) error {
jwtManager := getJWTManager(t, conf)

Expand Down Expand Up @@ -106,12 +109,73 @@ func TestRevokeTenantToken(t *testing.T) {
// validate the token again
_, err = jwtManager.ValidateTenantToken(token)

// error as the token was revoked
assert.Error(t, err)

return nil
})
}

func TestRevokeTenantTokenCache(t *testing.T) {
_ = os.Setenv("CACHE_DURATION", "60s")

testutils.RunTestWithDatabase(t, func(conf *database.Config) error {
jwtManager := getJWTManager(t, conf)

tenantId := uuid.New().String()

// create the tenant
slugSuffix, err := encryption.GenerateRandomBytes(8)

if err != nil {
t.Fatal(err.Error())
}

_, err = conf.Repository.Tenant().CreateTenant(&repository.CreateTenantOpts{
ID: &tenantId,
Name: "test-tenant",
Slug: fmt.Sprintf("test-tenant-%s", slugSuffix),
})

if err != nil {
t.Fatal(err.Error())
}

token, err := jwtManager.GenerateTenantToken(tenantId, "test token")

if err != nil {
t.Fatal(err.Error())
}

// validate the token
_, err = jwtManager.ValidateTenantToken(token)

assert.NoError(t, err)

// revoke the token
apiTokens, err := conf.Repository.APIToken().ListAPITokensByTenant(tenantId)

if err != nil {
t.Fatal(err.Error())
}

assert.Len(t, apiTokens, 1)
err = conf.Repository.APIToken().RevokeAPIToken(apiTokens[0].ID)

if err != nil {
t.Fatal(err.Error())
}

// validate the token again
_, err = jwtManager.ValidateTenantToken(token)

// no error as it is cached
assert.NoError(t, err)

return nil
})
}

func getJWTManager(t *testing.T, conf *database.Config) token.JWTManager {
t.Helper()

Expand Down
130 changes: 130 additions & 0 deletions internal/cache/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package cache

import (
"sync"
"time"
)

// item represents a cache item with a value and an expiration time.
type item[V any] struct {
value V
expiry time.Time
}

// isExpired checks if the cache item has expired.
func (i item[V]) isExpired() bool {
return time.Now().After(i.expiry)
}

// TTLCache is a generic cache implementation with support for time-to-live
// (TTL) expiration.
type TTLCache[K comparable, V any] struct {
items map[K]item[V] // The map storing cache items.
mu sync.Mutex // Mutex for controlling concurrent access to the cache.
stop chan interface{} // Channel to stop the goroutine that removes expired items.
}

// NewTTL creates a new TTLCache instance and starts a goroutine to periodically
// remove expired items every 5 seconds.
func NewTTL[K comparable, V any]() *TTLCache[K, V] {
c := &TTLCache[K, V]{
items: make(map[K]item[V]),
stop: make(chan interface{}),
}

go func() {
// Create a new ticker to remove expired items every 5 seconds.
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()

for {
select {
case <-c.stop:
return
case <-ticker.C:
c.mu.Lock()

// Iterate over the cache items and delete expired ones.
for key, item := range c.items {
if item.isExpired() {
delete(c.items, key)
}
}

c.mu.Unlock()
}
}
}()

return c
}

func (c *TTLCache[K, V]) Stop() {
close(c.stop)
}

// Set adds a new item to the cache with the specified key, value, and
// time-to-live (TTL).
func (c *TTLCache[K, V]) Set(key K, value V, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()

c.items[key] = item[V]{
value: value,
expiry: time.Now().Add(ttl),
}
}

// Get retrieves the value associated with the given key from the cache.
func (c *TTLCache[K, V]) Get(key K) (V, bool) {
c.mu.Lock()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to add a benchmark test to see how many Gets/second we can handle. Don't imagine this will be a bottleneck anytime soon but would still be good to have a sense.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, tracked via #282

defer c.mu.Unlock()

item, found := c.items[key]
if !found {
// If the key is not found, return the zero value for V and false.
return item.value, false
}

if item.isExpired() {
// If the item has expired, remove it from the cache and return the
// value and false.
delete(c.items, key)
return item.value, false
}

// Otherwise return the value and true.
return item.value, true
}

// Remove removes the item with the specified key from the cache.
func (c *TTLCache[K, V]) Remove(key K) {
c.mu.Lock()
defer c.mu.Unlock()

// Delete the item with the given key from the cache.
delete(c.items, key)
}

// Pop removes and returns the item with the specified key from the cache.
func (c *TTLCache[K, V]) Pop(key K) (V, bool) {
c.mu.Lock()
defer c.mu.Unlock()

item, found := c.items[key]
if !found {
// If the key is not found, return the zero value for V and false.
return item.value, false
}

// If the key is found, delete the item from the cache.
delete(c.items, key)

if item.isExpired() {
// If the item has expired, return the value and false.
return item.value, false
}

// Otherwise return the value and true.
return item.value, true
}
6 changes: 6 additions & 0 deletions internal/config/database/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package database

import (
"time"

"github.com/spf13/viper"

"github.com/hatchet-dev/hatchet/internal/config/shared"
Expand All @@ -20,6 +22,8 @@ type ConfigFile struct {
Logger shared.LoggerConfigFile `mapstructure:"logger" json:"logger,omitempty"`

LogQueries bool `mapstructure:"logQueries" json:"logQueries,omitempty" default:"false"`

CacheDuration time.Duration `mapstructure:"cacheDuration" json:"cacheDuration,omitempty" default:"60s"`
}

type SeedConfigFile struct {
Expand Down Expand Up @@ -51,6 +55,8 @@ func BindAllEnv(v *viper.Viper) {
_ = v.BindEnv("sslMode", "DATABASE_POSTGRES_SSL_MODE")
_ = v.BindEnv("logQueries", "DATABASE_LOG_QUERIES")

_ = v.BindEnv("cacheDuration", "CACHE_DURATION")

_ = v.BindEnv("seed.adminEmail", "ADMIN_EMAIL")
_ = v.BindEnv("seed.adminPassword", "ADMIN_PASSWORD")
_ = v.BindEnv("seed.adminName", "ADMIN_NAME")
Expand Down
23 changes: 13 additions & 10 deletions internal/config/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ import (
"path/filepath"
"strings"

"github.com/exaring/otelpgx"
pgxzero "github.com/jackc/pgx-zerolog"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/tracelog"

pgxzero "github.com/jackc/pgx-zerolog"

"github.com/hatchet-dev/hatchet/internal/auth/cookie"
"github.com/hatchet-dev/hatchet/internal/auth/oauth"
"github.com/hatchet-dev/hatchet/internal/auth/token"
Expand All @@ -27,15 +27,14 @@ import (
"github.com/hatchet-dev/hatchet/internal/integrations/vcs/github"
"github.com/hatchet-dev/hatchet/internal/logger"
"github.com/hatchet-dev/hatchet/internal/msgqueue/rabbitmq"
"github.com/hatchet-dev/hatchet/internal/repository/cache"
"github.com/hatchet-dev/hatchet/internal/repository/prisma"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
"github.com/hatchet-dev/hatchet/internal/services/ingestor"
"github.com/hatchet-dev/hatchet/internal/validator"
"github.com/hatchet-dev/hatchet/pkg/client"
"github.com/hatchet-dev/hatchet/pkg/errors"
"github.com/hatchet-dev/hatchet/pkg/errors/sentry"

"github.com/exaring/otelpgx"
)

// LoadDatabaseConfigFile loads the database config file via viper
Expand Down Expand Up @@ -121,11 +120,10 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile) (res *database.Con
cf.PostgresSSLMode,
)

os.Setenv("DATABASE_URL", databaseUrl)
// TODO db.WithDatasourceURL(databaseUrl) is not working
_ = os.Setenv("DATABASE_URL", databaseUrl)

c := db.NewClient(
// db.WithDatasourceURL(databaseUrl),
)
c := db.NewClient()

if err := c.Prisma.Connect(); err != nil {
return nil, err
Expand All @@ -152,9 +150,14 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile) (res *database.Con
return nil, fmt.Errorf("could not connect to database: %w", err)
}

ch := cache.New(cf.CacheDuration)

return &database.Config{
Disconnect: c.Prisma.Disconnect,
Repository: prisma.NewPrismaRepository(c, pool, prisma.WithLogger(&l)),
Disconnect: func() error {
ch.Stop()
return c.Prisma.Disconnect()
},
Repository: prisma.NewPrismaRepository(c, pool, prisma.WithLogger(&l), prisma.WithCache(ch)),
Seed: cf.Seed,
}, nil
}
Expand Down
Loading
Loading