diff --git a/contracts/SafeProtocolManager.sol b/contracts/SafeProtocolManager.sol index edee03f2..ec8ba73c 100644 --- a/contracts/SafeProtocolManager.sol +++ b/contracts/SafeProtocolManager.sol @@ -326,8 +326,10 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana bytes memory signatures, address msgSender ) external { - // Store hooks address in tempHooksAddress so that checkAfterExecution(...) and checkModuleTransaction(...) can access it. - address tempHooksAddressForSafe = tempHooksAddress[msg.sender] = enabledHooks[msg.sender]; + // Store hooks address in tempHooksAddress so that checkAfterExecution(...) can access it. + // A temprary storage is required to use old hooks in checkAfterExecution if hooks get updated in between transaction + tempHooksAddress[msg.sender] = enabledHooks[msg.sender]; + address tempHooksAddressForSafe = enabledHooks[msg.sender]; if (tempHooksAddressForSafe == address(0)) return; bytes memory executionMetadata = abi.encode( @@ -372,6 +374,16 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana tempHooksAddress[msg.sender] = address(0); } + /** + * @notice This function is introduced in Safe contracts v1.5 and used for checking module transactions when a guard is enabled. + * This function will be called when executing a transaction from a module with Safe v1.5 and Manager enabled as Guard on Safe. + * @param to The address to which the transaction is intended. + * @param value The value of the transaction in Wei. + * @param data The transaction data. + * @param operation The type of operation of the transaction. + * @param module The module involved in the transaction. + * @return moduleTxHash The hash of the module transaction. + */ function checkModuleTransaction( address to, uint256 value, @@ -379,27 +391,28 @@ contract SafeProtocolManager is ISafeProtocolManager, RegistryManager, HooksMana Enum.Operation operation, address module /* onlyPermittedPlugin(module) uncomment this? */ // Use term plugin? ) external returns (bytes32 moduleTxHash) { - // Store hooks address in tempHooksAddress so that checkAfterExecution(...) and checkModuleTransaction(...) can access it. - address tempHooksAddressForSafe = tempHooksAddress[msg.sender] = enabledHooks[msg.sender]; - - bytes memory executionMetadata = abi.encode(to, value, data, operation, module); + // Store hooks address in tempHooksAddress so that checkAfterExecution(...) can access it. + // A temprary storage is required to use old hooks in checkAfterExecution if hooks get updated in between transaction + tempHooksAddress[msg.sender] = enabledHooks[msg.sender]; + address tempHooksAddressForSafe = enabledHooks[msg.sender]; - if (tempHooksAddressForSafe == address(0)) return keccak256(executionMetadata); + moduleTxHash = keccak256(abi.encode(to, value, data, operation, module)); + if (tempHooksAddressForSafe == address(0)) return moduleTxHash; if (operation == Enum.Operation.Call) { SafeProtocolAction[] memory actions = new SafeProtocolAction[](1); actions[0] = SafeProtocolAction(payable(to), value, data); SafeTransaction memory safeTx = SafeTransaction(actions, 0, ""); - ISafeProtocolHooks(tempHooksAddressForSafe).preCheck(ISafe(msg.sender), safeTx, 0, executionMetadata); + ISafeProtocolHooks(tempHooksAddressForSafe).preCheck(ISafe(msg.sender), safeTx, 1, abi.encode(module)); } else { // Using else instead of "else if(operation == Enum.Operation.DelegateCall)" to reduce gas usage // and Safe allows only Call and DelegateCall operations. SafeProtocolAction memory action = SafeProtocolAction(payable(to), value, data); SafeRootAccess memory safeTx = SafeRootAccess(action, 0, ""); - ISafeProtocolHooks(tempHooksAddressForSafe).preCheckRootAccess(ISafe(msg.sender), safeTx, 0, executionMetadata); + ISafeProtocolHooks(tempHooksAddressForSafe).preCheckRootAccess(ISafe(msg.sender), safeTx, 1, abi.encode(module)); } - return keccak256(executionMetadata); + return moduleTxHash; } function supportsInterface(bytes4 interfaceId) external view virtual override returns (bool) { diff --git a/contracts/base/HooksManager.sol b/contracts/base/HooksManager.sol index 1668bd4f..7c55724f 100644 --- a/contracts/base/HooksManager.sol +++ b/contracts/base/HooksManager.sol @@ -1,9 +1,11 @@ // SPDX-License-Identifier: LGPL-3.0-only pragma solidity ^0.8.18; import {ISafeProtocolHooks} from "../interfaces/Integrations.sol"; + +import {RegistryManager} from "./RegistryManager.sol"; import {OnlyAccountCallable} from "./OnlyAccountCallable.sol"; -contract HooksManager is OnlyAccountCallable { +abstract contract HooksManager is RegistryManager, OnlyAccountCallable { mapping(address => address) public enabledHooks; /// @notice This variable should store the address of the hooks contract whenever @@ -14,9 +16,6 @@ contract HooksManager is OnlyAccountCallable { // Events event HooksChanged(address indexed safe, address indexed hooksAddress); - // Errors - error AddressDoesNotImplementHooksInterface(address hooksAddress); - /** * @notice Returns the address of hooks for a Safe account provided as a fucntion parameter. * Returns address(0) is no hooks are enabled. @@ -32,8 +31,10 @@ contract HooksManager is OnlyAccountCallable { * @param hooks Address of the hooks to be enabled for msg.sender. */ function setHooks(address hooks) external onlyAccount { - if (hooks != address(0) && !ISafeProtocolHooks(hooks).supportsInterface(type(ISafeProtocolHooks).interfaceId)) { - revert AddressDoesNotImplementHooksInterface(hooks); + if (hooks != address(0)) { + checkPermittedIntegration(hooks); + if (!ISafeProtocolHooks(hooks).supportsInterface(type(ISafeProtocolHooks).interfaceId)) + revert AccountDoesNotImplementValidInterfaceId(hooks); } enabledHooks[msg.sender] = hooks; emit HooksChanged(msg.sender, hooks); diff --git a/contracts/test/TestExecutor.sol b/contracts/test/TestExecutor.sol index 3965721e..eadc82da 100644 --- a/contracts/test/TestExecutor.sol +++ b/contracts/test/TestExecutor.sol @@ -1,7 +1,6 @@ // SPDX-License-Identifier: LGPL-3.0-only pragma solidity ^0.8.18; import {ISafe} from "../interfaces/Accounts.sol"; -import {MockContract} from "@safe-global/mock-contract/contracts/MockContract.sol"; contract TestExecutor is ISafe { address public module; diff --git a/test/SafeProtocolManager.spec.ts b/test/SafeProtocolManager.spec.ts index ecb8b6b2..11d84dd3 100644 --- a/test/SafeProtocolManager.spec.ts +++ b/test/SafeProtocolManager.spec.ts @@ -8,7 +8,7 @@ import { buildRootTx, buildSingleTx } from "./utils/builder"; import { getHooksWithFailingPrechecks, getHooksWithPassingChecks, getHooksWithFailingPostCheck } from "./utils/mockHooksBuilder"; import { IntegrationType } from "./utils/constants"; import { getInstance } from "./utils/contracts"; -import { SafeProtocolManager } from "../typechain-types"; +import { MockContract, SafeProtocolManager } from "../typechain-types"; describe("SafeProtocolManager", async () => { let deployer: SignerWithAddress, owner: SignerWithAddress, user1: SignerWithAddress, user2: SignerWithAddress; @@ -475,6 +475,7 @@ describe("SafeProtocolManager", async () => { const { safeProtocolManager, safe, safeProtocolRegistry } = await loadFixture(deployContractsWithEnabledManagerFixture); // Enable hooks on a safe const hooks = await getHooksWithPassingChecks(); + await safeProtocolRegistry.connect(owner).addIntegration(hooks.target, IntegrationType.Hooks); const dataSetHooks = safeProtocolManager.interface.encodeFunctionData("setHooks", [await hooks.getAddress()]); await safe.exec(safe.target, 0, dataSetHooks); @@ -509,6 +510,7 @@ describe("SafeProtocolManager", async () => { const { safeProtocolManager, safe, safeProtocolRegistry } = await loadFixture(deployContractsWithEnabledManagerFixture); // Enable hooks on a safe const hooks = await getHooksWithFailingPrechecks(); + await safeProtocolRegistry.connect(owner).addIntegration(hooks.target, IntegrationType.Hooks); const dataSetHooks = safeProtocolManager.interface.encodeFunctionData("setHooks", [await hooks.getAddress()]); await safe.exec(safe.target, 0, dataSetHooks); @@ -529,6 +531,7 @@ describe("SafeProtocolManager", async () => { const { safeProtocolManager, safe, safeProtocolRegistry } = await loadFixture(deployContractsWithEnabledManagerFixture); // Enable hooks on a safe const hooks = await getHooksWithFailingPostCheck(); + await safeProtocolRegistry.connect(owner).addIntegration(hooks.target, IntegrationType.Hooks); const dataSetHooks = safeProtocolManager.interface.encodeFunctionData("setHooks", [await hooks.getAddress()]); await safe.exec(safe.target, 0, dataSetHooks); @@ -671,8 +674,11 @@ describe("SafeProtocolManager", async () => { it("Should execute a transaction from root access enabled plugin with hooks enabled", async () => { const { safeProtocolManager, safe, safeProtocolRegistry } = await loadFixture(deployContractsWithEnabledManagerFixture); const safeAddress = await safe.getAddress(); + // Enable hooks on a safe const hooks = await getHooksWithPassingChecks(); + await safeProtocolRegistry.connect(owner).addIntegration(hooks.target, IntegrationType.Hooks); + const dataSetHooks = safeProtocolManager.interface.encodeFunctionData("setHooks", [await hooks.getAddress()]); await safe.exec(safe.target, 0, dataSetHooks); @@ -716,6 +722,7 @@ describe("SafeProtocolManager", async () => { const { safeProtocolManager, safe, safeProtocolRegistry } = await loadFixture(deployContractsWithEnabledManagerFixture); // Enable hooks on a safe const hooks = await getHooksWithFailingPrechecks(); + await safeProtocolRegistry.connect(owner).addIntegration(hooks.target, IntegrationType.Hooks); const dataSetHooks = safeProtocolManager.interface.encodeFunctionData("setHooks", [await hooks.getAddress()]); await safe.exec(safe.target, 0, dataSetHooks); @@ -748,6 +755,8 @@ describe("SafeProtocolManager", async () => { // Enable hooks on a safe const hooks = await getHooksWithFailingPostCheck(); + await safeProtocolRegistry.connect(owner).addIntegration(hooks.target, IntegrationType.Hooks); + const dataSetHooks = safeProtocolManager.interface.encodeFunctionData("setHooks", [await hooks.getAddress()]); await safe.exec(safe.target, 0, dataSetHooks); @@ -928,6 +937,7 @@ describe("SafeProtocolManager", async () => { await safeProtocolRegistry.connect(owner).addIntegration(hooks.target, IntegrationType.Hooks); await safeProtocolRegistry.connect(owner).addIntegration(hooksWithFailingPreChecks.target, IntegrationType.Hooks); + await safeProtocolRegistry.connect(owner).addIntegration(hooksWithFailingPostCheck.target, IntegrationType.Hooks); return { safe, safeProtocolManager, hooks, hooksWithFailingPreChecks, hooksWithFailingPostCheck }; }); @@ -1062,6 +1072,15 @@ describe("SafeProtocolManager", async () => { ]); expect(await safe.executeCallViaMock(safeProtocolManager.target, 0, execPostChecks, MaxUint256)); + + // Check if temporary hooks related storage is cleared after tx + expect(await safeProtocolManager.tempHooksAddress.staticCall(safe.target)).to.deep.equal(ZeroAddress); + + const mockHooks = await getInstance("MockContract", hooks.target); + // Pre-check hooks calls + expect(await mockHooks.invocationCountForMethod("0x176ae7b7")).to.equal(1); + const postCheckCallData = hooks.interface.encodeFunctionData("postCheck", [safe.target, true, "0x"]); + expect(await mockHooks.invocationCountForCalldata(postCheckCallData)).to.equal(1); }); it("Should pass hooks checks for module transaction with call operation", async () => { @@ -1084,8 +1103,25 @@ describe("SafeProtocolManager", async () => { hre.ethers.randomBytes(32), true, ]); - expect(await safe.executeCallViaMock(safeProtocolManager.target, 0, execPostChecks, MaxUint256)); + + // Check if temporary hooks related storage is cleared after tx + expect(await safeProtocolManager.tempHooksAddress.staticCall(safe.target)).to.deep.equal(ZeroAddress); + + const mockHooks = await getInstance("MockContract", hooks.target); + // preCheck hooks calls + const safeTx = buildSingleTx(user2.address, 0n, "0x", 0n, hre.ethers.ZeroHash); + const preCheckCalldata = hooks.interface.encodeFunctionData("preCheck", [ + safe.target, + safeTx, + 1, + hre.ethers.AbiCoder.defaultAbiCoder().encode(["address"], [ZeroAddress]), + ]); + expect(await mockHooks.invocationCountForMethod("0x176ae7b7")).to.equal(1); + expect(await mockHooks.invocationCountForCalldata(preCheckCalldata)).to.equal(1); + const postCheckCallData = hooks.interface.encodeFunctionData("postCheck", [safe.target, true, "0x"]); + + expect(await mockHooks.invocationCountForCalldata(postCheckCallData)).to.equal(1); }); it("Should pass hooks checks for module transaction with delegateCall operation", async () => { @@ -1110,9 +1146,25 @@ describe("SafeProtocolManager", async () => { ]); expect(await safe.executeCallViaMock(safeProtocolManager.target, 0, execPostChecks, MaxUint256)); + + const mockHooks = await getInstance("MockContract", hooks.target); + // preCheck hooks calls + const safeTx = buildRootTx(user2.address, 0n, "0x", 0n, hre.ethers.ZeroHash); + const preCheckCalldata = hooks.interface.encodeFunctionData("preCheckRootAccess", [ + safe.target, + safeTx, + 1, + hre.ethers.AbiCoder.defaultAbiCoder().encode(["address"], [ZeroAddress]), + ]); + // 0x7359b742 -> preCheckRootAccess function signature + expect(await mockHooks.invocationCountForMethod("0x7359b742")).to.equal(1); + expect(await mockHooks.invocationCountForCalldata(preCheckCalldata)).to.equal(1); + const postCheckCallData = hooks.interface.encodeFunctionData("postCheck", [safe.target, true, "0x"]); + + expect(await mockHooks.invocationCountForCalldata(postCheckCallData)).to.equal(1); }); - it("Should execute pass hooks checks for delegateCall operation", async () => { + it("Should pass hooks checks for delegateCall operation", async () => { const { safe, safeProtocolManager, hooks } = await setupTests(); // Set Hooks contract for the Safe const dataSetHooks = safeProtocolManager.interface.encodeFunctionData("setHooks", [hooks.target]); @@ -1139,6 +1191,12 @@ describe("SafeProtocolManager", async () => { ]); expect(await safe.executeCallViaMock(safeProtocolManager.target, 0, execPostChecks, MaxUint256)); + + const mockHooks = await getInstance("MockContract", hooks.target); + // preCheckRootAccess hooks calls + expect(await mockHooks.invocationCountForMethod("0x7359b742")).to.equal(1); + const postCheckCallData = hooks.interface.encodeFunctionData("postCheck", [safe.target, true, "0x"]); + expect(await mockHooks.invocationCountForCalldata(postCheckCallData)).to.equal(1); }); it("uses old hooks in checkAfterExecution if hooks get updated in between transactions", async () => { diff --git a/test/base/HooksManager.spec.ts b/test/base/HooksManager.spec.ts index 0e93bfaf..695f8d74 100644 --- a/test/base/HooksManager.spec.ts +++ b/test/base/HooksManager.spec.ts @@ -3,6 +3,7 @@ import { expect } from "chai"; import { SignerWithAddress } from "@nomicfoundation/hardhat-ethers/signers"; import { getHooksWithPassingChecks } from "../utils/mockHooksBuilder"; import { ZeroAddress } from "ethers"; +import { IntegrationType } from "../utils/constants"; describe("HooksManager", async () => { let deployer: SignerWithAddress, user1: SignerWithAddress, owner: SignerWithAddress; @@ -21,8 +22,9 @@ describe("HooksManager", async () => { const safe = await hre.ethers.deployContract("TestExecutor", [hooksManager.target], { signer: deployer }); const hooks = await getHooksWithPassingChecks(); + await safeProtocolRegistry.connect(owner).addIntegration(hooks.target, IntegrationType.Hooks); - return { hooksManager, hooks, safe }; + return { hooksManager, hooks, safe, safeProtocolRegistry }; }); it("Should emit HooksChanged event when hooks are enabled", async () => { @@ -65,15 +67,17 @@ describe("HooksManager", async () => { await expect(hooksManager.setHooks(hooksAddress)).to.be.reverted; }); - it("Should revert AddressDoesNotImplementHooksInterface if user attempts address does not implement Hooks interface", async () => { - const { hooksManager, safe } = await setupTests(); + it("Should revert AccountDoesNotImplementValidInterfaceId if user attempts address does not implement Hooks interface", async () => { + const { hooksManager, safe, safeProtocolRegistry } = await setupTests(); const contractNotImplementingHooksInterface = await (await hre.ethers.getContractFactory("MockContract")).deploy(); - await contractNotImplementingHooksInterface.givenMethodReturnBool("0x01ffc9a7", false); + await contractNotImplementingHooksInterface.givenMethodReturnBool("0x01ffc9a7", true); + await safeProtocolRegistry.connect(owner).addIntegration(contractNotImplementingHooksInterface.target, IntegrationType.Hooks); + await contractNotImplementingHooksInterface.givenMethodReturnBool("0x01ffc9a7", false); const calldata = hooksManager.interface.encodeFunctionData("setHooks", [contractNotImplementingHooksInterface.target]); await expect(safe.exec(safe.target, 0n, calldata)).to.be.revertedWithCustomError( hooksManager, - "AddressDoesNotImplementHooksInterface", + "AccountDoesNotImplementValidInterfaceId", ); }); });