# Copyright 2014 MINES ParisTech
#
# This file is part of LinPy.
#
# LinPy is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# LinPy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with LinPy.  If not, see <http://www.gnu.org/licenses/>.

import ast
import functools
import re
import math

from fractions import Fraction

from . import islhelper
from .islhelper import mainctx, libisl
from .linexprs import Expression, Symbol, Rational
from .geometry import GeometricObject, Point, Vector


__all__ = [
    'Domain',
    'And', 'Or', 'Not',
]


@functools.total_ordering
class Domain(GeometricObject):

    __slots__ = (
        '_polyhedra',
        '_symbols',
        '_dimension',
    )

    def __new__(cls, *polyhedra):
        from .polyhedra import Polyhedron
        if len(polyhedra) == 1:
            argument = polyhedra[0]
            if isinstance(argument, str):
                return cls.fromstring(argument)
            elif isinstance(argument, GeometricObject):
                return argument.aspolyhedron()
            else:
                raise TypeError('argument must be a string '
                    'or a GeometricObject instance')
        else:
            for polyhedron in polyhedra:
                if not isinstance(polyhedron, Polyhedron):
                    raise TypeError('arguments must be Polyhedron instances')
            symbols = cls._xsymbols(polyhedra)
            islset = cls._toislset(polyhedra, symbols)
            return cls._fromislset(islset, symbols)

    @classmethod
    def _xsymbols(cls, iterator):
        """
        Return the ordered tuple of symbols present in iterator.
        """
        symbols = set()
        for item in iterator:
            symbols.update(item.symbols)
        return tuple(sorted(symbols, key=Symbol.sortkey))

    @property
    def polyhedra(self):
        return self._polyhedra

    @property
    def symbols(self):
        return self._symbols

    @property
    def dimension(self):
        return self._dimension

    def disjoint(self):
        """
        Returns this set as disjoint.
        """
        islset = self._toislset(self.polyhedra, self.symbols)
        islset = libisl.isl_set_make_disjoint(mainctx, islset)
        return self._fromislset(islset, self.symbols)

    def isempty(self):
        """
        Returns true if this set is an Empty set.
        """
        islset = self._toislset(self.polyhedra, self.symbols)
        empty = bool(libisl.isl_set_is_empty(islset))
        libisl.isl_set_free(islset)
        return empty

    def __bool__(self):
        return not self.isempty()

    def isuniverse(self):
        """
        Returns true if this set is the Universe set.
        """
        islset = self._toislset(self.polyhedra, self.symbols)
        universe = bool(libisl.isl_set_plain_is_universe(islset))
        libisl.isl_set_free(islset)
        return universe

    def isbounded(self):
        """
        Returns true if this set is bounded.
        """
        islset = self._toislset(self.polyhedra, self.symbols)
        bounded = bool(libisl.isl_set_is_bounded(islset))
        libisl.isl_set_free(islset)
        return bounded

    def __eq__(self, other):
        """
        Returns true if two sets are equal.
        """
        symbols = self._xsymbols([self, other])
        islset1 = self._toislset(self.polyhedra, symbols)
        islset2 = other._toislset(other.polyhedra, symbols)
        equal = bool(libisl.isl_set_is_equal(islset1, islset2))
        libisl.isl_set_free(islset1)
        libisl.isl_set_free(islset2)
        return equal

    def isdisjoint(self, other):
        """
        Return True if two sets have a null intersection.
        """
        symbols = self._xsymbols([self, other])
        islset1 = self._toislset(self.polyhedra, symbols)
        islset2 = self._toislset(other.polyhedra, symbols)
        equal = bool(libisl.isl_set_is_disjoint(islset1, islset2))
        libisl.isl_set_free(islset1)
        libisl.isl_set_free(islset2)
        return equal

    def issubset(self, other):
        """
        Report whether another set contains this set.
        """
        symbols = self._xsymbols([self, other])
        islset1 = self._toislset(self.polyhedra, symbols)
        islset2 = self._toislset(other.polyhedra, symbols)
        equal = bool(libisl.isl_set_is_subset(islset1, islset2))
        libisl.isl_set_free(islset1)
        libisl.isl_set_free(islset2)
        return equal

    def __le__(self, other):
        """
        Returns true if this set is less than or equal to another set.
        """
        return self.issubset(other)

    def __lt__(self, other):
        """
        Returns true if this set is less than another set.
        """
        symbols = self._xsymbols([self, other])
        islset1 = self._toislset(self.polyhedra, symbols)
        islset2 = self._toislset(other.polyhedra, symbols)
        equal = bool(libisl.isl_set_is_strict_subset(islset1, islset2))
        libisl.isl_set_free(islset1)
        libisl.isl_set_free(islset2)
        return equal

    def complement(self):
        """
        Returns the complement of this set.
        """
        islset = self._toislset(self.polyhedra, self.symbols)
        islset = libisl.isl_set_complement(islset)
        return self._fromislset(islset, self.symbols)

    def __invert__(self):
        """
        Returns the complement of this set.
        """
        return self.complement()

    def simplify(self):
        """
        Returns a set without redundant constraints.
        """
        islset = self._toislset(self.polyhedra, self.symbols)
        islset = libisl.isl_set_remove_redundancies(islset)
        return self._fromislset(islset, self.symbols)

    def aspolyhedron(self):
        """
        Returns polyhedral hull of set.
        """
        from .polyhedra import Polyhedron
        islset = self._toislset(self.polyhedra, self.symbols)
        islbset = libisl.isl_set_polyhedral_hull(islset)
        return Polyhedron._fromislbasicset(islbset, self.symbols)

    def asdomain(self):
        return self

    def project(self, dims):
        """
        Return new set with given dimensions removed.
        """
        islset = self._toislset(self.polyhedra, self.symbols)
        n = 0
        for index, symbol in reversed(list(enumerate(self.symbols))):
            if symbol in dims:
                n += 1
            elif n > 0:
                islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, index + 1, n)
                n = 0
        if n > 0:
            islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, 0, n)
        dims = [symbol for symbol in self.symbols if symbol not in dims]
        return Domain._fromislset(islset, dims)

    def sample(self):
        """
        Returns a single subset of the input.
        """
        islset = self._toislset(self.polyhedra, self.symbols)
        islpoint = libisl.isl_set_sample_point(islset)
        if bool(libisl.isl_point_is_void(islpoint)):
            libisl.isl_point_free(islpoint)
            raise ValueError('domain must be non-empty')
        point = {}
        for index, symbol in enumerate(self.symbols):
            coordinate = libisl.isl_point_get_coordinate_val(islpoint,
                libisl.isl_dim_set, index)
            coordinate = islhelper.isl_val_to_int(coordinate)
            point[symbol] = coordinate
        libisl.isl_point_free(islpoint)
        return point

    def intersection(self, *others):
        """
         Return the intersection of two sets as a new set.
        """
        if len(others) == 0:
            return self
        symbols = self._xsymbols((self,) + others)
        islset1 = self._toislset(self.polyhedra, symbols)
        for other in others:
            islset2 = other._toislset(other.polyhedra, symbols)
            islset1 = libisl.isl_set_intersect(islset1, islset2)
        return self._fromislset(islset1, symbols)

    def __and__(self, other):
        """
         Return the intersection of two sets as a new set.
        """
        return self.intersection(other)

    def union(self, *others):
        """
        Return the union of sets as a new set.
        """
        if len(others) == 0:
            return self
        symbols = self._xsymbols((self,) + others)
        islset1 = self._toislset(self.polyhedra, symbols)
        for other in others:
            islset2 = other._toislset(other.polyhedra, symbols)
            islset1 = libisl.isl_set_union(islset1, islset2)
        return self._fromislset(islset1, symbols)

    def __or__(self, other):
        """
        Return a new set with elements from both sets.
        """
        return self.union(other)

    def __add__(self, other):
        """
        Return new set containing all elements in both sets.
        """
        return self.union(other)

    def difference(self, other):
        """
        Return the difference of two sets as a new set.
        """
        symbols = self._xsymbols([self, other])
        islset1 = self._toislset(self.polyhedra, symbols)
        islset2 = other._toislset(other.polyhedra, symbols)
        islset = libisl.isl_set_subtract(islset1, islset2)
        return self._fromislset(islset, symbols)

    def __sub__(self, other):
        """
        Return the difference of two sets as a new set.
        """
        return self.difference(other)

    def lexmin(self):
        """
        Return a new set containing the lexicographic minimum of the elements in the set.
        """
        islset = self._toislset(self.polyhedra, self.symbols)
        islset = libisl.isl_set_lexmin(islset)
        return self._fromislset(islset, self.symbols)

    def lexmax(self):
        """
        Return a new set containing the lexicographic maximum of the elements in the set.
        """
        islset = self._toislset(self.polyhedra, self.symbols)
        islset = libisl.isl_set_lexmax(islset)
        return self._fromislset(islset, self.symbols)


    def involves_vars(self, vars):
        """
        Returns true if a set depends on given dimensions.
        """
        islset = self._toislset(self.polyhedra, self.symbols)
        dims = sorted(vars)
        symbols = sorted(list(self.symbols))
        n = 0
        if len(dims)>0:
            for dim in dims:
                if dim in symbols:
                    first = symbols.index(dims[0])
                    n +=1
                else:
                    first = 0
        else:
            return False
        value = bool(libisl.isl_set_involves_dims(islset, libisl.isl_dim_set, first, n))
        libisl.isl_set_free(islset)
        return value

    _RE_COORDINATE = re.compile(r'\((?P<num>\-?\d+)\)(/(?P<den>\d+))?')

    def vertices(self):
        """
        Return a list of vertices for this Polygon.
        """
        from .polyhedra import Polyhedron
        if not self.isbounded():
            raise ValueError('domain must be bounded')
        islbset = self._toislbasicset(self.equalities, self.inequalities, self.symbols)
        vertices = libisl.isl_basic_set_compute_vertices(islbset);
        vertices = islhelper.isl_vertices_vertices(vertices)
        points = []
        for vertex in vertices:
            expr = libisl.isl_vertex_get_expr(vertex)
            coordinates = []
            if islhelper.isl_version < '0.13':
                constraints = islhelper.isl_basic_set_constraints(expr)
                for constraint in constraints:
                    constant = libisl.isl_constraint_get_constant_val(constraint)
                    constant = islhelper.isl_val_to_int(constant)
                    for index, symbol in enumerate(self.symbols):
                        coefficient = libisl.isl_constraint_get_coefficient_val(constraint,
                            libisl.isl_dim_set, index)
                        coefficient = islhelper.isl_val_to_int(coefficient)
                        if coefficient != 0:
                            coordinate = -Fraction(constant, coefficient)
                            coordinates.append((symbol, coordinate))
            else:
                string = islhelper.isl_multi_aff_to_str(expr)
                matches = self._RE_COORDINATE.finditer(string)
                for symbol, match in zip(self.symbols, matches):
                    numerator = int(match.group('num'))
                    denominator = match.group('den')
                    denominator = 1 if denominator is None else int(denominator)
                    coordinate = Fraction(numerator, denominator)
                    coordinates.append((symbol, coordinate))
            points.append(Point(coordinates))
        return points

    def points(self):
        """
        Returns the points contained in the set.
        """
        if not self.isbounded():
            raise ValueError('domain must be bounded')
        from .polyhedra import Universe, Eq
        islset = self._toislset(self.polyhedra, self.symbols)
        islpoints = islhelper.isl_set_points(islset)
        points = []
        for islpoint in islpoints:
            coordinates = {}
            for index, symbol in enumerate(self.symbols):
                coordinate = libisl.isl_point_get_coordinate_val(islpoint,
                    libisl.isl_dim_set, index)
                coordinate = islhelper.isl_val_to_int(coordinate)
                coordinates[symbol] = coordinate
            points.append(Point(coordinates))
        return points

    @classmethod
    def _polygon_inner_point(cls, points):
        symbols = points[0].symbols
        coordinates = {symbol: 0 for symbol in symbols}
        for point in points:
            for symbol, coordinate in point.coordinates():
                coordinates[symbol] += coordinate
        for symbol in symbols:
            coordinates[symbol] /= len(points)
        return Point(coordinates)

    @classmethod
    def _sort_polygon_2d(cls, points):
        if len(points) <= 3:
            return points
        o = cls._polygon_inner_point(points)
        angles = {}
        for m in points:
            om = Vector(o, m)
            dx, dy = (coordinate for symbol, coordinate in om.coordinates())
            angle = math.atan2(dy, dx)
            angles[m] = angle
        return sorted(points, key=angles.get)

    @classmethod
    def _sort_polygon_3d(cls, points):
        if len(points) <= 3:
            return points
        o = cls._polygon_inner_point(points)
        a = points[0]
        oa = Vector(o, a)
        norm_oa = oa.norm()
        for b in points[1:]:
            ob = Vector(o, b)
            u = oa.cross(ob)
            if not u.isnull():
                u = u.asunit()
                break
        else:
            raise ValueError('degenerate polygon')
        angles = {a: 0.}
        for m in points[1:]:
            om = Vector(o, m)
            normprod = norm_oa * om.norm()
            cosinus = max(oa.dot(om) / normprod, -1.)
            sinus = u.dot(oa.cross(om)) / normprod
            angle = math.acos(cosinus)
            angle = math.copysign(angle, sinus)
            angles[m] = angle
        return sorted(points, key=angles.get)

    def faces(self):
        """
        Returns the vertices of the faces of a polyhedra.
        """
        faces = []
        for polyhedron in self.polyhedra:
            vertices = polyhedron.vertices()
            for constraint in polyhedron.constraints:
                face = []
                for vertex in vertices:
                    if constraint.subs(vertex.coordinates()) == 0:
                        face.append(vertex)
                if len(face) >= 3:
                    faces.append(face)
        return faces

    def _plot_2d(self, plot=None, **kwargs):
        import matplotlib.pyplot as plt
        from matplotlib.patches import Polygon
        if plot is None:
            fig = plt.figure()
            plot = fig.add_subplot(1, 1, 1)
        xmin, xmax = plot.get_xlim()
        ymin, ymax = plot.get_ylim()
        for polyhedron in self.polyhedra:
            vertices = polyhedron._sort_polygon_2d(polyhedron.vertices())
            xys = [tuple(vertex.values()) for vertex in vertices]
            xs, ys = zip(*xys)
            xmin, xmax = min(xmin, float(min(xs))), max(xmax, float(max(xs)))
            ymin, ymax = min(ymin, float(min(ys))), max(ymax, float(max(ys)))
            plot.add_patch(Polygon(xys, closed=True, **kwargs))
        plot.set_xlim(xmin, xmax)
        plot.set_ylim(ymin, ymax)
        return plot

    def _plot_3d(self, plot=None, **kwargs):
        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d import Axes3D
        from mpl_toolkits.mplot3d.art3d import Poly3DCollection
        if plot is None:
            fig = plt.figure()
            axes = Axes3D(fig)
        else:
            axes = plot
        xmin, xmax = axes.get_xlim()
        ymin, ymax = axes.get_ylim()
        zmin, zmax = axes.get_zlim()
        poly_xyzs = []
        for vertices in self.faces():
            vertices = self._sort_polygon_3d(vertices)
            vertices.append(vertices[0])
            face_xyzs = [tuple(vertex.values()) for vertex in vertices]
            xs, ys, zs = zip(*face_xyzs)
            xmin, xmax = min(xmin, float(min(xs))), max(xmax, float(max(xs)))
            ymin, ymax = min(ymin, float(min(ys))), max(ymax, float(max(ys)))
            zmin, zmax = min(zmin, float(min(zs))), max(zmax, float(max(zs)))
            poly_xyzs.append(face_xyzs)
        collection = Poly3DCollection(poly_xyzs, **kwargs)
        axes.add_collection3d(collection)
        axes.set_xlim(xmin, xmax)
        axes.set_ylim(ymin, ymax)
        axes.set_zlim(zmin, zmax)
        return axes


    def plot(self, plot=None, **kwargs):
        """
        Display plot of this set.
        """
        if not self.isbounded():
            raise ValueError('domain must be bounded')
        elif self.dimension == 2:
            return self._plot_2d(plot=plot, **kwargs)
        elif self.dimension == 3:
            return self._plot_3d(plot=plot, **kwargs)
        else:
            raise ValueError('polyhedron must be 2 or 3-dimensional')

    def __contains__(self, point):
        for polyhedron in self.polyhedra:
            if point in polyhedron:
                return True
        return False

    def subs(self, symbol, expression=None):
        """
        Subsitute the given value into an expression and return the resulting
        expression.
        """
        polyhedra = [polyhedron.subs(symbol, expression)
            for polyhedron in self.polyhedra]
        return Domain(*polyhedra)

    @classmethod
    def _fromislset(cls, islset, symbols):
        from .polyhedra import Polyhedron
        islset = libisl.isl_set_remove_divs(islset)
        islbsets = islhelper.isl_set_basic_sets(islset)
        libisl.isl_set_free(islset)
        polyhedra = []
        for islbset in islbsets:
            polyhedron = Polyhedron._fromislbasicset(islbset, symbols)
            polyhedra.append(polyhedron)
        if len(polyhedra) == 0:
            from .polyhedra import Empty
            return Empty
        elif len(polyhedra) == 1:
            return polyhedra[0]
        else:
            self = object().__new__(Domain)
            self._polyhedra = tuple(polyhedra)
            self._symbols = cls._xsymbols(polyhedra)
            self._dimension = len(self._symbols)
            return self

    @classmethod
    def _toislset(cls, polyhedra, symbols):
        polyhedron = polyhedra[0]
        islbset = polyhedron._toislbasicset(polyhedron.equalities,
            polyhedron.inequalities, symbols)
        islset1 = libisl.isl_set_from_basic_set(islbset)
        for polyhedron in polyhedra[1:]:
            islbset = polyhedron._toislbasicset(polyhedron.equalities,
                polyhedron.inequalities, symbols)
            islset2 = libisl.isl_set_from_basic_set(islbset)
            islset1 = libisl.isl_set_union(islset1, islset2)
        return islset1

    @classmethod
    def _fromast(cls, node):
        from .polyhedra import Polyhedron
        if isinstance(node, ast.Module) and len(node.body) == 1:
            return cls._fromast(node.body[0])
        elif isinstance(node, ast.Expr):
            return cls._fromast(node.value)
        elif isinstance(node, ast.UnaryOp):
            domain = cls._fromast(node.operand)
            if isinstance(node.operand, ast.invert):
                return Not(domain)
        elif isinstance(node, ast.BinOp):
            domain1 = cls._fromast(node.left)
            domain2 = cls._fromast(node.right)
            if isinstance(node.op, ast.BitAnd):
                return And(domain1, domain2)
            elif isinstance(node.op, ast.BitOr):
                return Or(domain1, domain2)
        elif isinstance(node, ast.Compare):
            equalities = []
            inequalities = []
            left = Expression._fromast(node.left)
            for i in range(len(node.ops)):
                op = node.ops[i]
                right = Expression._fromast(node.comparators[i])
                if isinstance(op, ast.Lt):
                    inequalities.append(right - left - 1)
                elif isinstance(op, ast.LtE):
                    inequalities.append(right - left)
                elif isinstance(op, ast.Eq):
                    equalities.append(left - right)
                elif isinstance(op, ast.GtE):
                    inequalities.append(left - right)
                elif isinstance(op, ast.Gt):
                    inequalities.append(left - right - 1)
                else:
                    break
                left = right
            else:
                return Polyhedron(equalities, inequalities)
        raise SyntaxError('invalid syntax')

    _RE_BRACES = re.compile(r'^\{\s*|\s*\}$')
    _RE_EQ = re.compile(r'([^<=>])=([^<=>])')
    _RE_AND = re.compile(r'\band\b|,|&&|/\\|∧|∩')
    _RE_OR = re.compile(r'\bor\b|;|\|\||\\/|∨|∪')
    _RE_NOT = re.compile(r'\bnot\b|!|¬')
    _RE_NUM_VAR = Expression._RE_NUM_VAR
    _RE_OPERATORS = re.compile(r'(&|\||~)')

    @classmethod
    def fromstring(cls, string):
        # remove curly brackets
        string = cls._RE_BRACES.sub(r'', string)
        # replace '=' by '=='
        string = cls._RE_EQ.sub(r'\1==\2', string)
        # replace 'and', 'or', 'not'
        string = cls._RE_AND.sub(r' & ', string)
        string = cls._RE_OR.sub(r' | ', string)
        string = cls._RE_NOT.sub(r' ~', string)
        # add implicit multiplication operators, e.g. '5x' -> '5*x'
        string = cls._RE_NUM_VAR.sub(r'\1*\2', string)
        # add parentheses to force precedence
        tokens = cls._RE_OPERATORS.split(string)
        for i, token in enumerate(tokens):
            if i % 2 == 0:
                token = '({})'.format(token)
                tokens[i] = token
        string = ''.join(tokens)
        tree = ast.parse(string, 'eval')
        return cls._fromast(tree)

    def __repr__(self):
        assert len(self.polyhedra) >= 2
        strings = [repr(polyhedron) for polyhedron in self.polyhedra]
        return 'Or({})'.format(', '.join(strings))

    def _repr_latex_(self):
        strings = []
        for polyhedron in self.polyhedra:
            strings.append('({})'.format(polyhedron._repr_latex_().strip('$')))
        return '${}$'.format(' \\vee '.join(strings))

    @classmethod
    def fromsympy(cls, expr):
        import sympy
        from .polyhedra import Lt, Le, Eq, Ne, Ge, Gt
        funcmap = {
            sympy.And: And, sympy.Or: Or, sympy.Not: Not,
            sympy.Lt: Lt, sympy.Le: Le,
            sympy.Eq: Eq, sympy.Ne: Ne,
            sympy.Ge: Ge, sympy.Gt: Gt,
        }
        if expr.func in funcmap:
            args = [Domain.fromsympy(arg) for arg in expr.args]
            return funcmap[expr.func](*args)
        elif isinstance(expr, sympy.Expr):
            return Expression.fromsympy(expr)
        raise ValueError('non-domain expression: {!r}'.format(expr))

    def tosympy(self):
        import sympy
        polyhedra = [polyhedron.tosympy() for polyhedron in polyhedra]
        return sympy.Or(*polyhedra)


def And(*domains):
    """
    Return the intersection of two sets as a new set.
    """
    if len(domains) == 0:
        from .polyhedra import Universe
        return Universe
    else:
        return domains[0].intersection(*domains[1:])

def Or(*domains):
    """
    Return the union of sets as a new set.
    """
    if len(domains) == 0:
        from .polyhedra import Empty
        return Empty
    else:
        return domains[0].union(*domains[1:])

def Not(domain):
    """
    Returns the complement of this set.
    """
    return ~domain
