public GMMDiag getMarginal(boolean[] mask) { int nc = 0; for (boolean flag : mask) if (flag) nc++; GMMDiag g = new GMMDiag(getNgauss(), nc); int curc = 0; for (int j = 0; j < ncoefs; j++) { if (mask[j]) { for (int i = 0; i < ngauss; i++) { g.setMean(i, curc, getMean(i, j)); g.setVar(i, curc, getVar(i, j)); } curc++; } } for (int i = 0; i < ngauss; i++) { g.setWeight(i, getWeight(i)); } g.precomputeDistance(); return g; }
fin.close(); fin = new BufferedReader(new FileReader(nom)); allocate(); nT = 0; for (int i = 0; i < ngauss; i++) { nT += weights[i]; for (int j = 0; j < ncoefs; j++) { setMean(i, j, Float.parseFloat(ss[j + 1])); setVar(i, j, Float.parseFloat(ss[j])); setWeight(i, weights[i] / nT); precomputeDistance(); } catch (IOException e) { e.printStackTrace();
private void allocate() { if (weights == null) allocateWeights(); if (means == null) { loglikes = new float[ngauss]; means = new float[ngauss][ncoefs]; covar = new float[ngauss][ncoefs]; logPreComputedGaussianFactor = new float[ngauss]; } }
/** * extracts ONE gaussian from the GMM * * @param i position * @return gaussian */ public GMMDiag getGauss(int i) { GMMDiag res = new GMMDiag(1, getNcoefs()); System.arraycopy(means[i], 0, res.means[0], 0, getNcoefs()); System.arraycopy(covar[i], 0, res.covar[0], 0, getNcoefs()); res.setWeight(0, 1); res.precomputeDistance(); return res; }
/** * 2 GMMs are considered to be equal when all of their parameters do not * differ from more than 1% * @param g second gmm to compare to * @return if GMMs are equal */ public boolean isEqual(GMMDiag g) { if (getNgauss() != g.getNgauss()) return false; if (getNgauss() != g.getNcoefs()) return false; for (int i = 0; i < getNgauss(); i++) { if (isDiff(getWeight(i), g.getWeight(i))) return false; for (int j = 0; j < getNcoefs(); j++) { if (isDiff(getMean(i, j), g.getMean(i, j))) return false; if (isDiff(getVar(i, j), g.getVar(i, j))) return false; } } return true; }
/** * * @param g second GMM for the merge * @param w1 weight of the first GMM for the merge * @return gaussian */ public GMMDiag merge(GMMDiag g, float w1) { GMMDiag res = new GMMDiag(getNgauss() + g.getNgauss(), getNcoefs()); for (int i = 0; i < getNgauss(); i++) { System.arraycopy(means[i], 0, res.means[i], 0, getNcoefs()); System.arraycopy(covar[i], 0, res.covar[i], 0, getNcoefs()); res.setWeight(i, getWeight(i) * w1); } for (int i = 0; i < g.getNgauss(); i++) { System.arraycopy(g.means[i], 0, res.means[ngauss + i], 0, getNcoefs()); System.arraycopy(g.covar[i], 0, res.covar[ngauss + i], 0, getNcoefs()); res.setWeight(ngauss + i, g.getWeight(i) * (1f - w1)); } res.precomputeDistance(); return res; }
public PrintWriter saveHTKheader(String nomFich, String parmKind) { try { PrintWriter fout = new PrintWriter(new FileWriter(nomFich)); fout.println("~o"); fout.println("<HMMSETID> tree"); fout.println("<STREAMINFO> 1 " + getNcoefs()); fout.println("<VECSIZE> " + getNcoefs() + "<NULLD>" + parmKind + "<DIAGC>"); fout.println("~r \"rtree_1\""); fout.println("<REGTREE> 1"); fout.println("<TNODE> 1 " + getNgauss()); return fout; } catch (IOException e) { e.printStackTrace(); return null; } }
/** * Saves in proprietary format * @param name name of file to save */ public void save(String name) { try { PrintWriter fout = new PrintWriter(new FileWriter(name)); fout.println(ngauss + " " + ncoefs); for (int i = 0; i < ngauss; i++) { fout.println("gauss " + i + ' ' + getWeight(i)); for (int j = 0; j < ncoefs; j++) fout.print(means[i][j] + " "); fout.println(); for (int j = 0; j < ncoefs; j++) fout.print(getVar(i, j) + " "); fout.println(); } fout.println(nT); fout.close(); } catch (IOException e) { e.printStackTrace(); } }
public void precomputeDistance() { for (int gidx = 0; gidx < ngauss; gidx++) { float fact = 0.0f; for (int i = 0; i < ncoefs; i++) { fact += logMath.linearToLog(getVar(gidx, i)); } fact += logMath.linearToLog(2.0 * Math.PI) * ncoefs; logPreComputedGaussianFactor[gidx] = fact * 0.5f; } }
public GMMDiag(int ng, int nc) { ngauss = ng; ncoefs = nc; allocate(); }
int getGMMSize() { GMMDiag gmm = hmmsHTK.gmms.get(0); return gmm.getNgauss(); }
int getNcoefs() { GMMDiag gmm = hmmsHTK.gmms.get(0); return gmm.getNcoefs(); }
public float getLogLike() { return gmm.getLogLike(); }
public GaussianWeights htkWeights(String path, float floor) { int numStates = getNumStates(); int numStreams = 1; int numGaussiansPerState = getGMMSize(); GaussianWeights mixtureWeights = new GaussianWeights(path, numStates, numGaussiansPerState, numStreams); for (int i = 0; i < numStates; i++) { GMMDiag gmm = hmmsHTK.gmms.get(i); float[] logWeights = new float[numGaussiansPerState]; for (int j = 0; j < numGaussiansPerState; j++) { logWeights[j] = gmm.getWeight(j); } Utilities.floorData(logWeights, mixtureWeightFloor); logMath.linearToLog(logWeights); // the order of the weights is the order in the HMMSet.gmms // vector which is the order of appearance in the MMF file mixtureWeights.put(i, 0, logWeights); } return mixtureWeights; }
public Pool<float[]> htkMeans(String path) { Pool<float[]> pool = new Pool<float[]>(path); // Suppose this is the number of shared states int numStates = getNumStates(); int numStreams = 1; int numGaussiansPerState = getGMMSize(); pool.setFeature(NUM_SENONES, numStates); pool.setFeature(NUM_STREAMS, numStreams); pool.setFeature(NUM_GAUSSIANS_PER_STATE, numGaussiansPerState); int ncoefs = getNcoefs(); for (int i = 0; i < numStates; i++) { GMMDiag gmm = hmmsHTK.gmms.get(i); for (int j = 0; j < numGaussiansPerState; j++) { float[] density = new float[ncoefs]; for (int k = 0; k < ncoefs; k++) { density[k] = gmm.getMean(j, k); } int id = i * numGaussiansPerState + j; // the order of the means is the order in the HMMSet.gmms // vector which is the order of appearance in the MMF file pool.put(id, density); } } return pool; }
public void saveHTKState(PrintWriter fout) { fout.println("<NUMMIXES> " + getNgauss()); for (int i = 1; i <= getNgauss(); i++) { fout.println("<MIXTURE> " + i + ' ' + getWeight(i - 1)); fout.println("<RCLASS> 1"); fout.println("<MEAN> " + getNcoefs()); for (int j = 0; j < getNcoefs(); j++) { fout.print(getMean(i - 1, j) + " "); } fout.println(); fout.println("<VARIANCE> " + getNcoefs()); for (int j = 0; j < getNcoefs(); j++) { fout.print(getVar(i - 1, j) + " "); } fout.println(); } }
/** * * @param g second GMM for the merge * @param w1 weight of the first GMM for the merge * @return gaussian */ public GMMDiag merge(GMMDiag g, float w1) { GMMDiag res = new GMMDiag(getNgauss() + g.getNgauss(), getNcoefs()); for (int i = 0; i < getNgauss(); i++) { System.arraycopy(means[i], 0, res.means[i], 0, getNcoefs()); System.arraycopy(covar[i], 0, res.covar[i], 0, getNcoefs()); res.setWeight(i, getWeight(i) * w1); } for (int i = 0; i < g.getNgauss(); i++) { System.arraycopy(g.means[i], 0, res.means[ngauss + i], 0, getNcoefs()); System.arraycopy(g.covar[i], 0, res.covar[ngauss + i], 0, getNcoefs()); res.setWeight(ngauss + i, g.getWeight(i) * (1f - w1)); } res.precomputeDistance(); return res; }
/** * extracts ONE gaussian from the GMM * * @param i position * @return gaussian */ public GMMDiag getGauss(int i) { GMMDiag res = new GMMDiag(1, getNcoefs()); System.arraycopy(means[i], 0, res.means[0], 0, getNcoefs()); System.arraycopy(covar[i], 0, res.covar[0], 0, getNcoefs()); res.setWeight(0, 1); res.precomputeDistance(); return res; }