org.tensorflow.framework.TensorProto Java Examples
The following examples show how to use
org.tensorflow.framework.TensorProto.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: TensorConverter.java From vespa with Apache License 2.0 | 5 votes |
static Tensor toVespaTensor(TensorProto tensorProto, OrderedTensorType type) { Values values = readValuesOf(tensorProto); if (values.size() == 0) { // Might be stored as "tensor_content" instead return toVespaTensor(readTensorContentOf(tensorProto), type); } IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type.type()); for (int i = 0; i < values.size(); ++i) builder.cellByDirectIndex(i, values.get(i)); return builder.build(); }
Example #2
Source File: TensorConverter.java From vespa with Apache License 2.0 | 5 votes |
private static Values readValuesOf(TensorProto tensorProto) { switch (tensorProto.getDtype()) { case DT_BOOL: return new ProtoBoolValues(tensorProto); case DT_HALF: return new ProtoHalfValues(tensorProto); case DT_INT16: case DT_INT32: return new ProtoIntValues(tensorProto); case DT_INT64: return new ProtoInt64Values(tensorProto); case DT_FLOAT: return new ProtoFloatValues(tensorProto); case DT_DOUBLE: return new ProtoDoubleValues(tensorProto); default: throw new IllegalArgumentException("Unsupported data type in attribute tensor import"); } }
Example #3
Source File: TFTensorMappers.java From deeplearning4j with Apache License 2.0 | 4 votes |
public BoolTensorMapper(TensorProto tensorProto) { super(tensorProto); }
Example #4
Source File: ModelServerClassification.java From hazelcast-jet-demos with Apache License 2.0 | 4 votes |
private static Pipeline buildPipeline(String serverAddress, IMap<Long, String> reviewsMap) { ServiceFactory<Tuple2<PredictionServiceFutureStub, WordIndex>, Tuple2<PredictionServiceFutureStub, WordIndex>> tfServingContext = ServiceFactory .withCreateContextFn(context -> { WordIndex wordIndex = new WordIndex(context.attachedDirectory("data")); ManagedChannel channel = ManagedChannelBuilder.forTarget(serverAddress) .usePlaintext().build(); return Tuple2.tuple2(PredictionServiceGrpc.newFutureStub(channel), wordIndex); }) .withDestroyContextFn(t -> ((ManagedChannel) t.f0().getChannel()).shutdownNow()) .withCreateServiceFn((context, tuple2) -> tuple2); Pipeline p = Pipeline.create(); p.readFrom(Sources.map(reviewsMap)) .map(Map.Entry::getValue) .mapUsingServiceAsync(tfServingContext, 16, true, (t, review) -> { float[][] featuresTensorData = t.f1().createTensorInput(review); TensorProto.Builder featuresTensorBuilder = TensorProto.newBuilder(); for (float[] featuresTensorDatum : featuresTensorData) { for (float v : featuresTensorDatum) { featuresTensorBuilder.addFloatVal(v); } } TensorShapeProto.Dim featuresDim1 = TensorShapeProto.Dim.newBuilder().setSize(featuresTensorData.length).build(); TensorShapeProto.Dim featuresDim2 = TensorShapeProto.Dim.newBuilder().setSize(featuresTensorData[0].length).build(); TensorShapeProto featuresShape = TensorShapeProto.newBuilder().addDim(featuresDim1).addDim(featuresDim2).build(); featuresTensorBuilder.setDtype(org.tensorflow.framework.DataType.DT_FLOAT) .setTensorShape(featuresShape); TensorProto featuresTensorProto = featuresTensorBuilder.build(); // Generate gRPC request Int64Value version = Int64Value.newBuilder().setValue(1).build(); Model.ModelSpec modelSpec = Model.ModelSpec.newBuilder().setName("reviewSentiment").setVersion(version).build(); Predict.PredictRequest request = Predict.PredictRequest.newBuilder() .setModelSpec(modelSpec) .putInputs("input_review", featuresTensorProto) .build(); return toCompletableFuture(t.f0().predict(request)) .thenApply(response -> { float classification = response .getOutputsOrThrow("dense_1/Sigmoid:0") .getFloatVal(0); // emit the review along with the classification return tuple2(review, classification); }); }) .setLocalParallelism(1) // one worker is enough to drive they async calls .writeTo(Sinks.logger()); return p; }
Example #5
Source File: TFTensorMappers.java From deeplearning4j with Apache License 2.0 | 4 votes |
public StringTensorMapper(TensorProto tensorProto) { super(tensorProto); }
Example #6
Source File: TFTensorMappers.java From deeplearning4j with Apache License 2.0 | 4 votes |
public UInt64TensorMapper(TensorProto tensorProto) { super(tensorProto); }
Example #7
Source File: TFTensorMappers.java From deeplearning4j with Apache License 2.0 | 4 votes |
public UInt32TensorMapper(TensorProto tensorProto) { super(tensorProto); }
Example #8
Source File: TFTensorMappers.java From deeplearning4j with Apache License 2.0 | 4 votes |
public UInt16TensorMapper(TensorProto tensorProto) { super(tensorProto); }
Example #9
Source File: TFTensorMappers.java From deeplearning4j with Apache License 2.0 | 4 votes |
public UInt8TensorMapper(TensorProto tensorProto) { super(tensorProto); }
Example #10
Source File: TFTensorMappers.java From deeplearning4j with Apache License 2.0 | 4 votes |
public Int64TensorMapper(TensorProto tensorProto) { super(tensorProto); }
Example #11
Source File: TFTensorMappers.java From deeplearning4j with Apache License 2.0 | 4 votes |
public Int32TensorMapper(TensorProto tensorProto) { super(tensorProto); }
Example #12
Source File: TFTensorMappers.java From deeplearning4j with Apache License 2.0 | 4 votes |
public Int16TensorMapper(TensorProto tensorProto) { super(tensorProto); }
Example #13
Source File: TFTensorMappers.java From deeplearning4j with Apache License 2.0 | 4 votes |
public Int8TensorMapper(TensorProto tensorProto) { super(tensorProto); }
Example #14
Source File: TFTensorMappers.java From deeplearning4j with Apache License 2.0 | 4 votes |
public BFloat16TensorMapper(TensorProto tensorProto) { super(tensorProto); }
Example #15
Source File: TFTensorMappers.java From deeplearning4j with Apache License 2.0 | 4 votes |
public Float64TensorMapper(TensorProto tensorProto) { super(tensorProto); }
Example #16
Source File: TFTensorMappers.java From deeplearning4j with Apache License 2.0 | 4 votes |
public Float32TensorMapper(TensorProto tensorProto) { super(tensorProto); }
Example #17
Source File: TFTensorMappers.java From deeplearning4j with Apache License 2.0 | 4 votes |
public Float16TensorMapper(TensorProto tensorProto) { super(tensorProto); }
Example #18
Source File: TFTensorMappers.java From deeplearning4j with Apache License 2.0 | 4 votes |
public BaseTensorMapper(TensorProto tensorProto){ this.tfTensor = tensorProto; }
Example #19
Source File: TFTensorMappers.java From deeplearning4j with Apache License 2.0 | 4 votes |
public static TFTensorMapper<?,?> newMapper(TensorProto tp){ switch (tp.getDtype()){ case DT_HALF: return new Float16TensorMapper(tp); case DT_FLOAT: return new Float32TensorMapper(tp); case DT_DOUBLE: return new Float64TensorMapper(tp); case DT_BFLOAT16: return new BFloat16TensorMapper(tp); case DT_INT8: return new Int8TensorMapper(tp); case DT_INT16: return new Int16TensorMapper(tp); case DT_INT32: return new Int32TensorMapper(tp); case DT_INT64: return new Int64TensorMapper(tp); case DT_STRING: return new StringTensorMapper(tp); case DT_BOOL: return new BoolTensorMapper(tp); case DT_UINT8: return new UInt8TensorMapper(tp); case DT_UINT16: return new UInt16TensorMapper(tp); case DT_UINT32: return new UInt32TensorMapper(tp); case DT_UINT64: return new UInt64TensorMapper(tp); case DT_QINT8: case DT_QUINT8: case DT_QINT32: case DT_QINT16: case DT_QUINT16: throw new IllegalStateException("Unable to map quantized type: " + tp.getDtype()); case DT_COMPLEX64: case DT_COMPLEX128: throw new IllegalStateException("Unable to map complex type: " + tp.getDtype()); case DT_FLOAT_REF: case DT_DOUBLE_REF: case DT_INT32_REF: case DT_UINT8_REF: case DT_INT16_REF: case DT_INT8_REF: case DT_STRING_REF: case DT_COMPLEX64_REF: case DT_INT64_REF: case DT_BOOL_REF: case DT_QINT8_REF: case DT_QUINT8_REF: case DT_QINT32_REF: case DT_BFLOAT16_REF: case DT_QINT16_REF: case DT_QUINT16_REF: case DT_UINT16_REF: case DT_COMPLEX128_REF: case DT_HALF_REF: case DT_RESOURCE_REF: case DT_VARIANT_REF: case DT_UINT32_REF: case DT_UINT64_REF: throw new IllegalStateException("Unable to map reference type: " + tp.getDtype()); case UNRECOGNIZED: case DT_RESOURCE: case DT_VARIANT: case DT_INVALID: default: throw new IllegalStateException("Unable to map type: " + tp.getDtype()); } }
Example #20
Source File: TensorConverter.java From vespa with Apache License 2.0 | 4 votes |
private static org.tensorflow.Tensor readTensorContentOf(TensorProto tensorProto) { return org.tensorflow.Tensor.create(dataTypeToClass(tensorProto.getDtype()), asSizeArray(tensorProto.getTensorShape().getDimList()), tensorProto.getTensorContent().asReadOnlyByteBuffer()); }
Example #21
Source File: TensorConverter.java From vespa with Apache License 2.0 | votes |
ProtoDoubleValues(TensorProto tensorProto) { super(tensorProto); }
Example #22
Source File: TensorConverter.java From vespa with Apache License 2.0 | votes |
ProtoFloatValues(TensorProto tensorProto) { super(tensorProto); }
Example #23
Source File: TensorConverter.java From vespa with Apache License 2.0 | votes |
ProtoInt64Values(TensorProto tensorProto) { super(tensorProto); }
Example #24
Source File: TensorConverter.java From vespa with Apache License 2.0 | votes |
ProtoIntValues(TensorProto tensorProto) { super(tensorProto); }
Example #25
Source File: TensorConverter.java From vespa with Apache License 2.0 | votes |
ProtoHalfValues(TensorProto tensorProto) { super(tensorProto); }
Example #26
Source File: TensorConverter.java From vespa with Apache License 2.0 | votes |
ProtoBoolValues(TensorProto tensorProto) { super(tensorProto); }
Example #27
Source File: TensorConverter.java From vespa with Apache License 2.0 | votes |
ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; }