/** Returns {@code new ShortBufferIndexer(buffer)} */ public static Bfloat16Indexer create(ShortBuffer buffer) { return new Bfloat16BufferIndexer(buffer); } /** Returns {@code create(pointer, { pointer.limit() - pointer.position() }, { 1 }, true)} */
@Override public Bfloat16Indexer put(long i, long j, long k, float h) { buffer.put((int)i * (int)strides[0] + (int)j * (int)strides[1] + (int)k, (short)fromFloat(h)); return this; } @Override public Bfloat16Indexer put(long[] indices, float h) {
@Override public Bfloat16Indexer get(long i, long j, float[] h, int offset, int length) { for (int n = 0; n < length; n++) { h[offset + n] = toFloat(buffer.get((int)i * (int)strides[0] + (int)j * (int)strides[1] + n)); } return this; } @Override public float get(long i, long j, long k) {
@Override public Bfloat16Indexer put(long i, float[] h, int offset, int length) { for (int n = 0; n < length; n++) { buffer.put((int)i * (int)strides[0] + n, (short)fromFloat(h[offset + n])); } return this; } @Override public Bfloat16Indexer put(long i, long j, float h) {
@Override public Bfloat16Indexer get(long i, float[] h, int offset, int length) { for (int n = 0; n < length; n++) { h[offset + n] = toFloat(buffer.get((int)i * (int)strides[0] + n)); } return this; } @Override public float get(long i, long j) {
@Override public Bfloat16Indexer put(long i, long j, float h) { buffer.put((int)i * (int)strides[0] + (int)j, (short)fromFloat(h)); return this; } @Override public Bfloat16Indexer put(long i, long j, float[] h, int offset, int length) {
@Override public float get(long i, long j) { return toFloat(buffer.get((int)i * (int)strides[0] + (int)j)); } @Override public Bfloat16Indexer get(long i, long j, float[] h, int offset, int length) {
/** Returns {@code new ShortBufferIndexer(buffer, sizes, strides)} */ public static Bfloat16Indexer create(ShortBuffer buffer, long[] sizes, long[] strides) { return new Bfloat16BufferIndexer(buffer, sizes, strides); } /** Returns {@code create(pointer, sizes, strides, true)} */
@Override public Bfloat16Indexer put(long i, long j, float[] h, int offset, int length) { for (int n = 0; n < length; n++) { buffer.put((int)i * (int)strides[0] + (int)j * (int)strides[1] + n, (short)fromFloat(h[offset + n])); } return this; } @Override public Bfloat16Indexer put(long i, long j, long k, float h) {
@Override public float get(long i, long j, long k) { return toFloat(buffer.get((int)i * (int)strides[0] + (int)j * (int)strides[1] + (int)k)); } @Override public float get(long... indices) {
/** * Creates a bfloat16 indexer to access efficiently the data of a pointer. * * @param pointer data to access via a buffer or to copy to an array * @param direct {@code true} to use a direct buffer, see {@link Indexer} for details * @return the new short indexer backed by the raw memory interface, a buffer, or an array */ public static Bfloat16Indexer create(final ShortPointer pointer, long[] sizes, long[] strides, boolean direct) { if (direct) { return Raw.getInstance() != null ? new Bfloat16RawIndexer(pointer, sizes, strides) : new Bfloat16BufferIndexer(pointer.asBuffer(), sizes, strides); } else { final long position = pointer.position(); short[] array = new short[(int)Math.min(pointer.limit() - position, Integer.MAX_VALUE)]; pointer.get(array); return new Bfloat16ArrayIndexer(array, sizes, strides) { @Override public void release() { pointer.position(position).put(array); super.release(); } }; } }
@Override public Bfloat16Indexer put(long i, float h) { buffer.put((int)i, (short)fromFloat(h)); return this; } @Override public Bfloat16Indexer put(long i, float[] h, int offset, int length) {
@Override public float get(long i) { return toFloat(buffer.get((int)i)); } @Override public Bfloat16Indexer get(long i, float[] h, int offset, int length) {