Skip to content
This repository has been archived by the owner on May 6, 2024. It is now read-only.

Update plotting script to include xkey and ykey #259

Merged
merged 3 commits into from
Sep 22, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 48 additions & 20 deletions nle/scripts/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
python -m nle.scripts.plot path/to/run -x 100 -y 50
```

Plot a specific run to a specific target column.
```
python -m nle.scripts.plot path/to/multiple_runs/2020 --ykey "total_loss"
```


Plot all runs under a specific directory without a legend matching plots to runs.
```
python -m nle.scripts.plot path/to/multiple_runs --no_legend
Expand Down Expand Up @@ -74,6 +80,11 @@ def str_to_float_pair(s):
default="~/torchbeast/latest/logs.tsv",
help="file to plot or directory to look for log files",
)
parser.add_argument("--xkey", default="# Step", type=str, help="x values to plot.")
parser.add_argument(
"--ykey", default="mean_episode_return", type=str, help="y values to plot."
)

parser.add_argument(
"-w", "--window", type=int, default=-1, help="override automatic window size."
)
Expand Down Expand Up @@ -103,27 +114,36 @@ def str_to_float_pair(s):
)


def plot_single_ascii(target, width, height, window=-1, xrange=None, yrange=None):
def plot_single_ascii(
target,
width,
height,
xkey="# Step",
ykey="mean_episode_return",
window=-1,
xrange=None,
yrange=None,
):
"""
Plot the target file using the specified width and height.
If window > 0, use it to specify the window size for rolling averages.
xrange and yrange are used to specify the zoom level of the plot.
"""
print("plotting %s" % str(target))
df = pd.read_csv(target, sep="\t")
steps = np.array(df["# Step"])
steps = np.array(df[xkey])

if window < 0:
window = len(steps) // width + 1
window = df["mean_episode_return"].rolling(window=window, min_periods=0)
window = df[ykey].rolling(window=window, min_periods=0)
returns = np.array(window.mean())
stderrs = np.array(window.std())

plot_options = {}
plot_options["with"] = "yerrorbars"
plot_options["terminal"] = "dumb %d %d ansi" % (width, height)
plot_options["tuplesize"] = 3
plot_options["title"] = "averaged episode return"
plot_options["title"] = "averaged %s" % ykey
plot_options["xlabel"] = "steps"

if xrange is not None:
Expand Down Expand Up @@ -174,6 +194,8 @@ def plot_multiple_ascii(
target,
width,
height,
xkey="# Step",
ykey="mean_episode_return",
window=-1,
xrange=None,
yrange=None,
Expand All @@ -191,21 +213,21 @@ def plot_multiple_ascii(
dfs = collect_logs(target)

if window < 0:
max_size = max(len(df["# Step"]) for name, df in dfs)
max_size = max(len(df[xkey]) for name, df in dfs)
window = 2 * max_size // width + 1

datasets = []
for name, df in dfs:
steps = np.array(df["# Step"])
steps = np.array(df[xkey])
if window > 1:
roll = df["mean_episode_return"].rolling(window=window, min_periods=0)
roll = df[ykey].rolling(window=window, min_periods=0)
try:
rewards = np.array(roll.mean())
except pd.core.base.DataError:
print("Error reading file at %s" % name)
continue
else:
rewards = np.array(df["mean_episode_return"])
rewards = np.array(df[ykey])
if no_legend:
datasets.append((steps, rewards))
else:
Expand All @@ -223,7 +245,7 @@ def plot_multiple_ascii(
plot_options = {}
plot_options["terminal"] = "dumb %d %d ansi" % (width, height)
plot_options["tuplesize"] = 2
plot_options["title"] = "averaged episode return"
plot_options["title"] = "averaged %s" % ykey
plot_options["xlabel"] = "steps"
plot_options["set"] = "key outside below"

Expand All @@ -248,9 +270,11 @@ def plot(flags):
target,
flags.width,
flags.height,
flags.window,
flags.xrange,
flags.yrange,
xkey=flags.xkey,
ykey=flags.ykey,
window=flags.window,
xrange=flags.xrange,
yrange=flags.yrange,
)
else:
raise RuntimeError(
Expand All @@ -262,21 +286,25 @@ def plot(flags):
target / "logs.tsv",
flags.width,
flags.height,
flags.window,
flags.xrange,
flags.yrange,
xkey=flags.xkey,
ykey=flags.ykey,
window=flags.window,
xrange=flags.xrange,
yrange=flags.yrange,
)
else:
# look for runs underneath the specified directory
plot_multiple_ascii(
target,
flags.width,
flags.height,
flags.window,
flags.xrange,
flags.yrange,
flags.no_legend,
flags.shuffle,
xkey=flags.xkey,
ykey=flags.ykey,
window=flags.window,
xrange=flags.xrange,
yrange=flags.yrange,
no_legend=flags.no_legend,
shuffle=flags.shuffle,
)


Expand Down