com.tinspx.util.net.MultipartBody.java Source code

Java tutorial

Introduction

Here is the source code for com.tinspx.util.net.MultipartBody.java

Source

/* Copyright (C) 2013-2014 Ian Teune <ian.teune@gmail.com>
 * 
 * Permission is hereby granted, free of charge, to any person obtaining
 * a copy of this software and associated documentation files (the
 * "Software"), to deal in the Software without restriction, including
 * without limitation the rights to use, copy, modify, merge, publish,
 * distribute, sublicense, and/or sell copies of the Software, and to
 * permit persons to whom the Software is furnished to do so, subject to
 * the following conditions:
 * 
 * The above copyright notice and this permission notice shall be
 * included in all copies or substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
 * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
 * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
 * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 */
package com.tinspx.util.net;

import com.google.common.base.CharMatcher;
import com.google.common.base.Charsets;
import static com.google.common.base.Preconditions.*;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.collect.Iterables;
import com.google.common.collect.LinkedListMultimap;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.io.CountingOutputStream;
import com.google.common.net.HttpHeaders;
import com.google.common.net.MediaType;
import com.google.common.primitives.Ints;
import com.tinspx.util.collect.Listenable;
import com.tinspx.util.collect.Predicated;
import com.tinspx.util.io.BAOutputStream;
import com.tinspx.util.io.ByteUtils;
import com.tinspx.util.io.CharUtils;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.Charset;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import javax.annotation.concurrent.NotThreadSafe;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.experimental.Accessors;
import lombok.experimental.FieldDefaults;
import lombok.experimental.PackagePrivate;
import lombok.extern.slf4j.Slf4j;

/**
 * A <a hreg="http://tools.ietf.org/html/rfc2046#section-5.1">multipart</a>
 * {@code RequestBody}.
 * <p>
 * {@link #headers() headers()} returns a modifiable multimap; however the
 * {@code Content-Type} and {@code Content-Length} headers may not be set on
 * this multimap. The {@code Content-Type} may set in the constructor and
 * through {@link #contentType(MediaType) contentType(...)}. The boundary may be
 * explicitly set through {@link #boundary(String)} or set in the
 * {@code Content-Type}.
 *
 * @author Ian
 */
@Slf4j
@NotThreadSafe
@FieldDefaults(level = AccessLevel.PRIVATE)
@Accessors(fluent = true)
public class MultipartBody extends RequestBody {
    static final Random RANDOM = new Random();

    private static final byte[] COLONSPACE = { ':', ' ' };
    private static final byte[] CRLF = { '\r', '\n' };
    private static final byte[] DASHDASH = { '-', '-' };
    private static final byte[] CONTENT_LENGTH = CharUtils.encodeAscii(HttpHeaders.CONTENT_LENGTH);

    /**
     * The current boundary delimiter.
     */
    @Getter
    String boundary;
    byte[] boundaryBytes;
    @Getter
    Charset charset = Charsets.UTF_8;
    @Getter
    byte[] preamble = ByteUtils.emptyArray();
    @Getter
    byte[] epilogue = ByteUtils.emptyArray();

    final ListMultimap<String, String> headers;
    ListMultimap<String, String> headersView;
    final Set<RequestBody> parts;
    Set<RequestBody> partsView;
    /**
     * effective {@code Content-Length}.
     */
    @Getter
    MediaType type;

    /**
     * true if {@link #partCache} is valid. set to false when a new part is
     * added or the charset changes.
     */
    boolean valid;
    /**
     * used to compute the headers
     */
    BAOutputStream buffer;

    /**
     * max age in nanoseconds of the part cache.
     */
    private static final long MAX_AGE = TimeUnit.MILLISECONDS.toNanos(500);

    @RequiredArgsConstructor
    private static class PartCache {
        final @NonNull RequestBody part;
        final @NonNull byte[] bytes;
    }

    List<PartCache> partCache;
    long length;
    long cacheTime;
    /**
     * counts the number of times that the cache has been computed, used for
     * testing
     */
    @PackagePrivate
    int computeCount;

    private static final String BOUNDARY = "boundary";

    private static final CharMatcher BCHARS = CharMatcher
            .anyOf(CharUtils.DIGIT + CharUtils.ALPHA_LOWER + CharUtils.ALPHA_UPPER + "'()+_,-./:=? ");

    static String checkBoundary(@NonNull String boundary) {
        checkArgument(!boundary.isEmpty(), "boundary is empty");
        checkArgument(boundary.length() <= 70, "boundary (%s) is too long", boundary);
        checkArgument(BCHARS.matchesAllOf(boundary), "boundary (%s) contains invalid characters", boundary);
        checkArgument(!boundary.endsWith(" "), "boundary (%s) cannot end with a space", boundary);
        return boundary;
    }

    private static String generateBoundary() {
        return UUID.randomUUID().toString() + RANDOM.nextLong();
    }

    public MultipartBody(String contentType) {
        this(Headers.parseMediaType(contentType));
    }

    public MultipartBody(Headers.ContentType contentType) {
        this(contentType.asMediaType().get());
    }

    public MultipartBody(MediaType contentType) {
        this(contentType, null, true);
    }

    public MultipartBody(String contentType, String boundary) {
        this(Headers.parseMediaType(contentType), boundary);
    }

    public MultipartBody(Headers.ContentType contentType, String boundary) {
        this(contentType.asMediaType().get(), boundary);
    }

    public MultipartBody(MediaType contentType, String boundary) {
        this(contentType, boundary, false);
    }

    private MultipartBody(@NonNull MediaType contentType, @Nullable String boundary, boolean isNull) {
        if (!isNull) {
            checkNotNull(boundary, BOUNDARY);
        }
        headers = LinkedListMultimap.create();
        parts = Sets.newLinkedHashSet();
        setContentType(contentType, boundary, true);
        assert this.boundary != null;
    }

    private MultipartBody(MultipartBody m) {
        preamble = m.preamble.clone();
        epilogue = m.epilogue.clone();
        charset = m.charset;
        headers = LinkedListMultimap.create(m.headers);
        parts = Sets.newLinkedHashSet(m.parts);
        setContentType(m.type, null, true);
        assert this.boundary != null;
    }

    @Override
    public MultipartBody duplicate() {
        return new MultipartBody(this);
    }

    /**
     * Sets the preamble {@code discard-text} encoded with
     * {@link #charset() charset()}.
     */
    public MultipartBody preamble(@NonNull String preamble) {
        return preamble(preamble, charset);
    }

    public MultipartBody preamble(@NonNull String preamble, @NonNull Charset charset) {
        return preamble(preamble.getBytes(charset));
    }

    /**
     * Sets the preamble {@code discard-text}.
     */
    public MultipartBody preamble(@NonNull byte[] preamble) {
        this.preamble = preamble;
        return this;
    }

    /**
     * Sets the epilogue {@code discard-text} encoded with
     * {@link #charset() charset()}.
     */
    public MultipartBody epilogue(@NonNull String epilogue) {
        return epilogue(epilogue, charset);
    }

    public MultipartBody epilogue(@NonNull String epilogue, @NonNull Charset charset) {
        return epilogue(epilogue.getBytes(charset));
    }

    /**
     * Sets the epilogue {@code discard-text}.
     */
    public MultipartBody epilogue(@NonNull byte[] epilogue) {
        this.epilogue = epilogue;
        return this;
    }

    public MultipartBody header(@NonNull String name, @NonNull String value) {
        headers().replaceValues(name, Collections.singleton(value));
        return this;
    }

    public MultipartBody addHeader(@NonNull String name, @NonNull String value) {
        headers().put(name, value);
        return this;
    }

    public MultipartBody contentType(@NonNull String contentType) {
        return contentType(Headers.parseMediaType(contentType));
    }

    public MultipartBody contentType(@NonNull Headers.ContentType contentType) {
        return contentType(contentType.asMediaType().get());
    }

    public MultipartBody contentType(@NonNull MediaType contentType) {
        return setContentType(contentType, null, false);
    }

    private MultipartBody setContentType(@NonNull MediaType contentType, @Nullable String boundary, boolean init) {
        if (boundary != null) {
            contentType = contentType.withParameter(BOUNDARY, boundary);
        } else {
            List<String> b = contentType.parameters().get(BOUNDARY);
            if (b.isEmpty()) {
                boundary = init ? generateBoundary() : this.boundary;
                contentType = contentType.withParameter(BOUNDARY, boundary);
            } else if (b.size() > 1) {
                throw new IllegalArgumentException(
                        String.format("type (%s) has multiple boundary parameters", contentType));
            } else {
                boundary = b.get(0);
            }
        }
        checkBoundary(boundary);
        boundaryBytes = CharUtils.encodeAscii(boundary);
        this.boundary = boundary;
        headers.replaceValues(HttpHeaders.CONTENT_TYPE, Collections.singleton(contentType.toString()));
        type = contentType;
        valid = false;
        return this;
    }

    public MultipartBody boundary(@NonNull String boundary) {
        setContentType(type, boundary, false);
        return this;
    }

    /**
     * Sets the {@code Charset} used to encode the part headers (defaults to
     * {@code UTF-8}).
     */
    public MultipartBody charset(@NonNull Charset charset) {
        if (!this.charset.equals(charset)) {
            this.charset = charset;
            valid = false;
        }
        return this;
    }

    /**
     * Adds the specified part.
     */
    public MultipartBody part(@NonNull RequestBody part) {
        checkArgument(part != this, "cannot add this instance as a part");
        parts.add(part);
        valid = false;
        return this;
    }

    /**
     * Returns modifiable {@code Set} of all body parts. The returned set will
     * preserve iteration order of added parts.
     */
    public Set<RequestBody> parts() {
        if (partsView == null) {
            partsView = Predicated.set(Listenable.set(parts, new Listenable.Modification() {
                @Override
                public void onModify(Object src, Listenable.Event type) {
                    valid = false;
                }
            }), new Predicate<RequestBody>() {
                @Override
                public boolean apply(RequestBody input) {
                    return input != null && input != MultipartBody.this;
                }
            });
        }
        return partsView;
    }

    /**
     * Returns a modifiable view of the headers. Any header may be added
     * <i>except</i> {@code Content-Type} and {@code Content-Length}.
     */
    @Override
    public ListMultimap<String, String> headers() {
        if (headersView == null) {
            headersView = Predicated.listMultimap(headers, Requests.HKEY_NO_TYPE_LEN, Predicates.notNull());
        }
        return headersView;
    }

    private boolean checkNotEmpty() throws IOException {
        if (parts.isEmpty()) {
            throw new IOException("must have at least one body part");
        }
        return true;
    }

    @Override
    public boolean hasKnownSize() throws IOException {
        buildCache();
        return length > 0;
    }

    @Override
    public long copyTo(OutputStream output) throws IOException {
        buildCache();
        return doCopyTo(output);
    }

    private long doCopyTo(OutputStream output) throws IOException {
        assert checkNotEmpty();
        CountingOutputStream counter = new CountingOutputStream(output);
        counter.write(preamble);
        for (PartCache pc : partCache) {
            counter.write(pc.bytes);
            pc.part.copyTo(counter);
        }
        counter.write(CRLF);
        counter.write(DASHDASH);
        counter.write(boundaryBytes);
        counter.write(DASHDASH);
        counter.write(CRLF);
        counter.write(epilogue);
        return counter.getCount();
    }

    @Override
    public InputStream openStream() throws IOException {
        BAOutputStream out;
        if (hasKnownSize()) { //builds cache
            out = new BAOutputStream(Ints.checkedCast(computeSize(length)));
        } else {
            out = new BAOutputStream(4096);
        }
        doCopyTo(out);
        return out.toInputStream();
    }

    @Override
    public InputStream openBufferedStream() throws IOException {
        return openStream();
    }

    @Override
    public long size() throws IOException {
        if (hasKnownSize()) {
            return computeSize(length);
        }
        return computeSizeSlow();
    }

    private long computeSize(long length) {
        assert length > 0 : length;
        return preamble.length + length + CRLF.length + DASHDASH.length + boundaryBytes.length + DASHDASH.length
                + CRLF.length + epilogue.length;
    }

    private long computeSizeSlow() throws IOException {
        @SuppressWarnings("LocalVariableHidesMemberVariable")
        long length = 0;
        for (PartCache pc : partCache) {
            length += pc.bytes.length;
            long size = pc.part.size();
            if (size < 0) {
                size = pc.part.copyTo(ByteUtils.nullOutputStream());
            }
            length += size;
        }
        return computeSize(length);
    }

    @Override
    public boolean isEmpty() throws IOException {
        checkNotEmpty();
        return false;
    }

    private byte[] buildHeaderBytes(RequestBody part, long size) throws IOException {
        if (buffer == null) {
            buffer = new BAOutputStream(256);
        } else {
            buffer.reset();
        }
        buffer.write(CRLF);
        buffer.write(DASHDASH);
        buffer.write(boundaryBytes);
        buffer.write(CRLF);
        for (Map.Entry<String, String> h : Iterables.filter(part.headers().entries(),
                Requests.NOT_CONTENT_LENGTH)) {
            buffer.write(h.getKey(), charset);
            buffer.write(COLONSPACE);
            buffer.write(h.getValue(), charset);
            buffer.write(CRLF);
        }
        if (size >= 0) {
            buffer.write(CONTENT_LENGTH);
            buffer.write(COLONSPACE);
            buffer.writeAscii(String.valueOf(size));
            buffer.write(CRLF);
        }
        buffer.write(CRLF);
        return buffer.toByteArray();
    }

    private void buildCache() throws IOException {
        checkNotEmpty();
        final long now = System.nanoTime();
        if (partCache == null) {
            partCache = Lists.newArrayList();
        } else if (valid && now - cacheTime <= MAX_AGE) {
            return;
        } else {
            partCache.clear();
        }
        computeCount++;
        length = 0;
        for (RequestBody part : parts) {
            final long size = part.hasKnownSize() ? part.size() : -1;
            if (length >= 0) {
                if (size >= 0) {
                    length += size;
                } else {
                    length = -1;
                }
            }
            byte[] bytes = buildHeaderBytes(part, size);
            partCache.add(new PartCache(part, bytes));
            if (length >= 0) {
                length += bytes.length;
            }
        }
        valid = true;
        cacheTime = now;
    }
}