Java tutorial
/******************************************************************************* * Copyright 2014 Observational Health Data Sciences and Informatics * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ******************************************************************************/ package org.ohdsi.usagi; import java.awt.BorderLayout; import java.io.File; import java.io.IOException; import java.io.StringReader; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Set; import javax.swing.BorderFactory; import javax.swing.JDialog; import javax.swing.JFrame; import javax.swing.JLabel; import javax.swing.JPanel; import javax.swing.JProgressBar; import org.apache.lucene.analysis.core.KeywordAnalyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.Field.Store; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.StringField; import org.apache.lucene.document.TextField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.FieldInvertState; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.IndexWriterConfig.OpenMode; import org.apache.lucene.index.Term; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; import org.apache.lucene.queries.mlt.MoreLikeThis; import org.apache.lucene.queryparser.classic.ParseException; import org.apache.lucene.queryparser.classic.QueryParser; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanClause.Occur; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.similarities.DefaultSimilarity; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.Version; import org.ohdsi.utilities.DirectoryUtilities; import org.ohdsi.utilities.StringUtilities; /** * The Usagi search engine is used to find matching concepts for source terms. The search engine uses Lucene. */ public class UsagiSearchEngine { public static String MAIN_INDEX_FOLDER = "mainIndex"; public static String DERIVED_INDEX_FOLDER = "derivedIndex"; public static String SOURCE_CODE_TYPE_STRING = "S"; public static String CONCEPT_TYPE_STRING = "C"; private String folder; private IndexWriter writer; private IndexReader reader = null; private IndexSearcher searcher; private UsagiAnalyzer analyzer = new UsagiAnalyzer(); private Query conceptQuery; private QueryParser conceptIdQueryParser; private QueryParser conceptClassQueryParser; private QueryParser vocabularyQueryParser; private QueryParser keywordsQueryParser; private QueryParser invalidQueryParser; private QueryParser domainQueryParser; private int numDocs; private FieldType textVectorField = getTextVectorFieldType(); public UsagiSearchEngine(String folder) { this.folder = folder; } private FieldType getTextVectorFieldType() { FieldType textVectorField = new FieldType(); textVectorField.setIndexed(true); textVectorField.setTokenized(true); textVectorField.setStoreTermVectors(true); textVectorField.setStoreTermVectorPositions(false); textVectorField.setStoreTermVectorPayloads(false); textVectorField.setStoreTermVectorOffsets(false); textVectorField.setStored(true); textVectorField.freeze(); return textVectorField; } public void createNewMainIndex() { try { File indexFolder = new File(folder + "/" + MAIN_INDEX_FOLDER); if (indexFolder.exists()) DirectoryUtilities.deleteDir(indexFolder); Directory dir = FSDirectory.open(indexFolder); IndexWriterConfig iwc = new IndexWriterConfig(Version.LUCENE_4_9, new UsagiAnalyzer()); iwc.setOpenMode(OpenMode.CREATE); iwc.setRAMBufferSizeMB(256.0); writer = new IndexWriter(dir, iwc); } catch (Exception e) { throw new RuntimeException(e); } } public boolean mainIndexExists() { return new File(folder + "/" + MAIN_INDEX_FOLDER).exists(); } public void addConceptToIndex(TargetConcept concept) { if (writer == null) throw new RuntimeException("Indexed not open for writing"); try { Document document = new Document(); document.add(new StringField("TYPE", CONCEPT_TYPE_STRING, Store.YES)); document.add(new Field("TERM", concept.term, textVectorField)); document.add(new StringField("CONCEPT_ID", Integer.toString(concept.conceptId), Store.YES)); document.add(new StringField("CONCEPT_NAME", concept.conceptName, Store.YES)); document.add(new StringField("CONCEPT_CLASS", concept.conceptClass, Store.YES)); document.add(new StringField("CONCEPT_CODE", concept.conceptCode, Store.YES)); document.add(new StringField("VOCABULARY", concept.vocabulary, Store.YES)); document.add(new StringField("VALID_START_DATE", concept.validStartDate, Store.YES)); document.add(new StringField("VALID_END_DATE", concept.validEndDate, Store.YES)); document.add(new StringField("INVALID_REASON", concept.invalidReason, Store.YES)); document.add(new TextField("DOMAINS", StringUtilities.join(concept.domains, "\n"), Store.YES)); document.add(new StringField("ADDITIONAL_INFORMATION", concept.additionalInformation, Store.YES)); writer.addDocument(document); } catch (Exception e) { throw new RuntimeException(e); } } /** * Tokens that appear very frequently in the source code names, but not very often in the vocabulary, would get high weights (high IDF) even though they * probably are not very informative. To remedy this, we create a copy of the main index, and add all the source names to the index as well. * * @param sourceCodes * the list of source codes to add to the index * @param frame * a reference to the frame in case we want to show a progress dialog. Set to null if no progress dialog needs to be shown */ public void createDerivedIndex(List<SourceCode> sourceCodes, JFrame frame) { JDialog dialog = null; JProgressBar progressBar = null; if (frame != null) { dialog = new JDialog(frame, "Progress Dialog", false); JPanel panel = new JPanel(); panel.setBorder(BorderFactory.createRaisedBevelBorder()); panel.setLayout(new BorderLayout()); panel.add(BorderLayout.NORTH, new JLabel("Indexing source codes...")); progressBar = new JProgressBar(0, 100); panel.add(BorderLayout.CENTER, progressBar); dialog.add(panel); dialog.setDefaultCloseOperation(JDialog.DO_NOTHING_ON_CLOSE); dialog.setSize(300, 75); dialog.setLocationRelativeTo(frame); dialog.setUndecorated(true); dialog.setModal(true); } AddSourceCodesThread thread = new AddSourceCodesThread(sourceCodes, progressBar, dialog); thread.start(); if (dialog != null) dialog.setVisible(true); try { thread.join(); } catch (InterruptedException e) { e.printStackTrace(); } } private class AddSourceCodesThread extends Thread { private JProgressBar progressBar; private List<SourceCode> sourceCodes; private JDialog dialog; public AddSourceCodesThread(List<SourceCode> sourceCodes, JProgressBar progressBar, JDialog dialog) { this.sourceCodes = sourceCodes; this.progressBar = progressBar; this.dialog = dialog; } public void run() { try { File derivedIndexFolder = new File(folder + "/" + DERIVED_INDEX_FOLDER); if (derivedIndexFolder.exists()) if (!DirectoryUtilities.deleteDir(derivedIndexFolder)) System.out.println("Unable to delete derived index folder"); File indexFolder = new File(folder + "/" + MAIN_INDEX_FOLDER); DirectoryUtilities.copyDirectory(indexFolder, derivedIndexFolder); Directory dir = FSDirectory.open(derivedIndexFolder); IndexWriterConfig iwc = new IndexWriterConfig(Version.LUCENE_4_9, new UsagiAnalyzer()); iwc.setOpenMode(OpenMode.APPEND); iwc.setRAMBufferSizeMB(256.0); IndexWriter writer = new IndexWriter(dir, iwc); for (int i = 0; i < sourceCodes.size(); i++) { Document document = new Document(); document.add(new StringField("TYPE", SOURCE_CODE_TYPE_STRING, Store.YES)); document.add(new Field("TERM", sourceCodes.get(i).sourceName, textVectorField)); writer.addDocument(document); if (progressBar != null) progressBar.setValue(5 + (90 * i) / sourceCodes.size()); } // writer.forceMerge(1); writer.close(); System.gc(); if (dialog != null) dialog.setVisible(false); openIndexForSearching(); } catch (Exception e) { throw new RuntimeException(e); } } } public void openIndexForSearching() { try { reader = DirectoryReader.open(FSDirectory.open(new File(folder + "/" + DERIVED_INDEX_FOLDER))); searcher = new IndexSearcher(reader); searcher.setSimilarity(new CustomSimilarity()); BooleanQuery.setMaxClauseCount(Integer.MAX_VALUE); QueryParser typeQueryParser = new QueryParser(Version.LUCENE_4_9, "TYPE", new KeywordAnalyzer()); conceptQuery = typeQueryParser.parse(CONCEPT_TYPE_STRING); conceptIdQueryParser = new QueryParser(Version.LUCENE_4_9, "CONCEPT_ID", new KeywordAnalyzer()); conceptClassQueryParser = new QueryParser(Version.LUCENE_4_9, "CONCEPT_CLASS", new KeywordAnalyzer()); vocabularyQueryParser = new QueryParser(Version.LUCENE_4_9, "VOCABULARY", new KeywordAnalyzer()); keywordsQueryParser = new QueryParser(Version.LUCENE_4_9, "TERM", analyzer); domainQueryParser = new QueryParser(Version.LUCENE_4_9, "DOMAINS", analyzer); invalidQueryParser = new QueryParser(Version.LUCENE_4_9, "INVALID_REASON", new KeywordAnalyzer()); numDocs = reader.numDocs(); } catch (Exception e) { throw new RuntimeException(e); } } public class CustomSimilarity extends DefaultSimilarity { @Override public float lengthNorm(FieldInvertState state) { // simply return the field's configured boost value // instead of also factoring in the field's length return 1; } @Override public float idf(long docFreq, long numDocs) { return (float) (Math.log(numDocs / (docFreq + 1))); } @Override public float queryNorm(float sumOfSquaredWeights) { return 1; } @Override public float tf(float freq) { return freq; } @Override public float coord(int overlap, int maxOverlap) { return 1; } } public void close() { try { if (reader != null) { reader.close(); reader = null; System.gc(); } if (writer != null) { writer.forceMerge(1); writer.close(); writer = null; } } catch (IOException e) { e.printStackTrace(); } } public List<ScoredConcept> search(String searchTerm, boolean useMlt, Collection<Integer> filterConceptIds, String filterDomain, String filterConceptClass, String filterVocabulary, boolean filterInvalid) { List<ScoredConcept> results = new ArrayList<ScoredConcept>(); try { Query query; if (useMlt) { MoreLikeThis mlt = new MoreLikeThis(searcher.getIndexReader()); mlt.setMinTermFreq(1); mlt.setMinDocFreq(1); mlt.setMaxDocFreq(9999); mlt.setMinWordLen(1); mlt.setMaxWordLen(9999); mlt.setMaxDocFreqPct(100); mlt.setMaxNumTokensParsed(9999); mlt.setMaxQueryTerms(9999); mlt.setStopWords(null); mlt.setFieldNames(new String[] { "TERM" }); mlt.setAnalyzer(analyzer); query = mlt.like("TERM", new StringReader(searchTerm)); } else { try { query = keywordsQueryParser.parse(searchTerm); // if (query instanceof BooleanQuery) { // List<BooleanClause> clauses = ((BooleanQuery) query).clauses(); // BooleanClause lastClause = clauses.get(clauses.size() - 1); // lastClause.setQuery(new PrefixQuery(((TermQuery) lastClause.getQuery()).getTerm())); // } else if (query instanceof TermQuery) {// It's a single term // query = new PrefixQuery(((TermQuery) query).getTerm()); // } } catch (ParseException e) { return results; } } BooleanQuery booleanQuery = new BooleanQuery(); booleanQuery.add(query, Occur.SHOULD); booleanQuery.add(conceptQuery, Occur.MUST); if (filterConceptIds != null && filterConceptIds.size() > 0) { Query conceptIdQuery = conceptIdQueryParser.parse(StringUtilities.join(filterConceptIds, " OR ")); booleanQuery.add(conceptIdQuery, Occur.MUST); } if (filterDomain != null) { Query domainQuery = domainQueryParser.parse("\"" + filterDomain + "\""); booleanQuery.add(domainQuery, Occur.MUST); } if (filterConceptClass != null) { Query conceptClassQuery = conceptClassQueryParser .parse("\"" + filterConceptClass.toString() + "\""); booleanQuery.add(conceptClassQuery, Occur.MUST); } if (filterVocabulary != null) { Query vocabularyQuery = vocabularyQueryParser.parse("\"" + filterVocabulary.toString() + "\""); booleanQuery.add(vocabularyQuery, Occur.MUST); } if (filterInvalid) { Query invalidQuery = invalidQueryParser.parse("\"\""); booleanQuery.add(invalidQuery, Occur.MUST); } TopDocs topDocs = searcher.search(booleanQuery, 100); recomputeScores(topDocs.scoreDocs, query); for (ScoreDoc scoreDoc : topDocs.scoreDocs) { Document document = reader.document(scoreDoc.doc); int conceptId = Integer.parseInt(document.get("CONCEPT_ID")); // If matchscore = 0 but it was the one concept that was automatically selected, still allow it: if (scoreDoc.score > 0 || (filterConceptIds != null && filterConceptIds.size() == 1 && filterConceptIds.contains(conceptId))) { TargetConcept targetConcept = new TargetConcept(); targetConcept.term = document.get("TERM"); targetConcept.conceptId = conceptId; targetConcept.conceptName = document.get("CONCEPT_NAME"); targetConcept.conceptClass = document.get("CONCEPT_CLASS"); targetConcept.vocabulary = document.get("VOCABULARY"); targetConcept.conceptCode = document.get("CONCEPT_CODE"); targetConcept.validStartDate = document.get("VALID_START_DATE"); targetConcept.validEndDate = document.get("VALID_END_DATE"); targetConcept.invalidReason = document.get("INVALID_REASON"); for (String domain : document.get("DOMAINS").split("\n")) targetConcept.domains.add(domain); targetConcept.additionalInformation = document.get("ADDITIONAL_INFORMATION"); results.add(new ScoredConcept(scoreDoc.score, targetConcept)); } } reorderTies(results); removeDuplicateConcepts(results); } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); } return results; } private void removeDuplicateConcepts(List<ScoredConcept> results) { Set<Integer> seenConceptIds = new HashSet<Integer>(); Iterator<ScoredConcept> iterator = results.iterator(); while (iterator.hasNext()) { ScoredConcept scoredConcept = iterator.next(); if (!seenConceptIds.add(scoredConcept.concept.conceptId)) iterator.remove(); } } private void reorderTies(List<ScoredConcept> scoredConcepts) { Collections.sort(scoredConcepts, new Comparator<ScoredConcept>() { @Override public int compare(ScoredConcept arg0, ScoredConcept arg1) { int result = -Float.compare(arg0.matchScore, arg1.matchScore); if (result == 0) { if (arg0.concept.term.toLowerCase().equals(arg0.concept.conceptName.toLowerCase())) return -1; else if (arg1.concept.term.toLowerCase().equals(arg1.concept.conceptName.toLowerCase())) return 1; } return result; } }); } /** * Lucene's matching score does some weird things: it is not normalized (the value can be greater than 1), and not all tokens are included in the * computation. For that reason, we're recomputing the matching score as plain TF*IDF cosine matching here. * * @param scoreDocs * The array of documents scored by Lucene * @param query * The query used for retrieval */ private void recomputeScores(ScoreDoc[] scoreDocs, Query query) { try { Term2Tfidf searchTerm = null; if (query instanceof BooleanQuery) searchTerm = new Term2Tfidf((BooleanQuery) query); else if (query instanceof TermQuery) searchTerm = new Term2Tfidf((TermQuery) query); if (searchTerm != null && !searchTerm.isInvalid()) { for (ScoreDoc scoreDoc : scoreDocs) { Term2Tfidf hit = new Term2Tfidf(scoreDoc.doc, "TERM"); scoreDoc.score = (float) searchTerm.cosineSimilarity(hit); } Arrays.sort(scoreDocs, new Comparator<ScoreDoc>() { @Override public int compare(ScoreDoc arg0, ScoreDoc arg1) { return -Float.compare(arg0.score, arg1.score); } }); } } catch (Exception e) { System.err.println(e.getMessage()); } } public static class ScoredConcept { public float matchScore; public TargetConcept concept; public ScoredConcept(float matchScore, TargetConcept concept) { this.matchScore = matchScore; this.concept = concept; } } private class Term2Tfidf { public TermTfidfPair[] pairs; public double l1; public boolean invalid = false; public Term2Tfidf(int docId, String field) throws IOException { Terms vector = reader.getTermVector(docId, field); pairs = new TermTfidfPair[(int) vector.size()]; l1 = 0; TermsEnum termsEnum = vector.iterator(null); int i = 0; BytesRef text; while ((text = termsEnum.next()) != null) { double tfidf = termsEnum.totalTermFreq() * idf(reader.docFreq(new Term(field, termsEnum.term())), numDocs); pairs[i++] = new TermTfidfPair(BytesRef.deepCopyOf(text), tfidf); l1 += sqr(tfidf); } l1 = Math.sqrt(l1); sort(); } public boolean isInvalid() { return invalid; } public Term2Tfidf(BooleanQuery query) throws IOException { pairs = new TermTfidfPair[query.clauses().size()]; l1 = 0; int i = 0; for (BooleanClause clause : query.clauses()) { if (!(clause.getQuery() instanceof TermQuery)) invalid = true; else { TermQuery q = (TermQuery) clause.getQuery(); double tfidf = idf(reader.docFreq(q.getTerm()), numDocs); pairs[i++] = new TermTfidfPair(q.getTerm().bytes(), tfidf); l1 += sqr(tfidf); } } if (!invalid) { l1 = Math.sqrt(l1); sort(); } } public Term2Tfidf(TermQuery query) throws IOException { pairs = new TermTfidfPair[1]; l1 = 0; int i = 0; double tfidf = idf(reader.docFreq(query.getTerm()), numDocs); pairs[i++] = new TermTfidfPair(query.getTerm().bytes(), tfidf); l1 += sqr(tfidf); l1 = Math.sqrt(l1); sort(); } public void sort() { Arrays.sort(pairs, new Comparator<TermTfidfPair>() { @Override public int compare(TermTfidfPair o1, TermTfidfPair o2) { return o1.term.compareTo(o2.term); } }); } public double cosineSimilarity(Term2Tfidf other) { int cursor1 = 0; int cursor2 = 0; double dotProduct = 0; while (cursor1 < pairs.length && cursor2 < other.pairs.length) { int compare = pairs[cursor1].term.compareTo(other.pairs[cursor2].term); if (compare == 0) { dotProduct += pairs[cursor1].tfidf * other.pairs[cursor2].tfidf; cursor1++; cursor2++; } else if (compare < 0) { cursor1++; } else { cursor2++; } } if (l1 == 0 || other.l1 == 0) return 0; else return dotProduct / (l1 * other.l1); } } private class TermTfidfPair { public TermTfidfPair(BytesRef term, Double tfidf) { this.term = term; this.tfidf = tfidf; } public BytesRef term; public double tfidf; } private double sqr(double x) { return x * x; } private double idf(int docFreq, int d) { return Math.log(d / (double) docFreq); } public boolean isOpenForSearching() { return (reader != null); } }