io.hops.erasure_coding.TestNativeErasureCodes.java Source code

Java tutorial

Introduction

Here is the source code for io.hops.erasure_coding.TestNativeErasureCodes.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.hops.erasure_coding;

import static io.hops.erasure_coding.TestGaloisField.GF;
import java.nio.ByteBuffer;
import junit.framework.TestCase;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;
import org.apache.hadoop.io.erasurecode.ErasureCodeNative;
import org.apache.hadoop.io.erasurecode.ErasureCoderOptions;
import org.apache.hadoop.io.erasurecode.rawcoder.NativeRSRawDecoder;
import org.apache.hadoop.io.erasurecode.rawcoder.NativeRSRawEncoder;
import org.apache.hadoop.io.erasurecode.rawcoder.NativeRSRawErasureCoderFactory;
import org.junit.Test;

public class TestNativeErasureCodes extends TestCase {
    final int TEST_CODES = 100;
    final int TEST_TIMES = 1000;
    final Random RAND = new Random();

    @Test
    public void testEC() {
        ErasureCodeNative.checkNativeCodeLoaded();
    }

    public void setZero(byte[] b) {
        for (int i = 0; i < b.length; i++) {
            b[i] = 0x00;
        }
    }

    public void testEncodeDecode() {
        long overallEncode = 0L;
        long overallDecode = 0L;

        for (int n = 0; n < TEST_CODES; n++) {
            int stripeSize = 10;//RAND.nextInt(99) + 1; // 1, 2, 3, ... 100
            int paritySize = 4;//RAND.nextInt(9) + 1; //1, 2, 3, 4, ... 10
            ErasureCode ec = new ReedSolomonCode(stripeSize, paritySize);
            for (int m = 0; m < TEST_TIMES; m++) {

                byte[][] inputs = new byte[stripeSize][1024];
                byte[][] outputs = new byte[paritySize][1024];

                byte[][] outputData = new byte[stripeSize + paritySize][1024];
                byte[][] outputCopy = new byte[stripeSize + paritySize][1024];

                for (int i = 0; i < stripeSize; i++) {
                    byte[] b = new byte[1024];
                    RAND.nextBytes(b);
                    inputs[i] = b;
                    outputData[i + paritySize] = Arrays.copyOf(b, b.length);
                    outputCopy[i + paritySize] = Arrays.copyOf(b, b.length);
                }

                long startTime = System.currentTimeMillis();
                ec.encodeBulk(inputs, outputs);
                long stopTime = System.currentTimeMillis();
                overallEncode += (stopTime - startTime);

                for (int i = 0; i < paritySize; i++) {
                    outputData[i] = Arrays.copyOf(outputs[i], outputs.length);
                    outputCopy[i] = Arrays.copyOf(outputs[i], outputs.length);
                }

                int erasedLen = 4;//paritySize == 1 ? 1 : RAND.nextInt(paritySize - 1) + 1;
                int[] erasedLocations = randomErasedLocation(erasedLen, inputs.length);
                for (int i = 0; i < erasedLocations.length; i++) {
                    erasedLocations[i] += paritySize;
                    setZero(outputData[erasedLocations[i]]);
                }

                int[] erasedValues = new int[erasedLen];
                byte[][] writeBufs = new byte[erasedLen][1024];

                startTime = System.currentTimeMillis();
                ((ReedSolomonCode) ec).decodeBulk(outputData, writeBufs, erasedLocations);
                stopTime = System.currentTimeMillis();
                overallDecode += (stopTime - startTime);

            }
        }

        System.out.println("Encoding " + (overallEncode) + " milliseconds");
        System.out.println("Decoding " + (overallDecode) + " milliseconds");
    }

    @Test
    public void testNativeEncodeDecode() {

        long overallEncode = 0L;
        long overallDecode = 0L;

        for (int n = 0; n < TEST_CODES; n++) {
            int stripeSize = 10;//RAND.nextInt(99) + 1; // 1, 2, 3, ... 100
            int paritySize = 4;//RAND.nextInt(9) + 1; //1, 2, 3, 4, ... 10
            NativeRSRawErasureCoderFactory factory = new NativeRSRawErasureCoderFactory();
            NativeRSRawEncoder enc = (NativeRSRawEncoder) factory
                    .createEncoder(new ErasureCoderOptions(stripeSize, paritySize));
            NativeRSRawDecoder dec = (NativeRSRawDecoder) factory
                    .createDecoder(new ErasureCoderOptions(stripeSize, paritySize));

            for (int m = 0; m < TEST_TIMES; m++) {

                int symbolMax = (int) Math.pow(2, (int) Math.round(Math.log(GF.getFieldSize()) / Math.log(2)));

                int[] message = new int[stripeSize];
                for (int i = 0; i < stripeSize; i++) {
                    message[i] = RAND.nextInt(symbolMax) + 2;
                }

                int[] parity = new int[paritySize];

                /* Native Encode Starts */
                ByteBuffer[] encodeData = new ByteBuffer[message.length];
                ByteBuffer[] parityData = new ByteBuffer[parity.length];

                int[] inputOffsets = new int[encodeData.length];
                int[] outputOffsets = new int[parityData.length];

                for (int i = 0; i < message.length; i++) {
                    encodeData[i] = ByteBuffer.allocateDirect(1024);
                    encodeData[i].putInt(message[i]);
                    for (int j = 0; j < 255; j++) {
                        encodeData[i].putInt(RAND.nextInt(symbolMax) + 2);
                    }
                    encodeData[i].flip();
                }

                for (int i = 0; i < parity.length; i++) {
                    parityData[i] = ByteBuffer.allocateDirect(1024);
                }

                long startTime = System.currentTimeMillis();
                enc.performEncodeImpl(encodeData, inputOffsets, 1024, parityData, outputOffsets);
                long stopTime = System.currentTimeMillis();
                overallEncode += (stopTime - startTime);

                /* Native Encode Ends here*/

                /* Native Decode Starts here*/
                int[] data = new int[stripeSize + paritySize];
                int[] copy = new int[data.length];
                for (int i = 0; i < stripeSize; i++) {
                    data[i] = message[i];
                    copy[i] = message[i];
                }
                for (int i = 0; i < paritySize; i++) {
                    data[i + stripeSize] = parity[i];
                    copy[i + stripeSize] = parity[i];
                }
                int erasedLen = 4;//paritySize == 1 ? 1 : RAND.nextInt(paritySize - 1) + 1;
                int[] erasedLocations = randomErasedLocation(erasedLen, message.length);
                for (int i = 0; i < erasedLocations.length; i++) {
                    data[erasedLocations[i]] = 0;
                }
                int[] erasedValues = new int[erasedLen];

                //Native Decode
                ByteBuffer[] decodeData = new ByteBuffer[stripeSize + paritySize];
                ByteBuffer[] recoverData = new ByteBuffer[erasedValues.length];

                inputOffsets = new int[decodeData.length];
                outputOffsets = new int[recoverData.length];

                for (int i = 0; i < stripeSize; i++) {
                    if (data[i] == 0) {
                        decodeData[i] = null;
                        continue;
                    }
                    decodeData[i] = encodeData[i];
                    decodeData[i].flip();
                }
                for (int i = stripeSize; i < stripeSize + paritySize; i++) {
                    decodeData[i] = parityData[i - stripeSize];
                    decodeData[i].flip();
                }
                for (int i = 0; i < erasedValues.length; i++) {
                    recoverData[i] = ByteBuffer.allocateDirect(1024);
                }

                startTime = System.currentTimeMillis();
                dec.performDecodeImpl(decodeData, inputOffsets, 1024, erasedLocations, recoverData, outputOffsets);
                stopTime = System.currentTimeMillis();
                overallDecode += (stopTime - startTime);

                for (int i = 0; i < recoverData.length; i++) {
                    erasedValues[i] = recoverData[i].getInt();
                }

                /* Native Decode Ends here */

                for (int i = 0; i < erasedLen; i++) {

                    StringBuffer sb = new StringBuffer();
                    sb.append("\nC ");
                    for (int j = 0; j < data.length; j++) {
                        sb.append(" " + copy[j]);
                    }

                    sb.append("\nD ");

                    for (int j = 0; j < data.length; j++) {
                        sb.append(" " + data[j]);
                    }

                    assertEquals("Decode failed " + sb, copy[erasedLocations[i]], erasedValues[i]);

                }
            }
            enc.release();
            dec.release();

        }

        System.out.println("Encoding " + (overallEncode) + " milliseconds");
        System.out.println("Decoding " + (overallDecode) + " milliseconds");
    }

    public void testRSPerformance() {
        int stripeSize = 10;
        int paritySize = 4;
        ErasureCode ec = new ReedSolomonCode(stripeSize, paritySize);
        int symbolMax = (int) Math.pow(2, ec.symbolSize());
        byte[][] message = new byte[stripeSize][];
        int bufsize = 1024 * 1024 * 10;
        for (int i = 0; i < stripeSize; i++) {
            message[i] = new byte[bufsize];
            for (int j = 0; j < bufsize; j++) {
                message[i][j] = (byte) RAND.nextInt(symbolMax);
            }
        }
        byte[][] parity = new byte[paritySize][];
        for (int i = 0; i < paritySize; i++) {
            parity[i] = new byte[bufsize];
        }
        long encodeStart = System.currentTimeMillis();
        int[] tmpIn = new int[stripeSize];
        int[] tmpOut = new int[paritySize];
        for (int i = 0; i < bufsize; i++) {
            // Copy message.
            for (int j = 0; j < stripeSize; j++) {
                tmpIn[j] = 0x000000FF & message[j][i];
            }
            ec.encode(tmpIn, tmpOut);
            // Copy parity.
            for (int j = 0; j < paritySize; j++) {
                parity[j][i] = (byte) tmpOut[j];
            }
        }
        long encodeEnd = System.currentTimeMillis();
        float encodeMSecs = (encodeEnd - encodeStart);
        System.out.println("Time to encode rs = " + encodeMSecs + "msec ("
                + message[0].length / (1000 * encodeMSecs) + " MB/s)");

        // Copy erased array.
        int[] data = new int[paritySize + stripeSize];
        // 4th location is the 0th symbol in the message
        int[] erasedLocations = new int[] { 4, 1, 5, 7 };
        int[] erasedValues = new int[erasedLocations.length];
        byte[] copy = new byte[bufsize];
        for (int j = 0; j < bufsize; j++) {
            copy[j] = message[0][j];
            message[0][j] = 0;
        }

        long decodeStart = System.currentTimeMillis();
        for (int i = 0; i < bufsize; i++) {
            // Copy parity first.
            for (int j = 0; j < paritySize; j++) {
                data[j] = 0x000000FF & parity[j][i];
            }
            // Copy message. Skip 0 as the erased symbol
            for (int j = 1; j < stripeSize; j++) {
                data[j + paritySize] = 0x000000FF & message[j][i];
            }
            // Use 0, 2, 3, 6, 8, 9, 10, 11, 12, 13th symbol to reconstruct the data
            ec.decode(data, erasedLocations, erasedValues);
            message[0][i] = (byte) erasedValues[0];
        }
        long decodeEnd = System.currentTimeMillis();
        float decodeMSecs = (decodeEnd - decodeStart);
        System.out.println(
                "Time to decode = " + decodeMSecs + "msec (" + message[0].length / (1000 * decodeMSecs) + " MB/s)");
        assertTrue("Decode failed", Arrays.equals(copy, message[0]));
    }

    public void testRSEncodeDecodeBulk() {
        // verify the production size.
        verifyRSEncodeDecodeBulk(10, 4);

        // verify a test size
        verifyRSEncodeDecodeBulk(3, 3);
    }

    public void verifyRSEncodeDecodeBulk(int stripeSize, int paritySize) {
        ReedSolomonCode rsCode = new ReedSolomonCode(stripeSize, paritySize);
        int symbolMax = (int) Math.pow(2, rsCode.symbolSize());
        byte[][] message = new byte[stripeSize][];
        byte[][] cpMessage = new byte[stripeSize][];
        int bufsize = 1024 * 1024 * 10;
        for (int i = 0; i < stripeSize; i++) {
            message[i] = new byte[bufsize];
            cpMessage[i] = new byte[bufsize];
            for (int j = 0; j < bufsize; j++) {
                message[i][j] = (byte) RAND.nextInt(symbolMax);
                cpMessage[i][j] = message[i][j];
            }
        }
        byte[][] parity = new byte[paritySize][];
        for (int i = 0; i < paritySize; i++) {
            parity[i] = new byte[bufsize];
        }

        // encode.
        rsCode.encodeBulk(cpMessage, parity);

        int erasedLocation = RAND.nextInt(stripeSize);
        byte[] copy = new byte[bufsize];
        for (int i = 0; i < bufsize; i++) {
            copy[i] = message[erasedLocation][i];
            message[erasedLocation][i] = (byte) 0;
        }

        // test decode
        byte[][] data = new byte[stripeSize + paritySize][];
        for (int i = 0; i < paritySize; i++) {
            data[i] = new byte[bufsize];
            for (int j = 0; j < bufsize; j++) {
                data[i][j] = parity[i][j];
            }
        }

        for (int i = 0; i < stripeSize; i++) {
            data[i + paritySize] = new byte[bufsize];
            for (int j = 0; j < bufsize; j++) {
                data[i + paritySize][j] = message[i][j];
            }
        }
        byte[][] writeBufs = new byte[1][];
        writeBufs[0] = new byte[bufsize];
        rsCode.decodeBulk(data, writeBufs, new int[] { erasedLocation + paritySize });
        assertTrue("Decode failed", Arrays.equals(copy, writeBufs[0]));
    }

    public void testXorPerformance() {
        Random RAND = new Random();
        int stripeSize = 10;
        byte[][] message = new byte[stripeSize][];
        int bufsize = 1024 * 1024 * 10;
        for (int i = 0; i < stripeSize; i++) {
            message[i] = new byte[bufsize];
            for (int j = 0; j < bufsize; j++) {
                message[i][j] = (byte) RAND.nextInt(256);
            }
        }
        byte[] parity = new byte[bufsize];

        long encodeStart = System.currentTimeMillis();
        for (int i = 0; i < bufsize; i++) {
            for (int j = 0; j < stripeSize; j++) {
                parity[i] ^= message[j][i];
            }
        }
        long encodeEnd = System.currentTimeMillis();
        float encodeMSecs = encodeEnd - encodeStart;
        System.out.println("Time to encode xor = " + encodeMSecs + " msec ("
                + message[0].length / (1000 * encodeMSecs) + "MB/s)");

        byte[] copy = new byte[bufsize];
        for (int j = 0; j < bufsize; j++) {
            copy[j] = message[0][j];
            message[0][j] = 0;
        }

        long decodeStart = System.currentTimeMillis();
        for (int i = 0; i < bufsize; i++) {
            for (int j = 1; j < stripeSize; j++) {
                message[0][i] ^= message[j][i];
            }
            message[0][i] ^= parity[i];
        }
        long decodeEnd = System.currentTimeMillis();
        float decodeMSecs = decodeEnd - decodeStart;
        System.out.println("Time to decode xor = " + decodeMSecs + " msec ("
                + message[0].length / (1000 * decodeMSecs) + "MB/s)");
        assertTrue("Decode failed", Arrays.equals(copy, message[0]));
    }

    public void testComputeErrorLocations() {
        for (int i = 0; i < TEST_TIMES; ++i) {
            verifyErrorLocations(10, 4, 1);
            verifyErrorLocations(10, 4, 2);
        }
    }

    public void verifyErrorLocations(int stripeSize, int paritySize, int errors) {
        int[] message = new int[stripeSize];
        int[] parity = new int[paritySize];
        Set<Integer> errorLocations = new HashSet<Integer>();
        for (int i = 0; i < message.length; ++i) {
            message[i] = RAND.nextInt(256);
        }
        while (errorLocations.size() < errors) {
            int loc = RAND.nextInt(stripeSize + paritySize);
            errorLocations.add(loc);
        }
        ReedSolomonCode codec = new ReedSolomonCode(stripeSize, paritySize);
        codec.encode(message, parity);
        int[] data = combineArrays(parity, message);
        for (Integer i : errorLocations) {
            data[i] = randError(data[i]);
        }
        Set<Integer> recoveredLocations = new HashSet<Integer>();
        boolean resolved = codec.computeErrorLocations(data, recoveredLocations);
        if (resolved) {
            assertEquals(errorLocations, recoveredLocations);
        }
    }

    private int randError(int actual) {
        while (true) {
            int r = RAND.nextInt(256);
            if (r != actual) {
                return r;
            }
        }
    }

    private int[] combineArrays(int[] array1, int[] array2) {
        int[] result = new int[array1.length + array2.length];
        for (int i = 0; i < array1.length; ++i) {
            result[i] = array1[i];
        }
        for (int i = 0; i < array2.length; ++i) {
            result[i + array1.length] = array2[i];
        }
        return result;
    }

    private int[] randomErasedLocation(int erasedLen, int dataLen) {
        int[] erasedLocations = new int[erasedLen];
        for (int i = 0; i < erasedLen; i++) {
            Set<Integer> s = new HashSet<Integer>();
            while (s.size() != erasedLen) {
                s.add(RAND.nextInt(dataLen));
            }
            int t = 0;
            for (int erased : s) {
                erasedLocations[t++] = erased;
            }
        }
        return erasedLocations;
    }
}