Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions src/enforcers/AllowedCalldataAnyOfEnforcer.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// SPDX-License-Identifier: MIT AND Apache-2.0
pragma solidity 0.8.23;

import { ExecutionLib } from "@erc7579/lib/ExecutionLib.sol";

import { CaveatEnforcer } from "./CaveatEnforcer.sol";
import { ModeCode } from "../utils/Types.sol";

/**
* @title AllowedCalldataAnyOfEnforcer
* @dev Like `AllowedCalldataEnforcer`, but the delegator supplies several allowed byte sequences of **equal length**.
* @dev At `startIndex`, the execution calldata must exactly match **at least one** of those sequences (each candidate is compared
* over `valueLength` bytes, starting at `startIndex`).
* @dev This enforcer operates only in single execution call type and with default execution mode.
* @dev Prefer static or fixed-layout regions of calldata; validating dynamic types remains possible but is more error-prone,
* same as for `AllowedCalldataEnforcer`.
*/
contract AllowedCalldataAnyOfEnforcer is CaveatEnforcer {
using ExecutionLib for bytes;

////////////////////////////// Public Methods //////////////////////////////

/**
* @notice Allows the delegator to restrict calldata so that one of several equal-length slices matches at a fixed offset.
* @dev For each candidate, checks `callData[startIndex : startIndex + valueLength] == candidate`.
* @param _terms Binary layout:
* - **First 32 bytes:** `uint128 startIndex` (high 128 bits) | `uint128 valueLength` (low 128 bits) of one big-endian word.
* - **Remainder:** `candidateCount` candidates concatenated, each exactly `valueLength` bytes (so `len(remainder) == candidateCount * valueLength`).
* @param _mode The execution mode. (Must be Single callType, Default execType)
* @param _executionCallData The execution the delegate is trying to execute.
*/
function beforeHook(
bytes calldata _terms,
bytes calldata,
ModeCode _mode,
bytes calldata _executionCallData,
bytes32,
address,
address
)
public
pure
override
onlySingleCallTypeMode(_mode)
onlyDefaultExecutionMode(_mode)
{
_validateCalldata(_terms, _executionCallData);
}

/**
* @notice Decodes and validates the terms used in this CaveatEnforcer.
* @dev After reading `valueLength` from the header word, requires `valueLength >= 1`, a non-empty remainder, and that the
* remainder length is a multiple of `valueLength`.
* @param _terms Encoded data used during the execution hooks.
* @return startIndex_ Start index in the execution's call data.
* @return valueLength_ Length of every candidate slice and of the compared execution calldata window.
* @return candidateCount_ Number of candidates in the concatenated tail (`(len(_terms) - 32) / valueLength_`).
*/
function getTermsInfo(bytes calldata _terms)
public
pure
returns (uint128 startIndex_, uint128 valueLength_, uint256 candidateCount_)
{
require(_terms.length > 32, "AllowedCalldataAnyOfEnforcer:invalid-terms-size");
uint256 metadataWord_ = uint256(bytes32(_terms[0:32]));
startIndex_ = uint128(metadataWord_ >> 128);
valueLength_ = uint128(metadataWord_);

require(valueLength_ >= 1, "AllowedCalldataAnyOfEnforcer:invalid-value-length");

uint256 concatenatedValuesLength_ = _terms.length - 32;
require(concatenatedValuesLength_ != 0, "AllowedCalldataAnyOfEnforcer:no-allowed-values");
require(concatenatedValuesLength_ % uint256(valueLength_) == 0, "AllowedCalldataAnyOfEnforcer:invalid-values-padding");

candidateCount_ = concatenatedValuesLength_ / uint256(valueLength_);
}

/**
* @notice Validates that the execution calldata matches one of the allowed slices at `startIndex`.
* @param _terms Encoded terms (see `beforeHook`).
* @param _executionCallData The encoded single execution payload.
*/
function _validateCalldata(bytes calldata _terms, bytes calldata _executionCallData) private pure {
(uint128 startIndex_, uint128 valueLength_, uint256 candidateCount_) = getTermsInfo(_terms);

uint256 dataStart_ = uint256(startIndex_);
uint256 lengthToMatch_ = uint256(valueLength_);
(,, bytes calldata callData_) = _executionCallData.decodeSingle();

require(dataStart_ + lengthToMatch_ <= callData_.length, "AllowedCalldataAnyOfEnforcer:invalid-calldata-length");

bytes calldata callDataToMatch_ = callData_[dataStart_:dataStart_ + lengthToMatch_];

bool matched_;
for (uint256 i = 0; i < candidateCount_; ++i) {
uint256 offset_ = 32 + i * lengthToMatch_;
if (callDataToMatch_ == _terms[offset_:offset_ + lengthToMatch_]) {
matched_ = true;
break;
}
}
require(matched_, "AllowedCalldataAnyOfEnforcer:invalid-calldata");
}
}
285 changes: 285 additions & 0 deletions test/enforcers/AllowedCalldataAnyOfEnforcer.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
// SPDX-License-Identifier: MIT AND Apache-2.0
pragma solidity 0.8.23;

import "forge-std/Test.sol";
import { ExecutionLib } from "@erc7579/lib/ExecutionLib.sol";

import { Execution, Caveat, Delegation } from "../../src/utils/Types.sol";
import { CaveatEnforcerBaseTest } from "./CaveatEnforcerBaseTest.t.sol";
import { AllowedCalldataAnyOfEnforcer } from "../../src/enforcers/AllowedCalldataAnyOfEnforcer.sol";
import { BasicERC20, IERC20 } from "../utils/BasicERC20.t.sol";
import { ICaveatEnforcer } from "../../src/interfaces/ICaveatEnforcer.sol";

contract AllowedCalldataAnyOfEnforcerTest is CaveatEnforcerBaseTest {
////////////////////////////// State //////////////////////////////
AllowedCalldataAnyOfEnforcer public allowedCalldataAnyOfEnforcer;
BasicERC20 public basicCF20;

////////////////////// Set up //////////////////////

function setUp() public override {
super.setUp();
allowedCalldataAnyOfEnforcer = new AllowedCalldataAnyOfEnforcer();
vm.label(address(allowedCalldataAnyOfEnforcer), "Allowed Calldata Any-Of Enforcer");
basicCF20 = new BasicERC20(address(users.alice.deleGator), "TestToken1", "TestToken1", 100 ether);
}

/// @dev Header: `uint128 startIndex` (high) | `uint128 valueLength` (low), then `candidateCount * valueLength` bytes.
function _packTerms(uint128 startIndex_, uint128 valueLength_, bytes memory concatenatedValues_) internal pure returns (bytes memory) {
require(
concatenatedValues_.length > 0 && concatenatedValues_.length % uint256(valueLength_) == 0,
"test: bad concatenatedValues length"
);
uint256 metadataWord_ = (uint256(uint128(startIndex_)) << 128) | uint256(uint128(valueLength_));
return bytes.concat(bytes32(metadataWord_), concatenatedValues_);
}

////////////////////// Valid cases //////////////////////

// should allow when the calldata matches the first allowed slice at startIndex
function test_allowsWhenFirstCandidateMatches() public {
Execution memory execution_ = Execution({
target: address(basicCF20),
value: 0,
callData: abi.encodeWithSelector(BasicERC20.mint.selector, address(users.alice.deleGator), uint256(100))
});
bytes memory executionCallData_ = ExecutionLib.encodeSingle(execution_.target, execution_.value, execution_.callData);

uint128 startIndex_ = uint128(abi.encodeWithSelector(BasicERC20.mint.selector, address(0)).length);
uint128 valueLength_ = 32;
bytes memory concatenatedValues_ = bytes.concat(abi.encodePacked(uint256(100)), abi.encodePacked(uint256(200)));
bytes memory terms_ = _packTerms(startIndex_, valueLength_, concatenatedValues_);

vm.prank(address(delegationManager));
allowedCalldataAnyOfEnforcer.beforeHook(
terms_, hex"", singleDefaultMode, executionCallData_, keccak256(""), address(0), address(0)
);
}

// should allow when the calldata matches a later candidate at startIndex
function test_allowsWhenSecondCandidateMatches() public {
Execution memory execution_ = Execution({
target: address(basicCF20),
value: 0,
callData: abi.encodeWithSelector(BasicERC20.mint.selector, address(users.alice.deleGator), uint256(200))
});
bytes memory executionCallData_ = ExecutionLib.encodeSingle(execution_.target, execution_.value, execution_.callData);

uint128 startIndex_ = uint128(abi.encodeWithSelector(BasicERC20.mint.selector, address(0)).length);
uint128 valueLength_ = 32;
bytes memory concatenatedValues_ = bytes.concat(abi.encodePacked(uint256(100)), abi.encodePacked(uint256(200)));
bytes memory terms_ = _packTerms(startIndex_, valueLength_, concatenatedValues_);

vm.prank(address(delegationManager));
allowedCalldataAnyOfEnforcer.beforeHook(
terms_, hex"", singleDefaultMode, executionCallData_, keccak256(""), address(0), address(0)
);
}

// should allow when several equal-length candidates include the executed uint256
function test_allowsWhenOneOfSeveralUint256CandidatesMatches() public {
Execution memory execution_ = Execution({
target: address(basicCF20),
value: 0,
callData: abi.encodeWithSelector(BasicERC20.mint.selector, address(users.alice.deleGator), uint256(0xabcd))
});
bytes memory executionCallData_ = ExecutionLib.encodeSingle(execution_.target, execution_.value, execution_.callData);

uint128 startIndex_ = uint128(abi.encodeWithSelector(BasicERC20.mint.selector, address(0)).length);
uint128 valueLength_ = 32;
bytes memory concatenatedValues_ =
bytes.concat(abi.encodePacked(uint256(1)), abi.encodePacked(uint256(0xabcd)), abi.encodePacked(uint256(2)));
bytes memory terms_ = _packTerms(startIndex_, valueLength_, concatenatedValues_);

vm.prank(address(delegationManager));
allowedCalldataAnyOfEnforcer.beforeHook(
terms_, hex"", singleDefaultMode, executionCallData_, keccak256(""), address(0), address(0)
);
}

////////////////////// Invalid cases //////////////////////

// should NOT allow when no candidate matches at startIndex
function test_revertsWhenNoCandidateMatches() public {
Execution memory execution_ = Execution({
target: address(basicCF20),
value: 0,
callData: abi.encodeWithSelector(BasicERC20.mint.selector, address(users.alice.deleGator), uint256(300))
});
bytes memory executionCallData_ = ExecutionLib.encodeSingle(execution_.target, execution_.value, execution_.callData);

uint128 startIndex_ = uint128(abi.encodeWithSelector(BasicERC20.mint.selector, address(0)).length);
uint128 valueLength_ = 32;
bytes memory concatenatedValues_ = bytes.concat(abi.encodePacked(uint256(100)), abi.encodePacked(uint256(200)));
bytes memory terms_ = _packTerms(startIndex_, valueLength_, concatenatedValues_);

vm.prank(address(delegationManager));
vm.expectRevert("AllowedCalldataAnyOfEnforcer:invalid-calldata");
allowedCalldataAnyOfEnforcer.beforeHook(
terms_, hex"", singleDefaultMode, executionCallData_, keccak256(""), address(0), address(0)
);
}

// should NOT allow when the execution window is shorter than valueLength
function test_revertsWhenCalldataTooShortForSlice() public {
Execution memory execution_ = Execution({
target: address(basicCF20),
value: 0,
callData: abi.encodeWithSelector(BasicERC20.mint.selector, address(users.alice.deleGator), uint256(100))
});
bytes memory executionCallData_ = ExecutionLib.encodeSingle(execution_.target, execution_.value, execution_.callData);

uint128 startIndex_ = uint128(abi.encodeWithSelector(BasicERC20.mint.selector, address(0)).length);
uint128 valueLength_ = uint128(execution_.callData.length - uint256(startIndex_) + 1);
bytes memory concatenatedValues_ = new bytes(uint256(valueLength_));
for (uint256 i = 0; i < concatenatedValues_.length; ++i) {
concatenatedValues_[i] = 0xff;
}
bytes memory terms_ = _packTerms(startIndex_, valueLength_, concatenatedValues_);

vm.prank(address(delegationManager));
vm.expectRevert("AllowedCalldataAnyOfEnforcer:invalid-calldata-length");
allowedCalldataAnyOfEnforcer.beforeHook(
terms_, hex"", singleDefaultMode, executionCallData_, keccak256(""), address(0), address(0)
);
}

// should FAIL getTermsInfo when terms are shorter than 32 bytes
function test_getTermsInfoFailsForShortTerms() public {
vm.expectRevert("AllowedCalldataAnyOfEnforcer:invalid-terms-size");
allowedCalldataAnyOfEnforcer.getTermsInfo(hex"010203");
}

// should FAIL getTermsInfo when there is no candidate tail
function test_getTermsInfoFailsForEmptyCandidatesTail() public {
uint256 metadataWord_ = (uint256(uint128(0)) << 128) | uint256(uint128(32));
bytes memory terms_ = abi.encodePacked(bytes32(metadataWord_));
vm.expectRevert("AllowedCalldataAnyOfEnforcer:no-allowed-values");
allowedCalldataAnyOfEnforcer.getTermsInfo(terms_);
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test expects wrong revert from unreachable check

Medium Severity

test_getTermsInfoFailsForEmptyCandidatesTail will fail. It constructs 32-byte terms via abi.encodePacked(bytes32(...)) and expects revert "AllowedCalldataAnyOfEnforcer:no-allowed-values". However, the contract's getTermsInfo checks _terms.length > 32 first (line 64), which fails for exactly 32 bytes, reverting with "AllowedCalldataAnyOfEnforcer:invalid-terms-size" instead. The concatenatedValuesLength_ != 0 check on line 72 is unreachable dead code since _terms.length > 32 already guarantees concatenatedValuesLength_ >= 1.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 60dddc6. Configure here.


// should FAIL getTermsInfo when valueLength is zero
function test_getTermsInfoFailsForZeroValueLength() public {
uint256 metadataWord_ = (uint256(uint128(4)) << 128) | uint256(uint128(0));
bytes memory terms_ = bytes.concat(bytes32(metadataWord_), hex"aabb");
vm.expectRevert("AllowedCalldataAnyOfEnforcer:invalid-value-length");
allowedCalldataAnyOfEnforcer.getTermsInfo(terms_);
}

// should FAIL getTermsInfo when tail is not a multiple of valueLength
function test_getTermsInfoFailsForInvalidValuesPadding() public {
uint256 metadataWord_ = (uint256(uint128(0)) << 128) | uint256(uint128(32));
bytes memory terms_ = bytes.concat(bytes32(metadataWord_), new bytes(33));
vm.expectRevert("AllowedCalldataAnyOfEnforcer:invalid-values-padding");
allowedCalldataAnyOfEnforcer.getTermsInfo(terms_);
}

// should decode header via getTermsInfo
function test_getTermsInfoDecodesHeaderAndCount() public view {
uint128 expectedStartIndex_ = 40;
uint128 expectedValueLength_ = 32;
bytes memory concatenatedValues_ = bytes.concat(abi.encodePacked(uint256(1)), abi.encodePacked(uint256(2)));
bytes memory terms_ = _packTerms(expectedStartIndex_, expectedValueLength_, concatenatedValues_);
(uint128 startIndex_, uint128 valueLength_, uint256 candidateCount_) =
allowedCalldataAnyOfEnforcer.getTermsInfo(terms_);
assertEq(startIndex_, expectedStartIndex_);
assertEq(valueLength_, expectedValueLength_);
assertEq(candidateCount_, 2);
}

// should fail with invalid call type mode (batch instead of single mode)
function test_revertWithInvalidCallTypeMode() public {
bytes memory executionCallData_ = ExecutionLib.encodeBatch(new Execution[](2));

vm.expectRevert("CaveatEnforcer:invalid-call-type");

allowedCalldataAnyOfEnforcer.beforeHook(
hex"", hex"", batchDefaultMode, executionCallData_, bytes32(0), address(0), address(0)
);
}

// should fail with invalid call type mode (try instead of default)
function test_revertWithInvalidExecutionMode() public {
vm.prank(address(delegationManager));
vm.expectRevert("CaveatEnforcer:invalid-execution-type");
allowedCalldataAnyOfEnforcer.beforeHook(hex"", hex"", singleTryMode, hex"", bytes32(0), address(0), address(0));
}

////////////////////// Integration //////////////////////

// should allow execution when the amount matches one of the allowed encodings Integration
function test_integrationAllowsMatchingAmount() public {
assertEq(basicCF20.balanceOf(address(users.bob.deleGator)), 0);

Execution memory execution1_ = Execution({
target: address(basicCF20),
value: 0,
callData: abi.encodeWithSelector(IERC20.transfer.selector, address(users.bob.deleGator), uint256(2))
});

uint128 startIndex_ = uint128(abi.encodeWithSelector(IERC20.transfer.selector, address(0)).length);
uint128 valueLength_ = 32;
bytes memory concatenatedValues_ = bytes.concat(abi.encodePacked(uint256(1)), abi.encodePacked(uint256(2)));
bytes memory terms_ = _packTerms(startIndex_, valueLength_, concatenatedValues_);

Caveat[] memory caveats_ = new Caveat[](1);
caveats_[0] = Caveat({ args: hex"", enforcer: address(allowedCalldataAnyOfEnforcer), terms: terms_ });
Delegation memory delegation_ = Delegation({
delegate: address(users.bob.deleGator),
delegator: address(users.alice.deleGator),
authority: ROOT_AUTHORITY,
caveats: caveats_,
salt: 0,
signature: hex""
});

delegation_ = signDelegation(users.alice, delegation_);

Delegation[] memory delegations_ = new Delegation[](1);
delegations_[0] = delegation_;

invokeDelegation_UserOp(users.bob, delegations_, execution1_);

assertEq(basicCF20.balanceOf(address(users.bob.deleGator)), uint256(2));
}

// should NOT allow execution when the amount matches none of the allowed encodings Integration
function test_integrationRejectsNonMatchingAmount() public {
assertEq(basicCF20.balanceOf(address(users.bob.deleGator)), 0);

Execution memory execution1_ = Execution({
target: address(basicCF20),
value: 0,
callData: abi.encodeWithSelector(IERC20.transfer.selector, address(users.bob.deleGator), uint256(3))
});

uint128 startIndex_ = uint128(abi.encodeWithSelector(IERC20.transfer.selector, address(0)).length);
uint128 valueLength_ = 32;
bytes memory concatenatedValues_ = bytes.concat(abi.encodePacked(uint256(1)), abi.encodePacked(uint256(2)));
bytes memory terms_ = _packTerms(startIndex_, valueLength_, concatenatedValues_);

Caveat[] memory caveats_ = new Caveat[](1);
caveats_[0] = Caveat({ args: hex"", enforcer: address(allowedCalldataAnyOfEnforcer), terms: terms_ });
Delegation memory delegation_ = Delegation({
delegate: address(users.bob.deleGator),
delegator: address(users.alice.deleGator),
authority: ROOT_AUTHORITY,
caveats: caveats_,
salt: 0,
signature: hex""
});

delegation_ = signDelegation(users.alice, delegation_);

Delegation[] memory delegations_ = new Delegation[](1);
delegations_[0] = delegation_;

invokeDelegation_UserOp(users.bob, delegations_, execution1_);

assertEq(basicCF20.balanceOf(address(users.bob.deleGator)), uint256(0));
}

function _getEnforcer() internal view override returns (ICaveatEnforcer) {
return ICaveatEnforcer(address(allowedCalldataAnyOfEnforcer));
}
}
Loading