Python SimpleITK.BSplineTransformInitializer() Examples
The following are 5
code examples of SimpleITK.BSplineTransformInitializer().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
SimpleITK
, or try the search function
.
Example #1
Source File: data_augmentation.py From dataset_loaders with GNU General Public License v3.0 | 8 votes |
def gen_warp_field(shape, sigma=0.1, grid_size=3): '''Generate an spline warp field''' import SimpleITK as sitk # Initialize bspline transform args = shape+(sitk.sitkFloat32,) ref_image = sitk.Image(*args) tx = sitk.BSplineTransformInitializer(ref_image, [grid_size, grid_size]) # Initialize shift in control points: # mesh size = number of control points - spline order p = sigma * np.random.randn(grid_size+3, grid_size+3, 2) # Anchor the edges of the image p[:, 0, :] = 0 p[:, -1:, :] = 0 p[0, :, :] = 0 p[-1:, :, :] = 0 # Set bspline transform parameters to the above shifts tx.SetParameters(p.flatten()) # Compute deformation field displacement_filter = sitk.TransformToDisplacementFieldFilter() displacement_filter.SetReferenceImage(ref_image) displacement_field = displacement_filter.Execute(tx) return displacement_field
Example #2
Source File: utilities.py From VNet with GNU General Public License v3.0 | 5 votes |
def produceRandomlyDeformedImage(image, label, numcontrolpoints, stdDef): sitkImage=sitk.GetImageFromArray(image, isVector=False) sitklabel=sitk.GetImageFromArray(label, isVector=False) transfromDomainMeshSize=[numcontrolpoints]*sitkImage.GetDimension() tx = sitk.BSplineTransformInitializer(sitkImage,transfromDomainMeshSize) params = tx.GetParameters() paramsNp=np.asarray(params,dtype=float) paramsNp = paramsNp + np.random.randn(paramsNp.shape[0])*stdDef paramsNp[0:int(len(params)/3)]=0 #remove z deformations! The resolution in z is too bad params=tuple(paramsNp) tx.SetParameters(params) resampler = sitk.ResampleImageFilter() resampler.SetReferenceImage(sitkImage) resampler.SetInterpolator(sitk.sitkLinear) resampler.SetDefaultPixelValue(0) resampler.SetTransform(tx) resampler.SetDefaultPixelValue(0) outimgsitk = resampler.Execute(sitkImage) outlabsitk = resampler.Execute(sitklabel) outimg = sitk.GetArrayFromImage(outimgsitk) outimg = outimg.astype(dtype=np.float32) outlbl = sitk.GetArrayFromImage(outlabsitk) outlbl = (outlbl>0.5).astype(dtype=np.float32) return outimg,outlbl
Example #3
Source File: random_elastic_deformation.py From torchio with MIT License | 5 votes |
def get_bspline_transform( image: sitk.Image, num_control_points: Tuple[int, int, int], coarse_field: np.ndarray, ) -> sitk.BSplineTransformInitializer: mesh_shape = [n - SPLINE_ORDER for n in num_control_points] bspline_transform = sitk.BSplineTransformInitializer(image, mesh_shape) parameters = coarse_field.flatten(order='F').tolist() bspline_transform.SetParameters(parameters) return bspline_transform
Example #4
Source File: utils.py From Brats2019 with MIT License | 4 votes |
def produceRandomlyDeformedImage(image, label, numcontrolpoints, stdDef, seed=1): ''' This function comes from V-net,deform a image by B-spine interpolation :param image: images ,numpy array :param label: labels,numpy array :param numcontrolpoints: control point,B-spine interpolation parameters,take 2 for default :param stdDef: Deviation,B-spine interpolation parameters,take 15 for default :return: Deformed images and GT in numpy array ''' sitkImage = sitk.GetImageFromArray(image, isVector=False) sitklabel = sitk.GetImageFromArray(label, isVector=False) transfromDomainMeshSize = [numcontrolpoints] * sitkImage.GetDimension() tx = sitk.BSplineTransformInitializer( sitkImage, transfromDomainMeshSize) params = tx.GetParameters() paramsNp = np.asarray(params, dtype=float) # 设置种子值,确保多通道时两个通道变换程度一样 np.random.seed(seed) paramsNp = paramsNp + np.random.randn(paramsNp.shape[0]) * stdDef # remove z deformations! The resolution in z is too bad paramsNp[0:int(len(params) / 3)] = 0 params = tuple(paramsNp) tx.SetParameters(params) resampler = sitk.ResampleImageFilter() resampler.SetReferenceImage(sitkImage) resampler.SetInterpolator(sitk.sitkLinear) resampler.SetDefaultPixelValue(0) resampler.SetTransform(tx) resampler.SetDefaultPixelValue(0) outimgsitk = resampler.Execute(sitkImage) outlabsitk = resampler.Execute(sitklabel) outimg = sitk.GetArrayFromImage(outimgsitk) outimg = outimg.astype(dtype=np.float32) outlbl = sitk.GetArrayFromImage(outlabsitk) # outlbl = (outlbl > 0.5).astype(dtype=np.float32) return outimg, outlbl
Example #5
Source File: data_augmentation.py From Automated-Cardiac-Segmentation-and-Disease-Diagnosis with MIT License | 4 votes |
def produceRandomlyDeformedImage(image, label, numcontrolpoints=2, stdDef=15): sitkImage=sitk.GetImageFromArray(image, isVector=False) sitklabel=sitk.GetImageFromArray(label, isVector=False) transfromDomainMeshSize=[numcontrolpoints]*sitkImage.GetDimension() tx = sitk.BSplineTransformInitializer(sitkImage,transfromDomainMeshSize) params = tx.GetParameters() paramsNp=np.asarray(params,dtype=float) paramsNp = paramsNp + np.random.randn(paramsNp.shape[0])*stdDef #remove z deformations! The resolution in z is too bad in case of 3D or its channels in 2D paramsNp[0:int(len(params)/3)]=0 #remove z deformations! The resolution in z is too bad in case of 3D or its channels params=tuple(paramsNp) tx.SetParameters(params) # print (sitkImage.GetSize(), sitklabel.GetSize(), transfromDomainMeshSize, paramsNp.shape) resampler = sitk.ResampleImageFilter() resampler.SetReferenceImage(sitkImage) resampler.SetInterpolator(sitk.sitkLinear) resampler.SetDefaultPixelValue(0) resampler.SetTransform(tx) outimgsitk = resampler.Execute(sitkImage) # For Label use nearest neighbour resampler.SetReferenceImage(sitklabel) resampler.SetInterpolator(sitk.sitkLabelGaussian) resampler.SetDefaultPixelValue(0) outlabsitk = resampler.Execute(sitklabel) outimg = sitk.GetArrayFromImage(outimgsitk) outimg = outimg.astype(dtype=np.float32) outlbl = sitk.GetArrayFromImage(outlabsitk) outlbl = (outlbl).astype(dtype=np.uint8) return outimg, outlbl # ********************************Augmentation Transforms**************************#