From 3156820e058cd410078a4c39d5a2a85f6b714c07 Mon Sep 17 00:00:00 2001 From: Sam Hughes Date: Mon, 19 Jun 2017 20:27:15 -0700 Subject: [PATCH] Make BIGNUM functions in geo code avoid removed SSL API's - BIGNUM functions now avoid accessing ->d and ->top - Allocates EVP_MD_CTX on the heap - Adds the unittest BignumTest --- src/client_protocol/server.cc | 9 +- .../geo/s2/util/math/exactfloat/exactfloat.cc | 130 +++++++++--------- .../geo/s2/util/math/exactfloat/exactfloat.h | 27 ++-- src/unittest/bignum_test.cc | 59 ++++++++ 4 files changed, 147 insertions(+), 78 deletions(-) create mode 100644 src/unittest/bignum_test.cc diff --git a/src/client_protocol/server.cc b/src/client_protocol/server.cc index 4b9c4299f06..ccce61ca297 100644 --- a/src/client_protocol/server.cc +++ b/src/client_protocol/server.cc @@ -174,12 +174,13 @@ void http_conn_cache_t::on_ring() { } size_t http_conn_cache_t::sha_hasher_t::operator()(const conn_key_t &x) const { - EVP_MD_CTX c; - EVP_DigestInit(&c, EVP_sha256()); - EVP_DigestUpdate(&c, x.data(), x.size()); + EVP_MD_CTX *c = EVP_MD_CTX_create(); + EVP_DigestInit(c, EVP_sha256()); + EVP_DigestUpdate(c, x.data(), x.size()); unsigned char digest[EVP_MAX_MD_SIZE]; unsigned int digest_size = 0; - EVP_DigestFinal(&c, digest, &digest_size); + EVP_DigestFinal(c, digest, &digest_size); + EVP_MD_CTX_destroy(c); rassert(digest_size >= sizeof(size_t)); size_t res = 0; memcpy(&res, digest, std::min(sizeof(size_t), static_cast(digest_size))); diff --git a/src/rdb_protocol/geo/s2/util/math/exactfloat/exactfloat.cc b/src/rdb_protocol/geo/s2/util/math/exactfloat/exactfloat.cc index 1f3c5c7a85f..383d040f55c 100644 --- a/src/rdb_protocol/geo/s2/util/math/exactfloat/exactfloat.cc +++ b/src/rdb_protocol/geo/s2/util/math/exactfloat/exactfloat.cc @@ -60,7 +60,7 @@ bool bn_is_negative_func(const BIGNUM* bn) { } // Set a BIGNUM to the given unsigned 64-bit value. -inline static void BN_ext_set_uint64(BIGNUM* bn, uint64 v) { +void BN_ext_set_uint64(BIGNUM* bn, uint64 v) { #if BN_BITS2 == 64 CHECK(BN_set_word(bn, v)); #else @@ -73,39 +73,38 @@ inline static void BN_ext_set_uint64(BIGNUM* bn, uint64 v) { // Return the absolute value of a BIGNUM as a 64-bit unsigned integer. // Requires that BIGNUM fits into 64 bits. -inline static uint64 BN_ext_get_uint64(const BIGNUM* bn) { - DCHECK_LE(BN_num_bytes(bn), static_cast(sizeof(uint64))); -#if BN_BITS2 == 64 - return BN_get_word(bn); -#else - COMPILE_ASSERT(BN_BITS2 == 32, at_least_32_bit_openssl_build_needed); - if (bn->top == 0) return 0; - if (bn->top == 1) return BN_get_word(bn); - DCHECK_EQ(bn->top, 2); - return (static_cast(bn->d[1]) << 32) + bn->d[0]; -#endif +uint64 BN_ext_get_uint64(const BIGNUM* bn) { + int num_bytes = BN_num_bytes(bn); + DCHECK_LE(num_bytes, 8); + + std::unique_ptr buf(new unsigned char[num_bytes]); + int res = BN_bn2bin(bn, buf.get()); + DCHECK_EQ(num_bytes, res); + + uint64_t ret = 0; + for (int i = 0; i < res; ++i) { + ret = ret * 256; + ret += buf[i]; + } + + return ret; } // Count the number of low-order zero bits in the given BIGNUM (ignoring its // sign). Returns 0 if the argument is zero. -static int BN_ext_count_low_zero_bits(const BIGNUM* bn) { - int count = 0; - for (int i = 0; i < bn->top; ++i) { - BN_ULONG w = bn->d[i]; - if (w == 0) { - count += 8 * sizeof(BN_ULONG); - } else { - for (; (w & 1) == 0; w >>= 1) { - ++count; +int BN_ext_count_low_zero_bits(const BIGNUM* bn) { + int num_bytes = BN_num_bytes(bn); + for (int i = 0; i < num_bytes * 8; ++i) { + if (BN_is_bit_set(bn, i)) { + return i; } - break; - } } - return count; + // They're all zero. What now? Just going to treat this as if BN_num_bytes had been + // zero in any case. + return 0; } -ExactFloat::ExactFloat(double v) { - BN_init(&bn_); +ExactFloat::ExactFloat(double v) : bn_(make_BN_new()) { sign_ = signbit(v) ? -1 : 1; if (std::isnan(v)) { set_nan(); @@ -121,27 +120,26 @@ ExactFloat::ExactFloat(double v) { int expl; double f = frexp(fabs(v), &expl); uint64 m = static_cast(ldexp(f, kDoubleMantissaBits)); - BN_ext_set_uint64(&bn_, m); + BN_ext_set_uint64(bn_.get(), m); bn_exp_ = expl - kDoubleMantissaBits; Canonicalize(); } } -ExactFloat::ExactFloat(int v) { - BN_init(&bn_); +ExactFloat::ExactFloat(int v) : bn_(make_BN_new()) { sign_ = (v >= 0) ? 1 : -1; // Note that this works even for INT_MIN because the parameter type for // BN_set_word() is unsigned. - CHECK(BN_set_word(&bn_, abs(v))); + CHECK(BN_set_word(bn_.get(), abs(v))); bn_exp_ = 0; Canonicalize(); } ExactFloat::ExactFloat(const ExactFloat& b) : sign_(b.sign_), - bn_exp_(b.bn_exp_) { - BN_init(&bn_); - BN_copy(&bn_, &b.bn_); + bn_exp_(b.bn_exp_), + bn_(make_BN_new()) { + BN_copy(bn_.get(), b.bn_.get()); } ExactFloat ExactFloat::SignedZero(int sign) { @@ -163,30 +161,30 @@ ExactFloat ExactFloat::NaN() { } int ExactFloat::prec() const { - return BN_num_bits(&bn_); + return BN_num_bits(bn_.get()); } int ExactFloat::exp() const { DCHECK(is_normal()); - return bn_exp_ + BN_num_bits(&bn_); + return bn_exp_ + BN_num_bits(bn_.get()); } void ExactFloat::set_zero(int sign) { sign_ = sign; bn_exp_ = kExpZero; - if (!BN_is_zero(&bn_)) BN_zero(&bn_); + if (!BN_is_zero(bn_.get())) BN_zero(bn_.get()); } void ExactFloat::set_inf(int sign) { sign_ = sign; bn_exp_ = kExpInfinity; - if (!BN_is_zero(&bn_)) BN_zero(&bn_); + if (!BN_is_zero(bn_.get())) BN_zero(bn_.get()); } void ExactFloat::set_nan() { sign_ = 1; bn_exp_ = kExpNaN; - if (!BN_is_zero(&bn_)) BN_zero(&bn_); + if (!BN_is_zero(bn_.get())) BN_zero(bn_.get()); } double ExactFloat::ToDouble() const { @@ -200,13 +198,13 @@ double ExactFloat::ToDouble() const { } double ExactFloat::ToDoubleHelper() const { - DCHECK_LE(BN_num_bits(&bn_), kDoubleMantissaBits); + DCHECK_LE(BN_num_bits(bn_.get()), kDoubleMantissaBits); if (!is_normal()) { if (is_zero()) return copysign(0, sign_); if (is_inf()) return copysign(INFINITY, sign_); return copysign(NAN, sign_); } - uint64 d_mantissa = BN_ext_get_uint64(&bn_); + uint64 d_mantissa = BN_ext_get_uint64(bn_.get()); // We rely on ldexp() to handle overflow and underflow. (It will return a // signed zero or infinity if the result is too small or too large.) return sign_ * ldexp(static_cast(d_mantissa), bn_exp_); @@ -257,11 +255,11 @@ ExactFloat ExactFloat::RoundToPowerOf2(int bit_exp, RoundingMode mode) const { // Never increment. } else if (mode == kRoundTiesAwayFromZero) { // Increment if the highest discarded bit is 1. - if (BN_is_bit_set(&bn_, shift - 1)) + if (BN_is_bit_set(bn_.get(), shift - 1)) increment = true; } else if (mode == kRoundAwayFromZero) { // Increment unless all discarded bits are zero. - if (BN_ext_count_low_zero_bits(&bn_) < shift) + if (BN_ext_count_low_zero_bits(bn_.get()) < shift) increment = true; } else { DCHECK_EQ(mode, kRoundTiesToEven); @@ -271,16 +269,16 @@ ExactFloat ExactFloat::RoundToPowerOf2(int bit_exp, RoundingMode mode) const { // 0/10* -> Don't increment (fraction = 1/2, kept part even) // 1/10* -> Increment (fraction = 1/2, kept part odd) // ./1.*1.* -> Increment (fraction > 1/2) - if (BN_is_bit_set(&bn_, shift - 1) && - ((BN_is_bit_set(&bn_, shift) || - BN_ext_count_low_zero_bits(&bn_) < shift - 1))) { + if (BN_is_bit_set(bn_.get(), shift - 1) && + ((BN_is_bit_set(bn_.get(), shift) || + BN_ext_count_low_zero_bits(bn_.get()) < shift - 1))) { increment = true; } } r.bn_exp_ = bn_exp_ + shift; - CHECK(BN_rshift(&r.bn_, &bn_, shift)); + CHECK(BN_rshift(r.bn_.get(), bn_.get(), shift)); if (increment) { - CHECK(BN_add_word(&r.bn_, 1)); + CHECK(BN_add_word(r.bn_.get(), 1)); } r.sign_ = sign_; r.Canonicalize(); @@ -387,7 +385,7 @@ int ExactFloat::GetDecimalDigits(int max_digits, std::string* digits) const { int bn_exp10; if (bn_exp_ >= 0) { // The easy case: bn = bn_ * (2 ** bn_exp_)), bn_exp10 = 0. - CHECK(BN_lshift(bn, &bn_, bn_exp_)); + CHECK(BN_lshift(bn, bn_.get(), bn_exp_)); bn_exp10 = 0; } else { // Set bn = bn_ * (5 ** -bn_exp_) and bn_exp10 = bn_exp_. This is @@ -397,7 +395,7 @@ int ExactFloat::GetDecimalDigits(int max_digits, std::string* digits) const { CHECK(BN_set_word(bn, 5)); BN_CTX* ctx = BN_CTX_new(); CHECK(BN_exp(bn, bn, power, ctx)); - CHECK(BN_mul(bn, bn, &bn_, ctx)); + CHECK(BN_mul(bn, bn, bn_.get(), ctx)); BN_CTX_free(ctx); BN_free(power); bn_exp10 = bn_exp_; @@ -453,7 +451,7 @@ ExactFloat& ExactFloat::operator=(const ExactFloat& b) { if (this != &b) { sign_ = b.sign_; bn_exp_ = b.bn_exp_; - BN_copy(&bn_, &b.bn_); + BN_copy(bn_.get(), b.bn_.get()); } return *this; } @@ -500,24 +498,24 @@ ExactFloat ExactFloat::SignedSum(int a_sign, const ExactFloat* a, // Shift "a" if necessary so that both values have the same bn_exp_. ExactFloat r; if (a->bn_exp_ > b->bn_exp_) { - CHECK(BN_lshift(&r.bn_, &a->bn_, a->bn_exp_ - b->bn_exp_)); + CHECK(BN_lshift(r.bn_.get(), a->bn_.get(), a->bn_exp_ - b->bn_exp_)); a = &r; // The only field of "a" used below is bn_. } r.bn_exp_ = b->bn_exp_; if (a_sign == b_sign) { - CHECK(BN_add(&r.bn_, &a->bn_, &b->bn_)); + CHECK(BN_add(r.bn_.get(), a->bn_.get(), b->bn_.get())); r.sign_ = a_sign; } else { // Note that the BIGNUM documentation is out of date -- all methods now // allow the result to be the same as any input argument, so it is okay if // (a == &r) due to the shift above. - CHECK(BN_sub(&r.bn_, &a->bn_, &b->bn_)); - if (bn_is_zero_func(&r.bn_)) { + CHECK(BN_sub(r.bn_.get(), a->bn_.get(), b->bn_.get())); + if (bn_is_zero_func(r.bn_.get())) { r.sign_ = +1; - } else if (bn_is_negative_func(&r.bn_)) { + } else if (bn_is_negative_func(r.bn_.get())) { // The magnitude of "b" was larger. r.sign_ = b_sign; - BN_set_negative(&r.bn_, false); + BN_set_negative(r.bn_.get(), false); } else { // They were equal, or the magnitude of "a" was larger. r.sign_ = a_sign; @@ -533,16 +531,16 @@ void ExactFloat::Canonicalize() { // Underflow/overflow occurs if exp() is not in [kMinExp, kMaxExp]. // We also convert a zero mantissa to signed zero. int my_exp = exp(); - if (my_exp < kMinExp || BN_is_zero(&bn_)) { + if (my_exp < kMinExp || BN_is_zero(bn_.get())) { set_zero(sign_); } else if (my_exp > kMaxExp) { set_inf(sign_); - } else if (!BN_is_odd(&bn_)) { + } else if (!BN_is_odd(bn_.get())) { // Remove any low-order zero bits from the mantissa. - DCHECK(!BN_is_zero(&bn_)); - int shift = BN_ext_count_low_zero_bits(&bn_); + DCHECK(!BN_is_zero(bn_.get())); + int shift = BN_ext_count_low_zero_bits(bn_.get()); if (shift > 0) { - CHECK(BN_rshift(&bn_, &bn_, shift)); + CHECK(BN_rshift(bn_.get(), bn_.get(), shift)); bn_exp_ += shift; } } @@ -575,7 +573,7 @@ ExactFloat operator*(const ExactFloat& a, const ExactFloat& b) { r.sign_ = result_sign; r.bn_exp_ = a.bn_exp_ + b.bn_exp_; BN_CTX* ctx = BN_CTX_new(); - CHECK(BN_mul(&r.bn_, &a.bn_, &b.bn_, ctx)); + CHECK(BN_mul(r.bn_.get(), a.bn_.get(), b.bn_.get(), ctx)); BN_CTX_free(ctx); r.Canonicalize(); return r; @@ -594,14 +592,14 @@ bool operator==(const ExactFloat& a, const ExactFloat& b) { // Otherwise, the signs and mantissas must match. Note that non-normal // values such as infinity have a mantissa of zero. - return a.sign_ == b.sign_ && BN_ucmp(&a.bn_, &b.bn_) == 0; + return a.sign_ == b.sign_ && BN_ucmp(a.bn_.get(), b.bn_.get()) == 0; } int ExactFloat::ScaleAndCompare(const ExactFloat& b) const { DCHECK(is_normal() && b.is_normal() && bn_exp_ >= b.bn_exp_); ExactFloat tmp = *this; - CHECK(BN_lshift(&tmp.bn_, &tmp.bn_, bn_exp_ - b.bn_exp_)); - return BN_ucmp(&tmp.bn_, &b.bn_); + CHECK(BN_lshift(tmp.bn_.get(), tmp.bn_.get(), bn_exp_ - b.bn_exp_)); + return BN_ucmp(tmp.bn_.get(), b.bn_.get()); } bool ExactFloat::UnsignedLess(const ExactFloat& b) const { @@ -692,7 +690,7 @@ T ExactFloat::ToInteger(RoundingMode mode) const { if (!r.is_inf()) { // If the unsigned value has more than 63 bits it is always clamped. if (r.exp() < 64) { - int64 value = BN_ext_get_uint64(&r.bn_) << r.bn_exp_; + int64 value = BN_ext_get_uint64(r.bn_.get()) << r.bn_exp_; if (r.sign_ < 0) value = -value; return max(kMinValue, min(kMaxValue, value)); } diff --git a/src/rdb_protocol/geo/s2/util/math/exactfloat/exactfloat.h b/src/rdb_protocol/geo/s2/util/math/exactfloat/exactfloat.h index b91290799c6..2c472778fda 100644 --- a/src/rdb_protocol/geo/s2/util/math/exactfloat/exactfloat.h +++ b/src/rdb_protocol/geo/s2/util/math/exactfloat/exactfloat.h @@ -96,8 +96,9 @@ #include #include -#include +#include +#include #include #include @@ -107,6 +108,10 @@ namespace geo { +struct BIGNUM_deleter { + void operator()(BIGNUM *b) { BN_free(b); } +}; + class ExactFloat { public: // The following limits are imposed by OpenSSL. @@ -499,7 +504,7 @@ class ExactFloat { // - bn_exp_ is the base-2 exponent applied to bn_. int32 sign_; int32 bn_exp_; - BIGNUM bn_; + std::unique_ptr bn_; // A standard IEEE "double" has a 53-bit mantissa consisting of a 52-bit // fraction plus an implicit leading "1" bit. @@ -555,16 +560,18 @@ class ExactFloat { static ExactFloat Unimplemented(); }; +inline std::unique_ptr make_BN_new() { + std::unique_ptr ret(BN_new(), BIGNUM_deleter{}); + guarantee(ret.get() != nullptr); + return ret; +} + ///////////////////////////////////////////////////////////////////////// // Implementation details follow: -inline ExactFloat::ExactFloat() : sign_(1), bn_exp_(kExpZero) { - BN_init(&bn_); -} +inline ExactFloat::ExactFloat() : sign_(1), bn_exp_(kExpZero), bn_(make_BN_new()) {} -inline ExactFloat::~ExactFloat() { - BN_free(&bn_); -} +inline ExactFloat::~ExactFloat() {} inline bool ExactFloat::is_zero() const { return bn_exp_ == kExpZero; } inline bool ExactFloat::is_inf() const { return bn_exp_ == kExpInfinity; } @@ -601,6 +608,10 @@ inline ExactFloat ExactFloat::CopyWithSign(int sign) const { return r; } +void BN_ext_set_uint64(BIGNUM *bn, uint64 v); +uint64 BN_ext_get_uint64(const BIGNUM *bn); +int BN_ext_count_low_zero_bits(const BIGNUM *bn); + } // namespace geo #endif // UTIL_MATH_EXACTFLOAT_EXACTFLOAT_H_ diff --git a/src/unittest/bignum_test.cc b/src/unittest/bignum_test.cc new file mode 100644 index 00000000000..2aba6456103 --- /dev/null +++ b/src/unittest/bignum_test.cc @@ -0,0 +1,59 @@ +#include "unittest/gtest.hpp" + +#include "rdb_protocol/geo/s2/util/math/exactfloat/exactfloat.h" + +namespace unittest { + +TEST(BignumTest, TestGetUint64) { + std::unique_ptr bn = geo::make_BN_new(); + + uint64_t x = geo::BN_ext_get_uint64(bn.get()); + ASSERT_EQ(0, x); + + uint64_t values[4] = { 1, (1ull << 32) - 1, (1ull << 32), uint64_t(-1) }; + for (int i = 0; i < 4; ++i) { + geo::BN_ext_set_uint64(bn.get(), values[i]); + ASSERT_EQ(values[i], geo::BN_ext_get_uint64(bn.get())); + } +} + +TEST(BignumTest, TestCountLowZeroBits) { + std::unique_ptr bn = geo::make_BN_new(); + int n = geo::BN_ext_count_low_zero_bits(bn.get()); + ASSERT_EQ(0, n); + + geo::BN_ext_set_uint64(bn.get(), 1); + n = geo::BN_ext_count_low_zero_bits(bn.get()); + ASSERT_EQ(0, n); + + geo::BN_ext_set_uint64(bn.get(), 2); + n = geo::BN_ext_count_low_zero_bits(bn.get()); + ASSERT_EQ(1, n); + + geo::BN_ext_set_uint64(bn.get(), 128); + n = geo::BN_ext_count_low_zero_bits(bn.get()); + ASSERT_EQ(7, n); + + geo::BN_ext_set_uint64(bn.get(), 256); + n = geo::BN_ext_count_low_zero_bits(bn.get()); + ASSERT_EQ(8, n); + + geo::BN_ext_set_uint64(bn.get(), 257); + n = geo::BN_ext_count_low_zero_bits(bn.get()); + ASSERT_EQ(0, n); + + for (int i = 0; i < 64; i++) { + geo::BN_ext_set_uint64(bn.get(), (1ull << i)); + n = geo::BN_ext_count_low_zero_bits(bn.get()); + ASSERT_EQ(i, n); + } + + geo::BN_ext_set_uint64(bn.get(), (1ull << 63)); + BN_CTX *ctx = BN_CTX_new(); + BN_mul(bn.get(), bn.get(), bn.get(), ctx); + BN_CTX_free(ctx); + n = geo::BN_ext_count_low_zero_bits(bn.get()); + ASSERT_EQ(126, n); +} + +} // namespace unittest