Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use AsyncLocalStorage for getContext #587

Merged
merged 4 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/05_actions/src/components/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ const App = ({ name }: { name: string }) => {
<h3>This is a server component.</h3>
<p>Server counter: {getCounter()}</p>
<Counter
greet={greet as unknown as ServerFunction<typeof greet>}
greet={greet}
increment={increment as unknown as ServerFunction<typeof increment>}
/>
<Balancer>My Awesome Title</Balancer>
Expand Down
3 changes: 2 additions & 1 deletion examples/05_actions/src/components/funcs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import { rerender, getContext } from 'waku/server';

export const greet = (name: string) => {
export const greet = async (name: string) => {
await Promise.resolve();
console.log('RSC Context:', getContext()); // ---> {}
return `Hello ${name} from server!`;
};
Expand Down
10 changes: 10 additions & 0 deletions examples/08_cookies/src/components/App.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import { Suspense } from 'react';
import { getContext } from 'waku/server';

import { Counter } from './Counter.js';

const InternalAsyncComponent = async () => {
await new Promise((resolve) => setTimeout(resolve, 1000));
console.log(getContext());
return null;
};

const App = ({ name, items }: { name: string; items: unknown[] }) => {
const context = getContext<{ count: number }>();
return (
Expand All @@ -12,6 +19,9 @@ const App = ({ name, items }: { name: string; items: unknown[] }) => {
<p>Cookie count: {context.count}</p>
<Counter />
<p>Item count: {items.length}</p>
<Suspense>
<InternalAsyncComponent />
</Suspense>
</div>
);
};
Expand Down
2 changes: 1 addition & 1 deletion examples/08_cookies/waku.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export default {
),
]
: []),
import('waku/middleware/ssr'),
import('waku/middleware/rsc'),
import('waku/middleware/fallback'),
],
};
1 change: 1 addition & 0 deletions packages/waku/src/lib/builder/output-cloudflare.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export const emitCloudflareOutput = async (
name = "waku-project"
main = "${config.distDir}/${config.serveJs}"
compatibility_date = "2023-12-06"
compatibility_flags = [ "nodejs_als" ]

[site]
bucket = "./${config.distDir}/${config.publicDir}"
Expand Down
4 changes: 2 additions & 2 deletions packages/waku/src/lib/renderers/html-renderer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ export const renderHtml = async (

const loadClientModule = <T>(key: keyof typeof CLIENT_MODULE_MAP) =>
(isDev
? import(CLIENT_MODULE_MAP[key])
? import(/* @vite-ignore */ CLIENT_MODULE_MAP[key])
: opts.loadModule(CLIENT_PREFIX + key)) as Promise<T>;

const [
Expand Down Expand Up @@ -268,7 +268,7 @@ export const renderHtml = async (
if (!moduleLoading.has(id)) {
moduleLoading.set(
id,
import(id).then((m) => {
import(/* @vite-ignore */ id).then((m) => {
moduleCache.set(id, m);
}),
);
Expand Down
145 changes: 47 additions & 98 deletions packages/waku/src/lib/renderers/rsc-renderer.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import type { default as ReactType, ReactNode } from 'react';
import type { ReactNode } from 'react';
import type { default as RSDWServerType } from 'react-server-dom-webpack/server.edge';
import type { default as RSDWClientType } from 'react-server-dom-webpack/client.edge';

import type {
EntriesDev,
EntriesPrd,
setRenderContext as setRenderContextType,
runWithRenderStore as runWithRenderStoreType,
} from '../../server.js';
import type { ResolvedConfig } from '../config.js';
import { filePathToFileURL } from '../utils/path.js';
Expand All @@ -14,9 +13,7 @@ import { streamToString } from '../utils/stream.js';
import { decodeActionId } from '../renderers/utils.js';

export const SERVER_MODULE_MAP = {
react: 'react',
'rsdw-server': 'react-server-dom-webpack/server.edge',
'rsdw-client': 'react-server-dom-webpack/client.edge',
'waku-server': 'waku/server',
} as const;

Expand Down Expand Up @@ -82,106 +79,69 @@ export async function renderRsc(
: loadModule(key)) as Promise<T>;

const [
{
default: { createElement },
},
{
default: { renderToReadableStream, decodeReply },
},
{
default: { createFromReadableStream },
},
{ setRenderContext },
{ runWithRenderStore },
] = await Promise.all([
loadServerModule<{ default: typeof ReactType }>('react'),
loadServerModule<{ default: typeof RSDWServerType }>('rsdw-server'),
loadServerModule<{ default: typeof RSDWClientType }>('rsdw-client'),
(isDev
? opts.loadServerModule(SERVER_MODULE_MAP['waku-server'])
: loadModule('waku-server')) as Promise<{
setRenderContext: typeof setRenderContextType;
runWithRenderStore: typeof runWithRenderStoreType;
}>,
]);

const runWithRenderContext = async <T>(
renderContext: Parameters<typeof setRenderContext>[0],
fn: () => T,
): Promise<Awaited<T>> =>
new Promise<Awaited<T>>((resolve, reject) => {
createFromReadableStream(
renderToReadableStream(
createElement((async () => {
setRenderContext(renderContext);
resolve(await fn());
}) as any),
{},
),
{
ssrManifest: { moduleMap: null, moduleLoading: null },
},
).catch(reject);
});

const wrapWithContext = (
context: Record<string, unknown> | undefined,
elements: Record<string, ReactNode>,
value?: unknown,
) => {
const renderContext = {
context: context || {},
rerender: () => {
throw new Error('Cannot rerender');
const bundlerConfig = new Proxy(
{},
{
get(_target, encodedId: string) {
const [file, name] = encodedId.split('#') as [string, string];
const id = resolveClientEntry(file, config);
moduleIdCallback?.(id);
return { id, chunks: [id], name, async: true };
},
};
const elementEntries: [string, unknown][] = Object.entries(elements).map(
([k, v]) => [
k,
createElement(() => {
setRenderContext(renderContext);
return v as ReactNode; // XXX lie the type
}),
],
);
if (value !== undefined) {
elementEntries.push(['_value', value]);
}
return Object.fromEntries(elementEntries);
};
},
);

const renderWithContext = async (
context: Record<string, unknown> | undefined,
input: string,
searchParams: URLSearchParams,
) => {
const renderContext = {
const renderStore = {
context: context || {},
rerender: () => {
throw new Error('Cannot rerender');
},
};
const elements = await runWithRenderContext(renderContext, () =>
renderEntries(input, { searchParams, buildConfig }),
);
if (elements === null) {
const err = new Error('No function component found');
(err as any).statusCode = 404; // HACK our convention for NotFound
throw err;
}
if (Object.keys(elements).some((key) => key.startsWith('_'))) {
throw new Error('"_" prefix is reserved');
}
return wrapWithContext(context, elements);
return runWithRenderStore(renderStore, async () => {
const elements = await renderEntries(input, {
searchParams,
buildConfig,
});
if (elements === null) {
const err = new Error('No function component found');
(err as any).statusCode = 404; // HACK our convention for NotFound
throw err;
}
if (Object.keys(elements).some((key) => key.startsWith('_'))) {
throw new Error('"_" prefix is reserved');
}
return renderToReadableStream(elements, bundlerConfig);
});
};

const renderWithContextWithAction = async (
context: Record<string, unknown> | undefined,
actionFn: () => unknown,
actionFn: (...args: unknown[]) => unknown,
actionArgs: unknown[],
) => {
let elementsPromise: Promise<Record<string, ReactNode>> = Promise.resolve(
{},
);
let rendered = false;
const renderContext = {
const renderStore = {
context: context || {},
rerender: async (input: string, searchParams = new URLSearchParams()) => {
if (rendered) {
Expand All @@ -197,27 +157,20 @@ export async function renderRsc(
}));
},
};
const actionValue = await runWithRenderContext(renderContext, actionFn);
const elements = await elementsPromise;
rendered = true;
if (Object.keys(elements).some((key) => key.startsWith('_'))) {
throw new Error('"_" prefix is reserved');
}
return wrapWithContext(context, elements, actionValue);
return runWithRenderStore(renderStore, async () => {
const actionValue = await actionFn(...actionArgs);
const elements = await elementsPromise;
rendered = true;
if (Object.keys(elements).some((key) => key.startsWith('_'))) {
throw new Error('"_" prefix is reserved');
}
return renderToReadableStream(
{ ...elements, _value: actionValue },
bundlerConfig,
);
});
};

const bundlerConfig = new Proxy(
{},
{
get(_target, encodedId: string) {
const [file, name] = encodedId.split('#') as [string, string];
const id = resolveClientEntry(file, config);
moduleIdCallback?.(id);
return { id, chunks: [id], name, async: true };
},
},
);

if (method === 'POST') {
const rsfId = decodeActionId(input);
let args: unknown[] = [];
Expand Down Expand Up @@ -246,15 +199,11 @@ export async function renderRsc(
mod = await loadModule(fileId.slice('@id/'.length));
}
const fn = mod[name] || mod;
const elements = await renderWithContextWithAction(context, () =>
fn(...args),
);
return renderToReadableStream(elements, bundlerConfig);
return renderWithContextWithAction(context, fn, args);
}

// method === 'GET'
const elements = await renderWithContext(context, input, searchParams);
return renderToReadableStream(elements, bundlerConfig);
return renderWithContext(context, input, searchParams);
}

export async function getBuildConfig(opts: {
Expand Down
54 changes: 40 additions & 14 deletions packages/waku/src/server.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { cache } from 'react';
import type { AsyncLocalStorage as AsyncLocalStorageType } from 'node:async_hooks';
import type { ReactNode } from 'react';

import type { Config } from './config.js';
Expand Down Expand Up @@ -68,37 +68,63 @@ export function getEnv(key: string): string | undefined {
return (globalThis as any).__WAKU_PRIVATE_ENV__[key];
}

type RenderContext<
type RenderStore<
RscContext extends Record<string, unknown> = Record<string, unknown>,
> = {
rerender: (input: string, searchParams?: URLSearchParams) => void;
context: RscContext;
};

const getRenderContextHolder = cache(() => [] as [RenderContext?]);
const DO_NOT_BUNDLE = '';

let renderStorage: AsyncLocalStorageType<RenderStore> | undefined;

import(/* @vite-ignore */ DO_NOT_BUNDLE + 'node:async_hooks')
.then(({ AsyncLocalStorage }) => {
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think it needs to be bundled.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renderStorage = new AsyncLocalStorage();
})
.catch(() => {
console.warn(
'AsyncLocalStorage is not available, rerender and getContext are only available in sync.',
);
});

let previousRenderStore: RenderStore | undefined;
let currentRenderStore: RenderStore | undefined;

/**
* This is an internal function and not for public use.
*/
export const setRenderContext = (renderContext: RenderContext) => {
const holder = getRenderContextHolder();
holder[0] = renderContext;
export const runWithRenderStore = <T>(
renderStore: RenderStore,
fn: () => T,
): T => {
if (renderStorage) {
return renderStorage.run(renderStore, fn);
}
previousRenderStore = currentRenderStore;
currentRenderStore = renderStore;
try {
return fn();
} finally {
currentRenderStore = previousRenderStore;
}
};

export function rerender(input: string, searchParams?: URLSearchParams) {
const holder = getRenderContextHolder();
if (!holder[0]) {
throw new Error('[Bug] No render context found');
const renderStore = renderStorage?.getStore() ?? currentRenderStore;
if (!renderStore) {
throw new Error('Render store is not available');
}
holder[0].rerender(input, searchParams);
renderStore.rerender(input, searchParams);
}

export function getContext<
RscContext extends Record<string, unknown> = Record<string, unknown>,
>(): RscContext {
const holder = getRenderContextHolder();
if (!holder[0]) {
throw new Error('[Bug] No render context found');
const renderStore = renderStorage?.getStore() ?? currentRenderStore;
if (!renderStore) {
throw new Error('Render store is not available');
}
return holder[0].context as RscContext;
return renderStore.context as RscContext;
}
Loading