Java Code Examples for org.tensorflow.SavedModelBundle#load()
The following examples show how to use
org.tensorflow.SavedModelBundle#load() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ModelPredictTest.java From DeepMachineLearning with Apache License 2.0 | 7 votes |
@Test public void test() throws IOException, InterruptedException, IpssCacheException, ODMException, ExecutionException { AclfTrainDataGenerator gateway = new AclfTrainDataGenerator(); //read case String filename = "testdata/cases/ieee14.ieee"; gateway.loadCase(filename, "BusVoltLoadChangeTrainCaseBuilder"); //run loadflow gateway.trainCaseBuilder.createTestCase(); //generate input double[] inputs = gateway.trainCaseBuilder.getNetInput(); float[][] inputs_f = new float[1][inputs.length]; for (int i = 0; i < inputs.length; i++) { inputs_f[0][i] =(float) inputs[i]; } //read model SavedModelBundle bundle = SavedModelBundle.load("py/c_graph/single_net/model", "voltage"); //predict float[][] output = bundle.session().runner().feed("x", Tensor.create(inputs_f)).fetch("z").run().get(0) .copyTo(new float[1][28]); double[][] output_d = new double[1][inputs.length]; for (int i = 0; i < inputs.length; i++) { output_d[0][i] = output[0][i]; } //print out mismatch System.out.println("Model out mismatch: "+gateway.getMismatchInfo(output_d[0])); }
Example 2
Source File: BlogEvaluationBenchmark.java From vespa with Apache License 2.0 | 6 votes |
public static void main(String[] args) throws ParseException { SavedModelBundle tensorFlowModel = SavedModelBundle.load(modelDir, "serve"); ImportedModel model = new TensorFlowImporter().importModel("blog", modelDir, tensorFlowModel); Context context = TestableTensorFlowModel.contextFrom(model); Tensor u = generateInputTensor(); Tensor d = generateInputTensor(); context.put("input_u", new TensorValue(u)); context.put("input_d", new TensorValue(d)); // Parse the ranking expression from imported string to force primitive tensor functions. RankingExpression expression = new RankingExpression(model.expressions().get("y").getRoot().toString()); benchmarkJava(expression, context, 20, 200); System.out.println("*** Optimizing expression ***"); ExpressionOptimizer optimizer = new ExpressionOptimizer(); OptimizationReport report = optimizer.optimize(expression, (ContextIndex)context); System.out.println(report.toString()); benchmarkJava(expression, context, 2000, 20000); benchmarkTensorFlow(tensorFlowModel, 2000, 20000); }
Example 3
Source File: ObjectDetector.java From OpenLabeler with Apache License 2.0 | 5 votes |
private Void update(Path path) { try { File savedModelFile = new File(Settings.getTFSavedModelDir()); if (savedModelFile.exists() && (path == null || "saved_model".equals(path.toString()))) { if (path != null) { // coming from watched file Thread.sleep(5000); // Wait a while for model to be exported } synchronized (ObjectDetector.this) { if (model != null) { model.close(); } model = SavedModelBundle.load(savedModelFile.getAbsolutePath(), "serve"); String message = MessageFormat.format(bundle.getString("msg.loadedSavedModel"), savedModelFile); LOG.info(message); printSignature(model); Platform.runLater(() -> statusProperty.set(message)); } } else if (!savedModelFile.exists() && path == null) { LOG.info(savedModelFile.toString() + " does not exist"); } } catch (Exception ex) { LOG.log(Level.SEVERE, "Unable to update " + path, ex); } return null; }
Example 4
Source File: Bert.java From easy-bert with MIT License | 5 votes |
/** * Loads a pre-trained BERT model from a TensorFlow saved model saved by the easy-bert Python utilities * * @param path * the path to load the model from * @return a ready-to-use BERT model * @since 1.0.3 */ public static Bert load(Path path) { path = path.toAbsolutePath(); ModelDetails model; try { model = new ObjectMapper().readValue(path.resolve("assets").resolve(MODEL_DETAILS).toFile(), ModelDetails.class); } catch(final IOException e) { throw new RuntimeException(e); } return new Bert(SavedModelBundle.load(path.toString(), "serve"), model, path.resolve("assets").resolve(VOCAB_FILE)); }
Example 5
Source File: TensorFlowModel.java From zoltar with Apache License 2.0 | 5 votes |
/** * Note: Please use Models from zoltar-models module. * * <p>Returns a TensorFlow model with metadata given {@link SavedModelBundle} export directory URI * and {@link Options}. */ public static TensorFlowModel create( final Model.Id id, final URI modelResource, final Options options, final String signatureDefinition) throws IOException { // GCS requires that directory URIs have a trailing slash, so add the slash if it's missing // and the URI starts with 'gs'. final URI normalizedUri = !CloudStorageFileSystem.URI_SCHEME.equalsIgnoreCase(modelResource.getScheme()) || modelResource.toString().endsWith("/") ? modelResource : URI.create(modelResource.toString() + "/"); final URI localDir = FileSystemExtras.downloadIfNonLocal(normalizedUri); final SavedModelBundle model = SavedModelBundle.load(localDir.toString(), options.tags().toArray(new String[0])); final MetaGraphDef metaGraphDef; try { metaGraphDef = extractMetaGraphDefinition(model); } catch (TensorflowMetaGraphDefParsingException e) { throw new IOException(e); } final SignatureDef signatureDef = metaGraphDef.getSignatureDefOrThrow(signatureDefinition); return new AutoValue_TensorFlowModel( id, model, options, metaGraphDef, signatureDef, toNameMap(signatureDef.getInputsMap()), toNameMap(signatureDef.getOutputsMap())); }
Example 6
Source File: TensorFlowModelProducer.java From samantha with MIT License | 5 votes |
static public SavedModelBundle loadTensorFlowSavedModel(String exportDir) { SavedModelBundle savedModel = null; if (new File(exportDir).exists()) { savedModel = SavedModelBundle.load(exportDir, TensorFlowModel.SAVED_MODEL_TAG); } else { logger.warn("TensorFlow exported model dir does not exist: {}.", exportDir); } return savedModel; }
Example 7
Source File: TensorFlowImporter.java From vespa with Apache License 2.0 | 5 votes |
/** * Imports a saved TensorFlow model from a directory. * The model should be saved as a .pbtxt or .pb file. * * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_] * @param modelDir the directory containing the TensorFlow model files to import */ @Override public ImportedModel importModel(String modelName, String modelDir) { // Temporary (for testing): if path contains "tf_2_onnx", convert to ONNX then import that model. if (modelDir.contains("tf_2_onnx")) { return convertToOnnxAndImport(modelName, modelDir); } try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { return importModel(modelName, modelDir, model); } catch (IllegalArgumentException e) { throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); } }
Example 8
Source File: VariableConverter.java From vespa with Apache License 2.0 | 5 votes |
/** * Reads the tensor with the given TensorFlow name at the given model location, * and encodes it as UTF-8 Vespa document tensor JSON having the given ordered tensor type. * Note that order of dimensions in the tensor type does matter as the TensorFlow tensor * tensor dimensions are implicitly ordered. */ static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) { try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) { return JsonFormat.encode(TensorConverter.toVespaTensor(GraphImporter.readVariable(tensorFlowVariableName, bundle), OrderedTensorType.fromSpec(orderedTypeSpec))); } catch (IllegalArgumentException e) { throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); } }
Example 9
Source File: TestableTensorFlowModel.java From vespa with Apache License 2.0 | 5 votes |
public TestableTensorFlowModel(String modelName, String modelDir, int d0Size, int d1Size, boolean floatInput) { tensorFlowModel = SavedModelBundle.load(modelDir, "serve"); model = new TensorFlowImporter().importModel(modelName, modelDir, tensorFlowModel); this.d0Size = d0Size; this.d1Size = d1Size; this.floatInput = floatInput; }
Example 10
Source File: BlogEvaluationTestCase.java From vespa with Apache License 2.0 | 5 votes |
@Test public void testImport() { SavedModelBundle tensorFlowModel = SavedModelBundle.load(modelDir, "serve"); ImportedModel model = new TensorFlowImporter().importModel("blog", modelDir, tensorFlowModel); ImportedModel.Signature y = model.signature("serving_default.y"); assertNotNull(y); assertEquals(0, y.inputs().size()); }
Example 11
Source File: Tf2OnnxImportTestCase.java From vespa with Apache License 2.0 | 5 votes |
private boolean testModelWithOpset(Report report, int opset, String tfModel) throws IOException { String onnxModel = Paths.get(testFolder.getRoot().getAbsolutePath(), "converted.onnx").toString(); var res = tf2onnxConvert(tfModel, onnxModel, opset); if (res.getFirst() != 0) { return reportAndFail(report, opset, tfModel, "tf2onnx conversion failed: " + res.getSecond()); } SavedModelBundle tensorFlowModel = SavedModelBundle.load(tfModel, "serve"); ImportedModel model = new TensorFlowImporter().importModel("test", tfModel, tensorFlowModel); ImportedModel onnxImportedModel = new OnnxImporter().importModel("test", onnxModel); if (model.signature("serving_default").skippedOutputs().size() > 0) { return reportAndFail(report, opset, tfModel, "Failed to import model from TensorFlow due to skipped outputs"); } if (onnxImportedModel.signature("default").skippedOutputs().size() > 0) { return reportAndFail(report, opset, tfModel, "Failed to import model from ONNX due to skipped outputs"); } ImportedModel.Signature sig = model.signatures().values().iterator().next(); String output = sig.outputs().values().iterator().next(); String onnxOutput = onnxImportedModel.signatures().values().iterator().next().outputs().values().iterator().next(); Tensor tfResult = evaluateTF(tensorFlowModel, output, model.inputs()); Tensor vespaResult = evaluateVespa(model, output, model.inputs()); Tensor onnxResult = evaluateVespa(onnxImportedModel, onnxOutput, model.inputs()); if ( ! tfResult.equals(vespaResult) ) { return reportAndFail(report, opset, tfModel, "Diff between tf and imported tf evaluation:\n\t" + tfResult + "\n\t" + vespaResult); } if ( ! vespaResult.equals(onnxResult) ) { return reportAndFail(report, opset, tfModel, "Diff between imported tf eval and onnx eval:\n\t" + vespaResult + "\n\t" + onnxResult); } return reportAndSucceed(report, opset, tfModel, "Ok"); }
Example 12
Source File: EstimatorTest.java From jpmml-tensorflow with GNU Affero General Public License v3.0 | 4 votes |
@Override protected ArchiveBatch createBatch(String name, String dataset, Predicate<FieldName> predicate){ ArchiveBatch result = new IntegrationTestBatch(name, dataset, predicate){ @Override public IntegrationTest getIntegrationTest(){ return EstimatorTest.this; } @Override public PMML getPMML() throws Exception { File savedModelDir = getSavedModelDir(); SavedModelBundle bundle = SavedModelBundle.load(savedModelDir.getAbsolutePath(), "serve"); try(SavedModel savedModel = new SavedModel(bundle)){ EstimatorFactory estimatorFactory = EstimatorFactory.newInstance(); Estimator estimator = estimatorFactory.newEstimator(savedModel); PMML pmml = estimator.encodePMML(); ensureValidity(pmml); return pmml; } } private File getSavedModelDir() throws IOException, URISyntaxException { ClassLoader classLoader = (EstimatorTest.this.getClass()).getClassLoader(); String protoPath = ("savedmodel/" + getName() + getDataset() + "/saved_model.pbtxt"); URL protoResource = classLoader.getResource(protoPath); if(protoResource == null){ throw new NoSuchFileException(protoPath); } File protoFile = (Paths.get(protoResource.toURI())).toFile(); return protoFile.getParentFile(); } }; return result; }
Example 13
Source File: TensorFlowProcessor.java From datacollector with Apache License 2.0 | 4 votes |
@Override protected List<ConfigIssue> init() { List<ConfigIssue> issues = super.init(); String[] modelTags = new String[conf.modelTags.size()]; modelTags = conf.modelTags.toArray(modelTags); if (Strings.isNullOrEmpty(conf.modelPath)) { issues.add(getContext().createConfigIssue( Groups.TENSOR_FLOW.name(), TensorFlowConfigBean.MODEL_PATH_CONFIG, Errors.TENSOR_FLOW_01 )); return issues; } try { File exportedModelDir = new File(conf.modelPath); if (!exportedModelDir.isAbsolute()) { exportedModelDir = new File(getContext().getResourcesDirectory(), conf.modelPath).getAbsoluteFile(); } this.savedModel = SavedModelBundle.load(exportedModelDir.getAbsolutePath(), modelTags); } catch (TensorFlowException ex) { issues.add(getContext().createConfigIssue( Groups.TENSOR_FLOW.name(), TensorFlowConfigBean.MODEL_PATH_CONFIG, Errors.TENSOR_FLOW_02, ex )); return issues; } this.session = this.savedModel.session(); this.conf.inputConfigs.forEach(inputConfig -> { Pair<String, Integer> key = Pair.of(inputConfig.operation, inputConfig.index); inputConfigMap.put(key, inputConfig); } ); fieldPathEval = getContext().createELEval("conf.inputConfigs"); fieldPathVars = getContext().createELVars(); errorRecordHandler = new DefaultErrorRecordHandler(getContext()); return issues; }
Example 14
Source File: TensorflowSavedModel.java From tutorials with MIT License | 4 votes |
public static void main(String[] args) { SavedModelBundle model = SavedModelBundle.load("./model", "serve"); Tensor<Integer> tensor = model.session().runner().fetch("z").feed("x", Tensor.<Integer>create(3, Integer.class)) .feed("y", Tensor.<Integer>create(3, Integer.class)).run().get(0).expect(Integer.class); System.out.println(tensor.intValue()); }