org.tensorflow.SavedModelBundle Java Examples
The following examples show how to use
org.tensorflow.SavedModelBundle.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: ModelPredictTest.java From DeepMachineLearning with Apache License 2.0 | 7 votes |
@Test public void test() throws IOException, InterruptedException, IpssCacheException, ODMException, ExecutionException { AclfTrainDataGenerator gateway = new AclfTrainDataGenerator(); //read case String filename = "testdata/cases/ieee14.ieee"; gateway.loadCase(filename, "BusVoltLoadChangeTrainCaseBuilder"); //run loadflow gateway.trainCaseBuilder.createTestCase(); //generate input double[] inputs = gateway.trainCaseBuilder.getNetInput(); float[][] inputs_f = new float[1][inputs.length]; for (int i = 0; i < inputs.length; i++) { inputs_f[0][i] =(float) inputs[i]; } //read model SavedModelBundle bundle = SavedModelBundle.load("py/c_graph/single_net/model", "voltage"); //predict float[][] output = bundle.session().runner().feed("x", Tensor.create(inputs_f)).fetch("z").run().get(0) .copyTo(new float[1][28]); double[][] output_d = new double[1][inputs.length]; for (int i = 0; i < inputs.length; i++) { output_d[0][i] = output[0][i]; } //print out mismatch System.out.println("Model out mismatch: "+gateway.getMismatchInfo(output_d[0])); }
Example #2
Source File: GraphImporter.java From vespa with Apache License 2.0 | 6 votes |
private static IntermediateOperation importOperation(String nodeName, GraphDef tfGraph, IntermediateGraph intermediateGraph, SavedModelBundle bundle) { if (intermediateGraph.alreadyImported(nodeName)) { return intermediateGraph.get(nodeName); } NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(nodeName), tfGraph); List<IntermediateOperation> inputs = importOperationInputs(node, tfGraph, intermediateGraph, bundle); IntermediateOperation operation = mapOperation(node, inputs, intermediateGraph); intermediateGraph.put(nodeName, operation); List<IntermediateOperation> controlInputs = importControlInputs(node, tfGraph, intermediateGraph, bundle); if (controlInputs.size() > 0) { operation.setControlInputs(controlInputs); } if (operation.isConstant()) { operation.setConstantValueFunction( type -> new TensorValue(TensorConverter.toVespaTensor(readVariable(nodeName, bundle), type))); } return operation; }
Example #3
Source File: InProcessClassification.java From hazelcast-jet-demos with Apache License 2.0 | 6 votes |
private static Pipeline buildPipeline(IMap<Long, String> reviewsMap) { // Set up the mapping context that loads the model on each member, shared // by all parallel processors on that member. ServiceFactory<Tuple2<SavedModelBundle, WordIndex>, Tuple2<SavedModelBundle, WordIndex>> modelContext = ServiceFactory .withCreateContextFn(context -> { File data = context.attachedDirectory("data"); SavedModelBundle bundle = SavedModelBundle.load(data.toPath().resolve("model/1").toString(), "serve"); return tuple2(bundle, new WordIndex(data)); }) .withDestroyContextFn(t -> t.f0().close()) .withCreateServiceFn((context, tuple2) -> tuple2); Pipeline p = Pipeline.create(); p.readFrom(Sources.map(reviewsMap)) .map(Map.Entry::getValue) .mapUsingService(modelContext, (tuple, review) -> classify(review, tuple.f0(), tuple.f1())) // TensorFlow executes models in parallel, we'll use 2 local threads to maximize throughput. .setLocalParallelism(2) .writeTo(Sinks.logger(t -> String.format("Sentiment rating for review \"%s\" is %.2f", t.f0(), t.f1()))); return p; }
Example #4
Source File: BlogEvaluationBenchmark.java From vespa with Apache License 2.0 | 6 votes |
public static void main(String[] args) throws ParseException { SavedModelBundle tensorFlowModel = SavedModelBundle.load(modelDir, "serve"); ImportedModel model = new TensorFlowImporter().importModel("blog", modelDir, tensorFlowModel); Context context = TestableTensorFlowModel.contextFrom(model); Tensor u = generateInputTensor(); Tensor d = generateInputTensor(); context.put("input_u", new TensorValue(u)); context.put("input_d", new TensorValue(d)); // Parse the ranking expression from imported string to force primitive tensor functions. RankingExpression expression = new RankingExpression(model.expressions().get("y").getRoot().toString()); benchmarkJava(expression, context, 20, 200); System.out.println("*** Optimizing expression ***"); ExpressionOptimizer optimizer = new ExpressionOptimizer(); OptimizationReport report = optimizer.optimize(expression, (ContextIndex)context); System.out.println(report.toString()); benchmarkJava(expression, context, 2000, 20000); benchmarkTensorFlow(tensorFlowModel, 2000, 20000); }
Example #5
Source File: TestableTensorFlowModel.java From vespa with Apache License 2.0 | 5 votes |
private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) { Session.Runner runner = model.session().runner(); org.tensorflow.Tensor<?> input = floatInput ? tensorFlowFloatInputArgument() : tensorFlowDoubleInputArgument(); runner.feed(inputName, input); List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run(); assertEquals(1, results.size()); return TensorConverter.toVespaTensor(results.get(0)); }
Example #6
Source File: TestableTensorFlowModel.java From vespa with Apache License 2.0 | 5 votes |
public TestableTensorFlowModel(String modelName, String modelDir, int d0Size, int d1Size, boolean floatInput) { tensorFlowModel = SavedModelBundle.load(modelDir, "serve"); model = new TensorFlowImporter().importModel(modelName, modelDir, tensorFlowModel); this.d0Size = d0Size; this.d1Size = d1Size; this.floatInput = floatInput; }
Example #7
Source File: VariableConverter.java From vespa with Apache License 2.0 | 5 votes |
/** * Reads the tensor with the given TensorFlow name at the given model location, * and encodes it as UTF-8 Vespa document tensor JSON having the given ordered tensor type. * Note that order of dimensions in the tensor type does matter as the TensorFlow tensor * tensor dimensions are implicitly ordered. */ static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) { try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) { return JsonFormat.encode(TensorConverter.toVespaTensor(GraphImporter.readVariable(tensorFlowVariableName, bundle), OrderedTensorType.fromSpec(orderedTypeSpec))); } catch (IllegalArgumentException e) { throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); } }
Example #8
Source File: GraphImporter.java From vespa with Apache License 2.0 | 5 votes |
static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) { Session.Runner fetched = bundle.session().runner().fetch(name); List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); if (importedTensors.size() != 1) throw new IllegalStateException("Expected 1 tensor from fetching " + name + ", but got " + importedTensors.size()); return importedTensors.get(0); }
Example #9
Source File: GraphImporter.java From vespa with Apache License 2.0 | 5 votes |
private static List<IntermediateOperation> importControlInputs(NodeDef node, GraphDef tfGraph, IntermediateGraph intermediateGraph, SavedModelBundle bundle) { return node.getInputList().stream() .filter(nodeName -> isControlDependency(nodeName)) .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle)) .collect(Collectors.toList()); }
Example #10
Source File: GraphImporter.java From vespa with Apache License 2.0 | 5 votes |
private static List<IntermediateOperation> importOperationInputs(NodeDef node, GraphDef tfGraph, IntermediateGraph intermediateGraph, SavedModelBundle bundle) { return node.getInputList().stream() .filter(name -> ! isControlDependency(name)) .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle)) .collect(Collectors.toList()); }
Example #11
Source File: GraphImporter.java From vespa with Apache License 2.0 | 5 votes |
private static void importOperations(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph, SavedModelBundle bundle) { for (String signatureName : intermediateGraph.signatures()) { for (String outputName : intermediateGraph.outputs(signatureName).values()) { importOperation(outputName, tfGraph.getGraphDef(), intermediateGraph, bundle); } } }
Example #12
Source File: GraphImporter.java From vespa with Apache License 2.0 | 5 votes |
static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException { MetaGraphDef tfGraph = MetaGraphDef.parseFrom(bundle.metaGraphDef()); IntermediateGraph intermediateGraph = new IntermediateGraph(modelName); importSignatures(tfGraph, intermediateGraph); importOperations(tfGraph, intermediateGraph, bundle); verifyOutputTypes(tfGraph, intermediateGraph); return intermediateGraph; }
Example #13
Source File: TensorFlowImporter.java From vespa with Apache License 2.0 | 5 votes |
/** Imports a TensorFlow model */ public ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) { try { IntermediateGraph graph = GraphImporter.importGraph(modelName, model); return convertIntermediateGraphToModel(graph, modelDir); } catch (IOException e) { throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e); } }
Example #14
Source File: TensorFlowImporter.java From vespa with Apache License 2.0 | 5 votes |
/** * Imports a saved TensorFlow model from a directory. * The model should be saved as a .pbtxt or .pb file. * * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_] * @param modelDir the directory containing the TensorFlow model files to import */ @Override public ImportedModel importModel(String modelName, String modelDir) { // Temporary (for testing): if path contains "tf_2_onnx", convert to ONNX then import that model. if (modelDir.contains("tf_2_onnx")) { return convertToOnnxAndImport(modelName, modelDir); } try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { return importModel(modelName, modelDir, model); } catch (IllegalArgumentException e) { throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); } }
Example #15
Source File: TensorFlowModelProducer.java From samantha with MIT License | 5 votes |
public TensorFlowModel createTensorFlowModelModelFromExportDir( String modelName, SpaceMode spaceMode, String exportDir, List<String> groupKeys, List<List<String>> equalSizeChecks, List<String> indexKeys, List<FeatureExtractor> featureExtractors, String predItemFea, String lossOper, String updateOper, String outputOper, String initOper, String topKOper, String topKId, String topKValue, String itemIndex) { IndexSpace indexSpace = getIndexSpace(modelName, spaceMode, indexKeys); VariableSpace variableSpace = getVariableSpace(modelName, spaceMode); SavedModelBundle savedModel = loadTensorFlowSavedModel(exportDir); Session session = null; Graph graph = null; if (savedModel != null) { session = savedModel.session(); graph = savedModel.graph(); } return new TensorFlowModel(graph, session, null, exportDir, indexSpace, variableSpace, featureExtractors, predItemFea, lossOper, updateOper, topKId, itemIndex, topKValue, outputOper, topKOper, initOper, groupKeys, equalSizeChecks); }
Example #16
Source File: TensorFlowModelProducer.java From samantha with MIT License | 5 votes |
static public SavedModelBundle loadTensorFlowSavedModel(String exportDir) { SavedModelBundle savedModel = null; if (new File(exportDir).exists()) { savedModel = SavedModelBundle.load(exportDir, TensorFlowModel.SAVED_MODEL_TAG); } else { logger.warn("TensorFlow exported model dir does not exist: {}.", exportDir); } return savedModel; }
Example #17
Source File: BlogEvaluationTestCase.java From vespa with Apache License 2.0 | 5 votes |
@Test public void testImport() { SavedModelBundle tensorFlowModel = SavedModelBundle.load(modelDir, "serve"); ImportedModel model = new TensorFlowImporter().importModel("blog", modelDir, tensorFlowModel); ImportedModel.Signature y = model.signature("serving_default.y"); assertNotNull(y); assertEquals(0, y.inputs().size()); }
Example #18
Source File: BlogEvaluationBenchmark.java From vespa with Apache License 2.0 | 5 votes |
private static void benchmarkTensorFlow(SavedModelBundle tensorFlowModel, int warmup, int iterations) { org.tensorflow.Tensor<?> u = generateInputTensorFlow(); org.tensorflow.Tensor<?> d = generateInputTensorFlow(); System.out.println("*** TensorFlow evaluation - warmup ***"); evaluateTensorflow(tensorFlowModel, u, d, warmup); System.gc(); System.out.println("*** TensorFlow evaluation - " + iterations + " iterations ***"); double startTime = System.nanoTime(); evaluateTensorflow(tensorFlowModel, u, d, iterations); double endTime = System.nanoTime(); System.out.println("Model evaluation time is " + ((endTime-startTime) / (1000*1000) + " ms")); System.out.println("Average model evaluation time is " + ((endTime-startTime) / (1000*1000)) / iterations + " ms"); }
Example #19
Source File: BlogEvaluationBenchmark.java From vespa with Apache License 2.0 | 5 votes |
private static double evaluateTensorflow(SavedModelBundle tensorFlowModel, org.tensorflow.Tensor<?> u, org.tensorflow.Tensor<?> d, int iterations) { double result = 0; for (int i = 0 ; i < iterations; i++) { Session.Runner runner = tensorFlowModel.session().runner(); runner.feed("input_u", u); runner.feed("input_d", d); List<org.tensorflow.Tensor<?>> results = runner.fetch("y").run(); result = TensorConverter.toVespaTensor(results.get(0)).sum().asDouble(); } return result; }
Example #20
Source File: Tf2OnnxImportTestCase.java From vespa with Apache License 2.0 | 5 votes |
private boolean testModelWithOpset(Report report, int opset, String tfModel) throws IOException { String onnxModel = Paths.get(testFolder.getRoot().getAbsolutePath(), "converted.onnx").toString(); var res = tf2onnxConvert(tfModel, onnxModel, opset); if (res.getFirst() != 0) { return reportAndFail(report, opset, tfModel, "tf2onnx conversion failed: " + res.getSecond()); } SavedModelBundle tensorFlowModel = SavedModelBundle.load(tfModel, "serve"); ImportedModel model = new TensorFlowImporter().importModel("test", tfModel, tensorFlowModel); ImportedModel onnxImportedModel = new OnnxImporter().importModel("test", onnxModel); if (model.signature("serving_default").skippedOutputs().size() > 0) { return reportAndFail(report, opset, tfModel, "Failed to import model from TensorFlow due to skipped outputs"); } if (onnxImportedModel.signature("default").skippedOutputs().size() > 0) { return reportAndFail(report, opset, tfModel, "Failed to import model from ONNX due to skipped outputs"); } ImportedModel.Signature sig = model.signatures().values().iterator().next(); String output = sig.outputs().values().iterator().next(); String onnxOutput = onnxImportedModel.signatures().values().iterator().next().outputs().values().iterator().next(); Tensor tfResult = evaluateTF(tensorFlowModel, output, model.inputs()); Tensor vespaResult = evaluateVespa(model, output, model.inputs()); Tensor onnxResult = evaluateVespa(onnxImportedModel, onnxOutput, model.inputs()); if ( ! tfResult.equals(vespaResult) ) { return reportAndFail(report, opset, tfModel, "Diff between tf and imported tf evaluation:\n\t" + tfResult + "\n\t" + vespaResult); } if ( ! vespaResult.equals(onnxResult) ) { return reportAndFail(report, opset, tfModel, "Diff between imported tf eval and onnx eval:\n\t" + vespaResult + "\n\t" + onnxResult); } return reportAndSucceed(report, opset, tfModel, "Ok"); }
Example #21
Source File: TestableModel.java From vespa with Apache License 2.0 | 5 votes |
Tensor evaluateTF(SavedModelBundle tensorFlowModel, String operationName, Map<String, TensorType> inputs) { Session.Runner runner = tensorFlowModel.session().runner(); for (Map.Entry<String, TensorType> entry : inputs.entrySet()) { try { runner.feed(entry.getKey(), tensorFlowFloatInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue())); } catch (Exception e) { runner.feed(entry.getKey(), tensorFlowDoubleInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue())); } } List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run(); assertEquals(1, results.size()); return TensorConverter.toVespaTensor(results.get(0)); }
Example #22
Source File: InProcessClassification.java From hazelcast-jet-demos with Apache License 2.0 | 5 votes |
private static Tuple2<String, Float> classify( String review, SavedModelBundle model, WordIndex wordIndex ) { try (Tensor<Float> input = Tensors.create(wordIndex.createTensorInput(review)); Tensor<?> output = model.session().runner() .feed("embedding_input:0", input) .fetch("dense_1/Sigmoid:0").run().get(0) ) { float[][] result = new float[1][1]; output.copyTo(result); return tuple2(review, result[0][0]); } }
Example #23
Source File: TensorFlowModel.java From zoltar with Apache License 2.0 | 5 votes |
private static MetaGraphDef extractMetaGraphDefinition(final SavedModelBundle bundle) throws TensorflowMetaGraphDefParsingException { final MetaGraphDef metaGraphDef; try { metaGraphDef = MetaGraphDef.parseFrom(bundle.metaGraphDef()); } catch (InvalidProtocolBufferException e) { throw new TensorflowMetaGraphDefParsingException( "Failed parsing tensorflow metagraph " + "definition", e); } return metaGraphDef; }
Example #24
Source File: TensorFlowModel.java From zoltar with Apache License 2.0 | 5 votes |
/** * Note: Please use Models from zoltar-models module. * * <p>Returns a TensorFlow model with metadata given {@link SavedModelBundle} export directory URI * and {@link Options}. */ public static TensorFlowModel create( final Model.Id id, final URI modelResource, final Options options, final String signatureDefinition) throws IOException { // GCS requires that directory URIs have a trailing slash, so add the slash if it's missing // and the URI starts with 'gs'. final URI normalizedUri = !CloudStorageFileSystem.URI_SCHEME.equalsIgnoreCase(modelResource.getScheme()) || modelResource.toString().endsWith("/") ? modelResource : URI.create(modelResource.toString() + "/"); final URI localDir = FileSystemExtras.downloadIfNonLocal(normalizedUri); final SavedModelBundle model = SavedModelBundle.load(localDir.toString(), options.tags().toArray(new String[0])); final MetaGraphDef metaGraphDef; try { metaGraphDef = extractMetaGraphDefinition(model); } catch (TensorflowMetaGraphDefParsingException e) { throw new IOException(e); } final SignatureDef signatureDef = metaGraphDef.getSignatureDefOrThrow(signatureDefinition); return new AutoValue_TensorFlowModel( id, model, options, metaGraphDef, signatureDef, toNameMap(signatureDef.getInputsMap()), toNameMap(signatureDef.getOutputsMap())); }
Example #25
Source File: Bert.java From easy-bert with MIT License | 5 votes |
private Bert(final SavedModelBundle bundle, final ModelDetails model, final Path vocabulary) { tokenizer = new FullTokenizer(vocabulary, model.doLowerCase); this.bundle = bundle; this.model = model; final int[] ids = tokenizer.convert(new String[] {START_TOKEN, SEPARATOR_TOKEN}); startTokenId = ids[0]; separatorTokenId = ids[1]; }
Example #26
Source File: Bert.java From easy-bert with MIT License | 5 votes |
/** * Loads a pre-trained BERT model from a TensorFlow saved model saved by the easy-bert Python utilities * * @param path * the path to load the model from * @return a ready-to-use BERT model * @since 1.0.3 */ public static Bert load(Path path) { path = path.toAbsolutePath(); ModelDetails model; try { model = new ObjectMapper().readValue(path.resolve("assets").resolve(MODEL_DETAILS).toFile(), ModelDetails.class); } catch(final IOException e) { throw new RuntimeException(e); } return new Bert(SavedModelBundle.load(path.toString(), "serve"), model, path.resolve("assets").resolve(VOCAB_FILE)); }
Example #27
Source File: TfModel.java From djl with Apache License 2.0 | 5 votes |
/** {@inheritDoc} */ @Override public void load(Path modelPath, String prefix, Map<String, Object> options) throws FileNotFoundException { modelDir = modelPath.toAbsolutePath(); if (prefix == null) { prefix = modelName; } Path exportDir = findModleDir(prefix); if (exportDir == null) { exportDir = findModleDir("saved_model.pb"); if (exportDir == null) { throw new FileNotFoundException("No TensorFlow model found in: " + modelDir); } } String[] tags = null; ConfigProto proto = null; RunOptions runOptions = null; if (options != null) { tags = (String[]) options.get("Tags"); proto = (ConfigProto) options.get("ConfigProto"); runOptions = (RunOptions) options.get("RunOptions"); } if (tags == null) { tags = new String[] {"serve"}; } SavedModelBundle.Loader loader = SavedModelBundle.loader(exportDir.toString()).withTags(tags); if (proto != null) { loader.withConfigProto(proto); } if (runOptions != null) { loader.withRunOptions(runOptions); } SavedModelBundle bundle = loader.load(); block = new TfSymbolBlock(bundle); }
Example #28
Source File: ObjectDetector.java From OpenLabeler with Apache License 2.0 | 5 votes |
private Void update(Path path) { try { File savedModelFile = new File(Settings.getTFSavedModelDir()); if (savedModelFile.exists() && (path == null || "saved_model".equals(path.toString()))) { if (path != null) { // coming from watched file Thread.sleep(5000); // Wait a while for model to be exported } synchronized (ObjectDetector.this) { if (model != null) { model.close(); } model = SavedModelBundle.load(savedModelFile.getAbsolutePath(), "serve"); String message = MessageFormat.format(bundle.getString("msg.loadedSavedModel"), savedModelFile); LOG.info(message); printSignature(model); Platform.runLater(() -> statusProperty.set(message)); } } else if (!savedModelFile.exists() && path == null) { LOG.info(savedModelFile.toString() + " does not exist"); } } catch (Exception ex) { LOG.log(Level.SEVERE, "Unable to update " + path, ex); } return null; }
Example #29
Source File: TensorFlowProcessor.java From datacollector with Apache License 2.0 | 4 votes |
@Override protected List<ConfigIssue> init() { List<ConfigIssue> issues = super.init(); String[] modelTags = new String[conf.modelTags.size()]; modelTags = conf.modelTags.toArray(modelTags); if (Strings.isNullOrEmpty(conf.modelPath)) { issues.add(getContext().createConfigIssue( Groups.TENSOR_FLOW.name(), TensorFlowConfigBean.MODEL_PATH_CONFIG, Errors.TENSOR_FLOW_01 )); return issues; } try { File exportedModelDir = new File(conf.modelPath); if (!exportedModelDir.isAbsolute()) { exportedModelDir = new File(getContext().getResourcesDirectory(), conf.modelPath).getAbsoluteFile(); } this.savedModel = SavedModelBundle.load(exportedModelDir.getAbsolutePath(), modelTags); } catch (TensorFlowException ex) { issues.add(getContext().createConfigIssue( Groups.TENSOR_FLOW.name(), TensorFlowConfigBean.MODEL_PATH_CONFIG, Errors.TENSOR_FLOW_02, ex )); return issues; } this.session = this.savedModel.session(); this.conf.inputConfigs.forEach(inputConfig -> { Pair<String, Integer> key = Pair.of(inputConfig.operation, inputConfig.index); inputConfigMap.put(key, inputConfig); } ); fieldPathEval = getContext().createELEval("conf.inputConfigs"); fieldPathVars = getContext().createELVars(); errorRecordHandler = new DefaultErrorRecordHandler(getContext()); return issues; }
Example #30
Source File: TensorflowSavedModel.java From tutorials with MIT License | 4 votes |
public static void main(String[] args) { SavedModelBundle model = SavedModelBundle.load("./model", "serve"); Tensor<Integer> tensor = model.session().runner().fetch("z").feed("x", Tensor.<Integer>create(3, Integer.class)) .feed("y", Tensor.<Integer>create(3, Integer.class)).run().get(0).expect(Integer.class); System.out.println(tensor.intValue()); }