Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] - Add support for converting model results, including to DFs #196

Merged
merged 9 commits into from
Jun 29, 2023
2 changes: 1 addition & 1 deletion fooof/core/modutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,6 @@ def wrapped_func(*args, **kwargs):
if not dep:
raise ImportError("Optional FOOOF dependency " + name + \
" is required for this functionality.")
func(*args, **kwargs)
return func(*args, **kwargs)
return wrapped_func
return wrap
106 changes: 106 additions & 0 deletions fooof/data/conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Conversion functions for organizing model results into alternate representations."""

import numpy as np

from fooof import Bands
from fooof.core.funcs import infer_ap_func
from fooof.core.info import get_ap_indices, get_peak_indices
from fooof.core.modutils import safe_import, check_dependency
from fooof.analysis.periodic import get_band_peak

pd = safe_import('pandas')

###################################################################################################
###################################################################################################

def model_to_dict(fit_results, peak_org):
"""Convert model fit results to a dictionary.

Parameters
----------
fit_results : FOOOFResults
Results of a model fit.
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.

Returns
-------
dict
Model results organized into a dictionary.
"""

fr_dict = {}

# aperiodic parameters
for label, param in zip(get_ap_indices(infer_ap_func(fit_results.aperiodic_params)),
fit_results.aperiodic_params):
fr_dict[label] = param

# periodic parameters
peaks = fit_results.peak_params

if isinstance(peak_org, int):

if len(peaks) < peak_org:
nans = [np.array([np.nan] * 3) for ind in range(peak_org-len(peaks))]
peaks = np.vstack((peaks, nans))
ryanhammonds marked this conversation as resolved.
Show resolved Hide resolved

for ind, peak in enumerate(peaks[:peak_org, :]):
for pe_label, pe_param in zip(get_peak_indices(), peak):
ryanhammonds marked this conversation as resolved.
Show resolved Hide resolved
fr_dict[pe_label.lower() + '_' + str(ind)] = pe_param

elif isinstance(peak_org, Bands):
for band, f_range in peak_org:
for label, param in zip(get_peak_indices(), get_band_peak(peaks, f_range)):
fr_dict[band + '_' + label.lower()] = param

# goodness-of-fit metrics
fr_dict['error'] = fit_results.error
fr_dict['r_squared'] = fit_results.r_squared

return fr_dict

@check_dependency(pd, 'pandas')
def model_to_dataframe(fit_results, peak_org):
"""Convert model fit results to a dataframe.

Parameters
----------
fit_results : FOOOFResults
Results of a model fit.
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.

Returns
-------
pd.Series
Model results organized into a dataframe.
"""

return pd.Series(model_to_dict(fit_results, peak_org))


@check_dependency(pd, 'pandas')
def group_to_dataframe(fit_results, peak_org):
"""Convert a group of model fit results into a dataframe.

Parameters
----------
fit_results : list of FOOOFResults
List of FOOOFResults objects.
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.

Returns
-------
pd.DataFrame
Model results organized into a dataframe.
"""

return pd.DataFrame([model_to_dataframe(f_res, peak_org) for f_res in fit_results])
20 changes: 20 additions & 0 deletions fooof/objs/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from fooof.utils.data import trim_spectrum
from fooof.utils.params import compute_gauss_std
from fooof.data import FOOOFResults, FOOOFSettings, FOOOFMetaData
from fooof.data.conversions import model_to_dataframe
from fooof.sim.gen import gen_freqs, gen_aperiodic, gen_periodic, gen_model

###################################################################################################
Expand Down Expand Up @@ -716,6 +717,25 @@ def set_check_data_mode(self, check_data):
self._check_data = check_data


def to_df(self, peak_org):
"""Convert and extract the model results as a pandas object.

Parameters
----------
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.

Returns
-------
pd.Series
Model results organized into a pandas object.
"""

return model_to_dataframe(self.get_results(), peak_org)


def _check_width_limits(self):
"""Check and warn about peak width limits / frequency resolution interaction."""

Expand Down
20 changes: 20 additions & 0 deletions fooof/objs/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from fooof.core.strings import gen_results_fg_str
from fooof.core.io import save_fg, load_jsonlines
from fooof.core.modutils import copy_doc_func_to_method, safe_import
from fooof.data.conversions import group_to_dataframe

###################################################################################################
###################################################################################################
Expand Down Expand Up @@ -543,6 +544,25 @@ def print_results(self, concise=False):
print(gen_results_fg_str(self, concise))


def to_df(self, peak_org):
"""Convert and extract the model results as a pandas object.

Parameters
----------
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.

Returns
-------
pd.DataFrame
Model results organized into a pandas object.
"""

return group_to_dataframe(self.get_results(), peak_org)


def _fit(self, *args, **kwargs):
"""Create an alias to FOOOF.fit for FOOOFGroup object, for internal use."""

Expand Down
11 changes: 10 additions & 1 deletion fooof/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from fooof.core.modutils import safe_import

from fooof.tests.tutils import get_tfm, get_tfg, get_tbands
from fooof.tests.tutils import get_tfm, get_tfg, get_tbands, get_tresults
from fooof.tests.settings import BASE_TEST_FILE_PATH, TEST_DATA_PATH, TEST_REPORTS_PATH

plt = safe_import('.pyplot', 'matplotlib')
Expand Down Expand Up @@ -46,7 +46,16 @@ def tfg():
def tbands():
yield get_tbands()

@pytest.fixture(scope='session')
def tresults():
yield get_tresults()

@pytest.fixture(scope='session')
def skip_if_no_mpl():
if not safe_import('matplotlib'):
pytest.skip('Matplotlib not available: skipping test.')

@pytest.fixture(scope='session')
def skip_if_no_pandas():
if not safe_import('pandas'):
pytest.skip('Pandas not available: skipping test.')
53 changes: 53 additions & 0 deletions fooof/tests/data/test_conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Tests for the fooof.data.conversions."""

from copy import deepcopy

import numpy as np

from fooof.core.modutils import safe_import
pd = safe_import('pandas')

from fooof.data.conversions import *

###################################################################################################
###################################################################################################

def test_model_to_dict(tresults, tbands):

out = model_to_dict(tresults, peak_org=1)
assert isinstance(out, dict)
assert 'cf_0' in out
assert out['cf_0'] == tresults.peak_params[0, 0]
assert not 'cf_1' in out

out = model_to_dict(tresults, peak_org=2)
assert 'cf_0' in out
assert 'cf_1' in out
assert out['cf_1'] == tresults.peak_params[1, 0]

out = model_to_dict(tresults, peak_org=3)
assert 'cf_2' in out
assert np.isnan(out['cf_2'])

out = model_to_dict(tresults, peak_org=tbands)
assert 'alpha_cf' in out

def test_model_to_dataframe(tresults, tbands, skip_if_no_pandas):

for peak_org in [1, 2, 3]:
out = model_to_dataframe(tresults, peak_org=peak_org)
assert isinstance(out, pd.Series)

out = model_to_dataframe(tresults, peak_org=tbands)
assert isinstance(out, pd.Series)

def test_group_to_dataframe(tresults, tbands, skip_if_no_pandas):

fit_results = [deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)]

for peak_org in [1, 2, 3]:
out = group_to_dataframe(fit_results, peak_org=peak_org)
assert isinstance(out, pd.DataFrame)

out = group_to_dataframe(fit_results, peak_org=tbands)
assert isinstance(out, pd.DataFrame)
12 changes: 11 additions & 1 deletion fooof/tests/objs/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
from fooof.core.items import OBJ_DESC
from fooof.core.errors import FitError
from fooof.core.utils import group_three
from fooof.core.modutils import safe_import
from fooof.core.errors import DataError, NoDataError, InconsistentDataError
from fooof.sim import gen_freqs, gen_power_spectrum
from fooof.data import FOOOFSettings, FOOOFMetaData, FOOOFResults
from fooof.core.errors import DataError, NoDataError, InconsistentDataError

pd = safe_import('pandas')

from fooof.tests.settings import TEST_DATA_PATH
from fooof.tests.tutils import get_tfm, plot_test
Expand Down Expand Up @@ -425,3 +428,10 @@ def test_fooof_check_data():
# Model fitting should execute, but return a null model fit, given the NaNs, without failing
tfm.fit()
assert not tfm.has_model

def test_fooof_to_df(tfm, tbands, skip_if_no_pandas):

df1 = tfm.to_df(2)
assert isinstance(df1, pd.Series)
df2 = tfm.to_df(tbands)
assert isinstance(df2, pd.Series)
13 changes: 12 additions & 1 deletion fooof/tests/objs/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
import numpy as np
from numpy.testing import assert_equal

from fooof.data import FOOOFResults
from fooof.core.items import OBJ_DESC
from fooof.core.modutils import safe_import
from fooof.core.errors import DataError, NoDataError, InconsistentDataError
from fooof.data import FOOOFResults
from fooof.sim import gen_group_power_spectra

pd = safe_import('pandas')

from fooof.tests.settings import TEST_DATA_PATH
from fooof.tests.tutils import default_group_params, plot_test

Expand Down Expand Up @@ -349,3 +353,10 @@ def test_fg_get_group(tfg):
# Check that the correct results are extracted
assert [tfg.group_results[ind] for ind in inds1] == nfg1.group_results
assert [tfg.group_results[ind] for ind in inds2] == nfg2.group_results

def test_fg_to_df(tfg, tbands, skip_if_no_pandas):

df1 = tfg.to_df(2)
assert isinstance(df1, pd.DataFrame)
df2 = tfg.to_df(tbands)
assert isinstance(df2, pd.DataFrame)
11 changes: 11 additions & 0 deletions fooof/tests/tutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from functools import wraps

import numpy as np

from fooof.bands import Bands
from fooof.data import FOOOFResults
from fooof.objs import FOOOF, FOOOFGroup
from fooof.core.modutils import safe_import
from fooof.sim.params import param_sampler
Expand Down Expand Up @@ -43,6 +46,14 @@ def get_tbands():

return Bands({'theta' : (4, 8), 'alpha' : (8, 12), 'beta' : (13, 30)})

def get_tresults():
"""Get a FOOOFResults objet, for testing."""

return FOOOFResults(aperiodic_params=np.array([1.0, 1.00]),
peak_params=np.array([[10.0, 1.25, 2.0], [20.0, 1.0, 3.0]]),
r_squared=0.97, error=0.01,
gaussian_params=np.array([[10.0, 1.25, 1.0], [20.0, 1.0, 1.5]]))

def default_group_params():
"""Create default parameters for generating a test group of power spectra."""

Expand Down
3 changes: 2 additions & 1 deletion optional-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
matplotlib
tqdm
tqdm
pandas