Python config.BATCH_SIZE Examples
The following are 22
code examples of config.BATCH_SIZE().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
config
, or try the search function
.
Example #1
Source File: datalayer.py From tripletloss with MIT License | 6 votes |
def setup(self, bottom, top): """Setup the RoIDataLayer.""" # parse the layer parameter string, which must be valid YAML layer_params = yaml.load(self.param_str_) self._batch_size = config.BATCH_SIZE self._triplet = self._batch_size/3 assert self._batch_size % 3 == 0 self._name_to_top_map = { 'data': 0, 'labels': 1} self.data_container = sampledata() self._index = 0 # data blob: holds a batch of N images, each with 3 channels # The height and width (100 x 100) are dummy values top[0].reshape(self._batch_size, 3, 224, 224) top[1].reshape(self._batch_size)
Example #2
Source File: data_layer.py From triplet with MIT License | 6 votes |
def setup(self, bottom, top): """Setup the DataLayer.""" if cfg.TRIPLET_LOSS: self.batch_size = cfg.TRIPLET_BATCH_SIZE else: self.batch_size = cfg.BATCH_SIZE self._name_to_top_map = { 'data': 0, 'labels': 1} self._index = 0 self._epoch = 1 # data blob: holds a batch of N images, each with 3 channels # The height and width (100 x 100) are dummy values top[0].reshape(self.batch_size, 3, 224, 224) top[1].reshape(self.batch_size)
Example #3
Source File: TF_flowers.py From tensorflow_yolo2 with MIT License | 6 votes |
def __init__(self, val_split, rebuild=False, data_aug=False): self.name = 'TF_flowers' self.devkit_path = cfg.FLOWERS_PATH self.data_path = self.devkit_path self.cache_path = cfg.CACHE_PATH self.batch_size = cfg.BATCH_SIZE self.image_size = cfg.IMAGE_SIZE self.rebuild = rebuild self.data_aug = data_aug self.num_class = 5 self.classes = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] self.class_to_ind = dict( list(zip(self.classes, list(range(self.num_class))))) self.train_cursor = 0 self.val_cursor = 0 self.epoch = 1 self.gt_labels = None self.val_split = val_split assert os.path.exists(self.devkit_path), \ 'TF_flowers path does not exist: {}'.format(self.devkit_path) assert os.path.exists(self.data_path), \ 'Path does not exist: {}'.format(self.data_path) self.prepare()
Example #4
Source File: ilsvrc2017_cls.py From tensorflow_yolo2 with MIT License | 6 votes |
def __init__(self, image_set, rebuild=False, data_aug=True): self.name = 'ilsvrc_2017' self.devkit_path = cfg.ILSVRC_PATH self.data_path = self.devkit_path self.cache_path = cfg.CACHE_PATH self.batch_size = cfg.BATCH_SIZE self.image_size = cfg.IMAGE_SIZE self.image_set = image_set self.rebuild = rebuild self.data_aug = data_aug self.cursor = 0 self.load_classes() # self.gt_labels = None assert os.path.exists(self.devkit_path), \ 'VOCdevkit path does not exist: {}'.format(self.devkit_path) assert os.path.exists(self.data_path), \ 'Path does not exist: {}'.format(self.data_path) self.prepare()
Example #5
Source File: data.py From PINTO_model_zoo with MIT License | 6 votes |
def __init__(self): """ 需始终记住: small_detector对应下标索引0, medium_detector对应下标索引1,big_detector对应下标索引2 :param dataset_type: 选择加载训练样本或测试样本,必须是'train' or 'test' """ self.__dataset_path = cfg.DATASET_PATH self.__train_input_sizes = cfg.TRAIN_INPUT_SIZES self.__strides = np.array(cfg.STRIDES) self.__batch_size = cfg.BATCH_SIZE self.__classes = cfg.CLASSES self.__num_classes = len(self.__classes) self.__gt_per_grid = cfg.GT_PER_GRID self.__class_to_ind = dict(zip(self.__classes, range(self.__num_classes))) annotations_2007 = self.__load_annotations(os.path.join(self.__dataset_path, '2007_trainval')) annotations_2012 = self.__load_annotations(os.path.join(self.__dataset_path, '2012_trainval')) self.__annotations = annotations_2007 + annotations_2012 self.__num_samples = len(self.__annotations) logging.info(('The number of image for train is:').ljust(50) + str(self.__num_samples)) self.__num_batchs = np.ceil(self.__num_samples / self.__batch_size) self.__batch_count = 0
Example #6
Source File: Train.py From LeNet with MIT License | 6 votes |
def main(): mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) sess = tf.Session() batch_size = cfg.BATCH_SIZE parameter_path = cfg.PARAMETER_FILE lenet = Lenet() max_iter = cfg.MAX_ITER saver = tf.train.Saver() if os.path.exists(parameter_path): saver.restore(parameter_path) else: sess.run(tf.initialize_all_variables()) for i in range(max_iter): batch = mnist.train.next_batch(50) if i % 100 == 0: train_accuracy = sess.run(lenet.train_accuracy,feed_dict={ lenet.raw_input_image: batch[0],lenet.raw_input_label: batch[1] }) print("step %d, training accuracy %g" % (i, train_accuracy)) sess.run(lenet.train_op,feed_dict={lenet.raw_input_image: batch[0],lenet.raw_input_label: batch[1]}) save_path = saver.save(sess, parameter_path)
Example #7
Source File: chatbot.py From stanford-tensorflow-tutorials with MIT License | 5 votes |
def train(): """ Train the bot """ test_buckets, data_buckets, train_buckets_scale = _get_buckets() # in train mode, we need to create the backward path, so forwrad_only is False model = ChatBotModel(False, config.BATCH_SIZE) model.build_graph() saver = tf.train.Saver() with tf.Session() as sess: print('Running session') sess.run(tf.global_variables_initializer()) _check_restore_parameters(sess, saver) iteration = model.global_step.eval() total_loss = 0 while True: skip_step = _get_skip_step(iteration) bucket_id = _get_random_bucket(train_buckets_scale) encoder_inputs, decoder_inputs, decoder_masks = data.get_batch(data_buckets[bucket_id], bucket_id, batch_size=config.BATCH_SIZE) start = time.time() _, step_loss, _ = run_step(sess, model, encoder_inputs, decoder_inputs, decoder_masks, bucket_id, False) total_loss += step_loss iteration += 1 if iteration % skip_step == 0: print('Iter {}: loss {}, time {}'.format(iteration, total_loss/skip_step, time.time() - start)) start = time.time() total_loss = 0 saver.save(sess, os.path.join(config.CPT_PATH, 'chatbot'), global_step=model.global_step) if iteration % (10 * skip_step) == 0: # Run evals on development set and print their loss _eval_test_set(sess, model, test_buckets) start = time.time() sys.stdout.flush()
Example #8
Source File: tripletselectlayer.py From tripletloss with MIT License | 5 votes |
def setup(self, bottom, top): """Setup the TripletSelectLayer.""" self.triplet = config.BATCH_SIZE/3 top[0].reshape(self.triplet,shape(bottom[0].data)[1]) top[1].reshape(self.triplet,shape(bottom[0].data)[1]) top[2].reshape(self.triplet,shape(bottom[0].data)[1])
Example #9
Source File: chatbot.py From stanford-tensorflow-tutorials with MIT License | 5 votes |
def _eval_test_set(sess, model, test_buckets): """ Evaluate on the test set. """ for bucket_id in range(len(config.BUCKETS)): if len(test_buckets[bucket_id]) == 0: print(" Test: empty bucket %d" % (bucket_id)) continue start = time.time() encoder_inputs, decoder_inputs, decoder_masks = data.get_batch(test_buckets[bucket_id], bucket_id, batch_size=config.BATCH_SIZE) _, step_loss, _ = run_step(sess, model, encoder_inputs, decoder_inputs, decoder_masks, bucket_id, True) print('Test bucket {}: loss {}, time {}'.format(bucket_id, step_loss, time.time() - start))
Example #10
Source File: chatbot.py From stanford-tensorflow-tutorials with MIT License | 5 votes |
def train(): """ Train the bot """ test_buckets, data_buckets, train_buckets_scale = _get_buckets() # in train mode, we need to create the backward path, so forwrad_only is False model = ChatBotModel(False, config.BATCH_SIZE) model.build_graph() saver = tf.train.Saver() with tf.Session() as sess: print('Running session') sess.run(tf.global_variables_initializer()) _check_restore_parameters(sess, saver) iteration = model.global_step.eval() total_loss = 0 while True: skip_step = _get_skip_step(iteration) bucket_id = _get_random_bucket(train_buckets_scale) encoder_inputs, decoder_inputs, decoder_masks = data.get_batch(data_buckets[bucket_id], bucket_id, batch_size=config.BATCH_SIZE) start = time.time() _, step_loss, _ = run_step(sess, model, encoder_inputs, decoder_inputs, decoder_masks, bucket_id, False) total_loss += step_loss iteration += 1 if iteration % skip_step == 0: print('Iter {}: loss {}, time {}'.format(iteration, total_loss/skip_step, time.time() - start)) start = time.time() total_loss = 0 saver.save(sess, os.path.join(config.CPT_PATH, 'chatbot'), global_step=model.global_step) if iteration % (10 * skip_step) == 0: # Run evals on development set and print their loss _eval_test_set(sess, model, test_buckets) start = time.time() sys.stdout.flush()
Example #11
Source File: chatbot.py From stanford-tensorflow-tutorials with MIT License | 5 votes |
def _eval_test_set(sess, model, test_buckets): """ Evaluate on the test set. """ for bucket_id in range(len(config.BUCKETS)): if len(test_buckets[bucket_id]) == 0: print(" Test: empty bucket %d" % (bucket_id)) continue start = time.time() encoder_inputs, decoder_inputs, decoder_masks = data.get_batch(test_buckets[bucket_id], bucket_id, batch_size=config.BATCH_SIZE) _, step_loss, _ = run_step(sess, model, encoder_inputs, decoder_inputs, decoder_masks, bucket_id, True) print('Test bucket {}: loss {}, time {}'.format(bucket_id, step_loss, time.time() - start))
Example #12
Source File: model.py From MAX-Sports-Video-Classifier with Apache License 2.0 | 5 votes |
def process_frames(dirname, means, batch_size=BATCH_SIZE, num_frames_per_clip=NUM_FRAMES_PER_CLIP, crop_size=CROP_SIZE): tmp_data, _ = get_frames_data(dirname, num_frames_per_clip) img_datas = [] data = [] if len(tmp_data) != 0: for j in xrange(len(tmp_data)): img = Image.fromarray(tmp_data[j].astype(np.uint8)) if img.width > img.height: scale = float(crop_size) / float(img.height) img = np.array(cv2.resize(np.array(img), (int(img.width * scale + 1), crop_size))).astype(np.float32) else: scale = float(crop_size) / float(img.width) img = np.array(cv2.resize(np.array(img), (crop_size, int(img.height * scale + 1)))).astype(np.float32) crop_x = int((img.shape[0] - crop_size) / 2) crop_y = int((img.shape[1] - crop_size) / 2) img = img[crop_x:crop_x + crop_size, crop_y:crop_y + crop_size, :] - means[j] img_datas.append(img) data.append(img_datas) # pad (duplicate) data/label if less than batch_size valid_len = len(data) pad_len = batch_size - valid_len if pad_len: for i in range(pad_len): data.append(img_datas) np_arr_data = np.array(data).astype(np.float32) return np_arr_data
Example #13
Source File: utils.py From CapsNet with MIT License | 5 votes |
def get_iterator(mode): dataset = MNIST(root='./data', train=mode, download=True) data = getattr(dataset, 'train_data' if mode else 'test_data') labels = getattr(dataset, 'train_labels' if mode else 'test_labels') tensor_dataset = tnt.dataset.TensorDataset([data, labels]) return tensor_dataset.parallel(batch_size=config.BATCH_SIZE, num_workers=4, shuffle=mode)
Example #14
Source File: main.py From CapsNet with MIT License | 5 votes |
def on_end_epoch(state): print('[Epoch %d] Training Loss: %.4f (Accuracy: %.2f%%)' % ( state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0])) train_loss_logger.log(state['epoch'], meter_loss.value()[0]) train_accuracy_logger.log(state['epoch'], meter_accuracy.value()[0]) reset_meters() engine.test(processor, utils.get_iterator(False)) test_loss_logger.log(state['epoch'], meter_loss.value()[0]) test_accuracy_logger.log(state['epoch'], meter_accuracy.value()[0]) confusion_logger.log(confusion_meter.value()) print('[Epoch %d] Testing Loss: %.4f (Accuracy: %.2f%%)' % ( state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0])) torch.save(model.state_dict(), 'epochs/epoch_%d.pt' % state['epoch']) # reconstruction visualization test_sample = next(iter(utils.get_iterator(False))) ground_truth = (test_sample[0].unsqueeze(1).float() / 255.0) if torch.cuda.is_available(): _, reconstructions = model(Variable(ground_truth).cuda()) else: _, reconstructions = model(Variable(ground_truth)) reconstruction = reconstructions.cpu().view_as(ground_truth).data ground_truth_logger.log( make_grid(ground_truth, nrow=int(config.BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy()) reconstruction_logger.log( make_grid(reconstruction, nrow=int(config.BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy())
Example #15
Source File: ilsvrc_cls_multithread_scipy.py From tensorflow_yolo2 with MIT License | 5 votes |
def __init__(self, image_set, rebuild=False, multithread=False, batch_size=cfg.BATCH_SIZE, image_size = cfg.IMAGE_SIZE, random_noise=False): self.name = 'ilsvrc_2017_cls' self.devkit_path = cfg.ILSVRC_PATH self.data_path = self.devkit_path self.cache_path = cfg.CACHE_PATH self.batch_size = batch_size self.image_size = image_size self.image_set = image_set self.rebuild = rebuild self.multithread = multithread self.random_noise = random_noise self.load_classes() self.cursor = 0 self.epoch = 1 self.gt_labels = None assert os.path.exists(self.devkit_path), \ 'ILSVRC path does not exist: {}'.format(self.devkit_path) assert os.path.exists(self.data_path), \ 'Path does not exist: {}'.format(self.data_path) self.prepare() if self.multithread: self.prepare_multithread() self.get = self._get_multithread else: self.get = self._get
Example #16
Source File: pascal_voc.py From tensorflow_yolo2 with MIT License | 5 votes |
def __init__(self, image_set, batch_size=cfg.BATCH_SIZE, rebuild=False): self.name = 'voc_2007' self.devkit_path = cfg.PASCAL_PATH self.data_path = os.path.join(self.devkit_path, 'VOC2007') self.cache_path = cfg.CACHE_PATH self.batch_size = batch_size self.image_size = cfg.IMAGE_SIZE self.cell_size = cfg.S self.classes = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor') self.num_class = len(self.classes) self.class_to_ind = dict( list(zip(self.classes, list(range(self.num_class))))) self.flipped = cfg.FLIPPED self.image_set = image_set self.rebuild = rebuild self.cursor = 0 self.gt_labels = None assert os.path.exists(self.devkit_path), \ 'VOCdevkit path does not exist: {}'.format(self.devkit_path) assert os.path.exists(self.data_path), \ 'Path does not exist: {}'.format(self.data_path) self.prepare()
Example #17
Source File: ilsvrc2017_cls_multithread.py From tensorflow_yolo2 with MIT License | 5 votes |
def __init__(self, image_set, rebuild=False, data_aug=False, multithread=False, batch_size=cfg.BATCH_SIZE, image_size = cfg.IMAGE_SIZE, RGB=False): self.name = 'ilsvrc_2017_cls' self.devkit_path = cfg.ILSVRC_PATH self.data_path = self.devkit_path self.cache_path = cfg.CACHE_PATH self.batch_size = batch_size self.image_size = image_size self.image_set = image_set self.rebuild = rebuild self.multithread = multithread self.data_aug = data_aug self.RGB = RGB self.load_classes() self.cursor = 0 self.epoch = 1 self.gt_labels = None assert os.path.exists(self.devkit_path), \ 'ILSVRC path does not exist: {}'.format(self.devkit_path) assert os.path.exists(self.data_path), \ 'Path does not exist: {}'.format(self.data_path) self.prepare() if self.multithread: self.prepare_multithread() self.get = self._get_multithread else: self.get = self._get
Example #18
Source File: model_rpointnet.py From GSPN with MIT License | 5 votes |
def placeholder_inputs(config): pc_pl = tf.placeholder(tf.float32, shape=(config.BATCH_SIZE, config.NUM_POINT, 3)) color_pl = tf.placeholder(tf.float32, shape=(config.BATCH_SIZE, config.NUM_POINT, 3)) pc_ins_pl = tf.placeholder(tf.float32, shape=(config.BATCH_SIZE, config.NUM_GROUP, config.NUM_POINT_INS, 3)) group_label_pl = tf.placeholder(tf.int32, shape=(config.BATCH_SIZE, config.NUM_POINT)) group_indicator_pl = tf.placeholder(tf.int32, shape=(config.BATCH_SIZE, config.NUM_GROUP)) seg_label_pl = tf.placeholder(tf.int32, shape=(config.BATCH_SIZE, config.NUM_POINT)) bbox_ins_pl = tf.placeholder(tf.float32, shape=(config.BATCH_SIZE, config.NUM_GROUP, 6)) return pc_pl, color_pl, pc_ins_pl, group_label_pl, group_indicator_pl, seg_label_pl, bbox_ins_pl
Example #19
Source File: prepare_data.py From TensorFlow2.0_ResNet with MIT License | 5 votes |
def generate_datasets(): train_dataset, train_count = get_dataset(dataset_root_dir=config.train_dir) valid_dataset, valid_count = get_dataset(dataset_root_dir=config.valid_dir) test_dataset, test_count = get_dataset(dataset_root_dir=config.test_dir) # read the original_dataset in the form of batch train_dataset = train_dataset.shuffle(buffer_size=train_count).batch(batch_size=config.BATCH_SIZE) valid_dataset = valid_dataset.batch(batch_size=config.BATCH_SIZE) test_dataset = test_dataset.batch(batch_size=config.BATCH_SIZE) return train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count
Example #20
Source File: batch_generator.py From BirdCLEF-Baseline with MIT License | 5 votes |
def getNextImageBatch(split, augmentation=True): #fill batch for chunk in getDatasetChunk(split): #allocate numpy arrays for image data and targets x_b = np.zeros((cfg.BATCH_SIZE, cfg.IM_DIM, cfg.IM_SIZE[1], cfg.IM_SIZE[0]), dtype='float32') y_b = np.zeros((cfg.BATCH_SIZE, len(cfg.CLASSES)), dtype='float32') ib = 0 for sample in chunk: try: #load image data and class label from path x, y = loadImageAndTarget(sample, augmentation) #pack into batch array x_b[ib] = x y_b[ib] = y ib += 1 except: continue #trim to actual size x_b = x_b[:ib] y_b = y_b[:ib] #instead of return, we use yield yield x_b, y_b #Loading images with CPU background threads during GPU forward passes saves a lot of time #Credit: J. Schlüter (https://github.com/Lasagne/Lasagne/issues/12)
Example #21
Source File: batch_generator.py From BirdCLEF-Baseline with MIT License | 5 votes |
def getDatasetChunk(split): #get batch-sized chunks of image paths for i in xrange(0, len(split), cfg.BATCH_SIZE): yield split[i:i+cfg.BATCH_SIZE]
Example #22
Source File: stats.py From BirdCLEF-Baseline with MIT License | 5 votes |
def showProgress(epoch, done=False): global last_update # First call? if not 'batch_count' in cfg.STATS: bcnt = 0 else: bcnt = cfg.STATS['batch_count'] # Calculate number of batches to train total_batches = cfg.STATS['sample_count'] // cfg.BATCH_SIZE + 1 # Current progess if not done: if bcnt == 0: log.p(('EPOCH', epoch, '['), new_line=False) else: p = bcnt * 100 / total_batches if not p % 5 and not p == last_update: log.p('=', new_line=False) last_update = p else: log.p(']', new_line=False) # Clear on first load