diff --git a/internal/node/node.go b/internal/node/node.go index 6ffe29bf5..e94cfbb4d 100644 --- a/internal/node/node.go +++ b/internal/node/node.go @@ -194,7 +194,7 @@ func NewBus(cfg BusConfig, dir string, seed types.PrivateKey, l *zap.Logger) (ht func NewWorker(cfg config.Worker, b worker.Bus, seed types.PrivateKey, l *zap.Logger) (http.Handler, ShutdownFn, error) { workerKey := blake2b.Sum256(append([]byte("worker"), seed...)) - w, err := worker.New(workerKey, cfg.ID, b, cfg.ContractLockTimeout, cfg.BusFlushInterval, cfg.DownloadOverdriveTimeout, cfg.UploadOverdriveTimeout, cfg.DownloadMaxOverdrive, cfg.DownloadMaxMemory, cfg.UploadMaxMemory, cfg.UploadMaxOverdrive, cfg.AllowPrivateIPs, l) + w, err := worker.New(workerKey, cfg.ID, b, cfg.ContractLockTimeout, cfg.BusFlushInterval, cfg.DownloadOverdriveTimeout, cfg.UploadOverdriveTimeout, cfg.DownloadMaxOverdrive, cfg.UploadMaxOverdrive, cfg.DownloadMaxMemory, cfg.UploadMaxMemory, cfg.AllowPrivateIPs, l) if err != nil { return nil, nil, err } diff --git a/worker/download.go b/worker/download.go index 462a2292d..288db7f7c 100644 --- a/worker/download.go +++ b/worker/download.go @@ -132,7 +132,7 @@ func (w *worker) initDownloadManager(maxMemory, maxOverdrive uint64, overdriveTi panic("download manager already initialized") // developer error } - mm := newMemoryManager(logger, maxMemory) + mm := newMemoryManager(logger.Named("memorymanager"), maxMemory) w.downloadManager = newDownloadManager(w.shutdownCtx, w, mm, w.bus, maxOverdrive, overdriveTimeout, logger) } diff --git a/worker/downloader_test.go b/worker/downloader_test.go index c1d860c24..cbb48132c 100644 --- a/worker/downloader_test.go +++ b/worker/downloader_test.go @@ -7,11 +7,15 @@ import ( ) func TestDownloaderStopped(t *testing.T) { - w := newMockWorker() - h := w.addHost() - w.dl.refreshDownloaders(w.contracts()) + w := newTestWorker(t) + hosts := w.addHosts(1) - dl := w.dl.downloaders[h.PublicKey()] + // convenience variables + dm := w.downloadManager + h := hosts[0] + + dm.refreshDownloaders(w.contracts()) + dl := w.downloadManager.downloaders[h.PublicKey()] dl.Stop() req := sectorDownloadReq{ diff --git a/worker/gouging.go b/worker/gouging.go index 36963e24a..19ae177aa 100644 --- a/worker/gouging.go +++ b/worker/gouging.go @@ -63,7 +63,7 @@ func GougingCheckerFromContext(ctx context.Context, criticalMigration bool) (Gou return gc(criticalMigration) } -func WithGougingChecker(ctx context.Context, cs consensusState, gp api.GougingParams) context.Context { +func WithGougingChecker(ctx context.Context, cs ConsensusState, gp api.GougingParams) context.Context { return context.WithValue(ctx, keyGougingChecker, func(criticalMigration bool) (GougingChecker, error) { consensusState, err := cs.ConsensusState(ctx) if err != nil { diff --git a/worker/host.go b/worker/host.go index 43e0891af..ac8925872 100644 --- a/worker/host.go +++ b/worker/host.go @@ -35,10 +35,6 @@ type ( HostManager interface { Host(hk types.PublicKey, fcid types.FileContractID, siamuxAddr string) Host } - - HostStore interface { - Host(ctx context.Context, hostKey types.PublicKey) (hostdb.HostInfo, error) - } ) type ( diff --git a/worker/host_test.go b/worker/host_test.go index 87d35fb36..3703ce810 100644 --- a/worker/host_test.go +++ b/worker/host_test.go @@ -4,18 +4,130 @@ import ( "bytes" "context" "errors" + "io" + "sync" "testing" + "time" rhpv2 "go.sia.tech/core/rhp/v2" + rhpv3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/hostdb" + "lukechampine.com/frand" ) +type ( + testHost struct { + *hostMock + *contractMock + hptFn func() hostdb.HostPriceTable + } + + testHostManager struct { + t *testing.T + + mu sync.Mutex + hosts map[types.PublicKey]*testHost + } +) + +func newTestHostManager(t *testing.T) *testHostManager { + return &testHostManager{t: t, hosts: make(map[types.PublicKey]*testHost)} +} + +func (hm *testHostManager) Host(hk types.PublicKey, fcid types.FileContractID, siamuxAddr string) Host { + hm.mu.Lock() + defer hm.mu.Unlock() + + if _, ok := hm.hosts[hk]; !ok { + hm.t.Fatal("host not found") + } + return hm.hosts[hk] +} + +func (hm *testHostManager) addHost(hosts ...*testHost) { + hm.mu.Lock() + defer hm.mu.Unlock() + for _, h := range hosts { + hm.hosts[h.hk] = h + } +} + +func newTestHost(h *hostMock, c *contractMock) *testHost { + return newTestHostCustom(h, c, func() hostdb.HostPriceTable { return newTestHostPriceTable(time.Now().Add(time.Minute)) }) +} + +func newTestHostCustom(h *hostMock, c *contractMock, hptFn func() hostdb.HostPriceTable) *testHost { + return &testHost{ + hostMock: h, + contractMock: c, + hptFn: hptFn, + } +} + +func newTestHostPriceTable(expiry time.Time) hostdb.HostPriceTable { + var uid rhpv3.SettingsID + frand.Read(uid[:]) + + return hostdb.HostPriceTable{ + HostPriceTable: rhpv3.HostPriceTable{UID: uid, Validity: time.Minute}, + Expiry: expiry, + } +} + +func (h *testHost) PublicKey() types.PublicKey { + return h.hk +} + +func (h *testHost) DownloadSector(ctx context.Context, w io.Writer, root types.Hash256, offset, length uint32, overpay bool) error { + sector, exist := h.Sector(root) + if !exist { + return errSectorNotFound + } + if offset+length > rhpv2.SectorSize { + return errSectorOutOfBounds + } + _, err := w.Write(sector[offset : offset+length]) + return err +} + +func (h *testHost) UploadSector(ctx context.Context, sector *[rhpv2.SectorSize]byte, rev types.FileContractRevision) (types.Hash256, error) { + return h.AddSector(sector), nil +} + +func (h *testHost) FetchRevision(ctx context.Context, fetchTimeout time.Duration) (rev types.FileContractRevision, _ error) { + h.mu.Lock() + defer h.mu.Unlock() + rev = h.rev + return rev, nil +} + +func (h *testHost) FetchPriceTable(ctx context.Context, rev *types.FileContractRevision) (hostdb.HostPriceTable, error) { + return h.hptFn(), nil +} + +func (h *testHost) FundAccount(ctx context.Context, balance types.Currency, rev *types.FileContractRevision) error { + return nil +} + +func (h *testHost) RenewContract(ctx context.Context, rrr api.RHPRenewRequest) (_ rhpv2.ContractRevision, _ []types.Transaction, _ types.Currency, err error) { + return rhpv2.ContractRevision{}, nil, types.ZeroCurrency, nil +} + +func (h *testHost) SyncAccount(ctx context.Context, rev *types.FileContractRevision) error { + return nil +} + func TestHost(t *testing.T) { - h := newMockHost(types.PublicKey{1}) - h.c = newMockContract(h.hk, types.FileContractID{1}) - sector, root := newMockSector() + // create test host + h := newTestHost( + newHostMock(types.PublicKey{1}), + newContractMock(types.PublicKey{1}, types.FileContractID{1}), + ) // upload the sector + sector, root := newTestSector() uploaded, err := h.UploadSector(context.Background(), sector, types.FileContractRevision{}) if err != nil { t.Fatal(err) diff --git a/worker/mocks_test.go b/worker/mocks_test.go index 2490941af..63d52096a 100644 --- a/worker/mocks_test.go +++ b/worker/mocks_test.go @@ -5,222 +5,398 @@ import ( "encoding/json" "errors" "fmt" - "io" + "math/big" "sync" "time" rhpv2 "go.sia.tech/core/rhp/v2" rhpv3 "go.sia.tech/core/rhp/v3" "go.sia.tech/core/types" + "go.sia.tech/renterd/alerts" "go.sia.tech/renterd/api" "go.sia.tech/renterd/hostdb" "go.sia.tech/renterd/object" - "go.uber.org/zap" - "lukechampine.com/frand" + "go.sia.tech/renterd/webhooks" ) -type ( - mockContract struct { - rev types.FileContractRevision - metadata api.ContractMetadata +var _ AccountStore = (*accountsMock)(nil) - mu sync.Mutex - sectors map[types.Hash256]*[rhpv2.SectorSize]byte - } +type accountsMock struct{} - mockContractStore struct { - mu sync.Mutex - locks map[types.FileContractID]*sync.Mutex - } +func (*accountsMock) Accounts(context.Context) ([]api.Account, error) { + return nil, nil +} - mockHost struct { - hk types.PublicKey +func (*accountsMock) AddBalance(context.Context, rhpv3.Account, types.PublicKey, *big.Int) error { + return nil +} - mu sync.Mutex - c *mockContract +func (*accountsMock) LockAccount(context.Context, rhpv3.Account, types.PublicKey, bool, time.Duration) (api.Account, uint64, error) { + return api.Account{}, 0, nil +} - hpt hostdb.HostPriceTable - hptBlockChan chan struct{} - } +func (*accountsMock) UnlockAccount(context.Context, rhpv3.Account, uint64) error { + return nil +} - mockHostManager struct { - mu sync.Mutex - hosts map[types.PublicKey]*mockHost - } +func (*accountsMock) ResetDrift(context.Context, rhpv3.Account) error { + return nil +} - mockMemory struct{} - mockMemoryManager struct { - memBlockChan chan struct{} - } +func (*accountsMock) SetBalance(context.Context, rhpv3.Account, types.PublicKey, *big.Int) error { + return nil +} - mockObjectStore struct { - mu sync.Mutex - objects map[string]map[string]object.Object - partials map[string]mockPackedSlab - bufferIDCntr uint // allows marking packed slabs as uploaded - } +func (*accountsMock) ScheduleSync(context.Context, rhpv3.Account, types.PublicKey) error { + return nil +} - mockPackedSlab struct { - parameterKey string // ([minshards]-[totalshards]-[contractset]) - bufferID uint - slabKey object.EncryptionKey - data []byte - } +var _ alerts.Alerter = (*alerterMock)(nil) - mockWorker struct { - cs *mockContractStore - hm *mockHostManager - mm *mockMemoryManager - os *mockObjectStore +type alerterMock struct{} - dl *downloadManager - ul *uploadManager +func (*alerterMock) RegisterAlert(context.Context, alerts.Alert) error { return nil } +func (*alerterMock) DismissAlerts(context.Context, ...types.Hash256) error { return nil } +func (*alerterMock) DismissAllAlerts(context.Context) error { return nil } - mu sync.Mutex - hkCntr uint - fcidCntr uint - } -) +var _ ConsensusState = (*chainMock)(nil) -var ( - _ ContractStore = (*mockContractStore)(nil) - _ Host = (*mockHost)(nil) - _ HostManager = (*mockHostManager)(nil) - _ Memory = (*mockMemory)(nil) - _ MemoryManager = (*mockMemoryManager)(nil) - _ ObjectStore = (*mockObjectStore)(nil) -) +type chainMock struct{} -var ( - errBucketNotFound = errors.New("bucket not found") - errContractNotFound = errors.New("contract not found") - errObjectNotFound = errors.New("object not found") - errSlabNotFound = errors.New("slab not found") - errSectorOutOfBounds = errors.New("sector out of bounds") -) +func (c *chainMock) ConsensusState(ctx context.Context) (api.ConsensusState, error) { + return api.ConsensusState{}, nil +} -type ( - mockHosts []*mockHost - mockContracts []*mockContract -) +var _ Bus = (*busMock)(nil) -func (hosts mockHosts) contracts() mockContracts { - contracts := make([]*mockContract, len(hosts)) - for i, host := range hosts { - contracts[i] = host.c - } - return contracts +type busMock struct { + *alerterMock + *accountsMock + *chainMock + *contractLockerMock + *contractStoreMock + *hostStoreMock + *objectStoreMock + *settingStoreMock + *syncerMock + *walletMock + *webhookBroadcasterMock } -func (contracts mockContracts) metadata() []api.ContractMetadata { - metadata := make([]api.ContractMetadata, len(contracts)) - for i, contract := range contracts { - metadata[i] = contract.metadata +func newBusMock(cs *contractStoreMock, hs *hostStoreMock, os *objectStoreMock) *busMock { + return &busMock{ + alerterMock: &alerterMock{}, + accountsMock: &accountsMock{}, + chainMock: &chainMock{}, + contractLockerMock: newContractLockerMock(), + contractStoreMock: cs, + hostStoreMock: hs, + objectStoreMock: os, + settingStoreMock: &settingStoreMock{}, + syncerMock: &syncerMock{}, + walletMock: &walletMock{}, + webhookBroadcasterMock: &webhookBroadcasterMock{}, } - return metadata } -func (m *mockMemory) Release() {} -func (m *mockMemory) ReleaseSome(uint64) {} +type contractMock struct { + rev types.FileContractRevision + metadata api.ContractMetadata -func (mm *mockMemoryManager) Limit(amt uint64) (MemoryManager, error) { - return &mockMemoryManager{}, nil + mu sync.Mutex + sectors map[types.Hash256]*[rhpv2.SectorSize]byte } -func (mm *mockMemoryManager) Status() api.MemoryStatus { return api.MemoryStatus{} } -func (mm *mockMemoryManager) AcquireMemory(ctx context.Context, amt uint64) Memory { - if mm.memBlockChan != nil { - <-mm.memBlockChan + +func newContractMock(hk types.PublicKey, fcid types.FileContractID) *contractMock { + return &contractMock{ + metadata: api.ContractMetadata{ + ID: fcid, + HostKey: hk, + WindowStart: 0, + WindowEnd: 10, + }, + rev: types.FileContractRevision{ParentID: fcid}, + sectors: make(map[types.Hash256]*[rhpv2.SectorSize]byte), } - return &mockMemory{} } -func newMockContractStore() *mockContractStore { - return &mockContractStore{ +func (c *contractMock) AddSector(sector *[rhpv2.SectorSize]byte) (root types.Hash256) { + root = rhpv2.SectorRoot(sector) + c.mu.Lock() + c.sectors[root] = sector + c.mu.Unlock() + return +} + +func (c *contractMock) Sector(root types.Hash256) (sector *[rhpv2.SectorSize]byte, found bool) { + c.mu.Lock() + sector, found = c.sectors[root] + c.mu.Unlock() + return +} + +var _ ContractLocker = (*contractLockerMock)(nil) + +type contractLockerMock struct { + mu sync.Mutex + locks map[types.FileContractID]*sync.Mutex +} + +func newContractLockerMock() *contractLockerMock { + return &contractLockerMock{ locks: make(map[types.FileContractID]*sync.Mutex), } } -func (cs *mockContractStore) AcquireContract(ctx context.Context, fcid types.FileContractID, priority int, d time.Duration) (lockID uint64, err error) { +func (cs *contractLockerMock) AcquireContract(_ context.Context, fcid types.FileContractID, _ int, _ time.Duration) (uint64, error) { cs.mu.Lock() defer cs.mu.Unlock() - if lock, ok := cs.locks[fcid]; !ok { - return 0, errContractNotFound - } else { - lock.Lock() + lock, exists := cs.locks[fcid] + if !exists { + cs.locks[fcid] = new(sync.Mutex) + lock = cs.locks[fcid] } + + lock.Lock() return 0, nil } -func (cs *mockContractStore) ReleaseContract(ctx context.Context, fcid types.FileContractID, lockID uint64) (err error) { +func (cs *contractLockerMock) ReleaseContract(_ context.Context, fcid types.FileContractID, _ uint64) error { cs.mu.Lock() defer cs.mu.Unlock() - if lock, ok := cs.locks[fcid]; !ok { - return errContractNotFound - } else { - lock.Unlock() - } + cs.locks[fcid].Unlock() + delete(cs.locks, fcid) return nil } -func (cs *mockContractStore) KeepaliveContract(ctx context.Context, fcid types.FileContractID, lockID uint64, d time.Duration) (err error) { +func (*contractLockerMock) KeepaliveContract(context.Context, types.FileContractID, uint64, time.Duration) error { return nil } -func (os *mockContractStore) RenewedContract(ctx context.Context, renewedFrom types.FileContractID) (api.ContractMetadata, error) { - return api.ContractMetadata{}, api.ErrContractNotFound +var _ ContractStore = (*contractStoreMock)(nil) + +type contractStoreMock struct { + mu sync.Mutex + contracts map[types.FileContractID]*contractMock + hosts2fcid map[types.PublicKey]types.FileContractID + fcidCntr uint } -func newMockObjectStore() *mockObjectStore { - os := &mockObjectStore{ - objects: make(map[string]map[string]object.Object), - partials: make(map[string]mockPackedSlab), +func newContractStoreMock() *contractStoreMock { + return &contractStoreMock{ + contracts: make(map[types.FileContractID]*contractMock), + hosts2fcid: make(map[types.PublicKey]types.FileContractID), } - os.objects[testBucket] = make(map[string]object.Object) - return os } -func (cs *mockContractStore) addContract(c *mockContract) { +func (*contractStoreMock) RenewedContract(context.Context, types.FileContractID) (api.ContractMetadata, error) { + return api.ContractMetadata{}, nil +} + +func (*contractStoreMock) Contract(context.Context, types.FileContractID) (api.ContractMetadata, error) { + return api.ContractMetadata{}, nil +} + +func (*contractStoreMock) ContractSize(context.Context, types.FileContractID) (api.ContractSize, error) { + return api.ContractSize{}, nil +} + +func (*contractStoreMock) ContractRoots(context.Context, types.FileContractID) ([]types.Hash256, []types.Hash256, error) { + return nil, nil, nil +} + +func (cs *contractStoreMock) Contracts(context.Context, api.ContractsOpts) (metadatas []api.ContractMetadata, _ error) { cs.mu.Lock() defer cs.mu.Unlock() - cs.locks[c.metadata.ID] = new(sync.Mutex) + for _, c := range cs.contracts { + metadatas = append(metadatas, c.metadata) + } + return } -func (os *mockObjectStore) AddMultipartPart(ctx context.Context, bucket, path, contractSet, ETag, uploadID string, partNumber int, slices []object.SlabSlice) (err error) { +func (cs *contractStoreMock) addContract(hk types.PublicKey) *contractMock { + cs.mu.Lock() + defer cs.mu.Unlock() + + fcid := cs.newFileContractID() + cs.contracts[fcid] = newContractMock(hk, fcid) + cs.hosts2fcid[hk] = fcid + return cs.contracts[fcid] +} + +func (cs *contractStoreMock) renewContract(hk types.PublicKey) (*contractMock, error) { + cs.mu.Lock() + defer cs.mu.Unlock() + + curr := cs.hosts2fcid[hk] + c := cs.contracts[curr] + if c == nil { + return nil, errors.New("host does not have a contract to renew") + } + delete(cs.contracts, curr) + + renewal := newContractMock(hk, cs.newFileContractID()) + renewal.metadata.RenewedFrom = c.metadata.ID + renewal.metadata.WindowStart = c.metadata.WindowEnd + renewal.metadata.WindowEnd = renewal.metadata.WindowStart + (c.metadata.WindowEnd - c.metadata.WindowStart) + cs.contracts[renewal.metadata.ID] = renewal + cs.hosts2fcid[hk] = renewal.metadata.ID + return renewal, nil +} + +func (cs *contractStoreMock) newFileContractID() types.FileContractID { + cs.fcidCntr++ + return types.FileContractID{byte(cs.fcidCntr)} +} + +var errSectorOutOfBounds = errors.New("sector out of bounds") + +type hostMock struct { + hk types.PublicKey + hi hostdb.HostInfo +} + +func newHostMock(hk types.PublicKey) *hostMock { + return &hostMock{ + hk: hk, + hi: hostdb.HostInfo{Host: hostdb.Host{PublicKey: hk, Scanned: true}}, + } +} + +var _ HostStore = (*hostStoreMock)(nil) + +type hostStoreMock struct { + mu sync.Mutex + hosts map[types.PublicKey]*hostMock + hkCntr uint +} + +func newHostStoreMock() *hostStoreMock { + return &hostStoreMock{hosts: make(map[types.PublicKey]*hostMock)} +} + +func (hs *hostStoreMock) Host(ctx context.Context, hostKey types.PublicKey) (hostdb.HostInfo, error) { + hs.mu.Lock() + defer hs.mu.Unlock() + + h, ok := hs.hosts[hostKey] + if !ok { + return hostdb.HostInfo{}, api.ErrHostNotFound + } + return h.hi, nil +} + +func (hs *hostStoreMock) RecordHostScans(ctx context.Context, scans []hostdb.HostScan) error { + return nil +} + +func (hs *hostStoreMock) RecordPriceTables(ctx context.Context, priceTableUpdate []hostdb.PriceTableUpdate) error { return nil } -func (os *mockObjectStore) AddUploadingSector(ctx context.Context, uID api.UploadID, id types.FileContractID, root types.Hash256) error { +func (hs *hostStoreMock) RecordContractSpending(ctx context.Context, records []api.ContractSpendingRecord) error { return nil } -func (os *mockObjectStore) TrackUpload(ctx context.Context, uID api.UploadID) error { return nil } +func (hs *hostStoreMock) addHost() *hostMock { + hs.mu.Lock() + defer hs.mu.Unlock() -func (os *mockObjectStore) FinishUpload(ctx context.Context, uID api.UploadID) error { return nil } + hs.hkCntr++ + hk := types.PublicKey{byte(hs.hkCntr)} + hs.hosts[hk] = newHostMock(hk) + return hs.hosts[hk] +} -func (os *mockObjectStore) DeleteHostSector(ctx context.Context, hk types.PublicKey, root types.Hash256) error { +var ( + _ MemoryManager = (*memoryManagerMock)(nil) + _ Memory = (*memoryMock)(nil) +) + +type ( + memoryMock struct{} + memoryManagerMock struct{ memBlockChan chan struct{} } +) + +func (m *memoryMock) Release() {} +func (m *memoryMock) ReleaseSome(uint64) {} + +func (mm *memoryManagerMock) Limit(amt uint64) (MemoryManager, error) { + return &memoryManagerMock{}, nil +} + +func (mm *memoryManagerMock) Status() api.MemoryStatus { return api.MemoryStatus{} } + +func (mm *memoryManagerMock) AcquireMemory(ctx context.Context, amt uint64) Memory { + if mm.memBlockChan != nil { + <-mm.memBlockChan + } + return &memoryMock{} +} + +var _ ObjectStore = (*objectStoreMock)(nil) + +type ( + objectStoreMock struct { + mu sync.Mutex + objects map[string]map[string]object.Object + partials map[string]packedSlabMock + bufferIDCntr uint // allows marking packed slabs as uploaded + } + + packedSlabMock struct { + parameterKey string // ([minshards]-[totalshards]-[contractset]) + bufferID uint + slabKey object.EncryptionKey + data []byte + } +) + +func newObjectStoreMock(bucket string) *objectStoreMock { + os := &objectStoreMock{ + objects: make(map[string]map[string]object.Object), + partials: make(map[string]packedSlabMock), + } + os.objects[bucket] = make(map[string]object.Object) + return os +} + +func (os *objectStoreMock) AddMultipartPart(ctx context.Context, bucket, path, contractSet, ETag, uploadID string, partNumber int, slices []object.SlabSlice) (err error) { return nil } -func (os *mockObjectStore) DeleteObject(ctx context.Context, bucket, path string, opts api.DeleteObjectOptions) error { +func (os *objectStoreMock) AddUploadingSector(ctx context.Context, uID api.UploadID, id types.FileContractID, root types.Hash256) error { return nil } -func (os *mockObjectStore) AddObject(ctx context.Context, bucket, path, contractSet string, o object.Object, opts api.AddObjectOptions) error { +func (os *objectStoreMock) TrackUpload(ctx context.Context, uID api.UploadID) error { return nil } + +func (os *objectStoreMock) FinishUpload(ctx context.Context, uID api.UploadID) error { return nil } + +func (os *objectStoreMock) DeleteHostSector(ctx context.Context, hk types.PublicKey, root types.Hash256) error { + return nil +} + +func (os *objectStoreMock) DeleteObject(ctx context.Context, bucket, path string, opts api.DeleteObjectOptions) error { + return nil +} + +func (os *objectStoreMock) AddObject(ctx context.Context, bucket, path, contractSet string, o object.Object, opts api.AddObjectOptions) error { os.mu.Lock() defer os.mu.Unlock() // check if the bucket exists if _, exists := os.objects[bucket]; !exists { - return errBucketNotFound + return api.ErrBucketNotFound } os.objects[bucket][path] = o return nil } -func (os *mockObjectStore) AddPartialSlab(ctx context.Context, data []byte, minShards, totalShards uint8, contractSet string) (slabs []object.SlabSlice, slabBufferMaxSizeSoftReached bool, err error) { +func (os *objectStoreMock) AddPartialSlab(ctx context.Context, data []byte, minShards, totalShards uint8, contractSet string) (slabs []object.SlabSlice, slabBufferMaxSizeSoftReached bool, err error) { os.mu.Lock() defer os.mu.Unlock() @@ -239,7 +415,7 @@ func (os *mockObjectStore) AddPartialSlab(ctx context.Context, data []byte, minS } // update store - os.partials[ec.String()] = mockPackedSlab{ + os.partials[ec.String()] = packedSlabMock{ parameterKey: fmt.Sprintf("%d-%d-%v", minShards, totalShards, contractSet), bufferID: os.bufferIDCntr, slabKey: ec, @@ -250,18 +426,18 @@ func (os *mockObjectStore) AddPartialSlab(ctx context.Context, data []byte, minS return []object.SlabSlice{ss}, false, nil } -func (os *mockObjectStore) Object(ctx context.Context, bucket, path string, opts api.GetObjectOptions) (api.ObjectsResponse, error) { +func (os *objectStoreMock) Object(ctx context.Context, bucket, path string, opts api.GetObjectOptions) (api.ObjectsResponse, error) { os.mu.Lock() defer os.mu.Unlock() // check if the bucket exists if _, exists := os.objects[bucket]; !exists { - return api.ObjectsResponse{}, errBucketNotFound + return api.ObjectsResponse{}, api.ErrBucketNotFound } // check if the object exists if _, exists := os.objects[bucket][path]; !exists { - return api.ObjectsResponse{}, errObjectNotFound + return api.ObjectsResponse{}, api.ErrObjectNotFound } // clone to ensure the store isn't unwillingly modified @@ -278,13 +454,13 @@ func (os *mockObjectStore) Object(ctx context.Context, bucket, path string, opts }}, nil } -func (os *mockObjectStore) FetchPartialSlab(ctx context.Context, key object.EncryptionKey, offset, length uint32) ([]byte, error) { +func (os *objectStoreMock) FetchPartialSlab(ctx context.Context, key object.EncryptionKey, offset, length uint32) ([]byte, error) { os.mu.Lock() defer os.mu.Unlock() packedSlab, exists := os.partials[key.String()] if !exists { - return nil, errSlabNotFound + return nil, api.ErrSlabNotFound } if offset+length > uint32(len(packedSlab.data)) { return nil, errors.New("offset out of bounds") @@ -293,7 +469,7 @@ func (os *mockObjectStore) FetchPartialSlab(ctx context.Context, key object.Encr return packedSlab.data[offset : offset+length], nil } -func (os *mockObjectStore) Slab(ctx context.Context, key object.EncryptionKey) (slab object.Slab, err error) { +func (os *objectStoreMock) Slab(ctx context.Context, key object.EncryptionKey) (slab object.Slab, err error) { os.mu.Lock() defer os.mu.Unlock() @@ -304,12 +480,12 @@ func (os *mockObjectStore) Slab(ctx context.Context, key object.EncryptionKey) ( return } } - err = errSlabNotFound + err = api.ErrSlabNotFound }) return } -func (os *mockObjectStore) UpdateSlab(ctx context.Context, s object.Slab, contractSet string) error { +func (os *objectStoreMock) UpdateSlab(ctx context.Context, s object.Slab, contractSet string) error { os.mu.Lock() defer os.mu.Unlock() @@ -325,7 +501,7 @@ func (os *mockObjectStore) UpdateSlab(ctx context.Context, s object.Slab, contra return nil } -func (os *mockObjectStore) PackedSlabsForUpload(ctx context.Context, lockingDuration time.Duration, minShards, totalShards uint8, set string, limit int) (pss []api.PackedSlab, _ error) { +func (os *objectStoreMock) PackedSlabsForUpload(ctx context.Context, lockingDuration time.Duration, minShards, totalShards uint8, set string, limit int) (pss []api.PackedSlab, _ error) { os.mu.Lock() defer os.mu.Unlock() @@ -342,7 +518,7 @@ func (os *mockObjectStore) PackedSlabsForUpload(ctx context.Context, lockingDura return } -func (os *mockObjectStore) MarkPackedSlabsUploaded(ctx context.Context, slabs []api.UploadedPackedSlab) error { +func (os *objectStoreMock) MarkPackedSlabsUploaded(ctx context.Context, slabs []api.UploadedPackedSlab) error { os.mu.Lock() defer os.mu.Unlock() @@ -367,7 +543,15 @@ func (os *mockObjectStore) MarkPackedSlabsUploaded(ctx context.Context, slabs [] return nil } -func (os *mockObjectStore) forEachObject(fn func(bucket, path string, o object.Object)) { +func (os *objectStoreMock) Bucket(_ context.Context, bucket string) (api.Bucket, error) { + return api.Bucket{}, nil +} + +func (os *objectStoreMock) MultipartUpload(ctx context.Context, uploadID string) (resp api.MultipartUpload, err error) { + return api.MultipartUpload{}, nil +} + +func (os *objectStoreMock) forEachObject(fn func(bucket, path string, o object.Object)) { for bucket, objects := range os.objects { for path, object := range objects { fn(bucket, path, object) @@ -375,220 +559,58 @@ func (os *mockObjectStore) forEachObject(fn func(bucket, path string, o object.O } } -func newMockHost(hk types.PublicKey) *mockHost { - return &mockHost{ - hk: hk, - hpt: newTestHostPriceTable(time.Now().Add(time.Minute)), - } -} +var _ SettingStore = (*settingStoreMock)(nil) -func (h *mockHost) PublicKey() types.PublicKey { return h.hk } +type settingStoreMock struct{} -func (h *mockHost) DownloadSector(ctx context.Context, w io.Writer, root types.Hash256, offset, length uint32, overpay bool) error { - sector, exist := h.contract().sector(root) - if !exist { - return errSectorNotFound - } - if offset+length > rhpv2.SectorSize { - return errSectorOutOfBounds - } - _, err := w.Write(sector[offset : offset+length]) - return err +func (*settingStoreMock) GougingParams(context.Context) (api.GougingParams, error) { + return api.GougingParams{}, nil } -func (h *mockHost) UploadSector(ctx context.Context, sector *[rhpv2.SectorSize]byte, rev types.FileContractRevision) (types.Hash256, error) { - return h.contract().addSector(sector), nil +func (*settingStoreMock) UploadParams(context.Context) (api.UploadParams, error) { + return api.UploadParams{}, nil } -func (h *mockHost) FetchRevision(ctx context.Context, fetchTimeout time.Duration) (rev types.FileContractRevision, _ error) { - h.mu.Lock() - defer h.mu.Unlock() - rev = h.c.rev - return -} +var _ Syncer = (*syncerMock)(nil) -func (h *mockHost) FetchPriceTable(ctx context.Context, rev *types.FileContractRevision) (hostdb.HostPriceTable, error) { - <-h.hptBlockChan - return h.hpt, nil -} +type syncerMock struct{} -func (h *mockHost) FundAccount(ctx context.Context, balance types.Currency, rev *types.FileContractRevision) error { +func (*syncerMock) BroadcastTransaction(context.Context, []types.Transaction) error { return nil } -func (h *mockHost) RenewContract(ctx context.Context, rrr api.RHPRenewRequest) (_ rhpv2.ContractRevision, _ []types.Transaction, _ types.Currency, err error) { - return rhpv2.ContractRevision{}, nil, types.ZeroCurrency, nil -} - -func (h *mockHost) SyncAccount(ctx context.Context, rev *types.FileContractRevision) error { - return nil +func (*syncerMock) SyncerPeers(context.Context) ([]string, error) { + return nil, nil } -func (h *mockHost) contract() (c *mockContract) { - h.mu.Lock() - c = h.c - h.mu.Unlock() +var _ Wallet = (*walletMock)(nil) - if c == nil { - panic("host does not have a contract") - } - return -} +type walletMock struct{} -func newMockContract(hk types.PublicKey, fcid types.FileContractID) *mockContract { - return &mockContract{ - metadata: api.ContractMetadata{ - ID: fcid, - HostKey: hk, - WindowStart: 0, - WindowEnd: 10, - }, - rev: types.FileContractRevision{ParentID: fcid}, - sectors: make(map[types.Hash256]*[rhpv2.SectorSize]byte), - } -} - -func (c *mockContract) addSector(sector *[rhpv2.SectorSize]byte) (root types.Hash256) { - root = rhpv2.SectorRoot(sector) - c.mu.Lock() - c.sectors[root] = sector - c.mu.Unlock() - return -} - -func (c *mockContract) sector(root types.Hash256) (sector *[rhpv2.SectorSize]byte, found bool) { - c.mu.Lock() - sector, found = c.sectors[root] - c.mu.Unlock() - return -} - -func newMockHostManager() *mockHostManager { - return &mockHostManager{ - hosts: make(map[types.PublicKey]*mockHost), - } -} - -func (hm *mockHostManager) Host(hk types.PublicKey, fcid types.FileContractID, siamuxAddr string) Host { - hm.mu.Lock() - defer hm.mu.Unlock() - - if _, ok := hm.hosts[hk]; !ok { - panic("host not found") - } - return hm.hosts[hk] -} - -func (hm *mockHostManager) newHost(hk types.PublicKey) *mockHost { - hm.mu.Lock() - defer hm.mu.Unlock() - - if _, ok := hm.hosts[hk]; ok { - panic("host already exists") - } - - hm.hosts[hk] = newMockHost(hk) - return hm.hosts[hk] -} - -func (hm *mockHostManager) host(hk types.PublicKey) *mockHost { - hm.mu.Lock() - defer hm.mu.Unlock() - return hm.hosts[hk] -} - -func newMockSector() (*[rhpv2.SectorSize]byte, types.Hash256) { - var sector [rhpv2.SectorSize]byte - frand.Read(sector[:]) - return §or, rhpv2.SectorRoot(§or) -} - -func newMockWorker() *mockWorker { - cs := newMockContractStore() - hm := newMockHostManager() - os := newMockObjectStore() - mm := &mockMemoryManager{} - - return &mockWorker{ - cs: cs, - hm: hm, - mm: mm, - os: os, - - dl: newDownloadManager(context.Background(), hm, mm, os, 0, 0, zap.NewNop().Sugar()), - ul: newUploadManager(context.Background(), hm, mm, os, cs, 0, 0, time.Minute, zap.NewNop().Sugar()), - } -} - -func (w *mockWorker) addHosts(n int) { - for i := 0; i < n; i++ { - w.addHost() - } -} - -func (w *mockWorker) addHost() *mockHost { - host := w.hm.newHost(w.newHostKey()) - w.formContract(host) - return host +func (*walletMock) WalletDiscard(context.Context, types.Transaction) error { + return nil } -func (w *mockWorker) formContract(host *mockHost) *mockContract { - if host.c != nil { - panic("host already has contract, use renew") - } - host.c = newMockContract(host.hk, w.newFileContractID()) - w.cs.addContract(host.c) - return host.c +func (*walletMock) WalletFund(context.Context, *types.Transaction, types.Currency, bool) ([]types.Hash256, []types.Transaction, error) { + return nil, nil, nil } -func (w *mockWorker) renewContract(hk types.PublicKey) *mockContract { - host := w.hm.host(hk) - if host == nil { - panic("host not found") - } else if host.c == nil { - panic("host does not have a contract to renew") - } - - curr := host.c.metadata - update := newMockContract(host.hk, w.newFileContractID()) - update.metadata.RenewedFrom = curr.ID - update.metadata.WindowStart = curr.WindowEnd - update.metadata.WindowEnd = update.metadata.WindowStart + (curr.WindowEnd - curr.WindowStart) - host.c = update - - w.cs.addContract(host.c) - return host.c +func (*walletMock) WalletPrepareForm(context.Context, types.Address, types.PublicKey, types.Currency, types.Currency, types.PublicKey, rhpv2.HostSettings, uint64) ([]types.Transaction, error) { + return nil, nil } -func (w *mockWorker) contracts() (metadatas []api.ContractMetadata) { - for _, h := range w.hm.hosts { - metadatas = append(metadatas, h.c.metadata) - } - return +func (*walletMock) WalletPrepareRenew(context.Context, types.FileContractRevision, types.Address, types.Address, types.PrivateKey, types.Currency, types.Currency, rhpv3.HostPriceTable, uint64, uint64, uint64) (api.WalletPrepareRenewResponse, error) { + return api.WalletPrepareRenewResponse{}, nil } -func (w *mockWorker) newHostKey() (hk types.PublicKey) { - w.mu.Lock() - defer w.mu.Unlock() - w.hkCntr++ - hk = types.PublicKey{byte(w.hkCntr)} - return +func (*walletMock) WalletSign(context.Context, *types.Transaction, []types.Hash256, types.CoveredFields) error { + return nil } -func (w *mockWorker) newFileContractID() (fcid types.FileContractID) { - w.mu.Lock() - defer w.mu.Unlock() - w.fcidCntr++ - fcid = types.FileContractID{byte(w.fcidCntr)} - return -} +var _ webhooks.Broadcaster = (*webhookBroadcasterMock)(nil) -func newTestHostPriceTable(expiry time.Time) hostdb.HostPriceTable { - var uid rhpv3.SettingsID - frand.Read(uid[:]) +type webhookBroadcasterMock struct{} - return hostdb.HostPriceTable{ - HostPriceTable: rhpv3.HostPriceTable{UID: uid, Validity: time.Minute}, - Expiry: expiry, - } +func (*webhookBroadcasterMock) BroadcastAction(context.Context, webhooks.Event) error { + return nil } diff --git a/worker/pricetables_test.go b/worker/pricetables_test.go index 115abcd31..b053fa40f 100644 --- a/worker/pricetables_test.go +++ b/worker/pricetables_test.go @@ -3,90 +3,53 @@ package worker import ( "context" "errors" - "sync" "testing" "time" - "go.sia.tech/core/types" "go.sia.tech/renterd/hostdb" ) -var ( - errHostNotFound = errors.New("host not found") -) - -var ( - _ HostStore = (*mockHostStore)(nil) -) - -type mockHostStore struct { - mu sync.Mutex - hosts map[types.PublicKey]hostdb.HostInfo -} - -func (mhs *mockHostStore) Host(ctx context.Context, hostKey types.PublicKey) (hostdb.HostInfo, error) { - mhs.mu.Lock() - defer mhs.mu.Unlock() - - h, ok := mhs.hosts[hostKey] - if !ok { - return hostdb.HostInfo{}, errHostNotFound - } - return h, nil -} - -func newMockHostStore(hosts []*hostdb.HostInfo) *mockHostStore { - hs := &mockHostStore{hosts: make(map[types.PublicKey]hostdb.HostInfo)} - for _, h := range hosts { - hs.hosts[h.PublicKey] = *h - } - return hs -} - func TestPriceTables(t *testing.T) { - // create two price tables, a valid one and one that expired - expiredPT := newTestHostPriceTable(time.Now()) - validPT := newTestHostPriceTable(time.Now().Add(time.Minute)) - - // create host manager - hm := newMockHostManager() + // create host & contract stores + hs := newHostStoreMock() + cs := newContractStoreMock() - // create a mock host that has a valid price table - hk1 := types.PublicKey{1} - h1 := hm.newHost(hk1) - h1.hpt = validPT + // create host manager & price table + hm := newTestHostManager(t) + pts := newPriceTables(hm, hs) - // create a hostdb entry for that host that returns the expired price table - hdb1 := &hostdb.HostInfo{ - Host: hostdb.Host{ - PublicKey: hk1, - PriceTable: expiredPT, - Scanned: true, - }, - } + // create host & contract mock + h := hs.addHost() + c := cs.addContract(h.hk) - // create host store - hs := newMockHostStore([]*hostdb.HostInfo{hdb1}) + // expire its price table + h.hi.PriceTable = newTestHostPriceTable(time.Now()) - // create price tables - pts := newPriceTables(hm, hs) + // manage the host, make sure fetching the price table blocks + fetchPTBlockChan := make(chan struct{}) + validPT := newTestHostPriceTable(time.Now().Add(time.Minute)) + hm.addHost(newTestHostCustom(h, c, func() hostdb.HostPriceTable { + <-fetchPTBlockChan + return validPT + })) - // fetch the price table in a goroutine, make it blocking - h1.hptBlockChan = make(chan struct{}) - go pts.fetch(context.Background(), hk1, nil) + // trigger a fetch to make it block + go pts.fetch(context.Background(), h.hk, nil) time.Sleep(50 * time.Millisecond) - // fetch it again but with a canceled context to avoid blocking indefinitely, the error will indicate we were blocking on a price table update + // fetch it again but with a canceled context to avoid blocking + // indefinitely, the error will indicate we were blocking on a price table + // update ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err := pts.fetch(ctx, hk1, nil) + _, err := pts.fetch(ctx, h.hk, nil) if !errors.Is(err, errPriceTableUpdateTimedOut) { t.Fatal("expected errPriceTableUpdateTimedOut, got", err) } // unblock and assert we receive a valid price table - close(h1.hptBlockChan) - update, err := pts.fetch(context.Background(), hk1, nil) + close(fetchPTBlockChan) + update, err := pts.fetch(context.Background(), h.hk, nil) if err != nil { t.Fatal(err) } else if update.UID != validPT.UID { @@ -95,8 +58,9 @@ func TestPriceTables(t *testing.T) { // refresh the price table on the host, update again, assert we receive the // same price table as it hasn't expired yet - h1.hpt = newTestHostPriceTable(time.Now().Add(time.Minute)) - update, err = pts.fetch(context.Background(), hk1, nil) + refreshedPT := newTestHostPriceTable(time.Now().Add(time.Minute)) + h.hi.PriceTable = refreshedPT + update, err = pts.fetch(context.Background(), h.hk, nil) if err != nil { t.Fatal(err) } else if update.UID != validPT.UID { @@ -104,13 +68,13 @@ func TestPriceTables(t *testing.T) { } // manually expire the price table - pts.priceTables[hk1].hpt.Expiry = time.Now() + pts.priceTables[h.hk].hpt.Expiry = time.Now() // fetch it again and assert we updated the price table - update, err = pts.fetch(context.Background(), hk1, nil) + update, err = pts.fetch(context.Background(), h.hk, nil) if err != nil { t.Fatal(err) - } else if update.UID != h1.hpt.UID { + } else if update.UID != refreshedPT.UID { t.Fatal("price table mismatch") } } diff --git a/worker/upload.go b/worker/upload.go index 72c65bf07..6048661e3 100644 --- a/worker/upload.go +++ b/worker/upload.go @@ -39,6 +39,7 @@ type ( hm HostManager mm MemoryManager os ObjectStore + cl ContractLocker cs ContractStore logger *zap.SugaredLogger @@ -148,8 +149,8 @@ func (w *worker) initUploadManager(maxMemory, maxOverdrive uint64, overdriveTime panic("upload manager already initialized") // developer error } - mm := newMemoryManager(logger, maxMemory) - w.uploadManager = newUploadManager(w.shutdownCtx, w, mm, w.bus, w.bus, maxOverdrive, overdriveTimeout, w.contractLockingDuration, logger) + mm := newMemoryManager(logger.Named("memorymanager"), maxMemory) + w.uploadManager = newUploadManager(w.shutdownCtx, w, mm, w.bus, w.bus, w.bus, maxOverdrive, overdriveTimeout, w.contractLockingDuration, logger) } func (w *worker) upload(ctx context.Context, r io.Reader, contracts []api.ContractMetadata, up uploadParameters, opts ...UploadOption) (_ string, err error) { @@ -314,11 +315,12 @@ func (w *worker) uploadPackedSlab(ctx context.Context, rs api.RedundancySettings return nil } -func newUploadManager(ctx context.Context, hm HostManager, mm MemoryManager, os ObjectStore, cs ContractStore, maxOverdrive uint64, overdriveTimeout time.Duration, contractLockDuration time.Duration, logger *zap.SugaredLogger) *uploadManager { +func newUploadManager(ctx context.Context, hm HostManager, mm MemoryManager, os ObjectStore, cl ContractLocker, cs ContractStore, maxOverdrive uint64, overdriveTimeout time.Duration, contractLockDuration time.Duration, logger *zap.SugaredLogger) *uploadManager { return &uploadManager{ hm: hm, mm: mm, os: os, + cl: cl, cs: cs, logger: logger, @@ -336,9 +338,10 @@ func newUploadManager(ctx context.Context, hm HostManager, mm MemoryManager, os } } -func (mgr *uploadManager) newUploader(os ObjectStore, cs ContractStore, hm HostManager, c api.ContractMetadata) *uploader { +func (mgr *uploadManager) newUploader(os ObjectStore, cl ContractLocker, cs ContractStore, hm HostManager, c api.ContractMetadata) *uploader { return &uploader{ os: os, + cl: cl, cs: cs, hm: hm, logger: mgr.logger, @@ -751,7 +754,7 @@ func (mgr *uploadManager) refreshUploaders(contracts []api.ContractMetadata, bh // add missing uploaders for _, c := range contracts { if _, exists := existing[c.ID]; !exists && bh < c.WindowEnd { - uploader := mgr.newUploader(mgr.os, mgr.cs, mgr.hm, c) + uploader := mgr.newUploader(mgr.os, mgr.cl, mgr.cs, mgr.hm, c) refreshed = append(refreshed, uploader) go uploader.Start() } diff --git a/worker/upload_test.go b/worker/upload_test.go index 0b9f6b28b..efd8247b5 100644 --- a/worker/upload_test.go +++ b/worker/upload_test.go @@ -14,26 +14,22 @@ import ( "lukechampine.com/frand" ) -const ( - testBucket = "testbucket" - testContractSet = "testcontractset" -) - var ( + testContractSet = "testcontractset" testRedundancySettings = api.RedundancySettings{MinShards: 2, TotalShards: 6} ) func TestUpload(t *testing.T) { - // mock worker - w := newMockWorker() + // create test worker + w := newTestWorker(t) // add hosts to worker w.addHosts(testRedundancySettings.TotalShards * 2) // convenience variables os := w.os - dl := w.dl - ul := w.ul + dl := w.downloadManager + ul := w.uploadManager // create test data data := make([]byte, 128) @@ -115,7 +111,7 @@ func TestUpload(t *testing.T) { // try and upload into a bucket that does not exist params.bucket = "doesnotexist" _, _, err = ul.Upload(context.Background(), bytes.NewReader(data), w.contracts(), params, lockingPriorityUpload) - if !errors.Is(err, errBucketNotFound) { + if !errors.Is(err, api.ErrBucketNotFound) { t.Fatal("expected bucket not found error", err) } @@ -129,17 +125,17 @@ func TestUpload(t *testing.T) { } func TestUploadPackedSlab(t *testing.T) { - // mock worker - w := newMockWorker() + // create test worker + w := newTestWorker(t) // add hosts to worker w.addHosts(testRedundancySettings.TotalShards * 2) // convenience variables os := w.os - mm := w.mm - dl := w.dl - ul := w.ul + mm := w.ulmm + dl := w.downloadManager + ul := w.uploadManager // create test data data := make([]byte, 128) @@ -215,17 +211,17 @@ func TestUploadPackedSlab(t *testing.T) { } func TestUploadShards(t *testing.T) { - // mock worker - w := newMockWorker() + // create test worker + w := newTestWorker(t) // add hosts to worker w.addHosts(testRedundancySettings.TotalShards * 2) // convenience variables os := w.os - mm := w.mm - dl := w.dl - ul := w.ul + mm := w.ulmm + dl := w.downloadManager + ul := w.uploadManager // create test data data := make([]byte, 128) @@ -334,16 +330,16 @@ func TestUploadShards(t *testing.T) { } func TestRefreshUploaders(t *testing.T) { - // mock worker - w := newMockWorker() + // create test worker + w := newTestWorker(t) // add hosts to worker w.addHosts(testRedundancySettings.TotalShards) // convenience variables - ul := w.ul - hm := w.hm + ul := w.uploadManager cs := w.cs + hm := w.hm // create test data data := make([]byte, 128) @@ -356,7 +352,7 @@ func TestRefreshUploaders(t *testing.T) { // upload data contracts := w.contracts() - _, _, err := ul.Upload(context.Background(), bytes.NewReader(data), contracts, params, lockingPriorityUpload) + _, err := w.upload(context.Background(), bytes.NewReader(data), contracts, params) if err != nil { t.Fatal(err) } @@ -373,7 +369,7 @@ func TestRefreshUploaders(t *testing.T) { // remove the host from the second contract c2 := contracts[1] delete(hm.hosts, c2.HostKey) - delete(cs.locks, c2.ID) + delete(cs.contracts, c2.ID) // add a new host/contract hNew := w.addHost() @@ -389,7 +385,7 @@ func TestRefreshUploaders(t *testing.T) { var added, renewed int for _, ul := range ul.uploaders { switch ul.ContractID() { - case hNew.c.metadata.ID: + case hNew.metadata.ID: added++ case c1Renewed.metadata.ID: renewed++ @@ -410,7 +406,7 @@ func TestRefreshUploaders(t *testing.T) { // manually add a request to the queue of one of the uploaders we're about to expire responseChan := make(chan sectorUploadResp, 1) for _, ul := range ul.uploaders { - if ul.fcid == hNew.c.metadata.ID { + if ul.fcid == hNew.metadata.ID { ul.mu.Lock() ul.queue = append(ul.queue, §orUploadReq{responseChan: responseChan, sector: §orUpload{ctx: context.Background()}}) ul.mu.Unlock() @@ -436,17 +432,17 @@ func TestRefreshUploaders(t *testing.T) { } func TestUploadRegression(t *testing.T) { - // mock worker - w := newMockWorker() + // create test worker + w := newTestWorker(t) // add hosts to worker w.addHosts(testRedundancySettings.TotalShards) // convenience variables - mm := w.mm os := w.os - ul := w.ul - dl := w.dl + mm := w.ulmm + ul := w.uploadManager + dl := w.downloadManager // create test data data := make([]byte, 128) diff --git a/worker/uploader.go b/worker/uploader.go index dcff27eaf..3791f8b27 100644 --- a/worker/uploader.go +++ b/worker/uploader.go @@ -27,6 +27,7 @@ type ( uploader struct { os ObjectStore cs ContractStore + cl ContractLocker hm HostManager logger *zap.SugaredLogger @@ -200,13 +201,13 @@ func (u *uploader) execute(req *sectorUploadReq) (types.Hash256, time.Duration, u.mu.Unlock() // acquire contract lock - lockID, err := u.cs.AcquireContract(req.sector.ctx, fcid, req.contractLockPriority, req.contractLockDuration) + lockID, err := u.cl.AcquireContract(req.sector.ctx, fcid, req.contractLockPriority, req.contractLockDuration) if err != nil { return types.Hash256{}, 0, err } // defer the release - lock := newContractLock(u.shutdownCtx, fcid, lockID, req.contractLockDuration, u.cs, u.logger) + lock := newContractLock(u.shutdownCtx, fcid, lockID, req.contractLockDuration, u.cl, u.logger) defer func() { ctx, cancel := context.WithTimeout(u.shutdownCtx, 10*time.Second) lock.Release(ctx) diff --git a/worker/uploader_test.go b/worker/uploader_test.go index 7217cbaab..514d17aab 100644 --- a/worker/uploader_test.go +++ b/worker/uploader_test.go @@ -8,11 +8,13 @@ import ( ) func TestUploaderStopped(t *testing.T) { - w := newMockWorker() - w.addHost() - w.ul.refreshUploaders(w.contracts(), 1) + w := newTestWorker(t) + w.addHosts(1) - ul := w.ul.uploaders[0] + um := w.uploadManager + um.refreshUploaders(w.contracts(), 1) + + ul := um.uploaders[0] ul.Stop(errors.New("test")) req := sectorUploadReq{ diff --git a/worker/worker.go b/worker/worker.go index 17faca7cb..b2d0fdc9d 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -69,44 +69,21 @@ func NewClient(address, password string) *Client { type ( Bus interface { alerts.Alerter - consensusState + ConsensusState webhooks.Broadcaster AccountStore + ContractLocker ContractStore + HostStore ObjectStore + SettingStore - BroadcastTransaction(ctx context.Context, txns []types.Transaction) error - SyncerPeers(ctx context.Context) (resp []string, err error) - - Contract(ctx context.Context, id types.FileContractID) (api.ContractMetadata, error) - ContractSize(ctx context.Context, id types.FileContractID) (api.ContractSize, error) - ContractRoots(ctx context.Context, id types.FileContractID) ([]types.Hash256, []types.Hash256, error) - Contracts(ctx context.Context, opts api.ContractsOpts) ([]api.ContractMetadata, error) - - RecordHostScans(ctx context.Context, scans []hostdb.HostScan) error - RecordPriceTables(ctx context.Context, priceTableUpdate []hostdb.PriceTableUpdate) error - RecordContractSpending(ctx context.Context, records []api.ContractSpendingRecord) error - - Host(ctx context.Context, hostKey types.PublicKey) (hostdb.HostInfo, error) - - GougingParams(ctx context.Context) (api.GougingParams, error) - UploadParams(ctx context.Context) (api.UploadParams, error) - - Object(ctx context.Context, bucket, path string, opts api.GetObjectOptions) (api.ObjectsResponse, error) - DeleteObject(ctx context.Context, bucket, path string, opts api.DeleteObjectOptions) error - MultipartUpload(ctx context.Context, uploadID string) (resp api.MultipartUpload, err error) - PackedSlabsForUpload(ctx context.Context, lockingDuration time.Duration, minShards, totalShards uint8, set string, limit int) ([]api.PackedSlab, error) - - WalletDiscard(ctx context.Context, txn types.Transaction) error - WalletFund(ctx context.Context, txn *types.Transaction, amount types.Currency, useUnconfirmedTxns bool) ([]types.Hash256, []types.Transaction, error) - WalletPrepareForm(ctx context.Context, renterAddress types.Address, renterKey types.PublicKey, renterFunds, hostCollateral types.Currency, hostKey types.PublicKey, hostSettings rhpv2.HostSettings, endHeight uint64) (txns []types.Transaction, err error) - WalletPrepareRenew(ctx context.Context, revision types.FileContractRevision, hostAddress, renterAddress types.Address, renterKey types.PrivateKey, renterFunds, minNewCollateral types.Currency, pt rhpv3.HostPriceTable, endHeight, windowSize, expectedStorage uint64) (api.WalletPrepareRenewResponse, error) - WalletSign(ctx context.Context, txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) error - - Bucket(_ context.Context, bucket string) (api.Bucket, error) + Syncer + Wallet } + // An AccountStore manages ephemaral accounts state. AccountStore interface { Accounts(ctx context.Context) ([]api.Account, error) AddBalance(ctx context.Context, id rhpv3.Account, hk types.PublicKey, amt *big.Int) error @@ -120,11 +97,21 @@ type ( } ContractStore interface { - ContractLocker - + Contract(ctx context.Context, id types.FileContractID) (api.ContractMetadata, error) + ContractSize(ctx context.Context, id types.FileContractID) (api.ContractSize, error) + ContractRoots(ctx context.Context, id types.FileContractID) ([]types.Hash256, []types.Hash256, error) + Contracts(ctx context.Context, opts api.ContractsOpts) ([]api.ContractMetadata, error) RenewedContract(ctx context.Context, renewedFrom types.FileContractID) (api.ContractMetadata, error) } + HostStore interface { + RecordHostScans(ctx context.Context, scans []hostdb.HostScan) error + RecordPriceTables(ctx context.Context, priceTableUpdate []hostdb.PriceTableUpdate) error + RecordContractSpending(ctx context.Context, records []api.ContractSpendingRecord) error + + Host(ctx context.Context, hostKey types.PublicKey) (hostdb.HostInfo, error) + } + ObjectStore interface { // NOTE: used for download DeleteHostSector(ctx context.Context, hk types.PublicKey, root types.Hash256) error @@ -140,9 +127,34 @@ type ( MarkPackedSlabsUploaded(ctx context.Context, slabs []api.UploadedPackedSlab) error TrackUpload(ctx context.Context, uID api.UploadID) error UpdateSlab(ctx context.Context, s object.Slab, contractSet string) error + + // NOTE: used by worker + Bucket(_ context.Context, bucket string) (api.Bucket, error) + Object(ctx context.Context, bucket, path string, opts api.GetObjectOptions) (api.ObjectsResponse, error) + DeleteObject(ctx context.Context, bucket, path string, opts api.DeleteObjectOptions) error + MultipartUpload(ctx context.Context, uploadID string) (resp api.MultipartUpload, err error) + PackedSlabsForUpload(ctx context.Context, lockingDuration time.Duration, minShards, totalShards uint8, set string, limit int) ([]api.PackedSlab, error) } - consensusState interface { + SettingStore interface { + GougingParams(ctx context.Context) (api.GougingParams, error) + UploadParams(ctx context.Context) (api.UploadParams, error) + } + + Syncer interface { + BroadcastTransaction(ctx context.Context, txns []types.Transaction) error + SyncerPeers(ctx context.Context) (resp []string, err error) + } + + Wallet interface { + WalletDiscard(ctx context.Context, txn types.Transaction) error + WalletFund(ctx context.Context, txn *types.Transaction, amount types.Currency, useUnconfirmedTxns bool) ([]types.Hash256, []types.Transaction, error) + WalletPrepareForm(ctx context.Context, renterAddress types.Address, renterKey types.PublicKey, renterFunds, hostCollateral types.Currency, hostKey types.PublicKey, hostSettings rhpv2.HostSettings, endHeight uint64) (txns []types.Transaction, err error) + WalletPrepareRenew(ctx context.Context, revision types.FileContractRevision, hostAddress, renterAddress types.Address, renterKey types.PrivateKey, renterFunds, minNewCollateral types.Currency, pt rhpv3.HostPriceTable, endHeight, windowSize, expectedStorage uint64) (api.WalletPrepareRenewResponse, error) + WalletSign(ctx context.Context, txn *types.Transaction, toSign []types.Hash256, cf types.CoveredFields) error + } + + ConsensusState interface { ConsensusState(ctx context.Context) (api.ConsensusState, error) } ) @@ -183,7 +195,9 @@ func (w *worker) deriveRenterKey(hostKey types.PublicKey) types.PrivateKey { // A worker talks to Sia hosts to perform contract and storage operations within // a renterd system. type worker struct { - alerts alerts.Alerter + alerts alerts.Alerter + contracts ContractStore + allowPrivateIPs bool id string bus Bus @@ -1274,7 +1288,7 @@ func (w *worker) stateHandlerGET(jc jape.Context) { } // New returns an HTTP handler that serves the worker API. -func New(masterKey [32]byte, id string, b Bus, contractLockingDuration, busFlushInterval, downloadOverdriveTimeout, uploadOverdriveTimeout time.Duration, downloadMaxOverdrive, downloadMaxMemory, uploadMaxMemory, uploadMaxOverdrive uint64, allowPrivateIPs bool, l *zap.Logger) (*worker, error) { +func New(masterKey [32]byte, id string, b Bus, contractLockingDuration, busFlushInterval, downloadOverdriveTimeout, uploadOverdriveTimeout time.Duration, downloadMaxOverdrive, uploadMaxOverdrive, downloadMaxMemory, uploadMaxMemory uint64, allowPrivateIPs bool, l *zap.Logger) (*worker, error) { if contractLockingDuration == 0 { return nil, errors.New("contract lock duration must be positive") } @@ -1294,6 +1308,7 @@ func New(masterKey [32]byte, id string, b Bus, contractLockingDuration, busFlush return nil, errors.New("uploadMaxMemory cannot be 0") } + l = l.Named("worker").Named(id) ctx, cancel := context.WithCancel(context.Background()) w := &worker{ alerts: alerts.WithOrigin(b, fmt.Sprintf("worker.%s", id)), @@ -1302,7 +1317,7 @@ func New(masterKey [32]byte, id string, b Bus, contractLockingDuration, busFlush id: id, bus: b, masterKey: masterKey, - logger: l.Sugar().Named("worker").Named(id), + logger: l.Sugar(), startTime: time.Now(), uploadingPackedSlabs: make(map[string]bool), shutdownCtx: ctx, @@ -1313,8 +1328,8 @@ func New(masterKey [32]byte, id string, b Bus, contractLockingDuration, busFlush w.initPriceTables() w.initTransportPool() - w.initDownloadManager(downloadMaxMemory, downloadMaxOverdrive, downloadOverdriveTimeout, l.Sugar().Named("downloadmanager")) - w.initUploadManager(uploadMaxMemory, uploadMaxOverdrive, uploadOverdriveTimeout, l.Sugar().Named("uploadmanager")) + w.initDownloadManager(downloadMaxMemory, downloadMaxOverdrive, downloadOverdriveTimeout, l.Named("downloadmanager").Sugar()) + w.initUploadManager(uploadMaxMemory, uploadMaxOverdrive, uploadOverdriveTimeout, l.Named("uploadmanager").Sugar()) w.initContractSpendingRecorder(busFlushInterval) return w, nil diff --git a/worker/worker_test.go b/worker/worker_test.go new file mode 100644 index 000000000..ed510b71c --- /dev/null +++ b/worker/worker_test.go @@ -0,0 +1,114 @@ +package worker + +import ( + "context" + "testing" + "time" + + rhpv2 "go.sia.tech/core/rhp/v2" + "go.sia.tech/core/types" + "go.sia.tech/renterd/api" + "go.uber.org/zap" + "golang.org/x/crypto/blake2b" + "lukechampine.com/frand" +) + +// TODO: the fact we have to override the host and memory managers after +// initialising the worker shows our dependency injection isn't quite right + +type ( + testWorker struct { + t *testing.T + *worker + + cs *contractStoreMock + os *objectStoreMock + hs *hostStoreMock + + dlmm *memoryManagerMock + ulmm *memoryManagerMock + + hm *testHostManager + } +) + +const testBucket = "testbucket" + +func newTestWorker(t *testing.T) *testWorker { + // create bus dependencies + cs := newContractStoreMock() + os := newObjectStoreMock(testBucket) + hs := newHostStoreMock() + + // create worker dependencies + b := newBusMock(cs, hs, os) + dlmm := &memoryManagerMock{} + ulmm := &memoryManagerMock{} + + // create worker + w, err := New(blake2b.Sum256([]byte("testwork")), "test", b, time.Second, time.Second, time.Second, time.Second, 0, 0, 1, 1, false, zap.NewNop()) + if err != nil { + t.Fatal(err) + } + + // override managers + hm := newTestHostManager(t) + w.priceTables.hm = hm + w.downloadManager.hm = hm + w.downloadManager.mm = dlmm + w.uploadManager.hm = hm + w.uploadManager.mm = ulmm + + return &testWorker{ + t, + w, + cs, + os, + hs, + ulmm, + dlmm, + hm, + } +} + +func (w *testWorker) addHosts(n int) (added []*testHost) { + for i := 0; i < n; i++ { + added = append(added, w.addHost()) + } + return +} + +func (w *testWorker) addHost() *testHost { + h := w.hs.addHost() + c := w.cs.addContract(h.hk) + host := newTestHost(h, c) + w.hm.addHost(host) + return host +} + +func (w *testWorker) contracts() []api.ContractMetadata { + metadatas, err := w.cs.Contracts(context.Background(), api.ContractsOpts{}) + if err != nil { + w.t.Fatal(err) + } + return metadatas +} + +func (w *testWorker) renewContract(hk types.PublicKey) *contractMock { + h := w.hm.hosts[hk] + if h == nil { + w.t.Fatal("host not found") + } + + renewal, err := w.cs.renewContract(hk) + if err != nil { + w.t.Fatal(err) + } + return renewal +} + +func newTestSector() (*[rhpv2.SectorSize]byte, types.Hash256) { + var sector [rhpv2.SectorSize]byte + frand.Read(sector[:]) + return §or, rhpv2.SectorRoot(§or) +}