diff --git a/ush/SpatialTemporalStatsTool/SpatialTemporalStats.py b/ush/SpatialTemporalStatsTool/SpatialTemporalStats.py index ffca03f..2ba6b46 100644 --- a/ush/SpatialTemporalStatsTool/SpatialTemporalStats.py +++ b/ush/SpatialTemporalStatsTool/SpatialTemporalStats.py @@ -235,21 +235,24 @@ def plot_obs(self, selected_var_gdf, var_name, region, resolution, output_path): ) if item == "Obs_Minus_Forecast_adjusted_Average": - max_val_cbar = 5.0*std_val - min_val_cbar = -5.0*std_val + max_val_cbar = 5.0 * std_val + min_val_cbar = -5.0 * std_val cmap = "bwr" else: max_val_cbar = max_val min_val_cbar = min_val cmap = "jet" - cbar_label = "grid=%dx%d, min=%.3lf, max=%.3lf, bias=%.3lf, std=%.3lf\n" % ( - resolution, - resolution, - min_val, - max_val, - avg_val, - std_val, + cbar_label = ( + "grid=%dx%d, min=%.3lf, max=%.3lf, bias=%.3lf, std=%.3lf\n" + % ( + resolution, + resolution, + min_val, + max_val, + avg_val, + std_val, + ) ) filtered_gdf.plot( @@ -320,14 +323,14 @@ def make_summary_plots( self.sensor = sensor # read all obs files all_files = os.listdir(obs_files_path) - # obs_files = [os.path.join(obs_files_path, file) for file in all_files if file.endswith('.nc4')] obs_files = [ os.path.join(obs_files_path, file) for file in all_files if file.endswith(".nc4") and "diag_%s_ges" % sensor in file ] - # get date time from file names. alternatively could get from attribute but that needs reading the entire nc4 + # get date time from file names. + # alternatively could get from attribute but that needs reading the entire nc4 files_date_times_df = pd.DataFrame() files_date_times = self._extract_date_times(obs_files) @@ -354,51 +357,47 @@ def make_summary_plots( # get unique channels from one of the files ds = xarray.open_dataset(studied_cycle_files[index[0]]) unique_channels = np.unique(ds["Channel_Index"].data).tolist() - print('Total Number of Channels ', len(unique_channels)) - - Allchannels_data={} + print("Total Number of Channels ", len(unique_channels)) + Allchannels_data = {} for this_channel in unique_channels: - Allchannels_data[this_channel] = np.empty(shape=(0,)) + Allchannels_data[this_channel] = np.empty(shape=(0,)) for this_cycle_obs_file in studied_cycle_files: ds = xarray.open_dataset(this_cycle_obs_file) if QC_filter: QC_bool = ds["QC_Flag"].data == 0 - for this_channel in unique_channels: channel_bool = ds["Channel_Index"].data == this_channel - this_cycle_channel_var_values = ds[var_name].data[channel_bool*QC_bool] + this_cycle_channel_var_values = ds[var_name].data[ + channel_bool * QC_bool + ] Allchannels_data[this_channel] = np.append( Allchannels_data[this_channel], this_cycle_channel_var_values ) for this_channel in unique_channels: - this_channel_values=Allchannels_data[this_channel] + this_channel_values = Allchannels_data[this_channel] squared_values = [x**2 for x in this_channel_values] mean_of_squares = sum(squared_values) / len(squared_values) - rms_value=mean_of_squares ** 0.5 + rms_value = mean_of_squares ** 0.5 Summary_results.append( [ this_channel, np.size(this_channel_values), np.std(this_channel_values), np.mean(this_channel_values), - rms_value + rms_value, ] ) - Summary_resultsDF = pd.DataFrame( - Summary_results, columns=["channel", "count", "std", "mean", "rms"] - ) + Summary_results, columns=["channel", "count", "std", "mean", "rms"]) # Plotting plt.figure(figsize=(10, 6)) plt.scatter(Summary_resultsDF["channel"], Summary_resultsDF["count"], s=50) plt.xlabel("Channel") plt.ylabel("Count") plt.title("%s %s" % ((self.sensor, var_name))) - #plt.xticks(Summary_resultsDF["channel"]) - #plt.xticks(rotation=45) plt.grid(True) plt.tight_layout() plt.savefig( @@ -429,14 +428,12 @@ def make_summary_plots( Summary_resultsDF["rms"], s=50, label="Rms", - facecolors='none', - edgecolors='blue' + facecolors="none", + edgecolors="blue", ) plt.xlabel("Channel") plt.ylabel("Statistics") plt.title("%s %s" % ((self.sensor, var_name))) - #plt.xticks(Summary_resultsDF["channel"]) - #plt.xticks(rotation=45) plt.grid(True) plt.tight_layout() plt.legend() diff --git a/ush/SpatialTemporalStatsTool/user_Analysis.py b/ush/SpatialTemporalStatsTool/user_Analysis.py index c98354a..7f69c82 100644 --- a/ush/SpatialTemporalStatsTool/user_Analysis.py +++ b/ush/SpatialTemporalStatsTool/user_Analysis.py @@ -2,9 +2,7 @@ # Set input and output paths input_path = "/PATH/TO/Input/Files" - -#output_path = r'./Results' - +output_path = r'./Results' # Set sensor name sensor = "iasi_metop-c" @@ -14,7 +12,7 @@ channel_no = 1 # Set start and end dates -start_date, end_date = '2024-01-01', '2024-01-31' +start_date, end_date = "2024-01-01", "2024-01-31" # Set region # 1: global, 2: polar region, 3: mid-latitudes region, @@ -69,5 +67,4 @@ input_path, sensor, var_name, start_date, end_date, QC_filter, output_path ) print("Summary plots created!") - # Print summary results