From 2fdd05d95617f0ac9a44463c37b40172cfbfccb4 Mon Sep 17 00:00:00 2001 From: Anton Jurisevic Date: Fri, 20 Nov 2020 18:24:19 +1100 Subject: [PATCH] Address list renamed to address set and make adding elements idempotent. (#917) --- contracts/AddressListLib.sol | 62 ------- contracts/AddressSetLib.sol | 65 +++++++ contracts/BinaryOptionMarketManager.sol | 18 +- contracts/test-helpers/TestableAddressSet.sol | 38 ++++ test/contracts/AddressSetLib.js | 164 ++++++++++++++++++ 5 files changed, 276 insertions(+), 71 deletions(-) delete mode 100644 contracts/AddressListLib.sol create mode 100644 contracts/AddressSetLib.sol create mode 100644 contracts/test-helpers/TestableAddressSet.sol create mode 100644 test/contracts/AddressSetLib.js diff --git a/contracts/AddressListLib.sol b/contracts/AddressListLib.sol deleted file mode 100644 index ef3a900106..0000000000 --- a/contracts/AddressListLib.sol +++ /dev/null @@ -1,62 +0,0 @@ -pragma solidity ^0.5.16; - - -// https://docs.synthetix.io/contracts/source/libraries/addresslistlib/ -library AddressListLib { - struct AddressList { - address[] elements; - mapping(address => uint) indices; - } - - function contains(AddressList storage list, address candidate) internal view returns (bool) { - if (list.elements.length == 0) { - return false; - } - uint index = list.indices[candidate]; - return index != 0 || list.elements[0] == candidate; - } - - function getPage( - AddressList storage list, - uint index, - uint pageSize - ) internal view returns (address[] memory) { - // NOTE: This implementation should be converted to slice operators if the compiler is updated to v0.6.0+ - uint endIndex = index + pageSize; // The check below that endIndex <= index handles overflow. - - // If the page extends past the end of the list, truncate it. - if (endIndex > list.elements.length) { - endIndex = list.elements.length; - } - if (endIndex <= index) { - return new address[](0); - } - - uint n = endIndex - index; // We already checked for negative overflow. - address[] memory page = new address[](n); - for (uint i; i < n; i++) { - page[i] = list.elements[i + index]; - } - return page; - } - - function push(AddressList storage list, address element) internal { - list.indices[element] = list.elements.length; - list.elements.push(element); - } - - function remove(AddressList storage list, address element) internal { - require(contains(list, element), "Element not in list."); - // Replace the removed element with the last element of the list. - uint index = list.indices[element]; - uint lastIndex = list.elements.length - 1; // We required that element is in the list, so it is not empty. - if (index != lastIndex) { - // No need to shift the last element if it is the one we want to delete. - address shiftedElement = list.elements[lastIndex]; - list.elements[index] = shiftedElement; - list.indices[shiftedElement] = index; - } - list.elements.pop(); - delete list.indices[element]; - } -} diff --git a/contracts/AddressSetLib.sol b/contracts/AddressSetLib.sol new file mode 100644 index 0000000000..4040ac411c --- /dev/null +++ b/contracts/AddressSetLib.sol @@ -0,0 +1,65 @@ +pragma solidity ^0.5.16; + + +// https://docs.synthetix.io/contracts/source/libraries/addresssetlib/ +library AddressSetLib { + struct AddressSet { + address[] elements; + mapping(address => uint) indices; + } + + function contains(AddressSet storage set, address candidate) internal view returns (bool) { + if (set.elements.length == 0) { + return false; + } + uint index = set.indices[candidate]; + return index != 0 || set.elements[0] == candidate; + } + + function getPage( + AddressSet storage set, + uint index, + uint pageSize + ) internal view returns (address[] memory) { + // NOTE: This implementation should be converted to slice operators if the compiler is updated to v0.6.0+ + uint endIndex = index + pageSize; // The check below that endIndex <= index handles overflow. + + // If the page extends past the end of the list, truncate it. + if (endIndex > set.elements.length) { + endIndex = set.elements.length; + } + if (endIndex <= index) { + return new address[](0); + } + + uint n = endIndex - index; // We already checked for negative overflow. + address[] memory page = new address[](n); + for (uint i; i < n; i++) { + page[i] = set.elements[i + index]; + } + return page; + } + + function add(AddressSet storage set, address element) internal { + // Adding to a set is an idempotent operation. + if (!contains(set, element)) { + set.indices[element] = set.elements.length; + set.elements.push(element); + } + } + + function remove(AddressSet storage set, address element) internal { + require(contains(set, element), "Element not in set."); + // Replace the removed element with the last element of the list. + uint index = set.indices[element]; + uint lastIndex = set.elements.length - 1; // We required that element is in the list, so it is not empty. + if (index != lastIndex) { + // No need to shift the last element if it is the one we want to delete. + address shiftedElement = set.elements[lastIndex]; + set.elements[index] = shiftedElement; + set.indices[shiftedElement] = index; + } + set.elements.pop(); + delete set.indices[element]; + } +} diff --git a/contracts/BinaryOptionMarketManager.sol b/contracts/BinaryOptionMarketManager.sol index 0c20f282c0..7f7bf1bc03 100644 --- a/contracts/BinaryOptionMarketManager.sol +++ b/contracts/BinaryOptionMarketManager.sol @@ -7,7 +7,7 @@ import "./MixinResolver.sol"; import "./interfaces/IBinaryOptionMarketManager.sol"; // Libraries -import "./AddressListLib.sol"; +import "./AddressSetLib.sol"; import "./SafeDecimalMath.sol"; // Internal references @@ -24,7 +24,7 @@ contract BinaryOptionMarketManager is Owned, Pausable, MixinResolver, IBinaryOpt /* ========== LIBRARIES ========== */ using SafeMath for uint; - using AddressListLib for AddressListLib.AddressList; + using AddressSetLib for AddressSetLib.AddressSet; /* ========== TYPES ========== */ @@ -54,8 +54,8 @@ contract BinaryOptionMarketManager is Owned, Pausable, MixinResolver, IBinaryOpt bool public marketCreationEnabled = true; uint public totalDeposited; - AddressListLib.AddressList internal _activeMarkets; - AddressListLib.AddressList internal _maturedMarkets; + AddressSetLib.AddressSet internal _activeMarkets; + AddressSetLib.AddressSet internal _maturedMarkets; BinaryOptionMarketManager internal _migratingManager; @@ -275,7 +275,7 @@ contract BinaryOptionMarketManager is Owned, Pausable, MixinResolver, IBinaryOpt [fees.poolFee, fees.creatorFee, fees.refundFee] ); market.setResolverAndSyncCache(resolver); - _activeMarkets.push(address(market)); + _activeMarkets.add(address(market)); // The debt can't be incremented in the new market's constructor because until construction is complete, // the manager doesn't know its address in order to grant it permission. @@ -290,7 +290,7 @@ contract BinaryOptionMarketManager is Owned, Pausable, MixinResolver, IBinaryOpt require(_activeMarkets.contains(market), "Not an active market"); BinaryOptionMarket(market).resolve(); _activeMarkets.remove(market); - _maturedMarkets.push(market); + _maturedMarkets.add(market); } function cancelMarket(address market) external notPaused { @@ -346,7 +346,7 @@ contract BinaryOptionMarketManager is Owned, Pausable, MixinResolver, IBinaryOpt if (_numMarkets == 0) { return; } - AddressListLib.AddressList storage markets = active ? _activeMarkets : _maturedMarkets; + AddressSetLib.AddressSet storage markets = active ? _activeMarkets : _maturedMarkets; uint runningDepositTotal; for (uint i; i < _numMarkets; i++) { @@ -375,7 +375,7 @@ contract BinaryOptionMarketManager is Owned, Pausable, MixinResolver, IBinaryOpt if (_numMarkets == 0) { return; } - AddressListLib.AddressList storage markets = active ? _activeMarkets : _maturedMarkets; + AddressSetLib.AddressSet storage markets = active ? _activeMarkets : _maturedMarkets; uint runningDepositTotal; for (uint i; i < _numMarkets; i++) { @@ -383,7 +383,7 @@ contract BinaryOptionMarketManager is Owned, Pausable, MixinResolver, IBinaryOpt require(!_isKnownMarket(address(market)), "Market already known."); market.acceptOwnership(); - markets.push(address(market)); + markets.add(address(market)); // Update the market with the new manager address, runningDepositTotal = runningDepositTotal.add(market.deposited()); } diff --git a/contracts/test-helpers/TestableAddressSet.sol b/contracts/test-helpers/TestableAddressSet.sol new file mode 100644 index 0000000000..f29f216d5d --- /dev/null +++ b/contracts/test-helpers/TestableAddressSet.sol @@ -0,0 +1,38 @@ +pragma solidity ^0.5.16; + +import "../AddressSetLib.sol"; + + +contract TestableAddressSet { + using AddressSetLib for AddressSetLib.AddressSet; + + AddressSetLib.AddressSet internal set; + + function contains(address candidate) public view returns (bool) { + return set.contains(candidate); + } + + function getPage(uint index, uint pageSize) public view returns (address[] memory) { + return set.getPage(index, pageSize); + } + + function add(address element) public { + set.add(element); + } + + function remove(address element) public { + set.remove(element); + } + + function size() public view returns (uint) { + return set.elements.length; + } + + function element(uint index) public view returns (address) { + return set.elements[index]; + } + + function index(address element) public view returns (uint) { + return set.indices[element]; + } +} diff --git a/test/contracts/AddressSetLib.js b/test/contracts/AddressSetLib.js new file mode 100644 index 0000000000..917d5a7abd --- /dev/null +++ b/test/contracts/AddressSetLib.js @@ -0,0 +1,164 @@ +const { contract, artifacts } = require('@nomiclabs/buidler'); +const { assert, addSnapshotBeforeRestoreAfterEach } = require('./common'); +const TestableAddressSet = artifacts.require('TestableAddressSet'); + +contract('AddressSetLib', accounts => { + let set; + + const [a, b, c, d, e] = accounts; + const testAccounts = [a, b, c, d, e]; + + before(async () => { + set = await TestableAddressSet.new(); + }); + + addSnapshotBeforeRestoreAfterEach(); + + it('Adding elements', async () => { + for (const account of testAccounts) { + assert.isFalse(await set.contains(account)); + } + assert.bnEqual(await set.size(), 0); + + for (let i = 0; i < testAccounts.length; i++) { + await set.add(testAccounts[i]); + // included + for (const account of accounts.slice(0, i + 1)) { + assert.isTrue(await set.contains(account)); + } + // not included + for (const account of accounts.slice(i + 1)) { + assert.isFalse(await set.contains(account)); + } + assert.bnEqual(await set.size(), i + 1); + } + }); + + it('Adding existing elements does nothing', async () => { + for (const account of testAccounts) { + await set.add(account); + } + + const preSize = await set.size(); + const preElements = []; + + for (let i = 0; i < preSize; i++) { + preElements.push(await set.element(i)); + } + + for (const account of testAccounts) { + await set.add(account); + } + + const postSize = await set.size(); + const postElements = []; + + for (let i = 0; i < postSize; i++) { + postElements.push(await set.element(i)); + } + assert.bnEqual(postSize, preSize); + assert.bnEqual(JSON.stringify(postElements), JSON.stringify(preElements)); + }); + + it('Removing elements', async () => { + for (const account of testAccounts) { + await set.add(account); + } + + const remainingAccounts = Array.from(testAccounts); + const accountsToRemove = [b, e, c, d, a]; + + for (let i = 0; i < testAccounts.length; i++) { + const account = accountsToRemove[i]; + remainingAccounts.splice(remainingAccounts.indexOf(account), 1); + remainingAccounts.sort(); + await set.remove(account); + + const elements = []; + const size = await set.size(); + for (let j = 0; j < size; j++) { + elements.push(await set.element(j)); + } + elements.sort(); + + assert.equal(JSON.stringify(elements), JSON.stringify(remainingAccounts)); + assert.bnEqual(size, remainingAccounts.length); + } + }); + + it("Can't remove nonexistent elements", async () => { + assert.bnEqual(await set.size(), 0); + await assert.revert(set.remove(a), 'Element not in set.'); + await set.add(a); + await assert.revert(set.remove(b), 'Element not in set.'); + await set.add(b); + await set.remove(a); + await assert.revert(set.remove(a), 'Element not in set.'); + }); + + it('Retrieving pages', async () => { + const windowSize = 2; + let ms; + + // Empty list + for (let i = 0; i < testAccounts.length; i++) { + ms = await set.getPage(i, 2); + assert.equal(ms.length, 0); + } + + for (const address of testAccounts) { + await set.add(address); + } + + // Single elements + for (let i = 0; i < testAccounts.length; i++) { + ms = await set.getPage(i, 1); + assert.equal(ms.length, 1); + assert.equal(ms[0], testAccounts[i]); + } + + // shifting window + for (let i = 0; i < testAccounts.length - windowSize; i++) { + ms = await set.getPage(i, windowSize); + assert.equal(ms.length, windowSize); + + for (let j = 0; j < windowSize; j++) { + assert.equal(ms[j], testAccounts[i + j]); + } + } + + // entire list + ms = await set.getPage(0, testAccounts.length); + assert.equal(ms.length, testAccounts.length); + for (let i = 0; i < testAccounts.length; i++) { + assert.equal(ms[i], testAccounts[i]); + } + + // Page extends past end of list + ms = await set.getPage(testAccounts.length - windowSize, windowSize * 2); + assert.equal(ms.length, windowSize); + for (let i = testAccounts.length - windowSize; i < testAccounts.length; i++) { + const j = i - (testAccounts.length - windowSize); + assert.equal(ms[j], testAccounts[i]); + } + + // zero page size + for (let i = 0; i < testAccounts.length; i++) { + ms = await set.getPage(i, 0); + assert.equal(ms.length, 0); + } + + // index past the end + for (let i = 0; i < 3; i++) { + ms = await set.getPage(testAccounts.length, i); + assert.equal(ms.length, 0); + } + + // Page size larger than entire list + ms = await set.getPage(0, testAccounts.length * 2); + assert.equal(ms.length, testAccounts.length); + for (let i = 0; i < testAccounts.length; i++) { + assert.equal(ms[i], testAccounts[i]); + } + }); +});