6cd1ff432127b2a30e1bbae759f1140cbb328434
4 from fractions
import Fraction
6 from pypol
.linear
import *
11 def _requires_sympy(func
):
12 @functools.wraps(func
)
17 def _requires_sympy(func
):
18 @functools.wraps(func
)
20 raise unittest
.SkipTest('SymPy is not available')
24 class TestExpression(unittest
.TestCase
):
27 self
.x
= Expression({'x': 1})
28 self
.y
= Expression({'y': 1})
29 self
.z
= Expression({'z': 1})
30 self
.zero
= Expression(constant
=0)
31 self
.one
= Expression(constant
=1)
32 self
.pi
= Expression(constant
=Fraction(22, 7))
33 self
.expr
= self
.x
- 2*self
.y
+ 3
35 def test_new_subclass(self
):
36 self
.assertIsInstance(self
.x
, Symbol
)
37 self
.assertIsInstance(self
.pi
, Constant
)
38 self
.assertNotIsInstance(self
.x
+ self
.pi
, Symbol
)
39 self
.assertNotIsInstance(self
.x
+ self
.pi
, Constant
)
40 xx
= Expression({'x': 2})
41 self
.assertNotIsInstance(xx
, Symbol
)
43 def test_new_types(self
):
44 with self
.assertRaises(TypeError):
45 Expression('x + y', 2)
46 self
.assertEqual(Expression({'x': 2}), Expression({self
.x
: 2}))
47 with self
.assertRaises(TypeError):
49 with self
.assertRaises(TypeError):
50 Expression({'x': '2'})
51 self
.assertEqual(Expression(constant
=1), Expression(constant
=self
.one
))
52 with self
.assertRaises(TypeError):
53 Expression(constant
='1')
55 def test_symbols(self
):
56 self
.assertCountEqual(self
.x
.symbols
, ['x'])
57 self
.assertCountEqual(self
.pi
.symbols
, [])
58 self
.assertCountEqual(self
.expr
.symbols
, ['x', 'y'])
60 def test_dimension(self
):
61 self
.assertEqual(self
.x
.dimension
, 1)
62 self
.assertEqual(self
.pi
.dimension
, 0)
63 self
.assertEqual(self
.expr
.dimension
, 2)
65 def test_coefficient(self
):
66 self
.assertEqual(self
.expr
.coefficient('x'), 1)
67 self
.assertEqual(self
.expr
.coefficient('y'), -2)
68 self
.assertEqual(self
.expr
.coefficient(self
.y
), -2)
69 self
.assertEqual(self
.expr
.coefficient('z'), 0)
70 with self
.assertRaises(TypeError):
71 self
.expr
.coefficient(0)
72 with self
.assertRaises(TypeError):
73 self
.expr
.coefficient(self
.expr
)
75 def test_getitem(self
):
76 self
.assertEqual(self
.expr
['x'], 1)
77 self
.assertEqual(self
.expr
['y'], -2)
78 self
.assertEqual(self
.expr
[self
.y
], -2)
79 self
.assertEqual(self
.expr
['z'], 0)
80 with self
.assertRaises(TypeError):
82 with self
.assertRaises(TypeError):
85 def test_coefficients(self
):
86 self
.assertCountEqual(self
.expr
.coefficients(), [('x', 1), ('y', -2)])
88 def test_constant(self
):
89 self
.assertEqual(self
.x
.constant
, 0)
90 self
.assertEqual(self
.pi
.constant
, Fraction(22, 7))
91 self
.assertEqual(self
.expr
.constant
, 3)
93 def test_isconstant(self
):
94 self
.assertFalse(self
.x
.isconstant())
95 self
.assertTrue(self
.pi
.isconstant())
96 self
.assertFalse(self
.expr
.isconstant())
98 def test_values(self
):
99 self
.assertCountEqual(self
.expr
.values(), [1, -2, 3])
101 def test_issymbol(self
):
102 self
.assertTrue(self
.x
.issymbol())
103 self
.assertFalse(self
.pi
.issymbol())
104 self
.assertFalse(self
.expr
.issymbol())
107 self
.assertTrue(self
.x
)
108 self
.assertFalse(self
.zero
)
109 self
.assertTrue(self
.pi
)
110 self
.assertTrue(self
.expr
)
113 self
.assertEqual(+self
.expr
, self
.expr
)
116 self
.assertEqual(-self
.expr
, -self
.x
+ 2*self
.y
- 3)
119 self
.assertEqual(self
.x
+ Fraction(22, 7), self
.x
+ self
.pi
)
120 self
.assertEqual(Fraction(22, 7) + self
.x
, self
.x
+ self
.pi
)
121 self
.assertEqual(self
.x
+ self
.x
, 2 * self
.x
)
122 self
.assertEqual(self
.expr
+ 2*self
.y
, self
.x
+ 3)
125 self
.assertEqual(self
.x
- self
.x
, 0)
126 self
.assertEqual(self
.expr
- 3, self
.x
- 2*self
.y
)
127 self
.assertEqual(0 - self
.x
, -self
.x
)
130 self
.assertEqual(self
.pi
* 7, 22)
131 self
.assertEqual(self
.expr
* 0, 0)
132 self
.assertEqual(0 * self
.expr
, 0)
133 self
.assertEqual(self
.expr
* 2, 2*self
.x
- 4*self
.y
+ 6)
136 with self
.assertRaises(ZeroDivisionError):
138 self
.assertEqual(self
.expr
/ 2, self
.x
/ 2 - self
.y
+ Fraction(3, 2))
141 self
.assertEqual(str(Expression()), '0')
142 self
.assertEqual(str(self
.x
), 'x')
143 self
.assertEqual(str(-self
.x
), '-x')
144 self
.assertEqual(str(self
.pi
), '22/7')
145 self
.assertEqual(str(self
.expr
), 'x - 2*y + 3')
148 self
.assertEqual(repr(self
.x
), "Symbol('x')")
149 self
.assertEqual(repr(self
.one
), 'Constant(1)')
150 self
.assertEqual(repr(self
.pi
), 'Constant(22, 7)')
151 self
.assertEqual(repr(self
.x
+ self
.one
), "Expression('x + 1')")
152 self
.assertEqual(repr(self
.expr
), "Expression('x - 2*y + 3')")
154 def test_fromstring(self
):
155 self
.assertEqual(Expression
.fromstring('x'), self
.x
)
156 self
.assertEqual(Expression
.fromstring('-x'), -self
.x
)
157 self
.assertEqual(Expression
.fromstring('22/7'), self
.pi
)
158 self
.assertEqual(Expression
.fromstring('x - 2y + 3'), self
.expr
)
159 self
.assertEqual(Expression
.fromstring('x - (3-1)y + 3'), self
.expr
)
160 self
.assertEqual(Expression
.fromstring('x - 2*y + 3'), self
.expr
)
163 self
.assertEqual(self
.expr
, self
.expr
)
164 self
.assertNotEqual(self
.x
, self
.y
)
165 self
.assertEqual(self
.zero
, 0)
167 def test__toint(self
):
168 self
.assertEqual((self
.x
+ self
.y
/2 + self
.z
/3)._toint
(),
169 6*self
.x
+ 3*self
.y
+ 2*self
.z
)
172 def test_fromsympy(self
):
173 sp_x
, sp_y
= sympy
.symbols('x y')
174 self
.assertEqual(Expression
.fromsympy(sp_x
), self
.x
)
175 self
.assertEqual(Expression
.fromsympy(sympy
.Rational(22, 7)), self
.pi
)
176 self
.assertEqual(Expression
.fromsympy(sp_x
- 2*sp_y
+ 3), self
.expr
)
177 with self
.assertRaises(ValueError):
178 Expression
.fromsympy(sp_x
*sp_y
)
181 def test_tosympy(self
):
182 sp_x
, sp_y
= sympy
.symbols('x y')
183 self
.assertEqual(self
.x
.tosympy(), sp_x
)
184 self
.assertEqual(self
.pi
.tosympy(), sympy
.Rational(22, 7))
185 self
.assertEqual(self
.expr
.tosympy(), sp_x
- 2*sp_y
+ 3)
188 class TestConstant(unittest
.TestCase
):
191 self
.zero
= Constant(0)
192 self
.one
= Constant(1)
193 self
.pi
= Constant(Fraction(22, 7))
196 def test_fromsympy(self
):
197 self
.assertEqual(Constant
.fromsympy(sympy
.Rational(22, 7)), self
.pi
)
198 with self
.assertRaises(TypeError):
199 Constant
.fromsympy(sympy
.Symbol('x'))
202 class TestSymbol(unittest
.TestCase
):
209 self
.assertEqual(self
.x
.name
, 'x')
211 def test_symbols(self
):
212 self
.assertListEqual(list(symbols('x y')), [self
.x
, self
.y
])
213 self
.assertListEqual(list(symbols('x,y')), [self
.x
, self
.y
])
214 self
.assertListEqual(list(symbols(['x', 'y'])), [self
.x
, self
.y
])
217 def test_fromsympy(self
):
218 sp_x
= sympy
.Symbol('x')
219 self
.assertEqual(Symbol
.fromsympy(sp_x
), self
.x
)
220 with self
.assertRaises(TypeError):
221 Symbol
.fromsympy(sympy
.Rational(22, 7))
222 with self
.assertRaises(TypeError):
223 Symbol
.fromsympy(2 * sp_x
)
224 with self
.assertRaises(TypeError):
225 Symbol
.fromsympy(sp_x
*sp_x
)
228 class TestOperators(unittest
.TestCase
):
233 class TestPolyhedron(unittest
.TestCase
):
236 x
, y
= symbols('x y')
237 self
.square
= Polyhedron(inequalities
=[x
, 1 - x
, y
, 1 - y
])
239 def test_symbols(self
):
240 self
.assertCountEqual(self
.square
.symbols
, ['x', 'y'])
242 def test_dimension(self
):
243 self
.assertEqual(self
.square
.dimension
, 2)
246 self
.assertEqual(str(self
.square
),
247 'x >= 0, -x + 1 >= 0, y >= 0, -y + 1 >= 0')
250 self
.assertEqual(repr(self
.square
),
251 "Polyhedron('x >= 0, -x + 1 >= 0, y >= 0, -y + 1 >= 0')")
253 def test_fromstring(self
):
254 self
.assertEqual(Polyhedron
.fromstring('{x >= 0, -x + 1 >= 0, '
255 'y >= 0, -y + 1 >= 0}'), self
.square
)
257 def test_isempty(self
):
258 self
.assertFalse(self
.square
.isempty())
260 def test_isuniverse(self
):
261 self
.assertFalse(self
.square
.isuniverse())
263 @unittest.expectedFailure
265 def test_fromsympy(self
):
266 sp_x
, sp_y
= sympy
.symbols('x y')
267 self
.assertEqual(Polyhedron
.fromsympy((sp_x
>= 0) & (sp_x
<= 1) &
268 (sp_y
>= 0) & (sp_y
<= 1)), self
.square
)
271 def test_tosympy(self
):
272 sp_x
, sp_y
= sympy
.symbols('x y')
273 self
.assertEqual(self
.square
.tosympy(),
274 sympy
.And(-sp_x
+ 1 >= 0, -sp_y
+ 1 >= 0, sp_x
>= 0, sp_y
>= 0))
280 self
.assertEqual(repr(Empty
), 'Empty')
282 def test_isempty(self
):
283 self
.assertTrue(Empty
.isempty())
285 def test_isuniverse(self
):
286 self
.assertFalse(Empty
.isuniverse())
292 self
.assertEqual(repr(Universe
), 'Universe')
294 def test_isempty(self
):
295 self
.assertTrue(Universe
.isempty())
297 def test_isuniverse(self
):
298 self
.assertTrue(Universe
.isuniverse())