diff --git a/.gitignore b/.gitignore index da11ea3..500c9f6 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,6 @@ target/ tool-versions .vscode/ -node_modules/ \ No newline at end of file +node_modules/ + +deployment-script \ No newline at end of file diff --git a/src/airdrop/NFTHolderAirdrop.cairo b/src/airdrop/NFTHolderAirdrop.cairo index b0fcbc6..8ad537c 100644 --- a/src/airdrop/NFTHolderAirdrop.cairo +++ b/src/airdrop/NFTHolderAirdrop.cairo @@ -50,6 +50,12 @@ mod NFTHolderAirdrop { use starknet::get_contract_address; use starknet::get_caller_address; + use openzeppelin::security::ReentrancyGuardComponent; + + component!(path: ReentrancyGuardComponent, storage: reentrancy, event: ReentrancyEvent); + + impl ReentrancyInternalImpl = ReentrancyGuardComponent::InternalImpl; + #[storage] struct Storage { initialized: bool, @@ -58,6 +64,8 @@ mod NFTHolderAirdrop { eligible_nft: ContractAddress, rewards_per_nft: u256, claimed_nfts: LegacyMap::, + #[substorage(v0)] + reentrancy: ReentrancyGuardComponent::Storage } #[event] @@ -67,6 +75,8 @@ mod NFTHolderAirdrop { RewardTokensWithdrawnByAdmin: RewardTokensWithdrawnByAdmin, UpdatedRewardsPerNft: UpdatedRewardsPerNft, RewardsClaimed: RewardsClaimed, + #[flat] + ReentrancyEvent: ReentrancyGuardComponent::Event, } #[derive(Drop, starknet::Event)] @@ -111,6 +121,7 @@ mod NFTHolderAirdrop { } fn withdraw_reward_tokens(ref self: ContractState, _amount: u256) { + self.reentrancy.start(); assert(self.owner.read() == get_caller_address(), 'Only owner'); let reward_token = IERC20Dispatcher { contract_address: self.reward_token.read() }; reward_token.transfer(get_caller_address(), _amount); @@ -121,9 +132,12 @@ mod NFTHolderAirdrop { RewardTokensWithdrawnByAdmin { amount: _amount } ) ); + + self.reentrancy.end(); } - + fn set_reward_per_nft(ref self: ContractState, _rewards_per_nft: u256) { + self.reentrancy.start(); assert(self.owner.read() == get_caller_address(), 'Only owner'); let old_rewards_per_nft = self.rewards_per_nft.read(); self.rewards_per_nft.write(_rewards_per_nft); @@ -136,10 +150,13 @@ mod NFTHolderAirdrop { } ) ); + self.reentrancy.end(); } fn claim_rewards(ref self: ContractState, _token_id: u256) { + self.reentrancy.start(); self._claim_rewards(_token_id); + self.reentrancy.end(); } fn get_rewards_per_nft(self: @ContractState) -> u256 { @@ -163,7 +180,9 @@ mod NFTHolderAirdrop { self .emit( Event::RewardsClaimed( - RewardsClaimed { tokenId: _token_id, recipient: nftOwner, rewardAmount: rewards } + RewardsClaimed { + tokenId: _token_id, recipient: nftOwner, rewardAmount: rewards + } ) ); }