public static Map<String, SimpleMatrix> averageUnaryMatrices(List<Map<String, SimpleMatrix>> maps) { Map<String, SimpleMatrix> averages = Generics.newTreeMap(); for (String name : getUnaryMatrixNames(maps)) { int count = 0; SimpleMatrix matrix = null; for (Map<String, SimpleMatrix> map : maps) { if (!map.containsKey(name)) { continue; } SimpleMatrix original = map.get(name); ++count; if (matrix == null) { matrix = original; } else { matrix = matrix.plus(original); } } matrix = matrix.divide(count); averages.put(name, matrix); } return averages; }
private SimpleMatrix getAverageEmbedding(List<CoreLabel> words) { SimpleMatrix emb = new SimpleMatrix(staticWordEmbeddings.getEmbeddingSize(), 1); for (CoreLabel word : words) { emb = emb.plus(getStaticWordEmbedding(word.word())); } return emb.divide(Math.max(1, words.size())); }
private SimpleMatrix getAverageEmbeddings() { double[][] vec = new double[embeddings.getEmbeddingSize()][1]; SimpleMatrix totalVector = new SimpleMatrix(vec); for (IndexedWord w : sequence) { SimpleMatrix vector = embeddings.get(w.word().toLowerCase()); if (vector != null) { totalVector = totalVector.plus(vector); } } return totalVector.divide(this.sequence.size()); } }
unknownNumberVector = unknownNumberVector.divide(numberCount); } else { unknownNumberVector = new SimpleMatrix(unknownWordVector); unknownCapsVector = unknownCapsVector.divide(capsCount); } else { unknownCapsVector = new SimpleMatrix(unknownWordVector); log.info("Matched " + chineseYearCount + " chinese year vectors"); if (chineseYearCount > 0) { unknownChineseYearVector = unknownChineseYearVector.divide(chineseYearCount); } else { unknownChineseYearVector = new SimpleMatrix(unknownWordVector); log.info("Matched " + chineseNumberCount + " chinese number vectors"); if (chineseNumberCount > 0) { unknownChineseNumberVector = unknownChineseNumberVector.divide(chineseNumberCount); } else { unknownChineseNumberVector = new SimpleMatrix(unknownWordVector); log.info("Matched " + chinesePercentCount + " chinese percent vectors"); if (chinesePercentCount > 0) { unknownChinesePercentVector = unknownChinesePercentVector.divide(chinesePercentCount); } else { unknownChinesePercentVector = new SimpleMatrix(unknownWordVector);
public static TwoDimensionalMap<String, String, SimpleMatrix> averageBinaryMatrices(List<TwoDimensionalMap<String, String, SimpleMatrix>> maps) { TwoDimensionalMap<String, String, SimpleMatrix> averages = TwoDimensionalMap.treeMap(); for (Pair<String, String> binary : getBinaryMatrixNames(maps)) { int count = 0; SimpleMatrix matrix = null; for (TwoDimensionalMap<String, String, SimpleMatrix> map : maps) { if (!map.contains(binary.first(), binary.second())) { continue; } SimpleMatrix original = map.get(binary.first(), binary.second()); ++count; if (matrix == null) { matrix = original; } else { matrix = matrix.plus(original); } } matrix = matrix.divide(count); averages.put(binary.first(), binary.second(), matrix); } return averages; }
private SimpleMatrix getAverageEmbedding(List<CoreLabel> words) { SimpleMatrix emb = new SimpleMatrix(staticWordEmbeddings.getEmbeddingSize(), 1); for (CoreLabel word : words) { emb = emb.plus(getStaticWordEmbedding(word.word())); } return emb.divide(Math.max(1, words.size())); }
unknownNumberVector = unknownNumberVector.divide(numberCount); } else { unknownNumberVector = new SimpleMatrix(unknownWordVector); unknownCapsVector = unknownCapsVector.divide(capsCount); } else { unknownCapsVector = new SimpleMatrix(unknownWordVector); log.info("Matched " + chineseYearCount + " chinese year vectors"); if (chineseYearCount > 0) { unknownChineseYearVector = unknownChineseYearVector.divide(chineseYearCount); } else { unknownChineseYearVector = new SimpleMatrix(unknownWordVector); log.info("Matched " + chineseNumberCount + " chinese number vectors"); if (chineseNumberCount > 0) { unknownChineseNumberVector = unknownChineseNumberVector.divide(chineseNumberCount); } else { unknownChineseNumberVector = new SimpleMatrix(unknownWordVector); log.info("Matched " + chinesePercentCount + " chinese percent vectors"); if (chinesePercentCount > 0) { unknownChinesePercentVector = unknownChinesePercentVector.divide(chinesePercentCount); } else { unknownChinesePercentVector = new SimpleMatrix(unknownWordVector);
unknownNumberVector = unknownNumberVector.divide(numberCount); } else { unknownNumberVector = new SimpleMatrix(unknownWordVector); unknownCapsVector = unknownCapsVector.divide(capsCount); } else { unknownCapsVector = new SimpleMatrix(unknownWordVector); System.err.println("Matched " + chineseYearCount + " chinese year vectors"); if (chineseYearCount > 0) { unknownChineseYearVector = unknownChineseYearVector.divide(chineseYearCount); } else { unknownChineseYearVector = new SimpleMatrix(unknownWordVector); System.err.println("Matched " + chineseNumberCount + " chinese number vectors"); if (chineseNumberCount > 0) { unknownChineseNumberVector = unknownChineseNumberVector.divide(chineseNumberCount); } else { unknownChineseNumberVector = new SimpleMatrix(unknownWordVector); System.err.println("Matched " + chinesePercentCount + " chinese percent vectors"); if (chinesePercentCount > 0) { unknownChinesePercentVector = unknownChinesePercentVector.divide(chinesePercentCount); } else { unknownChinesePercentVector = new SimpleMatrix(unknownWordVector);
unknownNumberVector = unknownNumberVector.divide(numberCount); } else { unknownNumberVector = new SimpleMatrix(unknownWordVector); unknownCapsVector = unknownCapsVector.divide(capsCount); } else { unknownCapsVector = new SimpleMatrix(unknownWordVector); log.info("Matched " + chineseYearCount + " chinese year vectors"); if (chineseYearCount > 0) { unknownChineseYearVector = unknownChineseYearVector.divide(chineseYearCount); } else { unknownChineseYearVector = new SimpleMatrix(unknownWordVector); log.info("Matched " + chineseNumberCount + " chinese number vectors"); if (chineseNumberCount > 0) { unknownChineseNumberVector = unknownChineseNumberVector.divide(chineseNumberCount); } else { unknownChineseNumberVector = new SimpleMatrix(unknownWordVector); log.info("Matched " + chinesePercentCount + " chinese percent vectors"); if (chinesePercentCount > 0) { unknownChinesePercentVector = unknownChinesePercentVector.divide(chinesePercentCount); } else { unknownChinesePercentVector = new SimpleMatrix(unknownWordVector);
pos2Array = pos1dotArray.scale(tkArray.get(i)).divide(2); pos2Array = posArray.plus(pos2Array); vel2Array = vel1dotArray.scale(tkArray.get(i)).divide(2); vel2Array = velArray.plus(vel2Array); pos3Array = pos2dotArray.scale(tkArray.get(i)).divide(2); pos3Array = posArray.plus(pos3Array); vel3Array = vel2dotArray.scale(tkArray.get(i)).divide(2); vel3Array = velArray.plus(vel3Array); subPosArray = subPosArray.scale(tkArray.get(i)).divide(6); posArray = posArray.plus(subPosArray) ; subVelArray = subVelArray.scale(tkArray.get(i)).divide(6); velArray = velArray.plus(subVelArray) ;