Python provider.shift_point_cloud() Examples
The following are 4
code examples of provider.shift_point_cloud().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
provider
, or try the search function
.
Example #1
Source File: train_cls.py From dgl with Apache License 2.0 | 5 votes |
def train(net, opt, scheduler, train_loader, dev): net.train() total_loss = 0 num_batches = 0 total_correct = 0 count = 0 loss_f = nn.CrossEntropyLoss() with tqdm.tqdm(train_loader, ascii=True) as tq: for data, label in tq: data = data.data.numpy() data = provider.random_point_dropout(data) data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3]) data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3]) data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3]) data = torch.tensor(data) label = label[:, 0] num_examples = label.shape[0] data, label = data.to(dev), label.to(dev).squeeze().long() opt.zero_grad() logits = net(data) loss = loss_f(logits, label) loss.backward() opt.step() _, preds = logits.max(1) num_batches += 1 count += num_examples loss = loss.item() correct = (preds == label).sum().item() total_loss += loss total_correct += correct tq.set_postfix({ 'AvgLoss': '%.5f' % (total_loss / num_batches), 'AvgAcc': '%.5f' % (total_correct / count)}) scheduler.step()
Example #2
Source File: modelnet_h5_dataset.py From dfc2019 with MIT License | 5 votes |
def _augment_batch_data(self, batch_data): rotated_data = provider.rotate_point_cloud(batch_data) rotated_data = provider.rotate_perturbation_point_cloud(rotated_data) jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3]) jittered_data = provider.shift_point_cloud(jittered_data) jittered_data = provider.jitter_point_cloud(jittered_data) rotated_data[:,:,0:3] = jittered_data return provider.shuffle_points(rotated_data)
Example #3
Source File: modelnet_dataset.py From dfc2019 with MIT License | 5 votes |
def _augment_batch_data(self, batch_data): if self.normal_channel: rotated_data = provider.rotate_point_cloud_with_normal(batch_data) rotated_data = provider.rotate_perturbation_point_cloud_with_normal(rotated_data) else: rotated_data = provider.rotate_point_cloud(batch_data) rotated_data = provider.rotate_perturbation_point_cloud(rotated_data) jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3]) jittered_data = provider.shift_point_cloud(jittered_data) jittered_data = provider.jitter_point_cloud(jittered_data) rotated_data[:,:,0:3] = jittered_data return provider.shuffle_points(rotated_data)
Example #4
Source File: train.py From ldgcnn with MIT License | 4 votes |
def train_one_epoch(sess, ops, train_writer): """ ops: dict mapping from string to tf ops """ is_training = True # Shuffle train files train_file_idxs = np.arange(0, len(TRAIN_FILES)) np.random.shuffle(train_file_idxs) for fn in range(len(TRAIN_FILES)): log_string('----' + str(fn) + '-----') # Load data and labels from the files. current_data, current_label = provider.loadDataFile(TRAIN_FILES[train_file_idxs[fn]]) current_data = current_data[:,0:NUM_POINT,:] # Shuffle the data in the training set. current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label)) current_label = np.squeeze(current_label) file_size = current_data.shape[0] num_batches = file_size // BATCH_SIZE total_correct = 0 total_seen = 0 loss_sum = 0 for batch_idx in range(num_batches): start_idx = batch_idx * BATCH_SIZE end_idx = (batch_idx+1) * BATCH_SIZE # Augment batched point clouds by rotating, jittering, shifting, # and scaling. rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :]) jittered_data = provider.jitter_point_cloud(rotated_data) jittered_data = provider.random_scale_point_cloud(jittered_data) jittered_data = provider.rotate_perturbation_point_cloud(jittered_data) jittered_data = provider.shift_point_cloud(jittered_data) # Input the augmented point cloud and labels to the graph. feed_dict = {ops['pointclouds_pl']: jittered_data, ops['labels_pl']: current_label[start_idx:end_idx], ops['is_training_pl']: is_training,} # Calculate the loss and accuracy of the input batch data. summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict) train_writer.add_summary(summary, step) pred_val = np.argmax(pred_val, 1) correct = np.sum(pred_val == current_label[start_idx:end_idx]) total_correct += correct total_seen += BATCH_SIZE loss_sum += loss_val log_string('mean loss: %f' % (loss_sum / float(num_batches))) log_string('accuracy: %f' % (total_correct / float(total_seen)))