diff --git a/packages/x-data-grid-pro/src/tests/infiniteLoader.DataGridPro.test.tsx b/packages/x-data-grid-pro/src/tests/infiniteLoader.DataGridPro.test.tsx index 4ea7158627291..350aa89568c4f 100644 --- a/packages/x-data-grid-pro/src/tests/infiniteLoader.DataGridPro.test.tsx +++ b/packages/x-data-grid-pro/src/tests/infiniteLoader.DataGridPro.test.tsx @@ -1,5 +1,5 @@ import * as React from 'react'; -import { createRenderer, waitFor } from '@mui/internal-test-utils'; +import { act, createRenderer, waitFor } from '@mui/internal-test-utils'; import { expect } from 'chai'; import { DataGridPro } from '@mui/x-data-grid-pro'; import { spy, restore } from 'sinon'; @@ -40,24 +40,37 @@ describe(' - Infnite loader', () => { } const { container, setProps } = render(); const virtualScroller = container.querySelector('.MuiDataGrid-virtualScroller')!; - // arbitrary number to make sure that the bottom of the grid window is reached. - virtualScroller.scrollTop = 12345; - virtualScroller.dispatchEvent(new Event('scroll')); + + await act(async () => { + // arbitrary number to make sure that the bottom of the grid window is reached. + virtualScroller.scrollTop = 12345; + virtualScroller.dispatchEvent(new Event('scroll')); + }); + await waitFor(() => { expect(handleRowsScrollEnd.callCount).to.equal(1); }); - setProps({ - rows: baseRows.concat( - { id: 6, brand: 'Gucci' }, - { id: 7, brand: "Levi's" }, - { id: 8, brand: 'Ray-Ban' }, - ), + + await act(async () => { + setProps({ + rows: baseRows.concat( + { id: 6, brand: 'Gucci' }, + { id: 7, brand: "Levi's" }, + { id: 8, brand: 'Ray-Ban' }, + ), + }); + + // Trigger a scroll again to notify the grid that we're not in the bottom area anymore + virtualScroller.dispatchEvent(new Event('scroll')); }); - // Trigger a scroll again to notify the grid that we're not in the bottom area anymore - virtualScroller.dispatchEvent(new Event('scroll')); + expect(handleRowsScrollEnd.callCount).to.equal(1); - virtualScroller.scrollTop = 12345; - virtualScroller.dispatchEvent(new Event('scroll')); + + await act(async () => { + virtualScroller.scrollTop = 12345; + virtualScroller.dispatchEvent(new Event('scroll')); + }); + await waitFor(() => { expect(handleRowsScrollEnd.callCount).to.equal(2); }); diff --git a/packages/x-data-grid/src/components/virtualization/GridVirtualScrollbar.tsx b/packages/x-data-grid/src/components/virtualization/GridVirtualScrollbar.tsx index 912ed05493532..e5bf5410f4ccf 100644 --- a/packages/x-data-grid/src/components/virtualization/GridVirtualScrollbar.tsx +++ b/packages/x-data-grid/src/components/virtualization/GridVirtualScrollbar.tsx @@ -70,8 +70,8 @@ const GridVirtualScrollbar = React.forwardRef(null); const contentRef = React.useRef(null); const classes = useUtilityClasses(rootProps, props.position); @@ -96,28 +96,34 @@ const GridVirtualScrollbar = React.forwardRef { const scroller = apiRef.current.virtualScrollerRef.current!; const scrollbar = scrollbarRef.current!; - if (scrollbar[propertyScroll] === lastPositionScrollbar.current) { + if (isLocked.current) { + isLocked.current = false; return; } + isLocked.current = true; const value = scrollbar[propertyScroll] / scrollbarInnerSize; scroller[propertyScroll] = value * contentSize; - - lastPositionScroller.current = scroller[propertyScroll]; }); useOnMount(() => { diff --git a/packages/x-data-grid/src/hooks/core/useGridRefs.ts b/packages/x-data-grid/src/hooks/core/useGridRefs.ts index f1049fbc93c75..815c7f56e3281 100644 --- a/packages/x-data-grid/src/hooks/core/useGridRefs.ts +++ b/packages/x-data-grid/src/hooks/core/useGridRefs.ts @@ -7,6 +7,8 @@ export const useGridRefs = ( const rootElementRef = React.useRef(null); const mainElementRef = React.useRef(null); const virtualScrollerRef = React.useRef(null); + const virtualScrollbarVerticalRef = React.useRef(null); + const virtualScrollbarHorizontalRef = React.useRef(null); const columnHeadersContainerRef = React.useRef(null); apiRef.current.register('public', { @@ -16,6 +18,8 @@ export const useGridRefs = ( apiRef.current.register('private', { mainElementRef, virtualScrollerRef, + virtualScrollbarVerticalRef, + virtualScrollbarHorizontalRef, columnHeadersContainerRef, }); }; diff --git a/packages/x-data-grid/src/hooks/features/scroll/useGridScroll.ts b/packages/x-data-grid/src/hooks/features/scroll/useGridScroll.ts index f8628fa5451c4..67126ae3366e1 100644 --- a/packages/x-data-grid/src/hooks/features/scroll/useGridScroll.ts +++ b/packages/x-data-grid/src/hooks/features/scroll/useGridScroll.ts @@ -59,6 +59,8 @@ export const useGridScroll = ( const logger = useGridLogger(apiRef, 'useGridScroll'); const colRef = apiRef.current.columnHeadersContainerRef; const virtualScrollerRef = apiRef.current.virtualScrollerRef!; + const virtualScrollbarHorizontalRef = apiRef.current.virtualScrollbarHorizontalRef!; + const virtualScrollbarVerticalRef = apiRef.current.virtualScrollbarVerticalRef!; const visibleSortedRows = useGridSelector(apiRef, gridExpandedSortedRowEntriesSelector); const scrollToIndexes = React.useCallback( @@ -144,19 +146,37 @@ export const useGridScroll = ( const scroll = React.useCallback( (params: Partial) => { - if (virtualScrollerRef.current && params.left !== undefined && colRef.current) { + if ( + virtualScrollerRef.current && + virtualScrollbarHorizontalRef.current && + params.left !== undefined && + colRef.current + ) { const direction = isRtl ? -1 : 1; colRef.current.scrollLeft = params.left; virtualScrollerRef.current.scrollLeft = direction * params.left; + virtualScrollbarHorizontalRef.current.scrollLeft = direction * params.left; logger.debug(`Scrolling left: ${params.left}`); } - if (virtualScrollerRef.current && params.top !== undefined) { + if ( + virtualScrollerRef.current && + virtualScrollbarVerticalRef.current && + params.top !== undefined + ) { virtualScrollerRef.current.scrollTop = params.top; + virtualScrollbarVerticalRef.current.scrollTop = params.top; logger.debug(`Scrolling top: ${params.top}`); } logger.debug(`Scrolling, updating container, and viewport`); }, - [virtualScrollerRef, isRtl, colRef, logger], + [ + virtualScrollerRef, + virtualScrollbarHorizontalRef, + virtualScrollbarVerticalRef, + isRtl, + colRef, + logger, + ], ); const getScrollPosition = React.useCallback(() => { diff --git a/packages/x-data-grid/src/hooks/features/virtualization/useGridVirtualScroller.tsx b/packages/x-data-grid/src/hooks/features/virtualization/useGridVirtualScroller.tsx index cf6cb4401c697..26505febe3dd2 100644 --- a/packages/x-data-grid/src/hooks/features/virtualization/useGridVirtualScroller.tsx +++ b/packages/x-data-grid/src/hooks/features/virtualization/useGridVirtualScroller.tsx @@ -119,8 +119,8 @@ export const useGridVirtualScroller = () => { const gridRootRef = apiRef.current.rootElementRef; const mainRef = apiRef.current.mainElementRef; const scrollerRef = apiRef.current.virtualScrollerRef; - const scrollbarVerticalRef = React.useRef(null); - const scrollbarHorizontalRef = React.useRef(null); + const scrollbarVerticalRef = apiRef.current.virtualScrollbarVerticalRef; + const scrollbarHorizontalRef = apiRef.current.virtualScrollbarHorizontalRef; const contentHeight = dimensions.contentSize.height; const columnsTotalWidth = dimensions.columnsTotalWidth; const hasColSpan = useGridSelector(apiRef, gridHasColSpanSelector); diff --git a/packages/x-data-grid/src/models/api/gridCoreApi.ts b/packages/x-data-grid/src/models/api/gridCoreApi.ts index 8a48f4ee24125..1dfaf2834b597 100644 --- a/packages/x-data-grid/src/models/api/gridCoreApi.ts +++ b/packages/x-data-grid/src/models/api/gridCoreApi.ts @@ -69,9 +69,17 @@ export interface GridCorePrivateApi< */ mainElementRef: React.RefObject; /** - * The React ref of the grid virtual scroller container element. + * The React ref of the grid's virtual scroller container element. */ virtualScrollerRef: React.RefObject; + /** + * The React ref of the grid's vertical virtual scrollbar container element. + */ + virtualScrollbarVerticalRef: React.RefObject; + /** + * The React ref of the grid's horizontal virtual scrollbar container element. + */ + virtualScrollbarHorizontalRef: React.RefObject; /** * The React ref of the grid column container virtualized div element. */