diff --git a/cmd/db-integrity-check/main.go b/cmd/db-integrity-check/main.go deleted file mode 100644 index 4b74aee88..000000000 --- a/cmd/db-integrity-check/main.go +++ /dev/null @@ -1,32 +0,0 @@ -// +build !js - -// package db-integrity-check is an executable that can be used to check -// the integrity of the database used internally by 0x Mesh. -package main - -import ( - "log" - - "github.com/0xProject/0x-mesh/db" - "github.com/plaid/go-envvar/envvar" -) - -type envVars struct { - // DatabaseDir is the directory where the database files are persisted. - DatabaseDir string `envvar:"DATABASE_DIR" default:"0x_mesh/db"` -} - -func main() { - env := envVars{} - if err := envvar.Parse(&env); err != nil { - log.Fatal(err) - } - database, err := db.Open(env.DatabaseDir) - if err != nil { - log.Fatal(err) - } - if err := database.CheckIntegrity(); err != nil { - log.Fatal(err) - } - log.Print("Integrity check passed ✓") -} diff --git a/cmd/mesh/main.go b/cmd/mesh/main.go index c6f1e0f7d..afb68dc39 100644 --- a/cmd/mesh/main.go +++ b/cmd/mesh/main.go @@ -38,13 +38,13 @@ func main() { log.WithField("error", err.Error()).Fatal("could not parse environment variables") } - // Start core.App. - app, err := core.New(coreConfig) + // Initialize core.App. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + app, err := core.New(ctx, coreConfig) if err != nil { log.WithField("error", err.Error()).Fatal("could not initialize app") } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() // Below, we will start several independent goroutines. We use separate // channels to communicate errors and a waitgroup to wait for all goroutines @@ -55,7 +55,7 @@ func main() { wg.Add(1) go func() { defer wg.Done() - if err := app.Start(ctx); err != nil { + if err := app.Start(); err != nil { coreErrChan <- err } }() diff --git a/cmd/mesh/rpc_handler.go b/cmd/mesh/rpc_handler.go index 60a3e4e76..66ec4ce9f 100644 --- a/cmd/mesh/rpc_handler.go +++ b/cmd/mesh/rpc_handler.go @@ -18,6 +18,7 @@ import ( "github.com/0xProject/0x-mesh/rpc" "github.com/0xProject/0x-mesh/zeroex" "github.com/0xProject/0x-mesh/zeroex/ordervalidator" + "github.com/ethereum/go-ethereum/common" ethrpc "github.com/ethereum/go-ethereum/rpc" peerstore "github.com/libp2p/go-libp2p-peerstore" log "github.com/sirupsen/logrus" @@ -60,11 +61,10 @@ func instantiateServer(ctx context.Context, app *core.App, rpcAddr string) *rpc. } // GetOrders is called when an RPC client calls GetOrders. -func (handler *rpcHandler) GetOrders(page, perPage int, snapshotID string) (result *types.GetOrdersResponse, err error) { +func (handler *rpcHandler) GetOrders(perPage int, minOrderHashHex string) (result *types.GetOrdersResponse, err error) { log.WithFields(map[string]interface{}{ - "page": page, - "perPage": perPage, - "snapshotID": snapshotID, + "perPage": perPage, + "minOrderHashHex": minOrderHashHex, }).Debug("received GetOrders request via RPC") // Catch panics, log stack trace and return RPC error message defer func() { @@ -82,7 +82,8 @@ func (handler *rpcHandler) GetOrders(page, perPage int, snapshotID string) (resu err = errors.New("method handler crashed in GetOrders RPC call (check logs for stack trace)") } }() - getOrdersResponse, err := handler.app.GetOrders(page, perPage, snapshotID) + var minOrderHash = common.HexToHash(minOrderHashHex) + getOrdersResponse, err := handler.app.GetOrders(perPage, minOrderHash) if err != nil { if _, ok := err.(core.ErrSnapshotNotFound); ok { return nil, err diff --git a/common/types/types.go b/common/types/types.go index afc2da5ee..fcf37789a 100644 --- a/common/types/types.go +++ b/common/types/types.go @@ -11,6 +11,7 @@ import ( "github.com/0xProject/0x-mesh/zeroex" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" + "github.com/ethereum/go-ethereum/core/types" ) // Stats is the return value for core.GetStats. Also used in the browser and RPC @@ -42,9 +43,8 @@ type LatestBlock struct { // GetOrdersResponse is the return value for core.GetOrders. Also used in the // browser and RPC interface. type GetOrdersResponse struct { - SnapshotID string `json:"snapshotID"` - SnapshotTimestamp time.Time `json:"snapshotTimestamp"` - OrdersInfos []*OrderInfo `json:"ordersInfos"` + Timestamp time.Time `json:"timestamp"` + OrdersInfos []*OrderInfo `json:"ordersInfos"` } // AddOrdersOpts is a set of options for core.AddOrders. Also used in the @@ -96,3 +96,84 @@ func (o *OrderInfo) UnmarshalJSON(data []byte) error { } return nil } + +type OrderWithMetadata struct { + Hash common.Hash `json:"hash"` + ChainID *big.Int `json:"chainID"` + ExchangeAddress common.Address `json:"exchangeAddress"` + MakerAddress common.Address `json:"makerAddress"` + MakerAssetData []byte `json:"makerAssetData"` + MakerFeeAssetData []byte `json:"makerFeeAssetData"` + MakerAssetAmount *big.Int `json:"makerAssetAmount"` + MakerFee *big.Int `json:"makerFee"` + TakerAddress common.Address `json:"takerAddress"` + TakerAssetData []byte `json:"takerAssetData"` + TakerFeeAssetData []byte `json:"takerFeeAssetData"` + TakerAssetAmount *big.Int `json:"takerAssetAmount"` + TakerFee *big.Int `json:"takerFee"` + SenderAddress common.Address `json:"senderAddress"` + FeeRecipientAddress common.Address `json:"feeRecipientAddress"` + ExpirationTimeSeconds *big.Int `json:"expirationTimeSeconds"` + Salt *big.Int `json:"salt"` + Signature []byte `json:"signature"` + FillableTakerAssetAmount *big.Int `json:"fillableTakerAssetAmount"` + LastUpdated time.Time `json:"lastUpdated"` + // Was this order flagged for removal? Due to the possibility of block-reorgs, instead + // of immediately removing an order when FillableTakerAssetAmount becomes 0, we instead + // flag it for removal. After this order isn't updated for X time and has IsRemoved = true, + // the order can be permanently deleted. + IsRemoved bool `json:"isRemoved"` + // IsPinned indicates whether or not the order is pinned. Pinned orders are + // not removed from the database unless they become unfillable. + IsPinned bool `json:"isPinned"` + // JSON-encoded list of assetdatas contained in MakerAssetData. For non-MAP + // orders, the list contains only one element which is equal to MakerAssetData. + // For MAP orders, it contains each component assetdata. + ParsedMakerAssetData []*SingleAssetData `json:"parsedMakerAssetData"` + // Same as ParsedMakerAssetData but for MakerFeeAssetData instead of MakerAssetData. + ParsedMakerFeeAssetData []*SingleAssetData `json:"parsedMakerFeeAssetData"` +} + +func (order OrderWithMetadata) SignedOrder() *zeroex.SignedOrder { + return &zeroex.SignedOrder{ + Order: zeroex.Order{ + ChainID: order.ChainID, + ExchangeAddress: order.ExchangeAddress, + MakerAddress: order.MakerAddress, + MakerAssetData: order.MakerAssetData, + MakerFeeAssetData: order.MakerFeeAssetData, + MakerAssetAmount: order.MakerAssetAmount, + MakerFee: order.MakerFee, + TakerAddress: order.TakerAddress, + TakerAssetData: order.TakerAssetData, + TakerFeeAssetData: order.TakerFeeAssetData, + TakerAssetAmount: order.TakerAssetAmount, + TakerFee: order.TakerFee, + SenderAddress: order.SenderAddress, + FeeRecipientAddress: order.FeeRecipientAddress, + ExpirationTimeSeconds: order.ExpirationTimeSeconds, + Salt: order.Salt, + }, + Signature: order.Signature, + } +} + +type SingleAssetData struct { + Address common.Address `json:"address"` + TokenID *big.Int `json:"tokenID"` +} + +type MiniHeader struct { + Hash common.Hash `json:"hash"` + Parent common.Hash `json:"parent"` + Number *big.Int `json:"number"` + Timestamp time.Time `json:"timestamp"` + Logs []types.Log `json:"logs"` +} + +type Metadata struct { + EthereumChainID int + MaxExpirationTime *big.Int + EthRPCRequestsSentInCurrentUTCDay int + StartOfCurrentUTCDay time.Time +} diff --git a/common/types/types_js.go b/common/types/types_js.go index 86587d849..183af691b 100644 --- a/common/types/types_js.go +++ b/common/types/types_js.go @@ -5,6 +5,8 @@ package types import ( "encoding/json" "syscall/js" + + "github.com/0xProject/0x-mesh/packages/browser/go/jsutil" ) func (r GetOrdersResponse) JSValue() js.Value { @@ -48,3 +50,8 @@ func (s Stats) JSValue() js.Value { "ethRPCRateLimitExpiredRequests": s.EthRPCRateLimitExpiredRequests, }) } + +func (o OrderWithMetadata) JSValue() js.Value { + value, _ := jsutil.InefficientlyConvertToJS(o) + return value +} diff --git a/constants/constants.go b/constants/constants.go index 54e67457d..06eab4b24 100644 --- a/constants/constants.go +++ b/constants/constants.go @@ -105,3 +105,8 @@ var ( const ParityFilterUnknownBlock = "One of the blocks specified in filter (fromBlock, toBlock or blockHash) cannot be found" const GethFilterUnknownBlock = "unknown block" + +var ( + ZRXAssetData = common.Hex2Bytes("f47261b0000000000000000000000000871dd7c2b4b25e1aa18728e9d5f2af4c4e431f5c") + WETHAssetData = common.Hex2Bytes("f47261b00000000000000000000000000b1ba0af832d7c05fd64161e0db78e85978e8082") +) diff --git a/core/core.go b/core/core.go index b76c614bb..e5120f151 100644 --- a/core/core.go +++ b/core/core.go @@ -22,11 +22,8 @@ import ( "github.com/0xProject/0x-mesh/ethereum/blockwatch" "github.com/0xProject/0x-mesh/ethereum/ethrpcclient" "github.com/0xProject/0x-mesh/ethereum/ratelimit" - "github.com/0xProject/0x-mesh/ethereum/simplestack" - "github.com/0xProject/0x-mesh/expirationwatch" "github.com/0xProject/0x-mesh/keys" "github.com/0xProject/0x-mesh/loghooks" - "github.com/0xProject/0x-mesh/meshdb" "github.com/0xProject/0x-mesh/orderfilter" "github.com/0xProject/0x-mesh/p2p" "github.com/0xProject/0x-mesh/zeroex" @@ -38,7 +35,6 @@ import ( "github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/rpc" - "github.com/google/uuid" p2pcrypto "github.com/libp2p/go-libp2p-core/crypto" peer "github.com/libp2p/go-libp2p-core/peer" peerstore "github.com/libp2p/go-libp2p-peerstore" @@ -74,11 +70,16 @@ const ( // within the core package. Intended for testing purposes. type privateConfig struct { paginationSubprotocolPerPage int + paginationSubprotocols []ordersyncSubprotocolFactory } func defaultPrivateConfig() privateConfig { return privateConfig{ paginationSubprotocolPerPage: 500, + paginationSubprotocols: []ordersyncSubprotocolFactory{ + NewFilteredPaginationSubprotocolV1, + NewFilteredPaginationSubprotocolV0, + }, } } @@ -191,31 +192,23 @@ type Config struct { EthereumRPCClient ethclient.RPCClient `envvar:"-"` } -type snapshotInfo struct { - Snapshot *db.Snapshot - CreatedAt time.Time - ExpirationTimestamp time.Time -} - type App struct { - config Config - privateConfig privateConfig - peerID peer.ID - privKey p2pcrypto.PrivKey - node *p2p.Node - chainID int - blockWatcher *blockwatch.Watcher - orderWatcher *orderwatch.Watcher - orderValidator *ordervalidator.OrderValidator - orderFilter *orderfilter.Filter - snapshotExpirationWatcher *expirationwatch.Watcher - muIdToSnapshotInfo sync.Mutex - idToSnapshotInfo map[string]snapshotInfo - ethRPCRateLimiter ratelimit.RateLimiter - ethRPCClient ethrpcclient.Client - db *meshdb.MeshDB - ordersyncService *ordersync.Service - contractAddresses *ethereum.ContractAddresses + ctx context.Context + config Config + privateConfig privateConfig + peerID peer.ID + privKey p2pcrypto.PrivKey + node *p2p.Node + chainID int + blockWatcher *blockwatch.Watcher + orderWatcher *orderwatch.Watcher + orderValidator *ordervalidator.OrderValidator + orderFilter *orderfilter.Filter + ethRPCRateLimiter ratelimit.RateLimiter + ethRPCClient ethrpcclient.Client + db *db.DB + ordersyncService *ordersync.Service + contractAddresses *ethereum.ContractAddresses // started is closed to signal that the App has been started. Some methods // will block until after the App is started. @@ -224,11 +217,11 @@ type App struct { var setupLoggerOnce = &sync.Once{} -func New(config Config) (*App, error) { - return newWithPrivateConfig(config, defaultPrivateConfig()) +func New(ctx context.Context, config Config) (*App, error) { + return newWithPrivateConfig(ctx, config, defaultPrivateConfig()) } -func newWithPrivateConfig(config Config, pConfig privateConfig) (*App, error) { +func newWithPrivateConfig(ctx context.Context, config Config, pConfig privateConfig) (*App, error) { // Configure logger // TODO(albrow): Don't use global variables for log settings. setupLoggerOnce.Do(func() { @@ -281,14 +274,13 @@ func newWithPrivateConfig(config Config, pConfig privateConfig) (*App, error) { } // Initialize db - databasePath := filepath.Join(config.DataDir, "db") - meshDB, err := meshdb.New(databasePath, contractAddresses) + database, err := newDB(ctx, config) if err != nil { return nil, err } // Initialize metadata and check stored chain id (if any). - metadata, err := initMetadata(config.EthereumChainID, meshDB) + metadata, err := initMetadata(config.EthereumChainID, database) if err != nil { return nil, err } @@ -300,7 +292,7 @@ func newWithPrivateConfig(config Config, pConfig privateConfig) (*App, error) { } else { clock := clock.New() var err error - ethRPCRateLimiter, err = ratelimit.New(config.EthereumRPCMaxRequestsPer24HrUTC, config.EthereumRPCMaxRequestsPerSecond, meshDB, clock) + ethRPCRateLimiter, err = ratelimit.New(config.EthereumRPCMaxRequestsPer24HrUTC, config.EthereumRPCMaxRequestsPerSecond, database, clock) if err != nil { return nil, err } @@ -333,38 +325,9 @@ func newWithPrivateConfig(config Config, pConfig privateConfig) (*App, error) { return nil, err } - // Remove any old mini headers that might be lingering in the database. - // See https://github.com/0xProject/0x-mesh/issues/667 and https://github.com/0xProject/0x-mesh/pull/716 - // We need to leave this in place becuase: - // - // 1. It is still necessary for anyone upgrading from older versions to >= 9.0.1 in the future. - // 2. There's still a chance there are old MiniHeaders in the database (e.g. due to a sudden - // unexpected shut down). - // - totalMiniHeaders, err := meshDB.MiniHeaders.Count() - if err != nil { - return nil, err - } - miniHeadersToRemove := totalMiniHeaders - meshDB.MiniHeaderRetentionLimit - if miniHeadersToRemove > 0 { - log.WithFields(log.Fields{ - "numHeadersToRemove": miniHeadersToRemove, - "totalHeadersStored": totalMiniHeaders, - }).Warn("Removing outdated block headers in database (this can take a while)") - } - err = meshDB.PruneMiniHeadersAboveRetentionLimit() - if err != nil { - return nil, err - } - topics := orderwatch.GetRelevantTopics() - miniHeaders, err := meshDB.FindAllMiniHeadersSortedByNumber() - if err != nil { - return nil, err - } - stack := simplestack.New(meshDB.MiniHeaderRetentionLimit, miniHeaders) blockWatcherConfig := blockwatch.Config{ - Stack: stack, + DB: database, PollingInterval: config.BlockPollingInterval, WithLogs: true, Topics: topics, @@ -385,7 +348,7 @@ func newWithPrivateConfig(config Config, pConfig privateConfig) (*App, error) { // Initialize order watcher (but don't start it yet). orderWatcher, err := orderwatch.New(orderwatch.Config{ - MeshDB: meshDB, + DB: database, BlockWatcher: blockWatcher, OrderValidator: orderValidator, ChainID: config.EthereumChainID, @@ -403,26 +366,22 @@ func newWithPrivateConfig(config Config, pConfig privateConfig) (*App, error) { return nil, fmt.Errorf("invalid custom order filter: %s", err.Error()) } - // Initialize remaining fields. - snapshotExpirationWatcher := expirationwatch.New() - app := &App{ - started: make(chan struct{}), - config: config, - privateConfig: pConfig, - privKey: privKey, - peerID: peerID, - chainID: config.EthereumChainID, - blockWatcher: blockWatcher, - orderWatcher: orderWatcher, - orderValidator: orderValidator, - orderFilter: orderFilter, - snapshotExpirationWatcher: snapshotExpirationWatcher, - idToSnapshotInfo: map[string]snapshotInfo{}, - ethRPCRateLimiter: ethRPCRateLimiter, - ethRPCClient: ethClient, - db: meshDB, - contractAddresses: &contractAddresses, + ctx: ctx, + started: make(chan struct{}), + config: config, + privateConfig: pConfig, + privKey: privKey, + peerID: peerID, + chainID: config.EthereumChainID, + blockWatcher: blockWatcher, + orderWatcher: orderWatcher, + orderValidator: orderValidator, + orderFilter: orderFilter, + ethRPCRateLimiter: ethRPCRateLimiter, + ethRPCClient: ethClient, + db: database, + contractAddresses: &contractAddresses, } log.WithFields(map[string]interface{}{ @@ -497,16 +456,16 @@ func initPrivateKey(path string) (p2pcrypto.PrivKey, error) { return nil, err } -func initMetadata(chainID int, meshDB *meshdb.MeshDB) (*meshdb.Metadata, error) { - metadata, err := meshDB.GetMetadata() +func initMetadata(chainID int, database *db.DB) (*types.Metadata, error) { + metadata, err := database.GetMetadata() if err != nil { - if _, ok := err.(db.NotFoundError); ok { + if err == db.ErrNotFound { // No stored metadata found (first startup) - metadata = &meshdb.Metadata{ + metadata = &types.Metadata{ EthereumChainID: chainID, MaxExpirationTime: constants.UnlimitedExpirationTime, } - if err := meshDB.SaveMetadata(metadata); err != nil { + if err := database.SaveMetadata(metadata); err != nil { return nil, err } return metadata, nil @@ -523,7 +482,7 @@ func initMetadata(chainID int, meshDB *meshdb.MeshDB) (*meshdb.Metadata, error) return metadata, nil } -func (app *App) Start(ctx context.Context) error { +func (app *App) Start() error { // Get the publish topics depending on our custom order filter. publishTopics, err := getPublishTopics(app.config.EthereumChainID, *app.contractAddresses, app.orderFilter) if err != nil { @@ -532,7 +491,7 @@ func (app *App) Start(ctx context.Context) error { // Create a child context so that we can preemptively cancel if there is an // error. - innerCtx, cancel := context.WithCancel(ctx) + innerCtx, cancel := context.WithCancel(app.ctx) defer cancel() // Below, we will start several independent goroutines. We use separate @@ -540,17 +499,6 @@ func (app *App) Start(ctx context.Context) error { // to exit. wg := &sync.WaitGroup{} - // Close the database when the context is canceled. - wg.Add(1) - go func() { - defer wg.Done() - defer func() { - log.Debug("closing app.db") - }() - <-innerCtx.Done() - app.db.Close() - }() - // Start rateLimiter ethRPCRateLimiterErrChan := make(chan error, 1) wg.Add(1) @@ -562,29 +510,6 @@ func (app *App) Start(ctx context.Context) error { ethRPCRateLimiterErrChan <- app.ethRPCRateLimiter.Start(innerCtx, rateLimiterCheckpointInterval) }() - // Set up the snapshot expiration watcher pruning logic - wg.Add(1) - go func() { - defer wg.Done() - defer func() { - log.Debug("closing snapshot expiration watcher") - }() - ticker := time.NewTicker(expirationPollingInterval) - for { - select { - case <-innerCtx.Done(): - return - case now := <-ticker.C: - expiredSnapshots := app.snapshotExpirationWatcher.Prune(now) - for _, expiredSnapshot := range expiredSnapshots { - app.muIdToSnapshotInfo.Lock() - delete(app.idToSnapshotInfo, expiredSnapshot.ID) - app.muIdToSnapshotInfo.Unlock() - } - } - } - }() - // Start the order watcher. orderWatcherErrChan := make(chan error, 1) wg.Add(1) @@ -641,7 +566,7 @@ func (app *App) Start(ctx context.Context) error { // so that Mesh does not validate any orders at outdated block heights isCaughtUp := app.IsCaughtUpToLatestBlock(innerCtx) if !isCaughtUp { - if err := app.orderWatcher.WaitForAtLeastOneBlockToBeProcessed(ctx); err != nil { + if err := app.orderWatcher.WaitForAtLeastOneBlockToBeProcessed(innerCtx); err != nil { return err } } @@ -689,8 +614,9 @@ func (app *App) Start(ctx context.Context) error { } // Register and start ordersync service. - ordersyncSubprotocols := []ordersync.Subprotocol{ - NewFilteredPaginationSubprotocol(app, app.privateConfig.paginationSubprotocolPerPage), + var ordersyncSubprotocols []ordersync.Subprotocol + for _, subprotocolFactory := range app.privateConfig.paginationSubprotocols { + ordersyncSubprotocols = append(ordersyncSubprotocols, subprotocolFactory(app, app.privateConfig.paginationSubprotocolPerPage)) } app.ordersyncService = ordersync.New(innerCtx, app.node, ordersyncSubprotocols) orderSyncErrChan := make(chan error, 1) @@ -852,11 +778,21 @@ func (e ErrPerPageZero) Error() string { return "perPage cannot be zero" } -// GetOrders retrieves paginated orders from the Mesh DB at a specific snapshot in time. Passing an empty -// string as `snapshotID` creates a new snapshot and returns the first set of results. To fetch all orders, -// continue to make requests supplying the `snapshotID` returned from the first request. After 1 minute of not -// received further requests referencing a specific snapshot, the snapshot expires and can no longer be used. -func (app *App) GetOrders(page, perPage int, snapshotID string) (*types.GetOrdersResponse, error) { +// GetOrders retrieves perPage orders from the database with an order hash greater than +// minOrderHash (exclusive). The orders in the response are sorted by hash. In order to +// paginate through all orders: +// +// 1. First call GetOrders with an empty minOrderHash. +// 2. On subsequent calls, use the maximum hash of the orders from the previous response as the next minOrderHash. +// 3. When no orders are returned, pagination is complete. +// +// When following this process, GetOrders offers the following guarantees: +// +// 1. Any order that was present before pagination started *and* was present after pagination ended will be included in a response. +// 2. No order will be included in more than one response. +// 3. Orders that were added or deleted during pagination may or may not be included in a response. +// +func (app *App) GetOrders(perPage int, minOrderHash common.Hash) (*types.GetOrdersResponse, error) { <-app.started if perPage <= 0 { @@ -864,66 +800,43 @@ func (app *App) GetOrders(page, perPage int, snapshotID string) (*types.GetOrder } ordersInfos := []*types.OrderInfo{} - var snapshot *db.Snapshot - var createdAt time.Time - if snapshotID == "" { - // Create a new snapshot - snapshotID = uuid.New().String() - var err error - snapshot, err = app.db.Orders.GetSnapshot() - if err != nil { - return nil, err - } - createdAt = time.Now().UTC() - expirationTimestamp := time.Now().Add(1 * time.Minute) - app.snapshotExpirationWatcher.Add(expirationTimestamp, snapshotID) - app.muIdToSnapshotInfo.Lock() - app.idToSnapshotInfo[snapshotID] = snapshotInfo{ - Snapshot: snapshot, - CreatedAt: createdAt, - ExpirationTimestamp: expirationTimestamp, - } - app.muIdToSnapshotInfo.Unlock() - } else { - // Try and find an existing snapshot - app.muIdToSnapshotInfo.Lock() - info, ok := app.idToSnapshotInfo[snapshotID] - if !ok { - app.muIdToSnapshotInfo.Unlock() - return nil, ErrSnapshotNotFound{id: snapshotID} - } - snapshot = info.Snapshot - createdAt = info.CreatedAt - // Reset the snapshot's expiry - app.snapshotExpirationWatcher.Remove(info.ExpirationTimestamp, snapshotID) - expirationTimestamp := time.Now().Add(1 * time.Minute) - app.snapshotExpirationWatcher.Add(expirationTimestamp, snapshotID) - app.idToSnapshotInfo[snapshotID] = snapshotInfo{ - Snapshot: snapshot, - CreatedAt: createdAt, - ExpirationTimestamp: expirationTimestamp, - } - app.muIdToSnapshotInfo.Unlock() - } - - notRemovedFilter := app.db.Orders.IsRemovedIndex.ValueFilter([]byte{0}) - var selectedOrders []*meshdb.Order - err := snapshot.NewQuery(notRemovedFilter).Offset(page * perPage).Max(perPage).Run(&selectedOrders) + query := &db.OrderQuery{ + Filters: []db.OrderFilter{ + { + Field: db.OFIsRemoved, + Kind: db.Equal, + Value: false, + }, + { + Field: db.OFHash, + Kind: db.Greater, + Value: minOrderHash, + }, + }, + Sort: []db.OrderSort{ + { + Field: db.OFHash, + Direction: db.Ascending, + }, + }, + Limit: uint(perPage), + } + + orders, err := app.db.FindOrders(query) if err != nil { return nil, err } - for _, order := range selectedOrders { + for _, order := range orders { ordersInfos = append(ordersInfos, &types.OrderInfo{ OrderHash: order.Hash, - SignedOrder: order.SignedOrder, + SignedOrder: order.SignedOrder(), FillableTakerAssetAmount: order.FillableTakerAssetAmount, }) } getOrdersResponse := &types.GetOrdersResponse{ - SnapshotID: snapshotID, - SnapshotTimestamp: createdAt, - OrdersInfos: ordersInfos, + Timestamp: time.Now(), + OrdersInfos: ordersInfos, } return getOrdersResponse, nil @@ -1053,24 +966,43 @@ func (app *App) AddPeer(peerInfo peerstore.PeerInfo) error { func (app *App) GetStats() (*types.Stats, error) { <-app.started - latestBlockHeader, err := app.db.FindLatestMiniHeader() + var latestBlock types.LatestBlock + latestMiniHeader, err := app.db.GetLatestMiniHeader() if err != nil { - return nil, err - } - latestBlock := types.LatestBlock{ - Number: int(latestBlockHeader.Number.Int64()), - Hash: latestBlockHeader.Hash, + if err != db.ErrNotFound { + // ErrNotFound is okay. For any other error, return it. + return nil, err + } } - notRemovedFilter := app.db.Orders.IsRemovedIndex.ValueFilter([]byte{0}) - numOrders, err := app.db.Orders.NewQuery(notRemovedFilter).Count() + if latestMiniHeader != nil { + latestBlock.Number = int(latestMiniHeader.Number.Int64()) + latestBlock.Hash = latestMiniHeader.Hash + } + numOrders, err := app.db.CountOrders(&db.OrderQuery{ + Filters: []db.OrderFilter{ + { + Field: db.OFIsRemoved, + Kind: db.Equal, + Value: false, + }, + }, + }) if err != nil { return nil, err } - numOrdersIncludingRemoved, err := app.db.Orders.Count() + numOrdersIncludingRemoved, err := app.db.CountOrders(nil) if err != nil { return nil, err } - numPinnedOrders, err := app.db.CountPinnedOrders() + numPinnedOrders, err := app.db.CountOrders(&db.OrderQuery{ + Filters: []db.OrderFilter{ + { + Field: db.OFIsPinned, + Kind: db.Equal, + Value: true, + }, + }, + }) if err != nil { return nil, err } @@ -1148,9 +1080,10 @@ func (app *App) SubscribeToOrderEvents(sink chan<- []*zeroex.OrderEvent) event.S // IsCaughtUpToLatestBlock returns whether or not the latest block stored by Mesh corresponds // to the latest block retrieved from it's Ethereum RPC endpoint func (app *App) IsCaughtUpToLatestBlock(ctx context.Context) bool { - latestBlockStored, err := app.db.FindLatestMiniHeader() + latestStoredBlock, err := app.db.GetLatestMiniHeader() if err != nil { - if _, ok := err.(meshdb.MiniHeaderCollectionEmptyError); ok { + if err == db.ErrNotFound { + // This just means there are no MiniHeaders stored. return false } log.WithFields(map[string]interface{}{ @@ -1158,14 +1091,14 @@ func (app *App) IsCaughtUpToLatestBlock(ctx context.Context) bool { }).Warn("failed to fetch the latest miniHeader from DB") return false } - latestBlock, err := app.ethRPCClient.HeaderByNumber(ctx, nil) + latestRPCBlock, err := app.ethRPCClient.HeaderByNumber(ctx, nil) if err != nil { log.WithFields(map[string]interface{}{ "err": err.Error(), }).Warn("failed to fetch the latest block header via Ethereum RPC") return false } - return latestBlock.Number.Cmp(latestBlockStored.Number) == 0 + return latestRPCBlock.Number.Cmp(latestStoredBlock.Number) == 0 } func parseAndValidateCustomContractAddresses(chainID int, encodedContractAddresses string) (ethereum.ContractAddresses, error) { diff --git a/core/core_test.go b/core/core_test.go index 993a1ef5a..aaf6cc0e1 100644 --- a/core/core_test.go +++ b/core/core_test.go @@ -5,20 +5,22 @@ package core import ( "context" "flag" + "fmt" "sync" "testing" "time" "github.com/0xProject/0x-mesh/constants" + "github.com/0xProject/0x-mesh/db" "github.com/0xProject/0x-mesh/ethereum" - "github.com/0xProject/0x-mesh/meshdb" "github.com/0xProject/0x-mesh/scenario" "github.com/0xProject/0x-mesh/scenario/orderopts" "github.com/0xProject/0x-mesh/zeroex" "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/common" ethrpc "github.com/ethereum/go-ethereum/rpc" "github.com/google/uuid" - "github.com/libp2p/go-libp2p-core/peer" + peer "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -31,23 +33,22 @@ const ( ordersyncWaitTime = 2 * time.Second ) -var contractAddresses = ethereum.GanacheAddresses - func TestEthereumChainDetection(t *testing.T) { - meshDB, err := meshdb.New("/tmp/meshdb_testing/"+uuid.New().String(), contractAddresses) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) - defer meshDB.Close() // simulate starting up on mainnet - _, err = initMetadata(1, meshDB) + _, err = initMetadata(1, database) require.NoError(t, err) // simulate restart on same chain - _, err = initMetadata(1, meshDB) + _, err = initMetadata(1, database) require.NoError(t, err) // should error when attempting to start on different chain - _, err = initMetadata(2, meshDB) + _, err = initMetadata(2, database) assert.Error(t, err) } @@ -77,13 +78,13 @@ func TestConfigChainIDAndRPCMatchDetection(t *testing.T) { MaxOrdersInStorage: 100000, CustomOrderFilter: "{}", } - app, err := New(config) + app, err := New(ctx, config) require.NoError(t, err) wg.Add(1) go func() { defer wg.Done() - err := app.Start(ctx) + err := app.Start() require.Error(t, err) require.Contains(t, err.Error(), "ChainID mismatch") }() @@ -92,11 +93,11 @@ func TestConfigChainIDAndRPCMatchDetection(t *testing.T) { wg.Wait() } -func newTestApp(t *testing.T) *App { - return newTestAppWithPrivateConfig(t, defaultPrivateConfig()) +func newTestApp(t *testing.T, ctx context.Context) *App { + return newTestAppWithPrivateConfig(t, ctx, defaultPrivateConfig()) } -func newTestAppWithPrivateConfig(t *testing.T, pConfig privateConfig) *App { +func newTestAppWithPrivateConfig(t *testing.T, ctx context.Context, pConfig privateConfig) *App { dataDir := "/tmp/test_node/" + uuid.New().String() config := Config{ Verbosity: 2, @@ -115,7 +116,7 @@ func newTestAppWithPrivateConfig(t *testing.T, pConfig privateConfig) *App { MaxOrdersInStorage: 100000, CustomOrderFilter: "{}", } - app, err := newWithPrivateConfig(config, pConfig) + app, err := newWithPrivateConfig(ctx, config, pConfig) require.NoError(t, err) return app } @@ -146,6 +147,8 @@ func init() { } func TestRepeatedAppInitialization(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() dataDir := "/tmp/test_node/" + uuid.New().String() config := Config{ Verbosity: 2, @@ -165,10 +168,9 @@ func TestRepeatedAppInitialization(t *testing.T) { CustomOrderFilter: "{}", CustomContractAddresses: `{"exchange":"0x48bacb9266a570d521063ef5dd96e61686dbe788","devUtils":"0x38ef19fdf8e8415f18c307ed71967e19aac28ba1","erc20Proxy":"0x1dc4c1cefef38a777b15aa20260a54e584b16c48","erc721Proxy":"0x1d7022f5b17d2f8b695918fb48fa1089c9f85401","erc1155Proxy":"0x64517fa2b480ba3678a2a3c0cf08ef7fd4fad36f"}`, } - app, err := New(config) + _, err := New(ctx, config) require.NoError(t, err) - app.db.Close() - _, err = New(config) + _, err = New(ctx, config) require.NoError(t, err) } @@ -177,100 +179,140 @@ func TestOrderSync(t *testing.T) { t.Skip("Serial tests (tests which cannot run in parallel) are disabled. You can enable them with the --serial flag") } - teardownSubTest := setupSubTest(t) - defer teardownSubTest(t) - - // Set up two Mesh nodes. originalNode starts with some orders. newNode enters - // the network without any orders. - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) - defer cancel() - wg := &sync.WaitGroup{} - - perPage := 10 - pConfig := privateConfig{ - paginationSubprotocolPerPage: perPage, + testCases := []ordersyncTestCase{ + { + name: "FilteredPaginationSubprotocol version 0", + pConfig: privateConfig{ + paginationSubprotocolPerPage: 10, + paginationSubprotocols: []ordersyncSubprotocolFactory{ + NewFilteredPaginationSubprotocolV0, + }, + }, + }, + { + name: "FilteredPaginationSubprotocol version 1", + pConfig: privateConfig{ + paginationSubprotocolPerPage: 10, + paginationSubprotocols: []ordersyncSubprotocolFactory{ + NewFilteredPaginationSubprotocolV1, + }, + }, + }, + { + name: "FilteredPaginationSubprotocol version 1 and version 0", + pConfig: privateConfig{ + paginationSubprotocolPerPage: 10, + paginationSubprotocols: []ordersyncSubprotocolFactory{ + NewFilteredPaginationSubprotocolV1, + NewFilteredPaginationSubprotocolV0, + }, + }, + }, } - originalNode := newTestAppWithPrivateConfig(t, pConfig) - wg.Add(1) - go func() { - defer wg.Done() - if err := originalNode.Start(ctx); err != nil && err != context.Canceled { - // context.Canceled is expected. For any other error, fail the test. - require.NoError(t, err) - } - }() + for i, testCase := range testCases { + testCaseName := fmt.Sprintf("%s (test case %d)", testCase.name, i) + t.Run(testCaseName, runOrdersyncTestCase(t, testCase)) + } +} - // Manually add some orders to originalNode. - orderOptions := scenario.OptionsForAll(orderopts.SetupMakerState(true)) - originalOrders := scenario.NewSignedTestOrdersBatch(t, perPage*3+1, orderOptions) +type ordersyncTestCase struct { + name string + pConfig privateConfig +} - // We have to wait for latest block to be processed by the Mesh node. - time.Sleep(blockProcessingWaitTime) +func runOrdersyncTestCase(t *testing.T, testCase ordersyncTestCase) func(t *testing.T) { + return func(t *testing.T) { + teardownSubTest := setupSubTest(t) + defer teardownSubTest(t) + + // Set up two Mesh nodes. originalNode starts with some orders. newNode enters + // the network without any orders. + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + wg := &sync.WaitGroup{} + originalNode := newTestAppWithPrivateConfig(t, ctx, testCase.pConfig) + wg.Add(1) + go func() { + defer wg.Done() + if err := originalNode.Start(); err != nil && err != context.Canceled { + // context.Canceled is expected. For any other error, fail the test. + require.NoError(t, err) + } + }() - results, err := originalNode.orderWatcher.ValidateAndStoreValidOrders(ctx, originalOrders, true, constants.TestChainID) - require.NoError(t, err) - require.Empty(t, results.Rejected, "tried to add orders but some were invalid: \n%s\n", spew.Sdump(results)) + // Manually add some orders to originalNode. + orderOptions := scenario.OptionsForAll(orderopts.SetupMakerState(true)) + numOrders := testCase.pConfig.paginationSubprotocolPerPage*3 + 1 + originalOrders := scenario.NewSignedTestOrdersBatch(t, numOrders, orderOptions) - newNode := newTestApp(t) - wg.Add(1) - go func() { - defer wg.Done() - if err := newNode.Start(ctx); err != nil && err != context.Canceled { - // context.Canceled is expected. For any other error, fail the test. - require.NoError(t, err) - } - }() - <-newNode.started - - orderEventsChan := make(chan []*zeroex.OrderEvent) - orderEventsSub := newNode.SubscribeToOrderEvents(orderEventsChan) - defer orderEventsSub.Unsubscribe() - - // Connect the two nodes *after* adding orders to one of them. This should - // trigger the ordersync protocol. - err = originalNode.AddPeer(peer.AddrInfo{ - ID: newNode.node.ID(), - Addrs: newNode.node.Multiaddrs(), - }) - require.NoError(t, err) + // We have to wait for latest block to be processed by the Mesh node. + time.Sleep(blockProcessingWaitTime) - // Wait for newNode to get the orders via ordersync. - receivedAddedEvents := []*zeroex.OrderEvent{} -OrderEventLoop: - for { - select { - case <-ctx.Done(): - t.Fatalf("timed out waiting for %d order added events (received %d so far)", len(originalOrders), len(receivedAddedEvents)) - case orderEvents := <-orderEventsChan: - for _, orderEvent := range orderEvents { - if orderEvent.EndState == zeroex.ESOrderAdded { - receivedAddedEvents = append(receivedAddedEvents, orderEvent) - } + results, err := originalNode.orderWatcher.ValidateAndStoreValidOrders(ctx, originalOrders, true, constants.TestChainID) + require.NoError(t, err) + require.Empty(t, results.Rejected, "tried to add orders but some were invalid: \n%s\n", spew.Sdump(results)) + + newNode := newTestApp(t, ctx) + wg.Add(1) + go func() { + defer wg.Done() + if err := newNode.Start(); err != nil && err != context.Canceled { + // context.Canceled is expected. For any other error, fail the test. + require.NoError(t, err) } - if len(receivedAddedEvents) >= len(originalOrders) { - break OrderEventLoop + }() + <-newNode.started + + orderEventsChan := make(chan []*zeroex.OrderEvent) + orderEventsSub := newNode.SubscribeToOrderEvents(orderEventsChan) + defer orderEventsSub.Unsubscribe() + + // Connect the two nodes *after* adding orders to one of them. This should + // trigger the ordersync protocol. + err = originalNode.AddPeer(peer.AddrInfo{ + ID: newNode.node.ID(), + Addrs: newNode.node.Multiaddrs(), + }) + require.NoError(t, err) + + // Wait for newNode to get the orders via ordersync. + receivedAddedEvents := []*zeroex.OrderEvent{} + OrderEventLoop: + for { + select { + case <-ctx.Done(): + t.Fatalf("timed out waiting for %d order added events (received %d so far)", len(originalOrders), len(receivedAddedEvents)) + case orderEvents := <-orderEventsChan: + for _, orderEvent := range orderEvents { + if orderEvent.EndState == zeroex.ESOrderAdded { + receivedAddedEvents = append(receivedAddedEvents, orderEvent) + } + } + if len(receivedAddedEvents) >= len(originalOrders) { + break OrderEventLoop + } } } - } - // Test that the orders are actually in the database and are returned by - // GetOrders. - newNodeOrdersResp, err := newNode.GetOrders(0, len(originalOrders), "") - require.NoError(t, err) - assert.Len(t, newNodeOrdersResp.OrdersInfos, len(originalOrders), "new node should have %d orders", len(originalOrders)) - for _, expectedOrder := range originalOrders { - orderHash, err := expectedOrder.ComputeOrderHash() + // Test that the orders are actually in the database and are returned by + // GetOrders. + newNodeOrdersResp, err := newNode.GetOrders(len(originalOrders), common.Hash{}) require.NoError(t, err) - expectedOrder.ResetHash() - var dbOrder meshdb.Order - require.NoError(t, newNode.db.Orders.FindByID(orderHash.Bytes(), &dbOrder)) - actualOrder := dbOrder.SignedOrder - assert.Equal(t, expectedOrder, actualOrder, "correct order was not stored in new node database") - } + assert.Len(t, newNodeOrdersResp.OrdersInfos, len(originalOrders), "new node should have %d orders", len(originalOrders)) + for _, expectedOrder := range originalOrders { + orderHash, err := expectedOrder.ComputeOrderHash() + require.NoError(t, err) + expectedOrder.ResetHash() + dbOrder, err := newNode.db.GetOrder(orderHash) + require.NoError(t, err) + actualOrder := dbOrder.SignedOrder() + assert.Equal(t, expectedOrder, actualOrder, "correct order was not stored in new node database") + } - // Wait for nodes to exit without error. - cancel() - wg.Wait() + // Wait for nodes to exit without error. + cancel() + wg.Wait() + } } func setupSubTest(t *testing.T) func(t *testing.T) { diff --git a/core/new_db.go b/core/new_db.go new file mode 100644 index 000000000..5ae4d8df8 --- /dev/null +++ b/core/new_db.go @@ -0,0 +1,19 @@ +// +build !js + +package core + +import ( + "context" + "path/filepath" + + "github.com/0xProject/0x-mesh/db" +) + +func newDB(ctx context.Context, config Config) (*db.DB, error) { + databasePath := filepath.Join(config.DataDir, "sqlite-db", "db.sqlite") + return db.New(ctx, &db.Options{ + DriverName: "sqlite3", + DataSourceName: databasePath, + MaxOrders: config.MaxOrdersInStorage, + }) +} diff --git a/core/new_db_js.go b/core/new_db_js.go new file mode 100644 index 000000000..a09b8e1a3 --- /dev/null +++ b/core/new_db_js.go @@ -0,0 +1,19 @@ +// +build js,wasm + +package core + +import ( + "context" + "path/filepath" + + "github.com/0xProject/0x-mesh/db" +) + +func newDB(ctx context.Context, config Config) (*db.DB, error) { + databasePath := filepath.Join(config.DataDir, "mesh_dexie_db") + return db.New(ctx, &db.Options{ + DriverName: "dexie", + DataSourceName: databasePath, + MaxOrders: config.MaxOrdersInStorage, + }) +} diff --git a/core/ordersync_subprotocols.go b/core/ordersync_subprotocols.go index cd9ee2864..80589b183 100644 --- a/core/ordersync_subprotocols.go +++ b/core/ordersync_subprotocols.go @@ -2,6 +2,7 @@ package core import ( "context" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -9,76 +10,93 @@ import ( "github.com/0xProject/0x-mesh/core/ordersync" "github.com/0xProject/0x-mesh/orderfilter" "github.com/0xProject/0x-mesh/zeroex" + "github.com/ethereum/go-ethereum/common" log "github.com/sirupsen/logrus" ) -// Ensure that FilteredPaginationSubProtocol implements the Subprotocol interface. -var _ ordersync.Subprotocol = (*FilteredPaginationSubProtocol)(nil) +// ordersyncSubprotocolFactory is a function that can be used to create an ordersync.Subprotocol. +// Note(albrow): Using a factory here allows us to specify which subprotocols we want to use before +// the app is fully initialized. Factory functions won't actually be called until the app is done +// initializing. +type ordersyncSubprotocolFactory func(app *App, perPage int) ordersync.Subprotocol -// FilteredPaginationSubProtocol is an ordersync subprotocol which returns all orders by +// Ensure that FilteredPaginationSubProtocolV0 implements the Subprotocol interface. +var _ ordersync.Subprotocol = (*FilteredPaginationSubProtocolV0)(nil) + +// FilteredPaginationSubProtocolV0 is an ordersync subprotocol which returns all orders by // paginating through them. It involves sending multiple requests until pagination is -// finished and all orders have been returned. -type FilteredPaginationSubProtocol struct { +// finished and all orders have been returned. Version 0 of the subprotocol is deprecated +// but included for backwards-compatibility. +type FilteredPaginationSubProtocolV0 struct { app *App orderFilter *orderfilter.Filter perPage int } -// NewFilteredPaginationSubprotocol creates and returns a new FilteredPaginationSubprotocol +// NewFilteredPaginationSubprotocolV0 creates and returns a new FilteredPaginationSubprotocolV0 // which will respond with perPage orders for each individual request/response. -func NewFilteredPaginationSubprotocol(app *App, perPage int) *FilteredPaginationSubProtocol { - return &FilteredPaginationSubProtocol{ +func NewFilteredPaginationSubprotocolV0(app *App, perPage int) ordersync.Subprotocol { + return &FilteredPaginationSubProtocolV0{ app: app, orderFilter: app.orderFilter, perPage: perPage, } } -// FilteredPaginationRequestMetadata is the request metadata for the -// FilteredPaginationSubProtocol. It keeps track of the current page and SnapshotID, +// FilteredPaginationRequestMetadataV0 is the request metadata for the +// FilteredPaginationSubProtocolV0. It keeps track of the current page and SnapshotID, // which is expected to be an empty string on the first request. -type FilteredPaginationRequestMetadata struct { +type FilteredPaginationRequestMetadataV0 struct { Page int `json:"page"` SnapshotID string `json:"snapshotID"` } -// FilteredPaginationResponseMetadata is the response metadata for the -// FilteredPaginationSubProtocol. It keeps track of the current page and SnapshotID. -type FilteredPaginationResponseMetadata struct { +// FilteredPaginationResponseMetadataV0 is the response metadata for the +// FilteredPaginationSubProtocolV0. It keeps track of the current page and SnapshotID. +type FilteredPaginationResponseMetadataV0 struct { Page int `json:"page"` SnapshotID string `json:"snapshotID"` } -// Name returns the name of the FilteredPaginationSubProtocol -func (p *FilteredPaginationSubProtocol) Name() string { +// Name returns the name of the FilteredPaginationSubProtocolV0 +func (p *FilteredPaginationSubProtocolV0) Name() string { return "/pagination-with-filter/version/0" } // HandleOrderSyncRequest returns the orders for one page, based on the page number // and snapshotID corresponding to the given request. This is // the implementation for the "provider" side of the subprotocol. -func (p *FilteredPaginationSubProtocol) HandleOrderSyncRequest(ctx context.Context, req *ordersync.Request) (*ordersync.Response, error) { - var metadata *FilteredPaginationRequestMetadata +func (p *FilteredPaginationSubProtocolV0) HandleOrderSyncRequest(ctx context.Context, req *ordersync.Request) (*ordersync.Response, error) { + var metadata *FilteredPaginationRequestMetadataV0 if req.Metadata == nil { // Default metadata for the first request. - metadata = &FilteredPaginationRequestMetadata{ + metadata = &FilteredPaginationRequestMetadataV0{ Page: 0, SnapshotID: "", } } else { var ok bool - metadata, ok = req.Metadata.(*FilteredPaginationRequestMetadata) + metadata, ok = req.Metadata.(*FilteredPaginationRequestMetadataV0) if !ok { - return nil, fmt.Errorf("FilteredPaginationSubProtocol received request with wrong metadata type (got %T)", req.Metadata) + return nil, fmt.Errorf("FilteredPaginationSubProtocolV0 received request with wrong metadata type (got %T)", req.Metadata) } } + // Note(albrow): This version of Mesh does not support database snapshots. Instead, we use the SnapshotID + // field as minOrderHash. + var currentMinOrderHash common.Hash + if metadata.SnapshotID != "" { + if err := validateHexHash(metadata.SnapshotID); err != nil { + return nil, fmt.Errorf("FilteredPaginationSubProtocolV0 could not decode snapshotID (%q) as hex: %s", metadata.SnapshotID, err.Error()) + } + currentMinOrderHash = common.HexToHash(metadata.SnapshotID) + } + // It's possible that none of the orders in the current page match the filter. // We don't want to respond with zero orders, so keep iterating until we find // at least some orders that match the filter. filteredOrders := []*zeroex.SignedOrder{} - var snapshotID string - currentPage := metadata.Page + var nextMinOrderHash common.Hash for { select { case <-ctx.Done(): @@ -86,15 +104,15 @@ func (p *FilteredPaginationSubProtocol) HandleOrderSyncRequest(ctx context.Conte default: } // Get the orders for this page. - ordersResp, err := p.app.GetOrders(currentPage, p.perPage, metadata.SnapshotID) + ordersResp, err := p.app.GetOrders(p.perPage, currentMinOrderHash) if err != nil { return nil, err } - snapshotID = ordersResp.SnapshotID if len(ordersResp.OrdersInfos) == 0 { // No more orders left. break } + nextMinOrderHash = ordersResp.OrdersInfos[len(ordersResp.OrdersInfos)-1].OrderHash // Filter the orders for this page. for _, orderInfo := range ordersResp.OrdersInfos { if matches, err := p.orderFilter.MatchOrder(orderInfo.SignedOrder); err != nil { @@ -106,7 +124,7 @@ func (p *FilteredPaginationSubProtocol) HandleOrderSyncRequest(ctx context.Conte if len(filteredOrders) == 0 { // If none of the orders for this page match the filter, we continue // on to the next page. - currentPage += 1 + currentMinOrderHash = nextMinOrderHash continue } else { break @@ -116,9 +134,10 @@ func (p *FilteredPaginationSubProtocol) HandleOrderSyncRequest(ctx context.Conte return &ordersync.Response{ Orders: filteredOrders, Complete: len(filteredOrders) == 0, - Metadata: &FilteredPaginationResponseMetadata{ - Page: currentPage, - SnapshotID: snapshotID, + Metadata: &FilteredPaginationResponseMetadataV0{ + // Note(albrow): Page isn't actually used. Included for backwards compatibility only. + Page: metadata.Page + 1, + SnapshotID: nextMinOrderHash.Hex(), }, }, nil } @@ -126,13 +145,13 @@ func (p *FilteredPaginationSubProtocol) HandleOrderSyncRequest(ctx context.Conte // HandleOrderSyncResponse handles the orders for one page by validating them, storing them // in the database, and firing the appropriate events. It also returns the next request to // be sent. This is the implementation for the "requester" side of the subprotocol. -func (p *FilteredPaginationSubProtocol) HandleOrderSyncResponse(ctx context.Context, res *ordersync.Response) (*ordersync.Request, error) { +func (p *FilteredPaginationSubProtocolV0) HandleOrderSyncResponse(ctx context.Context, res *ordersync.Response) (*ordersync.Request, error) { if res.Metadata == nil { - return nil, errors.New("FilteredPaginationSubProtocol received response with nil metadata") + return nil, errors.New("FilteredPaginationSubProtocolV0 received response with nil metadata") } - metadata, ok := res.Metadata.(*FilteredPaginationResponseMetadata) + metadata, ok := res.Metadata.(*FilteredPaginationResponseMetadataV0) if !ok { - return nil, fmt.Errorf("FilteredPaginationSubProtocol received response with wrong metadata type (got %T)", res.Metadata) + return nil, fmt.Errorf("FilteredPaginationSubProtocolV0 received response with wrong metadata type (got %T)", res.Metadata) } filteredOrders := []*zeroex.SignedOrder{} for _, order := range res.Orders { @@ -165,25 +184,229 @@ func (p *FilteredPaginationSubProtocol) HandleOrderSyncResponse(ctx context.Cont } return &ordersync.Request{ - Metadata: &FilteredPaginationRequestMetadata{ + Metadata: &FilteredPaginationRequestMetadataV0{ Page: metadata.Page + 1, SnapshotID: metadata.SnapshotID, }, }, nil } -func (p *FilteredPaginationSubProtocol) ParseRequestMetadata(metadata json.RawMessage) (interface{}, error) { - var parsed FilteredPaginationRequestMetadata +func (p *FilteredPaginationSubProtocolV0) ParseRequestMetadata(metadata json.RawMessage) (interface{}, error) { + var parsed FilteredPaginationRequestMetadataV0 if err := json.Unmarshal(metadata, &parsed); err != nil { return nil, err } return &parsed, nil } -func (p *FilteredPaginationSubProtocol) ParseResponseMetadata(metadata json.RawMessage) (interface{}, error) { - var parsed FilteredPaginationResponseMetadata +func (p *FilteredPaginationSubProtocolV0) ParseResponseMetadata(metadata json.RawMessage) (interface{}, error) { + var parsed FilteredPaginationResponseMetadataV0 if err := json.Unmarshal(metadata, &parsed); err != nil { return nil, err } return &parsed, nil } + +// Ensure that FilteredPaginationSubProtocolV1 implements the Subprotocol interface. +var _ ordersync.Subprotocol = (*FilteredPaginationSubProtocolV1)(nil) + +// FilteredPaginationSubProtocolV1 is an ordersync subprotocol which returns all orders by +// paginating through them. It involves sending multiple requests until pagination is +// finished and all orders have been returned. Version 1 was implemented in +// https://github.com/0xProject/0x-mesh/pull/793 after changing the database implementation +// from LevelDB to SQL and Dexie.js/IndexedDB. +type FilteredPaginationSubProtocolV1 struct { + app *App + orderFilter *orderfilter.Filter + perPage int +} + +// NewFilteredPaginationSubprotocolV1 creates and returns a new FilteredPaginationSubprotocolV1 +// which will respond with perPage orders for each individual request/response. +func NewFilteredPaginationSubprotocolV1(app *App, perPage int) ordersync.Subprotocol { + return &FilteredPaginationSubProtocolV1{ + app: app, + orderFilter: app.orderFilter, + perPage: perPage, + } +} + +// FilteredPaginationRequestMetadataV1 is the request metadata for the +// FilteredPaginationSubProtocolV1. It keeps track of the current +// minOrderHash, which is expected to be an empty string on the first request. +type FilteredPaginationRequestMetadataV1 struct { + MinOrderHash common.Hash `json:"minOrderHash"` +} + +// FilteredPaginationResponseMetadataV1 is the response metadata for the +// FilteredPaginationSubProtocolV1. It contains the minOrderHash to use for +// the next request. +type FilteredPaginationResponseMetadataV1 struct { + NextMinOrderHash common.Hash `json:"nextMinOrderHash"` +} + +// Name returns the name of the FilteredPaginationSubProtocolV1 +func (p *FilteredPaginationSubProtocolV1) Name() string { + return "/pagination-with-filter/version/1" +} + +// HandleOrderSyncRequest returns the orders for one page, based on the page number +// and snapshotID corresponding to the given request. This is +// the implementation for the "provider" side of the subprotocol. +func (p *FilteredPaginationSubProtocolV1) HandleOrderSyncRequest(ctx context.Context, req *ordersync.Request) (*ordersync.Response, error) { + var metadata *FilteredPaginationRequestMetadataV1 + if req.Metadata == nil { + // Default metadata for the first request. + metadata = &FilteredPaginationRequestMetadataV1{ + MinOrderHash: common.Hash{}, + } + } else { + var ok bool + metadata, ok = req.Metadata.(*FilteredPaginationRequestMetadataV1) + if !ok { + return nil, fmt.Errorf("FilteredPaginationSubProtocolV1 received request with wrong metadata type (got %T)", req.Metadata) + } + } + + // It's possible that none of the orders in the current page match the filter. + // We don't want to respond with zero orders, so keep iterating until we find + // at least some orders that match the filter. + filteredOrders := []*zeroex.SignedOrder{} + currentMinOrderHash := metadata.MinOrderHash + nextMinOrderHash := common.Hash{} + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + // Get the orders for this page. + ordersResp, err := p.app.GetOrders(p.perPage, currentMinOrderHash) + if err != nil { + return nil, err + } + if len(ordersResp.OrdersInfos) == 0 { + // No more orders left. + break + } + nextMinOrderHash = ordersResp.OrdersInfos[len(ordersResp.OrdersInfos)-1].OrderHash + // Filter the orders for this page. + for _, orderInfo := range ordersResp.OrdersInfos { + if matches, err := p.orderFilter.MatchOrder(orderInfo.SignedOrder); err != nil { + return nil, err + } else if matches { + filteredOrders = append(filteredOrders, orderInfo.SignedOrder) + } + } + if len(filteredOrders) == 0 { + // If none of the orders for this page match the filter, we continue + // on to the next page. + currentMinOrderHash = nextMinOrderHash + continue + } else { + break + } + } + + return &ordersync.Response{ + Orders: filteredOrders, + Complete: len(filteredOrders) == 0, + Metadata: &FilteredPaginationResponseMetadataV1{ + NextMinOrderHash: nextMinOrderHash, + }, + }, nil +} + +// HandleOrderSyncResponse handles the orders for one page by validating them, storing them +// in the database, and firing the appropriate events. It also returns the next request to +// be sent. This is the implementation for the "requester" side of the subprotocol. +func (p *FilteredPaginationSubProtocolV1) HandleOrderSyncResponse(ctx context.Context, res *ordersync.Response) (*ordersync.Request, error) { + if res.Metadata == nil { + return nil, errors.New("FilteredPaginationSubProtocolV1 received response with nil metadata") + } + _, ok := res.Metadata.(*FilteredPaginationResponseMetadataV1) + if !ok { + return nil, fmt.Errorf("FilteredPaginationSubProtocolV1 received response with wrong metadata type (got %T)", res.Metadata) + } + filteredOrders := []*zeroex.SignedOrder{} + for _, order := range res.Orders { + if matches, err := p.orderFilter.MatchOrder(order); err != nil { + return nil, err + } else if matches { + filteredOrders = append(filteredOrders, order) + } else if !matches { + p.app.handlePeerScoreEvent(res.ProviderID, psReceivedOrderDoesNotMatchFilter) + } + } + validationResults, err := p.app.orderWatcher.ValidateAndStoreValidOrders(ctx, filteredOrders, false, p.app.chainID) + if err != nil { + return nil, err + } + for _, acceptedOrderInfo := range validationResults.Accepted { + if acceptedOrderInfo.IsNew { + log.WithFields(map[string]interface{}{ + "orderHash": acceptedOrderInfo.OrderHash.Hex(), + "from": res.ProviderID.Pretty(), + "protocol": "ordersync", + }).Info("received new valid order from peer") + log.WithFields(map[string]interface{}{ + "order": acceptedOrderInfo.SignedOrder, + "orderHash": acceptedOrderInfo.OrderHash.Hex(), + "from": res.ProviderID.Pretty(), + "protocol": "ordersync", + }).Trace("all fields for new valid order received from peer") + } + } + + // Calculate the next min order hash to send in our next request. + // This is equal to the maximum order hash we have received so far. + var nextMinOrderHash common.Hash + if len(res.Orders) > 0 { + hash, err := res.Orders[len(res.Orders)-1].ComputeOrderHash() + if err != nil { + return nil, err + } + nextMinOrderHash = hash + } + return &ordersync.Request{ + Metadata: &FilteredPaginationRequestMetadataV1{ + MinOrderHash: nextMinOrderHash, + }, + }, nil +} + +func (p *FilteredPaginationSubProtocolV1) ParseRequestMetadata(metadata json.RawMessage) (interface{}, error) { + var parsed FilteredPaginationRequestMetadataV1 + if err := json.Unmarshal(metadata, &parsed); err != nil { + return nil, err + } + return &parsed, nil +} + +func (p *FilteredPaginationSubProtocolV1) ParseResponseMetadata(metadata json.RawMessage) (interface{}, error) { + var parsed FilteredPaginationResponseMetadataV1 + if err := json.Unmarshal(metadata, &parsed); err != nil { + return nil, err + } + return &parsed, nil +} + +// validateHexHash returns an error if s is not a valid hex hash. It supports +// encodings with or without the "0x" prefix. +// Note(albrow) This is based on unexported code in go-ethereum. +func validateHexHash(s string) error { + if has0xPrefix(s) { + s = s[2:] + } + if len(s)%2 == 1 { + s = "0" + s + } + _, err := hex.DecodeString(s) + return err +} + +// has0xPrefix returns true if the given hex string starts with "0x" +// Note(albrow) This is copied from go-ethereum, where it is unexported. +func has0xPrefix(str string) bool { + return len(str) >= 2 && str[0] == '0' && (str[1] == 'x' || str[1] == 'X') +} diff --git a/db/batch.go b/db/batch.go deleted file mode 100644 index 5aefffec6..000000000 --- a/db/batch.go +++ /dev/null @@ -1,44 +0,0 @@ -package db - -import ( - "github.com/syndtr/goleveldb/leveldb" - "github.com/syndtr/goleveldb/leveldb/iterator" - "github.com/syndtr/goleveldb/leveldb/opt" - "github.com/syndtr/goleveldb/leveldb/util" -) - -type readerWithBatchWriter struct { - reader dbReader - batch *leveldb.Batch -} - -func newReaderWithBatchWriter(reader dbReader) *readerWithBatchWriter { - return &readerWithBatchWriter{ - reader: reader, - batch: &leveldb.Batch{}, - } -} - -var _ dbReadWriter = &readerWithBatchWriter{} - -func (readWriter *readerWithBatchWriter) Get(key []byte, ro *opt.ReadOptions) ([]byte, error) { - return readWriter.reader.Get(key, ro) -} - -func (readWriter *readerWithBatchWriter) NewIterator(slice *util.Range, ro *opt.ReadOptions) iterator.Iterator { - return readWriter.reader.NewIterator(slice, ro) -} - -func (readWriter *readerWithBatchWriter) Has(key []byte, ro *opt.ReadOptions) (bool, error) { - return readWriter.reader.Has(key, ro) -} - -func (readWriter *readerWithBatchWriter) Delete(key []byte, wo *opt.WriteOptions) error { - readWriter.batch.Delete(key) - return nil -} - -func (readWriter *readerWithBatchWriter) Put(key, value []byte, wo *opt.WriteOptions) error { - readWriter.batch.Put(key, value) - return nil -} diff --git a/db/col_info.go b/db/col_info.go deleted file mode 100644 index 9c58d7635..000000000 --- a/db/col_info.go +++ /dev/null @@ -1,83 +0,0 @@ -package db - -import ( - "fmt" - "reflect" - "sync" -) - -// colInfo is a set of information/metadata about a collection. -type colInfo struct { - db *DB - name string - modelType reflect.Type - indexes []*Index - // indexMut protects the indexes slice. - indexMut sync.RWMutex - // writeMut is used by transactions to prevent other goroutines from writing - // until the transaction is committed or discarded. Needs to be a pointer so - // that copies of this colInfo retain the same writeLock. - writeMut *sync.Mutex -} - -// copy returns a copy of the colInfo. Any changes made to the original (e.g. -// adding a new index) will not affect the copy. The copy and the original share -// the same writeMut. -func (info *colInfo) copy() *colInfo { - info.indexMut.RLock() - indexes := make([]*Index, len(info.indexes)) - copy(indexes, info.indexes) - info.indexMut.RUnlock() - return &colInfo{ - db: info.db, - name: info.name, - modelType: info.modelType, - indexes: indexes, - writeMut: info.writeMut, - } -} - -func (info *colInfo) prefix() []byte { - return []byte(fmt.Sprintf("model:%s", escape([]byte(info.name)))) -} - -// countKey returns the key used to store a count of the number of models in the -// collection. -func (info *colInfo) countKey() []byte { - return []byte(fmt.Sprintf("count:%s", escape([]byte(info.name)))) -} - -func (info *colInfo) primaryKeyForModel(model Model) []byte { - return info.primaryKeyForID(model.ID()) -} - -func (info *colInfo) primaryKeyForID(id []byte) []byte { - return []byte(fmt.Sprintf("%s:%s", info.prefix(), escape(id))) -} - -func (info *colInfo) primaryKeyForIDWithoutEscape(id []byte) []byte { - return []byte(fmt.Sprintf("%s:%s", info.prefix(), id)) -} - -func (info *colInfo) checkModelType(model Model) error { - actualType := reflect.TypeOf(model) - if info.modelType != actualType { - if actualType.Kind() == reflect.Ptr { - if info.modelType == actualType.Elem() { - // Pointers to the expected type are allowed here. - return nil - } - } - return fmt.Errorf("for %q collection: incorrect type for model (expected %s but got %s)", info.name, info.modelType, actualType) - } - return nil -} - -func (info *colInfo) checkModelsType(models interface{}) error { - expectedType := reflect.PtrTo(reflect.SliceOf(info.modelType)) - actualType := reflect.TypeOf(models) - if expectedType != actualType { - return fmt.Errorf("for %q collection: incorrect type for models (expected %s but got %s)", info.name, expectedType, actualType) - } - return nil -} diff --git a/db/collection.go b/db/collection.go deleted file mode 100644 index ecfbf2648..000000000 --- a/db/collection.go +++ /dev/null @@ -1,124 +0,0 @@ -package db - -import ( - "fmt" - "reflect" - "sync" - - "github.com/syndtr/goleveldb/leveldb" -) - -// Collection represents a set of a specific type of model. -type Collection struct { - info *colInfo - ldb *leveldb.DB -} - -// NewCollection creates and returns a new collection with the given name and -// model type. You should create exactly one collection for each model type. The -// collection should typically be created once at the start of your application -// and re-used. NewCollection returns an error if a collection has already been -// created with the given name for this db. -func (db *DB) NewCollection(name string, typ Model) (*Collection, error) { - col := &Collection{ - info: &colInfo{ - db: db, - name: name, - modelType: reflect.TypeOf(typ), - writeMut: &sync.Mutex{}, - }, - ldb: db.ldb, - } - db.colLock.Lock() - defer db.colLock.Unlock() - for _, existingCol := range db.collections { - if existingCol.info.name == name { - return nil, fmt.Errorf("a collection with the name %q already exists", name) - } - } - db.collections = append(db.collections, col) - return col, nil -} - -// Name returns the name of the collection. -func (c *Collection) Name() string { - return c.info.name -} - -// FindByID finds the model with the given ID and scans the results into the -// given model. As in the Unmarshal and Decode methods in the encoding/json -// package, model must be settable via reflect. Typically, this means you should -// pass in a pointer. -func (c *Collection) FindByID(id []byte, model Model) error { - return findByID(c.info, c.ldb, id, model) -} - -// FindAll finds all models for the collection and scans the results into the -// given models. models should be a pointer to an empty slice of a concrete -// model type (e.g. *[]myModelType). -func (c *Collection) FindAll(models interface{}) error { - return findAll(c.info, c.ldb, models) -} - -// Count returns the number of models in the collection. -func (c *Collection) Count() (int, error) { - return count(c.info, c.ldb) -} - -// Insert inserts the given model into the database. It returns an error if a -// model with the same id already exists. -func (c *Collection) Insert(model Model) error { - txn := c.OpenTransaction() - if err := insertWithTransaction(c.info, txn.readWriter, model); err != nil { - _ = txn.Discard() - return err - } - txn.updateInternalCount(1) - if err := txn.Commit(); err != nil { - _ = txn.Discard() - return err - } - return nil -} - -// Update updates an existing model in the database. It returns an error if the -// given model doesn't already exist. -func (c *Collection) Update(model Model) error { - txn := c.OpenTransaction() - if err := updateWithTransaction(c.info, txn.readWriter, model); err != nil { - _ = txn.Discard() - return err - } - if err := txn.Commit(); err != nil { - _ = txn.Discard() - return err - } - return nil -} - -// Delete deletes the model with the given ID from the database. It returns an -// error if the model doesn't exist in the database. -func (c *Collection) Delete(id []byte) error { - txn := c.OpenTransaction() - if err := deleteWithTransaction(c.info, txn.readWriter, id); err != nil { - _ = txn.Discard() - return err - } - txn.updateInternalCount(-1) - if err := txn.Commit(); err != nil { - _ = txn.Discard() - return err - } - return nil -} - -// New Query creates and returns a new query with the given filter. By default, -// a query will return all models that match the filter in ascending byte order -// according to their index values. The query offers methods that can be used to -// change this (e.g. Reverse and Max). The query is lazily executed, i.e. it -// does not actually touch the database until they are run. In general, queries -// have a runtime of O(N) where N is the number of models that are returned by -// the query, but using some features may significantly change this. -func (c *Collection) NewQuery(filter *Filter) *Query { - return newQuery(c.info, c.ldb, filter) -} diff --git a/db/collection_benchmark_test.go b/db/collection_benchmark_test.go deleted file mode 100644 index dcc4b0613..000000000 --- a/db/collection_benchmark_test.go +++ /dev/null @@ -1,175 +0,0 @@ -package db - -import ( - "fmt" - "strconv" - "testing" - - "github.com/stretchr/testify/require" -) - -func BenchmarkInsert(b *testing.B) { - db := newTestDB(b) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(b, err) - b.ResetTimer() - for i := 0; i < b.N; i++ { - b.StopTimer() - model := &testModel{ - Name: fmt.Sprintf("person_%d", i), - Age: i, - } - b.StartTimer() - err := col.Insert(model) - b.StopTimer() - require.NoError(b, err) - b.StartTimer() - } -} - -func BenchmarkFindByIDHot(b *testing.B) { - db := newTestDB(b) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(b, err) - model := &testModel{ - Name: "foo", - Age: 42, - } - require.NoError(b, col.Insert(model)) - b.ResetTimer() - for i := 0; i < b.N; i++ { - var found testModel - err := col.FindByID(model.ID(), &found) - b.StopTimer() - require.NoError(b, err) - b.StartTimer() - } -} - -func BenchmarkFindByIDCold(b *testing.B) { - db := newTestDB(b) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(b, err) - b.ResetTimer() - for i := 0; i < b.N; i++ { - b.StopTimer() - model := &testModel{ - Name: fmt.Sprintf("person_%d", i), - Age: i, - } - require.NoError(b, col.Insert(model)) - b.StartTimer() - var found testModel - err := col.FindByID(model.ID(), &found) - b.StopTimer() - require.NoError(b, err) - b.StartTimer() - } -} - -func BenchmarkUpdateHot(b *testing.B) { - db := newTestDB(b) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(b, err) - original := &testModel{ - Name: "person_0", - Age: 0, - } - require.NoError(b, col.Insert(original)) - b.ResetTimer() - for i := 0; i < b.N; i++ { - b.StopTimer() - updated := &testModel{ - Name: original.Name, - Age: i + 1, - } - b.StartTimer() - err := col.Update(updated) - b.StopTimer() - require.NoError(b, err) - b.StartTimer() - } -} - -func BenchmarkUpdateCold(b *testing.B) { - db := newTestDB(b) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(b, err) - b.ResetTimer() - for i := 0; i < b.N; i++ { - b.StopTimer() - original := &testModel{ - Name: fmt.Sprintf("person_%d", i), - Age: i, - } - require.NoError(b, col.Insert(original)) - updated := &testModel{ - Name: original.Name, - Age: i + 1, - } - b.StartTimer() - err := col.Update(updated) - b.StopTimer() - require.NoError(b, err) - b.StartTimer() - } -} - -func BenchmarkFindAll100(b *testing.B) { - benchmarkFindAll(b, 100) -} - -func BenchmarkFindAll1000(b *testing.B) { - benchmarkFindAll(b, 1000) -} - -func benchmarkFindAll(b *testing.B, count int) { - b.Helper() - db := newTestDB(b) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(b, err) - expected := []*testModel{} - for i := 0; i < count; i++ { - model := &testModel{ - Name: "person_%d" + strconv.Itoa(i), - Age: i, - } - require.NoError(b, col.Insert(model)) - expected = append(expected, model) - } - b.ResetTimer() - for i := 0; i < b.N; i++ { - var actual []*testModel - err := col.FindAll(&actual) - b.StopTimer() - require.NoError(b, err) - b.StartTimer() - } -} - -func BenchmarkDelete(b *testing.B) { - db := newTestDB(b) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(b, err) - b.ResetTimer() - for i := 0; i < b.N; i++ { - b.StopTimer() - model := &testModel{ - Name: "person_%d" + strconv.Itoa(i), - Age: i, - } - require.NoError(b, col.Insert(model)) - b.StartTimer() - err := col.Delete(model.ID()) - b.StopTimer() - require.NoError(b, err) - b.StartTimer() - } -} diff --git a/db/collection_test.go b/db/collection_test.go deleted file mode 100644 index 685bd17c2..000000000 --- a/db/collection_test.go +++ /dev/null @@ -1,203 +0,0 @@ -package db - -import ( - "fmt" - "strconv" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestNewCollection(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - _, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - _, err = db.NewCollection("people", &testModel{}) - require.Error(t, err, "Expected an error when creating new collection with the same name") -} - -func TestInsert(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - expected := &testModel{ - Name: "foo", - Age: 42, - } - require.NoError(t, col.Insert(expected)) - exists, err := db.ldb.Has([]byte("model:people:foo"), nil) - require.NoError(t, err) - assert.True(t, exists, "Model not stored in database at the expected key") -} - -func TestFindByID(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - expected := &testModel{ - Name: "foo", - Age: 42, - } - require.NoError(t, col.Insert(expected)) - actual := &testModel{} - require.NoError(t, col.FindByID(expected.ID(), actual)) - assert.Equal(t, expected, actual) -} - -func TestUpdate(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - original := &testModel{ - Name: "foo", - Age: 42, - } - require.NoError(t, col.Insert(original)) - updated := &testModel{ - Name: "foo", - Age: 43, - } - require.NoError(t, col.Update(updated)) - actual := &testModel{} - require.NoError(t, col.FindByID(original.ID(), actual)) - assert.Equal(t, updated, actual) -} - -func TestFindAll(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - expected := []*testModel{} - for i := 0; i < 5; i++ { - model := &testModel{ - Name: "Person_" + strconv.Itoa(i), - Age: i, - } - require.NoError(t, col.Insert(model)) - expected = append(expected, model) - } - var actual []*testModel - require.NoError(t, col.FindAll(&actual)) - assert.Equal(t, expected, actual) -} - -func TestCount(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - - // Insert some test models and make sure Count is equal to the number of - // models inserted. - expected := []*testModel{} - for i := 0; i < 5; i++ { - model := &testModel{ - Name: "Person_" + strconv.Itoa(i), - Age: i, - } - require.NoError(t, col.Insert(model)) - expected = append(expected, model) - } - { - actualCount, err := col.Count() - require.NoError(t, err) - assert.Equal(t, len(expected), actualCount, "Count returned wrong results") - } - - // Delete a model and then check that the count decremented by 1. - require.NoError(t, col.Delete(expected[0].ID())) - { - actualCount, err := col.Count() - require.NoError(t, err) - assert.Equal(t, len(expected)-1, actualCount, "Count returned wrong results") - } - - // Delete all remaining models and check that the countKey was deleted. - for _, model := range expected[1:] { - require.NoError(t, col.Delete(model.ID())) - } - { - actualCount, err := col.Count() - require.NoError(t, err) - assert.Equal(t, 0, actualCount, "Count returned wrong results") - countKeyExists, err := col.ldb.Has(col.info.countKey(), nil) - require.NoError(t, err) - require.False(t, countKeyExists, "expected countKey to be deleted but it was not") - } -} - -func TestDelete(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - col.AddIndex("age", func(m Model) []byte { - return []byte(fmt.Sprint(m.(*testModel).Age)) - }) - model := &testModel{ - Name: "foo", - Age: 42, - } - require.NoError(t, col.Insert(model)) - require.NoError(t, col.Delete(model.ID())) - { - exists, err := db.ldb.Has([]byte("model:people:foo"), nil) - require.NoError(t, err) - assert.False(t, exists, "Primary key should not be stored in database after calling Delete") - } - { - exists, err := db.ldb.Has([]byte("index:people:age:42:foo"), nil) - require.NoError(t, err) - assert.False(t, exists, "Index should not be stored in database after calling Delete") - } -} - -func TestDeleteAfterUpdate(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - col.AddIndex("age", func(m Model) []byte { - return []byte(fmt.Sprint(m.(*testModel).Age)) - }) - model := &testModel{ - Name: "foo", - Age: 42, - } - require.NoError(t, col.Insert(model)) - updated := &testModel{ - Name: "foo", - Age: 43, - } - require.NoError(t, col.Update(updated)) - require.NoError(t, col.Delete(model.ID())) - { - exists, err := db.ldb.Has([]byte("model:people:foo"), nil) - require.NoError(t, err) - assert.False(t, exists, "Primary key should not be stored in database after calling Delete") - } - { - exists, err := db.ldb.Has([]byte("index:people:age:42:foo"), nil) - require.NoError(t, err) - assert.False(t, exists, "Old index should not be stored in database after calling Delete") - } - { - exists, err := db.ldb.Has([]byte("index:people:age:43:foo"), nil) - require.NoError(t, err) - assert.False(t, exists, "Updated index should not be stored in database after calling Delete") - } -} diff --git a/db/common.go b/db/common.go new file mode 100644 index 000000000..98ad4243e --- /dev/null +++ b/db/common.go @@ -0,0 +1,339 @@ +package db + +import ( + "errors" + "fmt" + "math/big" + "time" + + "github.com/0xProject/0x-mesh/common/types" + "github.com/0xProject/0x-mesh/ethereum" + "github.com/0xProject/0x-mesh/zeroex" + "github.com/ethereum/go-ethereum/common" + "github.com/gibson042/canonicaljson-go" +) + +const ( + // The default miniHeaderRetentionLimit used by Mesh. This default only gets overwritten in tests. + defaultMiniHeaderRetentionLimit = 20 + // The maximum MiniHeaders to query per page when deleting MiniHeaders + miniHeadersMaxPerPage = 1000 + // The amount of time to wait before timing out when connecting to the database for the first time. + connectTimeout = 10 * time.Second +) + +var ( + ErrDBFilledWithPinnedOrders = errors.New("the database is full of pinned orders; no orders can be removed in order to make space") + ErrMetadataAlreadyExists = errors.New("metadata already exists in the database (use UpdateMetadata instead?)") + ErrNotFound = errors.New("could not find existing model or row in database") + ErrClosed = errors.New("database is already closed") +) + +type Database interface { + AddOrders(orders []*types.OrderWithMetadata) (added []*types.OrderWithMetadata, removed []*types.OrderWithMetadata, err error) + GetOrder(hash common.Hash) (*types.OrderWithMetadata, error) + FindOrders(opts *OrderQuery) ([]*types.OrderWithMetadata, error) + CountOrders(opts *OrderQuery) (int, error) + DeleteOrder(hash common.Hash) error + DeleteOrders(opts *OrderQuery) ([]*types.OrderWithMetadata, error) + UpdateOrder(hash common.Hash, updateFunc func(existingOrder *types.OrderWithMetadata) (updatedOrder *types.OrderWithMetadata, err error)) error + AddMiniHeaders(miniHeaders []*types.MiniHeader) (added []*types.MiniHeader, removed []*types.MiniHeader, err error) + GetMiniHeader(hash common.Hash) (*types.MiniHeader, error) + FindMiniHeaders(opts *MiniHeaderQuery) ([]*types.MiniHeader, error) + DeleteMiniHeader(hash common.Hash) error + DeleteMiniHeaders(opts *MiniHeaderQuery) ([]*types.MiniHeader, error) + GetMetadata() (*types.Metadata, error) + SaveMetadata(metadata *types.Metadata) error + UpdateMetadata(updateFunc func(oldmetadata *types.Metadata) (newMetadata *types.Metadata)) error +} + +type Options struct { + DriverName string `json:"driverName"` + DataSourceName string `json:"dataSourceName"` + MaxOrders int `json:"maxOrders"` + MaxMiniHeaders int `json:"maxMiniHeaders"` +} + +func parseOptions(opts *Options) *Options { + finalOpts := defaultOptions() + if opts == nil { + return finalOpts + } + if opts.DataSourceName != "" { + finalOpts.DataSourceName = opts.DataSourceName + } + if opts.MaxOrders != 0 { + finalOpts.MaxOrders = opts.MaxOrders + } + if opts.MaxMiniHeaders != 0 { + finalOpts.MaxMiniHeaders = opts.MaxMiniHeaders + } + return finalOpts +} + +type SortDirection string + +const ( + Ascending SortDirection = "ASC" + Descending SortDirection = "DESC" +) + +type FilterKind string + +const ( + Equal FilterKind = "=" + NotEqual FilterKind = "!=" + Less FilterKind = "<" + Greater FilterKind = ">" + LessOrEqual FilterKind = "<=" + GreaterOrEqual FilterKind = ">=" + Contains FilterKind = "CONTAINS" +) + +type OrderField string + +const ( + OFHash OrderField = "hash" + OFChainID OrderField = "chainID" + OFExchangeAddress OrderField = "exchangeAddress" + OFMakerAddress OrderField = "makerAddress" + OFMakerAssetData OrderField = "makerAssetData" + OFMakerFeeAssetData OrderField = "makerFeeAssetData" + OFMakerAssetAmount OrderField = "makerAssetAmount" + OFMakerFee OrderField = "makerFee" + OFTakerAddress OrderField = "takerAddress" + OFTakerAssetData OrderField = "takerAssetData" + OFTakerFeeAssetData OrderField = "takerFeeAssetData" + OFTakerAssetAmount OrderField = "takerAssetAmount" + OFTakerFee OrderField = "takerFee" + OFSenderAddress OrderField = "senderAddress" + OFFeeRecipientAddress OrderField = "feeRecipientAddress" + OFExpirationTimeSeconds OrderField = "expirationTimeSeconds" + OFSalt OrderField = "salt" + OFSignature OrderField = "signature" + OFLastUpdated OrderField = "lastUpdated" + OFFillableTakerAssetAmount OrderField = "fillableTakerAssetAmount" + OFIsRemoved OrderField = "isRemoved" + OFIsPinned OrderField = "isPinned" + OFParsedMakerAssetData OrderField = "parsedMakerAssetData" + OFParsedMakerFeeAssetData OrderField = "parsedMakerFeeAssetData" +) + +type OrderQuery struct { + Filters []OrderFilter `json:"filters"` + Sort []OrderSort `json:"sort"` + Limit uint `json:"limit"` + Offset uint `json:"offset"` +} + +type OrderSort struct { + Field OrderField `json:"field"` + Direction SortDirection `json:"direction"` +} + +type OrderFilter struct { + Field OrderField `json:"field"` + Kind FilterKind `json:"kind"` + Value interface{} `json:"value"` +} + +// MakerAssetIncludesTokenAddressAndTokenID is a helper method which returns a filter that will match orders +// that include the token address and token ID in MakerAssetData. +func MakerAssetIncludesTokenAddressAndTokenID(tokenAddress common.Address, tokenID *big.Int) OrderFilter { + return assetDataIncludesTokenAddressAndTokenID(OFParsedMakerAssetData, tokenAddress, tokenID) +} + +// MakerFeeAssetIncludesTokenAddressAndTokenID is a helper method which returns a filter that will match orders +// that include the token address and token ID in MakerFeeAssetData. +func MakerFeeAssetIncludesTokenAddressAndTokenID(tokenAddress common.Address, tokenID *big.Int) OrderFilter { + return assetDataIncludesTokenAddressAndTokenID(OFParsedMakerFeeAssetData, tokenAddress, tokenID) +} + +// MakerAssetIncludesTokenAddress is a helper method which returns a filter that will match orders +// that include the token address (and any token id, including null) in MakerAssetData. +func MakerAssetIncludesTokenAddress(tokenAddress common.Address) OrderFilter { + return assetDataIncludesTokenAddress(OFParsedMakerAssetData, tokenAddress) +} + +// MakerFeeAssetIncludesTokenAddress is a helper method which returns a filter that will match orders +// that include the token address (and any token id, including null) in MakerFeeAssetData. +func MakerFeeAssetIncludesTokenAddress(tokenAddress common.Address) OrderFilter { + return assetDataIncludesTokenAddress(OFParsedMakerFeeAssetData, tokenAddress) +} + +func assetDataIncludesTokenAddress(field OrderField, tokenAddress common.Address) OrderFilter { + tokenAddressJSON, err := canonicaljson.Marshal(tokenAddress) + if err != nil { + // big.Int and common.Address types should never return an error when marshaling to JSON + panic(err) + } + filterValue := fmt.Sprintf(`"address":%s`, tokenAddressJSON) + return OrderFilter{ + Field: field, + Kind: Contains, + Value: filterValue, + } +} + +type MiniHeaderField string + +const ( + MFHash MiniHeaderField = "hash" + MFParent MiniHeaderField = "parent" + MFNumber MiniHeaderField = "number" + MFTimestamp MiniHeaderField = "timestamp" + MFLogs MiniHeaderField = "logs" +) + +type MiniHeaderQuery struct { + Filters []MiniHeaderFilter `json:"filters"` + Sort []MiniHeaderSort `json:"sort"` + Limit uint `json:"limit"` + Offset uint `json:"offset"` +} + +type MiniHeaderSort struct { + Field MiniHeaderField `json:"field"` + Direction SortDirection `json:"direction"` +} + +type MiniHeaderFilter struct { + Field MiniHeaderField `json:"field"` + Kind FilterKind `json:"kind"` + Value interface{} `json:"value"` +} + +// GetLatestMiniHeader is a helper method for getting the latest MiniHeader. +// It returns ErrNotFound if there are no MiniHeaders in the database. +func (db *DB) GetLatestMiniHeader() (*types.MiniHeader, error) { + latestMiniHeaders, err := db.FindMiniHeaders(&MiniHeaderQuery{ + Sort: []MiniHeaderSort{ + { + Field: MFNumber, + Direction: Descending, + }, + }, + Limit: 1, + }) + if err != nil { + return nil, err + } + if len(latestMiniHeaders) == 0 { + return nil, ErrNotFound + } + return latestMiniHeaders[0], nil +} + +func ParseContractAddressesAndTokenIdsFromAssetData(assetDataDecoder *zeroex.AssetDataDecoder, assetData []byte, contractAddresses ethereum.ContractAddresses) ([]*types.SingleAssetData, error) { + if len(assetData) == 0 { + return []*types.SingleAssetData{}, nil + } + singleAssetDatas := []*types.SingleAssetData{} + + assetDataName, err := assetDataDecoder.GetName(assetData) + if err != nil { + return nil, err + } + switch assetDataName { + case "ERC20Token": + var decodedAssetData zeroex.ERC20AssetData + err := assetDataDecoder.Decode(assetData, &decodedAssetData) + if err != nil { + return nil, err + } + a := &types.SingleAssetData{ + Address: decodedAssetData.Address, + } + singleAssetDatas = append(singleAssetDatas, a) + case "ERC721Token": + var decodedAssetData zeroex.ERC721AssetData + err := assetDataDecoder.Decode(assetData, &decodedAssetData) + if err != nil { + return nil, err + } + a := &types.SingleAssetData{ + Address: decodedAssetData.Address, + TokenID: decodedAssetData.TokenId, + } + singleAssetDatas = append(singleAssetDatas, a) + case "ERC1155Assets": + var decodedAssetData zeroex.ERC1155AssetData + err := assetDataDecoder.Decode(assetData, &decodedAssetData) + if err != nil { + return nil, err + } + for _, id := range decodedAssetData.Ids { + a := &types.SingleAssetData{ + Address: decodedAssetData.Address, + TokenID: id, + } + singleAssetDatas = append(singleAssetDatas, a) + } + case "StaticCall": + var decodedAssetData zeroex.StaticCallAssetData + err := assetDataDecoder.Decode(assetData, &decodedAssetData) + if err != nil { + return nil, err + } + // NOTE(jalextowle): As of right now, none of the supported staticcalls + // have important information in the StaticCallData. We choose not to add + // `singleAssetData` because it would not be used. + case "MultiAsset": + var decodedAssetData zeroex.MultiAssetData + err := assetDataDecoder.Decode(assetData, &decodedAssetData) + if err != nil { + return nil, err + } + for _, assetData := range decodedAssetData.NestedAssetData { + as, err := ParseContractAddressesAndTokenIdsFromAssetData(assetDataDecoder, assetData, contractAddresses) + if err != nil { + return nil, err + } + singleAssetDatas = append(singleAssetDatas, as...) + } + case "ERC20Bridge": + var decodedAssetData zeroex.ERC20BridgeAssetData + err := assetDataDecoder.Decode(assetData, &decodedAssetData) + if err != nil { + return nil, err + } + tokenAddress := decodedAssetData.TokenAddress + // TODO(albrow): Update orderwatcher to account for this instead of storing + // it in the database. This would mean we can remove contractAddresses as an + // argument and simplify the implementation. Maybe even have the db package + // handle parsing asset data automatically. + // HACK(fabio): Despite Chai ERC20Bridge orders encoding the Dai address as + // the tokenAddress, we actually want to react to the Chai token's contract + // events, so we actually return it instead. + if decodedAssetData.BridgeAddress == contractAddresses.ChaiBridge { + tokenAddress = contractAddresses.ChaiToken + } + a := &types.SingleAssetData{ + Address: tokenAddress, + } + singleAssetDatas = append(singleAssetDatas, a) + default: + return nil, fmt.Errorf("unrecognized assetData type name found: %s", assetDataName) + } + return singleAssetDatas, nil +} + +func checkOrderQuery(query *OrderQuery) error { + if query == nil { + return nil + } + if query.Offset != 0 && query.Limit == 0 { + return errors.New("can't use Offset without Limit") + } + return nil +} + +func checkMiniHeaderQuery(query *MiniHeaderQuery) error { + if query == nil { + return nil + } + if query.Offset != 0 && query.Limit == 0 { + return errors.New("can't use Offset without Limit") + } + return nil +} diff --git a/db/db.go b/db/db.go deleted file mode 100644 index dabee114a..000000000 --- a/db/db.go +++ /dev/null @@ -1,44 +0,0 @@ -package db - -import ( - "sync" - - "github.com/syndtr/goleveldb/leveldb" -) - -// Note about the implementation: -// -// There are two types of keys used. A "primary key" is the main key for a -// particular model. It's value is the encoded data for that model. The format -// for a primary key is: `model::`. -// -// An "index key" is used in queries to find models with specific indexed -// values. The format for an index key is: -// `index::::`. Unlike primary -// keys, index keys have no values and don't store any actual data. Instead, the -// primary key can be extracted from an index key and then used to look up the -// data for the corresponding model. - -// Model is any type which can be inserted and retrieved from the database. The -// only requirement is an ID method. Because the db package uses reflect to -// encode/decode models, only exported struct fields will be saved and retrieved -// from the database. -type Model interface { - // ID returns a unique identifier for this model. - ID() []byte -} - -// DB is the top-level Database. -type DB struct { - ldb *leveldb.DB - globalWriteLock sync.RWMutex - collections []*Collection - colLock sync.Mutex -} - -// Close closes the database. It is not safe to call Close if there are any -// other methods that have not yet returned. It is safe to call Close multiple -// times. -func (db *DB) Close() error { - return db.ldb.Close() -} diff --git a/db/db_test.go b/db/db_test.go index a1eaddcb5..e479404d8 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -1,31 +1,1897 @@ package db import ( + "bytes" + "context" + "encoding/json" + "fmt" + "math/big" + "math/rand" + "sort" + "strings" "testing" + "time" - "github.com/google/uuid" + "github.com/0xProject/0x-mesh/common/types" + "github.com/0xProject/0x-mesh/constants" + "github.com/0xProject/0x-mesh/ethereum" + "github.com/0xProject/0x-mesh/zeroex" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/math" + ethtypes "github.com/ethereum/go-ethereum/core/types" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -type testModel struct { - Name string - Age int - Nicknames []string +var contractAddresses = ethereum.GanacheAddresses + +func TestAddOrders(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + numOrders := 10 + orders := []*types.OrderWithMetadata{} + for i := 0; i < numOrders; i++ { + orders = append(orders, newTestOrder()) + } + + { + added, removed, err := db.AddOrders(orders) + require.NoError(t, err) + assert.Len(t, removed, 0, "Expected no orders to be removed") + assertOrderSlicesAreUnsortedEqual(t, orders, added) + } + { + added, removed, err := db.AddOrders(orders) + require.NoError(t, err) + assert.Len(t, removed, 0, "Expected no orders to be removed") + assert.Len(t, added, 0, "Expected no orders to be added (they should already exist)") + } +} + +func TestGetOrder(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + added, _, err := db.AddOrders([]*types.OrderWithMetadata{newTestOrder()}) + require.NoError(t, err) + originalOrder := added[0] + + foundOrder, err := db.GetOrder(originalOrder.Hash) + require.NoError(t, err) + require.NotNil(t, foundOrder, "found order should not be nil") + assertOrdersAreEqual(t, originalOrder, foundOrder) + + _, err = db.GetOrder(common.Hash{}) + assert.EqualError(t, err, ErrNotFound.Error(), "calling GetOrder with a hash that doesn't exist should return ErrNotFound") +} + +func TestUpdateOrder(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + err := db.UpdateOrder(common.Hash{}, func(existingOrder *types.OrderWithMetadata) (*types.OrderWithMetadata, error) { + return existingOrder, nil + }) + assert.EqualError(t, err, ErrNotFound.Error(), "calling UpdateOrder with a hash that doesn't exist should return ErrNotFound") + + // Note(albrow): We create more than one order to make sure that + // UpdateOrder only updates one of them and does not affect the + // others. + numOrders := 3 + originalOrders := []*types.OrderWithMetadata{} + for i := 0; i < numOrders; i++ { + originalOrders = append(originalOrders, newTestOrder()) + } + _, _, err = db.AddOrders(originalOrders) + require.NoError(t, err) + + orderToUpdate := originalOrders[0] + updatedFillableAmount := big.NewInt(12345) + err = db.UpdateOrder(orderToUpdate.Hash, func(existingOrder *types.OrderWithMetadata) (*types.OrderWithMetadata, error) { + updatedOrder := existingOrder + updatedOrder.FillableTakerAssetAmount = updatedFillableAmount + return updatedOrder, nil + }) + require.NoError(t, err) + + expectedOrders := originalOrders + expectedOrders[0].FillableTakerAssetAmount = updatedFillableAmount + foundOrders, err := db.FindOrders(nil) + require.NoError(t, err) + assertOrderSlicesAreUnsortedEqual(t, expectedOrders, foundOrders) +} + +func TestFindOrders(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + numOrders := 10 + originalOrders := []*types.OrderWithMetadata{} + for i := 0; i < numOrders; i++ { + originalOrders = append(originalOrders, newTestOrder()) + } + _, _, err := db.AddOrders(originalOrders) + require.NoError(t, err) + + foundOrders, err := db.FindOrders(nil) + require.NoError(t, err) + assertOrderSlicesAreUnsortedEqual(t, originalOrders, foundOrders) +} + +func TestFindOrdersSort(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + // Create some test orders with carefully chosen MakerAssetAmount + // and TakerAssetAmount values for testing sorting. + numOrders := 5 + originalOrders := []*types.OrderWithMetadata{} + for i := 0; i < numOrders; i++ { + order := newTestOrder() + order.MakerAssetAmount = big.NewInt(int64(i)) + // It's important for some orders to have the same TakerAssetAmount + // so that we can test secondary sorts (sorting on more than one + // field). + if i%2 == 0 { + order.TakerAssetAmount = big.NewInt(100) + } else { + order.TakerAssetAmount = big.NewInt(200) + } + originalOrders = append(originalOrders, order) + } + _, _, err := db.AddOrders(originalOrders) + require.NoError(t, err) + + testCases := []findOrdersSortTestCase{ + { + sortOpts: []OrderSort{ + { + Field: OFMakerAssetAmount, + Direction: Ascending, + }, + }, + less: lessByMakerAssetAmountAsc, + }, + { + sortOpts: []OrderSort{ + { + Field: OFMakerAssetAmount, + Direction: Descending, + }, + }, + less: lessByMakerAssetAmountDesc, + }, + { + sortOpts: []OrderSort{ + { + Field: OFTakerAssetAmount, + Direction: Ascending, + }, + { + Field: OFMakerAssetAmount, + Direction: Ascending, + }, + }, + less: lessByTakerAssetAmountAscAndMakerAssetAmountAsc, + }, + { + sortOpts: []OrderSort{ + { + Field: OFTakerAssetAmount, + Direction: Descending, + }, + { + Field: OFMakerAssetAmount, + Direction: Descending, + }, + }, + less: lessByTakerAssetAmountDescAndMakerAssetAmountDesc, + }, + } + for i, testCase := range testCases { + testCaseName := fmt.Sprintf("test case %d", i) + t.Run(testCaseName, runFindOrdersSortTestCase(t, db, originalOrders, testCase)) + } +} + +type findOrdersSortTestCase struct { + sortOpts []OrderSort + less func([]*types.OrderWithMetadata) func(i, j int) bool +} + +func runFindOrdersSortTestCase(t *testing.T, db *DB, originalOrders []*types.OrderWithMetadata, testCase findOrdersSortTestCase) func(t *testing.T) { + return func(t *testing.T) { + expectedOrders := make([]*types.OrderWithMetadata, len(originalOrders)) + copy(expectedOrders, originalOrders) + sort.Slice(expectedOrders, testCase.less(expectedOrders)) + findOpts := &OrderQuery{ + Sort: testCase.sortOpts, + } + foundOrders, err := db.FindOrders(findOpts) + require.NoError(t, err) + assertOrderSlicesAreEqual(t, expectedOrders, foundOrders) + } +} + +func TestFindOrdersLimitAndOffset(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + numOrders := 10 + originalOrders := []*types.OrderWithMetadata{} + for i := 0; i < numOrders; i++ { + originalOrders = append(originalOrders, newTestOrder()) + } + _, _, err := db.AddOrders(originalOrders) + require.NoError(t, err) + sortOrdersByHash(originalOrders) + + testCases := []findOrdersLimitAndOffsetTestCase{ + { + limit: 0, + offset: 0, + expectedOrders: originalOrders, + }, + { + limit: 3, + offset: 0, + expectedOrders: originalOrders[:3], + }, + { + limit: 0, + offset: 3, + expectedError: "can't use Offset without Limit", + }, + { + limit: 10, + offset: 3, + expectedOrders: originalOrders[3:], + }, + { + limit: 4, + offset: 3, + expectedOrders: originalOrders[3:7], + }, + { + limit: 10, + offset: 10, + expectedOrders: []*types.OrderWithMetadata{}, + }, + } + for i, testCase := range testCases { + testCaseName := fmt.Sprintf("test case %d", i) + t.Run(testCaseName, runFindOrdersLimitAndOffsetTestCase(t, db, originalOrders, testCase)) + } +} + +type findOrdersLimitAndOffsetTestCase struct { + limit uint + offset uint + expectedOrders []*types.OrderWithMetadata + expectedError string +} + +func runFindOrdersLimitAndOffsetTestCase(t *testing.T, db *DB, originalOrders []*types.OrderWithMetadata, testCase findOrdersLimitAndOffsetTestCase) func(t *testing.T) { + return func(t *testing.T) { + findOpts := &OrderQuery{ + Sort: []OrderSort{ + { + Field: OFHash, + Direction: Ascending, + }, + }, + Limit: testCase.limit, + Offset: testCase.offset, + } + + foundOrders, err := db.FindOrders(findOpts) + if testCase.expectedError != "" { + require.Error(t, err, "expected an error but got nil") + assert.Contains(t, err.Error(), testCase.expectedError, "wrong error message") + } else { + require.NoError(t, err) + assertOrderSlicesAreEqual(t, testCase.expectedOrders, foundOrders) + } + } +} + +func TestFindOrdersFilter(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + _, testCases := makeOrderFilterTestCases(t, db) + + for i, testCase := range testCases { + testCaseName := fmt.Sprintf("%s (test case %d)", testCase.name, i) + t.Run(testCaseName, runFindOrdersFilterTestCase(t, db, testCase)) + } +} + +func TestFindOrdersFilterSortLimitAndOffset(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + storedOrders := createAndStoreOrdersForFilterTests(t, db) + + query := &OrderQuery{ + Filters: []OrderFilter{ + { + Field: OFMakerAssetAmount, + Kind: GreaterOrEqual, + Value: big.NewInt(3), + }, + }, + Sort: []OrderSort{ + { + Field: OFMakerAssetAmount, + Direction: Ascending, + }, + }, + Limit: 3, + Offset: 2, + } + expectedOrders := storedOrders[5:8] + actualOrders, err := db.FindOrders(query) + require.NoError(t, err) + assertOrderSlicesAreEqual(t, expectedOrders, actualOrders) +} + +func runFindOrdersFilterTestCase(t *testing.T, db *DB, testCase orderFilterTestCase) func(t *testing.T) { + return func(t *testing.T) { + findOpts := &OrderQuery{ + Filters: testCase.filters, + } + foundOrders, err := db.FindOrders(findOpts) + require.NoError(t, err) + assertOrderSlicesAreUnsortedEqual(t, testCase.expectedMatchingOrders, foundOrders) + } +} + +func TestCountOrdersFilter(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + _, testCases := makeOrderFilterTestCases(t, db) + + for i, testCase := range testCases { + testCaseName := fmt.Sprintf("%s (test case %d)", testCase.name, i) + t.Run(testCaseName, runCountOrdersFilterTestCase(t, db, testCase)) + } +} + +func runCountOrdersFilterTestCase(t *testing.T, db *DB, testCase orderFilterTestCase) func(t *testing.T) { + return func(t *testing.T) { + opts := &OrderQuery{ + Filters: testCase.filters, + } + + count, err := db.CountOrders(opts) + require.NoError(t, err) + require.Equal(t, len(testCase.expectedMatchingOrders), count, "wrong number of orders") + } +} + +func TestDeleteOrder(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + added, _, err := db.AddOrders([]*types.OrderWithMetadata{newTestOrder()}) + require.NoError(t, err) + originalOrder := added[0] + require.NoError(t, db.DeleteOrder(originalOrder.Hash)) + + foundOrders, err := db.FindOrders(nil) + require.NoError(t, err) + assert.Empty(t, foundOrders, "expected no orders remaining in the database") +} + +func TestDeleteOrdersLimitAndOffset(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + // Create orders with increasing makerAssetAmount. + // - orders[0].MakerAssetAmount = 0 + // - orders[1].MakerAssetAmount = 1 + // - etc. + numOrders := 10 + originalOrders := []*types.OrderWithMetadata{} + for i := 0; i < numOrders; i++ { + testOrder := newTestOrder() + testOrder.MakerAssetAmount = big.NewInt(int64(i)) + originalOrders = append(originalOrders, testOrder) + } + _, _, err := db.AddOrders(originalOrders) + require.NoError(t, err) + + // Call DeleteOrders and make sure the return value is what we expect. + deletedOrders, err := db.DeleteOrders(&OrderQuery{ + Sort: []OrderSort{ + { + Field: OFMakerAssetAmount, + Direction: Ascending, + }, + }, + Offset: 3, + Limit: 4, + }) + require.NoError(t, err) + assertOrderSlicesAreEqual(t, originalOrders[3:7], deletedOrders) + + // Call FindOrders to check that the remaining orders in the db are + // what we expect. + expectedRemainingOrders := append( + safeSubsliceOrders(originalOrders, 0, 3), + safeSubsliceOrders(originalOrders, 7, 10)..., + ) + actualRemainingOrders, err := db.FindOrders(nil) + require.NoError(t, err) + assertOrderSlicesAreUnsortedEqual(t, expectedRemainingOrders, actualRemainingOrders) } -func (tm *testModel) ID() []byte { - return []byte(tm.Name) +func TestDeleteOrdersFilter(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + storedOrders, testCases := makeOrderFilterTestCases(t, db) + for i, testCase := range testCases { + testCaseName := fmt.Sprintf("%s (test case %d)", testCase.name, i) + t.Run(testCaseName, runDeleteOrdersFilterTestCase(t, db, storedOrders, testCase)) + } } -func newTestDB(t require.TestingT) *DB { - db, err := Open("/tmp/leveldb_testing/" + uuid.New().String()) +func runDeleteOrdersFilterTestCase(t *testing.T, db *DB, originalOrders []*types.OrderWithMetadata, testCase orderFilterTestCase) func(t *testing.T) { + return func(t *testing.T) { + defer func() { + // After each case, reset the state of the database by re-adding the original orders. + _, _, err := db.AddOrders(originalOrders) + require.NoError(t, err) + }() + + deleteOpts := &OrderQuery{ + Filters: testCase.filters, + } + deletedOrders, err := db.DeleteOrders(deleteOpts) + assertOrderSlicesAreUnsortedEqual(t, testCase.expectedMatchingOrders, deletedOrders) + + // Figure out which orders should still remain in the database, then + // call FindOrders and make sure we get back what we expect. + expectedRemainingOrders := []*types.OrderWithMetadata{} + for _, order := range originalOrders { + shouldBeRemaining := true + for _, remainingOrder := range testCase.expectedMatchingOrders { + if order.Hash.Hex() == remainingOrder.Hash.Hex() { + shouldBeRemaining = false + break + } + } + if shouldBeRemaining { + expectedRemainingOrders = append(expectedRemainingOrders, order) + } + } + require.NoError(t, err) + actualRemainingOrders, err := db.FindOrders(nil) + require.NoError(t, err) + assertOrderSlicesAreUnsortedEqual(t, expectedRemainingOrders, actualRemainingOrders) + } +} + +func TestAddMiniHeaders(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + dbOpts := TestOptions() + db, err := New(ctx, dbOpts) + require.NoError(t, err) + + numMiniHeaders := dbOpts.MaxMiniHeaders + miniHeaders := []*types.MiniHeader{} + for i := 0; i < numMiniHeaders; i++ { + // It's important to note that each miniHeader has a increasing + // blockNumber. Later we will add more miniHeaders with higher numbers. + miniHeader := newTestMiniHeader() + miniHeader.Number = big.NewInt(int64(i)) + miniHeaders = append(miniHeaders, miniHeader) + } + { + added, removed, err := db.AddMiniHeaders(miniHeaders) + require.NoError(t, err) + assert.Len(t, removed, 0, "Expected no miniHeaders to be removed") + assertMiniHeaderSlicesAreUnsortedEqual(t, miniHeaders, added) + } + { + added, removed, err := db.AddMiniHeaders(miniHeaders) + require.NoError(t, err) + assert.Len(t, removed, 0, "Expected no miniHeaders to be removed") + assert.Len(t, added, 0, "Expected no miniHeaders to be added (they should already exist)") + } + + // Create 10 more mini headers with higher block numbers. + miniHeadersWithHigherBlockNumbers := []*types.MiniHeader{} + for i := dbOpts.MaxMiniHeaders; i < dbOpts.MaxMiniHeaders+10; i++ { + // It's important to note that each miniHeader has a increasing + // blockNumber. Later will add more miniHeaders with higher numbers. + miniHeader := newTestMiniHeader() + miniHeader.Number = big.NewInt(int64(i)) + miniHeadersWithHigherBlockNumbers = append(miniHeadersWithHigherBlockNumbers, miniHeader) + } + { + added, removed, err := db.AddMiniHeaders(miniHeadersWithHigherBlockNumbers) + require.NoError(t, err) + assertMiniHeaderSlicesAreUnsortedEqual(t, miniHeadersWithHigherBlockNumbers, added) + assertMiniHeaderSlicesAreUnsortedEqual(t, miniHeaders[:10], removed) + } +} + +func TestGetMiniHeader(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + added, _, err := db.AddMiniHeaders([]*types.MiniHeader{newTestMiniHeader()}) require.NoError(t, err) + originalMiniHeader := added[0] + + foundMiniHeader, err := db.GetMiniHeader(originalMiniHeader.Hash) + require.NoError(t, err) + assertMiniHeadersAreEqual(t, originalMiniHeader, foundMiniHeader) + + _, err = db.GetMiniHeader(common.Hash{}) + assert.EqualError(t, err, ErrNotFound.Error(), "calling GetMiniHeader with a hash that doesn't exist should return ErrNotFound") +} + +func TestGetLatestMiniHeader(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + numMiniHeaders := 3 + storedMiniHeaders := []*types.MiniHeader{} + for i := 0; i < numMiniHeaders; i++ { + miniHeader := newTestMiniHeader() + miniHeader.Number = big.NewInt(int64(i)) + storedMiniHeaders = append(storedMiniHeaders, miniHeader) + } + _, _, err := db.AddMiniHeaders(storedMiniHeaders) + require.NoError(t, err) + + foundMiniHeader, err := db.GetLatestMiniHeader() + require.NoError(t, err) + assertMiniHeadersAreEqual(t, storedMiniHeaders[2], foundMiniHeader) +} + +func TestFindMiniHeaders(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + numMiniHeaders := 10 + originalMiniHeaders := []*types.MiniHeader{} + for i := 0; i < numMiniHeaders; i++ { + originalMiniHeaders = append(originalMiniHeaders, newTestMiniHeader()) + } + _, _, err := db.AddMiniHeaders(originalMiniHeaders) + require.NoError(t, err) + + foundMiniHeaders, err := db.FindMiniHeaders(nil) + require.NoError(t, err) + assertMiniHeaderSlicesAreUnsortedEqual(t, originalMiniHeaders, foundMiniHeaders) +} + +func TestFindMiniHeadersSort(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + // Create some test miniHeaders with carefully chosen Number and Timestamp + // values for testing sorting. + numMiniHeaders := 5 + originalMiniHeaders := []*types.MiniHeader{} + for i := 0; i < numMiniHeaders; i++ { + miniHeader := newTestMiniHeader() + miniHeader.Number = big.NewInt(int64(i)) + // It's important for some miniHeaders to have the same Timestamp + // so that we can test secondary sorts (sorting on more than one + // field). + if i%2 == 0 { + miniHeader.Timestamp = time.Unix(717793653, 0) + } else { + miniHeader.Timestamp = time.Unix(1588194484, 0) + } + originalMiniHeaders = append(originalMiniHeaders, miniHeader) + } + _, _, err := db.AddMiniHeaders(originalMiniHeaders) + require.NoError(t, err) + + testCases := []findMiniHeadersSortTestCase{ + { + sortOpts: []MiniHeaderSort{ + { + Field: MFNumber, + Direction: Ascending, + }, + }, + less: lessByNumberAsc, + }, + { + sortOpts: []MiniHeaderSort{ + { + Field: MFNumber, + Direction: Descending, + }, + }, + less: lessByNumberDesc, + }, + { + sortOpts: []MiniHeaderSort{ + { + Field: MFTimestamp, + Direction: Ascending, + }, + { + Field: MFNumber, + Direction: Ascending, + }, + }, + less: lessByTimestampAscAndNumberAsc, + }, + { + sortOpts: []MiniHeaderSort{ + { + Field: MFTimestamp, + Direction: Descending, + }, + { + Field: MFNumber, + Direction: Descending, + }, + }, + less: lessByTimestampDescAndNumberDesc, + }, + } + for i, testCase := range testCases { + testCaseName := fmt.Sprintf("test case %d", i) + t.Run(testCaseName, runFindMiniHeadersSortTestCase(t, db, originalMiniHeaders, testCase)) + } +} + +type findMiniHeadersSortTestCase struct { + sortOpts []MiniHeaderSort + less func([]*types.MiniHeader) func(i, j int) bool +} + +func runFindMiniHeadersSortTestCase(t *testing.T, db *DB, originalMiniHeaders []*types.MiniHeader, testCase findMiniHeadersSortTestCase) func(t *testing.T) { + return func(t *testing.T) { + expectedMiniHeaders := make([]*types.MiniHeader, len(originalMiniHeaders)) + copy(expectedMiniHeaders, originalMiniHeaders) + sort.Slice(expectedMiniHeaders, testCase.less(expectedMiniHeaders)) + findOpts := &MiniHeaderQuery{ + Sort: testCase.sortOpts, + } + foundMiniHeaders, err := db.FindMiniHeaders(findOpts) + require.NoError(t, err) + assertMiniHeaderSlicesAreEqual(t, expectedMiniHeaders, foundMiniHeaders) + } +} + +func TestFindMiniHeadersLimitAndOffset(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + numMiniHeaders := 10 + originalMiniHeaders := []*types.MiniHeader{} + for i := 0; i < numMiniHeaders; i++ { + originalMiniHeaders = append(originalMiniHeaders, newTestMiniHeader()) + } + _, _, err := db.AddMiniHeaders(originalMiniHeaders) + require.NoError(t, err) + sortMiniHeadersByHash(originalMiniHeaders) + + testCases := []findMiniHeadersLimitAndOffsetTestCase{ + { + limit: 0, + offset: 0, + expectedMiniHeaders: originalMiniHeaders, + }, + { + limit: 3, + offset: 0, + expectedMiniHeaders: originalMiniHeaders[:3], + }, + { + limit: 0, + offset: 3, + expectedError: "can't use Offset without Limit", + }, + { + limit: 10, + offset: 3, + expectedMiniHeaders: originalMiniHeaders[3:], + }, + { + limit: 4, + offset: 3, + expectedMiniHeaders: originalMiniHeaders[3:7], + }, + { + limit: 10, + offset: 10, + expectedMiniHeaders: []*types.MiniHeader{}, + }, + } + for i, testCase := range testCases { + testCaseName := fmt.Sprintf("test case %d", i) + t.Run(testCaseName, runFindMiniHeadersLimitAndOffsetTestCase(t, db, originalMiniHeaders, testCase)) + } +} + +type findMiniHeadersLimitAndOffsetTestCase struct { + limit uint + offset uint + expectedMiniHeaders []*types.MiniHeader + expectedError string +} + +func runFindMiniHeadersLimitAndOffsetTestCase(t *testing.T, db *DB, originalMiniHeaders []*types.MiniHeader, testCase findMiniHeadersLimitAndOffsetTestCase) func(t *testing.T) { + return func(t *testing.T) { + findOpts := &MiniHeaderQuery{ + Sort: []MiniHeaderSort{ + { + Field: MFHash, + Direction: Ascending, + }, + }, + Limit: testCase.limit, + Offset: testCase.offset, + } + + foundMiniHeaders, err := db.FindMiniHeaders(findOpts) + if testCase.expectedError != "" { + require.Error(t, err, "expected an error but got nil") + assert.Contains(t, err.Error(), testCase.expectedError, "wrong error message") + } else { + require.NoError(t, err) + assertMiniHeaderSlicesAreEqual(t, testCase.expectedMiniHeaders, foundMiniHeaders) + } + } +} + +func TestFindMiniHeadersFilter(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + _, testCases := makeMiniHeaderFilterTestCases(t, db) + for i, testCase := range testCases { + testCaseName := fmt.Sprintf("%s (test case %d)", testCase.name, i) + t.Run(testCaseName, runFindMiniHeadersFilterTestCase(t, db, testCase)) + } +} + +func runFindMiniHeadersFilterTestCase(t *testing.T, db *DB, testCase miniHeaderFilterTestCase) func(t *testing.T) { + return func(t *testing.T) { + findOpts := &MiniHeaderQuery{ + Filters: testCase.filters, + } + + foundMiniHeaders, err := db.FindMiniHeaders(findOpts) + require.NoError(t, err) + assertMiniHeaderSlicesAreUnsortedEqual(t, testCase.expectedMatchingMiniHeaders, foundMiniHeaders) + } +} + +func TestDeleteMiniHeader(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + added, _, err := db.AddMiniHeaders([]*types.MiniHeader{newTestMiniHeader()}) + require.NoError(t, err) + originalMiniHeader := added[0] + require.NoError(t, db.DeleteMiniHeader(originalMiniHeader.Hash)) + + foundMiniHeaders, err := db.FindMiniHeaders(nil) + require.NoError(t, err) + assert.Empty(t, foundMiniHeaders, "expected no miniHeaders remaining in the database") +} + +func TestDeleteMiniHeadersLimitAndOffset(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + // Create miniHeaders with increasing Number. + // - miniHeaders[0].Number = 0 + // - miniHeaders[1].Number = 1 + // - etc. + numMiniHeaders := 10 + originalMiniHeaders := []*types.MiniHeader{} + for i := 0; i < numMiniHeaders; i++ { + testMiniHeader := newTestMiniHeader() + testMiniHeader.Number = big.NewInt(int64(i)) + originalMiniHeaders = append(originalMiniHeaders, testMiniHeader) + } + _, _, err := db.AddMiniHeaders(originalMiniHeaders) + require.NoError(t, err) + + // Call DeleteMiniHeaders and make sure the return value is what we expect. + deletedMiniHeaders, err := db.DeleteMiniHeaders(&MiniHeaderQuery{ + Sort: []MiniHeaderSort{ + { + Field: MFNumber, + Direction: Ascending, + }, + }, + Offset: 3, + Limit: 4, + }) + require.NoError(t, err) + assertMiniHeaderSlicesAreEqual(t, originalMiniHeaders[3:7], deletedMiniHeaders) + + // Call FindMiniHeaders to check that the remaining orders in the db are + // what we expect. + expectedRemainingMiniHeaders := append( + safeSubsliceMiniHeaders(originalMiniHeaders, 0, 3), + safeSubsliceMiniHeaders(originalMiniHeaders, 7, 10)..., + ) + actualRemainingMiniHeaders, err := db.FindMiniHeaders(nil) + require.NoError(t, err) + assertMiniHeaderSlicesAreUnsortedEqual(t, expectedRemainingMiniHeaders, actualRemainingMiniHeaders) +} + +func TestDeleteMiniHeadersFilter(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + storedMiniHeaders, testCases := makeMiniHeaderFilterTestCases(t, db) + + for i, testCase := range testCases { + testCaseName := fmt.Sprintf("%s (test case %d)", testCase.name, i) + t.Run(testCaseName, runDeleteMiniHeadersFilterTestCase(t, db, storedMiniHeaders, testCase)) + } +} + +func runDeleteMiniHeadersFilterTestCase(t *testing.T, db *DB, storedMiniHeaders []*types.MiniHeader, testCase miniHeaderFilterTestCase) func(t *testing.T) { + return func(t *testing.T) { + defer func() { + // After each case, reset the state of the database by re-adding the original miniHeaders. + _, _, err := db.AddMiniHeaders(storedMiniHeaders) + require.NoError(t, err) + }() + + findOpts := &MiniHeaderQuery{ + Filters: testCase.filters, + } + deletedMiniHeaders, err := db.DeleteMiniHeaders(findOpts) + require.NoError(t, err) + assertMiniHeaderSlicesAreUnsortedEqual(t, testCase.expectedMatchingMiniHeaders, deletedMiniHeaders) + + // Calculate expected remaining miniheaders and make sure that each one is still + // in the database. + expectedRemainingMiniHeaders := []*types.MiniHeader{} + for _, miniHeader := range storedMiniHeaders { + shouldBeRemaining := true + for _, remainingMiniHeader := range testCase.expectedMatchingMiniHeaders { + if miniHeader.Hash.Hex() == remainingMiniHeader.Hash.Hex() { + shouldBeRemaining = false + break + } + } + if shouldBeRemaining { + expectedRemainingMiniHeaders = append(expectedRemainingMiniHeaders, miniHeader) + } + } + + remainingMiniHeaders, err := db.FindMiniHeaders(nil) + require.NoError(t, err) + assertMiniHeaderSlicesAreUnsortedEqual(t, expectedRemainingMiniHeaders, remainingMiniHeaders) + } +} + +func TestSaveMetadata(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + err := db.SaveMetadata(newTestMetadata()) + require.NoError(t, err) + + // Attempting to save metadata when it already exists in the database should return an error. + err = db.SaveMetadata(newTestMetadata()) + assert.EqualError(t, err, ErrMetadataAlreadyExists.Error()) +} + +func TestGetMetadata(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + _, err := db.GetMetadata() + assert.EqualError(t, err, ErrNotFound.Error(), "calling GetMetadata when it hasn't been saved yet should return ErrNotFound") + + originalMetadata := newTestMetadata() + err = db.SaveMetadata(originalMetadata) + require.NoError(t, err) + + foundMetadata, err := db.GetMetadata() + require.NoError(t, err) + require.NotNil(t, foundMetadata, "found order should not be nil") + assertMetadatasAreEqual(t, originalMetadata, foundMetadata) +} + +func TestUpdateMetadata(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db := newTestDB(t, ctx) + + err := db.UpdateMetadata(func(existingMetadata *types.Metadata) *types.Metadata { + return existingMetadata + }) + assert.EqualError(t, err, ErrNotFound.Error(), "calling UpdateMetadata when it hasn't been saved yet should return ErrNotFound") + + originalMetadata := newTestMetadata() + err = db.SaveMetadata(originalMetadata) + require.NoError(t, err) + + updatedMaxExpirationTime := originalMetadata.MaxExpirationTime.Add(originalMetadata.MaxExpirationTime, big.NewInt(500)) + err = db.UpdateMetadata(func(existingMetadata *types.Metadata) *types.Metadata { + updatedMetadata := existingMetadata + updatedMetadata.MaxExpirationTime = updatedMaxExpirationTime + return updatedMetadata + }) + + expectedMetadata := originalMetadata + expectedMetadata.MaxExpirationTime = updatedMaxExpirationTime + foundMetadata, err := db.GetMetadata() + require.NoError(t, err) + assertMetadatasAreEqual(t, expectedMetadata, foundMetadata) +} + +func TestParseContractAddressesAndTokenIdsFromAssetData(t *testing.T) { + assetDataDecoder := zeroex.NewAssetDataDecoder() + // ERC20 AssetData + erc20AssetData := common.Hex2Bytes("f47261b000000000000000000000000038ae374ecf4db50b0ff37125b591a04997106a32") + parsedAssetData, err := ParseContractAddressesAndTokenIdsFromAssetData(assetDataDecoder, erc20AssetData, contractAddresses) + require.NoError(t, err) + assert.Len(t, parsedAssetData, 1) + expectedAddress := common.HexToAddress("0x38ae374ecf4db50b0ff37125b591a04997106a32") + assert.Equal(t, expectedAddress, parsedAssetData[0].Address) + var expectedTokenID *big.Int = nil + assert.Equal(t, expectedTokenID, parsedAssetData[0].TokenID) + + // ERC721 AssetData + erc721AssetData := common.Hex2Bytes("025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c480000000000000000000000000000000000000000000000000000000000000001") + parsedAssetData, err = ParseContractAddressesAndTokenIdsFromAssetData(assetDataDecoder, erc721AssetData, contractAddresses) + require.NoError(t, err) + assert.Equal(t, 1, len(parsedAssetData)) + expectedAddress = common.HexToAddress("0x1dC4c1cEFEF38a777b15aA20260a54E584b16C48") + assert.Equal(t, expectedAddress, parsedAssetData[0].Address) + expectedTokenID = big.NewInt(1) + assert.Equal(t, expectedTokenID, parsedAssetData[0].TokenID) + + // Multi AssetData + multiAssetData := common.Hex2Bytes("94cfcdd7000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000004600000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000024f47261b00000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c48000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000044025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c480000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000x94cfcdd7000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000004600000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000024f47261b00000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c48000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000044025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c48000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000") + parsedAssetData, err = ParseContractAddressesAndTokenIdsFromAssetData(assetDataDecoder, multiAssetData, contractAddresses) + require.NoError(t, err) + assert.Equal(t, 2, len(parsedAssetData)) + expectedParsedAssetData := []*types.SingleAssetData{ + { + Address: common.HexToAddress("0x1dc4c1cefef38a777b15aa20260a54e584b16c48"), + }, + { + Address: common.HexToAddress("0x1dc4c1cefef38a777b15aa20260a54e584b16c48"), + TokenID: big.NewInt(1), + }, + } + for i, singleAssetData := range parsedAssetData { + expectedSingleAssetData := expectedParsedAssetData[i] + assert.Equal(t, expectedSingleAssetData.Address, singleAssetData.Address) + assert.Equal(t, expectedSingleAssetData.TokenID, singleAssetData.TokenID) + } +} + +func newTestDB(t *testing.T, ctx context.Context) *DB { + db, err := New(ctx, TestOptions()) + require.NoError(t, err) + count, err := db.CountOrders(nil) + require.NoError(t, err) + require.Equal(t, count, 0, "there should be no orders stored in a brand new database") return db } -func TestOpen(t *testing.T) { - t.Parallel() - db, err := Open("/tmp/leveldb_testing") +// newTestOrder returns a new order with a random hash that is ready to insert +// into the database. Some computed fields (e.g. hash, signature) may not be +// correct, so the order will not pass 0x validation. +func newTestOrder() *types.OrderWithMetadata { + return &types.OrderWithMetadata{ + Hash: common.BigToHash(big.NewInt(int64(rand.Int()))), + ChainID: big.NewInt(constants.TestChainID), + MakerAddress: constants.GanacheAccount1, + TakerAddress: constants.NullAddress, + SenderAddress: constants.NullAddress, + FeeRecipientAddress: constants.NullAddress, + MakerAssetData: constants.ZRXAssetData, + MakerFeeAssetData: constants.NullBytes, + TakerAssetData: constants.WETHAssetData, + TakerFeeAssetData: constants.NullBytes, + Salt: big.NewInt(int64(time.Now().Nanosecond())), + MakerFee: big.NewInt(0), + TakerFee: big.NewInt(0), + MakerAssetAmount: math.MaxBig256, + TakerAssetAmount: big.NewInt(42), + ExpirationTimeSeconds: big.NewInt(time.Now().Add(24 * time.Hour).Unix()), + ExchangeAddress: contractAddresses.Exchange, + Signature: []byte{1, 2, 255, 255}, + LastUpdated: time.Now(), + FillableTakerAssetAmount: big.NewInt(42), + IsRemoved: false, + IsPinned: true, + ParsedMakerAssetData: []*types.SingleAssetData{ + { + Address: constants.GanacheDummyERC721TokenAddress, + TokenID: big.NewInt(10), + }, + { + Address: constants.GanacheDummyERC721TokenAddress, + TokenID: big.NewInt(20), + }, + { + Address: constants.GanacheDummyERC721TokenAddress, + TokenID: big.NewInt(30), + }, + }, + ParsedMakerFeeAssetData: []*types.SingleAssetData{ + { + Address: constants.GanacheDummyERC1155MintableAddress, + TokenID: big.NewInt(567), + }, + }, + } +} + +func newTestMiniHeader() *types.MiniHeader { + return &types.MiniHeader{ + Hash: common.BigToHash(big.NewInt(int64(rand.Int()))), + Parent: common.BigToHash(big.NewInt(int64(rand.Int()))), + Number: big.NewInt(int64(rand.Int())), + Timestamp: time.Now(), + Logs: newTestEventLogs(), + } +} + +func newTestEventLogs() []ethtypes.Log { + return []ethtypes.Log{ + { + Address: common.HexToAddress("0x21ab6c9fac80c59d401b37cb43f81ea9dde7fe34"), + Topics: []common.Hash{ + common.HexToHash("0xddf252ad1be2c89b69c2b068fc378daa952ba7f163c4a11628f55a4df523b3ef"), + common.HexToHash("0x0000000000000000000000004d8a4aa1f304f9632cf3877473445d85c577fe5d"), + common.HexToHash("0x0000000000000000000000004bdd0d16cfa18e33860470fc4d65c6f5cee60959"), + }, + Data: common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000337ad34c0"), + BlockNumber: 30, + TxHash: common.HexToHash("0xd9bb5f9e888ee6f74bedcda811c2461230f247c205849d6f83cb6c3925e54586"), + TxIndex: 0, + BlockHash: common.HexToHash("0x6bbf9b6e836207ab25379c20e517a89090cbbaf8877746f6ed7fb6820770816b"), + Index: 0, + Removed: false, + }, + { + Address: common.HexToAddress("0x21ab6c9fac80c59d401b37cb43f81ea9dde7fe34"), + Topics: []common.Hash{ + common.HexToHash("0xddf252ad1be2c89b69c2b068fc378daa952ba7f163c4a11628f55a4df523b3ef"), + common.HexToHash("0x0000000000000000000000004d8a4aa1f304f9632cf3877473445d85c577fe5d"), + common.HexToHash("0x0000000000000000000000004bdd0d16cfa18e33860470fc4d65c6f5cee60959"), + }, + Data: common.Hex2Bytes("00000000000000000000000000000000000000000000000000000000deadbeef"), + BlockNumber: 31, + TxHash: common.HexToHash("0xd9bb5f9e888ee6f74bedcda811c2461230f247c205849d6f83cb6c3925e54586"), + TxIndex: 1, + BlockHash: common.HexToHash("0x6bbf9b6e836207ab25379c20e517a89090cbbaf8877746f6ed7fb6820770816b"), + Index: 2, + Removed: true, + }, + } +} + +func newTestMetadata() *types.Metadata { + return &types.Metadata{ + EthereumChainID: 42, + MaxExpirationTime: big.NewInt(12345), + EthRPCRequestsSentInCurrentUTCDay: 1337, + StartOfCurrentUTCDay: time.Date(1992, time.September, 29, 8, 0, 0, 0, time.UTC), + } +} + +type orderFilterTestCase struct { + name string + filters []OrderFilter + expectedMatchingOrders []*types.OrderWithMetadata +} + +func createAndStoreOrdersForFilterTests(t *testing.T, db *DB) []*types.OrderWithMetadata { + // Create some test orders with very specific characteristics to make it easier to write tests. + // - Both MakerAssetAmount and TakerAssetAmount will be 0, 1, 2, etc. + // - MakerAssetData will be 'a', 'b', 'c', etc. + // - ParsedMakerAssetData will always be for the ERC721Dummy contract, and each will contain + // two token ids: (0, 1), (0, 11), (0, 21), (0, 31) etc. + numOrders := 10 + storedOrders := []*types.OrderWithMetadata{} + for i := 0; i < numOrders; i++ { + order := newTestOrder() + order.MakerAssetAmount = big.NewInt(int64(i)) + order.TakerAssetAmount = big.NewInt(int64(i)) + order.MakerAssetData = []byte{97 + byte(i)} + parsedMakerAssetData := []*types.SingleAssetData{ + { + Address: constants.GanacheDummyERC721TokenAddress, + TokenID: big.NewInt(0), + }, + { + Address: constants.GanacheDummyERC721TokenAddress, + TokenID: big.NewInt(int64(i)*10 + 1), + }, + } + order.ParsedMakerAssetData = parsedMakerAssetData + storedOrders = append(storedOrders, order) + } + _, _, err := db.AddOrders(storedOrders) + require.NoError(t, err) + return storedOrders +} + +func makeOrderFilterTestCases(t *testing.T, db *DB) ([]*types.OrderWithMetadata, []orderFilterTestCase) { + storedOrders := createAndStoreOrdersForFilterTests(t, db) + testCases := []orderFilterTestCase{ + { + name: "no filter", + filters: []OrderFilter{}, + expectedMatchingOrders: storedOrders, + }, + { + name: "IsRemoved = false", + filters: []OrderFilter{ + { + Field: OFIsRemoved, + Kind: Equal, + Value: false, + }, + }, + expectedMatchingOrders: storedOrders, + }, + { + name: "MakerAddress = Address1", + filters: []OrderFilter{ + { + Field: OFMakerAddress, + Kind: Equal, + Value: constants.GanacheAccount1, + }, + }, + expectedMatchingOrders: storedOrders, + }, + + // Filter on MakerAssetAmount (type BigInt/NUMERIC) + { + name: "MakerAssetAmount = 5", + filters: []OrderFilter{ + { + Field: OFMakerAssetAmount, + Kind: Equal, + Value: big.NewInt(5), + }, + }, + expectedMatchingOrders: storedOrders[5:6], + }, + { + name: "MakerAssetAmount != 5", + filters: []OrderFilter{ + { + Field: OFMakerAssetAmount, + Kind: NotEqual, + Value: big.NewInt(5), + }, + }, + expectedMatchingOrders: append(safeSubsliceOrders(storedOrders, 0, 5), safeSubsliceOrders(storedOrders, 6, 10)...), + }, + { + name: "MakerAssetAmount < 5", + filters: []OrderFilter{ + { + Field: OFMakerAssetAmount, + Kind: Less, + Value: big.NewInt(5), + }, + }, + expectedMatchingOrders: storedOrders[:5], + }, + { + name: "MakerAssetAmount > 5", + filters: []OrderFilter{ + { + Field: OFMakerAssetAmount, + Kind: Greater, + Value: big.NewInt(5), + }, + }, + expectedMatchingOrders: storedOrders[6:], + }, + { + name: "MakerAssetAmount <= 5", + filters: []OrderFilter{ + { + Field: OFMakerAssetAmount, + Kind: LessOrEqual, + Value: big.NewInt(5), + }, + }, + expectedMatchingOrders: storedOrders[:6], + }, + { + name: "MakerAssetAmount >= 5", + filters: []OrderFilter{ + { + Field: OFMakerAssetAmount, + Kind: GreaterOrEqual, + Value: big.NewInt(5), + }, + }, + expectedMatchingOrders: storedOrders[5:], + }, + { + name: "MakerAssetAmount < 10^76", + filters: []OrderFilter{ + { + Field: OFMakerAssetAmount, + Kind: Less, + Value: math.BigPow(10, 76), + }, + }, + expectedMatchingOrders: storedOrders, + }, + + // Filter on MakerAssetData (type []byte/TEXT) + { + name: "MakerAssetData = f", + filters: []OrderFilter{ + { + Field: OFMakerAssetData, + Kind: Equal, + Value: []byte("f"), + }, + }, + expectedMatchingOrders: storedOrders[5:6], + }, + { + name: "MakerAssetData != f", + filters: []OrderFilter{ + { + Field: OFMakerAssetData, + Kind: NotEqual, + Value: []byte("f"), + }, + }, + expectedMatchingOrders: append(safeSubsliceOrders(storedOrders, 0, 5), safeSubsliceOrders(storedOrders, 6, 10)...), + }, + { + name: "MakerAssetData < f", + filters: []OrderFilter{ + { + Field: OFMakerAssetData, + Kind: Less, + Value: []byte("f"), + }, + }, + expectedMatchingOrders: storedOrders[:5], + }, + { + name: "MakerAssetData > f", + filters: []OrderFilter{ + { + Field: OFMakerAssetData, + Kind: Greater, + Value: []byte("f"), + }, + }, + expectedMatchingOrders: storedOrders[6:], + }, + { + name: "MakerAssetData <= f", + filters: []OrderFilter{ + { + Field: OFMakerAssetData, + Kind: LessOrEqual, + Value: []byte("f"), + }, + }, + expectedMatchingOrders: storedOrders[:6], + }, + { + name: "MakerAssetData >= f", + filters: []OrderFilter{ + { + Field: OFMakerAssetData, + Kind: GreaterOrEqual, + Value: []byte("f"), + }, + }, + expectedMatchingOrders: storedOrders[5:], + }, + + // Filter on ParsedMakerAssetData (type ParsedAssetData/TEXT) + { + name: "ParsedMakerAssetData CONTAINS query that matches all", + filters: []OrderFilter{ + { + Field: OFParsedMakerAssetData, + Kind: Contains, + Value: fmt.Sprintf(`"address":"%s","tokenID":"0"`, strings.ToLower(constants.GanacheDummyERC721TokenAddress.Hex())), + }, + }, + expectedMatchingOrders: storedOrders, + }, + { + name: "ParsedMakerAssetData CONTAINS query that matches one", + filters: []OrderFilter{ + { + Field: OFParsedMakerAssetData, + Kind: Contains, + Value: fmt.Sprintf(`"address":"%s","tokenID":"51"`, strings.ToLower(constants.GanacheDummyERC721TokenAddress.Hex())), + }, + }, + expectedMatchingOrders: storedOrders[5:6], + }, + { + name: "ParsedMakerAssetData CONTAINS with helper method query that matches all", + filters: []OrderFilter{ + MakerAssetIncludesTokenAddressAndTokenID(constants.GanacheDummyERC721TokenAddress, big.NewInt(0)), + }, + expectedMatchingOrders: storedOrders, + }, + { + name: "ParsedMakerAssetData CONTAINS with helper method query that matches one", + filters: []OrderFilter{ + MakerAssetIncludesTokenAddressAndTokenID(constants.GanacheDummyERC721TokenAddress, big.NewInt(51)), + }, + expectedMatchingOrders: storedOrders[5:6], + }, + { + name: "ParsedMakerFeeAssetData CONTAINS with helper method query that matches all", + filters: []OrderFilter{ + MakerFeeAssetIncludesTokenAddressAndTokenID(constants.GanacheDummyERC1155MintableAddress, big.NewInt(567)), + }, + expectedMatchingOrders: storedOrders, + }, + + // Combining two or more filters + { + name: "MakerAssetAmount >= 3 AND MakerAssetData < h", + filters: []OrderFilter{ + { + Field: OFMakerAssetAmount, + Kind: GreaterOrEqual, + Value: big.NewInt(3), + }, + { + Field: OFMakerAssetData, + Kind: Less, + Value: []byte("h"), + }, + }, + expectedMatchingOrders: storedOrders[3:7], + }, + { + name: "MakerAssetAmount >= 3 AND MakerAssetData < h AND TakerAssetAmount != 5", + filters: []OrderFilter{ + { + Field: OFMakerAssetAmount, + Kind: GreaterOrEqual, + Value: big.NewInt(3), + }, + { + Field: OFMakerAssetData, + Kind: Less, + Value: []byte("h"), + }, + { + Field: OFTakerAssetAmount, + Kind: NotEqual, + Value: big.NewInt(5), + }, + }, + expectedMatchingOrders: append(safeSubsliceOrders(storedOrders, 3, 5), safeSubsliceOrders(storedOrders, 6, 7)...), + }, + } + + return storedOrders, testCases +} + +type miniHeaderFilterTestCase struct { + name string + filters []MiniHeaderFilter + expectedMatchingMiniHeaders []*types.MiniHeader +} + +func makeMiniHeaderFilterTestCases(t *testing.T, db *DB) ([]*types.MiniHeader, []miniHeaderFilterTestCase) { + // Create some test miniheaders with very specific characteristics to make it easier to write tests. + // - Number will be 0, 1, 2, etc. + // - Timestamp will be 0, 100, 200, etc. seconds since Unix Epoch + // - Each log in Logs will have BlockNumber set to 0, 1, 2, etc. + numMiniHeaders := 10 + storedMiniHeaders := []*types.MiniHeader{} + for i := 0; i < numMiniHeaders; i++ { + miniHeader := newTestMiniHeader() + miniHeader.Number = big.NewInt(int64(i)) + miniHeader.Timestamp = time.Unix(int64(i)*100, 0) + for i := range miniHeader.Logs { + miniHeader.Logs[i].BlockNumber = miniHeader.Number.Uint64() + } + storedMiniHeaders = append(storedMiniHeaders, miniHeader) + } + _, _, err := db.AddMiniHeaders(storedMiniHeaders) require.NoError(t, err) - require.NoError(t, db.Close()) + + testCases := []miniHeaderFilterTestCase{ + { + name: "no filter", + filters: []MiniHeaderFilter{}, + expectedMatchingMiniHeaders: storedMiniHeaders, + }, + + // Filter on Number (type BigInt/NUMERIC) + { + name: "Number = 5", + filters: []MiniHeaderFilter{ + { + Field: MFNumber, + Kind: Equal, + Value: big.NewInt(5), + }, + }, + expectedMatchingMiniHeaders: storedMiniHeaders[5:6], + }, + { + name: "Number != 5", + filters: []MiniHeaderFilter{ + { + Field: MFNumber, + Kind: NotEqual, + Value: big.NewInt(5), + }, + }, + expectedMatchingMiniHeaders: append(safeSubsliceMiniHeaders(storedMiniHeaders, 0, 5), safeSubsliceMiniHeaders(storedMiniHeaders, 6, 10)...), + }, + { + name: "Number < 5", + filters: []MiniHeaderFilter{ + { + Field: MFNumber, + Kind: Less, + Value: big.NewInt(5), + }, + }, + expectedMatchingMiniHeaders: storedMiniHeaders[:5], + }, + { + name: "Number > 5", + filters: []MiniHeaderFilter{ + { + Field: MFNumber, + Kind: Greater, + Value: big.NewInt(5), + }, + }, + expectedMatchingMiniHeaders: storedMiniHeaders[6:], + }, + { + name: "Number <= 5", + filters: []MiniHeaderFilter{ + { + Field: MFNumber, + Kind: LessOrEqual, + Value: big.NewInt(5), + }, + }, + expectedMatchingMiniHeaders: storedMiniHeaders[:6], + }, + { + name: "Number >= 5", + filters: []MiniHeaderFilter{ + { + Field: MFNumber, + Kind: GreaterOrEqual, + Value: big.NewInt(5), + }, + }, + expectedMatchingMiniHeaders: storedMiniHeaders[5:], + }, + { + name: "Number < 10^76", + filters: []MiniHeaderFilter{ + { + Field: MFNumber, + Kind: Less, + Value: math.BigPow(10, 76), + }, + }, + expectedMatchingMiniHeaders: storedMiniHeaders, + }, + + // Filter on Timestamp (type time.Time/TIMESTAMP) + { + name: "Timestamp = 500", + filters: []MiniHeaderFilter{ + { + Field: MFTimestamp, + Kind: Equal, + Value: time.Unix(500, 0), + }, + }, + expectedMatchingMiniHeaders: storedMiniHeaders[5:6], + }, + { + name: "Timestamp != 500", + filters: []MiniHeaderFilter{ + { + Field: MFTimestamp, + Kind: NotEqual, + Value: time.Unix(500, 0), + }, + }, + expectedMatchingMiniHeaders: append(safeSubsliceMiniHeaders(storedMiniHeaders, 0, 5), safeSubsliceMiniHeaders(storedMiniHeaders, 6, 10)...), + }, + { + name: "Timestamp < 500", + filters: []MiniHeaderFilter{ + { + Field: MFTimestamp, + Kind: Less, + Value: time.Unix(500, 0), + }, + }, + expectedMatchingMiniHeaders: storedMiniHeaders[:5], + }, + { + name: "Timestamp > 500", + filters: []MiniHeaderFilter{ + { + Field: MFTimestamp, + Kind: Greater, + Value: time.Unix(500, 0), + }, + }, + expectedMatchingMiniHeaders: storedMiniHeaders[6:], + }, + { + name: "Timestamp <= 500", + filters: []MiniHeaderFilter{ + { + Field: MFTimestamp, + Kind: LessOrEqual, + Value: time.Unix(500, 0), + }, + }, + expectedMatchingMiniHeaders: storedMiniHeaders[:6], + }, + { + name: "Timestamp >= 500", + filters: []MiniHeaderFilter{ + { + Field: MFTimestamp, + Kind: GreaterOrEqual, + Value: time.Unix(500, 0), + }, + }, + expectedMatchingMiniHeaders: storedMiniHeaders[5:], + }, + + // Filter on Logs (type ParsedAssetData/TEXT) + { + name: "Logs CONTAINS query that matches all", + filters: []MiniHeaderFilter{ + { + Field: MFLogs, + Kind: Contains, + Value: `"address":"0x21ab6c9fac80c59d401b37cb43f81ea9dde7fe34"`, + }, + }, + expectedMatchingMiniHeaders: storedMiniHeaders, + }, + { + name: "Logs CONTAINS query that matches one", + filters: []MiniHeaderFilter{ + { + Field: MFLogs, + Kind: Contains, + Value: `"blockNumber":"0x5"`, + }, + }, + expectedMatchingMiniHeaders: storedMiniHeaders[5:6], + }, + + // Combining two or more filters + { + name: "Number >= 3 AND Timestamp < h", + filters: []MiniHeaderFilter{ + { + Field: MFNumber, + Kind: GreaterOrEqual, + Value: big.NewInt(3), + }, + { + Field: MFTimestamp, + Kind: Less, + Value: time.Unix(700, 0), + }, + }, + expectedMatchingMiniHeaders: storedMiniHeaders[3:7], + }, + { + name: "Number >= 3 AND Timestamp < 700 AND Number != 5", + filters: []MiniHeaderFilter{ + { + Field: MFNumber, + Kind: GreaterOrEqual, + Value: big.NewInt(3), + }, + { + Field: MFTimestamp, + Kind: Less, + Value: time.Unix(700, 0), + }, + { + Field: MFNumber, + Kind: NotEqual, + Value: big.NewInt(5), + }, + }, + expectedMatchingMiniHeaders: append(safeSubsliceMiniHeaders(storedMiniHeaders, 3, 5), safeSubsliceMiniHeaders(storedMiniHeaders, 6, 7)...), + }, + } + + return storedMiniHeaders, testCases +} + +// safeSubsliceOrders returns a (shallow) subslice of orders without modifying +// the original slice. Uses the same semantics as slice expressions: low is +// inclusive, hi is exclusive. The returned slice still contains pointers, it +// just doesn't use the same underlying array. +func safeSubsliceOrders(orders []*types.OrderWithMetadata, low, hi int) []*types.OrderWithMetadata { + result := make([]*types.OrderWithMetadata, hi-low) + for i := low; i < hi; i++ { + result[i-low] = orders[i] + } + return result +} + +func sortOrdersByHash(orders []*types.OrderWithMetadata) { + sort.SliceStable(orders, func(i, j int) bool { + return bytes.Compare(orders[i].Hash.Bytes(), orders[j].Hash.Bytes()) == -1 + }) +} + +func lessByMakerAssetAmountAsc(orders []*types.OrderWithMetadata) func(i, j int) bool { + return func(i, j int) bool { + return orders[i].MakerAssetAmount.Cmp(orders[j].MakerAssetAmount) == -1 + } +} + +func lessByMakerAssetAmountDesc(orders []*types.OrderWithMetadata) func(i, j int) bool { + return func(i, j int) bool { + return orders[i].MakerAssetAmount.Cmp(orders[j].MakerAssetAmount) == 1 + } +} + +func lessByTakerAssetAmountAscAndMakerAssetAmountAsc(orders []*types.OrderWithMetadata) func(i, j int) bool { + return func(i, j int) bool { + switch orders[i].TakerAssetAmount.Cmp(orders[j].TakerAssetAmount) { + case -1: + // Less + return true + case 1: + // Greater + return false + default: + // Equal. In this case we use MakerAssetAmount as a secondary sort + // (i.e. a tie-breaker) + return orders[i].MakerAssetAmount.Cmp(orders[j].MakerAssetAmount) == -1 + } + } +} + +func lessByTakerAssetAmountDescAndMakerAssetAmountDesc(orders []*types.OrderWithMetadata) func(i, j int) bool { + return func(i, j int) bool { + switch orders[i].TakerAssetAmount.Cmp(orders[j].TakerAssetAmount) { + case -1: + // Less + return false + case 1: + // Greater + return true + default: + // Equal. In this case we use MakerAssetAmount as a secondary sort + // (i.e. a tie-breaker) + return orders[i].MakerAssetAmount.Cmp(orders[j].MakerAssetAmount) == 1 + } + } +} + +func assertOrderSlicesAreEqual(t *testing.T, expected, actual []*types.OrderWithMetadata) { + assert.Equal(t, len(expected), len(actual), "wrong number of orders") + for i, expectedOrder := range expected { + if i >= len(actual) { + break + } + actualOrder := actual[i] + assertOrdersAreEqual(t, expectedOrder, actualOrder) + } + if t.Failed() { + expectedJSON, err := json.MarshalIndent(expected, "", " ") + require.NoError(t, err) + actualJSON, err := json.MarshalIndent(actual, "", " ") + require.NoError(t, err) + t.Logf("\nexpected:\n%s\n\n", string(expectedJSON)) + t.Logf("\nactual:\n%s\n\n", string(actualJSON)) + assert.Equal(t, string(expectedJSON), string(actualJSON)) + } +} + +func assertOrderSlicesAreUnsortedEqual(t *testing.T, expected, actual []*types.OrderWithMetadata) { + // Make a copy of the given orders so we don't mess up the original when sorting them. + expectedCopy := make([]*types.OrderWithMetadata, len(expected)) + copy(expectedCopy, expected) + sortOrdersByHash(expectedCopy) + actualCopy := make([]*types.OrderWithMetadata, len(actual)) + copy(actualCopy, actual) + sortOrdersByHash(actualCopy) + assertOrderSlicesAreEqual(t, expectedCopy, actualCopy) +} + +func assertOrdersAreEqual(t *testing.T, expected, actual *types.OrderWithMetadata) { + if expected.LastUpdated.Equal(actual.LastUpdated) { + // HACK(albrow): In this case, the two values represent the same time. + // This is what we care about, but the assert package might consider + // them unequal if some internal fields are different (there are + // different ways of representing the same time). As a workaround, + // we manually set actual.LastUpdated. + actual.LastUpdated = expected.LastUpdated + } else { + assert.Equal(t, expected.LastUpdated, actual.LastUpdated, "order.LastUpdated was not equal") + } + // We can compare the rest of the fields normally. + assert.Equal(t, expected, actual) +} + +// safeSubsliceMiniHeaders returns a (shallow) subslice of mini headers without +// modifying the original slice. Uses the same semantics as slice expressions: +// low is inclusive, hi is exclusive. The returned slice still contains +// pointers, it just doesn't use the same underlying array. +func safeSubsliceMiniHeaders(miniHeaders []*types.MiniHeader, low, hi int) []*types.MiniHeader { + result := make([]*types.MiniHeader, hi-low) + for i := low; i < hi; i++ { + result[i-low] = miniHeaders[i] + } + return result +} + +func sortMiniHeadersByHash(miniHeaders []*types.MiniHeader) { + sort.SliceStable(miniHeaders, func(i, j int) bool { + return bytes.Compare(miniHeaders[i].Hash.Bytes(), miniHeaders[j].Hash.Bytes()) == -1 + }) +} + +func lessByNumberAsc(miniHeaders []*types.MiniHeader) func(i, j int) bool { + return func(i, j int) bool { + return miniHeaders[i].Number.Cmp(miniHeaders[j].Number) == -1 + } +} + +func lessByNumberDesc(miniHeaders []*types.MiniHeader) func(i, j int) bool { + return func(i, j int) bool { + return miniHeaders[i].Number.Cmp(miniHeaders[j].Number) == 1 + } +} + +func lessByTimestampAscAndNumberAsc(miniHeaders []*types.MiniHeader) func(i, j int) bool { + return func(i, j int) bool { + switch { + case miniHeaders[i].Timestamp.Before(miniHeaders[j].Timestamp): + // Less + return true + case miniHeaders[i].Timestamp.After(miniHeaders[j].Timestamp): + // Greater + return false + default: + // Equal. In this case we use Number as a secondary sort + // (i.e. a tie-breaker) + return miniHeaders[i].Number.Cmp(miniHeaders[j].Number) == -1 + } + } +} + +func lessByTimestampDescAndNumberDesc(miniHeaders []*types.MiniHeader) func(i, j int) bool { + return func(i, j int) bool { + switch { + case miniHeaders[i].Timestamp.Before(miniHeaders[j].Timestamp): + // Less + return false + case miniHeaders[i].Timestamp.After(miniHeaders[j].Timestamp): + // Greater + return true + default: + // Equal. In this case we use Number as a secondary sort + // (i.e. a tie-breaker) + return miniHeaders[i].Number.Cmp(miniHeaders[j].Number) == 1 + } + } +} + +func assertMiniHeaderSlicesAreEqual(t *testing.T, expected, actual []*types.MiniHeader) { + assert.Len(t, actual, len(expected), "wrong number of miniheaders") + for i, expectedMiniHeader := range expected { + if i >= len(actual) { + break + } + actualMiniHeader := expected[i] + assertMiniHeadersAreEqual(t, expectedMiniHeader, actualMiniHeader) + } + if t.Failed() { + expectedJSON, err := json.MarshalIndent(expected, "", " ") + require.NoError(t, err) + actualJSON, err := json.MarshalIndent(actual, "", " ") + require.NoError(t, err) + t.Logf("\nexpected:\n%s\n\n", string(expectedJSON)) + t.Logf("\nactual:\n%s\n\n", string(actualJSON)) + assert.Equal(t, string(expectedJSON), string(actualJSON)) + } +} + +func assertMiniHeaderSlicesAreUnsortedEqual(t *testing.T, expected, actual []*types.MiniHeader) { + // Make a copy of the given mini headers so we don't mess up the original when sorting them. + expectedCopy := make([]*types.MiniHeader, len(expected)) + copy(expectedCopy, expected) + sortMiniHeadersByHash(expectedCopy) + actualCopy := make([]*types.MiniHeader, len(actual)) + copy(actualCopy, actual) + sortMiniHeadersByHash(actualCopy) + assertMiniHeaderSlicesAreEqual(t, expected, actual) +} + +func assertMiniHeadersAreEqual(t *testing.T, expected, actual *types.MiniHeader) { + if expected.Timestamp.Equal(actual.Timestamp) { + // HACK(albrow): In this case, the two values represent the same time. + // This is what we care about, but the assert package might consider + // them unequal if some internal fields are different (there are + // different ways of representing the same time). As a workaround, + // we manually set actual.Timestamp. + actual.Timestamp = expected.Timestamp + } else { + assert.Equal(t, expected.Timestamp, actual.Timestamp, "miniHeader.Timestamp was not equal") + } + // We can compare the rest of the fields normally. + assert.Equal(t, expected, actual) +} + +func assertMetadatasAreEqual(t *testing.T, expected, actual *types.Metadata) { + if expected.StartOfCurrentUTCDay.Equal(actual.StartOfCurrentUTCDay) { + // HACK(albrow): In this case, the two values represent the same time. + // This is what we care about, but the assert package might consider + // them unequal if some internal fields are different (there are + // different ways of representing the same time). As a workaround, + // we manually set actual.StartOfCurrentUTCDay. + actual.StartOfCurrentUTCDay = expected.StartOfCurrentUTCDay + } else { + assert.Equal(t, expected.StartOfCurrentUTCDay, actual.StartOfCurrentUTCDay, "metadata.StartOfCurrentUTCDay was not equal") + } + // We can compare the rest of the fields normally. + assert.Equal(t, expected, actual) } diff --git a/db/dexie_implementation.go b/db/dexie_implementation.go new file mode 100644 index 000000000..6a51dc41c --- /dev/null +++ b/db/dexie_implementation.go @@ -0,0 +1,466 @@ +// +build js,wasm + +package db + +import ( + "context" + "errors" + "fmt" + "math/big" + "path/filepath" + "syscall/js" + + "github.com/0xProject/0x-mesh/common/types" + "github.com/0xProject/0x-mesh/db/dexietypes" + "github.com/0xProject/0x-mesh/packages/browser/go/jsutil" + "github.com/ethereum/go-ethereum/common" + "github.com/gibson042/canonicaljson-go" + "github.com/google/uuid" +) + +var _ Database = (*DB)(nil) + +type DB struct { + ctx context.Context + dexie js.Value + opts *Options +} + +func TestOptions() *Options { + return &Options{ + DriverName: "dexie", + DataSourceName: filepath.Join("mesh_testing", uuid.New().String()), + MaxOrders: 100, + MaxMiniHeaders: 20, + } +} + +func defaultOptions() *Options { + return &Options{ + DriverName: "dexie", + DataSourceName: "mesh_dexie_database", + MaxOrders: 100000, + MaxMiniHeaders: 20, + } +} + +// New creates a new connection to the database. The connection will be automatically closed +// when the given context is canceled. +func New(ctx context.Context, opts *Options) (database *DB, err error) { + if opts != nil && opts.DriverName != "dexie" { + return nil, fmt.Errorf(`unexpected driver name for js/wasm: %q (only "dexie" is supported)`, opts.DriverName) + } + defer func() { + if r := recover(); r != nil { + err = recoverError(r) + } + }() + newDexieDatabase := js.Global().Get("__mesh_dexie_newDatabase__") + if jsutil.IsNullOrUndefined(newDexieDatabase) { + return nil, errors.New("could not detect Dexie.js") + } + opts = parseOptions(opts) + dexie := newDexieDatabase.Invoke(opts) + + // Automatically close the database connection when the context is canceled. + go func() { + select { + case <-ctx.Done(): + _ = dexie.Call("close") + } + }() + + return &DB{ + ctx: ctx, + dexie: dexie, + opts: opts, + }, nil +} + +func (db *DB) AddOrders(orders []*types.OrderWithMetadata) (added []*types.OrderWithMetadata, removed []*types.OrderWithMetadata, err error) { + defer func() { + if r := recover(); r != nil { + err = recoverError(r) + } + }() + jsOrders, err := jsutil.InefficientlyConvertToJS(dexietypes.OrdersFromCommonType(orders)) + if err != nil { + return nil, nil, err + } + jsResult, err := jsutil.AwaitPromiseContext(db.ctx, db.dexie.Call("addOrdersAsync", jsOrders)) + if err != nil { + return nil, nil, convertJSError(err) + } + jsAdded := jsResult.Get("added") + var dexieAdded []*dexietypes.Order + if err := jsutil.InefficientlyConvertFromJS(jsAdded, &dexieAdded); err != nil { + return nil, nil, err + } + jsRemoved := jsResult.Get("removed") + var dexieRemoved []*dexietypes.Order + if err := jsutil.InefficientlyConvertFromJS(jsRemoved, &dexieRemoved); err != nil { + return nil, nil, err + } + return dexietypes.OrdersToCommonType(dexieAdded), dexietypes.OrdersToCommonType(dexieRemoved), nil +} + +func (db *DB) GetOrder(hash common.Hash) (order *types.OrderWithMetadata, err error) { + defer func() { + if r := recover(); r != nil { + err = recoverError(r) + } + }() + jsOrder, err := jsutil.AwaitPromiseContext(db.ctx, db.dexie.Call("getOrderAsync", hash.Hex())) + if err != nil { + return nil, convertJSError(err) + } + var dexieOrder dexietypes.Order + if err := jsutil.InefficientlyConvertFromJS(jsOrder, &dexieOrder); err != nil { + return nil, err + } + return dexietypes.OrderToCommonType(&dexieOrder), nil +} + +func (db *DB) FindOrders(query *OrderQuery) (orders []*types.OrderWithMetadata, err error) { + if err := checkOrderQuery(query); err != nil { + return nil, err + } + defer func() { + if r := recover(); r != nil { + err = recoverError(r) + } + }() + query = formatOrderQuery(query) + jsOrders, err := jsutil.AwaitPromiseContext(db.ctx, db.dexie.Call("findOrdersAsync", query)) + if err != nil { + return nil, convertJSError(err) + } + var dexieOrders []*dexietypes.Order + if err := jsutil.InefficientlyConvertFromJS(jsOrders, &dexieOrders); err != nil { + return nil, err + } + return dexietypes.OrdersToCommonType(dexieOrders), nil +} + +func (db *DB) CountOrders(query *OrderQuery) (count int, err error) { + if err := checkOrderQuery(query); err != nil { + return 0, err + } + defer func() { + if r := recover(); r != nil { + err = recoverError(r) + } + }() + query = formatOrderQuery(query) + jsCount, err := jsutil.AwaitPromiseContext(db.ctx, db.dexie.Call("countOrdersAsync", query)) + if err != nil { + return 0, convertJSError(err) + } + return jsCount.Int(), nil +} + +func (db *DB) DeleteOrder(hash common.Hash) (err error) { + defer func() { + if r := recover(); r != nil { + err = recoverError(r) + } + }() + _, jsErr := jsutil.AwaitPromiseContext(db.ctx, db.dexie.Call("deleteOrderAsync", hash.Hex())) + if jsErr != nil { + return convertJSError(jsErr) + } + return nil +} + +func (db *DB) DeleteOrders(query *OrderQuery) (deletedOrders []*types.OrderWithMetadata, err error) { + if err := checkOrderQuery(query); err != nil { + return nil, err + } + defer func() { + if r := recover(); r != nil { + err = recoverError(r) + } + }() + query = formatOrderQuery(query) + jsOrders, err := jsutil.AwaitPromiseContext(db.ctx, db.dexie.Call("deleteOrdersAsync", query)) + if err != nil { + return nil, convertJSError(err) + } + var dexieOrders []*dexietypes.Order + if err := jsutil.InefficientlyConvertFromJS(jsOrders, &dexieOrders); err != nil { + return nil, err + } + return dexietypes.OrdersToCommonType(dexieOrders), nil +} + +func (db *DB) UpdateOrder(hash common.Hash, updateFunc func(existingOrder *types.OrderWithMetadata) (updatedOrder *types.OrderWithMetadata, err error)) (err error) { + defer func() { + if r := recover(); r != nil { + err = recoverError(r) + } + }() + jsUpdateFunc := js.FuncOf(func(_ js.Value, args []js.Value) interface{} { + jsExistingOrder := args[0] + var dexieExistingOrder dexietypes.Order + if err := jsutil.InefficientlyConvertFromJS(jsExistingOrder, &dexieExistingOrder); err != nil { + panic(err) + } + orderToUpdate, err := updateFunc(dexietypes.OrderToCommonType(&dexieExistingOrder)) + if err != nil { + panic(err) + } + dexieOrderToUpdate := dexietypes.OrderFromCommonType(orderToUpdate) + jsOrderToUpdate, err := jsutil.InefficientlyConvertToJS(dexieOrderToUpdate) + if err != nil { + panic(err) + } + return jsOrderToUpdate + }) + defer jsUpdateFunc.Release() + _, jsErr := jsutil.AwaitPromiseContext(db.ctx, db.dexie.Call("updateOrderAsync", hash.Hex(), jsUpdateFunc)) + if jsErr != nil { + return convertJSError(jsErr) + } + return nil +} + +func (db *DB) AddMiniHeaders(miniHeaders []*types.MiniHeader) (added []*types.MiniHeader, removed []*types.MiniHeader, err error) { + defer func() { + if r := recover(); r != nil { + err = recoverError(r) + } + }() + jsMiniHeaders, err := jsutil.InefficientlyConvertToJS(dexietypes.MiniHeadersFromCommonType(miniHeaders)) + if err != nil { + return nil, nil, err + } + jsResult, err := jsutil.AwaitPromiseContext(db.ctx, db.dexie.Call("addMiniHeadersAsync", jsMiniHeaders)) + if err != nil { + return nil, nil, convertJSError(err) + } + jsAdded := jsResult.Get("added") + var dexieAdded []*dexietypes.MiniHeader + if err := jsutil.InefficientlyConvertFromJS(jsAdded, &dexieAdded); err != nil { + return nil, nil, err + } + jsRemoved := jsResult.Get("removed") + var dexieRemoved []*dexietypes.MiniHeader + if err := jsutil.InefficientlyConvertFromJS(jsRemoved, &dexieRemoved); err != nil { + return nil, nil, err + } + return dexietypes.MiniHeadersToCommonType(dexieAdded), dexietypes.MiniHeadersToCommonType(dexieRemoved), nil +} + +func (db *DB) GetMiniHeader(hash common.Hash) (miniHeader *types.MiniHeader, err error) { + defer func() { + if r := recover(); r != nil { + err = recoverError(r) + } + }() + jsMiniHeader, err := jsutil.AwaitPromiseContext(db.ctx, db.dexie.Call("getMiniHeaderAsync", hash.Hex())) + if err != nil { + return nil, convertJSError(err) + } + var dexieMiniHeader dexietypes.MiniHeader + if err := jsutil.InefficientlyConvertFromJS(jsMiniHeader, &dexieMiniHeader); err != nil { + return nil, err + } + return dexietypes.MiniHeaderToCommonType(&dexieMiniHeader), nil +} + +func (db *DB) FindMiniHeaders(query *MiniHeaderQuery) (miniHeaders []*types.MiniHeader, err error) { + if err := checkMiniHeaderQuery(query); err != nil { + return nil, err + } + defer func() { + if r := recover(); r != nil { + err = recoverError(r) + } + }() + query = formatMiniHeaderQuery(query) + jsMiniHeaders, err := jsutil.AwaitPromiseContext(db.ctx, db.dexie.Call("findMiniHeadersAsync", query)) + if err != nil { + return nil, convertJSError(err) + } + var dexieMiniHeaders []*dexietypes.MiniHeader + if err := jsutil.InefficientlyConvertFromJS(jsMiniHeaders, &dexieMiniHeaders); err != nil { + return nil, err + } + return dexietypes.MiniHeadersToCommonType(dexieMiniHeaders), nil +} + +func (db *DB) DeleteMiniHeader(hash common.Hash) (err error) { + defer func() { + if r := recover(); r != nil { + err = recoverError(r) + } + }() + _, jsErr := jsutil.AwaitPromiseContext(db.ctx, db.dexie.Call("deleteMiniHeaderAsync", hash.Hex())) + if jsErr != nil { + return convertJSError(jsErr) + } + return nil +} + +func (db *DB) DeleteMiniHeaders(query *MiniHeaderQuery) (deleted []*types.MiniHeader, err error) { + if err := checkMiniHeaderQuery(query); err != nil { + return nil, err + } + defer func() { + if r := recover(); r != nil { + err = recoverError(r) + } + }() + query = formatMiniHeaderQuery(query) + jsMiniHeaders, err := jsutil.AwaitPromiseContext(db.ctx, db.dexie.Call("deleteMiniHeadersAsync", query)) + if err != nil { + return nil, convertJSError(err) + } + var dexieMiniHeaders []*dexietypes.MiniHeader + if err := jsutil.InefficientlyConvertFromJS(jsMiniHeaders, &dexieMiniHeaders); err != nil { + return nil, err + } + return dexietypes.MiniHeadersToCommonType(dexieMiniHeaders), nil +} + +func (db *DB) GetMetadata() (metadata *types.Metadata, err error) { + defer func() { + if r := recover(); r != nil { + err = recoverError(r) + } + }() + jsMetadata, err := jsutil.AwaitPromiseContext(db.ctx, db.dexie.Call("getMetadataAsync")) + if err != nil { + return nil, convertJSError(err) + } + var dexieMetadata dexietypes.Metadata + if err := jsutil.InefficientlyConvertFromJS(jsMetadata, &dexieMetadata); err != nil { + return nil, err + } + return dexietypes.MetadataToCommonType(&dexieMetadata), nil +} + +func (db *DB) SaveMetadata(metadata *types.Metadata) (err error) { + defer func() { + if r := recover(); r != nil { + err = recoverError(r) + } + }() + dexieMetadata := dexietypes.MetadataFromCommonType(metadata) + jsMetadata, err := jsutil.InefficientlyConvertToJS(dexieMetadata) + if err != nil { + return err + } + _, err = jsutil.AwaitPromiseContext(db.ctx, db.dexie.Call("saveMetadataAsync", jsMetadata)) + if err != nil { + return convertJSError(err) + } + return nil +} + +func (db *DB) UpdateMetadata(updateFunc func(oldmetadata *types.Metadata) (newMetadata *types.Metadata)) (err error) { + defer func() { + if r := recover(); r != nil { + err = recoverError(r) + } + }() + jsUpdateFunc := js.FuncOf(func(_ js.Value, args []js.Value) interface{} { + jsExistingMetadata := args[0] + var dexieExistingMetadata dexietypes.Metadata + if err := jsutil.InefficientlyConvertFromJS(jsExistingMetadata, &dexieExistingMetadata); err != nil { + panic(err) + } + metadataToUpdate := updateFunc(dexietypes.MetadataToCommonType(&dexieExistingMetadata)) + dexieMetadataToUpdate := dexietypes.MetadataFromCommonType(metadataToUpdate) + jsMetadataToUpdate, err := jsutil.InefficientlyConvertToJS(dexieMetadataToUpdate) + if err != nil { + panic(err) + } + return jsMetadataToUpdate + }) + defer jsUpdateFunc.Release() + _, jsErr := jsutil.AwaitPromiseContext(db.ctx, db.dexie.Call("updateMetadataAsync", jsUpdateFunc)) + if jsErr != nil { + return convertJSError(jsErr) + } + return nil +} + +func recoverError(e interface{}) error { + switch e := e.(type) { + case error: + return e + case string: + return errors.New(e) + default: + return fmt.Errorf("unexpected JavaScript error: (%T) %v", e, e) + } +} + +func convertJSError(e error) error { + switch e := e.(type) { + case js.Error: + if jsutil.IsNullOrUndefined(e.Value) { + return e + } + if jsutil.IsNullOrUndefined(e.Value.Get("message")) { + return e + } + switch e.Value.Get("message").String() { + // TOOD(albrow): Handle more error messages here + case ErrNotFound.Error(): + return ErrNotFound + case ErrMetadataAlreadyExists.Error(): + return ErrMetadataAlreadyExists + case ErrDBFilledWithPinnedOrders.Error(): + return ErrDBFilledWithPinnedOrders + } + } + return e +} + +func formatOrderQuery(query *OrderQuery) *OrderQuery { + if query == nil { + return nil + } + for i, filter := range query.Filters { + query.Filters[i].Value = convertFilterValue(filter.Value) + } + return query +} + +func formatMiniHeaderQuery(query *MiniHeaderQuery) *MiniHeaderQuery { + if query == nil { + return nil + } + for i, filter := range query.Filters { + query.Filters[i].Value = convertFilterValue(filter.Value) + } + return query +} + +func convertFilterValue(value interface{}) interface{} { + switch v := value.(type) { + case *big.Int: + return dexietypes.NewSortedBigInt(v) + case bool: + return dexietypes.BoolToUint8(v) + } + return value +} + +func assetDataIncludesTokenAddressAndTokenID(field OrderField, tokenAddress common.Address, tokenID *big.Int) OrderFilter { + filterValueJSON, err := canonicaljson.Marshal(dexietypes.SingleAssetData{ + Address: tokenAddress, + TokenID: dexietypes.NewBigInt(tokenID), + }) + if err != nil { + // big.Int and common.Address types should never return an error when marshaling to JSON + panic(err) + } + return OrderFilter{ + Field: field, + Kind: Contains, + Value: string(filterValueJSON), + } +} diff --git a/db/dexietypes/dexietypes.go b/db/dexietypes/dexietypes.go new file mode 100644 index 000000000..587a1dff5 --- /dev/null +++ b/db/dexietypes/dexietypes.go @@ -0,0 +1,378 @@ +package dexietypes + +// Note(albrow): Could be optimized if needed by more directly converting between +// Go types and JavaScript types instead of using jsutil.IneffecientlyConvertX. + +import ( + "encoding/json" + "fmt" + "math/big" + "strconv" + "time" + + "github.com/0xProject/0x-mesh/common/types" + "github.com/ethereum/go-ethereum/common" + ethmath "github.com/ethereum/go-ethereum/common/math" + ethtypes "github.com/ethereum/go-ethereum/core/types" + "github.com/gibson042/canonicaljson-go" +) + +// BigInt is a wrapper around *big.Int that implements the json.Marshaler +// and json.Unmarshaler interfaces in a way that is compatible with Dexie.js +// but *does not* pad with zeroes and *does not* retain sort order. +type BigInt struct { + *big.Int +} + +func NewBigInt(v *big.Int) *BigInt { + return &BigInt{ + Int: v, + } +} + +func BigIntFromString(v string) (*BigInt, error) { + bigInt, ok := ethmath.ParseBig256(v) + if !ok { + return nil, fmt.Errorf("dexietypes: could not convert %q to BigInt", v) + } + return NewBigInt(bigInt), nil +} + +func BigIntFromInt64(v int64) *BigInt { + return NewBigInt(big.NewInt(v)) +} + +func (i *BigInt) MarshalJSON() ([]byte, error) { + if i == nil || i.Int == nil { + return json.Marshal(nil) + } + return json.Marshal(i.Int.String()) +} + +func (i *BigInt) UnmarshalJSON(data []byte) error { + unqouted, err := strconv.Unquote(string(data)) + if err != nil { + return fmt.Errorf("could not unmarshal JSON data into dexietypes.BigInt: %s", string(data)) + } + bigInt, ok := ethmath.ParseBig256(unqouted) + if !ok { + return fmt.Errorf("could not unmarshal JSON data into dexietypes.BigInt: %s", string(data)) + } + i.Int = bigInt + return nil +} + +// SortedBigInt is a wrapper around *big.Int that implements the json.Marshaler +// and json.Unmarshaler interfaces in a way that is compatible with Dexie.js and +// retains sort order by padding with zeroes. +type SortedBigInt struct { + *big.Int +} + +func NewSortedBigInt(v *big.Int) *SortedBigInt { + return &SortedBigInt{ + Int: v, + } +} + +func SortedBigIntFromString(v string) (*SortedBigInt, error) { + bigInt, ok := ethmath.ParseBig256(v) + if !ok { + return nil, fmt.Errorf("dexietypes: could not convert %q to BigInt", v) + } + return NewSortedBigInt(bigInt), nil +} + +func SortedBigIntFromInt64(v int64) *SortedBigInt { + return NewSortedBigInt(big.NewInt(v)) +} + +func (i *SortedBigInt) MarshalJSON() ([]byte, error) { + if i == nil || i.Int == nil { + return json.Marshal(nil) + } + // Note(albrow), strings in Dexie.js are sorted in alphanumerical order, not + // numerical order. In order to sort by numerical order, we need to pad with + // zeroes. The maximum length of an unsigned 256 bit integer is 80, so we + // pad with zeroes such that the length of the number is always 80. + return json.Marshal(fmt.Sprintf("%080s", i.Int.String())) +} + +func (i *SortedBigInt) UnmarshalJSON(data []byte) error { + unqouted, err := strconv.Unquote(string(data)) + if err != nil { + return fmt.Errorf("could not unmarshal JSON data into dexietypes.BigInt: %s", string(data)) + } + bigInt, ok := ethmath.ParseBig256(unqouted) + if !ok { + return fmt.Errorf("could not unmarshal JSON data into dexietypes.BigInt: %s", string(data)) + } + i.Int = bigInt + return nil +} + +type SingleAssetData struct { + Address common.Address `json:"address"` + TokenID *BigInt `json:"tokenID"` +} + +// ParsedAssetData is a wrapper around []*SingleAssetData that implements the +// sql.Valuer and sql.Scanner interfaces. +type ParsedAssetData []*SingleAssetData + +// Order is the SQL database representation a 0x order along with some relevant metadata. +type Order struct { + Hash common.Hash `json:"hash"` + ChainID *SortedBigInt `json:"chainID"` + ExchangeAddress common.Address `json:"exchangeAddress"` + MakerAddress common.Address `json:"makerAddress"` + MakerAssetData []byte `json:"makerAssetData"` + MakerFeeAssetData []byte `json:"makerFeeAssetData"` + MakerAssetAmount *SortedBigInt `json:"makerAssetAmount"` + MakerFee *SortedBigInt `json:"makerFee"` + TakerAddress common.Address `json:"takerAddress"` + TakerAssetData []byte `json:"takerAssetData"` + TakerFeeAssetData []byte `json:"takerFeeAssetData"` + TakerAssetAmount *SortedBigInt `json:"takerAssetAmount"` + TakerFee *SortedBigInt `json:"takerFee"` + SenderAddress common.Address `json:"senderAddress"` + FeeRecipientAddress common.Address `json:"feeRecipientAddress"` + ExpirationTimeSeconds *SortedBigInt `json:"expirationTimeSeconds"` + Salt *SortedBigInt `json:"salt"` + Signature []byte `json:"signature"` + LastUpdated time.Time `json:"lastUpdated"` + FillableTakerAssetAmount *SortedBigInt `json:"fillableTakerAssetAmount"` + IsRemoved uint8 `json:"isRemoved"` + IsPinned uint8 `json:"isPinned"` + ParsedMakerAssetData string `json:"parsedMakerAssetData"` + ParsedMakerFeeAssetData string `json:"parsedMakerFeeAssetData"` +} + +type MiniHeader struct { + Hash common.Hash `json:"hash"` + Parent common.Hash `json:"parent"` + Number *SortedBigInt `json:"number"` + Timestamp time.Time `json:"timestamp"` + Logs string `json:"logs"` +} + +type Metadata struct { + EthereumChainID int `json:"ethereumChainID"` + MaxExpirationTime *SortedBigInt `json:"maxExpirationTime"` + EthRPCRequestsSentInCurrentUTCDay int `json:"ethRPCRequestsSentInCurrentUTCDay"` + StartOfCurrentUTCDay time.Time `json:"startOfCurrentUTCDay"` +} + +func OrderToCommonType(order *Order) *types.OrderWithMetadata { + if order == nil { + return nil + } + return &types.OrderWithMetadata{ + Hash: order.Hash, + ChainID: order.ChainID.Int, + ExchangeAddress: order.ExchangeAddress, + MakerAddress: order.MakerAddress, + MakerAssetData: order.MakerAssetData, + MakerFeeAssetData: order.MakerFeeAssetData, + MakerAssetAmount: order.MakerAssetAmount.Int, + MakerFee: order.MakerFee.Int, + TakerAddress: order.TakerAddress, + TakerAssetData: order.TakerAssetData, + TakerFeeAssetData: order.TakerFeeAssetData, + TakerAssetAmount: order.TakerAssetAmount.Int, + TakerFee: order.TakerFee.Int, + SenderAddress: order.SenderAddress, + FeeRecipientAddress: order.FeeRecipientAddress, + ExpirationTimeSeconds: order.ExpirationTimeSeconds.Int, + Salt: order.Salt.Int, + Signature: order.Signature, + FillableTakerAssetAmount: order.FillableTakerAssetAmount.Int, + LastUpdated: order.LastUpdated, + IsRemoved: order.IsRemoved == 1, + IsPinned: order.IsPinned == 1, + ParsedMakerAssetData: ParsedAssetDataToCommonType(order.ParsedMakerAssetData), + ParsedMakerFeeAssetData: ParsedAssetDataToCommonType(order.ParsedMakerFeeAssetData), + } +} + +func OrderFromCommonType(order *types.OrderWithMetadata) *Order { + if order == nil { + return nil + } + return &Order{ + Hash: order.Hash, + ChainID: NewSortedBigInt(order.ChainID), + ExchangeAddress: order.ExchangeAddress, + MakerAddress: order.MakerAddress, + MakerAssetData: order.MakerAssetData, + MakerFeeAssetData: order.MakerFeeAssetData, + MakerAssetAmount: NewSortedBigInt(order.MakerAssetAmount), + MakerFee: NewSortedBigInt(order.MakerFee), + TakerAddress: order.TakerAddress, + TakerAssetData: order.TakerAssetData, + TakerFeeAssetData: order.TakerFeeAssetData, + TakerAssetAmount: NewSortedBigInt(order.TakerAssetAmount), + TakerFee: NewSortedBigInt(order.TakerFee), + SenderAddress: order.SenderAddress, + FeeRecipientAddress: order.FeeRecipientAddress, + ExpirationTimeSeconds: NewSortedBigInt(order.ExpirationTimeSeconds), + Salt: NewSortedBigInt(order.Salt), + Signature: order.Signature, + LastUpdated: order.LastUpdated, + FillableTakerAssetAmount: NewSortedBigInt(order.FillableTakerAssetAmount), + IsRemoved: BoolToUint8(order.IsRemoved), + IsPinned: BoolToUint8(order.IsPinned), + ParsedMakerAssetData: ParsedAssetDataFromCommonType(order.ParsedMakerAssetData), + ParsedMakerFeeAssetData: ParsedAssetDataFromCommonType(order.ParsedMakerFeeAssetData), + } +} + +func OrdersToCommonType(orders []*Order) []*types.OrderWithMetadata { + result := make([]*types.OrderWithMetadata, len(orders)) + for i, order := range orders { + result[i] = OrderToCommonType(order) + } + return result +} + +func OrdersFromCommonType(orders []*types.OrderWithMetadata) []*Order { + result := make([]*Order, len(orders)) + for i, order := range orders { + result[i] = OrderFromCommonType(order) + } + return result +} + +func ParsedAssetDataToCommonType(parsedAssetData string) []*types.SingleAssetData { + if parsedAssetData == "" { + return nil + } + var dexieAssetDatas []*SingleAssetData + _ = json.Unmarshal([]byte(parsedAssetData), &dexieAssetDatas) + result := make([]*types.SingleAssetData, len(dexieAssetDatas)) + for i, singleAssetData := range dexieAssetDatas { + result[i] = SingleAssetDataToCommonType(singleAssetData) + } + return result +} + +func ParsedAssetDataFromCommonType(parsedAssetData []*types.SingleAssetData) string { + dexieAssetDatas := ParsedAssetData(make([]*SingleAssetData, len(parsedAssetData))) + for i, singleAssetData := range parsedAssetData { + dexieAssetDatas[i] = SingleAssetDataFromCommonType(singleAssetData) + } + jsonAssetDatas, _ := canonicaljson.Marshal(dexieAssetDatas) + return string(jsonAssetDatas) +} + +func SingleAssetDataToCommonType(singleAssetData *SingleAssetData) *types.SingleAssetData { + if singleAssetData == nil { + return nil + } + var tokenID *big.Int + if singleAssetData.TokenID != nil { + tokenID = singleAssetData.TokenID.Int + } + return &types.SingleAssetData{ + Address: singleAssetData.Address, + TokenID: tokenID, + } +} + +func SingleAssetDataFromCommonType(singleAssetData *types.SingleAssetData) *SingleAssetData { + if singleAssetData == nil { + return nil + } + var tokenID *BigInt + if singleAssetData.TokenID != nil { + tokenID = NewBigInt(singleAssetData.TokenID) + } + return &SingleAssetData{ + Address: singleAssetData.Address, + TokenID: tokenID, + } +} + +func MiniHeaderToCommonType(miniHeader *MiniHeader) *types.MiniHeader { + if miniHeader == nil { + return nil + } + return &types.MiniHeader{ + Hash: miniHeader.Hash, + Parent: miniHeader.Parent, + Number: miniHeader.Number.Int, + Timestamp: miniHeader.Timestamp, + Logs: EventLogsToCommonType(miniHeader.Logs), + } +} + +func MiniHeaderFromCommonType(miniHeader *types.MiniHeader) *MiniHeader { + if miniHeader == nil { + return nil + } + return &MiniHeader{ + Hash: miniHeader.Hash, + Parent: miniHeader.Parent, + Number: NewSortedBigInt(miniHeader.Number), + Timestamp: miniHeader.Timestamp, + Logs: EventLogsFromCommonType(miniHeader.Logs), + } +} + +func MiniHeadersToCommonType(miniHeaders []*MiniHeader) []*types.MiniHeader { + result := make([]*types.MiniHeader, len(miniHeaders)) + for i, miniHeader := range miniHeaders { + result[i] = MiniHeaderToCommonType(miniHeader) + } + return result +} + +func MiniHeadersFromCommonType(miniHeaders []*types.MiniHeader) []*MiniHeader { + result := make([]*MiniHeader, len(miniHeaders)) + for i, miniHeader := range miniHeaders { + result[i] = MiniHeaderFromCommonType(miniHeader) + } + return result +} + +func EventLogsToCommonType(eventLogs string) []ethtypes.Log { + var result []ethtypes.Log + _ = json.Unmarshal([]byte(eventLogs), &result) + return result +} + +func EventLogsFromCommonType(eventLogs []ethtypes.Log) string { + result, _ := json.Marshal(eventLogs) + return string(result) +} + +func MetadataToCommonType(metadata *Metadata) *types.Metadata { + if metadata == nil { + return nil + } + return &types.Metadata{ + EthereumChainID: metadata.EthereumChainID, + MaxExpirationTime: metadata.MaxExpirationTime.Int, + EthRPCRequestsSentInCurrentUTCDay: metadata.EthRPCRequestsSentInCurrentUTCDay, + StartOfCurrentUTCDay: metadata.StartOfCurrentUTCDay, + } +} + +func MetadataFromCommonType(metadata *types.Metadata) *Metadata { + if metadata == nil { + return nil + } + return &Metadata{ + EthereumChainID: metadata.EthereumChainID, + MaxExpirationTime: NewSortedBigInt(metadata.MaxExpirationTime), + EthRPCRequestsSentInCurrentUTCDay: metadata.EthRPCRequestsSentInCurrentUTCDay, + StartOfCurrentUTCDay: metadata.StartOfCurrentUTCDay, + } +} + +func BoolToUint8(b bool) uint8 { + if b { + return 1 + } + return 0 +} diff --git a/db/errors.go b/db/errors.go deleted file mode 100644 index 4ac8da196..000000000 --- a/db/errors.go +++ /dev/null @@ -1,26 +0,0 @@ -package db - -import ( - "encoding/hex" - "fmt" -) - -// NotFoundError is returned whenever a model with a specific ID should be found -// in the database but it is not. -type NotFoundError struct { - ID []byte -} - -func (e NotFoundError) Error() string { - return fmt.Sprintf("could not find model with the given ID: %s", hex.EncodeToString(e.ID)) -} - -// AlreadyExistsError is returned whenever a model with a specific ID should not -// already exists in the database but it does. -type AlreadyExistsError struct { - ID []byte -} - -func (e AlreadyExistsError) Error() string { - return fmt.Sprintf("model already exists with the given ID: %s", hex.EncodeToString(e.ID)) -} diff --git a/db/escape.go b/db/escape.go deleted file mode 100644 index b83be10d8..000000000 --- a/db/escape.go +++ /dev/null @@ -1,58 +0,0 @@ -package db - -import ( - "bufio" - "bytes" - "encoding/hex" - - log "github.com/sirupsen/logrus" -) - -// escape replaces ':' with '\c' and '\' with '\\'. -func escape(value []byte) []byte { - escaped := []byte{} - for _, b := range value { - switch b { - case ':': - escaped = append(escaped, ([]byte{'\\', 'c'})...) - case '\\': - escaped = append(escaped, ([]byte{'\\', b})...) - default: - escaped = append(escaped, b) - } - } - return escaped -} - -// unescape is the inverse of escape. -func unescape(value []byte) ([]byte, error) { - reader := bufio.NewReader(bytes.NewBuffer(value)) - unescaped := []byte{} - for { - b, err := reader.ReadByte() - if err != nil { - // Assume io.EOF error indicating we reached the end of the value. - break - } - if b == '\\' { - next, err := reader.ReadByte() - if err != nil { - // This is only possible if the value was not escaped properly. Should - // never happen. - log.WithFields(log.Fields{ - "error": err.Error(), - "value": hex.Dump(value), - }).Error("unexpected error in unescape") - return nil, err - } - if next == 'c' { - unescaped = append(unescaped, ':') - } else { - unescaped = append(unescaped, next) - } - } else { - unescaped = append(unescaped, b) - } - } - return unescaped, nil -} diff --git a/db/escape_test.go b/db/escape_test.go deleted file mode 100644 index 72876a022..000000000 --- a/db/escape_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package db - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var trickyByteValues = [][]byte{ - []byte(":"), - []byte(`\`), - []byte("::"), - []byte(`\\`), - []byte(`\:`), - []byte(`:\`), - []byte(`\\:`), - []byte(`::\`), - []byte(`\:\:`), - []byte(`:\:\`), - []byte(`:\\`), - []byte(`\::`), - []byte(`::\\`), - []byte(`\\::`), -} - -func TestEscapeUnescape(t *testing.T) { - t.Parallel() - for _, expected := range trickyByteValues { - actual, err := unescape(escape(expected)) - require.NoError(t, err) - assert.Equal(t, expected, actual) - } -} - -func TestFindWithValueWithEscape(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - ageIndex := col.AddIndex("age", func(m Model) []byte { - // Note: We add the ':' to the index value to try and trip up the escaping - // algorithm. - return []byte(fmt.Sprintf(":%d:", m.(*testModel).Age)) - }) - models := make([]*testModel, len(trickyByteValues)) - // Use the trickyByteValues as the names for each model. - for i, name := range trickyByteValues { - models[i] = &testModel{ - Name: string(name), - Age: i, - } - } - for i, expected := range models { - require.NoError(t, col.Insert(expected), "testModel %d", i) - actual := []*testModel{} - query := col.NewQuery(ageIndex.ValueFilter([]byte(fmt.Sprintf(":%d:", expected.Age)))) - require.NoError(t, query.Run(&actual), "testModel %d", i) - require.Len(t, actual, 1, "testModel %d", i) - assert.Equal(t, expected, actual[0]) - } -} diff --git a/db/global_transaction.go b/db/global_transaction.go deleted file mode 100644 index 16bdb2e88..000000000 --- a/db/global_transaction.go +++ /dev/null @@ -1,171 +0,0 @@ -package db - -import ( - "sync" -) - -// GlobalTransaction is an atomic database transaction across all collections -// which can be used to guarantee consistency. -type GlobalTransaction struct { - db *DB - mut sync.Mutex - batchWriter dbBatchWriter - readWriter *readerWithBatchWriter - committed bool - discarded bool - // internalCounts keeps track of the number of models inserted/deleted within - // the transaction for each collection. An Insert increments the count and - // a Delete decrements it. When the transaction is committed, the - // internal count is added to the current count for each collection. - internalCounts map[*Collection]int -} - -// OpenGlobalTransaction opens and returns a new global transaction. While the -// transaction is open, no other state changes (e.g. Insert, Update, or Delete) -// can be made to the database (but concurrent reads are still allowed). This -// includes all collections. -// -// No new collections can be created while the global transaction is open. -// Calling NewCollection while the transaction is open will block until the -// transaction is committed or discarded. -// -// Transactions are atomic, meaning that either: -// -// (1) The transaction will succeed and *all* queued operations will be -// applied, or -// (2) the transaction will fail or be discarded, in which case *none* of -// the queued operations will be applied. -// -// The transaction must be closed once done, either by committing or discarding -// the transaction. No changes will be made to the database state until the -// transaction is committed. -func (db *DB) OpenGlobalTransaction() *GlobalTransaction { - // Note we acquire a Lock on the global write mutex. We're not really a - // "writer" but we behave like one in the context of an RWMutex. Up to one - // write lock for each collection can be held, or one global write lock can be - // held at any given time. - db.colLock.Lock() - db.globalWriteLock.Lock() - return &GlobalTransaction{ - db: db, - batchWriter: db.ldb, - readWriter: newReaderWithBatchWriter(db.ldb), - internalCounts: map[*Collection]int{}, - } -} - -// checkState acquires a lock on txn.mut and then calls unsafeCheckState. -func (txn *GlobalTransaction) checkState() error { - txn.mut.Lock() - defer txn.mut.Unlock() - return txn.unsafeCheckState() -} - -// unsafeCheckState checks the state of the transaction, assuming the caller has -// already acquired a lock. It returns an error if the transaction has already -// been committed or discarded. -func (txn *GlobalTransaction) unsafeCheckState() error { - if txn.discarded { - return ErrDiscarded - } else if txn.committed { - return ErrCommitted - } - return nil -} - -// Commit commits the transaction. If error is not nil, then the transaction is -// discarded. A new transaction must be created if you wish to retry the -// operations. -// -// Other methods should not be called after transaction has been committed. -func (txn *GlobalTransaction) Commit() error { - txn.mut.Lock() - defer txn.mut.Unlock() - if err := txn.unsafeCheckState(); err != nil { - return err - } - // Right before we commit, we need to update the count for each collection - // that was touched. - for col, internalCount := range txn.internalCounts { - if err := updateCountWithTransaction(col.info, txn.readWriter, int(internalCount)); err != nil { - _ = txn.Discard() - return err - } - } - if err := txn.batchWriter.Write(txn.readWriter.batch, nil); err != nil { - _ = txn.Discard() - return err - } - txn.committed = true - txn.db.globalWriteLock.Unlock() - txn.db.colLock.Unlock() - return nil -} - -// Discard discards the transaction. -// -// Other methods should not be called after transaction has been discarded. -// However, it is safe to call Discard multiple times. -func (txn *GlobalTransaction) Discard() error { - txn.mut.Lock() - defer txn.mut.Unlock() - if txn.committed { - return ErrCommitted - } - if txn.discarded { - return nil - } - txn.discarded = true - txn.db.globalWriteLock.Unlock() - txn.db.colLock.Unlock() - return nil -} - -// Insert queues an operation to insert the given model into the given -// collection. It returns an error if a model with the same id already exists. -// The model will not actually be inserted until the transaction is committed. -func (txn *GlobalTransaction) Insert(col *Collection, model Model) error { - if err := txn.checkState(); err != nil { - return err - } - if err := insertWithTransaction(col.info, txn.readWriter, model); err != nil { - return err - } - txn.updateInternalCount(col, 1) - return nil -} - -// Update queues an operation to update an existing model in the given -// collection. It returns an error if the given model doesn't already exist. The -// model will not actually be updated until the transaction is committed. -func (txn *GlobalTransaction) Update(col *Collection, model Model) error { - if err := txn.checkState(); err != nil { - return err - } - return updateWithTransaction(col.info, txn.readWriter, model) -} - -// Delete queues an operation to delete the model with the given ID from the -// given collection. It returns an error if the model doesn't exist in the -// database. The model will not actually be deleted until the transaction is -// committed. -func (txn *GlobalTransaction) Delete(col *Collection, id []byte) error { - if err := txn.checkState(); err != nil { - return err - } - if err := deleteWithTransaction(col.info, txn.readWriter, id); err != nil { - return err - } - txn.updateInternalCount(col, -1) - return nil -} - -func (txn *GlobalTransaction) updateInternalCount(col *Collection, diff int) { - txn.mut.Lock() - defer txn.mut.Unlock() - if existingCount, found := txn.internalCounts[col]; found { - txn.internalCounts[col] = existingCount + diff - } else { - txn.internalCounts[col] = diff - } -} diff --git a/db/global_transaction_test.go b/db/global_transaction_test.go deleted file mode 100644 index 941b22f51..000000000 --- a/db/global_transaction_test.go +++ /dev/null @@ -1,349 +0,0 @@ -package db - -import ( - "strconv" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestGlobalTransaction(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col0, err := db.NewCollection("people0", &testModel{}) - require.NoError(t, err) - col1, err := db.NewCollection("people1", &testModel{}) - require.NoError(t, err) - - // beforeTxnOpen is a set of testModels inserted before the transaction is opened. - beforeTxnOpen := []*testModel{} - for i := 0; i < 5; i++ { - model := &testModel{ - Name: "ExpectedPerson_" + strconv.Itoa(i), - Age: i, - } - require.NoError(t, col0.Insert(model)) - beforeTxnOpen = append(beforeTxnOpen, model) - } - - // Open a global transaction. - txn := db.OpenGlobalTransaction() - defer func() { - err := txn.Discard() - if err != nil && err != ErrCommitted { - t.Error(err) - } - }() - - // The WaitGroup will be used to wait for all goroutines to finish. - wg := &sync.WaitGroup{} - - // Any models we add to col0 after opening the transaction should not affect - // the database state until after is committed. - outsideTransaction := []*testModel{} - wg.Add(1) - go func() { - defer wg.Done() - for i := 0; i < 5; i++ { - model := &testModel{ - Name: "OutsideTransaction_" + strconv.Itoa(i), - Age: i, - } - require.NoError(t, col0.Insert(model)) - outsideTransaction = append(outsideTransaction, model) - } - }() - - // Any models we add to col0 within the transaction should not affect - // the database state until after it is committed. - insideTransaction := []*testModel{} - for i := 0; i < 5; i++ { - model := &testModel{ - Name: "InsideTransaction_" + strconv.Itoa(i), - Age: i, - } - require.NoError(t, txn.Insert(col0, model)) - insideTransaction = append(insideTransaction, model) - } - - // Any models we add to col1 after opening the transaction should not affect - // the database state until after it is committed. - wg.Add(1) - go func() { - defer wg.Done() - for i := 0; i < 5; i++ { - model := &testModel{ - Name: "OtherPerson_" + strconv.Itoa(i), - Age: i, - } - require.NoError(t, col1.Insert(model)) - } - }() - - // Any models we delete after opening the transaction should not affect - // the database state until after it is committed. - idToDelete := beforeTxnOpen[2].ID() - wg.Add(1) - go func(idToDelete []byte) { - defer wg.Done() - require.NoError(t, col0.Delete(idToDelete)) - }(idToDelete) - - // Attempting to add a new collection should block until after the transaction - // is committed/discarded. We use two channels to determine the order in which - // the two events occurred. - // commitSignal is fired right before the transaction is committed. - commitSignal := make(chan struct{}) - // newCollectionSignal is fired after the new collection has been created. - newCollectionSignal := make(chan struct{}) - - wg.Add(1) - go func() { - defer wg.Done() - _, err := db.NewCollection("people2", &testModel{}) - require.NoError(t, err) - // Signal that the new collection was created. - close(newCollectionSignal) - }() - - select { - case <-time.After(transactionExclusionTestTimeout): - // Expected outcome. Exit from select. - break - case <-newCollectionSignal: - t.Error("new collection was created before the transaction was committed") - return - } - - // Make sure that col0 only contains models that were created before the - // transaction was opened. - var actual []*testModel - require.NoError(t, col0.FindAll(&actual)) - assert.Equal(t, beforeTxnOpen, actual) - - // Make sure that col1 doesn't contain any models (since they were created - // after the transaction was opened). - actualCount, err := col1.Count() - require.NoError(t, err) - assert.Equal(t, 0, actualCount) - - // Signal that we are about to commit the transaction, then commit it. - close(commitSignal) - require.NoError(t, txn.Commit()) - - // Wait for any goroutines to finish. - wg.Wait() - - // Check that all the models are now written. - // TODO(albrow): Fix bug with Count and transactions, then we can use Count - // instead of FindAll here. - var existingModels []*testModel - require.NoError(t, col0.FindAll(&existingModels)) - assert.Len(t, existingModels, 14) - - col1PostTxnCount, err := col1.Count() - require.NoError(t, err) - assert.Equal(t, 5, col1PostTxnCount) -} - -func TestGlobalTransactionCount(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - - // insertedBeforeTransaction is a set of testModels inserted before the - // transaction is opened. - insertedBeforeTransaction := []*testModel{} - for i := 0; i < 10; i++ { - model := &testModel{ - Name: "Before_Transaction_" + strconv.Itoa(i), - Age: i, - } - require.NoError(t, col.Insert(model)) - insertedBeforeTransaction = append(insertedBeforeTransaction, model) - } - - // Open a global transaction. - txn := db.OpenGlobalTransaction() - defer func() { - err := txn.Discard() - if err != nil && err != ErrCommitted { - t.Error(err) - } - }() - - // Insert some models inside the transaction. - for i := 0; i < 7; i++ { - model := &testModel{ - Name: "Inside_Transaction_" + strconv.Itoa(i), - Age: i, - } - require.NoError(t, txn.Insert(col, model)) - } - - // The WaitGroup will be used to wait for all goroutines to finish. - wg := &sync.WaitGroup{} - - // Insert some models outside the transaction. - wg.Add(1) - go func() { - defer wg.Done() - for i := 0; i < 4; i++ { - model := &testModel{ - Name: "Outside_Transaction_" + strconv.Itoa(i), - Age: 42, - } - require.NoError(t, col.Insert(model)) - } - }() - - // Delete some models inside of the transaction. - idsToDeleteInside := [][]byte{ - insertedBeforeTransaction[0].ID(), - insertedBeforeTransaction[1].ID(), - insertedBeforeTransaction[2].ID(), - } - for _, id := range idsToDeleteInside { - require.NoError(t, txn.Delete(col, id)) - } - - // Delete some models outside of the transaction. - idsToDeleteOutside := [][]byte{ - insertedBeforeTransaction[3].ID(), - insertedBeforeTransaction[4].ID(), - } - wg.Add(1) - go func() { - defer wg.Done() - for _, id := range idsToDeleteOutside { - require.NoError(t, col.Delete(id)) - } - }() - - // Make sure that prior to commiting the transaction, Count only includes the - // models inserted/deleted before the transaction was open. - expectedPreCommitCount := 10 - actualPreCommitCount, err := col.Count() - require.NoError(t, err) - assert.Equal(t, expectedPreCommitCount, actualPreCommitCount) - - // Commit the transaction. - require.NoError(t, txn.Commit()) - - // Wait for any goroutines to finish. - wg.Wait() - - // Make sure that after commiting the transaction, Count includes the models - // inserted/deleted in the transaction and outside of the transaction. - // 10 before transaction. - // +7 inserted inside transaction - // +4 inserted outside transaction - // -3 deleted inside transaction - // -2 deleted outside transaction - // = 16 total - expectedPostCommitCount := 16 - actualPostCommitCount, err := col.Count() - require.NoError(t, err) - assert.Equal(t, expectedPostCommitCount, actualPostCommitCount) -} - -// TestGlobalTransactionExclusion is designed to test whether a global -// transaction has exclusive write access for all collections while open. -func TestGlobalTransactionExclusion(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col0, err := db.NewCollection("people0", &testModel{}) - require.NoError(t, err) - col1, err := db.NewCollection("people1", &testModel{}) - require.NoError(t, err) - - txn := db.OpenGlobalTransaction() - defer func() { - _ = txn.Discard() - }() - - // newGlobalTxnOpenSignal is fired when a new global transaction is opened. - newGlobalTxnOpenSignal := make(chan struct{}) - // col0TxnOpenSignal is fired when a transaction on col0 is opened. - col0TxnOpenSignal := make(chan struct{}) - // col1TxnOpenSignal is fired when a transaction on col1 is opened. - col1TxnOpenSignal := make(chan struct{}) - - wg := &sync.WaitGroup{} - wg.Add(1) - go func() { - defer wg.Done() - txn := col0.OpenTransaction() - close(col0TxnOpenSignal) - defer func() { - _ = txn.Discard() - }() - }() - - wg.Add(1) - go func() { - defer wg.Done() - txn := col1.OpenTransaction() - close(col1TxnOpenSignal) - defer func() { - _ = txn.Discard() - }() - }() - - wg.Add(1) - go func() { - defer wg.Done() - txn := db.OpenGlobalTransaction() - close(newGlobalTxnOpenSignal) - defer func() { - _ = txn.Discard() - }() - }() - - select { - case <-time.After(transactionExclusionTestTimeout): - // Expected outcome. Return from the goroutine. - return - case <-newGlobalTxnOpenSignal: - t.Error("a new global transaction was opened before the first was committed/discarded") - case <-col0TxnOpenSignal: - t.Error("a new transaction was opened on col0 before the global transaction was committed/discarded") - case <-col1TxnOpenSignal: - t.Error("a new transaction was opened on col1 before the global transaction was committed/discarded") - } - - // Discard the first global transaction. - require.NoError(t, txn.Discard()) - - // Check that col0 and col1 transactions are opened. - wasCol0TxnOpened := false - wasCol1TxnOpened := false - wasNewGlobalTxnOpened := false - txnOpenTimeout := time.After(transactionExclusionTestTimeout) - for { - if wasCol0TxnOpened && wasCol1TxnOpened && wasNewGlobalTxnOpened { - // All three transactions were opened. Break the for loop. - break - } - select { - case <-txnOpenTimeout: - t.Fatalf("timed out waiting for one or more transactions to open (tx0: %t, txn1: %t, global: %t)", wasCol0TxnOpened, wasCol1TxnOpened, wasNewGlobalTxnOpened) - case <-col0TxnOpenSignal: - wasCol0TxnOpened = true - case <-col1TxnOpenSignal: - wasCol1TxnOpened = true - case <-newGlobalTxnOpenSignal: - wasNewGlobalTxnOpened = true - } - } - - // Wait for all goroutines to exit. - wg.Wait() -} diff --git a/db/index.go b/db/index.go deleted file mode 100644 index e5a1019c0..000000000 --- a/db/index.go +++ /dev/null @@ -1,76 +0,0 @@ -package db - -import ( - "fmt" - "strings" -) - -// Index can be used to search for specific values or specific ranges of values -// for a collection. -type Index struct { - colInfo *colInfo - name string - getter func(m Model) [][]byte -} - -// AddIndex creates and returns a new index. name is an arbitrary, unique name -// for the index. getter is a function that accepts a model and returns the -// value for this particular index. For example, if you wanted to add an index -// on a struct field, getter should return the value of that field. After -// AddIndex is called, any new models in this collection that are inserted will -// be indexed. Any models inserted prior to calling AddIndex will *not* be -// indexed. Note that in order to function correctly, indexes must be based on -// data that is actually saved to the database (e.g. exported struct fields). -func (c *Collection) AddIndex(name string, getter func(Model) []byte) *Index { - // Internally, all indexes are treated as MultiIndexes. We wrap the given - // getter function so that it returns [][]byte instead of just []byte. - wrappedGetter := func(model Model) [][]byte { - return [][]byte{getter(model)} - } - return c.AddMultiIndex(name, wrappedGetter) -} - -// AddMultiIndex is like AddIndex but has the ability to index multiple values -// for the same model. For methods like FindWithRange and FindWithValue, the -// model will be included in the results if *any* of the values returned by the -// getter function satisfy the constraints. It is useful for representing -// one-to-many relationships. Any models inserted prior to calling AddMultiIndex -// will *not* be indexed. Note that in order to function correctly, indexes must -// be based on data that is actually saved to the database (e.g. exported struct fields). -func (c *Collection) AddMultiIndex(name string, getter func(Model) [][]byte) *Index { - c.info.indexMut.Lock() - defer c.info.indexMut.Unlock() - index := &Index{ - colInfo: c.info, - name: name, - getter: getter, - } - c.info.indexes = append(c.info.indexes, index) - return index -} - -// Name returns the name of the index. -func (index *Index) Name() string { - return index.name -} - -func (index *Index) prefix() []byte { - return []byte(fmt.Sprintf("index:%s:%s", index.colInfo.name, index.name)) -} - -func (index *Index) keysForModel(model Model) [][]byte { - values := index.getter(model) - indexKeys := make([][]byte, len(values)) - for i, value := range values { - indexKeys[i] = []byte(fmt.Sprintf("%s:%s:%s", index.prefix(), escape(value), escape(model.ID()))) - } - return indexKeys -} - -// primaryKeyFromIndexKey extracts and returns the primary key from the given index -// key. -func (index *Index) primaryKeyFromIndexKey(key []byte) []byte { - pkAndVal := strings.TrimPrefix(string(key), string(index.prefix())) - split := strings.Split(pkAndVal, ":") - return index.colInfo.primaryKeyForIDWithoutEscape([]byte(split[2])) -} diff --git a/db/index_test.go b/db/index_test.go deleted file mode 100644 index c2c7724d1..000000000 --- a/db/index_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package db - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestInsertWithIndex(t *testing.T) { - t.Parallel() - db := newTestDB(t) - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - col.AddIndex("age", func(m Model) []byte { - return []byte(fmt.Sprint(m.(*testModel).Age)) - }) - model := &testModel{ - Name: "foo", - Age: 42, - } - require.NoError(t, col.Insert(model)) - exists, err := db.ldb.Has([]byte("index:people:age:42:foo"), nil) - require.NoError(t, err) - assert.True(t, exists, "Index not stored in database at the expected key") -} - -func TestUpdateWithIndex(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - col.AddIndex("age", func(m Model) []byte { - return []byte(fmt.Sprint(m.(*testModel).Age)) - }) - model := &testModel{ - Name: "foo", - Age: 42, - } - require.NoError(t, col.Insert(model)) - updated := &testModel{ - Name: "foo", - Age: 43, - } - require.NoError(t, col.Update(updated)) - oldKeyExists, err := db.ldb.Has([]byte("index:people:age:42:foo"), nil) - require.NoError(t, err) - assert.False(t, oldKeyExists, "Old index was still stored after update") - updatedKeyExists, err := db.ldb.Has([]byte("index:people:age:43:foo"), nil) - require.NoError(t, err) - assert.True(t, updatedKeyExists, "Index not stored in database at the updated key") -} diff --git a/db/integrity_check.go b/db/integrity_check.go deleted file mode 100644 index 88164ae15..000000000 --- a/db/integrity_check.go +++ /dev/null @@ -1,99 +0,0 @@ -package db - -import ( - "encoding/json" - "fmt" - "reflect" - - "github.com/syndtr/goleveldb/leveldb" - "github.com/syndtr/goleveldb/leveldb/util" -) - -func (db *DB) CheckIntegrity() error { - db.colLock.Lock() - defer db.colLock.Unlock() - for _, col := range db.collections { - if err := db.checkCollectionIntegrity(col); err != nil { - return err - } - } - return nil -} - -func (db *DB) checkCollectionIntegrity(col *Collection) error { - col.info.indexMut.RLock() - defer col.info.indexMut.RUnlock() - - snapshot, err := col.GetSnapshot() - if err != nil { - return err - } - defer snapshot.Release() - - slice := util.BytesPrefix([]byte(fmt.Sprintf("%s:", col.info.prefix()))) - iter := snapshot.snapshot.NewIterator(slice, nil) - defer iter.Release() - for iter.Next() { - // Check that the model data can be unmarshaled into the expected type. - data := iter.Value() - modelVal := reflect.New(col.info.modelType) - if err := json.Unmarshal(data, modelVal.Interface()); err != nil { - return fmt.Errorf("integritiy check failed for collection %s: could not unmarshal model data for primary key %s: %s", col.Name(), iter.Key(), err.Error()) - } - model := modelVal.Elem().Interface().(Model) - - // Check that the index entries exist for this model. - for _, index := range col.info.indexes { - indexKeys := index.keysForModel(model) - for _, indexKey := range indexKeys { - indexKeyExists, err := snapshot.snapshot.Has(indexKey, nil) - if err != nil { - return err - } - if !indexKeyExists { - return fmt.Errorf("integritiy check failed for index %s.%s: indexKey %s does not exist", col.Name(), index.Name(), indexKey) - } - } - } - } - if err := iter.Error(); err != nil { - return err - } - - // Check the integrity of each index. - for _, index := range col.info.indexes { - if err := db.checkIndexIntegrity(snapshot, col, index); err != nil { - return err - } - } - - return nil -} - -// checkIndexIntegrity checks that each key in the index corresponds to model -// data that exists and is valid (can be unmarshaled into a model of the -// expected type). -func (db *DB) checkIndexIntegrity(snapshot *Snapshot, col *Collection, index *Index) error { - slice := util.BytesPrefix([]byte(fmt.Sprintf("%s:", index.prefix()))) - iter := snapshot.snapshot.NewIterator(slice, nil) - defer iter.Release() - for iter.Next() { - pk := index.primaryKeyFromIndexKey(iter.Key()) - data, err := snapshot.snapshot.Get(pk, nil) - if err != nil { - if err == leveldb.ErrNotFound { - return fmt.Errorf("integritiy check failed for index %s.%s: key exists in index but could not find corresponding model data for primary key: %s", col.Name(), index.Name(), pk) - } else { - return err - } - } - modelVal := reflect.New(col.info.modelType) - if err := json.Unmarshal(data, modelVal.Interface()); err != nil { - return fmt.Errorf("integritiy check failed for index %s.%s: could not unmarshal model data: %s", col.Name(), index.Name(), err.Error()) - } - } - if err := iter.Error(); err != nil { - return err - } - return nil -} diff --git a/db/integrity_check_test.go b/db/integrity_check_test.go deleted file mode 100644 index eeb1c7f16..000000000 --- a/db/integrity_check_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package db - -import ( - "fmt" - "strconv" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestIntegrityCheckPass(t *testing.T) { - t.Parallel() - db, _, _, _ := setUpIntegrityCheckTest(t) - - // We didn't break anything so the integrity check should pass - require.NoError(t, db.CheckIntegrity()) -} - -func TestIntegrityCheckInvalidModelData(t *testing.T) { - t.Parallel() - db, col, models, _ := setUpIntegrityCheckTest(t) - defer db.Close() - - // Manually break integrity by storing invalid model data. - keyToChange := col.info.primaryKeyForModel(models[0]) - require.NoError(t, db.ldb.Put(keyToChange, []byte("invalid data"), nil)) - expectedError := "integritiy check failed for collection people: could not unmarshal model data for primary key model:people:Person_0: invalid character 'i' looking for beginning of value" - require.EqualError(t, db.CheckIntegrity(), expectedError) -} - -func TestIntegrityCheckIndexKeyWithoutModelData(t *testing.T) { - t.Parallel() - db, col, models, _ := setUpIntegrityCheckTest(t) - defer db.Close() - - // Manually break integrity by deleting a primary key. - keyToDelete := col.info.primaryKeyForModel(models[0]) - require.NoError(t, db.ldb.Delete(keyToDelete, nil)) - expectedError := "integritiy check failed for index people.age: key exists in index but could not find corresponding model data for primary key: model:people:Person_0" - require.EqualError(t, db.CheckIntegrity(), expectedError) -} - -func TestIntegrityCheckModelNotIndexed(t *testing.T) { - t.Parallel() - db, _, models, ageIndex := setUpIntegrityCheckTest(t) - defer db.Close() - - // Manually break integrity by deleting an index key. - keyToDelete := ageIndex.keysForModel(models[0])[0] - require.NoError(t, db.ldb.Delete(keyToDelete, nil)) - expectedError := "integritiy check failed for index people.age: indexKey index:people:age:0:Person_0 does not exist" - require.EqualError(t, db.CheckIntegrity(), expectedError) -} - -func setUpIntegrityCheckTest(t *testing.T) (*DB, *Collection, []*testModel, *Index) { - db := newTestDB(t) - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - ageIndex := col.AddIndex("age", func(m Model) []byte { - return []byte(fmt.Sprint(m.(*testModel).Age)) - }) - - // Insert some test models - models := []*testModel{} - for i := 0; i < 5; i++ { - model := &testModel{ - Name: "Person_" + strconv.Itoa(i), - Age: i, - } - require.NoError(t, col.Insert(model)) - models = append(models, model) - } - - return db, col, models, ageIndex -} diff --git a/db/interfaces.go b/db/interfaces.go deleted file mode 100644 index eb023d968..000000000 --- a/db/interfaces.go +++ /dev/null @@ -1,32 +0,0 @@ -package db - -import ( - "github.com/syndtr/goleveldb/leveldb" - "github.com/syndtr/goleveldb/leveldb/opt" -) - -// dbReader is an interface that encapsulates read-only functionality. -type dbReader interface { - leveldb.Reader - Has(key []byte, ro *opt.ReadOptions) (bool, error) -} - -// dbWriter is an interface that encapsulates write/update functionality. -type dbWriter interface { - Delete(key []byte, wo *opt.WriteOptions) error - Put(key, value []byte, wo *opt.WriteOptions) error -} - -type dbBatchWriter interface { - Write(batch *leveldb.Batch, ro *opt.WriteOptions) error -} - -type dbReadWriter interface { - dbReader - dbWriter -} - -type dbReadBatchWriter interface { - dbReadWriter - dbBatchWriter -} diff --git a/db/open.go b/db/open.go deleted file mode 100644 index c04bd3c42..000000000 --- a/db/open.go +++ /dev/null @@ -1,17 +0,0 @@ -// +build !js - -package db - -import "github.com/syndtr/goleveldb/leveldb" - -// Open creates a new database using the given file path for permanent storage. -// It is not safe to have multiple DBs using the same file path. -func Open(path string) (*DB, error) { - ldb, err := leveldb.OpenFile(path, nil) - if err != nil { - return nil, err - } - return &DB{ - ldb: ldb, - }, nil -} diff --git a/db/open_js.go b/db/open_js.go deleted file mode 100644 index eb578bcba..000000000 --- a/db/open_js.go +++ /dev/null @@ -1,77 +0,0 @@ -// +build js,wasm - -package db - -import ( - "errors" - "syscall/js" - "time" - - "github.com/0xProject/0x-mesh/packages/browser/go/jsutil" - log "github.com/sirupsen/logrus" - "github.com/syndtr/goleveldb/leveldb" - "github.com/syndtr/goleveldb/leveldb/storage" -) - -const ( - // browserFSLoadCheckInterval is frequently to check whether browserFS is - // loaded. - browserFSLoadCheckInterval = 50 * time.Millisecond - // browserFSLoadTimeout is how long to wait for BrowserFS to finish loading - // before giving up. - browserFSLoadTimeout = 5 * time.Second -) - -// Open creates a new database for js/wasm environments. -func Open(path string) (*DB, error) { - // The global willLoadBrowserFS variable indicates whether browserFS will be - // loaded. browserFS has to be explicitly loaded in by JavaScript (and - // typically Webpack) and can't be loaded here. - if willLoadBrowserFS := js.Global().Get("willLoadBrowserFS"); !jsutil.IsNullOrUndefined(willLoadBrowserFS) && willLoadBrowserFS.Bool() == true { - return openBrowserFSDB(path) - } - // If browserFS is not going to be loaded, fallback to using an in-memory - // database. - return openInMemoryDB() -} - -func openInMemoryDB() (*DB, error) { - log.Warn("BrowserFS not detected. Using in-memory databse.") - ldb, err := leveldb.Open(storage.NewMemStorage(), nil) - if err != nil { - return nil, err - } - return &DB{ - ldb: ldb, - }, nil -} - -func openBrowserFSDB(path string) (*DB, error) { - log.Info("BrowserFS detected. Using BrowserFS-backed databse.") - // Wait for browserFS to load. - // - // HACK(albrow): We do this by checking for the global browserFS - // variable. This is definitely a bit of a hack and wastes some CPU resources, - // but it is also extremely reliable. Given that we have a chicken and egg - // problem with both Wasm and JavaScript code loading and executing at the - // same time, it is difficult to match this level of reliability with something - // like callback functions or events. - start := time.Now() - for { - if time.Since(start) >= browserFSLoadTimeout { - return nil, errors.New("timed out waiting for BrowserFS to load") - } - if !jsutil.IsNullOrUndefined(js.Global().Get("browserFS")) { - log.Info("BrowserFS finished loading") - break - } - time.Sleep(browserFSLoadCheckInterval) - } - ldb, err := leveldb.OpenFile(path, nil) - if err != nil { - return nil, err - } - return &DB{ - ldb: ldb, - }, nil -} diff --git a/db/operations.go b/db/operations.go deleted file mode 100644 index 7e4b9bf27..000000000 --- a/db/operations.go +++ /dev/null @@ -1,235 +0,0 @@ -package db - -import ( - "encoding/json" - "errors" - "fmt" - "reflect" - "strconv" - - "github.com/syndtr/goleveldb/leveldb" - "github.com/syndtr/goleveldb/leveldb/iterator" - "github.com/syndtr/goleveldb/leveldb/util" -) - -func findByID(info *colInfo, reader dbReader, id []byte, model Model) error { - if err := info.checkModelType(model); err != nil { - return err - } - pk := info.primaryKeyForID(id) - data, err := reader.Get(pk, nil) - if err != nil { - if err == leveldb.ErrNotFound { - return NotFoundError{ID: id} - } - return err - } - return json.Unmarshal(data, model) -} - -func findAll(info *colInfo, reader dbReader, models interface{}) error { - prefixRange := util.BytesPrefix([]byte(fmt.Sprintf("%s:", info.prefix()))) - iter := reader.NewIterator(prefixRange, nil) - return findWithIterator(info, iter, models) -} - -func findWithIterator(info *colInfo, iter iterator.Iterator, models interface{}) error { - defer iter.Release() - if err := info.checkModelsType(models); err != nil { - return err - } - modelsVal := reflect.ValueOf(models).Elem() - for iter.Next() && iter.Error() == nil { - // We assume that each value in the iterator is the encoded data for some - // model. - data := iter.Value() - model := reflect.New(info.modelType) - if err := json.Unmarshal(data, model.Interface()); err != nil { - return err - } - modelsVal.Set(reflect.Append(modelsVal, model.Elem())) - } - return nil -} - -// findExistingModelByPrimaryKeyWithTransaction gets the latest data for the -// given primary key. Useful in cases where the given model may be out of date -// with what is currently stored in the database. It *doesn't* discard the -// transaction if there is an error. -func findExistingModelByPrimaryKeyWithTransaction(info *colInfo, readWriter dbReadWriter, primaryKey []byte) (Model, error) { - data, err := readWriter.Get(primaryKey, nil) - if err != nil { - return nil, err - } - // Use reflect to create a new reference for the model type. - modelRef := reflect.New(info.modelType).Interface() - if err := json.Unmarshal(data, modelRef); err != nil { - return nil, err - } - model := reflect.ValueOf(modelRef).Elem().Interface().(Model) - return model, nil -} - -func insertWithTransaction(info *colInfo, readWriter dbReadWriter, model Model) error { - if len(model.ID()) == 0 { - return errors.New("can't insert model with empty ID") - } - if err := info.checkModelType(model); err != nil { - return err - } - data, err := json.Marshal(model) - if err != nil { - return err - } - pk := info.primaryKeyForModel(model) - if exists, err := readWriter.Has(pk, nil); err != nil { - return err - } else if exists { - return AlreadyExistsError{ID: model.ID()} - } - if err := readWriter.Put(pk, data, nil); err != nil { - return err - } - if err := saveIndexesWithTransaction(info, readWriter, model); err != nil { - return err - } - return nil -} - -func updateWithTransaction(info *colInfo, readWriter dbReadWriter, model Model) error { - if len(model.ID()) == 0 { - return errors.New("can't update model with empty ID") - } - if err := info.checkModelType(model); err != nil { - return err - } - - // Check if the model already exists and return an error if not. - pk := info.primaryKeyForModel(model) - if exists, err := readWriter.Has(pk, nil); err != nil { - return err - } else if !exists { - return NotFoundError{ID: model.ID()} - } - - // Get the existing data for the model and delete any (now outdated) indexes. - existingModel, err := findExistingModelByPrimaryKeyWithTransaction(info, readWriter, pk) - if err != nil { - return err - } - if err := deleteIndexesWithTransaction(info, readWriter, existingModel); err != nil { - return err - } - - // Save the new data and add the new indexes. - newData, err := json.Marshal(model) - if err != nil { - return err - } - if err := readWriter.Put(pk, newData, nil); err != nil { - return err - } - if err := saveIndexesWithTransaction(info, readWriter, model); err != nil { - return err - } - return nil -} - -func deleteWithTransaction(info *colInfo, readWriter dbReadWriter, id []byte) error { - if len(id) == 0 { - return errors.New("can't delete model with empty ID") - } - - // We need to get the latest data because the given model might be out of sync - // with the actual data in the database. - pk := info.primaryKeyForID(id) - latest, err := findExistingModelByPrimaryKeyWithTransaction(info, readWriter, pk) - if err != nil { - if err == leveldb.ErrNotFound { - return NotFoundError{ID: id} - } - return err - } - - // Delete the primary key. - if err := readWriter.Delete(pk, nil); err != nil { - return err - } - - // Delete any index entries. - if err := deleteIndexesWithTransaction(info, readWriter, latest); err != nil { - return err - } - - return nil -} - -func saveIndexesWithTransaction(info *colInfo, readWriter dbReadWriter, model Model) error { - info.indexMut.RLock() - defer info.indexMut.RUnlock() - for _, index := range info.indexes { - keys := index.keysForModel(model) - for _, key := range keys { - if err := readWriter.Put(key, nil, nil); err != nil { - return err - } - } - } - return nil -} - -// deleteIndexesForModel deletes any indexes computed from the given model. It -// *doesn't* discard the transaction if there is an error. -func deleteIndexesWithTransaction(info *colInfo, readWriter dbReadWriter, model Model) error { - info.indexMut.RLock() - defer info.indexMut.RUnlock() - for _, index := range info.indexes { - keys := index.keysForModel(model) - for _, key := range keys { - if err := readWriter.Delete(key, nil); err != nil { - return err - } - } - } - return nil -} - -func count(info *colInfo, reader dbReader) (int, error) { - encodedCount, err := reader.Get(info.countKey(), nil) - if err != nil { - if err == leveldb.ErrNotFound { - // If countKey doesn't exist, assume no models have been inserted and - // return a count of 0. - return 0, nil - } - return 0, err - } - count, err := decodeInt(encodedCount) - if err != nil { - return 0, err - } - return count, nil -} - -func updateCountWithTransaction(info *colInfo, readWriter dbReadWriter, diff int) error { - existingCount, err := count(info, readWriter) - if err != nil { - return err - } - newCount := existingCount + diff - if newCount == 0 { - return readWriter.Delete(info.countKey(), nil) - } else { - return readWriter.Put(info.countKey(), encodeInt(newCount), nil) - } -} - -func encodeInt(i int) []byte { - // TODO(albrow): Could potentially be optimized. - return []byte(strconv.Itoa(i)) -} - -func decodeInt(b []byte) (int, error) { - // TODO(albrow): Could potentially be optimized. - return strconv.Atoi(string(b)) -} diff --git a/db/query.go b/db/query.go deleted file mode 100644 index 08df238a7..000000000 --- a/db/query.go +++ /dev/null @@ -1,214 +0,0 @@ -package db - -import ( - "encoding/json" - "fmt" - "reflect" - - "github.com/syndtr/goleveldb/leveldb" - - "github.com/albrow/stringset" - "github.com/syndtr/goleveldb/leveldb/iterator" - "github.com/syndtr/goleveldb/leveldb/util" -) - -// Query is used to return certain results from the database. -type Query struct { - colInfo *colInfo - reader dbReader - filter *Filter - max int - offset int - reverse bool -} - -// Filter determines which models to return in the query and what order to -// return them in. -type Filter struct { - index *Index - slice *util.Range -} - -func newQuery(colInfo *colInfo, reader dbReader, filter *Filter) *Query { - return &Query{ - colInfo: colInfo, - reader: reader, - filter: filter, - } -} - -// Max causes the query to only return up to max results. It is the analog of -// the LIMIT keyword in SQL: -// https://www.postgresql.org/docs/current/queries-limit.html -func (q *Query) Max(max int) *Query { - q.max = max - return q -} - -// Reverse causes the query to return models in descending byte order according -// to their index values instead of the default (ascending byte order). -func (q *Query) Reverse() *Query { - q.reverse = true - return q -} - -// Offset causes the query to skip offset models when iterating through models -// that match the query. Note that queries which use an offset have a runtime -// of O(max(K, offset) + N), where N is the number of models returned by the -// query and K is the total number of keys in the corresponding index. Queries -// with a high offset can take a long time to run, regardless of the number of -// models returned. This is due to limitations of the underlying database. -// Offset is the analog of the OFFSET keyword in SQL: -// https://www.postgresql.org/docs/current/queries-limit.html -func (q *Query) Offset(offset int) *Query { - q.offset = offset - return q -} - -// ValueFilter returns a Filter which will match all models with an index value -// equal to the given value. -func (index *Index) ValueFilter(val []byte) *Filter { - prefix := []byte(fmt.Sprintf("%s:%s:", index.prefix(), escape(val))) - return &Filter{ - index: index, - slice: util.BytesPrefix(prefix), - } -} - -// RangeFilter returns a Filter which will match all models with an index value -// >= start and < limit. -func (index *Index) RangeFilter(start []byte, limit []byte) *Filter { - startWithPrefix := []byte(fmt.Sprintf("%s:%s", index.prefix(), escape(start))) - limitWithPrefix := []byte(fmt.Sprintf("%s:%s", index.prefix(), escape(limit))) - slice := &util.Range{Start: startWithPrefix, Limit: limitWithPrefix} - return &Filter{ - index: index, - slice: slice, - } -} - -// PrefixFilter returns a Filter which will match all models with an index value -// that starts with the given prefix. -func (index *Index) PrefixFilter(prefix []byte) *Filter { - keyPrefix := []byte(fmt.Sprintf("%s:%s", index.prefix(), escape(prefix))) - return &Filter{ - index: index, - slice: util.BytesPrefix(keyPrefix), - } -} - -// All returns a Filter which will match all models. It is useful for when you -// want to retrieve models in sorted order without excluding any of them. -func (index *Index) All() *Filter { - return index.PrefixFilter([]byte{}) -} - -// Run runs the query and scans the results into models. models should be a -// pointer to an empty slice of a concrete model type (e.g. *[]myModelType). It -// returns an error if models is the wrong type or there was a problem reading -// from the database. It does not return an error if no models match the query. -func (q *Query) Run(models interface{}) error { - if err := q.colInfo.checkModelsType(models); err != nil { - return err - } - - iter := q.reader.NewIterator(q.filter.slice, nil) - defer iter.Release() - if q.reverse { - return q.getModelsWithIteratorReverse(iter, models) - } - return q.getModelsWithIteratorForward(iter, models) -} - -// Count returns the number of unique models that match the query. It does not -// return an error if no models match the query. Note that this method *does* -// respect q.Max. If the number of models that match the filter is greater than -// q.Max, it will stop counting and return q.Max. -func (q *Query) Count() (int, error) { - iter := q.reader.NewIterator(q.filter.slice, nil) - defer iter.Release() - pkSet := stringset.New() - for i := 0; iter.Next() && iter.Error() == nil; i++ { - if i < q.offset { - continue - } - pk := q.filter.index.primaryKeyFromIndexKey(iter.Key()) - pkSet.Add(string(pk)) - if q.max != 0 && len(pkSet) >= q.max { - break - } - } - if iter.Error() != nil { - return 0, iter.Error() - } - return len(pkSet), nil -} - -func (q *Query) getModelsWithIteratorForward(iter iterator.Iterator, models interface{}) error { - // MultiIndexes can result in the same model being included more than once. To - // prevent this, we keep track of the primaryKeys we have already seen using - // pkSet. - pkSet := stringset.New() - modelsVal := reflect.ValueOf(models).Elem() - for i := 0; iter.Next() && iter.Error() == nil; i++ { - if i < q.offset { - continue - } - if err := q.getAndAppendModelIfUnique(q.filter.index, pkSet, iter.Key(), modelsVal); err != nil { - return err - } - if q.max != 0 && modelsVal.Len() >= q.max { - return iter.Error() - } - } - return iter.Error() -} - -func (q *Query) getModelsWithIteratorReverse(iter iterator.Iterator, models interface{}) error { - pkSet := stringset.New() - modelsVal := reflect.ValueOf(models).Elem() - // Move the iterator to the last key and then iterate backwards by calling - // Prev instead of Next for each iteration of the for loop. - iter.Last() - iter.Next() - for i := 0; iter.Prev() && iter.Error() == nil; i++ { - if i < q.offset { - continue - } - if err := q.getAndAppendModelIfUnique(q.filter.index, pkSet, iter.Key(), modelsVal); err != nil { - return err - } - if q.max != 0 && modelsVal.Len() >= q.max { - return iter.Error() - } - } - return iter.Error() -} - -func (q *Query) getAndAppendModelIfUnique(index *Index, pkSet stringset.Set, key []byte, modelsVal reflect.Value) error { - // We assume that each key in the iterator consists of an index prefix, the - // value for a particular model, and the model ID. We can extract a primary - // key from this key and use it to get the encoded data for the model - // itself. - pk := index.primaryKeyFromIndexKey(key) - if pkSet.Contains(string(pk)) { - return nil - } - pkSet.Add(string(pk)) - data, err := q.reader.Get(pk, nil) - if err == leveldb.ErrNotFound || data == nil { - // It is possible that a separate goroutine deleted the model while we were - // iterating through the keys in the index. This is not considered an error. - // We simply don't include this model in the final results. - return nil - } - if err != nil { - return err - } - model := reflect.New(q.colInfo.modelType) - if err := json.Unmarshal(data, model.Interface()); err != nil { - return err - } - modelsVal.Set(reflect.Append(modelsVal, model.Elem())) - return nil -} diff --git a/db/query_benchmark_test.go b/db/query_benchmark_test.go deleted file mode 100644 index 4eb1b9b5d..000000000 --- a/db/query_benchmark_test.go +++ /dev/null @@ -1,167 +0,0 @@ -package db - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/require" -) - -const defaultTargetNickname = "target" - -func setupQueryBenchmark(b *testing.B) (db *DB, col *Collection, nicknameIndex *Index) { - b.Helper() - db = newTestDB(b) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(b, err) - nicknameIndex = col.AddMultiIndex("nicknames", func(m Model) [][]byte { - person := m.(*testModel) - indexValues := make([][]byte, len(person.Nicknames)) - for i, nickname := range person.Nicknames { - indexValues[i] = []byte(nickname) - } - return indexValues - }) - return db, col, nicknameIndex -} - -func insertModelsForQueryBenchmark(b *testing.B, col *Collection, targetNickname string, targetCount int, otherCount int) { - txn := col.OpenTransaction() - - defer func() { - _ = txn.Discard() - }() - // Insert targetCount models with nickname = targetNickname - for i := 0; i < targetCount; i++ { - model := &testModel{ - Name: fmt.Sprintf("person_%d", i), - Age: i, - Nicknames: []string{targetNickname}, - } - require.NoError(b, txn.Insert(model)) - } - // Insert otherCount with nickname != targetnickName - for i := 0; i < otherCount; i++ { - model := &testModel{ - Name: fmt.Sprintf("person_%d", i), - Age: i, - Nicknames: []string{fmt.Sprintf("not_%s_%d", targetNickname, i)}, - } - require.NoError(b, txn.Insert(model)) - } - require.NoError(b, txn.Commit()) -} - -func benchmarkQueryFind(b *testing.B, targetCount int, total int) { - benchmarkQueryFindWithMaxAndOffset(b, targetCount, total, 0, 0) -} - -func BenchmarkQueryFind100OutOf100(b *testing.B) { - benchmarkQueryFind(b, 100, 100) -} - -func BenchmarkQueryFind100OutOf1000(b *testing.B) { - benchmarkQueryFind(b, 100, 1000) -} - -func BenchmarkQueryFind100OutOf10000(b *testing.B) { - benchmarkQueryFind(b, 100, 10000) -} - -func BenchmarkQueryFind1000OutOf1000(b *testing.B) { - benchmarkQueryFind(b, 1000, 1000) -} - -func BenchmarkQueryFind1000OutOf10000(b *testing.B) { - benchmarkQueryFind(b, 1000, 10000) -} - -func BenchmarkQueryFind10000OutOf10000(b *testing.B) { - benchmarkQueryFind(b, 10000, 10000) -} - -func benchmarkQueryCount(b *testing.B, targetCount int, total int) { - db, col, nicknameIndex := setupQueryBenchmark(b) - defer db.Close() - insertModelsForQueryBenchmark(b, col, defaultTargetNickname, targetCount, total-targetCount) - b.ResetTimer() - for i := 0; i < b.N; i++ { - query := col.NewQuery(nicknameIndex.ValueFilter([]byte(defaultTargetNickname))) - _, err := query.Count() - b.StopTimer() - require.NoError(b, err) - b.StartTimer() - } -} - -func BenchmarkQueryCount100OutOf100(b *testing.B) { - benchmarkQueryCount(b, 100, 100) -} - -func BenchmarkQueryCount100OutOf1000(b *testing.B) { - benchmarkQueryCount(b, 100, 1000) -} - -func BenchmarkQueryCount100OutOf10000(b *testing.B) { - benchmarkQueryCount(b, 100, 10000) -} - -func BenchmarkQueryCount1000OutOf1000(b *testing.B) { - benchmarkQueryCount(b, 1000, 1000) -} - -func BenchmarkQueryCount1000OutOf10000(b *testing.B) { - benchmarkQueryCount(b, 1000, 10000) -} - -func BenchmarkQueryCount10000OutOf10000(b *testing.B) { - benchmarkQueryCount(b, 10000, 10000) -} - -func benchmarkQueryFindWithMaxAndOffset(b *testing.B, targetCount int, total int, max int, offset int) { - db, col, nicknameIndex := setupQueryBenchmark(b) - defer db.Close() - insertModelsForQueryBenchmark(b, col, defaultTargetNickname, targetCount, total-targetCount) - b.ResetTimer() - for i := 0; i < b.N; i++ { - query := col.NewQuery(nicknameIndex.ValueFilter([]byte(defaultTargetNickname))) - var actual []*testModel - err := query.Max(max).Offset(offset).Run(&actual) - b.StopTimer() - require.NoError(b, err) - b.StartTimer() - } -} - -func BenchmarkQueryFind1000OutOf10000Max100Offset0(b *testing.B) { - benchmarkQueryFindWithMaxAndOffset(b, 1000, 10000, 100, 0) -} - -func BenchmarkQueryFind1000OutOf10000Max100Offset100(b *testing.B) { - benchmarkQueryFindWithMaxAndOffset(b, 1000, 10000, 100, 100) -} - -func BenchmarkQueryFind1000OutOf10000Max100Offset900(b *testing.B) { - benchmarkQueryFindWithMaxAndOffset(b, 1000, 10000, 100, 900) -} - -func BenchmarkQueryFind1000OutOf10000Max100Offset1000(b *testing.B) { - benchmarkQueryFindWithMaxAndOffset(b, 1000, 10000, 100, 1000) -} - -func BenchmarkQueryFind10000OutOf10000Max1000Offset0(b *testing.B) { - benchmarkQueryFindWithMaxAndOffset(b, 10000, 10000, 1000, 0) -} - -func BenchmarkQueryFind10000OutOf10000Max1000Offset1000(b *testing.B) { - benchmarkQueryFindWithMaxAndOffset(b, 10000, 10000, 1000, 1000) -} - -func BenchmarkQueryFind10000OutOf10000Max1000Offset9000(b *testing.B) { - benchmarkQueryFindWithMaxAndOffset(b, 10000, 10000, 1000, 9000) -} - -func BenchmarkQueryFind10000OutOf10000Max1000Offset10000(b *testing.B) { - benchmarkQueryFindWithMaxAndOffset(b, 10000, 10000, 1000, 10000) -} diff --git a/db/query_test.go b/db/query_test.go deleted file mode 100644 index 959ef3cb4..000000000 --- a/db/query_test.go +++ /dev/null @@ -1,327 +0,0 @@ -package db - -import ( - "fmt" - "math" - "strconv" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestQueryWithValue(t *testing.T) { - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - - ageIndex := col.AddIndex("age", func(m Model) []byte { - return []byte(fmt.Sprint(m.(*testModel).Age)) - }) - - // expected is a set of testModels with Age = 42 - expected := []*testModel{} - for i := 0; i < 5; i++ { - model := &testModel{ - Name: "ExpectedPerson_" + strconv.Itoa(i), - Age: 42, - } - require.NoError(t, col.Insert(model)) - expected = append(expected, model) - } - - // We also insert some other models with a different age. - for i := 0; i < 5; i++ { - model := &testModel{ - Name: "OtherPerson_" + strconv.Itoa(i), - Age: i, - } - require.NoError(t, col.Insert(model)) - } - - // Save one more model with an Age that is a prefix of the target age. - model := &testModel{ - Name: "PersonWithPrefixAge", - Age: 420, - } - require.NoError(t, col.Insert(model)) - - filter := ageIndex.ValueFilter([]byte("42")) - testQueryWithFilter(t, col, filter, expected) -} - -func TestQueryWithRange(t *testing.T) { - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - - ageIndex := col.AddIndex("age", func(m Model) []byte { - return []byte(fmt.Sprint(m.(*testModel).Age)) - }) - - all := []*testModel{} - for i := 0; i < 5; i++ { - model := &testModel{ - Name: "Person_" + strconv.Itoa(i), - Age: i, - } - require.NoError(t, col.Insert(model)) - all = append(all, model) - } - // expected is the set of people with 1 <= age < 4 - expected := all[1:4] - filter := ageIndex.RangeFilter([]byte("1"), []byte("4")) - testQueryWithFilter(t, col, filter, expected) -} - -func TestQueryWithPrefix(t *testing.T) { - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - - ageIndex := col.AddIndex("age", func(m Model) []byte { - return []byte(fmt.Sprint(m.(*testModel).Age)) - }) - - // expected is a set of testModels with an age that starts with "2" - expected := []*testModel{ - { - Name: "ExpectedPerson_0", - Age: 2021, - }, - { - Name: "ExpectedPerson_1", - Age: 22, - }, - { - Name: "ExpectedPerson_2", - Age: 250, - }, - } - for _, model := range expected { - require.NoError(t, col.Insert(model)) - } - - // We also insert some other models with different ages. - excluded := []*testModel{ - { - Name: "ExcludedPerson_0", - Age: 40, - }, - { - Name: "ExcludedPerson_1", - Age: 41, - }, - { - Name: "ExcludedPerson_2", - Age: 42, - }, - } - for _, model := range excluded { - require.NoError(t, col.Insert(model)) - } - { - filter := ageIndex.PrefixFilter([]byte("2")) - testQueryWithFilter(t, col, filter, expected) - } - { - // An empty prefix should return all models. - all := append(expected, excluded...) - filter := ageIndex.PrefixFilter([]byte{}) - testQueryWithFilter(t, col, filter, all) - } -} - -func TestFindWithValueWithMultiIndex(t *testing.T) { - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - nicknameIndex := col.AddMultiIndex("nicknames", func(m Model) [][]byte { - person := m.(*testModel) - indexValues := make([][]byte, len(person.Nicknames)) - for i, nickname := range person.Nicknames { - indexValues[i] = []byte(nickname) - } - return indexValues - }) - - // expected is a set of testModels that include the nickname "Bob" - expected := []*testModel{ - { - Name: "ExpectedPerson_0", - Age: 42, - Nicknames: []string{"Bob", "Jim", "John"}, - }, - { - Name: "ExpectedPerson_1", - Age: 43, - Nicknames: []string{"Alice", "Bob", "Emily"}, - }, - { - Name: "ExpectedPerson_2", - Age: 44, - Nicknames: []string{"Bob", "No one"}, - }, - } - for _, model := range expected { - require.NoError(t, col.Insert(model)) - } - - // We also insert some other models with different nicknames. - excluded := []*testModel{ - { - Name: "ExcludedPerson_0", - Age: 42, - Nicknames: []string{"Bill", "Jim", "John"}, - }, - { - Name: "ExcludedPerson_1", - Age: 43, - Nicknames: []string{"Alice", "Jane", "Emily"}, - }, - { - Name: "ExcludedPerson_2", - Age: 44, - Nicknames: []string{"Nemo", "No one"}, - }, - } - for _, model := range excluded { - require.NoError(t, col.Insert(model)) - } - - filter := nicknameIndex.ValueFilter([]byte("Bob")) - testQueryWithFilter(t, col, filter, expected) -} - -func TestFindWithRangeWithMultiIndex(t *testing.T) { - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - nicknameIndex := col.AddMultiIndex("nicknames", func(m Model) [][]byte { - person := m.(*testModel) - indexValues := make([][]byte, len(person.Nicknames)) - for i, nickname := range person.Nicknames { - indexValues[i] = []byte(nickname) - } - return indexValues - }) - - // expected is a set of testModels that include at least one nickname that - // satisfies "B" <= nickname < "E" - expected := []*testModel{ - { - Name: "ExpectedPerson_0", - Age: 42, - Nicknames: []string{"Alice", "Beth", "Emily"}, - }, - { - Name: "ExpectedPerson_1", - Age: 43, - Nicknames: []string{"Bob", "Charles", "Dan"}, - }, - { - Name: "ExpectedPerson_2", - Age: 44, - Nicknames: []string{"James", "Darell"}, - }, - } - for _, model := range expected { - require.NoError(t, col.Insert(model)) - } - - // We also insert some other models with different nicknames. - excluded := []*testModel{ - { - Name: "ExcludedPerson_0", - Age: 42, - Nicknames: []string{"Allen", "Jim", "John"}, - }, - { - Name: "ExcludedPerson_1", - Age: 43, - Nicknames: []string{"Sophia", "Jane", "Emily"}, - }, - { - Name: "ExcludedPerson_2", - Age: 44, - Nicknames: []string{"Nemo", "No one"}, - }, - } - for _, model := range excluded { - require.NoError(t, col.Insert(model)) - } - - filter := nicknameIndex.RangeFilter([]byte("B"), []byte("E")) - testQueryWithFilter(t, col, filter, expected) -} - -// testQueryWithFilter runs a comprehensive set of queries based on the given -// filter and checks that the results are always what we expect. -func testQueryWithFilter(t *testing.T, col *Collection, filter *Filter, expected []*testModel) { - reverseExpected := reverseSlice(expected) - // safeMax is min(2, len(expected)) to account for the fact that expected may - // have length shorter than 2. - var safeMax = int(math.Min(2, float64(len(expected)))) - - // Each test case covers slight variations of the same query. - testCases := []struct { - query *Query - expected []*testModel - }{ - { - query: col.NewQuery(filter), - expected: expected, - }, - { - query: col.NewQuery(filter).Reverse(), - expected: reverseExpected, - }, - { - query: col.NewQuery(filter).Max(safeMax), - expected: expected[:safeMax], - }, - { - query: col.NewQuery(filter).Reverse().Max(safeMax), - expected: reverseExpected[:safeMax], - }, - { - query: col.NewQuery(filter).Offset(1), - expected: expected[1:], - }, - { - query: col.NewQuery(filter).Offset(1).Max(safeMax - 1), - expected: expected[1:safeMax], - }, - { - query: col.NewQuery(filter).Offset(1).Reverse(), - expected: reverseExpected[1:], - }, - { - query: col.NewQuery(filter).Offset(1).Max(safeMax - 1).Reverse(), - expected: reverseExpected[1:safeMax], - }, - } - - for i, tc := range testCases { - var actual []*testModel - require.NoError(t, tc.query.Run(&actual), "test case %d", i) - assert.Equal(t, tc.expected, actual, "test case %d", i) - actualCount, err := tc.query.Count() - require.NoError(t, err, "test case %d", i) - assert.Equal(t, len(tc.expected), actualCount, "test case %d", i) - } -} - -func reverseSlice(s []*testModel) []*testModel { - reversed := make([]*testModel, len(s)) - copy(reversed, s) - for left, right := 0, len(reversed)-1; left < right; left, right = left+1, right-1 { - reversed[left], reversed[right] = reversed[right], reversed[left] - } - return reversed -} diff --git a/db/snapshot.go b/db/snapshot.go deleted file mode 100644 index 3ec47bedb..000000000 --- a/db/snapshot.go +++ /dev/null @@ -1,64 +0,0 @@ -package db - -import ( - "github.com/syndtr/goleveldb/leveldb" -) - -// Snapshot is a frozen, read-only snapshot of a DB state at a particular point -// in time. -type Snapshot struct { - colInfo *colInfo - snapshot *leveldb.Snapshot -} - -// GetSnapshot returns a latest snapshot of the underlying DB. The content of -// snapshot are guaranteed to be consistent. The snapshot must be released after -// use, by calling Release method. -func (c *Collection) GetSnapshot() (*Snapshot, error) { - snapshot, err := c.ldb.GetSnapshot() - if err != nil { - return nil, err - } - return &Snapshot{ - colInfo: c.info.copy(), - snapshot: snapshot, - }, nil -} - -// Release releases the snapshot. This will not release any ongoing queries, -// which will still finish unless the database is closed. Other methods should -// not be called after the snapshot has been released. -func (s *Snapshot) Release() { - s.snapshot.Release() -} - -// FindByID finds the model with the given ID and scans the results into the -// given model. As in the Unmarshal and Decode methods in the encoding/json -// package, model must be settable via reflect. Typically, this means you should -// pass in a pointer. -func (s *Snapshot) FindByID(id []byte, model Model) error { - return findByID(s.colInfo, s.snapshot, id, model) -} - -// FindAll finds all models for the collection and scans the results into the -// given models. models should be a pointer to an empty slice of a concrete -// model type (e.g. *[]myModelType). -func (s *Snapshot) FindAll(models interface{}) error { - return findAll(s.colInfo, s.snapshot, models) -} - -// Count returns the number of models in the collection. -func (s *Snapshot) Count() (int, error) { - return count(s.colInfo, s.snapshot) -} - -// New Query creates and returns a new query with the given filter. By default, -// a query will return all models that match the filter in ascending byte order -// according to their index values. The query offers methods that can be used to -// change this (e.g. Reverse and Max). The query is lazily executed, i.e. it -// does not actually touch the database until they are run. In general, queries -// have a runtime of O(N) where N is the number of models that are returned by -// the query, but using some features may significantly change this. -func (s *Snapshot) NewQuery(filter *Filter) *Query { - return newQuery(s.colInfo, s.snapshot, filter) -} diff --git a/db/snapshot_benchmark_test.go b/db/snapshot_benchmark_test.go deleted file mode 100644 index daf9482d9..000000000 --- a/db/snapshot_benchmark_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package db - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/require" -) - -func BenchmarkGetSnapshot(b *testing.B) { - db := newTestDB(b) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(b, err) - for i := 0; i < 1000; i++ { - model := &testModel{ - Name: fmt.Sprintf("person_%d", i), - Age: i, - } - require.NoError(b, col.Insert(model)) - } - b.ResetTimer() - for i := 0; i < b.N; i++ { - snapshot, err := col.GetSnapshot() - b.StopTimer() - require.NoError(b, err) - snapshot.Release() - b.StartTimer() - } -} - -func BenchmarkSnapshotFindByID(b *testing.B) { - db := newTestDB(b) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(b, err) - model := &testModel{ - Name: "foo", - Age: 42, - } - require.NoError(b, col.Insert(model)) - snapshot, err := col.GetSnapshot() - require.NoError(b, err) - defer snapshot.Release() - b.ResetTimer() - for i := 0; i < b.N; i++ { - var found testModel - err := snapshot.FindByID(model.ID(), &found) - b.StopTimer() - require.NoError(b, err) - b.StartTimer() - } -} diff --git a/db/snapshot_test.go b/db/snapshot_test.go deleted file mode 100644 index df50c1227..000000000 --- a/db/snapshot_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package db - -import ( - "fmt" - "strconv" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestSnapshot(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - - ageIndex := col.AddIndex("age", func(m Model) []byte { - return []byte(fmt.Sprint(m.(*testModel).Age)) - }) - - // expected is a set of testModels with Age = 42 - expected := []*testModel{} - for i := 0; i < 5; i++ { - model := &testModel{ - Name: "ExpectedPerson_" + strconv.Itoa(i), - Age: 42, - } - require.NoError(t, col.Insert(model)) - expected = append(expected, model) - } - - // Take a snapshot. - snapshot, err := col.GetSnapshot() - require.NoError(t, err) - defer snapshot.Release() - - // Any models we add after taking the snapshot should not affect the query. - for i := 0; i < 5; i++ { - model := &testModel{ - Name: "OtherPerson_" + strconv.Itoa(i), - Age: 42, - } - require.NoError(t, col.Insert(model)) - } - - // Any models we delete after taking the snapshot should not affect the query. - for _, model := range expected { - require.NoError(t, col.Delete(model.ID())) - } - - // Any new indexes we add should not affect indexes in the snapshot. - col.AddIndex("name", func(m Model) []byte { - return []byte(m.(*testModel).Name) - }) - assert.Equal(t, []*Index{ageIndex}, snapshot.colInfo.indexes) - - // Make sure that the query only return results that match the state at the - // time the snapshot was taken. - filter := ageIndex.ValueFilter([]byte("42")) - query := snapshot.NewQuery(filter) - var actual []*testModel - require.NoError(t, query.Run(&actual)) - assert.Equal(t, expected, actual) - actualCount, err := snapshot.Count() - require.NoError(t, err) - assert.Equal(t, len(expected), actualCount) -} diff --git a/db/sql_implementation.go b/db/sql_implementation.go new file mode 100644 index 000000000..8e74e9158 --- /dev/null +++ b/db/sql_implementation.go @@ -0,0 +1,797 @@ +// +build !js + +package db + +import ( + "context" + "database/sql" + "errors" + "fmt" + "math/big" + "os" + "path/filepath" + + "github.com/0xProject/0x-mesh/common/types" + "github.com/0xProject/0x-mesh/db/sqltypes" + "github.com/ethereum/go-ethereum/common" + "github.com/gibson042/canonicaljson-go" + "github.com/google/uuid" + "github.com/ido50/sqlz" + "github.com/jmoiron/sqlx" + _ "github.com/mattn/go-sqlite3" +) + +var _ Database = (*DB)(nil) + +// DB instantiates the DB connection and creates all the collections used by the application +type DB struct { + ctx context.Context + sqldb *sqlz.DB + opts *Options +} + +func defaultOptions() *Options { + return &Options{ + DriverName: "sqlite3", + DataSourceName: "0x_mesh/db/db.sqlite", + MaxOrders: 100000, + MaxMiniHeaders: 20, + } +} + +// TestOptions returns a set of options suitable for testing. +func TestOptions() *Options { + return &Options{ + DriverName: "sqlite3", + DataSourceName: filepath.Join("/tmp", "mesh_testing", uuid.New().String(), "db.sqlite"), + MaxOrders: 100, + MaxMiniHeaders: 20, + } +} + +// New creates a new connection to the database. The connection will be automatically closed +// when the given context is canceled. +func New(ctx context.Context, opts *Options) (*DB, error) { + opts = parseOptions(opts) + + connectCtx, cancel := context.WithTimeout(ctx, connectTimeout) + defer cancel() + + if err := os.MkdirAll(filepath.Dir(opts.DataSourceName), os.ModePerm); err != nil && err != os.ErrExist { + return nil, err + } + + sqldb, err := sqlx.ConnectContext(connectCtx, opts.DriverName, opts.DataSourceName) + if err != nil { + return nil, err + } + + // Automatically close the database connection when the context is canceled. + go func() { + select { + case <-ctx.Done(): + _ = sqldb.Close() + } + }() + + db := &DB{ + ctx: ctx, + sqldb: sqlz.Newx(sqldb), + opts: opts, + } + if err := db.migrate(); err != nil { + return nil, err + } + + return db, nil +} + +// TODO(albrow): Use a proper migration tool. We don't technically need this +// now but it will be necessary if we ever change the database schema. +// Note(albrow): If needed, we can optimize this by adding indexes to the +// orders and miniHeaders tables. +const schema = ` +CREATE TABLE IF NOT EXISTS orders ( + hash TEXT UNIQUE NOT NULL, + chainID TEXT NOT NULL, + exchangeAddress TEXT NOT NULL, + makerAddress TEXT NOT NULL, + makerAssetData TEXT NOT NULL, + makerFeeAssetData TEXT NOT NULL, + makerAssetAmount TEXT NOT NULL, + makerFee TEXT NOT NULL, + takerAddress TEXT NOT NULL, + takerAssetData TEXT NOT NULL, + takerFeeAssetData TEXT NOT NULL, + takerAssetAmount TEXT NOT NULL, + takerFee TEXT NOT NULL, + senderAddress TEXT NOT NULL, + feeRecipientAddress TEXT NOT NULL, + expirationTimeSeconds TEXT NOT NULL, + salt TEXT NOT NULL, + signature TEXT NOT NULL, + lastUpdated DATETIME NOT NULL, + fillableTakerAssetAmount TEXT NOT NULL, + isRemoved BOOLEAN NOT NULL, + isPinned BOOLEAN NOT NULL, + parsedMakerAssetData TEXT NOT NULL, + parsedMakerFeeAssetData TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS miniHeaders ( + hash TEXT UNIQUE NOT NULL, + number TEXT UNIQUE NOT NULL, + parent TEXT NOT NULL, + timestamp DATETIME NOT NULL, + logs TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS metadata ( + ethereumChainID BIGINT NOT NULL, + maxExpirationTime TEXT NOT NULL, + ethRPCRequestsSentInCurrentUTCDay BIGINT NOT NULL, + startOfCurrentUTCDay DATETIME NOT NULL +); +` + +// Note(albrow): If needed, we can optimize this by using prepared +// statements for inserts instead of just a string. +const insertOrderQuery = `INSERT INTO orders ( + hash, + chainID, + exchangeAddress, + makerAddress, + makerAssetData, + makerFeeAssetData, + makerAssetAmount, + makerFee, + takerAddress, + takerAssetData, + takerFeeAssetData, + takerAssetAmount, + takerFee, + senderAddress, + feeRecipientAddress, + expirationTimeSeconds, + salt, + signature, + lastUpdated, + fillableTakerAssetAmount, + isRemoved, + isPinned, + parsedMakerAssetData, + parsedMakerFeeAssetData +) VALUES ( + :hash, + :chainID, + :exchangeAddress, + :makerAddress, + :makerAssetData, + :makerFeeAssetData, + :makerAssetAmount, + :makerFee, + :takerAddress, + :takerAssetData, + :takerFeeAssetData, + :takerAssetAmount, + :takerFee, + :senderAddress, + :feeRecipientAddress, + :expirationTimeSeconds, + :salt, + :signature, + :lastUpdated, + :fillableTakerAssetAmount, + :isRemoved, + :isPinned, + :parsedMakerAssetData, + :parsedMakerFeeAssetData +) ON CONFLICT DO NOTHING +` + +const updateOrderQuery = `UPDATE orders SET + chainID = :chainID, + exchangeAddress = :exchangeAddress, + makerAddress = :makerAddress, + makerAssetData = :makerAssetData, + makerFeeAssetData = :makerFeeAssetData, + makerAssetAmount = :makerAssetAmount, + makerFee = :makerFee, + takerAddress = :takerAddress, + takerAssetData = :takerAssetData, + takerFeeAssetData = :takerFeeAssetData, + takerAssetAmount = :takerAssetAmount, + takerFee = :takerFee, + senderAddress = :senderAddress, + feeRecipientAddress = :feeRecipientAddress, + expirationTimeSeconds = :expirationTimeSeconds, + salt = :salt, + signature = :signature, + lastUpdated = :lastUpdated, + fillableTakerAssetAmount = :fillableTakerAssetAmount, + isRemoved = :isRemoved, + isPinned = :isPinned, + parsedMakerAssetData = :parsedMakerAssetData, + parsedMakerFeeAssetData = :parsedMakerFeeAssetData +WHERE orders.hash = :hash +` + +const insertMiniHeaderQuery = `INSERT INTO miniHeaders ( + hash, + parent, + number, + timestamp, + logs +) VALUES ( + :hash, + :parent, + :number, + :timestamp, + :logs +) ON CONFLICT DO NOTHING` + +const insertMetadataQuery = `INSERT INTO metadata ( + ethereumChainID, + maxExpirationTime, + ethRPCRequestsSentInCurrentUTCDay, + startOfCurrentUTCDay +) VALUES ( + :ethereumChainID, + :maxExpirationTime, + :ethRPCRequestsSentInCurrentUTCDay, + :startOfCurrentUTCDay +)` + +const updateMetadataQuery = `UPDATE metadata SET + ethereumChainID = :ethereumChainID, + maxExpirationTime = :maxExpirationTime, + ethRPCRequestsSentInCurrentUTCDay = :ethRPCRequestsSentInCurrentUTCDay, + startOfCurrentUTCDay = :startOfCurrentUTCDay +` + +func (db *DB) migrate() error { + _, err := db.sqldb.ExecContext(db.ctx, schema) + return convertErr(err) +} + +func (db *DB) AddOrders(orders []*types.OrderWithMetadata) (added []*types.OrderWithMetadata, removed []*types.OrderWithMetadata, err error) { + defer func() { + err = convertErr(err) + }() + txn, err := db.sqldb.BeginTxx(db.ctx, nil) + if err != nil { + return nil, nil, err + } + defer func() { + _ = txn.Rollback() + }() + + for _, order := range orders { + result, err := txn.NamedExecContext(db.ctx, insertOrderQuery, sqltypes.OrderFromCommonType(order)) + if err != nil { + return nil, nil, err + } + affected, err := result.RowsAffected() + if err != nil { + return nil, nil, err + } + if affected > 0 { + added = append(added, order) + } + } + if err := txn.Commit(); err != nil { + return nil, nil, err + } + + // TODO(albrow): Remove orders with longest expiration time. + return added, nil, nil +} + +func (db *DB) GetOrder(hash common.Hash) (order *types.OrderWithMetadata, err error) { + defer func() { + err = convertErr(err) + }() + var foundOrder sqltypes.Order + if err := db.sqldb.GetContext(db.ctx, &foundOrder, "SELECT * FROM orders WHERE hash = $1", hash); err != nil { + return nil, err + } + return sqltypes.OrderToCommonType(&foundOrder), nil +} + +func (db *DB) FindOrders(query *OrderQuery) (orders []*types.OrderWithMetadata, err error) { + defer func() { + err = convertErr(err) + }() + if err := checkOrderQuery(query); err != nil { + return nil, err + } + stmt, err := addOptsToSelectOrdersQuery(db.sqldb.Select("*").From("orders"), query) + if err != nil { + return nil, err + } + var foundOrders []*sqltypes.Order + if err := stmt.GetAllContext(db.ctx, &foundOrders); err != nil { + return nil, err + } + return sqltypes.OrdersToCommonType(foundOrders), nil +} + +func (db *DB) CountOrders(query *OrderQuery) (count int, err error) { + defer func() { + err = convertErr(err) + }() + if err := checkOrderQuery(query); err != nil { + return 0, err + } + stmt, err := addOptsToSelectOrdersQuery(db.sqldb.Select("COUNT(*)").From("orders"), query) + if err != nil { + return 0, err + } + gotCount, err := stmt.GetCount() + if err != nil { + return 0, err + } + return int(gotCount), nil +} + +type Selector interface { + Select(cols ...string) *sqlz.SelectStmt +} + +func addOptsToSelectOrdersQuery(stmt *sqlz.SelectStmt, opts *OrderQuery) (*sqlz.SelectStmt, error) { + if opts == nil { + return stmt, nil + } + + ordering := orderingFromOrderSortOpts(opts.Sort) + if len(ordering) != 0 { + stmt.OrderBy(ordering...) + } + if opts.Limit != 0 { + stmt.Limit(int64(opts.Limit)) + } + if opts.Offset != 0 { + stmt.Offset(int64(opts.Offset)) + } + whereConditions, err := whereConditionsFromOrderFilterOpts(opts.Filters) + if err != nil { + return nil, err + } + if len(whereConditions) != 0 { + stmt.Where(whereConditions...) + } + + return stmt, nil +} + +func orderingFromOrderSortOpts(sortOpts []OrderSort) []sqlz.SQLStmt { + ordering := []sqlz.SQLStmt{} + for _, sortOpt := range sortOpts { + if sortOpt.Direction == Ascending { + ordering = append(ordering, sqlz.Asc(string(sortOpt.Field))) + } else { + ordering = append(ordering, sqlz.Desc(string(sortOpt.Field))) + } + } + return ordering +} + +func whereConditionsFromOrderFilterOpts(filterOpts []OrderFilter) ([]sqlz.WhereCondition, error) { + whereConditions := make([]sqlz.WhereCondition, len(filterOpts)) + for i, filterOpt := range filterOpts { + value := convertFilterValue(filterOpt.Value) + switch filterOpt.Kind { + case Equal: + whereConditions[i] = sqlz.Eq(string(filterOpt.Field), value) + case NotEqual: + whereConditions[i] = sqlz.Not(sqlz.Eq(string(filterOpt.Field), value)) + case Less: + whereConditions[i] = sqlz.Lt(string(filterOpt.Field), value) + case Greater: + whereConditions[i] = sqlz.Gt(string(filterOpt.Field), value) + case LessOrEqual: + whereConditions[i] = sqlz.Lte(string(filterOpt.Field), value) + case GreaterOrEqual: + whereConditions[i] = sqlz.Gte(string(filterOpt.Field), value) + case Contains: + // Note(albrow): If needed, we can optimize this so it is easier to index. + // LIKE queries are notoriously slow. + whereConditions[i] = sqlz.Like(string(filterOpt.Field), fmt.Sprintf("%%%s%%", value)) + default: + return nil, fmt.Errorf("db.FindOrder: unknown FilterOpt.Kind: %s", filterOpt.Kind) + } + } + return whereConditions, nil +} + +func (db *DB) DeleteOrder(hash common.Hash) error { + if _, err := db.sqldb.ExecContext(db.ctx, "DELETE FROM orders WHERE hash = $1", hash); err != nil { + return convertErr(err) + } + return nil +} + +func (db *DB) DeleteOrders(query *OrderQuery) (deleted []*types.OrderWithMetadata, err error) { + defer func() { + err = convertErr(err) + }() + if err := checkOrderQuery(query); err != nil { + return nil, err + } + // HACK(albrow): sqlz doesn't support ORDER BY, LIMIT, and OFFSET + // for DELETE statements. It also doesn't support RETURNING. As a + // workaround, we do a SELECT and DELETE inside a transaction. + var ordersToDelete []*sqltypes.Order + err = db.sqldb.TransactionalContext(db.ctx, nil, func(txn *sqlz.Tx) error { + stmt, err := addOptsToSelectOrdersQuery(txn.Select("*").From("orders"), query) + if err != nil { + return err + } + if err := stmt.GetAllContext(db.ctx, &ordersToDelete); err != nil { + return err + } + for _, order := range ordersToDelete { + _, err := txn.DeleteFrom("orders").Where(sqlz.Eq(string(OFHash), order.Hash)).ExecContext(db.ctx) + if err != nil { + return err + } + } + return nil + }) + if err != nil { + return nil, err + } + return sqltypes.OrdersToCommonType(ordersToDelete), nil +} + +func (db *DB) UpdateOrder(hash common.Hash, updateFunc func(existingOrder *types.OrderWithMetadata) (updatedOrder *types.OrderWithMetadata, err error)) (err error) { + defer func() { + err = convertErr(err) + }() + if updateFunc == nil { + return errors.New("db.UpdateOrders: updateFunc cannot be nil") + } + + txn, err := db.sqldb.BeginTxx(db.ctx, nil) + if err != nil { + return err + } + defer func() { + _ = txn.Rollback() + }() + + var existingOrder sqltypes.Order + if err := txn.GetContext(db.ctx, &existingOrder, "SELECT * FROM orders WHERE hash = $1", hash); err != nil { + if err == sql.ErrNoRows { + return ErrNotFound + } + return err + } + + commonOrder := sqltypes.OrderToCommonType(&existingOrder) + commonUpdatedOrder, err := updateFunc(commonOrder) + if err != nil { + return fmt.Errorf("db.UpdateOrders: updateFunc returned error") + } + updatedOrder := sqltypes.OrderFromCommonType(commonUpdatedOrder) + _, err = txn.NamedExecContext(db.ctx, updateOrderQuery, updatedOrder) + if err != nil { + return err + } + return txn.Commit() +} + +func (db *DB) AddMiniHeaders(miniHeaders []*types.MiniHeader) (added []*types.MiniHeader, removed []*types.MiniHeader, err error) { + defer func() { + err = convertErr(err) + }() + var miniHeadersToRemove []*sqltypes.MiniHeader + err = db.sqldb.TransactionalContext(db.ctx, nil, func(txn *sqlz.Tx) error { + for _, miniHeader := range miniHeaders { + result, err := txn.NamedExecContext(db.ctx, insertMiniHeaderQuery, sqltypes.MiniHeaderFromCommonType(miniHeader)) + if err != nil { + return err + } + affected, err := result.RowsAffected() + if err != nil { + return err + } + if affected > 0 { + added = append(added, miniHeader) + } + } + + // HACK(albrow): sqlz doesn't support ORDER BY, LIMIT, and OFFSET + // for DELETE statements. It also doesn't support RETURNING. As a + // workaround, we do a SELECT and DELETE inside a transaction. + trimQuery := txn.Select("*").From("miniHeaders").OrderBy(sqlz.Desc(string(MFNumber))).Limit(99999999999).Offset(int64(db.opts.MaxMiniHeaders)) + if err := trimQuery.GetAllContext(db.ctx, &miniHeadersToRemove); err != nil { + return err + } + for _, miniHeader := range miniHeadersToRemove { + _, err := txn.DeleteFrom("miniHeaders").Where(sqlz.Eq(string(MFHash), miniHeader.Hash)).ExecContext(db.ctx) + if err != nil { + return err + } + } + return nil + }) + if err != nil { + return nil, nil, err + } + + // Because of how the above code is written, a single miniHeader could exist + // in both added and removed sets. We should remove such miniHeaders from both + // sets in this case. + addedMap := map[common.Hash]*types.MiniHeader{} + removedMap := map[common.Hash]*sqltypes.MiniHeader{} + for _, a := range added { + addedMap[a.Hash] = a + } + for _, r := range miniHeadersToRemove { + removedMap[r.Hash] = r + } + dedupedAdded := []*types.MiniHeader{} + dedupedRemoved := []*sqltypes.MiniHeader{} + for _, a := range added { + if _, wasRemoved := removedMap[a.Hash]; !wasRemoved { + dedupedAdded = append(dedupedAdded, a) + } + } + for _, r := range miniHeadersToRemove { + if _, wasAdded := addedMap[r.Hash]; !wasAdded { + dedupedRemoved = append(dedupedRemoved, r) + } + } + + return dedupedAdded, sqltypes.MiniHeadersToCommonType(dedupedRemoved), nil +} + +func (db *DB) GetMiniHeader(hash common.Hash) (miniHeader *types.MiniHeader, err error) { + defer func() { + err = convertErr(err) + }() + var sqlMiniHeader sqltypes.MiniHeader + if err := db.sqldb.GetContext(db.ctx, &sqlMiniHeader, "SELECT * FROM miniHeaders WHERE hash = $1", hash); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, err + } + return sqltypes.MiniHeaderToCommonType(&sqlMiniHeader), nil +} + +func (db *DB) FindMiniHeaders(query *MiniHeaderQuery) (miniHeaders []*types.MiniHeader, err error) { + defer func() { + err = convertErr(err) + }() + stmt, err := findMiniHeadersQueryFromOpts(db.sqldb, query) + if err != nil { + return nil, err + } + var sqlMiniHeaders []*sqltypes.MiniHeader + if err := stmt.GetAllContext(db.ctx, &sqlMiniHeaders); err != nil { + return nil, err + } + return sqltypes.MiniHeadersToCommonType(sqlMiniHeaders), nil +} + +func findMiniHeadersQueryFromOpts(selector Selector, query *MiniHeaderQuery) (*sqlz.SelectStmt, error) { + stmt := selector.Select("*").From("miniHeaders") + if query == nil { + return stmt, nil + } + + ordering := orderingFromMiniHeaderSortOpts(query.Sort) + if len(ordering) != 0 { + stmt.OrderBy(ordering...) + } + if query.Limit != 0 { + stmt.Limit(int64(query.Limit)) + } + if query.Offset != 0 { + if query.Limit == 0 { + return nil, errors.New("db.FindMiniHeaders: can't use Offset without Limit") + } + stmt.Offset(int64(query.Offset)) + } + whereConditions, err := whereConditionsFromMiniHeaderFilterOpts(query.Filters) + if err != nil { + return nil, err + } + if len(whereConditions) != 0 { + stmt.Where(whereConditions...) + } + + return stmt, nil +} + +func orderingFromMiniHeaderSortOpts(sortOpts []MiniHeaderSort) []sqlz.SQLStmt { + ordering := []sqlz.SQLStmt{} + for _, sortOpt := range sortOpts { + if sortOpt.Direction == Ascending { + ordering = append(ordering, sqlz.Asc(string(sortOpt.Field))) + } else { + ordering = append(ordering, sqlz.Desc(string(sortOpt.Field))) + } + } + return ordering +} + +func whereConditionsFromMiniHeaderFilterOpts(filterOpts []MiniHeaderFilter) ([]sqlz.WhereCondition, error) { + whereConditions := make([]sqlz.WhereCondition, len(filterOpts)) + for i, filterOpt := range filterOpts { + value := convertFilterValue(filterOpt.Value) + switch filterOpt.Kind { + case Equal: + whereConditions[i] = sqlz.Eq(string(filterOpt.Field), value) + case NotEqual: + whereConditions[i] = sqlz.Not(sqlz.Eq(string(filterOpt.Field), value)) + case Less: + whereConditions[i] = sqlz.Lt(string(filterOpt.Field), value) + case Greater: + whereConditions[i] = sqlz.Gt(string(filterOpt.Field), value) + case LessOrEqual: + whereConditions[i] = sqlz.Lte(string(filterOpt.Field), value) + case GreaterOrEqual: + whereConditions[i] = sqlz.Gte(string(filterOpt.Field), value) + case Contains: + // Note(albrow): If needed, we can optimize this so it is easier to index. + // LIKE queries are notoriously slow. + whereConditions[i] = sqlz.Like(string(filterOpt.Field), fmt.Sprintf("%%%s%%", value)) + default: + return nil, fmt.Errorf("db.FindMiniHeaders: unknown FilterOpt.Kind: %s", filterOpt.Kind) + } + } + return whereConditions, nil +} + +func (db *DB) DeleteMiniHeader(hash common.Hash) error { + if _, err := db.sqldb.ExecContext(db.ctx, "DELETE FROM miniHeaders WHERE hash = $1", hash); err != nil { + return convertErr(err) + } + return nil +} + +func (db *DB) DeleteMiniHeaders(query *MiniHeaderQuery) (deleted []*types.MiniHeader, err error) { + defer func() { + err = convertErr(err) + }() + // HACK(albrow): sqlz doesn't support ORDER BY, LIMIT, and OFFSET + // for DELETE statements. It also doesn't support RETURNING. As a + // workaround, we do a SELECT and DELETE inside a transaction. + var miniHeadersToDelete []*sqltypes.MiniHeader + err = db.sqldb.TransactionalContext(db.ctx, nil, func(tx *sqlz.Tx) error { + stmt, err := findMiniHeadersQueryFromOpts(tx, query) + if err != nil { + return err + } + if err := stmt.GetAllContext(db.ctx, &miniHeadersToDelete); err != nil { + return err + } + for _, miniHeader := range miniHeadersToDelete { + _, err := tx.DeleteFrom("miniHeaders").Where(sqlz.Eq(string(MFHash), miniHeader.Hash)).ExecContext(db.ctx) + if err != nil { + return err + } + } + return nil + }) + if err != nil { + return nil, err + } + return sqltypes.MiniHeadersToCommonType(miniHeadersToDelete), nil +} + +// GetMetadata returns the metadata (or db.ErrNotFound if no metadata has been saved). +func (db *DB) GetMetadata() (*types.Metadata, error) { + var metadata sqltypes.Metadata + if err := db.sqldb.GetContext(db.ctx, &metadata, "SELECT * FROM metadata LIMIT 1"); err != nil { + return nil, convertErr(err) + } + return sqltypes.MetadataToCommonType(&metadata), nil +} + +// SaveMetadata inserts the metadata into the database, overwriting any existing +// metadata. It returns ErrMetadataAlreadyExists if the metadata has already been +// saved in the database. +func (db *DB) SaveMetadata(metadata *types.Metadata) (err error) { + defer func() { + err = convertErr(err) + }() + err = db.sqldb.TransactionalContext(db.ctx, nil, func(txn *sqlz.Tx) error { + query := db.sqldb.Select("COUNT(*)").From("metadata") + count, err := query.GetCount() + if err != nil { + return err + } + if count != 0 { + return ErrMetadataAlreadyExists + } + _, err = db.sqldb.NamedExecContext(db.ctx, insertMetadataQuery, sqltypes.MetadataFromCommonType(metadata)) + return err + }) + return err +} + +// UpdateMetadata updates the metadata in the database via a transaction. It +// accepts a callback function which will be provided with the old metadata and +// should return the new metadata to save. +func (db *DB) UpdateMetadata(updateFunc func(oldmetadata *types.Metadata) (newMetadata *types.Metadata)) (err error) { + defer func() { + err = convertErr(err) + }() + if updateFunc == nil { + return errors.New("db.UpdateMetadata: updateFunc cannot be nil") + } + + txn, err := db.sqldb.BeginTxx(db.ctx, nil) + if err != nil { + return err + } + defer func() { + _ = txn.Rollback() + }() + + var existingMetadata sqltypes.Metadata + if err := txn.GetContext(db.ctx, &existingMetadata, "SELECT * FROM metadata LIMIT 1"); err != nil { + if err == sql.ErrNoRows { + return ErrNotFound + } + return err + } + + commonMetadata := sqltypes.MetadataToCommonType(&existingMetadata) + commonUpdatedMetadata := updateFunc(commonMetadata) + updatedMetadata := sqltypes.MetadataFromCommonType(commonUpdatedMetadata) + _, err = txn.NamedExecContext(db.ctx, updateMetadataQuery, updatedMetadata) + if err != nil { + return err + } + return txn.Commit() +} + +func convertFilterValue(value interface{}) interface{} { + switch v := value.(type) { + case *big.Int: + return sqltypes.NewSortedBigInt(v) + } + return value +} + +func assetDataIncludesTokenAddressAndTokenID(field OrderField, tokenAddress common.Address, tokenID *big.Int) OrderFilter { + filterValueJSON, err := canonicaljson.Marshal(sqltypes.SingleAssetData{ + Address: tokenAddress, + TokenID: sqltypes.NewBigInt(tokenID), + }) + if err != nil { + // big.Int and common.Address types should never return an error when marshaling to JSON + panic(err) + } + return OrderFilter{ + Field: field, + Kind: Contains, + Value: string(filterValueJSON), + } +} + +// convertErr converts from SQL specific errors to common error types. +func convertErr(err error) error { + if err == nil { + return nil + } + // Check if the error matches known exported values. + switch err { + case sql.ErrNoRows: + return ErrNotFound + case sql.ErrConnDone: + return ErrClosed + } + // As a fallback, check the error string for errors which have no + // exported type/value. + switch err.Error() { + case "sql: database is closed": + return ErrClosed + } + return err +} diff --git a/db/sqltypes/sqltypes.go b/db/sqltypes/sqltypes.go new file mode 100644 index 000000000..0f2d02a97 --- /dev/null +++ b/db/sqltypes/sqltypes.go @@ -0,0 +1,479 @@ +// +build !js + +package sqltypes + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "math" + "math/big" + "strconv" + "time" + + "github.com/0xProject/0x-mesh/common/types" + "github.com/ethereum/go-ethereum/common" + ethmath "github.com/ethereum/go-ethereum/common/math" + ethtypes "github.com/ethereum/go-ethereum/core/types" + "github.com/gibson042/canonicaljson-go" +) + +// BigInt is a wrapper around *big.Int that implements the sql.Valuer +// and sql.Scanner interfaces and *does not* retain sort order. +type BigInt struct { + *big.Int +} + +func NewBigInt(v *big.Int) *BigInt { + return &BigInt{ + Int: v, + } +} + +func BigIntFromString(v string) (*BigInt, error) { + bigInt, ok := ethmath.ParseBig256(v) + if !ok { + return nil, fmt.Errorf("dexietypes: could not convert %q to BigInt", v) + } + return NewBigInt(bigInt), nil +} + +func BigIntFromInt64(v int64) *BigInt { + return NewBigInt(big.NewInt(v)) +} + +func (i *BigInt) Value() (driver.Value, error) { + if i == nil || i.Int == nil { + return nil, nil + } + return i.Int.String(), nil +} + +func (i *BigInt) Scan(value interface{}) error { + if value == nil { + i = nil + return nil + } + switch v := value.(type) { + case int64: + i.Int = big.NewInt(v) + case float64: + if math.Trunc(v) != v { + // float64 may be used by the database driver to represent 0 or any other + // whole number. This is okay as long as v is a whole number, i.e. does not + // have anything after the decimal point. If this is not the case we return + // an error. + return fmt.Errorf("could not scan non-whole number float64 value %v into sqltypes.BigInt", value) + } + i.Int, _ = big.NewFloat(v).Int(big.NewInt(0)) + case string: + parsed, ok := ethmath.ParseBig256(v) + if !ok { + return fmt.Errorf("could not scan string value %q into sqltypes.BigInt", v) + } + i.Int = parsed + default: + return fmt.Errorf("could not scan type %T into sqltypes.BigInt", value) + } + + return nil +} + +func (i *BigInt) MarshalJSON() ([]byte, error) { + if i == nil || i.Int == nil { + return json.Marshal(nil) + } + return json.Marshal(i.Int.String()) +} + +func (i *BigInt) UnmarshalJSON(data []byte) error { + unqouted, err := strconv.Unquote(string(data)) + if err != nil { + return fmt.Errorf("could not unmarshal JSON data into dexietypes.BigInt: %s", string(data)) + } + bigInt, ok := ethmath.ParseBig256(unqouted) + if !ok { + return fmt.Errorf("could not unmarshal JSON data into dexietypes.BigInt: %s", string(data)) + } + i.Int = bigInt + return nil +} + +// SortedBigInt is a wrapper around *big.Int that implements the sql.Valuer +// and sql.Scanner interfaces and retains sort order by padding with zeroes. +type SortedBigInt struct { + *big.Int +} + +func NewSortedBigInt(v *big.Int) *SortedBigInt { + return &SortedBigInt{ + Int: v, + } +} + +func SortedBigIntFromString(v string) (*SortedBigInt, error) { + bigInt, ok := ethmath.ParseBig256(v) + if !ok { + return nil, fmt.Errorf("dexietypes: could not convert %q to BigInt", v) + } + return NewSortedBigInt(bigInt), nil +} + +func SortedBigIntFromInt64(v int64) *SortedBigInt { + return NewSortedBigInt(big.NewInt(v)) +} + +func (i *SortedBigInt) Value() (driver.Value, error) { + if i == nil || i.Int == nil { + return nil, nil + } + // Note(albrow), strings in SQL are sorted in alphanumerical order, not + // numerical order. In order to sort by numerical order, we need to pad with + // zeroes. The maximum length of an unsigned 256 bit integer is 80, so we + // pad with zeroes such that the length of the number is always 80. + return fmt.Sprintf("%080s", i.Int.String()), nil +} + +func (i *SortedBigInt) Scan(value interface{}) error { + if value == nil { + i = nil + return nil + } + switch v := value.(type) { + case int64: + i.Int = big.NewInt(v) + case float64: + if math.Trunc(v) != v { + // float64 may be used by the database driver to represent 0 or any other + // whole number. This is okay as long as v is a whole number, i.e. does not + // have anything after the decimal point. If this is not the case we return + // an error. + return fmt.Errorf("could not scan non-whole number float64 value %v into sqltypes.BigInt", value) + } + i.Int, _ = big.NewFloat(v).Int(big.NewInt(0)) + case string: + parsed, ok := ethmath.ParseBig256(v) + if !ok { + return fmt.Errorf("could not scan string value %q into sqltypes.BigInt", v) + } + i.Int = parsed + default: + return fmt.Errorf("could not scan type %T into sqltypes.BigInt", value) + } + + return nil +} + +func (i *SortedBigInt) MarshalJSON() ([]byte, error) { + if i == nil || i.Int == nil { + return json.Marshal(nil) + } + // Note(albrow), strings in Dexie.js are sorted in alphanumerical order, not + // numerical order. In order to sort by numerical order, we need to pad with + // zeroes. The maximum length of an unsigned 256 bit integer is 80, so we + // pad with zeroes such that the length of the number is always 80. + return json.Marshal(fmt.Sprintf("%080s", i.Int.String())) +} + +func (i *SortedBigInt) UnmarshalJSON(data []byte) error { + unqouted, err := strconv.Unquote(string(data)) + if err != nil { + return fmt.Errorf("could not unmarshal JSON data into dexietypes.BigInt: %s", string(data)) + } + bigInt, ok := ethmath.ParseBig256(unqouted) + if !ok { + return fmt.Errorf("could not unmarshal JSON data into dexietypes.BigInt: %s", string(data)) + } + i.Int = bigInt + return nil +} + +type SingleAssetData struct { + Address common.Address `json:"address"` + TokenID *BigInt `json:"tokenID"` +} + +// ParsedAssetData is a wrapper around []*SingleAssetData that implements the +// sql.Valuer and sql.Scanner interfaces. +type ParsedAssetData []*SingleAssetData + +func (s *ParsedAssetData) Value() (driver.Value, error) { + if s == nil { + return nil, nil + } + return canonicaljson.Marshal(s) +} + +func (s *ParsedAssetData) Scan(value interface{}) error { + if value == nil { + s = nil + return nil + } + switch v := value.(type) { + case []byte: + return json.Unmarshal(v, s) + case string: + return json.Unmarshal([]byte(v), s) + default: + return fmt.Errorf("could not scan type %T into EventLogs", value) + } +} + +// Order is the SQL database representation a 0x order along with some relevant metadata. +type Order struct { + Hash common.Hash `db:"hash"` + ChainID *SortedBigInt `db:"chainID"` + ExchangeAddress common.Address `db:"exchangeAddress"` + MakerAddress common.Address `db:"makerAddress"` + MakerAssetData []byte `db:"makerAssetData"` + MakerFeeAssetData []byte `db:"makerFeeAssetData"` + MakerAssetAmount *SortedBigInt `db:"makerAssetAmount"` + MakerFee *SortedBigInt `db:"makerFee"` + TakerAddress common.Address `db:"takerAddress"` + TakerAssetData []byte `db:"takerAssetData"` + TakerFeeAssetData []byte `db:"takerFeeAssetData"` + TakerAssetAmount *SortedBigInt `db:"takerAssetAmount"` + TakerFee *SortedBigInt `db:"takerFee"` + SenderAddress common.Address `db:"senderAddress"` + FeeRecipientAddress common.Address `db:"feeRecipientAddress"` + ExpirationTimeSeconds *SortedBigInt `db:"expirationTimeSeconds"` + Salt *SortedBigInt `db:"salt"` + Signature []byte `db:"signature"` + LastUpdated time.Time `db:"lastUpdated"` + FillableTakerAssetAmount *SortedBigInt `db:"fillableTakerAssetAmount"` + IsRemoved bool `db:"isRemoved"` + IsPinned bool `db:"isPinned"` + ParsedMakerAssetData *ParsedAssetData `db:"parsedMakerAssetData"` + ParsedMakerFeeAssetData *ParsedAssetData `db:"parsedMakerFeeAssetData"` +} + +// EventLogs is a wrapper around []*ethtypes.Log that implements the +// sql.Valuer and sql.Scanner interfaces. +type EventLogs struct { + Logs []ethtypes.Log +} + +func NewEventLogs(logs []ethtypes.Log) *EventLogs { + return &EventLogs{ + Logs: logs, + } +} + +func (e *EventLogs) Value() (driver.Value, error) { + if e == nil { + return nil, nil + } + logsJSON, err := canonicaljson.Marshal(e.Logs) + if err != nil { + return nil, err + } + return logsJSON, err +} + +func (e *EventLogs) Scan(value interface{}) error { + if value == nil { + e = nil + return nil + } + switch v := value.(type) { + case []byte: + return json.Unmarshal(v, &e.Logs) + case string: + return json.Unmarshal([]byte(v), &e.Logs) + default: + return fmt.Errorf("could not scan type %T into EventLogs", value) + } +} + +type MiniHeader struct { + Hash common.Hash `db:"hash"` + Parent common.Hash `db:"parent"` + Number *SortedBigInt `db:"number"` + Timestamp time.Time `db:"timestamp"` + Logs *EventLogs `db:"logs"` +} + +type Metadata struct { + EthereumChainID int `db:"ethereumChainID"` + MaxExpirationTime *SortedBigInt `db:"maxExpirationTime"` + EthRPCRequestsSentInCurrentUTCDay int `db:"ethRPCRequestsSentInCurrentUTCDay"` + StartOfCurrentUTCDay time.Time `db:"startOfCurrentUTCDay"` +} + +func OrderToCommonType(order *Order) *types.OrderWithMetadata { + if order == nil { + return nil + } + return &types.OrderWithMetadata{ + Hash: order.Hash, + ChainID: order.ChainID.Int, + ExchangeAddress: order.ExchangeAddress, + MakerAddress: order.MakerAddress, + MakerAssetData: order.MakerAssetData, + MakerFeeAssetData: order.MakerFeeAssetData, + MakerAssetAmount: order.MakerAssetAmount.Int, + MakerFee: order.MakerFee.Int, + TakerAddress: order.TakerAddress, + TakerAssetData: order.TakerAssetData, + TakerFeeAssetData: order.TakerFeeAssetData, + TakerAssetAmount: order.TakerAssetAmount.Int, + TakerFee: order.TakerFee.Int, + SenderAddress: order.SenderAddress, + FeeRecipientAddress: order.FeeRecipientAddress, + ExpirationTimeSeconds: order.ExpirationTimeSeconds.Int, + Salt: order.Salt.Int, + Signature: order.Signature, + FillableTakerAssetAmount: order.FillableTakerAssetAmount.Int, + LastUpdated: order.LastUpdated, + IsRemoved: order.IsRemoved, + IsPinned: order.IsPinned, + ParsedMakerAssetData: ParsedAssetDataToCommonType(order.ParsedMakerAssetData), + ParsedMakerFeeAssetData: ParsedAssetDataToCommonType(order.ParsedMakerFeeAssetData), + } +} + +func OrderFromCommonType(order *types.OrderWithMetadata) *Order { + if order == nil { + return nil + } + return &Order{ + Hash: order.Hash, + ChainID: NewSortedBigInt(order.ChainID), + ExchangeAddress: order.ExchangeAddress, + MakerAddress: order.MakerAddress, + MakerAssetData: order.MakerAssetData, + MakerFeeAssetData: order.MakerFeeAssetData, + MakerAssetAmount: NewSortedBigInt(order.MakerAssetAmount), + MakerFee: NewSortedBigInt(order.MakerFee), + TakerAddress: order.TakerAddress, + TakerAssetData: order.TakerAssetData, + TakerFeeAssetData: order.TakerFeeAssetData, + TakerAssetAmount: NewSortedBigInt(order.TakerAssetAmount), + TakerFee: NewSortedBigInt(order.TakerFee), + SenderAddress: order.SenderAddress, + FeeRecipientAddress: order.FeeRecipientAddress, + ExpirationTimeSeconds: NewSortedBigInt(order.ExpirationTimeSeconds), + Salt: NewSortedBigInt(order.Salt), + Signature: order.Signature, + LastUpdated: order.LastUpdated, + FillableTakerAssetAmount: NewSortedBigInt(order.FillableTakerAssetAmount), + IsRemoved: order.IsRemoved, + IsPinned: order.IsPinned, + ParsedMakerAssetData: ParsedAssetDataFromCommonType(order.ParsedMakerAssetData), + ParsedMakerFeeAssetData: ParsedAssetDataFromCommonType(order.ParsedMakerFeeAssetData), + } +} + +func OrdersToCommonType(orders []*Order) []*types.OrderWithMetadata { + result := make([]*types.OrderWithMetadata, len(orders)) + for i, order := range orders { + result[i] = OrderToCommonType(order) + } + return result +} + +func ParsedAssetDataToCommonType(parsedAssetData *ParsedAssetData) []*types.SingleAssetData { + if parsedAssetData == nil || len(*parsedAssetData) == 0 { + return nil + } + assetDataSlice := []*SingleAssetData(*parsedAssetData) + result := make([]*types.SingleAssetData, len(assetDataSlice)) + for i, singleAssetData := range assetDataSlice { + result[i] = SingleAssetDataToCommonType(singleAssetData) + } + return result +} + +func ParsedAssetDataFromCommonType(parsedAssetData []*types.SingleAssetData) *ParsedAssetData { + result := ParsedAssetData(make([]*SingleAssetData, len(parsedAssetData))) + for i, singleAssetData := range parsedAssetData { + result[i] = SingleAssetDataFromCommonType(singleAssetData) + } + return &result +} + +func SingleAssetDataToCommonType(singleAssetData *SingleAssetData) *types.SingleAssetData { + if singleAssetData == nil { + return nil + } + var tokenID *big.Int + if singleAssetData.TokenID != nil { + tokenID = singleAssetData.TokenID.Int + } + return &types.SingleAssetData{ + Address: singleAssetData.Address, + TokenID: tokenID, + } +} + +func SingleAssetDataFromCommonType(singleAssetData *types.SingleAssetData) *SingleAssetData { + if singleAssetData == nil { + return nil + } + var tokenID *BigInt + if singleAssetData.TokenID != nil { + tokenID = NewBigInt(singleAssetData.TokenID) + } + return &SingleAssetData{ + Address: singleAssetData.Address, + TokenID: tokenID, + } +} + +func MiniHeaderToCommonType(miniHeader *MiniHeader) *types.MiniHeader { + if miniHeader == nil { + return nil + } + return &types.MiniHeader{ + Hash: miniHeader.Hash, + Parent: miniHeader.Parent, + Number: miniHeader.Number.Int, + Timestamp: miniHeader.Timestamp, + Logs: miniHeader.Logs.Logs, + } +} + +func MiniHeaderFromCommonType(miniHeader *types.MiniHeader) *MiniHeader { + if miniHeader == nil { + return nil + } + return &MiniHeader{ + Hash: miniHeader.Hash, + Parent: miniHeader.Parent, + Number: NewSortedBigInt(miniHeader.Number), + Timestamp: miniHeader.Timestamp, + Logs: NewEventLogs(miniHeader.Logs), + } +} + +func MiniHeadersToCommonType(miniHeaders []*MiniHeader) []*types.MiniHeader { + result := make([]*types.MiniHeader, len(miniHeaders)) + for i, miniHeader := range miniHeaders { + result[i] = MiniHeaderToCommonType(miniHeader) + } + return result +} + +func MetadataToCommonType(metadata *Metadata) *types.Metadata { + if metadata == nil { + return nil + } + return &types.Metadata{ + EthereumChainID: metadata.EthereumChainID, + MaxExpirationTime: metadata.MaxExpirationTime.Int, + EthRPCRequestsSentInCurrentUTCDay: metadata.EthRPCRequestsSentInCurrentUTCDay, + StartOfCurrentUTCDay: metadata.StartOfCurrentUTCDay, + } +} + +func MetadataFromCommonType(metadata *types.Metadata) *Metadata { + if metadata == nil { + return nil + } + return &Metadata{ + EthereumChainID: metadata.EthereumChainID, + MaxExpirationTime: NewSortedBigInt(metadata.MaxExpirationTime), + EthRPCRequestsSentInCurrentUTCDay: metadata.EthRPCRequestsSentInCurrentUTCDay, + StartOfCurrentUTCDay: metadata.StartOfCurrentUTCDay, + } +} diff --git a/db/transaction.go b/db/transaction.go deleted file mode 100644 index e46294f17..000000000 --- a/db/transaction.go +++ /dev/null @@ -1,203 +0,0 @@ -package db - -import ( - "errors" - "fmt" - "sync" - "sync/atomic" - - "github.com/albrow/stringset" -) - -var ( - ErrDiscarded = errors.New("transaction has already been discarded") - ErrCommitted = errors.New("transaction has already been committed") -) - -// ConflictingOperationsError is returned when two conflicting operations are attempted within the same -// transaction -type ConflictingOperationsError struct { - operation string -} - -func (e ConflictingOperationsError) Error() string { - return fmt.Sprintf("error on %s: cannot perform more than one operation on the same model within a transaction", e.operation) -} - -// Transaction is an atomic database transaction for a single collection which -// can be used to guarantee consistency. -type Transaction struct { - db *DB - mut sync.Mutex - colInfo *colInfo - batchWriter dbBatchWriter - readWriter *readerWithBatchWriter - committed bool - discarded bool - // internalCount keeps track of the number of models inserted/deleted within - // the transaction. An Insert increments internalCount and a Delete decrements - // it. When the transaction is committed, internalCount is added to the - // current count. - internalCount int64 - // affectedIDs keeps track of the model ids that are affected by this - // transaction. - affectedIDs stringset.Set -} - -// OpenTransaction opens and returns a new transaction for the collection. While -// the transaction is open, no other state changes (e.g. Insert, Update, or -// Delete) can be made to the collection (but concurrent reads are still -// allowed). -// -// Transactions are atomic, meaning that either: -// -// (1) The transaction will succeed and *all* queued operations will be -// applied, or -// (2) the transaction will fail or be discarded, in which case *none* of -// the queued operations will be applied. -// -// The transaction must be closed once done, either by committing or discarding -// the transaction. No changes will be made to the database state until the -// transaction is committed. -func (c *Collection) OpenTransaction() *Transaction { - // Note we acquire an RLock on the global write mutex. We're not really a - // "reader" but we behave like one in the context of an RWMutex. Up to one - // write lock for each collection can be held, or one global write lock can be - // held at any given time. - c.info.db.globalWriteLock.RLock() - c.info.writeMut.Lock() - return &Transaction{ - db: c.info.db, - colInfo: c.info.copy(), - batchWriter: c.ldb, - readWriter: newReaderWithBatchWriter(c.ldb), - affectedIDs: stringset.New(), - } -} - -// checkState acquires a lock on txn.mut and then calls unsafeCheckState. -func (txn *Transaction) checkState() error { - txn.mut.Lock() - defer txn.mut.Unlock() - return txn.unsafeCheckState() -} - -// unsafeCheckState checks the state of the transaction, assuming the caller has -// already acquired a lock. It returns an error if the transaction has already -// been committed or discarded. -func (txn *Transaction) unsafeCheckState() error { - if txn.discarded { - return ErrDiscarded - } else if txn.committed { - return ErrCommitted - } - return nil -} - -// Commit commits the transaction. If error is not nil, then the transaction is -// discarded. A new transaction must be created if you wish to retry the -// operations. -// -// Other methods should not be called after transaction has been committed. -func (txn *Transaction) Commit() error { - txn.mut.Lock() - defer txn.mut.Unlock() - if err := txn.unsafeCheckState(); err != nil { - return err - } - // Right before we commit, we need to update the count with txn.internalCount. - if err := updateCountWithTransaction(txn.colInfo, txn.readWriter, int(txn.internalCount)); err != nil { - _ = txn.Discard() - return err - } - if err := txn.batchWriter.Write(txn.readWriter.batch, nil); err != nil { - _ = txn.Discard() - return err - } - txn.committed = true - txn.colInfo.writeMut.Unlock() - txn.db.globalWriteLock.RUnlock() - return nil -} - -// Discard discards the transaction. -// -// Other methods should not be called after transaction has been discarded. -// However, it is safe to call Discard multiple times. -func (txn *Transaction) Discard() error { - txn.mut.Lock() - defer txn.mut.Unlock() - if txn.committed { - return ErrCommitted - } - if txn.discarded { - return nil - } - txn.discarded = true - txn.colInfo.writeMut.Unlock() - txn.db.globalWriteLock.RUnlock() - return nil -} - -// Insert queues an operation to insert the given model into the database. It -// returns an error if a model with the same id already exists. The model will -// not actually be inserted until the transaction is committed. -func (txn *Transaction) Insert(model Model) error { - txn.mut.Lock() - defer txn.mut.Unlock() - if err := txn.unsafeCheckState(); err != nil { - return err - } - if txn.affectedIDs.Contains(string(model.ID())) { - return ConflictingOperationsError{operation: "insert"} - } - if err := insertWithTransaction(txn.colInfo, txn.readWriter, model); err != nil { - return err - } - txn.updateInternalCount(1) - txn.affectedIDs.Add(string(model.ID())) - return nil -} - -// Update queues an operation to update an existing model in the database. It -// returns an error if the given model doesn't already exist. The model will -// not actually be updated until the transaction is committed. -func (txn *Transaction) Update(model Model) error { - txn.mut.Lock() - defer txn.mut.Unlock() - if err := txn.unsafeCheckState(); err != nil { - return err - } - if txn.affectedIDs.Contains(string(model.ID())) { - return ConflictingOperationsError{operation: "update"} - } - if err := updateWithTransaction(txn.colInfo, txn.readWriter, model); err != nil { - return err - } - txn.affectedIDs.Add(string(model.ID())) - return nil -} - -// Delete queues an operation to delete the model with the given ID from the -// database. It returns an error if the model doesn't exist in the database. The -// model will not actually be deleted until the transaction is committed. -func (txn *Transaction) Delete(id []byte) error { - txn.mut.Lock() - defer txn.mut.Unlock() - if err := txn.unsafeCheckState(); err != nil { - return err - } - if txn.affectedIDs.Contains(string(id)) { - return ConflictingOperationsError{operation: "delete"} - } - if err := deleteWithTransaction(txn.colInfo, txn.readWriter, id); err != nil { - return err - } - txn.updateInternalCount(-1) - txn.affectedIDs.Add(string(id)) - return nil -} - -func (txn *Transaction) updateInternalCount(diff int64) { - atomic.AddInt64(&txn.internalCount, diff) -} diff --git a/db/transaction_benchmark_test.go b/db/transaction_benchmark_test.go deleted file mode 100644 index 666960fe7..000000000 --- a/db/transaction_benchmark_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package db - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/require" -) - -func BenchmarkTransactionInsert100(b *testing.B) { - benchmarkTransactionInsert(b, 100) -} - -func BenchmarkTransactionInsert1000(b *testing.B) { - benchmarkTransactionInsert(b, 1000) -} - -func BenchmarkTransactionInsert10000(b *testing.B) { - benchmarkTransactionInsert(b, 10000) -} - -func benchmarkTransactionInsert(b *testing.B, count int) { - b.Helper() - db := newTestDB(b) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(b, err) - b.ResetTimer() - for i := 0; i < b.N; i++ { - txn := col.OpenTransaction() - defer func() { - _ = txn.Discard() - }() - for j := 0; j < count; j++ { - model := &testModel{ - Name: fmt.Sprintf("person_%d_%d", i, j), - Age: j, - } - err := txn.Insert(model) - b.StopTimer() - require.NoError(b, err) - b.StartTimer() - } - err := txn.Commit() - b.StopTimer() - require.NoError(b, err) - b.StartTimer() - } -} diff --git a/db/transaction_test.go b/db/transaction_test.go deleted file mode 100644 index 195648701..000000000 --- a/db/transaction_test.go +++ /dev/null @@ -1,389 +0,0 @@ -package db - -import ( - "fmt" - "strconv" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// transactionExclusionTestTimeout is used in transaction exclusion tests to -// timeout while waiting for one transaction to open. -const transactionExclusionTestTimeout = 500 * time.Millisecond - -func TestTransaction(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - - ageIndex := col.AddIndex("age", func(m Model) []byte { - return []byte(fmt.Sprint(m.(*testModel).Age)) - }) - - // expected is a set of testModels with Age = 42 - expected := []*testModel{} - for i := 0; i < 5; i++ { - model := &testModel{ - Name: "ExpectedPerson_" + strconv.Itoa(i), - Age: 42, - } - require.NoError(t, col.Insert(model)) - expected = append(expected, model) - } - - // Open a transaction. - txn := col.OpenTransaction() - defer func() { - err := txn.Discard() - if err != nil && err != ErrCommitted { - t.Error(err) - } - }() - - // The WaitGroup will be used to wait for all goroutines to finish. - wg := &sync.WaitGroup{} - - // Any models we add after opening the transaction should not affect the query. - wg.Add(1) - go func() { - defer wg.Done() - for i := 0; i < 5; i++ { - model := &testModel{ - Name: "OtherPerson_" + strconv.Itoa(i), - Age: 42, - } - require.NoError(t, col.Insert(model)) - } - }() - - // Any models we delete after opening the transaction should not affect the query. - idToDelete := expected[2].ID() - wg.Add(1) - go func(idToDelete []byte) { - defer wg.Done() - require.NoError(t, col.Delete(idToDelete)) - }(idToDelete) - - // Any new indexes we add should not affect indexes in the transaction. - col.AddIndex("name", func(m Model) []byte { - return []byte(m.(*testModel).Name) - }) - assert.Equal(t, []*Index{ageIndex}, txn.colInfo.indexes) - - // Make sure that the query only return results that match the state inside - // the transaction. - filter := ageIndex.ValueFilter([]byte("42")) - query := col.NewQuery(filter) - var actual []*testModel - require.NoError(t, query.Run(&actual)) - assert.Equal(t, expected, actual) - - // Commit the transaction. - require.NoError(t, txn.Commit()) - - // Wait for any goroutines to finish. - wg.Wait() -} - -func TestTransactionCount(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - - // insertedBeforeTransaction is a set of testModels inserted before the - // transaction is opened. - insertedBeforeTransaction := []*testModel{} - for i := 0; i < 10; i++ { - model := &testModel{ - Name: "Before_Transaction_" + strconv.Itoa(i), - Age: i, - } - require.NoError(t, col.Insert(model)) - insertedBeforeTransaction = append(insertedBeforeTransaction, model) - } - - // Open a transaction. - txn := col.OpenTransaction() - defer func() { - err := txn.Discard() - if err != nil && err != ErrCommitted { - t.Error(err) - } - }() - - // Insert some models inside the transaction. - for i := 0; i < 7; i++ { - model := &testModel{ - Name: "Inside_Transaction_" + strconv.Itoa(i), - Age: i, - } - require.NoError(t, txn.Insert(model)) - } - - // The WaitGroup will be used to wait for all goroutines to finish. - wg := &sync.WaitGroup{} - - // Insert some models outside the transaction. - wg.Add(1) - go func() { - defer wg.Done() - for i := 0; i < 4; i++ { - model := &testModel{ - Name: "Outside_Transaction_" + strconv.Itoa(i), - Age: 42, - } - require.NoError(t, col.Insert(model)) - } - }() - - // Delete some models inside of the transaction. - idsToDeleteInside := [][]byte{ - insertedBeforeTransaction[0].ID(), - insertedBeforeTransaction[1].ID(), - insertedBeforeTransaction[2].ID(), - } - for _, id := range idsToDeleteInside { - require.NoError(t, txn.Delete(id)) - } - - // Delete some models outside of the transaction. - idsToDeleteOutside := [][]byte{ - insertedBeforeTransaction[3].ID(), - insertedBeforeTransaction[4].ID(), - } - wg.Add(1) - go func() { - defer wg.Done() - for _, id := range idsToDeleteOutside { - require.NoError(t, col.Delete(id)) - } - }() - - // Make sure that prior to commiting the transaction, Count only includes the - // models inserted/deleted before the transaction was open. - expectedPreCommitCount := 10 - actualPreCommitCount, err := col.Count() - require.NoError(t, err) - assert.Equal(t, expectedPreCommitCount, actualPreCommitCount) - - // Commit the transaction. - require.NoError(t, txn.Commit()) - - // Wait for any goroutines to finish. - wg.Wait() - - // Make sure that after commiting the transaction, Count includes the models - // inserted/deleted in the transaction and outside of the transaction. - // 10 before transaction. - // +7 inserted inside transaction - // +4 inserted outside transaction - // -3 deleted inside transaction - // -2 deleted outside transaction - // = 16 total - expectedPostCommitCount := 16 - actualPostCommitCount, err := col.Count() - require.NoError(t, err) - assert.Equal(t, expectedPostCommitCount, actualPostCommitCount) -} - -// TestTransactionExclusion is designed to test whether a collection-based -// transaction has exclusive write access for the collection while open. -func TestTransactionExclusion(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col0, err := db.NewCollection("people0", &testModel{}) - require.NoError(t, err) - col1, err := db.NewCollection("people1", &testModel{}) - require.NoError(t, err) - - txn := col0.OpenTransaction() - defer func() { - _ = txn.Discard() - }() - - // col0TxnOpenSignal is fired when a transaction on col0 is opened. - col0TxnOpenSignal := make(chan struct{}) - // col1TxnOpenSignal is fired when a transaction on col1 is opened. - col1TxnOpenSignal := make(chan struct{}) - - wg := &sync.WaitGroup{} - wg.Add(1) - go func() { - defer wg.Done() - txn := col0.OpenTransaction() - close(col0TxnOpenSignal) - defer func() { - _ = txn.Discard() - }() - }() - - wg.Add(1) - go func() { - defer wg.Done() - txn := col1.OpenTransaction() - close(col1TxnOpenSignal) - defer func() { - _ = txn.Discard() - }() - }() - - select { - case <-col1TxnOpenSignal: - // This is expected. Continue the test. - break - case <-time.After(transactionExclusionTestTimeout): - t.Fatal("timed out waiting for col1 transaction to open") - case <-col0TxnOpenSignal: - t.Error("a new transaction was opened on col0 before the first transaction was committed/discarded") - } - - require.NoError(t, txn.Discard()) - - select { - case <-col0TxnOpenSignal: - // This is expected. Continue the test. - break - case <-time.After(transactionExclusionTestTimeout): - t.Fatal("timed out waiting for second col0 transaction to open") - } - - wg.Wait() -} - -func TestTransactionDeleteThenInsertSameModel(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - - // Create a collection and insert one model. - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - model := &testModel{ - Name: "ExpectedPerson", - Age: 42, - } - require.NoError(t, col.Insert(model)) - - // Delete and then insert the same model in a single transaction. - txn := col.OpenTransaction() - defer func() { - err := txn.Discard() - if err != nil && err != ErrCommitted { - t.Error(err) - } - }() - require.NoError(t, txn.Delete(model.ID())) - err = txn.Insert(model) - assert.Error(t, err) - assert.Equal(t, ConflictingOperationsError{operation: "insert"}, err, "wrong error") -} - -func TestTransactionInsertThenDeleteSameModel(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - model := &testModel{ - Name: "ExpectedPerson", - Age: 42, - } - - // Insert and then delete the same model in a single transaction. - txn := col.OpenTransaction() - defer func() { - err := txn.Discard() - if err != nil && err != ErrCommitted { - t.Error(err) - } - }() - require.NoError(t, txn.Insert(model)) - err = txn.Delete(model.ID()) - assert.Error(t, err) - assert.Equal(t, ConflictingOperationsError{operation: "delete"}, err, "wrong error") -} - -func TestTransactionInsertThenInsertSameModel(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - model := &testModel{ - Name: "ExpectedPerson", - Age: 42, - } - - // Insert the same model twice in the same transaction. - txn := col.OpenTransaction() - defer func() { - err := txn.Discard() - if err != nil && err != ErrCommitted { - t.Error(err) - } - }() - require.NoError(t, txn.Insert(model)) - err = txn.Insert(model) - assert.Error(t, err) - assert.Equal(t, ConflictingOperationsError{operation: "insert"}, err, "wrong error") -} - -func TestTransactionDeleteThenDeleteSameModel(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - model := &testModel{ - Name: "ExpectedPerson", - Age: 42, - } - require.NoError(t, col.Insert(model)) - - // Delete the same model twice in the same transaction. - txn := col.OpenTransaction() - defer func() { - err := txn.Discard() - if err != nil && err != ErrCommitted { - t.Error(err) - } - }() - require.NoError(t, txn.Delete(model.ID())) - err = txn.Delete(model.ID()) - assert.Error(t, err) - assert.Equal(t, ConflictingOperationsError{operation: "delete"}, err, "wrong error") -} - -func TestTransactionInsertThenUpdateSameModel(t *testing.T) { - t.Parallel() - db := newTestDB(t) - defer db.Close() - col, err := db.NewCollection("people", &testModel{}) - require.NoError(t, err) - model := &testModel{ - Name: "ExpectedPerson", - Age: 42, - } - - // Insert and then update the same model within the same transaction. - txn := col.OpenTransaction() - defer func() { - err := txn.Discard() - if err != nil && err != ErrCommitted { - t.Error(err) - } - }() - require.NoError(t, txn.Insert(model)) - err = txn.Update(model) - assert.Error(t, err) - assert.Equal(t, ConflictingOperationsError{operation: "update"}, err, "wrong error") -} diff --git a/db/types_js.go b/db/types_js.go new file mode 100644 index 000000000..d50350dcd --- /dev/null +++ b/db/types_js.go @@ -0,0 +1,30 @@ +// +build js,wasm + +package db + +import ( + "syscall/js" + + "github.com/0xProject/0x-mesh/packages/browser/go/jsutil" +) + +func (opts *Options) JSValue() js.Value { + value, _ := jsutil.InefficientlyConvertToJS(opts) + return value +} + +func (query *OrderQuery) JSValue() js.Value { + if query == nil { + return js.Null() + } + value, _ := jsutil.InefficientlyConvertToJS(query) + return value +} + +func (query *MiniHeaderQuery) JSValue() js.Value { + if query == nil { + return js.Null() + } + value, _ := jsutil.InefficientlyConvertToJS(query) + return value +} diff --git a/ethereum/blockwatch/block_watcher.go b/ethereum/blockwatch/block_watcher.go index 2f0d4cc8a..efd30c79e 100644 --- a/ethereum/blockwatch/block_watcher.go +++ b/ethereum/blockwatch/block_watcher.go @@ -9,14 +9,14 @@ import ( "sync" "time" + "github.com/0xProject/0x-mesh/common/types" "github.com/0xProject/0x-mesh/constants" - "github.com/0xProject/0x-mesh/ethereum/miniheader" + "github.com/0xProject/0x-mesh/db" "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/types" + ethtypes "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/event" log "github.com/sirupsen/logrus" - "github.com/syndtr/goleveldb/leveldb" ) // go-ethereum client `ethereum.NotFound` error type message @@ -50,19 +50,7 @@ const ( // Event describes a block event emitted by a Watcher type Event struct { Type EventType - BlockHeader *miniheader.MiniHeader -} - -// Stack defines the interface a stack must implement in order to be used by -// OrderWatcher for block header storage -type Stack interface { - Pop() (*miniheader.MiniHeader, error) - Push(*miniheader.MiniHeader) error - Peek() (*miniheader.MiniHeader, error) - PeekAll() ([]*miniheader.MiniHeader, error) - Clear() error - Checkpoint() (int, error) - Reset(int) error + BlockHeader *types.MiniHeader } // TooMayBlocksBehindError is an error returned if the BlockWatcher has fallen too many blocks behind @@ -78,7 +66,7 @@ func (e TooMayBlocksBehindError) Error() string { // Config holds some configuration options for an instance of BlockWatcher. type Config struct { - Stack Stack + DB *db.DB PollingInterval time.Duration WithLogs bool Topics []common.Hash @@ -89,7 +77,7 @@ type Config struct { // supplied stack) handling block re-orgs and network disruptions gracefully. It can be started from // any arbitrary block height, and will emit both block added and removed events. type Watcher struct { - stack Stack + stack *Stack client Client blockFeed event.Feed blockScope event.SubscriptionScope // Subscription scope tracking current live listeners @@ -105,7 +93,7 @@ type Watcher struct { func New(config Config) *Watcher { return &Watcher{ pollingInterval: config.PollingInterval, - stack: config.Stack, + stack: NewStack(config.DB), client: config.Client, withLogs: config.WithLogs, topics: config.Topics, @@ -179,7 +167,7 @@ func (w *Watcher) Watch(ctx context.Context) error { // Sync immediately when `Watch()` is called instead of waiting for the // first Ticker tick if err := w.SyncToLatestBlock(); err != nil { - if err == leveldb.ErrClosed { + if err == db.ErrClosed { // We can't continue if the database is closed. Stop the watcher and // return an error. return err @@ -200,7 +188,7 @@ func (w *Watcher) Watch(ctx context.Context) error { return nil case <-ticker.C: if err := w.SyncToLatestBlock(); err != nil { - if err == leveldb.ErrClosed { + if err == db.ErrClosed { // We can't continue if the database is closed. Stop the watcher and // return an error. ticker.Stop() @@ -241,7 +229,7 @@ func (w *Watcher) SyncToLatestBlock() error { w.syncToLatestBlockMu.Lock() defer w.syncToLatestBlockMu.Unlock() - checkpointID, err := w.stack.Checkpoint() + existingMiniHeaders, err := w.stack.PeekAll() if err != nil { return err } @@ -314,21 +302,21 @@ func (w *Watcher) SyncToLatestBlock() error { return syncErr } if w.shouldRevertChanges(lastStoredHeader, allEvents) { - if err := w.stack.Reset(checkpointID); err != nil { + // TODO(albrow): Take another look at this. Maybe extract to method. + if err := w.stack.Clear(); err != nil { return err } - } else { - _, err = w.stack.Checkpoint() - if err != nil { + if _, _, err := w.stack.db.AddMiniHeaders(existingMiniHeaders); err != nil { return err } + } else { w.blockFeed.Send(allEvents) } return syncErr } -func (w *Watcher) shouldRevertChanges(lastStoredHeader *miniheader.MiniHeader, events []*Event) bool { +func (w *Watcher) shouldRevertChanges(lastStoredHeader *types.MiniHeader, events []*Event) bool { if len(events) == 0 || lastStoredHeader == nil { return false } @@ -339,7 +327,7 @@ func (w *Watcher) shouldRevertChanges(lastStoredHeader *miniheader.MiniHeader, e return newLatestHeader.Number.Cmp(lastStoredHeader.Number) <= 0 } -func (w *Watcher) buildCanonicalChain(nextHeader *miniheader.MiniHeader, events []*Event) ([]*Event, error) { +func (w *Watcher) buildCanonicalChain(nextHeader *types.MiniHeader, events []*Event) ([]*Event, error) { latestHeader, err := w.stack.Peek() if err != nil { return nil, err @@ -394,7 +382,7 @@ func (w *Watcher) buildCanonicalChain(nextHeader *miniheader.MiniHeader, events return events, nil } -func (w *Watcher) addLogs(header *miniheader.MiniHeader) (*miniheader.MiniHeader, error) { +func (w *Watcher) addLogs(header *types.MiniHeader) (*types.MiniHeader, error) { if !w.withLogs { return header, nil } @@ -443,20 +431,20 @@ func (w *Watcher) getMissedEventsToBackfill(ctx context.Context, blocksElapsed i // Create the block events from all the logs found by grouping // them into blockHeaders - hashToBlockHeader := map[common.Hash]*miniheader.MiniHeader{} + hashToBlockHeader := map[common.Hash]*types.MiniHeader{} for _, log := range logs { - blockHeader, ok := hashToBlockHeader[log.BlockHash] - if !ok { + blockHeader, found := hashToBlockHeader[log.BlockHash] + if !found { blockNumber := big.NewInt(0).SetUint64(log.BlockNumber) header, err := w.client.HeaderByNumber(blockNumber) if err != nil { return events, err } - blockHeader = &miniheader.MiniHeader{ + blockHeader = &types.MiniHeader{ Hash: log.BlockHash, Parent: header.Parent, Number: blockNumber, - Logs: []types.Log{}, + Logs: []ethtypes.Log{}, Timestamp: header.Timestamp, } hashToBlockHeader[log.BlockHash] = blockHeader @@ -478,7 +466,7 @@ func (w *Watcher) getMissedEventsToBackfill(ctx context.Context, blocksElapsed i type logRequestResult struct { From int To int - Logs []types.Log + Logs []ethtypes.Log Err error } @@ -490,7 +478,7 @@ const getLogsRequestChunkSize = 3 // the next batch of requests to be sent. If an error is encountered in a batch, all subsequent // batch requests are not sent. Instead, it returns all the logs it found up until the error was // encountered, along with the block number after which no further logs were retrieved. -func (w *Watcher) getLogsInBlockRange(ctx context.Context, from, to int) ([]types.Log, int) { +func (w *Watcher) getLogsInBlockRange(ctx context.Context, from, to int) ([]ethtypes.Log, int) { blockRanges := w.getSubBlockRanges(from, to, maxBlocksInGetLogsQuery) numChunks := 0 @@ -512,7 +500,7 @@ func (w *Watcher) getLogsInBlockRange(ctx context.Context, from, to int) ([]type didAPreviousRequestFail := false furthestBlockProcessed := from - 1 - allLogs := []types.Log{} + allLogs := []ethtypes.Log{} for i := 0; i < numChunks; i++ { // Add one to the semaphore chan. If it already has a value, the chunk blocks here until one frees up. @@ -542,13 +530,13 @@ func (w *Watcher) getLogsInBlockRange(ctx context.Context, from, to int) ([]type From: b.FromBlock, To: b.ToBlock, Err: errors.New("context was canceled"), - Logs: []types.Log{}, + Logs: []ethtypes.Log{}, } return default: } - logs, err := w.filterLogsRecurisively(b.FromBlock, b.ToBlock, []types.Log{}) + logs, err := w.filterLogsRecurisively(b.FromBlock, b.ToBlock, []ethtypes.Log{}) if err != nil { log.WithFields(map[string]interface{}{ "error": err, @@ -634,7 +622,7 @@ func (w *Watcher) getSubBlockRanges(from, to, rangeSize int) []*blockRange { const infuraTooManyResultsErrMsg = "query returned more than 10000 results" -func (w *Watcher) filterLogsRecurisively(from, to int, allLogs []types.Log) ([]types.Log, error) { +func (w *Watcher) filterLogsRecurisively(from, to int, allLogs []ethtypes.Log) ([]ethtypes.Log, error) { log.WithFields(map[string]interface{}{ "from": from, "to": to, @@ -689,7 +677,7 @@ func (w *Watcher) filterLogsRecurisively(from, to int, allLogs []types.Log) ([]t } // getAllRetainedBlocks returns the blocks retained in-memory by the Watcher. -func (w *Watcher) getAllRetainedBlocks() ([]*miniheader.MiniHeader, error) { +func (w *Watcher) getAllRetainedBlocks() ([]*types.MiniHeader, error) { return w.stack.PeekAll() } diff --git a/ethereum/blockwatch/block_watcher_test.go b/ethereum/blockwatch/block_watcher_test.go index aa44142c6..db3fe925e 100644 --- a/ethereum/blockwatch/block_watcher_test.go +++ b/ethereum/blockwatch/block_watcher_test.go @@ -11,10 +11,10 @@ import ( "testing" "time" - "github.com/0xProject/0x-mesh/ethereum/miniheader" - "github.com/0xProject/0x-mesh/ethereum/simplestack" + "github.com/0xProject/0x-mesh/common/types" + "github.com/0xProject/0x-mesh/db" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/types" + ethtypes "github.com/ethereum/go-ethereum/core/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -27,18 +27,27 @@ var config = Config{ var ( basicFakeClientFixture = "testdata/fake_client_basic_fixture.json" - blockRetentionLimit = 10 - startMiniHeaders = []*miniheader.MiniHeader{} ) +func dbOptions() *db.Options { + options := db.TestOptions() + // For the block watcher tests we set MaxMiniHeaders to 10. + options.MaxMiniHeaders = 10 + return options +} + func TestWatcher(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + database, err := db.New(ctx, dbOptions()) + require.NoError(t, err) fakeClient, err := newFakeClient("testdata/fake_client_block_poller_fixtures.json") require.NoError(t, err) // Polling interval unused because we hijack the ticker for this test require.NoError(t, err) - config.Stack = simplestack.New(blockRetentionLimit, startMiniHeaders) config.Client = fakeClient + config.DB = database watcher := New(config) // Having a buffer of 1 unblocks the below for-loop without resorting to a goroutine @@ -58,13 +67,13 @@ func TestWatcher(t *testing.T) { retainedBlocks, err := watcher.getAllRetainedBlocks() require.NoError(t, err) expectedRetainedBlocks := fakeClient.ExpectedRetainedBlocks() - assert.Equal(t, expectedRetainedBlocks, retainedBlocks, scenarioLabel) + assert.Equal(t, expectedRetainedBlocks, retainedBlocks, fmt.Sprintf("%s (timestep: %d)", scenarioLabel, i)) expectedEvents := fakeClient.GetEvents() if len(expectedEvents) != 0 { select { case gotEvents := <-events: - assert.Equal(t, expectedEvents, gotEvents, scenarioLabel) + assert.Equal(t, expectedEvents, gotEvents, fmt.Sprintf("%s (timestep: %d)", scenarioLabel, i)) case <-time.After(3 * time.Second): t.Fatal("Timed out waiting for Events channel to deliver expected events") @@ -80,17 +89,20 @@ func TestWatcher(t *testing.T) { } func TestWatcherStartStop(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + database, err := db.New(ctx, dbOptions()) + require.NoError(t, err) fakeClient, err := newFakeClient(basicFakeClientFixture) require.NoError(t, err) require.NoError(t, err) - config.Stack = simplestack.New(blockRetentionLimit, startMiniHeaders) config.Client = fakeClient + config.DB = database watcher := New(config) // Start the watcher in a goroutine. We use the done channel to signal when // watcher.Watch returns. - ctx, cancel := context.WithCancel(context.Background()) done := make(chan struct{}) defer cancel() go func() { @@ -120,61 +132,61 @@ type blockRangeChunksTestCase struct { func TestGetSubBlockRanges(t *testing.T) { rangeSize := 6 testCases := []blockRangeChunksTestCase{ - blockRangeChunksTestCase{ + { from: 10, to: 10, expectedBlockRanges: []*blockRange{ - &blockRange{ + { FromBlock: 10, ToBlock: 10, }, }, }, - blockRangeChunksTestCase{ + { from: 10, to: 16, expectedBlockRanges: []*blockRange{ - &blockRange{ + { FromBlock: 10, ToBlock: 15, }, - &blockRange{ + { FromBlock: 16, ToBlock: 16, }, }, }, - blockRangeChunksTestCase{ + { from: 10, to: 22, expectedBlockRanges: []*blockRange{ - &blockRange{ + { FromBlock: 10, ToBlock: 15, }, - &blockRange{ + { FromBlock: 16, ToBlock: 21, }, - &blockRange{ + { FromBlock: 22, ToBlock: 22, }, }, }, - blockRangeChunksTestCase{ + { from: 10, to: 24, expectedBlockRanges: []*blockRange{ - &blockRange{ + { FromBlock: 10, ToBlock: 15, }, - &blockRange{ + { FromBlock: 16, ToBlock: 21, }, - &blockRange{ + { FromBlock: 22, ToBlock: 24, }, @@ -182,11 +194,14 @@ func TestGetSubBlockRanges(t *testing.T) { }, } - fakeClient, err := newFakeClient(basicFakeClientFixture) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + database, err := db.New(ctx, dbOptions()) require.NoError(t, err) + fakeClient, err := newFakeClient(basicFakeClientFixture) require.NoError(t, err) - config.Stack = simplestack.New(blockRetentionLimit, startMiniHeaders) config.Client = fakeClient + config.DB = database watcher := New(config) for _, testCase := range testCases { @@ -196,110 +211,111 @@ func TestGetSubBlockRanges(t *testing.T) { } func TestFastSyncToLatestBlockLessThan128Missed(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + database, err := db.New(ctx, dbOptions()) + require.NoError(t, err) // Fixture will return block 132 as the tip of the chain (127 blocks from block 5) fakeClient, err := newFakeClient("testdata/fake_client_fast_sync_fixture.json") require.NoError(t, err) require.NoError(t, err) // Add block number 5 as the last block seen by BlockWatcher - lastBlockSeen := &miniheader.MiniHeader{ + lastBlockSeen := &types.MiniHeader{ Number: big.NewInt(5), Hash: common.HexToHash("0x293b9ea024055a3e9eddbf9b9383dc7731744111894af6aa038594dc1b61f87f"), Parent: common.HexToHash("0x26b13ac89500f7fcdd141b7d1b30f3a82178431eca325d1cf10998f9d68ff5ba"), Timestamp: time.Now(), } - config.Stack = simplestack.New(blockRetentionLimit, startMiniHeaders) - - err = config.Stack.Push(lastBlockSeen) - require.NoError(t, err) - + config.DB = database config.Client = fakeClient watcher := New(config) + err = watcher.stack.Push(lastBlockSeen) + require.NoError(t, err) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() blocksElapsed, err := watcher.FastSyncToLatestBlock(ctx) require.NoError(t, err) assert.Equal(t, 127, blocksElapsed) // Check that block 132 is now in the DB, and block 5 was removed. - headers, err := config.Stack.PeekAll() + headers, err := watcher.stack.PeekAll() require.NoError(t, err) require.Len(t, headers, 1) assert.Equal(t, big.NewInt(132), headers[0].Number) } func TestFastSyncToLatestBlockMoreThanOrExactly128Missed(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + database, err := db.New(ctx, dbOptions()) + require.NoError(t, err) // Fixture will return block 133 as the tip of the chain (128 blocks from block 5) fakeClient, err := newFakeClient("testdata/fake_client_reset_fixture.json") require.NoError(t, err) require.NoError(t, err) // Add block number 5 as the last block seen by BlockWatcher - lastBlockSeen := &miniheader.MiniHeader{ + lastBlockSeen := &types.MiniHeader{ Number: big.NewInt(5), Hash: common.HexToHash("0x293b9ea024055a3e9eddbf9b9383dc7731744111894af6aa038594dc1b61f87f"), Parent: common.HexToHash("0x26b13ac89500f7fcdd141b7d1b30f3a82178431eca325d1cf10998f9d68ff5ba"), Timestamp: time.Now(), } - config.Stack = simplestack.New(blockRetentionLimit, startMiniHeaders) - - err = config.Stack.Push(lastBlockSeen) - require.NoError(t, err) - + config.DB = database config.Client = fakeClient watcher := New(config) + err = watcher.stack.Push(lastBlockSeen) + require.NoError(t, err) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() blocksElapsed, err := watcher.FastSyncToLatestBlock(ctx) require.NoError(t, err) assert.Equal(t, 128, blocksElapsed) // Check that all blocks have been removed from BlockWatcher - headers, err := config.Stack.PeekAll() + headers, err := watcher.stack.PeekAll() require.NoError(t, err) require.Len(t, headers, 0) } func TestFastSyncToLatestBlockNoneMissed(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + database, err := db.New(ctx, dbOptions()) + require.NoError(t, err) // Fixture will return block 5 as the tip of the chain fakeClient, err := newFakeClient("testdata/fake_client_basic_fixture.json") require.NoError(t, err) require.NoError(t, err) // Add block number 5 as the last block seen by BlockWatcher - lastBlockSeen := &miniheader.MiniHeader{ + lastBlockSeen := &types.MiniHeader{ Number: big.NewInt(5), Hash: common.HexToHash("0x293b9ea024055a3e9eddbf9b9383dc7731744111894af6aa038594dc1b61f87f"), Parent: common.HexToHash("0x26b13ac89500f7fcdd141b7d1b30f3a82178431eca325d1cf10998f9d68ff5ba"), Timestamp: time.Now(), } - config.Stack = simplestack.New(blockRetentionLimit, startMiniHeaders) - - err = config.Stack.Push(lastBlockSeen) - require.NoError(t, err) - + config.DB = database config.Client = fakeClient watcher := New(config) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + err = watcher.stack.Push(lastBlockSeen) + require.NoError(t, err) + blocksElapsed, err := watcher.FastSyncToLatestBlock(ctx) require.NoError(t, err) assert.Equal(t, blocksElapsed, 0) // Check that block 5 is still in the DB - headers, err := config.Stack.PeekAll() + headers, err := watcher.stack.PeekAll() require.NoError(t, err) require.Len(t, headers, 1) assert.Equal(t, big.NewInt(5), headers[0].Number) } -var logStub = types.Log{ +var logStub = ethtypes.Log{ Address: common.HexToAddress("0x21ab6c9fac80c59d401b37cb43f81ea9dde7fe34"), Topics: []common.Hash{ common.HexToHash("0xddf252ad1be2c89b69c2b068fc378daa952ba7f163c4a11628f55a4df523b3ef"), @@ -321,91 +337,91 @@ type filterLogsRecusivelyTestCase struct { Label string rangeToFilterLogsResponse map[string]filterLogsResponse Err error - Logs []types.Log + Logs []ethtypes.Log } func TestFilterLogsRecursively(t *testing.T) { from := 10 to := 20 testCases := []filterLogsRecusivelyTestCase{ - filterLogsRecusivelyTestCase{ + { Label: "HAPPY_PATH", rangeToFilterLogsResponse: map[string]filterLogsResponse{ "10-20": filterLogsResponse{ - Logs: []types.Log{ + Logs: []ethtypes.Log{ logStub, }, }, }, - Logs: []types.Log{logStub}, + Logs: []ethtypes.Log{logStub}, }, - filterLogsRecusivelyTestCase{ + { Label: "TOO_MANY_RESULTS_INFURA_ERROR", rangeToFilterLogsResponse: map[string]filterLogsResponse{ - "10-20": filterLogsResponse{ + "10-20": { Err: errors.New(infuraTooManyResultsErrMsg), }, - "10-15": filterLogsResponse{ - Logs: []types.Log{ + "10-15": { + Logs: []ethtypes.Log{ logStub, }, }, - "16-20": filterLogsResponse{ - Logs: []types.Log{ + "16-20": { + Logs: []ethtypes.Log{ logStub, }, }, }, - Logs: []types.Log{logStub, logStub}, + Logs: []ethtypes.Log{logStub, logStub}, }, - filterLogsRecusivelyTestCase{ + { Label: "TOO_MANY_RESULTS_INFURA_ERROR_DEEPER_RECURSION", rangeToFilterLogsResponse: map[string]filterLogsResponse{ - "10-20": filterLogsResponse{ + "10-20": { Err: errors.New(infuraTooManyResultsErrMsg), }, - "10-15": filterLogsResponse{ - Logs: []types.Log{ + "10-15": { + Logs: []ethtypes.Log{ logStub, }, }, - "16-20": filterLogsResponse{ + "16-20": { Err: errors.New(infuraTooManyResultsErrMsg), }, - "16-18": filterLogsResponse{ - Logs: []types.Log{ + "16-18": { + Logs: []ethtypes.Log{ logStub, }, }, - "19-20": filterLogsResponse{ - Logs: []types.Log{ + "19-20": { + Logs: []ethtypes.Log{ logStub, }, }, }, - Logs: []types.Log{logStub, logStub, logStub}, + Logs: []ethtypes.Log{logStub, logStub, logStub}, }, - filterLogsRecusivelyTestCase{ + { Label: "TOO_MANY_RESULTS_INFURA_ERROR_DEEPER_RECURSION_FAILURE", rangeToFilterLogsResponse: map[string]filterLogsResponse{ - "10-20": filterLogsResponse{ + "10-20": { Err: errors.New(infuraTooManyResultsErrMsg), }, - "10-15": filterLogsResponse{ - Logs: []types.Log{ + "10-15": { + Logs: []ethtypes.Log{ logStub, }, }, - "16-20": filterLogsResponse{ + "16-20": { Err: errUnexpected, }, }, Err: errUnexpected, }, - filterLogsRecusivelyTestCase{ + { Label: "UNEXPECTED_ERROR", rangeToFilterLogsResponse: map[string]filterLogsResponse{ - "10-20": filterLogsResponse{ + "10-20": { Err: errUnexpected, }, }, @@ -413,15 +429,18 @@ func TestFilterLogsRecursively(t *testing.T) { }, } - config.Stack = simplestack.New(blockRetentionLimit, startMiniHeaders) - for _, testCase := range testCases { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + database, err := db.New(ctx, dbOptions()) + require.NoError(t, err) fakeLogClient, err := newFakeLogClient(testCase.rangeToFilterLogsResponse) require.NoError(t, err) config.Client = fakeLogClient + config.DB = database watcher := New(config) - logs, err := watcher.filterLogsRecurisively(from, to, []types.Log{}) + logs, err := watcher.filterLogsRecurisively(from, to, []ethtypes.Log{}) require.Equal(t, testCase.Err, err, testCase.Label) require.Equal(t, testCase.Logs, logs, testCase.Label) assert.Equal(t, len(testCase.rangeToFilterLogsResponse), fakeLogClient.Count()) @@ -433,7 +452,7 @@ type logsInBlockRangeTestCase struct { From int To int RangeToFilterLogsResponse map[string]filterLogsResponse - Logs []types.Log + Logs []ethtypes.Log FurthestBlockProcessed int } @@ -447,12 +466,12 @@ func TestGetLogsInBlockRange(t *testing.T) { To: to, RangeToFilterLogsResponse: map[string]filterLogsResponse{ aRange(from, to): filterLogsResponse{ - Logs: []types.Log{ + Logs: []ethtypes.Log{ logStub, }, }, }, - Logs: []types.Log{logStub}, + Logs: []ethtypes.Log{logStub}, FurthestBlockProcessed: to, }, logsInBlockRangeTestCase{ @@ -461,17 +480,17 @@ func TestGetLogsInBlockRange(t *testing.T) { To: from + maxBlocksInGetLogsQuery + 10, RangeToFilterLogsResponse: map[string]filterLogsResponse{ aRange(from, from+maxBlocksInGetLogsQuery-1): filterLogsResponse{ - Logs: []types.Log{ + Logs: []ethtypes.Log{ logStub, }, }, aRange(from+maxBlocksInGetLogsQuery, from+maxBlocksInGetLogsQuery+10): filterLogsResponse{ - Logs: []types.Log{ + Logs: []ethtypes.Log{ logStub, }, }, }, - Logs: []types.Log{logStub, logStub}, + Logs: []ethtypes.Log{logStub, logStub}, FurthestBlockProcessed: from + maxBlocksInGetLogsQuery + 10, }, logsInBlockRangeTestCase{ @@ -485,17 +504,17 @@ func TestGetLogsInBlockRange(t *testing.T) { Err: errUnexpected, }, aRange(from+maxBlocksInGetLogsQuery, from+(maxBlocksInGetLogsQuery*2)-1): filterLogsResponse{ - Logs: []types.Log{ + Logs: []ethtypes.Log{ logStub, }, }, aRange(from+(maxBlocksInGetLogsQuery*2), from+(maxBlocksInGetLogsQuery*3)-1): filterLogsResponse{ - Logs: []types.Log{ + Logs: []ethtypes.Log{ logStub, }, }, }, - Logs: []types.Log{}, + Logs: []ethtypes.Log{}, FurthestBlockProcessed: from - 1, }, logsInBlockRangeTestCase{ @@ -504,26 +523,26 @@ func TestGetLogsInBlockRange(t *testing.T) { To: from + maxBlocksInGetLogsQuery + 10, RangeToFilterLogsResponse: map[string]filterLogsResponse{ aRange(from, from+maxBlocksInGetLogsQuery-1): filterLogsResponse{ - Logs: []types.Log{ + Logs: []ethtypes.Log{ logStub, }, }, aRange(from+maxBlocksInGetLogsQuery, from+maxBlocksInGetLogsQuery+10): filterLogsResponse{ Err: errUnexpected, }}, - Logs: []types.Log{logStub}, + Logs: []ethtypes.Log{logStub}, FurthestBlockProcessed: from + maxBlocksInGetLogsQuery - 1, }, } - config.Stack = simplestack.New(blockRetentionLimit, startMiniHeaders) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - for _, testCase := range testCases { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + database, err := db.New(ctx, dbOptions()) + require.NoError(t, err) fakeLogClient, err := newFakeLogClient(testCase.RangeToFilterLogsResponse) require.NoError(t, err) + config.DB = database config.Client = fakeLogClient watcher := New(config) diff --git a/ethereum/blockwatch/client.go b/ethereum/blockwatch/client.go index d7ada54fa..a323cbad5 100644 --- a/ethereum/blockwatch/client.go +++ b/ethereum/blockwatch/client.go @@ -7,14 +7,14 @@ import ( "math/big" "time" + "github.com/0xProject/0x-mesh/common/types" "github.com/0xProject/0x-mesh/constants" "github.com/0xProject/0x-mesh/ethereum/ethrpcclient" - "github.com/0xProject/0x-mesh/ethereum/miniheader" "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/math" - "github.com/ethereum/go-ethereum/core/types" + ethtypes "github.com/ethereum/go-ethereum/core/types" ) var ( @@ -25,9 +25,9 @@ var ( // Client defines the methods needed to satisfy the client expected when // instantiating a Watcher instance. type Client interface { - HeaderByNumber(number *big.Int) (*miniheader.MiniHeader, error) - HeaderByHash(hash common.Hash) (*miniheader.MiniHeader, error) - FilterLogs(q ethereum.FilterQuery) ([]types.Log, error) + HeaderByNumber(number *big.Int) (*types.MiniHeader, error) + HeaderByHash(hash common.Hash) (*types.MiniHeader, error) + FilterLogs(q ethereum.FilterQuery) ([]ethtypes.Log, error) } // RpcClient is a Client for fetching Ethereum blocks from a specific JSON-RPC endpoint. @@ -63,7 +63,7 @@ func (e UnknownBlockNumberError) Error() string { // HeaderByNumber fetches a block header by its number. If no `number` is supplied, it will return the latest // block header. If no block exists with this number it will return a `ethereum.NotFound` error. -func (rc *RpcClient) HeaderByNumber(number *big.Int) (*miniheader.MiniHeader, error) { +func (rc *RpcClient) HeaderByNumber(number *big.Int) (*types.MiniHeader, error) { var blockParam string if number == nil { blockParam = "latest" @@ -101,7 +101,7 @@ func (rc *RpcClient) HeaderByNumber(number *big.Int) (*miniheader.MiniHeader, er if !ok { return nil, errors.New("Failed to parse big.Int value from hex-encoded block timestamp returned from eth_getBlockByNumber") } - miniHeader := &miniheader.MiniHeader{ + miniHeader := &types.MiniHeader{ Hash: header.Hash, Parent: header.ParentHash, Number: blockNum, @@ -122,7 +122,7 @@ func (e UnknownBlockHashError) Error() string { // HeaderByHash fetches a block header by its block hash. If no block exists with this number it will return // a `ethereum.NotFound` error. -func (rc *RpcClient) HeaderByHash(hash common.Hash) (*miniheader.MiniHeader, error) { +func (rc *RpcClient) HeaderByHash(hash common.Hash) (*types.MiniHeader, error) { ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) defer cancel() header, err := rc.ethRPCClient.HeaderByHash(ctx, hash) @@ -135,7 +135,7 @@ func (rc *RpcClient) HeaderByHash(hash common.Hash) (*miniheader.MiniHeader, err } return nil, err } - miniHeader := &miniheader.MiniHeader{ + miniHeader := &types.MiniHeader{ Hash: header.Hash(), Parent: header.ParentHash, Number: header.Number, @@ -156,7 +156,7 @@ func (e FilterUnknownBlockError) Error() string { } // FilterLogs returns the logs that satisfy the supplied filter query. -func (rc *RpcClient) FilterLogs(q ethereum.FilterQuery) ([]types.Log, error) { +func (rc *RpcClient) FilterLogs(q ethereum.FilterQuery) ([]ethtypes.Log, error) { ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) defer cancel() logs, err := rc.ethRPCClient.FilterLogs(ctx, q) diff --git a/ethereum/blockwatch/fake_client.go b/ethereum/blockwatch/fake_client.go index 1a0ad357a..db384cd43 100644 --- a/ethereum/blockwatch/fake_client.go +++ b/ethereum/blockwatch/fake_client.go @@ -7,20 +7,20 @@ import ( "math/big" "sync" - "github.com/0xProject/0x-mesh/ethereum/miniheader" + "github.com/0xProject/0x-mesh/common/types" "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/types" + ethtypes "github.com/ethereum/go-ethereum/core/types" ) // fixtureTimestep holds the JSON-RPC data available at every timestep of the simulation. type fixtureTimestep struct { - GetLatestBlock miniheader.MiniHeader `json:"getLatestBlock" gencodec:"required"` - GetBlockByNumber map[uint64]miniheader.MiniHeader `json:"getBlockByNumber" gencodec:"required"` - GetBlockByHash map[common.Hash]miniheader.MiniHeader `json:"getBlockByHash" gencodec:"required"` - GetCorrectChain []*miniheader.MiniHeader `json:"getCorrectChain" gencodec:"required"` - BlockEvents []*Event `json:"blockEvents" gencodec:"required"` - ScenarioLabel string `json:"scenarioLabel" gencodec:"required"` + GetLatestBlock types.MiniHeader `json:"getLatestBlock" gencodec:"required"` + GetBlockByNumber map[uint64]types.MiniHeader `json:"getBlockByNumber" gencodec:"required"` + GetBlockByHash map[common.Hash]types.MiniHeader `json:"getBlockByHash" gencodec:"required"` + GetCorrectChain []*types.MiniHeader `json:"getCorrectChain" gencodec:"required"` + BlockEvents []*Event `json:"blockEvents" gencodec:"required"` + ScenarioLabel string `json:"scenarioLabel" gencodec:"required"` } // fakeClient is a fake Client for testing purposes. @@ -46,11 +46,11 @@ func newFakeClient(fixtureFilePath string) (*fakeClient, error) { // HeaderByNumber fetches a block header by its number. If no `number` is supplied, it will return the latest // block header. If no block exists with this number it will return a `ethereum.NotFound` error. -func (fc *fakeClient) HeaderByNumber(number *big.Int) (*miniheader.MiniHeader, error) { +func (fc *fakeClient) HeaderByNumber(number *big.Int) (*types.MiniHeader, error) { fc.fixtureMut.Lock() defer fc.fixtureMut.Unlock() timestep := fc.fixtureData[fc.currentTimestep] - var miniHeader miniheader.MiniHeader + var miniHeader types.MiniHeader var ok bool if number == nil { miniHeader = timestep.GetLatestBlock @@ -65,7 +65,7 @@ func (fc *fakeClient) HeaderByNumber(number *big.Int) (*miniheader.MiniHeader, e // HeaderByHash fetches a block header by its block hash. If no block exists with this number it will return // a `ethereum.NotFound` error. -func (fc *fakeClient) HeaderByHash(hash common.Hash) (*miniheader.MiniHeader, error) { +func (fc *fakeClient) HeaderByHash(hash common.Hash) (*types.MiniHeader, error) { fc.fixtureMut.Lock() defer fc.fixtureMut.Unlock() timestep := fc.fixtureData[fc.currentTimestep] @@ -77,10 +77,10 @@ func (fc *fakeClient) HeaderByHash(hash common.Hash) (*miniheader.MiniHeader, er } // FilterLogs returns the logs that satisfy the supplied filter query. -func (fc *fakeClient) FilterLogs(q ethereum.FilterQuery) ([]types.Log, error) { +func (fc *fakeClient) FilterLogs(q ethereum.FilterQuery) ([]ethtypes.Log, error) { // IMPLEMENTED WITH A CANNED RESPONSE. FOR MORE ELABORATE TESTING, SEE `fakeLogClient` - return []types.Log{ - types.Log{ + return []ethtypes.Log{ + { Address: common.HexToAddress("0x21ab6c9fac80c59d401b37cb43f81ea9dde7fe34"), Topics: []common.Hash{ common.HexToHash("0xddf252ad1be2c89b69c2b068fc378daa952ba7f163c4a11628f55a4df523b3ef"), @@ -111,7 +111,7 @@ func (fc *fakeClient) NumberOfTimesteps() int { } // ExpectedRetainedBlocks returns the expected retained blocks at the current timestep. -func (fc *fakeClient) ExpectedRetainedBlocks() []*miniheader.MiniHeader { +func (fc *fakeClient) ExpectedRetainedBlocks() []*types.MiniHeader { fc.fixtureMut.Lock() defer fc.fixtureMut.Unlock() return fc.fixtureData[fc.currentTimestep].GetCorrectChain diff --git a/ethereum/blockwatch/fake_log_client.go b/ethereum/blockwatch/fake_log_client.go index fb2e34e0c..63a7407ad 100644 --- a/ethereum/blockwatch/fake_log_client.go +++ b/ethereum/blockwatch/fake_log_client.go @@ -7,14 +7,14 @@ import ( "sync/atomic" "time" - "github.com/0xProject/0x-mesh/ethereum/miniheader" + "github.com/0xProject/0x-mesh/common/types" "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/types" + ethtypes "github.com/ethereum/go-ethereum/core/types" ) type filterLogsResponse struct { - Logs []types.Log + Logs []ethtypes.Log Err error } @@ -31,17 +31,17 @@ func newFakeLogClient(rangeToResponse map[string]filterLogsResponse) (*fakeLogCl } // HeaderByNumber fetches a block header by its number -func (fc *fakeLogClient) HeaderByNumber(number *big.Int) (*miniheader.MiniHeader, error) { +func (fc *fakeLogClient) HeaderByNumber(number *big.Int) (*types.MiniHeader, error) { return nil, errors.New("NOT_IMPLEMENTED") } // HeaderByHash fetches a block header by its block hash -func (fc *fakeLogClient) HeaderByHash(hash common.Hash) (*miniheader.MiniHeader, error) { +func (fc *fakeLogClient) HeaderByHash(hash common.Hash) (*types.MiniHeader, error) { return nil, errors.New("NOT_IMPLEMENTED") } // FilterLogs returns the logs that satisfy the supplied filter query -func (fc *fakeLogClient) FilterLogs(q ethereum.FilterQuery) ([]types.Log, error) { +func (fc *fakeLogClient) FilterLogs(q ethereum.FilterQuery) ([]ethtypes.Log, error) { // Add a slight delay to simulate an actual network request. This also gives // BlockWatcher.getLogsInBlockRange multi-requests to hit the concurrent request // limit semaphore and simulate more realistic conditions. diff --git a/ethereum/blockwatch/stack.go b/ethereum/blockwatch/stack.go new file mode 100644 index 000000000..905dc6355 --- /dev/null +++ b/ethereum/blockwatch/stack.go @@ -0,0 +1,81 @@ +package blockwatch + +import ( + "fmt" + + "github.com/0xProject/0x-mesh/common/types" + "github.com/0xProject/0x-mesh/db" +) + +type MiniHeaderAlreadyExistsError struct { + miniHeader *types.MiniHeader +} + +func (e MiniHeaderAlreadyExistsError) Error() string { + return fmt.Sprintf("cannot add miniHeader with the same number (%s) or hash (%s) as an existing miniHeader", e.miniHeader.Number.String(), e.miniHeader.Hash.Hex()) +} + +// Stack is a simple in-memory stack used in tests +type Stack struct { + db *db.DB +} + +// New instantiates a new Stack +func NewStack(db *db.DB) *Stack { + return &Stack{ + db: db, + } +} + +// Peek returns the top of the stack +func (s *Stack) Peek() (*types.MiniHeader, error) { + latestMiniHeader, err := s.db.GetLatestMiniHeader() + if err != nil { + if err == db.ErrNotFound { + return nil, nil + } + return nil, err + } + return latestMiniHeader, nil +} + +// Pop returns the top of the stack and removes it from the stack +func (s *Stack) Pop() (*types.MiniHeader, error) { + removed, err := s.db.DeleteMiniHeaders(&db.MiniHeaderQuery{ + Limit: 1, + Sort: []db.MiniHeaderSort{ + { + Field: db.MFNumber, + Direction: db.Descending, + }, + }, + }) + if err != nil { + return nil, err + } else if len(removed) == 0 { + return nil, nil + } + return removed[0], nil +} + +// Push adds a db.MiniHeader to the stack. It returns an error if +// the stack already contains a miniHeader with the same number or +// hash. +func (s *Stack) Push(miniHeader *types.MiniHeader) error { + added, _, err := s.db.AddMiniHeaders([]*types.MiniHeader{miniHeader}) + if len(added) == 0 { + return MiniHeaderAlreadyExistsError{miniHeader: miniHeader} + } + return err +} + +// PeekAll returns all the miniHeaders currently in the stack +func (s *Stack) PeekAll() ([]*types.MiniHeader, error) { + return s.db.FindMiniHeaders(nil) +} + +// Clear removes all items from the stack and clears any set checkpoint +func (s *Stack) Clear() error { + _, err := s.db.DeleteMiniHeaders(nil) + return err +} diff --git a/ethereum/blockwatch/stack_test.go b/ethereum/blockwatch/stack_test.go new file mode 100644 index 000000000..f2abf626d --- /dev/null +++ b/ethereum/blockwatch/stack_test.go @@ -0,0 +1,92 @@ +package blockwatch + +import ( + "context" + "math/big" + "testing" + "time" + + "github.com/0xProject/0x-mesh/common/types" + "github.com/0xProject/0x-mesh/db" + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + miniHeaderOne = &types.MiniHeader{ + Number: big.NewInt(1), + Hash: common.HexToHash("0x1"), + Parent: common.HexToHash("0x0"), + Timestamp: time.Now().UTC(), + } +) + +func newTestStack(t *testing.T, ctx context.Context) *Stack { + database, err := db.New(ctx, db.TestOptions()) + require.NoError(t, err) + return NewStack(database) +} + +func TestStackPushPeekPop(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + stack := newTestStack(t, ctx) + + err := stack.Push(miniHeaderOne) + require.NoError(t, err) + expectedMiniHeader := miniHeaderOne + + actualMiniHeaders, err := stack.PeekAll() + require.NoError(t, err) + require.Len(t, actualMiniHeaders, 1) + assert.Equal(t, expectedMiniHeader, actualMiniHeaders[0]) + + actualMiniHeader, err := stack.Peek() + require.NoError(t, err) + assert.Equal(t, expectedMiniHeader, actualMiniHeader) + + actualMiniHeaders, err = stack.PeekAll() + require.NoError(t, err) + assert.Len(t, actualMiniHeaders, 1) + + actualMiniHeader, err = stack.Pop() + require.NoError(t, err) + assert.Equal(t, expectedMiniHeader, actualMiniHeader) + + actualMiniHeaders, err = stack.PeekAll() + require.NoError(t, err) + assert.Len(t, actualMiniHeaders, 0) +} + +func TestStackErrorIfPushTwoHeadersWithSameNumber(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + stack := newTestStack(t, ctx) + // Push miniHeaderOne + err := stack.Push(miniHeaderOne) + require.NoError(t, err) + // Push miniHeaderOne again + err = stack.Push(miniHeaderOne) + assert.Error(t, err) +} + +func TestStackClear(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + stack := newTestStack(t, ctx) + + err := stack.Push(miniHeaderOne) + require.NoError(t, err) + + miniHeader, err := stack.Peek() + require.NoError(t, err) + assert.Equal(t, miniHeaderOne, miniHeader) + + err = stack.Clear() + require.NoError(t, err) + + miniHeader, err = stack.Peek() + require.NoError(t, err) + require.Nil(t, miniHeader) +} diff --git a/ethereum/miniheader/miniheader.go b/ethereum/miniheader/miniheader.go index 81b35067e..b4b9b096f 100644 --- a/ethereum/miniheader/miniheader.go +++ b/ethereum/miniheader/miniheader.go @@ -19,7 +19,7 @@ type MiniHeader struct { // ID returns the MiniHeader's ID // HACK(fabio): This method is used when storing MiniHeaders in the DB -// Ideally this would live in the `meshdb` package but it adds the need +// Ideally this would live in the `db` package but it adds the need // to cast back-and-forth between two almost identical types so we keep // it here for convenience sake. func (m *MiniHeader) ID() []byte { diff --git a/ethereum/ratelimit/rate_limiter.go b/ethereum/ratelimit/rate_limiter.go index a74e7fbc5..e0e70d40d 100644 --- a/ethereum/ratelimit/rate_limiter.go +++ b/ethereum/ratelimit/rate_limiter.go @@ -7,10 +7,10 @@ import ( "sync" "time" - "github.com/0xProject/0x-mesh/meshdb" + "github.com/0xProject/0x-mesh/common/types" + "github.com/0xProject/0x-mesh/db" "github.com/benbjohnson/clock" log "github.com/sirupsen/logrus" - "github.com/syndtr/goleveldb/leveldb" "golang.org/x/time/rate" ) @@ -30,7 +30,7 @@ type rateLimiter struct { perSecondLimiter *rate.Limiter currentUTCCheckpoint time.Time // Start of current UTC 24hr period grantedInLast24hrsUTC int // Number of granted requests issued in last 24hr UTC - meshDB *meshdb.MeshDB + database *db.DB aClock clock.Clock wasStartedOnce bool // Whether the rate limiter has previously been started startMutex sync.Mutex // Mutex around the start check @@ -38,8 +38,8 @@ type rateLimiter struct { } // New instantiates a new RateLimiter -func New(maxRequestsPer24Hrs int, maxRequestsPerSecond float64, meshDB *meshdb.MeshDB, aClock clock.Clock) (RateLimiter, error) { - metadata, err := meshDB.GetMetadata() +func New(maxRequestsPer24Hrs int, maxRequestsPerSecond float64, database *db.DB, aClock clock.Clock) (RateLimiter, error) { + metadata, err := database.GetMetadata() if err != nil { return nil, err } @@ -53,7 +53,7 @@ func New(maxRequestsPer24Hrs int, maxRequestsPerSecond float64, meshDB *meshdb.M if currentUTCCheckpoint != storedUTCCheckpoint { storedUTCCheckpoint = currentUTCCheckpoint storedGrantedInLast24HrsUTC = 0 - if err := meshDB.UpdateMetadata(func(metadata meshdb.Metadata) meshdb.Metadata { + if err := database.UpdateMetadata(func(metadata *types.Metadata) *types.Metadata { metadata.StartOfCurrentUTCDay = storedUTCCheckpoint metadata.EthRPCRequestsSentInCurrentUTCDay = storedGrantedInLast24HrsUTC return metadata @@ -73,7 +73,7 @@ func New(maxRequestsPer24Hrs int, maxRequestsPerSecond float64, meshDB *meshdb.M aClock: aClock, maxRequestsPer24Hrs: maxRequestsPer24Hrs, perSecondLimiter: perSecondLimiter, - meshDB: meshDB, + database: database, currentUTCCheckpoint: storedUTCCheckpoint, grantedInLast24hrsUTC: storedGrantedInLast24HrsUTC, }, nil @@ -125,14 +125,14 @@ func (r *rateLimiter) Start(ctx context.Context, checkpointInterval time.Duratio case <-ticker.C: // Store grants issued and current UTC checkpoint to DB r.mu.Lock() - err := r.meshDB.UpdateMetadata(func(metadata meshdb.Metadata) meshdb.Metadata { + err := r.database.UpdateMetadata(func(metadata *types.Metadata) *types.Metadata { metadata.StartOfCurrentUTCDay = r.currentUTCCheckpoint metadata.EthRPCRequestsSentInCurrentUTCDay = r.grantedInLast24hrsUTC return metadata }) r.mu.Unlock() if err != nil { - if err == leveldb.ErrClosed { + if err == db.ErrClosed { // We can't continue if the database is closed. Stop the rateLimiter and // return an error. ticker.Stop() diff --git a/ethereum/ratelimit/rate_limiter_test.go b/ethereum/ratelimit/rate_limiter_test.go index 2ef431693..3509a0943 100644 --- a/ethereum/ratelimit/rate_limiter_test.go +++ b/ethereum/ratelimit/rate_limiter_test.go @@ -7,11 +7,10 @@ import ( "testing" "time" + "github.com/0xProject/0x-mesh/common/types" "github.com/0xProject/0x-mesh/constants" - "github.com/0xProject/0x-mesh/ethereum" - "github.com/0xProject/0x-mesh/meshdb" + "github.com/0xProject/0x-mesh/db" "github.com/benbjohnson/clock" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -27,15 +26,13 @@ const ( grantTimingTolerance = 50 * time.Millisecond ) -var contractAddresses = ethereum.GanacheAddresses - // Scenario1: If the 24 hour limit has *not* been hit, requests should be // granted based on the per second limiter. func TestScenario1(t *testing.T) { - meshDB, err := meshdb.New("/tmp/meshdb_testing/"+uuid.New().String(), contractAddresses) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) - defer meshDB.Close() - initMetadata(t, meshDB) + initMetadata(t, database) // Set up some constants for this test. const maxRequestsPer24Hrs = 100000 @@ -45,9 +42,8 @@ func TestScenario1(t *testing.T) { aClock := clock.NewMock() aClock.Set(GetUTCMidnightOfDate(time.Now())) - rateLimiter, err := New(maxRequestsPer24Hrs, maxRequestsPerSecond, meshDB, aClock) + rateLimiter, err := New(maxRequestsPer24Hrs, maxRequestsPerSecond, database, aClock) require.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) wg := &sync.WaitGroup{} wg.Add(1) go func() { @@ -69,9 +65,9 @@ func TestScenario1(t *testing.T) { // Scenario 2: Max requests per 24 hours used up. Subsequent calls to Wait // should return an error. func TestScenario2(t *testing.T) { - meshDB, err := meshdb.New("/tmp/meshdb_testing/"+uuid.New().String(), contractAddresses) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) - defer meshDB.Close() now := time.Now() startOfCurrentUTCDay := GetUTCMidnightOfDate(now) @@ -79,13 +75,13 @@ func TestScenario2(t *testing.T) { requestsSentInCurrentDay := defaultMaxRequestsPer24Hrs - requestsRemainingInCurrentDay // Set metadata to just short of maximum requests per 24 hours. - metadata := &meshdb.Metadata{ + metadata := &types.Metadata{ EthereumChainID: 1337, MaxExpirationTime: constants.UnlimitedExpirationTime, StartOfCurrentUTCDay: startOfCurrentUTCDay, EthRPCRequestsSentInCurrentUTCDay: requestsSentInCurrentDay, } - err = meshDB.SaveMetadata(metadata) + err = database.SaveMetadata(metadata) require.NoError(t, err) // Start a new rate limiter and set time to a few hours past midnight. @@ -93,10 +89,9 @@ func TestScenario2(t *testing.T) { // what we're trying to test. aClock := clock.NewMock() aClock.Set(startOfCurrentUTCDay.Add(3 * time.Hour)) - rateLimiter, err := New(defaultMaxRequestsPer24Hrs, math.MaxFloat64, meshDB, aClock) + rateLimiter, err := New(defaultMaxRequestsPer24Hrs, math.MaxFloat64, database, aClock) require.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) wg := &sync.WaitGroup{} wg.Add(1) go func() { @@ -126,9 +121,9 @@ func TestScenario2(t *testing.T) { // RateLimiter is instantiated. They then get updated after the checkpoint // interval elapses. func TestScenario3(t *testing.T) { - meshDB, err := meshdb.New("/tmp/meshdb_testing/"+uuid.New().String(), contractAddresses) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) - defer meshDB.Close() now := time.Now() yesterday := now.AddDate(0, 0, -1) @@ -136,18 +131,18 @@ func TestScenario3(t *testing.T) { // Set metadata to include an outdated `StartOfCurrentUTCDay` and an associated // non-zero `EthRPCRequestsSentInCurrentUTCDay` - metadata := &meshdb.Metadata{ + metadata := &types.Metadata{ EthereumChainID: 1337, MaxExpirationTime: constants.UnlimitedExpirationTime, StartOfCurrentUTCDay: yesterdayMidnightUTC, EthRPCRequestsSentInCurrentUTCDay: 5000, } - err = meshDB.SaveMetadata(metadata) + err = database.SaveMetadata(metadata) require.NoError(t, err) aClock := clock.NewMock() aClock.Set(now) - rateLimiter, err := New(defaultMaxRequestsPer24Hrs, defaultMaxRequestsPerSecond, meshDB, aClock) + rateLimiter, err := New(defaultMaxRequestsPer24Hrs, defaultMaxRequestsPerSecond, database, aClock) require.NoError(t, err) // Check that grant count and currentUTCCheckpoint were reset during instantiation @@ -156,7 +151,6 @@ func TestScenario3(t *testing.T) { assert.Equal(t, expectedCurrentUTCCheckpoint, rateLimiter.getCurrentUTCCheckpoint()) // Start the rateLimiter - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) wg := &sync.WaitGroup{} wg.Add(1) go func() { @@ -178,7 +172,7 @@ func TestScenario3(t *testing.T) { time.Sleep(50 * time.Millisecond) // Check metadata was stored in DB - metadata, err = meshDB.GetMetadata() + metadata, err = database.GetMetadata() require.NoError(t, err) assert.Equal(t, expectedCurrentUTCCheckpoint, metadata.StartOfCurrentUTCDay) @@ -188,12 +182,12 @@ func TestScenario3(t *testing.T) { wg.Wait() } -func initMetadata(t *testing.T, meshDB *meshdb.MeshDB) { - metadata := &meshdb.Metadata{ +func initMetadata(t *testing.T, database *db.DB) { + metadata := &types.Metadata{ EthereumChainID: 1337, MaxExpirationTime: constants.UnlimitedExpirationTime, } - err := meshDB.SaveMetadata(metadata) + err := database.SaveMetadata(metadata) require.NoError(t, err) } diff --git a/ethereum/simplestack/simple_stack.go b/ethereum/simplestack/simple_stack.go deleted file mode 100644 index 6adf35bb0..000000000 --- a/ethereum/simplestack/simple_stack.go +++ /dev/null @@ -1,177 +0,0 @@ -package simplestack - -import ( - "fmt" - "sync" - - "github.com/0xProject/0x-mesh/ethereum/miniheader" -) - -// UpdateType is the type of update applied to the in-memory stack -type UpdateType int - -// UpdateType values -const ( - Pop UpdateType = iota - Push -) - -// Update represents one update to the stack, either a pop or push of a miniHeader. -type Update struct { - Type UpdateType - MiniHeader *miniheader.MiniHeader -} - -// SimpleStack is a simple in-memory stack used in tests -type SimpleStack struct { - limit int - miniHeaders []*miniheader.MiniHeader - updates []*Update - mu sync.RWMutex - latestCheckpointID int -} - -// New instantiates a new SimpleStack -func New(retentionLimit int, miniHeaders []*miniheader.MiniHeader) *SimpleStack { - return &SimpleStack{ - limit: retentionLimit, - miniHeaders: miniHeaders, - updates: []*Update{}, - } -} - -// Peek returns the top of the stack -func (s *SimpleStack) Peek() (*miniheader.MiniHeader, error) { - s.mu.RLock() - defer s.mu.RUnlock() - - if len(s.miniHeaders) == 0 { - return nil, nil - } - return s.miniHeaders[len(s.miniHeaders)-1], nil -} - -// Pop returns the top of the stack and removes it from the stack -func (s *SimpleStack) Pop() (*miniheader.MiniHeader, error) { - s.mu.Lock() - defer s.mu.Unlock() - - return s.pop() -} - -// you MUST acquire a lock on the mutex `mu` before calling `pop()` -func (s *SimpleStack) pop() (*miniheader.MiniHeader, error) { - if len(s.miniHeaders) == 0 { - return nil, nil - } - top := s.miniHeaders[len(s.miniHeaders)-1] - s.miniHeaders = s.miniHeaders[:len(s.miniHeaders)-1] - s.updates = append(s.updates, &Update{ - Type: Pop, - MiniHeader: top, - }) - return top, nil -} - -// Push adds a miniheader.MiniHeader to the stack -func (s *SimpleStack) Push(miniHeader *miniheader.MiniHeader) error { - s.mu.Lock() - defer s.mu.Unlock() - - return s.push(miniHeader) -} - -// you MUST acquire a lock on the mutex `mu` before calling `push()` -func (s *SimpleStack) push(miniHeader *miniheader.MiniHeader) error { - for _, h := range s.miniHeaders { - if h.Number.Int64() == miniHeader.Number.Int64() { - return fmt.Errorf("attempted to push multiple blocks with block number %d to the stack", miniHeader.Number.Int64()) - } - } - - if len(s.miniHeaders) == s.limit { - s.miniHeaders = s.miniHeaders[1:] - } - s.miniHeaders = append(s.miniHeaders, miniHeader) - s.updates = append(s.updates, &Update{ - Type: Push, - MiniHeader: miniHeader, - }) - return nil -} - -// PeekAll returns all the miniHeaders currently in the stack -func (s *SimpleStack) PeekAll() ([]*miniheader.MiniHeader, error) { - s.mu.RLock() - defer s.mu.RUnlock() - - // Return copy of miniHeaders array - m := make([]*miniheader.MiniHeader, len(s.miniHeaders)) - copy(m, s.miniHeaders) - - return m, nil -} - -// Clear removes all items from the stack and clears any set checkpoint -func (s *SimpleStack) Clear() error { - s.mu.Lock() - defer s.mu.Unlock() - s.miniHeaders = []*miniheader.MiniHeader{} - s.updates = []*Update{} - s.latestCheckpointID = 0 - return nil -} - -// Checkpoint checkpoints the changes to the stack such that a subsequent -// call to `Reset(checkpointID)` with the checkpointID returned from this -// call will reset any subsequent changes back to the state of the stack -// at the time of this checkpoint -func (s *SimpleStack) Checkpoint() (int, error) { - s.mu.Lock() - defer s.mu.Unlock() - - s.updates = []*Update{} - s.latestCheckpointID++ - return s.latestCheckpointID, nil -} - -// Reset resets the in-memory stack with the contents from the latest checkpoint -func (s *SimpleStack) Reset(checkpointID int) error { - s.mu.Lock() - defer s.mu.Unlock() - - if s.latestCheckpointID == 0 { - return fmt.Errorf("Checkpoint() must be called before Reset() since without it the checkpointID is unspecified") - } else if checkpointID != s.latestCheckpointID { - return fmt.Errorf("Attempted to reset the stack to checkpoint %d but the latest checkpoint has ID %d", checkpointID, s.latestCheckpointID) - } - - for i := len(s.updates) - 1; i >= 0; i-- { - u := s.updates[i] - switch u.Type { - case Pop: - if err := s.push(u.MiniHeader); err != nil { - return err - } - case Push: - if _, err := s.pop(); err != nil { - return err - } - default: - return fmt.Errorf("Unrecognized update type encountered: %d", u.Type) - } - } - s.updates = []*Update{} - return nil -} - -// GetUpdates returns the updates applied since the last checkpoint -func (s *SimpleStack) GetUpdates() []*Update { - s.mu.RLock() - defer s.mu.RUnlock() - - // Return copy of updates array - u := make([]*Update, len(s.updates)) - copy(u, s.updates) - return u -} diff --git a/ethereum/simplestack/simple_stack_test.go b/ethereum/simplestack/simple_stack_test.go deleted file mode 100644 index 4505ca752..000000000 --- a/ethereum/simplestack/simple_stack_test.go +++ /dev/null @@ -1,225 +0,0 @@ -package simplestack - -import ( - "math/big" - "testing" - "time" - - "github.com/0xProject/0x-mesh/ethereum/miniheader" - "github.com/ethereum/go-ethereum/common" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -const limit = 10 - -var ( - miniHeaderOne = &miniheader.MiniHeader{ - Number: big.NewInt(1), - Hash: common.HexToHash("0x1"), - Parent: common.HexToHash("0x0"), - Timestamp: time.Now().UTC(), - } - miniHeaderTwo = &miniheader.MiniHeader{ - Number: big.NewInt(2), - Hash: common.HexToHash("0x2"), - Parent: common.HexToHash("0x1"), - Timestamp: time.Now().UTC(), - } -) - -func TestSimpleStackPushPeekPop(t *testing.T) { - stack := New(10, []*miniheader.MiniHeader{}) - err := stack.Push(miniHeaderOne) - require.NoError(t, err) - - expectedLen := 1 - miniHeaders, err := stack.PeekAll() - require.NoError(t, err) - assert.Len(t, miniHeaders, expectedLen) - - miniHeader, err := stack.Peek() - require.NoError(t, err) - assert.Equal(t, miniHeaders[0], miniHeader) - - expectedLen = 1 - miniHeaders, err = stack.PeekAll() - require.NoError(t, err) - assert.Len(t, miniHeaders, expectedLen) - - miniHeader, err = stack.Pop() - require.NoError(t, err) - assert.Equal(t, miniHeaders[0], miniHeader) - - expectedLen = 0 - miniHeaders, err = stack.PeekAll() - require.NoError(t, err) - assert.Len(t, miniHeaders, expectedLen) -} - -func TestSimpleStackErrorIfPushTwoHeadersWithSameNumber(t *testing.T) { - stack := New(10, []*miniheader.MiniHeader{}) - // Push miniHeaderOne - err := stack.Push(miniHeaderOne) - require.NoError(t, err) - // Push miniHeaderOne again - err = stack.Push(miniHeaderOne) - assert.Error(t, err) -} - -func TestSimpleStackErrorIfResetWithoutCheckpointFirst(t *testing.T) { - stack := New(10, []*miniheader.MiniHeader{}) - - checkpointID := 123 - err := stack.Reset(checkpointID) - require.Error(t, err) -} - -func TestSimpleStackClear(t *testing.T) { - stack := New(10, []*miniheader.MiniHeader{}) - - err := stack.Push(miniHeaderOne) - require.NoError(t, err) - - miniHeader, err := stack.Peek() - require.NoError(t, err) - assert.Equal(t, miniHeaderOne, miniHeader) - - err = stack.Clear() - require.NoError(t, err) - - miniHeader, err = stack.Peek() - require.NoError(t, err) - require.Nil(t, miniHeader) -} - -func TestSimpleStackErrorIfResetWithOldCheckpoint(t *testing.T) { - stack := New(10, []*miniheader.MiniHeader{}) - - checkpointIDOne, err := stack.Checkpoint() - require.NoError(t, err) - - checkpointIDTwo, err := stack.Checkpoint() - require.NoError(t, err) - - err = stack.Reset(checkpointIDOne) - require.Error(t, err) - - err = stack.Reset(checkpointIDTwo) - require.NoError(t, err) -} - -func TestSimpleStackCheckpoint(t *testing.T) { - stack := New(10, []*miniheader.MiniHeader{}) - err := stack.Push(miniHeaderOne) - require.NoError(t, err) - err = stack.Push(miniHeaderTwo) - require.NoError(t, err) - - assert.Len(t, stack.updates, 2) - - _, err = stack.Checkpoint() - require.NoError(t, err) - - assert.Len(t, stack.updates, 0) - - miniHeader, err := stack.Pop() - require.NoError(t, err) - assert.Equal(t, miniHeaderTwo, miniHeader) - - miniHeader, err = stack.Pop() - require.NoError(t, err) - assert.Equal(t, miniHeaderOne, miniHeader) - - assert.Len(t, stack.updates, 2) - - _, err = stack.Checkpoint() - require.NoError(t, err) - - assert.Len(t, stack.updates, 0) -} - -func TestSimpleStackCheckpointAfterSameHeaderPushedAndPopped(t *testing.T) { - stack := New(10, []*miniheader.MiniHeader{}) - // Push miniHeaderOne - err := stack.Push(miniHeaderOne) - require.NoError(t, err) - // Pop miniHeaderOne - miniHeader, err := stack.Pop() - require.NoError(t, err) - assert.Equal(t, miniHeaderOne, miniHeader) - - assert.Len(t, stack.miniHeaders, 0) - assert.Len(t, stack.updates, 2) - - _, err = stack.Checkpoint() - require.NoError(t, err) - - assert.Len(t, stack.updates, 0) -} - -func TestSimpleStackCheckpointAfterSameHeaderPushedThenPoppedThenPushed(t *testing.T) { - stack := New(10, []*miniheader.MiniHeader{}) - // Push miniHeaderOne - err := stack.Push(miniHeaderOne) - require.NoError(t, err) - // Pop miniHeaderOne - miniHeader, err := stack.Pop() - require.NoError(t, err) - assert.Equal(t, miniHeaderOne, miniHeader) - // Push miniHeaderOne again - err = stack.Push(miniHeaderOne) - require.NoError(t, err) - - assert.Len(t, stack.miniHeaders, 1) - assert.Len(t, stack.updates, 3) - - _, err = stack.Checkpoint() - require.NoError(t, err) - - assert.Len(t, stack.updates, 0) -} - -func TestSimpleStackCheckpointThenReset(t *testing.T) { - stack := New(10, []*miniheader.MiniHeader{}) - - checkpointID, err := stack.Checkpoint() - require.NoError(t, err) - - err = stack.Push(miniHeaderOne) - require.NoError(t, err) - - assert.Len(t, stack.miniHeaders, 1) - assert.Len(t, stack.updates, 1) - - err = stack.Reset(checkpointID) - require.NoError(t, err) - - assert.Len(t, stack.miniHeaders, 0) - assert.Len(t, stack.updates, 0) - - err = stack.Push(miniHeaderTwo) - require.NoError(t, err) - - assert.Len(t, stack.miniHeaders, 1) - assert.Len(t, stack.updates, 1) - - checkpointID, err = stack.Checkpoint() - require.NoError(t, err) - - assert.Len(t, stack.miniHeaders, 1) - assert.Len(t, stack.updates, 0) - - miniHeader, err := stack.Pop() - require.NoError(t, err) - assert.Equal(t, miniHeader, miniHeaderTwo) - - assert.Len(t, stack.miniHeaders, 0) - assert.Len(t, stack.updates, 1) - - checkpointID, err = stack.Checkpoint() - require.NoError(t, err) - - assert.Len(t, stack.miniHeaders, 0) - assert.Len(t, stack.updates, 0) -} diff --git a/expirationwatch/expiration_watcher.go b/expirationwatch/expiration_watcher.go deleted file mode 100644 index f5915b9ac..000000000 --- a/expirationwatch/expiration_watcher.go +++ /dev/null @@ -1,103 +0,0 @@ -package expirationwatch - -import ( - "sync" - "time" - - "github.com/albrow/stringset" - "github.com/ocdogan/rbt" - log "github.com/sirupsen/logrus" -) - -// ExpiredItem represents an expired item returned from the Watcher -type ExpiredItem struct { - ExpirationTimestamp time.Time - ID string -} - -// Watcher watches the expiration of items -type Watcher struct { - expiredItems chan []ExpiredItem - rbTreeMu sync.RWMutex - rbTree *rbt.RbTree -} - -// New instantiates a new expiration watcher -func New() *Watcher { - rbTree := rbt.NewRbTree() - return &Watcher{ - expiredItems: make(chan []ExpiredItem, 10), - rbTree: rbTree, - } -} - -// Add adds a new item identified by an ID to the expiration watcher -func (w *Watcher) Add(expirationTimestamp time.Time, id string) { - key := rbt.Int64Key(expirationTimestamp.Unix()) - w.rbTreeMu.Lock() - defer w.rbTreeMu.Unlock() - value, ok := w.rbTree.Get(&key) - var ids stringset.Set - if !ok { - ids = stringset.New() - } else { - ids = value.(stringset.Set) - } - ids.Add(id) - w.rbTree.Insert(&key, ids) -} - -// Remove removes the item with a specified id from the expiration watcher -func (w *Watcher) Remove(expirationTimestamp time.Time, id string) { - key := rbt.Int64Key(expirationTimestamp.Unix()) - w.rbTreeMu.Lock() - defer w.rbTreeMu.Unlock() - value, ok := w.rbTree.Get(&key) - if !ok { - // Due to the asynchronous nature of the Watcher and OrderWatcher, there are - // race-conditions where we try to remove an item from the Watcher after it - // has already been removed. - log.WithFields(log.Fields{ - "id": id, - }).Trace("Attempted to remove item from Watcher that no longer exists") - return // Noop - } else { - ids := value.(stringset.Set) - ids.Remove(id) - if len(ids) == 0 { - w.rbTree.Delete(&key) - } else { - w.rbTree.Insert(&key, ids) - } - } -} - -// Prune checks for any expired items given a timestamp and removes any expired -// items from the expiration watcher and returns them to the caller -func (w *Watcher) Prune(timestamp time.Time) []ExpiredItem { - pruned := []ExpiredItem{} - for { - w.rbTreeMu.RLock() - key, value := w.rbTree.Min() - w.rbTreeMu.RUnlock() - if key == nil { - break - } - expirationTimeSeconds := int64(*key.(*rbt.Int64Key)) - expirationTime := time.Unix(expirationTimeSeconds, 0) - if timestamp.Before(expirationTime) { - break - } - ids := value.(stringset.Set) - for id := range ids { - pruned = append(pruned, ExpiredItem{ - ExpirationTimestamp: expirationTime, - ID: id, - }) - } - w.rbTreeMu.Lock() - w.rbTree.Delete(key) - w.rbTreeMu.Unlock() - } - return pruned -} diff --git a/expirationwatch/expiration_watcher_test.go b/expirationwatch/expiration_watcher_test.go deleted file mode 100644 index 7a7e55275..000000000 --- a/expirationwatch/expiration_watcher_test.go +++ /dev/null @@ -1,137 +0,0 @@ -package expirationwatch - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestPrunesExpiredItems(t *testing.T) { - watcher := New() - - current := time.Now().Truncate(time.Second) - expiryEntryOne := ExpiredItem{ - ExpirationTimestamp: current.Add(-3 * time.Second), - ID: "0x8e209dda7e515025d0c34aa61a0d1156a631248a4318576a2ce0fb408d97385e", - } - watcher.Add(expiryEntryOne.ExpirationTimestamp, expiryEntryOne.ID) - - expiryEntryTwo := ExpiredItem{ - ExpirationTimestamp: current.Add(-1 * time.Second), - ID: "0x12ab7edd34515025d0c34aa61a0d1156a631248a4318576a2ce0fb408d3bee521", - } - watcher.Add(expiryEntryTwo.ExpirationTimestamp, expiryEntryTwo.ID) - - pruned := watcher.Prune(current) - assert.Len(t, pruned, 2, "two expired items should get pruned") - assert.Equal(t, expiryEntryOne, pruned[0]) - assert.Equal(t, expiryEntryTwo, pruned[1]) -} - -func TestPrunesTwoExpiredItemsWithSameExpiration(t *testing.T) { - watcher := New() - - current := time.Now().Truncate(time.Second) - expiration := current.Add(-3 * time.Second) - expiryEntryOne := ExpiredItem{ - ExpirationTimestamp: expiration, - ID: "0x8e209dda7e515025d0c34aa61a0d1156a631248a4318576a2ce0fb408d97385e", - } - watcher.Add(expiryEntryOne.ExpirationTimestamp, expiryEntryOne.ID) - - expiryEntryTwo := ExpiredItem{ - ExpirationTimestamp: expiration, - ID: "0x12ab7edd34515025d0c34aa61a0d1156a631248a4318576a2ce0fb408d3bee521", - } - watcher.Add(expiryEntryTwo.ExpirationTimestamp, expiryEntryTwo.ID) - - pruned := watcher.Prune(current) - assert.Len(t, pruned, 2, "two expired items should get pruned") - hashes := map[string]bool{ - expiryEntryOne.ID: true, - expiryEntryTwo.ID: true, - } - for _, expiredItem := range pruned { - assert.True(t, hashes[expiredItem.ID]) - } -} - -func TestPrunesBarelyExpiredItem(t *testing.T) { - watcher := New() - - current := time.Now().Truncate(time.Second) - expiryEntryOne := ExpiredItem{ - ExpirationTimestamp: current, - ID: "0x8e209dda7e515025d0c34aa61a0d1156a631248a4318576a2ce0fb408d97385e", - } - watcher.Add(expiryEntryOne.ExpirationTimestamp, expiryEntryOne.ID) - - pruned := watcher.Prune(current) - assert.Len(t, pruned, 1, "one expired item should get pruned") - assert.Equal(t, expiryEntryOne, pruned[0]) -} - -func TestKeepsUnexpiredItem(t *testing.T) { - watcher := New() - - id := "0x8e209dda7e515025d0c34aa61a0d1156a631248a4318576a2ce0fb408d97385e" - current := time.Now().Truncate(time.Second) - watcher.Add(current.Add(10*time.Second), id) - - pruned := watcher.Prune(current) - assert.Equal(t, 0, len(pruned), "Doesn't prune unexpired item") -} - -func TestReturnsEmptyIfNoItems(t *testing.T) { - watcher := New() - - pruned := watcher.Prune(time.Now()) - assert.Len(t, pruned, 0, "Returns empty array when no items tracked") -} - -func TestRemoveOnlyItemWithSpecificExpirationTime(t *testing.T) { - watcher := New() - - current := time.Now().Truncate(time.Second) - expiryEntryOne := ExpiredItem{ - ExpirationTimestamp: current.Add(-3 * time.Second), - ID: "0x8e209dda7e515025d0c34aa61a0d1156a631248a4318576a2ce0fb408d97385e", - } - watcher.Add(expiryEntryOne.ExpirationTimestamp, expiryEntryOne.ID) - - expiryEntryTwo := ExpiredItem{ - ExpirationTimestamp: current.Add(-1 * time.Second), - ID: "0x12ab7edd34515025d0c34aa61a0d1156a631248a4318576a2ce0fb408d3bee521", - } - watcher.Add(expiryEntryTwo.ExpirationTimestamp, expiryEntryTwo.ID) - - watcher.Remove(expiryEntryTwo.ExpirationTimestamp, expiryEntryTwo.ID) - - pruned := watcher.Prune(current) - assert.Len(t, pruned, 1, "two expired items should get pruned") - assert.Equal(t, expiryEntryOne, pruned[0]) -} -func TestRemoveItemWhichSharesExpirationTimeWithOtherItems(t *testing.T) { - watcher := New() - - current := time.Now().Truncate(time.Second) - singleExpirationTimestamp := current.Add(-3 * time.Second) - expiryEntryOne := ExpiredItem{ - ExpirationTimestamp: singleExpirationTimestamp, - ID: "0x8e209dda7e515025d0c34aa61a0d1156a631248a4318576a2ce0fb408d97385e", - } - watcher.Add(expiryEntryOne.ExpirationTimestamp, expiryEntryOne.ID) - - expiryEntryTwo := ExpiredItem{ - ExpirationTimestamp: singleExpirationTimestamp, - ID: "0x12ab7edd34515025d0c34aa61a0d1156a631248a4318576a2ce0fb408d3bee521", - } - watcher.Add(expiryEntryTwo.ExpirationTimestamp, expiryEntryTwo.ID) - - watcher.Remove(expiryEntryTwo.ExpirationTimestamp, expiryEntryTwo.ID) - - pruned := watcher.Prune(current) - assert.Len(t, pruned, 1, "two expired items should get pruned") - assert.Equal(t, expiryEntryOne, pruned[0]) -} diff --git a/go.mod b/go.mod index dbad54bb9..d4cdd70ad 100644 --- a/go.mod +++ b/go.mod @@ -34,15 +34,16 @@ require ( github.com/google/uuid v1.1.1 github.com/hashicorp/go-multierror v1.1.0 // indirect github.com/hashicorp/golang-lru v0.5.4 + github.com/ido50/sqlz v0.0.0-20200308174337-487b8faf612c github.com/ipfs/go-datastore v0.3.1 github.com/ipfs/go-ds-leveldb v0.4.0 + github.com/jmoiron/sqlx v1.2.0 github.com/jpillora/backoff v0.0.0-20170918002102-8eab2debe79d github.com/karalabe/usb v0.0.0-20191104083709-911d15fe12a9 // indirect github.com/karlseguin/ccache v2.0.3+incompatible github.com/karlseguin/expect v1.0.1 // indirect github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect github.com/lib/pq v1.2.0 - github.com/libp2p/go-conn-security v0.1.0 // indirect github.com/libp2p/go-libp2p v0.5.1 github.com/libp2p/go-libp2p-autonat-svc v0.1.0 github.com/libp2p/go-libp2p-circuit v0.1.4 @@ -52,13 +53,13 @@ require ( github.com/libp2p/go-libp2p-kad-dht v0.5.0 github.com/libp2p/go-libp2p-peer v0.2.0 github.com/libp2p/go-libp2p-peerstore v0.1.4 - github.com/libp2p/go-libp2p-protocol v0.1.0 // indirect github.com/libp2p/go-libp2p-pubsub v0.2.5 github.com/libp2p/go-libp2p-swarm v0.2.2 github.com/libp2p/go-maddr-filter v0.0.5 github.com/libp2p/go-tcp-transport v0.1.1 github.com/libp2p/go-ws-transport v0.2.0 github.com/mattn/go-colorable v0.1.2 // indirect + github.com/mattn/go-sqlite3 v2.0.3+incompatible github.com/multiformats/go-multiaddr v0.2.1 github.com/multiformats/go-multiaddr-dns v0.2.0 github.com/ocdogan/rbt v0.0.0-20160425054511-de6e2b48be33 @@ -79,5 +80,6 @@ require ( github.com/xeipuuv/gojsonschema v1.1.0 golang.org/x/crypto v0.0.0-20200311171314-f7b00557c8c4 golang.org/x/time v0.0.0-20191024005414-555d28b269f0 + gopkg.in/DATA-DOG/go-sqlmock.v1 v1.3.0 // indirect gopkg.in/karlseguin/expect.v1 v1.0.1 // indirect ) diff --git a/go.sum b/go.sum index 55e84ef81..40fce9c2e 100644 --- a/go.sum +++ b/go.sum @@ -20,10 +20,6 @@ github.com/0xProject/go-ethereum v1.8.8-0.20200603225022-cb1f52043425 h1:BFs4B5V github.com/0xProject/go-ethereum v1.8.8-0.20200603225022-cb1f52043425/go.mod h1:oP8FC5+TbICUyftkTWs+8JryntjIJLJvWvApK3z2AYw= github.com/0xProject/go-libp2p-pubsub v0.1.1-0.20200228234556-aaa0317e068a h1:OHjKy7tLiqETUbEzF2UmqaF8eUTjHqmJM2sP79dguJs= github.com/0xProject/go-libp2p-pubsub v0.1.1-0.20200228234556-aaa0317e068a/go.mod h1:R4R0kH/6p2vu8O9xsue0HNSjEuXMEPBgg4h3nVDI15o= -github.com/0xProject/go-ws-transport v0.1.1-0.20200123233232-0b38359294da h1:8POpSF5LiutCqYqgG+vP4OcUFj3nnyOSddcSjUEbGKA= -github.com/0xProject/go-ws-transport v0.1.1-0.20200123233232-0b38359294da/go.mod h1:9BHJz/4Q5A9ludYWKoGCFC5gUElzlHoKzu0yY9p/klM= -github.com/0xProject/go-ws-transport v0.1.1-0.20200131210609-7f37eee84b58 h1:p9qXd3Krt69MEC2YqNiNjuP+Hxe7cTuABx59GPLCc5s= -github.com/0xProject/go-ws-transport v0.1.1-0.20200131210609-7f37eee84b58/go.mod h1:9BHJz/4Q5A9ludYWKoGCFC5gUElzlHoKzu0yY9p/klM= github.com/0xProject/go-ws-transport v0.1.1-0.20200201000210-2db3396fec39 h1:zMth0Fw7e4MWjaNoN+lKzwdvqeNI2Mj12Zk63AMC3vI= github.com/0xProject/go-ws-transport v0.1.1-0.20200201000210-2db3396fec39/go.mod h1:9BHJz/4Q5A9ludYWKoGCFC5gUElzlHoKzu0yY9p/klM= github.com/0xProject/go-ws-transport v0.1.1-0.20200530011125-b4ab00766967 h1:D7HZfoMYUXCYTVflLSqAXvCTotwR3cQn8s9peVA3/5M= @@ -169,6 +165,8 @@ github.com/go-logfmt/logfmt v0.3.0 h1:8HUsc87TaSWLKwrnumgC8/YconD2fJQsRJAsWaPg2i github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-ole/go-ole v1.2.1/go.mod h1:7FAglXiTm7HKlQRDeOQ6ZNUHidzCWXuZWq/1dTyBNF8= github.com/go-sourcemap/sourcemap v2.1.2+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= +github.com/go-sql-driver/mysql v1.4.0 h1:7LxgVwFb2hIQtMm87NdgAVfXjnt4OePseqT1tKx+opk= +github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= @@ -219,12 +217,15 @@ github.com/hashicorp/golang-lru v0.5.3 h1:YPkqC67at8FYaadspW/6uE0COsBxS2656RLEr8 github.com/hashicorp/golang-lru v0.5.3/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huin/goupnp v1.0.0 h1:wg75sLpL6DZqwHQN6E1Cfk6mtfzS45z8OV+ic+DtHRo= github.com/huin/goupnp v1.0.0/go.mod h1:n9v9KO1tAxYH82qOn+UTIFQDmx5n1Zxd/ClZDMX7Bnc= github.com/huin/goutil v0.0.0-20170803182201-1ca381bf3150/go.mod h1:PpLOETDnJ0o3iZrZfqZzyLl6l7F3c6L1oWn7OICBi6o= +github.com/ido50/sqlz v0.0.0-20200308174337-487b8faf612c h1:29iV3Zn1Q5D6rviM3+Z/GN1ZKKzdcrTV6KbtZUAo1/c= +github.com/ido50/sqlz v0.0.0-20200308174337-487b8faf612c/go.mod h1:Fps9X8N3LiLLQNU9VT8fWOhEu8277Q3hBAVNtFfwswY= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/influxdata/influxdb v1.2.3-0.20180221223340-01288bdb0883/go.mod h1:qZna6X/4elxqT3yI9iZYdZrWWdeFOOprn86kgg4+IzY= github.com/ipfs/go-cid v0.0.1/go.mod h1:GHWU/WuQdMPmIosc4Yn1bcCT7dSeX4lBafM7iqUPQvM= @@ -275,6 +276,8 @@ github.com/jbenet/goprocess v0.1.3/go.mod h1:5yspPrukOVuOLORacaBi858NqyClJPQxYZl github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= +github.com/jmoiron/sqlx v1.2.0 h1:41Ip0zITnmWNR/vHV+S4m+VoUivnWY5E4OJfLZjCJMA= +github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks= github.com/jpillora/backoff v0.0.0-20170918002102-8eab2debe79d h1:ix3WmphUvN0GDd0DO9MH0v6/5xTv+Xm1bPN+1UJn58k= github.com/jpillora/backoff v0.0.0-20170918002102-8eab2debe79d/go.mod h1:2iMrUgbbvHEiQClaW2NsSzMyGHqN+rDFqY705q49KG0= github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ= @@ -310,6 +313,7 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/libp2p/go-addr-util v0.0.1 h1:TpTQm9cXVRVSKsYbgQ7GKc3KbbHVTnbostgGaDEP+88= @@ -317,8 +321,6 @@ github.com/libp2p/go-addr-util v0.0.1/go.mod h1:4ac6O7n9rIAKB1dnd+s8IbbMXkt+oBpz github.com/libp2p/go-buffer-pool v0.0.1/go.mod h1:xtyIz9PMobb13WaxR6Zo1Pd1zXJKYg0a8KiIvDp3TzQ= github.com/libp2p/go-buffer-pool v0.0.2 h1:QNK2iAFa8gjAe1SPz6mHSMuCcjs+X1wlHzeOSqcmlfs= github.com/libp2p/go-buffer-pool v0.0.2/go.mod h1:MvaB6xw5vOrDl8rYZGLFdKAuk/hRoRZd1Vi32+RXyFM= -github.com/libp2p/go-conn-security v0.1.0 h1:q8ii9TUOtSBD1gIoKTSOZIzPFP/agPM28amrCCoeIIA= -github.com/libp2p/go-conn-security v0.1.0/go.mod h1:NQdPF4opCZ5twtEUadzPL0tNSdkrbFc/HmLO7eWqEzY= github.com/libp2p/go-conn-security-multistream v0.1.0 h1:aqGmto+ttL/uJgX0JtQI0tD21CIEy5eYd1Hlp0juHY0= github.com/libp2p/go-conn-security-multistream v0.1.0/go.mod h1:aw6eD7LOsHEX7+2hJkDxw1MteijaVcI+/eP2/x3J1xc= github.com/libp2p/go-eventbus v0.1.0 h1:mlawomSAjjkk97QnYiEmHsLu7E136+2oCWSHRUvMfzQ= @@ -329,7 +331,6 @@ github.com/libp2p/go-libp2p v0.1.0/go.mod h1:6D/2OBauqLUoqcADOJpn9WbKqvaM07tDw68 github.com/libp2p/go-libp2p v0.5.0/go.mod h1:Os7a5Z3B+ErF4v7zgIJ7nBHNu2LYt8ZMLkTQUB3G/wA= github.com/libp2p/go-libp2p v0.5.1 h1:kZ9jg+2B9IIptRcltBHKBrQdhXNNSrjCoztvrMx7tqI= github.com/libp2p/go-libp2p v0.5.1/go.mod h1:Os7a5Z3B+ErF4v7zgIJ7nBHNu2LYt8ZMLkTQUB3G/wA= -github.com/libp2p/go-libp2p v6.0.23+incompatible h1:J/h9LNTeQwMhJeg3M96r/UOPLGxJn1vqJBb3LeKufpM= github.com/libp2p/go-libp2p-autonat v0.1.0 h1:aCWAu43Ri4nU0ZPO7NyLzUvvfqd0nE3dX0R/ZGYVgOU= github.com/libp2p/go-libp2p-autonat v0.1.0/go.mod h1:1tLf2yXxiE/oKGtDwPYWTSYG3PtvYlJmg7NeVtPRqH8= github.com/libp2p/go-libp2p-autonat v0.1.1 h1:WLBZcIRsjZlWdAZj9CiBSvU2wQXoUOiS1Zk1tM7DTJI= @@ -381,10 +382,6 @@ github.com/libp2p/go-libp2p-peerstore v0.1.3 h1:wMgajt1uM2tMiqf4M+4qWKVyyFc8SfA+ github.com/libp2p/go-libp2p-peerstore v0.1.3/go.mod h1:BJ9sHlm59/80oSkpWgr1MyY1ciXAXV397W6h1GH/uKI= github.com/libp2p/go-libp2p-peerstore v0.1.4 h1:d23fvq5oYMJ/lkkbO4oTwBp/JP+I/1m5gZJobNXCE/k= github.com/libp2p/go-libp2p-peerstore v0.1.4/go.mod h1:+4BDbDiiKf4PzpANZDAT+knVdLxvqh7hXOujessqdzs= -github.com/libp2p/go-libp2p-protocol v0.1.0 h1:HdqhEyhg0ToCaxgMhnOmUO8snQtt/kQlcjVk3UoJU3c= -github.com/libp2p/go-libp2p-protocol v0.1.0/go.mod h1:KQPHpAabB57XQxGrXCNvbL6UEXfQqUgC/1adR2Xtflk= -github.com/libp2p/go-libp2p-pubsub v0.2.5 h1:tPKbkjAUI0xLGN3KKTKKy9TQEviVfrP++zJgH5Muke4= -github.com/libp2p/go-libp2p-pubsub v0.2.5/go.mod h1:9Q2RRq8ofXkoewORcyVlgUFDKLKw7BuYSlJVWRcVk3Y= github.com/libp2p/go-libp2p-record v0.1.2 h1:M50VKzWnmUrk/M5/Dz99qO9Xh4vs8ijsK+7HkJvRP+0= github.com/libp2p/go-libp2p-record v0.1.2/go.mod h1:pal0eNcT5nqZaTV7UGhqeGqxFgGdsU/9W//C8dqjQDk= github.com/libp2p/go-libp2p-routing v0.1.0 h1:hFnj3WR3E2tOcKaGpyzfP4gvFZ3t8JkQmbapN0Ct+oU= @@ -438,8 +435,6 @@ github.com/libp2p/go-stream-muxer-multistream v0.2.0/go.mod h1:j9eyPol/LLRqT+GPL github.com/libp2p/go-tcp-transport v0.1.0/go.mod h1:oJ8I5VXryj493DEJ7OsBieu8fcg2nHGctwtInJVpipc= github.com/libp2p/go-tcp-transport v0.1.1 h1:yGlqURmqgNA2fvzjSgZNlHcsd/IulAnKM8Ncu+vlqnw= github.com/libp2p/go-tcp-transport v0.1.1/go.mod h1:3HzGvLbx6etZjnFlERyakbaYPdfjg2pWP97dFZworkY= -github.com/libp2p/go-ws-transport v0.0.0-20191008032742-3098bba549e8 h1:F1fQYoej9mjMSkYq3fcdlZK8xGzn4Bhp3cxSgqswp6M= -github.com/libp2p/go-ws-transport v0.0.0-20191008032742-3098bba549e8/go.mod h1:040XOA+VSh/dAe8ZsMIjP4EZpI8yMRDFVLyADY+snxs= github.com/libp2p/go-yamux v1.2.2/go.mod h1:FGTiPvoV/3DVdgWpX+tM0OW3tsM+W5bSE3gZwqQTcow= github.com/libp2p/go-yamux v1.2.3 h1:xX8A36vpXb59frIzWFdEgptLMsOANMFq2K7fPRlunYI= github.com/libp2p/go-yamux v1.2.3/go.mod h1:FGTiPvoV/3DVdgWpX+tM0OW3tsM+W5bSE3gZwqQTcow= @@ -463,6 +458,9 @@ github.com/mattn/go-runewidth v0.0.3/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzp github.com/mattn/go-runewidth v0.0.4 h1:2BvfKmzob6Bmd4YsL0zygOqfdFnK7GR4QL06Do4/p7Y= github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= +github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= +github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/miekg/dns v1.1.12/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= @@ -665,7 +663,6 @@ golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnf golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190225124518-7f87c0fbb88b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -737,6 +734,7 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898 h1:/atklqdjdhuosWIl6AIbOeHJjicWYPqR9bpxqxYG2pA= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= @@ -744,6 +742,8 @@ google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRn google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= +gopkg.in/DATA-DOG/go-sqlmock.v1 v1.3.0 h1:FVCohIoYO7IJoDDVpV2pdq7SgrMH6wHnuTyrdrxJNoY= +gopkg.in/DATA-DOG/go-sqlmock.v1 v1.3.0/go.mod h1:OdE7CF6DbADk7lN8LIKRzRJTTZXIjtWgA5THM5lhBAw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= diff --git a/integration-tests/browser_integration_test.go b/integration-tests/browser_integration_test.go index 1d0adc717..ca317ba23 100644 --- a/integration-tests/browser_integration_test.go +++ b/integration-tests/browser_integration_test.go @@ -59,10 +59,12 @@ func TestBrowserIntegration(t *testing.T) { customOrderFilter := `{"properties": { "makerAddress": { "const": "0x6ecbe1db9ef729cbe972c83fb886247691fb6beb" }}}` // Start the standalone node in a goroutine. + // Note(albrow) we need to use a specific data dir because we need to use the same private key for each test. + // The tests themselves are written in a way that depend on this. wg.Add(1) go func() { defer wg.Done() - startStandaloneNode(t, ctx, count, customOrderFilter, standaloneLogMessages) + startStandaloneNode(t, ctx, count, browserIntegrationTestDataDir, customOrderFilter, standaloneLogMessages) }() // standaloneOrder is an order that will be sent to the network by the diff --git a/integration-tests/constants.go b/integration-tests/constants.go index 4e0a775f4..c87e452dd 100644 --- a/integration-tests/constants.go +++ b/integration-tests/constants.go @@ -6,7 +6,7 @@ const ( wsRPCPort = 60501 httpRPCPort = 60701 - standaloneDataDirPrefix = "./data/standalone-" + browserIntegrationTestDataDir = "./data/standalone-0" standaloneWSRPCEndpointPrefix = "ws://localhost:" standaloneHTTPRPCEndpointPrefix = "http://localhost:" standaloneRPCAddrPrefix = "localhost:" diff --git a/integration-tests/rpc_integration_test.go b/integration-tests/rpc_integration_test.go index d3729b03f..b75e9fd0c 100644 --- a/integration-tests/rpc_integration_test.go +++ b/integration-tests/rpc_integration_test.go @@ -19,6 +19,7 @@ import ( "github.com/0xProject/0x-mesh/scenario" "github.com/0xProject/0x-mesh/scenario/orderopts" "github.com/0xProject/0x-mesh/zeroex" + "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -48,7 +49,7 @@ func runAddOrdersSuccessTest(t *testing.T, rpcEndpointPrefix, rpcServerType stri count := int(atomic.AddInt32(&nodeCount, 1)) go func() { defer wg.Done() - startStandaloneNode(t, ctx, count, "", logMessages) + startStandaloneNode(t, ctx, count, "", "", logMessages) }() // Wait until the rpc server has been started, and then create an rpc client @@ -117,7 +118,7 @@ func runGetOrdersTest(t *testing.T, rpcEndpointPrefix, rpcServerType string, rpc count := int(atomic.AddInt32(&nodeCount, 1)) go func() { defer wg.Done() - startStandaloneNode(t, ctx, count, "", logMessages) + startStandaloneNode(t, ctx, count, "", "", logMessages) }() _, err := waitForLogSubstring(ctx, logMessages, fmt.Sprintf("started %s RPC server", rpcServerType)) @@ -144,14 +145,13 @@ func runGetOrdersTest(t *testing.T, rpcEndpointPrefix, rpcServerType string, rpc assert.Len(t, validationResponse.Accepted, numOrders) assert.Len(t, validationResponse.Rejected, 0) - fixmeGetOrdersResponse, err := client.GetOrders(0, 10, "") + getOrdersResponse, err := client.GetOrders(10, common.Hash{}) require.NoError(t, err) // NOTE(jalextowle) This statement holds true for many pagination algorithms, but it may be necessary // to drop this requirement if the `GetOrders` endpoint changes dramatically. - require.Len(t, fixmeGetOrdersResponse.OrdersInfos, 10) + require.Len(t, getOrdersResponse.OrdersInfos, 10) // Make a new "GetOrders" request with different pagination parameters. - snapshotID := "" for _, testCase := range []struct { ordersPerPage int }{ @@ -169,29 +169,26 @@ func runGetOrdersTest(t *testing.T, rpcEndpointPrefix, rpcServerType string, rpc }, } { if testCase.ordersPerPage <= 0 { - _, err := client.GetOrders(0, testCase.ordersPerPage, snapshotID) + _, err := client.GetOrders(testCase.ordersPerPage, common.Hash{}) require.EqualError(t, err, "perPage cannot be zero") } else { - - // If numOrders % testCase.ordersPerPage is nonzero, then we must increment the number of pages to - // iterate through because the numOrder / testCase.ordersPerPage calculation rounds down. - highestPageNumber := numOrders / testCase.ordersPerPage - if numOrders%testCase.ordersPerPage > 0 { - highestPageNumber++ - } - // Iterate through enough pages to get all of the orders in the mesh nodes database. Compare the // responses to the orders that we expect to be in the database. var responseOrders []*types.OrderInfo - for pageNumber := 0; pageNumber < highestPageNumber; pageNumber++ { + currentMinHash := common.Hash{} + for { expectedTimestamp := time.Now().UTC() - getOrdersResponse, err := client.GetOrders(pageNumber, testCase.ordersPerPage, snapshotID) - assert.WithinDuration(t, expectedTimestamp, getOrdersResponse.SnapshotTimestamp, time.Second) + getOrdersResponse, err := client.GetOrders(testCase.ordersPerPage, currentMinHash) + assert.WithinDuration(t, expectedTimestamp, getOrdersResponse.Timestamp, time.Second) require.NoError(t, err) - // NOTE(jalextowle) This statement holds true for many pagination algorithms, but it may be necessary - // to drop this requirement if the `GetOrders` endpoint changes dramatically. - require.Len(t, getOrdersResponse.OrdersInfos, min(testCase.ordersPerPage, numOrders-pageNumber*testCase.ordersPerPage)) - responseOrders = append(responseOrders, getOrdersResponse.OrdersInfos...) + orderInfos := getOrdersResponse.OrdersInfos + assert.LessOrEqual(t, len(orderInfos), testCase.ordersPerPage, "response contained too many orders") + responseOrders = append(responseOrders, orderInfos...) + if len(orderInfos) > 0 { + currentMinHash = orderInfos[len(orderInfos)-1].OrderHash + } else { + break + } } assertSignedOrdersMatch(t, signedTestOrders, responseOrders) } @@ -226,7 +223,7 @@ func runGetStatsTest(t *testing.T, rpcEndpointPrefix, rpcServerType string, rpcP count := int(atomic.AddInt32(&nodeCount, 1)) go func() { defer wg.Done() - startStandaloneNode(t, ctx, count, "", logMessages) + startStandaloneNode(t, ctx, count, "", "", logMessages) }() // Wait for the rpc server to start and get the peer ID of the node. Start the @@ -253,16 +250,16 @@ func runGetStatsTest(t *testing.T, rpcEndpointPrefix, rpcServerType string, rpcP getStatsResponse.LatestBlock = types.LatestBlock{} // Ensure that the correct response was logged by "GetStats" - require.Equal(t, "/0x-orders/version/3/chain/1337/schema/e30=", getStatsResponse.PubSubTopic) - require.Equal(t, "/0x-mesh/network/1337/version/2", getStatsResponse.Rendezvous) - require.Equal(t, []string{}, getStatsResponse.SecondaryRendezvous) - require.Equal(t, jsonLog.PeerID, getStatsResponse.PeerID) - require.Equal(t, 1337, getStatsResponse.EthereumChainID) - require.Equal(t, types.LatestBlock{}, getStatsResponse.LatestBlock) - require.Equal(t, 0, getStatsResponse.NumOrders) - require.Equal(t, 0, getStatsResponse.NumPeers) - require.Equal(t, constants.UnlimitedExpirationTime.String(), getStatsResponse.MaxExpirationTime) - require.Equal(t, ratelimit.GetUTCMidnightOfDate(time.Now()), getStatsResponse.StartOfCurrentUTCDay) + require.Equal(t, "/0x-orders/version/3/chain/1337/schema/e30=", getStatsResponse.PubSubTopic, "PubSubTopic") + require.Equal(t, "/0x-mesh/network/1337/version/2", getStatsResponse.Rendezvous, "Rendezvous") + require.Equal(t, []string{}, getStatsResponse.SecondaryRendezvous, "SecondaryRendezvous") + require.Equal(t, jsonLog.PeerID, getStatsResponse.PeerID, "PeerID") + require.Equal(t, 1337, getStatsResponse.EthereumChainID, "EthereumChainID") + require.Equal(t, types.LatestBlock{}, getStatsResponse.LatestBlock, "LatestBlock") + require.Equal(t, 0, getStatsResponse.NumOrders, "NumOrders") + require.Equal(t, 0, getStatsResponse.NumPeers, "NumPeers") + require.Equal(t, constants.UnlimitedExpirationTime.String(), getStatsResponse.MaxExpirationTime, "MaxExpirationTime") + require.Equal(t, ratelimit.GetUTCMidnightOfDate(time.Now()), getStatsResponse.StartOfCurrentUTCDay, "StartOfCurrentDay") cancel() wg.Wait() @@ -285,7 +282,7 @@ func TestOrdersSubscription(t *testing.T) { count := int(atomic.AddInt32(&nodeCount, 1)) go func() { defer wg.Done() - startStandaloneNode(t, ctx, count, "", logMessages) + startStandaloneNode(t, ctx, count, "", "", logMessages) }() // Wait for the rpc server to start and then start the rpc client. @@ -345,7 +342,7 @@ func TestHeartbeatSubscription(t *testing.T) { count := int(atomic.AddInt32(&nodeCount, 1)) go func() { defer wg.Done() - startStandaloneNode(t, ctx, count, "", logMessages) + startStandaloneNode(t, ctx, count, "", "", logMessages) }() // Wait for the rpc server to start and then start the rpc client diff --git a/integration-tests/utils.go b/integration-tests/utils.go index e38ee1f35..94912695d 100644 --- a/integration-tests/utils.go +++ b/integration-tests/utils.go @@ -21,6 +21,7 @@ import ( "github.com/chromedp/cdproto/runtime" "github.com/chromedp/chromedp" ethrpc "github.com/ethereum/go-ethereum/rpc" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -59,13 +60,8 @@ func min(a, b int) int { } func removeOldFiles(t *testing.T, ctx context.Context) { - oldFiles, err := filepath.Glob(filepath.Join(standaloneDataDirPrefix + "*")) - require.NoError(t, err) - - for _, oldFile := range oldFiles { - require.NoError(t, os.RemoveAll(filepath.Join(oldFile, "db"))) - require.NoError(t, os.RemoveAll(filepath.Join(oldFile, "p2p"))) - } + require.NoError(t, os.RemoveAll(filepath.Join(browserIntegrationTestDataDir, "sqlite-db"))) + require.NoError(t, os.RemoveAll(filepath.Join(browserIntegrationTestDataDir, "p2p"))) require.NoError(t, os.RemoveAll(filepath.Join(bootstrapDataDir, "p2p"))) } @@ -153,12 +149,16 @@ func startBootstrapNode(t *testing.T, ctx context.Context) { assert.NoError(t, err, "could not run bootstrap node: %s", string(output)) } -func startStandaloneNode(t *testing.T, ctx context.Context, nodeID int, customOrderFilter string, logMessages chan<- string) { +func startStandaloneNode(t *testing.T, ctx context.Context, nodeID int, dataDir string, customOrderFilter string, logMessages chan<- string) { cmd := exec.CommandContext(ctx, "mesh") + if dataDir == "" { + // If dataDir is empty. Set a default data dir to a file in the /tmp directory + dataDir = filepath.Join("/tmp", "mesh_testing", uuid.New().String()) + } cmd.Env = append( os.Environ(), "VERBOSITY=6", - "DATA_DIR="+standaloneDataDirPrefix+strconv.Itoa(nodeID), + "DATA_DIR="+dataDir, "BOOTSTRAP_LIST="+bootstrapList, "ETHEREUM_RPC_URL="+ethereumRPCURL, "ETHEREUM_CHAIN_ID="+strconv.Itoa(ethereumChainID), @@ -229,7 +229,7 @@ func startBrowserNode(t *testing.T, ctx context.Context, url string, browserLogM case runtime.APITypeError: // Report any console.error events as test failures. for _, arg := range ev.Args { - t.Errorf("JavaScript console error: (%s) %s", arg.Type, arg.Value) + t.Errorf("JavaScript console error: (%s) %s %s", arg.Type, arg.Value, arg.Description) } } } diff --git a/meshdb/meshdb.go b/meshdb/meshdb.go deleted file mode 100644 index 871e79733..000000000 --- a/meshdb/meshdb.go +++ /dev/null @@ -1,681 +0,0 @@ -package meshdb - -import ( - "bytes" - "errors" - "fmt" - "math/big" - "time" - - "github.com/0xProject/0x-mesh/constants" - "github.com/0xProject/0x-mesh/db" - "github.com/0xProject/0x-mesh/ethereum" - "github.com/0xProject/0x-mesh/ethereum/miniheader" - "github.com/0xProject/0x-mesh/zeroex" - "github.com/ethereum/go-ethereum/common" - log "github.com/sirupsen/logrus" -) - -const ( - // The default miniHeaderRetentionLimit used by Mesh. This default only gets overwritten in tests. - defaultMiniHeaderRetentionLimit = 20 - // The maximum MiniHeaders to query per page when deleting MiniHeaders - miniHeadersMaxPerPage = 1000 -) - -var ErrDBFilledWithPinnedOrders = errors.New("the database is full of pinned orders; no orders can be removed in order to make space") - -// Order is the database representation a 0x order along with some relevant metadata -type Order struct { - Hash common.Hash - SignedOrder *zeroex.SignedOrder - // When was this order last validated - LastUpdated time.Time - // How much of this order can still be filled - FillableTakerAssetAmount *big.Int - // Was this order flagged for removal? Due to the possibility of block-reorgs, instead - // of immediately removing an order when FillableTakerAssetAmount becomes 0, we instead - // flag it for removal. After this order isn't updated for X time and has IsRemoved = true, - // the order can be permanently deleted. - IsRemoved bool - // IsPinned indicates whether or not the order is pinned. Pinned orders are - // not removed from the database unless they become unfillable. - IsPinned bool -} - -// ID returns the Order's ID -func (o Order) ID() []byte { - return o.Hash.Bytes() -} - -// Metadata is the database representation of MeshDB instance metadata -type Metadata struct { - EthereumChainID int - MaxExpirationTime *big.Int - EthRPCRequestsSentInCurrentUTCDay int - StartOfCurrentUTCDay time.Time -} - -// ID returns the id used for the metadata collection (one per DB) -func (m Metadata) ID() []byte { - return []byte{0} -} - -// MeshDB instantiates the DB connection and creates all the collections used by the application -type MeshDB struct { - database *db.DB - metadata *MetadataCollection - MiniHeaders *MiniHeadersCollection - Orders *OrdersCollection - MiniHeaderRetentionLimit int -} - -// MiniHeadersCollection represents a DB collection of mini Ethereum block headers -type MiniHeadersCollection struct { - *db.Collection - numberIndex *db.Index -} - -// OrdersCollection represents a DB collection of 0x orders -type OrdersCollection struct { - *db.Collection - MakerAddressAndSaltIndex *db.Index - MakerAddressTokenAddressTokenIDIndex *db.Index - MakerAddressMakerFeeAssetAddressTokenIDIndex *db.Index - LastUpdatedIndex *db.Index - IsRemovedIndex *db.Index - ExpirationTimeIndex *db.Index -} - -// MetadataCollection represents a DB collection used to store instance metadata -type MetadataCollection struct { - *db.Collection -} - -// New instantiates a new MeshDB instance -func New(path string, contractAddresses ethereum.ContractAddresses) (*MeshDB, error) { - database, err := db.Open(path) - if err != nil { - return nil, err - } - - miniHeaders, err := setupMiniHeaders(database) - if err != nil { - return nil, err - } - - orders, err := setupOrders(database, contractAddresses) - if err != nil { - return nil, err - } - - metadata, err := setupMetadata(database) - if err != nil { - return nil, err - } - - return &MeshDB{ - database: database, - metadata: metadata, - MiniHeaders: miniHeaders, - Orders: orders, - MiniHeaderRetentionLimit: defaultMiniHeaderRetentionLimit, - }, nil -} - -func setupOrders(database *db.DB, contractAddresses ethereum.ContractAddresses) (*OrdersCollection, error) { - col, err := database.NewCollection("order", &Order{}) - if err != nil { - return nil, err - } - lastUpdatedIndex := col.AddIndex("lastUpdated", func(m db.Model) []byte { - index := []byte(m.(*Order).LastUpdated.UTC().Format(time.RFC3339Nano)) - return index - }) - makerAddressAndSaltIndex := col.AddIndex("makerAddressAndSalt", func(m db.Model) []byte { - // By default, the index is sorted in byte order. In order to sort by - // numerical order, we need to pad with zeroes. The maximum length of an - // unsigned 256 bit integer is 80, so we pad with zeroes such that the - // length of the number is always 80. - signedOrder := m.(*Order).SignedOrder - index := []byte(fmt.Sprintf("%s|%s", signedOrder.MakerAddress.Hex(), uint256ToConstantLengthBytes(signedOrder.Salt))) - return index - }) - // TODO(fabio): Optimize this index callback since it gets called many times under-the-hood. - // We might want to parse the assetData once and store it's components in the DB. The trade-off - // here is compute time for storage space. - makerAddressTokenAddressTokenIDIndex := col.AddMultiIndex("makerAddressTokenAddressTokenId", func(m db.Model) [][]byte { - order := m.(*Order) - singleAssetDatas, err := parseContractAddressesAndTokenIdsFromAssetData(order.SignedOrder.MakerAssetData, contractAddresses) - if err != nil { - log.WithFields(log.Fields{ - "error": err.Error(), - }).Panic("Parsing assetData failed") - } - indexValues := make([][]byte, len(singleAssetDatas)) - for i, singleAssetData := range singleAssetDatas { - indexValue := []byte(order.SignedOrder.MakerAddress.Hex() + "|" + singleAssetData.Address.Hex() + "|") - if singleAssetData.TokenID != nil { - indexValue = append(indexValue, singleAssetData.TokenID.Bytes()...) - } - indexValues[i] = indexValue - } - return indexValues - }) - makerAddressMakerFeeAssetAddressTokenIDIndex := col.AddMultiIndex("makerAddressMakerFeeAssetAddressTokenID", func(m db.Model) [][]byte { - order := m.(*Order) - if bytes.Equal(order.SignedOrder.MakerFeeAssetData, constants.NullBytes) { - // MakerFeeAssetData is optional and the lack of a maker fee is indicated - // by null bytes ("0x0"). We still want to index this value so we can look - // up orders without a maker fee. - return [][]byte{ - []byte(order.SignedOrder.MakerAddress.Hex() + "|" + common.ToHex(constants.NullBytes) + "|"), - } - } - singleAssetDatas, err := parseContractAddressesAndTokenIdsFromAssetData(order.SignedOrder.MakerFeeAssetData, contractAddresses) - if err != nil { - log.WithFields(log.Fields{ - "error": err.Error(), - }).Panic("Parsing assetData failed") - } - - indexValues := make([][]byte, len(singleAssetDatas)) - for i, singleAssetData := range singleAssetDatas { - indexValue := []byte(order.SignedOrder.MakerAddress.Hex() + "|" + singleAssetData.Address.Hex() + "|") - if singleAssetData.TokenID != nil { - indexValue = append(indexValue, singleAssetData.TokenID.Bytes()...) - } - indexValues[i] = indexValue - } - return indexValues - }) - - isRemovedIndex := col.AddIndex("isRemoved", func(m db.Model) []byte { - order := m.(*Order) - // false = 0; true = 1 - if order.IsRemoved { - return []byte{1} - } - return []byte{0} - }) - - expirationTimeIndex := col.AddIndex("expirationTime", func(m db.Model) []byte { - order := m.(*Order) - expTimeString := uint256ToConstantLengthBytes(order.SignedOrder.ExpirationTimeSeconds) - // We separate pinned and non-pinned orders via a prefix that is either 0 or - // 1. - pinnedString := "0" - if order.IsPinned { - pinnedString = "1" - } - return []byte(fmt.Sprintf("%s|%s", pinnedString, expTimeString)) - }) - - return &OrdersCollection{ - Collection: col, - MakerAddressTokenAddressTokenIDIndex: makerAddressTokenAddressTokenIDIndex, - MakerAddressMakerFeeAssetAddressTokenIDIndex: makerAddressMakerFeeAssetAddressTokenIDIndex, - MakerAddressAndSaltIndex: makerAddressAndSaltIndex, - LastUpdatedIndex: lastUpdatedIndex, - IsRemovedIndex: isRemovedIndex, - ExpirationTimeIndex: expirationTimeIndex, - }, nil -} - -func setupMiniHeaders(database *db.DB) (*MiniHeadersCollection, error) { - col, err := database.NewCollection("miniHeader", &miniheader.MiniHeader{}) - if err != nil { - return nil, err - } - numberIndex := col.AddIndex("number", func(model db.Model) []byte { - // By default, the index is sorted in byte order. In order to sort by - // numerical order, we need to pad with zeroes. The maximum length of an - // unsigned 256 bit integer is 80, so we pad with zeroes such that the - // length of the number is always 80. - number := model.(*miniheader.MiniHeader).Number - return uint256ToConstantLengthBytes(number) - }) - - return &MiniHeadersCollection{ - Collection: col, - numberIndex: numberIndex, - }, nil -} - -func setupMetadata(database *db.DB) (*MetadataCollection, error) { - col, err := database.NewCollection("metadata", &Metadata{}) - if err != nil { - return nil, err - } - return &MetadataCollection{col}, nil -} - -// Close closes the database connection -func (m *MeshDB) Close() { - m.database.Close() -} - -// FindAllMiniHeadersSortedByNumber returns all MiniHeaders sorted in ascending block number order -func (m *MeshDB) FindAllMiniHeadersSortedByNumber() ([]*miniheader.MiniHeader, error) { - miniHeaders := []*miniheader.MiniHeader{} - query := m.MiniHeaders.NewQuery(m.MiniHeaders.numberIndex.All()) - if err := query.Run(&miniHeaders); err != nil { - return nil, err - } - return miniHeaders, nil -} - -// MiniHeaderCollectionEmptyError is returned when no miniHeaders have been stored in -// the DB yet -type MiniHeaderCollectionEmptyError struct{} - -func (e MiniHeaderCollectionEmptyError) Error() string { - return "Latest MiniHeader not found" -} - -// FindLatestMiniHeader returns the latest MiniHeader (i.e. the one with the -// largest block number). It returns nil, MiniHeaderCollectionEmptyError if there -// are no MiniHeaders in the database. -func (m *MeshDB) FindLatestMiniHeader() (*miniheader.MiniHeader, error) { - miniHeaders := []*miniheader.MiniHeader{} - query := m.MiniHeaders.NewQuery(m.MiniHeaders.numberIndex.All()).Reverse().Max(1) - if err := query.Run(&miniHeaders); err != nil { - return nil, err - } - if len(miniHeaders) == 0 { - return nil, MiniHeaderCollectionEmptyError{} - } - return miniHeaders[0], nil -} - -// MiniHeaderNotFoundError is returned when a miniHeaders is not found for a specific -// block number -type MiniHeaderNotFoundError struct { - blockNumber int64 -} - -func (e MiniHeaderNotFoundError) Error() string { - return fmt.Sprintf("MiniHeader not found for block number: %d", e.blockNumber) -} - -// FindMiniHeaderByBlockNumber returns the MiniHeader with the specified block number -func (m *MeshDB) FindMiniHeaderByBlockNumber(blockNumber *big.Int) (*miniheader.MiniHeader, error) { - miniHeaders := []*miniheader.MiniHeader{} - blockNumberFilter := m.MiniHeaders.numberIndex.ValueFilter(uint256ToConstantLengthBytes(blockNumber)) - query := m.MiniHeaders.NewQuery(blockNumberFilter) - if err := query.Run(&miniHeaders); err != nil { - return nil, err - } - if len(miniHeaders) == 0 { - return nil, MiniHeaderNotFoundError{blockNumber: blockNumber.Int64()} - } - return miniHeaders[0], nil -} - -// UpdateMiniHeaderRetentionLimit updates the MiniHeaderRetentionLimit. This is only used by tests in order -// to set the retention limit to a smaller size, making the tests shorter in length -func (m *MeshDB) UpdateMiniHeaderRetentionLimit(limit int) error { - m.MiniHeaderRetentionLimit = limit - return m.PruneMiniHeadersAboveRetentionLimit() -} - -// PruneMiniHeadersAboveRetentionLimit prunes miniHeaders from the DB that are above the retention limit -func (m *MeshDB) PruneMiniHeadersAboveRetentionLimit() error { - if totalMiniHeaders, err := m.MiniHeaders.Count(); err != nil { - return err - } else if totalMiniHeaders > m.MiniHeaderRetentionLimit { - latestMiniHeader, err := m.FindLatestMiniHeader() - if err != nil { - return err - } else if latestMiniHeader != nil { - minBlockNumber := big.NewInt(0).Sub(latestMiniHeader.Number, big.NewInt(int64(m.MiniHeaderRetentionLimit)-1)) - if err := m.ClearOldMiniHeaders(minBlockNumber); err != nil { - return err - } - } - } - return nil -} - -// ClearAllMiniHeaders removes all stored MiniHeaders from the database. -func (m *MeshDB) ClearAllMiniHeaders() error { - return m.clearMiniHeadersWithFilter(m.MiniHeaders.numberIndex.All()) -} - -// ClearOldMiniHeaders removes all stored MiniHeaders with a block number less then -// the given minBlockNumber. -func (m *MeshDB) ClearOldMiniHeaders(minBlockNumber *big.Int) error { - filter := m.MiniHeaders.numberIndex.RangeFilter( - uint256ToConstantLengthBytes(big.NewInt(0)), - uint256ToConstantLengthBytes(minBlockNumber), - ) - return m.clearMiniHeadersWithFilter(filter) -} - -func (m *MeshDB) clearMiniHeadersWithFilter(filter *db.Filter) error { - for { - removed, err := m.clearMiniHeadersOnce(filter) - if err != nil { - return err - } - if removed == 0 { - break - } - } - return nil -} - -// clearMiniHeadersOnce removes up to miniHeadersMaxPerPage MiniHeaders from the -// database that match the given filter. It returns the number of MiniHeaders removed. -func (m *MeshDB) clearMiniHeadersOnce(filter *db.Filter) (removed int, err error) { - txn := m.MiniHeaders.OpenTransaction() - defer func() { - _ = txn.Discard() - }() - var miniHeaders []*miniheader.MiniHeader - if err := m.MiniHeaders.NewQuery(filter).Max(miniHeadersMaxPerPage).Run(&miniHeaders); err != nil { - return 0, err - } - log.WithFields(log.Fields{ - "maxPerPage": miniHeadersMaxPerPage, - "numberToRemove": len(miniHeaders), - }).Trace("Removing outdated MiniHeaders from database") - - for _, miniHeader := range miniHeaders { - if err := txn.Delete(miniHeader.ID()); err != nil { - return 0, err - } - } - if err := txn.Commit(); err != nil { - return 0, err - } - return len(miniHeaders), nil -} - -// FindOrdersByMakerAddress finds all orders belonging to a particular maker address -func (m *MeshDB) FindOrdersByMakerAddress(makerAddress common.Address) ([]*Order, error) { - prefix := []byte(makerAddress.Hex() + "|") - filter := m.Orders.MakerAddressTokenAddressTokenIDIndex.PrefixFilter(prefix) - orders := []*Order{} - if err := m.Orders.NewQuery(filter).Run(&orders); err != nil { - return nil, err - } - return orders, nil -} - -// FindOrdersByMakerAddressTokenAddressAndTokenID finds all orders belonging to a particular maker -// address where makerAssetData encodes for a particular token contract and optionally a token ID -func (m *MeshDB) FindOrdersByMakerAddressTokenAddressAndTokenID(makerAddress, tokenAddress common.Address, tokenID *big.Int) ([]*Order, error) { - prefix := []byte(makerAddress.Hex() + "|" + tokenAddress.Hex() + "|") - if tokenID != nil { - prefix = append(prefix, tokenID.Bytes()...) - } - filter := m.Orders.MakerAddressTokenAddressTokenIDIndex.PrefixFilter(prefix) - orders := []*Order{} - if err := m.Orders.NewQuery(filter).Run(&orders); err != nil { - return nil, err - } - return orders, nil -} - -// FindOrdersByMakerAddressMakerFeeAssetAddressTokenID finds all orders belonging to -// a particular maker address where makerFeeAssetData encodes for a particular -// token contract and optionally a token ID. To find orders without a maker fee, -// use constants.NullAddress for makerFeeAssetAddress. -func (m *MeshDB) FindOrdersByMakerAddressMakerFeeAssetAddressAndTokenID(makerAddress, makerFeeAssetAddress common.Address, tokenID *big.Int) ([]*Order, error) { - var prefix []byte - if makerFeeAssetAddress == constants.NullAddress { - prefix = []byte(makerAddress.Hex() + "|" + common.ToHex(constants.NullBytes) + "|") - } else { - prefix = []byte(makerAddress.Hex() + "|" + makerFeeAssetAddress.Hex() + "|") - if tokenID != nil { - prefix = append(prefix, tokenID.Bytes()...) - } - } - - filter := m.Orders.MakerAddressMakerFeeAssetAddressTokenIDIndex.PrefixFilter(prefix) - orders := []*Order{} - if err := m.Orders.NewQuery(filter).Run(&orders); err != nil { - return nil, err - } - return orders, nil -} - -// FindOrdersByMakerAddressAndMaxSalt finds all orders belonging to a particular maker address that -// also have a salt value less then or equal to X -func (m *MeshDB) FindOrdersByMakerAddressAndMaxSalt(makerAddress common.Address, salt *big.Int) ([]*Order, error) { - // DB range queries exclude the limit value however the 0x protocol `cancelOrdersUpTo` method - // is inclusive of the value supplied. In order to make this helper method more useful to our - // particular use-case, we add 1 to the supplied salt (making the query inclusive instead) - saltPlusOne := new(big.Int).Add(salt, big.NewInt(1)) - start := []byte(fmt.Sprintf("%s|%080s", makerAddress.Hex(), "0")) - limit := []byte(fmt.Sprintf("%s|%s", makerAddress.Hex(), uint256ToConstantLengthBytes(saltPlusOne))) - filter := m.Orders.MakerAddressAndSaltIndex.RangeFilter(start, limit) - orders := []*Order{} - if err := m.Orders.NewQuery(filter).Run(&orders); err != nil { - return nil, err - } - return orders, nil -} - -// FindOrdersLastUpdatedBefore finds all orders where the LastUpdated time is less -// than X -func (m *MeshDB) FindOrdersLastUpdatedBefore(lastUpdated time.Time) ([]*Order, error) { - start := []byte(time.Unix(0, 0).Format(time.RFC3339Nano)) - limit := []byte(lastUpdated.UTC().Format(time.RFC3339Nano)) - filter := m.Orders.LastUpdatedIndex.RangeFilter(start, limit) - orders := []*Order{} - if err := m.Orders.NewQuery(filter).Run(&orders); err != nil { - return nil, err - } - return orders, nil -} - -// FindRemovedOrders finds all orders that have been flagged for removal -func (m *MeshDB) FindRemovedOrders() ([]*Order, error) { - var removedOrders []*Order - isRemovedFilter := m.Orders.IsRemovedIndex.ValueFilter([]byte{1}) - if err := m.Orders.NewQuery(isRemovedFilter).Run(&removedOrders); err != nil { - return nil, err - } - return removedOrders, nil -} - -// GetMetadata returns the metadata (or a db.NotFoundError if no metadata has been found). -func (m *MeshDB) GetMetadata() (*Metadata, error) { - var metadata Metadata - if err := m.metadata.FindByID([]byte{0}, &metadata); err != nil { - return nil, err - } - return &metadata, nil -} - -// SaveMetadata inserts the metadata into the database, overwriting any existing -// metadata. -func (m *MeshDB) SaveMetadata(metadata *Metadata) error { - if err := m.metadata.Insert(metadata); err != nil { - return err - } - return nil -} - -// UpdateMetadata updates the metadata in the database via a transaction. It -// accepts a callback function which will be provided with the old metadata and -// should return the new metadata to save. -func (m *MeshDB) UpdateMetadata(updater func(oldmetadata Metadata) (newMetadata Metadata)) error { - txn := m.metadata.OpenTransaction() - defer func() { - _ = txn.Discard() - }() - - oldMetadata, err := m.GetMetadata() - if err != nil { - return err - } - newMetadata := updater(*oldMetadata) - if err := txn.Update(&newMetadata); err != nil { - return err - } - - return txn.Commit() -} - -type singleAssetData struct { - Address common.Address - TokenID *big.Int -} - -func parseContractAddressesAndTokenIdsFromAssetData(assetData []byte, contractAddresses ethereum.ContractAddresses) ([]singleAssetData, error) { - singleAssetDatas := []singleAssetData{} - assetDataDecoder := zeroex.NewAssetDataDecoder() - - assetDataName, err := assetDataDecoder.GetName(assetData) - if err != nil { - return nil, err - } - switch assetDataName { - case "ERC20Token": - var decodedAssetData zeroex.ERC20AssetData - err := assetDataDecoder.Decode(assetData, &decodedAssetData) - if err != nil { - return nil, err - } - a := singleAssetData{ - Address: decodedAssetData.Address, - } - singleAssetDatas = append(singleAssetDatas, a) - case "ERC721Token": - var decodedAssetData zeroex.ERC721AssetData - err := assetDataDecoder.Decode(assetData, &decodedAssetData) - if err != nil { - return nil, err - } - a := singleAssetData{ - Address: decodedAssetData.Address, - TokenID: decodedAssetData.TokenId, - } - singleAssetDatas = append(singleAssetDatas, a) - case "ERC1155Assets": - var decodedAssetData zeroex.ERC1155AssetData - err := assetDataDecoder.Decode(assetData, &decodedAssetData) - if err != nil { - return nil, err - } - for _, id := range decodedAssetData.Ids { - a := singleAssetData{ - Address: decodedAssetData.Address, - TokenID: id, - } - singleAssetDatas = append(singleAssetDatas, a) - } - case "StaticCall": - var decodedAssetData zeroex.StaticCallAssetData - err := assetDataDecoder.Decode(assetData, &decodedAssetData) - if err != nil { - return nil, err - } - // NOTE(jalextowle): As of right now, none of the supported staticcalls - // have important information in the StaticCallData. We choose not to add - // `singleAssetData` because it would not be used. - case "MultiAsset": - var decodedAssetData zeroex.MultiAssetData - err := assetDataDecoder.Decode(assetData, &decodedAssetData) - if err != nil { - return nil, err - } - for _, assetData := range decodedAssetData.NestedAssetData { - as, err := parseContractAddressesAndTokenIdsFromAssetData(assetData, contractAddresses) - if err != nil { - return nil, err - } - singleAssetDatas = append(singleAssetDatas, as...) - } - case "ERC20Bridge": - var decodedAssetData zeroex.ERC20BridgeAssetData - err := assetDataDecoder.Decode(assetData, &decodedAssetData) - if err != nil { - return nil, err - } - tokenAddress := decodedAssetData.TokenAddress - // HACK(fabio): Despite Chai ERC20Bridge orders encoding the Dai address as - // the tokenAddress, we actually want to react to the Chai token's contract - // events, so we actually return it instead. - if decodedAssetData.BridgeAddress == contractAddresses.ChaiBridge { - tokenAddress = contractAddresses.ChaiToken - } - a := singleAssetData{ - Address: tokenAddress, - } - singleAssetDatas = append(singleAssetDatas, a) - default: - return nil, fmt.Errorf("unrecognized assetData type name found: %s", assetDataName) - } - return singleAssetDatas, nil -} - -func uint256ToConstantLengthBytes(v *big.Int) []byte { - return []byte(fmt.Sprintf("%080s", v.String())) -} - -// TrimOrdersByExpirationTime removes existing orders with the highest -// expiration time until the number of remaining orders is <= targetMaxOrders. -// It returns any orders that were removed and the new max expiration time that -// can be used to eliminate incoming orders that expire too far in the future. -func (m *MeshDB) TrimOrdersByExpirationTime(targetMaxOrders int) (newMaxExpirationTime *big.Int, removedOrders []*Order, err error) { - txn := m.Orders.OpenTransaction() - defer func() { - _ = txn.Discard() - }() - - numOrders, err := m.Orders.Count() - if err != nil { - return nil, nil, err - } - if numOrders <= targetMaxOrders { - // If the number of orders is less than the target, we don't need to remove - // any orders. Return UnlimitedExpirationTime. - return constants.UnlimitedExpirationTime, nil, nil - } - - // Find the orders which we need to remove. We use a prefix filter of "0|: so - // that we only remove non-pinned orders. - filter := m.Orders.ExpirationTimeIndex.PrefixFilter([]byte("0|")) - numOrdersToRemove := numOrders - targetMaxOrders - if err := m.Orders.NewQuery(filter).Reverse().Max(numOrdersToRemove).Run(&removedOrders); err != nil { - return nil, nil, err - } - - // Remove those orders and commit the transaction. - for _, order := range removedOrders { - if err := txn.Delete(order.Hash.Bytes()); err != nil { - return nil, nil, err - } - } - if err := txn.Commit(); err != nil { - return nil, nil, err - } - - // If we could not remove numOrdersToRemove orders than it means the database - // is full of pinned orders. We still remove as many orders as we can and then - // return an error. - if len(removedOrders) < numOrdersToRemove { - return nil, nil, ErrDBFilledWithPinnedOrders - } - - // The new max expiration time is simply the minimum expiration time of the - // orders that were removed (i.e., the expiration time of the last order in - // the slice). We add a buffer of -1 just to make sure we don't exceed - // targetMaxOrders. This means it is technically possible that there are a - // number of orders currently in the database that exceed the max expiration - // time, but no new orders that exceed this time will be added. - newMaxExpirationTime = removedOrders[len(removedOrders)-1].SignedOrder.ExpirationTimeSeconds - newMaxExpirationTime = newMaxExpirationTime.Sub(newMaxExpirationTime, big.NewInt(1)) - return newMaxExpirationTime, removedOrders, nil -} - -// CountPinnedOrders returns the number of pinned orders. -func (m *MeshDB) CountPinnedOrders() (int, error) { - // We use a prefix filter of "1|" so that we only count pinned orders. - filter := m.Orders.ExpirationTimeIndex.PrefixFilter([]byte("1|")) - return m.Orders.NewQuery(filter).Count() -} diff --git a/meshdb/meshdb_test.go b/meshdb/meshdb_test.go deleted file mode 100644 index 8ef7854de..000000000 --- a/meshdb/meshdb_test.go +++ /dev/null @@ -1,501 +0,0 @@ -package meshdb - -import ( - "math/big" - "testing" - "time" - - "github.com/0xProject/0x-mesh/constants" - "github.com/0xProject/0x-mesh/db" - "github.com/0xProject/0x-mesh/ethereum" - "github.com/0xProject/0x-mesh/ethereum/miniheader" - "github.com/0xProject/0x-mesh/zeroex" - "github.com/ethereum/go-ethereum/common" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var contractAddresses = ethereum.GanacheAddresses - -func TestOrderCRUDOperations(t *testing.T) { - meshDB, err := New("/tmp/meshdb_testing/"+uuid.New().String(), contractAddresses) - require.NoError(t, err) - defer meshDB.Close() - - makerAddress := constants.GanacheAccount0 - salt := big.NewInt(1548619145450) - o := &zeroex.Order{ - ChainID: big.NewInt(constants.TestChainID), - ExchangeAddress: contractAddresses.Exchange, - MakerAddress: makerAddress, - TakerAddress: constants.NullAddress, - SenderAddress: constants.NullAddress, - FeeRecipientAddress: common.HexToAddress("0xa258b39954cef5cb142fd567a46cddb31a670124"), - TakerAssetData: common.Hex2Bytes("f47261b000000000000000000000000034d402f14d58e001d8efbe6585051bf9706aa064"), - TakerFeeAssetData: constants.NullBytes, - MakerAssetData: common.Hex2Bytes("025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c480000000000000000000000000000000000000000000000000000000000000001"), - MakerFeeAssetData: constants.NullBytes, - Salt: salt, - MakerFee: big.NewInt(0), - TakerFee: big.NewInt(0), - MakerAssetAmount: big.NewInt(3551808554499581700), - TakerAssetAmount: big.NewInt(1), - ExpirationTimeSeconds: big.NewInt(1548619325), - } - signedOrder, err := zeroex.SignTestOrder(o) - require.NoError(t, err) - - orderHash, err := o.ComputeOrderHash() - require.NoError(t, err) - - currentTime := time.Now().UTC() - fiveMinutesFromNow := currentTime.Add(5 * time.Minute) - - // Insert - order := &Order{ - Hash: orderHash, - SignedOrder: signedOrder, - FillableTakerAssetAmount: big.NewInt(1), - LastUpdated: currentTime, - IsRemoved: false, - } - require.NoError(t, meshDB.Orders.Insert(order)) - // We need to call ResetHash so that unexported hash field is equal in later - // assertions. - signedOrder.ResetHash() - - // Find - foundOrder := &Order{} - require.NoError(t, meshDB.Orders.FindByID(order.ID(), foundOrder)) - assert.Equal(t, order, foundOrder) - - // Check Indexes - orders, err := meshDB.FindOrdersByMakerAddressAndMaxSalt(makerAddress, salt) - require.NoError(t, err) - assert.Equal(t, []*Order{order}, orders) - - orders, err = meshDB.FindOrdersByMakerAddress(makerAddress) - require.NoError(t, err) - assert.Equal(t, []*Order{order}, orders) - - orders, err = meshDB.FindOrdersLastUpdatedBefore(fiveMinutesFromNow) - require.NoError(t, err) - assert.Equal(t, []*Order{order}, orders) - - // Update - modifiedOrder := foundOrder - modifiedOrder.FillableTakerAssetAmount = big.NewInt(0) - require.NoError(t, meshDB.Orders.Update(modifiedOrder)) - foundModifiedOrder := &Order{} - require.NoError(t, meshDB.Orders.FindByID(modifiedOrder.ID(), foundModifiedOrder)) - assert.Equal(t, modifiedOrder, foundModifiedOrder) - - // Delete - require.NoError(t, meshDB.Orders.Delete(foundModifiedOrder.ID())) - nonExistentOrder := &Order{} - err = meshDB.Orders.FindByID(foundModifiedOrder.ID(), nonExistentOrder) - assert.IsType(t, db.NotFoundError{}, err) -} - -func TestParseContractAddressesAndTokenIdsFromAssetData(t *testing.T) { - // ERC20 AssetData - erc20AssetData := common.Hex2Bytes("f47261b000000000000000000000000038ae374ecf4db50b0ff37125b591a04997106a32") - singleAssetDatas, err := parseContractAddressesAndTokenIdsFromAssetData(erc20AssetData, contractAddresses) - require.NoError(t, err) - assert.Len(t, singleAssetDatas, 1) - expectedAddress := common.HexToAddress("0x38ae374ecf4db50b0ff37125b591a04997106a32") - assert.Equal(t, expectedAddress, singleAssetDatas[0].Address) - var expectedTokenID *big.Int - assert.Equal(t, expectedTokenID, singleAssetDatas[0].TokenID) - - // ERC721 AssetData - erc721AssetData := common.Hex2Bytes("025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c480000000000000000000000000000000000000000000000000000000000000001") - singleAssetDatas, err = parseContractAddressesAndTokenIdsFromAssetData(erc721AssetData, contractAddresses) - require.NoError(t, err) - assert.Equal(t, 1, len(singleAssetDatas)) - expectedAddress = common.HexToAddress("0x1dC4c1cEFEF38a777b15aA20260a54E584b16C48") - assert.Equal(t, expectedAddress, singleAssetDatas[0].Address) - expectedTokenID = big.NewInt(1) - assert.Equal(t, expectedTokenID, singleAssetDatas[0].TokenID) - - // Multi AssetData - multiAssetData := common.Hex2Bytes("94cfcdd7000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000004600000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000024f47261b00000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c48000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000044025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c480000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000x94cfcdd7000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000004600000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000024f47261b00000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c48000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000044025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c48000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000") - singleAssetDatas, err = parseContractAddressesAndTokenIdsFromAssetData(multiAssetData, contractAddresses) - require.NoError(t, err) - assert.Equal(t, 2, len(singleAssetDatas)) - expectedSingleAssetDatas := []singleAssetData{ - singleAssetData{ - Address: common.HexToAddress("0x1dc4c1cefef38a777b15aa20260a54e584b16c48"), - }, - singleAssetData{ - Address: common.HexToAddress("0x1dc4c1cefef38a777b15aa20260a54e584b16c48"), - TokenID: big.NewInt(1), - }, - } - for i, singleAssetData := range singleAssetDatas { - expectedSingleAssetData := expectedSingleAssetDatas[i] - assert.Equal(t, expectedSingleAssetData.Address, singleAssetData.Address) - assert.Equal(t, expectedSingleAssetData.TokenID, singleAssetData.TokenID) - } -} - -func TestTrimOrdersByExpirationTime(t *testing.T) { - meshDB, err := New("/tmp/meshdb_testing/"+uuid.New().String(), contractAddresses) - require.NoError(t, err) - defer meshDB.Close() - - // TODO(albrow): Move these to top of file. - makerAddress := constants.GanacheAccount0 - - // Note: most of the fields in these orders are the same. For the purposes of - // this test, the only thing that matters is the Salt and ExpirationTime. - rawUnpinnedOrders := []*zeroex.Order{ - { - MakerAddress: makerAddress, - TakerAddress: constants.NullAddress, - SenderAddress: constants.NullAddress, - FeeRecipientAddress: common.HexToAddress("0xa258b39954cef5cb142fd567a46cddb31a670124"), - TakerAssetData: common.Hex2Bytes("f47261b000000000000000000000000034d402f14d58e001d8efbe6585051bf9706aa064"), - MakerAssetData: common.Hex2Bytes("025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c480000000000000000000000000000000000000000000000000000000000000001"), - ChainID: big.NewInt(constants.TestChainID), - TakerFeeAssetData: constants.NullBytes, - MakerFeeAssetData: constants.NullBytes, - Salt: big.NewInt(0), - MakerFee: big.NewInt(0), - TakerFee: big.NewInt(0), - MakerAssetAmount: big.NewInt(3551808554499581700), - TakerAssetAmount: big.NewInt(1), - ExpirationTimeSeconds: big.NewInt(100), - ExchangeAddress: contractAddresses.Exchange, - }, - { - MakerAddress: makerAddress, - TakerAddress: constants.NullAddress, - SenderAddress: constants.NullAddress, - FeeRecipientAddress: common.HexToAddress("0xa258b39954cef5cb142fd567a46cddb31a670124"), - TakerAssetData: common.Hex2Bytes("f47261b000000000000000000000000034d402f14d58e001d8efbe6585051bf9706aa064"), - MakerAssetData: common.Hex2Bytes("025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c480000000000000000000000000000000000000000000000000000000000000001"), - ChainID: big.NewInt(constants.TestChainID), - TakerFeeAssetData: constants.NullBytes, - MakerFeeAssetData: constants.NullBytes, - Salt: big.NewInt(1), - MakerFee: big.NewInt(0), - TakerFee: big.NewInt(0), - MakerAssetAmount: big.NewInt(3551808554499581700), - TakerAssetAmount: big.NewInt(1), - ExpirationTimeSeconds: big.NewInt(200), - ExchangeAddress: contractAddresses.Exchange, - }, - { - MakerAddress: makerAddress, - TakerAddress: constants.NullAddress, - SenderAddress: constants.NullAddress, - FeeRecipientAddress: common.HexToAddress("0xa258b39954cef5cb142fd567a46cddb31a670124"), - TakerAssetData: common.Hex2Bytes("f47261b000000000000000000000000034d402f14d58e001d8efbe6585051bf9706aa064"), - MakerAssetData: common.Hex2Bytes("025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c480000000000000000000000000000000000000000000000000000000000000001"), - ChainID: big.NewInt(constants.TestChainID), - TakerFeeAssetData: constants.NullBytes, - MakerFeeAssetData: constants.NullBytes, - Salt: big.NewInt(2), - MakerFee: big.NewInt(0), - TakerFee: big.NewInt(0), - MakerAssetAmount: big.NewInt(3551808554499581700), - TakerAssetAmount: big.NewInt(1), - ExpirationTimeSeconds: big.NewInt(200), - ExchangeAddress: contractAddresses.Exchange, - }, - { - MakerAddress: makerAddress, - TakerAddress: constants.NullAddress, - SenderAddress: constants.NullAddress, - FeeRecipientAddress: common.HexToAddress("0xa258b39954cef5cb142fd567a46cddb31a670124"), - TakerAssetData: common.Hex2Bytes("f47261b000000000000000000000000034d402f14d58e001d8efbe6585051bf9706aa064"), - MakerAssetData: common.Hex2Bytes("025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c480000000000000000000000000000000000000000000000000000000000000001"), - ChainID: big.NewInt(constants.TestChainID), - TakerFeeAssetData: constants.NullBytes, - MakerFeeAssetData: constants.NullBytes, - Salt: big.NewInt(3), - MakerFee: big.NewInt(0), - TakerFee: big.NewInt(0), - MakerAssetAmount: big.NewInt(3551808554499581700), - TakerAssetAmount: big.NewInt(1), - ExpirationTimeSeconds: big.NewInt(300), - ExchangeAddress: contractAddresses.Exchange, - }, - } - rawPinnedOrders := []*zeroex.Order{ - { - MakerAddress: makerAddress, - TakerAddress: constants.NullAddress, - SenderAddress: constants.NullAddress, - FeeRecipientAddress: common.HexToAddress("0xa258b39954cef5cb142fd567a46cddb31a670124"), - TakerAssetData: common.Hex2Bytes("f47261b000000000000000000000000034d402f14d58e001d8efbe6585051bf9706aa064"), - MakerAssetData: common.Hex2Bytes("025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c480000000000000000000000000000000000000000000000000000000000000001"), - ChainID: big.NewInt(constants.TestChainID), - TakerFeeAssetData: constants.NullBytes, - MakerFeeAssetData: constants.NullBytes, - Salt: big.NewInt(0), - MakerFee: big.NewInt(0), - TakerFee: big.NewInt(0), - MakerAssetAmount: big.NewInt(3551808554499581700), - TakerAssetAmount: big.NewInt(1), - ExpirationTimeSeconds: big.NewInt(250), - ExchangeAddress: contractAddresses.Exchange, - }, - { - MakerAddress: makerAddress, - TakerAddress: constants.NullAddress, - SenderAddress: constants.NullAddress, - FeeRecipientAddress: common.HexToAddress("0xa258b39954cef5cb142fd567a46cddb31a670124"), - TakerAssetData: common.Hex2Bytes("f47261b000000000000000000000000034d402f14d58e001d8efbe6585051bf9706aa064"), - MakerAssetData: common.Hex2Bytes("025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c480000000000000000000000000000000000000000000000000000000000000001"), - ChainID: big.NewInt(constants.TestChainID), - TakerFeeAssetData: constants.NullBytes, - MakerFeeAssetData: constants.NullBytes, - Salt: big.NewInt(1), - MakerFee: big.NewInt(0), - TakerFee: big.NewInt(0), - MakerAssetAmount: big.NewInt(3551808554499581700), - TakerAssetAmount: big.NewInt(1), - ExpirationTimeSeconds: big.NewInt(350), - ExchangeAddress: contractAddresses.Exchange, - }, - } - - insertRawOrders(t, meshDB, rawUnpinnedOrders, false) - pinnedOrders := insertRawOrders(t, meshDB, rawPinnedOrders, true) - - // Call CalculateNewMaxExpirationTimeAndTrimDatabase and check the results. - targetMaxOrders := 4 - gotExpirationTime, gotRemovedOrders, err := meshDB.TrimOrdersByExpirationTime(targetMaxOrders) - require.NoError(t, err) - assert.Equal(t, "199", gotExpirationTime.String(), "newMaxExpirationTime") - assert.Len(t, gotRemovedOrders, 2, "wrong number of orders removed") - - // Check that the expiration time of each removed order is >= the new max. - for _, removedOrder := range gotRemovedOrders { - expirationTimeOfRemovedOrder := removedOrder.SignedOrder.ExpirationTimeSeconds - assert.True(t, expirationTimeOfRemovedOrder.Cmp(gotExpirationTime) != -1, "an order was removed with expiration time (%s) less than the new max (%s)", expirationTimeOfRemovedOrder, gotExpirationTime) - } - var remainingOrders []*Order - require.NoError(t, meshDB.Orders.FindAll(&remainingOrders)) - assert.Len(t, remainingOrders, 4, "wrong number of orders remaining") - - // Check that the expiration time of each remaining order is <= the new max. - for _, remainingOrder := range remainingOrders { - if !remainingOrder.IsPinned { - // Unpinned orders should not have an expiration time greater than the - // new max. - expirationTimeOfRemainingOrder := remainingOrder.SignedOrder.ExpirationTimeSeconds - newMaxPlusOne := big.NewInt(0).Add(gotExpirationTime, big.NewInt(1)) - assert.True(t, expirationTimeOfRemainingOrder.Cmp(newMaxPlusOne) != 1, "a remaining order had an expiration time (%s) greater than the new max + 1 (%s)", expirationTimeOfRemainingOrder, newMaxPlusOne) - } - } - - // Check that the pinned orders are still in the database. - for _, pinnedOrder := range pinnedOrders { - require.NoError(t, meshDB.Orders.FindByID(pinnedOrder.Hash.Bytes(), &Order{})) - } - - // Trying to trim orders when the database is full of pinned orders should - // return an error. - _, _, err = meshDB.TrimOrdersByExpirationTime(1) - assert.EqualError(t, err, ErrDBFilledWithPinnedOrders.Error(), "expected ErrFilledWithPinnedOrders when targetMaxOrders is less than the number of pinned orders") -} - -func TestFindOrdersByMakerAddressMakerFeeAssetAddressTokenID(t *testing.T) { - meshDB, err := New("/tmp/meshdb_testing/"+uuid.New().String(), contractAddresses) - require.NoError(t, err) - defer meshDB.Close() - - makerAddress := constants.GanacheAccount0 - nextSalt := big.NewInt(1548619145450) - - zeroexOrders := []*zeroex.Order{ - // No Maker fee - &zeroex.Order{ - ChainID: big.NewInt(constants.TestChainID), - ExchangeAddress: contractAddresses.Exchange, - MakerAddress: makerAddress, - TakerAddress: constants.NullAddress, - SenderAddress: constants.NullAddress, - FeeRecipientAddress: common.HexToAddress("0xa258b39954cef5cb142fd567a46cddb31a670124"), - TakerAssetData: common.Hex2Bytes("f47261b000000000000000000000000034d402f14d58e001d8efbe6585051bf9706aa064"), - TakerFeeAssetData: constants.NullBytes, - MakerAssetData: common.Hex2Bytes("025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c480000000000000000000000000000000000000000000000000000000000000001"), - MakerFeeAssetData: constants.NullBytes, - Salt: nextSalt.Add(nextSalt, big.NewInt(1)), - MakerFee: big.NewInt(0), - TakerFee: big.NewInt(0), - MakerAssetAmount: big.NewInt(3551808554499581700), - TakerAssetAmount: big.NewInt(1), - ExpirationTimeSeconds: big.NewInt(1548619325), - }, - // ERC20 maker fee - &zeroex.Order{ - ChainID: big.NewInt(constants.TestChainID), - ExchangeAddress: contractAddresses.Exchange, - MakerAddress: makerAddress, - TakerAddress: constants.NullAddress, - SenderAddress: constants.NullAddress, - FeeRecipientAddress: common.HexToAddress("0xa258b39954cef5cb142fd567a46cddb31a670124"), - TakerAssetData: common.Hex2Bytes("f47261b000000000000000000000000034d402f14d58e001d8efbe6585051bf9706aa064"), - TakerFeeAssetData: constants.NullBytes, - MakerAssetData: common.Hex2Bytes("025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c480000000000000000000000000000000000000000000000000000000000000001"), - MakerFeeAssetData: common.Hex2Bytes("f47261b000000000000000000000000038ae374ecf4db50b0ff37125b591a04997106a32"), - Salt: nextSalt.Add(nextSalt, big.NewInt(1)), - MakerFee: big.NewInt(0), - TakerFee: big.NewInt(0), - MakerAssetAmount: big.NewInt(3551808554499581700), - TakerAssetAmount: big.NewInt(1), - ExpirationTimeSeconds: big.NewInt(1548619325), - }, - // ERC721 maker fee with token id = 1 - &zeroex.Order{ - ChainID: big.NewInt(constants.TestChainID), - ExchangeAddress: contractAddresses.Exchange, - MakerAddress: makerAddress, - TakerAddress: constants.NullAddress, - SenderAddress: constants.NullAddress, - FeeRecipientAddress: common.HexToAddress("0xa258b39954cef5cb142fd567a46cddb31a670124"), - TakerAssetData: common.Hex2Bytes("f47261b000000000000000000000000034d402f14d58e001d8efbe6585051bf9706aa064"), - TakerFeeAssetData: constants.NullBytes, - MakerAssetData: common.Hex2Bytes("025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c480000000000000000000000000000000000000000000000000000000000000001"), - MakerFeeAssetData: common.Hex2Bytes("025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c480000000000000000000000000000000000000000000000000000000000000001"), - Salt: nextSalt.Add(nextSalt, big.NewInt(1)), - MakerFee: big.NewInt(0), - TakerFee: big.NewInt(0), - MakerAssetAmount: big.NewInt(3551808554499581700), - TakerAssetAmount: big.NewInt(1), - ExpirationTimeSeconds: big.NewInt(1548619325), - }, - // ERC721 maker fee with token id = 2 - &zeroex.Order{ - ChainID: big.NewInt(constants.TestChainID), - ExchangeAddress: contractAddresses.Exchange, - MakerAddress: makerAddress, - TakerAddress: constants.NullAddress, - SenderAddress: constants.NullAddress, - FeeRecipientAddress: common.HexToAddress("0xa258b39954cef5cb142fd567a46cddb31a670124"), - TakerAssetData: common.Hex2Bytes("f47261b000000000000000000000000034d402f14d58e001d8efbe6585051bf9706aa064"), - TakerFeeAssetData: constants.NullBytes, - MakerAssetData: common.Hex2Bytes("025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c480000000000000000000000000000000000000000000000000000000000000001"), - MakerFeeAssetData: common.Hex2Bytes("025717920000000000000000000000001dc4c1cefef38a777b15aa20260a54e584b16c480000000000000000000000000000000000000000000000000000000000000002"), - Salt: nextSalt.Add(nextSalt, big.NewInt(1)), - MakerFee: big.NewInt(0), - TakerFee: big.NewInt(0), - MakerAssetAmount: big.NewInt(3551808554499581700), - TakerAssetAmount: big.NewInt(1), - ExpirationTimeSeconds: big.NewInt(1548619325), - }, - } - orders := make([]*Order, len(zeroexOrders)) - for i, o := range zeroexOrders { - signedOrder, err := zeroex.SignTestOrder(o) - require.NoError(t, err) - orderHash, err := o.ComputeOrderHash() - require.NoError(t, err) - - orders[i] = &Order{ - Hash: orderHash, - SignedOrder: signedOrder, - FillableTakerAssetAmount: big.NewInt(1), - LastUpdated: time.Now().UTC(), - IsRemoved: false, - } - require.NoError(t, meshDB.Orders.Insert(orders[i])) - // We need to call ResetHash so that unexported hash field is equal in later - // assertions. - signedOrder.ResetHash() - } - - testCases := []struct { - makerFeeAssetAddress common.Address - makerFeeTokenID *big.Int - expectedOrders []*Order - }{ - { - makerFeeAssetAddress: constants.NullAddress, - makerFeeTokenID: nil, - expectedOrders: orders[0:1], - }, - { - makerFeeAssetAddress: common.HexToAddress("0x38ae374ecf4db50b0ff37125b591a04997106a32"), - makerFeeTokenID: nil, - expectedOrders: orders[1:2], - }, - { - // Since no token id was specified, this query should match all token ids. - makerFeeAssetAddress: common.HexToAddress("0x1dc4c1cefef38a777b15aa20260a54e584b16c48"), - makerFeeTokenID: nil, - expectedOrders: orders[2:4], - }, - { - makerFeeAssetAddress: common.HexToAddress("0x1dc4c1cefef38a777b15aa20260a54e584b16c48"), - makerFeeTokenID: big.NewInt(1), - expectedOrders: orders[2:3], - }, - { - makerFeeAssetAddress: common.HexToAddress("0x1dc4c1cefef38a777b15aa20260a54e584b16c48"), - makerFeeTokenID: big.NewInt(2), - expectedOrders: orders[3:4], - }, - } - for i, tc := range testCases { - foundOrders, err := meshDB.FindOrdersByMakerAddressMakerFeeAssetAddressAndTokenID(makerAddress, tc.makerFeeAssetAddress, tc.makerFeeTokenID) - require.NoError(t, err) - assert.Equal(t, tc.expectedOrders, foundOrders, "test case %d", i) - } -} - -func insertRawOrders(t *testing.T, meshDB *MeshDB, rawOrders []*zeroex.Order, isPinned bool) []*Order { - results := make([]*Order, len(rawOrders)) - for i, order := range rawOrders { - // Sign, compute order hash, and insert. - signedOrder, err := zeroex.SignTestOrder(order) - require.NoError(t, err) - orderHash, err := order.ComputeOrderHash() - require.NoError(t, err) - - order := &Order{ - Hash: orderHash, - SignedOrder: signedOrder, - FillableTakerAssetAmount: big.NewInt(1), - LastUpdated: time.Now(), - IsRemoved: false, - IsPinned: isPinned, - } - results[i] = order - require.NoError(t, meshDB.Orders.Insert(order)) - } - return results -} - -func TestPruneMiniHeadersAboveRetentionLimit(t *testing.T) { - t.Parallel() - - meshDB, err := New("/tmp/meshdb_testing/"+uuid.New().String(), contractAddresses) - require.NoError(t, err) - defer meshDB.Close() - - txn := meshDB.MiniHeaders.OpenTransaction() - defer func() { - _ = txn.Discard() - }() - - miniHeadersToAdd := miniHeadersMaxPerPage*2 + defaultMiniHeaderRetentionLimit + 1 - for i := 0; i < miniHeadersToAdd; i++ { - miniHeader := &miniheader.MiniHeader{ - Hash: common.BigToHash(big.NewInt(int64(i))), - Number: big.NewInt(int64(i)), - Timestamp: time.Now().Add(time.Duration(i)*time.Second - 5*time.Hour), - } - require.NoError(t, txn.Insert(miniHeader)) - } - require.NoError(t, txn.Commit()) - - require.NoError(t, meshDB.PruneMiniHeadersAboveRetentionLimit()) - remainingMiniHeaders, err := meshDB.MiniHeaders.Count() - assert.Equal(t, defaultMiniHeaderRetentionLimit, remainingMiniHeaders, "wrong number of MiniHeaders remaining") -} diff --git a/orderfilter/filter_js.go b/orderfilter/filter_js.go index 4a9c1de7d..7024ce9b0 100644 --- a/orderfilter/filter_js.go +++ b/orderfilter/filter_js.go @@ -24,13 +24,13 @@ func New(chainID int, customOrderSchema string, contractAddresses ethereum.Contr chainIDSchema := fmt.Sprintf(`{"$id": "/chainId", "const":%d}`, chainID) exchangeAddressSchema := fmt.Sprintf(`{"$id": "/exchangeAddress", "enum":[%q,%q]}`, contractAddresses.Exchange.Hex(), strings.ToLower(contractAddresses.Exchange.Hex())) - if jsutil.IsNullOrUndefined(js.Global().Get("createSchemaValidator")) { - return nil, errors.New(`"createSchemaValidator" has not been set on the Javascript "global" object`) + if jsutil.IsNullOrUndefined(js.Global().Get("__mesh_createSchemaValidator__")) { + return nil, errors.New(`"__mesh_createSchemaValidator__" has not been set on the Javascript "global" object`) } // NOTE(jalextowle): The order of the schemas within the two arrays // defines their order of compilation. schemaValidator := js.Global().Call( - "createSchemaValidator", + "__mesh_createSchemaValidator__", customOrderSchema, []interface{}{ addressSchema, diff --git a/packages/browser-lite/package.json b/packages/browser-lite/package.json index 3980571ea..de3137ea6 100644 --- a/packages/browser-lite/package.json +++ b/packages/browser-lite/package.json @@ -14,18 +14,19 @@ "config": { "docsPath": "../../docs/browser-bindings/browser-lite" }, + "devDependencies": { + "@types/dexie": "^1.3.1", + "@0x/ts-doc-gen": "^0.0.16", + "shx": "^0.3.2", + "typedoc": "^0.15.0", + "typescript": "^3.9.3" + }, "dependencies": { "@0x/order-utils": "^10.2.0", "@0x/utils": "^5.4.0", "ajv": "^6.12.2", "base64-arraybuffer": "^0.2.0", - "browserfs": "^1.4.3", + "dexie": "^3.0.1", "ethereum-types": "^3.0.0" - }, - "devDependencies": { - "@0x/ts-doc-gen": "^0.0.16", - "shx": "^0.3.2", - "typedoc": "^0.15.0", - "typescript": "^3.9.3" } } diff --git a/packages/browser-lite/src/database.ts b/packages/browser-lite/src/database.ts new file mode 100644 index 000000000..7378ebe48 --- /dev/null +++ b/packages/browser-lite/src/database.ts @@ -0,0 +1,577 @@ +// tslint:disable:max-file-line-count + +/** + * @hidden + */ + +/** + * NOTE(jalextowle): This comment must be here so that typedoc knows that the above + * comment is a module comment + */ + +import Dexie from 'dexie'; + +export type Record = Order | MiniHeader | Metadata; + +export interface Options { + dataSourceName: string; + maxOrders: number; + maxMiniHeaders: number; +} + +export interface Query { + filters?: Array>; + sort?: Array>; + limit?: number; + offset?: number; +} + +export interface SortOption { + field: Extract; + direction: SortDirection; +} + +export interface FilterOption { + field: Extract; + kind: FilterKind; + value: any; +} + +export enum SortDirection { + Asc = 'ASC', + Desc = 'DESC', +} + +export enum FilterKind { + Equal = '=', + NotEqual = '!=', + Less = '<', + Greater = '>', + LessOrEqual = '<=', + GreaterOrEqual = '>=', + Contains = 'CONTAINS', +} + +export interface Order { + hash: string; + chainId: number; + makerAddress: string; + makerAssetData: string; + makerAssetAmount: string; + makerFee: string; + makerFeeAssetData: string; + takerAddress: string; + takerAssetData: string; + takerFeeAssetData: string; + takerAssetAmount: string; + takerFee: string; + senderAddress: string; + feeRecipientAddress: string; + expirationTimeSeconds: string; + salt: string; + signature: string; + exchangeAddress: string; + fillableTakerAssetAmount: string; + lastUpdated: string; + isRemoved: number; + isPinned: number; + parsedMakerAssetData: string; + parsedMakerFeeAssetData: string; +} + +export type OrderField = keyof Order; + +export type OrderQuery = Query; + +export type OrderSort = SortOption; + +export type OrderFilter = FilterOption; + +export interface AddOrdersResult { + added: Order[]; + removed: Order[]; +} + +export interface MiniHeader { + hash: string; + parent: string; + number: string; + timestamp: string; + logs: string; +} + +export type MiniHeaderField = keyof MiniHeader; + +export type MiniHeaderQuery = Query; + +export type MiniHeaderSort = SortOption; + +export type MiniHeaderFilter = FilterOption; + +export interface AddMiniHeadersResult { + added: MiniHeader[]; + removed: MiniHeader[]; +} + +export interface Metadata { + ethereumChainID: number; + maxExpirationTime: string; + ethRPCRequestsSentInCurrentUTCDay: number; + startOfCurrentUTCDay: string; +} + +function newNotFoundError(): Error { + return new Error('could not find existing model or row in database'); +} + +function newMetadataAlreadExistsError(): Error { + return new Error('metadata already exists in the database (use UpdateMetadata instead?)'); +} + +/** + * Creates and returns a new database + * + * @param opts The options to use for the database + */ +export function createDatabase(opts: Options): Database { + return new Database(opts); +} + +export class Database { + private readonly _db: Dexie; + // private readonly _maxOrders: number; + private readonly _maxMiniHeaders: number; + private readonly _orders: Dexie.Table; + private readonly _miniHeaders: Dexie.Table; + private readonly _metadata: Dexie.Table; + + constructor(opts: Options) { + this._db = new Dexie(opts.dataSourceName); + // this._maxOrders = opts.maxOrders; + this._maxMiniHeaders = opts.maxMiniHeaders; + + this._db.version(1).stores({ + orders: + '&hash,chainId,makerAddress,makerAssetData,makerAssetAmount,makerFee,makerFeeAssetData,takerAddress,takerAssetData,takerFeeAssetData,takerAssetAmount,takerFee,senderAddress,feeRecipientAddress,expirationTimeSeconds,salt,signature,exchangeAddress,fillableTakerAssetAmount,lastUpdated,isRemoved,isPinned,parsedMakerAssetData,parsedMakerFeeAssetData', + miniHeaders: '&hash,parent,number,timestamp,logs', + metadata: 'ðereumChainID', + }); + + this._orders = this._db.table('orders'); + this._miniHeaders = this._db.table('miniHeaders'); + this._metadata = this._db.table('metadata'); + } + + public close(): void { + this._db.close(); + } + + // AddOrders(orders []*types.OrderWithMetadata) (added []*types.OrderWithMetadata, removed []*types.OrderWithMetadata, err error) + public async addOrdersAsync(orders: Order[]): Promise { + // TODO(albrow): Remove orders with max expiration time. + const added: Order[] = []; + await this._db.transaction('rw', this._orders, async () => { + for (const order of orders) { + try { + await this._orders.add(order); + } catch (e) { + if (e.name === 'ConstraintError') { + // An order with this hash already exists. This is fine based on the semantics of + // addOrders. + continue; + } + throw e; + } + added.push(order); + } + }); + return { + added, + removed: [], + }; + } + + // GetOrder(hash common.Hash) (*types.OrderWithMetadata, error) + public async getOrderAsync(hash: string): Promise { + const order = await this._orders.get(hash); + if (order === undefined) { + throw newNotFoundError(); + } + return order; + } + + // FindOrders(opts *OrderQuery) ([]*types.OrderWithMetadata, error) + public async findOrdersAsync(query?: OrderQuery): Promise { + if (!canUseNativeDexieIndexes(this._orders, query)) { + // As a fallback, implement the query inefficiently (in-memory). + // Note(albrow): If needed we can optimize specific common queries with compound indexes. + return runQueryInMemoryAsync(this._orders, query); + } + const col = buildCollectionWithDexieIndexes(this._orders, query); + return col.toArray(); + } + + // CountOrders(opts *OrderQuery) (int, error) + public async countOrdersAsync(query?: OrderQuery): Promise { + if (!canUseNativeDexieIndexes(this._orders, query)) { + // As a fallback, implement the query inefficiently (in-memory). + // Note(albrow): If needed we can optimize specific common queries with compound indexes. + const records = await runQueryInMemoryAsync(this._orders, query); + return records.length; + } + const col = buildCollectionWithDexieIndexes(this._orders, query); + return col.count(); + } + + // DeleteOrder(hash common.Hash) error + public async deleteOrderAsync(hash: string): Promise { + return this._orders.delete(hash); + } + + // DeleteOrders(opts *OrderQuery) ([]*types.OrderWithMetadata, error) + public async deleteOrdersAsync(query: OrderQuery | undefined): Promise { + const deletedOrders: Order[] = []; + await this._db.transaction('rw', this._orders, async () => { + const orders = await this.findOrdersAsync(query); + for (const order of orders) { + await this._orders.delete(order.hash); + deletedOrders.push(order); + } + }); + return deletedOrders; + } + + // UpdateOrder(hash common.Hash, updateFunc func(existingOrder *types.OrderWithMetadata) (updatedOrder *types.OrderWithMetadata, err error)) error + public async updateOrderAsync(hash: string, updateFunc: (existingOrder: Order) => Order): Promise { + await this._db.transaction('rw', this._orders, async () => { + const existingOrder = await this.getOrderAsync(hash); + const updatedOrder = updateFunc(existingOrder); + await this._orders.put(updatedOrder, hash); + }); + } + + // AddMiniHeaders(miniHeaders []*types.MiniHeader) (added []*types.MiniHeader, removed []*types.MiniHeader, err error) + public async addMiniHeadersAsync(miniHeaders: MiniHeader[]): Promise { + const added: MiniHeader[] = []; + const removed: MiniHeader[] = []; + await this._db.transaction('rw', this._miniHeaders, async () => { + for (const miniHeader of miniHeaders) { + try { + await this._miniHeaders.add(miniHeader); + } catch (e) { + if (e.name === 'ConstraintError') { + // A miniHeader with this hash already exists. This is fine based on the semantics of + // addMiniHeaders. + continue; + } + throw e; + } + added.push(miniHeader); + const outdatedMiniHeaders = await this._miniHeaders + .orderBy('number') + .offset(this._maxMiniHeaders) + .reverse() + .toArray(); + for (const outdated of outdatedMiniHeaders) { + await this._miniHeaders.delete(outdated.hash); + removed.push(outdated); + } + } + }); + return { + added, + removed, + }; + } + + // GetMiniHeader(hash common.Hash) (*types.MiniHeader, error) + public async getMiniHeaderAsync(hash: string): Promise { + const miniHeader = await this._miniHeaders.get(hash); + if (miniHeader === undefined) { + throw newNotFoundError(); + } + return miniHeader; + } + + // FindMiniHeaders(opts *MiniHeaderQuery) ([]*types.MiniHeader, error) + public async findMiniHeadersAsync(query: MiniHeaderQuery): Promise { + if (!canUseNativeDexieIndexes(this._miniHeaders, query)) { + // As a fallback, implement the query inefficiently (in-memory). + // Note(albrow): If needed we can optimize specific common queries with compound indexes. + return runQueryInMemoryAsync(this._miniHeaders, query); + } + const col = buildCollectionWithDexieIndexes(this._miniHeaders, query); + return col.toArray(); + } + + // DeleteMiniHeader(hash common.Hash) error + public async deleteMiniHeaderAsync(hash: string): Promise { + return this._miniHeaders.delete(hash); + } + + // DeleteMiniHeaders(opts *MiniHeaderQuery) ([]*types.MiniHeader, error) + public async deleteMiniHeadersAsync(query: MiniHeaderQuery): Promise { + const deletedMiniHeaders: MiniHeader[] = []; + await this._db.transaction('rw', this._miniHeaders, async () => { + const miniHeaders = await this.findMiniHeadersAsync(query); + for (const miniHeader of miniHeaders) { + await this._miniHeaders.delete(miniHeader.hash); + deletedMiniHeaders.push(miniHeader); + } + }); + return deletedMiniHeaders; + } + + // GetMetadata() (*types.Metadata, error) + public async getMetadataAsync(): Promise { + const count = await this._metadata.count(); + if (count === 0) { + throw newNotFoundError(); + } else if (count > 1) { + // This should never happen, but it's possible if a user manually messed around with + // IndexedDB. In this case, just delete the metadata table and we should start + // over. + await this._metadata.clear(); + throw new Error('more than one metadata entry stored in the database'); + } + const metadatas = await this._metadata.toArray(); + return metadatas[0]; + } + + // SaveMetadata(metadata *types.Metadata) error + public async saveMetadataAsync(metadata: Metadata): Promise { + await this._db.transaction('rw', this._metadata, async () => { + if ((await this._metadata.count()) > 0) { + throw newMetadataAlreadExistsError(); + } + await this._metadata.add(metadata); + }); + } + + // UpdateMetadata(updateFunc func(oldmetadata *types.Metadata) (newMetadata *types.Metadata)) error + public async updateMetadataAsync(updateFunc: (existingMetadata: Metadata) => Metadata): Promise { + await this._db.transaction('rw', this._metadata, async () => { + const existingMetadata = await this.getMetadataAsync(); + const updatedMetadata = updateFunc(existingMetadata); + await this._metadata.put(updatedMetadata); + }); + } +} + +function buildCollectionWithDexieIndexes( + table: Dexie.Table, + query?: Query, +): Dexie.Collection { + if (query === null || query === undefined) { + return table.toCollection(); + } + + // First we create the Collection based on the query fields. + let col: Dexie.Collection; + if (queryUsesFilters(query)) { + // tslint:disable-next-line:no-non-null-assertion + const filter = query.filters![0]; + switch (filter.kind) { + case FilterKind.Equal: + col = table.where(filter.field).equals(filter.value); + break; + case FilterKind.NotEqual: + col = table.where(filter.field).notEqual(filter.value); + break; + case FilterKind.Greater: + col = table.where(filter.field).above(filter.value); + break; + case FilterKind.GreaterOrEqual: + col = table.where(filter.field).aboveOrEqual(filter.value); + break; + case FilterKind.Less: + col = table.where(filter.field).below(filter.value); + break; + case FilterKind.LessOrEqual: + col = table.where(filter.field).belowOrEqual(filter.value); + break; + case FilterKind.Contains: + // Note(albrow): This iterates through all orders and is very inefficient. + // If needed, we should try to find a way to optimize this. + col = table.filter(containsFilterFunc(filter)); + break; + default: + throw new Error(`unexpected filter kind: ${filter.kind}`); + } + // tslint:disable-next-line:no-non-null-assertion + if (queryUsesSortOptions(query) && query.sort![0].direction === SortDirection.Desc) { + // Note(albrow): This is only allowed if the sort and filter are using + // the same field. Dexie automatically returns records sorted by the filter + // field. If the direction is Ascending, we don't need to do anything else. + // If it the direction is Descending, we just need to call reverse(). + col.reverse(); + } + } else if (queryUsesSortOptions(query)) { + // tslint:disable-next-line:no-non-null-assertion + const sortOpt = query.sort![0]; + col = table.orderBy(sortOpt.field); + if (sortOpt.direction === SortDirection.Desc) { + col = col.reverse(); + } + } else { + // Query doesn't use filter or sort options. + col = table.toCollection(); + } + if (queryUsesOffset(query)) { + // tslint:disable-next-line:no-non-null-assertion + col.offset(query.offset!); + } + if (queryUsesLimit(query)) { + // tslint:disable-next-line:no-non-null-assertion + col.limit(query.limit!); + } + return col; +} + +async function runQueryInMemoryAsync( + table: Dexie.Table, + query?: Query, +): Promise { + let records = await table.toArray(); + if (query === undefined || query === null) { + return records; + } + if (queryUsesFilters(query)) { + // tslint:disable-next-line:no-non-null-assertion + records = filterRecords(query.filters!, records); + } + if (queryUsesSortOptions(query)) { + // tslint:disable-next-line:no-non-null-assertion + records = sortRecords(query.sort!, records); + } + if (queryUsesOffset(query) && queryUsesLimit(query)) { + // tslint:disable-next-line:no-non-null-assertion + records = records.slice(query.offset!, query.limit!); + } else if (queryUsesLimit(query)) { + // tslint:disable-next-line:no-non-null-assertion + records = records.slice(0, query.limit!); + } else if (queryUsesOffset(query)) { + // tslint:disable-next-line:no-non-null-assertion + records = records.slice(query.offset!); + } + + return records; +} + +function filterRecords(filters: Array>, records: T[]): T[] { + let result = records; + // Note(albrow): As an optimization, we could use the native Dexie.js index for + // the *first* filter when possible. + for (const filter of filters) { + switch (filter.kind) { + case FilterKind.Equal: + result = result.filter(record => record[filter.field] === filter.value); + break; + case FilterKind.NotEqual: + result = result.filter(record => record[filter.field] !== filter.value); + break; + case FilterKind.Greater: + result = result.filter(record => record[filter.field] > filter.value); + break; + case FilterKind.GreaterOrEqual: + result = result.filter(record => record[filter.field] >= filter.value); + break; + case FilterKind.Less: + result = result.filter(record => record[filter.field] < filter.value); + break; + case FilterKind.LessOrEqual: + result = result.filter(record => record[filter.field] <= filter.value); + break; + case FilterKind.Contains: + result = result.filter(containsFilterFunc(filter)); + break; + default: + throw new Error(`unexpected filter kind: ${filter.kind}`); + } + } + + return result; +} + +function sortRecords(sortOpts: Array>, records: T[]): T[] { + // Note(albrow): As an optimization, we could use native Dexie.js ordering for + // the *first* sort option when possible. + const result = records; + return result.sort((a: T, b: T) => { + for (const s of sortOpts) { + switch (s.direction) { + case SortDirection.Asc: + if (a[s.field] < b[s.field]) { + return -1; + } else if (a[s.field] > b[s.field]) { + return 1; + } + break; + case SortDirection.Desc: + if (a[s.field] > b[s.field]) { + return -1; + } else if (a[s.field] < b[s.field]) { + return 1; + } + break; + default: + throw new Error(`unexpected sort direction: ${s.direction}`); + } + } + return 0; + }); +} + +function isString(x: any): x is string { + return typeof x === 'string'; +} + +function containsFilterFunc(filter: FilterOption): (record: T) => boolean { + return (record: T): boolean => { + const field = record[filter.field]; + if (!isString(field)) { + throw new Error( + `cannot use CONTAINS filter on non-string field ${filter.field} of type ${typeof record[filter.field]}`, + ); + } + return field.includes(filter.value); + }; +} + +function canUseNativeDexieIndexes(table: Dexie.Table, query?: Query): boolean { + if (query === null || query === undefined) { + return true; + } + // tslint:disable-next-line:no-non-null-assertion + if (queryUsesSortOptions(query) && query.sort!.length > 1) { + // Dexie does not support multiple sort orders. + return false; + } + // tslint:disable-next-line:no-non-null-assertion + if (queryUsesFilters(query) && query.filters!.length > 1) { + // Dexie does not support multiple filters. + return false; + } + // tslint:disable-next-line:no-non-null-assertion + if (queryUsesFilters(query) && queryUsesSortOptions(query) && query.filters![0].field !== query.sort![0].field) { + // Dexie does not support sorting and filtering by two different fields. + return false; + } + return true; +} + +function queryUsesSortOptions(query: Query): boolean { + return query.sort !== null && query.sort !== undefined && query.sort.length > 0; +} + +function queryUsesFilters(query: Query): boolean { + return query.filters !== null && query.filters !== undefined && query.filters.length > 0; +} + +function queryUsesLimit(query: Query): boolean { + return query.limit !== null && query.limit !== undefined && query.limit !== 0; +} + +function queryUsesOffset(query: Query): boolean { + return query.offset !== null && query.offset !== undefined && query.offset !== 0; +} diff --git a/packages/browser-lite/src/mesh.ts b/packages/browser-lite/src/mesh.ts index 5d9c97d4a..a38087767 100644 --- a/packages/browser-lite/src/mesh.ts +++ b/packages/browser-lite/src/mesh.ts @@ -1,15 +1,10 @@ import { SignedOrder } from '@0x/order-utils'; -import * as BrowserFS from 'browserfs'; +import { createDatabase } from './database'; import { createSchemaValidator } from './schema_validator'; -import './wasm_exec'; - -export { SignedOrder } from '@0x/order-utils'; -export { BigNumber } from '@0x/utils'; -export { SupportedProvider } from 'ethereum-types'; - import { AcceptedOrderInfo, + BigNumber, Config, ContractAddresses, ContractEvent, @@ -35,6 +30,7 @@ import { RejectedOrderKind, RejectedOrderStatus, Stats, + SupportedProvider, ValidationResults, Verbosity, WethDepositEvent, @@ -42,6 +38,7 @@ import { WrapperOrderEvent, ZeroExMesh, } from './types'; +import './wasm_exec'; import { configToWrapperConfig, orderEventsHandlerToWrapperOrderEventsHandler, @@ -53,6 +50,7 @@ import { export { AcceptedOrderInfo, + BigNumber, Config, ContractAddresses, ContractEvent, @@ -76,6 +74,8 @@ export { RejectedOrderInfo, RejectedOrderKind, RejectedOrderStatus, + SignedOrder, + SupportedProvider, Stats, ValidationResults, Verbosity, @@ -96,27 +96,19 @@ declare global { const zeroExMesh: ZeroExMesh; } -// We use the global willLoadBrowserFS variable to signal that we are going to -// initialize BrowserFS. -(window as any).willLoadBrowserFS = true; +/** + * Sets the required global variables the Mesh needs to access from Go land. + * This includes the `db` and `orderfilter` packages. + * + * @ignore + */ +export function _setGlobals(): void { + (window as any).__mesh_createSchemaValidator__ = createSchemaValidator; + (window as any).__mesh_dexie_newDatabase__ = createDatabase; +} -BrowserFS.configure( - { - fs: 'IndexedDB', - options: { - storeName: '0x-mesh-db', - }, - }, - e => { - if (e) { - throw e; - } - // We use the global browserFS variable as a handle for Go/Wasm code to - // call into the BrowserFS API. Setting this variable also indicates - // that BrowserFS has finished loading. - (window as any).browserFS = BrowserFS.BFSRequire('fs'); - }, -); +// We immediately want to set the required globals. +_setGlobals(); // The interval (in milliseconds) to check whether Wasm is done loading. const wasmLoadCheckIntervalMs = 100; @@ -128,8 +120,6 @@ window.addEventListener(loadEventName, () => { isWasmLoaded = true; }); -(window as any).createSchemaValidator = createSchemaValidator; - /** * The main class for this package. Has methods for receiving order events and * sending orders through the 0x Mesh network. @@ -227,27 +217,21 @@ export class Mesh { return Promise.reject(new Error('Mesh is still loading. Try again soon.')); } - let snapshotID = ''; // New snapshot - // TODO(albrow): De-dupe this code with the method by the same name // in the TypeScript RPC client. - let page = 0; - let getOrdersResponse = await this.getOrdersForPageAsync(page, perPage, snapshotID); - snapshotID = getOrdersResponse.snapshotID; + let getOrdersResponse = await this.getOrdersForPageAsync(perPage); let ordersInfos = getOrdersResponse.ordersInfos; - let allOrderInfos: OrderInfo[] = []; do { allOrderInfos = [...allOrderInfos, ...ordersInfos]; - page++; - getOrdersResponse = await this.getOrdersForPageAsync(page, perPage, snapshotID); + const minOrderHash = ordersInfos[ordersInfos.length - 1].orderHash; + getOrdersResponse = await this.getOrdersForPageAsync(perPage, minOrderHash); ordersInfos = getOrdersResponse.ordersInfos; } while (ordersInfos.length > 0); getOrdersResponse = { - snapshotID, - snapshotTimestamp: getOrdersResponse.snapshotTimestamp, + timestamp: getOrdersResponse.timestamp, ordersInfos: allOrderInfos, }; return getOrdersResponse; @@ -255,12 +239,11 @@ export class Mesh { /** * Get page of 0x signed orders stored on the Mesh node at the specified snapshot - * @param page Page index at which to retrieve orders * @param perPage Number of signedOrders to fetch per paginated request - * @param snapshotID The DB snapshot at which to fetch orders. If omitted, a new snapshot is created - * @returns the snapshotID, snapshotTimestamp and all orders, their hashes and fillableTakerAssetAmounts + * @param minOrderHash The minimum order hash for the returned orders. Should be set based on the last hash from the previous response. + * @returns Up to perPage orders with hash greater than minOrderHash, including order hashes and fillableTakerAssetAmounts */ - public async getOrdersForPageAsync(page: number, perPage: number, snapshotID?: string): Promise { + public async getOrdersForPageAsync(perPage: number, minOrderHash?: string): Promise { await waitForLoadAsync(); if (this._wrapper === undefined) { // If this is called after startAsync, this._wrapper is always @@ -269,7 +252,7 @@ export class Mesh { return Promise.reject(new Error('Mesh is still loading. Try again soon.')); } - const wrapperOrderResponse = await this._wrapper.getOrdersForPageAsync(page, perPage, snapshotID); + const wrapperOrderResponse = await this._wrapper.getOrdersForPageAsync(perPage, minOrderHash); return wrapperGetOrdersResponseToGetOrdersResponse(wrapperOrderResponse); } diff --git a/packages/browser-lite/src/types.ts b/packages/browser-lite/src/types.ts index 23d815a12..c7b7caf85 100644 --- a/packages/browser-lite/src/types.ts +++ b/packages/browser-lite/src/types.ts @@ -8,14 +8,12 @@ export { SupportedProvider } from 'ethereum-types'; /** @ignore */ export interface WrapperGetOrdersResponse { - snapshotID: string; - snapshotTimestamp: string; + timestamp: string; ordersInfos: WrapperOrderInfo[]; } export interface GetOrdersResponse { - snapshotID: string; - snapshotTimestamp: number; + timestamp: number; ordersInfos: OrderInfo[]; } @@ -231,7 +229,7 @@ export interface MeshWrapper { onError(handler: (err: Error) => void): void; onOrderEvents(handler: (events: WrapperOrderEvent[]) => void): void; getStatsAsync(): Promise; - getOrdersForPageAsync(page: number, perPage: number, snapshotID?: string): Promise; + getOrdersForPageAsync(perPage: number, minOrderHash?: string): Promise; addOrdersAsync(orders: WrapperSignedOrder[], pinned: boolean): Promise; } diff --git a/packages/browser-lite/src/wrapper_conversion.ts b/packages/browser-lite/src/wrapper_conversion.ts index 4aea6f696..671bc3167 100644 --- a/packages/browser-lite/src/wrapper_conversion.ts +++ b/packages/browser-lite/src/wrapper_conversion.ts @@ -284,7 +284,7 @@ export function wrapperGetOrdersResponseToGetOrdersResponse( ): GetOrdersResponse { return { ...wrapperGetOrdersResponse, - snapshotTimestamp: new Date(wrapperGetOrdersResponse.snapshotTimestamp).getTime(), + timestamp: new Date(wrapperGetOrdersResponse.timestamp).getTime(), ordersInfos: wrapperGetOrdersResponse.ordersInfos.map(wrapperOrderInfoToOrderInfo), }; } diff --git a/packages/browser/conversion-tests/conversion_test.ts b/packages/browser/conversion-tests/conversion_test.ts index d481b37ec..1a2777b7a 100644 --- a/packages/browser/conversion-tests/conversion_test.ts +++ b/packages/browser/conversion-tests/conversion_test.ts @@ -371,13 +371,11 @@ function testContractEventPrelude( function testGetOrdersResponse(getOrdersResponse: WrapperGetOrdersResponse[]): void { let printer = prettyPrintTestCase('getOrdersResponse', 'EmptyOrderInfo'); - printer('snapshotID', getOrdersResponse[0].snapshotID === '208c81f9-6f8d-44aa-b6ea-0a3276ec7318'); - printer('snapshotTimestamp', getOrdersResponse[0].snapshotTimestamp === '2006-01-01T00:00:00Z'); + printer('timestamp', getOrdersResponse[0].timestamp === '2006-01-01T00:00:00Z'); printer('orderInfo.length', getOrdersResponse[0].ordersInfos.length === 0); printer = prettyPrintTestCase('getOrdersResponse', 'OneOrderInfo'); - printer('snapshotID', getOrdersResponse[1].snapshotID === '208c81f9-6f8d-44aa-b6ea-0a3276ec7318'); - printer('snapshotTimestamp', getOrdersResponse[1].snapshotTimestamp === '2006-01-01T00:00:00Z'); + printer('timestamp', getOrdersResponse[1].timestamp === '2006-01-01T00:00:00Z'); printer('orderInfo.length', getOrdersResponse[1].ordersInfos.length === 1); printer('orderInfo.orderHash', getOrdersResponse[1].ordersInfos[0].orderHash === hexUtils.leftPad('0x1', 32)); printer('orderInfo.signedOrder.chainId', getOrdersResponse[1].ordersInfos[0].signedOrder.chainId === 1337); @@ -442,8 +440,7 @@ function testGetOrdersResponse(getOrdersResponse: WrapperGetOrdersResponse[]): v ); printer = prettyPrintTestCase('getOrdersResponse', 'TwoOrderInfos'); - printer('snapshotID', getOrdersResponse[2].snapshotID === '208c81f9-6f8d-44aa-b6ea-0a3276ec7318'); - printer('snapshotTimestamp', getOrdersResponse[2].snapshotTimestamp === '2006-01-01T00:00:00Z'); + printer('timestamp', getOrdersResponse[2].timestamp === '2006-01-01T00:00:00Z'); printer('orderInfo.length', getOrdersResponse[2].ordersInfos.length === 2); printer('orderInfo.orderHash', getOrdersResponse[2].ordersInfos[0].orderHash === hexUtils.leftPad('0x1', 32)); printer('orderInfo.signedOrder.chainId', getOrdersResponse[2].ordersInfos[0].signedOrder.chainId === 1337); diff --git a/packages/browser/go/conversion-test/conversion_test.go b/packages/browser/go/conversion-test/conversion_test.go index 20e537b20..6621be402 100644 --- a/packages/browser/go/conversion-test/conversion_test.go +++ b/packages/browser/go/conversion-test/conversion_test.go @@ -241,8 +241,7 @@ func registerConvertConfigField(description string, field string) { } func registerGetOrdersResponseTest(description string, orderInfoLength int) { - registerGetOrdersResponseField(description, "snapshotID") - registerGetOrdersResponseField(description, "snapshotTimestamp") + registerGetOrdersResponseField(description, "timestamp") registerGetOrdersResponseField(description, "orderInfo.length") for i := 0; i < orderInfoLength; i++ { registerGetOrdersResponseField(description, "orderInfo.orderHash") diff --git a/packages/browser/go/conversion-test/main.go b/packages/browser/go/conversion-test/main.go index c5e593cdb..a64ad5251 100644 --- a/packages/browser/go/conversion-test/main.go +++ b/packages/browser/go/conversion-test/main.go @@ -258,13 +258,11 @@ func setGlobals() { "getOrdersResponse": js.FuncOf(func(this js.Value, args []js.Value) interface{} { return []interface{}{ types.GetOrdersResponse{ - SnapshotID: "208c81f9-6f8d-44aa-b6ea-0a3276ec7318", - SnapshotTimestamp: time.Date(2006, time.January, 1, 0, 0, 0, 0, time.UTC), - OrdersInfos: []*types.OrderInfo{}, + Timestamp: time.Date(2006, time.January, 1, 0, 0, 0, 0, time.UTC), + OrdersInfos: []*types.OrderInfo{}, }, types.GetOrdersResponse{ - SnapshotID: "208c81f9-6f8d-44aa-b6ea-0a3276ec7318", - SnapshotTimestamp: time.Date(2006, time.January, 1, 0, 0, 0, 0, time.UTC), + Timestamp: time.Date(2006, time.January, 1, 0, 0, 0, 0, time.UTC), OrdersInfos: []*types.OrderInfo{ &types.OrderInfo{ OrderHash: common.HexToHash("0x1"), @@ -294,8 +292,7 @@ func setGlobals() { }, }, types.GetOrdersResponse{ - SnapshotID: "208c81f9-6f8d-44aa-b6ea-0a3276ec7318", - SnapshotTimestamp: time.Date(2006, time.January, 1, 0, 0, 0, 0, time.UTC), + Timestamp: time.Date(2006, time.January, 1, 0, 0, 0, 0, time.UTC), OrdersInfos: []*types.OrderInfo{ &types.OrderInfo{ OrderHash: common.HexToHash("0x1"), diff --git a/packages/browser/go/jsutil/jsutil.go b/packages/browser/go/jsutil/jsutil.go index e1459b2ce..c7f0f91d3 100644 --- a/packages/browser/go/jsutil/jsutil.go +++ b/packages/browser/go/jsutil/jsutil.go @@ -6,6 +6,7 @@ package jsutil import ( "bytes" + "context" "encoding/json" "fmt" "syscall/js" @@ -44,6 +45,47 @@ func WrapInPromise(f func() (interface{}, error)) js.Value { return js.Global().Get("Promise").New(executor) } +// AwaitPromiseContext is like AwaitPromise but accepts a context. If the context +// is canceled or times out before the promise resolves, it will return +// (js.Undefined, ctx.Error). +func AwaitPromiseContext(ctx context.Context, promise js.Value) (result js.Value, err error) { + resultsChan := make(chan js.Value) + errChan := make(chan js.Error) + + thenFunc := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + go func() { + resultsChan <- args[0] + }() + return js.Undefined() + }) + defer thenFunc.Release() + catchFunc := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + go func() { + errChan <- js.Error{Value: args[0]} + }() + return js.Undefined() + }) + defer catchFunc.Release() + promise.Call("then", thenFunc).Call("catch", catchFunc) + + select { + case <-ctx.Done(): + return js.Undefined(), ctx.Err() + case result := <-resultsChan: + return result, nil + case err := <-errChan: + return js.Undefined(), err + } +} + +// AwaitPromise accepts a js.Value representing a Promise. If the promise +// resolves, it returns (result, nil). If the promise rejects, it returns +// (js.Undefined, error). AwaitPromise has a synchronous-like API but does not +// block the JavaScript event loop. +func AwaitPromise(promise js.Value) (result js.Value, err error) { + return AwaitPromiseContext(context.Background(), promise) +} + // InefficientlyConvertToJS converts the given Go value to a JS value by // encoding to JSON and then decoding it. This function is not very efficient // and its use should be phased out over time as much as possible. diff --git a/packages/browser/go/mesh-browser/main.go b/packages/browser/go/mesh-browser/main.go index e237506fa..113603d4c 100644 --- a/packages/browser/go/mesh-browser/main.go +++ b/packages/browser/go/mesh-browser/main.go @@ -12,6 +12,7 @@ import ( "github.com/0xProject/0x-mesh/packages/browser/go/browserutil" "github.com/0xProject/0x-mesh/packages/browser/go/jsutil" "github.com/0xProject/0x-mesh/zeroex" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/event" ) @@ -74,11 +75,11 @@ type MeshWrapper struct { // NewMeshWrapper creates a new wrapper from the given config. func NewMeshWrapper(config core.Config) (*MeshWrapper, error) { - app, err := core.New(config) + ctx, cancel := context.WithCancel(context.Background()) + app, err := core.New(ctx, config) if err != nil { return nil, err } - ctx, cancel := context.WithCancel(context.Background()) return &MeshWrapper{ app: app, ctx: ctx, @@ -97,7 +98,7 @@ func (cw *MeshWrapper) Start() error { // cw.app.Start blocks until there is an error or the app is closed, so we // need to start it in a goroutine. go func() { - cw.errChan <- cw.app.Start(cw.ctx) + cw.errChan <- cw.app.Start() }() // Wait up to 1 second to see if cw.app.Start returns an error right away. @@ -167,8 +168,8 @@ func (cw *MeshWrapper) GetStats() (js.Value, error) { // GetOrders converts raw JavaScript parameters into the appropriate type, calls // core.App.GetOrders, converts the result into basic JavaScript types (string, // int, etc.) and returns it. -func (cw *MeshWrapper) GetOrders(page int, perPage int, snapshotID string) (js.Value, error) { - ordersResponse, err := cw.app.GetOrders(page, perPage, snapshotID) +func (cw *MeshWrapper) GetOrders(perPage int, minOrderHash string) (js.Value, error) { + ordersResponse, err := cw.app.GetOrders(perPage, common.HexToHash(minOrderHash)) if err != nil { return js.Undefined(), err } @@ -204,16 +205,16 @@ func (cw *MeshWrapper) JSValue() js.Value { return cw.GetStats() }) }), - // getOrdersForPageAsync(page: number, perPage: number, snapshotID?: string): Promise + // getOrdersForPageAsync(perPage: number, minOrderHash?: string): Promise "getOrdersForPageAsync": js.FuncOf(func(this js.Value, args []js.Value) interface{} { return jsutil.WrapInPromise(func() (interface{}, error) { - // snapshotID is optional in the JavaScript function. Check if it is + // minOrderHash is optional in the JavaScript function. Check if it is // null or undefined. - snapshotID := "" - if !jsutil.IsNullOrUndefined(args[2]) { - snapshotID = args[2].String() + minOrderHash := "" + if !jsutil.IsNullOrUndefined(args[1]) { + minOrderHash = args[1].String() } - return cw.GetOrders(args[0].Int(), args[1].Int(), snapshotID) + return cw.GetOrders(args[0].Int(), minOrderHash) }) }), // addOrdersAsync(orders: Array): Promise diff --git a/packages/integration-tests/src/index.ts b/packages/integration-tests/src/index.ts index f9392b262..7bcab4c14 100644 --- a/packages/integration-tests/src/index.ts +++ b/packages/integration-tests/src/index.ts @@ -69,9 +69,11 @@ provider.start(); for (const event of events) { // Check the happy path for getOrdersForPageAsync. There should // be two orders. (just make sure it doesn't throw/reject). - const firstOrdersResponse = await mesh.getOrdersForPageAsync(0, 1, ''); + const firstOrdersResponse = await mesh.getOrdersForPageAsync(1); console.log(JSON.stringify(firstOrdersResponse)); - const secondOrdersResponse = await mesh.getOrdersForPageAsync(1, 1, firstOrdersResponse.snapshotID); + const nextMinOrderHash = + firstOrdersResponse.ordersInfos[firstOrdersResponse.ordersInfos.length - 1].orderHash; + const secondOrdersResponse = await mesh.getOrdersForPageAsync(1, nextMinOrderHash); console.log(JSON.stringify(secondOrdersResponse)); // Check the happy path for getOrders (just make sure it diff --git a/packages/rpc-client/src/types.ts b/packages/rpc-client/src/types.ts index 7d90d3d29..b85f2cad3 100644 --- a/packages/rpc-client/src/types.ts +++ b/packages/rpc-client/src/types.ts @@ -401,8 +401,7 @@ export interface ValidationResults { } export interface RawGetOrdersResponse { - snapshotID: string; - snapshotTimestamp: string; + timestamp: string; ordersInfos: RawAcceptedOrderInfo[]; } @@ -410,8 +409,7 @@ export interface RawGetOrdersResponse { // method. The `snapshotTimestamp` is the second UTC timestamp of when the Mesh // was queried for these orders export interface GetOrdersResponse { - snapshotID: string; - snapshotTimestamp: number; + timestamp: number; ordersInfos: OrderInfo[]; } diff --git a/packages/rpc-client/src/ws_client.ts b/packages/rpc-client/src/ws_client.ts index d43b25861..6cc22fd9d 100644 --- a/packages/rpc-client/src/ws_client.ts +++ b/packages/rpc-client/src/ws_client.ts @@ -110,9 +110,8 @@ export class WSClient { } private static _convertRawGetOrdersResponse(rawGetOrdersResponse: RawGetOrdersResponse): GetOrdersResponse { return { - snapshotID: rawGetOrdersResponse.snapshotID, // tslint:disable-next-line:custom-no-magic-numbers - snapshotTimestamp: Math.round(new Date(rawGetOrdersResponse.snapshotTimestamp).getTime() / 1000), + timestamp: Math.round(new Date(rawGetOrdersResponse.timestamp).getTime() / 1000), ordersInfos: WSClient._convertRawOrderInfos(rawGetOrdersResponse.ordersInfos), }; } @@ -309,25 +308,19 @@ export class WSClient { * @returns the snapshotID, snapshotTimestamp and all orders, their hashes and fillableTakerAssetAmounts */ public async getOrdersAsync(perPage: number = 200): Promise { - let snapshotID = ''; // New snapshot - - let page = 0; - let getOrdersResponse = await this.getOrdersForPageAsync(page, perPage, snapshotID); - snapshotID = getOrdersResponse.snapshotID; + let getOrdersResponse = await this.getOrdersForPageAsync(perPage); let ordersInfos = getOrdersResponse.ordersInfos; - let allOrderInfos: OrderInfo[] = []; do { allOrderInfos = [...allOrderInfos, ...ordersInfos]; - page++; - getOrdersResponse = await this.getOrdersForPageAsync(page, perPage, snapshotID); + const minOrderHash = ordersInfos[ordersInfos.length - 1].orderHash; + getOrdersResponse = await this.getOrdersForPageAsync(perPage, minOrderHash); ordersInfos = getOrdersResponse.ordersInfos; } while (ordersInfos.length > 0); getOrdersResponse = { - snapshotID, - snapshotTimestamp: getOrdersResponse.snapshotTimestamp, + timestamp: getOrdersResponse.timestamp, ordersInfos: allOrderInfos, }; return getOrdersResponse; @@ -339,17 +332,10 @@ export class WSClient { * @param snapshotID The DB snapshot at which to fetch orders. If omitted, a new snapshot is created * @returns the snapshotID, snapshotTimestamp and all orders, their hashes and fillableTakerAssetAmounts */ - public async getOrdersForPageAsync( - page: number, - perPage: number = 200, - snapshotID?: string, - ): Promise { - const finalSnapshotID = snapshotID === undefined ? '' : snapshotID; - + public async getOrdersForPageAsync(perPage: number = 200, minOrderHash: string = ''): Promise { const rawGetOrdersResponse: RawGetOrdersResponse = await this._wsProvider.send('mesh_getOrders', [ - page, perPage, - finalSnapshotID, + minOrderHash, ]); const getOrdersResponse = WSClient._convertRawGetOrdersResponse(rawGetOrdersResponse); return getOrdersResponse; diff --git a/packages/rpc-client/test/ws_client_test.ts b/packages/rpc-client/test/ws_client_test.ts index 611ab214e..0b9d29383 100644 --- a/packages/rpc-client/test/ws_client_test.ts +++ b/packages/rpc-client/test/ws_client_test.ts @@ -9,7 +9,6 @@ import { DoneCallback, SignedOrder } from '@0x/types'; import { BigNumber, hexUtils } from '@0x/utils'; import { Web3Wrapper } from '@0x/web3-wrapper'; import 'mocha'; -import * as uuidValidate from 'uuid-validate'; import * as WebSocket from 'websocket'; import { OrderEvent, OrderEventEndState, WSClient } from '../src/index'; @@ -185,9 +184,7 @@ blockchainTests.resets('WSClient', env => { const now = new Date(Date.now()).getTime(); const perPage = ordersLength / 2; const response = await deployment.client.getOrdersAsync(perPage); - assertRoughlyEquals(now, response.snapshotTimestamp * secondsToMs(1), secondsToMs(2)); - // Verify that snapshot ID in the response meets the expected schema. - expect(uuidValidate(response.snapshotID)).to.be.true(); + assertRoughlyEquals(now, response.timestamp * secondsToMs(1), secondsToMs(2)); // Verify that all of the orders that were added to the mesh node // were returned in the `getOrders` rpc response @@ -211,19 +208,17 @@ blockchainTests.resets('WSClient', env => { // timestamp is approximately equal (within 1 second) because the server // will receive the request slightly after it is sent. const now = new Date(Date.now()).getTime(); - let page = 0; const perPage = 5; // First request for page index 0 - let response = await deployment.client.getOrdersForPageAsync(page, perPage); - assertRoughlyEquals(now, response.snapshotTimestamp * secondsToMs(1), secondsToMs(2)); - expect(uuidValidate(response.snapshotID)).to.be.true(); + let response = await deployment.client.getOrdersForPageAsync(perPage); + assertRoughlyEquals(now, response.timestamp * secondsToMs(1), secondsToMs(2)); let responseOrders = response.ordersInfos; + expect(responseOrders.length).to.be.eq(perPage); + const nextMinOrderHash = responseOrders[responseOrders.length - 1].orderHash; // Second request for page index 1 - page = 1; - response = await deployment.client.getOrdersForPageAsync(page, perPage, response.snapshotID); - expect(uuidValidate(response.snapshotID)).to.be.true(); + response = await deployment.client.getOrdersForPageAsync(perPage, nextMinOrderHash); // Combine orders found in first and second paginated requests responseOrders = [...responseOrders, ...response.ordersInfos]; diff --git a/packages/test-wasm/package.json b/packages/test-wasm/package.json index f4e238acf..2ca5f3073 100644 --- a/packages/test-wasm/package.json +++ b/packages/test-wasm/package.json @@ -10,11 +10,15 @@ "clean": "shx rm -r ./dist && shx rm -r ./lib || exit 0", "lint": "tslint --format stylish --project ." }, + "devDependencies": { + "@types/dexie": "^1.3.1" + }, "dependencies": { "@0x/mesh-browser-lite": "1.0.0", + "dexie": "^3.0.1", "shx": "^0.3.2", - "typescript": "^3.9.3", "ts-loader": "^6.2.1", + "typescript": "^3.9.3", "webpack": "^4.43.0", "webpack-cli": "^3.3.10" } diff --git a/packages/test-wasm/src/browser_shim.ts b/packages/test-wasm/src/browser_shim.ts index 3667a4de8..6ed70179a 100644 --- a/packages/test-wasm/src/browser_shim.ts +++ b/packages/test-wasm/src/browser_shim.ts @@ -1,3 +1,4 @@ -import { createSchemaValidator } from '@0x/mesh-browser-lite/lib/schema_validator'; +import { _setGlobals } from '@0x/mesh-browser-lite'; -(window as any).createSchemaValidator = createSchemaValidator; +// Set the globals that are required for e.g. the `db` and `orderfilter` packages. +_setGlobals(); diff --git a/rpc/client.go b/rpc/client.go index c81cb8fe3..dd2a3f704 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -9,6 +9,7 @@ import ( "github.com/0xProject/0x-mesh/common/types" "github.com/0xProject/0x-mesh/zeroex" "github.com/0xProject/0x-mesh/zeroex/ordervalidator" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/rpc" peer "github.com/libp2p/go-libp2p-core/peer" peerstore "github.com/libp2p/go-libp2p-peerstore" @@ -51,9 +52,9 @@ func (c *Client) AddOrders(orders []*zeroex.SignedOrder, opts ...types.AddOrders } // GetOrders gets all orders stored on the Mesh node at a particular point in time in a paginated fashion -func (c *Client) GetOrders(page, perPage int, snapshotID string) (*types.GetOrdersResponse, error) { +func (c *Client) GetOrders(perPage int, minOrderHash common.Hash) (*types.GetOrdersResponse, error) { var getOrdersResponse types.GetOrdersResponse - if err := c.rpcClient.Call(&getOrdersResponse, "mesh_getOrders", page, perPage, snapshotID); err != nil { + if err := c.rpcClient.Call(&getOrdersResponse, "mesh_getOrders", perPage, minOrderHash.Hex()); err != nil { return nil, err } return &getOrdersResponse, nil diff --git a/rpc/service.go b/rpc/service.go index 98f7dac44..8b2b938ac 100644 --- a/rpc/service.go +++ b/rpc/service.go @@ -33,7 +33,7 @@ type RPCHandler interface { // AddOrders is called when the client sends an AddOrders request. AddOrders(signedOrdersRaw []*json.RawMessage, opts types.AddOrdersOpts) (*ordervalidator.ValidationResults, error) // GetOrders is called when the clients sends a GetOrders request - GetOrders(page, perPage int, snapshotID string) (*types.GetOrdersResponse, error) + GetOrders(perPage int, minOrderHash string) (*types.GetOrdersResponse, error) // AddPeer is called when the client sends an AddPeer request. AddPeer(peerInfo peerstore.PeerInfo) error // GetStats is called when the client sends an GetStats request. @@ -136,8 +136,8 @@ func (s *rpcService) AddOrders(signedOrdersRaw []*json.RawMessage, opts *types.A } // GetOrders calls rpcHandler.GetOrders and returns the validation results. -func (s *rpcService) GetOrders(page, perPage int, snapshotID string) (*types.GetOrdersResponse, error) { - return s.rpcHandler.GetOrders(page, perPage, snapshotID) +func (s *rpcService) GetOrders(perPage int, minOrderHash string) (*types.GetOrdersResponse, error) { + return s.rpcHandler.GetOrders(perPage, minOrderHash) } // AddPeer builds PeerInfo out of the given peer ID and multiaddresses and diff --git a/scenario/scenario.go b/scenario/scenario.go index 453a5639c..679250cc7 100644 --- a/scenario/scenario.go +++ b/scenario/scenario.go @@ -27,8 +27,8 @@ import ( var ( ethClient *ethclient.Client ganacheAddresses = ethereum.GanacheAddresses - ZRXAssetData = common.Hex2Bytes("f47261b0000000000000000000000000871dd7c2b4b25e1aa18728e9d5f2af4c4e431f5c") - WETHAssetData = common.Hex2Bytes("f47261b00000000000000000000000000b1ba0af832d7c05fd64161e0db78e85978e8082") + ZRXAssetData = constants.ZRXAssetData + WETHAssetData = constants.WETHAssetData ) func init() { diff --git a/yarn.lock b/yarn.lock index df5dbf364..875698c35 100644 --- a/yarn.lock +++ b/yarn.lock @@ -656,6 +656,13 @@ dependencies: "@types/node" "*" +"@types/dexie@^1.3.1": + version "1.3.1" + resolved "https://registry.yarnpkg.com/@types/dexie/-/dexie-1.3.1.tgz#accca262f9071f1ed963a40255fcf4bc7af67e15" + integrity sha1-rMyiYvkHHx7ZY6QCVfz0vHr2fhU= + dependencies: + dexie "*" + "@types/ethereum-protocol@*": version "1.0.1" resolved "https://registry.yarnpkg.com/@types/ethereum-protocol/-/ethereum-protocol-1.0.1.tgz#04bb8a91824a5ee2fae959cc788412321350a75d" @@ -2953,6 +2960,11 @@ detect-node@2.0.3: resolved "https://registry.yarnpkg.com/detect-node/-/detect-node-2.0.3.tgz#a2033c09cc8e158d37748fbde7507832bd6ce127" integrity sha1-ogM8CcyOFY03dI+951B4Mr1s4Sc= +dexie@*, dexie@^3.0.1: + version "3.0.1" + resolved "https://registry.yarnpkg.com/dexie/-/dexie-3.0.1.tgz#faafeb94be0d5e18b25d700546a2c05725511cfc" + integrity sha512-/s4KzlaerQnCad/uY1ZNdFckTrbdMVhLlziYQzz62Ff9Ick1lHGomvTXNfwh4ApEZATyXRyVk5F6/y8UU84B0w== + diff@3.3.1: version "3.3.1" resolved "https://registry.yarnpkg.com/diff/-/diff-3.3.1.tgz#aa8567a6eed03c531fc89d3f711cd0e5259dec75" diff --git a/zeroex/orderwatch/order_watcher.go b/zeroex/orderwatch/order_watcher.go index 0a6b5519f..3d3c932f5 100644 --- a/zeroex/orderwatch/order_watcher.go +++ b/zeroex/orderwatch/order_watcher.go @@ -9,19 +9,17 @@ import ( "sync" "time" + "github.com/0xProject/0x-mesh/common/types" "github.com/0xProject/0x-mesh/constants" "github.com/0xProject/0x-mesh/db" "github.com/0xProject/0x-mesh/ethereum" "github.com/0xProject/0x-mesh/ethereum/blockwatch" - "github.com/0xProject/0x-mesh/ethereum/miniheader" - "github.com/0xProject/0x-mesh/expirationwatch" - "github.com/0xProject/0x-mesh/meshdb" "github.com/0xProject/0x-mesh/zeroex" "github.com/0xProject/0x-mesh/zeroex/ordervalidator" "github.com/0xProject/0x-mesh/zeroex/orderwatch/decoder" "github.com/0xProject/0x-mesh/zeroex/orderwatch/slowcounter" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/types" + ethtypes "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/event" logger "github.com/sirupsen/logrus" ) @@ -76,16 +74,17 @@ const ( slowCounterInterval = 5 * time.Minute ) +var errNoBlocksStored = errors.New("no blocks were stored in the database") + // Watcher watches all order-relevant state and handles the state transitions type Watcher struct { - meshDB *meshdb.MeshDB + db *db.DB blockWatcher *blockwatch.Watcher eventDecoder *decoder.Decoder assetDataDecoder *zeroex.AssetDataDecoder blockSubscription event.Subscription blockEventsChan chan []*blockwatch.Event contractAddresses ethereum.ContractAddresses - expirationWatcher *expirationwatch.Watcher orderFeed event.Feed orderScope event.SubscriptionScope // Subscription scope tracking current live listeners contractAddressToSeenCount map[common.Address]uint @@ -104,7 +103,7 @@ type Watcher struct { } type Config struct { - MeshDB *meshdb.MeshDB + DB *db.DB BlockWatcher *blockwatch.Watcher OrderValidator *ordervalidator.OrderValidator ChainID int @@ -145,9 +144,8 @@ func New(config Config) (*Watcher, error) { } w := &Watcher{ - meshDB: config.MeshDB, + db: config.DB, blockWatcher: config.BlockWatcher, - expirationWatcher: expirationwatch.New(), contractAddressToSeenCount: map[common.Address]uint{}, orderValidator: config.OrderValidator, eventDecoder: decoder, @@ -161,22 +159,13 @@ func New(config Config) (*Watcher, error) { didProcessABlock: false, } - // Check if any orders need to be removed right away due to high expiration - // times. - orderEvents, err := w.decreaseMaxExpirationTimeIfNeeded() - if err != nil { - return nil, err - } - w.orderFeed.Send(orderEvents) - // Pre-populate the OrderWatcher with all orders already stored in the DB - orders := []*meshdb.Order{} - err = w.meshDB.Orders.FindAll(&orders) + orders, err := w.db.FindOrders(nil) if err != nil { return nil, err } for _, order := range orders { - err := w.setupInMemoryOrderState(order.SignedOrder) + err := w.setupInMemoryOrderState(order) if err != nil { return nil, err } @@ -238,21 +227,25 @@ func (w *Watcher) Watch(ctx context.Context) error { select { case err := <-mainLoopErrChan: if err != nil { + logger.WithError(err).Error("error in orderwatcher mainLoop") cancel() return err } case err := <-cleanupLoopErrChan: if err != nil { + logger.WithError(err).Error("error in orderwatcher cleanupLoop") cancel() return err } case err := <-maxExpirationTimeLoopErrChan: if err != nil { + logger.WithError(err).Error("error in orderwatcher maxExpirationTimeLoop") cancel() return err } case err := <-removedCheckerLoopErrChan: if err != nil { + logger.WithError(err).Error("error in orderwatcher removedCheckerLoop") cancel() return err } @@ -366,77 +359,58 @@ func (w *Watcher) removedCheckerLoop(ctx context.Context) error { // handleOrderExpirations takes care of generating expired and unexpired order events for orders that do not require re-validation. // Since expiry is now done according to block timestamp, we can figure out which orders have expired/unexpired statically. We do not -// process blocks that require re-validation, since the validation process will already emit the necessary events and we cannot make -// multiple updates to an order within a single DB transaction. +// process blocks that require re-validation, since the validation process will already emit the necessary events. // latestBlockTimestamp is the latest block timestamp Mesh knows about -// previousLatestBlockTimestamp is the previous latest block timestamp Mesh knew about // ordersToRevalidate contains all the orders Mesh needs to re-validate given the events emitted by the blocks processed -func (w *Watcher) handleOrderExpirations(ordersColTxn *db.Transaction, latestBlockTimestamp, previousLatestBlockTimestamp time.Time, ordersToRevalidate map[common.Hash]*meshdb.Order) ([]*zeroex.OrderEvent, error) { +func (w *Watcher) handleOrderExpirations(latestBlockTimestamp time.Time, ordersToRevalidate map[common.Hash]*types.OrderWithMetadata) ([]*zeroex.OrderEvent, error) { orderEvents := []*zeroex.OrderEvent{} - var defaultTime time.Time - - if previousLatestBlockTimestamp == defaultTime || previousLatestBlockTimestamp.Before(latestBlockTimestamp) { - expiredOrders := w.expirationWatcher.Prune(latestBlockTimestamp) - for _, expiredOrder := range expiredOrders { - orderHash := common.HexToHash(expiredOrder.ID) - // If we will re-validate this order, the revalidation process will discover that - // it's expired, and an appropriate event will already be emitted - if _, ok := ordersToRevalidate[orderHash]; ok { - continue - } - order := &meshdb.Order{} - err := w.meshDB.Orders.FindByID(orderHash.Bytes(), order) - if err != nil { - logger.WithFields(logger.Fields{ - "error": err.Error(), - "orderHash": expiredOrder.ID, - }).Trace("Order expired that was no longer in DB") - continue - } - w.unwatchOrder(ordersColTxn, order, order.FillableTakerAssetAmount) - orderEvent := &zeroex.OrderEvent{ - Timestamp: latestBlockTimestamp, - OrderHash: common.HexToHash(expiredOrder.ID), - SignedOrder: order.SignedOrder, - FillableTakerAssetAmount: big.NewInt(0), - EndState: zeroex.ESOrderExpired, - } - orderEvents = append(orderEvents, orderEvent) + // Check for any orders that have now expired. + expiredOrders, err := w.findOrdersToExpire(latestBlockTimestamp) + if err != nil { + return orderEvents, err + } + for _, order := range expiredOrders { + // If we will re-validate this order, the revalidation process will discover that + // it's expired, and an appropriate event will already be emitted + if _, ok := ordersToRevalidate[order.Hash]; ok { + continue } - } else if previousLatestBlockTimestamp.After(latestBlockTimestamp) { - // A block re-org happened resulting in the latest block timestamp being - // lower than on the previous latest block. We need to "unexpire" any orders - // that have now become valid again as a result. - removedOrders, err := w.meshDB.FindRemovedOrders() - if err != nil { - return orderEvents, err + w.unwatchOrder(order, nil) + orderEvent := &zeroex.OrderEvent{ + Timestamp: latestBlockTimestamp, + OrderHash: order.Hash, + SignedOrder: order.SignedOrder(), + FillableTakerAssetAmount: big.NewInt(0), + EndState: zeroex.ESOrderExpired, } - for _, order := range removedOrders { - // Orders removed due to expiration have non-zero FillableTakerAssetAmounts - if order.FillableTakerAssetAmount.Cmp(big.NewInt(0)) == 0 { - continue - } - // If we will re-validate this order, the revalidation process will discover that - // it's unexpired, and an appropriate event will already be emitted - if _, ok := ordersToRevalidate[order.Hash]; ok { - continue - } - expiration := time.Unix(order.SignedOrder.ExpirationTimeSeconds.Int64(), 0) - if latestBlockTimestamp.Before(expiration) { - w.rewatchOrder(ordersColTxn, order, order.FillableTakerAssetAmount) - orderEvent := &zeroex.OrderEvent{ - Timestamp: latestBlockTimestamp, - OrderHash: order.Hash, - SignedOrder: order.SignedOrder, - FillableTakerAssetAmount: order.FillableTakerAssetAmount, - EndState: zeroex.ESOrderUnexpired, - } - orderEvents = append(orderEvents, orderEvent) - } + orderEvents = append(orderEvents, orderEvent) + } + + // Check for any orders which have now unexpired. + // + // A block re-org may have happened resulting in the latest block timestamp + // being lower than on the previous latest block. We need to "unexpire" any + // orders that have now become valid again as a result. + unexpiredOrders, err := w.findOrdersToUnexpire(latestBlockTimestamp) + if err != nil { + return orderEvents, err + } + for _, order := range unexpiredOrders { + // If we will re-validate this order, the revalidation process will discover that + // it's unexpired, and an appropriate event will already be emitted + if _, ok := ordersToRevalidate[order.Hash]; ok { + continue + } + w.rewatchOrder(order, order.FillableTakerAssetAmount) + orderEvent := &zeroex.OrderEvent{ + Timestamp: latestBlockTimestamp, + OrderHash: order.Hash, + SignedOrder: order.SignedOrder(), + FillableTakerAssetAmount: order.FillableTakerAssetAmount, + EndState: zeroex.ESOrderUnexpired, } - } else { - // The block timestamp hasn't changed, noop + orderEvents = append(orderEvents, orderEvent) } return orderEvents, nil @@ -452,34 +426,8 @@ func (w *Watcher) handleBlockEvents( return nil } - miniHeadersColTxn := w.meshDB.MiniHeaders.OpenTransaction() - defer func() { - _ = miniHeadersColTxn.Discard() - }() - ordersColTxn := w.meshDB.Orders.OpenTransaction() - defer func() { - _ = ordersColTxn.Discard() - }() - - var previousLatestBlockTimestamp time.Time - previousLatestBlock, err := w.meshDB.FindLatestMiniHeader() - if err != nil { - // If no previousLatestBlock, that's ok - if _, ok := err.(meshdb.MiniHeaderCollectionEmptyError); !ok { - return err - } - } - if previousLatestBlock != nil { - previousLatestBlockTimestamp = previousLatestBlock.Timestamp - } latestBlockNumber, latestBlockTimestamp := w.getBlockchainState(events) - - err = updateBlockHeadersStoredInDB(miniHeadersColTxn, events) - if err != nil { - return err - } - - orderHashToDBOrder := map[common.Hash]*meshdb.Order{} + orderHashToDBOrder := map[common.Hash]*types.OrderWithMetadata{} orderHashToEvents := map[common.Hash][]*zeroex.ContractEvent{} for _, event := range events { for _, log := range event.BlockHeader.Logs { @@ -510,7 +458,7 @@ func (w *Watcher) handleBlockEvents( Address: log.Address, Kind: eventType, } - orders := []*meshdb.Order{} + orders := []*types.OrderWithMetadata{} switch eventType { case "ERC20TransferEvent": var transferEvent decoder.ERC20TransferEvent @@ -522,12 +470,12 @@ func (w *Watcher) handleBlockEvents( return err } contractEvent.Parameters = transferEvent - fromOrders, err := w.findOrdersByTokenAddressAndTokenID(transferEvent.From, log.Address, nil) + fromOrders, err := w.findOrdersByTokenAddress(transferEvent.From, log.Address) if err != nil { return err } orders = append(orders, fromOrders...) - toOrders, err := w.findOrdersByTokenAddressAndTokenID(transferEvent.To, log.Address, nil) + toOrders, err := w.findOrdersByTokenAddress(transferEvent.To, log.Address) if err != nil { return err } @@ -547,7 +495,7 @@ func (w *Watcher) handleBlockEvents( continue } contractEvent.Parameters = approvalEvent - orders, err = w.findOrdersByTokenAddressAndTokenID(approvalEvent.Owner, log.Address, nil) + orders, err = w.findOrdersByTokenAddress(approvalEvent.Owner, log.Address) if err != nil { return err } @@ -602,7 +550,7 @@ func (w *Watcher) handleBlockEvents( continue } contractEvent.Parameters = approvalForAllEvent - orders, err = w.findOrdersByTokenAddressAndTokenID(approvalForAllEvent.Owner, log.Address, nil) + orders, err = w.findOrdersByTokenAddress(approvalForAllEvent.Owner, log.Address) if err != nil { return err } @@ -624,12 +572,12 @@ func (w *Watcher) handleBlockEvents( // further. In the future, we might want to special-case this broader approach for the Augur // contract address specifically. contractEvent.Parameters = transferEvent - fromOrders, err := w.findOrdersByTokenAddressAndTokenID(transferEvent.From, log.Address, nil) + fromOrders, err := w.findOrdersByTokenAddress(transferEvent.From, log.Address) if err != nil { return err } orders = append(orders, fromOrders...) - toOrders, err := w.findOrdersByTokenAddressAndTokenID(transferEvent.To, log.Address, nil) + toOrders, err := w.findOrdersByTokenAddress(transferEvent.To, log.Address) if err != nil { return err } @@ -645,12 +593,12 @@ func (w *Watcher) handleBlockEvents( return err } contractEvent.Parameters = transferEvent - fromOrders, err := w.findOrdersByTokenAddressAndTokenID(transferEvent.From, log.Address, nil) + fromOrders, err := w.findOrdersByTokenAddress(transferEvent.From, log.Address) if err != nil { return err } orders = append(orders, fromOrders...) - toOrders, err := w.findOrdersByTokenAddressAndTokenID(transferEvent.To, log.Address, nil) + toOrders, err := w.findOrdersByTokenAddress(transferEvent.To, log.Address) if err != nil { return err } @@ -670,7 +618,7 @@ func (w *Watcher) handleBlockEvents( continue } contractEvent.Parameters = approvalForAllEvent - orders, err = w.findOrdersByTokenAddressAndTokenID(approvalForAllEvent.Owner, log.Address, nil) + orders, err = w.findOrdersByTokenAddress(approvalForAllEvent.Owner, log.Address) if err != nil { return err } @@ -685,7 +633,7 @@ func (w *Watcher) handleBlockEvents( return err } contractEvent.Parameters = withdrawalEvent - orders, err = w.findOrdersByTokenAddressAndTokenID(withdrawalEvent.Owner, log.Address, nil) + orders, err = w.findOrdersByTokenAddress(withdrawalEvent.Owner, log.Address) if err != nil { return err } @@ -700,7 +648,7 @@ func (w *Watcher) handleBlockEvents( return err } contractEvent.Parameters = depositEvent - orders, err = w.findOrdersByTokenAddressAndTokenID(depositEvent.Owner, log.Address, nil) + orders, err = w.findOrdersByTokenAddress(depositEvent.Owner, log.Address) if err != nil { return err } @@ -746,7 +694,20 @@ func (w *Watcher) handleBlockEvents( return err } contractEvent.Parameters = exchangeCancelUpToEvent - cancelledOrders, err := w.meshDB.FindOrdersByMakerAddressAndMaxSalt(exchangeCancelUpToEvent.MakerAddress, exchangeCancelUpToEvent.OrderEpoch) + cancelledOrders, err := w.db.FindOrders(&db.OrderQuery{ + Filters: []db.OrderFilter{ + { + Field: db.OFMakerAddress, + Kind: db.Equal, + Value: exchangeCancelUpToEvent.MakerAddress, + }, + { + Field: db.OFSalt, + Kind: db.LessOrEqual, + Value: exchangeCancelUpToEvent.OrderEpoch, + }, + }, + }) if err != nil { logger.WithFields(logger.Fields{ "error": err.Error(), @@ -773,32 +734,19 @@ func (w *Watcher) handleBlockEvents( } } - expirationOrderEvents, err := w.handleOrderExpirations(ordersColTxn, latestBlockTimestamp, previousLatestBlockTimestamp, orderHashToDBOrder) + expirationOrderEvents, err := w.handleOrderExpirations(latestBlockTimestamp, orderHashToDBOrder) if err != nil { return err } // This timeout of 1min is for limiting how long this call should block at the ETH RPC rate limiter - ctx, done := context.WithTimeout(ctx, 1*time.Minute) - defer done() - postValidationOrderEvents, err := w.generateOrderEventsIfChanged(ctx, ordersColTxn, orderHashToDBOrder, orderHashToEvents, latestBlockNumber, latestBlockTimestamp) + ctx, cancel := context.WithTimeout(ctx, 1*time.Minute) + defer cancel() + postValidationOrderEvents, err := w.generateOrderEventsIfChanged(ctx, orderHashToDBOrder, orderHashToEvents, latestBlockNumber, latestBlockTimestamp) if err != nil { return err } - if err := ordersColTxn.Commit(); err != nil { - logger.WithFields(logger.Fields{ - "error": err.Error(), - }).Error("Failed to commit orders collection transaction") - return err - } - if err := miniHeadersColTxn.Commit(); err != nil { - logger.WithFields(logger.Fields{ - "error": err.Error(), - }).Error("Failed to commit miniheaders collection transaction") - return err - } - orderEvents := append(expirationOrderEvents, postValidationOrderEvents...) if len(orderEvents) > 0 { w.orderFeed.Send(orderEvents) @@ -811,14 +759,18 @@ func (w *Watcher) handleBlockEvents( } w.atLeastOneBlockProcessedMu.Unlock() - // Since we might have added MiniHeaders to the DB, we need to prune any excess MiniHeaders stored - // in the DB - err = w.meshDB.PruneMiniHeadersAboveRetentionLimit() + return nil +} + +func (w *Watcher) getLatestBlock() (*types.MiniHeader, error) { + latestBlock, err := w.db.GetLatestMiniHeader() if err != nil { - return err + if err == db.ErrNotFound { + return nil, errNoBlocksStored + } + return nil, err } - - return nil + return latestBlock, nil } // Cleanup re-validates all orders in DB which haven't been re-validated in @@ -828,12 +780,16 @@ func (w *Watcher) Cleanup(ctx context.Context, lastUpdatedBuffer time.Duration) w.handleBlockEventsMu.RLock() defer w.handleBlockEventsMu.RUnlock() - ordersColTxn := w.meshDB.Orders.OpenTransaction() - defer func() { - _ = ordersColTxn.Discard() - }() lastUpdatedCutOff := time.Now().Add(-lastUpdatedBuffer) - orders, err := w.meshDB.FindOrdersLastUpdatedBefore(lastUpdatedCutOff) + orders, err := w.db.FindOrders(&db.OrderQuery{ + Filters: []db.OrderFilter{ + { + Field: db.OFLastUpdated, + Kind: db.Less, + Value: lastUpdatedCutOff, + }, + }, + }) if err != nil { logger.WithFields(logger.Fields{ "error": err.Error(), @@ -841,7 +797,7 @@ func (w *Watcher) Cleanup(ctx context.Context, lastUpdatedBuffer time.Duration) }).Error("Failed to find orders by LastUpdatedBefore") return err } - orderHashToDBOrder := map[common.Hash]*meshdb.Order{} + orderHashToDBOrder := map[common.Hash]*types.OrderWithMetadata{} orderHashToEvents := map[common.Hash][]*zeroex.ContractEvent{} // No events when running cleanup job for _, order := range orders { select { @@ -853,24 +809,18 @@ func (w *Watcher) Cleanup(ctx context.Context, lastUpdatedBuffer time.Duration) orderHashToEvents[order.Hash] = []*zeroex.ContractEvent{} } - latestBlock, err := w.meshDB.FindLatestMiniHeader() + latestBlock, err := w.getLatestBlock() if err != nil { return err } // This timeout of 30min is for limiting how long this call should block at the ETH RPC rate limiter ctx, cancel := context.WithTimeout(ctx, 30*time.Minute) defer cancel() - orderEvents, err := w.generateOrderEventsIfChanged(ctx, ordersColTxn, orderHashToDBOrder, orderHashToEvents, latestBlock.Number, latestBlock.Timestamp) + orderEvents, err := w.generateOrderEventsIfChanged(ctx, orderHashToDBOrder, orderHashToEvents, latestBlock.Number, latestBlock.Timestamp) if err != nil { return err } - if err := ordersColTxn.Commit(); err != nil { - logger.WithFields(logger.Fields{ - "error": err.Error(), - }).Error("Failed to commit orders collection transaction") - } - if len(orderEvents) > 0 { w.orderFeed.Send(orderEvents) } @@ -879,20 +829,53 @@ func (w *Watcher) Cleanup(ctx context.Context, lastUpdatedBuffer time.Duration) } func (w *Watcher) permanentlyDeleteStaleRemovedOrders(ctx context.Context) error { - removedOrders, err := w.meshDB.FindRemovedOrders() + // TODO(albrow): This could be optimized by using a single query to delete + // stale orders instead of finding them and deleting one-by-one. Limited by + // the fact that we need to update in-memory state. When we remove in-memory + // state we can revisit this. + // + // opts := &db.DeleteOrdersOpts{ + // Filters: []db.OrderFilter{ + // { + // Field: db.OFIsRemoved, + // Kind: db.Equal, + // Value: true, + // }, + // { + // Field: db.OFLastUpdated, + // Kind: db.Less, + // Value: minLastUpdated, + // }, + // }, + // } + // return w.db.DeleteOrders(opts) + + // Find any orders marked as removed that have not been updated for a + // long time. The cutoff time is determined by permanentlyDeleteAfter. + minLastUpdated := time.Now().Add(-permanentlyDeleteAfter) + opts := &db.OrderQuery{ + Filters: []db.OrderFilter{ + { + Field: db.OFIsRemoved, + Kind: db.Equal, + Value: true, + }, + { + Field: db.OFLastUpdated, + Kind: db.Less, + Value: minLastUpdated, + }, + }, + } + ordersToDelete, err := w.db.FindOrders(opts) if err != nil { return err } - - for _, order := range removedOrders { - if time.Since(order.LastUpdated) > permanentlyDeleteAfter { - if err := w.permanentlyDeleteOrder(w.meshDB.Orders, order); err != nil { - return err - } - continue + for _, order := range ordersToDelete { + if err := w.permanentlyDeleteOrder(order); err != nil { + return err } } - return nil } @@ -902,215 +885,102 @@ func (w *Watcher) permanentlyDeleteStaleRemovedOrders(ctx context.Context) error // by any DDoS prevention or incentive mechanisms and will always stay in // storage until they are no longer fillable. func (w *Watcher) add(orderInfos []*ordervalidator.AcceptedOrderInfo, validationBlockNumber *big.Int, pinned bool) ([]*zeroex.OrderEvent, error) { - orderEvents, err := w.decreaseMaxExpirationTimeIfNeeded() - if err != nil { - return orderEvents, err - } - - // TODO(albrow): technically we should count the current number of orders, - // remove some if needed, and then insert the order in a single transaction to - // ensure that we don't accidentally exceed the maximum. In practice, and - // because of the way OrderWatcher works, the distinction shouldn't matter. - txn := w.meshDB.Orders.OpenTransaction() - defer func() { - _ = txn.Discard() - }() - now := time.Now().UTC() + orderEvents := []*zeroex.OrderEvent{} + dbOrders := []*types.OrderWithMetadata{} for _, orderInfo := range orderInfos { - order := &meshdb.Order{ - Hash: orderInfo.OrderHash, - SignedOrder: orderInfo.SignedOrder, - LastUpdated: now, - FillableTakerAssetAmount: orderInfo.FillableTakerAssetAmount, - IsRemoved: false, - IsPinned: pinned, - } - // Final expiration time check before inserting the order. We might have just - // changed max expiration time above. - if !pinned && orderInfo.SignedOrder.ExpirationTimeSeconds.Cmp(w.maxExpirationTime) == 1 { - // HACK(albrow): This is technically not the ideal way to respond to this - // situation, but it is a lot easier to implement for the time being. In the - // future, we should return an error and then react to that error - // differently depending on whether the order was received via RPC or from a - // peer. In the former case, we should return an RPC error response - // indicating that the order was not in fact added. In the latter case, we - // should effectively no-op, neither penalizing the peer or emitting any - // order events. For now, we respond by emitting an ADDED event immediately - // followed by a STOPPED_WATCHING event. If this order was submitted via - // RPC, the RPC client will see a response that indicates the order was - // successfully added, and then it will look like we immediately stopped - // watching it. This is not too far off from what really happened but is - // slightly inefficient. - addedEvent := &zeroex.OrderEvent{ - Timestamp: now, - OrderHash: orderInfo.OrderHash, - SignedOrder: orderInfo.SignedOrder, - FillableTakerAssetAmount: orderInfo.FillableTakerAssetAmount, - EndState: zeroex.ESOrderAdded, - } - orderEvents = append(orderEvents, addedEvent) - stoppedWatchingEvent := &zeroex.OrderEvent{ - Timestamp: now, - OrderHash: orderInfo.OrderHash, - SignedOrder: orderInfo.SignedOrder, - FillableTakerAssetAmount: orderInfo.FillableTakerAssetAmount, - EndState: zeroex.ESStoppedWatching, - } - orderEvents = append(orderEvents, stoppedWatchingEvent) - } else { - err = txn.Insert(order) - if err != nil { - if _, ok := err.(db.AlreadyExistsError); ok { - // If we're already watching the order, that's fine in this case. Don't - // return an error. - return orderEvents, nil - } - if _, ok := err.(db.ConflictingOperationsError); ok { - logger.WithFields(logger.Fields{ - "error": err.Error(), - "order": order, - }).Error("Failed to insert order into DB") - return orderEvents, nil - } - return orderEvents, err - } + dbOrder, err := w.orderInfoToOrderWithMetadata(orderInfo, pinned, now) + if err != nil { + return nil, err } + dbOrders = append(dbOrders, dbOrder) } - if err := txn.Commit(); err != nil { - return orderEvents, err + // TODO(albrow): Should AddOrders return the new max expiration time? + // Or is there a better way to do this? + addedOrders, removedOrders, err := w.db.AddOrders(dbOrders) + if err != nil { + return nil, err } - - for _, orderInfo := range orderInfos { - err = w.setupInMemoryOrderState(orderInfo.SignedOrder) + for _, order := range addedOrders { + err = w.setupInMemoryOrderState(order) if err != nil { return orderEvents, err } - - addedOrderEvent := &zeroex.OrderEvent{ + addedEvent := &zeroex.OrderEvent{ Timestamp: now, - OrderHash: orderInfo.OrderHash, - SignedOrder: orderInfo.SignedOrder, - FillableTakerAssetAmount: orderInfo.FillableTakerAssetAmount, + OrderHash: order.Hash, + SignedOrder: order.SignedOrder(), + FillableTakerAssetAmount: order.FillableTakerAssetAmount, EndState: zeroex.ESOrderAdded, } - orderEvents = append(orderEvents, addedOrderEvent) + orderEvents = append(orderEvents, addedEvent) } - - return orderEvents, nil -} - -func (w *Watcher) trimOrdersAndGenerateEvents() ([]*zeroex.OrderEvent, error) { - orderEvents := []*zeroex.OrderEvent{} - - targetMaxOrders := int(maxOrdersTrimRatio * float64(w.maxOrders)) - newMaxExpirationTime, removedOrders, err := w.meshDB.TrimOrdersByExpirationTime(targetMaxOrders) - if err != nil { - return orderEvents, err - } - if len(removedOrders) > 0 { - logger.WithFields(logger.Fields{ - "numOrdersRemoved": len(removedOrders), - "targetMaxOrders": targetMaxOrders, - }).Debug("removing orders to make space") - } - now := time.Now().UTC() - for _, removedOrder := range removedOrders { - // Fire a "STOPPED_WATCHING" event for each order that was removed. - orderEvent := &zeroex.OrderEvent{ + for _, order := range removedOrders { + stoppedWatchingEvent := &zeroex.OrderEvent{ Timestamp: now, - OrderHash: removedOrder.Hash, - SignedOrder: removedOrder.SignedOrder, - FillableTakerAssetAmount: removedOrder.FillableTakerAssetAmount, + OrderHash: order.Hash, + SignedOrder: order.SignedOrder(), + FillableTakerAssetAmount: order.FillableTakerAssetAmount, EndState: zeroex.ESStoppedWatching, } - orderEvents = append(orderEvents, orderEvent) + orderEvents = append(orderEvents, stoppedWatchingEvent) // Remove in-memory state - expirationTimestamp := time.Unix(removedOrder.SignedOrder.ExpirationTimeSeconds.Int64(), 0) - w.expirationWatcher.Remove(expirationTimestamp, removedOrder.Hash.Hex()) - err = w.removeAssetDataAddressFromEventDecoder(removedOrder.SignedOrder.MakerAssetData) + err = w.removeAssetDataAddressFromEventDecoder(order.MakerAssetData) if err != nil { // This should never happen since the same error would have happened when adding // the assetData to the EventDecoder. logger.WithFields(logger.Fields{ "error": err.Error(), - "signedOrder": removedOrder.SignedOrder, + "signedOrder": order.SignedOrder(), }).Error("Unexpected error when trying to remove an assetData from decoder") - return orderEvents, err + return nil, err } } - if newMaxExpirationTime.Cmp(w.maxExpirationTime) == -1 { - // Decrease the max expiration time to account for the fact that orders were - // removed. - logger.WithFields(logger.Fields{ - "oldMaxExpirationTime": w.maxExpirationTime.String(), - "newMaxExpirationTime": newMaxExpirationTime.String(), - }).Debug("decreasing max expiration time") - w.maxExpirationTime = newMaxExpirationTime - w.maxExpirationCounter.Reset(newMaxExpirationTime) - w.saveMaxExpirationTime(newMaxExpirationTime) - } + + // TODO(albrow): How to handle the edge case of orders that were not + // added due to the max expiration time changing? return orderEvents, nil } -// updateBlockHeadersStoredInDB updates the block headers stored in the DB. Since our DB txns don't support -// multiple operations involving the same entry, we make sure we only perform either an insertion or a deletion -// for each block in this method. -func updateBlockHeadersStoredInDB(miniHeadersColTxn *db.Transaction, events []*blockwatch.Event) error { - blocksToAdd := map[common.Hash]*miniheader.MiniHeader{} - blocksToRemove := map[common.Hash]*miniheader.MiniHeader{} - for _, event := range events { - blockHeader := event.BlockHeader - switch event.Type { - case blockwatch.Added: - if _, ok := blocksToAdd[blockHeader.Hash]; ok { - continue - } - if _, ok := blocksToRemove[blockHeader.Hash]; ok { - delete(blocksToRemove, blockHeader.Hash) - } - blocksToAdd[blockHeader.Hash] = blockHeader - case blockwatch.Removed: - if _, ok := blocksToAdd[blockHeader.Hash]; ok { - delete(blocksToAdd, blockHeader.Hash) - } - if _, ok := blocksToRemove[blockHeader.Hash]; ok { - continue - } - blocksToRemove[blockHeader.Hash] = blockHeader - default: - return fmt.Errorf("Unrecognized block event type encountered: %d", event.Type) - } - } - - for _, blockHeader := range blocksToAdd { - if err := miniHeadersColTxn.Insert(blockHeader); err != nil { - if _, ok := err.(db.AlreadyExistsError); !ok { - logger.WithFields(logger.Fields{ - "error": err.Error(), - "hash": blockHeader.Hash, - "number": blockHeader.Number, - }).Error("Failed to insert miniHeaders") - } - } +func (w *Watcher) orderInfoToOrderWithMetadata(orderInfo *ordervalidator.AcceptedOrderInfo, pinned bool, now time.Time) (*types.OrderWithMetadata, error) { + parsedMakerAssetData, err := db.ParseContractAddressesAndTokenIdsFromAssetData(w.assetDataDecoder, orderInfo.SignedOrder.MakerAssetData, w.contractAddresses) + if err != nil { + return nil, err } - for _, blockHeader := range blocksToRemove { - if err := miniHeadersColTxn.Delete(blockHeader.ID()); err != nil { - if _, ok := err.(db.NotFoundError); !ok { - logger.WithFields(logger.Fields{ - "error": err.Error(), - "hash": blockHeader.Hash, - "number": blockHeader.Number, - }).Error("Failed to delete miniHeaders") - } - } + parsedMakerFeeAssetData, err := db.ParseContractAddressesAndTokenIdsFromAssetData(w.assetDataDecoder, orderInfo.SignedOrder.MakerFeeAssetData, w.contractAddresses) + if err != nil { + return nil, err } - - return nil + return &types.OrderWithMetadata{ + Hash: orderInfo.OrderHash, + ChainID: orderInfo.SignedOrder.ChainID, + ExchangeAddress: orderInfo.SignedOrder.ExchangeAddress, + MakerAddress: orderInfo.SignedOrder.MakerAddress, + MakerAssetData: orderInfo.SignedOrder.MakerAssetData, + MakerFeeAssetData: orderInfo.SignedOrder.MakerFeeAssetData, + MakerAssetAmount: orderInfo.SignedOrder.MakerAssetAmount, + MakerFee: orderInfo.SignedOrder.MakerFee, + TakerAddress: orderInfo.SignedOrder.TakerAddress, + TakerAssetData: orderInfo.SignedOrder.TakerAssetData, + TakerFeeAssetData: orderInfo.SignedOrder.TakerFeeAssetData, + TakerAssetAmount: orderInfo.SignedOrder.TakerAssetAmount, + TakerFee: orderInfo.SignedOrder.TakerFee, + SenderAddress: orderInfo.SignedOrder.SenderAddress, + FeeRecipientAddress: orderInfo.SignedOrder.FeeRecipientAddress, + ExpirationTimeSeconds: orderInfo.SignedOrder.ExpirationTimeSeconds, + Salt: orderInfo.SignedOrder.Salt, + Signature: orderInfo.SignedOrder.Signature, + IsRemoved: false, + IsPinned: pinned, + LastUpdated: now, + ParsedMakerAssetData: parsedMakerAssetData, + ParsedMakerFeeAssetData: parsedMakerFeeAssetData, + FillableTakerAssetAmount: orderInfo.FillableTakerAssetAmount, + }, nil } // MaxExpirationTime returns the current maximum expiration time for incoming @@ -1119,28 +989,22 @@ func (w *Watcher) MaxExpirationTime() *big.Int { return w.maxExpirationTime } -func (w *Watcher) setupInMemoryOrderState(signedOrder *zeroex.SignedOrder) error { - orderHash, err := signedOrder.ComputeOrderHash() - if err != nil { - return err - } - w.eventDecoder.AddKnownExchange(signedOrder.ExchangeAddress) +// TODO(albrow): All in-memory state can be removed. +func (w *Watcher) setupInMemoryOrderState(order *types.OrderWithMetadata) error { + w.eventDecoder.AddKnownExchange(order.ExchangeAddress) // Add MakerAssetData and MakerFeeAssetData to EventDecoder - err = w.addAssetDataAddressToEventDecoder(signedOrder.MakerAssetData) + err := w.addAssetDataAddressToEventDecoder(order.MakerAssetData) if err != nil { return err } - if signedOrder.MakerFee.Cmp(big.NewInt(0)) == 1 { - err = w.addAssetDataAddressToEventDecoder(signedOrder.MakerFeeAssetData) + if order.MakerFee.Cmp(big.NewInt(0)) == 1 { + err = w.addAssetDataAddressToEventDecoder(order.MakerFeeAssetData) if err != nil { return err } } - expirationTimestamp := time.Unix(signedOrder.ExpirationTimeSeconds.Int64(), 0) - w.expirationWatcher.Add(expirationTimestamp, orderHash.Hex()) - return nil } @@ -1152,35 +1016,51 @@ func (w *Watcher) Subscribe(sink chan<- []*zeroex.OrderEvent) event.Subscription return w.orderScope.Track(w.orderFeed.Subscribe(sink)) } -func (w *Watcher) findOrder(orderHash common.Hash) *meshdb.Order { - order := meshdb.Order{} - err := w.meshDB.Orders.FindByID(orderHash.Bytes(), &order) +func (w *Watcher) findOrder(orderHash common.Hash) *types.OrderWithMetadata { + order, err := w.db.GetOrder(orderHash) if err != nil { - if _, ok := err.(db.NotFoundError); ok { + if err == db.ErrNotFound { // short-circuit. We expect to receive events from orders we aren't actively tracking return nil } logger.WithFields(logger.Fields{ "error": err.Error(), "orderHash": orderHash, - }).Warning("Unexpected error using FindByID for order") + }).Warning("Unexpected error from db.GetOrder") return nil } - return &order + return order } // findOrdersByTokenAddressAndTokenID finds and returns all orders that have // either a makerAsset or a makerFeeAsset matching the given tokenAddress and // tokenID. -func (w *Watcher) findOrdersByTokenAddressAndTokenID(makerAddress, tokenAddress common.Address, tokenID *big.Int) ([]*meshdb.Order, error) { - ordersWithAffectedMakerAsset, err := w.meshDB.FindOrdersByMakerAddressTokenAddressAndTokenID(makerAddress, tokenAddress, tokenID) +func (w *Watcher) findOrdersByTokenAddressAndTokenID(makerAddress, tokenAddress common.Address, tokenID *big.Int) ([]*types.OrderWithMetadata, error) { + ordersWithAffectedMakerAsset, err := w.db.FindOrders(&db.OrderQuery{ + Filters: []db.OrderFilter{ + { + Field: db.OFMakerAddress, + Kind: db.Equal, + Value: makerAddress, + }, + db.MakerAssetIncludesTokenAddressAndTokenID(tokenAddress, tokenID), + }, + }) if err != nil { logger.WithFields(logger.Fields{ "error": err.Error(), }).Error("unexpected query error encountered") return nil, err } - ordersWithAffectedMakerFeeAsset, err := w.meshDB.FindOrdersByMakerAddressMakerFeeAssetAddressAndTokenID(makerAddress, tokenAddress, tokenID) + ordersWithAffectedMakerFeeAsset, err := w.db.FindOrders(&db.OrderQuery{ + Filters: []db.OrderFilter{ + { + Field: db.OFMakerAddress, + Kind: db.Equal, + Value: makerAddress, + }, + db.MakerFeeAssetIncludesTokenAddressAndTokenID(tokenAddress, tokenID)}, + }) if err != nil { logger.WithFields(logger.Fields{ "error": err.Error(), @@ -1191,10 +1071,95 @@ func (w *Watcher) findOrdersByTokenAddressAndTokenID(makerAddress, tokenAddress return append(ordersWithAffectedMakerAsset, ordersWithAffectedMakerFeeAsset...), nil } +// findOrdersByTokenAddress finds and returns all orders that have +// either a makerAsset or a makerFeeAsset matching the given tokenAddress and +// any tokenID (including null). +func (w *Watcher) findOrdersByTokenAddress(makerAddress, tokenAddress common.Address) ([]*types.OrderWithMetadata, error) { + ordersWithAffectedMakerAsset, err := w.db.FindOrders(&db.OrderQuery{ + Filters: []db.OrderFilter{ + { + Field: db.OFMakerAddress, + Kind: db.Equal, + Value: makerAddress, + }, + db.MakerAssetIncludesTokenAddress(tokenAddress), + }, + }) + if err != nil { + logger.WithFields(logger.Fields{ + "error": err.Error(), + }).Error("unexpected query error encountered") + return nil, err + } + ordersWithAffectedMakerFeeAsset, err := w.db.FindOrders(&db.OrderQuery{ + Filters: []db.OrderFilter{ + { + Field: db.OFMakerAddress, + Kind: db.Equal, + Value: makerAddress, + }, + db.MakerFeeAssetIncludesTokenAddress(tokenAddress)}, + }) + if err != nil { + logger.WithFields(logger.Fields{ + "error": err.Error(), + }).Error("unexpected query error encountered") + return nil, err + } + + return append(ordersWithAffectedMakerAsset, ordersWithAffectedMakerFeeAsset...), nil +} + +// findOrdersToExpire returns all orders with an expiration time less than or equal to the latest +// block timestamp that have not already been removed. +func (w *Watcher) findOrdersToExpire(latestBlockTimestamp time.Time) ([]*types.OrderWithMetadata, error) { + return w.db.FindOrders(&db.OrderQuery{ + Filters: []db.OrderFilter{ + { + Field: db.OFExpirationTimeSeconds, + Kind: db.LessOrEqual, + Value: big.NewInt(latestBlockTimestamp.Unix()), + }, + { + Field: db.OFIsRemoved, + Kind: db.Equal, + Value: false, + }, + }, + }) +} + +// findOrdersToUnexpire returns all orders that: +// +// 1. have an expiration time greater than the latest block timestamp +// 2. were previously removed +// 3. have a non-zero FillableTakerAssetAmount +// +func (w *Watcher) findOrdersToUnexpire(latestBlockTimestamp time.Time) ([]*types.OrderWithMetadata, error) { + return w.db.FindOrders(&db.OrderQuery{ + Filters: []db.OrderFilter{ + { + Field: db.OFExpirationTimeSeconds, + Kind: db.Greater, + Value: big.NewInt(latestBlockTimestamp.Unix()), + }, + { + Field: db.OFIsRemoved, + Kind: db.Equal, + Value: true, + }, + { + Field: db.OFFillableTakerAssetAmount, + Kind: db.NotEqual, + Value: 0, + }, + }, + }) +} + func (w *Watcher) convertValidationResultsIntoOrderEvents( - ordersColTxn *db.Transaction, validationResults *ordervalidator.ValidationResults, - orderHashToDBOrder map[common.Hash]*meshdb.Order, + orderHashToDBOrder map[common.Hash]*types.OrderWithMetadata, orderHashToEvents map[common.Hash][]*zeroex.ContractEvent, validationBlockTimestamp time.Time, ) ([]*zeroex.OrderEvent, error) { @@ -1217,27 +1182,27 @@ func (w *Watcher) convertValidationResultsIntoOrderEvents( // A previous event caused this order to be removed from DB because it's // fillableAmount became 0, but it has now been revived (e.g., block re-org // causes order fill txn to get reverted). We need to re-add order and emit an event. - w.rewatchOrder(ordersColTxn, order, acceptedOrderInfo.FillableTakerAssetAmount) + w.rewatchOrder(order, acceptedOrderInfo.FillableTakerAssetAmount) orderEvent := &zeroex.OrderEvent{ Timestamp: validationBlockTimestamp, OrderHash: acceptedOrderInfo.OrderHash, - SignedOrder: order.SignedOrder, + SignedOrder: order.SignedOrder(), FillableTakerAssetAmount: acceptedOrderInfo.FillableTakerAssetAmount, EndState: zeroex.ESOrderAdded, ContractEvents: orderHashToEvents[order.Hash], } orderEvents = append(orderEvents, orderEvent) } else { - expiration := time.Unix(order.SignedOrder.ExpirationTimeSeconds.Int64(), 0) + expiration := time.Unix(order.SignedOrder().ExpirationTimeSeconds.Int64(), 0) if oldFillableAmount.Cmp(newFillableAmount) == 0 { // If order was previously expired, check if it has become unexpired if order.IsRemoved && oldFillableAmount.Cmp(big.NewInt(0)) != 0 && validationBlockTimestamp.Before(expiration) { - w.rewatchOrder(ordersColTxn, order, order.FillableTakerAssetAmount) + w.rewatchOrder(order, nil) orderEvent := &zeroex.OrderEvent{ Timestamp: validationBlockTimestamp, OrderHash: order.Hash, - SignedOrder: order.SignedOrder, + SignedOrder: order.SignedOrder(), FillableTakerAssetAmount: order.FillableTakerAssetAmount, EndState: zeroex.ESOrderUnexpired, } @@ -1249,24 +1214,23 @@ func (w *Watcher) convertValidationResultsIntoOrderEvents( if oldFillableAmount.Cmp(big.NewInt(0)) == 1 && oldAmountIsMoreThenNewAmount { // If order was previously expired, check if it has become unexpired if order.IsRemoved && oldFillableAmount.Cmp(big.NewInt(0)) != 0 && validationBlockTimestamp.Before(expiration) { - w.rewatchOrder(ordersColTxn, order, newFillableAmount) + w.rewatchOrder(order, newFillableAmount) orderEvent := &zeroex.OrderEvent{ Timestamp: validationBlockTimestamp, OrderHash: order.Hash, - SignedOrder: order.SignedOrder, + SignedOrder: order.SignedOrder(), FillableTakerAssetAmount: order.FillableTakerAssetAmount, EndState: zeroex.ESOrderUnexpired, } orderEvents = append(orderEvents, orderEvent) } else { - order.FillableTakerAssetAmount = newFillableAmount - w.updateOrderDBEntry(ordersColTxn, order) + w.updateOrderFillableTakerAssetAmount(order, newFillableAmount) } // Order was filled, emit event orderEvent := &zeroex.OrderEvent{ Timestamp: validationBlockTimestamp, OrderHash: acceptedOrderInfo.OrderHash, - SignedOrder: order.SignedOrder, + SignedOrder: order.SignedOrder(), EndState: zeroex.ESOrderFilled, FillableTakerAssetAmount: acceptedOrderInfo.FillableTakerAssetAmount, ContractEvents: orderHashToEvents[order.Hash], @@ -1276,23 +1240,22 @@ func (w *Watcher) convertValidationResultsIntoOrderEvents( // The order is now fillable for more then it was before. E.g.: A fill txn reverted (block-reorg) // If order was previously expired, check if it has become unexpired if order.IsRemoved && oldFillableAmount.Cmp(big.NewInt(0)) != 0 && validationBlockTimestamp.Before(expiration) { - w.rewatchOrder(ordersColTxn, order, newFillableAmount) + w.rewatchOrder(order, newFillableAmount) orderEvent := &zeroex.OrderEvent{ Timestamp: validationBlockTimestamp, OrderHash: order.Hash, - SignedOrder: order.SignedOrder, + SignedOrder: order.SignedOrder(), FillableTakerAssetAmount: order.FillableTakerAssetAmount, EndState: zeroex.ESOrderUnexpired, } orderEvents = append(orderEvents, orderEvent) } else { - order.FillableTakerAssetAmount = newFillableAmount - w.updateOrderDBEntry(ordersColTxn, order) + w.updateOrderFillableTakerAssetAmount(order, newFillableAmount) } orderEvent := &zeroex.OrderEvent{ Timestamp: validationBlockTimestamp, OrderHash: acceptedOrderInfo.OrderHash, - SignedOrder: order.SignedOrder, + SignedOrder: order.SignedOrder(), EndState: zeroex.ESOrderFillabilityIncreased, FillableTakerAssetAmount: acceptedOrderInfo.FillableTakerAssetAmount, ContractEvents: orderHashToEvents[order.Hash], @@ -1320,7 +1283,7 @@ func (w *Watcher) convertValidationResultsIntoOrderEvents( // If the oldFillableAmount was already 0, this order is already flagged for removal. } else { // If oldFillableAmount > 0, it got fullyFilled, cancelled, expired or unfunded - w.unwatchOrder(ordersColTxn, order, big.NewInt(0)) + w.unwatchOrder(order, big.NewInt(0)) endState, ok := ordervalidator.ConvertRejectOrderCodeToOrderEventEndState(rejectedOrderInfo.Status) if !ok { err := fmt.Errorf("no OrderEventEndState corresponding to RejectedOrderStatus: %q", rejectedOrderInfo.Status) @@ -1349,8 +1312,7 @@ func (w *Watcher) convertValidationResultsIntoOrderEvents( func (w *Watcher) generateOrderEventsIfChanged( ctx context.Context, - ordersColTxn *db.Transaction, - orderHashToDBOrder map[common.Hash]*meshdb.Order, + orderHashToDBOrder map[common.Hash]*types.OrderWithMetadata, orderHashToEvents map[common.Hash][]*zeroex.ContractEvent, validationBlockNumber *big.Int, validationBlockTimestamp time.Time, @@ -1358,12 +1320,12 @@ func (w *Watcher) generateOrderEventsIfChanged( signedOrders := []*zeroex.SignedOrder{} for _, order := range orderHashToDBOrder { if order.IsRemoved && time.Since(order.LastUpdated) > permanentlyDeleteAfter { - if err := w.permanentlyDeleteOrder(ordersColTxn, order); err != nil { + if err := w.permanentlyDeleteOrder(order); err != nil { return nil, err } continue } - signedOrders = append(signedOrders, order.SignedOrder) + signedOrders = append(signedOrders, order.SignedOrder()) } if len(signedOrders) == 0 { return nil, nil @@ -1372,7 +1334,7 @@ func (w *Watcher) generateOrderEventsIfChanged( validationResults := w.orderValidator.BatchValidate(ctx, signedOrders, areNewOrders, validationBlockNumber) return w.convertValidationResultsIntoOrderEvents( - ordersColTxn, validationResults, orderHashToDBOrder, orderHashToEvents, validationBlockTimestamp, + validationResults, orderHashToDBOrder, orderHashToEvents, validationBlockTimestamp, ) } @@ -1433,14 +1395,14 @@ func (w *Watcher) ValidateAndStoreValidOrders(ctx context.Context, orders []*zer return results, nil } -func (w *Watcher) onchainOrderValidation(ctx context.Context, orders []*zeroex.SignedOrder) (*miniheader.MiniHeader, *ordervalidator.ValidationResults, error) { +func (w *Watcher) onchainOrderValidation(ctx context.Context, orders []*zeroex.SignedOrder) (*types.MiniHeader, *ordervalidator.ValidationResults, error) { // HACK(fabio): While we wait for EIP-1898 support in Parity, we have no choice but to do the `eth_call` // at the latest known block _number_. As outlined in the `Rationale` section of EIP-1898, this approach cannot account // for the block being re-org'd out before the `eth_call` and then back in before the `eth_getBlockByNumber` // call (an unlikely but possible situation leading to an incorrect view of the world for these orders). // Unfortunately, this is the best we can do until EIP-1898 support in Parity. // Source: https://github.com/ethereum/EIPs/blob/master/EIPS/eip-1898.md#rationale - validationBlock, err := w.meshDB.FindLatestMiniHeader() + latestMiniHeader, err := w.getLatestBlock() if err != nil { return nil, nil, err } @@ -1448,8 +1410,8 @@ func (w *Watcher) onchainOrderValidation(ctx context.Context, orders []*zeroex.S ctx, cancel := context.WithTimeout(ctx, 1*time.Minute) defer cancel() areNewOrders := true - zeroexResults := w.orderValidator.BatchValidate(ctx, orders, areNewOrders, validationBlock.Number) - return validationBlock, zeroexResults, nil + zeroexResults := w.orderValidator.BatchValidate(ctx, orders, areNewOrders, latestMiniHeader.Number) + return latestMiniHeader, zeroexResults, nil } func (w *Watcher) meshSpecificOrderValidation(orders []*zeroex.SignedOrder, chainID int) (*ordervalidator.ValidationResults, []*zeroex.SignedOrder, error) { @@ -1538,14 +1500,13 @@ func (w *Watcher) meshSpecificOrderValidation(orders []*zeroex.SignedOrder, chai } // Check if order is already stored in DB - var dbOrder meshdb.Order - err = w.meshDB.Orders.FindByID(orderHash.Bytes(), &dbOrder) + dbOrder, err := w.db.GetOrder(orderHash) if err != nil { - if _, ok := err.(db.NotFoundError); !ok { + if err != db.ErrNotFound { logger.WithField("error", err).Error("could not check if order was already stored") return nil, nil, err } - // If the error is a db.NotFoundError, it just means the order is not currently stored in + // If the error is db.ErrNotFound, it just means the order is not currently stored in // the database. There's nothing else in the database to check, so we can continue. } else { // If stored but flagged for removal, reject it @@ -1586,13 +1547,12 @@ func validateOrderSize(order *zeroex.SignedOrder) error { return nil } -type orderUpdater interface { - Update(model db.Model) error -} - -func (w *Watcher) updateOrderDBEntry(u orderUpdater, order *meshdb.Order) { - order.LastUpdated = time.Now().UTC() - err := u.Update(order) +func (w *Watcher) updateOrderFillableTakerAssetAmount(order *types.OrderWithMetadata, newFillableTakerAssetAmount *big.Int) { + err := w.db.UpdateOrder(order.Hash, func(orderToUpdate *types.OrderWithMetadata) (*types.OrderWithMetadata, error) { + orderToUpdate.LastUpdated = time.Now().UTC() + orderToUpdate.FillableTakerAssetAmount = newFillableTakerAssetAmount + return orderToUpdate, nil + }) if err != nil { logger.WithFields(logger.Fields{ "error": err.Error(), @@ -1601,61 +1561,51 @@ func (w *Watcher) updateOrderDBEntry(u orderUpdater, order *meshdb.Order) { } } -func (w *Watcher) rewatchOrder(u orderUpdater, order *meshdb.Order, fillableTakerAssetAmount *big.Int) { - order.IsRemoved = false - order.LastUpdated = time.Now().UTC() - order.FillableTakerAssetAmount = fillableTakerAssetAmount - err := u.Update(order) +func (w *Watcher) rewatchOrder(order *types.OrderWithMetadata, newFillableTakerAssetAmount *big.Int) { + err := w.db.UpdateOrder(order.Hash, func(orderToUpdate *types.OrderWithMetadata) (*types.OrderWithMetadata, error) { + orderToUpdate.IsRemoved = false + orderToUpdate.LastUpdated = time.Now().UTC() + if newFillableTakerAssetAmount != nil { + orderToUpdate.FillableTakerAssetAmount = newFillableTakerAssetAmount + } + return orderToUpdate, nil + }) if err != nil { logger.WithFields(logger.Fields{ "error": err.Error(), "order": order, }).Error("Failed to update order") } - - // Re-add order to expiration watcher - expirationTimestamp := time.Unix(order.SignedOrder.ExpirationTimeSeconds.Int64(), 0) - w.expirationWatcher.Add(expirationTimestamp, order.Hash.Hex()) } -func (w *Watcher) unwatchOrder(u orderUpdater, order *meshdb.Order, newFillableAmount *big.Int) { - order.IsRemoved = true - order.LastUpdated = time.Now().UTC() - order.FillableTakerAssetAmount = newFillableAmount - err := u.Update(order) +func (w *Watcher) unwatchOrder(order *types.OrderWithMetadata, newFillableAmount *big.Int) { + err := w.db.UpdateOrder(order.Hash, func(orderToUpdate *types.OrderWithMetadata) (*types.OrderWithMetadata, error) { + orderToUpdate.IsRemoved = true + orderToUpdate.LastUpdated = time.Now().UTC() + if newFillableAmount != nil { + orderToUpdate.FillableTakerAssetAmount = newFillableAmount + } + return orderToUpdate, nil + }) if err != nil { logger.WithFields(logger.Fields{ "error": err.Error(), "order": order, }).Error("Failed to update order") } - - expirationTimestamp := time.Unix(order.SignedOrder.ExpirationTimeSeconds.Int64(), 0) - w.expirationWatcher.Remove(expirationTimestamp, order.Hash.Hex()) } type orderDeleter interface { Delete(id []byte) error } -func (w *Watcher) permanentlyDeleteOrder(deleter orderDeleter, order *meshdb.Order) error { - err := deleter.Delete(order.Hash.Bytes()) - if err != nil { - if _, ok := err.(db.ConflictingOperationsError); ok { - logger.WithFields(logger.Fields{ - "error": err.Error(), - "order": order, - }).Error("Failed to permanently delete order") - return nil - } - if _, ok := err.(db.NotFoundError); ok { - return nil // Already deleted. Noop. - } +func (w *Watcher) permanentlyDeleteOrder(order *types.OrderWithMetadata) error { + if err := w.db.DeleteOrder(order.Hash); err != nil { return err } // After permanently deleting an order, we also remove it's assetData from the Decoder - err = w.removeAssetDataAddressFromEventDecoder(order.SignedOrder.MakerAssetData) + err := w.removeAssetDataAddressFromEventDecoder(order.MakerAssetData) if err != nil { // This should never happen since the same error would have happened when adding // the assetData to the EventDecoder. @@ -1815,18 +1765,8 @@ func (w *Watcher) removeAssetDataAddressFromEventDecoder(assetData []byte) error return nil } -func (w *Watcher) decreaseMaxExpirationTimeIfNeeded() ([]*zeroex.OrderEvent, error) { - orderEvents := []*zeroex.OrderEvent{} - if orderCount, err := w.meshDB.Orders.Count(); err != nil { - return orderEvents, err - } else if orderCount+1 > w.maxOrders { - return w.trimOrdersAndGenerateEvents() - } - return orderEvents, nil -} - func (w *Watcher) increaseMaxExpirationTimeIfPossible() error { - if orderCount, err := w.meshDB.Orders.Count(); err != nil { + if orderCount, err := w.db.CountOrders(nil); err != nil { return err } else if orderCount < w.maxOrders { // We have enough space for new orders. Set the new max expiration time to the @@ -1847,7 +1787,7 @@ func (w *Watcher) increaseMaxExpirationTimeIfPossible() error { // saveMaxExpirationTime saves the new max expiration time in the database. func (w *Watcher) saveMaxExpirationTime(maxExpirationTime *big.Int) { - if err := w.meshDB.UpdateMetadata(func(metadata meshdb.Metadata) meshdb.Metadata { + if err := w.db.UpdateMetadata(func(metadata *types.Metadata) *types.Metadata { metadata.MaxExpirationTime = maxExpirationTime return metadata }); err != nil { @@ -1880,5 +1820,5 @@ func (w *Watcher) WaitForAtLeastOneBlockToBeProcessed(ctx context.Context) error type logWithType struct { Type string - Log types.Log + Log ethtypes.Log } diff --git a/zeroex/orderwatch/order_watcher_test.go b/zeroex/orderwatch/order_watcher_test.go index 6fc3e2698..2bcb4e221 100644 --- a/zeroex/orderwatch/order_watcher_test.go +++ b/zeroex/orderwatch/order_watcher_test.go @@ -9,15 +9,14 @@ import ( "testing" "time" + "github.com/0xProject/0x-mesh/common/types" "github.com/0xProject/0x-mesh/constants" + "github.com/0xProject/0x-mesh/db" "github.com/0xProject/0x-mesh/ethereum" "github.com/0xProject/0x-mesh/ethereum/blockwatch" "github.com/0xProject/0x-mesh/ethereum/ethrpcclient" - "github.com/0xProject/0x-mesh/ethereum/miniheader" "github.com/0xProject/0x-mesh/ethereum/ratelimit" - "github.com/0xProject/0x-mesh/ethereum/simplestack" "github.com/0xProject/0x-mesh/ethereum/wrappers" - "github.com/0xProject/0x-mesh/meshdb" "github.com/0xProject/0x-mesh/scenario" "github.com/0xProject/0x-mesh/scenario/orderopts" "github.com/0xProject/0x-mesh/zeroex" @@ -25,11 +24,10 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/types" + ethtypes "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/rpc" ethrpc "github.com/ethereum/go-ethereum/rpc" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -119,13 +117,16 @@ func TestOrderWatcherUnfundedInsufficientERC20Balance(t *testing.T) { teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) - signedOrder := scenario.NewSignedTestOrder(t, orderopts.SetupMakerState(true)) - ctx, cancelFn := context.WithCancel(context.Background()) - defer cancelFn() - blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, meshDB, signedOrder) + signedOrder := scenario.NewSignedTestOrder(t, + orderopts.SetupMakerState(true), + orderopts.MakerAssetData(scenario.ZRXAssetData), + ) + blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, database, signedOrder) // Transfer makerAsset out of maker address opts := &bind.TransactOpts{ @@ -144,8 +145,7 @@ func TestOrderWatcherUnfundedInsufficientERC20Balance(t *testing.T) { orderEvent := orderEvents[0] assert.Equal(t, zeroex.ESOrderBecameUnfunded, orderEvent.EndState) - var orders []*meshdb.Order - err = meshDB.Orders.FindAll(&orders) + orders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, orders, 1) assert.Equal(t, orderEvent.OrderHash, orders[0].Hash) @@ -161,7 +161,9 @@ func TestOrderWatcherUnfundedInsufficientERC20BalanceForMakerFee(t *testing.T) { teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) + ctx, cancelFn := context.WithCancel(context.Background()) + defer cancelFn() + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) makerAssetData := scenario.GetDummyERC721AssetData(big.NewInt(1)) @@ -173,9 +175,7 @@ func TestOrderWatcherUnfundedInsufficientERC20BalanceForMakerFee(t *testing.T) { orderopts.MakerFeeAssetData(scenario.WETHAssetData), orderopts.MakerFee(wethFeeAmount), ) - ctx, cancelFn := context.WithCancel(context.Background()) - defer cancelFn() - blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, meshDB, signedOrder) + blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, database, signedOrder) // Transfer makerAsset out of maker address opts := &bind.TransactOpts{ @@ -194,8 +194,7 @@ func TestOrderWatcherUnfundedInsufficientERC20BalanceForMakerFee(t *testing.T) { orderEvent := orderEvents[0] assert.Equal(t, zeroex.ESOrderBecameUnfunded, orderEvent.EndState) - var orders []*meshdb.Order - err = meshDB.Orders.FindAll(&orders) + orders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, orders, 1) assert.Equal(t, orderEvent.OrderHash, orders[0].Hash) @@ -210,7 +209,9 @@ func TestOrderWatcherUnfundedInsufficientERC721Balance(t *testing.T) { teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) + ctx, cancelFn := context.WithCancel(context.Background()) + defer cancelFn() + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) tokenID := big.NewInt(1) @@ -220,9 +221,7 @@ func TestOrderWatcherUnfundedInsufficientERC721Balance(t *testing.T) { orderopts.MakerAssetAmount(big.NewInt(1)), orderopts.MakerAssetData(makerAssetData), ) - ctx, cancelFn := context.WithCancel(context.Background()) - defer cancelFn() - blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, meshDB, signedOrder) + blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, database, signedOrder) // Transfer makerAsset out of maker address opts := &bind.TransactOpts{ @@ -241,8 +240,7 @@ func TestOrderWatcherUnfundedInsufficientERC721Balance(t *testing.T) { orderEvent := orderEvents[0] assert.Equal(t, zeroex.ESOrderBecameUnfunded, orderEvent.EndState) - var orders []*meshdb.Order - err = meshDB.Orders.FindAll(&orders) + orders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, orders, 1) assert.Equal(t, orderEvent.OrderHash, orders[0].Hash) @@ -258,7 +256,9 @@ func TestOrderWatcherUnfundedInsufficientERC721Allowance(t *testing.T) { teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) + ctx, cancelFn := context.WithCancel(context.Background()) + defer cancelFn() + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) tokenID := big.NewInt(1) @@ -268,9 +268,7 @@ func TestOrderWatcherUnfundedInsufficientERC721Allowance(t *testing.T) { orderopts.MakerAssetAmount(big.NewInt(1)), orderopts.MakerAssetData(makerAssetData), ) - ctx, cancelFn := context.WithCancel(context.Background()) - defer cancelFn() - blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, meshDB, signedOrder) + blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, database, signedOrder) // Remove Maker's NFT approval to ERC721Proxy. We do this by setting the // operator/spender to the null address. @@ -290,8 +288,7 @@ func TestOrderWatcherUnfundedInsufficientERC721Allowance(t *testing.T) { orderEvent := orderEvents[0] assert.Equal(t, zeroex.ESOrderBecameUnfunded, orderEvent.EndState) - var orders []*meshdb.Order - err = meshDB.Orders.FindAll(&orders) + orders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, orders, 1) assert.Equal(t, orderEvent.OrderHash, orders[0].Hash) @@ -307,7 +304,9 @@ func TestOrderWatcherUnfundedInsufficientERC1155Allowance(t *testing.T) { teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) + ctx, cancelFn := context.WithCancel(context.Background()) + defer cancelFn() + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) makerAssetData := scenario.GetDummyERC1155AssetData(t, []*big.Int{big.NewInt(1)}, []*big.Int{big.NewInt(100)}) @@ -316,9 +315,7 @@ func TestOrderWatcherUnfundedInsufficientERC1155Allowance(t *testing.T) { orderopts.MakerAssetAmount(big.NewInt(1)), orderopts.MakerAssetData(makerAssetData), ) - ctx, cancelFn := context.WithCancel(context.Background()) - defer cancelFn() - blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, meshDB, signedOrder) + blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, database, signedOrder) // Remove Maker's ERC1155 approval to ERC1155Proxy opts := &bind.TransactOpts{ @@ -337,8 +334,7 @@ func TestOrderWatcherUnfundedInsufficientERC1155Allowance(t *testing.T) { orderEvent := orderEvents[0] assert.Equal(t, zeroex.ESOrderBecameUnfunded, orderEvent.EndState) - var orders []*meshdb.Order - err = meshDB.Orders.FindAll(&orders) + orders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, orders, 1) assert.Equal(t, orderEvent.OrderHash, orders[0].Hash) @@ -354,7 +350,9 @@ func TestOrderWatcherUnfundedInsufficientERC1155Balance(t *testing.T) { teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) + ctx, cancelFn := context.WithCancel(context.Background()) + defer cancelFn() + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) tokenID := big.NewInt(1) @@ -365,9 +363,7 @@ func TestOrderWatcherUnfundedInsufficientERC1155Balance(t *testing.T) { orderopts.MakerAssetAmount(big.NewInt(1)), orderopts.MakerAssetData(makerAssetData), ) - ctx, cancelFn := context.WithCancel(context.Background()) - defer cancelFn() - blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, meshDB, signedOrder) + blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, database, signedOrder) // Reduce Maker's ERC1155 balance opts := &bind.TransactOpts{ @@ -386,8 +382,7 @@ func TestOrderWatcherUnfundedInsufficientERC1155Balance(t *testing.T) { orderEvent := orderEvents[0] assert.Equal(t, zeroex.ESOrderBecameUnfunded, orderEvent.EndState) - var orders []*meshdb.Order - err = meshDB.Orders.FindAll(&orders) + orders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, orders, 1) assert.Equal(t, orderEvent.OrderHash, orders[0].Hash) @@ -403,16 +398,16 @@ func TestOrderWatcherUnfundedInsufficientERC20Allowance(t *testing.T) { teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) + ctx, cancelFn := context.WithCancel(context.Background()) + defer cancelFn() + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) signedOrder := scenario.NewSignedTestOrder(t, orderopts.SetupMakerState(true), orderopts.MakerAssetData(scenario.ZRXAssetData), ) - ctx, cancelFn := context.WithCancel(context.Background()) - defer cancelFn() - blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, meshDB, signedOrder) + blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, database, signedOrder) // Remove Maker's ZRX approval to ERC20Proxy opts := &bind.TransactOpts{ @@ -431,8 +426,7 @@ func TestOrderWatcherUnfundedInsufficientERC20Allowance(t *testing.T) { orderEvent := orderEvents[0] assert.Equal(t, zeroex.ESOrderBecameUnfunded, orderEvent.EndState) - var orders []*meshdb.Order - err = meshDB.Orders.FindAll(&orders) + orders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, orders, 1) assert.Equal(t, orderEvent.OrderHash, orders[0].Hash) @@ -448,7 +442,9 @@ func TestOrderWatcherUnfundedThenFundedAgain(t *testing.T) { teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) + ctx, cancelFn := context.WithCancel(context.Background()) + defer cancelFn() + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) signedOrder := scenario.NewSignedTestOrder(t, @@ -456,9 +452,7 @@ func TestOrderWatcherUnfundedThenFundedAgain(t *testing.T) { orderopts.MakerAssetData(scenario.ZRXAssetData), orderopts.TakerAssetData(scenario.WETHAssetData), ) - ctx, cancelFn := context.WithCancel(context.Background()) - defer cancelFn() - blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, meshDB, signedOrder) + blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, database, signedOrder) // Transfer makerAsset out of maker address opts := &bind.TransactOpts{ @@ -477,8 +471,7 @@ func TestOrderWatcherUnfundedThenFundedAgain(t *testing.T) { orderEvent := orderEvents[0] assert.Equal(t, zeroex.ESOrderBecameUnfunded, orderEvent.EndState) - var orders []*meshdb.Order - err = meshDB.Orders.FindAll(&orders) + orders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, orders, 1) assert.Equal(t, orderEvent.OrderHash, orders[0].Hash) @@ -503,8 +496,7 @@ func TestOrderWatcherUnfundedThenFundedAgain(t *testing.T) { orderEvent = orderEvents[0] assert.Equal(t, zeroex.ESOrderAdded, orderEvent.EndState) - var newOrders []*meshdb.Order - err = meshDB.Orders.FindAll(&newOrders) + newOrders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, newOrders, 1) assert.Equal(t, orderEvent.OrderHash, newOrders[0].Hash) @@ -520,7 +512,9 @@ func TestOrderWatcherNoChange(t *testing.T) { teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) + ctx, cancelFn := context.WithCancel(context.Background()) + defer cancelFn() + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) signedOrder := scenario.NewSignedTestOrder(t, @@ -528,12 +522,9 @@ func TestOrderWatcherNoChange(t *testing.T) { orderopts.MakerAssetData(scenario.ZRXAssetData), orderopts.TakerAssetData(scenario.WETHAssetData), ) - ctx, cancelFn := context.WithCancel(context.Background()) - defer cancelFn() - blockWatcher, _ := setupOrderWatcherScenario(ctx, t, ethClient, meshDB, signedOrder) + blockWatcher, _ := setupOrderWatcherScenario(ctx, t, ethClient, database, signedOrder) - var orders []*meshdb.Order - err = meshDB.Orders.FindAll(&orders) + orders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, orders, 1) dbOrder := orders[0] @@ -552,8 +543,7 @@ func TestOrderWatcherNoChange(t *testing.T) { err = blockWatcher.SyncToLatestBlock() require.NoError(t, err) - var newOrders []*meshdb.Order - err = meshDB.Orders.FindAll(&newOrders) + newOrders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, newOrders, 1) require.NotEqual(t, dbOrder.LastUpdated, newOrders[0].Hash) @@ -569,7 +559,9 @@ func TestOrderWatcherWETHWithdrawAndDeposit(t *testing.T) { teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) + ctx, cancelFn := context.WithCancel(context.Background()) + defer cancelFn() + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) signedOrder := scenario.NewSignedTestOrder(t, @@ -577,9 +569,7 @@ func TestOrderWatcherWETHWithdrawAndDeposit(t *testing.T) { orderopts.MakerAssetData(scenario.WETHAssetData), orderopts.TakerAssetData(scenario.ZRXAssetData), ) - ctx, cancelFn := context.WithCancel(context.Background()) - defer cancelFn() - blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, meshDB, signedOrder) + blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, database, signedOrder) // Withdraw maker's WETH (i.e. decrease WETH balance) // HACK(fabio): For some reason the txn fails with "out of gas" error with the @@ -602,8 +592,7 @@ func TestOrderWatcherWETHWithdrawAndDeposit(t *testing.T) { orderEvent := orderEvents[0] assert.Equal(t, zeroex.ESOrderBecameUnfunded, orderEvent.EndState) - var orders []*meshdb.Order - err = meshDB.Orders.FindAll(&orders) + orders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, orders, 1) assert.Equal(t, orderEvent.OrderHash, orders[0].Hash) @@ -627,8 +616,7 @@ func TestOrderWatcherWETHWithdrawAndDeposit(t *testing.T) { orderEvent = orderEvents[0] assert.Equal(t, zeroex.ESOrderAdded, orderEvent.EndState) - var newOrders []*meshdb.Order - err = meshDB.Orders.FindAll(&newOrders) + newOrders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, newOrders, 1) assert.Equal(t, orderEvent.OrderHash, newOrders[0].Hash) @@ -644,13 +632,13 @@ func TestOrderWatcherCanceled(t *testing.T) { teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) + ctx, cancelFn := context.WithCancel(context.Background()) + defer cancelFn() + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) signedOrder := scenario.NewSignedTestOrder(t, orderopts.SetupMakerState(true)) - ctx, cancelFn := context.WithCancel(context.Background()) - defer cancelFn() - blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, meshDB, signedOrder) + blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, database, signedOrder) // Cancel order opts := &bind.TransactOpts{ @@ -670,8 +658,7 @@ func TestOrderWatcherCanceled(t *testing.T) { orderEvent := orderEvents[0] assert.Equal(t, zeroex.ESOrderCancelled, orderEvent.EndState) - var orders []*meshdb.Order - err = meshDB.Orders.FindAll(&orders) + orders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, orders, 1) assert.Equal(t, orderEvent.OrderHash, orders[0].Hash) @@ -687,13 +674,13 @@ func TestOrderWatcherCancelUpTo(t *testing.T) { teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) + ctx, cancelFn := context.WithCancel(context.Background()) + defer cancelFn() + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) signedOrder := scenario.NewSignedTestOrder(t, orderopts.SetupMakerState(true)) - ctx, cancelFn := context.WithCancel(context.Background()) - defer cancelFn() - blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, meshDB, signedOrder) + blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, database, signedOrder) // Cancel order with epoch opts := &bind.TransactOpts{ @@ -713,8 +700,7 @@ func TestOrderWatcherCancelUpTo(t *testing.T) { orderEvent := orderEvents[0] assert.Equal(t, zeroex.ESOrderCancelled, orderEvent.EndState) - var orders []*meshdb.Order - err = meshDB.Orders.FindAll(&orders) + orders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, orders, 1) assert.Equal(t, orderEvent.OrderHash, orders[0].Hash) @@ -730,7 +716,9 @@ func TestOrderWatcherERC20Filled(t *testing.T) { teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) + ctx, cancelFn := context.WithCancel(context.Background()) + defer cancelFn() + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) takerAddress := constants.GanacheAccount3 @@ -738,9 +726,7 @@ func TestOrderWatcherERC20Filled(t *testing.T) { orderopts.SetupMakerState(true), orderopts.SetupTakerAddress(takerAddress), ) - ctx, cancelFn := context.WithCancel(context.Background()) - defer cancelFn() - blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, meshDB, signedOrder) + blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, database, signedOrder) // Fill order opts := &bind.TransactOpts{ @@ -761,8 +747,7 @@ func TestOrderWatcherERC20Filled(t *testing.T) { orderEvent := orderEvents[0] assert.Equal(t, zeroex.ESOrderFullyFilled, orderEvent.EndState) - var orders []*meshdb.Order - err = meshDB.Orders.FindAll(&orders) + orders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, orders, 1) assert.Equal(t, orderEvent.OrderHash, orders[0].Hash) @@ -778,7 +763,9 @@ func TestOrderWatcherERC20PartiallyFilled(t *testing.T) { teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) + ctx, cancelFn := context.WithCancel(context.Background()) + defer cancelFn() + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) takerAddress := constants.GanacheAccount3 @@ -786,9 +773,7 @@ func TestOrderWatcherERC20PartiallyFilled(t *testing.T) { orderopts.SetupMakerState(true), orderopts.SetupTakerAddress(takerAddress), ) - ctx, cancelFn := context.WithCancel(context.Background()) - defer cancelFn() - blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, meshDB, signedOrder) + blockWatcher, orderEventsChan := setupOrderWatcherScenario(ctx, t, ethClient, database, signedOrder) // Partially fill order opts := &bind.TransactOpts{ @@ -810,14 +795,14 @@ func TestOrderWatcherERC20PartiallyFilled(t *testing.T) { orderEvent := orderEvents[0] assert.Equal(t, zeroex.ESOrderFilled, orderEvent.EndState) - var orders []*meshdb.Order - err = meshDB.Orders.FindAll(&orders) + orders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, orders, 1) assert.Equal(t, orderEvent.OrderHash, orders[0].Hash) assert.Equal(t, false, orders[0].IsRemoved) assert.Equal(t, halfAmount, orders[0].FillableTakerAssetAmount) } + func TestOrderWatcherOrderExpiredThenUnexpired(t *testing.T) { if !serialTestsEnabled { t.Skip("Serial tests (tests which cannot run in parallel) are disabled. You can enable them with the --serial flag") @@ -826,12 +811,11 @@ func TestOrderWatcherOrderExpiredThenUnexpired(t *testing.T) { // Set up test and orderWatcher teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) - require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - defer func() { - cancel() - }() + defer cancel() + dbOptions := db.TestOptions() + database, err := db.New(ctx, dbOptions) + require.NoError(t, err) // Create and add an order (which will later become expired) to OrderWatcher expirationTime := time.Now().Add(24 * time.Hour) @@ -840,23 +824,23 @@ func TestOrderWatcherOrderExpiredThenUnexpired(t *testing.T) { orderopts.SetupMakerState(true), orderopts.ExpirationTimeSeconds(expirationTimeSeconds), ) - blockwatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, meshDB) + blockwatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, database) watchOrder(ctx, t, orderWatcher, blockwatcher, ethClient, signedOrder) orderEventsChan := make(chan []*zeroex.OrderEvent, 2*orderWatcher.maxOrders) orderWatcher.Subscribe(orderEventsChan) // Simulate a block found with a timestamp past expirationTime - latestBlock, err := meshDB.FindLatestMiniHeader() + latestBlock, err := database.GetLatestMiniHeader() require.NoError(t, err) - nextBlock := &miniheader.MiniHeader{ + nextBlock := &types.MiniHeader{ Parent: latestBlock.Hash, Hash: common.HexToHash("0x1"), Number: big.NewInt(0).Add(latestBlock.Number, big.NewInt(1)), Timestamp: expirationTime.Add(1 * time.Minute), } expiringBlockEvents := []*blockwatch.Event{ - &blockwatch.Event{ + { Type: blockwatch.Added, BlockHeader: nextBlock, }, @@ -869,8 +853,7 @@ func TestOrderWatcherOrderExpiredThenUnexpired(t *testing.T) { orderEvent := orderEvents[0] assert.Equal(t, zeroex.ESOrderExpired, orderEvent.EndState) - var orders []*meshdb.Order - err = meshDB.Orders.FindAll(&orders) + orders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, orders, 1) assert.Equal(t, orderEvent.OrderHash, orders[0].Hash) @@ -880,27 +863,27 @@ func TestOrderWatcherOrderExpiredThenUnexpired(t *testing.T) { // Simulate a block re-org replacementBlockHash := common.HexToHash("0x2") reorgBlockEvents := []*blockwatch.Event{ - &blockwatch.Event{ + { Type: blockwatch.Removed, BlockHeader: nextBlock, }, - &blockwatch.Event{ + { Type: blockwatch.Added, - BlockHeader: &miniheader.MiniHeader{ + BlockHeader: &types.MiniHeader{ Parent: nextBlock.Parent, Hash: replacementBlockHash, Number: nextBlock.Number, - Logs: []types.Log{}, + Logs: []ethtypes.Log{}, Timestamp: expirationTime.Add(-2 * time.Hour), }, }, - &blockwatch.Event{ + { Type: blockwatch.Added, - BlockHeader: &miniheader.MiniHeader{ + BlockHeader: &types.MiniHeader{ Parent: replacementBlockHash, Hash: common.HexToHash("0x3"), Number: big.NewInt(0).Add(nextBlock.Number, big.NewInt(1)), - Logs: []types.Log{}, + Logs: []ethtypes.Log{}, Timestamp: expirationTime.Add(-1 * time.Hour), }, }, @@ -913,8 +896,7 @@ func TestOrderWatcherOrderExpiredThenUnexpired(t *testing.T) { orderEvent = orderEvents[0] assert.Equal(t, zeroex.ESOrderUnexpired, orderEvent.EndState) - var newOrders []*meshdb.Order - err = meshDB.Orders.FindAll(&newOrders) + newOrders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, newOrders, 1) assert.Equal(t, orderEvent.OrderHash, newOrders[0].Hash) @@ -922,7 +904,9 @@ func TestOrderWatcherOrderExpiredThenUnexpired(t *testing.T) { assert.Equal(t, signedOrder.TakerAssetAmount, newOrders[0].FillableTakerAssetAmount) } +// TODO(albrow): Re-enable this test or move it. func TestOrderWatcherDecreaseExpirationTime(t *testing.T) { + t.Skip("Decreasing expiratin time is not yet implemented") if !serialTestsEnabled { t.Skip("Serial tests (tests which cannot run in parallel) are disabled. You can enable them with the --serial flag") } @@ -930,22 +914,20 @@ func TestOrderWatcherDecreaseExpirationTime(t *testing.T) { // Set up test and orderWatcher. Manually change maxOrders. teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + database, err := db.New(ctx, db.TestOptions()) require.NoError(t, err) // Store metadata entry in DB - metadata := &meshdb.Metadata{ + metadata := &types.Metadata{ EthereumChainID: 1337, MaxExpirationTime: constants.UnlimitedExpirationTime, } - err = meshDB.SaveMetadata(metadata) + err = database.SaveMetadata(metadata) require.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - defer func() { - cancel() - }() - blockWatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, meshDB) + blockWatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, database) orderWatcher.maxOrders = 20 // Create and watch maxOrders orders. Each order has a different expiration time. @@ -995,11 +977,11 @@ func TestOrderWatcherDecreaseExpirationTime(t *testing.T) { // Now we check that the correct number of orders remain and that all // remaining orders have an expiration time less than the current max. expectedRemainingOrders := int(float64(orderWatcher.maxOrders)*maxOrdersTrimRatio) + 1 - var remainingOrders []*meshdb.Order - require.NoError(t, meshDB.Orders.FindAll(&remainingOrders)) + remainingOrders, err := database.FindOrders(nil) + require.NoError(t, err) require.Len(t, remainingOrders, expectedRemainingOrders) for _, order := range remainingOrders { - assert.True(t, order.SignedOrder.ExpirationTimeSeconds.Cmp(orderWatcher.MaxExpirationTime()) == -1, "remaining order has an expiration time of %s which is *greater than* the maximum of %s", order.SignedOrder.ExpirationTimeSeconds, orderWatcher.MaxExpirationTime()) + assert.True(t, order.ExpirationTimeSeconds.Cmp(orderWatcher.MaxExpirationTime()) == -1, "remaining order has an expiration time of %s which is *greater than* the maximum of %s", order.ExpirationTimeSeconds, orderWatcher.MaxExpirationTime()) } } @@ -1011,13 +993,12 @@ func TestOrderWatcherBatchEmitsAddedEvents(t *testing.T) { teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) - require.NoError(t, err) - ctx, cancelFn := context.WithCancel(context.Background()) defer cancelFn() + database, err := db.New(ctx, db.TestOptions()) + require.NoError(t, err) - blockWatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, meshDB) + blockWatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, database) // Subscribe to OrderWatcher orderEventsChan := make(chan []*zeroex.OrderEvent, 10) @@ -1048,8 +1029,7 @@ func TestOrderWatcherBatchEmitsAddedEvents(t *testing.T) { assert.Equal(t, zeroex.ESOrderAdded, orderEvent.EndState) } - var orders []*meshdb.Order - err = meshDB.Orders.FindAll(&orders) + orders, err := database.FindOrders(nil) require.NoError(t, err) require.Len(t, orders, numOrders) } @@ -1062,12 +1042,11 @@ func TestOrderWatcherCleanup(t *testing.T) { teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) - require.NoError(t, err) - ctx, cancelFn := context.WithCancel(context.Background()) defer cancelFn() - blockWatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, meshDB) + database, err := db.New(ctx, db.TestOptions()) + require.NoError(t, err) + blockWatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, database) // Create and add two orders to OrderWatcher orderOptions := scenario.OptionsForAll(orderopts.SetupMakerState(true)) @@ -1081,11 +1060,10 @@ func TestOrderWatcherCleanup(t *testing.T) { // Set lastUpdate for signedOrderOne to more than defaultLastUpdatedBuffer so that signedOrderOne // does not get re-validated by the cleanup job - signedOrderOneDB := &meshdb.Order{} - err = meshDB.Orders.FindByID(signedOrderOneHash.Bytes(), signedOrderOneDB) - require.NoError(t, err) - signedOrderOneDB.LastUpdated = time.Now().Add(-defaultLastUpdatedBuffer - 1*time.Minute) - err = meshDB.Orders.Update(signedOrderOneDB) + err = database.UpdateOrder(signedOrderOneHash, func(orderToUpdate *types.OrderWithMetadata) (*types.OrderWithMetadata, error) { + orderToUpdate.LastUpdated = time.Now().Add(-defaultLastUpdatedBuffer - 1*time.Minute) + return orderToUpdate, nil + }) require.NoError(t, err) // Subscribe to OrderWatcher @@ -1105,161 +1083,6 @@ func TestOrderWatcherCleanup(t *testing.T) { } } -func TestOrderWatcherUpdateBlockHeadersStoredInDBHeaderExists(t *testing.T) { - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) - require.NoError(t, err) - - headerOne := &miniheader.MiniHeader{ - Number: big.NewInt(5), - Hash: common.HexToHash("0x293b9ea024055a3e9eddbf9b9383dc7731744111894af6aa038594dc1b61f87f"), - Parent: common.HexToHash("0x26b13ac89500f7fcdd141b7d1b30f3a82178431eca325d1cf10998f9d68ff5ba"), - Timestamp: time.Now().UTC(), - } - - testCases := []struct { - events []*blockwatch.Event - startMiniHeaders []*miniheader.MiniHeader - expectedMiniHeaders []*miniheader.MiniHeader - }{ - // Scenario 1: Header 1 exists in DB. Get's removed and then re-added. - { - events: []*blockwatch.Event{ - &blockwatch.Event{ - Type: blockwatch.Removed, - BlockHeader: headerOne, - }, - &blockwatch.Event{ - Type: blockwatch.Added, - BlockHeader: headerOne, - }, - }, - startMiniHeaders: []*miniheader.MiniHeader{ - headerOne, - }, - expectedMiniHeaders: []*miniheader.MiniHeader{ - headerOne, - }, - }, - // Scenario 2: Header doesn't exist, get's added and then removed - { - events: []*blockwatch.Event{ - &blockwatch.Event{ - Type: blockwatch.Added, - BlockHeader: headerOne, - }, - &blockwatch.Event{ - Type: blockwatch.Removed, - BlockHeader: headerOne, - }, - }, - startMiniHeaders: []*miniheader.MiniHeader{}, - expectedMiniHeaders: []*miniheader.MiniHeader{}, - }, - // Scenario 3: Header added, removed then re-added - { - events: []*blockwatch.Event{ - &blockwatch.Event{ - Type: blockwatch.Added, - BlockHeader: headerOne, - }, - &blockwatch.Event{ - Type: blockwatch.Removed, - BlockHeader: headerOne, - }, - &blockwatch.Event{ - Type: blockwatch.Added, - BlockHeader: headerOne, - }, - }, - startMiniHeaders: []*miniheader.MiniHeader{}, - expectedMiniHeaders: []*miniheader.MiniHeader{ - headerOne, - }, - }, - // Scenario 4: Header removed, added then removed again - { - events: []*blockwatch.Event{ - &blockwatch.Event{ - Type: blockwatch.Removed, - BlockHeader: headerOne, - }, - &blockwatch.Event{ - Type: blockwatch.Added, - BlockHeader: headerOne, - }, - &blockwatch.Event{ - Type: blockwatch.Removed, - BlockHeader: headerOne, - }, - }, - startMiniHeaders: []*miniheader.MiniHeader{ - headerOne, - }, - expectedMiniHeaders: []*miniheader.MiniHeader{}, - }, - // Scenario 5: Call added twice for the same block - { - events: []*blockwatch.Event{ - &blockwatch.Event{ - Type: blockwatch.Added, - BlockHeader: headerOne, - }, - &blockwatch.Event{ - Type: blockwatch.Added, - BlockHeader: headerOne, - }, - }, - startMiniHeaders: []*miniheader.MiniHeader{}, - expectedMiniHeaders: []*miniheader.MiniHeader{ - headerOne, - }, - }, - // Scenario 6: Call removed twice for the same block - { - events: []*blockwatch.Event{ - &blockwatch.Event{ - Type: blockwatch.Removed, - BlockHeader: headerOne, - }, - &blockwatch.Event{ - Type: blockwatch.Removed, - BlockHeader: headerOne, - }, - }, - startMiniHeaders: []*miniheader.MiniHeader{ - headerOne, - }, - expectedMiniHeaders: []*miniheader.MiniHeader{}, - }, - } - - for _, testCase := range testCases { - for _, startMiniHeader := range testCase.startMiniHeaders { - err = meshDB.MiniHeaders.Insert(startMiniHeader) - require.NoError(t, err) - } - - miniHeadersColTxn := meshDB.MiniHeaders.OpenTransaction() - defer func() { - _ = miniHeadersColTxn.Discard() - }() - - err = updateBlockHeadersStoredInDB(miniHeadersColTxn, testCase.events) - require.NoError(t, err) - - err = miniHeadersColTxn.Commit() - require.NoError(t, err) - - miniHeaders := []*miniheader.MiniHeader{} - err = meshDB.MiniHeaders.FindAll(&miniHeaders) - require.NoError(t, err) - assert.Equal(t, testCase.expectedMiniHeaders, miniHeaders) - - err := meshDB.ClearAllMiniHeaders() - require.NoError(t, err) - } -} - func TestOrderWatcherHandleOrderExpirationsExpired(t *testing.T) { if !serialTestsEnabled { t.Skip("Serial tests (tests which cannot run in parallel) are disabled. You can enable them with the --serial flag") @@ -1268,12 +1091,10 @@ func TestOrderWatcherHandleOrderExpirationsExpired(t *testing.T) { // Set up test and orderWatcher teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) - require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - defer func() { - cancel() - }() + defer cancel() + database, err := db.New(ctx, db.TestOptions()) + require.NoError(t, err) // Create and add an order (which will later become expired) to OrderWatcher expirationTime := time.Now().Add(24 * time.Hour) @@ -1285,29 +1106,23 @@ func TestOrderWatcherHandleOrderExpirationsExpired(t *testing.T) { signedOrders := scenario.NewSignedTestOrdersBatch(t, 2, orderOptions) signedOrderOne := signedOrders[0] signedOrderTwo := signedOrders[1] - blockwatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, meshDB) + blockwatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, database) watchOrder(ctx, t, orderWatcher, blockwatcher, ethClient, signedOrderOne) watchOrder(ctx, t, orderWatcher, blockwatcher, ethClient, signedOrderTwo) signedOrderOneHash, err := signedOrderOne.ComputeOrderHash() require.NoError(t, err) - var orderOne meshdb.Order - err = meshDB.Orders.FindByID(signedOrderOneHash.Bytes(), &orderOne) + orderOne, err := database.GetOrder(signedOrderOneHash) require.NoError(t, err) // Since we flag SignedOrderOne for revalidation, we expect `handleOrderExpirations` not to return an // expiry event for it. - ordersToRevalidate := map[common.Hash]*meshdb.Order{ - signedOrderOneHash: &orderOne, + ordersToRevalidate := map[common.Hash]*types.OrderWithMetadata{ + signedOrderOneHash: orderOne, } - ordersColTxn := meshDB.Orders.OpenTransaction() - defer func() { - _ = ordersColTxn.Discard() - }() - - previousLatestBlockTimestamp := expirationTime.Add(-1 * time.Minute) + // previousLatestBlockTimestamp := expirationTime.Add(-1 * time.Minute) latestBlockTimestamp := expirationTime.Add(1 * time.Second) - orderEvents, err := orderWatcher.handleOrderExpirations(ordersColTxn, latestBlockTimestamp, previousLatestBlockTimestamp, ordersToRevalidate) + orderEvents, err := orderWatcher.handleOrderExpirations(latestBlockTimestamp, ordersToRevalidate) require.NoError(t, err) require.Len(t, orderEvents, 1) @@ -1319,11 +1134,7 @@ func TestOrderWatcherHandleOrderExpirationsExpired(t *testing.T) { assert.Equal(t, big.NewInt(0), orderEvent.FillableTakerAssetAmount) assert.Len(t, orderEvent.ContractEvents, 0) - err = ordersColTxn.Commit() - require.NoError(t, err) - - var orderTwo meshdb.Order - err = meshDB.Orders.FindByID(signedOrderTwoHash.Bytes(), &orderTwo) + orderTwo, err := database.GetOrder(signedOrderTwoHash) require.NoError(t, err) assert.Equal(t, true, orderTwo.IsRemoved) } @@ -1336,12 +1147,10 @@ func TestOrderWatcherHandleOrderExpirationsUnexpired(t *testing.T) { // Set up test and orderWatcher teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) - require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - defer func() { - cancel() - }() + defer cancel() + database, err := db.New(ctx, db.TestOptions()) + require.NoError(t, err) // Create and add an order (which will later become expired) to OrderWatcher expirationTime := time.Now().Add(24 * time.Hour) @@ -1353,7 +1162,7 @@ func TestOrderWatcherHandleOrderExpirationsUnexpired(t *testing.T) { signedOrders := scenario.NewSignedTestOrdersBatch(t, 2, orderOptions) signedOrderOne := signedOrders[0] signedOrderTwo := signedOrders[1] - blockwatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, meshDB) + blockwatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, database) watchOrder(ctx, t, orderWatcher, blockwatcher, ethClient, signedOrderOne) watchOrder(ctx, t, orderWatcher, blockwatcher, ethClient, signedOrderTwo) @@ -1361,17 +1170,17 @@ func TestOrderWatcherHandleOrderExpirationsUnexpired(t *testing.T) { orderWatcher.Subscribe(orderEventsChan) // Simulate a block found with a timestamp past expirationTime - latestBlock, err := meshDB.FindLatestMiniHeader() + latestBlock, err := database.GetLatestMiniHeader() require.NoError(t, err) blockTimestamp := expirationTime.Add(1 * time.Minute) - nextBlock := &miniheader.MiniHeader{ + nextBlock := &types.MiniHeader{ Parent: latestBlock.Hash, Hash: common.HexToHash("0x1"), Number: big.NewInt(0).Add(latestBlock.Number, big.NewInt(1)), Timestamp: blockTimestamp, } expiringBlockEvents := []*blockwatch.Event{ - &blockwatch.Event{ + { Type: blockwatch.Added, BlockHeader: nextBlock, }, @@ -1387,25 +1196,18 @@ func TestOrderWatcherHandleOrderExpirationsUnexpired(t *testing.T) { signedOrderOneHash, err := signedOrderOne.ComputeOrderHash() require.NoError(t, err) - var orderOne meshdb.Order - err = meshDB.Orders.FindByID(signedOrderOneHash.Bytes(), &orderOne) + orderOne, err := database.GetOrder(signedOrderOneHash) require.NoError(t, err) // Since we flag SignedOrderOne for revalidation, we expect `handleOrderExpirations` not to return an // unexpiry event for it. - ordersToRevalidate := map[common.Hash]*meshdb.Order{ - signedOrderOneHash: &orderOne, + ordersToRevalidate := map[common.Hash]*types.OrderWithMetadata{ + signedOrderOneHash: orderOne, } - ordersColTxn := meshDB.Orders.OpenTransaction() - defer func() { - _ = ordersColTxn.Discard() - }() - // LatestBlockTimestamp is earlier than previous latest simulating block-reorg where new latest block // has an earlier timestamp than the last - previousLatestBlockTimestamp := blockTimestamp latestBlockTimestamp := expirationTime.Add(-1 * time.Minute) - orderEvents, err = orderWatcher.handleOrderExpirations(ordersColTxn, latestBlockTimestamp, previousLatestBlockTimestamp, ordersToRevalidate) + orderEvents, err = orderWatcher.handleOrderExpirations(latestBlockTimestamp, ordersToRevalidate) require.NoError(t, err) require.Len(t, orderEvents, 1) @@ -1417,81 +1219,11 @@ func TestOrderWatcherHandleOrderExpirationsUnexpired(t *testing.T) { assert.Equal(t, signedOrderTwo.TakerAssetAmount, orderEvent.FillableTakerAssetAmount) assert.Len(t, orderEvent.ContractEvents, 0) - err = ordersColTxn.Commit() - require.NoError(t, err) - - var orderTwo meshdb.Order - err = meshDB.Orders.FindByID(signedOrderTwoHash.Bytes(), &orderTwo) + orderTwo, err := database.GetOrder(signedOrderTwoHash) require.NoError(t, err) assert.Equal(t, false, orderTwo.IsRemoved) } -func TestOrderWatcherMaintainMiniHeaderRetentionLimit(t *testing.T) { - if !serialTestsEnabled { - t.Skip("Serial tests (tests which cannot run in parallel) are disabled. You can enable them with the --serial flag") - } - - // Set up test and orderWatcher - teardownSubTest := setupSubTest(t) - defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) - require.NoError(t, err) - err = meshDB.UpdateMiniHeaderRetentionLimit(miniHeaderRetentionLimit) - require.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - defer func() { - cancel() - }() - _, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, meshDB) - - latestMiniHeader, err := meshDB.FindLatestMiniHeader() - require.NoError(t, err) - - headerOne := &miniheader.MiniHeader{ - Number: big.NewInt(0).Add(latestMiniHeader.Number, big.NewInt(1)), - Hash: common.HexToHash("0x293b9ea024055a3e9eddbf9b9383dc7731744111894af6aa038594dc1b61f87f"), - Parent: common.HexToHash("0x26b13ac89500f7fcdd141b7d1b30f3a82178431eca325d1cf10998f9d68ff5ba"), - Timestamp: time.Now().UTC(), - } - headerTwo := &miniheader.MiniHeader{ - Number: big.NewInt(0).Add(headerOne.Number, big.NewInt(1)), - Hash: common.HexToHash("0x72ca9481b09b8c00b2c38575e5652f2de1077f1676c6b868cf575229fcb06a96"), - Parent: common.HexToHash("0x293b9ea024055a3e9eddbf9b9383dc7731744111894af6aa038594dc1b61f87f"), - Timestamp: time.Now().UTC(), - } - headerThree := &miniheader.MiniHeader{ - Number: big.NewInt(0).Add(headerTwo.Number, big.NewInt(1)), - Hash: common.HexToHash("0x1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347"), - Parent: common.HexToHash("0x72ca9481b09b8c00b2c38575e5652f2de1077f1676c6b868cf575229fcb06a96"), - Timestamp: time.Now().UTC(), - } - - blockEvents := []*blockwatch.Event{ - &blockwatch.Event{ - Type: blockwatch.Added, - BlockHeader: headerOne, - }, - &blockwatch.Event{ - Type: blockwatch.Added, - BlockHeader: headerTwo, - }, - &blockwatch.Event{ - Type: blockwatch.Added, - BlockHeader: headerThree, - }, - } - err = orderWatcher.handleBlockEvents(ctx, blockEvents) - require.NoError(t, err) - - latestMiniHeader, err = meshDB.FindLatestMiniHeader() - require.NoError(t, err) - assert.Equal(t, headerThree.Hash, latestMiniHeader.Hash) - - totalMiniHeaders, err := meshDB.MiniHeaders.Count() - require.NoError(t, err) - assert.Equal(t, meshDB.MiniHeaderRetentionLimit, totalMiniHeaders) -} - // Scenario: Order has become unexpired and filled in the same block events processed. We test this case using // `convertValidationResultsIntoOrderEvents` since we cannot properly time-travel using Ganache. // Source: https://github.com/trufflesuite/ganache-cli/issues/708 @@ -1503,12 +1235,10 @@ func TestConvertValidationResultsIntoOrderEventsUnexpired(t *testing.T) { // Set up test and orderWatcher teardownSubTest := setupSubTest(t) defer teardownSubTest(t) - meshDB, err := meshdb.New("/tmp/leveldb_testing/"+uuid.New().String(), ganacheAddresses) - require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - defer func() { - cancel() - }() + defer cancel() + database, err := db.New(ctx, db.TestOptions()) + require.NoError(t, err) // Create and add an order (which will later become expired) to OrderWatcher expirationTime := time.Now().Add(24 * time.Hour) @@ -1517,7 +1247,7 @@ func TestConvertValidationResultsIntoOrderEventsUnexpired(t *testing.T) { orderopts.SetupMakerState(true), orderopts.ExpirationTimeSeconds(expirationTimeSeconds), ) - blockwatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, meshDB) + blockwatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, database) watchOrder(ctx, t, orderWatcher, blockwatcher, ethClient, signedOrder) orderEventsChan := make(chan []*zeroex.OrderEvent, 2*orderWatcher.maxOrders) @@ -1525,10 +1255,10 @@ func TestConvertValidationResultsIntoOrderEventsUnexpired(t *testing.T) { // Simulate a block found with a timestamp past expirationTime. This will mark the order as removed // and will remove it from the expiration watcher. - latestBlock, err := meshDB.FindLatestMiniHeader() + latestBlock, err := database.GetLatestMiniHeader() require.NoError(t, err) blockTimestamp := expirationTime.Add(1 * time.Minute) - nextBlock := &miniheader.MiniHeader{ + nextBlock := &types.MiniHeader{ Parent: latestBlock.Hash, Hash: common.HexToHash("0x1"), Number: big.NewInt(0).Add(latestBlock.Number, big.NewInt(1)), @@ -1548,15 +1278,9 @@ func TestConvertValidationResultsIntoOrderEventsUnexpired(t *testing.T) { orderHash, err := signedOrder.ComputeOrderHash() require.NoError(t, err) - var orderOne meshdb.Order - err = meshDB.Orders.FindByID(orderHash.Bytes(), &orderOne) + orderOne, err := database.GetOrder(orderHash) require.NoError(t, err) - ordersColTxn := meshDB.Orders.OpenTransaction() - defer func() { - _ = ordersColTxn.Discard() - }() - validationResults := ordervalidator.ValidationResults{ Accepted: []*ordervalidator.AcceptedOrderInfo{ &ordervalidator.AcceptedOrderInfo{ @@ -1568,19 +1292,19 @@ func TestConvertValidationResultsIntoOrderEventsUnexpired(t *testing.T) { }, Rejected: []*ordervalidator.RejectedOrderInfo{}, } - orderHashToDBOrder := map[common.Hash]*meshdb.Order{ - orderHash: &orderOne, + orderHashToDBOrder := map[common.Hash]*types.OrderWithMetadata{ + orderHash: orderOne, } exchangeFillEvent := "ExchangeFillEvent" orderHashToEvents := map[common.Hash][]*zeroex.ContractEvent{ - orderHash: []*zeroex.ContractEvent{ + orderHash: { &zeroex.ContractEvent{ Kind: exchangeFillEvent, }, }, } validationBlockTimestamp := expirationTime.Add(-1 * time.Minute) - orderEvents, err = orderWatcher.convertValidationResultsIntoOrderEvents(ordersColTxn, &validationResults, orderHashToDBOrder, orderHashToEvents, validationBlockTimestamp) + orderEvents, err = orderWatcher.convertValidationResultsIntoOrderEvents(&validationResults, orderHashToDBOrder, orderHashToEvents, validationBlockTimestamp) require.NoError(t, err) require.Len(t, orderEvents, 2) @@ -1594,11 +1318,7 @@ func TestConvertValidationResultsIntoOrderEventsUnexpired(t *testing.T) { assert.Len(t, orderEventOne.ContractEvents, 1) assert.Equal(t, orderEventOne.ContractEvents[0].Kind, exchangeFillEvent) - err = ordersColTxn.Commit() - require.NoError(t, err) - - var existingOrder meshdb.Order - err = meshDB.Orders.FindByID(orderHash.Bytes(), &existingOrder) + existingOrder, err := database.GetOrder(orderHash) require.NoError(t, err) assert.Equal(t, false, existingOrder.IsRemoved) } @@ -1607,9 +1327,9 @@ func TestDrainAllBlockEventsChan(t *testing.T) { blockEventsChan := make(chan []*blockwatch.Event, 100) ts := time.Now().Add(1 * time.Hour) blockEventsOne := []*blockwatch.Event{ - &blockwatch.Event{ + { Type: blockwatch.Added, - BlockHeader: &miniheader.MiniHeader{ + BlockHeader: &types.MiniHeader{ Parent: common.HexToHash("0x0"), Hash: common.HexToHash("0x1"), Number: big.NewInt(1), @@ -1620,9 +1340,9 @@ func TestDrainAllBlockEventsChan(t *testing.T) { blockEventsChan <- blockEventsOne blockEventsTwo := []*blockwatch.Event{ - &blockwatch.Event{ + { Type: blockwatch.Added, - BlockHeader: &miniheader.MiniHeader{ + BlockHeader: &types.MiniHeader{ Parent: common.HexToHash("0x1"), Hash: common.HexToHash("0x2"), Number: big.NewInt(2), @@ -1648,8 +1368,8 @@ func TestDrainAllBlockEventsChan(t *testing.T) { require.Equal(t, allEvents[0], blockEventsOne[0]) } -func setupOrderWatcherScenario(ctx context.Context, t *testing.T, ethClient *ethclient.Client, meshDB *meshdb.MeshDB, signedOrder *zeroex.SignedOrder) (*blockwatch.Watcher, chan []*zeroex.OrderEvent) { - blockWatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, meshDB) +func setupOrderWatcherScenario(ctx context.Context, t *testing.T, ethClient *ethclient.Client, database *db.DB, signedOrder *zeroex.SignedOrder) (*blockwatch.Watcher, chan []*zeroex.OrderEvent) { + blockWatcher, orderWatcher := setupOrderWatcher(ctx, t, ethRPCClient, database) // Start watching an order watchOrder(ctx, t, orderWatcher, blockWatcher, ethClient, signedOrder) @@ -1673,13 +1393,12 @@ func watchOrder(ctx context.Context, t *testing.T, orderWatcher *Watcher, blockW require.Len(t, validationResults.Accepted, 1, "Expected order to pass validation and get added to OrderWatcher") } -func setupOrderWatcher(ctx context.Context, t *testing.T, ethRPCClient ethrpcclient.Client, meshDB *meshdb.MeshDB) (*blockwatch.Watcher, *Watcher) { +func setupOrderWatcher(ctx context.Context, t *testing.T, ethRPCClient ethrpcclient.Client, database *db.DB) (*blockwatch.Watcher, *Watcher) { blockWatcherClient, err := blockwatch.NewRpcClient(ethRPCClient) require.NoError(t, err) topics := GetRelevantTopics() - stack := simplestack.New(meshDB.MiniHeaderRetentionLimit, []*miniheader.MiniHeader{}) blockWatcherConfig := blockwatch.Config{ - Stack: stack, + DB: database, PollingInterval: blockPollingInterval, WithLogs: true, Topics: topics, @@ -1689,7 +1408,7 @@ func setupOrderWatcher(ctx context.Context, t *testing.T, ethRPCClient ethrpccli orderValidator, err := ordervalidator.New(ethRPCClient, constants.TestChainID, ethereumRPCMaxContentLength, ganacheAddresses) require.NoError(t, err) orderWatcher, err := New(Config{ - MeshDB: meshDB, + DB: database, BlockWatcher: blockWatcher, OrderValidator: orderValidator, ChainID: constants.TestChainID, @@ -1707,7 +1426,7 @@ func setupOrderWatcher(ctx context.Context, t *testing.T, ethRPCClient ethrpccli // Ensure at least one block has been processed and is stored in the DB // before tests run - storedBlocks, err := meshDB.FindAllMiniHeadersSortedByNumber() + storedBlocks, err := database.FindMiniHeaders(nil) require.NoError(t, err) if len(storedBlocks) == 0 { err := blockWatcher.SyncToLatestBlock() @@ -1743,7 +1462,7 @@ func waitForOrderEvents(t *testing.T, orderEventsChan <-chan []*zeroex.OrderEven } } -func waitTxnSuccessfullyMined(t *testing.T, ethClient *ethclient.Client, txn *types.Transaction) { +func waitTxnSuccessfullyMined(t *testing.T, ethClient *ethclient.Client, txn *ethtypes.Transaction) { ctx, cancelFn := context.WithTimeout(context.Background(), 4*time.Second) defer cancelFn() receipt, err := bind.WaitMined(ctx, ethClient, txn)