Skip to content

Commit

Permalink
Add visualizations for Schema and ExampleAnomalies (kubeflow#3026)
Browse files Browse the repository at this point in the history
* Add visualizations for Schema and ExampleAnomalies

* Run npm format

* Fix compile warnings

* Address PR comments
  • Loading branch information
Realsen authored and Jeffwan committed Dec 9, 2020
1 parent e33c21d commit 7d77ada
Showing 1 changed file with 77 additions and 38 deletions.
115 changes: 77 additions & 38 deletions frontend/src/lib/OutputArtifactLoader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,56 @@ export class OutputArtifactLoader {
return [];
}

// TODO: Visualize other artifact types, such as Anomalies and Schema, using TFDV
// as well as ModelEvaluation using TFMA.
const tfdvArtifactPaths = filterTfdvArtifactsPaths(artifactTypes, artifacts);
const tfdvArtifactViewerConfigs = getTfdvArtifactViewers(tfdvArtifactPaths);
return Promise.all(tfdvArtifactViewerConfigs);
// TODO: Visualize non-TFDV artifacts, such as ModelEvaluation using TFMA
let viewers: Array<Promise<HTMLViewerConfig>> = [];
const exampleStatisticsArtifactUris = filterArtifactUrisByType(
'ExampleStatistics',
artifactTypes,
artifacts,
);
exampleStatisticsArtifactUris.forEach(uri => {
const evalUri = uri + '/eval/stats_tfrecord';
const trainUri = uri + '/train/stats_tfrecord';
viewers = viewers.concat(
[evalUri, trainUri].map(async specificUri => {
const script = [
'import tensorflow_data_validation as tfdv',
`stats = tfdv.load_statistics('${specificUri}')`,
'tfdv.visualize_statistics(stats)',
];
return buildArtifactViewer(script);
}),
);
});
const schemaGenArtifactUris = filterArtifactUrisByType('Schema', artifactTypes, artifacts);
viewers = viewers.concat(
schemaGenArtifactUris.map(uri => {
uri = uri + '/schema.pbtxt';
const script = [
'import tensorflow_data_validation as tfdv',
`schema = tfdv.load_schema_text('${uri}')`,
'tfdv.display_schema(schema)',
];
return buildArtifactViewer(script);
}),
);
const anomaliesArtifactUris = filterArtifactUrisByType(
'ExampleAnomalies',
artifactTypes,
artifacts,
);
viewers = viewers.concat(
anomaliesArtifactUris.map(uri => {
uri = uri + '/anomalies.pbtxt';
const script = [
'import tensorflow_data_validation as tfdv',
`anomalies = tfdv.load_anomalies_text('${uri}')`,
'tfdv.display_anomalies(anomalies)',
];
return buildArtifactViewer(script);
}),
);
return Promise.all(viewers);
}

public static async buildMarkdownViewerConfig(
Expand Down Expand Up @@ -450,45 +495,39 @@ async function getArtifactTypes(): Promise<ArtifactType[]> {
return res.getArtifactTypesList();
}

function filterTfdvArtifactsPaths(artifactTypes: ArtifactType[], artifacts: Artifact[]): string[] {
const tfdvArtifactTypeIds = artifactTypes
.filter(artifactType => artifactType.getName() === 'ExampleStatistics')
function filterArtifactUrisByType(
artifactTypeName: string,
artifactTypes: ArtifactType[],
artifacts: Artifact[],
): string[] {
const artifactTypeIds = artifactTypes
.filter(artifactType => artifactType.getName() === artifactTypeName)
.map(artifactType => artifactType.getId());
const tfdvArtifacts = artifacts.filter(artifact =>
tfdvArtifactTypeIds.includes(artifact.getTypeId()),
const matchingArtifacts = artifacts.filter(artifact =>
artifactTypeIds.includes(artifact.getTypeId()),
);

const tfdvArtifactsPaths = tfdvArtifacts
.filter(artifact => artifact.getUri()) // uri not empty
.flatMap(artifact => [
artifact.getUri() + '/eval/stats_tfrecord', // eval uri
artifact.getUri() + '/train/stats_tfrecord', // train uri
]);
const tfdvArtifactsPaths = matchingArtifacts
.map(artifact => artifact.getUri())
.filter(uri => uri); // uri not empty
return tfdvArtifactsPaths;
}

function getTfdvArtifactViewers(tfdvArtifactPaths: string[]): Array<Promise<HTMLViewerConfig>> {
return tfdvArtifactPaths.map(async artifactPath => {
const script = [
'import tensorflow_data_validation as tfdv',
`stats = tfdv.load_statistics('${artifactPath}')`,
'tfdv.visualize_statistics(stats)',
];
const visualizationData: ApiVisualization = {
arguments: JSON.stringify({ code: script }),
source: '',
type: ApiVisualizationType.CUSTOM,
};
const visualization = await Apis.buildPythonVisualizationConfig(visualizationData);
if (!visualization.htmlContent) {
// TODO: Improve error message with details.
throw new Error('Failed to build TFDV artifact visualization');
}
return {
htmlContent: visualization.htmlContent,
type: PlotType.WEB_APP,
};
});
async function buildArtifactViewer(script: string[]): Promise<HTMLViewerConfig> {
const visualizationData: ApiVisualization = {
arguments: JSON.stringify({ code: script }),
source: '',
type: ApiVisualizationType.CUSTOM,
};
const visualization = await Apis.buildPythonVisualizationConfig(visualizationData);
if (!visualization.htmlContent) {
// TODO: Improve error message with details.
throw new Error('Failed to build artifact viewer');
}
return {
htmlContent: visualization.htmlContent,
type: PlotType.WEB_APP,
};
}

// TODO: add tfma back
Expand Down

0 comments on commit 7d77ada

Please sign in to comment.