Python tensorflow.batch_gather() Examples
The following are 13
code examples of tensorflow.batch_gather().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: neigh_samplers.py From DGFraud with Apache License 2.0 | 6 votes |
def _call(self, inputs): eps = 0.001 ids, num_samples, features, batch_size = inputs adj_lists = tf.gather(self.adj_info, ids) node_features = tf.gather(features, ids) feature_size = tf.shape(features)[-1] node_feature_repeat = tf.tile(node_features, [1,self.num_neighs]) node_feature_repeat = tf.reshape(node_feature_repeat, [batch_size, self.num_neighs, feature_size]) neighbor_feature = tf.gather(features, adj_lists) distance = tf.sqrt(tf.reduce_sum(tf.square(node_feature_repeat - neighbor_feature), -1)) prob = tf.exp(-distance) prob_sum = tf.reduce_sum(prob, -1, keepdims=True) prob_sum = tf.tile(prob_sum, [1,self.num_neighs]) prob = tf.divide(prob, prob_sum) prob = tf.where(prob>eps, prob, 0*prob) # uncommenting this line to use eps to filter small probabilities samples_idx = tf.random.categorical(tf.math.log(prob), num_samples) selected = tf.batch_gather(adj_lists, samples_idx) return selected
Example #2
Source File: neural_dater.py From NeuralDater with Apache License 2.0 | 6 votes |
def gather(self, data, pl_idx, pl_mask, max_len, name=None): """ Lookup equivalent for tensors with dim > 2 (Can be simplified using tf.batch_gather) Parameters ---------- data: Tensor in which lookup has to be performed pl_idx: The indices to be taken pl_mask: For handling padding in pl_idx max_len: Maximum length of indices Returns ------- et_vecs * mask_vec: Extracted vectors at given indices """ idx1 = tf.range(self.p.batch_size, dtype=tf.int32) idx1 = tf.reshape(idx1, [-1, 1]) idx1_ = tf.reshape(tf.tile(idx1, [1, max_len]) , [-1, 1]) idx_reshape = tf.reshape(pl_idx, [-1, 1]) indices = tf.concat((idx1_, idx_reshape), axis=1) et_vecs = tf.gather_nd(data, indices) et_vecs = tf.reshape(et_vecs, [self.p.batch_size, self.max_et, -1]) mask_vec = tf.expand_dims(pl_mask, axis=2) return et_vecs * mask_vec
Example #3
Source File: augdesc.py From pyslam with GNU General Public License v3.0 | 6 votes |
def _interpolate(self, xy1, xy2, points2): batch_size = tf.shape(xy1)[0] ndataset1 = tf.shape(xy1)[1] eps = 1e-6 dist_mat = tf.matmul(xy1, xy2, transpose_b=True) norm1 = tf.reduce_sum(xy1 * xy1, axis=-1, keepdims=True) norm2 = tf.reduce_sum(xy2 * xy2, axis=-1, keepdims=True) dist_mat = tf.sqrt(norm1 - 2 * dist_mat + tf.linalg.matrix_transpose(norm2) + eps) dist, idx = tf.math.top_k(tf.negative(dist_mat), k=3) dist = tf.maximum(dist, 1e-10) norm = tf.reduce_sum((1.0 / dist), axis=2, keepdims=True) norm = tf.tile(norm, [1, 1, 3]) weight = (1.0 / dist) / norm idx = tf.reshape(idx, (batch_size, -1)) nn_points = tf.batch_gather(points2, idx) nn_points = tf.reshape(nn_points, (batch_size, ndataset1, 3, points2.get_shape()[-1].value)) interpolated_points = tf.reduce_sum(weight[..., tf.newaxis] * nn_points, axis=-2) return interpolated_points
Example #4
Source File: modeling.py From grover with Apache License 2.0 | 5 votes |
def _top_k_sample(logits, ignore_ids=None, num_samples=1, k=10): """ Does top-k sampling. if ignore_ids is on, then we will zero out those logits. :param logits: [batch_size, vocab_size] tensor :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict, like padding maybe :param p: topp threshold to use, either a float or a [batch_size] vector :return: [batch_size, num_samples] samples # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK """ with tf.variable_scope('top_p_sample'): batch_size, vocab_size = get_shape_list(logits, expected_rank=2) probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10, axis=-1) # [batch_size, vocab_perm] indices = tf.argsort(probs, direction='DESCENDING') # find the top pth index to cut off. careful we don't want to cutoff everything! # result will be [batch_size, vocab_perm] k_expanded = k if isinstance(k, int) else k[:, None] exclude_mask = tf.range(vocab_size)[None] >= k_expanded # OPTION A - sample in the sorted space, then unsort. logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10 sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples) sample = tf.batch_gather(indices, sample_perm) return { 'probs': probs, 'sample': sample, }
Example #5
Source File: nuelus_sampling_utils.py From BERT with Apache License 2.0 | 5 votes |
def nucleus_sampling(logits, vocab_size, p=0.9, input_ids=None, input_ori_ids=None, **kargs): input_shape_list = bert_utils.get_shape_list(logits, expected_rank=[2,3]) if len(input_shape_list) == 3: logits = tf.reshape(logits, (-1, vocab_size)) probs = tf.nn.softmax(logits, axis=-1) # [batch_size, seq, vocab_perm] # indices = tf.argsort(probs, direction='DESCENDING') indices = tf.contrib.framework.argsort(probs, direction='DESCENDING') cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False) # find the top pth index to cut off. careful we don't want to cutoff everything! # result will be [batch_size, seq, vocab_perm] exclude_mask = tf.logical_not( tf.logical_or(cumulative_probabilities < p, tf.range(vocab_size)[None] < 1)) exclude_mask = tf.cast(exclude_mask, tf.float32) indices_v1 = tf.contrib.framework.argsort(indices) exclude_mask = reorder(exclude_mask, tf.cast(indices_v1, dtype=tf.int32)) if len(input_shape_list) == 3: exclude_mask = tf.reshape(exclude_mask, input_shape_list) # logits = tf.reshape(logits, input_shape_list) if input_ids is not None and input_ori_ids is not None: exclude_mask, input_ori_ids = get_extra_mask( input_ids, input_ori_ids, exclude_mask, vocab_size, **kargs) return [exclude_mask, input_ori_ids] else: return [exclude_mask]
Example #6
Source File: backend.py From bert4keras with Apache License 2.0 | 5 votes |
def batch_gather(params, indices): """同tf旧版本的batch_gather """ try: return tf.gather(params, indices, batch_dims=K.ndim(indices) - 1) except Exception as e1: try: return tf.batch_gather(params, indices) except Exception as e2: raise ValueError('%s\n%s\n' % (e1.message, e2.message))
Example #7
Source File: modeling.py From grover with Apache License 2.0 | 4 votes |
def _top_p_sample(logits, ignore_ids=None, num_samples=1, p=0.9): """ Does top-p sampling. if ignore_ids is on, then we will zero out those logits. :param logits: [batch_size, vocab_size] tensor :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict, like padding maybe :param p: topp threshold to use, either a float or a [batch_size] vector :return: [batch_size, num_samples] samples # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK """ with tf.variable_scope('top_p_sample'): batch_size, vocab_size = get_shape_list(logits, expected_rank=2) probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10, axis=-1) if isinstance(p, float) and p > 0.999999: # Don't do top-p sampling in this case print("Top-p sampling DISABLED", flush=True) return { 'probs': probs, 'sample': tf.random.categorical( logits=logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10, num_samples=num_samples, dtype=tf.int32), } # [batch_size, vocab_perm] indices = tf.argsort(probs, direction='DESCENDING') cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False) # find the top pth index to cut off. careful we don't want to cutoff everything! # result will be [batch_size, vocab_perm] p_expanded = p if isinstance(p, float) else p[:, None] exclude_mask = tf.logical_not( tf.logical_or(cumulative_probabilities < p_expanded, tf.range(vocab_size)[None] < 1)) # OPTION A - sample in the sorted space, then unsort. logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10 sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples) sample = tf.batch_gather(indices, sample_perm) # OPTION B - unsort first - Indices need to go back to 0 -> N-1 -- then sample # unperm_indices = tf.argsort(indices, direction='ASCENDING') # include_mask_unperm = tf.batch_gather(include_mask, unperm_indices) # logits_to_use = logits - (1 - tf.cast(include_mask_unperm, tf.float32)) * 1e10 # sample = tf.random.categorical(logits=logits_to_use, num_samples=num_samples, dtype=tf.int32) return { 'probs': probs, 'sample': sample, }
Example #8
Source File: modeling.py From grover with Apache License 2.0 | 4 votes |
def sample_step(tokens, ignore_ids, news_config, batch_size=1, p_for_topp=0.95, cache=None, do_topk=False): """ Helper function that samples from grover for a single step :param tokens: [batch_size, n_ctx_b] tokens that we will predict from :param ignore_ids: [n_vocab] mask of the tokens we don't want to predict :param news_config: config for the GroverModel :param batch_size: batch size to use :param p_for_topp: top-p or top-k threshold :param cache: [batch_size, news_config.num_hidden_layers, 2, news_config.num_attention_heads, n_ctx_a, news_config.hidden_size // news_config.num_attention_heads] OR, None :return: new_tokens, size [batch_size] new_probs, also size [batch_size] new_cache, size [batch_size, news_config.num_hidden_layers, 2, n_ctx_b, news_config.num_attention_heads, news_config.hidden_size // news_config.num_attention_heads] """ model = GroverModel( config=news_config, is_training=False, input_ids=tokens, reuse=tf.AUTO_REUSE, scope='newslm', chop_off_last_token=False, do_cache=True, cache=cache, ) # Extract the FINAL SEQ LENGTH batch_size_times_seq_length, vocab_size = get_shape_list(model.logits_flat, expected_rank=2) next_logits = tf.reshape(model.logits_flat, [batch_size, -1, vocab_size])[:, -1] if do_topk: sample_info = _top_k_sample(next_logits, num_samples=1, k=tf.cast(p_for_topp, dtype=tf.int32)) else: sample_info = _top_p_sample(next_logits, ignore_ids=ignore_ids, num_samples=1, p=p_for_topp) new_tokens = tf.squeeze(sample_info['sample'], 1) new_probs = tf.squeeze(tf.batch_gather(sample_info['probs'], sample_info['sample']), 1) return { 'new_tokens': new_tokens, 'new_probs': new_probs, 'new_cache': model.new_kvs, }
Example #9
Source File: beam_search.py From BERT with Apache License 2.0 | 4 votes |
def fast_tpu_gather(params, indices, name=None): """Fast gather implementation for models running on TPU. This function use one_hot and batch matmul to do gather, which is faster than gather_nd on TPU. For params that have dtype of int32 (sequences to gather from), batch_gather is used to keep accuracy. Args: params: A tensor from which to gather values. [batch_size, original_size, ...] indices: A tensor used as the index to gather values. [batch_size, selected_size]. name: A string, name of the operation (optional). Returns: gather_result: A tensor that has the same rank as params. [batch_size, selected_size, ...] """ with tf.name_scope(name): dtype = params.dtype def _gather(params, indices): """Fast gather using one_hot and batch matmul.""" if dtype != tf.float32: params = tf.to_float(params) shape = common_layers.shape_list(params) indices_shape = common_layers.shape_list(indices) ndims = params.shape.ndims # Adjust the shape of params to match one-hot indices, which is the # requirement of Batch MatMul. if ndims == 2: params = tf.expand_dims(params, axis=-1) if ndims > 3: params = tf.reshape(params, [shape[0], shape[1], -1]) gather_result = tf.matmul( tf.one_hot(indices, shape[1], dtype=params.dtype), params) if ndims == 2: gather_result = tf.squeeze(gather_result, axis=-1) if ndims > 3: shape[1] = indices_shape[1] gather_result = tf.reshape(gather_result, shape) if dtype != tf.float32: gather_result = tf.cast(gather_result, dtype) return gather_result # If the dtype is int, use the gather instead of one_hot matmul to avoid # precision loss. The max int value can be represented by bfloat16 in MXU is # 256, which is smaller than the possible id values. Encoding/decoding can # potentially used to make it work, but the benenfit is small right now. if dtype.is_integer: gather_result = tf.batch_gather(params, indices) else: gather_result = _gather(params, indices) return gather_result
Example #10
Source File: beam_search.py From training_results_v0.5 with Apache License 2.0 | 4 votes |
def fast_tpu_gather(params, indices, name=None): """Fast gather implementation for models running on TPU. This function use one_hot and batch matmul to do gather, which is faster than gather_nd on TPU. For params that have dtype of int32 (sequences to gather from), batch_gather is used to keep accuracy. Args: params: A tensor from which to gather values. [batch_size, original_size, ...] indices: A tensor used as the index to gather values. [batch_size, selected_size]. name: A string, name of the operation (optional). Returns: gather_result: A tensor that has the same rank as params. [batch_size, selected_size, ...] """ with tf.name_scope(name): dtype = params.dtype def _gather(params, indices): """Fast gather using one_hot and batch matmul.""" if dtype != tf.float32: params = tf.to_float(params) shape = common_layers.shape_list(params) indices_shape = common_layers.shape_list(indices) ndims = params.shape.ndims # Adjust the shape of params to match one-hot indices, which is the # requirement of Batch MatMul. if ndims == 2: params = tf.expand_dims(params, axis=-1) if ndims > 3: params = tf.reshape(params, [shape[0], shape[1], -1]) gather_result = tf.matmul( tf.one_hot(indices, shape[1], dtype=params.dtype), params) if ndims == 2: gather_result = tf.squeeze(gather_result, axis=-1) if ndims > 3: shape[1] = indices_shape[1] gather_result = tf.reshape(gather_result, shape) if dtype != tf.float32: gather_result = tf.cast(gather_result, dtype) return gather_result # If the dtype is int32, use the gather instead of one_hot matmul to avoid # precision loss. The max int value can be represented by bfloat16 in MXU is # 256, which is smaller than the possible id values. Encoding/decoding can # potentially used to make it work, but the benenfit is small right now. if dtype == tf.int32: gather_result = tf.batch_gather(params, indices) else: gather_result = _gather(params, indices) return gather_result
Example #11
Source File: beam_search.py From training_results_v0.5 with Apache License 2.0 | 4 votes |
def fast_tpu_gather(params, indices, name=None): """Fast gather implementation for models running on TPU. This function use one_hot and batch matmul to do gather, which is faster than gather_nd on TPU. For params that have dtype of int32 (sequences to gather from), batch_gather is used to keep accuracy. Args: params: A tensor from which to gather values. [batch_size, original_size, ...] indices: A tensor used as the index to gather values. [batch_size, selected_size]. name: A string, name of the operation (optional). Returns: gather_result: A tensor that has the same rank as params. [batch_size, selected_size, ...] """ with tf.name_scope(name): dtype = params.dtype def _gather(params, indices): """Fast gather using one_hot and batch matmul.""" if dtype != tf.float32: params = tf.to_float(params) shape = common_layers.shape_list(params) indices_shape = common_layers.shape_list(indices) ndims = params.shape.ndims # Adjust the shape of params to match one-hot indices, which is the # requirement of Batch MatMul. if ndims == 2: params = tf.expand_dims(params, axis=-1) if ndims > 3: params = tf.reshape(params, [shape[0], shape[1], -1]) gather_result = tf.matmul( tf.one_hot(indices, shape[1], dtype=params.dtype), params) if ndims == 2: gather_result = tf.squeeze(gather_result, axis=-1) if ndims > 3: shape[1] = indices_shape[1] gather_result = tf.reshape(gather_result, shape) if dtype != tf.float32: gather_result = tf.cast(gather_result, dtype) return gather_result # If the dtype is int32, use the gather instead of one_hot matmul to avoid # precision loss. The max int value can be represented by bfloat16 in MXU is # 256, which is smaller than the possible id values. Encoding/decoding can # potentially used to make it work, but the benenfit is small right now. if dtype == tf.int32: gather_result = tf.batch_gather(params, indices) else: gather_result = _gather(params, indices) return gather_result
Example #12
Source File: net.py From gcdn with MIT License | 4 votes |
def gconv(self, h, name, in_feat, out_feat, stride_th1, stride_th2, compute_graph=True, return_graph=False, D=[]): if compute_graph: D = self.compute_graph(h) _, top_idx = tf.nn.top_k(-D, self.config.min_nn+1) # (B, N, d+1) top_idx2 = tf.reshape(tf.tile(tf.expand_dims(top_idx[:,:,0],2), [1, 1, self.config.min_nn-8]), [-1, self.N*(self.config.min_nn-8)]) # (B, N*d) top_idx = tf.reshape(top_idx[:,:,9:],[-1, self.N*(self.config.min_nn-8)]) # (B, N*d) x_tilde1 = tf.batch_gather(h, top_idx) # (B, K, dlm1) x_tilde2 = tf.batch_gather(h, top_idx2) # (B, K, dlm1) labels = x_tilde1 - x_tilde2 # (B, K, dlm1) x_tilde1 = tf.reshape(x_tilde1, [-1, in_feat]) # (B*K, dlm1) labels = tf.reshape(labels, [-1, in_feat]) # (B*K, dlm1) d_labels = tf.reshape( tf.reduce_sum(labels*labels, 1), [-1, self.config.min_nn-8]) # (B*N, d) name_flayer = name + "_flayer0" labels = tf.nn.leaky_relu(tf.matmul(labels, self.W[name_flayer]) + self.b[name_flayer]) # (B*K, F) name_flayer = name + "_flayer1" labels_exp = tf.expand_dims(labels, 1) # (B*K, 1, F) labels1 = labels_exp+0.0 for ss in range(1, in_feat/stride_th1): labels1 = tf.concat( [labels1, self.myroll(labels_exp, shift=(ss+1)*stride_th1, axis=2)], axis=1 ) # (B*K, dlm1/stride, dlm1) labels2 = labels_exp+0.0 for ss in range(1, out_feat/stride_th2): labels2 = tf.concat( [labels2, self.myroll(labels_exp, shift=(ss+1)*stride_th2, axis=2)], axis=1 ) # (B*K, dl/stride, dlm1) theta1 = tf.matmul( tf.reshape(labels1, [-1, in_feat]), self.W[name_flayer+"_th1"] ) # (B*K*dlm1/stride, R*stride) theta1 = tf.reshape(theta1, [-1, self.config.rank_theta, in_feat] ) + self.b[name_flayer+"_th1"] theta2 = tf.matmul( tf.reshape(labels2, [-1, in_feat]), self.W[name_flayer+"_th2"] ) # (B*K*dl/stride, R*stride) theta2 = tf.reshape(theta2, [-1, self.config.rank_theta, out_feat] ) + self.b[name_flayer+"_th2"] thetal = tf.expand_dims( tf.matmul(labels, self.W[name_flayer+"_thl"]) + self.b[name_flayer+"_thl"], 2 ) # (B*K, R, 1) x = tf.matmul(theta1, tf.expand_dims(x_tilde1,2)) # (B*K, R, 1) x = tf.multiply(x, thetal) # (B*K, R, 1) x = tf.matmul(theta2, x, transpose_a=True)[:,:,0] # (B*K, dl) x = tf.reshape(x, [-1, self.config.min_nn-8, out_feat]) # (N, d, dl) x = tf.multiply(x, tf.expand_dims(tf.exp(-tf.div(d_labels,10)),2)) # (N, d, dl) x = tf.reduce_mean(x, 1) # (N, dl) x = tf.reshape(x,[-1, self.N, out_feat]) # (B, N, dl) if return_graph: return x, D else: return x
Example #13
Source File: net_conv2.py From gcdn with MIT License | 4 votes |
def gconv_conv_inner(self, h, name, in_feat, out_feat, stride_th1, stride_th2, compute_graph=True, return_graph=False, D=[]): h = tf.expand_dims(h, 0) # (1,M,dl) p = tf.image.extract_image_patches(h, ksizes=[1, self.config.search_window[0], self.config.search_window[1], 1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID") # (1,X,Y,dlm1*W) p = tf.reshape(p,[-1, self.config.search_window[0], self.config.search_window[1], in_feat]) p = tf.reshape(p,[-1, self.config.searchN, in_feat]) # (N,W,dlm1) if compute_graph: D = tf.map_fn(lambda feat: self.gconv_conv_inner2(feat), tf.reshape(p,[self.config.search_window[0],self.config.search_window[1],self.config.searchN, in_feat]), parallel_iterations=16, swap_memory=False) # (B,N/B,W) D = tf.reshape(D,[-1, self.config.searchN]) # (N,W) _, top_idx = tf.nn.top_k(-D, self.config.min_nn+1) # (N, d+1) #top_idx2 = tf.reshape(tf.tile(tf.expand_dims(top_idx[:,0],1), [1, self.config.min_nn[i]]), [-1]) top_idx2 = tf.tile(tf.expand_dims(top_idx[:,0],1), [1, self.config.min_nn-8]) # (N, d) #top_idx = tf.reshape(top_idx[:,1:],[-1]) # (N*d,) top_idx = top_idx[:,9:] # (N, d) x_tilde1 = tf.batch_gather(p, top_idx) # (N, d, dlm1) x_tilde1 = tf.reshape(x_tilde1, [-1, in_feat]) # (K, dlm1) x_tilde2 = tf.batch_gather(p, top_idx2) # (N, d, dlm1) x_tilde2 = tf.reshape(x_tilde2, [-1, in_feat]) # (K, dlm1) labels = x_tilde1 - x_tilde2 # (K, dlm1) d_labels = tf.reshape( tf.reduce_sum(labels*labels, 1), [-1, self.config.min_nn-8]) # (N, d) name_flayer = name + "_flayer0" labels = tf.nn.leaky_relu(tf.matmul(labels, self.W[name_flayer]) + self.b[name_flayer]) name_flayer = name + "_flayer1" labels_exp = tf.expand_dims(labels, 1) # (B*K, 1, F) labels1 = labels_exp+0.0 for ss in range(1, in_feat/stride_th1): labels1 = tf.concat( [labels1, self.myroll(labels_exp, shift=(ss+1)*stride_th1, axis=2)], axis=1 ) # (B*K, dlm1/stride, dlm1) labels2 = labels_exp+0.0 for ss in range(1, out_feat/stride_th2): labels2 = tf.concat( [labels2, self.myroll(labels_exp, shift=(ss+1)*stride_th2, axis=2)], axis=1 ) # (B*K, dl/stride, dlm1) theta1 = tf.matmul( tf.reshape(labels1, [-1, in_feat]), self.W[name_flayer+"_th1"] ) # (B*K*dlm1/stride, R*stride) theta1 = tf.reshape(theta1, [-1, self.config.rank_theta, in_feat] ) + self.b[name_flayer+"_th1"] theta2 = tf.matmul( tf.reshape(labels2, [-1, in_feat]), self.W[name_flayer+"_th2"] ) # (B*K*dl/stride, R*stride) theta2 = tf.reshape(theta2, [-1, self.config.rank_theta, out_feat] ) + self.b[name_flayer+"_th2"] thetal = tf.expand_dims( tf.matmul(labels, self.W[name_flayer+"_thl"]) + self.b[name_flayer+"_thl"], 2 ) # (B*K, R, 1) x = tf.matmul(theta1, tf.expand_dims(x_tilde1,2)) # (K, R, 1) x = tf.multiply(x, thetal) # (K, R, 1) x = tf.matmul(theta2, x, transpose_a=True)[:,:,0] # (K, dl) x = tf.reshape(x, [-1, self.config.min_nn-8, out_feat]) # (N, d, dl) x = tf.multiply(x, tf.expand_dims(tf.exp(-tf.div(d_labels,10)),2)) # (N, d, dl) x = tf.reduce_mean(x, 1) # (N, dl) x = tf.expand_dims(x,0) # (1, N, dl) return [x, D]