Skip to content
This repository has been archived by the owner on Jun 11, 2024. It is now read-only.

Commit

Permalink
Added LSK check on beforeCrossChainMessageForwarding (#8806)
Browse files Browse the repository at this point in the history
* Added LSK check on beforeCrossChainMessageForwarding

* Implemented function getMessageFeeTokenIDFromCCM and migrated code

* use sendingChainID to align with LIP

* Add chainID == ownChainAccount.chainID check

* Rearrange ownChainAccount check

* Rearrange code inside _getChannelCommon

* Update method name
  • Loading branch information
Phanco authored Aug 10, 2023
1 parent 47e3191 commit 3a280f7
Show file tree
Hide file tree
Showing 12 changed files with 93 additions and 36 deletions.
10 changes: 2 additions & 8 deletions framework/src/modules/fee/cc_method.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ export class FeeInteroperableMethod extends BaseCCMethod {
}

public async beforeCrossChainCommandExecute(ctx: CrossChainMessageContext): Promise<void> {
const messageTokenID = await this._interopMethod.getMessageFeeTokenID(
ctx,
ctx.ccm.sendingChainID,
);
const messageTokenID = await this._interopMethod.getMessageFeeTokenIDFromCCM(ctx, ctx.ccm);
await this._tokenMethod.lock(
ctx.getMethodContext(),
ctx.transaction.senderAddress,
Expand All @@ -59,10 +56,7 @@ export class FeeInteroperableMethod extends BaseCCMethod {
}

public async afterCrossChainCommandExecute(ctx: CrossChainMessageContext): Promise<void> {
const messageTokenID = await this._interopMethod.getMessageFeeTokenID(
ctx,
ctx.ccm.sendingChainID,
);
const messageTokenID = await this._interopMethod.getMessageFeeTokenIDFromCCM(ctx, ctx.ccm);
await this._tokenMethod.unlock(
ctx.getMethodContext(),
ctx.transaction.senderAddress,
Expand Down
2 changes: 2 additions & 0 deletions framework/src/modules/fee/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import { MethodContext, ImmutableMethodContext } from '../../state_machine/types';
import { JSONObject } from '../../types';
import { CCMsg } from '../interoperability';

export type FeeTokenID = Buffer;

Expand Down Expand Up @@ -76,4 +77,5 @@ export interface GetMinFeePerByteResponse {

export interface InteroperabilityMethod {
getMessageFeeTokenID(methodContext: ImmutableMethodContext, chainID: Buffer): Promise<Buffer>;
getMessageFeeTokenIDFromCCM(methodContext: ImmutableMethodContext, ccm: CCMsg): Promise<Buffer>;
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,13 @@ export abstract class BaseInteroperabilityMethod<
}

private async _getChannelCommon(context: ImmutableMethodContext, chainID: Buffer) {
const mainchainID = getMainchainID(chainID);
const ownChainAccount = await this.getOwnChainAccount(context);
const hasChainAccount = await this.stores.get(ChainAccountStore).has(context, chainID);
if (chainID.equals(ownChainAccount.chainID)) {
throw new Error('Channel with own chain account does not exist.');
}

const mainchainID = getMainchainID(chainID);
const hasChainAccount = await this.stores.get(ChainAccountStore).has(context, chainID);
let updatedChainID = chainID;
// Check for direct channel while processing on a sidechain
if (!ownChainAccount.chainID.equals(mainchainID) && !hasChainAccount) {
Expand All @@ -120,6 +123,13 @@ export abstract class BaseInteroperabilityMethod<
return channel.messageFeeTokenID;
}

public async getMessageFeeTokenIDFromCCM(
context: ImmutableMethodContext,
ccm: CCMsg,
): Promise<Buffer> {
return this.getMessageFeeTokenID(context, ccm.sendingChainID);
}

// https://github.com/LiskHQ/lips/blob/main/proposals/lip-0045.md#getminreturnfeeperbyte
public async getMinReturnFeePerByte(
context: ImmutableMethodContext,
Expand Down
20 changes: 9 additions & 11 deletions framework/src/modules/token/cc_method.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import { RecoverEvent } from './events/recover';
import { EMPTY_BYTES } from '../interoperability/constants';
import { BeforeCCMForwardingEvent } from './events/before_ccm_forwarding';
import { splitTokenID } from './utils';
import { getEncodedCCMAndID } from '../interoperability/utils';
import { getEncodedCCMAndID, getTokenIDLSK } from '../interoperability/utils';
import { InternalMethod } from './internal_method';

export class TokenInteroperableMethod extends BaseCCMethod {
Expand All @@ -45,10 +45,7 @@ export class TokenInteroperableMethod extends BaseCCMethod {
ccm,
} = ctx;
const methodContext = ctx.getMethodContext();
const tokenID = await this._interopMethod.getMessageFeeTokenID(
methodContext,
ccm.sendingChainID,
);
const tokenID = await this._interopMethod.getMessageFeeTokenIDFromCCM(methodContext, ccm);
const { ccmID } = getEncodedCCMAndID(ccm);
const [chainID] = splitTokenID(tokenID);
const userStore = this.stores.get(UserStore);
Expand Down Expand Up @@ -88,12 +85,16 @@ export class TokenInteroperableMethod extends BaseCCMethod {
public async beforeCrossChainMessageForwarding(ctx: CrossChainMessageContext): Promise<void> {
const { ccm } = ctx;
const methodContext = ctx.getMethodContext();
const messageFeeTokenID = await this._interopMethod.getMessageFeeTokenID(
const messageFeeTokenID = await this._interopMethod.getMessageFeeTokenIDFromCCM(
methodContext,
ccm.receivingChainID,
ccm,
);
const { ccmID } = getEncodedCCMAndID(ccm);

if (!messageFeeTokenID.equals(getTokenIDLSK(ctx.chainID))) {
throw new Error('Message fee token should be LSK.');
}

const escrowStore = this.stores.get(EscrowStore);
const escrowKey = escrowStore.getKey(ccm.sendingChainID, messageFeeTokenID);
const escrowAccount = await escrowStore.getOrDefault(methodContext, escrowKey);
Expand Down Expand Up @@ -128,10 +129,7 @@ export class TokenInteroperableMethod extends BaseCCMethod {
public async verifyCrossChainMessage(ctx: CrossChainMessageContext): Promise<void> {
const { ccm } = ctx;
const methodContext = ctx.getMethodContext();
const tokenID = await this._interopMethod.getMessageFeeTokenID(
methodContext,
ccm.sendingChainID,
);
const tokenID = await this._interopMethod.getMessageFeeTokenIDFromCCM(methodContext, ccm);
const [chainID] = splitTokenID(tokenID);
if (chainID.equals(ctx.chainID)) {
const escrowStore = this.stores.get(EscrowStore);
Expand Down
1 change: 1 addition & 0 deletions framework/src/modules/token/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ export interface InteroperabilityMethod {
terminateChain(methodContext: MethodContext, chainID: Buffer): Promise<void>;
getChannel(methodContext: MethodContext, chainID: Buffer): Promise<ChannelData>;
getMessageFeeTokenID(methodContext: ImmutableMethodContext, chainID: Buffer): Promise<Buffer>;
getMessageFeeTokenIDFromCCM(methodContext: ImmutableMethodContext, ccm: CCMsg): Promise<Buffer>;
}

export interface FeeMethod {
Expand Down
1 change: 1 addition & 0 deletions framework/test/unit/modules/fee/cc_method.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ describe('FeeInteroperableMethod', () => {
feeMethod.addDependencies(
{
getMessageFeeTokenID: jest.fn().mockResolvedValue(messageFeeTokenID),
getMessageFeeTokenIDFromCCM: jest.fn().mockResolvedValue(messageFeeTokenID),
},
{
burn: jest.fn(),
Expand Down
9 changes: 9 additions & 0 deletions framework/test/unit/modules/interoperability/method.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,15 @@ describe('Sample Method', () => {
expect(channelDataStoreMock.get).toHaveBeenCalledWith(expect.anything(), newChainID);
});

it('should throw error if chainID equals ownChainAccount.chainID', async () => {
jest.spyOn(ownChainAccountStoreMock, 'get').mockResolvedValue({
chainID: newChainID,
});
await expect(
sampleInteroperabilityMethod['_getChannelCommon'](methodContext, newChainID),
).rejects.toThrow('Channel with own chain account does not exist.');
});

it('should throw error if channel is not found', async () => {
jest.spyOn(channelDataStoreMock, 'has').mockResolvedValue(false);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ describe('CrossChain Transfer Command', () => {
terminateChain: jest.fn(),
getChannel: jest.fn(),
getMessageFeeTokenID: jest.fn(),
getMessageFeeTokenIDFromCCM: jest.fn(),
};
const config = {
ownChainID,
Expand Down
67 changes: 52 additions & 15 deletions framework/test/unit/modules/token/cc_method.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ describe('TokenInteroperableMethod', () => {
'hex',
);
const defaultAddress = address.getAddressFromPublicKey(defaultPublicKey);
const ownChainID = Buffer.from([0, 0, 0, 1]);
const ownChainID = Buffer.from([1, 0, 0, 0]);
const defaultTokenID = Buffer.concat([ownChainID, Buffer.alloc(4)]);
const defaultForeignTokenID = Buffer.from([0, 0, 0, 2, 0, 0, 0, 0]);
const defaultForeignTokenID = Buffer.from([2, 0, 0, 0, 0, 0, 0, 0]);
const defaultAccount = {
availableBalance: BigInt(10000000000),
lockedBalances: [
Expand Down Expand Up @@ -119,6 +119,7 @@ describe('TokenInteroperableMethod', () => {
{
send: jest.fn().mockResolvedValue(true),
getMessageFeeTokenID: jest.fn().mockResolvedValue(defaultTokenID),
getMessageFeeTokenIDFromCCM: jest.fn().mockResolvedValue(defaultTokenID),
} as never,
internalMethod,
);
Expand Down Expand Up @@ -153,7 +154,7 @@ describe('TokenInteroperableMethod', () => {
describe('beforeCrossChainCommandExecute', () => {
it('should credit fee to transaction sender if token id is not native', async () => {
jest
.spyOn(tokenInteropMethod['_interopMethod'], 'getMessageFeeTokenID')
.spyOn(tokenInteropMethod['_interopMethod'], 'getMessageFeeTokenIDFromCCM')
.mockResolvedValue(defaultForeignTokenID);
await expect(
tokenInteropMethod.beforeCrossChainCommandExecute({
Expand All @@ -162,7 +163,7 @@ describe('TokenInteroperableMethod', () => {
module: tokenModule.name,
nonce: BigInt(1),
sendingChainID,
receivingChainID: Buffer.from([0, 0, 0, 1]),
receivingChainID: ownChainID,
fee,
status: CCM_STATUS_OK,
params: utils.getRandomBytes(30),
Expand Down Expand Up @@ -205,7 +206,7 @@ describe('TokenInteroperableMethod', () => {
module: tokenModule.name,
nonce: BigInt(1),
sendingChainID,
receivingChainID: Buffer.from([0, 0, 0, 1]),
receivingChainID: ownChainID,
fee: fee + defaultEscrowAmount,
status: CCM_STATUS_OK,
params: utils.getRandomBytes(30),
Expand Down Expand Up @@ -243,7 +244,7 @@ describe('TokenInteroperableMethod', () => {
module: tokenModule.name,
nonce: BigInt(1),
sendingChainID,
receivingChainID: Buffer.from([0, 0, 0, 1]),
receivingChainID: ownChainID,
fee,
status: CCM_STATUS_OK,
params: utils.getRandomBytes(30),
Expand Down Expand Up @@ -285,6 +286,42 @@ describe('TokenInteroperableMethod', () => {
});

describe('beforeCrossChainMessageForwarding', () => {
it('should throw if messageFeeTokenID is not LSK', async () => {
jest
.spyOn(tokenInteropMethod['_interopMethod'], 'getMessageFeeTokenIDFromCCM')
.mockResolvedValue(defaultForeignTokenID);
await expect(
tokenInteropMethod.beforeCrossChainMessageForwarding({
ccm: {
crossChainCommand: CROSS_CHAIN_COMMAND_NAME_TRANSFER,
module: tokenModule.name,
nonce: BigInt(1),
sendingChainID,
receivingChainID: ownChainID,
fee: fee + defaultEscrowAmount,
status: CCM_STATUS_OK,
params: utils.getRandomBytes(30),
},
getMethodContext: () => methodContext,
eventQueue: new EventQueue(0),
getStore: (moduleID: Buffer, prefix: Buffer) => stateStore.getStore(moduleID, prefix),
logger: fakeLogger,
chainID: ownChainID,
header: {
timestamp: Date.now(),
height: 10,
},
stateStore,
contextStore,
transaction: {
fee,
senderAddress: defaultAddress,
params: defaultEncodedCCUParams,
},
}),
).rejects.toThrow('Message fee token should be LSK.');
});

it('should throw if escrow balance is not sufficient', async () => {
await expect(
tokenInteropMethod.beforeCrossChainMessageForwarding({
Expand All @@ -293,7 +330,7 @@ describe('TokenInteroperableMethod', () => {
module: tokenModule.name,
nonce: BigInt(1),
sendingChainID,
receivingChainID: Buffer.from([0, 0, 0, 1]),
receivingChainID: ownChainID,
fee: fee + defaultEscrowAmount,
status: CCM_STATUS_OK,
params: utils.getRandomBytes(30),
Expand Down Expand Up @@ -323,15 +360,15 @@ describe('TokenInteroperableMethod', () => {
);
});

it('should deduct escrow account for fee and credit to receving chain escrow account if ccm command is not transfer', async () => {
it('should deduct escrow account for fee and credit to receiving chain escrow account if ccm command is not transfer', async () => {
await expect(
tokenInteropMethod.beforeCrossChainMessageForwarding({
ccm: {
crossChainCommand: CROSS_CHAIN_COMMAND_REGISTRATION,
module: tokenModule.name,
nonce: BigInt(1),
sendingChainID,
receivingChainID: Buffer.from([0, 0, 0, 1]),
receivingChainID: ownChainID,
fee,
status: CCM_STATUS_OK,
params: codec.encode(crossChainForwardMessageParams, {
Expand Down Expand Up @@ -370,7 +407,7 @@ describe('TokenInteroperableMethod', () => {
expect(amount).toEqual(defaultEscrowAmount - fee);
const { amount: receiver } = await escrowStore.get(
methodContext,
escrowStore.getKey(Buffer.from([0, 0, 0, 1]), defaultTokenID),
escrowStore.getKey(ownChainID, defaultTokenID),
);
expect(receiver).toEqual(fee);
});
Expand All @@ -385,7 +422,7 @@ describe('TokenInteroperableMethod', () => {
module: tokenModule.name,
nonce: BigInt(1),
sendingChainID,
receivingChainID: Buffer.from([0, 0, 0, 1]),
receivingChainID: ownChainID,
fee,
status: CCM_STATUS_OK,
params: utils.getRandomBytes(30),
Expand Down Expand Up @@ -418,7 +455,7 @@ describe('TokenInteroperableMethod', () => {
module: tokenModule.name,
nonce: BigInt(1),
sendingChainID,
receivingChainID: Buffer.from([0, 0, 0, 1]),
receivingChainID: ownChainID,
fee: fee + defaultEscrowAmount,
status: CCM_STATUS_OK,
params: utils.getRandomBytes(30),
Expand All @@ -445,7 +482,7 @@ describe('TokenInteroperableMethod', () => {

it('should resolve if token id is not native', async () => {
jest
.spyOn(tokenInteropMethod['_interopMethod'], 'getMessageFeeTokenID')
.spyOn(tokenInteropMethod['_interopMethod'], 'getMessageFeeTokenIDFromCCM')
.mockResolvedValue(defaultForeignTokenID);
await expect(
tokenInteropMethod.verifyCrossChainMessage({
Expand All @@ -454,7 +491,7 @@ describe('TokenInteroperableMethod', () => {
module: tokenModule.name,
nonce: BigInt(1),
sendingChainID,
receivingChainID: Buffer.from([0, 0, 0, 1]),
receivingChainID: ownChainID,
fee,
status: CCM_STATUS_OK,
params: utils.getRandomBytes(30),
Expand Down Expand Up @@ -549,7 +586,7 @@ describe('TokenInteroperableMethod', () => {

it('should reject if token is not native', async () => {
jest
.spyOn(tokenInteropMethod['_interopMethod'], 'getMessageFeeTokenID')
.spyOn(tokenInteropMethod['_interopMethod'], 'getMessageFeeTokenIDFromCCM')
.mockResolvedValue(defaultForeignTokenID);
await expect(
tokenInteropMethod.recover({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ describe('Transfer command', () => {
terminateChain: jest.fn(),
getChannel: jest.fn(),
getMessageFeeTokenID: jest.fn(),
getMessageFeeTokenIDFromCCM: jest.fn(),
};
internalMethod.addDependencies({ payFee: jest.fn() });
method.addDependencies(interopMethod, internalMethod);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ describe('CCTransfer command', () => {
terminateChain: jest.Mock;
getChannel: jest.Mock;
getMessageFeeTokenID: jest.Mock;
getMessageFeeTokenIDFromCCM: jest.Mock;
};

beforeEach(() => {
Expand All @@ -129,6 +130,7 @@ describe('CCTransfer command', () => {
terminateChain: jest.fn(),
getChannel: jest.fn(),
getMessageFeeTokenID: jest.fn(),
getMessageFeeTokenIDFromCCM: jest.fn(),
};
internalMethod.addDependencies({
payFee: jest.fn(),
Expand Down
1 change: 1 addition & 0 deletions framework/test/unit/modules/token/endpoint.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ describe('token endpoint', () => {
terminateChain: jest.fn(),
getChannel: jest.fn(),
getMessageFeeTokenID: jest.fn(),
getMessageFeeTokenIDFromCCM: jest.fn(),
},
internalMethod,
);
Expand Down

0 comments on commit 3a280f7

Please sign in to comment.