From 97df682209ab4f672c8ffce67577a2781e3d4d2e Mon Sep 17 00:00:00 2001 From: vimiix Date: Tue, 1 Aug 2023 14:46:38 +0800 Subject: [PATCH] feat(extras):add prepare and execute_prepared_batch function --- lib/extras.py | 36 +++++++- psycopg/cursor_type.c | 206 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 241 insertions(+), 1 deletion(-) diff --git a/lib/extras.py b/lib/extras.py index 1aed151..332b08f 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -1301,4 +1301,38 @@ def _split_sql(sql): if curr is pre: raise ValueError("the query doesn't contain any '%s' placeholder") - return pre, post \ No newline at end of file + return pre, post + + +def execute_prepared_batch(cur, prepared_statement_name, args_list, page_size=100): + r""" + [openGauss libpq only] + + Execute prepared statement with api `PQexecPreparedBatch` (new api in openGauss) + + Param: + argslist: 2d list, do nothing if empty + """ + if len(args_list) == 0: + return + + nparams = len(args_list[0]) + for page in _paginate(args_list, page_size=page_size): + cur.execute_prepared_batch(prepared_statement_name, nparams, len(page), page) + + +def execute_params_batch(cur, sql_format, args_list, page_size=100): + r""" + [openGauss libpq only] + + Execute sql with api `PQexecParamsBatch` (new api in openGauss) + + Arguments: + argslist: 2d list, do nothing if empty + """ + if len(args_list) == 0: + return + + nparams = len(args_list[0]) + for page in _paginate(args_list, page_size=page_size): + cur.execute_params_batch(sql_format, nparams, len(page), page) diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c index 74ab1e7..17ad401 100644 --- a/psycopg/cursor_type.c +++ b/psycopg/cursor_type.c @@ -618,6 +618,206 @@ curs_executemany(cursorObject *self, PyObject *args, PyObject *kwargs) } } +#define curs_execute_prepared_batch_doc \ +"execute_prepared_batch(statement_name, nParams, nBatch, varsList) -- Execute prepared sql with bound batch vars" + +static PyObject * +curs_execute_prepared_batch(cursorObject *self, PyObject *args) +{ + const char *stmtName = NULL; + int nParams = 0, nBatch = 0; + PyObject *argsList = NULL; + + Py_ssize_t rowIdx, colIdx, total; + char **paramValues = NULL; + PGresult *res = NULL; + + /* reset rowcount to -1 to avoid setting it when an exception is raised */ + self->rowcount = -1; + + if (!PyArg_ParseTuple(args, "siiO", &stmtName, &nParams, &nBatch, &argsList)) + { + Dprintf("execute_prepared_batch: parse tuple failed"); + goto exit; + } + Dprintf("execute_prepared_batch parsed statement_name: %s, nParams: %d, nBatch: %d", + stmtName, nParams, nBatch); + total = nBatch*nParams; + + EXC_IF_CURS_CLOSED(self); + EXC_IF_CURS_ASYNC(self, execute_prepared_batch); + EXC_IF_TPC_PREPARED(self->conn, execute_prepared_batch); + + if (self->name != NULL) { + psyco_set_error(ProgrammingError, self, "can't call .execute_prepared_batch() on named cursors"); + goto exit; + } + + // allocate all memory + if (!(paramValues = (char**)malloc(sizeof(char*) * total))) { + PyErr_NoMemory(); + goto exit; + } + memset(paramValues, 0, sizeof(char *) * total); + + if (!PySequence_Check(argsList)) { + psyco_set_error(DataError, self, "expect varsList is a list"); + goto exit; + } + + for (rowIdx = 0; rowIdx < nBatch; rowIdx++) { + PyObject *rowArgs = PySequence_GetItem(argsList, rowIdx); + + // Check if the inner object is a list + if (!PySequence_Check(rowArgs)) { + psyco_set_error(DataError, self, "expect every item in varsList is a list"); + goto exit; + } + + // Loop through each row of parameters + for (colIdx = 0; colIdx < nParams; colIdx++) { + PyObject *argItem = PySequence_GetItem(rowArgs, colIdx); + + if (argItem == Py_None) { + paramValues[rowIdx*nParams+colIdx] = "NULL"; + } else { + PyObject *t = microprotocol_getquoted(argItem, self->conn); + paramValues[rowIdx*nParams+colIdx] = strdup(Bytes_AsString(t)); + Py_XDECREF(t); + } + Py_XDECREF(argItem); + } + Py_XDECREF(rowArgs); + } + + res = PQexecPreparedBatch(self->conn->pgconn, stmtName, nParams, nBatch, + paramValues, NULL, NULL, 0); + conn_set_result(self->conn, res); + if (PQresultStatus(res) != PGRES_COMMAND_OK) { + Dprintf("execute_prepared_batch error: %s", PQresultErrorMessage(res)); + psyco_set_error(OperationalError, self, PQresultErrorMessage(res)); + goto exit; + } + +exit: + free(paramValues); + Py_RETURN_NONE; +} + + +#define curs_execute_params_batch_doc \ +"execute_params_batch(sql, nParams, nBatch, varsList) -- Execute sql with bound batch vars" + +static PyObject * +curs_execute_params_batch(cursorObject *self, PyObject *args) +{ + const char *sql = NULL; + int nParams = 0, nBatch = 0; + PyObject *argsList = NULL; + + Py_ssize_t rowIdx, colIdx, total; + char **paramValues = NULL; + PGresult *res = NULL; + + self->rowcount = -1; + + if (!PyArg_ParseTuple(args, "siiO", &sql, &nParams, &nBatch, &argsList)) + { + Dprintf("execute_params_batch: parse tuple failed"); + goto exit; + } + Dprintf("execute_params_batch parsed sql: %s, nParams: %d, nBatch: %d", + sql, nParams, nBatch); + + total = nBatch*nParams; + + EXC_IF_CURS_CLOSED(self); + EXC_IF_CURS_ASYNC(self, execute_params_batch); + EXC_IF_TPC_PREPARED(self->conn, execute_params_batch); + + if (self->name != NULL) { + psyco_set_error(ProgrammingError, self, "can't call .execute_params_batch() on named cursors"); + goto exit; + } + + if (!(paramValues = (char**)malloc(sizeof(char*) * total))) { + PyErr_NoMemory(); + goto exit; + } + memset(paramValues, 0, sizeof(char *) * total); + + if (!PySequence_Check(argsList)) { + psyco_set_error(DataError, self, "expect varsList is a list"); + goto exit; + } + + for (rowIdx = 0; rowIdx < nBatch; rowIdx++) { + PyObject *rowArgs = PySequence_GetItem(argsList, rowIdx); + + // Check if the inner object is a list + if (!PySequence_Check(rowArgs)) { + psyco_set_error(DataError, self, "expect every item in varsList is a list"); + goto exit; + } + + // Loop through each row of parameters + for (colIdx = 0; colIdx < nParams; colIdx++) { + PyObject *argItem = PySequence_GetItem(rowArgs, colIdx); + + if (argItem == Py_None) { + paramValues[rowIdx*nParams+colIdx] = "NULL"; + } else { + PyObject *t = microprotocol_getquoted(argItem, self->conn); + paramValues[rowIdx*nParams+colIdx] = strdup(Bytes_AsString(t)); + Py_XDECREF(t); + } + Py_XDECREF(argItem); + } + Py_XDECREF(rowArgs); + } + + res = PQexecParamsBatch(self->conn->pgconn, sql, nParams, nBatch, NULL, + paramValues, NULL, NULL, 0); + conn_set_result(self->conn, res); + if (PQresultStatus(res) != PGRES_COMMAND_OK) { + Dprintf("execute_params_batch error: %s", PQresultErrorMessage(res)); + psyco_set_error(OperationalError, self, PQresultErrorMessage(res)); + goto exit; + } + +exit: + free(paramValues); + Py_RETURN_NONE; +} + +#define curs_prepare_doc \ +"prepare(name, command, nparams) -- Prepare a statement" + +static PyObject * +curs_prepare(cursorObject *self, PyObject *args) +{ + const char *stmtname = NULL; + const char *query = NULL; + int nparams; + PGresult *res = NULL; + + if (!PyArg_ParseTuple(args, "ssi", &stmtname, &query, &nparams)) { + return NULL; + } + + EXC_IF_CURS_CLOSED(self); + EXC_IF_ASYNC_IN_PROGRESS(self, prepare); + EXC_IF_TPC_PREPARED(self->conn, prepare); + + res = PQprepare(self->conn->pgconn, stmtname, query, nparams, NULL); + conn_set_result(self->conn, res); + if (PQresultStatus(res) != PGRES_COMMAND_OK) { + psyco_set_error(OperationalError, self, PQresultErrorMessage(res)); + return NULL; + } + + Py_RETURN_NONE; +} #define curs_mogrify_doc \ "mogrify(query, vars=None) -> str -- Return query after vars binding." @@ -1913,6 +2113,12 @@ static struct PyMethodDef cursorObject_methods[] = { METH_VARARGS|METH_KEYWORDS, curs_execute_doc}, {"executemany", (PyCFunction)curs_executemany, METH_VARARGS|METH_KEYWORDS, curs_executemany_doc}, + {"execute_prepared_batch", (PyCFunction)curs_execute_prepared_batch, + METH_VARARGS, curs_execute_prepared_batch_doc}, + {"execute_params_batch", (PyCFunction)curs_execute_params_batch, + METH_VARARGS, curs_execute_params_batch_doc}, + {"prepare", (PyCFunction)curs_prepare, + METH_VARARGS, curs_prepare_doc}, {"fetchone", (PyCFunction)curs_fetchone, METH_NOARGS, curs_fetchone_doc}, {"fetchmany", (PyCFunction)curs_fetchmany, -- Gitee