diff --git a/src/hooks/swap/useSyncTokenDefaults.test.ts b/src/hooks/swap/useSyncTokenDefaults.test.ts new file mode 100644 index 000000000..6952fa244 --- /dev/null +++ b/src/hooks/swap/useSyncTokenDefaults.test.ts @@ -0,0 +1,127 @@ +import { TradeType } from '@uniswap/sdk-core' +import { SupportedChainId } from 'constants/chains' +import { DAI_POLYGON, nativeOnChain } from 'constants/tokens' +import { useAtomValue } from 'jotai/utils' +import { Field, stateAtom, Swap, swapAtom } from 'state/swap' +import { renderHook } from 'test' + +import { USDC_MAINNET } from '../../constants/tokens' +import useSyncTokenDefaults, { TokenDefaults } from './useSyncTokenDefaults' + +const MOCK_DAI_POLYGON = DAI_POLYGON +const MOCK_USDC_MAINNET = USDC_MAINNET + +const INITIAL_SWAP: Swap = { + type: TradeType.EXACT_INPUT, + amount: '10', + [Field.INPUT]: MOCK_USDC_MAINNET, + [Field.OUTPUT]: MOCK_DAI_POLYGON, +} + +const TOKEN_DEFAULTS: TokenDefaults = { + defaultInputAmount: 10, + defaultInputTokenAddress: 'NATIVE', + defaultOutputTokenAddress: 'NATIVE', +} + +jest.mock('@web3-react/core', () => { + const { SupportedChainId } = jest.requireActual('constants/chains') + + return { + useWeb3React: () => ({ + chainId: SupportedChainId.MAINNET, + connector: {}, + }), + } +}) + +jest.mock('../useTokenList', () => { + return { + useIsTokenListLoaded: () => true, + } +}) + +jest.mock('hooks/useCurrency', () => { + return { + useToken: () => MOCK_DAI_POLYGON, + } +}) + +describe('useSyncTokenDefaults', () => { + it('syncs to default chainId on initial render if defaultChainId is provided', () => { + const { rerender } = renderHook( + () => { + useSyncTokenDefaults({ ...TOKEN_DEFAULTS, defaultChainId: SupportedChainId.POLYGON }) + }, + { + initialAtomValues: [[stateAtom, INITIAL_SWAP]], + } + ) + + const { result } = rerender(() => useAtomValue(swapAtom)) + expect(result.current).toMatchObject({ + ...INITIAL_SWAP, + INPUT: nativeOnChain(SupportedChainId.POLYGON), + OUTPUT: nativeOnChain(SupportedChainId.POLYGON), + }) + }) + + it('does not sync to default chainId on initial render if defaultChainId is not provided', () => { + const { rerender } = renderHook( + () => { + useSyncTokenDefaults(TOKEN_DEFAULTS) + }, + { + initialAtomValues: [[stateAtom, INITIAL_SWAP]], + } + ) + + const { result } = rerender(() => useAtomValue(swapAtom)) + expect(result.current).toMatchObject({ + ...INITIAL_SWAP, + INPUT: nativeOnChain(SupportedChainId.MAINNET), + OUTPUT: nativeOnChain(SupportedChainId.MAINNET), + }) + }) + + it('syncs to default non NATIVE tokens of default chainId on initial render if defaultChainId is provided', () => { + const { rerender } = renderHook( + () => { + useSyncTokenDefaults({ + ...TOKEN_DEFAULTS, + defaultInputTokenAddress: DAI_POLYGON.address, + defaultOutputTokenAddress: DAI_POLYGON.address, + defaultChainId: SupportedChainId.POLYGON, + }) + }, + { + initialAtomValues: [[stateAtom, INITIAL_SWAP]], + } + ) + + const { result } = rerender(() => useAtomValue(swapAtom)) + expect(result.current).toMatchObject({ + ...INITIAL_SWAP, + INPUT: DAI_POLYGON, + OUTPUT: DAI_POLYGON, + }) + }) + + it('syncs to non NATIVE tokens of chainId on initial render if defaultChainId is not provided', () => { + const { rerender } = renderHook( + () => { + useSyncTokenDefaults(TOKEN_DEFAULTS) + }, + { + initialAtomValues: [[stateAtom, INITIAL_SWAP]], + } + ) + + const { result } = rerender(() => useAtomValue(swapAtom)) + expect(result.current).toMatchObject({ + ...INITIAL_SWAP, + INPUT: nativeOnChain(SupportedChainId.MAINNET), + OUTPUT: nativeOnChain(SupportedChainId.MAINNET), + }) + }) +}) diff --git a/src/hooks/swap/useSyncTokenDefaults.ts b/src/hooks/swap/useSyncTokenDefaults.ts index 95762611a..a8459b6a6 100644 --- a/src/hooks/swap/useSyncTokenDefaults.ts +++ b/src/hooks/swap/useSyncTokenDefaults.ts @@ -1,10 +1,11 @@ import { Currency, TradeType } from '@uniswap/sdk-core' import { useWeb3React } from '@web3-react/core' +import { Connector } from '@web3-react/types' +import { SupportedChainId } from 'constants/chains' import { nativeOnChain } from 'constants/tokens' import { useToken } from 'hooks/useCurrency' -import useNativeCurrency from 'hooks/useNativeCurrency' import { useUpdateAtom } from 'jotai/utils' -import { useCallback, useEffect, useRef } from 'react' +import { useCallback, useEffect, useMemo, useRef } from 'react' import { Field, Swap, swapAtom } from 'state/swap' import useOnSupportedNetwork from '../useOnSupportedNetwork' @@ -17,27 +18,31 @@ export interface TokenDefaults { defaultInputAmount?: number | string defaultOutputTokenAddress?: DefaultAddress defaultOutputAmount?: number | string + defaultChainId?: SupportedChainId } function useDefaultToken( defaultAddress: DefaultAddress | undefined, - chainId: number | undefined + chainId: number | undefined, + defaultToNative: boolean ): Currency | undefined { - let address = undefined + let address: undefined | string = undefined if (typeof defaultAddress === 'string') { address = defaultAddress } else if (typeof defaultAddress === 'object' && chainId) { address = defaultAddress[chainId] } - const token = useToken(address) + const token = useToken(address, chainId) + const onSupportedNetwork = useOnSupportedNetwork(chainId) - const onSupportedNetwork = useOnSupportedNetwork() + return useMemo(() => { + // Only use native currency if chain ID is in supported chains. ExtendedEther will error otherwise. + if (chainId && onSupportedNetwork && (address === 'NATIVE' || (!token && defaultToNative))) { + return nativeOnChain(chainId) + } - // Only use native currency if chain ID is in supported chains. ExtendedEther will error otherwise. - if (chainId && address === 'NATIVE' && onSupportedNetwork) { - return nativeOnChain(chainId) - } - return token ?? undefined + return token ?? undefined + }, [address, chainId, defaultToNative, onSupportedNetwork, token]) } export default function useSyncTokenDefaults({ @@ -45,40 +50,60 @@ export default function useSyncTokenDefaults({ defaultInputAmount, defaultOutputTokenAddress, defaultOutputAmount, + defaultChainId, }: TokenDefaults) { + const lastChainId = useRef(undefined) + const lastConnector = useRef(undefined) const updateSwap = useUpdateAtom(swapAtom) - const { chainId } = useWeb3React() - const onSupportedNetwork = useOnSupportedNetwork() - const nativeCurrency = useNativeCurrency() - const defaultOutputToken = useDefaultToken(defaultOutputTokenAddress, chainId) - const defaultInputToken = - useDefaultToken(defaultInputTokenAddress, chainId) ?? - // Default the input token to the native currency if it is not the output token. - (defaultOutputToken !== nativeCurrency && onSupportedNetwork ? nativeCurrency : undefined) + const { chainId, connector } = useWeb3React() - const setToDefaults = useCallback(() => { - const defaultSwapState: Swap = { - amount: '', - [Field.INPUT]: defaultInputToken, - [Field.OUTPUT]: defaultOutputToken, - type: TradeType.EXACT_INPUT, - } - if (defaultInputToken && defaultInputAmount) { - defaultSwapState.amount = defaultInputAmount.toString() - } else if (defaultOutputToken && defaultOutputAmount) { - defaultSwapState.type = TradeType.EXACT_OUTPUT - defaultSwapState.amount = defaultOutputAmount.toString() - } - updateSwap((swap) => ({ ...swap, ...defaultSwapState })) - }, [defaultInputAmount, defaultInputToken, defaultOutputAmount, defaultOutputToken, updateSwap]) + const defaultOutputToken = useDefaultToken(defaultOutputTokenAddress, chainId, false) + const defaultChainIdOutputToken = useDefaultToken(defaultOutputTokenAddress, defaultChainId, false) + + const defaultInputToken = useDefaultToken(defaultInputTokenAddress, chainId, true) + const defaultChainIdInputToken = useDefaultToken(defaultInputTokenAddress, defaultChainId, true) + + const setToDefaults = useCallback( + (shouldUseDefaultChainId) => { + const defaultSwapState: Swap = { + amount: '', + [Field.INPUT]: shouldUseDefaultChainId ? defaultChainIdInputToken : defaultInputToken, + [Field.OUTPUT]: shouldUseDefaultChainId ? defaultChainIdOutputToken : defaultOutputToken, + type: TradeType.EXACT_INPUT, + } + + if (defaultInputToken && defaultInputAmount) { + defaultSwapState.amount = defaultInputAmount.toString() + } else if (defaultOutputToken && defaultOutputAmount) { + defaultSwapState.type = TradeType.EXACT_OUTPUT + defaultSwapState.amount = defaultOutputAmount.toString() + } + updateSwap((swap) => ({ ...swap, ...defaultSwapState })) + }, + [ + defaultChainIdInputToken, + defaultInputToken, + defaultChainIdOutputToken, + defaultOutputToken, + defaultInputAmount, + defaultOutputAmount, + updateSwap, + ] + ) - const lastChainId = useRef(undefined) const isTokenListLoaded = useIsTokenListLoaded() + useEffect(() => { - const shouldSync = isTokenListLoaded && chainId && chainId !== lastChainId.current + const isChainSwitched = chainId && chainId !== lastChainId.current + const isConnectorSwitched = connector && connector !== lastConnector.current + const shouldSync = isTokenListLoaded && (isChainSwitched || isConnectorSwitched) + const shouldUseDefaultChainId = isConnectorSwitched && defaultChainId + if (shouldSync) { - setToDefaults() + setToDefaults(shouldUseDefaultChainId) + lastChainId.current = chainId + lastConnector.current = connector } - }, [isTokenListLoaded, chainId, setToDefaults]) + }, [isTokenListLoaded, chainId, setToDefaults, connector, defaultChainId]) } diff --git a/src/hooks/useCurrency.ts b/src/hooks/useCurrency.ts index 7b307eccf..801515dd3 100644 --- a/src/hooks/useCurrency.ts +++ b/src/hooks/useCurrency.ts @@ -10,6 +10,7 @@ import { useMemo } from 'react' import { isAddress } from 'utils' import { supportedChainId } from 'utils/supportedChainId' +import { SupportedChainId } from '..' import { TokenMap, useTokenMap } from './useTokenList' // parse a name or symbol from a token response @@ -76,13 +77,17 @@ export function useTokenFromNetwork(tokenAddress: string | null | undefined): To * Returns null if token is loading or null was passed. * Returns undefined if tokenAddress is invalid or token does not exist. */ -export function useTokenFromMapOrNetwork(tokens: TokenMap, tokenAddress?: string | null): Token | null | undefined { +export function useTokenFromMapOrNetwork( + tokens: TokenMap, + tokenAddress?: string | null, + skipNetwork = false +): Token | null | undefined { const address = isAddress(tokenAddress) const token: Token | undefined = address ? tokens[address] : undefined const tokenFromNetwork = useTokenFromNetwork(token ? undefined : address ? address : undefined) - return tokenFromNetwork ?? token + return skipNetwork ? token : tokenFromNetwork || token } /** @@ -90,9 +95,12 @@ export function useTokenFromMapOrNetwork(tokens: TokenMap, tokenAddress?: string * Returns null if token is loading or null was passed. * Returns undefined if tokenAddress is invalid or token does not exist. */ -export function useToken(tokenAddress?: string | null): Token | null | undefined { - const tokens = useTokenMap() - return useTokenFromMapOrNetwork(tokens, tokenAddress) +export function useToken(tokenAddress?: string | null, chainId?: SupportedChainId): Token | null | undefined { + const { chainId: activeChainId } = useWeb3React() + + const tokens = useTokenMap(chainId) + const skipNetwork = chainId && chainId !== activeChainId + return useTokenFromMapOrNetwork(tokens, tokenAddress, skipNetwork) } /** diff --git a/src/hooks/useOnSupportedNetwork.ts b/src/hooks/useOnSupportedNetwork.ts index aeec47377..cf8bd26c9 100644 --- a/src/hooks/useOnSupportedNetwork.ts +++ b/src/hooks/useOnSupportedNetwork.ts @@ -1,9 +1,12 @@ import { useWeb3React } from '@web3-react/core' -import { ALL_SUPPORTED_CHAIN_IDS } from 'constants/chains' +import { ALL_SUPPORTED_CHAIN_IDS, SupportedChainId } from 'constants/chains' import { useMemo } from 'react' -function useOnSupportedNetwork() { - const { chainId } = useWeb3React() +function useOnSupportedNetwork(chainId?: SupportedChainId) { + const { chainId: activeChainId } = useWeb3React() + + chainId = chainId || activeChainId + return useMemo(() => Boolean(chainId && ALL_SUPPORTED_CHAIN_IDS.includes(chainId)), [chainId]) } diff --git a/src/hooks/useTokenList/index.tsx b/src/hooks/useTokenList/index.tsx index 34bac0a91..7bdb42f87 100644 --- a/src/hooks/useTokenList/index.tsx +++ b/src/hooks/useTokenList/index.tsx @@ -1,6 +1,7 @@ import { Token } from '@uniswap/sdk-core' import { TokenInfo, TokenList } from '@uniswap/token-lists' import { useWeb3React } from '@web3-react/core' +import { SupportedChainId } from 'constants/chains' import { createContext, PropsWithChildren, useCallback, useContext, useEffect, useMemo, useState } from 'react' import { WrappedTokenInfo } from 'state/lists/wrappedTokenInfo' import resolveENSContentHash from 'utils/resolveENSContentHash' @@ -41,8 +42,11 @@ export default function useTokenList(): WrappedTokenInfo[] { export type TokenMap = { [address: string]: Token } -export function useTokenMap(): TokenMap { - const { chainId } = useWeb3React() +export function useTokenMap(chainId?: SupportedChainId): TokenMap { + const { chainId: activeChainId } = useWeb3React() + + chainId = chainId || activeChainId + const chainTokenMap = useChainTokenMapContext() const tokenMap = chainId && chainTokenMap?.[chainId] return useMemo(() => {