Skip to content

Commit

Permalink
Render flexible confusion matrices as expected (#2523)
Browse files Browse the repository at this point in the history
* concatenate data required for flexible confusion matrix into rev field

* refactor fill template

* remove unnecessary stroke dash entries from confusion matrix

* refactor suppression of encoding elements for confusion matrices

* add unit test for get children

* include integration test

* refactor fill template

* update HEAD revision to main
  • Loading branch information
mattseddon committed Oct 5, 2022
1 parent f69837c commit 6c5e2aa
Show file tree
Hide file tree
Showing 7 changed files with 337,534 additions and 31 deletions.
51 changes: 35 additions & 16 deletions extension/src/plots/model/collect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -501,30 +501,49 @@ export const collectTemplates = (data: PlotsOutput): TemplateAccumulator => {
return acc
}

const updateDatapoints = (
datapoints: unknown[],
key: string,
fields: string[]
): unknown[] =>
datapoints.map(data => {
const obj = data as Record<string, unknown>
return {
...obj,
[key]: mergeFields(fields.map(field => obj[field] as string))
}
})

const stringifyDatapoints = (
datapoints: unknown[],
field: string | undefined,
isMultiView: boolean
): string => {
if (!field || (!isMultiView && !isConcatenatedField(field))) {
return JSON.stringify(datapoints)
}

const fields = unmergeConcatenatedFields(field)

if (isMultiView) {
fields.unshift('rev')
return JSON.stringify(updateDatapoints(datapoints, 'rev', fields))
}

return JSON.stringify(updateDatapoints(datapoints, field, fields))
}

const fillTemplate = (
template: string,
datapoints: unknown[],
field?: string
) => {
if (!field || !isConcatenatedField(field)) {
return JSON.parse(
template.replace('"<DVC_METRIC_DATA>"', JSON.stringify(datapoints))
) as TopLevelSpec
}
): TopLevelSpec => {
const isMultiView = isMultiViewPlot(JSON.parse(template))

const fields = unmergeConcatenatedFields(field)
return JSON.parse(
template.replace(
'"<DVC_METRIC_DATA>"',
JSON.stringify(
datapoints.map(data => {
const obj = data as Record<string, unknown>
return {
...obj,
[field]: mergeFields(fields.map(field => obj[field] as string))
}
})
)
stringifyDatapoints(datapoints, field, isMultiView)
)
) as TopLevelSpec
}
Expand Down
92 changes: 92 additions & 0 deletions extension/src/plots/paths/model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,96 @@ describe('PathsModel', () => {

expect(model.getComparisonPaths()).toStrictEqual(newOrder)
})

it('should return the expected children from the test fixture', () => {
const model = new PathsModel(mockDvcRoot, buildMockMemento())
model.transformAndSet(plotsDiffFixture)

const rootChildren = model.getChildren(undefined, {
'predictions.json': {
strokeDash: { field: '', scale: { domain: [], range: [] } }
}
})
expect(rootChildren).toStrictEqual([
{
descendantStatuses: [2, 2, 2],
hasChildren: true,
label: 'plots',
parentPath: undefined,
path: 'plots',
status: 2
},
{
descendantStatuses: [2, 2],
hasChildren: true,
label: 'logs',
parentPath: undefined,
path: 'logs',
status: 2
},
{
descendantStatuses: [],
hasChildren: false,
label: 'predictions.json',
parentPath: undefined,
path: 'predictions.json',
status: 2,
type: new Set([PathType.TEMPLATE_MULTI])
}
])

const directoryChildren = model.getChildren('logs')
expect(directoryChildren).toStrictEqual([
{
descendantStatuses: [],
hasChildren: false,
label: 'loss.tsv',
parentPath: 'logs',
path: logsLoss,
status: 2,
type: new Set([PathType.TEMPLATE_SINGLE])
},
{
descendantStatuses: [],
hasChildren: false,
label: 'acc.tsv',
parentPath: 'logs',
path: logsAcc,
status: 2,
type: new Set([PathType.TEMPLATE_SINGLE])
}
])

const plotsWithEncoding = model.getChildren('logs', {
[logsAcc]: {
strokeDash: { field: '', scale: { domain: [], range: [] } }
},
[logsLoss]: {
strokeDash: { field: '', scale: { domain: [], range: [] } }
}
})
expect(plotsWithEncoding).toStrictEqual([
{
descendantStatuses: [],
hasChildren: true,
label: 'loss.tsv',
parentPath: 'logs',
path: logsLoss,
status: 2,
type: new Set([PathType.TEMPLATE_SINGLE])
},
{
descendantStatuses: [],
hasChildren: true,
label: 'acc.tsv',
parentPath: 'logs',
path: logsAcc,
status: 2,
type: new Set([PathType.TEMPLATE_SINGLE])
}
])

const noChildren = model.getChildren(logsLoss)
expect(noChildren).toStrictEqual([])
})
})
37 changes: 23 additions & 14 deletions extension/src/plots/paths/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,13 @@ export class PathsModel extends PathSelectionModel<PlotPath> {
path: string | undefined,
multiSourceEncoding: MultiSourceEncoding = {}
) {
return this.filterChildren(path).map(element => {
const hasChildren =
element.hasChildren === false
? !!multiSourceEncoding[element.path]
: element.hasChildren

return {
...element,
descendantStatuses: this.getTerminalNodeStatuses(element.path),
hasChildren,
label: element.label,
status: this.status[element.path]
}
})
return this.filterChildren(path).map(element => ({
...element,
descendantStatuses: this.getTerminalNodeStatuses(element.path),
hasChildren: this.getHasChildren(element, multiSourceEncoding),
label: element.label,
status: this.status[element.path]
}))
}

public getTemplateOrder(): TemplateOrder {
Expand Down Expand Up @@ -116,4 +109,20 @@ export class PathsModel extends PathSelectionModel<PlotPath> {
return element.parentPath === path
})
}

private getHasChildren(
element: PlotPath,
multiSourceEncoding: MultiSourceEncoding
) {
const hasEncodingChildren =
!element.hasChildren &&
!element.type?.has(PathType.TEMPLATE_MULTI) &&
!!multiSourceEncoding[element.path]

if (hasEncodingChildren) {
return true
}

return element.hasChildren
}
}
4 changes: 4 additions & 0 deletions extension/src/test/fixtures/plotsDiff/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,10 @@ export const getOutput = (

export const getMinimalOutput = (): PlotsOutput => ({ ...basicVega })

export const getMultiSourceOutput = (): PlotsOutput => ({
...require('./multiSource').default
})

const expectedRevisions = ['workspace', 'main', '4fb124a', '42b8736', '1ba7bcd']

const extendedSpecs = (plotsOutput: TemplatePlots): TemplatePlotSection[] => {
Expand Down
Loading

0 comments on commit 6c5e2aa

Please sign in to comment.