net.myrrix.online.generation.InputFilesReader.java Source code

Java tutorial

Introduction

Here is the source code for net.myrrix.online.generation.InputFilesReader.java

Source

/*
 * Copyright Myrrix Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.myrrix.online.generation;

import java.io.File;
import java.io.FilenameFilter;
import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.NoSuchElementException;

import com.google.common.base.Splitter;
import com.google.common.io.PatternFilenameFilter;
import org.apache.commons.math3.util.FastMath;
import org.apache.mahout.cf.taste.model.IDMigrator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import net.myrrix.common.LangUtils;
import net.myrrix.common.OneWayMigrator;
import net.myrrix.common.collection.FastByIDFloatMap;
import net.myrrix.common.collection.FastByIDMap;
import net.myrrix.common.collection.FastIDSet;
import net.myrrix.common.io.InvertedFilenameFilter;
import net.myrrix.common.iterator.FileLineIterable;
import net.myrrix.common.math.MatrixUtils;

/**
 * Reads input files into the "R" matrix representation.
 * 
 * @author Sean Owen
 */
final class InputFilesReader {

    private static final Logger log = LoggerFactory.getLogger(InputFilesReader.class);

    private static final Splitter COMMA = Splitter.on(',').trimResults();

    /**
     * Values with absolute value less than this in the input are considered 0.
     * Values are generally assumed to be > 1, actually,
     * and usually not negative, though they need not be.
     */
    private static final float ZERO_THRESHOLD = Float
            .parseFloat(System.getProperty("model.decay.zeroThreshold", "0.0001"));

    private InputFilesReader() {
    }

    static void readInputFiles(FastByIDMap<FastIDSet> knownItemIDs, FastByIDMap<FastByIDFloatMap> rbyRow,
            FastByIDMap<FastByIDFloatMap> rbyColumn, FastIDSet itemTagIDs, FastIDSet userTagIDs, File inputDir)
            throws IOException {

        FilenameFilter csvFilter = new PatternFilenameFilter(".+\\.csv(\\.(zip|gz))?");

        File[] otherFiles = inputDir.listFiles(new InvertedFilenameFilter(csvFilter));
        if (otherFiles != null) {
            for (File otherFile : otherFiles) {
                log.info("Skipping file {}", otherFile.getName());
            }
        }

        File[] inputFiles = inputDir.listFiles(csvFilter);
        if (inputFiles == null) {
            log.info("No input files in {}", inputDir);
            return;
        }
        Arrays.sort(inputFiles, ByLastModifiedComparator.INSTANCE);

        IDMigrator hash = new OneWayMigrator();

        int lines = 0;
        int badLines = 0;
        for (File inputFile : inputFiles) {
            log.info("Reading {}", inputFile);
            for (String line : new FileLineIterable(inputFile)) {

                if (badLines > 100) { // Crude check
                    throw new IOException("Too many bad lines; aborting");
                }

                lines++;

                if (line.isEmpty() || line.charAt(0) == '#') {
                    continue;
                }

                Iterator<String> it = COMMA.split(line).iterator();

                long userID;
                boolean userIsTag;
                long itemID;
                boolean itemIsTag;
                float value;
                try {

                    String userIDString = it.next();
                    userIsTag = userIDString.startsWith("\"");
                    if (userIsTag) {
                        userID = hash.toLongID(userIDString.substring(1, userIDString.length() - 1));
                    } else {
                        userID = Long.parseLong(userIDString);
                    }

                    String itemIDString = it.next();
                    itemIsTag = itemIDString.startsWith("\"");
                    if (itemIsTag) {
                        itemID = hash.toLongID(itemIDString.substring(1, itemIDString.length() - 1));
                    } else {
                        itemID = Long.parseLong(itemIDString);
                    }

                    if (it.hasNext()) {
                        String valueToken = it.next();
                        value = valueToken.isEmpty() ? Float.NaN : LangUtils.parseFloat(valueToken);
                    } else {
                        value = 1.0f;
                    }

                } catch (NoSuchElementException ignored) {
                    log.warn("Ignoring line with too few columns: '{}'", line);
                    badLines++;
                    continue;
                } catch (IllegalArgumentException iae) { // includes NumberFormatException
                    if (lines == 1) {
                        log.info("Ignoring header line: '{}'", line);
                    } else {
                        log.warn("Ignoring unparseable line: '{}'", line);
                        badLines++;
                    }
                    continue;
                }

                if (userIsTag && itemIsTag) {
                    log.warn("Two tags not allowed: '{}'", line);
                    badLines++;
                    continue;
                }

                if (userIsTag) {
                    itemTagIDs.add(userID);
                }

                if (itemIsTag) {
                    userTagIDs.add(itemID);
                }

                if (Float.isNaN(value)) {
                    // Remove, not set
                    MatrixUtils.remove(userID, itemID, rbyRow, rbyColumn);
                } else {
                    MatrixUtils.addTo(userID, itemID, value, rbyRow, rbyColumn);
                }

                if (knownItemIDs != null) {
                    FastIDSet itemIDs = knownItemIDs.get(userID);
                    if (Float.isNaN(value)) {
                        // Remove, not set
                        if (itemIDs != null) {
                            itemIDs.remove(itemID);
                            if (itemIDs.isEmpty()) {
                                knownItemIDs.remove(userID);
                            }
                        }
                    } else {
                        if (itemIDs == null) {
                            itemIDs = new FastIDSet();
                            knownItemIDs.put(userID, itemIDs);
                        }
                        itemIDs.add(itemID);
                    }
                }

                if (lines % 1000000 == 0) {
                    log.info("Finished {} lines", lines);
                }
            }
        }

        log.info("Pruning near-zero entries");
        removeSmall(rbyRow);
        removeSmall(rbyColumn);
    }

    private static void removeSmall(FastByIDMap<FastByIDFloatMap> matrix) {
        for (FastByIDMap.MapEntry<FastByIDFloatMap> entry : matrix.entrySet()) {
            for (Iterator<FastByIDFloatMap.MapEntry> it = entry.getValue().entrySet().iterator(); it.hasNext();) {
                FastByIDFloatMap.MapEntry entry2 = it.next();
                if (FastMath.abs(entry2.getValue()) < ZERO_THRESHOLD) {
                    it.remove();
                }
            }
        }
    }

}