Skip to content

Commit

Permalink
Coding norms
Browse files Browse the repository at this point in the history
  • Loading branch information
azadeh-gh committed May 1, 2024
1 parent dbdcf79 commit 3b6f933
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 34 deletions.
55 changes: 26 additions & 29 deletions ush/SpatialTemporalStatsTool/SpatialTemporalStats.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,21 +235,24 @@ def plot_obs(self, selected_var_gdf, var_name, region, resolution, output_path):
)

if item == "Obs_Minus_Forecast_adjusted_Average":
max_val_cbar = 5.0*std_val
min_val_cbar = -5.0*std_val
max_val_cbar = 5.0 * std_val
min_val_cbar = -5.0 * std_val
cmap = "bwr"
else:
max_val_cbar = max_val
min_val_cbar = min_val
cmap = "jet"

cbar_label = "grid=%dx%d, min=%.3lf, max=%.3lf, bias=%.3lf, std=%.3lf\n" % (
resolution,
resolution,
min_val,
max_val,
avg_val,
std_val,
cbar_label = (
"grid=%dx%d, min=%.3lf, max=%.3lf, bias=%.3lf, std=%.3lf\n"
% (
resolution,
resolution,
min_val,
max_val,
avg_val,
std_val,
)
)

filtered_gdf.plot(
Expand Down Expand Up @@ -320,14 +323,14 @@ def make_summary_plots(
self.sensor = sensor
# read all obs files
all_files = os.listdir(obs_files_path)
# obs_files = [os.path.join(obs_files_path, file) for file in all_files if file.endswith('.nc4')]
obs_files = [
os.path.join(obs_files_path, file)
for file in all_files
if file.endswith(".nc4") and "diag_%s_ges" % sensor in file
]

# get date time from file names. alternatively could get from attribute but that needs reading the entire nc4
# get date time from file names.
# alternatively could get from attribute but that needs reading the entire nc4
files_date_times_df = pd.DataFrame()

files_date_times = self._extract_date_times(obs_files)
Expand All @@ -354,51 +357,47 @@ def make_summary_plots(
# get unique channels from one of the files
ds = xarray.open_dataset(studied_cycle_files[index[0]])
unique_channels = np.unique(ds["Channel_Index"].data).tolist()
print('Total Number of Channels ', len(unique_channels))

Allchannels_data={}
print("Total Number of Channels ", len(unique_channels))
Allchannels_data = {}
for this_channel in unique_channels:
Allchannels_data[this_channel] = np.empty(shape=(0,))
Allchannels_data[this_channel] = np.empty(shape=(0,))
for this_cycle_obs_file in studied_cycle_files:
ds = xarray.open_dataset(this_cycle_obs_file)
if QC_filter:
QC_bool = ds["QC_Flag"].data == 0

for this_channel in unique_channels:
channel_bool = ds["Channel_Index"].data == this_channel

this_cycle_channel_var_values = ds[var_name].data[channel_bool*QC_bool]
this_cycle_channel_var_values = ds[var_name].data[
channel_bool * QC_bool
]
Allchannels_data[this_channel] = np.append(
Allchannels_data[this_channel], this_cycle_channel_var_values
)

for this_channel in unique_channels:
this_channel_values=Allchannels_data[this_channel]
this_channel_values = Allchannels_data[this_channel]
squared_values = [x**2 for x in this_channel_values]
mean_of_squares = sum(squared_values) / len(squared_values)
rms_value=mean_of_squares ** 0.5
rms_value = mean_of_squares ** 0.5
Summary_results.append(
[
this_channel,
np.size(this_channel_values),
np.std(this_channel_values),
np.mean(this_channel_values),
rms_value
rms_value,
]
)


Summary_resultsDF = pd.DataFrame(
Summary_results, columns=["channel", "count", "std", "mean", "rms"]
)
Summary_results, columns=["channel", "count", "std", "mean", "rms"])
# Plotting
plt.figure(figsize=(10, 6))
plt.scatter(Summary_resultsDF["channel"], Summary_resultsDF["count"], s=50)
plt.xlabel("Channel")
plt.ylabel("Count")
plt.title("%s %s" % ((self.sensor, var_name)))
#plt.xticks(Summary_resultsDF["channel"])
#plt.xticks(rotation=45)
plt.grid(True)
plt.tight_layout()
plt.savefig(
Expand Down Expand Up @@ -429,14 +428,12 @@ def make_summary_plots(
Summary_resultsDF["rms"],
s=50,
label="Rms",
facecolors='none',
edgecolors='blue'
facecolors="none",
edgecolors="blue",
)
plt.xlabel("Channel")
plt.ylabel("Statistics")
plt.title("%s %s" % ((self.sensor, var_name)))
#plt.xticks(Summary_resultsDF["channel"])
#plt.xticks(rotation=45)
plt.grid(True)
plt.tight_layout()
plt.legend()
Expand Down
7 changes: 2 additions & 5 deletions ush/SpatialTemporalStatsTool/user_Analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

# Set input and output paths
input_path = "/PATH/TO/Input/Files"

#output_path = r'./Results'

output_path = r'./Results'

# Set sensor name
sensor = "iasi_metop-c"
Expand All @@ -14,7 +12,7 @@
channel_no = 1

# Set start and end dates
start_date, end_date = '2024-01-01', '2024-01-31'
start_date, end_date = "2024-01-01", "2024-01-31"

# Set region
# 1: global, 2: polar region, 3: mid-latitudes region,
Expand Down Expand Up @@ -69,5 +67,4 @@
input_path, sensor, var_name, start_date, end_date, QC_filter, output_path
)
print("Summary plots created!")

# Print summary results

0 comments on commit 3b6f933

Please sign in to comment.