# Author: Carsten Sachse 1-Jan-2012
# Copyright: EMBL (2010 - 2018), Forschungszentrum Juelich (2019 - 2021)
# License: see license.txt for details
from EMAN2 import EMData
from spring.csinfrastr.csproductivity import OpenMpi
from spring.segment2d.segmentexam import SegmentExam, SegmentExamPar
import os
[docs]class SegmentExamMpi(SegmentExam):
[docs] def generate_local_name_for_reduction(self, emdata_file, rank):
"""
>>> from spring.segment2d.segmentexam_mpi import SegmentExamMpi
>>> SegmentExamMpi().generate_local_name_for_reduction('ps_234567891.hdf', 2)
'ps_2345678912.hdf'
"""
local_emdata_file = '{pre}{rank}{ext}'.format(pre=os.path.splitext(emdata_file)[0], rank=rank,
ext=os.path.splitext(emdata_file)[1])
return local_emdata_file
[docs] def reduce_emdata_from_memory_on_main_node(self, widthavg):
widthavg_file = 'width_avg.hdf'
local_width_avg = self.generate_local_name_for_reduction(widthavg_file, self.rank)
widthavg.write_image(local_width_avg)
local_width_avg = self.comm.gather(local_width_avg, root=0)
if self.rank == 0:
widthavg = OpenMpi().reduce_emdata_on_main_node(widthavg, local_width_avg)
else:
widthavg = None
return widthavg
[docs] def prepare_segmentexam_mpi(self):
self.comm, self.rank, self.size, self.log, self.tempdir = OpenMpi().setup_mpi_and_simultaneous_logging(self.log,
self.feature_set.logfile, self.temppath)
if self.rank == 0:
segment_ids = self.copy_database_and_filter_segment_ids()
segment_ids = OpenMpi().split_sequence_evenly(segment_ids, self.size)
else:
segment_ids = None
segment_ids = self.comm.scatter(segment_ids, root=0)
binned_stack = self.tempdir + os.path.basename(self.infilestack)
self.infilestack, self.segsizepix, self.helixwidthpix, self.pixelsize = self.apply_binfactor(self.binfactor,
self.infilestack, self.segsizepix, self.helixwidthpix, self.pixelsize, segment_ids, binned_stack)
self.segsizepix = self.comm.bcast(self.segsizepix, root=0)
self.helixwidthpix = self.comm.bcast(self.helixwidthpix, root=0)
self.pixelsize = self.comm.bcast(self.pixelsize, root=0)
self.comm.barrier()
if self.rank == 0:
self.log.plog(10)
return segment_ids
[docs] def add_powers_locally_and_reduce_on_main_node(self, segment_ids):
masked_infilestack = os.path.join(self.tempdir, 'infilestack-masked.hdf')
power_infilestack = os.path.join(self.tempdir, 'infilestack-power.hdf')
avg_periodogram = self.add_power_spectra_from_verticalized_stack(self.infilestack, segment_ids,
self.helixwidthpix, masked_infilestack, power_infilestack)
local_power_img = self.generate_local_name_for_reduction(self.power_img, self.rank)
if self.enhanced_power_option:
local_power_enhanced_img = self.generate_local_name_for_reduction(self.power_enhanced_img, self.rank)
else:
local_power_enhanced_img = self.generate_local_name_for_reduction(os.path.splitext(self.power_img)[0] + \
'_enh' + os.path.splitext(self.power_img)[-1], self.rank)
avg_periodogram_enhanced = self.write_avg_periodograms(avg_periodogram, local_power_img,
local_power_enhanced_img)
self.comm.barrier()
local_power_img = self.comm.gather(local_power_img, root=0)
local_power_enhanced_img = self.comm.gather(local_power_enhanced_img, root=0)
if self.rank == 0:
avg_periodogram = OpenMpi().reduce_emdata_on_main_node(avg_periodogram, local_power_img)
avg_periodogram.write_image(self.power_img)
avg_periodogram_enhanced = OpenMpi().reduce_emdata_on_main_node(avg_periodogram_enhanced,
local_power_enhanced_img)
if self.enhanced_power_option:
avg_periodogram_enhanced.write_image(self.power_enhanced_img)
avg_collapsed_power_line, avg_collapsed_line_enhanced = self.collapse_periodograms(avg_periodogram,
avg_periodogram_enhanced)
self.log.plog(60)
else:
avg_collapsed_power_line = None
avg_collapsed_line_enhanced = None
self.comm.barrier()
return avg_periodogram, power_infilestack, masked_infilestack, avg_periodogram_enhanced, \
avg_collapsed_power_line, avg_collapsed_line_enhanced
[docs] def correlate_layer_line_region_mpi(self, segment_ids, avg_periodogram, power_infilestack):
if self.layer_ccc_option:
avg_periodogram = EMData()
avg_periodogram.read_image(self.power_img)
correlations = self.correlate_layer_lines_of_average_power_with_individual_segments(avg_periodogram,
power_infilestack, segment_ids)
correlations = self.comm.gather(correlations, root=0)
segment_ids = self.comm.gather(segment_ids, root=0)
if self.rank == 0:
correlations = OpenMpi().merge_sequence_of_sequences(correlations)
segment_ids = OpenMpi().merge_sequence_of_sequences(segment_ids)
self.enter_correlation_values_in_database(correlations, segment_ids)
self.comm.barrier()
os.remove(power_infilestack)
[docs] def determine_width_from_collapsed_profile_mpi(self, segment_ids, masked_infilestack):
temp_rowsadd, widths = self.determine_width(masked_infilestack, self.segsizepix, segment_ids)
common_masked_stack = 'common_masked.hdf'
OpenMpi().gather_stacks_from_cpus_to_common_stack(self.comm, masked_infilestack, common_masked_stack)
common_rows = 'common_rows.hdf'
OpenMpi().gather_stacks_from_cpus_to_common_stack(self.comm, temp_rowsadd, common_rows)
self.widths = self.comm.gather(widths, root=0)
return common_masked_stack, common_rows
[docs] def visualize_avg_var_widths_and_power_spectra_mpi(self, avg_periodogram, avg_periodogram_enhanced,
avg_collapsed_power_line, avg_collapsed_line_enhanced, common_masked_stack, common_rows):
if self.rank == 0:
widthavg, widthvar, twodavg, twodvar = self.compute_avg_and_var_of_width_and_image(common_masked_stack,
common_rows)
os.remove(common_masked_stack)
self.widths = OpenMpi().merge_sequence_of_sequences(self.widths)
self.visualize_power_avg_and_width_analysis(widthavg, widthvar, self.widths, twodavg, twodvar,
avg_periodogram, avg_periodogram_enhanced, avg_collapsed_power_line, avg_collapsed_line_enhanced)
self.log.plog(80)
self.cleanup(self.infilestack)
[docs] def add_up_power_spectra(self):
segment_ids = self.prepare_segmentexam_mpi()
avg_periodogram, power_infilestack, masked_infilestack, avg_periodogram_enhanced, avg_collapsed_power_line, \
avg_collapsed_line_enhanced = self.add_powers_locally_and_reduce_on_main_node(segment_ids)
self.correlate_layer_line_region_mpi(segment_ids, avg_periodogram, power_infilestack)
common_masked_stack, common_rows = self.determine_width_from_collapsed_profile_mpi(segment_ids, masked_infilestack)
self.visualize_avg_var_widths_and_power_spectra_mpi(avg_periodogram, avg_periodogram_enhanced,
avg_collapsed_power_line, avg_collapsed_line_enhanced, common_masked_stack, common_rows)
os.rmdir(self.tempdir)
self.comm.barrier()
if self.rank == 0:
self.log.endlog(self.feature_set)
[docs]def main():
parset = SegmentExamPar()
reduced_parset = OpenMpi().start_main_mpi(parset)
####### Program
stack = SegmentExamMpi(reduced_parset)
stack.add_up_power_spectra()
if __name__ == '__main__':
main()