Skip to content

Commit

Permalink
Added authorization via ZKP to SmartAccount (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
KyrylR authored Sep 25, 2024
1 parent a3245a3 commit 4784caa
Show file tree
Hide file tree
Showing 17 changed files with 641 additions and 71 deletions.
2 changes: 1 addition & 1 deletion .github/actions/setup/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ runs:
- name: Setup node
uses: actions/setup-node@v3
with:
node-version: "18.18.x"
node-version: "20.x"
cache: npm
- name: Install packages
run: npm install
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@ jobs:
uses: actions/checkout@v3
- name: Setup
uses: ./.github/actions/setup
- name: Setup circuits
run: npm run prepare-circuits
- name: Run tests
run: npm run test
22 changes: 22 additions & 0 deletions circuits/BuildNullifier.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// LICENSE: MIT
pragma circom 2.1.6;

include "circomlib/circuits/poseidon.circom";

template BuildNullifier() {
signal output nullifier;

signal input sk_i;
signal input eventID;

component hasher1 = Poseidon(1);
component hasher3 = Poseidon(3);

sk_i ==> hasher1.inputs[0];

sk_i ==> hasher3.inputs[0];
hasher1.out ==> hasher3.inputs[1];
eventID ==> hasher3.inputs[2];

nullifier <== hasher3.out;
}
27 changes: 27 additions & 0 deletions circuits/ExtractPublicKey.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// LICENSE: MIT
pragma circom 2.1.6;

include "circomlib/circuits/escalarmulfix.circom";

template ExtractPublicKey() {
signal input privateKey;
signal output Ax;
signal output Ay;

var BASE8[2] = [
5299619240641551281634865583518297030282874472190772894086521144482721001553,
16950150798460657717958625567821834550301663161624707787222815936182638968203
];

component privateKeyBits = Num2Bits(254);
privateKeyBits.in <== privateKey;

component mulFix = EscalarMulFix(254, BASE8);

for (var i = 0; i < 254; i++) {
mulFix.e[i] <== privateKeyBits.out[i];
}

Ax <== mulFix.out[0];
Ay <== mulFix.out[1];
}
48 changes: 48 additions & 0 deletions circuits/IdentityAuth.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// LICENSE: MIT
pragma circom 2.1.6;

include "circomlib/circuits/babyjub.circom";
include "circomlib/circuits/poseidon.circom";

include "BuildNullifier.circom";
include "ExtractPublicKey.circom";
include "OptimizedEdDSAPoseidonVerifier.circom";

template IdentityAuth() {
// Public Outputs
signal output nullifier; // Poseidon3(sk_i, Poseidon1(sk_i), eventID)

// Public Inputs
signal input messageHash;

// Private Inputs
signal input sk_i;
signal input eventID;

signal input signatureR8x;
signal input signatureR8y;
signal input signatureS;

// Verify Nullifier
component nullifierVerifier = BuildNullifier();

sk_i ==> nullifierVerifier.sk_i;
eventID ==> nullifierVerifier.eventID;

nullifier <== nullifierVerifier.nullifier;

component getPubKey = ExtractPublicKey();
sk_i ==> getPubKey.privateKey;

component sigVerifier = OptimizedEdDSAPoseidonVerifier();
sigVerifier.enabled <== 1;

sigVerifier.Ax <== getPubKey.Ax;
sigVerifier.Ay <== getPubKey.Ay;
sigVerifier.S <== signatureS;
sigVerifier.R8x <== signatureR8x;
sigVerifier.R8y <== signatureR8y;
sigVerifier.M <== messageHash;
}

component main {public [messageHash]} = IdentityAuth();
86 changes: 86 additions & 0 deletions circuits/OptimizedEdDSAPoseidonVerifier.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// LICENSE: MIT
pragma circom 2.1.6;

include "circomlib/circuits/compconstant.circom";
include "circomlib/circuits/poseidon.circom";
include "circomlib/circuits/bitify.circom";
include "circomlib/circuits/escalarmulany.circom";
include "circomlib/circuits/escalarmulfix.circom";

template OptimizedEdDSAPoseidonVerifier() {
signal input enabled;
signal input Ax;
signal input Ay;

signal input S;
signal input R8x;
signal input R8y;

signal input M;

// Ensure S < Subgroup Order
component sNum2Bits = Num2Bits(253);
sNum2Bits.in <== S;

component compConstant = CompConstant(2736030358979909402780800718157159386076813972158567259200215660948447373040);

for (var i = 0; i < 253; i++) {
sNum2Bits.out[i] ==> compConstant.in[i];
}

compConstant.in[253] <== 0;
compConstant.out * enabled === 0;

// Calculate the h = H(R, A, msg)
component hash = Poseidon(5);

hash.inputs[0] <== R8x;
hash.inputs[1] <== R8y;
hash.inputs[2] <== Ax;
hash.inputs[3] <== Ay;
hash.inputs[4] <== M;

component h2bits = Num2Bits_strict();
h2bits.in <== hash.out;

component b2Num = Bits2Num_strict();
h2bits.out ==> b2Num.in;

// Calculate second part of the right side: right2 = h * A
component mulAny = EscalarMulAny(254);
for (var i = 0; i < 254; i++) {
mulAny.e[i] <== h2bits.out[i];
}

mulAny.p[0] <== Ax;
mulAny.p[1] <== Ay;

// Compute the right side: right = R8 + right2
component addRight = BabyAdd();
addRight.x1 <== R8x;
addRight.y1 <== R8y;
addRight.x2 <== mulAny.out[0];
addRight.y2 <== mulAny.out[1];

// Calculate left side of equation left = S * B8
var BASE8[2] = [
5299619240641551281634865583518297030282874472190772894086521144482721001553,
16950150798460657717958625567821834550301663161624707787222815936182638968203
];

component mulFix = EscalarMulFix(253, BASE8);
for (var i = 0; i < 253; i++) {
mulFix.e[i] <== sNum2Bits.out[i];
}

// Do the comparison left == right if enabled;
component eqCheckX = ForceEqualIfEnabled();
eqCheckX.enabled <== enabled;
eqCheckX.in[0] <== mulFix.out[0];
eqCheckX.in[1] <== addRight.xout;

component eqCheckY = ForceEqualIfEnabled();
eqCheckY.enabled <== enabled;
eqCheckY.in[0] <== mulFix.out[1];
eqCheckY.in[1] <== addRight.yout;
}
94 changes: 69 additions & 25 deletions contracts/SmartAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,35 @@
pragma solidity ^0.8.20;

import {Nonces} from "@openzeppelin/contracts/utils/Nonces.sol";
import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol";
import {ERC1967Utils} from "@openzeppelin/contracts/proxy/ERC1967/ERC1967Utils.sol";
import {UUPSUpgradeable} from "@openzeppelin/contracts/proxy/utils/UUPSUpgradeable.sol";
import {ERC1155Holder} from "@openzeppelin/contracts/token/ERC1155/utils/ERC1155Holder.sol";
import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol";

import {OwnableUpgradeable} from "@openzeppelin/contracts-upgradeable/access/OwnableUpgradeable.sol";
import {Initializable} from "@openzeppelin/contracts/proxy/utils/Initializable.sol";

import {IAccount} from "@account-abstraction/contracts/interfaces/IAccount.sol";
import {IEntryPoint} from "@account-abstraction/contracts/interfaces/IEntryPoint.sol";
import {PackedUserOperation} from "@account-abstraction/contracts/interfaces/PackedUserOperation.sol";
import {SIG_VALIDATION_FAILED, SIG_VALIDATION_SUCCESS} from "@account-abstraction/contracts/core/Helpers.sol";

contract SmartAccount is IAccount, UUPSUpgradeable, ERC1155Holder, Nonces, OwnableUpgradeable {
IEntryPoint private immutable ENTRY_POINT;
import {TypeCaster} from "@solarity/solidity-lib/libs/utils/TypeCaster.sol";
import {VerifierHelper} from "@solarity/solidity-lib/libs/zkp/snarkjs/VerifierHelper.sol";

contract SmartAccount is IAccount, Initializable, UUPSUpgradeable, ERC1155Holder, Nonces {
using TypeCaster for *;
using VerifierHelper for address;

struct IdentityProof {
VerifierHelper.ProofPoints identityProof;
}

IEntryPoint public immutable ENTRY_POINT;

address public immutable IDENTITY_AUTH_VERIFIER;

bytes32 public nullifier;

mapping(address => uint48) public sessionAccounts;

modifier onlyThis() {
_requireThis();
Expand All @@ -33,6 +47,9 @@ contract SmartAccount is IAccount, UUPSUpgradeable, ERC1155Holder, Nonces, Ownab
_;
}

event SessionAccountSet(address indexed account, uint256 timestamp);

error InvalidProof();
error CallFailed(bytes result);
error NotFromThis(address sender);
error InvalidNonce(uint256 nonce);
Expand All @@ -41,14 +58,16 @@ contract SmartAccount is IAccount, UUPSUpgradeable, ERC1155Holder, Nonces, Ownab

receive() external payable {}

constructor(address entryPoint_) {
constructor(address entryPoint_, address identityAuthVerifier_) {
ENTRY_POINT = IEntryPoint(entryPoint_);

IDENTITY_AUTH_VERIFIER = identityAuthVerifier_;

_disableInitializers();
}

function __SmartAccount_init(address owner_) external initializer {
__Ownable_init(owner_);
function __SmartAccount_init(bytes32 nullifier_) external initializer {
nullifier = nullifier_;
}

function execute(
Expand All @@ -72,28 +91,53 @@ contract SmartAccount is IAccount, UUPSUpgradeable, ERC1155Holder, Nonces, Ownab
_payPrefund(missingAccountFunds);
}

function validateSignature(
bytes32 messageHash_,
bytes memory signature_
) public view returns (bool) {
return
ECDSA.recover(MessageHashUtils.toEthSignedMessageHash(messageHash_), signature_) ==
owner();
function setSessionAccount(address candidate_, bytes memory signature_) external {
IdentityProof memory identityProof_ = decodeIdentityProof(signature_);

bool proofResult_ = IDENTITY_AUTH_VERIFIER.verifyProofSafe(
[uint256(nullifier), uint256(uint160(candidate_))].asDynamic(),
identityProof_.identityProof,
2
);

if (!proofResult_) {
revert InvalidProof();
}

sessionAccounts[candidate_] = uint48(block.timestamp);

emit SessionAccountSet(candidate_, block.timestamp);
}

function encodeIdentityProof(
IdentityProof memory proof_
) external pure returns (bytes memory) {
return abi.encode(proof_);
}

function getCurrentNonce() public view virtual returns (uint256) {
function decodeIdentityProof(bytes memory data_) public pure returns (IdentityProof memory) {
return abi.decode(data_, (IdentityProof));
}

function getCurrentNonce() public view returns (uint256) {
return nonces(address(this));
}

function supportsInterface(bytes4 interfaceId) public view virtual override returns (bool) {
function supportsInterface(bytes4 interfaceId) public view override returns (bool) {
return interfaceId == type(IAccount).interfaceId || super.supportsInterface(interfaceId);
}

function _validateSignature(
PackedUserOperation calldata userOp,
bytes32 userOpHash
) internal view returns (uint256 validationData) {
bool proofResult_ = validateSignature(userOpHash, userOp.signature);
IdentityProof memory identityProof_ = decodeIdentityProof(userOp.signature);

bool proofResult_ = IDENTITY_AUTH_VERIFIER.verifyProofSafe(
[uint256(nullifier), uint256(userOpHash)].asDynamic(),
identityProof_.identityProof,
2
);

if (!proofResult_) {
return SIG_VALIDATION_FAILED;
Expand All @@ -104,7 +148,7 @@ contract SmartAccount is IAccount, UUPSUpgradeable, ERC1155Holder, Nonces, Ownab

function _payPrefund(uint256 missingAccountFunds_) internal {
if (missingAccountFunds_ != 0) {
(bool success, ) = payable(_msgSender()).call{
(bool success, ) = payable(msg.sender).call{
value: missingAccountFunds_,
gas: type(uint256).max
}("");
Expand All @@ -126,20 +170,20 @@ contract SmartAccount is IAccount, UUPSUpgradeable, ERC1155Holder, Nonces, Ownab
}

function _requireEntryPoint() internal view {
if (_msgSender() != address(ENTRY_POINT)) {
revert NotFromEntryPoint(_msgSender());
if (msg.sender != address(ENTRY_POINT)) {
revert NotFromEntryPoint(msg.sender);
}
}

function _requireEntryPointOrOwner() internal view {
if (_msgSender() != address(ENTRY_POINT) && _msgSender() != owner()) {
revert NotFromEntryPointOrOwner(_msgSender());
if (msg.sender != address(ENTRY_POINT) && sessionAccounts[msg.sender] == 0) {
revert NotFromEntryPointOrOwner(msg.sender);
}
}

function _requireThis() internal view {
if (_msgSender() != address(this)) {
revert NotFromThis(_msgSender());
if (msg.sender != address(this)) {
revert NotFromThis(msg.sender);
}
}
}
Loading

0 comments on commit 4784caa

Please sign in to comment.