diff --git a/frontend/src/lib/OutputArtifactLoader.ts b/frontend/src/lib/OutputArtifactLoader.ts index fd25e2751cc4..071d7d4a693e 100644 --- a/frontend/src/lib/OutputArtifactLoader.ts +++ b/frontend/src/lib/OutputArtifactLoader.ts @@ -260,11 +260,56 @@ export class OutputArtifactLoader { return []; } - // TODO: Visualize other artifact types, such as Anomalies and Schema, using TFDV - // as well as ModelEvaluation using TFMA. - const tfdvArtifactPaths = filterTfdvArtifactsPaths(artifactTypes, artifacts); - const tfdvArtifactViewerConfigs = getTfdvArtifactViewers(tfdvArtifactPaths); - return Promise.all(tfdvArtifactViewerConfigs); + // TODO: Visualize non-TFDV artifacts, such as ModelEvaluation using TFMA + let viewers: Array> = []; + const exampleStatisticsArtifactUris = filterArtifactUrisByType( + 'ExampleStatistics', + artifactTypes, + artifacts, + ); + exampleStatisticsArtifactUris.forEach(uri => { + const evalUri = uri + '/eval/stats_tfrecord'; + const trainUri = uri + '/train/stats_tfrecord'; + viewers = viewers.concat( + [evalUri, trainUri].map(async specificUri => { + const script = [ + 'import tensorflow_data_validation as tfdv', + `stats = tfdv.load_statistics('${specificUri}')`, + 'tfdv.visualize_statistics(stats)', + ]; + return buildArtifactViewer(script); + }), + ); + }); + const schemaGenArtifactUris = filterArtifactUrisByType('Schema', artifactTypes, artifacts); + viewers = viewers.concat( + schemaGenArtifactUris.map(uri => { + uri = uri + '/schema.pbtxt'; + const script = [ + 'import tensorflow_data_validation as tfdv', + `schema = tfdv.load_schema_text('${uri}')`, + 'tfdv.display_schema(schema)', + ]; + return buildArtifactViewer(script); + }), + ); + const anomaliesArtifactUris = filterArtifactUrisByType( + 'ExampleAnomalies', + artifactTypes, + artifacts, + ); + viewers = viewers.concat( + anomaliesArtifactUris.map(uri => { + uri = uri + '/anomalies.pbtxt'; + const script = [ + 'import tensorflow_data_validation as tfdv', + `anomalies = tfdv.load_anomalies_text('${uri}')`, + 'tfdv.display_anomalies(anomalies)', + ]; + return buildArtifactViewer(script); + }), + ); + return Promise.all(viewers); } public static async buildMarkdownViewerConfig( @@ -450,45 +495,39 @@ async function getArtifactTypes(): Promise { return res.getArtifactTypesList(); } -function filterTfdvArtifactsPaths(artifactTypes: ArtifactType[], artifacts: Artifact[]): string[] { - const tfdvArtifactTypeIds = artifactTypes - .filter(artifactType => artifactType.getName() === 'ExampleStatistics') +function filterArtifactUrisByType( + artifactTypeName: string, + artifactTypes: ArtifactType[], + artifacts: Artifact[], +): string[] { + const artifactTypeIds = artifactTypes + .filter(artifactType => artifactType.getName() === artifactTypeName) .map(artifactType => artifactType.getId()); - const tfdvArtifacts = artifacts.filter(artifact => - tfdvArtifactTypeIds.includes(artifact.getTypeId()), + const matchingArtifacts = artifacts.filter(artifact => + artifactTypeIds.includes(artifact.getTypeId()), ); - const tfdvArtifactsPaths = tfdvArtifacts - .filter(artifact => artifact.getUri()) // uri not empty - .flatMap(artifact => [ - artifact.getUri() + '/eval/stats_tfrecord', // eval uri - artifact.getUri() + '/train/stats_tfrecord', // train uri - ]); + const tfdvArtifactsPaths = matchingArtifacts + .map(artifact => artifact.getUri()) + .filter(uri => uri); // uri not empty return tfdvArtifactsPaths; } -function getTfdvArtifactViewers(tfdvArtifactPaths: string[]): Array> { - return tfdvArtifactPaths.map(async artifactPath => { - const script = [ - 'import tensorflow_data_validation as tfdv', - `stats = tfdv.load_statistics('${artifactPath}')`, - 'tfdv.visualize_statistics(stats)', - ]; - const visualizationData: ApiVisualization = { - arguments: JSON.stringify({ code: script }), - source: '', - type: ApiVisualizationType.CUSTOM, - }; - const visualization = await Apis.buildPythonVisualizationConfig(visualizationData); - if (!visualization.htmlContent) { - // TODO: Improve error message with details. - throw new Error('Failed to build TFDV artifact visualization'); - } - return { - htmlContent: visualization.htmlContent, - type: PlotType.WEB_APP, - }; - }); +async function buildArtifactViewer(script: string[]): Promise { + const visualizationData: ApiVisualization = { + arguments: JSON.stringify({ code: script }), + source: '', + type: ApiVisualizationType.CUSTOM, + }; + const visualization = await Apis.buildPythonVisualizationConfig(visualizationData); + if (!visualization.htmlContent) { + // TODO: Improve error message with details. + throw new Error('Failed to build artifact viewer'); + } + return { + htmlContent: visualization.htmlContent, + type: PlotType.WEB_APP, + }; } // TODO: add tfma back