public static InputType[] inferInputTypes(INDArray... inputArrays) { InputType[] out = new InputType[inputArrays.length]; for (int i = 0; i < inputArrays.length; i++) { out[i] = inferInputType(inputArrays[i]); } return out; }
/** * Build the multilayer network defined by the networkconfiguration and the list of layers. */ protected void createModel() throws Exception { final INDArray features = getFirstBatchFeatures(trainData); ComputationGraphConfiguration.GraphBuilder gb = netConfig.builder().seed(getSeed()).graphBuilder(); // Set ouput size final Layer lastLayer = layers[layers.length - 1]; final int nOut = trainData.numClasses(); if (lastLayer instanceof FeedForwardLayer) { ((FeedForwardLayer) lastLayer).setNOut(nOut); } if (getInstanceIterator() instanceof CnnTextEmbeddingInstanceIterator) { makeCnnTextLayerSetup(gb); } else { makeDefaultLayerSetup(gb); } gb.setInputTypes(InputType.inferInputType(features)); ComputationGraphConfiguration conf = gb.pretrain(false).backprop(true).build(); ComputationGraph model = new ComputationGraph(conf); model.init(); this.model = model; }
gb.setInputTypes(InputType.inferInputType(features));