summaryrefslogtreecommitdiff
path: root/src/interfaces/libpq/fe-auth-scram.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/interfaces/libpq/fe-auth-scram.c')
-rw-r--r--src/interfaces/libpq/fe-auth-scram.c170
1 files changed, 151 insertions, 19 deletions
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c
index edfd42df854..f2403147ca5 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -17,6 +17,7 @@
#include "common/base64.h"
#include "common/saslprep.h"
#include "common/scram-common.h"
+#include "libpq/scram.h"
#include "fe-auth.h"
/* These are needed for getpid(), in the fallback implementation */
@@ -44,6 +45,11 @@ typedef struct
/* These are supplied by the user */
const char *username;
char *password;
+ bool ssl_in_use;
+ char *tls_finished_message;
+ size_t tls_finished_len;
+ char *sasl_mechanism;
+ const char *channel_binding_type;
/* We construct these */
uint8 SaltedPassword[SCRAM_KEY_LEN];
@@ -79,25 +85,50 @@ static bool pg_frontend_random(char *dst, int len);
/*
* Initialize SCRAM exchange status.
+ *
+ * The non-const char* arguments should be passed in malloc'ed. They will be
+ * freed by pg_fe_scram_free().
*/
void *
-pg_fe_scram_init(const char *username, const char *password)
+pg_fe_scram_init(const char *username,
+ const char *password,
+ bool ssl_in_use,
+ const char *sasl_mechanism,
+ char *tls_finished_message,
+ size_t tls_finished_len)
{
fe_scram_state *state;
char *prep_password;
pg_saslprep_rc rc;
+ Assert(sasl_mechanism != NULL);
+
state = (fe_scram_state *) malloc(sizeof(fe_scram_state));
if (!state)
return NULL;
memset(state, 0, sizeof(fe_scram_state));
state->state = FE_SCRAM_INIT;
state->username = username;
+ state->ssl_in_use = ssl_in_use;
+ state->tls_finished_message = tls_finished_message;
+ state->tls_finished_len = tls_finished_len;
+ state->sasl_mechanism = strdup(sasl_mechanism);
+ if (!state->sasl_mechanism)
+ {
+ free(state);
+ return NULL;
+ }
+
+ /*
+ * Store channel binding type. Only one type is currently supported.
+ */
+ state->channel_binding_type = SCRAM_CHANNEL_BINDING_TLS_UNIQUE;
/* Normalize the password with SASLprep, if possible */
rc = pg_saslprep(password, &prep_password);
if (rc == SASLPREP_OOM)
{
+ free(state->sasl_mechanism);
free(state);
return NULL;
}
@@ -106,6 +137,7 @@ pg_fe_scram_init(const char *username, const char *password)
prep_password = strdup(password);
if (!prep_password)
{
+ free(state->sasl_mechanism);
free(state);
return NULL;
}
@@ -125,6 +157,10 @@ pg_fe_scram_free(void *opaq)
if (state->password)
free(state->password);
+ if (state->tls_finished_message)
+ free(state->tls_finished_message);
+ if (state->sasl_mechanism)
+ free(state->sasl_mechanism);
/* client messages */
if (state->client_nonce)
@@ -297,9 +333,10 @@ static char *
build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage)
{
char raw_nonce[SCRAM_RAW_NONCE_LEN + 1];
- char *buf;
- char buflen;
+ char *result;
+ int channel_info_len;
int encoded_len;
+ PQExpBufferData buf;
/*
* Generate a "raw" nonce. This is converted to ASCII-printable form by
@@ -328,26 +365,61 @@ build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage)
* prepared with SASLprep, the message parsing would fail if it includes
* '=' or ',' characters.
*/
- buflen = 8 + strlen(state->client_nonce) + 1;
- buf = malloc(buflen);
- if (buf == NULL)
+
+ initPQExpBuffer(&buf);
+
+ /*
+ * First build the gs2-header with channel binding information.
+ */
+ if (strcmp(state->sasl_mechanism, SCRAM_SHA256_PLUS_NAME) == 0)
{
- printfPQExpBuffer(errormessage,
- libpq_gettext("out of memory\n"));
- return NULL;
+ Assert(state->ssl_in_use);
+ appendPQExpBuffer(&buf, "p=%s", state->channel_binding_type);
}
- snprintf(buf, buflen, "n,,n=,r=%s", state->client_nonce);
-
- state->client_first_message_bare = strdup(buf + 3);
- if (!state->client_first_message_bare)
+ else if (state->ssl_in_use)
{
- free(buf);
- printfPQExpBuffer(errormessage,
- libpq_gettext("out of memory\n"));
- return NULL;
+ /*
+ * Client supports channel binding, but thinks the server does not.
+ */
+ appendPQExpBuffer(&buf, "y");
}
+ else
+ {
+ /*
+ * Client does not support channel binding.
+ */
+ appendPQExpBuffer(&buf, "n");
+ }
+
+ if (PQExpBufferDataBroken(buf))
+ goto oom_error;
+
+ channel_info_len = buf.len;
+
+ appendPQExpBuffer(&buf, ",,n=,r=%s", state->client_nonce);
+ if (PQExpBufferDataBroken(buf))
+ goto oom_error;
+
+ /*
+ * The first message content needs to be saved without channel binding
+ * information.
+ */
+ state->client_first_message_bare = strdup(buf.data + channel_info_len + 2);
+ if (!state->client_first_message_bare)
+ goto oom_error;
+
+ result = strdup(buf.data);
+ if (result == NULL)
+ goto oom_error;
+
+ termPQExpBuffer(&buf);
+ return result;
- return buf;
+oom_error:
+ termPQExpBuffer(&buf);
+ printfPQExpBuffer(errormessage,
+ libpq_gettext("out of memory\n"));
+ return NULL;
}
/*
@@ -366,7 +438,67 @@ build_client_final_message(fe_scram_state *state, PQExpBuffer errormessage)
* Construct client-final-message-without-proof. We need to remember it
* for verifying the server proof in the final step of authentication.
*/
- appendPQExpBuffer(&buf, "c=biws,r=%s", state->nonce);
+ if (strcmp(state->sasl_mechanism, SCRAM_SHA256_PLUS_NAME) == 0)
+ {
+ char *cbind_data;
+ size_t cbind_data_len;
+ size_t cbind_header_len;
+ char *cbind_input;
+ size_t cbind_input_len;
+
+ if (strcmp(state->channel_binding_type, SCRAM_CHANNEL_BINDING_TLS_UNIQUE) == 0)
+ {
+ cbind_data = state->tls_finished_message;
+ cbind_data_len = state->tls_finished_len;
+ }
+ else
+ {
+ /* should not happen */
+ termPQExpBuffer(&buf);
+ printfPQExpBuffer(errormessage,
+ libpq_gettext("invalid channel binding type\n"));
+ return NULL;
+ }
+
+ /* should not happen */
+ if (cbind_data == NULL || cbind_data_len == 0)
+ {
+ termPQExpBuffer(&buf);
+ printfPQExpBuffer(errormessage,
+ libpq_gettext("empty channel binding data for channel binding type \"%s\"\n"),
+ state->channel_binding_type);
+ return NULL;
+ }
+
+ appendPQExpBuffer(&buf, "c=");
+
+ cbind_header_len = 4 + strlen(state->channel_binding_type); /* p=type,, */
+ cbind_input_len = cbind_header_len + cbind_data_len;
+ cbind_input = malloc(cbind_input_len);
+ if (!cbind_input)
+ goto oom_error;
+ snprintf(cbind_input, cbind_input_len, "p=%s,,", state->channel_binding_type);
+ memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
+
+ if (!enlargePQExpBuffer(&buf, pg_b64_enc_len(cbind_input_len)))
+ {
+ free(cbind_input);
+ goto oom_error;
+ }
+ buf.len += pg_b64_encode(cbind_input, cbind_input_len, buf.data + buf.len);
+ buf.data[buf.len] = '\0';
+
+ free(cbind_input);
+ }
+ else if (state->ssl_in_use)
+ appendPQExpBuffer(&buf, "c=eSws"); /* base64 of "y,," */
+ else
+ appendPQExpBuffer(&buf, "c=biws"); /* base64 of "n,," */
+
+ if (PQExpBufferDataBroken(buf))
+ goto oom_error;
+
+ appendPQExpBuffer(&buf, ",r=%s", state->nonce);
if (PQExpBufferDataBroken(buf))
goto oom_error;