solve(d, x, currentColumn+1, b1, c1, d1); return; } else { x[currentColumn] = anythingIfNotNaN(x[currentColumn+1]); return; solve(d, x, currentColumn+1, c0, 0, d0); solve(d, x, currentColumn+1, b1, c1, d1); solve(d, x, currentColumn+1, (b0 * b1 - a1 * c0), (b0 * c1),
private double anythingIfNotNaN (double testValue) { return ifNotNaN(testValue, 1.0); } private double ifNotNaN (double testValue, double value) {
private void solve (Vector d, double[] x, int currentColumn, double b0, double c0, double d0) { // Solve the case M x = d for x, where we (the tri-diagonal matrix) is m // // We do this recursively, solving one place at a time, pretending in // each case we are at the top left of the matrix. // // So, in each case, we have one of three cases, depending on whether // there are one, two, or three-or-more rows left to solve. int rowsLeft = _n-currentColumn; if (0 == rowsLeft) { return; // Nothing left to solve } else if (1 == rowsLeft) { solveSingleRow(d, x, currentColumn, b0, c0, d0); } else if (2 == rowsLeft) { solveDoubleRow(d, x, currentColumn, b0, c0, d0); } else { solveTripleRow(d, x, currentColumn, b0, c0, d0); } }
TriDiagonalMatrix M = new TriDiagonalMatrix(matrixEntries); Vector Y = getCoordinateVector(Ys, i); Vector K = getCoordinateVector(Ks, i); Vector P1i = M.solve(Y);
M = new TriDiagonalMatrix(0, 1, 0, 1, 2, 3, 4); D = new Vector(4, 7, 17); X = M.solve(D); Assert.assertEquals(vNaN, X); X = M.solve(D); Assert.assertEquals(3, X.size()); Assert.assertFalse(Double.isNaN(X.coord(0))); M = new TriDiagonalMatrix(0, 1, 2, 3, 4, 5, 6); X0 = new Vector(1, 2, 3); D = M.times(X0); X1 = M.solve(D); Assert.assertEquals(X0, X1); M = new TriDiagonalMatrix(1, 2, 0, 3, 4, 5, 6); X0 = new Vector(1, 2, 3); D = M.times(X0); X1 = M.solve(D); Assert.assertEquals(X0, X1); M = new TriDiagonalMatrix(1, 2, 3, 4, 5,
/** * Find the X for which this*X=d * * Taken from @see <a href="http://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm></a>, but modified to * use recursion instead of iteration, and thereby to handle degenerate * cases cleanly. * * @param d * The result (<code>d</code>) in the above equation * @return The <code>X</code> in the above equation */ public Vector solve (Vector d) { if (d.size() != _n) throw new IllegalArgumentException("Attempt to find tri-diagonal solution with improper-sized vector"); double[] x = new double[_n]; solve(d, x, 0, (_n > 0 ? _b[0] : 0), (_n > 1 ? _c[0] : 0), (_n > 0 ? d.coord(0) : 0)); //solve3(d, x, 1, _b[0], _c[0], 0, d.coord(0)); return new Vector(x); }
double sln = solveDegenerate2D(a1, b1, d1, b0, c0, d0); x[currentColumn] = sln; x[currentColumn+1] = anythingIfNotNaN(sln); } else if (Math.abs(b0) >= EPSILON) { double sln = solveDegenerate2D(b0, c0, d0, b1, a1, d1); x[currentColumn] = sln; x[currentColumn+1] = anythingIfNotNaN(sln); } else if (Math.abs(c0) >= EPSILON) { double sln = solveDegenerate2D(c0, b0, d0, b1, a1, d1); x[currentColumn] = anythingIfNotNaN(sln); x[currentColumn+1] = sln; } else { double sln = solveDegenerate2D(b1, a1, d1, c0, b0, d0); x[currentColumn] = anythingIfNotNaN(sln); x[currentColumn+1] = sln;
@Test public void testMultiplication () { TriDiagonalMatrix M = new TriDiagonalMatrix(1, 1, 1, 1, 1, 1, 1, 1, 1, 1); Vector X = new Vector(1, 2, 3, 4); Vector Y = M.times(X); Assert.assertEquals(4, Y.size()); Assert.assertEquals(3, Y.coord(0), EPSILON); Assert.assertEquals(7, Y.coord(3), EPSILON); M = new TriDiagonalMatrix(1, 2, 3, 4, 5, 6, 7, 8, 3, 2); X = new Vector(1, 2, 3, 4, 5, 6); Y = M.times(X); Assert.assertEquals(6, Y.size()); Assert.assertEquals(5, Y.coord(0), EPSILON); M = new TriDiagonalMatrix(1, 2, 1, 2); X = new Vector(3, 4); Y = M.times(X); Assert.assertEquals(2, Y.size()); Assert.assertEquals(11, Y.coord(0), EPSILON);
@Test public void test1DSolving () { TriDiagonalMatrix M = new TriDiagonalMatrix(0); Vector D = new Vector(0); Vector X = M.solve(D); Assert.assertEquals(new Vector(1), X); D = new Vector(1); X = M.solve(D); Assert.assertEquals(new Vector(Double.NaN), X); M = new TriDiagonalMatrix(1); D = new Vector(0); X = M.solve(D); Assert.assertEquals(new Vector(0), X); D = new Vector(1); X = M.solve(D); Assert.assertEquals(new Vector(1), X); }
/** * Find the X for which this*X=d * * Taken from {@linkplain http * ://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm}, but modified to * use recursion instead of iteration, and thereby to handle degenerate * cases cleanly. * * @param d * The result (<code>d</code>) in the above equation * @return The <code>X</code> in the above equation */ public Vector solve (Vector d) { if (d.size() != _n) throw new IllegalArgumentException("Attempt to find tri-diagonal solution with improper-sized vector"); double[] x = new double[_n]; solve(d, x, 0, (_n > 0 ? _b[0] : 0), (_n > 1 ? _c[0] : 0), (_n > 0 ? d.coord(0) : 0)); //solve3(d, x, 1, _b[0], _c[0], 0, d.coord(0)); return new Vector(x); }
double sln = solveDegenerate2D(a1, b1, d1, b0, c0, d0); x[currentColumn] = sln; x[currentColumn+1] = anythingIfNotNaN(sln); } else if (Math.abs(b0) >= EPSILON) { double sln = solveDegenerate2D(b0, c0, d0, b1, a1, d1); x[currentColumn] = sln; x[currentColumn+1] = anythingIfNotNaN(sln); } else if (Math.abs(c0) >= EPSILON) { double sln = solveDegenerate2D(c0, b0, d0, b1, a1, d1); x[currentColumn] = anythingIfNotNaN(sln); x[currentColumn+1] = sln; } else { double sln = solveDegenerate2D(b1, a1, d1, c0, b0, d0); x[currentColumn] = anythingIfNotNaN(sln); x[currentColumn+1] = sln;
TriDiagonalMatrix M = new TriDiagonalMatrix(matrixEntries); Vector Y = getCoordinateVector(Ys, i); Vector K = getCoordinateVector(Ks, i); Vector P1i = M.solve(Y);
private void solve (Vector d, double[] x, int currentColumn, double b0, double c0, double d0) { // Solve the case M x = d for x, where we (the tri-diagonal matrix) is m // // We do this recursively, solving one place at a time, pretending in // each case we are at the top left of the matrix. // // So, in each case, we have one of three cases, depending on whether // there are one, two, or three-or-more rows left to solve. int rowsLeft = _n-currentColumn; if (0 == rowsLeft) { return; // Nothing left to solve } else if (1 == rowsLeft) { solveSingleRow(d, x, currentColumn, b0, c0, d0); } else if (2 == rowsLeft) { solveDoubleRow(d, x, currentColumn, b0, c0, d0); } else { solveTripleRow(d, x, currentColumn, b0, c0, d0); } }
solve(d, x, currentColumn+1, b1, c1, d1); return; } else { x[currentColumn] = anythingIfNotNaN(x[currentColumn+1]); return; solve(d, x, currentColumn+1, c0, 0, d0); solve(d, x, currentColumn+1, b1, c1, d1); solve(d, x, currentColumn+1, (b0 * b1 - a1 * c0), (b0 * c1),
private double anythingIfNotNaN (double testValue) { return ifNotNaN(testValue, 1.0); } private double ifNotNaN (double testValue, double value) {
M = new TriDiagonalMatrix(0, 0, 0, 0); D = new Vector(0, 0); X = M.solve(D); Assert.assertEquals(2, X.size()); Assert.assertFalse(Double.isNaN(X.coord(0))); X = M.solve(D); Assert.assertEquals(vNaN, X); X = M.solve(D); Assert.assertEquals(vNaN, X); M = new TriDiagonalMatrix(0, 1, 0, 1); D = new Vector(1, 0); X = M.solve(D); Assert.assertEquals(vNaN, X); X = M.solve(D); Assert.assertEquals(vNaN, X); X = M.solve(D); Assert.assertEquals(2, X.size()); Assert.assertFalse(Double.isNaN(X.coord(0))); M = new TriDiagonalMatrix(1, 1, 0, 0); D = new Vector(4, 0); X = M.solve(D); Assert.assertEquals(2, X.size()); Assert.assertEquals(4.0, X.coord(0) + X.coord(1), EPSILON);