summaryrefslogtreecommitdiff
path: root/src/backend/libpq/auth-scram.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/backend/libpq/auth-scram.c')
-rw-r--r--src/backend/libpq/auth-scram.c141
1 files changed, 87 insertions, 54 deletions
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index c9bab85e82f..126eb70974a 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -141,10 +141,14 @@ typedef struct
Port *port;
bool channel_binding_in_use;
+ /* State data depending on the hash type */
+ pg_cryptohash_type hash_type;
+ int key_length;
+
int iterations;
char *salt; /* base64-encoded */
- uint8 StoredKey[SCRAM_KEY_LEN];
- uint8 ServerKey[SCRAM_KEY_LEN];
+ uint8 StoredKey[SCRAM_MAX_KEY_LEN];
+ uint8 ServerKey[SCRAM_MAX_KEY_LEN];
/* Fields of the first message from client */
char cbind_flag;
@@ -155,7 +159,7 @@ typedef struct
/* Fields from the last message from client */
char *client_final_message_without_proof;
char *client_final_nonce;
- char ClientProof[SCRAM_KEY_LEN];
+ char ClientProof[SCRAM_MAX_KEY_LEN];
/* Fields generated in the server */
char *server_first_message;
@@ -177,12 +181,15 @@ static char *build_server_first_message(scram_state *state);
static char *build_server_final_message(scram_state *state);
static bool verify_client_proof(scram_state *state);
static bool verify_final_nonce(scram_state *state);
-static void mock_scram_secret(const char *username, int *iterations,
- char **salt, uint8 *stored_key, uint8 *server_key);
+static void mock_scram_secret(const char *username, pg_cryptohash_type *hash_type,
+ int *iterations, int *key_length, char **salt,
+ uint8 *stored_key, uint8 *server_key);
static bool is_scram_printable(char *p);
static char *sanitize_char(char c);
static char *sanitize_str(const char *s);
-static char *scram_mock_salt(const char *username);
+static char *scram_mock_salt(const char *username,
+ pg_cryptohash_type hash_type,
+ int key_length);
/*
* Get a list of SASL mechanisms that this module supports.
@@ -266,8 +273,11 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
if (password_type == PASSWORD_TYPE_SCRAM_SHA_256)
{
- if (parse_scram_secret(shadow_pass, &state->iterations, &state->salt,
- state->StoredKey, state->ServerKey))
+ if (parse_scram_secret(shadow_pass, &state->iterations,
+ &state->hash_type, &state->key_length,
+ &state->salt,
+ state->StoredKey,
+ state->ServerKey))
got_secret = true;
else
{
@@ -310,8 +320,10 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
*/
if (!got_secret)
{
- mock_scram_secret(state->port->user_name, &state->iterations,
- &state->salt, state->StoredKey, state->ServerKey);
+ mock_scram_secret(state->port->user_name, &state->hash_type,
+ &state->iterations, &state->key_length,
+ &state->salt,
+ state->StoredKey, state->ServerKey);
state->doomed = true;
}
@@ -482,7 +494,8 @@ pg_be_scram_build_secret(const char *password)
(errcode(ERRCODE_INTERNAL_ERROR),
errmsg("could not generate random salt")));
- result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN,
+ result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN,
+ saltbuf, SCRAM_DEFAULT_SALT_LEN,
SCRAM_DEFAULT_ITERATIONS, password,
&errstr);
@@ -505,16 +518,18 @@ scram_verify_plain_password(const char *username, const char *password,
char *salt;
int saltlen;
int iterations;
- uint8 salted_password[SCRAM_KEY_LEN];
- uint8 stored_key[SCRAM_KEY_LEN];
- uint8 server_key[SCRAM_KEY_LEN];
- uint8 computed_key[SCRAM_KEY_LEN];
+ int key_length = 0;
+ pg_cryptohash_type hash_type;
+ uint8 salted_password[SCRAM_MAX_KEY_LEN];
+ uint8 stored_key[SCRAM_MAX_KEY_LEN];
+ uint8 server_key[SCRAM_MAX_KEY_LEN];
+ uint8 computed_key[SCRAM_MAX_KEY_LEN];
char *prep_password;
pg_saslprep_rc rc;
const char *errstr = NULL;
- if (!parse_scram_secret(secret, &iterations, &encoded_salt,
- stored_key, server_key))
+ if (!parse_scram_secret(secret, &iterations, &hash_type, &key_length,
+ &encoded_salt, stored_key, server_key))
{
/*
* The password looked like a SCRAM secret, but could not be parsed.
@@ -541,9 +556,11 @@ scram_verify_plain_password(const char *username, const char *password,
password = prep_password;
/* Compute Server Key based on the user-supplied plaintext password */
- if (scram_SaltedPassword(password, salt, saltlen, iterations,
+ if (scram_SaltedPassword(password, hash_type, key_length,
+ salt, saltlen, iterations,
salted_password, &errstr) < 0 ||
- scram_ServerKey(salted_password, computed_key, &errstr) < 0)
+ scram_ServerKey(salted_password, hash_type, key_length,
+ computed_key, &errstr) < 0)
{
elog(ERROR, "could not compute server key: %s", errstr);
}
@@ -555,7 +572,7 @@ scram_verify_plain_password(const char *username, const char *password,
* Compare the secret's Server Key with the one computed from the
* user-supplied password.
*/
- return memcmp(computed_key, server_key, SCRAM_KEY_LEN) == 0;
+ return memcmp(computed_key, server_key, key_length) == 0;
}
@@ -565,14 +582,15 @@ scram_verify_plain_password(const char *username, const char *password,
* On success, the iteration count, salt, stored key, and server key are
* extracted from the secret, and returned to the caller. For 'stored_key'
* and 'server_key', the caller must pass pre-allocated buffers of size
- * SCRAM_KEY_LEN. Salt is returned as a base64-encoded, null-terminated
+ * SCRAM_MAX_KEY_LEN. Salt is returned as a base64-encoded, null-terminated
* string. The buffer for the salt is palloc'd by this function.
*
* Returns true if the SCRAM secret has been parsed, and false otherwise.
*/
bool
-parse_scram_secret(const char *secret, int *iterations, char **salt,
- uint8 *stored_key, uint8 *server_key)
+parse_scram_secret(const char *secret, int *iterations,
+ pg_cryptohash_type *hash_type, int *key_length,
+ char **salt, uint8 *stored_key, uint8 *server_key)
{
char *v;
char *p;
@@ -606,6 +624,8 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
/* Parse the fields */
if (strcmp(scheme_str, "SCRAM-SHA-256") != 0)
goto invalid_secret;
+ *hash_type = PG_SHA256;
+ *key_length = SCRAM_SHA_256_KEY_LEN;
errno = 0;
*iterations = strtol(iterations_str, &p, 10);
@@ -631,17 +651,17 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
decoded_stored_buf = palloc(decoded_len);
decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str),
decoded_stored_buf, decoded_len);
- if (decoded_len != SCRAM_KEY_LEN)
+ if (decoded_len != *key_length)
goto invalid_secret;
- memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN);
+ memcpy(stored_key, decoded_stored_buf, *key_length);
decoded_len = pg_b64_dec_len(strlen(serverkey_str));
decoded_server_buf = palloc(decoded_len);
decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
decoded_server_buf, decoded_len);
- if (decoded_len != SCRAM_KEY_LEN)
+ if (decoded_len != *key_length)
goto invalid_secret;
- memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN);
+ memcpy(server_key, decoded_server_buf, *key_length);
return true;
@@ -655,20 +675,25 @@ invalid_secret:
*
* In a normal authentication, these are extracted from the secret
* stored in the server. This function generates values that look
- * realistic, for when there is no stored secret.
+ * realistic, for when there is no stored secret, using SCRAM-SHA-256.
*
* Like in parse_scram_secret(), for 'stored_key' and 'server_key', the
- * caller must pass pre-allocated buffers of size SCRAM_KEY_LEN, and
+ * caller must pass pre-allocated buffers of size SCRAM_MAX_KEY_LEN, and
* the buffer for the salt is palloc'd by this function.
*/
static void
-mock_scram_secret(const char *username, int *iterations, char **salt,
+mock_scram_secret(const char *username, pg_cryptohash_type *hash_type,
+ int *iterations, int *key_length, char **salt,
uint8 *stored_key, uint8 *server_key)
{
char *raw_salt;
char *encoded_salt;
int encoded_len;
+ /* Enforce the use of SHA-256, which would be realistic enough */
+ *hash_type = PG_SHA256;
+ *key_length = SCRAM_SHA_256_KEY_LEN;
+
/*
* Generate deterministic salt.
*
@@ -677,7 +702,7 @@ mock_scram_secret(const char *username, int *iterations, char **salt,
* as the salt generated for mock authentication uses the cluster's nonce
* value.
*/
- raw_salt = scram_mock_salt(username);
+ raw_salt = scram_mock_salt(username, *hash_type, *key_length);
if (raw_salt == NULL)
elog(ERROR, "could not encode salt");
@@ -695,8 +720,8 @@ mock_scram_secret(const char *username, int *iterations, char **salt,
*iterations = SCRAM_DEFAULT_ITERATIONS;
/* StoredKey and ServerKey are not used in a doomed authentication */
- memset(stored_key, 0, SCRAM_KEY_LEN);
- memset(server_key, 0, SCRAM_KEY_LEN);
+ memset(stored_key, 0, SCRAM_MAX_KEY_LEN);
+ memset(server_key, 0, SCRAM_MAX_KEY_LEN);
}
/*
@@ -1111,10 +1136,10 @@ verify_final_nonce(scram_state *state)
static bool
verify_client_proof(scram_state *state)
{
- uint8 ClientSignature[SCRAM_KEY_LEN];
- uint8 ClientKey[SCRAM_KEY_LEN];
- uint8 client_StoredKey[SCRAM_KEY_LEN];
- pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
+ uint8 ClientSignature[SCRAM_MAX_KEY_LEN];
+ uint8 ClientKey[SCRAM_MAX_KEY_LEN];
+ uint8 client_StoredKey[SCRAM_MAX_KEY_LEN];
+ pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
int i;
const char *errstr = NULL;
@@ -1123,7 +1148,7 @@ verify_client_proof(scram_state *state)
* here even when processing the calculations as this could involve a mock
* authentication.
*/
- if (pg_hmac_init(ctx, state->StoredKey, SCRAM_KEY_LEN) < 0 ||
+ if (pg_hmac_init(ctx, state->StoredKey, state->key_length) < 0 ||
pg_hmac_update(ctx,
(uint8 *) state->client_first_message_bare,
strlen(state->client_first_message_bare)) < 0 ||
@@ -1135,7 +1160,7 @@ verify_client_proof(scram_state *state)
pg_hmac_update(ctx,
(uint8 *) state->client_final_message_without_proof,
strlen(state->client_final_message_without_proof)) < 0 ||
- pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0)
+ pg_hmac_final(ctx, ClientSignature, state->key_length) < 0)
{
elog(ERROR, "could not calculate client signature: %s",
pg_hmac_error(ctx));
@@ -1144,14 +1169,15 @@ verify_client_proof(scram_state *state)
pg_hmac_free(ctx);
/* Extract the ClientKey that the client calculated from the proof */
- for (i = 0; i < SCRAM_KEY_LEN; i++)
+ for (i = 0; i < state->key_length; i++)
ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i];
/* Hash it one more time, and compare with StoredKey */
- if (scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey, &errstr) < 0)
+ if (scram_H(ClientKey, state->hash_type, state->key_length,
+ client_StoredKey, &errstr) < 0)
elog(ERROR, "could not hash stored key: %s", errstr);
- if (memcmp(client_StoredKey, state->StoredKey, SCRAM_KEY_LEN) != 0)
+ if (memcmp(client_StoredKey, state->StoredKey, state->key_length) != 0)
return false;
return true;
@@ -1349,12 +1375,12 @@ read_client_final_message(scram_state *state, const char *input)
client_proof_len = pg_b64_dec_len(strlen(value));
client_proof = palloc(client_proof_len);
if (pg_b64_decode(value, strlen(value), client_proof,
- client_proof_len) != SCRAM_KEY_LEN)
+ client_proof_len) != state->key_length)
ereport(ERROR,
(errcode(ERRCODE_PROTOCOL_VIOLATION),
errmsg("malformed SCRAM message"),
errdetail("Malformed proof in client-final-message.")));
- memcpy(state->ClientProof, client_proof, SCRAM_KEY_LEN);
+ memcpy(state->ClientProof, client_proof, state->key_length);
pfree(client_proof);
if (*p != '\0')
@@ -1374,13 +1400,13 @@ read_client_final_message(scram_state *state, const char *input)
static char *
build_server_final_message(scram_state *state)
{
- uint8 ServerSignature[SCRAM_KEY_LEN];
+ uint8 ServerSignature[SCRAM_MAX_KEY_LEN];
char *server_signature_base64;
int siglen;
- pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
+ pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
/* calculate ServerSignature */
- if (pg_hmac_init(ctx, state->ServerKey, SCRAM_KEY_LEN) < 0 ||
+ if (pg_hmac_init(ctx, state->ServerKey, state->key_length) < 0 ||
pg_hmac_update(ctx,
(uint8 *) state->client_first_message_bare,
strlen(state->client_first_message_bare)) < 0 ||
@@ -1392,7 +1418,7 @@ build_server_final_message(scram_state *state)
pg_hmac_update(ctx,
(uint8 *) state->client_final_message_without_proof,
strlen(state->client_final_message_without_proof)) < 0 ||
- pg_hmac_final(ctx, ServerSignature, sizeof(ServerSignature)) < 0)
+ pg_hmac_final(ctx, ServerSignature, state->key_length) < 0)
{
elog(ERROR, "could not calculate server signature: %s",
pg_hmac_error(ctx));
@@ -1400,11 +1426,11 @@ build_server_final_message(scram_state *state)
pg_hmac_free(ctx);
- siglen = pg_b64_enc_len(SCRAM_KEY_LEN);
+ siglen = pg_b64_enc_len(state->key_length);
/* don't forget the zero-terminator */
server_signature_base64 = palloc(siglen + 1);
siglen = pg_b64_encode((const char *) ServerSignature,
- SCRAM_KEY_LEN, server_signature_base64,
+ state->key_length, server_signature_base64,
siglen);
if (siglen < 0)
elog(ERROR, "could not encode server signature");
@@ -1431,10 +1457,11 @@ build_server_final_message(scram_state *state)
* pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN, or NULL.
*/
static char *
-scram_mock_salt(const char *username)
+scram_mock_salt(const char *username, pg_cryptohash_type hash_type,
+ int key_length)
{
pg_cryptohash_ctx *ctx;
- static uint8 sha_digest[PG_SHA256_DIGEST_LENGTH];
+ static uint8 sha_digest[SCRAM_MAX_KEY_LEN];
char *mock_auth_nonce = GetMockAuthenticationNonce();
/*
@@ -1446,11 +1473,17 @@ scram_mock_salt(const char *username)
StaticAssertDecl(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN,
"salt length greater than SHA256 digest length");
- ctx = pg_cryptohash_create(PG_SHA256);
+ /*
+ * This may be worth refreshing if support for more hash methods is\
+ * added.
+ */
+ Assert(hash_type == PG_SHA256);
+
+ ctx = pg_cryptohash_create(hash_type);
if (pg_cryptohash_init(ctx) < 0 ||
pg_cryptohash_update(ctx, (uint8 *) username, strlen(username)) < 0 ||
pg_cryptohash_update(ctx, (uint8 *) mock_auth_nonce, MOCK_AUTH_NONCE_LEN) < 0 ||
- pg_cryptohash_final(ctx, sha_digest, sizeof(sha_digest)) < 0)
+ pg_cryptohash_final(ctx, sha_digest, key_length) < 0)
{
pg_cryptohash_free(ctx);
return NULL;