Java tutorial
/* TagRecommender: A framework to implement and evaluate algorithms for the recommendation of tags. Copyright (C) 2013 Dominik Kowald This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see <http://www.gnu.org/licenses/>. */ package processing; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.TreeMap; import java.util.TreeSet; import java.util.concurrent.TimeUnit; import java.util.regex.Pattern; import com.google.common.base.Stopwatch; import com.google.common.primitives.Doubles; import com.google.common.primitives.Ints; import common.DoubleMapComparator; import common.UserData; import common.Utilities; import cc.mallet.pipe.Array2FeatureVector; import cc.mallet.pipe.CharSequence2TokenSequence; import cc.mallet.pipe.CharSequenceArray2TokenSequence; import cc.mallet.pipe.CharSequenceLowercase; import cc.mallet.pipe.Pipe; import cc.mallet.pipe.PrintInputAndTarget; import cc.mallet.pipe.SerialPipes; import cc.mallet.pipe.StringList2FeatureSequence; import cc.mallet.pipe.TokenSequence2FeatureSequence; import cc.mallet.pipe.TokenSequenceLowercase; import cc.mallet.topics.ParallelTopicModel; import cc.mallet.types.Alphabet; import cc.mallet.types.FeatureSequence; import cc.mallet.types.IDSorter; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; import cc.mallet.types.TokenSequence; import file.PredictionFileWriter; import file.BookmarkReader; import file.BookmarkSplitter; public class MalletCalculator { private final static int MAX_RECOMMENDATIONS = 10; private final static int MAX_TERMS = 100; private final static int NUM_THREADS = 10; private final static int NUM_ITERATIONS = 2000; private final static double ALPHA = 0.01; private final static double BETA = 0.01; private final static double TOPIC_THRESHOLD = 0.001; private int numTopics; private List<Map<Integer, Integer>> maps; private InstanceList instances; private List<Map<Integer, Double>> docList; private List<Map<Integer, Double>> topicList; public MalletCalculator(List<Map<Integer, Integer>> maps, int numTopics) { this.numTopics = numTopics; this.maps = maps; initializeDataStructures(); } private void initializeDataStructures() { this.instances = new InstanceList(new StringList2FeatureSequence()); for (Map<Integer, Integer> map : this.maps) { List<String> tags = new ArrayList<String>(); for (Map.Entry<Integer, Integer> entry : map.entrySet()) { for (int i = 0; i < entry.getValue(); i++) { tags.add(entry.getKey().toString()); } } Instance inst = new Instance(tags, null, null, null); inst.setData(tags); this.instances.addThruPipe(inst); } } private List<Map<Integer, Double>> getMaxTopicsByDocs(ParallelTopicModel LDA, int maxTopicsPerDoc) { List<Map<Integer, Double>> docList = new ArrayList<Map<Integer, Double>>(); int numDocs = this.instances.size(); for (int doc = 0; doc < numDocs; ++doc) { Map<Integer, Double> topicList = new LinkedHashMap<Integer, Double>(); double[] topicProbs = LDA.getTopicProbabilities(doc); //double probSum = 0.0; for (int topic = 0; topic < topicProbs.length && topic < maxTopicsPerDoc; topic++) { //if (topicProbs[topic] > 0.01) { // TODO topicList.put(topic, topicProbs[topic]); //probSum += topicProbs[topic]; //} } //System.out.println("Topic Sum: " + probSum); Map<Integer, Double> sortedTopicList = new TreeMap<Integer, Double>(new DoubleMapComparator(topicList)); sortedTopicList.putAll(topicList); docList.add(sortedTopicList); } return docList; } private List<Map<Integer, Double>> getMaxTermsByTopics(ParallelTopicModel LDA, int limit) { Alphabet alphabet = LDA.getAlphabet(); List<Map<Integer, Double>> topicList = new ArrayList<Map<Integer, Double>>(); int numTopics = LDA.getNumTopics(); List<TreeSet<IDSorter>> sortedWords = LDA.getSortedWords(); for (int topic = 0; topic < numTopics; ++topic) { Map<Integer, Double> termList = new LinkedHashMap<Integer, Double>(); TreeSet<IDSorter> topicWords = sortedWords.get(topic); //int i = 0; double weightSum = 0.0; for (IDSorter entry : topicWords) { if (entry.getWeight() > 0.0) { // TODO //if (i++ < limit) { // TODO int tag = Integer.parseInt(alphabet.lookupObject(entry.getID()).toString()); termList.put(tag, entry.getWeight()); weightSum += entry.getWeight(); //} else { // break; //} } } // relative values //double relSum = 0.0; for (Map.Entry<Integer, Double> entry : termList.entrySet()) { //relSum += (entry.getValue() / weightSum); entry.setValue(entry.getValue() / weightSum); } //System.out.println("RelSum: " + relSum); topicList.add(termList); } return topicList; } public void predictValuesProbs() { ParallelTopicModel LDA = new ParallelTopicModel(this.numTopics, ALPHA * this.numTopics, BETA); // TODO LDA.addInstances(this.instances); LDA.setNumThreads(1); LDA.setNumIterations(NUM_ITERATIONS); LDA.setRandomSeed(43); try { LDA.estimate(); } catch (Exception e) { e.printStackTrace(); } this.docList = getMaxTopicsByDocs(LDA, this.numTopics); System.out.println("Fetched Doc-List"); this.topicList = getMaxTermsByTopics(LDA, MAX_TERMS); System.out.println("Fetched Topic-List"); } public Map<Integer, Double> getValueProbsForID(int id, boolean topicCreation) { Map<Integer, Double> terms = null; if (id < this.docList.size()) { terms = new LinkedHashMap<Integer, Double>(); Map<Integer, Double> docVals = this.docList.get(id); for (Map.Entry<Integer, Double> topic : docVals.entrySet()) { // look at each assigned topic Set<Entry<Integer, Double>> entrySet = this.topicList.get(topic.getKey()).entrySet(); double topicProb = topic.getValue(); for (Map.Entry<Integer, Double> entry : entrySet) { // and its terms if (topicCreation) { if (topicProb > TOPIC_THRESHOLD) { terms.put(entry.getKey(), topicProb); break; // only use first tag as topic-name with the topic probability } } else { double wordProb = entry.getValue(); Double val = terms.get(entry.getKey()); terms.put(entry.getKey(), val == null ? wordProb * topicProb : val + wordProb * topicProb); } } } } return terms; } // Statics ------------------------------------------------------------------------------------------------------------------------- private static List<Double> getDenoms(List<Map<Integer, Double>> maps) { List<Double> denoms = new ArrayList<Double>(); for (Map<Integer, Double> map : maps) { double denom = 0.0; for (Map.Entry<Integer, Double> entry : map.entrySet()) { denom += Math.exp(entry.getValue()); } denoms.add(denom); } return denoms; } private static Map<Integer, Double> getRankedTagList(BookmarkReader reader, Map<Integer, Double> userMap, double userDenomVal, Map<Integer, Double> resMap, double resDenomVal, boolean sorting, boolean smoothing, boolean topicCreation) { Map<Integer, Double> resultMap = new LinkedHashMap<Integer, Double>(); if (userMap != null) { for (Map.Entry<Integer, Double> entry : userMap.entrySet()) { resultMap.put(entry.getKey(), entry.getValue().doubleValue()); } } if (resMap != null) { for (Map.Entry<Integer, Double> entry : resMap.entrySet()) { double resVal = entry.getValue().doubleValue(); Double val = resultMap.get(entry.getKey()); resultMap.put(entry.getKey(), val == null ? resVal : val.doubleValue() + resVal); } } if (sorting) { Map<Integer, Double> sortedResultMap = new TreeMap<Integer, Double>(new DoubleMapComparator(resultMap)); sortedResultMap.putAll(resultMap); Map<Integer, Double> returnMap = new LinkedHashMap<Integer, Double>(MAX_RECOMMENDATIONS); int i = 0; for (Map.Entry<Integer, Double> entry : sortedResultMap.entrySet()) { if (i++ < MAX_RECOMMENDATIONS) { returnMap.put(entry.getKey(), entry.getValue()); } else { break; } } return returnMap; } return resultMap; /* double size = (double)reader.getTagAssignmentsCount(); double tagSize = (double)reader.getTags().size(); Map<Integer, Double> resultMap = new LinkedHashMap<Integer, Double>(); for (int i = 0; i < tagSize; i++) { double pt = (double)reader.getTagCounts().get(i) / size; Double userVal = 0.0; if (userMap != null && userMap.containsKey(i)) { userVal = userMap.get(i);//Math.exp(userMap.get(i)) / userDenomVal; } Double resVal = 0.0; if (resMap != null && resMap.containsKey(i)) { resVal = resMap.get(i);//Math.exp(resMap.get(i)) / resDenomVal; } if (userVal > 0.0 || resVal > 0.0) { // TODO if (smoothing) { resultMap.put(i, Utilities.getSmoothedTagValue(userVal, userDenomVal, resVal, resDenomVal, pt)); } else { resultMap.put(i, userVal + resVal); } } } if (sorting) { Map<Integer, Double> sortedResultMap = new TreeMap<Integer, Double>(new DoubleMapComparator(resultMap)); sortedResultMap.putAll(resultMap); if (topicCreation) { return sortedResultMap; } else { // otherwise filter Map<Integer, Double> returnMap = new LinkedHashMap<Integer, Double>(MAX_RECOMMENDATIONS); int i = 0; for (Map.Entry<Integer, Double> entry : sortedResultMap.entrySet()) { if (i++ < MAX_RECOMMENDATIONS) { returnMap.put(entry.getKey(), entry.getValue()); } else { break; } } return returnMap; } } return resultMap; */ } private static String timeString; public static List<Map<Integer, Double>> startLdaCreation(BookmarkReader reader, int sampleSize, boolean sorting, int numTopics, boolean userBased, boolean resBased, boolean topicCreation, boolean smoothing) { timeString = ""; int size = reader.getUserLines().size(); int trainSize = size - sampleSize; Stopwatch timer = new Stopwatch(); timer.start(); MalletCalculator userCalc = null; List<Map<Integer, Integer>> userMaps = null; //List<Double> userDenoms = null; if (userBased) { userMaps = Utilities.getUserMaps(reader.getUserLines().subList(0, trainSize)); userCalc = new MalletCalculator(userMaps, numTopics); userCalc.predictValuesProbs(); //userDenoms = getDenoms(userPredictionValues); System.out.println("User-Training finished"); } MalletCalculator resCalc = null; List<Map<Integer, Integer>> resMaps = null; //List<Double> resDenoms = null; if (resBased) { resMaps = Utilities.getResMaps(reader.getUserLines().subList(0, trainSize)); resCalc = new MalletCalculator(resMaps, numTopics); resCalc.predictValuesProbs(); //resDenoms = getDenoms(resPredictionValues); System.out.println("Res-Training finished"); } List<Map<Integer, Double>> results = new ArrayList<Map<Integer, Double>>(); if (trainSize == size) { trainSize = 0; } timer.stop(); long trainingTime = timer.elapsed(TimeUnit.MILLISECONDS); timer = new Stopwatch(); timer.start(); for (int i = trainSize; i < size; i++) { // the test set UserData data = reader.getUserLines().get(i); int userID = data.getUserID(); int resID = data.getWikiID(); //Map<Integer, Integer> userMap = null; //if (userBased && userMaps != null && userID < userMaps.size()) { // userMap = userMaps.get(userID); //} //Map<Integer, Integer> resMap = null; //if (resBased && resMaps != null && resID < resMaps.size()) { // resMap = resMaps.get(resID); //} double userTagCount = 0.0;//Utilities.getMapCount(userMap); double resTagCount = 0.0;//Utilities.getMapCount(resMap); /* double userDenomVal = 0.0; if (userDenoms != null && userID < userDenoms.size()) { userDenomVal = userDenoms.get(userID); } double resDenomVal = 0.0; if (resDenoms != null && resID < resDenoms.size()) { resDenomVal = resDenoms.get(resID); } */ Map<Integer, Double> userPredMap = null; if (userCalc != null) { userPredMap = userCalc.getValueProbsForID(userID, topicCreation); } Map<Integer, Double> resPredMap = null; if (resCalc != null) { resPredMap = resCalc.getValueProbsForID(resID, topicCreation); } Map<Integer, Double> map = getRankedTagList(reader, userPredMap, userTagCount, resPredMap, resTagCount, sorting, smoothing, topicCreation); results.add(map); } timer.stop(); long testTime = timer.elapsed(TimeUnit.MILLISECONDS); timeString += ("Full training time: " + trainingTime + "\n"); timeString += ("Full test time: " + testTime + "\n"); timeString += ("Average test time: " + testTime / (double) sampleSize) + "\n"; timeString += ("Total time: " + (trainingTime + testTime) + "\n"); return results; } public static void predictSample(String filename, int trainSize, int sampleSize, int numTopics, boolean userBased, boolean resBased) { BookmarkReader reader = new BookmarkReader(trainSize, false); reader.readFile(filename); List<Map<Integer, Double>> ldaValues = startLdaCreation(reader, sampleSize, true, numTopics, userBased, resBased, false, true); List<int[]> predictionValues = new ArrayList<int[]>(); for (int i = 0; i < ldaValues.size(); i++) { Map<Integer, Double> ldaVal = ldaValues.get(i); predictionValues.add(Ints.toArray(ldaVal.keySet())); } reader.setUserLines(reader.getUserLines().subList(trainSize, reader.getUserLines().size())); PredictionFileWriter writer = new PredictionFileWriter(reader, predictionValues); writer.writeFile(filename + "_lda_" + numTopics); Utilities.writeStringToFile("./data/metrics/" + filename + "_lda_" + numTopics + "_TIME.txt", timeString); } public static void createSample(String filename, int sampleSize, short numTopics, boolean userBased, boolean resBased) { String outputFile = new String(filename) + "_lda_" + numTopics; filename += "_res"; outputFile += "_res"; BookmarkReader reader = new BookmarkReader(0, false); reader.readFile(filename); int trainSize = reader.getUserLines().size() - sampleSize; List<Map<Integer, Double>> ldaValues = startLdaCreation(reader, 0, true, numTopics, userBased, resBased, true, true); List<int[]> predictionValues = new ArrayList<int[]>(); // TODO: make argument for the probValues List<double[]> probValues = new ArrayList<double[]>(); for (int i = 0; i < ldaValues.size(); i++) { Map<Integer, Double> ldaVal = ldaValues.get(i); predictionValues.add(Ints.toArray(ldaVal.keySet())); probValues.add(Doubles.toArray(ldaVal.values())); /* int[] values = new int[MAX_RECOMMENDATIONS]; int j = 0; for (Integer val : ldaVal.keySet()) { if (j < MAX_RECOMMENDATIONS) { values[j++] = val; } else { break; } } predictionValues.add(values); */ } List<UserData> trainUserSample = reader.getUserLines().subList(0, trainSize); List<UserData> testUserSample = reader.getUserLines().subList(trainSize, trainSize + sampleSize); List<UserData> userSample = reader.getUserLines().subList(0, trainSize + sampleSize); BookmarkSplitter.writeWikiSample(reader, trainUserSample, outputFile + "_train", predictionValues); BookmarkSplitter.writeWikiSample(reader, testUserSample, outputFile + "_test", predictionValues); BookmarkSplitter.writeWikiSample(reader, userSample, outputFile, predictionValues); } }