diff --git a/frontend/src/lib/OutputArtifactLoader.test.ts b/frontend/src/lib/OutputArtifactLoader.test.ts index d2cbeee266cf..8f076db2e20c 100644 --- a/frontend/src/lib/OutputArtifactLoader.test.ts +++ b/frontend/src/lib/OutputArtifactLoader.test.ts @@ -145,27 +145,27 @@ describe('OutputArtifactLoader', () => { ); }); + const basicMetadata = { + labels: ['field1', 'field2'], + schema: [ + { + name: 'field1', + }, + { + name: 'field2', + type: 'field2 type', + }, + ], + source: 'gs://path', + }; it('returns a confusion matrix config with basic metadata', async () => { - const metadata = { - labels: ['field1', 'field2'], - schema: [ - { - name: 'field1', - }, - { - name: 'field2', - type: 'field2 type', - }, - ], - source: 'gs://path', - }; fileToRead = ` field1,field1,0 field1,field2,0 field2,field1,0 field2,field2,0 `; - const result = await OutputArtifactLoader.buildConfusionMatrixConfig(metadata as any); + const result = await OutputArtifactLoader.buildConfusionMatrixConfig(basicMetadata as any); expect(result).toEqual({ axes: ['field1', 'field2'], data: [ @@ -176,6 +176,32 @@ describe('OutputArtifactLoader', () => { type: PlotType.CONFUSION_MATRIX, } as ConfusionMatrixConfig); }); + it('supports inline confusion matrix data', async () => { + fileToRead = ''; + + const source = ` + field1,field1,1 + field1,field2,2 + field2,field1,3 + field2,field2,4 + `; + const expectedResult: ConfusionMatrixConfig = { + axes: ['field1', 'field2'], + data: [ + [1, 2], + [3, 4], + ], + labels: ['field1', 'field2'], + type: PlotType.CONFUSION_MATRIX, + }; + + const result = await OutputArtifactLoader.buildConfusionMatrixConfig({ + ...basicMetadata, + storage: 'inline', + source, + } as any); + expect(result).toEqual(expectedResult); + }); }); describe('buildPagedTableConfig', () => { @@ -207,19 +233,20 @@ describe('OutputArtifactLoader', () => { ); }); + const basicMetadata = { + format: 'csv', + header: ['field1', 'field2'], + source: 'gs://path', + }; + it('returns a paged table config with basic metadata', async () => { - const metadata = { - format: 'csv', - header: ['field1', 'field2'], - source: 'gs://path', - }; fileToRead = ` field1,field1,0 field1,field2,0 field2,field1,0 field2,field2,0 `; - const result = await OutputArtifactLoader.buildPagedTableConfig(metadata as any); + const result = await OutputArtifactLoader.buildPagedTableConfig(basicMetadata as any); expect(result).toEqual({ data: [ ['field1', 'field1', '0'], @@ -231,6 +258,29 @@ describe('OutputArtifactLoader', () => { type: PlotType.TABLE, } as PagedTableConfig); }); + + it('returns a paged table config with inline metadata', async () => { + fileToRead = ''; + const source = ` + field1,field1,1 + field1,field2,2 + field2,field1,3 + field2,field2,4 + `; + const metadata = { ...basicMetadata, storage: 'inline', source }; + const result = await OutputArtifactLoader.buildPagedTableConfig(metadata as any); + const expectedResult: PagedTableConfig = { + data: [ + ['field1', 'field1', '1'], + ['field1', 'field2', '2'], + ['field2', 'field1', '3'], + ['field2', 'field2', '4'], + ], + labels: ['field1', 'field2'], + type: PlotType.TABLE, + }; + expect(result).toEqual(expectedResult); + }); }); describe('buildTensorboardConfig', () => { @@ -268,6 +318,19 @@ describe('OutputArtifactLoader', () => { type: PlotType.WEB_APP, } as HTMLViewerConfig); }); + + it('returns source as html content when storage type is inline', async () => { + const metadata = { + source: ` + Hello World! + `, + storage: 'inline', + }; + expect(await OutputArtifactLoader.buildHtmlViewerConfig(metadata as any)).toEqual({ + htmlContent: metadata.source, + type: PlotType.WEB_APP, + } as HTMLViewerConfig); + }); }); describe('buildMarkdownViewerConfig', () => { @@ -364,11 +427,13 @@ describe('OutputArtifactLoader', () => { ); }); + const basicMetadata = { + schema: [{ name: 'fpr' }, { name: 'tpr' }, { name: 'threshold' }], + source: 'gs://path', + }; + it('returns an ROC viewer config with basic metadata', async () => { - const metadata = { - schema: [{ name: 'fpr' }, { name: 'tpr' }, { name: 'threshold' }], - source: 'gs://path', - }; + const metadata = basicMetadata; fileToRead = ` 0,1,2 3,4,5 @@ -384,6 +449,27 @@ describe('OutputArtifactLoader', () => { } as ROCCurveConfig); }); + it('returns an ROC viewer config with basic metadata', async () => { + const source = ` + 9,1,2 + 3,4,5 + 6,7,8 + `; + const metadata = { ...basicMetadata, source, storage: 'inline' }; + fileToRead = ''; + const expectedResult: ROCCurveConfig = { + data: [ + { label: '2', x: 9, y: 1 }, + { label: '5', x: 3, y: 4 }, + { label: '8', x: 6, y: 7 }, + ], + type: PlotType.ROC, + }; + expect(await OutputArtifactLoader.buildRocCurveConfig(metadata as any)).toEqual( + expectedResult, + ); + }); + it('returns an ROC viewer config with fields out of order', async () => { const metadata = { schema: [{ name: 'threshold' }, { name: 'tpr' }, { name: 'fpr' }], diff --git a/frontend/src/lib/OutputArtifactLoader.ts b/frontend/src/lib/OutputArtifactLoader.ts index 687d04de1231..da2f2796114d 100644 --- a/frontend/src/lib/OutputArtifactLoader.ts +++ b/frontend/src/lib/OutputArtifactLoader.ts @@ -57,6 +57,8 @@ export interface PlotMetadata { type: PlotType; } +type PlotMetadataContent = Omit; + export interface OutputMetadata { outputs: PlotMetadata[]; } @@ -109,7 +111,7 @@ export class OutputArtifactLoader { } public static async buildConfusionMatrixConfig( - metadata: PlotMetadata, + metadata: PlotMetadataContent, ): Promise { if (!metadata.source) { throw new Error('Malformed metadata, property "source" is required.'); @@ -124,8 +126,8 @@ export class OutputArtifactLoader { throw new Error('"schema" must be an array of {"name": string, "type": string} objects'); } - const path = WorkflowParser.parseStoragePath(metadata.source); - const csvRows = csvParseRows((await Apis.readFile(path)).trim()); + const content = await getSourceContent(metadata.source, metadata.storage); + const csvRows = csvParseRows(content.trim()); const labels = metadata.labels; const labelIndex: { [label: string]: number } = {}; let index = 0; @@ -162,7 +164,9 @@ export class OutputArtifactLoader { }; } - public static async buildPagedTableConfig(metadata: PlotMetadata): Promise { + public static async buildPagedTableConfig( + metadata: PlotMetadataContent, + ): Promise { if (!metadata.source) { throw new Error('Malformed metadata, property "source" is required.'); } @@ -174,11 +178,11 @@ export class OutputArtifactLoader { } let data: string[][] = []; const labels = metadata.header || []; + const content = await getSourceContent(metadata.source, metadata.storage); switch (metadata.format) { case 'csv': - const path = WorkflowParser.parseStoragePath(metadata.source); - data = csvParseRows((await Apis.readFile(path)).trim()).map(r => r.map(c => c.trim())); + data = csvParseRows(content.trim()).map(r => r.map(c => c.trim())); break; default: throw new Error('Unsupported table format: ' + metadata.format); @@ -192,7 +196,7 @@ export class OutputArtifactLoader { } public static async buildTensorboardConfig( - metadata: PlotMetadata, + metadata: PlotMetadataContent, ): Promise { if (!metadata.source) { throw new Error('Malformed metadata, property "source" is required.'); @@ -204,15 +208,14 @@ export class OutputArtifactLoader { }; } - public static async buildHtmlViewerConfig(metadata: PlotMetadata): Promise { + public static async buildHtmlViewerConfig( + metadata: PlotMetadataContent, + ): Promise { if (!metadata.source) { throw new Error('Malformed metadata, property "source" is required.'); } - const path = WorkflowParser.parseStoragePath(metadata.source); - const htmlContent = await Apis.readFile(path); - return { - htmlContent, + htmlContent: await getSourceContent(metadata.source, metadata.storage), type: PlotType.WEB_APP, }; } @@ -350,26 +353,18 @@ export class OutputArtifactLoader { } public static async buildMarkdownViewerConfig( - metadata: PlotMetadata, + metadata: PlotMetadataContent, ): Promise { if (!metadata.source) { throw new Error('Malformed metadata, property "source" is required.'); } - let markdownContent = ''; - if (metadata.storage === 'inline') { - markdownContent = metadata.source; - } else { - const path = WorkflowParser.parseStoragePath(metadata.source); - markdownContent = await Apis.readFile(path); - } - return { - markdownContent, + markdownContent: await getSourceContent(metadata.source, metadata.storage), type: PlotType.MARKDOWN, }; } - public static async buildRocCurveConfig(metadata: PlotMetadata): Promise { + public static async buildRocCurveConfig(metadata: PlotMetadataContent): Promise { if (!metadata.source) { throw new Error('Malformed metadata, property "source" is required.'); } @@ -380,8 +375,8 @@ export class OutputArtifactLoader { throw new Error('Malformed schema, must be an array of {"name": string, "type": string}'); } - const path = WorkflowParser.parseStoragePath(metadata.source); - const stringData = csvParseRows((await Apis.readFile(path)).trim()); + const content = await getSourceContent(metadata.source, metadata.storage); + const stringData = csvParseRows(content.trim()); const fprIndex = metadata.schema.findIndex(field => field.name === 'fpr'); if (fprIndex === -1) { @@ -581,3 +576,13 @@ async function buildArtifactViewerTfdvStatistics(url: string): Promise { + if (storage === 'inline') { + return source; + } + return await Apis.readFile(WorkflowParser.parseStoragePath(source)); +}