Python torch.solve() Examples

The following are 18 code examples of torch.solve(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: test_cvxpylayer.py    From cvxpylayers with Apache License 2.0 6 votes vote down vote up
def test_example(self):
        n, m = 2, 3
        x = cp.Variable(n)
        A = cp.Parameter((m, n))
        b = cp.Parameter(m)
        constraints = [x >= 0]
        objective = cp.Minimize(0.5 * cp.pnorm(A @ x - b, p=1))
        problem = cp.Problem(objective, constraints)
        assert problem.is_dpp()

        cvxpylayer = CvxpyLayer(problem, parameters=[A, b], variables=[x])
        A_tch = torch.randn(m, n, requires_grad=True)
        b_tch = torch.randn(m, requires_grad=True)

        # solve the problem
        solution, = cvxpylayer(A_tch, b_tch)

        # compute the gradient of the sum of the solution with respect to A, b
        solution.sum().backward() 
Example #2
Source File: birkhoff_polytope.py    From geoopt with Apache License 2.0 5 votes vote down vote up
def proj_tangent(x, u):
    assert x.shape[-2:] == u.shape[-2:], "Wrong shapes"
    x, u = torch.broadcast_tensors(x, u)
    x_shape = x.shape
    x = x.reshape(-1, x_shape[-2], x_shape[-1])
    u = u.reshape(-1, x_shape[-2], x_shape[-1])
    xt = x.transpose(-1, -2)
    batch_size, n = x.shape[0:2]

    I = torch.eye(n, dtype=x.dtype, device=x.device)
    I = I.expand_as(x)

    mu = x * u

    A = linalg.block_matrix([[I, x], [xt, I]])

    B = A[:, :, 1:]

    z1 = mu.sum(dim=-1).unsqueeze(-1)
    zt1 = mu.sum(dim=-2).unsqueeze(-1)

    b = torch.cat([z1, zt1], dim=1,)
    rhs = B.transpose(1, 2) @ (b - A[:, :, 0:1])
    lhs = B.transpose(1, 2) @ B
    zeta, _ = torch.solve(rhs, lhs)
    alpha = torch.cat(
        [torch.ones(batch_size, 1, 1, dtype=x.dtype), zeta[:, 0 : n - 1]], dim=1
    )
    beta = zeta[:, n - 1 : 2 * n - 1]
    rgrad = mu - (alpha + beta.transpose(-1, -2)) * x

    rgrad = rgrad.reshape(x_shape)
    return rgrad 
Example #3
Source File: stiefel.py    From geoopt with Apache License 2.0 5 votes vote down vote up
def _transp_follow_one(
        self, x: torch.Tensor, v: torch.Tensor, *, u: torch.Tensor
    ) -> torch.Tensor:
        a = self._amat(x, u)
        rhs = v + 1 / 2 * a @ v
        lhs = -1 / 2 * a
        lhs[..., torch.arange(a.shape[-2]), torch.arange(x.shape[-2])] += 1
        qv, _ = torch.solve(rhs, lhs)
        return qv 
Example #4
Source File: test_manifold_basic.py    From geoopt with Apache License 2.0 5 votes vote down vote up
def proju_original(x, u):
    import geoopt.linalg as linalg

    # takes batch data
    # batch_size, n, _ = x.shape
    x_shape = x.shape
    x = x.reshape(-1, x_shape[-2], x_shape[-1])
    batch_size, n = x.shape[0:2]

    e = torch.ones(batch_size, n, 1, dtype=x.dtype)
    I = torch.unsqueeze(torch.eye(x.shape[-1], dtype=x.dtype), 0).repeat(
        batch_size, 1, 1
    )

    mu = x * u

    A = linalg.block_matrix([[I, x], [torch.transpose(x, 1, 2), I]])

    B = A[:, :, 1:]
    b = torch.cat(
        [
            torch.sum(mu, dim=2, keepdim=True),
            torch.transpose(torch.sum(mu, dim=1, keepdim=True), 1, 2),
        ],
        dim=1,
    )

    zeta, _ = torch.solve(B.transpose(1, 2) @ (b - A[:, :, 0:1]), B.transpose(1, 2) @ B)
    alpha = torch.cat(
        [torch.ones(batch_size, 1, 1, dtype=x.dtype), zeta[:, 0 : n - 1]], dim=1
    )
    beta = zeta[:, n - 1 : 2 * n - 1]
    rgrad = mu - (alpha @ e.transpose(1, 2) + e @ beta.transpose(1, 2)) * x

    rgrad = rgrad.reshape(x_shape)
    return rgrad 
Example #5
Source File: gaussian.py    From torch-kalman with MIT License 5 votes vote down vote up
def kalman_gain(covariance: Tensor, system_covariance: Tensor, H: Tensor):
        Ht = H.permute(0, 2, 1)
        covs_measured = torch.bmm(covariance, Ht)

        A = system_covariance.permute(0, 2, 1)
        B = covs_measured.permute(0, 2, 1)
        Kt, _ = torch.solve(B, A)
        K = Kt.permute(0, 2, 1)

        return K 
Example #6
Source File: test_pivoted_cholesky.py    From gpytorch with MIT License 5 votes vote down vote up
def test_solve_qr(self, dtype=torch.float64, tol=1e-8):
        size = 50
        X = torch.rand((size, 2)).to(dtype=dtype)
        y = torch.sin(torch.sum(X, 1)).unsqueeze(-1).to(dtype=dtype)
        with settings.min_preconditioning_size(0):
            noise = torch.DoubleTensor(size).uniform_(math.log(1e-3), math.log(1e-1)).exp_().to(dtype=dtype)
            lazy_tsr = RBFKernel().to(dtype=dtype)(X).evaluate_kernel().add_diag(noise)
            precondition_qr, _, logdet_qr = lazy_tsr._preconditioner()

            F = lazy_tsr._piv_chol_self
            M = noise.diag() + F.matmul(F.t())

        x_exact = torch.solve(y, M)[0]
        x_qr = precondition_qr(y)

        self.assertTrue(approx_equal(x_exact, x_qr, tol))

        logdet = 2 * torch.cholesky(M).diag().log().sum(-1)
        self.assertTrue(approx_equal(logdet, logdet_qr, tol)) 
Example #7
Source File: test_pivoted_cholesky.py    From gpytorch with MIT License 5 votes vote down vote up
def test_solve_qr_constant_noise(self, dtype=torch.float64, tol=1e-8):
        size = 50
        X = torch.rand((size, 2)).to(dtype=dtype)
        y = torch.sin(torch.sum(X, 1)).unsqueeze(-1).to(dtype=dtype)

        with settings.min_preconditioning_size(0):
            noise = 1e-2 * torch.ones(size, dtype=dtype)
            lazy_tsr = RBFKernel().to(dtype=dtype)(X).evaluate_kernel().add_diag(noise)
            precondition_qr, _, logdet_qr = lazy_tsr._preconditioner()

            F = lazy_tsr._piv_chol_self
        M = noise.diag() + F.matmul(F.t())

        x_exact = torch.solve(y, M)[0]
        x_qr = precondition_qr(y)

        self.assertTrue(approx_equal(x_exact, x_qr, tol))

        logdet = 2 * torch.cholesky(M).diag().log().sum(-1)
        self.assertTrue(approx_equal(logdet, logdet_qr, tol)) 
Example #8
Source File: test_minres.py    From gpytorch with MIT License 5 votes vote down vote up
def _test_minres(self, rhs_shape, shifts=None, matrix_batch_shape=torch.Size([])):
        size = rhs_shape[-2] if len(rhs_shape) > 1 else rhs_shape[-1]
        rhs = torch.randn(rhs_shape, dtype=torch.float64)

        matrix = torch.randn(*matrix_batch_shape, size, size, dtype=torch.float64)
        matrix = matrix.matmul(matrix.transpose(-1, -2))
        matrix.div_(matrix.norm())
        matrix.add_(torch.eye(size, dtype=torch.float64).mul_(1e-1))

        # Compute solves with minres
        if shifts is not None:
            shifts = shifts.type_as(rhs)

        with gpytorch.settings.minres_tolerance(1e-6):
            solves = minres(matrix, rhs=rhs, value=-1, shifts=shifts)

        # Make sure that we're not getting weird batch dim effects
        while matrix.dim() < len(rhs_shape):
            matrix = matrix.unsqueeze(0)

        # Maybe add shifts
        if shifts is not None:
            matrix = matrix - torch.mul(
                torch.eye(size, dtype=torch.float64), shifts.view(*shifts.shape, *[1 for _ in matrix.shape])
            )

        # Compute solves exactly
        actual, _ = torch.solve(rhs.unsqueeze(-1) if rhs.dim() == 1 else rhs, -matrix)
        if rhs.dim() == 1:
            actual = actual.squeeze(-1)

        self.assertAllClose(solves, actual, atol=1e-3, rtol=1e-4) 
Example #9
Source File: wishart_prior.py    From gpytorch with MIT License 5 votes vote down vote up
def log_prob(self, X):
        logdetp = torch.logdet(X)
        pinvK = torch.solve(self.K, X)[0]
        trpinvK = torch.diagonal(pinvK, dim1=-2, dim2=-1).sum(-1)  # trace in batch mode
        return self.C - 0.5 * ((self.nu + 2 * self.n) * logdetp + trpinvK) 
Example #10
Source File: echo_state_network.py    From pytorch-esn with MIT License 5 votes vote down vote up
def fit(self):
        if self.readout_training in {'gd', 'svd'}:
            return

        if self.readout_training == 'cholesky':
            W = torch.solve(self.XTy,
                           self.XTX + self.lambda_reg * torch.eye(
                               self.XTX.size(0), device=self.XTX.device))[0].t()
            self.XTX = None
            self.XTy = None

            self.readout.bias = nn.Parameter(W[:, 0])
            self.readout.weight = nn.Parameter(W[:, 1:])
        elif self.readout_training == 'inv':
            I = (self.lambda_reg * torch.eye(self.XTX.size(0))).to(
                self.XTX.device)
            A = self.XTX + I

            if torch.det(A) != 0:
                W = torch.mm(torch.inverse(A), self.XTy).t()
            else:
                pinv = torch.pinverse(A)
                W = torch.mm(pinv, self.XTy).t()

            self.readout.bias = nn.Parameter(W[:, 0])
            self.readout.weight = nn.Parameter(W[:, 1:])

            self.XTX = None
            self.XTy = None 
Example #11
Source File: test_cvxpylayer.py    From cvxpylayers with Apache License 2.0 5 votes vote down vote up
def test_least_squares(self):
        set_seed(243)
        m, n = 100, 20

        A = cp.Parameter((m, n))
        b = cp.Parameter(m)
        x = cp.Variable(n)
        obj = cp.sum_squares(A@x - b) + cp.sum_squares(x)
        prob = cp.Problem(cp.Minimize(obj))
        prob_th = CvxpyLayer(prob, [A, b], [x])

        A_th = torch.randn(m, n).double().requires_grad_()
        b_th = torch.randn(m).double().requires_grad_()

        x = prob_th(A_th, b_th, solver_args={"eps": 1e-10})[0]

        def lstsq(
            A,
            b): return torch.solve(
            (A_th.t() @ b_th).unsqueeze(1),
            A_th.t() @ A_th +
            torch.eye(n).double())[0]
        x_lstsq = lstsq(A_th, b_th)

        grad_A_cvxpy, grad_b_cvxpy = grad(x.sum(), [A_th, b_th])
        grad_A_lstsq, grad_b_lstsq = grad(x_lstsq.sum(), [A_th, b_th])

        self.assertAlmostEqual(
            torch.norm(
                grad_A_cvxpy -
                grad_A_lstsq).item(),
            0.0)
        self.assertAlmostEqual(
            torch.norm(
                grad_b_cvxpy -
                grad_b_lstsq).item(),
            0.0) 
Example #12
Source File: test_cvxpylayer.py    From cvxpylayers with Apache License 2.0 5 votes vote down vote up
def test_broadcasting(self):
        set_seed(243)
        n_batch, m, n = 2, 100, 20

        A = cp.Parameter((m, n))
        b = cp.Parameter(m)
        x = cp.Variable(n)
        obj = cp.sum_squares(A@x - b) + cp.sum_squares(x)
        prob = cp.Problem(cp.Minimize(obj))
        prob_th = CvxpyLayer(prob, [A, b], [x])

        A_th = torch.randn(m, n).double().requires_grad_()
        b_th = torch.randn(m).double().unsqueeze(0).repeat(n_batch, 1) \
            .requires_grad_()
        b_th_0 = b_th[0]

        x = prob_th(A_th, b_th, solver_args={"eps": 1e-10})[0]

        def lstsq(
            A,
            b): return torch.solve(
            (A.t() @ b).unsqueeze(1),
            A.t() @ A +
            torch.eye(n).double())[0]
        x_lstsq = lstsq(A_th, b_th_0)

        grad_A_cvxpy, grad_b_cvxpy = grad(x.sum(), [A_th, b_th])
        grad_A_lstsq, grad_b_lstsq = grad(x_lstsq.sum(), [A_th, b_th_0])

        self.assertAlmostEqual(
            torch.norm(
                grad_A_cvxpy / n_batch -
                grad_A_lstsq).item(),
            0.0)
        self.assertAlmostEqual(
            torch.norm(
                grad_b_cvxpy[0] -
                grad_b_lstsq).item(),
            0.0) 
Example #13
Source File: test_cvxpylayer.py    From cvxpylayers with Apache License 2.0 5 votes vote down vote up
def test_basic_gp(self):
        set_seed(243)

        x = cp.Variable(pos=True)
        y = cp.Variable(pos=True)
        z = cp.Variable(pos=True)

        a = cp.Parameter(pos=True, value=2.0)
        b = cp.Parameter(pos=True, value=1.0)
        c = cp.Parameter(value=0.5)

        objective_fn = 1/(x*y*z)
        constraints = [a*(x*y + x*z + y*z) <= b, x >= y**c]
        problem = cp.Problem(cp.Minimize(objective_fn), constraints)
        problem.solve(cp.SCS, gp=True, eps=1e-12)

        layer = CvxpyLayer(
            problem, parameters=[a, b, c], variables=[x, y, z], gp=True)
        a_tch = torch.tensor(2.0, requires_grad=True)
        b_tch = torch.tensor(1.0, requires_grad=True)
        c_tch = torch.tensor(0.5, requires_grad=True)
        with torch.no_grad():
            x_tch, y_tch, z_tch = layer(a_tch, b_tch, c_tch)

        self.assertAlmostEqual(x.value, x_tch.detach().numpy(), places=5)
        self.assertAlmostEqual(y.value, y_tch.detach().numpy(), places=5)
        self.assertAlmostEqual(z.value, z_tch.detach().numpy(), places=5)

        torch.autograd.gradcheck(lambda a, b, c: layer(
            a, b, c, solver_args={
                "eps": 1e-12, "acceleration_lookback": 0})[0].sum(),
                (a_tch, b_tch, c_tch), atol=1e-3, rtol=1e-3) 
Example #14
Source File: pairwise_gp.py    From botorch with MIT License 4 votes vote down vote up
def _util_newton_updates(self, x0, max_iter=100, xtol=None) -> Tensor:
        r"""Make `max_iter` newton updates on utility.

        This is used in `forward` to calculate and fill in gradient into tensors.
        Instead of doing utility -= H^-1 @ g, use substition method.
        See more explanation in _update_utility_derived_values.

        Args:
            x0: A `batch_size x n` dimension tensor, initial values
            max_iter: Max number of iterations
            xtol: Stop creteria. If `None`, do not stop until
                finishing `max_iter` updates
        """
        xtol = float("-Inf") if xtol is None else xtol
        dp, D, DT, sn, ch, ci = (
            self.datapoints,
            self.D,
            self.DT,
            self.std_noise,
            self.covar_chol,
            self.covar_inv,
        )
        covar = self.covar
        diff = float("Inf")
        i = 0
        x = x0
        eye = None
        while i < max_iter and diff > xtol:
            hl = self._hess_likelihood_f_sum(x, D, DT, sn)
            cov_hl = covar @ hl
            if eye is None:
                eye = torch.eye(cov_hl.size(-1)).expand(cov_hl.shape)
            cov_hl = cov_hl + eye  # add 1 to cov_hl
            g = self._grad_posterior_f(x, dp, D, DT, sn, ch, ci)
            cov_g = covar @ g.unsqueeze(-1)
            x_update = torch.solve(cov_g, cov_hl).solution.squeeze(-1)
            x_next = x - x_update
            diff = torch.norm(x - x_next)
            x = x_next
            i += 1
        return x

    # ============== public APIs ============== 
Example #15
Source File: sparse_image_warp_pytorch.py    From Speech-Transformer with MIT License 4 votes vote down vote up
def solve_interpolation(train_points, train_values, order, regularization_weight):
    b, n, d = train_points.shape
    k = train_values.shape[-1]

    # First, rename variables so that the notation (c, f, w, v, A, B, etc.)
    # follows https://en.wikipedia.org/wiki/Polyharmonic_spline.
    # To account for python style guidelines we use
    # matrix_a for A and matrix_b for B.

    c = train_points
    f = train_values.float()

    matrix_a = phi(cross_squared_distance_matrix(c, c), order).unsqueeze(0)  # [b, n, n]
    #     if regularization_weight > 0:
    #         batch_identity_matrix = array_ops.expand_dims(
    #           linalg_ops.eye(n, dtype=c.dtype), 0)
    #         matrix_a += regularization_weight * batch_identity_matrix

    # Append ones to the feature values for the bias term in the linear model.
    ones = torch.ones(1, dtype=train_points.dtype).view([-1, 1, 1])
    matrix_b = torch.cat((c, ones), 2).float()  # [b, n, d + 1]

    # [b, n + d + 1, n]
    left_block = torch.cat((matrix_a, torch.transpose(matrix_b, 2, 1)), 1)

    num_b_cols = matrix_b.shape[2]  # d + 1

    # In Tensorflow, zeros are used here. Pytorch gesv fails with zeros for some reason we don't understand.
    # So instead we use very tiny randn values (variance of one, zero mean) on one side of our multiplication.
    lhs_zeros = torch.randn((b, num_b_cols, num_b_cols)) / 1e10
    right_block = torch.cat((matrix_b, lhs_zeros),
                            1)  # [b, n + d + 1, d + 1]
    lhs = torch.cat((left_block, right_block),
                    2)  # [b, n + d + 1, n + d + 1]

    rhs_zeros = torch.zeros((b, d + 1, k), dtype=train_points.dtype).float()
    rhs = torch.cat((f, rhs_zeros), 1)  # [b, n + d + 1, k]

    # Then, solve the linear system and unpack the results.
    X, LU = torch.solve(rhs, lhs)
    w = X[:, :n, :]
    v = X[:, n:, :]

    return w, v 
Example #16
Source File: torch_wpe.py    From nara_wpe with MIT License 4 votes vote down vote up
def wpe_v6(Y, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode='full'):
    """
    Short of wpe_v7 with no extern references.
    Applicable in for-loops.

    >>> T = np.random.randint(100, 120)
    >>> D = np.random.randint(2, 8)
    >>> K = np.random.randint(3, 5)
    >>> delay = np.random.randint(0, 2)
    
    # Real test:
    >>> Y = np.random.normal(size=(D, T))
    >>> from nara_wpe import wpe as np_wpe
    >>> desired = np_wpe.wpe_v6(Y, K, delay, statistics_mode='full')
    >>> actual = wpe_v6(torch.tensor(Y), K, delay, statistics_mode='full').numpy()
    >>> np.testing.assert_allclose(actual, desired, atol=1e-6)

    # Complex test:
    >>> Y = np.random.normal(size=(D, T)) + 1j * np.random.normal(size=(D, T))
    >>> from nara_wpe import wpe as np_wpe
    >>> desired = np_wpe.wpe_v6(Y, K, delay, statistics_mode='full')
    >>> actual = wpe_v6(torch.tensor(Y), K, delay, statistics_mode='full').numpy()
    >>> np.testing.assert_allclose(actual, desired, atol=1e-6)
    """

    if statistics_mode == 'full':
        s = Ellipsis
    elif statistics_mode == 'valid':
        s = (Ellipsis, slice(delay + taps - 1, None))
    else:
        raise ValueError(statistics_mode)

    X = torch.clone(Y)
    Y_tilde = build_y_tilde(Y, taps, delay)
    for iteration in range(iterations):
        inverse_power = get_power_inverse(X, psd_context=psd_context)
        Y_tilde_inverse_power = Y_tilde * inverse_power[..., None, :]
        R = torch.matmul(Y_tilde_inverse_power[s], hermite(Y_tilde[s]))
        P = torch.matmul(Y_tilde_inverse_power[s], hermite(Y[s]))
        # G = _stable_solve(R, P)
        G, _ = torch.solve(P, R)
        X = Y - torch.matmul(hermite(G), Y_tilde)

    return X 
Example #17
Source File: models.py    From EndoscopyDepthEstimation-Pytorch with GNU General Public License v3.0 4 votes vote down vote up
def _warp_coordinate_generate(depth_maps_1, img_masks, translation_vectors, rotation_matrices, intrinsic_matrices):
    # Generate a meshgrid for each depth map to calculate value
    num_batch, height, width, channels = depth_maps_1.shape

    y_grid, x_grid = torch.meshgrid(
        [torch.arange(start=0, end=height, dtype=torch.float32).cuda(),
         torch.arange(start=0, end=width, dtype=torch.float32).cuda()])

    x_grid = x_grid.view(1, height, width, 1)
    y_grid = y_grid.view(1, height, width, 1)

    ones_grid = torch.ones((1, height, width, 1), dtype=torch.float32).cuda()

    # intrinsic_matrix_inverse = intrinsic_matrix.inverse()
    eye = torch.eye(3).float().cuda().view(1, 3, 3).expand(intrinsic_matrices.shape[0], -1, -1)
    intrinsic_matrices_inverse, _ = torch.solve(eye, intrinsic_matrices)

    rotation_matrices_inverse = rotation_matrices.transpose(1, 2)

    # The following is when we have different intrinsic matrices for samples within a batch
    temp_mat = torch.bmm(intrinsic_matrices, rotation_matrices_inverse)
    W = torch.bmm(temp_mat, -translation_vectors)
    M = torch.bmm(temp_mat, intrinsic_matrices_inverse)

    mesh_grid = torch.cat((x_grid, y_grid, ones_grid), dim=-1).view(height, width, 3, 1)
    intermediate_result = torch.matmul(M.view(-1, 1, 1, 3, 3), mesh_grid).view(-1, height, width, 3)

    depth_maps_2_calculate = W.view(-1, 3).narrow(dim=-1, start=2, length=1).view(-1, 1, 1, 1) + torch.mul(
        depth_maps_1,
        intermediate_result.narrow(dim=-1, start=2, length=1).view(-1, height,
                                                                   width, 1))

    # expand operation doesn't allocate new memory (repeat does)
    depth_maps_2_calculate = torch.tensor(1.0e30).float().cuda() * (torch.tensor(1.0).float().cuda() - img_masks) + \
                             img_masks * depth_maps_2_calculate

    # This is the source coordinate in coordinate system 2 but ordered in coordinate system 1 in order to warp image 2 to coordinate system 1
    u_2 = (W.view(-1, 3).narrow(dim=-1, start=0, length=1).view(-1, 1, 1, 1) + torch.mul(depth_maps_1,
                                                                                         intermediate_result.narrow(
                                                                                             dim=-1, start=0,
                                                                                             length=1).view(-1,
                                                                                                            height,
                                                                                                            width,
                                                                                                            1))) / depth_maps_2_calculate

    v_2 = (W.view(-1, 3).narrow(dim=-1, start=1, length=1).view(-1, 1, 1, 1) + torch.mul(depth_maps_1,
                                                                                         intermediate_result.narrow(
                                                                                             dim=-1, start=1,
                                                                                             length=1).view(-1,
                                                                                                            height,
                                                                                                            width,
                                                                                                            1))) / depth_maps_2_calculate
    return [u_2, v_2]


# Optical flow for frame 1 to frame 2 
Example #18
Source File: sparse_image_warp_pytorch.py    From SpecAugment with Apache License 2.0 4 votes vote down vote up
def solve_interpolation(train_points, train_values, order, regularization_weight):
    b, n, d = train_points.shape
    k = train_values.shape[-1]

    # First, rename variables so that the notation (c, f, w, v, A, B, etc.)
    # follows https://en.wikipedia.org/wiki/Polyharmonic_spline.
    # To account for python style guidelines we use
    # matrix_a for A and matrix_b for B.

    c = train_points
    f = train_values.float()

    matrix_a = phi(cross_squared_distance_matrix(c, c), order).unsqueeze(0)  # [b, n, n]
    #     if regularization_weight > 0:
    #         batch_identity_matrix = array_ops.expand_dims(
    #           linalg_ops.eye(n, dtype=c.dtype), 0)
    #         matrix_a += regularization_weight * batch_identity_matrix

    # Append ones to the feature values for the bias term in the linear model.
    ones = torch.ones(1, dtype=train_points.dtype).view([-1, 1, 1])
    matrix_b = torch.cat((c, ones), 2).float()  # [b, n, d + 1]

    # [b, n + d + 1, n]
    left_block = torch.cat((matrix_a, torch.transpose(matrix_b, 2, 1)), 1)

    num_b_cols = matrix_b.shape[2]  # d + 1

    # In Tensorflow, zeros are used here. Pytorch gesv fails with zeros for some reason we don't understand.
    # So instead we use very tiny randn values (variance of one, zero mean) on one side of our multiplication.
    lhs_zeros = torch.randn((b, num_b_cols, num_b_cols)) / 1e10
    right_block = torch.cat((matrix_b, lhs_zeros),
                            1)  # [b, n + d + 1, d + 1]
    lhs = torch.cat((left_block, right_block),
                    2)  # [b, n + d + 1, n + d + 1]

    rhs_zeros = torch.zeros((b, d + 1, k), dtype=train_points.dtype).float()
    rhs = torch.cat((f, rhs_zeros), 1)  # [b, n + d + 1, k]

    # Then, solve the linear system and unpack the results.
    X, LU = torch.solve(rhs, lhs)
    w = X[:, :n, :]
    v = X[:, n:, :]

    return w, v