Python SimpleITK.BSplineTransformInitializer() Examples

The following are 5 code examples of SimpleITK.BSplineTransformInitializer(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module SimpleITK , or try the search function .
Example #1
Source File: data_augmentation.py    From dataset_loaders with GNU General Public License v3.0 8 votes vote down vote up
def gen_warp_field(shape, sigma=0.1, grid_size=3):
    '''Generate an spline warp field'''
    import SimpleITK as sitk
    # Initialize bspline transform
    args = shape+(sitk.sitkFloat32,)
    ref_image = sitk.Image(*args)
    tx = sitk.BSplineTransformInitializer(ref_image, [grid_size, grid_size])

    # Initialize shift in control points:
    # mesh size = number of control points - spline order
    p = sigma * np.random.randn(grid_size+3, grid_size+3, 2)

    # Anchor the edges of the image
    p[:, 0, :] = 0
    p[:, -1:, :] = 0
    p[0, :, :] = 0
    p[-1:, :, :] = 0

    # Set bspline transform parameters to the above shifts
    tx.SetParameters(p.flatten())

    # Compute deformation field
    displacement_filter = sitk.TransformToDisplacementFieldFilter()
    displacement_filter.SetReferenceImage(ref_image)
    displacement_field = displacement_filter.Execute(tx)

    return displacement_field 
Example #2
Source File: utilities.py    From VNet with GNU General Public License v3.0 5 votes vote down vote up
def produceRandomlyDeformedImage(image, label, numcontrolpoints, stdDef):
    sitkImage=sitk.GetImageFromArray(image, isVector=False)
    sitklabel=sitk.GetImageFromArray(label, isVector=False)

    transfromDomainMeshSize=[numcontrolpoints]*sitkImage.GetDimension()

    tx = sitk.BSplineTransformInitializer(sitkImage,transfromDomainMeshSize)


    params = tx.GetParameters()

    paramsNp=np.asarray(params,dtype=float)
    paramsNp = paramsNp + np.random.randn(paramsNp.shape[0])*stdDef

    paramsNp[0:int(len(params)/3)]=0 #remove z deformations! The resolution in z is too bad

    params=tuple(paramsNp)
    tx.SetParameters(params)

    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(sitkImage)
    resampler.SetInterpolator(sitk.sitkLinear)
    resampler.SetDefaultPixelValue(0)
    resampler.SetTransform(tx)

    resampler.SetDefaultPixelValue(0)
    outimgsitk = resampler.Execute(sitkImage)
    outlabsitk = resampler.Execute(sitklabel)

    outimg = sitk.GetArrayFromImage(outimgsitk)
    outimg = outimg.astype(dtype=np.float32)

    outlbl = sitk.GetArrayFromImage(outlabsitk)
    outlbl = (outlbl>0.5).astype(dtype=np.float32)

    return outimg,outlbl 
Example #3
Source File: random_elastic_deformation.py    From torchio with MIT License 5 votes vote down vote up
def get_bspline_transform(
            image: sitk.Image,
            num_control_points: Tuple[int, int, int],
            coarse_field: np.ndarray,
            ) -> sitk.BSplineTransformInitializer:
        mesh_shape = [n - SPLINE_ORDER for n in num_control_points]
        bspline_transform = sitk.BSplineTransformInitializer(image, mesh_shape)
        parameters = coarse_field.flatten(order='F').tolist()
        bspline_transform.SetParameters(parameters)
        return bspline_transform 
Example #4
Source File: utils.py    From Brats2019 with MIT License 4 votes vote down vote up
def produceRandomlyDeformedImage(image, label, numcontrolpoints, stdDef, seed=1):
        '''
        This function comes from V-net,deform a image by B-spine interpolation
        :param image: images ,numpy array
        :param label: labels,numpy array
        :param numcontrolpoints: control point,B-spine interpolation parameters,take 2 for default
        :param stdDef: Deviation,B-spine interpolation parameters,take 15 for default
        :return: Deformed images and GT in numpy array
        '''
        sitkImage = sitk.GetImageFromArray(image, isVector=False)
        sitklabel = sitk.GetImageFromArray(label, isVector=False)

        transfromDomainMeshSize = [numcontrolpoints] * sitkImage.GetDimension()

        tx = sitk.BSplineTransformInitializer(
            sitkImage, transfromDomainMeshSize)

        params = tx.GetParameters()

        paramsNp = np.asarray(params, dtype=float)
        # 设置种子值,确保多通道时两个通道变换程度一样
        np.random.seed(seed)
        paramsNp = paramsNp + np.random.randn(paramsNp.shape[0]) * stdDef

        # remove z deformations! The resolution in z is too bad
        paramsNp[0:int(len(params) / 3)] = 0

        params = tuple(paramsNp)
        tx.SetParameters(params)

        resampler = sitk.ResampleImageFilter()
        resampler.SetReferenceImage(sitkImage)
        resampler.SetInterpolator(sitk.sitkLinear)
        resampler.SetDefaultPixelValue(0)
        resampler.SetTransform(tx)

        resampler.SetDefaultPixelValue(0)
        outimgsitk = resampler.Execute(sitkImage)
        outlabsitk = resampler.Execute(sitklabel)

        outimg = sitk.GetArrayFromImage(outimgsitk)
        outimg = outimg.astype(dtype=np.float32)

        outlbl = sitk.GetArrayFromImage(outlabsitk)
        # outlbl = (outlbl > 0.5).astype(dtype=np.float32)

        return outimg, outlbl 
Example #5
Source File: data_augmentation.py    From Automated-Cardiac-Segmentation-and-Disease-Diagnosis with MIT License 4 votes vote down vote up
def produceRandomlyDeformedImage(image, label, numcontrolpoints=2, stdDef=15):
    sitkImage=sitk.GetImageFromArray(image, isVector=False)
    sitklabel=sitk.GetImageFromArray(label, isVector=False)

    transfromDomainMeshSize=[numcontrolpoints]*sitkImage.GetDimension()

    tx = sitk.BSplineTransformInitializer(sitkImage,transfromDomainMeshSize)


    params = tx.GetParameters()

    paramsNp=np.asarray(params,dtype=float)
    paramsNp = paramsNp + np.random.randn(paramsNp.shape[0])*stdDef
    #remove z deformations! The resolution in z is too bad in case of 3D or its channels in 2D
    paramsNp[0:int(len(params)/3)]=0 #remove z deformations! The resolution in z is too bad in case of 3D or its channels

    params=tuple(paramsNp)
    tx.SetParameters(params)
    # print (sitkImage.GetSize(), sitklabel.GetSize(), transfromDomainMeshSize, paramsNp.shape)

    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(sitkImage)
    resampler.SetInterpolator(sitk.sitkLinear)
    resampler.SetDefaultPixelValue(0)
    resampler.SetTransform(tx)

    outimgsitk = resampler.Execute(sitkImage)

    # For Label use nearest neighbour
    resampler.SetReferenceImage(sitklabel)
    resampler.SetInterpolator(sitk.sitkLabelGaussian)
    resampler.SetDefaultPixelValue(0)
    outlabsitk = resampler.Execute(sitklabel)

    outimg = sitk.GetArrayFromImage(outimgsitk)
    outimg = outimg.astype(dtype=np.float32)

    outlbl = sitk.GetArrayFromImage(outlabsitk)
    outlbl = (outlbl).astype(dtype=np.uint8)
    return outimg, outlbl

# ********************************Augmentation Transforms**************************#