org.tensorflow.Tensors Java Examples
The following examples show how to use
org.tensorflow.Tensors.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: TensorFlowGraphModelTest.java From zoltar with Apache License 2.0 | 6 votes |
@Test public void testDummyLoadOfTensorFlowGraphWithPrefix() throws Exception { final String prefix = "test"; final Path graphFile = createADummyTFGraph(); try (final TensorFlowGraphModel model = TensorFlowGraphModel.create(graphFile.toUri(), null, prefix); final Session session = model.instance(); final Tensor<Double> double3 = Tensors.create(3.0D)) { List<Tensor<?>> result = null; try { result = session .runner() .fetch(prefix + "/" + mulResult) .feed(prefix + "/" + inputOpName, double3) .run(); assertEquals(result.get(0).doubleValue(), 6.0D, Double.MIN_VALUE); } finally { if (result != null) { result.forEach(Tensor::close); } } } }
Example #2
Source File: TensorFlowGraphModelTest.java From zoltar with Apache License 2.0 | 6 votes |
@Test public void testDummyLoadOfTensorFlowGraph() throws Exception { final Path graphFile = createADummyTFGraph(); try (final TensorFlowGraphModel model = TensorFlowGraphModel.create(graphFile.toUri(), null, null); final Session session = model.instance(); final Tensor<Double> double3 = Tensors.create(3.0D)) { List<Tensor<?>> result = null; try { result = session.runner().fetch(mulResult).feed(inputOpName, double3).run(); assertEquals(result.get(0).doubleValue(), 6.0D, Double.MIN_VALUE); } finally { if (result != null) { result.forEach(Tensor::close); } } } }
Example #3
Source File: JTensorTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void longTensorSerializable() throws IOException { final long[] longValue = {1, 2, 3, 4, 5}; final Tensor<Long> tensor = Tensors.create(longValue); final JTensor jt = JTensor.create(tensor); new ObjectOutputStream(new ByteArrayOutputStream()).writeObject(jt); }
Example #4
Source File: FaceRecognizer.java From server_face_recognition with GNU General Public License v3.0 | 5 votes |
/** * Running neural network * * @param image cropped, centralized face * @return describing of a face based on 128 float features */ private FaceFeatures passImageThroughNeuralNetwork(BufferedImage image, int faceType) { FaceFeatures features; try (Session session = new Session(graph)) { Tensor<Float> feedImage = Tensors.create(imageToMultiDimensionalArray(image)); long timeResponse = System.currentTimeMillis(); Tensor<Float> response = session.runner() .feed("input", feedImage) .feed("phase_train", Tensor.create(false)) .fetch("embeddings") .run().get(0) .expect(Float.class); FileUtils.timeSpent(timeResponse, "RESPONSE"); final long[] shape = response.shape(); //first dimension should return 1 as for image with normal size //second dimension should give 128 characteristics of face if (shape[0] != 1 || shape[1] != 128) { throw new IllegalStateException("illegal output values: 1 = " + shape[0] + " 2 = " + shape[1]); } float[][] featuresHolder = new float[1][128]; response.copyTo(featuresHolder); features = new FaceFeatures(featuresHolder[0], faceType); response.close(); } return features; }
Example #5
Source File: InProcessClassification.java From hazelcast-jet-demos with Apache License 2.0 | 5 votes |
private static Tuple2<String, Float> classify( String review, SavedModelBundle model, WordIndex wordIndex ) { try (Tensor<Float> input = Tensors.create(wordIndex.createTensorInput(review)); Tensor<?> output = model.session().runner() .feed("embedding_input:0", input) .fetch("dense_1/Sigmoid:0").run().get(0) ) { float[][] result = new float[1][1]; output.copyTo(result); return tuple2(review, result[0][0]); } }
Example #6
Source File: TensorFlowPredictFn.java From zoltar with Apache License 2.0 | 5 votes |
/** * TensorFlow Example prediction function. * * @deprecated Use {@link #example(Function, String...)} * @param outTensorExtractor Function to extract the output value from JTensor's * @param fetchOps operations to fetch. */ @Deprecated static <InputT, ValueT> TensorFlowPredictFn<InputT, List<Example>, ValueT> exampleBatch( final Function<Map<String, JTensor>, ValueT> outTensorExtractor, final String... fetchOps) { final BiFunction<TensorFlowModel, List<Example>, ValueT> predictFn = (model, examples) -> { final byte[][] bytes = examples.stream().map(Example::toByteArray).toArray(byte[][]::new); try (final Tensor<String> t = Tensors.create(bytes)) { final Session.Runner runner = model.instance().session().runner().feed("input_example_tensor", t); final Map<String, JTensor> result = TensorFlowExtras.runAndExtract(runner, fetchOps); return outTensorExtractor.apply(result); } }; return (model, vectors) -> { final List<CompletableFuture<Prediction<InputT, ValueT>>> predictions = vectors .stream() .map( vector -> CompletableFuture.supplyAsync(() -> predictFn.apply(model, vector.value())) .thenApply(v -> Prediction.create(vector.input(), v))) .collect(Collectors.toList()); return CompletableFutures.allAsList(predictions); }; }
Example #7
Source File: TensorFlowPredictFn.java From zoltar with Apache License 2.0 | 5 votes |
/** * TensorFlow Example prediction function. * * @param outTensorExtractor Function to extract the output value from JTensor's * @param fetchOps operations to fetch. */ static <InputT, ValueT> TensorFlowPredictFn<InputT, Example, ValueT> example( final Function<Map<String, JTensor>, List<ValueT>> outTensorExtractor, final String... fetchOps) { return (model, vectors) -> CompletableFuture.supplyAsync( () -> { final byte[][] bytes = vectors .stream() .map(Vector::value) .map(Example::toByteArray) .toArray(byte[][]::new); try (final Tensor<String> t = Tensors.create(bytes)) { final Session.Runner runner = model.instance().session().runner().feed("input_example_tensor", t); final Map<String, JTensor> result = TensorFlowExtras.runAndExtract(runner, fetchOps); final Iterator<Vector<InputT, Example>> vectorIterator = vectors.iterator(); final Iterator<ValueT> valueTIterator = outTensorExtractor.apply(result).iterator(); final List<Prediction<InputT, ValueT>> predictions = new ArrayList<>(); while (vectorIterator.hasNext() && valueTIterator.hasNext()) { predictions.add( Prediction.create(vectorIterator.next().input(), valueTIterator.next())); } return predictions; } }); }
Example #8
Source File: TensorFlowGraphModelTest.java From zoltar with Apache License 2.0 | 5 votes |
/** * Creates a simple TensorFlow graph that multiplies Double on input by 2.0, result is available * via multiply operation. */ private Path createADummyTFGraph() throws IOException { final Path graphFile; try (final Graph graph = new Graph(); final Tensor<Double> t = Tensors.create(2.0D)) { final Output<Double> input = graph .opBuilder("Placeholder", inputOpName) .setAttr("dtype", t.dataType()) .build() .output(0); final Output<Double> two = graph .opBuilder("Const", "two") .setAttr("dtype", t.dataType()) .setAttr("value", t) .build() .output(0); graph.opBuilder("Mul", mulResult).addInput(two).addInput(input).build(); graphFile = Files.createTempFile("tf-graph", ".bin"); Files.write(graphFile, graph.toGraphDef()); } return graphFile; }
Example #9
Source File: TensorFlowExtrasTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void testExtract2b() { final Graph graph = createDummyGraph(); final Session session = new Session(graph); final Session.Runner runner = session.runner(); runner.feed("input", Tensors.create(10.0)); final Map<String, JTensor> result = TensorFlowExtras.runAndExtract(runner, mul3, mul2); assertEquals(Lists.newArrayList(mul3, mul2), new ArrayList<>(result.keySet())); assertScalar(result.get(mul2), 20.0); assertScalar(result.get(mul3), 30.0); session.close(); graph.close(); }
Example #10
Source File: TensorFlowExtrasTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void testExtract2a() { final Graph graph = createDummyGraph(); final Session session = new Session(graph); final Session.Runner runner = session.runner(); runner.feed("input", Tensors.create(10.0)); final Map<String, JTensor> result = TensorFlowExtras.runAndExtract(runner, mul2, mul3); assertEquals(Lists.newArrayList(mul2, mul3), new ArrayList<>(result.keySet())); assertScalar(result.get(mul2), 20.0); assertScalar(result.get(mul3), 30.0); session.close(); graph.close(); }
Example #11
Source File: TensorFlowExtrasTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void testExtract1() { final Graph graph = createDummyGraph(); final Session session = new Session(graph); final Session.Runner runner = session.runner(); runner.feed("input", Tensors.create(10.0)); final Map<String, JTensor> result = TensorFlowExtras.runAndExtract(runner, mul2); assertEquals(Sets.newHashSet(mul2), result.keySet()); assertScalar(result.get(mul2), 20.0); session.close(); graph.close(); }
Example #12
Source File: TensorFlowExtrasTest.java From zoltar with Apache License 2.0 | 5 votes |
private static Graph createDummyGraph() { final Tensor<Double> t2 = Tensors.create(2.0); final Tensor<Double> t3 = Tensors.create(3.0); final Graph graph = new Graph(); final Output<Double> input = graph.opBuilder("Placeholder", "input").setAttr("dtype", DataType.DOUBLE).build().output(0); final Output<Double> two = graph .opBuilder("Const", "two") .setAttr("dtype", t2.dataType()) .setAttr("value", t2) .build() .output(0); final Output<Double> three = graph .opBuilder("Const", "three") .setAttr("dtype", t3.dataType()) .setAttr("value", t3) .build() .output(0); graph.opBuilder("Mul", mul2).addInput(input).addInput(two).build(); graph.opBuilder("Mul", mul3).addInput(input).addInput(three).build(); return graph; }
Example #13
Source File: JTensorTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void doubleTensorSerializable() throws IOException { final double[] doubleValue = {1, 2, 3, 4, 5}; final Tensor<Double> tensor = Tensors.create(doubleValue); final JTensor jt = JTensor.create(tensor); new ObjectOutputStream(new ByteArrayOutputStream()).writeObject(jt); }
Example #14
Source File: JTensorTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void floatTensorSerializable() throws IOException { final float[] floatValue = {1, 2, 3, 4, 5}; final Tensor<Float> tensor = Tensors.create(floatValue); final JTensor jt = JTensor.create(tensor); new ObjectOutputStream(new ByteArrayOutputStream()).writeObject(jt); }
Example #15
Source File: JTensorTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void intTensorSerializable() throws IOException { final int[] intValue = {1, 2, 3, 4, 5}; final Tensor<Integer> tensor = Tensors.create(intValue); final JTensor jt = JTensor.create(tensor); new ObjectOutputStream(new ByteArrayOutputStream()).writeObject(jt); }
Example #16
Source File: JTensorTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void multidimensionalStringTensorSerializable() throws IOException { final byte[][][][] byteArray = toByteArray(STRING_ARRAY_3DIMENSIONS); final Tensor<String> tensor = Tensors.create(byteArray); final JTensor jt = JTensor.create(tensor); new ObjectOutputStream(new ByteArrayOutputStream()).writeObject(jt); }
Example #17
Source File: JTensorTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void stringTensorSerializable() throws IOException { final String stringValue = "world"; final Tensor<String> tensor = Tensors.create(stringValue); final JTensor jt = JTensor.create(tensor); new ObjectOutputStream(new ByteArrayOutputStream()).writeObject(jt); }
Example #18
Source File: JTensorTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void testBooleanTensor() { final boolean[] booleanValue = {true, true, false, true, false}; final Tensor<Boolean> tensor = Tensors.create(booleanValue); final JTensor jt = JTensor.create(tensor); assertEquals(DataType.BOOL, jt.dataType()); assertEquals(1, jt.numDimensions()); assertArrayEquals(shape, jt.shape()); assertArrayEquals(booleanValue, jt.booleanValue()); testException(jt, JTensor::stringValue); testException(jt, JTensor::intValue); testException(jt, JTensor::longValue); testException(jt, JTensor::floatValue); testException(jt, JTensor::doubleValue); }
Example #19
Source File: JTensorTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void testDoubleTensor() { final double[] doubleValue = {1, 2, 3, 4, 5}; final Tensor<Double> tensor = Tensors.create(doubleValue); final JTensor jt = JTensor.create(tensor); assertEquals(DataType.DOUBLE, jt.dataType()); assertEquals(1, jt.numDimensions()); assertArrayEquals(shape, jt.shape()); assertArrayEquals(doubleValue, jt.doubleValue(), 0.0); testException(jt, JTensor::stringValue); testException(jt, JTensor::intValue); testException(jt, JTensor::longValue); testException(jt, JTensor::floatValue); }
Example #20
Source File: JTensorTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void testFloatTensor() { final float[] floatValue = {1, 2, 3, 4, 5}; final Tensor<Float> tensor = Tensors.create(floatValue); final JTensor jt = JTensor.create(tensor); assertEquals(DataType.FLOAT, jt.dataType()); assertEquals(1, jt.numDimensions()); assertArrayEquals(shape, jt.shape()); assertArrayEquals(floatValue, jt.floatValue(), 0.0f); testException(jt, JTensor::stringValue); testException(jt, JTensor::intValue); testException(jt, JTensor::longValue); testException(jt, JTensor::doubleValue); }
Example #21
Source File: JTensorTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void testLongTensor() { final long[] longValue = {1, 2, 3, 4, 5}; final Tensor<Long> tensor = Tensors.create(longValue); final JTensor jt = JTensor.create(tensor); assertEquals(DataType.INT64, jt.dataType()); assertEquals(1, jt.numDimensions()); assertArrayEquals(shape, jt.shape()); assertArrayEquals(longValue, jt.longValue()); testException(jt, JTensor::stringValue); testException(jt, JTensor::intValue); testException(jt, JTensor::floatValue); testException(jt, JTensor::doubleValue); }
Example #22
Source File: JTensorTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void testIntTensor() { final int[] intValue = {1, 2, 3, 4, 5}; final Tensor<Integer> tensor = Tensors.create(intValue); final JTensor jt = JTensor.create(tensor); assertEquals(DataType.INT32, jt.dataType()); assertEquals(1, jt.numDimensions()); assertArrayEquals(shape, jt.shape()); assertArrayEquals(intValue, jt.intValue()); testException(jt, JTensor::stringValue); testException(jt, JTensor::longValue); testException(jt, JTensor::floatValue); testException(jt, JTensor::doubleValue); }
Example #23
Source File: JTensorTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void testStringTensor3DimensionsDescending() { final byte[][][][] byteArray = toByteArray(STRING_ARRAY_3DIMENSIONS_DESCENDING); final Tensor<String> tensor = Tensors.create(byteArray); final JTensor jt = JTensor.create(tensor); testMultidimensionalStringTensor(jt, STRING_ARRAY_3DIMENSIONS_DESCENDING, new long[] {3, 2, 1}); }
Example #24
Source File: JTensorTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void testStringTensor3DimensionsAscending() { final byte[][][][] byteArray = toByteArray(STRING_ARRAY_3DIMENSIONS_ASCENDING); final Tensor<String> tensor = Tensors.create(byteArray); final JTensor jt = JTensor.create(tensor); testMultidimensionalStringTensor(jt, STRING_ARRAY_3DIMENSIONS_ASCENDING, new long[] {1, 2, 3}); }
Example #25
Source File: JTensorTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void testStringTensor3Dimensions() { final byte[][][][] byteArray = toByteArray(STRING_ARRAY_3DIMENSIONS); final Tensor<String> tensor = Tensors.create(byteArray); final JTensor jt = JTensor.create(tensor); testMultidimensionalStringTensor(jt, STRING_ARRAY_3DIMENSIONS, new long[] {3, 3, 3}); }
Example #26
Source File: JTensorTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void testStringTensor1Dimension() { final byte[][] byteArray = toByteArray(STRING_ARRAY_1DIMENSION); final Tensor<String> tensor = Tensors.create(byteArray); final JTensor jt = JTensor.create(tensor); testMultidimensionalStringTensor(jt, STRING_ARRAY_1DIMENSION, new long[] {3}); }
Example #27
Source File: JTensorTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void testStringTensor() { final String stringValue = "world"; final Tensor<String> tensor = Tensors.create(stringValue); final JTensor jt = JTensor.create(tensor); assertEquals(DataType.STRING, jt.dataType()); assertEquals(0, jt.numDimensions()); assertArrayEquals(new long[0], jt.shape()); assertEquals(stringValue, jt.stringValue()); testException(jt, JTensor::intValue); testException(jt, JTensor::longValue); testException(jt, JTensor::floatValue); testException(jt, JTensor::doubleValue); }
Example #28
Source File: ObjectDetector.java From OpenLabeler with Apache License 2.0 | 4 votes |
/** * See <a href="https://github.com/tensorflow/tensorflow/issues/24331#issuecomment-447523402">GitHub issue</a> */ private static Tensor<?> makeImageStringTensor(File imageFile) throws IOException { var content = FileUtils.readFileToByteArray(imageFile); byte[][] data = { content }; return Tensors.create(data); }