Python tensorflow.VERSION Examples
The following are 30
code examples of tensorflow.VERSION().
Example #1
Source File: From incubator-tvm with Apache License 2.0 | 6 votes |
def test_forward_mobilenet_v3(): """Test the Mobilenet V3 TF Lite model.""" # In MobilenetV3, some ops are not supported before tf 1.15 fbs schema if package_version.parse(tf.VERSION) < package_version.parse('1.15.0'): return tflite_model_file = tf_testing.get_workload_official( "", "v3-large_224_1.0_float/v3-large_224_1.0_float.tflite") with open(tflite_model_file, "rb") as f: tflite_model_buf = data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32') tflite_output = run_tflite_graph(tflite_model_buf, data) tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) ####################################################################### # Inception # ---------
Example #2
Source File: From incubator-tvm with Apache License 2.0 | 6 votes |
def _test_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape): # tflite 1.13 convert method does not accept empty shapes if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): with tf.Graph().as_default(): indices = tf.placeholder(shape=sparse_indices.shape, dtype=str(sparse_indices.dtype), name="indices") values = tf.placeholder(shape=sparse_values.shape, dtype=str(sparse_values.dtype), name="values") oshape = tf.constant(output_shape, shape=output_shape.shape, dtype=str(output_shape.dtype)) if default_value == None: output = tf.sparse_to_dense(indices, oshape, values) compare_tflite_with_tvm( [sparse_indices, sparse_values], ["indices", "values"], [indices, values], [output] ) else: dv = tf.placeholder(shape=(), dtype=str(default_value.dtype), name="default_value") output = tf.sparse_to_dense(indices, oshape, values, dv) compare_tflite_with_tvm( [sparse_indices, sparse_values, default_value], ["indices", "values", "default_value"], [indices, values, dv], [output] )
Example #3
Source File: From incubator-tvm with Apache License 2.0 | 6 votes |
def test_forward_tflite2_qnn_inception_v1(): """Test the Quantized TFLite version 2.1.0 Inception V1 model.""" if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): tflite_model_file = download_testdata( "", "inception_v1_quantized.tflite") with open(tflite_model_file, "rb") as f: tflite_model_buf = data = pre_processed_image(224, 224) tflite_output = run_tflite_graph(tflite_model_buf, data) tflite_predictions = np.squeeze(tflite_output) tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1') tvm_predictions = np.squeeze(tvm_output) tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
Example #4
Source File: From incubator-tvm with Apache License 2.0 | 6 votes |
def _test_range(start, limit, delta): # tflite 1.13 convert method does not accept empty shapes if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): tf.reset_default_graph() with tf.Graph().as_default(): start_scalar, limit_scalar, delta_scalar = \ tf.placeholder(dtype=start.dtype, shape=(), name="start"), \ tf.placeholder(dtype=limit.dtype, shape=(), name="limit"), \ tf.placeholder(dtype=delta.dtype, shape=(), name="delta") out = tf.range(start_scalar, limit_scalar, delta_scalar, name="range") compare_tflite_with_tvm( [start, limit, delta], ["start", "limit", "delta"], [start_scalar, limit_scalar, delta_scalar], [out], mode="vm", quantized=False )
Example #5
Source File: From incubator-tvm with Apache License 2.0 | 6 votes |
def test_forward_tflite2_qnn_mobilenet_v2(): """Test the Quantized TFLite version 2.1.0 Mobilenet V2 model.""" if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): tflite_model_file = download_testdata( "", "mobilenet_v2_quantized.tflite") with open(tflite_model_file, "rb") as f: tflite_model_buf = data = pre_processed_image(224, 224) tflite_output = run_tflite_graph(tflite_model_buf, data) tflite_predictions = np.squeeze(tflite_output) tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1') tvm_predictions = np.squeeze(tvm_output) tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) ####################################################################### # Quantized SSD Mobilenet # -----------------------
Example #6
Source File: From incubator-tvm with Apache License 2.0 | 6 votes |
def _test_range_default(): # tflite 1.13 convert method does not accept empty shapes if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): tf.reset_default_graph() with tf.Graph().as_default(): inputs = [ tf.placeholder(dtype=tf.int32, shape=(), name="p1"), tf.placeholder(dtype=tf.int32, shape=(), name="p2") ] outputs = [ tf.range(start = inputs[0], limit = inputs[1]), # use default delta tf.range(start = inputs[1]) # use start as limit with 0 as the first item in the range ] compare_tflite_with_tvm( [np.int32(1), np.int32(18)], ["p1", "p2"], inputs, outputs, mode="vm" )
Example #7
Source File: From incubator-tvm with Apache License 2.0 | 6 votes |
def test_tensor_array_size(): if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): pytest.skip("Needs fixing for tflite >= 1.15.0") def run(dtype_str, infer_shape): with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] np_data = np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str) in_data = [np_data, np_data] t1 = tf.constant(np_data, dtype=dtype) t2 = tf.constant(np_data, dtype=dtype) ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=infer_shape) ta2 = ta1.write(0, t1) ta3 = ta2.write(1, t2) out = ta3.size() g = tf.get_default_graph() compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug') for dtype in ["float32", "int8"]: run(dtype, False) run(dtype, True)
Example #8
Source File: From incubator-tvm with Apache License 2.0 | 6 votes |
def test_all_unary_elemwise(): _test_forward_unary_elemwise(_test_abs) _test_forward_unary_elemwise(_test_floor) _test_forward_unary_elemwise(_test_exp) _test_forward_unary_elemwise(_test_log) _test_forward_unary_elemwise(_test_sin) _test_forward_unary_elemwise(_test_sqrt) _test_forward_unary_elemwise(_test_rsqrt) _test_forward_unary_elemwise(_test_neg) _test_forward_unary_elemwise(_test_square) # ceil and cos come with TFLite 1.14.0.post1 fbs schema if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): _test_forward_unary_elemwise(_test_ceil) _test_forward_unary_elemwise(_test_cos) _test_forward_unary_elemwise(_test_round) # This fails with TF and Tflite 1.15.2, this could not have been tested # in CI or anywhere else. The failure mode is that we see a backtrace # from the converter that we need to provide a custom Tan operator # implementation. #_test_forward_unary_elemwise(_test_tan) _test_forward_unary_elemwise(_test_elu) ####################################################################### # Element-wise # ------------
Example #9
Source File: From incubator-tvm with Apache License 2.0 | 6 votes |
def test_tensor_array_stack(): def run(dtype_str, infer_shape): if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): pytest.skip("Needs fixing for tflite >= 1.15.0") with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str)) scatter_indices = tf.constant([2, 1, 0]) ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=infer_shape) ta2 = ta1.scatter(scatter_indices, t) t1 = ta2.stack() print(t1) g = tf.get_default_graph() compare_tf_with_tvm([], [], ['TensorArrayStack/TensorArrayGatherV3:0'], mode='vm') for dtype in ["float32", "int8"]: run(dtype, True)
Example #10
Source File: From incubator-tvm with Apache License 2.0 | 6 votes |
def _test_fill(dims, value_data, value_dtype): """ Use the fill op to create a tensor of value_data with constant dims.""" value_data = np.array(value_data, dtype=value_dtype) # TF 1.13 TFLite convert method does not accept empty shapes if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): with tf.Graph().as_default(): value = array_ops.placeholder(dtype=value_dtype, name="value", shape=[]) out = tf.fill(dims, value) compare_tflite_with_tvm([value_data], ["value"], [value], [out]) with tf.Graph().as_default(): input1 = array_ops.placeholder(dtype=value_dtype, name="input1", shape=dims) # Fill op gets converted to static tensor during conversion out = tf.fill(dims, value_data) out1 = tf.add(out, input1) input1_data = np.random.uniform(0, 5, size=dims).astype(value_dtype) compare_tflite_with_tvm([input1_data], ["input1"], [input1], [out1])
Example #11
Source File: From incubator-tvm with Apache License 2.0 | 6 votes |
def test_tensor_array_unstack(): def run(dtype_str, input_shape, infer_shape): if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): pytest.skip("Needs fixing for tflite >= 1.15.0") with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] t = tf.constant(np.random.choice([0, 1, 2, 3], size=input_shape).astype( ta1 = tf.TensorArray(dtype=dtype, infer_shape=infer_shape, size=input_shape[0]) ta2 = ta1.unstack(t) out0 = ta2.size() out1 = compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug') compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug') for dtype in ["float32", "int8"]: run(dtype, (5,), False) run(dtype, (5, 5), True) run(dtype, (5, 5, 5), False) run(dtype, (5, 5, 5, 5), True) ####################################################################### # ConcatV2 # --------
Example #12
Source File: From incubator-tvm with Apache License 2.0 | 6 votes |
def test_forward_add_n(): if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): x = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32) y = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32) z = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32) m, n, o = x.astype(np.float32), y.astype(np.float32), z.astype(np.float32) in0 = x in1 = [x, y] in2 = (x, y, z) in3 = m in4 = [m, n] in5 = (m, n, o) _test_forward_add_n(in0) _test_forward_add_n(in1) _test_forward_add_n(in2) _test_forward_add_n(in3) _test_forward_add_n(in4) _test_forward_add_n(in5) ####################################################################### # Logical operators # -----------------
Example #13
Source File: From incubator-tvm with Apache License 2.0 | 5 votes |
def test_all_logical(): data = [np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool'), np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool')] # boolean dtype is not supported by older versions than TFLite 1.15.0 if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): _test_forward_logical_and(data) _test_forward_logical_or(data) _test_forward_logical_not(data) ####################################################################### # Zeros like # ----------
Example #14
Source File: From incubator-tvm with Apache License 2.0 | 5 votes |
def test_forward_convolution(): for quantized in [False, True]: _test_convolution([4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC', quantized=quantized) _test_convolution([4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC', quantized=quantized) _test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC', quantized=quantized) _test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC', quantized=quantized) # depthwise convolution _test_convolution([4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized) _test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized) _test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized) _test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized) _test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized) # depthwise convolution with single input channel _test_convolution([1, 76, 64, 1], [9, 5, 1, 96], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized) # TFLite2 quantized convolution testing if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): _test_tflite2_quantized_convolution([1, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC') _test_tflite2_quantized_convolution([1, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC') _test_tflite2_quantized_convolution([1, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC') _test_tflite2_quantized_convolution([1, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC') # depthwise convolution _test_tflite2_quantized_depthwise_convolution([1, 8, 8, 128], [1, 1, 128, 1], [1, 1], [1, 1], 'SAME', 'NHWC', 1) _test_tflite2_quantized_depthwise_convolution([1, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', 1) _test_tflite2_quantized_depthwise_convolution([1, 24, 24, 3], [7, 7, 3, 8], [1, 1], [2, 2], 'SAME', 'NHWC', 8) ####################################################################### # Transpose Convolution # ---------------------
Example #15
Source File: From incubator-tvm with Apache License 2.0 | 5 votes |
def test_all_elemwise(): _test_forward_elemwise(_test_add) _test_forward_elemwise_quantized(_test_add) _test_forward_elemwise(partial(_test_add, fused_activation_function="RELU")) # this is broken with tf upgrade 1.15.2 and hits a segfault that needs # further investigation. # _test_forward_elemwise(partial(_test_add, fused_activation_function="RELU6")) _test_forward_elemwise(_test_sub) _test_forward_elemwise_quantized(_test_sub) _test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU")) _test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU6")) _test_forward_elemwise(_test_mul) _test_forward_elemwise_quantized(_test_mul) _test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU")) _test_forward_elemwise(partial(_test_mul, fused_activation_function="RELU6")) _test_forward_elemwise(_test_div) _test_forward_elemwise(partial(_test_div, fused_activation_function="RELU")) _test_forward_elemwise(partial(_test_div, fused_activation_function="RELU6")) _test_forward_elemwise(_test_pow) _test_forward_elemwise(_test_maximum) _test_forward_elemwise(_test_minimum) _test_forward_elemwise(_test_greater) _test_forward_elemwise(_test_squared_difference) _test_forward_elemwise(_test_greater_equal) _test_forward_elemwise(_test_less) _test_forward_elemwise(_test_less_equal) _test_forward_elemwise(_test_equal) _test_forward_elemwise(_test_not_equal) if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): _test_forward_elemwise(_test_floor_divide) _test_forward_elemwise(_test_floor_mod) ####################################################################### # AddN # ----
Example #16
Source File: From g-tensorflow-models with Apache License 2.0 | 5 votes |
def _collect_tensorflow_info(run_info): run_info["tensorflow_version"] = { "version": tf.VERSION, "git_hash": tf.GIT_VERSION}
Example #17
Source File: From incubator-tvm with Apache License 2.0 | 5 votes |
def test_forward_quantize_dequantize(): """ Quantize Dequantize """ data = np.random.uniform(0, 1, (1, 4, 4, 3)).astype("float32") if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'): _test_quantize_dequantize(data) ####################################################################### # Pad # ---
Example #18
Source File: From incubator-tvm with Apache License 2.0 | 5 votes |
def test_forward_unpack(): """ UNPACK """ _test_unpack(np.array(np.random.uniform(0, 5, (3, 1)), dtype=np.int32), axis=1, num_unpacks=1) _test_unpack(np.array(np.random.uniform(0, 5, (3, 4)), dtype=np.float32), axis=0, num_unpacks=3) # tflite 1.13 doesn't accept negative axis if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): _test_unpack(np.array(np.random.uniform(0, 5, (3, 6)), dtype=np.int32), axis=-2, num_unpacks=3) _test_unpack(np.array(np.random.uniform(0, 5, (2, 3, 4)), dtype=np.int32), axis=-3, num_unpacks=2) ####################################################################### # Local response normalization # ----------------------------
Example #19
Source File: From incubator-tvm with Apache License 2.0 | 5 votes |
def test_forward_local_response_normalization(): """ LOCAL_RESPONSE_NORMALIZATION """ data = np.random.uniform(size=(1, 6, 4, 3)).astype('float32') # LOCAL_RESPONSE_NORMALIZATION come with TFLite >= 1.14.0 fbs schema if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): _test_local_response_normalization(data, depth_radius=5, bias=1, alpha=1, beta=0.5) ####################################################################### # L2 normalization # ----------------
Example #20
Source File: From incubator-tvm with Apache License 2.0 | 5 votes |
def _test_relu(data, quantized=False): """ One iteration of ReLU """ if quantized: if package_version.parse(tf.VERSION) < package_version.parse('2.1.0'): pytest.skip("Testcase requires tflite version >= 2.1.0") data_in = tf.keras.layers.Input(shape=data.shape[1:]) relu = tf.keras.layers.ReLU()(data_in) keras_model = tf.keras.models.Model(inputs=data_in, outputs=relu) input_name =":")[0] # To create quantized values with dynamic range of activations, needs representative dataset def representative_data_gen(): for i in range(1): yield [data] tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen) tflite_output = run_tflite_graph(tflite_model_quant, data) tvm_output = run_tvm_graph(tflite_model_quant, data, input_name) tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) else: with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) out = nn_ops.relu(in_data) compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
Example #21
Source File: From incubator-tvm with Apache License 2.0 | 5 votes |
def test_forward_leaky_relu(): """ Leaky_ReLU """ _test_leaky_relu(np.random.uniform(-5, 5, (1, 6)).astype(np.float32), alpha=0.2) if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): _test_leaky_relu(np.random.uniform(0, 255, (2, 3)).astype(np.uint8), alpha=0.3, quantized=True) ####################################################################### # ReLU_n1_to_1 # ------------
Example #22
Source File: From incubator-tvm with Apache License 2.0 | 5 votes |
def test_forward_depthtospace(): # DEPTH_TO_SPACE comes with TFLite >= 1.15.0 fbs schema if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): _test_depthtospace(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2) _test_depthtospace(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4) ####################################################################### # SpaceToDepth # ------------
Example #23
Source File: From incubator-tvm with Apache License 2.0 | 5 votes |
def test_forward_reverse_sequence(): if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): _test_reverse_sequence([4, 3], "float32", [3, 2, 1], 1, 0) _test_reverse_sequence([4, 3], "float32", [3, 2, 1, 3], 0, 1) _test_reverse_sequence([2, 3, 3, 3], "float32", [2, 3, 2], 2, 1) _test_reverse_sequence([2, 4, 6, 4, 5], "float32", [5, 3], 0, 2) _test_reverse_sequence([2, 4, 6, 4, 5], "float32", [5, 3, 1, 4], 3, 2) ####################################################################### # Sparse To Dense # ---------------
Example #24
Source File: From incubator-tvm with Apache License 2.0 | 5 votes |
def test_forward_qnn_mobilenet_v3_net(): """Test the Quantized TFLite Mobilenet V3 model.""" # In MobilenetV3, some ops are not supported before tf 1.15 fbs schema if package_version.parse(tf.VERSION) < package_version.parse('1.15.0'): pytest.skip("Unsupported in tflite < 1.15.0") else: pytest.skip("This segfaults with tensorflow 1.15.2 and above") tflite_model_file = tf_testing.get_workload_official( "", "v3-large_224_1.0_uint8/v3-large_224_1.0_uint8.tflite") with open(tflite_model_file, "rb") as f: tflite_model_buf = # Test image. Checking the labels because the requantize implementation is different between # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via # labels. Also, giving a real image, instead of random inputs. data = get_real_image(224, 224) tflite_output = run_tflite_graph(tflite_model_buf, data) tflite_predictions = np.squeeze(tflite_output) tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') tvm_predictions = np.squeeze(tvm_output) tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
Example #25
Source File: From models with Apache License 2.0 | 5 votes |
def _collect_tensorflow_info(run_info): run_info["tensorflow_version"] = { "version": tf.VERSION, "git_hash": tf.GIT_VERSION}
Example #26
Source File: From nsfw with Apache License 2.0 | 5 votes |
def _collect_tensorflow_info(run_info): run_info["tensorflow_version"] = { "version": tf.VERSION, "git_hash": tf.GIT_VERSION}
Example #27
Source File: From Gun-Detector with Apache License 2.0 | 5 votes |
def _collect_tensorflow_info(run_info): run_info["tensorflow_version"] = { "version": tf.VERSION, "git_hash": tf.GIT_VERSION}
Example #28
Source File: From multilabel-image-classification-tensorflow with MIT License | 5 votes |
def _collect_tensorflow_info(run_info): run_info["tensorflow_version"] = { "version": tf.VERSION, "git_hash": tf.GIT_VERSION}
Example #29
Source File: From SPFN with MIT License | 5 votes |
def compute_consistent_plane_frame(normal): # Input: normal is Bx3 # Returns: x_axis, y_axis, both of dimension Bx3 batch_size = tf.shape(normal)[0] candidate_axes = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # Actually, 2 should be enough. This may still cause singularity TODO!!! y_axes = [] for tmp_axis in candidate_axes: tf_axis = tf.tile(tf.expand_dims(tf.constant(dtype=tf.float32, value=tmp_axis), axis=0), [batch_size, 1]) # Bx3 y_axes.append(tf.cross(normal, tf_axis)) y_axes = tf.stack(y_axes, axis=0) # QxBx3 y_axes_norm = tf.norm(y_axes, axis=2) # QxB # choose the axis with largest norm y_axes_chosen_idx = tf.argmax(y_axes_norm, axis=0) # B # y_axes_chosen[b, :] = y_axes[y_axes_chosen_idx[b], b, :] indices_0 = tf.tile(tf.expand_dims(y_axes_chosen_idx, axis=1), [1, 3]) # Bx3 indices_1 = tf.tile(tf.expand_dims(tf.range(batch_size), axis=1), [1, 3]) # Bx3 indices_2 = tf.tile(tf.expand_dims(tf.range(3), axis=0), [batch_size, 1]) # Bx3 indices = tf.stack([tf.cast(indices_0, tf.int32), indices_1, indices_2], axis=2) # Bx3x3 y_axes = tf.gather_nd(y_axes, indices=indices) # Bx3 if tf.VERSION == '1.4.1': y_axes = tf.nn.l2_normalize(y_axes, dim=1) else: y_axes = tf.nn.l2_normalize(y_axes, axis=1) x_axes = tf.cross(y_axes, normal) # Bx3 return x_axes, y_axes
Example #30
Source File: From deep_image_model with Apache License 2.0 | 5 votes |
def testVersion(self): self.assertEqual(type(tf.__version__), str) self.assertEqual(type(tf.VERSION), str) # This pattern will need to grow as we include alpha, builds, etc. self.assertRegexpMatches(tf.__version__, r'^\d+\.\d+\.\w+$') self.assertRegexpMatches(tf.VERSION, r'^\d+\.\d+\.\w+$')