diff options
Diffstat (limited to 'src/interfaces/libpq-oauth/oauth-curl.c')
-rw-r--r-- | src/interfaces/libpq-oauth/oauth-curl.c | 219 |
1 files changed, 165 insertions, 54 deletions
diff --git a/src/interfaces/libpq-oauth/oauth-curl.c b/src/interfaces/libpq-oauth/oauth-curl.c index dba9a684fa8..aa50b00d053 100644 --- a/src/interfaces/libpq-oauth/oauth-curl.c +++ b/src/interfaces/libpq-oauth/oauth-curl.c @@ -278,6 +278,7 @@ struct async_ctx bool user_prompted; /* have we already sent the authz prompt? */ bool used_basic_auth; /* did we send a client secret? */ bool debugging; /* can we give unsafe developer assistance? */ + int dbg_num_calls; /* (debug mode) how many times were we called? */ }; /* @@ -1291,22 +1292,31 @@ register_socket(CURL *curl, curl_socket_t socket, int what, void *ctx, return 0; #elif defined(HAVE_SYS_EVENT_H) - struct kevent ev[2] = {0}; + struct kevent ev[2]; struct kevent ev_out[2]; struct timespec timeout = {0}; int nev = 0; int res; + /* + * We don't know which of the events is currently registered, perhaps + * both, so we always try to remove unneeded events. This means we need to + * tolerate ENOENT below. + */ switch (what) { case CURL_POLL_IN: EV_SET(&ev[nev], socket, EVFILT_READ, EV_ADD | EV_RECEIPT, 0, 0, 0); nev++; + EV_SET(&ev[nev], socket, EVFILT_WRITE, EV_DELETE | EV_RECEIPT, 0, 0, 0); + nev++; break; case CURL_POLL_OUT: EV_SET(&ev[nev], socket, EVFILT_WRITE, EV_ADD | EV_RECEIPT, 0, 0, 0); nev++; + EV_SET(&ev[nev], socket, EVFILT_READ, EV_DELETE | EV_RECEIPT, 0, 0, 0); + nev++; break; case CURL_POLL_INOUT: @@ -1317,12 +1327,6 @@ register_socket(CURL *curl, curl_socket_t socket, int what, void *ctx, break; case CURL_POLL_REMOVE: - - /* - * We don't know which of these is currently registered, perhaps - * both, so we try to remove both. This means we need to tolerate - * ENOENT below. - */ EV_SET(&ev[nev], socket, EVFILT_READ, EV_DELETE | EV_RECEIPT, 0, 0, 0); nev++; EV_SET(&ev[nev], socket, EVFILT_WRITE, EV_DELETE | EV_RECEIPT, 0, 0, 0); @@ -1334,7 +1338,10 @@ register_socket(CURL *curl, curl_socket_t socket, int what, void *ctx, return -1; } - res = kevent(actx->mux, ev, nev, ev_out, lengthof(ev_out), &timeout); + Assert(nev <= lengthof(ev)); + Assert(nev <= lengthof(ev_out)); + + res = kevent(actx->mux, ev, nev, ev_out, nev, &timeout); if (res < 0) { actx_error(actx, "could not modify kqueue: %m"); @@ -1377,6 +1384,53 @@ register_socket(CURL *curl, curl_socket_t socket, int what, void *ctx, } /* + * If there is no work to do on any of the descriptors in the multiplexer, then + * this function must ensure that the multiplexer is not readable. + * + * Unlike epoll descriptors, kqueue descriptors only transition from readable to + * unreadable when kevent() is called and finds nothing, after removing + * level-triggered conditions that have gone away. We therefore need a dummy + * kevent() call after operations might have been performed on the monitored + * sockets or timer_fd. Any event returned is ignored here, but it also remains + * queued (being level-triggered) and leaves the descriptor readable. This is a + * no-op for epoll descriptors. + */ +static bool +comb_multiplexer(struct async_ctx *actx) +{ +#if defined(HAVE_SYS_EPOLL_H) + /* The epoll implementation doesn't hold onto stale events. */ + return true; +#elif defined(HAVE_SYS_EVENT_H) + struct timespec timeout = {0}; + struct kevent ev; + + /* + * Try to read a single pending event. We can actually ignore the result: + * either we found an event to process, in which case the multiplexer is + * correctly readable for that event at minimum, and it doesn't matter if + * there are any stale events; or we didn't find any, in which case the + * kernel will have discarded any stale events as it traveled to the end + * of the queue. + * + * Note that this depends on our registrations being level-triggered -- + * even the timer, so we use a chained kqueue for that instead of an + * EVFILT_TIMER on the top-level mux. If we used edge-triggered events, + * this call would improperly discard them. + */ + if (kevent(actx->mux, NULL, 0, &ev, 1, &timeout) < 0) + { + actx_error(actx, "could not comb kqueue: %m"); + return false; + } + + return true; +#else +#error comb_multiplexer is not implemented on this platform +#endif +} + +/* * Enables or disables the timer in the multiplexer set. The timeout value is * in milliseconds (negative values disable the timer). * @@ -1483,40 +1537,20 @@ set_timer(struct async_ctx *actx, long timeout) /* * Returns 1 if the timeout in the multiplexer set has expired since the last - * call to set_timer(), 0 if the timer is still running, or -1 (with an - * actx_error() report) if the timer cannot be queried. + * call to set_timer(), 0 if the timer is either still running or disarmed, or + * -1 (with an actx_error() report) if the timer cannot be queried. */ static int timer_expired(struct async_ctx *actx) { -#if defined(HAVE_SYS_EPOLL_H) - struct itimerspec spec = {0}; - - if (timerfd_gettime(actx->timerfd, &spec) < 0) - { - actx_error(actx, "getting timerfd value: %m"); - return -1; - } - - /* - * This implementation assumes we're using single-shot timers. If you - * change to using intervals, you'll need to reimplement this function - * too, possibly with the read() or select() interfaces for timerfd. - */ - Assert(spec.it_interval.tv_sec == 0 - && spec.it_interval.tv_nsec == 0); - - /* If the remaining time to expiration is zero, we're done. */ - return (spec.it_value.tv_sec == 0 - && spec.it_value.tv_nsec == 0); -#elif defined(HAVE_SYS_EVENT_H) +#if defined(HAVE_SYS_EPOLL_H) || defined(HAVE_SYS_EVENT_H) int res; - /* Is the timer queue ready? */ + /* Is the timer ready? */ res = PQsocketPoll(actx->timerfd, 1 /* forRead */ , 0, 0); if (res < 0) { - actx_error(actx, "checking kqueue for timeout: %m"); + actx_error(actx, "checking timer expiration: %m"); return -1; } @@ -1549,6 +1583,36 @@ register_timer(CURLM *curlm, long timeout, void *ctx) } /* + * Removes any expired-timer event from the multiplexer. If was_expired is not + * NULL, it will contain whether or not the timer was expired at time of call. + */ +static bool +drain_timer_events(struct async_ctx *actx, bool *was_expired) +{ + int res; + + res = timer_expired(actx); + if (res < 0) + return false; + + if (res > 0) + { + /* + * Timer is expired. We could drain the event manually from the + * timerfd, but it's easier to simply disable it; that keeps the + * platform-specific code in set_timer(). + */ + if (!set_timer(actx, -1)) + return false; + } + + if (was_expired) + *was_expired = (res > 0); + + return true; +} + +/* * Prints Curl request debugging information to stderr. * * Note that this will expose a number of critical secrets, so users have to opt @@ -2751,38 +2815,64 @@ pg_fe_run_oauth_flow_impl(PGconn *conn) { PostgresPollingStatusType status; + /* + * Clear any expired timeout before calling back into + * Curl. Curl is not guaranteed to do this for us, because + * its API expects us to use single-shot (i.e. + * edge-triggered) timeouts, and ours are level-triggered + * via the mux. + * + * This can't be combined with the comb_multiplexer() call + * below: we might accidentally clear a short timeout that + * was both set and expired during the call to + * drive_request(). + */ + if (!drain_timer_events(actx, NULL)) + goto error_return; + + /* Move the request forward. */ status = drive_request(actx); if (status == PGRES_POLLING_FAILED) goto error_return; - else if (status != PGRES_POLLING_OK) - { - /* not done yet */ - return status; - } + else if (status == PGRES_POLLING_OK) + break; /* done! */ - break; + /* + * This request is still running. + * + * Make sure that stale events don't cause us to come back + * early. (Currently, this can occur only with kqueue.) If + * this is forgotten, the multiplexer can get stuck in a + * signaled state and we'll burn CPU cycles pointlessly. + */ + if (!comb_multiplexer(actx)) + goto error_return; + + return status; } case OAUTH_STEP_WAIT_INTERVAL: - - /* - * The client application is supposed to wait until our timer - * expires before calling PQconnectPoll() again, but that - * might not happen. To avoid sending a token request early, - * check the timer before continuing. - */ - if (!timer_expired(actx)) { - set_conn_altsock(conn, actx->timerfd); - return PGRES_POLLING_READING; - } + bool expired; - /* Disable the expired timer. */ - if (!set_timer(actx, -1)) - goto error_return; + /* + * The client application is supposed to wait until our + * timer expires before calling PQconnectPoll() again, but + * that might not happen. To avoid sending a token request + * early, check the timer before continuing. + */ + if (!drain_timer_events(actx, &expired)) + goto error_return; - break; + if (!expired) + { + set_conn_altsock(conn, actx->timerfd); + return PGRES_POLLING_READING; + } + + break; + } } /* @@ -2932,6 +3022,8 @@ PostgresPollingStatusType pg_fe_run_oauth_flow(PGconn *conn) { PostgresPollingStatusType result; + fe_oauth_state *state = conn_sasl_state(conn); + struct async_ctx *actx; #ifndef WIN32 sigset_t osigset; bool sigpipe_pending; @@ -2960,6 +3052,25 @@ pg_fe_run_oauth_flow(PGconn *conn) result = pg_fe_run_oauth_flow_impl(conn); + /* + * To assist with finding bugs in comb_multiplexer() and + * drain_timer_events(), when we're in debug mode, track the total number + * of calls to this function and print that at the end of the flow. + * + * Be careful that state->async_ctx could be NULL if early initialization + * fails during the first call. + */ + actx = state->async_ctx; + Assert(actx || result == PGRES_POLLING_FAILED); + + if (actx && actx->debugging) + { + actx->dbg_num_calls++; + if (result == PGRES_POLLING_OK || result == PGRES_POLLING_FAILED) + fprintf(stderr, "[libpq] total number of polls: %d\n", + actx->dbg_num_calls); + } + #ifndef WIN32 if (masked) { |