Skip to content

Commit

Permalink
Fix dummy output viz in console
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Aug 26, 2023
1 parent a8ec7b6 commit b95fd18
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions inseq/data/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def show_attributions(
rprint(get_heatmap_type(attribution, curr_color, "Source", use_html=False))
if attribution.target_attributions is not None:
curr_color = colors[idx + 1]
if attribution.target_attributions is not None and display:
display_scores = attribution.source_attributions is None and attribution.step_scores
if (attribution.target_attributions is not None or display_scores) and display:
if curr_color is None and colors:
curr_color = colors[idx]
print("\n\n")
Expand Down Expand Up @@ -280,10 +281,7 @@ def get_saliency_heatmap_rich(
style = lambda val, limit: "bold" if abs(val) >= limit and isinstance(val, float) else ""
score_row = [Text(step_score_name, style="bold")]
for score in step_score_values:
if isinstance(step_score_values[col_index].item(), float):
curr_score = round(step_score_values[col_index].item(), 2)
else:
curr_score = step_score_values[col_index].item()
curr_score = round(score.item(), 2) if isinstance(score, float) else score.item()
score_row.append(Text(f"{score:.2f}", justify="center", style=style(curr_score, threshold)))
table.add_row(*score_row, end_section=True)
return table
Expand Down

0 comments on commit b95fd18

Please sign in to comment.