Java tutorial
/* * (c) 2005 David B. Bracewell * * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package com.davidbracewell.ml.utils; import com.davidbracewell.tuple.Pair; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import java.util.List; /** * The type Data sets. * * @author David B. Bracewell */ public class DataSets { private static <TrainType> void addAll(List<TrainType> list, List<List<TrainType>> listOfLists) { for (List<TrainType> l : listOfLists) { list.addAll(l); } } public static <TrainType> Pair<List<TrainType>, List<TrainType>> createTrainTest(List<List<TrainType>> folds, int i) { List<TrainType> train = Lists.newArrayList(); List<TrainType> test = Lists.newArrayList(folds.get(i)); if (i == 0) { addAll(train, folds.subList(1, folds.size())); } else if (i == folds.size() - 1) { addAll(train, folds.subList(0, folds.size() - 1)); } else { addAll(train, folds.subList(0, i)); addAll(train, folds.subList(i + 1, folds.size())); } return Pair.of(train, test); } /** * Split pair. * * @param data the data * @param trainPercentage the train percentage * @return the pair */ public static <V> Pair<List<V>, List<V>> split(List<V> data, double trainPercentage) { Preconditions.checkNotNull(data); Preconditions.checkArgument(data.size() > 1, "Must be at least 2 items in the data set."); Preconditions.checkArgument(trainPercentage > 0, "Training percentage must be > 0."); int index = (int) Math.floor((double) data.size() * trainPercentage); return Pair.of(data.subList(0, index), data.subList(index, data.size())); } /** * Create folds. * * @param data the data * @param numFolds the num folds * @return the list */ public static <V> List<List<V>> nFolds(List<V> data, int numFolds) { Preconditions.checkNotNull(data); Preconditions.checkArgument(numFolds > 1, "Must be at least two folds."); Preconditions.checkArgument(data.size() >= numFolds, "Must be at least number of fold items in the data set."); return Lists.partition(data, (int) Math.floor((double) data.size() / numFolds)); } }//END OF DataSets