/*
 * Decompiled with CFR 0.152.
 */
package uniol.apt.analysis.synthesize.separation;

import de.uni_freiburg.informatik.ultimate.logic.Logics;
import de.uni_freiburg.informatik.ultimate.logic.Script;
import de.uni_freiburg.informatik.ultimate.logic.Sort;
import de.uni_freiburg.informatik.ultimate.logic.Term;
import de.uni_freiburg.informatik.ultimate.logic.TermVariable;
import de.uni_freiburg.informatik.ultimate.smtinterpol.DefaultLogger;
import de.uni_freiburg.informatik.ultimate.smtinterpol.LogProxy;
import de.uni_freiburg.informatik.ultimate.smtinterpol.smtlib2.SMTInterpol;
import de.uni_freiburg.informatik.ultimate.smtinterpol.smtlib2.TerminationRequest;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.commons.collections4.collection.CompositeCollection;
import uniol.apt.adt.ts.Arc;
import uniol.apt.adt.ts.State;
import uniol.apt.adt.ts.TransitionSystem;
import uniol.apt.analysis.synthesize.PNProperties;
import uniol.apt.analysis.synthesize.RegionUtility;
import uniol.apt.analysis.synthesize.UnreachableException;
import uniol.apt.util.AbstractEquivalenceRelation;
import uniol.apt.util.DebugUtil;
import uniol.apt.util.DifferentPairsIterable;
import uniol.apt.util.DomainEquivalenceRelation;
import uniol.apt.util.IEquivalenceRelation;
import uniol.apt.util.Pair;
import uniol.apt.util.interrupt.InterrupterRegistry;

public class SMTInterpolHelper {
    private final Script script;
    private final RegionUtility utility;
    private final PNProperties properties;
    private final String[] locationMap;

    public SMTInterpolHelper(RegionUtility utility, PNProperties properties, String[] locationMap) {
        Term[] letTerms;
        TermVariable[] letVariables;
        TermVariable[] params;
        int event;
        DefaultLogger logger = new DefaultLogger();
        this.script = new SMTInterpol((LogProxy)logger, new TerminationRequest(){

            @Override
            public boolean isTerminationRequested() {
                return InterrupterRegistry.getCurrentThreadInterrupter().isInterruptRequested();
            }
        });
        logger.setLoglevel(0);
        this.utility = utility;
        this.properties = properties;
        this.locationMap = Arrays.copyOf(locationMap, locationMap.length);
        this.script.setLogic(Logics.QF_LIA);
        int numberEvents = utility.getNumberOfEvents();
        List<String> eventList = utility.getEventList();
        Term[] weight = new TermVariable[numberEvents];
        Term[] backwardWeight = new TermVariable[numberEvents];
        Term[] forwardWeight = new TermVariable[numberEvents];
        for (event = 0; event < numberEvents; ++event) {
            weight[event] = this.script.variable("e-" + eventList.get(event), this.script.sort("Int", new Sort[0]));
            backwardWeight[event] = this.script.variable("b-" + eventList.get(event), this.script.sort("Int", new Sort[0]));
            forwardWeight[event] = this.script.variable("f-" + eventList.get(event), this.script.sort("Int", new Sort[0]));
        }
        if (properties.isPure()) {
            params = new TermVariable[1 + numberEvents];
            letVariables = new TermVariable[2 * numberEvents];
            letTerms = new Term[2 * numberEvents];
            Term zero = this.script.numeral(BigInteger.ZERO);
            for (int event2 = 0; event2 < numberEvents; ++event2) {
                params[1 + event2] = weight[event2];
                letVariables[event2] = backwardWeight[event2];
                letVariables[event2 + numberEvents] = forwardWeight[event2];
                letTerms[event2] = this.script.term("ite", this.script.term(">", zero, weight[event2]), this.script.term("-", weight[event2]), zero);
                letTerms[event2 + numberEvents] = this.script.term("ite", this.script.term("<", zero, weight[event2]), weight[event2], zero);
            }
        } else {
            params = new TermVariable[1 + 2 * numberEvents];
            letTerms = new Term[numberEvents];
            letVariables = weight;
            for (event = 0; event < numberEvents; ++event) {
                params[1 + event] = backwardWeight[event];
                params[1 + event + numberEvents] = forwardWeight[event];
                letTerms[event] = this.script.term("-", forwardWeight[event], backwardWeight[event]);
            }
        }
        params[0] = this.script.variable("m0", this.script.sort("Int", new Sort[0]));
        TermVariable initialMarking = params[0];
        ArrayList<Term> isRegion = new ArrayList<Term>();
        isRegion.addAll(this.requireRegion(initialMarking, weight, backwardWeight, forwardWeight));
        if (properties.isKBounded()) {
            isRegion.addAll(this.requireKBounded(initialMarking, weight, properties.getKForKBounded()));
        }
        if (properties.isPlain() || properties.isConflictFree()) {
            isRegion.addAll(this.requirePlainness(initialMarking, backwardWeight, forwardWeight));
        }
        assert (!properties.isOutputNonbranching());
        isRegion.addAll(this.requireDistributableNet(backwardWeight));
        if (properties.isMergeFree()) {
            isRegion.addAll(this.requireMergeFree(forwardWeight));
        }
        if (properties.isConflictFree()) {
            isRegion.addAll(this.requireConflictFree(weight, backwardWeight));
        }
        if (properties.isTNet() || properties.isMarkedGraph()) {
            isRegion.addAll(this.requireTNetOrMarkedGraph(backwardWeight, properties.isMarkedGraph()));
            isRegion.addAll(this.requireTNetOrMarkedGraph(forwardWeight, properties.isMarkedGraph()));
        }
        if (properties.isHomogeneous()) {
            isRegion.addAll(this.requireHomogeneous(backwardWeight));
        }
        if (properties.isKMarking()) {
            isRegion.addAll(this.requireKMarking(initialMarking, properties.getKForKMarking()));
        }
        if (properties.isBehaviourallyConflictFree()) {
            isRegion.addAll(this.requireBehaviourallyConflictFree(backwardWeight));
        }
        if (properties.isBinaryConflictFree()) {
            isRegion.addAll(this.requireBinaryConflictFree(initialMarking, weight, backwardWeight));
        }
        if (properties.isEqualConflict()) {
            isRegion.addAll(this.requireEqualConflict(utility, backwardWeight));
        }
        Term isRegionTerm = this.collectTerms("and", isRegion.toArray(new Term[isRegion.size()]), this.script.term("true", new Term[0]));
        isRegionTerm = this.script.let(letVariables, letTerms, isRegionTerm);
        this.script.defineFun("isRegion", params, this.script.sort("Bool", new Sort[0]), isRegionTerm);
    }

    public Term evaluateReachingParikhVector(Term initialMarking, Term[] weight, State state) throws UnreachableException {
        Term result = this.evaluateParikhVector(weight, this.utility.getReachingParikhVector(state));
        return this.script.term("+", initialMarking, result);
    }

    private Term evaluateParikhVector(Term[] weight, List<BigInteger> pv) {
        assert (weight.length == pv.size());
        Term[] summands = new Term[weight.length];
        for (int event = 0; event < weight.length; ++event) {
            BigInteger w = pv.get(event);
            Term addend = w.compareTo(BigInteger.ZERO) >= 0 ? this.script.numeral(w) : this.script.term("-", this.script.numeral(w.negate()));
            summands[event] = this.script.term("*", addend, weight[event]);
        }
        return this.collectTerms("+", summands, this.script.numeral(BigInteger.ZERO));
    }

    private List<Term> requireRegion(Term initialMarking, Term[] weight, Term[] backwardWeight, Term[] forwardWeight) {
        ArrayList<Term> result = new ArrayList<Term>();
        Term zero = this.script.numeral(BigInteger.ZERO);
        HashSet<List<BigInteger>> parikhVectorsOfCycles = new HashSet<List<BigInteger>>();
        for (Arc arc : this.utility.getSpanningTree().getChords()) {
            try {
                parikhVectorsOfCycles.add(this.utility.getParikhVectorForEdge(arc));
            }
            catch (UnreachableException e) {
                throw new RuntimeException("Chords of a spanning tree cannot belong to unreachable states?!", e);
            }
        }
        for (List list : parikhVectorsOfCycles) {
            result.add(this.script.term("=", zero, this.evaluateParikhVector(weight, list)));
        }
        for (Arc arc : this.utility.getTransitionSystem().getEdges()) {
            try {
                int event = this.utility.getEventIndex(arc.getLabel());
                Term term = this.evaluateReachingParikhVector(initialMarking, weight, (State)arc.getSource());
                result.add(this.script.term("<=", backwardWeight[event], term));
            }
            catch (UnreachableException e) {}
        }
        result.add(this.script.term("<=", zero, initialMarking));
        for (Iterator<Arc> iterator : backwardWeight) {
            result.add(this.script.term("<=", new Term[]{zero, iterator}));
        }
        for (Iterator<Arc> iterator : forwardWeight) {
            result.add(this.script.term("<=", new Term[]{zero, iterator}));
        }
        return result;
    }

    private List<Term> requireKBounded(Term initialMarking, Term[] weight, int k) {
        ArrayList<Term> result = new ArrayList<Term>();
        Term biK = this.script.numeral(BigInteger.valueOf(k));
        for (State state : this.utility.getTransitionSystem().getNodes()) {
            try {
                Term term = this.evaluateReachingParikhVector(initialMarking, weight, state);
                result.add(this.script.term("<=", term, biK));
            }
            catch (UnreachableException e) {}
        }
        return result;
    }

    private List<Term> requirePlainness(Term initialMarking, Term[] backwardWeight, Term[] forwardWeight) {
        ArrayList<Term> result = new ArrayList<Term>();
        Term one = this.script.numeral(BigInteger.ONE);
        for (int event = 0; event < this.utility.getNumberOfEvents(); ++event) {
            result.add(this.script.term(">=", one, backwardWeight[event]));
            result.add(this.script.term(">=", one, forwardWeight[event]));
        }
        return result;
    }

    private List<Term> requireTNetOrMarkedGraph(Term[] weight, boolean markedGraph) {
        int numberEvents = this.utility.getNumberOfEvents();
        Term[] result = new Term[numberEvents];
        Term zero = this.script.numeral(BigInteger.ZERO);
        for (int event = 0; event < numberEvents; ++event) {
            Term[] sum = new Term[this.utility.getNumberOfEvents() - 1];
            for (int idx = 0; idx < this.utility.getNumberOfEvents(); ++idx) {
                if (idx < event) {
                    sum[idx] = weight[idx];
                    continue;
                }
                if (idx <= event) continue;
                sum[idx - 1] = weight[idx];
            }
            result[event] = this.script.term("=", zero, this.collectTerms("+", sum, zero));
            if (!markedGraph) continue;
            result[event] = this.script.term("and", result[event], this.script.term("<", zero, weight[event]));
        }
        return Collections.singletonList(this.collectTerms("or", result, this.script.term("true", new Term[0])));
    }

    private List<Term> requireDistributableNet(Term[] backwardWeight) {
        HashSet<String> locations = new HashSet<String>(Arrays.asList(this.locationMap));
        locations.remove(null);
        if (locations.isEmpty()) {
            return Collections.emptyList();
        }
        Term zero = this.script.numeral(BigInteger.ZERO);
        Term[] terms = new Term[locations.size()];
        int index = 0;
        for (String location : locations) {
            Term term = zero;
            for (int eventIndex = 0; eventIndex < this.utility.getNumberOfEvents(); ++eventIndex) {
                if (this.locationMap[eventIndex] == null || this.locationMap[eventIndex].equals(location)) continue;
                term = this.script.term("+", backwardWeight[eventIndex], term);
            }
            terms[index++] = this.script.term("=", zero, term);
        }
        return Collections.singletonList(this.collectTerms("or", terms, this.script.term("true", new Term[0])));
    }

    private List<Term> requireMergeFree(Term[] forwardWeight) {
        int numberEvents = this.utility.getNumberOfEvents();
        Term zero = this.script.numeral(BigInteger.ZERO);
        Term[] result = new Term[numberEvents];
        for (int event = 0; event < numberEvents; ++event) {
            Term[] sum = new Term[numberEvents - 1];
            for (int idx = 0; idx < numberEvents; ++idx) {
                if (idx < event) {
                    sum[idx] = forwardWeight[idx];
                    continue;
                }
                if (idx <= event) continue;
                sum[idx - 1] = forwardWeight[idx];
            }
            result[event] = this.script.term("=", zero, this.collectTerms("+", sum, zero));
        }
        return Collections.singletonList(this.collectTerms("or", result, this.script.term("true", new Term[0])));
    }

    private List<Term> requireConflictFree(Term[] weight, Term[] backwardWeight) {
        Term zero = this.script.numeral(BigInteger.ZERO);
        Term one = this.script.numeral(BigInteger.ONE);
        backwardWeight = (Term[])Arrays.copyOf(backwardWeight, backwardWeight.length, Term[].class);
        Term result = this.script.term(">=", one, this.collectTerms("+", backwardWeight, zero));
        Term[] presetPostset = new Term[this.utility.getNumberOfEvents()];
        for (int event = 0; event < this.utility.getNumberOfEvents(); ++event) {
            presetPostset[event] = this.script.term("<=", zero, weight[event]);
        }
        return Collections.singletonList(this.script.term("or", result, this.collectTerms("and", presetPostset, this.script.term("false", new Term[0]))));
    }

    private List<Term> requireHomogeneous(Term[] backwardWeight) {
        ArrayList<Term> result = new ArrayList<Term>();
        Term zero = this.script.numeral(BigInteger.ZERO);
        for (Pair<Term, Term> pair : new DifferentPairsIterable<Term>(Arrays.asList(backwardWeight))) {
            result.add(this.script.term("or", this.script.term("=", zero, pair.getFirst()), this.script.term("=", zero, pair.getSecond()), this.script.term("=", pair.getFirst(), pair.getSecond())));
        }
        return result;
    }

    private List<Term> requireKMarking(Term initialMarking, int k) {
        return Collections.singletonList(this.script.term("divisible", new BigInteger[]{BigInteger.valueOf(k)}, null, initialMarking));
    }

    private List<Term> requireBehaviourallyConflictFree(Term[] backwardWeight) {
        HashSet simultaneouslyActivated = new HashSet();
        for (State state : this.utility.getTransitionSystem().getNodes()) {
            HashSet<Integer> activated = new HashSet<Integer>();
            for (Arc arc : state.getPostsetEdges()) {
                activated.add(this.utility.getEventIndex(arc.getLabel()));
            }
            if (activated.isEmpty()) continue;
            simultaneouslyActivated.add(activated);
        }
        ArrayList<Term> result = new ArrayList<Term>();
        Term zeroTerm = this.script.numeral(BigInteger.ZERO);
        for (Set set : simultaneouslyActivated) {
            Term[] terms = new Term[set.size()];
            int nextTermsIndex = 0;
            Iterator iterator = set.iterator();
            while (iterator.hasNext()) {
                int allowed = (Integer)iterator.next();
                Term[] summands = new Term[set.size() - 1];
                int nextSummandsIndex = 0;
                Iterator iterator2 = set.iterator();
                while (iterator2.hasNext()) {
                    int notAllowed = (Integer)iterator2.next();
                    if (allowed == notAllowed) continue;
                    summands[nextSummandsIndex++] = backwardWeight[notAllowed];
                }
                terms[nextTermsIndex++] = this.script.term("=", zeroTerm, this.collectTerms("+", summands, zeroTerm));
            }
            result.add(this.collectTerms("or", terms, this.script.term("true", new Term[0])));
        }
        return result;
    }

    private List<Term> requireBinaryConflictFree(Term initialMarking, Term[] weight, Term[] backwardWeight) {
        ArrayList<Term> result = new ArrayList<Term>();
        for (State state : this.utility.getTransitionSystem().getNodes()) {
            Term stateMarking;
            try {
                stateMarking = this.evaluateReachingParikhVector(initialMarking, weight, state);
            }
            catch (UnreachableException e) {
                continue;
            }
            HashSet<Integer> activated = new HashSet<Integer>();
            for (Arc arc : state.getPostsetEdges()) {
                activated.add(this.utility.getEventIndex(arc.getLabel()));
            }
            for (Pair pair : new DifferentPairsIterable(activated)) {
                result.add(this.script.term(">=", stateMarking, this.script.term("+", backwardWeight[(Integer)pair.getFirst()], backwardWeight[(Integer)pair.getSecond()])));
            }
        }
        return result;
    }

    private List<Term> requireEqualConflict(RegionUtility utility, Term[] backwardWeight) {
        if (backwardWeight.length == 0) {
            return Collections.emptyList();
        }
        TransitionSystem ts = utility.getTransitionSystem();
        AbstractEquivalenceRelation relation = new DomainEquivalenceRelation<String>(ts.getAlphabet());
        String someEvent = ts.getAlphabet().iterator().next();
        for (String otherEvent : ts.getAlphabet()) {
            relation.joinClasses(someEvent, otherEvent);
        }
        for (final State state : ts.getNodes()) {
            relation = relation.refine(new IEquivalenceRelation<String>(){

                @Override
                public boolean isEquivalent(String event1, String event2) {
                    Set<State> postset1 = state.getPostsetNodesByLabel(event1);
                    Set<State> postset2 = state.getPostsetNodesByLabel(event2);
                    return postset1.isEmpty() == postset2.isEmpty();
                }
            });
        }
        DebugUtil.debug((Object)"Enabling-equivalent transitions: ", relation);
        ArrayList<Term> result = new ArrayList<Term>();
        List<String> eventList = utility.getEventList();
        for (Set equivalenceClass : new CompositeCollection(relation, Collections.singleton(Collections.emptySet()))) {
            Term[] current = new Term[eventList.size()];
            Term pivot = null;
            for (int index = 0; index < backwardWeight.length; ++index) {
                if (equivalenceClass.contains(utility.getEventList().get(index))) {
                    if (pivot == null) {
                        pivot = backwardWeight[index];
                        current[index] = this.script.term("<", this.script.numeral(BigInteger.ZERO), pivot);
                        continue;
                    }
                    current[index] = this.script.term("=", pivot, backwardWeight[index]);
                    continue;
                }
                current[index] = this.script.term("=", this.script.numeral(BigInteger.ZERO), backwardWeight[index]);
            }
            result.add(this.collectTerms("and", current, this.script.term("true", new Term[0])));
        }
        return Arrays.asList(this.collectTerms("or", result.toArray(new Term[0]), this.script.term("false", new Term[0])));
    }

    private Term collectTerms(String operation, Term[] terms, Term def) {
        switch (terms.length) {
            case 0: {
                return def;
            }
            case 1: {
                return terms[0];
            }
        }
        return this.script.term(operation, terms);
    }

    public Script getScript() {
        return this.script;
    }
}

