/** * Returns tracking point for Allocator * * PLEASE NOTE: Suitable & meaningful only for specific backends * * @return */ @Override public Long getTrackingPoint() { if (underlyingDataBuffer() != this) return underlyingDataBuffer() == null ? trackingPoint : underlyingDataBuffer().getTrackingPoint(); return trackingPoint; }
@Override public AllocationPoint getAllocationPoint(DataBuffer buffer) { if (buffer instanceof CompressedDataBuffer) { log.warn("Trying to get AllocationPoint from CompressedDataBuffer"); throw new RuntimeException("AP CDB"); } return getAllocationPoint(buffer.getTrackingPoint()); }
@Override public AllocationPoint getAllocationPoint(DataBuffer buffer) { if (buffer instanceof CompressedDataBuffer) { log.warn("Trying to get AllocationPoint from CompressedDataBuffer"); throw new RuntimeException("AP CDB"); } return getAllocationPoint(buffer.getTrackingPoint()); }
@Override public boolean sameUnderlyingData(DataBuffer buffer) { return buffer.getTrackingPoint() == getTrackingPoint(); }
@Override public boolean sameUnderlyingData(DataBuffer buffer) { return buffer.getTrackingPoint() == getTrackingPoint(); }
@Override public void tickHostWrite(DataBuffer buffer) { AllocationPoint point = getAllocationPoint(buffer.getTrackingPoint()); point.tickHostWrite(); }
@Override public void tickHostWrite(DataBuffer buffer) { AllocationPoint point = getAllocationPoint(buffer.getTrackingPoint()); point.tickHostWrite(); }
/** * This method should be called to make sure that data on host side is actualized * * @param buffer */ @Override public void synchronizeHostData(DataBuffer buffer) { // we don't synchronize constant buffers, since we assume they are always valid on host side if (buffer.isConstant()) { return; } // we actually need synchronization only in device-dependant environment. no-op otherwise if (memoryHandler.isDeviceDependant()) { AllocationPoint point = getAllocationPoint(buffer.getTrackingPoint()); if (point == null) throw new RuntimeException("AllocationPoint is NULL"); memoryHandler.synchronizeThreadDevice(Thread.currentThread().getId(), memoryHandler.getDeviceId(), point); } }
/** * This method should be called to make sure that data on host side is actualized * * @param buffer */ @Override public void synchronizeHostData(DataBuffer buffer) { // we don't want non-committed ops left behind //Nd4j.getExecutioner().push(); // we don't synchronize constant buffers, since we assume they are always valid on host side if (buffer.isConstant()) { return; } // we actually need synchronization only in device-dependant environment. no-op otherwise if (memoryHandler.isDeviceDependant()) { AllocationPoint point = getAllocationPoint(buffer.getTrackingPoint()); if (point == null) throw new RuntimeException("AllocationPoint is NULL"); memoryHandler.synchronizeThreadDevice(Thread.currentThread().getId(), memoryHandler.getDeviceId(), point); } }
@Override public void tickDeviceWrite(INDArray array) { DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer(); AllocationPoint point = getAllocationPoint(buffer.getTrackingPoint()); point.tickDeviceWrite(); }
@Override public void tickDeviceWrite(INDArray array) { DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer(); AllocationPoint point = getAllocationPoint(buffer.getTrackingPoint()); point.tickDeviceWrite(); }
public BaseCudaDataBuffer(@NonNull DataBuffer underlyingBuffer, long length, long offset) { //this(length, underlyingBuffer.getElementSize(), offset); this.allocationMode = AllocationMode.LONG_SHAPE; initTypeAndSize(); this.wrappedDataBuffer = underlyingBuffer; this.originalBuffer = underlyingBuffer.originalDataBuffer() == null ? underlyingBuffer : underlyingBuffer.originalDataBuffer(); this.length = length; this.offset = offset; this.originalOffset = offset; this.trackingPoint = underlyingBuffer.getTrackingPoint(); this.elementSize = (byte) underlyingBuffer.getElementSize(); this.allocationPoint = ((BaseCudaDataBuffer) underlyingBuffer).allocationPoint; if (underlyingBuffer.dataType() == Type.DOUBLE) { this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asDoublePointer(); indexer = DoubleIndexer.create((DoublePointer) pointer); } else if (underlyingBuffer.dataType() == Type.FLOAT) { this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asFloatPointer(); indexer = FloatIndexer.create((FloatPointer) pointer); } else if (underlyingBuffer.dataType() == Type.INT) { this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asIntPointer(); indexer = IntIndexer.create((IntPointer) pointer); } else if (underlyingBuffer.dataType() == Type.HALF) { this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asShortPointer(); indexer = HalfIndexer.create((ShortPointer) pointer); } else if (underlyingBuffer.dataType() == Type.LONG) { this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), originalBuffer.length()).asLongPointer(); indexer = LongIndexer.create((LongPointer) pointer); } }
this.offset = offset; this.originalOffset = offset; this.trackingPoint = underlyingBuffer.getTrackingPoint(); this.elementSize = (byte) underlyingBuffer.getElementSize(); this.allocationPoint = ((BaseCudaDataBuffer) underlyingBuffer).allocationPoint;