Python torch.allclose() Examples
The following are 30
code examples of torch.allclose().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: autograd_hacks_test.py From autograd-hacks with The Unlicense | 6 votes |
def test_grad1(): torch.manual_seed(1) model = Net() loss_fn = nn.CrossEntropyLoss() n = 4 data = torch.rand(n, 1, 28, 28) targets = torch.LongTensor(n).random_(0, 10) autograd_hacks.add_hooks(model) output = model(data) loss_fn(output, targets).backward(retain_graph=True) autograd_hacks.compute_grad1(model) autograd_hacks.disable_hooks() # Compare values against autograd losses = torch.stack([loss_fn(output[i:i+1], targets[i:i+1]) for i in range(len(data))]) for layer in model.modules(): if not autograd_hacks.is_supported(layer): continue for param in layer.parameters(): assert torch.allclose(param.grad, param.grad1.mean(dim=0)) assert torch.allclose(jacobian(losses, param), param.grad1)
Example #2
Source File: test_root_finding.py From entmax with MIT License | 6 votes |
def test_arbitrary_dimension(dim): shape = [3, 4, 2, 5] X = torch.randn(*shape, dtype=torch.float64) alpha_shape = shape alpha_shape[dim] = 1 alphas = 1.05 + torch.rand(alpha_shape, dtype=torch.float64) P = entmax_bisect(X, alpha=alphas, dim=dim) ranges = [ list(range(k)) if i != dim else [slice(None)] for i, k in enumerate(shape) ] for ix in product(*ranges): x = X[ix].unsqueeze(0) alpha = alphas[ix].item() p_true = entmax_bisect(x, alpha=alpha, dim=-1) assert torch.allclose(P[ix], p_true)
Example #3
Source File: losses_label_smoothing_cross_entropy_loss_test.py From ClassyVision with MIT License | 6 votes |
def test_smoothing_ignore_index_one_hot_targets(self): config = { "name": "label_smoothing_cross_entropy", "ignore_index": -1, "smoothing_param": 0.5, } crit = build_loss(config) targets = torch.tensor([[-1, 0, 0, 0, 1]]) self.assertTrue(isinstance(crit, LabelSmoothingCrossEntropyLoss)) valid_targets = crit.compute_valid_targets(targets, 5) self.assertTrue( torch.allclose(valid_targets, torch.tensor([[0.0, 0.0, 0.0, 0.0, 1.0]])) ) smoothed_targets = crit.smooth_targets(valid_targets, 5) self.assertTrue( torch.allclose( smoothed_targets, torch.tensor([[1 / 15, 1 / 15, 1 / 15, 1 / 15, 11 / 15]]), ) )
Example #4
Source File: losses_label_smoothing_cross_entropy_loss_test.py From ClassyVision with MIT License | 6 votes |
def test_smoothing_multilabel_one_hot_targets(self): config = { "name": "label_smoothing_cross_entropy", "ignore_index": -1, "smoothing_param": 0.5, } crit = build_loss(config) targets = torch.tensor([[1, 0, 0, 0, 1]]) self.assertTrue(isinstance(crit, LabelSmoothingCrossEntropyLoss)) valid_targets = crit.compute_valid_targets(targets, 5) self.assertTrue( torch.allclose(valid_targets, torch.tensor([[1.0, 0.0, 0.0, 0.0, 1.0]])) ) smoothed_targets = crit.smooth_targets(valid_targets, 5) self.assertTrue( torch.allclose( smoothed_targets, torch.tensor([[6 / 15, 1 / 15, 1 / 15, 1 / 15, 6 / 15]]), ) )
Example #5
Source File: losses_label_smoothing_cross_entropy_loss_test.py From ClassyVision with MIT License | 6 votes |
def test_smoothing_all_ones_one_hot_targets(self): config = { "name": "label_smoothing_cross_entropy", "ignore_index": -1, "smoothing_param": 0.1, } crit = build_loss(config) targets = torch.tensor([[1, 1, 1, 1]]) self.assertTrue(isinstance(crit, LabelSmoothingCrossEntropyLoss)) valid_targets = crit.compute_valid_targets(targets, 4) self.assertTrue( torch.allclose(valid_targets, torch.tensor([[1.0, 1.0, 1.0, 1.0]])) ) smoothed_targets = crit.smooth_targets(valid_targets, 4) self.assertTrue( torch.allclose(smoothed_targets, torch.tensor([[0.25, 0.25, 0.25, 0.25]])) )
Example #6
Source File: birkhoff_polytope.py From geoopt with Apache License 2.0 | 6 votes |
def _check_point_on_manifold(self, x, *, atol=1e-4, rtol=1e-4): row_sum = x.sum(dim=-1) col_sum = x.sum(dim=-2) row_ok = torch.allclose( row_sum, row_sum.new((1,)).fill_(1), atol=atol, rtol=rtol ) col_ok = torch.allclose( col_sum, col_sum.new((1,)).fill_(1), atol=atol, rtol=rtol ) if row_ok and col_ok: return True, None else: return ( False, "illegal doubly stochastic matrix with atol={}, rtol={}".format( atol, rtol ), )
Example #7
Source File: sphere.py From geoopt with Apache License 2.0 | 6 votes |
def _check_point_on_manifold( self, x: torch.Tensor, *, atol=1e-5, rtol=1e-5 ) -> Tuple[bool, Optional[str]]: norm = x.norm(dim=-1) ok = torch.allclose(norm, norm.new((1,)).fill_(1), atol=atol, rtol=rtol) if not ok: return False, "`norm(x) != 1` with atol={}, rtol={}".format(atol, rtol) ok = torch.allclose(self._project_on_subspace(x), x, atol=atol, rtol=rtol) if not ok: return ( False, "`x` is not in the subspace of the manifold with atol={}, rtol={}".format( atol, rtol ), ) return True, None
Example #8
Source File: generic_util_test.py From ClassyVision with MIT License | 6 votes |
def test_split_batchnorm_params(self): class MyModel(nn.Module): def __init__(self): super().__init__() self.lin = nn.Linear(2, 3, bias=False) self.relu = nn.ReLU() self.bn = nn.BatchNorm1d(3) def forward(self, x): return self.bn(self.relu(self.lin(x))) torch.manual_seed(1) model = MyModel() bn_params, lin_params = split_batchnorm_params(model) self.assertEquals(len(bn_params), 2) self.assertEquals(len(lin_params), 1) self.assertTrue(torch.allclose(bn_params[0], model.bn.weight)) self.assertTrue(torch.allclose(bn_params[1], model.bn.bias)) self.assertTrue(torch.allclose(lin_params[0], model.lin.weight))
Example #9
Source File: test_polar.py From pytorch_geometric with MIT License | 6 votes |
def test_polar(): assert Polar().__repr__() == 'Polar(norm=True, max_value=None)' pos = torch.Tensor([[0, 0], [1, 0]]) edge_index = torch.tensor([[0, 1], [1, 0]]) edge_attr = torch.Tensor([1, 1]) data = Data(edge_index=edge_index, pos=pos) data = Polar(norm=False)(data) assert len(data) == 3 assert data.pos.tolist() == pos.tolist() assert data.edge_index.tolist() == edge_index.tolist() assert torch.allclose( data.edge_attr, torch.Tensor([[1, 0], [1, PI]]), atol=1e-04) data = Data(edge_index=edge_index, pos=pos, edge_attr=edge_attr) data = Polar(norm=True)(data) assert len(data) == 3 assert data.pos.tolist() == pos.tolist() assert data.edge_index.tolist() == edge_index.tolist() assert torch.allclose( data.edge_attr, torch.Tensor([[1, 1, 0], [1, 1, 0.5]]), atol=1e-04)
Example #10
Source File: models_classy_model_test.py From ClassyVision with MIT License | 6 votes |
def test_classy_model_wrapper_torch_jittable(self): orig_wrapper_cls = MyTestModel2.wrapper_cls input = torch.ones((2, 2)) for wrapper_cls, expected_output in [ (None, input + 1), (TestSimpleClassyModelWrapper, (input + 1) * 2), ]: MyTestModel2.wrapper_cls = wrapper_cls model = MyTestModel2() jitted_model = torch.jit.trace(model, input) self.assertTrue(torch.allclose(expected_output, model(input))) self.assertTrue(torch.allclose(expected_output, jitted_model(input))) # restore the original wrapper class MyTestModel2.wrapper_cls = orig_wrapper_cls
Example #11
Source File: models_classy_model_test.py From ClassyVision with MIT License | 6 votes |
def test_classy_model_wrapper_torch_scriptable(self): orig_wrapper_cls = MyTestModel2.wrapper_cls input = torch.ones((2, 2)) for wrapper_cls, expected_output in [ (None, input + 1), # this isn't supported yet # (TestSimpleClassyModelWrapper, (input + 1) * 2), ]: MyTestModel2.wrapper_cls = wrapper_cls model = MyTestModel2() scripted_model = torch.jit.script(model) self.assertTrue(torch.allclose(expected_output, model(input))) self.assertTrue(torch.allclose(expected_output, scripted_model(input))) # restore the original wrapper class MyTestModel2.wrapper_cls = orig_wrapper_cls
Example #12
Source File: test_glob.py From pytorch_geometric with MIT License | 6 votes |
def test_permuted_global_pool(): N_1, N_2 = 4, 6 x = torch.randn(N_1 + N_2, 4) batch = torch.cat([torch.zeros(N_1), torch.ones(N_2)]).to(torch.long) perm = torch.randperm(N_1 + N_2) px = x[perm] pbatch = batch[perm] px1 = px[pbatch == 0] px2 = px[pbatch == 1] out = global_add_pool(px, pbatch) assert out.size() == (2, 4) assert torch.allclose(out[0], px1.sum(dim=0)) assert torch.allclose(out[1], px2.sum(dim=0)) out = global_mean_pool(px, pbatch) assert out.size() == (2, 4) assert torch.allclose(out[0], px1.mean(dim=0)) assert torch.allclose(out[1], px2.mean(dim=0)) out = global_max_pool(px, pbatch) assert out.size() == (2, 4) assert torch.allclose(out[0], px1.max(dim=0)[0]) assert torch.allclose(out[1], px2.max(dim=0)[0])
Example #13
Source File: utils.py From ClassyVision with MIT License | 6 votes |
def compare_batches(test_fixture, batch1, batch2): """Compare two batches. Does not do recursive comparison""" test_fixture.assertEqual(type(batch1), type(batch2)) if isinstance(batch1, (tuple, list)): test_fixture.assertEqual(len(batch1), len(batch2)) for n in range(len(batch1)): value1 = batch1[n] value2 = batch2[n] test_fixture.assertEqual(type(value1), type(value2)) if torch.is_tensor(value1): test_fixture.assertTrue(torch.allclose(value1, value2)) else: test_fixture.assertEqual(value1, value2) elif isinstance(batch1, dict): test_fixture.assertEqual(batch1.keys(), batch2.keys()) for key, value1 in batch1.items(): value2 = batch2[key] test_fixture.assertEqual(type(value1), type(value2)) if torch.is_tensor(value1): test_fixture.assertTrue(torch.allclose(value1, value2)) else: test_fixture.assertEqual(value1, value2)
Example #14
Source File: optim_test_util.py From ClassyVision with MIT License | 6 votes |
def _compare_momentum_values(self, optim1, optim2): self.assertEqual(len(optim1["param_groups"]), len(optim2["param_groups"])) for i in range(len(optim1["param_groups"])): self.assertEqual( len(optim1["param_groups"][i]["params"]), len(optim2["param_groups"][i]["params"]), ) if self._check_momentum_buffer(): for j in range(len(optim1["param_groups"][i]["params"])): id1 = optim1["param_groups"][i]["params"][j] id2 = optim2["param_groups"][i]["params"][j] self.assertTrue( torch.allclose( optim1["state"][id1]["momentum_buffer"], optim2["state"][id2]["momentum_buffer"], ) )
Example #15
Source File: test_static_graph.py From pytorch_geometric with MIT License | 6 votes |
def test_static_graph(): edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) x1, x2 = torch.randn(3, 8), torch.randn(3, 8) data1 = Data(edge_index=edge_index, x=x1) data2 = Data(edge_index=edge_index, x=x2) batch = Batch.from_data_list([data1, data2]) x = torch.stack([x1, x2], dim=0) for conv in [MyConv(), GCNConv(8, 16), ChebConv(8, 16, K=2)]: out1 = conv(batch.x, batch.edge_index) assert out1.size(0) == 6 conv.node_dim = 1 out2 = conv(x, edge_index) assert out2.size()[:2] == (2, 3) assert torch.allclose(out1, out2.view(-1, out2.size(-1)))
Example #16
Source File: test_appnp.py From pytorch_geometric with MIT License | 6 votes |
def test_appnp(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = APPNP(K=10, alpha=0.1) assert conv.__repr__() == 'APPNP(K=10, alpha=0.1)' out = conv(x, edge_index) assert out.size() == (4, 16) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out.tolist() t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)
Example #17
Source File: test_cluster_gcn_conv.py From pytorch_geometric with MIT License | 6 votes |
def test_cluster_gcn_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = ClusterGCNConv(16, 32, diag_lambda=1.) assert conv.__repr__() == 'ClusterGCNConv(16, 32, diag_lambda=1.0)' out = conv(x, edge_index) assert out.size() == (4, 32) assert conv(x, edge_index, size=(4, 4)).tolist() == out.tolist() assert torch.allclose(conv(x, adj.t()), out) t = '(Tensor, Tensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out.tolist() assert jit(x, edge_index, size=(4, 4)).tolist() == out.tolist() t = '(Tensor, SparseTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj.t()), out)
Example #18
Source File: test_rgcn_conv.py From pytorch_geometric with MIT License | 6 votes |
def test_rgcn_conv_equality(conf): num_bases, num_blocks = conf x1 = torch.randn(4, 4) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [0, 0, 1, 0, 1, 1]]) edge_type = torch.tensor([0, 1, 1, 0, 0, 1]) edge_index = torch.tensor([ [0, 1, 1, 2, 2, 3, 0, 1, 1, 2, 2, 3], [0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1], ]) edge_type = torch.tensor([0, 1, 1, 0, 0, 1, 2, 3, 3, 2, 2, 3]) torch.manual_seed(12345) conv1 = RGCNConv(4, 32, 4, num_bases, num_blocks) torch.manual_seed(12345) conv2 = FastRGCNConv(4, 32, 4, num_bases, num_blocks) out1 = conv1(x1, edge_index, edge_type) out2 = conv2(x1, edge_index, edge_type) assert torch.allclose(out1, out2, atol=1e-6)
Example #19
Source File: test_agnn_conv.py From pytorch_geometric with MIT License | 6 votes |
def test_agnn_conv(requires_grad): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = AGNNConv(requires_grad=requires_grad) assert conv.__repr__() == 'AGNNConv()' out = conv(x, edge_index) assert out.size() == (4, 16) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) t = '(Tensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out.tolist() t = '(Tensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)
Example #20
Source File: utils.py From ClassyVision with MIT License | 6 votes |
def compare_model_state(test_fixture, state, state2, check_heads=True): for k in state["model"]["trunk"].keys(): if not torch.allclose(state["model"]["trunk"][k], state2["model"]["trunk"][k]): print(k, state["model"]["trunk"][k], state2["model"]["trunk"][k]) test_fixture.assertTrue( torch.allclose(state["model"]["trunk"][k], state2["model"]["trunk"][k]) ) if check_heads: for block, head_states in state["model"]["heads"].items(): for head_id, states in head_states.items(): for k in states.keys(): test_fixture.assertTrue( torch.allclose( state["model"]["heads"][block][head_id][k], state2["model"]["heads"][block][head_id][k], ) )
Example #21
Source File: test_arma_conv.py From pytorch_geometric with MIT License | 6 votes |
def test_arma_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = ARMAConv(16, 32, num_stacks=8, num_layers=4) assert conv.__repr__() == 'ARMAConv(16, 32, num_stacks=8, num_layers=4)' out = conv(x, edge_index) assert out.size() == (4, 32) assert conv(x, adj.t()).tolist() == out.tolist() t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out.tolist() t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)
Example #22
Source File: test_losses.py From mmdetection with Apache License 2.0 | 6 votes |
def test_ce_loss(): # use_mask and use_sigmoid cannot be true at the same time with pytest.raises(AssertionError): loss_cfg = dict( type='CrossEntropyLoss', use_mask=True, use_sigmoid=True, loss_weight=1.0) build_loss(loss_cfg) # test loss with class weights loss_cls_cfg = dict( type='CrossEntropyLoss', use_sigmoid=False, class_weight=[0.8, 0.2], loss_weight=1.0) loss_cls = build_loss(loss_cls_cfg) fake_pred = torch.Tensor([[100, -100]]) fake_label = torch.Tensor([1]).long() assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) loss_cls_cfg = dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) loss_cls = build_loss(loss_cls_cfg) assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.))
Example #23
Source File: autograd_hacks_test.py From autograd-hacks with The Unlicense | 5 votes |
def subtest_hess_type(hess_type): torch.manual_seed(1) model = TinyNet() def least_squares_loss(data_, targets_): assert len(data_) == len(targets_) err = data_ - targets_ return torch.sum(err * err) / 2 / len(data_) n = 3 data = torch.rand(n, 1, 28, 28) autograd_hacks.add_hooks(model) output = model(data) if hess_type == 'LeastSquares': targets = torch.rand(output.shape) loss_fn = least_squares_loss else: # hess_type == 'CrossEntropy': targets = torch.LongTensor(n).random_(0, 10) loss_fn = nn.CrossEntropyLoss() autograd_hacks.backprop_hess(output, hess_type=hess_type) autograd_hacks.clear_backprops(model) autograd_hacks.backprop_hess(output, hess_type=hess_type) autograd_hacks.compute_hess(model) autograd_hacks.disable_hooks() for layer in model.modules(): if not autograd_hacks.is_supported(layer): continue for param in layer.parameters(): loss = loss_fn(output, targets) hess_autograd = hessian(loss, param) hess = param.hess assert torch.allclose(hess, hess_autograd.reshape(hess.shape))
Example #24
Source File: test_nn.py From PySyft with Apache License 2.0 | 5 votes |
def test_conv2d(workers): """ Test the nn.Conv2d module to ensure that it produces the exact same output as the primary torch implementation, in the same order. """ torch.manual_seed(121) # Truncation might not always work so we set the random seed # Disable mkldnn to avoid rounding errors due to difference in implementation mkldnn_enabled_init = torch._C._get_mkldnn_enabled() torch._C._set_mkldnn_enabled(False) # Direct Import from Syft model = syft_nn.Conv2d(1, 2, 3, bias=True) model_1 = nn.Conv2d(1, 2, 3, bias=True) model.weight = model_1.weight.fix_prec() model.bias = model_1.bias.fix_prec() data = torch.rand(10, 1, 28, 28) # eg. mnist data out = model(data.fix_prec()).float_prec() out_1 = model_1(data) assert torch.allclose(out, out_1, atol=1e-2) # Fixed Precision Tensor model_2 = model_1.copy().fix_prec() out_2 = model_2(data.fix_prec()).float_prec() # Note: absolute tolerance can be reduced by increasing precision_fractional of fix_prec() assert torch.allclose(out_1, out_2, atol=1e-2) # Additive Shared Tensor bob, alice, james = (workers["bob"], workers["alice"], workers["james"]) shared_data = data.fix_prec().share(bob, alice, crypto_provider=james) mode_3 = model_2.share(bob, alice, crypto_provider=james) out_3 = mode_3(shared_data).get().float_prec() assert torch.allclose(out_1, out_3, atol=1e-2) # Reset mkldnn to the original state torch._C._set_mkldnn_enabled(mkldnn_enabled_init)
Example #25
Source File: test_nn.py From PySyft with Apache License 2.0 | 5 votes |
def test_cnn_model(workers): torch.manual_seed(121) # Truncation might not always work so we set the random seed bob, alice, james = (workers["bob"], workers["alice"], workers["james"]) class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5, 1) self.conv2 = nn.Conv2d(20, 50, 5, 1) self.fc1 = nn.Linear(4 * 4 * 50, 500) self.fc2 = nn.Linear(500, 10) def forward(self, x): # TODO: uncomment maxpool2d operations # once it is supported with smpc. x = F.relu(self.conv1(x)) # x = F.max_pool2d(x, 2, 2) x = F.relu(self.conv2(x)) # x = F.max_pool2d(x, 2, 2) x = x.view(-1, 4 * 4 * 50) x = F.relu(self.fc1(x)) x = self.fc2(x) return x model = Net() sh_model = copy.deepcopy(model).fix_precision().share(alice, bob, crypto_provider=james) data = torch.zeros((1, 1, 28, 28)) sh_data = torch.zeros((1, 1, 28, 28)).fix_precision().share(alice, bob, crypto_provider=james) assert torch.allclose(sh_model(sh_data).get().float_prec(), model(data), atol=1e-2)
Example #26
Source File: losses_generic_utils_test.py From ClassyVision with MIT License | 5 votes |
def test_two(self): targets = torch.tensor([[0], [1]]) one_hot_target = convert_to_one_hot(targets, 3) self.assertTrue( torch.allclose(one_hot_target, torch.tensor([[1, 0, 0], [0, 1, 0]])) )
Example #27
Source File: __init__.py From dgl with Apache License 2.0 | 5 votes |
def allclose(a, b, rtol=1e-4, atol=1e-4): return th.allclose(a.float().cpu(), b.float().cpu(), rtol=rtol, atol=atol)
Example #28
Source File: test_differentiable_sgd.py From garage with MIT License | 5 votes |
def test_differentiable_sgd(): """Test second order derivative after taking optimization step.""" policy = torch.nn.Linear(10, 10, bias=False) lr = 0.01 diff_sgd = DifferentiableSGD(policy, lr=lr) named_theta = dict(policy.named_parameters()) theta = list(named_theta.values())[0] meta_loss = torch.sum(theta**2) meta_loss.backward(create_graph=True) diff_sgd.step() theta_prime = list(policy.parameters())[0] loss = torch.sum(theta_prime**2) update_module_params(policy, named_theta) diff_sgd.zero_grad() loss.backward() result = theta.grad dtheta_prime = 1 - 2 * lr # dtheta_prime/dtheta dloss = 2 * theta_prime # dloss/dtheta_prime expected_result = dloss * dtheta_prime # dloss/dtheta assert torch.allclose(result, expected_result)
Example #29
Source File: test_tanh_normal_dist.py From garage with MIT License | 5 votes |
def test_tanh_normal_log_prob(self): """Verify the correctnes of the tanh_normal log likelihood function.""" mean = torch.zeros(1) std = torch.ones(1) dist = TanhNormal(mean, std) pre_tanh_action = torch.Tensor([[2.0960]]) action = pre_tanh_action.tanh() log_prob = dist.log_prob(action, pre_tanh_action) log_prob_approx = dist.log_prob(action) assert torch.allclose(log_prob, torch.Tensor([-0.2798519])) assert torch.allclose(log_prob_approx, torch.Tensor([-0.2798519])) del dist
Example #30
Source File: test_nn.py From PySyft with Apache License 2.0 | 5 votes |
def test_RNNCell(): """ Test the RNNCell module to ensure that it produces the exact same output as the primary torch implementation, in the same order. """ # Disable mkldnn to avoid rounding errors due to difference in implementation mkldnn_enabled_init = torch._C._get_mkldnn_enabled() torch._C._set_mkldnn_enabled(False) batch_size = 5 input_size = 10 hidden_size = 50 test_input = torch.rand(batch_size, input_size) test_hidden = torch.rand(batch_size, hidden_size) # RNNCell implemented in pysyft rnn_syft = syft_nn.RNNCell(input_size, hidden_size, True, "tanh") # RNNCell implemented in original pytorch rnn_torch = nn.RNNCell(input_size, hidden_size, True, "tanh") # Make sure the weights of both RNNCell are identical rnn_syft.fc_xh.weight = rnn_torch.weight_ih rnn_syft.fc_hh.weight = rnn_torch.weight_hh rnn_syft.fc_xh.bias = rnn_torch.bias_ih rnn_syft.fc_hh.bias = rnn_torch.bias_hh output_syft = rnn_syft(test_input, test_hidden) output_torch = rnn_torch(test_input, test_hidden) assert torch.allclose(output_syft, output_torch, atol=1e-2) # Reset mkldnn to the original state torch._C._set_mkldnn_enabled(mkldnn_enabled_init)