diff --git a/CHANGELOG.md b/CHANGELOG.md index 812c24687b3..f29774455c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### New features * `AccessControl`: new contract for managing permissions in a system, replacement for `Ownable` and `Roles`. ([#2112](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2112)) * `SafeCast`: new functions to convert to and from signed and unsigned values: `toUint256` and `toInt256`. ([#2123](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2123)) + * `EnumerableMap`: a new data structure for key-value pairs (like `mapping`) that can be iterated over. ([#2160](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2160)) ### Breaking changes * `ERC721`: `burn(owner, tokenId)` was removed, use `burn(tokenId)` instead. ([#2125](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2125)) @@ -30,6 +31,8 @@ * `ERC777`: removed `_callsTokensToSend` and `_callTokensReceived`. ([#2134](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2134)) * `EnumerableSet`: renamed `get` to `at`. ([#2151](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2151)) * `ERC165Checker`: functions no longer have a leading underscore. ([#2150](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2150)) + * `ERC721Metadata`, `ERC721Enumerable`: these contracts were removed, and their functionality merged into `ERC721`. ([#2160](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2160)) + * `ERC721`: added a constructor for `name` and `symbol`. ([#2160](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2160)) * `ERC20Detailed`: this contract was removed and its functionality merged into `ERC20`. ([#2161](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2161)) * `ERC20`: added a constructor for `name` and `symbol`. `decimals` now defaults to 18. ([#2161](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2161)) diff --git a/contracts/mocks/ERC721Mock.sol b/contracts/mocks/ERC721Mock.sol index 7290aba4976..03fdbdd8bf7 100644 --- a/contracts/mocks/ERC721Mock.sol +++ b/contracts/mocks/ERC721Mock.sol @@ -13,10 +13,6 @@ contract ERC721Mock is ERC721 { return _exists(tokenId); } - function tokensOfOwner(address owner) public view returns (uint256[] memory) { - return _tokensOfOwner(owner); - } - function setTokenURI(uint256 tokenId, string memory uri) public { _setTokenURI(tokenId, uri); } diff --git a/contracts/mocks/EnumerableMapMock.sol b/contracts/mocks/EnumerableMapMock.sol new file mode 100644 index 00000000000..74e35a2029b --- /dev/null +++ b/contracts/mocks/EnumerableMapMock.sol @@ -0,0 +1,38 @@ +pragma solidity ^0.6.0; + +import "../utils/EnumerableMap.sol"; + +contract EnumerableMapMock { + using EnumerableMap for EnumerableMap.UintToAddressMap; + + event OperationResult(bool result); + + EnumerableMap.UintToAddressMap private _map; + + function contains(uint256 key) public view returns (bool) { + return _map.contains(key); + } + + function set(uint256 key, address value) public { + bool result = _map.set(key, value); + emit OperationResult(result); + } + + function remove(uint256 key) public { + bool result = _map.remove(key); + emit OperationResult(result); + } + + function length() public view returns (uint256) { + return _map.length(); + } + + function at(uint256 index) public view returns (uint256 key, address value) { + return _map.at(index); + } + + + function get(uint256 key) public view returns (address) { + return _map.get(key); + } +} diff --git a/contracts/mocks/EnumerableSetMock.sol b/contracts/mocks/EnumerableSetMock.sol index f33f2bb098c..dc9db00ec7b 100644 --- a/contracts/mocks/EnumerableSetMock.sol +++ b/contracts/mocks/EnumerableSetMock.sol @@ -5,7 +5,7 @@ import "../utils/EnumerableSet.sol"; contract EnumerableSetMock { using EnumerableSet for EnumerableSet.AddressSet; - event TransactionResult(bool result); + event OperationResult(bool result); EnumerableSet.AddressSet private _set; @@ -15,16 +15,12 @@ contract EnumerableSetMock { function add(address value) public { bool result = _set.add(value); - emit TransactionResult(result); + emit OperationResult(result); } function remove(address value) public { bool result = _set.remove(value); - emit TransactionResult(result); - } - - function enumerate() public view returns (address[] memory) { - return _set.enumerate(); + emit OperationResult(result); } function length() public view returns (uint256) { diff --git a/contracts/token/ERC721/ERC721.sol b/contracts/token/ERC721/ERC721.sol index ad6cdb89add..17f7efff6b6 100644 --- a/contracts/token/ERC721/ERC721.sol +++ b/contracts/token/ERC721/ERC721.sol @@ -8,7 +8,8 @@ import "./IERC721Receiver.sol"; import "../../introspection/ERC165.sol"; import "../../math/SafeMath.sol"; import "../../utils/Address.sol"; -import "../../utils/Counters.sol"; +import "../../utils/EnumerableSet.sol"; +import "../../utils/EnumerableMap.sol"; /** * @title ERC721 Non-Fungible Token Standard basic implementation @@ -17,21 +18,22 @@ import "../../utils/Counters.sol"; contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable { using SafeMath for uint256; using Address for address; - using Counters for Counters.Counter; + using EnumerableSet for EnumerableSet.UintSet; + using EnumerableMap for EnumerableMap.UintToAddressMap; // Equals to `bytes4(keccak256("onERC721Received(address,address,uint256,bytes)"))` // which can be also obtained as `IERC721Receiver(0).onERC721Received.selector` bytes4 private constant _ERC721_RECEIVED = 0x150b7a02; - // Mapping from token ID to owner - mapping (uint256 => address) private _tokenOwner; + // Mapping from holder address to their (enumerable) set of owned tokens + mapping (address => EnumerableSet.UintSet) private _holderTokens; + + // Enumerable mapping from token ids to their owners + EnumerableMap.UintToAddressMap private _tokenOwners; // Mapping from token ID to approved address mapping (uint256 => address) private _tokenApprovals; - // Mapping from owner to number of owned token - mapping (address => Counters.Counter) private _ownedTokensCount; - // Mapping from owner to operator approvals mapping (address => mapping (address => bool)) private _operatorApprovals; @@ -47,18 +49,6 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable // Base URI string private _baseURI; - // Mapping from owner to list of owned token IDs - mapping(address => uint256[]) private _ownedTokens; - - // Mapping from token ID to index of the owner tokens list - mapping(uint256 => uint256) private _ownedTokensIndex; - - // Array with all token ids, used for enumeration - uint256[] private _allTokens; - - // Mapping from token id to position in the allTokens array - mapping(uint256 => uint256) private _allTokensIndex; - /* * bytes4(keccak256('balanceOf(address)')) == 0x70a08231 * bytes4(keccak256('ownerOf(uint256)')) == 0x6352211e @@ -111,7 +101,7 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable function balanceOf(address owner) public view override returns (uint256) { require(owner != address(0), "ERC721: balance query for the zero address"); - return _ownedTokensCount[owner].current(); + return _holderTokens[owner].length(); } /** @@ -120,10 +110,7 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable * @return address currently marked as the owner of the given token ID */ function ownerOf(uint256 tokenId) public view override returns (address) { - address owner = _tokenOwner[tokenId]; - require(owner != address(0), "ERC721: owner query for nonexistent token"); - - return owner; + return _tokenOwners.get(tokenId, "ERC721: owner query for nonexistent token"); } /** @@ -180,8 +167,7 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable * @return uint256 token ID at the given index of the tokens list owned by the requested address */ function tokenOfOwnerByIndex(address owner, uint256 index) public view override returns (uint256) { - require(index < balanceOf(owner), "ERC721Enumerable: owner index out of bounds"); - return _ownedTokens[owner][index]; + return _holderTokens[owner].at(index); } /** @@ -189,7 +175,8 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable * @return uint256 representing the total amount of tokens */ function totalSupply() public view override returns (uint256) { - return _allTokens.length; + // _tokenOwners are indexed by tokenIds, so .length() returns the number of tokenIds + return _tokenOwners.length(); } /** @@ -199,8 +186,8 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable * @return uint256 token ID at the given index of the tokens list */ function tokenByIndex(uint256 index) public view override returns (uint256) { - require(index < totalSupply(), "ERC721Enumerable: global index out of bounds"); - return _allTokens[index]; + (uint256 tokenId, ) = _tokenOwners.at(index); + return tokenId; } /** @@ -327,8 +314,7 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable * @return bool whether the token exists */ function _exists(uint256 tokenId) internal view returns (bool) { - address owner = _tokenOwner[tokenId]; - return owner != address(0); + return _tokenOwners.contains(tokenId); } /** @@ -386,11 +372,9 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable _beforeTokenTransfer(address(0), to, tokenId); - _addTokenToOwnerEnumeration(to, tokenId); - _addTokenToAllTokensEnumeration(tokenId); + _holderTokens[to].add(tokenId); - _tokenOwner[tokenId] = to; - _ownedTokensCount[to].increment(); + _tokenOwners.set(tokenId, to); emit Transfer(address(0), to, tokenId); } @@ -405,22 +389,17 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable _beforeTokenTransfer(owner, address(0), tokenId); + // Clear approvals + _approve(address(0), tokenId); + // Clear metadata (if any) if (bytes(_tokenURIs[tokenId]).length != 0) { delete _tokenURIs[tokenId]; } - _removeTokenFromOwnerEnumeration(owner, tokenId); - // Since tokenId will be deleted, we can clear its slot in _ownedTokensIndex to trigger a gas refund - _ownedTokensIndex[tokenId] = 0; - - _removeTokenFromAllTokensEnumeration(tokenId); - - // Clear approvals - _approve(address(0), tokenId); + _holderTokens[owner].remove(tokenId); - _ownedTokensCount[owner].decrement(); - _tokenOwner[tokenId] = address(0); + _tokenOwners.remove(tokenId); emit Transfer(owner, address(0), tokenId); } @@ -438,16 +417,13 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable _beforeTokenTransfer(from, to, tokenId); - _removeTokenFromOwnerEnumeration(from, tokenId); - _addTokenToOwnerEnumeration(to, tokenId); - - // Clear approvals + // Clear approvals from the previous owner _approve(address(0), tokenId); - _ownedTokensCount[from].decrement(); - _ownedTokensCount[to].increment(); + _holderTokens[from].remove(tokenId); + _holderTokens[to].add(tokenId); - _tokenOwner[tokenId] = to; + _tokenOwners.set(tokenId, to); emit Transfer(from, to, tokenId); } @@ -474,15 +450,6 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable _baseURI = baseURI_; } - /** - * @dev Gets the list of token IDs of the requested owner. - * @param owner address owning the tokens - * @return uint256[] List of token IDs owned by the requested address - */ - function _tokensOfOwner(address owner) internal view returns (uint256[] storage) { - return _ownedTokens[owner]; - } - /** * @dev Internal function to invoke {IERC721Receiver-onERC721Received} on a target address. * The call is not executed if the target address is not a contract. @@ -528,81 +495,6 @@ contract ERC721 is Context, ERC165, IERC721, IERC721Metadata, IERC721Enumerable emit Approval(ownerOf(tokenId), to, tokenId); } - /** - * @dev Private function to add a token to this extension's ownership-tracking data structures. - * @param to address representing the new owner of the given token ID - * @param tokenId uint256 ID of the token to be added to the tokens list of the given address - */ - function _addTokenToOwnerEnumeration(address to, uint256 tokenId) private { - _ownedTokensIndex[tokenId] = _ownedTokens[to].length; - _ownedTokens[to].push(tokenId); - } - - /** - * @dev Private function to add a token to this extension's token tracking data structures. - * @param tokenId uint256 ID of the token to be added to the tokens list - */ - function _addTokenToAllTokensEnumeration(uint256 tokenId) private { - _allTokensIndex[tokenId] = _allTokens.length; - _allTokens.push(tokenId); - } - - /** - * @dev Private function to remove a token from this extension's ownership-tracking data structures. Note that - * while the token is not assigned a new owner, the `_ownedTokensIndex` mapping is _not_ updated: this allows for - * gas optimizations e.g. when performing a transfer operation (avoiding double writes). - * This has O(1) time complexity, but alters the order of the _ownedTokens array. - * @param from address representing the previous owner of the given token ID - * @param tokenId uint256 ID of the token to be removed from the tokens list of the given address - */ - function _removeTokenFromOwnerEnumeration(address from, uint256 tokenId) private { - // To prevent a gap in from's tokens array, we store the last token in the index of the token to delete, and - // then delete the last slot (swap and pop). - - uint256 lastTokenIndex = _ownedTokens[from].length.sub(1); - uint256 tokenIndex = _ownedTokensIndex[tokenId]; - - // When the token to delete is the last token, the swap operation is unnecessary - if (tokenIndex != lastTokenIndex) { - uint256 lastTokenId = _ownedTokens[from][lastTokenIndex]; - - _ownedTokens[from][tokenIndex] = lastTokenId; // Move the last token to the slot of the to-delete token - _ownedTokensIndex[lastTokenId] = tokenIndex; // Update the moved token's index - } - - // Deletes the contents at the last position of the array - _ownedTokens[from].pop(); - - // Note that _ownedTokensIndex[tokenId] hasn't been cleared: it still points to the old slot (now occupied by - // lastTokenId, or just over the end of the array if the token was the last one). - } - - /** - * @dev Private function to remove a token from this extension's token tracking data structures. - * This has O(1) time complexity, but alters the order of the _allTokens array. - * @param tokenId uint256 ID of the token to be removed from the tokens list - */ - function _removeTokenFromAllTokensEnumeration(uint256 tokenId) private { - // To prevent a gap in the tokens array, we store the last token in the index of the token to delete, and - // then delete the last slot (swap and pop). - - uint256 lastTokenIndex = _allTokens.length.sub(1); - uint256 tokenIndex = _allTokensIndex[tokenId]; - - // When the token to delete is the last token, the swap operation is unnecessary. However, since this occurs so - // rarely (when the last minted token is burnt) that we still do the swap here to avoid the gas cost of adding - // an 'if' statement (like in _removeTokenFromOwnerEnumeration) - uint256 lastTokenId = _allTokens[lastTokenIndex]; - - _allTokens[tokenIndex] = lastTokenId; // Move the last token to the slot of the to-delete token - _allTokensIndex[lastTokenId] = tokenIndex; // Update the moved token's index - - // Delete the contents at the last position of the array - _allTokens.pop(); - - _allTokensIndex[tokenId] = 0; - } - /** * @dev Hook that is called before any token transfer. This includes minting * and burning. diff --git a/contracts/utils/EnumerableMap.sol b/contracts/utils/EnumerableMap.sol new file mode 100644 index 00000000000..9ba74de5688 --- /dev/null +++ b/contracts/utils/EnumerableMap.sol @@ -0,0 +1,211 @@ +pragma solidity ^0.6.0; + +library EnumerableMap { + // To implement this library for multiple types with as little code + // repetition as possible, we write it in terms of a generic Map type with + // bytes32 keys and values. + // The Map implementation uses private functions, and user-facing + // implementations (such as Uint256ToAddressMap) are just wrappers around + // the underlying Map. + // This means that we can only create new EnumerableMaps for types that fit + // in bytes32. + + struct MapEntry { + bytes32 _key; + bytes32 _value; + } + + struct Map { + // Storage of map keys and values + MapEntry[] _entries; + + // Position of the entry defined by a key in the `entries` array, plus 1 + // because index 0 means a key is not in the map. + mapping (bytes32 => uint256) _indexes; + } + + /** + * @dev Adds a key-value pair to a map, or updates the value for an existing + * key. O(1). + * + * Returns true if the key was added to the map, that is if it was not + * already present. + */ + function _set(Map storage map, bytes32 key, bytes32 value) private returns (bool) { + // We read and store the key's index to prevent multiple reads from the same storage slot + uint256 keyIndex = map._indexes[key]; + + if (keyIndex == 0) { // Equivalent to !contains(map, key) + map._entries.push(MapEntry({ _key: key, _value: value })); + // The entry is stored at length-1, but we add 1 to all indexes + // and use 0 as a sentinel value + map._indexes[key] = map._entries.length; + return true; + } else { + map._entries[keyIndex - 1]._value = value; + return false; + } + } + + /** + * @dev Removes a key-value pair from a map. O(1). + * + * Returns true if the key was removed from the map, that is if it was present. + */ + function _remove(Map storage map, bytes32 key) private returns (bool) { + // We read and store the key's index to prevent multiple reads from the same storage slot + uint256 keyIndex = map._indexes[key]; + + if (keyIndex != 0) { // Equivalent to contains(map, key) + // To delete a key-value pair from the _entries array in O(1), we swap the entry to delete with the last one + // in the array, and then remove the last entry (sometimes called as 'swap and pop'). + // This modifies the order of the array, as noted in {at}. + + uint256 toDeleteIndex = keyIndex - 1; + uint256 lastIndex = map._entries.length - 1; + + // When the entry to delete is the last one, the swap operation is unnecessary. However, since this occurs + // so rarely, we still do the swap anyway to avoid the gas cost of adding an 'if' statement. + + MapEntry storage lastEntry = map._entries[lastIndex]; + + // Move the last entry to the index where the entry to delete is + map._entries[toDeleteIndex] = lastEntry; + // Update the index for the moved entry + map._indexes[lastEntry._key] = toDeleteIndex + 1; // All indexes are 1-based + + // Delete the slot where the moved entry was stored + map._entries.pop(); + + // Delete the index for the deleted slot + delete map._indexes[key]; + + return true; + } else { + return false; + } + } + + /** + * @dev Returns true if the key is in the map. O(1). + */ + function _contains(Map storage map, bytes32 key) private view returns (bool) { + return map._indexes[key] != 0; + } + + /** + * @dev Returns the number of key-value pairs in the map. O(1). + */ + function _length(Map storage map) private view returns (uint256) { + return map._entries.length; + } + + /** + * @dev Returns the key-value pair stored at position `index` in the map. O(1). + * + * Note that there are no guarantees on the ordering of entries inside the + * array, and it may change when more entries are added or removed. + * + * Requirements: + * + * - `index` must be strictly less than {length}. + */ + function _at(Map storage map, uint256 index) private view returns (bytes32, bytes32) { + require(map._entries.length > index, "EnumerableMap: index out of bounds"); + + MapEntry storage entry = map._entries[index]; + return (entry._key, entry._value); + } + + /** + * @dev Returns the value associated with `key`. O(1). + * + * Requirements: + * + * - `key` must be in the map. + */ + function _get(Map storage map, bytes32 key) private view returns (bytes32) { + return _get(map, key, "EnumerableMap: nonexistent key"); + } + + /** + * @dev Same as {_get}, with a custom error message when `key` is not in the map. + */ + function _get(Map storage map, bytes32 key, string memory errorMessage) private view returns (bytes32) { + uint256 keyIndex = map._indexes[key]; + require(keyIndex != 0, errorMessage); // Equivalent to contains(map, key) + return map._entries[keyIndex - 1]._value; // All indexes are 1-based + } + + // UintToAddressMap + + struct UintToAddressMap { + Map _inner; + } + + /** + * @dev Adds a key-value pair to a map, or updates the value for an existing + * key. O(1). + * + * Returns true if the key was added to the map, that is if it was not + * already present. + */ + function set(UintToAddressMap storage map, uint256 key, address value) internal returns (bool) { + return _set(map._inner, bytes32(key), bytes32(uint256(value))); + } + + /** + * @dev Removes a value from a set. O(1). + * + * Returns true if the key was removed from the map, that is if it was present. + */ + function remove(UintToAddressMap storage map, uint256 key) internal returns (bool) { + return _remove(map._inner, bytes32(key)); + } + + /** + * @dev Returns true if the key is in the map. O(1). + */ + function contains(UintToAddressMap storage map, uint256 key) internal view returns (bool) { + return _contains(map._inner, bytes32(key)); + } + + /** + * @dev Returns the number of elements in the map. O(1). + */ + function length(UintToAddressMap storage map) internal view returns (uint256) { + return _length(map._inner); + } + + /** + * @dev Returns the element stored at position `index` in the set. O(1). + * Note that there are no guarantees on the ordering of values inside the + * array, and it may change when more values are added or removed. + * + * Requirements: + * + * - `index` must be strictly less than {length}. + */ + function at(UintToAddressMap storage map, uint256 index) internal view returns (uint256, address) { + (bytes32 key, bytes32 value) = _at(map._inner, index); + return (uint256(key), address(uint256(value))); + } + + /** + * @dev Returns the value associated with `key`. O(1). + * + * Requirements: + * + * - `key` must be in the map. + */ + function get(UintToAddressMap storage map, uint256 key) internal view returns (address) { + return address(uint256(_get(map._inner, bytes32(key)))); + } + + /** + * @dev Same as {get}, with a custom error message when `key` is not in the map. + */ + function get(UintToAddressMap storage map, uint256 key, string memory errorMessage) internal view returns (address) { + return address(uint256(_get(map._inner, bytes32(key), errorMessage))); + } +} diff --git a/contracts/utils/EnumerableSet.sol b/contracts/utils/EnumerableSet.sol index dbe54458be1..4c040e9faff 100644 --- a/contracts/utils/EnumerableSet.sol +++ b/contracts/utils/EnumerableSet.sol @@ -18,24 +18,32 @@ pragma solidity ^0.6.0; * @author Alberto Cuesta CaƱada */ library EnumerableSet { + // To implement this library for multiple types with as little code + // repetition as possible, we write it in terms of a generic Set type with + // bytes32 values. + // The Set implementation uses private functions, and user-facing + // implementations (such as AddressSet) are just wrappers around the + // underlying Set. + // This means that we can only create new EnumerableSets for types that fit + // in bytes32. + + struct Set { + // Storage of set values + bytes32[] _values; - struct AddressSet { - address[] _values; // Position of the value in the `values` array, plus 1 because index 0 // means a value is not in the set. - mapping (address => uint256) _indexes; + mapping (bytes32 => uint256) _indexes; } /** * @dev Add a value to a set. O(1). * - * Returns false if the value was already in the set. + * Returns true if the value was added to the set, that is if it was not + * already present. */ - function add(AddressSet storage set, address value) - internal - returns (bool) - { - if (!contains(set, value)) { + function _add(Set storage set, bytes32 value) private returns (bool) { + if (!_contains(set, value)) { set._values.push(value); // The value is stored at length-1, but we add 1 to all indexes // and use 0 as a sentinel value @@ -49,25 +57,30 @@ library EnumerableSet { /** * @dev Removes a value from a set. O(1). * - * Returns false if the value was not present in the set. + * Returns true if the value was removed from the set, that is if it was + * present. */ - function remove(AddressSet storage set, address value) - internal - returns (bool) - { - if (contains(set, value)){ - uint256 toDeleteIndex = set._indexes[value] - 1; + function _remove(Set storage set, bytes32 value) private returns (bool) { + // We read and store the value's index to prevent multiple reads from the same storage slot + uint256 valueIndex = set._indexes[value]; + + if (valueIndex != 0) { // Equivalent to contains(set, value) + // To delete an element from the _values array in O(1), we swap the element to delete with the last one in + // the array, and then remove the last element (sometimes called as 'swap and pop'). + // This modifies the order of the array, as noted in {at}. + + uint256 toDeleteIndex = valueIndex - 1; uint256 lastIndex = set._values.length - 1; - // If the value we're deleting is the last one, we can just remove it without doing a swap - if (lastIndex != toDeleteIndex) { - address lastvalue = set._values[lastIndex]; + // When the value to delete is the last one, the swap operation is unnecessary. However, since this occurs + // so rarely, we still do the swap anyway to avoid the gas cost of adding an 'if' statement. + + bytes32 lastvalue = set._values[lastIndex]; - // Move the last value to the index where the deleted value is - set._values[toDeleteIndex] = lastvalue; - // Update the index for the moved value - set._indexes[lastvalue] = toDeleteIndex + 1; // All indexes are 1-based - } + // Move the last value to the index where the value to delete is + set._values[toDeleteIndex] = lastvalue; + // Update the index for the moved value + set._indexes[lastvalue] = toDeleteIndex + 1; // All indexes are 1-based // Delete the slot where the moved value was stored set._values.pop(); @@ -84,44 +97,125 @@ library EnumerableSet { /** * @dev Returns true if the value is in the set. O(1). */ - function contains(AddressSet storage set, address value) - internal - view - returns (bool) - { + function _contains(Set storage set, bytes32 value) private view returns (bool) { return set._indexes[value] != 0; } /** - * @dev Returns an array with all values in the set. O(N). + * @dev Returns the number of values on the set. O(1). + */ + function _length(Set storage set) private view returns (uint256) { + return set._values.length; + } + + /** + * @dev Returns the value stored at position `index` in the set. O(1). + * + * Note that there are no guarantees on the ordering of values inside the + * array, and it may change when more values are added or removed. + * + * Requirements: + * + * - `index` must be strictly less than {length}. + */ + function _at(Set storage set, uint256 index) private view returns (bytes32) { + require(set._values.length > index, "EnumerableSet: index out of bounds"); + return set._values[index]; + } + + // AddressSet + + struct AddressSet { + Set _inner; + } + + /** + * @dev Add a value to a set. O(1). + * + * Returns true if the value was added to the set, that is if it was not + * already present. + */ + function add(AddressSet storage set, address value) internal returns (bool) { + return _add(set._inner, bytes32(uint256(value))); + } + + /** + * @dev Removes a value from a set. O(1). * - * Note that there are no guarantees on the ordering of values inside the - * array, and it may change when more values are added or removed. + * Returns true if the value was removed from the set, that is if it was + * present. + */ + function remove(AddressSet storage set, address value) internal returns (bool) { + return _remove(set._inner, bytes32(uint256(value))); + } - * WARNING: This function may run out of gas on large sets: use {length} and - * {at} instead in these cases. + /** + * @dev Returns true if the value is in the set. O(1). */ - function enumerate(AddressSet storage set) - internal - view - returns (address[] memory) - { - address[] memory output = new address[](set._values.length); - for (uint256 i; i < set._values.length; i++){ - output[i] = set._values[i]; - } - return output; + function contains(AddressSet storage set, address value) internal view returns (bool) { + return _contains(set._inner, bytes32(uint256(value))); + } + + /** + * @dev Returns the number of values in the set. O(1). + */ + function length(AddressSet storage set) internal view returns (uint256) { + return _length(set._inner); + } + + /** + * @dev Returns the value stored at position `index` in the set. O(1). + * + * Note that there are no guarantees on the ordering of values inside the + * array, and it may change when more values are added or removed. + * + * Requirements: + * + * - `index` must be strictly less than {length}. + */ + function at(AddressSet storage set, uint256 index) internal view returns (address) { + return address(uint256(_at(set._inner, index))); + } + + + // UintSet + + struct UintSet { + Set _inner; + } + + /** + * @dev Add a value to a set. O(1). + * + * Returns true if the value was added to the set, that is if it was not + * already present. + */ + function add(UintSet storage set, uint256 value) internal returns (bool) { + return _add(set._inner, bytes32(value)); + } + + /** + * @dev Removes a value from a set. O(1). + * + * Returns true if the value was removed from the set, that is if it was + * present. + */ + function remove(UintSet storage set, uint256 value) internal returns (bool) { + return _remove(set._inner, bytes32(value)); + } + + /** + * @dev Returns true if the value is in the set. O(1). + */ + function contains(UintSet storage set, uint256 value) internal view returns (bool) { + return _contains(set._inner, bytes32(value)); } /** * @dev Returns the number of values on the set. O(1). */ - function length(AddressSet storage set) - internal - view - returns (uint256) - { - return set._values.length; + function length(UintSet storage set) internal view returns (uint256) { + return _length(set._inner); } /** @@ -134,12 +228,7 @@ library EnumerableSet { * * - `index` must be strictly less than {length}. */ - function at(AddressSet storage set, uint256 index) - internal - view - returns (address) - { - require(set._values.length > index, "EnumerableSet: index out of bounds"); - return set._values[index]; + function at(UintSet storage set, uint256 index) internal view returns (uint256) { + return uint256(_at(set._inner, index)); } } diff --git a/package-lock.json b/package-lock.json index 5595a78f22d..584648521d9 100644 --- a/package-lock.json +++ b/package-lock.json @@ -31087,6 +31087,12 @@ "lodash._reinterpolate": "^3.0.0" } }, + "lodash.zip": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/lodash.zip/-/lodash.zip-4.2.0.tgz", + "integrity": "sha1-7GZi5IlkCO1KtsVCo5kLcswIACA=", + "dev": true + }, "log-symbols": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/log-symbols/-/log-symbols-3.0.0.tgz", diff --git a/package.json b/package.json index ba1748332ce..c0584ee4a4f 100644 --- a/package.json +++ b/package.json @@ -61,6 +61,7 @@ "ethereumjs-util": "^6.2.0", "ganache-core-coverage": "https://github.com/OpenZeppelin/ganache-core-coverage/releases/download/2.5.3-coverage/ganache-core-coverage-2.5.3.tgz", "lodash.startcase": "^4.4.0", + "lodash.zip": "^4.2.0", "micromatch": "^4.0.2", "mocha": "^7.1.1", "solhint": "^3.0.0-rc.6", diff --git a/test/token/ERC721/ERC721.test.js b/test/token/ERC721/ERC721.test.js index eb156da3ff2..2b850a0dca8 100644 --- a/test/token/ERC721/ERC721.test.js +++ b/test/token/ERC721/ERC721.test.js @@ -176,27 +176,17 @@ describe('ERC721', function () { expect(await this.token.ownerOf(tokenId)).to.be.equal(this.toWhom); }); + it('emits a Transfer event', async function () { + expectEvent.inLogs(logs, 'Transfer', { from: owner, to: this.toWhom, tokenId: tokenId }); + }); + it('clears the approval for the token ID', async function () { expect(await this.token.getApproved(tokenId)).to.be.equal(ZERO_ADDRESS); }); - if (approved) { - it('emit only a transfer event', async function () { - expectEvent.inLogs(logs, 'Transfer', { - from: owner, - to: this.toWhom, - tokenId: tokenId, - }); - }); - } else { - it('emits only a transfer event', async function () { - expectEvent.inLogs(logs, 'Transfer', { - from: owner, - to: this.toWhom, - tokenId: tokenId, - }); - }); - } + it('emits an Approval event', async function () { + expectEvent.inLogs(logs, 'Approval', { owner, approved: ZERO_ADDRESS, tokenId: tokenId }); + }); it('adjusts owners balances', async function () { expect(await this.token.balanceOf(owner)).to.be.bignumber.equal('1'); @@ -708,15 +698,6 @@ describe('ERC721', function () { }); }); - describe('tokensOfOwner', function () { - it('returns total tokens of owner', async function () { - const tokenIds = await this.token.tokensOfOwner(owner); - expect(tokenIds.length).to.equal(2); - expect(tokenIds[0]).to.be.bignumber.equal(firstTokenId); - expect(tokenIds[1]).to.be.bignumber.equal(secondTokenId); - }); - }); - describe('totalSupply', function () { it('returns total token supply', async function () { expect(await this.token.totalSupply()).to.be.bignumber.equal('2'); @@ -733,7 +714,7 @@ describe('ERC721', function () { describe('when the index is greater than or equal to the total tokens owned by the given address', function () { it('reverts', async function () { await expectRevert( - this.token.tokenOfOwnerByIndex(owner, 2), 'ERC721Enumerable: owner index out of bounds' + this.token.tokenOfOwnerByIndex(owner, 2), 'EnumerableSet: index out of bounds' ); }); }); @@ -741,7 +722,7 @@ describe('ERC721', function () { describe('when the given address does not own any token', function () { it('reverts', async function () { await expectRevert( - this.token.tokenOfOwnerByIndex(other, 0), 'ERC721Enumerable: owner index out of bounds' + this.token.tokenOfOwnerByIndex(other, 0), 'EnumerableSet: index out of bounds' ); }); }); @@ -764,7 +745,7 @@ describe('ERC721', function () { it('returns empty collection for original owner', async function () { expect(await this.token.balanceOf(owner)).to.be.bignumber.equal('0'); await expectRevert( - this.token.tokenOfOwnerByIndex(owner, 0), 'ERC721Enumerable: owner index out of bounds' + this.token.tokenOfOwnerByIndex(owner, 0), 'EnumerableSet: index out of bounds' ); }); }); @@ -781,7 +762,7 @@ describe('ERC721', function () { it('should revert if index is greater than supply', async function () { await expectRevert( - this.token.tokenByIndex(2), 'ERC721Enumerable: global index out of bounds' + this.token.tokenByIndex(2), 'EnumerableMap: index out of bounds' ); }); @@ -790,7 +771,7 @@ describe('ERC721', function () { const newTokenId = new BN(300); const anotherNewTokenId = new BN(400); - await this.token.burn(tokenId, { from: owner }); + await this.token.burn(tokenId); await this.token.mint(newOwner, newTokenId); await this.token.mint(newOwner, anotherNewTokenId); @@ -865,6 +846,10 @@ describe('ERC721', function () { expectEvent.inLogs(this.logs, 'Transfer', { from: owner, to: ZERO_ADDRESS, tokenId: firstTokenId }); }); + it('emits an Approval event', function () { + expectEvent.inLogs(this.logs, 'Approval', { owner, approved: ZERO_ADDRESS, tokenId: firstTokenId }); + }); + it('deletes the token', async function () { expect(await this.token.balanceOf(owner)).to.be.bignumber.equal('1'); await expectRevert( @@ -884,7 +869,7 @@ describe('ERC721', function () { await this.token.burn(secondTokenId, { from: owner }); expect(await this.token.totalSupply()).to.be.bignumber.equal('0'); await expectRevert( - this.token.tokenByIndex(0), 'ERC721Enumerable: global index out of bounds' + this.token.tokenByIndex(0), 'EnumerableMap: index out of bounds' ); }); diff --git a/test/utils/EnumerableMap.test.js b/test/utils/EnumerableMap.test.js new file mode 100644 index 00000000000..19e24c8876c --- /dev/null +++ b/test/utils/EnumerableMap.test.js @@ -0,0 +1,139 @@ +const { accounts, contract } = require('@openzeppelin/test-environment'); +const { BN, expectEvent } = require('@openzeppelin/test-helpers'); +const { expect } = require('chai'); + +const zip = require('lodash.zip'); + +const EnumerableMapMock = contract.fromArtifact('EnumerableMapMock'); + +describe('EnumerableMap', function () { + const [ accountA, accountB, accountC ] = accounts; + + const keyA = new BN('7891'); + const keyB = new BN('451'); + const keyC = new BN('9592328'); + + beforeEach(async function () { + this.map = await EnumerableMapMock.new(); + }); + + async function expectMembersMatch (map, keys, values) { + expect(keys.length).to.equal(values.length); + + await Promise.all(keys.map(async key => + expect(await map.contains(key)).to.equal(true) + )); + + expect(await map.length()).to.bignumber.equal(keys.length.toString()); + + expect(await Promise.all(keys.map(key => + map.get(key) + ))).to.have.same.members(values); + + // To compare key-value pairs, we zip keys and values, and convert BNs to + // strings to workaround Chai limitations when dealing with nested arrays + expect(await Promise.all([...Array(keys.length).keys()].map(async (index) => { + const entry = await map.at(index); + return [entry.key.toString(), entry.value]; + }))).to.have.same.deep.members( + zip(keys.map(k => k.toString()), values) + ); + } + + it('starts empty', async function () { + expect(await this.map.contains(keyA)).to.equal(false); + + await expectMembersMatch(this.map, [], []); + }); + + it('adds a key', async function () { + const receipt = await this.map.set(keyA, accountA); + expectEvent(receipt, 'OperationResult', { result: true }); + + await expectMembersMatch(this.map, [keyA], [accountA]); + }); + + it('adds several keys', async function () { + await this.map.set(keyA, accountA); + await this.map.set(keyB, accountB); + + await expectMembersMatch(this.map, [keyA, keyB], [accountA, accountB]); + expect(await this.map.contains(keyC)).to.equal(false); + }); + + it('returns false when adding keys already in the set', async function () { + await this.map.set(keyA, accountA); + + const receipt = (await this.map.set(keyA, accountA)); + expectEvent(receipt, 'OperationResult', { result: false }); + + await expectMembersMatch(this.map, [keyA], [accountA]); + }); + + it('updates values for keys already in the set', async function () { + await this.map.set(keyA, accountA); + + await this.map.set(keyA, accountB); + + await expectMembersMatch(this.map, [keyA], [accountB]); + }); + + it('removes added keys', async function () { + await this.map.set(keyA, accountA); + + const receipt = await this.map.remove(keyA); + expectEvent(receipt, 'OperationResult', { result: true }); + + expect(await this.map.contains(keyA)).to.equal(false); + await expectMembersMatch(this.map, [], []); + }); + + it('returns false when removing keys not in the set', async function () { + const receipt = await this.map.remove(keyA); + expectEvent(receipt, 'OperationResult', { result: false }); + + expect(await this.map.contains(keyA)).to.equal(false); + }); + + it('adds and removes multiple keys', async function () { + // [] + + await this.map.set(keyA, accountA); + await this.map.set(keyC, accountC); + + // [A, C] + + await this.map.remove(keyA); + await this.map.remove(keyB); + + // [C] + + await this.map.set(keyB, accountB); + + // [C, B] + + await this.map.set(keyA, accountA); + await this.map.remove(keyC); + + // [A, B] + + await this.map.set(keyA, accountA); + await this.map.set(keyB, accountB); + + // [A, B] + + await this.map.set(keyC, accountC); + await this.map.remove(keyA); + + // [B, C] + + await this.map.set(keyA, accountA); + await this.map.remove(keyB); + + // [A, C] + + await expectMembersMatch(this.map, [keyA, keyC], [accountA, accountC]); + + expect(await this.map.contains(keyB)).to.equal(false); + }); +}); diff --git a/test/utils/EnumerableSet.test.js b/test/utils/EnumerableSet.test.js index 9d309c661c6..8585b755dc2 100644 --- a/test/utils/EnumerableSet.test.js +++ b/test/utils/EnumerableSet.test.js @@ -11,18 +11,16 @@ describe('EnumerableSet', function () { this.set = await EnumerableSetMock.new(); }); - async function expectMembersMatch (set, members) { - await Promise.all(members.map(async account => + async function expectMembersMatch (set, values) { + await Promise.all(values.map(async account => expect(await set.contains(account)).to.equal(true) )); - expect(await set.enumerate()).to.have.same.members(members); + expect(await set.length()).to.bignumber.equal(values.length.toString()); - expect(await set.length()).to.bignumber.equal(members.length.toString()); - - expect(await Promise.all([...Array(members.length).keys()].map(index => + expect(await Promise.all([...Array(values.length).keys()].map(index => set.at(index) - ))).to.have.same.members(members); + ))).to.have.same.members(values); } it('starts empty', async function () { @@ -33,7 +31,7 @@ describe('EnumerableSet', function () { it('adds a value', async function () { const receipt = await this.set.add(accountA); - expectEvent(receipt, 'TransactionResult', { result: true }); + expectEvent(receipt, 'OperationResult', { result: true }); await expectMembersMatch(this.set, [accountA]); }); @@ -46,11 +44,11 @@ describe('EnumerableSet', function () { expect(await this.set.contains(accountC)).to.equal(false); }); - it('returns false when adding elements already in the set', async function () { + it('returns false when adding values already in the set', async function () { await this.set.add(accountA); const receipt = (await this.set.add(accountA)); - expectEvent(receipt, 'TransactionResult', { result: false }); + expectEvent(receipt, 'OperationResult', { result: false }); await expectMembersMatch(this.set, [accountA]); }); @@ -63,15 +61,15 @@ describe('EnumerableSet', function () { await this.set.add(accountA); const receipt = await this.set.remove(accountA); - expectEvent(receipt, 'TransactionResult', { result: true }); + expectEvent(receipt, 'OperationResult', { result: true }); expect(await this.set.contains(accountA)).to.equal(false); await expectMembersMatch(this.set, []); }); - it('returns false when removing elements not in the set', async function () { + it('returns false when removing values not in the set', async function () { const receipt = await this.set.remove(accountA); - expectEvent(receipt, 'TransactionResult', { result: false }); + expectEvent(receipt, 'OperationResult', { result: false }); expect(await this.set.contains(accountA)).to.equal(false); });