List of usage examples for com.google.common.collect Table cellSet
Set<Cell<R, C, V>> cellSet();
From source file:net.librec.recommender.cf.ranking.FISMrmseRecommender.java
@Override protected void trainModel() throws LibrecException { int sampleSize = rho * trainMatrixSize; int totalSize = numUsers * numItems; for (int iter = 1; iter <= numIterations; iter++) { loss = 0.0d;//from w w w . j a v a 2s.c o m // temporal data DenseMatrix PS = new DenseMatrix(numItems, numFactors); DenseMatrix QS = new DenseMatrix(numItems, numFactors); // new training data by sampling negative values Table<Integer, Integer, Double> R = trainMatrix.getDataTable(); // make a random sample of negative feedback (total - nnz) List<Integer> indices = null; try { indices = Randoms.randInts(sampleSize, 0, totalSize - trainMatrixSize); } catch (Exception e) { e.printStackTrace(); } int index = 0, count = 0; boolean isDone = false; for (int u = 0; u < numUsers; u++) { for (int j = 0; j < numItems; j++) { double ruj = trainMatrix.get(u, j); if (ruj != 0) continue; // rated items if (count++ == indices.get(index)) { R.put(u, j, 0.0); index++; if (index >= indices.size()) { isDone = true; break; } } } if (isDone) break; } // update throughout each user-item-rating (u, j, ruj) cell for (Cell<Integer, Integer, Double> cell : R.cellSet()) { int u = cell.getRowKey(); int j = cell.getColumnKey(); double ruj = cell.getValue(); // for efficiency, use the below code to predict ruj instead of // simply using "predict(u,j)" SparseVector Ru = trainMatrix.row(u); double bu = userBiases.get(u), bj = itemBiases.get(j); double sum_ij = 0; int cnt = 0; for (VectorEntry ve : Ru) { int i = ve.index(); // for training, i and j should be equal as j may be rated // or unrated if (i != j) { sum_ij += DenseMatrix.rowMult(P, i, Q, j); cnt++; } } double wu = cnt > 0 ? Math.pow(cnt, -alpha) : 0; double puj = bu + bj + wu * sum_ij; double euj = puj - ruj; loss += euj * euj; // update bu userBiases.add(u, -learnRate * (euj + regBias * bu)); // update bj itemBiases.add(j, -learnRate * (euj + regBias * bj)); loss += regBias * bu * bu + regBias * bj * bj; // update qjf for (int f = 0; f < numFactors; f++) { double qjf = Q.get(j, f); double sum_i = 0; for (VectorEntry ve : Ru) { int i = ve.index(); if (i != j) { sum_i += P.get(i, f); } } double delta = euj * wu * sum_i + regItem * qjf; QS.add(j, f, -learnRate * delta); loss += regItem * qjf * qjf; } // update pif for (VectorEntry ve : Ru) { int i = ve.index(); if (i != j) { for (int f = 0; f < numFactors; f++) { double pif = P.get(i, f); double delta = euj * wu * Q.get(j, f) + regItem * pif; PS.add(i, f, -learnRate * delta); loss += regItem * pif * pif; } } } } P = P.add(PS); Q = Q.add(QS); loss *= 0.5d; if (isConverged(iter) && earlyStop) { break; } updateLRate(iter); } }
From source file:hu.ppke.itk.nlpg.purepos.decoder.BeamedViterbi.java
public List<Pair<List<Integer>, Double>> beamedSearch(final NGram<Integer> start, final List<String> observations, int resultsNumber) { HashMap<NGram<Integer>, Node> beam = new HashMap<NGram<Integer>, Node>(); beam.put(start, startNode(start));//from ww w . j av a 2 s . c o m boolean isFirst = true; int pos = 0; for (String obs : observations) { // System.err.println(obs); // logger.trace("Current observation " + obs); // logger.trace("\tCurrent states:"); // for (Entry<NGram<Integer>, Node> entry : beam.entrySet()) { // logger.trace("\t\t" + entry.getKey() + " - " + entry.getValue()); // } HashMap<NGram<Integer>, Node> newBeam = new HashMap<NGram<Integer>, Node>(); Table<NGram<Integer>, Integer, Double> nextProbs = HashBasedTable.create(); Map<NGram<Integer>, Double> obsProbs = new HashMap<NGram<Integer>, Double>(); Set<NGram<Integer>> contexts = beam.keySet(); Map<NGram<Integer>, Map<Integer, Pair<Double, Double>>> nexts = getNextProbs(contexts, obs, pos, isFirst); for (Map.Entry<NGram<Integer>, Map<Integer, Pair<Double, Double>>> nextsEntry : nexts.entrySet()) { NGram<Integer> context = nextsEntry.getKey(); Map<Integer, Pair<Double, Double>> nextContextProbs = nextsEntry.getValue(); for (Map.Entry<Integer, Pair<Double, Double>> entry : nextContextProbs.entrySet()) { Integer tag = entry.getKey(); nextProbs.put(context, tag, entry.getValue().getLeft()); obsProbs.put(context.add(tag), entry.getValue().getRight()); } } // for (Integer t : nextProbs.keySet()) { // logger.trace("\t\tNext node:" + context + t); // logger.trace("\t\tnode currentprob:" // + (beam.get(context) + nextProbs.get(t).getLeft())); // logger.trace("\t\tnode emissionprob:" // + nextProbs.get(t).getRight()); // logger.trace("\n"); // // logger.trace("\t\tNext node:" + context + t); // } for (Cell<NGram<Integer>, Integer, Double> cell : nextProbs.cellSet()) { Integer nextTag = cell.getColumnKey(); NGram<Integer> context = cell.getRowKey(); Double transVal = cell.getValue(); NGram<Integer> newState = context.add(nextTag); Node from = beam.get(context); double newVal = transVal + beam.get(context).getWeight(); update(newBeam, newState, newVal, from); } // adding observation probabilities // logger.trace("beam" + newBeam); if (nextProbs.size() > 1) for (NGram<Integer> tagSeq : newBeam.keySet()) { // Integer tag = tagSeq.getLast(); Node node = newBeam.get(tagSeq); // Double prevVal = node.getWeight(); Double obsProb = obsProbs.get(tagSeq); // logger.trace("put to beam: " + context + "(from) " // + tagSeq + " " + prevVal + "+" + obsProb); node.setWeight(obsProb + node.getWeight()); } beam = prune(newBeam); isFirst = false; // for (Entry<NGram<Integer>, Node> e : beam.entrySet()) { // logger.trace("\t\tNode state: " + e.getKey() + " " // + e.getValue()); // } ++pos; } return findMax(beam, resultsNumber); }
From source file:i5.las2peer.services.recommender.librec.rating.TimeComNeighSVD.java
/** * @return a list of length numCBins containing the tagging data that falls into each cbin (bins are numbered from 0 to numCBins-1) *//*from w ww . j a v a 2 s . c o m*/ private List<Table<Integer, Integer, Set<Long>>> tagDataCBins(Table<Integer, Integer, Set<Long>> tagTable) { Logs.info("{}{} split tagging data into bins for dynamic community structure detection ...", algoName, foldInfo); List<Table<Integer, Integer, Set<Long>>> tagTableList = new ArrayList<Table<Integer, Integer, Set<Long>>>( numCBins); for (int cbin = 0; cbin < numCBins; cbin++) { tagTableList.add(cbin, HashBasedTable.create()); } for (Cell<Integer, Integer, Set<Long>> c : tagTable.cellSet()) { int useritem = c.getRowKey(); // may be user or item int tag = c.getColumnKey(); Set<Long> times = c.getValue(); for (long time : times) { int days = days(time, minTrainTimestamp); int cbin = cbin(days); // cbins are numbered 1..numCBins if (!tagTableList.get(cbin - 1).contains(useritem, tag)) { tagTableList.get(cbin - 1).put(useritem, tag, new HashSet<Long>()); } tagTableList.get(cbin - 1).get(useritem, tag).add(time); } } for (int cbin = 0; cbin < numCBins; cbin++) { Logs.info("{}{} tagging data cbin {} contains {} tagging instances", algoName, foldInfo, cbin, tagTableList.get(cbin).size()); } return tagTableList; }
From source file:carskit.alg.cars.transformation.prefiltering.SPF.java
protected SparseMatrix getUIMatrix(int ctx) { DenseVector vc_target = getContextVector(ctx); // Table {row-id, col-id, rate} Table<Integer, Integer, Double> dataTable_ui = HashBasedTable.create(); Table<Integer, Integer, Double> dataTable_ui_count = HashBasedTable.create(); // Map {col-id, multiple row-id}: used to fast build a rating matrix Multimap<Integer, Integer> colMap = HashMultimap.create(); // read data to have a list of rating profiles for each uc pair for (MatrixEntry me : trainMatrix) { int ui = me.row(); // user-item int u = rateDao.getUserIdFromUI(ui); int j = rateDao.getItemIdFromUI(ui); int c = me.column(); // context DenseVector vc_current = getContextVector(c); double sim = cosineSimilarity(vc_target, vc_current); if (sim >= th) { double rujc = me.get(); if (dataTable_ui.contains(u, j)) { dataTable_ui.put(u, j, dataTable_ui.get(u, j)); dataTable_ui_count.put(u, j, dataTable_ui_count.get(u, j) + 1.0); } else { dataTable_ui.put(u, j, rujc); dataTable_ui_count.put(u, j, 1.0); }//from www .ja va2 s . c om } // formulate sparse matrix in order to perform SVD for (Cell<Integer, Integer, Double> cell : dataTable_ui.cellSet()) { int uu = cell.getRowKey(); int jj = cell.getColumnKey(); dataTable_ui.put(uu, jj, cell.getValue() / dataTable_ui_count.get(uu, jj)); colMap.put(jj, uu); } } return new SparseMatrix(numUsers, numItems, dataTable_ui, colMap); }
From source file:carskit.alg.cars.transformation.prefiltering.SPF.java
protected SparseMatrix getCUMatrix() { // Table {row-id, col-id, rate} Table<Integer, Integer, Double> dataTable_cu = HashBasedTable.create(); Table<Integer, Integer, Double> dataTable_cu_count = HashBasedTable.create(); // Map {col-id, multiple row-id}: used to fast build a rating matrix Multimap<Integer, Integer> colMap = HashMultimap.create(); // read data to have a list of rating profiles for each uc pair for (MatrixEntry me : trainMatrix) { int ui = me.row(); // user-item int u = rateDao.getUserIdFromUI(ui); int j = rateDao.getItemIdFromUI(ui); int ctx = me.column(); // context Collection<Integer> cs = rateDao.getContextConditionsList().get(ctx); double rujc = me.get(); double bui = mean + bu.get(u) + bi.get(j); for (int c : cs) { if (dataTable_cu.contains(c, u)) { dataTable_cu.put(c, u, dataTable_cu.get(c, u) + rujc - bui); dataTable_cu_count.put(c, u, dataTable_cu_count.get(c, u) + 1.0); } else { dataTable_cu.put(c, u, rujc - bui); dataTable_cu_count.put(c, u, 1.0); }//w w w. j a va 2s. c o m } } // formulate sparse matrix in order to perform SVD for (Cell<Integer, Integer, Double> cell : dataTable_cu.cellSet()) { int c = cell.getRowKey(); int u = cell.getColumnKey(); dataTable_cu.put(c, u, cell.getValue() / (beta + dataTable_cu_count.get(c, u))); colMap.put(u, c); } //Logs.info("numConditions = " + numConditions+", datatable.row = "+dataTable_cu.rowKeySet().size()); return new SparseMatrix(numConditions, numUsers, dataTable_cu, colMap); }
From source file:carskit.alg.cars.transformation.prefiltering.SPF.java
protected SparseMatrix getCIMatrix() { // Table {row-id, col-id, rate} Table<Integer, Integer, Double> dataTable_ci = HashBasedTable.create(); Table<Integer, Integer, Double> dataTable_ci_count = HashBasedTable.create(); // Map {col-id, multiple row-id}: used to fast build a rating matrix Multimap<Integer, Integer> colMap = HashMultimap.create(); // read data to have a list of rating profiles for each uc pair for (MatrixEntry me : trainMatrix) { int ui = me.row(); // user-item int u = rateDao.getUserIdFromUI(ui); int j = rateDao.getItemIdFromUI(ui); int ctx = me.column(); // context Collection<Integer> cs = rateDao.getContextConditionsList().get(ctx); double rujc = me.get(); double bui = mean + bu.get(u) + bi.get(j); for (int c : cs) { if (dataTable_ci.contains(c, j)) { dataTable_ci.put(c, j, dataTable_ci.get(c, j) + rujc - bui); dataTable_ci_count.put(c, j, dataTable_ci_count.get(c, j) + 1.0); } else { dataTable_ci.put(c, j, rujc - bui); dataTable_ci_count.put(c, j, 1.0); }// ww w. ja v a 2 s . c o m } } // formulate sparse matrix in order to perform SVD for (Cell<Integer, Integer, Double> cell : dataTable_ci.cellSet()) { int c = cell.getRowKey(); int j = cell.getColumnKey(); dataTable_ci.put(c, j, cell.getValue() / (beta + dataTable_ci_count.get(c, j))); colMap.put(j, c); } return new SparseMatrix(numConditions, numItems, dataTable_ci, colMap); }
From source file:net.librec.math.structure.SparseStringMatrix.java
/** * Construct a sparse matrix//w ww .j ava2s.co m * * @param dataTable data table * @param columnStructure column structure */ private void construct(Table<Integer, Integer, ? extends String> dataTable, Multimap<Integer, Integer> columnStructure) { int nnz = dataTable.size(); // CRS rowPtr = new int[numRows + 1]; colInd = new int[nnz]; rowData = new String[nnz]; int j = 0; for (int i = 1; i <= numRows; ++i) { Set<Integer> cols = dataTable.row(i - 1).keySet(); rowPtr[i] = rowPtr[i - 1] + cols.size(); for (int col : cols) { colInd[j++] = col; if (col < 0 || col >= numColumns) throw new IllegalArgumentException( "colInd[" + j + "]=" + col + ", which is not a valid column index"); } Arrays.sort(colInd, rowPtr[i - 1], rowPtr[i]); } // CCS colPtr = new int[numColumns + 1]; rowInd = new int[nnz]; colData = new String[nnz]; j = 0; for (int i = 1; i <= numColumns; ++i) { // dataTable.col(i-1) is more time-consuming than columnStructure.get(i-1) Collection<Integer> rows = columnStructure != null ? columnStructure.get(i - 1) : dataTable.column(i - 1).keySet(); colPtr[i] = colPtr[i - 1] + rows.size(); for (int row : rows) { rowInd[j++] = row; if (row < 0 || row >= numRows) throw new IllegalArgumentException( "rowInd[" + j + "]=" + row + ", which is not a valid row index"); } Arrays.sort(rowInd, colPtr[i - 1], colPtr[i]); } // set data for (Cell<Integer, Integer, ? extends String> en : dataTable.cellSet()) { int row = en.getRowKey(); int col = en.getColumnKey(); String val = en.getValue().toString(); set(row, col, val); } }
From source file:com.numb3r3.common.data.SparseMatrix.java
/** * Construct a sparse matrix//from w w w . j ava2s .c o m * * @param dataTable data table * @param columnStructure column structure */ private void construct(Table<Integer, Integer, Double> dataTable, Multimap<Integer, Integer> columnStructure) { int nnz = dataTable.size(); // CRS rowPtr = new int[numRows + 1]; colInd = new int[nnz]; rowData = new double[nnz]; int j = 0; for (int i = 1; i <= numRows; ++i) { Set<Integer> cols = dataTable.row(i - 1).keySet(); rowPtr[i] = rowPtr[i - 1] + cols.size(); for (int col : cols) { colInd[j++] = col; if (col < 0 || col >= numColumns) throw new IllegalArgumentException( "colInd[" + j + "]=" + col + ", which is not a valid column index"); } Arrays.sort(colInd, rowPtr[i - 1], rowPtr[i]); } // CCS if (columnStructure != null) { colPtr = new int[numColumns + 1]; rowInd = new int[nnz]; colData = new double[nnz]; j = 0; for (int i = 1; i <= numColumns; ++i) { // dataTable.col(i-1) is very time-consuming Collection<Integer> rows = columnStructure.get(i - 1); colPtr[i] = colPtr[i - 1] + rows.size(); for (int row : rows) { rowInd[j++] = row; if (row < 0 || row >= numRows) throw new IllegalArgumentException( "rowInd[" + j + "]=" + row + ", which is not a valid row index"); } Arrays.sort(rowInd, colPtr[i - 1], colPtr[i]); } } // set data for (Cell<Integer, Integer, Double> en : dataTable.cellSet()) { int row = en.getRowKey(); int col = en.getColumnKey(); double val = en.getValue(); set(row, col, val); } }
From source file:matrix.SparseMatrix.java
/** * Construct a sparse matrix//from w w w. java2 s. com * * @param dataTable * data table * @param columnStructure * column structure */ private void construct(Table<Integer, Integer, Float> dataTable, Multimap<Integer, Integer> columnStructure) { int nnz = dataTable.size(); // CRS rowPtr = new int[numRows + 1]; colInd = new int[nnz]; rowData = new float[nnz]; int j = 0; for (int i = 1; i <= numRows; ++i) { Set<Integer> cols = dataTable.row(i - 1).keySet(); rowPtr[i] = rowPtr[i - 1] + cols.size(); for (int col : cols) { colInd[j++] = col; if (col < 0 || col >= numColumns) throw new IllegalArgumentException( "colInd[" + j + "]=" + col + ", which is not a valid column index"); } Arrays.sort(colInd, rowPtr[i - 1], rowPtr[i]); } // CCS if (columnStructure != null) { colPtr = new int[numColumns + 1]; rowInd = new int[nnz]; colData = new float[nnz]; j = 0; for (int i = 1; i <= numColumns; ++i) { // dataTable.col(i-1) is very time-consuming Collection<Integer> rows = columnStructure.get(i - 1); colPtr[i] = colPtr[i - 1] + rows.size(); for (int row : rows) { rowInd[j++] = row; if (row < 0 || row >= numRows) throw new IllegalArgumentException( "rowInd[" + j + "]=" + row + ", which is not a valid row index"); } Arrays.sort(rowInd, colPtr[i - 1], colPtr[i]); } } // set data for (Cell<Integer, Integer, Float> en : dataTable.cellSet()) { int row = en.getRowKey(); int col = en.getColumnKey(); float val = en.getValue(); set(row, col, val); } }
From source file:librec.data.SparseMatrix.java
/** * Construct a sparse matrix//w ww .java 2 s .c o m * * @param dataTable * data table * @param columnStructure * column structure */ private void construct(Table<Integer, Integer, ? extends Number> dataTable, Multimap<Integer, Integer> columnStructure) { int nnz = dataTable.size(); // CRS rowPtr = new int[numRows + 1]; colInd = new int[nnz]; rowData = new double[nnz]; int j = 0; for (int i = 1; i <= numRows; ++i) { Set<Integer> cols = dataTable.row(i - 1).keySet(); rowPtr[i] = rowPtr[i - 1] + cols.size(); for (int col : cols) { colInd[j++] = col; if (col < 0 || col >= numColumns) throw new IllegalArgumentException( "colInd[" + j + "]=" + col + ", which is not a valid column index"); } Arrays.sort(colInd, rowPtr[i - 1], rowPtr[i]); } // CCS colPtr = new int[numColumns + 1]; rowInd = new int[nnz]; colData = new double[nnz]; j = 0; for (int i = 1; i <= numColumns; ++i) { // dataTable.col(i-1) is more time-consuming than columnStructure.get(i-1) Collection<Integer> rows = columnStructure != null ? columnStructure.get(i - 1) : dataTable.column(i - 1).keySet(); colPtr[i] = colPtr[i - 1] + rows.size(); for (int row : rows) { rowInd[j++] = row; if (row < 0 || row >= numRows) throw new IllegalArgumentException( "rowInd[" + j + "]=" + row + ", which is not a valid row index"); } Arrays.sort(rowInd, colPtr[i - 1], colPtr[i]); } // set data for (Cell<Integer, Integer, ? extends Number> en : dataTable.cellSet()) { int row = en.getRowKey(); int col = en.getColumnKey(); double val = en.getValue().doubleValue(); set(row, col, val); } }