diff --git a/src/enforcers/AllowedCalldataAnyOfEnforcer.sol b/src/enforcers/AllowedCalldataAnyOfEnforcer.sol new file mode 100644 index 00000000..c31e6dd0 --- /dev/null +++ b/src/enforcers/AllowedCalldataAnyOfEnforcer.sol @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT AND Apache-2.0 +pragma solidity 0.8.23; + +import { ExecutionLib } from "@erc7579/lib/ExecutionLib.sol"; + +import { CaveatEnforcer } from "./CaveatEnforcer.sol"; +import { ModeCode } from "../utils/Types.sol"; + +/** + * @title AllowedCalldataAnyOfEnforcer + * @dev Like `AllowedCalldataEnforcer`, but the delegator supplies several allowed byte sequences of **equal length**. + * @dev At `startIndex`, the execution calldata must exactly match **at least one** of those sequences (each candidate is compared + * over `valueLength` bytes, starting at `startIndex`). + * @dev This enforcer operates only in single execution call type and with default execution mode. + * @dev Prefer static or fixed-layout regions of calldata; validating dynamic types remains possible but is more error-prone, + * same as for `AllowedCalldataEnforcer`. + */ +contract AllowedCalldataAnyOfEnforcer is CaveatEnforcer { + using ExecutionLib for bytes; + + ////////////////////////////// Public Methods ////////////////////////////// + + /** + * @notice Allows the delegator to restrict calldata so that one of several equal-length slices matches at a fixed offset. + * @dev For each candidate, checks `callData[startIndex : startIndex + valueLength] == candidate`. + * @param _terms Binary layout: + * - **First 32 bytes:** `uint128 startIndex` (high 128 bits) | `uint128 valueLength` (low 128 bits) of one big-endian word. + * - **Remainder:** `candidateCount` candidates concatenated, each exactly `valueLength` bytes (so `len(remainder) == candidateCount * valueLength`). + * @param _mode The execution mode. (Must be Single callType, Default execType) + * @param _executionCallData The execution the delegate is trying to execute. + */ + function beforeHook( + bytes calldata _terms, + bytes calldata, + ModeCode _mode, + bytes calldata _executionCallData, + bytes32, + address, + address + ) + public + pure + override + onlySingleCallTypeMode(_mode) + onlyDefaultExecutionMode(_mode) + { + _validateCalldata(_terms, _executionCallData); + } + + /** + * @notice Decodes and validates the terms used in this CaveatEnforcer. + * @dev After reading `valueLength` from the header word, requires `valueLength >= 1`, a non-empty remainder, and that the + * remainder length is a multiple of `valueLength`. + * @param _terms Encoded data used during the execution hooks. + * @return startIndex_ Start index in the execution's call data. + * @return valueLength_ Length of every candidate slice and of the compared execution calldata window. + * @return candidateCount_ Number of candidates in the concatenated tail (`(len(_terms) - 32) / valueLength_`). + */ + function getTermsInfo(bytes calldata _terms) + public + pure + returns (uint128 startIndex_, uint128 valueLength_, uint256 candidateCount_) + { + require(_terms.length > 32, "AllowedCalldataAnyOfEnforcer:invalid-terms-size"); + uint256 metadataWord_ = uint256(bytes32(_terms[0:32])); + startIndex_ = uint128(metadataWord_ >> 128); + valueLength_ = uint128(metadataWord_); + + require(valueLength_ >= 1, "AllowedCalldataAnyOfEnforcer:invalid-value-length"); + + uint256 concatenatedValuesLength_ = _terms.length - 32; + require(concatenatedValuesLength_ != 0, "AllowedCalldataAnyOfEnforcer:no-allowed-values"); + require(concatenatedValuesLength_ % uint256(valueLength_) == 0, "AllowedCalldataAnyOfEnforcer:invalid-values-padding"); + + candidateCount_ = concatenatedValuesLength_ / uint256(valueLength_); + } + + /** + * @notice Validates that the execution calldata matches one of the allowed slices at `startIndex`. + * @param _terms Encoded terms (see `beforeHook`). + * @param _executionCallData The encoded single execution payload. + */ + function _validateCalldata(bytes calldata _terms, bytes calldata _executionCallData) private pure { + (uint128 startIndex_, uint128 valueLength_, uint256 candidateCount_) = getTermsInfo(_terms); + + uint256 dataStart_ = uint256(startIndex_); + uint256 lengthToMatch_ = uint256(valueLength_); + (,, bytes calldata callData_) = _executionCallData.decodeSingle(); + + require(dataStart_ + lengthToMatch_ <= callData_.length, "AllowedCalldataAnyOfEnforcer:invalid-calldata-length"); + + bytes calldata callDataToMatch_ = callData_[dataStart_:dataStart_ + lengthToMatch_]; + + bool matched_; + for (uint256 i = 0; i < candidateCount_; ++i) { + uint256 offset_ = 32 + i * lengthToMatch_; + if (callDataToMatch_ == _terms[offset_:offset_ + lengthToMatch_]) { + matched_ = true; + break; + } + } + require(matched_, "AllowedCalldataAnyOfEnforcer:invalid-calldata"); + } +} diff --git a/test/enforcers/AllowedCalldataAnyOfEnforcer.t.sol b/test/enforcers/AllowedCalldataAnyOfEnforcer.t.sol new file mode 100644 index 00000000..56467f7e --- /dev/null +++ b/test/enforcers/AllowedCalldataAnyOfEnforcer.t.sol @@ -0,0 +1,285 @@ +// SPDX-License-Identifier: MIT AND Apache-2.0 +pragma solidity 0.8.23; + +import "forge-std/Test.sol"; +import { ExecutionLib } from "@erc7579/lib/ExecutionLib.sol"; + +import { Execution, Caveat, Delegation } from "../../src/utils/Types.sol"; +import { CaveatEnforcerBaseTest } from "./CaveatEnforcerBaseTest.t.sol"; +import { AllowedCalldataAnyOfEnforcer } from "../../src/enforcers/AllowedCalldataAnyOfEnforcer.sol"; +import { BasicERC20, IERC20 } from "../utils/BasicERC20.t.sol"; +import { ICaveatEnforcer } from "../../src/interfaces/ICaveatEnforcer.sol"; + +contract AllowedCalldataAnyOfEnforcerTest is CaveatEnforcerBaseTest { + ////////////////////////////// State ////////////////////////////// + AllowedCalldataAnyOfEnforcer public allowedCalldataAnyOfEnforcer; + BasicERC20 public basicCF20; + + ////////////////////// Set up ////////////////////// + + function setUp() public override { + super.setUp(); + allowedCalldataAnyOfEnforcer = new AllowedCalldataAnyOfEnforcer(); + vm.label(address(allowedCalldataAnyOfEnforcer), "Allowed Calldata Any-Of Enforcer"); + basicCF20 = new BasicERC20(address(users.alice.deleGator), "TestToken1", "TestToken1", 100 ether); + } + + /// @dev Header: `uint128 startIndex` (high) | `uint128 valueLength` (low), then `candidateCount * valueLength` bytes. + function _packTerms(uint128 startIndex_, uint128 valueLength_, bytes memory concatenatedValues_) internal pure returns (bytes memory) { + require( + concatenatedValues_.length > 0 && concatenatedValues_.length % uint256(valueLength_) == 0, + "test: bad concatenatedValues length" + ); + uint256 metadataWord_ = (uint256(uint128(startIndex_)) << 128) | uint256(uint128(valueLength_)); + return bytes.concat(bytes32(metadataWord_), concatenatedValues_); + } + + ////////////////////// Valid cases ////////////////////// + + // should allow when the calldata matches the first allowed slice at startIndex + function test_allowsWhenFirstCandidateMatches() public { + Execution memory execution_ = Execution({ + target: address(basicCF20), + value: 0, + callData: abi.encodeWithSelector(BasicERC20.mint.selector, address(users.alice.deleGator), uint256(100)) + }); + bytes memory executionCallData_ = ExecutionLib.encodeSingle(execution_.target, execution_.value, execution_.callData); + + uint128 startIndex_ = uint128(abi.encodeWithSelector(BasicERC20.mint.selector, address(0)).length); + uint128 valueLength_ = 32; + bytes memory concatenatedValues_ = bytes.concat(abi.encodePacked(uint256(100)), abi.encodePacked(uint256(200))); + bytes memory terms_ = _packTerms(startIndex_, valueLength_, concatenatedValues_); + + vm.prank(address(delegationManager)); + allowedCalldataAnyOfEnforcer.beforeHook( + terms_, hex"", singleDefaultMode, executionCallData_, keccak256(""), address(0), address(0) + ); + } + + // should allow when the calldata matches a later candidate at startIndex + function test_allowsWhenSecondCandidateMatches() public { + Execution memory execution_ = Execution({ + target: address(basicCF20), + value: 0, + callData: abi.encodeWithSelector(BasicERC20.mint.selector, address(users.alice.deleGator), uint256(200)) + }); + bytes memory executionCallData_ = ExecutionLib.encodeSingle(execution_.target, execution_.value, execution_.callData); + + uint128 startIndex_ = uint128(abi.encodeWithSelector(BasicERC20.mint.selector, address(0)).length); + uint128 valueLength_ = 32; + bytes memory concatenatedValues_ = bytes.concat(abi.encodePacked(uint256(100)), abi.encodePacked(uint256(200))); + bytes memory terms_ = _packTerms(startIndex_, valueLength_, concatenatedValues_); + + vm.prank(address(delegationManager)); + allowedCalldataAnyOfEnforcer.beforeHook( + terms_, hex"", singleDefaultMode, executionCallData_, keccak256(""), address(0), address(0) + ); + } + + // should allow when several equal-length candidates include the executed uint256 + function test_allowsWhenOneOfSeveralUint256CandidatesMatches() public { + Execution memory execution_ = Execution({ + target: address(basicCF20), + value: 0, + callData: abi.encodeWithSelector(BasicERC20.mint.selector, address(users.alice.deleGator), uint256(0xabcd)) + }); + bytes memory executionCallData_ = ExecutionLib.encodeSingle(execution_.target, execution_.value, execution_.callData); + + uint128 startIndex_ = uint128(abi.encodeWithSelector(BasicERC20.mint.selector, address(0)).length); + uint128 valueLength_ = 32; + bytes memory concatenatedValues_ = + bytes.concat(abi.encodePacked(uint256(1)), abi.encodePacked(uint256(0xabcd)), abi.encodePacked(uint256(2))); + bytes memory terms_ = _packTerms(startIndex_, valueLength_, concatenatedValues_); + + vm.prank(address(delegationManager)); + allowedCalldataAnyOfEnforcer.beforeHook( + terms_, hex"", singleDefaultMode, executionCallData_, keccak256(""), address(0), address(0) + ); + } + + ////////////////////// Invalid cases ////////////////////// + + // should NOT allow when no candidate matches at startIndex + function test_revertsWhenNoCandidateMatches() public { + Execution memory execution_ = Execution({ + target: address(basicCF20), + value: 0, + callData: abi.encodeWithSelector(BasicERC20.mint.selector, address(users.alice.deleGator), uint256(300)) + }); + bytes memory executionCallData_ = ExecutionLib.encodeSingle(execution_.target, execution_.value, execution_.callData); + + uint128 startIndex_ = uint128(abi.encodeWithSelector(BasicERC20.mint.selector, address(0)).length); + uint128 valueLength_ = 32; + bytes memory concatenatedValues_ = bytes.concat(abi.encodePacked(uint256(100)), abi.encodePacked(uint256(200))); + bytes memory terms_ = _packTerms(startIndex_, valueLength_, concatenatedValues_); + + vm.prank(address(delegationManager)); + vm.expectRevert("AllowedCalldataAnyOfEnforcer:invalid-calldata"); + allowedCalldataAnyOfEnforcer.beforeHook( + terms_, hex"", singleDefaultMode, executionCallData_, keccak256(""), address(0), address(0) + ); + } + + // should NOT allow when the execution window is shorter than valueLength + function test_revertsWhenCalldataTooShortForSlice() public { + Execution memory execution_ = Execution({ + target: address(basicCF20), + value: 0, + callData: abi.encodeWithSelector(BasicERC20.mint.selector, address(users.alice.deleGator), uint256(100)) + }); + bytes memory executionCallData_ = ExecutionLib.encodeSingle(execution_.target, execution_.value, execution_.callData); + + uint128 startIndex_ = uint128(abi.encodeWithSelector(BasicERC20.mint.selector, address(0)).length); + uint128 valueLength_ = uint128(execution_.callData.length - uint256(startIndex_) + 1); + bytes memory concatenatedValues_ = new bytes(uint256(valueLength_)); + for (uint256 i = 0; i < concatenatedValues_.length; ++i) { + concatenatedValues_[i] = 0xff; + } + bytes memory terms_ = _packTerms(startIndex_, valueLength_, concatenatedValues_); + + vm.prank(address(delegationManager)); + vm.expectRevert("AllowedCalldataAnyOfEnforcer:invalid-calldata-length"); + allowedCalldataAnyOfEnforcer.beforeHook( + terms_, hex"", singleDefaultMode, executionCallData_, keccak256(""), address(0), address(0) + ); + } + + // should FAIL getTermsInfo when terms are shorter than 32 bytes + function test_getTermsInfoFailsForShortTerms() public { + vm.expectRevert("AllowedCalldataAnyOfEnforcer:invalid-terms-size"); + allowedCalldataAnyOfEnforcer.getTermsInfo(hex"010203"); + } + + // should FAIL getTermsInfo when there is no candidate tail + function test_getTermsInfoFailsForEmptyCandidatesTail() public { + uint256 metadataWord_ = (uint256(uint128(0)) << 128) | uint256(uint128(32)); + bytes memory terms_ = abi.encodePacked(bytes32(metadataWord_)); + vm.expectRevert("AllowedCalldataAnyOfEnforcer:no-allowed-values"); + allowedCalldataAnyOfEnforcer.getTermsInfo(terms_); + } + + // should FAIL getTermsInfo when valueLength is zero + function test_getTermsInfoFailsForZeroValueLength() public { + uint256 metadataWord_ = (uint256(uint128(4)) << 128) | uint256(uint128(0)); + bytes memory terms_ = bytes.concat(bytes32(metadataWord_), hex"aabb"); + vm.expectRevert("AllowedCalldataAnyOfEnforcer:invalid-value-length"); + allowedCalldataAnyOfEnforcer.getTermsInfo(terms_); + } + + // should FAIL getTermsInfo when tail is not a multiple of valueLength + function test_getTermsInfoFailsForInvalidValuesPadding() public { + uint256 metadataWord_ = (uint256(uint128(0)) << 128) | uint256(uint128(32)); + bytes memory terms_ = bytes.concat(bytes32(metadataWord_), new bytes(33)); + vm.expectRevert("AllowedCalldataAnyOfEnforcer:invalid-values-padding"); + allowedCalldataAnyOfEnforcer.getTermsInfo(terms_); + } + + // should decode header via getTermsInfo + function test_getTermsInfoDecodesHeaderAndCount() public view { + uint128 expectedStartIndex_ = 40; + uint128 expectedValueLength_ = 32; + bytes memory concatenatedValues_ = bytes.concat(abi.encodePacked(uint256(1)), abi.encodePacked(uint256(2))); + bytes memory terms_ = _packTerms(expectedStartIndex_, expectedValueLength_, concatenatedValues_); + (uint128 startIndex_, uint128 valueLength_, uint256 candidateCount_) = + allowedCalldataAnyOfEnforcer.getTermsInfo(terms_); + assertEq(startIndex_, expectedStartIndex_); + assertEq(valueLength_, expectedValueLength_); + assertEq(candidateCount_, 2); + } + + // should fail with invalid call type mode (batch instead of single mode) + function test_revertWithInvalidCallTypeMode() public { + bytes memory executionCallData_ = ExecutionLib.encodeBatch(new Execution[](2)); + + vm.expectRevert("CaveatEnforcer:invalid-call-type"); + + allowedCalldataAnyOfEnforcer.beforeHook( + hex"", hex"", batchDefaultMode, executionCallData_, bytes32(0), address(0), address(0) + ); + } + + // should fail with invalid call type mode (try instead of default) + function test_revertWithInvalidExecutionMode() public { + vm.prank(address(delegationManager)); + vm.expectRevert("CaveatEnforcer:invalid-execution-type"); + allowedCalldataAnyOfEnforcer.beforeHook(hex"", hex"", singleTryMode, hex"", bytes32(0), address(0), address(0)); + } + + ////////////////////// Integration ////////////////////// + + // should allow execution when the amount matches one of the allowed encodings Integration + function test_integrationAllowsMatchingAmount() public { + assertEq(basicCF20.balanceOf(address(users.bob.deleGator)), 0); + + Execution memory execution1_ = Execution({ + target: address(basicCF20), + value: 0, + callData: abi.encodeWithSelector(IERC20.transfer.selector, address(users.bob.deleGator), uint256(2)) + }); + + uint128 startIndex_ = uint128(abi.encodeWithSelector(IERC20.transfer.selector, address(0)).length); + uint128 valueLength_ = 32; + bytes memory concatenatedValues_ = bytes.concat(abi.encodePacked(uint256(1)), abi.encodePacked(uint256(2))); + bytes memory terms_ = _packTerms(startIndex_, valueLength_, concatenatedValues_); + + Caveat[] memory caveats_ = new Caveat[](1); + caveats_[0] = Caveat({ args: hex"", enforcer: address(allowedCalldataAnyOfEnforcer), terms: terms_ }); + Delegation memory delegation_ = Delegation({ + delegate: address(users.bob.deleGator), + delegator: address(users.alice.deleGator), + authority: ROOT_AUTHORITY, + caveats: caveats_, + salt: 0, + signature: hex"" + }); + + delegation_ = signDelegation(users.alice, delegation_); + + Delegation[] memory delegations_ = new Delegation[](1); + delegations_[0] = delegation_; + + invokeDelegation_UserOp(users.bob, delegations_, execution1_); + + assertEq(basicCF20.balanceOf(address(users.bob.deleGator)), uint256(2)); + } + + // should NOT allow execution when the amount matches none of the allowed encodings Integration + function test_integrationRejectsNonMatchingAmount() public { + assertEq(basicCF20.balanceOf(address(users.bob.deleGator)), 0); + + Execution memory execution1_ = Execution({ + target: address(basicCF20), + value: 0, + callData: abi.encodeWithSelector(IERC20.transfer.selector, address(users.bob.deleGator), uint256(3)) + }); + + uint128 startIndex_ = uint128(abi.encodeWithSelector(IERC20.transfer.selector, address(0)).length); + uint128 valueLength_ = 32; + bytes memory concatenatedValues_ = bytes.concat(abi.encodePacked(uint256(1)), abi.encodePacked(uint256(2))); + bytes memory terms_ = _packTerms(startIndex_, valueLength_, concatenatedValues_); + + Caveat[] memory caveats_ = new Caveat[](1); + caveats_[0] = Caveat({ args: hex"", enforcer: address(allowedCalldataAnyOfEnforcer), terms: terms_ }); + Delegation memory delegation_ = Delegation({ + delegate: address(users.bob.deleGator), + delegator: address(users.alice.deleGator), + authority: ROOT_AUTHORITY, + caveats: caveats_, + salt: 0, + signature: hex"" + }); + + delegation_ = signDelegation(users.alice, delegation_); + + Delegation[] memory delegations_ = new Delegation[](1); + delegations_[0] = delegation_; + + invokeDelegation_UserOp(users.bob, delegations_, execution1_); + + assertEq(basicCF20.balanceOf(address(users.bob.deleGator)), uint256(0)); + } + + function _getEnforcer() internal view override returns (ICaveatEnforcer) { + return ICaveatEnforcer(address(allowedCalldataAnyOfEnforcer)); + } +}