Java Code Examples for org.nd4j.linalg.factory.Nd4j#createFromNpyFile()
The following examples show how to use
org.nd4j.linalg.factory.Nd4j#createFromNpyFile() .
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: NumpyFormatTests.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test(expected = RuntimeException.class) public void readNumpyCorruptHeader1() throws Exception { File f = testDir.newFolder(); File fValid = new ClassPathResource("numpy_arrays/arange_3,4_float32.npy").getFile(); byte[] numpyBytes = FileUtils.readFileToByteArray(fValid); for( int i=0; i<10; i++ ){ numpyBytes[i] = 0; } File fCorrupt = new File(f, "corrupt.npy"); FileUtils.writeByteArrayToFile(fCorrupt, numpyBytes); INDArray exp = Nd4j.arange(12).castTo(DataType.FLOAT).reshape(3,4); INDArray act1 = Nd4j.createFromNpyFile(fValid); assertEquals(exp, act1); INDArray probablyShouldntLoad = Nd4j.createFromNpyFile(fCorrupt); //Loads fine boolean eq = exp.equals(probablyShouldntLoad); //And is actually equal content }
Example 2
Source File: NumpyFormatTests.java From deeplearning4j with Apache License 2.0 | 6 votes |
@Test(expected = RuntimeException.class) public void readNumpyCorruptHeader2() throws Exception { File f = testDir.newFolder(); File fValid = new ClassPathResource("numpy_arrays/arange_3,4_float32.npy").getFile(); byte[] numpyBytes = FileUtils.readFileToByteArray(fValid); for( int i=1; i<10; i++ ){ numpyBytes[i] = 0; } File fCorrupt = new File(f, "corrupt.npy"); FileUtils.writeByteArrayToFile(fCorrupt, numpyBytes); INDArray exp = Nd4j.arange(12).castTo(DataType.FLOAT).reshape(3,4); INDArray act1 = Nd4j.createFromNpyFile(fValid); assertEquals(exp, act1); INDArray probablyShouldntLoad = Nd4j.createFromNpyFile(fCorrupt); //Loads fine boolean eq = exp.equals(probablyShouldntLoad); //And is actually equal content }
Example 3
Source File: ImportModelDebugger.java From deeplearning4j with Apache License 2.0 | 6 votes |
public static Map<String, INDArray> loadPlaceholders(File rootDir){ File dir = new File(rootDir, "__placeholders"); if(!dir.exists()){ throw new IllegalStateException("Cannot find placeholders: directory does not exist: " + dir.getAbsolutePath()); } Map<String, INDArray> ret = new HashMap<>(); Iterator<File> iter = FileUtils.iterateFiles(dir, null, true); while(iter.hasNext()){ File f = iter.next(); if(!f.isFile()) continue; String s = dir.toURI().relativize(f.toURI()).getPath(); int idx = s.lastIndexOf("__"); String name = s.substring(0, idx); INDArray arr = Nd4j.createFromNpyFile(f); ret.put(name, arr); } return ret; }
Example 4
Source File: TestNDArrayCreation.java From nd4j with Apache License 2.0 | 5 votes |
@Test @Ignore public void testCreateNpy() throws Exception { INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("test.npy").getFile()); assertEquals(2, arrCreate.size(0)); assertEquals(2, arrCreate.size(1)); assertEquals(1.0, arrCreate.getDouble(0, 0), 1e-1); assertEquals(2.0, arrCreate.getDouble(0, 1), 1e-1); assertEquals(3.0, arrCreate.getDouble(1, 0), 1e-1); assertEquals(4.0, arrCreate.getDouble(1, 1), 1e-1); }
Example 5
Source File: TestNDArrayCreation.java From nd4j with Apache License 2.0 | 5 votes |
@Test public void testCreateNpy3() throws Exception { INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("rank3.npy").getFile()); assertEquals(8, arrCreate.length()); assertEquals(3, arrCreate.rank()); Pointer pointer = NativeOpsHolder.getInstance().getDeviceNativeOps() .pointerForAddress(arrCreate.data().address()); assertEquals(arrCreate.data().address(), pointer.address()); }
Example 6
Source File: TestNDArrayCreation.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test @Ignore public void testCreateNpy() throws Exception { INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/test.npy").getFile()); assertEquals(2, arrCreate.size(0)); assertEquals(2, arrCreate.size(1)); assertEquals(1.0, arrCreate.getDouble(0, 0), 1e-1); assertEquals(2.0, arrCreate.getDouble(0, 1), 1e-1); assertEquals(3.0, arrCreate.getDouble(1, 0), 1e-1); assertEquals(4.0, arrCreate.getDouble(1, 1), 1e-1); }
Example 7
Source File: TestNDArrayCreation.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test @Ignore("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657") public void testCreateNpy3() throws Exception { INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("nd4j-tests/rank3.npy").getFile()); assertEquals(8, arrCreate.length()); assertEquals(3, arrCreate.rank()); Pointer pointer = NativeOpsHolder.getInstance().getDeviceNativeOps() .pointerForAddress(arrCreate.data().address()); assertEquals(arrCreate.data().address(), pointer.address()); }
Example 8
Source File: NumpyFormatTests.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Ignore @Test public void testNumpyBoolean() { INDArray out = Nd4j.createFromNpyFile(new File("c:/Users/raver/Downloads/error2.npy")); // System.out.println(ArrayUtil.toList(ArrayUtil.toInts(out.shape()))); // System.out.println(out); }
Example 9
Source File: Nd4jValidator.java From deeplearning4j with Apache License 2.0 | 5 votes |
/** * Validate whether the file represents a valid Numpy .npy file to be read with {@link Nd4j#createFromNpyFile(File)} } * * @param f File that should represent a Numpy .npy file written with Numpy save method * @return Result of validation */ public static ValidationResult validateNpyFile(@NonNull File f) { ValidationResult vr = Nd4jCommonValidator.isValidFile(f, "Numpy .npy File", false); if (vr != null && !vr.isValid()) return vr; //TODO let's do this without reading whole thing into memory try (INDArray arr = Nd4j.createFromNpyFile(f)) { //Using the fact that INDArray.close() exists -> deallocate memory as soon as reading is done } catch (Throwable t) { if (t instanceof OutOfMemoryError || t.getMessage().toLowerCase().contains("failed to allocate")) { //This is a memory exception during reading... result is indeterminant (might be valid, might not be, can't tell here) return ValidationResult.builder() .valid(true) .formatType("Numpy .npy File") .path(Nd4jCommonValidator.getPath(f)) .build(); } return ValidationResult.builder() .valid(false) .formatType("Numpy .npy File") .path(Nd4jCommonValidator.getPath(f)) .issues(Collections.singletonList("File may be corrupt or is not a Numpy .npy file")) .exception(t) .build(); } return ValidationResult.builder() .valid(true) .formatType("Numpy .npy File") .path(Nd4jCommonValidator.getPath(f)) .build(); }
Example 10
Source File: FullModelComparisons.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void cnnBatchNormTest() throws IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException { String modelPath = "modelimport/keras/fullconfigs/cnn/cnn_batch_norm.h5"; KerasSequentialModel kerasModel = new KerasModel().modelBuilder() .modelHdf5Filename(Resources.asFile(modelPath).getAbsolutePath()) .enforceTrainingConfig(false) .buildSequential(); MultiLayerNetwork model = kerasModel.getMultiLayerNetwork(); model.init(); System.out.println(model.summary()); INDArray input = Nd4j.createFromNpyFile(Resources.asFile("modelimport/keras/fullconfigs/cnn/input.npy")); input = input.permute(0, 3, 1, 2); assertTrue(Arrays.equals(input.shape(), new long[] {5, 3, 10, 10})); INDArray output = model.output(input); INDArray kerasOutput = Nd4j.createFromNpyFile(Resources.asFile("modelimport/keras/fullconfigs/cnn/predictions.npy")); for (int i = 0; i < 5; i++) { TestCase.assertEquals(output.getDouble(i), kerasOutput.getDouble(i), 1e-4); } }
Example 11
Source File: FullModelComparisons.java From deeplearning4j with Apache License 2.0 | 5 votes |
@Test public void cnnBatchNormLargerTest() throws IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException { String modelPath = "modelimport/keras/fullconfigs/cnn_batch_norm/cnn_batch_norm_medium.h5"; KerasSequentialModel kerasModel = new KerasModel().modelBuilder() .modelHdf5Filename(Resources.asFile(modelPath).getAbsolutePath()) .enforceTrainingConfig(false) .buildSequential(); MultiLayerNetwork model = kerasModel.getMultiLayerNetwork(); model.init(); System.out.println(model.summary()); INDArray input = Nd4j.createFromNpyFile(Resources.asFile("modelimport/keras/fullconfigs/cnn_batch_norm/input.npy")); input = input.permute(0, 3, 1, 2); assertTrue(Arrays.equals(input.shape(), new long[] {5, 1, 48, 48})); INDArray output = model.output(input); INDArray kerasOutput = Nd4j.createFromNpyFile(Resources.asFile("modelimport/keras/fullconfigs/cnn_batch_norm/predictions.npy")); for (int i = 0; i < 5; i++) { // TODO this should be a little closer TestCase.assertEquals(output.getDouble(i), kerasOutput.getDouble(i), 1e-2); } }
Example 12
Source File: NumpyFormatTests.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testNpy() throws Exception { for(boolean empty : new boolean[]{false, true}) { val dir = testDir.newFolder(); if(!empty) { new ClassPathResource("numpy_arrays/npy/3,4/").copyDirectory(dir); } else { new ClassPathResource("numpy_arrays/npy/0,3_empty/").copyDirectory(dir); } File[] files = dir.listFiles(); int cnt = 0; for (File f : files) { if (!f.getPath().endsWith(".npy")) { log.warn("Skipping: {}", f); continue; } String path = f.getAbsolutePath(); int lastDot = path.lastIndexOf('.'); int lastUnderscore = path.lastIndexOf('_'); String dtype = path.substring(lastUnderscore + 1, lastDot); // System.out.println(path + " : " + dtype); DataType dt = DataType.fromNumpy(dtype); //System.out.println(dt); INDArray exp; if(empty){ exp = Nd4j.create(dt, 0, 3); } else { exp = Nd4j.arange(12).castTo(dt).reshape(3, 4); } INDArray act = Nd4j.createFromNpyFile(f); assertEquals("Failed with file [" + f.getName() + "]", exp, act); cnt++; } assertTrue(cnt > 0); } }
Example 13
Source File: NumpyFormatTests.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test public void testFromNumpyScalar() throws Exception { val out = Nd4j.createFromNpyFile(new ClassPathResource("numpy_oneoff/scalar.npy").getFile()); assertEquals(Nd4j.scalar(DataType.INT, 1), out); }
Example 14
Source File: NumpyFormatTests.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test(expected = IllegalArgumentException.class) public void testAbsentNumpyFile_1() throws Exception { val f = new File("pew-pew-zomg.some_extension_that_wont_exist"); INDArray act1 = Nd4j.createFromNpyFile(f); }
Example 15
Source File: NumpyFormatTests.java From deeplearning4j with Apache License 2.0 | 4 votes |
@Test(expected = IllegalArgumentException.class) public void testAbsentNumpyFile_2() throws Exception { val f = new File("c:/develop/batch-x-1.npy"); INDArray act1 = Nd4j.createFromNpyFile(f); log.info("Array shape: {}; sum: {};", act1.shape(), act1.sumNumber().doubleValue()); }