Skip to content
This repository has been archived by the owner on Aug 29, 2023. It is now read-only.

Dynamic kwargs generation as part of the recipe #7

Merged
merged 7 commits into from
May 17, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
262 changes: 247 additions & 15 deletions feedstock/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import requests
import numpy
import pandas as pd
import warnings
from typing import List, Dict

#dummy comment
# dummy comment
# copied from Naomis code https://github.com/pangeo-data/pangeo-cmip6-cloud/blob/master/myconfig.py
target_keys = [
"activity_id",
Expand Down Expand Up @@ -173,6 +175,7 @@ def esgf_search(

return dz


####################################################
from pangeo_forge_recipes.patterns import pattern_from_file_sequence
from pangeo_forge_recipes.recipes import XarrayZarrRecipe
Expand All @@ -188,6 +191,7 @@ def esgf_search(
"dkrz": "https://esgf-data.dkrz.de/esg-search/search",
}


def urls_from_instance_id(instance_id):
# get facets from instance_id
facet_labels = (
Expand Down Expand Up @@ -215,8 +219,7 @@ def urls_from_instance_id(instance_id):

if facets["mip_era"] != "CMIP6":
raise ValueError("Only CMIP6 mip_era supported")



# version doesn't work here
keep_facets = (
"activity_id",
Expand All @@ -235,7 +238,6 @@ def urls_from_instance_id(instance_id):
search_node
] # TODO: We might have to be more clever here and search through different nodes. For later.


df = esgf_search(search_facets, server=ESGF_site) # this modifies the dict inside?

# get list of urls
Expand All @@ -244,31 +246,261 @@ def urls_from_instance_id(instance_id):
# sort urls in decending time order (to be able to pass them directly to the pangeo-forge recipe)
end_dates = [url.split("-")[-1].replace(".nc", "") for url in urls]
urls = [url for _, url in sorted(zip(end_dates, urls))]

# version is still not working
# if facets["version"].startswith("v"):
# facets["version"] = facets["version"][1:]

# TODO Check that there are no gaps or duplicates.

return urls

inputs = {
'CMIP6.CMIP.CCCma.CanESM5.historical.r1i1p1f1.Omon.zos.gn.v20190429':{'target_chunks':{'time':360}},
'CMIP6.CMIP.CCCma.CanESM5.historical.r1i1p1f1.Omon.so.gn.v20190429':{'target_chunks':{'time':6}, 'subset_inputs':{'time':5}},
}

## Misc logic
def facets_from_iid(iid):
iid_name_template = "mip_era.activity_id.institution_id.source_id.experiment_id.variant_label.table_id.variable_id.grid_label.version"
facets = {}
for name, value in zip(iid_name_template.split("."), iid.split(".")):
facets[name] = value
return facets


## Logic to dynamically generate input kwargs
def choose_chunksize(
chunksize_candidates: List[int],
max_size: float,
element_size_lst: List[float],
timesteps_lst: List[int],
include_last: bool = True,
) -> int:
"""Determines the ideal chunksize based on a list of preferred `divisors` and
informations about the input files
given the following constraints:
- The resulting chunks are smaller than `max_size`
- The determined chunksize will divide each file into even chunks
(if `include_last` is false, the last file is allowed to have uneven chunks,
but cannot be larger than the number of timesteps in the last file)

Parameters
----------
candidate_chunks : List[int]
A list of chunksizes to consider.
max_size : float
Maximum size (in bytes) of the resulting chunksize
element_size_lst : List[float]
List of sizes (in bytes) of a single element along the chunking dimension (often time)
for each of the input elements (files).
timesteps_lst : List[int]
List of timesteps for input elements
include_last : bool, optional
Option to include or exclude the last element from above lists, by default True.
If number of elements of lists above is 1, this is always True

Returns
-------
int
Choosen chunksize
"""
# # TODO: infer clean divisions of the divisor (e.g. [1, 2, 3, 4, 6] for 12) automatically here
# candidate_chunks = divisors[:-1]+list(range(divisors[-1], max(timesteps_lst), divisors[-1]))

if (
not include_last and len(timesteps_lst) > 1
): # we cannot exclude the last one if there is only one element.
chunksize_filtered = [
cs
for cs in chunksize_candidates
if all(
nt % cs == 0 for nt in timesteps_lst[:-1]
) # do I need and timesteps_lst[-1] > cs
]
else:
chunksize_filtered = [
cs
for cs in chunksize_candidates
if all(nt % cs == 0 for nt in timesteps_lst)
]
output_chunksizes = [
max([cs for cs in chunksize_filtered if cs * element_size <= max_size])
for element_size in element_size_lst
]
# what do we do if somehow this ends up being different? Take the min/max?
if not all(oc == output_chunksizes[0] for oc in output_chunksizes):
raise ValueError("Determined chunksizes are not all equal.")
else:
return output_chunksizes[0]


def dynamic_kwarg_generation(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow @jbusecke heroic work here :)

iid: str,
) -> Dict[str, Dict[str, int]]:
"""Dynamically generates keyword arguments `target_chunks` and `subsset_input` for
recipe generation based on information available via the ESGF API

Parameters
----------
iid : str
ESGF instance_id

Returns
-------
Dict[str,Dict[str, int]]
Dictionary containing keyword arguments that can be passed to `XarrayZarrRecipe`

"""
# TODO, I query the API in multiple places. Need to refactor something robust (which might try different urls?)

url = "https://esgf-node.llnl.gov/esg-search/search"
# url = "https://esgf-data.dkrz.de/esg-search/search"

# TODO: the 'distrib' parameter does not work as expected for all datasets.
# Need to investigate that on the ESGF side.
# Could just iterate through nodes for now.

params = {
"type": "File",
"retracted": "false",
"replica": "false",
"format": "application/solr+json",
# "fields": "size",
"latest": "true",
# "distrib": "true",
"limit": 500,
}

facets = facets_from_iid(iid)
params.update(facets)

del params[
"version"
] # TODO: Why do we have to delete this? Need to understand that better
resp = requests.get(url=url, params=params)

file_resp = resp.json()["response"]["docs"]

if not len(file_resp) > 0:
raise ValueError("ESGF API query did not return any files.")

# Check that all responses indeed have the same attributes
# (error out on e.g. mixed versions for now)
# TODO: We might allow mixed versions later, but need to be careful with that!
check_facets = [
"mip_era",
"activity_id",
"institution_id",
"source_id",
"experiment_id",
"variant_label",
"table_id",
"variable_id",
"grid_label",
"version",
]

def _check_single_element_list(lst):
# double check that the facet returns are just a single element
[out] = lst # errors on a list with more than one element
return out

for fac in check_facets:
file_facets = [_check_single_element_list(f[fac]) for f in file_resp]
if not all(ff == file_facets[0] for ff in file_facets):
raise ValueError(
f"Found non-matching values for {fac} in search query response. Got {file_facets}"
)

# now make sure that the table_id is a key in `preferred_time_divisions` otherwise error
table_id = file_resp[0]["table_id"][
0
] # Confirmed before that this list is only 1 element
if table_id not in allowed_divisors.keys():
raise ValueError(
f"Didnt find `table_id` value {table_id} in the `allowed_divisors` dict."
)

filesizes = [f["size"] for f in file_resp]

# extract date range from filename
# TODO: Is there a more robust way to do this?
# otherwise maybe use `id` (harder to parse)
dates = [a["title"].replace(".nc", "").split("_")[-1].split("-") for a in file_resp]

# infer number of timesteps using pandas
def format_date(str_date):
return "-".join([str_date[0:4], str_date[4:]])

# TODO: For non-monthly data, we have to make the freq input smarter
timesteps = [
len(pd.date_range(format_date(a[0]), format_date(a[1]), freq="1MS"))
for a in dates
]
element_sizes = [size / n_t for size, n_t in zip(filesizes, timesteps)]

target_chunks = {
"time": choose_chunksize(
allowed_divisors[table_id],
200e6,
element_sizes,
timesteps,
include_last=False,
)
}

subset_chunks = choose_chunksize(
allowed_divisors[table_id],
500e6,
[max(element_sizes)],
[max(timesteps)],
include_last=True,
)
# print([es*subset_chunks/1e6 for es in element_sizes])
# convert chunksize into number of chunks
subset_input = int(max(timesteps) / subset_chunks)

# make sure that this actually divides all files clean
if not all(ts % subset_input == 0 for ts in timesteps):
warnings.warn(
"The dynamically inferred `subset_input` does not divide each file cleanly"
)

dynamic_kwargs = {"target_chunks": target_chunks}
if subset_input > 1:
dynamic_kwargs["subset_inputs"] = {"time": subset_input}
print(f"Dynamically determined kwargs: {dynamic_kwargs} for {iid}")
jbusecke marked this conversation as resolved.
Show resolved Hide resolved
print(
f"Will result in max chunksize of {max(element_sizes)*target_chunks['time']/1e6}MB"
)
return dynamic_kwargs


## global variables

# For certain table_ids it is preferrable to have time chunks that are a multiple of e.g. 1 year for monthly data.
monthly_divisors = [1, 3, 6, 12, 12 * 3] + list(range(12 * 5, 12 * 200, 12 * 5))
allowed_divisors = {
"Omon": monthly_divisors,
"SImon": monthly_divisors,
} # Add table_ids and allowed divisors as needed


## Recipe Generation
iids = [
"CMIP6.CMIP.CCCma.CanESM5.historical.r1i1p1f1.Omon.zos.gn.v20190429",
"CMIP6.CMIP.CCCma.CanESM5.historical.r1i1p1f1.Omon.so.gn.v20190429",
cisaacstern marked this conversation as resolved.
Show resolved Hide resolved
]
inputs = {iid: dynamic_kwarg_generation(iid) for iid in iids}


def recipe_from_urls(urls, instance_kwargs):
pattern = pattern_from_file_sequence(urls, "time")

recipe = XarrayZarrRecipe(
pattern,
xarray_concat_kwargs={"join": "exact"},
**instance_kwargs

pattern, xarray_concat_kwargs={"join": "exact"}, **instance_kwargs
)
return recipe


recipes = {iid: recipe_from_urls(urls_from_instance_id(iid), kwargs) for iid, kwargs in inputs.items()}
recipes = {
iid: recipe_from_urls(urls_from_instance_id(iid), kwargs)
for iid, kwargs in inputs.items()
}