Skip to content

Commit

Permalink
[doc/fix]add a Jupyternotebook for bezierfit and fix some small bugs …
Browse files Browse the repository at this point in the history
…in bezierfit
  • Loading branch information
ZhenHuangLab committed Mar 14, 2024
1 parent 687a45c commit debd521
Show file tree
Hide file tree
Showing 29 changed files with 2,229 additions and 173 deletions.
2 changes: 0 additions & 2 deletions build/lib/memxterminator/GUI/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
import sys
sys.path.insert(0, '/data3/Zhen/MemXTerminator/src/')
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def retranslateUi(self, ParticleMembraneSubtraction_bezierfit):
self.controlpoints_label.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Control points file"))
self.template_browse_pushButton.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Browse..."))
self.particle.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Particle .cs file"))
self.points_step_lineEdit.setText(_translate("ParticleMembraneSubtraction_bezierfit", "0.005"))
self.points_step_lineEdit.setText(_translate("ParticleMembraneSubtraction_bezierfit", "0.001"))
self.points_step_label.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Points_step"))
self.physical_membrane_dist_label.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Physical membrane distance(Å)"))
self.physical_membrane_dist_lineEdit.setText(_translate("ParticleMembraneSubtraction_bezierfit", "35"))
Expand Down
2 changes: 1 addition & 1 deletion build/lib/memxterminator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__author__ = 'Zhen Huang'
__email__ = 'zhen.victor.huang@gmail.com'
__version__ = '1.2.1'
__version__ = '1.2.2'
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def membrane_subtract(particle_filename):
shifts = shift_list[mask]
classes = class_list[mask]
for particle_idx, psi, pixel_size, shift, class_ in zip(particle_idxes, psis, pixel_sizes, shifts, classes):
# class_得根据control_points.json找到control_points
# if str(class_) in control_points_dict:
control_points = np.array(control_points_dict[str(class_)])
# print(control_points)
Expand Down
53 changes: 1 addition & 52 deletions build/lib/memxterminator/bezierfit/lib/bezierfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,13 @@
import json
from scipy.ndimage import zoom


# def bezier_curve(control_points, t):
# B = np.outer((1 - t) ** 3, control_points[0]) + \
# np.outer(3 * (1 - t) ** 2 * t, control_points[1]) + \
# np.outer(3 * (1 - t) * t ** 2, control_points[2]) + \
# np.outer(t ** 3, control_points[3])
# return B.squeeze()
def bezier_curve(control_points, t):
n = len(control_points) - 1
B = np.zeros_like(control_points[0], dtype=float)
for i, point in enumerate(control_points):
B += comb(n, i) * (1 - t) ** (n - i) * t ** i * point
return B
# def bezier_curve_derivative(control_points, t):
# control_points = np.array(control_points)
# B_prime = 3 * (1 - t) ** 2 * (control_points[1] - control_points[0]) + \
# 6 * (1 - t) * t * (control_points[2] - control_points[1]) + \
# 3 * t ** 2 * (control_points[3] - control_points[2])
# return B_prime

def bezier_curve_derivative(control_points, t):
n = len(control_points) - 1
B_prime = np.zeros(2)
Expand All @@ -41,25 +29,6 @@ def bezier_curve_derivative(control_points, t):
B_prime += coef * (control_points[i+1] - control_points[i])
return B_prime

# def bezier_curvature(control_points, t):
# dB0 = -3 * (1 - t) ** 2
# dB1 = 3 * (1 - t) ** 2 - 6 * t * (1 - t)
# dB2 = 6 * t * (1 - t) - 3 * t ** 2
# dB3 = 3 * t ** 2

# ddB0 = 6 * (1 - t)
# ddB1 = 6 - 18 * t
# ddB2 = 18 * t - 6
# ddB3 = 6 * t

# p = control_points
# dx = sum([p[i, 0] * [dB0, dB1, dB2, dB3][i] for i in range(4)])
# dy = sum([p[i, 1] * [dB0, dB1, dB2, dB3][i] for i in range(4)])
# ddx = sum([p[i, 0] * [ddB0, ddB1, ddB2, ddB3][i] for i in range(4)])
# ddy = sum([p[i, 1] * [ddB0, ddB1, ddB2, ddB3][i] for i in range(4)])

# curvature = abs(dx * ddy - dy * ddx) / (dx * dx + dy * dy) ** 1.5
# return curvature
def bezier_curvature(control_points, t, threshold=1e-6, high_curvature_value=1e6):
n = len(control_points) - 1

Expand All @@ -79,7 +48,6 @@ def bezier_curvature(control_points, t, threshold=1e-6, high_curvature_value=1e6

magnitude_squared = dx * dx + dy * dy

# 规避除数接近零的问题
if magnitude_squared < threshold:
return high_curvature_value

Expand Down Expand Up @@ -331,22 +299,3 @@ def generate_curve_within_boundaries(control_points, image_shape, step):
break
fitted_curve_points = np.array([bezier_curve(control_points, t_val) for t_val in t_values])
return np.array(fitted_curve_points), np.array(t_values)

if __name__ == '__main__':
multiprocessing.set_start_method('spawn', force=True)
with mrcfile.open('/data3/kzhang/cryosparc/CS-vsv/J354/templates_selected.mrc') as f:
image = f.data[2]
image = zoom(image, 2)
coarsefit = Coarsefit(image, 600, 3, 300, 20)
initial_control_points = coarsefit()
ga_refine = GA_Refine(image, 1.068, 0.05, 50, 700, 18)
refined_control_points = ga_refine(initial_control_points, image)
refined_control_points = np.array(refined_control_points)
fitted_curve_points, t_values = generate_curve_within_boundaries(refined_control_points, image.shape, 0.01)
# save the control points in JSON format
with open('control_points.json', 'w') as f:
json.dump(refined_control_points.tolist(), f)
plt.imshow(image, cmap='gray')
plt.plot(fitted_curve_points[:, 0], fitted_curve_points[:, 1], 'r-')
plt.plot(refined_control_points[:, 0], refined_control_points[:, 1], 'g.')
plt.show()
23 changes: 10 additions & 13 deletions build/lib/memxterminator/bezierfit/lib/subtraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ def generate_2d_mask(self, image, fitted_points, membrane_distance):
membrane_mask[mask_outside_distance & ~mask_small_gray_value] = gray_value[mask_outside_distance & ~mask_small_gray_value]
return membrane_mask

def average_1d(self, image_gpu, fitted_points, normals, extra_mem_dist):
def average_1d(self, image_gpu, fitted_points, normals, mem_dist):
average_1d_lst = []
for membrane_dist in range(-extra_mem_dist, extra_mem_dist+1):
for membrane_dist in range(-mem_dist, mem_dist+1):
normals_points = fitted_points + membrane_dist * normals
# Ensure the points are within the image boundaries
mask = (normals_points[:, 0] >= 0) & (normals_points[:, 0] < image_gpu.shape[1]) & \
Expand Down Expand Up @@ -227,11 +227,11 @@ def average_1d(self, image_gpu, fitted_points, normals, extra_mem_dist):
# new_image[cp.isnan(new_image)] = 0
# return new_image.astype(image.dtype)

def average_2d(self, image_gpu, fitted_points, normals, average_1d_lst, extra_mem_dist):
def average_2d(self, image_gpu, fitted_points, normals, average_1d_lst, mem_dist):
image = image_gpu.get()
new_image = np.zeros_like(image)
count_image = np.zeros_like(image)
for membrane_dist, average_1d in zip(range(-extra_mem_dist, extra_mem_dist+1), average_1d_lst):
for membrane_dist, average_1d in zip(range(-mem_dist, mem_dist+1), average_1d_lst):
# start_time = time.time()
normals_points = fitted_points + membrane_dist * normals
mask = (normals_points[:, 0] >= 0) & (normals_points[:, 0] < image.shape[1]) & \
Expand All @@ -247,11 +247,11 @@ def average_2d(self, image_gpu, fitted_points, normals, average_1d_lst, extra_me
new_image[np.isnan(new_image)] = 0
return new_image.astype(image.dtype)

def average_2d_gpu(self, image_gpu, fitted_points, normals, average_1d_lst, extra_mem_dist):
def average_2d_gpu(self, image_gpu, fitted_points, normals, average_1d_lst, mem_dist):
new_image = cp.zeros_like(image_gpu)
count_image = cp.zeros_like(image_gpu)
fitted_points = cp.asarray(fitted_points)
membrane_dists = cp.arange(-extra_mem_dist, extra_mem_dist + 1)
membrane_dists = cp.arange(-mem_dist, mem_dist + 1)
# Expand dimensions for broadcasting
membrane_dists = membrane_dists[:, cp.newaxis, cp.newaxis]
# Calculate all normals_points at once
Expand All @@ -275,14 +275,11 @@ def average_2d_gpu(self, image_gpu, fitted_points, normals, average_1d_lst, extr
def mem_subtract(self):
control_points = self.control_points_trasf(self.control_points, self.psi, self.origin_x, self.origin_y)
fitted_curve_points, t_values = generate_curve_within_boundaries(control_points, self.image.shape, self.points_step)
# plt.imshow(self.image, cmap='gray')
# plt.plot(fitted_curve_points[:, 0], fitted_curve_points[:, 1], 'r-')
# plt.plot(control_points[:, 0], control_points[:, 1], 'g.')
# plt.show()
extra_mem_dist = 10
mem_mask = self.generate_2d_mask(self.image_gpu, fitted_curve_points, self.mem_dist)
raw_image_average_1d_lst = self.average_1d(self.image_gpu, fitted_curve_points, points_along_normal(control_points, t_values).get(), self.mem_dist)
raw_image_average_2d = self.average_2d(self.image_gpu, fitted_curve_points, points_along_normal(control_points, t_values).get(), raw_image_average_1d_lst, self.mem_dist)
raw_image_average_2d = cp.asarray(raw_image_average_2d)
raw_image_average_1d_lst = self.average_1d(self.image_gpu, fitted_curve_points, points_along_normal(control_points, t_values).get(), self.mem_dist+extra_mem_dist)
raw_image_average_2d = self.average_2d(self.image_gpu, fitted_curve_points, points_along_normal(control_points, t_values).get(), raw_image_average_1d_lst, self.mem_dist+extra_mem_dist)
raw_image_average_2d = cp.asarray(raw_image_average_2d) * mem_mask
kernel = gaussian_kernel(5, 1)
image_conv = convolve2d(self.image_gpu, kernel, mode = 'same')
raw_image_average_2d_conv = convolve2d(raw_image_average_2d, kernel, mode = 'same')
Expand Down
2 changes: 0 additions & 2 deletions build/lib/memxterminator/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
import sys
sys.path.insert(0, '/data3/Zhen/MemXTerminator/src/')
14 changes: 6 additions & 8 deletions build/lib/memxterminator/cli/bezierfit_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,11 @@ def kill_process(self):
os.remove(self.PID_FILE)
self.timer.stop()
def update_log(self):
# 读取日志文件内容
try:
with open('run.out', 'r') as f:
f.seek(self.last_read_position) # 跳转到上次读取的位置
new_content = f.read() # 读取新内容
self.last_read_position = f.tell() # 更新读取的位置
f.seek(self.last_read_position)
new_content = f.read()
self.last_read_position = f.tell()
if new_content:
self.LOG_textBrowser.append(new_content)
except FileNotFoundError:
Expand Down Expand Up @@ -198,12 +197,11 @@ def kill_process(self):
os.remove(self.PID_FILE)
self.timer.stop()
def update_log(self):
# 读取日志文件内容
try:
with open('run.out', 'r') as f:
f.seek(self.last_read_position) # 跳转到上次读取的位置
new_content = f.read() # 读取新内容
self.last_read_position = f.tell() # 更新读取的位置
f.seek(self.last_read_position)
new_content = f.read()
self.last_read_position = f.tell()
if new_content:
self.LOG_textBrowser.append(new_content)
except FileNotFoundError:
Expand Down
9 changes: 4 additions & 5 deletions build/lib/memxterminator/cli/radonfit_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,11 @@ def kill_process(self):
os.remove(self.PID_FILE)
self.timer.stop()
def update_log(self):
# 读取日志文件内容
try:
with open('run.out', 'r') as f:
f.seek(self.last_read_position) # 跳转到上次读取的位置
new_content = f.read() # 读取新内容
self.last_read_position = f.tell() # 更新读取的位置
f.seek(self.last_read_position)
new_content = f.read()
self.last_read_position = f.tell()
if new_content:
self.textBrowser_log.append(new_content)
except FileNotFoundError:
Expand Down Expand Up @@ -274,7 +273,7 @@ def __init__(self, parent=None):
self.last_read_position = 0
self.timer = QtCore.QTimer(self)
self.timer.timeout.connect(self.update_log)
self.timer.start(1000) # 每秒更新一次
self.timer.start(1000)


def browse_mem_analysis_starfile(self):
Expand Down
Binary file added dist/MemXTerminator-1.2.2-py3-none-any.whl
Binary file not shown.
Binary file added dist/MemXTerminator-1.2.2.tar.gz
Binary file not shown.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,5 +124,6 @@ nav:
- Frequently Asked Questions: ./tutorials/faq.md
- Reference Pages:
- "Visualize Radonfit (.ipynb)": ./tutorials/reference/radonfit-mem-analysis-visualizer.ipynb
- "Visualize Bezierfit (.ipynb)": ./tutorials/reference/bezierfit-mem-analysis-visualizer.ipynb
- "Conventions": ./tutorials/reference/conventions.md
- "API Reference": ./tutorials/reference/api.md
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='MemXTerminator',
version='1.2.1',
version='1.2.2',
packages=find_packages(where='src'),
package_dir={'': 'src'},
author='Zhen Huang',
Expand Down
2 changes: 1 addition & 1 deletion src/MemXTerminator.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: MemXTerminator
Version: 1.2.1
Version: 1.2.2
Summary: A software for membrane analysis and subtraction in cryo-EM
Home-page: https://github.com/ZhenHuangLab/MemXTerminator
Author: Zhen Huang
Expand Down
2 changes: 0 additions & 2 deletions src/memxterminator/GUI/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
import sys
sys.path.insert(0, '/data3/Zhen/MemXTerminator/src/')
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def retranslateUi(self, ParticleMembraneSubtraction_bezierfit):
self.controlpoints_label.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Control points file"))
self.template_browse_pushButton.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Browse..."))
self.particle.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Particle .cs file"))
self.points_step_lineEdit.setText(_translate("ParticleMembraneSubtraction_bezierfit", "0.005"))
self.points_step_lineEdit.setText(_translate("ParticleMembraneSubtraction_bezierfit", "0.001"))
self.points_step_label.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Points_step"))
self.physical_membrane_dist_label.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Physical membrane distance(Å)"))
self.physical_membrane_dist_lineEdit.setText(_translate("ParticleMembraneSubtraction_bezierfit", "35"))
Expand Down
2 changes: 1 addition & 1 deletion src/memxterminator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__author__ = 'Zhen Huang'
__email__ = 'zhen.victor.huang@gmail.com'
__version__ = '1.2.1'
__version__ = '1.2.2'
1 change: 0 additions & 1 deletion src/memxterminator/bezierfit/bin/mem_subtract_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def membrane_subtract(particle_filename):
shifts = shift_list[mask]
classes = class_list[mask]
for particle_idx, psi, pixel_size, shift, class_ in zip(particle_idxes, psis, pixel_sizes, shifts, classes):
# class_得根据control_points.json找到control_points
# if str(class_) in control_points_dict:
control_points = np.array(control_points_dict[str(class_)])
# print(control_points)
Expand Down
53 changes: 1 addition & 52 deletions src/memxterminator/bezierfit/lib/bezierfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,13 @@
import json
from scipy.ndimage import zoom


# def bezier_curve(control_points, t):
# B = np.outer((1 - t) ** 3, control_points[0]) + \
# np.outer(3 * (1 - t) ** 2 * t, control_points[1]) + \
# np.outer(3 * (1 - t) * t ** 2, control_points[2]) + \
# np.outer(t ** 3, control_points[3])
# return B.squeeze()
def bezier_curve(control_points, t):
n = len(control_points) - 1
B = np.zeros_like(control_points[0], dtype=float)
for i, point in enumerate(control_points):
B += comb(n, i) * (1 - t) ** (n - i) * t ** i * point
return B
# def bezier_curve_derivative(control_points, t):
# control_points = np.array(control_points)
# B_prime = 3 * (1 - t) ** 2 * (control_points[1] - control_points[0]) + \
# 6 * (1 - t) * t * (control_points[2] - control_points[1]) + \
# 3 * t ** 2 * (control_points[3] - control_points[2])
# return B_prime

def bezier_curve_derivative(control_points, t):
n = len(control_points) - 1
B_prime = np.zeros(2)
Expand All @@ -41,25 +29,6 @@ def bezier_curve_derivative(control_points, t):
B_prime += coef * (control_points[i+1] - control_points[i])
return B_prime

# def bezier_curvature(control_points, t):
# dB0 = -3 * (1 - t) ** 2
# dB1 = 3 * (1 - t) ** 2 - 6 * t * (1 - t)
# dB2 = 6 * t * (1 - t) - 3 * t ** 2
# dB3 = 3 * t ** 2

# ddB0 = 6 * (1 - t)
# ddB1 = 6 - 18 * t
# ddB2 = 18 * t - 6
# ddB3 = 6 * t

# p = control_points
# dx = sum([p[i, 0] * [dB0, dB1, dB2, dB3][i] for i in range(4)])
# dy = sum([p[i, 1] * [dB0, dB1, dB2, dB3][i] for i in range(4)])
# ddx = sum([p[i, 0] * [ddB0, ddB1, ddB2, ddB3][i] for i in range(4)])
# ddy = sum([p[i, 1] * [ddB0, ddB1, ddB2, ddB3][i] for i in range(4)])

# curvature = abs(dx * ddy - dy * ddx) / (dx * dx + dy * dy) ** 1.5
# return curvature
def bezier_curvature(control_points, t, threshold=1e-6, high_curvature_value=1e6):
n = len(control_points) - 1

Expand All @@ -79,7 +48,6 @@ def bezier_curvature(control_points, t, threshold=1e-6, high_curvature_value=1e6

magnitude_squared = dx * dx + dy * dy

# 规避除数接近零的问题
if magnitude_squared < threshold:
return high_curvature_value

Expand Down Expand Up @@ -331,22 +299,3 @@ def generate_curve_within_boundaries(control_points, image_shape, step):
break
fitted_curve_points = np.array([bezier_curve(control_points, t_val) for t_val in t_values])
return np.array(fitted_curve_points), np.array(t_values)

if __name__ == '__main__':
multiprocessing.set_start_method('spawn', force=True)
with mrcfile.open('/data3/kzhang/cryosparc/CS-vsv/J354/templates_selected.mrc') as f:
image = f.data[2]
image = zoom(image, 2)
coarsefit = Coarsefit(image, 600, 3, 300, 20)
initial_control_points = coarsefit()
ga_refine = GA_Refine(image, 1.068, 0.05, 50, 700, 18)
refined_control_points = ga_refine(initial_control_points, image)
refined_control_points = np.array(refined_control_points)
fitted_curve_points, t_values = generate_curve_within_boundaries(refined_control_points, image.shape, 0.01)
# save the control points in JSON format
with open('control_points.json', 'w') as f:
json.dump(refined_control_points.tolist(), f)
plt.imshow(image, cmap='gray')
plt.plot(fitted_curve_points[:, 0], fitted_curve_points[:, 1], 'r-')
plt.plot(refined_control_points[:, 0], refined_control_points[:, 1], 'g.')
plt.show()
Loading

0 comments on commit debd521

Please sign in to comment.