diff --git a/src/forecast.py b/src/forecast.py index 2cf77a9..efb8956 100644 --- a/src/forecast.py +++ b/src/forecast.py @@ -45,16 +45,21 @@ def forecast_page(): ][0] forecast_models = st.sidebar.multiselect( - "Select a model", ["cnn", "National_xg", "pvnet_v2", "blend"], ["cnn"] + "Select a model", ["cnn", "National_xg", "pvnet_v2", "blend"], ["pvnet_v2"] ) + probabilistic_forecasts = [model for model in forecast_models if model in ["National_xg", "pvnet_v2", "blend"]] + + if len(probabilistic_forecasts) > 0: + show_prob = st.sidebar.checkbox('Show Probabilities Forecast', value=False) + else: + show_prob = False + if gsp_id != 0: if "National_xg" in forecast_models: forecast_models.remove("National_xg") st.sidebar.warning("National_xg only available for National forecast.") - if gsp_id == 0: - show_prob = st.sidebar.checkbox('Show Probabilities Forecast', value=False) use_adjuster = st.sidebar.radio("Use adjuster", [True, False], index=1) forecast_type = st.sidebar.radio( @@ -63,13 +68,13 @@ def forecast_page(): if forecast_type == "Creation Time": now = datetime.now(tz=timezone.utc) - timedelta(days=1) d = st.sidebar.date_input("Forecast creation date:", now.date()) - t = st.sidebar.time_input("Forecast creation time", time(12, 00)) + t = st.sidebar.time_input("Forecast creation time", time(8, 0)) forecast_time = datetime.combine(d, t) st.sidebar.write(f"Forecast creation time: {forecast_time}") elif forecast_type == "Forecast Horizon": now = datetime.now(tz=timezone.utc) - timedelta(days=1) start_d = st.sidebar.date_input("Forecast start date:", now.date()) - start_t = st.sidebar.time_input("Forecast start time", time(12, 00)) + start_t = st.sidebar.time_input("Forecast start time", time(0, 0)) start_datetime = datetime.combine(start_d, start_t) end_datetime = start_datetime + timedelta(days=2) @@ -199,39 +204,39 @@ def forecast_page(): ) ) - if model != "cnn" and len(forecast) > 0 and show_prob: - try: - properties_0 = forecast[0]._properties - if isinstance(properties_0, dict): - assert "10" in properties_0.keys() and "90" in properties_0.keys() - plevel_10 = [i._properties["10"] for i in forecast] - plevel_90 = [i._properties["90"] for i in forecast] - - fig.add_trace( - go.Scatter( - x=x, - y=plevel_10, - mode="lines", - name="p10: " + model, - line=dict(color=colour_per_model[model], width=0), - showlegend=False, + if len(forecast) > 0 and show_prob: + try: + properties_0 = forecast[0]._properties + if isinstance(properties_0, dict): + assert "10" in properties_0.keys() and "90" in properties_0.keys() + plevel_10 = [i._properties["10"] for i in forecast] + plevel_90 = [i._properties["90"] for i in forecast] + + fig.add_trace( + go.Scatter( + x=x, + y=plevel_10, + mode="lines", + name="p10: " + model, + line=dict(color=colour_per_model[model], width=0), + showlegend=False, + ) ) - ) - fig.add_trace( - go.Scatter( - x=x, - y=plevel_90, - mode="lines", - name="p90: " + model, - line=dict(color=colour_per_model[model], width=0), - fill="tonexty", - showlegend=False, + fig.add_trace( + go.Scatter( + x=x, + y=plevel_90, + mode="lines", + name="p90: " + model, + line=dict(color=colour_per_model[model], width=0), + fill="tonexty", + showlegend=False, + ) ) - ) - except Exception as e: - print(e) - print("Could not add plevel to chart") - raise e + except Exception as e: + print(e) + print("Could not add plevel to chart") + raise e # pvlive on the chart for k, v in pvlive_data.items(): diff --git a/src/main.py b/src/main.py index 013010c..9ca46e5 100644 --- a/src/main.py +++ b/src/main.py @@ -51,7 +51,7 @@ def metric_page(): use_adjuster = st.sidebar.radio("Use adjuster", [True, False], index=1) st.sidebar.subheader("Select Forecast Model") - model_name = st.sidebar.selectbox("Select", ["cnn", "National_xg", "pvnet_v2"]) + model_name = st.sidebar.selectbox("Select", ["cnn", "National_xg", "pvnet_v2"], "pvnet_v2") # set up database connection url = os.environ["DB_URL"]