diff --git a/packages/react-query/src/__tests__/suspense.test.tsx b/packages/react-query/src/__tests__/suspense.test.tsx new file mode 100644 index 0000000000..cc8ae1fe7f --- /dev/null +++ b/packages/react-query/src/__tests__/suspense.test.tsx @@ -0,0 +1,194 @@ +import { act, render, waitFor } from '@testing-library/react' +import { Suspense } from 'react' +import { + afterAll, + beforeAll, + beforeEach, + describe, + expect, + it, + vi, +} from 'vitest' +import { QueryClient, QueryClientProvider, useSuspenseQuery } from '..' +import { queryKey } from './utils' +import type { QueryKey } from '..' + +function renderWithSuspense(client: QueryClient, ui: React.ReactNode) { + return render( + + {ui} + , + ) +} + +function createTestQuery(options: { + fetchCount: { count: number } + queryKey: QueryKey + staleTime?: number | (() => number) +}) { + return function TestComponent() { + const { data } = useSuspenseQuery({ + queryKey: options.queryKey, + queryFn: () => { + options.fetchCount.count++ + return 'data' + }, + staleTime: options.staleTime, + }) + return
data: {data}
+ } +} + +describe('Suspense Timer Tests', () => { + let queryClient: QueryClient + let fetchCount: { count: number } + + beforeAll(() => { + vi.useFakeTimers({ shouldAdvanceTime: true }) + }) + + afterAll(() => { + vi.useRealTimers() + }) + + beforeEach(() => { + queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + }, + }) + fetchCount = { count: 0 } + }) + + it('should enforce minimum staleTime of 1000ms when using suspense with number', async () => { + const TestComponent = createTestQuery({ + fetchCount, + queryKey: ['test'], + staleTime: 10, + }) + + const rendered = renderWithSuspense(queryClient, ) + + await waitFor(() => rendered.getByText('data: data')) + + rendered.rerender( + + + + + , + ) + + act(() => { + vi.advanceTimersByTime(100) + }) + + expect(fetchCount.count).toBe(1) + }) + + it('should enforce minimum staleTime of 1000ms when using suspense with function', async () => { + const TestComponent = createTestQuery({ + fetchCount, + queryKey: ['test-func'], + staleTime: () => 10, + }) + + const rendered = renderWithSuspense(queryClient, ) + + await waitFor(() => rendered.getByText('data: data')) + + rendered.rerender( + + + + + , + ) + + act(() => { + vi.advanceTimersByTime(100) + }) + + expect(fetchCount.count).toBe(1) + }) + + it('should respect staleTime when value is greater than 1000ms', async () => { + const TestComponent = createTestQuery({ + fetchCount, + queryKey: queryKey(), + staleTime: 2000, + }) + + const rendered = renderWithSuspense(queryClient, ) + + await waitFor(() => rendered.getByText('data: data')) + + rendered.rerender( + + + + + , + ) + + act(() => { + vi.advanceTimersByTime(1500) + }) + + expect(fetchCount.count).toBe(1) + }) + + it('should enforce minimum staleTime when undefined is provided', async () => { + const TestComponent = createTestQuery({ + fetchCount, + queryKey: queryKey(), + staleTime: undefined, + }) + + const rendered = renderWithSuspense(queryClient, ) + + await waitFor(() => rendered.getByText('data: data')) + + rendered.rerender( + + + + + , + ) + + act(() => { + vi.advanceTimersByTime(500) + }) + + expect(fetchCount.count).toBe(1) + }) + + it('should respect staleTime when function returns value greater than 1000ms', async () => { + const TestComponent = createTestQuery({ + fetchCount, + queryKey: queryKey(), + staleTime: () => 3000, + }) + + const rendered = renderWithSuspense(queryClient, ) + + await waitFor(() => rendered.getByText('data: data')) + + rendered.rerender( + + + + + , + ) + + act(() => { + vi.advanceTimersByTime(2000) + }) + + expect(fetchCount.count).toBe(1) + }) +}) diff --git a/packages/react-query/src/__tests__/useSuspenseQuery.test.tsx b/packages/react-query/src/__tests__/useSuspenseQuery.test.tsx index 13d1184a0d..30aa9bbdb1 100644 --- a/packages/react-query/src/__tests__/useSuspenseQuery.test.tsx +++ b/packages/react-query/src/__tests__/useSuspenseQuery.test.tsx @@ -327,61 +327,6 @@ describe('useSuspenseQuery', () => { consoleMock.mockRestore() }) - it('should refetch when re-mounting', async () => { - const key = queryKey() - let count = 0 - - function Component() { - const result = useSuspenseQuery({ - queryKey: key, - queryFn: async () => { - await sleep(100) - count++ - return count - }, - retry: false, - staleTime: 0, - }) - return ( -
- data: {result.data} - fetching: {result.isFetching ? 'true' : 'false'} -
- ) - } - - function Page() { - const [show, setShow] = React.useState(true) - return ( -
- - - {show && } - -
- ) - } - - const rendered = renderWithClient(queryClient, ) - - await waitFor(() => rendered.getByText('Loading...')) - await waitFor(() => rendered.getByText('data: 1')) - await waitFor(() => rendered.getByText('fetching: false')) - await waitFor(() => rendered.getByText('hide')) - fireEvent.click(rendered.getByText('hide')) - await waitFor(() => rendered.getByText('show')) - fireEvent.click(rendered.getByText('show')) - await waitFor(() => rendered.getByText('fetching: true')) - await waitFor(() => rendered.getByText('data: 2')) - await waitFor(() => rendered.getByText('fetching: false')) - }) - it('should set staleTime when having passed a function', async () => { const key = queryKey() let count = 0 diff --git a/packages/react-query/src/suspense.ts b/packages/react-query/src/suspense.ts index 497bb83bd8..6981e07422 100644 --- a/packages/react-query/src/suspense.ts +++ b/packages/react-query/src/suspense.ts @@ -21,12 +21,16 @@ export const defaultThrowOnError = < export const ensureSuspenseTimers = ( defaultedOptions: DefaultedQueryObserverOptions, ) => { + const originalStaleTime = defaultedOptions.staleTime + if (defaultedOptions.suspense) { - // Always set stale time when using suspense to prevent - // fetching again when directly mounting after suspending - if (defaultedOptions.staleTime === undefined) { - defaultedOptions.staleTime = 1000 - } + // Handle staleTime to ensure minimum 1000ms in Suspense mode + // This prevents unnecessary refetching when components remount after suspending + defaultedOptions.staleTime = + typeof originalStaleTime === 'function' + ? (...args) => Math.max(originalStaleTime(...args), 1000) + : Math.max(originalStaleTime ?? 1000, 1000) + if (typeof defaultedOptions.gcTime === 'number') { defaultedOptions.gcTime = Math.max(defaultedOptions.gcTime, 1000) }