/** * Solves the system Ax = y for the unknown vector x, where A is a tridiagonal matrix and y is a vector. * This takes order n operations where n is the size of the system * (number of linear equations), as opposed to order n^3 for the general problem. * @param aM tridiagonal matrix * @param b known vector (must be same length as rows/columns of matrix) * @return vector (as an array of doubles) with same length as y */ public static double[] solvTriDag(TridiagonalMatrix aM, double[] b) { ArgChecker.notNull(aM, "null matrix"); ArgChecker.notNull(b, "null vector"); double[] d = aM.getDiagonal(); //b is modified, so get copy of diagonal int n = d.length; ArgChecker.isTrue(n == b.length, "vector y wrong length for matrix"); double[] y = Arrays.copyOf(b, n); double[] l = aM.getLowerSubDiagonalData(); double[] u = aM.getUpperSubDiagonalData(); double[] x = new double[n]; for (int i = 1; i < n; i++) { double m = l[i - 1] / d[i - 1]; d[i] = d[i] - m * u[i - 1]; y[i] = y[i] - m * y[i - 1]; } x[n - 1] = y[n - 1] / d[n - 1]; for (int i = n - 2; i >= 0; i--) { x[i] = (y[i] - u[i] * x[i + 1]) / d[i]; } return x; }
private DoubleArray multiply(TridiagonalMatrix matrix, DoubleArray vector) { double[] a = matrix.getLowerSubDiagonalData(); double[] b = matrix.getDiagonalData(); double[] c = matrix.getUpperSubDiagonalData(); double[] x = vector.toArrayUnsafe(); int n = x.length; ArgChecker.isTrue(b.length == n, "Matrix/vector size mismatch"); double[] res = new double[n]; int i; res[0] = b[0] * x[0] + c[0] * x[1]; res[n - 1] = b[n - 1] * x[n - 1] + a[n - 2] * x[n - 2]; for (i = 1; i < n - 1; i++) { res[i] = a[i - 1] * x[i - 1] + b[i] * x[i] + c[i] * x[i + 1]; } return DoubleArray.ofUnsafe(res); }
private DoubleArray multiply(DoubleArray vector, TridiagonalMatrix matrix) { double[] a = matrix.getLowerSubDiagonalData(); double[] b = matrix.getDiagonalData(); double[] c = matrix.getUpperSubDiagonalData(); double[] x = vector.toArrayUnsafe(); int n = x.length; ArgChecker.isTrue(b.length == n, "Matrix/vector size mismatch"); double[] res = new double[n]; int i; res[0] = b[0] * x[0] + a[0] * x[1]; res[n - 1] = b[n - 1] * x[n - 1] + c[n - 2] * x[n - 2]; for (i = 1; i < n - 1; i++) { res[i] = a[i] * x[i + 1] + b[i] * x[i] + c[i - 1] * x[i - 1]; } return DoubleArray.ofUnsafe(res); }
ArgChecker.notNull(x, "x"); double[] a = x.getDiagonalData(); double[] b = x.getUpperSubDiagonalData(); double[] c = x.getLowerSubDiagonalData(); int n = a.length;
@Test public void testGetters() { assertTrue(Arrays.equals(A, M.getDiagonalData())); assertTrue(Arrays.equals(B, M.getUpperSubDiagonalData())); assertTrue(Arrays.equals(C, M.getLowerSubDiagonalData())); final int n = A.length; final DoubleMatrix matrix = M.toDoubleMatrix(); for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { if (i == j) { assertEquals(matrix.get(i, j), A[i], 0); } else if (j == i + 1) { assertEquals(matrix.get(i, j), B[j - 1], 0); } else if (j == i - 1) { assertEquals(matrix.get(i, j), C[j], 0); } else { assertEquals(matrix.get(i, j), 0, 0); } } } }