summaryrefslogtreecommitdiff
path: root/src/interfaces/libpq/fe-protocol3.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/interfaces/libpq/fe-protocol3.c')
-rw-r--r--src/interfaces/libpq/fe-protocol3.c67
1 files changed, 56 insertions, 11 deletions
diff --git a/src/interfaces/libpq/fe-protocol3.c b/src/interfaces/libpq/fe-protocol3.c
index da7a8db68c8..838e42e661a 100644
--- a/src/interfaces/libpq/fe-protocol3.c
+++ b/src/interfaces/libpq/fe-protocol3.c
@@ -16,6 +16,7 @@
#include <ctype.h>
#include <fcntl.h>
+#include <limits.h>
#ifdef WIN32
#include "win32.h"
@@ -55,8 +56,8 @@ static int getCopyStart(PGconn *conn, ExecStatusType copytype);
static int getReadyForQuery(PGconn *conn);
static void reportErrorPosition(PQExpBuffer msg, const char *query,
int loc, int encoding);
-static int build_startup_packet(const PGconn *conn, char *packet,
- const PQEnvironmentOption *options);
+static size_t build_startup_packet(const PGconn *conn, char *packet,
+ const PQEnvironmentOption *options);
/*
@@ -1234,8 +1235,21 @@ reportErrorPosition(PQExpBuffer msg, const char *query, int loc, int encoding)
* scridx[] respectively.
*/
- /* we need a safe allocation size... */
+ /*
+ * We need a safe allocation size.
+ *
+ * The only caller of reportErrorPosition() is pqBuildErrorMessage3(); it
+ * gets its query from either a PQresultErrorField() or a PGcmdQueueEntry,
+ * both of which must have fit into conn->inBuffer/outBuffer. So slen fits
+ * inside an int, but we can't assume that (slen * sizeof(int)) fits
+ * inside a size_t.
+ */
slen = strlen(wquery) + 1;
+ if (slen > SIZE_MAX / sizeof(int))
+ {
+ free(wquery);
+ return;
+ }
qidx = (int *) malloc(slen * sizeof(int));
if (qidx == NULL)
@@ -2373,29 +2387,57 @@ pqBuildStartupPacket3(PGconn *conn, int *packetlen,
const PQEnvironmentOption *options)
{
char *startpacket;
+ size_t len;
+
+ len = build_startup_packet(conn, NULL, options);
+ if (len == 0 || len > INT_MAX)
+ return NULL;
- *packetlen = build_startup_packet(conn, NULL, options);
+ *packetlen = len;
startpacket = (char *) malloc(*packetlen);
if (!startpacket)
return NULL;
- *packetlen = build_startup_packet(conn, startpacket, options);
+
+ len = build_startup_packet(conn, startpacket, options);
+ Assert(*packetlen == len);
+
return startpacket;
}
/*
+ * Frontend version of the backend's add_size(), intended to be API-compatible
+ * with the pg_add_*_overflow() helpers. Stores the result into *dst on success.
+ * Returns true instead if the addition overflows.
+ *
+ * TODO: move to common/int.h
+ */
+static bool
+add_size_overflow(size_t s1, size_t s2, size_t *dst)
+{
+ size_t result;
+
+ result = s1 + s2;
+ if (result < s1 || result < s2)
+ return true;
+
+ *dst = result;
+ return false;
+}
+
+/*
* Build a startup packet given a filled-in PGconn structure.
*
* We need to figure out how much space is needed, then fill it in.
* To avoid duplicate logic, this routine is called twice: the first time
* (with packet == NULL) just counts the space needed, the second time
* (with packet == allocated space) fills it in. Return value is the number
- * of bytes used.
+ * of bytes used, or zero in the unlikely event of size_t overflow.
*/
-static int
+static size_t
build_startup_packet(const PGconn *conn, char *packet,
const PQEnvironmentOption *options)
{
- int packet_len = 0;
+ size_t packet_len = 0;
const PQEnvironmentOption *next_eo;
const char *val;
@@ -2414,10 +2456,12 @@ build_startup_packet(const PGconn *conn, char *packet,
do { \
if (packet) \
strcpy(packet + packet_len, optname); \
- packet_len += strlen(optname) + 1; \
+ if (add_size_overflow(packet_len, strlen(optname) + 1, &packet_len)) \
+ return 0; \
if (packet) \
strcpy(packet + packet_len, optval); \
- packet_len += strlen(optval) + 1; \
+ if (add_size_overflow(packet_len, strlen(optval) + 1, &packet_len)) \
+ return 0; \
} while(0)
if (conn->pguser && conn->pguser[0])
@@ -2452,7 +2496,8 @@ build_startup_packet(const PGconn *conn, char *packet,
/* Add trailing terminator */
if (packet)
packet[packet_len] = '\0';
- packet_len++;
+ if (add_size_overflow(packet_len, 1, &packet_len))
+ return 0;
return packet_len;
}