/** * Form the node data according data in parquet row. * * @param g The given group presenting the node data from Spark DT model. */ @NotNull private static SparkModelParser.NodeData extractNodeDataFromParquetRow(SimpleGroup g) { NodeData nodeData = new NodeData(); nodeData.id = g.getInteger(0, 0); nodeData.prediction = g.getDouble(1, 0); nodeData.leftChildId = g.getInteger(5, 0); nodeData.rightChildId = g.getInteger(6, 0); if (nodeData.leftChildId == -1 && nodeData.rightChildId == -1) { nodeData.featureIdx = -1; nodeData.threshold = -1; nodeData.isLeafNode = true; } else { final SimpleGroup splitGrp = (SimpleGroup)g.getGroup(7, 0); nodeData.featureIdx = splitGrp.getInteger(0, 0); nodeData.threshold = splitGrp.getGroup(1, 0).getGroup(0, 0).getDouble(0, 0); } return nodeData; }
/** * Read interceptor value from parquet. * * @param g Interceptor group. */ private static double readInterceptor(SimpleGroup g) { double interceptor; final SimpleGroup interceptVector = (SimpleGroup)g.getGroup(2, 0); final SimpleGroup interceptVectorVal = (SimpleGroup)interceptVector.getGroup(3, 0); final SimpleGroup interceptVectorValElement = (SimpleGroup)interceptVectorVal.getGroup(0, 0); interceptor = interceptVectorValElement.getDouble(0, 0); return interceptor; }
final int treeID = g.getInteger(0, 0); final SimpleGroup nodeDataGroup = (SimpleGroup)g.getGroup(1, 0);
@Override public Group getGroup(int fieldIndex, int index) { return (Group)getValue(fieldIndex, index); }
@Override public void add(int fieldIndex, boolean value) { add(fieldIndex, new BooleanValue(value)); }
@Override public Group addGroup(int fieldIndex) { SimpleGroup g = new SimpleGroup(schema.getType(fieldIndex).asGroupType()); add(fieldIndex, g); return g; }
@Override public void add(int fieldIndex, Binary value) { switch (getType().getType(fieldIndex).asPrimitiveType().getPrimitiveTypeName()) { case BINARY: case FIXED_LEN_BYTE_ARRAY: add(fieldIndex, new BinaryValue(value)); break; case INT96: add(fieldIndex, new Int96Value(value)); break; default: throw new UnsupportedOperationException( getType().asPrimitiveType().getName() + " not supported for Binary"); } }
@Override public Group newGroup() { return new SimpleGroup(schema); }
@Override public String toString() { return toString(""); }
/** * Read coefficient matrix from parquet. * * @param g Coefficient group. * @return Vector of coefficients. */ private static Vector readCoefficients(SimpleGroup g) { Vector coefficients; final int amountOfCoefficients = g.getGroup(3, 0).getGroup(5, 0).getFieldRepetitionCount(0); coefficients = new DenseVector(amountOfCoefficients); for (int j = 0; j < amountOfCoefficients; j++) { double coefficient = g.getGroup(3, 0).getGroup(5, 0).getGroup(0, j).getDouble(0, 0); coefficients.set(j, coefficient); } return coefficients; }
/** * Read interceptor value from parquet. * * @param g Interceptor group. */ private static double readLinRegInterceptor(SimpleGroup g) { return g.getDouble(0, 0); }
@Override public Group getGroup(int fieldIndex, int index) { return (Group)getValue(fieldIndex, index); }
@Override public void add(int fieldIndex, float value) { add(fieldIndex, new FloatValue(value)); }
@Override public Group addGroup(int fieldIndex) { SimpleGroup g = new SimpleGroup(schema.getType(fieldIndex).asGroupType()); add(fieldIndex, g); return g; }
@Override public void add(int fieldIndex, Binary value) { switch (getType().getType(fieldIndex).asPrimitiveType().getPrimitiveTypeName()) { case BINARY: case FIXED_LEN_BYTE_ARRAY: add(fieldIndex, new BinaryValue(value)); break; case INT96: add(fieldIndex, new Int96Value(value)); break; default: throw new UnsupportedOperationException( getType().asPrimitiveType().getName() + " not supported for Binary"); } }
@Override public Group newGroup() { return new SimpleGroup(schema); }
@Override public String toString() { return toString(""); }
/** * Read coefficient matrix from parquet. * * @param g Coefficient group. * @return Vector of coefficients. */ private static Vector readLinRegCoefficients(SimpleGroup g) { Vector coefficients; Group coeffGroup = g.getGroup(1, 0).getGroup(3, 0); final int amountOfCoefficients = coeffGroup.getFieldRepetitionCount(0); coefficients = new DenseVector(amountOfCoefficients); for (int j = 0; j < amountOfCoefficients; j++) { double coefficient = coeffGroup.getGroup(0, j).getDouble(0, 0); coefficients.set(j, coefficient); } return coefficients; }
/** * Read interceptor value from parquet. * * @param g Interceptor group. */ private static double readSVMInterceptor(SimpleGroup g) { return g.getDouble(1, 0); }
@Override public String getValueToString(int fieldIndex, int index) { return String.valueOf(getValue(fieldIndex, index)); }