Python torch.autograd.Function() Examples
The following are 6
code examples of torch.autograd.Function().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.autograd
, or try the search function
.
Example #1
Source File: losses.py From soccerontable with BSD 2-Clause "Simplified" License | 6 votes |
def forward(self, input, target, mask): self.loss = self.criterion(input, target*mask) return self.loss # class DepthTo3D(Function): # # def forward(self, input, pix_inv, R_inv, T): # self.save_for_backward(input, pix_inv, R_inv, T) # return torch.bmm(R_inv, input.resize(bs, 1, sx * sy).repeat(1, 3, 1) * pix_inv - T_var.repeat(1, 1,sx * sy)).resize(bs, 3, sx, sy) # # def backward(self): # # pix_inv, R_inv, T = self.saved_tensors # # return grad_input,
Example #2
Source File: test_operators.py From onnx-fb-universe with MIT License | 6 votes |
def test_symbolic_mismatch(self): class MyFun(Function): @staticmethod def symbolic(g, x): # The inside of this function should never be invoked, because # we will fail due to an argument mismatch first. assert False @staticmethod def forward(ctx, x, y): return x + y x = Variable(torch.randn(2, 2).fill_(1.0)) y = Variable(torch.randn(2, 2).fill_(1.0)) # NB: Don't use expect test here, the type error wobbles depending # on Python version with self.assertRaisesRegex(TypeError, "occurred when translating MyFun"): export_to_string(FuncModule(MyFun().apply), (x, y)) # TODO: Do an nn style test for these
Example #3
Source File: test_operators.py From onnx-fb-universe with MIT License | 6 votes |
def test_at_op(self): x = Variable(torch.randn(3, 4)) class MyFun(Function): @staticmethod def symbolic(g, x): return g.at("add", x, x) @staticmethod def forward(ctx, x): return x + x class MyModule(Module): def forward(self, x): return MyFun.apply(x) self.assertONNX(MyModule(), x)
Example #4
Source File: test_verify.py From onnx-fb-universe with MIT License | 6 votes |
def test_result_different(self): class BrokenAdd(Function): @staticmethod def symbolic(g, a, b): return g.op("Add", a, b) @staticmethod def forward(ctx, a, b): return a.sub(b) # yahaha! you found me! class MyModel(Module): def forward(self, x, y): return BrokenAdd().apply(x, y) x = Variable(torch.Tensor([1,2])) y = Variable(torch.Tensor([3,4])) self.assertVerifyExpectFail(MyModel(), (x, y), backend)
Example #5
Source File: miscs.py From homura with Apache License 2.0 | 6 votes |
def straight_backprop(function): """ A function whose `derivative` is as linear >>> straight_backprop_relu = straight_backprop(F.relu) >>> straight_backprop_relu(tensor) :param function: original function :return: modified function """ class _StraightBackprop(Function): @staticmethod def forward(ctx, inputs): return function(inputs) @staticmethod def backward(ctx, grad_outputs): return grad_outputs return _StraightBackprop.apply
Example #6
Source File: test_gridpooling.py From occupancy_networks with MIT License | 5 votes |
def grid_pooling_auto(pts, feat): xv_value, yv_value, zv_value = np.meshgrid(x_grids[:-1], y_grids[:-1], z_grids[:-1], indexing='ij') xv_value = xv_value.flatten() yv_value = yv_value.flatten() zv_value = zv_value.flatten() feat_cell = Variable(torch.zeros((len(x_grids)-1) * (len(y_grids)-1) * (len(z_grids)-1), C).type(dtype)) #for k in range(batchsize): for i_,(x_,y_,z_) in enumerate(zip(xv_value, yv_value, zv_value)): pts_index = pts_in_cell(pts.unsqueeze(0),[x_,y_,z_, x_+len_cell, y_+len_cell, z_+len_cell]) if len(pts_index)>0: pts_index = torch.LongTensor(pts_index).type(dtype_long) #pts_feat = feat.index_select(0, pts_index) pts_feat = feat[pts_index,:] # max pooling #pts_feat,_ = torch.max(pts_feat, 0) m = nn.MaxPool1d(len(pts_index)) pts_feat = m(pts_feat.t().unsqueeze(0)) feat_cell[i_, :] = pts_feat.squeeze() return feat_cell #class GridPooling(Function): # def forward(self, points, feat_points): # feat_cells = torch.zeros(W*H*D, C).type(dtype) # indices = -1 * torch.ones(W*H*D, C).type(dtype_long) # shape = torch.LongTensor([W, H, D]).type(dtype_long) # forward_utils.grid_pooling_forward(points, feat_points, shape, feat_cells, indices) # self.saved_indices = indices # return feat_cells # # def backward(self, grad_output): # grad_points = torch.zeros(N, C).type(torch.FloatTensor) # forward_utils.grid_pooling_backward( grad_output, self.saved_indices, grad_points) # return None, grad_points