test_state.cc 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. /* Copyright (c) 2018, Google Inc.
  2. *
  3. * Permission to use, copy, modify, and/or distribute this software for any
  4. * purpose with or without fee is hereby granted, provided that the above
  5. * copyright notice and this permission notice appear in all copies.
  6. *
  7. * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  8. * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  9. * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
  10. * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  11. * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
  12. * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
  13. * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
  14. #include "test_state.h"
  15. #include <openssl/ssl.h>
  16. #include "../../crypto/internal.h"
  17. #include "../internal.h"
  18. using namespace bssl;
  19. static CRYPTO_once_t g_once = CRYPTO_ONCE_INIT;
  20. static int g_state_index = 0;
  21. // Some code treats the zero time special, so initialize the clock to a
  22. // non-zero time.
  23. static timeval g_clock = { 1234, 1234 };
  24. static void TestStateExFree(void *parent, void *ptr, CRYPTO_EX_DATA *ad,
  25. int index, long argl, void *argp) {
  26. delete ((TestState *)ptr);
  27. }
  28. static void init_once() {
  29. g_state_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, TestStateExFree);
  30. if (g_state_index < 0) {
  31. abort();
  32. }
  33. }
  34. struct timeval *GetClock() {
  35. CRYPTO_once(&g_once, init_once);
  36. return &g_clock;
  37. }
  38. void AdvanceClock(unsigned seconds) {
  39. CRYPTO_once(&g_once, init_once);
  40. g_clock.tv_sec += seconds;
  41. }
  42. bool SetTestState(SSL *ssl, std::unique_ptr<TestState> state) {
  43. CRYPTO_once(&g_once, init_once);
  44. // |SSL_set_ex_data| takes ownership of |state| only on success.
  45. if (SSL_set_ex_data(ssl, g_state_index, state.get()) == 1) {
  46. state.release();
  47. return true;
  48. }
  49. return false;
  50. }
  51. TestState *GetTestState(const SSL *ssl) {
  52. CRYPTO_once(&g_once, init_once);
  53. return (TestState *)SSL_get_ex_data(ssl, g_state_index);
  54. }
  55. static void ssl_ctx_add_session(SSL_SESSION *session, void *void_param) {
  56. SSL_CTX *ctx = reinterpret_cast<SSL_CTX *>(void_param);
  57. UniquePtr<SSL_SESSION> new_session = SSL_SESSION_dup(
  58. session, SSL_SESSION_INCLUDE_NONAUTH | SSL_SESSION_INCLUDE_TICKET);
  59. if (new_session != nullptr) {
  60. SSL_CTX_add_session(ctx, new_session.get());
  61. }
  62. }
  63. void CopySessions(SSL_CTX *dst, const SSL_CTX *src) {
  64. lh_SSL_SESSION_doall_arg(src->sessions, ssl_ctx_add_session, dst);
  65. }
  66. static void push_session(SSL_SESSION *session, void *arg) {
  67. auto s = reinterpret_cast<std::vector<SSL_SESSION *> *>(arg);
  68. s->push_back(session);
  69. }
  70. bool SerializeContextState(SSL_CTX *ctx, CBB *cbb) {
  71. CBB out, ctx_sessions, ticket_keys;
  72. uint8_t keys[48];
  73. if (!CBB_add_u24_length_prefixed(cbb, &out) ||
  74. !CBB_add_u16(&out, 0 /* version */) ||
  75. !SSL_CTX_get_tlsext_ticket_keys(ctx, &keys, sizeof(keys)) ||
  76. !CBB_add_u8_length_prefixed(&out, &ticket_keys) ||
  77. !CBB_add_bytes(&ticket_keys, keys, sizeof(keys)) ||
  78. !CBB_add_asn1(&out, &ctx_sessions, CBS_ASN1_SEQUENCE)) {
  79. return false;
  80. }
  81. std::vector<SSL_SESSION *> sessions;
  82. lh_SSL_SESSION_doall_arg(ctx->sessions, push_session, &sessions);
  83. for (const auto &sess : sessions) {
  84. if (!ssl_session_serialize(sess, &ctx_sessions)) {
  85. return false;
  86. }
  87. }
  88. return CBB_flush(cbb);
  89. }
  90. bool DeserializeContextState(CBS *cbs, SSL_CTX *ctx) {
  91. CBS in, sessions, ticket_keys;
  92. uint16_t version;
  93. constexpr uint16_t kVersion = 0;
  94. if (!CBS_get_u24_length_prefixed(cbs, &in) ||
  95. !CBS_get_u16(&in, &version) ||
  96. version > kVersion ||
  97. !CBS_get_u8_length_prefixed(&in, &ticket_keys) ||
  98. !SSL_CTX_set_tlsext_ticket_keys(ctx, CBS_data(&ticket_keys),
  99. CBS_len(&ticket_keys)) ||
  100. !CBS_get_asn1(&in, &sessions, CBS_ASN1_SEQUENCE)) {
  101. return false;
  102. }
  103. while (CBS_len(&sessions)) {
  104. UniquePtr<SSL_SESSION> session =
  105. SSL_SESSION_parse(&sessions, ctx->x509_method, ctx->pool);
  106. if (!session) {
  107. return false;
  108. }
  109. SSL_CTX_add_session(ctx, session.get());
  110. }
  111. return true;
  112. }
  113. bool TestState::Serialize(CBB *cbb) const {
  114. CBB out, pending, text;
  115. if (!CBB_add_u24_length_prefixed(cbb, &out) ||
  116. !CBB_add_u16(&out, 0 /* version */) ||
  117. !CBB_add_u24_length_prefixed(&out, &pending) ||
  118. (pending_session &&
  119. !ssl_session_serialize(pending_session.get(), &pending)) ||
  120. !CBB_add_u16_length_prefixed(&out, &text) ||
  121. !CBB_add_bytes(
  122. &text, reinterpret_cast<const uint8_t *>(msg_callback_text.data()),
  123. msg_callback_text.length()) ||
  124. !CBB_flush(cbb)) {
  125. return false;
  126. }
  127. return true;
  128. }
  129. std::unique_ptr<TestState> TestState::Deserialize(CBS *cbs, SSL_CTX *ctx) {
  130. CBS in, pending_session, text;
  131. std::unique_ptr<TestState> out_state(new TestState());
  132. uint16_t version;
  133. constexpr uint16_t kVersion = 0;
  134. if (!CBS_get_u24_length_prefixed(cbs, &in) ||
  135. !CBS_get_u16(&in, &version) ||
  136. version > kVersion ||
  137. !CBS_get_u24_length_prefixed(&in, &pending_session) ||
  138. !CBS_get_u16_length_prefixed(&in, &text)) {
  139. return nullptr;
  140. }
  141. if (CBS_len(&pending_session)) {
  142. out_state->pending_session = SSL_SESSION_parse(
  143. &pending_session, ctx->x509_method, ctx->pool);
  144. if (!out_state->pending_session) {
  145. return nullptr;
  146. }
  147. }
  148. out_state->msg_callback_text = std::string(
  149. reinterpret_cast<const char *>(CBS_data(&text)), CBS_len(&text));
  150. return out_state;
  151. }