org.apache.samza.util.EmbeddedTaggedRateLimiter.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.samza.util.EmbeddedTaggedRateLimiter.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.samza.util;

import com.google.common.base.Preconditions;
import com.google.common.base.Stopwatch;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.samza.container.TaskName;
import org.apache.samza.context.Context;
import org.apache.samza.context.TaskContextImpl;
import org.apache.samza.job.model.JobModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import static java.util.concurrent.TimeUnit.NANOSECONDS;

/**
 * An embedded rate limiter that supports tags. A default tag will be used if users specifies a simple rate only
 * for simple use cases.
 */
public class EmbeddedTaggedRateLimiter implements RateLimiter {
    static final private Logger LOGGER = LoggerFactory.getLogger(EmbeddedTaggedRateLimiter.class);
    private static final String DEFAULT_TAG = "default-tag";
    private static final Map<String, Integer> DEFAULT_TAG_MAP = Collections.singletonMap(DEFAULT_TAG, 0);

    private final Map<String, Integer> tagToTargetRateMap;
    private Map<String, com.google.common.util.concurrent.RateLimiter> tagToRateLimiterMap;
    private boolean initialized;

    public EmbeddedTaggedRateLimiter(int creditsPerSecond) {
        this(Collections.singletonMap(DEFAULT_TAG, creditsPerSecond));
    }

    public EmbeddedTaggedRateLimiter(Map<String, Integer> tagToCreditsPerSecondMap) {
        Preconditions.checkArgument(tagToCreditsPerSecondMap.size() > 0, "Map of tags can't be empty");
        tagToCreditsPerSecondMap.values()
                .forEach(c -> Preconditions.checkArgument(c >= 0, "Credits must be non-negative"));
        this.tagToTargetRateMap = tagToCreditsPerSecondMap;
    }

    @Override
    public void acquire(Map<String, Integer> tagToCreditsMap) {
        ensureTagsAreValid(tagToCreditsMap);
        tagToCreditsMap.forEach((tag, numberOfCredits) -> tagToRateLimiterMap.get(tag).acquire(numberOfCredits));
    }

    @Override
    public Map<String, Integer> acquire(Map<String, Integer> tagToCreditsMap, long timeout, TimeUnit unit) {
        ensureTagsAreValid(tagToCreditsMap);

        long timeoutInNanos = NANOSECONDS.convert(timeout, unit);

        Stopwatch stopwatch = Stopwatch.createStarted();
        return tagToCreditsMap.entrySet().stream().map(e -> {
            String tag = e.getKey();
            int requiredCredits = e.getValue();
            long remainingTimeoutInNanos = Math.max(0L, timeoutInNanos - stopwatch.elapsed(NANOSECONDS));
            com.google.common.util.concurrent.RateLimiter rateLimiter = tagToRateLimiterMap.get(tag);
            int availableCredits = rateLimiter.tryAcquire(requiredCredits, remainingTimeoutInNanos, NANOSECONDS)
                    ? requiredCredits
                    : 0;
            return new ImmutablePair<>(tag, availableCredits);
        }).collect(Collectors.toMap(ImmutablePair::getKey, ImmutablePair::getValue));
    }

    @Override
    public Set<String> getSupportedTags() {
        return Collections.unmodifiableSet(tagToRateLimiterMap.keySet());
    }

    @Override
    public void acquire(int numberOfCredits) {
        ensureTagsAreValid(DEFAULT_TAG_MAP);
        tagToRateLimiterMap.get(DEFAULT_TAG).acquire(numberOfCredits);
    }

    @Override
    public int acquire(int numberOfCredit, long timeout, TimeUnit unit) {
        ensureTagsAreValid(DEFAULT_TAG_MAP);
        return tagToRateLimiterMap.get(DEFAULT_TAG).tryAcquire(numberOfCredit, timeout, unit) ? numberOfCredit : 0;
    }

    @Override
    public void init(Context context) {
        this.tagToRateLimiterMap = Collections.unmodifiableMap(tagToTargetRateMap.entrySet().stream().map(e -> {
            String tag = e.getKey();
            JobModel jobModel = ((TaskContextImpl) context.getTaskContext()).getJobModel();
            int numTasks = jobModel.getContainers().values().stream().mapToInt(cm -> cm.getTasks().size()).sum();
            int effectiveRate = e.getValue() / numTasks;
            TaskName taskName = context.getTaskContext().getTaskModel().getTaskName();
            LOGGER.info(String.format("Effective rate limit for task %s and tag %s is %d", taskName, tag,
                    effectiveRate));
            return new ImmutablePair<>(tag, com.google.common.util.concurrent.RateLimiter.create(effectiveRate));
        }).collect(Collectors.toMap(ImmutablePair::getKey, ImmutablePair::getValue)));
        initialized = true;
    }

    private void ensureInitialized() {
        Preconditions.checkState(initialized, "Not initialized");
    }

    private void ensureTagsAreValid(Map<String, ?> tagMap) {
        ensureInitialized();
        tagMap.keySet().forEach(
                tag -> Preconditions.checkArgument(tagToRateLimiterMap.containsKey(tag), "Invalid tag: " + tag));
    }

}