/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.matrix.data;

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Set;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
import org.apache.sysds.utils.Hash;

public class LibMatrixCountDistinct {
    private static final Log LOG = LogFactory.getLog((String)LibMatrixCountDistinct.class.getName());
    public static int minimumSize = 1024;

    private LibMatrixCountDistinct() {
    }

    public static int estimateDistinctValues(MatrixBlock in, CountDistinctOperator op) {
        int res = 0;
        if (op.operatorType == CountDistinctOperator.CountDistinctTypes.KMV && (op.hashType == Hash.HashType.ExpHash || op.hashType == Hash.HashType.StandardJava)) {
            throw new DMLException("Invalid hashing configuration using " + (Object)((Object)op.hashType) + " and " + (Object)((Object)op.operatorType));
        }
        if (op.operatorType == CountDistinctOperator.CountDistinctTypes.HLL) {
            throw new NotImplementedException("HyperLogLog not implemented");
        }
        if (in.getLength() == 1L || in.isEmpty()) {
            return 1;
        }
        if (in.getNonZeros() < (long)minimumSize) {
            res = LibMatrixCountDistinct.countDistinctValuesNaive(in);
        } else {
            switch (op.operatorType) {
                case COUNT: {
                    res = LibMatrixCountDistinct.countDistinctValuesNaive(in);
                    break;
                }
                case KMV: {
                    res = LibMatrixCountDistinct.countDistinctValuesKVM(in, op);
                    break;
                }
                default: {
                    throw new DMLException("Invalid or not implemented Estimator Type");
                }
            }
        }
        if (res == 0) {
            throw new DMLRuntimeException("Impossible estimate of distinct values");
        }
        return res;
    }

    private static int countDistinctValuesNaive(MatrixBlock in) {
        HashSet<Double> distinct;
        block10: {
            block9: {
                long nonZeros;
                distinct = new HashSet<Double>();
                if (in instanceof CompressedMatrixBlock) {
                    CompressedMatrixBlock inC = (CompressedMatrixBlock)in;
                    if (inC.isOverlapping()) {
                        in = inC.decompress();
                        inC = null;
                    } else {
                        List<AColGroup> colGroups = ((CompressedMatrixBlock)in).getColGroups();
                        for (AColGroup cg : colGroups) {
                            LibMatrixCountDistinct.countDistinctValuesNaive(cg.getValues(), distinct);
                        }
                    }
                }
                if ((nonZeros = in.getNonZeros()) != -1L && nonZeros < (long)(in.getNumColumns() * in.getNumRows())) {
                    distinct.add(0.0);
                }
                if (in.sparseBlock == null) break block9;
                SparseBlock sb = in.sparseBlock;
                if (in.sparseBlock.isContiguous()) {
                    double[] data = sb.values(0);
                    LibMatrixCountDistinct.countDistinctValuesNaive(data, distinct);
                } else {
                    for (int i = 0; i < in.getNumRows(); ++i) {
                        if (sb.isEmpty(i)) continue;
                        double[] data = in.sparseBlock.values(i);
                        LibMatrixCountDistinct.countDistinctValuesNaive(data, distinct);
                    }
                }
                break block10;
            }
            if (in.denseBlock == null) break block10;
            DenseBlock db = in.denseBlock;
            for (int i = 0; i <= db.numBlocks(); ++i) {
                double[] data = db.valuesAt(i);
                LibMatrixCountDistinct.countDistinctValuesNaive(data, distinct);
            }
        }
        return distinct.size();
    }

    private static Set<Double> countDistinctValuesNaive(double[] valuesPart, Set<Double> distinct) {
        for (double v : valuesPart) {
            distinct.add(v);
        }
        return distinct;
    }

    private static int countDistinctValuesKVM(MatrixBlock in, CountDistinctOperator op) {
        long D = in.getNonZeros() + 1L;
        long tmp = D * D;
        int M = tmp > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int)tmp;
        LOG.debug((Object)("M not forced to int size: " + tmp));
        LOG.debug((Object)("M: " + M));
        int k = D > 64L ? 64 : (int)D;
        SmallestPriorityQueue spq = new SmallestPriorityQueue(k);
        LibMatrixCountDistinct.countDistinctValuesKVM(in, op.hashType, k, spq, M);
        LOG.debug((Object)("M: " + M));
        LOG.debug((Object)("smallest hash:" + spq.peek()));
        LOG.debug((Object)("spq: " + spq.toString()));
        if (spq.size() < k) {
            return spq.size();
        }
        double U_k = (double)spq.poll() / (double)M;
        LOG.debug((Object)("U_k : " + U_k));
        double estimate = (double)(k - 1) / U_k;
        LOG.debug((Object)("Estimate: " + estimate));
        double ceilEstimate = Math.min(estimate, (double)D);
        LOG.debug((Object)("Ceil worst case: " + D));
        return (int)ceilEstimate;
    }

    private static void countDistinctValuesKVM(MatrixBlock in, Hash.HashType hashType, int k, SmallestPriorityQueue spq, int m) {
        if (in.sparseBlock == null && in.denseBlock == null) {
            List<AColGroup> colGroups = ((CompressedMatrixBlock)in).getColGroups();
            for (AColGroup cg : colGroups) {
                LibMatrixCountDistinct.countDistinctValuesKVM(cg.getValues(), hashType, k, spq, m);
            }
        } else if (in.sparseBlock != null) {
            SparseBlock sb = in.sparseBlock;
            if (in.sparseBlock.isContiguous()) {
                double[] data = sb.values(0);
                LibMatrixCountDistinct.countDistinctValuesKVM(data, hashType, k, spq, m);
            } else {
                for (int i = 0; i < in.getNumRows(); ++i) {
                    if (sb.isEmpty(i)) continue;
                    double[] data = in.sparseBlock.values(i);
                    LibMatrixCountDistinct.countDistinctValuesKVM(data, hashType, k, spq, m);
                }
            }
        } else {
            DenseBlock db = in.denseBlock;
            int bil = db.index(0);
            int biu = db.index(in.rlen);
            for (int i = bil; i <= biu; ++i) {
                double[] data = db.valuesAt(i);
                LibMatrixCountDistinct.countDistinctValuesKVM(data, hashType, k, spq, m);
            }
        }
    }

    private static void countDistinctValuesKVM(double[] data, Hash.HashType hashType, int k, SmallestPriorityQueue spq, int m) {
        for (double fullValue : data) {
            int hash = Hash.hash(fullValue, hashType);
            int v = Math.abs(hash) % (m - 1) + 1;
            spq.add(v);
        }
    }

    private static class SmallestPriorityQueue {
        private Set<Integer> containedSet;
        private PriorityQueue<Integer> smallestHashes;
        private int k;

        public SmallestPriorityQueue(int k) {
            this.smallestHashes = new PriorityQueue(k, Collections.reverseOrder());
            this.containedSet = new HashSet<Integer>(1);
            this.k = k;
        }

        public void add(int v) {
            if (!this.containedSet.contains(v)) {
                if (this.smallestHashes.size() < this.k) {
                    this.smallestHashes.add(v);
                    this.containedSet.add(v);
                } else if (v < this.smallestHashes.peek()) {
                    LOG.trace((Object)(this.smallestHashes.peek() + " -- " + v));
                    this.smallestHashes.add(v);
                    this.containedSet.add(v);
                    this.containedSet.remove(this.smallestHashes.poll());
                }
            }
        }

        public int size() {
            return this.smallestHashes.size();
        }

        public int peek() {
            return this.smallestHashes.peek();
        }

        public int poll() {
            return this.smallestHashes.poll();
        }

        public String toString() {
            return this.smallestHashes.toString();
        }
    }
}

