/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.util.infotheory;

import com.oracle.labs.mlrg.olcut.util.MutableLong;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.DoubleStream;
import java.util.stream.Stream;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;
import org.tribuo.util.infotheory.impl.CachedPair;
import org.tribuo.util.infotheory.impl.CachedTriple;
import org.tribuo.util.infotheory.impl.PairDistribution;
import org.tribuo.util.infotheory.impl.RowList;
import org.tribuo.util.infotheory.impl.TripleDistribution;

public final class InformationTheory {
    private static final Logger logger = Logger.getLogger(InformationTheory.class.getName());
    public static final double SAMPLES_RATIO = 5.0;
    public static final int DEFAULT_MAP_SIZE = 20;
    public static final double LOG_2 = Math.log(2.0);
    public static final double LOG_E = Math.log(Math.E);
    public static double LOG_BASE = LOG_2;

    private InformationTheory() {
    }

    public static <T1, T2> double mi(Set<List<T1>> first, Set<List<T2>> second) {
        RowList firstList = new RowList(first);
        RowList secondList = new RowList(second);
        return InformationTheory.mi(firstList, secondList);
    }

    public static <T1, T2, T3> double cmi(List<T1> first, List<T2> second, Set<List<T3>> condition) {
        if (condition.isEmpty()) {
            return InformationTheory.mi(first, second);
        }
        RowList conditionList = new RowList(condition);
        return InformationTheory.conditionalMI(first, second, conditionList);
    }

    public static <T1, T2, T3> GTestStatistics gTest(List<T1> first, List<T2> second, Set<List<T3>> condition) {
        ScoreStateCountTuple tuple;
        if (condition == null) {
            tuple = InformationTheory.innerMI(first, second);
        } else if (condition.isEmpty()) {
            tuple = InformationTheory.innerMI(first, second);
        } else {
            RowList conditionList = new RowList(condition);
            tuple = InformationTheory.innerConditionalMI(first, second, conditionList);
        }
        double gMetric = (double)(2 * second.size()) * tuple.score;
        ChiSquaredDistribution dist = new ChiSquaredDistribution((double)tuple.stateCount);
        double prob = dist.cumulativeProbability(gMetric);
        GTestStatistics test = new GTestStatistics(gMetric, tuple.stateCount, prob);
        return test;
    }

    public static <T1, T2, T3> double jointMI(List<T1> first, List<T2> second, List<T3> target) {
        if (first.size() == second.size() && first.size() == target.size()) {
            TripleDistribution<T1, T2, T3> tripleRV = TripleDistribution.constructFromLists(first, second, target);
            return InformationTheory.jointMI(tripleRV);
        }
        throw new IllegalArgumentException("Joint Mutual Information requires three vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", target.size() = " + target.size());
    }

    public static <T1, T2, T3> double jointMI(TripleDistribution<T1, T2, T3> rv) {
        double vecLength = rv.count;
        Map<CachedTriple<T1, T2, T3>, MutableLong> jointCount = rv.getJointCount();
        Map<CachedPair<T1, T2>, MutableLong> abCount = rv.getABCount();
        Map<T3, MutableLong> cCount = rv.getCCount();
        double jmi = 0.0;
        for (Map.Entry<CachedTriple<T1, T2, T3>, MutableLong> e : jointCount.entrySet()) {
            double jointCurCount = e.getValue().doubleValue();
            double prob = jointCurCount / vecLength;
            CachedPair<T1, T2> pair = e.getKey().getAB();
            double abCurCount = abCount.get(pair).doubleValue();
            double cCurCount = cCount.get(e.getKey().getC()).doubleValue();
            jmi += prob * Math.log(vecLength * jointCurCount / (abCurCount * cCurCount));
        }
        jmi /= LOG_BASE;
        double stateRatio = vecLength / (double)jointCount.size();
        if (stateRatio < 5.0) {
            logger.log(Level.INFO, "Joint MI estimate of {0} had samples/state ratio of {1}, with {2} observations and {3} states", new Object[]{jmi, stateRatio, vecLength, jointCount.size()});
        }
        return jmi;
    }

    private static <T1, T2, T3> ScoreStateCountTuple innerConditionalMI(TripleDistribution<T1, T2, T3> rv, boolean flipped) {
        Map<CachedTriple<T1, T2, T3>, MutableLong> jointCount = rv.getJointCount();
        Map<CachedPair<T1, T2>, MutableLong> abCount = rv.getABCount();
        Map<CachedPair<T1, T3>, MutableLong> acCount = rv.getACCount();
        Map<CachedPair<T2, T3>, MutableLong> bcCount = rv.getBCCount();
        Map<T2, MutableLong> bCount = rv.getBCount();
        Map<T3, MutableLong> cCount = rv.getCCount();
        double vectorLength = rv.count;
        double cmi = 0.0;
        if (flipped) {
            for (Map.Entry<CachedTriple<T1, T2, T3>, MutableLong> e : jointCount.entrySet()) {
                double jointCurCount = e.getValue().doubleValue();
                double prob = jointCurCount / vectorLength;
                CachedPair<T1, T2> abPair = e.getKey().getAB();
                CachedPair<T2, T3> bcPair = e.getKey().getBC();
                double abCurCount = abCount.get(abPair).doubleValue();
                double bcCurCount = bcCount.get(bcPair).doubleValue();
                double bCurCount = bCount.get(e.getKey().getB()).doubleValue();
                cmi += prob * Math.log(bCurCount * jointCurCount / (abCurCount * bcCurCount));
            }
        } else {
            for (Map.Entry<CachedTriple<T1, T2, T3>, MutableLong> e : jointCount.entrySet()) {
                double jointCurCount = e.getValue().doubleValue();
                double prob = jointCurCount / vectorLength;
                CachedPair<T1, T3> acPair = e.getKey().getAC();
                CachedPair<T2, T3> bcPair = e.getKey().getBC();
                double acCurCount = acCount.get(acPair).doubleValue();
                double bcCurCount = bcCount.get(bcPair).doubleValue();
                double cCurCount = cCount.get(e.getKey().getC()).doubleValue();
                cmi += prob * Math.log(cCurCount * jointCurCount / (acCurCount * bcCurCount));
            }
        }
        cmi /= LOG_BASE;
        double stateRatio = vectorLength / (double)jointCount.size();
        if (stateRatio < 5.0) {
            logger.log(Level.INFO, "Conditional MI estimate of {0} had samples/state ratio of {1}", new Object[]{cmi, stateRatio});
        }
        return new ScoreStateCountTuple(cmi, jointCount.size());
    }

    private static <T1, T2, T3> ScoreStateCountTuple innerConditionalMI(List<T1> first, List<T2> second, List<T3> condition) {
        if (first.size() == second.size() && first.size() == condition.size()) {
            TripleDistribution<T1, T2, T3> tripleRV = TripleDistribution.constructFromLists(first, second, condition);
            return InformationTheory.innerConditionalMI(tripleRV, false);
        }
        throw new IllegalArgumentException("Conditional Mutual Information requires three vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size() + ", condition.size() = " + condition.size());
    }

    public static <T1, T2, T3> double conditionalMI(List<T1> first, List<T2> second, List<T3> condition) {
        return InformationTheory.innerConditionalMI(first, second, condition).score;
    }

    public static <T1, T2, T3> double conditionalMI(TripleDistribution<T1, T2, T3> rv) {
        return InformationTheory.innerConditionalMI(rv, (boolean)false).score;
    }

    public static <T1, T2, T3> double conditionalMIFlipped(TripleDistribution<T1, T2, T3> rv) {
        return InformationTheory.innerConditionalMI(rv, (boolean)true).score;
    }

    private static <T1, T2> ScoreStateCountTuple innerMI(PairDistribution<T1, T2> pairDist) {
        Map countDist = pairDist.jointCounts;
        Map firstCountDist = pairDist.firstCount;
        Map secondCountDist = pairDist.secondCount;
        double vectorLength = pairDist.count;
        double mi = 0.0;
        boolean error = false;
        for (Map.Entry e : countDist.entrySet()) {
            double secondProb;
            double jointCount = e.getValue().doubleValue();
            double prob = jointCount / vectorLength;
            double top = vectorLength * jointCount;
            double firstProb = firstCountDist.get(e.getKey().getA()).doubleValue();
            double bottom = firstProb * (secondProb = secondCountDist.get(e.getKey().getB()).doubleValue());
            double ratio = top / bottom;
            double logRatio = Math.log(ratio);
            if (Double.isNaN(logRatio) || Double.isNaN(prob) || Double.isNaN(mi)) {
                logger.log(Level.WARNING, "State = " + e.getKey().toString());
                logger.log(Level.WARNING, "mi = " + mi + " prob = " + prob + " top = " + top + " bottom = " + bottom + " ratio = " + ratio + " logRatio = " + logRatio);
                error = true;
            }
            mi += prob * logRatio;
        }
        mi /= LOG_BASE;
        double stateRatio = vectorLength / (double)countDist.size();
        if (stateRatio < 5.0) {
            logger.log(Level.INFO, "MI estimate of {0} had samples/state ratio of {1}", new Object[]{mi, stateRatio});
        }
        if (error) {
            logger.log(Level.SEVERE, "NanFound ", new IllegalStateException("NaN found"));
        }
        return new ScoreStateCountTuple(mi, countDist.size());
    }

    private static <T1, T2> ScoreStateCountTuple innerMI(List<T1> first, List<T2> second) {
        if (first.size() == second.size()) {
            PairDistribution<T1, T2> pairDist = PairDistribution.constructFromLists(first, second);
            return InformationTheory.innerMI(pairDist);
        }
        throw new IllegalArgumentException("Mutual Information requires two vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size());
    }

    public static <T1, T2> double mi(List<T1> first, List<T2> second) {
        return InformationTheory.innerMI(first, second).score;
    }

    public static <T1, T2> double mi(PairDistribution<T1, T2> pairDist) {
        return InformationTheory.innerMI(pairDist).score;
    }

    public static <T1, T2> double jointEntropy(List<T1> first, List<T2> second) {
        if (first.size() == second.size()) {
            double vectorLength = first.size();
            double jointEntropy = 0.0;
            PairDistribution<T1, T2> countPair = PairDistribution.constructFromLists(first, second);
            Map countDist = countPair.jointCounts;
            for (Map.Entry e : countDist.entrySet()) {
                double prob = e.getValue().doubleValue() / vectorLength;
                jointEntropy -= prob * Math.log(prob);
            }
            jointEntropy /= LOG_BASE;
            double stateRatio = vectorLength / (double)countDist.size();
            if (stateRatio < 5.0) {
                logger.log(Level.INFO, "Joint Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{jointEntropy, stateRatio});
            }
            return jointEntropy;
        }
        throw new IllegalArgumentException("Joint Entropy requires two vectors the same length. first.size() = " + first.size() + ", second.size() = " + second.size());
    }

    public static <T1, T2> double conditionalEntropy(List<T1> vector, List<T2> condition) {
        if (vector.size() == condition.size()) {
            double vectorLength = vector.size();
            double condEntropy = 0.0;
            PairDistribution<T1, T2> countPair = PairDistribution.constructFromLists(vector, condition);
            Map countDist = countPair.jointCounts;
            Map conditionCountDist = countPair.secondCount;
            for (Map.Entry e : countDist.entrySet()) {
                double prob = e.getValue().doubleValue() / vectorLength;
                double condProb = conditionCountDist.get(e.getKey().getB()).doubleValue() / vectorLength;
                condEntropy -= prob * Math.log(prob / condProb);
            }
            condEntropy /= LOG_BASE;
            double stateRatio = vectorLength / (double)countDist.size();
            if (stateRatio < 5.0) {
                logger.log(Level.INFO, "Conditional Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{condEntropy, stateRatio});
            }
            return condEntropy;
        }
        throw new IllegalArgumentException("Conditional Entropy requires two vectors the same length. vector.size() = " + vector.size() + ", condition.size() = " + condition.size());
    }

    public static <T> double entropy(List<T> vector) {
        double vectorLength = vector.size();
        double entropy = 0.0;
        Map<T, Long> countDist = InformationTheory.calculateCountDist(vector);
        for (Map.Entry<T, Long> e : countDist.entrySet()) {
            double prob = (double)e.getValue().longValue() / vectorLength;
            entropy -= prob * Math.log(prob);
        }
        entropy /= LOG_BASE;
        double stateRatio = vectorLength / (double)countDist.size();
        if (stateRatio < 5.0) {
            logger.log(Level.INFO, "Entropy estimate of {0} had samples/state ratio of {1}", new Object[]{entropy, stateRatio});
        }
        return entropy;
    }

    public static <T> Map<T, Long> calculateCountDist(List<T> vector) {
        HashMap<T, Long> countDist = new HashMap<T, Long>(20);
        for (T e : vector) {
            Long curCount = countDist.getOrDefault(e, 0L);
            curCount = curCount + 1L;
            countDist.put(e, curCount);
        }
        return countDist;
    }

    public static double calculateEntropy(Stream<Double> vector) {
        return vector.map(p -> -p.doubleValue() * Math.log(p) / LOG_BASE).reduce(0.0, Double::sum);
    }

    public static double calculateEntropy(DoubleStream vector) {
        return vector.map(p -> -p * Math.log(p) / LOG_BASE).sum();
    }

    public static final class GTestStatistics {
        public final double gStatistic;
        public final int numStates;
        public final double probability;

        public GTestStatistics(double gStatistic, int numStates, double probability) {
            this.gStatistic = gStatistic;
            this.numStates = numStates;
            this.probability = probability;
        }

        public String toString() {
            return "GTest(statistic=" + this.gStatistic + ",probability=" + this.probability + ",numStates=" + this.numStates + ")";
        }
    }

    private static class ScoreStateCountTuple {
        public final double score;
        public final int stateCount;

        ScoreStateCountTuple(double score, int stateCount) {
            this.score = score;
            this.stateCount = stateCount;
        }

        public String toString() {
            return "ScoreStateCount(score=" + this.score + ",stateCount=" + this.stateCount + ")";
        }
    }
}

