@Override public Bfloat16Indexer put(long i, float h) { RAW.putShort(base + checkIndex(i, size) * VALUE_BYTES, (short)fromFloat(h)); return this; } @Override public Bfloat16Indexer put(long i, float[] h, int offset, int length) {
@Override public float get(long i) { return toFloat(RAW.getShort(base + checkIndex(i, size) * VALUE_BYTES)); } @Override public Bfloat16Indexer get(long i, float[] h, int offset, int length) {
@Override public Bfloat16Indexer put(long i, float[] h, int offset, int length) { for (int n = 0; n < length; n++) { put(i * strides[0] + n, h[offset + n]); } return this; } @Override public Bfloat16Indexer put(long i, long j, float h) {
@Override public Bfloat16Indexer get(long i, float[] h, int offset, int length) { for (int n = 0; n < length; n++) { h[offset + n] = get(i * strides[0] + n); } return this; } @Override public float get(long i, long j) {
/** * Creates a bfloat16 indexer to access efficiently the data of a pointer. * * @param pointer data to access via a buffer or to copy to an array * @param direct {@code true} to use a direct buffer, see {@link Indexer} for details * @return the new short indexer backed by the raw memory interface, a buffer, or an array */ public static Bfloat16Indexer create(final ShortPointer pointer, long[] sizes, long[] strides, boolean direct) { if (direct) { return Raw.getInstance() != null ? new Bfloat16RawIndexer(pointer, sizes, strides) : new Bfloat16BufferIndexer(pointer.asBuffer(), sizes, strides); } else { final long position = pointer.position(); short[] array = new short[(int)Math.min(pointer.limit() - position, Integer.MAX_VALUE)]; pointer.get(array); return new Bfloat16ArrayIndexer(array, sizes, strides) { @Override public void release() { pointer.position(position).put(array); super.release(); } }; } }
@Override public Bfloat16Indexer put(long i, long j, float h) { put(i * strides[0] + j, h); return this; } @Override public Bfloat16Indexer put(long i, long j, float[] h, int offset, int length) {
@Override public Bfloat16Indexer get(long i, long j, float[] h, int offset, int length) { for (int n = 0; n < length; n++) { h[offset + n] = get(i * strides[0] + j * strides[1] + n); } return this; } @Override public float get(long i, long j, long k) {
@Override public Bfloat16Indexer put(long i, long j, float[] h, int offset, int length) { for (int n = 0; n < length; n++) { put(i * strides[0] + j * strides[1] + n, h[offset + n]); } return this; } @Override public Bfloat16Indexer put(long i, long j, long k, float h) {
@Override public float get(long i, long j) { return get(i * strides[0] + j); } @Override public Bfloat16Indexer get(long i, long j, float[] h, int offset, int length) {
@Override public Bfloat16Indexer put(long i, long j, long k, float h) { put(i * strides[0] + j * strides[1] + k, h); return this; } @Override public Bfloat16Indexer put(long[] indices, float h) {
@Override public float get(long i, long j, long k) { return get(i * strides[0] + j * strides[1] + k); } @Override public float get(long... indices) {