Skip to content

Commit

Permalink
Clean up getReader
Browse files Browse the repository at this point in the history
  • Loading branch information
DylanTinianov committed Nov 15, 2024
1 parent ca7c982 commit e671b69
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 29 deletions.
32 changes: 19 additions & 13 deletions pkg/solana/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,10 @@ func TestGetState(t *testing.T) {
}))
defer mockServer.Close()

reader := func() (client.AccountReader, error) { return testSetupReader(t, mockServer.URL), nil }
reader := testSetupReader(t, mockServer.URL)
getReader := func() (client.AccountReader, error) { return reader, nil }
// happy path does not error (actual state decoding handled in types_test)
_, _, err := GetState(context.TODO(), reader, solana.PublicKey{}, "")
_, _, err := GetState(context.TODO(), getReader, solana.PublicKey{}, "")
require.NoError(t, err)
}

Expand All @@ -133,18 +134,19 @@ func TestGetLatestTransmission(t *testing.T) {
}))
defer mockServer.Close()

reader := func() (client.AccountReader, error) { return testSetupReader(t, mockServer.URL), nil }
a, _, err := GetLatestTransmission(context.TODO(), reader, solana.PublicKey{}, "")
reader := testSetupReader(t, mockServer.URL)
getReader := func() (client.AccountReader, error) { return reader, nil }
a, _, err := GetLatestTransmission(context.TODO(), getReader, solana.PublicKey{}, "")
assert.NoError(t, err)
assert.Equal(t, expectedTime, a.Timestamp)
assert.Equal(t, expectedAns, a.Data.String())

// fail if returned transmission header is too short
_, _, err = GetLatestTransmission(context.TODO(), reader, solana.PublicKey{}, "")
_, _, err = GetLatestTransmission(context.TODO(), getReader, solana.PublicKey{}, "")
assert.Error(t, err)

// fail if returned transmission is too short
_, _, err = GetLatestTransmission(context.TODO(), reader, solana.PublicKey{}, "")
_, _, err = GetLatestTransmission(context.TODO(), getReader, solana.PublicKey{}, "")
assert.Error(t, err)
}

Expand All @@ -167,14 +169,16 @@ func TestCache(t *testing.T) {
w.Write(testTransmissionsResponse(t, body, 0)) //nolint:errcheck
}))

reader := func() (client.Reader, error) { return testSetupReader(t, mockServer.URL), nil }
reader := testSetupReader(t, mockServer.URL)
getReader := func() (client.Reader, error) { return reader, nil }
getAccountReader := func() (client.AccountReader, error) { return reader, nil }

lggr := logger.Test(t)
stateCache := NewStateCache(
solana.MustPublicKeyFromBase58("11111111111111111111111111111111"),
"test-chain-id",
config.NewDefault(),
reader,
getReader,
lggr,
)
require.NoError(t, stateCache.Start(ctx))
Expand All @@ -189,7 +193,7 @@ func TestCache(t *testing.T) {
solana.MustPublicKeyFromBase58("11111111111111111111111111111112"),
"test-chain-id",
config.NewDefault(),
reader,
getAccountReader,
lggr,
)
require.NoError(t, transmissionsCache.Start(ctx))
Expand Down Expand Up @@ -223,17 +227,19 @@ func TestNilPointerHandling(t *testing.T) {
defer mockServer.Close()

errString := "nil pointer returned in "
reader := func() (client.AccountReader, error) { return testSetupReader(t, mockServer.URL), nil }

reader := testSetupReader(t, mockServer.URL)
getReader := func() (client.AccountReader, error) { return reader, nil }

// fail on get state query
_, _, err := GetState(context.TODO(), reader, solana.PublicKey{}, "")
_, _, err := GetState(context.TODO(), getReader, solana.PublicKey{}, "")
assert.EqualError(t, err, errString+"GetState.GetAccountInfoWithOpts")

// fail on transmissions header query
_, _, err = GetLatestTransmission(context.TODO(), reader, solana.PublicKey{}, "")
_, _, err = GetLatestTransmission(context.TODO(), getReader, solana.PublicKey{}, "")
assert.EqualError(t, err, errString+"GetLatestTransmission.GetAccountInfoWithOpts.Header")

passFirst = true // allow proper response for header query, fail on transmission
_, _, err = GetLatestTransmission(context.TODO(), reader, solana.PublicKey{}, "")
_, _, err = GetLatestTransmission(context.TODO(), getReader, solana.PublicKey{}, "")
assert.EqualError(t, err, errString+"GetLatestTransmission.GetAccountInfoWithOpts.Transmission")
}
3 changes: 1 addition & 2 deletions pkg/solana/client/multinode/poller.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ func (p *Poller[T]) Err() <-chan error {
}

func (p *Poller[T]) pollingLoop(ctx context.Context) {
tickerCfg := services.TickerConfig{Initial: 0, JitterPct: services.DefaultJitter}
ticker := tickerCfg.NewTicker(p.pollingInterval)
ticker := services.NewTicker(p.pollingInterval)
defer ticker.Stop()

for {
Expand Down
4 changes: 2 additions & 2 deletions pkg/solana/config_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

type ConfigTracker struct {
stateCache *StateCache
reader GetReader
getReader GetReader
}

func (c *ConfigTracker) Notify() <-chan struct{} {
Expand Down Expand Up @@ -73,7 +73,7 @@ func (c *ConfigTracker) LatestConfig(ctx context.Context, changedInBlock uint64)

// LatestBlockHeight returns the height of the most recent block in the chain.
func (c *ConfigTracker) LatestBlockHeight(ctx context.Context) (blockHeight uint64, err error) {
reader, err := c.reader()
reader, err := c.getReader()
if err != nil {
return 0, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/solana/config_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestLatestBlockHeight(t *testing.T) {

ctx := context.Background()
c := &ConfigTracker{
reader: func() (client.Reader, error) { return testSetupReader(t, mockServer.URL), nil },
getReader: func() (client.Reader, error) { return testSetupReader(t, mockServer.URL), nil },
}

h, err := c.LatestBlockHeight(ctx)
Expand Down
7 changes: 3 additions & 4 deletions pkg/solana/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ func (r *Relayer) NewMedianProvider(ctx context.Context, rargs relaytypes.RelayA
}

cfg := configWatcher.chain.Config()
transmissionsCache := NewTransmissionsCache(transmissionsID, relayConfig.ChainID, cfg, configWatcher.chain.Reader, r.lggr)
getReader := func() (client.AccountReader, error) { return configWatcher.chain.Reader() }
transmissionsCache := NewTransmissionsCache(transmissionsID, relayConfig.ChainID, cfg, getReader, r.lggr)
return &medianProvider{
configProvider: configWatcher,
transmissionsCache: transmissionsCache,
Expand Down Expand Up @@ -187,8 +188,6 @@ func (r *Relayer) NewAutomationProvider(ctx context.Context, rargs relaytypes.Re

var _ relaytypes.ConfigProvider = &configProvider{}

type GetReader func() (client.Reader, error)

type configProvider struct {
services.StateMachine
chainID string
Expand Down Expand Up @@ -231,7 +230,7 @@ func newConfigProvider(_ context.Context, lggr logger.Logger, chain Chain, args
storeProgramID: storeProgramID,
stateCache: stateCache,
offchainConfigDigester: offchainConfigDigester,
configTracker: &ConfigTracker{stateCache: stateCache, reader: chain.Reader},
configTracker: &ConfigTracker{stateCache: stateCache, getReader: chain.Reader},
chain: chain,
}, nil
}
Expand Down
9 changes: 5 additions & 4 deletions pkg/solana/state_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,20 @@ type StateCache struct {
*client.Cache[State]
}

type GetReader func() (client.Reader, error)
type GetAccountReader func() (client.AccountReader, error)

func NewStateCache(stateID solana.PublicKey, chainID string, cfg config.Config, reader GetReader, lggr logger.Logger) *StateCache {
func NewStateCache(stateID solana.PublicKey, chainID string, cfg config.Config, getReader GetReader, lggr logger.Logger) *StateCache {
name := "ocr2_median_state"
getter := func(ctx context.Context) (State, uint64, error) {
getAccountReader := func() (client.AccountReader, error) { return reader() }
getAccountReader := func() (client.AccountReader, error) { return getReader() }
return GetState(ctx, getAccountReader, stateID, cfg.Commitment())
}
return &StateCache{client.NewCache(name, stateID, chainID, cfg, getter, logger.With(lggr, "cache", name))}
}

func GetState(ctx context.Context, reader GetAccountReader, account solana.PublicKey, commitment rpc.CommitmentType) (State, uint64, error) {
r, err := reader()
func GetState(ctx context.Context, getReader GetAccountReader, account solana.PublicKey, commitment rpc.CommitmentType) (State, uint64, error) {
r, err := getReader()
if err != nil {
return State{}, 0, fmt.Errorf("failed to get reader: %w", err)
}
Expand Down
5 changes: 2 additions & 3 deletions pkg/solana/transmissions_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ type TransmissionsCache struct {
*client.Cache[Answer]
}

func NewTransmissionsCache(transmissionsID solana.PublicKey, chainID string, cfg config.Config, getReader GetReader, lggr logger.Logger) *TransmissionsCache {
func NewTransmissionsCache(transmissionsID solana.PublicKey, chainID string, cfg config.Config, getReader GetAccountReader, lggr logger.Logger) *TransmissionsCache {
name := "ocr2_median_transmissions"
getter := func(ctx context.Context) (Answer, uint64, error) {
getAccountReader := func() (client.AccountReader, error) { return getReader() }
return GetLatestTransmission(ctx, getAccountReader, transmissionsID, cfg.Commitment())
return GetLatestTransmission(ctx, getReader, transmissionsID, cfg.Commitment())
}
return &TransmissionsCache{client.NewCache(name, transmissionsID, chainID, cfg, getter, logger.With(lggr, "cache", name))}
}
Expand Down

0 comments on commit e671b69

Please sign in to comment.