Python tensorflow.get_default_graph() Examples
The following are 30
code examples of tensorflow.get_default_graph().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow
, or try the search function
.
Example #1
Source File: utils_pytorch.py From neural-fingerprinting with BSD 3-Clause "New" or "Revised" License | 7 votes |
def _py_func_with_gradient(func, inp, Tout, stateful=True, name=None, grad_func=None): """ PyFunc defined as given by Tensorflow :param func: Custom Function :param inp: Function Inputs :param Tout: Ouput Type of out Custom Function :param stateful: Calculate Gradients when stateful is True :param name: Name of the PyFunction :param grad: Custom Gradient Function :return: """ # Generate random name in order to avoid conflicts with inbuilt names rnd_name = 'PyFuncGrad-' + '%0x' % getrandbits(30 * 4) # Register Tensorflow Gradient tf.RegisterGradient(rnd_name)(grad_func) # Get current graph g = tf.get_default_graph() # Add gradient override map with g.gradient_override_map( {"PyFunc": rnd_name, "PyFuncStateless": rnd_name}): return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
Example #2
Source File: build.py From Traffic_sign_detection_YOLO with MIT License | 6 votes |
def build_from_pb(self): with tf.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def( graph_def, name="" ) with open(self.FLAGS.metaLoad, 'r') as fp: self.meta = json.load(fp) self.framework = create_framework(self.meta, self.FLAGS) # Placeholders self.inp = tf.get_default_graph().get_tensor_by_name('input:0') self.feed = dict() # other placeholders self.out = tf.get_default_graph().get_tensor_by_name('output:0') self.setup_meta_ops()
Example #3
Source File: test_model.py From models with MIT License | 6 votes |
def network_surgery(): tf.reset_default_graph() inputs = tf.placeholder(tf.float32, shape=(None, 131072, 4), name='inputs') targets = tf.placeholder(tf.float32, shape=(None, 1024, 4229), name='targets') targets_na = tf.placeholder(tf.bool, shape=(None, 1024), name="targets_na") preds_adhoc = tf.placeholder(tf.float32, shape=(None, 960, 4229), name="Placeholder_15") saver = tf.train.import_meta_graph("model_files/model.tf.meta", input_map={'Placeholder_15:0': preds_adhoc, 'Placeholder:0': targets_na, 'inputs:0': inputs, 'targets:0': targets }) ops = tf.get_default_graph().get_operations() out = tf.train.export_meta_graph(filename='model_files/model.tf-modified.meta', as_text=True) ops[:15]
Example #4
Source File: graph_rewriter_builder.py From vehicle_counting_tensorflow with MIT License | 6 votes |
def build(graph_rewriter_config, is_training): """Returns a function that modifies default graph based on options. Args: graph_rewriter_config: graph_rewriter_pb2.GraphRewriter proto. is_training: whether in training of eval mode. """ def graph_rewrite_fn(): """Function to quantize weights and activation of the default graph.""" if (graph_rewriter_config.quantization.weight_bits != 8 or graph_rewriter_config.quantization.activation_bits != 8): raise ValueError('Only 8bit quantization is supported') # Quantize the graph by inserting quantize ops for weights and activations if is_training: tf.contrib.quantize.create_training_graph( input_graph=tf.get_default_graph(), quant_delay=graph_rewriter_config.quantization.delay) else: tf.contrib.quantize.create_eval_graph(input_graph=tf.get_default_graph()) tf.contrib.layers.summarize_collection('quant_vars') return graph_rewrite_fn
Example #5
Source File: build.py From Traffic-Signs-and-Object-Detection with GNU General Public License v3.0 | 6 votes |
def build_from_pb(self): with tf.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def( graph_def, name="" ) with open(self.FLAGS.metaLoad, 'r') as fp: self.meta = json.load(fp) self.framework = create_framework(self.meta, self.FLAGS) # Placeholders self.inp = tf.get_default_graph().get_tensor_by_name('input:0') self.feed = dict() # other placeholders self.out = tf.get_default_graph().get_tensor_by_name('output:0') self.setup_meta_ops()
Example #6
Source File: tfutil.py From disentangling_conditional_gans with MIT License | 6 votes |
def init_uninited_vars(vars=None): if vars is None: vars = tf.global_variables() test_vars = []; test_ops = [] with tf.control_dependencies(None): # ignore surrounding control_dependencies for var in vars: assert is_tf_expression(var) try: tf.get_default_graph().get_tensor_by_name(var.name.replace(':0', '/IsVariableInitialized:0')) except KeyError: # Op does not exist => variable may be uninitialized. test_vars.append(var) with absolute_name_scope(var.name.split(':')[0]): test_ops.append(tf.is_variable_initialized(var)) init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] run([var.initializer for var in init_vars]) #---------------------------------------------------------------------------- # Set the values of given tf.Variables. # Equivalent to the following, but more efficient and does not bloat the tf graph: # tfutil.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
Example #7
Source File: train.py From TFFRCNN with MIT License | 6 votes |
def __init__(self, sess, network, imdb, roidb, output_dir, logdir, pretrained_model=None): """Initialize the SolverWrapper.""" self.net = network self.imdb = imdb self.roidb = roidb self.output_dir = output_dir self.pretrained_model = pretrained_model print 'Computing bounding-box regression targets...' if cfg.TRAIN.BBOX_REG: self.bbox_means, self.bbox_stds = rdl_roidb.add_bbox_regression_targets(roidb) print 'done' # For checkpoint self.saver = tf.train.Saver(max_to_keep=100) self.writer = tf.summary.FileWriter(logdir=logdir, graph=tf.get_default_graph(), flush_secs=5)
Example #8
Source File: graph_rewriter_builder_test.py From vehicle_counting_tensorflow with MIT License | 6 votes |
def testQuantizationBuilderSetsUpCorrectTrainArguments(self): with mock.patch.object( tf.contrib.quantize, 'create_training_graph') as mock_quant_fn: with mock.patch.object(tf.contrib.layers, 'summarize_collection') as mock_summarize_col: graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter() graph_rewriter_proto.quantization.delay = 10 graph_rewriter_proto.quantization.weight_bits = 8 graph_rewriter_proto.quantization.activation_bits = 8 graph_rewrite_fn = graph_rewriter_builder.build( graph_rewriter_proto, is_training=True) graph_rewrite_fn() _, kwargs = mock_quant_fn.call_args self.assertEqual(kwargs['input_graph'], tf.get_default_graph()) self.assertEqual(kwargs['quant_delay'], 10) mock_summarize_col.assert_called_with('quant_vars')
Example #9
Source File: ssd_mobilenet_v2_feature_extractor_test.py From vehicle_counting_tensorflow with MIT License | 6 votes |
def test_has_fused_batchnorm(self, use_keras): image_height = 40 image_width = 40 depth_multiplier = 1 pad_to_multiple = 1 image_placeholder = tf.placeholder(tf.float32, [1, image_height, image_width, 3]) feature_extractor = self._create_feature_extractor(depth_multiplier, pad_to_multiple, use_keras=use_keras) preprocessed_image = feature_extractor.preprocess(image_placeholder) if use_keras: _ = feature_extractor(preprocessed_image) else: _ = feature_extractor.extract_features(preprocessed_image) self.assertTrue(any(op.type == 'FusedBatchNorm' for op in tf.get_default_graph().get_operations()))
Example #10
Source File: evaluate.py From DOTA_models with Apache License 2.0 | 6 votes |
def main(_): tf.logging.set_verbosity(tf.logging.INFO) tf.gfile.MakeDirs(FLAGS.eval_dir) tf.logging.info('Building eval graph...') output = graphs.get_model().eval_graph(FLAGS.eval_data) eval_ops, moving_averaged_variables = output saver = tf.train.Saver(moving_averaged_variables) summary_writer = tf.summary.FileWriter( FLAGS.eval_dir, graph=tf.get_default_graph()) while True: run_eval(eval_ops, summary_writer, saver) if FLAGS.run_once: break time.sleep(FLAGS.eval_interval_secs)
Example #11
Source File: count_weights.py From soccer-matlab with BSD 2-Clause "Simplified" License | 6 votes |
def count_weights(scope=None, exclude=None, graph=None): """Count learnable parameters. Args: scope: Resrict the count to a variable scope. exclude: Regex to match variable names to exclude. graph: Operate on a graph other than the current default graph. Returns: Number of learnable parameters as integer. """ if scope: scope = scope if scope.endswith('/') else scope + '/' graph = graph or tf.get_default_graph() vars_ = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if scope: vars_ = [var for var in vars_ if var.name.startswith(scope)] if exclude: exclude = re.compile(exclude) vars_ = [var for var in vars_ if not exclude.match(var.name)] shapes = [var.get_shape().as_list() for var in vars_] return int(sum(np.prod(shape) for shape in shapes))
Example #12
Source File: model.py From Voice_Converter_CycleGAN with MIT License | 6 votes |
def __init__(self, num_features, discriminator = discriminator, generator = generator_gatedcnn, mode = 'train', log_dir = './log'): self.num_features = num_features self.input_shape = [None, num_features, None] # [batch_size, num_features, num_frames] self.discriminator = discriminator self.generator = generator self.mode = mode self.build_model() self.optimizer_initializer() self.saver = tf.train.Saver() self.sess = tf.Session() self.sess.run(tf.global_variables_initializer()) if self.mode == 'train': self.train_step = 0 now = datetime.now() self.log_dir = os.path.join(log_dir, now.strftime('%Y%m%d-%H%M%S')) self.writer = tf.summary.FileWriter(self.log_dir, tf.get_default_graph()) self.generator_summaries, self.discriminator_summaries = self.summary()
Example #13
Source File: nn_model.py From mercari-price-suggestion with MIT License | 6 votes |
def __init__(self, train_df, word_count, batch_size, epochs): tf.set_random_seed(4) session_conf = tf.ConfigProto(intra_op_parallelism_threads=2, inter_op_parallelism_threads=8) backend.set_session(tf.Session(graph=tf.get_default_graph(), config=session_conf)) self.batch_size = batch_size self.epochs = epochs self.max_name_seq = 10 self.max_item_desc_seq = 75 self.max_text = word_count + 1 self.max_brand = np.max(train_df.brand_name.max()) + 1 self.max_condition = np.max(train_df.item_condition_id.max()) + 1 self.max_subcat0 = np.max(train_df.subcat_0.max()) + 1 self.max_subcat1 = np.max(train_df.subcat_1.max()) + 1 self.max_subcat2 = np.max(train_df.subcat_2.max()) + 1
Example #14
Source File: freeze_model.py From deep_sort with GNU General Public License v3.0 | 6 votes |
def main(): args = parse_args() with tf.Session(graph=tf.Graph()) as session: input_var = tf.placeholder( tf.uint8, (None, 128, 64, 3), name="images") image_var = tf.map_fn( lambda x: _preprocess(x), tf.cast(input_var, tf.float32), back_prop=False) factory_fn = _network_factory() features, _ = factory_fn(image_var, reuse=None) features = tf.identity(features, name="features") saver = tf.train.Saver(slim.get_variables_to_restore()) saver.restore(session, args.checkpoint_in) output_graph_def = tf.graph_util.convert_variables_to_constants( session, tf.get_default_graph().as_graph_def(), [features.name.split(":")[0]]) with tf.gfile.GFile(args.graphdef_out, "wb") as file_handle: file_handle.write(output_graph_def.SerializeToString())
Example #15
Source File: generate_detections.py From deep_sort with GNU General Public License v3.0 | 6 votes |
def __init__(self, checkpoint_filename, input_name="images", output_name="features"): self.session = tf.Session() with tf.gfile.GFile(checkpoint_filename, "rb") as file_handle: graph_def = tf.GraphDef() graph_def.ParseFromString(file_handle.read()) tf.import_graph_def(graph_def, name="net") self.input_var = tf.get_default_graph().get_tensor_by_name( "net/%s:0" % input_name) self.output_var = tf.get_default_graph().get_tensor_by_name( "net/%s:0" % output_name) assert len(self.output_var.get_shape()) == 2 assert len(self.input_var.get_shape()) == 4 self.feature_dim = self.output_var.get_shape().as_list()[-1] self.image_shape = self.input_var.get_shape().as_list()[1:]
Example #16
Source File: build.py From Automatic-Identification-and-Counting-of-Blood-Cells with GNU General Public License v3.0 | 6 votes |
def build_from_pb(self): with tf.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def( graph_def, name="" ) with open(self.FLAGS.metaLoad, 'r') as fp: self.meta = json.load(fp) self.framework = create_framework(self.meta, self.FLAGS) # Placeholders self.inp = tf.get_default_graph().get_tensor_by_name('input:0') self.feed = dict() # other placeholders self.out = tf.get_default_graph().get_tensor_by_name('output:0') self.setup_meta_ops()
Example #17
Source File: count_weights.py From soccer-matlab with BSD 2-Clause "Simplified" License | 6 votes |
def count_weights(scope=None, exclude=None, graph=None): """Count learnable parameters. Args: scope: Resrict the count to a variable scope. exclude: Regex to match variable names to exclude. graph: Operate on a graph other than the current default graph. Returns: Number of learnable parameters as integer. """ if scope: scope = scope if scope.endswith('/') else scope + '/' graph = graph or tf.get_default_graph() vars_ = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if scope: vars_ = [var for var in vars_ if var.name.startswith(scope)] if exclude: exclude = re.compile(exclude) vars_ = [var for var in vars_ if not exclude.match(var.name)] shapes = [var.get_shape().as_list() for var in vars_] return int(sum(np.prod(shape) for shape in shapes))
Example #18
Source File: shape_utils_test.py From vehicle_counting_tensorflow with MIT License | 6 votes |
def test_with_dynamic_shape(self): def fn(input_tensor): return tf.reduce_sum(input_tensor) input_tensor = tf.placeholder(tf.float32, shape=(None, 2)) map_fn_output = shape_utils.static_or_dynamic_map_fn(fn, input_tensor) op_names = [op.name for op in tf.get_default_graph().get_operations()] self.assertTrue(any(['map' == op_name[:3] for op_name in op_names])) with self.test_session() as sess: result1 = sess.run( map_fn_output, feed_dict={ input_tensor: [[1, 2], [3, 1], [0, 4]]}) result2 = sess.run( map_fn_output, feed_dict={ input_tensor: [[-1, 1], [0, 9]]}) self.assertAllEqual(result1, [3, 4, 4]) self.assertAllEqual(result2, [0, 9])
Example #19
Source File: run_RingNet.py From RingNet with MIT License | 5 votes |
def __init__(self, config, sess=None): self.config = config self.load_path = config.load_path if not config.load_path: raise Exception( "provide a pretrained model path" ) if not exists(config.load_path + '.index'): print('%s couldnt find..' % config.load_path) import ipdb ipdb.set_trace() # Data self.batch_size = config.batch_size self.img_size = config.img_size self.data_format = config.data_format input_size = (self.batch_size, self.img_size, self.img_size, 3) self.images_pl = tf.placeholder(tf.float32, shape=input_size, name='input_images') if sess is None: self.sess = tf.Session() else: self.sess = sess # Load graph. self.saver = tf.train.import_meta_graph(self.load_path+'.meta') self.graph = tf.get_default_graph() self.prepare()
Example #20
Source File: graph_rewriter_builder_test.py From vehicle_counting_tensorflow with MIT License | 5 votes |
def testQuantizationBuilderSetsUpCorrectEvalArguments(self): with mock.patch.object(tf.contrib.quantize, 'create_eval_graph') as mock_quant_fn: with mock.patch.object(tf.contrib.layers, 'summarize_collection') as mock_summarize_col: graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter() graph_rewriter_proto.quantization.delay = 10 graph_rewrite_fn = graph_rewriter_builder.build( graph_rewriter_proto, is_training=False) graph_rewrite_fn() _, kwargs = mock_quant_fn.call_args self.assertEqual(kwargs['input_graph'], tf.get_default_graph()) mock_summarize_col.assert_called_with('quant_vars')
Example #21
Source File: generators.py From python-esppy with Apache License 2.0 | 5 votes |
def gen_wrap_str(self): dir_path = ntpath.dirname(self.file) + '/' wrap_str = ''' sess = None def tf_score({4}): "Output: {5}" import tensorflow as tf import numpy as np global sess global score_op global input_op #If it is called for the first time, restore the model and necessary operations if sess is None: sess=tf.Session() #load meta graph and restore weights saver = tf.train.import_meta_graph('{0}') saver.restore(sess,tf.train.latest_checkpoint('{1}')) graph = tf.get_default_graph() #restore the ops. Both ops were pre-defined in the model. input_op = graph.get_tensor_by_name("{2}:0") #op to feed input data score_op = graph.get_tensor_by_name("{3}:0") #op to score the input #Note that the feed value of x has shape (?,xyz), NOT (,xyz) {4}_wrap = np.array([{4}]) {5} = sess.run(score_op, feed_dict={{input_op: {4}_wrap}})[0] if isinstance({5}, np.ndarray): {5} = {5}.tolist() else: {5} = {5}.item() return {5}'''.format(self.file, dir_path, self.input_op, self.score_op, self.input_name, self.output_name) return wrap_str
Example #22
Source File: generators.py From python-esppy with Apache License 2.0 | 5 votes |
def gen_wrap_str(self): if self.output_class: predict = 'predict_classes' else: predict = 'predict' wrap_str = ''' model = None def ks_score({0}): "Output: {1}" from keras.models import load_model import tensorflow as tf import numpy as np global model global graph # If it is called for the first time, restore the model if model is None: model = load_model('{2}') model._make_predict_function() graph = tf.get_default_graph() # make prediction {0}_wrap = np.array([{0}]) with graph.as_default(): {1} = model.{3}({0}_wrap)[0] if isinstance({1}, np.ndarray): {1} = {1}.tolist() else: {1} = {1}.item() return {1}'''.format(self.input_name, self.output_name, self.file, predict) return wrap_str
Example #23
Source File: ssd_mobilenet_v1_fpn_feature_extractor_test.py From vehicle_counting_tensorflow with MIT License | 5 votes |
def test_fused_batchnorm(self): image_height = 256 image_width = 256 depth_multiplier = 1 pad_to_multiple = 1 image_placeholder = tf.placeholder(tf.float32, [1, image_height, image_width, 3]) feature_extractor = self._create_feature_extractor(depth_multiplier, pad_to_multiple) preprocessed_image = feature_extractor.preprocess(image_placeholder) _ = feature_extractor.extract_features(preprocessed_image) self.assertTrue( any(op.type == 'FusedBatchNorm' for op in tf.get_default_graph().get_operations()))
Example #24
Source File: model_base.py From tf-recsys with MIT License | 5 votes |
def load_model(self, model_dir): """Loads Tensorflow model. Args: model_dir: A string, the path of saving directory """ tensor_names = ['placeholder/users:0', 'placeholder/items:0', 'placeholder/ratings:0', 'prediction/pred:0'] operation_names = ['optimizer/optimizer'] model_name = type(self).__name__ model_path = os.path.join(model_dir, model_name) self._saver = tf.train.import_meta_graph(model_path + '.meta') self._saver.restore(self._sess, model_path) for name in tensor_names: attr = '_' + name.split('/')[1].split(':')[0] setattr(self, attr, tf.get_default_graph().get_tensor_by_name(name)) for name in operation_names: attr = '_' + name.split('/')[1].split(':')[0] setattr(self, attr, tf.get_default_graph( ).get_operation_by_name(name)) self._built = True
Example #25
Source File: ml_util.py From sparkflow with MIT License | 5 votes |
def predict_func(rows, graph_json, prediction, graph_weights, inp, activation, tf_input, tf_dropout=None, to_keep_dropout=False): rows = [r.asDict() for r in rows] if len(rows) > 0: graph = tf.MetaGraphDef() graph = json_format.Parse(graph_json, graph) loaded_weights = json.loads(graph_weights) loaded_weights = [np.asarray(x) for x in loaded_weights] A = [np.asarray(row[inp]) for row in rows] new_graph = tf.Graph() with tf.Session(graph=new_graph) as sess: tf.train.import_meta_graph(graph) sess.run(tf.global_variables_initializer()) tensorflow_set_weights(loaded_weights) out_node = tf.get_default_graph().get_tensor_by_name(activation) dropout_v = 1.0 if tf_dropout is not None and to_keep_dropout else 0.0 feed_dict = {tf_input: A} if tf_dropout is None else {tf_input: A, tf_dropout: dropout_v} pred = sess.run(out_node, feed_dict=feed_dict) for i in range(0, len(rows)): row = rows[i] try: # Vectors Dense are handled differently in python 3 internal = float(pred[i]) row[prediction] = internal except: row[prediction] = Vectors.dense(pred[i]) return [Row(**a) for a in rows] return []
Example #26
Source File: app.py From easy-tensorflow-multimodel-server with MIT License | 5 votes |
def load_model(model_dir, model_prefix): label_map = label_map_util.load_labelmap('{}/{}{}'.format(model_dir, model_prefix, LABEL_MAP_SUFFIX)) categories = label_map_util.convert_label_map_to_categories( label_map, max_num_classes=90, use_display_name=True) category_index = label_map_util.create_category_index(categories) detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile('{}/{}{}'.format(model_dir, model_prefix, MODEL_SUFFIX), 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') # Get handles to input and output tensors ops = tf.get_default_graph().get_operations() all_tensor_names = { output.name for op in ops for output in op.outputs } tensor_dict = {} for key in [ 'num_detections', 'detection_boxes', 'detection_scores', 'detection_classes' ]: tensor_name = key + ':0' if tensor_name in all_tensor_names: tensor_dict[key] = tf.get_default_graph( ).get_tensor_by_name(tensor_name) image_tensor = tf.get_default_graph().get_tensor_by_name( 'image_tensor:0') sess = tf.Session(graph=detection_graph) return { 'session': sess, 'image_tensor': image_tensor, 'tensor_dict': tensor_dict, 'category_index': category_index }
Example #27
Source File: model.py From DeepLab_v3 with MIT License | 5 votes |
def __init__(self, base_architecture, training=True, num_classes=21, ignore_label=255, batch_norm_momentum=0.9997, pre_trained_model=None, log_dir='data/logs/deeplab/'): self.is_training = tf.placeholder(tf.bool, None, name='is_training') self.num_classes = num_classes self.ignore_label = ignore_label self.inputs_shape = [None, None, None, 3] self.labels_shape = [None, None, None, 1] self.training = training self.inputs = tf.placeholder(tf.float32, shape=self.inputs_shape, name='inputs') self.labels = tf.placeholder(tf.uint8, shape=self.labels_shape, name='labels') self.target_height = tf.placeholder(tf.int32, None, name='target_image_height') self.target_width = tf.placeholder(tf.int32, None, name='target_image_width') self.weight_decay = tf.placeholder(tf.float32, None, name='weight_decay') self.regularizer = tf.contrib.layers.l2_regularizer(scale=self.weight_decay) self.batch_norm_momentum = batch_norm_momentum self.feature_map = self.backbone_initializer(base_architecture) if pre_trained_model: self.initialize_backbone_from_pretrained_weights(pre_trained_model) self.outputs = self.model_initializer() self.learning_rate = tf.placeholder(tf.float32, None, name='learning_rate') self.loss = self.loss_initializer() self.optimizer = self.optimizer_initializer() # Initialize tensorflow session self.saver = tf.train.Saver() self.sess = tf.Session() self.sess.run(tf.global_variables_initializer()) if self.training: self.train_step = 0 now = datetime.now() self.log_dir = os.path.join(log_dir, now.strftime('%Y%m%d-%H%M%S')) self.writer = tf.summary.FileWriter(self.log_dir, tf.get_default_graph()) self.train_summaries, self.valid_summaries = self.summary()
Example #28
Source File: mobilenet_v2_test.py From DeepLab_v3 with MIT License | 5 votes |
def find_ops(optype): """Find ops of a given type in graphdef or a graph. Args: optype: operation type (e.g. Conv2D) Returns: List of operations. """ gd = tf.get_default_graph() return [var for var in gd.get_operations() if var.type == optype]
Example #29
Source File: ssd_mobilenet_v1_ppn_feature_extractor_test.py From vehicle_counting_tensorflow with MIT License | 5 votes |
def test_has_fused_batchnorm(self): image_height = 320 image_width = 320 depth_multiplier = 1 pad_to_multiple = 1 image_placeholder = tf.placeholder(tf.float32, [1, image_height, image_width, 3]) feature_extractor = self._create_feature_extractor(depth_multiplier, pad_to_multiple) preprocessed_image = feature_extractor.preprocess(image_placeholder) _ = feature_extractor.extract_features(preprocessed_image) self.assertTrue(any(op.type == 'FusedBatchNorm' for op in tf.get_default_graph().get_operations()))
Example #30
Source File: tf_to_uff.py From iAI with MIT License | 5 votes |
def getFrozenModel(model_path): with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.import_meta_graph(model_path+'.meta') saver.restore(sess, model_path) graph = tf.get_default_graph().as_graph_def() frozen_graph = tf.graph_util.convert_variables_to_constants(sess, graph, frozen_node_name) return tf.graph_util.remove_training_nodes(frozen_graph)