summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDamien George <damien@micropython.org>2021-11-30 00:31:46 +1100
committerDamien George <damien@micropython.org>2021-12-21 18:00:05 +1100
commit2c139bbf4e5724ab253b5b034ce925e04267a9c4 (patch)
tree68837dc2bebd0350d59b7ab50fbf0c4cab275676
parent05bea70979232629e059a7453fb7965545113d9f (diff)
py/mpz: Fix bugs with bitwise of -0 by ensuring all 0's are positive.
This commit makes sure that the value zero is always encoded in an mpz_t as neg=0 and len=0 (previously it was just len=0). This invariant is needed for some of the bitwise operations that operate on negative numbers, because they cannot handle -0. For example (-((1<<100)-(1<<100)))|1 was being computed as -65535, instead of 1. Fixes issue #8042. Signed-off-by: Damien George <damien@micropython.org>
-rw-r--r--py/mpz.c19
-rw-r--r--py/mpz.h3
-rw-r--r--tests/basics/int_big_zeroone.py41
3 files changed, 52 insertions, 11 deletions
diff --git a/py/mpz.c b/py/mpz.c
index 75e1fb1fd..b61997e2f 100644
--- a/py/mpz.c
+++ b/py/mpz.c
@@ -713,6 +713,7 @@ void mpz_set(mpz_t *dest, const mpz_t *src) {
void mpz_set_from_int(mpz_t *z, mp_int_t val) {
if (val == 0) {
+ z->neg = 0;
z->len = 0;
return;
}
@@ -899,10 +900,6 @@ bool mpz_is_even(const mpz_t *z) {
#endif
int mpz_cmp(const mpz_t *z1, const mpz_t *z2) {
- // to catch comparison of -0 with +0
- if (z1->len == 0 && z2->len == 0) {
- return 0;
- }
int cmp = (int)z2->neg - (int)z1->neg;
if (cmp != 0) {
return cmp;
@@ -1052,7 +1049,9 @@ void mpz_neg_inpl(mpz_t *dest, const mpz_t *z) {
if (dest != z) {
mpz_set(dest, z);
}
- dest->neg = 1 - dest->neg;
+ if (dest->len) {
+ dest->neg = 1 - dest->neg;
+ }
}
/* computes dest = ~z (= -z - 1)
@@ -1148,7 +1147,7 @@ void mpz_add_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
dest->len = mpn_sub(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
}
- dest->neg = lhs->neg;
+ dest->neg = lhs->neg & !!dest->len;
}
/* computes dest = lhs - rhs
@@ -1172,7 +1171,9 @@ void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
dest->len = mpn_sub(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
}
- if (neg) {
+ if (dest->len == 0) {
+ dest->neg = 0;
+ } else if (neg) {
dest->neg = 1 - lhs->neg;
} else {
dest->neg = lhs->neg;
@@ -1484,14 +1485,16 @@ void mpz_divmod_inpl(mpz_t *dest_quo, mpz_t *dest_rem, const mpz_t *lhs, const m
mpz_need_dig(dest_quo, lhs->len + 1); // +1 necessary?
memset(dest_quo->dig, 0, (lhs->len + 1) * sizeof(mpz_dig_t));
+ dest_quo->neg = 0;
dest_quo->len = 0;
mpz_need_dig(dest_rem, lhs->len + 1); // +1 necessary?
mpz_set(dest_rem, lhs);
mpn_div(dest_rem->dig, &dest_rem->len, rhs->dig, rhs->len, dest_quo->dig, &dest_quo->len);
+ dest_rem->neg &= !!dest_rem->len;
// check signs and do Python style modulo
if (lhs->neg != rhs->neg) {
- dest_quo->neg = 1;
+ dest_quo->neg = !!dest_quo->len;
if (!mpz_is_zero(dest_rem)) {
mpz_t mpzone;
mpz_init_from_int(&mpzone, -1);
diff --git a/py/mpz.h b/py/mpz.h
index 425587ee9..d27f57240 100644
--- a/py/mpz.h
+++ b/py/mpz.h
@@ -91,6 +91,7 @@ typedef int8_t mpz_dbl_dig_signed_t;
#define MPZ_NUM_DIG_FOR_LL ((sizeof(long long) * 8 + MPZ_DIG_SIZE - 1) / MPZ_DIG_SIZE)
typedef struct _mpz_t {
+ // Zero has neg=0, len=0. Negative zero is not allowed.
size_t neg : 1;
size_t fixed_dig : 1;
size_t alloc : (8 * sizeof(size_t) - 2);
@@ -119,7 +120,7 @@ static inline bool mpz_is_zero(const mpz_t *z) {
return z->len == 0;
}
static inline bool mpz_is_neg(const mpz_t *z) {
- return z->len != 0 && z->neg != 0;
+ return z->neg != 0;
}
int mpz_cmp(const mpz_t *lhs, const mpz_t *rhs);
diff --git a/tests/basics/int_big_zeroone.py b/tests/basics/int_big_zeroone.py
index 7e0b7a720..81381526b 100644
--- a/tests/basics/int_big_zeroone.py
+++ b/tests/basics/int_big_zeroone.py
@@ -1,4 +1,4 @@
-# test [0,-0,1,-1] edge cases of bignum
+# test [0,1,-1] edge cases of bignum
long_zero = (2**64) >> 65
long_neg_zero = -long_zero
@@ -13,7 +13,7 @@ print([~c for c in cases])
print([c >> 1 for c in cases])
print([c << 1 for c in cases])
-# comparison of 0/-0/+0
+# comparison of 0
print(long_zero == 0)
print(long_neg_zero == 0)
print(long_one - 1 == 0)
@@ -26,3 +26,40 @@ print(long_neg_zero < 1)
print(long_neg_zero < -1)
print(long_neg_zero > 1)
print(long_neg_zero > -1)
+
+# generate zeros that involve negative numbers
+large = 1 << 70
+large_plus_one = large + 1
+zeros = (
+ large - large,
+ -large + large,
+ large + -large,
+ -(large - large),
+ large - large_plus_one + 1,
+ -large & (large - large),
+ -large ^ -large,
+ -large * (large - large),
+ (large - large) // -large,
+ -large // -large_plus_one,
+ -(large + large) % large,
+ (large + large) % -large,
+ -(large + large) % -large,
+)
+print(zeros)
+
+# compute arithmetic operations that may have problems with -0
+# (this checks that -0 is never generated in the zeros tuple)
+cases = (0, 1, -1) + zeros
+for lhs in cases:
+ print("-{} = {}".format(lhs, -lhs))
+ print("~{} = {}".format(lhs, ~lhs))
+ print("{} >> 1 = {}".format(lhs, lhs >> 1))
+ print("{} << 1 = {}".format(lhs, lhs << 1))
+ for rhs in cases:
+ print("{} == {} = {}".format(lhs, rhs, lhs == rhs))
+ print("{} + {} = {}".format(lhs, rhs, lhs + rhs))
+ print("{} - {} = {}".format(lhs, rhs, lhs - rhs))
+ print("{} * {} = {}".format(lhs, rhs, lhs * rhs))
+ print("{} | {} = {}".format(lhs, rhs, lhs | rhs))
+ print("{} & {} = {}".format(lhs, rhs, lhs & rhs))
+ print("{} ^ {} = {}".format(lhs, rhs, lhs ^ rhs))