diff --git a/src/enforcers/LogicalOrWrapperEnforcer.sol b/src/enforcers/LogicalOrWrapperEnforcer.sol index f0b771dc..0837c348 100644 --- a/src/enforcers/LogicalOrWrapperEnforcer.sol +++ b/src/enforcers/LogicalOrWrapperEnforcer.sol @@ -32,10 +32,11 @@ import { ModeCode, Caveat } from "../utils/Types.sol"; * - The enforcer iterates over all caveats in the specified `CaveatGroup`. * - For a group to pass, all caveats within that group must succeed. * - Every caveat in the group is evaluated. - * - The group index provided via `SelectedGroup.groupIndex` must be valid (i.e. less than or equal to the length of the terms - * array). + * - The group index provided via `SelectedGroup.groupIndex` must be valid (i.e. less than the length of the terms array). * - The length of `SelectedGroup.caveatArgs` must exactly match the number of caveats in the corresponding `CaveatGroup`. * Empty bytes can be used for caveats that do not require arguments. + * - To prevent delegationHash collisions between different caveat groups using the same delegationHash, this enforcer + * creates a unique delegationHash by combining the original delegationHash with the caveat group index. * * @dev Security Notice: This enforcer allows the redeemer to choose which caveat group to use at * execution time, via the groupIndex parameter. If multiple caveat groups are defined with varying @@ -105,6 +106,17 @@ contract LogicalOrWrapperEnforcer is CaveatEnforcer { ////////////////////////////// Public Methods ////////////////////////////// + /** + * @notice Combines a delegation hash with the group index to create a unique identifier + * @dev This is the delegationHash that the caveat enforcers will receive. + * @param _delegationHash The original delegation hash + * @param _groupIndex The index of the caveat group being evaluated + * @return bytes32 A unique hash combining the delegation hash and group index + */ + function getLogicalOrDelegationHash(bytes32 _delegationHash, uint256 _groupIndex) external pure returns (bytes32) { + return _getLogicalOrDelegationHash(_delegationHash, _groupIndex); + } + /** * @notice Hook called before all delegations are executed * @dev Validates that the execution mode is default and calls the appropriate hook on all caveats @@ -266,11 +278,21 @@ contract LogicalOrWrapperEnforcer is CaveatEnforcer { selectedGroup_.caveatArgs[i], _params.mode, _params.executionCallData, - _params.delegationHash, + _getLogicalOrDelegationHash(_params.delegationHash, selectedGroup_.groupIndex), _params.delegator, _params.redeemer ) ); } } + + /** + * @notice Combines a delegation hash with the group index to create a unique identifier + * @param _delegationHash The original delegation hash + * @param _groupIndex The index of the caveat group being evaluated + * @return bytes32 A unique hash combining the delegation hash and group index + */ + function _getLogicalOrDelegationHash(bytes32 _delegationHash, uint256 _groupIndex) internal pure returns (bytes32) { + return keccak256(abi.encode(_delegationHash, _groupIndex)); + } } diff --git a/test/enforcers/LogicalOrWrapperEnforcer.t.sol b/test/enforcers/LogicalOrWrapperEnforcer.t.sol index 18c2aba4..b92137e4 100644 --- a/test/enforcers/LogicalOrWrapperEnforcer.t.sol +++ b/test/enforcers/LogicalOrWrapperEnforcer.t.sol @@ -3,6 +3,7 @@ pragma solidity 0.8.23; import { Test } from "forge-std/Test.sol"; import { ExecutionLib } from "@erc7579/lib/ExecutionLib.sol"; +import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import { Execution, Caveat, Delegation } from "../../src/utils/Types.sol"; import { Counter } from "../utils/Counter.t.sol"; @@ -14,6 +15,12 @@ import { AllowedTargetsEnforcer } from "../../src/enforcers/AllowedTargetsEnforc import { NativeTokenTransferAmountEnforcer } from "../../src/enforcers/NativeTokenTransferAmountEnforcer.sol"; import { TimestampEnforcer } from "../../src/enforcers/TimestampEnforcer.sol"; import { ArgsEqualityCheckEnforcer } from "../../src/enforcers/ArgsEqualityCheckEnforcer.sol"; +import { ERC20PeriodTransferEnforcer } from "../../src/enforcers/ERC20PeriodTransferEnforcer.sol"; +import { NativeTokenPeriodTransferEnforcer } from "../../src/enforcers/NativeTokenPeriodTransferEnforcer.sol"; +import { LimitedCallsEnforcer } from "../../src/enforcers/LimitedCallsEnforcer.sol"; +import { EncoderLib } from "../../src/libraries/EncoderLib.sol"; +import { BasicERC20 } from "../utils/BasicERC20.t.sol"; +import "forge-std/Test.sol"; contract LogicalOrWrapperEnforcerTest is CaveatEnforcerBaseTest { ////////////////////// State ////////////////////// @@ -24,6 +31,11 @@ contract LogicalOrWrapperEnforcerTest is CaveatEnforcerBaseTest { TimestampEnforcer public timestampEnforcer; ArgsEqualityCheckEnforcer public argsEqualityCheckEnforcer; NativeTokenTransferAmountEnforcer public nativeTokenTransferAmountEnforcer; + ERC20PeriodTransferEnforcer public erc20PeriodTransferEnforcer; + NativeTokenPeriodTransferEnforcer public nativeTokenPeriodTransferEnforcer; + LimitedCallsEnforcer public limitedCallsEnforcer; + + address[] tokens = new address[](3); ////////////////////// Set up ////////////////////// @@ -35,12 +47,18 @@ contract LogicalOrWrapperEnforcerTest is CaveatEnforcerBaseTest { timestampEnforcer = new TimestampEnforcer(); argsEqualityCheckEnforcer = new ArgsEqualityCheckEnforcer(); nativeTokenTransferAmountEnforcer = new NativeTokenTransferAmountEnforcer(); + erc20PeriodTransferEnforcer = new ERC20PeriodTransferEnforcer(); + nativeTokenPeriodTransferEnforcer = new NativeTokenPeriodTransferEnforcer(); + limitedCallsEnforcer = new LimitedCallsEnforcer(); vm.label(address(logicalOrWrapperEnforcer), "Logical OR Wrapper Enforcer"); vm.label(address(allowedMethodsEnforcer), "Allowed Methods Enforcer"); vm.label(address(timestampEnforcer), "Timestamp Enforcer"); vm.label(address(argsEqualityCheckEnforcer), "Args Equality Check Enforcer"); vm.label(address(allowedTargetsEnforcer), "Allowed Targets Enforcer"); vm.label(address(nativeTokenTransferAmountEnforcer), "Native Token Transfer Amount Enforcer"); + vm.label(address(erc20PeriodTransferEnforcer), "ERC20 Period Transfer Enforcer"); + vm.label(address(nativeTokenPeriodTransferEnforcer), "Native Token Period Transfer Enforcer"); + vm.label(address(limitedCallsEnforcer), "Limited Calls Enforcer"); } ////////////////////// Helper Functions ////////////////////// @@ -285,6 +303,243 @@ contract LogicalOrWrapperEnforcerTest is CaveatEnforcerBaseTest { ); } + /// @notice Tests that multiple groups with NativeTokenPeriodTransferEnforcer work correctly through the redemption flow + function test_multipleNativeTokenPeriodTransferGroups() public { + vm.warp(block.timestamp + 1 days); + // Create 3 groups with different period transfer limits + address[] memory enforcers = new address[](1); + enforcers[0] = address(nativeTokenPeriodTransferEnforcer); + + bytes[] memory terms1_ = new bytes[](1); + terms1_[0] = abi.encode(1 ether, 1 days, block.timestamp); // 1 ETH per day + + bytes[] memory terms2_ = new bytes[](1); + terms2_[0] = abi.encode(1 ether, 2 days, block.timestamp); // 2 ETH per 2 days + + bytes[] memory terms3_ = new bytes[](1); + terms3_[0] = abi.encode(1 ether, 3 days, block.timestamp); // 3 ETH per 3 days + + LogicalOrWrapperEnforcer.CaveatGroup[] memory groups_ = new LogicalOrWrapperEnforcer.CaveatGroup[](3); + groups_[0] = _createCaveatGroup(enforcers, terms1_); + groups_[1] = _createCaveatGroup(enforcers, terms2_); + groups_[2] = _createCaveatGroup(enforcers, terms3_); + + // Create and sign delegation + Caveat[] memory caveats_ = new Caveat[](1); + caveats_[0] = Caveat({ enforcer: address(logicalOrWrapperEnforcer), terms: abi.encode(groups_), args: hex"" }); + + Delegation memory delegation_ = Delegation({ + delegate: address(users.bob.deleGator), + delegator: address(users.alice.deleGator), + authority: ROOT_AUTHORITY, + caveats: caveats_, + salt: 0, + signature: hex"" + }); + + delegation_ = signDelegation(users.alice, delegation_); + // bytes32 delegationHash_ = EncoderLib._getDelegationHash(delegation_); + + // Verify initial balance + uint256 recipientInitialBalance_ = address(0x123).balance; + for (uint256 i = 0; i < groups_.length; i++) { + // Create execution data for a 1 ETH transfer + Execution memory execution_ = Execution({ target: payable(address(0x123)), value: 1 ether, callData: hex"" }); + + // Create selected group using group index 1 (2 ETH per 2 days) + // bytes[] memory caveatArgs_ = new bytes[](1); + // caveatArgs_[0] = hex""; // No args needed for NativeTokenPeriodTransferEnforcer + LogicalOrWrapperEnforcer.SelectedGroup memory selectedGroup_ = _createSelectedGroup(i, new bytes[](1)); + delegation_.caveats[0].args = abi.encode(selectedGroup_); + + // Execute Bob's UserOp + Delegation[] memory delegations_ = new Delegation[](1); + delegations_[0] = delegation_; + + uint256 recipientBalanceBefore_ = address(0x123).balance; + + // Execute the delegation + invokeDelegation_UserOp(users.bob, delegations_, execution_); + + (uint256 availableAmount_, bool isNewPeriod_, uint256 currentPeriod_) = nativeTokenPeriodTransferEnforcer + .getAvailableAmount( + logicalOrWrapperEnforcer.getLogicalOrDelegationHash( + EncoderLib._getDelegationHash(delegation_), selectedGroup_.groupIndex + ), + address(logicalOrWrapperEnforcer), + groups_[selectedGroup_.groupIndex].caveats[0].terms + ); + assertEq(availableAmount_, 0, "Available amount should be 0"); + assertEq(isNewPeriod_, false, "Is new period should be false"); + assertEq(currentPeriod_, 1, "Current period should be 1"); + + // Verify the transfer occurred + assertEq(address(0x123).balance, recipientBalanceBefore_ + 1 ether, "Transfer should have occurred with 1 ether"); + } + // Verify the transfer occurred + assertEq(address(0x123).balance, recipientInitialBalance_ + 3 ether, "Transfer should have occurred with 3 ether"); + } + + /// @notice Tests that multiple ERC20 period transfer groups work correctly by verifying that transfers within different period + /// limits succeed + function test_multipleERC20PeriodTransferGroups() public { + vm.warp(block.timestamp + 1 days); + + // Create test token and mint initial balance + // address[] memory tokens_ = new address[](3); + tokens[0] = address(new BasicERC20(address(users.alice.deleGator), "TEST1", "TEST1", 100 ether)); + tokens[1] = address(new BasicERC20(address(users.alice.deleGator), "TEST2", "TEST2", 100 ether)); + tokens[2] = address(new BasicERC20(address(users.alice.deleGator), "TEST2", "TEST2", 100 ether)); + + // Create groups with different period transfer limits + address[] memory enforcers_ = new address[](1); + enforcers_[0] = address(erc20PeriodTransferEnforcer); + bytes[] memory terms1_ = new bytes[](1); + terms1_[0] = abi.encodePacked(address(tokens[0]), uint256(1 ether), uint256(1 days), block.timestamp); // 1 ETH per day + + bytes[] memory terms2_ = new bytes[](1); + terms2_[0] = abi.encodePacked(address(tokens[1]), uint256(1 ether), uint256(2 days), block.timestamp); // 1 ETH per 2 days + + bytes[] memory terms3_ = new bytes[](1); + terms3_[0] = abi.encodePacked(address(tokens[2]), uint256(1 ether), uint256(3 days), block.timestamp); // 1 ETH per 3 days + + LogicalOrWrapperEnforcer.CaveatGroup[] memory groups_ = new LogicalOrWrapperEnforcer.CaveatGroup[](3); + groups_[0] = _createCaveatGroup(enforcers_, terms1_); + groups_[1] = _createCaveatGroup(enforcers_, terms2_); + groups_[2] = _createCaveatGroup(enforcers_, terms3_); + + // Create the caveat with the groups + Caveat[] memory caveats_ = new Caveat[](1); + caveats_[0] = Caveat({ enforcer: address(logicalOrWrapperEnforcer), terms: abi.encode(groups_), args: hex"" }); + + Delegation memory delegation_ = Delegation({ + delegate: address(users.bob.deleGator), + delegator: address(users.alice.deleGator), + authority: ROOT_AUTHORITY, + caveats: caveats_, + salt: 0, + signature: hex"" + }); + + delegation_ = signDelegation(users.alice, delegation_); + + for (uint256 i = 0; i < groups_.length; i++) { + LogicalOrWrapperEnforcer.SelectedGroup memory selectedGroup_ = _createSelectedGroup(i, new bytes[](1)); + delegation_.caveats[0].args = abi.encode(selectedGroup_); + + // Execute Bob's UserOp + Delegation[] memory delegations_ = new Delegation[](1); + delegations_[0] = delegation_; + + uint256 recipientBalanceBefore_ = IERC20(tokens[i]).balanceOf(address(0x123)); + + // Execute the delegation + invokeDelegation_UserOp( + users.bob, + delegations_, + Execution({ + target: address(tokens[i]), + value: 0, + callData: abi.encodeWithSelector(IERC20.transfer.selector, address(0x123), 1 ether) + }) + ); + + (uint256 availableAmount_, bool isNewPeriod_, uint256 currentPeriod_) = erc20PeriodTransferEnforcer.getAvailableAmount( + logicalOrWrapperEnforcer.getLogicalOrDelegationHash( + EncoderLib._getDelegationHash(delegation_), selectedGroup_.groupIndex + ), + address(logicalOrWrapperEnforcer), + groups_[selectedGroup_.groupIndex].caveats[0].terms + ); + assertEq(availableAmount_, 0, "Available amount should be 0"); + assertEq(isNewPeriod_, false, "Is new period should be false"); + assertEq(currentPeriod_, 1, "Current period should be 1"); + + // Verify the transfer occurred + assertEq( + IERC20(tokens[i]).balanceOf(address(0x123)), + recipientBalanceBefore_ + 1 ether, + "Transfer should have occurred with 1 ether" + ); + } + // Verify the total transfers occurred + assertEq(IERC20(tokens[0]).balanceOf(address(0x123)), 1 ether, "Transfer should have occurred with 1 ether"); + assertEq(IERC20(tokens[1]).balanceOf(address(0x123)), 1 ether, "Transfer should have occurred with 1 ether"); + assertEq(IERC20(tokens[2]).balanceOf(address(0x123)), 1 ether, "Transfer should have occurred with 1 ether"); + } + + /// @notice Tests that two CaveatGroups with LimitedCallsEnforcer can be redeemed successfully + function test_twoGroupsWithLimitedCallsEnforcer() public { + uint256 initialValue_ = aliceDeleGatorCounter.count(); + + // Create two groups with LimitedCallsEnforcer + LogicalOrWrapperEnforcer.CaveatGroup[] memory groups_ = new LogicalOrWrapperEnforcer.CaveatGroup[](2); + address[] memory enforcers_ = new address[](1); + enforcers_[0] = address(limitedCallsEnforcer); + + // Group 0: Allow 2 calls to increment + bytes[] memory terms_ = new bytes[](1); + terms_[0] = abi.encodePacked(uint256(2)); // Allow 2 calls + groups_[0] = _createCaveatGroup(enforcers_, terms_); + + // Group 1: Allow 1 call to increment + bytes[] memory terms1_ = new bytes[](1); + terms1_[0] = abi.encodePacked(uint256(1)); // Allow 1 call + groups_[1] = _createCaveatGroup(enforcers_, terms1_); + + // Create execution for counter increment + Execution memory execution_ = Execution({ + target: address(aliceDeleGatorCounter), + value: 0, + callData: abi.encodeWithSelector(Counter.increment.selector) + }); + + // Create caveat for the logical OR wrapper + Caveat[] memory caveats_ = new Caveat[](1); + caveats_[0] = Caveat({ enforcer: address(logicalOrWrapperEnforcer), terms: abi.encode(groups_), args: hex"" }); + + // Create delegation + Delegation memory delegation_ = Delegation({ + delegate: address(users.bob.deleGator), + delegator: address(users.alice.deleGator), + authority: ROOT_AUTHORITY, + caveats: caveats_, + salt: 0, + signature: hex"" + }); + + delegation_ = signDelegation(users.alice, delegation_); + + // Execute Bob's UserOp for first group (2 calls) + Delegation[] memory delegations_ = new Delegation[](1); + delegations_[0] = delegation_; + + delegations_[0].caveats[0].args = abi.encode(_createSelectedGroup(0, new bytes[](1))); + + // First call using group 0 + invokeDelegation_UserOp(users.bob, delegations_, execution_); + assertEq(aliceDeleGatorCounter.count(), initialValue_ + 1, "First call should increment counter"); + + // Second call using group 0 + invokeDelegation_UserOp(users.bob, delegations_, execution_); + assertEq(aliceDeleGatorCounter.count(), initialValue_ + 2, "Second call should increment counter"); + + // Third call using group 1 should fail + invokeDelegation_UserOp(users.bob, delegations_, execution_); + assertEq(aliceDeleGatorCounter.count(), initialValue_ + 2, "Third call should not increment counter"); + + // Switch to group 1 (1 call allowed) + delegations_[0].caveats[0].args = abi.encode(_createSelectedGroup(1, new bytes[](1))); + + // Fourth call using group 1 + invokeDelegation_UserOp(users.bob, delegations_, execution_); + assertEq(aliceDeleGatorCounter.count(), initialValue_ + 3, "Fourth call should increment counter"); + + // Fifth call using group 1 should fail + invokeDelegation_UserOp(users.bob, delegations_, execution_); + assertEq(aliceDeleGatorCounter.count(), initialValue_ + 3, "Fifth call should not increment counter"); + } + ////////////////////// Invalid cases ////////////////////// /// @notice Tests that an invalid group index reverts with the expected error by verifying that selecting a group index beyond @@ -403,13 +658,24 @@ contract LogicalOrWrapperEnforcerTest is CaveatEnforcerBaseTest { ); } - /// @notice Tests that only the delegation manager can call the beforeHook function by verifying that calls from non-delegation + /// @notice Tests that only the delegation manager can call the hooks by verifying that calls from non-delegation /// manager addresses revert with the expected error function test_onlyDelegationManager() public { // Call the hook from a non-delegation manager address - vm.prank(address(0x1234)); + vm.startPrank(address(0x1234)); vm.expectRevert("LogicalOrWrapperEnforcer:only-delegation-manager"); logicalOrWrapperEnforcer.beforeHook(hex"", hex"", singleDefaultMode, hex"", keccak256(""), address(0), address(0)); + + vm.expectRevert("LogicalOrWrapperEnforcer:only-delegation-manager"); + logicalOrWrapperEnforcer.beforeAllHook(hex"", hex"", singleDefaultMode, hex"", keccak256(""), address(0), address(0)); + + vm.expectRevert("LogicalOrWrapperEnforcer:only-delegation-manager"); + logicalOrWrapperEnforcer.afterHook(hex"", hex"", singleDefaultMode, hex"", keccak256(""), address(0), address(0)); + + vm.expectRevert("LogicalOrWrapperEnforcer:only-delegation-manager"); + logicalOrWrapperEnforcer.afterAllHook(hex"", hex"", singleDefaultMode, hex"", keccak256(""), address(0), address(0)); + + vm.stopPrank(); } ////////////////////// Integration //////////////////////