Source code for segclassmodel_mpi

# Author: Carsten Sachse 
# Copyright: EMBL (2010 - 2018), Forschungszentrum Juelich (2019 - 2021)
# License: see license.txt for details

from spring.csinfrastr.csproductivity import OpenMpi
from spring.segment3d.segclassmodel import SegClassModelPar, SegClassModel
import json
import os


[docs]class SegClassModelMpi(SegClassModel):
[docs] def prepare_merged_stack_of_projections_mpi(self): sr3d = self.prepare_sr3d_object_with_filter_settings() projection_size, projection_info = self.prepare_prj_through_series_of_models() for each_model_id, each_reference in enumerate(self.references): if self.rank == 0: reference_info, pixelinfo = self.prepare_volume_for_projection(projection_size, sr3d, each_model_id, each_reference) reference_info = list(reference_info) pixelinfo = list(pixelinfo) else: reference_info = None pixelinfo = None reference_info = self.comm.bcast(reference_info, root=0) pixelinfo = self.comm.bcast(pixelinfo, root=0) pixelinfo_nt = sr3d.make_pixel_info_named_tuple() reference_info_nt = sr3d.make_reference_info_named_tuple() pixelinfo = pixelinfo_nt._make(pixelinfo) reference_info = reference_info_nt._make(reference_info) projection_stack, projection_parameters, fine_projection_stack, fine_projection_parameters = \ sr3d.project_through_reference_volume_in_helical_perspectives('medium', reference_info.model_id, reference_info.ref_file, pixelinfo, reference_info.helical_symmetry, reference_info.rotational_symmetry) if self.rank == 0: projection_info = self.summarize_prj_info(projection_info, each_model_id, projection_stack, projection_parameters) self.comm.barrier() if self.rank == 0: merged_prj_stack, projection_info = self.merge_prj_stacks_and_collect_prj_info(projection_info) dfile = open('prj_info.dat', 'w') json.dump(projection_info, dfile, indent=4, sort_keys=False) dfile.close() else: merged_prj_stack = None projection_info = None self.comm.barrier() return merged_prj_stack, projection_info
[docs] def match_reprojections_to_classes(self): self.comm, self.rank, self.size, self.log, self.tempdir = OpenMpi().setup_mpi_and_simultaneous_logging(self.log, self.feature_set.logfile, self.temppath) merged_prj_stack, prj_info = self.prepare_merged_stack_of_projections_mpi() os.rmdir(self.tempdir) if self.rank == 0: self.log.endlog(self.feature_set)
[docs]def main(): parset = SegClassModelPar() reduced_parset = OpenMpi().start_main_mpi(parset) ####### Program class_average = SegClassModelMpi(reduced_parset) class_average.match_reprojections_to_classes()
if __name__ == '__main__': main()