Improve mapping recognition
[linpy.git] / pypol / polyhedra.py
1
2 import functools
3 import math
4 import numbers
5
6 from . import islhelper
7
8 from .islhelper import mainctx, libisl
9 from .linexprs import Expression, Symbol, Rational
10 from .domains import Domain
11
12
13 __all__ = [
14 'Polyhedron',
15 'Lt', 'Le', 'Eq', 'Ne', 'Ge', 'Gt',
16 'Empty', 'Universe',
17 ]
18
19
20 class Polyhedron(Domain):
21
22 __slots__ = (
23 '_equalities',
24 '_inequalities',
25 '_constraints',
26 '_symbols',
27 '_dimension',
28 )
29
30 def __new__(cls, equalities=None, inequalities=None):
31 if isinstance(equalities, str):
32 if inequalities is not None:
33 raise TypeError('too many arguments')
34 return cls.fromstring(equalities)
35 elif isinstance(equalities, Polyhedron):
36 if inequalities is not None:
37 raise TypeError('too many arguments')
38 return equalities
39 elif isinstance(equalities, Domain):
40 if inequalities is not None:
41 raise TypeError('too many arguments')
42 return equalities.aspolyhedron()
43 if equalities is None:
44 equalities = []
45 else:
46 for i, equality in enumerate(equalities):
47 if not isinstance(equality, Expression):
48 raise TypeError('equalities must be linear expressions')
49 equalities[i] = equality.scaleint()
50 if inequalities is None:
51 inequalities = []
52 else:
53 for i, inequality in enumerate(inequalities):
54 if not isinstance(inequality, Expression):
55 raise TypeError('inequalities must be linear expressions')
56 inequalities[i] = inequality.scaleint()
57 symbols = cls._xsymbols(equalities + inequalities)
58 islbset = cls._toislbasicset(equalities, inequalities, symbols)
59 return cls._fromislbasicset(islbset, symbols)
60
61 @property
62 def equalities(self):
63 return self._equalities
64
65 @property
66 def inequalities(self):
67 return self._inequalities
68
69 @property
70 def constraints(self):
71 return self._constraints
72
73 @property
74 def polyhedra(self):
75 return self,
76
77 def disjoint(self):
78 return self
79
80 def isuniverse(self):
81 islbset = self._toislbasicset(self.equalities, self.inequalities,
82 self.symbols)
83 universe = bool(libisl.isl_basic_set_is_universe(islbset))
84 libisl.isl_basic_set_free(islbset)
85 return universe
86
87 def aspolyhedron(self):
88 return self
89
90 def subs(self, symbol, expression=None):
91 equalities = [equality.subs(symbol, expression)
92 for equality in self.equalities]
93 inequalities = [inequality.subs(symbol, expression)
94 for inequality in self.inequalities]
95 return Polyhedron(equalities, inequalities)
96
97 @classmethod
98 def _fromislbasicset(cls, islbset, symbols):
99 islconstraints = islhelper.isl_basic_set_constraints(islbset)
100 equalities = []
101 inequalities = []
102 for islconstraint in islconstraints:
103 constant = libisl.isl_constraint_get_constant_val(islconstraint)
104 constant = islhelper.isl_val_to_int(constant)
105 coefficients = {}
106 for index, symbol in enumerate(symbols):
107 coefficient = libisl.isl_constraint_get_coefficient_val(islconstraint,
108 libisl.isl_dim_set, index)
109 coefficient = islhelper.isl_val_to_int(coefficient)
110 if coefficient != 0:
111 coefficients[symbol] = coefficient
112 expression = Expression(coefficients, constant)
113 if libisl.isl_constraint_is_equality(islconstraint):
114 equalities.append(expression)
115 else:
116 inequalities.append(expression)
117 libisl.isl_basic_set_free(islbset)
118 self = object().__new__(Polyhedron)
119 self._equalities = tuple(equalities)
120 self._inequalities = tuple(inequalities)
121 self._constraints = tuple(equalities + inequalities)
122 self._symbols = cls._xsymbols(self._constraints)
123 self._dimension = len(self._symbols)
124 return self
125
126 @classmethod
127 def _toislbasicset(cls, equalities, inequalities, symbols):
128 dimension = len(symbols)
129 indices = {symbol: index for index, symbol in enumerate(symbols)}
130 islsp = libisl.isl_space_set_alloc(mainctx, 0, dimension)
131 islbset = libisl.isl_basic_set_universe(libisl.isl_space_copy(islsp))
132 islls = libisl.isl_local_space_from_space(islsp)
133 for equality in equalities:
134 isleq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(islls))
135 for symbol, coefficient in equality.coefficients():
136 islval = str(coefficient).encode()
137 islval = libisl.isl_val_read_from_str(mainctx, islval)
138 index = indices[symbol]
139 isleq = libisl.isl_constraint_set_coefficient_val(isleq,
140 libisl.isl_dim_set, index, islval)
141 if equality.constant != 0:
142 islval = str(equality.constant).encode()
143 islval = libisl.isl_val_read_from_str(mainctx, islval)
144 isleq = libisl.isl_constraint_set_constant_val(isleq, islval)
145 islbset = libisl.isl_basic_set_add_constraint(islbset, isleq)
146 for inequality in inequalities:
147 islin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(islls))
148 for symbol, coefficient in inequality.coefficients():
149 islval = str(coefficient).encode()
150 islval = libisl.isl_val_read_from_str(mainctx, islval)
151 index = indices[symbol]
152 islin = libisl.isl_constraint_set_coefficient_val(islin,
153 libisl.isl_dim_set, index, islval)
154 if inequality.constant != 0:
155 islval = str(inequality.constant).encode()
156 islval = libisl.isl_val_read_from_str(mainctx, islval)
157 islin = libisl.isl_constraint_set_constant_val(islin, islval)
158 islbset = libisl.isl_basic_set_add_constraint(islbset, islin)
159 return islbset
160
161 @classmethod
162 def fromstring(cls, string):
163 domain = Domain.fromstring(string)
164 if not isinstance(domain, Polyhedron):
165 raise ValueError('non-polyhedral expression: {!r}'.format(string))
166 return domain
167
168 def __repr__(self):
169 if self.isempty():
170 return 'Empty'
171 elif self.isuniverse():
172 return 'Universe'
173 else:
174 strings = []
175 for equality in self.equalities:
176 strings.append('0 == {}'.format(equality))
177 for inequality in self.inequalities:
178 strings.append('0 <= {}'.format(inequality))
179 if len(strings) == 1:
180 return strings[0]
181 else:
182 return 'And({})'.format(', '.join(strings))
183
184 @classmethod
185 def fromsympy(cls, expr):
186 domain = Domain.fromsympy(expr)
187 if not isinstance(domain, Polyhedron):
188 raise ValueError('non-polyhedral expression: {!r}'.format(expr))
189 return domain
190
191 def tosympy(self):
192 import sympy
193 constraints = []
194 for equality in self.equalities:
195 constraints.append(sympy.Eq(equality.tosympy(), 0))
196 for inequality in self.inequalities:
197 constraints.append(sympy.Ge(inequality.tosympy(), 0))
198 return sympy.And(*constraints)
199
200 @classmethod
201 def _sort_polygon_2d(cls, points):
202 if len(points) <= 3:
203 return points
204 o = sum((Vector(point) for point in points)) / len(points)
205 o = Point(o.coordinates())
206 angles = {}
207 for m in points:
208 om = Vector(o, m)
209 dx, dy = (coordinate for symbol, coordinates in om.coordinates())
210 angle = math.atan2(dy, dx)
211 angles[m] = angle
212 return sorted(points, key=angles.get)
213
214 @classmethod
215 def _sort_polygon_3d(cls, points):
216 if len(points) <= 3:
217 return points
218 o = sum((Vector(point) for point in points)) / len(points)
219 o = Point(o.coordinates())
220 a, b = points[:2]
221 oa = Vector(o, a)
222 ob = Vector(o, b)
223 norm_oa = oa.norm()
224 u = (oa.cross(ob)).asunit()
225 angles = {a: 0.}
226 for m in points[1:]:
227 om = Vector(o, m)
228 normprod = norm_oa * om.norm()
229 cosinus = oa.dot(om) / normprod
230 sinus = u.dot(oa.cross(om)) / normprod
231 angle = math.acos(cosinus)
232 angle = math.copysign(angle, sinus)
233 angles[m] = angle
234 return sorted(points, key=angles.get)
235
236 def plot(self):
237 import matplotlib.pyplot as plt
238 from matplotlib.path import Path
239 import matplotlib.patches as patches
240
241 if len(self.symbols)> 3:
242 raise TypeError
243
244 elif len(self.symbols) == 2:
245 verts = self.vertices()
246 points = []
247 codes = [Path.MOVETO]
248 for vert in verts:
249 pairs = ()
250 for sym in sorted(vert, key=Symbol.sortkey):
251 num = vert.get(sym)
252 pairs = pairs + (num,)
253 points.append(pairs)
254 points.append((0.0, 0.0))
255 num = len(points)
256 while num > 2:
257 codes.append(Path.LINETO)
258 num = num - 1
259 else:
260 codes.append(Path.CLOSEPOLY)
261 path = Path(points, codes)
262 fig = plt.figure()
263 ax = fig.add_subplot(111)
264 patch = patches.PathPatch(path, facecolor='blue', lw=2)
265 ax.add_patch(patch)
266 ax.set_xlim(-5,5)
267 ax.set_ylim(-5,5)
268 plt.show()
269
270 elif len(self.symbols)==3:
271 return 0
272
273 return points
274
275
276 def _polymorphic(func):
277 @functools.wraps(func)
278 def wrapper(left, right):
279 if isinstance(left, numbers.Rational):
280 left = Rational(left)
281 elif not isinstance(left, Expression):
282 raise TypeError('left must be a a rational number '
283 'or a linear expression')
284 if isinstance(right, numbers.Rational):
285 right = Rational(right)
286 elif not isinstance(right, Expression):
287 raise TypeError('right must be a a rational number '
288 'or a linear expression')
289 return func(left, right)
290 return wrapper
291
292 @_polymorphic
293 def Lt(left, right):
294 return Polyhedron([], [right - left - 1])
295
296 @_polymorphic
297 def Le(left, right):
298 return Polyhedron([], [right - left])
299
300 @_polymorphic
301 def Eq(left, right):
302 return Polyhedron([left - right], [])
303
304 @_polymorphic
305 def Ne(left, right):
306 return ~Eq(left, right)
307
308 @_polymorphic
309 def Gt(left, right):
310 return Polyhedron([], [left - right - 1])
311
312 @_polymorphic
313 def Ge(left, right):
314 return Polyhedron([], [left - right])
315
316
317 Empty = Eq(1, 0)
318
319 Universe = Polyhedron([])