/** * This method sets dataType for the current JVM runtime * @param dType */ public static void setDataType(@NonNull DataBuffer.Type dType) { DataTypeUtil.setDTypeForContext(dType); }
/** * Returns the data opType used for the runtime * * @return the datatype used for the runtime */ public static DataBuffer.Type dataType() { return DataTypeUtil.getDtypeFromContext(); }
/** * Set the allocation mode for the nd4j context * The value must be one of: heap, java cpp, or direct * or an @link{IllegalArgumentException} is thrown * @param allocationModeForContext */ public static void setDTypeForContext(DataBuffer.Type allocationModeForContext) { try { lock.writeLock().lock(); dtype = allocationModeForContext; setDTypeForContext(getDTypeForName(allocationModeForContext)); } finally { lock.writeLock().unlock(); } }
/** * Returns the number of bytes * for the graph * * @return */ public long memoryForGraph() { return numElements() * DataTypeUtil.lengthForDtype(Nd4j.dataType()); }
@Override public X set(int i, X aX) { if(DataTypeUtil.getDtypeFromContext() == DataBuffer.Type.DOUBLE) container.putScalar(i,aX.doubleValue()); else { container.putScalar(i,aX.floatValue()); } return aX; }
DataTypeUtil.setDTypeForContext(dtype);
/** * get the allocation mode from the context * @return */ public static DataBuffer.Type getDtypeFromContext() { try { lock.readLock().lock(); if (dtype == null) { lock.readLock().unlock(); lock.writeLock().lock(); if (dtype == null) dtype = getDtypeFromContext(Nd4jContext.getInstance().getConf().getProperty("dtype")); lock.writeLock().unlock(); lock.readLock().lock(); } return dtype; } finally { lock.readLock().unlock(); } }
/** * This method sets dataType for the current JVM runtime * @param dType */ public static void setDataType(@NonNull DataBuffer.Type dType) { DataTypeUtil.setDTypeForContext(dType); }
Type currentType = Type.valueOf(s.readUTF()); if (currentType != Type.COMPRESSED) type = DataTypeUtil.getDtypeFromContext(); else type = currentType; else if (DataTypeUtil.getDtypeFromContext() == Type.DOUBLE && currentType != Type.INT) elementSize = 8; else if (DataTypeUtil.getDtypeFromContext() == Type.FLOAT || currentType == Type.INT) elementSize = 4; else if (DataTypeUtil.getDtypeFromContext() == Type.HALF && currentType != Type.INT) elementSize = 2; if (currentType != DataTypeUtil.getDtypeFromContext() && currentType != Type.HALF && currentType != Type.INT && !(DataTypeUtil.getDtypeFromContext() == Type.DOUBLE)) { log.warn("Loading a data stream with opType different from what is set globally. Expect precision loss"); if (DataTypeUtil.getDtypeFromContext() == Type.INT) log.warn("Int to float/double widening UNSUPPORTED!!!"); readContent(s, currentType, DataTypeUtil.getDtypeFromContext()); Type currentType = Type.valueOf(s.readUTF()); if (currentType != Type.COMPRESSED) type = DataTypeUtil.getDtypeFromContext(); else type = currentType; else if (DataTypeUtil.getDtypeFromContext() == Type.DOUBLE && currentType != Type.INT) elementSize = 8;
private Properties loadModelProperties() throws IOException { FileInputStream input = new FileInputStream(modelPath + "/config.properties"); // load a properties file Properties prop = new Properties(); prop.load(input); if (prop.getProperty("precision") != null && prop.getProperty("precision").equals("FP16")) { LOG.info("Model uses FP16 precision. Activating support."); DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF); } return prop; }
@Override public void add(int i, X aX) { rangeCheck(i); growCapacity(i); moveForward(i); if(DataTypeUtil.getDtypeFromContext() == DataBuffer.Type.DOUBLE) container.putScalar(i,aX.doubleValue()); else { container.putScalar(i,aX.floatValue()); } size++; }
DataTypeUtil.setDTypeForContext(dtype);
@Override public boolean add(X aX) { if(container == null) { container = Nd4j.create(10); } else if(size == container.length()) { growCapacity(size * 2); } if(DataTypeUtil.getDtypeFromContext() == DataBuffer.Type.DOUBLE) container.putScalar(size,aX.doubleValue()); else { container.putScalar(size,aX.floatValue()); } size++; return true; }
@Override public void execute() { if ("FP16".equals(args().precision)) { precision = ParameterPrecision.FP16; System.out.println("Parameter precision set to FP16."); } if ("FP16".equals(args().precision)) { DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF); } if (args().getTrainingSets().length == 0) { System.err.println("You must provide training datasets."); } FeatureMapper featureMapper = null; try { featureMapper = configureFeatureMapper(args().featureMapperClassname, args().isTrio, args().getTrainingSets()); execute(featureMapper, args().getTrainingSets(), args().miniBatchSize); } catch (IOException e) { System.err.println("An exception occured. Details may be provided below"); e.printStackTrace(); } }
throw new IllegalArgumentException("Invalid maxRelativeError: " + maxRelError); DataBuffer.Type dataType = DataTypeUtil.getDtypeFromContext(); if (dataType != DataBuffer.Type.DOUBLE) { throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision ("
DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF);
public void pointerIndexerByGlobalType(Type currentType) { if (currentType == Type.LONG) { pointer = new LongPointer(length()); setIndexer(LongRawIndexer.create((LongPointer) pointer)); type = Type.LONG; } else if (currentType == Type.INT) { pointer = new IntPointer(length()); setIndexer(IntIndexer.create((IntPointer) pointer)); type = Type.INT; } else { if (DataTypeUtil.getDtypeFromContext() == Type.DOUBLE) { pointer = new DoublePointer(length()); indexer = DoubleIndexer.create((DoublePointer) pointer); } else if (DataTypeUtil.getDtypeFromContext() == Type.FLOAT) { pointer = new FloatPointer(length()); setIndexer(FloatIndexer.create((FloatPointer) pointer)); } else if (DataTypeUtil.getDtypeFromContext() == Type.LONG) { pointer = new LongPointer(length()); setIndexer(LongIndexer.create((LongPointer) pointer)); } } }
public static void main(String[] args) { DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE);
/** * Returns the data type used for the runtime * * @return the datatype used for the runtime */ public static DataBuffer.Type dataType() { return DataTypeUtil.getDtypeFromContext(); }
/** * Get the total memory use in bytes for the given configuration (using the current ND4J data type) * * @param minibatchSize Mini batch size to estimate the memory for * @param memoryUseMode The memory use mode (training or inference) * @param cacheMode The CacheMode to use * @return The estimated total memory consumption in bytes */ public long getTotalMemoryBytes(int minibatchSize, @NonNull MemoryUseMode memoryUseMode, @NonNull CacheMode cacheMode) { return getTotalMemoryBytes(minibatchSize, memoryUseMode, cacheMode, DataTypeUtil.getDtypeFromContext()); }