diff --git a/src/loan/vault/PWNVault.sol b/src/loan/vault/PWNVault.sol index 1288fb7..2485167 100644 --- a/src/loan/vault/PWNVault.sol +++ b/src/loan/vault/PWNVault.sol @@ -61,6 +61,11 @@ abstract contract PWNVault is IERC721Receiver, IERC1155Receiver { */ error IncompleteTransfer(); + /** + * @notice Thrown when an asset transfer source and destination address are the same. + */ + error VaultTransferSameSourceAndDestination(address addr); + /*----------------------------------------------------------*| |* # TRANSFER FUNCTIONS *| @@ -76,7 +81,13 @@ abstract contract PWNVault is IERC721Receiver, IERC1155Receiver { uint256 originalBalance = asset.balanceOf(address(this)); asset.transferAssetFrom(origin, address(this)); - _checkTransfer(asset, originalBalance, address(this), true); + _checkTransfer({ + asset: asset, + originalBalance: originalBalance, + checkedAddress: address(this), + counterPartyAddress: origin, + checkIncreasingBalance: true + }); emit VaultPull(asset, origin); } @@ -91,7 +102,13 @@ abstract contract PWNVault is IERC721Receiver, IERC1155Receiver { uint256 originalBalance = asset.balanceOf(beneficiary); asset.safeTransferAssetFrom(address(this), beneficiary); - _checkTransfer(asset, originalBalance, beneficiary, true); + _checkTransfer({ + asset: asset, + originalBalance: originalBalance, + checkedAddress: beneficiary, + counterPartyAddress: address(this), + checkIncreasingBalance: true + }); emit VaultPush(asset, beneficiary); } @@ -107,7 +124,13 @@ abstract contract PWNVault is IERC721Receiver, IERC1155Receiver { uint256 originalBalance = asset.balanceOf(beneficiary); asset.safeTransferAssetFrom(origin, beneficiary); - _checkTransfer(asset, originalBalance, beneficiary, true); + _checkTransfer({ + asset: asset, + originalBalance: originalBalance, + checkedAddress: beneficiary, + counterPartyAddress: origin, + checkIncreasingBalance: true + }); emit VaultPushFrom(asset, origin, beneficiary); } @@ -124,7 +147,13 @@ abstract contract PWNVault is IERC721Receiver, IERC1155Receiver { uint256 originalBalance = asset.balanceOf(owner); poolAdapter.withdraw(pool, owner, asset.assetAddress, asset.amount); - _checkTransfer(asset, originalBalance, owner, true); + _checkTransfer({ + asset: asset, + originalBalance: originalBalance, + checkedAddress: owner, + counterPartyAddress: pool, + checkIncreasingBalance: true + }); emit PoolWithdraw(asset, address(poolAdapter), pool, owner); } @@ -143,7 +172,13 @@ abstract contract PWNVault is IERC721Receiver, IERC1155Receiver { asset.transferAssetFrom(address(this), address(poolAdapter)); poolAdapter.supply(pool, owner, asset.assetAddress, asset.amount); - _checkTransfer(asset, originalBalance, address(this), false); + _checkTransfer({ + asset: asset, + originalBalance: originalBalance, + checkedAddress: address(this), + counterPartyAddress: pool, + checkIncreasingBalance: false + }); // Note: Assuming pool will revert supply transaction if it fails. @@ -154,8 +189,13 @@ abstract contract PWNVault is IERC721Receiver, IERC1155Receiver { MultiToken.Asset memory asset, uint256 originalBalance, address checkedAddress, + address counterPartyAddress, bool checkIncreasingBalance ) private view { + if (checkedAddress == counterPartyAddress) { + revert VaultTransferSameSourceAndDestination({ addr: checkedAddress }); + } + uint256 expectedBalance = checkIncreasingBalance ? originalBalance + asset.getTransferAmount() : originalBalance - asset.getTransferAmount(); diff --git a/test/unit/PWNVault.t.sol b/test/unit/PWNVault.t.sol index b3a86fc..a376c35 100644 --- a/test/unit/PWNVault.t.sol +++ b/test/unit/PWNVault.t.sol @@ -102,6 +102,14 @@ contract PWNVault_Pull_Test is PWNVaultTest { vault.pull(asset, alice); } + function test_shouldFail_whenSameSourceAndDestination() external { + t721.mint(address(vault), 42); + + vm.expectRevert(abi.encodeWithSelector(PWNVault.VaultTransferSameSourceAndDestination.selector, address(vault))); + MultiToken.Asset memory asset = MultiToken.Asset(MultiToken.Category.ERC721, address(t721), 42, 0); + vault.pull(asset, address(vault)); + } + function test_shouldEmitEvent_VaultPull() external { t721.mint(alice, 42); vm.prank(alice); @@ -148,6 +156,14 @@ contract PWNVault_Push_Test is PWNVaultTest { vault.push(asset, alice); } + function test_shouldFail_whenSameSourceAndDestination() external { + t721.mint(address(vault), 42); + + vm.expectRevert(abi.encodeWithSelector(PWNVault.VaultTransferSameSourceAndDestination.selector, address(vault))); + MultiToken.Asset memory asset = MultiToken.Asset(MultiToken.Category.ERC721, address(t721), 42, 0); + vault.push(asset, address(vault)); + } + function test_shouldEmitEvent_VaultPush() external { t721.mint(address(vault), 42); @@ -194,6 +210,16 @@ contract PWNVault_PushFrom_Test is PWNVaultTest { vault.pushFrom(asset, alice, bob); } + function test_shouldFail_whenSameSourceAndDestination() external { + t721.mint(alice, 42); + vm.prank(alice); + t721.approve(address(vault), 42); + + vm.expectRevert(abi.encodeWithSelector(PWNVault.VaultTransferSameSourceAndDestination.selector, alice)); + MultiToken.Asset memory asset = MultiToken.Asset(MultiToken.Category.ERC721, address(t721), 42, 0); + vault.pushFrom(asset, alice, alice); + } + function test_shouldEmitEvent_VaultPushFrom() external { t721.mint(alice, 42); vm.prank(alice); @@ -252,6 +278,11 @@ contract PWNVault_WithdrawFromPool_Test is PWNVaultTest { vault.withdrawFromPool(asset, poolAdapter, pool, alice); } + function test_shouldFail_whenSameSourceAndDestination() external { + vm.expectRevert(abi.encodeWithSelector(PWNVault.VaultTransferSameSourceAndDestination.selector, pool)); + vault.withdrawFromPool(asset, poolAdapter, pool, pool); + } + function test_shouldEmitEvent_PoolWithdraw() external { vm.expectEmit(); emit PoolWithdraw(asset, address(poolAdapter), pool, alice); @@ -311,6 +342,11 @@ contract PWNVault_SupplyToPool_Test is PWNVaultTest { vault.supplyToPool(asset, poolAdapter, pool, alice); } + function test_shouldFail_whenSameSourceAndDestination() external { + vm.expectRevert(abi.encodeWithSelector(PWNVault.VaultTransferSameSourceAndDestination.selector, address(vault))); + vault.supplyToPool(asset, poolAdapter, address(vault), alice); + } + function test_shouldEmitEvent_PoolSupply() external { vm.expectEmit(); emit PoolSupply(asset, address(poolAdapter), pool, alice);