Skip to content

Commit

Permalink
Replace implicit kwargs with explicit defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
ldeso committed Oct 11, 2023
1 parent 19a5e6a commit fdafa8f
Showing 1 changed file with 30 additions and 19 deletions.
49 changes: 30 additions & 19 deletions src/Loggers/LogCustomScalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,12 @@ function chart(name::String, metadata::Tuple{tb_chart_type, AbstractArray})
chart_type, tags = metadata

if chart_type == tb_multiline
content = MultilineChartContent(tag = tags)
return Chart(title = name, multiline = content)
content = MultilineChartContent(tags)
return Chart(name, OneOf(:multiline, content))
elseif chart_type == tb_margin
@assert length(tags) == 3
args = Dict(k => v for (k, v) in zip([:value, :lower, :upper], tags))
content = MarginChartContent(
series = [MarginChartContent_Series(; args...)])
return Chart(title = name, margin = content)
@assert length(tags) == 3 # value, lower, upper
content = MarginChartContent([MarginChartContent_Series(tags...)])
return Chart(name, OneOf(:margin, content))
else
@error "The chart type must be `tb_multiline` or `tb_margin`"
end
Expand All @@ -47,16 +45,29 @@ end

function custom_scalar_summary(layout)
cat_spec = zip(keys(layout), values(layout))
categories = [Category(title = k, chart = charts(c)) for (k, c) in cat_spec]

layout = Layout(category = categories)
plugin_data = SummaryMetadata_PluginData(plugin_name = "custom_scalars")
smd = SummaryMetadata(plugin_data = plugin_data)
cs_tensor = TensorProto(dtype = _DataType.DT_STRING,
string_val = [serialize_proto(layout)],
tensor_shape = TensorShapeProto())

Summary_Value(tag = "custom_scalars__config__",
tensor = cs_tensor,
metadata = smd)
categories = [Category(title, charts(c), false) for (title, c) in cat_spec]

layout = Layout(zero(Int32), categories)
plugin_data = SummaryMetadata_PluginData("custom_scalars", UInt8[])
smd = SummaryMetadata(plugin_data, "", "", DataClass.DATA_CLASS_UNKNOWN)
cs_tensor = TensorProto(_DataType.DT_STRING,
nothing,
zero(Int32),
UInt8[],
Int32[],
Float32[],
Float64[],
Int32[],
[serialize_proto(layout)],
Float32[],
Int64[],
Bool[],
Float64[],
ResourceHandleProto[],
VariantTensorDataProto[],
UInt32[],
UInt64[],
UInt8[])

Summary_Value("", "custom_scalars__config__", smd, OneOf(:tensor, cs_tensor))
end

0 comments on commit fdafa8f

Please sign in to comment.