public static long[] broadcastToShape(long[] inputShapeWithOnes, long seed) { Nd4j.getRandom().setSeed(seed); val shape = new long[inputShapeWithOnes.length]; for (int i = 0; i < shape.length; i++) { if (inputShapeWithOnes[i] == 1) { shape[i] = Nd4j.getRandom().nextInt(9) + 1; } else shape[i] = inputShapeWithOnes[i]; } return shape; }
/** * Generate a random shape to * broadcast to * given a randomly generated * shape with 1s in it as inputs * @param inputShapeWithOnes * @param seed * @return */ public static int[] broadcastToShape(int[] inputShapeWithOnes, long seed) { Nd4j.getRandom().setSeed(seed); int[] shape = new int[inputShapeWithOnes.length]; for (int i = 0; i < shape.length; i++) { if (inputShapeWithOnes[i] == 1) { shape[i] = Nd4j.getRandom().nextInt(9) + 1; } else shape[i] = inputShapeWithOnes[i]; } return shape; }
/** * Sample a dataset * * @param numSamples the number of samples to getFromOrigin * @param rng the rng to use * @param withReplacement whether to allow duplicates (only tracked by example row number) * @return the sample dataset */ @Override public DataSet sample(int numSamples, org.nd4j.linalg.api.rng.Random rng, boolean withReplacement) { INDArray examples = Nd4j.create(numSamples, getFeatures().columns()); INDArray outcomes = Nd4j.create(numSamples, numOutcomes()); Set<Integer> added = new HashSet<>(); for (int i = 0; i < numSamples; i++) { int picked = rng.nextInt(numExamples()); if (!withReplacement) while (added.contains(picked)) picked = rng.nextInt(numExamples()); examples.putRow(i, get(picked).getFeatures()); outcomes.putRow(i, get(picked).getLabels()); } return new DataSet(examples, outcomes); }
/** * Create an ndarray * of * @param seed * @param rank * @param numShapes * @return */ public static int[][] getRandomBroadCastShape(long seed, int rank, int numShapes) { Nd4j.getRandom().setSeed(seed); INDArray coinFlip = Nd4j.getDistributions().createBinomial(1, 0.5).sample(new int[] {numShapes, rank}); int[][] ret = new int[(int) coinFlip.rows()][(int) coinFlip.columns()]; for (int i = 0; i < coinFlip.rows(); i++) { for (int j = 0; j < coinFlip.columns(); j++) { int set = coinFlip.getInt(i, j); if (set > 0) ret[i][j] = set; else { //anything from 0 to 9 ret[i][j] = Nd4j.getRandom().nextInt(9) + 1; } } } return ret; }
public int nextInt(int maxExclusive) { return nd4jRandom.nextInt(maxExclusive); }
private INDArray doubleNextInt(long[] shape) { Nd4j.setDataType(bufferType); return nd4jRandom.nextInt(shape); }
/** * Get the device number for a particular host thread * @return the device for the given host thread * */ public int getDeviceForThread() { if(numDevices > 1) { Integer device = threadNameToDeviceNumber.get(Thread.currentThread().getName()); if(device == null) { org.nd4j.linalg.api.rng.Random random = Nd4j.getRandom(); if(random == null) throw new IllegalStateException("Unable to load random class"); device = Nd4j.getRandom().nextInt(numDevices); //reroute banned devices while(bannedDevices != null && bannedDevices.contains(device)) device = Nd4j.getRandom().nextInt(numDevices); threadNameToDeviceNumber.put(Thread.currentThread().getName(),device); return device; } } return 0; }
/** * Generate a random shape to * broadcast to * given a randomly generated * shape with 1s in it as inputs * @param inputShapeWithOnes * @param seed * @return */ public static int[] broadcastToShape(int[] inputShapeWithOnes,long seed) { Nd4j.getRandom().setSeed(seed); int[] shape = new int[inputShapeWithOnes.length]; for(int i = 0; i < shape.length; i++) { if(inputShapeWithOnes[i] == 1) { shape[i] = Nd4j.getRandom().nextInt(9) + 1; } else shape[i] = inputShapeWithOnes[i]; } return shape; }
/** * Takes an image and returns a randomly cropped image. * * @param image to transform, null == end of stream * @param random object to use (or null for deterministic) * @return transformed image */ @Override public ImageWritable transform(ImageWritable image, Random random) { if (image == null) { return null; } // ensure that transform is valid if (image.getFrame().imageHeight < outputHeight || image.getFrame().imageWidth < outputWidth) throw new UnsupportedOperationException( "Output height/width cannot be more than the input image. Requested: " + outputHeight + "+x" + outputWidth + ", got " + image.getFrame().imageHeight + "+x" + image.getFrame().imageWidth); // determine boundary to place random offset int cropTop = image.getFrame().imageHeight - outputHeight; int cropLeft = image.getFrame().imageWidth - outputWidth; Mat mat = converter.convert(image.getFrame()); int top = rng.nextInt(cropTop + 1); int left = rng.nextInt(cropLeft + 1); int y = Math.min(top, mat.rows() - 1); int x = Math.min(left, mat.cols() - 1); Mat result = mat.apply(new Rect(x, y, outputWidth, outputHeight)); return new ImageWritable(converter.convert(result)); }
/** * Sample a dataset * * @param numSamples the number of samples to getFromOrigin * @param rng the rng to use * @param withReplacement whether to allow duplicates (only tracked by example row number) * @return the sample dataset */ @Override public DataSet sample(int numSamples, org.nd4j.linalg.api.rng.Random rng, boolean withReplacement) { INDArray examples = Nd4j.create(numSamples, getFeatures().columns()); INDArray outcomes = Nd4j.create(numSamples, numOutcomes()); Set<Integer> added = new HashSet<>(); for (int i = 0; i < numSamples; i++) { int picked = rng.nextInt(numExamples()); if (!withReplacement) while (added.contains(picked)) picked = rng.nextInt(numExamples()); examples.putRow(i, get(picked).getFeatures()); outcomes.putRow(i, get(picked).getLabels()); } return new DataSet(examples, outcomes); }
/** * Create an ndarray * of * @param seed * @param rank * @param numShapes * @return */ public static int[][] getRandomBroadCastShape(long seed ,int rank,int numShapes) { Nd4j.getRandom().setSeed(seed); INDArray coinFlip = Nd4j.getDistributions().createBinomial(1,0.5).sample(new int[]{numShapes,rank}); int[][] ret = new int[coinFlip.rows()][coinFlip.columns()]; for(int i = 0; i < coinFlip.rows(); i++) { for(int j = 0; j < coinFlip.columns(); j++) { int set = coinFlip.getInt(i,j); if(set > 0) ret[i][j] = set; else { //anything from 0 to 9 ret[i][j] = Nd4j.getRandom().nextInt(9) + 1; } } } return ret; }