@Override protected double[] matrixEqnSolver(final double[][] doubMat, final double[] doubVec) { final int sizeM1 = doubMat.length - 1; final double[] a = new double[sizeM1]; final double[] b = new double[sizeM1 + 1]; final double[] c = new double[sizeM1]; for (int i = 0; i < sizeM1; ++i) { a[i] = doubMat[i][i + 1]; b[i] = doubMat[i][i]; c[i] = doubMat[i + 1][i]; } b[sizeM1] = doubMat[sizeM1][sizeM1]; final TridiagonalMatrix m = new TridiagonalMatrix(b, a, c); return TridiagonalSolver.solvTriDag(m, doubVec); }
/** * Solves the system Ax = y for the unknown vector x, where A is a tridiagonal matrix and y is a vector. * This takes order n operations where n is the size of the system * (number of linear equations), as opposed to order n^3 for the general problem. * @param aM tridiagonal matrix * @param b known vector (must be same length as rows/columns of matrix) * @return vector with same length as y */ public static DoubleArray solvTriDag(TridiagonalMatrix aM, DoubleArray b) { return DoubleArray.copyOf(solvTriDag(aM, b.toArray())); }
@Override protected DoubleArray[] combinedMatrixEqnSolver(final double[][] doubMat1, final double[] doubVec, final double[][] doubMat2) { final int size = doubVec.length; final DoubleArray[] res = new DoubleArray[size + 1]; final DoubleMatrix doubMat2Matrix = DoubleMatrix.copyOf(doubMat2); final double[] u = new double[size - 1]; final double[] d = new double[size]; final double[] l = new double[size - 1]; for (int i = 0; i < size - 1; ++i) { u[i] = doubMat1[i][i + 1]; d[i] = doubMat1[i][i]; l[i] = doubMat1[i + 1][i]; } d[size - 1] = doubMat1[size - 1][size - 1]; final TridiagonalMatrix m = new TridiagonalMatrix(d, u, l); res[0] = DoubleArray.copyOf(TridiagonalSolver.solvTriDag(m, doubVec)); for (int i = 0; i < size; ++i) { DoubleArray doubMat2Colum = doubMat2Matrix.column(i); res[i + 1] = TridiagonalSolver.solvTriDag(m, doubMat2Colum); } return res; }
@Test public void test() { final int n = 97; double[] a = new double[n - 1]; double[] b = new double[n]; double[] c = new double[n - 1]; double[] x = new double[n]; for (int ii = 0; ii < n; ii++) { b[ii] = RANDOM.nextRandom(); x[ii] = RANDOM.nextRandom(); if (ii < n - 1) { a[ii] = RANDOM.nextRandom(); c[ii] = RANDOM.nextRandom(); } } final TridiagonalMatrix m = new TridiagonalMatrix(b, a, c); final DoubleArray xVec = DoubleArray.copyOf(x); final DoubleArray yVec = (DoubleArray) MA.multiply(m, xVec); final double[] xSolv = solvTriDag(m, yVec).toArray(); for (int i = 0; i < n; i++) { assertEquals(x[i], xSolv[i], 1e-9); } DoubleArray resi = (DoubleArray) MA.subtract(MA.multiply(m, DoubleArray.copyOf(xSolv)), yVec); double err = MA.getNorm2(resi); assertEquals(0.0, err, 1e-14); }