summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/backend/libpq/be-secure.c87
-rw-r--r--src/backend/tcop/postgres.c6
2 files changed, 76 insertions, 17 deletions
diff --git a/src/backend/libpq/be-secure.c b/src/backend/libpq/be-secure.c
index 82a608dc5ba..0dccd3f3287 100644
--- a/src/backend/libpq/be-secure.c
+++ b/src/backend/libpq/be-secure.c
@@ -102,6 +102,9 @@ char *ssl_crl_file;
int ssl_renegotiation_limit;
#ifdef USE_SSL
+/* are we in the middle of a renegotiation? */
+static bool in_ssl_renegotiation = false;
+
static SSL_CTX *SSL_context = NULL;
static bool ssl_loaded_verify_locations = false;
#endif
@@ -295,6 +298,7 @@ rloop:
(errcode(ERRCODE_PROTOCOL_VIOLATION),
errmsg("unrecognized SSL error code: %d",
err)));
+ errno = ECONNRESET;
n = -1;
break;
}
@@ -326,29 +330,55 @@ secure_write(Port *port, void *ptr, size_t len)
int err;
unsigned long ecode;
- if (ssl_renegotiation_limit && port->count > ssl_renegotiation_limit * 1024L)
+ /*
+ * If SSL renegotiations are enabled and we're getting close to the
+ * limit, start one now; but avoid it if there's one already in
+ * progress. Request the renegotiation 1kB before the limit has
+ * actually expired.
+ */
+ if (ssl_renegotiation_limit && !in_ssl_renegotiation &&
+ port->count > (ssl_renegotiation_limit - 1) * 1024L)
{
+ in_ssl_renegotiation = true;
+
+ /*
+ * The way we determine that a renegotiation has completed is by
+ * observing OpenSSL's internal renegotiation counter. Make sure
+ * we start out at zero, and assume that the renegotiation is
+ * complete when the counter advances.
+ *
+ * OpenSSL provides SSL_renegotiation_pending(), but this doesn't
+ * seem to work in testing.
+ */
+ SSL_clear_num_renegotiations(port->ssl);
+
SSL_set_session_id_context(port->ssl, (void *) &SSL_context,
sizeof(SSL_context));
if (SSL_renegotiate(port->ssl) <= 0)
ereport(COMMERROR,
(errcode(ERRCODE_PROTOCOL_VIOLATION),
- errmsg("SSL renegotiation failure")));
- if (SSL_do_handshake(port->ssl) <= 0)
- ereport(COMMERROR,
- (errcode(ERRCODE_PROTOCOL_VIOLATION),
- errmsg("SSL renegotiation failure")));
- if (port->ssl->state != SSL_ST_OK)
- ereport(COMMERROR,
- (errcode(ERRCODE_PROTOCOL_VIOLATION),
- errmsg("SSL failed to send renegotiation request")));
- port->ssl->state |= SSL_ST_ACCEPT;
- SSL_do_handshake(port->ssl);
- if (port->ssl->state != SSL_ST_OK)
- ereport(COMMERROR,
- (errcode(ERRCODE_PROTOCOL_VIOLATION),
- errmsg("SSL renegotiation failure")));
- port->count = 0;
+ errmsg("SSL failure during renegotiation start")));
+ else
+ {
+ int retries;
+
+ /*
+ * A handshake can fail, so be prepared to retry it, but only
+ * a few times.
+ */
+ for (retries = 0;; retries++)
+ {
+ if (SSL_do_handshake(port->ssl) > 0)
+ break; /* done */
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("SSL handshake failure on renegotiation, retrying")));
+ if (retries >= 20)
+ ereport(FATAL,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("could not complete SSL handshake on renegotiation, too many failures")));
+ }
+ }
}
wloop:
@@ -393,9 +423,32 @@ wloop:
(errcode(ERRCODE_PROTOCOL_VIOLATION),
errmsg("unrecognized SSL error code: %d",
err)));
+ errno = ECONNRESET;
n = -1;
break;
}
+
+ if (n >= 0)
+ {
+ /* is renegotiation complete? */
+ if (in_ssl_renegotiation &&
+ SSL_num_renegotiations(port->ssl) >= 1)
+ {
+ in_ssl_renegotiation = false;
+ port->count = 0;
+ }
+
+ /*
+ * if renegotiation is still ongoing, and we've gone beyond the
+ * limit, kill the connection now -- continuing to use it can be
+ * considered a security problem.
+ */
+ if (in_ssl_renegotiation &&
+ port->count > ssl_renegotiation_limit * 1024L)
+ ereport(FATAL,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("SSL failed to renegotiate connection before limit expired")));
+ }
}
else
#endif
diff --git a/src/backend/tcop/postgres.c b/src/backend/tcop/postgres.c
index 3ba3a0c6142..4c140ee3154 100644
--- a/src/backend/tcop/postgres.c
+++ b/src/backend/tcop/postgres.c
@@ -553,16 +553,22 @@ prepare_for_client_read(void)
/*
* client_read_ended -- get out of the client-input state
+ *
+ * This is called just after low-level reads. It must preserve errno!
*/
void
client_read_ended(void)
{
if (DoingCommandRead)
{
+ int save_errno = errno;
+
ImmediateInterruptOK = false;
DisableNotifyInterrupt();
DisableCatchupInterrupt();
+
+ errno = save_errno;
}
}