org.tensorflow.framework.ConfigProto Java Examples
The following examples show how to use
org.tensorflow.framework.ConfigProto.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: GraphRunner.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static org.tensorflow.framework.ConfigProto getAlignedWithNd4j() { org.tensorflow.framework.ConfigProto configProto = org.tensorflow.framework.ConfigProto.getDefaultInstance(); ConfigProto.Builder builder1 = configProto.toBuilder().addDeviceFilters(TensorflowConversion.defaultDeviceForThread()); try { //cuda if(Nd4j.getBackend().getClass().getName().toLowerCase().contains("jcu")) { builder1.setGpuOptions(org.tensorflow.framework.GPUOptions.newBuilder() .setAllowGrowth(true) .setPerProcessGpuMemoryFraction(0.5) .build()); } //cpu else { } } catch (Exception e) { log.error("",e); } return builder1.build(); }
Example #2
Source File: GraphRunner.java From deeplearning4j with Apache License 2.0 | 6 votes |
/** * Convert a json string written out * by {@link org.nd4j.shade.protobuf.util.JsonFormat} * to a {@link org.bytedeco.tensorflow.ConfigProto} * @param json the json to read * @return the config proto to use */ public static org.tensorflow.framework.ConfigProto fromJson(String json) { org.tensorflow.framework.ConfigProto.Builder builder = org.tensorflow.framework.ConfigProto.newBuilder(); try { org.nd4j.shade.protobuf.util.JsonFormat.parser().merge(json,builder); org.tensorflow.framework.ConfigProto build = builder.build(); org.nd4j.shade.protobuf.ByteString serialized = build.toByteString(); byte[] binaryString = serialized.toByteArray(); org.tensorflow.framework.ConfigProto configProto = org.tensorflow.framework.ConfigProto.parseFrom(binaryString); return configProto; } catch (Exception e) { log.error("",e); } return null; }
Example #3
Source File: TensorFlowGraphModel.java From zoltar with Apache License 2.0 | 6 votes |
/** * Note: Please use Models from zoltar-models module. * * <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}. * * @param id model id @{link Model.Id}. * @param graphDef byte array representing the TensorFlow {@link Graph} definition. * @param config ConfigProto config for TensorFlow {@link Session}. * @param prefix a prefix that will be prepended to names in graphDef. */ public static TensorFlowGraphModel create( final Model.Id id, final byte[] graphDef, @Nullable final ConfigProto config, @Nullable final String prefix) { final Graph graph = new Graph(); final Session session = new Session(graph, config != null ? config.toByteArray() : null); final long loadStart = System.currentTimeMillis(); if (prefix == null) { LOG.debug("Loading graph definition without prefix"); graph.importGraphDef(graphDef); } else { LOG.debug("Loading graph definition with prefix: {}", prefix); graph.importGraphDef(graphDef, prefix); } LOG.info("TensorFlow graph loaded in {} ms", System.currentTimeMillis() - loadStart); return new AutoValue_TensorFlowGraphModel(id, graph, session); }
Example #4
Source File: GraphRunnerTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
private void runGraphRunnerTest(GraphRunner graphRunner) throws Exception { String json = graphRunner.sessionOptionsToJson(); if( json != null ) { org.tensorflow.framework.ConfigProto.Builder builder = org.tensorflow.framework.ConfigProto.newBuilder(); JsonFormat.parser().merge(json, builder); org.tensorflow.framework.ConfigProto build = builder.build(); assertEquals(build,graphRunner.getSessionOptionsConfigProto()); } assertNotNull(graphRunner.getInputOrder()); assertNotNull(graphRunner.getOutputOrder()); org.tensorflow.framework.ConfigProto configProto1 = json == null ? null : GraphRunner.fromJson(json); assertEquals(graphRunner.getSessionOptionsConfigProto(),configProto1); assertEquals(2,graphRunner.getInputOrder().size()); assertEquals(1,graphRunner.getOutputOrder().size()); INDArray input1 = Nd4j.linspace(1,4,4).reshape(4); INDArray input2 = Nd4j.linspace(1,4,4).reshape(4); Map<String,INDArray> inputs = new LinkedHashMap<>(); inputs.put("input_0",input1); inputs.put("input_1",input2); for(int i = 0; i < 2; i++) { Map<String,INDArray> outputs = graphRunner.run(inputs); INDArray assertion = input1.add(input2); assertEquals(assertion,outputs.get("output")); } }
Example #5
Source File: GraphRunnerTest.java From deeplearning4j with Apache License 2.0 | 5 votes |
public static ConfigProto getConfig(){ String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); if("CUDA".equalsIgnoreCase(backend)) { org.tensorflow.framework.ConfigProto configProto = org.tensorflow.framework.ConfigProto.getDefaultInstance(); ConfigProto.Builder b = configProto.toBuilder().addDeviceFilters(TensorflowConversion.defaultDeviceForThread()); return b.setGpuOptions(GPUOptions.newBuilder() .setAllowGrowth(true) .setPerProcessGpuMemoryFraction(0.5) .build()).build(); } return null; }
Example #6
Source File: TFServer.java From TensorFlowOnYARN with Apache License 2.0 | 5 votes |
public static TFServer createLocalServer() { HashMap<String, List<String>> cluster = new HashMap<String, List<String>>(); List<String> address_list = new ArrayList<String>(); address_list.add("localhost:0"); cluster.put("worker", address_list); ClusterSpec cluster_spec = new ClusterSpec(cluster); return new TFServer(cluster_spec, "worker", 0, "grpc", ConfigProto.getDefaultInstance()); }
Example #7
Source File: MtcnnService.java From mtcnn-java with Apache License 2.0 | 5 votes |
private GraphRunner createGraphRunner(String tensorflowModelUri, String inputLabel) { try { return new GraphRunner( IOUtils.toByteArray(new DefaultResourceLoader().getResource(tensorflowModelUri).getInputStream()), Arrays.asList(inputLabel), ConfigProto.getDefaultInstance()); } catch (IOException e) { throw new IllegalStateException(String.format("Failed to load TF model [%s] and input [%s]:", tensorflowModelUri, inputLabel), e); } }
Example #8
Source File: GraphRunner.java From deeplearning4j with Apache License 2.0 | 4 votes |
/** * The constructor for creating a graph runner via builder * @param inputNames the input names to use * @param outputNames the output names to use * @param savedModelConfig the saved model configuration to load from (note this can not be used in conjunction * with graph path) * @param sessionOptionsConfigProto the session options for running the model (this maybe null) * @param sessionOptionsProtoBytes the proto bytes equivalent of the session configuration * @param sessionOptionsProtoPath the file path to a session configuration proto file * @param graph the tensorflow graph to use * @param graphPath the path to the graph * @param graphBytes the in memory bytes of the graph * @param inputDataTypes the expected input data types * @param outputDataTypes the expected output data types */ @Builder public GraphRunner(List<String> inputNames, List<String> outputNames, SavedModelConfig savedModelConfig, org.tensorflow.framework.ConfigProto sessionOptionsConfigProto, byte[] sessionOptionsProtoBytes, File sessionOptionsProtoPath, TF_Graph graph, File graphPath, byte[] graphBytes, Map<String, TensorDataType> inputDataTypes, Map<String, TensorDataType> outputDataTypes) { try { if(sessionOptionsConfigProto == null) { if(sessionOptionsConfigProto != null) { this.sessionOptionsConfigProto = ConfigProto.parseFrom(sessionOptionsProtoBytes); } else if(sessionOptionsProtoPath != null) { byte[] load = FileUtils.readFileToByteArray(sessionOptionsProtoPath); this.sessionOptionsConfigProto = ConfigProto.parseFrom(load); } } else this.sessionOptionsConfigProto = sessionOptionsConfigProto; this.inputDataTypes = inputDataTypes; this.outputDataTypes = outputDataTypes; //note that the input and output order, maybe null here //if the names are specified, we should defer to those instead this.inputOrder = inputNames; this.outputOrder = outputNames; initOptionsIfNeeded(); if(graph != null) { this.graph = graph; } else if(graphBytes != null) { this.graph = conversion.loadGraph(graphBytes, status); } else if(graphPath != null) { graphBytes = IOUtils.toByteArray(graphPath.toURI()); this.graph = conversion.loadGraph(graphBytes, status); } else this.graph = TF_NewGraph(); if(savedModelConfig != null) { this.savedModelConfig = savedModelConfig; Map<String,String> inputsMap = new LinkedHashMap<>(); Map<String,String> outputsMap = new LinkedHashMap<>(); this.session = conversion.loadSavedModel(savedModelConfig, options, null, this.graph, inputsMap, outputsMap, status); if(inputOrder == null || inputOrder.isEmpty()) inputOrder = new ArrayList<>(inputsMap.values()); if(outputOrder == null || outputOrder.isEmpty()) outputOrder = new ArrayList<>(outputsMap.values()); savedModelConfig.setSavedModelInputOrder(new ArrayList<>(inputsMap.values())); savedModelConfig.setSaveModelOutputOrder(new ArrayList<>(outputsMap.values())); log.info("Loaded input names from saved model configuration " + inputOrder); log.info("Loaded output names from saved model configuration " + outputOrder); } initSessionAndStatusIfNeeded(graphBytes); } catch (Exception e) { throw new IllegalArgumentException("Unable to parse protobuf",e); } }
Example #9
Source File: TFServer.java From TensorFlowOnYARN with Apache License 2.0 | 4 votes |
public static ServerDef makeServerDef(ClusterSpec clusterSpec, String jobName, int taskIndex, String proto, ConfigProto config) { return ServerDef.newBuilder().setCluster(clusterSpec.as_cluster_def()) .setJobName(jobName).setProtocol(proto).setTaskIndex(taskIndex) .setDefaultSessionConfig(config).build(); }
Example #10
Source File: TFServer.java From TensorFlowOnYARN with Apache License 2.0 | 4 votes |
public static ServerDef makeServerDef(ServerDef serverDef, String jobName, int taskIndex, String proto, ConfigProto config) { return ServerDef.newBuilder().mergeFrom(serverDef).setJobName(jobName) .setTaskIndex(taskIndex).setProtocol(proto).setDefaultSessionConfig(config).build(); }
Example #11
Source File: TFServer.java From TensorFlowOnYARN with Apache License 2.0 | 4 votes |
public TFServer(Map<String, List<String>> clusterSpec, String jobName, int taskIndex) throws TFServerException { this(new ClusterSpec(clusterSpec), jobName, taskIndex, "grpc", ConfigProto.getDefaultInstance()); }
Example #12
Source File: TensorFlowGraphLoader.java From zoltar with Apache License 2.0 | 3 votes |
/** * Returns a TensorFlow model loader based on a serialized TensorFlow {@link Graph}. * * @param id model id @{link Model.Id}. * @param graphDef byte array representing the TensorFlow {@link Graph} definition. * @param config optional TensorFlow {@link ConfigProto} config. * @param prefix optional prefix that will be prepended to names in the graph. */ static TensorFlowGraphLoader create( final Model.Id id, final byte[] graphDef, @Nullable final ConfigProto config, @Nullable final String prefix) { return create(() -> TensorFlowGraphModel.create(id, graphDef, config, prefix)); }
Example #13
Source File: TensorFlowGraphLoader.java From zoltar with Apache License 2.0 | 3 votes |
/** * Returns a TensorFlow model loader based on a serialized TensorFlow {@link Graph}. * * @param id model id @{link Model.Id}. * @param modelUri should point to a serialized TensorFlow {@link org.tensorflow.Graph} file on * local filesystem, resource, GCS etc. * @param config optional TensorFlow {@link ConfigProto} config. * @param prefix optional prefix that will be prepended to names in the graph. */ static TensorFlowGraphLoader create( final Model.Id id, final String modelUri, @Nullable final ConfigProto config, @Nullable final String prefix) { return create(() -> TensorFlowGraphModel.create(id, URI.create(modelUri), config, prefix)); }
Example #14
Source File: TensorFlowGraphModel.java From zoltar with Apache License 2.0 | 3 votes |
/** * Note: Please use Models from zoltar-models module. * * <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}. * * @param id model id @{link Model.Id}. * @param graphUri URI to the TensorFlow graph definition. * @param config config for TensorFlow {@link Session}. * @param prefix optional prefix that will be prepended to names in the graph. */ public static TensorFlowGraphModel create( final Model.Id id, final URI graphUri, @Nullable final ConfigProto config, @Nullable final String prefix) throws IOException { final byte[] graphBytes = Files.readAllBytes(FileSystemExtras.path(graphUri)); return create(id, graphBytes, config, prefix); }
Example #15
Source File: Models.java From zoltar with Apache License 2.0 | 3 votes |
/** * Returns a TensorFlow model loader based on a serialized TensorFlow {@link Graph}. * * @param id model id @{link Model.Id}. * @param graphDef byte array representing the TensorFlow {@link Graph} definition. * @param config optional TensorFlow {@link ConfigProto} config. * @param prefix optional prefix that will be prepended to names in the graph. */ public static TensorFlowGraphLoader tensorFlowGraph( final Model.Id id, final byte[] graphDef, @Nullable final ConfigProto config, @Nullable final String prefix) { return TensorFlowGraphLoader.create(id, graphDef, config, prefix); }
Example #16
Source File: Models.java From zoltar with Apache License 2.0 | 3 votes |
/** * Returns a TensorFlow model loader based on a serialized TensorFlow {@link Graph}. * * @param id model id @{link Model.Id}. * @param modelUri should point to a serialized TensorFlow {@link org.tensorflow.Graph} file on * local filesystem, resource, GCS etc. * @param config optional TensorFlow {@link ConfigProto} config. * @param prefix optional prefix that will be prepended to names in the graph. */ public static TensorFlowGraphLoader tensorFlowGraph( final Model.Id id, final String modelUri, @Nullable final ConfigProto config, @Nullable final String prefix) { return TensorFlowGraphLoader.create(id, modelUri, config, prefix); }
Example #17
Source File: Models.java From zoltar with Apache License 2.0 | 2 votes |
/** * Returns a TensorFlow model loader based on a serialized TensorFlow {@link Graph}. * * @param modelUri should point to a serialized TensorFlow {@link org.tensorflow.Graph} file on * local filesystem, resource, GCS etc. * @param config optional TensorFlow {@link ConfigProto} config. * @param prefix optional prefix that will be prepended to names in the graph. */ public static TensorFlowGraphLoader tensorFlowGraph( final String modelUri, @Nullable final ConfigProto config, @Nullable final String prefix) { return TensorFlowGraphLoader.create(modelUri, config, prefix); }
Example #18
Source File: TensorFlowGraphLoader.java From zoltar with Apache License 2.0 | 2 votes |
/** * Returns a TensorFlow model loader based on a serialized TensorFlow {@link Graph}. * * @param graphDef byte array representing the TensorFlow {@link Graph} definition. * @param config optional TensorFlow {@link ConfigProto} config. * @param prefix optional prefix that will be prepended to names in the graph. */ static TensorFlowGraphLoader create( final byte[] graphDef, @Nullable final ConfigProto config, @Nullable final String prefix) { return create(() -> TensorFlowGraphModel.create(graphDef, config, prefix)); }
Example #19
Source File: TensorFlowGraphLoader.java From zoltar with Apache License 2.0 | 2 votes |
/** * Returns a TensorFlow model loader based on a serialized TensorFlow {@link Graph}. * * @param modelUri should point to a serialized TensorFlow {@link org.tensorflow.Graph} file on * local filesystem, resource, GCS etc. * @param config optional TensorFlow {@link ConfigProto} config. * @param prefix optional prefix that will be prepended to names in the graph. */ static TensorFlowGraphLoader create( final String modelUri, @Nullable final ConfigProto config, @Nullable final String prefix) { return create(() -> TensorFlowGraphModel.create(URI.create(modelUri), config, prefix)); }
Example #20
Source File: TensorFlowGraphModel.java From zoltar with Apache License 2.0 | 2 votes |
/** * Note: Please use Models from zoltar-models module. * * <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}. * * @param graphDef byte array representing the TensorFlow {@link Graph} definition. * @param config ConfigProto config for TensorFlow {@link Session}. * @param prefix a prefix that will be prepended to names in graphDef. */ public static TensorFlowGraphModel create( final byte[] graphDef, @Nullable final ConfigProto config, @Nullable final String prefix) throws IOException { return create(DEFAULT_ID, graphDef, config, prefix); }
Example #21
Source File: TensorFlowGraphModel.java From zoltar with Apache License 2.0 | 2 votes |
/** * Note: Please use Models from zoltar-models module. * * <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}. * * @param graphUri URI to the TensorFlow graph definition. * @param config config for TensorFlow {@link Session}. * @param prefix optional prefix that will be prepended to names in the graph. */ public static TensorFlowGraphModel create( final URI graphUri, @Nullable final ConfigProto config, @Nullable final String prefix) throws IOException { return create(DEFAULT_ID, graphUri, config, prefix); }
Example #22
Source File: Models.java From zoltar with Apache License 2.0 | 2 votes |
/** * Returns a TensorFlow model loader based on a serialized TensorFlow {@link Graph}. * * @param graphDef byte array representing the TensorFlow {@link Graph} definition. * @param config optional TensorFlow {@link ConfigProto} config. * @param prefix optional prefix that will be prepended to names in the graph. */ public static TensorFlowGraphLoader tensorFlowGraph( final byte[] graphDef, @Nullable final ConfigProto config, @Nullable final String prefix) { return TensorFlowGraphLoader.create(graphDef, config, prefix); }