Python SimpleITK.sitkFloat32() Examples

The following are 22 code examples of SimpleITK.sitkFloat32(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module SimpleITK , or try the search function .
Example #1
Source File: utils.py    From Brats2019 with MIT License 9 votes vote down vote up
def N4BiasFieldCorrection(src_path, dst_path):
        '''
        This function carry out BiasFieldCorrection for the files in a specific directory
        :param src_path: path of the source file
        :param dst_path: path of the target file
        :return:
        '''
        print("N4 bias correction runs.")
        inputImage = sitk.ReadImage(src_path)

        maskImage = sitk.OtsuThreshold(inputImage, 0, 1, 200)
        sitk.WriteImage(maskImage, dst_path)

        inputImage = sitk.Cast(inputImage, sitk.sitkFloat32)

        corrector = sitk.N4BiasFieldCorrectionImageFilter()

        # corrector.SetMaximumNumberOfIterations(10)

        output = corrector.Execute(inputImage, maskImage)
        sitk.WriteImage(output, dst_path)
        print("Finished N4 Bias Field Correction.....")

    # normalize the data(zero mean and unit variance) 
Example #2
Source File: random_motion.py    From torchio with MIT License 8 votes vote down vote up
def resample_images(
            image: sitk.Image,
            transforms: List[sitk.Euler3DTransform],
            interpolation: Interpolation,
            ) -> List[sitk.Image]:
        floating = reference = image
        default_value = np.float64(sitk.GetArrayViewFromImage(image).min())
        transforms = transforms[1:]  # first is identity
        images = [image]  # first is identity
        for transform in transforms:
            resampler = sitk.ResampleImageFilter()
            resampler.SetInterpolator(get_sitk_interpolator(interpolation))
            resampler.SetReferenceImage(reference)
            resampler.SetOutputPixelType(sitk.sitkFloat32)
            resampler.SetDefaultPixelValue(default_value)
            resampler.SetTransform(transform)
            resampled = resampler.Execute(floating)
            images.append(resampled)
        return images 
Example #3
Source File: data_augmentation.py    From dataset_loaders with GNU General Public License v3.0 8 votes vote down vote up
def gen_warp_field(shape, sigma=0.1, grid_size=3):
    '''Generate an spline warp field'''
    import SimpleITK as sitk
    # Initialize bspline transform
    args = shape+(sitk.sitkFloat32,)
    ref_image = sitk.Image(*args)
    tx = sitk.BSplineTransformInitializer(ref_image, [grid_size, grid_size])

    # Initialize shift in control points:
    # mesh size = number of control points - spline order
    p = sigma * np.random.randn(grid_size+3, grid_size+3, 2)

    # Anchor the edges of the image
    p[:, 0, :] = 0
    p[:, -1:, :] = 0
    p[0, :, :] = 0
    p[-1:, :, :] = 0

    # Set bspline transform parameters to the above shifts
    tx.SetParameters(p.flatten())

    # Compute deformation field
    displacement_filter = sitk.TransformToDisplacementFieldFilter()
    displacement_filter.SetReferenceImage(ref_image)
    displacement_field = displacement_filter.Execute(tx)

    return displacement_field 
Example #4
Source File: data_writer.py    From NiftyMIC with BSD 3-Clause "New" or "Revised" License 7 votes vote down vote up
def write_image(
        image_sitk,
        path_to_file,
        compress=True,
        verbose=True,
        description=None,
    ):
        info = "Write image to %s" % path_to_file
        if compress:
            image_sitk = sitk.Cast(image_sitk, sitk.sitkFloat32)
            info += " (float32)"
        if verbose:
            ph.print_info("%s ... " % info, newline=False)
        header_update = DataWriter._get_header_update(description=description)

        sitkh.write_nifti_image_sitk(
            image_sitk, path_to_file, header_update=header_update)
        if verbose:
            print("done") 
Example #5
Source File: preprocessing.py    From gdl-fire-4d with GNU General Public License v3.0 6 votes vote down vote up
def normalize_image(image,
                    valid_min=-1024,
                    valid_max=3071):
    image = sitk.Cast(image, sitk.sitkFloat32)
    f_min_max = sitk.MinimumMaximumImageFilter()
    f_min_max.Execute(image)
    min_ = f_min_max.GetMinimum()
    max_ = f_min_max.GetMaximum()
    log.debug(f'Got image with value range [{min_}, {max_}]')
    if min_ < valid_min or max_ > valid_max:
        log.warning(
            f'Got image with non-default hounsfield scale range: Got range ' \
            f'[{min_}, {max_}]. Values will be clipped to [{valid_min}, {valid_max}].'
        )
        f_clamp = sitk.ClampImageFilter()
        f_clamp.SetLowerBound(valid_min)
        f_clamp.SetUpperBound(valid_max)
        image = f_clamp.Execute(image)

    f_subtract = sitk.SubtractImageFilter()
    image = f_subtract.Execute(image, valid_min)
    f_divide = sitk.DivideImageFilter()

    return f_divide.Execute(image, valid_max - valid_min) 
Example #6
Source File: heatmap_test.py    From MedicalDataAugmentationTool with GNU General Public License v3.0 6 votes vote down vote up
def get_transformed_image_sitk(self, prediction_np, reference_sitk=None, output_spacing=None, transformation=None):
        """
        Returns a list of transformed sitk images from the prediction np array for the given reference image and transformation.
        :param prediction_np: The np array to transform.
        :param reference_sitk: The reference sitk image.
        :param output_spacing: The output spacing of the np array.
        :param transformation: The transformation. If transformation is None, the prediction np array will not be transformed.
        :return: A list of transformed sitk images.
        """
        if transformation is not None:
            predictions_sitk = utils.sitk_image.transform_np_output_to_sitk_input(output_image=prediction_np,
                                                                                  output_spacing=output_spacing,
                                                                                  channel_axis=None,
                                                                                  input_image_sitk=reference_sitk,
                                                                                  transform=transformation,
                                                                                  interpolator=self.interpolator,
                                                                                  output_pixel_type=sitk.sitkFloat32)
        else:
            predictions_np = utils.np_image.split_by_axis(prediction_np, self.channel_axis)
            predictions_sitk = [utils.sitk_np.np_to_sitk(prediction_np) for prediction_np in predictions_np]
        return predictions_sitk 
Example #7
Source File: segmentation_test.py    From MedicalDataAugmentationTool with GNU General Public License v3.0 6 votes vote down vote up
def get_transformed_image(self, prediction_np, reference_sitk=None, output_spacing=None, transformation=None):
        """
        Returns the transformed predictions as a list of sitk images. If the transformation is None, the prediction_np image
        will not be transformed, but only split and converted to a list of sitk images.
        :param prediction_np: The predicted np array.
        :param reference_sitk: The reference sitk image from which origin/spacing/direction is taken from.
        :param output_spacing: The output spacing of the prediction_np array.
        :param transformation: The sitk transformation used to transform the reference_sitk image to the network input.
        :return: A list of the transformed sitk predictions.
        """
        if transformation is not None:
            predictions_sitk = utils.sitk_image.transform_np_output_to_sitk_input(output_image=prediction_np,
                                                                                  output_spacing=output_spacing,
                                                                                  channel_axis=self.channel_axis,
                                                                                  input_image_sitk=reference_sitk,
                                                                                  transform=transformation,
                                                                                  interpolator=self.interpolator,
                                                                                  output_pixel_type=sitk.sitkFloat32)
            prediction_np = utils.sitk_np.sitk_list_to_np(predictions_sitk, axis=0)
        else:
            if self.channel_axis != 0:
                prediction_np = utils.np_image.split_by_axis(prediction_np, self.channel_axis)
                prediction_np = np.stack(prediction_np, axis=0)
        return prediction_np 
Example #8
Source File: itkRegionGrow.py    From MedImg_Py_Library with MIT License 6 votes vote down vote up
def itkRegionGrow(im_arr, type_str, seedlist, lower=0, upper=1):

    func_1 = 'CT'
    func_2 = 'CC'

    # get an image from the array input
    image = itkEdgePreservedSmoothing.itkEdgePreservedSmoothing(im_arr, 'CF')
    image = sitk.GetImageFromArray(im_arr)
    image = sitk.Cast(image, sitk.sitkFloat32)

    # find out the way to process the image according to type_str
    if type_str == func_1:
        im_new = sitk.ConnectedThreshold(image, seedlist, lower, upper)

    elif type_str == func_2:
        im_new = sitk.ConfidenceConnected(image, seedlist)

    else:
        print('Please check your spelling,'
              'and try again.')
    return im_new


# an example of using the function 
Example #9
Source File: dataReader.py    From MARL-for-Anatomical-Landmark-Detection with Apache License 2.0 5 votes vote down vote up
def decode(self, filename,label=False):
        """ decode a single nifti image
        Args
          filename: string for input images
          label: True if nifti image is label
        Returns
          image: an image container with attributes; name, data, dims
        """
        image = ImageRecord()
        image.name = filename
        assert self._is_nifti(image.name), "unknown image format for %r" % image.name

        if label:
            sitk_image = sitk.ReadImage(image.name, sitk.sitkInt8)
        else:
            sitk_image = sitk.ReadImage(image.name, sitk.sitkFloat32)
            np_image = sitk.GetArrayFromImage(sitk_image)
            # threshold image between p10 and p98 then re-scale [0-255]
            p0 = np_image.min().astype('float')
            p10 = np.percentile(np_image,10)
            p99 = np.percentile(np_image,99)
            p100 = np_image.max().astype('float')
            # logger.info('p0 {} , p5 {} , p10 {} , p90 {} , p98 {} , p100 {}'.format(p0,p5,p10,p90,p98,p100))
            sitk_image = sitk.Threshold(sitk_image,
                                        lower=p10,
                                        upper=p100,
                                        outsideValue=p10)
            sitk_image = sitk.Threshold(sitk_image,
                                        lower=p0,
                                        upper=p99,
                                        outsideValue=p99)
            sitk_image = sitk.RescaleIntensity(sitk_image,
                                               outputMinimum=0,
                                               outputMaximum=255)

        # Convert from [depth, width, height] to [width, height, depth]
        image.data = sitk.GetArrayFromImage(sitk_image).transpose(2,1,0)#.astype('uint8')
        image.dims = np.shape(image.data)

        return sitk_image, image 
Example #10
Source File: getPatchImageAndMask.py    From LiTS---Liver-Tumor-Segmentation-Challenge with MIT License 5 votes vote down vote up
def load_itk(filename):
    """
    load mhd files and normalization 0-255
    :param filename:
    :return:
    """
    rescalFilt = sitk.RescaleIntensityImageFilter()
    rescalFilt.SetOutputMaximum(255)
    rescalFilt.SetOutputMinimum(0)
    # Reads the image using SimpleITK
    itkimage = rescalFilt.Execute(sitk.Cast(sitk.ReadImage(filename), sitk.sitkFloat32))
    return itkimage 
Example #11
Source File: sampleTrain.py    From rl-medical with Apache License 2.0 5 votes vote down vote up
def decode(self, filename,label=False):
        """ decode a single nifti image
        Args
          filename: string for input images
          label: True if nifti image is label
        Returns
          image: an image container with attributes; name, data, dims
        """
        image = ImageRecord()
        image.name = filename
        assert self._is_nifti(image.name), "unknown image format for %r" % image.name

        if label:
            sitk_image = sitk.ReadImage(image.name, sitk.sitkInt8)
        else:
            sitk_image = sitk.ReadImage(image.name, sitk.sitkFloat32)
            np_image = sitk.GetArrayFromImage(sitk_image)
            # threshold image between p10 and p98 then re-scale [0-255]
            p0 = np_image.min().astype('float')
            p10 = np.percentile(np_image,10)
            p99 = np.percentile(np_image,99)
            p100 = np_image.max().astype('float')
            # logger.info('p0 {} , p5 {} , p10 {} , p90 {} , p98 {} , p100 {}'.format(p0,p5,p10,p90,p98,p100))
            sitk_image = sitk.Threshold(sitk_image,
                                        lower=p10,
                                        upper=p100,
                                        outsideValue=p10)
            sitk_image = sitk.Threshold(sitk_image,
                                        lower=p0,
                                        upper=p99,
                                        outsideValue=p99)
            sitk_image = sitk.RescaleIntensity(sitk_image,
                                               outputMinimum=0,
                                               outputMaximum=255)

        # Convert from [depth, width, height] to [width, height, depth]
        # stupid simpleitk
        image.data = sitk.GetArrayFromImage(sitk_image).transpose(2,1,0)#.astype('uint8')
        image.dims = np.shape(image.data)

        return sitk_image, image 
Example #12
Source File: dataReader.py    From rl-medical with Apache License 2.0 5 votes vote down vote up
def decode(self, filename, label=False):
        """ decode a single nifti image
        Args
          filename: string for input images
          label: True if nifti image is label
        Returns
          image: an image container with attributes; name, data, dims
        """
        image = ImageRecord()
        image.name = filename
        assert self._is_nifti(image.name), "unknown image format for %r" % image.name

        if label:
            sitk_image = sitk.ReadImage(image.name, sitk.sitkInt8)
        else:
            sitk_image = sitk.ReadImage(image.name, sitk.sitkFloat32)
            np_image = sitk.GetArrayFromImage(sitk_image)
            # threshold image between p10 and p98 then re-scale [0-255]
            p0 = np_image.min().astype('float')
            p10 = np.percentile(np_image, 10)
            p99 = np.percentile(np_image, 99)
            p100 = np_image.max().astype('float')
            # logger.info('p0 {} , p5 {} , p10 {} , p90 {} , p98 {} , p100 {}'.format(p0,p5,p10,p90,p98,p100))
            sitk_image = sitk.Threshold(sitk_image,
                                        lower=p10,
                                        upper=p100,
                                        outsideValue=p10)
            sitk_image = sitk.Threshold(sitk_image,
                                        lower=p0,
                                        upper=p99,
                                        outsideValue=p99)
            sitk_image = sitk.RescaleIntensity(sitk_image,
                                               outputMinimum=0,
                                               outputMaximum=255)

        # Convert from [depth, width, height] to [width, height, depth]
        image.data = sitk.GetArrayFromImage(sitk_image).transpose(2, 1, 0) #.astype('uint8')
        image.dims = np.shape(image.data)

        return sitk_image, image 
Example #13
Source File: processing.py    From istn with Apache License 2.0 5 votes vote down vote up
def resample_image(image, out_spacing=(1.0, 1.0, 1.0), out_size=None, is_label=False, pad_value=0):
    """Resamples an image to given element spacing and output size."""

    original_spacing = np.array(image.GetSpacing())
    original_size = np.array(image.GetSize())

    if out_size is None:
        out_size = np.round(np.array(original_size * original_spacing / np.array(out_spacing))).astype(int)
    else:
        out_size = np.array(out_size)

    original_direction = np.array(image.GetDirection()).reshape(len(original_spacing),-1)
    original_center = (np.array(original_size, dtype=float) - 1.0) / 2.0 * original_spacing
    out_center = (np.array(out_size, dtype=float) - 1.0) / 2.0 * np.array(out_spacing)

    original_center = np.matmul(original_direction, original_center)
    out_center = np.matmul(original_direction, out_center)
    out_origin = np.array(image.GetOrigin()) + (original_center - out_center)

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(out_spacing)
    resample.SetSize(out_size.tolist())
    resample.SetOutputDirection(image.GetDirection())
    resample.SetOutputOrigin(out_origin.tolist())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(pad_value)

    if is_label:
        resample.SetInterpolator(sitk.sitkNearestNeighbor)
    else:
        #resample.SetInterpolator(sitk.sitkBSpline)
        resample.SetInterpolator(sitk.sitkLinear)

    return resample.Execute(sitk.Cast(image, sitk.sitkFloat32)) 
Example #14
Source File: image.py    From airlab with Apache License 2.0 5 votes vote down vote up
def read_image_as_tensor(filename, dtype=th.float32, device='cpu'):

    itk_image = sitk.ReadImage(filename, sitk.sitkFloat32)

    return create_tensor_image_from_itk_image(itk_image, dtype=dtype, device=device) 
Example #15
Source File: image.py    From airlab with Apache License 2.0 5 votes vote down vote up
def read(filename, dtype=th.float32, device='cpu'):
        """
        Static method to directly read an image through the Image class

        filename (str): filename of the image
        dtype: specific dtype for representing the tensor
        device: on which device the image has to be allocated
        return (Image): an airlab image
        """
        return Image(sitk.ReadImage(filename, sitk.sitkFloat32), dtype, device) 
Example #16
Source File: random_elastic_deformation.py    From torchio with MIT License 5 votes vote down vote up
def apply_bspline_transform(
            self,
            tensor: torch.Tensor,
            affine: np.ndarray,
            bspline_params: np.ndarray,
            interpolation: Interpolation,
            ) -> torch.Tensor:
        assert tensor.dim() == 4
        assert len(tensor) == 1
        image = self.nib_to_sitk(tensor[0], affine)
        floating = reference = image
        bspline_transform = self.get_bspline_transform(
            image,
            self.num_control_points,
            bspline_params,
        )
        self.parse_free_form_transform(
            bspline_transform, self.max_displacement)
        resampler = sitk.ResampleImageFilter()
        resampler.SetReferenceImage(reference)
        resampler.SetTransform(bspline_transform)
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
        resampler.SetDefaultPixelValue(tensor.min().item())
        resampler.SetOutputPixelType(sitk.sitkFloat32)
        resampled = resampler.Execute(floating)

        np_array = sitk.GetArrayFromImage(resampled)
        np_array = np_array.transpose()  # ITK to NumPy
        tensor[0] = torch.from_numpy(np_array)
        return tensor 
Example #17
Source File: DataManager.py    From VNet with GNU General Public License v3.0 5 votes vote down vote up
def loadGT(self):
        self.sitkGT=dict()

        for f in self.gtList:
            self.sitkGT[f]=sitk.Cast(sitk.ReadImage(join(self.srcFolder, f))>0.5,sitk.sitkFloat32) 
Example #18
Source File: DataManager.py    From VNet with GNU General Public License v3.0 5 votes vote down vote up
def loadImages(self):
        self.sitkImages=dict()
        rescalFilt=sitk.RescaleIntensityImageFilter()
        rescalFilt.SetOutputMaximum(1)
        rescalFilt.SetOutputMinimum(0)

        stats = sitk.StatisticsImageFilter()
        m = 0.
        for f in self.fileList:
            self.sitkImages[f]=rescalFilt.Execute(sitk.Cast(sitk.ReadImage(join(self.srcFolder, f)),sitk.sitkFloat32))
            stats.Execute(self.sitkImages[f])
            m += stats.GetMean()

        self.meanIntensityTrain=m/len(self.sitkImages) 
Example #19
Source File: registration.py    From DeepBrainSeg with MIT License 4 votes vote down vote up
def register_patient(self, moving_images, 
                            fixed_image, 
                            save_path,
                            save_transform=True,
                            isotropic=True):
        """
        moving_images : {'key1': path1, 'key2': path2}
        fixed_image :t1c path
        save_path: save path 
        """
        fixed_name = fixed_image.split('/').pop().split('.')[0]
        fixed_image =  sitk.ReadImage(fixed_image, sitk.sitkFloat32)
        coregistration_path = os.path.join(save_path, 'registered')
        isotropic_path = os.path.join(save_path, 'isotropic')
        transform_path = os.path.join(save_path, 'transforms')

        if not os.path.exists(coregistration_path):
            os.makedirs(coregistration_path, exist_ok=True)

        if isotropic:
            if not os.path.exists(isotropic_path):
                os.makedirs(isotropic_path, exist_ok=True)

        if save_transform:
            if not os.path.exists(transform_path):
                os.makedirs(transform_path, exist_ok=True)

        for key in moving_images.keys():
            moving_image = sitk.ReadImage(moving_images[key], sitk.sitkFloat32)
            initial_transform = sitk.CenteredTransformInitializer(fixed_image, 
                                                      moving_image, 
                                                      sitk.VersorRigid3DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)

            
            self.registration_method.SetInitialTransform(initial_transform, inPlace=False)
            final_transform = self.registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                              sitk.Cast(moving_image, sitk.sitkFloat32))
            
            print("[INFO: DeepBrainSeg] (" + strftime("%a, %d %b %Y %H:%M:%S +0000", gmtime()) + ") " +  'Final metric value: {0}'.format(self.registration_method.GetMetricValue()))
            print("[INFO: DeepBrainSeg] (" + strftime("%a, %d %b %Y %H:%M:%S +0000", gmtime()) + ") " +  'Optimizer\'s stopping condition, {0}'.format(self.registration_method.GetOptimizerStopConditionDescription()))
            
            moving_resampled= sitk.Resample(moving_image, 
                                            fixed_image, 
                                            final_transform, 
                                            sitk.sitkLinear, 0.0, 
                                            moving_image.GetPixelID())
            
            sitk.WriteImage(moving_resampled, os.path.join(coregistration_path, key+'.nii.gz'))
            sitk.WriteTransform(final_transform, os.path.join(transform_path, key+'.tfm'))
            # Write Fixed image in nii.gz
            if isotropic:
                print("[INFO: DeepBrainSeg] (" + strftime("%a, %d %b %Y %H:%M:%S +0000", gmtime()) + ") " +  'converting to isotropic volume')
                moving_resized = self.resize_sitk_3D(moving_resampled)
                sitk.WriteImage(moving_resized, os.path.join(isotropic_path, key+'.nii.gz'))

        sitk.WriteImage(fixed_image, os.path.join(coregistration_path, fixed_name+'.nii.gz'))            
        if isotropic:
            fixed_resized = self.resize_sitk_3D(fixed_image)
            sitk.WriteImage(fixed_resized, os.path.join(isotropic_path, fixed_name+'.nii.gz')) 
Example #20
Source File: itkEdgePreservedSmoothing.py    From MedImg_Py_Library with MIT License 4 votes vote down vote up
def itkEdgePreservedSmoothing(im_arr, type_str):

    func_1 = 'B'
    func_2 = 'MC'
    func_3 = 'CF'
    func_4 = 'CAD'
    func_5 = 'GAD'

    # get an image from the array input

    image = sitk.GetImageFromArray(im_arr)
    image = sitk.Cast(image, sitk.sitkFloat32)

    # find out the way to process the image according to type_str
    # smooth the image
    if type_str == func_1:

        im_new = sitk.Bilateral(image)

    elif type_str == func_2:

        im_new = sitk.MinMaxCurvatureFlow(image)

    elif type_str == func_3:

        im_new = sitk.CurvatureFlow(image)

    elif type_str == func_4:

        im_new = sitk.CurvatureAnisotropicDiffusion(image)

    elif type_str == func_5:

        im_new = sitk.GradientAnisotropicDiffusion(image)

    else:
        print('Please check your spelling,'
              'and try again.')
    return im_new


# an example of using the function 
Example #21
Source File: sitk_image.py    From MedicalDataAugmentationTool with GNU General Public License v3.0 4 votes vote down vote up
def surface_distance(label_image_0, label_image_1):
    # code adapted from https://insightsoftwareconsortium.github.io/SimpleITK-Notebooks/Python_html/34_Segmentation_Evaluation.html
    try:
        # calculate distances on label contours
        reference_distance_map = sitk.SignedMaurerDistanceMap(label_image_1, squaredDistance=False, useImageSpacing=True)
        reference_distance_map_arr = sitk.GetArrayViewFromImage(reference_distance_map)
        reference_surface = sitk.LabelContour(label_image_1)
        reference_surface_arr = sitk.GetArrayViewFromImage(reference_surface)

        segmented_distance_map = sitk.SignedMaurerDistanceMap(label_image_0, squaredDistance=False, useImageSpacing=True)
        segmented_distance_map_arr = sitk.GetArrayViewFromImage(segmented_distance_map)
        segmented_surface = sitk.LabelContour(label_image_0)
        segmented_surface_arr = sitk.GetArrayViewFromImage(segmented_surface)

        seg2ref_distances = np.abs(reference_distance_map_arr[segmented_surface_arr == 1])
        ref2seg_distances = np.abs(segmented_distance_map_arr[reference_surface_arr == 1])

        all_surface_distances = np.concatenate([seg2ref_distances, ref2seg_distances])

        # # Multiply the binary surface segmentations with the distance maps. The resulting distance
        # # maps contain non-zero values only on the surface (they can also contain zero on the surface)
        # seg2ref_distance_map = reference_distance_map * sitk.Cast(segmented_surface, sitk.sitkFloat32)
        # ref2seg_distance_map = segmented_distance_map * sitk.Cast(reference_surface, sitk.sitkFloat32)
        #
        # statistics_image_filter = sitk.StatisticsImageFilter()
        # # Get the number of pixels in the reference surface by counting all pixels that are 1.
        # statistics_image_filter.Execute(reference_surface)
        # num_reference_surface_pixels = int(statistics_image_filter.GetSum())
        # # Get the number of pixels in the reference surface by counting all pixels that are 1.
        # statistics_image_filter.Execute(segmented_surface)
        # num_segmented_surface_pixels = int(statistics_image_filter.GetSum())
        #
        # # Get all non-zero distances and then add zero distances if required.
        # seg2ref_distance_map_arr = sitk.GetArrayViewFromImage(seg2ref_distance_map)
        # seg2ref_distances = list(seg2ref_distance_map_arr[seg2ref_distance_map_arr != 0])
        # seg2ref_distances = seg2ref_distances + list(np.zeros(num_segmented_surface_pixels - len(seg2ref_distances)))
        # ref2seg_distance_map_arr = sitk.GetArrayViewFromImage(ref2seg_distance_map)
        # ref2seg_distances = list(ref2seg_distance_map_arr[ref2seg_distance_map_arr != 0])
        # ref2seg_distances = ref2seg_distances + list(np.zeros(num_reference_surface_pixels - len(ref2seg_distances)))
        #
        # all_surface_distances = seg2ref_distances + ref2seg_distances

        current_mean_surface_distance = np.mean(all_surface_distances)
        current_median_surface_distance = np.median(all_surface_distances)
        current_std_surface_distance = np.std(all_surface_distances)
        current_max_surface_distance = np.max(all_surface_distances)
    except:
        current_mean_surface_distance = np.nan
        current_median_surface_distance = np.nan
        current_std_surface_distance = np.nan
        current_max_surface_distance = np.nan
        pass

    return current_mean_surface_distance, current_median_surface_distance, current_std_surface_distance, current_max_surface_distance 
Example #22
Source File: random_affine.py    From torchio with MIT License 4 votes vote down vote up
def apply_affine_transform(
            self,
            tensor: torch.Tensor,
            affine: np.ndarray,
            scaling_params: List[float],
            rotation_params: List[float],
            translation_params: List[float],
            interpolation: Interpolation,
            center_lps: Optional[TypeTripletFloat] = None,
            ) -> torch.Tensor:
        assert tensor.ndim == 4
        assert len(tensor) == 1

        image = self.nib_to_sitk(tensor[0], affine)
        floating = reference = image

        scaling_transform = self.get_scaling_transform(
            scaling_params,
            center_lps=center_lps,
        )
        rotation_transform = self.get_rotation_transform(
            rotation_params,
            translation_params,
            center_lps=center_lps,
        )
        transform = sitk.Transform(3, sitk.sitkComposite)
        transform.AddTransform(scaling_transform)
        transform.AddTransform(rotation_transform)

        if self.default_pad_value == 'minimum':
            default_value = tensor.min().item()
        elif self.default_pad_value == 'mean':
            default_value = get_borders_mean(image, filter_otsu=False)
        elif self.default_pad_value == 'otsu':
            default_value = get_borders_mean(image, filter_otsu=True)
        else:
            default_value = self.default_pad_value

        resampler = sitk.ResampleImageFilter()
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
        resampler.SetReferenceImage(reference)
        resampler.SetDefaultPixelValue(float(default_value))
        resampler.SetOutputPixelType(sitk.sitkFloat32)
        resampler.SetTransform(transform)
        resampled = resampler.Execute(floating)

        np_array = sitk.GetArrayFromImage(resampled)
        np_array = np_array.transpose()  # ITK to NumPy
        tensor[0] = torch.from_numpy(np_array)
        return tensor