Java Code Examples for org.tensorflow.Session#Runner
The following examples show how to use
org.tensorflow.Session#Runner .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TfSymbolBlock.java From djl with Apache License 2.0 | 6 votes |
/** {@inheritDoc} */ @Override public NDList forward( ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) { Session.Runner runner = session.runner(); PairList<String, Shape> inputDescriptions = describeInput(); PairList<String, Shape> outputDescriptions = describeOutput(); for (int i = 0; i < inputDescriptions.size(); i++) { runner.feed(inputDescriptions.get(i).getKey(), ((TfNDArray) inputs.get(i)).getTensor()); } for (int i = 0; i < outputDescriptions.size(); i++) { runner.fetch(outputDescriptions.get(i).getKey()); } List<Tensor<?>> result = runner.run(); NDList resultNDList = new NDList(); TfNDManager tfNDManager = (TfNDManager) inputs.head().getManager(); for (Tensor<?> tensor : result) { resultNDList.add(tfNDManager.create(tensor)); } return resultNDList; }
Example 2
Source File: TensorFlowExtras.java From zoltar with Apache License 2.0 | 6 votes |
/** * Fetch a list of operations from a {@link Session.Runner}, run it, extract output {@link * Tensor}s as {@link JTensor}s and close them. * * @param runner {@link Session.Runner} to fetch operations and extract outputs from. * @param fetchOps operations to fetch. * @return a {@link Map} of operations and output {@link JTensor}s. Map keys are in the same order * as {@code fetchOps}. */ public static Map<String, JTensor> runAndExtract( final Session.Runner runner, final String... fetchOps) { for (final String op : fetchOps) { runner.fetch(op); } final Map<String, JTensor> result = Maps.newLinkedHashMapWithExpectedSize(fetchOps.length); final List<Tensor<?>> tensors = runner.run(); try { for (int i = 0; i < fetchOps.length; i++) { final Tensor<?> tensor = tensors.get(i); result.put(fetchOps[i], JTensor.create(tensor)); } } finally { tensors.forEach(Tensor::close); } return result; }
Example 3
Source File: TensorFlowExtrasTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void testExtract1() { final Graph graph = createDummyGraph(); final Session session = new Session(graph); final Session.Runner runner = session.runner(); runner.feed("input", Tensors.create(10.0)); final Map<String, JTensor> result = TensorFlowExtras.runAndExtract(runner, mul2); assertEquals(Sets.newHashSet(mul2), result.keySet()); assertScalar(result.get(mul2), 20.0); session.close(); graph.close(); }
Example 4
Source File: TensorFlowExtrasTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void testExtract2a() { final Graph graph = createDummyGraph(); final Session session = new Session(graph); final Session.Runner runner = session.runner(); runner.feed("input", Tensors.create(10.0)); final Map<String, JTensor> result = TensorFlowExtras.runAndExtract(runner, mul2, mul3); assertEquals(Lists.newArrayList(mul2, mul3), new ArrayList<>(result.keySet())); assertScalar(result.get(mul2), 20.0); assertScalar(result.get(mul3), 30.0); session.close(); graph.close(); }
Example 5
Source File: TensorFlowExtrasTest.java From zoltar with Apache License 2.0 | 5 votes |
@Test public void testExtract2b() { final Graph graph = createDummyGraph(); final Session session = new Session(graph); final Session.Runner runner = session.runner(); runner.feed("input", Tensors.create(10.0)); final Map<String, JTensor> result = TensorFlowExtras.runAndExtract(runner, mul3, mul2); assertEquals(Lists.newArrayList(mul3, mul2), new ArrayList<>(result.keySet())); assertScalar(result.get(mul2), 20.0); assertScalar(result.get(mul3), 30.0); session.close(); graph.close(); }
Example 6
Source File: TensorFlowPredictFn.java From zoltar with Apache License 2.0 | 5 votes |
/** * TensorFlow Example prediction function. * * @deprecated Use {@link #example(Function, String...)} * @param outTensorExtractor Function to extract the output value from JTensor's * @param fetchOps operations to fetch. */ @Deprecated static <InputT, ValueT> TensorFlowPredictFn<InputT, List<Example>, ValueT> exampleBatch( final Function<Map<String, JTensor>, ValueT> outTensorExtractor, final String... fetchOps) { final BiFunction<TensorFlowModel, List<Example>, ValueT> predictFn = (model, examples) -> { final byte[][] bytes = examples.stream().map(Example::toByteArray).toArray(byte[][]::new); try (final Tensor<String> t = Tensors.create(bytes)) { final Session.Runner runner = model.instance().session().runner().feed("input_example_tensor", t); final Map<String, JTensor> result = TensorFlowExtras.runAndExtract(runner, fetchOps); return outTensorExtractor.apply(result); } }; return (model, vectors) -> { final List<CompletableFuture<Prediction<InputT, ValueT>>> predictions = vectors .stream() .map( vector -> CompletableFuture.supplyAsync(() -> predictFn.apply(model, vector.value())) .thenApply(v -> Prediction.create(vector.input(), v))) .collect(Collectors.toList()); return CompletableFutures.allAsList(predictions); }; }
Example 7
Source File: RNTensorflowInference.java From react-native-tensorflow with Apache License 2.0 | 5 votes |
private static TfContext createContext(ReactContext reactContext, String model) throws IOException { byte[] b = new ResourceManager(reactContext).loadResource(model); Graph graph = new Graph(); graph.importGraphDef(b); Session session = new Session(graph); Session.Runner runner = session.runner(); return new TfContext(session, runner, graph); }
Example 8
Source File: GraphImporter.java From vespa with Apache License 2.0 | 5 votes |
static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) { Session.Runner fetched = bundle.session().runner().fetch(name); List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); if (importedTensors.size() != 1) throw new IllegalStateException("Expected 1 tensor from fetching " + name + ", but got " + importedTensors.size()); return importedTensors.get(0); }
Example 9
Source File: TestableTensorFlowModel.java From vespa with Apache License 2.0 | 5 votes |
private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) { Session.Runner runner = model.session().runner(); org.tensorflow.Tensor<?> input = floatInput ? tensorFlowFloatInputArgument() : tensorFlowDoubleInputArgument(); runner.feed(inputName, input); List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run(); assertEquals(1, results.size()); return TensorConverter.toVespaTensor(results.get(0)); }
Example 10
Source File: BlogEvaluationBenchmark.java From vespa with Apache License 2.0 | 5 votes |
private static double evaluateTensorflow(SavedModelBundle tensorFlowModel, org.tensorflow.Tensor<?> u, org.tensorflow.Tensor<?> d, int iterations) { double result = 0; for (int i = 0 ; i < iterations; i++) { Session.Runner runner = tensorFlowModel.session().runner(); runner.feed("input_u", u); runner.feed("input_d", d); List<org.tensorflow.Tensor<?>> results = runner.fetch("y").run(); result = TensorConverter.toVespaTensor(results.get(0)).sum().asDouble(); } return result; }
Example 11
Source File: TestableModel.java From vespa with Apache License 2.0 | 5 votes |
Tensor evaluateTF(SavedModelBundle tensorFlowModel, String operationName, Map<String, TensorType> inputs) { Session.Runner runner = tensorFlowModel.session().runner(); for (Map.Entry<String, TensorType> entry : inputs.entrySet()) { try { runner.feed(entry.getKey(), tensorFlowFloatInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue())); } catch (Exception e) { runner.feed(entry.getKey(), tensorFlowDoubleInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue())); } } List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run(); assertEquals(1, results.size()); return TensorConverter.toVespaTensor(results.get(0)); }
Example 12
Source File: TensorFlowProcessor.java From datacollector with Apache License 2.0 | 5 votes |
private void processUseEntireBatch(Batch batch, SingleLaneBatchMaker singleLaneBatchMaker) throws StageException { Session.Runner runner = this.session.runner(); Iterator<Record> batchRecords = batch.getRecords(); if (batchRecords.hasNext()) { Map<Pair<String, Integer>, Tensor> inputs = convertBatch(batch, conf.inputConfigs); try { for (Map.Entry<Pair<String, Integer>, Tensor> inputMapEntry : inputs.entrySet()) { runner.feed(inputMapEntry.getKey().getLeft(), inputMapEntry.getKey().getRight(), inputMapEntry.getValue() ); } for (TensorConfig outputConfig : conf.outputConfigs) { runner.fetch(outputConfig.operation, outputConfig.index); } List<Tensor<?>> tensorOutput = runner.run(); LinkedHashMap<String, Field> outputTensorFieldMap = createOutputFieldValue(tensorOutput); EventRecord eventRecord = TensorFlowEvents.TENSOR_FLOW_OUTPUT_CREATOR.create(getContext()).create(); eventRecord.set(Field.createListMap(outputTensorFieldMap)); getContext().toEvent(eventRecord); } finally { inputs.values().forEach(Tensor::close); } Iterator<Record> it = batch.getRecords(); while (it.hasNext()) { singleLaneBatchMaker.addRecord(it.next()); } } }
Example 13
Source File: TensorFlowProcessor.java From datacollector with Apache License 2.0 | 5 votes |
public void processUseRecordByRecord(Batch batch, SingleLaneBatchMaker singleLaneBatchMaker) throws StageException { Iterator<Record> it = batch.getRecords(); while (it.hasNext()) { Record record = it.next(); setInputConfigFields(record); Session.Runner runner = this.session.runner(); Map<Pair<String, Integer>, Tensor> inputs = null; try { inputs = convertRecord(record, conf.inputConfigs); } catch (OnRecordErrorException ex) { errorRecordHandler.onError(ex); continue; } try { for (Map.Entry<Pair<String, Integer>, Tensor> inputMapEntry : inputs.entrySet()) { runner.feed(inputMapEntry.getKey().getLeft(), inputMapEntry.getKey().getRight(), inputMapEntry.getValue() ); } for (TensorConfig outputConfig : conf.outputConfigs) { runner.fetch(outputConfig.operation, outputConfig.index); } List<Tensor<?>> tensorOutput = runner.run(); LinkedHashMap<String, Field> outputTensorFieldMap = createOutputFieldValue(tensorOutput); record.set(conf.outputField, Field.create(outputTensorFieldMap)); singleLaneBatchMaker.addRecord(record); } finally { inputs.values().forEach(Tensor::close); } } }
Example 14
Source File: RNTensorflowInference.java From react-native-tensorflow with Apache License 2.0 | 4 votes |
TfContext(Session session, Session.Runner runner, Graph graph) { this.session = session; this.runner = runner; this.graph = graph; outputTensors = new HashMap<>(); }