Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ci] prevent Python tests from leaving behind files #6626

Merged
merged 10 commits into from
Sep 3, 2024
12 changes: 9 additions & 3 deletions tests/python_package_test/test_plotting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# coding: utf-8

from pathlib import Path

vnherdeiro marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -177,7 +180,7 @@ def test_plot_tree(breast_cancer_split):


@pytest.mark.skipif(not GRAPHVIZ_INSTALLED, reason="graphviz is not installed")
def test_create_tree_digraph(breast_cancer_split):
def test_create_tree_digraph(tmp_path, breast_cancer_split):
X_train, _, y_train, _ = breast_cancer_split

constraints = [-1, 1] * int(X_train.shape[1] / 2)
Expand All @@ -193,6 +196,7 @@ def test_create_tree_digraph(breast_cancer_split):
show_info=["split_gain", "internal_value", "internal_weight"],
name="Tree4",
node_attr={"color": "red"},
directory=tmp_path
)
graph.render(view=False)
assert isinstance(graph, graphviz.Digraph)
Expand All @@ -213,7 +217,7 @@ def test_create_tree_digraph(breast_cancer_split):


@pytest.mark.skipif(not GRAPHVIZ_INSTALLED, reason="graphviz is not installed")
def test_tree_with_categories_below_max_category_values():
def test_tree_with_categories_below_max_category_values(tmp_path: Path):
vnherdeiro marked this conversation as resolved.
Show resolved Hide resolved
X_train, y_train = _categorical_data(2, 10)
params = {
"n_estimators": 10,
Expand All @@ -238,6 +242,7 @@ def test_tree_with_categories_below_max_category_values():
name="Tree4",
node_attr={"color": "red"},
max_category_values=10,
directory=tmp_path,
)
graph.render(view=False)
assert isinstance(graph, graphviz.Digraph)
Expand All @@ -257,7 +262,7 @@ def test_tree_with_categories_below_max_category_values():


@pytest.mark.skipif(not GRAPHVIZ_INSTALLED, reason="graphviz is not installed")
def test_tree_with_categories_above_max_category_values():
def test_tree_with_categories_above_max_category_values(tmp_path: Path):
vnherdeiro marked this conversation as resolved.
Show resolved Hide resolved
X_train, y_train = _categorical_data(20, 30)
params = {
"n_estimators": 10,
Expand All @@ -282,6 +287,7 @@ def test_tree_with_categories_above_max_category_values():
name="Tree4",
node_attr={"color": "red"},
max_category_values=4,
directory=tmp_path,
)
graph.render(view=False)
assert isinstance(graph, graphviz.Digraph)
Expand Down