Python torch.eq() Examples
The following are 30
code examples of torch.eq().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: test_context_conditioned_policy.py From garage with MIT License | 6 votes |
def test_update_context(self): """Test update_context.""" s = TimeStep(env_spec=self.env_spec, observation=np.ones(self.obs_dim), next_observation=np.ones(self.obs_dim), action=np.ones(self.action_dim), reward=1.0, terminal=False, env_info={}, agent_info={}) updates = 10 for _ in range(updates): self.module.update_context(s) assert torch.all( torch.eq(self.module.context, torch.ones(updates, self.encoder_input_dim)))
Example #2
Source File: test_wrappers.py From torchbearer with MIT License | 6 votes |
def test_train(self): self._metric.train() calls = [[torch.FloatTensor([0.0]), torch.LongTensor([0])], [torch.FloatTensor([0.0, 0.1, 0.2, 0.3]), torch.LongTensor([0, 1, 2, 3])]] for i in range(len(self._states)): self._metric.process(self._states[i]) self.assertEqual(2, len(self._metric_function.call_args_list)) for i in range(len(self._metric_function.call_args_list)): self.assertTrue(torch.eq(self._metric_function.call_args_list[i][0][0], calls[i][0]).all) self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[i][0][1], -calls[i][1])), 1e-12).all) self._metric_function.reset_mock() self._metric.process_final({}) self.assertEqual(self._metric_function.call_count, 1) self.assertTrue(torch.eq(self._metric_function.call_args_list[0][0][1], torch.LongTensor([0, 1, 2, 3, 4])).all) self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[0][0][0], -torch.FloatTensor([0.0, 0.1, 0.2, 0.3, 0.4]))), 1e-12).all)
Example #3
Source File: sequence_labeling.py From GraphIE with GNU General Public License v3.0 | 6 votes |
def decode(self, input_word_orig, input_word, input_char, adjs, target=None, mask=None, length=None, hx=None, leading_symbolic=0, graph_types=['coref']): # output from rnn [batch, length, tag_space] output, target, sent_mask, length, _ = self._get_gcn_output(input_word_orig, input_word, input_char, adjs, target, mask=mask, length=length, hx=hx, leading_symbolic=leading_symbolic, graph_types=graph_types) if target is None: return self.crf.decode(output, mask=sent_mask, leading_symbolic=leading_symbolic), None preds = self.crf.decode(output, mask=sent_mask, leading_symbolic=leading_symbolic) if mask is None: return preds, torch.eq(preds, target).float().sum() else: return preds, (torch.eq(preds, target).float() * sent_mask).sum()
Example #4
Source File: sequence_labeling.py From GraphIE with GNU General Public License v3.0 | 6 votes |
def decode(self, input_word_orig, input_word, input_char, _, target=None, mask=None, length=None, hx=None, leading_symbolic=0): if len(input_word.size()) == 3: # input_word is the packed sents [n_sent, sent_len] input_word, input_char, target, sent_mask, length, doc_n_sent = self._doc2sent( input_word, input_char, target) # output from rnn [batch, length, tag_space] output, _, mask, length = self._get_rnn_output(input_word_orig, input_word, input_char, mask=mask, length=length, hx=hx) if target is None: return self.crf.decode(output, mask=mask, leading_symbolic=leading_symbolic), None if length is not None: max_len = length.max() target = target[:, :max_len] preds = self.crf.decode(output, mask=mask, leading_symbolic=leading_symbolic) if mask is None: return preds, torch.eq(preds, target).float().sum() else: return preds, (torch.eq(preds, target).float() * mask).sum()
Example #5
Source File: som.py From USIP with GNU General Public License v3.0 | 6 votes |
def query(self, x): ''' :param x: input data CxN tensor :return: mask: Nxnode_num ''' # expand as CxNxnode_num node = self.node.unsqueeze(1).expand(x.size(0), x.size(1), self.rows * self.cols) x_expanded = x.unsqueeze(2).expand_as(node) # calcuate difference between x and each node diff = x_expanded - node # CxNxnode_num diff_norm = (diff ** 2).sum(dim=0) # Nxnode_num # find the nearest neighbor _, min_idx = torch.min(diff_norm, dim=1) # N min_idx_expanded = min_idx.unsqueeze(1).expand(min_idx.size()[0], self.rows * self.cols).float() # Nxnode_num node_idx_list = self.node_idx_list.unsqueeze(0).expand_as(min_idx_expanded) # Nxnode_num mask = torch.eq(min_idx_expanded, node_idx_list).float() # Nxnode_num mask_row_max, _ = torch.max(mask, dim=0) # node_num, this indicates whether the node has nearby x return mask, mask_row_max
Example #6
Source File: test_dataset.py From kge with MIT License | 6 votes |
def test_data_pickle_correctness(self): # this will create new pickle files for train, valid, test dataset = Dataset.create( config=self.config, folder=self.dataset_folder, preload_data=True ) # create new dataset which loads the triples from stored pckl files dataset_load_by_pickle = Dataset.create( config=self.config, folder=self.dataset_folder, preload_data=True ) for split in dataset._triples.keys(): self.assertTrue( torch.all( torch.eq(dataset_load_by_pickle.split(split), dataset.split(split)) ) ) self.assertEqual(dataset._meta, dataset_load_by_pickle._meta)
Example #7
Source File: test_continuous_mlp_q_function.py From garage with MIT License | 6 votes |
def test_forward(self, hidden_sizes): env_spec = GarageEnv(DummyBoxEnv()) obs_dim = env_spec.observation_space.flat_dim act_dim = env_spec.action_space.flat_dim obs = torch.ones(obs_dim, dtype=torch.float32).unsqueeze(0) act = torch.ones(act_dim, dtype=torch.float32).unsqueeze(0) qf = ContinuousMLPQFunction(env_spec=env_spec, hidden_nonlinearity=None, hidden_sizes=hidden_sizes, hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) output = qf(obs, act) expected_output = torch.full([1, 1], fill_value=(obs_dim + act_dim) * np.prod(hidden_sizes), dtype=torch.float32) assert torch.eq(output, expected_output) # yapf: disable
Example #8
Source File: test_continuous_mlp_q_function.py From garage with MIT License | 6 votes |
def test_is_pickleable(self, hidden_sizes): env_spec = GarageEnv(DummyBoxEnv()) obs_dim = env_spec.observation_space.flat_dim act_dim = env_spec.action_space.flat_dim obs = torch.ones(obs_dim, dtype=torch.float32).unsqueeze(0) act = torch.ones(act_dim, dtype=torch.float32).unsqueeze(0) qf = ContinuousMLPQFunction(env_spec=env_spec, hidden_nonlinearity=None, hidden_sizes=hidden_sizes, hidden_w_init=nn.init.ones_, output_w_init=nn.init.ones_) output1 = qf(obs, act) p = pickle.dumps(qf) qf_pickled = pickle.loads(p) output2 = qf_pickled(obs, act) assert torch.eq(output1, output2)
Example #9
Source File: precision.py From UnsupervisedGeometryAwareRepresentationLearning with GNU General Public License v3.0 | 6 votes |
def update(self, output): y_pred, y = output num_classes = y_pred.size(1) indices = torch.max(y_pred, 1)[1] correct = torch.eq(indices, y) pred_onehot = to_onehot(indices, num_classes) all_positives = pred_onehot.sum(dim=0) if correct.sum() == 0: true_positives = torch.zeros_like(all_positives) else: correct_onehot = to_onehot(indices[correct], num_classes) true_positives = correct_onehot.sum(dim=0) if self._all_positives is None: self._all_positives = all_positives self._true_positives = true_positives else: self._all_positives += all_positives self._true_positives += true_positives
Example #10
Source File: recall.py From UnsupervisedGeometryAwareRepresentationLearning with GNU General Public License v3.0 | 6 votes |
def update(self, output): y_pred, y = output num_classes = y_pred.size(1) indices = torch.max(y_pred, 1)[1] correct = torch.eq(indices, y) actual_onehot = to_onehot(y, num_classes) actual = actual_onehot.sum(dim=0) if correct.sum() == 0: true_positives = torch.zeros_like(actual) else: correct_onehot = to_onehot(indices[correct], num_classes) true_positives = correct_onehot.sum(dim=0) if self._actual is None: self._actual = actual self._true_positives = true_positives else: self._actual += actual self._true_positives += true_positives
Example #11
Source File: test_coding.py From nn-compression with MIT License | 6 votes |
def test_encode_param(): param = torch.rand(256, 128, 3, 3) prune_vanilla_elementwise(sparsity=0.7, param=param) quantize_linear_fix_zeros(param, k=16) huffman = EncodedParam(param=param, method='huffman', encode_indices=True, bit_length_zero_run_length=4) stats = huffman.stats print(stats) assert torch.eq(param, huffman.data).all() state_dict = huffman.state_dict() huffman = EncodedParam() huffman.load_state_dict(state_dict) assert torch.eq(param, huffman.data).all() vanilla = EncodedParam(param=param, method='vanilla', encode_indices=True, bit_length_zero_run_length=4) stats = vanilla.stats print(stats) assert torch.eq(param, vanilla.data).all() quantize_fixed_point(param=param, bit_length=4, bit_length_integer=0) fixed_point = EncodedParam(param=param, method='fixed_point', bit_length=4, bit_length_integer=0, encode_indices=True, bit_length_zero_run_length=4) stats = fixed_point.stats print(stats) assert torch.eq(param, fixed_point.data).all()
Example #12
Source File: hinge.py From dfw with MIT License | 6 votes |
def _compute_xi(self, s, aug, y): # find argmax of augmented scores _, y_star = torch.max(aug, 1) # xi_max: one-hot encoding of maximal indices xi_max = torch.eq(y_star[:, None], self._range).float() if MultiClassHingeLoss.smooth: # find smooth argmax of scores xi_smooth = nn.functional.softmax(s, dim=1) # compute for each sample whether it has a positive contribution to the loss losses = torch.sum(xi_smooth * aug, 1) mask_smooth = torch.ge(losses, 0).float()[:, None] # keep only smoothing for positive contributions xi = mask_smooth * xi_smooth + (1 - mask_smooth) * xi_max else: xi = xi_max return xi
Example #13
Source File: model.py From VSE-C with MIT License | 6 votes |
def forward(self, feed_dict): feed_dict = GView(feed_dict) feature_f = self._extract_sent_feature(feed_dict.sent_f, feed_dict.sent_f_length, self.gru_f) feature_b = self._extract_sent_feature(feed_dict.sent_b, feed_dict.sent_b_length, self.gru_b) feature_img = feed_dict.image feature = torch.cat([feature_f, feature_b, feature_img], dim=1) predict = self.predict(feature) if self.training: label = self.embedding(feed_dict.label) loss = cosine_loss(predict, label).mean() return loss, {}, {} else: output_dict = dict(pred=predict) if 'label' in feed_dict: dis = cosine_distance(predict, self.embedding.weight) _, topk = dis.topk(1000, dim=1, sorted=True) for k in [1, 10, 100, 1000]: output_dict['top{}'.format(k)] = torch.eq(topk, feed_dict.label.unsqueeze(-1))[:, :k].float().sum(dim=1).mean() return output_dict
Example #14
Source File: accuracy.py From LaSO with BSD 3-Clause "New" or "Revised" License | 6 votes |
def update(self, output): y_pred, y = self._check_shape(output) self._check_type((y_pred, y)) if self._type == "binary": correct = torch.eq(y_pred.type(y.type()), y).view(-1) elif self._type == "multiclass": indices = torch.max(y_pred, dim=1)[1] correct = torch.eq(indices, y).view(-1) elif self._type == "multilabel": # if y, y_pred shape is (N, C, ...) -> (N x ..., C) num_classes = y_pred.size(1) last_dim = y_pred.ndimension() y_pred = torch.transpose(y_pred, 1, last_dim - 1).reshape(-1, num_classes) y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes) correct = torch.all(y == y_pred.type_as(y), dim=-1) self._num_correct += torch.sum(correct).item() self._num_examples += correct.shape[0]
Example #15
Source File: metrics.py From LaSO with BSD 3-Clause "New" or "Revised" License | 5 votes |
def update(self, output): y_pred, y = output y_pred = torch.sigmoid(y_pred) y_pred = (y_pred > 0.5).float() correct = torch.eq(y_pred, y).view(-1) self._num_correct += torch.sum(correct).item() self._num_examples += correct.shape[0]
Example #16
Source File: categorical_accuracy.py From argus with MIT License | 5 votes |
def update(self, step_output: dict): pred = step_output['prediction'] trg = step_output['target'] indices = torch.max(pred, dim=1)[1] correct = torch.eq(indices, trg).view(-1) self.correct += torch.sum(correct).item() self.count += correct.shape[0]
Example #17
Source File: loader.py From AGGCN with MIT License | 5 votes |
def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key < 0 or key >= len(self.data): raise IndexError batch = self.data[key] batch_size = len(batch) batch = list(zip(*batch)) # for nary dataset assert len(batch) == 9 # sort all fields by lens for easy RNN operations lens = [len(x) for x in batch[0]] batch, orig_idx = sort_all(batch, lens) # word dropout if not self.eval: words = [word_dropout(sent, self.opt['word_dropout']) for sent in batch[0]] else: words = batch[0] # convert to tensors words = get_long_tensor(words, batch_size) masks = torch.eq(words, 0) pos = get_long_tensor(batch[1], batch_size) deprel = get_long_tensor(batch[2], batch_size) head = get_long_tensor(batch[3], batch_size) first_positions = get_long_tensor(batch[4], batch_size) second_positions = get_long_tensor(batch[5], batch_size) third_positions = get_long_tensor(batch[6], batch_size) cross = batch[7] rels = torch.LongTensor(batch[8]) return (words, masks, pos, deprel, head, first_positions, second_positions, third_positions, cross, rels, orig_idx)
Example #18
Source File: loader.py From AGGCN with MIT License | 5 votes |
def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key < 0 or key >= len(self.data): raise IndexError batch = self.data[key] batch_size = len(batch) batch = list(zip(*batch)) if dataset == 'dataset/tacred': assert len(batch) == 10 else: assert len(batch) == 7 # sort all fields by lens for easy RNN operations lens = [len(x) for x in batch[0]] batch, orig_idx = sort_all(batch, lens) # word dropout if not self.eval: words = [word_dropout(sent, self.opt['word_dropout']) for sent in batch[0]] else: words = batch[0] # convert to tensors words = get_long_tensor(words, batch_size) masks = torch.eq(words, 0) pos = get_long_tensor(batch[1], batch_size) deprel = get_long_tensor(batch[2], batch_size) head = get_long_tensor(batch[3], batch_size) subj_positions = get_long_tensor(batch[4], batch_size) obj_positions = get_long_tensor(batch[5], batch_size) rels = torch.LongTensor(batch[6]) return (words, masks, pos, deprel, head, subj_positions, obj_positions, rels, orig_idx)
Example #19
Source File: loader.py From AGGCN with MIT License | 5 votes |
def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key < 0 or key >= len(self.data): raise IndexError batch = self.data[key] batch_size = len(batch) batch = list(zip(*batch)) # for nary dataset assert len(batch) == 8 # sort all fields by lens for easy RNN operations lens = [len(x) for x in batch[0]] batch, orig_idx = sort_all(batch, lens) # word dropout if not self.eval: words = [word_dropout(sent, self.opt['word_dropout']) for sent in batch[0]] else: words = batch[0] # convert to tensors words = get_long_tensor(words, batch_size) masks = torch.eq(words, 0) pos = get_long_tensor(batch[1], batch_size) deprel = get_long_tensor(batch[2], batch_size) head = get_long_tensor(batch[3], batch_size) first_positions = get_long_tensor(batch[4], batch_size) second_positions = get_long_tensor(batch[5], batch_size) cross = batch[6] rels = torch.LongTensor(batch[7]) return (words, masks, pos, deprel, head, first_positions, second_positions, cross, rels, orig_idx)
Example #20
Source File: loader.py From AGGCN with MIT License | 5 votes |
def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key < 0 or key >= len(self.data): raise IndexError batch = self.data[key] batch_size = len(batch) batch = list(zip(*batch)) assert len(batch) == 10 # sort all fields by lens for easy RNN operations lens = [len(x) for x in batch[0]] batch, orig_idx = sort_all(batch, lens) # word dropout if not self.eval: words = [word_dropout(sent, self.opt['word_dropout']) for sent in batch[0]] else: words = batch[0] # convert to tensors words = get_long_tensor(words, batch_size) masks = torch.eq(words, 0) pos = get_long_tensor(batch[1], batch_size) ner = get_long_tensor(batch[2], batch_size) deprel = get_long_tensor(batch[3], batch_size) head = get_long_tensor(batch[4], batch_size) subj_positions = get_long_tensor(batch[5], batch_size) obj_positions = get_long_tensor(batch[6], batch_size) subj_type = get_long_tensor(batch[7], batch_size) obj_type = get_long_tensor(batch[8], batch_size) rels = torch.LongTensor(batch[9]) return (words, masks, pos, ner, deprel, head, subj_positions, obj_positions, subj_type, obj_type, rels, orig_idx)
Example #21
Source File: som.py From USIP with GNU General Public License v3.0 | 5 votes |
def query_topk(node, x, M, k): ''' :param node: SOM node of BxCxM tensor :param x: input data BxCxN tensor :param M: number of SOM nodes :param k: topk :return: mask: Nxnode_num ''' # ensure x, and other stored tensors are in the same device device = x.device node = node.to(x.device) node_idx_list = torch.from_numpy(np.arange(M).astype(np.int64)).to(device) # node_num LongTensor # expand as BxCxNxnode_num node = node.unsqueeze(2).expand(x.size(0), x.size(1), x.size(2), M) x_expanded = x.unsqueeze(3).expand_as(node) # calcuate difference between x and each node diff = x_expanded - node # BxCxNxnode_num diff_norm = (diff ** 2).sum(dim=1) # BxNxnode_num # find the nearest neighbor _, min_idx = torch.topk(diff_norm, k=k, dim=2, largest=False, sorted=False) # BxNxk min_idx_expanded = min_idx.unsqueeze(2).expand(min_idx.size()[0], min_idx.size()[1], M, k) # BxNxnode_numxk node_idx_list = node_idx_list.unsqueeze(0).unsqueeze(0).unsqueeze(3).expand_as( min_idx_expanded).long() # BxNxnode_numxk mask = torch.eq(min_idx_expanded, node_idx_list).int() # BxNxnode_numxk # mask = torch.sum(mask, dim=3) # BxNxnode_num # debug B, N, M = mask.size()[0], mask.size()[1], mask.size()[2] mask = mask.permute(0, 2, 3, 1).contiguous().view(B, M, k*N).permute(0, 2, 1).contiguous() # BxMxkxN -> BxMxkN -> BxkNxM min_idx = min_idx.permute(0, 2, 1).contiguous().view(B, k*N) mask_row_max, _ = torch.max(mask, dim=1) # Bxnode_num, this indicates whether the node has nearby x return mask, mask_row_max, min_idx
Example #22
Source File: transforms.py From PyTorch-ENet with MIT License | 5 votes |
def __call__(self, tensor): """Performs the conversion from ``torch.LongTensor`` to a ``PIL image`` Keyword arguments: - tensor (``torch.LongTensor``): the tensor to convert Returns: A ``PIL.Image``. """ # Check if label_tensor is a LongTensor if not isinstance(tensor, torch.LongTensor): raise TypeError("label_tensor should be torch.LongTensor. Got {}" .format(type(tensor))) # Check if encoding is a ordered dictionary if not isinstance(self.rgb_encoding, OrderedDict): raise TypeError("encoding should be an OrderedDict. Got {}".format( type(self.rgb_encoding))) # label_tensor might be an image without a channel dimension, in this # case unsqueeze it if len(tensor.size()) == 2: tensor.unsqueeze_(0) color_tensor = torch.ByteTensor(3, tensor.size(1), tensor.size(2)) for index, (class_name, color) in enumerate(self.rgb_encoding.items()): # Get a mask of elements equal to index mask = torch.eq(tensor, index).squeeze_() # Fill color_tensor with corresponding colors for channel, color_value in enumerate(color): color_tensor[channel].masked_fill_(mask, color_value) return ToPILImage()(color_tensor)
Example #23
Source File: BayesianConvs.py From UCB with MIT License | 5 votes |
def prune_module(self, mask): self.mask_flag = True self.pruned_weight_mu=self.weight_mu.data.mul_(mask) # self.pruned_weight_rho=self.weight_rho.data.mul_(mask) # pruning_mask = torch.eq(mask, torch.zeros_like(mask))
Example #24
Source File: sequence_labeling.py From GraphIE with GNU General Public License v3.0 | 5 votes |
def loss(self, input_word_orig, input_word, input_char, target, mask=None, length=None, hx=None, leading_symbolic=0, show_net=False): # [batch, length, tag_space] output, mask, length = self.forward(input_word_orig, input_word, input_char, mask=mask, length=length, hx=hx) # [batch, length, num_labels] output = self.dense_softmax(output) # preds = [batch, length] _, preds = torch.max(output[:, :, leading_symbolic:], dim=2) preds += leading_symbolic output_size = output.size() # [batch * length, num_labels] output_size = (output_size[0] * output_size[1], output_size[2]) output = output.view(output_size) if length is not None and target.size(1) != mask.size(1): max_len = length.max() target = target[:, :max_len].contiguous() if mask is not None: return (self.nll_loss(self.logsoftmax(output), target.view(-1)) * mask.contiguous().view( -1)).sum() / mask.sum(), \ (torch.eq(preds, target).type_as(mask) * mask).sum(), preds else: num = output_size[0] * output_size[1] return self.nll_loss(self.logsoftmax(output), target.view(-1)).sum() / num, \ (torch.eq(preds, target).type_as(output)).sum(), preds
Example #25
Source File: test_dataset.py From kge with MIT License | 5 votes |
def assertEqualTorch(self, first, second, msg=None): """Compares first and second using ==, except for PyTorch tensors, where `torch.eq` is used.""" # TODO factor out to utility class self.assertEqual(type(first), type(second), msg=msg) if isinstance(first, dict): self.assertEqual(len(first), len(second), msg=msg) for key in first.keys(): self.assertTrue(key in second, msg=msg) self.assertEqualTorch(first[key], second[key], msg=msg) elif isinstance(first, list): self.assertEqual(len(first), len(second), msg=msg) for i in range(len(first)): self.assertEqualTorch(first[i], second[i], msg=msg) elif isinstance(first, KvsAllIndex): first_attributes = [a for a in dir(first) if not a.startswith("__")] second_attributes = [a for a in dir(second) if not a.startswith("__")] for first_attribute, second_attribute in zip( first_attributes, second_attributes ): self.assertEqualTorch(first_attribute, second_attribute) else: if type(first) is torch.Tensor: self.assertTrue(torch.all(torch.eq(first, second)), msg=msg) else: self.assertEqual(first, second, msg=msg)
Example #26
Source File: test_coding.py From nn-compression with MIT License | 5 votes |
def test_codec(): quantize_rule = [ ('0.weight', 'k-means', 4, 'k-means++'), ('1.weight', 'fixed_point', 6, 1), ] model = torch.nn.Sequential(torch.nn.Conv2d(256, 128, 3, bias=True), torch.nn.Conv2d(128, 512, 1, bias=False)) mask_dict = {} for n, p in model.named_parameters(): mask_dict[n] = prune_vanilla_elementwise(sparsity=0.6, param=p.data) quantizer = Quantizer(rule=quantize_rule, fix_zeros=True) quantizer.quantize(model, update_labels=False, verbose=True) rule = [ ('0.weight', 'huffman', 0, 0, 4), ('1.weight', 'fixed_point', 6, 1, 4) ] codec = Codec(rule=rule) encoded_module = codec.encode(model) print(codec.stats) state_dict = encoded_module.state_dict() model_2 = torch.nn.Sequential(torch.nn.Conv2d(256, 128, 3, bias=True), torch.nn.Conv2d(128, 512, 1, bias=False)) model_2 = Codec.decode(model_2, state_dict) for p1, p2 in zip(model.parameters(), model_2.parameters()): if p1.dim() > 1: assert torch.eq(p1, p2).all()
Example #27
Source File: linear.py From nn-compression with MIT License | 5 votes |
def quantize_linear_fix_zeros(param, k=16, **unused): """ linearly quantize while fixing zeros :param param: torch.(cuda.)tensor :param k: int, the number of quantization level, default=16 :param unused: unused options :return: dict, {'centers_': torch.tensor}, codebook of quantization """ zero_mask = torch.eq(param, 0.0) # get zero mask num_param = param.numel() kth = int(math.ceil(num_param * magic_percentile)) param_flatten = param.view(num_param) param_min, _ = torch.topk(param_flatten, kth, dim=0, largest=False, sorted=False) param_min = param_min.max() param_max, _ = torch.topk(param_flatten, kth, dim=0, largest=True, sorted=False) param_max = param_max.min() step = (param_max - param_min) / (k - 2) param.clamp_(param_min, param_max).sub_(param_min).div_(step).round_().mul_(step).add_(param_min) param.masked_fill_(zero_mask, 0) # recover zeros # codebook = {'centers_': torch.tensor(list(set(param_flatten.cpu().tolist())))} codebook = {'cluster_centers_': torch.zeros(k), 'method': 'linear', } codebook['cluster_centers_'][1:] = torch.linspace(param_min, param_max, k - 1) return codebook
Example #28
Source File: group.py From pose-ae-train with BSD 3-Clause "New" or "Revised" License | 5 votes |
def nms(self, det): # suppose det is a tensor maxm = self.pool(det) maxm = torch.eq(maxm, det).float() det = det * maxm return det
Example #29
Source File: few_shot.py From cactus-protonets with MIT License | 5 votes |
def loss(self, sample): xs = Variable(sample['xs']) # support xq = Variable(sample['xq']) # query n_class = xs.size(0) assert xq.size(0) == n_class n_support = xs.size(1) n_query = xq.size(1) target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long() target_inds = Variable(target_inds, requires_grad=False) if xq.is_cuda: target_inds = target_inds.cuda() x = torch.cat([xs.view(n_class * n_support, *xs.size()[2:]), xq.view(n_class * n_query, *xq.size()[2:])], 0) z = self.encoder.forward(x) z_dim = z.size(-1) z_proto = z[:n_class*n_support].view(n_class, n_support, z_dim).mean(1) zq = z[n_class*n_support:] dists = euclidean_dist(zq, z_proto) log_p_y = F.log_softmax(-dists, dim=1).view(n_class, n_query, -1) loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean() _, y_hat = log_p_y.max(2) acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean() return loss_val, { 'loss': loss_val.item(), 'acc': acc_val.item() }
Example #30
Source File: netmath.py From ibeis with Apache License 2.0 | 5 votes |
def _siamese_metrics(output, label, margin=1): l2_dist_tensor = torch.from_numpy(output.data.cpu().numpy()) label_tensor = torch.from_numpy(label.data.cpu().numpy()) # Distance is_pos = torch.ByteTensor() POS_LABEL = 1 NEG_LABEL = 0 torch.eq(label_tensor, POS_LABEL, out=is_pos) # y==1 pos_dist = 0 if len(l2_dist_tensor[is_pos]) == 0 else l2_dist_tensor[is_pos].mean() neg_dist = 0 if len(l2_dist_tensor[~is_pos]) == 0 else l2_dist_tensor[~is_pos].mean() # print('same dis : diff dis {} : {}'.format(l2_dist_tensor[is_pos == 0].mean(), l2_dist_tensor[is_pos].mean())) # accuracy pred_pos_flags = torch.ByteTensor() torch.le(l2_dist_tensor, margin, out=pred_pos_flags) # y==1's idx cur_score = torch.FloatTensor(label.size(0)) cur_score.fill_(NEG_LABEL) cur_score[pred_pos_flags] = POS_LABEL label_tensor_ = label_tensor.type(torch.FloatTensor) accuracy = torch.eq(cur_score, label_tensor_).sum() / label_tensor.size(0) metrics = { 'accuracy': accuracy, 'pos_dist': pos_dist, 'neg_dist': neg_dist, } return metrics