Python keras.backend.get_session() Examples
The following are 30
code examples of keras.backend.get_session().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
keras.backend
, or try the search function
.
Example #1
Source File: model_wrappers.py From nips-2017-adversarial with MIT License | 7 votes |
def load_ckpt(ckpt_name, var_scope_name, scope, constructor, input_tensor, label_offset, load_weights, **kwargs): """ Arguments ckpt_name file name of the checkpoint var_scope_name name of the variable scope scope arg_scope constructor constructor of the model input_tensor tensor of input image label_offset whether it is 1000 classes or 1001 classes, if it is 1001, remove class 0 load_weights whether to load weights kwargs is_training create_aux_logits """ with slim.arg_scope(scope): logits, endpoints = constructor(\ input_tensor, num_classes=1000+label_offset, \ scope=var_scope_name, **kwargs) if load_weights: init_fn = slim.assign_from_checkpoint_fn(\ ckpt_name, slim.get_model_variables(var_scope_name)) init_fn(K.get_session()) return logits, endpoints
Example #2
Source File: tfrecord_model.py From sample-cnn with MIT License | 6 votes |
def predict_tfrecord(self, x_batch): if self.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = [0.] else: ins = [] self._make_tfrecord_predict_function() try: sess = K.get_session() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) outputs = self.predict_function(ins) finally: # TODO: If you close the queue, you can't open it again.. # if stop_queue_runners: # coord.request_stop() # coord.join(threads) pass if len(outputs) == 1: return outputs[0] return outputs
Example #3
Source File: util.py From deeplift with MIT License | 6 votes |
def compile_func(inputs, outputs): if (isinstance(inputs, list)==False): print("Wrapping the inputs in a list...") inputs = [inputs] assert isinstance(inputs, list) def func_to_return(inp): if len(inp) > len(inputs) and len(inputs)==1: print("Wrapping the inputs in a list...") inp = [inp] assert len(inp)==len(inputs),\ ("length of provided list should be " +str(len(inputs))+" for tensors "+str(inputs) +" but got input of length "+str(len(inp))) feed_dict = {} for input_tensor, input_val in zip(inputs, inp): feed_dict[input_tensor] = input_val sess = get_session() return sess.run(outputs, feed_dict=feed_dict) return func_to_return
Example #4
Source File: image_classifier_tf.py From aiexamples with Apache License 2.0 | 6 votes |
def keras_to_tensorflow(keras_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True): if os.path.exists(output_dir) == False: os.mkdir(output_dir) out_nodes = [] for i in range(len(keras_model.outputs)): out_nodes.append(out_prefix + str(i + 1)) tf.identity(keras_model.output[i], out_prefix + str(i + 1)) sess = K.get_session() init_graph = sess.graph.as_graph_def() main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes) graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False) if log_tensorboard: import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)
Example #5
Source File: test_shap.py From AIX360 with Apache License 2.0 | 6 votes |
def test_ShapGradientExplainer(self): # model = VGG16(weights='imagenet', include_top=True) # X, y = shap.datasets.imagenet50() # to_explain = X[[39, 41]] # # url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json" # fname = shap.datasets.cache(url) # with open(fname) as f: # class_names = json.load(f) # # def map2layer(x, layer): # feed_dict = dict(zip([model.layers[0].input], [preprocess_input(x.copy())])) # return K.get_session().run(model.layers[layer].input, feed_dict) # # e = GradientExplainer((model.layers[7].input, model.layers[-1].output), # map2layer(preprocess_input(X.copy()), 7)) # shap_values, indexes = e.explain_instance(map2layer(to_explain, 7), ranked_outputs=2) # print("Skipped Shap GradientExplainer")
Example #6
Source File: mnist_dnn.py From tensorflow_examples with Apache License 2.0 | 6 votes |
def export_savedmodel(model): print("input: {}, output: {}".format(model.input, model.output)) model_signature = tf.saved_model.signature_def_utils.predict_signature_def( inputs={'input': model.input}, outputs={'output': model.output}) model_path = "model" model_version = 1 export_path = os.path.join( compat.as_bytes(model_path), compat.as_bytes(str(model_version))) logging.info("Export the model to {}".format(export_path)) builder = tf.saved_model.builder.SavedModelBuilder(export_path) builder.add_meta_graph_and_variables( sess=K.get_session(), tags=[tf.saved_model.tag_constants.SERVING], clear_devices=True, signature_def_map={ tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: model_signature }) builder.save()
Example #7
Source File: callbacks.py From training_results_v0.6 with Apache License 2.0 | 6 votes |
def _average_metrics_in_place(self, logs): logs = logs or {} reduced_logs = {} # Reduce every metric among workers. Sort metrics by name # to ensure consistent order. for metric, value in sorted(logs.items()): if metric not in self.variables: self.variables[metric], self.allreduce_ops[metric] = \ self._make_variable(metric, value) else: K.set_value(self.variables[metric], value) reduced_logs[metric] = \ K.get_session().run(self.allreduce_ops[metric]) # Override the reduced values back into logs dictionary # for other callbacks to use. for metric, value in reduced_logs.items(): logs[metric] = value
Example #8
Source File: query_based_attack.py From blackbox-attacks with MIT License | 6 votes |
def CW_est(logits, x, x_plus_i, x_minus_i, curr_sample, curr_target): curr_logits = K.get_session().run([logits], feed_dict={x: curr_sample})[0] # So that when max is taken, it returns max among classes apart from the # target curr_logits[np.arange(BATCH_SIZE), list(curr_target)] = -1e4 max_indices = np.argmax(curr_logits, 1) logit_plus = K.get_session().run([logits], feed_dict={x: x_plus_i})[0] logit_plus_t = logit_plus[np.arange(BATCH_SIZE), list(curr_target)] logit_plus_max = logit_plus[np.arange(BATCH_SIZE), list(max_indices)] logit_minus = K.get_session().run([logits], feed_dict={x: x_minus_i})[0] logit_minus_t = logit_minus[np.arange(BATCH_SIZE), list(curr_target)] logit_minus_max = logit_minus[np.arange(BATCH_SIZE), list(max_indices)] logit_t_grad_est = (logit_plus_t - logit_minus_t)/args.delta logit_max_grad_est = (logit_plus_max - logit_minus_max)/args.delta return logit_t_grad_est/2.0, logit_max_grad_est/2.0
Example #9
Source File: log_utils.py From rpg_public_dronet with MIT License | 6 votes |
def on_epoch_end(self, epoch, logs={}): # Save training and validation losses logz.log_tabular('train_loss', logs.get('loss')) logz.log_tabular('val_loss', logs.get('val_loss')) logz.dump_tabular() # Save model every 'period' epochs if (epoch+1) % self.period == 0: filename = self.filepath + '/model_weights_' + str(epoch) + '.h5' print("Saved model at {}".format(filename)) self.model.save_weights(filename, overwrite=True) # Hard mining sess = K.get_session() mse_function = self.batch_size-(self.batch_size-10)*(np.maximum(0.0,1.0-np.exp(-1.0/30.0*(epoch-30.0)))) entropy_function = self.batch_size-(self.batch_size-5)*(np.maximum(0.0,1.0-np.exp(-1.0/30.0*(epoch-30.0)))) self.model.k_mse.load(int(np.round(mse_function)), sess) self.model.k_entropy.load(int(np.round(entropy_function)), sess)
Example #10
Source File: cifar10_query_based.py From blackbox-attacks with MIT License | 6 votes |
def one_shot_method(prediction, x, curr_sample, curr_target, p_t): grad_est = np.zeros((BATCH_SIZE, IMAGE_ROWS, IMAGE_COLS, NUM_CHANNELS)) DELTA = np.random.randint(2, size=(BATCH_SIZE, IMAGE_ROWS, IMAGE_COLS, NUM_CHANNELS)) np.place(DELTA, DELTA==0, -1) y_plus = np.clip(curr_sample + args.delta * DELTA, CLIP_MIN, CLIP_MAX) y_minus = np.clip(curr_sample - args.delta * DELTA, CLIP_MIN, CLIP_MAX) if args.CW_loss == 0: pred_plus = K.get_session().run([prediction], feed_dict={x: y_plus, K.learning_phase(): 0})[0] pred_plus_t = pred_plus[np.arange(BATCH_SIZE), list(curr_target)] pred_minus = K.get_session().run([prediction], feed_dict={x: y_minus, K.learning_phase(): 0})[0] pred_minus_t = pred_minus[np.arange(BATCH_SIZE), list(curr_target)] num_est = (pred_plus_t - pred_minus_t) grad_est = num_est[:, None, None, None]/(args.delta * DELTA) # Getting gradient of the loss if args.CW_loss == 0: loss_grad = -1.0 * grad_est/p_t[:, None, None, None] return loss_grad
Example #11
Source File: yolov3.py From keras-onnx with MIT License | 6 votes |
def __init__(self, model_path='model_data/yolo.h5', anchors_path='model_data/yolo_anchors.txt', yolo3_dir=None): self.yolo3_dir = yolo3_dir self.model_path = model_path self.anchors_path = anchors_path self.classes_path = 'model_data/coco_classes.txt' self.score = 0.3 self.iou = 0.45 self.class_names = self._get_class() self.anchors = self._get_anchors() self.sess = K.get_session() self.model_image_size = (416, 416) # fixed size or (None, None), hw self.session = None self.final_model = None # Generate colors for drawing bounding boxes. hsv_tuples = [(x / len(self.class_names), 1., 1.) for x in range(len(self.class_names))] self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) self.colors = list( map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors)) np.random.seed(10101) # Fixed seed for consistent colors across runs. np.random.shuffle(self.colors) # Shuffle colors to decorrelate adjacent classes. np.random.seed(None) # Reset seed to default. K.set_learning_phase(0)
Example #12
Source File: callbacks.py From deform-conv with MIT License | 5 votes |
def set_model(self, model): self.model = model self.sess = K.get_session() total_loss = self.model.total_loss if self.histogram_freq and self.merged is None: for layer in self.model.layers: for weight in layer.weights: # dense_1/bias:0 > dense_1/bias_0 name = weight.name.replace(':', '_') tf.summary.histogram(name, weight) tf.summary.histogram( '{}_gradients'.format(name), K.gradients(total_loss, [weight])[0] ) if self.write_images: w_img = tf.squeeze(weight) shape = w_img.get_shape() if len(shape) > 1 and shape[0] > shape[1]: w_img = tf.transpose(w_img) if len(shape) == 1: w_img = tf.expand_dims(w_img, 0) w_img = tf.expand_dims(tf.expand_dims(w_img, 0), -1) tf.summary.image(name, w_img) if hasattr(layer, 'output'): tf.summary.histogram('{}_out'.format(layer.name), layer.output) self.merged = tf.summary.merge_all() if self.write_graph: self.writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) else: self.writer = tf.summary.FileWriter(self.log_dir)
Example #13
Source File: keras_models.py From gentun with Apache License 2.0 | 5 votes |
def reset_weights(self): """Initialize model weights.""" session = K.get_session() for layer in self.model.layers: if hasattr(layer, 'kernel_initializer'): layer.kernel.initializer.run(session=session)
Example #14
Source File: train_mrcnn.py From maskrcnn with MIT License | 5 votes |
def set_debugger_session(): sess = K.get_session() sess = tf_debug.LocalCLIDebugWrapperSession(sess) sess.add_tensor_filter('name_filter', name_filter) K.set_session(sess)
Example #15
Source File: yolo.py From keras-YOLOv3-mobilenet with MIT License | 5 votes |
def __init__(self, **kwargs): self.__dict__.update(self._defaults) # set up default values self.__dict__.update(kwargs) # and update with user overrides self.class_names = self._get_class() self.anchors = self._get_anchors() self.sess = K.get_session() self.boxes, self.scores, self.classes = self.generate()
Example #16
Source File: yolo.py From yolo3_keras_Flag_Detection with MIT License | 5 votes |
def __init__(self, **kwargs): self.__dict__.update(self._defaults) # set up default values self.__dict__.update(kwargs) # and update with user overrides self.class_names = self._get_class() self.anchors = self._get_anchors() self.sess = K.get_session() self.boxes, self.scores, self.classes = self.generate()
Example #17
Source File: __init__.py From training_results_v0.6 with Apache License 2.0 | 5 votes |
def allgather(value, name=None): """ Perform an allgather on a tensor-compatible value. The concatenation is done on the first dimension, so the input values on the different processes must have the same rank and shape, except for the first dimension, which is allowed to be different. Arguments: value: A tensor-compatible value to gather. name: Optional name prefix for the constants created by this operation. """ allgather_op = hvd.allgather(tf.constant(value, name=name)) return K.get_session().run(allgather_op)
Example #18
Source File: network_utils.py From nips-2017-adversarial with MIT License | 5 votes |
def restore_source_model(saved_pb_name, grad_dict=None): print('restoring', saved_pb_name) with open(saved_pb_name + '.pickle', 'rb') as f: info = pickle.load(f) print(info) sess = K.get_session() print('restoring frozen graph def') with open(saved_pb_name + '.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') # prepare set of tensors that needs to be found tensor_search_name = set() for gname, (input_names, output_names) in info.items(): tensor_search_name = tensor_search_name.union(set(input_names+output_names)) found_tensors = {} # search for tensors ops = tf.get_default_graph().get_operations() for op in ops: if len(op.outputs) != 1: continue if op.outputs[0].name in tensor_search_name: found_tensors[op.outputs[0].name] = op.outputs[0] flag = True for t in tensor_search_name: if t not in found_tensors: print('Tensor not found:', t) flag = False if not flag: return print('all nodes found') for gname, (input_names, output_names) in info.items(): input_list = [found_tensors[tname] for tname in input_names] output_list = [found_tensors[tname] for tname in output_names] print('{0}\n Input: {1}\n Output: {2}\n'.format(gname, input_list, output_list)) grad_dict[gname] = (input_list, output_list, K.function(input_list, output_list)) print('restore finished')
Example #19
Source File: network_utils.py From nips-2017-adversarial with MIT License | 5 votes |
def restore_source_model(saved_ckpt_name, grad_dict=None, from_frozen=False): print('restoring', saved_ckpt_name) with open(saved_ckpt_name + '.pickle', 'rb') as f: info = pickle.load(f) print(info[0]) print(info[1]) sess = K.get_session() if not from_frozen: print('restoring graph') saver = tf.train.import_meta_graph(saved_ckpt_name + '.ckpt.meta') print('restoring variables') saver.restore(sess, saved_ckpt_name + '.ckpt') else: print('restoring frozen graph def') with open(saved_ckpt_name + '.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') ops = tf.get_default_graph().get_operations() tensor_search_name = set(info[0] + list(info[1].values())) found_tensors = {} for op in ops: if len(op.outputs) != 1: continue if op.outputs[0].name.startswith('src_input'): print(op.outputs[0].name) if op.outputs[0].name in tensor_search_name: found_tensors[op.outputs[0].name] = op.outputs[0] input_tensors = [found_tensors[nm] for nm in info[0]] pred_input_tensors = input_tensors[:2] + [input_tensors[3]] print(input_tensors) for model_name, tensor_name in info[1].items(): grad_dict[model_name] = K.function(\ input_tensors if model_name != 'PRED' else pred_input_tensors, \ [found_tensors[tensor_name]]) print('restore finished')
Example #20
Source File: network_utils.py From nips-2017-adversarial with MIT License | 5 votes |
def restore_source_model(saved_ckpt_name, grad_dict=None, from_frozen=False): print('restoring', saved_ckpt_name) with open(saved_ckpt_name + '.pickle', 'rb') as f: info = pickle.load(f) print(info[0]) print(info[1]) sess = K.get_session() if not from_frozen: print('restoring graph') saver = tf.train.import_meta_graph(saved_ckpt_name + '.ckpt.meta') print('restoring variables') saver.restore(sess, saved_ckpt_name + '.ckpt') else: print('restoring frozen graph def') with open(saved_ckpt_name + '.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') ops = tf.get_default_graph().get_operations() tensor_search_name = set(info[0] + list(info[1].values())) found_tensors = {} for op in ops: if len(op.outputs) != 1: continue if op.outputs[0].name.startswith('src_input'): print(op.outputs[0].name) if op.outputs[0].name in tensor_search_name: found_tensors[op.outputs[0].name] = op.outputs[0] input_tensors = [found_tensors[nm] for nm in info[0]] # TODO this part is not really right. need to integrate with # attack better # also in the end learning phase is added twice, that's why # it worked for the frozen graphs, but will break for the non-frozen ones pred_input_tensors = input_tensors + [input_tensors[2]] print(input_tensors) for model_name, tensor_name in info[1].items(): grad_dict[model_name] = K.function(\ input_tensors if model_name != 'PRED' else pred_input_tensors, \ [found_tensors[tensor_name]]) print('restore finished')
Example #21
Source File: model_wrappers.py From nips-2017-adversarial with MIT License | 5 votes |
def load_ckpt(ckpt_name, var_scope_name, scope, constructor, input_tensor, label_offset, load_weights, **kwargs): """ kwargs are is_training and create_aux_logits """ print(var_scope_name) with slim.arg_scope(scope): logits, endpoints = constructor(\ input_tensor, num_classes=1000+label_offset, \ scope=var_scope_name, **kwargs) if load_weights: init_fn = slim.assign_from_checkpoint_fn(\ ckpt_name, slim.get_model_variables(var_scope_name)) init_fn(K.get_session()) return logits, endpoints
Example #22
Source File: network_utils.py From nips-2017-adversarial with MIT License | 5 votes |
def restore_source_model(saved_ckpt_name, grad_dict=None, from_frozen=False): print('restoring', saved_ckpt_name) with open(saved_ckpt_name + '.pickle', 'rb') as f: info = pickle.load(f) print(info[0]) print(info[1]) sess = K.get_session() if not from_frozen: print('restoring graph') saver = tf.train.import_meta_graph(saved_ckpt_name + '.ckpt.meta') print('restoring variables') saver.restore(sess, saved_ckpt_name + '.ckpt') else: print('restoring frozen graph def') with open(saved_ckpt_name + '.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') ops = tf.get_default_graph().get_operations() tensor_search_name = set(info[0] + list(info[1].values())) found_tensors = {} for op in ops: if len(op.outputs) != 1: continue if op.outputs[0].name.startswith('src_input'): print(op.outputs[0].name) if op.outputs[0].name in tensor_search_name: found_tensors[op.outputs[0].name] = op.outputs[0] input_tensors = [found_tensors[nm] for nm in info[0]] pred_input_tensors = input_tensors[:2] + [input_tensors[3]] print(input_tensors) for model_name, tensor_name in info[1].items(): grad_dict[model_name] = K.function(\ input_tensors if model_name != 'PRED' else pred_input_tensors, \ [found_tensors[tensor_name]]) print('restore finished')
Example #23
Source File: model_wrappers.py From nips-2017-adversarial with MIT License | 5 votes |
def load_ckpt(ckpt_name, var_scope_name, scope, constructor, input_tensor, label_offset, load_weights, **kwargs): """ kwargs are is_training and create_aux_logits """ print(var_scope_name) with slim.arg_scope(scope): logits, endpoints = constructor(\ input_tensor, num_classes=1000+label_offset, \ scope=var_scope_name, **kwargs) if load_weights: init_fn = slim.assign_from_checkpoint_fn(\ ckpt_name, slim.get_model_variables(var_scope_name)) init_fn(K.get_session()) return logits, endpoints
Example #24
Source File: __init__.py From training_results_v0.6 with Apache License 2.0 | 5 votes |
def allreduce(value, name=None, average=True): """ Perform an allreduce on a tensor-compatible value. Arguments: value: A tensor-compatible value to reduce. The shape of the input must be identical across all ranks. name: Optional name for the constants created by this operation. average: If True, computes the average over all ranks. Otherwise, computes the sum over all ranks. """ allreduce_op = hvd.allreduce(tf.constant(value, name=name), average=average) return K.get_session().run(allreduce_op)
Example #25
Source File: utils.py From dts with MIT License | 5 votes |
def get_flops(model): run_meta = tf.RunMetadata() opts = tf.profiler.ProfileOptionBuilder.float_operation() # We use the Keras session graph in the call to the profiler. flops = tf.profiler.profile(graph=K.get_session().graph, run_meta=run_meta, cmd='op', options=opts) return flops.total_float_ops # Prints the "flops" of the model.
Example #26
Source File: __init__.py From BERT-keras with GNU General Public License v3.0 | 5 votes |
def tpu_compatible(): '''Fit the tpu problems we meet while using keras tpu model''' if not hasattr(tpu_compatible, 'once'): tpu_compatible.once = True else: return import tensorflow as tf import tensorflow.keras.backend as K _version = tf.__version__.split('.') is_correct_version = int(_version[0]) >= 1 and (int(_version[0]) >= 2 or int(_version[1]) >= 13) from tensorflow.contrib.tpu.python.tpu.keras_support import KerasTPUModel def initialize_uninitialized_variables(): sess = K.get_session() uninitialized_variables = set([i.decode('ascii') for i in sess.run(tf.report_uninitialized_variables())]) init_op = tf.variables_initializer( [v for v in tf.global_variables() if v.name.split(':')[0] in uninitialized_variables] ) sess.run(init_op) _tpu_compile = KerasTPUModel.compile def tpu_compile(self, optimizer, loss=None, metrics=None, loss_weights=None, sample_weight_mode=None, weighted_metrics=None, target_tensors=None, **kwargs): if not is_correct_version: raise ValueError('You need tensorflow >= 1.3 for better keras tpu support!') _tpu_compile(self, optimizer, loss, metrics, loss_weights, sample_weight_mode, weighted_metrics, target_tensors, **kwargs) initialize_uninitialized_variables() # for unknown reason, we should run this after compile sometimes KerasTPUModel.compile = tpu_compile
Example #27
Source File: utils.py From voxelmorph with GNU General Public License v3.0 | 5 votes |
def reset_weights(model, session=None): """ reset weights of model with the appropriate initializer. Note: only uses "kernel_initializer" and "bias_initializer" does not close session. Reference: https://www.codementor.io/nitinsurya/how-to-re-initialize-keras-model-weights-et41zre2g Parameters: model: keras model to reset session (optional): the current session """ if session is None: session = K.get_session() for layer in model.layers: reset = False if hasattr(layer, 'kernel_initializer'): layer.kernel.initializer.run(session=session) reset = True if hasattr(layer, 'bias_initializer'): layer.bias.initializer.run(session=session) reset = True if not reset: print('Could not find initializer for layer %s. skipping', layer.name)
Example #28
Source File: __init__.py From training_results_v0.6 with Apache License 2.0 | 5 votes |
def broadcast(value, root_rank, name=None): """ Perform a broadcast on a tensor-compatible value. Arguments: value: A tensor-compatible value to reduce. The shape of the input must be identical across all ranks. root_rank: Rank of the process from which global variables will be broadcasted to all other processes. name: Optional name for the constants created by this operation. """ bcast_op = hvd.broadcast(tf.constant(value, name=name), root_rank) return K.get_session().run(bcast_op)
Example #29
Source File: word_vectors.py From keras-image-captioning with MIT License | 5 votes |
def vectorize_words(self, words): vectors = [] for word in words: vector = self._word_vector_of.get(word) vectors.append(vector) num_unknowns = len(filter(lambda x: x is None, vectors)) inits = self._initializer(shape=(num_unknowns, self._embedding_size)) inits = K.get_session().run(inits) inits = iter(inits) for i in range(len(vectors)): if vectors[i] is None: vectors[i] = next(inits) return np.array(vectors)
Example #30
Source File: callbacks.py From training_results_v0.6 with Apache License 2.0 | 5 votes |
def on_train_begin(self, logs=None): with tf.device(self.device): bcast_op = hvd.broadcast_global_variables(self.root_rank) K.get_session().run(bcast_op)