Python tensorflow.sparse_concat() Examples
The following are 29
code examples of tensorflow.sparse_concat().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: tensorflow_backend.py From DeepLearning_Wavelet-LSTM with MIT License | 6 votes |
def concatenate(tensors, axis=-1): """Concatenates a list of tensors alongside the specified axis. # Arguments tensors: list of tensors to concatenate. axis: concatenation axis. # Returns A tensor. """ if axis < 0: rank = ndim(tensors[0]) if rank: axis %= rank else: axis = 0 if py_all([is_sparse(x) for x in tensors]): return tf.sparse_concat(axis, tensors) else: return tf.concat([to_dense(x) for x in tensors], axis)
Example #2
Source File: tensorflow_backend.py From keras-lambda with MIT License | 6 votes |
def concatenate(tensors, axis=-1): """Concatenates a list of tensors alongside the specified axis. # Arguments tensors: list of tensors to concatenate. axis: concatenation axis. # Returns A tensor. """ if axis < 0: rank = ndim(tensors[0]) if rank: axis %= rank else: axis = 0 if py_all([is_sparse(x) for x in tensors]): return tf.sparse_concat(axis, tensors) else: return tf.concat([to_dense(x) for x in tensors], axis)
Example #3
Source File: tf_sequence_example_decoder.py From multilabel-image-classification-tensorflow with MIT License | 6 votes |
def tensors_to_item(self, keys_to_tensors): """Maps the given dictionary of tensors to a concatenated list of bboxes. Args: keys_to_tensors: a mapping of TF-Example keys to parsed tensors. Returns: [time, num_boxes, 4] tensor of bounding box coordinates, in order [y_min, x_min, y_max, x_max]. Whether the tensor is a SparseTensor or a dense Tensor is determined by the return_dense parameter. Empty positions in the sparse tensor are filled with -1.0 values. """ sides = [] for key in self._full_keys: value = keys_to_tensors[key] expanded_dims = tf.concat( [tf.to_int64(tf.shape(value)), tf.constant([1], dtype=tf.int64)], 0) side = tf.sparse_reshape(value, expanded_dims) sides.append(side) bounding_boxes = tf.sparse_concat(2, sides) if self._return_dense: bounding_boxes = tf.sparse_tensor_to_dense( bounding_boxes, default_value=self._default_value) return bounding_boxes
Example #4
Source File: tf_sequence_example_decoder.py From g-tensorflow-models with Apache License 2.0 | 6 votes |
def tensors_to_item(self, keys_to_tensors): """Maps the given dictionary of tensors to a concatenated list of bboxes. Args: keys_to_tensors: a mapping of TF-Example keys to parsed tensors. Returns: [time, num_boxes, 4] tensor of bounding box coordinates, in order [y_min, x_min, y_max, x_max]. Whether the tensor is a SparseTensor or a dense Tensor is determined by the return_dense parameter. Empty positions in the sparse tensor are filled with -1.0 values. """ sides = [] for key in self._full_keys: value = keys_to_tensors[key] expanded_dims = tf.concat( [tf.to_int64(tf.shape(value)), tf.constant([1], dtype=tf.int64)], 0) side = tf.sparse_reshape(value, expanded_dims) sides.append(side) bounding_boxes = tf.sparse_concat(2, sides) if self._return_dense: bounding_boxes = tf.sparse_tensor_to_dense( bounding_boxes, default_value=self._default_value) return bounding_boxes
Example #5
Source File: tensorflow_backend.py From deepQuest with BSD 3-Clause "New" or "Revised" License | 6 votes |
def concatenate(tensors, axis=-1): """Concatenates a list of tensors alongside the specified axis. # Arguments tensors: list of tensors to concatenate. axis: concatenation axis. # Returns A tensor. """ if axis < 0: rank = ndim(tensors[0]) if rank: axis %= rank else: axis = 0 if py_all([is_sparse(x) for x in tensors]): return tf.sparse_concat(axis, tensors) else: return tf.concat([to_dense(x) for x in tensors], axis)
Example #6
Source File: node_edge_models.py From gcnn-survey-paper with Apache License 2.0 | 6 votes |
def compute_inference(self, node_features_in, sp_adj_matrix, is_training): with tf.variable_scope('edge-model'): z_latent = gcn_module(node_features_in, sp_adj_matrix, self.n_hidden_edge, self.p_drop_edge, is_training, self.input_dim, self.sparse_features) adj_matrix_pred = compute_adj(z_latent, self.att_mechanism, self.p_drop_edge, is_training) self.adj_matrix_pred = adj_matrix_pred with tf.variable_scope('node-model'): z_latent = tf.sparse_concat( axis=1, sp_inputs=[ tf.contrib.layers.dense_to_sparse(z_latent), node_features_in ]) sparse_features = True input_dim = self.n_hidden_edge[-1] + self.input_dim logits = gcn_module( z_latent, sp_adj_matrix, self.n_hidden_node, self.p_drop_node, is_training, input_dim, sparse_features=sparse_features) return logits, adj_matrix_pred
Example #7
Source File: model_utils.py From gcnn-survey-paper with Apache License 2.0 | 6 votes |
def get_sp_topk(adj_pred, sp_adj_train, nb_nodes, k): """Returns binary matrix with topK.""" _, indices = tf.nn.top_k(tf.reshape(adj_pred, (-1,)), k) indices = tf.reshape(tf.cast(indices, tf.int64), (-1, 1)) sp_adj_pred = tf.SparseTensor( indices=indices, values=tf.ones(k), dense_shape=(nb_nodes * nb_nodes,)) sp_adj_pred = tf.sparse_reshape(sp_adj_pred, shape=(nb_nodes, nb_nodes, 1)) sp_adj_train = tf.SparseTensor( indices=sp_adj_train.indices, values=tf.ones_like(sp_adj_train.values), dense_shape=sp_adj_train.dense_shape) sp_adj_train = tf.sparse_reshape(sp_adj_train, shape=(nb_nodes, nb_nodes, 1)) sp_adj_pred = tf.sparse_concat( sp_inputs=[sp_adj_pred, sp_adj_train], axis=-1) return tf.sparse_reduce_max(sp_adj_pred, axis=-1)
Example #8
Source File: tensorflow_backend.py From DeepLearning_Wavelet-LSTM with MIT License | 6 votes |
def concatenate(tensors, axis=-1): """Concatenates a list of tensors alongside the specified axis. # Arguments tensors: list of tensors to concatenate. axis: concatenation axis. # Returns A tensor. """ if axis < 0: rank = ndim(tensors[0]) if rank: axis %= rank else: axis = 0 if py_all([is_sparse(x) for x in tensors]): return tf.sparse_concat(axis, tensors) else: return tf.concat([to_dense(x) for x in tensors], axis)
Example #9
Source File: tensorflow_backend.py From DeepLearning_Wavelet-LSTM with MIT License | 6 votes |
def concatenate(tensors, axis=-1): """Concatenates a list of tensors alongside the specified axis. # Arguments tensors: list of tensors to concatenate. axis: concatenation axis. # Returns A tensor. """ if axis < 0: rank = ndim(tensors[0]) if rank: axis %= rank else: axis = 0 if py_all([is_sparse(x) for x in tensors]): return tf.sparse_concat(axis, tensors) else: return tf.concat([to_dense(x) for x in tensors], axis)
Example #10
Source File: tensorflow_backend.py From DeepLearning_Wavelet-LSTM with MIT License | 6 votes |
def concatenate(tensors, axis=-1): """Concatenates a list of tensors alongside the specified axis. # Arguments tensors: list of tensors to concatenate. axis: concatenation axis. # Returns A tensor. """ if axis < 0: rank = ndim(tensors[0]) if rank: axis %= rank else: axis = 0 if py_all([is_sparse(x) for x in tensors]): return tf.sparse_concat(axis, tensors) else: return tf.concat([to_dense(x) for x in tensors], axis)
Example #11
Source File: tensorflow_backend.py From DeepLearning_Wavelet-LSTM with MIT License | 6 votes |
def concatenate(tensors, axis=-1): """Concatenates a list of tensors alongside the specified axis. # Arguments tensors: list of tensors to concatenate. axis: concatenation axis. # Returns A tensor. """ if axis < 0: rank = ndim(tensors[0]) if rank: axis %= rank else: axis = 0 if py_all([is_sparse(x) for x in tensors]): return tf.sparse_concat(axis, tensors) else: return tf.concat([to_dense(x) for x in tensors], axis)
Example #12
Source File: tensorflow_backend.py From DeepLearning_Wavelet-LSTM with MIT License | 6 votes |
def concatenate(tensors, axis=-1): """Concatenates a list of tensors alongside the specified axis. # Arguments tensors: list of tensors to concatenate. axis: concatenation axis. # Returns A tensor. """ if axis < 0: rank = ndim(tensors[0]) if rank: axis %= rank else: axis = 0 if py_all([is_sparse(x) for x in tensors]): return tf.sparse_concat(axis, tensors) else: return tf.concat([to_dense(x) for x in tensors], axis)
Example #13
Source File: tensorflow_backend.py From DeepLearning_Wavelet-LSTM with MIT License | 6 votes |
def concatenate(tensors, axis=-1): """Concatenates a list of tensors alongside the specified axis. # Arguments tensors: list of tensors to concatenate. axis: concatenation axis. # Returns A tensor. """ if axis < 0: rank = ndim(tensors[0]) if rank: axis %= rank else: axis = 0 if py_all([is_sparse(x) for x in tensors]): return tf.sparse_concat(axis, tensors) else: return tf.concat([to_dense(x) for x in tensors], axis)
Example #14
Source File: sparse_concat_op_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def testConcat3(self): with self.test_session(use_gpu=False) as sess: # concat(A, B, C): # [ 1 ] # [2 1 1 ] # [3 4 2 1 0 2 ] sp_a = self._SparseTensor_3x3() sp_b = self._SparseTensor_3x5() sp_c = self._SparseTensor_3x2() for concat_dim in (-1, 1): sp_concat = tf.sparse_concat(concat_dim, [sp_a, sp_b, sp_c]) self.assertEqual(sp_concat.indices.get_shape(), [10, 2]) self.assertEqual(sp_concat.values.get_shape(), [10]) self.assertEqual(sp_concat.shape.get_shape(), [2]) concat_out = sess.run(sp_concat) self.assertAllEqual(concat_out.indices, [[0, 2], [1, 0], [1, 4], [1, 8], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7], [2, 8]]) self.assertAllEqual(concat_out.values, [1, 2, 1, 1, 3, 4, 2, 1, 0, 2]) self.assertAllEqual(concat_out.shape, [3, 10])
Example #15
Source File: sparse_concat_op_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def testConcatDim0(self): with self.test_session(use_gpu=False) as sess: # concat(A, D): # [ 1] # [2 ] # [3 4] # [ 1 ] # [1 2] sp_a = self._SparseTensor_3x3() sp_d = self._SparseTensor_2x3() for concat_dim in (-2, 0): sp_concat = tf.sparse_concat(concat_dim, [sp_a, sp_d]) self.assertEqual(sp_concat.indices.get_shape(), [7, 2]) self.assertEqual(sp_concat.values.get_shape(), [7]) self.assertEqual(sp_concat.shape.get_shape(), [2]) concat_out = sess.run(sp_concat) self.assertAllEqual( concat_out.indices, [[0, 2], [1, 0], [2, 0], [2, 2], [3, 1], [4, 0], [4, 2]]) self.assertAllEqual(concat_out.values, np.array([1, 2, 3, 4, 1, 1, 2])) self.assertAllEqual(concat_out.shape, np.array([5, 3]))
Example #16
Source File: sparse_concat_op_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def testConcat2(self): with self.test_session(use_gpu=False) as sess: # concat(A, B): # [ 1 ] # [2 1 ] # [3 4 2 1 0] for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()): for sp_b in (self._SparseTensorValue_3x5(), self._SparseTensor_3x5()): for concat_dim in (-1, 1): sp_concat = tf.sparse_concat(concat_dim, [sp_a, sp_b]) self.assertEqual(sp_concat.indices.get_shape(), [8, 2]) self.assertEqual(sp_concat.values.get_shape(), [8]) self.assertEqual(sp_concat.shape.get_shape(), [2]) concat_out = sess.run(sp_concat) self.assertAllEqual(concat_out.indices, [[0, 2], [1, 0], [1, 4], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7]]) self.assertAllEqual(concat_out.values, [1, 2, 1, 3, 4, 2, 1, 0]) self.assertAllEqual(concat_out.shape, [3, 8])
Example #17
Source File: sparse_concat_op_test.py From deep_image_model with Apache License 2.0 | 6 votes |
def testConcat1(self): with self.test_session(use_gpu=False) as sess: # concat(A): # [ 1] # [2 ] # [3 4] for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()): # Note that we ignore concat_dim in this case since we short-circuit the # single-input case in python. for concat_dim in (-2000, 1, 2000): sp_concat = tf.sparse_concat(concat_dim, [sp_a]) self.assertEqual(sp_concat.indices.get_shape(), [4, 2]) self.assertEqual(sp_concat.values.get_shape(), [4]) self.assertEqual(sp_concat.shape.get_shape(), [2]) concat_out = sess.run(sp_concat) self.assertAllEqual(concat_out.indices, [[0, 2], [1, 0], [2, 0], [2, 2]]) self.assertAllEqual(concat_out.values, [1, 2, 3, 4]) self.assertAllEqual(concat_out.shape, [3, 3])
Example #18
Source File: tensorflow_backend.py From GraphicDesignPatternByPython with MIT License | 6 votes |
def concatenate(tensors, axis=-1): """Concatenates a list of tensors alongside the specified axis. # Arguments tensors: list of tensors to concatenate. axis: concatenation axis. # Returns A tensor. """ if axis < 0: rank = ndim(tensors[0]) if rank: axis %= rank else: axis = 0 if py_all([is_sparse(x) for x in tensors]): return tf.sparse_concat(axis, tensors) else: return tf.concat([to_dense(x) for x in tensors], axis)
Example #19
Source File: sparse_concat_op_test.py From deep_image_model with Apache License 2.0 | 5 votes |
def testShapeInferenceUnknownShapes(self): with self.test_session(use_gpu=False): sp_inputs = [ self._SparseTensor_UnknownShape(), self._SparseTensor_UnknownShape(val_shape=[3]), self._SparseTensor_UnknownShape(ind_shape=[1, 3]), self._SparseTensor_UnknownShape(shape_shape=[3])] for concat_dim in (-2, 0): sp_concat = tf.sparse_concat(concat_dim, sp_inputs) self.assertEqual(sp_concat.indices.get_shape().as_list(), [None, 3]) self.assertEqual(sp_concat.values.get_shape().as_list(), [None]) self.assertEqual(sp_concat.shape.get_shape(), [3])
Example #20
Source File: tf_ops.py From nucleus7 with Mozilla Public License 2.0 | 5 votes |
def concat_padded(list_of_tensors: List[tf.Tensor], axis: int = 0, expand_nonconcat_dim: bool = True) -> tf.Tensor: """ Concatenate tensors and pad tensors with smaller dimension. Uses sparse concatenation inside, so can be slow Parameters ---------- list_of_tensors list of tensors axis axis to concatenate expand_nonconcat_dim whether to allow the expansion in the non-concat dimensions. Returns ------- concatenated_tensor concatenated tensor """ t_sparse = [dense_to_sparse(t, tf.shape(t, out_type=tf.int64)) for t in list_of_tensors] t_concatenated_sparse = tf.sparse_concat( axis, t_sparse, expand_nonconcat_dim=expand_nonconcat_dim) return tf.sparse_tensor_to_dense(t_concatenated_sparse)
Example #21
Source File: sparse_concat_op_test.py From deep_image_model with Apache License 2.0 | 5 votes |
def testMismatchedShapesExpandNonconcatDim(self): with self.test_session(use_gpu=False) as sess: sp_a = self._SparseTensor_3x3() sp_b = self._SparseTensor_3x5() sp_c = self._SparseTensor_3x2() sp_d = self._SparseTensor_2x3() for concat_dim0 in (-2, 0): for concat_dim1 in (-1, 1): sp_concat_dim0 = tf.sparse_concat( concat_dim0, [sp_a, sp_b, sp_c, sp_d], expand_nonconcat_dim=True) sp_concat_dim1 = tf.sparse_concat( concat_dim1, [sp_a, sp_b, sp_c, sp_d], expand_nonconcat_dim=True) sp_concat_dim0_out = sess.run(sp_concat_dim0) sp_concat_dim1_out = sess.run(sp_concat_dim1) self.assertAllEqual(sp_concat_dim0_out.indices, [[0, 2], [1, 0], [2, 0], [2, 2], [4, 1], [5, 0], [5, 3], [5, 4], [7, 0], [8, 0], [9, 1], [10, 0], [10, 2]]) self.assertAllEqual(sp_concat_dim0_out.values, [1, 2, 3, 4, 1, 2, 1, 0, 1, 2, 1, 1, 2]) self.assertAllEqual(sp_concat_dim0_out.shape, [11, 5]) self.assertAllEqual(sp_concat_dim1_out.indices, [[0, 2], [0, 11], [1, 0], [1, 4], [1, 8], [1, 10], [1, 12], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7], [2, 8]]) self.assertAllEqual(sp_concat_dim1_out.values, [1, 1, 2, 1, 1, 1, 2, 3, 4, 2, 1, 0, 2]) self.assertAllEqual(sp_concat_dim1_out.shape, [3, 13])
Example #22
Source File: sparse_concat_op_test.py From deep_image_model with Apache License 2.0 | 5 votes |
def testMismatchedShapes(self): with self.test_session(use_gpu=False) as sess: sp_a = self._SparseTensor_3x3() sp_b = self._SparseTensor_3x5() sp_c = self._SparseTensor_3x2() sp_d = self._SparseTensor_2x3() for concat_dim in (-1, 1): sp_concat = tf.sparse_concat(concat_dim, [sp_a, sp_b, sp_c, sp_d]) # Shape mismatches can only be caught when the op is run with self.assertRaisesOpError("Input shapes must match"): sess.run(sp_concat)
Example #23
Source File: sparse_concat_op_test.py From deep_image_model with Apache License 2.0 | 5 votes |
def testMismatchedRankExpandNonconcatDim(self): with self.test_session(use_gpu=False): sp_a = self._SparseTensor_3x3() sp_e = self._SparseTensor_2x3x4() # Rank mismatches should be caught at shape-inference time, even for # expand_nonconcat_dim=True. for concat_dim in (-1, 1): with self.assertRaises(ValueError): tf.sparse_concat(concat_dim, [sp_a, sp_e], expand_nonconcat_dim=True)
Example #24
Source File: sparse_concat_op_test.py From deep_image_model with Apache License 2.0 | 5 votes |
def testMismatchedRank(self): with self.test_session(use_gpu=False): sp_a = self._SparseTensor_3x3() sp_e = self._SparseTensor_2x3x4() # Rank mismatches can be caught at shape-inference time for concat_dim in (-1, 1): with self.assertRaises(ValueError): tf.sparse_concat(concat_dim, [sp_a, sp_e])
Example #25
Source File: node_edge_models.py From gcnn-survey-paper with Apache License 2.0 | 5 votes |
def compute_inference(self, node_features_in, sp_adj_matrix, is_training): with tf.variable_scope('edge-model'): z_latent = gcn_module(node_features_in, sp_adj_matrix, self.n_hidden_edge, self.p_drop_edge, is_training, self.input_dim, self.sparse_features) adj_matrix_pred = compute_adj(z_latent, self.att_mechanism, self.p_drop_edge, is_training) self.adj_matrix_pred = adj_matrix_pred with tf.variable_scope('node-model'): z_latent = tf.sparse_concat( axis=1, sp_inputs=[ tf.contrib.layers.dense_to_sparse(z_latent), node_features_in ]) sparse_features = True input_dim = self.n_hidden_edge[-1] + self.input_dim sp_adj_train = tf.SparseTensor( indices=sp_adj_matrix.indices, values=tf.ones_like(sp_adj_matrix.values), dense_shape=sp_adj_matrix.dense_shape) logits = gat_module( z_latent, sp_adj_train, self.n_hidden_node, self.n_att_node, self.p_drop_node, is_training, input_dim, sparse_features=sparse_features, average_last=True) return logits, adj_matrix_pred
Example #26
Source File: sparse_split_op_test.py From deep_image_model with Apache License 2.0 | 5 votes |
def testSliceConcat(self): for sp_input in ( self._SparseTensorValue_3x4x2(), self._SparseTensor_3x4x2()): with self.test_session(use_gpu=False): sparse_tensors = tf.sparse_split(1, 2, sp_input) concat_tensor = tf.sparse_concat(1, sparse_tensors) expected_output = self._SparseTensor_3x4x2() self.assertAllEqual(concat_tensor.indices.eval(), expected_output.indices.eval())
Example #27
Source File: model_utils.py From nucleus7 with Mozilla Public License 2.0 | 5 votes |
def combine_predictions_from_devices( predictions_devices: List[Dict[str, tf.Tensor]], predictions_have_variable_shape: bool = False) -> Dict[str, tf.Tensor]: """ Combines (concatenates) the predictions from multiple devices Parameters ---------- predictions_devices list of dicts with same structure from multiple devices predictions_have_variable_shape if predictions from different devices may have different shapes; if so, it will use sparse operations to combine them Returns ------- dict with same structure as first element in predictions_devices with concatenated over first dimension (batch dimension) values. If inputs have variable shape, then concatenation is done using :obj:`tf.sparse_concat` instead of :obj:`tf.concat` """ if len(predictions_devices) == 1: return _dict_identity(predictions_devices[0]) if predictions_have_variable_shape: combine_fun = lambda x: tf_ops.concat_padded(x, axis=0) else: combine_fun = lambda x: tf_ops.concat_or_stack(x, axis=0) with tf.variable_scope('combine_predictions'): predictions = nest_utils.combine_nested(predictions_devices, combine_fun=combine_fun) return predictions
Example #28
Source File: node_edge_models.py From gcnn-survey-paper with Apache License 2.0 | 4 votes |
def compute_inference(self, node_features_in, sp_adj_matrix, is_training): with tf.variable_scope('edge-model'): z_latent = gat_module( node_features_in, sp_adj_matrix, self.n_hidden_edge, self.n_att_edge, self.p_drop_edge, is_training, self.input_dim, self.sparse_features, average_last=False) adj_matrix_pred = compute_adj(z_latent, self.att_mechanism, self.p_drop_edge, is_training) self.adj_matrix_pred = adj_matrix_pred with tf.variable_scope('node-model'): concat = True if concat: z_latent = tf.sparse_concat( axis=1, sp_inputs=[ tf.contrib.layers.dense_to_sparse(z_latent), node_features_in ], ) sparse_features = True input_dim = self.n_hidden_edge[-1] * self.n_att_edge[ -1] + self.input_dim else: sparse_features = False input_dim = self.n_hidden_edge[-1] * self.n_att_edge[-1] logits = gat_module( z_latent, sp_adj_matrix, self.n_hidden_node, self.n_att_node, self.p_drop_node, is_training, input_dim, sparse_features=sparse_features, average_last=False) return logits, adj_matrix_pred
Example #29
Source File: example.py From rgat with Apache License 2.0 | 4 votes |
def model_fn(features, labels, mode, params): training = mode == ModeKeys.TRAIN if params.model == "rgat": model_class = RGATNModel elif params.model == "rgc": model_class = RGCNModel else: raise ValueError( "Unknown model {}. Must be one of `'rgat'` or `'rgc'`".format( params.model)) model = model_class(params=params, training=training) inputs, support = features['features'], features['support'] # Combine dict of supports into a single matrix support = tf.sparse_concat(axis=1, sp_inputs=list(support.values()), name="combine_supports") logits = model(inputs=inputs, support=support) predictions = tf.argmax(logits, axis=-1, name='predictions') if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec( mode, predictions={'logits': logits, 'predictions': predictions}) mask, labels = labels['mask'], labels['labels'] # Get only unmasked labels, logits and predictions labels, logits = tf.gather(labels, mask), tf.gather(logits, mask) predictions = tf.gather(predictions, mask) loss = tf.losses.sparse_softmax_cross_entropy( labels=labels, logits=logits) with tf.name_scope('metrics'): accuracy = tf.metrics.accuracy( labels=labels, predictions=predictions) metrics = {'accuracy': accuracy} if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=metrics) assert mode == tf.estimator.ModeKeys.TRAIN optimizer = tf.train.AdamOptimizer(learning_rate=params.learning_rate) global_step = tf.train.get_global_step() train_op = optimizer.minimize(loss, global_step=global_step) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)