/** * Generate a basis function of the required degree by recursion * @param knots The knots that support the basis functions * @param degree The required degree * @param index The index of the required function * @return The basis function */ private Function<Double, Double> generate(final double[] knots, final int degree, final int index) { return generate(knots, degree, index, null); }
/** * Generate a set of N-dimensional b-splines as the produce of 1-dimensional b-splines with a given polynomial degree. * on the specified knots * @param knots holder for the knots and degree in each dimension * @return a List of functions */ public List<Function<double[], Double>> generateSet(BasisFunctionKnots[] knots) { ArgChecker.noNulls(knots, "knots"); int dim = knots.length; int[] nSplines = new int[dim]; int product = 1; List<List<Function<Double, Double>>> oneDSets = new ArrayList<>(dim); for (int i = 0; i < dim; i++) { oneDSets.add(generateSet(knots[i])); nSplines[i] = knots[i].getNumSplines(); product *= nSplines[i]; } final List<Function<double[], Double>> functions = new ArrayList<>(product); for (int i = 0; i < product; i++) { int[] indices = FunctionUtils.fromTensorIndex(i, nSplines); functions.add(generateMultiDim(oneDSets, indices)); } return functions; }
/** * Fits a curve to x-y data. * @param x The independent variables * @param y The dependent variables * @param sigma The error (or tolerance) on the y variables * @param xa The lowest value of x * @param xb The highest value of x * @param nKnots Number of knots (note, the actual number of basis splines and thus fitted weights, equals nKnots + degree-1) * @param degree The degree of the basis function - 0 is piecewise constant, 1 is a sawtooth function (i.e. two straight lines joined in the middle), 2 gives three * quadratic sections joined together, etc. For a large value of degree, the basis function tends to a gaussian * @param lambda The weight given to the penalty function * @param differenceOrder applies the penalty the nth order difference in the weights, so a differenceOrder of 2 will penalise large 2nd derivatives etc * @return The results of the fit */ public GeneralizedLeastSquareResults<Double> solve(List<Double> x, List<Double> y, List<Double> sigma, double xa, double xb, int nKnots, int degree, double lambda, int differenceOrder) { List<Function<Double, Double>> bSplines = _generator.generateSet(BasisFunctionKnots.fromUniform(xa, xb, nKnots, degree)); return _gls.solve(x, y, sigma, bSplines, lambda, differenceOrder); }
public void testPSplineFit2() { final BasisFunctionGenerator generator = new BasisFunctionGenerator(); List<Function<Double, Double>> basisFuncs = generator.generateSet(BasisFunctionKnots.fromUniform(0, 12, 100, 3)); List<Function<Double, Double>> basisFuncsLog = generator.generateSet(BasisFunctionKnots.fromUniform(-5, 3, 100, 3));
/** * Generate a set of b-splines with a given polynomial degree on the specified knots. * @param knots holder for the knots and degree * @return a List of functions */ public List<Function<Double, Double>> generateSet(BasisFunctionKnots knots) { ArgChecker.notNull(knots, "knots"); double[] k = knots.getKnots(); List<Function<Double, Double>> set = null; for (int d = 0; d <= knots.getDegree(); d++) { set = generateSet(k, d, set); } return set; }
private List<Function<Double, Double>> generateSet(final double[] knots, final int degree, final List<Function<Double, Double>> degreeM1Set) { int nSplines = knots.length - degree - 1; final List<Function<Double, Double>> functions = new ArrayList<>(nSplines); for (int i = 0; i < nSplines; i++) { functions.add(generate(knots, degree, i, degreeM1Set)); } return functions; }
knots[i] = BasisFunctionKnots.fromUniform(xa[i], xb[i], nKnots[i], degree[i]); List<Function<double[], Double>> bSplines = _generator.generateSet(knots);
@Test public void testTwoD() { BasisFunctionKnots knots1 = BasisFunctionKnots.fromInternalKnots(KNOTS, 2); BasisFunctionKnots knots2 = BasisFunctionKnots.fromInternalKnots(KNOTS, 3); List<Function<double[], Double>> set = GENERATOR.generateSet(new BasisFunctionKnots[] {knots1, knots2 }); //pick of one of the basis functions for testing int index = FunctionUtils.toTensorIndex(new int[] {3, 3 }, new int[] {knots1.getNumSplines(), knots2.getNumSplines() }); Function<double[], Double> func = set.get(index); assertEquals(1. / 3., func.apply(new double[] {2.0, 2.0 }), 0.0); assertEquals(1. / 2., func.apply(new double[] {2.5, 2.0 }), 0.0); assertEquals(1. / 8. / 48., func.apply(new double[] {1.5, 3.5 }), 0.0); assertEquals(0.0, func.apply(new double[] {4.0, 2.5 }), 0.0); }
@Test(expectedExceptions = IllegalArgumentException.class) public void testFunctionIndexOutOfRange1() { BasisFunctionKnots k = BasisFunctionKnots.fromKnots(KNOTS, 2); GENERATOR.generate(k, -1); }
@Test public void testThreeD() { BasisFunctionKnots knots1 = BasisFunctionKnots.fromInternalKnots(KNOTS, 2); BasisFunctionKnots knots2 = BasisFunctionKnots.fromInternalKnots(KNOTS, 3); BasisFunctionKnots knots3 = BasisFunctionKnots.fromInternalKnots(KNOTS, 1); List<Function<double[], Double>> set = GENERATOR.generateSet(new BasisFunctionKnots[] {knots1, knots2, knots3 }); //pick of one of the basis functions for testing int index = FunctionUtils.toTensorIndex(new int[] {3, 3, 3 }, new int[] {knots1.getNumSplines(), knots2.getNumSplines(), knots3.getNumSplines() }); Function<double[], Double> func = set.get(index); assertEquals(1. / 3., func.apply(new double[] {2.0, 2.0, 3.0 }), 0.0); }
@Test(expectedExceptions = IllegalArgumentException.class) public void testFunctionIndexOutOfRange2() { BasisFunctionKnots k = BasisFunctionKnots.fromKnots(KNOTS, 5); int nS = k.getNumSplines(); GENERATOR.generate(k, nS); }
@Test public void testFirstOrder() { BasisFunctionKnots knots = BasisFunctionKnots.fromInternalKnots(KNOTS, 1); final Function<Double, Double> func = GENERATOR.generate(knots, 3); assertEquals(0.0, func.apply(1.76), 0.0); assertEquals(1.0, func.apply(3.0), 0.0); assertEquals(0, func.apply(4.0), 0.0); assertEquals(0.5, func.apply(2.5), 0.0); }
@Test public void testZeroOrder() { BasisFunctionKnots knots = BasisFunctionKnots.fromInternalKnots(KNOTS, 0); final Function<Double, Double> func = GENERATOR.generate(knots, 4); assertEquals(0.0, func.apply(3.5), 0.0); assertEquals(1.0, func.apply(4.78), 0.0); assertEquals(1.0, func.apply(4.0), 0.0); assertEquals(0.0, func.apply(5.0), 0.0); }
@Test public void testSecondOrder() { BasisFunctionKnots knots = BasisFunctionKnots.fromInternalKnots(KNOTS, 2); final Function<Double, Double> func = GENERATOR.generate(knots, 3); assertEquals(0.0, func.apply(0.76), 0.0); assertEquals(0.125, func.apply(1.5), 0.0); assertEquals(0.5, func.apply(2.0), 0.0); assertEquals(0.75, func.apply(2.5), 0.0); assertEquals(0.0, func.apply(4.0), 0.0); }
@Test public void testThirdOrder() { BasisFunctionKnots knots = BasisFunctionKnots.fromInternalKnots(KNOTS, 3); final Function<Double, Double> func = GENERATOR.generate(knots, 3); assertEquals(0.0, func.apply(-0.1), 0.0); assertEquals(1. / 6., func.apply(1.0), 0.0); assertEquals(2. / 3., func.apply(2.0), 0.0); assertEquals(1 / 48., func.apply(3.5), 0.0); assertEquals(0.0, func.apply(4.0), 0.0); }
/** * Generate the i^th basis function * @param data Container for the knots and degree of the basis function * @param index The index (from zero) of the function. Must be in range 0 to data.getNumSplines() (exclusive) * For example if the degree is 1, and index is 0, this will cover the first three knots. * @return The i^th basis function */ protected Function<Double, Double> generate(BasisFunctionKnots data, final int index) { ArgChecker.notNull(data, "data"); ArgChecker.isTrue(index >= 0 && index < data.getNumSplines(), "index must be in range {} to {} (exclusive)", 0, data.getNumSplines()); return generate(data.getKnots(), data.getDegree(), index); }