/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.sidecar.cluster.driver;

import com.datastax.driver.core.Cluster;
import com.datastax.driver.core.Host;
import com.datastax.driver.core.HostDistance;
import com.datastax.driver.core.Statement;
import com.datastax.driver.core.policies.DCAwareRoundRobinPolicy;
import com.datastax.driver.core.policies.LoadBalancingPolicy;
import com.datastax.driver.core.policies.RoundRobinPolicy;
import com.google.common.collect.Iterators;
import java.net.InetSocketAddress;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.cassandra.sidecar.common.server.utils.DriverUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SidecarLoadBalancingPolicy
implements LoadBalancingPolicy {
    public static final int MIN_NON_LOCAL_CONNECTIONS = 2;
    private static final Logger LOGGER = LoggerFactory.getLogger(SidecarLoadBalancingPolicy.class);
    private final Set<Host> selectedHosts = new HashSet<Host>();
    private final Set<InetSocketAddress> localHostAddresses;
    private final DriverUtils driverUtils;
    private final LoadBalancingPolicy childPolicy;
    private final int totalRequestedConnections;
    private final Random random = new Random();
    private final HashSet<Host> allHosts = new HashSet();
    private Cluster cluster;

    public SidecarLoadBalancingPolicy(List<InetSocketAddress> localHostAddresses, String localDc, int numAdditionalConnections, DriverUtils driverUtils) {
        this.childPolicy = this.createChildPolicy(localDc);
        this.localHostAddresses = new HashSet<InetSocketAddress>(localHostAddresses);
        this.driverUtils = driverUtils;
        if (numAdditionalConnections < 2) {
            LOGGER.warn("Additional instances requested was {}, which is less than the minimum of {}. Using {}.", new Object[]{numAdditionalConnections, 2, 2});
            numAdditionalConnections = 2;
        }
        this.totalRequestedConnections = this.localHostAddresses.size() + numAdditionalConnections;
    }

    public void init(Cluster cluster, Collection<Host> hosts) {
        this.cluster = cluster;
        this.allHosts.addAll(hosts);
        this.recalculateSelectedHosts();
        this.childPolicy.init(cluster, hosts);
    }

    public HostDistance distance(Host host) {
        if (this.selectedHosts.contains(host) || this.isLocalHost(host)) {
            return this.childPolicy.distance(host);
        }
        return HostDistance.IGNORED;
    }

    public Iterator<Host> newQueryPlan(String loggedKeyspace, Statement statement) {
        Iterator child = this.childPolicy.newQueryPlan(loggedKeyspace, statement);
        return Iterators.filter((Iterator)child, this.selectedHosts::contains);
    }

    public synchronized void onAdd(Host host) {
        this.onUp(host);
        this.childPolicy.onAdd(host);
    }

    public synchronized void onUp(Host host) {
        this.allHosts.add(host);
        if (this.selectedHosts.size() < this.totalRequestedConnections) {
            this.recalculateSelectedHosts();
        }
        this.childPolicy.onUp(host);
    }

    public synchronized void onDown(Host host) {
        if (this.localHostAddresses.contains(this.driverUtils.getSocketAddress(host))) {
            LOGGER.debug("Local Node {} has been marked down.", (Object)host);
            return;
        }
        boolean wasSelected = this.selectedHosts.remove(host);
        if (!wasSelected) {
            this.driverUtils.startPeriodicReconnectionAttempt(this.cluster, host);
        }
        this.recalculateSelectedHosts();
        this.childPolicy.onDown(host);
    }

    public synchronized void onRemove(Host host) {
        this.allHosts.remove(host);
        this.onDown(host);
        this.childPolicy.onRemove(host);
    }

    public void close() {
        this.childPolicy.close();
    }

    private LoadBalancingPolicy createChildPolicy(String localDc) {
        if (localDc != null) {
            return DCAwareRoundRobinPolicy.builder().withLocalDc(localDc).build();
        }
        return new RoundRobinPolicy();
    }

    private synchronized void recalculateSelectedHosts() {
        Map<Boolean, List<Host>> partitionedHosts = this.allHosts.stream().collect(Collectors.partitioningBy(this::isLocalHost));
        List<Host> localHosts = partitionedHosts.get(true);
        int numLocalHostsConfigured = this.localHostAddresses.size();
        if (localHosts == null || localHosts.isEmpty()) {
            LOGGER.warn("Did not find any local hosts in allHosts.");
        } else {
            if (localHosts.size() < numLocalHostsConfigured) {
                LOGGER.warn("Could not find all configured local hosts in host list. ConfiguredHosts={} AvailableHosts={}", (Object)numLocalHostsConfigured, (Object)localHosts.size());
            }
            this.selectedHosts.addAll(localHosts);
        }
        int requiredNonLocalHosts = this.totalRequestedConnections - this.selectedHosts.size();
        if (requiredNonLocalHosts > 0) {
            List<Object> nonLocalHosts = partitionedHosts.get(false);
            if (nonLocalHosts == null || nonLocalHosts.isEmpty()) {
                LOGGER.debug("Did not find any non-local hosts in allHosts");
                return;
            }
            if ((nonLocalHosts = nonLocalHosts.stream().filter(h -> !this.selectedHosts.contains(h) && h.isUp()).collect(Collectors.toList())).size() < requiredNonLocalHosts) {
                LOGGER.warn("Could not find enough new, up non-local hosts to meet requested number {}", (Object)requiredNonLocalHosts);
            } else {
                LOGGER.debug("Found enough new, up, non-local hosts to meet requested number {}", (Object)requiredNonLocalHosts);
            }
            if (nonLocalHosts.size() > requiredNonLocalHosts) {
                Collections.shuffle(nonLocalHosts, this.random);
            }
            int hostsToAdd = Math.min(requiredNonLocalHosts, nonLocalHosts.size());
            for (int i = 0; i < hostsToAdd; ++i) {
                this.selectedHosts.add((Host)nonLocalHosts.get(i));
            }
        }
    }

    private boolean isLocalHost(Host host) {
        return this.localHostAddresses.contains(this.driverUtils.getSocketAddress(host));
    }
}

