用高斯消元法求解線性方程組
阿新 • • 發佈:2019-02-15
con sta eat sed 高斯 sso 一個 ids row
線性方程組問題可以利用矩陣變換求解。利用高斯消元法,將矩陣轉換成一個行階梯矩陣,最後得到一個簡化行階梯矩陣,就是方程的解。參考資料(高斯消元法)
Java代碼
public class FunctionResolver { public static class LinearEquationGroup { /** 代表線性方程組的矩陣。方程組已經經過歸一化處理,帶未知變量的部分全部位於“=”左邊,常數合並後位於“=”右邊。 比如: 2a + b - c = 8 -3a - b + 2c = -11 -2a + b + 2c = -3 對應的矩陣為: 2 1 -1 8 -3 -1 2 -11 -2 1 2 -3 */ private BigDecimal[][] matrix; //未知變量的名稱,排列順序和矩陣一致,比如上面的例子中,對應的變量名列表就是a, b, c private List<String> variantes; public LinearEquationGroup() { } public LinearEquationGroup(BigDecimal[][] matrix) { this.matrix = matrix; check(); } private void check() { for (int i = 0; i < matrix.length; i++) { if (matrix.length != (matrix[i].length - 1)) { throw new IllegalArgumentException("輸入矩陣有誤! 必須為n*(n+1)矩陣"); } } } public Map<String, BigDecimal> solve() { check(); sortRows(); eliminateVarianteDownwards(); normalize(); eliminateVarianteUpwards(); Map<String, BigDecimal> ret = new HashMap<>(); int i = 0; int lastCol = variantes.size(); for (String var : variantes) { ret.put(var, matrix[i++][lastCol]); } return ret; } /** * 重排序行, 以便做高斯消元. 保證第i行的第i列元素不為0 */ void sortRows() { int row = 0; int below = 0; int col = 0; for (; row < matrix.length - 1; ++row) { col = row; if (matrix[row][col].compareTo(BigDecimal.ZERO) == 0) { for (below = row + 1; below < matrix.length; below++) { if (matrix[below][col].compareTo(BigDecimal.ZERO) != 0) { BigDecimal[] temp = matrix[row]; matrix[row] = matrix[below]; matrix[below] = temp; break; } } if (below >= matrix.length) { throw new IllegalArgumentException("方程組無解或者無唯一解!"); } } } } /** * 從上往下消元。消元結果使得第i行以下的第i列都為0 */ void eliminateVarianteDownwards() { int baseRow = 0; int targetRow; final BigDecimal minusOne = new BigDecimal(-1); int colCnt = matrix[0].length; int rowCnt = matrix.length; for (; baseRow < matrix.length - 1; ++baseRow) { int col = baseRow; for (targetRow = baseRow + 1; targetRow < matrix.length; ++targetRow) { if (matrix[targetRow][col].compareTo(BigDecimal.ZERO) != 0) { BigDecimal fraction = matrix[targetRow][col].divide(matrix[baseRow][col]).multiply(minusOne); for (int j = 0; j < colCnt; j++) { matrix[targetRow][j] = matrix[targetRow][j].add(fraction.multiply(matrix[baseRow][j])); } } } } } //歸一化。使得第i行第i列的元素都為1,這樣最後得到的簡化行階梯矩陣的最後一列即為解 void normalize() { for (int i = 0; i < matrix.length; i++) { if (matrix[i][i].compareTo(BigDecimal.ZERO) == 0) { throw new IllegalArgumentException("方程組無解或者無唯一解!"); } if (matrix[i][i].compareTo(BigDecimal.ONE) == 0) continue; for (int j = matrix[i].length - 1; j >= i; j--) { matrix[i][j] = matrix[i][j].divide(matrix[i][i]); } } } //從下往上消元。使得第i行以上第i列的元素都為0 void eliminateVarianteUpwards() { int baseRow = matrix.length - 1; int colCnt = matrix[0].length; final BigDecimal minusOne = new BigDecimal(-1); for (; baseRow > 0; baseRow--) { int col = baseRow; for (int targetRow = baseRow - 1; targetRow >= 0; targetRow--) { if (matrix[targetRow][col].compareTo(BigDecimal.ZERO) != 0) { BigDecimal fraction = minusOne.multiply(matrix[targetRow][col].divide(matrix[baseRow][col])); for (int j = 0; j < colCnt; j++) { matrix[targetRow][j] = matrix[targetRow][j].add(fraction.multiply(matrix[baseRow][j])); } } } } } public BigDecimal[][] getMatrix() { return matrix; } public List<String> getVariantes() { return variantes; } public void setVariantes(List<String> variantes) { this.variantes = variantes; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; LinearEquationGroup that = (LinearEquationGroup) o; return Arrays.deepEquals(matrix, that.matrix); } @Override public int hashCode() { return Arrays.deepHashCode(matrix); } } //解析Map形式輸入的方程組。輸入類型Map<String, String>,每個entry代表一個方程,entry的key是方程“=”左邊部分,entry的value是“=”右邊部分。未知變量用字母串表示。 public static class LinearEquationGroupBuilder { private Map<String, List<String>> keyTokens = new HashMap<>(); private Map<String, List<String>> valueTokens = new HashMap<>(); public LinearEquationGroup createFrom(Map<String, String> funcMap) { final LinearEquationGroup linearEquationGroup = new LinearEquationGroup(); linearEquationGroup.matrix = new BigDecimal[funcMap.size()][]; funcMap.forEach((k, v) -> { keyTokens.put(k, tokenize(k)); valueTokens.put(v, tokenize(v)); }); Map<String, Integer> vars = findVariantes(funcMap); final ArrayList<String> variants = new ArrayList<>(vars.size()); vars.forEach((k, v) -> variants.add(v, k)); linearEquationGroup.setVariantes(variants); BigDecimal[][] matrix = linearEquationGroup.matrix; for (int i = 0; i < matrix.length; i++) { matrix[i] = new BigDecimal[vars.size() + 1]; for (int j = 0; j < matrix[i].length; j++) { matrix[i][j] = new BigDecimal(0); } } buildMatrix(matrix, funcMap, vars); return linearEquationGroup; } private Map<String, Integer> findVariantes(Map<String, String> fucMap) { final Map<String, Integer> map = new HashMap<>(); final int[] keyIdx = new int[1]; keyIdx[0] = 0; fucMap.forEach((k,v) -> { parseVariantes(map, keyIdx, keyTokens.get(k)); parseVariantes(map, keyIdx, valueTokens.get(v)); }); return map; } private void parseVariantes(Map<String, Integer> map, int[] keyIdx, List<String> tokens) { for (String part : tokens) { final Object[] objects = splitFactorAndVar(part); if (objects[1] != null) { final String var = (String)objects[1]; if (!map.containsKey(var)) { map.put(var, keyIdx[0]++); } } } } private LinkedList<String> tokenize(String k) { final LinkedList<String> parts = new LinkedList<>(); int end = k.length(); for (int i = k.length() - 1; i >= 0; i--) { if (k.charAt(i) == ‘-‘) { final String trimed = k.substring(i, end).trim(); if (trimed.length() > 0) parts.addFirst(trimed); end = i; } else if (k.charAt(i) == ‘+‘) { final String trimed = k.substring(i + 1, end).trim(); if (trimed.length() > 0) parts.addFirst(trimed); end = i; } else { // } } if (end > 0) { final String trimed = k.substring(0, end).trim(); if (trimed.length() > 0) parts.addFirst(trimed); } return parts; } /** * 解析數字和變量 * @param part 數字和變量, 或者純數字, 或者純變量, 比如2a, 1.5b, 7, c, -3d, -e * @return 返回一個數組, 0號元素是解析出來的系數, 1號元素是變量名. 系數或變量名可能不存在,這個時候對應位置的元素為null */ Object[] splitFactorAndVar(String part) { Object[] ret = new Object[]{null, null}; part = part.trim(); boolean isNegative = false; if (part.charAt(0) == ‘-‘) { isNegative = true; part = part.substring(1); } int idx = part.length() - 1; while (idx >= 0 && !Character.isDigit(part.charAt(idx))) --idx; if (idx >= 0) { double factor = Double.parseDouble(part.substring(0, idx + 1)); ret[0] = factor; String var = null; if (idx + 1 < part.length()) { var = part.substring(idx + 1); ret[1] = var.trim(); } } else { ret[1] = part.trim(); ret[0] = 1.0; } if (isNegative) { ret[0] = -1 * (double)ret[0]; } return ret; } private void buildMatrix(BigDecimal[][] matric, Map<String, String> funcMap, Map<String, Integer> varIds) { final int[] i = new int[1]; i[0] = 0; funcMap.forEach((k, v) -> { //matric[i[0]] = new BigDecimal[varIds.size() + 1]; processOneFunction(matric[i[0]++], k, v, varIds); }); } void processOneFunction(BigDecimal[] matricRow, String key, String value, Map<String, Integer> varIds) { processOneside(matricRow, key, varIds, true); processOneside(matricRow, value, varIds, false); } private void processOneside(BigDecimal[] matricRow, String key, Map<String, Integer> varIds, boolean isKey) { int commonFactor; List<String> parts; if (isKey) { commonFactor = 1; parts = keyTokens.get(key); } else { commonFactor = -1; parts = valueTokens.get(key); } for (String part : parts) { final Object[] objects = splitFactorAndVar(part); // 含有變量 if (objects[1] != null) { if (objects[0] != null) { final Double factor = (Double) objects[0]; final Integer idx = varIds.get((String) objects[1]); if (idx == null) { throw new IllegalArgumentException("can‘t found " + objects[1] + " in varIds"); } if (matricRow[idx] == null) { matricRow[idx] = new BigDecimal(factor * commonFactor); } else { matricRow[idx] = matricRow[idx].add(new BigDecimal(factor * commonFactor)); } } else { final Integer idx = varIds.get((String) objects[1]); if (idx == null) { throw new IllegalArgumentException("can‘t found " + objects[1] + " in varIds"); } if (matricRow[idx] == null) { matricRow[idx] = new BigDecimal(1 * commonFactor); } else { matricRow[idx] = matricRow[idx].add(new BigDecimal(1 * commonFactor)); } } } else { //只是數字 if (objects[0] != null) { matricRow[matricRow.length - 1] = matricRow[matricRow.length - 1].add(new BigDecimal(-1 * commonFactor * ((Double)objects[0]))); } else { throw new IllegalArgumentException("系數和數字不能同時為null: " + key); } } } } } public static void main(String[] args) { final FunctionResolver.LinearEquationGroupBuilder linearEquationGroupBuilder = new FunctionResolver.LinearEquationGroupBuilder(); Map<String, String> fMap = new HashMap<>(); fMap.put("2a+b-c", "8"); fMap.put("-3a-b+2c", "-11"); fMap.put("-2a+b+2c", "-3"); FunctionResolver.LinearEquationGroup equationGroup = linearEquationGroupBuilder.createFrom(fMap); final Map<String, BigDecimal> solve = equationGroup.solve(); solve.forEach((k, v) -> System.out.println(k + "=" + v)); } }
復雜度分析
該算法的時間復雜度為O(n^3),空間復雜度為O(n^2)。對於維度不高的線性方程還是可以接受。
用高斯消元法求解線性方程組