Skip to content

Commit

Permalink
Ensure the first channel column doesn't get added when Rosetta round …
Browse files Browse the repository at this point in the history
…2 compensation matrices are combined (#444)
  • Loading branch information
alex-l-kong authored Oct 20, 2023
1 parent 81d28e2 commit 9264f13
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
5 changes: 3 additions & 2 deletions src/toffy/rosetta.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,10 @@ def combine_compensation_files(comp_matrix_path, compensation_matrix_names, fina
)

# loop over the rest and add them in
# NOTE: skip the first column, since that just delineates the channels
for matrix in compensation_matrix_names[1:]:
final_compensation_matrix = final_compensation_matrix.add(
pd.read_csv(os.path.join(comp_matrix_path, matrix))
final_compensation_matrix.iloc[:, 1:] = final_compensation_matrix.iloc[:, 1:].add(
pd.read_csv(os.path.join(comp_matrix_path, matrix)).iloc[:, 1:]
)

# save the final compensation matrix to final_matrix_name
Expand Down
15 changes: 11 additions & 4 deletions tests/rosetta_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,11 @@ def test_combine_compensation_files():
for m in mults:
for cp in channel_pairs:
df = pd.DataFrame(
np.zeros((3, 3)),
np.zeros((3, 4)),
index=channels,
columns=channels,
columns=[None] + channels,
)
df.iloc[:, 0] = np.arange(3)
df.loc[cp[0], cp[1]] = m
df.to_csv(
os.path.join(rosetta_test_dir, f"{cp[0]}_{cp[1]}_compensation_matrix_{m}.csv"),
Expand All @@ -261,12 +262,18 @@ def test_combine_compensation_files():
rosetta_test_dir, compensation_matrices, "final_rosetta_matrix.csv"
)

# assert final_rosetta_matrix.csv created and generated correctly
# assert final_rosetta_matrix.csv created
final_rosetta_matrix_path = os.path.join(rosetta_test_dir, "final_rosetta_matrix.csv")
assert os.path.exists(final_rosetta_matrix_path)

# assert the compensation coefficients got combined correctly
final_rosetta_matrix = pd.read_csv(final_rosetta_matrix_path)
actual_final_values = np.array([[0, 0.5, 0], [0, 0, 1], [2, 0, 0]])
assert np.all(final_rosetta_matrix.values == actual_final_values)
assert np.all(final_rosetta_matrix.values[:, 1:] == actual_final_values)

# assert the channel column didn't get modified
actual_column_values = np.arange(3)
assert np.all(final_rosetta_matrix.iloc[:, 0].values == actual_column_values)


def test_flat_field_correction():
Expand Down

0 comments on commit 9264f13

Please sign in to comment.