Source code for sr3d_project

# Author: Carsten Sachse 08-Jun-2011
# Copyright: EMBL (2010 - 2018), Forschungszentrum Juelich (2019 - 2021)
# License: see license.txt for details
from EMAN2 import EMData, EMUtil
from collections import namedtuple
from spring.csinfrastr.csdatabase import SpringDataBase, refine_base, RefinementCycleTable
from spring.segment2d.segment import Segment
from spring.segment2d.segmentexam import SegmentExam
from spring.segment3d.refine.sr3d_prepare import SegmentRefine3dSymmetry
from spring.segment3d.segclassreconstruct import SegClassReconstruct
from sqlalchemy.sql.expression import desc
from tabulate import tabulate
from sparx import binarize, model_blank, model_circle, rot_shift2D
import numpy as np
import os
import shutil


[docs]class SegmentRefine3dPreparationStrategy(SegmentRefine3dSymmetry):
[docs] def make_series_info_named_tuple(self): bin_info = namedtuple('series_info', 'bin_factor resolution_aim azimuthal_restraint out_of_plane_restraint ' + 'x_range y_range max_range pixelsize iteration_count') return bin_info
[docs] def compute_resolution_ranges_for_binseries(self, bin_series, pixelsize): # nyqist_twice_pixs = 2.0 * pixelsize * bin_series thrice_pixs = 3.0 * pixelsize * bin_series # res_avg = (nyqist_twice_pixs + thrice_pixs) / 2.0 res_avg = thrice_pixs return res_avg
[docs] def set_res_expectation(self): low_res = 24.0#sum([40.0, 20.0]) / 2.0 medium_res = 12.0#sum([20.0, 10.0]) / 2.0 high_res = 7.0#sum([10.0, 6.0]) / 2.0 max_res = 1.0#sum([6.0, 0.0]) / 2.0 return [low_res, medium_res, high_res, max_res]
[docs] def get_closest_binfactors(self, pixel_res, bin_series): res_avg = self.set_res_expectation() binfactors = [bin_series[np.argmin(np.abs(pixel_res - each_res))] for each_res in res_avg] return binfactors
[docs] def determine_res_ranges_and_binfactors_to_be_used(self, pixelsize, resolution_aim): """ >>> from spring.segment3d.refine.sr3d_main import SegmentRefine3d >>> s = SegmentRefine3d() >>> res_aim = {'low': True, 'medium': True, 'high': True, 'max': False} >>> s.determine_res_ranges_and_binfactors_to_be_used(1.78, res_aim) {'low': 4, 'medium': 2, 'high': 1} >>> res_aim['max']=True >>> s.determine_res_ranges_and_binfactors_to_be_used(0.6, res_aim) {'low': 13, 'medium': 7, 'high': 4, 'max': 1} """ bin_series = np.arange(1, 30) pixel_res = self.compute_resolution_ranges_for_binseries(bin_series, pixelsize) closest_binfactors = self.get_closest_binfactors(pixel_res, bin_series) selected_binfactors = {} for each_id, each_aim in enumerate(['low', 'medium', 'high', 'max']): if resolution_aim[each_aim]: selected_binfactors[each_aim]=closest_binfactors[each_id] return selected_binfactors
[docs] def get_and_distribute_total_iteration_count(self, total_iteration_count, aims): """ >>> from spring.segment3d.refine.sr3d_main import SegmentRefine3d >>> s = SegmentRefine3d() >>> dd = {'low': True, 'medium': True, 'high': True, 'max': False} >>> s.get_and_distribute_total_iteration_count(14, dd) {'low': 5, 'medium': 5, 'high': 4} >>> s.get_and_distribute_total_iteration_count(7, dd) {'low': 3, 'medium': 2, 'high': 2} """ selected_keys = [each_aim for each_aim in ['low', 'medium', 'high', 'max'] if aims[each_aim]] count, rest = divmod(total_iteration_count, len(selected_keys)) selected_counts = [count] * len(selected_keys) for each_rest in list(range(rest)): selected_counts[each_rest] += 1 selected_counts = dict(zip(selected_keys, selected_counts)) return selected_counts
[docs] def define_series_of_search_steps(self, pixelsize, refine_strategy, low_resolution, medium_resolution, high_resolution, max_resolution, total_iteration_count): """ >>> from spring.segment3d.refine.sr3d_main import SegmentRefine3d >>> s = SegmentRefine3d() >>> low_res = (False, (180.0, 180.0), (30, 20)) >>> medium_res = (True, (180.0, 180.0), (20, 10)) >>> high_res = (True, (20.0, 10.0), (10, 5)) >>> max_res = (False, (2.0, 2.0), (5, 2.5)) >>> p = 1.372, True, low_res, medium_res, high_res, max_res, 10 >>> s.define_series_of_search_steps(p[0], p[1], p[2], p[3], p[4], p[5], p[6]) #doctest: +NORMALIZE_WHITESPACE ([series_info(bin_factor=3, resolution_aim='medium', azimuthal_restraint=180.0, out_of_plane_restraint=180.0, x_range=4.859086491739553, y_range=2.4295432458697763, max_range=4.859086491739553, pixelsize=4.1160000000000005, iteration_count=5), series_info(bin_factor=2, resolution_aim='high', azimuthal_restraint=20.0, out_of_plane_restraint=10.0, x_range=3.6443148688046643, y_range=1.8221574344023321, max_range=7.2886297376093285, pixelsize=2.744, iteration_count=5)], 10) >>> low_res = (True, (180.0, 180.0), (30, 20)) >>> medium_res = (True, (180.0, 180.0), (20, 10)) >>> high_res = (True, (20.0, 10.0), (10, 5)) >>> max_res = (True, (2.0, 2.0), (5, 2.5)) >>> p = 1.2, True, low_res, medium_res, high_res, max_res, 18 >>> s.define_series_of_search_steps(p[0], p[1], p[2], p[3], p[4], p[5], p[6]) #doctest: +NORMALIZE_WHITESPACE ([series_info(bin_factor=7, resolution_aim='low', azimuthal_restraint=180.0, out_of_plane_restraint=180.0, x_range=3.571428571428571, y_range=2.380952380952381, max_range=3.571428571428571, pixelsize=8.4, iteration_count=5), series_info(bin_factor=3, resolution_aim='medium', azimuthal_restraint=180.0, out_of_plane_restraint=180.0, x_range=5.555555555555556, y_range=2.777777777777778, max_range=8.333333333333334, pixelsize=3.5999999999999996, iteration_count=5), series_info(bin_factor=2, resolution_aim='high', azimuthal_restraint=20.0, out_of_plane_restraint=10.0, x_range=4.166666666666667, y_range=2.0833333333333335, max_range=12.5, pixelsize=2.4, iteration_count=4), series_info(bin_factor=1, resolution_aim='max', azimuthal_restraint=2.0, out_of_plane_restraint=2.0, x_range=4.166666666666667, y_range=2.0833333333333335, max_range=25.0, pixelsize=1.2, iteration_count=4)], 18) >>> p = 1.2, True, low_res, medium_res, high_res, max_res, 18 >>> ps = s.define_series_of_search_steps(p[0], p[1], p[2], p[3], p[4], p[5], p[6]) >>> [(each.bin_factor, each.pixelsize) for each in ps[0]] [(7, 8.4), (3, 3.5999999999999996), (2, 2.4), (1, 1.2)] >>> p = 0.6, True, low_res, medium_res, high_res, max_res, 18 >>> ps = s.define_series_of_search_steps(p[0], p[1], p[2], p[3], p[4], p[5], p[6]) >>> [(each.bin_factor, each.pixelsize) for each in ps[0]] [(13, 7.8), (7, 4.2), (4, 2.4), (1, 0.6)] >>> low_res = (True, (180.0, 180.0), (30, 20)) >>> medium_res = (True,(180.0, 180.0), (20, 10)) >>> high_res = (False, (20.0, 10.0), (10, 5)) >>> max_res = (False, (2.0, 2.0), (5, 2.5)) >>> p = 1.2, True, low_res, medium_res, high_res, max_res, 13 >>> s.define_series_of_search_steps(p[0], p[1], p[2], p[3], p[4], p[5], p[6]) #doctest: +NORMALIZE_WHITESPACE ([series_info(bin_factor=7, resolution_aim='low', azimuthal_restraint=180.0, out_of_plane_restraint=180.0, x_range=3.571428571428571, y_range=2.380952380952381, max_range=3.571428571428571, pixelsize=8.4, iteration_count=7), series_info(bin_factor=3, resolution_aim='medium', azimuthal_restraint=180.0, out_of_plane_restraint=180.0, x_range=5.555555555555556, y_range=2.777777777777778, max_range=8.333333333333334, pixelsize=3.5999999999999996, iteration_count=6)], 13) >>> low_res = (False, (180.0, 180.0), (30, 20)) >>> medium_res = (False, (180.0, 180.0), (20, 10)) >>> high_res = (False, (20.0, 10.0), (10, 5)) >>> max_res = (True, (2.0, 2.0), (5, 2.5)) >>> p = 1.2, True, low_res, medium_res, high_res, max_res, 7 >>> s.define_series_of_search_steps(p[0], p[1], p[2], p[3], p[4], p[5], p[6]) #doctest: +NORMALIZE_WHITESPACE ([series_info(bin_factor=1, resolution_aim='max', azimuthal_restraint=2.0, out_of_plane_restraint=2.0, x_range=4.166666666666667, y_range=2.0833333333333335, max_range=4.166666666666667, pixelsize=1.2, iteration_count=7)], 7) >>> low_res = (False, (180.0, 180.0), (30, 20)) >>> medium_res = (True, (180.0, 180.0), (20, 10)) >>> high_res = (True, (10.0, 5.0), (10.0, 5.0)) >>> max_res = (False, (2.0, 2.0), (5, 2.5)) >>> p = 1.2, True, low_res, medium_res, high_res, max_res, 7 >>> s.define_series_of_search_steps(p[0], p[1], p[2], p[3], p[4], p[5], p[6]) #doctest: +NORMALIZE_WHITESPACE ([series_info(bin_factor=3, resolution_aim='medium', azimuthal_restraint=180.0, out_of_plane_restraint=180.0, x_range=5.555555555555556, y_range=2.777777777777778, max_range=5.555555555555556, pixelsize=3.5999999999999996, iteration_count=4), series_info(bin_factor=2, resolution_aim='high', azimuthal_restraint=10.0, out_of_plane_restraint=5.0, x_range=4.166666666666667, y_range=2.0833333333333335, max_range=8.333333333333334, pixelsize=2.4, iteration_count=3)], 7) >>> low_res = (True, (180.0, 180.0), (30, 20)) >>> medium_res = (True, (180.0, 180.0), (20, 10)) >>> high_res = (True, (20.0, 10.0), (10, 5)) >>> max_res = (True, (2.0, 2.0), (5, 2.5)) >>> p = 1.2, True, low_res, medium_res, high_res, max_res, 3 >>> s.define_series_of_search_steps(p[0], p[1], p[2], p[3], p[4], p[5], p[6]) #doctest: +NORMALIZE_WHITESPACE ([series_info(bin_factor=7, resolution_aim='low', azimuthal_restraint=180.0, out_of_plane_restraint=180.0, x_range=3.571428571428571, y_range=2.380952380952381, max_range=3.571428571428571, pixelsize=8.4, iteration_count=1), series_info(bin_factor=3, resolution_aim='medium', azimuthal_restraint=180.0, out_of_plane_restraint=180.0, x_range=5.555555555555556, y_range=2.777777777777778, max_range=8.333333333333334, pixelsize=3.5999999999999996, iteration_count=1), series_info(bin_factor=2, resolution_aim='high', azimuthal_restraint=20.0, out_of_plane_restraint=10.0, x_range=4.166666666666667, y_range=2.0833333333333335, max_range=12.5, pixelsize=2.4, iteration_count=1), series_info(bin_factor=1, resolution_aim='max', azimuthal_restraint=2.0, out_of_plane_restraint=2.0, x_range=4.166666666666667, y_range=2.0833333333333335, max_range=25.0, pixelsize=1.2, iteration_count=1)], 4) >>> low_res = (False, (180.0, 180.0), (30, 20)) >>> medium_res = (False, (180.0, 180.0), (20, 10)) >>> high_res = (False, (10.0, 5.0), (10.0, 5.0)) >>> max_res = (False, (2.0, 2.0), (5, 2.5)) >>> p = 3.5, True, low_res, medium_res, high_res, max_res, 3 >>> s.define_series_of_search_steps(p[0], p[1], p[2], p[3], p[4], p[5], p[6]) #doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): File '/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/doctest.py', line 1254, in __run compileflags, 1) in test.globs File '<doctest segmentrefine3d_prj_algn.SegmentRefine3dAlign.define_series_of_search_steps[9]>', line 1, in <module> s.define_series_of_search_steps(1.2, True, False, False, False, False) File 'spring/segment3d/segmentrefine3d_prj_algn.py', line 975, in define_series_of_search_steps raise ValueError, error_msg ValueError: You have specified to assemble a refinement strategy without any resolution aim. Please, specify at least one resolution range you are targeting. >>> low_res = (True, (180.0, 180.0), (30, 20)) >>> medium_res = (True, (180.0, 180.0), (20, 10)) >>> high_res = (True, (180.0, 20.0), (10.0, 5.0)) >>> max_res = (False, (2.0, 2.0), (5, 2.5)) >>> p = 1.78, True, low_res, medium_res, high_res, max_res, 15 >>> s.define_series_of_search_steps(p[0], p[1], p[2], p[3], p[4], p[5], p[6]) #doctest: +NORMALIZE_WHITESPACE ([series_info(bin_factor=4, resolution_aim='low', azimuthal_restraint=180.0, out_of_plane_restraint=180.0, x_range=4.213483146067416, y_range=2.8089887640449436, max_range=4.213483146067416, pixelsize=7.12, iteration_count=5), series_info(bin_factor=2, resolution_aim='medium', azimuthal_restraint=180.0, out_of_plane_restraint=180.0, x_range=5.617977528089887, y_range=2.8089887640449436, max_range=8.426966292134832, pixelsize=3.56, iteration_count=5), series_info(bin_factor=1, resolution_aim='high', azimuthal_restraint=180.0, out_of_plane_restraint=20.0, x_range=5.617977528089887, y_range=2.8089887640449436, max_range=16.853932584269664, pixelsize=1.78, iteration_count=5)], 15) """ (low_resolution_aim, low_res_ang_range, (low_res_x_range, low_res_y_range)) = low_resolution (medium_resolution_aim, medium_res_ang_range, (medium_res_x_range, medium_res_y_range)) = medium_resolution (high_resolution_aim, high_res_ang_range, (high_res_x_range, high_res_y_range)) = high_resolution (max_resolution_aim, max_res_ang_range, (max_res_x_range, max_res_y_range)) = max_resolution aims = {'low': low_resolution_aim, 'medium': medium_resolution_aim, 'high': high_resolution_aim, 'max': max_resolution_aim} azimuthal_series = {'low': low_res_ang_range[0], 'medium': medium_res_ang_range[0], 'high': high_res_ang_range[0], 'max': max_res_ang_range[0]} out_of_plane_series = {'low': low_res_ang_range[1], 'medium': medium_res_ang_range[1], 'high': high_res_ang_range[1], 'max': max_res_ang_range[1]} x_ranges_A = {'low': low_res_x_range, 'medium': medium_res_x_range, 'high': high_res_x_range, 'max': max_res_x_range} y_ranges_A = {'low': low_res_y_range, 'medium': medium_res_y_range, 'high': high_res_y_range, 'max': max_res_y_range} selected_binfactors = self.determine_res_ranges_and_binfactors_to_be_used(pixelsize, aims) total_iteration_count = max(total_iteration_count, len(selected_binfactors)) if len(selected_binfactors) == 0 and refine_strategy: error_msg = 'You have specified to assemble a refinement strategy without any resolution aim. Please, ' +\ 'specify at least one resolution range you are targeting.' raise ValueError(error_msg) max_range_A = max([max(x_ranges_A[each_aim], y_ranges_A[each_aim]) for each_aim in selected_binfactors.keys()]) selected_counts = self.get_and_distribute_total_iteration_count(total_iteration_count, aims) bin_info = self.make_series_info_named_tuple() info_series = [] for each_aim in ['low', 'medium', 'high', 'max']: if aims[each_aim]: each_pixelsize = (selected_binfactors[each_aim] * pixelsize) info_series.append(bin_info(selected_binfactors[each_aim], each_aim, azimuthal_series[each_aim], out_of_plane_series[each_aim], x_ranges_A[each_aim] / each_pixelsize, y_ranges_A[each_aim] / each_pixelsize, max_range_A / each_pixelsize, each_pixelsize, selected_counts[each_aim] )) return info_series, total_iteration_count
[docs]class SegmentRefine3dProjection(SegmentRefine3dPreparationStrategy):
[docs] def generate_thetas_evenly_dependent_on_cos_of_out_of_plane_angle(self, thetas): """ >>> from spring.segment3d.refine.sr3d_main import SegmentRefine3d >>> s = SegmentRefine3d() >>> s.generate_thetas_evenly_dependent_on_cos_of_out_of_plane_angle(np.arange(78, 106, 4, dtype=float)) array([ 78. , 80.20802669, 83.08024783, 90. , 96.91975217, 99.79197331, 102. ]) >>> s.generate_thetas_evenly_dependent_on_cos_of_out_of_plane_angle(np.arange(78, 106, 2, dtype=float)) array([ 78. , 79.0488977 , 80.20802669, 81.52248781, 83.08024783, 85.10848412, 90. , 95.28021705, 97.47000184, 99.15209489, 100.5716871 , 101.82371539, 102.95685186, 104. ]) """ thetas_gt_90 = thetas[thetas > 90.0] - 90.0 cos_even_space = np.linspace(np.cos(np.deg2rad(np.max(thetas_gt_90))), 1, len(thetas_gt_90) + 1 ) even_thetas_gt_90 = np.rad2deg(np.arccos(cos_even_space)) thetas_le_90 = thetas[thetas <= 90.0] - 90.0 cos_even_space = np.linspace(np.cos(np.deg2rad(np.min(thetas_le_90))), 1, len(thetas_le_90)) even_thetas_le_90 = np.rad2deg(np.arccos(cos_even_space)) first = even_thetas_gt_90 second = -even_thetas_le_90 unique_thetas = np.sort(np.unique(np.append(first, second))) + 90.0 assert len(thetas) == len(unique_thetas) return unique_thetas
[docs] def generate_phis_evenly_across_asymmetric_unit_and_distribute_over_360(self, helical_rotation, azimuth_view_count): """ * Function contributed by Ambroise Desfosses (May 2012) >>> from spring.segment3d.refine.sr3d_main import SegmentRefine3d >>> s = SegmentRefine3d() >>> s.generate_phis_evenly_across_asymmetric_unit_and_distribute_over_360(22.04, 6) array([ 0. , 25.71333333, 51.42666667, 77.14 , 102.85333333, 128.56666667]) >>> s.generate_phis_evenly_across_asymmetric_unit_and_distribute_over_360(-22.04, 6) array([ 0. , 25.71333333, 51.42666667, 77.14 , 102.85333333, 128.56666667]) """ views_asym_unit = np.arange(azimuth_view_count) * abs(helical_rotation) / azimuth_view_count multiples_rot = abs(helical_rotation) * np.arange(int(360.0 / abs(helical_rotation) + 0.5)) all_multiples_rot = multiples_rot.tolist() * views_asym_unit.size phis = [each_view + all_multiples_rot[each_index] for each_index, each_view in enumerate(views_asym_unit)] phis.sort() return np.array(phis)
[docs] def generate_Euler_angles_for_projection(self, azimuthal_count, out_of_plane_range, out_of_plane_count, helical_rotation): """ >>> from spring.segment3d.refine.sr3d_main import SegmentRefine3d >>> s = SegmentRefine3d() >>> s.generate_Euler_angles_for_projection(4, [0, 0], 0, 22.04) #doctest: +NORMALIZE_WHITESPACE [[0.0, 90.0, 270.0, 0.0, 0.0], [27.549999999999997, 90.0, 270.0, 0.0, 0.0], [55.099999999999994, 90.0, 270.0, 0.0, 0.0], [82.65, 90.0, 270.0, 0.0, 0.0]] >>> s.generate_Euler_angles_for_projection(4, [-8, 8], 3, 22.04) #doctest: +NORMALIZE_WHITESPACE [[0.0, 82.0, 270.0, 0.0, 0.0], [27.549999999999997, 82.0, 270.0, 0.0, 0.0], [55.099999999999994, 82.0, 270.0, 0.0, 0.0], [82.65, 82.0, 270.0, 0.0, 0.0], [0.0, 90.0, 270.0, 0.0, 0.0], [27.549999999999997, 90.0, 270.0, 0.0, 0.0], [55.099999999999994, 90.0, 270.0, 0.0, 0.0], [82.65, 90.0, 270.0, 0.0, 0.0], [0.0, 98.0, 270.0, 0.0, 0.0], [27.549999999999997, 98.0, 270.0, 0.0, 0.0], [55.099999999999994, 98.0, 270.0, 0.0, 0.0], [82.65, 98.0, 270.0, 0.0, 0.0]] >>> s.generate_Euler_angles_for_projection(2, [-8, 8], 3, 22.04) #doctest: +NORMALIZE_WHITESPACE [[0.0, 82.0, 270.0, 0.0, 0.0], [33.06, 82.0, 270.0, 0.0, 0.0], [0.0, 90.0, 270.0, 0.0, 0.0], [33.06, 90.0, 270.0, 0.0, 0.0], [0.0, 98.0, 270.0, 0.0, 0.0], [33.06, 98.0, 270.0, 0.0, 0.0]] >>> s.generate_Euler_angles_for_projection(4, [-8, -8], 0, 22.04) #doctest: +NORMALIZE_WHITESPACE [[0.0, 82.0, 270.0, 0.0, 0.0], [27.549999999999997, 82.0, 270.0, 0.0, 0.0], [55.099999999999994, 82.0, 270.0, 0.0, 0.0], [82.65, 82.0, 270.0, 0.0, 0.0]] >>> s.generate_Euler_angles_for_projection(4, [0, 0], 0, 0) #doctest: +NORMALIZE_WHITESPACE [[0.0, 90.0, 270.0, 0.0, 0.0], [90.0, 90.0, 270.0, 0.0, 0.0], [180.0, 90.0, 270.0, 0.0, 0.0], [270.0, 90.0, 270.0, 0.0, 0.0]] >>> s.generate_Euler_angles_for_projection(4, [-12.0, 12.0], 1, 58) #doctest: +NORMALIZE_WHITESPACE [[0.0, 90.0, 270.0, 0.0, 0.0], [72.5, 90.0, 270.0, 0.0, 0.0], [145.0, 90.0, 270.0, 0.0, 0.0], [217.5, 90.0, 270.0, 0.0, 0.0]] """ if out_of_plane_range[0] == 0 and out_of_plane_range[1] == 0: out_of_plane_count = 1 if out_of_plane_range[0] == out_of_plane_range[1] and out_of_plane_count == 0: out_of_plane_count = 1 if out_of_plane_count == 0: out_of_plane_count = 1 if out_of_plane_count == 1: out_of_plane_range = 2 * [float(np.mean(out_of_plane_range))] theta_angles = 90.0 + np.linspace(min(out_of_plane_range), max(out_of_plane_range), out_of_plane_count) # if min(out_of_plane_range) != max(out_of_plane_range): # theta_angles = self.generate_thetas_evenly_dependent_on_cos_of_out_of_plane_angle(theta_angles) if helical_rotation != 0: phi_angles = self.generate_phis_evenly_across_asymmetric_unit_and_distribute_over_360(helical_rotation, azimuthal_count) else: phi_angles = np.linspace(0.0, 360.0, azimuthal_count, endpoint=False) shifts = 0.0 psi = 270.0 parameter_list = [] for each_theta in theta_angles: for each_phi in phi_angles: parameter_list.append([float(each_phi), float(each_theta), psi, shifts, shifts]) return parameter_list
[docs] def collect_prj_params_and_update_reference_info(self, updated_ref_files, each_reference, projection_stack, projection_parameters, fine_projection_stack, fine_projection_parameters, merged_prj_params, merged_fine_prj_params): projection_parameters = [[each_reference.model_id] + each_prj for each_prj in projection_parameters] merged_prj_params += projection_parameters if fine_projection_parameters is not None: fine_projection_parameters = [[each_reference.model_id] + each_prj for each_prj in fine_projection_parameters] merged_fine_prj_params += fine_projection_parameters else: merged_fine_prj_params = fine_projection_parameters each_reference = each_reference._replace(prj_stack=projection_stack) each_reference = each_reference._replace(fine_prj_stack=fine_projection_stack) updated_ref_files.append(each_reference) return updated_ref_files, merged_prj_params, merged_fine_prj_params
[docs] def merge_prj_ref_stacks_into_single_prj_stack(self, updated_ref_files, prj_attr): prj = EMData() for each_reference in updated_ref_files: if each_reference.model_id == 0: merged_prj_stack = getattr(each_reference, prj_attr) elif each_reference.model_id > 0: if getattr(each_reference, prj_attr) is not None: if getattr(each_reference, prj_attr).startswith(self.tempdir): prj_img_count = EMUtil.get_image_count(getattr(each_reference, prj_attr)) for each_img in list(range(prj_img_count)): prj.read_image(getattr(each_reference, prj_attr), each_img) prj.append_image(merged_prj_stack) os.remove(getattr(each_reference, prj_attr)) return merged_prj_stack
[docs] def write_out_reference_and_get_prj_prefix_depending_on_number_of_models(self, reference_files, ref_cycle_id, each_iteration_number, each_reference, reference_volume): if len(reference_files) > 1: prj_prefix = 'projection_stack_mod{0:03}_'.format(each_reference.model_id) each_reference = self.write_out_reference_volume(each_reference, each_iteration_number, ref_cycle_id, reference_volume, each_reference.model_id) else: prj_prefix = 'projection_stack' each_reference = self.write_out_reference_volume(each_reference, each_iteration_number, ref_cycle_id, reference_volume) return each_reference, prj_prefix
[docs] def copy_image_stack_to_new_stack(self, projection_stack, local_projection_stack): img_count = EMUtil.get_image_count(projection_stack) img = EMData() for each_img in list(range(img_count)): img.read_image(projection_stack, each_img) img.write_image(local_projection_stack, each_img)
[docs] def copy_image_stack_to_new_stack_shutil(self, projection_stack, local_projection_stack): shutil.copy(projection_stack, local_projection_stack)
[docs] def generate_projection_stack(self, resolution_aim, cycle_number, reference_volume, pixelinfo, azimuthal_angle_count, out_of_plane_tilt_angle_count, projection_stack, helical_symmetry, rotational_sym): projection_parameters = self.generate_Euler_angles_for_projection(azimuthal_angle_count, self.out_of_plane_tilt_angle_range, out_of_plane_tilt_angle_count, helical_symmetry[1]) prj_ids = list(range(len(projection_parameters))) projection_stack = \ SegClassReconstruct().project_through_reference_using_parameters_and_log(projection_parameters, pixelinfo.alignment_size, prj_ids, projection_stack, reference_volume) self.filter_layer_lines_if_demanded(resolution_aim, projection_parameters, prj_ids, projection_stack, pixelinfo, helical_symmetry, rotational_sym) local_projection_stack = os.path.join(self.tempdir, projection_stack) self.copy_image_stack_to_new_stack(projection_stack, local_projection_stack) self.remove_intermediate_files_if_desired(projection_stack) return local_projection_stack, projection_parameters
[docs] def project_through_reference_volume_in_helical_perspectives(self, resolution_aim, cycle_number, reference_volume_file, pixelinfo, helical_symmetry, rotational_sym, prj_prefix='projection_stack'): self.log.fcttolog() self.log.in_progress_log() reference_volume = EMData() reference_volume.read_image(reference_volume_file) projection_stack = '{0}{1:03}.hdf'.format(prj_prefix, cycle_number) if hasattr(self, 'comm'): projection_stack, projection_parameters = self.generate_projection_stack_mpi(resolution_aim, cycle_number, reference_volume, pixelinfo, self.azimuthal_angle_count, self.out_of_plane_tilt_angle_count, projection_stack, helical_symmetry, rotational_sym) else: projection_stack, projection_parameters = self.generate_projection_stack(resolution_aim, cycle_number, reference_volume, pixelinfo, self.azimuthal_angle_count, self.out_of_plane_tilt_angle_count, projection_stack, helical_symmetry, rotational_sym) fine_projection_stack = '{0}_fine{1:03}.hdf'.format(prj_prefix, cycle_number) if resolution_aim in ['high', 'max'] and hasattr(self, 'comm'): fine_projection_stack, fine_projection_parameters = self.generate_projection_stack_mpi(resolution_aim, cycle_number, reference_volume, pixelinfo, 5 * self.azimuthal_angle_count, 5 * self.out_of_plane_tilt_angle_count, fine_projection_stack, helical_symmetry, rotational_sym) elif resolution_aim in ['high', 'max'] and not hasattr(self, 'comm'): fine_projection_stack, fine_projection_parameters = self.generate_projection_stack(resolution_aim, cycle_number, reference_volume, pixelinfo, 5 * self.azimuthal_angle_count, 5 * self.out_of_plane_tilt_angle_count, fine_projection_stack, helical_symmetry, rotational_sym) else: fine_projection_stack = None fine_projection_parameters = None return projection_stack, projection_parameters, fine_projection_stack, fine_projection_parameters
[docs]class SegmentRefine3dProjectionLayerLineFilter(SegmentRefine3dProjection):
[docs] def convert_reciprocal_Angstrom_to_Fourier_pixel_position_in_power_spectrum(self, pixelsize, powersize, reciprocal_Angstrom): """ >>> from spring.segment3d.refine.sr3d_main import SegmentRefine3d >>> s = SegmentRefine3d() >>> s.convert_reciprocal_Angstrom_to_Fourier_pixel_position_in_power_\ spectrum(2.0, 200, 0.25) 100 >>> s.convert_reciprocal_Angstrom_to_Fourier_pixel_position_in_power_\ spectrum(2.0, 200, 0.05) 20 """ Fourier_pixel = int(round(reciprocal_Angstrom * powersize * pixelsize)) return Fourier_pixel
[docs] def convert_Fouier_pixel_to_reciprocal_Angstrom(self, pixelsize, powersize, Fourier_pixel): """ >>> from spring.segment3d.refine.sr3d_main import SegmentRefine3d >>> s = SegmentRefine3d() >>> s.convert_Fouier_pixel_to_reciprocal_Angstrom(2.0, 200, 100) 0.25 >>> s.convert_Fouier_pixel_to_reciprocal_Angstrom(2.0, 200, 20) 0.05 """ reciprocal_Angstrom = Fourier_pixel / float(powersize * pixelsize) return reciprocal_Angstrom
[docs] def determine_Fourier_pixel_position_of_highest_resolution_layer_line(self, helical_symmetry, rotational_sym, helixwidth, pixelsize, powersize, tilt): min_layerline_bessel_pairs = \ SegClassReconstruct().generate_layerline_bessel_pairs_from_rise_and_rotation(helical_symmetry, rotational_sym, helixwidth, pixelsize, 300.0, 2 * pixelsize, tilt) min_ll, min_bessel = zip(*min_layerline_bessel_pairs) max_pixel = \ self.convert_reciprocal_Angstrom_to_Fourier_pixel_position_in_power_spectrum(pixelsize, powersize, np.max(min_ll)) return max_pixel
[docs] def compute_corresponding_tilts_from_Fourier_pixel_series(self, no_tilt_pixel, tilted_series): """ >>> from spring.segment3d.refine.sr3d_main import SegmentRefine3d >>> s = SegmentRefine3d() >>> s.compute_corresponding_tilts_from_Fourier_pixel_series(200, 201) 5.717679685251981 """ tilts = np.rad2deg(np.arccos(no_tilt_pixel / np.array(tilted_series, dtype=float))) return tilts
[docs] def blur_ideal_power_spectrum_by_even_out_of_plane_variation(self, helical_symmetry, rotational_sym, helixwidth, pixelsize, powersize, each_unique_tilt, out_of_plane_blur, no_tilt_pixel): min_tilt = max(abs(each_unique_tilt) - out_of_plane_blur, 0) min_pixel = self.determine_Fourier_pixel_position_of_highest_resolution_layer_line(helical_symmetry, rotational_sym, helixwidth, pixelsize, powersize, min_tilt) max_tilt = abs(each_unique_tilt + out_of_plane_blur) max_pixel = self.determine_Fourier_pixel_position_of_highest_resolution_layer_line(helical_symmetry, rotational_sym, helixwidth, pixelsize, powersize, max_tilt) pixel_series = np.arange(min(min_pixel, max_pixel), max(min_pixel, max_pixel) + 1) tilts = self.compute_corresponding_tilts_from_Fourier_pixel_series(no_tilt_pixel, pixel_series) ideal_power_img = model_blank(powersize, powersize) for each_tilt in tilts: layerline_bessel_pairs =\ SegClassReconstruct().generate_layerline_bessel_pairs_from_rise_and_rotation(helical_symmetry, rotational_sym, helixwidth, pixelsize, 300.0, 2 * pixelsize, each_tilt) tilt_power_img, linex_fine =\ SegClassReconstruct().prepare_ideal_power_spectrum_from_layer_lines(layerline_bessel_pairs, helixwidth, powersize, pixelsize, binary=True) ideal_power_img += tilt_power_img ideal_power_img /= len(tilts) return ideal_power_img
[docs] def generate_binary_layer_line_filter_for_different_out_of_plane_tilt_angles(self, tilts, helical_symmetry, rotational_sym, helixwidth, pixelsize, powersize, binary=True, out_of_plane_blur=None): ideal_power_imgs = [] if out_of_plane_blur is not None: no_tilt_pixel = self.determine_Fourier_pixel_position_of_highest_resolution_layer_line(helical_symmetry, rotational_sym, helixwidth, pixelsize, powersize, 0.0) for each_unique_tilt in tilts: if out_of_plane_blur is not None: ideal_power_img = self.blur_ideal_power_spectrum_by_even_out_of_plane_variation(helical_symmetry, rotational_sym, helixwidth, pixelsize, powersize, each_unique_tilt, out_of_plane_blur, no_tilt_pixel) if binary is True: ideal_power_img = binarize(ideal_power_img, 1e-14) else: layerline_bessel_pairs =\ SegClassReconstruct().generate_layerline_bessel_pairs_from_rise_and_rotation(helical_symmetry, rotational_sym, helixwidth, pixelsize,300.0, 2 * pixelsize, each_unique_tilt) ideal_power_img, linex_fine =\ SegClassReconstruct().prepare_ideal_power_spectrum_from_layer_lines(layerline_bessel_pairs, helixwidth, powersize, pixelsize, binary=True) ideal_power_imgs.append(ideal_power_img) return ideal_power_imgs
[docs] def get_maximum_pixel_displacement_by_inplane_rotation_at_edge(self, inplane_blur, powersize, pixel_pos): """ >>> from spring.segment3d.refine.sr3d_main import SegmentRefine3d >>> s = SegmentRefine3d() >>> s.get_maximum_pixel_displacement_by_inplane_rotation_at_edge(5, 200, 80) 8 >>> s.get_maximum_pixel_displacement_by_inplane_rotation_at_edge(-5, 200, 80) -9 >>> s.get_maximum_pixel_displacement_by_inplane_rotation_at_edge(-0.3, 200, 80) 0 """ rot_x, rot_y = Segment().rotate_coordinates_by_angle(0.0, pixel_pos, -inplane_blur) slope = np.tan(np.deg2rad(inplane_blur)) intercept = rot_y - slope * rot_x y_height = slope * (- powersize / 2.0) + intercept y_difference = int(pixel_pos - y_height) return y_difference
[docs] def generate_power_spectrum_average_by_inplane_angle_blurring(self, power_spectrum, inplane_blur, helical_symmetry, rotational_sym, helixwidth, pixelsize, binary=True): powersize = power_spectrum.get_ysize() pixel_pos = self.determine_Fourier_pixel_position_of_highest_resolution_layer_line(helical_symmetry, rotational_sym, helixwidth, pixelsize, powersize, 0.0) min_pix = self.get_maximum_pixel_displacement_by_inplane_rotation_at_edge(-inplane_blur, powersize, pixel_pos) max_pix = self.get_maximum_pixel_displacement_by_inplane_rotation_at_edge(inplane_blur, powersize, pixel_pos) inplane_angles = np.arange(min(min_pix, max_pix) - 1, max(min_pix, max_pix) + 2) blurred_power = model_blank(powersize, powersize, 1, 0) for each_angle in inplane_angles: rot_power = rot_shift2D(power_spectrum, float(each_angle)) if binary is True: rot_power = binarize(rot_power, 0.1) blurred_power += rot_power blurred_power /= len(inplane_angles) return blurred_power
[docs] def get_padsize_and_unique_tilts(self, alignment_size, projection_parameters): padsize = 4 * alignment_size theta = np.array(projection_parameters)[:,1] unique_tilts = np.unique(np.abs(theta - 90.0)) return padsize, unique_tilts
[docs] def compute_angular_blur_based_on_Crowther_criterion(self, ref_cycle_id, outer_diameter, pixelsize): if os.path.exists('refinement{0:03}.db'.format(ref_cycle_id)): temp_ref_db = self.copy_ref_db_to_tempdir(ref_cycle_id) ref_session = SpringDataBase().setup_sqlite_db(refine_base, temp_ref_db) last_cycle = ref_session.query(RefinementCycleTable).order_by(desc(RefinementCycleTable.id)).first() ref_session.close() os.remove(temp_ref_db) else: last_cycle = None if last_cycle is not None: angular_blur = last_cycle.fsc_05 / np.deg2rad(outer_diameter / 2.0) else: angular_blur = pixelsize * 4.0 / np.deg2rad(outer_diameter / 2.0) return angular_blur
[docs] def apply_inplane_blurring_to_set_of_power_spectra(self, ideal_power_imgs, helical_symmetry, rotational_sym, helixwidth, pixelsize, angular_blur, binary=True): if angular_blur is not None: blurred_powers = [] power_size = ideal_power_imgs[0].get_ysize() circle = model_circle(power_size /2, power_size, power_size) for each_power in ideal_power_imgs: blurred_power = self.generate_power_spectrum_average_by_inplane_angle_blurring(each_power, angular_blur / 5.0, helical_symmetry, rotational_sym, helixwidth, pixelsize) if binary: # blurred_power = -1 * binarize(-1 * blurred_power, 0.0) + 1 blurred_power = binarize(blurred_power, 1e-14) blurred_power *= circle blurred_powers.append(blurred_power) ideal_power_imgs = blurred_powers return ideal_power_imgs
[docs] def generate_binary_layer_line_filters_including_angular_blur(self, projection_parameters, pixelinfo, helical_symmetry, rotational_sym, angular_blur=None): padsize, unique_tilts = self.get_padsize_and_unique_tilts(pixelinfo.alignment_size, projection_parameters) ideal_power_imgs = self.generate_binary_layer_line_filter_for_different_out_of_plane_tilt_angles(unique_tilts, helical_symmetry, rotational_sym, self.helixwidth, pixelinfo.pixelsize, padsize, binary=True, out_of_plane_blur=angular_blur) ideal_power_imgs = self.apply_inplane_blurring_to_set_of_power_spectra(ideal_power_imgs, helical_symmetry, rotational_sym, self.helixwidth, pixelinfo.pixelsize, angular_blur) return unique_tilts, padsize, ideal_power_imgs
[docs] def filter_projections_using_provided_layer_line_filters(self, projection_parameters, prj_ids, projection_stack, unique_tilts, padsize, ideal_power_imgs, pixelinfo): projection = EMData() rectangular_mask = SegmentExam().make_smooth_rectangular_mask(pixelinfo.helixwidthpix, pixelinfo.helix_heightpix, pixelinfo.alignment_size) filter_loginfo = [] for each_local_prj_id, each_parameter in enumerate(projection_parameters): each_theta = each_parameter[1] each_total_prj_id = prj_ids[each_local_prj_id] tilt_index = np.where(unique_tilts == abs(each_theta - 90.0)) projection.read_image(projection_stack, each_local_prj_id) projection = SegClassReconstruct().filter_image_by_fourier_filter_while_padding(projection, pixelinfo.alignment_size, padsize, ideal_power_imgs[tilt_index[0][0]]) projection *= rectangular_mask projection.write_image(projection_stack, each_local_prj_id) filter_loginfo += [[each_total_prj_id, each_local_prj_id, each_theta - 90]] header = ['stack_id', 'local_id', 'out-of-plane tilt'] msg = tabulate(filter_loginfo, header) self.log.tlog('The following projection images were filtered with a layer-line based' + 'filter:\n{0}'.format(msg))
[docs] def filter_layer_lines_if_demanded(self, resolution_aim, projection_parameters, prj_ids, projection_stack, pixelinfo, helical_symmetry, rotational_sym): if self.layer_line_filter:# and resolution_aim in ['high', 'max']: unique_tilts, padsize, ideal_power_imgs = \ self.generate_binary_layer_line_filters_including_angular_blur(projection_parameters, pixelinfo, helical_symmetry, rotational_sym) self.filter_projections_using_provided_layer_line_filters(projection_parameters, prj_ids, projection_stack, unique_tilts, padsize, ideal_power_imgs, pixelinfo)