AbstractBucketSorter.java :  » Search » obsearch » net » obsearch » index » sorter » Java Open Source

Java Open Source » Search » obsearch 
obsearch » net » obsearch » index » sorter » AbstractBucketSorter.java
package net.obsearch.index.sorter;

import hep.aida.bin.StaticBin1D;

import java.io.IOException;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.logging.Logger;

import net.obsearch.Index;
import net.obsearch.OB;
import net.obsearch.OperationStatus;
import net.obsearch.Status;
import net.obsearch.asserts.OBAsserts;
import net.obsearch.cache.OBCacheByteArray;
import net.obsearch.cache.OBCacheHandlerByteArray;
import net.obsearch.constants.OBSearchProperties;
import net.obsearch.dimension.AbstractDimension;
import net.obsearch.exception.AlreadyFrozenException;
import net.obsearch.exception.IllegalIdException;
import net.obsearch.exception.OBException;
import net.obsearch.exception.OBStorageException;
import net.obsearch.exception.OutOfRangeException;
import net.obsearch.index.bucket.AbstractBucketIndex;
import net.obsearch.index.bucket.BucketContainer;
import net.obsearch.index.bucket.BucketObject;
import net.obsearch.index.ghs.FixedPriorityQueue;
import net.obsearch.pivots.IncrementalPivotSelector;
import net.obsearch.query.AbstractOBQuery;
import net.obsearch.storage.CloseIterator;
import net.obsearch.storage.OBStorageConfig;
import net.obsearch.storage.OBStore;
import net.obsearch.storage.OBStoreFactory;
import net.obsearch.storage.OBStoreLong;
import net.obsearch.storage.TupleBytes;
import net.obsearch.storage.TupleLong;
import net.obsearch.storage.OBStorageConfig.IndexType;

public abstract class AbstractBucketSorter<O extends OB, B extends BucketObject<O>, Q, BC extends BucketContainer<O, B, Q>, P extends Projection<P, CP>, CP>
    extends AbstractBucketIndex<O, B, Q, BC> {

  static final transient Logger logger = Logger
      .getLogger(AbstractBucketSorter.class.getName());

  /**
   * Estimators for k ranges. This object tells us the amount of buckets we
   * should read for a given k query to match a certain value of error.
   */
  protected StaticBin1D[] kEstimators;
  
  
  /**
   * Number of samples employed to generate the k estimators.
   */
  private int sampleSize = 20;
  /**
   * For a query k we take the kEstimators[k] estimation and return the value
   * kEstimators[k].mean() + (kEstimators[k].stdDev() * kAlpha)
   */
  private float kAlpha = 3f;
  /**
   * K configuration that will be used by the user.
   */
  protected int[] userK = new int[] { 1, 3, 10, 50 };
  /**
   * The target CompoundError value used in the estimation.
   */
  private double expectedEP = 0.0;
  private Random r = new Random();

  /**
   * We keep the projection list
   */
  protected transient OBStoreLong projectionStorage;

  /**
   * We keep here the projections.
   */
  protected transient List<CP> projections;

  /**
   * Cache used for storing buckets
   */
  protected transient OBCacheByteArray<BC> bucketCache;

  /**
   * Pivot count for each bucket.
   */
  protected int bucketPivotCount;

  public AbstractBucketSorter(Class type,
      IncrementalPivotSelector pivotSelector, int pivotCount,
      int bucketPivotCount) throws OBStorageException, OBException {
    super(type, pivotSelector, pivotCount);
    this.bucketPivotCount = bucketPivotCount;
  }
  
  public AbstractBucketSorter(){
    super();
  }
  

  /**
   * For a query k we take the kEstimators[k] estimation and return the value
   * kEstimators[k].mean() + (kEstimators[k].stdDev() * kAlpha) This method
   * sets kAlpha.
   * 
   * @param kAlpha
   *            new kAlpha to set.
   */
  public void setKAlpha(float kAlpha) {
    this.kAlpha = kAlpha;
  }

  /**
   * Set the k values used by this index.
   * 
   * @param userK
   */
  public void setMaxK(int[] maxK) {
    this.userK = maxK;
  }

  /**
   * Set the sample size of the estimation
   * 
   * @param size
   *            the new size
   */
  public void setSampleSize(int size) {
    this.sampleSize = size;
  }

  @Override
  public byte[] getAddress(B bucket) throws OBException {
    return getProjection(bucket).getAddress();
  }

  protected abstract P getProjection(B b) throws OBException;

  protected abstract byte[] compactRepresentationToBytes(CP cp);

  protected abstract CP bytesToCompactRepresentation(byte[] data);

  protected abstract Class<CP> getPInstance();

  /**
   * Load the masks from the storage device into memory.
   * 
   * @throws OBException
   */
  protected void loadMasks() throws OBException {
    if (projections != null) {
      return;
    }
    logger.info("Loading masks!");
    OBAsserts.chkAssert(projectionStorage.size() <= Integer.MAX_VALUE,
        "Capacity exceeded");
    OBAsserts.chkAssert(projectionStorage.size() <= Integer.MAX_VALUE,
        "Exceeded allowed sketch set size");
    projections = new ArrayList<CP>((int) projectionStorage.size());
    CloseIterator<TupleLong> it = projectionStorage.processAll();
    // assert projectionStorage.size() == A.size() : "Projection storage: "
    // + projectionStorage.size() + " A: " + A.size();
    int i = 0;

    assert projections.size() == 0;
    HashSet<CP> viewed = new HashSet<CP>(projections.size());
    while (it.hasNext()) {
      TupleLong t = it.next();
      // assert Buckets.getValue(t.getValue()) != null;
      CP cp = this.bytesToCompactRepresentation(t.getValue());
      if (!viewed.contains(cp)) {
        projections.add(cp);
        viewed.add(cp);
      }

      i++;
    }
    logger.info("Loaded: " + projections.size() + " masks");
    // assert (Buckets.size() ) == projections.size() : "Buckets: " +
    // Buckets.size() + " project: " + projections.size() + " viewed: " +
    // viewed.size();
    it.closeCursor();
  }

  /**
   * Calculates the distance between a query and some projection
   * 
   * @param query
   * @return
   */
  protected abstract void updateDistance(P query, CP proj,FixedPriorityQueue<P> queue );

  /**
   * Search the f closest buckets to the given query. We drop the distance
   * values for performance reasons, but we could add them if we wanted in the
   * future.
   * 
   * @param query
   *            the query to employ
   * @param maxF
   *            the max number of items that will be returned
   * @return
   * @throws InstantiationException
   * @throws IllegalAccessException
   * @throws OBException
   */
  protected List<P> searchBuckets(P query, int maxF)
      throws InstantiationException, IllegalAccessException, OBException {
    loadMasks();
    FixedPriorityQueue<P> queue = new FixedPriorityQueue<P>(maxF);
    for (CP p : this.projections) {
      updateDistance(query, p, queue);      
    }
    return queue.getSortedData();
  }

  public void init(OBStoreFactory fact) throws OBStorageException,
      OBException, InstantiationException, IllegalAccessException {
    super.init(fact);
    OBStorageConfig conf = new OBStorageConfig();
    conf.setTemp(false);
    conf.setDuplicates(false);
    conf.setIndexType(IndexType.FIXED_RECORD);
    conf.setRecordSize(getCPSize());
    this.projectionStorage = fact.createOBStoreLong("projections", conf);
  }

  protected BC getBucketContainer(byte[] id) throws OBException,
      InstantiationException, IllegalAccessException {
    // BC bc = instantiateBucketContainer(null, id);
    BC container = this.bucketCache.get(id);
    return container;

  }

  protected void initByteArrayBuckets() throws OBException {
    OBStorageConfig conf = new OBStorageConfig();
    conf.setTemp(false);
    conf.setDuplicates(false);
    conf.setBulkMode(!isFrozen());
    this.Buckets = fact.createOBStore("Buckets_byte_array", conf);

  }

  /**
   * Return the compact representation size
   * 
   * @return
   */
  protected abstract int getCPSize();

  /**
   * Set the expected NN error
   * 
   * @param ep
   *            CompoundError value.
   * 
   */
  public void setExpectedError(double ep) {
    this.expectedEP = ep;
  }

  /**
   * Calculate the estimators.
   * 
   * @throws IllegalIdException
   * @throws OBException
   * @throws IllegalAccessException
   * @throws InstantiationException
   */
  protected void calculateEstimators() throws IllegalIdException,
      OBException, IllegalAccessException, InstantiationException {
    maxKEstimation();
  }

  /**
   * Sort all masks, and then start the search until the CompoundError is less than some
   * threshold. Do this for each k.
   * 
   * @throws IllegalIdException
   * @throws OBException
   * @throws IllegalAccessException
   * @throws InstantiationException
   */
  protected void maxKEstimation() throws IllegalIdException, OBException,
      IllegalAccessException, InstantiationException {
    if(userK.length == 0){
      return;
    }
    kEstimators = new StaticBin1D[getMaxK().length];
    logger.fine("Max k estimation");
    int i = 0;
    while (i < kEstimators.length) {
      kEstimators[i] = new StaticBin1D();
      i++;
    }

    long[] sample = AbstractDimension.select(sampleSize, r, null,
        (Index) this, null);
    O[] sampleSet = getObjects(sample);
    i = 0;
    for (O o : sampleSet) {
      logger.info("Estimating k sample #: " + i + " of " + sampleSize);
      maxKEstimationAux(o);
      i++;
    }

    i = 0;
    for (StaticBin1D s : kEstimators) {
      logger.info(" k" + userK[i]);
      if(printEstimation(i) != null){
        logger.info(printEstimation(i));
      }
      logger.info(s.toString());
      i++;
    }

  }
  
  protected String printEstimation(int i){
    return null;
  }

  private class BucketsLoader implements OBCacheHandlerByteArray<BC> {

    public long getDBSize() throws OBStorageException {
      return Buckets.size();
    }

    public BC loadObject(byte[] i) throws OBException,
        InstantiationException, IllegalAccessException,
        IllegalIdException {

      byte[] data = Buckets.getValue(i);
      if (data == null) {
        return null;
      }

      return instantiateBucketContainer(data, i);
    }

    @Override
    public void store(byte[] key, BC object) throws OBException {

      /*
       * if (object.isModified()) { OperationStatus s =
       * Buckets.putIfNew(key, object .serialize()); if(s.)
       * stats.addExtraStats("B_SIZE", object.size()); }
       */

    }

  }

  /**
   * Stores the given bucket b into the {@link #Buckets} storage device. The
   * given bucket b should have been returned by {@link #getBucket(OB, int)}
   * 
   * @param b
   *            The bucket in which we will insert the object.
   * @param object
   *            The object to insert.
   * @return A OperationStatus object with the new id of the object if the
   *         object was inserted successfully.
   * @throws OBStorageException
   */
  protected OperationStatus insertBucket(B b, O object)
      throws OBStorageException, IllegalIdException,
      IllegalAccessException, InstantiationException,
      OutOfRangeException, OBException {

    byte[] bucketId = getAddress(b);

    BC bc = instantiateBucketContainer(null, bucketId);
    OperationStatus s = bc.insert(b, object);
    // store the data in the index.

    // we have to re-do everything.
    byte[] bucketData = Buckets.getValue(bucketId);
    bc = instantiateBucketContainer(bucketData, bucketId);
    s = bc.insert(b, object);
    if (s.getStatus() == Status.OK) {
      projections = null; // make the sketch set void
      projectionStorage.put(b.getId(), bucketId);
    }
    Buckets.put(bucketId, bc.serialize());

    this.bucketCache.put(bucketId, bc);
    stats.addExtraStats("B_SIZE", bc.size());
    return s;
  }

  /**
   * Stores the given bucket b into the {@link #Buckets} storage device. The
   * given bucket b should have been returned by {@link #getBucket(OB, int)}
   * No checks are performed, we simply add the objects believing they are
   * unique.
   * 
   * @param b
   *            The bucket in which we will insert the object.
   * @param object
   *            The object to insert.
   * @return A OperationStatus object with the new id of the object if the
   *         object was inserted successfully.
   * @throws OBStorageException
   */
  @Override
  protected OperationStatus insertBucketBulk(B b, O object)
      throws OBStorageException, IllegalIdException,
      IllegalAccessException, InstantiationException,
      OutOfRangeException, OBException {

    projections = null; // make the sketch set void

    byte[] bucketId = getAddress(b);
    BC bc = instantiateBucketContainer(null, bucketId);
    OperationStatus s = bc.insertBulk(b, object);

    // we have to re-do everything.
    byte[] bucketData = Buckets.getValue(bucketId);
    bc = instantiateBucketContainer(bucketData, bucketId);
    s = bc.insertBulk(b, object);
    // long prevSize = Buckets.size();
    byte[] data = bc.serialize();
    Buckets.put(bucketId, data);

    // assert Arrays.equals(Buckets.getValue(bucketId), data) :
    // " Bucket storage is not working";
    /*
     * if(bucketData == null){ assert Buckets.size() == (prevSize + 1); }
     */
    projectionStorage.put(b.getId(), bucketId);
    this.bucketCache.put(bucketId, bc);
    stats.addExtraStats("B_SIZE", bc.size());
    return s;
  }

  protected void freezeDefault() throws AlreadyFrozenException,
      IllegalIdException, IllegalAccessException, InstantiationException,
      OutOfRangeException, OBException {
    //Buckets.deleteAll();
    projections = null;
    long i = 0;
    long max = databaseSize();
    logger.info("Creating masks...");
    
    OBAsserts.chkAssert(max <= Integer.MAX_VALUE, "No more than Integer.MAX_VALUE objects during freeze");
    List<MaskHolder> masks = new ArrayList<MaskHolder>((int)max);

      while (i < max) {
        O o  = getObject(i);
        B b = getBucket(o);
        b.setId(i);
        byte[] bucketId = getAddress(b);
        projectionStorage.put(i, bucketId);
        masks.add(new MaskHolder(bucketId, i, b));    
        b.setObject(null);

        i++;
      }
      logger.info("Sorting " + masks.size() + " masks...");
      Collections.sort(masks);
      logger.info("Sorted masks!");
      // now we sort the bucket ids in memory so that we can
      // do a bulk insert of the tree.
      MaskHolder previous = null;
      BC bc = null;
      logger.info("Bulk insert");
      i = 0;
      int inserted = 0;
      for(MaskHolder m : masks){
        if(previous == null || ! previous.equals(m)){
          if(previous != null){
            // this means that ! previous.equals(m)
            assert bc.size() > 0;
            byte[] data = bc.serialize();
            Buckets.put(previous.bucketId, data);
            inserted++;
            if(inserted % 100000 == 0){
              logger.info("Inserted: " + inserted + " buckets, " + i + " objects");
            }
          }
          bc = instantiateBucketContainer(null, m.bucketId);
        }
        O o = getObject(m.id);
        // horrible hack
        m.bucket.setObject(o);
        bc.insertBulk(m.bucket, o);
        previous = m;
        i++;
      }
      Buckets.put(previous.bucketId, bc.serialize());
      inserted++;
      assert inserted == Buckets.size();
      logger.info("Buckets size: " + Buckets.size());




  }

  private class MaskHolder implements Comparable<MaskHolder> {
    
    private byte[] bucketId;
    private long id;
    private B bucket;

    public MaskHolder(byte[] bucketId, long id, B bucket) {
      this.bucketId = bucketId;
      this.id = id;
      this.bucket = bucket;
    }
    
    public boolean equals(Object o){
      MaskHolder m = (MaskHolder)o;
      return Arrays.equals(bucketId, m.bucketId);
    }

    @Override
    public int compareTo(MaskHolder o) {
      if(bucketId.length < o.bucketId.length){
        return -1;
      }else if(bucketId.length > o.bucketId.length){
        return 1;
      }else{
        int i = 0;
        while(i < bucketId.length){
          if(bucketId[i] < o.bucketId[i]){
            return -1;
          }else if(bucketId[i] > o.bucketId[i]){
            return 1;
          }
          i++;
        }
        // finished the loop.h
        return 0;
      }
    }
    
    
  }

  @Override
  public void close() throws OBException {
    bucketCache.clearAll();
    projectionStorage.close();
    super.close();
  }

  /**
   * Returns a k query for the given object.
   * 
   * @param object
   *            (query object)
   * @param k
   *            the number of objects to accept in the query.
   * @return
   * @throws IllegalAccessException
   * @throws InstantiationException
   * @throws OBException
   */
  protected abstract AbstractOBQuery<O> getKQuery(O object, int k)
      throws OBException, InstantiationException, IllegalAccessException;

  /**
   * Returns a list of all the objects of this index.
   * 
   * @return a list of all the objects of this index.
   * @throws OBException
   * @throws InstantiationException
   * @throws IllegalAccessException
   * @throws IllegalIdException
   */
  public List<O> getAllObjects() throws IllegalIdException,
      IllegalAccessException, InstantiationException, OBException {
    List<O> db = new ArrayList<O>((int) databaseSize());
    int i = 0;
    long max = databaseSize();
    while (i < max) {
      O obj = getObject(i);
      db.add(obj);
      i++;
    }
    return db;
  }

  protected void initCache() throws OBException {
    super.initCache();
    bucketCache = new OBCacheByteArray<BC>(new BucketsLoader(),
        OBSearchProperties.getBucketsCacheSize());
  }

  /**
   * Estimate ks for the given query object and the given list of objects.
   * 
   * @param object
   * @param objects
   * @throws OBException
   * @throws InstantiationException
   * @throws IllegalAccessException
   */
  protected abstract void maxKEstimationAux(O object) throws OBException,
      InstantiationException, IllegalAccessException;

  /**
   * Estimate the k needed for a k-nn query.
   * 
   * @param queryK
   *            k of the k-nn query.
   * @return Number of buckets that should be retrieved for this query.
   * @throws OBException
   */
  public int estimateK(int queryK) throws OBException {
    int i = 0;
    for (int kval : this.userK) {
      if (kval == queryK) {
        break;
      }
      i++;
    }
    if (i == this.userK.length) {
      throw new OBException("Wrong k value");
    }
    if(kEstimators[i].size() == 0){
      return 1; //hack to avoid NaNs in very small DBs
    }
    long x = Math.round(this.kEstimators[i].mean()
        + (this.kEstimators[i].standardDeviation() * kAlpha));
    assert x <= Integer.MAX_VALUE;
    
    return (int) x;
    // return 10;
  }

  public int[] getMaxK() {
    return userK;
  }

  public double getExpectedEP() {
    return expectedEP;
  }

  public int getBucketPivotCount() {
    return bucketPivotCount;
  }

  public void bucketStats() throws OBStorageException, IllegalIdException,
      IllegalAccessException, InstantiationException, OBException {

    logger.fine("Bucket stats starts!");
    CloseIterator<TupleBytes> it = Buckets.processAll();
    // assert Buckets.size() == A.size();
    StaticBin1D s = new StaticBin1D();

    while (it.hasNext()) {
      TupleBytes t = it.next();
      BC bc = instantiateBucketContainer(t.getValue(), t.getKey());
      s.add(bc.size());
    }
    getStats().putStats("BUCKET_STATS", s);
    logger.info("Bucket Stats:");
    logger.info(s.toString());
    it.closeCursor();

  }

}
java2s.com  | Contact Us | Privacy Policy
Copyright 2009 - 12 Demo Source and Support. All rights reserved.
All other trademarks are property of their respective owners.