/** Returns {@link #tensor_data()} wrapped in a {@link Buffer} of appropriate type starting at given index. */ public <B extends Buffer> B createBuffer(long index) { BytePointer ptr = tensor_data(); long size = TotalBytes(); switch (dtype()) { case DT_COMPLEX64: case DT_FLOAT: return (B)new FloatPointer(ptr).position(index).capacity(size/4).asBuffer(); case DT_DOUBLE: return (B)new DoublePointer(ptr).position(index).capacity(size/8).asBuffer(); case DT_QINT32: case DT_INT32: return (B)new IntPointer(ptr).position(index).capacity(size/4).asBuffer(); case DT_BOOL: case DT_QUINT8: case DT_UINT8: case DT_QINT8: case DT_INT8: return (B)ptr.position(index).capacity(size).asBuffer(); case DT_BFLOAT16: case DT_INT16: return (B)new ShortPointer(ptr).position(index).capacity(size/2).asBuffer(); case DT_INT64: return (B)new LongPointer(ptr).position(index).capacity(size/8).asBuffer(); case DT_STRING: default: assert false; } return null; }
case DT_BFLOAT16: return (I)UShortIndexer.create(new ShortPointer(ptr).capacity(size/2), sizes, strides, direct).indexable(this); case DT_INT16: return (I)ShortIndexer.create(new ShortPointer(ptr).capacity(size/2), sizes, strides, direct).indexable(this); case DT_INT64: return (I)LongIndexer.create(new LongPointer(ptr).capacity(size/8), sizes, strides, direct).indexable(this); case DT_STRING: default: assert false;