diff options
Diffstat (limited to 'src/interfaces/libpq/fe-auth-scram.c')
-rw-r--r-- | src/interfaces/libpq/fe-auth-scram.c | 170 |
1 files changed, 151 insertions, 19 deletions
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c index edfd42df854..f2403147ca5 100644 --- a/src/interfaces/libpq/fe-auth-scram.c +++ b/src/interfaces/libpq/fe-auth-scram.c @@ -17,6 +17,7 @@ #include "common/base64.h" #include "common/saslprep.h" #include "common/scram-common.h" +#include "libpq/scram.h" #include "fe-auth.h" /* These are needed for getpid(), in the fallback implementation */ @@ -44,6 +45,11 @@ typedef struct /* These are supplied by the user */ const char *username; char *password; + bool ssl_in_use; + char *tls_finished_message; + size_t tls_finished_len; + char *sasl_mechanism; + const char *channel_binding_type; /* We construct these */ uint8 SaltedPassword[SCRAM_KEY_LEN]; @@ -79,25 +85,50 @@ static bool pg_frontend_random(char *dst, int len); /* * Initialize SCRAM exchange status. + * + * The non-const char* arguments should be passed in malloc'ed. They will be + * freed by pg_fe_scram_free(). */ void * -pg_fe_scram_init(const char *username, const char *password) +pg_fe_scram_init(const char *username, + const char *password, + bool ssl_in_use, + const char *sasl_mechanism, + char *tls_finished_message, + size_t tls_finished_len) { fe_scram_state *state; char *prep_password; pg_saslprep_rc rc; + Assert(sasl_mechanism != NULL); + state = (fe_scram_state *) malloc(sizeof(fe_scram_state)); if (!state) return NULL; memset(state, 0, sizeof(fe_scram_state)); state->state = FE_SCRAM_INIT; state->username = username; + state->ssl_in_use = ssl_in_use; + state->tls_finished_message = tls_finished_message; + state->tls_finished_len = tls_finished_len; + state->sasl_mechanism = strdup(sasl_mechanism); + if (!state->sasl_mechanism) + { + free(state); + return NULL; + } + + /* + * Store channel binding type. Only one type is currently supported. + */ + state->channel_binding_type = SCRAM_CHANNEL_BINDING_TLS_UNIQUE; /* Normalize the password with SASLprep, if possible */ rc = pg_saslprep(password, &prep_password); if (rc == SASLPREP_OOM) { + free(state->sasl_mechanism); free(state); return NULL; } @@ -106,6 +137,7 @@ pg_fe_scram_init(const char *username, const char *password) prep_password = strdup(password); if (!prep_password) { + free(state->sasl_mechanism); free(state); return NULL; } @@ -125,6 +157,10 @@ pg_fe_scram_free(void *opaq) if (state->password) free(state->password); + if (state->tls_finished_message) + free(state->tls_finished_message); + if (state->sasl_mechanism) + free(state->sasl_mechanism); /* client messages */ if (state->client_nonce) @@ -297,9 +333,10 @@ static char * build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage) { char raw_nonce[SCRAM_RAW_NONCE_LEN + 1]; - char *buf; - char buflen; + char *result; + int channel_info_len; int encoded_len; + PQExpBufferData buf; /* * Generate a "raw" nonce. This is converted to ASCII-printable form by @@ -328,26 +365,61 @@ build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage) * prepared with SASLprep, the message parsing would fail if it includes * '=' or ',' characters. */ - buflen = 8 + strlen(state->client_nonce) + 1; - buf = malloc(buflen); - if (buf == NULL) + + initPQExpBuffer(&buf); + + /* + * First build the gs2-header with channel binding information. + */ + if (strcmp(state->sasl_mechanism, SCRAM_SHA256_PLUS_NAME) == 0) { - printfPQExpBuffer(errormessage, - libpq_gettext("out of memory\n")); - return NULL; + Assert(state->ssl_in_use); + appendPQExpBuffer(&buf, "p=%s", state->channel_binding_type); } - snprintf(buf, buflen, "n,,n=,r=%s", state->client_nonce); - - state->client_first_message_bare = strdup(buf + 3); - if (!state->client_first_message_bare) + else if (state->ssl_in_use) { - free(buf); - printfPQExpBuffer(errormessage, - libpq_gettext("out of memory\n")); - return NULL; + /* + * Client supports channel binding, but thinks the server does not. + */ + appendPQExpBuffer(&buf, "y"); } + else + { + /* + * Client does not support channel binding. + */ + appendPQExpBuffer(&buf, "n"); + } + + if (PQExpBufferDataBroken(buf)) + goto oom_error; + + channel_info_len = buf.len; + + appendPQExpBuffer(&buf, ",,n=,r=%s", state->client_nonce); + if (PQExpBufferDataBroken(buf)) + goto oom_error; + + /* + * The first message content needs to be saved without channel binding + * information. + */ + state->client_first_message_bare = strdup(buf.data + channel_info_len + 2); + if (!state->client_first_message_bare) + goto oom_error; + + result = strdup(buf.data); + if (result == NULL) + goto oom_error; + + termPQExpBuffer(&buf); + return result; - return buf; +oom_error: + termPQExpBuffer(&buf); + printfPQExpBuffer(errormessage, + libpq_gettext("out of memory\n")); + return NULL; } /* @@ -366,7 +438,67 @@ build_client_final_message(fe_scram_state *state, PQExpBuffer errormessage) * Construct client-final-message-without-proof. We need to remember it * for verifying the server proof in the final step of authentication. */ - appendPQExpBuffer(&buf, "c=biws,r=%s", state->nonce); + if (strcmp(state->sasl_mechanism, SCRAM_SHA256_PLUS_NAME) == 0) + { + char *cbind_data; + size_t cbind_data_len; + size_t cbind_header_len; + char *cbind_input; + size_t cbind_input_len; + + if (strcmp(state->channel_binding_type, SCRAM_CHANNEL_BINDING_TLS_UNIQUE) == 0) + { + cbind_data = state->tls_finished_message; + cbind_data_len = state->tls_finished_len; + } + else + { + /* should not happen */ + termPQExpBuffer(&buf); + printfPQExpBuffer(errormessage, + libpq_gettext("invalid channel binding type\n")); + return NULL; + } + + /* should not happen */ + if (cbind_data == NULL || cbind_data_len == 0) + { + termPQExpBuffer(&buf); + printfPQExpBuffer(errormessage, + libpq_gettext("empty channel binding data for channel binding type \"%s\"\n"), + state->channel_binding_type); + return NULL; + } + + appendPQExpBuffer(&buf, "c="); + + cbind_header_len = 4 + strlen(state->channel_binding_type); /* p=type,, */ + cbind_input_len = cbind_header_len + cbind_data_len; + cbind_input = malloc(cbind_input_len); + if (!cbind_input) + goto oom_error; + snprintf(cbind_input, cbind_input_len, "p=%s,,", state->channel_binding_type); + memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len); + + if (!enlargePQExpBuffer(&buf, pg_b64_enc_len(cbind_input_len))) + { + free(cbind_input); + goto oom_error; + } + buf.len += pg_b64_encode(cbind_input, cbind_input_len, buf.data + buf.len); + buf.data[buf.len] = '\0'; + + free(cbind_input); + } + else if (state->ssl_in_use) + appendPQExpBuffer(&buf, "c=eSws"); /* base64 of "y,," */ + else + appendPQExpBuffer(&buf, "c=biws"); /* base64 of "n,," */ + + if (PQExpBufferDataBroken(buf)) + goto oom_error; + + appendPQExpBuffer(&buf, ",r=%s", state->nonce); if (PQExpBufferDataBroken(buf)) goto oom_error; |