Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of bad_times and mask_times #123

Merged
merged 1 commit into from
Sep 8, 2022

Conversation

taldcroft
Copy link
Member

@taldcroft taldcroft commented Sep 5, 2022

Description

This fixes a performance issue note on Slack in #thermal_working_group:

I am doing some work that requires me to run the thermal models many, many times in a row, and was profiling my code when I noticed that the pm2thv1t model is running significantly slower than the pm1thv2t model, e.g. in my MATLAB setup:

  • pm1thv2t takes ~032ms to run, whereas
  • pm2thv1t takes ~643ms to run

This PR vectorizes the bad times processing and minimizes the number of date --> time (CXC second) conversions. It also factors this processing out into a separate function (working to improve separation in this code).

Some points of note:

  • The new code assumes that the bad_times in a JSON model spec file are in standard Year-day-of-year date format. There was another part of the code already making this assumption so I think this is OK but this needs review.
  • This code changes the behavior of the mask_* attributes. Previously these values were set corresponding to every bad_times interval in the spec file, including those that are outside of the model date range. While this was not a problem per se it was likely impacting performance. See the functional testing section for more.

Interface impacts

None.

Testing

Unit tests

  • Mac

Independent check of unit tests by [REVIEWER NAME]

  • [PLATFORM]:

Functional tests

I ran this script on master (version 4.26.1) and this branch (4.27.1.dev1+g19aed41).

from time import time

import xija
from xija.get_model_spec import get_xija_model_spec

print(xija.__version__)


def print_mdl_times(mdl):
    print(f"{len(mdl.bad_times)=}")
    print(f"{len(mdl.bad_times_indices)=}")
    print(f"{len(mdl.mask_times)=}")
    print(f"{len(mdl.mask_time_secs)=}")
    print(f"{len(mdl.mask_times_bad)=}")


def timer_func(func):
    # This function shows the execution time of
    # the function object passed
    def wrap_func(*args, **kwargs):
        t1 = time()
        result = func(*args, **kwargs)
        t2 = time()
        name = kwargs["model_spec"]["name"]
        datestart, datestop = args[1:3]
        print(f"{name=} {datestart=} {datestop=} executed in {(t2-t1) * 1000:.1f} ms")
        print_mdl_times(result)
        return result

    return wrap_func


# No bad times in spec1
spec1, version = get_xija_model_spec("pm1thv2t", version="3.40.2")

# 5876 bad times in spec2
spec2, version = get_xija_model_spec("pm2thv1t", version="3.40.2")

XijaModel = timer_func(xija.XijaModel)

mdl = XijaModel("mdl", "2019:001", "2022:301", model_spec=spec1)
mdl = XijaModel("mdl", "2025:001", "2025:002", model_spec=spec1)
mdl = XijaModel("mdl", "2019:001", "2022:301", model_spec=spec2)
mdl = XijaModel("mdl", "2025:001", "2025:002", model_spec=spec2)
mdl = XijaModel("mdl", "2022:083:22:30:00", "2022:084:04:00:00", model_spec=spec2)

Master version 4.26.1

4.26.1
name='pm1thv2t' datestart='2019:001' datestop='2022:301' executed in 6.7 ms
len(mdl.bad_times)=0
len(mdl.bad_times_indices)=0
len(mdl.mask_times)=0
len(mdl.mask_time_secs)=0
len(mdl.mask_times_bad)=0
name='pm1thv2t' datestart='2025:001' datestop='2025:002' executed in 1.1 ms
len(mdl.bad_times)=0
len(mdl.bad_times_indices)=0
len(mdl.mask_times)=0
len(mdl.mask_time_secs)=0
len(mdl.mask_times_bad)=0
name='pm2thv1t' datestart='2019:001' datestop='2022:301' executed in 407.7 ms
len(mdl.bad_times)=5876
len(mdl.bad_times_indices)=3842
len(mdl.mask_times)=5876
len(mdl.mask_time_secs)=5876
len(mdl.mask_times_bad)=5876
name='pm2thv1t' datestart='2025:001' datestop='2025:002' executed in 396.8 ms
len(mdl.bad_times)=5876
len(mdl.bad_times_indices)=0
len(mdl.mask_times)=5876
len(mdl.mask_time_secs)=5876
len(mdl.mask_times_bad)=5876
name='pm2thv1t' datestart='2022:083:22:30:00' datestop='2022:084:04:00:00' executed in 382.5 ms
len(mdl.bad_times)=5876
len(mdl.bad_times_indices)=2
len(mdl.mask_times)=5876
len(mdl.mask_time_secs)=5876
len(mdl.mask_times_bad)=5876

Dev version 4.27.1.dev1+g19aed41

4.27.1.dev2+g859870c
name='pm1thv2t' datestart='2019:001' datestop='2022:301' executed in 7.2 ms
len(mdl.bad_times)=0
len(mdl.bad_times_indices)=0
len(mdl.mask_times)=0
len(mdl.mask_time_secs)=0
len(mdl.mask_times_bad)=0
name='pm1thv2t' datestart='2025:001' datestop='2025:002' executed in 1.0 ms
len(mdl.bad_times)=0
len(mdl.bad_times_indices)=0
len(mdl.mask_times)=0
len(mdl.mask_time_secs)=0
len(mdl.mask_times_bad)=0
name='pm2thv1t' datestart='2019:001' datestop='2022:301' executed in 19.8 ms
len(mdl.bad_times)=3842
len(mdl.bad_times_indices)=3842
len(mdl.mask_times)=3842
len(mdl.mask_time_secs)=3842
len(mdl.mask_times_bad)=3842
name='pm2thv1t' datestart='2025:001' datestop='2025:002' executed in 6.0 ms
len(mdl.bad_times)=0
len(mdl.bad_times_indices)=0
len(mdl.mask_times)=0
len(mdl.mask_time_secs)=0
len(mdl.mask_times_bad)=0
name='pm2thv1t' datestart='2022:083:22:30:00' datestop='2022:084:04:00:00' executed in 5.7 ms
len(mdl.bad_times)=2
len(mdl.bad_times_indices)=2
len(mdl.mask_times)=2
len(mdl.mask_time_secs)=2
len(mdl.mask_times_bad)=2

bad_times: np.ndarray = np.array(bad_times_in)

# Get inclusive overlap of bad_times with datestart to datestop
ok = (bad_times[:, 1] > datestart) & (bad_times[:, 0] < datestop)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to admit I was a bit surprised to see that this works, but my it does! #TIL

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more obvious if you start with the logical converse for a single interval:

no_overlap = time1 < datestart or time0 > datestop
overlap  =  not(time1 < datestart) and not(time0 > datestop)
overlap = time1 >= datestart and time0 <= datestop

And writing it out that way points out that I need the equality in there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants