diff --git a/docs/tutorials/intro_tutorial.ipynb b/docs/tutorials/intro_tutorial.ipynb index 8f053d17b71..fe76ea032be 100644 --- a/docs/tutorials/intro_tutorial.ipynb +++ b/docs/tutorials/intro_tutorial.ipynb @@ -53,7 +53,7 @@ "Install Mesa:\n", "\n", "```bash\n", - "pip install mesa\n", + "pip install --upgrade mesa\n", "```\n", "\n", "Install Jupyter Notebook (optional):\n", @@ -130,8 +130,8 @@ "source": [ "import mesa\n", "\n", - "# Data visualization tool.\n", - "import matplotlib.pyplot as plt\n", + "# Data visualization tools.\n", + "import seaborn as sns\n", "\n", "# Has multi-dimensional arrays and matrices. Has a large collection of\n", "# mathematical functions to operate on these arrays.\n", @@ -509,6 +509,7 @@ "If you are running from a text editor or IDE, you'll also need to add this line, to make the graph appear.\n", "\n", "```python\n", + "import matplotlib.pyplot as plt\n", "plt.show()\n", "```" ] @@ -531,7 +532,11 @@ "import matplotlib.pyplot as plt\n", "\n", "agent_wealth = [a.wealth for a in model.schedule.agents]\n", - "plt.hist(agent_wealth)" + "# Create a histogram with seaborn\n", + "g = sns.histplot(agent_wealth, discrete=True)\n", + "g.set(\n", + " title=\"Wealth distribution\", xlabel=\"Wealth\", ylabel=\"Number of agents\"\n", + "); # The semicolon is just to avoid printing the object representation" ] }, { @@ -571,7 +576,9 @@ " for agent in model.schedule.agents:\n", " all_wealth.append(agent.wealth)\n", "\n", - "plt.hist(all_wealth, bins=range(max(all_wealth) + 1))" + "# Use seaborn\n", + "g = sns.histplot(all_wealth, discrete=True)\n", + "g.set(title=\"Wealth distribution\", xlabel=\"Wealth\", ylabel=\"Number of agents\");" ] }, { @@ -758,7 +765,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Let's create a model with 50 agents on a 10x10 grid, and run it for 20 steps." + "Let's create a model with 100 agents on a 10x10 grid, and run it for 20 steps." ] }, { @@ -772,7 +779,7 @@ }, "outputs": [], "source": [ - "model = MoneyModel(50, 10, 10)\n", + "model = MoneyModel(100, 10, 10)\n", "for i in range(20):\n", " model.step()" ] @@ -802,11 +809,10 @@ " cell_content, x, y = cell\n", " agent_count = len(cell_content)\n", " agent_counts[x][y] = agent_count\n", - "plt.imshow(agent_counts, interpolation=\"nearest\")\n", - "plt.colorbar()\n", - "\n", - "# If running from a text editor or IDE, remember you'll need the following:\n", - "# plt.show()" + "# Plot using seaborn, with a size of 5x5\n", + "g = sns.heatmap(agent_counts, cmap=\"viridis\", annot=True, cbar=False, square=True)\n", + "g.figure.set_size_inches(4, 4)\n", + "g.set(title=\"Number of agents on each cell of the grid\");" ] }, { @@ -923,7 +929,7 @@ }, "outputs": [], "source": [ - "model = MoneyModel(50, 10, 10)\n", + "model = MoneyModel(100, 10, 10)\n", "for i in range(100):\n", " model.step()" ] @@ -947,7 +953,9 @@ "outputs": [], "source": [ "gini = model.datacollector.get_model_vars_dataframe()\n", - "gini.plot()" + "# Plot the Gini coefficient over time\n", + "g = sns.lineplot(data=gini)\n", + "g.set(title=\"Gini Coefficient over Time\", ylabel=\"Gini Coefficient\");" ] }, { @@ -976,7 +984,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You'll see that the DataFrame's index is pairings of model step and agent ID. You can analyze it the way you would any other DataFrame. For example, to get a histogram of agent wealth at the model's end:" + "You'll see that the DataFrame's index is pairings of model step and agent ID. This is because the data collector stores the data in a dictionary, with the step number as the key, and a dictionary of agent ID and variable value pairs as the value. The data collector then converts this dictionary into a DataFrame, which is why the index is a pair of (model step, agent ID). You can analyze it the way you would any other DataFrame. For example, to get a histogram of agent wealth at the model's end:" ] }, { @@ -990,8 +998,15 @@ }, "outputs": [], "source": [ - "end_wealth = agent_wealth.xs(99, level=\"Step\")[\"Wealth\"]\n", - "end_wealth.hist(bins=range(agent_wealth.Wealth.max() + 1))" + "last_step = agent_wealth.index.get_level_values(\"Step\").max()\n", + "end_wealth = agent_wealth.xs(last_step, level=\"Step\")[\"Wealth\"]\n", + "# Create a histogram of wealth at the last step\n", + "g = sns.histplot(end_wealth, discrete=True)\n", + "g.set(\n", + " title=\"Distribution of wealth at the end of simulation\",\n", + " xlabel=\"Wealth\",\n", + " ylabel=\"Number of agents\",\n", + ");" ] }, { @@ -1012,8 +1027,12 @@ }, "outputs": [], "source": [ + "# Get the wealth of agent 14 over time\n", "one_agent_wealth = agent_wealth.xs(14, level=\"AgentID\")\n", - "one_agent_wealth.Wealth.plot()" + "\n", + "# Plot the wealth of agent 14 over time\n", + "g = sns.lineplot(data=one_agent_wealth, x=\"Step\", y=\"Wealth\")\n", + "g.set(title=\"Wealth of agent 14 over time\");" ] }, { @@ -1235,10 +1254,17 @@ }, "outputs": [], "source": [ + "# Filter the results to only contain the data of one agent (the Gini coefficient will be the same for the entire population at any time) at the 100th step of each episode\n", "results_filtered = results_df[(results_df.AgentID == 0) & (results_df.Step == 100)]\n", - "N_values = results_filtered.N.values\n", - "gini_values = results_filtered.Gini.values\n", - "plt.scatter(N_values, gini_values)" + "results_filtered[[\"iteration\", \"N\", \"Gini\"]].reset_index(\n", + " drop=True\n", + ").head() # Create a scatter plot\n", + "g = sns.scatterplot(data=results_filtered, x=\"N\", y=\"Gini\")\n", + "g.set(\n", + " xlabel=\"Number of agents\",\n", + " ylabel=\"Gini coefficient\",\n", + " title=\"Gini coefficient vs. number of agents\",\n", + ");" ] }, { diff --git a/setup.py b/setup.py index 00e88cbe18c..7eff8bb317c 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,14 @@ # Explicitly install ipykernel for Python 3.8. # See https://stackoverflow.com/questions/28831854/how-do-i-add-python3-kernel-to-jupyter-ipython # Could be removed in the future - "docs": ["sphinx<7", "ipython", "nbsphinx", "ipykernel", "pydata_sphinx_theme"], + "docs": [ + "sphinx<7", + "ipython", + "nbsphinx", + "ipykernel", + "pydata_sphinx_theme", + "seaborn", + ], } version = ""