Skip to content
This repository has been archived by the owner on Sep 5, 2023. It is now read-only.

Cross contract call #1034

Merged
merged 17 commits into from
May 15, 2023
1 change: 1 addition & 0 deletions src/ast/cairoNodes/cairoFunctionDefinition.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ export enum FunctionStubKind {
FunctionDefStub,
StorageDefStub,
StructDefStub,
TraitStructDefStub,
}

export class CairoFunctionDefinition extends FunctionDefinition {
Expand Down
3 changes: 2 additions & 1 deletion src/cairoUtilFuncGen/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ export abstract class CairoUtilFuncGenBase {
name: string,
inputs?: ParameterInfo[],
outputs?: ParameterInfo[],
options?: { isTrait?: boolean },
): CairoImportFunctionDefinition {
return createImport(location, name, this.sourceUnit, this.ast, inputs, outputs);
return createImport(location, name, this.sourceUnit, this.ast, inputs, outputs, options);
}
}

Expand Down
12 changes: 10 additions & 2 deletions src/cairoUtilFuncGen/memory/memoryRead.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
FunctionStateMutability,
generalizeType,
TypeNode,
FunctionDefinition,
} from 'solc-typed-ast';
import { CairoFunctionDefinition, typeNameFromTypeNode } from '../../export';
import {
Expand Down Expand Up @@ -101,9 +102,16 @@ export class MemoryReadGen extends StringIndexedFuncGen {
}

private getOrCreate(typeToRead: CairoType): GeneratedFunctionInfo {
const functionsCalled: FunctionDefinition[] = [this.requireImport(...DICT_READ)];

const funcName = `WM${this.generatedFunctionsDef.size}_READ_${typeToRead.typeName}`;
const resultCairoType = typeToRead.toString();
const [reads, pack] = serialiseReads(typeToRead, readFelt, readFelt);
const [reads, pack, requiredImports] = serialiseReads(typeToRead, readFelt, readFelt);

requiredImports.map((i) => {
functionsCalled.push(this.requireImport(...i.import, [], [], { isTrait: i.isTrait }));
});

const funcInfo: GeneratedFunctionInfo = {
name: funcName,
code: [
Expand All @@ -113,7 +121,7 @@ export class MemoryReadGen extends StringIndexedFuncGen {
` return (${pack},);`,
'}',
].join('\n'),
functionsCalled: [this.requireImport(...DICT_READ)],
functionsCalled: functionsCalled,
};
return funcInfo;
}
Expand Down
94 changes: 65 additions & 29 deletions src/cairoUtilFuncGen/serialisation.ts
rjnrohit marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import {
CairoContractAddress,
CairoBool,
CairoFelt,
CairoStaticArray,
Expand All @@ -9,28 +10,48 @@ import {
WarpLocation,
} from '../utils/cairoTypeSystem';
import { TranspileFailedError } from '../utils/errors';
import { FELT252_INTO_BOOL } from '../utils/importPaths';
import { CONTRACT_ADDRESS, FELT252_INTO_BOOL, OPTION_TRAIT } from '../utils/importPaths';

export function serialiseReads(
type: CairoType,
readFelt: (offset: number) => string,
readId: (offset: number) => string,
): [reads: string[], pack: string, requiredImports: [string[], string][]] {
): [
reads: string[],
pack: string,
requiredImports: { import: [string[], string]; isTrait?: boolean }[],
] {
const packExpression = producePackExpression(type);
const reads: string[] = [];
const requiredImports: [string[], string][] = [];
const requiredImports: { import: [string[], string]; isTrait?: boolean }[] = [];
const packString: string = packExpression
.map((elem: string | Read) => {
if (elem === Read.Felt) {
.map((elem: packExpReturnType) => {
if (elem.dataOrDataType === Read.UN) {
reads.push(readFelt(reads.length));
} else if (elem === Read.Id) {
reads.push(
`let read${reads.length} = core::integer::u${elem.metadata.nBits}_from_felt252(read${
reads.length - 1
});`,
);
} else if (elem.dataOrDataType === Read.Address) {
requiredImports.push({ import: OPTION_TRAIT, isTrait: true });
requiredImports.push({ import: CONTRACT_ADDRESS });
reads.push(readFelt(reads.length));
reads.push(
`let read${reads.length} = starknet::contract_address_try_from_felt252(read${
reads.length - 1
}).unwrap();`,
);
} else if (elem.dataOrDataType === Read.Felt) {
reads.push(readFelt(reads.length));
} else if (elem.dataOrDataType === Read.Id) {
reads.push(readId(reads.length));
} else if (elem === Read.Bool) {
requiredImports.push(FELT252_INTO_BOOL);
} else if (elem.dataOrDataType === Read.Bool) {
requiredImports.push({ import: FELT252_INTO_BOOL });
reads.push(readFelt(reads.length));
reads.push(`let read${reads.length} = felt252_into_bool(read${reads.length - 1});`);
} else {
return elem;
return elem.dataOrDataType;
}
return `read${reads.length - 1}`;
})
Expand All @@ -41,58 +62,73 @@ export function serialiseReads(
enum Read {
Felt,
Id,
Address,
UN,
Bool,
}

function producePackExpression(type: CairoType): (string | Read)[] {
if (type instanceof WarpLocation) return [Read.Id];
if (type instanceof CairoFelt) return [Read.Felt];
if (type instanceof CairoBool) return [Read.Bool];
interface packExpReturnType {
dataOrDataType: string | Read;
metadata: {
nBits?: number;
};
}

function createPackExpReturnType(dataOrDataType: string | Read): packExpReturnType {
return { dataOrDataType, metadata: {} };
}

function producePackExpression(type: CairoType): packExpReturnType[] {
if (type instanceof WarpLocation) return [createPackExpReturnType(Read.Id)];
if (type instanceof CairoFelt) return [createPackExpReturnType(Read.Felt)];
if (type instanceof CairoContractAddress) return [createPackExpReturnType(Read.Address)];
if (type instanceof CairoBool) return [createPackExpReturnType(Read.Bool)];
if (type instanceof CairoStaticArray) {
return [
'(',
createPackExpReturnType('('),
...Array(type.size)
.fill([...producePackExpression(type.type), ','])
.flat(),
')',
.flat()
.map(createPackExpReturnType),
createPackExpReturnType(')'),
];
}

if (type instanceof CairoUint) {
if (type.fullStringRepresentation === CairoUint256.fullStringRepresentation) {
return [
type.toString(),
'{',
createPackExpReturnType(type.toString()),
createPackExpReturnType('{'),
...[
['low', new CairoUint(128)],
['high', new CairoUint(128)],
]
.flatMap(([memberName, memberType]) => [
memberName as string,
':',
createPackExpReturnType(memberName as string),
createPackExpReturnType(':'),
...producePackExpression(memberType as CairoType),
',',
createPackExpReturnType(','),
])
.slice(0, -1),
'}',
createPackExpReturnType('}'),
];
}
return [`core::integer::${type.toString()}_from_felt252(${Read.Felt})`];
return [{ dataOrDataType: Read.UN, metadata: { nBits: type.nBits } }];
}

if (type instanceof CairoStruct) {
return [
type.name,
'{',
createPackExpReturnType(type.name),
createPackExpReturnType('{'),
...[...type.members.entries()]
.flatMap(([memberName, memberType]) => [
memberName,
':',
createPackExpReturnType(memberName),
createPackExpReturnType(':'),
...producePackExpression(memberType),
',',
createPackExpReturnType(','),
])
.slice(0, -1),
'}',
createPackExpReturnType('}'),
];
}

Expand Down
3 changes: 1 addition & 2 deletions src/cairoUtilFuncGen/storage/storageRead.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ export class StorageReadGen extends StringIndexedFuncGen {
const [reads, pack, requiredImports] = serialiseReads(typeToRead, readFelt, readId);

requiredImports.map((i) => {
const funcDef = this.requireImport(...i);
if (!functionsCalled.includes(funcDef)) functionsCalled.push(funcDef);
functionsCalled.push(this.requireImport(...i.import, [], [], { isTrait: i.isTrait }));
});

const funcInfo: GeneratedFunctionInfo = {
Expand Down
6 changes: 5 additions & 1 deletion src/cairoWriter/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ export function getInterfaceNameForContract(
contractName: string,
nodeInSourceUnit: ASTNode,
interfaceNameMappings: Map<SourceUnit, Map<string, string>>,
isDelegateCall = false,
dispatcherSuffix = true,
): string {
const sourceUnit =
nodeInSourceUnit instanceof SourceUnit
Expand All @@ -36,5 +38,7 @@ export function getInterfaceNameForContract(
`An error occurred during name substitution for the interface ${contractName}`,
);

return interfaceName;
return (
interfaceName + (dispatcherSuffix ? (isDelegateCall ? 'LibraryDispatcher' : 'Dispatcher') : '')
);
}
2 changes: 1 addition & 1 deletion src/cairoWriter/writers/assignmentWriter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ export class AssignmentWriter extends CairoASTNodeWriter {
assert(node.operator === '=', `Unexpected operator ${node.operator}`);
const [lhs, rhs] = [node.vLeftHandSide, node.vRightHandSide];
const nodes = [lhs, rhs].map((v) => writer.write(v));
return [`let ${nodes[0]} ${node.operator} ${nodes[1]};`];
return [`${nodes[0]} ${node.operator} ${nodes[1]}`];
}
}
19 changes: 13 additions & 6 deletions src/cairoWriter/writers/cairoContractWriter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,11 @@ export class CairoContractWriter extends CairoASTNodeWriter {

const contractHeader = '#[contract] \n' + `mod ${node.name} {`;

const globalImports = ['use starknet::ContractAddress;'].join('\n');
rjnrohit marked this conversation as resolved.
Show resolved Hide resolved

return [
[
globalImports,
contractHeader,
documentation,
writtenImportFuncs,
Expand Down Expand Up @@ -194,9 +197,10 @@ export class CairoContractWriter extends CairoASTNodeWriter {
// remove all content between any two pairing curly braces
.replace(/\{[^]*\}/g, '')
.split('\n');
const funcLineIndex = resultLines.findIndex((line) => line.startsWith('func'));
const funcLineIndex = resultLines.findIndex((line) => line.startsWith('fn'));
resultLines.splice(0, funcLineIndex);
return resultLines.join('\n') + '{\n}';
resultLines[0] = '#[external] ' + resultLines[0];
return resultLines.join('\n') + ';';
});
// Handle the workaround of genContractInterface function of externalContractInterfaceInserter.ts
// Remove `@interface` to get the actual contract interface name
Expand All @@ -206,14 +210,17 @@ export class CairoContractWriter extends CairoASTNodeWriter {
node.name.replace(TEMP_INTERFACE_SUFFIX, ''),
node,
interfaceNameMappings,
false,
false,
)
: node.name;

return [
[
documentation,
[`@contract_interface`, `namespace ${interfaceName}{`, ...functions, `}`].join('\n'),
].join('\n'),
endent`#[abi]
${documentation}
trait ${interfaceName}{
${functions.join('\n')}
}`,
];
}
}
17 changes: 8 additions & 9 deletions src/cairoWriter/writers/functionCallWriter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {
import { CairoASTNodeWriter } from '../base';
import { getDocumentation, getInterfaceNameForContract } from '../utils';
import { interfaceNameMappings } from './sourceUnitWriter';
import endent from 'endent';

export class FunctionCallWriter extends CairoASTNodeWriter {
writeInner(node: FunctionCall, writer: ASTWriter): SrcDesc {
Expand Down Expand Up @@ -50,15 +51,13 @@ export class FunctionCallWriter extends CairoASTNodeWriter {
isDelegateCall = true;
firstArg = classHashTextLets[1];
}
return [
`${getInterfaceNameForContract(
nodeType.definition.name,
node,
interfaceNameMappings,
)}.${(isDelegateCall ? 'library_call_' : '') + memberName}(${firstArg}${
args ? ', ' : ''
}${args})`,
];
const interfaceName = getInterfaceNameForContract(
nodeType.definition.name,
node,
interfaceNameMappings,
isDelegateCall,
);
return [endent`${interfaceName}{contract_address: ${firstArg}}.${memberName}(${args})`];
}
} else if (
node.vReferencedDeclaration instanceof CairoGeneratedFunctionDefinition &&
Expand Down
12 changes: 7 additions & 5 deletions src/cairoWriter/writers/variableDeclarationStatementWriter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import { isDynamicArray, safeGetNodeType } from '../../utils/nodeTypeProcessing'
import { isExternalCall } from '../../utils/utils';
import { CairoASTNodeWriter } from '../base';
import { getDocumentation } from '../utils';
import endent from 'endent';

export class VariableDeclarationStatementWriter extends CairoASTNodeWriter {
gapVarCounter = 0;
Expand Down Expand Up @@ -69,14 +70,15 @@ export class VariableDeclarationStatementWriter extends CairoASTNodeWriter {
});
if (declarations.length > 1) {
return [
[
documentation,
`let (${declarations.join(', ')}) = ${writer.write(node.vInitialValue)};`,
].join('\n'),
endent`${documentation}
let (${declarations.map((decl) => `mut ${decl}`).join(', ')}) = ${writer.write(
node.vInitialValue,
)};`,
];
}
return [
[documentation, `let ${declarations[0]} = ${writer.write(node.vInitialValue)};`].join('\n'),
endent`${documentation}
let mut ${declarations[0]} = ${writer.write(node.vInitialValue)};`,
];
}
}
2 changes: 1 addition & 1 deletion src/passes/builtinHandler/thisKeyword.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ export class ThisKeyword extends ASTMapper {
(contract) => contract.name === getTemporaryInterfaceName(currentContract.name),
);
if (contractIndex === -1) {
const insertedInterface = genContractInterface(currentContract, sourceUnit, ast);
const insertedInterface = genContractInterface(node, currentContract, sourceUnit, ast);
replaceInterfaceWithCairoContract(insertedInterface, ast);
}
}
Expand Down
Loading