/*
 * Decompiled with CFR 0.152.
 */
package org.ldaptive;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.ThreadLocalRandom;
import javax.naming.Context;
import javax.naming.NamingEnumeration;
import javax.naming.NamingException;
import javax.naming.directory.Attribute;
import javax.naming.directory.Attributes;
import javax.naming.directory.InitialDirContext;
import org.ldaptive.ConnectionFactoryMetadata;
import org.ldaptive.ConnectionStrategy;
import org.ldaptive.LdapUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DnsSrvConnectionStrategy
implements ConnectionStrategy {
    private static final String DNS_CONTEXT_FACTORY = "com.sun.jndi.dns.DnsContextFactory";
    private static final String DNS_PROVIDER_URL = "dns:";
    private static final long DEFAULT_TTL = 3600000L;
    protected final Logger logger = LoggerFactory.getLogger(this.getClass());
    private Map<String, Object> jndiEnv = new HashMap<String, Object>();
    private long srvTtl;
    private List<SrvRecord> srvRecords;

    public DnsSrvConnectionStrategy() {
        this(null, 3600000L);
    }

    public DnsSrvConnectionStrategy(long ttl) {
        this(null, ttl);
    }

    public DnsSrvConnectionStrategy(Map<String, Object> env, long ttl) {
        if (env != null) {
            this.setJndiEnvironment(env);
        }
        this.setTimeToLive(ttl);
    }

    public Map<String, Object> getJndiEnvironment() {
        return this.jndiEnv;
    }

    public long getTimeToLive() {
        return this.srvTtl;
    }

    public void setJndiEnvironment(Map<String, Object> env) {
        this.jndiEnv = new HashMap<String, Object>(env);
    }

    public void setTimeToLive(long ttl) {
        this.srvTtl = ttl;
    }

    @Override
    public String[] getLdapUrls(ConnectionFactoryMetadata metadata) {
        if (metadata == null || metadata.getLdapUrl() == null) {
            return null;
        }
        if (this.srvRecords == null || this.srvRecords.isEmpty() || System.currentTimeMillis() >= this.srvRecords.get(0).getExpirationTime()) {
            try {
                this.srvRecords = this.sortSrvRecords(this.retrieveDNSRecords(metadata.getLdapUrl(), this.jndiEnv, this.srvTtl));
            }
            catch (NamingException e) {
                throw new IllegalArgumentException("Could not retrieve DNS SRV record for " + metadata.getLdapUrl(), e);
            }
            if (this.srvRecords.isEmpty()) {
                throw new IllegalArgumentException("No DNS SRV records found for " + metadata.getLdapUrl());
            }
            this.logger.debug("Retrieved SRV records from DNS: {}", this.srvRecords);
        } else {
            this.logger.debug("Using SRV records from internal cache: {}", this.srvRecords);
        }
        String[] urls = new String[this.srvRecords.size()];
        for (int i = 0; i < this.srvRecords.size(); ++i) {
            urls[i] = this.srvRecords.get(i).getLdapURL();
        }
        return urls;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected List<SrvRecord> retrieveDNSRecords(String name, Map<String, Object> props, long ttl) throws NamingException {
        ArrayList<SrvRecord> records = new ArrayList<SrvRecord>();
        Context context = null;
        NamingEnumeration<?> en = null;
        try {
            Attribute attr;
            Attributes attrs;
            Hashtable<String, Object> env = new Hashtable<String, Object>(props);
            if (!env.containsKey("java.naming.factory.initial")) {
                env.put("java.naming.factory.initial", DNS_CONTEXT_FACTORY);
            }
            if (!env.containsKey("java.naming.provider.url")) {
                env.put("java.naming.provider.url", DNS_PROVIDER_URL);
            }
            if ((attrs = (context = new InitialDirContext(env)).getAttributes(name, new String[]{"SRV"})) != null && (attr = attrs.get("SRV")) != null) {
                en = attr.getAll();
                long expTime = System.currentTimeMillis() + ttl;
                while (en.hasMore()) {
                    records.add(new SrvRecord((String)en.next(), expTime));
                }
            }
        }
        finally {
            if (en != null) {
                en.close();
            }
            if (context != null) {
                context.close();
            }
        }
        return records;
    }

    protected List<SrvRecord> sortSrvRecords(List<SrvRecord> records) {
        TreeMap priorityRecords = new TreeMap();
        for (SrvRecord record : records) {
            List<SrvRecord> priority;
            if (!priorityRecords.containsKey(record.getPriority())) {
                priority = new ArrayList();
                priorityRecords.put(record.getPriority(), priority);
            } else {
                priority = (List)priorityRecords.get(record.getPriority());
            }
            priority.add(record);
        }
        ArrayList<SrvRecord> sortedRecords = new ArrayList<SrvRecord>();
        for (Map.Entry entry : priorityRecords.entrySet()) {
            HashMap<Long, SrvRecord> weighted = new HashMap<Long, SrvRecord>();
            ArrayList<SrvRecord> unweighted = new ArrayList<SrvRecord>();
            long totalWeight = 0L;
            for (SrvRecord record : (List)entry.getValue()) {
                if (record.getWeight() == 0L) {
                    unweighted.add(record);
                    continue;
                }
                weighted.put(totalWeight += record.getWeight(), record);
            }
            while (!weighted.isEmpty()) {
                SrvRecord record = null;
                Iterator i = weighted.keySet().iterator();
                long random = ThreadLocalRandom.current().nextLong(totalWeight + 1L);
                while (i.hasNext()) {
                    Long weight = (Long)i.next();
                    if (weight < random) continue;
                    record = (SrvRecord)weighted.get(weight);
                    totalWeight -= record.getWeight();
                    i.remove();
                    break;
                }
                sortedRecords.add(record);
            }
            if (unweighted.isEmpty()) continue;
            sortedRecords.addAll(unweighted);
        }
        return sortedRecords;
    }

    public String toString() {
        return String.format("[%s@%d::jndiEnv=%s, srvTtl=%s, srvRecords=%s]", this.getClass().getName(), this.hashCode(), this.jndiEnv, this.srvTtl, this.srvRecords);
    }

    protected static class SrvRecord {
        private static final int HASH_CODE_SEED = 1201;
        private final long priority;
        private final long weight;
        private final int port;
        private final String target;
        private final long expirationTime;

        public SrvRecord(String record, long time) {
            String[] parts = record.split(" ");
            int i = 0;
            this.priority = Long.parseLong(parts[i++]);
            this.weight = Long.parseLong(parts[i++]);
            this.port = Integer.parseInt(parts[i++]);
            this.target = parts[i].endsWith(".") ? parts[i].substring(0, parts[i].length() - 1) : parts[i];
            this.expirationTime = time;
        }

        public long getPriority() {
            return this.priority;
        }

        public long getWeight() {
            return this.weight;
        }

        public int getPort() {
            return this.port;
        }

        public String getTarget() {
            return this.target;
        }

        public String getLdapURL() {
            return String.format("ldap://%s:%s", this.target, this.port);
        }

        public long getExpirationTime() {
            return this.expirationTime;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (o instanceof SrvRecord) {
                SrvRecord v = (SrvRecord)o;
                return LdapUtils.areEqual(this.priority, v.priority) && LdapUtils.areEqual(this.weight, v.weight) && LdapUtils.areEqual(this.port, v.port) && LdapUtils.areEqual(this.target, v.target) && LdapUtils.areEqual(this.expirationTime, v.expirationTime);
            }
            return false;
        }

        public int hashCode() {
            return LdapUtils.computeHashCode(1201, this.priority, this.weight, this.port, this.target, this.expirationTime);
        }

        public String toString() {
            return String.format("[%s@%d::priority=%s, weight=%s, port=%s, target=%s, expirationTime=%s]", this.getClass().getName(), this.hashCode(), this.priority, this.weight, this.port, this.target, this.expirationTime);
        }
    }
}

