shithub: tlsclient

ref: 009439541d2c6e8af2596f8fb1b4df85861fd212
dir: /third_party/boringssl/src/crypto/mldsa/mldsa_test.cc/

View raw version
/* Copyright (c) 2024, Google LLC
 *
 * Permission to use, copy, modify, and/or distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */

#include <openssl/mldsa.h>

#include <memory>
#include <vector>

#include <gtest/gtest.h>

#include <openssl/bytestring.h>
#include <openssl/mem.h>
#include <openssl/span.h>

#include "../test/file_test.h"
#include "../test/test_util.h"
#include "./internal.h"


namespace {

template <typename T>
std::vector<uint8_t> Marshal(int (*marshal_func)(CBB *, const T *),
                             const T *t) {
  bssl::ScopedCBB cbb;
  uint8_t *encoded;
  size_t encoded_len;
  if (!CBB_init(cbb.get(), 1) ||      //
      !marshal_func(cbb.get(), t) ||  //
      !CBB_finish(cbb.get(), &encoded, &encoded_len)) {
    abort();
  }

  std::vector<uint8_t> ret(encoded, encoded + encoded_len);
  OPENSSL_free(encoded);
  return ret;
}

// This test is very slow, so it is disabled by default.
TEST(MLDSATest, DISABLED_BitFlips) {
  std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
  auto priv = std::make_unique<MLDSA65_private_key>();
  uint8_t seed[MLDSA_SEED_BYTES];
  EXPECT_TRUE(
      MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get()));

  std::vector<uint8_t> encoded_signature(MLDSA65_SIGNATURE_BYTES);
  static const uint8_t kMessage[] = {'H', 'e', 'l', 'l', 'o', ' ',
                                     'w', 'o', 'r', 'l', 'd'};
  EXPECT_TRUE(MLDSA65_sign(encoded_signature.data(), priv.get(), kMessage,
                           sizeof(kMessage), nullptr, 0));

  auto pub = std::make_unique<MLDSA65_public_key>();
  CBS cbs = bssl::MakeConstSpan(encoded_public_key);
  ASSERT_TRUE(MLDSA65_parse_public_key(pub.get(), &cbs));

  EXPECT_EQ(MLDSA65_verify(pub.get(), encoded_signature.data(),
                           encoded_signature.size(), kMessage, sizeof(kMessage),
                           nullptr, 0),
            1);

  for (size_t i = 0; i < MLDSA65_SIGNATURE_BYTES; i++) {
    for (int j = 0; j < 8; j++) {
      encoded_signature[i] ^= 1 << j;
      EXPECT_EQ(MLDSA65_verify(pub.get(), encoded_signature.data(),
                               encoded_signature.size(), kMessage,
                               sizeof(kMessage), nullptr, 0),
                0)
          << "Bit flip in signature at byte " << i << " bit " << j
          << " didn't cause a verification failure";
      encoded_signature[i] ^= 1 << j;
    }
  }
}

TEST(MLDSATest, Basic) {
  std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
  auto priv = std::make_unique<MLDSA65_private_key>();
  uint8_t seed[MLDSA_SEED_BYTES];
  EXPECT_TRUE(
      MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get()));

  std::vector<uint8_t> encoded_signature(MLDSA65_SIGNATURE_BYTES);
  static const uint8_t kMessage[] = {'H', 'e', 'l', 'l', 'o', ' ',
                                     'w', 'o', 'r', 'l', 'd'};
  static const uint8_t kContext[] = {'c', 't', 'x'};
  EXPECT_TRUE(MLDSA65_sign(encoded_signature.data(), priv.get(), kMessage,
                           sizeof(kMessage), kContext, sizeof(kContext)));

  auto pub = std::make_unique<MLDSA65_public_key>();
  CBS cbs = bssl::MakeConstSpan(encoded_public_key);
  ASSERT_TRUE(MLDSA65_parse_public_key(pub.get(), &cbs));

  EXPECT_EQ(MLDSA65_verify(pub.get(), encoded_signature.data(),
                           encoded_signature.size(), kMessage, sizeof(kMessage),
                           kContext, sizeof(kContext)),
            1);

  auto priv2 = std::make_unique<MLDSA65_private_key>();
  EXPECT_TRUE(MLDSA65_private_key_from_seed(priv2.get(), seed, sizeof(seed)));

  EXPECT_EQ(Bytes(Marshal(MLDSA65_marshal_private_key, priv.get())),
            Bytes(Marshal(MLDSA65_marshal_private_key, priv2.get())));
}

TEST(MLDSATest, SignatureIsRandomized) {
  std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
  auto priv = std::make_unique<MLDSA65_private_key>();
  uint8_t seed[MLDSA_SEED_BYTES];
  EXPECT_TRUE(
      MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get()));

  auto pub = std::make_unique<MLDSA65_public_key>();
  CBS cbs = bssl::MakeConstSpan(encoded_public_key);
  ASSERT_TRUE(MLDSA65_parse_public_key(pub.get(), &cbs));

  std::vector<uint8_t> encoded_signature1(MLDSA65_SIGNATURE_BYTES);
  std::vector<uint8_t> encoded_signature2(MLDSA65_SIGNATURE_BYTES);
  static const uint8_t kMessage[] = {'H', 'e', 'l', 'l', 'o', ' ',
                                     'w', 'o', 'r', 'l', 'd'};
  EXPECT_TRUE(MLDSA65_sign(encoded_signature1.data(), priv.get(), kMessage,
                           sizeof(kMessage), nullptr, 0));
  EXPECT_TRUE(MLDSA65_sign(encoded_signature2.data(), priv.get(), kMessage,
                           sizeof(kMessage), nullptr, 0));

  EXPECT_NE(Bytes(encoded_signature1), Bytes(encoded_signature2));

  // Even though the signatures are different, they both verify.
  EXPECT_EQ(MLDSA65_verify(pub.get(), encoded_signature1.data(),
                           encoded_signature1.size(), kMessage,
                           sizeof(kMessage), nullptr, 0),
            1);
  EXPECT_EQ(MLDSA65_verify(pub.get(), encoded_signature2.data(),
                           encoded_signature2.size(), kMessage,
                           sizeof(kMessage), nullptr, 0),
            1);
}

TEST(MLDSATest, PublicFromPrivateIsConsistent) {
  std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
  auto priv = std::make_unique<MLDSA65_private_key>();
  uint8_t seed[MLDSA_SEED_BYTES];
  EXPECT_TRUE(
      MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get()));

  auto pub = std::make_unique<MLDSA65_public_key>();
  EXPECT_TRUE(MLDSA65_public_from_private(pub.get(), priv.get()));

  std::vector<uint8_t> encoded_public_key2(MLDSA65_PUBLIC_KEY_BYTES);

  CBB cbb;
  CBB_init_fixed(&cbb, encoded_public_key2.data(), encoded_public_key2.size());
  ASSERT_TRUE(MLDSA65_marshal_public_key(&cbb, pub.get()));

  EXPECT_EQ(Bytes(encoded_public_key2), Bytes(encoded_public_key));
}

TEST(MLDSATest, InvalidPublicKeyEncodingLength) {
  // Encode a public key with a trailing 0 at the end.
  std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES + 1);
  auto priv = std::make_unique<MLDSA65_private_key>();
  uint8_t seed[MLDSA_SEED_BYTES];
  EXPECT_TRUE(
      MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get()));

  // Public key is 1 byte too short.
  CBS cbs = bssl::MakeConstSpan(encoded_public_key)
                .first(MLDSA65_PUBLIC_KEY_BYTES - 1);
  auto parsed_pub = std::make_unique<MLDSA65_public_key>();
  EXPECT_FALSE(MLDSA65_parse_public_key(parsed_pub.get(), &cbs));

  // Public key has the correct length.
  cbs = bssl::MakeConstSpan(encoded_public_key).first(MLDSA65_PUBLIC_KEY_BYTES);
  EXPECT_TRUE(MLDSA65_parse_public_key(parsed_pub.get(), &cbs));

  // Public key is 1 byte too long.
  cbs = bssl::MakeConstSpan(encoded_public_key);
  EXPECT_FALSE(MLDSA65_parse_public_key(parsed_pub.get(), &cbs));
}

TEST(MLDSATest, InvalidPrivateKeyEncodingLength) {
  std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
  auto priv = std::make_unique<MLDSA65_private_key>();
  uint8_t seed[MLDSA_SEED_BYTES];
  EXPECT_TRUE(
      MLDSA65_generate_key(encoded_public_key.data(), seed, priv.get()));

  CBB cbb;
  std::vector<uint8_t> malformed_private_key(MLDSA65_PRIVATE_KEY_BYTES + 1, 0);
  CBB_init_fixed(&cbb, malformed_private_key.data(), MLDSA65_PRIVATE_KEY_BYTES);
  ASSERT_TRUE(MLDSA65_marshal_private_key(&cbb, priv.get()));

  CBS cbs;
  auto parsed_priv = std::make_unique<MLDSA65_private_key>();

  // Private key is 1 byte too short.
  CBS_init(&cbs, malformed_private_key.data(), MLDSA65_PRIVATE_KEY_BYTES - 1);
  EXPECT_FALSE(MLDSA65_parse_private_key(parsed_priv.get(), &cbs));

  // Private key has the correct length.
  CBS_init(&cbs, malformed_private_key.data(), MLDSA65_PRIVATE_KEY_BYTES);
  EXPECT_TRUE(MLDSA65_parse_private_key(parsed_priv.get(), &cbs));

  // Private key is 1 byte too long.
  CBS_init(&cbs, malformed_private_key.data(), MLDSA65_PRIVATE_KEY_BYTES + 1);
  EXPECT_FALSE(MLDSA65_parse_private_key(parsed_priv.get(), &cbs));
}

static void MLDSASigGenTest(FileTest *t) {
  std::vector<uint8_t> private_key_bytes, msg, expected_signature;
  ASSERT_TRUE(t->GetBytes(&private_key_bytes, "sk"));
  ASSERT_TRUE(t->GetBytes(&msg, "message"));
  ASSERT_TRUE(t->GetBytes(&expected_signature, "signature"));

  auto priv = std::make_unique<MLDSA65_private_key>();
  CBS cbs;
  CBS_init(&cbs, private_key_bytes.data(), private_key_bytes.size());
  EXPECT_TRUE(MLDSA65_parse_private_key(priv.get(), &cbs));

  const uint8_t zero_randomizer[MLDSA_SIGNATURE_RANDOMIZER_BYTES] = {0};
  std::vector<uint8_t> signature(MLDSA65_SIGNATURE_BYTES);
  EXPECT_TRUE(MLDSA65_sign_internal(signature.data(), priv.get(), msg.data(),
                                    msg.size(), nullptr, 0, nullptr, 0,
                                    zero_randomizer));

  EXPECT_EQ(Bytes(signature), Bytes(expected_signature));

  auto pub = std::make_unique<MLDSA65_public_key>();
  ASSERT_TRUE(MLDSA65_public_from_private(pub.get(), priv.get()));
  EXPECT_TRUE(MLDSA65_verify_internal(pub.get(), signature.data(), msg.data(),
                                      msg.size(), nullptr, 0, nullptr, 0));
}

TEST(MLDSATest, SigGenTests) {
  FileTestGTest("crypto/mldsa/mldsa_nist_siggen_tests.txt", MLDSASigGenTest);
}

static void MLDSAKeyGenTest(FileTest *t) {
  std::vector<uint8_t> seed, expected_public_key, expected_private_key;
  ASSERT_TRUE(t->GetBytes(&seed, "seed"));
  ASSERT_TRUE(t->GetBytes(&expected_public_key, "pub"));
  ASSERT_TRUE(t->GetBytes(&expected_private_key, "priv"));

  std::vector<uint8_t> encoded_public_key(MLDSA65_PUBLIC_KEY_BYTES);
  auto priv = std::make_unique<MLDSA65_private_key>();
  ASSERT_TRUE(MLDSA65_generate_key_external_entropy(encoded_public_key.data(),
                                                    priv.get(), seed.data()));

  EXPECT_EQ(Bytes(encoded_public_key), Bytes(expected_public_key));
}

TEST(MLDSATest, KeyGenTests) {
  FileTestGTest("crypto/mldsa/mldsa_nist_keygen_tests.txt", MLDSAKeyGenTest);
}

static void MLDSAWycheproofSignTest(FileTest *t) {
  std::vector<uint8_t> private_key_bytes, msg, expected_signature, context;
  ASSERT_TRUE(t->GetInstructionBytes(&private_key_bytes, "privateKey"));
  ASSERT_TRUE(t->GetBytes(&msg, "msg"));
  ASSERT_TRUE(t->GetBytes(&expected_signature, "sig"));
  if (t->HasAttribute("ctx")) {
    t->GetBytes(&context, "ctx");
  }
  std::string result;
  ASSERT_TRUE(t->GetAttribute(&result, "result"));
  t->IgnoreAttribute("flags");

  CBS cbs;
  CBS_init(&cbs, private_key_bytes.data(), private_key_bytes.size());
  auto priv = std::make_unique<MLDSA65_private_key>();
  const int priv_ok = MLDSA65_parse_private_key(priv.get(), &cbs);

  if (!priv_ok) {
    ASSERT_TRUE(result != "valid");
    return;
  }

  // Unfortunately we need to reimplement the context length check here because
  // we are using the internal function in order to pass in an all-zero
  // randomizer.
  if (context.size() > 255) {
    ASSERT_TRUE(result != "valid");
    return;
  }

  const uint8_t zero_randomizer[MLDSA_SIGNATURE_RANDOMIZER_BYTES] = {0};
  std::vector<uint8_t> signature(MLDSA65_SIGNATURE_BYTES);
  const uint8_t context_prefix[2] = {0, static_cast<uint8_t>(context.size())};
  EXPECT_TRUE(MLDSA65_sign_internal(
      signature.data(), priv.get(), msg.data(), msg.size(), context_prefix,
      sizeof(context_prefix), context.data(), context.size(), zero_randomizer));

  EXPECT_EQ(Bytes(signature), Bytes(expected_signature));
}

TEST(MLDSATest, WycheproofSignTests) {
  FileTestGTest(
      "third_party/wycheproof_testvectors/mldsa_65_standard_sign_test.txt",
      MLDSAWycheproofSignTest);
}

static void MLDSAWycheproofVerifyTest(FileTest *t) {
  std::vector<uint8_t> public_key_bytes, msg, signature, context;
  ASSERT_TRUE(t->GetInstructionBytes(&public_key_bytes, "publicKey"));
  ASSERT_TRUE(t->GetBytes(&msg, "msg"));
  ASSERT_TRUE(t->GetBytes(&signature, "sig"));
  if (t->HasAttribute("ctx")) {
    t->GetBytes(&context, "ctx");
  }
  std::string result, flags;
  ASSERT_TRUE(t->GetAttribute(&result, "result"));
  ASSERT_TRUE(t->GetAttribute(&flags, "flags"));

  CBS cbs;
  CBS_init(&cbs, public_key_bytes.data(), public_key_bytes.size());
  auto pub = std::make_unique<MLDSA65_public_key>();
  const int pub_ok = MLDSA65_parse_public_key(pub.get(), &cbs);

  if (!pub_ok) {
    EXPECT_EQ(flags, "IncorrectPublicKeyLength");
    return;
  }

  const int sig_ok =
      MLDSA65_verify(pub.get(), signature.data(), signature.size(), msg.data(),
                     msg.size(), context.data(), context.size());
  if (!sig_ok) {
    EXPECT_EQ(result, "invalid");
  } else {
    EXPECT_EQ(result, "valid");
  }
}

TEST(MLDSATest, WycheproofVerifyTests) {
  FileTestGTest(
      "third_party/wycheproof_testvectors/mldsa_65_standard_verify_test.txt",
      MLDSAWycheproofVerifyTest);
}

}  // namespace