# Author: Carsten Sachse
# Copyright: EMBL (2010 - 2018), Forschungszentrum Juelich (2019 - 2021)
# License: see license.txt for details
"""
Program to extract amplitudes and phases from desired layer lines of class averages
"""
from EMAN2 import Util, EMData, EMNumPy, periodogram
from collections import OrderedDict, namedtuple
from filter import filt_table
from functools import partial
import os
from spring.csinfrastr.csfeatures import Features
from spring.csinfrastr.csgui import QTabWidgetCloseable, NumbersOptionsGuiWindow
from spring.csinfrastr.cslogger import Logger
from spring.csinfrastr.csproductivity import DiagnosticPlot
from spring.csinfrastr.csreadinput import OptHandler
from spring.segment2d.segclassexam import SegClassExam
from spring.segment2d.segmentalign2d import SegmentAlign2d
from spring.segment2d.segmentexam import SegmentExam
from spring.segment3d.segclassreconstruct import SegClassReconstruct
from spring.springgui.springdataexplore import SpringCommon, SpringDataExplore
import sys
from utilities import model_blank
from PyQt5.QtCore import Qt
##from PyQt5.QtCore import pyqtSignal as SIGNAL
from PyQt5.QtWidgets import QApplication, QWidget, QComboBox, QStackedWidget, QSplitter, QToolTip, QGridLayout, \
QDoubleSpinBox, QLabel
from PyQt5.QtGui import QFont
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.font_manager import FontProperties
from scipy import interpolate
from tabulate import tabulate
import numpy as np
[docs]class SegClassLayerPar(object):
"""
Class to initiate default dictionary with input parameters including help and range values and status dictionary
"""
def __init__(self):
# package/program identity
self.package = 'emspring'
self.progname = 'segclasslayer'
self.proginfo = __doc__
self.code_files = [self.progname]
self.segclasslayer_features = Features()
self.feature_set = self.segclasslayer_features.setup(self)
self.define_parameters_and_their_properties()
self.define_program_states()
[docs] def define_parameters_and_their_properties(self):
self.feature_set = self.segclasslayer_features.set_class_avg_stack(self.feature_set)
self.feature_set = self.segclasslayer_features.set_interactive_vs_batch_mode(self.feature_set)
self.feature_set = self.segclasslayer_features.set_output_plot(self.feature_set, self.progname + '_diag.pdf',
'Batch mode')
self.feature_set = self.set_class_format_choice(self.feature_set)
self.feature_set = self.segclasslayer_features.set_pixelsize(self.feature_set)
self.feature_set = self.segclasslayer_features.set_exact_helix_width(self.feature_set)
self.feature_set = self.segclasslayer_features.set_bfactor_on_images(self.feature_set)
self.feature_set = self.segclasslayer_features.set_power_cutoff(self.feature_set)
self.feature_set = self.segclasslayer_features.set_class_number_range_to_be_analyzed(self.feature_set)
self.feature_set = self.set_layer_line_position(self.feature_set)
self.feature_set = self.set_pad_option(self.feature_set)
[docs] def define_program_states(self):
self.feature_set.program_states['extract_layerlines']='Extracts layer lines from specified position'
self.feature_set.program_states['visualize_layerlines']='Visualization of layer lines'
[docs] def set_layer_line_position(self, feature_set):
inp8 = 'Layer line positions'
feature_set.parameters[inp8] = str('0.1234,0.234')
feature_set.hints[inp8] = 'List of comma-separated values in reciprocal Angstrom'
feature_set.properties[inp8] = feature_set.file_properties(1, ['*'], None)
feature_set.level[inp8]='expert'
feature_set.relatives[inp8]='Batch mode'
return feature_set
[docs] def set_pad_option(self, feature_set):
inp9 = 'Pad option'
feature_set.parameters[inp9] = bool(True)
feature_set.hints[inp9] = 'If layer lines are part of a continuous lattice set, average will be padded so ' + \
'that layer lines lie exactly on pixel grid'
feature_set.level[inp9]='expert'
feature_set.relatives[inp9]='Batch mode'
return feature_set
[docs]class SegClassLayerGui(QWidget):
def __init__(self, feature_set, parent = None):
QWidget.__init__(self, parent)
self.feature_set = feature_set
self.properties = feature_set.properties
self = SpringCommon().setup_spring_page_top(self, feature_set)
self.segclasslayer = SegClassLayer(self.feature_set)
start_cls, end_cls = self.segclasslayer.classno_range
self.classes = ['{0} - {1:03}'.format(os.path.basename(self.segclasslayer.infile), each_class) \
for each_class in list(range(start_cls, end_cls + 1))]
self.stackedComboBox = QComboBox()
self.stackedComboBox.addItems(self.classes)
self.layout.addWidget(self.stackedComboBox, 0, 3, 1, 1)
self.angstrom_str = NumbersOptionsGuiWindow().convert_angstrom_string('Angstrom')
self.stackedWidget = QStackedWidget()
self.class_plot = DiagnosticPlot()
container = self.segclasslayer.make_named_tuple_amp_phase_images()
self.class_img_canvas = [None for each_class in list(range(start_cls, end_cls + 1))]
self.class_ps_canvas = [None for each_class in list(range(start_cls, end_cls + 1))]
self.class_data = [container(each_class, None, None, None) \
for each_id, each_class in enumerate(list(range(start_cls, end_cls + 1)))]
self.bfactor_dials = [None for each_class in list(range(start_cls, end_cls + 1))]
self.tabwidgets = []
for each_id, each_class in enumerate(list(range(start_cls, end_cls + 1))):
self.splitter = QSplitter(Qt.Horizontal)
self.splitter_vert = QSplitter(Qt.Vertical)
self.setCurrentDisplay(each_id)
self.splitter_vert.addWidget(self.class_img_canvas[each_id])
self.splitter_vert.addWidget(self.class_ps_canvas[each_id])
self.splitter.addWidget(self.splitter_vert)
self.class_ps_canvas[each_id].main_frame.setToolTip('Mouse button to display amplitude and phases of ' + \
'corresponding layer line.')
self.class_ps_canvas[each_id].picked_left_point.connect(partial(self.add_new_tab_with_layer_line_profile))
self.class_ps_canvas[each_id].picked_middle_point.connect(partial(self.add_new_tab_with_layer_line_profile))
fig = self.prepare_layer_line_profile_plot(self.class_plot, self.layer_profile, self.max_amp)
self.layer_fig = SpringDataExplore(fig)
self.layer_plane = QWidget()
self.layer_canvas = QGridLayout()
self.layer_canvas.addWidget(self.layer_fig, 0, 0, 0, 4)
#
self.bfactor_label = QLabel()
self.bfactor_label.setText('B-factor')
self.layer_canvas.addWidget(self.bfactor_label, 1, 1, 1, 1)
self.bfactor_dials[each_id] = QDoubleSpinBox()
bfact_cutoff_tip = 'Apply a negative B-factor to improve visualization of dampened high resolution ' + \
'layer lines.'
self.bfactor_dials[each_id].setToolTip(bfact_cutoff_tip)
self.bfactor_dials[each_id].setRange(-90000, 90000)
self.bfactor_dials[each_id].setSingleStep(50)
self.bfactor_dials[each_id].setDecimals(0)
self.bfactor_dials[each_id].setValue(self.segclasslayer.bfactor)
##self.connect(self.bfactor_dials[each_id], SIGNAL('editingFinished()'), self.activateOrInactivateBfactor)
self.bfactor_dials[each_id].editingFinished.connect(self.activateOrInactivateBfactor)
self.layer_canvas.addWidget(self.bfactor_dials[each_id], 1, 2, 1, 1)
self.layer_plane.setLayout(self.layer_canvas)
self.tabWidget = QTabWidgetCloseable()
self.tabWidget.addTab(self.layer_plane, 'Equator')
self.tabwidgets.append(self.tabWidget)
self.splitter.addWidget(self.tabWidget)
self.stackedWidget.addWidget(self.splitter)
##self.connect(self.stackedComboBox, SIGNAL('currentIndexChanged(int)'), self.stackedWidget.setCurrentIndex)
self.stackedComboBox.currentIndexChanged.connect(self.stackedWidget.setCurrentIndex)
self.layout.addWidget(self.stackedWidget, 2, 0, 2, 5)
self.setLayout(self.layout)
self.setMouseTracking(True)
[docs] def enterEvent(self, event):
QToolTip.setFont(QFont('Courier', 8))
[docs] def setCurrentDisplay(self, each_index):
class_id = self.class_data[each_index].class_id
class_img_pd, cls_img_np, cls_ps_np = self.segclasslayer.prepare_class_img_and_power_spectrum(class_id,
self.class_plot)
img_dim_A = cls_img_np.shape[0] * self.segclasslayer.pixelsize
A_grid = np.array([(-img_dim_A, -img_dim_A), (-img_dim_A, img_dim_A), (img_dim_A, img_dim_A),
(img_dim_A, -img_dim_A)]).reshape((2,2,2))
if self.class_img_canvas[each_index] == None:
self.class_img_canvas[each_index] = SpringDataExplore()
self.class_img_canvas[each_index].on_draw(cls_img_np, A_grid, [''] + ['Angstrom']*2, '2d', color_map='gray')
nyquist = 1 / (2 * self.segclasslayer.pixelsize)
recip_A_grid = np.array([(-nyquist, -nyquist), (-nyquist, nyquist), (nyquist, nyquist),
(nyquist, -nyquist)]).reshape((2,2,2))
if self.class_ps_canvas[each_index] == None:
self.class_ps_canvas[each_index] = SpringDataExplore()
self.class_ps_canvas[each_index].on_draw(cls_ps_np, recip_A_grid, [''] + ['1/Angstrom']*2, '2d', color_map='hot')
amp_img, phase_img = self.segclasslayer.prepare_amplitude_and_phase_image(class_img_pd)
self.layer_profile = self.segclasslayer.get_amplitude_and_phase_at_position(amp_img, phase_img, 1)
self.fourier_dim = amp_img.get_ysize() / 2
container = self.segclasslayer.make_named_tuple_amp_phase_images()
self.max_amp = 1.1 * max(np.append(self.layer_profile.left_amp, self.layer_profile.right_amp))
self.class_data[each_index]=(container(class_id, amp_img, phase_img, self.max_amp))
[docs] def activateOrInactivateBfactor(self):
cur_id = self.stackedComboBox.currentIndex()
self.segclasslayer.bfactor = self.bfactor_dials[cur_id].value()
self.setCurrentDisplay(cur_id)
[docs] def add_new_tab_with_layer_line_profile(self, index_pair):
y_pos = index_pair[1]
fourier_pixel = int(abs(y_pos - self.fourier_dim)) + 1
nyquist = 1 / (2 * self.segclasslayer.pixelsize)
fourier_pix_A = fourier_pixel * nyquist / self.fourier_dim
cur_id = self.stackedComboBox.currentIndex()
amp_img = self.class_data[cur_id].amp
phase_img = self.class_data[cur_id].phase
max_amp = self.class_data[cur_id].max_amp
layer_profile = self.segclasslayer.get_amplitude_and_phase_at_position(amp_img, phase_img, fourier_pixel)
fig = self.prepare_layer_line_profile_plot(self.class_plot, layer_profile, max_amp)
self.layer_canvas = SpringDataExplore(fig)
besselorder, primarymax, primarymaxwidth = self.segclasslayer.get_bessel_table_quant(self.segclasslayer.helixwidth)
tool_str = tabulate([['{0}'.format('Bessel order')] + besselorder,
['{0}'.format('Primary max at 2*Pi*r*R')] + primarymax,
['{0}'.format('At helix radius (R={0} A)'.format(self.segclasslayer.helixwidth // 2))] + primarymaxwidth],
tablefmt='grid')
self.layer_canvas.main_frame.setToolTip(tool_str)
title = '{0:.04} 1/{1}'.format(fourier_pix_A, self.angstrom_str)
self.tabwidgets[cur_id].addTab(self.layer_canvas, title)
self.tabwidgets[cur_id].setCurrentIndex(self.tabwidgets[cur_id].indexOf(self.layer_canvas))
[docs] def prepare_layer_line_profile_plot(self, class_plot, layer_profile, ampmax=None):
arrresolution = SegmentExam().make_oneoverres(layer_profile.left_phase, self.segclasslayer.pixelsize)
fig = class_plot.create_next_figure()
ax1 = fig.add_subplot(111)
ax1, ax2 = self.segclasslayer.add_amplitude_and_phase_difference_to_plot(layer_profile,
arrresolution, ax1, ampmax)
self.segclasslayer.add_labels_to_plot(ax1, ax2)
return fig
[docs]class SegClassLayerGuiSupport(SegClassLayerExtract):
[docs] def make_named_tuple_amp_phase_images(self):
return namedtuple('image', 'class_id amp phase max_amp')
[docs] def get_class_img(self, class_id):
class_img = EMData()
class_img.read_image(self.infilestack, class_id)
segment_size = class_img.get_xsize()
if self.bfactor != 0:
filter_coefficients = SegmentAlign2d().prepare_filter_function(False, 1 / 300.0, True, 0.02,
self.pixelsize, segment_size, 0.08, False, None, self.bfactor)
class_img = filt_table(class_img, filter_coefficients)
return segment_size, class_img
[docs] def prepare_class_img_and_power_spectrum(self, class_id, class_plot):
segment_size, class_img = self.get_class_img(class_id)
cls_img_np = np.copy(EMNumPy.em2numpy(class_img))
if self.class_format == 'real':
helixmask = SegmentExam().make_smooth_rectangular_mask(self.helixwidthpix, segment_size * 0.8,
segment_size, 0.15)
padsize = 4
class_img_pd = Util.pad(class_img * helixmask, padsize * segment_size, padsize * segment_size, 1, 0, 0, 0, '0')
cls_ps = periodogram(class_img_pd)
cls_ps.process_inplace('normalize')
cls_ps_np = np.copy(EMNumPy.em2numpy(cls_ps))
elif self.class_format == 'power':
cls_ps_np = np.copy(cls_img_np)
return class_img_pd, cls_img_np, cls_ps_np
[docs]class SegClassLayer(SegClassLayerGuiSupport):
[docs] def add_amplitude_and_phase_difference_to_plot(self, layer_profile, arrresolution, ax1, ampmax=None):
ax1.plot(arrresolution, layer_profile.left_amp, label='Left quadrant')
ax1.plot(arrresolution, layer_profile.right_amp, label='Right quadrant amplitudes')
if ampmax is not None:
ax1.set_ylim(0, ampmax)
ax1.set_xlim(arrresolution[0], arrresolution[-1])
ax1.set_yticks([])
ax2 = ax1.twinx()
ax2.set_xlim(arrresolution[0], arrresolution[-1])
ax2.set_ylim(-190, 230)
phase_diff = self.compute_phase_difference(layer_profile.left_phase, layer_profile.right_phase)
ax2.plot(arrresolution, phase_diff, 'r.', label='Phase difference left/right')
ax2.fill([0,arrresolution[-1], arrresolution[-1], 0], [-90,-90, 90, 90], 'grey', alpha=0.05,
label='Even Bessel order')
# ax2.plot(arrresolution, layer_profile.left_phase, '.', label='phase left')
# ax2.plot(arrresolution, layer_profile.right_phase, 'x', label='phase right')
ax2.set_yticks([-180, -90, 0, 90, 180])
[t.set_fontsize(6) for t in ax1.get_xticklabels() + ax1.get_yticklabels() + ax2.get_yticklabels()]
return ax1, ax2
[docs] def determine_first_maxima_computationally(self, layer_profile, ax1):
finestep = (self.arrresolution[1] - self.arrresolution[0]) / 100
resarrpol = np.arange(self.arrresolution[0], self.arrresolution[-1], finestep)
t = interpolate.splrep(self.arrresolution, layer_profile.left_amp, k=3, s=0)
leftamppol = interpolate.splev(resarrpol, t)
ind, val = SegmentExam().find_local_extrema(leftamppol)
leftmax = resarrpol[ind[-1]]
leftmaxval = val[-1]
ax1.text(leftmax, 0.2 * leftmaxval, 'Max {0:.4}'.format(leftmax), fontsize=6)
t = interpolate.splrep(self.arrresolution, layer_profile.right_amp, k=3, s=0)
rightamppol = interpolate.splev(resarrpol, t)
ind, val = SegmentExam().find_local_extrema(rightamppol)
rightmax = resarrpol[ind[-1]]
ax1.text(rightmax, 0.4 * leftmaxval, 'Max {0:.4}'.format(rightmax), fontsize=6)
[docs] def add_labels_to_plot(self, ax1, ax2):
ax1.set_xlabel('Resolution in 1/Angstrom', fontsize=8)
ax1.set_ylabel('LL amplitudes', fontsize=8)
ax1.legend(loc='upper left', ncol=1, prop=FontProperties(size='x-small'))
ax2.set_ylabel('LL phase difference (degrees)', fontsize=8)
ax2.legend(loc='upper right', ncol=1, prop=FontProperties(size='x-small'))
return ax1, ax2
[docs] def plot_individual_layer_lines(self, diag_plot, layerlines, layer_profiles):
plots = OrderedDict()
for self.order, line in enumerate(layerlines):
llplot = 'plot{0}'.format(self.order)
plots[llplot] = diag_plot.plt.subplot2grid((len(layerlines) + 1, 1), (self.order, 0), colspan=1, rowspan=1)
# plot figure
plots[llplot].set_title('{order}. layer line ({res} 1/Angstrom)'.format(order=self.order, res=line),
fontsize=9)
llplottwin = 'plot{0}twin'.format(self.order)
plots[llplot], plots[llplottwin] = self.add_amplitude_and_phase_difference_to_plot(layer_profiles[self.order],
self.arrresolution, plots[llplot])
self.determine_first_maxima_computationally(layer_profiles[self.order], plots[llplot])
plots[llplottwin].minorticks_on()
plots[llplottwin].grid(True)
self.log.ilog('Layer line {order} was visualized left and right quadrant amplitudes and phase ' + \
'difference'.format(order=self.order))
# last plot
if self.order == (len(layerlines) - 1):
self.add_labels_to_plot(plots[llplot], plots[llplottwin])
tbl = diag_plot.plt.subplot2grid((len(self.layerlines) + 1,1), (self.order + 1,0), colspan=1, rowspan=1,
frameon=False)
self.add_besseltable(diag_plot, tbl)
self.log.ilog('Final Bessel order lookup table added')
return diag_plot
[docs] def determine_max(self):
"""
* Function to determine maximum of layer line
"""
ind1 = np.lexsort((self.arrresolution, self.leftamp[self.order]))
ind2 = np.lexsort((self.arrresolution, self.rightamp[self.order]))
if self.leftamp[self.order][ind1[-1]] > self.rightamp[self.order][ind2[-1]]:
self.maxx = self.arrresolution[ind1[-1]]
self.maxy = self.leftamp[self.order][ind1[-1]]
else:
self.maxx = self.arrresolution[ind2[-1]]
self.maxy = self.rightamp[self.order][ind2[-1]]
return self.maxx, self.maxy
[docs] def get_bessel_table_quant(self, helixwidth):
"""
>>> from spring.segment2d.segclasslayer import SegClassLayer
>>> s = SegClassLayer()
>>> s.get_bessel_table_quant(100) #doctest: +NORMALIZE_WHITESPACE
([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
[0.0, 1.8, 3.1, 4.2, 5.3, 6.4, 7.5, 8.6, 9.6, 10.7, 11.8, 12.8, 13.9,
14.9, 16.0, 17.0, 18.1, 19.1, 20.1, 21.2], [0.0, 0.006, 0.01, 0.013,
0.017, 0.02, 0.024, 0.027, 0.031, 0.034, 0.038, 0.041, 0.044, 0.047,
0.051, 0.054, 0.058, 0.061, 0.064, 0.067])
"""
primarymax = SegClassReconstruct().get_list_of_bessel_order_maxima(20)
primarymax = np.around(primarymax, decimals=1)
primarymaxwidth = primarymax / (2 * np.pi * helixwidth / 2.0)
primarymaxwidth = np.around(primarymaxwidth, decimals=3)
besselorder = np.arange(len(primarymax), dtype=int)
return besselorder.tolist(), primarymax.tolist(), primarymaxwidth.tolist()
[docs] def add_besseltable(self, diag_plot, tbl, helixwidth=None):
"""
* Function to add Bessel order look-up table to printout
#. Input: helix width in Angstrom
#. Output: table with Bessel order and their corroesponding primary maximum and expected \
maximum for helix width
#. Usage: add_besseltable(helixwidth)
"""
if helixwidth is None: helixwidth = self.helixwidth
besselorder, primarymax, primarymaxwidth = self.get_bessel_table_quant(helixwidth)
tbl.set_title('Bessel order look-up table (from Stewart 1988)', fontsize=10)
tbl.xaxis.set_visible(False)
tbl.yaxis.set_visible(False)
rowLabels = ['Bessel order',
'Primary maximum at 2*Pi*r*R', 'at helix radius \nof R={width} Angstrom'.format(width=helixwidth/2)]
cellText = np.row_stack((besselorder, primarymax, primarymaxwidth))
the_table = diag_plot.plt.table(cellText=cellText, rowLabels=rowLabels, loc='center')
the_table.auto_set_font_size(False)
the_table.set_fontsize(4)
[docs] def visualize_layerlines(self, figno, layerlines, layer_profiles):
"""
* Function to visualize layer lines
#. Input: figurenumber, list of layerlines (1/Angstrom), list of left-side centered amplitude \
arrays, list of right-side centered amplitude arrays, list of phasedifference arrays
#. Output: figure with stacked amplitude profile and phase difference plots
#. Usage: figure = visualize_layerlines(figno, layerlines, leftamp, rightamp, phasediff)
"""
self.log.fcttolog()
layerline_plot = DiagnosticPlot()
layerline_plot.plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=0.65)
# self.fig = layerline_plot.add_header_and_footer(self.feature_set)
self.arrresolution = SegmentExam().make_oneoverres(layer_profiles[0].left_phase, self.pixelsize)
layerline_plot = self.plot_individual_layer_lines(layerline_plot, layerlines, layer_profiles)
return layerline_plot.fig
[docs] def print_layerlines_as_requested(self):
if os.path.splitext(self.outfile)[-1].endswith('pdf'):
self.pdf = PdfPages(self.outfile)
self.classno_range = SegClassExam().check_maximum_class_number(self.infilestack, self.classno_range)
classno_start, classno_end = self.classno_range
self.log.plog(10)
classes_iter = list(range(classno_start, classno_end + 1))
for each_class_index in classes_iter:
if len(classes_iter) > 1:
plot_file, self.feature_set = SegClassExam().rename_plot_title_for_multiple_classes(self.infile,
self.outfile, classno_start, classno_end, each_class_index, self.feature_set)
else:
plot_file = self.outfile
layerlines, layer_profiles = self.extract_layerlines(self.infilestack, each_class_index)
self.fig = self.visualize_layerlines(each_class_index, layerlines, layer_profiles)
if os.path.splitext(self.outfile)[-1].endswith('pdf'):
self.pdf.savefig(self.fig)
elif not os.path.splitext(self.outfile)[-1].endswith('pdf') and classno_end != classno_start:
self.fig.savefig(plot_file, dpi=600)
else:
self.fig.savefig(self.outfile, dpi=600)
self.log.plog(100 * (each_class_index + 1) / (classno_end + 1))
if os.path.splitext(self.outfile)[-1].endswith('pdf'):
self.pdf.close()
[docs] def launch_segclasslayer_gui(self, feature_set):
self.log.fcttolog()
app = QApplication(sys.argv)
gridexplor = SegClassLayerGui(feature_set)
gridexplor.show()
app.exec_()
[docs] def extract_and_visualize_layer_lines(self):
if self.batch_mode:
self.print_layerlines_as_requested()
else:
self.launch_segclasslayer_gui(self.feature_set)
self.log.endlog(self.feature_set)
[docs]def main():
# Option handling
parset = SegClassLayerPar()
mergeparset = OptHandler(parset)
######## Program
stack = SegClassLayer(mergeparset)
stack.extract_and_visualize_layer_lines()
if __name__ == '__main__':
main()