Python nets.nasnet.nasnet_utils.global_avg_pool() Examples
The following are 30
code examples of nets.nasnet.nasnet_utils.global_avg_pool().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
nets.nasnet.nasnet_utils
, or try the search function
.
Example #1
Source File: nasnet_utils_test.py From style_swap_tensorflow with Apache License 2.0 | 5 votes |
def testGlobalAvgPool(self): data_formats = ['NHWC', 'NCHW'] inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) for data_format in data_formats: output = nasnet_utils.global_avg_pool( inputs, data_format) self.assertEqual(output.shape, [5, 10])
Example #2
Source File: nasnet_utils_test.py From nasnet-tensorflow with Apache License 2.0 | 5 votes |
def testGlobalAvgPool(self): data_formats = ['NHWC', 'NCHW'] inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) for data_format in data_formats: output = nasnet_utils.global_avg_pool( inputs, data_format) self.assertEqual(output.shape, [5, 10])
Example #3
Source File: nasnet_utils_test.py From models with Apache License 2.0 | 5 votes |
def testGlobalAvgPool(self): data_formats = ['NHWC', 'NCHW'] inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) for data_format in data_formats: output = nasnet_utils.global_avg_pool( inputs, data_format) self.assertEqual(output.shape, [5, 10])
Example #4
Source File: nasnet_utils_test.py From edafa with MIT License | 5 votes |
def testGlobalAvgPool(self): data_formats = ['NHWC', 'NCHW'] inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) for data_format in data_formats: output = nasnet_utils.global_avg_pool( inputs, data_format) self.assertEqual(output.shape, [5, 10])
Example #5
Source File: nasnet.py From tf-pose with Apache License 2.0 | 5 votes |
def build_nasnet_cifar( images, num_classes, is_training=True): """Build NASNet model for the Cifar Dataset.""" hparams = _cifar_config(is_training=is_training) if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': tf.logging.info('A GPU is available on the machine, consider using NCHW ' 'data format for increased speed on GPU.') if hparams.data_format == 'NCHW': images = tf.transpose(images, [0, 3, 1, 2]) # Calculate the total number of cells in the network # Add 2 for the reduction cells total_num_cells = hparams.num_cells + 2 normal_cell = nasnet_utils.NasNetANormalCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) reduction_cell = nasnet_utils.NasNetAReductionCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], is_training=is_training): with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm, slim.separable_conv2d, nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim], data_format=hparams.data_format): return _build_nasnet_base(images, normal_cell=normal_cell, reduction_cell=reduction_cell, num_classes=num_classes, hparams=hparams, is_training=is_training, stem_type='cifar')
Example #6
Source File: nasnet.py From nasnet-tensorflow with Apache License 2.0 | 5 votes |
def build_nasnet_cifar( images, num_classes, is_training=True): """Build NASNet model for the Cifar Dataset.""" hparams = _cifar_config(is_training=is_training) if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': tf.logging.info('A GPU is available on the machine, consider using NCHW ' 'data format for increased speed on GPU.') if hparams.data_format == 'NCHW': images = tf.transpose(images, [0, 3, 1, 2]) # Calculate the total number of cells in the network # Add 2 for the reduction cells total_num_cells = hparams.num_cells + 2 normal_cell = nasnet_utils.NasNetANormalCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) reduction_cell = nasnet_utils.NasNetAReductionCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], is_training=is_training): with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm, slim.separable_conv2d, nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim], data_format=hparams.data_format): return _build_nasnet_base(images, normal_cell=normal_cell, reduction_cell=reduction_cell, num_classes=num_classes, hparams=hparams, is_training=is_training, stem_type='cifar')
Example #7
Source File: nasnet_utils_test.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def testGlobalAvgPool(self): data_formats = ['NHWC', 'NCHW'] inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) for data_format in data_formats: output = nasnet_utils.global_avg_pool( inputs, data_format) self.assertEqual(output.shape, [5, 10])
Example #8
Source File: nasnet_utils_test.py From CVTron with Apache License 2.0 | 5 votes |
def testGlobalAvgPool(self): data_formats = ['NHWC', 'NCHW'] inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) for data_format in data_formats: output = nasnet_utils.global_avg_pool( inputs, data_format) self.assertEqual(output.shape, [5, 10])
Example #9
Source File: nasnet_utils_test.py From tf-pose with Apache License 2.0 | 5 votes |
def testGlobalAvgPool(self): data_formats = ['NHWC', 'NCHW'] inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) for data_format in data_formats: output = nasnet_utils.global_avg_pool( inputs, data_format) self.assertEqual(output.shape, [5, 10])
Example #10
Source File: nasnet_utils_test.py From yolo_v2 with Apache License 2.0 | 5 votes |
def testGlobalAvgPool(self): data_formats = ['NHWC', 'NCHW'] inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) for data_format in data_formats: output = nasnet_utils.global_avg_pool( inputs, data_format) self.assertEqual(output.shape, [5, 10])
Example #11
Source File: nasnet.py From yolo_v2 with Apache License 2.0 | 5 votes |
def build_nasnet_cifar( images, num_classes, is_training=True): """Build NASNet model for the Cifar Dataset.""" hparams = _cifar_config(is_training=is_training) if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': tf.logging.info('A GPU is available on the machine, consider using NCHW ' 'data format for increased speed on GPU.') if hparams.data_format == 'NCHW': images = tf.transpose(images, [0, 3, 1, 2]) # Calculate the total number of cells in the network # Add 2 for the reduction cells total_num_cells = hparams.num_cells + 2 normal_cell = nasnet_utils.NasNetANormalCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) reduction_cell = nasnet_utils.NasNetAReductionCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], is_training=is_training): with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm, slim.separable_conv2d, nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim], data_format=hparams.data_format): return _build_nasnet_base(images, normal_cell=normal_cell, reduction_cell=reduction_cell, num_classes=num_classes, hparams=hparams, is_training=is_training, stem_type='cifar')
Example #12
Source File: nasnet_utils_test.py From MAX-Object-Detector with Apache License 2.0 | 5 votes |
def testGlobalAvgPool(self): data_formats = ['NHWC', 'NCHW'] inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) for data_format in data_formats: output = nasnet_utils.global_avg_pool( inputs, data_format) self.assertEqual(output.shape, [5, 10])
Example #13
Source File: nasnet_utils_test.py From Gun-Detector with Apache License 2.0 | 5 votes |
def testGlobalAvgPool(self): data_formats = ['NHWC', 'NCHW'] inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) for data_format in data_formats: output = nasnet_utils.global_avg_pool( inputs, data_format) self.assertEqual(output.shape, [5, 10])
Example #14
Source File: nasnet.py From object_detection_with_tensorflow with MIT License | 5 votes |
def build_nasnet_cifar( images, num_classes, is_training=True): """Build NASNet model for the Cifar Dataset.""" hparams = _cifar_config(is_training=is_training) if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': tf.logging.info('A GPU is available on the machine, consider using NCHW ' 'data format for increased speed on GPU.') if hparams.data_format == 'NCHW': images = tf.transpose(images, [0, 3, 1, 2]) # Calculate the total number of cells in the network # Add 2 for the reduction cells total_num_cells = hparams.num_cells + 2 normal_cell = nasnet_utils.NasNetANormalCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) reduction_cell = nasnet_utils.NasNetAReductionCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], is_training=is_training): with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm, slim.separable_conv2d, nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim], data_format=hparams.data_format): return _build_nasnet_base(images, normal_cell=normal_cell, reduction_cell=reduction_cell, num_classes=num_classes, hparams=hparams, is_training=is_training, stem_type='cifar')
Example #15
Source File: nasnet_utils_test.py From CBAM-tensorflow-slim with MIT License | 5 votes |
def testGlobalAvgPool(self): data_formats = ['NHWC', 'NCHW'] inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) for data_format in data_formats: output = nasnet_utils.global_avg_pool( inputs, data_format) self.assertEqual(output.shape, [5, 10])
Example #16
Source File: nasnet.py From style_swap_tensorflow with Apache License 2.0 | 5 votes |
def build_nasnet_cifar( images, num_classes, is_training=True): """Build NASNet model for the Cifar Dataset.""" hparams = _cifar_config(is_training=is_training) if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': tf.logging.info('A GPU is available on the machine, consider using NCHW ' 'data format for increased speed on GPU.') if hparams.data_format == 'NCHW': images = tf.transpose(images, [0, 3, 1, 2]) # Calculate the total number of cells in the network # Add 2 for the reduction cells total_num_cells = hparams.num_cells + 2 normal_cell = nasnet_utils.NasNetANormalCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) reduction_cell = nasnet_utils.NasNetAReductionCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], is_training=is_training): with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm, slim.separable_conv2d, nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim], data_format=hparams.data_format): return _build_nasnet_base(images, normal_cell=normal_cell, reduction_cell=reduction_cell, num_classes=num_classes, hparams=hparams, is_training=is_training, stem_type='cifar')
Example #17
Source File: nasnet_utils_test.py From MAX-Image-Segmenter with Apache License 2.0 | 5 votes |
def testGlobalAvgPool(self): data_formats = ['NHWC', 'NCHW'] inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) for data_format in data_formats: output = nasnet_utils.global_avg_pool( inputs, data_format) self.assertEqual(output.shape, [5, 10])
Example #18
Source File: nasnet_utils_test.py From TwinGAN with Apache License 2.0 | 5 votes |
def testGlobalAvgPool(self): data_formats = ['NHWC', 'NCHW'] inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) for data_format in data_formats: output = nasnet_utils.global_avg_pool( inputs, data_format) self.assertEqual(output.shape, [5, 10])
Example #19
Source File: nasnet_utils_test.py From DeepLab_v3 with MIT License | 5 votes |
def testGlobalAvgPool(self): data_formats = ['NHWC', 'NCHW'] inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) for data_format in data_formats: output = nasnet_utils.global_avg_pool( inputs, data_format) self.assertEqual(output.shape, [5, 10])
Example #20
Source File: nasnet.py From TwinGAN with Apache License 2.0 | 5 votes |
def build_nasnet_cifar( images, num_classes, is_training=True): """Build NASNet model for the Cifar Dataset.""" hparams = _cifar_config(is_training=is_training) if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': tf.logging.info('A GPU is available on the machine, consider using NCHW ' 'data format for increased speed on GPU.') if hparams.data_format == 'NCHW': images = tf.transpose(images, [0, 3, 1, 2]) # Calculate the total number of cells in the network # Add 2 for the reduction cells total_num_cells = hparams.num_cells + 2 normal_cell = nasnet_utils.NasNetANormalCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) reduction_cell = nasnet_utils.NasNetAReductionCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], is_training=is_training): with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm, slim.separable_conv2d, nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim], data_format=hparams.data_format): return _build_nasnet_base(images, normal_cell=normal_cell, reduction_cell=reduction_cell, num_classes=num_classes, hparams=hparams, is_training=is_training, stem_type='cifar')
Example #21
Source File: nasnet_utils_test.py From BMW-TensorFlow-Training-GUI with Apache License 2.0 | 5 votes |
def testGlobalAvgPool(self): data_formats = ['NHWC', 'NCHW'] inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) for data_format in data_formats: output = nasnet_utils.global_avg_pool( inputs, data_format) self.assertEqual(output.shape, [5, 10])
Example #22
Source File: nasnet_utils_test.py From object_detection_with_tensorflow with MIT License | 5 votes |
def testGlobalAvgPool(self): data_formats = ['NHWC', 'NCHW'] inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) for data_format in data_formats: output = nasnet_utils.global_avg_pool( inputs, data_format) self.assertEqual(output.shape, [5, 10])
Example #23
Source File: nasnet.py From Creative-Adversarial-Networks with MIT License | 5 votes |
def build_nasnet_cifar( images, num_classes, is_training=True): """Build NASNet model for the Cifar Dataset.""" hparams = _cifar_config(is_training=is_training) if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': tf.logging.info('A GPU is available on the machine, consider using NCHW ' 'data format for increased speed on GPU.') if hparams.data_format == 'NCHW': images = tf.transpose(images, [0, 3, 1, 2]) # Calculate the total number of cells in the network # Add 2 for the reduction cells total_num_cells = hparams.num_cells + 2 normal_cell = nasnet_utils.NasNetANormalCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) reduction_cell = nasnet_utils.NasNetAReductionCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], is_training=is_training): with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm, slim.separable_conv2d, nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim], data_format=hparams.data_format): return _build_nasnet_base(images, normal_cell=normal_cell, reduction_cell=reduction_cell, num_classes=num_classes, hparams=hparams, is_training=is_training, stem_type='cifar')
Example #24
Source File: nasnet_utils_test.py From Creative-Adversarial-Networks with MIT License | 5 votes |
def testGlobalAvgPool(self): data_formats = ['NHWC', 'NCHW'] inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) for data_format in data_formats: output = nasnet_utils.global_avg_pool( inputs, data_format) self.assertEqual(output.shape, [5, 10])
Example #25
Source File: nasnet_utils_test.py From SENet-tensorflow-slim with MIT License | 5 votes |
def testGlobalAvgPool(self): data_formats = ['NHWC', 'NCHW'] inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) for data_format in data_formats: output = nasnet_utils.global_avg_pool( inputs, data_format) self.assertEqual(output.shape, [5, 10])
Example #26
Source File: pnasnet.py From models with Apache License 2.0 | 4 votes |
def build_pnasnet_large(images, num_classes, is_training=True, final_endpoint=None, config=None): """Build PNASNet Large model for the ImageNet Dataset.""" hparams = copy.deepcopy(config) if config else large_imagenet_config() # pylint: disable=protected-access nasnet._update_hparams(hparams, is_training) # pylint: enable=protected-access if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': tf.logging.info( 'A GPU is available on the machine, consider using NCHW ' 'data format for increased speed on GPU.') if hparams.data_format == 'NCHW': images = tf.transpose(a=images, perm=[0, 3, 1, 2]) # Calculate the total number of cells in the network. # There is no distinction between reduction and normal cells in PNAS so the # total number of cells is equal to the number normal cells plus the number # of stem cells (two by default). total_num_cells = hparams.num_cells + 2 normal_cell = PNasNetNormalCell(hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps, hparams.use_bounded_activation) with arg_scope( [slim.dropout, nasnet_utils.drop_path, slim.batch_norm], is_training=is_training): with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm, slim.separable_conv2d, nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim], data_format=hparams.data_format): return _build_pnasnet_base( images, normal_cell=normal_cell, num_classes=num_classes, hparams=hparams, is_training=is_training, final_endpoint=final_endpoint)
Example #27
Source File: pnasnet.py From g-tensorflow-models with Apache License 2.0 | 4 votes |
def build_pnasnet_large(images, num_classes, is_training=True, final_endpoint=None, config=None): """Build PNASNet Large model for the ImageNet Dataset.""" hparams = copy.deepcopy(config) if config else large_imagenet_config() # pylint: disable=protected-access nasnet._update_hparams(hparams, is_training) # pylint: enable=protected-access if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': tf.logging.info('A GPU is available on the machine, consider using NCHW ' 'data format for increased speed on GPU.') if hparams.data_format == 'NCHW': images = tf.transpose(images, [0, 3, 1, 2]) # Calculate the total number of cells in the network. # There is no distinction between reduction and normal cells in PNAS so the # total number of cells is equal to the number normal cells plus the number # of stem cells (two by default). total_num_cells = hparams.num_cells + 2 normal_cell = PNasNetNormalCell(hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps, hparams.use_bounded_activation) with arg_scope( [slim.dropout, nasnet_utils.drop_path, slim.batch_norm], is_training=is_training): with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm, slim.separable_conv2d, nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim], data_format=hparams.data_format): return _build_pnasnet_base( images, normal_cell=normal_cell, num_classes=num_classes, hparams=hparams, is_training=is_training, final_endpoint=final_endpoint)
Example #28
Source File: nasnet.py From TwinGAN with Apache License 2.0 | 4 votes |
def build_nasnet_large(images, num_classes, is_training=True, final_endpoint=None): """Build NASNet Large model for the ImageNet Dataset.""" hparams = _large_imagenet_config(is_training=is_training) if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': tf.logging.info('A GPU is available on the machine, consider using NCHW ' 'data format for increased speed on GPU.') if hparams.data_format == 'NCHW': images = tf.transpose(images, [0, 3, 1, 2]) # Calculate the total number of cells in the network # Add 2 for the reduction cells total_num_cells = hparams.num_cells + 2 # If ImageNet, then add an additional two for the stem cells total_num_cells += 2 normal_cell = nasnet_utils.NasNetANormalCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) reduction_cell = nasnet_utils.NasNetAReductionCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps) with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], is_training=is_training): with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm, slim.separable_conv2d, nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim], data_format=hparams.data_format): return _build_nasnet_base(images, normal_cell=normal_cell, reduction_cell=reduction_cell, num_classes=num_classes, hparams=hparams, is_training=is_training, stem_type='imagenet', final_endpoint=final_endpoint)
Example #29
Source File: pnasnet.py From g-tensorflow-models with Apache License 2.0 | 4 votes |
def build_pnasnet_mobile(images, num_classes, is_training=True, final_endpoint=None, config=None): """Build PNASNet Mobile model for the ImageNet Dataset.""" hparams = copy.deepcopy(config) if config else mobile_imagenet_config() # pylint: disable=protected-access nasnet._update_hparams(hparams, is_training) # pylint: enable=protected-access if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': tf.logging.info('A GPU is available on the machine, consider using NCHW ' 'data format for increased speed on GPU.') if hparams.data_format == 'NCHW': images = tf.transpose(images, [0, 3, 1, 2]) # Calculate the total number of cells in the network. # There is no distinction between reduction and normal cells in PNAS so the # total number of cells is equal to the number normal cells plus the number # of stem cells (two by default). total_num_cells = hparams.num_cells + 2 normal_cell = PNasNetNormalCell(hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps, hparams.use_bounded_activation) with arg_scope( [slim.dropout, nasnet_utils.drop_path, slim.batch_norm], is_training=is_training): with arg_scope( [ slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm, slim.separable_conv2d, nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim ], data_format=hparams.data_format): return _build_pnasnet_base( images, normal_cell=normal_cell, num_classes=num_classes, hparams=hparams, is_training=is_training, final_endpoint=final_endpoint)
Example #30
Source File: nasnet.py From g-tensorflow-models with Apache License 2.0 | 4 votes |
def build_nasnet_cifar(images, num_classes, is_training=True, config=None, current_step=None): """Build NASNet model for the Cifar Dataset.""" hparams = cifar_config() if config is None else copy.deepcopy(config) _update_hparams(hparams, is_training) if tf.test.is_gpu_available() and hparams.data_format == 'NHWC': tf.logging.info('A GPU is available on the machine, consider using NCHW ' 'data format for increased speed on GPU.') if hparams.data_format == 'NCHW': images = tf.transpose(images, [0, 3, 1, 2]) # Calculate the total number of cells in the network # Add 2 for the reduction cells total_num_cells = hparams.num_cells + 2 normal_cell = nasnet_utils.NasNetANormalCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps, hparams.use_bounded_activation) reduction_cell = nasnet_utils.NasNetAReductionCell( hparams.num_conv_filters, hparams.drop_path_keep_prob, total_num_cells, hparams.total_training_steps, hparams.use_bounded_activation) with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm], is_training=is_training): with arg_scope([slim.avg_pool2d, slim.max_pool2d, slim.conv2d, slim.batch_norm, slim.separable_conv2d, nasnet_utils.factorized_reduction, nasnet_utils.global_avg_pool, nasnet_utils.get_channel_index, nasnet_utils.get_channel_dim], data_format=hparams.data_format): return _build_nasnet_base(images, normal_cell=normal_cell, reduction_cell=reduction_cell, num_classes=num_classes, hparams=hparams, is_training=is_training, stem_type='cifar', current_step=current_step)