# Author: Carsten Sachse 8-Jul-2011
# Copyright: EMBL (2010 - 2018), Forschungszentrum Juelich (2019 - 2021)
# License: see license.txt for details
from spring.csinfrastr.csproductivity import OpenMpi
from spring.segment2d.segment import Segment, SegmentPar
import os
[docs]class SegmentMpiPreparation(Segment):
[docs] def prepare_segmentation_mpi(self):
self.comm, self.rank, self.size, self.log, tempdir = OpenMpi().setup_mpi_and_simultaneous_logging(self.log,
self.feature_set.logfile, self.temppath)
if self.rank == 0:
assigned_mics = self.validate_input()
pair = self.assign_reorganize(self.micrograph_files, self.coordinate_files)
helices, assigned_stack_ids, assigned_helix_ids = self.single_out(pair, self.stepsize, self.pixelsize,
assigned_mics)
helices = OpenMpi().convert_list_of_namedtuples_to_list_of_lists(helices)
helices = OpenMpi().split_sequence_evenly(helices, self.size)
else:
helices = None
assigned_stack_ids = None
assigned_helix_ids = None
helices = self.comm.scatter(helices, root=0)
helixinfo = self.make_helixinfo_named_tuple()
helices = OpenMpi().convert_list_of_lists_to_list_of_provided_namedtuple(helices, helixinfo)
return helices, tempdir, assigned_stack_ids, assigned_helix_ids
[docs]class SegmentMpi(SegmentMpiPreparation):
[docs] def gather_distributed_helices_to_root(self, comm, helices):
helices = OpenMpi().convert_list_of_namedtuples_to_list_of_lists(helices)
helices = comm.gather(helices, root=0)
if comm.rank == 0:
helices = OpenMpi().merge_sequence_of_sequences(helices)
helixinfo = self.make_helixinfo_named_tuple()
helices = OpenMpi().convert_list_of_lists_to_list_of_provided_namedtuple(helices, helixinfo)
else:
helices = None
return helices
[docs] def finish_segmentation_mpi(self, tempdir, imgstack, local_windowed_stack, assigned_stack_ids, assigned_helix_ids):
OpenMpi().gather_stacks_from_cpus_to_common_stack(self.comm, local_windowed_stack, imgstack)
self.helices = self.gather_distributed_helices_to_root(self.comm, self.helices)
self.comm.barrier()
if self.rank == 0:
self.enter_helix_parameters_in_database(self.helices, assigned_stack_ids, assigned_helix_ids)
os.rmdir(tempdir)
self.comm.barrier()
if self.rank == 0:
self.log.endlog(self.feature_set)
[docs] def segment(self):
self.helices, self.tempdir, assigned_stack_ids, assigned_helix_ids = self.prepare_segmentation_mpi()
local_windowed_stack, imgstack = self.extract_segments_mpi(self.tempdir)
self.perform_binning_if_demanded(imgstack, local_windowed_stack)
self.finish_segmentation_mpi(self.tempdir, imgstack, local_windowed_stack,
assigned_stack_ids, assigned_helix_ids)
[docs]def main():
parset = SegmentPar()
reduced_parset = OpenMpi().start_main_mpi(parset)
####### Program
stack = SegmentMpi(reduced_parset)
stack.segment()
if __name__ == '__main__':
main()