Skip to content

Commit

Permalink
[#47] Add tests for function handler
Browse files Browse the repository at this point in the history
  • Loading branch information
akshay-ap committed Aug 20, 2023
1 parent f48c12d commit 8319bfd
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 43 deletions.
6 changes: 5 additions & 1 deletion contracts/base/OnlySelfCallable.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ abstract contract OnlySelfCallable {
error InvalidCalldataLength();

modifier onlySelf() {
checkOnlySelf();
_;
}

function checkOnlySelf() private view {
if (msg.data.length < 20) {
revert InvalidCalldataLength();
}
Expand All @@ -17,6 +22,5 @@ abstract contract OnlySelfCallable {
if (sender != msg.sender) {
revert InvalidSender(sender);
}
_;
}
}
3 changes: 1 addition & 2 deletions contracts/test/TestExecutor.sol
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ contract TestExecutor is ISafe {
// The handler can make us of `HandlerContext.sol` to extract the address.
// This is done because in the next call frame the `msg.sender` will be FallbackManager's address
// and having the original caller address may enable additional verification scenarios.
// Source: https://github.com/safe-global/safe-contracts/blob/main/contracts/base/FallbackManager.sol#L62
// solhint-disable-next-line payable-fallback,no-complex-fallback
fallback() external {
address handler = fallbackHandler;
// solhint-disable-next-line no-inline-assembly
/// @solidity memory-safe-assembly
assembly {
// When compiled with the optimizer, the compiler relies on a certain assumptions on how the
// memory is used, therefore we need to guarantee memory safety (keeping the free memory point 0x40 slot intact,
Expand Down Expand Up @@ -123,6 +123,5 @@ contract TestExecutor is ISafe {
return(returnDataPtr, returndatasize())
}
}

receive() external payable {}
}
57 changes: 54 additions & 3 deletions test/FunctionHandlerManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import hre, { deployments, ethers } from "hardhat";
import { getMockFunctionHandler } from "./utils/mockFunctionHandlerBuilder";
import { IntegrationType } from "./utils/constants";
import { expect } from "chai";
import { getMockTestExecutorInstance, getInstance } from "./utils/contracts";
import { getInstance } from "./utils/contracts";
import { MaxUint256, ZeroAddress } from "ethers";
import { ISafeProtocolFunctionHandler__factory, MockContract } from "../typechain-types";

Expand All @@ -27,8 +27,7 @@ describe("Test Function Handler", async () => {
).deploy(owner.address, await safeProtocolRegistry.getAddress());

await safeProtocolRegistry.addIntegration(mockFunctionHandler.target, IntegrationType.FunctionHandler);

const safe = await getMockTestExecutorInstance();
const safe = await hre.ethers.deployContract("TestExecutor", [functionHandlerManager.target], { signer: deployer });

return { safe, functionHandlerManager, mockFunctionHandler, safeProtocolRegistry };
});
Expand Down Expand Up @@ -136,4 +135,56 @@ describe("Test Function Handler", async () => {
expect(await mockContract.invocationCountForCalldata(expectedCallData)).to.equal(1n);
expect(await mockContract.invocationCount()).to.equal(1n);
});

it("Should revert if address does not implement expected interface Id", async () => {
const { safe, functionHandlerManager, mockFunctionHandler } = await setupTests();

const mock = await getInstance<MockContract>("MockContract", mockFunctionHandler.target);
await mock.givenMethodReturnBool("0x01ffc9a7", false);
// 0xf8a8fd6d -> function test() external {}
const functionId = "0xf8a8fd6d";
const dataSetFunctionHandler = functionHandlerManager.interface.encodeFunctionData("setFunctionHandler", [
functionId,
mockFunctionHandler.target,
]);

await expect(safe.executeCallViaMock(functionHandlerManager, 0n, dataSetFunctionHandler, MaxUint256))
.to.be.revertedWithCustomError(functionHandlerManager, "AccountDoesNotImplementValidInterfaceId")
.withArgs(mockFunctionHandler.target);
});

// it("Should revert with InvalidSender when caller it not safe", async () => {
// const { functionHandlerManager, mockFunctionHandler, safeProtocolRegistry } = await setupTests();
// const plugin = await (await hre.ethers.getContractFactory("TestPluginWithRootAccess")).deploy();
// const safe = await getSafeWithOwners([owner], 1, functionHandlerManager.target);

// const encodedPluginAdd = functionHandlerManager.interface.encodeFunctionData("enablePlugin", [plugin.target, true]);

// await safeProtocolRegistry.addIntegration(plugin.target, IntegrationType.Plugin);
// const funcSIg = "0x250db3c0"; // enable plugin;
// await functionHandlerManager.connect(user1).setFunctionHandler(funcSIg, mockFunctionHandler.target);

// await (
// await user1.sendTransaction({
// to: safe.target,
// value: 0,
// data: encodedPluginAdd,
// })
// ).wait();

// const encodedPluginDisable = functionHandlerManager.interface.encodeFunctionData("disablePlugin", [
// SENTINEL_MODULES,
// plugin.target,
// ]);

// await (
// await user1.sendTransaction({
// to: safe.target,
// value: 0,
// data: encodedPluginDisable,
// })
// ).wait();

// console.log(await functionHandlerManager.isPluginEnabled.staticCall(plugin.target, safe.target));
// });
});
36 changes: 29 additions & 7 deletions test/SafeProtocolManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import { SignerWithAddress } from "@nomicfoundation/hardhat-ethers/signers";
import { buildRootTx, buildSingleTx } from "./utils/builder";
import { getHooksWithFailingPrechecks, getHooksWithPassingChecks, getHooksWithFailingPostCheck } from "./utils/mockHooksBuilder";
import { IntegrationType } from "./utils/constants";
import { getMockTestExecutorInstance } from "./utils/contracts";
import { getInstance } from "./utils/contracts";
import { SafeProtocolManager } from "../typechain-types";

describe("SafeProtocolManager", async () => {
let deployer: SignerWithAddress, owner: SignerWithAddress, user1: SignerWithAddress, user2: SignerWithAddress;
Expand All @@ -24,9 +25,6 @@ describe("SafeProtocolManager", async () => {
).deploy(owner.address, await safeProtocolRegistry.getAddress());

const safe = await hre.ethers.deployContract("TestExecutor", [safeProtocolManager.target], { signer: deployer });
// const otherSafe = await hre.ethers.deployContract("TestExecutor", [manager.target], {signer: deployer});
// const safeProtocolManager = await getInstance<SafeProtocolManager>("SafeProtocolManager", safe.target);
// const safeProtocolManager = (await hre.ethers.getContractFactory("SafeProtocolManager")).attach(otherSafe.target);
return { safeProtocolManager, safeProtocolRegistry, safe };
});

Expand Down Expand Up @@ -72,6 +70,19 @@ describe("SafeProtocolManager", async () => {
.withArgs(hre.ethers.ZeroAddress);
});

it("Blocks calls not initiated from Safe", async () => {
const { safeProtocolManager, plugin, safe } = await loadFixture(deployContractsWithPluginFixture);
const pluginAddress = await plugin.getAddress();
await expect(safeProtocolManager.enablePlugin(pluginAddress, false))
.to.be.revertedWithCustomError(safeProtocolManager, "InvalidSender")
.withArgs(ZeroAddress);

const contract = await getInstance<SafeProtocolManager>("SafeProtocolManager", safe);
await expect(contract.connect(user1).enablePlugin(pluginAddress, false))
.to.be.revertedWithCustomError(safeProtocolManager, "InvalidSender")
.withArgs(user1.address);
});

it("Should not allow a Safe to enable plugin if not added as a integration in registry", async () => {
const { safeProtocolManager, safe } = await loadFixture(deployContractsWithPluginFixture);
await safe.setModule(await safeProtocolManager.getAddress());
Expand Down Expand Up @@ -139,6 +150,19 @@ describe("SafeProtocolManager", async () => {
.withArgs(hre.ethers.ZeroAddress);
});

it("Blocks calls not initiated from Safe", async () => {
const { safeProtocolManager, plugin, safe } = await loadFixture(deployContractsWithPluginFixture);
const pluginAddress = await plugin.getAddress();

const data = safeProtocolManager.interface.encodeFunctionData("enablePlugin", [pluginAddress, false]);
await safe.exec(safe.target, 0, data);

const contract = await getInstance<SafeProtocolManager>("SafeProtocolManager", safe);
await expect(contract.connect(user1).disablePlugin(SENTINEL_MODULES, pluginAddress))
.to.be.revertedWithCustomError(safeProtocolManager, "InvalidSender")
.withArgs(user1.address);
});

it("Should not allow a Safe to disable SENTINEL_MODULES plugin", async () => {
const { safeProtocolManager, safe } = await loadFixture(deployContractsWithPluginFixture);
const safeProtocolManagerAddress = await safeProtocolManager.getAddress();
Expand Down Expand Up @@ -503,7 +527,6 @@ describe("SafeProtocolManager", async () => {

it("Should fail executing a transaction through plugin when hooks post-check fails", async () => {
const { safeProtocolManager, safe, safeProtocolRegistry } = await loadFixture(deployContractsWithEnabledManagerFixture);
const safeProtocolManagerAddress = await safeProtocolManager.getAddress();
// Enable hooks on a safe
const hooks = await getHooksWithFailingPostCheck();

Expand Down Expand Up @@ -722,7 +745,6 @@ describe("SafeProtocolManager", async () => {
it("Should fail to execute a transaction from root access enabled plugin when hooks post-check fails", async () => {
const { safeProtocolManager, safe, safeProtocolRegistry } = await loadFixture(deployContractsWithEnabledManagerFixture);
const safeAddress = await safe.getAddress();
const safeProtocolManagerAddress = await safeProtocolManager.getAddress();

// Enable hooks on a safe
const hooks = await getHooksWithFailingPostCheck();
Expand Down Expand Up @@ -898,7 +920,7 @@ describe("SafeProtocolManager", async () => {
await hre.ethers.getContractFactory("SafeProtocolManager")
).deploy(owner.address, safeProtocolRegistry.target);

const safe = await getMockTestExecutorInstance(safeProtocolManager.target);
const safe = await hre.ethers.deployContract("TestExecutor", [safeProtocolManager.target], { signer: deployer });

const hooks = await getHooksWithPassingChecks();
const hooksWithFailingPreChecks = await getHooksWithFailingPrechecks();
Expand Down
61 changes: 36 additions & 25 deletions test/base/HooksManager.spec.ts
Original file line number Diff line number Diff line change
@@ -1,66 +1,77 @@
import hre from "hardhat";
import { loadFixture } from "@nomicfoundation/hardhat-toolbox/network-helpers";
import hre, { deployments } from "hardhat";
import { expect } from "chai";
import { SignerWithAddress } from "@nomicfoundation/hardhat-ethers/signers";
import { getHooksWithPassingChecks } from "../utils/mockHooksBuilder";
import { ZeroAddress } from "ethers";

describe("HooksManager", async () => {
let deployer: SignerWithAddress, user1: SignerWithAddress;
let deployer: SignerWithAddress, user1: SignerWithAddress, owner: SignerWithAddress;

before(async () => {
[deployer, user1] = await hre.ethers.getSigners();
[deployer, owner, user1] = await hre.ethers.getSigners();
});

async function deployContractsFixture() {
const setupTests = deployments.createFixture(async ({ deployments }) => {
await deployments.fixture();
[deployer, user1] = await hre.ethers.getSigners();
const safeProtocolRegistry = await hre.ethers.deployContract("SafeProtocolRegistry", [owner.address]);
const hooksManager = await (
await hre.ethers.getContractFactory("SafeProtocolManager")
).deploy(owner.address, await safeProtocolRegistry.getAddress());

const hooksManager = await hre.ethers.deployContract("HooksManager", { signer: deployer });
const safe = await hre.ethers.deployContract("TestExecutor", [hooksManager.target], { signer: deployer });
const hooks = await getHooksWithPassingChecks();

return { hooksManager, hooks };
}
return { hooksManager, hooks, safe };
});

it("Should emit HooksChanged event when hooks are enabled", async () => {
const { hooksManager, hooks } = await loadFixture(deployContractsFixture);
const hooksAddress = await hooks.getAddress();
expect(await hooksManager.connect(user1).setHooks(hooksAddress))
const { hooksManager, hooks, safe } = await setupTests();

const calldata = hooksManager.interface.encodeFunctionData("setHooks", [hooks.target]);

expect(await safe.exec(safe.target, 0n, calldata))
.to.emit(hooksManager, "HooksChanged")
.withArgs(user1, hooksAddress);
.withArgs(safe.target, hooks.target);
});

it("Should return correct hooks address", async () => {
const { hooksManager, hooks } = await loadFixture(deployContractsFixture);
const hooksAddress = await hooks.getAddress();
await hooksManager.connect(user1).setHooks(hooksAddress);
expect(await hooksManager.getEnabledHooks(user1.address)).to.be.equal(hooksAddress);
const { hooksManager, hooks, safe } = await setupTests();
const calldata = hooksManager.interface.encodeFunctionData("setHooks", [hooks.target]);
await safe.exec(safe.target, 0n, calldata);
expect(await hooksManager.getEnabledHooks(safe.target)).to.be.equal(hooks.target);
});

it("Should return zero address if hooks are not enabled", async () => {
const { hooksManager } = await loadFixture(deployContractsFixture);
const { hooksManager } = await setupTests();
expect(await hooksManager.getEnabledHooks(user1.address)).to.be.equal(hre.ethers.ZeroAddress);
});

it("Should return zero address if hooks address is reset to zero address", async () => {
const { hooksManager, hooks } = await loadFixture(deployContractsFixture);
const { hooksManager, hooks, safe } = await setupTests();

const hooksAddress = await hooks.getAddress();
expect(await hooksManager.connect(user1).setHooks(hooksAddress));
expect(await hooksManager.connect(user1).setHooks(hre.ethers.ZeroAddress));
expect(await hooksManager.getEnabledHooks(user1.address)).to.be.equal(hre.ethers.ZeroAddress);
const calldata = hooksManager.interface.encodeFunctionData("setHooks", [hooks.target]);
await safe.exec(safe.target, 0n, calldata);

const calldata2 = hooksManager.interface.encodeFunctionData("setHooks", [ZeroAddress]);
await safe.exec(safe.target, 0n, calldata2);

expect(await hooksManager.getEnabledHooks(safe.target)).to.be.equal(hre.ethers.ZeroAddress);
});

it("Should revert if user attempts to set random address as hooks", async () => {
const { hooksManager } = await loadFixture(deployContractsFixture);
const { hooksManager } = await setupTests();
const hooksAddress = hre.ethers.getAddress(hre.ethers.hexlify(hre.ethers.randomBytes(20)));
await expect(hooksManager.setHooks(hooksAddress)).to.be.reverted;
});

it("Should revert AddressDoesNotImplementHooksInterface if user attempts address does not implement Hooks interface", async () => {
const { hooksManager } = await loadFixture(deployContractsFixture);
const { hooksManager, safe } = await setupTests();
const contractNotImplementingHooksInterface = await (await hre.ethers.getContractFactory("MockContract")).deploy();
await contractNotImplementingHooksInterface.givenMethodReturnBool("0x01ffc9a7", false);

await expect(hooksManager.setHooks(await contractNotImplementingHooksInterface.getAddress())).to.be.revertedWithCustomError(
const calldata = hooksManager.interface.encodeFunctionData("setHooks", [contractNotImplementingHooksInterface.target]);
await expect(safe.exec(safe.target, 0n, calldata)).to.be.revertedWithCustomError(
hooksManager,
"AddressDoesNotImplementHooksInterface",
);
Expand Down
5 changes: 0 additions & 5 deletions test/utils/contracts.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import { Addressable, BaseContract } from "ethers";
import hre from "hardhat";
import { TestExecutor } from "../../typechain-types";
export const getInstance = async <T extends BaseContract>(name: string, address: string | Addressable): Promise<T> => {
// TODO: this typecasting should be refactored
return (await hre.ethers.getContractAt(name, address)) as unknown as T;
};

export const getMockTestExecutorInstance = async (): Promise<TestExecutor> => {
return await (await hre.ethers.getContractFactory("TestExecutor")).deploy();
};

0 comments on commit 8319bfd

Please sign in to comment.