# Author: Carsten Sachse 23-Aug-2011
# Copyright: EMBL (2010 - 2018), Forschungszentrum Juelich (2019 - 2021)
# License: see license.txt for details
"""
Program to correct helical segments by determined CTF
"""
from EMAN2 import EMData
from filter import filt_ctf
from spring.csinfrastr.csdatabase import SpringDataBase, base, SegmentTable, CtfFindMicrographTable, \
CtfTiltMicrographTable, CtfMicrographTable
from spring.csinfrastr.csfeatures import Features
from spring.csinfrastr.cslogger import Logger
from spring.csinfrastr.csproductivity import OpenMpi
from spring.csinfrastr.csreadinput import OptHandler
from tabulate import tabulate
from utilities import generate_ctf
import numpy as np
import os
import shutil
[docs]class SegmentCtfApplyPar(object):
"""
Class to initiate default dictionary with input parameters including help and range values and status dictionary
"""
def __init__(self):
# package/program identity
self.package = 'emspring'
self.progname = 'segmentctfapply'
self.proginfo = __doc__
self.code_files = [self.progname, self.progname + '_mpi']
self.segmentctfapply_features = Features()
self.feature_set = self.segmentctfapply_features.setup(self)
self.define_parameters_and_their_properties()
self.define_program_states()
[docs] def define_parameters_and_their_properties(self):
self.feature_set = self.segmentctfapply_features.set_inp_stack(self.feature_set)
self.feature_set = self.segmentctfapply_features.set_out_stack(self.feature_set)
self.feature_set = self.segmentctfapply_features.set_spring_path_segments(self.feature_set)
self.feature_set = self.segmentctfapply_features.set_pixelsize(self.feature_set)
self.feature_set = self.set_ctf_correct(self.feature_set)
self.feature_set = self.segmentctfapply_features.set_mpi(self.feature_set)
self.feature_set = self.segmentctfapply_features.set_ncpus(self.feature_set)
self.feature_set = self.segmentctfapply_features.set_temppath(self.feature_set)
[docs] def define_program_states(self):
self.feature_set.program_states['get_ctf_values_from_database_and_compute_local_ctf_based_if_demanded']=\
'Retrieve CTF values from database'
self.feature_set.program_states['apply_ctf_to_segments']='Convolute each segment with CTF parameters'
[docs] def add_ctf_correct_option_as_relative(self, feature_set, inp7):
if 'CTF correct option' in feature_set.parameters:
feature_set.relatives[inp7] = 'CTF correct option'
return feature_set
[docs] def set_ctffind_or_ctftilt_choice(self, feature_set):
inp7 = 'CTFFIND or CTFTILT'
feature_set.parameters[inp7] = str('ctftilt')
feature_set.hints[inp7] = 'Choose whether \'ctffind\' or \'ctftilt\' values are used for CTF correction.'
feature_set.properties[inp7] = feature_set.choice_properties(2, ['ctffind', 'ctftilt'], 'QComboBox')
feature_set = self.add_ctf_correct_option_as_relative(feature_set, inp7)
feature_set.level[inp7]='intermediate'
return feature_set
[docs] def set_ctfconvolve_or_ctfphase_flip_option(self, feature_set):
inp7 = 'convolve or phase-flip'
feature_set.parameters[inp7] = str('convolve')
feature_set.hints[inp7] = 'Choose whether to \'convolve\' or \'phase-flip\' images with determined CTF.'
feature_set.properties[inp7] = feature_set.choice_properties(2, ['convolve', 'phase-flip'], 'QComboBox')
feature_set = self.add_ctf_correct_option_as_relative(feature_set, inp7)
feature_set.level[inp7]='intermediate'
return feature_set
[docs] def set_astigmatism_option(self, feature_set):
inp6 = 'Astigmatism correction'
feature_set.parameters[inp6] = bool(True)
feature_set.hints[inp6] = 'Option to correct for astigmatism in image otherwise average defocus is used.'
feature_set = self.add_ctf_correct_option_as_relative(feature_set, inp6)
feature_set.level[inp6]='expert'
return feature_set
[docs] def set_ctf_correct(self, feature_set):
feature_set = self.set_ctffind_or_ctftilt_choice(feature_set)
feature_set = self.set_ctfconvolve_or_ctfphase_flip_option(feature_set)
feature_set = self.set_astigmatism_option(feature_set)
return feature_set
[docs]class SegmentCtfApplyCtfFindCtfTilt(object):
def __init__(self, parset = None):
self.log = Logger()
if parset is not None:
self.feature_set = parset
p = self.feature_set.parameters
self.infile = p['Image input stack']
self.outfile = p['Image output stack']
self.spring_path = p['spring.db file']
self.pixelsize = float(p['Pixel size in Angstrom'])
self = self.define_ctf_correction_parameters(self, p)
self.mpi_option = p['MPI option']
self.cpu_count = p['Number of CPUs']
self.temppath=p['Temporary directory']
[docs] def define_ctf_correction_parameters(self, obj, p):
obj.ctffind_or_ctftilt_choice = p['CTFFIND or CTFTILT']
obj.convolve_or_phaseflip_choice = p['convolve or phase-flip']
obj.astigmatism_option = p['Astigmatism correction']
return obj
[docs]class SegmentCtfApplyConversion(SegmentCtfApplyCtfFindCtfTilt):
[docs] def compute_local_defocus_from_ctftilt_parameters(self, coord_x, coord_y, df1, df2, center_x, center_y, pixelsize,
taxis, tangle):
"""
>>> from spring.segment2d.segmentctfapply import SegmentCtfApply
>>> s = SegmentCtfApply()
>>> s.compute_local_defocus_from_ctftilt_parameters(500, 500, 37735.65, 42714.57, 500, 500, 5.0, 135.10, -19.97)
(37735.65, 42714.57)
>>> s.compute_local_defocus_from_ctftilt_parameters(1, 1000, 37735.65, 42714.57, 500, 500, 5.0, 135.10, -19.97)
(37739.1747698892, 42718.0947698892)
>>> s.compute_local_defocus_from_ctftilt_parameters(1, 1, 37735.65, 42714.57, 500, 500, 5.0, 135.10, -19.97)
(36453.48835358775, 41432.40835358775)
>>> s.compute_local_defocus_from_ctftilt_parameters(1000, 1, 37735.65, 42714.57, 500, 500, 5.0, 135.10, -19.97)
(37734.69469232806, 42713.61469232806)
>>> c_x = np.array([1000, 500])
>>> c_y = np.array([1, 500])
>>> s.compute_local_defocus_from_ctftilt_parameters(c_x, c_y, 37735.65, 42714.57, 500, 500, 5.0, 135.10, -19.97)
(array([37734.69469233, 37735.65 ]), array([42713.61469233, 42714.57 ]))
"""
"""
This is the CTFTILT convention (inverted tiltangle) from version 1.7 (May 2012)
This is a typical CTFTILT output::
DFMID1 DFMID2 ANGAST TLTAXIS TANGLE CC
37735.65 42714.57 78.74 135.10 -19.97 0.39982 Final Values
EQUATION FOR CALCULATING DEFOCUS DFL1,DFL2 AT LOCATION NX,NY:
DFL1 = DFMID1 +DF
DFL2 = DFMID2 +DF
DF = (N1*DX+N2*DY)*PSIZE*TAN(TANGLE)
DX = CX-NX
DY = CY-NY
CX = CENTER_X = 500
CY = CENTER_Y = 500
PSIZE = PIXEL SIZE [A] = 5.0000
N1,N2 = TILT AXIS NORMAL:
N1 = SIN(TLTAXIS) = 0.705872
N2 = -COS(TLTAXIS) = 0.708339
36453.64, 41432.56 <--(DFMID1,DFMID2)--> 37734.70, 42713.62
1, 1 <------(NX,NY)------> 1000, 1
+----------------------------------------------------------+
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| 37735.65, 42714.57 |
| 500, 500 |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
+----------------------------------------------------------+
1, 1000 <------(NX,NY)------> 1000, 1000
37739.17, 42718.09 <--(DFMID1,DFMID2)--> 39020.23, 43999.15
"""
n1 = np.sin(np.deg2rad(taxis))
n2 = -np.cos(np.deg2rad(taxis))
dx = center_x - coord_x
dy = center_y - coord_y
df = (n1*dx + n2*dy) * pixelsize * np.tan(np.deg2rad(tangle))
df1_local = df1 + df
df2_local = df2 + df
return df1_local, df2_local
[docs] def convert_mrc_defocus_to_sparx_defocus(self, defocus1, defocus2, astigmation_angle):
"""
>>> from spring.micprgs.micctfdetermine import MicCtfDetermine
>>> SegmentCtfApply().convert_mrc_defocus_to_sparx_defocus(18000.0, 22000.0, 20.0)
(20000.0, 4000.0, 25.0)
>>> SegmentCtfApply().convert_mrc_defocus_to_sparx_defocus(18000.0, 22000.0, 80.0)
(20000.0, 4000.0, 145.0)
>>> SegmentCtfApply().convert_mrc_defocus_to_sparx_defocus(22000.0, 18000.0, 80.0)
(20000.0, 4000.0, 55.0)
"""
df1_df2_diff = defocus1 - defocus2
if df1_df2_diff < 0:
astigmation_angle_sparx = 45.0 - astigmation_angle
elif df1_df2_diff >= 0:
astigmation_angle_sparx = 135.0 - astigmation_angle
astigmatism_sparx = abs(df1_df2_diff)
avg_defocus_sparx = sum([defocus1, defocus2]) / 2.0
astigmation_angle_sparx = astigmation_angle_sparx % 180
return avg_defocus_sparx, astigmatism_sparx, astigmation_angle_sparx
[docs] def convert_mrc_defocus_to_spider_defocus(self, defocus1, defocus2, astigmation_angle):
"""
>>> from spring.micprgs.micctfdetermine import MicCtfDetermine
>>> SegmentCtfApply().convert_mrc_defocus_to_spider_defocus(18000.0, 22000.0, 20.0)
(20000.0, 4000.0, 155.0)
>>> SegmentCtfApply().convert_mrc_defocus_to_spider_defocus(18000.0, 22000.0, 80.0)
(20000.0, 4000.0, 35.0)
>>> SegmentCtfApply().convert_mrc_defocus_to_spider_defocus(22000.0, 18000.0, 80.0)
(20000.0, 4000.0, 125.0)
"""
df1_df2_diff = defocus1 - defocus2
if df1_df2_diff < 0:
astigmation_angle_sparx = astigmation_angle + 135
elif df1_df2_diff >= 0:
astigmation_angle_sparx = astigmation_angle + 45
astigmatism_sparx = abs(df1_df2_diff)
avg_defocus_sparx = sum([defocus1, defocus2]) / 2.0
astigmation_angle_sparx = astigmation_angle_sparx % 180
return avg_defocus_sparx, astigmatism_sparx, astigmation_angle_sparx
[docs]class SegmentCtfApplyDatabase(SegmentCtfApplyConversion):
[docs] def raise_error_if_not_found(self, spring_path, matched_mic_find):
if matched_mic_find is None:
find_error_msg = 'Specified {0} file does not contain micrograph information from '.format(spring_path) + \
'CTFFIND. Please re-run MicCtfDetermine.'
raise ValueError(find_error_msg)
[docs] def get_micrograph_from_database_by_micid(self, session, mic_id, spring_path):
matched_mic_find = session.query(CtfFindMicrographTable).get(mic_id)
self.raise_error_if_not_found(spring_path, matched_mic_find)
return matched_mic_find
[docs] def get_micrograph_from_database_by_micname(self, session, micrograph_file, spring_path):
matched_mic_find = session.query(CtfFindMicrographTable).\
filter(CtfFindMicrographTable.micrograph_name == os.path.basename(micrograph_file)).first()
self.raise_error_if_not_found(spring_path, matched_mic_find)
return matched_mic_find
[docs] def get_ctfparameters_from_database(self, ctffind_or_ctftilt_choice, astigmatism_option, pixelsize, session,
each_segment, matched_mic_find, spring_path):
if ctffind_or_ctftilt_choice in ['ctftilt']:
matched_mic_tilt = session.query(CtfTiltMicrographTable).get(matched_mic_find.id)
self.raise_error_if_not_found(spring_path, matched_mic_tilt)
local_df1, local_df2 = self.compute_local_defocus_from_ctftilt_parameters(each_segment.x_coordinate_A /
matched_mic_find.pixelsize, each_segment.y_coordinate_A / matched_mic_find.pixelsize,
matched_mic_tilt.defocus1, matched_mic_tilt.defocus2, matched_mic_tilt.center_x, matched_mic_tilt.center_y,
matched_mic_find.pixelsize, matched_mic_tilt.tilt_axis, matched_mic_tilt.tilt_angle)
defocus1 = local_df1
defocus2 = local_df2
astigmation_angle = matched_mic_tilt.astigmation_angle
elif ctffind_or_ctftilt_choice in ['ctffind']:
defocus1 = matched_mic_find.defocus1
defocus2 = matched_mic_find.defocus2
astigmation_angle = matched_mic_find.astigmation_angle
avg_defocus, astigmatism, astig_angle = self.convert_mrc_defocus_to_sparx_defocus(defocus1, defocus2,
astigmation_angle)
matched_mic = session.query(CtfMicrographTable).get(matched_mic_find.id)
if not astigmatism_option:
astigmatism = 0
astig_angle = 0
ctf_params = [avg_defocus * 1e-4, matched_mic.spherical_aberration, matched_mic.voltage, pixelsize, 0,
matched_mic.amplitude_contrast, astigmatism * 1e-4, astig_angle]
return ctf_params, avg_defocus, astigmatism, astig_angle
[docs] def update_ctfparameters_in_database(self, ctffind_or_ctftilt_choice, convolve_or_phaseflip_choice,
astigmatism_option, session, each_segment, avg_defocus, astigmatism, astig_angle):
each_segment.avg_defocus = avg_defocus
each_segment.astigmatism = astigmatism
each_segment.astigmation_angle = astig_angle
if ctffind_or_ctftilt_choice in ['ctftilt']:
each_segment.ctffind_applied = False
each_segment.ctftilt_applied = True
elif ctffind_or_ctftilt_choice in ['ctffind']:
each_segment.ctffind_applied = True
each_segment.ctftilt_applied = False
if convolve_or_phaseflip_choice in ['convolve']:
each_segment.ctf_convolved = True
each_segment.ctf_phase_flipped = False
elif convolve_or_phaseflip_choice in ['phase-flip']:
each_segment.ctf_convolved = False
each_segment.ctf_phase_flipped = True
if astigmatism_option:
each_segment.ctf_astigmatism_applied = True
else:
each_segment.ctf_astigmatism_applied = False
return session, each_segment
[docs] def get_ctf_values_from_database_and_compute_local_ctf_based_if_demanded(self, ctffind_or_ctftilt_choice,
convolve_or_phaseflip_choice, astigmatism_option, pixelsize, spring_path):
self.log.fcttolog()
self.log.plog(10)
session = SpringDataBase().setup_sqlite_db(base)
matched_segments = session.query(SegmentTable).order_by(SegmentTable.stack_id).all()
ctf_parameters = []
for each_segment in matched_segments:
matched_mic_find = self.get_micrograph_from_database_by_micid(session, each_segment.mic_id, spring_path)
ctf_params, avg_defocus, astigmatism, astig_angle = \
self.get_ctfparameters_from_database(ctffind_or_ctftilt_choice, astigmatism_option, pixelsize, session,
each_segment, matched_mic_find, spring_path)
ctf_parameters.append(ctf_params)
session, each_segment = self.update_ctfparameters_in_database(ctffind_or_ctftilt_choice,
convolve_or_phaseflip_choice, astigmatism_option, session, each_segment, avg_defocus, astigmatism,
astig_angle)
session.merge(each_segment)
session.commit()
return ctf_parameters
[docs]class SegmentCtfApply(SegmentCtfApplyDatabase):
[docs] def filter_image_by_ctf_convolve_or_phaseflip(self, convolve_or_phaseflip_choice, segment, each_segment_ctf_p):
local_ctf = generate_ctf(each_segment_ctf_p)
if convolve_or_phaseflip_choice in ['convolve']:
segment = filt_ctf(segment, local_ctf)
elif convolve_or_phaseflip_choice in ['phase-flip']:
segment = filt_ctf(segment, local_ctf, binary=True)
return segment
[docs] def apply_ctf_to_segments(self, segment_ids, ctf_parameters, convolve_or_phaseflip_choice, infile_stack,
outfile_stack):
self.log.fcttolog()
self.log.plog(20)
segment = EMData()
if ctf_parameters != []:
log_info = [ctf_parameters[0][1:5]]
msg = tabulate(log_info, ['Cs(mm)', 'voltage(kV)', 'pixelsize', 'bfactor', 'amp_contrast'])
self.log.ilog(msg)
log_info = []
for each_local_seg_id, each_seg_id in enumerate(segment_ids):
each_segment_ctf_p = ctf_parameters[each_local_seg_id]
segment.read_image(infile_stack, each_seg_id)
segment = self.filter_image_by_ctf_convolve_or_phaseflip(convolve_or_phaseflip_choice, segment,
each_segment_ctf_p)
segment.write_image(outfile_stack, each_local_seg_id)
log_info += [[each_seg_id, each_local_seg_id, each_segment_ctf_p[0], each_segment_ctf_p[6],
each_segment_ctf_p[7]]]
if ctf_parameters != []:
msg = tabulate(log_info, ['segment_id', 'local_id', 'avg_defocus(microm)', 'astigmatism', 'astig_angle'])
self.log.ilog(log_info)
self.log.plog(90)
[docs] def apply_ctf_to_segment_stack(self):
OpenMpi().setup_and_start_mpi_version_if_demanded(self.mpi_option, self.feature_set, self.cpu_count)
shutil.copy(self.spring_path, 'spring.db')
ctf_parameters = \
self.get_ctf_values_from_database_and_compute_local_ctf_based_if_demanded(self.ctffind_or_ctftilt_choice,
self.convolve_or_phaseflip_choice, self.astigmatism_option, self.pixelsize, self.spring_path)
segment_ids = list(range(len(ctf_parameters)))
self.apply_ctf_to_segments(segment_ids, ctf_parameters, self.convolve_or_phaseflip_choice, self.infile,
self.outfile)
self.log.endlog(self.feature_set)
[docs]def main():
# Option handling
parset = SegmentCtfApplyPar()
mergeparset = OptHandler(parset)
######## Program
stack = SegmentCtfApply(mergeparset)
stack.apply_ctf_to_segment_stack()
if __name__ == '__main__':
main()