org.killbill.billing.util.security.shiro.realm.KillBillJndiLdapRealm.java Source code

Java tutorial

Introduction

Here is the source code for org.killbill.billing.util.security.shiro.realm.KillBillJndiLdapRealm.java

Source

/*
 * Copyright 2010-2013 Ning, Inc.
 *
 * Ning licenses this file to you under the Apache License, version 2.0
 * (the "License"); you may not use this file except in compliance with the
 * License.  You may obtain a copy of the License at:
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */

package org.killbill.billing.util.security.shiro.realm;

import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

import javax.naming.NamingEnumeration;
import javax.naming.NamingException;
import javax.naming.directory.Attribute;
import javax.naming.directory.SearchControls;
import javax.naming.directory.SearchResult;
import javax.naming.ldap.LdapContext;

import org.apache.shiro.authc.AuthenticationException;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.authz.SimpleAuthorizationInfo;
import org.apache.shiro.config.Ini;
import org.apache.shiro.config.Ini.Section;
import org.apache.shiro.realm.ldap.JndiLdapContextFactory;
import org.apache.shiro.realm.ldap.JndiLdapRealm;
import org.apache.shiro.realm.ldap.LdapContextFactory;
import org.apache.shiro.realm.ldap.LdapUtils;
import org.apache.shiro.subject.PrincipalCollection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.killbill.billing.util.config.definition.SecurityConfig;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterators;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.inject.Inject;

public class KillBillJndiLdapRealm extends JndiLdapRealm {

    private static final Logger log = LoggerFactory.getLogger(KillBillJndiLdapRealm.class);

    private static final String USERDN_SUBSTITUTION_TOKEN = "{0}";

    private static final SearchControls SUBTREE_SCOPE = new SearchControls();

    static {
        SUBTREE_SCOPE.setSearchScope(SearchControls.SUBTREE_SCOPE);
    }

    private static final Splitter SPLITTER = Splitter.on(',').omitEmptyStrings().trimResults();

    private final String searchBase;
    private final String groupSearchFilter;
    private final String groupNameId;
    private final Map<String, Collection<String>> permissionsByGroup = Maps.newLinkedHashMap();

    @Inject
    public KillBillJndiLdapRealm(final SecurityConfig securityConfig) {
        super();

        if (securityConfig.getShiroLDAPUserDnTemplate() != null) {
            setUserDnTemplate(securityConfig.getShiroLDAPUserDnTemplate());
        }

        final JndiLdapContextFactory contextFactory = (JndiLdapContextFactory) getContextFactory();
        if (securityConfig.disableShiroLDAPSSLCheck()) {
            contextFactory.getEnvironment().put("java.naming.ldap.factory.socket",
                    SkipSSLCheckSocketFactory.class.getName());
        }
        if (securityConfig.getShiroLDAPUrl() != null) {
            contextFactory.setUrl(securityConfig.getShiroLDAPUrl());
        }
        if (securityConfig.getShiroLDAPSystemUsername() != null) {
            contextFactory.setSystemUsername(securityConfig.getShiroLDAPSystemUsername());
        }
        if (securityConfig.getShiroLDAPSystemPassword() != null) {
            contextFactory.setSystemPassword(securityConfig.getShiroLDAPSystemPassword());
        }
        if (securityConfig.getShiroLDAPAuthenticationMechanism() != null) {
            contextFactory.setAuthenticationMechanism(securityConfig.getShiroLDAPAuthenticationMechanism());
        }
        setContextFactory(contextFactory);

        searchBase = securityConfig.getShiroLDAPSearchBase();
        groupSearchFilter = securityConfig.getShiroLDAPGroupSearchFilter();
        groupNameId = securityConfig.getShiroLDAPGroupNameID();

        if (securityConfig.getShiroLDAPPermissionsByGroup() != null) {
            final Ini ini = new Ini();
            // When passing properties on the command line, \n can be escaped
            ini.load(securityConfig.getShiroLDAPPermissionsByGroup().replace("\\n", "\n"));
            for (final Section section : ini.getSections()) {
                for (final String role : section.keySet()) {
                    final Collection<String> permissions = ImmutableList
                            .<String>copyOf(SPLITTER.split(section.get(role)));
                    permissionsByGroup.put(role, permissions);
                }
            }
        }
    }

    @Override
    protected AuthorizationInfo queryForAuthorizationInfo(final PrincipalCollection principals,
            final LdapContextFactory ldapContextFactory) throws NamingException {
        final Set<String> userGroups = findLDAPGroupsForUser(principals, ldapContextFactory);

        final SimpleAuthorizationInfo simpleAuthorizationInfo = new SimpleAuthorizationInfo(userGroups);
        final Set<String> stringPermissions = groupsPermissions(userGroups);
        simpleAuthorizationInfo.setStringPermissions(stringPermissions);

        return simpleAuthorizationInfo;
    }

    private Set<String> findLDAPGroupsForUser(final PrincipalCollection principals,
            final LdapContextFactory ldapContextFactory) throws NamingException {
        final String username = (String) getAvailablePrincipal(principals);

        LdapContext systemLdapCtx = null;
        try {
            systemLdapCtx = ldapContextFactory.getSystemLdapContext();
            return findLDAPGroupsForUser(username, systemLdapCtx);
        } catch (AuthenticationException ex) {
            log.info("LDAP authentication exception='{}'", ex.getLocalizedMessage());
            return ImmutableSet.<String>of();
        } finally {
            LdapUtils.closeContext(systemLdapCtx);
        }
    }

    private Set<String> findLDAPGroupsForUser(final String userName, final LdapContext ldapCtx)
            throws NamingException {
        final NamingEnumeration<SearchResult> foundGroups = ldapCtx.search(searchBase,
                groupSearchFilter.replace(USERDN_SUBSTITUTION_TOKEN, userName), SUBTREE_SCOPE);

        // Extract the name of all the groups
        final Iterator<SearchResult> groupsIterator = Iterators.<SearchResult>forEnumeration(foundGroups);
        final Iterator<String> groupsNameIterator = Iterators.<SearchResult, String>transform(groupsIterator,
                new Function<SearchResult, String>() {
                    @Override
                    public String apply(final SearchResult groupEntry) {
                        return extractGroupNameFromSearchResult(groupEntry);
                    }
                });
        final Iterator<String> finalGroupsNameIterator = Iterators.<String>filter(groupsNameIterator,
                Predicates.notNull());

        return Sets.newHashSet(finalGroupsNameIterator);
    }

    private String extractGroupNameFromSearchResult(final SearchResult searchResult) {
        // Get all attributes for that group
        final Iterator<? extends Attribute> attributesIterator = Iterators
                .forEnumeration(searchResult.getAttributes().getAll());

        // Find the attribute representing the group name
        final Iterator<? extends Attribute> groupNameAttributesIterator = Iterators.filter(attributesIterator,
                new Predicate<Attribute>() {
                    @Override
                    public boolean apply(final Attribute attribute) {
                        return groupNameId.equalsIgnoreCase(attribute.getID());
                    }
                });

        // Extract the group name from the attribute
        // Note: at this point, groupNameAttributesIterator should really contain a single element
        final Iterator<String> groupNamesIterator = Iterators.transform(groupNameAttributesIterator,
                new Function<Attribute, String>() {
                    @Override
                    public String apply(final Attribute groupNameAttribute) {
                        try {
                            final NamingEnumeration<?> enumeration = groupNameAttribute.getAll();
                            if (enumeration.hasMore()) {
                                return enumeration.next().toString();
                            } else {
                                return null;
                            }
                        } catch (NamingException namingException) {
                            log.warn("Unable to read group name", namingException);
                            return null;
                        }
                    }
                });
        final Iterator<String> finalGroupNamesIterator = Iterators.<String>filter(groupNamesIterator,
                Predicates.notNull());

        if (finalGroupNamesIterator.hasNext()) {
            return finalGroupNamesIterator.next();
        } else {
            log.warn("Unable to find an attribute matching {}", groupNameId);
            return null;
        }
    }

    private Set<String> groupsPermissions(final Set<String> groups) {
        final Set<String> permissions = new HashSet<String>();
        for (final String group : groups) {
            final Collection<String> permissionsForGroup = permissionsByGroup.get(group);
            if (permissionsForGroup != null) {
                permissions.addAll(permissionsForGroup);
            }
        }
        return permissions;
    }

    @VisibleForTesting
    public Map<String, Collection<String>> getPermissionsByGroup() {
        return permissionsByGroup;
    }
}