diff --git a/.eslintrc.json b/.eslintrc.json index ac8fb75f6..e1c5c0e84 100644 --- a/.eslintrc.json +++ b/.eslintrc.json @@ -16,7 +16,10 @@ "eqeqeq": "error", "no-unused-vars": "off", "unicorn/filename-case": ["error", { "case": "camelCase" }], - "@typescript-eslint/no-unused-vars": ["error", { "argsIgnorePattern": "^_" }], + "@typescript-eslint/no-unused-vars": [ + "error", + { "argsIgnorePattern": "^_", "varsIgnorePattern": "^_" } + ], "@typescript-eslint/ban-ts-comment": ["error", { "ts-ignore": "allow-with-description" }] } } diff --git a/docs/i18n/vi/docusaurus-plugin-content-docs/current/features/cairo_stubs.mdx b/docs/i18n/vi/docusaurus-plugin-content-docs/current/features/cairo_stubs.mdx index 4ff75f69a..ee4144275 100644 --- a/docs/i18n/vi/docusaurus-plugin-content-docs/current/features/cairo_stubs.mdx +++ b/docs/i18n/vi/docusaurus-plugin-content-docs/current/features/cairo_stubs.mdx @@ -9,7 +9,7 @@ The system works in the following way: 1. To start a Cairo Block add your Cairo code above a Solidity function with 3 forward slashes at the beginning of each line and the phrase `warp-cairo` at the top. 2. The user then uses a number of MACROS to interact with the transpiled contract. -3. The Soldiity function will then be replaced with the Cario function that is above it. +3. The Solidity function will then be replaced with the Cairo function that is above it. The following MACROS are supported: diff --git a/exampleContracts/internalFunctions.sol b/exampleContracts/internalFunctions.sol new file mode 100644 index 000000000..ffd4abd36 --- /dev/null +++ b/exampleContracts/internalFunctions.sol @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT +pragma solidity >=0.8.6; + +contract WARP { + function f( + uint256 x, + uint8 y, + uint8 z + ) internal pure returns (bytes memory) { + return abi.encode(x, y, z); + } +} diff --git a/package.json b/package.json index 18ecc6c60..9f89663b7 100644 --- a/package.json +++ b/package.json @@ -74,10 +74,12 @@ }, "dependencies": { "@algorithm.ts/gcd": "^2.0.14", + "@types/glob": "^8.1.0", "chalk": "^4.1.2", "commander": "^9.1.0", "eslint-plugin-unicorn": "^45.0.2", "ethers": "^5.6.2", + "glob": "^8.1.0", "keccak": "^3.0.2", "peggy": "^1.2.0", "prompts": "^2.4.2", diff --git a/src/ast/ast.ts b/src/ast/ast.ts index 108ee0040..4ec945fab 100644 --- a/src/ast/ast.ts +++ b/src/ast/ast.ts @@ -25,9 +25,11 @@ import { printNode } from '../utils/astPrinter'; import { TranspileFailedError } from '../utils/errors'; import { Implicits } from '../utils/implicits'; import { createBlock } from '../utils/nodeTemplates'; +import { createImport } from '../utils/importFuncGenerator'; import { safeGetNodeType } from '../utils/nodeTypeProcessing'; -import { getContainingSourceUnit, isExternalCall, mergeImports } from '../utils/utils'; -import { CairoFunctionDefinition } from './cairoNodes'; +import { getContainingSourceUnit, isExternalCall } from '../utils/utils'; +import { CairoFunctionDefinition, CairoImportFunctionDefinition } from './cairoNodes'; +import { ParameterInfo } from '../export'; /* A centralised store of information required for transpilation, a reference @@ -54,8 +56,6 @@ export class AST { private cairoUtilFuncGen: Map = new Map(); context: ASTContext; - // node requiring cairo import -> file to import from -> symbols to import - imports: Map>> = new Map(); public inference: InferType; readonly tempId = -1; @@ -81,16 +81,6 @@ export class AST { ); } - copyRegisteredImports(oldNode: ASTNode, newNode: ASTNode): void { - this.imports.set( - newNode, - mergeImports( - this.imports.get(oldNode) ?? new Map>(), - this.imports.get(newNode) ?? new Map>(), - ), - ); - } - extractToConstant( node: Expression, vType: TypeName, @@ -175,20 +165,6 @@ export class AST { return containingFunction.implicits; } - getImports(sourceUnit: SourceUnit): Map> { - assert( - this.roots.includes(sourceUnit), - `Tried to get imports associated with ${printNode( - sourceUnit, - )}, which is not one of the roots of the AST`, - ); - const reachableNodeImports = sourceUnit - .getChildren(true) - .map((node) => this.imports.get(node) ?? new Map>()); - const utilFunctionImports = this.getUtilFuncGen(sourceUnit)?.getImports(); - return mergeImports(utilFunctionImports, ...reachableNodeImports); - } - getUtilFuncGen(node: ASTNode): CairoUtilFuncGen { const sourceUnit = node instanceof SourceUnit ? node : getContainingSourceUnit(node); const gen = this.cairoUtilFuncGen.get(sourceUnit.id); @@ -298,12 +274,15 @@ export class AST { return child.id; } - registerImport(node: ASTNode, location: string, name: string): void { - const nodeImports = this.imports.get(node) ?? new Map>(); - const fileImports = nodeImports.get(location) ?? new Set(); - fileImports.add(name); - nodeImports.set(location, fileImports); - this.imports.set(node, nodeImports); + registerImport( + node: ASTNode, + location: string, + name: string, + inputs: ParameterInfo[], + outputs: ParameterInfo[], + options?: { acceptsRawDarray?: boolean; acceptsUnpackedStructArray?: boolean }, + ): CairoImportFunctionDefinition { + return createImport(location, name, node, this, inputs, outputs, options); } removeStatement(statement: Statement): void { @@ -317,19 +296,9 @@ export class AST { } // Reference notes/astnodetypes.ts for exact restrictions on what can safely be replaced with what - replaceNode( - oldNode: Expression, - newNode: Expression, - parent?: ASTNode, - copyImports?: boolean, - ): number; - replaceNode( - oldNode: Statement, - newNode: Statement, - parent?: ASTNode, - copyImports?: boolean, - ): number; - replaceNode(oldNode: ASTNode, newNode: ASTNode, parent?: ASTNode, copyImports = true): number { + replaceNode(oldNode: Expression, newNode: Expression, parent?: ASTNode): number; + replaceNode(oldNode: Statement, newNode: Statement, parent?: ASTNode): number; + replaceNode(oldNode: ASTNode, newNode: ASTNode, parent?: ASTNode): number { if (oldNode === newNode) { console.log(`WARNING: Attempted to replace node ${printNode(oldNode)} with itself`); return oldNode.id; @@ -355,9 +324,6 @@ export class AST { replaceNode(oldNode, newNode, parent); this.context.unregister(oldNode); this.setContextRecursive(newNode); - if (copyImports) { - this.copyRegisteredImports(oldNode, newNode); - } return newNode.id; } diff --git a/src/ast/cairoNodes/cairoBlockFunctionDefinition.ts b/src/ast/cairoNodes/cairoBlockFunctionDefinition.ts new file mode 100644 index 000000000..c70ec570e --- /dev/null +++ b/src/ast/cairoNodes/cairoBlockFunctionDefinition.ts @@ -0,0 +1,37 @@ +import { + FunctionKind, + FunctionStateMutability, + FunctionVisibility, + ParameterList, +} from 'solc-typed-ast'; +import { FunctionStubKind } from './cairoFunctionDefinition'; +import { CairoRawStringFunctionDefinition } from './cairoRawStringFunctionDefinition'; + +export class CairoBlockFunctionDefinition extends CairoRawStringFunctionDefinition { + constructor( + id: number, + src: string, + scope: number, + kind: FunctionKind, + name: string, + visibility: FunctionVisibility, + stateMutability: FunctionStateMutability, + parameters: ParameterList, + returnParameters: ParameterList, + rawStringDefinition: string, + ) { + super( + id, + src, + scope, + kind, + name, + visibility, + stateMutability, + parameters, + returnParameters, + FunctionStubKind.FunctionDefStub, + rawStringDefinition, + ); + } +} diff --git a/src/ast/cairoNodes/cairoFunctionDefinition.ts b/src/ast/cairoNodes/cairoFunctionDefinition.ts index 63005852a..25b4e48a8 100644 --- a/src/ast/cairoNodes/cairoFunctionDefinition.ts +++ b/src/ast/cairoNodes/cairoFunctionDefinition.ts @@ -13,17 +13,18 @@ import { import { Implicits } from '../../utils/implicits'; /* - An extension of FunctionDefinition to track which implicit arguments are used + An extension of FunctionDefinition to track which implicit arguments are used. Additionally we often use function stubs for instances where we want to be able to insert function during transpilation where it wouldn't make sense to include their body in the AST. For example, stubs are used for warplib functions, and - those generated to handle memory and storage processing. Marking a CairoFunctionDefintion - as a stub tells the CairoWriter not to print it + those generated to handle memory and storage processing. Marking a CairoFunctionDefinition + as a stub tells the CairoWriter not to print it. */ export enum FunctionStubKind { None, FunctionDefStub, + StorageDefStub, StructDefStub, } diff --git a/src/ast/cairoNodes/cairoGeneratedFunctionDefinition.ts b/src/ast/cairoNodes/cairoGeneratedFunctionDefinition.ts new file mode 100644 index 000000000..b2b39587e --- /dev/null +++ b/src/ast/cairoNodes/cairoGeneratedFunctionDefinition.ts @@ -0,0 +1,54 @@ +import { + FunctionDefinition, + FunctionKind, + FunctionStateMutability, + FunctionVisibility, + ParameterList, +} from 'solc-typed-ast'; +import { FunctionStubKind } from './cairoFunctionDefinition'; +import { CairoRawStringFunctionDefinition } from './cairoRawStringFunctionDefinition'; + +export class CairoGeneratedFunctionDefinition extends CairoRawStringFunctionDefinition { + /** + * List of function defintions called by the generated function + */ + public functionsCalled: FunctionDefinition[]; + + constructor( + id: number, + src: string, + scope: number, + kind: FunctionKind, + name: string, + visibility: FunctionVisibility, + stateMutability: FunctionStateMutability, + parameters: ParameterList, + returnParameters: ParameterList, + functionStubKind: FunctionStubKind, + rawStringDefinition: string, + functionsCalled: FunctionDefinition[], + acceptsRawDArray = false, + acceptsUnpackedStructArray = false, + ) { + super( + id, + src, + scope, + kind, + name, + visibility, + stateMutability, + parameters, + returnParameters, + functionStubKind, + rawStringDefinition, + acceptsRawDArray, + acceptsUnpackedStructArray, + ); + this.functionsCalled = removeRepeatedFunction(functionsCalled); + } +} + +function removeRepeatedFunction(functionsCalled: FunctionDefinition[]) { + return [...new Set(functionsCalled)]; +} diff --git a/src/ast/cairoNodes/cairoImportFunctionDefinition.ts b/src/ast/cairoNodes/cairoImportFunctionDefinition.ts new file mode 100644 index 000000000..417e2b7d1 --- /dev/null +++ b/src/ast/cairoNodes/cairoImportFunctionDefinition.ts @@ -0,0 +1,45 @@ +import { + FunctionKind, + FunctionStateMutability, + FunctionVisibility, + ParameterList, +} from 'solc-typed-ast'; +import { CairoFunctionDefinition, FunctionStubKind } from './cairoFunctionDefinition'; +import { Implicits } from '../../utils/implicits'; + +export class CairoImportFunctionDefinition extends CairoFunctionDefinition { + path: string; + constructor( + id: number, + src: string, + scope: number, + name: string, + path: string, + implicits: Set, + parameters: ParameterList, + returnParameters: ParameterList, + stubKind: FunctionStubKind, + acceptsRawDArray = false, + acceptsUnpackedStructArray = false, + ) { + super( + id, + src, + scope, + FunctionKind.Function, + name, + false, + FunctionVisibility.Internal, + FunctionStateMutability.NonPayable, + false, + parameters, + returnParameters, + [], + implicits, + stubKind, + acceptsRawDArray, + acceptsUnpackedStructArray, + ); + this.path = path; + } +} diff --git a/src/ast/cairoNodes/cairoRawStringFunctionDefinition.ts b/src/ast/cairoNodes/cairoRawStringFunctionDefinition.ts new file mode 100644 index 000000000..cb44ab87d --- /dev/null +++ b/src/ast/cairoNodes/cairoRawStringFunctionDefinition.ts @@ -0,0 +1,51 @@ +import { + FunctionKind, + FunctionStateMutability, + FunctionVisibility, + ParameterList, +} from 'solc-typed-ast'; +import { CairoFunctionDefinition, FunctionStubKind } from './cairoFunctionDefinition'; +import { getRawCairoFunctionInfo } from '../../utils/cairoParsing'; + +export class CairoRawStringFunctionDefinition extends CairoFunctionDefinition { + rawStringDefinition: string; + constructor( + id: number, + src: string, + scope: number, + kind: FunctionKind, + name: string, + visibility: FunctionVisibility, + stateMutability: FunctionStateMutability, + parameters: ParameterList, + returnParameters: ParameterList, + functionSutbKind: FunctionStubKind, + rawStringDefinition: string, + acceptsRawDArray = false, + acceptsUnpackedStructArray = false, + ) { + super( + id, + src, + scope, + kind, + name, + false, // Virtual + visibility, + stateMutability, + false, // IsConstructor + parameters, + returnParameters, + [], // Modifier Invocation + functionSutbKind === FunctionStubKind.FunctionDefStub + ? new Set(getRawCairoFunctionInfo(rawStringDefinition).implicits) + : functionSutbKind === FunctionStubKind.StorageDefStub + ? new Set(['pedersen_ptr']) + : new Set(), + functionSutbKind, + acceptsRawDArray, + acceptsUnpackedStructArray, + ); + this.rawStringDefinition = rawStringDefinition; + } +} diff --git a/src/ast/cairoNodes/index.ts b/src/ast/cairoNodes/index.ts index 7d422a2b5..db088463b 100644 --- a/src/ast/cairoNodes/index.ts +++ b/src/ast/cairoNodes/index.ts @@ -1,4 +1,6 @@ export * from './cairoAssert'; export * from './cairoContract'; +export * from './cairoGeneratedFunctionDefinition'; export * from './cairoFunctionDefinition'; +export * from './cairoImportFunctionDefinition'; export * from './cairoTempVarStatement'; diff --git a/src/ast/visitor.ts b/src/ast/visitor.ts index 16e772c47..1326efe30 100644 --- a/src/ast/visitor.ts +++ b/src/ast/visitor.ts @@ -73,6 +73,7 @@ import { } from './cairoNodes'; import { AST } from './ast'; +import { CairoGeneratedFunctionDefinition } from './cairoNodes/cairoGeneratedFunctionDefinition'; /* Visits every node in a tree in depth first order, calling visitT for each T extends ASTNode @@ -120,6 +121,8 @@ export abstract class ASTVisitor { else if (node instanceof Literal) res = this.visitLiteral(node, ast); else if (node instanceof TupleExpression) res = this.visitTupleExpression(node, ast); else if (node instanceof UnaryOperation) res = this.visitUnaryOperation(node, ast); + else if (node instanceof CairoGeneratedFunctionDefinition) + res = this.visitCairoGeneratedFunctionDefinition(node, ast); else if (node instanceof CairoFunctionDefinition) res = this.visitCairoFunctionDefinition(node, ast); else if (node instanceof CairoTempVarStatement) res = this.visitCairoTempVar(node, ast); @@ -169,6 +172,9 @@ export abstract class ASTVisitor { visitCairoFunctionDefinition(node: CairoFunctionDefinition, ast: AST): T { return this.visitFunctionDefinition(node, ast); } + visitCairoGeneratedFunctionDefinition(node: CairoGeneratedFunctionDefinition, ast: AST): T { + return this.visitCairoFunctionDefinition(node, ast); + } visitCairoTempVar(node: CairoTempVarStatement, ast: AST): T { return this.commonVisit(node, ast); } diff --git a/src/cairoUtilFuncGen/abi/abiDecode.ts b/src/cairoUtilFuncGen/abi/abiDecode.ts index 83f86b01b..61f223f36 100644 --- a/src/cairoUtilFuncGen/abi/abiDecode.ts +++ b/src/cairoUtilFuncGen/abi/abiDecode.ts @@ -14,10 +14,12 @@ import { UserDefinedType, } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { CairoFunctionDefinition } from '../../ast/cairoNodes'; +import { GeneratedFunctionInfo } from '../base'; import { printTypeNode } from '../../utils/astPrinter'; import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; import { TranspileFailedError } from '../../utils/errors'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { createBytesTypeName } from '../../utils/nodeTemplates'; import { getByteSize, @@ -33,7 +35,7 @@ import { import { typeNameFromTypeNode } from '../../utils/utils'; import { uint256 } from '../../warplib/utils'; import { add, delegateBasedOnType, mul, StringIndexedFuncGenWithAuxiliar } from '../base'; -import { MemoryWriteGen } from '../export'; +import { MemoryWriteGen } from '../memory/memoryWrite'; import { removeSizeInfo } from './base'; const IMPLICITS = @@ -48,7 +50,7 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { this.memoryWrite = memoryWrite; } - public gen(expressions: Expression[], sourceUnit?: SourceUnit): FunctionCall { + public gen(expressions: Expression[]): FunctionCall { assert( expressions.length === 2, 'ABI decode must recieve two arguments: data to decode, and types to decode into', @@ -62,47 +64,65 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { ); const typesToDecode = types instanceof TupleType ? types.elements : [types]; - const functionName = this.getOrCreate(typesToDecode.map((t) => generalizeType(t)[0])); + const generatedFunction = this.getOrCreateFuncDef( + typesToDecode.map((t) => generalizeType(t)[0]), + ); + + return createCallToFunction(generatedFunction, [expressions[0]], this.ast); + } - const functionStub = createCairoFunctionStub( - functionName, + public getOrCreateFuncDef(types: TypeNode[]): CairoFunctionDefinition { + const key = types.map((t) => t.pp()).join(','); + const existing = this.generatedFunctionsDef.get(key); + if (existing !== undefined) { + return existing; + } + + const funcInfo = this.getOrCreate(types); + + const funcDef = createCairoGeneratedFunction( + funcInfo, [['data', createBytesTypeName(this.ast), DataLocation.Memory]], - typesToDecode.map((t, index) => + types.map((t, index) => isValueType(t) ? [`result${index}`, typeNameFromTypeNode(t, this.ast)] : [`result${index}`, typeNameFromTypeNode(t, this.ast), DataLocation.Memory], ), - ['bitwise_ptr', 'range_check_ptr', 'warp_memory'], this.ast, - sourceUnit ?? this.sourceUnit, + this.sourceUnit, ); - return createCallToFunction(functionStub, [expressions[0]], this.ast); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - public getOrCreate(types: TypeNode[]): string { - const key = types.map((t) => t.pp()).join(','); - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - - const [returnParams, decodings] = types.reduce( - ([returnParams, decodings], type, index) => [ - [ + private getOrCreate(types: TypeNode[]): GeneratedFunctionInfo { + const [returnParams, decodings, functionsCalled] = types.reduce( + ([returnParams, decodings, functionsCalled], type, index) => { + const newReturnParams = [ ...returnParams, { name: `result${index}`, type: CairoType.fromSol(type, this.ast, TypeConversionContext.Ref).toString(), }, - ], - [ - ...decodings, - `// Param ${index} decoding:`, - this.generateDecodingCode(type, 'mem_index', `result${index}`, 'mem_index'), - ], + ]; + const [newDecodings, newFunctionsCalled] = this.generateDecodingCode( + type, + 'mem_index', + `result${index}`, + 'mem_index', + ); + return [ + newReturnParams, + [...decodings, `// Param ${index} decoding:`, newDecodings], + [...functionsCalled, ...newFunctionsCalled], + ]; + }, + [ + new Array<{ name: string; type: string }>(), + new Array(), + new Array(), ], - [new Array<{ name: string; type: string }>(), new Array()], ); const indexLength = types.reduce( @@ -112,7 +132,7 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { const returnCairoParams = returnParams.map((r) => `${r.name} : ${r.type}`).join(','); const returnValues = returnParams.map((r) => `${r.name} = ${r.name}`).join(','); - const funcName = `${this.functionName}${this.generatedFunctions.size}`; + const funcName = `${this.functionName}${this.generatedFunctionsDef.size}`; const code = [ `func ${funcName}${IMPLICITS}(mem_ptr : felt) -> (${returnCairoParams}){`, ` alloc_locals;`, @@ -124,17 +144,15 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { `}`, ].join('\n'); - const cairoFunc = { name: funcName, code: code }; - this.generatedFunctions.set(key, cairoFunc); - return cairoFunc.name; + return { name: funcName, code: code, functionsCalled: functionsCalled }; } - public getOrCreateDecoding(type: TypeNode): string { + public getOrCreateDecoding(type: TypeNode): CairoFunctionDefinition { const unexpectedType = () => { throw new TranspileFailedError(`Decoding of ${printTypeNode(type)} is not valid`); }; - return delegateBasedOnType( + return delegateBasedOnType( type, (type) => type instanceof ArrayType @@ -153,14 +171,14 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { * @param newIndexVar cairo var to store new index position after decoding the type * @param decodeResult cairo var that stores the result of the decoding * @param relativeIndexVar cairo var to handle offset values - * @returns the generated code + * @returns the generated code and functions called */ public generateDecodingCode( type: TypeNode, newIndexVar: string, decodeResult: string, relativeIndexVar: string, - ): string { + ): [string, CairoFunctionDefinition[]] { assert( !(type instanceof PointerType), 'Pointer types are not valid types for decoding. Try to generalize them', @@ -168,14 +186,18 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { // address types get special treatment due to different byte size in ethereum and starknet if (isAddressType(type)) { - const funcName = this.createValueTypeDecoding(31); + const func = this.createValueTypeDecoding(31); return [ - `let (${decodeResult} : felt) = ${funcName}(mem_index, mem_index + 32, mem_ptr, 0);`, - `let ${newIndexVar} = mem_index + 32;`, - ].join('\n'); + [ + `let (${decodeResult} : felt) = ${func.name}(mem_index, mem_index + 32, mem_ptr, 0);`, + `let ${newIndexVar} = mem_index + 32;`, + ].join('\n'), + [func], + ]; } - const funcName = this.getOrCreateDecoding(type); + const auxFunc = this.getOrCreateDecoding(type); + const importedFuncs = []; if (isReferenceType(type)) { // Find where the type is encoded in the bytes array: @@ -188,7 +210,9 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { let initInstructions: string[] = []; let typeIndex = 'mem_index'; if (isDynamicallySized(type, this.ast.inference)) { - this.requireImport('warplib.dynamic_arrays_util', 'byte_array_to_felt_value'); + importedFuncs.push( + this.requireImport('warplib.dynamic_arrays_util', 'byte_array_to_felt_value'), + ); initInstructions = [ `let (param_offset) = byte_array_to_felt_value(mem_index, mem_index + 32, mem_ptr, 0);`, `let mem_offset = ${calcOffset('mem_index', 'param_offset', relativeIndexVar)};`, @@ -216,7 +240,7 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { `let (${decodeResult}) = wm_new(${decodeResult}_dyn_array_length256, ${uint256( elementTWidth, )});`, - `${funcName}(`, + `${auxFunc.name}(`, ` ${typeIndex} + 32,`, ` mem_ptr,`, ` 0,`, @@ -225,7 +249,7 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { `);`, ]; // Other relevant imports get added when the function is generated - this.requireImport('warplib.memory', 'wm_new'); + importedFuncs.push(this.requireImport('warplib.memory', 'wm_new')); } else if (type instanceof ArrayType) { // Handling static arrays assert(type.size !== undefined); @@ -234,7 +258,7 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { ); callInstructions = [ `let (${decodeResult}) = wm_alloc(${uint256(type.size * elemenTWidth)});`, - `${funcName}(`, + `${auxFunc.name}(`, ` ${typeIndex},`, ` mem_ptr,`, ` 0,`, @@ -242,7 +266,7 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { ` ${decodeResult}`, `);`, ]; - this.requireImport('warplib.memory', 'wm_alloc'); + importedFuncs.push(this.requireImport('warplib.memory', 'wm_alloc')); } else if (type instanceof UserDefinedType && type.definition instanceof StructDefinition) { const maxSize = CairoType.fromSol( type, @@ -251,13 +275,13 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { ).width; callInstructions = [ `let (${decodeResult}) = wm_alloc(${uint256(maxSize)});`, - `${funcName}(`, + `${auxFunc.name}(`, ` ${typeIndex},`, ` mem_ptr,`, ` ${decodeResult}`, `);`, ]; - this.requireImport('warplib.memory', 'wm_alloc'); + importedFuncs.push(this.requireImport('warplib.memory', 'wm_alloc')); } else { throw new TranspileFailedError( `Unexpected reference type to generate decoding code: ${printTypeNode(type)}`, @@ -265,10 +289,13 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { } return [ - ...initInstructions, - ...callInstructions, - `let ${newIndexVar} = mem_index + ${getByteSize(type, this.ast.inference)};`, - ].join('\n'); + [ + ...initInstructions, + ...callInstructions, + `let ${newIndexVar} = mem_index + ${getByteSize(type, this.ast.inference)};`, + ].join('\n'), + [...importedFuncs, auxFunc], + ]; } // Handling value types @@ -282,25 +309,31 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { if (byteSize === 32) { args.push('0'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); + importedFuncs.push(this.requireImport('starkware.cairo.common.uint256', 'Uint256')); } const decodeType = byteSize === 32 ? 'Uint256' : 'felt'; return [ - `let (${decodeResult} : ${decodeType}) = ${funcName}(${args.join(',')});`, - `let ${newIndexVar} = mem_index + 32;`, - ].join('\n'); + [ + `let (${decodeResult} : ${decodeType}) = ${auxFunc.name}(${args.join(',')});`, + `let ${newIndexVar} = mem_index + 32;`, + ].join('\n'), + [...importedFuncs, auxFunc], + ]; } - private createStaticArrayDecoding(type: ArrayType) { + private createStaticArrayDecoding(type: ArrayType): CairoFunctionDefinition { assert(type.size !== undefined); + const key = 'static' + removeSizeInfo(type); const existing = this.auxiliarGeneratedFunctions.get(key); - if (existing !== undefined) return existing.name; + if (existing !== undefined) { + return existing; + } const elementTWidth = CairoType.fromSol(type.elementT, this.ast).width; - const decodingCode = this.generateDecodingCode( + const [decodingCode, functionsCalled] = this.generateDecodingCode( type.elementT, 'next_mem_index', 'element', @@ -310,9 +343,8 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { 'array_index', elementTWidth, )};`; - const writeToMemCode = `${this.memoryWrite.getOrCreate( - type.elementT, - )}(write_to_mem_location, element);`; + const writeToMemFunc = this.memoryWrite.getOrCreateFuncDef(type.elementT); + const writeToMemCode = `${writeToMemFunc.name}(write_to_mem_location, element);`; const name = `${this.functionName}_static_array${this.auxiliarGeneratedFunctions.size}`; const code = [ @@ -334,21 +366,30 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { `}`, ].join('\n'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); + const funcInfo = { + name, + code, + functionsCalled: [ + this.requireImport('warplib.maths.utils', 'felt_to_uint256'), + ...functionsCalled, + writeToMemFunc, + ], + }; - this.auxiliarGeneratedFunctions.set(key, { name, code }); - return name; + const generatedFunc = this.createAuxiliarGeneratedFunction(funcInfo); + this.auxiliarGeneratedFunctions.set(key, generatedFunc); + return generatedFunc; } - private createDynamicArrayDecoding(type: ArrayType): string { + private createDynamicArrayDecoding(type: ArrayType): CairoFunctionDefinition { const key = 'dynamic' + type.pp(); const existing = this.auxiliarGeneratedFunctions.get(key); - if (existing !== undefined) return existing.name; + if (existing !== undefined) return existing; const elementT = getElementType(type); const elemenTWidth = CairoType.fromSol(elementT, this.ast, TypeConversionContext.Ref).width; - const decodingCode = this.generateDecodingCode( + const [decodingCode, functionsCalled] = this.generateDecodingCode( elementT, 'next_mem_index', 'element', @@ -360,9 +401,8 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { elemenTWidth, )});`, ].join('\n'); - const writeToMemCode = `${this.memoryWrite.getOrCreate( - elementT, - )}(write_to_mem_location, element);`; + const writeToMemFunc = this.memoryWrite.getOrCreateFuncDef(elementT); + const writeToMemCode = `${writeToMemFunc.name}(write_to_mem_location, element);`; const name = `${this.functionName}_dynamic_array${this.auxiliarGeneratedFunctions.size}`; const code = [ @@ -390,44 +430,64 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { `}`, ].join('\n'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('warplib.memory', 'wm_index_dyn'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); - this.requireImport('warplib.maths.utils', 'narrow_safe'); + const importedFuncs = [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('warplib.memory', 'wm_index_dyn'), + this.requireImport('warplib.maths.utils', 'felt_to_uint256'), + this.requireImport('warplib.maths.utils', 'narrow_safe'), + ]; - this.auxiliarGeneratedFunctions.set(key, { name, code }); - return name; + const funcInfo = { + name, + code, + functionsCalled: [...importedFuncs, ...functionsCalled, writeToMemFunc], + }; + const generatedFunc = this.createAuxiliarGeneratedFunction(funcInfo); + + this.auxiliarGeneratedFunctions.set(key, generatedFunc); + return generatedFunc; } - private createStructDecoding(type: UserDefinedType, definition: StructDefinition) { + private createStructDecoding( + type: UserDefinedType, + definition: StructDefinition, + ): CairoFunctionDefinition { const key = type.pp(); const existing = this.auxiliarGeneratedFunctions.get(key); - if (existing !== undefined) return existing.name; + if (existing !== undefined) return existing; let indexWalked = 0; let structWriteLocation = 0; - const instructions = definition.vMembers.map((member, index) => { - const type = generalizeType(safeGetNodeType(member, this.ast.inference))[0]; - const elemWidth = CairoType.fromSol(type, this.ast, TypeConversionContext.Ref).width; - const decodingCode = this.generateDecodingCode( - type, - 'mem_index', - `member${index}`, - `${indexWalked}`, - ); - indexWalked += Number(getByteSize(type, this.ast.inference)); - structWriteLocation += index * elemWidth; - const getMemLocCode = `let mem_to_write_loc = ${add('struct_ptr', structWriteLocation)};`; - const writeMemLocCode = `${this.memoryWrite.getOrCreate( - type, - )}(mem_to_write_loc, member${index});`; - return [ - `// Decoding member ${member.name}`, - `${decodingCode}`, - `${getMemLocCode}`, - `${writeMemLocCode}`, - ].join('\n'); - }); + const decodingInfo: [string, CairoFunctionDefinition[]][] = definition.vMembers.map( + (member, index) => { + const [type] = generalizeType(safeGetNodeType(member, this.ast.inference)); + const elemWidth = CairoType.fromSol(type, this.ast, TypeConversionContext.Ref).width; + const [decodingCode, functionsCalled] = this.generateDecodingCode( + type, + 'mem_index', + `member${index}`, + `${indexWalked}`, + ); + indexWalked += Number(getByteSize(type, this.ast.inference)); + structWriteLocation += index * elemWidth; + const getMemLocCode = `let mem_to_write_loc = ${add('struct_ptr', structWriteLocation)};`; + + const writeMemLocFunc = this.memoryWrite.getOrCreateFuncDef(type); + const writeMemLocCode = `${writeMemLocFunc.name}(mem_to_write_loc, member${index});`; + return [ + [ + `// Decoding member ${member.name}`, + `${decodingCode}`, + `${getMemLocCode}`, + `${writeMemLocCode}`, + ].join('\n'), + [...functionsCalled, writeMemLocFunc], + ]; + }, + ); + + const instructions = decodingInfo.map((info) => info[0]); + const functionsCalled = decodingInfo.flatMap((info) => info[1]); const name = `${this.functionName}_struct_${definition.name}`; const code = [ @@ -442,21 +502,28 @@ export class AbiDecode extends StringIndexedFuncGenWithAuxiliar { `}`, ].join('\n'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); - this.auxiliarGeneratedFunctions.set(key, { name, code }); - return name; + const importedFuncs = [this.requireImport('warplib.maths.utils', 'felt_to_uint256')]; + const genFuncInfo = { + name, + code, + functionsCalled: [...importedFuncs, ...functionsCalled], + }; + const auxFunc = this.createAuxiliarGeneratedFunction(genFuncInfo); + + this.auxiliarGeneratedFunctions.set(key, auxFunc); + return auxFunc; } - private createStringBytesDecoding(): string { + private createStringBytesDecoding(): CairoFunctionDefinition { const funcName = 'memory_dyn_array_copy'; - this.requireImport('warplib.dynamic_arrays_util', funcName); - return funcName; + const importedFunc = this.requireImport('warplib.dynamic_arrays_util', funcName); + return importedFunc; } - private createValueTypeDecoding(byteSize: number | bigint): string { + private createValueTypeDecoding(byteSize: number | bigint): CairoFunctionDefinition { const funcName = byteSize === 32 ? 'byte_array_to_uint256_value' : 'byte_array_to_felt_value'; - this.requireImport('warplib.dynamic_arrays_util', funcName); - return funcName; + const importedFunc = this.requireImport('warplib.dynamic_arrays_util', funcName); + return importedFunc; } } diff --git a/src/cairoUtilFuncGen/abi/abiEncode.ts b/src/cairoUtilFuncGen/abi/abiEncode.ts index e385eef5d..c55c3793b 100644 --- a/src/cairoUtilFuncGen/abi/abiEncode.ts +++ b/src/cairoUtilFuncGen/abi/abiEncode.ts @@ -8,6 +8,7 @@ import { UserDefinedType, } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { CairoFunctionDefinition } from '../../export'; import { printTypeNode } from '../../utils/astPrinter'; import { CairoType, MemoryLocation, TypeConversionContext } from '../../utils/cairoTypeSystem'; import { TranspileFailedError } from '../../utils/errors'; @@ -22,7 +23,7 @@ import { safeGetNodeType, } from '../../utils/nodeTypeProcessing'; import { uint256 } from '../../warplib/utils'; -import { delegateBasedOnType, mul } from '../base'; +import { delegateBasedOnType, GeneratedFunctionInfo, mul } from '../base'; import { MemoryReadGen } from '../memory/memoryRead'; import { AbiBase, removeSizeInfo } from './base'; @@ -42,23 +43,29 @@ export class AbiEncode extends AbiBase { this.memoryRead = memoryRead; } - public getOrCreate(types: TypeNode[]): string { - const key = types.map((t) => t.pp()).join(','); - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - - const [params, encodings] = types.reduce( - ([params, encodings], type, index) => { + public getOrCreate(types: TypeNode[]): GeneratedFunctionInfo { + const [params, encodings, functionsCalled] = types.reduce( + ([params, encodings, functionsCalled], type, index) => { const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.Ref); params.push({ name: `param${index}`, type: cairoType.toString() }); - encodings.push( - this.generateEncodingCode(type, 'bytes_index', 'bytes_offset', '0', `param${index}`), + + const [paramEncoding, paramFunctionsCalled] = this.generateEncodingCode( + type, + 'bytes_index', + 'bytes_offset', + '0', + `param${index}`, ); - return [params, encodings]; + + encodings.push(paramEncoding); + + return [params, encodings, functionsCalled.concat(paramFunctionsCalled)]; }, - [new Array<{ name: string; type: string }>(), new Array()], + [ + new Array<{ name: string; type: string }>(), + new Array(), + new Array(), + ], ); const initialOffset = types.reduce( @@ -67,7 +74,7 @@ export class AbiEncode extends AbiBase { ); const cairoParams = params.map((p) => `${p.name} : ${p.type}`).join(', '); - const funcName = `${this.functionName}${this.generatedFunctions.size}`; + const funcName = `${this.functionName}${this.generatedFunctionsDef.size}`; const code = [ `func ${funcName}${IMPLICITS}(${cairoParams}) -> (result_ptr : felt){`, ` alloc_locals;`, @@ -82,16 +89,20 @@ export class AbiEncode extends AbiBase { `}`, ].join('\n'); - this.requireImport('starkware.cairo.common.alloc', 'alloc'); - this.requireImport('starkware.cairo.common.cairo_builtins', 'BitwiseBuiltin'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); - this.requireImport('warplib.memory', 'wm_new'); - this.requireImport('warplib.dynamic_arrays_util', 'felt_array_to_warp_memory_array'); - - const cairoFunc = { name: funcName, code: code }; - this.generatedFunctions.set(key, cairoFunc); - return cairoFunc.name; + const importedFuncs = [ + this.requireImport('starkware.cairo.common.alloc', 'alloc'), + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('warplib.maths.utils', 'felt_to_uint256'), + this.requireImport('warplib.memory', 'wm_new'), + this.requireImport('warplib.dynamic_arrays_util', 'felt_array_to_warp_memory_array'), + ]; + + const funcInfo = { + name: funcName, + code: code, + functionsCalled: [...importedFuncs, ...functionsCalled], + }; + return funcInfo; } /** @@ -99,12 +110,12 @@ export class AbiEncode extends AbiBase { * @param type type to encode * @returns the name of the generated function */ - public getOrCreateEncoding(type: TypeNode): string { + public getOrCreateEncoding(type: TypeNode): CairoFunctionDefinition { const unexpectedType = () => { throw new TranspileFailedError(`Encoding ${printTypeNode(type)} is not supported yet`); }; - return delegateBasedOnType( + return delegateBasedOnType( type, (type) => type instanceof ArrayType @@ -139,67 +150,78 @@ export class AbiEncode extends AbiBase { newOffsetVar: string, elementOffset: string, varToEncode: string, - ): string { - const funcName = this.getOrCreateEncoding(type); + ): [string, CairoFunctionDefinition[]] { + const func = this.getOrCreateEncoding(type); if (isDynamicallySized(type, this.ast.inference) || isStruct(type)) { return [ - `let (${newIndexVar}, ${newOffsetVar}) = ${funcName}(`, - ` bytes_index,`, - ` bytes_offset,`, - ` bytes_array,`, - ` ${elementOffset},`, - ` ${varToEncode}`, - `);`, - ].join('\n'); + [ + `let (${newIndexVar}, ${newOffsetVar}) = ${func.name}(`, + ` bytes_index,`, + ` bytes_offset,`, + ` bytes_array,`, + ` ${elementOffset},`, + ` ${varToEncode}`, + `);`, + ].join('\n'), + [func], + ]; } // Static array with known compile time size if (type instanceof ArrayType) { assert(type.size !== undefined); return [ - `let (${newIndexVar}, ${newOffsetVar}) = ${funcName}(`, - ` bytes_index,`, - ` bytes_offset,`, - ` bytes_array,`, - ` ${elementOffset},`, - ` 0,`, - ` ${type.size},`, - ` ${varToEncode},`, - `);`, - ].join('\n'); + [ + `let (${newIndexVar}, ${newOffsetVar}) = ${func.name}(`, + ` bytes_index,`, + ` bytes_offset,`, + ` bytes_array,`, + ` ${elementOffset},`, + ` 0,`, + ` ${type.size},`, + ` ${varToEncode},`, + `);`, + ].join('\n'), + [func], + ]; } // Is value type const size = getPackedByteSize(type, this.ast.inference); const instructions: string[] = []; + + const funcsCalled: CairoFunctionDefinition[] = [func]; // packed size of addresses is 32 bytes, but they are treated as felts, // so they should be converted to Uint256 accordingly if (size < 32 || isAddressType(type)) { - this.requireImport(`warplib.maths.utils`, 'felt_to_uint256'); + funcsCalled.push(this.requireImport(`warplib.maths.utils`, 'felt_to_uint256')); instructions.push(`let (${varToEncode}256) = felt_to_uint256(${varToEncode});`); varToEncode = `${varToEncode}256`; } instructions.push( ...[ - `${funcName}(bytes_index, bytes_array, 0, ${varToEncode});`, + `${func.name}(bytes_index, bytes_array, 0, ${varToEncode});`, `let ${newIndexVar} = bytes_index + 32;`, ], ); if (newOffsetVar !== 'bytes_offset') { instructions.push(`let ${newOffsetVar} = bytes_offset;`); } - return instructions.join('\n'); + + return [instructions.join('\n'), funcsCalled]; } - private createDynamicArrayHeadEncoding(type: ArrayType): string { + private createDynamicArrayHeadEncoding(type: ArrayType): CairoFunctionDefinition { const key = 'head ' + type.pp(); const existing = this.auxiliarGeneratedFunctions.get(key); - if (existing !== undefined) return existing.name; + if (existing !== undefined) return existing; const elementT = getElementType(type); const elementByteSize = getByteSize(elementT, this.ast.inference); const tailEncoding = this.createDynamicArrayTailEncoding(type); + const valueEncoding = this.createValueTypeHeadEncoding(); + const name = `${this.functionName}_head_dynamic_array${this.auxiliarGeneratedFunctions.size}`; const code = [ `func ${name}${IMPLICITS}(`, @@ -212,16 +234,16 @@ export class AbiEncode extends AbiBase { ` alloc_locals;`, ` // Storing pointer to data`, ` let (bytes_offset256) = felt_to_uint256(bytes_offset - element_offset);`, - ` ${this.createValueTypeHeadEncoding()}(bytes_index, bytes_array, 0, bytes_offset256);`, + ` ${valueEncoding.name}(bytes_index, bytes_array, 0, bytes_offset256);`, ` let new_index = bytes_index + 32;`, ` // Storing the length`, ` let (length256) = wm_dyn_array_length(mem_ptr);`, - ` ${this.createValueTypeHeadEncoding()}(bytes_offset, bytes_array, 0, length256);`, + ` ${valueEncoding.name}(bytes_offset, bytes_array, 0, length256);`, ` let bytes_offset = bytes_offset + 32;`, ` // Storing the data`, ` let (length) = narrow_safe(length256);`, ` let bytes_offset_offset = bytes_offset + ${mul('length', elementByteSize)};`, - ` let (extended_offset) = ${tailEncoding}(`, + ` let (extended_offset) = ${tailEncoding.name}(`, ` bytes_offset,`, ` bytes_offset_offset,`, ` bytes_array,`, @@ -237,24 +259,33 @@ export class AbiEncode extends AbiBase { `}`, ].join('\n'); - this.requireImport('warplib.memory', 'wm_dyn_array_length'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); - this.requireImport('warplib.maths.utils', 'narrow_safe'); + const importedFuncs = [ + this.requireImport('warplib.memory', 'wm_dyn_array_length'), + this.requireImport('warplib.maths.utils', 'felt_to_uint256'), + this.requireImport('warplib.maths.utils', 'narrow_safe'), + ]; - this.auxiliarGeneratedFunctions.set(key, { name, code }); - return name; + const genFuncInfo = { + name, + code, + functionsCalled: [...importedFuncs, valueEncoding, tailEncoding], + }; + const auxFunc = this.createAuxiliarGeneratedFunction(genFuncInfo); + + this.auxiliarGeneratedFunctions.set(key, auxFunc); + return auxFunc; } - private createDynamicArrayTailEncoding(type: ArrayType): string { + private createDynamicArrayTailEncoding(type: ArrayType): CairoFunctionDefinition { const key = 'tail ' + type.pp(); const existing = this.auxiliarGeneratedFunctions.get(key); - if (existing !== undefined) return existing.name; + if (existing !== undefined) return existing; const elementT = getElementType(type); const elemntTSize = CairoType.fromSol(elementT, this.ast).width; - const readElement = this.readMemory(elementT, 'elem_loc'); - const headEncodingCode = this.generateEncodingCode( + const [readElement, readFunc] = this.readMemory(elementT, 'elem_loc'); + const [headEncodingCode, functionsCalled] = this.generateEncodingCode( elementT, 'new_bytes_index', 'new_bytes_offset', @@ -284,23 +315,33 @@ export class AbiEncode extends AbiBase { `}`, ].join('\n'); - this.requireImport('warplib.memory', 'wm_index_dyn'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); + const importedFuncs = [ + this.requireImport('warplib.memory', 'wm_index_dyn'), + this.requireImport('warplib.maths.utils', 'felt_to_uint256'), + ]; + + const genFuncInfo = { + name, + code, + functionsCalled: [...importedFuncs, ...functionsCalled, readFunc], + }; + const auxFunc = this.createAuxiliarGeneratedFunction(genFuncInfo); - this.auxiliarGeneratedFunctions.set(key, { name, code }); - return name; + this.auxiliarGeneratedFunctions.set(key, auxFunc); + return auxFunc; } - private createStaticArrayHeadEncoding(type: ArrayType): string { + private createStaticArrayHeadEncoding(type: ArrayType): CairoFunctionDefinition { assert(type.size !== undefined); const key = 'head ' + type.pp(); const existing = this.auxiliarGeneratedFunctions.get(key); - if (existing !== undefined) return existing.name; + if (existing !== undefined) return existing; const elementT = getElementType(type); const elementByteSize = getByteSize(elementT, this.ast.inference); const inlineEncoding = this.createArrayInlineEncoding(type); + const valueEncoding = this.createValueTypeHeadEncoding(); const name = `${this.functionName}_head_static_array${this.auxiliarGeneratedFunctions.size}`; const code = [ @@ -314,12 +355,12 @@ export class AbiEncode extends AbiBase { ` alloc_locals;`, ` // Storing pointer to data`, ` let (bytes_offset256) = felt_to_uint256(bytes_offset - element_offset);`, - ` ${this.createValueTypeHeadEncoding()}(bytes_index, bytes_array, 0, bytes_offset256);`, + ` ${valueEncoding.name}(bytes_index, bytes_array, 0, bytes_offset256);`, ` let new_bytes_index = bytes_index + 32;`, ` // Storing the data`, ` let length = ${type.size};`, ` let bytes_offset_offset = bytes_offset + ${mul('length', elementByteSize)};`, - ` let (_, extended_offset) = ${inlineEncoding}(`, + ` let (_, extended_offset) = ${inlineEncoding.name}(`, ` bytes_offset,`, ` bytes_offset_offset,`, ` bytes_array,`, @@ -335,22 +376,29 @@ export class AbiEncode extends AbiBase { `}`, ].join('\n'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); + const importedFunc = this.requireImport('warplib.maths.utils', 'felt_to_uint256'); - this.auxiliarGeneratedFunctions.set(key, { name, code }); - return name; + const genFuncInfo = { + name, + code, + functionsCalled: [importedFunc, inlineEncoding, valueEncoding], + }; + const auxFunc = this.createAuxiliarGeneratedFunction(genFuncInfo); + + this.auxiliarGeneratedFunctions.set(key, auxFunc); + return auxFunc; } - private createArrayInlineEncoding(type: ArrayType) { + private createArrayInlineEncoding(type: ArrayType): CairoFunctionDefinition { const key = 'inline ' + removeSizeInfo(type); const existing = this.auxiliarGeneratedFunctions.get(key); - if (existing !== undefined) return existing.name; + if (existing !== undefined) return existing; const elementTWidth = CairoType.fromSol(type.elementT, this.ast).width; - const readElement = this.readMemory(type.elementT, 'elem_loc'); + const [readElement, readFunc] = this.readMemory(type.elementT, 'elem_loc'); - const headEncodingCode = this.generateEncodingCode( + const [headEncodingCode, functionsCalled] = this.generateEncodingCode( type.elementT, 'new_bytes_index', 'new_bytes_offset', @@ -388,16 +436,21 @@ export class AbiEncode extends AbiBase { `}`, ].join('\n'); - this.auxiliarGeneratedFunctions.set(key, { name, code }); - return name; + const genFuncInfo = { name, code, functionsCalled: [...functionsCalled, readFunc] }; + const auxFunc = this.createAuxiliarGeneratedFunction(genFuncInfo); + + this.auxiliarGeneratedFunctions.set(key, auxFunc); + return auxFunc; } - private createStructHeadEncoding(type: UserDefinedType, def: StructDefinition) { + private createStructHeadEncoding( + type: UserDefinedType, + def: StructDefinition, + ): CairoFunctionDefinition { const key = 'struct head ' + type.pp(); const existing = this.auxiliarGeneratedFunctions.get(key); - if (existing !== undefined) return existing.name; + if (existing !== undefined) return existing; - const inlineEncoding = this.createStructInlineEncoding(type, def); // Get the size of all it's members const typeByteSize = def.vMembers.reduce( (sum, varDecl) => @@ -411,6 +464,9 @@ export class AbiEncode extends AbiBase { 0n, ); + const inlineEncoding = this.createStructInlineEncoding(type, def); + const valueEncoding = this.createValueTypeHeadEncoding(); + const name = `${this.functionName}_head_${def.name}`; const code = [ `func ${name}${IMPLICITS}(`, @@ -423,11 +479,11 @@ export class AbiEncode extends AbiBase { ` alloc_locals;`, ` // Storing pointer to data`, ` let (bytes_offset256) = felt_to_uint256(bytes_offset - element_offset);`, - ` ${this.createValueTypeHeadEncoding()}(bytes_index, bytes_array, 0, bytes_offset256);`, + ` ${valueEncoding.name}(bytes_index, bytes_array, 0, bytes_offset256);`, ` let new_bytes_index = bytes_index + 32;`, ` // Storing the data`, ` let bytes_offset_offset = bytes_offset + ${typeByteSize};`, - ` let (_, new_bytes_offset) = ${inlineEncoding}(`, + ` let (_, new_bytes_offset) = ${inlineEncoding.name}(`, ` bytes_offset,`, ` bytes_offset_offset,`, ` bytes_array,`, @@ -438,34 +494,55 @@ export class AbiEncode extends AbiBase { `}`, ].join('\n'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); - this.auxiliarGeneratedFunctions.set(key, { name, code }); - return name; + const genFuncInfo = { + name, + code, + functionsCalled: [ + this.requireImport('warplib.maths.utils', 'felt_to_uint256'), + inlineEncoding, + valueEncoding, + ], + }; + const auxFunc = this.createAuxiliarGeneratedFunction(genFuncInfo); + + this.auxiliarGeneratedFunctions.set(key, auxFunc); + return auxFunc; } - private createStructInlineEncoding(type: UserDefinedType, def: StructDefinition) { + private createStructInlineEncoding( + type: UserDefinedType, + def: StructDefinition, + ): CairoFunctionDefinition { const key = 'struct inline ' + type.pp(); const existing = this.auxiliarGeneratedFunctions.get(key); - if (existing !== undefined) return existing.name; - - const instructions = def.vMembers.map((member, index) => { - const type = generalizeType(safeGetNodeType(member, this.ast.inference))[0]; - const elemWidth = CairoType.fromSol(type, this.ast).width; - const readFunc = this.readMemory(type, 'mem_ptr'); - const encoding = this.generateEncodingCode( - type, - 'bytes_index', - 'bytes_offset', - 'element_offset', - `elem${index}`, - ); - return [ - `// Encoding member ${member.name}`, - `let (elem${index}) = ${readFunc};`, - `${encoding}`, - `let mem_ptr = mem_ptr + ${elemWidth};`, - ].join('\n'); - }); + if (existing !== undefined) return existing; + + const decodingInfo: [string, CairoFunctionDefinition[]][] = def.vMembers.map( + (member, index) => { + const type = generalizeType(safeGetNodeType(member, this.ast.inference))[0]; + const elemWidth = CairoType.fromSol(type, this.ast).width; + const [readElement, readFunc] = this.readMemory(type, 'mem_ptr'); + const [encoding, funcsCalled] = this.generateEncodingCode( + type, + 'bytes_index', + 'bytes_offset', + 'element_offset', + `elem${index}`, + ); + return [ + [ + `// Encoding member ${member.name}`, + `let (elem${index}) = ${readElement};`, + `${encoding}`, + `let mem_ptr = mem_ptr + ${elemWidth};`, + ].join('\n'), + [...funcsCalled, readFunc], + ]; + }, + ); + + const instructions = decodingInfo.map((info) => info[0]); + const functionsCalled = decodingInfo.flatMap((info) => info[1]); const name = `${this.functionName}_inline_struct_${def.name}`; const code = [ @@ -482,29 +559,30 @@ export class AbiEncode extends AbiBase { `}`, ].join('\n'); - this.auxiliarGeneratedFunctions.set(key, { name, code }); - return name; + const genFuncInfo = { name, code, functionsCalled }; + const auxFunc = this.createAuxiliarGeneratedFunction(genFuncInfo); + + this.auxiliarGeneratedFunctions.set(key, auxFunc); + return auxFunc; } - private createStringOrBytesHeadEncoding(): string { + private createStringOrBytesHeadEncoding(): CairoFunctionDefinition { const funcName = 'bytes_to_felt_dynamic_array'; - this.requireImport('warplib.dynamic_arrays_util', funcName); - return funcName; + return this.requireImport('warplib.dynamic_arrays_util', funcName); } - private createValueTypeHeadEncoding(): string { + private createValueTypeHeadEncoding(): CairoFunctionDefinition { const funcName = 'fixed_bytes256_to_felt_dynamic_array'; - this.requireImport('warplib.dynamic_arrays_util', funcName); - return funcName; + return this.requireImport('warplib.dynamic_arrays_util', funcName); } - protected readMemory(type: TypeNode, arg: string) { + protected readMemory(type: TypeNode, arg: string): [string, CairoFunctionDefinition] { + const func = this.memoryRead.getOrCreateFuncDef(type); const cairoType = CairoType.fromSol(type, this.ast); - const funcName = this.memoryRead.getOrCreate(cairoType); const args = cairoType instanceof MemoryLocation ? [arg, isDynamicArray(type) ? uint256(2) : uint256(0)] : [arg]; - return `${funcName}(${args.join(',')})`; + return [`${func.name}(${args.join(',')})`, func]; } } diff --git a/src/cairoUtilFuncGen/abi/abiEncodePacked.ts b/src/cairoUtilFuncGen/abi/abiEncodePacked.ts index 2b89e4df9..581852b0d 100644 --- a/src/cairoUtilFuncGen/abi/abiEncodePacked.ts +++ b/src/cairoUtilFuncGen/abi/abiEncodePacked.ts @@ -1,5 +1,14 @@ -import { ArrayType, BytesType, SourceUnit, StringType, TypeNode } from 'solc-typed-ast'; +import { + ArrayType, + BytesType, + FunctionDefinition, + SourceUnit, + StringType, + TypeNode, +} from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { CairoFunctionDefinition } from '../../ast/cairoNodes'; +import { GeneratedFunctionInfo } from '../base'; import { printTypeNode } from '../../utils/astPrinter'; import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; import { TranspileFailedError } from '../../utils/errors'; @@ -30,25 +39,29 @@ export class AbiEncodePacked extends AbiBase { this.memoryRead = memoryRead; } - public getOrCreate(types: TypeNode[]): string { - const key = types.map((t) => t.pp()).join(','); - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - - const [params, encodings] = types.reduce( - ([params, encodings], type, index) => { + public getOrCreate(types: TypeNode[]): GeneratedFunctionInfo { + const [params, encodings, functionsCalled] = types.reduce( + ([params, encodings, functionsCalled], type, index) => { const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.Ref); params.push({ name: `param${index}`, type: cairoType.toString() }); - encodings.push(this.generateEncodingCode(type, 'bytes_index', `param${index}`)); - return [params, encodings]; + const [paramEncoding, paramFuncCalls] = this.generateEncodingCode( + type, + 'bytes_index', + `param${index}`, + ); + encodings.push(paramEncoding); + + return [params, encodings, functionsCalled.concat(paramFuncCalls)]; }, - [new Array<{ name: string; type: string }>(), new Array()], + [ + new Array<{ name: string; type: string }>(), + new Array(), + new Array(), + ], ); const cairoParams = params.map((p) => `${p.name} : ${p.type}`).join(', '); - const funcName = `${this.functionName}${this.generatedFunctions.size}`; + const funcName = `${this.functionName}${this.generatedFunctionsDef.size}`; const code = [ `func ${funcName}${IMPLICITS}(${cairoParams}) -> (result_ptr : felt){`, ` alloc_locals;`, @@ -62,24 +75,27 @@ export class AbiEncodePacked extends AbiBase { `}`, ].join('\n'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('starkware.cairo.common.alloc', 'alloc'); - this.requireImport('starkware.cairo.common.cairo_builtins', 'BitwiseBuiltin'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); - this.requireImport('warplib.memory', 'wm_new'); - this.requireImport('warplib.dynamic_arrays_util', 'felt_array_to_warp_memory_array'); - - const cairoFunc = { name: funcName, code: code }; - this.generatedFunctions.set(key, cairoFunc); - return cairoFunc.name; + const importedFuncs = [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('starkware.cairo.common.alloc', 'alloc'), + this.requireImport('warplib.maths.utils', 'felt_to_uint256'), + this.requireImport('warplib.memory', 'wm_new'), + this.requireImport('warplib.dynamic_arrays_util', 'felt_array_to_warp_memory_array'), + ]; + + return { + name: funcName, + code: code, + functionsCalled: [...importedFuncs, ...functionsCalled], + }; } - public override getOrCreateEncoding(type: TypeNode): string { + public override getOrCreateEncoding(type: TypeNode): CairoFunctionDefinition { const unexpectedType = () => { throw new TranspileFailedError(`Encoding ${printTypeNode(type)} is not supported`); }; - return delegateBasedOnType( + return delegateBasedOnType( type, (type) => this.createArrayInlineEncoding(type), (type) => this.createArrayInlineEncoding(type), @@ -89,34 +105,50 @@ export class AbiEncodePacked extends AbiBase { ); } - private generateEncodingCode(type: TypeNode, newIndexVar: string, varToEncode: string): string { + private generateEncodingCode( + type: TypeNode, + newIndexVar: string, + varToEncode: string, + ): [string, CairoFunctionDefinition[]] { // Cairo address are 251 bits in size but solidity is 160. // It was decided to store them fully before just a part if (isAddressType(type)) { - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); - this.requireImport('warplib.dynamic_arrays_util', 'fixed_bytes256_to_felt_dynamic_array'); return [ - `let (${varToEncode}256) = felt_to_uint256(${varToEncode});`, - `fixed_bytes256_to_felt_dynamic_array(bytes_index, bytes_array, 0, ${varToEncode}256);`, - `let ${newIndexVar} = bytes_index + 32;`, - ].join('\n'); + [ + `let (${varToEncode}256) = felt_to_uint256(${varToEncode});`, + `fixed_bytes256_to_felt_dynamic_array(bytes_index, bytes_array, 0, ${varToEncode}256);`, + `let ${newIndexVar} = bytes_index + 32;`, + ].join('\n'), + [ + this.requireImport('warplib.maths.utils', 'felt_to_uint256'), + this.requireImport('warplib.dynamic_arrays_util', 'fixed_bytes256_to_felt_dynamic_array'), + ], + ]; } - const funcName = this.getOrCreateEncoding(type); + const func = this.getOrCreateEncoding(type); if (isDynamicArray(type)) { - this.requireImport('warplib.memory', 'wm_dyn_array_length'); - this.requireImport('warplib.maths.utils', 'narrow_safe'); return [ - `let (length256) = wm_dyn_array_length(${varToEncode});`, - `let (length) = narrow_safe(length256);`, - `let (${newIndexVar}) = ${funcName}(bytes_index, bytes_array, 0, length, ${varToEncode});`, - ].join('\n'); + [ + `let (length256) = wm_dyn_array_length(${varToEncode});`, + `let (length) = narrow_safe(length256);`, + `let (${newIndexVar}) = ${func.name}(bytes_index, bytes_array, 0, length, ${varToEncode});`, + ].join('\n'), + [ + this.requireImport('warplib.memory', 'wm_dyn_array_length'), + this.requireImport('warplib.maths.utils', 'narrow_safe'), + func, + ], + ]; } // Type is a static array if (type instanceof ArrayType) { - return `let (${newIndexVar}) = ${funcName}(bytes_index, bytes_array, 0, ${type.size}, ${varToEncode});`; + return [ + `let (${newIndexVar}) = ${func.name}(bytes_index, bytes_array, 0, ${type.size}, ${varToEncode});`, + [func], + ]; } // Type is value type @@ -125,19 +157,24 @@ export class AbiEncodePacked extends AbiBase { if (packedByteSize < 32) args.push(`${packedByteSize}`); return [ - `${funcName}(${args.join(',')});`, - `let ${newIndexVar} = bytes_index + ${packedByteSize};`, - ].join('\n'); + [ + `${func.name}(${args.join(',')});`, + `let ${newIndexVar} = bytes_index + ${packedByteSize};`, + ].join('\n'), + [func], + ]; } /* * Produce inline array encoding for static and dynamic array types */ - private createArrayInlineEncoding(type: ArrayType | BytesType | StringType): string { + private createArrayInlineEncoding( + type: ArrayType | BytesType | StringType, + ): CairoFunctionDefinition { const key = type.pp(); const existing = this.auxiliarGeneratedFunctions.get(key); if (existing !== undefined) { - return existing.name; + return existing; } const elementT = getElementType(type); @@ -152,10 +189,14 @@ export class AbiEncodePacked extends AbiBase { ].join('\n') : `let elem_loc : felt = mem_ptr + ${mul('mem_index', cairoElementT.width)};`; - const readFunc = this.memoryRead.getOrCreate(cairoElementT); - const readCode = `let (elem) = ${readFunc}(elem_loc);`; + const readFunc = this.memoryRead.getOrCreateFuncDef(elementT); + const readCode = `let (elem) = ${readFunc.name}(elem_loc);`; - const encodingCode = this.generateEncodingCode(elementT, 'new_bytes_index', 'elem'); + const [encodingCode, funcCalls] = this.generateEncodingCode( + elementT, + 'new_bytes_index', + 'elem', + ); const name = `${this.functionName}_inline_array${this.auxiliarGeneratedFunctions.size}`; const code = [ @@ -183,19 +224,24 @@ export class AbiEncodePacked extends AbiBase { `}`, ].join('\n'); - if (isDynamicArray(type)) { - this.requireImport('warplib.memory', 'wm_index_dyn'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); - } + const importedFuncs = isDynamicArray(type) + ? [ + this.requireImport('warplib.memory', 'wm_index_dyn'), + this.requireImport('warplib.maths.utils', 'felt_to_uint256'), + ] + : []; + + const genFuncInfo = { name, code, functionsCalled: [...importedFuncs, ...funcCalls, readFunc] }; + const auxFunc = this.createAuxiliarGeneratedFunction(genFuncInfo); - this.auxiliarGeneratedFunctions.set(key, { name, code }); - return name; + this.auxiliarGeneratedFunctions.set(key, auxFunc); + return auxFunc; } - private createValueTypeHeadEncoding(size: number | bigint): string { + private createValueTypeHeadEncoding(size: number | bigint): CairoFunctionDefinition { const funcName = size === 32 ? 'fixed_bytes256_to_felt_dynamic_array' : `fixed_bytes_to_felt_dynamic_array`; - this.requireImport('warplib.dynamic_arrays_util', funcName); - return funcName; + + return this.requireImport('warplib.dynamic_arrays_util', funcName); } } diff --git a/src/cairoUtilFuncGen/abi/abiEncodeWithSelector.ts b/src/cairoUtilFuncGen/abi/abiEncodeWithSelector.ts index 8f5afea26..adac1aa99 100644 --- a/src/cairoUtilFuncGen/abi/abiEncodeWithSelector.ts +++ b/src/cairoUtilFuncGen/abi/abiEncodeWithSelector.ts @@ -1,10 +1,12 @@ import { FixedBytesType, SourceUnit, TypeNode } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { CairoFunctionDefinition } from '../../export'; import { printTypeNode } from '../../utils/astPrinter'; import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; import { TranspileFailedError } from '../../utils/errors'; import { getByteSize } from '../../utils/nodeTypeProcessing'; import { uint256 } from '../../warplib/utils'; +import { GeneratedFunctionInfo } from '../base'; import { AbiEncode } from './abiEncode'; import { AbiBase } from './base'; @@ -20,7 +22,7 @@ export class AbiEncodeWithSelector extends AbiBase { this.abiEncode = abiEncode; } - public override getOrCreate(types: TypeNode[]): string { + public override getOrCreate(types: TypeNode[]): GeneratedFunctionInfo { const selector = types[0]; if (!(selector instanceof FixedBytesType && selector.size === 4)) { throw new TranspileFailedError( @@ -31,26 +33,20 @@ export class AbiEncodeWithSelector extends AbiBase { } types = types.slice(1); - const key = types.map((t) => t.pp()).join(','); - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - - const [params, encodings] = types.reduce( - ([params, encodings], type, index) => { + const [params, encodings, functionsCalled] = types.reduce( + ([params, encodings, functionsCalled], type, index) => { const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.Ref); params.push({ name: `param${index}`, type: cairoType.toString() }); - encodings.push( - this.abiEncode.generateEncodingCode( - type, - 'bytes_index', - 'bytes_offset', - '4', - `param${index}`, - ), + const [paramEncodings, paramFuncCalls] = this.abiEncode.generateEncodingCode( + type, + 'bytes_index', + 'bytes_offset', + '4', + `param${index}`, ); - return [params, encodings]; + + encodings.push(paramEncodings); + return [params, encodings, functionsCalled.concat(paramFuncCalls)]; }, [ [{ name: 'selector', type: 'felt' }], @@ -60,6 +56,7 @@ export class AbiEncodeWithSelector extends AbiBase { 'let bytes_index = bytes_index + 4;', ].join('\n'), ], + new Array(), ], ); @@ -69,7 +66,7 @@ export class AbiEncodeWithSelector extends AbiBase { ); const cairoParams = params.map((p) => `${p.name} : ${p.type}`).join(', '); - const funcName = `${this.functionName}${this.generatedFunctions.size}`; + const funcName = `${this.functionName}${this.generatedFunctionsDef.size}`; const code = [ `func ${funcName}${IMPLICITS}(${cairoParams}) -> (result_ptr : felt){`, ` alloc_locals;`, @@ -84,17 +81,21 @@ export class AbiEncodeWithSelector extends AbiBase { `}`, ].join('\n'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('starkware.cairo.common.alloc', 'alloc'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); - this.requireImport('warplib.memory', 'wm_new'); - this.requireImport('warplib.dynamic_arrays_util', 'felt_array_to_warp_memory_array'); - this.requireImport('warplib.dynamic_arrays_util', 'fixed_bytes_to_felt_dynamic_array'); - this.requireImport('warplib.keccak', 'warp_keccak'); - - const cairoFunc = { name: funcName, code: code }; - this.generatedFunctions.set(key, cairoFunc); + const importedFuncs = [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('starkware.cairo.common.alloc', 'alloc'), + this.requireImport('warplib.maths.utils', 'felt_to_uint256'), + this.requireImport('warplib.memory', 'wm_new'), + this.requireImport('warplib.dynamic_arrays_util', 'felt_array_to_warp_memory_array'), + this.requireImport('warplib.dynamic_arrays_util', 'fixed_bytes_to_felt_dynamic_array'), + this.requireImport('warplib.keccak', 'warp_keccak'), + ]; - return cairoFunc.name; + const funcInfo = { + name: funcName, + code: code, + functionsCalled: [...importedFuncs, ...functionsCalled], + }; + return funcInfo; } } diff --git a/src/cairoUtilFuncGen/abi/abiEncodeWithSignature.ts b/src/cairoUtilFuncGen/abi/abiEncodeWithSignature.ts index 6126296e0..599ae7a47 100644 --- a/src/cairoUtilFuncGen/abi/abiEncodeWithSignature.ts +++ b/src/cairoUtilFuncGen/abi/abiEncodeWithSignature.ts @@ -7,14 +7,16 @@ import { StringType, TypeNode, } from 'solc-typed-ast'; +import { CairoFunctionDefinition } from '../../export'; import { printTypeNode } from '../../utils/astPrinter'; import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; import { TranspileFailedError } from '../../utils/errors'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { createBytesTypeName } from '../../utils/nodeTemplates'; import { getByteSize, isValueType, safeGetNodeType } from '../../utils/nodeTypeProcessing'; import { typeNameFromTypeNode } from '../../utils/utils'; import { uint256 } from '../../warplib/utils'; +import { GeneratedFunctionInfo } from '../base'; import { AbiEncodeWithSelector } from './abiEncodeWithSelector'; const IMPLICITS = @@ -27,17 +29,16 @@ export class AbiEncodeWithSignature extends AbiEncodeWithSelector { const exprTypes = expressions.map( (expr) => generalizeType(safeGetNodeType(expr, this.ast.inference))[0], ); - const functionName = this.getOrCreate(exprTypes); + const funcInfo = this.getOrCreate(exprTypes); - const functionStub = createCairoFunctionStub( - functionName, + const functionStub = createCairoGeneratedFunction( + funcInfo, exprTypes.map((exprT, index) => isValueType(exprT) ? [`param${index}`, typeNameFromTypeNode(exprT, this.ast)] : [`param${index}`, typeNameFromTypeNode(exprT, this.ast), DataLocation.Memory], ), [['result', createBytesTypeName(this.ast), DataLocation.Memory]], - ['bitwise_ptr', 'keccak_ptr', 'range_check_ptr', 'warp_memory'], this.ast, sourceUnit ?? this.sourceUnit, ); @@ -45,7 +46,7 @@ export class AbiEncodeWithSignature extends AbiEncodeWithSelector { return createCallToFunction(functionStub, expressions, this.ast); } - public override getOrCreate(types: TypeNode[]): string { + public override getOrCreate(types: TypeNode[]): GeneratedFunctionInfo { const signature = types[0]; if (!(signature instanceof StringType)) { throw new TranspileFailedError( @@ -56,26 +57,20 @@ export class AbiEncodeWithSignature extends AbiEncodeWithSelector { } types = types.slice(1); - const key = types.map((t) => t.pp()).join(','); - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - - const [params, encodings] = types.reduce( - ([params, encodings], type, index) => { + const [params, encodings, functionsCalled] = types.reduce( + ([params, encodings, functionsCalled], type, index) => { const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.Ref); params.push({ name: `param${index}`, type: cairoType.toString() }); - encodings.push( - this.abiEncode.generateEncodingCode( - type, - 'bytes_index', - 'bytes_offset', - '4', - `param${index}`, - ), + const [paramEncodings, paramFuncCalls] = this.abiEncode.generateEncodingCode( + type, + 'bytes_index', + 'bytes_offset', + '4', + `param${index}`, ); - return [params, encodings]; + + encodings.push(paramEncodings); + return [params, encodings, functionsCalled.concat(paramFuncCalls)]; }, [ [{ name: 'signature', type: 'felt' }], @@ -93,6 +88,7 @@ export class AbiEncodeWithSignature extends AbiEncodeWithSelector { 'let bytes_index = bytes_index + 4;', ].join('\n'), ], + new Array(), ], ); @@ -102,7 +98,7 @@ export class AbiEncodeWithSignature extends AbiEncodeWithSelector { ); const cairoParams = params.map((p) => `${p.name} : ${p.type}`).join(', '); - const funcName = `${this.functionName}${this.generatedFunctions.size}`; + const funcName = `${this.functionName}${this.generatedFunctionsDef.size}`; const code = [ `func ${funcName}${IMPLICITS}(${cairoParams}) -> (result_ptr : felt){`, ` alloc_locals;`, @@ -117,17 +113,22 @@ export class AbiEncodeWithSignature extends AbiEncodeWithSelector { `}`, ].join('\n'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('starkware.cairo.common.alloc', 'alloc'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); - this.requireImport('warplib.memory', 'wm_new'); - this.requireImport('warplib.dynamic_arrays_util', 'felt_array_to_warp_memory_array'); - this.requireImport('warplib.maths.bytes_access', 'byte256_at_index'); - this.requireImport('warplib.keccak', 'warp_keccak'); + const importedFuncs = [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('starkware.cairo.common.alloc', 'alloc'), + this.requireImport('warplib.maths.utils', 'felt_to_uint256'), + this.requireImport('warplib.memory', 'wm_new'), + this.requireImport('warplib.dynamic_arrays_util', 'felt_array_to_warp_memory_array'), + this.requireImport('warplib.maths.bytes_access', 'byte256_at_index'), + this.requireImport('warplib.keccak', 'warp_keccak'), + ]; - const cairoFunc = { name: funcName, code: code }; - this.generatedFunctions.set(key, cairoFunc); + const cairoFunc = { + name: funcName, + code: code, + functionsCalled: [...importedFuncs, ...functionsCalled], + }; - return cairoFunc.name; + return cairoFunc; } } diff --git a/src/cairoUtilFuncGen/abi/base.ts b/src/cairoUtilFuncGen/abi/base.ts index 5a283089c..a8cd4eb95 100644 --- a/src/cairoUtilFuncGen/abi/base.ts +++ b/src/cairoUtilFuncGen/abi/base.ts @@ -5,45 +5,57 @@ import { Expression, FunctionCall, generalizeType, - SourceUnit, TypeNode, } from 'solc-typed-ast'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { CairoFunctionDefinition } from '../../ast/cairoNodes'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { createBytesTypeName } from '../../utils/nodeTemplates'; import { isValueType, safeGetNodeType } from '../../utils/nodeTypeProcessing'; import { typeNameFromTypeNode } from '../../utils/utils'; -import { StringIndexedFuncGenWithAuxiliar } from '../base'; +import { GeneratedFunctionInfo, StringIndexedFuncGenWithAuxiliar } from '../base'; export abstract class AbiBase extends StringIndexedFuncGenWithAuxiliar { protected functionName = 'not_implemented'; - public gen(expressions: Expression[], sourceUnit?: SourceUnit): FunctionCall { + public gen(expressions: Expression[]): FunctionCall { const exprTypes = expressions.map( (expr) => generalizeType(safeGetNodeType(expr, this.ast.inference))[0], ); - const functionName = this.getOrCreate(exprTypes); - const functionStub = createCairoFunctionStub( - functionName, - exprTypes.map((exprT, index) => + const generatedFunction = this.getOrCreateFuncDef(exprTypes); + + return createCallToFunction(generatedFunction, expressions, this.ast); + } + + public getOrCreateFuncDef(types: TypeNode[]): CairoFunctionDefinition { + const key = types.map((t) => t.pp()).join(','); + const existing = this.generatedFunctionsDef.get(key); + if (existing !== undefined) { + return existing; + } + + const genFuncInfo = this.getOrCreate(types); + const functionStub = createCairoGeneratedFunction( + genFuncInfo, + types.map((exprT, index) => isValueType(exprT) ? [`param${index}`, typeNameFromTypeNode(exprT, this.ast)] : [`param${index}`, typeNameFromTypeNode(exprT, this.ast), DataLocation.Memory], ), [['result', createBytesTypeName(this.ast), DataLocation.Memory]], - ['bitwise_ptr', 'range_check_ptr', 'warp_memory'], this.ast, - sourceUnit ?? this.sourceUnit, + this.sourceUnit, ); - return createCallToFunction(functionStub, expressions, this.ast); + this.generatedFunctionsDef.set(key, functionStub); + return functionStub; } - public getOrCreate(_types: TypeNode[]): string { + public getOrCreate(_types: TypeNode[]): GeneratedFunctionInfo { throw new Error('Method not implemented.'); } - public getOrCreateEncoding(_type: TypeNode): string { + public getOrCreateEncoding(_type: TypeNode): CairoFunctionDefinition { throw new Error('Method not implemented.'); } } diff --git a/src/cairoUtilFuncGen/abi/indexEncode.ts b/src/cairoUtilFuncGen/abi/indexEncode.ts index 6542da301..6fcf360d6 100644 --- a/src/cairoUtilFuncGen/abi/indexEncode.ts +++ b/src/cairoUtilFuncGen/abi/indexEncode.ts @@ -8,6 +8,8 @@ import { UserDefinedType, } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { CairoImportFunctionDefinition } from '../../ast/cairoNodes'; +import { CairoFunctionDefinition } from '../../export'; import { printTypeNode } from '../../utils/astPrinter'; import { CairoType, MemoryLocation, TypeConversionContext } from '../../utils/cairoTypeSystem'; import { TranspileFailedError } from '../../utils/errors'; @@ -21,7 +23,7 @@ import { safeGetNodeType, } from '../../utils/nodeTypeProcessing'; import { uint256 } from '../../warplib/utils'; -import { delegateBasedOnType, mul } from '../base'; +import { delegateBasedOnType, GeneratedFunctionInfo, mul } from '../base'; import { MemoryReadGen } from '../memory/memoryRead'; import { AbiBase, removeSizeInfo } from './base'; @@ -42,28 +44,29 @@ export class IndexEncode extends AbiBase { this.memoryRead = memoryRead; } - public getOrCreate(types: TypeNode[]): string { - const key = types.map((t) => t.pp()).join(','); - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - - const [params, encodings] = types.reduce( - ([params, encodings], type, index) => { + public getOrCreate(types: TypeNode[]): GeneratedFunctionInfo { + const [params, encodings, functionsCalled] = types.reduce( + ([params, encodings, functionsCalled], type, index) => { const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.Ref); params.push({ name: `param${index}`, type: cairoType.toString() }); - encodings.push( - // padding is not required for strings and bytes - this.generateEncodingCode(type, 'bytes_index', `param${index}`, false), + const [paramEncoding, paramFuncCalls] = this.generateEncodingCode( + type, + 'bytes_index', + `param${index}`, + false, ); - return [params, encodings]; + encodings.push(paramEncoding); + return [params, encodings, functionsCalled.concat(paramFuncCalls)]; }, - [new Array<{ name: string; type: string }>(), new Array()], + [ + new Array<{ name: string; type: string }>(), + new Array(), + new Array(), + ], ); const cairoParams = params.map((p) => `${p.name} : ${p.type}`).join(', '); - const funcName = `${this.functionName}${this.generatedFunctions.size}`; + const funcName = `${this.functionName}${this.generatedFunctionsDef.size}`; const code = [ `func ${funcName}${IMPLICITS}(${cairoParams}) -> (result_ptr : felt){`, ` alloc_locals;`, @@ -77,16 +80,21 @@ export class IndexEncode extends AbiBase { `}`, ].join('\n'); - this.requireImport('starkware.cairo.common.alloc', 'alloc'); - this.requireImport('starkware.cairo.common.cairo_builtins', 'BitwiseBuiltin'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); - this.requireImport('warplib.memory', 'wm_new'); - this.requireImport('warplib.dynamic_arrays_util', 'felt_array_to_warp_memory_array'); - - const cairoFunc = { name: funcName, code: code }; - this.generatedFunctions.set(key, cairoFunc); - return cairoFunc.name; + const importedFuncs = [ + this.requireImport('starkware.cairo.common.alloc', 'alloc'), + this.requireImport('starkware.cairo.common.cairo_builtins', 'BitwiseBuiltin'), + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('warplib.maths.utils', 'felt_to_uint256'), + this.requireImport('warplib.memory', 'wm_new'), + this.requireImport('warplib.dynamic_arrays_util', 'felt_array_to_warp_memory_array'), + ]; + + const cairoFunc = { + name: funcName, + code: code, + functionsCalled: [...importedFuncs, ...functionsCalled], + }; + return cairoFunc; } /** @@ -94,12 +102,12 @@ export class IndexEncode extends AbiBase { * @param type type to encode * @returns the name of the generated function */ - public getOrCreateEncoding(type: TypeNode, padding = true): string { + public getOrCreateEncoding(type: TypeNode, padding = true): CairoFunctionDefinition { const unexpectedType = () => { throw new TranspileFailedError(`Encoding ${printTypeNode(type)} is not supported yet`); }; - return delegateBasedOnType( + return delegateBasedOnType( type, (type) => type instanceof ArrayType @@ -133,55 +141,62 @@ export class IndexEncode extends AbiBase { newIndexVar: string, varToEncode: string, padding = true, - ): string { - const funcName = this.getOrCreateEncoding(type, padding); + ): [string, CairoFunctionDefinition[]] { + const func = this.getOrCreateEncoding(type, padding); if (isDynamicallySized(type, this.ast.inference) || isStruct(type)) { return [ - `let (${newIndexVar}) = ${funcName}(`, - ` bytes_index,`, - ` bytes_array,`, - ` ${varToEncode}`, - `);`, - ].join('\n'); + [ + `let (${newIndexVar}) = ${func.name}(`, + ` bytes_index,`, + ` bytes_array,`, + ` ${varToEncode}`, + `);`, + ].join('\n'), + [func], + ]; } // Static array with known compile time size if (type instanceof ArrayType) { assert(type.size !== undefined); return [ - `let (${newIndexVar}) = ${funcName}(`, - ` bytes_index,`, - ` bytes_array,`, - ` 0,`, - ` ${type.size},`, - ` ${varToEncode},`, - `);`, - ].join('\n'); + [ + `let (${newIndexVar}) = ${func.name}(`, + ` bytes_index,`, + ` bytes_array,`, + ` 0,`, + ` ${type.size},`, + ` ${varToEncode},`, + `);`, + ].join('\n'), + [func], + ]; } // Is value type const size = getPackedByteSize(type, this.ast.inference); const instructions: string[] = []; + const importedFunc = []; // packed size of addresses is 32 bytes, but they are treated as felts, // so they should be converted to Uint256 accordingly if (size < 32 || isAddressType(type)) { - this.requireImport(`warplib.maths.utils`, 'felt_to_uint256'); instructions.push(`let (${varToEncode}256) = felt_to_uint256(${varToEncode});`); + importedFunc.push(this.requireImport(`warplib.maths.utils`, 'felt_to_uint256')); varToEncode = `${varToEncode}256`; } instructions.push( ...[ - `${funcName}(bytes_index, bytes_array, 0, ${varToEncode});`, + `${func.name}(bytes_index, bytes_array, 0, ${varToEncode});`, `let ${newIndexVar} = bytes_index + 32;`, ], ); - return instructions.join('\n'); + return [instructions.join('\n'), importedFunc]; } - private createDynamicArrayHeadEncoding(type: ArrayType): string { + private createDynamicArrayHeadEncoding(type: ArrayType): CairoFunctionDefinition { const key = 'head ' + type.pp(); const existing = this.auxiliarGeneratedFunctions.get(key); - if (existing !== undefined) return existing.name; + if (existing !== undefined) return existing; const tailEncoding = this.createDynamicArrayTailEncoding(type); const name = `${this.functionName}_head_dynamic_array_spl${this.auxiliarGeneratedFunctions.size}`; @@ -195,7 +210,7 @@ export class IndexEncode extends AbiBase { ` let (length256) = wm_dyn_array_length(mem_ptr);`, ` let (length) = narrow_safe(length256);`, ` // Storing the element values encoding`, - ` let (new_index) = ${tailEncoding}(`, + ` let (new_index) = ${tailEncoding.name}(`, ` bytes_index,`, ` bytes_array,`, ` 0,`, @@ -208,24 +223,33 @@ export class IndexEncode extends AbiBase { `}`, ].join('\n'); - this.requireImport('warplib.memory', 'wm_dyn_array_length'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); - this.requireImport('warplib.maths.utils', 'narrow_safe'); + const importedFuncs = [ + this.requireImport('warplib.memory', 'wm_dyn_array_length'), + this.requireImport('warplib.maths.utils', 'felt_to_uint256'), + this.requireImport('warplib.maths.utils', 'narrow_safe'), + ]; + + const funcInfo = { name, code, functionsCalled: [...importedFuncs, tailEncoding] }; + const auxFunc = this.createAuxiliarGeneratedFunction(funcInfo); - this.auxiliarGeneratedFunctions.set(key, { name, code }); - return name; + this.auxiliarGeneratedFunctions.set(key, auxFunc); + return auxFunc; } - private createDynamicArrayTailEncoding(type: ArrayType): string { + private createDynamicArrayTailEncoding(type: ArrayType): CairoFunctionDefinition { const key = 'tail ' + type.pp(); const existing = this.auxiliarGeneratedFunctions.get(key); - if (existing !== undefined) return existing.name; + if (existing !== undefined) return existing; const elementT = getElementType(type); const elemntTSize = CairoType.fromSol(elementT, this.ast).width; - const readElement = this.readMemory(elementT, 'elem_loc'); - const headEncodingCode = this.generateEncodingCode(elementT, 'bytes_index', 'elem'); + const [readElement, readFunc] = this.readMemory(elementT, 'elem_loc'); + const [headEncodingCode, functionsCalled] = this.generateEncodingCode( + elementT, + 'bytes_index', + 'elem', + ); const name = `${this.functionName}_tail_dynamic_array_spl${this.auxiliarGeneratedFunctions.size}`; const code = [ `func ${name}${IMPLICITS}(`, @@ -247,18 +271,27 @@ export class IndexEncode extends AbiBase { `}`, ].join('\n'); - this.requireImport('warplib.memory', 'wm_index_dyn'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); + const importedFuncs = [ + this.requireImport('warplib.memory', 'wm_index_dyn'), + this.requireImport('warplib.maths.utils', 'felt_to_uint256'), + ]; + + const funcInfo = { + name, + code, + functionsCalled: [...importedFuncs, ...functionsCalled, readFunc], + }; + const auxFunc = this.createAuxiliarGeneratedFunction(funcInfo); - this.auxiliarGeneratedFunctions.set(key, { name, code }); - return name; + this.auxiliarGeneratedFunctions.set(key, auxFunc); + return auxFunc; } - private createStaticArrayHeadEncoding(type: ArrayType): string { + private createStaticArrayHeadEncoding(type: ArrayType): CairoFunctionDefinition { assert(type.size !== undefined); const key = 'head ' + type.pp(); const existing = this.auxiliarGeneratedFunctions.get(key); - if (existing !== undefined) return existing.name; + if (existing !== undefined) return existing; const inlineEncoding = this.createArrayInlineEncoding(type); @@ -285,22 +318,29 @@ export class IndexEncode extends AbiBase { `}`, ].join('\n'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); + const importedFunc = this.requireImport('warplib.maths.utils', 'felt_to_uint256'); - this.auxiliarGeneratedFunctions.set(key, { name, code }); - return name; + const funcInfo = { name, code, functionsCalled: [importedFunc] }; + const auxFunc = this.createAuxiliarGeneratedFunction(funcInfo); + + this.auxiliarGeneratedFunctions.set(key, auxFunc); + return auxFunc; } - private createArrayInlineEncoding(type: ArrayType) { + private createArrayInlineEncoding(type: ArrayType): CairoFunctionDefinition { const key = 'inline ' + removeSizeInfo(type); const existing = this.auxiliarGeneratedFunctions.get(key); - if (existing !== undefined) return existing.name; + if (existing !== undefined) return existing; const elementTWidth = CairoType.fromSol(type.elementT, this.ast).width; - const readElement = this.readMemory(type.elementT, 'elem_loc'); + const [readElement, readFunc] = this.readMemory(type.elementT, 'elem_loc'); - const headEncodingCode = this.generateEncodingCode(type.elementT, 'bytes_index', 'elem'); + const [headEncodingCode, functionsCalled] = this.generateEncodingCode( + type.elementT, + 'bytes_index', + 'elem', + ); const name = `${this.functionName}_inline_array_spl${this.auxiliarGeneratedFunctions.size}`; const code = [ @@ -328,14 +368,21 @@ export class IndexEncode extends AbiBase { `}`, ].join('\n'); - this.auxiliarGeneratedFunctions.set(key, { name, code }); - return name; + const funcInfo = { name, code, functionsCalled: [...functionsCalled, readFunc] }; + + const auxFunc = this.createAuxiliarGeneratedFunction(funcInfo); + + this.auxiliarGeneratedFunctions.set(key, auxFunc); + return auxFunc; } - private createStructHeadEncoding(type: UserDefinedType, def: StructDefinition) { + private createStructHeadEncoding( + type: UserDefinedType, + def: StructDefinition, + ): CairoFunctionDefinition { const key = 'struct head ' + type.pp(); const existing = this.auxiliarGeneratedFunctions.get(key); - if (existing !== undefined) return existing.name; + if (existing !== undefined) return existing; const inlineEncoding = this.createStructInlineEncoding(type, def); @@ -357,28 +404,53 @@ export class IndexEncode extends AbiBase { `}`, ].join('\n'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); - this.auxiliarGeneratedFunctions.set(key, { name, code }); - return name; + const importedFunction = this.requireImport('warplib.maths.utils', 'felt_to_uint256'); + + const funcInfo = { name, code, functionsCalled: [importedFunction, inlineEncoding] }; + const auxFunc = this.createAuxiliarGeneratedFunction(funcInfo); + + this.auxiliarGeneratedFunctions.set(key, auxFunc); + + return auxFunc; } - private createStructInlineEncoding(type: UserDefinedType, def: StructDefinition) { + private createStructInlineEncoding( + type: UserDefinedType, + def: StructDefinition, + ): CairoFunctionDefinition { const key = 'struct inline ' + type.pp(); const existing = this.auxiliarGeneratedFunctions.get(key); - if (existing !== undefined) return existing.name; + if (existing !== undefined) return existing; + + const encodingInfo: [string, CairoFunctionDefinition[]][] = def.vMembers.map( + (member, index) => { + const type = generalizeType(safeGetNodeType(member, this.ast.inference))[0]; + const elemWidth = CairoType.fromSol(type, this.ast).width; + const [readElement, readFunc] = this.readMemory(type, 'mem_ptr'); + const [encoding, functionsCalled] = this.generateEncodingCode( + type, + 'bytes_index', + `elem${index}`, + ); + return [ + [ + `// Encoding member ${member.name}`, + `let (elem${index}) = ${readElement};`, + `${encoding}`, + `let mem_ptr = mem_ptr + ${elemWidth};`, + ].join('\n'), + [...functionsCalled, readFunc], + ]; + }, + ); - const instructions = def.vMembers.map((member, index) => { - const type = generalizeType(safeGetNodeType(member, this.ast.inference))[0]; - const elemWidth = CairoType.fromSol(type, this.ast).width; - const readFunc = this.readMemory(type, 'mem_ptr'); - const encoding = this.generateEncodingCode(type, 'bytes_index', `elem${index}`); - return [ - `// Encoding member ${member.name}`, - `let (elem${index}) = ${readFunc};`, - `${encoding}`, - `let mem_ptr = mem_ptr + ${elemWidth};`, - ].join('\n'); - }); + const [instructions, functionsCalled] = encodingInfo.reduce( + ([instructions, functionsCalled], [currentInstruction, currentFuncs]) => [ + [...instructions, currentInstruction], + [...functionsCalled, ...currentFuncs], + ], + [new Array(), new Array()], + ); const name = `${this.functionName}_inline_struct_spl_${def.name}`; const code = [ @@ -393,35 +465,35 @@ export class IndexEncode extends AbiBase { `}`, ].join('\n'); - this.auxiliarGeneratedFunctions.set(key, { name, code }); - return name; + const funcInfo = { name, code, functionsCalled }; + const auxFunc = this.createAuxiliarGeneratedFunction(funcInfo); + + this.auxiliarGeneratedFunctions.set(key, auxFunc); + return auxFunc; } - private createStringOrBytesHeadEncoding(): string { + private createStringOrBytesHeadEncoding(): CairoImportFunctionDefinition { const funcName = 'bytes_to_felt_dynamic_array_spl'; - this.requireImport('warplib.dynamic_arrays_util', funcName); - return funcName; + return this.requireImport('warplib.dynamic_arrays_util', funcName); } - private createStringOrBytesHeadEncodingWithoutPadding(): string { + private createStringOrBytesHeadEncodingWithoutPadding(): CairoImportFunctionDefinition { const funcName = 'bytes_to_felt_dynamic_array_spl_without_padding'; - this.requireImport('warplib.dynamic_arrays_util', funcName); - return funcName; + return this.requireImport('warplib.dynamic_arrays_util', funcName); } - private createValueTypeHeadEncoding(): string { + private createValueTypeHeadEncoding(): CairoImportFunctionDefinition { const funcName = 'fixed_bytes256_to_felt_dynamic_array_spl'; - this.requireImport('warplib.dynamic_arrays_util', funcName); - return funcName; + return this.requireImport('warplib.dynamic_arrays_util', funcName); } - protected readMemory(type: TypeNode, arg: string) { + protected readMemory(type: TypeNode, arg: string): [string, CairoFunctionDefinition] { + const func = this.memoryRead.getOrCreateFuncDef(type); const cairoType = CairoType.fromSol(type, this.ast); - const funcName = this.memoryRead.getOrCreate(cairoType); const args = cairoType instanceof MemoryLocation ? [arg, isDynamicArray(type) ? uint256(2) : uint256(0)] : [arg]; - return `${funcName}(${args.join(',')})`; + return [`${func.name}(${args.join(',')})`, func]; } } diff --git a/src/cairoUtilFuncGen/base.ts b/src/cairoUtilFuncGen/base.ts index fb9af808b..19452c0fa 100644 --- a/src/cairoUtilFuncGen/base.ts +++ b/src/cairoUtilFuncGen/base.ts @@ -3,8 +3,8 @@ import { ArrayType, BytesType, DataLocation, + FunctionDefinition, generalizeType, - IntType, MappingType, PointerType, SourceUnit, @@ -14,17 +14,21 @@ import { UserDefinedType, } from 'solc-typed-ast'; import { AST } from '../ast/ast'; +import { CairoFunctionDefinition, CairoImportFunctionDefinition } from '../ast/cairoNodes'; +import { createCairoGeneratedFunction, ParameterInfo } from '../utils/functionGeneration'; import { TranspileFailedError } from '../utils/errors'; +import { createImport } from '../utils/importFuncGenerator'; import { isDynamicArray, isReferenceType } from '../utils/nodeTypeProcessing'; -export type CairoFunction = { +export type CairoStructDef = { name: string; code: string; }; -export type CairoStructDef = { +export type GeneratedFunctionInfo = { name: string; code: string; + functionsCalled: FunctionDefinition[]; }; /* @@ -38,29 +42,19 @@ export abstract class CairoUtilFuncGenBase { protected ast: AST; protected imports: Map> = new Map(); protected sourceUnit: SourceUnit; + constructor(ast: AST, sourceUnit: SourceUnit) { this.ast = ast; this.sourceUnit = sourceUnit; } - // import file -> import symbols - getImports(): Map> { - return this.imports; - } - - // Concatenate all the generated cairo code into a single string - abstract getGeneratedCode(): string; - - protected requireImport(location: string, name: string): void { - const existingImports = this.imports.get(location) ?? new Set(); - existingImports.add(name); - this.imports.set(location, existingImports); - } - - protected checkForImport(type: TypeNode): void { - if (type instanceof IntType && type.nBits === 256) { - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - } + protected requireImport( + location: string, + name: string, + inputs?: ParameterInfo[], + outputs?: ParameterInfo[], + ): CairoImportFunctionDefinition { + return createImport(location, name, this.sourceUnit, this.ast, inputs, outputs); } } @@ -68,21 +62,15 @@ export abstract class CairoUtilFuncGenBase { Most subclasses of CairoUtilFuncGenBase index their CairoFunctions off a single string, usually the cairo type of the input that the function's code depends on */ -export class StringIndexedFuncGen extends CairoUtilFuncGenBase { - protected generatedFunctions: Map = new Map(); - - getGeneratedCode(): string { - return [...this.generatedFunctions.values()].map((func) => func.code).join('\n\n'); - } +export abstract class StringIndexedFuncGen extends CairoUtilFuncGenBase { + protected generatedFunctionsDef: Map = new Map(); } -export class StringIndexedFuncGenWithAuxiliar extends StringIndexedFuncGen { - protected auxiliarGeneratedFunctions: Map = new Map(); +export abstract class StringIndexedFuncGenWithAuxiliar extends StringIndexedFuncGen { + protected auxiliarGeneratedFunctions: Map = new Map(); - getGeneratedCode(): string { - return [...this.auxiliarGeneratedFunctions.values(), ...this.generatedFunctions.values()] - .map((func) => func.code) - .join('\n\n'); + protected createAuxiliarGeneratedFunction(genFuncInfo: GeneratedFunctionInfo) { + return createCairoGeneratedFunction(genFuncInfo, [], [], this.ast, this.sourceUnit); } } diff --git a/src/cairoUtilFuncGen/calldata/calldataToMemory.ts b/src/cairoUtilFuncGen/calldata/calldataToMemory.ts index 7e478465a..1b32fdbd9 100644 --- a/src/cairoUtilFuncGen/calldata/calldataToMemory.ts +++ b/src/cairoUtilFuncGen/calldata/calldataToMemory.ts @@ -3,7 +3,6 @@ import { DataLocation, ArrayType, Expression, - ASTNode, generalizeType, FunctionStateMutability, TypeNode, @@ -13,9 +12,9 @@ import { StringType, } from 'solc-typed-ast'; import assert from 'assert'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { CairoDynArray, CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; -import { add, CairoFunction, delegateBasedOnType, StringIndexedFuncGen } from '../base'; +import { add, delegateBasedOnType, GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; import { uint256 } from '../../warplib/utils'; import { NotSupportedYetError } from '../../utils/errors'; import { printTypeNode } from '../../utils/astPrinter'; @@ -26,61 +25,59 @@ import { isReferenceType, safeGetNodeType, } from '../../utils/nodeTypeProcessing'; +import { CairoFunctionDefinition } from '../../export'; + +const IMPLICITS = + '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, warp_memory : DictAccess*}'; export class CallDataToMemoryGen extends StringIndexedFuncGen { - gen(node: Expression, nodeInSourceUnit?: ASTNode): FunctionCall { + public gen(node: Expression): FunctionCall { const type = generalizeType(safeGetNodeType(node, this.ast.inference))[0]; + const funcDef = this.getOrCreateFuncDef(type); + return createCallToFunction(funcDef, [node], this.ast); + } - const name = this.getOrCreate(type); - const functionStub = createCairoFunctionStub( - name, + public getOrCreateFuncDef(type: TypeNode) { + const key = type.pp(); + const value = this.generatedFunctionsDef.get(key); + if (value !== undefined) { + return value; + } + + const funcInfo = this.getOrCreate(type); + const funcDef = createCairoGeneratedFunction( + funcInfo, [['calldata', typeNameFromTypeNode(type, this.ast), DataLocation.CallData]], [['mem_loc', typeNameFromTypeNode(type, this.ast), DataLocation.Memory]], - ['syscall_ptr', 'pedersen_ptr', 'range_check_ptr', 'warp_memory'], this.ast, - nodeInSourceUnit ?? node, + this.sourceUnit, { mutability: FunctionStateMutability.Pure }, ); - return createCallToFunction(functionStub, [node], this.ast); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - private getOrCreate(type: TypeNode): string { - const key = type.pp(); - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - - const funcName = `cd_to_memory${this.generatedFunctions.size}`; - // Set an empty entry so recursive function generation doesn't clash - this.generatedFunctions.set(key, { name: funcName, code: '' }); - + private getOrCreate(type: TypeNode): GeneratedFunctionInfo { const unexpectedTypeFunc = () => { throw new NotSupportedYetError( `Copying ${printTypeNode(type)} from calldata to memory not implemented yet`, ); }; - const code = delegateBasedOnType( + const funcInfo = delegateBasedOnType( type, - (type) => this.createDynamicArrayCopyFunction(funcName, type), - (type) => this.createStaticArrayCopyFunction(funcName, type), - (type) => this.createStructCopyFunction(funcName, type), + (type) => this.createDynamicArrayCopyFunction(type), + (type) => this.createStaticArrayCopyFunction(type), + (type, def) => this.createStructCopyFunction(type, def), unexpectedTypeFunc, unexpectedTypeFunc, ); - - this.generatedFunctions.set(key, code); - return code.name; + return funcInfo; } - createDynamicArrayCopyFunction( - funcName: string, + + private createDynamicArrayCopyFunction( type: ArrayType | BytesType | StringType, - ): CairoFunction { - this.requireImport('starkware.cairo.common.dict', 'dict_write'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('warplib.memory', 'wm_new'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); + ): GeneratedFunctionInfo { const elementT = getElementType(type); const size = getSize(type); @@ -90,27 +87,31 @@ export class CallDataToMemoryGen extends StringIndexedFuncGen { const memoryElementWidth = CairoType.fromSol(elementT, this.ast).width; let copyCode: string; - + let auxFunc: CairoFunctionDefinition; if (isReferenceType(elementT)) { - const recursiveFunc = this.getOrCreate(elementT); + const recursiveFunc = this.getOrCreateFuncDef(elementT); copyCode = [ `let cdElem = calldata[0];`, - `let (mElem) = ${recursiveFunc}(cdElem);`, + `let (mElem) = ${recursiveFunc.name}(cdElem);`, `dict_write{dict_ptr=warp_memory}(mem_start, mElem);`, ].join('\n'); + auxFunc = recursiveFunc; } else if (memoryElementWidth === 2) { copyCode = [ `dict_write{dict_ptr=warp_memory}(mem_start, calldata[0].low);`, `dict_write{dict_ptr=warp_memory}(mem_start+1, calldata[0].high);`, ].join('\n'); + auxFunc = this.requireImport('starkware.cairo.common.dict', 'dict_write'); } else { copyCode = `dict_write{dict_ptr=warp_memory}(mem_start, calldata[0]);`; + auxFunc = this.requireImport('starkware.cairo.common.dict', 'dict_write'); } + const funcName = `cd_to_memory_dynamic_array${this.generatedFunctionsDef.size}`; return { name: funcName, code: [ - `func ${funcName}_elem${implicits}(calldata: ${callDataType.vPtr}, mem_start: felt, length: felt){`, + `func ${funcName}_elem${IMPLICITS}(calldata: ${callDataType.vPtr}, mem_start: felt, length: felt){`, ` alloc_locals;`, ` if (length == 0){`, ` return ();`, @@ -118,7 +119,7 @@ export class CallDataToMemoryGen extends StringIndexedFuncGen { copyCode, ` return ${funcName}_elem(calldata + ${callDataType.vPtr.to.width}, mem_start + ${memoryElementWidth}, length - 1);`, `}`, - `func ${funcName}${implicits}(calldata : ${callDataType}) -> (mem_loc: felt){`, + `func ${funcName}${IMPLICITS}(calldata : ${callDataType}) -> (mem_loc: felt){`, ` alloc_locals;`, ` let (len256) = felt_to_uint256(calldata.len);`, ` let (mem_start) = wm_new(len256, ${uint256(memoryElementWidth)});`, @@ -126,31 +127,36 @@ export class CallDataToMemoryGen extends StringIndexedFuncGen { ` return (mem_start,);`, `}`, ].join('\n'), + functionsCalled: [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('warplib.memory', 'wm_new'), + this.requireImport('warplib.maths.utils', 'felt_to_uint256'), + auxFunc, + ], }; } - createStaticArrayCopyFunction(funcName: string, type: ArrayType): CairoFunction { - this.requireImport('starkware.cairo.common.dict', 'dict_write'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('warplib.memory', 'wm_alloc'); + private createStaticArrayCopyFunction(type: ArrayType): GeneratedFunctionInfo { assert(type.size !== undefined); const callDataType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); const memoryType = CairoType.fromSol(type, this.ast, TypeConversionContext.MemoryAllocation); const memoryElementWidth = CairoType.fromSol(type.elementT, this.ast).width; const memoryOffsetMultiplier = memoryElementWidth === 1 ? '' : `* ${memoryElementWidth}`; - let copyCode: (index: number) => string; - const loc = (index: number) => index === 0 ? `mem_start` : `mem_start + ${index}${memoryOffsetMultiplier}`; + + let copyCode: (index: number) => string; + let funcCalls: CairoFunctionDefinition[] = []; if (isReferenceType(type.elementT)) { - const recursiveFunc = this.getOrCreate(type.elementT); + const recursiveFunc = this.getOrCreateFuncDef(type.elementT); copyCode = (index) => [ `let cdElem = calldata[${index}];`, - `let (mElem) = ${recursiveFunc}(cdElem);`, + `let (mElem) = ${recursiveFunc.name}(cdElem);`, `dict_write{dict_ptr=warp_memory}(${loc(index)}, mElem);`, ].join('\n'); + funcCalls = [recursiveFunc]; } else if (memoryElementWidth === 2) { copyCode = (index) => [ @@ -161,69 +167,82 @@ export class CallDataToMemoryGen extends StringIndexedFuncGen { copyCode = (index) => `dict_write{dict_ptr=warp_memory}(${loc(index)}, calldata[${index}]);`; } + const funcName = `cd_to_memory_static_array${this.generatedFunctionsDef.size}`; return { name: funcName, code: [ - `func ${funcName}${implicits}(calldata : ${callDataType}) -> (mem_loc: felt){`, + `func ${funcName}${IMPLICITS}(calldata : ${callDataType}) -> (mem_loc: felt){`, ` alloc_locals;`, ` let (mem_start) = wm_alloc(${uint256(memoryType.width)});`, ...mapRange(narrowBigIntSafe(type.size), (n) => copyCode(n)), ` return (mem_start,);`, `}`, ].join('\n'), + functionsCalled: [ + this.requireImport('warplib.memory', 'wm_alloc'), + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('starkware.cairo.common.dict', 'dict_write'), + ...funcCalls, + ], }; } - createStructCopyFunction(funcName: string, type: UserDefinedType): CairoFunction { - this.requireImport('starkware.cairo.common.dict', 'dict_write'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('warplib.memory', 'wm_alloc'); - const callDataType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); - const memoryType = CairoType.fromSol(type, this.ast, TypeConversionContext.MemoryAllocation); - const structDef = type.definition; - assert(structDef instanceof StructDefinition); + private createStructCopyFunction( + type: UserDefinedType, + structDef: StructDefinition, + ): GeneratedFunctionInfo { + const calldataType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); + const memoryType = CairoType.fromSol(type, this.ast, TypeConversionContext.MemoryAllocation); - let memOffset = 0; - return { - name: funcName, - code: [ - `func ${funcName}${implicits}(calldata : ${callDataType}) -> (mem_loc: felt){`, - ` alloc_locals;`, - ` let (mem_start) = wm_alloc(${uint256(memoryType.width)});`, - ...structDef.vMembers.map((decl): string => { - const memberType = safeGetNodeType(decl, this.ast.inference); - if (isReferenceType(memberType)) { - const recursiveFunc = this.getOrCreate(memberType); - const code = [ - `let (m${memOffset}) = ${recursiveFunc}(calldata.${decl.name});`, - `dict_write{dict_ptr=warp_memory}(${add('mem_start', memOffset)}, m${memOffset});`, - ].join('\n'); - memOffset++; - return code; - } else { - const memberWidth = CairoType.fromSol(memberType, this.ast).width; - if (memberWidth === 2) { - return [ - `dict_write{dict_ptr=warp_memory}(${add('mem_start', memOffset++)}, calldata.${ + const [copyCode, funcCalls] = structDef.vMembers.reduce( + ([copyCode, funcCalls, offset], decl) => { + const type = safeGetNodeType(decl, this.ast.inference); + + if (isReferenceType(type)) { + const recursiveFunc = this.getOrCreateFuncDef(type); + const code = [ + `let (member_${decl.name}) = ${recursiveFunc.name}(calldata.${decl.name});`, + `dict_write{dict_ptr=warp_memory}(${add('mem_start', offset)}, member_${decl.name});`, + ].join('\n'); + return [[...copyCode, code], [...funcCalls, recursiveFunc], offset + 1]; + } + + const memberWidth = CairoType.fromSol(type, this.ast).width; + const code = + memberWidth === 2 + ? [ + `dict_write{dict_ptr=warp_memory}(${add('mem_start', offset)}, calldata.${ decl.name }.low);`, - `dict_write{dict_ptr=warp_memory}(${add('mem_start', memOffset++)}, calldata.${ + `dict_write{dict_ptr=warp_memory}(${add('mem_start', offset + 1)}, calldata.${ decl.name }.high);`, - ].join('\n'); - } else { - return `dict_write{dict_ptr=warp_memory}(${add('mem_start', memOffset++)}, calldata.${ + ].join('\n') + : `dict_write{dict_ptr=warp_memory}(${add('mem_start', offset)}, calldata.${ decl.name });`; - } - } - }), + return [[...copyCode, code], funcCalls, offset + memberWidth]; + }, + [new Array(), new Array(), 0], + ); + + const funcName = `cd_to_memory_struct_${structDef.name}`; + return { + name: funcName, + code: [ + `func ${funcName}${IMPLICITS}(calldata : ${calldataType}) -> (mem_loc: felt){`, + ` alloc_locals;`, + ` let (mem_start) = wm_alloc(${uint256(memoryType.width)});`, + ...copyCode, ` return (mem_start,);`, `}`, ].join('\n'), + functionsCalled: [ + this.requireImport('starkware.cairo.common.dict', 'dict_write'), + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('warplib.memory', 'wm_alloc'), + ...funcCalls, + ], }; } } - -const implicits = - '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, warp_memory : DictAccess*}'; diff --git a/src/cairoUtilFuncGen/calldata/calldataToStorage.ts b/src/cairoUtilFuncGen/calldata/calldataToStorage.ts index 2b681d35b..4e849e6f1 100644 --- a/src/cairoUtilFuncGen/calldata/calldataToStorage.ts +++ b/src/cairoUtilFuncGen/calldata/calldataToStorage.ts @@ -1,10 +1,10 @@ import assert from 'assert'; import { ArrayType, - ASTNode, BytesType, DataLocation, Expression, + FunctionCall, generalizeType, SourceUnit, StringType, @@ -13,18 +13,21 @@ import { UserDefinedType, } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { CairoFunctionDefinition } from '../../export'; import { printTypeNode } from '../../utils/astPrinter'; import { CairoDynArray, CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; import { NotSupportedYetError } from '../../utils/errors'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { getElementType, safeGetNodeType } from '../../utils/nodeTypeProcessing'; import { mapRange, narrowBigIntSafe, typeNameFromTypeNode } from '../../utils/utils'; -import { add, delegateBasedOnType, StringIndexedFuncGen } from '../base'; +import { add, delegateBasedOnType, GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; import { DynArrayGen } from '../storage/dynArray'; import { StorageWriteGen } from '../storage/storageWrite'; +const IMPLICITS = '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}'; + export class CalldataToStorageGen extends StringIndexedFuncGen { - constructor( + public constructor( private dynArrayGen: DynArrayGen, private storageWriteGen: StorageWriteGen, ast: AST, @@ -33,143 +36,141 @@ export class CalldataToStorageGen extends StringIndexedFuncGen { super(ast, sourceUnit); } - gen(storageLocation: Expression, calldataLocation: Expression, nodeInSourceUnit?: ASTNode) { + public gen(storageLocation: Expression, calldataLocation: Expression): FunctionCall { const storageType = generalizeType(safeGetNodeType(storageLocation, this.ast.inference))[0]; const calldataType = generalizeType(safeGetNodeType(calldataLocation, this.ast.inference))[0]; + const funcDef = this.getOrCreateFuncDef(calldataType, storageType); + + return createCallToFunction(funcDef, [storageLocation, calldataLocation], this.ast); + } - const name = this.getOrCreate(calldataType); - const functionStub = createCairoFunctionStub( - name, + public getOrCreateFuncDef(calldataType: TypeNode, storageType: TypeNode) { + const key = `calldataToStorage(${calldataType.pp()},${storageType.pp()})`; + const value = this.generatedFunctionsDef.get(key); + if (value !== undefined) { + return value; + } + + const funcInfo = this.getOrCreate(calldataType); + const funcDef = createCairoGeneratedFunction( + funcInfo, [ ['loc', typeNameFromTypeNode(storageType, this.ast), DataLocation.Storage], ['dynarray', typeNameFromTypeNode(calldataType, this.ast), DataLocation.CallData], ], [['loc', typeNameFromTypeNode(storageType, this.ast), DataLocation.Storage]], - ['syscall_ptr', 'pedersen_ptr', 'range_check_ptr'], this.ast, - nodeInSourceUnit ?? storageLocation, + this.sourceUnit, ); - - return createCallToFunction(functionStub, [storageLocation, calldataLocation], this.ast); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - getOrCreate(type: TypeNode) { - const key = type.pp(); - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - + private getOrCreate(type: TypeNode): GeneratedFunctionInfo { const unexpectedTypeFunc = () => { throw new NotSupportedYetError( `Copying ${printTypeNode(type)} from calldata to storage is not supported yet`, ); }; - return delegateBasedOnType( + return delegateBasedOnType( type, - (type) => this.createDynamicArrayCopyFunction(key, type), - (type) => this.createStaticArrayCopyFunction(key, type), - (type) => this.createStructCopyFunction(key, type), + (type) => this.createDynamicArrayCopyFunction(type), + (type) => this.createStaticArrayCopyFunction(type), + (type, def) => this.createStructCopyFunction(type, def), unexpectedTypeFunc, unexpectedTypeFunc, ); } - private createStructCopyFunction(key: string, structType: UserDefinedType): string { - assert(structType.definition instanceof StructDefinition); - const structDef = structType.definition; + private createStructCopyFunction( + structType: UserDefinedType, + structDef: StructDefinition, + ): GeneratedFunctionInfo { const cairoStruct = CairoType.fromSol( structType, this.ast, TypeConversionContext.StorageAllocation, ); - const structName = `struct_${cairoStruct.toString()}`; - + const structName = `struct_${structDef.name}`; const members = structDef.vMembers.map((varDecl) => `${structName}.${varDecl.name}`); - const copyInstructions = this.generateStructCopyInstructions( + + const [copyInstructions, funcsCalled] = this.generateStructCopyInstructions( structDef.vMembers.map((varDecl) => safeGetNodeType(varDecl, this.ast.inference)), members, ); - const implicits = '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}'; const funcName = `cd_struct_${cairoStruct.toString()}_to_storage`; const code = [ - `func ${funcName}${implicits}(loc : felt, ${structName} : ${cairoStruct.toString()}) -> (loc : felt){`, + `func ${funcName}${IMPLICITS}(loc : felt, ${structName} : ${cairoStruct.toString()}) -> (loc : felt){`, ` alloc_locals;`, ...copyInstructions, ` return (loc,);`, `}`, ].join('\n'); - this.generatedFunctions.set(key, { name: funcName, code: code }); - return funcName; + return { name: funcName, code: code, functionsCalled: funcsCalled }; } - private createStaticArrayCopyFunction(key: string, arrayType: ArrayType): string { + // TODO: Check if funcion size can be reduced for big static arrays + private createStaticArrayCopyFunction(arrayType: ArrayType): GeneratedFunctionInfo { assert(arrayType.size !== undefined); - const cairoType = CairoType.fromSol(arrayType, this.ast, TypeConversionContext.CallDataRef); - const len = narrowBigIntSafe(arrayType.size); + const cairoType = CairoType.fromSol(arrayType, this.ast, TypeConversionContext.CallDataRef); const elems = mapRange(len, (n) => `static_array[${n}]`); - const copyInstructions = this.generateStructCopyInstructions( + const [copyInstructions, funcsCalled] = this.generateStructCopyInstructions( mapRange(len, () => arrayType.elementT), elems, ); - const implicits = '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}'; - const funcName = `cd_static_array_to_storage${this.generatedFunctions.size}`; + const funcName = `cd_static_array_to_storage${this.generatedFunctionsDef.size}`; const code = [ - `func ${funcName}${implicits}(loc : felt, static_array : ${cairoType.toString()}) -> (loc : felt){`, + `func ${funcName}${IMPLICITS}(loc : felt, static_array : ${cairoType.toString()}) -> (loc : felt){`, ` alloc_locals;`, ...copyInstructions, ` return (loc,);`, `}`, ].join('\n'); - this.generatedFunctions.set(key, { name: funcName, code: code }); - - return funcName; + return { name: funcName, code: code, functionsCalled: funcsCalled }; } private createDynamicArrayCopyFunction( - key: string, arrayType: ArrayType | BytesType | StringType, - ): string { + ): GeneratedFunctionInfo { const elementT = getElementType(arrayType); const structDef = CairoType.fromSol(arrayType, this.ast, TypeConversionContext.CallDataRef); assert(structDef instanceof CairoDynArray); - const [arrayName, arrayLen] = this.dynArrayGen.gen( - CairoType.fromSol(elementT, this.ast, TypeConversionContext.StorageAllocation), - ); + const [dynArray, dynArrayLength] = this.dynArrayGen.getOrCreateFuncDef(elementT); + const lenName = dynArrayLength.name; const cairoElementType = CairoType.fromSol( elementT, this.ast, TypeConversionContext.StorageAllocation, ); - const copyCode = `${this.storageWriteGen.getOrCreate(elementT)}(elem_loc, elem[index]);`; + const writeDef = this.storageWriteGen.getOrCreateFuncDef(elementT); + const copyCode = `${writeDef.name}(elem_loc, elem[index]);`; - const implicits = '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}'; const pointerType = `${cairoElementType.toString()}*`; - const funcName = `cd_dynamic_array_to_storage${this.generatedFunctions.size}`; + const funcName = `cd_dynamic_array_to_storage${this.generatedFunctionsDef.size}`; const code = [ - `func ${funcName}_write${implicits}(loc : felt, index : felt, len : felt, elem: ${pointerType}){`, + `func ${funcName}_write${IMPLICITS}(loc : felt, index : felt, len : felt, elem: ${pointerType}){`, ` alloc_locals;`, ` if (index == len){`, ` return ();`, ` }`, ` let (index_uint256) = warp_uint256(index);`, - ` let (elem_loc) = ${arrayName}.read(loc, index_uint256);`, + ` let (elem_loc) = ${dynArray.name}.read(loc, index_uint256);`, ` if (elem_loc == 0){`, ` let (elem_loc) = WARP_USED_STORAGE.read();`, ` WARP_USED_STORAGE.write(elem_loc + ${cairoElementType.width});`, - ` ${arrayName}.write(loc, index_uint256, elem_loc);`, + ` ${dynArray.name}.write(loc, index_uint256, elem_loc);`, ` ${copyCode}`, ` return ${funcName}_write(loc, index + 1, len, elem);`, ` }else{`, @@ -178,39 +179,50 @@ export class CalldataToStorageGen extends StringIndexedFuncGen { ` }`, `}`, - `func ${funcName}${implicits}(loc : felt, dyn_array_struct : ${structDef.name}) -> (loc : felt){ `, + `func ${funcName}${IMPLICITS}(loc : felt, dyn_array_struct : ${structDef.name}) -> (loc : felt){ `, ` alloc_locals;`, ` let (len_uint256) = warp_uint256(dyn_array_struct.len);`, - ` ${arrayLen}.write(loc, len_uint256);`, + ` ${lenName}.write(loc, len_uint256);`, ` ${funcName}_write(loc, 0, dyn_array_struct.len, dyn_array_struct.ptr);`, ` return (loc,);`, `}`, ].join('\n'); - this.requireImport('warplib.maths.int_conversions', 'warp_uint256'); - - this.generatedFunctions.set(key, { name: funcName, code: code }); - - return funcName; + return { + name: funcName, + code: code, + functionsCalled: [ + this.requireImport('warplib.maths.int_conversions', 'warp_uint256'), + dynArray, + dynArrayLength, + writeDef, + ], + }; } - private generateStructCopyInstructions(varTypes: TypeNode[], names: string[]): string[] { - let offset = 0; - const copyInstructions = varTypes.map((varType, index) => { - const varCairoTypeWidth = CairoType.fromSol( - varType, - this.ast, - TypeConversionContext.CallDataRef, - ).width; - - const funcName = this.storageWriteGen.getOrCreate(varType); - const location = add('loc', offset); - - offset += varCairoTypeWidth; - - return ` ${funcName}(${location}, ${names[index]});`; - }); + private generateStructCopyInstructions( + varTypes: TypeNode[], + names: string[], + ): [string[], CairoFunctionDefinition[]] { + const [copyInstructions, funcCalls] = varTypes.reduce( + ([copyInstructions, funcCalls, offset], varType, index) => { + const varCairoTypeWidth = CairoType.fromSol( + varType, + this.ast, + TypeConversionContext.CallDataRef, + ).width; + + const writeDef = this.storageWriteGen.getOrCreateFuncDef(varType); + const location = add('loc', offset); + return [ + [...copyInstructions, ` ${writeDef.name}(${location}, ${names[index]});`], + [...funcCalls, writeDef], + offset + varCairoTypeWidth, + ]; + }, + [new Array(), new Array(), 0], + ); - return copyInstructions; + return [copyInstructions, funcCalls]; } } diff --git a/src/cairoUtilFuncGen/calldata/externalDynArray/externalDynArrayStructConstructor.ts b/src/cairoUtilFuncGen/calldata/externalDynArray/externalDynArrayStructConstructor.ts index bae5bf49d..a247b5672 100644 --- a/src/cairoUtilFuncGen/calldata/externalDynArray/externalDynArrayStructConstructor.ts +++ b/src/cairoUtilFuncGen/calldata/externalDynArray/externalDynArrayStructConstructor.ts @@ -10,15 +10,19 @@ import { generalizeType, BytesType, StringType, + TypeNode, } from 'solc-typed-ast'; import assert from 'assert'; -import { createCairoFunctionStub, createCallToFunction } from '../../../utils/functionGeneration'; +import { + createCairoGeneratedFunction, + createCallToFunction, +} from '../../../utils/functionGeneration'; import { CairoType, generateCallDataDynArrayStructName, TypeConversionContext, } from '../../../utils/cairoTypeSystem'; -import { StringIndexedFuncGen } from '../../base'; +import { GeneratedFunctionInfo, StringIndexedFuncGen } from '../../base'; import { createIdentifier } from '../../../utils/nodeTemplates'; import { FunctionStubKind } from '../../../ast/cairoNodes'; import { typeNameFromTypeNode } from '../../../utils/utils'; @@ -32,9 +36,9 @@ import { const INDENT = ' '.repeat(4); export class ExternalDynArrayStructConstructor extends StringIndexedFuncGen { - gen(astNode: VariableDeclaration, nodeInSourceUnit?: ASTNode): FunctionCall; - gen(astNode: Expression, nodeInSourceUnit?: ASTNode): void; - gen( + public gen(astNode: VariableDeclaration, nodeInSourceUnit?: ASTNode): FunctionCall; + public gen(astNode: Expression, nodeInSourceUnit?: ASTNode): void; + public gen( astNode: VariableDeclaration | Expression, nodeInSourceUnit?: ASTNode, ): FunctionCall | undefined { @@ -46,57 +50,62 @@ export class ExternalDynArrayStructConstructor extends StringIndexedFuncGen { `Attempted to create dynArray struct for non-dynarray type ${printTypeNode(type)}`, ); - const name = this.getOrCreate(type); - const structDefStub = createCairoFunctionStub( - name, + const funcDef = this.getOrCreateFuncDef(type); + if (astNode instanceof VariableDeclaration) { + const functionInputs: Identifier[] = [ + createIdentifier(astNode, this.ast, DataLocation.CallData, nodeInSourceUnit ?? astNode), + ]; + return createCallToFunction(funcDef, functionInputs, this.ast); + } else { + // When CallData DynArrays are being returned and we do not need the StructConstructor + // to be returned, we just need the StructDefinition to be in the contract. + return; + } + } + + public getOrCreateFuncDef(type: ArrayType | BytesType | StringType) { + const elemType = getElementType(type); + + const key = elemType.pp(); + const value = this.generatedFunctionsDef.get(key); + if (value !== undefined) { + return value; + } + + const funcInfo = this.getOrCreate(elemType); + const funcDef = createCairoGeneratedFunction( + funcInfo, [['darray', typeNameFromTypeNode(type, this.ast), DataLocation.CallData]], [['darray_struct', typeNameFromTypeNode(type, this.ast), DataLocation.CallData]], - [], this.ast, - nodeInSourceUnit ?? astNode, + this.sourceUnit, { mutability: FunctionStateMutability.View, stubKind: FunctionStubKind.StructDefStub, acceptsRawDArray: true, }, ); - - if (astNode instanceof VariableDeclaration) { - const functionInputs: Identifier[] = [ - createIdentifier(astNode, this.ast, DataLocation.CallData, nodeInSourceUnit ?? astNode), - ]; - return createCallToFunction(structDefStub, functionInputs, this.ast); - } else { - // When CallData DynArrays are being returned and we do not need the StructConstructor to be returned, we just need - // the StructDefinition to be in the contract. - return; - } + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - getOrCreate(type: ArrayType | BytesType | StringType): string { - const elemType = getElementType(type); + private getOrCreate(elemType: TypeNode): GeneratedFunctionInfo { const elementCairoType = CairoType.fromSol( elemType, this.ast, TypeConversionContext.CallDataRef, ); - const key = generateCallDataDynArrayStructName(elemType, this.ast); - - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - - this.generatedFunctions.set(key, { - name: key, + const structName = generateCallDataDynArrayStructName(elemType, this.ast); + const funcInfo: GeneratedFunctionInfo = { + name: structName, code: [ - `struct ${key}{`, + `struct ${structName}{`, `${INDENT} len : felt ,`, `${INDENT} ptr : ${elementCairoType.toString()}*,`, `}`, ].join('\n'), - }); - - return key; + functionsCalled: [], + }; + return funcInfo; } } diff --git a/src/cairoUtilFuncGen/calldata/implicitArrayConversion.ts b/src/cairoUtilFuncGen/calldata/implicitArrayConversion.ts index 639edd27b..dda90377e 100644 --- a/src/cairoUtilFuncGen/calldata/implicitArrayConversion.ts +++ b/src/cairoUtilFuncGen/calldata/implicitArrayConversion.ts @@ -7,25 +7,31 @@ import { FunctionCall, generalizeType, IntType, - PointerType, SourceUnit, TypeNode, } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { CairoFunctionDefinition } from '../../export'; import { CairoDynArray, CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; import { cloneASTNode } from '../../utils/cloning'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; -import { isDynamicStorageArray, safeGetNodeType } from '../../utils/nodeTypeProcessing'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; +import { isDynamicArray, safeGetNodeType } from '../../utils/nodeTypeProcessing'; import { mapRange, narrowBigIntSafe, typeNameFromTypeNode } from '../../utils/utils'; import { uint256 } from '../../warplib/utils'; -import { add, CairoFunction, StringIndexedFuncGen } from '../base'; +import { add, delegateBasedOnType, GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; import { getBaseType } from '../memory/implicitConversion'; import { DynArrayGen } from '../storage/dynArray'; import { DynArrayIndexAccessGen } from '../storage/dynArrayIndexAccess'; import { StorageWriteGen } from '../storage/storageWrite'; +import { NotSupportedYetError } from '../../utils/errors'; +import { printTypeNode } from '../../utils/astPrinter'; +const IMPLICITS = + '{syscall_ptr : felt*, range_check_ptr, pedersen_ptr : HashBuiltin*, bitwise_ptr : BitwiseBuiltin*}'; + +// TODO: Add checks for expressions locations when generating export class ImplicitArrayConversion extends StringIndexedFuncGen { - constructor( + public constructor( private storageWriteGen: StorageWriteGen, private dynArrayGen: DynArrayGen, private dynArrayIndexAccessGen: DynArrayIndexAccessGen, @@ -35,88 +41,44 @@ export class ImplicitArrayConversion extends StringIndexedFuncGen { super(ast, sourceUnit); } - genIfNecessary( + public genIfNecessary( targetExpression: Expression, sourceExpression: Expression, ): [Expression, boolean] { const targetType = generalizeType(safeGetNodeType(targetExpression, this.ast.inference))[0]; const sourceType = generalizeType(safeGetNodeType(sourceExpression, this.ast.inference))[0]; - if (this.checkDims(targetType, sourceType) || this.checkSizes(targetType, sourceType)) { + if (checkDims(targetType, sourceType) || checkSizes(targetType, sourceType)) { return [this.gen(targetExpression, sourceExpression), true]; } else { return [sourceExpression, false]; } } - checkSizes(targetType: TypeNode, sourceType: TypeNode): boolean { - const targetBaseType = getBaseType(targetType); - const sourceBaseType = getBaseType(sourceType); - if (targetBaseType instanceof IntType && sourceBaseType instanceof IntType) { - return ( - (targetBaseType.nBits > sourceBaseType.nBits && sourceBaseType.signed) || - (!targetBaseType.signed && targetBaseType.nBits === 256 && 256 > sourceBaseType.nBits) - ); - } - if (targetBaseType instanceof FixedBytesType && sourceBaseType instanceof FixedBytesType) { - return targetBaseType.size > sourceBaseType.size; - } - return false; - } - - checkDims(targetType: TypeNode, sourceType: TypeNode): boolean { - const targetArray = generalizeType(targetType)[0]; - const sourceArray = generalizeType(sourceType)[0]; - - if (targetArray instanceof ArrayType && sourceArray instanceof ArrayType) { - const targetArrayElm = generalizeType(targetArray.elementT)[0]; - const sourceArrayElm = generalizeType(sourceArray.elementT)[0]; - - if (targetArray.size !== undefined && sourceArray.size !== undefined) { - if (targetArray.size > sourceArray.size) { - return true; - } else if (targetArrayElm instanceof ArrayType && sourceArrayElm instanceof ArrayType) { - return this.checkDims(targetArrayElm, sourceArrayElm); - } else { - return false; - } - } else if (targetArray.size === undefined && sourceArray.size !== undefined) { - return true; - } else if (targetArray.size === undefined && sourceArray.size === undefined) - if (targetArrayElm instanceof ArrayType && sourceArrayElm instanceof ArrayType) { - return this.checkDims(targetArrayElm, sourceArrayElm); - } - } - return false; - } - - gen(lhs: Expression, rhs: Expression): FunctionCall { + public gen(lhs: Expression, rhs: Expression): FunctionCall { const lhsType = safeGetNodeType(lhs, this.ast.inference); const rhsType = safeGetNodeType(rhs, this.ast.inference); - - const name = this.getOrCreate(lhsType, rhsType); - - const functionStub = createCairoFunctionStub( - name, - [ - ['lhs', typeNameFromTypeNode(lhsType, this.ast), DataLocation.Storage], - ['rhs', typeNameFromTypeNode(rhsType, this.ast), DataLocation.CallData], - ], - [], - ['syscall_ptr', 'bitwise_ptr', 'range_check_ptr', 'pedersen_ptr', 'bitwise_ptr'], - this.ast, - rhs, - ); + const funcDef = this.getOrCreateFuncDef(lhsType, rhsType); return createCallToFunction( - functionStub, + funcDef, [cloneASTNode(lhs, this.ast), cloneASTNode(rhs, this.ast)], this.ast, ); } - getOrCreate(targetType: TypeNode, sourceType: TypeNode): string { + + public getOrCreateFuncDef(targetType: TypeNode, sourceType: TypeNode) { + targetType = generalizeType(targetType)[0]; + sourceType = generalizeType(sourceType)[0]; + assert( + targetType instanceof ArrayType && sourceType instanceof ArrayType, + `Invalid calldata implicit conversion: Expected ArrayType type but found: ${printTypeNode( + targetType, + )} and ${printTypeNode(sourceType)}`, + ); + const sourceRepForKey = CairoType.fromSol( - generalizeType(sourceType)[0], + sourceType, this.ast, TypeConversionContext.CallDataRef, ).fullStringRepresentation; @@ -125,329 +87,164 @@ export class ImplicitArrayConversion extends StringIndexedFuncGen { // Using Calldata here gives us the full representation instead of WarpId provided by Storage. // This is only for KeyGen and no further processing. const targetRepForKey = CairoType.fromSol( - generalizeType(targetType)[0], + targetType, this.ast, TypeConversionContext.CallDataRef, ).fullStringRepresentation; - const key = `${targetRepForKey}_${getBaseType( - targetType, - ).pp()} -> ${sourceRepForKey}_${getBaseType(sourceType).pp()}`; - - const existing = this.generatedFunctions.get(key); + const targetBaseType = getBaseType(targetType).pp(); + const sourceBaseType = getBaseType(sourceType).pp(); + const key = `${targetRepForKey}_${targetBaseType} -> ${sourceRepForKey}_${sourceBaseType}`; + const existing = this.generatedFunctionsDef.get(key); if (existing !== undefined) { - return existing.name; + return existing; } - assert(targetType instanceof PointerType && sourceType instanceof PointerType); - assert(targetType.to instanceof ArrayType && sourceType.to instanceof ArrayType); - - let cairoFunc: CairoFunction; - if (targetType.to.size === undefined && sourceType.to.size === undefined) { - cairoFunc = this.DynamicToDynamicConversion(key, targetType, sourceType); - } else if (targetType.to.size === undefined && sourceType.to.size !== undefined) { - cairoFunc = this.staticToDynamicConversion(key, targetType, sourceType); - } else { - cairoFunc = this.staticToStaticConversion(key, targetType, sourceType); - } - return cairoFunc.name; + + const funcInfo = this.getOrCreate(targetType, sourceType); + const funcDef = createCairoGeneratedFunction( + funcInfo, + [ + ['lhs', typeNameFromTypeNode(targetType, this.ast), DataLocation.Storage], + ['rhs', typeNameFromTypeNode(sourceType, this.ast), DataLocation.CallData], + ], + [], + this.ast, + this.sourceUnit, + ); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - private staticToStaticConversion( - key: string, - targetType: TypeNode, - sourceType: TypeNode, - ): CairoFunction { - assert(targetType instanceof PointerType && sourceType instanceof PointerType); - assert(targetType.to instanceof ArrayType && sourceType.to instanceof ArrayType); + private getOrCreate(targetType: ArrayType, sourceType: ArrayType): GeneratedFunctionInfo { + const unexpectedTypeFunc = () => { + throw new NotSupportedYetError( + `Scaling ${printTypeNode(sourceType)} to ${printTypeNode( + targetType, + )} from memory to storage not implemented yet`, + ); + }; - const targetElmType = targetType.to.elementT; - const sourceElmType = sourceType.to.elementT; + return delegateBasedOnType( + targetType, + (targetType) => { + assert(targetType instanceof ArrayType && sourceType instanceof ArrayType); + return sourceType.size === undefined + ? this.dynamicToDynamicArrayConversion(targetType, sourceType) + : this.staticToDynamicArrayConversion(targetType, sourceType); + }, + (targetType) => { + assert(sourceType instanceof ArrayType); + return this.staticToStaticArrayConversion(targetType, sourceType); + }, + unexpectedTypeFunc, + unexpectedTypeFunc, + unexpectedTypeFunc, + ); + } - const funcName = `CD_ST_TO_WS_ST${this.generatedFunctions.size}`; - this.generatedFunctions.set(key, { name: funcName, code: '' }); + private staticToStaticArrayConversion( + targetType: ArrayType, + sourceType: ArrayType, + ): GeneratedFunctionInfo { + assert(targetType.size !== undefined && sourceType.size !== undefined); + assert( + targetType.size >= sourceType.size, + `Cannot convert a bigger static array (${targetType.size}) into a smaller one (${sourceType.size})`, + ); - const cairoSourceType = CairoType.fromSol( + const [generateCopyCode, requiredFunctions] = this.createStaticToStaticCopyCode( + targetType, sourceType, - this.ast, - TypeConversionContext.CallDataRef, ); - assert(sourceType.to.size !== undefined); - const sizeSource = narrowBigIntSafe(sourceType.to.size); - - const copyInstructions = this.generateS2SCopyInstructions( - targetElmType, - sourceElmType, - sizeSource, + const sourceSize = narrowBigIntSafe(sourceType.size); + const targetElementTSize = CairoType.fromSol( + targetType.elementT, + this.ast, + TypeConversionContext.StorageAllocation, + ).width; + const copyInstructions: string[] = mapRange(sourceSize, (index) => + generateCopyCode(index, index * targetElementTSize), ); - const implicit = - '{syscall_ptr : felt*, range_check_ptr, pedersen_ptr : HashBuiltin*, bitwise_ptr : BitwiseBuiltin*}'; + const cairoSourceTypeName = CairoType.fromSol( + sourceType, + this.ast, + TypeConversionContext.CallDataRef, + ).toString(); + const funcName = `calldata_conversion_static_to_static${this.generatedFunctionsDef.size}`; const code = [ - `func ${funcName}${implicit}(storage_loc: felt, arg: ${cairoSourceType.toString()}){`, + `func ${funcName}${IMPLICITS}(storage_loc: felt, arg: ${cairoSourceTypeName}){`, `alloc_locals;`, ...copyInstructions, ' return ();', '}', ].join('\n'); - this.addImports(targetElmType, sourceElmType); - this.generatedFunctions.set(key, { name: funcName, code: code }); - return { name: funcName, code: code }; - } - private generateS2SCopyInstructions( - targetElmType: TypeNode, - sourceElmType: TypeNode, - length: number, - ): string[] { - const cairoTargetElementType = CairoType.fromSol( - targetElmType, - this.ast, - TypeConversionContext.StorageAllocation, - ); - - let offset = 0; - const instructions = mapRange(length, (index) => { - let code; - if (targetElmType instanceof IntType) { - assert(sourceElmType instanceof IntType); - if (targetElmType.nBits === sourceElmType.nBits) { - code = ` ${this.storageWriteGen.getOrCreate(targetElmType)}(${add( - 'storage_loc', - offset, - )}, arg[${index}]);`; - } else if (targetElmType.signed) { - code = [ - ` let (arg_${index}) = warp_int${sourceElmType.nBits}_to_int${targetElmType.nBits}(arg[${index}]);`, - `${this.storageWriteGen.getOrCreate(targetElmType)}(${add( - 'storage_loc', - offset, - )}, arg_${index});`, - ].join('\n'); - } else { - code = [ - ` let (arg_${index}) = felt_to_uint256(arg[${index}]);`, - ` ${this.storageWriteGen.getOrCreate(targetElmType)}(${add( - 'storage_loc', - offset, - )}, arg_${index});`, - ].join('\n'); - } - } else if ( - targetElmType instanceof FixedBytesType && - sourceElmType instanceof FixedBytesType - ) { - if (targetElmType.size > sourceElmType.size) { - code = [ - ` let (arg_${index}) = warp_bytes_widen${ - targetElmType.size === 32 ? '_256' : '' - }(arg[${index}], ${(targetElmType.size - sourceElmType.size) * 8});`, - ` ${this.storageWriteGen.getOrCreate(targetElmType)}(${add( - 'storage_loc', - offset, - )}, arg_${index});`, - ].join('\n'); - } else { - code = ` ${this.storageWriteGen.getOrCreate(targetElmType)}(${add( - 'storage_loc', - offset, - )}, arg[${index}]);`; - } - } else { - if (isDynamicStorageArray(targetElmType)) { - code = [ - ` let (ref_${index}) = readId(${add('storage_loc', offset)});`, - ` ${this.getOrCreate(targetElmType, sourceElmType)}(ref_${index}, arg[${index}]);`, - ].join('\n'); - } else { - code = [ - ` ${this.getOrCreate(targetElmType, sourceElmType)}(${add( - 'storage_loc', - offset, - )}, arg[${index}]);`, - ].join('\n'); - } - } - offset = offset + cairoTargetElementType.width; - return code; - }); - return instructions; + return { + name: funcName, + code: code, + functionsCalled: requiredFunctions, + }; } - private staticToDynamicConversion( - key: string, - targetType: TypeNode, - sourceType: TypeNode, - ): CairoFunction { - assert(targetType instanceof PointerType && sourceType instanceof PointerType); - assert(targetType.to instanceof ArrayType && sourceType.to instanceof ArrayType); - - assert(targetType.to.size === undefined && sourceType.to.size !== undefined); + private staticToDynamicArrayConversion( + targetType: ArrayType, + sourceType: ArrayType, + ): GeneratedFunctionInfo { + assert(targetType.size === undefined && sourceType.size !== undefined); - const targetElmType = targetType.to.elementT; - const sourceElmType = sourceType.to.elementT; + const [generateCopyCode, requiredFunctions] = this.createStaticToDynamicCopyCode( + targetType, + sourceType, + ); - const funcName = `CD_ST_TO_WS_DY${this.generatedFunctions.size}`; - this.generatedFunctions.set(key, { name: funcName, code: '' }); + const sourceSize = narrowBigIntSafe(sourceType.size); + const copyInstructions: string[] = mapRange(sourceSize, (index) => generateCopyCode(index)); - const cairoTargetElementType = CairoType.fromSol( - targetType.to.elementT, - this.ast, - TypeConversionContext.StorageAllocation, - ); + let optionalCode = ''; + let optionalImport: CairoFunctionDefinition[] = []; + if (isDynamicArray(targetType)) { + const [_dynArray, dynArrayLength] = this.dynArrayGen.getOrCreateFuncDef(targetType.elementT); + optionalImport = [dynArrayLength]; + optionalCode = `${dynArrayLength.name}.write(ref, ${uint256(sourceSize)});`; + } - const cairoSourceType = CairoType.fromSol( + const cairoSourceTypeName = CairoType.fromSol( sourceType, this.ast, TypeConversionContext.CallDataRef, - ); - - const cairoSourceTypeString = cairoSourceType.toString(); - - const sizeSource = narrowBigIntSafe(sourceType.to.size); - - assert(sizeSource !== undefined); - - const dynArrayLengthName = this.dynArrayGen.gen(cairoTargetElementType)[1]; - const copyInstructions = this.generateS2DCopyInstructions( - targetElmType, - sourceElmType, - sizeSource, - ); - - const implicit = - '{syscall_ptr : felt*, range_check_ptr, pedersen_ptr : HashBuiltin*, bitwise_ptr : BitwiseBuiltin*}'; + ).toString(); + const funcName = `calldata_conversion_static_to_dynamic${this.generatedFunctionsDef.size}`; const code = [ - `func ${funcName}${implicit}(ref: felt, arg: ${cairoSourceTypeString}){`, - ` alloc_locals;`, - isDynamicStorageArray(targetType) - ? ` ${dynArrayLengthName}.write(ref, ${uint256(sourceType.to.size)});` - : '', + `func ${funcName}${IMPLICITS}(ref: felt, arg: ${cairoSourceTypeName}){`, + `alloc_locals;`, + ` ${optionalCode}`, ...copyInstructions, ' return ();', '}', ].join('\n'); - this.addImports(targetElmType, sourceElmType); - this.generatedFunctions.set(key, { name: funcName, code: code }); - return { name: funcName, code: code }; - } - - private generateS2DCopyInstructions( - targetElmType: TypeNode, - sourceElmType: TypeNode, - length: number, - ): string[] { - const cairoTargetElementType = CairoType.fromSol( - targetElmType, - this.ast, - TypeConversionContext.StorageAllocation, - ); - const instructions = mapRange(length, (index) => { - if (targetElmType instanceof IntType) { - assert(sourceElmType instanceof IntType); - if (targetElmType.nBits === sourceElmType.nBits) { - return [ - ` let (storage_loc${index}) = ${this.dynArrayIndexAccessGen.getOrCreate( - targetElmType, - )}(ref, ${uint256(index)});`, - ` ${this.storageWriteGen.getOrCreate( - targetElmType, - )}(storage_loc${index}, arg[${index}]);`, - ].join('\n'); - } else if (targetElmType.signed) { - return [ - ` let (arg_${index}) = warp_int${sourceElmType.nBits}_to_int${targetElmType.nBits}(arg[${index}]);`, - ` let (storage_loc${index}) = ${this.dynArrayIndexAccessGen.getOrCreate( - targetElmType, - )}(ref, ${uint256(index)});`, - ` ${this.storageWriteGen.getOrCreate( - targetElmType, - )}(storage_loc${index}, arg_${index});`, - ].join('\n'); - } else { - return [ - ` let (arg_${index}) = felt_to_uint256(arg[${index}]);`, - ` let (storage_loc${index}) = ${this.dynArrayIndexAccessGen.getOrCreate( - targetElmType, - )}(ref, ${uint256(index)});`, - ` ${this.storageWriteGen.getOrCreate( - targetElmType, - )}(storage_loc${index}, arg_${index});`, - ].join('\n'); - } - } else if ( - targetElmType instanceof FixedBytesType && - sourceElmType instanceof FixedBytesType - ) { - if (targetElmType.size > sourceElmType.size) { - return [ - ` let (arg_${index}) = warp_bytes_widen${ - targetElmType.size === 32 ? '_256' : '' - }(arg[${index}], ${(targetElmType.size - sourceElmType.size) * 8});`, - ` let (storage_loc${index}) = ${this.dynArrayIndexAccessGen.getOrCreate( - targetElmType, - )}(ref, ${uint256(index)});`, - ` ${this.storageWriteGen.getOrCreate( - targetElmType, - )}(storage_loc${index}, arg_${index});`, - ].join('\n'); - } else { - return [ - ` let (storage_loc${index}) = ${this.dynArrayIndexAccessGen.getOrCreate( - targetElmType, - )}(ref, ${uint256(index)});`, - ` ${this.storageWriteGen.getOrCreate( - targetElmType, - )}(storage_loc${index}, arg[${index}]);`, - ].join('\n'); - } - } else { - if (isDynamicStorageArray(targetElmType)) { - const dynArrayLengthName = this.dynArrayGen.gen(cairoTargetElementType)[1]; - return [ - ` let (storage_loc${index}) = ${this.dynArrayIndexAccessGen.getOrCreate( - targetElmType, - )}(ref, ${uint256(index)});`, - ` let (ref_${index}) = readId(storage_loc${index});`, - ` ${dynArrayLengthName}.write(ref_${index}, ${uint256(length)});`, - ` ${this.getOrCreate(targetElmType, sourceElmType)}(ref_${index}, arg[${index}]);`, - ].join('\n'); - } else { - return [ - ` let (storage_loc${index}) = ${this.dynArrayIndexAccessGen.getOrCreate( - targetElmType, - )}(ref, ${uint256(index)});`, - ` ${this.getOrCreate( - targetElmType, - sourceElmType, - )}(storage_loc${index}, arg[${index}]);`, - ].join('\n'); - } - } - }); - return instructions; + return { + name: funcName, + code: code, + functionsCalled: [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + ...requiredFunctions, + ...optionalImport, + ], + }; } - private DynamicToDynamicConversion( - key: string, - targetType: TypeNode, - sourceType: TypeNode, - ): CairoFunction { - assert(targetType instanceof PointerType && sourceType instanceof PointerType); - assert(targetType.to instanceof ArrayType && sourceType.to instanceof ArrayType); - - assert(targetType.to.size === undefined && sourceType.to.size === undefined); - - const targetElmType = targetType.to.elementT; - const sourceElmType = sourceType.to.elementT; + private dynamicToDynamicArrayConversion( + targetType: ArrayType, + sourceType: ArrayType, + ): GeneratedFunctionInfo { + assert(targetType.size === undefined && sourceType.size === undefined); - const funcName = `CD_DY_TO_WS_DY${this.generatedFunctions.size}`; - this.generatedFunctions.set(key, { name: funcName, code: '' }); - - const cairoTargetElementType = CairoType.fromSol( - targetType.to.elementT, - this.ast, - TypeConversionContext.StorageAllocation, + const [_dynArray, dynArrayLength] = this.dynArrayGen.getOrCreateFuncDef(targetType.elementT); + const arrayDef = this.dynArrayIndexAccessGen.getOrCreateFuncDef( + targetType.elementT, + targetType, ); const cairoSourceType = CairoType.fromSol( @@ -455,85 +252,323 @@ export class ImplicitArrayConversion extends StringIndexedFuncGen { this.ast, TypeConversionContext.CallDataRef, ); - assert(cairoSourceType instanceof CairoDynArray); - const dynArrayLengthName = this.dynArrayGen.gen(cairoTargetElementType)[1]; - const implicit = - '{syscall_ptr : felt*, range_check_ptr, pedersen_ptr : HashBuiltin*, bitwise_ptr : BitwiseBuiltin*}'; - const loaderName = `DY_LOADER${this.generatedFunctions.size}`; - - const copyInstructions = this.generateDynCopyInstructions(targetElmType, sourceElmType); + const [copyInstructions, requiredFunctions] = this.createDyamicToDynamicCopyCode( + targetType, + sourceType, + ); + assert(cairoSourceType instanceof CairoDynArray); + const funcName = `calldata_conversion_dynamic_to_dynamic${this.generatedFunctionsDef.size}`; + const recursiveFuncName = `${funcName}_helper`; const code = [ - `func ${loaderName}${implicit}(ref: felt, len: felt, ptr: ${cairoSourceType.ptr_member.toString()}*, target_index: felt){`, + `func ${recursiveFuncName}${IMPLICITS}(ref: felt, len: felt, ptr: ${cairoSourceType.ptr_member.toString()}*, target_index: felt){`, ` alloc_locals;`, ` if (len == 0){`, ` return ();`, ` }`, - ` let (storage_loc) = ${this.dynArrayIndexAccessGen.getOrCreate( - targetElmType, - )}(ref, Uint256(target_index, 0));`, - copyInstructions, + ` let (storage_loc) = ${arrayDef.name}(ref, Uint256(target_index, 0));`, + copyInstructions(), - ` return ${loaderName}(ref, len - 1, ptr + ${cairoSourceType.ptr_member.width}, target_index+ 1 );`, + ` return ${recursiveFuncName}(ref, len - 1, ptr + ${cairoSourceType.ptr_member.width}, target_index+ 1 );`, `}`, ``, - `func ${funcName}${implicit}(ref: felt, source: ${cairoSourceType.toString()}){`, + `func ${funcName}${IMPLICITS}(ref: felt, source: ${cairoSourceType.toString()}){`, ` alloc_locals;`, - ` ${dynArrayLengthName}.write(ref, Uint256(source.len, 0));`, - ` ${loaderName}(ref, source.len, source.ptr, 0);`, + ` ${dynArrayLength.name}.write(ref, Uint256(source.len, 0));`, + ` ${recursiveFuncName}(ref, source.len, source.ptr, 0);`, ' return ();', '}', ].join('\n'); - this.addImports(targetElmType, sourceElmType); - this.generatedFunctions.set(key, { name: funcName, code: code }); - return { name: funcName, code: code }; + + return { name: funcName, code: code, functionsCalled: [...requiredFunctions, dynArrayLength] }; } - private generateDynCopyInstructions(targetElmType: TypeNode, sourceElmType: TypeNode): string { - if (sourceElmType instanceof IntType && targetElmType instanceof IntType) { + private createStaticToStaticCopyCode( + targetType: ArrayType, + sourceType: ArrayType, + ): [(index: number, offset: number) => string, CairoFunctionDefinition[]] { + const targetElementT = targetType.elementT; + const sourceElementT = sourceType.elementT; + + if (targetElementT instanceof IntType) { + assert(sourceElementT instanceof IntType); + const writeToStorage = this.storageWriteGen.getOrCreateFuncDef(targetElementT); + if (targetElementT.nBits === sourceElementT.nBits) { + return [ + (index, offset) => + `${writeToStorage.name}(${add('storage_loc', offset)}, arg[${index}]);`, + [writeToStorage], + ]; + } + if (targetElementT.signed) { + const convertionFunc = this.requireImport( + 'warplib.maths.int_conversions', + `warp_int${sourceElementT.nBits}_to_int${targetElementT.nBits}`, + ); + return [ + (index, offset) => + [ + ` let (arg_${index}) = ${convertionFunc.name}(arg[${index}]);`, + ` ${writeToStorage.name}(${add('storage_loc', offset)}, arg_${index});`, + ].join('\n'), + [writeToStorage, convertionFunc], + ]; + } + const toUintFunc = this.requireImport('warplib.maths.utils', 'felt_to_uint256'); return [ - sourceElmType.signed - ? ` let (val) = warp_int${sourceElmType.nBits}_to_int${targetElmType.nBits}(ptr[0]);` - : ` let (val) = felt_to_uint256(ptr[0]);`, - ` ${this.storageWriteGen.getOrCreate(targetElmType)}(storage_loc, val);`, - ].join('\n'); - } else if (targetElmType instanceof FixedBytesType && sourceElmType instanceof FixedBytesType) { + (index, offset) => + [ + ` let (arg_${index}) = ${toUintFunc.name}(arg[${index}]);`, + ` ${writeToStorage.name}(${add('storage_loc', offset)}, arg_${index});`, + ].join('\n'), + [writeToStorage, toUintFunc], + ]; + } + + if (targetElementT instanceof FixedBytesType) { + assert(sourceElementT instanceof FixedBytesType); + const writeToStorage = this.storageWriteGen.getOrCreateFuncDef(targetElementT); + if (targetElementT.size > sourceElementT.size) { + const widenFunc = this.requireImport( + 'warplib.maths.bytes_conversions', + `warp_bytes_widen${targetElementT.size === 32 ? '_256' : ''}`, + ); + return [ + (index, offset) => + [ + ` let (arg_${index}) = ${widenFunc.name}(arg[${index}], ${ + (targetElementT.size - sourceElementT.size) * 8 + });`, + ` ${writeToStorage.name}(${add('storage_loc', offset)}, arg_${index});`, + ].join('\n'), + [writeToStorage, widenFunc], + ]; + } return [ - targetElmType.size === 32 - ? ` let (val) = warp_bytes_widen_256(ptr[0], ${ - (targetElmType.size - sourceElmType.size) * 8 - });` - : ` let (val) = warp_bytes_widen(ptr[0], ${ - (targetElmType.size - sourceElmType.size) * 8 - });`, - ` ${this.storageWriteGen.getOrCreate(targetElmType)}(storage_loc, val);`, - ].join('\n'); - } else { - return isDynamicStorageArray(targetElmType) - ? ` let (ref_name) = readId(storage_loc); - ${this.getOrCreate(targetElmType, sourceElmType)}(ref_name, ptr[0]);` - : ` ${this.getOrCreate(targetElmType, sourceElmType)}(storage_loc, ptr[0]);`; + (index, offset) => + ` ${writeToStorage.name}(${add('storage_loc', offset)}, arg[${index}]);`, + [writeToStorage], + ]; } + + const auxFunc = this.getOrCreateFuncDef(targetElementT, sourceElementT); + return [ + isDynamicArray(targetElementT) + ? (index, offset) => + [ + ` let (ref_${index}) = readId(${add('storage_loc', offset)});`, + ` ${auxFunc.name}(ref_${index}, arg[${index}]);`, + ].join('\n') + : (index, offset) => ` ${auxFunc.name}(${add('storage_loc', offset)}, arg[${index}]);`, + [auxFunc], + ]; } - addImports(targetElmType: TypeNode, sourceElmType: TypeNode): void { + private createStaticToDynamicCopyCode( + targetType: ArrayType, + sourceType: ArrayType, + ): [(index: number) => string, CairoFunctionDefinition[]] { + const targetElmType = targetType.elementT; + const sourceElmType = sourceType.elementT; + if (targetElmType instanceof IntType) { assert(sourceElmType instanceof IntType); - if (targetElmType.nBits > sourceElmType.nBits && targetElmType.signed) { - this.requireImport( + const arrayDef = this.dynArrayIndexAccessGen.getOrCreateFuncDef(targetElmType, targetType); + const writeDef = this.storageWriteGen.getOrCreateFuncDef(targetElmType); + if (targetElmType.nBits === sourceElmType.nBits) { + return [ + (index) => + [ + ` let (storage_loc${index}) = ${arrayDef.name}(ref, ${uint256(index)});`, + ` ${writeDef.name}(storage_loc${index}, arg[${index}]);`, + ].join('\n'), + [arrayDef, writeDef], + ]; + } + if (targetElmType.signed) { + const conversionFunc = this.requireImport( 'warplib.maths.int_conversions', `warp_int${sourceElmType.nBits}_to_int${targetElmType.nBits}`, ); - } else { - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); + return [ + (index) => + [ + ` let (arg_${index}) = ${conversionFunc.name}(arg[${index}]);`, + ` let (storage_loc${index}) = ${arrayDef.name}(ref, ${uint256(index)});`, + ` ${writeDef.name}(storage_loc${index}, arg_${index});`, + ].join('\n'), + [arrayDef, writeDef, conversionFunc], + ]; } - } else if (targetElmType instanceof FixedBytesType) { - this.requireImport( + const toUintFunc = this.requireImport('warplib.maths.utils', 'felt_to_uint256'); + return [ + (index) => + [ + ` let (arg_${index}) = ${toUintFunc.name}(arg[${index}]);`, + ` let (storage_loc${index}) = ${arrayDef.name}(ref, ${uint256(index)});`, + ` ${writeDef.name}(storage_loc${index}, arg_${index});`, + ].join('\n'), + [arrayDef, writeDef, toUintFunc], + ]; + } + + if (targetElmType instanceof FixedBytesType) { + assert(sourceElmType instanceof FixedBytesType); + const arrayDef = this.dynArrayIndexAccessGen.getOrCreateFuncDef(targetElmType, targetType); + const writeDef = this.storageWriteGen.getOrCreateFuncDef(targetElmType); + + if (targetElmType.size > sourceElmType.size) { + const widenFunc = this.requireImport( + 'warplib.maths.bytes_conversions', + `warp_bytes_widen${targetElmType.size === 32 ? '_256' : ''}`, + ); + const bits = (targetElmType.size - sourceElmType.size) * 8; + return [ + (index) => + [ + ` let (arg_${index}) = ${widenFunc.name}(arg[${index}], ${bits});`, + ` let (storage_loc${index}) = ${arrayDef.name}(ref, ${uint256(index)});`, + ` ${writeDef.name}(storage_loc${index}, arg_${index});`, + ].join('\n'), + [arrayDef, writeDef, widenFunc], + ]; + } + + return [ + (index) => + [ + ` let (storage_loc${index}) = ${arrayDef.name}(ref, ${uint256(index)});`, + ` ${writeDef.name}(storage_loc${index}, arg[${index}]);`, + ].join('\n'), + [arrayDef, writeDef], + ]; + } + + const sourceSize = sourceType.size; + assert(sourceSize !== undefined); + + const arrayDef = this.dynArrayIndexAccessGen.getOrCreateFuncDef(targetElmType, targetType); + const auxFunc = this.getOrCreateFuncDef(targetElmType, sourceElmType); + const [_dynArray, dynArrayLength] = this.dynArrayGen.getOrCreateFuncDef(targetElmType); + if (isDynamicArray(targetElmType)) { + return [ + (index) => + [ + ` let (storage_loc${index}) = ${arrayDef.name}(ref, ${uint256(index)});`, + ` let (ref_${index}) = readId(storage_loc${index});`, + // TODO: Potential bug here: when array size is reduced, remaining elements must be + // deleted. Investigate + ` ${dynArrayLength.name}.write(ref_${index}, ${uint256(sourceSize)});`, + ` ${auxFunc.name}(ref_${index}, arg[${index}]);`, + ].join('\n'), + [arrayDef, auxFunc, dynArrayLength], + ]; + } + + return [ + (index) => + [ + ` let (storage_loc${index}) = ${arrayDef.name}(ref, ${uint256(index)});`, + ` ${auxFunc.name}(storage_loc${index}, arg[${index}]);`, + ].join('\n'), + [arrayDef, auxFunc], + ]; + } + + private createDyamicToDynamicCopyCode( + targetType: ArrayType, + sourceType: ArrayType, + ): [() => string, CairoFunctionDefinition[]] { + const targetElmType = targetType.elementT; + const sourceElmType = sourceType.elementT; + + const writeDef = this.storageWriteGen.getOrCreateFuncDef(targetElmType); + + if (targetElmType instanceof IntType) { + assert(sourceElmType instanceof IntType); + const convertionFunc = targetElmType.signed + ? this.requireImport( + 'warplib.maths.int_conversions', + `warp_int${sourceElmType.nBits}_to_int${targetElmType.nBits}`, + ) + : this.requireImport('warplib.maths.utils', 'felt_to_uint256'); + return [ + () => + [ + sourceElmType.signed + ? ` let (val) = ${convertionFunc.name}(ptr[0]);` + : ` let (val) = felt_to_uint256(ptr[0]);`, + ` ${writeDef.name}(storage_loc, val);`, + ].join('\n'), + [writeDef, convertionFunc], + ]; + } + + if (targetElmType instanceof FixedBytesType) { + assert(sourceElmType instanceof FixedBytesType); + const widenFunc = this.requireImport( 'warplib.maths.bytes_conversions', - targetElmType.size === 32 ? 'warp_bytes_widen_256' : 'warp_bytes_widen', + `warp_bytes_widen${targetElmType.size === 32 ? '_256' : ''}`, ); + const bits = (targetElmType.size - sourceElmType.size) * 8; + return [ + () => + [ + ` let (val) = ${widenFunc.name}(ptr[0], ${bits});`, + ` ${writeDef.name}(storage_loc, val);`, + ].join('\n'), + [writeDef, widenFunc], + ]; } - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); + + const auxFunc = this.getOrCreateFuncDef(targetElmType, sourceElmType); + return [ + isDynamicArray(targetElmType) + ? () => + [`let (ref_name) = readId(storage_loc);`, `${auxFunc.name}(ref_name, ptr[0]);`].join( + '\n', + ) + : () => `${auxFunc.name}(storage_loc, ptr[0]);`, + [auxFunc], + ]; + } +} + +function checkSizes(targetType: TypeNode, sourceType: TypeNode): boolean { + const targetBaseType = getBaseType(targetType); + const sourceBaseType = getBaseType(sourceType); + if (targetBaseType instanceof IntType && sourceBaseType instanceof IntType) { + return ( + (targetBaseType.nBits > sourceBaseType.nBits && sourceBaseType.signed) || + (!targetBaseType.signed && targetBaseType.nBits === 256 && 256 > sourceBaseType.nBits) + ); + } + if (targetBaseType instanceof FixedBytesType && sourceBaseType instanceof FixedBytesType) { + return targetBaseType.size > sourceBaseType.size; + } + return false; +} + +function checkDims(targetType: TypeNode, sourceType: TypeNode): boolean { + if (targetType instanceof ArrayType && sourceType instanceof ArrayType) { + const targetArrayElm = targetType.elementT; + const sourceArrayElm = sourceType.elementT; + + if (targetType.size !== undefined && sourceType.size !== undefined) { + if (targetType.size > sourceType.size) { + return true; + } else if (targetArrayElm instanceof ArrayType && sourceArrayElm instanceof ArrayType) { + return checkDims(targetArrayElm, sourceArrayElm); + } else { + return false; + } + } else if (targetType.size === undefined && sourceType.size !== undefined) { + return true; + } else if (targetType.size === undefined && sourceType.size === undefined) + if (targetArrayElm instanceof ArrayType && sourceArrayElm instanceof ArrayType) { + return checkDims(targetArrayElm, sourceArrayElm); + } } + return false; } diff --git a/src/cairoUtilFuncGen/enumInputCheck.ts b/src/cairoUtilFuncGen/enumInputCheck.ts index 0fe07841c..eb74637c5 100644 --- a/src/cairoUtilFuncGen/enumInputCheck.ts +++ b/src/cairoUtilFuncGen/enumInputCheck.ts @@ -10,13 +10,16 @@ import { TypeNode, } from 'solc-typed-ast'; import { FunctionStubKind } from '../ast/cairoNodes'; -import { createCairoFunctionStub, createCallToFunction } from '../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../utils/functionGeneration'; import { safeGetNodeType } from '../utils/nodeTypeProcessing'; import { typeNameFromTypeNode } from '../utils/utils'; -import { StringIndexedFuncGen } from './base'; +import { GeneratedFunctionInfo, StringIndexedFuncGen } from './base'; +// TODO: Does this enum input check overrides the input check from the general method?! +// It looks like it does export class EnumInputCheck extends StringIndexedFuncGen { - gen( + // TODO: When is nodeInSourceUnit different thant the current sourceUnit?? + public gen( node: Expression, nodeInput: Expression, enumDef: EnumDefinition, @@ -26,36 +29,50 @@ export class EnumInputCheck extends StringIndexedFuncGen { const inputType = safeGetNodeType(nodeInput, this.ast.inference); this.sourceUnit = this.ast.getContainingRoot(nodeInSourceUnit); - const name = this.getOrCreate(inputType, enumDef); - const functionStub = createCairoFunctionStub( - name, + const funcDef = this.getOrCreateFuncDef(inputType, nodeType, enumDef); + return createCallToFunction(funcDef, [nodeInput], this.ast); + } + + public getOrCreateFuncDef(inputType: TypeNode, nodeType: TypeNode, enumDef: EnumDefinition) { + assert(inputType instanceof IntType); + + const key = enumDef.name + (inputType.nBits === 256 ? '256' : ''); + const exisiting = this.generatedFunctionsDef.get(key); + if (exisiting !== undefined) { + return exisiting; + } + + const funcInfo = this.getOrCreate(inputType, enumDef); + const funcDef = createCairoGeneratedFunction( + funcInfo, [['arg', typeNameFromTypeNode(inputType, this.ast), DataLocation.Default]], [['ret', typeNameFromTypeNode(nodeType, this.ast), DataLocation.Default]], - ['range_check_ptr'], this.ast, - nodeInSourceUnit ?? nodeInput, + this.sourceUnit, { mutability: FunctionStateMutability.Pure, stubKind: FunctionStubKind.FunctionDefStub, }, ); - - return createCallToFunction(functionStub, [nodeInput], this.ast); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - private getOrCreate(type: TypeNode, enumDef: EnumDefinition): string { - const key = `${enumDef.name}_${type.pp()}`; - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; + private getOrCreate(type: IntType, enumDef: EnumDefinition) { + const input256Bits = type.nBits === 256; + const funcName = `enum_bound_check_${enumDef.name}` + (input256Bits ? '_256' : ''); + + const imports = [this.requireImport('starkware.cairo.common.math_cmp', 'is_le_felt')]; + if (input256Bits) { + imports.push( + this.requireImport('warplib.maths.utils', 'narrow_safe'), + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + ); } - assert(type instanceof IntType); - const funcName = `enum_bound_check${this.generatedFunctions.size}`; const implicits = '{range_check_ptr : felt}'; const nMembers = enumDef.vMembers.length; - const input256Bits = type.nBits === 256; - this.generatedFunctions.set(key, { + const funcInfo: GeneratedFunctionInfo = { name: funcName, code: [ `func ${funcName}${implicits}(${ @@ -72,12 +89,8 @@ export class EnumInputCheck extends StringIndexedFuncGen { ` return (arg,);`, `}`, ].join('\n'), - }); - if (input256Bits) { - this.requireImport('warplib.maths.utils', 'narrow_safe'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - } - this.requireImport('starkware.cairo.common.math_cmp', 'is_le_felt'); - return funcName; + functionsCalled: imports, + }; + return funcInfo; } } diff --git a/src/cairoUtilFuncGen/event.ts b/src/cairoUtilFuncGen/event.ts index b64f483ae..632060680 100644 --- a/src/cairoUtilFuncGen/event.ts +++ b/src/cairoUtilFuncGen/event.ts @@ -8,21 +8,22 @@ import { TypeNode, } from 'solc-typed-ast'; import { - AbiEncode, AST, CairoType, - createCairoFunctionStub, createCallToFunction, isValueType, - IndexEncode, safeGetNodeType, TypeConversionContext, typeNameFromTypeNode, + createCairoGeneratedFunction, warpEventSignatureHash256FromString, EMIT_PREFIX, + CairoFunctionDefinition, } from '../export'; -import { StringIndexedFuncGen } from './base'; +import { GeneratedFunctionInfo, StringIndexedFuncGen } from './base'; import { ABIEncoderVersion } from 'solc-typed-ast/dist/types/abi'; +import { AbiEncode } from './abi/abiEncode'; +import { IndexEncode } from './abi/indexEncode'; export const BYTES_IN_FELT_PACKING = 31; const BIG_ENDIAN = 1; // 0 for little endian, used for packing of bytes (31 byte felts -> a 248 bit felt) @@ -48,37 +49,39 @@ export class EventFunction extends StringIndexedFuncGen { const argsTypes: TypeNode[] = node.vEventCall.vArguments.map( (arg) => generalizeType(safeGetNodeType(arg, this.ast.inference))[0], ); - - const funcName = this.getOrCreate(refEventDef); - - const functionStub = createCairoFunctionStub( - funcName, + const funcDef = this.getOrCreateFuncDef(refEventDef, argsTypes); + return createCallToFunction(funcDef, node.vEventCall.vArguments, this.ast); + } + private getOrCreateFuncDef(eventDef: EventDefinition, argsTypes: TypeNode[]) { + const key = `${eventDef.name}_${this.ast.inference.signatureHash( + eventDef, + ABIEncoderVersion.V2, + )}`; + const value = this.generatedFunctionsDef.get(key); + if (value !== undefined) { + return value; + } + const funcInfo = this.getOrCreate(eventDef); + const funcDef = createCairoGeneratedFunction( + funcInfo, argsTypes.map((argT, index) => isValueType(argT) ? [`param${index}`, typeNameFromTypeNode(argT, this.ast)] : [`param${index}`, typeNameFromTypeNode(argT, this.ast), DataLocation.Memory], ), [], - ['bitwise_ptr', 'range_check_ptr', 'warp_memory', 'keccak_ptr'], this.ast, this.sourceUnit, ); - - return createCallToFunction(functionStub, node.vEventCall.vArguments, this.ast); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - private getOrCreate(node: EventDefinition): string { + private getOrCreate(node: EventDefinition) { // Add the canonicalSignatureHash so that generated function names don't collide when overloaded - const key = `${node.name}_${this.ast.inference.signatureHash(node, ABIEncoderVersion.V2)}`; - const existing = this.generatedFunctions.get(key); - - if (existing !== undefined) { - return existing.name; - } - - const [params, keysInsertions, dataParams, dataParamTypes] = + const [params, keysInsertions, dataParams, dataParamTypes, requiredFuncs] = node.vParameters.vParameters.reduce( - ([params, keysInsertions, dataParams, dataParamTypes], param, index) => { + ([params, keysInsertions, dataParams, dataParamTypes, requiredFuncs], param, index) => { const paramType = generalizeType(safeGetNodeType(param, this.ast.inference))[0]; const cairoType = CairoType.fromSol(paramType, this.ast, TypeConversionContext.Ref); @@ -89,16 +92,20 @@ export class EventFunction extends StringIndexedFuncGen { if (isValueType(paramType)) { // If the parameter is a value type, we can just add it to the keys array // as it is, as we do regular abi encoding - keysInsertions.push( - this.generateSimpleEncodingCode([paramType], 'keys', [`param${index}`]), - ); + const [code, calledFuncs] = this.generateSimpleEncodingCode([paramType], 'keys', [ + `param${index}`, + ]); + keysInsertions.push(code); + requiredFuncs.push(...calledFuncs); } else { // If the parameter is a reference type, we hash the with special encoding // function: more at: // https://docs.soliditylang.org/en/v0.8.14/abi-spec.html#encoding-of-indexed-event-parameters - keysInsertions.push( - this.generateComplexEncodingCode([paramType], 'keys', [`param${index}`]), - ); + const [code, calledFuncs] = this.generateComplexEncodingCode([paramType], 'keys', [ + `param${index}`, + ]); + keysInsertions.push(code); + requiredFuncs.push(...calledFuncs); } } else { // A non-indexed parameter should go to the data array @@ -106,17 +113,22 @@ export class EventFunction extends StringIndexedFuncGen { dataParamTypes.push(paramType); } - return [params, keysInsertions, dataParams, dataParamTypes]; + return [params, keysInsertions, dataParams, dataParamTypes, requiredFuncs]; }, [ new Array<{ name: string; type: string }>(), new Array(), new Array(), new Array(), + new Array(), ], ); - const dataInsertions = this.generateSimpleEncodingCode(dataParamTypes, 'data', dataParams); + const [dataInsertions, dataInsertionsCalls] = this.generateSimpleEncodingCode( + dataParamTypes, + 'data', + dataParams, + ); const cairoParams = params.map((p) => `${p.name} : ${p.type}`).join(', '); @@ -124,18 +136,20 @@ export class EventFunction extends StringIndexedFuncGen { this.ast.inference.signature(node, ABIEncoderVersion.V2), ); + const [anonymousCode, anonymousCalls] = this.generateAnonymizeCode( + node.anonymous, + topic, + this.ast.inference.signature(node, ABIEncoderVersion.V2), + ); + const suffix = `${node.name}_${this.ast.inference.signatureHash(node, ABIEncoderVersion.V2)}`; const code = [ - `func ${EMIT_PREFIX}${key}${IMPLICITS}(${cairoParams}){`, + `func ${EMIT_PREFIX}${suffix}${IMPLICITS}(${cairoParams}){`, ` alloc_locals;`, ` // keys arrays`, ` let keys_len: felt = 0;`, ` let (keys: felt*) = alloc();`, ` //Insert topic`, - this.generateAnonymizeCode( - node.anonymous, - topic, - this.ast.inference.signature(node, ABIEncoderVersion.V2), - ), + anonymousCode, ...keysInsertions, ` // keys: pack 31 byte felts into a single 248 bit felt`, ` let (keys_len: felt, keys: felt*) = pack_bytes_felt(${BYTES_IN_FELT_PACKING}, ${BIG_ENDIAN}, keys_len, keys);`, @@ -150,65 +164,93 @@ export class EventFunction extends StringIndexedFuncGen { `}`, ].join('\n'); - this.requireImport('starkware.starknet.common.syscalls', 'emit_event'); - this.requireImport('starkware.cairo.common.alloc', 'alloc'); - this.requireImport('warplib.keccak', 'pack_bytes_felt'); - this.requireImport('starkware.cairo.common.cairo_builtins', 'BitwiseBuiltin'); - - this.generatedFunctions.set(key, { name: `${EMIT_PREFIX}${key}`, code: code }); - return `${EMIT_PREFIX}${key}`; + const funcInfo: GeneratedFunctionInfo = { + name: `${EMIT_PREFIX}${suffix}`, + code: code, + functionsCalled: [ + this.requireImport('starkware.starknet.common.syscalls', 'emit_event'), + this.requireImport('starkware.cairo.common.alloc', 'alloc'), + this.requireImport('warplib.keccak', 'pack_bytes_felt'), + this.requireImport('starkware.cairo.common.cairo_builtins', 'BitwiseBuiltin'), + ...requiredFuncs, + ...dataInsertionsCalls, + ...anonymousCalls, + ], + }; + return funcInfo; } private generateAnonymizeCode( isAnonymous: boolean, topic: { low: string; high: string }, eventSig: string, - ): string { - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport(`warplib.maths.utils`, 'felt_to_uint256'); - this.requireImport('warplib.dynamic_arrays_util', 'fixed_bytes256_to_felt_dynamic_array_spl'); + ): [string, CairoFunctionDefinition[]] { if (isAnonymous) { - return [`// Event is anonymous, topic won't be added to keys`].join('\n'); + return [[`// Event is anonymous, topic won't be added to keys`].join('\n'), []]; } return [ - ` let topic256: Uint256 = Uint256(${topic.low}, ${topic.high});// keccak of event signature: ${eventSig}`, - ` let (keys_len: felt) = fixed_bytes256_to_felt_dynamic_array_spl(keys_len, keys, 0, topic256);`, - ].join('\n'); + [ + ` let topic256: Uint256 = Uint256(${topic.low}, ${topic.high});// keccak of event signature: ${eventSig}`, + ` let (keys_len: felt) = fixed_bytes256_to_felt_dynamic_array_spl(keys_len, keys, 0, topic256);`, + ].join('\n'), + [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport(`warplib.maths.utils`, 'felt_to_uint256'), + this.requireImport( + 'warplib.dynamic_arrays_util', + 'fixed_bytes256_to_felt_dynamic_array_spl', + ), + ], + ]; } private generateSimpleEncodingCode( types: TypeNode[], arrayName: string, argNames: string[], - ): string { - const abiFunc = this.abiEncode.getOrCreate(types); + ): [string, CairoFunctionDefinition[]] { + const abiFunc = this.abiEncode.getOrCreateFuncDef(types); this.requireImport('warplib.memory', 'wm_to_felt_array'); this.requireImport('warplib.keccak', 'felt_array_concat'); return [ - ` let (mem_encode: felt) = ${abiFunc}(${argNames.join(',')});`, - ` let (encode_bytes_len: felt, encode_bytes: felt*) = wm_to_felt_array(mem_encode);`, - ` let (${arrayName}_len: felt) = felt_array_concat(encode_bytes_len, 0, encode_bytes, ${arrayName}_len, ${arrayName});`, - ].join('\n'); + [ + ` let (mem_encode: felt) = ${abiFunc.name}(${argNames.join(',')});`, + ` let (encode_bytes_len: felt, encode_bytes: felt*) = wm_to_felt_array(mem_encode);`, + ` let (${arrayName}_len: felt) = felt_array_concat(encode_bytes_len, 0, encode_bytes, ${arrayName}_len, ${arrayName});`, + ].join('\n'), + [ + this.requireImport('warplib.memory', 'wm_to_felt_array'), + this.requireImport('warplib.keccak', 'felt_array_concat'), + abiFunc, + ], + ]; } private generateComplexEncodingCode( types: TypeNode[], arrayName: string, argNames: string[], - ): string { - const abiFunc = this.indexEncode.getOrCreate(types); - - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport(`warplib.maths.utils`, 'felt_to_uint256'); - this.requireImport('warplib.keccak', 'warp_keccak'); - this.requireImport('warplib.dynamic_arrays_util', 'fixed_bytes256_to_felt_dynamic_array_spl'); + ): [string, CairoFunctionDefinition[]] { + const abiFunc = this.indexEncode.getOrCreateFuncDef(types); return [ - ` let (mem_encode: felt) = ${abiFunc}(${argNames.join(',')});`, - ` let (keccak_hash256: Uint256) = warp_keccak(mem_encode);`, - ` let (${arrayName}_len: felt) = fixed_bytes256_to_felt_dynamic_array_spl(${arrayName}_len, ${arrayName}, 0, keccak_hash256);`, - ].join('\n'); + [ + ` let (mem_encode: felt) = ${abiFunc.name}(${argNames.join(',')});`, + ` let (keccak_hash256: Uint256) = warp_keccak(mem_encode);`, + ` let (${arrayName}_len: felt) = fixed_bytes256_to_felt_dynamic_array_spl(${arrayName}_len, ${arrayName}, 0, keccak_hash256);`, + ].join('\n'), + [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport(`warplib.maths.utils`, 'felt_to_uint256'), + this.requireImport('warplib.keccak', 'warp_keccak'), + this.requireImport( + 'warplib.dynamic_arrays_util', + 'fixed_bytes256_to_felt_dynamic_array_spl', + ), + abiFunc, + ], + ]; } } diff --git a/src/cairoUtilFuncGen/index.ts b/src/cairoUtilFuncGen/index.ts index 3cad8bc42..f6848e964 100644 --- a/src/cairoUtilFuncGen/index.ts +++ b/src/cairoUtilFuncGen/index.ts @@ -1,6 +1,4 @@ import { AST } from '../ast/ast'; -import { mergeImports } from '../utils/utils'; -import { CairoUtilFuncGenBase } from './base'; import { InputCheckGen } from './inputArgCheck/inputCheck'; import { MemoryArrayLiteralGen } from './memory/arrayLiteral'; import { MemoryDynArrayLengthGen } from './memory/memoryDynArrayLength'; @@ -11,7 +9,6 @@ import { MemoryWriteGen } from './memory/memoryWrite'; import { MemoryStaticArrayIndexAccessGen } from './memory/staticIndexAccess'; import { DynArrayGen } from './storage/dynArray'; import { DynArrayIndexAccessGen } from './storage/dynArrayIndexAccess'; -import { DynArrayLengthGen } from './storage/dynArrayLength'; import { DynArrayPopGen } from './storage/dynArrayPop'; import { DynArrayPushWithArgGen } from './storage/dynArrayPushWithArg'; import { DynArrayPushWithoutArgGen } from './storage/dynArrayPushWithoutArg'; @@ -40,8 +37,8 @@ import { AbiEncodePacked } from './abi/abiEncodePacked'; import { AbiEncodeWithSelector } from './abi/abiEncodeWithSelector'; import { AbiEncodeWithSignature } from './abi/abiEncodeWithSignature'; import { AbiDecode } from './abi/abiDecode'; -import { EventFunction } from '../export'; import { IndexEncode } from './abi/indexEncode'; +import { EventFunction } from './event'; export class CairoUtilFuncGen { abi: { @@ -72,8 +69,8 @@ export class CairoUtilFuncGen { }; storage: { delete: StorageDeleteGen; + dynArray: DynArrayGen; dynArrayIndexAccess: DynArrayIndexAccessGen; - dynArrayLength: DynArrayLengthGen; dynArrayPop: DynArrayPopGen; dynArrayPush: { withArg: DynArrayPushWithArgGen; @@ -100,56 +97,31 @@ export class CairoUtilFuncGen { encodeAsFelt: EncodeAsFelt; }; - private implementation: { - dynArray: DynArrayGen; - }; - constructor(ast: AST, sourceUnit: SourceUnit) { - this.implementation = { - dynArray: new DynArrayGen(ast, sourceUnit), - }; - + const dynArray = new DynArrayGen(ast, sourceUnit); + const memoryRead = new MemoryReadGen(ast, sourceUnit); const storageReadGen = new StorageReadGen(ast, sourceUnit); - const storageDelete = new StorageDeleteGen( - this.implementation.dynArray, - storageReadGen, - ast, - sourceUnit, - ); + const storageDelete = new StorageDeleteGen(dynArray, storageReadGen, ast, sourceUnit); const memoryToStorage = new MemoryToStorageGen( - this.implementation.dynArray, + dynArray, + memoryRead, storageDelete, ast, sourceUnit, ); const storageWrite = new StorageWriteGen(ast, sourceUnit); - const storageToStorage = new StorageToStorageGen( - this.implementation.dynArray, - storageDelete, - ast, - sourceUnit, - ); - const calldataToStorage = new CalldataToStorageGen( - this.implementation.dynArray, - storageWrite, - ast, - sourceUnit, - ); + const storageToStorage = new StorageToStorageGen(dynArray, storageDelete, ast, sourceUnit); + const calldataToStorage = new CalldataToStorageGen(dynArray, storageWrite, ast, sourceUnit); const externalDynArrayStructConstructor = new ExternalDynArrayStructConstructor( ast, sourceUnit, ); - const memoryRead = new MemoryReadGen(ast, sourceUnit); const memoryWrite = new MemoryWriteGen(ast, sourceUnit); - const storageDynArrayIndexAccess = new DynArrayIndexAccessGen( - this.implementation.dynArray, - ast, - sourceUnit, - ); + const storageDynArrayIndexAccess = new DynArrayIndexAccessGen(dynArray, ast, sourceUnit); const callDataConvert = new ImplicitArrayConversion( storageWrite, - this.implementation.dynArray, + dynArray, storageDynArrayIndexAccess, ast, sourceUnit, @@ -163,18 +135,23 @@ export class CairoUtilFuncGen { read: memoryRead, staticArrayIndexAccess: new MemoryStaticArrayIndexAccessGen(ast, sourceUnit), struct: new MemoryStructGen(ast, sourceUnit), - toCallData: new MemoryToCallDataGen(externalDynArrayStructConstructor, ast, sourceUnit), + toCallData: new MemoryToCallDataGen( + externalDynArrayStructConstructor, + memoryRead, + ast, + sourceUnit, + ), toStorage: memoryToStorage, write: memoryWrite, }; this.storage = { delete: storageDelete, + dynArray: dynArray, dynArrayIndexAccess: storageDynArrayIndexAccess, - dynArrayLength: new DynArrayLengthGen(this.implementation.dynArray, ast, sourceUnit), - dynArrayPop: new DynArrayPopGen(this.implementation.dynArray, storageDelete, ast, sourceUnit), + dynArrayPop: new DynArrayPopGen(dynArray, storageDelete, ast, sourceUnit), dynArrayPush: { withArg: new DynArrayPushWithArgGen( - this.implementation.dynArray, + dynArray, storageWrite, memoryToStorage, storageToStorage, @@ -183,20 +160,20 @@ export class CairoUtilFuncGen { ast, sourceUnit, ), - withoutArg: new DynArrayPushWithoutArgGen(this.implementation.dynArray, ast, sourceUnit), + withoutArg: new DynArrayPushWithoutArgGen(dynArray, ast, sourceUnit), }, - mappingIndexAccess: new MappingIndexAccessGen(this.implementation.dynArray, ast, sourceUnit), + mappingIndexAccess: new MappingIndexAccessGen(dynArray, ast, sourceUnit), memberAccess: new StorageMemberAccessGen(ast, sourceUnit), read: storageReadGen, staticArrayIndexAccess: new StorageStaticArrayIndexAccessGen(ast, sourceUnit), toCallData: new StorageToCalldataGen( - this.implementation.dynArray, + dynArray, storageReadGen, externalDynArrayStructConstructor, ast, sourceUnit, ), - toMemory: new StorageToMemoryGen(this.implementation.dynArray, ast, sourceUnit), + toMemory: new StorageToMemoryGen(dynArray, ast, sourceUnit), toStorage: storageToStorage, write: storageWrite, }; @@ -228,33 +205,4 @@ export class CairoUtilFuncGen { encodeAsFelt: new EncodeAsFelt(externalDynArrayStructConstructor, ast, sourceUnit), }; } - - getImports(): Map> { - return mergeImports(...this.getAllChildren().map((c) => c.getImports())); - } - getGeneratedCode(): string { - return this.getAllChildren() - .map((c) => c.getGeneratedCode()) - .sort((a, b) => { - // This sort is needed to make sure the structs generated from CairoUtilGen are before the generated functions that - // reference them. This sort is also order preserving in that it will only make sure the structs come before - // any functions and not sort the struct/functions within their respective groups. - if (a.slice(0, 1) < b.slice(0, 1)) { - return 1; - } else if (a.slice(0, 1) > b.slice(0, 1)) { - return -1; - } - return 0; - }) - .join('\n\n'); - } - private getAllChildren(): CairoUtilFuncGenBase[] { - return getAllGenerators(this); - } -} - -function getAllGenerators(container: unknown): CairoUtilFuncGenBase[] { - if (typeof container !== 'object' || container === null) return []; - if (container instanceof CairoUtilFuncGenBase) return [container]; - return Object.values(container).flatMap(getAllGenerators); } diff --git a/src/cairoUtilFuncGen/inputArgCheck/inputCheck.ts b/src/cairoUtilFuncGen/inputArgCheck/inputCheck.ts index 745e3e7d5..90d7ecda1 100644 --- a/src/cairoUtilFuncGen/inputArgCheck/inputCheck.ts +++ b/src/cairoUtilFuncGen/inputArgCheck/inputCheck.ts @@ -1,11 +1,8 @@ import assert from 'assert'; import { - AddressType, ArrayType, - ASTNode, BoolType, BytesType, - ContractDefinition, DataLocation, EnumDefinition, Expression, @@ -20,28 +17,32 @@ import { UserDefinedType, VariableDeclaration, } from 'solc-typed-ast'; -import { FunctionStubKind } from '../../ast/cairoNodes'; +import { CairoFunctionDefinition, FunctionStubKind } from '../../ast/cairoNodes'; import { printTypeNode } from '../../utils/astPrinter'; import { CairoDynArray, CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; import { NotSupportedYetError } from '../../utils/errors'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { createIdentifier } from '../../utils/nodeTemplates'; import { mapRange, narrowBigIntSafe, typeNameFromTypeNode } from '../../utils/utils'; -import { delegateBasedOnType, locationIfComplexType, StringIndexedFuncGen } from '../base'; +import { + delegateBasedOnType, + GeneratedFunctionInfo, + locationIfComplexType, + StringIndexedFuncGen, +} from '../base'; import { checkableType, getElementType, + isAddressType, isDynamicArray, safeGetNodeType, } from '../../utils/nodeTypeProcessing'; import { cloneASTNode } from '../../utils/cloning'; +const IMPLICITS = '{range_check_ptr : felt}'; + export class InputCheckGen extends StringIndexedFuncGen { - gen( - nodeInput: VariableDeclaration | Expression, - typeToCheck: TypeNode, - nodeInSourceUnit: ASTNode, - ): FunctionCall { + public gen(nodeInput: VariableDeclaration | Expression, typeToCheck: TypeNode): FunctionCall { let functionInput; let isUint256 = false; if (nodeInput instanceof VariableDeclaration) { @@ -51,166 +52,164 @@ export class InputCheckGen extends StringIndexedFuncGen { const inputType = safeGetNodeType(nodeInput, this.ast.inference); this.ast.setContextRecursive(functionInput); isUint256 = inputType instanceof IntType && inputType.nBits === 256; - this.requireImport('warplib.maths.utils', 'narrow_safe'); } - this.sourceUnit = this.ast.getContainingRoot(nodeInSourceUnit); - const name = this.getOrCreate(typeToCheck, isUint256); - const functionStub = createCairoFunctionStub( - name, + + const funcDef = this.getOrCreateFuncDef(typeToCheck, isUint256); + return createCallToFunction(funcDef, [functionInput], this.ast); + } + + private getOrCreateFuncDef(type: TypeNode, takesUint256: boolean): CairoFunctionDefinition { + const key = type.pp(); + const value = this.generatedFunctionsDef.get(key); + if (value !== undefined) { + return value; + } + + if (type instanceof FixedBytesType) + return this.requireImport( + 'warplib.maths.external_input_check_ints', + `warp_external_input_check_int${type.size * 8}`, + ); + if (type instanceof IntType) + return this.requireImport( + 'warplib.maths.external_input_check_ints', + `warp_external_input_check_int${type.nBits}`, + ); + if (isAddressType(type)) + return this.requireImport( + 'warplib.maths.external_input_check_address', + `warp_external_input_check_address`, + ); + if (type instanceof BoolType) + return this.requireImport( + 'warplib.maths.external_input_check_bool', + `warp_external_input_check_bool`, + ); + + const funcInfo = this.getOrCreate(type, takesUint256); + const funcDef = createCairoGeneratedFunction( + funcInfo, [ [ 'ref_var', - typeNameFromTypeNode(typeToCheck, this.ast), - locationIfComplexType(typeToCheck, DataLocation.CallData), + typeNameFromTypeNode(type, this.ast), + locationIfComplexType(type, DataLocation.CallData), ], ], [], - ['range_check_ptr'], this.ast, - nodeInSourceUnit ?? nodeInput, + this.sourceUnit, { mutability: FunctionStateMutability.Pure, stubKind: FunctionStubKind.FunctionDefStub, - acceptsRawDArray: isDynamicArray(typeToCheck), + acceptsRawDArray: isDynamicArray(type), }, ); - return createCallToFunction(functionStub, [functionInput], this.ast); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - private getOrCreate(type: TypeNode, takesUint = false): string { - const key = type.pp(); - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - + private getOrCreate(type: TypeNode, takesUint: boolean): GeneratedFunctionInfo { const unexpectedTypeFunc = () => { throw new NotSupportedYetError(`Input check for ${printTypeNode(type)} not defined yet.`); }; - return delegateBasedOnType( + return delegateBasedOnType( type, - (type) => this.createDynArrayInputCheck(key, this.generateFuncName(key), type), - (type) => this.createStaticArrayInputCheck(key, this.generateFuncName(key), type), - (type) => this.createStructInputCheck(key, this.generateFuncName(key), type), + (type) => this.createDynArrayInputCheck(type), + (type) => this.createStaticArrayInputCheck(type), + (type, def) => this.createStructInputCheck(type, def), unexpectedTypeFunc, (type) => { - if (type instanceof FixedBytesType) { - return this.createIntInputCheck(type.size * 8); - } else if (type instanceof IntType) { - return this.createIntInputCheck(type.nBits); - } else if (type instanceof BoolType) { - return this.createBoolInputCheck(); - } else if (type instanceof UserDefinedType && type.definition instanceof EnumDefinition) { - return this.createEnumInputCheck(key, type, takesUint); - } else if ( - type instanceof AddressType || - (type instanceof UserDefinedType && type.definition instanceof ContractDefinition) - ) { - return this.createAddressInputCheck(); - } else { - return unexpectedTypeFunc(); - } + if (type instanceof UserDefinedType && type.definition instanceof EnumDefinition) + return this.createEnumInputCheck(type, takesUint); + return unexpectedTypeFunc(); }, ); } - private generateFuncName(key: string): string { - const funcName = `extern_input_check${this.generatedFunctions.size}`; - this.generatedFunctions.set(key, { name: funcName, code: '' }); - return funcName; - } - - private createIntInputCheck(bitWidth: number): string { - const funcName = `warp_external_input_check_int${bitWidth}`; - this.requireImport( - 'warplib.maths.external_input_check_ints', - `warp_external_input_check_int${bitWidth}`, - ); - return funcName; - } + private createStructInputCheck( + type: UserDefinedType, + structDef: StructDefinition, + ): GeneratedFunctionInfo { + const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); - private createAddressInputCheck(): string { - const funcName = 'warp_external_input_check_address'; - this.requireImport( - 'warplib.maths.external_input_check_address', - `warp_external_input_check_address`, + const [inputCheckCode, funcCalls] = structDef.vMembers.reduce( + ([inputCheckCode, funcCalls], decl) => { + const memberType = safeGetNodeType(decl, this.ast.inference); + if (checkableType(memberType)) { + const memberCheckFunc = this.getOrCreateFuncDef(memberType, false); + return [ + [...inputCheckCode, `${memberCheckFunc.name}(arg.${decl.name});`], + [...funcCalls, memberCheckFunc], + ]; + } + return [inputCheckCode, funcCalls]; + }, + [new Array(), new Array()], ); - return funcName; - } - private createStructInputCheck(key: string, funcName: string, type: UserDefinedType): string { - const implicits = '{range_check_ptr : felt}'; - - const structDef = type.definition; - assert(structDef instanceof StructDefinition); - const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); - - this.generatedFunctions.set(key, { + const funcName = `external_input_check_struct_${structDef.name}`; + const funcInfo: GeneratedFunctionInfo = { name: funcName, code: [ - `func ${funcName}${implicits}(arg : ${cairoType.toString()}) -> (){`, + `func ${funcName}${IMPLICITS}(arg : ${cairoType.toString()}) -> (){`, `alloc_locals;`, - ...structDef.vMembers.map((decl) => { - const memberType = safeGetNodeType(decl, this.ast.inference); - this.checkForImport(memberType); - if (checkableType(memberType)) { - const memberCheck = this.getOrCreate(memberType); - return [`${memberCheck}(arg.${decl.name});`]; - } else { - return ''; - } - }), + ...inputCheckCode, `return ();`, `}`, ].join('\n'), - }); - return funcName; + functionsCalled: funcCalls, + }; + return funcInfo; } - private createStaticArrayInputCheck(key: string, funcName: string, type: ArrayType): string { - const implicits = '{range_check_ptr : felt}'; - + // Todo: This function can probably be made recursive for big size static arrays + private createStaticArrayInputCheck(type: ArrayType): GeneratedFunctionInfo { assert(type.size !== undefined); const length = narrowBigIntSafe(type.size); - assert(length !== undefined); const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); const elementType = generalizeType(type.elementT)[0]; - this.checkForImport(elementType); - this.generatedFunctions.set(key, { + + const auxFunc = this.getOrCreateFuncDef(elementType, false); + + const funcName = `external_input_check_static_array${this.generatedFunctionsDef.size}`; + const funcInfo: GeneratedFunctionInfo = { name: funcName, code: [ - `func ${funcName}${implicits}(arg : ${cairoType.toString()}) -> (){`, + `func ${funcName}${IMPLICITS}(arg : ${cairoType.toString()}) -> (){`, `alloc_locals;`, ...mapRange(length, (index) => { - const indexCheck = this.getOrCreate(elementType); - return [`${indexCheck}(arg[${index}]);`]; + return [`${auxFunc.name}(arg[${index}]);`]; }), `return ();`, `}`, ].join('\n'), - }); - return funcName; - } - - private createBoolInputCheck(): string { - const funcName = `warp_external_input_check_bool`; - this.requireImport('warplib.maths.external_input_check_bool', `warp_external_input_check_bool`); - return funcName; + functionsCalled: [auxFunc], + }; + return funcInfo; } - private createEnumInputCheck(key: string, type: UserDefinedType, takesUint = false): string { - const funcName = `extern_input_check${this.generatedFunctions.size}`; - const implicits = '{range_check_ptr : felt}'; - + // TODO: this function and EnumInputCheck single file do the same??? + // TODO: When does takesUint == true? + private createEnumInputCheck(type: UserDefinedType, takesUint = false): GeneratedFunctionInfo { const enumDef = type.definition; assert(enumDef instanceof EnumDefinition); + + // TODO: enum names are unique right? + const funcName = `external_input_check_enum_${enumDef.name}`; + + const importFuncs = [this.requireImport('starkware.cairo.common.math_cmp', 'is_le_felt')]; + if (takesUint) { + importFuncs.push(this.requireImport('warplib.maths.utils', 'narrow_safe')); + } + const nMembers = enumDef.vMembers.length; - this.generatedFunctions.set(key, { + const funcInfo: GeneratedFunctionInfo = { name: funcName, code: [ - `func ${funcName}${implicits}(arg : ${takesUint ? 'Uint256' : 'felt'}) -> (){`, + `func ${funcName}${IMPLICITS}(arg : ${takesUint ? 'Uint256' : 'felt'}) -> (){`, takesUint ? [ ' let (arg_0) = narrow_safe(arg);', @@ -225,39 +224,38 @@ export class InputCheckGen extends StringIndexedFuncGen { ` return ();`, `}`, ].join('\n'), - }); - this.requireImport('starkware.cairo.common.math_cmp', 'is_le_felt'); - return funcName; + functionsCalled: importFuncs, + }; + return funcInfo; } private createDynArrayInputCheck( - key: string, - funcName: string, type: ArrayType | BytesType | StringType, - ): string { - const implicits = '{range_check_ptr : felt}'; - + ): GeneratedFunctionInfo { const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); assert(cairoType instanceof CairoDynArray); + const ptrType = cairoType.vPtr; const elementType = generalizeType(getElementType(type))[0]; - this.checkForImport(elementType); - const indexCheck = [`${this.getOrCreate(elementType)}(ptr[0]);`]; - this.generatedFunctions.set(key, { + const calledFunction = this.getOrCreateFuncDef(elementType, false); + + const funcName = `external_input_check_dynamic_array${this.generatedFunctionsDef.size}`; + const funcInfo: GeneratedFunctionInfo = { name: funcName, code: [ - `func ${funcName}${implicits}(len: felt, ptr : ${ptrType.toString()}) -> (){`, + `func ${funcName}${IMPLICITS}(len: felt, ptr : ${ptrType.toString()}) -> (){`, ` alloc_locals;`, ` if (len == 0){`, ` return ();`, ` }`, - ...indexCheck, - ` ${funcName}(len = len - 1, ptr = ptr + ${ptrType.to.width});`, + ` ${calledFunction.name}(ptr[0]);`, + ` ${funcName}(len = len - 1, ptr = ptr + ${ptrType.to.width});`, ` return ();`, `}`, ].join('\n'), - }); - return funcName; + functionsCalled: [calledFunction], + }; + return funcInfo; } } diff --git a/src/cairoUtilFuncGen/memory/arrayConcat.ts b/src/cairoUtilFuncGen/memory/arrayConcat.ts index bf215d13f..9f28eec66 100644 --- a/src/cairoUtilFuncGen/memory/arrayConcat.ts +++ b/src/cairoUtilFuncGen/memory/arrayConcat.ts @@ -1,101 +1,104 @@ import assert from 'assert'; import { + BytesType, DataLocation, FixedBytesType, FunctionCall, + generalizeType, IntType, - PointerType, + isReferenceType, + StringType, TypeName, TypeNode, } from 'solc-typed-ast'; -import { printNode, printTypeNode } from '../../utils/astPrinter'; +import { CairoImportFunctionDefinition } from '../../ast/cairoNodes'; +import { createBytesTypeName, createStringTypeName } from '../../export'; +import { printTypeNode } from '../../utils/astPrinter'; import { CairoType } from '../../utils/cairoTypeSystem'; import { TranspileFailedError } from '../../utils/errors'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; -import { Implicits } from '../../utils/implicits'; -import { isDynamicArray, safeGetNodeType } from '../../utils/nodeTypeProcessing'; +import { + createCairoGeneratedFunction, + createCallToFunction, + ParameterInfo, +} from '../../utils/functionGeneration'; +import { safeGetNodeType } from '../../utils/nodeTypeProcessing'; import { mapRange, typeNameFromTypeNode } from '../../utils/utils'; import { getIntOrFixedByteBitWidth, uint256 } from '../../warplib/utils'; -import { CairoFunction, StringIndexedFuncGen } from '../base'; +import { GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; export class MemoryArrayConcat extends StringIndexedFuncGen { - gen(concat: FunctionCall) { - const args = concat.vArguments; - args.forEach((expr) => { - const exprType = safeGetNodeType(expr, this.ast.inference); - if ( - !isDynamicArray(exprType) && - !(exprType instanceof IntType || exprType instanceof FixedBytesType) - ) - throw new TranspileFailedError( - `Unexpected type ${printTypeNode(exprType)} in ${printNode(expr)} to concatenate.` + - 'Expected FixedBytes, IntType, ArrayType, BytesType, or StringType', - ); - }); - - const inputs: [string, TypeName, DataLocation][] = mapRange(args.length, (n) => [ - `arg_${n}`, - typeNameFromTypeNode(safeGetNodeType(args[n], this.ast.inference), this.ast), - DataLocation.Memory, - ]); - const output: [string, TypeName, DataLocation] = [ - 'res_loc', - typeNameFromTypeNode(safeGetNodeType(concat, this.ast.inference), this.ast), - DataLocation.Memory, - ]; - - const argTypes = args.map((e) => safeGetNodeType(e, this.ast.inference)); - const name = this.getOrCreate(argTypes); - - const implicits: Implicits[] = argTypes.some( - (type) => type instanceof IntType || type instanceof FixedBytesType, - ) - ? ['bitwise_ptr', 'range_check_ptr', 'warp_memory'] - : ['range_check_ptr', 'warp_memory']; - - const functionStub = createCairoFunctionStub( - name, - inputs, - [output], - implicits, - this.ast, - concat, + public gen(concat: FunctionCall) { + const argTypes = concat.vArguments.map( + (expr) => generalizeType(safeGetNodeType(expr, this.ast.inference))[0], ); - return createCallToFunction(functionStub, args, this.ast); + const funcDef = this.getOrCreateFuncDef(argTypes); + return createCallToFunction(funcDef, concat.vArguments, this.ast); } - private getOrCreate(argTypes: TypeNode[]): string { + public getOrCreateFuncDef(argTypes: TypeNode[]) { + // TODO: Check for hex"" and unicode"" which are treated as bytes instead of strings?! + const validArgs = argTypes.every( + (type) => type instanceof BytesType || type instanceof FixedBytesType || StringType, + ); + assert( + validArgs, + `Concat arguments must be all of string, bytes or fixed bytes type. Instead of: ${argTypes.map( + (t) => printTypeNode(t), + )}`, + ); + const key = argTypes + // TODO: Wouldn't type.pp() work here? .map((type) => { - if (type instanceof PointerType) return 'A'; + if (isReferenceType(type)) return 'A'; return `B${getIntOrFixedByteBitWidth(type)}`; }) .join(''); - - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; + const value = this.generatedFunctionsDef.get(key); + if (value !== undefined) { + return value; } + const inputs: ParameterInfo[] = argTypes.map((arg, index) => [ + `arg_${index}`, + typeNameFromTypeNode(arg, this.ast), + DataLocation.Memory, + ]); + + const outputTypeName: TypeName = argTypes.some((t) => t instanceof StringType) + ? createStringTypeName(this.ast) + : createBytesTypeName(this.ast); + const output: ParameterInfo = ['res_loc', outputTypeName, DataLocation.Memory]; + + const funcInfo = this.getOrCreate(argTypes); + const funcDef = createCairoGeneratedFunction( + funcInfo, + inputs, + [output], + this.ast, + this.sourceUnit, + ); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; + } + + private getOrCreate(argTypes: TypeNode[]): GeneratedFunctionInfo { const implicits = argTypes.some( (type) => type instanceof IntType || type instanceof FixedBytesType, ) ? '{bitwise_ptr : BitwiseBuiltin*, range_check_ptr : felt, warp_memory : DictAccess*}' : '{range_check_ptr : felt, warp_memory : DictAccess*}'; - const cairoFunc = this.generateBytesConcat(argTypes, implicits); - this.generatedFunctions.set(key, cairoFunc); - return cairoFunc.name; + const funcInfo = this.generateBytesConcat(argTypes, implicits); + return funcInfo; } - private generateBytesConcat(argTypes: TypeNode[], implicits: string): CairoFunction { + private generateBytesConcat(argTypes: TypeNode[], implicits: string): GeneratedFunctionInfo { const argAmount = argTypes.length; - const funcName = `concat${this.generatedFunctions.size}_${argAmount}`; + const funcName = `concat${this.generatedFunctionsDef.size}_${argAmount}`; if (argAmount === 0) { - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('warplib.memory', 'wm_new'); return { name: funcName, code: [ @@ -105,6 +108,10 @@ export class MemoryArrayConcat extends StringIndexedFuncGen { ` return (res_loc,);`, `}`, ].join('\n'), + functionsCalled: [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('warplib.memory', 'wm_new'), + ], }; } @@ -112,51 +119,82 @@ export class MemoryArrayConcat extends StringIndexedFuncGen { const cairoType = CairoType.fromSol(type, this.ast).toString(); return `arg_${index} : ${cairoType}`; }); + + const [argSizes, argSizesImports] = argTypes + .map((t, n) => this.getSize(t, n)) + .reduce(([argSizes, argSizesImports], [sizeCode, sizeImport]) => { + return [`${argSizes}\n${sizeCode}`, [...argSizesImports, ...sizeImport]]; + }); + + const [concatCode, concatImports] = argTypes.reduce( + ([concatCode, concatImports], argType, index) => { + const [copyCode, copyImport] = this.getCopyFunctionCall(argType, index); + const fullCopyCode = [ + `let end_loc = start_loc + size_${index};`, + copyCode, + `let start_loc = end_loc;`, + ]; + return [ + [ + ...concatCode, + index < argTypes.length - 1 + ? fullCopyCode.join('\n') + : fullCopyCode.slice(0, -1).join('\n'), + ], + [...concatImports, copyImport], + ]; + }, + [new Array(), new Array()], + ); + const code = [ `func ${funcName}${implicits}(${cairoArgs}) -> (res_loc : felt){`, ` alloc_locals;`, ` // Get all sizes`, - ...argTypes.map((t, n) => this.getSize(t, n)), + argSizes, ` let total_length = ${mapRange(argAmount, (n) => `size_${n}`).join('+')};`, ` let (total_length256) = felt_to_uint256(total_length);`, ` let (res_loc) = wm_new(total_length256, ${uint256(1)});`, ` // Copy values`, ` let start_loc = 0;`, - ...mapRange(argAmount, (n) => { - const copy = [ - `let end_loc = start_loc + size_${n};`, - this.getCopyFunctionCall(argTypes[n], n), - `let start_loc = end_loc;`, - ]; - return n < argAmount - 1 ? copy.join('\n') : copy.slice(0, -1).join('\n'); - }), + ...concatCode, ` return (res_loc,);`, `}`, ].join('\n'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); - this.requireImport('warplib.memory', 'wm_new'); - - return { name: funcName, code: code }; + return { + name: funcName, + code: code, + functionsCalled: [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('warplib.maths.utils', 'felt_to_uint256'), + this.requireImport('warplib.memory', 'wm_new'), + ...argSizesImports, + ...concatImports, + ], + }; } - private getSize(type: TypeNode, index: number): string { - if (type instanceof PointerType) { - this.requireImport('warplib.memory', 'wm_dyn_array_length'); - this.requireImport('warplib.maths.utils', 'narrow_safe'); + private getSize(type: TypeNode, index: number): [string, CairoImportFunctionDefinition[]] { + if (type instanceof StringType || type instanceof BytesType) { return [ - `let (size256_${index}) = wm_dyn_array_length(arg_${index});`, - `let (size_${index}) = narrow_safe(size256_${index});`, - ].join('\n'); + [ + `let (size256_${index}) = wm_dyn_array_length(arg_${index});`, + `let (size_${index}) = narrow_safe(size256_${index});`, + ].join('\n'), + [ + this.requireImport('warplib.memory', 'wm_dyn_array_length'), + this.requireImport('warplib.maths.utils', 'narrow_safe'), + ], + ]; } if (type instanceof IntType) { - return `let size_${index} = ${type.nBits / 8};`; + return [`let size_${index} = ${type.nBits / 8};`, []]; } if (type instanceof FixedBytesType) { - return `let size_${index} = ${type.size};`; + return [`let size_${index} = ${type.size};`, []]; } throw new TranspileFailedError( @@ -164,19 +202,28 @@ export class MemoryArrayConcat extends StringIndexedFuncGen { ); } - private getCopyFunctionCall(type: TypeNode, index: number): string { - if (type instanceof PointerType) { - this.requireImport('warplib.dynamic_arrays_util', 'dynamic_array_copy_felt'); - return `dynamic_array_copy_felt(res_loc, start_loc, end_loc, arg_${index}, 0);`; + private getCopyFunctionCall( + type: TypeNode, + index: number, + ): [string, CairoImportFunctionDefinition] { + if (type instanceof StringType || type instanceof BytesType) { + return [ + `dynamic_array_copy_felt(res_loc, start_loc, end_loc, arg_${index}, 0);`, + this.requireImport('warplib.dynamic_arrays_util', 'dynamic_array_copy_felt'), + ]; } assert(type instanceof FixedBytesType); if (type.size < 32) { - this.requireImport('warplib.dynamic_arrays_util', 'fixed_bytes_to_dynamic_array'); - return `fixed_bytes_to_dynamic_array(res_loc, start_loc, end_loc, arg_${index}, 0, size_${index});`; + return [ + `fixed_bytes_to_dynamic_array(res_loc, start_loc, end_loc, arg_${index}, 0, size_${index});`, + this.requireImport('warplib.dynamic_arrays_util', 'fixed_bytes_to_dynamic_array'), + ]; } - this.requireImport('warplib.dynamic_arrays_util', 'fixed_bytes256_to_dynamic_array'); - return `fixed_bytes256_to_dynamic_array(res_loc, start_loc, end_loc, arg_${index}, 0);`; + return [ + `fixed_bytes256_to_dynamic_array(res_loc, start_loc, end_loc, arg_${index}, 0);`, + this.requireImport('warplib.dynamic_arrays_util', 'fixed_bytes256_to_dynamic_array'), + ]; } } diff --git a/src/cairoUtilFuncGen/memory/arrayLiteral.ts b/src/cairoUtilFuncGen/memory/arrayLiteral.ts index 45f1eb525..38a51a794 100644 --- a/src/cairoUtilFuncGen/memory/arrayLiteral.ts +++ b/src/cairoUtilFuncGen/memory/arrayLiteral.ts @@ -3,11 +3,11 @@ import { ArrayType, BytesType, DataLocation, - FixedBytesType, FunctionCall, generalizeType, Literal, LiteralKind, + StringLiteralType, StringType, TupleExpression, TupleType, @@ -16,8 +16,8 @@ import { import { printNode } from '../../utils/astPrinter'; import { CairoType } from '../../utils/cairoTypeSystem'; import { cloneASTNode } from '../../utils/cloning'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; -import { createNumberLiteral, createStringTypeName } from '../../utils/nodeTemplates'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; +import { createNumberLiteral } from '../../utils/nodeTemplates'; import { getElementType, getSize, @@ -27,7 +27,7 @@ import { import { notNull } from '../../utils/typeConstructs'; import { mapRange, narrowBigIntSafe, typeNameFromTypeNode } from '../../utils/utils'; import { uint256 } from '../../warplib/utils'; -import { add, locationIfComplexType, StringIndexedFuncGen } from '../base'; +import { add, GeneratedFunctionInfo, locationIfComplexType, StringIndexedFuncGen } from '../base'; /* Converts [a,b,c] and "abc" into WM0_arr(a,b,c), which allocates new space in warp_memory @@ -35,7 +35,7 @@ import { add, locationIfComplexType, StringIndexedFuncGen } from '../base'; start of the array */ export class MemoryArrayLiteralGen extends StringIndexedFuncGen { - stringGen(node: Literal): FunctionCall { + public stringGen(node: Literal): FunctionCall { // Encode the literal to the uint-8 byte representation assert( node.kind === LiteralKind.String || @@ -44,21 +44,11 @@ export class MemoryArrayLiteralGen extends StringIndexedFuncGen { ); const size = node.hexValue.length / 2; - const baseType = new FixedBytesType(1); - const baseTypeName = typeNameFromTypeNode(baseType, this.ast); - const name = this.getOrCreate(baseType, size, true); - - const stub = createCairoFunctionStub( - name, - mapRange(size, (n) => [`e${n}`, cloneASTNode(baseTypeName, this.ast), DataLocation.Default]), - [['arr', createStringTypeName(false, this.ast), DataLocation.Memory]], - ['range_check_ptr', 'warp_memory'], - this.ast, - node, - ); + const type = generalizeType(safeGetNodeType(node, this.ast.inference))[0]; + const funcDef = this.getOrCreateFuncDef(type, size); return createCallToFunction( - stub, + funcDef, mapRange(size, (n) => createNumberLiteral(parseInt(node.hexValue.slice(2 * n, 2 * n + 2), 16), this.ast), ), @@ -66,7 +56,7 @@ export class MemoryArrayLiteralGen extends StringIndexedFuncGen { ); } - tupleGen(node: TupleExpression): FunctionCall { + public tupleGen(node: TupleExpression): FunctionCall { const elements = node.vOriginalComponents.filter(notNull); assert(elements.length === node.vOriginalComponents.length); @@ -78,47 +68,55 @@ export class MemoryArrayLiteralGen extends StringIndexedFuncGen { type instanceof StringType, ); - const elementT = getElementType(type); - const wideSize = getSize(type); const size = wideSize !== undefined ? narrowBigIntSafe(wideSize, `${printNode(node)} too long to process`) : elements.length; - const name = this.getOrCreate(elementT, size, isDynamicArray(type)); + const funcDef = this.getOrCreateFuncDef(type, size); + return createCallToFunction(funcDef, elements, this.ast); + } + + public getOrCreateFuncDef(type: ArrayType | StringType, size: number) { + const baseType = getElementType(type); - const stub = createCairoFunctionStub( - name, + const key = baseType.pp() + size; + const value = this.generatedFunctionsDef.get(key); + if (value !== undefined) { + return value; + } + + const baseTypeName = typeNameFromTypeNode(baseType, this.ast); + const funcInfo = this.getOrCreate( + baseType, + size, + isDynamicArray(type) || type instanceof StringLiteralType, + ); + const funcDef = createCairoGeneratedFunction( + funcInfo, mapRange(size, (n) => [ - `e${n}`, - typeNameFromTypeNode(elementT, this.ast), - locationIfComplexType(elementT, DataLocation.Memory), + `arg_${n}`, + cloneASTNode(baseTypeName, this.ast), + locationIfComplexType(baseType, DataLocation.Memory), ]), [['arr', typeNameFromTypeNode(type, this.ast), DataLocation.Memory]], - ['range_check_ptr', 'warp_memory'], this.ast, - node, + this.sourceUnit, ); - - return createCallToFunction(stub, elements, this.ast); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - private getOrCreate(type: TypeNode, size: number, dynamic: boolean): string { + private getOrCreate(type: TypeNode, size: number, dynamic: boolean): GeneratedFunctionInfo { const elementCairoType = CairoType.fromSol(type, this.ast); - const key = `${dynamic ? 'd' : 's'}${size}${elementCairoType.fullStringRepresentation}`; - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - - const funcName = `WM${this.generatedFunctions.size}_${dynamic ? 'd' : 's'}_arr`; + const funcName = `wm${this.generatedFunctionsDef.size}_${dynamic ? 'dynamic' : 'static'}_array`; const argString = mapRange(size, (n) => `e${n}: ${elementCairoType.toString()}`).join(', '); // If it's dynamic we need to include the length at the start const alloc_len = dynamic ? size * elementCairoType.width + 2 : size * elementCairoType.width; - this.generatedFunctions.set(key, { + return { name: funcName, code: [ `func ${funcName}{range_check_ptr, warp_memory: DictAccess*}(${argString}) -> (loc: felt){`, @@ -139,13 +137,12 @@ export class MemoryArrayLiteralGen extends StringIndexedFuncGen { ` return (start,);`, `}`, ].join('\n'), - }); - - this.requireImport('warplib.memory', 'wm_alloc'); - this.requireImport('warplib.memory', 'wm_write_256'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('starkware.cairo.common.dict', 'dict_write'); - - return funcName; + functionsCalled: [ + this.requireImport('warplib.memory', 'wm_alloc'), + this.requireImport('warplib.memory', 'wm_write_256'), + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('starkware.cairo.common.dict', 'dict_write'), + ], + }; } } diff --git a/src/cairoUtilFuncGen/memory/implicitConversion.ts b/src/cairoUtilFuncGen/memory/implicitConversion.ts index 1b7439668..400e1d95a 100644 --- a/src/cairoUtilFuncGen/memory/implicitConversion.ts +++ b/src/cairoUtilFuncGen/memory/implicitConversion.ts @@ -15,20 +15,21 @@ import { UserDefinedType, } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { CairoFunctionDefinition, CairoImportFunctionDefinition } from '../../ast/cairoNodes'; import { printTypeNode } from '../../utils/astPrinter'; import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; import { NotSupportedYetError, TranspileFailedError } from '../../utils/errors'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { isDynamicArray, safeGetNodeType } from '../../utils/nodeTypeProcessing'; import { narrowBigIntSafe, typeNameFromTypeNode } from '../../utils/utils'; import { uint256 } from '../../warplib/utils'; -import { CairoFunction, delegateBasedOnType, StringIndexedFuncGen } from '../base'; +import { delegateBasedOnType, GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; import { MemoryReadGen } from './memoryRead'; import { MemoryWriteGen } from './memoryWrite'; /* Class that converts arrays with smaller element types into bigger types - e. g. + e.g. uint8[] -> uint256[] uint8[3] -> uint256[] uint8[3] -> uint256[3] @@ -39,7 +40,7 @@ import { MemoryWriteGen } from './memoryWrite'; const IMPLICITS = '{range_check_ptr, bitwise_ptr : BitwiseBuiltin*, warp_memory : DictAccess*}'; export class MemoryImplicitConversionGen extends StringIndexedFuncGen { - constructor( + public constructor( private memoryWrite: MemoryWriteGen, private memoryRead: MemoryReadGen, ast: AST, @@ -48,7 +49,7 @@ export class MemoryImplicitConversionGen extends StringIndexedFuncGen { super(ast, sourceUnit); } - genIfNecesary(sourceExpression: Expression, targetType: TypeNode): [Expression, boolean] { + public genIfNecesary(sourceExpression: Expression, targetType: TypeNode): [Expression, boolean] { const sourceType = safeGetNodeType(sourceExpression, this.ast.inference); const generalTarget = generalizeType(targetType)[0]; @@ -98,31 +99,33 @@ export class MemoryImplicitConversionGen extends StringIndexedFuncGen { return [sourceExpression, false]; } - gen(source: Expression, targetType: TypeNode): FunctionCall { + public gen(source: Expression, targetType: TypeNode): FunctionCall { const sourceType = safeGetNodeType(source, this.ast.inference); - const name = this.getOrCreate(targetType, sourceType); + const funcDef = this.getOrCreateFuncDef(targetType, sourceType); + return createCallToFunction(funcDef, [source], this.ast); + } + + public getOrCreateFuncDef(targetType: TypeNode, sourceType: TypeNode) { + const key = targetType.pp() + sourceType.pp(); + const existing = this.generatedFunctionsDef.get(key); + if (existing !== undefined) { + return existing; + } - const functionStub = createCairoFunctionStub( - name, + const funcInfo = this.getOrCreate(targetType, sourceType); + const funcDef = createCairoGeneratedFunction( + funcInfo, [['source', typeNameFromTypeNode(sourceType, this.ast), DataLocation.Memory]], [['target', typeNameFromTypeNode(targetType, this.ast), DataLocation.Memory]], - ['range_check_ptr', 'bitwise_ptr', 'warp_memory'], this.ast, this.sourceUnit, ); - - return createCallToFunction(functionStub, [source], this.ast); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - getOrCreate(targetType: TypeNode, sourceType: TypeNode): string { - const key = targetType.pp() + sourceType.pp(); - - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - + private getOrCreate(targetType: TypeNode, sourceType: TypeNode): GeneratedFunctionInfo { assert(targetType instanceof PointerType && sourceType instanceof PointerType); targetType = targetType.to; sourceType = sourceType.to; @@ -135,7 +138,7 @@ export class MemoryImplicitConversionGen extends StringIndexedFuncGen { ); }; - const cairoFunc = delegateBasedOnType( + const funcInfo = delegateBasedOnType( targetType, (targetType) => { assert(targetType instanceof ArrayType && sourceType instanceof ArrayType); @@ -151,15 +154,13 @@ export class MemoryImplicitConversionGen extends StringIndexedFuncGen { unexpectedTypeFunc, unexpectedTypeFunc, ); - - this.generatedFunctions.set(key, cairoFunc); - return cairoFunc.name; + return funcInfo; } private staticToStaticArrayConversion( targetType: ArrayType, sourceType: ArrayType, - ): CairoFunction { + ): GeneratedFunctionInfo { assert( targetType.size !== undefined && sourceType.size !== undefined && @@ -172,26 +173,28 @@ export class MemoryImplicitConversionGen extends StringIndexedFuncGen { ); const sourceLoc = `${getOffset('source', 'index', cairoSourceElementType.width)}`; - let sourceLocationCode; + let sourceLocationFunc: CairoFunctionDefinition; + let sourceLocationCode: string; if (targetType.elementT instanceof PointerType) { - this.requireImport('warplib.memory', 'wm_read_id'); const idAllocSize = isDynamicArray(sourceType.elementT) ? 2 : cairoSourceElementType.width; + sourceLocationFunc = this.requireImport('warplib.memory', 'wm_read_id'); sourceLocationCode = `let (source_elem) = wm_read_id(${sourceLoc}, ${uint256(idAllocSize)});`; } else { - sourceLocationCode = `let (source_elem) = ${this.memoryRead.getOrCreate( - cairoSourceElementType, - )}(${sourceLoc});`; + sourceLocationFunc = this.memoryRead.getOrCreateFuncDef(sourceType.elementT); + sourceLocationCode = `let (source_elem) = ${sourceLocationFunc.name}(${sourceLoc});`; } - const conversionCode = this.generateScalingCode(targetType.elementT, sourceType.elementT); + const [conversionCode, calledFuncs] = this.generateScalingCode( + targetType.elementT, + sourceType.elementT, + ); + const memoryWriteDef = this.memoryWrite.getOrCreateFuncDef(targetType.elementT); const targetLoc = `${getOffset('target', 'index', cairoTargetElementType.width)}`; - const targetCopyCode = `${this.memoryWrite.getOrCreate( - targetType.elementT, - )}(${targetLoc}, target_elem);`; + const targetCopyCode = `${memoryWriteDef.name}(${targetLoc}, target_elem);`; const allocSize = narrowBigIntSafe(targetType.size) * cairoTargetElementType.width; - const funcName = `memory_conversion_static_to_static${this.generatedFunctions.size}`; + const funcName = `memory_conversion_static_to_static${this.generatedFunctionsDef.size}`; const code = [ `func ${funcName}_copy${IMPLICITS}(source : felt, target : felt, index : felt){`, ` alloc_locals;`, @@ -212,16 +215,23 @@ export class MemoryImplicitConversionGen extends StringIndexedFuncGen { `}`, ].join('\n'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('warplib.memory', 'wm_alloc'); - - return { name: funcName, code: code }; + return { + name: funcName, + code: code, + functionsCalled: [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('warplib.memory', 'wm_alloc'), + sourceLocationFunc, + ...calledFuncs, + memoryWriteDef, + ], + }; } private staticToDynamicArrayConversion( targetType: ArrayType, sourceType: ArrayType, - ): CairoFunction { + ): GeneratedFunctionInfo { assert(sourceType.size !== undefined); const [cairoTargetElementType, cairoSourceElementType] = typesToCairoTypes( [targetType.elementT, sourceType.elementT], @@ -231,9 +241,9 @@ export class MemoryImplicitConversionGen extends StringIndexedFuncGen { const sourceTWidth = cairoSourceElementType.width; const targetTWidth = cairoTargetElementType.width; + const memoryRead = this.memoryRead.getOrCreateFuncDef(sourceType.elementT); const sourceLocationCode = ['let felt_index = index.low + index.high * 128;']; if (sourceType.elementT instanceof PointerType) { - this.requireImport('warplib.memory', 'wm_read_id'); const idAllocSize = isDynamicArray(sourceType.elementT) ? 2 : cairoSourceElementType.width; sourceLocationCode.push( `let (source_elem) = wm_read_id(${getOffset( @@ -244,7 +254,7 @@ export class MemoryImplicitConversionGen extends StringIndexedFuncGen { ); } else { sourceLocationCode.push( - `let (source_elem) = ${this.memoryRead.getOrCreate(cairoSourceElementType)}(${getOffset( + `let (source_elem) = ${memoryRead.name}(${getOffset( 'source', 'felt_index', sourceTWidth, @@ -252,14 +262,18 @@ export class MemoryImplicitConversionGen extends StringIndexedFuncGen { ); } - const conversionCode = this.generateScalingCode(targetType.elementT, sourceType.elementT); + const [conversionCode, conversionFuncs] = this.generateScalingCode( + targetType.elementT, + sourceType.elementT, + ); + const memoryWrite = this.memoryWrite.getOrCreateFuncDef(targetType.elementT); const targetCopyCode = [ `let (target_elem_loc) = wm_index_dyn(target, index, ${uint256(targetTWidth)});`, - `${this.memoryWrite.getOrCreate(targetType.elementT)}(target_elem_loc, target_elem);`, + `${memoryWrite.name}(target_elem_loc, target_elem);`, ]; - const funcName = `memory_conversion_static_to_dynamic${this.generatedFunctions.size}`; + const funcName = `memory_conversion_static_to_dynamic${this.generatedFunctionsDef.size}`; const code = [ `func ${funcName}_copy${IMPLICITS}(source : felt, target : felt, index : Uint256, len : Uint256){`, ` alloc_locals;`, @@ -281,18 +295,25 @@ export class MemoryImplicitConversionGen extends StringIndexedFuncGen { `}`, ].join('\n'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_add'); - this.requireImport('warplib.memory', 'wm_index_dyn'); - this.requireImport('warplib.memory', 'wm_new'); - - return { name: funcName, code: code }; + return { + name: funcName, + code: code, + functionsCalled: [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_add'), + this.requireImport('warplib.memory', 'wm_index_dyn'), + this.requireImport('warplib.memory', 'wm_new'), + memoryRead, + ...conversionFuncs, + memoryWrite, + ], + }; } private dynamicToDynamicArrayConversion( targetType: ArrayType, sourceType: ArrayType, - ): CairoFunction { + ): GeneratedFunctionInfo { const [cairoTargetElementType, cairoSourceElementType] = typesToCairoTypes( [targetType.elementT, sourceType.elementT], this.ast, @@ -304,29 +325,30 @@ export class MemoryImplicitConversionGen extends StringIndexedFuncGen { const sourceLocationCode = [ `let (source_elem_loc) = wm_index_dyn(source, index, ${uint256(sourceTWidth)});`, ]; + + const memoryRead = this.memoryRead.getOrCreateFuncDef(sourceType.elementT); if (sourceType.elementT instanceof PointerType) { - this.requireImport('warplib.memory', 'wm_read_id'); const idAllocSize = isDynamicArray(sourceType.elementT) ? 2 : cairoSourceElementType.width; sourceLocationCode.push( `let (source_elem) = wm_read_id(source_elem_loc, ${uint256(idAllocSize)});`, ); } else { - sourceLocationCode.push( - `let (source_elem) = ${this.memoryRead.getOrCreate( - cairoSourceElementType, - )}(source_elem_loc);`, - ); + sourceLocationCode.push(`let (source_elem) = ${memoryRead.name}(source_elem_loc);`); } - const conversionCode = this.generateScalingCode(targetType.elementT, sourceType.elementT); + const [conversionCode, conversionCalls] = this.generateScalingCode( + targetType.elementT, + sourceType.elementT, + ); + const memoryWrite = this.memoryWrite.getOrCreateFuncDef(targetType.elementT); const targetCopyCode = [ `let (target_elem_loc) = wm_index_dyn(target, index, ${uint256(targetTWidth)});`, - `${this.memoryWrite.getOrCreate(targetType.elementT)}(target_elem_loc, target_elem);`, + `${memoryWrite.name}(target_elem_loc, target_elem);`, ]; const targetWidth = cairoTargetElementType.width; - const funcName = `memory_conversion_dynamic_to_dynamic${this.generatedFunctions.size}`; + const funcName = `memory_conversion_dynamic_to_dynamic${this.generatedFunctionsDef.size}`; const code = [ `func ${funcName}_copy${IMPLICITS}(source : felt, target : felt, index : Uint256, len : Uint256){`, ` alloc_locals;`, @@ -349,15 +371,25 @@ export class MemoryImplicitConversionGen extends StringIndexedFuncGen { `}`, ].join('\n'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_add'); - this.requireImport('warplib.memory', 'wm_index_dyn'); - this.requireImport('warplib.memory', 'wm_new'); - - return { name: funcName, code: code }; + return { + name: funcName, + code: code, + functionsCalled: [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_add'), + this.requireImport('warplib.memory', 'wm_index_dyn'), + this.requireImport('warplib.memory', 'wm_new'), + memoryRead, + ...conversionCalls, + memoryWrite, + ], + }; } - private generateScalingCode(targetType: TypeNode, sourceType: TypeNode) { + private generateScalingCode( + targetType: TypeNode, + sourceType: TypeNode, + ): [string, CairoFunctionDefinition[]] { if (targetType instanceof IntType) { assert(sourceType instanceof IntType); return this.generateIntegerScalingCode(targetType, sourceType, 'target_elem', 'source_elem'); @@ -371,14 +403,15 @@ export class MemoryImplicitConversionGen extends StringIndexedFuncGen { ); } else if (targetType instanceof PointerType) { assert(sourceType instanceof PointerType); - return `let (target_elem) = ${this.getOrCreate(targetType, sourceType)}(source_elem);`; + const auxFunc = this.getOrCreateFuncDef(targetType, sourceType); + return [`let (target_elem) = ${auxFunc.name}(source_elem);`, [auxFunc]]; } else if (isNoScalableType(targetType)) { - return `let target_elem = source_elem;`; + return [`let target_elem = source_elem;`, []]; } else { throw new TranspileFailedError( `Cannot scale ${printTypeNode(sourceType)} into ${printTypeNode( targetType, - )} from memory to strage`, + )} from memory to storage`, ); } } @@ -388,17 +421,21 @@ export class MemoryImplicitConversionGen extends StringIndexedFuncGen { sourceType: IntType, targetVar: string, sourceVar: string, - ): string { + ): [string, CairoImportFunctionDefinition[]] { if (targetType.signed && targetType.nBits !== sourceType.nBits) { const conversionFunc = `warp_int${sourceType.nBits}_to_int${targetType.nBits}`; - this.requireImport('warplib.maths.int_conversions', conversionFunc); - return `let (${targetVar}) = ${conversionFunc}(${sourceVar});`; + return [ + `let (${targetVar}) = ${conversionFunc}(${sourceVar});`, + [this.requireImport('warplib.maths.int_conversions', conversionFunc)], + ]; } else if (!targetType.signed && targetType.nBits === 256 && sourceType.nBits < 256) { const conversionFunc = `felt_to_uint256`; - this.requireImport('warplib.maths.utils', conversionFunc); - return `let (${targetVar}) = ${conversionFunc}(${sourceVar});`; + return [ + `let (${targetVar}) = ${conversionFunc}(${sourceVar});`, + [this.requireImport('warplib.maths.utils', conversionFunc)], + ]; } else { - return `let ${targetVar} = ${sourceVar};`; + return [`let ${targetVar} = ${sourceVar};`, []]; } } @@ -407,16 +444,18 @@ export class MemoryImplicitConversionGen extends StringIndexedFuncGen { sourceType: FixedBytesType, targetVar: string, sourceVar: string, - ): string { + ): [string, CairoImportFunctionDefinition[]] { const widthDiff = targetType.size - sourceType.size; if (widthDiff === 0) { - return `let ${targetVar} = ${sourceVar};`; + return [`let ${targetVar} = ${sourceVar};`, []]; } const conversionFunc = targetType.size === 32 ? 'warp_bytes_widen_256' : 'warp_bytes_widen'; - this.requireImport('warplib.maths.bytes_conversions', conversionFunc); - return `let (${targetVar}) = ${conversionFunc}(${sourceVar}, ${widthDiff * 8});`; + return [ + `let (${targetVar}) = ${conversionFunc}(${sourceVar}, ${widthDiff * 8});`, + [this.requireImport('warplib.maths.bytes_conversions', conversionFunc)], + ]; } } diff --git a/src/cairoUtilFuncGen/memory/memoryDynArrayLength.ts b/src/cairoUtilFuncGen/memory/memoryDynArrayLength.ts index dec938d8f..2d9ac40af 100644 --- a/src/cairoUtilFuncGen/memory/memoryDynArrayLength.ts +++ b/src/cairoUtilFuncGen/memory/memoryDynArrayLength.ts @@ -1,29 +1,21 @@ import { AST } from '../../ast/ast'; import { MemberAccess, FunctionCall, DataLocation, generalizeType } from 'solc-typed-ast'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCallToFunction } from '../../utils/functionGeneration'; import { createUint256TypeName } from '../../utils/nodeTemplates'; import { typeNameFromTypeNode } from '../../utils/utils'; -import { CairoUtilFuncGenBase } from '../base'; import { safeGetNodeType } from '../../utils/nodeTypeProcessing'; +import { CairoUtilFuncGenBase } from '../base'; export class MemoryDynArrayLengthGen extends CairoUtilFuncGenBase { - getGeneratedCode(): string { - return ''; - } - gen(node: MemberAccess, ast: AST): FunctionCall { const arrayType = generalizeType(safeGetNodeType(node.vExpression, ast.inference))[0]; const arrayTypeName = typeNameFromTypeNode(arrayType, ast); - const functionStub = createCairoFunctionStub( + const funcDef = this.requireImport( + 'warplib.memory', 'wm_dyn_array_length', [['arrayLoc', arrayTypeName, DataLocation.Memory]], [['len', createUint256TypeName(this.ast)]], - ['warp_memory'], - this.ast, - node, ); - const call = createCallToFunction(functionStub, [node.vExpression], this.ast); - this.ast.registerImport(call, 'warplib.memory', 'wm_dyn_array_length'); - return call; + return createCallToFunction(funcDef, [node.vExpression], this.ast); } } diff --git a/src/cairoUtilFuncGen/memory/memoryMemberAccess.ts b/src/cairoUtilFuncGen/memory/memoryMemberAccess.ts index 658c97ae2..493894a2a 100644 --- a/src/cairoUtilFuncGen/memory/memoryMemberAccess.ts +++ b/src/cairoUtilFuncGen/memory/memoryMemberAccess.ts @@ -1,19 +1,20 @@ import assert = require('assert'); import { MemberAccess, - ASTNode, FunctionCall, PointerType, UserDefinedType, VariableDeclaration, DataLocation, + StructDefinition, } from 'solc-typed-ast'; +import { printTypeNode } from '../../utils/astPrinter'; import { CairoType, TypeConversionContext, CairoStruct } from '../../utils/cairoTypeSystem'; import { cloneASTNode } from '../../utils/cloning'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { safeGetNodeType } from '../../utils/nodeTypeProcessing'; -import { typeNameFromTypeNode, countNestedMapItems } from '../../utils/utils'; -import { CairoUtilFuncGenBase, CairoFunction, add } from '../base'; +import { typeNameFromTypeNode } from '../../utils/utils'; +import { add, GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; /* Produces a separate function for each struct type and member name, that when given @@ -22,52 +23,63 @@ import { CairoUtilFuncGenBase, CairoFunction, add } from '../base'; so that it doesn't get converted into fixed-width solidity arithmetic. A CairoExpression node could serve as an optimisation here */ -export class MemoryMemberAccessGen extends CairoUtilFuncGenBase { - // cairoType -> property name -> code - private generatedFunctions: Map> = new Map(); - - // Concatenate all the generated cairo code into a single string - getGeneratedCode(): string { - return [...this.generatedFunctions.values()] - .flatMap((map) => [...map.values()]) - .map((cairoMapping) => cairoMapping.code) - .join('\n\n'); - } - - gen(memberAccess: MemberAccess, nodeInSourceUnit?: ASTNode): FunctionCall { +export class MemoryMemberAccessGen extends StringIndexedFuncGen { + public gen(memberAccess: MemberAccess): FunctionCall { const solType = safeGetNodeType(memberAccess.vExpression, this.ast.inference); - assert(solType instanceof PointerType); - assert(solType.to instanceof UserDefinedType); - const structCairoType = CairoType.fromSol( - solType, - this.ast, - TypeConversionContext.MemoryAllocation, + assert( + solType instanceof PointerType && + solType.to instanceof UserDefinedType && + solType.to.definition instanceof StructDefinition, + `Trying to generate a member access for a type different than a struct: ${printTypeNode( + solType, + )}`, ); - const name = this.getOrCreate(structCairoType, memberAccess.memberName); + const referencedDeclaration = memberAccess.vReferencedDeclaration; assert(referencedDeclaration instanceof VariableDeclaration); + const outType = referencedDeclaration.vType; assert(outType !== undefined); - const functionStub = createCairoFunctionStub( - name, - [['loc', typeNameFromTypeNode(solType, this.ast), DataLocation.Memory]], - [['memberLoc', cloneASTNode(outType, this.ast), DataLocation.Memory]], - [], + + const funcDef = this.getOrCreateFuncDef(solType.to, memberAccess.memberName); + return createCallToFunction(funcDef, [memberAccess.vExpression], this.ast); + } + + public getOrCreateFuncDef(solType: UserDefinedType, memberName: string) { + assert(solType.definition instanceof StructDefinition); + const structCairoType = CairoType.fromSol( + solType, this.ast, - nodeInSourceUnit ?? memberAccess, + TypeConversionContext.MemoryAllocation, ); - return createCallToFunction(functionStub, [memberAccess.vExpression], this.ast); - } - private getOrCreate(structCairoType: CairoType, memberName: string): string { - const existingMemberAccesses = - this.generatedFunctions.get(structCairoType.fullStringRepresentation) ?? - new Map(); - const existing = existingMemberAccesses.get(memberName); + const key = structCairoType.fullStringRepresentation + memberName; + const existing = this.generatedFunctionsDef.get(key); if (existing !== undefined) { - return existing.name; + return existing; } + const funcInfo = this.getOrCreate(structCairoType, memberName); + + const solTypeName = typeNameFromTypeNode(solType, this.ast); + const [outTypeName] = solType.definition.vMembers + .filter((member) => member.name === memberName) + .map((member) => member.vType); + assert(outTypeName !== undefined); + + const funcDef = createCairoGeneratedFunction( + funcInfo, + [['loc', solTypeName, DataLocation.Memory]], + [['member_loc', cloneASTNode(outTypeName, this.ast), DataLocation.Memory]], + this.ast, + this.sourceUnit, + ); + + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; + } + + private getOrCreate(structCairoType: CairoType, memberName: string): GeneratedFunctionInfo { const structName = structCairoType.toString(); assert( structCairoType instanceof CairoStruct, @@ -75,20 +87,16 @@ export class MemoryMemberAccessGen extends CairoUtilFuncGenBase { ); const offset = structCairoType.offsetOf(memberName); - const funcName = `WM${countNestedMapItems( - this.generatedFunctions, - )}_${structName}_${memberName}`; + const funcName = `wm_${structName}_${memberName}`; - existingMemberAccesses.set(memberName, { + return { name: funcName, code: [ `func ${funcName}(loc: felt) -> (memberLoc: felt){`, ` return (${add('loc', offset)},);`, `}`, ].join('\n'), - }); - - this.generatedFunctions.set(structCairoType.fullStringRepresentation, existingMemberAccesses); - return funcName; + functionsCalled: [], + }; } } diff --git a/src/cairoUtilFuncGen/memory/memoryRead.ts b/src/cairoUtilFuncGen/memory/memoryRead.ts index e308342f7..1c337e396 100644 --- a/src/cairoUtilFuncGen/memory/memoryRead.ts +++ b/src/cairoUtilFuncGen/memory/memoryRead.ts @@ -1,12 +1,13 @@ import { Expression, TypeName, - ASTNode, FunctionCall, DataLocation, FunctionStateMutability, generalizeType, + TypeNode, } from 'solc-typed-ast'; +import { CairoFunctionDefinition, typeNameFromTypeNode } from '../../export'; import { CairoFelt, CairoType, @@ -15,10 +16,10 @@ import { TypeConversionContext, } from '../../utils/cairoTypeSystem'; import { cloneASTNode } from '../../utils/cloning'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { createNumberLiteral, createNumberTypeName } from '../../utils/nodeTemplates'; import { isDynamicArray, safeGetNodeType } from '../../utils/nodeTypeProcessing'; -import { add, locationIfComplexType, StringIndexedFuncGen } from '../base'; +import { add, GeneratedFunctionInfo, locationIfComplexType, StringIndexedFuncGen } from '../base'; import { serialiseReads } from '../serialisation'; /* @@ -28,19 +29,15 @@ import { serialiseReads } from '../serialisation'; */ export class MemoryReadGen extends StringIndexedFuncGen { - gen(memoryRef: Expression, type: TypeName, nodeInSourceUnit?: ASTNode): FunctionCall { + gen(memoryRef: Expression): FunctionCall { const valueType = generalizeType(safeGetNodeType(memoryRef, this.ast.inference))[0]; const resultCairoType = CairoType.fromSol(valueType, this.ast); - const params: [string, TypeName, DataLocation][] = [ - ['loc', cloneASTNode(type, this.ast), DataLocation.Memory], - ]; const args = [memoryRef]; if (resultCairoType instanceof MemoryLocation) { // The size parameter represents how much space to allocate - // for the contents of the newly accessed suboject - params.push(['size', createNumberTypeName(256, false, this.ast), DataLocation.Default]); + // for the contents of the newly accessed subobject args.push( createNumberLiteral( isDynamicArray(valueType) @@ -51,49 +48,57 @@ export class MemoryReadGen extends StringIndexedFuncGen { ), ); } - - const name = this.getOrCreate(resultCairoType); - const functionStub = createCairoFunctionStub( - name, - params, - [ - [ - 'val', - cloneASTNode(type, this.ast), - locationIfComplexType(valueType, DataLocation.Memory), - ], - ], - ['range_check_ptr', 'warp_memory'], - this.ast, - nodeInSourceUnit ?? memoryRef, - { mutability: FunctionStateMutability.View }, - ); - - return createCallToFunction(functionStub, args, this.ast); + const funcDef = this.getOrCreateFuncDef(valueType); + return createCallToFunction(funcDef, args, this.ast); } - getOrCreate(typeToRead: CairoType): string { - if (typeToRead instanceof MemoryLocation) { - this.requireImport('warplib.memory', 'wm_read_id'); - return 'wm_read_id'; - } else if (typeToRead instanceof CairoFelt) { - this.requireImport('warplib.memory', 'wm_read_felt'); - return 'wm_read_felt'; - } else if (typeToRead.fullStringRepresentation === CairoUint256.fullStringRepresentation) { - this.requireImport('warplib.memory', 'wm_read_256'); - return 'wm_read_256'; + getOrCreateFuncDef(typeToRead: TypeNode) { + const key = typeToRead.pp(); + const existing = this.generatedFunctionsDef.get(key); + if (existing !== undefined) { + return existing; } - const key = typeToRead.fullStringRepresentation; - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; + const typeToReadName = typeNameFromTypeNode(typeToRead, this.ast); + const resultCairoType = CairoType.fromSol(typeToRead, this.ast); + + const inputs: [string, TypeName, DataLocation][] = + resultCairoType instanceof MemoryLocation + ? [ + ['loc', cloneASTNode(typeToReadName, this.ast), DataLocation.Memory], + ['size', createNumberTypeName(256, false, this.ast), DataLocation.Default], + ] + : [['loc', cloneASTNode(typeToReadName, this.ast), DataLocation.Memory]]; + const outputs: [string, TypeName, DataLocation][] = [ + [ + 'val', + cloneASTNode(typeToReadName, this.ast), + locationIfComplexType(typeToRead, DataLocation.Memory), + ], + ]; + + let funcDef: CairoFunctionDefinition; + if (resultCairoType instanceof MemoryLocation) { + funcDef = this.requireImport('warplib.memory', 'wm_read_id', inputs, outputs); + } else if (resultCairoType instanceof CairoFelt) { + funcDef = this.requireImport('warplib.memory', 'wm_read_felt', inputs, outputs); + } else if (resultCairoType.fullStringRepresentation === CairoUint256.fullStringRepresentation) { + funcDef = this.requireImport('warplib.memory', 'wm_read_256', inputs, outputs); + } else { + const funcInfo = this.getOrCreate(resultCairoType); + funcDef = createCairoGeneratedFunction(funcInfo, inputs, outputs, this.ast, this.sourceUnit, { + mutability: FunctionStateMutability.View, + }); } + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; + } - const funcName = `WM${this.generatedFunctions.size}_READ_${typeToRead.typeName}`; + private getOrCreate(typeToRead: CairoType): GeneratedFunctionInfo { + const funcName = `WM${this.generatedFunctionsDef.size}_READ_${typeToRead.typeName}`; const resultCairoType = typeToRead.toString(); const [reads, pack] = serialiseReads(typeToRead, readFelt, readFelt); - this.generatedFunctions.set(key, { + const funcInfo: GeneratedFunctionInfo = { name: funcName, code: [ `func ${funcName}{range_check_ptr, warp_memory : DictAccess*}(loc: felt) ->(val: ${resultCairoType}){`, @@ -102,9 +107,9 @@ export class MemoryReadGen extends StringIndexedFuncGen { ` return (${pack},);`, '}', ].join('\n'), - }); - this.requireImport('starkware.cairo.common.dict', 'dict_read'); - return funcName; + functionsCalled: [this.requireImport('starkware.cairo.common.dict', 'dict_read')], + }; + return funcInfo; } } diff --git a/src/cairoUtilFuncGen/memory/memoryStruct.ts b/src/cairoUtilFuncGen/memory/memoryStruct.ts index 356ecef43..77252fbef 100644 --- a/src/cairoUtilFuncGen/memory/memoryStruct.ts +++ b/src/cairoUtilFuncGen/memory/memoryStruct.ts @@ -5,34 +5,44 @@ import { IdentifierPath, PointerType, StructDefinition, + TypeNode, UserDefinedTypeName, } from 'solc-typed-ast'; import { CairoStruct, CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; import { cloneASTNode } from '../../utils/cloning'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { safeGetNodeType, typeNameToSpecializedTypeNode } from '../../utils/nodeTypeProcessing'; import { uint256 } from '../../warplib/utils'; -import { add, StringIndexedFuncGen } from '../base'; +import { add, GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; /* Produces functions to allocate memory structs, assign their members, and return their location This replaces StructConstructorCalls referencing memory with normal FunctionCalls */ export class MemoryStructGen extends StringIndexedFuncGen { - gen(node: FunctionCall): FunctionCall { + public gen(node: FunctionCall): FunctionCall { const structDef = node.vReferencedDeclaration; assert(structDef instanceof StructDefinition); - const cairoType = CairoType.fromSol( - safeGetNodeType(node, this.ast.inference), - this.ast, - TypeConversionContext.MemoryAllocation, - ); - assert(cairoType instanceof CairoStruct); - const name = this.getOrCreate(cairoType); + const nodeType = safeGetNodeType(node, this.ast.inference); + const funcDef = this.getOrCreateFuncDef(nodeType, structDef); - const stub = createCairoFunctionStub( - name, + structDef.vScope.acceptChildren(); + return createCallToFunction(funcDef, node.vArguments, this.ast); + } + + public getOrCreateFuncDef(nodeType: TypeNode, structDef: StructDefinition) { + const key = `memoryStruct(${nodeType.pp()},${structDef.name})`; + const value = this.generatedFunctionsDef.get(key); + if (value !== undefined) { + return value; + } + + const cairoType = CairoType.fromSol(nodeType, this.ast, TypeConversionContext.MemoryAllocation); + assert(cairoType instanceof CairoStruct); + const funcInfo = this.getOrCreate(cairoType); + const funcDef = createCairoGeneratedFunction( + funcInfo, structDef.vMembers.map((decl) => { assert(decl.vType !== undefined); const type = typeNameToSpecializedTypeNode( @@ -60,34 +70,24 @@ export class MemoryStructGen extends StringIndexedFuncGen { DataLocation.Memory, ], ], - ['range_check_ptr', 'warp_memory'], this.ast, - node, + this.sourceUnit, ); - - structDef.vScope.acceptChildren(); - - return createCallToFunction(stub, node.vArguments, this.ast); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - private getOrCreate(structType: CairoStruct): string { - const key = structType.fullStringRepresentation; - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - - const funcName = `WM${this.generatedFunctions.size}_struct_${structType.name}`; + private getOrCreate(structType: CairoStruct): GeneratedFunctionInfo { + const funcName = `WM${this.generatedFunctionsDef.size}_struct_${structType.name}`; const mangledStructMembers: [string, CairoType][] = [...structType.members.entries()].map( ([name, type]) => [`member_${name}`, type], ); - const argString = mangledStructMembers .map(([name, type]) => `${name}: ${type.toString()}`) .join(', '); - this.generatedFunctions.set(key, { + return { name: funcName, code: [ `func ${funcName}{range_check_ptr, warp_memory: DictAccess*}(${argString}) -> (res:felt){`, @@ -100,14 +100,13 @@ export class MemoryStructGen extends StringIndexedFuncGen { ` return (start,);`, `}`, ].join('\n'), - }); - - this.requireImport('warplib.memory', 'wm_alloc'); - this.requireImport('starkware.cairo.common.dict', 'dict_write'); - this.requireImport('starkware.cairo.common.dict_access', 'DictAccess'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - - return funcName; + functionsCalled: [ + this.requireImport('warplib.memory', 'wm_alloc'), + this.requireImport('starkware.cairo.common.dict', 'dict_write'), + this.requireImport('starkware.cairo.common.dict_access', 'DictAccess'), + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + ], + }; } } diff --git a/src/cairoUtilFuncGen/memory/memoryToCalldata.ts b/src/cairoUtilFuncGen/memory/memoryToCalldata.ts index e02b581c5..377f55e2b 100644 --- a/src/cairoUtilFuncGen/memory/memoryToCalldata.ts +++ b/src/cairoUtilFuncGen/memory/memoryToCalldata.ts @@ -1,7 +1,6 @@ import assert from 'assert'; import { ArrayType, - ASTNode, BytesType, DataLocation, Expression, @@ -15,15 +14,11 @@ import { UserDefinedType, } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { CairoFunctionDefinition } from '../../export'; import { printTypeNode } from '../../utils/astPrinter'; -import { - CairoDynArray, - CairoType, - generateCallDataDynArrayStructName, - TypeConversionContext, -} from '../../utils/cairoTypeSystem'; +import { CairoDynArray, CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; import { NotSupportedYetError } from '../../utils/errors'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { getElementType, getSize, @@ -33,184 +28,140 @@ import { } from '../../utils/nodeTypeProcessing'; import { mapRange, narrowBigIntSafe, typeNameFromTypeNode } from '../../utils/utils'; import { uint256 } from '../../warplib/utils'; -import { add, delegateBasedOnType, StringIndexedFuncGen } from '../base'; +import { add, delegateBasedOnType, GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; import { ExternalDynArrayStructConstructor } from '../calldata/externalDynArray/externalDynArrayStructConstructor'; +import { MemoryReadGen } from './memoryRead'; +const IMPLICITS = + '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, warp_memory : DictAccess*}'; export class MemoryToCallDataGen extends StringIndexedFuncGen { - constructor( + public constructor( private dynamicArrayStructGen: ExternalDynArrayStructConstructor, + private memoryReadGen: MemoryReadGen, ast: AST, sourceUnit: SourceUnit, ) { super(ast, sourceUnit); } - gen(node: Expression, nodeInSourceUnit?: ASTNode): FunctionCall { + + public gen(node: Expression): FunctionCall { const type = generalizeType(safeGetNodeType(node, this.ast.inference))[0]; - if (isDynamicArray(type)) { - this.dynamicArrayStructGen.gen(node, nodeInSourceUnit); + const funcDef = this.getOrCreateFuncDef(type); + return createCallToFunction(funcDef, [node], this.ast); + } + + public getOrCreateFuncDef(type: TypeNode) { + const key = type.pp(); + const existing = this.generatedFunctionsDef.get(key); + if (existing !== undefined) { + return existing; } - const name = this.getOrCreate(type); - const functionStub = createCairoFunctionStub( - name, + const funcInfo = this.getOrCreate(type); + + const funcDef = createCairoGeneratedFunction( + funcInfo, [['mem_loc', typeNameFromTypeNode(type, this.ast), DataLocation.Memory]], [['retData', typeNameFromTypeNode(type, this.ast), DataLocation.CallData]], - ['syscall_ptr', 'pedersen_ptr', 'range_check_ptr', 'warp_memory'], this.ast, - nodeInSourceUnit ?? node, + this.sourceUnit, { mutability: FunctionStateMutability.Pure }, ); - return createCallToFunction(functionStub, [node], this.ast); - } - private getOrCreate(type: TypeNode): string { - const key = type.pp(); - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; + } + private getOrCreate(type: TypeNode): GeneratedFunctionInfo { const unexpectedTypeFunc = () => { throw new NotSupportedYetError( `Copying ${printTypeNode(type)} from memory to calldata not implemented yet`, ); }; - return delegateBasedOnType( + return delegateBasedOnType( type, - (type) => this.createDynamicArrayCopyFunction(key, type), - (type) => this.createStaticArrayCopyFunction(key, type), - (type) => this.createStructCopyFunction(key, type), + (type) => this.createDynamicArrayCopyFunction(type), + (type) => this.createStaticArrayCopyFunction(type), + (type) => this.createStructCopyFunction(type), unexpectedTypeFunc, unexpectedTypeFunc, ); } - private createStructCopyFunction(key: string, type: TypeNode): string { - const funcName = `wm_to_calldata${this.generatedFunctions.size}`; - const implicits = - '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, warp_memory : DictAccess*}'; - const outputType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); - - assert(type instanceof UserDefinedType); + private createStructCopyFunction(type: TypeNode): GeneratedFunctionInfo { + assert(type instanceof UserDefinedType && type.definition instanceof StructDefinition); const structDef = type.definition; - assert(structDef instanceof StructDefinition); + const outputType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); - // Set an empty entry so recursive function generation doesn't clash - this.generatedFunctions.set(key, { name: funcName, code: '' }); + const [code, funcCalls] = structDef.vMembers + .map((decl) => safeGetNodeType(decl, this.ast.inference)) + .reduce( + ([code, funcCalls, offset], type, index) => { + const [copyCode, copyFuncCalls, newOffset] = this.generateElementCopyCode( + type, + offset, + index, + ); + return [[...code, ...copyCode], [...funcCalls, ...copyFuncCalls], newOffset]; + }, + [new Array(), new Array(), 0], + ); - let offset = 0; - this.generatedFunctions.set(key, { + const funcName = `wm_to_calldata${this.generatedFunctionsDef.size}_struct_${structDef.name}`; + return { name: funcName, code: [ - `func ${funcName}${implicits}(mem_loc : felt) -> (retData: ${outputType.toString()}){`, + `func ${funcName}${IMPLICITS}(mem_loc : felt) -> (ret_data: ${outputType.toString()}){`, ` alloc_locals;`, - ...structDef.vMembers.map((decl, index) => { - const memberType = safeGetNodeType(decl, this.ast.inference); - if (isReferenceType(memberType)) { - this.requireImport('warplib.memory', 'wm_read_id'); - const allocSize = isDynamicArray(memberType) - ? 2 - : CairoType.fromSol(memberType, this.ast, TypeConversionContext.Ref).width; - const memberGetter = this.getOrCreate(memberType); - return [ - `let (read_${index}) = wm_read_id(${add('mem_loc', offset++)}, ${uint256( - allocSize, - )});`, - `let (member${index}) = ${memberGetter}(read_${index});`, - ].join('\n'); - } else { - const memberCairoType = CairoType.fromSol(memberType, this.ast); - if (memberCairoType.width === 1) { - const code = `let (member${index}) = wm_read_felt(${add('mem_loc', offset++)});`; - this.requireImport('warplib.memory', 'wm_read_felt'); - return code; - } else if (memberCairoType.width === 2) { - const code = `let (member${index}) = wm_read_256(${add('mem_loc', offset)});`; - this.requireImport('warplib.memory', 'wm_read_256'); - offset += 2; - return code; - } - } - }), + ...code, ` return (${outputType.toString()}(${mapRange( structDef.vMembers.length, (n) => `member${n}`, )}),);`, `}`, ].join('\n'), - }); - - this.requireImport('starkware.cairo.common.dict', 'dict_read'); - - return funcName; + functionsCalled: funcCalls, + }; } - private createStaticArrayCopyFunction(key: string, type: ArrayType): string { - const funcName = `wm_to_calldata${this.generatedFunctions.size}`; - const implicits = - '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, warp_memory : DictAccess*}'; + // TODO: With big static arrays, this functions gets huge. Can that be fixed?! + private createStaticArrayCopyFunction(type: ArrayType): GeneratedFunctionInfo { const outputType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); assert(type.size !== undefined); const length = narrowBigIntSafe(type.size); const elementT = type.elementT; - // Set an empty entry so recursive function generation doesn't clash - this.generatedFunctions.set(key, { name: funcName, code: '' }); + const memberFeltSize = CairoType.fromSol(elementT, this.ast).width; + const [copyCode, funcCalls] = mapRange(length, (n): [string[], CairoFunctionDefinition[]] => { + const [memberCopyCode, memberCalls] = this.generateElementCopyCode( + elementT, + n * memberFeltSize, + n, + ); + return [memberCopyCode, memberCalls]; + }).reduce(([copyCode, funcCalls], [memberCode, memberCalls]) => [ + [...copyCode, ...memberCode], + [...funcCalls, ...memberCalls], + ]); - let offset = 0; - this.generatedFunctions.set(key, { + const funcName = `wm_to_calldata_static_array${this.generatedFunctionsDef.size}`; + return { name: funcName, code: [ - `func ${funcName}${implicits}(mem_loc : felt) -> (retData: ${outputType.toString()}){`, + `func ${funcName}${IMPLICITS}(mem_loc : felt) -> (ret_data: ${outputType.toString()}){`, ` alloc_locals;`, - ...mapRange(length, (index) => { - if (isReferenceType(elementT)) { - this.requireImport('warplib.memory', 'wm_read_id'); - const memberGetter = this.getOrCreate(elementT); - const allocSize = isDynamicArray(elementT) - ? 2 - : CairoType.fromSol(elementT, this.ast, TypeConversionContext.Ref).width; - return [ - `let (read${index}) = wm_read_id(${add('mem_loc', offset++)}, ${uint256( - allocSize, - )});`, - `let (member${index}) = ${memberGetter}(read${index});`, - ].join('\n'); - } else { - const memberCairoType = CairoType.fromSol(elementT, this.ast); - if (memberCairoType.width === 1) { - const code = `let (member${index}) = wm_read_felt(${add('mem_loc', offset++)});`; - this.requireImport('warplib.memory', 'wm_read_felt'); - return code; - } else if (memberCairoType.width === 2) { - const code = `let (member${index}) = wm_read_256(${add('mem_loc', offset)});`; - this.requireImport('warplib.memory', 'wm_read_256'); - offset += 2; - return code; - } - } - }), + ...copyCode, ` return ((${mapRange(length, (n) => `member${n}`)}),);`, `}`, ].join('\n'), - }); - - this.requireImport('starkware.cairo.common.dict', 'dict_read'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - - return funcName; + functionsCalled: funcCalls, + }; } - private createDynamicArrayCopyFunction(key: string, type: TypeNode): string { - const funcName = `wm_to_calldata${this.generatedFunctions.size}`; - // Set an empty entry so recursive function generation doesn't clash. - this.generatedFunctions.set(key, { name: funcName, code: '' }); - const implicits = - '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, warp_memory : DictAccess*}'; - + private createDynamicArrayCopyFunction(type: TypeNode): GeneratedFunctionInfo { const outputType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); assert(outputType instanceof CairoDynArray); @@ -226,70 +177,69 @@ export class MemoryToCallDataGen extends StringIndexedFuncGen { ); } - this.generatedFunctions.set(key, { + const dynArrayReaderInfo = this.createDynArrayReader(elementT); + const calldataDynArrayStruct = this.dynamicArrayStructGen.getOrCreateFuncDef(type); + + const funcName = `wm_to_calldata_dynamic_array${this.generatedFunctionsDef.size}`; + const funcInfo: GeneratedFunctionInfo = { name: funcName, code: [ - `func ${funcName}${implicits}(mem_loc: felt) -> (retData: ${outputType.toString()}){`, + dynArrayReaderInfo.code, + `func ${funcName}${IMPLICITS}(mem_loc: felt) -> (retData: ${outputType.toString()}){`, ` alloc_locals;`, ` let (len_256) = wm_read_256(mem_loc);`, ` let (ptr : ${outputType.vPtr.toString()}) = alloc();`, ` let (len_felt) = narrow_safe(len_256);`, - ` ${this.createDynArrayReader(elementT)}(len_felt, ptr, mem_loc + 2);`, - ` return (${generateCallDataDynArrayStructName( - elementT, - this.ast, - )}(len=len_felt, ptr=ptr),);`, + ` ${dynArrayReaderInfo.name}(len_felt, ptr, mem_loc + 2);`, + ` return (${calldataDynArrayStruct.name}(len=len_felt, ptr=ptr),);`, `}`, ].join('\n'), - }); - this.requireImport('starkware.cairo.common.alloc', 'alloc'); - this.requireImport('warplib.maths.utils', 'narrow_safe'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_add'); - this.requireImport('warplib.memory', 'wm_read_256'); - return funcName; + functionsCalled: [ + this.requireImport('starkware.cairo.common.alloc', 'alloc'), + this.requireImport('warplib.maths.utils', 'narrow_safe'), + this.requireImport('warplib.memory', 'wm_read_256'), + calldataDynArrayStruct, + ...dynArrayReaderInfo.functionsCalled, + ], + }; + return funcInfo; } - private createDynArrayReader(elementT: TypeNode): string { - const funcName = `wm_to_calldata${this.generatedFunctions.size}`; - const key = elementT.pp() + 'dynReader'; - // Set an empty entry so recursive function generation doesn't clash - this.generatedFunctions.set(key, { name: funcName, code: '' }); - - const implicits = - '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, warp_memory : DictAccess*}'; + private createDynArrayReader(elementT: TypeNode): GeneratedFunctionInfo { + const funcName = `wm_to_calldata_dynamic_array_reader${this.generatedFunctionsDef.size}`; const cairoType = CairoType.fromSol(elementT, this.ast, TypeConversionContext.CallDataRef); - const memWidth = CairoType.fromSol(elementT, this.ast).width; + const memWidth = CairoType.fromSol(elementT, this.ast, TypeConversionContext.Ref).width; const ptrString = `${cairoType.toString()}`; - let code = ['']; + const readFunc = this.memoryReadGen.getOrCreateFuncDef(elementT); + let code: string[]; + let funcCalls: CairoFunctionDefinition[]; if (isReferenceType(elementT)) { const allocSize = isDynamicArray(elementT) ? 2 : CairoType.fromSol(elementT, this.ast, TypeConversionContext.Ref).width; + + const auxFunc = this.getOrCreateFuncDef(elementT); code = [ - `let (mem_read0) = wm_read_id(mem_loc, ${uint256(allocSize)});`, - `let (mem_read1) = ${this.getOrCreate(elementT)}(mem_read0);`, + `let (mem_read0) = ${readFunc.name}(mem_loc, ${uint256(allocSize)});`, + `let (mem_read1) = ${auxFunc.name}(mem_read0);`, `assert ptr[0] = mem_read1;`, ]; - this.requireImport('warplib.memory', 'wm_read_id'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - } else if (cairoType.width === 1) { - code = ['let (mem_read0) = wm_read_felt(mem_loc);', 'assert ptr[0] = mem_read0;']; - this.requireImport('warplib.memory', 'wm_read_felt'); - } else if (cairoType.width === 2) { - code = ['let (mem_read0) = wm_read_256(mem_loc);', 'assert ptr[0] = mem_read0;']; - this.requireImport('warplib.memory', 'wm_read_256'); + funcCalls = [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + auxFunc, + readFunc, + ]; } else { - throw new NotSupportedYetError( - `Element type ${cairoType.toString()} not supported yet in m->c`, - ); + code = [`let (mem_read0) = ${readFunc.name}(mem_loc);`, 'assert ptr[0] = mem_read0;']; + funcCalls = [readFunc]; } - this.generatedFunctions.set(funcName, { + return { name: funcName, code: [ - `func ${funcName}${implicits}(len: felt, ptr: ${ptrString}*, mem_loc: felt) -> (){`, + `func ${funcName}${IMPLICITS}(len: felt, ptr: ${ptrString}*, mem_loc: felt) -> (){`, ` alloc_locals;`, ` if (len == 0){`, ` return ();`, @@ -299,9 +249,42 @@ export class MemoryToCallDataGen extends StringIndexedFuncGen { ` return ();`, `}`, ].join('\n'), - }); - this.requireImport('starkware.cairo.common.uint256', 'uint256_sub'); + functionsCalled: funcCalls, + }; + } + + private generateElementCopyCode( + type: TypeNode, + offset: number, + index: number, + ): [string[], CairoFunctionDefinition[], number] { + const readFunc = this.memoryReadGen.getOrCreateFuncDef(type); + if (isReferenceType(type)) { + const memberGetterFunc = this.getOrCreateFuncDef(type); + const allocSize = isDynamicArray(type) + ? 2 + : CairoType.fromSol(type, this.ast, TypeConversionContext.Ref).width; + return [ + [ + `let (read_${index}) = ${readFunc.name}(${add('mem_loc', offset)}, ${uint256( + allocSize, + )});`, + `let (member${index})= ${memberGetterFunc.name}(read_${index});`, + ], + [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + memberGetterFunc, + readFunc, + ], + offset + 1, + ]; + } - return funcName; + const memberFeltSize = CairoType.fromSol(type, this.ast).width; + return [ + [`let (member${index}) = ${readFunc.name}(${add('mem_loc', offset)});`], + [readFunc], + offset + memberFeltSize, + ]; } } diff --git a/src/cairoUtilFuncGen/memory/memoryToStorage.ts b/src/cairoUtilFuncGen/memory/memoryToStorage.ts index e2d30f1e9..3291904ef 100644 --- a/src/cairoUtilFuncGen/memory/memoryToStorage.ts +++ b/src/cairoUtilFuncGen/memory/memoryToStorage.ts @@ -1,10 +1,10 @@ import assert from 'assert'; import { ArrayType, - ASTNode, BytesType, DataLocation, Expression, + FunctionCall, FunctionStateMutability, generalizeType, SourceUnit, @@ -14,10 +14,11 @@ import { UserDefinedType, } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { CairoFunctionDefinition } from '../../export'; import { printTypeNode } from '../../utils/astPrinter'; import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; -import { NotSupportedYetError, TranspileFailedError } from '../../utils/errors'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { NotSupportedYetError } from '../../utils/errors'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { getElementType, isDynamicArray, @@ -27,9 +28,10 @@ import { } from '../../utils/nodeTypeProcessing'; import { mapRange, narrowBigIntSafe, typeNameFromTypeNode } from '../../utils/utils'; import { uint256 } from '../../warplib/utils'; -import { add, delegateBasedOnType, StringIndexedFuncGen } from '../base'; +import { add, delegateBasedOnType, GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; import { DynArrayGen } from '../storage/dynArray'; import { StorageDeleteGen } from '../storage/storageDelete'; +import { MemoryReadGen } from './memoryRead'; /* Generates functions to copy data from warp_memory to WARP_STORAGE @@ -37,57 +39,61 @@ import { StorageDeleteGen } from '../storage/storageDelete'; These require extra care because the representations are different in storage and memory In storage nested structures are stored in place, whereas in memory 'pointers' are used */ +const IMPLICITS = + '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, warp_memory : DictAccess*}'; export class MemoryToStorageGen extends StringIndexedFuncGen { - constructor( + public constructor( private dynArrayGen: DynArrayGen, + private memoryReadGen: MemoryReadGen, private storageDeleteGen: StorageDeleteGen, ast: AST, sourceUnit: SourceUnit, ) { super(ast, sourceUnit); } - gen( - storageLocation: Expression, - memoryLocation: Expression, - nodeInSourceUnit?: ASTNode, - ): Expression { + + public gen(storageLocation: Expression, memoryLocation: Expression): FunctionCall { const type = generalizeType(safeGetNodeType(storageLocation, this.ast.inference))[0]; + const funcDef = this.getOrCreateFuncDef(type); + return createCallToFunction(funcDef, [storageLocation, memoryLocation], this.ast); + } + + public getOrCreateFuncDef(type: TypeNode) { + const key = type.pp(); + const value = this.generatedFunctionsDef.get(key); + if (value !== undefined) { + return value; + } - const name = this.getOrCreate(type); - const functionStub = createCairoFunctionStub( - name, + const funcInfo = this.getOrCreate(type); + const funcDef = createCairoGeneratedFunction( + funcInfo, [ ['loc', typeNameFromTypeNode(type, this.ast), DataLocation.Storage], ['mem_loc', typeNameFromTypeNode(type, this.ast), DataLocation.Memory], ], [['loc', typeNameFromTypeNode(type, this.ast), DataLocation.Storage]], - ['syscall_ptr', 'pedersen_ptr', 'range_check_ptr', 'warp_memory'], this.ast, - nodeInSourceUnit ?? storageLocation, + this.sourceUnit, { mutability: FunctionStateMutability.View }, ); - return createCallToFunction(functionStub, [storageLocation, memoryLocation], this.ast); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - getOrCreate(type: TypeNode): string { - const key = type.pp(); - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - + private getOrCreate(type: TypeNode): GeneratedFunctionInfo { const unexpectedTypeFunc = () => { throw new NotSupportedYetError( `Copying ${printTypeNode(type)} from memory to storage not implemented yet`, ); }; - return delegateBasedOnType( + return delegateBasedOnType( type, - (type) => this.createDynamicArrayCopyFunction(key, type), - (type) => this.createStaticArrayCopyFunction(key, type), - (type) => this.createStructCopyFunction(key, type), + (type) => this.createDynamicArrayCopyFunction(type), + (type) => this.createStaticArrayCopyFunction(type), + (type, def) => this.createStructCopyFunction(type, def), unexpectedTypeFunc, unexpectedTypeFunc, ); @@ -95,80 +101,63 @@ export class MemoryToStorageGen extends StringIndexedFuncGen { // This can also be used for static arrays, in which case they are treated // like structs with members of the same type - private createStructCopyFunction(key: string, type: TypeNode): string { - const funcName = `wm_to_storage${this.generatedFunctions.size}`; - const implicits = - '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, warp_memory : DictAccess*}'; - - // Set an empty entry so recursive function generation doesn't clash - this.generatedFunctions.set(key, { name: funcName, code: '' }); - - this.generatedFunctions.set(key, { + private createStructCopyFunction( + _type: UserDefinedType, + def: StructDefinition, + ): GeneratedFunctionInfo { + const funcName = `wm_to_storage_struct_${def.name}`; + + const [copyInstructions, funcsCalled] = this.generateTupleCopyInstructions( + def.vMembers.map((decl) => safeGetNodeType(decl, this.ast.inference)), + ); + return { name: funcName, code: [ - `func ${funcName}${implicits}(loc : felt, mem_loc: felt) -> (loc: felt){`, + `func ${funcName}${IMPLICITS}(loc : felt, mem_loc: felt) -> (loc: felt){`, ` alloc_locals;`, - ...generateCopyInstructions(type, this.ast).flatMap( - ({ storageOffset, copyType }, index) => { - const elemLoc = `elem_mem_loc_${index}`; - if (copyType === undefined) { - return [ - `let (${elemLoc}) = dict_read{dict_ptr=warp_memory}(${add('mem_loc', index)});`, - `WARP_STORAGE.write(${add('loc', storageOffset)}, ${elemLoc});`, - ]; - } else if (isDynamicArray(copyType)) { - this.requireImport('warplib.memory', 'wm_read_id'); - const funcName = this.getOrCreate(copyType); - return [ - `let (${elemLoc}) = wm_read_id(${add('mem_loc', index)}, ${uint256(2)});`, - `let (storage_dyn_array_loc) = readId(${add('loc', storageOffset)});`, - `${funcName}(storage_dyn_array_loc, ${elemLoc});`, - ]; - } else { - this.requireImport('warplib.memory', 'wm_read_id'); - const funcName = this.getOrCreate(copyType); - const copyTypeWidth = CairoType.fromSol( - copyType, - this.ast, - TypeConversionContext.Ref, - ).width; - return [ - `let (${elemLoc}) = wm_read_id(${add('mem_loc', index)}, ${uint256( - copyTypeWidth, - )});`, - `${funcName}(${add('loc', storageOffset)}, ${elemLoc});`, - ]; - } - }, - ), + ...copyInstructions, ` return (loc,);`, `}`, ].join('\n'), - }); - - this.requireImport('starkware.cairo.common.dict', 'dict_read'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - return funcName; + functionsCalled: funcsCalled, + }; } - private createStaticArrayCopyFunction(key: string, type: ArrayType): string { + private createStaticArrayCopyFunction(type: ArrayType): GeneratedFunctionInfo { assert(type.size !== undefined, 'Expected static array with known size'); return type.size <= 5 - ? this.createStructCopyFunction(key, type) - : this.createLargeStaticArrayCopyFunction(key, type); + ? this.createSmallStaticArrayCopyFunction(type) + : this.createLargeStaticArrayCopyFunction(type); + } + + private createSmallStaticArrayCopyFunction(type: ArrayType): GeneratedFunctionInfo { + assert(type.size !== undefined); + const size = narrowBigIntSafe(type.size, 'Static array size is unsupported'); + + const [copyInstructions, funcsCalled] = this.generateTupleCopyInstructions( + new Array(size).fill(type.elementT), + ); + + const funcName = `wm_to_storage_static_array_${this.generatedFunctionsDef.size}`; + return { + name: funcName, + code: [ + `func ${funcName}${IMPLICITS}(loc : felt, mem_loc: felt) -> (loc: felt){`, + ` alloc_locals;`, + ...copyInstructions, + ` return (loc,);`, + `}`, + ].join('\n'), + functionsCalled: funcsCalled, + }; } - private createLargeStaticArrayCopyFunction(key: string, type: ArrayType) { + private createLargeStaticArrayCopyFunction(type: ArrayType): GeneratedFunctionInfo { assert(type.size !== undefined, 'Expected static array with known size'); const length = narrowBigIntSafe( type.size, `Failed to narrow size of ${printTypeNode(type)} in memory->storage copy generation`, ); - const funcName = `wm_to_storage${this.generatedFunctions.size}`; - this.generatedFunctions.set(key, { name: funcName, code: '' }); - - const implicits = - '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, warp_memory : DictAccess*}'; const elementStorageWidth = CairoType.fromSol( type.elementT, @@ -178,19 +167,26 @@ export class MemoryToStorageGen extends StringIndexedFuncGen { const elementMemoryWidth = CairoType.fromSol(type.elementT, this.ast).width; let copyCode: string; + let calledFuncs: CairoFunctionDefinition[]; if (isDynamicArray(type.elementT)) { + const readFunc = this.memoryReadGen.getOrCreateFuncDef(type.elementT); + const auxFunc = this.getOrCreateFuncDef(type.elementT); copyCode = [ ` let (storage_id) = readId(storage_loc);`, - ` let (read) = wm_read_id(mem_loc, ${uint256(2)});`, - ` ${this.getOrCreate(type.elementT)}(storage_id, read);`, + ` let (memory_id) = ${readFunc.name}(mem_loc, ${uint256(2)});`, + ` ${auxFunc.name}(storage_id, memory_id);`, ].join('\n'); + calledFuncs = [readFunc, auxFunc]; } else if (isStruct(type.elementT)) { + const readFunc = this.memoryReadGen.getOrCreateFuncDef(type.elementT); + const auxFunc = this.getOrCreateFuncDef(type.elementT); copyCode = [ - ` let (read) = wm_read_id{dict_ptr=warp_memory}(mem_loc, ${uint256( + ` let (memory_id) = ${readFunc.name}{dict_ptr=warp_memory}(mem_loc, ${uint256( elementMemoryWidth, )});`, - ` ${this.getOrCreate(type.elementT)}(storage_loc, read);`, + ` ${auxFunc.name}(storage_loc, memory_id);`, ].join('\n'); + calledFuncs = [readFunc, auxFunc]; } else { copyCode = mapRange(elementStorageWidth, (n) => [ @@ -198,59 +194,45 @@ export class MemoryToStorageGen extends StringIndexedFuncGen { ` WARP_STORAGE.write(${add('storage_loc', n)}, copy);`, ].join('\n'), ).join('\n'); + calledFuncs = [this.requireImport('starkware.cairo.common.dict', 'dict_read')]; } - this.generatedFunctions.set(key, { + const funcName = `wm_to_storage_static_array_${this.generatedFunctionsDef.size}`; + return { name: funcName, code: [ - `func ${funcName}_elem${implicits}(storage_loc: felt, mem_loc : felt, length: felt) -> (){`, + `func ${funcName}_elem${IMPLICITS}(storage_loc: felt, mem_loc : felt, length: felt) -> (){`, ` alloc_locals;`, ` if (length == 0){`, ` return ();`, ` }`, ` let index = length - 1;`, - copyCode, + ` ${copyCode}`, ` return ${funcName}_elem(${add('storage_loc', elementStorageWidth)}, ${add( 'mem_loc', elementMemoryWidth, )}, index);`, `}`, - `func ${funcName}${implicits}(loc : felt, mem_loc : felt) -> (loc : felt){`, + `func ${funcName}${IMPLICITS}(loc : felt, mem_loc : felt) -> (loc : felt){`, ` alloc_locals;`, ` ${funcName}_elem(loc, mem_loc, ${length});`, ` return (loc,);`, `}`, ].join('\n'), - }); - - this.requireImport('starkware.cairo.common.dict', 'dict_write'); - this.requireImport('warplib.memory', 'wm_alloc'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_sub'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - if (isReferenceType(type.elementT)) { - this.requireImport('warplib.memory', 'wm_read_id'); - } - - return funcName; + functionsCalled: calledFuncs, + }; } private createDynamicArrayCopyFunction( - key: string, type: ArrayType | BytesType | StringType, - ): string { - const funcName = `wm_to_storage${this.generatedFunctions.size}`; - - this.generatedFunctions.set(key, { name: funcName, code: '' }); - + ): GeneratedFunctionInfo { const elementT = getElementType(type); - const [elemMapping, lengthMapping] = this.dynArrayGen.gen( - CairoType.fromSol(elementT, this.ast, TypeConversionContext.StorageAllocation), - ); + const [dynArray, dynArrayLength] = this.dynArrayGen.getOrCreateFuncDef(elementT); - const implicits = - '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, warp_memory : DictAccess*}'; + const elemMappingName = dynArray.name; + const lengthMappingName = dynArrayLength.name; const elementStorageWidth = CairoType.fromSol( elementT, @@ -258,18 +240,23 @@ export class MemoryToStorageGen extends StringIndexedFuncGen { TypeConversionContext.StorageAllocation, ).width; const elementMemoryWidth = CairoType.fromSol(elementT, this.ast).width; + let copyCode: string; - if (isDynamicArray(elementT)) { - copyCode = [ - ` let (storage_id) = readId(storage_loc);`, - ` let (read) = wm_read_id(mem_loc, ${uint256(2)});`, - ` ${this.getOrCreate(elementT)}(storage_id, read);`, - ].join('\n'); - } else if (isReferenceType(elementT)) { - copyCode = [ - ` let (read) = wm_read_id(mem_loc, ${uint256(elementMemoryWidth)});`, - ` ${this.getOrCreate(elementT)}(storage_loc, read);`, - ].join('\n'); + let funcCalls: CairoFunctionDefinition[]; + if (isReferenceType(elementT)) { + const readFunc = this.memoryReadGen.getOrCreateFuncDef(elementT); + const auxFunc = this.getOrCreateFuncDef(elementT); + copyCode = isDynamicArray(elementT) + ? [ + ` let (storage_id) = readId(storage_loc);`, + ` let (read) = ${readFunc.name}(mem_loc, ${uint256(2)});`, + ` ${auxFunc.name}(storage_id, read);`, + ].join('\n') + : [ + ` let (read) = ${readFunc.name}(mem_loc, ${uint256(elementMemoryWidth)});`, + ` ${auxFunc.name}(storage_loc, read);`, + ].join('\n'); + funcCalls = [readFunc, auxFunc]; } else { copyCode = mapRange(elementStorageWidth, (n) => [ @@ -277,40 +264,42 @@ export class MemoryToStorageGen extends StringIndexedFuncGen { ` WARP_STORAGE.write(${add('storage_loc', n)}, copy);`, ].join('\n'), ).join('\n'); + funcCalls = [this.requireImport('starkware.cairo.common.dict', 'dict_read')]; } - const deleteRemainingCode = `${this.storageDeleteGen.genAuxFuncName( - type, - )}(loc, mem_length, length);`; + const deleteFunc = this.storageDeleteGen.getOrCreateFuncDef(type); + const auxDeleteFuncName = deleteFunc.name + '_elem'; + const deleteRemainingCode = `${auxDeleteFuncName}(loc, mem_length, length);`; - this.generatedFunctions.set(key, { + const funcName = `wm_to_storage_dynamic_array${this.generatedFunctionsDef.size}`; + const funcInfo: GeneratedFunctionInfo = { name: funcName, code: [ - `func ${funcName}_elem${implicits}(storage_name: felt, mem_loc : felt, length: Uint256) -> (){`, + `func ${funcName}_elem${IMPLICITS}(storage_name: felt, mem_loc : felt, length: Uint256) -> (){`, ` alloc_locals;`, ` if (length.low == 0 and length.high == 0){`, ` return ();`, ` }`, ` let (index) = uint256_sub(length, Uint256(1,0));`, - ` let (storage_loc) = ${elemMapping}.read(storage_name, index);`, + ` let (storage_loc) = ${elemMappingName}.read(storage_name, index);`, ` let mem_loc = mem_loc - ${elementMemoryWidth};`, ` if (storage_loc == 0){`, ` let (storage_loc) = WARP_USED_STORAGE.read();`, ` WARP_USED_STORAGE.write(storage_loc + ${elementStorageWidth});`, - ` ${elemMapping}.write(storage_name, index, storage_loc);`, - copyCode, + ` ${elemMappingName}.write(storage_name, index, storage_loc);`, + ` ${copyCode}`, ` return ${funcName}_elem(storage_name, mem_loc, index);`, ` }else{`, - copyCode, + ` ${copyCode}`, ` return ${funcName}_elem(storage_name, mem_loc, index);`, ` }`, `}`, - `func ${funcName}${implicits}(loc : felt, mem_loc : felt) -> (loc : felt){`, + `func ${funcName}${IMPLICITS}(loc : felt, mem_loc : felt) -> (loc : felt){`, ` alloc_locals;`, - ` let (length) = ${lengthMapping}.read(loc);`, + ` let (length) = ${lengthMappingName}.read(loc);`, ` let (mem_length) = wm_dyn_array_length(mem_loc);`, - ` ${lengthMapping}.write(loc, mem_length);`, + ` ${lengthMappingName}.write(loc, mem_length);`, ` let (narrowedLength) = narrow_safe(mem_length);`, ` ${funcName}_elem(loc, mem_loc + 2 + ${elementMemoryWidth} * narrowedLength, mem_length);`, ` let (lesser) = uint256_lt(mem_length, length);`, @@ -322,60 +311,75 @@ export class MemoryToStorageGen extends StringIndexedFuncGen { ` }`, `}`, ].join('\n'), - }); - - this.requireImport('starkware.cairo.common.dict', 'dict_read'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_sub'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_lt'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('warplib.memory', 'wm_dyn_array_length'); - this.requireImport('warplib.maths.utils', 'narrow_safe'); - if (isReferenceType(elementT)) { - this.requireImport('warplib.memory', 'wm_read_id'); - } - - return funcName; + functionsCalled: [ + this.requireImport('warplib.maths.utils', 'narrow_safe'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_lt'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_sub'), + this.requireImport('warplib.memory', 'wm_dyn_array_length'), + ...funcCalls, + dynArray, + dynArrayLength, + deleteFunc, + ], + }; + return funcInfo; } -} - -type CopyInstruction = { - // The offset into the storage object to write to - storageOffset: number; - // If the copy requires a recursive call, this is the type to copy - copyType?: TypeNode; -}; - -function generateCopyInstructions(type: TypeNode, ast: AST): CopyInstruction[] { - let members: TypeNode[]; - - if (type instanceof UserDefinedType && type.definition instanceof StructDefinition) { - members = type.definition.vMembers.map((decl) => safeGetNodeType(decl, ast.inference)); - } else if (type instanceof ArrayType && type.size !== undefined) { - const narrowedWidth = narrowBigIntSafe(type.size, `Array size ${type.size} not supported`); - members = mapRange(narrowedWidth, () => type.elementT); - } else { - throw new TranspileFailedError( - `Attempted to create incorrect form of memory->storage copy for ${printTypeNode(type)}`, + private generateTupleCopyInstructions(types: TypeNode[]): [string[], CairoFunctionDefinition[]] { + const [code, funcCalls] = types.reduce( + ([code, funcCalls, storageOffset, memOffset], type, index) => { + const typeFeltWidth = getFeltWidth(type, this.ast); + const readFunc = this.memoryReadGen.getOrCreateFuncDef(type); + const elemLoc = `elem_mem_loc_${index}`; + if (isReferenceType(type)) { + const auxFunc = this.getOrCreateFuncDef(type); + const copyCode = isDynamicArray(type) + ? [ + `let (${elemLoc}) = ${readFunc.name}(${add('mem_loc', memOffset)}, ${uint256(2)});`, + `let (storage_dyn_array_loc) = readId(${add('loc', storageOffset)});`, + `${auxFunc.name}(storage_dyn_array_loc, ${elemLoc});`, + ] + : [ + `let (${elemLoc}) = ${readFunc.name}(${add('mem_loc', memOffset)}, ${uint256( + CairoType.fromSol(type, this.ast, TypeConversionContext.Ref).width, + )});`, + `${auxFunc.name}(${add('loc', storageOffset)}, ${elemLoc});`, + ]; + return [ + [...code, ...copyCode], + [ + ...funcCalls, + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + readFunc, + auxFunc, + ], + storageOffset + typeFeltWidth, + memOffset + 1, + ]; + } + return [ + [ + ...code, + ...mapRange(typeFeltWidth, (n) => + [ + `let (${elemLoc}_prt_${n}) = dict_read{dict_ptr=warp_memory}(${add( + 'mem_loc', + memOffset + n, + )});`, + `WARP_STORAGE.write(${add('loc', storageOffset + n)}, ${elemLoc}_prt_${n});`, + ].join('\n'), + ), + ], + [...funcCalls, this.requireImport('starkware.cairo.common.dict', 'dict_read')], + storageOffset + typeFeltWidth, + memOffset + typeFeltWidth, + ]; + }, + [new Array(), new Array(), 0, 0], ); + return [code, funcCalls]; } +} - let storageOffset = 0; - return members.flatMap((memberType) => { - if (isReferenceType(memberType)) { - const offset = storageOffset; - storageOffset += CairoType.fromSol( - memberType, - ast, - TypeConversionContext.StorageAllocation, - ).width; - return [{ storageOffset: offset, copyType: memberType }]; - } else { - const width = CairoType.fromSol( - memberType, - ast, - TypeConversionContext.StorageAllocation, - ).width; - return mapRange(width, () => ({ storageOffset: storageOffset++ })); - } - }); +function getFeltWidth(type: TypeNode, ast: AST): number { + return CairoType.fromSol(type, ast, TypeConversionContext.StorageAllocation).width; } diff --git a/src/cairoUtilFuncGen/memory/memoryWrite.ts b/src/cairoUtilFuncGen/memory/memoryWrite.ts index f76d4d676..dfed5d024 100644 --- a/src/cairoUtilFuncGen/memory/memoryWrite.ts +++ b/src/cairoUtilFuncGen/memory/memoryWrite.ts @@ -1,73 +1,77 @@ -import { - Expression, - FunctionCall, - TypeNode, - ASTNode, - DataLocation, - PointerType, -} from 'solc-typed-ast'; +import { Expression, FunctionCall, TypeNode, DataLocation, PointerType } from 'solc-typed-ast'; import { CairoFelt, CairoType, CairoUint256 } from '../../utils/cairoTypeSystem'; import { cloneASTNode } from '../../utils/cloning'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { + createCairoGeneratedFunction, + createCallToFunction, + ParameterInfo, +} from '../../utils/functionGeneration'; import { safeGetNodeType } from '../../utils/nodeTypeProcessing'; import { typeNameFromTypeNode } from '../../utils/utils'; -import { add, StringIndexedFuncGen } from '../base'; +import { add, GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; /* Produces functions to write a given value into warp_memory, returning that value (to simulate assignments) This involves serialising the data into a series of felts and writing each one into the DictAccess */ export class MemoryWriteGen extends StringIndexedFuncGen { - gen(memoryRef: Expression, writeValue: Expression, nodeInSourceUnit?: ASTNode): FunctionCall { + public gen(memoryRef: Expression, writeValue: Expression): FunctionCall { const typeToWrite = safeGetNodeType(memoryRef, this.ast.inference); - const name = this.getOrCreate(typeToWrite); + const funcDef = this.getOrCreateFuncDef(typeToWrite); + return createCallToFunction(funcDef, [memoryRef, writeValue], this.ast); + } + + public getOrCreateFuncDef(typeToWrite: TypeNode) { + const key = typeToWrite.pp(); + const existing = this.generatedFunctionsDef.get(key); + if (existing !== undefined) { + return existing; + } + const argTypeName = typeNameFromTypeNode(typeToWrite, this.ast); - const functionStub = createCairoFunctionStub( - name, + const inputs: ParameterInfo[] = [ + ['loc', argTypeName, DataLocation.Memory], [ - ['loc', argTypeName, DataLocation.Memory], - [ - 'value', - cloneASTNode(argTypeName, this.ast), - typeToWrite instanceof PointerType ? DataLocation.Memory : DataLocation.Default, - ], + 'value', + cloneASTNode(argTypeName, this.ast), + typeToWrite instanceof PointerType ? DataLocation.Memory : DataLocation.Default, ], + ]; + const outputs: ParameterInfo[] = [ [ - [ - 'res', - cloneASTNode(argTypeName, this.ast), - typeToWrite instanceof PointerType ? DataLocation.Memory : DataLocation.Default, - ], + 'res', + cloneASTNode(argTypeName, this.ast), + typeToWrite instanceof PointerType ? DataLocation.Memory : DataLocation.Default, ], - ['warp_memory'], - this.ast, - nodeInSourceUnit ?? memoryRef, - ); - return createCallToFunction(functionStub, [memoryRef, writeValue], this.ast); - } + ]; - getOrCreate(typeToWrite: TypeNode): string { const cairoTypeToWrite = CairoType.fromSol(typeToWrite, this.ast); - if (cairoTypeToWrite instanceof CairoFelt) { - this.requireImport('warplib.memory', 'wm_write_felt'); - return 'wm_write_felt'; + return this.requireImport('warplib.memory', 'wm_write_felt', inputs, outputs); } else if ( cairoTypeToWrite.fullStringRepresentation === CairoUint256.fullStringRepresentation ) { - this.requireImport('warplib.memory', 'wm_write_256'); - return 'wm_write_256'; + return this.requireImport('warplib.memory', 'wm_write_256', inputs, outputs); } - const key = cairoTypeToWrite.fullStringRepresentation; - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } + const funcInfo = this.getOrCreate(typeToWrite); + const funcDef = createCairoGeneratedFunction( + funcInfo, + inputs, + outputs, + this.ast, + this.sourceUnit, + ); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; + } + + private getOrCreate(typeToWrite: TypeNode): GeneratedFunctionInfo { + const cairoTypeToWrite = CairoType.fromSol(typeToWrite, this.ast); const cairoTypeString = cairoTypeToWrite.toString(); - const funcName = `WM_WRITE${this.generatedFunctions.size}`; - this.generatedFunctions.set(key, { + const funcName = `WM_WRITE${this.generatedFunctionsDef.size}`; + const funcInfo: GeneratedFunctionInfo = { name: funcName, code: [ `func ${funcName}{warp_memory : DictAccess*}(loc: felt, value: ${cairoTypeString}) -> (res: ${cairoTypeString}){`, @@ -77,11 +81,9 @@ export class MemoryWriteGen extends StringIndexedFuncGen { ' return (value,);', '}', ].join('\n'), - }); - - this.requireImport('starkware.cairo.common.dict', 'dict_write'); - - return funcName; + functionsCalled: [this.requireImport('starkware.cairo.common.dict', 'dict_write')], + }; + return funcInfo; } } diff --git a/src/cairoUtilFuncGen/memory/staticIndexAccess.ts b/src/cairoUtilFuncGen/memory/staticIndexAccess.ts index b9e1f2124..8f2d61b22 100644 --- a/src/cairoUtilFuncGen/memory/staticIndexAccess.ts +++ b/src/cairoUtilFuncGen/memory/staticIndexAccess.ts @@ -1,8 +1,8 @@ import assert = require('assert'); -import { ArrayType, ASTNode, DataLocation, FunctionCall, IndexAccess } from 'solc-typed-ast'; +import { ArrayType, DataLocation, FunctionCall, IndexAccess } from 'solc-typed-ast'; import { printNode } from '../../utils/astPrinter'; import { CairoType } from '../../utils/cairoTypeSystem'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCallToFunction } from '../../utils/functionGeneration'; import { createNumberLiteral, createUint256TypeName } from '../../utils/nodeTemplates'; import { typeNameFromTypeNode } from '../../utils/utils'; import { CairoUtilFuncGenBase } from '../base'; @@ -15,16 +15,18 @@ import { CairoUtilFuncGenBase } from '../base'; as parameters to avoid bloating the code with separate functions for each case */ export class MemoryStaticArrayIndexAccessGen extends CairoUtilFuncGenBase { - getGeneratedCode(): string { - return ''; - } - - gen(indexAccess: IndexAccess, arrayType: ArrayType, nodeInSourceUnit?: ASTNode): FunctionCall { + gen(indexAccess: IndexAccess, arrayType: ArrayType): FunctionCall { assert( arrayType.size !== undefined, `Attempted to use static indexing for dynamic index ${printNode(indexAccess)}`, ); - const stub = createCairoFunctionStub( + assert( + indexAccess.vIndexExpression, + `Found index access without index expression at ${printNode(indexAccess)}`, + ); + + const importFunc = this.requireImport( + 'warplib.memory', 'wm_index_static', [ ['arr', typeNameFromTypeNode(arrayType, this.ast), DataLocation.Memory], @@ -33,21 +35,10 @@ export class MemoryStaticArrayIndexAccessGen extends CairoUtilFuncGenBase { ['length', createUint256TypeName(this.ast)], ], [['child', typeNameFromTypeNode(arrayType.elementT, this.ast), DataLocation.Memory]], - ['range_check_ptr'], - this.ast, - nodeInSourceUnit ?? indexAccess, ); - - this.ast.registerImport(stub, 'warplib.memory', 'wm_index_static'); - const width = CairoType.fromSol(arrayType.elementT, this.ast).width; - - assert( - indexAccess.vIndexExpression, - `Found index access without index expression at ${printNode(indexAccess)}`, - ); return createCallToFunction( - stub, + importFunc, [ indexAccess.vBaseExpression, indexAccess.vIndexExpression, diff --git a/src/cairoUtilFuncGen/storage/copyToStorage.ts b/src/cairoUtilFuncGen/storage/copyToStorage.ts index 490ae2510..d059c2ba4 100644 --- a/src/cairoUtilFuncGen/storage/copyToStorage.ts +++ b/src/cairoUtilFuncGen/storage/copyToStorage.ts @@ -1,11 +1,11 @@ import assert from 'assert'; import { ArrayType, - ASTNode, BytesType, DataLocation, Expression, FixedBytesType, + FunctionCall, FunctionStateMutability, generalizeType, IntType, @@ -13,13 +13,13 @@ import { StringType, StructDefinition, TypeNode, - UserDefinedType, } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { CairoFunctionDefinition } from '../../export'; import { printTypeNode } from '../../utils/astPrinter'; import { CairoType, TypeConversionContext, WarpLocation } from '../../utils/cairoTypeSystem'; import { TranspileFailedError } from '../../utils/errors'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { getElementType, getSize, @@ -28,10 +28,13 @@ import { } from '../../utils/nodeTypeProcessing'; import { mapRange, narrowBigIntSafe, typeNameFromTypeNode } from '../../utils/utils'; import { uint256 } from '../../warplib/utils'; -import { add, CairoFunction, delegateBasedOnType, StringIndexedFuncGen } from '../base'; +import { add, delegateBasedOnType, GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; import { DynArrayGen } from './dynArray'; import { StorageDeleteGen } from './storageDelete'; +const IMPLICITS = + '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, bitwise_ptr : BitwiseBuiltin*}'; + /* Generates functions to copy data from WARP_STORAGE to WARP_STORAGE The main point of care here is to copy dynamic arrays. Mappings and types containing them @@ -40,7 +43,7 @@ import { StorageDeleteGen } from './storageDelete'; */ export class StorageToStorageGen extends StringIndexedFuncGen { - constructor( + public constructor( private dynArrayGen: DynArrayGen, private storageDeleteGen: StorageDeleteGen, ast: AST, @@ -48,39 +51,38 @@ export class StorageToStorageGen extends StringIndexedFuncGen { ) { super(ast, sourceUnit); } - gen(to: Expression, from: Expression, nodeInSourceUnit?: ASTNode): Expression { + public gen(to: Expression, from: Expression): FunctionCall { const toType = generalizeType(safeGetNodeType(to, this.ast.inference))[0]; const fromType = generalizeType(safeGetNodeType(from, this.ast.inference))[0]; + const funcDef = this.getOrCreateFuncDef(toType, fromType); - const name = this.getOrCreate(toType, fromType); - const functionStub = createCairoFunctionStub( - name, + return createCallToFunction(funcDef, [to, from], this.ast); + } + + public getOrCreateFuncDef(toType: TypeNode, fromType: TypeNode) { + const key = `${fromType.pp()}->${toType.pp()}`; + const exisiting = this.generatedFunctionsDef.get(key); + if (exisiting !== undefined) { + return exisiting; + } + const funcInfo = this.getOrCreate(toType, fromType); + const funcDef = createCairoGeneratedFunction( + funcInfo, [ ['toLoc', typeNameFromTypeNode(toType, this.ast), DataLocation.Storage], ['fromLoc', typeNameFromTypeNode(fromType, this.ast), DataLocation.Storage], ], [['retLoc', typeNameFromTypeNode(toType, this.ast), DataLocation.Storage]], - ['syscall_ptr', 'pedersen_ptr', 'range_check_ptr', 'bitwise_ptr'], this.ast, - nodeInSourceUnit ?? to, + this.sourceUnit, { mutability: FunctionStateMutability.View }, ); - return createCallToFunction(functionStub, [to, from], this.ast); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - getOrCreate(toType: TypeNode, fromType: TypeNode): string { - const key = `${fromType.pp()}->${toType.pp()}`; - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - - const funcName = `ws_copy${this.generatedFunctions.size}`; - - // Set an empty entry so recursive function generation doesn't clash - this.generatedFunctions.set(key, { name: funcName, code: '' }); - - const cairoFunction = delegateBasedOnType( + private getOrCreate(toType: TypeNode, fromType: TypeNode): GeneratedFunctionInfo { + const funcInfo = delegateBasedOnType( toType, (toType) => { assert( @@ -89,75 +91,82 @@ export class StorageToStorageGen extends StringIndexedFuncGen { fromType instanceof StringType, ); if (getSize(fromType) === undefined) { - return this.createDynamicArrayCopyFunction(funcName, toType, fromType); + return this.createDynamicArrayCopyFunction(toType, fromType); } else { assert(fromType instanceof ArrayType); - return this.createStaticToDynamicArrayCopyFunction(funcName, toType, fromType); + return this.createStaticToDynamicArrayCopyFunction(toType, fromType); } }, (toType) => { assert(fromType instanceof ArrayType); - return this.createStaticArrayCopyFunction(funcName, toType, fromType); + return this.createStaticArrayCopyFunction(toType, fromType); }, - (toType) => this.createStructCopyFunction(funcName, toType), + (_toType, def) => this.createStructCopyFunction(def), () => { throw new TranspileFailedError('Attempted to create mapping clone function'); }, (toType) => { if (toType instanceof IntType) { assert(fromType instanceof IntType); - return this.createIntegerCopyFunction(funcName, toType, fromType); + return this.createIntegerCopyFunction(toType, fromType); } else if (toType instanceof FixedBytesType) { assert(fromType instanceof FixedBytesType); - return this.createFixedBytesCopyFunction(funcName, toType, fromType); + return this.createFixedBytesCopyFunction(toType, fromType); } else { - return this.createValueTypeCopyFunction(funcName, toType); + return this.createValueTypeCopyFunction(toType); } }, ); - - this.generatedFunctions.set(key, cairoFunction); - return cairoFunction.name; + return funcInfo; } - private createStructCopyFunction(funcName: string, type: UserDefinedType): CairoFunction { - const def = type.definition; - assert(def instanceof StructDefinition); + private createStructCopyFunction(def: StructDefinition): GeneratedFunctionInfo { const members = def.vMembers.map((decl) => safeGetNodeType(decl, this.ast.inference)); + const [copyCode, funcsCalled] = members.reduce( + ([copyCode, funcsCalled, offset], memberType) => { + const width = CairoType.fromSol( + memberType, + this.ast, + TypeConversionContext.StorageAllocation, + ).width; + + if (isReferenceType(memberType)) { + const memberCopyFunc = this.getOrCreateFuncDef(memberType, memberType); + const toLoc = add('to_loc', offset); + const fromLoc = add('from_loc', offset); + return [ + [...copyCode, `${memberCopyFunc.name}(${toLoc}, ${fromLoc});`], + [...funcsCalled, memberCopyFunc], + offset + width, + ]; + } + return [ + [...copyCode, mapRange(width, (index) => copyAtOffset(index + offset)).join('\n')], + funcsCalled, + offset + width, + ]; + }, + [new Array(), new Array(), 0], + ); - let offset = 0; + const funcName = `WS_COPY_STRUCT_${def.name}`; return { name: funcName, code: [ - `func ${funcName}${implicits}(to_loc: felt, from_loc: felt) -> (retLoc: felt){`, + `func ${funcName}${IMPLICITS}(to_loc: felt, from_loc: felt) -> (retLoc: felt){`, ` alloc_locals;`, - ...members.map((memberType): string => { - const width = CairoType.fromSol( - memberType, - this.ast, - TypeConversionContext.StorageAllocation, - ).width; - let code: string; - if (isReferenceType(memberType)) { - const memberCopyFunc = this.getOrCreate(memberType, memberType); - code = `${memberCopyFunc}(${add('to_loc', offset)}, ${add('from_loc', offset)});`; - } else { - code = mapRange(width, (index) => copyAtOffset(index + offset)).join('\n'); - } - offset += width; - return code; - }), + ...copyCode, ` return (to_loc,);`, `}`, ].join('\n'), + functionsCalled: funcsCalled, }; } private createStaticArrayCopyFunction( - funcName: string, toType: ArrayType, fromType: ArrayType, - ): CairoFunction { + ): GeneratedFunctionInfo { assert( toType.size !== undefined, `Attempted to copy to storage dynamic array as static array in ${printTypeNode( @@ -171,7 +180,7 @@ export class StorageToStorageGen extends StringIndexedFuncGen { )}->${printTypeNode(toType)}`, ); - const elementCopyFunc = this.getOrCreate(toType.elementT, fromType.elementT); + const elementCopyFunc = this.getOrCreateFuncDef(toType.elementT, fromType.elementT); const toElemType = CairoType.fromSol( toType.elementT, @@ -183,22 +192,27 @@ export class StorageToStorageGen extends StringIndexedFuncGen { this.ast, TypeConversionContext.StorageAllocation, ); - const copyCode = createElementCopy(toElemType, fromElemType, elementCopyFunc); + const copyCode = createElementCopy(toElemType, fromElemType, elementCopyFunc.name); const fromSize = narrowBigIntSafe(fromType.size); const toSize = narrowBigIntSafe(toType.size); - let stopRecursion; + + const funcName = `WS_COPY_STATIC_${this.generatedFunctionsDef.size}`; + let optionalCalls: CairoFunctionDefinition[]; + let stopRecursion: string[]; if (fromSize === toSize) { + optionalCalls = []; stopRecursion = [`if (index == ${fromSize}){`, `return ();`, `}`]; } else { - this.requireImport('starkware.cairo.common.math_cmp', 'is_le'); + const deleteFunc = this.storageDeleteGen.getOrCreateFuncDef(toType.elementT); + optionalCalls = [deleteFunc, this.requireImport('starkware.cairo.common.math_cmp', 'is_le')]; stopRecursion = [ `if (index == ${toSize}){`, ` return ();`, `}`, `let lesser = is_le(index, ${fromSize - 1});`, `if (lesser == 0){`, - ` ${this.storageDeleteGen.genFuncName(toType.elementT)}(to_elem_loc);`, + ` ${deleteFunc.name}(to_elem_loc);`, ` return ${funcName}_elem(to_elem_loc + ${toElemType.width}, from_elem_loc, index + 1);`, `}`, ]; @@ -207,24 +221,24 @@ export class StorageToStorageGen extends StringIndexedFuncGen { return { name: funcName, code: [ - `func ${funcName}_elem${implicits}(to_elem_loc: felt, from_elem_loc: felt, index: felt) -> (){`, + `func ${funcName}_elem${IMPLICITS}(to_elem_loc: felt, from_elem_loc: felt, index: felt) -> (){`, ...stopRecursion, ` ${copyCode('to_elem_loc', 'from_elem_loc')}`, ` return ${funcName}_elem(to_elem_loc + ${toElemType.width}, from_elem_loc + ${fromElemType.width}, index + 1);`, `}`, - `func ${funcName}${implicits}(to_elem_loc: felt, from_elem_loc: felt) -> (retLoc: felt){`, + `func ${funcName}${IMPLICITS}(to_elem_loc: felt, from_elem_loc: felt) -> (retLoc: felt){`, ` ${funcName}_elem(to_elem_loc, from_elem_loc, 0);`, ` return (to_elem_loc,);`, `}`, ].join('\n'), + functionsCalled: [elementCopyFunc, ...optionalCalls], }; } private createDynamicArrayCopyFunction( - funcName: string, toType: ArrayType | BytesType | StringType, fromType: ArrayType | BytesType | StringType, - ): CairoFunction { + ): GeneratedFunctionInfo { const fromElementT = getElementType(fromType); const fromSize = getSize(fromType); const toElementT = getElementType(toType); @@ -232,11 +246,8 @@ export class StorageToStorageGen extends StringIndexedFuncGen { assert(toSize === undefined, 'Attempted to copy to storage static array as dynamic array'); assert(fromSize === undefined, 'Attempted to copy from storage static array as dynamic array'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_sub'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_lt'); + const elementCopyFunc = this.getOrCreateFuncDef(toElementT, fromElementT); - const elementCopyFunc = this.getOrCreate(toElementT, fromElementT); const fromElementCairoType = CairoType.fromSol( fromElementT, this.ast, @@ -247,30 +258,41 @@ export class StorageToStorageGen extends StringIndexedFuncGen { this.ast, TypeConversionContext.StorageAllocation, ); - const [fromElementMapping, fromLengthMapping] = this.dynArrayGen.gen(fromElementCairoType); - const [toElementMapping, toLengthMapping] = this.dynArrayGen.gen(toElementCairoType); - const copyCode = createElementCopy(toElementCairoType, fromElementCairoType, elementCopyFunc); + const [fromElementMapping, fromLengthMapping] = + this.dynArrayGen.getOrCreateFuncDef(fromElementT); + const fromElementMappingName = fromElementMapping.name; + const fromLengthMappingName = fromLengthMapping.name; - const deleteRemainingCode = `${this.storageDeleteGen.genAuxFuncName( - toType, - )}(to_loc, from_length, to_length)`; + const [toElementMapping, toLengthMapping] = this.dynArrayGen.getOrCreateFuncDef(toElementT); + const toElementMappingName = toElementMapping.name; + const toLengthMappingName = toLengthMapping.name; + const copyCode = createElementCopy( + toElementCairoType, + fromElementCairoType, + elementCopyFunc.name, + ); + + const deleteFunc = this.storageDeleteGen.getOrCreateFuncDef(toType); + const deleteRemainingCode = `${deleteFunc.name}_elem(to_loc, from_length, to_length)`; + + const funcName = `WS_COPY_DYNAMIC_${this.generatedFunctionsDef.size}`; return { name: funcName, code: [ - `func ${funcName}_elem${implicits}(to_loc: felt, from_loc: felt, length: Uint256) -> (){`, + `func ${funcName}_elem${IMPLICITS}(to_loc: felt, from_loc: felt, length: Uint256) -> (){`, ` alloc_locals;`, ` if (length.low == 0 and length.high == 0){`, ` return ();`, ` }`, ` let (index) = uint256_sub(length, Uint256(1,0));`, - ` let (from_elem_loc) = ${fromElementMapping}.read(from_loc, index);`, - ` let (to_elem_loc) = ${toElementMapping}.read(to_loc, index);`, + ` let (from_elem_loc) = ${fromElementMappingName}.read(from_loc, index);`, + ` let (to_elem_loc) = ${toElementMappingName}.read(to_loc, index);`, ` if (to_elem_loc == 0){`, ` let (to_elem_loc) = WARP_USED_STORAGE.read();`, ` WARP_USED_STORAGE.write(to_elem_loc + ${toElementCairoType.width});`, - ` ${toElementMapping}.write(to_loc, index, to_elem_loc);`, + ` ${toElementMappingName}.write(to_loc, index, to_elem_loc);`, ` ${copyCode('to_elem_loc', 'from_elem_loc')}`, ` return ${funcName}_elem(to_loc, from_loc, index);`, ` }else{`, @@ -278,11 +300,11 @@ export class StorageToStorageGen extends StringIndexedFuncGen { ` return ${funcName}_elem(to_loc, from_loc, index);`, ` }`, `}`, - `func ${funcName}${implicits}(to_loc: felt, from_loc: felt) -> (retLoc: felt){`, + `func ${funcName}${IMPLICITS}(to_loc: felt, from_loc: felt) -> (retLoc: felt){`, ` alloc_locals;`, - ` let (from_length) = ${fromLengthMapping}.read(from_loc);`, - ` let (to_length) = ${toLengthMapping}.read(to_loc);`, - ` ${toLengthMapping}.write(to_loc, from_length);`, + ` let (from_length) = ${fromLengthMappingName}.read(from_loc);`, + ` let (to_length) = ${toLengthMappingName}.read(to_loc);`, + ` ${toLengthMappingName}.write(to_loc, from_length);`, ` ${funcName}_elem(to_loc, from_loc, from_length);`, ` let (lesser) = uint256_lt(from_length, to_length);`, ` if (lesser == 1){`, @@ -293,24 +315,31 @@ export class StorageToStorageGen extends StringIndexedFuncGen { ` }`, `}`, ].join('\n'), + functionsCalled: [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_sub'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_lt'), + elementCopyFunc, + fromElementMapping, + fromLengthMapping, + toElementMapping, + toLengthMapping, + deleteFunc, + ], }; } private createStaticToDynamicArrayCopyFunction( - funcName: string, toType: ArrayType | BytesType | StringType, fromType: ArrayType, - ): CairoFunction { + ): GeneratedFunctionInfo { const toSize = getSize(toType); const toElementT = getElementType(toType); assert(fromType.size !== undefined); assert(toSize === undefined); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_add'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_lt'); + const elementCopyFunc = this.getOrCreateFuncDef(toElementT, fromType.elementT); - const elementCopyFunc = this.getOrCreate(toElementT, fromType.elementT); const fromElementCairoType = CairoType.fromSol( fromType.elementT, this.ast, @@ -321,30 +350,37 @@ export class StorageToStorageGen extends StringIndexedFuncGen { this.ast, TypeConversionContext.StorageAllocation, ); - const [toElementMapping, toLengthMapping] = this.dynArrayGen.gen(toElementCairoType); - const copyCode = createElementCopy(toElementCairoType, fromElementCairoType, elementCopyFunc); - const deleteRemainingCode = `${this.storageDeleteGen.genAuxFuncName( - toType, - )}(to_loc, from_length, to_length)`; + const [toElementMapping, toLengthMapping] = this.dynArrayGen.getOrCreateFuncDef(toElementT); + const toElementMappingName = toElementMapping.name; + const toLengthMappingName = toLengthMapping.name; + const copyCode = createElementCopy( + toElementCairoType, + fromElementCairoType, + elementCopyFunc.name, + ); + const deleteFunc = this.storageDeleteGen.getOrCreateFuncDef(toType); + const deleteRemainingCode = `${deleteFunc.name}_elem(to_loc, from_length, to_length)`; + + const funcName = `WS_COPY_STATIC_TO_DYNAMIC_${this.generatedFunctionsDef.size}`; return { name: funcName, code: [ - `func ${funcName}_elem${implicits}(to_loc: felt, from_elem_loc: felt, length: Uint256, index: Uint256) -> (){`, + `func ${funcName}_elem${IMPLICITS}(to_loc: felt, from_elem_loc: felt, length: Uint256, index: Uint256) -> (){`, ` alloc_locals;`, ` if (length.low == index.low){`, ` if (length.high == index.high){`, ` return ();`, ` }`, ` }`, - ` let (to_elem_loc) = ${toElementMapping}.read(to_loc, index);`, + ` let (to_elem_loc) = ${toElementMappingName}.read(to_loc, index);`, ` let (next_index, carry) = uint256_add(index, Uint256(1,0));`, ` assert carry = 0;`, ` if (to_elem_loc == 0){`, ` let (to_elem_loc) = WARP_USED_STORAGE.read();`, ` WARP_USED_STORAGE.write(to_elem_loc + ${toElementCairoType.width});`, - ` ${toElementMapping}.write(to_loc, index, to_elem_loc);`, + ` ${toElementMappingName}.write(to_loc, index, to_elem_loc);`, ` ${copyCode('to_elem_loc', 'from_elem_loc')}`, ` return ${funcName}_elem(to_loc, from_elem_loc + ${fromElementCairoType.width}, length, next_index);`, ` }else{`, @@ -352,11 +388,11 @@ export class StorageToStorageGen extends StringIndexedFuncGen { ` return ${funcName}_elem(to_loc, from_elem_loc + ${fromElementCairoType.width}, length, next_index);`, ` }`, `}`, - `func ${funcName}${implicits}(to_loc: felt, from_loc: felt) -> (retLoc: felt){`, + `func ${funcName}${IMPLICITS}(to_loc: felt, from_loc: felt) -> (retLoc: felt){`, ` alloc_locals;`, ` let from_length = ${uint256(narrowBigIntSafe(fromType.size))};`, - ` let (to_length) = ${toLengthMapping}.read(to_loc);`, - ` ${toLengthMapping}.write(to_loc, from_length);`, + ` let (to_length) = ${toLengthMappingName}.read(to_loc);`, + ` ${toLengthMappingName}.write(to_loc, from_length);`, ` ${funcName}_elem(to_loc, from_loc, from_length , Uint256(0,0));`, ` let (lesser) = uint256_lt(from_length, to_length);`, ` if (lesser == 1){`, @@ -367,29 +403,24 @@ export class StorageToStorageGen extends StringIndexedFuncGen { ` }`, `}`, ].join('\n'), + functionsCalled: [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_add'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_lt'), + elementCopyFunc, + toElementMapping, + toLengthMapping, + deleteFunc, + ], }; } - private createIntegerCopyFunction( - funcName: string, - toType: IntType, - fromType: IntType, - ): CairoFunction { + private createIntegerCopyFunction(toType: IntType, fromType: IntType): GeneratedFunctionInfo { assert( fromType.nBits <= toType.nBits, `Attempted to scale integer ${fromType.nBits} to ${toType.nBits}`, ); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - if (toType.signed) { - this.requireImport( - 'warplib.maths.int_conversions', - `warp_int${fromType.nBits}_to_int${toType.nBits}`, - ); - } else { - this.requireImport('warplib.maths.utils', 'felt_to_uint256'); - } - // Read changes depending if From is 256 bits or less const readFromCode = fromType.nBits === 256 @@ -418,10 +449,11 @@ export class StorageToStorageGen extends StringIndexedFuncGen { ].join('\n') : 'WARP_STORAGE.write(to_loc, to_elem);'; + const funcName = `WS_COPY_INTEGER_${this.generatedFunctionsDef.size}`; return { name: funcName, code: [ - `func ${funcName}${implicits}(to_loc : felt, from_loc : felt) -> (ret_loc : felt){`, + `func ${funcName}${IMPLICITS}(to_loc : felt, from_loc : felt) -> (ret_loc : felt){`, ` alloc_locals;`, ` ${readFromCode}`, ` ${scalingCode}`, @@ -429,19 +461,26 @@ export class StorageToStorageGen extends StringIndexedFuncGen { ` return (to_loc,);`, `}`, ].join('\n'), + functionsCalled: [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + toType.signed + ? this.requireImport( + 'warplib.maths.int_conversions', + `warp_int${fromType.nBits}_to_int${toType.nBits}`, + ) + : this.requireImport('warplib.maths.utils', 'felt_to_uint256'), + ], }; } private createFixedBytesCopyFunction( - funcName: string, toType: FixedBytesType, fromType: FixedBytesType, - ) { + ): GeneratedFunctionInfo { const bitWidthDiff = (toType.size - fromType.size) * 8; assert(bitWidthDiff >= 0, `Attempted to scale fixed byte ${fromType.size} to ${toType.size}`); const conversionFunc = toType.size === 32 ? 'warp_bytes_widen_256' : 'warp_bytes_widen'; - this.requireImport('warplib.maths.bytes_conversions', conversionFunc); const readFromCode = fromType.size === 32 @@ -465,10 +504,11 @@ export class StorageToStorageGen extends StringIndexedFuncGen { ].join('\n') : 'WARP_STORAGE.write(to_loc, to_elem);'; + const funcName = `WS_COPY_FIXED_BYTES_${this.generatedFunctionsDef.size}`; return { name: funcName, code: [ - `func ${funcName}${implicits}(to_loc : felt, from_loc : felt) -> (ret_loc : felt){`, + `func ${funcName}${IMPLICITS}(to_loc : felt, from_loc : felt) -> (ret_loc : felt){`, ` alloc_locals;`, ` ${readFromCode}`, ` ${scalingCode}`, @@ -476,28 +516,31 @@ export class StorageToStorageGen extends StringIndexedFuncGen { ` return (to_loc,);`, `}`, ].join('\n'), + functionsCalled: [ + this.requireImport('warplib.maths.bytes_conversions', conversionFunc), + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + ], }; } - private createValueTypeCopyFunction(funcName: string, type: TypeNode): CairoFunction { + private createValueTypeCopyFunction(type: TypeNode): GeneratedFunctionInfo { const width = CairoType.fromSol(type, this.ast, TypeConversionContext.StorageAllocation).width; + const funcName = `WS_COPY_VALUE_${this.generatedFunctionsDef.size}`; return { name: funcName, code: [ - `func ${funcName}${implicits}(to_loc : felt, from_loc : felt) -> (ret_loc : felt){`, + `func ${funcName}${IMPLICITS}(to_loc : felt, from_loc : felt) -> (ret_loc : felt){`, ` alloc_locals;`, ...mapRange(width, copyAtOffset), ` return (to_loc,);`, `}`, ].join('\n'), + functionsCalled: [], }; } } -const implicits = - '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, bitwise_ptr : BitwiseBuiltin*}'; - function copyAtOffset(n: number): string { return [ `let (copy) = WARP_STORAGE.read(${add('from_loc', n)});`, @@ -505,6 +548,8 @@ function copyAtOffset(n: number): string { ].join('\n'); } +// TODO: There is a bunch of `readId` here! +// Do they need to be imported function createElementCopy( toElementCairoType: CairoType, fromElementCairoType: CairoType, diff --git a/src/cairoUtilFuncGen/storage/dynArray.ts b/src/cairoUtilFuncGen/storage/dynArray.ts index de7555ccc..6ee79d0a8 100644 --- a/src/cairoUtilFuncGen/storage/dynArray.ts +++ b/src/cairoUtilFuncGen/storage/dynArray.ts @@ -1,72 +1,100 @@ -import { CairoType } from '../../utils/cairoTypeSystem'; -import { StringIndexedFuncGen } from '../base'; -import { INCLUDE_CAIRO_DUMP_FUNCTIONS } from '../../cairoWriter/utils'; +import assert from 'assert'; +import { + ArrayType, + BytesType, + FunctionCall, + FunctionStateMutability, + MemberAccess, + StringType, + TypeNode, +} from 'solc-typed-ast'; +import { + CairoFunctionDefinition, + createCairoGeneratedFunction, + createCallToFunction, + createUint256TypeName, + createUintNTypeName, + FunctionStubKind, + getElementType, +} from '../../export'; +import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; +import { GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; export class DynArrayGen extends StringIndexedFuncGen { - gen(valueCairoType: CairoType): [data: string, len: string] { - const key = valueCairoType.fullStringRepresentation; - const existing = this.generatedFunctions.get(key); + public genLength( + node: MemberAccess, + arrayType: ArrayType | BytesType | StringType, + ): FunctionCall { + const [_dynArray, dynArrayLength] = this.getOrCreateFuncDef(getElementType(arrayType)); + return createCallToFunction(dynArrayLength, [node.vExpression], this.ast); + } + + public getOrCreateFuncDef(type: TypeNode): [CairoFunctionDefinition, CairoFunctionDefinition] { + const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.StorageAllocation); + + const key = cairoType.fullStringRepresentation; + const lenghtKey = key + '_LENGTH'; + const existing = this.generatedFunctionsDef.get(key); if (existing !== undefined) { - return [existing.name, `${existing.name}_LENGTH`]; + const exsitingLength = this.generatedFunctionsDef.get(lenghtKey); + assert(exsitingLength !== undefined); + return [existing, exsitingLength]; } - const mappingName = `WARP_DARRAY${this.generatedFunctions.size}_${valueCairoType.typeName}`; - this.generatedFunctions.set(key, { + const [arrayInfo, lengthInfo] = this.getOrCreate(cairoType); + + const dynArray = createCairoGeneratedFunction( + arrayInfo, + [ + ['name', createUintNTypeName(248, this.ast)], + ['index', createUint256TypeName(this.ast)], + ], + [['res_loc', createUintNTypeName(248, this.ast)]], + this.ast, + this.sourceUnit, + { + mutability: FunctionStateMutability.View, + stubKind: FunctionStubKind.StorageDefStub, + }, + ); + const dynArrayLength = createCairoGeneratedFunction( + lengthInfo, + [['name', createUintNTypeName(248, this.ast)]], + [['length', createUint256TypeName(this.ast)]], + this.ast, + this.sourceUnit, + { + mutability: FunctionStateMutability.View, + stubKind: FunctionStubKind.StorageDefStub, + }, + ); + + this.generatedFunctionsDef.set(key, dynArray); + this.generatedFunctionsDef.set(lenghtKey, dynArrayLength); + return [dynArray, dynArrayLength]; + } + + private getOrCreate(valueCairoType: CairoType): [GeneratedFunctionInfo, GeneratedFunctionInfo] { + const mappingName = `WARP_DARRAY${this.generatedFunctionsDef.size}_${valueCairoType.typeName}`; + const funcInfo: GeneratedFunctionInfo = { name: mappingName, code: [ `@storage_var`, - `func ${mappingName}(name: felt, index: Uint256) -> (resLoc : felt){`, + `func ${mappingName}(name: felt, index: Uint256) -> (res_loc : felt){`, `}`, + ].join('\n'), + functionsCalled: [], + }; + + const lengthFuncInfo: GeneratedFunctionInfo = { + name: `${mappingName}_LENGTH`, + code: [ `@storage_var`, - `func ${mappingName}_LENGTH(name: felt) -> (index: Uint256){`, + `func ${mappingName}_LENGTH(name: felt) -> (length: Uint256){`, `}`, - ...getDumpFunctions(mappingName), ].join('\n'), - }); - return [mappingName, `${mappingName}_LENGTH`]; + functionsCalled: [], + }; + return [funcInfo, lengthFuncInfo]; } } - -function getDumpFunctions(mappingName: string): string[] { - return INCLUDE_CAIRO_DUMP_FUNCTIONS - ? [ - `func DUMP_${mappingName}_ITER{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}(name: felt, length : felt, ptr: felt*){`, - ` alloc_locals;`, - ` if (length == 0){`, - ` return ();`, - ` }`, - ` let index = length - 1;`, - ` let (read) = ${mappingName}.read(name, Uint256(index, 0));`, - ` assert ptr[index] = read;`, - ` DUMP_${mappingName}_ITER(name, index, ptr);`, - ` return ();`, - `}`, - `@external`, - `func DUMP_${mappingName}{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}(name: felt, length : felt) -> (data_len : felt, data: felt*){`, - ` alloc_locals;`, - ` let (p: felt*) = alloc();`, - ` DUMP_${mappingName}_ITER(name, length, p);`, - ` return (length, p);`, - `}`, - `func DUMP_${mappingName}_LENGTH_ITER{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}(length : felt, ptr: felt*){`, - ` alloc_locals;`, - ` if (length == 0){`, - ` return ();`, - ` }`, - ` let index = length - 1;`, - ` let (read) = ${mappingName}_LENGTH.read(index);`, - ` assert ptr[2*index] = read.low;`, - ` assert ptr[2*index+1] = read.high;`, - ` DUMP_${mappingName}_LENGTH_ITER(index, ptr);`, - ` return ();`, - `}`, - `@external`, - `func DUMP_${mappingName}_LENGTH{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}(length : felt) -> (data_len : felt, data: felt*){`, - ` alloc_locals;`, - ` let (p: felt*) = alloc();`, - ` DUMP_${mappingName}_LENGTH_ITER(length, p);`, - ` return (length*2, p);`, - `}`, - ] - : []; -} diff --git a/src/cairoUtilFuncGen/storage/dynArrayIndexAccess.ts b/src/cairoUtilFuncGen/storage/dynArrayIndexAccess.ts index a439b3508..b14d66624 100644 --- a/src/cairoUtilFuncGen/storage/dynArrayIndexAccess.ts +++ b/src/cairoUtilFuncGen/storage/dynArrayIndexAccess.ts @@ -1,6 +1,5 @@ import assert from 'assert'; import { - ASTNode, DataLocation, FunctionCall, IndexAccess, @@ -10,19 +9,19 @@ import { } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { createUint256TypeName } from '../../utils/nodeTemplates'; import { isDynamicArray, safeGetNodeType } from '../../utils/nodeTypeProcessing'; import { typeNameFromTypeNode } from '../../utils/utils'; -import { StringIndexedFuncGen } from '../base'; +import { GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; import { DynArrayGen } from './dynArray'; export class DynArrayIndexAccessGen extends StringIndexedFuncGen { - constructor(private dynArrayGen: DynArrayGen, ast: AST, sourceUnit: SourceUnit) { + public constructor(private dynArrayGen: DynArrayGen, ast: AST, sourceUnit: SourceUnit) { super(ast, sourceUnit); } - gen(node: IndexAccess, nodeInSourceUnit?: ASTNode): FunctionCall { + public gen(node: IndexAccess): FunctionCall { const base = node.vBaseExpression; const index = node.vIndexExpression; assert(index !== undefined); @@ -31,38 +30,45 @@ export class DynArrayIndexAccessGen extends StringIndexedFuncGen { const baseType = safeGetNodeType(base, this.ast.inference); assert(baseType instanceof PointerType && isDynamicArray(baseType.to)); - const name = this.getOrCreate(nodeType); - const functionStub = createCairoFunctionStub( - name, + const funcDef = this.getOrCreateFuncDef(nodeType, baseType); + return createCallToFunction(funcDef, [base, index], this.ast); + } + + public getOrCreateFuncDef(nodeType: TypeNode, baseType: TypeNode) { + const nodeCairoType = CairoType.fromSol( + nodeType, + this.ast, + TypeConversionContext.StorageAllocation, + ); + + const key = nodeCairoType.fullStringRepresentation; + const value = this.generatedFunctionsDef.get(key); + if (value !== undefined) { + return value; + } + + const funcInfo = this.getOrCreate(nodeType, nodeCairoType); + const funcDef = createCairoGeneratedFunction( + funcInfo, [ ['loc', typeNameFromTypeNode(baseType, this.ast), DataLocation.Storage], ['offset', createUint256TypeName(this.ast)], ], - [['resLoc', typeNameFromTypeNode(nodeType, this.ast), DataLocation.Storage]], - ['syscall_ptr', 'pedersen_ptr', 'range_check_ptr'], + [['res_loc', typeNameFromTypeNode(nodeType, this.ast), DataLocation.Storage]], this.ast, - nodeInSourceUnit ?? node, + this.sourceUnit, ); - - return createCallToFunction(functionStub, [base, index], this.ast); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - getOrCreate(valueType: TypeNode): string { - const valueCairoType = CairoType.fromSol( - valueType, - this.ast, - TypeConversionContext.StorageAllocation, - ); - const key = valueCairoType.fullStringRepresentation; - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - - const [arrayName, lengthName] = this.dynArrayGen.gen(valueCairoType); + private getOrCreate(valueType: TypeNode, valueCairoType: CairoType): GeneratedFunctionInfo { + const [arrayDef, arrayLength] = this.dynArrayGen.getOrCreateFuncDef(valueType); + const arrayName = arrayDef.name; + const lengthName = arrayLength.name; const funcName = `${arrayName}_IDX`; - this.generatedFunctions.set(key, { + return { name: funcName, code: [ `func ${funcName}{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}(ref: felt, index: Uint256) -> (res: felt){`, @@ -81,9 +87,12 @@ export class DynArrayIndexAccessGen extends StringIndexedFuncGen { ` }`, `}`, ].join('\n'), - }); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_lt'); - return funcName; + functionsCalled: [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_lt'), + arrayDef, + arrayLength, + ], + }; } } diff --git a/src/cairoUtilFuncGen/storage/dynArrayLength.ts b/src/cairoUtilFuncGen/storage/dynArrayLength.ts deleted file mode 100644 index 5a55f18b5..000000000 --- a/src/cairoUtilFuncGen/storage/dynArrayLength.ts +++ /dev/null @@ -1,53 +0,0 @@ -import { - MemberAccess, - ArrayType, - FunctionCall, - ASTNode, - DataLocation, - SourceUnit, - BytesType, - StringType, -} from 'solc-typed-ast'; -import { AST } from '../../ast/ast'; -import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; -import { createUint256TypeName } from '../../utils/nodeTemplates'; -import { getElementType } from '../../utils/nodeTypeProcessing'; -import { typeNameFromTypeNode } from '../../utils/utils'; -import { CairoUtilFuncGenBase } from '../base'; -import { DynArrayGen } from './dynArray'; - -export class DynArrayLengthGen extends CairoUtilFuncGenBase { - constructor(private dynArrayGen: DynArrayGen, ast: AST, sourceUnit: SourceUnit) { - super(ast, sourceUnit); - } - - getGeneratedCode(): string { - return ''; - } - - gen( - node: MemberAccess, - arrayType: ArrayType | BytesType | StringType, - nodeInSourceUnit?: ASTNode, - ): FunctionCall { - const lengthName = this.dynArrayGen.gen( - CairoType.fromSol( - getElementType(arrayType), - this.ast, - TypeConversionContext.StorageAllocation, - ), - )[1]; - - const functionStub = createCairoFunctionStub( - `${lengthName}.read`, - [['name', typeNameFromTypeNode(arrayType, this.ast), DataLocation.Storage]], - [['len', createUint256TypeName(this.ast)]], - ['syscall_ptr', 'pedersen_ptr', 'range_check_ptr'], - this.ast, - nodeInSourceUnit ?? node, - ); - - return createCallToFunction(functionStub, [node.vExpression], this.ast); - } -} diff --git a/src/cairoUtilFuncGen/storage/dynArrayPop.ts b/src/cairoUtilFuncGen/storage/dynArrayPop.ts index 0efd88da4..6651e2793 100644 --- a/src/cairoUtilFuncGen/storage/dynArrayPop.ts +++ b/src/cairoUtilFuncGen/storage/dynArrayPop.ts @@ -1,7 +1,6 @@ import assert from 'assert'; import { ArrayType, - ASTNode, BytesType, DataLocation, FunctionCall, @@ -12,8 +11,9 @@ import { TypeNode, } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { CairoFunctionDefinition } from '../../export'; import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { getElementType, isDynamicArray, @@ -21,7 +21,7 @@ import { safeGetNodeType, } from '../../utils/nodeTypeProcessing'; import { typeNameFromTypeNode } from '../../utils/utils'; -import { StringIndexedFuncGen } from '../base'; +import { GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; import { DynArrayGen } from './dynArray'; import { StorageDeleteGen } from './storageDelete'; @@ -35,7 +35,7 @@ export class DynArrayPopGen extends StringIndexedFuncGen { super(ast, sourceUnit); } - gen(pop: FunctionCall, nodeInSourceUnit?: ASTNode): FunctionCall { + public gen(pop: FunctionCall): FunctionCall { assert(pop.vExpression instanceof MemberAccess); const arrayType = generalizeType( safeGetNodeType(pop.vExpression.vExpression, this.ast.inference), @@ -46,35 +46,44 @@ export class DynArrayPopGen extends StringIndexedFuncGen { arrayType instanceof StringType, ); - const name = this.getOrCreate(getElementType(arrayType)); - - const functionStub = createCairoFunctionStub( - name, - [['loc', typeNameFromTypeNode(arrayType, this.ast), DataLocation.Storage]], - [], - ['syscall_ptr', 'pedersen_ptr', 'range_check_ptr'], - this.ast, - nodeInSourceUnit ?? pop, - ); - - return createCallToFunction(functionStub, [pop.vExpression.vExpression], this.ast); + const funcDef = this.getOrCreateFuncDef(arrayType); + return createCallToFunction(funcDef, [pop.vExpression.vExpression], this.ast); } - private getOrCreate(elementType: TypeNode): string { + public getOrCreateFuncDef( + arrayType: ArrayType | BytesType | StringType, + ): CairoFunctionDefinition { + const elementT = getElementType(arrayType); const cairoElementType = CairoType.fromSol( - elementType, + elementT, this.ast, TypeConversionContext.StorageAllocation, ); const key = cairoElementType.fullStringRepresentation; - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; + const value = this.generatedFunctionsDef.get(key); + if (value !== undefined) { + return value; } - const [arrayName, lengthName] = this.dynArrayGen.gen(cairoElementType); - const deleteFuncName = this.storageDelete.genFuncName(elementType); + const funcInfo = this.getOrCreate(elementT); + const funcDef = createCairoGeneratedFunction( + funcInfo, + [['loc', typeNameFromTypeNode(arrayType, this.ast), DataLocation.Storage]], + [], + this.ast, + this.sourceUnit, + ); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; + } + + private getOrCreate(elementType: TypeNode): GeneratedFunctionInfo { + const deleteFunc = this.storageDelete.getOrCreateFuncDef(elementType); + const [dynArray, dynArrayLength] = this.dynArrayGen.getOrCreateFuncDef(elementType); + + const arrayName = dynArray.name; + const lengthName = dynArrayLength.name; const getElemLoc = isDynamicArray(elementType) || isMapping(elementType) @@ -85,7 +94,7 @@ export class DynArrayPopGen extends StringIndexedFuncGen { : `let (elem_loc) = ${arrayName}.read(loc, newLen);`; const funcName = `${arrayName}_POP`; - this.generatedFunctions.set(key, { + return { name: funcName, code: [ `func ${funcName}{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}(loc: felt) -> (){`, @@ -96,13 +105,17 @@ export class DynArrayPopGen extends StringIndexedFuncGen { ` let (newLen) = uint256_sub(len, Uint256(1,0));`, ` ${lengthName}.write(loc, newLen);`, ` ${getElemLoc}`, - ` return ${deleteFuncName}(elem_loc);`, + ` return ${deleteFunc.name}(elem_loc);`, `}`, ].join('\n'), - }); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_eq'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_sub'); - return funcName; + functionsCalled: [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_eq'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_sub'), + deleteFunc, + dynArray, + dynArrayLength, + ], + }; } } diff --git a/src/cairoUtilFuncGen/storage/dynArrayPushWithArg.ts b/src/cairoUtilFuncGen/storage/dynArrayPushWithArg.ts index ab6b4cd60..659f991d5 100644 --- a/src/cairoUtilFuncGen/storage/dynArrayPushWithArg.ts +++ b/src/cairoUtilFuncGen/storage/dynArrayPushWithArg.ts @@ -2,7 +2,6 @@ import assert from 'assert'; import { ArrayType, MappingType, - ASTNode, BytesType, DataLocation, FunctionCall, @@ -11,18 +10,18 @@ import { SourceUnit, StringType, TypeNode, + FunctionDefinition, } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { typeNameFromTypeNode } from '../../utils/utils'; -import { StringIndexedFuncGen } from '../base'; +import { GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; import { MemoryToStorageGen } from '../memory/memoryToStorage'; import { DynArrayGen } from './dynArray'; import { StorageWriteGen } from './storageWrite'; import { StorageToStorageGen } from './copyToStorage'; import { CalldataToStorageGen } from '../calldata/calldataToStorage'; -import { Implicits } from '../../utils/implicits'; import { getElementType, isDynamicArray, @@ -32,7 +31,7 @@ import { import { ImplicitArrayConversion } from '../calldata/implicitArrayConversion'; export class DynArrayPushWithArgGen extends StringIndexedFuncGen { - constructor( + public constructor( private dynArrayGen: DynArrayGen, private storageWrite: StorageWriteGen, private memoryToStorage: MemoryToStorageGen, @@ -45,7 +44,7 @@ export class DynArrayPushWithArgGen extends StringIndexedFuncGen { super(ast, sourceUnit); } - gen(push: FunctionCall, nodeInSourceUnit?: ASTNode): FunctionCall { + public gen(push: FunctionCall): FunctionCall { assert(push.vExpression instanceof MemberAccess); const arrayType = generalizeType( safeGetNodeType(push.vExpression.vExpression, this.ast.inference), @@ -64,58 +63,66 @@ export class DynArrayPushWithArgGen extends StringIndexedFuncGen { safeGetNodeType(push.vArguments[0], this.ast.inference), ); - const name = this.getOrCreate( + const funcDef = this.getOrCreateFuncDef(arrayType, argType, argLoc); + return createCallToFunction( + funcDef, + [push.vExpression.vExpression, push.vArguments[0]], + this.ast, + ); + } + + public getOrCreateFuncDef( + arrayType: TypeNode, + argType: TypeNode, + argLoc: DataLocation | undefined, + ) { + const key = `dynArrayPushWithArg(${arrayType.pp()},${argType.pp()},${argLoc})`; + const value = this.generatedFunctionsDef.get(key); + if (value !== undefined) { + return value; + } + + const funcInfo = this.getOrCreate( getElementType(arrayType), argType, argLoc ?? DataLocation.Default, ); - const implicits: Implicits[] = - argLoc === DataLocation.Memory - ? ['syscall_ptr', 'pedersen_ptr', 'range_check_ptr', 'warp_memory'] - : ['syscall_ptr', 'pedersen_ptr', 'range_check_ptr', 'bitwise_ptr']; - const functionStub = createCairoFunctionStub( - name, + const funcDef = createCairoGeneratedFunction( + funcInfo, [ ['loc', typeNameFromTypeNode(arrayType, this.ast), DataLocation.Storage], ['value', typeNameFromTypeNode(argType, this.ast), argLoc ?? DataLocation.Default], ], [], - implicits, - this.ast, - nodeInSourceUnit ?? push, - ); - - return createCallToFunction( - functionStub, - [push.vExpression.vExpression, push.vArguments[0]], this.ast, + this.sourceUnit, ); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - private getOrCreate(elementType: TypeNode, argType: TypeNode, argLoc: DataLocation): string { - const key = `${elementType.pp()}->${argType.pp()}`; - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - - let elementWriteFunc: string; + private getOrCreate( + elementType: TypeNode, + argType: TypeNode, + argLoc: DataLocation, + ): GeneratedFunctionInfo { + let elementWriteDef: FunctionDefinition; let inputType: string; if (argLoc === DataLocation.Memory) { - elementWriteFunc = this.memoryToStorage.getOrCreate(elementType); + elementWriteDef = this.memoryToStorage.getOrCreateFuncDef(elementType); inputType = 'felt'; } else if (argLoc === DataLocation.Storage) { - elementWriteFunc = this.storageToStorage.getOrCreate(elementType, argType); + elementWriteDef = this.storageToStorage.getOrCreateFuncDef(elementType, argType); inputType = 'felt'; } else if (argLoc === DataLocation.CallData) { if (elementType.pp() !== argType.pp()) { - elementWriteFunc = this.calldataToStorageConversion.getOrCreate( + elementWriteDef = this.calldataToStorageConversion.getOrCreateFuncDef( specializeType(elementType, DataLocation.Storage), specializeType(argType, DataLocation.CallData), ); } else { - elementWriteFunc = this.calldataToStorage.getOrCreate(elementType); + elementWriteDef = this.calldataToStorage.getOrCreateFuncDef(elementType, argType); } inputType = CairoType.fromSol( argType, @@ -123,16 +130,19 @@ export class DynArrayPushWithArgGen extends StringIndexedFuncGen { TypeConversionContext.CallDataRef, ).toString(); } else { - elementWriteFunc = this.storageWrite.getOrCreate(elementType); + elementWriteDef = this.storageWrite.getOrCreateFuncDef(elementType); inputType = CairoType.fromSol(elementType, this.ast).toString(); } + const allocationCairoType = CairoType.fromSol( elementType, this.ast, TypeConversionContext.StorageAllocation, ); - const [arrayName, lengthName] = this.dynArrayGen.gen(allocationCairoType); - const funcName = `${arrayName}_PUSHV${this.generatedFunctions.size}`; + const [dynArray, dynArrayLength] = this.dynArrayGen.getOrCreateFuncDef(elementType); + const arrayName = dynArray.name; + const lengthName = dynArrayLength.name; + const funcName = `${arrayName}_PUSHV${this.generatedFunctionsDef.size}`; const implicits = argLoc === DataLocation.Memory ? '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, warp_memory: DictAccess*}' @@ -140,10 +150,10 @@ export class DynArrayPushWithArgGen extends StringIndexedFuncGen { const callWriteFunc = (cairoVar: string) => isDynamicArray(argType) || argType instanceof MappingType - ? [`let (elem_id) = readId(${cairoVar});`, `${elementWriteFunc}(elem_id, value);`] - : [`${elementWriteFunc}(${cairoVar}, value);`]; + ? [`let (elem_id) = readId(${cairoVar});`, `${elementWriteDef.name}(elem_id, value);`] + : [`${elementWriteDef.name}(${cairoVar}, value);`]; - this.generatedFunctions.set(key, { + return { name: funcName, code: [ `func ${funcName}${implicits}(loc: felt, value: ${inputType}) -> (){`, @@ -164,9 +174,13 @@ export class DynArrayPushWithArgGen extends StringIndexedFuncGen { ` return ();`, `}`, ].join('\n'), - }); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_add'); - return funcName; + functionsCalled: [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_add'), + elementWriteDef, + dynArray, + dynArrayLength, + ], + }; } } diff --git a/src/cairoUtilFuncGen/storage/dynArrayPushWithoutArg.ts b/src/cairoUtilFuncGen/storage/dynArrayPushWithoutArg.ts index 4a71c64f2..4473538a0 100644 --- a/src/cairoUtilFuncGen/storage/dynArrayPushWithoutArg.ts +++ b/src/cairoUtilFuncGen/storage/dynArrayPushWithoutArg.ts @@ -1,11 +1,21 @@ import assert from 'assert'; -import { ASTNode, DataLocation, FunctionCall, MemberAccess, SourceUnit } from 'solc-typed-ast'; +import { + ArrayType, + BytesType, + DataLocation, + FunctionCall, + generalizeType, + MemberAccess, + SourceUnit, + TypeNode, +} from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { printTypeNode } from '../../export'; import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; -import { safeGetNodeType } from '../../utils/nodeTypeProcessing'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; +import { getElementType, safeGetNodeType } from '../../utils/nodeTypeProcessing'; import { typeNameFromTypeNode } from '../../utils/utils'; -import { StringIndexedFuncGen } from '../base'; +import { GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; import { DynArrayGen } from './dynArray'; export class DynArrayPushWithoutArgGen extends StringIndexedFuncGen { @@ -13,37 +23,52 @@ export class DynArrayPushWithoutArgGen extends StringIndexedFuncGen { super(ast, sourceUnit); } - gen(push: FunctionCall, nodeInSourceUnit?: ASTNode): FunctionCall { + gen(push: FunctionCall): FunctionCall { assert(push.vExpression instanceof MemberAccess); - const arrayType = safeGetNodeType(push.vExpression.vExpression, this.ast.inference); - const elementType = safeGetNodeType(push, this.ast.inference); + const arrayType = generalizeType( + safeGetNodeType(push.vExpression.vExpression, this.ast.inference), + )[0]; + assert( + arrayType instanceof ArrayType || arrayType instanceof BytesType, + `Pushing without args to a non array: ${printTypeNode(arrayType)}`, + ); + const funcDef = this.getOrCreateFuncDef(arrayType); + + return createCallToFunction(funcDef, [push.vExpression.vExpression], this.ast); + } - const name = this.getOrCreate( - CairoType.fromSol(elementType, this.ast, TypeConversionContext.StorageAllocation), + getOrCreateFuncDef(arrayType: ArrayType | BytesType) { + const elementType = getElementType(arrayType); + const cairoElementType = CairoType.fromSol( + elementType, + this.ast, + TypeConversionContext.StorageAllocation, ); - const functionStub = createCairoFunctionStub( - name, + const key = elementType.pp(); //cairoElementType.fullStringRepresentation; + const value = this.generatedFunctionsDef.get(key); + if (value !== undefined) { + return value; + } + + const funcInfo = this.getOrCreate(elementType, cairoElementType); + const funcDef = createCairoGeneratedFunction( + funcInfo, [['loc', typeNameFromTypeNode(arrayType, this.ast), DataLocation.Storage]], - [['newElemLoc', typeNameFromTypeNode(elementType, this.ast), DataLocation.Storage]], - ['syscall_ptr', 'pedersen_ptr', 'range_check_ptr'], + [['new_elem_loc', typeNameFromTypeNode(elementType, this.ast), DataLocation.Storage]], this.ast, - nodeInSourceUnit ?? push, + this.sourceUnit, ); - - return createCallToFunction(functionStub, [push.vExpression.vExpression], this.ast); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - private getOrCreate(elementType: CairoType): string { - const key = elementType.fullStringRepresentation; - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - - const [arrayName, lengthName] = this.dynArrayGen.gen(elementType); + private getOrCreate(elementType: TypeNode, cairoElementType: CairoType): GeneratedFunctionInfo { + const [dynArray, dynArrayLength] = this.dynArrayGen.getOrCreateFuncDef(elementType); + const arrayName = dynArray.name; + const lengthName = dynArrayLength.name; const funcName = `${arrayName}_PUSH`; - this.generatedFunctions.set(key, { + return { name: funcName, code: [ `func ${funcName}{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}(loc: felt) -> (newElemLoc: felt){`, @@ -55,7 +80,7 @@ export class DynArrayPushWithoutArgGen extends StringIndexedFuncGen { ` let (existing) = ${arrayName}.read(loc, len);`, ` if ((existing) == 0){`, ` let (used) = WARP_USED_STORAGE.read();`, - ` WARP_USED_STORAGE.write(used + ${elementType.width});`, + ` WARP_USED_STORAGE.write(used + ${cairoElementType.width});`, ` ${arrayName}.write(loc, len, used);`, ` return (used,);`, ` }else{`, @@ -63,9 +88,12 @@ export class DynArrayPushWithoutArgGen extends StringIndexedFuncGen { ` }`, `}`, ].join('\n'), - }); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_add'); - return funcName; + functionsCalled: [ + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_add'), + dynArray, + dynArrayLength, + ], + }; } } diff --git a/src/cairoUtilFuncGen/storage/export.ts b/src/cairoUtilFuncGen/storage/export.ts index c3624e893..3dea6dfff 100644 --- a/src/cairoUtilFuncGen/storage/export.ts +++ b/src/cairoUtilFuncGen/storage/export.ts @@ -9,7 +9,6 @@ export * from './storageToMemory'; export * from './storageRead'; export * from './dynArray'; export * from './dynArrayIndexAccess'; -export * from './dynArrayLength'; export * from './dynArrayPop'; export * from './storageWrite'; export * from './storageDelete'; diff --git a/src/cairoUtilFuncGen/storage/mappingIndexAccess.ts b/src/cairoUtilFuncGen/storage/mappingIndexAccess.ts index 635ef1415..cb407fde8 100644 --- a/src/cairoUtilFuncGen/storage/mappingIndexAccess.ts +++ b/src/cairoUtilFuncGen/storage/mappingIndexAccess.ts @@ -1,65 +1,86 @@ import assert from 'assert'; import { - ASTNode, DataLocation, - Expression, FunctionCall, generalizeType, IndexAccess, MappingType, PointerType, SourceUnit, + TypeNode, } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { CairoFunctionDefinition, FunctionStubKind } from '../../export'; import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; -import { createUint8TypeName } from '../../utils/nodeTemplates'; -import { isReferenceType, safeGetNodeType } from '../../utils/nodeTypeProcessing'; +import { + createCairoGeneratedFunction, + createCallToFunction, + ParameterInfo, +} from '../../utils/functionGeneration'; +import { createUint8TypeName, createUintNTypeName } from '../../utils/nodeTemplates'; +import { + getElementType, + isDynamicArray, + isReferenceType, + safeGetNodeType, +} from '../../utils/nodeTypeProcessing'; import { typeNameFromTypeNode } from '../../utils/utils'; -import { locationIfComplexType, StringIndexedFuncGen } from '../base'; +import { CairoUtilFuncGenBase, GeneratedFunctionInfo, locationIfComplexType } from '../base'; import { DynArrayGen } from './dynArray'; -export class MappingIndexAccessGen extends StringIndexedFuncGen { - private generatedHashFunctionNumber = 0; - +export class MappingIndexAccessGen extends CairoUtilFuncGenBase { + private indexAccesFunctions = new Map(); + private stringHashFunctions = new Map(); constructor(private dynArrayGen: DynArrayGen, ast: AST, sourceUnit: SourceUnit) { super(ast, sourceUnit); } - gen(node: IndexAccess, nodeInSourceUnit?: ASTNode): FunctionCall { + public gen(node: IndexAccess): FunctionCall { const base = node.vBaseExpression; let index = node.vIndexExpression; assert(index !== undefined); const nodeType = safeGetNodeType(node, this.ast.inference); const baseType = safeGetNodeType(base, this.ast.inference); - assert(baseType instanceof PointerType && baseType.to instanceof MappingType); - const indexCairoType = CairoType.fromSol(baseType.to.keyType, this.ast); - const valueCairoType = CairoType.fromSol( - nodeType, - this.ast, - TypeConversionContext.StorageAllocation, - ); - if (isReferenceType(baseType.to.keyType)) { - const stringLoc = generalizeType(safeGetNodeType(index, this.ast.inference))[1]; + const [stringType, stringLoc] = generalizeType(safeGetNodeType(index, this.ast.inference)); assert(stringLoc !== undefined); - const call = this.createStringHashFunction(node, stringLoc, indexCairoType); - index = call; + const stringHashFunc = this.getOrCreateStringHashFunction(stringType, stringLoc); + index = createCallToFunction(stringHashFunc, [index], this.ast, this.sourceUnit); } - const name = this.getOrCreate(indexCairoType, valueCairoType); + const funcDef = this.getOrCreateIndexAccessFunction(baseType.to.keyType, nodeType); + return createCallToFunction(funcDef, [base, index], this.ast); + } - const functionStub = createCairoFunctionStub( - name, + public getOrCreateIndexAccessFunction(indexType: TypeNode, nodeType: TypeNode) { + const indexKey = CairoType.fromSol( + indexType, + this.ast, + TypeConversionContext.StorageAllocation, + ).fullStringRepresentation; + const nodeKey = CairoType.fromSol( + nodeType, + this.ast, + TypeConversionContext.StorageAllocation, + ).fullStringRepresentation; + const key = indexKey + '-' + nodeKey; + const existing = this.indexAccesFunctions.get(key); + if (existing !== undefined) { + return existing; + } + + const funcInfo = this.generateIndexAccess(indexType, nodeType); + const funcDef = createCairoGeneratedFunction( + funcInfo, [ - ['name', typeNameFromTypeNode(baseType, this.ast), DataLocation.Storage], + ['name', typeNameFromTypeNode(indexType, this.ast), DataLocation.Storage], [ 'index', - typeNameFromTypeNode(baseType.to.keyType, this.ast), - locationIfComplexType(baseType.to.keyType, DataLocation.Memory), + typeNameFromTypeNode(indexType, this.ast), + locationIfComplexType(indexType, DataLocation.Memory), ], ], [ @@ -69,40 +90,56 @@ export class MappingIndexAccessGen extends StringIndexedFuncGen { locationIfComplexType(nodeType, DataLocation.Storage), ], ], - ['syscall_ptr', 'pedersen_ptr', 'range_check_ptr'], this.ast, - nodeInSourceUnit ?? node, + this.sourceUnit, ); - - return createCallToFunction(functionStub, [base, index], this.ast); + this.indexAccesFunctions.set(key, funcDef); + return funcDef; } - private getOrCreate(indexType: CairoType, valueType: CairoType): string { - const key = `${indexType.fullStringRepresentation}/${valueType.fullStringRepresentation}`; - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } + private generateIndexAccess(indexType: TypeNode, valueType: TypeNode): GeneratedFunctionInfo { + const indexCairoType = CairoType.fromSol(indexType, this.ast); + const valueCairoType = CairoType.fromSol( + valueType, + this.ast, + TypeConversionContext.StorageAllocation, + ); - const funcName = `WS${this.generatedFunctions.size - this.generatedHashFunctionNumber}_INDEX_${ - indexType.typeName - }_to_${valueType.typeName}`; - const mappingName = `WARP_MAPPING${ - this.generatedFunctions.size - this.generatedHashFunctionNumber - }`; - const indexTypeString = indexType.toString(); - this.generatedFunctions.set(key, { - name: funcName, + const identifier = this.indexAccesFunctions.size; + const funcName = `WS_INDEX_${indexCairoType.typeName}_to_${valueCairoType.typeName}${identifier}`; + const mappingName = `WARP_MAPPING${identifier}`; + const indexTypeString = indexCairoType.toString(); + + const mappingFuncInfo: GeneratedFunctionInfo = { + name: mappingName, code: [ `@storage_var`, `func ${mappingName}(name: felt, index: ${indexTypeString}) -> (resLoc : felt){`, `}`, + ].join('\n'), + functionsCalled: [], + }; + const mappingFunc = createCairoGeneratedFunction( + mappingFuncInfo, + [ + ['name', createUintNTypeName(248, this.ast)], + ['index', typeNameFromTypeNode(indexType, this.ast)], + ], + [], + this.ast, + this.sourceUnit, + { stubKind: FunctionStubKind.StorageDefStub }, + ); + + return { + name: funcName, + code: [ `func ${funcName}{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}(name: felt, index: ${indexTypeString}) -> (res: felt){`, ` alloc_locals;`, ` let (existing) = ${mappingName}.read(name, index);`, ` if (existing == 0){`, ` let (used) = WARP_USED_STORAGE.read();`, - ` WARP_USED_STORAGE.write(used + ${valueType.width});`, + ` WARP_USED_STORAGE.write(used + ${valueCairoType.width});`, ` ${mappingName}.write(name, index, used);`, ` return (used,);`, ` }else{`, @@ -110,100 +147,112 @@ export class MappingIndexAccessGen extends StringIndexedFuncGen { ` }`, `}`, ].join('\n'), - }); - return funcName; + functionsCalled: [mappingFunc], + }; } - private createStringHashFunction( - node: IndexAccess, - loc: DataLocation, - indexCairoType: CairoType, - ): FunctionCall { - assert(node.vIndexExpression instanceof Expression); - const indexType = safeGetNodeType(node.vIndexExpression, this.ast.inference); + public getOrCreateStringHashFunction( + indexType: TypeNode, + dataLocation: DataLocation, + ): CairoFunctionDefinition { + assert(dataLocation !== DataLocation.Default); + + const key = indexType.pp() + dataLocation; + const existing = this.stringHashFunctions.get(key); + if (existing !== undefined) { + return existing; + } + const indexTypeName = typeNameFromTypeNode(indexType, this.ast); - if (loc === DataLocation.CallData) { - const stub = createCairoFunctionStub( + + const inputInfo: ParameterInfo[] = [['str', indexTypeName, dataLocation]]; + const outputInfo: ParameterInfo[] = [ + ['hashed_str', createUint8TypeName(this.ast), DataLocation.Default], + ]; + + if (dataLocation === DataLocation.CallData) { + const importFunction = this.ast.registerImport( + this.sourceUnit, + 'warplib.string_hash', 'string_hash', - [['str', indexTypeName, DataLocation.CallData]], - [['hashedStr', createUint8TypeName(this.ast), DataLocation.Default]], - ['pedersen_ptr'], - this.ast, - node, + inputInfo, + outputInfo, ); - const call = createCallToFunction(stub, [node.vIndexExpression], this.ast); - this.ast.registerImport(call, 'warplib.string_hash', 'string_hash'); - return call; - } else if (loc === DataLocation.Memory) { - const stub = createCairoFunctionStub( + return importFunction; + } + + if (dataLocation === DataLocation.Memory) { + const importFunction = this.ast.registerImport( + this.sourceUnit, + 'warplib.string_hash', 'wm_string_hash', - [['str', indexTypeName, DataLocation.Memory]], - [['hashedStr', createUint8TypeName(this.ast), DataLocation.Default]], - ['pedersen_ptr', 'range_check_ptr', 'warp_memory'], - this.ast, - node, - ); - const call = createCallToFunction(stub, [node.vIndexExpression], this.ast); - this.ast.registerImport(call, 'warplib.string_hash', 'wm_string_hash'); - return call; - } else { - const [data, len] = this.dynArrayGen.gen(indexCairoType); - const key = `${data}/${len}_hash`; - let funcName = `ws_string_hash${this.generatedHashFunctionNumber}`; - const helperFuncName = `ws_to_felt_array${this.generatedHashFunctionNumber}`; - - const existing = this.generatedFunctions.get(key); - if (existing === undefined) { - this.generatedFunctions.set(key, { - name: funcName, - code: [ - `func ${helperFuncName}{pedersen_ptr : HashBuiltin*, range_check_ptr, syscall_ptr : felt*}(`, - ` name : felt, ptr : felt*, len : felt`, - `){`, - ` alloc_locals;`, - ` if (len == 0){`, - ` return ();`, - ` }`, - ` let index = len - 1;`, - ` let (index256) = felt_to_uint256(index);`, - ` let (loc) = ${data}.read(name, index256);`, - ` let (value) = WARP_STORAGE.read(loc);`, - ` assert ptr[index] = value;`, - ` ${helperFuncName}(name, ptr, index);`, - ` return ();`, - `}`, - `func ${funcName}{pedersen_ptr : HashBuiltin*, range_check_ptr, syscall_ptr : felt*}(`, - ` name : felt`, - `) -> (hashedValue : felt){`, - ` alloc_locals;`, - ` let (len256) = ${len}.read(name);`, - ` let (len) = narrow_safe(len256);`, - ` let (ptr) = alloc();`, - ` ${helperFuncName}(name, ptr, len);`, - ` let (hashValue) = string_hash(len, ptr);`, - ` return (hashValue,);`, - `}`, - ].join('\n'), - }); - this.generatedHashFunctionNumber++; - } else { - funcName = existing.name; - } - const stub = createCairoFunctionStub( - funcName, - [['name', indexTypeName, DataLocation.Storage]], - [['hashedStr', createUint8TypeName(this.ast), DataLocation.Default]], - ['pedersen_ptr', 'range_check_ptr', 'syscall_ptr'], - this.ast, - node, + inputInfo, + outputInfo, ); - - const call = createCallToFunction(stub, [node.vIndexExpression], this.ast); - this.ast.registerImport(call, 'warplib.maths.utils', 'narrow_safe'); - this.ast.registerImport(call, 'warplib.maths.utils', 'felt_to_uint256'); - this.ast.registerImport(call, 'starkware.cairo.common.alloc', 'alloc'); - this.ast.registerImport(call, 'warplib.string_hash', 'string_hash'); - return call; + return importFunction; } + + // Datalocation is storage + const funcInfo = this.generateStringHashFunction(indexType); + const genFunc = createCairoGeneratedFunction( + funcInfo, + inputInfo, + outputInfo, + this.ast, + this.sourceUnit, + ); + this.stringHashFunctions.set(key, genFunc); + return genFunc; + } + + private generateStringHashFunction(indexType: TypeNode): GeneratedFunctionInfo { + assert(isDynamicArray(indexType)); + const elemenT = getElementType(indexType); + + const [dynArray, dynArrayLen] = this.dynArrayGen.getOrCreateFuncDef(elemenT); + const arrayName = dynArray.name; + const lenName = dynArrayLen.name; + + const funcName = `WS_STRING_HASH${this.stringHashFunctions.size}`; + const helperFuncName = `WS_TO_FELT_ARRAY${this.stringHashFunctions.size}`; + return { + name: funcName, + code: [ + `func ${helperFuncName}{pedersen_ptr : HashBuiltin*, range_check_ptr, syscall_ptr : felt*}(`, + ` name : felt, ptr : felt*, len : felt`, + `){`, + ` alloc_locals;`, + ` if (len == 0){`, + ` return ();`, + ` }`, + ` let index = len - 1;`, + ` let (index256) = felt_to_uint256(index);`, + ` let (loc) = ${arrayName}.read(name, index256);`, + ` let (value) = WARP_STORAGE.read(loc);`, + ` assert ptr[index] = value;`, + ` ${helperFuncName}(name, ptr, index);`, + ` return ();`, + `}`, + `func ${funcName}{pedersen_ptr : HashBuiltin*, range_check_ptr, syscall_ptr : felt*}(`, + ` name : felt`, + `) -> (hashedValue : felt){`, + ` alloc_locals;`, + ` let (len256) = ${lenName}.read(name);`, + ` let (len) = narrow_safe(len256);`, + ` let (ptr) = alloc();`, + ` ${helperFuncName}(name, ptr, len);`, + ` let (hashValue) = string_hash(len, ptr);`, + ` return (hashValue,);`, + `}`, + ].join('\n'), + functionsCalled: [ + this.requireImport('warplib.maths.utils', 'narrow_safe'), + this.requireImport('warplib.maths.utils', 'felt_to_uint256'), + this.requireImport('starkware.cairo.common.alloc', 'alloc'), + this.requireImport('warplib.string_hash', 'string_hash'), + dynArray, + dynArrayLen, + ], + }; } } diff --git a/src/cairoUtilFuncGen/storage/staticArrayIndexAccess.ts b/src/cairoUtilFuncGen/storage/staticArrayIndexAccess.ts index c06751dbd..6d027164e 100644 --- a/src/cairoUtilFuncGen/storage/staticArrayIndexAccess.ts +++ b/src/cairoUtilFuncGen/storage/staticArrayIndexAccess.ts @@ -1,56 +1,34 @@ import assert from 'assert'; import { ArrayType, - ASTNode, DataLocation, FunctionCall, IndexAccess, PointerType, + TypeNode, } from 'solc-typed-ast'; import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { createNumberLiteral, createUint256TypeName } from '../../utils/nodeTemplates'; import { safeGetNodeType } from '../../utils/nodeTypeProcessing'; import { typeNameFromTypeNode } from '../../utils/utils'; -import { CairoUtilFuncGenBase } from '../base'; +import { GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; -export class StorageStaticArrayIndexAccessGen extends CairoUtilFuncGenBase { - private generatedFunction: string | null = null; - - getGeneratedCode(): string { - return this.generatedFunction ?? ''; - } - - gen(node: IndexAccess, nodeInSourceUnit?: ASTNode): FunctionCall { +export class StorageStaticArrayIndexAccessGen extends StringIndexedFuncGen { + public gen(node: IndexAccess): FunctionCall { assert(node.vIndexExpression !== undefined); - const name = this.getOrCreate(); - const arrayType = safeGetNodeType(node.vBaseExpression, this.ast.inference); assert( arrayType instanceof PointerType && arrayType.to instanceof ArrayType && arrayType.to.size !== undefined, ); - const valueType = safeGetNodeType(node, this.ast.inference); - const functionStub = createCairoFunctionStub( - name, - [ - ['loc', typeNameFromTypeNode(arrayType, this.ast), DataLocation.Storage], - ['index', createUint256TypeName(this.ast)], - ['size', createUint256TypeName(this.ast)], - ['limit', createUint256TypeName(this.ast)], - ], - [['resLoc', typeNameFromTypeNode(valueType, this.ast), DataLocation.Storage]], - ['range_check_ptr'], - this.ast, - nodeInSourceUnit ?? node, - ); - + const funcDef = this.getOrCreateFuncDef(arrayType, valueType); return createCallToFunction( - functionStub, + funcDef, [ node.vBaseExpression, node.vIndexExpression, @@ -65,33 +43,58 @@ export class StorageStaticArrayIndexAccessGen extends CairoUtilFuncGenBase { ); } - private getOrCreate(): string { - if (this.generatedFunction === null) { - this.generatedFunction = idxCode; - this.requireImport('starkware.cairo.common.math', 'split_felt'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_add'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_le'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_lt'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_mul'); + public getOrCreateFuncDef(arrayType: TypeNode, valueType: TypeNode) { + const key = arrayType.pp() + valueType.pp(); + const existing = this.generatedFunctionsDef.get(key); + if (existing !== undefined) { + return existing; } - return 'WS0_IDX'; + + const funcInfo = this.getOrCreate(); + const funcDef = createCairoGeneratedFunction( + funcInfo, + [ + ['loc', typeNameFromTypeNode(arrayType, this.ast), DataLocation.Storage], + ['index', createUint256TypeName(this.ast)], + ['size', createUint256TypeName(this.ast)], + ['limit', createUint256TypeName(this.ast)], + ], + [['res_loc', typeNameFromTypeNode(valueType, this.ast), DataLocation.Storage]], + this.ast, + this.sourceUnit, + ); + + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } -} -const idxCode = [ - `func WS0_IDX{range_check_ptr}(loc: felt, index: Uint256, size: Uint256, limit: Uint256) -> (resLoc: felt){`, - ` alloc_locals;`, - ` let (inRange) = uint256_lt(index, limit);`, - ` assert inRange = 1;`, - ` let (locHigh, locLow) = split_felt(loc);`, - ` let (offset, overflow) = uint256_mul(index, size);`, - ` assert overflow.low = 0;`, - ` assert overflow.high = 0;`, - ` let (res256, carry) = uint256_add(Uint256(locLow, locHigh), offset);`, - ` assert carry = 0;`, - ` let (feltLimitHigh, feltLimitLow) = split_felt(-1);`, - ` let (narrowable) = uint256_le(res256, Uint256(feltLimitLow, feltLimitHigh));`, - ` assert narrowable = 1;`, - ` return (res256.low + 2**128 * res256.high,);`, - `}`, -].join('\n'); + private getOrCreate(): GeneratedFunctionInfo { + return { + name: 'WS0_IDX', + code: [ + `func WS0_IDX{range_check_ptr}(loc: felt, index: Uint256, size: Uint256, limit: Uint256) -> (resLoc: felt){`, + ` alloc_locals;`, + ` let (inRange) = uint256_lt(index, limit);`, + ` assert inRange = 1;`, + ` let (locHigh, locLow) = split_felt(loc);`, + ` let (offset, overflow) = uint256_mul(index, size);`, + ` assert overflow.low = 0;`, + ` assert overflow.high = 0;`, + ` let (res256, carry) = uint256_add(Uint256(locLow, locHigh), offset);`, + ` assert carry = 0;`, + ` let (feltLimitHigh, feltLimitLow) = split_felt(-1);`, + ` let (narrowable) = uint256_le(res256, Uint256(feltLimitLow, feltLimitHigh));`, + ` assert narrowable = 1;`, + ` return (res256.low + 2**128 * res256.high,);`, + `}`, + ].join('\n'), + functionsCalled: [ + this.requireImport('starkware.cairo.common.math', 'split_felt'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_add'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_le'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_lt'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_mul'), + ], + }; + } +} diff --git a/src/cairoUtilFuncGen/storage/storageDelete.ts b/src/cairoUtilFuncGen/storage/storageDelete.ts index 25e7b54d6..3b3eece5e 100644 --- a/src/cairoUtilFuncGen/storage/storageDelete.ts +++ b/src/cairoUtilFuncGen/storage/storageDelete.ts @@ -1,32 +1,45 @@ import assert from 'assert'; import { ArrayType, - ASTNode, BytesType, DataLocation, Expression, FunctionCall, generalizeType, MappingType, - PointerType, SourceUnit, StringType, StructDefinition, TypeNode, + UserDefinedType, } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { CairoGeneratedFunctionDefinition } from '../../ast/cairoNodes'; +import { CairoFunctionDefinition } from '../../export'; import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; -import { TranspileFailedError } from '../../utils/errors'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { getElementType, isDynamicArray, safeGetNodeType } from '../../utils/nodeTypeProcessing'; import { typeNameFromTypeNode, mapRange, narrowBigIntSafe } from '../../utils/utils'; import { uint256 } from '../../warplib/utils'; -import { add, CairoFunction, delegateBasedOnType, StringIndexedFuncGen } from '../base'; +import { add, delegateBasedOnType, GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; import { DynArrayGen } from './dynArray'; import { StorageReadGen } from './storageRead'; +const IMPLICITS = '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}'; + export class StorageDeleteGen extends StringIndexedFuncGen { - private nothingHandlerGen: boolean; + // Map to store functions being created to + // avoid infinite recursion when deleting + // recursive types such as: + // struct S { + // S[]; + // } + private creatingFunctions: Map; + + // Map to store unsolved function dependecies + // of generated functions + private functionDependencies: Map; + constructor( private dynArrayGen: DynArrayGen, private storageReadGen: StorageReadGen, @@ -34,118 +47,113 @@ export class StorageDeleteGen extends StringIndexedFuncGen { sourceUnit: SourceUnit, ) { super(ast, sourceUnit); - this.nothingHandlerGen = false; + this.creatingFunctions = new Map(); + this.functionDependencies = new Map(); } - gen(node: Expression, nodeInSourceUnit?: ASTNode): FunctionCall { - const nodeType = dereferenceType(safeGetNodeType(node, this.ast.inference)); + public gen(node: Expression): FunctionCall { + const nodeType = generalizeType(safeGetNodeType(node, this.ast.inference))[0]; + const funcDef = this.getOrCreateFuncDef(nodeType); + return createCallToFunction(funcDef, [node], this.ast); + } - const functionName = this.getOrCreate(nodeType); + public getOrCreateFuncDef(type: TypeNode) { + const key = generateKey(type); + const existing = this.generatedFunctionsDef.get(key); + if (existing !== undefined) { + return existing; + } - const functionStub = createCairoFunctionStub( - functionName, - [['loc', typeNameFromTypeNode(nodeType, this.ast), DataLocation.Storage]], + const funcInfo = this.getOrCreate(type); + const funcDef = createCairoGeneratedFunction( + funcInfo, + [['loc', typeNameFromTypeNode(type, this.ast), DataLocation.Storage]], [], - ['syscall_ptr', 'pedersen_ptr', 'range_check_ptr'], this.ast, - nodeInSourceUnit ?? node, + this.sourceUnit, ); - return createCallToFunction(functionStub, [node], this.ast); - } - genFuncName(node: TypeNode): string { - return this.getOrCreate(node); - } + assert( + this.creatingFunctions.delete(key), + 'Cannot delete function which is not being processed', + ); + this.generatedFunctionsDef.set(key, funcDef); + this.processRecursiveDependencies(); - genAuxFuncName(node: TypeNode): string { - return `${this.getOrCreate(node)}_elem`; + return funcDef; } - private getOrCreate(type: TypeNode): string { - const key = type.pp(); - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; + private safeGetOrCreateFuncDef(parentType: TypeNode, type: TypeNode) { + const parentKey = generateKey(parentType); + const childKey = generateKey(type); + const dependencies = this.functionDependencies.get(parentKey); + if (dependencies === undefined) { + this.functionDependencies.set(parentKey, [childKey]); + } else { + dependencies.push(childKey); } - const cairoFuncName = delegateBasedOnType( - type, - () => `WS${this.generatedFunctions.size}_DYNAMIC_ARRAY_DELETE`, - () => `WS${this.generatedFunctions.size}_STATIC_ARRAY_DELETE`, - (_type, def) => `WS_STRUCT_${def.name}_DELETE`, - () => `WSMAP_DELETE`, - () => `WS${this.generatedFunctions.size}_DELETE`, - ); + const processingName = this.creatingFunctions.get(childKey); + if (processingName !== undefined) { + return processingName; + } - this.generatedFunctions.set(key, { - name: cairoFuncName, - get code(): string { - throw new TranspileFailedError('Tried accessing code yet to be generated'); - }, - }); + return this.getOrCreateFuncDef(type).name; + } - const cairoFunc = delegateBasedOnType( + private getOrCreate(type: TypeNode): GeneratedFunctionInfo { + const funcInfo = delegateBasedOnType( type, - (type) => this.deleteDynamicArray(type, cairoFuncName), + (type) => this.deleteDynamicArray(type), (type) => { assert(type.size !== undefined); return type.size <= 5 - ? this.deleteSmallStaticArray(type, cairoFuncName) - : this.deleteLargeStaticArray(type, cairoFuncName); + ? this.deleteSmallStaticArray(type) + : this.deleteLargeStaticArray(type); }, - (_type, def) => this.deleteStruct(def, cairoFuncName), - () => this.deleteNothing(cairoFuncName), - () => this.deleteGeneric(CairoType.fromSol(type, this.ast), cairoFuncName), + (type, def) => this.deleteStruct(type, def), + (type) => this.deleteNothing(type), + (type) => this.deleteGeneric(type), ); - - // WSMAP_DELETE can be keyed with multiple types but since its definition - // is always the same we want to make sure its not duplicated or else it - // clashes with itself. - if (cairoFunc.name === 'WSMAP_DELETE' && !this.nothingHandlerGen) { - this.nothingHandlerGen = true; - } else if (cairoFunc.name === 'WSMAP_DELETE' && this.nothingHandlerGen) { - this.generatedFunctions.set(key, { ...cairoFunc, code: '' }); - return cairoFunc.name; - } - - this.generatedFunctions.set(key, cairoFunc); - - return cairoFunc.name; + return funcInfo; } - private deleteGeneric(cairoType: CairoType, funcName: string): CairoFunction { - const implicits = '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}'; + private deleteGeneric(type: TypeNode): GeneratedFunctionInfo { + const funcName = `WS${this.getId()}_GENERIC_DELETE`; + this.creatingFunctions.set(generateKey(type), funcName); + + const cairoType = CairoType.fromSol(type, this.ast); return { name: funcName, code: [ - `func ${funcName}${implicits}(loc: felt){`, + `func ${funcName}${IMPLICITS}(loc: felt){`, ...mapRange(cairoType.width, (n) => ` WARP_STORAGE.write(${add('loc', n)}, 0);`), ` return ();`, `}`, ].join('\n'), + functionsCalled: [], }; } - private deleteDynamicArray( - type: ArrayType | BytesType | StringType, - funcName: string, - ): CairoFunction { - const implicits = '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}'; + private deleteDynamicArray(type: ArrayType | BytesType | StringType): GeneratedFunctionInfo { + const funcName = `WS${this.getId()}_DYNAMIC_ARRAY_DELETE`; + this.creatingFunctions.set(generateKey(type), funcName); - const elementT = dereferenceType(getElementType(type)); - const [arrayName, lengthName] = this.dynArrayGen.gen( - CairoType.fromSol(elementT, this.ast, TypeConversionContext.StorageAllocation), - ); + const elementT = generalizeType(getElementType(type))[0]; + + const [dynArray, dynArrayLen] = this.dynArrayGen.getOrCreateFuncDef(elementT); + const arrayName = dynArray.name; + const lengthName = dynArrayLen.name; + + const readFunc = this.storageReadGen.getOrCreateFuncDef(elementT); + const auxDeleteFuncName = this.safeGetOrCreateFuncDef(type, elementT); const deleteCode = requiresReadBeforeRecursing(elementT) - ? [ - ` let (elem_id) = ${this.storageReadGen.genFuncName(elementT)}(elem_loc);`, - ` ${this.getOrCreate(elementT)}(elem_id);`, - ] - : [` ${this.getOrCreate(elementT)}(elem_loc);`]; + ? [` let (elem_id) = ${readFunc.name}(elem_loc);`, ` ${auxDeleteFuncName}(elem_id);`] + : [` ${auxDeleteFuncName}(elem_loc);`]; const deleteFunc = [ - `func ${funcName}_elem${implicits}(loc : felt, index : Uint256, length : Uint256){`, + `func ${funcName}_elem${IMPLICITS}(loc : felt, index : Uint256, length : Uint256){`, ` alloc_locals;`, ` let (stop) = uint256_eq(index, length);`, ` if (stop == 1){`, @@ -156,7 +164,7 @@ export class StorageDeleteGen extends StringIndexedFuncGen { ` let (next_index, _) = uint256_add(index, ${uint256(1)});`, ` return ${funcName}_elem(loc, next_index, length);`, `}`, - `func ${funcName}${implicits}(loc : felt){`, + `func ${funcName}${IMPLICITS}(loc : felt){`, ` alloc_locals;`, ` let (length) = ${lengthName}.read(loc);`, ` ${lengthName}.write(loc, ${uint256(0)});`, @@ -164,53 +172,67 @@ export class StorageDeleteGen extends StringIndexedFuncGen { `}`, ].join('\n'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_eq'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_add'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - - return { name: funcName, code: deleteFunc }; + const importedFuncs = [ + this.requireImport('starkware.cairo.common.uint256', 'uint256_eq'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_add'), + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + ]; + return { + name: funcName, + code: deleteFunc, + functionsCalled: [...importedFuncs, dynArray, dynArrayLen, readFunc], + }; } - private deleteSmallStaticArray(type: ArrayType, funcName: string) { + private deleteSmallStaticArray(type: ArrayType): GeneratedFunctionInfo { + const funcName = `WS${this.getId()}_SMALL_STATIC_ARRAY_DELETE`; + this.creatingFunctions.set(generateKey(type), funcName); + assert(type.size !== undefined); - const implicits = '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}'; + const [deleteCode, funcCalls] = this.generateStaticArrayDeletionCode( + type, + type.elementT, + narrowBigIntSafe(type.size), + ); const code = [ + `func ${funcName}${IMPLICITS}(loc: felt) {`, ` alloc_locals;`, - ...this.generateStructDeletionCode( - mapRange(narrowBigIntSafe(type.size), () => type.elementT), - ), + ...deleteCode, ` return ();`, `}`, - ]; + ].join('\n'); return { name: funcName, - code: [`func ${funcName}${implicits}(loc : felt){`, ...code].join('\n'), + code: code, + functionsCalled: funcCalls, }; } - private deleteLargeStaticArray(type: ArrayType, funcName: string) { + private deleteLargeStaticArray(type: ArrayType): GeneratedFunctionInfo { assert(type.size !== undefined); + const funcName = `WS${this.getId()}_LARGE_STATIC_ARRAY_DELETE`; + this.creatingFunctions.set(generateKey(type), funcName); - const implicits = '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}'; - const elementT = dereferenceType(type.elementT); + const elementT = generalizeType(type.elementT)[0]; const elementTWidht = CairoType.fromSol( elementT, this.ast, TypeConversionContext.StorageAllocation, ).width; + const storageReadFunc = this.storageReadGen.getOrCreateFuncDef(elementT); + const auxDeleteFuncName = this.safeGetOrCreateFuncDef(type, elementT); + const deleteCode = requiresReadBeforeRecursing(elementT) - ? [ - ` let (elem_id) = ${this.storageReadGen.genFuncName(elementT)}(loc);`, - ` ${this.getOrCreate(elementT)}(elem_id);`, - ] - : [` ${this.getOrCreate(elementT)}(loc);`]; + ? [` let (elem_id) = ${storageReadFunc.name}(loc);`, ` ${auxDeleteFuncName}(elem_id);`] + : [` ${auxDeleteFuncName}(loc);`]; const length = narrowBigIntSafe(type.size); const nextLoc = add('loc', elementTWidht); + const deleteFunc = [ - `func ${funcName}_elem${implicits}(loc : felt, index : felt){`, + `func ${funcName}_elem${IMPLICITS}(loc : felt, index : felt){`, ` alloc_locals;`, ` if (index == ${length}){`, ` return ();`, @@ -219,72 +241,156 @@ export class StorageDeleteGen extends StringIndexedFuncGen { ...deleteCode, ` return ${funcName}_elem(${nextLoc}, next_index);`, `}`, - `func ${funcName}${implicits}(loc : felt){`, + `func ${funcName}${IMPLICITS}(loc : felt){`, ` alloc_locals;`, ` return ${funcName}_elem(loc, 0);`, `}`, ].join('\n'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_eq'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_sub'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); + const importedFuncs = [ + this.requireImport('starkware.cairo.common.uint256', 'uint256_eq'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_sub'), + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + ]; - return { name: funcName, code: deleteFunc }; + return { + name: funcName, + code: deleteFunc, + functionsCalled: [...importedFuncs, storageReadFunc], + }; } - private deleteStruct(structDef: StructDefinition, funcName: string): CairoFunction { - const implicits = '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}'; - // struct names are unique + private deleteStruct(type: UserDefinedType, structDef: StructDefinition): GeneratedFunctionInfo { + const funcName = `WS_STRUCT_${structDef.name}_DELETE`; + this.creatingFunctions.set(generateKey(type), funcName); + + const [deleteCode, funcCalls] = this.generateStructDeletionCode( + type, + structDef.vMembers.map((varDecl) => safeGetNodeType(varDecl, this.ast.inference)), + ); + const deleteFunc = [ - `func ${funcName}${implicits}(loc : felt){`, + `func ${funcName}${IMPLICITS}(loc : felt){`, ` alloc_locals;`, - ...this.generateStructDeletionCode( - structDef.vMembers.map((varDecl) => safeGetNodeType(varDecl, this.ast.inference)), - ), + ...deleteCode, ` return ();`, `}`, ].join('\n'); - return { name: funcName, code: deleteFunc }; + return { name: funcName, code: deleteFunc, functionsCalled: funcCalls }; } - private deleteNothing(funcName: string): CairoFunction { - const implicits = '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}'; + private deleteNothing(type: TypeNode): GeneratedFunctionInfo { + const funcName = 'WS_MAP_DELETE'; + this.creatingFunctions.set(generateKey(type), funcName); + return { name: funcName, - code: [`func ${funcName}${implicits}(loc: felt){`, ` return ();`, `}`].join('\n'), + code: [`func ${funcName}${IMPLICITS}(loc: felt){`, ` return ();`, `}`].join('\n'), + functionsCalled: [], }; } - private generateStructDeletionCode(varDeclarations: TypeNode[], index = 0, offset = 0): string[] { - if (index >= varDeclarations.length) return []; - const varType = dereferenceType(varDeclarations[index]); + private generateStructDeletionCode( + structType: UserDefinedType, + varDeclarations: TypeNode[], + index = 0, + offset = 0, + ): [string[], CairoFunctionDefinition[]] { + if (index >= varDeclarations.length) { + return [[], []]; + } + + const varType = generalizeType(varDeclarations[index])[0]; const varWidth = CairoType.fromSol( varType, this.ast, TypeConversionContext.StorageAllocation, ).width; + const readIdFunc = this.storageReadGen.getOrCreateFuncDef(varType); + const auxDeleteFuncName = this.safeGetOrCreateFuncDef(structType, varType); + const deleteLoc = add('loc', offset); const deleteCode = requiresReadBeforeRecursing(varType) ? [ - ` let (elem_id) = ${this.storageReadGen.genFuncName(varType)}(${deleteLoc});`, - ` ${this.getOrCreate(varType)}(elem_id);`, + ` let (elem_id) = ${readIdFunc.name}(${deleteLoc});`, + ` ${auxDeleteFuncName}(elem_id);`, ] - : [` ${this.getOrCreate(varType)}(${deleteLoc});`]; + : [` ${auxDeleteFuncName}(${deleteLoc});`]; + const [code, funcsCalled] = this.generateStructDeletionCode( + structType, + varDeclarations, + index + 1, + offset + varWidth, + ); return [ - ...deleteCode, - ...this.generateStructDeletionCode(varDeclarations, index + 1, offset + varWidth), + [...deleteCode, ...code], + [readIdFunc, ...funcsCalled], ]; } -} -function dereferenceType(type: TypeNode): TypeNode { - return generalizeType(type)[0]; + private generateStaticArrayDeletionCode( + arrayType: ArrayType, + elementT: TypeNode, + size: number, + ): [string[], CairoFunctionDefinition[]] { + const elementTWidth = CairoType.fromSol( + elementT, + this.ast, + TypeConversionContext.StorageAllocation, + ).width; + const readIdFunc = this.storageReadGen.getOrCreateFuncDef(elementT); + const auxDeleteFuncName = this.safeGetOrCreateFuncDef(arrayType, elementT); + + const generateDeleteCode = requiresReadBeforeRecursing(elementT) + ? (deleteLoc: string) => [ + ` let (elem_id) = ${readIdFunc.name}(${deleteLoc});`, + ` ${auxDeleteFuncName}(elem_id);`, + ] + : (deleteLoc: string) => [` ${auxDeleteFuncName}(${deleteLoc});`]; + + const generateCode = (index: number, offset: number): string[] => { + if (index === size) { + return []; + } + const deleteLoc = add('loc', offset); + const deleteCode = generateDeleteCode(deleteLoc); + + return [...deleteCode, ...generateCode(index + 1, offset + elementTWidth)]; + }; + + return [generateCode(0, 0), requiresReadBeforeRecursing(elementT) ? [readIdFunc] : []]; + } + + private processRecursiveDependencies() { + [...this.functionDependencies.entries()].forEach(([key, dependencies]) => { + if (!this.creatingFunctions.has(key)) { + const generatedFunc = this.generatedFunctionsDef.get(key); + assert(generatedFunc instanceof CairoGeneratedFunctionDefinition); + + dependencies.forEach((otherKey) => { + const otherFunc = this.generatedFunctionsDef.get(otherKey); + if (otherFunc === undefined) { + assert(this.creatingFunctions.has(otherKey)); + return; + } + generatedFunc.functionsCalled.push(otherFunc); + }); + } + }); + } + + private getId() { + return this.generatedFunctionsDef.size + this.creatingFunctions.size; + } } function requiresReadBeforeRecursing(type: TypeNode): boolean { - if (type instanceof PointerType) return requiresReadBeforeRecursing(type.to); return isDynamicArray(type) || type instanceof MappingType; } + +function generateKey(type: TypeNode) { + return type.pp(); +} diff --git a/src/cairoUtilFuncGen/storage/storageMemberAccess.ts b/src/cairoUtilFuncGen/storage/storageMemberAccess.ts index ee6e79743..7522b6297 100644 --- a/src/cairoUtilFuncGen/storage/storageMemberAccess.ts +++ b/src/cairoUtilFuncGen/storage/storageMemberAccess.ts @@ -1,65 +1,80 @@ import assert from 'assert'; import { MemberAccess, - ASTNode, FunctionCall, PointerType, UserDefinedType, VariableDeclaration, DataLocation, + StructDefinition, } from 'solc-typed-ast'; +import { CairoFunctionDefinition } from '../../ast/cairoNodes'; +import { printTypeNode } from '../../export'; import { CairoType, TypeConversionContext, CairoStruct } from '../../utils/cairoTypeSystem'; import { cloneASTNode } from '../../utils/cloning'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { safeGetNodeType } from '../../utils/nodeTypeProcessing'; -import { typeNameFromTypeNode, countNestedMapItems } from '../../utils/utils'; -import { CairoUtilFuncGenBase, CairoFunction, add } from '../base'; +import { typeNameFromTypeNode } from '../../utils/utils'; +import { add, GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; -export class StorageMemberAccessGen extends CairoUtilFuncGenBase { - // cairoType -> property name -> code - private generatedFunctions: Map> = new Map(); - - getGeneratedCode(): string { - return [...this.generatedFunctions.values()] - .flatMap((map) => [...map.values()]) - .map((cairoMapping) => cairoMapping.code) - .join('\n\n'); - } - - gen(memberAccess: MemberAccess, nodeInSourceUnit?: ASTNode): FunctionCall { +export class StorageMemberAccessGen extends StringIndexedFuncGen { + public gen(memberAccess: MemberAccess): FunctionCall { const solType = safeGetNodeType(memberAccess.vExpression, this.ast.inference); - assert(solType instanceof PointerType); - assert(solType.to instanceof UserDefinedType); - const structCairoType = CairoType.fromSol( - solType, - this.ast, - TypeConversionContext.StorageAllocation, + assert( + solType instanceof PointerType && + solType.to instanceof UserDefinedType && + solType.to.definition instanceof StructDefinition, + `Trying to generate a member access for a type different than a struct: ${printTypeNode( + solType, + )}`, ); - const name = this.getOrCreate(structCairoType, memberAccess.memberName); + const referencedDeclaration = memberAccess.vReferencedDeclaration; assert(referencedDeclaration instanceof VariableDeclaration); + const outType = referencedDeclaration.vType; assert(outType !== undefined); - const functionStub = createCairoFunctionStub( - name, - [['loc', typeNameFromTypeNode(solType, this.ast), DataLocation.Storage]], - [['memberLoc', cloneASTNode(outType, this.ast), DataLocation.Storage]], - [], + + const funcDef = this.getOrCreateFuncDef(solType.to, memberAccess.memberName); + return createCallToFunction(funcDef, [memberAccess.vExpression], this.ast); + } + + public getOrCreateFuncDef(solType: UserDefinedType, memberName: string): CairoFunctionDefinition { + assert(solType.definition instanceof StructDefinition); + const structCairoType = CairoType.fromSol( + solType, this.ast, - nodeInSourceUnit ?? memberAccess, + TypeConversionContext.StorageAllocation, ); - return createCallToFunction(functionStub, [memberAccess.vExpression], this.ast); - } - private getOrCreate(structCairoType: CairoType, memberName: string): string { - const existingMemberAccesses = - this.generatedFunctions.get(structCairoType.fullStringRepresentation) ?? - new Map(); - const existing = existingMemberAccesses.get(memberName); + const key = structCairoType.fullStringRepresentation + memberName; + const existing = this.generatedFunctionsDef.get(key); if (existing !== undefined) { - return existing.name; + return existing; } + const funcInfo = this.getOrCreate(structCairoType, memberName); + + const solTypeName = typeNameFromTypeNode(solType, this.ast); + const [outTypeName] = solType.definition.vMembers + .filter((member) => member.name === memberName) + .map((member) => member.vType); + assert(outTypeName !== undefined); + + const funcDef = createCairoGeneratedFunction( + funcInfo, + [['loc', solTypeName, DataLocation.Storage]], + [['memberLoc', cloneASTNode(outTypeName, this.ast), DataLocation.Storage]], + this.ast, + this.sourceUnit, + ); + + this.generatedFunctionsDef.set(key, funcDef); + + return funcDef; + } + + private getOrCreate(structCairoType: CairoType, memberName: string): GeneratedFunctionInfo { const structName = structCairoType.toString(); assert( structCairoType instanceof CairoStruct, @@ -67,20 +82,16 @@ export class StorageMemberAccessGen extends CairoUtilFuncGenBase { ); const offset = structCairoType.offsetOf(memberName); - const funcName = `WSM${countNestedMapItems( - this.generatedFunctions, - )}_${structName}_${memberName}`; + const funcName = `WS_${structName}_${memberName}`; - existingMemberAccesses.set(memberName, { + return { name: funcName, code: [ `func ${funcName}(loc: felt) -> (memberLoc: felt){`, ` return (${add('loc', offset)},);`, `}`, ].join('\n'), - }); - - this.generatedFunctions.set(structCairoType.fullStringRepresentation, existingMemberAccesses); - return funcName; + functionsCalled: [], + }; } } diff --git a/src/cairoUtilFuncGen/storage/storageRead.ts b/src/cairoUtilFuncGen/storage/storageRead.ts index 11c06258d..70503b67e 100644 --- a/src/cairoUtilFuncGen/storage/storageRead.ts +++ b/src/cairoUtilFuncGen/storage/storageRead.ts @@ -1,62 +1,67 @@ import { Expression, - TypeName, FunctionCall, DataLocation, FunctionStateMutability, TypeNode, - ASTNode, + TypeName, } from 'solc-typed-ast'; +import { typeNameFromTypeNode } from '../../export'; import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; import { cloneASTNode } from '../../utils/cloning'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { safeGetNodeType } from '../../utils/nodeTypeProcessing'; -import { add, locationIfComplexType, StringIndexedFuncGen } from '../base'; +import { add, GeneratedFunctionInfo, locationIfComplexType, StringIndexedFuncGen } from '../base'; import { serialiseReads } from '../serialisation'; export class StorageReadGen extends StringIndexedFuncGen { - gen(storageLocation: Expression, type: TypeName, nodeInSourceUnit?: ASTNode): FunctionCall { + // TODO: is typename safe to remove? + public gen(storageLocation: Expression, typeName?: TypeName): FunctionCall { const valueType = safeGetNodeType(storageLocation, this.ast.inference); + + const funcDef = this.getOrCreateFuncDef(valueType, typeName); + + return createCallToFunction(funcDef, [storageLocation], this.ast); + } + + public getOrCreateFuncDef(valueType: TypeNode, typeName?: TypeName) { + typeName = typeName ?? typeNameFromTypeNode(valueType, this.ast); const resultCairoType = CairoType.fromSol( valueType, this.ast, TypeConversionContext.StorageAllocation, ); - const name = this.getOrCreate(resultCairoType); - const functionStub = createCairoFunctionStub( - name, - [['loc', cloneASTNode(type, this.ast), DataLocation.Storage]], + + const key = resultCairoType.fullStringRepresentation + typeName.typeString; + const existing = this.generatedFunctionsDef.get(key); + if (existing !== undefined) { + return existing; + } + + const funcInfo = this.getOrCreate(resultCairoType); + const funcDef = createCairoGeneratedFunction( + funcInfo, + [['loc', cloneASTNode(typeName, this.ast), DataLocation.Storage]], [ [ 'val', - cloneASTNode(type, this.ast), + cloneASTNode(typeName, this.ast), locationIfComplexType(valueType, DataLocation.Storage), ], ], - ['syscall_ptr', 'pedersen_ptr', 'range_check_ptr'], this.ast, - nodeInSourceUnit ?? storageLocation, + this.sourceUnit, { mutability: FunctionStateMutability.View }, ); - return createCallToFunction(functionStub, [storageLocation], this.ast); - } - - genFuncName(type: TypeNode) { - const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.StorageAllocation); - return this.getOrCreate(cairoType); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - private getOrCreate(typeToRead: CairoType): string { - const key = typeToRead.fullStringRepresentation; - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - - const funcName = `WS${this.generatedFunctions.size}_READ_${typeToRead.typeName}`; + private getOrCreate(typeToRead: CairoType): GeneratedFunctionInfo { + const funcName = `WS${this.generatedFunctionsDef.size}_READ_${typeToRead.typeName}`; const resultCairoType = typeToRead.toString(); const [reads, pack] = serialiseReads(typeToRead, readFelt, readId); - this.generatedFunctions.set(key, { + const funcInfo: GeneratedFunctionInfo = { name: funcName, code: [ `func ${funcName}{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}(loc: felt) ->(val: ${resultCairoType}){`, @@ -65,8 +70,9 @@ export class StorageReadGen extends StringIndexedFuncGen { ` return (${pack},);`, '}', ].join('\n'), - }); - return funcName; + functionsCalled: [], + }; + return funcInfo; } } diff --git a/src/cairoUtilFuncGen/storage/storageToCalldata.ts b/src/cairoUtilFuncGen/storage/storageToCalldata.ts index c3dfe2e29..a1f453560 100644 --- a/src/cairoUtilFuncGen/storage/storageToCalldata.ts +++ b/src/cairoUtilFuncGen/storage/storageToCalldata.ts @@ -12,13 +12,14 @@ import { UserDefinedType, } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { CairoFunctionDefinition } from '../../export'; import { printTypeNode } from '../../utils/astPrinter'; import { CairoDynArray, CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; import { NotSupportedYetError } from '../../utils/errors'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { getElementType, safeGetNodeType } from '../../utils/nodeTypeProcessing'; import { mapRange, narrowBigIntSafe, typeNameFromTypeNode } from '../../utils/utils'; -import { add, delegateBasedOnType, StringIndexedFuncGen } from '../base'; +import { add, delegateBasedOnType, GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; import { ExternalDynArrayStructConstructor } from '../calldata/externalDynArray/externalDynArrayStructConstructor'; import { DynArrayGen } from './dynArray'; import { StorageReadGen } from './storageRead'; @@ -34,46 +35,49 @@ export class StorageToCalldataGen extends StringIndexedFuncGen { super(ast, sourceUnit); } - gen(storageLocation: Expression) { + public gen(storageLocation: Expression) { const storageType = generalizeType(safeGetNodeType(storageLocation, this.ast.inference))[0]; - const name = this.getOrCreate(storageType); - const functionStub = createCairoFunctionStub( - name, - [['loc', typeNameFromTypeNode(storageType, this.ast), DataLocation.Storage]], - [['obj', typeNameFromTypeNode(storageType, this.ast), DataLocation.CallData]], - ['syscall_ptr', 'pedersen_ptr', 'range_check_ptr'], - this.ast, - storageLocation, - ); - - return createCallToFunction(functionStub, [storageLocation], this.ast); + const funcDef = this.getOrCreateFuncDef(storageType); + return createCallToFunction(funcDef, [storageLocation], this.ast); } - private getOrCreate(type: TypeNode): string { + public getOrCreateFuncDef(type: TypeNode) { const key = type.pp(); - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; + const value = this.generatedFunctionsDef.get(key); + if (value !== undefined) { + return value; } + const funcInfo = this.getOrCreate(type); + const funcDef = createCairoGeneratedFunction( + funcInfo, + [['loc', typeNameFromTypeNode(type, this.ast), DataLocation.Storage]], + [['obj', typeNameFromTypeNode(type, this.ast), DataLocation.CallData]], + this.ast, + this.sourceUnit, + ); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; + } + private getOrCreate(type: TypeNode): GeneratedFunctionInfo { const unexpectedTypeFunc = () => { throw new NotSupportedYetError( `Copying ${printTypeNode(type)} from storage to calldata is not supported yet`, ); }; - return delegateBasedOnType( + return delegateBasedOnType( type, - (type) => this.createDynamicArrayCopyFunction(key, type), - (type) => this.createStaticArrayCopyFunction(key, type), - (type) => this.createStructCopyFunction(key, type), + (type) => this.createDynamicArrayCopyFunction(type), + (type) => this.createStaticArrayCopyFunction(type), + (type) => this.createStructCopyFunction(type), unexpectedTypeFunc, unexpectedTypeFunc, ); } - private createStructCopyFunction(key: string, structType: UserDefinedType) { + private createStructCopyFunction(structType: UserDefinedType): GeneratedFunctionInfo { assert(structType.definition instanceof StructDefinition); const structDef = structType.definition; @@ -85,7 +89,7 @@ export class StorageToCalldataGen extends StringIndexedFuncGen { const structName = `struct_${cairoStruct.toString()}`; - const [copyInstructions, members] = this.generateStructCopyInstructions( + const [copyInstructions, members, funcsCalled] = this.generateStructCopyInstructions( structDef.vMembers.map((varDecl) => safeGetNodeType(varDecl, this.ast.inference)), 'member', ); @@ -100,22 +104,26 @@ export class StorageToCalldataGen extends StringIndexedFuncGen { `}`, ].join('\n'); - this.generatedFunctions.set(key, { name: funcName, code: code }); - return funcName; + const funcInfo: GeneratedFunctionInfo = { + name: funcName, + code: code, + functionsCalled: funcsCalled, + }; + return funcInfo; } - private createStaticArrayCopyFunction(key: string, arrayType: ArrayType): string { + private createStaticArrayCopyFunction(arrayType: ArrayType): GeneratedFunctionInfo { assert(arrayType.size !== undefined); const cairoType = CairoType.fromSol(arrayType, this.ast, TypeConversionContext.CallDataRef); - const [copyInstructions, members] = this.generateStructCopyInstructions( + const [copyInstructions, members, funcsCalled] = this.generateStructCopyInstructions( mapRange(narrowBigIntSafe(arrayType.size), () => arrayType.elementT), 'elem', ); const implicits = '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}'; - const funcName = `ws_static_array_to_calldata${this.generatedFunctions.size}`; + const funcName = `ws_static_array_to_calldata${this.generatedFunctionsDef.size}`; const code = [ `func ${funcName}${implicits}(loc : felt) -> (static_array : ${cairoType.toString()}){`, ` alloc_locals;`, @@ -124,23 +132,26 @@ export class StorageToCalldataGen extends StringIndexedFuncGen { `}`, ].join('\n'); - this.generatedFunctions.set(key, { name: funcName, code: code }); - - return funcName; + return { + name: funcName, + code: code, + functionsCalled: funcsCalled, + }; } private createDynamicArrayCopyFunction( - key: string, arrayType: ArrayType | BytesType | StringType, - ): string { + ): GeneratedFunctionInfo { const elementT = getElementType(arrayType); const structDef = CairoType.fromSol(arrayType, this.ast, TypeConversionContext.CallDataRef); assert(structDef instanceof CairoDynArray); - this.externalDynArrayStructConstructor.getOrCreate(arrayType); - const [arrayName, arrayLen] = this.dynArrayGen.gen( - CairoType.fromSol(elementT, this.ast, TypeConversionContext.StorageAllocation), - ); + const storageReadFunc = this.storageReadGen.getOrCreateFuncDef(elementT); + const sturctDynArray = this.externalDynArrayStructConstructor.getOrCreateFuncDef(arrayType); + const [dynArray, dynArrayLength] = this.dynArrayGen.getOrCreateFuncDef(elementT); + + const arrayName = dynArray.name; + const lenName = dynArrayLength.name; const cairoElementType = CairoType.fromSol( elementT, this.ast, @@ -150,7 +161,7 @@ export class StorageToCalldataGen extends StringIndexedFuncGen { const ptrType = `${cairoElementType.toString()}*`; const implicits = '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}'; - const funcName = `ws_dynamic_array_to_calldata${this.generatedFunctions.size}`; + const funcName = `ws_dynamic_array_to_calldata${this.generatedFunctionsDef.size}`; const code = [ `func ${funcName}_write${implicits}(`, ` loc : felt,`, @@ -163,14 +174,14 @@ export class StorageToCalldataGen extends StringIndexedFuncGen { ` }`, ` let (index_uint256) = warp_uint256(index);`, ` let (elem_loc) = ${arrayName}.read(loc, index_uint256);`, // elem_loc should never be zero - ` let (elem) = ${this.storageReadGen.genFuncName(elementT)}(elem_loc);`, + ` let (elem) = ${storageReadFunc.name}(elem_loc);`, ` assert ptr[index] = elem;`, ` return ${funcName}_write(loc, index + 1, len, ptr);`, `}`, `func ${funcName}${implicits}(loc : felt) -> (dyn_array_struct : ${structDef.name}){`, ` alloc_locals;`, - ` let (len_uint256) = ${arrayLen}.read(loc);`, + ` let (len_uint256) = ${lenName}.read(loc);`, ` let len = len_uint256.low + len_uint256.high*128;`, ` let (ptr : ${ptrType}) = alloc();`, ` let (ptr : ${ptrType}) = ${funcName}_write(loc, 0, len, ptr);`, @@ -179,38 +190,56 @@ export class StorageToCalldataGen extends StringIndexedFuncGen { `}`, ].join('\n'); - this.requireImport('warplib.maths.int_conversions', 'warp_uint256'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('starkware.cairo.common.alloc', 'alloc'); - - this.generatedFunctions.set(key, { name: funcName, code: code }); - - return funcName; + const importedFuncs = [ + this.requireImport('warplib.maths.int_conversions', 'warp_uint256'), + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('starkware.cairo.common.alloc', 'alloc'), + ]; + + const funcInfo: GeneratedFunctionInfo = { + name: funcName, + code: code, + functionsCalled: [ + ...importedFuncs, + sturctDynArray, + dynArray, + dynArrayLength, + storageReadFunc, + ], + }; + return funcInfo; } + // TODO: static arrays functions are going to be huge with big size. We should build + // a copy function for them instead of reusing structs private generateStructCopyInstructions( varDeclarations: TypeNode[], tempVarName: string, - ): [string[], string[]] { - const members: string[] = []; - let offset = 0; - const copyInstructions = varDeclarations.map((varType, index) => { - const varCairoTypeWidth = CairoType.fromSol( - varType, - this.ast, - TypeConversionContext.CallDataRef, - ).width; - - const funcName = this.storageReadGen.genFuncName(varType); - const location = add('loc', offset); - const memberName = `${tempVarName}_${index}`; - - members.push(memberName); - offset += varCairoTypeWidth; - - return ` let (${memberName}) = ${funcName}(${location});`; - }); - - return [copyInstructions, members]; + ): [string[], string[], CairoFunctionDefinition[]] { + const [members, copyInstructions, funcsCalled] = varDeclarations.reduce( + ([members, copyInstructions, funcsCalled, offset], varType, index) => { + const varCairoTypeWidth = CairoType.fromSol( + varType, + this.ast, + TypeConversionContext.CallDataRef, + ).width; + + const readFunc = this.storageReadGen.getOrCreateFuncDef(varType); + const location = add('loc', offset); + const memberName = `${tempVarName}_${index}`; + + const instruction = ` let (${memberName}) = ${readFunc.name}(${location});`; + + return [ + [...members, memberName], + [...copyInstructions, instruction], + [...funcsCalled, readFunc], + offset + varCairoTypeWidth, + ]; + }, + [new Array(), new Array(), new Array(), 0], + ); + + return [copyInstructions, members, funcsCalled]; } } diff --git a/src/cairoUtilFuncGen/storage/storageToMemory.ts b/src/cairoUtilFuncGen/storage/storageToMemory.ts index 32635a663..417a72ddb 100644 --- a/src/cairoUtilFuncGen/storage/storageToMemory.ts +++ b/src/cairoUtilFuncGen/storage/storageToMemory.ts @@ -1,12 +1,13 @@ import assert from 'assert'; import { ArrayType, - ASTNode, BytesType, DataLocation, Expression, + FunctionCall, FunctionStateMutability, generalizeType, + isReferenceType, SourceUnit, StringType, StructDefinition, @@ -14,16 +15,19 @@ import { UserDefinedType, } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { CairoFunctionDefinition, TranspileFailedError } from '../../export'; import { printTypeNode } from '../../utils/astPrinter'; import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; import { NotSupportedYetError } from '../../utils/errors'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { getElementType, isDynamicArray, safeGetNodeType } from '../../utils/nodeTypeProcessing'; import { mapRange, narrowBigIntSafe, typeNameFromTypeNode } from '../../utils/utils'; import { uint256 } from '../../warplib/utils'; -import { add, delegateBasedOnType, StringIndexedFuncGen } from '../base'; +import { add, delegateBasedOnType, GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; import { DynArrayGen } from './dynArray'; +const IMPLICITS = + '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, warp_memory : DictAccess*}'; /* Generates functions to copy data from WARP_STORAGE to warp_memory Specifically this has to deal with structs, static arrays, and dynamic arrays @@ -32,134 +36,150 @@ import { DynArrayGen } from './dynArray'; */ export class StorageToMemoryGen extends StringIndexedFuncGen { - constructor(private dynArrayGen: DynArrayGen, ast: AST, sourceUnit: SourceUnit) { + public constructor(private dynArrayGen: DynArrayGen, ast: AST, sourceUnit: SourceUnit) { super(ast, sourceUnit); } - gen(node: Expression, nodeInSourceUnit?: ASTNode): Expression { - const type = generalizeType(safeGetNodeType(node, this.ast.inference))[0]; - const name = this.getOrCreate(type); - const functionStub = createCairoFunctionStub( - name, + public gen(node: Expression): FunctionCall { + const type = safeGetNodeType(node, this.ast.inference); + + const funcDef = this.getOrCreateFuncDef(type); + return createCallToFunction(funcDef, [node], this.ast); + } + + public getOrCreateFuncDef(type: TypeNode): CairoFunctionDefinition { + type = generalizeType(type)[0]; + + const key = type.pp(); + const existing = this.generatedFunctionsDef.get(key); + if (existing !== undefined) { + return existing; + } + + const funcInfo = this.getOrCreate(type); + const funcDef = createCairoGeneratedFunction( + funcInfo, [['loc', typeNameFromTypeNode(type, this.ast), DataLocation.Storage]], [['mem_loc', typeNameFromTypeNode(type, this.ast), DataLocation.Memory]], - ['syscall_ptr', 'pedersen_ptr', 'range_check_ptr', 'warp_memory'], this.ast, - nodeInSourceUnit ?? node, + this.sourceUnit, { mutability: FunctionStateMutability.View }, ); - return createCallToFunction(functionStub, [node], this.ast); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - private getOrCreate(type: TypeNode): string { - const key = type.pp(); - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - + private getOrCreate(type: TypeNode): GeneratedFunctionInfo { const unexpectedTypeFunc = () => { throw new NotSupportedYetError( `Copying ${printTypeNode(type)} from storage to memory not implemented yet`, ); }; - return delegateBasedOnType( + return delegateBasedOnType( type, - (type) => this.createDynamicArrayCopyFunction(key, type), - (type) => this.createStaticArrayCopyFunction(key, type), - (type) => this.createStructCopyFunction(key, type), + (type) => this.createDynamicArrayCopyFunction(type), + (type) => this.createStaticArrayCopyFunction(type), + (type, def) => this.createStructCopyFunction(type, def), unexpectedTypeFunc, unexpectedTypeFunc, ); } - private createStructCopyFunction(key: string, type: UserDefinedType): string { + private createStructCopyFunction( + type: UserDefinedType, + def: StructDefinition, + ): GeneratedFunctionInfo { const memoryType = CairoType.fromSol(type, this.ast, TypeConversionContext.MemoryAllocation); - const funcName = `ws_to_memory${this.generatedFunctions.size}`; - const implicits = - '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, warp_memory : DictAccess*}'; - - // Set an empty entry so recursive function generation doesn't clash - this.generatedFunctions.set(key, { name: funcName, code: '' }); + const [copyInstructions, copyCalls] = generateCopyInstructions(type, this.ast).reduce( + ([copyInstructions, copyCalls], { storageOffset, copyType }, index) => { + const [copyCode, calls] = this.getIterCopyCode(copyType, index, storageOffset); + return [ + [ + ...copyInstructions, + copyCode, + `dict_write{dict_ptr=warp_memory}(${add('mem_start', index)}, copy${index});`, + ], + [...copyCalls, ...calls], + ]; + }, + [new Array(), new Array()], + ); - this.generatedFunctions.set(key, { + const funcName = `ws_to_memory${this.generatedFunctionsDef.size}_struct_${def.name}`; + const funcInfo: GeneratedFunctionInfo = { name: funcName, code: [ - `func ${funcName}${implicits}(loc : felt) -> (mem_loc: felt){`, + `func ${funcName}${IMPLICITS}(loc : felt) -> (mem_loc: felt){`, ` alloc_locals;`, ` let (mem_start) = wm_alloc(${uint256(memoryType.width)});`, - ...generateCopyInstructions(type, this.ast).flatMap( - ({ storageOffset, copyType }, index) => [ - this.getIterCopyCode(copyType, index, storageOffset), - `dict_write{dict_ptr=warp_memory}(${add('mem_start', index)}, copy${index});`, - ], - ), + ...copyInstructions, ` return (mem_start,);`, `}`, ].join('\n'), - }); - - this.requireImport('starkware.cairo.common.dict', 'dict_write'); - this.requireImport('warplib.memory', 'wm_alloc'); - - return funcName; + functionsCalled: [ + this.requireImport('starkware.cairo.common.dict', 'dict_write'), + this.requireImport('warplib.memory', 'wm_alloc'), + ...copyCalls, + ], + }; + return funcInfo; } - private createStaticArrayCopyFunction(key: string, type: ArrayType): string { + private createStaticArrayCopyFunction(type: ArrayType): GeneratedFunctionInfo { assert(type.size !== undefined, 'Expected static array with known size'); return type.size <= 5 - ? this.createSmallStaticArrayCopyFunction(key, type) - : this.createLargeStaticArrayCopyFunction(key, type); + ? this.createSmallStaticArrayCopyFunction(type) + : this.createLargeStaticArrayCopyFunction(type); } - private createSmallStaticArrayCopyFunction(key: string, type: ArrayType) { + private createSmallStaticArrayCopyFunction(type: ArrayType): GeneratedFunctionInfo { const memoryType = CairoType.fromSol(type, this.ast, TypeConversionContext.MemoryAllocation); - const funcName = `ws_to_memory${this.generatedFunctions.size}`; - const implicits = - '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, warp_memory : DictAccess*}'; - - // Set an empty entry so recursive function generation doesn't clash - this.generatedFunctions.set(key, { name: funcName, code: '' }); + const [copyInstructions, copyCalls] = generateCopyInstructions(type, this.ast).reduce( + ([copyInstructions, copyCalls], { storageOffset, copyType }, index) => { + const [copyCode, calls] = this.getIterCopyCode(copyType, index, storageOffset); + return [ + [ + ...copyInstructions, + copyCode, + `dict_write{dict_ptr=warp_memory}(${add('mem_start', index)}, copy${index});`, + ], + [...copyCalls, ...calls], + ]; + }, + [new Array(), new Array()], + ); - this.generatedFunctions.set(key, { + const funcName = `ws_to_memory_small_static_array${this.generatedFunctionsDef.size}`; + const funcInfo: GeneratedFunctionInfo = { name: funcName, code: [ - `func ${funcName}${implicits}(loc : felt) -> (mem_loc : felt){`, + `func ${funcName}${IMPLICITS}(loc : felt) -> (mem_loc : felt){`, ` alloc_locals;`, ` let length = ${uint256(memoryType.width)};`, ` let (mem_start) = wm_alloc(length);`, - ...generateCopyInstructions(type, this.ast).flatMap( - ({ storageOffset, copyType }, index) => [ - this.getIterCopyCode(copyType, index, storageOffset), - `dict_write{dict_ptr=warp_memory}(${add('mem_start', index)}, copy${index});`, - ], - ), + ...copyInstructions, ` return (mem_start,);`, `}`, ].join('\n'), - }); - - this.requireImport('starkware.cairo.common.dict', 'dict_write'); - this.requireImport('warplib.memory', 'wm_alloc'); + functionsCalled: [ + this.requireImport('starkware.cairo.common.dict', 'dict_write'), + this.requireImport('warplib.memory', 'wm_alloc'), + ...copyCalls, + ], + }; - return funcName; + return funcInfo; } - private createLargeStaticArrayCopyFunction(key: string, type: ArrayType) { + private createLargeStaticArrayCopyFunction(type: ArrayType): GeneratedFunctionInfo { assert(type.size !== undefined, 'Expected static array with known size'); - const funcName = `ws_to_memory${this.generatedFunctions.size}`; const length = narrowBigIntSafe( type.size, `Failed to narrow size of ${printTypeNode(type)} in memory->storage copy generation`, ); - const implicits = - '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, warp_memory : DictAccess*}'; - - // Set an empty entry so recursive function generation doesn't clash - this.generatedFunctions.set(key, { name: funcName, code: '' }); const elementMemoryWidth = CairoType.fromSol(type.elementT, this.ast).width; const elementStorageWidth = CairoType.fromSol( @@ -167,18 +187,18 @@ export class StorageToMemoryGen extends StringIndexedFuncGen { this.ast, TypeConversionContext.StorageAllocation, ).width; - - const copyCode: string = this.getRecursiveCopyCode( + const [copyCode, copyCalls] = this.getRecursiveCopyCode( type.elementT, elementMemoryWidth, 'loc', 'mem_start', ); - this.generatedFunctions.set(key, { + const funcName = `ws_to_memory_large_static_array${this.generatedFunctionsDef.size}`; + const funcInfo: GeneratedFunctionInfo = { name: funcName, code: [ - `func ${funcName}_elem${implicits}(mem_start: felt, loc : felt, length: Uint256) -> (){`, + `func ${funcName}_elem${IMPLICITS}(mem_start: felt, loc : felt, length: Uint256) -> (){`, ` alloc_locals;`, ` if (length.low == 0){`, ` if (length.high == 0){`, @@ -193,7 +213,7 @@ export class StorageToMemoryGen extends StringIndexedFuncGen { )}, index);`, `}`, - `func ${funcName}${implicits}(loc : felt) -> (mem_loc : felt){`, + `func ${funcName}${IMPLICITS}(loc : felt) -> (mem_loc : felt){`, ` alloc_locals;`, ` let length = ${uint256(length)};`, ` let (mem_start) = wm_alloc(length);`, @@ -201,77 +221,74 @@ export class StorageToMemoryGen extends StringIndexedFuncGen { ` return (mem_start,);`, `}`, ].join('\n'), - }); - - this.requireImport('starkware.cairo.common.dict', 'dict_write'); - this.requireImport('warplib.memory', 'wm_alloc'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_sub'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - - return funcName; + functionsCalled: [ + this.requireImport('starkware.cairo.common.dict', 'dict_write'), + this.requireImport('warplib.memory', 'wm_alloc'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_sub'), + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + ...copyCalls, + ], + }; + return funcInfo; } private createDynamicArrayCopyFunction( - key: string, type: ArrayType | BytesType | StringType, - ): string { + ): GeneratedFunctionInfo { const elementT = getElementType(type); const memoryElementType = CairoType.fromSol(elementT, this.ast); - const funcName = `ws_to_memory${this.generatedFunctions.size}`; - this.generatedFunctions.set(key, { - name: funcName, - code: '', - }); - const [elemMapping, lengthMapping] = this.dynArrayGen.gen( - CairoType.fromSol(elementT, this.ast, TypeConversionContext.StorageAllocation), - ); - const implicits = - '{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt, warp_memory : DictAccess*}'; + const [dynArray, dynArrayLength] = this.dynArrayGen.getOrCreateFuncDef(elementT); + const elemMappingName = dynArray.name; + const lengthMappingName = dynArrayLength.name; // This is the code to copy a single element // Complex types require calls to another function generated here // Simple types take one or two WARP_STORAGE-dict_write pairs - const copyCode: string = this.getRecursiveCopyCode( + const [copyCode, copyCalls] = this.getRecursiveCopyCode( elementT, memoryElementType.width, 'element_storage_loc', 'mem_loc', ); - // Now generate two functions: the setup function funcName, and the elementwise copy function: funcName_elem - this.generatedFunctions.set(key, { + const funcName = `ws_to_memory_dynamic_array${this.generatedFunctionsDef.size}`; + const funcInfo: GeneratedFunctionInfo = { name: funcName, code: [ - `func ${funcName}_elem${implicits}(storage_name: felt, mem_start: felt, length: Uint256) -> (){`, + `func ${funcName}_elem${IMPLICITS}(storage_name: felt, mem_start: felt, length: Uint256) -> (){`, ` alloc_locals;`, ` if (length.low == 0 and length.high == 0){`, ` return ();`, ` }`, ` let (index) = uint256_sub(length, Uint256(1,0));`, ` let (mem_loc) = wm_index_dyn(mem_start, index, ${uint256(memoryElementType.width)});`, - ` let (element_storage_loc) = ${elemMapping}.read(storage_name, index);`, + ` let (element_storage_loc) = ${elemMappingName}.read(storage_name, index);`, copyCode, ` return ${funcName}_elem(storage_name, mem_start, index);`, `}`, - `func ${funcName}${implicits}(loc : felt) -> (mem_loc : felt){`, + `func ${funcName}${IMPLICITS}(loc : felt) -> (mem_loc : felt){`, ` alloc_locals;`, - ` let (length: Uint256) = ${lengthMapping}.read(loc);`, + ` let (length: Uint256) = ${lengthMappingName}.read(loc);`, ` let (mem_start) = wm_new(length, ${uint256(memoryElementType.width)});`, ` ${funcName}_elem(loc, mem_start, length);`, ` return (mem_start,);`, `}`, ].join('\n'), - }); - - this.requireImport('starkware.cairo.common.dict', 'dict_write'); - this.requireImport('starkware.cairo.common.uint256', 'uint256_sub'); - this.requireImport('starkware.cairo.common.uint256', 'Uint256'); - this.requireImport('warplib.memory', 'wm_new'); - this.requireImport('warplib.memory', 'wm_index_dyn'); + functionsCalled: [ + this.requireImport('starkware.cairo.common.dict', 'dict_write'), + this.requireImport('starkware.cairo.common.uint256', 'uint256_sub'), + this.requireImport('starkware.cairo.common.uint256', 'Uint256'), + this.requireImport('warplib.memory', 'wm_new'), + this.requireImport('warplib.memory', 'wm_index_dyn'), + ...copyCalls, + dynArray, + dynArrayLength, + ], + }; - return funcName; + return funcInfo; } // Copy code generation for iterative copy instructions (small static arrays and structs) @@ -279,18 +296,21 @@ export class StorageToMemoryGen extends StringIndexedFuncGen { copyType: TypeNode | undefined, index: number, storageOffset: number, - ): string { + ): [string, CairoFunctionDefinition[]] { if (copyType === undefined) { - return `let (copy${index}) = WARP_STORAGE.read(${add('loc', storageOffset)});`; + return [`let (copy${index}) = WARP_STORAGE.read(${add('loc', storageOffset)});`, []]; } - const funcName = this.getOrCreate(copyType); - return isDynamicArray(copyType) - ? [ - `let (dyn_loc) = WARP_STORAGE.read(${add('loc', storageOffset)});`, - `let (copy${index}) = ${funcName}(dyn_loc);`, - ].join('\n') - : `let (copy${index}) = ${funcName}(${add('loc', storageOffset)});`; + const func = this.getOrCreateFuncDef(copyType); + return [ + isDynamicArray(copyType) + ? [ + `let (dyn_loc) = WARP_STORAGE.read(${add('loc', storageOffset)});`, + `let (copy${index}) = ${func.name}(dyn_loc);`, + ].join('\n') + : `let (copy${index}) = ${func.name}(${add('loc', storageOffset)});`, + [func], + ]; } // Copy code generation for recursive copy instructions (large static arrays and dynamic arrays) @@ -299,26 +319,43 @@ export class StorageToMemoryGen extends StringIndexedFuncGen { elementMemoryWidth: number, storageLoc: string, memoryLoc: string, - ) { - if (isStaticArrayOrStruct(elementT)) { - return [ - ` let (copy) = ${this.getOrCreate(elementT)}(${storageLoc});`, - ` dict_write{dict_ptr=warp_memory}(${memoryLoc}, copy);`, - ].join('\n'); - } else if (isDynamicArray(elementT)) { - return [ - ` let (dyn_loc) = readId(${storageLoc});`, - ` let (copy) = ${this.getOrCreate(elementT)}(dyn_loc);`, - ` dict_write{dict_ptr=warp_memory}(${memoryLoc}, copy);`, - ].join('\n'); - } else { - return mapRange(elementMemoryWidth, (n) => + ): [string, CairoFunctionDefinition[]] { + if (isReferenceType(elementT)) { + const auxFunc = this.getOrCreateFuncDef(elementT); + if (isStaticArrayOrStruct(elementT)) { + return [ + [ + ` let (copy) = ${auxFunc.name}(${storageLoc});`, + ` dict_write{dict_ptr=warp_memory}(${memoryLoc}, copy);`, + ].join('\n'), + [auxFunc], + ]; + } else if (isDynamicArray(elementT)) { + return [ + [ + ` let (dyn_loc) = readId(${storageLoc});`, + ` let (copy) = ${auxFunc.name}(dyn_loc);`, + ` dict_write{dict_ptr=warp_memory}(${memoryLoc}, copy);`, + ].join('\n'), + [auxFunc], + ]; + } + throw new TranspileFailedError( + `Trying to create recursive code for unsupported referency type: ${printTypeNode( + elementT, + )}`, + ); + } + + return [ + mapRange(elementMemoryWidth, (n) => [ ` let (copy) = WARP_STORAGE.read(${add(`${storageLoc}`, n)});`, ` dict_write{dict_ptr=warp_memory}(${add(`${memoryLoc}`, n)}, copy);`, ].join('\n'), - ).join('\n'); - } + ).join('\n'), + [], + ]; } } diff --git a/src/cairoUtilFuncGen/storage/storageWrite.ts b/src/cairoUtilFuncGen/storage/storageWrite.ts index 9c7319a51..d0ed013d5 100644 --- a/src/cairoUtilFuncGen/storage/storageWrite.ts +++ b/src/cairoUtilFuncGen/storage/storageWrite.ts @@ -1,29 +1,29 @@ -import { - Expression, - FunctionCall, - TypeNode, - ASTNode, - DataLocation, - PointerType, -} from 'solc-typed-ast'; +import { Expression, FunctionCall, TypeNode, DataLocation, PointerType } from 'solc-typed-ast'; import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; import { cloneASTNode } from '../../utils/cloning'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { safeGetNodeType } from '../../utils/nodeTypeProcessing'; import { typeNameFromTypeNode } from '../../utils/utils'; -import { add, StringIndexedFuncGen } from '../base'; +import { add, GeneratedFunctionInfo, StringIndexedFuncGen } from '../base'; export class StorageWriteGen extends StringIndexedFuncGen { - gen( - storageLocation: Expression, - writeValue: Expression, - nodeInSourceUnit?: ASTNode, - ): FunctionCall { + public gen(storageLocation: Expression, writeValue: Expression): FunctionCall { const typeToWrite = safeGetNodeType(storageLocation, this.ast.inference); - const name = this.getOrCreate(typeToWrite); + const funcDef = this.getOrCreateFuncDef(typeToWrite); + return createCallToFunction(funcDef, [storageLocation, writeValue], this.ast); + } + + public getOrCreateFuncDef(typeToWrite: TypeNode) { + const key = `dynArrayPop(${typeToWrite.pp()})`; + const value = this.generatedFunctionsDef.get(key); + if (value !== undefined) { + return value; + } + + const funcInfo = this.getOrCreate(typeToWrite); const argTypeName = typeNameFromTypeNode(typeToWrite, this.ast); - const functionStub = createCairoFunctionStub( - name, + const funcDef = createCairoGeneratedFunction( + funcInfo, [ ['loc', argTypeName, DataLocation.Storage], [ @@ -39,28 +39,23 @@ export class StorageWriteGen extends StringIndexedFuncGen { typeToWrite instanceof PointerType ? DataLocation.Storage : DataLocation.Default, ], ], - ['syscall_ptr', 'pedersen_ptr', 'range_check_ptr'], this.ast, - nodeInSourceUnit ?? storageLocation, + this.sourceUnit, ); - return createCallToFunction(functionStub, [storageLocation, writeValue], this.ast); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } - getOrCreate(typeToWrite: TypeNode): string { + private getOrCreate(typeToWrite: TypeNode): GeneratedFunctionInfo { const cairoTypeToWrite = CairoType.fromSol( typeToWrite, this.ast, TypeConversionContext.StorageAllocation, ); - const key = cairoTypeToWrite.fullStringRepresentation; - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } const cairoTypeString = cairoTypeToWrite.toString(); - const funcName = `WS_WRITE${this.generatedFunctions.size}`; - this.generatedFunctions.set(key, { + const funcName = `WS_WRITE${this.generatedFunctionsDef.size}`; + const funcInfo: GeneratedFunctionInfo = { name: funcName, code: [ `func ${funcName}{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr : felt}(loc: felt, value: ${cairoTypeString}) -> (res: ${cairoTypeString}){`, @@ -70,8 +65,9 @@ export class StorageWriteGen extends StringIndexedFuncGen { ' return (value,);', '}', ].join('\n'), - }); - return funcName; + functionsCalled: [], + }; + return funcInfo; } } diff --git a/src/cairoUtilFuncGen/utils/encodeToFelt.ts b/src/cairoUtilFuncGen/utils/encodeToFelt.ts index 12660e3e7..3f0cc60ac 100644 --- a/src/cairoUtilFuncGen/utils/encodeToFelt.ts +++ b/src/cairoUtilFuncGen/utils/encodeToFelt.ts @@ -14,6 +14,8 @@ import { UserDefinedType, } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { CairoGeneratedFunctionDefinition } from '../../ast/cairoNodes/cairoGeneratedFunctionDefinition'; +import { CairoFunctionDefinition, notUndefined } from '../../export'; import { printTypeNode } from '../../utils/astPrinter'; import { CairoDynArray, @@ -22,7 +24,7 @@ import { TypeConversionContext, } from '../../utils/cairoTypeSystem'; import { NotSupportedYetError, WillNotSupportError } from '../../utils/errors'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCairoGeneratedFunction, createCallToFunction } from '../../utils/functionGeneration'; import { createBytesTypeName } from '../../utils/nodeTemplates'; import { getElementType, @@ -32,7 +34,11 @@ import { safeGetNodeType, } from '../../utils/nodeTypeProcessing'; import { mapRange, narrowBigIntSafe, typeNameFromTypeNode } from '../../utils/utils'; -import { CairoFunction, delegateBasedOnType, StringIndexedFuncGen } from '../base'; +import { + delegateBasedOnType, + GeneratedFunctionInfo, + StringIndexedFuncGenWithAuxiliar, +} from '../base'; import { ExternalDynArrayStructConstructor } from '../calldata/externalDynArray/externalDynArrayStructConstructor'; const IMPLICITS = ''; @@ -41,7 +47,7 @@ const IMPLICITS = ''; * This class generate `encode` cairo util functions with the objective of making * a list of values into a single list where all items are felts. For example: * Value list: [a : felt, b : Uint256, c : (felt, felt, felt), d_len : felt, d : felt*] - * Result: [a, b.low, b.high, c[0], c[1], c[2], d_len, d[0], ..., d[n]] + * Result: [a, b.low, b.high, c[0], c[1], c[2], d_len, d[0], ..., d[d_len - 1]] * * It generates a different function depending on the amount of expressions * and their types. It also generate different auxiliar functions depending @@ -51,10 +57,8 @@ const IMPLICITS = ''; * generated encoding functions. I.e. the auxiliar function to encode felt * dynamic arrays will be always the same */ -export class EncodeAsFelt extends StringIndexedFuncGen { - private auxiliarGeneratedFunctions = new Map(); - - constructor( +export class EncodeAsFelt extends StringIndexedFuncGenWithAuxiliar { + public constructor( private externalArrayGen: ExternalDynArrayStructConstructor, ast: AST, sourceUnit: SourceUnit, @@ -62,12 +66,6 @@ export class EncodeAsFelt extends StringIndexedFuncGen { super(ast, sourceUnit); } - getGeneratedCode(): string { - return [...this.auxiliarGeneratedFunctions.values(), ...this.generatedFunctions.values()] - .map((func) => func.code) - .join('\n\n'); - } - /** * Given a expression list it generates a `encode` cairo function definition * and call that serializes the arguments into a list of felts @@ -76,23 +74,32 @@ export class EncodeAsFelt extends StringIndexedFuncGen { * @param sourceUnit source unit where the expression is defined * @returns a function call that serializes the value of `expressions` */ - gen(expressions: Expression[], expectedTypes: TypeNode[], sourceUnit?: SourceUnit): FunctionCall { + public gen(expressions: Expression[], expectedTypes: TypeNode[]): FunctionCall { assert(expectedTypes.length === expressions.length); expectedTypes = expectedTypes.map((type) => generalizeType(type)[0]); - const functionName = this.getOrCreate(expectedTypes); + const funcDef = this.getOrCreateFuncDef(expectedTypes); + return createCallToFunction(funcDef, expressions, this.ast); + } - const functionStub = createCairoFunctionStub( - functionName, - expectedTypes.map((exprT, index) => { + public getOrCreateFuncDef(typesToEncode: TypeNode[]) { + const key = typesToEncode.map((t) => t.pp()).join(','); + const existing = this.generatedFunctionsDef.get(key); + if (existing !== undefined) { + return existing; + } + const funcInfo = this.getOrCreate(typesToEncode); + const funcDef = createCairoGeneratedFunction( + funcInfo, + typesToEncode.map((exprT, index) => { const input: [string, TypeName] = [`arg${index}`, typeNameFromTypeNode(exprT, this.ast)]; return isValueType(exprT) ? input : [...input, DataLocation.CallData]; }), [['result', createBytesTypeName(this.ast), DataLocation.CallData]], - [], this.ast, - sourceUnit ?? this.sourceUnit, + this.sourceUnit, ); - return createCallToFunction(functionStub, expressions, this.ast); + this.generatedFunctionsDef.set(key, funcDef); + return funcDef; } /** @@ -101,82 +108,99 @@ export class EncodeAsFelt extends StringIndexedFuncGen { * @param typesToEncode type list * @returns the name of the generated function */ - getOrCreate(typesToEncode: TypeNode[]): string { - const key = typesToEncode.map((t) => t.pp()).join(','); - const existing = this.generatedFunctions.get(key); - if (existing !== undefined) { - return existing.name; - } - - const parameters: string[] = []; - const encodeCode: string[] = []; - - typesToEncode.forEach((type, index) => { - const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); - const prefix = `arg_${index}`; - - if (isDynamicArray(type)) { - assert(cairoType instanceof CairoDynArray); - const arrayName = `${prefix}_dynamic`; - parameters.push(` ${arrayName} : ${cairoType.typeName}`); - const auxFuncName = this.getOrCreateAuxiliar(type); - encodeCode.push( - `assert decode_array[total_size] = ${arrayName}.len;`, - `let total_size = total_size + 1;`, - `let (total_size) = ${auxFuncName}(total_size, decode_array, 0, ${arrayName}.len, ${arrayName}.ptr);`, - ); - } else if (type instanceof ArrayType) { - parameters.push(`${prefix}_static : ${cairoType.toString()}`); - const auxFuncName = this.getOrCreateAuxiliar(type); - encodeCode.push( - `let (total_size) = ${auxFuncName}(total_size, decode_array, ${prefix}_static);`, - ); - } else if (isStruct(type)) { - assert(cairoType instanceof CairoStruct); - parameters.push(`${prefix}_${cairoType.name} : ${cairoType.typeName}`); - const auxFuncName = this.getOrCreateAuxiliar(type); - encodeCode.push( - `let (total_size) = ${auxFuncName}(total_size, decode_array, ${prefix}_${cairoType.name});`, - ); - } else if (isValueType(type)) { - parameters.push(`${prefix} : ${cairoType.typeName}`); - encodeCode.push( - cairoType.width > 1 - ? [ - `assert decode_array[total_size] = ${prefix}.low;`, - `assert decode_array[total_size + 1] = ${prefix}.high;`, - `let total_size = total_size + 2;`, - ].join('\n') - : [ - `assert decode_array[total_size] = ${prefix};`, - `let total_size = total_size + 1;`, - ].join('\n'), - ); - } else { + private getOrCreate(typesToEncode: TypeNode[]): GeneratedFunctionInfo { + const [parameters, encodeCode, encodeCalls] = typesToEncode.reduce( + ([parameters, encodeCode, encodeCalls], type, index) => { + const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); + const prefix = `arg_${index}`; + if (isDynamicArray(type)) { + // Handle dynamic arrays + assert(cairoType instanceof CairoDynArray); + const arrayName = `${prefix}_dynamic`; + const auxFunc = this.getOrCreateAuxiliar(type); + return [ + [...parameters, ` ${arrayName} : ${cairoType.typeName}`], + [ + ...encodeCode, + `assert decode_array[total_size] = ${arrayName}.len;`, + `let total_size = total_size + 1;`, + `let (total_size) = ${auxFunc.name}(total_size, decode_array, 0, ${arrayName}.len, ${arrayName}.ptr);`, + ], + [...encodeCalls, auxFunc], + ]; + } else if (type instanceof ArrayType) { + // Handle static arrays + const auxFunc = this.getOrCreateAuxiliar(type); + return [ + [...parameters, `${prefix}_static : ${cairoType.toString()}`], + [ + ...encodeCode, + `let (total_size) = ${auxFunc.name}(total_size, decode_array, ${prefix}_static);`, + ], + [...encodeCalls, auxFunc], + ]; + } else if (isStruct(type)) { + // Handle structs + assert(cairoType instanceof CairoStruct); + const auxFuncName = this.getOrCreateAuxiliar(type); + return [ + [...parameters, `${prefix}_${cairoType.name} : ${cairoType.typeName}`], + [ + ...encodeCode, + `let (total_size) = ${auxFuncName.name}(total_size, decode_array, ${prefix}_${cairoType.name});`, + ], + [...encodeCalls, auxFuncName], + ]; + } else if (isValueType(type)) { + // Handle value types + return [ + [...parameters, `${prefix} : ${cairoType.typeName}`], + [ + ...encodeCode, + cairoType.width > 1 + ? [ + `assert decode_array[total_size] = ${prefix}.low;`, + `assert decode_array[total_size + 1] = ${prefix}.high;`, + `let total_size = total_size + 2;`, + ].join('\n') + : [ + `assert decode_array[total_size] = ${prefix};`, + `let total_size = total_size + 1;`, + ].join('\n'), + ], + [...encodeCalls], + ]; + } throw new WillNotSupportError( `Decoding ${printTypeNode(type)} into felt dynamic array is not supported`, ); - } - }); + }, + [new Array(), new Array(), new Array()], + ); - const resultStruct = this.externalArrayGen.getOrCreate(new BytesType()); + const resultStruct = this.externalArrayGen.getOrCreateFuncDef(new BytesType()); const cairoParams = parameters.join(','); - const funcName = `encode_as_felt${this.generatedFunctions.size}`; + const funcName = `encode_as_felt${this.generatedFunctionsDef.size}`; const code = [ - `func ${funcName}${IMPLICITS}(${cairoParams}) -> (calldata_array : ${resultStruct}){`, + `func ${funcName}${IMPLICITS}(${cairoParams}) -> (calldata_array : ${resultStruct.name}){`, ` alloc_locals;`, ` let total_size : felt = 0;`, ` let (decode_array : felt*) = alloc();`, ...encodeCode, - ` let result = ${resultStruct}(total_size, decode_array);`, + ` let result = ${resultStruct.name}(total_size, decode_array);`, ` return (result,);`, `}`, ].join('\n'); - this.requireImport('starkware.cairo.common.alloc', 'alloc'); - this.generatedFunctions.set(key, { name: funcName, code: code }); - return funcName; + const importFunc = this.requireImport('starkware.cairo.common.alloc', 'alloc'); + + const funcInfo = { + name: funcName, + code: code, + functionsCalled: [importFunc, ...encodeCalls, resultStruct].filter(notUndefined), + }; + return funcInfo; } /** @@ -184,12 +208,12 @@ export class EncodeAsFelt extends StringIndexedFuncGen { * @param type to encode (only arrays and structs allowed) * @returns name of the generated function */ - private getOrCreateAuxiliar(type: TypeNode): string { + private getOrCreateAuxiliar(type: TypeNode): CairoFunctionDefinition { const key = type.pp(); const existing = this.auxiliarGeneratedFunctions.get(key); if (existing !== undefined) { - return existing.name; + return existing; } const unexpectedTypeFunc = () => { @@ -198,7 +222,7 @@ export class EncodeAsFelt extends StringIndexedFuncGen { ); }; - const cairoFunc = delegateBasedOnType( + const cairoFunc = delegateBasedOnType( type, (type) => this.generateDynamicArrayEncodeFunction(type), (type) => this.generateStaticArrayEncodeFunction(type), @@ -208,7 +232,7 @@ export class EncodeAsFelt extends StringIndexedFuncGen { ); this.auxiliarGeneratedFunctions.set(key, cairoFunc); - return cairoFunc.name; + return cairoFunc; } /** @@ -219,31 +243,41 @@ export class EncodeAsFelt extends StringIndexedFuncGen { * @param currentElementName cairo variable to encode to felt * @returns generated code */ - private generateEncodeCode(type: TypeNode, currentElementName: string): string[] { + private generateEncodeCode( + type: TypeNode, + currentElementName: string, + ): [string[], CairoFunctionDefinition[]] { if (isValueType(type)) { const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); - return cairoType.width > 1 - ? [ - `assert to_array[to_index] = ${currentElementName}.low;`, - `assert to_array[to_index + 1] = ${currentElementName}.high;`, - `let to_index = to_index + 2;`, - ] - : [`assert to_array[to_index] = ${currentElementName};`, `let to_index = to_index + 1;`]; + return [ + cairoType.width === 2 + ? [ + `assert to_array[to_index] = ${currentElementName}.low;`, + `assert to_array[to_index + 1] = ${currentElementName}.high;`, + `let to_index = to_index + 2;`, + ] + : [`assert to_array[to_index] = ${currentElementName};`, `let to_index = to_index + 1;`], + [], + ]; } const auxFuncName = this.getOrCreateAuxiliar(type); - return [`let (to_index) = ${auxFuncName}(to_index, to_array, ${currentElementName});`]; + return [ + [`let (to_index) = ${auxFuncName.name}(to_index, to_array, ${currentElementName});`], + [auxFuncName], + ]; } private generateDynamicArrayEncodeFunction( type: ArrayType | BytesType | StringType, - ): CairoFunction { + ): CairoGeneratedFunctionDefinition { const cairoElementType = CairoType.fromSol( getElementType(type), this.ast, TypeConversionContext.CallDataRef, ); const elemenT = getElementType(type); + const [encodingCode, encodingCalls] = this.generateEncodeCode(elemenT, 'current_element'); const funcName = `encode_dynamic_array${this.auxiliarGeneratedFunctions.size}`; const code = [ `func ${funcName}${IMPLICITS}(`, @@ -258,50 +292,97 @@ export class EncodeAsFelt extends StringIndexedFuncGen { ` return (total_copied=to_index,);`, ` }`, ` let current_element = from_array[from_index];`, - ...this.generateEncodeCode(elemenT, 'current_element'), + ...encodingCode, ` return ${funcName}(to_index, to_array, from_index + 1, from_size, from_array);`, `}`, ]; - return { name: funcName, code: code.join('\n') }; + const funcInfo: GeneratedFunctionInfo = { + name: funcName, + code: code.join('\n'), + functionsCalled: encodingCalls, + }; + + return createCairoGeneratedFunction(funcInfo, [], [], this.ast, this.sourceUnit); } - private generateStructEncodeFunction(type: UserDefinedType): CairoFunction { + private generateStructEncodeFunction(type: UserDefinedType): CairoGeneratedFunctionDefinition { assert(type.definition instanceof StructDefinition); - const encodeCode = type.definition.vMembers.map((varDecl, index) => { - const varType = safeGetNodeType(varDecl, this.ast.inference); - return [ - `let member_${index} = from_struct.${varDecl.name};`, - ...this.generateEncodeCode(varType, `member_${index}`), - ].join('\n'); - }); + const [encodeCode, encodeCalls] = type.definition.vMembers.reduce( + ([encodeCode, encodeCalls], varDecl, index) => { + const varType = safeGetNodeType(varDecl, this.ast.inference); + const [memberEncodeCode, memberEncodeCalls] = this.generateEncodeCode( + varType, + `member_${index}`, + ); + return [ + [ + ...encodeCode, + `let member_${index} = from_struct.${varDecl.name};`, + ...memberEncodeCode, + ], + [...encodeCalls, ...memberEncodeCalls], + ]; + }, + [new Array(), new Array()], + ); const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); assert(cairoType instanceof CairoStruct); const funcName = `encode_struct_${cairoType.name}`; - const code = [ - `func ${funcName}${IMPLICITS}(`, - ` to_index : felt, to_array : felt*, from_struct : ${cairoType.toString()}`, - `) -> (total_copied : felt){`, - ` alloc_locals;`, - ...encodeCode, - ` return (to_index,);`, - `}`, - ]; - return { name: funcName, code: code.join('\n') }; + return this.createAuxiliarGeneratedFunction({ + name: funcName, + code: [ + `func ${funcName}${IMPLICITS}(`, + ` to_index : felt, to_array : felt*, from_struct : ${cairoType.toString()}`, + `) -> (total_copied : felt){`, + ` alloc_locals;`, + ...encodeCode, + ` return (to_index,);`, + `}`, + ].join('\n'), + functionsCalled: encodeCalls, + }); } - private generateStaticArrayEncodeFunction(type: ArrayType): CairoFunction { + // TODO: Do a small version of static array encoding + private generateStaticArrayEncodeFunction(type: ArrayType): CairoGeneratedFunctionDefinition { assert(type.size !== undefined); const cairoType = CairoType.fromSol(type, this.ast, TypeConversionContext.CallDataRef); const elemenT = type.elementT; + + const cairoElementT = CairoType.fromSol(elemenT, this.ast, TypeConversionContext.CallDataRef); + + let staticArrayEncoding: (element: string) => string[]; + let funcsCalled: CairoFunctionDefinition[]; + if (isValueType(elemenT)) { + staticArrayEncoding = + cairoElementT.width === 2 + ? (element: string) => [ + `assert to_array[to_index] = ${element}.low;`, + `assert to_array[to_index + 1] = ${element}.high;`, + `let to_index = to_index + 2;`, + ] + : (element: string) => [ + `assert to_array[to_index] = ${element};`, + `let to_index = to_index + 1;`, + ]; + funcsCalled = []; + } else { + const auxFunc = this.getOrCreateAuxiliar(elemenT); + staticArrayEncoding = (element: string) => [ + `let (to_index) = ${auxFunc.name}(to_index, to_array, ${element});`, + ]; + funcsCalled = [auxFunc]; + } + const encodeCode = mapRange(narrowBigIntSafe(type.size), (index) => { return [ `let elem_${index} = from_static_array[${index}];`, - ...this.generateEncodeCode(elemenT, `elem_${index}`), + ...staticArrayEncoding(`elem_${index}`), ].join('\n'); }); @@ -312,7 +393,12 @@ export class EncodeAsFelt extends StringIndexedFuncGen { ...encodeCode, ` return (to_index,);`, `}`, - ]; - return { name: funcName, code: code.join('\n') }; + ].join('\n'); + + return this.createAuxiliarGeneratedFunction({ + name: funcName, + code: code, + functionsCalled: funcsCalled, + }); } } diff --git a/src/cairoWriter/index.ts b/src/cairoWriter/index.ts index b467a152d..4b966538c 100644 --- a/src/cairoWriter/index.ts +++ b/src/cairoWriter/index.ts @@ -61,6 +61,8 @@ import { CairoAssert, CairoContract, CairoFunctionDefinition, + CairoGeneratedFunctionDefinition, + CairoImportFunctionDefinition, CairoTempVarStatement, } from '../ast/cairoNodes'; import { @@ -90,6 +92,8 @@ import { StructuredDocumentationWriter, TupleExpressionWriter, VariableDeclarationStatementWriter, + CairoGeneratedFunctionDefinitionWriter, + CairoImportFunctionDefinitionWriter, VariableDeclarationWriter, } from './writers'; @@ -103,6 +107,14 @@ export const CairoASTMapping = (ast: AST, throwOnUnimplemented: boolean) => [CairoAssert, new CairoAssertWriter(ast, throwOnUnimplemented)], [CairoContract, new CairoContractWriter(ast, throwOnUnimplemented)], [CairoFunctionDefinition, new CairoFunctionDefinitionWriter(ast, throwOnUnimplemented)], + [ + CairoGeneratedFunctionDefinition, + new CairoGeneratedFunctionDefinitionWriter(ast, throwOnUnimplemented), + ], + [ + CairoImportFunctionDefinition, + new CairoImportFunctionDefinitionWriter(ast, throwOnUnimplemented), + ], [CairoTempVarStatement, new CairoTempVarWriter(ast, throwOnUnimplemented)], [Conditional, new NotImplementedWriter(ast, throwOnUnimplemented)], [Continue, new NotImplementedWriter(ast, throwOnUnimplemented)], diff --git a/src/cairoWriter/utils.ts b/src/cairoWriter/utils.ts index 6cdd3c1f2..84bff7349 100644 --- a/src/cairoWriter/utils.ts +++ b/src/cairoWriter/utils.ts @@ -1,6 +1,5 @@ import assert from 'assert'; import { ASTNode, ASTWriter, SourceUnit, StructuredDocumentation } from 'solc-typed-ast'; -import { mergeImports } from '../utils/utils'; export const INDENT = ' '.repeat(4); export const INCLUDE_CAIRO_DUMP_FUNCTIONS = false; @@ -16,21 +15,6 @@ export function getDocumentation( : ''; } -export function writeImports(imports: Map>): string { - if (INCLUDE_CAIRO_DUMP_FUNCTIONS) { - imports = mergeImports( - imports, - new Map([['starkware.cairo.common.alloc', new Set(['alloc'])]]), - ); - } - return [...imports.entries()] - .map( - ([location, importedSymbols]) => - `from ${location} import ${[...importedSymbols.keys()].join(', ')}`, - ) - .join('\n'); -} - export function getInterfaceNameForContract( contractName: string, nodeInSourceUnit: ASTNode, diff --git a/src/cairoWriter/writers/cairoContractWriter2.ts b/src/cairoWriter/writers/cairoContractWriter2.ts deleted file mode 100644 index 0b8b0bd2e..000000000 --- a/src/cairoWriter/writers/cairoContractWriter2.ts +++ /dev/null @@ -1,38 +0,0 @@ -// import { ASTWriter, ContractKind, SrcDesc } from 'solc-typed-ast'; -// import { isExternallyVisible } from '../../utils/utils'; -// import { CairoContract } from '../../ast/cairoNodes'; -// import { TEMP_INTERFACE_SUFFIX } from '../../utils/nameModifiers'; -// import { CairoASTNodeWriter } from '../base'; -// import { getDocumentation, getInterfaceNameForContract, INDENT } from '../utils'; -// import { interfaceNameMappings } from './sourceUnitWriter'; -// import { TranspileFailedError } from '../../export'; -// -// export class CairoContractWriter extends CairoASTNodeWriter { -// writeInner(node: CairoContract, writer: ASTWriter): SrcDesc { -// // TODO: Deal with interfaces -// // TODO: Deal with abstracts (Currently are being dropped) -// if (node.kind === ContractKind.Interface || node.abstract) { -// throw new TranspileFailedError('Cannot transpile abstract contracts or interfaces'); -// } -// -// // TODO: Figure out constants outside function definitions -// // const staticVariables = [...node.staticStorageAllocations.entries()].map( -// // ([decl, loc]) => `const ${decl.name} = ${loc};`, -// // ); -// // const dynamicVariables = [...node.dynamicStorageAllocations.entries()].map( -// // ([decl, loc]) => `const ${decl.name} = ${loc};`, -// // ); -// -// const storageVars = ['WARP_STORAGE', 'WARP_USED_STORAGE', 'WARP']; -// -// const externalFunctions = node.vFunctions -// .filter((func) => isExternallyVisible(func)) -// .map((func) => writer.write(func)); -// -// const otherFunctions = node.vFunctions -// .filter((func) => !isExternallyVisible(func)) -// .map((func) => writer.write(func)); -// } -// -// writeWhole(node: CairoContract, writer: ASTWriter): SrcDesc {} -// } diff --git a/src/cairoWriter/writers/cairoGeneratedFunctionDefinitionWriter.ts b/src/cairoWriter/writers/cairoGeneratedFunctionDefinitionWriter.ts new file mode 100644 index 000000000..0044fc5f0 --- /dev/null +++ b/src/cairoWriter/writers/cairoGeneratedFunctionDefinitionWriter.ts @@ -0,0 +1,9 @@ +import { ASTWriter, SrcDesc } from 'solc-typed-ast'; +import { CairoGeneratedFunctionDefinition } from '../../ast/cairoNodes/cairoGeneratedFunctionDefinition'; +import { CairoASTNodeWriter } from '../base'; + +export class CairoGeneratedFunctionDefinitionWriter extends CairoASTNodeWriter { + writeInner(node: CairoGeneratedFunctionDefinition, _writer: ASTWriter): SrcDesc { + return [node.rawStringDefinition]; + } +} diff --git a/src/cairoWriter/writers/cairoImportFunctionDefinitionWriter.ts b/src/cairoWriter/writers/cairoImportFunctionDefinitionWriter.ts new file mode 100644 index 000000000..2d8ec444f --- /dev/null +++ b/src/cairoWriter/writers/cairoImportFunctionDefinitionWriter.ts @@ -0,0 +1,10 @@ +import { ASTWriter, SrcDesc } from 'solc-typed-ast'; +import { CairoImportFunctionDefinition } from '../../ast/cairoNodes'; +import { CairoASTNodeWriter } from '../base'; + +// Not being used as for now +export class CairoImportFunctionDefinitionWriter extends CairoASTNodeWriter { + writeInner(node: CairoImportFunctionDefinition, _writer: ASTWriter): SrcDesc { + return [`from ${node.path} import ${node.name}`]; + } +} diff --git a/src/cairoWriter/writers/functionCallWriter.ts b/src/cairoWriter/writers/functionCallWriter.ts index 46ade7a10..19a1ede35 100644 --- a/src/cairoWriter/writers/functionCallWriter.ts +++ b/src/cairoWriter/writers/functionCallWriter.ts @@ -11,6 +11,7 @@ import { StructDefinition, UserDefinedType, } from 'solc-typed-ast'; +import { CairoGeneratedFunctionDefinition } from '../../ast/cairoNodes'; import { CairoFunctionDefinition, isDynamicArray, @@ -63,6 +64,14 @@ export class FunctionCallWriter extends CairoASTNodeWriter { }${args})`, ]; } + } else if ( + node.vReferencedDeclaration instanceof CairoGeneratedFunctionDefinition && + node.vReferencedDeclaration.rawStringDefinition.includes('@storage_var') + ) { + return node.vArguments.length === + node.vReferencedDeclaration.vParameters.vParameters.length + ? [`${func}.read(${args})`] + : [`${func}.write(${args})`]; } else if ( node.vReferencedDeclaration instanceof CairoFunctionDefinition && (node.vReferencedDeclaration.acceptsRawDarray || diff --git a/src/cairoWriter/writers/index.ts b/src/cairoWriter/writers/index.ts index 7dd522648..0b616505d 100644 --- a/src/cairoWriter/writers/index.ts +++ b/src/cairoWriter/writers/index.ts @@ -4,6 +4,8 @@ export * from './blockWriter'; export * from './cairoAssertWriter'; export * from './cairoContractWriter'; export * from './cairoFunctionDefinitionWriter'; +export * from './cairoGeneratedFunctionDefinitionWriter'; +export * from './cairoImportFunctionDefinitionWriter'; export * from './cairoTempVarWriter'; export * from './elementaryTypeNameExpressionWriter'; export * from './emitStatementWriter'; diff --git a/src/cairoWriter/writers/sourceUnitWriter.ts b/src/cairoWriter/writers/sourceUnitWriter.ts index 7c3244c00..f936f9743 100644 --- a/src/cairoWriter/writers/sourceUnitWriter.ts +++ b/src/cairoWriter/writers/sourceUnitWriter.ts @@ -1,10 +1,14 @@ import assert from 'assert'; import { ASTWriter, ContractKind, SourceUnit, SrcDesc } from 'solc-typed-ast'; +import { + CairoImportFunctionDefinition, + CairoGeneratedFunctionDefinition, + FunctionStubKind, +} from '../../ast/cairoNodes'; import { getStructsAndRemappings } from '../../freeStructWritter'; import { removeExcessNewlines } from '../../utils/formatting'; import { TEMP_INTERFACE_SUFFIX } from '../../utils/nameModifiers'; import { CairoASTNodeWriter } from '../base'; -import { writeImports } from '../utils'; // Used by: // -> CairoContractWriter @@ -23,7 +27,7 @@ export class SourceUnitWriter extends CairoASTNodeWriter { ? node.vContracts.filter((cd) => !cd.name.endsWith(TEMP_INTERFACE_SUFFIX)) : node.vContracts; - assert(mainContract_.length <= 1, 'xx'); + assert(mainContract_.length <= 1); const [mainContract] = mainContract_; const [freeStructs, freeStructRemappings_] = mainContract @@ -41,21 +45,52 @@ export class SourceUnitWriter extends CairoASTNodeWriter { writer.write(v), ); - const functions = node.vFunctions.map((v) => writer.write(v)); + const importFunctions = node.vFunctions.filter( + (f): f is CairoImportFunctionDefinition => f instanceof CairoImportFunctionDefinition, + ); + const generatedFunctions = node.vFunctions.filter( + (f): f is CairoGeneratedFunctionDefinition => f instanceof CairoGeneratedFunctionDefinition, + ); + const functions = node.vFunctions.filter( + (f) => + !(f instanceof CairoGeneratedFunctionDefinition) && + !(f instanceof CairoImportFunctionDefinition), + ); + + const writtenImportFuncs = getGroupedImports( + importFunctions + .sort((funcA, funcB) => + `${funcA.path}.${funcA.name}`.localeCompare(`${funcB.path}.${funcB.name}`), + ) + .filter((func, index, importFuncs) => func.name !== importFuncs[index - 1]?.name), + ).reduce((writtenImports, importFunc) => `${writtenImports}\n${importFunc}`, ''); + + const writtenGeneratedFuncs = generatedFunctions + .sort((funcA, funcB) => funcA.name.localeCompare(funcB.name)) + .sort((funcA, funcB) => { + const stubA = funcA.functionStubKind; + const stubB = funcB.functionStubKind; + if (stubA === stubB) return 0; + if (stubA === FunctionStubKind.StructDefStub) return -1; + if (stubA === FunctionStubKind.StorageDefStub) return -1; + return 1; + }) + .filter((func, index, genFuncs) => func.name !== genFuncs[index - 1]?.name) + .map((func) => writer.write(func)); + + const writtenFuncs = functions.map((func) => writer.write(func)); const contracts = node.vContracts.map((v) => writer.write(v)); - const generatedUtilFunctions = this.ast.getUtilFuncGen(node).getGeneratedCode(); - const imports = writeImports(this.ast.getImports(node)); return [ removeExcessNewlines( [ '%lang starknet', - [imports], + writtenImportFuncs, ...constants, ...structs, - generatedUtilFunctions, - ...functions, + ...writtenGeneratedFuncs, + ...writtenFuncs, ...contracts, ].join('\n\n\n'), 3, @@ -86,3 +121,16 @@ export class SourceUnitWriter extends CairoASTNodeWriter { interfaceNameMappings.set(node, map); } } + +function getGroupedImports(imports: CairoImportFunctionDefinition[]): string[] { + const processedImports: string[] = []; + imports.reduce((functionNames: string[], importNode, index) => { + functionNames.push(importNode.name); + if (importNode.path !== imports[index + 1]?.path) { + processedImports.push(`from ${importNode.path} import ${functionNames.join(', ')}`); + functionNames = []; + } + return functionNames; + }, []); + return processedImports; +} diff --git a/src/export.ts b/src/export.ts index 41c7a2cbd..310b89528 100644 --- a/src/export.ts +++ b/src/export.ts @@ -1,5 +1,5 @@ export * from './ast/export'; -export * from './cairoUtilFuncGen/export'; +export * from './cairoUtilFuncGen'; export * from './cairoWriter'; export * from './cli'; export * from './freeStructWritter'; diff --git a/src/passes/annotateImplicits.ts b/src/passes/annotateImplicits.ts index 9b4c58386..cfce32be6 100644 --- a/src/passes/annotateImplicits.ts +++ b/src/passes/annotateImplicits.ts @@ -11,11 +11,11 @@ import { CairoFunctionDefinition, FunctionStubKind } from '../ast/cairoNodes'; import { ASTMapper } from '../ast/mapper'; import { ASTVisitor } from '../ast/visitor'; import { printNode } from '../utils/astPrinter'; -import { TranspileFailedError } from '../utils/errors'; -import { Implicits, implicitTypes, registerImportsForImplicit } from '../utils/implicits'; +import { Implicits, registerImportsForImplicit } from '../utils/implicits'; import { isExternallyVisible, union } from '../utils/utils'; import { getDocString, isCairoStub } from './cairoStubProcessor'; import { EMIT_PREFIX } from '../export'; +import { parseImplicits } from '../utils/cairoParsing'; export class AnnotateImplicits extends ASTMapper { // Function to add passes that should have been run before this pass @@ -56,7 +56,6 @@ export class AnnotateImplicits extends ASTMapper { node.raw, ); ast.replaceNode(node, annotatedFunction); - ast.copyRegisteredImports(node, annotatedFunction); implicits.forEach((i) => registerImportsForImplicit(ast, annotatedFunction, i)); node.children.forEach((child) => this.dispatchVisit(child, ast)); } @@ -146,32 +145,6 @@ function extractImplicitFromStubs(node: FunctionDefinition, result: Set impl1 : type1, impl2, ..., impln : typen - const implicits = funcSignature[1]; - - // implicitsList -> [impl1 : type1, impl2, ...., impln : typen] - const implicitsList = [...implicits.matchAll(/[A-Za-z][A-Za-z_: 0-9]*/g)].map((w) => w[0]); - - // implicitsNameList -> [impl1, impl2, ..., impln] - const implicitsNameList = implicitsList.map((i) => i.match(/[A-Za-z][A-Za-z_0-9]*/)); - if (!notContainsNull(implicitsNameList)) return; - - // Check that implicits are valid and add them to result - implicitsNameList.forEach((i) => { - const impl = i[0]; - if (!elementIsImplicit(impl)) { - throw new TranspileFailedError( - `Implicit ${impl} defined on function stub (${printNode(node)}) is not known`, - ); - } - result.add(impl); - }); -} - -function elementIsImplicit(e: string): e is Implicits { - return Object.keys(implicitTypes).includes(e); -} - -function notContainsNull(l: (T | null)[]): l is T[] { - return !l.some((e) => e === null); + const implicits = parseImplicits(funcSignature[1]); + implicits.forEach((impl) => result.add(impl)); } diff --git a/src/passes/argBoundChecker.ts b/src/passes/argBoundChecker.ts index 001781131..a9836fe76 100644 --- a/src/passes/argBoundChecker.ts +++ b/src/passes/argBoundChecker.ts @@ -32,9 +32,7 @@ export class ArgBoundChecker extends ASTMapper { node.vParameters.vParameters.forEach((decl) => { const type = safeGetNodeType(decl, ast.inference); if (checkableType(type)) { - const functionCall = ast - .getUtilFuncGen(node) - .boundChecks.inputCheck.gen(decl, type, node); + const functionCall = ast.getUtilFuncGen(node).boundChecks.inputCheck.gen(decl, type); this.insertFunctionCall(node, functionCall, ast); } }); diff --git a/src/passes/builtinHandler/blockMethods.ts b/src/passes/builtinHandler/blockMethods.ts index cc5a39c06..6cfc4553b 100644 --- a/src/passes/builtinHandler/blockMethods.ts +++ b/src/passes/builtinHandler/blockMethods.ts @@ -1,7 +1,7 @@ import { MemberAccess, Identifier, ExternalReferenceType, Expression } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; import { ASTMapper } from '../../ast/mapper'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCallToFunction } from '../../utils/functionGeneration'; import { createUint256TypeName } from '../../utils/nodeTemplates'; /* @@ -22,34 +22,30 @@ export class BlockMethods extends ASTMapper { ) { if (node.memberName === 'number') { const replacementCall = createCallToFunction( - createCairoFunctionStub( + ast.registerImport( + node, + 'warplib.block_methods', 'warp_block_number', [], [['block_num', createUint256TypeName(ast)]], - ['syscall_ptr', 'range_check_ptr'], - ast, - node, ), [], ast, ); ast.replaceNode(node, replacementCall); - ast.registerImport(replacementCall, 'warplib.block_methods', 'warp_block_number'); } else if (node.memberName === 'timestamp') { const replacementCall = createCallToFunction( - createCairoFunctionStub( + ast.registerImport( + node, + 'warplib.block_methods', 'warp_block_timestamp', [], [['block_timestamp', createUint256TypeName(ast)]], - ['syscall_ptr', 'range_check_ptr'], - ast, - node, ), [], ast, ); ast.replaceNode(node, replacementCall); - ast.registerImport(replacementCall, 'warplib.block_methods', 'warp_block_timestamp'); } } else { this.visitExpression(node, ast); diff --git a/src/passes/builtinHandler/ecrecover.ts b/src/passes/builtinHandler/ecrecover.ts index 08f74c88f..86c52d19b 100644 --- a/src/passes/builtinHandler/ecrecover.ts +++ b/src/passes/builtinHandler/ecrecover.ts @@ -1,7 +1,7 @@ import { ExternalReferenceType, FunctionCall } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; import { ASTMapper } from '../../ast/mapper'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCallToFunction } from '../../utils/functionGeneration'; import { createUintNTypeName } from '../../utils/nodeTemplates'; export class Ecrecover extends ASTMapper { @@ -15,7 +15,9 @@ export class Ecrecover extends ASTMapper { return this.commonVisit(node, ast); } - const ecrecoverEth = createCairoFunctionStub( + const ecrecoverEth = ast.registerImport( + node, + 'warplib.ecrecover', 'ecrecover_eth', [ ['msg_hash', createUintNTypeName(256, ast)], @@ -24,13 +26,7 @@ export class Ecrecover extends ASTMapper { ['s', createUintNTypeName(256, ast)], ], [['eth_address', createUintNTypeName(160, ast)]], - ['range_check_ptr', 'bitwise_ptr', 'keccak_ptr'], - ast, - node, ); - - ast.registerImport(node, 'warplib.ecrecover', 'ecrecover_eth'); - ast.replaceNode(node, createCallToFunction(ecrecoverEth, node.vArguments, ast)); } } diff --git a/src/passes/builtinHandler/explicitConversionToFunc.ts b/src/passes/builtinHandler/explicitConversionToFunc.ts index 21580da77..9183c06e4 100644 --- a/src/passes/builtinHandler/explicitConversionToFunc.ts +++ b/src/passes/builtinHandler/explicitConversionToFunc.ts @@ -25,7 +25,7 @@ import { NotSupportedYetError } from '../../utils/errors'; import { createAddressTypeName, createUint256TypeName } from '../../utils/nodeTemplates'; import { bigintToTwosComplement, toHexString } from '../../utils/utils'; import { functionaliseIntConversion } from '../../warplib/implementations/conversions/int'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCallToFunction } from '../../utils/functionGeneration'; import { functionaliseFixedBytesConversion } from '../../warplib/implementations/conversions/fixedBytes'; import { functionaliseBytesToFixedBytes } from '../../warplib/implementations/conversions/dynBytesToFixed'; import { safeGetNodeType } from '../../utils/nodeTypeProcessing'; @@ -79,20 +79,17 @@ export class ExplicitConversionToFunc extends ASTMapper { functionaliseIntConversion(node, ast); } else if (argType instanceof AddressType) { const replacementCall = createCallToFunction( - createCairoFunctionStub( + ast.registerImport( + node, + 'warplib.maths.utils', 'felt_to_uint256', [['address_arg', createAddressTypeName(false, ast)]], [['uint_ret', createUint256TypeName(ast)]], - ['range_check_ptr'], - ast, - node, ), [node.vArguments[0]], ast, ); - ast.replaceNode(node, replacementCall); - ast.registerImport(replacementCall, 'warplib.maths.utils', 'felt_to_uint256'); } else { throw new NotSupportedYetError( `Unexpected type ${printTypeNode(argType)} in uint256 conversion`, @@ -107,20 +104,17 @@ export class ExplicitConversionToFunc extends ASTMapper { (argType instanceof FixedBytesType && argType.size === 32) ) { const replacementCall = createCallToFunction( - createCairoFunctionStub( + ast.registerImport( + node, + 'warplib.maths.utils', 'uint256_to_address_felt', [['uint_arg', createUint256TypeName(ast)]], [['address_ret', createAddressTypeName(false, ast)]], - [], - ast, - node, ), [node.vArguments[0]], ast, ); - ast.replaceNode(node, replacementCall); - ast.registerImport(replacementCall, 'warplib.maths.utils', 'uint256_to_address_felt'); } else { ast.replaceNode(node, node.vArguments[0]); } @@ -130,20 +124,18 @@ export class ExplicitConversionToFunc extends ASTMapper { if (typeTo instanceof FixedBytesType) { if (argType instanceof AddressType) { const replacementCall = createCallToFunction( - createCairoFunctionStub( + ast.registerImport( + node, + 'warplib.maths.utils', 'felt_to_uint256', [['address_arg', createAddressTypeName(false, ast)]], [['uint_ret', createUint256TypeName(ast)]], - ['range_check_ptr'], - ast, - node, ), [node.vArguments[0]], ast, ); ast.replaceNode(node, replacementCall); - ast.registerImport(replacementCall, 'warplib.maths.utils', 'felt_to_uint256'); return; } else if (argType instanceof BytesType) { functionaliseBytesToFixedBytes(node, typeTo, ast); diff --git a/src/passes/builtinHandler/keccak.ts b/src/passes/builtinHandler/keccak.ts index 5a0857a7b..967a6c76c 100644 --- a/src/passes/builtinHandler/keccak.ts +++ b/src/passes/builtinHandler/keccak.ts @@ -1,7 +1,7 @@ import { DataLocation, ExternalReferenceType, FunctionCall } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; import { ASTMapper } from '../../ast/mapper'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCallToFunction } from '../../utils/functionGeneration'; import { createArrayTypeName, createBytesNTypeName, @@ -19,16 +19,14 @@ export class Keccak extends ASTMapper { return this.commonVisit(node, ast); } - const warpKeccak = createCairoFunctionStub( + const warpKeccak = ast.registerImport( + node, + 'warplib.keccak', 'warp_keccak', [['input', createArrayTypeName(createUintNTypeName(8, ast), ast), DataLocation.Memory]], [['hash', createBytesNTypeName(32, ast)]], - ['range_check_ptr', 'bitwise_ptr', 'warp_memory', 'keccak_ptr'], - ast, - node, ); - ast.registerImport(node, 'warplib.keccak', 'warp_keccak'); ast.replaceNode(node, createCallToFunction(warpKeccak, node.vArguments, ast)); this.commonVisit(node, ast); diff --git a/src/passes/builtinHandler/mathsOperationToFunction.ts b/src/passes/builtinHandler/mathsOperationToFunction.ts index 10aff008b..aa1a23d3e 100644 --- a/src/passes/builtinHandler/mathsOperationToFunction.ts +++ b/src/passes/builtinHandler/mathsOperationToFunction.ts @@ -8,7 +8,7 @@ import { import { AST } from '../../ast/ast'; import { ASTMapper } from '../../ast/mapper'; import { NotSupportedYetError } from '../../utils/errors'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCallToFunction } from '../../utils/functionGeneration'; import { createNumberLiteral, createUint256TypeName } from '../../utils/nodeTemplates'; import { functionaliseAdd } from '../../warplib/implementations/maths/add'; import { functionaliseAnd } from '../../warplib/implementations/maths/and'; @@ -105,20 +105,18 @@ export class MathsOperationToFunction extends ASTMapper { ) { if (['mulmod', 'addmod'].includes(node.vExpression.name)) { const name = `warp_${node.vExpression.name}`; - const cairoStub = createCairoFunctionStub( + const importedFunc = ast.registerImport( + node, + `warplib.maths.${node.vExpression.name}`, name, [ ['x', createUint256TypeName(ast)], ['y', createUint256TypeName(ast)], ], [['res', createUint256TypeName(ast)]], - [], - ast, - node, ); - const replacement = createCallToFunction(cairoStub, node.vArguments, ast); + const replacement = createCallToFunction(importedFunc, node.vArguments, ast); ast.replaceNode(node, replacement); - ast.registerImport(replacement, `warplib.maths.${node.vExpression.name}`, name); } } } diff --git a/src/passes/builtinHandler/msgSender.ts b/src/passes/builtinHandler/msgSender.ts index b6109345d..46c7207a6 100644 --- a/src/passes/builtinHandler/msgSender.ts +++ b/src/passes/builtinHandler/msgSender.ts @@ -1,7 +1,7 @@ import { MemberAccess, Identifier, ExternalReferenceType } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; import { ASTMapper } from '../../ast/mapper'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCallToFunction } from '../../utils/functionGeneration'; import { createAddressTypeName } from '../../utils/nodeTemplates'; export class MsgSender extends ASTMapper { @@ -13,23 +13,17 @@ export class MsgSender extends ASTMapper { node.memberName === 'sender' ) { const replacementCall = createCallToFunction( - createCairoFunctionStub( + ast.registerImport( + node, + 'starkware.starknet.common.syscalls', 'get_caller_address', [], [['address', createAddressTypeName(false, ast)]], - ['syscall_ptr'], - ast, - node, ), [], ast, ); ast.replaceNode(node, replacementCall); - ast.registerImport( - replacementCall, - 'starkware.starknet.common.syscalls', - 'get_caller_address', - ); } // Fine to recurse because there is a check that the member access is a Builtin. Therefor a.msg.sender should // not be picked up. diff --git a/src/passes/builtinHandler/thisKeyword.ts b/src/passes/builtinHandler/thisKeyword.ts index 2c12e0a22..4757c6952 100644 --- a/src/passes/builtinHandler/thisKeyword.ts +++ b/src/passes/builtinHandler/thisKeyword.ts @@ -8,7 +8,7 @@ import { } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; import { ASTMapper } from '../../ast/mapper'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCallToFunction } from '../../utils/functionGeneration'; import { CairoContract } from '../../ast/cairoNodes'; import { typeNameFromTypeNode } from '../../utils/utils'; import { @@ -22,23 +22,17 @@ export class ThisKeyword extends ASTMapper { visitIdentifier(node: Identifier, ast: AST): void { if (node.name === 'this') { const replacementCall = createCallToFunction( - createCairoFunctionStub( + ast.registerImport( + node, + 'starkware.starknet.common.syscalls', 'get_contract_address', [], [['address', typeNameFromTypeNode(safeGetNodeType(node, ast.inference), ast)]], - ['syscall_ptr'], - ast, - node, ), [], ast, ); ast.replaceNode(node, replacementCall); - ast.registerImport( - replacementCall, - 'starkware.starknet.common.syscalls', - 'get_contract_address', - ); } else { return; } diff --git a/src/passes/bytesConverter.ts b/src/passes/bytesConverter.ts index b95c92937..d235c0fd3 100644 --- a/src/passes/bytesConverter.ts +++ b/src/passes/bytesConverter.ts @@ -21,7 +21,7 @@ import { } from 'solc-typed-ast'; import { AST } from '../ast/ast'; import { ASTMapper } from '../ast/mapper'; -import { createCairoFunctionStub, createCallToFunction } from '../utils/functionGeneration'; +import { createCallToFunction } from '../utils/functionGeneration'; import { generateExpressionTypeString } from '../utils/getTypeString'; import { typeNameFromTypeNode } from '../utils/utils'; import { @@ -107,21 +107,15 @@ export class BytesConverter extends ASTMapper { callArgs.push(createNumberLiteral(width, ast, 'uint8')); } - const functionStub = createCairoFunctionStub( + const importedFunc = ast.registerImport( + node, + 'warplib.maths.bytes_access', selectWarplibFunction(baseTypeName, indexTypeName), stubParams, [['res', createUint8TypeName(ast)]], - ['bitwise_ptr', 'range_check_ptr'], - ast, - node, ); - const call = createCallToFunction(functionStub, callArgs, ast); - ast.registerImport( - call, - 'warplib.maths.bytes_access', - selectWarplibFunction(baseTypeName, indexTypeName), - ); + const call = createCallToFunction(importedFunc, callArgs, ast); ast.replaceNode(node, call, node.parent); const typeNode = replaceBytesType(safeGetNodeType(call, ast.inference)); call.typeString = generateExpressionTypeString(typeNode); diff --git a/src/passes/cairoUtilImporter.ts b/src/passes/cairoUtilImporter.ts index c51f27809..5fab01fd3 100644 --- a/src/passes/cairoUtilImporter.ts +++ b/src/passes/cairoUtilImporter.ts @@ -1,9 +1,18 @@ -import { ElementaryTypeName, IntType, Literal } from 'solc-typed-ast'; +import { + ElementaryTypeName, + IntType, + Literal, + SourceUnit, + StructDefinition, + UserDefinedType, + VariableDeclaration, +} from 'solc-typed-ast'; import { AST } from '../ast/ast'; import { CairoFunctionDefinition } from '../ast/cairoNodes'; import { ASTMapper } from '../ast/mapper'; +import { createImport } from '../utils/importFuncGenerator'; import { safeGetNodeType } from '../utils/nodeTypeProcessing'; -import { isExternallyVisible, primitiveTypeToCairo } from '../utils/utils'; +import { getContainingSourceUnit, isExternallyVisible, primitiveTypeToCairo } from '../utils/utils'; /* Analyses the tree after all processing has been done to find code the relies on @@ -11,7 +20,10 @@ import { isExternallyVisible, primitiveTypeToCairo } from '../utils/utils'; the warplib maths functions as they are added to the code, but for determining if Uint256 needs to be imported, it's easier to do it here */ + export class CairoUtilImporter extends ASTMapper { + private dummySourceUnit: SourceUnit | undefined; + // Function to add passes that should have been run before this pass addInitialPassPrerequisites(): void { const passKeys: Set = new Set([]); @@ -20,28 +32,47 @@ export class CairoUtilImporter extends ASTMapper { visitElementaryTypeName(node: ElementaryTypeName, ast: AST): void { if (primitiveTypeToCairo(node.name) === 'Uint256') { - ast.registerImport(node, 'starkware.cairo.common.uint256', 'Uint256'); + createImport('starkware.cairo.common.uint256', 'Uint256', this.dummySourceUnit ?? node, ast); } } visitLiteral(node: Literal, ast: AST): void { const type = safeGetNodeType(node, ast.inference); if (type instanceof IntType && type.nBits > 251) { - ast.registerImport(node, 'starkware.cairo.common.uint256', 'Uint256'); + createImport('starkware.cairo.common.uint256', 'Uint256', this.dummySourceUnit ?? node, ast); + } + } + + visitVariableDeclaration(node: VariableDeclaration, ast: AST): void { + const type = safeGetNodeType(node, ast.inference); + if (type instanceof IntType && type.nBits > 251) { + createImport('starkware.cairo.common.uint256', 'Uint256', this.dummySourceUnit ?? node, ast); + } + + // Patch to struct inlining + if (type instanceof UserDefinedType && type.definition instanceof StructDefinition) { + const currentSourceUnit = getContainingSourceUnit(node); + if (currentSourceUnit !== type.definition.getClosestParentByType(SourceUnit)) { + this.dummySourceUnit = this.dummySourceUnit ?? currentSourceUnit; + type.definition.walkChildren((child) => this.commonVisit(child, ast)); + this.dummySourceUnit = + this.dummySourceUnit === currentSourceUnit ? undefined : this.dummySourceUnit; + } } + this.visitExpression(node, ast); } visitCairoFunctionDefinition(node: CairoFunctionDefinition, ast: AST): void { if (node.implicits.has('warp_memory') && isExternallyVisible(node)) { - ast.registerImport(node, 'starkware.cairo.common.default_dict', 'default_dict_new'); - ast.registerImport(node, 'starkware.cairo.common.default_dict', 'default_dict_finalize'); - ast.registerImport(node, 'starkware.cairo.common.dict', 'dict_write'); + createImport('starkware.cairo.common.default_dict', 'default_dict_new', node, ast); + createImport('starkware.cairo.common.default_dict', 'default_dict_finalize', node, ast); + createImport('starkware.cairo.common.dict', 'dict_write', node, ast); } if (node.implicits.has('keccak_ptr') && isExternallyVisible(node)) { - ast.registerImport(node, 'starkware.cairo.common.cairo_keccak.keccak', 'finalize_keccak'); + createImport('starkware.cairo.common.cairo_keccak.keccak', 'finalize_keccak', node, ast); // Required to create a keccak_ptr - ast.registerImport(node, 'starkware.cairo.common.alloc', 'alloc'); + createImport('starkware.cairo.common.alloc', 'alloc', node, ast); } this.commonVisit(node, ast); diff --git a/src/passes/functionPruner/callGraph.ts b/src/passes/functionPruner/callGraph.ts index e4c8869ab..0beb65a78 100644 --- a/src/passes/functionPruner/callGraph.ts +++ b/src/passes/functionPruner/callGraph.ts @@ -1,6 +1,7 @@ import assert from 'assert'; import { FunctionCall, FunctionDefinition } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; +import { CairoGeneratedFunctionDefinition } from '../../ast/cairoNodes/cairoGeneratedFunctionDefinition'; import { ASTMapper } from '../../ast/mapper'; import { printNode } from '../../utils/astPrinter'; @@ -30,6 +31,14 @@ export class CallGraphBuilder extends ASTMapper { this.currentFunction = undefined; } + visitCairoGeneratedFunctionDefinition(node: CairoGeneratedFunctionDefinition, ast: AST): void { + this.currentFunction = node; + this.functionId.set(node.id, node); + this.callGraph.set(node.id, new Set(node.functionsCalled.map((funcDef) => funcDef.id))); + node.functionsCalled.forEach((funcDef) => this.commonVisit(funcDef, ast)); + this.currentFunction = undefined; + } + visitFunctionCall(node: FunctionCall, ast: AST) { assert( this.currentFunction !== undefined, @@ -41,7 +50,7 @@ export class CallGraphBuilder extends ASTMapper { `${printNode(this.currentFunction)} should have been added to the map`, ); const refFunc = node.vReferencedDeclaration; - if (refFunc !== undefined && refFunc instanceof FunctionDefinition) { + if (refFunc instanceof FunctionDefinition) { existingCalls.add(refFunc.id); this.callGraph.set(this.currentFunction.id, existingCalls); } diff --git a/src/passes/functionPruner/functionRemover.ts b/src/passes/functionPruner/functionRemover.ts index 41775f24f..739d9a82c 100644 --- a/src/passes/functionPruner/functionRemover.ts +++ b/src/passes/functionPruner/functionRemover.ts @@ -1,37 +1,47 @@ import assert from 'assert'; -import { ContractDefinition, FunctionDefinition } from 'solc-typed-ast'; +import { ContractDefinition, FunctionDefinition, SourceUnit } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; import { ASTMapper } from '../../ast/mapper'; import { printNode } from '../../utils/astPrinter'; import { isExternallyVisible } from '../../utils/utils'; export class FunctionRemover extends ASTMapper { - functionGraph: Map; + private functionGraph: Map; + private reachableFunctions: Set; constructor(graph: Map) { super(); this.functionGraph = graph; + this.reachableFunctions = new Set(); + } + + visitSourceUnit(node: SourceUnit, ast: AST): void { + node.vFunctions.filter((func) => isExternallyVisible(func)).forEach((func) => this.dfs(func)); + + node.vContracts.forEach((c) => this.visitContractDefinition(c, ast)); + + node.vFunctions + .filter((func) => !this.reachableFunctions.has(func.id)) + .forEach((func) => node.removeChild(func)); } visitContractDefinition(node: ContractDefinition, _ast: AST) { - const reachableFunctions: Set = new Set(); // Collect visible functions and obtain ids of all reachable functions - node.vFunctions - .filter((func) => isExternallyVisible(func)) - .forEach((func) => this.dfs(func, reachableFunctions)); + node.vFunctions.filter((func) => isExternallyVisible(func)).forEach((func) => this.dfs(func)); + // Remove unreachable functions node.vFunctions - .filter((func) => !reachableFunctions.has(func.id)) + .filter((func) => !this.reachableFunctions.has(func.id)) .forEach((func) => node.removeChild(func)); } - dfs(f: FunctionDefinition, visited: Set): void { - visited.add(f.id); + dfs(f: FunctionDefinition): void { + this.reachableFunctions.add(f.id); const functions = this.functionGraph.get(f.id); assert(functions !== undefined, `Function ${printNode(f)} was not added to the functionGraph`); functions.forEach((f) => { - if (!visited.has(f.id)) this.dfs(f, visited); + if (!this.reachableFunctions.has(f.id)) this.dfs(f); }); } } diff --git a/src/passes/newToDeploy.ts b/src/passes/newToDeploy.ts index 197bfff4a..b9ca009aa 100644 --- a/src/passes/newToDeploy.ts +++ b/src/passes/newToDeploy.ts @@ -17,11 +17,7 @@ import { AST } from '../ast/ast'; import { ASTMapper } from '../ast/mapper'; import { UserDefinedTypeName } from 'solc-typed-ast'; import assert from 'assert'; -import { - createCairoFunctionStub, - createCallToFunction, - createElementaryConversionCall, -} from '../utils/functionGeneration'; +import { createCallToFunction, createElementaryConversionCall } from '../utils/functionGeneration'; import { createAddressTypeName, createBoolLiteral, @@ -144,7 +140,9 @@ export class NewToDeploy extends ASTMapper { salt: Expression, ast: AST, ): FunctionCall { - const deployStub = createCairoFunctionStub( + const deployFunc = ast.registerImport( + node, + 'starkware.starknet.common.syscalls', 'deploy', [ ['class_hash', createAddressTypeName(false, ast)], @@ -153,19 +151,15 @@ export class NewToDeploy extends ASTMapper { ['deploy_from_zero', createBoolTypeName(ast)], ], [['contract_address', cloneASTNode(typeName, ast)]], - ['syscall_ptr'], - ast, - node, { acceptsUnpackedStructArray: true }, ); - ast.registerImport(node, 'starkware.starknet.common.syscalls', 'deploy'); const encodedArguments = ast .getUtilFuncGen(node) .utils.encodeAsFelt.gen(node.vArguments, getParameterTypes(node, ast)); const deployFromZero = createBoolLiteral(false, ast); return createCallToFunction( - deployStub, + deployFunc, [placeHolderIdentifier, salt, encodedArguments, deployFromZero], ast, node, diff --git a/src/passes/references/arrayFunctions.ts b/src/passes/references/arrayFunctions.ts index 8e0fb294a..6a30fab4e 100644 --- a/src/passes/references/arrayFunctions.ts +++ b/src/passes/references/arrayFunctions.ts @@ -5,14 +5,12 @@ import { ExternalReferenceType, FixedBytesType, FunctionCall, - FunctionStateMutability, generalizeType, MemberAccess, PointerType, } from 'solc-typed-ast'; import { AST } from '../../ast/ast'; -import { FunctionStubKind } from '../../ast/cairoNodes'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCallToFunction } from '../../utils/functionGeneration'; import { getSize, isDynamicArray, @@ -108,22 +106,15 @@ export class ArrayFunctions extends ReferenceSubPass { const parent = node.parent; const type = generalizeType(safeGetNodeType(node, ast.inference))[0]; - const funcStub = createCairoFunctionStub( + const importedFunc = ast.registerImport( + node, + 'warplib.maths.utils', 'felt_to_uint256', [['cd_dstruct_array_len', typeNameFromTypeNode(type, ast)]], [['len256', typeNameFromTypeNode(type, ast)]], - ['range_check_ptr'], - ast, - node, - { - mutability: FunctionStateMutability.Pure, - stubKind: FunctionStubKind.FunctionDefStub, - }, ); - const funcCall = createCallToFunction(funcStub, [node], ast); - - ast.registerImport(funcCall, 'warplib.maths.utils', 'felt_to_uint256'); + const funcCall = createCallToFunction(importedFunc, [node], ast); this.replace( node, @@ -151,7 +142,7 @@ export class ArrayFunctions extends ReferenceSubPass { } else { const replacement = baseType.location === DataLocation.Storage - ? ast.getUtilFuncGen(node).storage.dynArrayLength.gen(node, baseType.to) + ? ast.getUtilFuncGen(node).storage.dynArray.genLength(node, baseType.to) : ast.getUtilFuncGen(node).memory.dynArrayLength.gen(node, ast); // The length function returns the actual length rather than a storage pointer to it, // so the new actual location is Default diff --git a/src/passes/references/dataAccessFunctionaliser.ts b/src/passes/references/dataAccessFunctionaliser.ts index dfe667f01..d83a8d4d8 100644 --- a/src/passes/references/dataAccessFunctionaliser.ts +++ b/src/passes/references/dataAccessFunctionaliser.ts @@ -27,7 +27,7 @@ import { printNode, printTypeNode } from '../../utils/astPrinter'; import { AST } from '../../ast/ast'; import { isCairoConstant, typeNameFromTypeNode } from '../../utils/utils'; import { error } from '../../utils/formatting'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCallToFunction } from '../../utils/functionGeneration'; import { createNumberLiteral, createUint256TypeName, @@ -67,17 +67,16 @@ export class DataAccessFunctionaliser extends ReferenceSubPass { return this.commonVisit(node, ast); } - const nodeType = safeGetNodeType(node, ast.inference); const utilFuncGen = ast.getUtilFuncGen(node); const parent = node.parent; // Finally if a copy from actual to expected location is required, insert this last - let copyFunc: Expression | null = null; + let copyFunc: FunctionCall | null = null; if (actualLoc !== expectedLoc) { if (actualLoc === DataLocation.Storage) { switch (expectedLoc) { case DataLocation.Default: { - copyFunc = utilFuncGen.storage.read.gen(node, typeNameFromTypeNode(nodeType, ast)); + copyFunc = utilFuncGen.storage.read.gen(node); break; } case DataLocation.Memory: { @@ -92,7 +91,7 @@ export class DataAccessFunctionaliser extends ReferenceSubPass { } else if (actualLoc === DataLocation.Memory) { switch (expectedLoc) { case DataLocation.Default: { - copyFunc = utilFuncGen.memory.read.gen(node, typeNameFromTypeNode(nodeType, ast)); + copyFunc = utilFuncGen.memory.read.gen(node); break; } case DataLocation.Storage: { @@ -243,7 +242,9 @@ export class DataAccessFunctionaliser extends ReferenceSubPass { assert( (type instanceof UserDefinedType && type.definition instanceof ContractDefinition) || type instanceof FixedBytesType, - `Unexpected unhandled non-pointer non-contract member access. Found at ${printNode(node)}`, + `Unexpected unhandled non-pointer non-contract member access. Found at ${printNode( + node, + )}: '${node.memberName}' with type ${printTypeNode(type)}`, ); return this.visitExpression(node, ast); } @@ -324,7 +325,9 @@ function createMemoryDynArrayIndexAccess(indexAccess: IndexAccess, ast: AST): Fu ? cloneASTNode(arrayTypeName.vBaseType, ast) : createUint8TypeName(ast); - const stub = createCairoFunctionStub( + const importedFunc = ast.registerImport( + indexAccess, + 'warplib.memory', 'wm_index_dyn', [ ['arrayLoc', arrayTypeName, DataLocation.Memory], @@ -332,9 +335,6 @@ function createMemoryDynArrayIndexAccess(indexAccess: IndexAccess, ast: AST): Fu ['width', createUint256TypeName(ast)], ], [['loc', returnTypeName, DataLocation.Memory]], - ['range_check_ptr', 'warp_memory'], - ast, - indexAccess, ); assert(indexAccess.vIndexExpression); @@ -350,7 +350,7 @@ function createMemoryDynArrayIndexAccess(indexAccess: IndexAccess, ast: AST): Fu ).width; const call = createCallToFunction( - stub, + importedFunc, [ indexAccess.vBaseExpression, indexAccess.vIndexExpression, @@ -359,8 +359,6 @@ function createMemoryDynArrayIndexAccess(indexAccess: IndexAccess, ast: AST): Fu ast, ); - ast.registerImport(call, 'warplib.memory', 'wm_index_dyn'); - return call; } diff --git a/src/passes/references/externalReturnReceiver.ts b/src/passes/references/externalReturnReceiver.ts index fb36c31ff..6fc177366 100644 --- a/src/passes/references/externalReturnReceiver.ts +++ b/src/passes/references/externalReturnReceiver.ts @@ -60,7 +60,7 @@ export class ExternalReturnReceiver extends ASTMapper { function addOutputValidation(decl: VariableDeclaration, ast: AST) { const validationFunctionCall = ast .getUtilFuncGen(decl) - .boundChecks.inputCheck.gen(decl, safeGetNodeType(decl, ast.inference), decl); + .boundChecks.inputCheck.gen(decl, safeGetNodeType(decl, ast.inference)); const validationStatement = createExpressionStatement(ast, validationFunctionCall); ast.insertStatementAfter(decl, validationStatement); } diff --git a/src/passes/references/memoryAllocations.ts b/src/passes/references/memoryAllocations.ts index 87a85743a..26740211a 100644 --- a/src/passes/references/memoryAllocations.ts +++ b/src/passes/references/memoryAllocations.ts @@ -16,7 +16,7 @@ import { AST } from '../../ast/ast'; import { printNode } from '../../utils/astPrinter'; import { CairoType, TypeConversionContext } from '../../utils/cairoTypeSystem'; import { NotSupportedYetError } from '../../utils/errors'; -import { createCairoFunctionStub, createCallToFunction } from '../../utils/functionGeneration'; +import { createCallToFunction } from '../../utils/functionGeneration'; import { createNumberLiteral, createUint256TypeName } from '../../utils/nodeTemplates'; import { getElementType, safeGetNodeType } from '../../utils/nodeTypeProcessing'; @@ -78,16 +78,15 @@ export class MemoryAllocations extends ReferenceSubPass { }`, ); - const stub = createCairoFunctionStub( + const funcImport = ast.registerImport( + node, + 'warplib.memory', 'wm_new', [ ['len', createUint256TypeName(ast)], ['elemWidth', createUint256TypeName(ast)], ], [['loc', node.vExpression.vTypeName, DataLocation.Memory]], - ['range_check_ptr', 'warp_memory'], - ast, - node, ); const arrayType = generalizeType(safeGetNodeType(node, ast.inference))[0]; @@ -104,13 +103,12 @@ export class MemoryAllocations extends ReferenceSubPass { ); const call = createCallToFunction( - stub, + funcImport, [node.vArguments[0], createNumberLiteral(elementCairoType.width, ast, 'uint256')], ast, ); const [actualLoc, expectedLoc] = this.getLocations(node); this.replace(node, call, undefined, actualLoc, expectedLoc, ast); - ast.registerImport(call, 'warplib.memory', 'wm_new'); } } diff --git a/src/passes/references/storedPointerDereference.ts b/src/passes/references/storedPointerDereference.ts index 8b9ae1a38..4175201a8 100644 --- a/src/passes/references/storedPointerDereference.ts +++ b/src/passes/references/storedPointerDereference.ts @@ -14,7 +14,6 @@ import { isReferenceType, safeGetNodeType, } from '../../utils/nodeTypeProcessing'; -import { typeNameFromTypeNode } from '../../utils/utils'; import { ReferenceSubPass } from './referenceSubPass'; export class StoredPointerDereference extends ReferenceSubPass { @@ -33,9 +32,9 @@ export class StoredPointerDereference extends ReferenceSubPass { // Next, if the node is a type that requires an extra read, insert this first let readFunc: FunctionCall | null = null; if (actualLoc === DataLocation.Storage && (isDynamicArray(nodeType) || isMapping(nodeType))) { - readFunc = utilFuncGen.storage.read.gen(node, typeNameFromTypeNode(nodeType, ast), parent); + readFunc = utilFuncGen.storage.read.gen(node); } else if (actualLoc === DataLocation.Memory && isReferenceType(nodeType)) { - readFunc = utilFuncGen.memory.read.gen(node, typeNameFromTypeNode(nodeType, ast), parent); + readFunc = utilFuncGen.memory.read.gen(node); } if (readFunc !== null) { this.replace(node, readFunc, parent, actualLoc, expectedLoc, ast); diff --git a/src/solWriter.ts b/src/solWriter.ts index 2658674ae..fd2e3120e 100644 --- a/src/solWriter.ts +++ b/src/solWriter.ts @@ -12,6 +12,8 @@ import { CairoAssert, CairoContract, CairoFunctionDefinition, + CairoGeneratedFunctionDefinition, + CairoImportFunctionDefinition, FunctionStubKind, } from './ast/cairoNodes'; @@ -94,6 +96,18 @@ class CairoFunctionDefinitionSolWriter extends ASTNodeWriter { } } +class CairoGeneratedFunctionDefinitionSolWriter extends ASTNodeWriter { + writeInner(node: CairoGeneratedFunctionDefinition, _writer: ASTWriter): SrcDesc { + return [node.rawStringDefinition]; + } +} + +class CairoImportFunctionDefinitionSolWriter extends ASTNodeWriter { + writeInner(node: CairoImportFunctionDefinition, _writer: ASTWriter): SrcDesc { + return [`from ${node.path} import ${node.name}`]; + } +} + class CairoAssertSolWriter extends ASTNodeWriter { writeInner(node: CairoAssert, writer: ASTWriter): SrcDesc { const result: SrcDesc = []; @@ -106,6 +120,8 @@ const CairoExtendedASTWriterMapping = (printStubs: boolean) => new Map, ASTNodeWriter>([ [CairoContract, new CairoContractSolWriter()], [CairoFunctionDefinition, new CairoFunctionDefinitionSolWriter(printStubs)], + [CairoGeneratedFunctionDefinition, new CairoGeneratedFunctionDefinitionSolWriter()], + [CairoImportFunctionDefinition, new CairoImportFunctionDefinitionSolWriter()], [CairoAssert, new CairoAssertSolWriter()], ]); diff --git a/src/utils/cairoParsing.ts b/src/utils/cairoParsing.ts new file mode 100644 index 000000000..904794855 --- /dev/null +++ b/src/utils/cairoParsing.ts @@ -0,0 +1,88 @@ +import assert from 'assert'; +import { TranspileFailedError } from './errors'; +import { Implicits, implicitTypes } from './implicits'; + +export type RawCairoFunctionInfo = { + name: string; + implicits: Implicits[]; +}; + +/** + * Given several Cairo function represented in plain text extracts information from it + * @param rawFunctions Multiple cairo functions in a single text + * @returns A list of each function information + */ +export function parseMultipleRawCairoFunctions(rawFunctions: string): RawCairoFunctionInfo[] { + const functions = rawFunctions.matchAll(/func (\w+)\s*[{]?.*?[}]?\s*[(].*?[)]/gis); + + return [...functions].map((func) => getRawCairoFunctionInfo(func[0])); +} + +/** + * Given a Cairo function represented in plain text extracts information from it + * @param rawFunction Cairo code + * @returns The function implicits and it's name + */ +export function getRawCairoFunctionInfo(rawFunction: string): RawCairoFunctionInfo { + // Todo: Update match so implicit can be empty and there is a version of them without keys + const funcSignature = + rawFunction.match(/func (?\w+)\s*[{](?.+)[}]/is) ?? + rawFunction.match(/func (?\w+)\s*/); + + assert( + funcSignature !== null && funcSignature.groups !== undefined, + `Invalid parsing of raw string function:\n${rawFunction}`, + ); + + const name = funcSignature.groups.name; + const implicits = + funcSignature.groups.implicits !== undefined + ? parseImplicits(funcSignature.groups.implicits) + : []; + + return { name, implicits }; +} + +/** + * @param rawImplicits implicits in string form ?\{ impl1, impl2:type2, ... ?\} + * @returns a list of each Implicit after checking it's valid + */ +export function parseImplicits(rawImplicits: string): Implicits[] { + const matchedImplicits = + rawImplicits.match(/[{](?[a-zA-Z0-9:,_*\n ]*)[}]/) ?? + rawImplicits.match(/(?[a-zA-Z0-9:,_*\n ]*)/); + + assert( + matchedImplicits !== null && matchedImplicits.groups !== undefined, + `Failure to parse implicits: '${rawImplicits}'`, + ); + + // implicits -> impl1 : type1, impl2, ..., impln : typen + const implicits = matchedImplicits.groups.implicits; + + // implicitsList -> [impl1 : type1, impl2, ...., impln : typen] + const implicitsList = [...implicits.matchAll(/[A-Za-z][A-Za-z_: 0-9]*/g)].map((w) => w[0]); + + // implicitsNameList -> [impl1, impl2, ..., impln] + const implicitsNameList = implicitsList.map((i) => i.match(/[A-Za-z][A-Za-z_0-9]*/)); + + assert(notContainsNull(implicitsNameList), 'Failure to parse implicits: Invalid implicit name'); + + return implicitsNameList.map((i) => { + const impl = i[0]; + if (!elementIsImplicit(impl)) { + throw new TranspileFailedError( + `Implicit '${impl}' defined on raw string is not known: '${rawImplicits}'`, + ); + } + return impl; + }); +} + +function elementIsImplicit(e: string): e is Implicits { + return Object.keys(implicitTypes).includes(e); +} + +function notContainsNull(l: (T | null)[]): l is T[] { + return !l.some((e) => e === null); +} diff --git a/src/utils/cloning.ts b/src/utils/cloning.ts index 4a8446d66..ff7526274 100644 --- a/src/utils/cloning.ts +++ b/src/utils/cloning.ts @@ -520,7 +520,6 @@ function cloneASTNodeImpl( if (notNull(newNode) && sameType(newNode, node)) { ast.setContextRecursive(newNode); - ast.copyRegisteredImports(node, newNode); return newNode; } else { throw new NotSupportedYetError(`Unable to clone ${printNode(node)}`); diff --git a/src/utils/functionGeneration.ts b/src/utils/functionGeneration.ts index c7cc23cae..3d4dba4eb 100644 --- a/src/utils/functionGeneration.ts +++ b/src/utils/functionGeneration.ts @@ -20,7 +20,8 @@ import { VariableDeclaration, } from 'solc-typed-ast'; import { AST } from '../ast/ast'; -import { CairoFunctionDefinition, FunctionStubKind } from '../ast/cairoNodes'; +import { CairoImportFunctionDefinition, FunctionStubKind } from '../ast/cairoNodes'; +import { CairoGeneratedFunctionDefinition } from '../ast/cairoNodes/cairoGeneratedFunctionDefinition'; import { getFunctionTypeString, getReturnTypeString } from './getTypeString'; import { Implicits } from './implicits'; import { createIdentifier, createParameterList } from './nodeTemplates'; @@ -72,11 +73,12 @@ interface CairoFunctionStubOptions { acceptsUnpackedStructArray?: boolean; } -export function createCairoFunctionStub( - name: string, - inputs: ([string, TypeName] | [string, TypeName, DataLocation])[], - returns: ([string, TypeName] | [string, TypeName, DataLocation])[], - implicits: Implicits[], +export type ParameterInfo = [string, TypeName] | [string, TypeName, DataLocation]; + +export function createCairoGeneratedFunction( + genFuncInfo: { name: string; code: string; functionsCalled: FunctionDefinition[] }, + inputs: ParameterInfo[], + returns: ParameterInfo[], ast: AST, nodeInSourceUnit: ASTNode, options: CairoFunctionStubOptions = { @@ -85,46 +87,24 @@ export function createCairoFunctionStub( acceptsRawDArray: false, acceptsUnpackedStructArray: false, }, -): CairoFunctionDefinition { +): CairoGeneratedFunctionDefinition { const sourceUnit = ast.getContainingRoot(nodeInSourceUnit); const funcDefId = ast.reserveId(); - const createParameters = (inputs: ([string, TypeName] | [string, TypeName, DataLocation])[]) => - inputs.map( - ([name, type, location]) => - new VariableDeclaration( - ast.reserveId(), - '', - false, // constant - false, // indexed - name, - funcDefId, - false, // stateVariable - location ?? DataLocation.Default, - StateVariableVisibility.Private, - Mutability.Mutable, - type.typeString, - undefined, - type, - ), - ); - - const funcDef = new CairoFunctionDefinition( + const funcDef = new CairoGeneratedFunctionDefinition( funcDefId, '', sourceUnit.id, FunctionKind.Function, - name, - false, // virtual + genFuncInfo.name, FunctionVisibility.Private, options.mutability ?? FunctionStateMutability.NonPayable, - false, // isConstructor - createParameterList(createParameters(inputs), ast), - createParameterList(createParameters(returns), ast), - [], - new Set(implicits), + createParameterList(createParameters(inputs, funcDefId, ast), ast), + createParameterList(createParameters(returns, funcDefId, ast), ast), options.stubKind ?? FunctionStubKind.FunctionDefStub, - options.acceptsRawDArray ?? false, - options.acceptsUnpackedStructArray ?? false, + genFuncInfo.code, + genFuncInfo.functionsCalled, + options.acceptsRawDArray, + options.acceptsUnpackedStructArray, ); ast.setContextRecursive(funcDef); @@ -133,6 +113,97 @@ export function createCairoFunctionStub( return funcDef; } +export function createCairoImportFunctionDefintion( + funcName: string, + path: string, + implicits: Set, + params: ParameterInfo[], + retParams: ParameterInfo[], + ast: AST, + nodeInSourceUnit: ASTNode, + options: CairoFunctionStubOptions = { + acceptsRawDArray: false, + acceptsUnpackedStructArray: false, + }, +): CairoImportFunctionDefinition { + const sourceUnit = ast.getContainingRoot(nodeInSourceUnit); + + const id = ast.reserveId(); + const scope = sourceUnit.id; + + const funcDef = new CairoImportFunctionDefinition( + id, + '', + scope, + funcName, + path, + implicits, + createParameterList(createParameters(params, id, ast), ast), + createParameterList(createParameters(retParams, id, ast), ast), + FunctionStubKind.FunctionDefStub, + options.acceptsRawDArray, + options.acceptsUnpackedStructArray, + ); + ast.setContextRecursive(funcDef); + sourceUnit.insertAtBeginning(funcDef); + return funcDef; +} + +export function createCairoImportStructDefinition( + structName: string, + path: string, + ast: AST, + nodeInSourceUnit: ASTNode, +): CairoImportFunctionDefinition { + const sourceUnit = ast.getContainingRoot(nodeInSourceUnit); + + const id = ast.reserveId(); + const scope = sourceUnit.id; + + const implicits = new Set(); + const params = createParameterList([], ast); + const retParams = createParameterList([], ast); + const funcDef = new CairoImportFunctionDefinition( + id, + '', + scope, + structName, + path, + implicits, + params, + retParams, + FunctionStubKind.StructDefStub, + ); + ast.setContextRecursive(funcDef); + sourceUnit.insertAtBeginning(funcDef); + return funcDef; +} + +function createParameters( + inputs: ([string, TypeName] | [string, TypeName, DataLocation])[], + funcDefId: number, + ast: AST, +) { + return inputs.map( + ([name, type, location]) => + new VariableDeclaration( + ast.reserveId(), + '', + false, // constant + false, // indexed + name, + funcDefId, + false, // stateVariable + location ?? DataLocation.Default, + StateVariableVisibility.Private, + Mutability.Mutable, + type.typeString, + undefined, + type, + ), + ); +} + export function createElementaryConversionCall( typeTo: ElementaryTypeName, expression: Expression, diff --git a/src/utils/implicits.ts b/src/utils/implicits.ts index 1f6043312..d6aa09d40 100644 --- a/src/utils/implicits.ts +++ b/src/utils/implicits.ts @@ -7,7 +7,8 @@ export type Implicits = | 'range_check_ptr' | 'syscall_ptr' | 'warp_memory' - | 'keccak_ptr'; + | 'keccak_ptr' + | 'dict_ptr'; export type CairoBuiltin = 'bitwise' | 'pedersen' | 'range_check'; const implicitsOrder = { @@ -17,6 +18,7 @@ const implicitsOrder = { bitwise_ptr: 3, warp_memory: 4, keccak_ptr: 5, + dict_ptr: 6, }; export function implicitOrdering(a: Implicits, b: Implicits): number { @@ -30,18 +32,19 @@ export const implicitTypes: { [key in Implicits]: string } = { syscall_ptr: 'felt*', warp_memory: 'DictAccess*', keccak_ptr: 'felt*', + dict_ptr: 'DictAccess*', }; export function registerImportsForImplicit(ast: AST, node: ASTNode, implicit: Implicits) { switch (implicit) { case 'bitwise_ptr': - ast.registerImport(node, 'starkware.cairo.common.cairo_builtins', 'BitwiseBuiltin'); + ast.registerImport(node, 'starkware.cairo.common.cairo_builtins', 'BitwiseBuiltin', [], []); break; case 'pedersen_ptr': - ast.registerImport(node, 'starkware.cairo.common.cairo_builtins', 'HashBuiltin'); + ast.registerImport(node, 'starkware.cairo.common.cairo_builtins', 'HashBuiltin', [], []); break; case 'warp_memory': - ast.registerImport(node, 'starkware.cairo.common.dict_access', 'DictAccess'); + ast.registerImport(node, 'starkware.cairo.common.dict_access', 'DictAccess', [], []); break; } } @@ -53,4 +56,5 @@ export const requiredBuiltin: { [key in Implicits]: CairoBuiltin | null } = { syscall_ptr: null, warp_memory: null, keccak_ptr: null, + dict_ptr: null, }; diff --git a/src/utils/importFuncGenerator.ts b/src/utils/importFuncGenerator.ts new file mode 100644 index 000000000..93f6530e8 --- /dev/null +++ b/src/utils/importFuncGenerator.ts @@ -0,0 +1,133 @@ +import { ASTNode, SourceUnit } from 'solc-typed-ast'; +import { CairoImportFunctionDefinition } from '../ast/cairoNodes'; +import { AST } from '../ast/ast'; +import { TranspileFailedError } from '../utils/errors'; +import { warplibImportInfo } from '../warplib/gatherWarplibImports'; +import { Implicits } from './implicits'; +import { + createCairoImportFunctionDefintion, + createCairoImportStructDefinition, + ParameterInfo, +} from './functionGeneration'; +import { getContainingSourceUnit } from './utils'; + +// Paths +const STARKWARE_CAIRO_COMMON_ALLOC = 'starkware.cairo.common.alloc'; +const STARKWARE_CAIRO_COMMON_BUILTINS = 'starkware.cairo.common.cairo_builtins'; +const STARKWARE_CAIRO_COMMON_CAIRO_KECCAK = 'starkware.cairo.common.cairo_keccak.keccak'; +const STARKWARE_CAIRO_COMMON_DEFAULT_DICT = 'starkware.cairo.common.default_dict'; +const STARKWARE_CAIRO_COMMON_DICT = 'starkware.cairo.common.dict'; +const STARKWARE_CAIRO_COMMON_DICT_ACCESS = 'starkware.cairo.common.dict_access'; +const STARKWARE_CAIRO_COMMON_MATH = 'starkware.cairo.common.math'; +const STARKWARE_CAIRO_COMMON_MATH_CMP = 'starkware.cairo.common.math_cmp'; +const STARKWARE_CAIRO_COMMON_UINT256 = 'starkware.cairo.common.uint256'; +const STARKWARE_STARKNET_COMMON_SYSCALLS = 'starkware.starknet.common.syscalls'; + +export function createImport( + path: string, + name: string, + nodeInSourceUnit: ASTNode, + ast: AST, + inputs?: ParameterInfo[], + outputs?: ParameterInfo[], + options?: { acceptsRawDarray?: boolean; acceptsUnpackedStructArray?: boolean }, +) { + const sourceUnit = getContainingSourceUnit(nodeInSourceUnit); + + const existingImport = findExistingImport(name, sourceUnit); + if (existingImport !== undefined) { + const hasInputs = inputs !== undefined && inputs.length > 0; + const hasOutputs = outputs !== undefined && outputs.length > 0; + if (!hasInputs || !hasOutputs) return existingImport; + return createCairoImportFunctionDefintion( + name, + path, + existingImport.implicits, + inputs, + outputs, + ast, + sourceUnit, + options, + ); + } + + const createFuncImport = (...implicits: Implicits[]) => + createCairoImportFunctionDefintion( + name, + path, + new Set(implicits), + inputs ?? [], + outputs ?? [], + ast, + sourceUnit, + options, + ); + const createStructImport = () => createCairoImportStructDefinition(name, path, ast, sourceUnit); + + const warplibFunc = warplibImportInfo.get(path)?.get(name); + if (warplibFunc !== undefined) { + return createFuncImport(...warplibFunc); + } + + switch (path + name) { + case STARKWARE_CAIRO_COMMON_ALLOC + 'alloc': + return createFuncImport(); + case STARKWARE_CAIRO_COMMON_BUILTINS + 'BitwiseBuiltin': + return createStructImport(); + case STARKWARE_CAIRO_COMMON_BUILTINS + 'HashBuiltin': + return createStructImport(); + case STARKWARE_CAIRO_COMMON_CAIRO_KECCAK + 'finalize_keccak': + return createFuncImport('range_check_ptr', 'bitwise_ptr'); + case STARKWARE_CAIRO_COMMON_DEFAULT_DICT + 'default_dict_new': + return createFuncImport(); + case STARKWARE_CAIRO_COMMON_DEFAULT_DICT + 'default_dict_finalize': + return createFuncImport('range_check_ptr'); + case STARKWARE_CAIRO_COMMON_DICT + 'dict_read': + return createFuncImport('dict_ptr'); + case STARKWARE_CAIRO_COMMON_DICT + 'dict_write': + return createFuncImport('dict_ptr'); + case STARKWARE_CAIRO_COMMON_DICT_ACCESS + 'DictAccess': + return createStructImport(); + case STARKWARE_CAIRO_COMMON_MATH + 'split_felt': + return createFuncImport('range_check_ptr'); + case STARKWARE_CAIRO_COMMON_MATH_CMP + 'is_le': + return createFuncImport('range_check_ptr'); + case STARKWARE_CAIRO_COMMON_MATH_CMP + 'is_le_felt': + return createFuncImport('range_check_ptr'); + case STARKWARE_CAIRO_COMMON_UINT256 + 'Uint256': + return createStructImport(); + case STARKWARE_CAIRO_COMMON_UINT256 + 'uint256_add': + return createFuncImport('range_check_ptr'); + case STARKWARE_CAIRO_COMMON_UINT256 + 'uint256_eq': + return createFuncImport('range_check_ptr'); + case STARKWARE_CAIRO_COMMON_UINT256 + 'uint256_le': + return createFuncImport('range_check_ptr'); + case STARKWARE_CAIRO_COMMON_UINT256 + 'uint256_lt': + return createFuncImport('range_check_ptr'); + case STARKWARE_CAIRO_COMMON_UINT256 + 'uint256_mul': + return createFuncImport('range_check_ptr'); + case STARKWARE_CAIRO_COMMON_UINT256 + 'uint256_sub': + return createFuncImport('range_check_ptr'); + case STARKWARE_STARKNET_COMMON_SYSCALLS + 'deploy': + return createFuncImport('syscall_ptr'); + case STARKWARE_STARKNET_COMMON_SYSCALLS + 'emit_event': + return createFuncImport('syscall_ptr'); + case STARKWARE_STARKNET_COMMON_SYSCALLS + 'get_caller_address': + return createFuncImport('syscall_ptr'); + case STARKWARE_STARKNET_COMMON_SYSCALLS + 'get_contract_address': + return createFuncImport('syscall_ptr'); + default: + throw new TranspileFailedError(`Import ${name} from ${path} is not defined.`); + } +} + +function findExistingImport( + name: string, + node: SourceUnit, +): CairoImportFunctionDefinition | undefined { + const found = node.vFunctions.filter( + (n): n is CairoImportFunctionDefinition => + n instanceof CairoImportFunctionDefinition && n.name === name, + ); + return found[0]; +} diff --git a/src/utils/nodeTemplates.ts b/src/utils/nodeTemplates.ts index 6b28b9dfb..9b4e99785 100644 --- a/src/utils/nodeTemplates.ts +++ b/src/utils/nodeTemplates.ts @@ -50,14 +50,8 @@ export function createAddressTypeName(payable: boolean, ast: AST): ElementaryTyp return node; } -export function createStringTypeName(payable: boolean, ast: AST): ElementaryTypeName { - const node = new ElementaryTypeName( - ast.reserveId(), - '', - 'string', - 'string', - payable ? 'payable' : 'nonpayable', - ); +export function createStringTypeName(ast: AST): ElementaryTypeName { + const node = new ElementaryTypeName(ast.reserveId(), '', 'string', 'string', 'nonpayable'); ast.setContextRecursive(node); return node; } diff --git a/src/utils/utils.ts b/src/utils/utils.ts index f9283142e..35ed08fd3 100644 --- a/src/utils/utils.ts +++ b/src/utils/utils.ts @@ -40,6 +40,7 @@ import { SourceLocation, SourceUnit, StateVariableVisibility, + StringLiteralType, StringType, StructDefinition, TimeUnit, @@ -253,7 +254,7 @@ export function typeNameFromTypeNode(node: TypeNode, ast: AST): TypeName { node.definition.id, new IdentifierPath(ast.reserveId(), '', node.definition.name, node.definition.id), ); - } else if (node instanceof StringType) { + } else if (node instanceof StringType || node instanceof StringLiteralType) { return new ElementaryTypeName(ast.reserveId(), '', 'string', 'string', 'nonpayable'); } @@ -265,17 +266,6 @@ export function typeNameFromTypeNode(node: TypeNode, ast: AST): TypeName { return result; } -export function mergeImports(...maps: Map>[]): Map> { - return maps.reduce((acc, curr) => { - curr.forEach((importedSymbols, location) => { - const accSet = acc.get(location) ?? new Set(); - importedSymbols.forEach((s) => accSet.add(s)); - acc.set(location, accSet); - }); - return acc; - }, new Map>()); -} - export function groupBy(arr: V[], groupFunc: (arg: V) => K): Map> { const grouped = new Map>(); arr.forEach((v) => { @@ -565,6 +555,10 @@ export function getContainingFunction(node: ASTNode): FunctionDefinition { } export function getContainingSourceUnit(node: ASTNode): SourceUnit { + if (node instanceof SourceUnit) { + return node; + } + const root = node.getClosestParentByType(SourceUnit); assert(root !== undefined, `Unable to find root source unit for ${printNode(node)}`); return root; diff --git a/src/warplib/gatherWarplibImports.ts b/src/warplib/gatherWarplibImports.ts new file mode 100644 index 000000000..3c7d5bc58 --- /dev/null +++ b/src/warplib/gatherWarplibImports.ts @@ -0,0 +1,28 @@ +import fs from 'fs'; +import { Implicits } from '../utils/implicits'; +import { parseMultipleRawCairoFunctions } from '../utils/cairoParsing'; +import { glob } from 'glob'; + +export const warplibImportInfo = glob + .sync('warplib/**/*.cairo') + .reduce((warplibMap, pathToFile) => { + const rawCairoCode = fs.readFileSync(pathToFile, { encoding: 'utf8' }); + + const importPath = pathToFile + .split('/') + .join('.') + .slice(0, pathToFile.length - '.cairo'.length); + + const fileMap: Map = + warplibMap.get(importPath) ?? new Map(); + + if (!warplibMap.has(importPath)) { + warplibMap.set(importPath, fileMap); + } + + parseMultipleRawCairoFunctions(rawCairoCode).forEach((cairoFunc) => + fileMap.set(cairoFunc.name, cairoFunc.implicits), + ); + + return warplibMap; + }, new Map>()); diff --git a/src/warplib/generateWarplib.ts b/src/warplib/generateWarplib.ts index 6a678ce08..50e98d118 100644 --- a/src/warplib/generateWarplib.ts +++ b/src/warplib/generateWarplib.ts @@ -1,6 +1,6 @@ +import { generateFile, WarplibFunctionInfo } from './utils'; import { int_conversions } from './implementations/conversions/int'; import { add, add_unsafe, add_signed, add_signed_unsafe } from './implementations/maths/add'; -import { bitwise_not } from './implementations/maths/bitwiseNot'; import { div_signed, div_signed_unsafe } from './implementations/maths/div'; import { exp, exp_signed, exp_signed_unsafe, exp_unsafe } from './implementations/maths/exp'; import { ge_signed } from './implementations/maths/ge'; @@ -13,64 +13,54 @@ import { negate } from './implementations/maths/negate'; import { shl } from './implementations/maths/shl'; import { shr, shr_signed } from './implementations/maths/shr'; import { sub_unsafe, sub_signed, sub_signed_unsafe } from './implementations/maths/sub'; +import { bitwise_not } from './implementations/maths/bitwiseNot'; import { external_input_check_ints } from './implementations/external_input_checks/externalInputChecksInts'; -add(); -add_unsafe(); -add_signed(); -add_signed_unsafe(); - -//sub - handwritten -sub_unsafe(); -sub_signed(); -sub_signed_unsafe(); - -mul(); -mul_unsafe(); -mul_signed(); -mul_signed_unsafe(); - -//div - handwritten -div_signed(); -div_signed_unsafe(); - -// mod - handwritten -mod_signed(); - -exp(); -exp_signed(); -exp_unsafe(); -exp_signed_unsafe(); - -negate(); - -shl(); - -shr(); -shr_signed(); - -//ge - handwritten -ge_signed(); - -//gt - handwritten -gt_signed(); - -//le - handwritten -le_signed(); - -//lt - handwritten -lt_signed(); - -//xor - handwritten -//bitwise_and - handwritten -//bitwise_or - handwritten -bitwise_not(); - -// ---conversions--- - -int_conversions(); - -// ---external_input_checks--- -external_input_check_ints(); -// and - handwritten -// or - handwritten +export const warplibFunctions: WarplibFunctionInfo[] = [ + add(), + add_unsafe(), + add_signed(), + add_signed_unsafe(), + // sub - handwritten + sub_unsafe(), + sub_signed(), + sub_signed_unsafe(), + mul(), + mul_unsafe(), + mul_signed(), + mul_signed_unsafe(), + // div - handwritten + // div_unsafe - handwritten + div_signed(), + div_signed_unsafe(), + // mod - handwritten + mod_signed(), + exp(), + exp_signed(), + exp_unsafe(), + exp_signed_unsafe(), + negate(), + shl(), + shr(), + shr_signed(), + // ge - handwritten + ge_signed(), + // gt - handwritten + gt_signed(), + // le - handwritten + le_signed(), + // lt - handwritten + lt_signed(), + // and - handwritten + // xor - handwritten + // bitwise_and - handwritten + // bitwise_or - handwritten + bitwise_not(), + // ---conversions--- + int_conversions(), + // ---external_input_checks--- + external_input_check_ints(), + // external_inputt_check_address - handwritten +]; + +warplibFunctions.forEach((warpFunc: WarplibFunctionInfo) => generateFile(warpFunc)); diff --git a/src/warplib/implementations/conversions/dynBytesToFixed.ts b/src/warplib/implementations/conversions/dynBytesToFixed.ts index d3554133c..98bce1cfa 100644 --- a/src/warplib/implementations/conversions/dynBytesToFixed.ts +++ b/src/warplib/implementations/conversions/dynBytesToFixed.ts @@ -1,13 +1,6 @@ -import { - DataLocation, - FixedBytesType, - FunctionCall, - FunctionStateMutability, - TypeName, -} from 'solc-typed-ast'; +import { DataLocation, FixedBytesType, FunctionCall } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; -import { createCairoFunctionStub, createCallToFunction } from '../../../utils/functionGeneration'; -import { Implicits } from '../../../utils/implicits'; +import { createCallToFunction, ParameterInfo } from '../../../utils/functionGeneration'; import { createBytesTypeName, createNumberLiteral, @@ -22,31 +15,23 @@ export function functionaliseBytesToFixedBytes( ): void { const wide = targetType.size === 32; const funcName = wide ? 'wm_bytes_to_fixed32' : 'wm_bytes_to_fixed'; - const args: ([string, TypeName] | [string, TypeName, DataLocation])[] = wide + const args: ParameterInfo[] = wide ? [['bytesLoc', createBytesTypeName(ast), DataLocation.Memory]] : [ ['bytesLoc', createBytesTypeName(ast), DataLocation.Memory], ['width', createUint8TypeName(ast)], ]; - const implicits: Implicits[] = wide ? ['range_check_ptr', 'warp_memory'] : ['warp_memory']; - const stub = createCairoFunctionStub( - funcName, - args, - [['res', typeNameFromTypeNode(targetType, ast)]], - implicits, - ast, - node, - { mutability: FunctionStateMutability.Pure }, - ); + const importedFunc = ast.registerImport(node, 'warplib.memory', funcName, args, [ + ['res', typeNameFromTypeNode(targetType, ast)], + ]); const replacement = createCallToFunction( - stub, + importedFunc, wide ? node.vArguments : [...node.vArguments, createNumberLiteral(targetType.size, ast, 'uint8')], ast, ); ast.replaceNode(node, replacement); - ast.registerImport(replacement, 'warplib.memory', `wm_bytes_to_fixed${wide ? '32' : ''}`); } diff --git a/src/warplib/implementations/conversions/fixedBytes.ts b/src/warplib/implementations/conversions/fixedBytes.ts index 735c9b4a7..56040588a 100644 --- a/src/warplib/implementations/conversions/fixedBytes.ts +++ b/src/warplib/implementations/conversions/fixedBytes.ts @@ -2,7 +2,7 @@ import assert from 'assert'; import { FixedBytesType, FunctionCall, generalizeType } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; import { printTypeNode, printNode } from '../../../utils/astPrinter'; -import { createCairoFunctionStub, createCallToFunction } from '../../../utils/functionGeneration'; +import { createCallToFunction } from '../../../utils/functionGeneration'; import { createNumberLiteral, createUint8TypeName } from '../../../utils/nodeTemplates'; import { safeGetNodeType } from '../../../utils/nodeTypeProcessing'; import { typeNameFromTypeNode } from '../../../utils/utils'; @@ -26,52 +26,48 @@ export function functionaliseFixedBytesConversion(conversion: FunctionCall, ast: if (fromType.size < toType.size) { const fullName = `warp_bytes_widen${toType.size === 32 ? '_256' : ''}`; - const stub = createCairoFunctionStub( + const importedFunc = ast.registerImport( + conversion, + `warplib.maths.bytes_conversions`, fullName, [ ['op', typeNameFromTypeNode(fromType, ast)], ['widthDiff', createUint8TypeName(ast)], ], [['res', typeNameFromTypeNode(toType, ast)]], - toType.size === 32 ? ['range_check_ptr'] : [], - ast, - conversion, ); const call = createCallToFunction( - stub, + importedFunc, [arg, createNumberLiteral(8 * (toType.size - fromType.size), ast, 'uint8')], ast, ); ast.replaceNode(conversion, call); - ast.registerImport(call, `warplib.maths.bytes_conversions`, fullName); return; } else if (fromType.size === toType.size) { ast.replaceNode(conversion, arg); return; } else { const fullName = `warp_bytes_narrow${fromType.size === 32 ? '_256' : ''}`; - const stub = createCairoFunctionStub( + const importedFunc = ast.registerImport( + conversion, + `warplib.maths.bytes_conversions`, fullName, [ ['op', typeNameFromTypeNode(fromType, ast)], ['widthDiff', createUint8TypeName(ast)], ], [['res', typeNameFromTypeNode(toType, ast)]], - fromType.size === 32 ? ['range_check_ptr'] : ['bitwise_ptr'], - ast, - conversion, ); const call = createCallToFunction( - stub, + importedFunc, [arg, createNumberLiteral(8 * (fromType.size - toType.size), ast, 'uint8')], ast, ); ast.replaceNode(conversion, call); - ast.registerImport(call, `warplib.maths.bytes_conversions`, fullName); return; } } diff --git a/src/warplib/implementations/conversions/int.ts b/src/warplib/implementations/conversions/int.ts index b7ba2a314..6ecd0be13 100644 --- a/src/warplib/implementations/conversions/int.ts +++ b/src/warplib/implementations/conversions/int.ts @@ -2,22 +2,29 @@ import assert from 'assert'; import { FunctionCall, generalizeType, IntType } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; import { printNode, printTypeNode } from '../../../utils/astPrinter'; -import { Implicits } from '../../../utils/implicits'; import { safeGetNodeType } from '../../../utils/nodeTypeProcessing'; -import { bound, forAllWidths, generateFile, IntFunction, mask, msb, uint256 } from '../../utils'; +import { + bound, + forAllWidths, + IntFunction, + mask, + msb, + uint256, + WarplibFunctionInfo, +} from '../../utils'; -export function int_conversions(): void { - generateFile( - 'int_conversions', - [ +export function int_conversions(): WarplibFunctionInfo { + return { + fileName: 'int_conversions', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.math import split_felt', 'from starkware.cairo.common.uint256 import Uint256, uint256_add', ], - [ + functions: [ ...forAllWidths((from) => { - return forAllWidths((to) => { + const x = forAllWidths((to) => { if (from < to) { if (to === 256) { return [ @@ -76,14 +83,16 @@ export function int_conversions(): void { } } }); + return x.map((f) => f.join('\n')).join('\n'); }), - '', - 'func warp_uint256{range_check_ptr}(op : felt) -> (res : Uint256){', - ' let split = split_felt(op);', - ' return (Uint256(low=split.low, high=split.high),);', - '}', + [ + 'func warp_uint256{range_check_ptr}(op : felt) -> (res : Uint256){', + ' let split = split_felt(op);', + ' return (Uint256(low=split.low, high=split.high),);', + '}', + ].join('\n'), ], - ); + }; } function sign_extend_value(from: number, to: number): bigint { @@ -108,14 +117,7 @@ export function functionaliseIntConversion(conversion: FunctionCall, ast: AST): ); if (fromType.nBits < 256 && toType.nBits === 256 && !fromType.signed && !toType.signed) { - IntFunction( - conversion, - conversion.vArguments[0], - 'uint', - 'int_conversions', - () => ['range_check_ptr'], - ast, - ); + IntFunction(conversion, conversion.vArguments[0], 'uint', 'int_conversions', ast); return; } else if ( fromType.nBits === toType.nBits || @@ -126,11 +128,7 @@ export function functionaliseIntConversion(conversion: FunctionCall, ast: AST): return; } else { const name = `${fromType.pp().startsWith('u') ? fromType.pp().slice(1) : fromType.pp()}_to_int`; - const implicitsFn = (wide: boolean): Implicits[] => { - if (wide) return ['range_check_ptr', 'bitwise_ptr']; - return ['bitwise_ptr']; - }; - IntFunction(conversion, conversion.vArguments[0], name, 'int_conversions', implicitsFn, ast); + IntFunction(conversion, conversion.vArguments[0], name, 'int_conversions', ast); return; } } diff --git a/src/warplib/implementations/external_input_checks/externalInputChecksInts.ts b/src/warplib/implementations/external_input_checks/externalInputChecksInts.ts index 35d140ac6..9fc9889f8 100644 --- a/src/warplib/implementations/external_input_checks/externalInputChecksInts.ts +++ b/src/warplib/implementations/external_input_checks/externalInputChecksInts.ts @@ -1,40 +1,38 @@ -import { generateFile, forAllWidths, mask } from '../../utils'; +import { forAllWidths, mask, WarplibFunctionInfo } from '../../utils'; const INDENT = ' '.repeat(4); -const import_strings: string[] = [ - 'from starkware.cairo.common.math_cmp import is_le_felt', - 'from starkware.cairo.common.uint256 import Uint256', -]; - -const BitBoundChecker: Array = [ - ...forAllWidths((int_width) => { - if (int_width === 256) { - return [ - `func warp_external_input_check_int256{range_check_ptr}(x : Uint256){`, - `${INDENT}alloc_locals;`, - `${INDENT}let inRangeHigh : felt = is_le_felt(x.high, ${mask(128)});`, - `${INDENT}let inRangeLow : felt = is_le_felt(x.low, ${mask(128)});`, - `${INDENT}with_attr error_message("Error: value out-of-bounds. Values passed to high and low members of Uint256 must be less than 2**128."){`, - `${INDENT.repeat(2)}assert 1 = (inRangeHigh * inRangeLow);`, - `${INDENT}}`, - `${INDENT}return();`, - `}\n`, - ]; - } else { - return [ - `func warp_external_input_check_int${int_width}{range_check_ptr}(x : felt){`, - `${INDENT}let inRange : felt = is_le_felt(x, ${mask(int_width)});`, - `${INDENT}with_attr error_message("Error: value out-of-bounds. Value must be less than 2**${int_width}."){`, - `${INDENT.repeat(2)}assert 1 = inRange;`, - `${INDENT}}`, - `${INDENT}return ();`, - `}\n`, - ]; - } - }), -]; - -export function external_input_check_ints(): void { - generateFile('external_input_check_ints', import_strings, [...BitBoundChecker]); +export function external_input_check_ints(): WarplibFunctionInfo { + return { + fileName: 'external_input_check_ints', + imports: [ + 'from starkware.cairo.common.math_cmp import is_le_felt', + 'from starkware.cairo.common.uint256 import Uint256', + ], + functions: forAllWidths((int_width) => { + if (int_width === 256) { + return [ + `func warp_external_input_check_int256{range_check_ptr}(x : Uint256){`, + `${INDENT}alloc_locals;`, + `${INDENT}let inRangeHigh : felt = is_le_felt(x.high, ${mask(128)});`, + `${INDENT}let inRangeLow : felt = is_le_felt(x.low, ${mask(128)});`, + `${INDENT}with_attr error_message("Error: value out-of-bounds. Values passed to high and low members of Uint256 must be less than 2**128."){`, + `${INDENT.repeat(2)}assert 1 = (inRangeHigh * inRangeLow);`, + `${INDENT}}`, + `${INDENT}return();`, + `}\n`, + ].join('\n'); + } else { + return [ + `func warp_external_input_check_int${int_width}{range_check_ptr}(x : felt){`, + `${INDENT}let inRange : felt = is_le_felt(x, ${mask(int_width)});`, + `${INDENT}with_attr error_message("Error: value out-of-bounds. Value must be less than 2**${int_width}."){`, + `${INDENT.repeat(2)}assert 1 = inRange;`, + `${INDENT}}`, + `${INDENT}return ();`, + `}\n`, + ].join('\n'); + } + }), + }; } diff --git a/src/warplib/implementations/maths/add.ts b/src/warplib/implementations/maths/add.ts index 0d73f85bd..e99fdfde6 100644 --- a/src/warplib/implementations/maths/add.ts +++ b/src/warplib/implementations/maths/add.ts @@ -1,55 +1,61 @@ import { BinaryOperation } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; -import { Implicits } from '../../../utils/implicits'; -import { forAllWidths, generateFile, IntxIntFunction, mask, msb, msbAndNext } from '../../utils'; +import { + forAllWidths, + IntxIntFunction, + mask, + msb, + msbAndNext, + WarplibFunctionInfo, +} from '../../utils'; -export function add(): void { - generateFile( - 'add', - [ - 'from starkware.cairo.common.math_cmp import is_le_felt', - 'from starkware.cairo.common.uint256 import Uint256, uint256_add', - ], - forAllWidths((width) => { - if (width === 256) { - return [ - `func warp_add256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){`, - ` let (res : Uint256, carry : felt) = uint256_add(lhs, rhs);`, - ` assert carry = 0;`, - ` return (res,);`, - `}`, - ]; - } else { - return [ - `func warp_add${width}{range_check_ptr}(lhs : felt, rhs : felt) -> (res : felt){`, - ` let res = lhs + rhs;`, - ` let inRange : felt = is_le_felt(res, ${mask(width)});`, - ` assert inRange = 1;`, - ` return (res,);`, - `}`, - ]; - } - }), - ); +export function add(): WarplibFunctionInfo { + const fileName = 'add'; + const imports = [ + 'from starkware.cairo.common.math_cmp import is_le_felt', + 'from starkware.cairo.common.uint256 import Uint256, uint256_add', + ]; + const functions = forAllWidths((width) => { + if (width === 256) { + return [ + `func warp_add256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){`, + ` let (res : Uint256, carry : felt) = uint256_add(lhs, rhs);`, + ` assert carry = 0;`, + ` return (res,);`, + `}`, + ].join('\n'); + } else { + return [ + `func warp_add${width}{range_check_ptr}(lhs : felt, rhs : felt) -> (res : felt){`, + ` let res = lhs + rhs;`, + ` let inRange : felt = is_le_felt(res, ${mask(width)});`, + ` assert inRange = 1;`, + ` return (res,);`, + `}`, + ].join('\n'); + } + }); + + return { fileName, imports, functions }; } -export function add_unsafe(): void { - generateFile( - 'add_unsafe', - [ +export function add_unsafe(): WarplibFunctionInfo { + return { + fileName: 'add_unsafe', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.math_cmp import is_le_felt', 'from starkware.cairo.common.uint256 import Uint256, uint256_add', ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ `func warp_add_unsafe256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){`, ` let (res : Uint256, _) = uint256_add(lhs, rhs);`, ` return (res,);`, `}`, - ]; + ].join('\n'); } else { return [ `func warp_add_unsafe${width}{bitwise_ptr : BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (`, @@ -57,22 +63,22 @@ export function add_unsafe(): void { ` let (res) = bitwise_and(lhs + rhs, ${mask(width)});`, ` return (res,);`, `}`, - ]; + ].join('\n'); } }), - ); + }; } -export function add_signed(): void { - generateFile( - 'add_signed', - [ +export function add_signed(): WarplibFunctionInfo { + return { + fileName: 'add_signed', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.math_cmp import is_le_felt', 'from starkware.cairo.common.uint256 import Uint256, uint256_add', ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ `func warp_add_signed256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, @@ -86,7 +92,7 @@ export function add_signed(): void { ` assert msb = carry_lsb;`, ` return (res,);`, `}`, - ]; + ].join('\n'); } else { return [ `func warp_add_signed${width}{bitwise_ptr : BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (`, @@ -102,29 +108,29 @@ export function add_signed(): void { ` let (res) = bitwise_and(big_res, ${mask(width)});`, ` return (res,);`, `}`, - ]; + ].join('\n'); } }), - ); + }; } -export function add_signed_unsafe(): void { - generateFile( - 'add_signed_unsafe', - [ +export function add_signed_unsafe(): WarplibFunctionInfo { + return { + fileName: 'add_signed_unsafe', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.math_cmp import is_le_felt', 'from starkware.cairo.common.uint256 import Uint256, uint256_add', ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ `func warp_add_signed_unsafe256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){`, ` let (res : Uint256, _) = uint256_add(lhs, rhs);`, ` return (res,);`, `}`, - ]; + ].join('\n'); } else { return [ `func warp_add_signed_unsafe${width}{bitwise_ptr : BitwiseBuiltin*}(`, @@ -132,17 +138,12 @@ export function add_signed_unsafe(): void { ` let (res) = bitwise_and(lhs + rhs, ${mask(width)});`, ` return (res,);`, `}`, - ]; + ].join('\n'); } }), - ); + }; } export function functionaliseAdd(node: BinaryOperation, unsafe: boolean, ast: AST): void { - const implicitsFn = (width: number, signed: boolean): Implicits[] => { - if (!unsafe && signed && width === 256) return ['range_check_ptr', 'bitwise_ptr']; - else if ((!unsafe && !signed) || width === 256) return ['range_check_ptr']; - else return ['bitwise_ptr']; - }; - IntxIntFunction(node, 'add', 'always', true, unsafe, implicitsFn, ast); + IntxIntFunction(node, 'add', 'always', true, unsafe, ast); } diff --git a/src/warplib/implementations/maths/bitwiseAnd.ts b/src/warplib/implementations/maths/bitwiseAnd.ts index 75914a9ed..10b994053 100644 --- a/src/warplib/implementations/maths/bitwiseAnd.ts +++ b/src/warplib/implementations/maths/bitwiseAnd.ts @@ -1,12 +1,7 @@ import { BinaryOperation } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; -import { Implicits } from '../../../utils/implicits'; import { IntxIntFunction } from '../../utils'; export function functionaliseBitwiseAnd(node: BinaryOperation, ast: AST): void { - const implicitsFn = (width: number): Implicits[] => { - if (width === 256) return ['range_check_ptr', 'bitwise_ptr']; - else return ['bitwise_ptr']; - }; - IntxIntFunction(node, 'bitwise_and', 'only256', false, false, implicitsFn, ast); + IntxIntFunction(node, 'bitwise_and', 'only256', false, false, ast); } diff --git a/src/warplib/implementations/maths/bitwiseNot.ts b/src/warplib/implementations/maths/bitwiseNot.ts index 36865dbed..f15c7b9dd 100644 --- a/src/warplib/implementations/maths/bitwiseNot.ts +++ b/src/warplib/implementations/maths/bitwiseNot.ts @@ -1,40 +1,35 @@ import { UnaryOperation } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; -import { Implicits } from '../../../utils/implicits'; -import { forAllWidths, generateFile, IntFunction, mask } from '../../utils'; +import { forAllWidths, IntFunction, mask, WarplibFunctionInfo } from '../../utils'; -export function bitwise_not(): void { - generateFile( - 'bitwise_not', - [ +export function bitwise_not(): WarplibFunctionInfo { + return { + fileName: 'bitwise_not', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_xor', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.uint256 import Uint256, uint256_not', ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ 'func warp_bitwise_not256{range_check_ptr}(op : Uint256) -> (res : Uint256){', ' let (res) = uint256_not(op);', ' return (res,);', '}', - ]; + ].join('\n'); } else { return [ `func warp_bitwise_not${width}{bitwise_ptr : BitwiseBuiltin*}(op : felt) -> (res : felt){`, ` let (res) = bitwise_xor(op, ${mask(width)});`, ` return (res,);`, '}', - ]; + ].join('\n'); } }), - ); + }; } export function functionaliseBitwiseNot(node: UnaryOperation, ast: AST): void { - const implicitsFn = (wide: boolean): Implicits[] => { - if (wide) return ['range_check_ptr']; - else return ['bitwise_ptr']; - }; - IntFunction(node, node.vSubExpression, 'bitwise_not', 'bitwise_not', implicitsFn, ast); + IntFunction(node, node.vSubExpression, 'bitwise_not', 'bitwise_not', ast); } diff --git a/src/warplib/implementations/maths/bitwiseOr.ts b/src/warplib/implementations/maths/bitwiseOr.ts index 5162994a4..a1acdf678 100644 --- a/src/warplib/implementations/maths/bitwiseOr.ts +++ b/src/warplib/implementations/maths/bitwiseOr.ts @@ -1,12 +1,7 @@ import { BinaryOperation } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; -import { Implicits } from '../../../utils/implicits'; import { IntxIntFunction } from '../../utils'; export function functionaliseBitwiseOr(node: BinaryOperation, ast: AST): void { - const implicitsFn = (width: number): Implicits[] => { - if (width === 256) return ['range_check_ptr', 'bitwise_ptr']; - else return ['bitwise_ptr']; - }; - IntxIntFunction(node, 'bitwise_or', 'only256', false, false, implicitsFn, ast); + IntxIntFunction(node, 'bitwise_or', 'only256', false, false, ast); } diff --git a/src/warplib/implementations/maths/div.ts b/src/warplib/implementations/maths/div.ts index d13e22028..ec292617c 100644 --- a/src/warplib/implementations/maths/div.ts +++ b/src/warplib/implementations/maths/div.ts @@ -1,13 +1,12 @@ import { BinaryOperation } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; -import { Implicits } from '../../../utils/implicits'; import { mapRange } from '../../../utils/utils'; -import { forAllWidths, generateFile, IntxIntFunction, mask } from '../../utils'; +import { forAllWidths, IntxIntFunction, mask, WarplibFunctionInfo } from '../../utils'; -export function div_signed() { - generateFile( - 'div_signed', - [ +export function div_signed(): WarplibFunctionInfo { + return { + fileName: 'div_signed', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_div_rem, uint256_eq', @@ -18,7 +17,7 @@ export function div_signed() { ).join(', ')}, ${mapRange(31, (n) => `warp_int256_to_int${8 * n + 8}`).join(', ')}`, `from warplib.maths.mul_signed import ${mapRange(32, (n) => `warp_mul_signed${8 * n + 8}`)}`, ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ 'func warp_div_signed256{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){', @@ -35,7 +34,7 @@ export function div_signed() { ' let (res : Uint256, _) = uint256_signed_div_rem(lhs, rhs);', ' return (res,);', '}', - ]; + ].join('\n'); } else { return [ `func warp_div_signed${width}{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (res : felt){`, @@ -55,16 +54,16 @@ export function div_signed() { ` let (truncated) = warp_int256_to_int${width}(res256);`, ` return (truncated,);`, '}', - ]; + ].join('\n'); } }), - ); + }; } -export function div_signed_unsafe() { - generateFile( - 'div_signed_unsafe', - [ +export function div_signed_unsafe(): WarplibFunctionInfo { + return { + fileName: 'div_signed_unsafe', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_div_rem, uint256_eq', @@ -78,7 +77,7 @@ export function div_signed_unsafe() { (n) => `warp_mul_signed_unsafe${8 * n + 8}`, )}`, ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ 'func warp_div_signed_unsafe256{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){', @@ -95,7 +94,7 @@ export function div_signed_unsafe() { ' let (res : Uint256, _) = uint256_signed_div_rem(lhs, rhs);', ' return (res,);', '}', - ]; + ].join('\n'); } else { return [ `func warp_div_signed_unsafe${width}{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (res : felt){`, @@ -115,18 +114,12 @@ export function div_signed_unsafe() { ` let (truncated) = warp_int256_to_int${width}(res256);`, ` return (truncated,);`, '}', - ]; + ].join('\n'); } }), - ); + }; } export function functionaliseDiv(node: BinaryOperation, unsafe: boolean, ast: AST): void { - const implicitsFn = (width: number, signed: boolean): Implicits[] => { - if (signed || (unsafe && width >= 128 && width < 256)) - return ['range_check_ptr', 'bitwise_ptr']; - else if (unsafe && width < 128) return ['bitwise_ptr']; - else return ['range_check_ptr']; - }; - IntxIntFunction(node, 'div', 'signedOrWide', true, unsafe, implicitsFn, ast); + IntxIntFunction(node, 'div', 'signedOrWide', true, unsafe, ast); } diff --git a/src/warplib/implementations/maths/eq.ts b/src/warplib/implementations/maths/eq.ts index 508bf413e..520e6c1d1 100644 --- a/src/warplib/implementations/maths/eq.ts +++ b/src/warplib/implementations/maths/eq.ts @@ -1,12 +1,7 @@ import { BinaryOperation } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; -import { Implicits } from '../../../utils/implicits'; import { Comparison } from '../../utils'; export function functionaliseEq(node: BinaryOperation, ast: AST): void { - const implicitsFn = (wide: boolean): Implicits[] => { - if (wide) return ['range_check_ptr']; - return []; - }; - Comparison(node, 'eq', 'only256', false, implicitsFn, ast); + Comparison(node, 'eq', 'only256', false, ast); } diff --git a/src/warplib/implementations/maths/exp.ts b/src/warplib/implementations/maths/exp.ts index 90cbfc235..daed4d75d 100644 --- a/src/warplib/implementations/maths/exp.ts +++ b/src/warplib/implementations/maths/exp.ts @@ -8,32 +8,31 @@ import { } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; import { printNode, printTypeNode } from '../../../utils/astPrinter'; -import { createCairoFunctionStub } from '../../../utils/functionGeneration'; import { safeGetNodeType } from '../../../utils/nodeTypeProcessing'; import { mapRange, typeNameFromTypeNode } from '../../../utils/utils'; -import { forAllWidths, generateFile, getIntOrFixedByteBitWidth, mask } from '../../utils'; +import { forAllWidths, getIntOrFixedByteBitWidth, mask, WarplibFunctionInfo } from '../../utils'; export function exp() { - createExp(false, false); + return createExp(false, false); } export function exp_signed() { - createExp(true, false); + return createExp(true, false); } export function exp_unsafe() { - createExp(false, true); + return createExp(false, true); } export function exp_signed_unsafe() { - createExp(true, true); + return createExp(true, true); } -function createExp(signed: boolean, unsafe: boolean) { +function createExp(signed: boolean, unsafe: boolean): WarplibFunctionInfo { const suffix = `${signed ? '_signed' : ''}${unsafe ? '_unsafe' : ''}`; - generateFile( - `exp${suffix}`, - [ + return { + fileName: `exp${suffix}`, + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.uint256 import Uint256, uint256_sub', @@ -42,7 +41,7 @@ function createExp(signed: boolean, unsafe: boolean) { (n) => `warp_mul${suffix}${8 * n + 8}`, ).join(', ')}`, ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ `func _repeated_multiplication${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(op : Uint256, count : felt) -> (res : Uint256){`, @@ -86,7 +85,7 @@ function createExp(signed: boolean, unsafe: boolean) { ` let (res) = _repeated_multiplication_256_${width}(lhs, rhs);`, ` return (res,);`, `}`, - ]; + ].join('\n'); } else { return [ `func _repeated_multiplication${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(op : felt, count : felt) -> (res : felt){`, @@ -138,10 +137,10 @@ function createExp(signed: boolean, unsafe: boolean) { ` let (res) = _repeated_multiplication_256_${width}(lhs, rhs);`, ` return (res,);`, '}', - ]; + ].join('\n'); } }), - ); + }; } function getNegativeOneShortcutCode(signed: boolean, lhsWidth: number, rhsWide: boolean): string[] { @@ -192,17 +191,17 @@ export function functionaliseExp(node: BinaryOperation, unsafe: boolean, ast: AS unsafe ? '_unsafe' : '', ].join(''); - const stub = createCairoFunctionStub( + const importedFunc = ast.registerImport( + node, + importName, fullName, [ ['lhs', typeNameFromTypeNode(lhsType, ast)], ['rhs', typeNameFromTypeNode(rhsType, ast)], ], [['res', typeNameFromTypeNode(retType, ast)]], - ['range_check_ptr', 'bitwise_ptr'], - ast, - node, ); + const call = new FunctionCall( ast.reserveId(), node.src, @@ -213,11 +212,10 @@ export function functionaliseExp(node: BinaryOperation, unsafe: boolean, ast: AS '', `function (${node.typeString}, ${node.typeString}) returns (${node.typeString})`, fullName, - stub.id, + importedFunc.id, ), [node.vLeftExpression, node.vRightExpression], ); ast.replaceNode(node, call); - ast.registerImport(call, importName, fullName); } diff --git a/src/warplib/implementations/maths/ge.ts b/src/warplib/implementations/maths/ge.ts index e1893d2aa..426e9df1d 100644 --- a/src/warplib/implementations/maths/ge.ts +++ b/src/warplib/implementations/maths/ge.ts @@ -1,13 +1,12 @@ import { BinaryOperation } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; -import { Implicits } from '../../../utils/implicits'; import { mapRange } from '../../../utils/utils'; -import { generateFile, forAllWidths, Comparison } from '../../utils'; +import { forAllWidths, Comparison, WarplibFunctionInfo } from '../../utils'; -export function ge_signed() { - generateFile( - 'ge_signed', - [ +export function ge_signed(): WarplibFunctionInfo { + return { + fileName: 'ge_signed', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_le', @@ -15,14 +14,14 @@ export function ge_signed() { ', ', )}`, ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ 'func warp_ge_signed256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : felt){', ' let (res) = uint256_signed_le(rhs, lhs);', ' return (res,);', '}', - ]; + ].join('\n'); } else { return [ `func warp_ge_signed${width}{bitwise_ptr : BitwiseBuiltin*, range_check_ptr}(`, @@ -30,16 +29,12 @@ export function ge_signed() { ` let (res) = warp_le_signed${width}(rhs, lhs);`, ` return (res,);`, '}', - ]; + ].join('\n'); } }), - ); + }; } export function functionaliseGe(node: BinaryOperation, ast: AST): void { - const implicitsFn = (wide: boolean, signed: boolean): Implicits[] => { - if (wide || !signed) return ['range_check_ptr']; - else return ['range_check_ptr', 'bitwise_ptr']; - }; - Comparison(node, 'ge', 'signedOrWide', true, implicitsFn, ast); + Comparison(node, 'ge', 'signedOrWide', true, ast); } diff --git a/src/warplib/implementations/maths/gt.ts b/src/warplib/implementations/maths/gt.ts index 64e4e6eac..7764c5aca 100644 --- a/src/warplib/implementations/maths/gt.ts +++ b/src/warplib/implementations/maths/gt.ts @@ -1,13 +1,12 @@ import { BinaryOperation } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; -import { Implicits } from '../../../utils/implicits'; import { mapRange } from '../../../utils/utils'; -import { generateFile, forAllWidths, Comparison } from '../../utils'; +import { forAllWidths, Comparison, WarplibFunctionInfo } from '../../utils'; -export function gt_signed() { - generateFile( - 'gt_signed', - [ +export function gt_signed(): WarplibFunctionInfo { + return { + fileName: 'gt_signed', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_lt', @@ -15,14 +14,14 @@ export function gt_signed() { ', ', )}`, ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ 'func warp_gt_signed256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : felt){', ' let (res) = uint256_signed_lt(rhs, lhs);', ' return (res,);', '}', - ]; + ].join('\n'); } else { return [ `func warp_gt_signed${width}{bitwise_ptr : BitwiseBuiltin*, range_check_ptr}(`, @@ -30,16 +29,12 @@ export function gt_signed() { ` let (res) = warp_lt_signed${width}(rhs, lhs);`, ` return (res,);`, '}', - ]; + ].join('\n'); } }), - ); + }; } export function functionaliseGt(node: BinaryOperation, ast: AST): void { - const implicitsFn = (wide: boolean, signed: boolean): Implicits[] => { - if (wide || !signed) return ['range_check_ptr']; - else return ['range_check_ptr', 'bitwise_ptr']; - }; - Comparison(node, 'gt', 'signedOrWide', true, implicitsFn, ast); + Comparison(node, 'gt', 'signedOrWide', true, ast); } diff --git a/src/warplib/implementations/maths/le.ts b/src/warplib/implementations/maths/le.ts index e37c95475..9e18a1f5a 100644 --- a/src/warplib/implementations/maths/le.ts +++ b/src/warplib/implementations/maths/le.ts @@ -1,25 +1,24 @@ import { BinaryOperation } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; -import { Implicits } from '../../../utils/implicits'; -import { generateFile, forAllWidths, msb, Comparison } from '../../utils'; +import { forAllWidths, msb, Comparison, WarplibFunctionInfo } from '../../utils'; -export function le_signed() { - generateFile( - 'le_signed', - [ +export function le_signed(): WarplibFunctionInfo { + return { + fileName: 'le_signed', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.math_cmp import is_le_felt', 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_le', ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ `func warp_le_signed${width}{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : felt){`, ' let (res) = uint256_signed_le(lhs, rhs);', ' return (res,);', '}', - ]; + ].join('\n'); } else { return [ `func warp_le_signed${width}{bitwise_ptr : BitwiseBuiltin*, range_check_ptr}(`, @@ -51,17 +50,12 @@ export function le_signed() { ` }`, ` }`, `}`, - ]; + ].join('\n'); } }), - ); + }; } export function functionaliseLe(node: BinaryOperation, ast: AST): void { - const implicitsFn = (wide: boolean, signed: boolean): Implicits[] => { - if (!wide && signed) return ['range_check_ptr', 'bitwise_ptr']; - else return ['range_check_ptr']; - }; - - Comparison(node, 'le', 'signedOrWide', true, implicitsFn, ast); + Comparison(node, 'le', 'signedOrWide', true, ast); } diff --git a/src/warplib/implementations/maths/lt.ts b/src/warplib/implementations/maths/lt.ts index 8d43d21ab..b58212b8c 100644 --- a/src/warplib/implementations/maths/lt.ts +++ b/src/warplib/implementations/maths/lt.ts @@ -1,13 +1,12 @@ import { BinaryOperation } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; -import { Implicits } from '../../../utils/implicits'; import { mapRange } from '../../../utils/utils'; -import { generateFile, forAllWidths, Comparison } from '../../utils'; +import { forAllWidths, Comparison, WarplibFunctionInfo } from '../../utils'; -export function lt_signed() { - generateFile( - 'lt_signed', - [ +export function lt_signed(): WarplibFunctionInfo { + return { + fileName: 'lt_signed', + imports: [ 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_lt', 'from warplib.maths.utils import felt_to_uint256', @@ -15,14 +14,14 @@ export function lt_signed() { ', ', )}`, ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ 'func warp_lt_signed256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : felt){', ' let (res) = uint256_signed_lt(lhs, rhs);', ' return (res,);', '}', - ]; + ].join('\n'); } else { return [ `func warp_lt_signed${width}{bitwise_ptr : BitwiseBuiltin*, range_check_ptr}(`, @@ -33,17 +32,12 @@ export function lt_signed() { ` let (res) = warp_le_signed${width}(lhs, rhs);`, ` return (res,);`, '}', - ]; + ].join('\n'); } }), - ); + }; } export function functionaliseLt(node: BinaryOperation, ast: AST): void { - const implicitsFn = (wide: boolean, signed: boolean): Implicits[] => { - if (!wide && signed) return ['range_check_ptr', 'bitwise_ptr']; - else return ['range_check_ptr']; - }; - - Comparison(node, 'lt', 'signedOrWide', true, implicitsFn, ast); + Comparison(node, 'lt', 'signedOrWide', true, ast); } diff --git a/src/warplib/implementations/maths/mod.ts b/src/warplib/implementations/maths/mod.ts index 45b8148b7..6cae243d4 100644 --- a/src/warplib/implementations/maths/mod.ts +++ b/src/warplib/implementations/maths/mod.ts @@ -1,13 +1,12 @@ import { BinaryOperation } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; -import { Implicits } from '../../../utils/implicits'; import { mapRange } from '../../../utils/utils'; -import { forAllWidths, generateFile, IntxIntFunction } from '../../utils'; +import { forAllWidths, IntxIntFunction, WarplibFunctionInfo } from '../../utils'; -export function mod_signed() { - generateFile( - 'mod_signed', - [ +export function mod_signed(): WarplibFunctionInfo { + return { + fileName: 'mod_signed', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.uint256 import Uint256, uint256_signed_div_rem', @@ -17,7 +16,7 @@ export function mod_signed() { (n) => `warp_int${8 * n + 8}_to_int256`, ).join(', ')}, ${mapRange(31, (n) => `warp_int256_to_int${8 * n + 8}`).join(', ')}`, ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ 'func warp_mod_signed256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){', @@ -29,7 +28,7 @@ export function mod_signed() { ' let (_, res : Uint256) = uint256_signed_div_rem(lhs, rhs);', ' return (res,);', '}', - ]; + ].join('\n'); } else { return [ `func warp_mod_signed${width}{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (res : felt){`, @@ -45,16 +44,12 @@ export function mod_signed() { ` let (truncated) = warp_int256_to_int${width}(res256);`, ` return (truncated,);`, '}', - ]; + ].join('\n'); } }), - ); + }; } export function functionaliseMod(node: BinaryOperation, ast: AST): void { - const implicits = (width: number, signed: boolean): Implicits[] => { - if (width !== 256 && signed) return ['range_check_ptr', 'bitwise_ptr']; - return ['range_check_ptr']; - }; - IntxIntFunction(node, 'mod', 'signedOrWide', true, false, implicits, ast); + IntxIntFunction(node, 'mod', 'signedOrWide', true, false, ast); } diff --git a/src/warplib/implementations/maths/mul.ts b/src/warplib/implementations/maths/mul.ts index 58e9131bd..738c5ba2b 100644 --- a/src/warplib/implementations/maths/mul.ts +++ b/src/warplib/implementations/maths/mul.ts @@ -1,9 +1,7 @@ import { BinaryOperation } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; -import { Implicits } from '../../../utils/implicits'; import { mapRange } from '../../../utils/utils'; import { - generateFile, forAllWidths, uint256, pow2, @@ -11,18 +9,19 @@ import { mask, msb, IntxIntFunction, + WarplibFunctionInfo, } from '../../utils'; -export function mul(): void { - generateFile( - 'mul', - [ +export function mul(): WarplibFunctionInfo { + return { + fileName: 'mul', + imports: [ 'from starkware.cairo.common.uint256 import Uint256, uint256_mul', 'from starkware.cairo.common.math_cmp import is_le_felt', 'from warplib.maths.ge import warp_ge256', 'from warplib.maths.utils import felt_to_uint256', ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ 'func warp_mul256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){', @@ -31,7 +30,7 @@ export function mul(): void { ' assert overflow.high = 0;', ' return (result,);', '}', - ]; + ].join('\n'); } else if (width >= 128) { return [ `func warp_mul${width}{range_check_ptr}(lhs : felt, rhs : felt) -> (res : felt){`, @@ -43,7 +42,7 @@ export function mul(): void { ' assert outOfRange = 0;', ` return (res.low + ${bound(128)} * res.high,);`, '}', - ]; + ].join('\n'); } else { return [ `func warp_mul${width}{range_check_ptr}(lhs : felt, rhs : felt) -> (res : felt){`, @@ -52,29 +51,29 @@ export function mul(): void { ' assert inRange = 1;', ' return (res,);', '}', - ]; + ].join('\n'); } }), - ); + }; } -export function mul_unsafe(): void { - generateFile( - 'mul_unsafe', - [ +export function mul_unsafe(): WarplibFunctionInfo { + return { + fileName: 'mul_unsafe', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.uint256 import Uint256, uint256_mul', 'from warplib.maths.utils import felt_to_uint256', ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ `func warp_mul_unsafe256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){`, ` let (res : Uint256, _) = uint256_mul(lhs, rhs);`, ` return (res,);`, `}`, - ]; + ].join('\n'); } else if (width >= 128) { return [ `func warp_mul_unsafe${width}{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (res : felt){`, @@ -85,23 +84,23 @@ export function mul_unsafe(): void { ` let (high) = bitwise_and(res.high, ${mask(width - 128)});`, ` return (res.low + ${bound(128)} * high,);`, `}`, - ]; + ].join('\n'); } else { return [ `func warp_mul_unsafe${width}{bitwise_ptr : BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (res : felt){`, ` let (res) = bitwise_and(lhs * rhs, ${mask(width)});`, ` return (res,);`, '}', - ]; + ].join('\n'); } }), - ); + }; } -export function mul_signed(): void { - generateFile( - 'mul_signed', - [ +export function mul_signed(): WarplibFunctionInfo { + return { + fileName: 'mul_signed', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.math_cmp import is_le_felt', @@ -113,7 +112,7 @@ export function mul_signed(): void { ', ', )}`, ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ `func warp_mul_signed256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, @@ -142,7 +141,7 @@ export function mul_signed(): void { ` return (res_abs,);`, ` }`, `}`, - ]; + ].join('\n'); } else { return [ `func warp_mul_signed${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, @@ -183,16 +182,16 @@ export function mul_signed(): void { ` }`, ` }`, `}`, - ]; + ].join('\n'); } }), - ); + }; } -export function mul_signed_unsafe(): void { - generateFile( - 'mul_signed_unsafe', - [ +export function mul_signed_unsafe(): WarplibFunctionInfo { + return { + fileName: 'mul_signed_unsafe', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.math_cmp import is_le_felt', @@ -206,7 +205,7 @@ export function mul_signed_unsafe(): void { )}`, 'from warplib.maths.utils import felt_to_uint256', ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ `func warp_mul_signed_unsafe256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, @@ -220,7 +219,7 @@ export function mul_signed_unsafe(): void { ` let (res) = uint256_cond_neg(res_abs, (lhs_nn + rhs_nn) * (2 - lhs_nn - rhs_nn));`, ` return (res,);`, `}`, - ]; + ].join('\n'); } else { return [ `func warp_mul_signed_unsafe${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, @@ -237,18 +236,12 @@ export function mul_signed_unsafe(): void { ` return (res,);`, ` }`, `}`, - ]; + ].join('\n'); } }), - ); + }; } export function functionaliseMul(node: BinaryOperation, unsafe: boolean, ast: AST): void { - const implicitsFn = (width: number, signed: boolean): Implicits[] => { - if (signed || (unsafe && width >= 128 && width < 256)) - return ['range_check_ptr', 'bitwise_ptr']; - else if (unsafe && width < 128) return ['bitwise_ptr']; - else return ['range_check_ptr']; - }; - IntxIntFunction(node, 'mul', 'always', true, unsafe, implicitsFn, ast); + IntxIntFunction(node, 'mul', 'always', true, unsafe, ast); } diff --git a/src/warplib/implementations/maths/negate.ts b/src/warplib/implementations/maths/negate.ts index 845201212..388b91138 100644 --- a/src/warplib/implementations/maths/negate.ts +++ b/src/warplib/implementations/maths/negate.ts @@ -1,25 +1,24 @@ import { UnaryOperation } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; -import { Implicits } from '../../../utils/implicits'; -import { bound, forAllWidths, generateFile, IntFunction, mask } from '../../utils'; +import { bound, forAllWidths, IntFunction, mask, WarplibFunctionInfo } from '../../utils'; // This satisfies the solidity convention of -type(intX).min = type(intX).min -export function negate(): void { - generateFile( - 'negate', - [ +export function negate(): WarplibFunctionInfo { + return { + fileName: 'negate', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.uint256 import Uint256, uint256_neg', ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ 'func warp_negate256{range_check_ptr}(op : Uint256) -> (res : Uint256){', ' let (res) = uint256_neg(op);', ' return (res,);', '}', - ]; + ].join('\n'); } else { // Could also have if op == 0: 0 else limit-op return [ @@ -28,16 +27,12 @@ export function negate(): void { ` let (res) = bitwise_and(raw_res, ${mask(width)});`, ` return (res,);`, `}`, - ]; + ].join('\n'); } }), - ); + }; } export function functionaliseNegate(node: UnaryOperation, ast: AST): void { - const implicitsFn = (wide: boolean): Implicits[] => { - if (wide) return ['range_check_ptr']; - else return ['bitwise_ptr']; - }; - IntFunction(node, node.vSubExpression, 'negate', 'negate', implicitsFn, ast); + IntFunction(node, node.vSubExpression, 'negate', 'negate', ast); } diff --git a/src/warplib/implementations/maths/neq.ts b/src/warplib/implementations/maths/neq.ts index 344beefad..0d62da748 100644 --- a/src/warplib/implementations/maths/neq.ts +++ b/src/warplib/implementations/maths/neq.ts @@ -1,13 +1,7 @@ import { BinaryOperation } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; -import { Implicits } from '../../../utils/implicits'; import { Comparison } from '../../utils'; export function functionaliseNeq(node: BinaryOperation, ast: AST): void { - const implicitsFn = (wide: boolean): Implicits[] => { - if (wide) return ['range_check_ptr']; - else return []; - }; - - Comparison(node, 'neq', 'only256', false, implicitsFn, ast); + Comparison(node, 'neq', 'only256', false, ast); } diff --git a/src/warplib/implementations/maths/shl.ts b/src/warplib/implementations/maths/shl.ts index 3b6a5b7bb..9ee6f870a 100644 --- a/src/warplib/implementations/maths/shl.ts +++ b/src/warplib/implementations/maths/shl.ts @@ -9,17 +9,16 @@ import { } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; import { printNode, printTypeNode } from '../../../utils/astPrinter'; -import { createCairoFunctionStub } from '../../../utils/functionGeneration'; import { safeGetNodeType } from '../../../utils/nodeTypeProcessing'; import { typeNameFromTypeNode } from '../../../utils/utils'; -import { generateFile, forAllWidths, getIntOrFixedByteBitWidth } from '../../utils'; +import { forAllWidths, getIntOrFixedByteBitWidth, WarplibFunctionInfo } from '../../utils'; // rhs is always unsigned, and signed and unsigned shl are the same -export function shl(): void { +export function shl(): WarplibFunctionInfo { //Need to provide an implementation with 256bit rhs and <256bit lhs - generateFile( - 'shl', - [ + return { + fileName: 'shl', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.math import split_felt', @@ -27,7 +26,7 @@ export function shl(): void { 'from starkware.cairo.common.uint256 import Uint256, uint256_shl', 'from warplib.maths.pow2 import pow2', ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ 'func warp_shl256{range_check_ptr}(lhs : Uint256, rhs : felt) -> (result : Uint256){', @@ -39,7 +38,7 @@ export function shl(): void { ' let (res) = uint256_shl(lhs, rhs);', ' return (res,);', '}', - ]; + ].join('\n'); } else { return [ `func warp_shl${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, @@ -65,10 +64,10 @@ export function shl(): void { ` return (0,);`, ` }`, `}`, - ]; + ].join('\n'); } }), - ); + }; } export function functionaliseShl(node: BinaryOperation, ast: AST): void { @@ -91,16 +90,15 @@ export function functionaliseShl(node: BinaryOperation, ast: AST): void { const importName = 'warplib.maths.shl'; - const stub = createCairoFunctionStub( + const importedFunc = ast.registerImport( + node, + importName, fullName, [ ['lhs', typeNameFromTypeNode(lhsType, ast)], ['rhs', typeNameFromTypeNode(rhsType, ast)], ], [['res', typeNameFromTypeNode(retType, ast)]], - lhsWidth === 256 ? ['range_check_ptr'] : ['range_check_ptr', 'bitwise_ptr'], - ast, - node, ); const call = new FunctionCall( ast.reserveId(), @@ -112,11 +110,10 @@ export function functionaliseShl(node: BinaryOperation, ast: AST): void { '', `function (${node.vLeftExpression.typeString}, ${node.vRightExpression.typeString}) returns (${node.typeString})`, fullName, - stub.id, + importedFunc.id, ), [node.vLeftExpression, node.vRightExpression], ); ast.replaceNode(node, call); - ast.registerImport(call, importName, fullName); } diff --git a/src/warplib/implementations/maths/shr.ts b/src/warplib/implementations/maths/shr.ts index ac4e3e2a0..a42152d4e 100644 --- a/src/warplib/implementations/maths/shr.ts +++ b/src/warplib/implementations/maths/shr.ts @@ -9,29 +9,28 @@ import { } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; import { printNode, printTypeNode } from '../../../utils/astPrinter'; -import { createCairoFunctionStub } from '../../../utils/functionGeneration'; import { safeGetNodeType } from '../../../utils/nodeTypeProcessing'; import { mapRange, typeNameFromTypeNode } from '../../../utils/utils'; import { - generateFile, forAllWidths, bound, msb, mask, getIntOrFixedByteBitWidth, + WarplibFunctionInfo, } from '../../utils'; -export function shr(): void { - generateFile( - 'shr', - [ +export function shr(): WarplibFunctionInfo { + return { + fileName: 'shr', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and, bitwise_not', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.math_cmp import is_le, is_le_felt', 'from starkware.cairo.common.uint256 import Uint256, uint256_and', 'from warplib.maths.pow2 import pow2', ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ `func warp_shr256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : felt) -> (`, @@ -95,7 +94,7 @@ export function shr(): void { ` }`, ` return (Uint256(0, 0),);`, `}`, - ]; + ].join('\n'); } else { return [ `func warp_shr${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, @@ -122,16 +121,16 @@ export function shr(): void { ` return (0,);`, ` }`, `}`, - ]; + ].join('\n'); } }), - ); + }; } -export function shr_signed(): void { - generateFile( - 'shr_signed', - [ +export function shr_signed(): WarplibFunctionInfo { + return { + fileName: 'shr_signed', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.math_cmp import is_le, is_le_felt', @@ -139,7 +138,7 @@ export function shr_signed(): void { 'from warplib.maths.pow2 import pow2', `from warplib.maths.shr import ${mapRange(32, (n) => `warp_shr${8 * n + 8}`).join(', ')}`, ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ `func warp_shr_signed256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : felt) -> (res : Uint256){`, @@ -177,7 +176,7 @@ export function shr_signed(): void { ` return (res,);`, ` }`, `}`, - ]; + ].join('\n'); } else { return [ `func warp_shr_signed${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(`, @@ -211,10 +210,10 @@ export function shr_signed(): void { ` return (res,);`, ` }`, `}`, - ]; + ].join('\n'); } }), - ); + }; } export function functionaliseShr(node: BinaryOperation, ast: AST): void { @@ -240,16 +239,15 @@ export function functionaliseShr(node: BinaryOperation, ast: AST): void { const importName = `warplib.maths.shr${signed ? '_signed' : ''}`; - const stub = createCairoFunctionStub( + const importedFunc = ast.registerImport( + node, + importName, fullName, [ ['lhs', typeNameFromTypeNode(lhsType, ast)], ['rhs', typeNameFromTypeNode(rhsType, ast)], ], [['res', typeNameFromTypeNode(retType, ast)]], - ['range_check_ptr', 'bitwise_ptr'], - ast, - node, ); const call = new FunctionCall( ast.reserveId(), @@ -261,11 +259,10 @@ export function functionaliseShr(node: BinaryOperation, ast: AST): void { '', `function (${node.vLeftExpression.typeString}, ${node.vRightExpression.typeString}) returns (${node.typeString})`, fullName, - stub.id, + importedFunc.id, ), [node.vLeftExpression, node.vRightExpression], ); ast.replaceNode(node, call); - ast.registerImport(call, importName, fullName); } diff --git a/src/warplib/implementations/maths/sub.ts b/src/warplib/implementations/maths/sub.ts index 5a394ef57..d62feb3cb 100644 --- a/src/warplib/implementations/maths/sub.ts +++ b/src/warplib/implementations/maths/sub.ts @@ -2,27 +2,26 @@ import assert from 'assert'; import { BinaryOperation, IntType } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; import { printTypeNode } from '../../../utils/astPrinter'; -import { Implicits } from '../../../utils/implicits'; import { safeGetNodeType } from '../../../utils/nodeTypeProcessing'; import { - generateFile, forAllWidths, bound, mask, msb, msbAndNext, IntxIntFunction, + WarplibFunctionInfo, } from '../../utils'; -export function sub_unsafe(): void { - generateFile( - 'sub_unsafe', - [ +export function sub_unsafe(): WarplibFunctionInfo { + return { + fileName: 'sub_unsafe', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.uint256 import Uint256', ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ `func warp_sub_unsafe256{bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (`, @@ -42,7 +41,7 @@ export function sub_unsafe(): void { ` return (Uint256(low_safe, high),);`, ` }`, `}`, - ]; + ].join('\n'); } else { return [ `func warp_sub_unsafe${width}{bitwise_ptr : BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (`, @@ -51,21 +50,21 @@ export function sub_unsafe(): void { ` let (res) = bitwise_and(res, ${mask(width)});`, ` return (res,);`, `}`, - ]; + ].join('\n'); } }), - ); + }; } -export function sub_signed(): void { - generateFile( - 'sub_signed', - [ +export function sub_signed(): WarplibFunctionInfo { + return { + fileName: 'sub_signed', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.uint256 import Uint256, uint256_add, uint256_signed_le, uint256_sub, uint256_not', ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ `func warp_sub_signed${width}{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (`, @@ -91,7 +90,7 @@ export function sub_signed(): void { ` // Narrow and return`, ` return (res,);`, `}`, - ]; + ].join('\n'); } else { return [ `func warp_sub_signed${width}{bitwise_ptr : BitwiseBuiltin*}(lhs : felt, rhs : felt) -> (`, @@ -114,28 +113,28 @@ export function sub_signed(): void { ` let (res) = bitwise_and(extended_res, ${mask(width)});`, ` return (res,);`, `}`, - ]; + ].join('\n'); } }), - ); + }; } -export function sub_signed_unsafe(): void { - generateFile( - 'sub_signed_unsafe', - [ +export function sub_signed_unsafe(): WarplibFunctionInfo { + return { + fileName: 'sub_signed_unsafe', + imports: [ 'from starkware.cairo.common.bitwise import bitwise_and', 'from starkware.cairo.common.cairo_builtins import BitwiseBuiltin', 'from starkware.cairo.common.uint256 import Uint256, uint256_sub', ], - forAllWidths((width) => { + functions: forAllWidths((width) => { if (width === 256) { return [ 'func warp_sub_signed_unsafe256{range_check_ptr}(lhs : Uint256, rhs : Uint256) -> (res : Uint256){', ' let (res) = uint256_sub(lhs, rhs);', ' return (res,);', '}', - ]; + ].join('\n'); } else { return [ `func warp_sub_signed_unsafe${width}{bitwise_ptr : BitwiseBuiltin*}(`, @@ -154,37 +153,24 @@ export function sub_signed_unsafe(): void { ` let (res) = bitwise_and(extended_res, ${mask(width)});`, ` return (res,);`, `}`, - ]; + ].join('\n'); } }), - ); + }; } //func warp_sub{range_check_ptr}(lhs : felt, rhs : felt) -> (res : felt): //func warp_sub256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}(lhs : Uint256, rhs : Uint256) -> (res : Uint256): export function functionaliseSub(node: BinaryOperation, unsafe: boolean, ast: AST): void { - const implicitsFn = (width: number, signed: boolean): Implicits[] => { - if (signed) { - if (width === 256) return ['range_check_ptr', 'bitwise_ptr']; - else return ['bitwise_ptr']; - } else { - if (unsafe) { - return ['bitwise_ptr']; - } else { - if (width === 256) return ['range_check_ptr', 'bitwise_ptr']; - else return ['range_check_ptr']; - } - } - }; const typeNode = safeGetNodeType(node, ast.inference); assert( typeNode instanceof IntType, `Expected IntType for subtraction, got ${printTypeNode(typeNode)}`, ); if (unsafe) { - IntxIntFunction(node, 'sub', 'always', true, unsafe, implicitsFn, ast); + IntxIntFunction(node, 'sub', 'always', true, unsafe, ast); } else { - IntxIntFunction(node, 'sub', 'signedOrWide', true, unsafe, implicitsFn, ast); + IntxIntFunction(node, 'sub', 'signedOrWide', true, unsafe, ast); } } diff --git a/src/warplib/implementations/maths/xor.ts b/src/warplib/implementations/maths/xor.ts index 30cc0ca25..5097ad341 100644 --- a/src/warplib/implementations/maths/xor.ts +++ b/src/warplib/implementations/maths/xor.ts @@ -1,12 +1,7 @@ import { BinaryOperation } from 'solc-typed-ast'; import { AST } from '../../../ast/ast'; -import { Implicits } from '../../../utils/implicits'; import { IntxIntFunction } from '../../utils'; export function functionaliseXor(node: BinaryOperation, ast: AST): void { - const implicitsFn = (width: number): Implicits[] => { - if (width === 256) return ['range_check_ptr', 'bitwise_ptr']; - else return ['bitwise_ptr']; - }; - IntxIntFunction(node, 'xor', 'only256', false, false, implicitsFn, ast); + IntxIntFunction(node, 'xor', 'only256', false, false, ast); } diff --git a/src/warplib/utils.ts b/src/warplib/utils.ts index 889acb297..9cb9162ee 100644 --- a/src/warplib/utils.ts +++ b/src/warplib/utils.ts @@ -13,13 +13,18 @@ import { } from 'solc-typed-ast'; import { AST } from '../ast/ast'; import { printNode, printTypeNode } from '../utils/astPrinter'; -import { createCairoFunctionStub } from '../utils/functionGeneration'; -import { Implicits } from '../utils/implicits'; import { mapRange, typeNameFromTypeNode } from '../utils/utils'; import { safeGetNodeType } from '../utils/nodeTypeProcessing'; +import path from 'path'; -export function forAllWidths(funcGen: (width: number) => string[]): string[] { - return mapRange(32, (n) => 8 * (n + 1)).flatMap(funcGen); +export type WarplibFunctionInfo = { + fileName: string; + imports: string[]; + functions: string[]; +}; + +export function forAllWidths(funcGen: (width: number) => T): T[] { + return mapRange(32, (n) => 8 * (n + 1)).map(funcGen); } export function pow2(n: number): bigint { @@ -54,10 +59,12 @@ export function msbAndNext(width: number): string { // This is used along with the commented out code in generateFile to enable cairo-formatting // const warpVenvPrefix = `PATH=${path.resolve(__dirname, '..', '..', 'warp_venv', 'bin')}:$PATH`; -export function generateFile(name: string, imports: string[], functions: string[]): void { +export function generateFile(warpFunc: WarplibFunctionInfo): void { + const pathToFile = path.join('.', 'warplib', 'maths', `${warpFunc.fileName}.cairo`); + fs.writeFileSync( - `./warplib/maths/${name}.cairo`, - `//AUTO-GENERATED\n${imports.join('\n')}\n\n${functions.join('\n')}\n`, + pathToFile, + `//AUTO-GENERATED\n${warpFunc.imports.join('\n')}\n\n${warpFunc.functions.join('\n')}\n`, ); // Disable cairo-formatting for now, as it has a bug that breaks the generated code // execSync(`${warpVenvPrefix} cairo-format -i ./warplib/maths/${name}.cairo`); @@ -69,7 +76,6 @@ export function IntxIntFunction( appendWidth: 'always' | 'only256' | 'signedOrWide', separateSigned: boolean, unsafe: boolean, - implicits: (width: number, signed: boolean) => Implicits[], ast: AST, ) { const lhsType = typeNameFromTypeNode(safeGetNodeType(node.vLeftExpression, ast.inference), ast); @@ -99,17 +105,17 @@ export function IntxIntFunction( unsafe ? '_unsafe' : '', ].join(''); - const stub = createCairoFunctionStub( + const importedFunc = ast.registerImport( + node, + importName, fullName, [ ['lhs', lhsType], ['rhs', rhsType], ], [['res', typeNameFromTypeNode(retType, ast)]], - implicits(width, signed), - ast, - node, ); + const call = new FunctionCall( ast.reserveId(), node.src, @@ -120,13 +126,12 @@ export function IntxIntFunction( '', `function (${node.typeString}, ${node.typeString}) returns (${node.typeString})`, fullName, - stub.id, + importedFunc.id, ), [node.vLeftExpression, node.vRightExpression], ); ast.replaceNode(node, call); - ast.registerImport(call, importName, fullName); } export function Comparison( @@ -134,7 +139,6 @@ export function Comparison( name: string, appendWidth: 'only256' | 'signedOrWide', separateSigned: boolean, - implicits: (wide: boolean, signed: boolean) => Implicits[], ast: AST, ): void { const lhsType = safeGetNodeType(node.vLeftExpression, ast.inference); @@ -154,16 +158,15 @@ export function Comparison( const importName = `warplib.maths.${name}${signed && separateSigned ? '_signed' : ''}`; - const stub = createCairoFunctionStub( + const importedFunc = ast.registerImport( + node, + importName, fullName, [ ['lhs', typeNameFromTypeNode(lhsType, ast)], ['rhs', typeNameFromTypeNode(rhsType, ast)], ], [['res', typeNameFromTypeNode(retType, ast)]], - implicits(wide, signed), - ast, - node, ); const call = new FunctionCall( @@ -176,13 +179,12 @@ export function Comparison( '', `function (${node.vLeftExpression.typeString}, ${node.vRightExpression.typeString}) returns (${node.typeString})`, fullName, - stub.id, + importedFunc.id, ), [node.vLeftExpression, node.vRightExpression], ); ast.replaceNode(node, call); - ast.registerImport(call, importName, fullName); } export function IntFunction( @@ -190,7 +192,6 @@ export function IntFunction( argument: Expression, name: string, fileName: string, - implicits: (wide: boolean) => Implicits[], ast: AST, ): void { const opType = safeGetNodeType(argument, ast.inference); @@ -201,13 +202,13 @@ export function IntFunction( ); const width = getIntOrFixedByteBitWidth(retType); const fullName = `warp_${name}${width}`; - const stub = createCairoFunctionStub( + + const importedFunc = ast.registerImport( + node, + `warplib.maths.${fileName}`, fullName, [['op', typeNameFromTypeNode(opType, ast)]], [['res', typeNameFromTypeNode(retType, ast)]], - implicits(width === 256), - ast, - node, ); const call = new FunctionCall( @@ -220,13 +221,12 @@ export function IntFunction( '', `function (${argument.typeString}) returns (${node.typeString})`, fullName, - stub.id, + importedFunc.id, ), [argument], ); ast.replaceNode(node, call); - ast.registerImport(call, `warplib.maths.${fileName}`, fullName); } export function BoolxBoolFunction(node: BinaryOperation, name: string, ast: AST): void { @@ -248,16 +248,13 @@ export function BoolxBoolFunction(node: BinaryOperation, name: string, ast: AST) ); const fullName = `warp_${name}`; - const stub = createCairoFunctionStub( - fullName, - [ - ['lhs', typeNameFromTypeNode(lhsType, ast)], - ['rhs', typeNameFromTypeNode(rhsType, ast)], - ], - [['res', typeNameFromTypeNode(retType, ast)]], - [], - ast, + + const importedFunc = ast.registerImport( node, + `warplib.maths.${name}`, + fullName, + [['lhs', typeNameFromTypeNode(lhsType, ast)]], + [['rhs', typeNameFromTypeNode(rhsType, ast)]], ); const call = new FunctionCall( @@ -270,13 +267,12 @@ export function BoolxBoolFunction(node: BinaryOperation, name: string, ast: AST) '', `function (${node.vLeftExpression.typeString}, ${node.vRightExpression.typeString}) returns (${node.typeString})`, fullName, - stub.id, + importedFunc.id, ), [node.vLeftExpression, node.vRightExpression], ); ast.replaceNode(node, call); - ast.registerImport(call, `warplib.maths.${name}`, fullName); } export function getIntOrFixedByteBitWidth(type: TypeNode): number { diff --git a/tests/testing.ts b/tests/testing.ts index a64b9d288..80f6948ee 100644 --- a/tests/testing.ts +++ b/tests/testing.ts @@ -55,7 +55,6 @@ const expectedResults = new Map( ['exampleContracts/boolOpSideEffects.sol', 'Success'], ['exampleContracts/bytesXAccess.sol', 'Success'], ['exampleContracts/c2c.sol', 'Success'], - // Uses conditionals explicitly ['exampleContracts/conditional.sol', 'Success'], ['exampleContracts/conditionalSimple.sol', 'Success'], ['exampleContracts/contractToContract.sol', 'Success'], @@ -82,9 +81,7 @@ const expectedResults = new Map( ['exampleContracts/externalFunction.sol', 'Success'], ['exampleContracts/fallbackWithoutArgs.sol', 'Success'], ['exampleContracts/fallbackWithArgs.sol', 'WillNotSupport'], - // Cannot import with a - in the filename ['exampleContracts/fileWithMinusSignIncluded-.sol', 'Success'], - // Typestring for the internal function call doesn't contain a location so a read isn't generated ['exampleContracts/freeFunction.sol', 'Success'], ['exampleContracts/freeStruct.sol', 'Success'], ['exampleContracts/functionWithNestedReturn.sol', 'Success'], @@ -135,9 +132,9 @@ const expectedResults = new Map( ['exampleContracts/inheritance/super/derived.sol', 'Success'], ['exampleContracts/inheritance/super/mid.sol', 'Success'], ['exampleContracts/inheritance/variables.sol', 'Success'], - // Requires struct imports ['exampleContracts/interfaces.sol', 'Success'], ['exampleContracts/interfaceFromBaseContract.sol', 'Success'], + ['exampleContracts/internalFunctions.sol', 'Success'], ['exampleContracts/invalidSolidity.sol', 'SolCompileFailed'], ['exampleContracts/lib.sol', 'Success'], ['exampleContracts/libraries/usingForStar.sol', 'Success'], diff --git a/yarn.lock b/yarn.lock index 499dcf1a4..f1f33e889 100644 --- a/yarn.lock +++ b/yarn.lock @@ -764,6 +764,14 @@ dependencies: "@types/node" "*" +"@types/glob@^8.1.0": + version "8.1.0" + resolved "https://registry.yarnpkg.com/@types/glob/-/glob-8.1.0.tgz#b63e70155391b0584dce44e7ea25190bbc38f2fc" + integrity sha512-IO+MJPVhoqz+28h1qLAcBEH2+xHMK6MTyHJc7MTnnYb6wsoLR29POVGJ7LycmVXIqyy/4/2ShP5sUwTXuOwb/w== + dependencies: + "@types/minimatch" "^5.1.2" + "@types/node" "*" + "@types/json-schema@^7.0.9": version "7.0.11" resolved "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.11.tgz" @@ -781,6 +789,11 @@ resolved "https://registry.npmjs.org/@types/minimatch/-/minimatch-3.0.5.tgz" integrity sha512-Klz949h02Gz2uZCMGwDUSDS1YBlTdDDgbWHi+81l29tQALUtvz4rAYi5uoVhE5Lagoq6DeqAUlbrHvW/mXDgdQ== +"@types/minimatch@^5.1.2": + version "5.1.2" + resolved "https://registry.yarnpkg.com/@types/minimatch/-/minimatch-5.1.2.tgz#07508b45797cb81ec3f273011b054cd0755eddca" + integrity sha512-K0VQKziLUWkVKiRVrx4a40iPaxTUefQmjtkQofBkYRcoaaL/8rhwDWww9qWbrgicNOgnpIsMxyNIUM4+n6dUIA== + "@types/mocha@^9.1.0": version "9.1.0" resolved "https://registry.npmjs.org/@types/mocha/-/mocha-9.1.0.tgz" @@ -1200,6 +1213,13 @@ brace-expansion@^1.1.7: balanced-match "^1.0.0" concat-map "0.0.1" +brace-expansion@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/brace-expansion/-/brace-expansion-2.0.1.tgz#1edc459e0f0c548486ecf9fc99f2221364b9a0ae" + integrity sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA== + dependencies: + balanced-match "^1.0.0" + braces@^3.0.2, braces@~3.0.2: version "3.0.2" resolved "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz" @@ -2655,6 +2675,17 @@ glob@7.2.0, glob@^7.1.3: once "^1.3.0" path-is-absolute "^1.0.0" +glob@^8.1.0: + version "8.1.0" + resolved "https://registry.yarnpkg.com/glob/-/glob-8.1.0.tgz#d388f656593ef708ee3e34640fdfb99a9fd1c33e" + integrity sha512-r8hpEjiQEYlF2QU0df3dS+nxxSIreXQS1qRhMJM0Q5NDdR386C7jb7Hwwod8Fgiuex+k0GFjgft18yvxm5XoCQ== + dependencies: + fs.realpath "^1.0.0" + inflight "^1.0.4" + inherits "2" + minimatch "^5.0.1" + once "^1.3.0" + global-modules@^1.0.0: version "1.0.0" resolved "https://registry.npmjs.org/global-modules/-/global-modules-1.0.0.tgz" @@ -3593,6 +3624,13 @@ minimatch@^3.0.4, minimatch@^3.0.5, minimatch@^3.1.2: dependencies: brace-expansion "^1.1.7" +minimatch@^5.0.1: + version "5.1.6" + resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-5.1.6.tgz#1cfcb8cf5522ea69952cd2af95ae09477f122a96" + integrity sha512-lKwV/1brpG6mBUFHtb7NUmtABCb2WZZmm2wNiOA5hAb8VdCS4B3dtMWyvcoViccwAW/COERjXLt0zP1zXUN26g== + dependencies: + brace-expansion "^2.0.1" + minimist@^1.2.6: version "1.2.6" resolved "https://registry.npmjs.org/minimist/-/minimist-1.2.6.tgz"