Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

intro tutorial: Switch to using Seaborn #1718

Merged
merged 1 commit into from
Jun 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 47 additions & 21 deletions docs/tutorials/intro_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"Install Mesa:\n",
"\n",
"```bash\n",
"pip install mesa\n",
"pip install --upgrade mesa\n",
"```\n",
"\n",
"Install Jupyter Notebook (optional):\n",
Expand Down Expand Up @@ -130,8 +130,8 @@
"source": [
"import mesa\n",
"\n",
"# Data visualization tool.\n",
"import matplotlib.pyplot as plt\n",
"# Data visualization tools.\n",
"import seaborn as sns\n",
"\n",
"# Has multi-dimensional arrays and matrices. Has a large collection of\n",
"# mathematical functions to operate on these arrays.\n",
Expand Down Expand Up @@ -509,6 +509,7 @@
"If you are running from a text editor or IDE, you'll also need to add this line, to make the graph appear.\n",
"\n",
"```python\n",
"import matplotlib.pyplot as plt\n",
"plt.show()\n",
"```"
]
Expand All @@ -531,7 +532,11 @@
"import matplotlib.pyplot as plt\n",
"\n",
"agent_wealth = [a.wealth for a in model.schedule.agents]\n",
"plt.hist(agent_wealth)"
"# Create a histogram with seaborn\n",
"g = sns.histplot(agent_wealth, discrete=True)\n",
"g.set(\n",
" title=\"Wealth distribution\", xlabel=\"Wealth\", ylabel=\"Number of agents\"\n",
"); # The semicolon is just to avoid printing the object representation"
]
},
{
Expand Down Expand Up @@ -571,7 +576,9 @@
" for agent in model.schedule.agents:\n",
" all_wealth.append(agent.wealth)\n",
"\n",
"plt.hist(all_wealth, bins=range(max(all_wealth) + 1))"
"# Use seaborn\n",
"g = sns.histplot(all_wealth, discrete=True)\n",
"g.set(title=\"Wealth distribution\", xlabel=\"Wealth\", ylabel=\"Number of agents\");"
]
},
{
Expand Down Expand Up @@ -758,7 +765,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's create a model with 50 agents on a 10x10 grid, and run it for 20 steps."
"Let's create a model with 100 agents on a 10x10 grid, and run it for 20 steps."
]
},
{
Expand All @@ -772,7 +779,7 @@
},
"outputs": [],
"source": [
"model = MoneyModel(50, 10, 10)\n",
"model = MoneyModel(100, 10, 10)\n",
"for i in range(20):\n",
" model.step()"
]
Expand Down Expand Up @@ -802,11 +809,10 @@
" cell_content, x, y = cell\n",
" agent_count = len(cell_content)\n",
" agent_counts[x][y] = agent_count\n",
"plt.imshow(agent_counts, interpolation=\"nearest\")\n",
"plt.colorbar()\n",
"\n",
"# If running from a text editor or IDE, remember you'll need the following:\n",
"# plt.show()"
"# Plot using seaborn, with a size of 5x5\n",
"g = sns.heatmap(agent_counts, cmap=\"viridis\", annot=True, cbar=False, square=True)\n",
"g.figure.set_size_inches(4, 4)\n",
"g.set(title=\"Number of agents on each cell of the grid\");"
]
},
{
Expand Down Expand Up @@ -923,7 +929,7 @@
},
"outputs": [],
"source": [
"model = MoneyModel(50, 10, 10)\n",
"model = MoneyModel(100, 10, 10)\n",
"for i in range(100):\n",
" model.step()"
]
Expand All @@ -947,7 +953,9 @@
"outputs": [],
"source": [
"gini = model.datacollector.get_model_vars_dataframe()\n",
"gini.plot()"
"# Plot the Gini coefficient over time\n",
"g = sns.lineplot(data=gini)\n",
"g.set(title=\"Gini Coefficient over Time\", ylabel=\"Gini Coefficient\");"
]
},
{
Expand Down Expand Up @@ -976,7 +984,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"You'll see that the DataFrame's index is pairings of model step and agent ID. You can analyze it the way you would any other DataFrame. For example, to get a histogram of agent wealth at the model's end:"
"You'll see that the DataFrame's index is pairings of model step and agent ID. This is because the data collector stores the data in a dictionary, with the step number as the key, and a dictionary of agent ID and variable value pairs as the value. The data collector then converts this dictionary into a DataFrame, which is why the index is a pair of (model step, agent ID). You can analyze it the way you would any other DataFrame. For example, to get a histogram of agent wealth at the model's end:"
]
},
{
Expand All @@ -990,8 +998,15 @@
},
"outputs": [],
"source": [
"end_wealth = agent_wealth.xs(99, level=\"Step\")[\"Wealth\"]\n",
"end_wealth.hist(bins=range(agent_wealth.Wealth.max() + 1))"
"last_step = agent_wealth.index.get_level_values(\"Step\").max()\n",
"end_wealth = agent_wealth.xs(last_step, level=\"Step\")[\"Wealth\"]\n",
"# Create a histogram of wealth at the last step\n",
"g = sns.histplot(end_wealth, discrete=True)\n",
"g.set(\n",
" title=\"Distribution of wealth at the end of simulation\",\n",
" xlabel=\"Wealth\",\n",
" ylabel=\"Number of agents\",\n",
");"
]
},
{
Expand All @@ -1012,8 +1027,12 @@
},
"outputs": [],
"source": [
"# Get the wealth of agent 14 over time\n",
"one_agent_wealth = agent_wealth.xs(14, level=\"AgentID\")\n",
"one_agent_wealth.Wealth.plot()"
"\n",
"# Plot the wealth of agent 14 over time\n",
"g = sns.lineplot(data=one_agent_wealth, x=\"Step\", y=\"Wealth\")\n",
"g.set(title=\"Wealth of agent 14 over time\");"
]
},
{
Expand Down Expand Up @@ -1235,10 +1254,17 @@
},
"outputs": [],
"source": [
"# Filter the results to only contain the data of one agent (the Gini coefficient will be the same for the entire population at any time) at the 100th step of each episode\n",
"results_filtered = results_df[(results_df.AgentID == 0) & (results_df.Step == 100)]\n",
"N_values = results_filtered.N.values\n",
"gini_values = results_filtered.Gini.values\n",
"plt.scatter(N_values, gini_values)"
"results_filtered[[\"iteration\", \"N\", \"Gini\"]].reset_index(\n",
" drop=True\n",
").head() # Create a scatter plot\n",
"g = sns.scatterplot(data=results_filtered, x=\"N\", y=\"Gini\")\n",
"g.set(\n",
" xlabel=\"Number of agents\",\n",
" ylabel=\"Gini coefficient\",\n",
" title=\"Gini coefficient vs. number of agents\",\n",
");"
]
},
{
Expand Down
9 changes: 8 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
# Explicitly install ipykernel for Python 3.8.
# See https://stackoverflow.com/questions/28831854/how-do-i-add-python3-kernel-to-jupyter-ipython
# Could be removed in the future
"docs": ["sphinx<7", "ipython", "nbsphinx", "ipykernel", "pydata_sphinx_theme"],
"docs": [
"sphinx<7",
"ipython",
"nbsphinx",
"ipykernel",
"pydata_sphinx_theme",
"seaborn",
],
}

version = ""
Expand Down