Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected get_direct_beam_position behaviour with center_of_mass method. #1079

Open
emichr opened this issue May 22, 2024 · 9 comments
Open
Labels

Comments

@emichr
Copy link

emichr commented May 22, 2024

Describe the bug
When calling shift = signal.get_direct_beam_position(method="center_of_mass") with the half_square_width / signal_slice and mask keyword arguments, the resulting shifts are all NaNs. I expect this issue is due to the mask coordinates not being correctly shifted/transforemed after slicing the signal internally.

To Reproduce
The following code was run with Pyxem 0.18.0

%matplotlib widget
import hyperspy.api as hs;
import pyxem as pxm;
import numpy as np;

# Load dummy data and set signal axes units to uncalibrated units
s = pxm.data.tilt_boundary_data(correct_pivot_point=False)
for ax in s.axes_manager.signal_axes:
    s.axes_manager[ax].scale = 1
    s.axes_manager[ax].offset = 0
s.axes_manager

# Plot the maximum through the stack and add a widget to get an idea of the total direct beam offsets
m = s.max(axis=[0, 1])
m.plot()
roi=hs.roi.CircleROI(131, 131, 14.5)
roi.add_widget(m)
print(roi)

#Test three combinations of arguments.
cx, cy, r = roi.cx, roi.cy, roi.r #Mask in calibrated coordinates
hsw = 30 #Half square_width in pixels

kwargs_list = [{},
    {"half_square_width": hsw}, #Works
    {"mask": (cx, cy, r)}, #Works
    {"half_square_width": hsw, "mask": (cx, cy, r)} #Combination does not work
]
print(kwargs_list)
shifts = [
    s.get_direct_beam_position(method="center_of_mass", **kwargs) for kwargs in kwargs_list
]

hs.plot.plot_signals(shifts)  # Plotting the shifts of the different methods

The last plot clearly shows the method failing.

I tried some short debugging by picking the relevant parts of the pyxem code:

half_square_width = hsw
from pyxem.utils._signals import (
    _select_method_from_method_dict,
    _to_hyperspy_index,
)


# Code from line 677 to 695 in pyxem/signals/diffraction2d (without the if/else blocks):
signal_shape = s.axes_manager.signal_shape
signal_center = np.array(signal_shape) / 2
min_x = int(signal_center[0] - half_square_width)
max_x = int(signal_center[0] + half_square_width)
min_y = int(signal_center[1] - half_square_width)
max_y = int(signal_center[1] + half_square_width)
signal_slice = (min_x, max_x, min_y, max_y)

sig_axes = s.axes_manager.signal_axes
sig_axes = np.repeat(sig_axes, 2)
low_x, high_x, low_y, high_y = [
    _to_hyperspy_index(ind, ax)
    for ind, ax in zip(
        signal_slice,
        sig_axes,
    )
]
# End of code from pyxem/signals/diffraction2d

ss = s.isig[low_x:high_x, low_y:high_y] #The signal that `get_direct_beam_position(...)` works on when `half_square_width` is supplied
ss.plot() 

ccx, ccy = [hsw + c - o//2 for c, o in zip([cx, cy], s.axes_manager.signal_shape)] #Transform the `cx` and `cy` values into the axes of the sliced `ss` signal

#Calculate COM for different masks:
mask_list = [
    (cx, cy, r), # Same mask that gives the unexpected results above
    (hsw, hsw, r), #Using the half-square-width to "get back" to the pattern center
    (ccx, ccy, r), #Using the transformed `cx` and `cy` center of `ss`
]
print(mask_list)

coms = [ss.center_of_mass(mask=mask) for mask in mask_list]
hs.plot.plot_signals(coms)

It looks like the last part works and gives the expected result. I.e. I expect the issue to be solved by a suitable offset/transformation of the mask coordinates to the new internally sliced signal. This also looks to be working with scaled axes:

# Load dummy data and set signal axes units to uncalibrated units
s = pxm.data.tilt_boundary_data(correct_pivot_point=False)
s.axes_manager

half_square_width = hsw
from pyxem.utils._signals import (
    _select_method_from_method_dict,
    _to_hyperspy_index,
)


# Code from line 677 to 695 in pyxem/signals/diffraction2d (without the if/else blocks):
signal_shape = s.axes_manager.signal_shape
signal_center = np.array(signal_shape) / 2
min_x = int(signal_center[0] - half_square_width)
max_x = int(signal_center[0] + half_square_width)
min_y = int(signal_center[1] - half_square_width)
max_y = int(signal_center[1] + half_square_width)
signal_slice = (min_x, max_x, min_y, max_y)

sig_axes = s.axes_manager.signal_axes
sig_axes = np.repeat(sig_axes, 2)
low_x, high_x, low_y, high_y = [
    _to_hyperspy_index(ind, ax)
    for ind, ax in zip(
        signal_slice,
        sig_axes,
    )
]
# End of code from pyxem/signals/diffraction2d

ss = s.isig[low_x:high_x, low_y:high_y] #The signal that `get_direct_beam_position(...)` works on when `half_square_width` is supplied
ss.plot() 

ccx, ccy = [hsw + c - o//2 for c, o in zip([cx, cy], s.axes_manager.signal_shape)] #Transform the `cx` and `cy` values into the axes of the sliced `ss` signal

#Calculate COM for different masks:
mask_list = [
    (cx, cy, r), # Same mask that gives the unexpected results above
    (hsw, hsw, r), #Using the half-square-width to "get back" to the pattern center
    (ccx, ccy, r), #Using the transformed `cx` and `cy` center of `ss`
]
print(mask_list)

coms = [ss.center_of_mass(mask=mask) for mask in mask_list]
hs.plot.plot_signals(coms)
@pc494 pc494 added the bug label May 25, 2024
@emichr
Copy link
Author

emichr commented May 27, 2024

I guess #1057 will change how signal.get_direct_beam_position works and "fix" this issue. Just to be clear as to why I wanted to use both the mask and the signal_slice/half_square_width arguments: There is currently a memory leak when finding the direct beam position with the mask argument which causes issues with big datasets. A workaround (for now) is to first slice the signal down to a rectangular region that contains the ROI of the diffraction pattern and pass a modified mask to the get_direct_beam function. This can be done through e.g. the ROI objects of hyperspy (useful for illustrations along the way):

#First, define the circular region of the data that should be used for centering
_max = signal.max(axis=[0, 1]) #Calculate maximum throughstack for illustration
_max.plot()
circle = hs.roi.CircleROI(128, 128, 17.5) #Create a circular ROI that you would like to use as a mask for `get_direct_beam_position(method="center_of_mass")` 
circle.add_widget(_max)
s = circle(signal, axes=[2, 3]) #Get the part of the data that contains the circular ROI. This signal cannot be used directly due to missing data outside the circular ROI making the center of mass calculation return NaNs.

#Next, find the corresponding square region of that data
extent = s.axes_manager.signal_extent #Get the signal extent
rectangle = hs.roi.RectangularROI(extent[0], extent[2], extent[1], extent[3]) #Create a rectangular ROI based on the sliced signal extent
s = rectangle(signal, axes=[2,3]) #Replace the sliced/masked signal to work around the missing data.

#Finally, calculate the shifts of the data using the square signal and the circular mask
circle.add_widget(s, axes=[2, 3]) #Add the circular ROI to the new sliced signal
shifts = s.get_direct_beam_position(method='center_of_mass', mask=(circle.cx-s.axes_manager[2].offset, circle.cy-s.axes_manager[3].offset, circle.r)) #Calculate shifts
shifts.make_linear_plane() #Make the linear plane estimation

# [Optional] Center the datasets
signal.center_direct_beam(shifts=shifts)

@CSSFrancis
Copy link
Member

CSSFrancis commented May 27, 2024

@emichr I think the solution is that it shouldn't be possible to pass the mask in addition to a signal slice.

As far as the memory leak. It's fairly likely that the case is that dask isn't processing things opimally. What version of pyxem are you using? I changed how the COM was computed here #1005 which should remove a fair bit of the read amplification that was previously happening. In particuar the function was rechunking rather agressively which sometimes forces data to be held in memory longer than it should be.

Another thing to consider is that your data should be saved in chunks only in the navigation dimensions. If that is not the case the following code should automatically rechunk:

s.rechunk()
s.save("rechunked.zspy")

@sivborg
Copy link
Contributor

sivborg commented May 27, 2024

@emichr The issue seems to be from a small bug in the code, at this line.
The index here should be 2 instead of 1. The use of a mask in addition to slicing is a bit interesting, although it wasn't considered fully due to the niche use cases. In fact it uses a less-optimal solution by modifying kwargs directly.
I can create a quick PR fixing this issue in a little bit.

@sivborg
Copy link
Contributor

sivborg commented May 27, 2024

I have created a fix for this issue in #1080.

@emichr
Copy link
Author

emichr commented May 28, 2024

@emichr I think the solution is that it shouldn't be possible to pass the mask in addition to a signal slice.

As far as the memory leak. It's fairly likely that the case is that dask isn't processing things opimally. What version of pyxem are you using? I changed how the COM was computed here #1005 which should remove a fair bit of the read amplification that was previously happening. In particuar the function was rechunking rather agressively which sometimes forces data to be held in memory longer than it should be.

Another thing to consider is that your data should be saved in chunks only in the navigation dimensions. If that is not the case the following code should automatically rechunk:

s.rechunk()
s.save("rechunked.zspy")

@CSSFrancis I am using pyxem 0.18.0, so I guess the changes you made to COM computation should be there. Thanks for the tips, Converting to .zpsy and using appropriate chunking might work better - I'll try that a bit! I might get a headache from working with .zspy files on a HPC cluster with restrictions on the number of files, but that's not really a pyxem headache :)

Just to be clear though, i was not working on lazy data. It seems that when working with lazy data on a HPC cluster, the memory requirements are higher than I would expect (not sure why, but there are memory spikes above the size of the dataset when e.g. loading .mib data lazily and when calculating COM etc). So I figured I should use the memory I ask for anyway and work non-lazy.

@sivborg Great, thanks for looking into this and squashing that bug! I guess the most optimal solution for this is to slice the signal down to the mask size when a mask is given, but that might be for the upcoming pull requests for 0.19.0 or 1.0.0? Maybe the unit tests for beam centering with COM should be looked into as well to see if they can be improved to catch similar bugs?

@emichr
Copy link
Author

emichr commented May 28, 2024

Just to be clear though, i was not working on lazy data. It seems that when working with lazy data on a HPC cluster, the memory requirements are higher than I would expect (not sure why, but there are memory spikes above the size of the dataset when e.g. loading .mib data lazily and when calculating COM etc). So I figured I should use the memory I ask for anyway and work non-lazy.

To be a bit more quantitative here: When loading a 31.3 GB .mib dataset with hs.load('dataset.mib', lazy=True), the RSS memory spikes at about 95 GB. I might create a separate issue over at hyperspy on this.

@CSSFrancis
Copy link
Member

Part of this is fixed by #1080

@magnunor
Copy link
Collaborator

I think the bigger issue in get_direct_beam_position, is that there is currently multiple ways of "cropping" the signal dimension during the processing.

In my opinion, we should only have one, common, way of doing this. I think should be some slicing, followed by some kind of round mask. Which I think makes the most sense for this kind of data.

We should make sure that the data is sliced before any "masking", to avoid using too much memory.


Another thing to consider is that your data should be saved in chunks only in the navigation dimensions. If that is not the case the following code should automatically rechunk:

@CSSFrancis, I don't think that is always the optimal chunking. For example, if you want to just use the direct beam, and ignore the rest of the signal space, having chunks in the signal dimension means you will not have to load the full dataset from the harddrive.

Converting to .zpsy and using appropriate chunking might work better - I'll try that a bit! I might get a headache from working with .zspy files on a HPC cluster with restrictions on the number of files, but that's not really a pyxem headache :)

@emichr, You could also try the zarr zipstore. It seems to work pretty well, at least locally on my own computer: hyperspy/rosettasciio#249 (comment)

To be a bit more quantitative here: When loading a 31.3 GB .mib dataset with hs.load('dataset.mib', lazy=True), the RSS memory spikes at about 95 GB. I might create a separate issue over at hyperspy on this.

In case someone else bumps into the same issue in the future, we figured out the origins of this bug. It is due to a change in dask: dask/dask#11152 (comment)

A temporary fix until things are fixed is to downgrade to dask version 2024.1.1.

@CSSFrancis
Copy link
Member

CSSFrancis commented May 31, 2024

@CSSFrancis, I don't think that is always the optimal chunking. For example, if you want to just use the direct beam, and ignore the rest of the signal space, having chunks in the signal dimension means you will not have to load the full dataset from the harddrive.

@magnunor I'm definetly in the camp of 1. Make everything run and 2. Make everything run fast. Chunking equally in all dimensions with the way that the map function works it forces data into this form and it can causes some rechunking which can increase your memory usage. That being said dask is much better now than 2 years ago at this as now it throttles processes that produce memory.

Disk IO as a bottleneck is a lot less of a problem now with something like zarr than it was previously. Spining disk harddrives read at ~160 MB/sec and you can pretty easily put 10 of them in an array to get over 1 GB/sec. Paired with something like factor of 10 for good compression most likely the CPU becomes the bottleneck or things like passing data between workers in dask becomes an issue on rechunking. Of course I say this and I think dask/distributed#7507 should solve some of the rechunking problems dask has

There are a couple other fun things I've found over the years:

  1. Memory mapping binary files into equal dimensional chunks is slow because the data is stored [DP1, DP2] so to do that chunking requires tons of movement on the disk. For big files it is faster to read the data chunked along only 1 of the navigation dimensions, compress and save it. Then read and rechunk it/ save it.
  2. Plotting with equal chunks is faster than plotting without as you can load in parrallel and decompress in parallel. I'd love to be able to implement something like sharding to take advantage of this.

So while yes for this case you are correct that slicing and then doing the DPC is faster it comes with some additional considerations and potentially some workflows that no longer work.

I've been starting to write up a 4D STEM Hardware/ Software guide that I was going to publish as I think some of these things are a bit lost. Its a little bit more confusing when you add in a GPU to the mix trying to figure out what is the biggest issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants