diff --git a/packages/network/src/NetworkCanvas.tsx b/packages/network/src/NetworkCanvas.tsx index c97476bb7..ea15a7e5d 100644 --- a/packages/network/src/NetworkCanvas.tsx +++ b/packages/network/src/NetworkCanvas.tsx @@ -1,4 +1,13 @@ -import { useCallback, useRef, useEffect, createElement, MouseEvent, useMemo } from 'react' +import { + ForwardedRef, + forwardRef, + useCallback, + useRef, + useEffect, + createElement, + MouseEvent, + useMemo, +} from 'react' import { getDistance, getRelativeCursor, Container, useDimensions, useTheme } from '@nivo/core' import { useTooltip } from '@nivo/tooltip' import { useComputedAnnotations, renderAnnotationsToCanvas } from '@nivo/annotations' @@ -17,7 +26,9 @@ import { type InnerNetworkCanvasProps = Omit< NetworkCanvasProps, 'renderWrapper' | 'theme' -> +> & { + canvasRef: ForwardedRef +} const InnerNetworkCanvas = ({ width, @@ -56,6 +67,7 @@ const InnerNetworkCanvas = ({ defaultActiveNodeIds = canvasDefaultProps.defaultActiveNodeIds, nodeTooltip = canvasDefaultProps.nodeTooltip as NodeTooltip, onClick, + canvasRef, }: InnerNetworkCanvasProps) => { const canvasEl = useRef(null) const { margin, innerWidth, innerHeight, outerWidth, outerHeight } = useDimensions( @@ -202,7 +214,10 @@ const InnerNetworkCanvas = ({ return ( { + canvasEl.current = canvas + if (canvasRef && 'current' in canvasRef) canvasRef.current = canvas + }} width={outerWidth * pixelRatio} height={outerHeight * pixelRatio} style={{ @@ -218,18 +233,24 @@ const InnerNetworkCanvas = ({ ) } -export const NetworkCanvas = < - Node extends InputNode = InputNode, - Link extends InputLink = InputLink ->({ - theme, - isInteractive = canvasDefaultProps.isInteractive, - animate = canvasDefaultProps.animate, - motionConfig = canvasDefaultProps.motionConfig, - renderWrapper, - ...otherProps -}: NetworkCanvasProps) => ( - - isInteractive={isInteractive} {...otherProps} /> - +export const NetworkCanvas = forwardRef( + ( + { + theme, + isInteractive = canvasDefaultProps.isInteractive, + animate = canvasDefaultProps.animate, + motionConfig = canvasDefaultProps.motionConfig, + renderWrapper, + ...otherProps + }: NetworkCanvasProps, + ref: ForwardedRef + ) => ( + + + isInteractive={isInteractive} + {...otherProps} + canvasRef={ref} + /> + + ) ) diff --git a/packages/network/src/ResponsiveNetworkCanvas.tsx b/packages/network/src/ResponsiveNetworkCanvas.tsx index ccbd5e65d..62c2031df 100644 --- a/packages/network/src/ResponsiveNetworkCanvas.tsx +++ b/packages/network/src/ResponsiveNetworkCanvas.tsx @@ -1,16 +1,28 @@ import { ResponsiveWrapper } from '@nivo/core' +import { ForwardedRef, forwardRef } from 'react' import { NetworkCanvasProps, InputNode, InputLink } from './types' import { NetworkCanvas } from './NetworkCanvas' -export const ResponsiveNetworkCanvas = < +export const ResponsiveNetworkCanvas = forwardRef(function ResponsiveBarCanvas< Node extends InputNode = InputNode, Link extends InputLink = InputLink >( - props: Omit, 'height' | 'width'> -) => ( - - {({ width, height }) => ( - width={width} height={height} {...props} /> - )} - -) + props: Omit, 'height' | 'width'>, + ref: ForwardedRef +) { + return ( + + {({ width, height }) => ( + , + 'height' | 'width' + >)} + ref={ref} + /> + )} + + ) +}) diff --git a/packages/network/stories/networkCanvas.stories.tsx b/packages/network/stories/networkCanvas.stories.tsx index 8cca846a4..1e79955e2 100644 --- a/packages/network/stories/networkCanvas.stories.tsx +++ b/packages/network/stories/networkCanvas.stories.tsx @@ -9,6 +9,7 @@ import { NodeTooltipProps, // @ts-ignore } from '../src' +import { useRef } from 'react' export default { component: NetworkCanvas, @@ -70,3 +71,22 @@ export const CustomNodeRenderer = () => ( export const OnClickHandler = () => ( {...commonProperties} onClick={action('onClick')} /> ) + +export const CustomCanvasRef = () => { + const ref = useRef(undefined) + + const download = ref => { + const canvas = ref.current + const link = document.createElement('a') + link.download = 'test.png' + link.href = canvas.toDataURL('image/png') + link.click() + } + + return ( + <> + {...commonProperties} ref={ref} /> + + + ) +} diff --git a/packages/scatterplot/src/ResponsiveScatterPlotCanvas.tsx b/packages/scatterplot/src/ResponsiveScatterPlotCanvas.tsx index 889436e27..99289d9fe 100644 --- a/packages/scatterplot/src/ResponsiveScatterPlotCanvas.tsx +++ b/packages/scatterplot/src/ResponsiveScatterPlotCanvas.tsx @@ -1,13 +1,28 @@ import { ResponsiveWrapper } from '@nivo/core' +import { ForwardedRef, forwardRef } from 'react' + import { ScatterPlotCanvas } from './ScatterPlotCanvas' import { ScatterPlotCanvasProps, ScatterPlotDatum } from './types' -export const ResponsiveScatterPlotCanvas = ( - props: Omit, 'width' | 'height'> -) => ( - - {({ width, height }) => ( - width={width} height={height} {...props} /> - )} - -) +export const ResponsiveScatterPlotCanvas = forwardRef(function ResponsiveScatterPlotCanvas< + RawDatum extends ScatterPlotDatum +>( + props: Omit, 'width' | 'height'>, + ref: ForwardedRef +) { + return ( + + {({ width, height }) => ( + , + 'height' | 'width' + >)} + ref={ref} + /> + )} + + ) +}) diff --git a/packages/scatterplot/src/ScatterPlotCanvas.tsx b/packages/scatterplot/src/ScatterPlotCanvas.tsx index 2a7264068..888828316 100644 --- a/packages/scatterplot/src/ScatterPlotCanvas.tsx +++ b/packages/scatterplot/src/ScatterPlotCanvas.tsx @@ -1,4 +1,13 @@ -import { createElement, useRef, useState, useEffect, useCallback, useMemo } from 'react' +import { + ForwardedRef, + createElement, + forwardRef, + useCallback, + useEffect, + useMemo, + useRef, + useState, +} from 'react' import { Container, useDimensions, useTheme, getRelativeCursor, isCursorInRect } from '@nivo/core' import { renderAnnotationsToCanvas } from '@nivo/annotations' import { CanvasAxisProps, renderAxesToCanvas, renderGridLinesToCanvas } from '@nivo/axes' @@ -12,7 +21,9 @@ import { ScatterPlotCanvasProps, ScatterPlotDatum, ScatterPlotNodeData } from '. type InnerScatterPlotCanvasProps = Omit< ScatterPlotCanvasProps, 'renderWrapper' | 'theme' -> +> & { + canvasRef: ForwardedRef +} const InnerScatterPlotCanvas = ({ data, @@ -46,6 +57,7 @@ const InnerScatterPlotCanvas = ({ onClick, tooltip = canvasDefaultProps.tooltip, legends = canvasDefaultProps.legends, + canvasRef, }: InnerScatterPlotCanvasProps) => { const canvasEl = useRef(null) const theme = useTheme() @@ -270,7 +282,10 @@ const InnerScatterPlotCanvas = ({ return ( { + canvasEl.current = canvas + if (canvasRef && 'current' in canvasRef) canvasRef.current = canvas + }} width={outerWidth * pixelRatio} height={outerHeight * pixelRatio} style={{ @@ -286,13 +301,13 @@ const InnerScatterPlotCanvas = ({ ) } -export const ScatterPlotCanvas = ({ - isInteractive, - renderWrapper, - theme, - ...props -}: ScatterPlotCanvasProps) => ( - - {...props} /> - +export const ScatterPlotCanvas = forwardRef( + ( + { isInteractive, renderWrapper, theme, ...props }: ScatterPlotCanvasProps, + ref: ForwardedRef + ) => ( + + {...props} canvasRef={ref} /> + + ) ) diff --git a/packages/scatterplot/stories/ScatterPlotCanvas.stories.tsx b/packages/scatterplot/stories/ScatterPlotCanvas.stories.tsx index f3c3db4ab..fb743b591 100644 --- a/packages/scatterplot/stories/ScatterPlotCanvas.stories.tsx +++ b/packages/scatterplot/stories/ScatterPlotCanvas.stories.tsx @@ -1,8 +1,8 @@ -import { useState, useCallback, useMemo } from 'react' +import { useCallback, useMemo, useRef, useState } from 'react' import omit from 'lodash/omit' import { Meta } from '@storybook/react' // @ts-ignore -import { ScatterPlotCanvas, ResponsiveScatterPlotCanvas, ScatterPlotNodeData } from '../src' +import { ResponsiveScatterPlotCanvas, ScatterPlotCanvas, ScatterPlotNodeData } from '../src' export default { component: ScatterPlotCanvas, @@ -403,3 +403,42 @@ export const CustomTooltip = () => ( )} /> ) + +export const CustomCanvasRef = () => { + const ref = useRef(undefined) + + const download = ref => { + const canvas = ref.current + const link = document.createElement('a') + link.download = 'test.png' + link.href = canvas.toDataURL('image/png') + link.click() + } + + return ( + <> + + {...commonProps} + ref={ref} + tooltip={({ node }) => ( +
+ + {node.id} ({node.serieId}) + +
+ {`x: ${node.formattedX}`} +
+ {`y: ${node.formattedY}`} +
+ )} + /> + + + ) +}