diff options
Diffstat (limited to 'src/backend/libpq/auth-scram.c')
-rw-r--r-- | src/backend/libpq/auth-scram.c | 141 |
1 files changed, 87 insertions, 54 deletions
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c index c9bab85e82f..126eb70974a 100644 --- a/src/backend/libpq/auth-scram.c +++ b/src/backend/libpq/auth-scram.c @@ -141,10 +141,14 @@ typedef struct Port *port; bool channel_binding_in_use; + /* State data depending on the hash type */ + pg_cryptohash_type hash_type; + int key_length; + int iterations; char *salt; /* base64-encoded */ - uint8 StoredKey[SCRAM_KEY_LEN]; - uint8 ServerKey[SCRAM_KEY_LEN]; + uint8 StoredKey[SCRAM_MAX_KEY_LEN]; + uint8 ServerKey[SCRAM_MAX_KEY_LEN]; /* Fields of the first message from client */ char cbind_flag; @@ -155,7 +159,7 @@ typedef struct /* Fields from the last message from client */ char *client_final_message_without_proof; char *client_final_nonce; - char ClientProof[SCRAM_KEY_LEN]; + char ClientProof[SCRAM_MAX_KEY_LEN]; /* Fields generated in the server */ char *server_first_message; @@ -177,12 +181,15 @@ static char *build_server_first_message(scram_state *state); static char *build_server_final_message(scram_state *state); static bool verify_client_proof(scram_state *state); static bool verify_final_nonce(scram_state *state); -static void mock_scram_secret(const char *username, int *iterations, - char **salt, uint8 *stored_key, uint8 *server_key); +static void mock_scram_secret(const char *username, pg_cryptohash_type *hash_type, + int *iterations, int *key_length, char **salt, + uint8 *stored_key, uint8 *server_key); static bool is_scram_printable(char *p); static char *sanitize_char(char c); static char *sanitize_str(const char *s); -static char *scram_mock_salt(const char *username); +static char *scram_mock_salt(const char *username, + pg_cryptohash_type hash_type, + int key_length); /* * Get a list of SASL mechanisms that this module supports. @@ -266,8 +273,11 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass) if (password_type == PASSWORD_TYPE_SCRAM_SHA_256) { - if (parse_scram_secret(shadow_pass, &state->iterations, &state->salt, - state->StoredKey, state->ServerKey)) + if (parse_scram_secret(shadow_pass, &state->iterations, + &state->hash_type, &state->key_length, + &state->salt, + state->StoredKey, + state->ServerKey)) got_secret = true; else { @@ -310,8 +320,10 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass) */ if (!got_secret) { - mock_scram_secret(state->port->user_name, &state->iterations, - &state->salt, state->StoredKey, state->ServerKey); + mock_scram_secret(state->port->user_name, &state->hash_type, + &state->iterations, &state->key_length, + &state->salt, + state->StoredKey, state->ServerKey); state->doomed = true; } @@ -482,7 +494,8 @@ pg_be_scram_build_secret(const char *password) (errcode(ERRCODE_INTERNAL_ERROR), errmsg("could not generate random salt"))); - result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN, + result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN, + saltbuf, SCRAM_DEFAULT_SALT_LEN, SCRAM_DEFAULT_ITERATIONS, password, &errstr); @@ -505,16 +518,18 @@ scram_verify_plain_password(const char *username, const char *password, char *salt; int saltlen; int iterations; - uint8 salted_password[SCRAM_KEY_LEN]; - uint8 stored_key[SCRAM_KEY_LEN]; - uint8 server_key[SCRAM_KEY_LEN]; - uint8 computed_key[SCRAM_KEY_LEN]; + int key_length = 0; + pg_cryptohash_type hash_type; + uint8 salted_password[SCRAM_MAX_KEY_LEN]; + uint8 stored_key[SCRAM_MAX_KEY_LEN]; + uint8 server_key[SCRAM_MAX_KEY_LEN]; + uint8 computed_key[SCRAM_MAX_KEY_LEN]; char *prep_password; pg_saslprep_rc rc; const char *errstr = NULL; - if (!parse_scram_secret(secret, &iterations, &encoded_salt, - stored_key, server_key)) + if (!parse_scram_secret(secret, &iterations, &hash_type, &key_length, + &encoded_salt, stored_key, server_key)) { /* * The password looked like a SCRAM secret, but could not be parsed. @@ -541,9 +556,11 @@ scram_verify_plain_password(const char *username, const char *password, password = prep_password; /* Compute Server Key based on the user-supplied plaintext password */ - if (scram_SaltedPassword(password, salt, saltlen, iterations, + if (scram_SaltedPassword(password, hash_type, key_length, + salt, saltlen, iterations, salted_password, &errstr) < 0 || - scram_ServerKey(salted_password, computed_key, &errstr) < 0) + scram_ServerKey(salted_password, hash_type, key_length, + computed_key, &errstr) < 0) { elog(ERROR, "could not compute server key: %s", errstr); } @@ -555,7 +572,7 @@ scram_verify_plain_password(const char *username, const char *password, * Compare the secret's Server Key with the one computed from the * user-supplied password. */ - return memcmp(computed_key, server_key, SCRAM_KEY_LEN) == 0; + return memcmp(computed_key, server_key, key_length) == 0; } @@ -565,14 +582,15 @@ scram_verify_plain_password(const char *username, const char *password, * On success, the iteration count, salt, stored key, and server key are * extracted from the secret, and returned to the caller. For 'stored_key' * and 'server_key', the caller must pass pre-allocated buffers of size - * SCRAM_KEY_LEN. Salt is returned as a base64-encoded, null-terminated + * SCRAM_MAX_KEY_LEN. Salt is returned as a base64-encoded, null-terminated * string. The buffer for the salt is palloc'd by this function. * * Returns true if the SCRAM secret has been parsed, and false otherwise. */ bool -parse_scram_secret(const char *secret, int *iterations, char **salt, - uint8 *stored_key, uint8 *server_key) +parse_scram_secret(const char *secret, int *iterations, + pg_cryptohash_type *hash_type, int *key_length, + char **salt, uint8 *stored_key, uint8 *server_key) { char *v; char *p; @@ -606,6 +624,8 @@ parse_scram_secret(const char *secret, int *iterations, char **salt, /* Parse the fields */ if (strcmp(scheme_str, "SCRAM-SHA-256") != 0) goto invalid_secret; + *hash_type = PG_SHA256; + *key_length = SCRAM_SHA_256_KEY_LEN; errno = 0; *iterations = strtol(iterations_str, &p, 10); @@ -631,17 +651,17 @@ parse_scram_secret(const char *secret, int *iterations, char **salt, decoded_stored_buf = palloc(decoded_len); decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str), decoded_stored_buf, decoded_len); - if (decoded_len != SCRAM_KEY_LEN) + if (decoded_len != *key_length) goto invalid_secret; - memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN); + memcpy(stored_key, decoded_stored_buf, *key_length); decoded_len = pg_b64_dec_len(strlen(serverkey_str)); decoded_server_buf = palloc(decoded_len); decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str), decoded_server_buf, decoded_len); - if (decoded_len != SCRAM_KEY_LEN) + if (decoded_len != *key_length) goto invalid_secret; - memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN); + memcpy(server_key, decoded_server_buf, *key_length); return true; @@ -655,20 +675,25 @@ invalid_secret: * * In a normal authentication, these are extracted from the secret * stored in the server. This function generates values that look - * realistic, for when there is no stored secret. + * realistic, for when there is no stored secret, using SCRAM-SHA-256. * * Like in parse_scram_secret(), for 'stored_key' and 'server_key', the - * caller must pass pre-allocated buffers of size SCRAM_KEY_LEN, and + * caller must pass pre-allocated buffers of size SCRAM_MAX_KEY_LEN, and * the buffer for the salt is palloc'd by this function. */ static void -mock_scram_secret(const char *username, int *iterations, char **salt, +mock_scram_secret(const char *username, pg_cryptohash_type *hash_type, + int *iterations, int *key_length, char **salt, uint8 *stored_key, uint8 *server_key) { char *raw_salt; char *encoded_salt; int encoded_len; + /* Enforce the use of SHA-256, which would be realistic enough */ + *hash_type = PG_SHA256; + *key_length = SCRAM_SHA_256_KEY_LEN; + /* * Generate deterministic salt. * @@ -677,7 +702,7 @@ mock_scram_secret(const char *username, int *iterations, char **salt, * as the salt generated for mock authentication uses the cluster's nonce * value. */ - raw_salt = scram_mock_salt(username); + raw_salt = scram_mock_salt(username, *hash_type, *key_length); if (raw_salt == NULL) elog(ERROR, "could not encode salt"); @@ -695,8 +720,8 @@ mock_scram_secret(const char *username, int *iterations, char **salt, *iterations = SCRAM_DEFAULT_ITERATIONS; /* StoredKey and ServerKey are not used in a doomed authentication */ - memset(stored_key, 0, SCRAM_KEY_LEN); - memset(server_key, 0, SCRAM_KEY_LEN); + memset(stored_key, 0, SCRAM_MAX_KEY_LEN); + memset(server_key, 0, SCRAM_MAX_KEY_LEN); } /* @@ -1111,10 +1136,10 @@ verify_final_nonce(scram_state *state) static bool verify_client_proof(scram_state *state) { - uint8 ClientSignature[SCRAM_KEY_LEN]; - uint8 ClientKey[SCRAM_KEY_LEN]; - uint8 client_StoredKey[SCRAM_KEY_LEN]; - pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256); + uint8 ClientSignature[SCRAM_MAX_KEY_LEN]; + uint8 ClientKey[SCRAM_MAX_KEY_LEN]; + uint8 client_StoredKey[SCRAM_MAX_KEY_LEN]; + pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type); int i; const char *errstr = NULL; @@ -1123,7 +1148,7 @@ verify_client_proof(scram_state *state) * here even when processing the calculations as this could involve a mock * authentication. */ - if (pg_hmac_init(ctx, state->StoredKey, SCRAM_KEY_LEN) < 0 || + if (pg_hmac_init(ctx, state->StoredKey, state->key_length) < 0 || pg_hmac_update(ctx, (uint8 *) state->client_first_message_bare, strlen(state->client_first_message_bare)) < 0 || @@ -1135,7 +1160,7 @@ verify_client_proof(scram_state *state) pg_hmac_update(ctx, (uint8 *) state->client_final_message_without_proof, strlen(state->client_final_message_without_proof)) < 0 || - pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0) + pg_hmac_final(ctx, ClientSignature, state->key_length) < 0) { elog(ERROR, "could not calculate client signature: %s", pg_hmac_error(ctx)); @@ -1144,14 +1169,15 @@ verify_client_proof(scram_state *state) pg_hmac_free(ctx); /* Extract the ClientKey that the client calculated from the proof */ - for (i = 0; i < SCRAM_KEY_LEN; i++) + for (i = 0; i < state->key_length; i++) ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i]; /* Hash it one more time, and compare with StoredKey */ - if (scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey, &errstr) < 0) + if (scram_H(ClientKey, state->hash_type, state->key_length, + client_StoredKey, &errstr) < 0) elog(ERROR, "could not hash stored key: %s", errstr); - if (memcmp(client_StoredKey, state->StoredKey, SCRAM_KEY_LEN) != 0) + if (memcmp(client_StoredKey, state->StoredKey, state->key_length) != 0) return false; return true; @@ -1349,12 +1375,12 @@ read_client_final_message(scram_state *state, const char *input) client_proof_len = pg_b64_dec_len(strlen(value)); client_proof = palloc(client_proof_len); if (pg_b64_decode(value, strlen(value), client_proof, - client_proof_len) != SCRAM_KEY_LEN) + client_proof_len) != state->key_length) ereport(ERROR, (errcode(ERRCODE_PROTOCOL_VIOLATION), errmsg("malformed SCRAM message"), errdetail("Malformed proof in client-final-message."))); - memcpy(state->ClientProof, client_proof, SCRAM_KEY_LEN); + memcpy(state->ClientProof, client_proof, state->key_length); pfree(client_proof); if (*p != '\0') @@ -1374,13 +1400,13 @@ read_client_final_message(scram_state *state, const char *input) static char * build_server_final_message(scram_state *state) { - uint8 ServerSignature[SCRAM_KEY_LEN]; + uint8 ServerSignature[SCRAM_MAX_KEY_LEN]; char *server_signature_base64; int siglen; - pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256); + pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type); /* calculate ServerSignature */ - if (pg_hmac_init(ctx, state->ServerKey, SCRAM_KEY_LEN) < 0 || + if (pg_hmac_init(ctx, state->ServerKey, state->key_length) < 0 || pg_hmac_update(ctx, (uint8 *) state->client_first_message_bare, strlen(state->client_first_message_bare)) < 0 || @@ -1392,7 +1418,7 @@ build_server_final_message(scram_state *state) pg_hmac_update(ctx, (uint8 *) state->client_final_message_without_proof, strlen(state->client_final_message_without_proof)) < 0 || - pg_hmac_final(ctx, ServerSignature, sizeof(ServerSignature)) < 0) + pg_hmac_final(ctx, ServerSignature, state->key_length) < 0) { elog(ERROR, "could not calculate server signature: %s", pg_hmac_error(ctx)); @@ -1400,11 +1426,11 @@ build_server_final_message(scram_state *state) pg_hmac_free(ctx); - siglen = pg_b64_enc_len(SCRAM_KEY_LEN); + siglen = pg_b64_enc_len(state->key_length); /* don't forget the zero-terminator */ server_signature_base64 = palloc(siglen + 1); siglen = pg_b64_encode((const char *) ServerSignature, - SCRAM_KEY_LEN, server_signature_base64, + state->key_length, server_signature_base64, siglen); if (siglen < 0) elog(ERROR, "could not encode server signature"); @@ -1431,10 +1457,11 @@ build_server_final_message(scram_state *state) * pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN, or NULL. */ static char * -scram_mock_salt(const char *username) +scram_mock_salt(const char *username, pg_cryptohash_type hash_type, + int key_length) { pg_cryptohash_ctx *ctx; - static uint8 sha_digest[PG_SHA256_DIGEST_LENGTH]; + static uint8 sha_digest[SCRAM_MAX_KEY_LEN]; char *mock_auth_nonce = GetMockAuthenticationNonce(); /* @@ -1446,11 +1473,17 @@ scram_mock_salt(const char *username) StaticAssertDecl(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN, "salt length greater than SHA256 digest length"); - ctx = pg_cryptohash_create(PG_SHA256); + /* + * This may be worth refreshing if support for more hash methods is\ + * added. + */ + Assert(hash_type == PG_SHA256); + + ctx = pg_cryptohash_create(hash_type); if (pg_cryptohash_init(ctx) < 0 || pg_cryptohash_update(ctx, (uint8 *) username, strlen(username)) < 0 || pg_cryptohash_update(ctx, (uint8 *) mock_auth_nonce, MOCK_AUTH_NONCE_LEN) < 0 || - pg_cryptohash_final(ctx, sha_digest, sizeof(sha_digest)) < 0) + pg_cryptohash_final(ctx, sha_digest, key_length) < 0) { pg_cryptohash_free(ctx); return NULL; |