Python chainer.functions.squeeze() Examples

The following are 30 code examples of chainer.functions.squeeze(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module chainer.functions , or try the search function .
Example #1
Source File: pcl.py    From chainerrl with MIT License 6 votes vote down vote up
def update(self, loss):

        self.average_loss += (
            (1 - self.average_loss_decay) *
            (asfloat(loss) - self.average_loss))

        # Compute gradients using thread-specific model
        self.model.cleargrads()
        F.squeeze(loss).backward()
        if self.train_async:
            # Copy the gradients to the globally shared model
            copy_param.copy_grad(
                target_link=self.shared_model, source_link=self.model)
            if self.process_idx == 0:
                xp = self.xp
                norm = sum(xp.sum(xp.square(param.grad))
                           for param in self.optimizer.target.params()
                           if param.grad is not None)
                self.logger.debug('grad norm:%s', norm)
        self.optimizer.update()

        if self.train_async:
            self.sync_parameters()
        if isinstance(self.model, Recurrent):
            self.model.unchain_backward() 
Example #2
Source File: reinforce.py    From chainerrl with MIT License 6 votes vote down vote up
def accumulate_grad(self):
        if self.n_backward == 0:
            self.model.cleargrads()
        # Compute losses
        losses = []
        for r_seq, log_prob_seq, ent_seq in zip(self.reward_sequences,
                                                self.log_prob_sequences,
                                                self.entropy_sequences):
            assert len(r_seq) - 1 == len(log_prob_seq) == len(ent_seq)
            # Convert rewards into returns (=sum of future rewards)
            R_seq = np.cumsum(list(reversed(r_seq[1:])))[::-1]
            for R, log_prob, entropy in zip(R_seq, log_prob_seq, ent_seq):
                loss = -R * log_prob - self.beta * entropy
                losses.append(loss)
        total_loss = chainerrl.functions.sum_arrays(losses)
        # When self.batchsize is future.types.newint.newint, dividing a
        # Variable with it will raise an error, so it is manually converted to
        # float here.
        total_loss /= float(self.batchsize)
        F.squeeze(total_loss).backward()
        self.reward_sequences = [[]]
        self.log_prob_sequences = [[]]
        self.entropy_sequences = [[]]
        self.n_backward += 1 
Example #3
Source File: chainer_functions.py    From chainer-compiler with MIT License 6 votes vote down vote up
def __call__(self, ty_args, ty_kwargs):
        x_type, = ty_args

        self.axis, lacks_axis = get_kwarg(ty_kwargs, 'axis', None)
        if isinstance(self.axis, int):
            self.axis = (self.axis,)

        if is_incomplete_shape(x_type.shape):
            # TODO: use ty_kwargs['axis'].size()
            if lacks_axis or self.axis is None:
                assert False, "chainer.fucntions.squeeze: cannot guess ndim of return type"

        self.check_type_forward(make_multiple_tc_variable(ty_args, ('x',)))

        if self.axis is not None:
            for i in self.axis:
                assert x_type.shape[i] == 1, "chainer.fucntions.squeeze: invalid axis"
        return self.infer_return(x_type) 
Example #4
Source File: light_voxelnet.py    From voxelnet_chainer with MIT License 6 votes vote down vote up
def __call__(self, x, *args):
        """
           Args:
               x (ndarray): Shape is (Batch * K, 7, t).
                            each set has (xi, yi, zi, ri, xi −vx, yi −vy, zi −vz).
                            vx, vy, vz is local mean at each voxel.
           Return:
               y (ndarray): Shape is (Batch * K, 128)
        """
        n_batch, n_channels, n_points = x.shape
        # mask = F.max(x, axis=(1, 2), keepdims=True).data != 0
        mask = F.max(x, axis=1, keepdims=True).data != 0
        active_length = 0 #mask.sum()

        # Convolution1D -> BN -> relu -> pool -> concat
        h = F.relu(self.bn1(self.conv1(x), active_length, mask))
        global_feat = F.max_pooling_nd(h, n_points)
        # Shape is (Batch, channel, points)
        global_feat_expand = F.tile(global_feat, (1, 1, n_points))
        h = F.concat((h, global_feat_expand))
        h *= mask

        h = self.conv2(h)
        return F.squeeze(F.max_pooling_nd(h, n_points)) 
Example #5
Source File: net.py    From convolutional_seq2seq with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def __call__(self, x, z, ze, mask, conv_mask):
        att_scale = self.xp.sum(
            mask, axis=2, keepdims=True)[:, None, :, :] ** 0.5
        pad = self.xp.zeros(
            (x.shape[0], x.shape[1], self.width - 1, 1), dtype=x.dtype)
        base_x = x
        z = F.squeeze(z, axis=3)
        # Note: these behaviors of input, output, and attention result
        # may refer to the code by authors, which looks little different
        # from the paper's saying.
        for conv_name, preatt_name in zip(self.conv_names, self.preatt_names):
            # Calculate Output of GLU
            out = getattr(self, conv_name)(
                F.concat([pad, x], axis=2), conv_mask)
            # Calcualte Output of Attention using Output of GLU
            preatt = seq_linear(getattr(self, preatt_name), out)
            query = base_x + preatt
            query = F.squeeze(query, axis=3)
            c = self.attend(query, z, ze, mask) * att_scale
            # Merge Them in Redidual Calculation and Scaling
            x = (x + (c + out) * scale05) * scale05

        return x 
Example #6
Source File: fsns.py    From see with GNU General Public License v3.0 6 votes vote down vote up
def attend(self, encoded_features):
        self.out_lstm.reset_state()
        transformed_encoded_features = F.concat([F.expand_dims(self.transform_encoded_features(feature), axis=1) for feature in encoded_features], axis=1)
        concat_encoded_features = F.concat([F.expand_dims(e, axis=1) for e in encoded_features], axis=1)

        lstm_output = self.xp.zeros_like(encoded_features[0])
        outputs = []
        for _ in range(self.num_labels):
            transformed_lstm_output = self.transform_out_lstm_feature(lstm_output)
            attended_feats = []
            for transformed_encoded_feature in F.separate(transformed_encoded_features, axis=1):
                attended_feat = transformed_encoded_feature + transformed_lstm_output
                attended_feat = F.tanh(attended_feat)
                attended_feats.append(self.generate_attended_feat(attended_feat))

            attended_feats = F.concat(attended_feats, axis=1)
            alphas = F.softmax(attended_feats, axis=1)

            lstm_input_feature = F.batch_matmul(alphas, concat_encoded_features, transa=True)
            lstm_input_feature = F.squeeze(lstm_input_feature, axis=1)
            lstm_output = self.out_lstm(lstm_input_feature)
            outputs.append(lstm_output)
        return outputs 
Example #7
Source File: svhn_bbox_plotter.py    From see with GNU General Public License v3.0 6 votes vote down vote up
def decode_predictions(self, predictions):
        # concat all individual predictions and slice for each time step
        predictions = F.concat([F.expand_dims(p, axis=0) for p in predictions], axis=0)

        words = []
        with cuda.get_device_from_array(predictions.data):
            for prediction in F.separate(predictions, axis=0):
                prediction = F.squeeze(prediction, axis=0)
                prediction = F.softmax(prediction, axis=1)
                prediction = self.xp.argmax(prediction.data, axis=1)
                word = self.loss_metrics.strip_prediction(prediction[self.xp.newaxis, ...])[0]
                if len(word) == 1 and word[0] == 0:
                    return ''

                word = "".join(map(self.loss_metrics.label_to_char, word))
                word = word.replace(chr(self.loss_metrics.char_map[str(self.loss_metrics.blank_symbol)]), '')
                words.append(word)

        text = " ".join(words)
        return text 
Example #8
Source File: elmo.py    From models with MIT License 5 votes vote down vote up
def _load_cnn_weights(self):
        cnn_options = self._options['char_cnn']
        filters = cnn_options['filters']
        char_embed_dim = cnn_options['embedding']['dim']

        convolutions = []

        for i, (width, num) in enumerate(filters):
            conv = L.Convolution2D(
                in_channels=char_embed_dim,
                out_channels=num,
                ksize=(width, 1),
                nobias=False
            )
            # load the weights
            with h5py.File(cached_path(self._weight_file), 'r') as fin:
                weight = fin['CNN']['W_cnn_{}'.format(i)][...]
                bias = fin['CNN']['b_cnn_{}'.format(i)][...]

            w_reshaped = numpy.transpose(
                weight.squeeze(axis=0), axes=(2, 1, 0))
            # if w_reshaped.shape != tuple(conv.W.data.shape):
            #     raise ValueError("Invalid weight file")
            w_reshaped = w_reshaped[:, :, :, None]
            conv.W.data[:] = w_reshaped
            conv.b.data[:] = bias

            conv.W._requires_grad = self.requires_grad
            conv.b._requires_grad = self.requires_grad

            convolutions.append(conv)
            with self.init_scope():
                setattr(self, 'char_conv_{}'.format(i), conv)

        self._convolutions = convolutions 
Example #9
Source File: test_squeeze.py    From chainer with MIT License 5 votes vote down vote up
def check_invalid_type(self, x_data):
        with self.assertRaises(ValueError):
            functions.squeeze(x_data, axis=self.axis) 
Example #10
Source File: test_squeeze.py    From chainer with MIT License 5 votes vote down vote up
def test_invalid_axis(self):
        with self.assertRaises(TypeError):
            functions.squeeze(self.x, axis='a') 
Example #11
Source File: memnn.py    From pfio with MIT License 5 votes vote down vote up
def query(self, u):
        xp = backend.get_array_module(u)
        size = self.m.shape[1]
        inds = xp.arange(size - 1, -1, -1, dtype=numpy.int32)
        tm = self.TA(inds)
        tc = self.TC(inds)
        tm = F.broadcast_to(tm, self.m.shape)
        tc = F.broadcast_to(tc, self.c.shape)
        p = F.softmax(F.batch_matmul(self.m + tm, u))
        o = F.batch_matmul(F.swapaxes(self.c + tc, 2, 1), p)
        o = F.squeeze(o, -1)
        u = o + u
        return u 
Example #12
Source File: text_rec_bbox_plotter.py    From see with GNU General Public License v3.0 5 votes vote down vote up
def decode_predictions(self, predictions):
        # concat all individual predictions and slice for each time step
        predictions = F.concat([F.expand_dims(prediction, axis=0) for prediction in predictions], axis=0)

        with cuda.get_device_from_array(predictions.data):
            prediction = F.squeeze(predictions, axis=1)
            classification = F.softmax(prediction, axis=1)
            classification = classification.data
            classification = self.xp.argmax(classification, axis=1)

            words = self.loss_metrics.strip_prediction(classification[self.xp.newaxis, ...])[0]
            word = "".join(map(self.loss_metrics.label_to_char, words))

        return word 
Example #13
Source File: textrec_bbox_plotter.py    From see with GNU General Public License v3.0 5 votes vote down vote up
def decode_predictions(self, predictions):
        # concat all individual predictions and slice for each time step
        predictions = predictions[0]

        with cuda.get_device_from_array(predictions.data):
            prediction = F.squeeze(predictions, axis=1)
            classification = F.softmax(prediction, axis=1)
            classification = classification.data
            classification = self.xp.argmax(classification, axis=1)

            words = self.loss_metrics.strip_prediction(classification[self.xp.newaxis, ...])[0]
            word = "".join(map(self.loss_metrics.label_to_char, words))

        return word 
Example #14
Source File: modeling.py    From models with MIT License 5 votes vote down vote up
def predict(self, input_ids, input_mask, token_type_ids, unique_ids):
        (start_logits, end_logits) = self.__call__(
            input_ids, input_mask, token_type_ids)
        predictions = {
            "unique_ids": unique_ids[:, 0].tolist(),  # squeeze
            "start_logits": start_logits.array.tolist(),
            "end_logits": end_logits.array.tolist(),
        }
        return predictions 
Example #15
Source File: acer.py    From chainerrl with MIT License 5 votes vote down vote up
def compute_loss_with_kl_constraint(distrib, another_distrib, original_loss,
                                    delta):
    """Compute loss considering a KL constraint.

    Args:
        distrib (Distribution): Distribution to optimize
        another_distrib (Distribution): Distribution used to compute KL
        original_loss (chainer.Variable): Loss to minimize
        delta (float): Minimum KL difference
    Returns:
        loss (chainer.Variable)
    """
    for param in distrib.params:
        assert param.shape[0] == 1
        assert param.requires_grad
    # Compute g: a direction to minimize the original loss
    g = [grad.array[0] for grad in
         chainer.grad([F.squeeze(original_loss)], distrib.params)]

    # Compute k: a direction to increase KL div.
    kl = F.squeeze(another_distrib.kl(distrib))
    k = [grad.array[0] for grad in
         chainer.grad([-kl], distrib.params)]

    # Compute z: combination of g and k to keep small KL div.
    kg_dot = sum(np.dot(kp.ravel(), gp.ravel())
                 for kp, gp in zip(k, g))
    kk_dot = sum(np.dot(kp.ravel(), kp.ravel()) for kp in k)
    if kk_dot > 0:
        k_factor = max(0, ((kg_dot - delta) / kk_dot))
    else:
        k_factor = 0
    z = [gp - k_factor * kp for kp, gp in zip(k, g)]
    loss = 0
    for p, zp in zip(distrib.params, z):
        loss += F.sum(p * zp)
    return F.reshape(loss, original_loss.shape), float(kl.array) 
Example #16
Source File: elmo.py    From models with MIT License 5 votes vote down vote up
def forward(self, inputs):
        """
        Parameters
        ----------
        inputs: ``torch.autograd.Variable``
            Shape ``(batch_size, timesteps, 50)`` of character ids representing the current batch.

        Returns
        -------
        Dict with keys:

        ``'activations'``: ``List[torch.autograd.Variable]``
            A list of activations at each layer of the network, each of shape
            ``(batch_size, timesteps + 2, embedding_dim)``
        ``'mask'``:  ``torch.autograd.Variable``
            Shape ``(batch_size, timesteps + 2)`` long tensor with sequence mask.

        Note that the output tensors all include additional special begin and end of sequence
        markers.
        """
        token_embedding = self._token_embedder.forward(inputs)
        type_representation = token_embedding['token_embedding']
        mask = token_embedding['mask']

        lstm_outputs = self._elmo_lstm.forward(type_representation, mask)

        # Prepare the output.  The first layer is duplicated.
        output_tensors = [
            F.concat([type_representation, type_representation], axis=-1)
        ]
        for layer_activations in F.split_axis(lstm_outputs, lstm_outputs.shape[0], axis=0):
            output_tensors.append(F.squeeze(layer_activations, 0))

        return {
            'activations': output_tensors,
            'mask': mask,
        } 
Example #17
Source File: light_voxelnet.py    From voxelnet_chainer with MIT License 5 votes vote down vote up
def __call__(self, x, *args):
        """
           Args:
               x (ndarray): Shape is (Batch * K, 7, t).
                            each set has (xi, yi, zi, intensity, xi−vx, yi−vy, zi−vz).
                            vx, vy, vz is local mean at each voxel.
           Return:
               y (ndarray): Shape is (Batch * K, 128)
        """
        n_batch, n_channels, n_points = x.shape
        # mask = F.max(x, axis=(1, 2), keepdims=True).data != 0
        mask = F.max(x, axis=1, keepdims=True).data != 0
        active_length = 0 #mask.sum()

        # Convolution1D -> BN -> relu -> pool -> concat
        h = F.relu(self.bn1(self.conv1(x), active_length, mask))
        global_feat = F.max_pooling_nd(h, n_points)
        # Shape is (Batch, channel, points)
        global_feat_expand = F.tile(global_feat, (1, 1, n_points))
        h = F.concat((h, global_feat_expand))
        h *= mask

        h = F.relu(self.bn2(self.conv2(h), active_length, mask))
        global_feat = F.max_pooling_nd(h, n_points)
        global_feat_expand = F.tile(global_feat, (1, 1, n_points))
        h = F.concat((h, global_feat_expand))
        h *= mask

        # h = F.relu(self.bn3(self.conv3(h), active_length))
        h = self.conv3(h)
        # h *= mask
        return F.squeeze(F.max_pooling_nd(h, n_points)) 
Example #18
Source File: light_voxelnet.py    From voxelnet_chainer with MIT License 5 votes vote down vote up
def __call__(self, x, *args):
        """
           Args:
               x (ndarray): Shape is (Batch * K, 7, t).
                            each set has (xi, yi, zi, ri, xi −vx, yi −vy, zi −vz).
                            vx, vy, vz is local mean at each voxel.
           Return:
               y (ndarray): Shape is (Batch * K, 128)
        """
        n_batch, n_channels, n_points = x.shape
        h = F.relu(self.conv1(x))
        return F.squeeze(F.max_pooling_nd(h, n_points)) 
Example #19
Source File: light_voxelnet.py    From voxelnet_chainer with MIT License 5 votes vote down vote up
def __call__(self, x, *args):
        """
           Args:
               x (ndarray): Shape is (Batch * K, 7, t).
                            each set has (xi, yi, zi, ri, xi −vx, yi −vy, zi −vz).
                            vx, vy, vz are local mean at each voxel.
           Return:
               y (ndarray): Shape is (Batch * K, 128)
        """
        n_batch, n_channels, n_points = x.shape
        # mask = F.max(x, axis=(1, 2), keepdims=True).data != 0
        mask = F.max(x, axis=1, keepdims=True).data != 0
        active_length = 0 #mask.sum()

        # Convolution1D -> BN -> relu -> pool -> concat
        h = F.relu(self.bn1(self.conv1(x), active_length, mask))
        global_feat = F.max_pooling_nd(h, n_points)
        # Shape is (Batch, channel, points)
        global_feat_expand = F.tile(global_feat, (1, 1, n_points))
        h = F.concat((h, global_feat_expand))
        h *= mask

        h = F.relu(self.bn2(self.conv2(h), active_length, mask))
        global_feat = F.max_pooling_nd(h, n_points)
        global_feat_expand = F.tile(global_feat, (1, 1, n_points))
        h = F.concat((h, global_feat_expand))
        h *= mask

        # h = F.relu(self.bn3(self.conv3(h), active_length))
        h = self.conv3(h)
        # h *= mask
        return F.squeeze(F.max_pooling_nd(h, n_points)) 
Example #20
Source File: GAIN.py    From Guided-Attention-Inference-Network with MIT License 5 votes vote down vote up
def get_mask(gcam, sigma=.5, w=8):
		gcam = (gcam - F.min(gcam).data)/(F.max(gcam) - F.min(gcam)).data
		mask = F.squeeze(F.sigmoid(w * (gcam - sigma)))
		return mask 
Example #21
Source File: test_squeeze.py    From chainer with MIT License 5 votes vote down vote up
def forward_expected(self, inputs):
        x, = inputs
        y = numpy.squeeze(x, axis=self.axis)
        return y, 
Example #22
Source File: acer.py    From chainerrl with MIT License 5 votes vote down vote up
def update(self, t_start, t_stop, R, states, actions, rewards, values,
               action_values, action_distribs, action_distribs_mu,
               avg_action_distribs):

        assert np.isscalar(R)

        total_loss = self.compute_loss(
            t_start=t_start,
            t_stop=t_stop,
            R=R,
            states=states,
            actions=actions,
            rewards=rewards,
            values=values,
            action_values=action_values,
            action_distribs=action_distribs,
            action_distribs_mu=action_distribs_mu,
            avg_action_distribs=avg_action_distribs)

        # Compute gradients using thread-specific model
        self.model.cleargrads()
        F.squeeze(total_loss).backward()
        # Copy the gradients to the globally shared model
        copy_param.copy_grad(
            target_link=self.shared_model, source_link=self.model)
        # Update the globally shared model
        if self.process_idx == 0:
            norm = sum(np.sum(np.square(param.grad))
                       for param in self.optimizer.target.params()
                       if param.grad is not None)
            self.logger.debug('grad norm:%s', norm)
        self.optimizer.update()

        self.sync_parameters()
        if isinstance(self.model, Recurrent):
            self.model.unchain_backward() 
Example #23
Source File: Squeeze.py    From chainer-compiler with MIT License 5 votes vote down vote up
def forward(self, x):
        return F.squeeze(x, 1) 
Example #24
Source File: Squeeze.py    From chainer-compiler with MIT License 5 votes vote down vote up
def forward(self, x):
        return F.squeeze(x, axis=(1, 3)) 
Example #25
Source File: Squeeze.py    From chainer-compiler with MIT License 5 votes vote down vote up
def forward(self, x):
        return F.squeeze(x, 1) 
Example #26
Source File: Squeeze.py    From chainer-compiler with MIT License 5 votes vote down vote up
def forward(self, x):
        return F.squeeze(x, axis=(1, 3)) 
Example #27
Source File: Squeeze.py    From chainer-compiler with MIT License 5 votes vote down vote up
def forward(self, x):
        return F.squeeze(x)


# ====================================== 
Example #28
Source File: ExtFunctions_test.py    From chainer-compiler with MIT License 5 votes vote down vote up
def test_squeeze(self):
        class Test():
            def forward(self):
                F.squeeze(np.zeros((2, 1, 1, 3)))
                F.squeeze(np.zeros((2, 1, 1, 3)), axis=2)
                F.squeeze(np.zeros((2, 1, 1, 3)), axis=(1,2))

        id2type = generate_id2type_from_forward(Test(), ())

        self.assertEqual(str(id2type[1]), "class Test -> NoneType")	# FunctionDef forward (line 1)
        self.assertEqual(str(id2type[5]), "NoneType")	# Expr
        self.assertEqual(str(id2type[6]), "Variable(float64, (2, 3))")	# Call F.squeeze(np.zeros((2, 1, 1, 3))) (line 2)
        self.assertEqual(str(id2type[11]), "ndarray(float64, (2, 1, 1, 3))")	# Call np.zeros((2, 1, 1, 3)) (line 2)
        self.assertEqual(str(id2type[16]), "(int, int, int, int)")	# Tuple (2, 1, 1, 3) (line 2)
        self.assertEqual(str(id2type[17]), "int")	# Num 2 (line 2)
        self.assertEqual(str(id2type[18]), "int")	# Num 1 (line 2)
        self.assertEqual(str(id2type[19]), "int")	# Num 1 (line 2)
        self.assertEqual(str(id2type[20]), "int")	# Num 3 (line 2)
        self.assertEqual(str(id2type[22]), "NoneType")	# Expr
        self.assertEqual(str(id2type[23]), "Variable(float64, (2, 1, 3))")	# Call F.squeeze(np.zeros((2, 1, 1, 3)), axis=2) (line 3)
        self.assertEqual(str(id2type[28]), "ndarray(float64, (2, 1, 1, 3))")	# Call np.zeros((2, 1, 1, 3)) (line 3)
        self.assertEqual(str(id2type[33]), "(int, int, int, int)")	# Tuple (2, 1, 1, 3) (line 3)
        self.assertEqual(str(id2type[34]), "int")	# Num 2 (line 3)
        self.assertEqual(str(id2type[35]), "int")	# Num 1 (line 3)
        self.assertEqual(str(id2type[36]), "int")	# Num 1 (line 3)
        self.assertEqual(str(id2type[37]), "int")	# Num 3 (line 3)
        self.assertEqual(str(id2type[40]), "int")	# Num 2 (line 3)
        self.assertEqual(str(id2type[41]), "NoneType")	# Expr
        self.assertEqual(str(id2type[42]), "Variable(float64, (2, 3))")	# Call F.squeeze(np.zeros((2, 1, 1, 3)), axis=(1, 2)) (line 4)
        self.assertEqual(str(id2type[47]), "ndarray(float64, (2, 1, 1, 3))")	# Call np.zeros((2, 1, 1, 3)) (line 4)
        self.assertEqual(str(id2type[52]), "(int, int, int, int)")	# Tuple (2, 1, 1, 3) (line 4)
        self.assertEqual(str(id2type[53]), "int")	# Num 2 (line 4)
        self.assertEqual(str(id2type[54]), "int")	# Num 1 (line 4)
        self.assertEqual(str(id2type[55]), "int")	# Num 1 (line 4)
        self.assertEqual(str(id2type[56]), "int")	# Num 3 (line 4)
        self.assertEqual(str(id2type[59]), "(int, int)")	# Tuple (1, 2) (line 4)
        self.assertEqual(str(id2type[60]), "int")	# Num 1 (line 4)
        self.assertEqual(str(id2type[61]), "int")	# Num 2 (line 4) 
Example #29
Source File: model.py    From GP-GAN with MIT License 5 votes vote down vote up
def __call__(self, x):
        x = self.encode(x)
        x = F.sum(x, axis=0) / x.shape[0]
        return F.squeeze(x) 
Example #30
Source File: ddpg_pendulum.py    From chainer with MIT License 5 votes vote down vote up
def update(Q, target_Q, policy, target_policy, opt_Q, opt_policy,
           samples, gamma=0.99):
    """Update a Q-function and a policy."""
    dtype = chainer.get_dtype()
    xp = Q.xp
    obs = xp.asarray([sample[0] for sample in samples], dtype=dtype)
    action = xp.asarray([sample[1] for sample in samples], dtype=dtype)
    reward = xp.asarray([sample[2] for sample in samples], dtype=dtype)
    done = xp.asarray([sample[3] for sample in samples], dtype=dtype)
    obs_next = xp.asarray([sample[4] for sample in samples], dtype=dtype)

    def update_Q():
        # Predicted values: Q(s,a)
        y = F.squeeze(Q(obs, action), axis=1)
        # Target values: r + gamma * Q(s,policy(s))
        with chainer.no_backprop_mode():
            next_q = F.squeeze(target_Q(obs_next, target_policy(obs_next)),
                               axis=1)
            target = reward + gamma * (1 - done) * next_q
        loss = F.mean_squared_error(y, target)
        Q.cleargrads()
        loss.backward()
        opt_Q.update()

    def update_policy():
        # Maximize Q(s,policy(s))
        q = Q(obs, policy(obs))
        q = q[:]  # Avoid https://github.com/chainer/chainer/issues/2744
        loss = - F.mean(q)
        policy.cleargrads()
        loss.backward()
        opt_policy.update()

    update_Q()
    update_policy()