diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0656c5a60eb04bea6846b8c4549736c6e1184351 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +build +dist +*.pyc +.idea/ +.DS_Store +py_opengauss.egg-info diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 0000000000000000000000000000000000000000..58c8986d313d04492cdb7eaf22bd793ab81c4307 --- /dev/null +++ b/AUTHORS @@ -0,0 +1,25 @@ +Contributors: + James William Pye [faults are mostly mine] + Elvis Pranskevichus + William Grzybowski [subjective paramstyle] + Barry Grussling [inet/cidr support] + Matthew Grant [inet/cidr support] + +Support by Donation: + AppCove Network + +Imported +======== + +DB-API 2.0 Test Case +-------------------- + +postgresql/test/test_dbapi20.py: + Stuart Bishop + + +fcrypt +------ + +postgresql/resolved/crypt.py: + Carey Evans diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..875a834e685d0b8ac67c091d8429c80bd28e26fe --- /dev/null +++ b/LICENSE @@ -0,0 +1,29 @@ +Redistribution and use in source and binary forms, +with or without modification, are permitted provided +that the following conditions are met: + + Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + Neither the name of the James William Pye nor the names of [its] + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING +IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..fcb59edebf84ff2b2b569a7dc1228d3b82f58d25 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,6 @@ +include AUTHORS +include LICENSE +recursive-include postgresql *.c +recursive-include postgresql *.sql +recursive-include postgresql *.txt +recursive-include postgresql/documentation/sphinx *.rst conf.py diff --git a/README.md b/README.md index 20a1fd2d1185945ef01a24fb47fb5417b5ab1270..0f6089c75c1b2c54f9dd20796919c623419cc561 100644 --- a/README.md +++ b/README.md @@ -1,39 +1,60 @@ -# openGauss-connector-python +### About -#### 介绍 -{**以下是 Gitee 平台说明,您可以替换此简介** -Gitee 是 OSCHINA 推出的基于 Git 的代码托管平台(同时支持 SVN)。专为开发者提供稳定、高效、安全的云端软件开发协作平台 -无论是个人、团队、或是企业,都能够用 Gitee 实现代码托管、项目管理、协作开发。企业项目请看 [https://gitee.com/enterprises](https://gitee.com/enterprises)} +py-opengauss is a Python 3 package providing modules for working with openGauss. +Primarily, a high-level driver for querying databases. -#### 软件架构 -软件架构说明 +For a high performance async interface, MagicStack's asyncpg +http://github.com/MagicStack/asyncpg should be considered. +py-opengauss, currently, does not have direct support for high-level async +interfaces provided by recent versions of Python. Future versions may change this. -#### 安装教程 +### Advisory -1. xxxx -2. xxxx -3. xxxx +In v1.3, `py_opengauss.driver.dbapi20.connect` will now raise `ClientCannotConnectError` directly. +Exception traps around connect should still function, but the `__context__` attribute +on the error instance will be `None` in the usual failure case as it is no longer +incorrectly chained. Trapping `ClientCannotConnectError` ahead of `Error` should +allow both cases to co-exist in the event that data is being extracted from +the `ClientCannotConnectError`. -#### 使用说明 +In v2.0, support for older versions of PostgreSQL and Python will be removed. +If you have automated installations using PyPI, make sure that they specify a major version. -1. xxxx -2. xxxx -3. xxxx +### Installation -#### 参与贡献 +Using PyPI.org: -1. Fork 本仓库 -2. 新建 Feat_xxx 分支 -3. 提交代码 -4. 新建 Pull Request + $ pip install py-opengauss +From a clone: -#### 特技 + $ git clone https://github.com/vimiix/py-opengauss.git + $ cd py-opengauss + $ python3 ./setup.py install # Or use in-place without installation(PYTHONPATH). -1. 使用 Readme\_XXX.md 来支持不同的语言,例如 Readme\_en.md, Readme\_zh.md -2. Gitee 官方博客 [blog.gitee.com](https://blog.gitee.com) -3. 你可以 [https://gitee.com/explore](https://gitee.com/explore) 这个地址来了解 Gitee 上的优秀开源项目 -4. [GVP](https://gitee.com/gvp) 全称是 Gitee 最有价值开源项目,是综合评定出的优秀开源项目 -5. Gitee 官方提供的使用手册 [https://gitee.com/help](https://gitee.com/help) -6. Gitee 封面人物是一档用来展示 Gitee 会员风采的栏目 [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) +### Basic Usage + +> Support schemes: ['pq', 'postgres', 'postgresql', 'og', 'opengauss'] + +```python +import py_opengauss +db = py_opengauss.open('opengauss://user:password@host:port/database') + +get_table = db.prepare("SELECT * from information_schema.tables WHERE table_name = $1") +print(get_table("tables")) + +# Streaming, in a transaction. +with db.xact(): + for x in get_table.rows("tables"): + print(x) +``` + +### Documentation + +http://py-postgresql.readthedocs.io + +### Related + +- http://postgresql.org +- http://python.org diff --git a/py_opengauss/__init__.py b/py_opengauss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..621f9a62efae0afb33d17c8f41fb9687dc17ec1f --- /dev/null +++ b/py_opengauss/__init__.py @@ -0,0 +1,98 @@ +## +# py-postgresql root package +# http://github.com/python-postgres/fe +## +""" +py-postgresql is a Python package for using PostgreSQL. This includes low-level +protocol tools, a driver(PG-API and DB-API 2.0), and cluster management tools. + +See for more information about PostgreSQL and +for information about Python. +""" +__all__ = [ + '__author__', + '__date__', + '__version__', + '__docformat__', + 'version', + 'version_info', + 'open', +] + +#: The version string of py-postgresql. +version = '' # overridden by subsequent import from .project. + +#: The version triple of py-postgresql: (major, minor, patch). +version_info = () # overridden by subsequent import from .project. + +# Optional. +try: + from .project import version_info, version, \ + author as __author__, date as __date__ + __version__ = version +except ImportError: + pass + +# Avoid importing these until requested. +_pg_iri = _pg_driver = _pg_param = None +def open(iri = None, prompt_title = None, **kw): + """ + Create a `postgresql.api.Connection` to the server referenced by the given + `iri`:: + + >>> import py_opengauss + # General Format: + >>> db = py_opengauss.open('pq://user:password@host:port/database') + + # Also support opengauss scheme: + >>> db = py_opengauss.open('opengauss://user:password@host:port/database') + + # Connect to 'postgres' at localhost. + >>> db = py_opengauss.open('localhost/postgres') + + Connection keywords can also be used with `open`. See the narratives for + more information. + + The `prompt_title` keyword is ignored. `open` will never prompt for + the password unless it is explicitly instructed to do so. + + (Note: "pq" is the name of the protocol used to communicate with PostgreSQL) + """ + global _pg_iri, _pg_driver, _pg_param + if _pg_iri is None: + from . import iri as _pg_iri + from . import driver as _pg_driver + from . import clientparameters as _pg_param + + return_connector = False + if iri is not None: + if iri.startswith('&'): + return_connector = True + iri = iri[1:] + iri_params = _pg_iri.parse(iri) + iri_params.pop('path', None) + else: + iri_params = {} + + std_params = _pg_param.collect(prompt_title = None) + # If unix is specified, it's going to conflict with any standard + # settings, so remove them right here. + if 'unix' in kw or 'unix' in iri_params: + std_params.pop('host', None) + std_params.pop('port', None) + params = _pg_param.normalize( + list(_pg_param.denormalize_parameters(std_params)) + \ + list(_pg_param.denormalize_parameters(iri_params)) + \ + list(_pg_param.denormalize_parameters(kw)) + ) + _pg_param.resolve_password(params) + + C = _pg_driver.default.fit(**params) + if return_connector is True: + return C + else: + c = C() + c.connect() + return c + +__docformat__ = 'reStructuredText' diff --git a/py_opengauss/alock.py b/py_opengauss/alock.py new file mode 100644 index 0000000000000000000000000000000000000000..cdf413d16704cb716058415f2b35bd64b5827af0 --- /dev/null +++ b/py_opengauss/alock.py @@ -0,0 +1,157 @@ +## +# .alock - Advisory Locks +## +""" +Tools for Advisory Locks +""" +import abc +from .python import element + +__all__ = [ + 'ALock', + 'ExclusiveLock', + 'ShareLock', +] + +class ALock(element.Element): + """ + Advisory Lock class for managing the acquisition and release of a sequence + of PostgreSQL advisory locks. + + ALock()'s are fairly consistent with threading.RLock()'s. They can be + acquired multiple times, and they must be released the same number of times + for the lock to actually be released. + + A notably difference is that ALock's manage a sequence of lock identifiers. + This means that a given ALock() may represent multiple advisory locks. + """ + _e_factors = ('database', 'identifiers',) + _e_label = 'ALOCK' + def _e_metas(self, + headfmt = "{1} [{0}]".format + ): + yield None, headfmt(self.state, self.mode) + + @abc.abstractproperty + def mode(self): + """ + The mode of the lock class. + """ + + @abc.abstractproperty + def __select_statements__(self): + """ + Implemented by subclasses to return the statements to try, acquire, and + release the advisory lock. + + Returns a triple of callables where each callable takes two arguments, + the lock-id pairs, and then the int8 lock-ids. + ``(try, acquire, release)``. + """ + + @staticmethod + def _split_lock_identifiers(idseq): + # lame O(2) + id_pairs = [ + list(x) if x.__class__ is not int else [None,None] + for x in idseq + ] + ids = [ + x if x.__class__ is int else None + for x in idseq + ] + return (id_pairs, ids) + + def acquire(self, blocking = True, len = len): + """ + Acquire the locks using the configured identifiers. + """ + if self._count == 0: + # _count is zero, so the locks need to be acquired. + wait = bool(blocking) + if wait: + self._acquire(self._id_pairs, self._ids) + else: + # grab the success of each lock id. if some were + # unsuccessful, then the ones that were successful need to be + # released. + r = self._try(self._id_pairs, self._ids) + # accumulate the identifiers that *did* lock + release_seq = [ + id for didlock, id in zip(r, self.identifiers) if didlock[0] + ] + if len(release_seq) != len(self.identifiers): + # some failed, so release the acquired and return False + # + # reverse in case there is another waiting for all. + # that is, release last-to-first so that if another is waiting + # on the same seq that it should be able to acquire all of + # them once the contended lock is released. + release_seq.reverse() + self._release(*self._split_lock_identifiers(release_seq)) + # unable to acquire all. + return False + self._count = self._count + 1 + return True + + def __enter__(self): + self.acquire() + return self + + def release(self): + """ + Release the locks using the configured identifiers. + """ + if self._count < 1: + raise RuntimeError("cannot release un-acquired lock") + if not self.database.closed and self._count > 0: + # if the database has been closed, or the count will + # remain non-zero, there is no need to release. + self._release(reversed(self._id_pairs), reversed(self._ids)) + # decrement the count nonetheless. + self._count = self._count - 1 + + def __exit__(self, typ, val, tb): + self.release() + + def locked(self): + """ + Whether the locks have been acquired. This method is sensitive to the + connection's state. If the connection is closed, it will return False. + """ + return (self._count > 0) and (not self.database.closed) + + @property + def state(self): + return 'locked' if self.locked() else 'unlocked' + + def __init__(self, database, *identifiers): + """ + Initialize the lock object to manage a sequence of advisory locks + for use with the given database. + """ + self._count = 0 + self.connection = self.database = database + self.identifiers = identifiers + self._id_pairs, self._ids = self._split_lock_identifiers(identifiers) + self._try, self._acquire, self._release = self.__select_statements__() + +class ShareLock(ALock): + mode = 'share' + + def __select_statements__(self): + return ( + self.database.sys.try_advisory_shared, + self.database.sys.acquire_advisory_shared, + self.database.sys.release_advisory_shared, + ) + +class ExclusiveLock(ALock): + mode = 'exclusive' + + def __select_statements__(self): + return ( + self.database.sys.try_advisory_exclusive, + self.database.sys.acquire_advisory_exclusive, + self.database.sys.release_advisory_exclusive, + ) diff --git a/py_opengauss/api.py b/py_opengauss/api.py new file mode 100644 index 0000000000000000000000000000000000000000..e177c30ec567cdb60630ad8498c4ce6b3b1ea1a1 --- /dev/null +++ b/py_opengauss/api.py @@ -0,0 +1,1379 @@ +## +# .api - ABCs for database interface elements +## +""" +Application Programmer Interfaces for PostgreSQL. + +``postgresql.api`` is a collection of Python APIs for the PostgreSQL DBMS. It +is designed to take full advantage of PostgreSQL's features to provide the +Python programmer with substantial convenience. + +This module is used to define "PG-API". It creates a set of ABCs +that makes up the basic interfaces used to work with a PostgreSQL server. +""" +import collections.abc +import abc + +from .python.element import Element + +__all__ = [ + 'Message', + 'Statement', + 'Chunks', + 'Cursor', + 'Connector', + 'Category', + 'Database', + 'TypeIO', + 'Connection', + 'Transaction', + 'Settings', + 'StoredProcedure', + 'Driver', + 'Installation', + 'Cluster', +] + +class Message(Element): + """ + A message emitted by PostgreSQL. + A message being a NOTICE, WARNING, INFO, etc. + """ + _e_label = 'MESSAGE' + + severities = ( + 'DEBUG', + 'INFO', + 'NOTICE', + 'WARNING', + 'ERROR', + 'FATAL', + 'PANIC', + ) + sources = ( + 'SERVER', + 'CLIENT', + ) + + @property + @abc.abstractmethod + def source(self) -> str: + """ + Where the message originated from. Normally, 'SERVER', but sometimes + 'CLIENT'. + """ + + @property + @abc.abstractmethod + def code(self) -> str: + """ + The SQL state code of the message. + """ + + @property + @abc.abstractmethod + def message(self) -> str: + """ + The primary message string. + """ + + @property + @abc.abstractmethod + def details(self) -> dict: + """ + The additional details given with the message. Common keys *should* be the + following: + + * 'severity' + * 'context' + * 'detail' + * 'hint' + * 'file' + * 'line' + * 'function' + * 'position' + * 'internal_position' + * 'internal_query' + """ + + @abc.abstractmethod + def isconsistent(self, other) -> bool: + """ + Whether the fields of the `other` Message object is consistent with the + fields of `self`. + + This *must* return the result of the comparison of code, source, message, + and details. + + This method is provided as the alternative to overriding equality; + often, pointer equality is the desirable means for comparison, but + equality of the fields is also necessary. + """ + +class Result(Element): + """ + A result is an object managing the results of a prepared statement. + + These objects represent a binding of parameters to a given statement object. + + For results that were constructed on the server and a reference passed back + to the client, statement and parameters may be None. + """ + _e_label = 'RESULT' + _e_factors = ('statement', 'parameters', 'cursor_id') + + @abc.abstractmethod + def close(self) -> None: + """ + Close the Result discarding any supporting resources and causing + future read operations to emit empty record sets. + """ + + @property + @abc.abstractmethod + def cursor_id(self) -> str: + """ + The cursor's identifier. + """ + + @property + @abc.abstractmethod + def sql_column_types(self) -> [str]: + """ + The type of the columns produced by the cursor. + + A sequence of `str` objects stating the SQL type name:: + + ['INTEGER', 'CHARACTER VARYING', 'INTERVAL'] + """ + + @property + @abc.abstractmethod + def pg_column_types(self) -> [int]: + """ + The type Oids of the columns produced by the cursor. + + A sequence of `int` objects stating the SQL type name:: + + [27, 28] + """ + + @property + @abc.abstractmethod + def column_names(self) -> [str]: + """ + The attribute names of the columns produced by the cursor. + + A sequence of `str` objects stating the column name:: + + ['column1', 'column2', 'emp_name'] + """ + + @property + @abc.abstractmethod + def column_types(self) -> [str]: + """ + The Python types of the columns produced by the cursor. + + A sequence of type objects:: + + [, ] + """ + + @property + @abc.abstractmethod + def parameters(self) -> (tuple, None): + """ + The parameters bound to the cursor. `None`, if unknown and an empty tuple + `()`, if no parameters were given. + + These should be the *original* parameters given to the invoked statement. + + This should only be `None` when the cursor is created from an identifier, + `postgresql.api.Database.cursor_from_id`. + """ + + @property + @abc.abstractmethod + def statement(self) -> ("Statement", None): + """ + The query object used to create the cursor. `None`, if unknown. + + This should only be `None` when the cursor is created from an identifier, + `postgresql.api.Database.cursor_from_id`. + """ + +@collections.abc.Iterator.register +class Chunks(Result): + pass + +@collections.abc.Iterator.register +class Cursor(Result): + """ + A `Cursor` object is an interface to a sequence of tuples(rows). A result + set. Cursors publish a file-like interface for reading tuples from a cursor + declared on the database. + + `Cursor` objects are created by invoking the `Statement.declare` + method or by opening a cursor using an identifier via the + `Database.cursor_from_id` method. + """ + _e_label = 'CURSOR' + + _seek_whence_map = { + 0 : 'ABSOLUTE', + 1 : 'RELATIVE', + 2 : 'FROM_END', + 3 : 'FORWARD', + 4 : 'BACKWARD' + } + _direction_map = { + True : 'FORWARD', + False : 'BACKWARD', + } + + @abc.abstractmethod + def clone(self) -> "Cursor": + """ + Create a new cursor using the same factors as `self`. + """ + + def __iter__(self): + return self + + @property + @abc.abstractmethod + def direction(self) -> bool: + """ + The default `direction` argument for read(). + + When `True` reads are FORWARD. + When `False` reads are BACKWARD. + + Cursor operation option. + """ + + @abc.abstractmethod + def read(self, quantity = None, direction = None) -> ["Row"]: + """ + Read, fetch, the specified number of rows and return them in a list. + If quantity is `None`, all records will be fetched. + + `direction` can be used to override the default configured direction. + + This alters the cursor's position. + + Read does not directly correlate to FETCH. If zero is given as the + quantity, an empty sequence *must* be returned. + """ + + @abc.abstractmethod + def __next__(self) -> "Row": + """ + Get the next tuple in the cursor. + Advances the cursor position by one. + """ + + @abc.abstractmethod + def seek(self, offset, whence = 'ABSOLUTE'): + """ + Set the cursor's position to the given offset with respect to the + whence parameter and the configured direction. + + Whence values: + + ``0`` or ``"ABSOLUTE"`` + Absolute. + ``1`` or ``"RELATIVE"`` + Relative. + ``2`` or ``"FROM_END"`` + Absolute from end. + ``3`` or ``"FORWARD"`` + Relative forward. + ``4`` or ``"BACKWARD"`` + Relative backward. + + Direction effects whence. If direction is BACKWARD, ABSOLUTE positioning + will effectively be FROM_END, RELATIVE's position will be negated, and + FROM_END will effectively be ABSOLUTE. + """ + +class Execution(metaclass = abc.ABCMeta): + """ + The abstract class of execution methods. + """ + + @abc.abstractmethod + def __call__(self, *parameters) -> ["Row"]: + """ + Execute the prepared statement with the given arguments as parameters. + + Usage: + + >>> p=db.prepare("SELECT column FROM ttable WHERE key = $1") + >>> p('identifier') + [...] + """ + + @abc.abstractmethod + def column(self, *parameters) -> collections.abc.Iterable: + """ + Return an iterator producing the values of first column of the + rows produced by the cursor created from the statement bound with the + given parameters. + + Column iterators are never scrollable. + + Supporting cursors will be WITH HOLD when outside of a transaction to + allow cross-transaction access. + + `column` is designed for the situations involving large data sets. + + Each iteration returns a single value. + + column expressed in sibling terms:: + + return map(operator.itemgetter(0), self.rows(*parameters)) + """ + + @abc.abstractmethod + def chunks(self, *parameters) -> collections.abc.Iterable: + """ + Return an iterator producing sequences of rows produced by the cursor + created from the statement bound with the given parameters. + + Chunking iterators are *never* scrollable. + + Supporting cursors will be WITH HOLD when outside of a transaction. + + `chunks` is designed for moving large data sets efficiently. + + Each iteration returns sequences of rows *normally* of length(seq) == + chunksize. If chunksize is unspecified, a default, positive integer will + be filled in. The rows contained in the sequences are only required to + support the basic `collections.abc.Sequence` interfaces; simple and quick + sequence types should be used. + """ + + @abc.abstractmethod + def rows(self, *parameters) -> collections.abc.Iterable: + """ + Return an iterator producing rows produced by the cursor + created from the statement bound with the given parameters. + + Row iterators are never scrollable. + + Supporting cursors will be WITH HOLD when outside of a transaction to + allow cross-transaction access. + + `rows` is designed for the situations involving large data sets. + + Each iteration returns a single row. Arguably, best implemented:: + + return itertools.chain.from_iterable(self.chunks(*parameters)) + """ + + @abc.abstractmethod + def column(self, *parameters) -> collections.abc.Iterable: + """ + Return an iterator producing the values of the first column in + the cursor created from the statement bound with the given parameters. + + Column iterators are never scrollable. + + Supporting cursors will be WITH HOLD when outside of a transaction to + allow cross-transaction access. + + `column` is designed for the situations involving large data sets. + + Each iteration returns a single value. `column` is equivalent to:: + + return map(operator.itemgetter(0), self.rows(*parameters)) + """ + + @abc.abstractmethod + def declare(self, *parameters) -> Cursor: + """ + Return a scrollable cursor with hold using the statement bound with the + given parameters. + """ + + @abc.abstractmethod + def first(self, *parameters): + """ + Execute the prepared statement with the given arguments as parameters. + If the statement returns rows with multiple columns, return the first + row. If the statement returns rows with a single column, return the + first column in the first row. If the query does not return rows at all, + return the count or `None` if no count exists in the completion message. + + Usage: + + >>> db.prepare("SELECT * FROM ttable WHERE key = $1").first("somekey") + ('somekey', 'somevalue') + >>> db.prepare("SELECT 'foo'").first() + 'foo' + >>> db.prepare("INSERT INTO atable (col) VALUES (1)").first() + 1 + """ + + @abc.abstractmethod + def load_rows(self, iterable): + """ + Given an iterable, `iterable`, feed the produced parameters to the + query. This is a bulk-loading interface for parameterized queries. + + Effectively, it is equivalent to: + + >>> q = db.prepare(sql) + >>> for i in iterable: + ... q(*i) + + Its purpose is to allow the implementation to take advantage of the + knowledge that a series of parameters are to be loaded so that the + operation can be optimized. + """ + + @abc.abstractmethod + def load_chunks(self, iterable): + """ + Given an iterable, `iterable`, feed the produced parameters of the chunks + produced by the iterable to the query. This is a bulk-loading interface + for parameterized queries. + + Effectively, it is equivalent to: + + >>> ps = db.prepare(...) + >>> for c in iterable: + ... for i in c: + ... q(*i) + + Its purpose is to allow the implementation to take advantage of the + knowledge that a series of chunks of parameters are to be loaded so + that the operation can be optimized. + """ + +@collections.abc.Iterator.register +@collections.abc.Callable.register +@Execution.register +class Statement(Element): + """ + Instances of `Statement` are returned by the `prepare` method of + `Database` instances. + + A Statement is an Iterable as well as Callable. + + The Iterable interface is supported for queries that take no arguments at + all. It allows the syntax:: + + >>> for x in db.prepare('select * FROM table'): + ... pass + """ + _e_label = 'STATEMENT' + _e_factors = ('database', 'statement_id', 'string',) + + @property + @abc.abstractmethod + def statement_id(self) -> str: + """ + The statment's identifier. + """ + + @property + @abc.abstractmethod + def string(self) -> object: + """ + The SQL string of the prepared statement. + + `None` if not available. This can happen in cases where a statement is + prepared on the server and a reference to the statement is sent to the + client which subsequently uses the statement via the `Database`'s + `statement` constructor. + """ + + @property + @abc.abstractmethod + def sql_parameter_types(self) -> [str]: + """ + The type of the parameters required by the statement. + + A sequence of `str` objects stating the SQL type name:: + + ['INTEGER', 'VARCHAR', 'INTERVAL'] + """ + + @property + @abc.abstractmethod + def sql_column_types(self) -> [str]: + """ + The type of the columns produced by the statement. + + A sequence of `str` objects stating the SQL type name:: + + ['INTEGER', 'VARCHAR', 'INTERVAL'] + """ + + @property + @abc.abstractmethod + def pg_parameter_types(self) -> [int]: + """ + The type Oids of the parameters required by the statement. + + A sequence of `int` objects stating the PostgreSQL type Oid:: + + [27, 28] + """ + + @property + @abc.abstractmethod + def pg_column_types(self) -> [int]: + """ + The type Oids of the columns produced by the statement. + + A sequence of `int` objects stating the SQL type name:: + + [27, 28] + """ + + @property + @abc.abstractmethod + def column_names(self) -> [str]: + """ + The attribute names of the columns produced by the statement. + + A sequence of `str` objects stating the column name:: + + ['column1', 'column2', 'emp_name'] + """ + + @property + @abc.abstractmethod + def column_types(self) -> [type]: + """ + The Python types of the columns produced by the statement. + + A sequence of type objects:: + + [, ] + """ + + @property + @abc.abstractmethod + def parameter_types(self) -> [type]: + """ + The Python types expected of parameters given to the statement. + + A sequence of type objects:: + + [, ] + """ + + @abc.abstractmethod + def clone(self) -> "Statement": + """ + Create a new statement object using the same factors as `self`. + + When used for refreshing plans, the new clone should replace references to + the original. + """ + + @abc.abstractmethod + def close(self) -> None: + """ + Close the prepared statement releasing resources associated with it. + """ +PreparedStatement = Statement + +@collections.abc.Callable.register +class StoredProcedure(Element): + """ + A function stored on the database. + """ + _e_label = 'FUNCTION' + _e_factors = ('database',) + + @abc.abstractmethod + def __call__(self, *args, **kw) -> (object, Cursor, collections.abc.Iterable): + """ + Execute the procedure with the given arguments. If keyword arguments are + passed they must be mapped to the argument whose name matches the key. + If any positional arguments are given, they must fill in gaps created by + the stated keyword arguments. If too few or too many arguments are + given, a TypeError must be raised. If a keyword argument is passed where + the procedure does not have a corresponding argument name, then, + likewise, a TypeError must be raised. + + In the case where the `StoredProcedure` references a set returning + function(SRF), the result *must* be an iterable. SRFs that return single + columns *must* return an iterable of that column; not row data. If the + SRF returns a composite(OUT parameters), it *should* return a `Cursor`. + """ + +## +# Arguably, it would be wiser to isolate blocks, and savepoints, but the utility +# of the separation is not significant. It's really +# more interesting as a formality that the user may explicitly state the +# type of the transaction. However, this capability is not completely absent +# from the current interface as the configuration parameters, or lack thereof, +# help imply the expectations. +class Transaction(Element): + """ + A `Tranaction` is an element that represents a transaction in the session. + Once created, it's ready to be started, and subsequently committed or + rolled back. + + Read-only transaction: + + >>> with db.xact(mode = 'read only'): + ... ... + + Read committed isolation: + + >>> with db.xact(isolation = 'READ COMMITTED'): + ... ... + + Savepoints are created if inside a transaction block: + + >>> with db.xact(): + ... with db.xact(): + ... ... + """ + _e_label = 'XACT' + _e_factors = ('database',) + + @property + @abc.abstractmethod + def mode(self) -> (None, str): + """ + The mode of the transaction block: + + START TRANSACTION [ISOLATION] ; + + The `mode` property is a string and will be directly interpolated into the + START TRANSACTION statement. + """ + + @property + @abc.abstractmethod + def isolation(self) -> (None, str): + """ + The isolation level of the transaction block: + + START TRANSACTION [MODE]; + + The `isolation` property is a string and will be directly interpolated into + the START TRANSACTION statement. + """ + + @abc.abstractmethod + def start(self) -> None: + """ + Start the transaction. + + If the database is in a transaction block, the transaction should be + configured as a savepoint. If any transaction block configuration was + applied to the transaction, raise a `postgresql.exceptions.OperationError`. + + If the database is not in a transaction block, start one using the + configuration where: + + `self.isolation` specifies the ``ISOLATION LEVEL``. Normally, ``READ + COMMITTED``, ``SERIALIZABLE``, or ``READ UNCOMMITTED``. + + `self.mode` specifies the mode of the transaction. Normally, ``READ + ONLY`` or ``READ WRITE``. + + If the transaction is already open, do nothing. + + If the transaction has been committed or aborted, raise an + `postgresql.exceptions.OperationError`. + """ + begin = start + + @abc.abstractmethod + def commit(self) -> None: + """ + Commit the transaction. + + If the transaction is a block, issue a COMMIT statement. + + If the transaction was started inside a transaction block, it should be + identified as a savepoint, and the savepoint should be released. + + If the transaction has already been committed, do nothing. + """ + + @abc.abstractmethod + def rollback(self) -> None: + """ + Abort the transaction. + + If the transaction is a savepoint, ROLLBACK TO the savepoint identifier. + + If the transaction is a transaction block, issue an ABORT. + + If the transaction has already been aborted, do nothing. + """ + abort = rollback + + @abc.abstractmethod + def __enter__(self): + """ + Run the `start` method and return self. + """ + + @abc.abstractmethod + def __exit__(self, typ, obj, tb): + """ + If an exception is indicated by the parameters, run the transaction's + `rollback` method iff the database is still available(not closed), and + return a `False` value. + + If an exception is not indicated, but the database's transaction state is + in error, run the transaction's `rollback` method and raise a + `postgresql.exceptions.InFailedTransactionError`. If the database is + unavailable, the `rollback` method should cause a + `postgresql.exceptions.ConnectionDoesNotExistError` exception to occur. + + Otherwise, run the transaction's `commit` method. + + When the `commit` is ultimately unsuccessful or not ran at all, the purpose + of __exit__ is to resolve the error state of the database iff the + database is available(not closed) so that more commands can be after the + block's exit. + """ + +@collections.abc.MutableMapping.register +class Settings(Element): + """ + A mapping interface to the session's settings. This provides a direct + interface to ``SHOW`` or ``SET`` commands. Identifiers and values need + not be quoted specially as the implementation must do that work for the + user. + """ + _e_label = 'SETTINGS' + + @abc.abstractmethod + def __getitem__(self, key): + """ + Return the setting corresponding to the given key. The result should be + consistent with what the ``SHOW`` command returns. If the key does not + exist, raise a KeyError. + """ + + @abc.abstractmethod + def __setitem__(self, key, value): + """ + Set the setting with the given key to the given value. The action should + be consistent with the effect of the ``SET`` command. + """ + + @abc.abstractmethod + def __call__(self, **kw): + """ + Create a context manager applying the given settings on __enter__ and + restoring the old values on __exit__. + + >>> with db.settings(search_path = 'local,public'): + ... ... + """ + + @abc.abstractmethod + def get(self, key, default = None): + """ + Get the setting with the corresponding key. If the setting does not + exist, return the `default`. + """ + + @abc.abstractmethod + def getset(self, keys): + """ + Return a dictionary containing the key-value pairs of the requested + settings. If *any* of the keys do not exist, a `KeyError` must be raised + with the set of keys that did not exist. + """ + + @abc.abstractmethod + def update(self, mapping): + """ + For each key-value pair, incur the effect of the `__setitem__` method. + """ + + @abc.abstractmethod + def keys(self): + """ + Return an iterator to all of the settings' keys. + """ + + @abc.abstractmethod + def values(self): + """ + Return an iterator to all of the settings' values. + """ + + @abc.abstractmethod + def items(self): + """ + Return an iterator to all of the setting value pairs. + """ + +class Database(Element): + """ + The interface to an individual database. `Connection` objects inherit from + this + """ + _e_label = 'DATABASE' + + @property + @abc.abstractmethod + def backend_id(self) -> (int, None): + """ + The backend's process identifier. + """ + + @property + @abc.abstractmethod + def version_info(self) -> tuple: + """ + A version tuple of the database software similar Python's `sys.version_info`. + + >>> db.version_info + (8, 1, 3, '', 0) + """ + + @property + @abc.abstractmethod + def client_address(self) -> (str, None): + """ + The client address that the server sees. This is obtainable by querying + the ``pg_catalog.pg_stat_activity`` relation. + + `None` if unavailable. + """ + + @property + @abc.abstractmethod + def client_port(self) -> (int, None): + """ + The client port that the server sees. This is obtainable by querying + the ``pg_catalog.pg_stat_activity`` relation. + + `None` if unavailable. + """ + + @property + @abc.abstractmethod + def xact(self, isolation = None, mode = None) -> Transaction: + """ + Create a `Transaction` object using the given keyword arguments as its + configuration. + """ + + @property + @abc.abstractmethod + def settings(self) -> Settings: + """ + A `Settings` instance bound to the `Database`. + """ + + @abc.abstractmethod + def do(language, source) -> None: + """ + Execute a DO statement using the given language and source. + Always returns `None`. + + Likely to be a function of Connection.execute. + """ + + @abc.abstractmethod + def execute(sql) -> None: + """ + Execute an arbitrary block of SQL. Always returns `None` and raise + an exception on error. + """ + + @abc.abstractmethod + def prepare(self, sql : str) -> Statement: + """ + Create a new `Statement` instance bound to the connection + using the given SQL. + + >>> s = db.prepare("SELECT 1") + >>> c = s() + >>> c.next() + (1,) + """ + + @abc.abstractmethod + def query(self, sql : str, *args) -> Execution: + """ + Prepare and execute the statement, `sql`, with the given arguments. + Equivalent to ``db.prepare(sql)(*args)``. + """ + + @abc.abstractmethod + def statement_from_id(self, statement_id) -> Statement: + """ + Create a `Statement` object that was already prepared on the + server. The distinction between this and a regular query is that it + must be explicitly closed if it is no longer desired, and it is + instantiated using the statement identifier as opposed to the SQL + statement itself. + """ + + @abc.abstractmethod + def cursor_from_id(self, cursor_id) -> Cursor: + """ + Create a `Cursor` object from the given `cursor_id` that was already + declared on the server. + + `Cursor` objects created this way must *not* be closed when the object + is garbage collected. Rather, the user must explicitly close it for + the server resources to be released. This is in contrast to `Cursor` + objects that are created by invoking a `Statement` or a SRF + `StoredProcedure`. + """ + + @abc.abstractmethod + def proc(self, procedure_id) -> StoredProcedure: + """ + Create a `StoredProcedure` instance using the given identifier. + + The `proc_id` given can be either an ``Oid``, or a ``regprocedure`` + that identifies the stored procedure to create the interface for. + + >>> p = db.proc('version()') + >>> p() + 'PostgreSQL 8.3.0' + >>> qstr = "select oid from pg_proc where proname = 'generate_series'" + >>> db.prepare(qstr).first() + 1069 + >>> generate_series = db.proc(1069) + >>> list(generate_series(1,5)) + [1, 2, 3, 4, 5] + """ + + @abc.abstractmethod + def reset(self) -> None: + """ + Reset the connection into it's original state. + + Issues a ``RESET ALL`` to the database. If the database supports + removing temporary tables created in the session, then remove them. + Reapply initial configuration settings such as path. + + The purpose behind this method is to provide a soft-reconnect method + that re-initializes the connection into its original state. One + obvious use of this would be in a connection pool where the connection + is being recycled. + """ + + @abc.abstractmethod + def notify(self, *channels, **channel_and_payload) -> int: + """ + NOTIFY the channels with the given payload. + + Equivalent to issuing "NOTIFY " or "NOTIFY , " + for each item in `channels` and `channel_and_payload`. All NOTIFYs issued + *must* occur in the same transaction. + + The items in `channels` can either be a string or a tuple. If a string, + no payload is given, but if an item is a `builtins.tuple`, the second item + will be given as the payload. `channels` offers a means to issue NOTIFYs + in guaranteed order. + + The items in `channel_and_payload` are all payloaded NOTIFYs where the + keys are the channels and the values are the payloads. Order is undefined. + """ + + @abc.abstractmethod + def listen(self, *channels) -> None: + """ + Start listening to the given channels. + + Equivalent to issuing "LISTEN " for x in channels. + """ + + @abc.abstractmethod + def unlisten(self, *channels) -> None: + """ + Stop listening to the given channels. + + Equivalent to issuing "UNLISTEN " for x in channels. + """ + + @abc.abstractmethod + def listening_channels(self) -> ["channel name", ...]: + """ + Return an *iterator* to all the channels currently being listened to. + """ + + @abc.abstractmethod + def iternotifies(self, timeout = None) -> collections.abc.Iterator: + """ + Return an iterator to the notifications received by the connection. The + iterator *must* produce triples in the form ``(channel, payload, pid)``. + + If timeout is not `None`, `None` *must* be emitted at the specified + timeout interval. If the timeout is zero, all the pending notifications + *must* be yielded by the iterator and then `StopIteration` *must* be + raised. + + If the connection is closed for any reason, the iterator *must* silently + stop by raising `StopIteration`. Further error control is then the + responsibility of the user. + """ + +class TypeIO(Element): + _e_label = 'TYPIO' + + def _e_metas(self): + return () + +class SocketFactory(object): + @property + @abc.abstractmethod + def fatal_exception(self) -> Exception: + """ + The exception that is raised by sockets that indicate a fatal error. + + The exception can be a base exception as the `fatal_error_message` will + indicate if that particular exception is actually fatal. + """ + + @property + @abc.abstractmethod + def timeout_exception(self) -> Exception: + """ + The exception raised by the socket when an operation could not be + completed due to a configured time constraint. + """ + + @property + @abc.abstractmethod + def tryagain_exception(self) -> Exception: + """ + The exception raised by the socket when an operation was interrupted, but + should be tried again. + """ + + @property + @abc.abstractmethod + def tryagain(self, err : Exception) -> bool: + """ + Whether or not `err` suggests the operation should be tried again. + """ + + @abc.abstractmethod + def fatal_exception_message(self, err : Exception) -> (str, None): + """ + A function returning a string describing the failure, this string will be + given to the `postgresql.exceptions.ConnectionFailure` instance that will + subsequently be raised by the `Connection` object. + + Returns `None` when `err` is not actually fatal. + """ + + @abc.abstractmethod + def socket_secure(self, socket): + """ + Return a reference to the secured socket using the given parameters. + + If securing the socket for the connector is impossible, the user should + never be able to instantiate the connector with parameters requesting + security. + """ + + @abc.abstractmethod + def socket_factory_sequence(self) -> [collections.abc.Callable]: + """ + Return a sequence of `SocketCreator`s that `Connection` objects will use to + create the socket object. + """ + +class Category(Element): + """ + A category is an object that initializes the subject connection for a + specific purpose. + + Arguably, a runtime class for use with connections. + """ + _e_label = 'CATEGORY' + _e_factors = () + + @abc.abstractmethod + def __call__(self, connection): + """ + Initialize the given connection in order to conform to the category. + """ + +class Connector(Element): + """ + A connector is an object providing the necessary information to establish a + connection. This includes credentials, database settings, and many times + addressing information. + """ + _e_label = 'CONNECTOR' + _e_factors = ('driver', 'category') + + def __call__(self, *args, **kw): + """ + Create and connect. Arguments will be given to the `Connection` instance's + `connect` method. + """ + return self.driver.connection(self, *args, **kw) + + def __init__(self, + user : str = None, + password : str = None, + database : str = None, + settings : (dict, [(str,str)]) = None, + category : Category = None, + ): + if user is None: + # sure, it's a "required" keyword, makes for better documentation + raise TypeError("'user' is a required keyword") + self.user = user + self.password = password + self.database = database + self.settings = settings + self.category = category + if category is not None and not isinstance(category, Category): + raise TypeError("'category' must a be `None` or `postgresql.api.Category`") + +class Connection(Database): + """ + The interface to a connection to a PostgreSQL database. This is a + `Database` interface with the additional connection management tools that + are particular to using a remote database. + """ + _e_label = 'CONNECTION' + _e_factors = ('connector',) + + @property + @abc.abstractmethod + def connector(self) -> Connector: + """ + The :py:class:`Connector` instance facilitating the `Connection` object's + communication and initialization. + """ + + @property + @abc.abstractmethod + def closed(self) -> bool: + """ + `True` if the `Connection` is closed, `False` if the `Connection` is + open. + + >>> db.closed + True + """ + + @abc.abstractmethod + def clone(self) -> "Connection": + """ + Create another connection using the same factors as `self`. The returned + object should be open and ready for use. + """ + + @abc.abstractmethod + def connect(self) -> None: + """ + Establish the connection to the server and initialize the category. + + Does nothing if the connection is already established. + """ + cat = self.connector.category + if cat is not None: + cat(self) + + @abc.abstractmethod + def close(self) -> None: + """ + Close the connection. + + Does nothing if the connection is already closed. + """ + + @abc.abstractmethod + def __enter__(self): + """ + Establish the connection and return self. + """ + + @abc.abstractmethod + def __exit__(self, typ, obj, tb): + """ + Closes the connection and returns `False` when an exception is passed in, + `True` when `None`. + """ + +class Driver(Element): + """ + The `Driver` element provides the `Connector` and other information + pertaining to the implementation of the driver. Information about what the + driver supports is available in instances. + """ + _e_label = "DRIVER" + _e_factors = () + + @abc.abstractmethod + def connect(**kw): + """ + Create a connection using the given parameters for the Connector. + """ + +class Installation(Element): + """ + Interface to a PostgreSQL installation. Instances would provide various + information about an installation of PostgreSQL accessible by the Python + """ + _e_label = "INSTALLATION" + _e_factors = () + + @property + @abc.abstractmethod + def version(self): + """ + A version string consistent with what `SELECT version()` would output. + """ + + @property + @abc.abstractmethod + def version_info(self): + """ + A tuple specifying the version in a form similar to Python's + sys.version_info. (8, 3, 3, 'final', 0) + + See `postgresql.versionstring`. + """ + + @property + @abc.abstractmethod + def type(self): + """ + The "type" of PostgreSQL. Normally, the first component of the string + returned by pg_config. + """ + + @property + @abc.abstractmethod + def ssl(self) -> bool: + """ + Whether the installation supports SSL. + """ + +class Cluster(Element): + """ + Interface to a PostgreSQL cluster--a data directory. An implementation of + this provides a means to control a server. + """ + _e_label = 'CLUSTER' + _e_factors = ('installation', 'data_directory') + + @property + @abc.abstractmethod + def installation(self) -> Installation: + """ + The installation used by the cluster. + """ + + @property + @abc.abstractmethod + def data_directory(self) -> str: + """ + The path to the data directory of the cluster. + """ + + @abc.abstractmethod + def init(self, + initdb = None, + user = None, password = None, + encoding = None, locale = None, + collate = None, ctype = None, + monetary = None, numeric = None, time = None, + text_search_config = None, + xlogdir = None, + ): + """ + Create the cluster at the `data_directory` associated with the Cluster + instance. + """ + + @abc.abstractmethod + def drop(self): + """ + Kill the server and completely remove the data directory. + """ + + @abc.abstractmethod + def start(self): + """ + Start the cluster. + """ + + @abc.abstractmethod + def stop(self): + """ + Signal the server to shutdown. + """ + + @abc.abstractmethod + def kill(self): + """ + Kill the server. + """ + + @abc.abstractmethod + def restart(self): + """ + Restart the cluster. + """ + + @abc.abstractmethod + def wait_until_started(self, timeout = 10): + """ + After the start() method is ran, the database may not be ready for use. + This method provides a mechanism to block until the cluster is ready for + use. + + If the `timeout` is reached, the method *must* throw a + `postgresql.exceptions.ClusterTimeoutError`. + """ + + @abc.abstractmethod + def wait_until_stopped(self, timeout = 10): + """ + After the stop() method is ran, the database may still be running. + This method provides a mechanism to block until the cluster is completely + shutdown. + + If the `timeout` is reached, the method *must* throw a + `postgresql.exceptions.ClusterTimeoutError`. + """ + + @property + @abc.abstractmethod + def settings(self): + """ + A `Settings` interface to the ``postgresql.conf`` file associated with the + cluster. + """ + + @abc.abstractmethod + def __enter__(self): + """ + Start the cluster if it's not already running, and wait for it to be + readied. + """ + + @abc.abstractmethod + def __exit__(self, exc, val, tb): + """ + Stop the cluster and wait for it to shutdown *iff* it was started by the + corresponding enter. + """ + +__docformat__ = 'reStructuredText' +if __name__ == '__main__': + help(__package__ + '.api') diff --git a/py_opengauss/bin/__init__.py b/py_opengauss/bin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..946865ce5214fb93834d3c8b53d7a30016f157ad --- /dev/null +++ b/py_opengauss/bin/__init__.py @@ -0,0 +1,11 @@ +""" +Console-script collection package. + +Contents: + + pg_python + Python console with a PostgreSQL connection bound to `db`. + + pg_dotconf + Modify a PostgreSQL configuration file. +""" diff --git a/py_opengauss/bin/pg_dotconf.py b/py_opengauss/bin/pg_dotconf.py new file mode 100644 index 0000000000000000000000000000000000000000..31306e990ff1216c1e1be030b071e38ed87aaae4 --- /dev/null +++ b/py_opengauss/bin/pg_dotconf.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python +import sys +import os +from optparse import OptionParser +from .. import configfile +from .. import __version__ + +__all__ = ['command'] + +def command(args): + """ + pg_dotconf script entry point. + """ + op = OptionParser( + "%prog [--stdout] [-f settings] postgresql.conf ([param=val]|[param])*", + version = __version__ + ) + op.add_option( + '-f', '--file', + dest = 'settings', + help = 'A file of settings to *apply* to the given "postgresql.conf"', + default = [], + action = 'append', + ) + op.add_option( + '--stdout', + dest = 'stdout', + help = 'Redirect the product to standard output instead of writing back to the "postgresql.conf" file', + action = 'store_true', + default = False + ) + co, ca = op.parse_args(args[1:]) + if not ca: + return 0 + + settings = {} + for sfp in co.settings: + with open(sfp) as sf: + for line in sf: + pl = configfile.parse_line(line) + if pl is not None: + if comment not in line[pl[0].start]: + settings[line[pl[0]]] = unquote(line[pl[1]]) + + prev = None + for p in ca[1:]: + if '=' not in p: + k = p + v = None + else: + k, v = p.split('=', 1) + k = k.strip() + if not k: + sys.stderr.write("ERROR: invalid setting, %r after %r%s" %( + p, prev, os.linesep + )) + sys.stderr.write( + "HINT: Settings must take the form 'setting=value' " \ + "or 'setting_name_to_comment'. Settings must also be received " \ + "as a single argument." + os.linesep + ) + sys.exit(1) + prev = p + settings[k] = v + + fp = ca[0] + with open(fp, 'r') as fr: + lines = configfile.alter_config(settings, fr) + + if co.stdout or fp == '/dev/stdin': + for l in lines: + sys.stdout.write(l) + else: + with open(fp, 'w') as fw: + for l in lines: + fw.write(l) + return 0 + +if __name__ == '__main__': + sys.exit(command(sys.argv)) diff --git a/py_opengauss/bin/pg_python.py b/py_opengauss/bin/pg_python.py new file mode 100644 index 0000000000000000000000000000000000000000..a97aa97b39a7c65e6d508dacc6f9ca81789d5cb3 --- /dev/null +++ b/py_opengauss/bin/pg_python.py @@ -0,0 +1,136 @@ +## +# .bin.pg_python - Python console with a connection. +## +""" +Python command with a PG-API connection(``db``). +""" +import os +import sys +import re +import code +import optparse +import contextlib +from .. import clientparameters +from ..python import command as pycmd +from .. import project + +from ..driver import default as pg_driver +from .. import exceptions as pg_exc +from .. import sys as pg_sys +from .. import lib as pg_lib + +pq_trace = optparse.make_option( + '--pq-trace', + dest = 'pq_trace', + help = 'trace PQ protocol transmissions', + default = None, +) +default_options = [ + pq_trace, + clientparameters.option_lib, + clientparameters.option_libpath, +] + pycmd.default_optparse_options + +def command(argv = sys.argv): + p = clientparameters.DefaultParser( + "%prog [connection options] [script] ...", + version = project.version, + option_list = default_options + ) + p.disable_interspersed_args() + co, ca = p.parse_args(argv[1:]) + rv = 1 + + # Resolve the category. + pg_sys.libpath.insert(0, os.path.curdir) + pg_sys.libpath.extend(co.libpath or []) + if co.lib: + cat = pg_lib.Category(*map(pg_lib.load, co.lib)) + else: + cat = None + + trace_file = None + if co.pq_trace is not None: + trace_file = open(co.pq_trace, 'a') + + try: + need_prompt = False + cond = None + connector = None + connection = None + while connection is None: + try: + cond = clientparameters.collect(parsed_options = co, prompt_title = None) + if need_prompt: + # authspec error thrown last time, so force prompt. + cond['prompt_password'] = True + try: + clientparameters.resolve_password(cond, prompt_title = 'pg_python') + except EOFError: + raise SystemExit(1) + connector = pg_driver.fit(category = cat, **cond) + connection = connector() + if trace_file is not None: + connection.tracer = trace_file.write + connection.connect() + except pg_exc.ClientCannotConnectError as err: + for att in connection.failures: + exc = att.error + if isinstance(exc, pg_exc.AuthenticationSpecificationError): + sys.stderr.write(os.linesep + exc.message + (os.linesep*2)) + # keep prompting the user + need_prompt = True + connection = None + break + else: + # no invalid password failures.. + raise + + pythonexec = pycmd.Execution(ca, + context = getattr(co, 'python_context', None), + loader = getattr(co, 'python_main', None), + ) + + builtin_overload = { + # New built-ins + 'connector' : connector, + 'db' : connection, + 'do' : connection.do, + 'prepare' : connection.prepare, + + 'sqlexec' : connection.execute, + 'settings' : connection.settings, + 'proc' : connection.proc, + 'xact' : connection.xact, + } + if not isinstance(__builtins__, dict): + builtins_d = __builtins__.__dict__ + else: + builtins_d = __builtins__ + restore = {k : builtins_d.get(k) for k in builtin_overload} + + builtins_d.update(builtin_overload) + try: + with connection: + rv = pythonexec( + context = pycmd.postmortem(os.environ.get('PYTHON_POSTMORTEM')) + ) + exc = getattr(sys, 'last_type', None) + if rv and exc and not issubclass(exc, Exception): + # Don't try to close it if wasn't an Exception. + del connection.pq.socket + finally: + # restore __builtins__ + builtins_d.update(restore) + for k, v in builtin_overload.items(): + if v is None: + del builtins_d[x] + if trace_file is not None: + trace_file.close() + except: + pg_sys.libpath.remove(os.path.curdir) + raise + return rv + +if __name__ == '__main__': + sys.exit(command(sys.argv)) diff --git a/py_opengauss/clientparameters.py b/py_opengauss/clientparameters.py new file mode 100644 index 0000000000000000000000000000000000000000..554097671e2032ed055df65f0fa1a0bd7c31f643 --- /dev/null +++ b/py_opengauss/clientparameters.py @@ -0,0 +1,643 @@ +## +# .clientparameters +## +""" +Collect client connection parameters from various sources. + +This module provides functions for collecting client parameters from various +sources such as user relative defaults, environment variables, and even command +line options. + +There are two primary data-structures that this module deals with: normalized +parameters and denormalized parameters. + +Normalized parameters is a proper mapping object, dictionary, consisting of +the parameters used to apply to a connection creation interface. The high-level +interface, ``collect`` returns normalized parameters. + +Denormalized parameters is a sequence or iterable of key-value pairs. However, +the key is always a tuple whose components make up the "key-path". This is used +to support sub-dictionaries like settings:: + + >>> normal_params = { + 'user' : 'dbusername', + 'host' : 'localhost', + 'settings' : {'default_statistics_target' : 200, 'search_path' : 'home,public'} + } + +Denormalized parameters are used to simplify the overriding of past parameters. +For this to work with dictionaries in a general fashion, dictionary objects +would need a "deep update" method. +""" +import sys +import os +import configparser +import optparse +from itertools import chain +from functools import partial + +from . import iri as pg_iri +from . import pgpassfile as pg_pass +from . exceptions import Error + +class ClientParameterError(Error): + code = '-*000' + source = '.clientparameters' +class ServiceDoesNotExistError(ClientParameterError): + code = '-*srv' + +try: + from getpass import getuser, getpass +except ImportError: + getpass = raw_input + def getuser(): + return 'postgres' + +default_host = 'localhost' +default_port = 5432 + +pg_service_envvar = 'PGSERVICE' +pg_service_file_envvar = 'PGSERVICEFILE' +pg_sysconfdir_envvar = 'PGSYSCONFDIR' +pg_service_filename = 'pg_service.conf' +pg_service_user_filename = '.pg_service.conf' + +# posix +pg_home_passfile = '.pgpass' +pg_home_directory = '.postgresql' + +# win32 +pg_appdata_directory = 'postgresql' +pg_appdata_passfile = 'pgpass.conf' + +# In order to support pg_service.conf, it is +# necessary to identify driver parameters, so +# that database configuration parameters can +# be placed in settings. +pg_service_driver_parameters = set([ + 'user', + 'host', + 'database', + 'port', + 'password', + + 'sslcrtfile', + 'sslkeyfile', + 'sslrootcrtfile', + 'sslrootkeyfile', + + 'sslmode', + 'server_encoding', + 'connect_timeout', +]) + +# environment variables that will be in the parameters' "settings" dictionary. +default_envvar_settings_map = { + 'TZ' : 'timezone', + 'DATESTYLE' : 'datestyle', + 'CLIENTENCODING' : 'client_encoding', + 'GEQO' : 'geqo', + 'OPTIONS' : 'options', +} + +# Environment variables that require no transformation. +default_envvar_map = { + 'USER' : 'user', + 'DATABASE' : 'database', + 'HOST' : 'host', + 'PORT' : 'port', + 'PASSWORD' : 'password', + 'SSLMODE' : 'sslmode', + 'SSLKEY' : 'sslkey', + 'CONNECT_TIMEOUT' : 'connect_timeout', + + 'REALM' : 'kerberos4_realm', + 'KRBSRVNAME' : 'kerberos5_service', + + # Extensions + #'ROLE' : 'role', # SET ROLE $PGROLE + + # This keyword *should* never make it to a connect() function + # as `resolve_password` should be called to fill in the + # parameter accordingly. + 'PASSFILE' : 'pgpassfile', +} + +def defaults(environ = os.environ): + """ + Produce the defaults based on the existing configuration. + """ + user = getuser() or 'postgres' + userdir = os.path.expanduser('~' + user) or '/dev/null' + pgdata = os.path.join(userdir, pg_home_directory) + yield ('user',), getuser() + yield ('host',), default_host + yield ('port',), default_port + + # If appdata is available, override the pgdata and pgpassfile + # configuration settings. + if sys.platform == 'win32': + appdata = environ.get('APPDATA') + if appdata: + pgdata = os.path.join(appdata, pg_appdata_directory) + pgpassfile = os.path.join(pgdata, pg_appdata_passfile) + else: + pgpassfile = os.path.join(userdir, pg_home_passfile) + + for k, v in ( + ('sslcrtfile', os.path.join(pgdata, 'postgresql.crt')), + ('sslkeyfile', os.path.join(pgdata, 'postgresql.key')), + ('sslrootcrtfile', os.path.join(pgdata, 'root.crt')), + ('sslrootcrlfile', os.path.join(pgdata, 'root.crl')), + ('pgpassfile', pgpassfile), + ): + if os.path.exists(v): + yield (k,), v + +def envvars(environ = os.environ, modifier = 'PG'.__add__): + """ + Create a clientparams dictionary from the given environment variables. + + PGUSER -> user + PGDATABASE -> database + PGHOST -> host + PGHOSTADDR -> host (overrides PGHOST) + PGPORT -> port + + PGPASSWORD -> password + PGPASSFILE -> pgpassfile + + PGSSLMODE -> sslmode + PGREQUIRESSL gets rewritten into "sslmode = 'require'". + + PGREALM -> kerberos4_realm + PGKRBSVRNAME -> kerberos5_service + PGSSLKEY -> sslkey + + PGTZ -> settings['timezone'] + PGDATESTYLE -> settings['datestyle'] + PGCLIENTENCODING -> settings['client_encoding'] + PGGEQO -> settings['geqo'] + + The 'PG' prefix can be customized via the `modifier` argument. However, + PGSYSCONFDIR will not respect any such change as it's not a client parameter + itself. + + :param modifier: environment variable key modifier + """ + hostaddr = modifier('HOSTADDR') + reqssl = modifier('REQUIRESSL') + if reqssl in environ: + if environ[reqssl].strip() == '1': + yield ('sslmode',), ('require', reqssl + '=1') + + for k, v in default_envvar_map.items(): + k = modifier(k) + if k in environ: + yield ((v,), environ[k]) + if hostaddr in environ: + yield (('host',), environ[hostaddr]) + + envvar_settings_map = (( + (modifier(k), v) for k,v in default_envvar_settings_map.items() + )) + settings = [ + (('settings', v,), environ[k]) for k, v in envvar_settings_map if k in environ + ] + + # PGSYSCONFDIR based + if pg_sysconfdir_envvar in environ: + yield ('config-pg_sysconfdir', environ[pg_sysconfdir_envvar]) + # PGSERVICEFILE based + if pg_service_file_envvar in environ: + yield ('config-pg_service_file', environ[pg_service_file_envvar]) + + service = modifier('SERVICE') + if service in environ: + yield ('pg_service', environ[service]) + +## +# optparse options +## + +option_datadir = optparse.make_option('-D', '--datadir', + help = 'location of the database storage area', + default = None, + dest = 'datadir', +) + +option_in_xact = optparse.make_option('-1', '--with-transaction', + dest = 'in_xact', + action = 'store_true', + help = 'run operation with a transaction block', +) + +def append_db_client_parameters(option, opt_str, value, parser): + # for options without arguments, None is passed in. + value = True if value is None else value + parser.values.db_client_parameters.append( + ((option.dest,), value) + ) + +make_option = partial( + optparse.make_option, + action = 'callback', + callback = append_db_client_parameters +) + +option_user = make_option('-U', '--username', + dest = 'user', + type = 'str', + help = 'user name to connect as', +) +option_database = make_option('-d', '--database', + type = 'str', + help = "database's name", + dest = 'database', +) +option_password = make_option('-W', '--password', + dest = 'prompt_password', + help = 'prompt for password', +) +option_host = make_option('-h', '--host', + help = 'database server host', + type = 'str', + dest = 'host', +) +option_port = make_option('-p', '--port', + help = 'database server port', + type = 'str', + dest = 'port', +) +option_unix = make_option('--unix', + help = 'path to filesystem socket', + type = 'str', + dest = 'unix', +) + +def append_settings(option, opt_str, value, parser): + """ + split the string into a (key,value) pair tuple + """ + kv = value.split('=', 1) + if len(kv) != 2: + raise OptionValueError("invalid setting argument, %r" %(value,)) + parser.values.db_client_parameters.append( + ((option.dest, kv[0]), kv[1]) + ) + +option_settings = make_option('-s', '--setting', + dest = 'settings', + help = 'run-time parameters to set upon connecting', + callback = append_settings, + type = 'str', +) + +option_sslmode = make_option('--ssl-mode', + dest = 'sslmode', + help = 'SSL requirement for connectivity: require, prefer, allow, disable', + choices = ('require','prefer','allow','disable'), + type = 'choice', +) + +def append_db_client_x_parameters(option, opt_str, value, parser): + parser.values.db_client_parameters.append((option.dest, value)) +make_x_option = partial(make_option, callback = append_db_client_x_parameters) + +option_iri = make_x_option('-I', '--iri', + help = 'database locator string [pq://user:password@host:port/database?[driver_param]=value&setting=value]', + type = 'str', + dest = 'pq_iri', +) + +option_lib = optparse.make_option('-l', + help = 'bind the library found in postgresql.sys.libpath to the connection', + type = 'str', + dest = 'lib', + action = 'append' +) +option_libpath = optparse.make_option('-L', + help = 'append the library path', + type = 'str', + dest = 'libpath', + action = 'append' +) + +# PostgreSQL Standard Options +standard_optparse_options = ( + option_host, option_port, + option_user, option_password, + option_database, +) + +class StandardParser(optparse.OptionParser): + """ + Option parser limited to the basic -U, -h, -p, -W, and -D options. + This parser subclass is necessary for two reasons: + + 1. _add_help_option override to not conflict with -h + 2. Initialize the db_client_parameters on the parser's values. + + See the DefaultParser for more fun. + """ + standard_option_list = standard_optparse_options + + def get_default_values(self, *args, **kw): + v = super().get_default_values(*args, **kw) + v.db_client_parameters = [] + return v + + def _add_help_option(self): + # Only allow long --help so that it will not conflict with -h(host) + self.add_option("--help", + action = "help", + help = "show this help message and exit", + ) + +# Extended Options +default_optparse_options = [ + option_unix, + option_sslmode, + option_settings, +# Complex Options + option_iri, +] +default_optparse_options.extend(standard_optparse_options) + +class DefaultParser(StandardParser): + """ + Parser that includes a variety of connectivity options. + (IRI, sslmode, settings) + """ + standard_option_list = default_optparse_options + +def resolve_password(parameters, getpass = getpass, prompt_title = ''): + """ + Given a parameters dictionary, resolve the 'password' key. + + If `prompt_password` is `True`. + If sys.stdin is a TTY, use `getpass` to prompt the user. + Otherwise, read a single line from sys.stdin. + delete 'prompt_password' from the dictionary. + + Otherwise. + If the 'password' key is `None`, attempt to resolve the password using the + 'pgpassfile' key. + + Finally, remove the pgpassfile key as the password has been resolved for the + given parameters. + + :param parameters: a fully normalized set of client parameters(dict) + """ + prompt_for_password = parameters.pop('prompt_password', False) + pgpassfile = parameters.pop('pgpassfile', None) + prompt_title = parameters.pop('prompt_title', None) + if prompt_for_password is True: + # it's a prompt + if sys.stdin.isatty(): + prompt = prompt_title or parameters.pop('prompt_title', '') + prompt += '[' + pg_iri.serialize(parameters, obscure_password = True) + ']' + parameters['password'] = getpass("Password for " + prompt +": ") + else: + # getpass will throw an exception if it's not a tty, + # so just take the next line. + pw = sys.stdin.readline() + # try to clean it up.. + if pw.endswith(os.linesep): + pw = pw[:len(pw)-len(os.linesep)] + parameters['password'] = pw + else: + if parameters.get('password') is None: + # No password? Look in the pgpassfile. + if pgpassfile is not None: + parameters['password'] = pg_pass.lookup_pgpass(parameters, pgpassfile) + # Don't need the pgpassfile parameter anymore as the password + # has been resolved. + +def x_settings(sdict, config): + d=dict(sdict) + for (k,v) in d.items(): + yield (('settings', k), v) + +def denormalize_parameters(p): + """ + Given a fully normalized parameters dictionary: + {'host': 'localhost', 'settings' : {'timezone':'utc'}} + + Denormalize it: + [(('host',), 'localhost'), (('settings','timezone'), 'utc')] + """ + for k,v in p.items(): + if k == 'settings': + for sk, sv in dict(v).items(): + yield (('settings', sk), sv) + else: + yield ((k,), v) + +def x_pq_iri(iri, config): + return denormalize_parameters(pg_iri.parse(iri)) + +# Lookup service data using the `service_name` +# Be sure to map 'dbname' to 'database'. +def x_pg_service(service_name, config): + service_files = [] + + f = config.get('pg_service_file') + if f is not None: + # service file override + service_files.append(f) + else: + # override is not specified, use the user service file + home = os.path.expanduser('~' + getuser()) + service_files.append(os.path.join(home, pg_service_user_filename)) + + # global service file is checked next. + sysconfdir = config.get('pg_sysconfdir') + if sysconfdir: + sf = config.get('pg_service_filename', pg_service_filename) + f = os.path.join(sysconfdir, sf) + # existence will be checked later. + service_files.append(f) + + for sf in service_files: + if not os.path.exists(sf): + continue + + cp = configparser.RawConfigParser() + cp.read(sf) + try: + s = cp.items(service_name) + except configparser.NoSectionError: + continue + + for (k, v) in s: + k = k.lower() + if k == 'ldap': + yield ('pg_ldap', ':'.join((k, v))) + elif k == 'pg_service': + # ignore + pass + elif k == 'hostaddr': + # XXX: should yield ipv as well? + yield (('host',), v) + elif k == 'dbname': + yield (('database',), v) + elif k not in pg_service_driver_parameters: + # it's a GUC. + yield (('settings', k), v) + else: + yield ((k,), v) + else: + break + else: + # iterator exhausted; service not found + if sum([os.path.exists(x) for x in service_files]): + details = { + 'context': ', '.join(service_files), + } + else: + details = { + 'hint': "No service files could be found." + } + raise ServiceDoesNotExistError( + 'cannot find service named "{0}"'.format(service_name), + details = details + ) + +def x_pg_ldap(ldap_url, config): + raise Exception("cannot resolve ldap URLs") + +default_x_callbacks = { + 'settings' : x_settings, + 'pq_iri' : x_pq_iri, + 'pg_service' : x_pg_service, + 'pg_ldap' : x_pg_ldap, +} + +def extrapolate(iter, config = None, callbacks = default_x_callbacks): + """ + Given an iterable of standardized settings, + + [((path0, path1, ..., pathN), value)] + + Process any callbacks. + """ + config = config or {} + for item in iter: + k = item[0] + if isinstance(k, str): + if k.startswith('config-'): + config[k[len('config-'):]] = item[1] + else: + cb = callbacks.get(k) + if cb: + for x in extrapolate( + cb(item[1], config), + config = config, + callbacks = callbacks + ): + yield x + else: + pass + else: + yield item + +def normalize_parameter(kv): + """ + Translate a parameter into standard form. + """ + (k, v) = kv + if k[0] == 'requiressl' and v in ('1', True): + k[0] = 'sslmode' + v = 'require' + elif k[0] == 'dbname': + k[0] = 'database' + elif k[0] == 'sslmode': + v = v.lower() + return (tuple(k),v) + +def normalize(iter): + """ + Normally takes the output of `extrapolate` and makes a dictionary suitable + for applying to a connector. + """ + rd = {} + for (k, v) in iter: + sd = rd + for sk in k[:len(k)-1]: + sd = sd.setdefault(sk, {}) + sd[k[-1]] = v + return rd + +def resolve_pg_service_file( + environ = os.environ, + default_pg_sysconfdir = None, + default_pg_service_filename = pg_service_filename +): + sysconfdir = environ.get(pg_sysconfdir_envvar, default_pg_sysconfdir) + if sysconfdir: + return os.path.join(sysconfdir, default_pg_service_filename) + return None + +def collect( + parsed_options = None, + no_defaults = False, + environ = os.environ, + environ_prefix = 'PG', + default_pg_sysconfdir = None, + pg_service_file = None, + prompt_title = '', + parameters = (), +): + """ + Build a normalized client parameters dictionary for use with a connection + construction interface. + + :param parsed_options: options parsed using the `DefaultParser` + :param no_defaults: Don't build-out defaults like 'user' from getpass.getuser() + :param environ: environment variables to use, `None` to disable + :param environ_prefix: prefix to use for collecting environment variables + :param default_pg_sysconfdir: default 'PGSYSCONFDIR' to use + :param pg_service_file: the pg-service file to actually use + :param prompt_title: additional title to use if a prompt request is made + :param parameters: base-client parameters to use(applied after defaults) + """ + d_parameters = [] + d_parameters.append([('config-environ', environ)]) + if default_pg_sysconfdir is not None: + d_parameters.append([ + ('config-pg_sysconfdir', default_pg_sysconfdir) + ]) + if pg_service_file is not None: + d_parameters.append([ + ('config-pg_service_file', pg_service_file) + ]) + + if not no_defaults: + d_parameters.append(defaults(environ = environ)) + + if parameters: + d_parameters.append(denormalize_parameters(dict(parameters))) + + if environ is not None: + d_parameters.append(envvars( + environ = environ, + modifier = environ_prefix.__add__ + )) + cop = getattr(parsed_options, 'db_client_parameters', None) + if cop: + d_parameters.append(cop) + + cpd = normalize(extrapolate(chain(*d_parameters))) + if prompt_title is not None: + resolve_password(cpd, prompt_title = prompt_title) + return cpd + +if __name__ == '__main__': + import pprint + p = DefaultParser( + description = "print the clientparams dictionary for the environment" + ) + (co, ca) = p.parse_args() + r = collect(parsed_options = co, prompt_title = 'custom_prompt_title') + pprint.pprint(r) diff --git a/py_opengauss/cluster.py b/py_opengauss/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..122103ebab24b7faf7f18d510f26903d0d6104bc --- /dev/null +++ b/py_opengauss/cluster.py @@ -0,0 +1,637 @@ +## +# .cluster - PostgreSQL cluster management +## +""" +Create, control, and destroy PostgreSQL clusters. + +postgresql.cluster provides a programmer's interface to controlling a PostgreSQL +cluster. It provides direct access to proper signalling interfaces. +""" +import sys +import os +import errno +import time +import subprocess as sp +from tempfile import NamedTemporaryFile + +from . import api as pg_api +from . import configfile +from . import installation as pg_inn +from . import exceptions as pg_exc +from . import driver as pg_driver +from .encodings.aliases import get_python_name +from .python.os import close_fds + +if sys.platform in ('win32', 'win64'): + from .port import signal1_msw as signal + pg_kill = signal.kill + def namedtemp(encoding): + return NamedTemporaryFile(delete = False, mode = 'w', encoding=encoding) +else: + import signal + pg_kill = os.kill + def namedtemp(encoding): + return NamedTemporaryFile(mode = 'w', encoding=encoding) + +class ClusterError(pg_exc.Error): + """ + General cluster error. + """ + code = '-C000' + source = 'CLUSTER' +class ClusterInitializationError(ClusterError): + """ + General cluster initialization failure. + """ + code = '-Cini' +class InitDBError(ClusterInitializationError): + """ + A non-zero result was returned by the initdb command. + """ + code = '-Cidb' +class ClusterStartupError(ClusterError): + """ + Cluster startup failed. + """ + code = '-Cbot' +class ClusterNotRunningError(ClusterError): + """ + Cluster is not running. + """ + code = '-Cdwn' +class ClusterTimeoutError(ClusterError): + """ + Cluster operation timed out. + """ + code = '-Cout' + +class ClusterWarning(pg_exc.Warning): + """ + Warning issued by cluster operations. + """ + code = '-Cwrn' + source = 'CLUSTER' + +DEFAULT_CLUSTER_ENCODING = 'utf-8' +DEFAULT_CONFIG_FILENAME = 'postgresql.conf' +DEFAULT_HBA_FILENAME = 'pg_hba.conf' +DEFAULT_PID_FILENAME = 'postmaster.pid' + +initdb_option_map = { + 'encoding' : '-E', + 'authentication' : '-A', + 'user' : '-U', + # pwprompt is not supported. + # interactive use should be implemented by the application + # calling Cluster.init() +} + +class Cluster(pg_api.Cluster): + """ + Interface to a PostgreSQL cluster. + + Provides mechanisms to start, stop, restart, kill, drop, and initalize a + cluster(data directory). + + Cluster does not strive to be consistent with ``pg_ctl``. This is considered + to be a base class for managing a cluster, and is intended to be extended to + accommodate for a particular purpose. + """ + driver = pg_driver.default + installation = None + data_directory = None + DEFAULT_CLUSTER_ENCODING = DEFAULT_CLUSTER_ENCODING + DEFAULT_CONFIG_FILENAME = DEFAULT_CONFIG_FILENAME + DEFAULT_PID_FILENAME = DEFAULT_PID_FILENAME + DEFAULT_HBA_FILENAME = DEFAULT_HBA_FILENAME + + @property + def state(self): + if self.running(): + return 'running' + if not os.path.exists(self.data_directory): + return 'void' + return 'stopped' + + def _e_metas(self): + state = self.state + yield (None, '[' + state + ']') + if state == 'running': + yield ('pid', self.state) + + @property + def daemon_path(self): + """ + Path to the executable to use to startup the cluster. + """ + return self.installation.postmaster or self.installation.postgres + + def get_pid_from_file(self): + """ + The current pid from the postmaster.pid file. + """ + try: + path = os.path.join(self.data_directory, self.DEFAULT_PID_FILENAME) + with open(path) as f: + return int(f.readline()) + except IOError as e: + if e.errno in (errno.EIO, errno.ENOENT): + return None + + @property + def pid(self): + """ + If we have the subprocess, use the pid on the object. + """ + pid = self.get_pid_from_file() + if pid is None: + d = self.daemon_process + if d is not None: + return d.pid + return pid + + @property + def settings(self): + if not hasattr(self, '_settings'): + self._settings = configfile.ConfigFile(self.pgsql_dot_conf) + return self._settings + + @property + def hba_file(self, join = os.path.join): + """ + The path to the HBA file of the cluster. + """ + return self.settings.get( + 'hba_file', + join(self.data_directory, self.DEFAULT_HBA_FILENAME) + ) + + def __init__(self, installation, data_directory): + self.installation = installation + self.data_directory = os.path.abspath(data_directory) + self.pgsql_dot_conf = os.path.join( + self.data_directory, + self.DEFAULT_CONFIG_FILENAME + ) + self.daemon_process = None + self.daemon_command = None + + def __repr__(self, format = "{mod}.{name}({ins!r}, {dir!r})".format): + return format( + type(self).__module__, + type(self).__name__, + self.installation, + self.data_directory, + ) + + def __enter__(self): + """ + Start the cluster and wait for it to startup. + """ + self.start() + self.wait_until_started() + return self + + def __exit__(self, typ, val, tb): + """ + Stop the cluster and wait for it to shutdown. + """ + self.stop() + self.wait_until_stopped() + + def init(self, password = None, timeout = None, **kw): + """ + Create the cluster at the given `data_directory` using the + provided keyword parameters as options to the command. + + `command_option_map` provides the mapping of keyword arguments + to command options. + """ + initdb = self.installation.initdb + if initdb is None: + initdb = (self.installation.pg_ctl, 'initdb',) + else: + initdb = (initdb,) + + if None in initdb: + raise ClusterInitializationError( + "unable to find executable for cluster initialization", + details = { + 'detail' : "The installation does not have 'initdb' or 'pg_ctl'.", + }, + creator = self + ) + # Transform keyword options into command options for the executable. + + # A default is used rather than looking at the environment to, well, + # avoid looking at the environment. + kw.setdefault('encoding', self.DEFAULT_CLUSTER_ENCODING) + opts = [] + for x in kw: + if x in ('logfile', 'extra_arguments'): + continue + if x not in initdb_option_map: + raise TypeError("got an unexpected keyword argument %r" %(x,)) + opts.append(initdb_option_map[x]) + opts.append(kw[x]) + logfile = kw.get('logfile') or sp.PIPE + extra_args = tuple([ + str(x) for x in kw.get('extra_arguments', ()) + ]) + + supw_file = () + supw_tmp = None + p = None + try: + if password is not None: + # got a superuserpass, store it in a tempfile for initdb + supw_tmp = namedtemp(encoding = get_python_name(kw['encoding'])) + supw_tmp.write(password) + supw_tmp.flush() + supw_file = ('--pwfile=' + supw_tmp.name,) + + cmd = initdb + ('-D', self.data_directory) \ + + tuple(opts) \ + + supw_file \ + + extra_args + + p = sp.Popen( + cmd, + close_fds = close_fds, + bufsize = 1024 * 5, # not expecting this to ever be filled. + stdin = None, + stdout = logfile, + # stderr is used to identify a reasonable error message. + stderr = sp.PIPE, + ) + + try: + stdout, stderr = p.communicate(timeout=timeout) + except sp.TimeoutExpired: + p.kill() + stdout, stderr = p.communicate() + finally: + rc = p.returncode + + if rc != 0: + # initdb returned non-zero, pickup stderr and attach to exception. + + r = stderr + try: + msg = r.decode('utf-8') + except UnicodeDecodeError: + # split up the lines, and use rep. + msg = os.linesep.join([ + repr(x)[2:-1] for x in r.splitlines() + ]) + + raise InitDBError( + "initdb exited with non-zero status", + details = { + 'command': cmd, + 'stderr': msg, + 'stdout': msg, + }, + creator = self + ) + finally: + if supw_tmp is not None: + n = supw_tmp.name + supw_tmp.close() + # XXX: win32 compensation. + if os.path.exists(n): + os.unlink(n) + + def drop(self): + """ + Stop the cluster and remove it from the filesystem + """ + if self.running(): + self.shutdown() + try: + self.wait_until_stopped() + except ClusterTimeoutError: + self.kill() + try: + self.wait_until_stopped() + except ClusterTimeoutError: + ClusterWarning( + 'cluster failed to shutdown after kill', + details = {'hint' : 'Shared memory may have been leaked.'}, + creator = self + ).emit() + # Really, using rm -rf would be the best, but use this for portability. + for root, dirs, files in os.walk(self.data_directory, topdown = False): + for name in files: + os.remove(os.path.join(root, name)) + for name in dirs: + os.rmdir(os.path.join(root, name)) + os.rmdir(self.data_directory) + + def start(self, logfile = None, settings = None): + """ + Start the cluster. + """ + if self.running(): + return + cmd = (self.daemon_path, '-D', self.data_directory) + if settings is not None: + for k,v in dict(settings).items(): + cmd.append('--{k}={v}'.format(k=k,v=v)) + + p = sp.Popen( + cmd, + close_fds = close_fds, + bufsize = 1024, + # send everything to logfile + stdout = sp.PIPE if logfile is None else logfile, + stderr = sp.STDOUT, + stdin = sp.PIPE, + ) + if logfile is None: + p.stdout.close() + p.stdin.close() + self.daemon_process = p + self.daemon_command = cmd + + def restart(self, logfile = None, settings = None, timeout = 10): + """ + Restart the cluster gracefully. + + This provides a higher level interface to stopping then starting the + cluster. It will perform the wait operations and block until the + restart is complete. + + If waiting is not desired, .start() and .stop() should be used directly. + """ + if self.running(): + self.stop() + self.wait_until_stopped(timeout = timeout) + if self.running(): + raise ClusterError( + "failed to shutdown cluster", + creator = self + ) + self.start(logfile = logfile, settings = settings) + self.wait_until_started(timeout = timeout) + + def reload(self): + """ + Signal the cluster to reload its configuration file. + """ + pid = self.pid + if pid is not None: + try: + pg_kill(pid, signal.SIGHUP) + except OSError as e: + if e.errno != errno.ESRCH: + raise + + def stop(self): + """ + Stop the cluster gracefully waiting for clients to disconnect(SIGTERM). + """ + pid = self.pid + if pid is not None: + try: + pg_kill(pid, signal.SIGTERM) + except OSError as e: + if e.errno != errno.ESRCH: + raise + + def shutdown(self): + """ + Shutdown the cluster as soon as possible, disconnecting clients. + """ + pid = self.pid + if pid is not None: + try: + pg_kill(pid, signal.SIGINT) + except OSError as e: + if e.errno != errno.ESRCH: + raise + + def kill(self): + """ + Stop the cluster immediately(SIGKILL). + + Does *not* wait for shutdown. + """ + pid = self.pid + if pid is not None: + try: + pg_kill(pid, signal.SIGKILL) + except OSError as e: + if e.errno != errno.ESRCH: + raise + # already dead, so it would seem. + + def initialized(self): + """ + Whether or not the data directory *appears* to be a valid cluster. + """ + if os.path.isdir(self.data_directory) and \ + os.path.exists(self.pgsql_dot_conf) and \ + os.path.isdir(os.path.join(self.data_directory, 'base')): + return True + return False + + def running(self): + """ + Whether or not the postmaster is running. + + This does *not* mean the cluster is accepting connections. + """ + if self.daemon_process is not None: + r = self.daemon_process.poll() + if r is not None: + pid = self.get_pid_from_file() + if pid is not None: + # daemon process does not exist, but there's a pidfile. + self.daemon_process = None + return self.running() + return False + else: + return True + else: + pid = self.get_pid_from_file() + if pid is None: + return False + try: + pg_kill(pid, signal.SIG_DFL) + except OSError as e: + if e.errno != errno.ESRCH: + raise + return False + return True + + def connector(self, **kw): + """ + Create a postgresql.driver connector based on the given keywords and + listen_addresses and port configuration in settings. + """ + host, port = self.address() + return self.driver.fit( + host = host or 'localhost', + port = port or 5432, + **kw + ) + + def connection(self, **kw): + """ + Create a connection object to the cluster, but do not connect. + """ + return self.connector(**kw)() + + def connect(self, **kw): + """ + Create an established connection from the connector. + + Cluster must be running. + """ + if not self.running(): + raise ClusterNotRunningError( + "cannot connect if cluster is not running", + creator = self + ) + x = self.connection(**kw) + x.connect() + return x + + def address(self): + """ + Get the host-port pair from the configuration. + """ + d = self.settings.getset(( + 'listen_addresses', 'port', + )) + if d.get('listen_addresses') is not None: + # Prefer localhost over other addresses. + # More likely to get a successful connection. + addrs = d.get('listen_addresses').lower().split(',') + if 'localhost' in addrs or '*' in addrs: + host = 'localhost' + elif '127.0.0.1' in addrs: + host = '127.0.0.1' + elif '::1' in addrs: + host = '::1' + else: + host = addrs[0] + else: + host = None + return (host, d.get('port')) + + def ready_for_connections(self): + """ + If the daemon is running, and is not in startup mode. + + This only works for clusters configured for TCP/IP connections. + """ + if not self.running(): + return False + e = None + host, port = self.address() + connection = self.driver.fit( + user = ' -*- ping -*- ', + host = host, port = port, + database = 'template1', + sslmode = 'disable', + )() + try: + connection.connect() + except pg_exc.ClientCannotConnectError as err: + for attempt in err.database.failures: + x = attempt.error + if self.installation.version_info[:2] < (8,1): + if isinstance(x, ( + pg_exc.UndefinedObjectError, + pg_exc.AuthenticationSpecificationError, + )): + # undefined user.. whatever... + return True + else: + if isinstance(x, pg_exc.AuthenticationSpecificationError): + return True + # configuration file error. ya, that's probably not going to change. + if isinstance(x, (pg_exc.CFError, pg_exc.ProtocolError)): + raise x + if isinstance(x, pg_exc.ServerNotReadyError): + e = x + break + else: + e = err + # the else true means we successfully connected with those + # credentials... strange, but true.. + return e if e is not None else True + + def wait_until_started(self, timeout = 10, delay = 0.05): + """ + After the `start` method is used, this can be ran in order to block + until the cluster is ready for use. + + This method loops until `ready_for_connections` returns `True` in + order to make sure that the cluster is actually up. + """ + start = time.time() + checkpoint = start + while True: + if not self.running(): + if self.daemon_process is not None: + r = self.daemon_process.returncode + if r is not None: + raise ClusterStartupError( + "postgres daemon terminated", + details = { + 'RESULT' : r, + 'COMMAND' : self.daemon_command, + }, + creator = self + ) + else: + raise ClusterNotRunningError( + "postgres daemon has not been started", + creator = self + ) + r = self.ready_for_connections() + + checkpoint = time.time() + if r is True: + break + + if checkpoint - start >= timeout: + # timeout was reached, but raise ServerNotReadyError + # to signal to the user that it was *not* due to some unknown + # condition, rather it's *still* starting up. + if r is not None and isinstance(r, pg_exc.ServerNotReadyError): + raise r + e = ClusterTimeoutError( + 'timeout on startup', + creator = self + ) + if r not in (True,False): + raise e from r + raise e + time.sleep(delay) + + def wait_until_stopped(self, timeout = 10, delay = 0.05): + """ + After the `stop` method is used, this can be ran in order to block until + the cluster is shutdown. + + Additionally, catching `ClusterTimeoutError` exceptions would be a + starting point for making decisions about whether or not to issue a kill + to the daemon. + """ + start = time.time() + while self.running() is True: + # pickup the exit code. + if self.daemon_process is not None: + self.last_exit_code = self.daemon_process.poll() + else: + self.last_exit_code = pg_kill(self.get_pid_from_file(), 0) + if time.time() - start >= timeout: + raise ClusterTimeoutError( + 'timeout on shutdown', + creator = self, + ) + time.sleep(delay) diff --git a/py_opengauss/configfile.py b/py_opengauss/configfile.py new file mode 100644 index 0000000000000000000000000000000000000000..2f94f0e1a0476aac9396629ef23b7b17b13c90de --- /dev/null +++ b/py_opengauss/configfile.py @@ -0,0 +1,319 @@ +## +# .configfile +## +""" +PostgreSQL configuration file parser and editor functions. +""" +import sys +import os +from . import string as pg_str +from . import api as pg_api + +quote = "'" +comment = '#' + +def parse_line(line, equality = '=', comment = comment, quote = quote): + keyval = line.split(equality, 1) + if len(keyval) == 2: + key, val = keyval + + prekey_len = 0 + for c in key: + if not c.isspace() and c not in comment: + break + prekey_len += 1 + + key_len = 0 + for c in key[prekey_len:]: + if not (c.isalpha() or c.isdigit() or c in '_'): + break + key_len += 1 + + # If non-whitespace exists after the key, + # it's a complex comment, so just bail out. + if key[prekey_len + key_len:].strip(): + return + + preval_len = 0 + for c in val: + if not c.isspace() or c in '\n\r': + break + preval_len += 1 + + inquotes = False + escaped = False + val_len = 0 + for i in range(preval_len, len(val)): + c = val[i] + if c == quote: + if inquotes is False: + inquotes = True + else: + if escaped is False: + # Peek ahead to see if it's escaped with another quote + escaped = (len(val) > i+1 and val[i+1] == quote) + if escaped is False: + inquotes = False + elif escaped is True: + # It *was* an escaped quote. + escaped = False + elif inquotes is False and (c.isspace() or c in comment): + break + val_len += 1 + + return ( + # The key slice + slice(prekey_len, key_len + prekey_len), + # The value slice + slice(len(key) + 1 + preval_len, len(key) + 1 + preval_len + val_len) + ) + +def unquote(s, quote = quote): + """ + Unquote the string `s` if quoted. + """ + s = s.strip() + if not s.startswith(quote): + return s + return s[1:-1].replace(quote*2, quote) + +def write_config(map, writer, keys = None): + """ + A configuration writer that will trample & merely write the settings. + """ + if keys is None: + keys = map + for k in keys: + writer('='.join((k, map[k])) + os.linesep) + +def alter_config(map, fo, keys = None): + """ + Alters a configuration file without trampling on the existing structure. + """ + if keys is None: + keys = list(map.keys()) + # Normalize keys and map them back to + pkeys = { + k.lower().strip() : keys.index(k) for k in keys + } + + lines = [] + candidates = {} + i = -1 + # Process lines in fo + for l in fo: + i += 1 + lines.append(l) + pl = parse_line(l) + # "Bad" line? fuh-get-duh-bowt-it. + if pl is None: + continue + sk, sv = pl + k = l[sk].lower() + v = l[sv] + # It's a candidate? + if k in pkeys: + c = candidates.get(k) + if c is None: + candidates[k] = c = [] + c.append((i, sk, sv)) + # Simply insert the data somewhere for unfound keys. + for k in pkeys: + if k not in candidates: + key = keys[pkeys[k]] + val = map[key] + # Can't comment without an uncommented candidate. + if val is not None: + if not lines[-1].endswith(os.linesep): + lines[-1] = lines[-1] + os.linesep + lines.append("%s = '%s'" %(key, val.replace("'", "''"))) + + # Multiple lines may have the key, so make a decision based on the value. + for ck in candidates.keys(): + to_set_key = keys[pkeys[ck]] + to_set_val = map[keys[pkeys[ck]]] + + if to_set_val is None: + # Comment uncommented occurrences. + for cl in candidates[ck]: + line_num, sk, sv = cl + if comment not in lines[line_num][:sk.start]: + lines[line_num] = '#' + lines[line_num] + else: + # Manage occurrences. + # w_ is for winner. + # Now, a winner is elected for alteration. The winner + # is decided based on a two factors: commenting and value. + w_score = -1 + w_commented = None + w_val = None + w_cl = None + for cl in candidates[ck]: + line_num, sk, sv = cl + l = lines[line_num] + lkey = l[sk] + lval = l[sv] + commented = (comment in l[:sk.start]) + score = \ + (not commented and 1 or 0) + \ + (unquote(lval) == to_set_val and 2 or 0) + # So, if a line is not commented, and has equal + # values, then that's the winner. If a line is commented, + # and has a has equal values, it will succeed over a mere + # uncommented value. + + if score > w_score: + if w_commented is False: + # It's now a loser, so comment it out if necessary. + lines[w_cl[0]] = '#' + lines[w_cl[0]] + w_score = score + w_commented = commented + w_val = lval + w_cl = cl + elif commented is False: + # Loser needs to be commented. + lines[line_num] = '#' + l + + line_num, sk, sv = w_cl + l = lines[line_num] + if w_commented: + bol = '' + else: + bol = l[:sk.start] + post_val = l[sv.stop:] + # If there is post-value data, validate that it's commented. + if post_val and not post_val.isspace(): + stripped_post_val = post_val.lstrip() + if not stripped_post_val.startswith(comment): + post_val = '%s%s%s' %( + # The whitespace before the uncommented visibles + post_val[0:len(post_val) - len(stripped_post_val)], + # A comment followed by the uncommented visibles + comment, stripped_post_val + ) + # Everything is set as quoted as it's the only safe + # way to set something without delving into setting types. + lines[line_num] = \ + bol + l[sk.start:sv.start] + \ + "'%s'" %(to_set_val.replace("'", "''"),) + post_val + return lines + +def read_config(iter, d = None, selector = None): + if d is None: + d = {} + for line in iter: + kv = parse_line(line) + if kv: + key = line[kv[0]] + if comment not in line[:kv[0].start] and \ + (selector is None or selector(key)): + d[key] = unquote(line[kv[1]]) + return d + +class ConfigFile(pg_api.Settings): + """ + Provides a mapping interface to a configuration file. + + Every operation will cause the file to be wholly read, so using `update` to make + multiple changes is desirable. + """ + _e_factors = ('path',) + _e_label = 'CONFIGFILE' + + def _e_metas(self): + yield (None, len(self.keys())) + + def __init__(self, path, open = open): + self.path = path + self._open = open + self._store = [] + self._restore = {} + + def __repr__(self): + return "%s.%s(%r)" %( + type(self).__module__, + type(self).__name__, + self.path + ) + + def _save(self, lines : [str]): + with self._open(self.path, 'w') as cf: + for l in lines: + cf.write(l) + + def __delitem__(self, k): + with self._open(self.path) as cf: + lines = alter_config({k : None}, cf) + self._save() + + def __getitem__(self, k): + with self._open(self.path) as cfo: + return read_config( + cfo, + selector = k.__eq__ + )[k] + + def __setitem__(self, k, v): + self.update({k : v}) + + def __call__(self, **kw): + self._store.insert(0, kw) + + def __context__(self): + return self + + def __iter__(self): + return self.keys() + + def __len__(self): + return len(list(self.keys())) + + def __enter__(self): + res = self.getset(self._store[0].keys()) + self.update(self._store[0]) + del self._store[0] + self._restore.append(res) + + def __exit__(self, exc, val, tb): + self._restored.update(self._restore[-1]) + del self._restore[-1] + self.update(self._restored) + self._restored.clear() + return exc is None + + def get(self, k, alt = None): + with self._open(self.path) as cf: + return read_config(cf, selector = k.__eq__).get(k, alt) + + def keys(self): + return read_config(self._open(self.path)).keys() + + def values(self): + return read_config(self._open(self.path)).values() + + def items(self): + return read_config(self._open(self.path)).items() + + def update(self, keyvals): + """ + Given a dictionary of settings, apply them to the cluster's + postgresql.conf. + """ + with self._open(self.path) as cf: + lines = alter_config(keyvals, cf) + self._save(lines) + + def getset(self, keys): + """ + Get all the settings in the list of keys. + Returns a dictionary of those keys. + """ + keys = set(keys) + with self._open(self.path) as cfo: + cfg = read_config( + cfo, + selector = keys.__contains__ + ) + for x in (keys - set(cfg.keys())): + cfg[x] = None + return cfg diff --git a/py_opengauss/copyman.py b/py_opengauss/copyman.py new file mode 100644 index 0000000000000000000000000000000000000000..b0f7a6c0ecd152586cf29f0b964c28730afc93ca --- /dev/null +++ b/py_opengauss/copyman.py @@ -0,0 +1,841 @@ +## +# .copyman - COPY manager +## +""" +Manage complex COPY operations; one-to-many COPY streaming. + +Primarily this module houses the `CopyManager` class, and the `transfer` +function for a high-level interface to using the `CopyManager`. +""" +import sys +from abc import abstractmethod, abstractproperty +from collections import Iterator +from .python.element import Element, ElementSet +from .python.structlib import ulong_unpack, ulong_pack +from .protocol.buffer import pq_message_stream +from .protocol.element3 import CopyData, CopyDone, Complete, cat_messages +from .protocol.xact3 import Complete as xactComplete + +#: 10KB buffer for COPY messages by default. +default_buffer_size = 1024 * 10 + +class Fault(Exception): + pass + +class ProducerFault(Fault): + """ + Exception raised when the Producer caused an exception. + + Normally, Producer faults are fatal. + """ + def __init__(self, manager): + self.manager = manager + + def __str__(self): + return "producer raised exception" + +class ReceiverFault(Fault): + """ + Exception raised when Receivers cause an exception. + + Faults should be trapped if recovery from an exception is + possible, or if the failed receiver is optional to the succes of the + operation. + + The 'manager' attribute is the CopyManager that raised the fault. + + The 'faults' attribute is a dictionary mapping the receiver to the exception + instance raised. + """ + def __init__(self, manager, faults): + self.manager = manager + self.faults = faults + + def __str__(self): + return "{0} faults occurred".format(len(self.faults)) + +class CopyFail(Exception): + """ + Exception thrown by the CopyManager when the COPY operation failed. + + The 'manager' attribute the CopyManager that raised the CopyFail. + + The 'reason' attribute is a string indicating why it failed. + + The 'receiver_faults' attribute is a mapping of receivers to exceptions that were + raised on exit. + + The 'producer_fault' attribute specifies if the producer raise an exception + on exit. + """ + def __init__(self, manager, reason = None, + receiver_faults = None, + producer_fault = None, + ): + self.manager = manager + self.reason = reason + self.receiver_faults = receiver_faults or {} + self.producer_fault = producer_fault + + def __str__(self): + return self.reason or 'copy aborted' + +# The identifier for PQv3 copy data. +PROTOCOL_PQv3 = "PQv3" +# The identifier for iterables of copy data sequences. +# iter([[row1, row2], [row3, row4]]) +PROTOCOL_CHUNKS = "CHUNKS" +# The protocol identifier for NULL producers and receivers. +PROTOCOL_NULL = None + +class ChunkProtocol(object): + __slots__ = ('buffer',) + def __init__(self): + self.buffer = pq_message_stream() + + def __call__(self, data): + self.buffer.write(bytes(data)) + return [x[1] for x in self.buffer.read()] + +# Null protocol mapping. +def EmptyView(arg): + return memoryview(b'') +def EmptyList(arg): + return [] +def ReturnNone(arg): + return None +# Zero-Transformation +def NoTransformation(arg): + return arg + +# Copy protocols being at the Python level; *not* wire/serialization format. +copy_protocol_mappings = { + # PQv3 -> Chunks + (PROTOCOL_PQv3, PROTOCOL_CHUNKS) : ChunkProtocol, + # Chunks -> PQv3 + (PROTOCOL_CHUNKS, PROTOCOL_PQv3) : lambda: cat_messages, + # Null Producers and Receivers + (PROTOCOL_NULL, PROTOCOL_PQv3) : lambda: EmptyView, + (PROTOCOL_NULL, PROTOCOL_CHUNKS) : lambda: EmptyList, + (PROTOCOL_PQv3, PROTOCOL_NULL) : lambda: ReturnNone, + (PROTOCOL_CHUNKS, PROTOCOL_NULL) : lambda: ReturnNone, + # Zero Transformations + (PROTOCOL_NULL, PROTOCOL_NULL) : lambda: NoTransformation, + (PROTOCOL_CHUNKS, PROTOCOL_CHUNKS) : lambda: NoTransformation, + (PROTOCOL_PQv3, PROTOCOL_PQv3) : lambda: NoTransformation, +} + +# Used to manage the conversions of COPY data. +# Notably, chunks -> PQv3 or PQv3 -> chunks. +class CopyTransformer(object): + __slots__ = ('current', 'transformers', 'get') + def __init__(self, source_protocol, target_protocols): + self.current = {} + self.transformers = { + x : copy_protocol_mappings[(source_protocol, x)]() + for x in set(target_protocols) + } + self.get = self.current.__getitem__ + + def __call__(self, data): + for protocol, transformer in self.transformers.items(): + self.current[protocol] = transformer(data) + +## +# This is the object that does the magic. +# It tracks the state of the wire. +# It ends when non-COPY data is found. +class WireState(object): + """ + Manages the state of the wire. + + This class manages three possible positions: + + 1. Between wire messages + 2. Inside message header + 3. Inside message (with complete header) + + The wire state will become unusable when the configured condition is True. + """ + __slots__ = ('remaining_bytes', 'size_fragment', 'final_view', 'condition',) + + def update(self, view, getlen = ulong_unpack, len = len): + """ + Given the state of the COPY and new data, advance the position on the + COPY stream. + """ + # Only usable until the terminating condition. + if self.final_view is not None: + raise RuntimeError("wire state encountered exceptional condition") + + nmessages = 0 + + # State carried over from prior run. + remaining_bytes = self.remaining_bytes + size_fragment = self.size_fragment + + # Terminating condition. + CONDITION = self.condition + + # Is it a continuation of a message header? + if remaining_bytes == -1: + ## + # Inside message header; after message type. + # Continue adding to the 'size_fragment' + # until there are four bytes to unpack. + ## + o = len(size_fragment) + size_fragment += bytes(view[:4-o]) + if len(size_fragment) == 4: + # The size fragment is completed; only part + # of the fragment remains to be consumed. + remaining_bytes = getlen(size_fragment) - o + size_fragment = b'' + else: + assert len(size_fragment) < 4 + # size_fragment got updated.. + + if remaining_bytes >= 0: + vlen = len(view) + while True: + if remaining_bytes: + ## + # Inside message body. Message length has been unpacked. + ## + view = view[remaining_bytes:] + # How much is remaining now? + rb = remaining_bytes - vlen + if rb <= 0: + # Finished it. + vlen = -rb + remaining_bytes = 0 + nmessages += 1 + else: + vlen = 0 + remaining_bytes = rb + ## + # In between protocol messages. + ## + if not view: + # no more data to analyze + break + # There is at least one byte in the view. + if CONDITION(view[0]): + # State is dead now. + # User needs to handle unexpected message, then continue. + self.final_view = view + assert remaining_bytes == 0 + break + if vlen < 5: + # Header continuation. + remaining_bytes = -1 + view = view[1:] + size_fragment += bytes(view) + # Not enough left for the header of the next message? + break + # Update remaining_bytes to include the header, and start over. + remaining_bytes = getlen(view[1:5]) + 1 + + # Update the state for the next update. + self.remaining_bytes, self.size_fragment = ( + remaining_bytes, size_fragment, + ) + # Emit the number of messages "consumed" this round. + return nmessages + + def __init__(self, condition = (CopyData.type[0].__ne__ if isinstance(memoryview(b'f')[0], int) else CopyData.type.__ne__)): + self.remaining_bytes = 0 + self.size_fragment = b'' + self.final_view = None + self.condition = condition + +class Fitting(Element): + _e_label = 'FITTING' + + def _e_metas(self): + yield None, '[' + self.state + ']' + + @abstractproperty + def protocol(self): + """ + The COPY data format produced or consumed. + """ + + # Used to setup the Receiver/Producer + def __enter__(self): + pass + + # Used to tear down the Receiver/Producer + def __exit__(self, typ, val, tb): + pass + +class Producer(Fitting, Iterator): + _e_label = 'PRODUCER' + + def _e_metas(self): + for x in super()._e_metas(): + yield x + yield 'data', str(self.total_bytes / (1024**2)) + 'MB' + yield 'messages', self.total_messages + yield 'average size', (self.total_bytes / self.total_messages) + + def __init__(self): + self.total_messages = 0 + self.total_bytes = 0 + + @abstractmethod + def realign(self): + """ + Method implemented by producers that emit COPY data that is not + guaranteed to be aligned. + + This is only necessary in failure cases where receivers still need more + data to complete the message. + """ + + @abstractmethod + def __next__(self): + """ + Produce the next set of data. + """ + +class Receiver(Fitting): + _e_label = 'RECEIVER' + + @abstractmethod + def transmit(self): + """ + Finish the reception of the accepted data. + """ + + @abstractmethod + def accept(self, data): + """ + Take the data object to be processed. + """ + +class NullProducer(Producer): + """ + Produces no copy data. + """ + _e_factors = () + protocol = PROTOCOL_NULL + + def realign(self): + # Never needs to realigned. + pass + + def __next__(self): + raise StopIteration + +class IteratorProducer(Producer): + _e_factors = ('iterator',) + protocol = PROTOCOL_CHUNKS + + def __init__(self, iterator): + self.iterator = iter(iterator) + self.__next__ = self.iterator.__next__ + super().__init__() + + def realign(self): + # Never needs to realign; data is emitted on message boundaries. + pass + + def __next__(self, next = next): + n = next(self.iterator) + self.total_messages += len(n) + self.total_bytes += sum(map(len, n)) + return n + +class ProtocolProducer(Producer): + """ + Producer using a PQv3 data stream. + + Normally, this class needs to be subclassed as it assumes that the given + recv_into function will write COPY messages. + """ + protocol = PROTOCOL_PQv3 + + @abstractmethod + def recover(self, view): + """ + Given a view containing data read from the wire, recover the + controller's state. + + This needs to be implemented by subclasses in order for the + ProtocolReceiver to pass control back to the original state machine. + """ + + ## + # When a COPY is interrupted, this can be used to accommodate + # the original state machine to identify the message boundaries. + def realign(self): + s = self._state + + if s is None: + # It's already aligned. + self.nextchunk = iter(()).__next__ + return + + if s.final_view: + # It was at the end or non-COPY. + for_producer = bytes(s.final_view) + for_receivers = b'' + elif s.remaining_bytes == -1: + # In the middle of a message header. + for_producer = CopyData.type + s.size_fragment + # receivers: + header = (self._state.size_fragment.ljust(3, b'\x00') + b'\x04') + # Don't include the already sent parts. + buf = header[len(self._state.size_fragment):] + bodylen = ulong_unpack(header) - 4 + # This will often cause an invalid copy data error, + # but it doesn't matter much because we will issue a copy fail. + buf += b'\x00' * bodylen + for_receivers = buf + elif s.remaining_bytes > 0: + # In the middle of a message. + for_producer = CopyData.type + ulong_pack(s.remaining_bytes + 4) + for_receivers = b'\x00' * self._state.remaining_bytes + else: + for_producer = for_receivers = b'' + + self.recover(for_producer) + if for_receivers: + self.nextchunk = iter((for_receivers,)).__next__ + else: + self.nextchunk = iter(()).__next__ + + def process_copy_data(self, view): + self.total_messages += self._state.update(view) + if self._state.final_view is not None: + # It's not COPY data. + fv = self._state.final_view + # Only publish up to the final_view. + if fv: + view = view[:-len(fv)] + # The next next() will handle the async, error, or completion. + self.recover(fv) + self._state = None + self.total_bytes += len(view) + return view + + # Given a view, begin tracking the state of the wire. + def track_state(self, view): + self._state = WireState() + self.nextchunk = self.recv_view + return self.process_copy_data(view) + + # The usual method for receiving more data. + def recv_view(self): + view = self.buffer_view[:self.recv_into(self.buffer, self.buffer_size)] + if not view: + # Zero read; let the subclass handle the situation. + self.recover(memoryview(b'')) + return self.nextchunk() + view = self.process_copy_data(view) + return view + + def nextchunk(self): + raise RuntimeError("producer not properly initialized") + + def __next__(self): + return self.nextchunk() + + def __init__(self, + recv_into, + buffer_size = default_buffer_size + ): + super().__init__() + self.recv_into = recv_into + self.buffer_size = buffer_size + self.buffer = bytearray(buffer_size) + self.buffer_view = memoryview(self.buffer) + self._state = None + +class StatementProducer(ProtocolProducer): + _e_factors = ('statement', 'parameters',) + + def _e_metas(self): + for x in super()._e_metas(): + yield x + + @property + def state(self): + if self._chunks is None: + return 'created' + return 'producing' + + def count(self): + return self._chunks.count() + + def command(self): + return self._chunks.command() + + def __init__(self, statement, *args, **kw): + super().__init__(statement.database.pq.socket.recv_into, **kw) + self.statement = statement + self.parameters = args + self._chunks = None + + ## + # Take any data held by the statement's chunks and connection. + def confiscate(self, next = next): + current = [] + try: + while not current: + current.extend(next(self._chunks)) + except StopIteration: + if not current: + # End of COPY. + raise + pq = self._chunks.database.pq + buffer = cat_messages(current) + pq.message_buffer.getvalue() + (pq.read_data or b'') + view = memoryview(buffer) + pq.read_data = None + pq.message_buffer.truncate() + # Reconstruct the buffer from the already parsed lines. + r = self.track_state(view) + # XXX: Better way? Probably shouldn't do the full track_state if complete.. + if self._chunks._xact.state is xactComplete: + # It's over, don't hand off to recv_view. + self.nextchunk = self.confiscate + assert self._state.final_view is None + return r + + def recover(self, view): + # Method used when non-COPY data is found. + self._chunks.database.pq.message_buffer.write(bytes(view)) + self.nextchunk = self.confiscate + + def __enter__(self): + super().__enter__() + if self._chunks is not None: + raise RuntimeError("receiver already used") + self._chunks = self.statement.chunks(*self.parameters) + # Start by confiscating the connection state. + self.nextchunk = self.confiscate + + def __exit__(self, typ, val, tb): + if typ is None or issubclass(typ, Exception): + db = self.statement.database + if not db.closed and self._chunks._xact is not None: + # The COPY transaction is still happening, + # force an interrupt if the connection still exists. + db.interrupt() + if db.pq.xact: + # Raise, CopyManager should trap. + db._pq_complete() + super().__exit__(typ, val, tb) + +class NullReceiver(Receiver): + _e_factors = () + protocol = PROTOCOL_NULL + state = 'null' + + def transmit(self): + # Nothing to do. + pass + + def accept(self, data): + pass + +class ProtocolReceiver(Receiver): + protocol = PROTOCOL_PQv3 + __slots__ = ('send', 'view') + + def __init__(self, send): + super().__init__() + self.send = send + self.view = memoryview(b'') + + def accept(self, data): + self.view = data + + def transmit(self): + while self.view: + self.view = self.view[self.send(self.view):] + + def __enter__(self): + return self + + def __exit__(self, typ, val, tb): + pass + +class StatementReceiver(ProtocolReceiver): + _e_factors = ('statement', 'parameters',) + __slots__ = ProtocolReceiver.__slots__ + _e_factors + ('xact',) + + def _e_metas(self): + yield None, '[' + self.state + ']' + + def __init__(self, statement, *parameters): + self.statement = statement + self.parameters = parameters + self.xact = None + super().__init__(statement.database.pq.socket.send,) + + # XXX: A bit of a hack... + # This is actually a good indication that statements need a .copy() + # execution method for producing a "CopyCursor" that reads or writes. + class WireReady(BaseException): + pass + def raise_wire_ready(self): + raise self.WireReady() + yield None + + def __enter__(self, iter = iter): + super().__enter__() + # Get the connection in the COPY state. + try: + self.statement.load_chunks( + iter(self.raise_wire_ready()), *self.parameters + ) + except self.WireReady: + # It's a BaseException; nothing should trap it. + # Note the transaction object; we'll use it on exit. + self.xact = self.statement.database.pq.xact + + def __exit__(self, typ, val, tb): + if self.xact is None: + # Nothing to do. + return super().__exit__(typ, val, tb) + + if self.view: + # The realigned producer emitted the necessary + # data for message boundary alignment. + # + # In this case, we unconditionally fail. + pq = self.statement.database.pq + # There shouldn't be any message_data, atm. + pq.message_data = bytes(self.view) + self.statement.database._pq_complete() + # It is possible for a non-alignment view to exist in cases of + # faults. However, exit should *not* be called in those cases. + ## + elif typ is None: + # Success? + self.xact.messages = self.xact.CopyDoneSequence + # If not, this will blow up. + self.statement.database._pq_complete() + # Find the complete message for command and count. + for x in self.xact.messages_received(): + if getattr(x, 'type', None) == Complete.type: + self._complete_message = x + elif issubclass(typ, Exception): + # Likely raises. CopyManager should trap. + self.statement.database._pq_complete() + + return super().__exit__(typ, val, tb) + + def count(self): + return self._complete_message.extract_count() + + def command(self): + return self._complete_message.extract_command().decode('ascii') + +class CallReceiver(Receiver): + """ + Call the given object with a list of COPY lines. + """ + _e_factors = ('callable',) + protocol = PROTOCOL_CHUNKS + + def __init__(self, callable): + self.callable = callable + self.lines = None + super().__init__() + + def transmit(self): + if self.lines is not None: + self.callable(self.lines) + self.lines = None + + def accept(self, lines): + self.lines = lines + +class CopyManager(Element, Iterator): + """ + A class for managing COPY operations. + + Connects the producer to the receivers. + """ + _e_label = 'COPY' + _e_factors = ('producer', 'receivers',) + + def _e_metas(self): + yield None, '[' + self.state + ']' + + @property + def state(self): + if self.transformer is None: + return 'initialized' + return str(self.producer.total_messages) + ' messages transferred' + + def __init__(self, producer, *receivers): + self.producer = producer + self.transformer = None + self.receivers = ElementSet(receivers) + self._seen_stop_iteration = False + rp = set() + add = rp.add + for x in self.receivers: + add(x.protocol) + self.protocols = rp + + def __enter__(self): + if self.transformer: + raise RuntimeError("copy already started") + self._stats = (0, 0) + self.transformer = CopyTransformer(self.producer.protocol, self.protocols) + self.producer.__enter__() + try: + for x in self.receivers: + x.__enter__() + except Exception: + self.__exit__(*sys.exc_info()) + return self + + def __exit__(self, typ, val, tb): + ## + # Exiting the CopyManager is a fairly complex operation. + # + # In cases of failure, re-alignment may need to happen + # for when the receivers are not on a message boundary. + ## + if typ is not None and not issubclass(typ, Exception): + # Don't bother, it's an interrupt or sufficient resources. + return + + profail = None + try: + # Does nothing if the COPY was successful. + self.producer.realign() + try: + ## + # If the producer is not aligned to a message boundary, + # it can emit completion data that will put the receivers + # back on track. + # This last service call will move that data onto the receivers. + self._service_producer() + ## + # The receivers need to handle any new data in their __exit__. + except StopIteration: + # No re-alignment needed. + pass + + self.producer.__exit__(typ, val, tb) + except Exception as x: + # reference profail later. + profail = x + + # No receivers? It wasn't a success. + if not self.receivers: + raise CopyFail(self, "no receivers", producer_fault = profail) + + exit_faults = {} + for x in self.receivers: + try: + x.__exit__(typ, val, tb) + except Exception as e: + exit_faults[x] = e + + if typ or exit_faults or profail or not self._seen_stop_iteration: + raise CopyFail(self, + "could not complete the COPY operation", + receiver_faults = exit_faults or None, + producer_fault = profail + ) + + def reconcile(self, r): + """ + Reconcile a receiver that faulted. + + This method should be used to add back a receiver that failed to + complete its write operation, but is capable of completing the + operation at this time. + """ + if r.protocol not in self.protocols: + raise RuntimeError("cannot add new receivers to copy operations") + r.transmit() + # Okay, add it back. + self.receivers.add(r) + + def _service_producer(self): + # Setup current data. + if not self.receivers: + # No receivers to take the data. + raise StopIteration + + try: + nextdata = next(self.producer) + except StopIteration: + # Should be over. + self._seen_stop_iteration = True + raise + except Exception: + raise ProducerFault(self) + + self.transformer(nextdata) + + # Distribute data to receivers. + for x in self.receivers: + x.accept(self.transformer.get(x.protocol)) + + def _service_receivers(self): + faults = {} + for x in self.receivers: + # Process all the receivers. + try: + x.transmit() + except Exception as e: + faults[x] = e + if faults: + # The CopyManager is eager to continue the operation. + for x in faults: + self.receivers.discard(x) + raise ReceiverFault(self, faults) + + # Run the COPY to completion. + def run(self): + with self: + try: + while True: + self._service_producer() + self._service_receivers() + except StopIteration: + # It's done. + pass + + def __iter__(self): + return self + + def __next__(self): + messages = self.producer.total_messages + bytes = self.producer.total_bytes + + self._service_producer() + # Record the progress in case a receiver faults. + self._stats = ( + self._stats[0] + (self.producer.total_messages - messages), + self._stats[1] + (self.producer.total_bytes - bytes), + ) + self._service_receivers() + # Return the progress. + current_stats = self._stats + self._stats = (0, 0) + return current_stats + +def transfer(producer, *receivers): + """ + Perform a COPY operation using the given statements:: + + >>> import copyman + >>> copyman.transfer(src.prepare("COPY table TO STDOUT"), dst.prepare("COPY table FROM STDIN")) + """ + cm = CopyManager( + StatementProducer(producer), + *[x if isinstance(x, Receiver) else StatementReceiver(x) for x in receivers] + ) + cm.run() + return (cm.producer.total_messages, cm.producer.total_bytes) diff --git a/py_opengauss/documentation/__init__.py b/py_opengauss/documentation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f2c7b839cad22de5f9104031572b0e438ab8c0a --- /dev/null +++ b/py_opengauss/documentation/__init__.py @@ -0,0 +1,7 @@ +## +# .documentation +## +r""" +See: `postgresql.documentation.index` +""" +__docformat__ = 'reStructuredText' diff --git a/py_opengauss/documentation/admin.rst b/py_opengauss/documentation/admin.rst new file mode 100644 index 0000000000000000000000000000000000000000..092600d54975feb4dfe5f464deeaa861c652951a --- /dev/null +++ b/py_opengauss/documentation/admin.rst @@ -0,0 +1,33 @@ +Administration +============== + +This chapter covers the administration of py-postgresql. This includes +installation and other aspects of working with py-postgresql such as +environment variables and configuration files. + +Installation +------------ + +py-postgresql uses Python's distutils package to manage the build and +installation process of the package. The normal entry point for +this is the ``setup.py`` script contained in the root project directory. + +After extracting the archive and changing the into the project's directory, +installation is normally as simple as:: + + $ python3 ./setup.py install + +However, if you need to install for use with a particular version of python, +just use the path of the executable that should be used:: + + $ /usr/opt/bin/python3 ./setup.py install + + +Environment +----------- + +These environment variables effect the operation of the package: + + ============== =============================================================================== + PGINSTALLATION The path to the ``pg_config`` executable of the installation to use by default. + ============== =============================================================================== diff --git a/py_opengauss/documentation/alock.rst b/py_opengauss/documentation/alock.rst new file mode 100644 index 0000000000000000000000000000000000000000..6767767422ada306f816e7cbe456c5b1b7aaa449 --- /dev/null +++ b/py_opengauss/documentation/alock.rst @@ -0,0 +1,108 @@ +.. _alock: + +************** +Advisory Locks +************** + +.. warning:: `postgresql.alock` is a new feature in v1.0. + +`Explicit Locking in PostgreSQL `_. + +PostgreSQL's advisory locks offer a cooperative synchronization primitive. +These are used in cases where an application needs access to a resource, but +using table locks may cause interference with other operations that can be +safely performed alongside the application-level, exclusive operation. + +Advisory locks can be used by directly executing the stored procedures in the +database or by using the :class:`postgresql.alock.ALock` subclasses, which +provides a context manager that uses those stored procedures. + +Currently, only two subclasses exist. Each represents the lock mode +supported by PostgreSQL's advisory locks: + + * :class:`postgresql.alock.ShareLock` + * :class:`postgresql.alock.ExclusiveLock` + + +Acquiring ALocks +================ + +An ALock instance represents a sequence of advisory locks. A single ALock can +acquire and release multiple advisory locks by creating the instance with +multiple lock identifiers:: + + >>> from postgresql import alock + >>> table1_oid = 192842 + >>> table2_oid = 192849 + >>> l = alock.ExclusiveLock(db, (table1_oid, 0), (table2_oid, 0)) + >>> l.acquire() + >>> ... + >>> l.release() + +:class:`postgresql.alock.ALock` is similar to :class:`threading.RLock`; in +order for an ALock to be released, it must be released the number of times it +has been acquired. ALocks are associated with and survived by their session. +Much like how RLocks are associated with the thread they are acquired in: +acquiring an ALock again will merely increment its count. + +PostgreSQL allows advisory locks to be identified using a pair of `int4` or a +single `int8`. ALock instances represent a *sequence* of those identifiers:: + + >>> from postgresql import alock + >>> ids = [(0,0), 0, 1] + >>> with alock.ShareLock(db, *ids): + ... ... + +Both types of identifiers may be used within the same ALock, and, regardless of +their type, will be aquired in the order that they were given to the class' +constructor. In the above example, ``(0,0)`` is acquired first, then ``0``, and +lastly ``1``. + + +ALocks +====== + +`postgresql.alock.ALock` is abstract; it defines the interface and some common +functionality. The lock mode is selected by choosing the appropriate subclass. + +There are two: + + ``postgresql.alock.ExclusiveLock(database, *identifiers)`` + Instantiate an ALock object representing the `identifiers` for use with the + `database`. Exclusive locks will conflict with other exclusive locks and share + locks. + + ``postgresql.alock.ShareLock(database, *identifiers)`` + Instantiate an ALock object representing the `identifiers` for use with the + `database`. Share locks can be acquired when a share lock with the same + identifier has been acquired by another backend. However, an exclusive lock + with the same identifier will conflict. + + +ALock Interface Points +---------------------- + +Methods and properties available on :class:`postgresql.alock.ALock` instances: + + ``alock.acquire(blocking = True)`` + Acquire the advisory locks represented by the ``alock`` object. If blocking is + `True`, the default, the method will block until locks on *all* the + identifiers have been acquired. + + If blocking is `False`, acquisition may not block, and success will be + indicated by the returned object: `True` if *all* lock identifiers were + acquired and `False` if any of the lock identifiers could not be acquired. + + ``alock.release()`` + Release the advisory locks represented by the ``alock`` object. If the lock + has not been acquired, a `RuntimeError` will be raised. + + ``alock.locked()`` + Returns a boolean describing whether the locks are held or not. This will + return `False` if the lock connection has been closed. + + ``alock.__enter__()`` + Alias to ``acquire``; context manager protocol. Always blocking. + + ``alock.__exit__(typ, val, tb)`` + Alias to ``release``; context manager protocol. diff --git a/py_opengauss/documentation/bin.rst b/py_opengauss/documentation/bin.rst new file mode 100644 index 0000000000000000000000000000000000000000..43e7a76e9c2488ac57f1730965e3fab5e90c0e73 --- /dev/null +++ b/py_opengauss/documentation/bin.rst @@ -0,0 +1,170 @@ +Commands +******** + +This chapter discusses the usage of the available console scripts. + + +postgresql.bin.pg_python +======================== + +The ``pg_python`` command provides a simple way to write Python scripts against a +single target database. It acts like the regular Python console command, but +takes standard PostgreSQL options as well to specify the client parameters +to make establish connection with. The Python environment is then augmented +with the following built-ins: + + ``db`` + The PG-API connection object. + + ``xact`` + ``db.xact``, the transaction creator. + + ``settings`` + ``db.settings`` + + ``prepare`` + ``db.prepare``, the statement creator. + + ``proc`` + ``db.proc`` + + ``do`` + ``db.do``, execute a single DO statement. + + ``sqlexec`` + ``db.execute``, execute multiple SQL statements (``None`` is always returned) + +pg_python Usage +--------------- + +Usage: postgresql.bin.pg_python [connection options] [script] ... + +Options: + --unix=UNIX path to filesystem socket + --ssl-mode=SSLMODE SSL requirement for connectivity: require, prefer, + allow, disable + -s SETTINGS, --setting=SETTINGS + run-time parameters to set upon connecting + -I PQ_IRI, --iri=PQ_IRI + database locator string + [pq://user:password@host:port/database?setting=value] + -h HOST, --host=HOST database server host + -p PORT, --port=PORT database server port + -U USER, --username=USER + user name to connect as + -W, --password prompt for password + -d DATABASE, --database=DATABASE + database's name + --pq-trace=PQ_TRACE trace PQ protocol transmissions + -C PYTHON_CONTEXT, --context=PYTHON_CONTEXT + Python context code to run[file://,module:,] + -m PYTHON_MAIN Python module to run as script(__main__) + -c PYTHON_MAIN Python expression to run(__main__) + --version show program's version number and exit + --help show this help message and exit + + +Interactive Console Backslash Commands +-------------------------------------- + +Inspired by ``psql``:: + + >>> \? + Backslash Commands: + + \? Show this help message. + \E Edit a file or a temporary script. + \e Edit and Execute the file directly in the context. + \i Execute a Python script within the interpreter's context. + \set Configure environment variables. \set without arguments to show all + \x Execute the Python command within this process. + + +pg_python Examples +------------------ + +Module execution taking advantage of the new built-ins:: + + $ python3 -m postgresql.bin.pg_python -h localhost -W -m timeit "prepare('SELECT 1').first()" + Password for pg_python[pq://dbusername@localhost:5432]: + 1000 loops, best of 3: 1.35 msec per loop + + $ python3 -m postgresql.bin.pg_python -h localhost -W -m timeit -s "ps=prepare('SELECT 1')" "ps.first()" + Password for pg_python[pq://dbusername@localhost:5432]: + 1000 loops, best of 3: 442 usec per loop + +Simple interactive usage:: + + $ python3 -m postgresql.bin.pg_python -h localhost -W + Password for pg_python[pq://dbusername@localhost:5432]: + >>> ps = prepare('select 1') + >>> ps.first() + 1 + >>> c = ps() + >>> c.read() + [(1,)] + >>> ps.close() + >>> import sys + >>> sys.exit(0) + + +postgresql.bin.pg_dotconf +========================= + +pg_dotconf is used to modify a PostgreSQL cluster's configuration file. +It provides a means to apply settings specified from the command line and from a +file referenced using the ``-f`` option. + +.. warning:: + ``include`` directives in configuration files are *completely* ignored. If + modification of an included file is desired, the command must be applied to + that specific file. + + +pg_dotconf Usage +---------------- + +Usage: postgresql.bin.pg_dotconf [--stdout] [-f filepath] postgresql.conf ([param=val]|[param])* + +Options: + --version show program's version number and exit + -h, --help show this help message and exit + -f SETTINGS, --file=SETTINGS + A file of settings to *apply* to the given + "postgresql.conf" + --stdout Redirect the product to standard output instead of + writing back to the "postgresql.conf" file + + +Examples +-------- + +Modifying a simple configuration file:: + + $ echo "setting = value" >pg.conf + + # change 'setting' + $ python3 -m postgresql.bin.pg_dotconf pg.conf setting=newvalue + + $ cat pg.conf + setting = 'newvalue' + + # new settings are appended to the file + $ python3 -m postgresql.bin.pg_dotconf pg.conf another_setting=value + $ cat pg.conf + setting = 'newvalue' + another_setting = 'value' + + # comment a setting + $ python3 -m postgresql.bin.pg_dotconf pg.conf another_setting + + $ cat pg.conf + setting = 'newvalue' + #another_setting = 'value' + +When a setting is given on the command line, it must been seen as one argument +to the command, so it's *very* important to avoid invocations like:: + + $ python3 -m postgresql.bin.pg_dotconf pg.conf setting = value + ERROR: invalid setting, '=' after 'setting' + HINT: Settings must take the form 'setting=value' or 'setting_name_to_comment'. Settings must also be received as a single argument. diff --git a/py_opengauss/documentation/changes-v1.0.rst b/py_opengauss/documentation/changes-v1.0.rst new file mode 100644 index 0000000000000000000000000000000000000000..89c8cea31f2a5bdfc58ff91b8a5bdb4924d7dc58 --- /dev/null +++ b/py_opengauss/documentation/changes-v1.0.rst @@ -0,0 +1,79 @@ +Changes in v1.0 +=============== + +1.0.4 in development +-------------------- + + * Alter how changes are represented in documentation to simplify merging. + +1.0.3 released on 2011-09-24 +---------------------------- + + * Use raise x from y to generalize exceptions. (Elvis Pranskevichus) + * Alter postgresql.string.quote_ident to always quote. (Elvis Pranskevichus) + * Add postgresql.string.quote_ident_if_necessary (Modification of Elvis Pranskevichus' patch) + * Many postgresql.string bug fixes (Elvis Pranskevichus) + * Correct ResourceWarnings improving Python 3.2 support. (jwp) + * Add test command to setup.py (Elvis Pranskevichus) + +1.0.2 released on 2010-09-18 +---------------------------- + + * Add support for DOMAINs in registered composites. (Elvis Pranskevichus) + * Properly raise StopIteration in Cursor.__next__. (Elvis Pranskevichus) + * Add Cluster Management documentation. + * Release savepoints after rolling them back. + * Fix Startup() usage for Python 3.2. + * Emit deprecation warning when 'gid' is given to xact(). + * Compensate for Python3.2's ElementTree API changes. + +1.0.1 released on 2010-04-24 +---------------------------- + + * Fix unpacking of array NULLs. (Elvis Pranskevichus) + * Fix .first()'s handling of counts and commands. + Bad logic caused zero-counts to return the command tag. + * Don't interrupt and close a temporal connection if it's not open. + * Use the Driver's typio attribute for TypeIO overrides. (Elvis Pranskevichus) + +1.0 released on 2010-03-27 +-------------------------- + + * **DEPRECATION**: Removed 2PC support documentation. + * **DEPRECATION**: Removed pg_python and pg_dotconf 'scripts'. + They are still accessible by python3 -m postgresql.bin.pg_* + * Add support for binary hstore. + * Add support for user service files. + * Implement a Copy manager for direct connection-to-connection COPY operations. + * Added db.do() method for DO-statement support(convenience method). + * Set the default client_min_messages level to WARNING. + NOTICEs are often not desired by programmers, and py-postgresql's + high verbosity further irritates that case. + * Added postgresql.project module to provide project information. + Project name, author, version, etc. + * Increased default recvsize and chunksize for improved performance. + * 'D' messages are special cased as builtins.tuples instead of + protocol.element3.Tuple + * Alter Statement.chunks() to return chunks of builtins.tuple. Being + an interface intended for speed, types.Row() impedes its performance. + * Fix handling of infinity values with timestamptz, timestamp, and date. + [Bug reported by Axel Rau.] + * Correct representation of PostgreSQL ARRAYs by properly recording + lowerbounds and upperbounds. Internally, sub-ARRAYs have their own + element lists. + * Implement a NotificationManager for managing the NOTIFYs received + by a connection. The class can manage NOTIFYs from multiple + connections, whereas the db.wait() method is tailored for single targets. + * Implement an ALock class for managing advisory locks using the + threading.Lock APIs. [Feedback from Valentine Gogichashvili] + * Implement reference symbols. Allow libraries to define symbols that + are used to create queries that inherit the original symbol's type and + execution method. ``db.prepare(db.prepare(...).first())`` + * Fix handling of unix domain sockets by pg.open and driver.connect. + [Reported by twitter.com/rintavarustus] + * Fix typo/dropped parts of a raise LoadError in .lib. + [Reported by Vlad Pranskevichus] + * Fix db.tracer and pg_python's --pq-trace= + * Fix count return from .first() method. Failed to provide an empty + tuple for the rformats of the bind statement. + [Reported by dou dou] diff --git a/py_opengauss/documentation/changes-v1.1.rst b/py_opengauss/documentation/changes-v1.1.rst new file mode 100644 index 0000000000000000000000000000000000000000..aa1abbac53dfa184dcd5e9bb861f6f32140aa56d --- /dev/null +++ b/py_opengauss/documentation/changes-v1.1.rst @@ -0,0 +1,25 @@ +Changes in v1.1 +=============== + +1.1.0 +----- + + * Remove two-phase commit interfaces per deprecation in v1.0. + For proper two phase commit use, a lock manager must be employed that + the implementation did nothing to accommodate for. + * Add support for unpacking anonymous records (Elvis) + * Support PostgreSQL 9.2 (Elvis) + * Python 3.3 Support (Elvis) + * Add column execution method. (jwp) + * Add one-shot statement interface. Connection.query.* (jwp) + * Modify the inet/cidr support by relying on the ipaddress module introduced in Python 3.3 (Google's ipaddr project) + The existing implementation relied on simple str() representation supported by the + socket module. Unfortunately, MS Windows' socket library does not appear to support the + necessary functionality, or Python's socket module does not expose it. ipaddress fixes + the problem. + +.. note:: + The `ipaddress` module is now required for local inet and cidr. While it is + of "preliminary" status, the ipaddr project has been around for some time and + well supported. ipaddress appears to be the safest way forward for native + network types. diff --git a/py_opengauss/documentation/changes-v1.2.rst b/py_opengauss/documentation/changes-v1.2.rst new file mode 100644 index 0000000000000000000000000000000000000000..8b79fa165e6aa4c8cc57d36df7f27cd00d52fcf2 --- /dev/null +++ b/py_opengauss/documentation/changes-v1.2.rst @@ -0,0 +1,18 @@ +Changes in v1.2 +=============== + +1.2.2 released on 2020-09-22 +---------------------------- + + * Correct broken Connection.proc. + * Correct IPv6 IRI host oversight. + * Document an ambiguity case of DB-API 2.0 connection creation and the workaround(unix vs host/port). + * (Pending, active in 1.3) DB-API 2.0 connect() failures caused an undesired exception chain; ClientCannotConnect is now raised. + * Minor maintenance on tests and support modules. + +1.2.0 released on 2016-06-23 +---------------------------- + + * PostgreSQL 9.3 compatibility fixes (Elvis) + * Python 3.5 compatibility fixes (Elvis) + * Add support for JSONB type (Elvis) diff --git a/py_opengauss/documentation/changes-v1.3.rst b/py_opengauss/documentation/changes-v1.3.rst new file mode 100644 index 0000000000000000000000000000000000000000..8b8686c3fd39aa2583172f9e937909e69be228cb --- /dev/null +++ b/py_opengauss/documentation/changes-v1.3.rst @@ -0,0 +1,14 @@ +Changes in v1.3 +=============== + +1.3.0 +----- + + * Commit DB-API 2.0 ClientCannotConnect exception correction. + * Eliminate types-as-documentation annotations. + * Add Connection.transaction alias for asyncpg consistency. + * Eliminate multiple inheritance in `postgresql.api` in favor of ABC registration. + * Add support for PGTEST environment variable (pq-IRI) to improve test performance + and to aid in cases where the target fixture is already available. + This should help for testing the driver against servers that are not actually + postgresql. diff --git a/py_opengauss/documentation/clientparameters.rst b/py_opengauss/documentation/clientparameters.rst new file mode 100644 index 0000000000000000000000000000000000000000..8c8441cf8aa0a2465d43bce459ecff09aeb8a052 --- /dev/null +++ b/py_opengauss/documentation/clientparameters.rst @@ -0,0 +1,260 @@ +Client Parameters +***************** + +.. warning:: **The interfaces dealing with optparse are subject to change in 1.0**. + +There are various sources of parameters used by PostgreSQL client applications. +The `postgresql.clientparameters` module provides a means for collecting and +managing those parameters. + +Connection creation interfaces in `postgresql.driver` are purposefully simple. +All parameters taken by those interfaces are keywords, and are taken +literally; if a parameter is not given, it will effectively be `None`. +libpq-based drivers tend differ as they inherit some default client parameters +from the environment. Doing this by default is undesirable as it can cause +trivial failures due to unexpected parameter inheritance. However, using these +parameters from the environment and other sources are simply expected in *some* +cases: `postgresql.open`, `postgresql.bin.pg_python`, and other high-level +utilities. The `postgresql.clientparameters` module provides a means to collect +them into one dictionary object for subsequent application to a connection +creation interface. + +`postgresql.clientparameters` is primarily useful to script authors that want to +provide an interface consistent with PostgreSQL commands like ``psql``. + + +Collecting Parameters +===================== + +The primary entry points in `postgresql.clientparameters` are +`postgresql.clientparameters.collect` and +`postgresql.clientparameters.resolve_password`. + +For most purposes, ``collect`` will suffice. By default, it will prompt for the +password if instructed to(``-W``). Therefore, ``resolve_password`` need not be +used in most cases:: + + >>> import sys + >>> import postgresql.clientparameters as pg_param + >>> p = pg_param.DefaultParser() + >>> co, ca = p.parse_args(sys.argv[1:]) + >>> params = pg_param.collect(parsed_options = co) + +The `postgresql.clientparameters` module is executable, so you can see the +results of the above snippet by:: + + $ python -m postgresql.clientparameters -h localhost -U a_db_user -ssearch_path=public + {'host': 'localhost', + 'password': None, + 'port': 5432, + 'settings': {'search_path': 'public'}, + 'user': 'a_db_user'} + + +`postgresql.clientparameters.collect` +-------------------------------------- + +Build a client parameter dictionary from the environment and parsed command +line options. The following is a list of keyword arguments that ``collect`` will +accept: + + ``parsed_options`` + Options parsed by `postgresql.clientparameters.StandardParser` or + `postgresql.clientparameters.DefaultParser` instances. + + ``no_defaults`` + When `True`, don't include defaults like ``pgpassfile`` and ``user``. + Defaults to `False`. + + ``environ`` + Environment variables to extract client parameter variables from. + Defaults to `os.environ` and expects a `collections.abc.Mapping` interface. + + ``environ_prefix`` + Environment variable prefix to use. Defaults to "PG". This allows the + collection of non-standard environment variables whose keys are partially + consistent with the standard variants. e.g. "PG_SRC_USER", "PG_SRC_HOST", + etc. + + ``default_pg_sysconfdir`` + The location of the pg_service.conf file. The ``PGSYSCONFDIR`` environment + variable will override this. When a default installation is present, + ``PGINSTALLATION``, it should be set to this. + + ``pg_service_file`` + Explicit location of the service file. This will override the "sysconfdir" + based path. + + ``prompt_title`` + Descriptive title to use if a password prompt is needed. `None` to disable + password resolution entirely. Setting this to `None` will also disable + pgpassfile lookups, so it is necessary that further processing occurs when + this is `None`. + + ``parameters`` + Base client parameters to use. These are set after the *defaults* are + collected. (The defaults that can be disabled by ``no_defaults``). + +If ``prompt_title`` is not set to `None`, it will prompt for the password when +instructed to do by the ``prompt_password`` key in the parameters:: + + >>> import postgresql.clientparameters as pg_param + >>> p = pg_param.collect(prompt_title = 'my_prompt!', parameters = {'prompt_password':True}) + Password for my_prompt![pq://dbusername@localhost:5432]: + >>> p + {'host': 'localhost', 'user': 'dbusername', 'password': 'secret', 'port': 5432} + +If `None`, it will leave the necessary password resolution information in the +parameters dictionary for ``resolve_password``:: + + >>> p = pg_param.collect(prompt_title = None, parameters = {'prompt_password':True}) + >>> p + {'pgpassfile': '/home/{USER}/.pgpass', 'prompt_password': True, 'host': 'localhost', 'user': 'dbusername', 'port': 5432} + +Of course, ``'prompt_password'`` is normally specified when ``parsed_options`` +received a ``-W`` option from the command line:: + + >>> op = pg_param.DefaultParser() + >>> co, ca = op.parse_args(['-W']) + >>> p = pg_param.collect(parsed_options = co) + >>> p=pg_param.collect(parsed_options = co) + Password for [pq://dbusername@localhost:5432]: + >>> p + {'host': 'localhost', 'user': 'dbusername', 'password': 'secret', 'port': 5432} + >>> + + +`postgresql.clientparameters.resolve_password` +---------------------------------------------- + +Resolve the password for the given client parameters dictionary returned by +``collect``. By default, this function need not be used as ``collect`` will +resolve the password by default. `resolve_password` accepts the following +arguments: + + ``parameters`` + First positional argument. Normalized client parameters dictionary to update + in-place with the resolved password. If the 'prompt_password' key is in + ``parameters``, it will prompt regardless(normally comes from ``-W``). + + ``getpass`` + Function to call to prompt for the password. Defaults to `getpass.getpass`. + + ``prompt_title`` + Additional title to use if a prompt is requested. This can also be specified + in the ``parameters`` as the ``prompt_title`` key. This *augments* the IRI + display on the prompt. Defaults to an empty string, ``''``. + +The resolution process is effected by the contents of the given ``parameters``. +Notable keywords: + + ``prompt_password`` + If present in the given parameters, the user will be prompted for the using + the given ``getpass`` function. This disables the password file lookup + process. + + ``prompt_title`` + This states a default prompt title to use. If the ``prompt_title`` keyword + argument is given to ``resolve_password``, this will not be used. + + ``pgpassfile`` + The PostgreSQL password file to lookup the password in. If the ``password`` + parameter is present, this will not be used. + +When resolution occurs, the ``prompt_password``, ``prompt_title``, and +``pgpassfile`` keys are *removed* from the given parameters dictionary:: + + >>> p=pg_param.collect(prompt_title = None) + >>> p + {'pgpassfile': '/home/{USER}/.pgpass', 'host': 'localhost', 'user': 'dbusername', 'port': 5432} + >>> pg_param.resolve_password(p) + >>> p + {'host': 'localhost', 'password': 'secret', 'user': 'dbusername', 'port': 5432} + + +Defaults +======== + +The following is a list of default parameters provided by ``collect`` and the +sources of their values: + + ==================== =================================================================== + Key Value + ==================== =================================================================== + ``'user'`` `getpass.getuser()` or ``'postgres'`` + ``'host'`` `postgresql.clientparameters.default_host` (``'localhost'``) + ``'port'`` `postgresql.clientparameters.default_port` (``5432``) + ``'pgpassfile'`` ``"$HOME/.pgpassfile"`` or ``[PGDATA]`` + ``'pgpass.conf'`` (Win32) + ``'sslcrtfile'`` ``[PGDATA]`` + ``'postgresql.crt'`` + ``'sslkeyfile'`` ``[PGDATA]`` + ``'postgresql.key'`` + ``'sslrootcrtfile'`` ``[PGDATA]`` + ``'root.crt'`` + ``'sslrootcrlfile'`` ``[PGDATA]`` + ``'root.crl'`` + ==================== =================================================================== + +``[PGDATA]`` referenced in the above table is a directory whose path is platform +dependent. On most systems, it is ``"$HOME/.postgresql"``, but on Windows based +systems it is ``"%APPDATA%\postgresql"`` + +.. note:: + [PGDATA] is *not* an environment variable. + + +.. _pg_envvars: + +PostgreSQL Environment Variables +================================ + +The following is a list of environment variables that will be collected by the +`postgresql.clientparameter.collect` function using "PG" as the +``environ_prefix`` and the keyword that it will be mapped to: + + ===================== ====================================== + Environment Variable Keyword + ===================== ====================================== + ``PGUSER`` ``'user'`` + ``PGDATABASE`` ``'database'`` + ``PGHOST`` ``'host'`` + ``PGPORT`` ``'port'`` + ``PGPASSWORD`` ``'password'`` + ``PGSSLMODE`` ``'sslmode'`` + ``PGSSLKEY`` ``'sslkey'`` + ``PGCONNECT_TIMEOUT`` ``'connect_timeout'`` + ``PGREALM`` ``'kerberos4_realm'`` + ``PGKRBSRVNAME`` ``'kerberos5_service'`` + ``PGPASSFILE`` ``'pgpassfile'`` + ``PGTZ`` ``'settings' = {'timezone': }`` + ``PGDATESTYLE`` ``'settings' = {'datestyle': }`` + ``PGCLIENTENCODING`` ``'settings' = {'client_encoding': }`` + ``PGGEQO`` ``'settings' = {'geqo': }`` + ===================== ====================================== + + +.. _pg_passfile: + +PostgreSQL Password File +======================== + +The password file is a simple newline separated list of ``:`` separated fields. It +is located at ``$HOME/.pgpass`` for most systems and at +``%APPDATA%\postgresql\pgpass.conf`` for Windows based systems. However, the +``PGPASSFILE`` environment variable may be used to override that location. + +The lines in the file must be in the following form:: + + hostname:port:database:username:password + +A single asterisk, ``*``, may be used to indicate that any value will match the +field. However, this only effects fields other than ``password``. + +See http://www.postgresql.org/docs/current/static/libpq-pgpass.html for more +details. + +Client parameters produced by ``collect`` that have not been processed +by ``resolve_password`` will include a ``'pgpassfile'`` key. This is the value +that ``resolve_password`` will use to locate the pgpassfile to interrogate if a +password key is not present and it is not instructed to prompt for a password. + +.. warning:: + Connection creation interfaces will *not* resolve ``'pgpassfile'``, so it is + important that the parameters produced by ``collect()`` are properly processed + before an attempt is made to establish a connection. diff --git a/py_opengauss/documentation/cluster.rst b/py_opengauss/documentation/cluster.rst new file mode 100644 index 0000000000000000000000000000000000000000..1993ea280df5d1269579197c577aac6b31d6731e --- /dev/null +++ b/py_opengauss/documentation/cluster.rst @@ -0,0 +1,378 @@ +.. _cluster_management: + +****************** +Cluster Management +****************** + +py-postgresql provides cluster management tools in order to give the user +fine-grained control over a PostgreSQL cluster and access to information about an +installation of PostgreSQL. + + +.. _installation: + +Installations +============= + +`postgresql.installation.Installation` objects are primarily used to +access PostgreSQL installation information. Normally, they are created using a +dictionary constructed from the output of the pg_config_ executable:: + + from postgresql.installation import Installation, pg_config_dictionary + pg_install = Installation(pg_config_dictionary('/usr/local/pgsql/bin/pg_config')) + +The extraction of pg_config_ information is isolated from Installation +instantiation in order to allow Installations to be created from arbitrary +dictionaries. This can be useful in cases where the installation layout is +inconsistent with the standard PostgreSQL installation layout, or if a faux +Installation needs to be created for testing purposes. + + +Installation Interface Points +----------------------------- + + ``Installation(info)`` + Instantiate an Installation using the given information. Normally, this + information is extracted from a pg_config_ executable using + `postgresql.installation.pg_config_dictionary`:: + + info = pg_config_dictionary('/usr/local/pgsql/bin/pg_config') + pg_install = Installation(info) + + ``Installation.version`` + The installation's version string:: + + pg_install.version + 'PostgreSQL 9.0devel' + + ``Installation.version_info`` + A tuple containing the version's ``(major, minor, patch, state, level)``. + Where ``major``, ``minor``, ``patch``, and ``level`` are `int` objects, and + ``state`` is a `str` object:: + + pg_install.version_info + (9, 0, 0, 'devel', 0) + + ``Installation.ssl`` + A `bool` indicating whether or not the installation has SSL support. + + ``Installation.configure_options`` + The options given to the ``configure`` script that built the installation. The + options are represented using a dictionary object whose keys are normalized + long option names, and whose values are the option's argument. If the option + takes no argument, `True` will be used as the value. + + The normalization of the long option names consists of removing the preceding + dashes, lowering the string, and replacing any dashes with underscores. For + instance, ``--enable-debug`` will be ``enable_debug``:: + + pg_install.configure_options + {'enable_debug': True, 'with_libxml': True, + 'enable_cassert': True, 'with_libedit_preferred': True, + 'prefix': '/src/build/pg90', 'with_openssl': True, + 'enable_integer_datetimes': True, 'enable_depend': True} + + ``Installation.paths`` + The paths of the installation as a dictionary where the keys are the path + identifiers and the values are the absolute file system paths. For instance, + ``'bindir'`` is associated with ``$PREFIX/bin``, ``'libdir'`` is associated + with ``$PREFIX/lib``, etc. The paths included in this dictionary are + listed on the class' attributes: `Installation.pg_directories` and + `Installation.pg_executables`. + + The keys that point to installation directories are: ``bindir``, ``docdir``, + ``includedir``, ``pkgincludedir``, ``includedir_server``, ``libdir``, + ``pkglibdir``, ``localedir``, ``mandir``, ``sharedir``, and ``sysconfdir``. + + The keys that point to installation executables are: ``pg_config``, ``psql``, + ``initdb``, ``pg_resetxlog``, ``pg_controldata``, ``clusterdb``, ``pg_ctl``, + ``pg_dump``, ``pg_dumpall``, ``postgres``, ``postmaster``, ``reindexdb``, + ``vacuumdb``, ``ipcclean``, ``createdb``, ``ecpg``, ``createuser``, + ``createlang``, ``droplang``, ``dropuser``, and ``pg_restore``. + + .. note:: If the executable does not exist, the value will be `None` instead + of an absoluate path. + + To get the path to the psql_ executable:: + + from postgresql.installation import Installation + pg_install = Installation('/usr/local/pgsql/bin/pg_config') + psql_path = pg_install.paths['psql'] + + +Clusters +======== + +`postgresql.cluster.Cluster` is the class used to manage a PostgreSQL +cluster--a data directory created by initdb_. A Cluster represents a data +directory with respect to a given installation of PostgreSQL, so +creating a `postgresql.cluster.Cluster` object requires a +`postgresql.installation.Installation`, and a +file system path to the data directory. + +In part, a `postgresql.cluster.Cluster` is the Python programmer's variant of +the pg_ctl_ command. However, it goes beyond the basic process control +functionality and extends into initialization and configuration as well. + +A Cluster manages the server process using the `subprocess` module and +signals. The `subprocess.Popen` object, ``Cluster.daemon_process``, is +retained when the Cluster starts the server process itself. This gives +the Cluster access to the result code of server process when it exits, and the +ability to redirect stderr and stdout to a parameterized file object using +subprocess features. + +Despite its use of `subprocess`, Clusters can control a server process +that was *not* started by the Cluster's ``start`` method. + + +Initializing Clusters +--------------------- + +`postgresql.cluster.Cluster` provides a method for initializing a +`Cluster`'s data directory, ``init``. This method provides a Python interface to +the PostgreSQL initdb_ command. + +``init`` is a regular method and accepts a few keyword parameters. Normally, +parameters are directly mapped to initdb_ command options. However, ``password`` +makes use of initdb's capability to read the superuser's password from a file. +To do this, a temporary file is allocated internally by the method:: + + from postgresql.installation import Installation, pg_config_dictionary + from postgresql.cluster import Cluster + pg_install = Installation(pg_config_dictionary('/usr/local/pgsql/bin/pg_config')) + pg_cluster = Cluster(pg_install, 'pg_data') + pg_cluster.init(user = 'pg', password = 'secret', encoding = 'utf-8') + +The init method will block until the initdb command is complete. Once +initialized, the Cluster may be configured. + + +Configuring Clusters +-------------------- + +A Cluster's `configuration file`_ can be manipulated using the +`Cluster.settings` mapping. The mapping's methods will always access the +configuration file, so it may be desirable to cache repeat reads. Also, if +multiple settings are being applied, using the ``update()`` method may be +important to avoid writing the entire file multiple times:: + + pg_cluster.settings.update({'listen_addresses' : 'localhost', 'port' : '6543'}) + +Similarly, to avoid opening and reading the entire file multiple times, +`Cluster.settings.getset` should be used to retrieve multiple settings:: + + d = pg_cluster.settings.getset(set(('listen_addresses', 'port'))) + d + {'listen_addresses' : 'localhost', 'port' : '6543'} + +Values contained in ``settings`` are always Python strings:: + + assert pg_cluster.settings['max_connections'].__class__ is str + +The ``postgresql.conf`` file is only one part of the server configuration. +Structured access and manipulation of the pg_hba_ file is not +supported. Clusters only provide the file path to the pg_hba_ file:: + + hba = open(pg_cluster.hba_file) + +If the configuration of the Cluster is altered while the server process is +running, it may be necessary to signal the process that configuration changes +have been made. This signal can be sent using the ``Cluster.reload()`` method. +``Cluster.reload()`` will send a SIGHUP signal to the server process. However, +not all changes to configuration settings can go into effect after calling +``Cluster.reload()``. In those cases, the server process will need to be +shutdown and started again. + + +Controlling Clusters +-------------------- + +The server process of a Cluster object can be controlled with the ``start()``, +``stop()``, ``shutdown()``, ``kill()``, and ``restart()`` methods. +These methods start the server process, signal the server process, or, in the +case of restart, a combination of the two. + +When a Cluster starts the server process, it's ran as a subprocess. Therefore, +if the current process exits, the server process will exit as well. ``start()`` +does *not* automatically daemonize the server process. + +.. note:: Under Microsoft Windows, above does not hold true. The server process + will continue running despite the exit of the parent process. + +To terminate a server process, one of these three methods should be called: +``stop``, ``shutdown``, or ``kill``. ``stop`` is a graceful shutdown and will +*wait for all clients to disconnect* before shutting down. ``shutdown`` will +close any open connections and safely shutdown the server process. +``kill`` will immediately terminate the server process leading to recovery upon +starting the server process again. + +.. note:: Using ``kill`` may cause shared memory to be leaked. + +Normally, `Cluster.shutdown` is the appropriate way to terminate a server +process. + + +Cluster Interface Points +------------------------ + +Methods and properties available on `postgresql.cluster.Cluster` instances: + + ``Cluster(installation, data_directory)`` + Create a `postgresql.cluster.Cluster` object for the specified + `postgresql.installation.Installation`, and ``data_directory``. + + The ``data_directory`` must be an absoluate file system path. The directory + does *not* need to exist. The ``init()`` method may later be used to create + the cluster. + + ``Cluster.installation`` + The Cluster's `postgresql.installation.Installation` instance. + + ``Cluster.data_directory`` + The absolute path to the PostgreSQL data directory. + This directory may not exist. + + ``Cluster.init([encoding = None[, user = None[, password = None]]])`` + Run the `initdb`_ executable of the configured installation to initialize the + cluster at the configured data directory, `Cluster.data_directory`. + + ``encoding`` is mapped to ``-E``, the default database encoding. By default, + the encoding is determined from the environment's locale. + + ``user`` is mapped to ``-U``, the database superuser name. By default, the + current user's name. + + ``password`` is ultimately mapped to ``--pwfile``. The argument given to the + long option is actually a path to the temporary file that holds the given + password. + + Raises `postgresql.cluster.InitDBError` when initdb_ returns a non-zero result + code. + + Raises `postgresql.cluster.ClusterInitializationError` when there is no + initdb_ in the Installation. + + ``Cluster.initialized()`` + Whether or not the data directory exists, *and* if it looks like a PostgreSQL + data directory. Meaning, the directory must contain a ``postgresql.conf`` file + and a ``base`` directory. + + ``Cluster.drop()`` + Shutdown the Cluster's server process and completely remove the + `Cluster.data_directory` from the file system. + + ``Cluster.pid()`` + The server's process identifier as a Python `int`. `None` if there is no + server process running. + This is a method rather than a property as it may read the PID from a file + in cases where the server process was not started by the Cluster. + + ``Cluster.start([logfile = None[, settings = None]])`` + Start the PostgreSQL server process for the Cluster if it is not + already running. This will execute postgres_ as a subprocess. + + If ``logfile``, an opened and writable file object, is given, stderr and + stdout will be redirected to that file. By default, both stderr and stdout are + closed. + + If ``settings`` is given, the mapping or sequence of pairs will be used as + long options to the subprocess. For each item, ``--{key}={value}`` will be + given as an argument to the subprocess. + + ``Cluster.running()`` + Whether or not the cluster's server process is running. Returns `True` or + `False`. Even if `True` is returned, it does *not* mean that the server + process is ready to accept connections. + + ``Cluster.ready_for_connections()`` + Whether or not the Cluster is ready to accept connections. Usually called + after `Cluster.start`. + + Returns `True` when the Cluster can accept connections, `False` when it + cannot, and `None` if the Cluster's server process is not running at all. + + ``Cluster.wait_until_started([timeout = 10[, delay = 0.05]])`` + Blocks the process until the cluster is identified as being ready for + connections. Usually called after ``Cluster.start()``. + + Raises `postgresql.cluster.ClusterNotRunningError` if the server process is + not running at all. + + Raises `postgresql.cluster.ClusterTimeoutError` if + `Cluster.ready_for_connections()` does not return `True` within the given + `timeout` period. + + Raises `postgresql.cluster.ClusterStartupError` if the server process + terminates while polling for readiness. + + ``timeout`` and ``delay`` are both in seconds. Where ``timeout`` is the + maximum time to wait for the Cluster to be ready for connections, and + ``delay`` is the time to sleep between calls to + `Cluster.ready_for_connections()`. + + ``Cluster.stop()`` + Signal the cluster to shutdown when possible. The *server* will wait for all + clients to disconnect before shutting down. + + ``Cluster.shutdown()`` + Signal the cluster to shutdown immediately. Any open client connections will + be closed. + + ``Cluster.kill()`` + Signal the absolute destruction of the server process(SIGKILL). + *This will require recovery when the cluster is started again.* + *Shared memory may be leaked.* + + ``Cluster.wait_until_stopped([timeout = 10[, delay = 0.05]])`` + Blocks the process until the cluster is identified as being shutdown. Usually + called after `Cluster.stop` or `Cluster.shutdown`. + + Raises `postgresql.cluster.ClusterTimeoutError` if + `Cluster.ready_for_connections` does not return `None` within the given + `timeout` period. + + ``Cluster.reload()`` + Signal the server that it should reload its configuration files(SIGHUP). + Usually called after manipulating `Cluster.settings` or modifying the + contents of `Cluster.hba_file`. + + ``Cluster.restart([logfile = None[, settings = None[, timeout = 10]]])`` + Stop the server process, wait until it is stopped, start the server + process, and wait until it has started. + + .. note:: This calls ``Cluster.stop()``, so it will wait until clients + disconnect before starting up again. + + The ``logfile`` and ``settings`` parameters will be given to `Cluster.start`. + ``timeout`` will be given to `Cluster.wait_until_stopped` and + `Cluster.wait_until_started`. + + ``Cluster.settings`` + A `collections.abc.Mapping` interface to the ``postgresql.conf`` file of the + cluster. + + A notable extension to the mapping interface is the ``getset`` method. This + method will return a dictionary object containing the settings whose names + were contained in the `set` object given to the method. + This method should be used when multiple settings need to be retrieved from + the configuration file. + + ``Cluster.hba_file`` + The path to the cluster's pg_hba_ file. This property respects the HBA file + location setting in ``postgresql.conf``. Usually, ``$PGDATA/pg_hba.conf``. + + ``Cluster.daemon_path`` + The path to the executable to use to start the server process. + + ``Cluster.daemon_process`` + The `subprocess.Popen` instance of the server process. `None` if the server + process was not started or was not started using the Cluster object. + + +.. _pg_hba: http://www.postgresql.org/docs/current/static/auth-pg-hba-conf.html +.. _pg_config: http://www.postgresql.org/docs/current/static/app-pgconfig.html +.. _initdb: http://www.postgresql.org/docs/current/static/app-initdb.html +.. _psql: http://www.postgresql.org/docs/current/static/app-psql.html +.. _postgres: http://www.postgresql.org/docs/current/static/app-postgres.html +.. _pg_ctl: http://www.postgresql.org/docs/current/static/app-pg-ctl.html +.. _configuration file: http://www.postgresql.org/docs/current/static/runtime-config.html diff --git a/py_opengauss/documentation/copyman.rst b/py_opengauss/documentation/copyman.rst new file mode 100644 index 0000000000000000000000000000000000000000..d4a18cb16bcc8ee4e911498117e1d4ad6491f1fe --- /dev/null +++ b/py_opengauss/documentation/copyman.rst @@ -0,0 +1,317 @@ +.. _pg_copyman: + +*************** +Copy Management +*************** + +The `postgresql.copyman` module provides a way to quickly move COPY data coming +from one connection to many connections. Alternatively, it can be sourced +by arbitrary iterators and target arbitrary callables. + +Statement execution methods offer a way for running COPY operations +with iterators, but the cost of allocating objects for each row is too +significant for transferring gigabytes of COPY data from one connection to +another. The interfaces available on statement objects are primarily intended to +be used when transferring COPY data to and from arbitrary Python +objects. + +Direct connection-to-connection COPY operations can be performed using the +high-level `postgresql.copyman.transfer` function:: + + >>> from postgresql import copyman + >>> send_stmt = source.prepare("COPY (SELECT i FROM generate_series(1, 1000000) AS g(i)) TO STDOUT") + >>> destination.execute("CREATE TEMP TABLE loading_table (i int8)") + >>> receive_stmt = destination.prepare("COPY loading_table FROM STDIN") + >>> total_rows, total_bytes = copyman.transfer(send_stmt, receive_stmt) + +However, if more control is needed, the `postgresql.copyman.CopyManager` class +should be used directly. + + +Copy Managers +============= + +The `postgresql.copyman.CopyManager` class manages the Producer and the +Receivers involved in a COPY operation. Normally, +`postgresql.copyman.StatementProducer` and +`postgresql.copyman.StatementReceiver` instances. Naturally, a Producer is the +object that produces the COPY data to be given to the Manager's Receivers. + +Using a Manager directly means that there is a need for more control over +the operation. The Manager is both a context manager and an iterator. The +context manager interfaces handle initialization and finalization of the COPY +state, and the iterator provides an event loop emitting information about the +amount of COPY data transferred this cycle. Normal usage takes the form:: + + >>> from postgresql import copyman + >>> send_stmt = source.prepare("COPY (SELECT i FROM generate_series(1, 1000000) AS g(i)) TO STDOUT") + >>> destination.execute("CREATE TEMP TABLE loading_table (i int8)") + >>> receive_stmt = destination.prepare("COPY loading_table FROM STDIN") + >>> producer = copyman.StatementProducer(send_stmt) + >>> receiver = copyman.StatementReceiver(receive_stmt) + >>> + >>> with source.xact(), destination.xact(): + ... with copyman.CopyManager(producer, receiver) as copy: + ... for num_messages, num_bytes in copy: + ... update_rate(num_bytes) + +As an alternative to a for-loop inside a with-statement block, the `run` method +can be called to perform the operation:: + + >>> with source.xact(), destination.xact(): + ... copyman.CopyManager(producer, receiver).run() + +However, there is little benefit beyond using the high-level +`postgresql.copyman.transfer` function. + +Manager Interface Points +------------------------ + +Primarily, the `postgresql.copyman.CopyManager` provides a context manager and +an iterator for controlling the COPY operation. + + ``CopyManager.run()`` + Perform the entire COPY operation. + + ``CopyManager.__enter__()`` + Start the COPY operation. Connections taking part in the COPY should **not** + be used until ``__exit__`` is ran. + + ``CopyManager.__exit__(typ, val, tb)`` + Finish the COPY operation. Fails in the case of an incomplete + COPY, or an untrapped exception. Either returns `None` or raises the generalized + exception, `postgresql.copyman.CopyFail`. + + ``CopyManager.__iter__()`` + Returns the CopyManager instance. + + ``CopyManager.__next__()`` + Transfer the next chunk of COPY data to the receivers. Yields a tuple + consisting of the number of messages and bytes transferred, + ``(num_messages, num_bytes)``. Raises `StopIteration` when complete. + + Raises `postgresql.copyman.ReceiverFault` when a Receiver raises an + exception. + Raises `postgresql.copyman.ProducerFault` when the Producer raises an + exception. The original exception is available via the exception's + ``__context__`` attribute. + + ``CopyManager.reconcile(faulted_receiver)`` + Reconcile a faulted receiver. When a receiver faults, it will no longer + be in the set of Receivers. This method is used to signal to the manager that the + problem has been corrected, and the receiver is again ready to receive. + + ``CopyManager.receivers`` + The `builtins.set` of Receivers involved in the COPY operation. + + ``CopyManager.producer`` + The Producer emitting the data to be given to the Receivers. + + +Faults +====== + +The CopyManager generalizes any exceptions that occur during transfer. While +inside the context manager, `postgresql.copyman.Fault` may be raised if a +Receiver or a Producer raises an exception. A `postgresql.copyman.ProducerFault` +in the case of the Producer, and `postgresql.copyman.ReceiverFault` in the case +of the Receivers. + +.. note:: + Faults are only raised by `postgresql.copyman.CopyManager.__next__`. The + ``run()`` method will only raise `postgresql.copyman.CopyFail`. + +Receiver Faults +--------------- + +The Manager assumes the Fault is fatal to a Receiver, and immediately removes +it from the set of target receivers. Additionally, if the Fault exception goes +untrapped, the copy will ultimately fail. + +The Fault exception references the Manager that raised the exception, and the +actual exceptions that occurred associated with the Receiver that caused them. + +In order to identify the exception that caused a Fault, the ``faults`` attribute +on the `postgresql.copyman.ReceiverFault` must be referenced:: + + >>> from postgresql import copyman + >>> send_stmt = source.prepare("COPY (SELECT i FROM generate_series(1, 1000000) AS g(i)) TO STDOUT") + >>> destination.execute("CREATE TEMP TABLE loading_table (i int8)") + >>> receive_stmt = destination.prepare("COPY loading_table FROM STDIN") + >>> producer = copyman.StatementProducer(send_stmt) + >>> receiver = copyman.StatementReceiver(receive_stmt) + >>> + >>> with source.xact(), destination.xact(): + ... with copyman.CopyManager(producer, receiver) as copy: + ... while copy.receivers: + ... try: + ... for num_messages, num_bytes in copy: + ... update_rate(num_bytes) + ... break + ... except copyman.ReceiverFault as cf: + ... # Access the original exception using the receiver as the key. + ... original_exception = cf.faults[receiver] + ... if unknown_failure(original_exception): + ... ... + ... raise + + +ReceiverFault Properties +~~~~~~~~~~~~~~~~~~~~~~~~ + +The following attributes exist on `postgresql.copyman.ReceiverFault` instances: + + ``ReceiverFault.manager`` + The subject `postgresql.copyman.CopyManager` instance. + + ``ReceiverFault.faults`` + A dictionary mapping the Receiver to the exception raised by that Receiver. + + +Reconciliation +~~~~~~~~~~~~~~ + +When a `postgresql.copyman.ReceiverFault` is raised, the Manager immediately +removes the Receiver so that the COPY operation can continue. Continuation of +the COPY can occur by trapping the exception and continuing the iteration of the +Manager. However, if the fault is recoverable, the +`postgresql.copyman.CopyManager.reconcile` method must be used to reintroduce the +Receiver into the Manager's set. Faults must be trapped from within the +Manager's context:: + + >>> import socket + >>> from postgresql import copyman + >>> send_stmt = source.prepare("COPY (SELECT i FROM generate_series(1, 1000000) AS g(i)) TO STDOUT") + >>> destination.execute("CREATE TEMP TABLE loading_table (i int8)") + >>> receive_stmt = destination.prepare("COPY loading_table FROM STDIN") + >>> producer = copyman.StatementProducer(send_stmt) + >>> receiver = copyman.StatementReceiver(receive_stmt) + >>> + >>> with source.xact(), destination.xact(): + ... with copyman.CopyManager(producer, receiver) as copy: + ... while copy.receivers: + ... try: + ... for num_messages, num_bytes in copy: + ... update_rate(num_bytes) + ... except copyman.ReceiverFault as cf: + ... if isinstance(cf.faults[receiver], socket.timeout): + ... copy.reconcile(receiver) + ... else: + ... raise + +Recovering from Faults does add significant complexity to a COPY operation, +so, often, it's best to avoid conditions in which reconciliable Faults may +occur. + + +Producer Faults +--------------- + +Producer faults are normally fatal to the COPY operation and should rarely be +trapped. However, the Manager makes no state changes when a Producer faults, +so, unlike Receiver Faults, no reconciliation process is necessary; rather, +if it's safe to continue, the Manager's iterator should continue to be +processed. + +ProducerFault Properties +~~~~~~~~~~~~~~~~~~~~~~~~ + +The following attributes exist on `postgresql.copyman.ProducerFault` instances: + + ``ReceiverFault.manager`` + The subject `postgresql.copyman.CopyManager`. + + ``ReceiverFault.__context__`` + The original exception raised by the Producer. + + +Failures +======== + +When a COPY operation is aborted, either by an exception or by the iterator +being broken, a `postgresql.copyman.CopyFail` exception will be raised by the +Manager's ``__exit__()`` method. The `postgresql.copyman.CopyFail` exception +offers to record any exceptions that occur during the exit of the context +managers of the Producer and the Receivers. + + +CopyFail Properties +------------------- + +The following properties exist on `postgresql.copyman.CopyFail` exceptions: + + ``CopyFail.manager`` + The Manager whose COPY operation failed. + + ``CopyFail.receiver_faults`` + A dictionary mapping a `postgresql.copyman.Receiver` to the exception raised + by that Receiver's ``__exit__``. `None` if no exceptions were raised by the + Receivers. + + ``CopyFail.producer_fault`` + The exception Raised by the `postgresql.copyman.Producer`. `None` if none. + + +Producers +========= + +The following Producers are available: + + ``postgresql.copyman.StatementProducer(postgresql.api.Statement)`` + Given a Statement producing COPY data, construct a Producer. + + ``postgresql.copyman.IteratorProducer(collections.abc.Iterator)`` + Given an Iterator producing *chunks* of COPY lines, construct a Producer to + manage the data coming from the iterator. + + +Receivers +========= + + ``postgresql.copyman.StatementReceiver(postgresql.api.Statement)`` + Given a Statement producing COPY data, construct a Producer. + + ``postgresql.copyman.CallReceiver(callable)`` + Given a callable, construct a Receiver that will transmit COPY data in chunks + of lines. That is, the callable will be given a list of COPY lines for each + transfer cycle. + + +Terminology +=========== + +The following terms are regularly used to describe the implementation and +processes of the `postgresql.copyman` module: + + Manager + The object used to manage data coming from a Producer and being given to the + Receivers. It also manages the necessary initialization and finalization steps + required by those factors. + + Producer + The object used to produce the COPY data to be given to the Receivers. The + source. + + Receiver + An object that consumes COPY data. A target. + + Fault + Specifically, `postgresql.copyman.Fault` exceptions. A Fault is raised + when a Receiver or a Producer raises an exception during the COPY operation. + + Reconciliation + Generally, the steps performed by the "reconcile" method on + `postgresql.copyman.CopyManager` instances. More precisely, the + necessary steps for a Receiver's reintroduction into the COPY operation after + a Fault. + + Failed Copy + A failed copy is an aborted COPY operation. This occurs in situations of + untrapped exceptions or an incomplete COPY. Specifically, the COPY will be + noted as failed in cases where the Manager's iterator is *not* ran until + exhaustion. + + Realignment + The process of providing compensating data to the Receivers so that the + connection will be on a message boundary. Occurs when the COPY operation + is aborted. diff --git a/py_opengauss/documentation/driver.rst b/py_opengauss/documentation/driver.rst new file mode 100644 index 0000000000000000000000000000000000000000..aaebde26a48be1c41f1b2ce566954402105ab5d8 --- /dev/null +++ b/py_opengauss/documentation/driver.rst @@ -0,0 +1,1806 @@ +.. _db_interface: + +****** +Driver +****** + +`postgresql.driver` provides a PG-API, `postgresql.api`, interface to a +PostgreSQL server using PQ version 3.0 to facilitate communication. It makes +use of the protocol's extended features to provide binary datatype transmission +and protocol level prepared statements for strongly typed parameters. + +`postgresql.driver` currently supports PostgreSQL servers as far back as 8.0. +Prior versions are not tested. While any version of PostgreSQL supporting +version 3.0 of the PQ protocol *should* work, many features may not work due to +absent functionality in the remote end. + +For DB-API 2.0 users, the driver module is located at +`postgresql.driver.dbapi20`. The DB-API 2.0 interface extends PG-API. All of the +features discussed in this chapter are available on DB-API connections. + +.. warning:: + PostgreSQL versions 8.1 and earlier do not support standard conforming + strings. In order to avoid subjective escape methods on connections, + `postgresql.driver.pq3` enables the ``standard_conforming_strings`` setting + by default. Greater care must be taken when working versions that do not + support standard strings. + **The majority of issues surrounding the interpolation of properly quoted literals can be easily avoided by using parameterized statements**. + +The following identifiers are regularly used as shorthands for significant +interface elements: + + ``db`` + `postgresql.api.Connection`, a database connection. `Connections`_ + + ``ps`` + `postgresql.api.Statement`, a prepared statement. `Prepared Statements`_ + + ``c`` + `postgresql.api.Cursor`, a cursor; the results of a prepared statement. + `Cursors`_ + + ``C`` + `postgresql.api.Connector`, a connector. `Connectors`_ + + +Establishing a Connection +========================= + +There are many ways to establish a `postgresql.api.Connection` to a +PostgreSQL server using `postgresql.driver`. This section discusses those, +connection creation, interfaces. + + +`postgresql.open` +----------------- + +In the root package module, the ``open()`` function is provided for accessing +databases using a locator string and optional connection keywords. The string +taken by `postgresql.open` is a URL whose components make up the client +parameters:: + + >>> db = postgresql.open("pq://localhost/postgres") + +This will connect to the host, ``localhost`` and to the database named +``postgres`` via the ``pq`` protocol. open will inherit client parameters from +the environment, so the user name given to the server will come from +``$PGUSER``, or if that is unset, the result of `getpass.getuser`--the username +of the user running the process. The user's "pgpassfile" will even be +referenced if no password is given:: + + >>> db = postgresql.open("pq://username:password@localhost/postgres") + +In this case, the password *is* given, so ``~/.pgpass`` would never be +referenced. The ``user`` client parameter is also given, ``username``, so +``$PGUSER`` or `getpass.getuser` will not be given to the server. + +Settings can also be provided by the query portion of the URL:: + + >>> db = postgresql.open("pq://user@localhost/postgres?search_path=public&timezone=mst") + +The above syntax ultimately passes the query as settings(see the description of +the ``settings`` keyword in `Connection Keywords`). Driver parameters require a +distinction. This distinction is made when the setting's name is wrapped in +square-brackets, '[' and ']': + + >>> db = postgresql.open("pq://user@localhost/postgres?[sslmode]=require&[connect_timeout]=5") + +``sslmode`` and ``connect_timeout`` are driver parameters. These are never sent +to the server, but if they were not in square-brackets, they would be, and the +driver would never identify them as driver parameters. + +The general structure of a PQ-locator is:: + + protocol://user:password@host:port/database?[driver_setting]=value&server_setting=value + +Optionally, connection keyword arguments can be used to override anything given +in the locator:: + + >>> db = postgresql.open("pq://user:secret@host", password = "thE_real_sekrat") + +Or, if the locator is not desired, individual keywords can be used exclusively:: + + >>> db = postgresql.open(user = 'user', host = 'localhost', port = 6543) + +In fact, all arguments to `postgresql.open` are optional as all arguments are +keywords; ``iri`` is merely the first keyword argument taken by +`postgresql.open`. If the environment has all the necessary parameters for a +successful connection, there is no need to pass anything to open:: + + >>> db = postgresql.open() + +For a complete list of keywords that `postgresql.open` can accept, see +`Connection Keywords`_. +For more information about the environment variables, see :ref:`pg_envvars`. +For more information about the ``pgpassfile``, see :ref:`pg_passfile`. + +`postgresql.driver.connect` +--------------------------- + +`postgresql.open` is a high-level interface to connection creation. It provides +password resolution services and client parameter inheritance. For some +applications, this is undesirable as such implicit inheritance may lead to +failures due to unanticipated parameters being used. For those applications, +use of `postgresql.open` is not recommended. Rather, `postgresql.driver.connect` +should be used when explicit parameterization is desired by an application: + + >>> import py_opengauss.driver as pg_driver + >>> db = pg_driver.connect( + ... user = 'usename', + ... password = 'secret', + ... host = 'localhost', + ... port = 5432 + ... ) + +This will create a connection to the server listening on port + + >>> import py_opengauss.driver as pg_driver + >>> db = pg_driver.connect( + ... user = 'usename', + ... password = 'secret', + ... host = 'localhost', + ... port = 5432 + ... ) + +This will create a connection to the server listening on port + + >>> import postgresql.driver as pg_driver + >>> db = pg_driver.connect( + ... user = 'usename', + ... password = 'secret', + ... host = 'localhost', + ... port = 5432 + ... ) + +This will create a connection to the server listening on port ``5432`` +on the host ``localhost`` as the user ``usename`` with the password ``secret``. + +.. note:: + `connect` will *not* inherit parameters from the environment as libpq-based drivers do. + +See `Connection Keywords`_ for a full list of acceptable keyword parameters and +their meaning. + + +Connectors +---------- + +Connectors are the supporting objects used to instantiate a connection. They +exist for the purpose of providing connections with the necessary abstractions +for facilitating the client's communication with the server, *and to act as a +container for the client parameters*. The latter purpose is of primary interest +to this section. + +Each connection object is associated with its connector by the ``connector`` +attribute on the connection. This provides the user with access to the +parameters used to establish the connection in the first place, and the means to +create another connection to the same server. The attributes on the connector +should *not* be altered. If parameter changes are needed, a new connector should +be created. + +The attributes available on a connector are consistent with the names of the +connection parameters described in `Connection Keywords`_, so that list can be +used as a reference to identify the information available on the connector. + +Connectors fit into the category of "connection creation interfaces", so +connector instantiation normally takes the same parameters that the +`postgresql.driver.connect` function takes. + +.. note:: + Connector implementations are specific to the transport, so keyword arguments + like ``host`` and ``port`` aren't supported by the ``Unix`` connector. + +The driver, `postgresql.driver.default` provides a set of connectors for making +a connection: + + ``postgresql.driver.default.host(...)`` + Provides a ``getaddrinfo()`` abstraction for establishing a connection. + + ``postgresql.driver.default.ip4(...)`` + Connect to a single IPv4 addressed host. + + ``postgresql.driver.default.ip6(...)`` + Connect to a single IPv6 addressed host. + + ``postgresql.driver.default.unix(...)`` + Connect to a single unix domain socket. Requires the ``unix`` keyword which + must be an absolute path to the unix domain socket to connect to. + +``host`` is the usual connector used to establish a connection:: + + >>> C = postgresql.driver.default.host( + ... user = 'auser', + ... host = 'foo.com', + ... port = 5432) + >>> # create + >>> db = C() + >>> # establish + >>> db.connect() + +If a constant internet address is used, ``ip4`` or ``ip6`` can be used:: + + >>> C = postgresql.driver.default.ip4(user='auser', host='127.0.0.1', port=5432) + >>> db = C() + >>> db.connect() + +Additionally, ``db.connect()`` on ``db.__enter__()`` for with-statement support: + + >>> with C() as db: + ... ... + +Connectors are constant. They have no knowledge of PostgreSQL service files, +environment variables or LDAP services, so changes made to those facilities +will *not* be reflected in a connector's configuration. If the latest +information from any of these sources is needed, a new connector needs to be +created as the credentials have changed. + +.. note:: + ``host`` connectors use ``getaddrinfo()``, so if DNS changes are made, + new connections *will* use the latest information. + + +Connection Keywords +------------------- + +The following is a list of keywords accepted by connection creation +interfaces: + + ``user`` + The user to connect as. + + ``password`` + The user's password. + + ``database`` + The name of the database to connect to. (PostgreSQL defaults it to `user`) + + ``host`` + The hostname or IP address to connect to. + + ``port`` + The port on the host to connect to. + + ``unix`` + The unix domain socket to connect to. Exclusive with ``host`` and ``port``. + Expects a string containing the *absolute path* to the unix domain socket to + connect to. + + ``settings`` + A dictionary or key-value pair sequence stating the parameters to give to the + database. These settings are included in the startup packet, and should be + used carefully as when an invalid setting is given, it will cause the + connection to fail. + + ``connect_timeout`` + Amount of time to wait for a connection to be made. (in seconds) + + ``server_encoding`` + Hint given to the driver to properly encode password data and some information + in the startup packet. + This should only be used in cases where connections cannot be made due to + authentication failures that occur while using known-correct credentials. + + ``sslmode`` + ``'disable'`` + Don't allow SSL connections. + ``'allow'`` + Try without SSL first, but if that doesn't work, try with. + ``'prefer'`` + Try SSL first, then without. + ``'require'`` + Require an SSL connection. + + ``sslcrtfile`` + Certificate file path given to `ssl.wrap_socket`. + + ``sslkeyfile`` + Key file path given to `ssl.wrap_socket`. + + ``sslrootcrtfile`` + Root certificate file path given to `ssl.wrap_socket` + + ``sslrootcrlfile`` + Revocation list file path. [Currently not checked.] + + +Connections +=========== + +`postgresql.open` and `postgresql.driver.connect` provide the means to +establish a connection. Connections provide a `postgresql.api.Database` +interface to a PostgreSQL server; specifically, a `postgresql.api.Connection`. + +Connections are one-time objects. Once, it is closed or lost, it can longer be +used to interact with the database provided by the server. If further use of the +server is desired, a new connection *must* be established. + +.. note:: + Cannot connect failures, exceptions raised on ``connect()``, are also terminal. + +In cases where operations are performed on a closed connection, a +`postgresql.exceptions.ConnectionDoesNotExistError` will be raised. + + +Database Interface Points +------------------------- + +After a connection is established:: + + >>> import postgresql + >>> db = postgresql.open(...) + +The methods and properties on the connection object are ready for use: + + ``Connection.prepare(sql_statement_string)`` + Create a `postgresql.api.Statement` object for querying the database. + This provides an "SQL statement template" that can be executed multiple times. + See `Prepared Statements`_ for more information. + + ``Connection.proc(procedure_id)`` + Create a `postgresql.api.StoredProcedure` object referring to a stored + procedure on the database. The returned object will provide a + `collections.abc.Callable` interface to the stored procedure on the server. See + `Stored Procedures`_ for more information. + + ``Connection.statement_from_id(statement_id)`` + Create a `postgresql.api.Statement` object from an existing statement + identifier. This is used in cases where the statement was prepared on the + server. See `Prepared Statements`_ for more information. + + ``Connection.cursor_from_id(cursor_id)`` + Create a `postgresql.api.Cursor` object from an existing cursor identifier. + This is used in cases where the cursor was declared on the server. See + `Cursors`_ for more information. + + ``Connection.do(language, source)`` + Execute a DO statement on the server using the specified language. + *DO statements are available on PostgreSQL 9.0 and greater.* + *Executing this method on servers that do not support DO statements will* + *likely cause a SyntaxError*. + + ``Connection.execute(sql_statements_string)`` + Run a block of SQL on the server. This method returns `None` unless an error + occurs. If errors occur, the processing of the statements will stop and the + error will be raised. + + ``Connection.xact(isolation = None, mode = None)`` + The `postgresql.api.Transaction` constructor for creating transactions. + This method creates a transaction reference. The transaction will not be + started until it's instructed to do so. See `Transactions`_ for more + information. + + ``Connection.settings`` + A property providing a `collections.abc.MutableMapping` interface to the + database's SQL settings. See `Settings`_ for more information. + + ``Connection.clone()`` + Create a new connection object based on the same factors that were used to + create ``db``. The new connection returned will already be connected. + + ``Connection.msghook(msg)`` + By default, the `msghook` attribute does not exist. If set to a callable, any + message that occurs during an operation of the database or an operation of a + database derived object will be given to the callable. See the + `Database Messages`_ section for more information. + + ``Connection.listen(*channels)`` + Start listening for asynchronous notifications in the specified channels. + Sends a batch of ``LISTEN`` statements to the server. + + ``Connection.unlisten(*channels)`` + Stop listening for asynchronous notifications in the specified channels. + Sends a batch of ``UNLISTEN`` statements to the server. + + ``Connection.listening_channels()`` + Return an iterator producing the channel names that are currently being + listened to. + + ``Connection.notify(*channels, **channel_and_payload)`` + NOTIFY the channels with the given payload. Sends a batch of ``NOTIFY`` + statements to the server. + + Equivalent to issuing "NOTIFY " or "NOTIFY , " + for each item in `channels` and `channel_and_payload`. All NOTIFYs issued + will occur in the same transaction, regardless of auto-commit. + + The items in `channels` can either be a string or a tuple. If a string, + no payload is given, but if an item is a `builtins.tuple`, the second item + in the pair will be given as the payload, and the first as the channel. + `channels` offers a means to issue NOTIFYs in guaranteed order:: + + >>> db.notify('channel1', ('different_channel', 'payload')) + + In the above, ``NOTIFY "channel1";`` will be issued first, followed by + ``NOTIFY "different_channel", 'payload';``. + + The items in `channel_and_payload` are all payloaded NOTIFYs where the + keys are the channels and the values are the payloads. Order is undefined:: + + >>> db.notify(channel_name = 'payload_data') + + `channels` and `channels_and_payload` can be used together. In such cases all + NOTIFY statements generated from `channels_and_payload` will follow those in + `channels`. + + ``Connection.iternotifies(timeout = None)`` + Return an iterator to the NOTIFYs received on the connection. The iterator + will yield notification triples consisting of ``(channel, payload, pid)``. + While iterating, the connection should *not* be used in other threads. + The optional timeout can be used to enable "idle" events in which `None` + objects will be yielded by the iterator. + See :ref:`notifyman` for details. + +When a connection is established, certain pieces of information are collected from +the backend. The following are the attributes set on the connection object after +the connection is made: + + ``Connection.version`` + The version string of the *server*; the result of ``SELECT version()``. + + ``Connection.version_info`` + A ``sys.version_info`` form of the ``server_version`` setting. eg. + ``(8, 1, 2, 'final', 0)``. + + ``Connection.security`` + `None` if no security. ``'ssl'`` if SSL is enabled. + + ``Connection.backend_id`` + The process-id of the backend process. + + ``Connection.backend_start`` + When backend was started. ``datetime.datetime`` instance. + + ``Connection.client_address`` + The address of the client that the backend is communicating with. + + ``Connection.client_port`` + The port of the client that the backend is communicating with. + + ``Connection.fileno()`` + Method to get the file descriptor number of the connection's socket. This + method will return `None` if the socket object does not have a ``fileno``. + Under normal circumstances, it will return an `int`. + +The ``backend_start``, ``client_address``, and ``client_port`` are collected +from pg_stat_activity. If this information is unavailable, the attributes will +be `None`. + + +Prepared Statements +=================== + +Prepared statements are the primary entry point for initiating an operation on +the database. Prepared statement objects represent a request that will, likely, +be sent to the database at some point in the future. A statement is a single +SQL command. + +The ``prepare`` entry point on the connection provides the standard method for +creating a `postgersql.api.Statement` instance bound to the +connection(``db``) from an SQL statement string:: + + >>> ps = db.prepare("SELECT 1") + >>> ps() + [(1,)] + +Statement objects may also be created from a statement identifier using the +``statement_from_id`` method on the connection. When this method is used, the +statement must have already been prepared or an error will be raised. + + >>> db.execute("PREPARE a_statement_id AS SELECT 1;") + >>> ps = db.statement_from_id('a_statement_id') + >>> ps() + [(1,)] + +When a statement is executed, it binds any given parameters to a *new* cursor +and the entire result-set is returned. + +Statements created using ``prepare()`` will leverage garbage collection in order +to automatically close statements that are no longer referenced. However, +statements created from pre-existing identifiers, ``statement_from_id``, must +be explicitly closed if the statement is to be discarded. + +Statement objects are one-time objects. Once closed, they can no longer be used. + + +Statement Interface Points +-------------------------- + +Prepared statements can be executed just like functions: + + >>> ps = db.prepare("SELECT 'hello, world!'") + >>> ps() + [('hello, world!',)] + +The default execution method, ``__call__``, produces the entire result set. It +is the simplest form of statement execution. Statement objects can be executed in +different ways to accommodate for the larger results or random access(scrollable +cursors). + +Prepared statement objects have a few execution methods: + + ``Statement(*parameters)`` + As shown before, statement objects can be invoked like a function to get + the statement's results. + + ``Statement.rows(*parameters)`` + Return a iterator to all the rows produced by the statement. This + method will stream rows on demand, so it is ideal for situations where + each individual row in a large result-set must be processed. + + ``iter(Statement)`` + Convenience interface that executes the ``rows()`` method without arguments. + This enables the following syntax: + + >>> for table_name, in db.prepare("SELECT table_name FROM information_schema.tables"): + ... print(table_name) + + ``Statement.column(*parameters)`` + Return a iterator to the first column produced by the statement. This + method will stream values on demand, and *should* only be used with statements + that have a single column; otherwise, bandwidth will ultimately be wasted as + the other columns will be dropped. + *This execution method cannot be used with COPY statements.* + + ``Statement.first(*parameters)`` + For simple statements, cursor objects are unnecessary. + Consider the data contained in ``c`` from above, 'hello world!'. To get at this + data directly from the ``__call__(...)`` method, it looks something like:: + + >>> ps = db.prepare("SELECT 'hello, world!'") + >>> ps()[0][0] + 'hello, world!' + + To simplify access to simple data, the ``first`` method will simply return + the "first" of the result set:: + + >>> ps.first() + 'hello, world!' + + The first value. + When the result set consists of a single column, ``first()`` will return + that column in the first row. + + The first row. + When the result set consists of multiple columns, ``first()`` will return + that first row. + + The first, and only, row count. + When DML--for instance, an INSERT-statement--is executed, ``first()`` will + return the row count returned by the statement as an integer. + + .. note:: + DML that returns row data, RETURNING, will *not* return a row count. + + The result set created by the statement determines what is actually returned. + Naturally, a statement used with ``first()`` should be crafted with these + rules in mind. + + ``Statement.chunks(*parameters)`` + This access point is designed for situations where rows are being streamed out + quickly. It is a method that returns a ``collections.abc.Iterator`` that produces + *sequences* of rows. This is the most efficient way to get rows from the + database. The rows in the sequences are ``builtins.tuple`` objects. + + ``Statement.declare(*parameters)`` + Create a scrollable cursor with hold. This returns a `postgresql.api.Cursor` + ready for accessing random rows in the result-set. Applications that use the + database to support paging can use this method to manage the view. + + ``Statement.close()`` + Close the statement inhibiting further use. + + ``Statement.load_rows(collections.abc.Iterable(parameters))`` + Given an iterable producing parameters, execute the statement for each + iteration. Always returns `None`. + + ``Statement.load_chunks(collections.abc.Iterable(collections.abc.Iterable(parameters)))`` + Given an iterable of iterables producing parameters, execute the statement + for each parameter produced. However, send the all execution commands with + the corresponding parameters of each chunk before reading any results. + Always returns `None`. This access point is designed to be used in conjunction + with ``Statement.chunks()`` for transferring rows from one connection to another with + great efficiency:: + + >>> dst.prepare(...).load_chunks(src.prepare(...).chunks()) + + ``Statement.clone()`` + Create a new statement object based on the same factors that were used to + create ``ps``. + + ``Statement.msghook(msg)`` + By default, the `msghook` attribute does not exist. If set to a callable, any + message that occurs during an operation of the statement or an operation of a + statement derived object will be given to the callable. See the + `Database Messages`_ section for more information. + +In order to provide the appropriate type transformations, the driver must +acquire metadata about the statement's parameters and results. This data is +published via the following properties on the statement object: + + ``Statement.sql_parameter_types`` + A sequence of SQL type names specifying the types of the parameters used in + the statement. + + ``Statement.sql_column_types`` + A sequence of SQL type names specifying the types of the columns produced by + the statement. `None` if the statement does not return row-data. + + ``Statement.pg_parameter_types`` + A sequence of PostgreSQL type Oid's specifying the types of the parameters + used in the statement. + + ``Statement.pg_column_types`` + A sequence of PostgreSQL type Oid's specifying the types of the columns produced by + the statement. `None` if the statement does not return row-data. + + ``Statement.parameter_types`` + A sequence of Python types that the statement expects. + + ``Statement.column_types`` + A sequence of Python types that the statement will produce. + + ``Statement.column_names`` + A sequence of `str` objects specifying the names of the columns produced by + the statement. `None` if the statement does not return row-data. + +The indexes of the parameter sequences correspond to the parameter's +identifier, N+1: ``sql_parameter_types[0]`` -> ``'$1'``. + + >>> ps = db.prepare("SELECT $1::integer AS intname, $2::varchar AS chardata") + >>> ps.sql_parameter_types + ('INTEGER','VARCHAR') + >>> ps.sql_column_types + ('INTEGER','VARCHAR') + >>> ps.column_names + ('intname','chardata') + >>> ps.column_types + (, ) + + +Parameterized Statements +------------------------ + +Statements can take parameters. Using statement parameters is the recommended +way to interrogate the database when variable information is needed to formulate +a complete request. In order to do this, the statement must be defined using +PostgreSQL's positional parameter notation. ``$1``, ``$2``, ``$3``, etc:: + + >>> ps = db.prepare("SELECT $1") + >>> ps('hello, world!')[0][0] + 'hello, world!' + +PostgreSQL determines the type of the parameter based on the context of the +parameter's identifier:: + + >>> ps = db.prepare( + ... "SELECT * FROM information_schema.tables WHERE table_name = $1 LIMIT $2" + ... ) + >>> ps("tables", 1) + [('postgres', 'information_schema', 'tables', 'VIEW', None, None, None, None, None, 'NO', 'NO', None)] + +Parameter ``$1`` in the above statement will take on the type of the +``table_name`` column and ``$2`` will take on the type required by the LIMIT +clause(text and int8). + +However, parameters can be forced to a specific type using explicit casts: + + >>> ps = db.prepare("SELECT $1::integer") + >>> ps.first(-400) + -400 + +Parameters are typed. PostgreSQL servers provide the driver with the +type information about a positional parameter, and the serialization routine +will raise an exception if the given object is inappropriate. The Python +types expected by the driver for a given SQL-or-PostgreSQL type are listed +in `Type Support`_. + +This usage of types is not always convenient. Notably, the `datetime` module +does not provide a friendly way for a user to express intervals, dates, or +times. There is a likely inclination to forego these parameter type +requirements. + +In such cases, explicit casts can be made to work-around the type +requirements:: + + >>> ps = db.prepare("SELECT $1::text::date") + >>> ps.first('yesterday') + datetime.date(2009, 3, 11) + +The parameter, ``$1``, is given to the database as a string, which is then +promptly cast into a date. Of course, without the explicit cast as text, the +outcome would be different:: + + >>> ps = db.prepare("SELECT $1::date") + >>> ps.first('yesterday') + Traceback: + ... + postgresql.exceptions.ParameterError + +The function that processes the parameter expects a `datetime.date` object, and +the given `str` object does not provide the necessary interfaces for the +conversion, so the driver raises a `postgresql.exceptions.ParameterError` from +the original conversion exception. + + +Inserting and DML +----------------- + +Loading data into the database is facilitated by prepared statements. In these +examples, a table definition is necessary for a complete illustration:: + + >>> db.execute( + ... """ + ... CREATE TABLE employee ( + ... employee_name text, + ... employee_salary numeric, + ... employee_dob date, + ... employee_hire_date date + ... ); + ... """ + ... ) + +Create an INSERT statement using ``prepare``:: + + >>> mkemp = db.prepare("INSERT INTO employee VALUES ($1, $2, $3, $4)") + +And add "Mr. Johnson" to the table:: + + >>> import datetime + >>> r = mkemp( + ... "John Johnson", + ... "92000", + ... datetime.date(1950, 12, 10), + ... datetime.date(1998, 4, 23) + ... ) + >>> print(r[0]) + INSERT + >>> print(r[1]) + 1 + +The execution of DML will return a tuple. This tuple contains the completed +command name and the associated row count. + +Using the call interface is fine for making a single insert, but when multiple +records need to be inserted, it's not the most efficient means to load data. For +multiple records, the ``ps.load_rows([...])`` provides an efficient way to load +large quantities of structured data:: + + >>> from datetime import date + >>> mkemp.load_rows([ + ... ("Jack Johnson", "85000", date(1962, 11, 23), date(1990, 3, 5)), + ... ("Debra McGuffer", "52000", date(1973, 3, 4), date(2002, 1, 14)), + ... ("Barbara Smith", "86000", date(1965, 2, 24), date(2005, 7, 19)), + ... ]) + +While small, the above illustrates the ``ps.load_rows()`` method taking an +iterable of tuples that provides parameters for the each execution of the +statement. + +``load_rows`` is also used to support ``COPY ... FROM STDIN`` statements:: + + >>> copy_emps_in = db.prepare("COPY employee FROM STDIN") + >>> copy_emps_in.load_rows([ + ... b'Emp Name1\t72000\t1970-2-01\t1980-10-22\n', + ... b'Emp Name2\t62000\t1968-9-11\t1985-11-1\n', + ... b'Emp Name3\t62000\t1968-9-11\t1985-11-1\n', + ... ]) + +Copy data goes in as bytes and come out as bytes regardless of the type of COPY +taking place. It is the user's obligation to make sure the row-data is in the +appropriate encoding. + + +COPY Statements +--------------- + +`postgresql.driver` transparently supports PostgreSQL's COPY command. To the +user, COPY will act exactly like other statements that produce tuples; COPY +tuples, however, are `bytes` objects. The only distinction in usability is that +the COPY *should* be completed before other actions take place on the +connection--this is important when a COPY is invoked via ``rows()`` or +``chunks()``. + +In situations where other actions are invoked during a ``COPY TO STDOUT``, the +entire result set of the COPY will be read. However, no error will be raised so +long as there is enough memory available, so it is *very* desirable to avoid +doing other actions on the connection while a COPY is active. + +In situations where other actions are invoked during a ``COPY FROM STDIN``, a +COPY failure error will occur. The driver manages the connection state in such +a way that will purposefully cause the error as the COPY was inappropriately +interrupted. This not usually a problem as ``load_rows(...)`` and +``load_chunks(...)`` methods must complete the COPY command before returning. + +Copy data is always transferred using ``bytes`` objects. Even in cases where the +COPY is not in ``BINARY`` mode. Any needed encoding transformations *must* be +made the caller. This is done to avoid any unnecessary overhead by default:: + + >>> ps = db.prepare("COPY (SELECT i FROM generate_series(0, 99) AS g(i)) TO STDOUT") + >>> r = ps() + >>> len(r) + 100 + >>> r[0] + b'0\n' + >>> r[-1] + b'99\n' + +Of course, invoking a statement that way will read the entire result-set into +memory, which is not usually desirable for COPY. Using the ``chunks(...)`` +iterator is the *fastest* way to move data:: + + >>> ci = ps.chunks() + >>> import sys + >>> for rowset in ps.chunks(): + ... sys.stdout.buffer.writelines(rowset) + ... + + +``COPY FROM STDIN`` commands are supported via +`postgresql.api.Statement.load_rows`. Each invocation to +``load_rows`` is a single invocation of COPY. ``load_rows`` takes an iterable of +COPY lines to send to the server:: + + >>> db.execute(""" + ... CREATE TABLE sample_copy ( + ... sc_number int, + ... sc_text text + ... ); + ... """) + >>> copyin = db.prepare('COPY sample_copy FROM STDIN') + >>> copyin.load_rows([ + ... b'123\tone twenty three\n', + ... b'350\ttree fitty\n', + ... ]) + +For direct connection-to-connection COPY, use of ``load_chunks(...)`` is +recommended as it will provide the most efficient transfer method:: + + >>> copyout = src.prepare('COPY atable TO STDOUT') + >>> copyin = dst.prepare('COPY atable FROM STDIN') + >>> copyin.load_chunks(copyout.chunks()) + +Specifically, each chunk of row data produced by ``chunks()`` will be written in +full by ``load_chunks()`` before getting another chunk to write. + + +Cursors +======= + +When a prepared statement is declared, ``ps.declare(...)``, a +`postgresql.api.Cursor` is created and returned for random access to the rows in +the result set. Direct use of cursors is primarily useful for applications that +need to implement paging. For situations that need to iterate over the result +set, the ``ps.rows(...)`` or ``ps.chunks(...)`` execution methods should be +used. + +Cursors can also be created directly from ``cursor_id``'s using the +``cursor_from_id`` method on connection objects:: + + >>> db.execute('DECLARE the_cursor_id CURSOR WITH HOLD FOR SELECT 1;') + >>> c = db.cursor_from_id('the_cursor_id') + >>> c.read() + [(1,)] + >>> c.close() + +.. hint:: + If the cursor that needs to be opened is going to be treated as an iterator, + then a FETCH-statement should be prepared instead using ``cursor_from_id``. + +Like statements created from an identifier, cursors created from an identifier +must be explicitly closed in order to destroy the object on the server. +Likewise, cursors created from statement invocations will be automatically +released when they are no longer referenced. + +.. note:: + PG-API cursors are a direct interface to single result-set SQL cursors. This + is in contrast with DB-API cursors, which have interfaces for dealing with + multiple result-sets. There is no execute method on PG-API cursors. + + +Cursor Interface Points +----------------------- + +For cursors that return row data, these interfaces are provided for accessing +those results: + + ``Cursor.read(quantity = None, direction = None)`` + This method name is borrowed from `file` objects, and are semantically + similar. However, this being a cursor, rows are returned instead of bytes or + characters. When the number of rows returned is less then the quantity + requested, it means that the cursor has been exhausted in the configured + direction. The ``direction`` argument can be either ``'FORWARD'`` or `True` + to FETCH FORWARD, or ``'BACKWARD'`` or `False` to FETCH BACKWARD. + + Like, ``seek()``, the ``direction`` *property* on the cursor object effects + this method. + + ``Cursor.seek(position[, whence = 0])`` + When the cursor is scrollable, this seek interface can be used to move the + position of the cursor. See `Scrollable Cursors`_ for more information. + + ``next(Cursor)`` + This fetches the next row in the cursor object. Cursors support the iterator + protocol. While equivalent to ``cursor.read(1)[0]``, `StopIteration` is raised + if the returned sequence is empty. (``__next__()``) + + ``Cursor.close()`` + For cursors opened using ``cursor_from_id()``, this method must be called in + order to ``CLOSE`` the cursor. For cursors created by invoking a prepared + statement, this is not necessary as the garbage collection interface will take + the appropriate steps. + + ``Cursor.clone()`` + Create a new cursor object based on the same factors that were used to + create ``c``. + + ``Cursor.msghook(msg)`` + By default, the `msghook` attribute does not exist. If set to a callable, any + message that occurs during an operation of the cursor will be given to the + callable. See the `Database Messages`_ section for more information. + + +Cursors have some additional configuration properties that may be modified +during the use of the cursor: + + ``Cursor.direction`` + A value of `True`, the default, will cause read to fetch forwards, whereas a + value of `False` will cause it to fetch backwards. ``'BACKWARD'`` and + ``'FORWARD'`` can be used instead of `False` and `True`. + +Cursors normally share metadata with the statements that create them, so it is +usually unnecessary for referencing the cursor's column descriptions directly. +However, when a cursor is opened from an identifier, the cursor interface must +collect the metadata itself. These attributes provide the metadata in absence of +a statement object: + + ``Cursor.sql_column_types`` + A sequence of SQL type names specifying the types of the columns produced by + the cursor. `None` if the cursor does not return row-data. + + ``Cursor.pg_column_types`` + A sequence of PostgreSQL type Oid's specifying the types of the columns produced by + the cursor. `None` if the cursor does not return row-data. + + ``Cursor.column_types`` + A sequence of Python types that the cursor will produce. + + ``Cursor.column_names`` + A sequence of `str` objects specifying the names of the columns produced by + the cursor. `None` if the cursor does not return row-data. + + ``Cursor.statement`` + The statement that was executed that created the cursor. `None` if + unknown--``db.cursor_from_id()``. + + +Scrollable Cursors +------------------ + +Scrollable cursors are supported for applications that need to implement paging. +When statements are invoked via the ``declare(...)`` method, the returned cursor +is scrollable. + +.. note:: + Scrollable cursors never pre-fetch in order to provide guaranteed positioning. + +The cursor interface supports scrolling using the ``seek`` method. Like +``read``, it is semantically similar to a file object's ``seek()``. + +``seek`` takes two arguments: ``position`` and ``whence``: + + ``position`` + The position to scroll to. The meaning of this is determined by ``whence``. + + ``whence`` + How to use the position: absolute, relative, or absolute from end: + + absolute: ``'ABSOLUTE'`` or ``0`` (default) + seek to the absolute position in the cursor relative to the beginning of the + cursor. + + relative: ``'RELATIVE'`` or ``1`` + seek to the relative position. Negative ``position``'s will cause a MOVE + backwards, while positive ``position``'s will MOVE forwards. + + from end: ``'FROM_END'`` or ``2`` + seek to the end of the cursor and then MOVE backwards by the given + ``position``. + +The ``whence`` keyword argument allows for either numeric and textual +specifications. + +Scrolling through employees:: + + >>> emps_by_age = db.prepare(""" + ... SELECT + ... employee_name, employee_salary, employee_dob, employee_hire_date, + ... EXTRACT(years FROM AGE(employee_dob)) AS age + ... ORDER BY age ASC + ... """) + >>> c = emps_by_age.declare() + >>> # seek to the end, ``2`` works as well. + >>> c.seek(0, 'FROM_END') + >>> # scroll back one, ``1`` works as well. + >>> c.seek(-1, 'RELATIVE') + >>> # and back to the beginning again + >>> c.seek(0) + +Additionally, scrollable cursors support backward fetches by specifying the +direction keyword argument:: + + >>> c.seek(0, 2) + >>> c.read(1, 'BACKWARD') + + +Cursor Direction +---------------- + +The ``direction`` property on the cursor states the default direction for read +and seek operations. Normally, the direction is `True`, ``'FORWARD'``. When the +property is set to ``'BACKWARD'`` or `False`, the read method will fetch +backward by default, and seek operations will be inverted to simulate a +reversely ordered cursor. The following example illustrates the effect:: + + >>> reverse_c = db.prepare('SELECT i FROM generate_series(99, 0, -1) AS g(i)').declare() + >>> c = db.prepare('SELECT i FROM generate_series(0, 99) AS g(i)').declare() + >>> reverse_c.direction = 'BACKWARD' + >>> reverse_c.seek(0) + >>> c.read() == reverse_c.read() + +Furthermore, when the cursor is configured to read backwards, specifying +``'BACKWARD'`` for read's ``direction`` argument will ultimately cause a forward +fetch. This potentially confusing facet of direction configuration is +implemented in order to create an appropriate symmetry in functionality. +The cursors in the above example contain the same rows, but are ultimately in +reverse order. The backward direction property is designed so that the effect +of any read or seek operation on those cursors is the same:: + + >>> reverse_c.seek(50) + >>> c.seek(50) + >>> c.read(10) == reverse_c.read(10) + >>> c.read(10, 'BACKWARD') == reverse_c.read(10, 'BACKWARD') + +And for relative seeks:: + + >>> c.seek(-10, 1) + >>> reverse_c.seek(-10, 1) + >>> c.read(10, 'BACKWARD') == reverse_c.read(10, 'BACKWARD') + + +Rows +==== + +Rows received from PostgreSQL are instantiated into `postgresql.types.Row` +objects. Rows are both a sequence and a mapping. Items accessed with an `int` +are seen as indexes and other objects are seen as keys:: + + >>> row = db.prepare("SELECT 't'::text AS col0, 2::int4 AS col1").first() + >>> row + ('t', 2) + >>> row[0] + 't' + >>> row["col0"] + 't' + +However, this extra functionality is not free. The cost of instantiating +`postgresql.types.Row` objects is quite measurable, so the `chunks()` execution +method will produce `builtins.tuple` objects for cases where performance is +critical. + +.. note:: + Attributes aren't used to provide access to values due to potential conflicts + with existing method and property names. + + +Row Interface Points +-------------------- + +Rows implement the `collections.abc.Mapping` and `collections.abc.Sequence` interfaces. + + ``Row.keys()`` + An iterable producing the column names. Order is not guaranteed. See the + ``column_names`` property to get an ordered sequence. + + ``Row.values()`` + Iterable to the values in the row. + + ``Row.get(key_or_index[, default=None])`` + Get the item in the row. If the key doesn't exist or the index is out of + range, return the default. + + ``Row.items()`` + Iterable of key-value pairs. Ordered by index. + + ``iter(Row)`` + Iterable to the values in index order. + + ``value in Row`` + Whether or not the value exists in the row. (__contains__) + + ``Row[key_or_index]`` + If ``key_or_index`` is an integer, return the value at that index. If the + index is out of range, raise an `IndexError`. Otherwise, return the value + associated with column name. If the given key, ``key_or_index``, does not + exist, raise a `KeyError`. + + ``Row.index_from_key(key)`` + Return the index associated with the given key. + + ``Row.key_from_index(index)`` + Return the key associated with the given index. + + ``Row.transform(*args, **kw)`` + Create a new row object of the same length, with the same keys, but with new + values produced by applying the given callables to the corresponding items. + Callables given as ``args`` will be associated with values by their index and + callables given as keywords will be associated with values by their key, + column name. + +While the mapping interfaces will provide most of the needed information, some +additional properties are provided for consistency with statement and cursor +objects. + + ``Row.column_names`` + Property providing an ordered sequence of column names. The index corresponds + to the row value-index that the name refers to. + + >>> row[row.column_names[i]] == row[i] + + +Row Transformations +------------------- + +After a row is returned, sometimes the data in the row is not in the desired +format. Further processing is needed if the row object is to going to be +given to another piece of code which requires an object of differring +consistency. + +The ``transform`` method on row objects provides a means to create a new row +object consisting of the old row's items, but with certain columns transformed +using the given callables:: + + >>> row = db.prepare(""" + ... SELECT + ... 'XX9301423'::text AS product_code, + ... 2::int4 AS quantity, + ... '4.92'::numeric AS total + ... """).first() + >>> row + ('XX9301423', 2, Decimal("4.92")) + >>> row.transform(quantity = str) + ('XX9301423', '2', Decimal("4.92")) + +``transform`` supports both positional and keyword arguments in order to +assign the callable for a column's transformation:: + + >>> from operator import methodcaller + >>> row.transform(methodcaller('strip', 'XX')) + ('9301423', 2, Decimal("4.92")) + +Of course, more than one column can be transformed:: + + >>> stripxx = methodcaller('strip', 'XX') + >>> row.transform(stripxx, str, str) + ('9301423', '2', '4.92') + +`None` can also be used to indicate no transformation:: + + >>> row.transform(None, str, str) + ('XX9301423', '2', '4.92') + +More advanced usage can make use of lambdas for compound transformations in a +single pass of the row:: + + >>> strip_and_int = lambda x: int(stripxx(x)) + >>> row.transform(strip_and_int) + (9301423, 2, Decimal("4.92")) + +Transformations will be, more often than not, applied against *rows* as +opposed to *a* row. Using `operator.methodcaller` with `map` provides the +necessary functionality to create simple iterables producing transformed row +sequences:: + + >>> import decimal + >>> apply_tax = lambda x: (x * decimal.Decimal("0.1")) + x + >>> transform_row = methodcaller('transform', strip_and_int, None, apply_tax) + >>> r = map(transform_row, [row]) + >>> list(r) + [(9301423, 2, Decimal('5.412'))] + +And finally, `functools.partial` can be used to create a simple callable:: + + >>> from functools import partial + >>> transform_rows = partial(map, transform_row) + >>> list(transform_rows([row])) + [(9301423, 2, Decimal('5.412'))] + + +Queries +======= + +Queries in `py-postgresql` are single use prepared statements. They exist primarily for +syntactic convenience, but they also allow the driver to recognize the short lifetime of +the statement. + +Single use statements are supported using the ``query`` property on connection +objects, :py:class:`postgresql.api.Connection.query`. The statement object is not +available when using queries as the results, or handle to the results, are directly returned. + +Queries have access to all execution methods: + + * ``Connection.query(sql, *parameters)`` + * ``Connection.query.rows(sql, *parameters)`` + * ``Connection.query.column(sql, *parameters)`` + * ``Connection.query.first(sql, *parameters)`` + * ``Connection.query.chunks(sql, *parameters)`` + * ``Connection.query.declare(sql, *parameters)`` + * ``Connection.query.load_rows(sql, collections.abc.Iterable(parameters))`` + * ``Connection.query.load_chunks(collections.abc.Iterable(collections.abc.Iterable(parameters)))`` + +In cases where a sequence of one-shot queries needs to be performed, it may be important to +avoid unnecessary repeat attribute resolution from the connection object as the ``query`` +property is an interface object created on access. Caching the target execution methods is +recommended:: + + qrows = db.query.rows + l = [] + for x in my_queries: + l.append(qrows(x)) + +The characteristic of Each execution method is discussed in the prior +`Prepared Statements`_ section. + +Stored Procedures +================= + +The ``proc`` method on `postgresql.api.Database` objects provides a means to +create a reference to a stored procedure on the remote database. +`postgresql.api.StoredProcedure` objects are used to represent the referenced +SQL routine. + +This provides a direct interface to functions stored on the database. It +leverages knowledge of the parameters and results of the function in order +to provide the user with a natural interface to the procedure:: + + >>> func = db.proc('version()') + >>> func() + 'PostgreSQL 8.3.6 on ...' + + +Stored Procedure Interface Points +--------------------------------- + +It's more-or-less a function, so there's only one interface point: + + ``func(*args, **kw)`` (``__call__``) + Stored procedure objects are callable, executing a procedure will return an + object of suitable representation for a given procedure's type signature. + + If it returns a single object, it will return the single object produced by + the procedure. + + If it's a set returning function, it will return an *iterable* to the values + produced by the procedure. + + In cases of set returning function with multiple OUT-parameters, a cursor + will be returned. + + +Stored Procedure Type Support +----------------------------- + +Stored procedures support most types of functions. "Function Types" being set +returning functions, multiple-OUT parameters, and simple single-object returns. + +Set-returning functions, SRFs return a sequence:: + + >>> generate_series = db.proc('generate_series(int,int)') + >>> gs = generate_series(1, 20) + >>> gs + > + >>> next(gs) + 1 + >>> list(gs) + [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] + +For functions like ``generate_series()``, the driver is able to identify that +the return is a sequence of *solitary* integer objects, so the result of the +function is just that, a sequence of integers. + +Functions returning composite types are recognized, and return row objects:: + + >>> db.execute(""" + ... CREATE FUNCTION composite(OUT i int, OUT t text) + ... LANGUAGE SQL AS + ... $body$ + ... SELECT 900::int AS i, 'sample text'::text AS t; + ... $body$; + ... """) + >>> composite = db.proc('composite()') + >>> r = composite() + >>> r + (900, 'sample text') + >>> r['i'] + 900 + >>> r['t'] + 'sample text' + +Functions returning a set of composites are recognized, and the result is a +`postgresql.api.Cursor` object whose column names are consistent with the names +of the OUT parameters:: + + >>> db.execute(""" + ... CREATE FUNCTION srfcomposite(out i int, out t text) + ... RETURNS SETOF RECORD + ... LANGUAGE SQL AS + ... $body$ + ... SELECT 900::int AS i, 'sample text'::text AS t + ... UNION ALL + ... SELECT 450::int AS i, 'more sample text'::text AS t + ... $body$; + ... """) + >>> srfcomposite = db.proc('srfcomposite()') + >>> r = srfcomposite() + >>> next(r) + (900, 'sample text') + >>> v = next(r) + >>> v['i'], v['t'] + (450, 'more sample text') + + +Transactions +============ + +Transactions are managed by creating an object corresponding to a +transaction started on the server. A transaction is a transaction block, +a savepoint, or a prepared transaction. The ``xact(...)`` method on the +connection object provides the standard method for creating a +`postgresql.api.Transaction` object to manage a transaction on the connection. + +The creation of a transaction object does not start the transaction. Rather, the +transaction must be explicitly started using the ``start()`` method on the +transaction object. Usually, transactions *should* be managed with the context +manager interfaces:: + + >>> with db.xact(): + ... ... + +The transaction in the above example is opened, started, by the ``__enter__`` +method invoked by the with-statement's usage. It will be subsequently +committed or rolled-back depending on the exception state and the error state +of the connection when ``__exit__`` is called. + +**Using the with-statement syntax for managing transactions is strongly +recommended.** By using the transaction's context manager, it allows for Python +exceptions to be properly treated as fatal to the transaction as when an +uncaught exception of any kind occurs within the block, it is unlikely that +the state of the transaction can be trusted. Additionally, the ``__exit__`` +method provides a safe-guard against invalid commits. This can occur if a +database error is inappropriately caught within a block without being raised. + +The context manager interfaces are higher level interfaces to the explicit +instruction methods provided by `postgresql.api.Transaction` objects. + + +Transaction Configuration +------------------------- + +Keyword arguments given to ``xact()`` provide the means for configuring the +properties of the transaction. Only three points of configuration are available: + + ``isolation`` + The isolation level of the transaction. This must be a string. It will be + interpolated directly into the START TRANSACTION statement. Normally, + 'SERIALIZABLE' or 'READ COMMITTED': + + >>> with db.xact('SERIALIZABLE'): + ... ... + + ``mode`` + A string, 'READ ONLY' or 'READ WRITE'. States the mutability of stored + information in the database. Like ``isolation``, this is interpolated + directly into the START TRANSACTION string. + +The specification of any of these transaction properties imply that the transaction +is a block. Savepoints do not take configuration, so if a transaction identified +as a block is started while another block is running, an exception will be +raised. + + +Transaction Interface Points +---------------------------- + +The methods available on transaction objects manage the state of the transaction +and relay any necessary instructions to the remote server in order to reflect +that change of state. + + >>> x = db.xact(...) + + ``x.start()`` + Start the transaction. + + ``x.commit()`` + Commit the transaction. + + ``x.rollback()`` + Abort the transaction. + +These methods are primarily provided for applications that manage transactions +in a way that cannot be formed around single, sequential blocks of code. +Generally, using these methods require additional work to be performed by the +code that is managing the transaction. +If usage of these direct, instructional methods is necessary, it is important to +note that if the database is in an error state when a *transaction block's* +commit() is executed, an implicit rollback will occur. The transaction object +will simply follow instructions and issue the ``COMMIT`` statement, and it will +succeed without exception. + + +Error Control +------------- + +Handling *database* errors inside transaction CMs is generally discouraged as +any database operation that occurs within a failed transaction is an error +itself. It is important to trap any recoverable database errors *outside* of the +scope of the transaction's context manager: + + >>> try: + ... with db.xact(): + ... ... + ... except postgresql.exceptions.UniqueError: + ... pass + +In cases where the database is in an error state, but the context exits +without an exception, a `postgresql.exceptions.InFailedTransactionError` is +raised by the driver: + + >>> with db.xact(): + ... try: + ... ... + ... except postgresql.exceptions.UniqueError: + ... pass + ... + Traceback (most recent call last): + ... + postgresql.exceptions.InFailedTransactionError: invalid block exit detected + CODE: 25P02 + SEVERITY: ERROR + +Normally, if a ``COMMIT`` is issued on a failed transaction, the command implies a +``ROLLBACK`` without error. This is a very undesirable result for the CM's exit +as it may allow for code to be ran that presumes the transaction was committed. +The driver intervenes here and raises the +`postgresql.exceptions.InFailedTransactionError` to safe-guard against such +cases. This effect is consistent with savepoint releases that occur during an +error state. The distinction between the two cases is made using the ``source`` +property on the raised exception. + + +Settings +======== + +SQL's SHOW and SET provides a means to configure runtime parameters on the +database("GUC"s). In order to save the user some grief, a +`collections.abc.MutableMapping` interface is provided to simplify configuration. + +The ``settings`` attribute on the connection provides the interface extension. + +The standard dictionary interface is supported: + + >>> db.settings['search_path'] = "$user,public" + +And ``update(...)`` is better performing for multiple sets: + + >>> db.settings.update({ + ... 'search_path' : "$user,public", + ... 'default_statistics_target' : "1000" + ... }) + +.. note:: + The ``transaction_isolation`` setting cannot be set using the ``settings`` + mapping. Internally, ``settings`` uses ``set_config``, which cannot adjust + that particular setting. + +Settings Interface Points +------------------------- + +Manipulation and interrogation of the connection's settings is achieved by +using the standard `collections.abc.MutableMapping` interfaces. + + ``Connection.settings[k]`` + Get the value of a single setting. + + ``Connection.settings[k] = v`` + Set the value of a single setting. + + ``Connection.settings.update([(k1,v2), (k2,v2), ..., (kn,vn)])`` + Set multiple settings using a sequence of key-value pairs. + + ``Connection.settings.update({k1 : v1, k2 : v2, ..., kn : vn})`` + Set multiple settings using a dictionary or mapping object. + + ``Connection.settings.getset([k1, k2, ..., kn])`` + Get a set of a settings. This is the most efficient way to get multiple + settings as it uses a single request. + + ``Connection.settings.keys()`` + Get all available setting names. + + ``Connection.settings.values()`` + Get all setting values. + + ``Connection.settings.items()`` + Get a sequence of key-value pairs corresponding to all settings on the + database. + +Settings Management +------------------- + +`postgresql.api.Settings` objects can create context managers when called. +This gives the user with the ability to specify sections of code that are to +be ran with certain settings. The settings' context manager takes full +advantage of keyword arguments in order to configure the context manager: + + >>> with db.settings(search_path = 'local,public', timezone = 'mst'): + ... ... + +`postgresql.api.Settings` objects are callable; the return is a context manager +configured with the given keyword arguments representing the settings to use for +the block of code that is about to be executed. + +When the block exits, the settings will be restored to the values that they had +before the block entered. + + +Type Support +============ + +The driver supports a large number of PostgreSQL types at the binary level. +Most types are converted to standard Python types. The remaining types are +usually PostgreSQL specific types that are converted into objects whose class +is defined in `postgresql.types`. + +When a conversion function is not available for a particular type, the driver +will use the string format of the type and instantiate a `str` object +for the data. It will also expect `str` data when parameter of a type without a +conversion function is bound. + + +.. note:: + Generally, these standard types are provided for convenience. If conversions into + these datatypes are not desired, it is recommended that explicit casts into + ``text`` are made in statement string. + + +.. table:: Python types used to represent PostgreSQL types. + + ================================= ================================== =========== + PostgreSQL Types Python Types SQL Types + ================================= ================================== =========== + `postgresql.types.INT2OID` `int` smallint + `postgresql.types.INT4OID` `int` integer + `postgresql.types.INT8OID` `int` bigint + `postgresql.types.FLOAT4OID` `float` float + `postgresql.types.FLOAT8OID` `float` double + `postgresql.types.VARCHAROID` `str` varchar + `postgresql.types.BPCHAROID` `str` char + `postgresql.types.XMLOID` `xml.etree` (cElementTree) xml + + `postgresql.types.DATEOID` `datetime.date` date + `postgresql.types.TIMESTAMPOID` `datetime.datetime` timestamp + `postgresql.types.TIMESTAMPTZOID` `datetime.datetime` (tzinfo) timestamptz + `postgresql.types.TIMEOID` `datetime.time` time + `postgresql.types.TIMETZOID` `datetime.time` timetz + `postgresql.types.INTERVALOID` `datetime.timedelta` interval + + `postgresql.types.NUMERICOID` `decimal.Decimal` numeric + `postgresql.types.BYTEAOID` `bytes` bytea + `postgresql.types.TEXTOID` `str` text + `dict` hstore + ================================= ================================== =========== + +The mapping in the above table *normally* goes both ways. So when a parameter +is passed to a statement, the type *should* be consistent with the corresponding +Python type. However, many times, for convenience, the object will be passed +through the type's constructor, so it is not always necessary. + + +Arrays +------ + +Arrays of PostgreSQL types are supported with near transparency. For simple +arrays, arbitrary iterables can just be given as a statement's parameter and the +array's constructor will consume the objects produced by the iterator into a +`postgresql.types.Array` instance. However, in situations where the array has +multiple dimensions, `list` objects are used to delimit the boundaries of the +array. + + >>> ps = db.prepare("select $1::int[]") + >>> ps.first([(1,2), (2,3)]) + Traceback: + ... + postgresql.exceptions.ParameterError + +In the above case, it is apparent that this array is supposed to have two +dimensions. However, this is not the case for other types: + + >>> ps = db.prepare("select $1::point[]") + >>> ps.first([(1,2), (2,3)]) + postgresql.types.Array([postgresql.types.point((1.0, 2.0)), postgresql.types.point((2.0, 3.0))]) + +Lists are used to provide the necessary boundary information: + + >>> ps = db.prepare("select $1::int[]") + >>> ps.first([[1,2],[2,3]]) + postgresql.types.Array([[1,2],[2,3]]) + +The above is the appropriate way to define the array from the original example. + +.. hint:: + The root-iterable object given as an array parameter does not need to be a + list-type as it's assumed to be made up of elements. + + +Composites +---------- + +Composites are supported using `postgresql.types.Row` objects to represent +the data. When a composite is referenced for the first time, the driver +queries the database for information about the columns that make up the type. +This information is then used to create the necessary I/O routines for packing +and unpacking the parameters and columns of that type:: + + >>> db.execute("CREATE TYPE ctest AS (i int, t text, n numeric);") + >>> ps = db.prepare("SELECT $1::ctest") + >>> i = (100, 'text', "100.02013") + >>> r = ps.first(i) + >>> r["t"] + 'text' + >>> r["n"] + Decimal("100.02013") + +Or if use of a dictionary is desired:: + + >>> r = ps.first({'t' : 'just-the-text'}) + >>> r + (None, 'just-the-text', None) + +When a dictionary is given to construct the row, absent values are filled with +`None`. + +.. _db_messages: + +Database Messages +================= + +By default, py-postgresql gives detailed reports of messages emitted by the +database. Often, the verbosity is excessive due to single target processes or +existing application infrastructure for tracing the sources of various events. +Normally, this verbosity is not a significant problem as the driver defaults the +``client_min_messages`` setting to ``'WARNING'`` by default. + +However, if ``NOTICE`` or ``INFO`` messages are needed, finer grained control +over message propagation may be desired, py-postgresql's object relationship +model provides a common protocol for controlling message propagation and, +ultimately, display. + +The ``msghook`` attribute on elements--for instance, Statements, Connections, +and Connectors--is absent by default. However, when present on an object that +contributed the cause of a message event, it will be invoked with the Message, +`postgresql.message.Message`, object as its sole parameter. The attribute of +the object that is closest to the event is checked first, if present it will +be called. If the ``msghook()`` call returns a `True` +value(specficially, ``bool(x) is True``), the message will *not* be +propagated any further. However, if a `False` value--notably, `None`--is +returned, the next element is checked until the list is exhausted and the +message is given to `postgresql.sys.msghook`. The normal list of elements is +as follows:: + + Output → Statement → Connection → Connector → Driver → postgresql.sys + +Where ``Output`` can be a `postgresql.api.Cursor` object produced by +``declare(...)`` or an implicit output management object used *internally* by +``Statement.__call__()`` and other statement execution methods. Setting the +``msghook`` attribute on `postgresql.api.Statement` gives very fine +control over raised messages. Consider filtering the notice message on create +table statements that implicitly create indexes:: + + >>> db = postgresql.open(...) + >>> db.settings['client_min_messages'] = 'NOTICE' + >>> ct_this = db.prepare('CREATE TEMP TABLE "this" (i int PRIMARY KEY)') + >>> ct_that = db.prepare('CREATE TEMP TABLE "that" (i int PRIMARY KEY)') + >>> def filter_notices(msg): + ... if msg.details['severity'] == 'NOTICE': + ... return True + ... + >>> ct_that() + NOTICE: CREATE TABLE / PRIMARY KEY will create implicit index "that_pkey" for table "that" + ... + ('CREATE TABLE', None) + >>> ct_this.msghook = filter_notices + >>> ct_this() + ('CREATE TABLE', None) + >>> + +The above illustrates the quality of an installed ``msghook`` that simply +inhibits further propagation of messages with a severity of 'NOTICE'--but, only +notices coming from objects derived from the ``ct_this`` +`postgresql.api.Statement` object. + +Subsequently, if the filter is installed on the connection's ``msghook``:: + + >>> db = postgresql.open(...) + >>> db.settings['client_min_messages'] = 'NOTICE' + >>> ct_this = db.prepare('CREATE TEMP TABLE "this" (i int PRIMARY KEY)') + >>> ct_that = db.prepare('CREATE TEMP TABLE "that" (i int PRIMARY KEY)') + >>> def filter_notices(msg): + ... if msg.details['severity'] == 'NOTICE': + ... return True + ... + >>> db.msghook = filter_notices + >>> ct_that() + ('CREATE TABLE', None) + >>> ct_this() + ('CREATE TABLE', None) + >>> + +Any message with ``'NOTICE'`` severity coming from the connection, ``db``, will be +suffocated by the ``filter_notices`` function. However, if a ``msghook`` is +installed on either of those statements, it would be possible for display to +occur depending on the implementation of the hook installed on the statement +objects. + + +Message Metadata +---------------- + +PostgreSQL messages, `postgresql.message.Message`, are primarily described in three +parts: the SQL-state code, the main message string, and a mapping containing the +details. The follow attributes are available on message objects: + + ``Message.message`` + The primary message string. + + ``Message.code`` + The SQL-state code associated with a given message. + + ``Message.source`` + The origins of the message. Normally, ``'SERVER'`` or ``'CLIENT'``. + + ``Message.location`` + A terse, textual representation of ``'file'``, ``'line'``, and ``'function'`` + provided by the associated ``details``. + + ``Message.details`` + A mapping providing extended information about a message. This mapping + object **can** contain the following keys: + + ``'severity'`` + Any of ``'DEBUG'``, ``'INFO'``, ``'NOTICE'``, ``'WARNING'``, ``'ERROR'``, + ``'FATAL'``, or ``'PANIC'``; the latter three are usually associated with a + `postgresql.exceptions.Error` instance. + + ``'context'`` + The CONTEXT portion of the message. + + ``'detail'`` + The DETAIL portion of the message. + + ``'hint'`` + The HINT portion of the message. + + ``'position'`` + A number identifying the position in the statement string that caused a + parse error. + + ``'file'`` + The name of the file that emitted the message. + (*normally* server information) + + ``'function'`` + The name of the function that emitted the message. + (*normally* server information) + + ``'line'`` + The line of the file that emitted the message. + (*normally* server information) diff --git a/py_opengauss/documentation/gotchas.rst b/py_opengauss/documentation/gotchas.rst new file mode 100644 index 0000000000000000000000000000000000000000..915e3360c993a01df73707528ecf7ed1c420f423 --- /dev/null +++ b/py_opengauss/documentation/gotchas.rst @@ -0,0 +1,114 @@ +Gotchas +======= + +It is recognized that decisions were made that may not always be ideal for a +given user. In order to highlight those potential issues and hopefully bring +some sense into a confusing situation, this document was drawn. + +Thread Safety +------------- + +py-postgresql connection operations are not thread safe. + +`client_encoding` setting should be altered carefully +----------------------------------------------------- + +`postgresql.driver`'s streaming cursor implementation reads a fixed set of rows +when it queries the server for more. In order to optimize some situations, the +driver will send a request for more data, but makes no attempt to wait and +process the data as it is not yet needed. When the user comes back to read more +data from the cursor, it will then look at this new data. The problem being, if +`client_encoding` was switched, it may use the wrong codec to transform the +wire data into higher level Python objects(str). + +To avoid this problem from ever happening, set the `client_encoding` early. +Furthermore, it is probably best to never change the `client_encoding` as the +driver automatically makes the necessary transformation to Python strings. + + +The user and password is correct, but it does not work when using `postgresql.driver` +------------------------------------------------------------------------------------- + +This issue likely comes from the possibility that the information sent to the +server early in the negotiation phase may not be in an encoding that is +consistent with the server's encoding. + +One problem is that PostgreSQL does not provide the client with the server +encoding early enough in the negotiation phase, and, therefore, is unable to +process the password data in a way that is consistent with the server's +expectations. + +Another problem is that PostgreSQL takes much of the data in the startup message +as-is, so a decision about the best way to encode parameters is difficult. + +The easy way to avoid *most* issues with this problem is to initialize the +database in the `utf-8` encoding. The driver defaults the expected server +encoding to `utf-8`. However, this can be overridden by creating the `Connector` +with a `server_encoding` parameter. Setting `server_encoding` to the proper +value of the target server will allow the driver to properly encode *some* of +the parameters. Also, any GUC parameters passed via the `settings` parameter +should use typed objects when possible to hint that the server encoding should +not be used on that parameter(`bytes`, for instance). + + +Backslash characters are being treated literally +------------------------------------------------ + +The driver enables standard compliant strings. Stop using non-standard features. +;) + +If support for non-standard strings was provided it would require to the +driver to provide subjective quote interfaces(eg, db.quote_literal). Doing so is +not desirable as it introduces difficulties for the driver *and* the user. + + +Types without binary support in the driver are unsupported in arrays and records +-------------------------------------------------------------------------------- + +When an array or composite type is identified, `postgresql.protocol.typio` +ultimately chooses the binary format for the transfer of the column or +parameter. When this is done, PostgreSQL will pack or expect *all* the values +in binary format as well. If that binary format is not supported and the type +is not an string, it will fail to unpack the row or pack the appropriate data for +the element or attribute. + +In most cases issues related to this can be avoided with explicit casts to text. + + +NOTICEs, WARNINGs, and other messages are too verbose +----------------------------------------------------- + +For many situations, the information provided with database messages is +far too verbose. However, considering that py-postgresql is a programmer's +library, the default of high verbosity is taken with the express purpose of +allowing the programmer to "adjust the volume" until appropriate. + +By default, py-postgresql adjusts the ``client_min_messages`` to only emit +messages at the WARNING level or higher--ERRORs, FATALs, and PANICs. +This reduces the number of messages generated by most connections dramatically. + +If further customization is needed, the :ref:`db_messages` section has +information on overriding the default action taken with database messages. + +Strange TypeError using load_rows() or load_chunks() +---------------------------------------------------- + +When a prepared statement is directly executed using ``__call__()``, it can easily +validate that the appropriate number of parameters are given to the function. +When ``load_rows()`` or ``load_chunks()`` is used, any tuple in the +the entire sequence can cause this TypeError during the loading process:: + + TypeError: inconsistent items, N processors and M items in row + +This exception is raised by a generic processing routine whose functionality +is abstract in nature, so the message is abstract as well. It essentially means +that a tuple in the sequence given to the loading method had too many or too few +items. + +Non-English Locales +------------------- + +In the past, some builds of PostgreSQL localized the severity field of some protocol messages. +`py-postgresql` expects these fields to be consistent with their english terms. If the driver +raises strange exceptions during the use of non-english locales, it may be necessary to use an +english setting in order to coax the server into issueing familiar terms. diff --git a/py_opengauss/documentation/index.rst b/py_opengauss/documentation/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..9189c563ede57efa2c3c6de6bed70bae6d1f5fa7 --- /dev/null +++ b/py_opengauss/documentation/index.rst @@ -0,0 +1,75 @@ +py-postgresql +============= + +py-postgresql is a project dedicated to improving the Python client interfaces to PostgreSQL. + +At its core, py-postgresql provides a PG-API, `postgresql.api`, and +DB-API 2.0 interface for using a PostgreSQL database. + +Contents +-------- + +.. toctree:: + :maxdepth: 2 + + admin + driver + clientparameters + cluster + notifyman + alock + copyman + gotchas + +Reference +--------- + +.. toctree:: + :maxdepth: 2 + + bin + reference + +Changes +------- + +.. toctree:: + :maxdepth: 1 + + changes-v1.3 + changes-v1.2 + changes-v1.1 + changes-v1.0 + +Sample Code +----------- + +Using `postgresql.driver`:: + + >>> import postgresql + >>> db = postgresql.open("pq://user:password@host/name_of_database") + >>> db.execute("CREATE TABLE emp (emp_name text PRIMARY KEY, emp_salary numeric)") + >>> + >>> # Create the statements. + >>> make_emp = db.prepare("INSERT INTO emp VALUES ($1, $2)") + >>> raise_emp = db.prepare("UPDATE emp SET emp_salary = emp_salary + $2 WHERE emp_name = $1") + >>> get_emp_with_salary_lt = db.prepare("SELECT emp_name FROM emp WHERE emp_salay < $1") + >>> + >>> # Create some employees, but do it in a transaction--all or nothing. + >>> with db.xact(): + ... make_emp("John Doe", "150,000") + ... make_emp("Jane Doe", "150,000") + ... make_emp("Andrew Doe", "55,000") + ... make_emp("Susan Doe", "60,000") + >>> + >>> # Give some raises + >>> with db.xact(): + ... for row in get_emp_with_salary_lt("125,000"): + ... print(row["emp_name"]) + ... raise_emp(row["emp_name"], "10,000") + +Of course, if DB-API 2.0 is desired, the module is located at +`postgresql.driver.dbapi20`. DB-API extends PG-API, so the features +illustrated above are available on DB-API connections. + +See :ref:`db_interface` for more information. diff --git a/py_opengauss/documentation/lib.rst b/py_opengauss/documentation/lib.rst new file mode 100644 index 0000000000000000000000000000000000000000..592b96fa7f238d1f9f8eefd145f50736300b5efe --- /dev/null +++ b/py_opengauss/documentation/lib.rst @@ -0,0 +1,534 @@ +Categories and Libraries +************************ + +This chapter discusses the usage and implementation of connection categories and +libraries. Originally these features were written with general purpose use in mind; +however, it is recommended that these features **not** be used in applications. +They are primarily used internally by the the driver and may be removed in the future. + +Libraries are a collection of SQL statements that can be bound to a +connection. Libraries are *normally* bound directly to the connection object as +an attribute using a name specified by the library. + +Libraries provide a common way for SQL statements to be managed outside of the +code that uses them. When using ILFs, this increases the portability of the SQL +by keeping the statements isolated from the Python code in an accessible format +that can be easily used by other languages or systems --- An ILF parser can be +implemented within a few dozen lines using basic text tools. + +SQL statements defined by a Library are identified by their Symbol. These +symbols are named and annotated in order to allow the user to define how a +statement is to be used. The user may state the default execution method of +the statement object, or whether the symbol is to be preloaded at bind +time--these properties are Symbol Annotations. + +The purpose of libraries are to provide a means to manage statements on +disk and at runtime. ILFs provide a means to reference a collection +of statements on disk, and, when loaded, the symbol bindings provides means to +reference a statement already prepared for use on a given connection. + +The `postgresql.lib` package-module provides fundamental classes for supporting +categories and libraries. + + +Writing Libraries +================= + +ILF files are the recommended way to build a library. These files use the +naming convention "lib{NAME}.sql". The prefix and suffix are used describe the +purpose of the file and to provide a hint to editors that SQL highlighting +should be used. The format of an ILF takes the form:: + + + [name:type:method] + + ... + +Where multiple symbols may be defined. The Preface that comes before the first +symbol is an arbitrary block of text that should be used to describe the library. +This block is free-form, and should be considered a good place for some +general documentation. + +Symbols are named and described using the contents of section markers: +``('[' ... ']')``. Section markers have three components: the symbol name, +the symbol type and the symbol method. Each of these components are separated +using a single colon, ``:``. All components are optional except the Symbol name. +For example:: + + [get_user_info] + SELECT * FROM user WHERE user_id = $1 + + [get_user_info_v2::] + SELECT * FROM user WHERE user_id = $1 + +In the above example, ``get_user_info`` and ``get_user_info_v2`` are identical. +Empty components indicate the default effect. + +The second component in the section identifier is the symbol type. All Symbol +types are listed in `Symbol Types`_. This can be +used to specify what the section's contents are or when to bind the +symbol:: + + [get_user_info:preload] + SELECT * FROM user WHERE user_id = $1 + +This provides the Binding with the knowledge that the statement should be +prepared when the Library is bound. Therefore, when this Symbol's statement +is used for the first time, it will have already been prepared. + +Another type is the ``const`` Symbol type. This defines a data Symbol whose +*statement results* will be resolved when the Library is bound:: + + [user_type_ids:const] + SELECT user_type_id, user_type FROM user_types; + +Constant Symbols cannot take parameters as they are data properties. The +*result* of the above query is set to the Bindings' ``user_type_ids`` +attribute:: + + >>> db.lib.user_type_ids + + +Where ``lib`` in the above is a Binding of the Library containing the +``user_type_ids`` Symbol. + +Finally, procedures can be bound as symbols using the ``proc`` type:: + + [remove_user:proc] + remove_user(bigint) + +All procedures symbols are loaded when the Library is bound. Procedure symbols +are special because the execution method is effectively specified by the +procedure itself. + + +The third component is the symbol ``method``. This defines the execution method +of the statement and ultimately what is returned when the Symbol is called at +runtime. All the execution methods are listed in `Symbol Execution Methods`_. + +The default execution method is the default execution method of +`postgresql.api.PreparedStatement` objects; return the entire result set in a +list object:: + + [get_numbers] + SELECT i FROM generate_series(0, 100-1) AS g(i); + +When bound:: + + >>> db.lib.get_numbers() == [(x,) for x in range(100)] + True + +The transformation of range in the above is necessary as statements +return a sequence of row objects by default. + +For large result-sets, fetching all the rows would be taxing on a system's +memory. The ``rows`` and ``chunks`` methods provide an iterator to rows produced +by a statement using a stream:: + + [get_some_rows::rows] + SELECT i FROM generate_series(0, 1000) AS g(i); + + [get_some_chunks::chunks] + SELECT i FROM generate_series(0, 1000) AS g(i); + +``rows`` means that the Symbol will return an iterator producing individual rows +of the result, and ``chunks`` means that the Symbol will return an iterator +producing sequences of rows of the result. + +When bound:: + + >>> from itertools import chain + >>> list(db.lib.get_some_rows()) == list(chain.from_iterable(db.lib.get_some_chunks())) + True + +Other methods include ``column`` and ``first``. The column method provides a +means to designate that the symbol should return an iterator of the values in +the first column instead of an iterator to the rows:: + + [another_generate_series_example::column] + SELECT i FROM generate_series(0, $1::int) AS g(i) + +In use:: + + >>> list(db.lib.another_generate_series_example(100-1)) == list(range(100)) + True + >>> list(db.lib.another_generate_series_example(10-1)) + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + +The ``first`` method provides direct access to simple results. +Specifically, the first column of the first row when there is only one column. +When there are multiple columns the first row is returned:: + + [get_one::first] + SELECT 1 + + [get_one_twice::first] + SELECT 1, 1 + +In use:: + + >>> db.lib.get_one() == 1 + True + >>> db.lib.get_one_twice() == (1,1) + True + +.. note:: + ``first`` should be used with care. When the result returns no rows, `None` + will be returned. + + +Using Libraries +=============== + +After a library is created, it must be loaded before it can be bound using +programmer interfaces. The `postgresql.lib.load` interface provides the +primary entry point for loading libraries. + +When ``load`` is given a string, it identifies if a directory separator is in +the string, if there is it will treat the string as a *path* to the ILF to be +loaded. If no separator is found, it will treat the string as the library +name fragment and look for "lib{NAME}.sql" in the directories listed in +`postgresql.sys.libpath`. + +Once a `postgresql.lib.Library` instance has been acquired, it can then be +bound to a connection for use. `postgresql.lib.Binding` is used to create an +object that provides and manages the Bound Symbols:: + + >>> import postgresql.lib as pg_lib + >>> lib = pg_lib.load(...) + >>> B = pg_lib.Binding(db, lib) + +The ``B`` object in the above example provides the Library's Symbols as +attributes which can be called to in order to execute the Symbol's statement:: + + >>> B.symbol(param) + ... + +While it is sometimes necessary, manual creation of a Binding is discouraged. +Rather, `postgresql.lib.Category` objects should be used to manage the set of +Libraries to be bound to a connection. + + +Categories +---------- + +Libraries provide access to a collection of symbols; Bindings provide an +interface to the symbols with respect to a subject database. When a connection +is established, multiple Bindings may need to be created in order to fulfill +the requirements of the programmer. When a Binding is created, it exists in +isolation; this can be an inconvenience when access to both the Binding and +the Connection is necessary. Categories exist to provide a formal method for +defining the interface extensions on a `postgresql.api.Database` +instance(connection). + +A Category is essentially a runtime-class for connections. It provides a +formal initialization procedure for connection objects at runtime. However, +the connection resource must be connected prior to category initialization. + +Categories are sets of Libraries to be bound to a connection with optional name +substitutions. In order to create one directly, pass the Library instances to +`postgresql.lib.Category`:: + + >>> import postgresql.lib as pg_lib + >>> cat = pg_lib.Category(lib1, lib2, libN) + +Where ``lib1``, ``lib2``, ``libN`` are `postgresql.lib.Library` instances; +usually created by `postgresql.lib.load`. Once created, categories can then +used by passing the ``category`` keyword to connection creation interfaces:: + + >>> import postgresql + >>> db = postgresql.open(category = cat) + +The ``db`` object will now have Bindings for ``lib1``, ``lib2``, ..., and +``libN``. + +Categories can alter the access point(attribute name) of Bindings. This is done +by instantiating the Category using keyword parameters:: + + >>> cat = pg_lib.Category(lib1, lib2, libname = libN) + +At this point, when a connection is established as the category ``cat``, +``libN`` will be bound to the connection object on the attribute ``libname`` +instead of the name defined by the library. + +And a final illustration of Category usage:: + + >>> db = postgresql.open(category = pg_lib.Category(pg_lib.load('name'))) + >>> db.name + + + +Symbol Types +============ + +The symbol type determines how a symbol is going to be treated by the Binding. +For instance, ``const`` symbols are resolved when the Library is bound and +the statement object is immediately discarded. Here is a list of symbol types +that can be used in ILF libraries: + + ```` (Empty component) + The symbol's statement will never change. This allows the Bound Symbol to + hold onto the `postgresql.api.PreparedStatement` object. When the symbol is + used again, it will refer to the existing prepared statement object. + + ``preload`` + Like the default type, the Symbol is a simple statement, but it should be + loaded when the library is bound to the connection. + + ``const`` + The statement takes no parameters and only needs to be executed once. This + will cause the statement to be executed when the library is bound and the + results of the statement will be set to the Binding using the symbol name so + that it may be used as a property by the user. + + ``proc`` + The contents of the section is a procedure identifier. When this type is used + the symbol method *should not* be specified as the method annotation will be + automatically resolved based on the procedure's signature. + + ``transient`` + The Symbol is a statement that should *not* be retained. Specifically, it is + a statement object that will be discarded when the user discard the referenced + Symbol. Used in cases where the statement is used once or very infrequently. + + +Symbol Execution Methods +======================== + +The Symbol Execution Method provides a way to specify how a statement is going +to be used. Specifically, which `postgresql.api.PreparedStatement` method +should be executed when a Bound Symbol is called. The following is a list of +the symbol execution methods and the effect it will have when invoked: + + ```` (Empty component) + Returns the entire result set in a single list object. If the statement does + not return rows, a ``(command, count)`` pair will be returned. + + ``rows`` + Returns an iterator producing each row in the result set. + + ``chunks`` + Returns an iterator producing "chunks" of rows in the result set. + + ``first`` + Returns the first column of the first row if there is one column in the result + set. If there are multiple columns in the result set, the first row is + returned. If query is non-RETURNING DML--insert, update, or delete, the row + count is returned. + + ``column`` + Returns an iterator to values in the first column. (Equivalent to + executing a statement as ``map(operator.itemgetter(0), ps.rows())``.) + + ``declare`` + Returns a scrollable cursor, `postgresql.api.Cursor`, to the result set. + + ``load_chunks`` + Takes an iterable row-chunks to be given to the statement. Returns `None`. If + the statement is a ``COPY ... FROM STDIN``, the iterable must produce chunks + of COPY lines. + + ``load_rows`` + Takes an iterable rows to be given as parameters. If the statement is a ``COPY + ... FROM STDIN``, the iterable must produce COPY lines. + + +Reference Symbols +================= + +Reference Symbols provide a way to construct a Bound Symbol using the Symbol's +query. When invoked, A Reference Symbol's query is executed in order to produce +an SQL statement to be used as a Bound Symbol. In ILF files, a reference is +identified by its symbol name being prefixed with an ampersand:: + + [&refsym::first] + SELECT 'SELECT 1::int4'::text + +Then executed:: + + >>> # Runs the 'refsym' SQL, and creates a Bound Symbol using the results. + >>> sym = lib.refsym() + >>> assert sym() == 1 + +The Reference Symbol's type and execution method are inherited by the created +Bound Symbol. With one exception, ``const`` reference symbols are +special in that they immediately resolved into the target Bound Symbol. + +A Reference Symbol's source query *must* produce rows of text columns. Multiple +columns and multiple rows may be produced by the query, but they must be +character types as the results are promptly joined together with whitespace so +that the target statement may be prepared. + +Reference Symbols are most likely to be used in dynamic DDL and DML situations, +or, somewhat more specifically, any query whose definition depends on a +generated column list. + +Distributing and Usage +====================== + +For applications, distribution and management can easily be a custom +process. The application designates the library directory; the entry point +adds the path to the `postgresql.sys.libpath` list; a category is built; and, a +connection is made using the category. + +For mere Python extensions, however, ``distutils`` has a feature that can +aid in ILF distribution. The ``package_data`` setup keyword can be used to +include ILF files alongside the Python modules that make up a project. See +http://docs.python.org/3.1/distutils/setupscript.html#installing-package-data +for more detailed information on the keyword parameter. + +The recommended way to manage libraries for extending projects is to +create a package to contain them. For instance, consider the following layout:: + + project/ + setup.py + pkg/ + __init__.py + lib/ + __init__.py + libthis.sql + libthat.sql + +The project's SQL libraries are organized into a single package directory, +``lib``, so ``package_data`` would be configured:: + + package_data = {'pkg.lib': ['*.sql']} + +Subsequently, the ``lib`` package initialization script can then be used to +load the libraries, and create any categories(``project/pkg/lib/__init__.py``):: + + import os.path + import postgresql.lib as pg_lib + import postgresql.sys as pg_sys + libdir = os.path.dirname(__file__) + pg_sys.libpath.append(libdir) + libthis = pg_lib.load('this') + libthat = pg_lib.load('that') + stdcat = pg_lib.Category(libthis, libthat) + +However, it can be undesirable to add the package directory to the global +`postgresql.sys.libpath` search paths. Direct path loading can be used in those +cases:: + + import os.path + import postgresql.lib as pg_lib + libdir = os.path.dirname(__file__) + libthis = pg_lib.load(os.path.join(libdir, 'libthis.sql')) + libthat = pg_lib.load(os.path.join(libdir, 'libthat.sql')) + stdcat = pg_lib.Category(libthis, libthat) + +Using the established project context, a connection would then be created as:: + + from pkg.lib import stdcat + import postgresql as pg + db = pg.open(..., category = stdcat) + # And execute some fictitious symbols. + db.this.sym_from_libthis() + db.that.sym_from_libthat(...) + + +Audience and Motivation +======================= + +This chapter covers advanced material. It is **not** recommended that categories +and libraries be used for trivial applications or introductory projects. + +.. note:: + Libraries and categories are not likely to be of interest to ORM or DB-API users. + +With exception to ORMs or other similar abstractions, the most common pattern +for managing connections and statements is delegation:: + + class MyAppDB(object): + def __init__(self, connection): + self.connection = connection + + def my_operation(self, op_arg1, op_arg2): + return self.connection.prepare( + "SELECT my_operation_proc($1,$2)", + )(op_arg1, op_arg2) + ... + +The straightforward nature is likeable, but the usage does not take advantage of +prepared statements. In order to do that an extra condition is necessary to see +if the statement has already been prepared:: + + ... + + def my_operation(self, op_arg1, op_arg2): + if self.hasattr(self, '_my_operation'): + ps = self._my_operation + else: + ps = self._my_operation = self.connection.prepare( + "SELECT my_operation_proc($1, $2)", + ) + return ps(op_arg1, op_arg2) + ... + +There are many variations that can implement the above. It works and it's +simple, but it will be exhausting if repeated and error prone if the +initialization condition is not factored out. Additionally, if access to statement +metadata is needed, the above example is still lacking as it would require +execution of the statement and further protocol expectations to be established. +This is the province of libraries: direct database interface management. + +Categories and Libraries are used to factor out and simplify +the above functionality so re-implementation is unnecessary. For example, an +ILF library containing the symbol:: + + [my_operation] + SELECT my_operation_proc($1, $2) + + [] + ... + +Will provide the same functionality as the ``my_operation`` method in the +latter Python implementation. + + +Terminology +=========== + +The following terms are used throughout this chapter: + + Annotations + The information of about a Symbol describing what it is and how it should be + used. + + Binding + An interface to the Symbols provided by a Library for use with a given + connection. + + Bound Symbol + An interface to an individual Symbol ready for execution against the subject + database. + + Bound Reference + An interface to an individual Reference Symbol that will produce a Bound + Symbol when executed. + + ILF + INI-style Library Format. "lib{NAME}.sql" files. + + Library + A collection of Symbols--mapping of names to SQL statements. + + Local Symbol + A relative term used to denote a symbol that exists in the same library as + the subject symbol. + + Preface + The block of text that comes before the first symbol in an ILF file. + + Symbol + An named database operation provided by a Library. Usually, an SQL statement + with Annotations. + + Reference Symbol + A Symbol whose SQL statement *produces* the source for a Bound Symbol. + + Category + An object supporting a classification for connectors that provides database + initialization facilities for produced connections. For libraries, + `postgresql.lib.Category` objects are a set of Libraries, + `postgresql.lib.Library`. diff --git a/py_opengauss/documentation/notifyman.rst b/py_opengauss/documentation/notifyman.rst new file mode 100644 index 0000000000000000000000000000000000000000..d774ee52b7e579bd73ef6a3e1fa199feb8db6cca --- /dev/null +++ b/py_opengauss/documentation/notifyman.rst @@ -0,0 +1,237 @@ +.. _notifyman: + +*********************** +Notification Management +*********************** + +Relevant SQL commands: `NOTIFY `_, +`LISTEN `_, +`UNLISTEN `_. + +Asynchronous notifications offer a means for PostgreSQL to signal application +code. Often these notifications are used to signal cache invalidation. In 9.0 +and greater, notifications may include a "payload" in which arbitrary data may +be delivered on a channel being listened to. + +By default, received notifications will merely be appended to an internal +list on the connection object. This list will remain empty for the duration +of a connection *unless* the connection begins listening to a channel that +receives notifications. + +The `postgresql.notifyman.NotificationManager` class is used to wait for +messages to come in on a set of connections, pick up the messages, and deliver +the messages to the object's user via the `collections.Iterator` protocol. + + +Listening on a Single Connection +================================ + +The ``db.iternotifies()`` method is a simplification of the notification manager. It +returns an iterator to the notifications received on the subject connection. +The iterator yields triples consisting of the ``channel`` being +notified, the ``payload`` sent with the notification, and the ``pid`` of the +backend that caused the notification:: + + >>> db.listen('for_rabbits') + >>> db.notify('for_rabbits') + >>> for x in db.iternotifies(): + ... channel, payload, pid = x + ... break + >>> assert channel == 'for_rabbits' + True + >>> assert payload == '' + True + >>> assert pid == db.backend_id + True + +The iterator, by default, will continue listening forever unless the connection +is terminated--thus the immediate ``break`` statement in the above loop. In +cases where some additional activity is necessary, a timeout parameter may be +given to the ``iternotifies`` method in order to allow "idle" events to occur +at the designated frequency:: + + >>> for x in db.iternotifies(0.5): + ... if x is None: + ... break + +The above example illustrates that idle events are represented using `None` +objects. Idle events are guaranteed to occur *approximately* at the +specified interval--the ``timeout`` keyword parameter. In addition to +providing a means to do other processing or polling, they also offer a safe +break point for the loop. Internally, the iterator produced by the +``iternotifies`` method *is* a `NotificationManager`, which will localize the +notifications prior to emitting them via the iterator. +*It's not safe to break out of the loop, unless an idle event is being handled.* +If the loop is broken while a regular event is being processed, some events may +remain in the iterator. In order to consume those events, the iterator *must* +be accessible. + +The iterator will be exhausted when the connection is closed, but if the +connection is closed during the loop, any remaining notifications *will* +be emitted prior to the loop ending, so it is important to be prepared to +handle exceptions or check for a closed connection. + +In situations where multiple connections need to be watched, direct use of the +`NotificationManager` is necessary. + + +Listening on Multiple Connections +================================= + +The `postgresql.notifyman.NotificationManager` class is used to manage +*connections* that are expecting to receive notifications. Instances are +iterators that yield the connection object and notifications received on the +connection or `None` in the case of an idle event. The manager emits events as +a pair; the connection object that received notifications, and *all* the +notifications picked up on that connection:: + + >>> from postgresql.notifyman import NotificationManager + >>> # Using ``nm`` to reference the manager from here on. + >>> nm = NotificationManager(db1, db2, ..., dbN) + >>> nm.settimeout(2) + >>> for x in nm: + ... if x is None: + ... # idle + ... break + ... + ... db, notifies = x + ... for channel, payload, pid in notifies: + ... ... + +The manager will continue to wait for and emit events so long as there are +good connections available in the set; it is possible for connections to be +added and removed at any time. Although, in rare circumstances, discarded +connections may still have pending events if it not removed during an idle +event. The ``connections`` attribute on `NotificationManager` objects is a +set object that may be used directly in order to add and remove connections +from the manager:: + + >>> y = [] + >>> for x in nm: + ... if x is None: + ... if y: + ... nm.connections.add(y[0]) + ... del y[0] + ... + +The notification manager is resilient; if a connection dies, it will discard the +connection from the set, and add it to the set of bad connections, the +``garbage`` attribute. In these cases, the idle event *should* be leveraged to +check for these failures if that's a concern. It is the user's +responsibility to explicitly handle the failure cases, and remove the bad +connections from the ``garbage`` set:: + + >>> for x in nm: + ... if x is None: + ... if nm.garbage: + ... recovered = take_out_trash(nm.garbage) + ... nm.connections.update(recovered) + ... nm.garbage.clear() + ... db, notifies = x + ... for channel, payload, pid in notifies: + ... ... + +Explicitly removing connections from the set can also be a means to gracefully +terminate the event loop:: + + >>> for x in nm: + ... if x in None: + ... if done_listening is True: + ... nm.connections.clear() + +However, doing so inside the loop is not a requirement; it is safe to remove a +connection from the set at any point. + + +Notification Managers +===================== + +The `postgresql.notifyman.NotificationManager` is an event loop that services +multiple connections. In cases where only one connection needs to be serviced, +the `postgresql.api.Database.iternotifies` method can be used to simplify the +process. + + +Notification Manager Constructors +--------------------------------- + + ``NotificationManager(*connections, timeout = None)`` + Create a NotificationManager instance that manages the notifications coming + from the given set of connections. The ``timeout`` keyword is optional and + can be configured using the ``settimeout`` method as well. + + +Notification Manager Interface Points +------------------------------------- + + ``NotificationManager.__iter__()`` + Returns the instance; it is an iterator. + + ``NotificationManager.__next__()`` + Normally, yield the pair, connection and notifications list, when the next + event is received. If a timeout is configured, `None` may be yielded to signal + an idle event. The notifications list is a list of triples: + ``(channel, payload, pid)``. + + ``NotificationManager.settimeout(timeout : int)`` + Set the amount of time to wait before the manager yields an idle event. + If zero, the manager will never wait and only yield notifications that are + immediately available. + If `None`, the manager will never emit idle events. + + ``NotificationManager.gettimeout() -> [int, None]`` + Get the configured timeout; returns either `None`, or an `int`. + + ``NotificationManager.connections`` + The set of connections that the manager is actively watching for + notifications. Connections may be added or removed from the set at any time. + + ``NotificationManager.garbage`` + The set of connections that failed. Normally empty, but when a connection gets + an exceptional condition or explicitly raises an exception, it is removed from + the ``connections`` set, and placed in ``garbage``. + + +Zero Timeout +------------ + +When a timeout of zero, ``0``, is configured, the notification manager will +terminate early. Specifically, each connection will be polled for any pending +notifications, and once all of the collected notifications have been emitted +by the iterator, `StopIteration` will be raised. Notably, *no* idle events will +occur when the timeout is configured to zero. + +Zero timeouts offer a means for the notification "queue" to be polled. Often, +this is the appropriate way to collect pending notifications on active +connections where using the connection exclusively for waiting is not +practical:: + + >>> notifies = list(db.iternotifies(0)) + +Or with a NotificationManager instance:: + + >>> nm.settimeout(0) + >>> db_notifies = list(nm) + +In both cases of zero timeout, the iterator may be promptly discarded without +losing any events. + + +Summary of Characteristics +-------------------------- + + * The iterator will continue until the connections die. + * Objects yielded by the iterator are either `None`, an "idle event", or an + individual notification triple if using ``db.iternotifies()``, or a + ``(db, notifies)`` pair if using the base `NotificationManager`. + * When a connection dies or raises an exception, it will be removed from + the ``nm.connections`` set and added to the ``nm.garbage`` set. + * The NotificationManager instance will *not* hold any notifications + during an idle event. Idle events offer a break point in which the manager + may be discarded. + * A timeout of zero will cause the iterator to only yield the events + that are pending right now, and promptly end. However, the same manager + object may be used again. + * A notification triple is a tuple consisting of ``(channel, payload, pid)``. + * Connections may be added and removed from the ``nm.connections`` set at + any time. diff --git a/py_opengauss/documentation/reference.rst b/py_opengauss/documentation/reference.rst new file mode 100644 index 0000000000000000000000000000000000000000..466a672e3b7fb3805b99d789349b7a9157c805f7 --- /dev/null +++ b/py_opengauss/documentation/reference.rst @@ -0,0 +1,82 @@ +Reference +========= + +:mod:`postgresql` +----------------- + +.. automodule:: postgresql +.. autodata:: version +.. autodata:: version_info +.. autofunction:: open + +:mod:`postgresql.api` +--------------------- + +.. automodule:: + postgresql.api + :members: + :show-inheritance: + +:mod:`postgresql.sys` +--------------------- + +.. automodule:: + postgresql.sys + :members: + :show-inheritance: + +:mod:`postgresql.string` +------------------------ + +.. automodule:: + postgresql.string + :members: + :show-inheritance: + +:mod:`postgresql.exceptions` +---------------------------- + +.. automodule:: + postgresql.exceptions + :members: + :show-inheritance: + +:mod:`postgresql.temporal` +-------------------------- + +.. automodule:: + postgresql.temporal + :members: + :show-inheritance: + +:mod:`postgresql.installation` +------------------------------ + +.. automodule:: + postgresql.installation + :members: + :show-inheritance: + +:mod:`postgresql.cluster` +------------------------- + +.. automodule:: + postgresql.cluster + :members: + :show-inheritance: + +:mod:`postgresql.copyman` +------------------------- + +.. automodule:: + postgresql.copyman + :members: + :show-inheritance: + +:mod:`postgresql.alock` +----------------------- + +.. automodule:: + postgresql.alock + :members: + :show-inheritance: diff --git a/py_opengauss/documentation/sphinx/.gitignore b/py_opengauss/documentation/sphinx/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..2211df63dd2831aa0cfc38ba1ebc95e3c4620894 --- /dev/null +++ b/py_opengauss/documentation/sphinx/.gitignore @@ -0,0 +1 @@ +*.txt diff --git a/py_opengauss/documentation/sphinx/admin.rst b/py_opengauss/documentation/sphinx/admin.rst new file mode 100644 index 0000000000000000000000000000000000000000..092600d54975feb4dfe5f464deeaa861c652951a --- /dev/null +++ b/py_opengauss/documentation/sphinx/admin.rst @@ -0,0 +1,33 @@ +Administration +============== + +This chapter covers the administration of py-postgresql. This includes +installation and other aspects of working with py-postgresql such as +environment variables and configuration files. + +Installation +------------ + +py-postgresql uses Python's distutils package to manage the build and +installation process of the package. The normal entry point for +this is the ``setup.py`` script contained in the root project directory. + +After extracting the archive and changing the into the project's directory, +installation is normally as simple as:: + + $ python3 ./setup.py install + +However, if you need to install for use with a particular version of python, +just use the path of the executable that should be used:: + + $ /usr/opt/bin/python3 ./setup.py install + + +Environment +----------- + +These environment variables effect the operation of the package: + + ============== =============================================================================== + PGINSTALLATION The path to the ``pg_config`` executable of the installation to use by default. + ============== =============================================================================== diff --git a/py_opengauss/documentation/sphinx/alock.rst b/py_opengauss/documentation/sphinx/alock.rst new file mode 100644 index 0000000000000000000000000000000000000000..6767767422ada306f816e7cbe456c5b1b7aaa449 --- /dev/null +++ b/py_opengauss/documentation/sphinx/alock.rst @@ -0,0 +1,108 @@ +.. _alock: + +************** +Advisory Locks +************** + +.. warning:: `postgresql.alock` is a new feature in v1.0. + +`Explicit Locking in PostgreSQL `_. + +PostgreSQL's advisory locks offer a cooperative synchronization primitive. +These are used in cases where an application needs access to a resource, but +using table locks may cause interference with other operations that can be +safely performed alongside the application-level, exclusive operation. + +Advisory locks can be used by directly executing the stored procedures in the +database or by using the :class:`postgresql.alock.ALock` subclasses, which +provides a context manager that uses those stored procedures. + +Currently, only two subclasses exist. Each represents the lock mode +supported by PostgreSQL's advisory locks: + + * :class:`postgresql.alock.ShareLock` + * :class:`postgresql.alock.ExclusiveLock` + + +Acquiring ALocks +================ + +An ALock instance represents a sequence of advisory locks. A single ALock can +acquire and release multiple advisory locks by creating the instance with +multiple lock identifiers:: + + >>> from postgresql import alock + >>> table1_oid = 192842 + >>> table2_oid = 192849 + >>> l = alock.ExclusiveLock(db, (table1_oid, 0), (table2_oid, 0)) + >>> l.acquire() + >>> ... + >>> l.release() + +:class:`postgresql.alock.ALock` is similar to :class:`threading.RLock`; in +order for an ALock to be released, it must be released the number of times it +has been acquired. ALocks are associated with and survived by their session. +Much like how RLocks are associated with the thread they are acquired in: +acquiring an ALock again will merely increment its count. + +PostgreSQL allows advisory locks to be identified using a pair of `int4` or a +single `int8`. ALock instances represent a *sequence* of those identifiers:: + + >>> from postgresql import alock + >>> ids = [(0,0), 0, 1] + >>> with alock.ShareLock(db, *ids): + ... ... + +Both types of identifiers may be used within the same ALock, and, regardless of +their type, will be aquired in the order that they were given to the class' +constructor. In the above example, ``(0,0)`` is acquired first, then ``0``, and +lastly ``1``. + + +ALocks +====== + +`postgresql.alock.ALock` is abstract; it defines the interface and some common +functionality. The lock mode is selected by choosing the appropriate subclass. + +There are two: + + ``postgresql.alock.ExclusiveLock(database, *identifiers)`` + Instantiate an ALock object representing the `identifiers` for use with the + `database`. Exclusive locks will conflict with other exclusive locks and share + locks. + + ``postgresql.alock.ShareLock(database, *identifiers)`` + Instantiate an ALock object representing the `identifiers` for use with the + `database`. Share locks can be acquired when a share lock with the same + identifier has been acquired by another backend. However, an exclusive lock + with the same identifier will conflict. + + +ALock Interface Points +---------------------- + +Methods and properties available on :class:`postgresql.alock.ALock` instances: + + ``alock.acquire(blocking = True)`` + Acquire the advisory locks represented by the ``alock`` object. If blocking is + `True`, the default, the method will block until locks on *all* the + identifiers have been acquired. + + If blocking is `False`, acquisition may not block, and success will be + indicated by the returned object: `True` if *all* lock identifiers were + acquired and `False` if any of the lock identifiers could not be acquired. + + ``alock.release()`` + Release the advisory locks represented by the ``alock`` object. If the lock + has not been acquired, a `RuntimeError` will be raised. + + ``alock.locked()`` + Returns a boolean describing whether the locks are held or not. This will + return `False` if the lock connection has been closed. + + ``alock.__enter__()`` + Alias to ``acquire``; context manager protocol. Always blocking. + + ``alock.__exit__(typ, val, tb)`` + Alias to ``release``; context manager protocol. diff --git a/py_opengauss/documentation/sphinx/bin.rst b/py_opengauss/documentation/sphinx/bin.rst new file mode 100644 index 0000000000000000000000000000000000000000..43e7a76e9c2488ac57f1730965e3fab5e90c0e73 --- /dev/null +++ b/py_opengauss/documentation/sphinx/bin.rst @@ -0,0 +1,170 @@ +Commands +******** + +This chapter discusses the usage of the available console scripts. + + +postgresql.bin.pg_python +======================== + +The ``pg_python`` command provides a simple way to write Python scripts against a +single target database. It acts like the regular Python console command, but +takes standard PostgreSQL options as well to specify the client parameters +to make establish connection with. The Python environment is then augmented +with the following built-ins: + + ``db`` + The PG-API connection object. + + ``xact`` + ``db.xact``, the transaction creator. + + ``settings`` + ``db.settings`` + + ``prepare`` + ``db.prepare``, the statement creator. + + ``proc`` + ``db.proc`` + + ``do`` + ``db.do``, execute a single DO statement. + + ``sqlexec`` + ``db.execute``, execute multiple SQL statements (``None`` is always returned) + +pg_python Usage +--------------- + +Usage: postgresql.bin.pg_python [connection options] [script] ... + +Options: + --unix=UNIX path to filesystem socket + --ssl-mode=SSLMODE SSL requirement for connectivity: require, prefer, + allow, disable + -s SETTINGS, --setting=SETTINGS + run-time parameters to set upon connecting + -I PQ_IRI, --iri=PQ_IRI + database locator string + [pq://user:password@host:port/database?setting=value] + -h HOST, --host=HOST database server host + -p PORT, --port=PORT database server port + -U USER, --username=USER + user name to connect as + -W, --password prompt for password + -d DATABASE, --database=DATABASE + database's name + --pq-trace=PQ_TRACE trace PQ protocol transmissions + -C PYTHON_CONTEXT, --context=PYTHON_CONTEXT + Python context code to run[file://,module:,] + -m PYTHON_MAIN Python module to run as script(__main__) + -c PYTHON_MAIN Python expression to run(__main__) + --version show program's version number and exit + --help show this help message and exit + + +Interactive Console Backslash Commands +-------------------------------------- + +Inspired by ``psql``:: + + >>> \? + Backslash Commands: + + \? Show this help message. + \E Edit a file or a temporary script. + \e Edit and Execute the file directly in the context. + \i Execute a Python script within the interpreter's context. + \set Configure environment variables. \set without arguments to show all + \x Execute the Python command within this process. + + +pg_python Examples +------------------ + +Module execution taking advantage of the new built-ins:: + + $ python3 -m postgresql.bin.pg_python -h localhost -W -m timeit "prepare('SELECT 1').first()" + Password for pg_python[pq://dbusername@localhost:5432]: + 1000 loops, best of 3: 1.35 msec per loop + + $ python3 -m postgresql.bin.pg_python -h localhost -W -m timeit -s "ps=prepare('SELECT 1')" "ps.first()" + Password for pg_python[pq://dbusername@localhost:5432]: + 1000 loops, best of 3: 442 usec per loop + +Simple interactive usage:: + + $ python3 -m postgresql.bin.pg_python -h localhost -W + Password for pg_python[pq://dbusername@localhost:5432]: + >>> ps = prepare('select 1') + >>> ps.first() + 1 + >>> c = ps() + >>> c.read() + [(1,)] + >>> ps.close() + >>> import sys + >>> sys.exit(0) + + +postgresql.bin.pg_dotconf +========================= + +pg_dotconf is used to modify a PostgreSQL cluster's configuration file. +It provides a means to apply settings specified from the command line and from a +file referenced using the ``-f`` option. + +.. warning:: + ``include`` directives in configuration files are *completely* ignored. If + modification of an included file is desired, the command must be applied to + that specific file. + + +pg_dotconf Usage +---------------- + +Usage: postgresql.bin.pg_dotconf [--stdout] [-f filepath] postgresql.conf ([param=val]|[param])* + +Options: + --version show program's version number and exit + -h, --help show this help message and exit + -f SETTINGS, --file=SETTINGS + A file of settings to *apply* to the given + "postgresql.conf" + --stdout Redirect the product to standard output instead of + writing back to the "postgresql.conf" file + + +Examples +-------- + +Modifying a simple configuration file:: + + $ echo "setting = value" >pg.conf + + # change 'setting' + $ python3 -m postgresql.bin.pg_dotconf pg.conf setting=newvalue + + $ cat pg.conf + setting = 'newvalue' + + # new settings are appended to the file + $ python3 -m postgresql.bin.pg_dotconf pg.conf another_setting=value + $ cat pg.conf + setting = 'newvalue' + another_setting = 'value' + + # comment a setting + $ python3 -m postgresql.bin.pg_dotconf pg.conf another_setting + + $ cat pg.conf + setting = 'newvalue' + #another_setting = 'value' + +When a setting is given on the command line, it must been seen as one argument +to the command, so it's *very* important to avoid invocations like:: + + $ python3 -m postgresql.bin.pg_dotconf pg.conf setting = value + ERROR: invalid setting, '=' after 'setting' + HINT: Settings must take the form 'setting=value' or 'setting_name_to_comment'. Settings must also be received as a single argument. diff --git a/py_opengauss/documentation/sphinx/build.sh b/py_opengauss/documentation/sphinx/build.sh new file mode 100755 index 0000000000000000000000000000000000000000..acf026c1690cfb753cdfb912e27070f9558760ef --- /dev/null +++ b/py_opengauss/documentation/sphinx/build.sh @@ -0,0 +1,34 @@ +#!/bin/sh +cd "$(dirname $0)" + +# distutils doesn't make it straighforward to include an arbitrary +# directory in the package data, so manage .static and .templates here. +mkdir -p .static .templates +cat >.static/unsuck.css_t <.templates/layout.html < +{% endblock %} +EOF + +mkdir -p ../html/doctrees +sphinx-build -c "$(pwd)" -E -b html -d ../html/doctrees .. ../html +cd ../html && pwd diff --git a/py_opengauss/documentation/sphinx/changes-v1.0.rst b/py_opengauss/documentation/sphinx/changes-v1.0.rst new file mode 100644 index 0000000000000000000000000000000000000000..89c8cea31f2a5bdfc58ff91b8a5bdb4924d7dc58 --- /dev/null +++ b/py_opengauss/documentation/sphinx/changes-v1.0.rst @@ -0,0 +1,79 @@ +Changes in v1.0 +=============== + +1.0.4 in development +-------------------- + + * Alter how changes are represented in documentation to simplify merging. + +1.0.3 released on 2011-09-24 +---------------------------- + + * Use raise x from y to generalize exceptions. (Elvis Pranskevichus) + * Alter postgresql.string.quote_ident to always quote. (Elvis Pranskevichus) + * Add postgresql.string.quote_ident_if_necessary (Modification of Elvis Pranskevichus' patch) + * Many postgresql.string bug fixes (Elvis Pranskevichus) + * Correct ResourceWarnings improving Python 3.2 support. (jwp) + * Add test command to setup.py (Elvis Pranskevichus) + +1.0.2 released on 2010-09-18 +---------------------------- + + * Add support for DOMAINs in registered composites. (Elvis Pranskevichus) + * Properly raise StopIteration in Cursor.__next__. (Elvis Pranskevichus) + * Add Cluster Management documentation. + * Release savepoints after rolling them back. + * Fix Startup() usage for Python 3.2. + * Emit deprecation warning when 'gid' is given to xact(). + * Compensate for Python3.2's ElementTree API changes. + +1.0.1 released on 2010-04-24 +---------------------------- + + * Fix unpacking of array NULLs. (Elvis Pranskevichus) + * Fix .first()'s handling of counts and commands. + Bad logic caused zero-counts to return the command tag. + * Don't interrupt and close a temporal connection if it's not open. + * Use the Driver's typio attribute for TypeIO overrides. (Elvis Pranskevichus) + +1.0 released on 2010-03-27 +-------------------------- + + * **DEPRECATION**: Removed 2PC support documentation. + * **DEPRECATION**: Removed pg_python and pg_dotconf 'scripts'. + They are still accessible by python3 -m postgresql.bin.pg_* + * Add support for binary hstore. + * Add support for user service files. + * Implement a Copy manager for direct connection-to-connection COPY operations. + * Added db.do() method for DO-statement support(convenience method). + * Set the default client_min_messages level to WARNING. + NOTICEs are often not desired by programmers, and py-postgresql's + high verbosity further irritates that case. + * Added postgresql.project module to provide project information. + Project name, author, version, etc. + * Increased default recvsize and chunksize for improved performance. + * 'D' messages are special cased as builtins.tuples instead of + protocol.element3.Tuple + * Alter Statement.chunks() to return chunks of builtins.tuple. Being + an interface intended for speed, types.Row() impedes its performance. + * Fix handling of infinity values with timestamptz, timestamp, and date. + [Bug reported by Axel Rau.] + * Correct representation of PostgreSQL ARRAYs by properly recording + lowerbounds and upperbounds. Internally, sub-ARRAYs have their own + element lists. + * Implement a NotificationManager for managing the NOTIFYs received + by a connection. The class can manage NOTIFYs from multiple + connections, whereas the db.wait() method is tailored for single targets. + * Implement an ALock class for managing advisory locks using the + threading.Lock APIs. [Feedback from Valentine Gogichashvili] + * Implement reference symbols. Allow libraries to define symbols that + are used to create queries that inherit the original symbol's type and + execution method. ``db.prepare(db.prepare(...).first())`` + * Fix handling of unix domain sockets by pg.open and driver.connect. + [Reported by twitter.com/rintavarustus] + * Fix typo/dropped parts of a raise LoadError in .lib. + [Reported by Vlad Pranskevichus] + * Fix db.tracer and pg_python's --pq-trace= + * Fix count return from .first() method. Failed to provide an empty + tuple for the rformats of the bind statement. + [Reported by dou dou] diff --git a/py_opengauss/documentation/sphinx/changes-v1.1.rst b/py_opengauss/documentation/sphinx/changes-v1.1.rst new file mode 100644 index 0000000000000000000000000000000000000000..aa1abbac53dfa184dcd5e9bb861f6f32140aa56d --- /dev/null +++ b/py_opengauss/documentation/sphinx/changes-v1.1.rst @@ -0,0 +1,25 @@ +Changes in v1.1 +=============== + +1.1.0 +----- + + * Remove two-phase commit interfaces per deprecation in v1.0. + For proper two phase commit use, a lock manager must be employed that + the implementation did nothing to accommodate for. + * Add support for unpacking anonymous records (Elvis) + * Support PostgreSQL 9.2 (Elvis) + * Python 3.3 Support (Elvis) + * Add column execution method. (jwp) + * Add one-shot statement interface. Connection.query.* (jwp) + * Modify the inet/cidr support by relying on the ipaddress module introduced in Python 3.3 (Google's ipaddr project) + The existing implementation relied on simple str() representation supported by the + socket module. Unfortunately, MS Windows' socket library does not appear to support the + necessary functionality, or Python's socket module does not expose it. ipaddress fixes + the problem. + +.. note:: + The `ipaddress` module is now required for local inet and cidr. While it is + of "preliminary" status, the ipaddr project has been around for some time and + well supported. ipaddress appears to be the safest way forward for native + network types. diff --git a/py_opengauss/documentation/sphinx/changes-v1.2.rst b/py_opengauss/documentation/sphinx/changes-v1.2.rst new file mode 100644 index 0000000000000000000000000000000000000000..8b79fa165e6aa4c8cc57d36df7f27cd00d52fcf2 --- /dev/null +++ b/py_opengauss/documentation/sphinx/changes-v1.2.rst @@ -0,0 +1,18 @@ +Changes in v1.2 +=============== + +1.2.2 released on 2020-09-22 +---------------------------- + + * Correct broken Connection.proc. + * Correct IPv6 IRI host oversight. + * Document an ambiguity case of DB-API 2.0 connection creation and the workaround(unix vs host/port). + * (Pending, active in 1.3) DB-API 2.0 connect() failures caused an undesired exception chain; ClientCannotConnect is now raised. + * Minor maintenance on tests and support modules. + +1.2.0 released on 2016-06-23 +---------------------------- + + * PostgreSQL 9.3 compatibility fixes (Elvis) + * Python 3.5 compatibility fixes (Elvis) + * Add support for JSONB type (Elvis) diff --git a/py_opengauss/documentation/sphinx/changes-v1.3.rst b/py_opengauss/documentation/sphinx/changes-v1.3.rst new file mode 100644 index 0000000000000000000000000000000000000000..8b8686c3fd39aa2583172f9e937909e69be228cb --- /dev/null +++ b/py_opengauss/documentation/sphinx/changes-v1.3.rst @@ -0,0 +1,14 @@ +Changes in v1.3 +=============== + +1.3.0 +----- + + * Commit DB-API 2.0 ClientCannotConnect exception correction. + * Eliminate types-as-documentation annotations. + * Add Connection.transaction alias for asyncpg consistency. + * Eliminate multiple inheritance in `postgresql.api` in favor of ABC registration. + * Add support for PGTEST environment variable (pq-IRI) to improve test performance + and to aid in cases where the target fixture is already available. + This should help for testing the driver against servers that are not actually + postgresql. diff --git a/py_opengauss/documentation/sphinx/clientparameters.rst b/py_opengauss/documentation/sphinx/clientparameters.rst new file mode 100644 index 0000000000000000000000000000000000000000..8c8441cf8aa0a2465d43bce459ecff09aeb8a052 --- /dev/null +++ b/py_opengauss/documentation/sphinx/clientparameters.rst @@ -0,0 +1,260 @@ +Client Parameters +***************** + +.. warning:: **The interfaces dealing with optparse are subject to change in 1.0**. + +There are various sources of parameters used by PostgreSQL client applications. +The `postgresql.clientparameters` module provides a means for collecting and +managing those parameters. + +Connection creation interfaces in `postgresql.driver` are purposefully simple. +All parameters taken by those interfaces are keywords, and are taken +literally; if a parameter is not given, it will effectively be `None`. +libpq-based drivers tend differ as they inherit some default client parameters +from the environment. Doing this by default is undesirable as it can cause +trivial failures due to unexpected parameter inheritance. However, using these +parameters from the environment and other sources are simply expected in *some* +cases: `postgresql.open`, `postgresql.bin.pg_python`, and other high-level +utilities. The `postgresql.clientparameters` module provides a means to collect +them into one dictionary object for subsequent application to a connection +creation interface. + +`postgresql.clientparameters` is primarily useful to script authors that want to +provide an interface consistent with PostgreSQL commands like ``psql``. + + +Collecting Parameters +===================== + +The primary entry points in `postgresql.clientparameters` are +`postgresql.clientparameters.collect` and +`postgresql.clientparameters.resolve_password`. + +For most purposes, ``collect`` will suffice. By default, it will prompt for the +password if instructed to(``-W``). Therefore, ``resolve_password`` need not be +used in most cases:: + + >>> import sys + >>> import postgresql.clientparameters as pg_param + >>> p = pg_param.DefaultParser() + >>> co, ca = p.parse_args(sys.argv[1:]) + >>> params = pg_param.collect(parsed_options = co) + +The `postgresql.clientparameters` module is executable, so you can see the +results of the above snippet by:: + + $ python -m postgresql.clientparameters -h localhost -U a_db_user -ssearch_path=public + {'host': 'localhost', + 'password': None, + 'port': 5432, + 'settings': {'search_path': 'public'}, + 'user': 'a_db_user'} + + +`postgresql.clientparameters.collect` +-------------------------------------- + +Build a client parameter dictionary from the environment and parsed command +line options. The following is a list of keyword arguments that ``collect`` will +accept: + + ``parsed_options`` + Options parsed by `postgresql.clientparameters.StandardParser` or + `postgresql.clientparameters.DefaultParser` instances. + + ``no_defaults`` + When `True`, don't include defaults like ``pgpassfile`` and ``user``. + Defaults to `False`. + + ``environ`` + Environment variables to extract client parameter variables from. + Defaults to `os.environ` and expects a `collections.abc.Mapping` interface. + + ``environ_prefix`` + Environment variable prefix to use. Defaults to "PG". This allows the + collection of non-standard environment variables whose keys are partially + consistent with the standard variants. e.g. "PG_SRC_USER", "PG_SRC_HOST", + etc. + + ``default_pg_sysconfdir`` + The location of the pg_service.conf file. The ``PGSYSCONFDIR`` environment + variable will override this. When a default installation is present, + ``PGINSTALLATION``, it should be set to this. + + ``pg_service_file`` + Explicit location of the service file. This will override the "sysconfdir" + based path. + + ``prompt_title`` + Descriptive title to use if a password prompt is needed. `None` to disable + password resolution entirely. Setting this to `None` will also disable + pgpassfile lookups, so it is necessary that further processing occurs when + this is `None`. + + ``parameters`` + Base client parameters to use. These are set after the *defaults* are + collected. (The defaults that can be disabled by ``no_defaults``). + +If ``prompt_title`` is not set to `None`, it will prompt for the password when +instructed to do by the ``prompt_password`` key in the parameters:: + + >>> import postgresql.clientparameters as pg_param + >>> p = pg_param.collect(prompt_title = 'my_prompt!', parameters = {'prompt_password':True}) + Password for my_prompt![pq://dbusername@localhost:5432]: + >>> p + {'host': 'localhost', 'user': 'dbusername', 'password': 'secret', 'port': 5432} + +If `None`, it will leave the necessary password resolution information in the +parameters dictionary for ``resolve_password``:: + + >>> p = pg_param.collect(prompt_title = None, parameters = {'prompt_password':True}) + >>> p + {'pgpassfile': '/home/{USER}/.pgpass', 'prompt_password': True, 'host': 'localhost', 'user': 'dbusername', 'port': 5432} + +Of course, ``'prompt_password'`` is normally specified when ``parsed_options`` +received a ``-W`` option from the command line:: + + >>> op = pg_param.DefaultParser() + >>> co, ca = op.parse_args(['-W']) + >>> p = pg_param.collect(parsed_options = co) + >>> p=pg_param.collect(parsed_options = co) + Password for [pq://dbusername@localhost:5432]: + >>> p + {'host': 'localhost', 'user': 'dbusername', 'password': 'secret', 'port': 5432} + >>> + + +`postgresql.clientparameters.resolve_password` +---------------------------------------------- + +Resolve the password for the given client parameters dictionary returned by +``collect``. By default, this function need not be used as ``collect`` will +resolve the password by default. `resolve_password` accepts the following +arguments: + + ``parameters`` + First positional argument. Normalized client parameters dictionary to update + in-place with the resolved password. If the 'prompt_password' key is in + ``parameters``, it will prompt regardless(normally comes from ``-W``). + + ``getpass`` + Function to call to prompt for the password. Defaults to `getpass.getpass`. + + ``prompt_title`` + Additional title to use if a prompt is requested. This can also be specified + in the ``parameters`` as the ``prompt_title`` key. This *augments* the IRI + display on the prompt. Defaults to an empty string, ``''``. + +The resolution process is effected by the contents of the given ``parameters``. +Notable keywords: + + ``prompt_password`` + If present in the given parameters, the user will be prompted for the using + the given ``getpass`` function. This disables the password file lookup + process. + + ``prompt_title`` + This states a default prompt title to use. If the ``prompt_title`` keyword + argument is given to ``resolve_password``, this will not be used. + + ``pgpassfile`` + The PostgreSQL password file to lookup the password in. If the ``password`` + parameter is present, this will not be used. + +When resolution occurs, the ``prompt_password``, ``prompt_title``, and +``pgpassfile`` keys are *removed* from the given parameters dictionary:: + + >>> p=pg_param.collect(prompt_title = None) + >>> p + {'pgpassfile': '/home/{USER}/.pgpass', 'host': 'localhost', 'user': 'dbusername', 'port': 5432} + >>> pg_param.resolve_password(p) + >>> p + {'host': 'localhost', 'password': 'secret', 'user': 'dbusername', 'port': 5432} + + +Defaults +======== + +The following is a list of default parameters provided by ``collect`` and the +sources of their values: + + ==================== =================================================================== + Key Value + ==================== =================================================================== + ``'user'`` `getpass.getuser()` or ``'postgres'`` + ``'host'`` `postgresql.clientparameters.default_host` (``'localhost'``) + ``'port'`` `postgresql.clientparameters.default_port` (``5432``) + ``'pgpassfile'`` ``"$HOME/.pgpassfile"`` or ``[PGDATA]`` + ``'pgpass.conf'`` (Win32) + ``'sslcrtfile'`` ``[PGDATA]`` + ``'postgresql.crt'`` + ``'sslkeyfile'`` ``[PGDATA]`` + ``'postgresql.key'`` + ``'sslrootcrtfile'`` ``[PGDATA]`` + ``'root.crt'`` + ``'sslrootcrlfile'`` ``[PGDATA]`` + ``'root.crl'`` + ==================== =================================================================== + +``[PGDATA]`` referenced in the above table is a directory whose path is platform +dependent. On most systems, it is ``"$HOME/.postgresql"``, but on Windows based +systems it is ``"%APPDATA%\postgresql"`` + +.. note:: + [PGDATA] is *not* an environment variable. + + +.. _pg_envvars: + +PostgreSQL Environment Variables +================================ + +The following is a list of environment variables that will be collected by the +`postgresql.clientparameter.collect` function using "PG" as the +``environ_prefix`` and the keyword that it will be mapped to: + + ===================== ====================================== + Environment Variable Keyword + ===================== ====================================== + ``PGUSER`` ``'user'`` + ``PGDATABASE`` ``'database'`` + ``PGHOST`` ``'host'`` + ``PGPORT`` ``'port'`` + ``PGPASSWORD`` ``'password'`` + ``PGSSLMODE`` ``'sslmode'`` + ``PGSSLKEY`` ``'sslkey'`` + ``PGCONNECT_TIMEOUT`` ``'connect_timeout'`` + ``PGREALM`` ``'kerberos4_realm'`` + ``PGKRBSRVNAME`` ``'kerberos5_service'`` + ``PGPASSFILE`` ``'pgpassfile'`` + ``PGTZ`` ``'settings' = {'timezone': }`` + ``PGDATESTYLE`` ``'settings' = {'datestyle': }`` + ``PGCLIENTENCODING`` ``'settings' = {'client_encoding': }`` + ``PGGEQO`` ``'settings' = {'geqo': }`` + ===================== ====================================== + + +.. _pg_passfile: + +PostgreSQL Password File +======================== + +The password file is a simple newline separated list of ``:`` separated fields. It +is located at ``$HOME/.pgpass`` for most systems and at +``%APPDATA%\postgresql\pgpass.conf`` for Windows based systems. However, the +``PGPASSFILE`` environment variable may be used to override that location. + +The lines in the file must be in the following form:: + + hostname:port:database:username:password + +A single asterisk, ``*``, may be used to indicate that any value will match the +field. However, this only effects fields other than ``password``. + +See http://www.postgresql.org/docs/current/static/libpq-pgpass.html for more +details. + +Client parameters produced by ``collect`` that have not been processed +by ``resolve_password`` will include a ``'pgpassfile'`` key. This is the value +that ``resolve_password`` will use to locate the pgpassfile to interrogate if a +password key is not present and it is not instructed to prompt for a password. + +.. warning:: + Connection creation interfaces will *not* resolve ``'pgpassfile'``, so it is + important that the parameters produced by ``collect()`` are properly processed + before an attempt is made to establish a connection. diff --git a/py_opengauss/documentation/sphinx/cluster.rst b/py_opengauss/documentation/sphinx/cluster.rst new file mode 100644 index 0000000000000000000000000000000000000000..1993ea280df5d1269579197c577aac6b31d6731e --- /dev/null +++ b/py_opengauss/documentation/sphinx/cluster.rst @@ -0,0 +1,378 @@ +.. _cluster_management: + +****************** +Cluster Management +****************** + +py-postgresql provides cluster management tools in order to give the user +fine-grained control over a PostgreSQL cluster and access to information about an +installation of PostgreSQL. + + +.. _installation: + +Installations +============= + +`postgresql.installation.Installation` objects are primarily used to +access PostgreSQL installation information. Normally, they are created using a +dictionary constructed from the output of the pg_config_ executable:: + + from postgresql.installation import Installation, pg_config_dictionary + pg_install = Installation(pg_config_dictionary('/usr/local/pgsql/bin/pg_config')) + +The extraction of pg_config_ information is isolated from Installation +instantiation in order to allow Installations to be created from arbitrary +dictionaries. This can be useful in cases where the installation layout is +inconsistent with the standard PostgreSQL installation layout, or if a faux +Installation needs to be created for testing purposes. + + +Installation Interface Points +----------------------------- + + ``Installation(info)`` + Instantiate an Installation using the given information. Normally, this + information is extracted from a pg_config_ executable using + `postgresql.installation.pg_config_dictionary`:: + + info = pg_config_dictionary('/usr/local/pgsql/bin/pg_config') + pg_install = Installation(info) + + ``Installation.version`` + The installation's version string:: + + pg_install.version + 'PostgreSQL 9.0devel' + + ``Installation.version_info`` + A tuple containing the version's ``(major, minor, patch, state, level)``. + Where ``major``, ``minor``, ``patch``, and ``level`` are `int` objects, and + ``state`` is a `str` object:: + + pg_install.version_info + (9, 0, 0, 'devel', 0) + + ``Installation.ssl`` + A `bool` indicating whether or not the installation has SSL support. + + ``Installation.configure_options`` + The options given to the ``configure`` script that built the installation. The + options are represented using a dictionary object whose keys are normalized + long option names, and whose values are the option's argument. If the option + takes no argument, `True` will be used as the value. + + The normalization of the long option names consists of removing the preceding + dashes, lowering the string, and replacing any dashes with underscores. For + instance, ``--enable-debug`` will be ``enable_debug``:: + + pg_install.configure_options + {'enable_debug': True, 'with_libxml': True, + 'enable_cassert': True, 'with_libedit_preferred': True, + 'prefix': '/src/build/pg90', 'with_openssl': True, + 'enable_integer_datetimes': True, 'enable_depend': True} + + ``Installation.paths`` + The paths of the installation as a dictionary where the keys are the path + identifiers and the values are the absolute file system paths. For instance, + ``'bindir'`` is associated with ``$PREFIX/bin``, ``'libdir'`` is associated + with ``$PREFIX/lib``, etc. The paths included in this dictionary are + listed on the class' attributes: `Installation.pg_directories` and + `Installation.pg_executables`. + + The keys that point to installation directories are: ``bindir``, ``docdir``, + ``includedir``, ``pkgincludedir``, ``includedir_server``, ``libdir``, + ``pkglibdir``, ``localedir``, ``mandir``, ``sharedir``, and ``sysconfdir``. + + The keys that point to installation executables are: ``pg_config``, ``psql``, + ``initdb``, ``pg_resetxlog``, ``pg_controldata``, ``clusterdb``, ``pg_ctl``, + ``pg_dump``, ``pg_dumpall``, ``postgres``, ``postmaster``, ``reindexdb``, + ``vacuumdb``, ``ipcclean``, ``createdb``, ``ecpg``, ``createuser``, + ``createlang``, ``droplang``, ``dropuser``, and ``pg_restore``. + + .. note:: If the executable does not exist, the value will be `None` instead + of an absoluate path. + + To get the path to the psql_ executable:: + + from postgresql.installation import Installation + pg_install = Installation('/usr/local/pgsql/bin/pg_config') + psql_path = pg_install.paths['psql'] + + +Clusters +======== + +`postgresql.cluster.Cluster` is the class used to manage a PostgreSQL +cluster--a data directory created by initdb_. A Cluster represents a data +directory with respect to a given installation of PostgreSQL, so +creating a `postgresql.cluster.Cluster` object requires a +`postgresql.installation.Installation`, and a +file system path to the data directory. + +In part, a `postgresql.cluster.Cluster` is the Python programmer's variant of +the pg_ctl_ command. However, it goes beyond the basic process control +functionality and extends into initialization and configuration as well. + +A Cluster manages the server process using the `subprocess` module and +signals. The `subprocess.Popen` object, ``Cluster.daemon_process``, is +retained when the Cluster starts the server process itself. This gives +the Cluster access to the result code of server process when it exits, and the +ability to redirect stderr and stdout to a parameterized file object using +subprocess features. + +Despite its use of `subprocess`, Clusters can control a server process +that was *not* started by the Cluster's ``start`` method. + + +Initializing Clusters +--------------------- + +`postgresql.cluster.Cluster` provides a method for initializing a +`Cluster`'s data directory, ``init``. This method provides a Python interface to +the PostgreSQL initdb_ command. + +``init`` is a regular method and accepts a few keyword parameters. Normally, +parameters are directly mapped to initdb_ command options. However, ``password`` +makes use of initdb's capability to read the superuser's password from a file. +To do this, a temporary file is allocated internally by the method:: + + from postgresql.installation import Installation, pg_config_dictionary + from postgresql.cluster import Cluster + pg_install = Installation(pg_config_dictionary('/usr/local/pgsql/bin/pg_config')) + pg_cluster = Cluster(pg_install, 'pg_data') + pg_cluster.init(user = 'pg', password = 'secret', encoding = 'utf-8') + +The init method will block until the initdb command is complete. Once +initialized, the Cluster may be configured. + + +Configuring Clusters +-------------------- + +A Cluster's `configuration file`_ can be manipulated using the +`Cluster.settings` mapping. The mapping's methods will always access the +configuration file, so it may be desirable to cache repeat reads. Also, if +multiple settings are being applied, using the ``update()`` method may be +important to avoid writing the entire file multiple times:: + + pg_cluster.settings.update({'listen_addresses' : 'localhost', 'port' : '6543'}) + +Similarly, to avoid opening and reading the entire file multiple times, +`Cluster.settings.getset` should be used to retrieve multiple settings:: + + d = pg_cluster.settings.getset(set(('listen_addresses', 'port'))) + d + {'listen_addresses' : 'localhost', 'port' : '6543'} + +Values contained in ``settings`` are always Python strings:: + + assert pg_cluster.settings['max_connections'].__class__ is str + +The ``postgresql.conf`` file is only one part of the server configuration. +Structured access and manipulation of the pg_hba_ file is not +supported. Clusters only provide the file path to the pg_hba_ file:: + + hba = open(pg_cluster.hba_file) + +If the configuration of the Cluster is altered while the server process is +running, it may be necessary to signal the process that configuration changes +have been made. This signal can be sent using the ``Cluster.reload()`` method. +``Cluster.reload()`` will send a SIGHUP signal to the server process. However, +not all changes to configuration settings can go into effect after calling +``Cluster.reload()``. In those cases, the server process will need to be +shutdown and started again. + + +Controlling Clusters +-------------------- + +The server process of a Cluster object can be controlled with the ``start()``, +``stop()``, ``shutdown()``, ``kill()``, and ``restart()`` methods. +These methods start the server process, signal the server process, or, in the +case of restart, a combination of the two. + +When a Cluster starts the server process, it's ran as a subprocess. Therefore, +if the current process exits, the server process will exit as well. ``start()`` +does *not* automatically daemonize the server process. + +.. note:: Under Microsoft Windows, above does not hold true. The server process + will continue running despite the exit of the parent process. + +To terminate a server process, one of these three methods should be called: +``stop``, ``shutdown``, or ``kill``. ``stop`` is a graceful shutdown and will +*wait for all clients to disconnect* before shutting down. ``shutdown`` will +close any open connections and safely shutdown the server process. +``kill`` will immediately terminate the server process leading to recovery upon +starting the server process again. + +.. note:: Using ``kill`` may cause shared memory to be leaked. + +Normally, `Cluster.shutdown` is the appropriate way to terminate a server +process. + + +Cluster Interface Points +------------------------ + +Methods and properties available on `postgresql.cluster.Cluster` instances: + + ``Cluster(installation, data_directory)`` + Create a `postgresql.cluster.Cluster` object for the specified + `postgresql.installation.Installation`, and ``data_directory``. + + The ``data_directory`` must be an absoluate file system path. The directory + does *not* need to exist. The ``init()`` method may later be used to create + the cluster. + + ``Cluster.installation`` + The Cluster's `postgresql.installation.Installation` instance. + + ``Cluster.data_directory`` + The absolute path to the PostgreSQL data directory. + This directory may not exist. + + ``Cluster.init([encoding = None[, user = None[, password = None]]])`` + Run the `initdb`_ executable of the configured installation to initialize the + cluster at the configured data directory, `Cluster.data_directory`. + + ``encoding`` is mapped to ``-E``, the default database encoding. By default, + the encoding is determined from the environment's locale. + + ``user`` is mapped to ``-U``, the database superuser name. By default, the + current user's name. + + ``password`` is ultimately mapped to ``--pwfile``. The argument given to the + long option is actually a path to the temporary file that holds the given + password. + + Raises `postgresql.cluster.InitDBError` when initdb_ returns a non-zero result + code. + + Raises `postgresql.cluster.ClusterInitializationError` when there is no + initdb_ in the Installation. + + ``Cluster.initialized()`` + Whether or not the data directory exists, *and* if it looks like a PostgreSQL + data directory. Meaning, the directory must contain a ``postgresql.conf`` file + and a ``base`` directory. + + ``Cluster.drop()`` + Shutdown the Cluster's server process and completely remove the + `Cluster.data_directory` from the file system. + + ``Cluster.pid()`` + The server's process identifier as a Python `int`. `None` if there is no + server process running. + This is a method rather than a property as it may read the PID from a file + in cases where the server process was not started by the Cluster. + + ``Cluster.start([logfile = None[, settings = None]])`` + Start the PostgreSQL server process for the Cluster if it is not + already running. This will execute postgres_ as a subprocess. + + If ``logfile``, an opened and writable file object, is given, stderr and + stdout will be redirected to that file. By default, both stderr and stdout are + closed. + + If ``settings`` is given, the mapping or sequence of pairs will be used as + long options to the subprocess. For each item, ``--{key}={value}`` will be + given as an argument to the subprocess. + + ``Cluster.running()`` + Whether or not the cluster's server process is running. Returns `True` or + `False`. Even if `True` is returned, it does *not* mean that the server + process is ready to accept connections. + + ``Cluster.ready_for_connections()`` + Whether or not the Cluster is ready to accept connections. Usually called + after `Cluster.start`. + + Returns `True` when the Cluster can accept connections, `False` when it + cannot, and `None` if the Cluster's server process is not running at all. + + ``Cluster.wait_until_started([timeout = 10[, delay = 0.05]])`` + Blocks the process until the cluster is identified as being ready for + connections. Usually called after ``Cluster.start()``. + + Raises `postgresql.cluster.ClusterNotRunningError` if the server process is + not running at all. + + Raises `postgresql.cluster.ClusterTimeoutError` if + `Cluster.ready_for_connections()` does not return `True` within the given + `timeout` period. + + Raises `postgresql.cluster.ClusterStartupError` if the server process + terminates while polling for readiness. + + ``timeout`` and ``delay`` are both in seconds. Where ``timeout`` is the + maximum time to wait for the Cluster to be ready for connections, and + ``delay`` is the time to sleep between calls to + `Cluster.ready_for_connections()`. + + ``Cluster.stop()`` + Signal the cluster to shutdown when possible. The *server* will wait for all + clients to disconnect before shutting down. + + ``Cluster.shutdown()`` + Signal the cluster to shutdown immediately. Any open client connections will + be closed. + + ``Cluster.kill()`` + Signal the absolute destruction of the server process(SIGKILL). + *This will require recovery when the cluster is started again.* + *Shared memory may be leaked.* + + ``Cluster.wait_until_stopped([timeout = 10[, delay = 0.05]])`` + Blocks the process until the cluster is identified as being shutdown. Usually + called after `Cluster.stop` or `Cluster.shutdown`. + + Raises `postgresql.cluster.ClusterTimeoutError` if + `Cluster.ready_for_connections` does not return `None` within the given + `timeout` period. + + ``Cluster.reload()`` + Signal the server that it should reload its configuration files(SIGHUP). + Usually called after manipulating `Cluster.settings` or modifying the + contents of `Cluster.hba_file`. + + ``Cluster.restart([logfile = None[, settings = None[, timeout = 10]]])`` + Stop the server process, wait until it is stopped, start the server + process, and wait until it has started. + + .. note:: This calls ``Cluster.stop()``, so it will wait until clients + disconnect before starting up again. + + The ``logfile`` and ``settings`` parameters will be given to `Cluster.start`. + ``timeout`` will be given to `Cluster.wait_until_stopped` and + `Cluster.wait_until_started`. + + ``Cluster.settings`` + A `collections.abc.Mapping` interface to the ``postgresql.conf`` file of the + cluster. + + A notable extension to the mapping interface is the ``getset`` method. This + method will return a dictionary object containing the settings whose names + were contained in the `set` object given to the method. + This method should be used when multiple settings need to be retrieved from + the configuration file. + + ``Cluster.hba_file`` + The path to the cluster's pg_hba_ file. This property respects the HBA file + location setting in ``postgresql.conf``. Usually, ``$PGDATA/pg_hba.conf``. + + ``Cluster.daemon_path`` + The path to the executable to use to start the server process. + + ``Cluster.daemon_process`` + The `subprocess.Popen` instance of the server process. `None` if the server + process was not started or was not started using the Cluster object. + + +.. _pg_hba: http://www.postgresql.org/docs/current/static/auth-pg-hba-conf.html +.. _pg_config: http://www.postgresql.org/docs/current/static/app-pgconfig.html +.. _initdb: http://www.postgresql.org/docs/current/static/app-initdb.html +.. _psql: http://www.postgresql.org/docs/current/static/app-psql.html +.. _postgres: http://www.postgresql.org/docs/current/static/app-postgres.html +.. _pg_ctl: http://www.postgresql.org/docs/current/static/app-pg-ctl.html +.. _configuration file: http://www.postgresql.org/docs/current/static/runtime-config.html diff --git a/py_opengauss/documentation/sphinx/conf.py b/py_opengauss/documentation/sphinx/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..15c68cff2a7f495a873ee85261d2877ff002efe5 --- /dev/null +++ b/py_opengauss/documentation/sphinx/conf.py @@ -0,0 +1,137 @@ +import sys, os +sys.path.insert(0, os.path.abspath('../../..')) # needed for autodoc. +sys.dont_write_bytecode = True + +# read the project info from the PKG.project module. +mod = {} +with open(os.path.abspath('../../project.py')) as f: + exec(f.read(), mod, mod) + +rst_prolog = "" +rst_epilog = "" + +# General configuration +# --------------------- + +# Add any Sphinx extension module names here, as strings. They can be extensions +# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. +extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode'] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['.templates'] + +# The suffix of source filenames. +source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + +# General substitutions. +copyright = mod['meaculpa'] + +# The default replacements for |version| and |release|, also used in various +# other places throughout the built documents. +# +# The short X.Y version. +version = '.'.join(map(str, mod['version_info'][:2])) +# The full version, including alpha/beta/rc tags. +release = mod['version'] +project = mod['name'] + +# There are two options for replacing |today|: either, you set today to some +# non-false value, then it is used: +#today = '' +# Else, today_fmt is used as the format for a strftime call. +today_fmt = '%B %d, %Y' + +# List of documents that shouldn't be included in the build. +#unused_docs = [] + +# List of directories, relative to source directories, that shouldn't be searched +# for source files. +#exclude_dirs = [] + +# The reST default role (used for this markup: `text`) to use for all documents. +#default_role = None + +# If true, '()' will be appended to :func: etc. cross-reference text. +#add_function_parentheses = True + +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +#add_module_names = True + +# If true, sectionauthor and moduleauthor directives will be shown in the +# output. They are ignored by default. +#show_authors = False + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + + +# Options for HTML output +# ----------------------- + +# The style sheet to use for HTML and HTML Help pages. A file of that name +# must exist either in Sphinx' static/ path, or in one of the custom paths +# given in html_static_path. +html_style = 'default.css' + +# The name for this set of Sphinx documents. If None, it defaults to +# " v documentation". +#html_title = None + +# A shorter title for the navigation bar. Default is the same as html_title. +#html_short_title = None + +# The name of an image file (within the static path) to place at the top of +# the sidebar. +#html_logo = None + +# The name of an image file (within the static path) to use as favicon of the +# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +#html_favicon = None + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['.static'] + +# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, +# using the given strftime format. +html_last_updated_fmt = '%b %d, %Y' + +# If true, SmartyPants will be used to convert quotes and dashes to +# typographically correct entities. +#html_use_smartypants = True + +# Custom sidebar templates, maps document names to template names. +#html_sidebars = {} + +# Additional templates that should be rendered to pages, maps page names to +# template names. +#html_additional_pages = {} + +# If false, no module index is generated. +#html_use_modindex = True + +# If false, no index is generated. +#html_use_index = True + +# If true, the index is split into individual pages for each letter. +#html_split_index = False + +# If true, the reST sources are included in the HTML build as _sources/. +#html_copy_source = True + +# If true, an OpenSearch description file will be output, and all pages will +# contain a tag referring to it. The value of this option must be the +# base URL from which the finished HTML is served. +#html_use_opensearch = '' + +# If nonempty, this is the file name suffix for HTML files (e.g. ".xhtml"). +#html_file_suffix = '' + +# Output file base name for HTML help builder. +htmlhelp_basename = project diff --git a/py_opengauss/documentation/sphinx/copyman.rst b/py_opengauss/documentation/sphinx/copyman.rst new file mode 100644 index 0000000000000000000000000000000000000000..d4a18cb16bcc8ee4e911498117e1d4ad6491f1fe --- /dev/null +++ b/py_opengauss/documentation/sphinx/copyman.rst @@ -0,0 +1,317 @@ +.. _pg_copyman: + +*************** +Copy Management +*************** + +The `postgresql.copyman` module provides a way to quickly move COPY data coming +from one connection to many connections. Alternatively, it can be sourced +by arbitrary iterators and target arbitrary callables. + +Statement execution methods offer a way for running COPY operations +with iterators, but the cost of allocating objects for each row is too +significant for transferring gigabytes of COPY data from one connection to +another. The interfaces available on statement objects are primarily intended to +be used when transferring COPY data to and from arbitrary Python +objects. + +Direct connection-to-connection COPY operations can be performed using the +high-level `postgresql.copyman.transfer` function:: + + >>> from postgresql import copyman + >>> send_stmt = source.prepare("COPY (SELECT i FROM generate_series(1, 1000000) AS g(i)) TO STDOUT") + >>> destination.execute("CREATE TEMP TABLE loading_table (i int8)") + >>> receive_stmt = destination.prepare("COPY loading_table FROM STDIN") + >>> total_rows, total_bytes = copyman.transfer(send_stmt, receive_stmt) + +However, if more control is needed, the `postgresql.copyman.CopyManager` class +should be used directly. + + +Copy Managers +============= + +The `postgresql.copyman.CopyManager` class manages the Producer and the +Receivers involved in a COPY operation. Normally, +`postgresql.copyman.StatementProducer` and +`postgresql.copyman.StatementReceiver` instances. Naturally, a Producer is the +object that produces the COPY data to be given to the Manager's Receivers. + +Using a Manager directly means that there is a need for more control over +the operation. The Manager is both a context manager and an iterator. The +context manager interfaces handle initialization and finalization of the COPY +state, and the iterator provides an event loop emitting information about the +amount of COPY data transferred this cycle. Normal usage takes the form:: + + >>> from postgresql import copyman + >>> send_stmt = source.prepare("COPY (SELECT i FROM generate_series(1, 1000000) AS g(i)) TO STDOUT") + >>> destination.execute("CREATE TEMP TABLE loading_table (i int8)") + >>> receive_stmt = destination.prepare("COPY loading_table FROM STDIN") + >>> producer = copyman.StatementProducer(send_stmt) + >>> receiver = copyman.StatementReceiver(receive_stmt) + >>> + >>> with source.xact(), destination.xact(): + ... with copyman.CopyManager(producer, receiver) as copy: + ... for num_messages, num_bytes in copy: + ... update_rate(num_bytes) + +As an alternative to a for-loop inside a with-statement block, the `run` method +can be called to perform the operation:: + + >>> with source.xact(), destination.xact(): + ... copyman.CopyManager(producer, receiver).run() + +However, there is little benefit beyond using the high-level +`postgresql.copyman.transfer` function. + +Manager Interface Points +------------------------ + +Primarily, the `postgresql.copyman.CopyManager` provides a context manager and +an iterator for controlling the COPY operation. + + ``CopyManager.run()`` + Perform the entire COPY operation. + + ``CopyManager.__enter__()`` + Start the COPY operation. Connections taking part in the COPY should **not** + be used until ``__exit__`` is ran. + + ``CopyManager.__exit__(typ, val, tb)`` + Finish the COPY operation. Fails in the case of an incomplete + COPY, or an untrapped exception. Either returns `None` or raises the generalized + exception, `postgresql.copyman.CopyFail`. + + ``CopyManager.__iter__()`` + Returns the CopyManager instance. + + ``CopyManager.__next__()`` + Transfer the next chunk of COPY data to the receivers. Yields a tuple + consisting of the number of messages and bytes transferred, + ``(num_messages, num_bytes)``. Raises `StopIteration` when complete. + + Raises `postgresql.copyman.ReceiverFault` when a Receiver raises an + exception. + Raises `postgresql.copyman.ProducerFault` when the Producer raises an + exception. The original exception is available via the exception's + ``__context__`` attribute. + + ``CopyManager.reconcile(faulted_receiver)`` + Reconcile a faulted receiver. When a receiver faults, it will no longer + be in the set of Receivers. This method is used to signal to the manager that the + problem has been corrected, and the receiver is again ready to receive. + + ``CopyManager.receivers`` + The `builtins.set` of Receivers involved in the COPY operation. + + ``CopyManager.producer`` + The Producer emitting the data to be given to the Receivers. + + +Faults +====== + +The CopyManager generalizes any exceptions that occur during transfer. While +inside the context manager, `postgresql.copyman.Fault` may be raised if a +Receiver or a Producer raises an exception. A `postgresql.copyman.ProducerFault` +in the case of the Producer, and `postgresql.copyman.ReceiverFault` in the case +of the Receivers. + +.. note:: + Faults are only raised by `postgresql.copyman.CopyManager.__next__`. The + ``run()`` method will only raise `postgresql.copyman.CopyFail`. + +Receiver Faults +--------------- + +The Manager assumes the Fault is fatal to a Receiver, and immediately removes +it from the set of target receivers. Additionally, if the Fault exception goes +untrapped, the copy will ultimately fail. + +The Fault exception references the Manager that raised the exception, and the +actual exceptions that occurred associated with the Receiver that caused them. + +In order to identify the exception that caused a Fault, the ``faults`` attribute +on the `postgresql.copyman.ReceiverFault` must be referenced:: + + >>> from postgresql import copyman + >>> send_stmt = source.prepare("COPY (SELECT i FROM generate_series(1, 1000000) AS g(i)) TO STDOUT") + >>> destination.execute("CREATE TEMP TABLE loading_table (i int8)") + >>> receive_stmt = destination.prepare("COPY loading_table FROM STDIN") + >>> producer = copyman.StatementProducer(send_stmt) + >>> receiver = copyman.StatementReceiver(receive_stmt) + >>> + >>> with source.xact(), destination.xact(): + ... with copyman.CopyManager(producer, receiver) as copy: + ... while copy.receivers: + ... try: + ... for num_messages, num_bytes in copy: + ... update_rate(num_bytes) + ... break + ... except copyman.ReceiverFault as cf: + ... # Access the original exception using the receiver as the key. + ... original_exception = cf.faults[receiver] + ... if unknown_failure(original_exception): + ... ... + ... raise + + +ReceiverFault Properties +~~~~~~~~~~~~~~~~~~~~~~~~ + +The following attributes exist on `postgresql.copyman.ReceiverFault` instances: + + ``ReceiverFault.manager`` + The subject `postgresql.copyman.CopyManager` instance. + + ``ReceiverFault.faults`` + A dictionary mapping the Receiver to the exception raised by that Receiver. + + +Reconciliation +~~~~~~~~~~~~~~ + +When a `postgresql.copyman.ReceiverFault` is raised, the Manager immediately +removes the Receiver so that the COPY operation can continue. Continuation of +the COPY can occur by trapping the exception and continuing the iteration of the +Manager. However, if the fault is recoverable, the +`postgresql.copyman.CopyManager.reconcile` method must be used to reintroduce the +Receiver into the Manager's set. Faults must be trapped from within the +Manager's context:: + + >>> import socket + >>> from postgresql import copyman + >>> send_stmt = source.prepare("COPY (SELECT i FROM generate_series(1, 1000000) AS g(i)) TO STDOUT") + >>> destination.execute("CREATE TEMP TABLE loading_table (i int8)") + >>> receive_stmt = destination.prepare("COPY loading_table FROM STDIN") + >>> producer = copyman.StatementProducer(send_stmt) + >>> receiver = copyman.StatementReceiver(receive_stmt) + >>> + >>> with source.xact(), destination.xact(): + ... with copyman.CopyManager(producer, receiver) as copy: + ... while copy.receivers: + ... try: + ... for num_messages, num_bytes in copy: + ... update_rate(num_bytes) + ... except copyman.ReceiverFault as cf: + ... if isinstance(cf.faults[receiver], socket.timeout): + ... copy.reconcile(receiver) + ... else: + ... raise + +Recovering from Faults does add significant complexity to a COPY operation, +so, often, it's best to avoid conditions in which reconciliable Faults may +occur. + + +Producer Faults +--------------- + +Producer faults are normally fatal to the COPY operation and should rarely be +trapped. However, the Manager makes no state changes when a Producer faults, +so, unlike Receiver Faults, no reconciliation process is necessary; rather, +if it's safe to continue, the Manager's iterator should continue to be +processed. + +ProducerFault Properties +~~~~~~~~~~~~~~~~~~~~~~~~ + +The following attributes exist on `postgresql.copyman.ProducerFault` instances: + + ``ReceiverFault.manager`` + The subject `postgresql.copyman.CopyManager`. + + ``ReceiverFault.__context__`` + The original exception raised by the Producer. + + +Failures +======== + +When a COPY operation is aborted, either by an exception or by the iterator +being broken, a `postgresql.copyman.CopyFail` exception will be raised by the +Manager's ``__exit__()`` method. The `postgresql.copyman.CopyFail` exception +offers to record any exceptions that occur during the exit of the context +managers of the Producer and the Receivers. + + +CopyFail Properties +------------------- + +The following properties exist on `postgresql.copyman.CopyFail` exceptions: + + ``CopyFail.manager`` + The Manager whose COPY operation failed. + + ``CopyFail.receiver_faults`` + A dictionary mapping a `postgresql.copyman.Receiver` to the exception raised + by that Receiver's ``__exit__``. `None` if no exceptions were raised by the + Receivers. + + ``CopyFail.producer_fault`` + The exception Raised by the `postgresql.copyman.Producer`. `None` if none. + + +Producers +========= + +The following Producers are available: + + ``postgresql.copyman.StatementProducer(postgresql.api.Statement)`` + Given a Statement producing COPY data, construct a Producer. + + ``postgresql.copyman.IteratorProducer(collections.abc.Iterator)`` + Given an Iterator producing *chunks* of COPY lines, construct a Producer to + manage the data coming from the iterator. + + +Receivers +========= + + ``postgresql.copyman.StatementReceiver(postgresql.api.Statement)`` + Given a Statement producing COPY data, construct a Producer. + + ``postgresql.copyman.CallReceiver(callable)`` + Given a callable, construct a Receiver that will transmit COPY data in chunks + of lines. That is, the callable will be given a list of COPY lines for each + transfer cycle. + + +Terminology +=========== + +The following terms are regularly used to describe the implementation and +processes of the `postgresql.copyman` module: + + Manager + The object used to manage data coming from a Producer and being given to the + Receivers. It also manages the necessary initialization and finalization steps + required by those factors. + + Producer + The object used to produce the COPY data to be given to the Receivers. The + source. + + Receiver + An object that consumes COPY data. A target. + + Fault + Specifically, `postgresql.copyman.Fault` exceptions. A Fault is raised + when a Receiver or a Producer raises an exception during the COPY operation. + + Reconciliation + Generally, the steps performed by the "reconcile" method on + `postgresql.copyman.CopyManager` instances. More precisely, the + necessary steps for a Receiver's reintroduction into the COPY operation after + a Fault. + + Failed Copy + A failed copy is an aborted COPY operation. This occurs in situations of + untrapped exceptions or an incomplete COPY. Specifically, the COPY will be + noted as failed in cases where the Manager's iterator is *not* ran until + exhaustion. + + Realignment + The process of providing compensating data to the Receivers so that the + connection will be on a message boundary. Occurs when the COPY operation + is aborted. diff --git a/py_opengauss/documentation/sphinx/driver.rst b/py_opengauss/documentation/sphinx/driver.rst new file mode 100644 index 0000000000000000000000000000000000000000..aaebde26a48be1c41f1b2ce566954402105ab5d8 --- /dev/null +++ b/py_opengauss/documentation/sphinx/driver.rst @@ -0,0 +1,1806 @@ +.. _db_interface: + +****** +Driver +****** + +`postgresql.driver` provides a PG-API, `postgresql.api`, interface to a +PostgreSQL server using PQ version 3.0 to facilitate communication. It makes +use of the protocol's extended features to provide binary datatype transmission +and protocol level prepared statements for strongly typed parameters. + +`postgresql.driver` currently supports PostgreSQL servers as far back as 8.0. +Prior versions are not tested. While any version of PostgreSQL supporting +version 3.0 of the PQ protocol *should* work, many features may not work due to +absent functionality in the remote end. + +For DB-API 2.0 users, the driver module is located at +`postgresql.driver.dbapi20`. The DB-API 2.0 interface extends PG-API. All of the +features discussed in this chapter are available on DB-API connections. + +.. warning:: + PostgreSQL versions 8.1 and earlier do not support standard conforming + strings. In order to avoid subjective escape methods on connections, + `postgresql.driver.pq3` enables the ``standard_conforming_strings`` setting + by default. Greater care must be taken when working versions that do not + support standard strings. + **The majority of issues surrounding the interpolation of properly quoted literals can be easily avoided by using parameterized statements**. + +The following identifiers are regularly used as shorthands for significant +interface elements: + + ``db`` + `postgresql.api.Connection`, a database connection. `Connections`_ + + ``ps`` + `postgresql.api.Statement`, a prepared statement. `Prepared Statements`_ + + ``c`` + `postgresql.api.Cursor`, a cursor; the results of a prepared statement. + `Cursors`_ + + ``C`` + `postgresql.api.Connector`, a connector. `Connectors`_ + + +Establishing a Connection +========================= + +There are many ways to establish a `postgresql.api.Connection` to a +PostgreSQL server using `postgresql.driver`. This section discusses those, +connection creation, interfaces. + + +`postgresql.open` +----------------- + +In the root package module, the ``open()`` function is provided for accessing +databases using a locator string and optional connection keywords. The string +taken by `postgresql.open` is a URL whose components make up the client +parameters:: + + >>> db = postgresql.open("pq://localhost/postgres") + +This will connect to the host, ``localhost`` and to the database named +``postgres`` via the ``pq`` protocol. open will inherit client parameters from +the environment, so the user name given to the server will come from +``$PGUSER``, or if that is unset, the result of `getpass.getuser`--the username +of the user running the process. The user's "pgpassfile" will even be +referenced if no password is given:: + + >>> db = postgresql.open("pq://username:password@localhost/postgres") + +In this case, the password *is* given, so ``~/.pgpass`` would never be +referenced. The ``user`` client parameter is also given, ``username``, so +``$PGUSER`` or `getpass.getuser` will not be given to the server. + +Settings can also be provided by the query portion of the URL:: + + >>> db = postgresql.open("pq://user@localhost/postgres?search_path=public&timezone=mst") + +The above syntax ultimately passes the query as settings(see the description of +the ``settings`` keyword in `Connection Keywords`). Driver parameters require a +distinction. This distinction is made when the setting's name is wrapped in +square-brackets, '[' and ']': + + >>> db = postgresql.open("pq://user@localhost/postgres?[sslmode]=require&[connect_timeout]=5") + +``sslmode`` and ``connect_timeout`` are driver parameters. These are never sent +to the server, but if they were not in square-brackets, they would be, and the +driver would never identify them as driver parameters. + +The general structure of a PQ-locator is:: + + protocol://user:password@host:port/database?[driver_setting]=value&server_setting=value + +Optionally, connection keyword arguments can be used to override anything given +in the locator:: + + >>> db = postgresql.open("pq://user:secret@host", password = "thE_real_sekrat") + +Or, if the locator is not desired, individual keywords can be used exclusively:: + + >>> db = postgresql.open(user = 'user', host = 'localhost', port = 6543) + +In fact, all arguments to `postgresql.open` are optional as all arguments are +keywords; ``iri`` is merely the first keyword argument taken by +`postgresql.open`. If the environment has all the necessary parameters for a +successful connection, there is no need to pass anything to open:: + + >>> db = postgresql.open() + +For a complete list of keywords that `postgresql.open` can accept, see +`Connection Keywords`_. +For more information about the environment variables, see :ref:`pg_envvars`. +For more information about the ``pgpassfile``, see :ref:`pg_passfile`. + +`postgresql.driver.connect` +--------------------------- + +`postgresql.open` is a high-level interface to connection creation. It provides +password resolution services and client parameter inheritance. For some +applications, this is undesirable as such implicit inheritance may lead to +failures due to unanticipated parameters being used. For those applications, +use of `postgresql.open` is not recommended. Rather, `postgresql.driver.connect` +should be used when explicit parameterization is desired by an application: + + >>> import py_opengauss.driver as pg_driver + >>> db = pg_driver.connect( + ... user = 'usename', + ... password = 'secret', + ... host = 'localhost', + ... port = 5432 + ... ) + +This will create a connection to the server listening on port + + >>> import py_opengauss.driver as pg_driver + >>> db = pg_driver.connect( + ... user = 'usename', + ... password = 'secret', + ... host = 'localhost', + ... port = 5432 + ... ) + +This will create a connection to the server listening on port + + >>> import postgresql.driver as pg_driver + >>> db = pg_driver.connect( + ... user = 'usename', + ... password = 'secret', + ... host = 'localhost', + ... port = 5432 + ... ) + +This will create a connection to the server listening on port ``5432`` +on the host ``localhost`` as the user ``usename`` with the password ``secret``. + +.. note:: + `connect` will *not* inherit parameters from the environment as libpq-based drivers do. + +See `Connection Keywords`_ for a full list of acceptable keyword parameters and +their meaning. + + +Connectors +---------- + +Connectors are the supporting objects used to instantiate a connection. They +exist for the purpose of providing connections with the necessary abstractions +for facilitating the client's communication with the server, *and to act as a +container for the client parameters*. The latter purpose is of primary interest +to this section. + +Each connection object is associated with its connector by the ``connector`` +attribute on the connection. This provides the user with access to the +parameters used to establish the connection in the first place, and the means to +create another connection to the same server. The attributes on the connector +should *not* be altered. If parameter changes are needed, a new connector should +be created. + +The attributes available on a connector are consistent with the names of the +connection parameters described in `Connection Keywords`_, so that list can be +used as a reference to identify the information available on the connector. + +Connectors fit into the category of "connection creation interfaces", so +connector instantiation normally takes the same parameters that the +`postgresql.driver.connect` function takes. + +.. note:: + Connector implementations are specific to the transport, so keyword arguments + like ``host`` and ``port`` aren't supported by the ``Unix`` connector. + +The driver, `postgresql.driver.default` provides a set of connectors for making +a connection: + + ``postgresql.driver.default.host(...)`` + Provides a ``getaddrinfo()`` abstraction for establishing a connection. + + ``postgresql.driver.default.ip4(...)`` + Connect to a single IPv4 addressed host. + + ``postgresql.driver.default.ip6(...)`` + Connect to a single IPv6 addressed host. + + ``postgresql.driver.default.unix(...)`` + Connect to a single unix domain socket. Requires the ``unix`` keyword which + must be an absolute path to the unix domain socket to connect to. + +``host`` is the usual connector used to establish a connection:: + + >>> C = postgresql.driver.default.host( + ... user = 'auser', + ... host = 'foo.com', + ... port = 5432) + >>> # create + >>> db = C() + >>> # establish + >>> db.connect() + +If a constant internet address is used, ``ip4`` or ``ip6`` can be used:: + + >>> C = postgresql.driver.default.ip4(user='auser', host='127.0.0.1', port=5432) + >>> db = C() + >>> db.connect() + +Additionally, ``db.connect()`` on ``db.__enter__()`` for with-statement support: + + >>> with C() as db: + ... ... + +Connectors are constant. They have no knowledge of PostgreSQL service files, +environment variables or LDAP services, so changes made to those facilities +will *not* be reflected in a connector's configuration. If the latest +information from any of these sources is needed, a new connector needs to be +created as the credentials have changed. + +.. note:: + ``host`` connectors use ``getaddrinfo()``, so if DNS changes are made, + new connections *will* use the latest information. + + +Connection Keywords +------------------- + +The following is a list of keywords accepted by connection creation +interfaces: + + ``user`` + The user to connect as. + + ``password`` + The user's password. + + ``database`` + The name of the database to connect to. (PostgreSQL defaults it to `user`) + + ``host`` + The hostname or IP address to connect to. + + ``port`` + The port on the host to connect to. + + ``unix`` + The unix domain socket to connect to. Exclusive with ``host`` and ``port``. + Expects a string containing the *absolute path* to the unix domain socket to + connect to. + + ``settings`` + A dictionary or key-value pair sequence stating the parameters to give to the + database. These settings are included in the startup packet, and should be + used carefully as when an invalid setting is given, it will cause the + connection to fail. + + ``connect_timeout`` + Amount of time to wait for a connection to be made. (in seconds) + + ``server_encoding`` + Hint given to the driver to properly encode password data and some information + in the startup packet. + This should only be used in cases where connections cannot be made due to + authentication failures that occur while using known-correct credentials. + + ``sslmode`` + ``'disable'`` + Don't allow SSL connections. + ``'allow'`` + Try without SSL first, but if that doesn't work, try with. + ``'prefer'`` + Try SSL first, then without. + ``'require'`` + Require an SSL connection. + + ``sslcrtfile`` + Certificate file path given to `ssl.wrap_socket`. + + ``sslkeyfile`` + Key file path given to `ssl.wrap_socket`. + + ``sslrootcrtfile`` + Root certificate file path given to `ssl.wrap_socket` + + ``sslrootcrlfile`` + Revocation list file path. [Currently not checked.] + + +Connections +=========== + +`postgresql.open` and `postgresql.driver.connect` provide the means to +establish a connection. Connections provide a `postgresql.api.Database` +interface to a PostgreSQL server; specifically, a `postgresql.api.Connection`. + +Connections are one-time objects. Once, it is closed or lost, it can longer be +used to interact with the database provided by the server. If further use of the +server is desired, a new connection *must* be established. + +.. note:: + Cannot connect failures, exceptions raised on ``connect()``, are also terminal. + +In cases where operations are performed on a closed connection, a +`postgresql.exceptions.ConnectionDoesNotExistError` will be raised. + + +Database Interface Points +------------------------- + +After a connection is established:: + + >>> import postgresql + >>> db = postgresql.open(...) + +The methods and properties on the connection object are ready for use: + + ``Connection.prepare(sql_statement_string)`` + Create a `postgresql.api.Statement` object for querying the database. + This provides an "SQL statement template" that can be executed multiple times. + See `Prepared Statements`_ for more information. + + ``Connection.proc(procedure_id)`` + Create a `postgresql.api.StoredProcedure` object referring to a stored + procedure on the database. The returned object will provide a + `collections.abc.Callable` interface to the stored procedure on the server. See + `Stored Procedures`_ for more information. + + ``Connection.statement_from_id(statement_id)`` + Create a `postgresql.api.Statement` object from an existing statement + identifier. This is used in cases where the statement was prepared on the + server. See `Prepared Statements`_ for more information. + + ``Connection.cursor_from_id(cursor_id)`` + Create a `postgresql.api.Cursor` object from an existing cursor identifier. + This is used in cases where the cursor was declared on the server. See + `Cursors`_ for more information. + + ``Connection.do(language, source)`` + Execute a DO statement on the server using the specified language. + *DO statements are available on PostgreSQL 9.0 and greater.* + *Executing this method on servers that do not support DO statements will* + *likely cause a SyntaxError*. + + ``Connection.execute(sql_statements_string)`` + Run a block of SQL on the server. This method returns `None` unless an error + occurs. If errors occur, the processing of the statements will stop and the + error will be raised. + + ``Connection.xact(isolation = None, mode = None)`` + The `postgresql.api.Transaction` constructor for creating transactions. + This method creates a transaction reference. The transaction will not be + started until it's instructed to do so. See `Transactions`_ for more + information. + + ``Connection.settings`` + A property providing a `collections.abc.MutableMapping` interface to the + database's SQL settings. See `Settings`_ for more information. + + ``Connection.clone()`` + Create a new connection object based on the same factors that were used to + create ``db``. The new connection returned will already be connected. + + ``Connection.msghook(msg)`` + By default, the `msghook` attribute does not exist. If set to a callable, any + message that occurs during an operation of the database or an operation of a + database derived object will be given to the callable. See the + `Database Messages`_ section for more information. + + ``Connection.listen(*channels)`` + Start listening for asynchronous notifications in the specified channels. + Sends a batch of ``LISTEN`` statements to the server. + + ``Connection.unlisten(*channels)`` + Stop listening for asynchronous notifications in the specified channels. + Sends a batch of ``UNLISTEN`` statements to the server. + + ``Connection.listening_channels()`` + Return an iterator producing the channel names that are currently being + listened to. + + ``Connection.notify(*channels, **channel_and_payload)`` + NOTIFY the channels with the given payload. Sends a batch of ``NOTIFY`` + statements to the server. + + Equivalent to issuing "NOTIFY " or "NOTIFY , " + for each item in `channels` and `channel_and_payload`. All NOTIFYs issued + will occur in the same transaction, regardless of auto-commit. + + The items in `channels` can either be a string or a tuple. If a string, + no payload is given, but if an item is a `builtins.tuple`, the second item + in the pair will be given as the payload, and the first as the channel. + `channels` offers a means to issue NOTIFYs in guaranteed order:: + + >>> db.notify('channel1', ('different_channel', 'payload')) + + In the above, ``NOTIFY "channel1";`` will be issued first, followed by + ``NOTIFY "different_channel", 'payload';``. + + The items in `channel_and_payload` are all payloaded NOTIFYs where the + keys are the channels and the values are the payloads. Order is undefined:: + + >>> db.notify(channel_name = 'payload_data') + + `channels` and `channels_and_payload` can be used together. In such cases all + NOTIFY statements generated from `channels_and_payload` will follow those in + `channels`. + + ``Connection.iternotifies(timeout = None)`` + Return an iterator to the NOTIFYs received on the connection. The iterator + will yield notification triples consisting of ``(channel, payload, pid)``. + While iterating, the connection should *not* be used in other threads. + The optional timeout can be used to enable "idle" events in which `None` + objects will be yielded by the iterator. + See :ref:`notifyman` for details. + +When a connection is established, certain pieces of information are collected from +the backend. The following are the attributes set on the connection object after +the connection is made: + + ``Connection.version`` + The version string of the *server*; the result of ``SELECT version()``. + + ``Connection.version_info`` + A ``sys.version_info`` form of the ``server_version`` setting. eg. + ``(8, 1, 2, 'final', 0)``. + + ``Connection.security`` + `None` if no security. ``'ssl'`` if SSL is enabled. + + ``Connection.backend_id`` + The process-id of the backend process. + + ``Connection.backend_start`` + When backend was started. ``datetime.datetime`` instance. + + ``Connection.client_address`` + The address of the client that the backend is communicating with. + + ``Connection.client_port`` + The port of the client that the backend is communicating with. + + ``Connection.fileno()`` + Method to get the file descriptor number of the connection's socket. This + method will return `None` if the socket object does not have a ``fileno``. + Under normal circumstances, it will return an `int`. + +The ``backend_start``, ``client_address``, and ``client_port`` are collected +from pg_stat_activity. If this information is unavailable, the attributes will +be `None`. + + +Prepared Statements +=================== + +Prepared statements are the primary entry point for initiating an operation on +the database. Prepared statement objects represent a request that will, likely, +be sent to the database at some point in the future. A statement is a single +SQL command. + +The ``prepare`` entry point on the connection provides the standard method for +creating a `postgersql.api.Statement` instance bound to the +connection(``db``) from an SQL statement string:: + + >>> ps = db.prepare("SELECT 1") + >>> ps() + [(1,)] + +Statement objects may also be created from a statement identifier using the +``statement_from_id`` method on the connection. When this method is used, the +statement must have already been prepared or an error will be raised. + + >>> db.execute("PREPARE a_statement_id AS SELECT 1;") + >>> ps = db.statement_from_id('a_statement_id') + >>> ps() + [(1,)] + +When a statement is executed, it binds any given parameters to a *new* cursor +and the entire result-set is returned. + +Statements created using ``prepare()`` will leverage garbage collection in order +to automatically close statements that are no longer referenced. However, +statements created from pre-existing identifiers, ``statement_from_id``, must +be explicitly closed if the statement is to be discarded. + +Statement objects are one-time objects. Once closed, they can no longer be used. + + +Statement Interface Points +-------------------------- + +Prepared statements can be executed just like functions: + + >>> ps = db.prepare("SELECT 'hello, world!'") + >>> ps() + [('hello, world!',)] + +The default execution method, ``__call__``, produces the entire result set. It +is the simplest form of statement execution. Statement objects can be executed in +different ways to accommodate for the larger results or random access(scrollable +cursors). + +Prepared statement objects have a few execution methods: + + ``Statement(*parameters)`` + As shown before, statement objects can be invoked like a function to get + the statement's results. + + ``Statement.rows(*parameters)`` + Return a iterator to all the rows produced by the statement. This + method will stream rows on demand, so it is ideal for situations where + each individual row in a large result-set must be processed. + + ``iter(Statement)`` + Convenience interface that executes the ``rows()`` method without arguments. + This enables the following syntax: + + >>> for table_name, in db.prepare("SELECT table_name FROM information_schema.tables"): + ... print(table_name) + + ``Statement.column(*parameters)`` + Return a iterator to the first column produced by the statement. This + method will stream values on demand, and *should* only be used with statements + that have a single column; otherwise, bandwidth will ultimately be wasted as + the other columns will be dropped. + *This execution method cannot be used with COPY statements.* + + ``Statement.first(*parameters)`` + For simple statements, cursor objects are unnecessary. + Consider the data contained in ``c`` from above, 'hello world!'. To get at this + data directly from the ``__call__(...)`` method, it looks something like:: + + >>> ps = db.prepare("SELECT 'hello, world!'") + >>> ps()[0][0] + 'hello, world!' + + To simplify access to simple data, the ``first`` method will simply return + the "first" of the result set:: + + >>> ps.first() + 'hello, world!' + + The first value. + When the result set consists of a single column, ``first()`` will return + that column in the first row. + + The first row. + When the result set consists of multiple columns, ``first()`` will return + that first row. + + The first, and only, row count. + When DML--for instance, an INSERT-statement--is executed, ``first()`` will + return the row count returned by the statement as an integer. + + .. note:: + DML that returns row data, RETURNING, will *not* return a row count. + + The result set created by the statement determines what is actually returned. + Naturally, a statement used with ``first()`` should be crafted with these + rules in mind. + + ``Statement.chunks(*parameters)`` + This access point is designed for situations where rows are being streamed out + quickly. It is a method that returns a ``collections.abc.Iterator`` that produces + *sequences* of rows. This is the most efficient way to get rows from the + database. The rows in the sequences are ``builtins.tuple`` objects. + + ``Statement.declare(*parameters)`` + Create a scrollable cursor with hold. This returns a `postgresql.api.Cursor` + ready for accessing random rows in the result-set. Applications that use the + database to support paging can use this method to manage the view. + + ``Statement.close()`` + Close the statement inhibiting further use. + + ``Statement.load_rows(collections.abc.Iterable(parameters))`` + Given an iterable producing parameters, execute the statement for each + iteration. Always returns `None`. + + ``Statement.load_chunks(collections.abc.Iterable(collections.abc.Iterable(parameters)))`` + Given an iterable of iterables producing parameters, execute the statement + for each parameter produced. However, send the all execution commands with + the corresponding parameters of each chunk before reading any results. + Always returns `None`. This access point is designed to be used in conjunction + with ``Statement.chunks()`` for transferring rows from one connection to another with + great efficiency:: + + >>> dst.prepare(...).load_chunks(src.prepare(...).chunks()) + + ``Statement.clone()`` + Create a new statement object based on the same factors that were used to + create ``ps``. + + ``Statement.msghook(msg)`` + By default, the `msghook` attribute does not exist. If set to a callable, any + message that occurs during an operation of the statement or an operation of a + statement derived object will be given to the callable. See the + `Database Messages`_ section for more information. + +In order to provide the appropriate type transformations, the driver must +acquire metadata about the statement's parameters and results. This data is +published via the following properties on the statement object: + + ``Statement.sql_parameter_types`` + A sequence of SQL type names specifying the types of the parameters used in + the statement. + + ``Statement.sql_column_types`` + A sequence of SQL type names specifying the types of the columns produced by + the statement. `None` if the statement does not return row-data. + + ``Statement.pg_parameter_types`` + A sequence of PostgreSQL type Oid's specifying the types of the parameters + used in the statement. + + ``Statement.pg_column_types`` + A sequence of PostgreSQL type Oid's specifying the types of the columns produced by + the statement. `None` if the statement does not return row-data. + + ``Statement.parameter_types`` + A sequence of Python types that the statement expects. + + ``Statement.column_types`` + A sequence of Python types that the statement will produce. + + ``Statement.column_names`` + A sequence of `str` objects specifying the names of the columns produced by + the statement. `None` if the statement does not return row-data. + +The indexes of the parameter sequences correspond to the parameter's +identifier, N+1: ``sql_parameter_types[0]`` -> ``'$1'``. + + >>> ps = db.prepare("SELECT $1::integer AS intname, $2::varchar AS chardata") + >>> ps.sql_parameter_types + ('INTEGER','VARCHAR') + >>> ps.sql_column_types + ('INTEGER','VARCHAR') + >>> ps.column_names + ('intname','chardata') + >>> ps.column_types + (, ) + + +Parameterized Statements +------------------------ + +Statements can take parameters. Using statement parameters is the recommended +way to interrogate the database when variable information is needed to formulate +a complete request. In order to do this, the statement must be defined using +PostgreSQL's positional parameter notation. ``$1``, ``$2``, ``$3``, etc:: + + >>> ps = db.prepare("SELECT $1") + >>> ps('hello, world!')[0][0] + 'hello, world!' + +PostgreSQL determines the type of the parameter based on the context of the +parameter's identifier:: + + >>> ps = db.prepare( + ... "SELECT * FROM information_schema.tables WHERE table_name = $1 LIMIT $2" + ... ) + >>> ps("tables", 1) + [('postgres', 'information_schema', 'tables', 'VIEW', None, None, None, None, None, 'NO', 'NO', None)] + +Parameter ``$1`` in the above statement will take on the type of the +``table_name`` column and ``$2`` will take on the type required by the LIMIT +clause(text and int8). + +However, parameters can be forced to a specific type using explicit casts: + + >>> ps = db.prepare("SELECT $1::integer") + >>> ps.first(-400) + -400 + +Parameters are typed. PostgreSQL servers provide the driver with the +type information about a positional parameter, and the serialization routine +will raise an exception if the given object is inappropriate. The Python +types expected by the driver for a given SQL-or-PostgreSQL type are listed +in `Type Support`_. + +This usage of types is not always convenient. Notably, the `datetime` module +does not provide a friendly way for a user to express intervals, dates, or +times. There is a likely inclination to forego these parameter type +requirements. + +In such cases, explicit casts can be made to work-around the type +requirements:: + + >>> ps = db.prepare("SELECT $1::text::date") + >>> ps.first('yesterday') + datetime.date(2009, 3, 11) + +The parameter, ``$1``, is given to the database as a string, which is then +promptly cast into a date. Of course, without the explicit cast as text, the +outcome would be different:: + + >>> ps = db.prepare("SELECT $1::date") + >>> ps.first('yesterday') + Traceback: + ... + postgresql.exceptions.ParameterError + +The function that processes the parameter expects a `datetime.date` object, and +the given `str` object does not provide the necessary interfaces for the +conversion, so the driver raises a `postgresql.exceptions.ParameterError` from +the original conversion exception. + + +Inserting and DML +----------------- + +Loading data into the database is facilitated by prepared statements. In these +examples, a table definition is necessary for a complete illustration:: + + >>> db.execute( + ... """ + ... CREATE TABLE employee ( + ... employee_name text, + ... employee_salary numeric, + ... employee_dob date, + ... employee_hire_date date + ... ); + ... """ + ... ) + +Create an INSERT statement using ``prepare``:: + + >>> mkemp = db.prepare("INSERT INTO employee VALUES ($1, $2, $3, $4)") + +And add "Mr. Johnson" to the table:: + + >>> import datetime + >>> r = mkemp( + ... "John Johnson", + ... "92000", + ... datetime.date(1950, 12, 10), + ... datetime.date(1998, 4, 23) + ... ) + >>> print(r[0]) + INSERT + >>> print(r[1]) + 1 + +The execution of DML will return a tuple. This tuple contains the completed +command name and the associated row count. + +Using the call interface is fine for making a single insert, but when multiple +records need to be inserted, it's not the most efficient means to load data. For +multiple records, the ``ps.load_rows([...])`` provides an efficient way to load +large quantities of structured data:: + + >>> from datetime import date + >>> mkemp.load_rows([ + ... ("Jack Johnson", "85000", date(1962, 11, 23), date(1990, 3, 5)), + ... ("Debra McGuffer", "52000", date(1973, 3, 4), date(2002, 1, 14)), + ... ("Barbara Smith", "86000", date(1965, 2, 24), date(2005, 7, 19)), + ... ]) + +While small, the above illustrates the ``ps.load_rows()`` method taking an +iterable of tuples that provides parameters for the each execution of the +statement. + +``load_rows`` is also used to support ``COPY ... FROM STDIN`` statements:: + + >>> copy_emps_in = db.prepare("COPY employee FROM STDIN") + >>> copy_emps_in.load_rows([ + ... b'Emp Name1\t72000\t1970-2-01\t1980-10-22\n', + ... b'Emp Name2\t62000\t1968-9-11\t1985-11-1\n', + ... b'Emp Name3\t62000\t1968-9-11\t1985-11-1\n', + ... ]) + +Copy data goes in as bytes and come out as bytes regardless of the type of COPY +taking place. It is the user's obligation to make sure the row-data is in the +appropriate encoding. + + +COPY Statements +--------------- + +`postgresql.driver` transparently supports PostgreSQL's COPY command. To the +user, COPY will act exactly like other statements that produce tuples; COPY +tuples, however, are `bytes` objects. The only distinction in usability is that +the COPY *should* be completed before other actions take place on the +connection--this is important when a COPY is invoked via ``rows()`` or +``chunks()``. + +In situations where other actions are invoked during a ``COPY TO STDOUT``, the +entire result set of the COPY will be read. However, no error will be raised so +long as there is enough memory available, so it is *very* desirable to avoid +doing other actions on the connection while a COPY is active. + +In situations where other actions are invoked during a ``COPY FROM STDIN``, a +COPY failure error will occur. The driver manages the connection state in such +a way that will purposefully cause the error as the COPY was inappropriately +interrupted. This not usually a problem as ``load_rows(...)`` and +``load_chunks(...)`` methods must complete the COPY command before returning. + +Copy data is always transferred using ``bytes`` objects. Even in cases where the +COPY is not in ``BINARY`` mode. Any needed encoding transformations *must* be +made the caller. This is done to avoid any unnecessary overhead by default:: + + >>> ps = db.prepare("COPY (SELECT i FROM generate_series(0, 99) AS g(i)) TO STDOUT") + >>> r = ps() + >>> len(r) + 100 + >>> r[0] + b'0\n' + >>> r[-1] + b'99\n' + +Of course, invoking a statement that way will read the entire result-set into +memory, which is not usually desirable for COPY. Using the ``chunks(...)`` +iterator is the *fastest* way to move data:: + + >>> ci = ps.chunks() + >>> import sys + >>> for rowset in ps.chunks(): + ... sys.stdout.buffer.writelines(rowset) + ... + + +``COPY FROM STDIN`` commands are supported via +`postgresql.api.Statement.load_rows`. Each invocation to +``load_rows`` is a single invocation of COPY. ``load_rows`` takes an iterable of +COPY lines to send to the server:: + + >>> db.execute(""" + ... CREATE TABLE sample_copy ( + ... sc_number int, + ... sc_text text + ... ); + ... """) + >>> copyin = db.prepare('COPY sample_copy FROM STDIN') + >>> copyin.load_rows([ + ... b'123\tone twenty three\n', + ... b'350\ttree fitty\n', + ... ]) + +For direct connection-to-connection COPY, use of ``load_chunks(...)`` is +recommended as it will provide the most efficient transfer method:: + + >>> copyout = src.prepare('COPY atable TO STDOUT') + >>> copyin = dst.prepare('COPY atable FROM STDIN') + >>> copyin.load_chunks(copyout.chunks()) + +Specifically, each chunk of row data produced by ``chunks()`` will be written in +full by ``load_chunks()`` before getting another chunk to write. + + +Cursors +======= + +When a prepared statement is declared, ``ps.declare(...)``, a +`postgresql.api.Cursor` is created and returned for random access to the rows in +the result set. Direct use of cursors is primarily useful for applications that +need to implement paging. For situations that need to iterate over the result +set, the ``ps.rows(...)`` or ``ps.chunks(...)`` execution methods should be +used. + +Cursors can also be created directly from ``cursor_id``'s using the +``cursor_from_id`` method on connection objects:: + + >>> db.execute('DECLARE the_cursor_id CURSOR WITH HOLD FOR SELECT 1;') + >>> c = db.cursor_from_id('the_cursor_id') + >>> c.read() + [(1,)] + >>> c.close() + +.. hint:: + If the cursor that needs to be opened is going to be treated as an iterator, + then a FETCH-statement should be prepared instead using ``cursor_from_id``. + +Like statements created from an identifier, cursors created from an identifier +must be explicitly closed in order to destroy the object on the server. +Likewise, cursors created from statement invocations will be automatically +released when they are no longer referenced. + +.. note:: + PG-API cursors are a direct interface to single result-set SQL cursors. This + is in contrast with DB-API cursors, which have interfaces for dealing with + multiple result-sets. There is no execute method on PG-API cursors. + + +Cursor Interface Points +----------------------- + +For cursors that return row data, these interfaces are provided for accessing +those results: + + ``Cursor.read(quantity = None, direction = None)`` + This method name is borrowed from `file` objects, and are semantically + similar. However, this being a cursor, rows are returned instead of bytes or + characters. When the number of rows returned is less then the quantity + requested, it means that the cursor has been exhausted in the configured + direction. The ``direction`` argument can be either ``'FORWARD'`` or `True` + to FETCH FORWARD, or ``'BACKWARD'`` or `False` to FETCH BACKWARD. + + Like, ``seek()``, the ``direction`` *property* on the cursor object effects + this method. + + ``Cursor.seek(position[, whence = 0])`` + When the cursor is scrollable, this seek interface can be used to move the + position of the cursor. See `Scrollable Cursors`_ for more information. + + ``next(Cursor)`` + This fetches the next row in the cursor object. Cursors support the iterator + protocol. While equivalent to ``cursor.read(1)[0]``, `StopIteration` is raised + if the returned sequence is empty. (``__next__()``) + + ``Cursor.close()`` + For cursors opened using ``cursor_from_id()``, this method must be called in + order to ``CLOSE`` the cursor. For cursors created by invoking a prepared + statement, this is not necessary as the garbage collection interface will take + the appropriate steps. + + ``Cursor.clone()`` + Create a new cursor object based on the same factors that were used to + create ``c``. + + ``Cursor.msghook(msg)`` + By default, the `msghook` attribute does not exist. If set to a callable, any + message that occurs during an operation of the cursor will be given to the + callable. See the `Database Messages`_ section for more information. + + +Cursors have some additional configuration properties that may be modified +during the use of the cursor: + + ``Cursor.direction`` + A value of `True`, the default, will cause read to fetch forwards, whereas a + value of `False` will cause it to fetch backwards. ``'BACKWARD'`` and + ``'FORWARD'`` can be used instead of `False` and `True`. + +Cursors normally share metadata with the statements that create them, so it is +usually unnecessary for referencing the cursor's column descriptions directly. +However, when a cursor is opened from an identifier, the cursor interface must +collect the metadata itself. These attributes provide the metadata in absence of +a statement object: + + ``Cursor.sql_column_types`` + A sequence of SQL type names specifying the types of the columns produced by + the cursor. `None` if the cursor does not return row-data. + + ``Cursor.pg_column_types`` + A sequence of PostgreSQL type Oid's specifying the types of the columns produced by + the cursor. `None` if the cursor does not return row-data. + + ``Cursor.column_types`` + A sequence of Python types that the cursor will produce. + + ``Cursor.column_names`` + A sequence of `str` objects specifying the names of the columns produced by + the cursor. `None` if the cursor does not return row-data. + + ``Cursor.statement`` + The statement that was executed that created the cursor. `None` if + unknown--``db.cursor_from_id()``. + + +Scrollable Cursors +------------------ + +Scrollable cursors are supported for applications that need to implement paging. +When statements are invoked via the ``declare(...)`` method, the returned cursor +is scrollable. + +.. note:: + Scrollable cursors never pre-fetch in order to provide guaranteed positioning. + +The cursor interface supports scrolling using the ``seek`` method. Like +``read``, it is semantically similar to a file object's ``seek()``. + +``seek`` takes two arguments: ``position`` and ``whence``: + + ``position`` + The position to scroll to. The meaning of this is determined by ``whence``. + + ``whence`` + How to use the position: absolute, relative, or absolute from end: + + absolute: ``'ABSOLUTE'`` or ``0`` (default) + seek to the absolute position in the cursor relative to the beginning of the + cursor. + + relative: ``'RELATIVE'`` or ``1`` + seek to the relative position. Negative ``position``'s will cause a MOVE + backwards, while positive ``position``'s will MOVE forwards. + + from end: ``'FROM_END'`` or ``2`` + seek to the end of the cursor and then MOVE backwards by the given + ``position``. + +The ``whence`` keyword argument allows for either numeric and textual +specifications. + +Scrolling through employees:: + + >>> emps_by_age = db.prepare(""" + ... SELECT + ... employee_name, employee_salary, employee_dob, employee_hire_date, + ... EXTRACT(years FROM AGE(employee_dob)) AS age + ... ORDER BY age ASC + ... """) + >>> c = emps_by_age.declare() + >>> # seek to the end, ``2`` works as well. + >>> c.seek(0, 'FROM_END') + >>> # scroll back one, ``1`` works as well. + >>> c.seek(-1, 'RELATIVE') + >>> # and back to the beginning again + >>> c.seek(0) + +Additionally, scrollable cursors support backward fetches by specifying the +direction keyword argument:: + + >>> c.seek(0, 2) + >>> c.read(1, 'BACKWARD') + + +Cursor Direction +---------------- + +The ``direction`` property on the cursor states the default direction for read +and seek operations. Normally, the direction is `True`, ``'FORWARD'``. When the +property is set to ``'BACKWARD'`` or `False`, the read method will fetch +backward by default, and seek operations will be inverted to simulate a +reversely ordered cursor. The following example illustrates the effect:: + + >>> reverse_c = db.prepare('SELECT i FROM generate_series(99, 0, -1) AS g(i)').declare() + >>> c = db.prepare('SELECT i FROM generate_series(0, 99) AS g(i)').declare() + >>> reverse_c.direction = 'BACKWARD' + >>> reverse_c.seek(0) + >>> c.read() == reverse_c.read() + +Furthermore, when the cursor is configured to read backwards, specifying +``'BACKWARD'`` for read's ``direction`` argument will ultimately cause a forward +fetch. This potentially confusing facet of direction configuration is +implemented in order to create an appropriate symmetry in functionality. +The cursors in the above example contain the same rows, but are ultimately in +reverse order. The backward direction property is designed so that the effect +of any read or seek operation on those cursors is the same:: + + >>> reverse_c.seek(50) + >>> c.seek(50) + >>> c.read(10) == reverse_c.read(10) + >>> c.read(10, 'BACKWARD') == reverse_c.read(10, 'BACKWARD') + +And for relative seeks:: + + >>> c.seek(-10, 1) + >>> reverse_c.seek(-10, 1) + >>> c.read(10, 'BACKWARD') == reverse_c.read(10, 'BACKWARD') + + +Rows +==== + +Rows received from PostgreSQL are instantiated into `postgresql.types.Row` +objects. Rows are both a sequence and a mapping. Items accessed with an `int` +are seen as indexes and other objects are seen as keys:: + + >>> row = db.prepare("SELECT 't'::text AS col0, 2::int4 AS col1").first() + >>> row + ('t', 2) + >>> row[0] + 't' + >>> row["col0"] + 't' + +However, this extra functionality is not free. The cost of instantiating +`postgresql.types.Row` objects is quite measurable, so the `chunks()` execution +method will produce `builtins.tuple` objects for cases where performance is +critical. + +.. note:: + Attributes aren't used to provide access to values due to potential conflicts + with existing method and property names. + + +Row Interface Points +-------------------- + +Rows implement the `collections.abc.Mapping` and `collections.abc.Sequence` interfaces. + + ``Row.keys()`` + An iterable producing the column names. Order is not guaranteed. See the + ``column_names`` property to get an ordered sequence. + + ``Row.values()`` + Iterable to the values in the row. + + ``Row.get(key_or_index[, default=None])`` + Get the item in the row. If the key doesn't exist or the index is out of + range, return the default. + + ``Row.items()`` + Iterable of key-value pairs. Ordered by index. + + ``iter(Row)`` + Iterable to the values in index order. + + ``value in Row`` + Whether or not the value exists in the row. (__contains__) + + ``Row[key_or_index]`` + If ``key_or_index`` is an integer, return the value at that index. If the + index is out of range, raise an `IndexError`. Otherwise, return the value + associated with column name. If the given key, ``key_or_index``, does not + exist, raise a `KeyError`. + + ``Row.index_from_key(key)`` + Return the index associated with the given key. + + ``Row.key_from_index(index)`` + Return the key associated with the given index. + + ``Row.transform(*args, **kw)`` + Create a new row object of the same length, with the same keys, but with new + values produced by applying the given callables to the corresponding items. + Callables given as ``args`` will be associated with values by their index and + callables given as keywords will be associated with values by their key, + column name. + +While the mapping interfaces will provide most of the needed information, some +additional properties are provided for consistency with statement and cursor +objects. + + ``Row.column_names`` + Property providing an ordered sequence of column names. The index corresponds + to the row value-index that the name refers to. + + >>> row[row.column_names[i]] == row[i] + + +Row Transformations +------------------- + +After a row is returned, sometimes the data in the row is not in the desired +format. Further processing is needed if the row object is to going to be +given to another piece of code which requires an object of differring +consistency. + +The ``transform`` method on row objects provides a means to create a new row +object consisting of the old row's items, but with certain columns transformed +using the given callables:: + + >>> row = db.prepare(""" + ... SELECT + ... 'XX9301423'::text AS product_code, + ... 2::int4 AS quantity, + ... '4.92'::numeric AS total + ... """).first() + >>> row + ('XX9301423', 2, Decimal("4.92")) + >>> row.transform(quantity = str) + ('XX9301423', '2', Decimal("4.92")) + +``transform`` supports both positional and keyword arguments in order to +assign the callable for a column's transformation:: + + >>> from operator import methodcaller + >>> row.transform(methodcaller('strip', 'XX')) + ('9301423', 2, Decimal("4.92")) + +Of course, more than one column can be transformed:: + + >>> stripxx = methodcaller('strip', 'XX') + >>> row.transform(stripxx, str, str) + ('9301423', '2', '4.92') + +`None` can also be used to indicate no transformation:: + + >>> row.transform(None, str, str) + ('XX9301423', '2', '4.92') + +More advanced usage can make use of lambdas for compound transformations in a +single pass of the row:: + + >>> strip_and_int = lambda x: int(stripxx(x)) + >>> row.transform(strip_and_int) + (9301423, 2, Decimal("4.92")) + +Transformations will be, more often than not, applied against *rows* as +opposed to *a* row. Using `operator.methodcaller` with `map` provides the +necessary functionality to create simple iterables producing transformed row +sequences:: + + >>> import decimal + >>> apply_tax = lambda x: (x * decimal.Decimal("0.1")) + x + >>> transform_row = methodcaller('transform', strip_and_int, None, apply_tax) + >>> r = map(transform_row, [row]) + >>> list(r) + [(9301423, 2, Decimal('5.412'))] + +And finally, `functools.partial` can be used to create a simple callable:: + + >>> from functools import partial + >>> transform_rows = partial(map, transform_row) + >>> list(transform_rows([row])) + [(9301423, 2, Decimal('5.412'))] + + +Queries +======= + +Queries in `py-postgresql` are single use prepared statements. They exist primarily for +syntactic convenience, but they also allow the driver to recognize the short lifetime of +the statement. + +Single use statements are supported using the ``query`` property on connection +objects, :py:class:`postgresql.api.Connection.query`. The statement object is not +available when using queries as the results, or handle to the results, are directly returned. + +Queries have access to all execution methods: + + * ``Connection.query(sql, *parameters)`` + * ``Connection.query.rows(sql, *parameters)`` + * ``Connection.query.column(sql, *parameters)`` + * ``Connection.query.first(sql, *parameters)`` + * ``Connection.query.chunks(sql, *parameters)`` + * ``Connection.query.declare(sql, *parameters)`` + * ``Connection.query.load_rows(sql, collections.abc.Iterable(parameters))`` + * ``Connection.query.load_chunks(collections.abc.Iterable(collections.abc.Iterable(parameters)))`` + +In cases where a sequence of one-shot queries needs to be performed, it may be important to +avoid unnecessary repeat attribute resolution from the connection object as the ``query`` +property is an interface object created on access. Caching the target execution methods is +recommended:: + + qrows = db.query.rows + l = [] + for x in my_queries: + l.append(qrows(x)) + +The characteristic of Each execution method is discussed in the prior +`Prepared Statements`_ section. + +Stored Procedures +================= + +The ``proc`` method on `postgresql.api.Database` objects provides a means to +create a reference to a stored procedure on the remote database. +`postgresql.api.StoredProcedure` objects are used to represent the referenced +SQL routine. + +This provides a direct interface to functions stored on the database. It +leverages knowledge of the parameters and results of the function in order +to provide the user with a natural interface to the procedure:: + + >>> func = db.proc('version()') + >>> func() + 'PostgreSQL 8.3.6 on ...' + + +Stored Procedure Interface Points +--------------------------------- + +It's more-or-less a function, so there's only one interface point: + + ``func(*args, **kw)`` (``__call__``) + Stored procedure objects are callable, executing a procedure will return an + object of suitable representation for a given procedure's type signature. + + If it returns a single object, it will return the single object produced by + the procedure. + + If it's a set returning function, it will return an *iterable* to the values + produced by the procedure. + + In cases of set returning function with multiple OUT-parameters, a cursor + will be returned. + + +Stored Procedure Type Support +----------------------------- + +Stored procedures support most types of functions. "Function Types" being set +returning functions, multiple-OUT parameters, and simple single-object returns. + +Set-returning functions, SRFs return a sequence:: + + >>> generate_series = db.proc('generate_series(int,int)') + >>> gs = generate_series(1, 20) + >>> gs + > + >>> next(gs) + 1 + >>> list(gs) + [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] + +For functions like ``generate_series()``, the driver is able to identify that +the return is a sequence of *solitary* integer objects, so the result of the +function is just that, a sequence of integers. + +Functions returning composite types are recognized, and return row objects:: + + >>> db.execute(""" + ... CREATE FUNCTION composite(OUT i int, OUT t text) + ... LANGUAGE SQL AS + ... $body$ + ... SELECT 900::int AS i, 'sample text'::text AS t; + ... $body$; + ... """) + >>> composite = db.proc('composite()') + >>> r = composite() + >>> r + (900, 'sample text') + >>> r['i'] + 900 + >>> r['t'] + 'sample text' + +Functions returning a set of composites are recognized, and the result is a +`postgresql.api.Cursor` object whose column names are consistent with the names +of the OUT parameters:: + + >>> db.execute(""" + ... CREATE FUNCTION srfcomposite(out i int, out t text) + ... RETURNS SETOF RECORD + ... LANGUAGE SQL AS + ... $body$ + ... SELECT 900::int AS i, 'sample text'::text AS t + ... UNION ALL + ... SELECT 450::int AS i, 'more sample text'::text AS t + ... $body$; + ... """) + >>> srfcomposite = db.proc('srfcomposite()') + >>> r = srfcomposite() + >>> next(r) + (900, 'sample text') + >>> v = next(r) + >>> v['i'], v['t'] + (450, 'more sample text') + + +Transactions +============ + +Transactions are managed by creating an object corresponding to a +transaction started on the server. A transaction is a transaction block, +a savepoint, or a prepared transaction. The ``xact(...)`` method on the +connection object provides the standard method for creating a +`postgresql.api.Transaction` object to manage a transaction on the connection. + +The creation of a transaction object does not start the transaction. Rather, the +transaction must be explicitly started using the ``start()`` method on the +transaction object. Usually, transactions *should* be managed with the context +manager interfaces:: + + >>> with db.xact(): + ... ... + +The transaction in the above example is opened, started, by the ``__enter__`` +method invoked by the with-statement's usage. It will be subsequently +committed or rolled-back depending on the exception state and the error state +of the connection when ``__exit__`` is called. + +**Using the with-statement syntax for managing transactions is strongly +recommended.** By using the transaction's context manager, it allows for Python +exceptions to be properly treated as fatal to the transaction as when an +uncaught exception of any kind occurs within the block, it is unlikely that +the state of the transaction can be trusted. Additionally, the ``__exit__`` +method provides a safe-guard against invalid commits. This can occur if a +database error is inappropriately caught within a block without being raised. + +The context manager interfaces are higher level interfaces to the explicit +instruction methods provided by `postgresql.api.Transaction` objects. + + +Transaction Configuration +------------------------- + +Keyword arguments given to ``xact()`` provide the means for configuring the +properties of the transaction. Only three points of configuration are available: + + ``isolation`` + The isolation level of the transaction. This must be a string. It will be + interpolated directly into the START TRANSACTION statement. Normally, + 'SERIALIZABLE' or 'READ COMMITTED': + + >>> with db.xact('SERIALIZABLE'): + ... ... + + ``mode`` + A string, 'READ ONLY' or 'READ WRITE'. States the mutability of stored + information in the database. Like ``isolation``, this is interpolated + directly into the START TRANSACTION string. + +The specification of any of these transaction properties imply that the transaction +is a block. Savepoints do not take configuration, so if a transaction identified +as a block is started while another block is running, an exception will be +raised. + + +Transaction Interface Points +---------------------------- + +The methods available on transaction objects manage the state of the transaction +and relay any necessary instructions to the remote server in order to reflect +that change of state. + + >>> x = db.xact(...) + + ``x.start()`` + Start the transaction. + + ``x.commit()`` + Commit the transaction. + + ``x.rollback()`` + Abort the transaction. + +These methods are primarily provided for applications that manage transactions +in a way that cannot be formed around single, sequential blocks of code. +Generally, using these methods require additional work to be performed by the +code that is managing the transaction. +If usage of these direct, instructional methods is necessary, it is important to +note that if the database is in an error state when a *transaction block's* +commit() is executed, an implicit rollback will occur. The transaction object +will simply follow instructions and issue the ``COMMIT`` statement, and it will +succeed without exception. + + +Error Control +------------- + +Handling *database* errors inside transaction CMs is generally discouraged as +any database operation that occurs within a failed transaction is an error +itself. It is important to trap any recoverable database errors *outside* of the +scope of the transaction's context manager: + + >>> try: + ... with db.xact(): + ... ... + ... except postgresql.exceptions.UniqueError: + ... pass + +In cases where the database is in an error state, but the context exits +without an exception, a `postgresql.exceptions.InFailedTransactionError` is +raised by the driver: + + >>> with db.xact(): + ... try: + ... ... + ... except postgresql.exceptions.UniqueError: + ... pass + ... + Traceback (most recent call last): + ... + postgresql.exceptions.InFailedTransactionError: invalid block exit detected + CODE: 25P02 + SEVERITY: ERROR + +Normally, if a ``COMMIT`` is issued on a failed transaction, the command implies a +``ROLLBACK`` without error. This is a very undesirable result for the CM's exit +as it may allow for code to be ran that presumes the transaction was committed. +The driver intervenes here and raises the +`postgresql.exceptions.InFailedTransactionError` to safe-guard against such +cases. This effect is consistent with savepoint releases that occur during an +error state. The distinction between the two cases is made using the ``source`` +property on the raised exception. + + +Settings +======== + +SQL's SHOW and SET provides a means to configure runtime parameters on the +database("GUC"s). In order to save the user some grief, a +`collections.abc.MutableMapping` interface is provided to simplify configuration. + +The ``settings`` attribute on the connection provides the interface extension. + +The standard dictionary interface is supported: + + >>> db.settings['search_path'] = "$user,public" + +And ``update(...)`` is better performing for multiple sets: + + >>> db.settings.update({ + ... 'search_path' : "$user,public", + ... 'default_statistics_target' : "1000" + ... }) + +.. note:: + The ``transaction_isolation`` setting cannot be set using the ``settings`` + mapping. Internally, ``settings`` uses ``set_config``, which cannot adjust + that particular setting. + +Settings Interface Points +------------------------- + +Manipulation and interrogation of the connection's settings is achieved by +using the standard `collections.abc.MutableMapping` interfaces. + + ``Connection.settings[k]`` + Get the value of a single setting. + + ``Connection.settings[k] = v`` + Set the value of a single setting. + + ``Connection.settings.update([(k1,v2), (k2,v2), ..., (kn,vn)])`` + Set multiple settings using a sequence of key-value pairs. + + ``Connection.settings.update({k1 : v1, k2 : v2, ..., kn : vn})`` + Set multiple settings using a dictionary or mapping object. + + ``Connection.settings.getset([k1, k2, ..., kn])`` + Get a set of a settings. This is the most efficient way to get multiple + settings as it uses a single request. + + ``Connection.settings.keys()`` + Get all available setting names. + + ``Connection.settings.values()`` + Get all setting values. + + ``Connection.settings.items()`` + Get a sequence of key-value pairs corresponding to all settings on the + database. + +Settings Management +------------------- + +`postgresql.api.Settings` objects can create context managers when called. +This gives the user with the ability to specify sections of code that are to +be ran with certain settings. The settings' context manager takes full +advantage of keyword arguments in order to configure the context manager: + + >>> with db.settings(search_path = 'local,public', timezone = 'mst'): + ... ... + +`postgresql.api.Settings` objects are callable; the return is a context manager +configured with the given keyword arguments representing the settings to use for +the block of code that is about to be executed. + +When the block exits, the settings will be restored to the values that they had +before the block entered. + + +Type Support +============ + +The driver supports a large number of PostgreSQL types at the binary level. +Most types are converted to standard Python types. The remaining types are +usually PostgreSQL specific types that are converted into objects whose class +is defined in `postgresql.types`. + +When a conversion function is not available for a particular type, the driver +will use the string format of the type and instantiate a `str` object +for the data. It will also expect `str` data when parameter of a type without a +conversion function is bound. + + +.. note:: + Generally, these standard types are provided for convenience. If conversions into + these datatypes are not desired, it is recommended that explicit casts into + ``text`` are made in statement string. + + +.. table:: Python types used to represent PostgreSQL types. + + ================================= ================================== =========== + PostgreSQL Types Python Types SQL Types + ================================= ================================== =========== + `postgresql.types.INT2OID` `int` smallint + `postgresql.types.INT4OID` `int` integer + `postgresql.types.INT8OID` `int` bigint + `postgresql.types.FLOAT4OID` `float` float + `postgresql.types.FLOAT8OID` `float` double + `postgresql.types.VARCHAROID` `str` varchar + `postgresql.types.BPCHAROID` `str` char + `postgresql.types.XMLOID` `xml.etree` (cElementTree) xml + + `postgresql.types.DATEOID` `datetime.date` date + `postgresql.types.TIMESTAMPOID` `datetime.datetime` timestamp + `postgresql.types.TIMESTAMPTZOID` `datetime.datetime` (tzinfo) timestamptz + `postgresql.types.TIMEOID` `datetime.time` time + `postgresql.types.TIMETZOID` `datetime.time` timetz + `postgresql.types.INTERVALOID` `datetime.timedelta` interval + + `postgresql.types.NUMERICOID` `decimal.Decimal` numeric + `postgresql.types.BYTEAOID` `bytes` bytea + `postgresql.types.TEXTOID` `str` text + `dict` hstore + ================================= ================================== =========== + +The mapping in the above table *normally* goes both ways. So when a parameter +is passed to a statement, the type *should* be consistent with the corresponding +Python type. However, many times, for convenience, the object will be passed +through the type's constructor, so it is not always necessary. + + +Arrays +------ + +Arrays of PostgreSQL types are supported with near transparency. For simple +arrays, arbitrary iterables can just be given as a statement's parameter and the +array's constructor will consume the objects produced by the iterator into a +`postgresql.types.Array` instance. However, in situations where the array has +multiple dimensions, `list` objects are used to delimit the boundaries of the +array. + + >>> ps = db.prepare("select $1::int[]") + >>> ps.first([(1,2), (2,3)]) + Traceback: + ... + postgresql.exceptions.ParameterError + +In the above case, it is apparent that this array is supposed to have two +dimensions. However, this is not the case for other types: + + >>> ps = db.prepare("select $1::point[]") + >>> ps.first([(1,2), (2,3)]) + postgresql.types.Array([postgresql.types.point((1.0, 2.0)), postgresql.types.point((2.0, 3.0))]) + +Lists are used to provide the necessary boundary information: + + >>> ps = db.prepare("select $1::int[]") + >>> ps.first([[1,2],[2,3]]) + postgresql.types.Array([[1,2],[2,3]]) + +The above is the appropriate way to define the array from the original example. + +.. hint:: + The root-iterable object given as an array parameter does not need to be a + list-type as it's assumed to be made up of elements. + + +Composites +---------- + +Composites are supported using `postgresql.types.Row` objects to represent +the data. When a composite is referenced for the first time, the driver +queries the database for information about the columns that make up the type. +This information is then used to create the necessary I/O routines for packing +and unpacking the parameters and columns of that type:: + + >>> db.execute("CREATE TYPE ctest AS (i int, t text, n numeric);") + >>> ps = db.prepare("SELECT $1::ctest") + >>> i = (100, 'text', "100.02013") + >>> r = ps.first(i) + >>> r["t"] + 'text' + >>> r["n"] + Decimal("100.02013") + +Or if use of a dictionary is desired:: + + >>> r = ps.first({'t' : 'just-the-text'}) + >>> r + (None, 'just-the-text', None) + +When a dictionary is given to construct the row, absent values are filled with +`None`. + +.. _db_messages: + +Database Messages +================= + +By default, py-postgresql gives detailed reports of messages emitted by the +database. Often, the verbosity is excessive due to single target processes or +existing application infrastructure for tracing the sources of various events. +Normally, this verbosity is not a significant problem as the driver defaults the +``client_min_messages`` setting to ``'WARNING'`` by default. + +However, if ``NOTICE`` or ``INFO`` messages are needed, finer grained control +over message propagation may be desired, py-postgresql's object relationship +model provides a common protocol for controlling message propagation and, +ultimately, display. + +The ``msghook`` attribute on elements--for instance, Statements, Connections, +and Connectors--is absent by default. However, when present on an object that +contributed the cause of a message event, it will be invoked with the Message, +`postgresql.message.Message`, object as its sole parameter. The attribute of +the object that is closest to the event is checked first, if present it will +be called. If the ``msghook()`` call returns a `True` +value(specficially, ``bool(x) is True``), the message will *not* be +propagated any further. However, if a `False` value--notably, `None`--is +returned, the next element is checked until the list is exhausted and the +message is given to `postgresql.sys.msghook`. The normal list of elements is +as follows:: + + Output → Statement → Connection → Connector → Driver → postgresql.sys + +Where ``Output`` can be a `postgresql.api.Cursor` object produced by +``declare(...)`` or an implicit output management object used *internally* by +``Statement.__call__()`` and other statement execution methods. Setting the +``msghook`` attribute on `postgresql.api.Statement` gives very fine +control over raised messages. Consider filtering the notice message on create +table statements that implicitly create indexes:: + + >>> db = postgresql.open(...) + >>> db.settings['client_min_messages'] = 'NOTICE' + >>> ct_this = db.prepare('CREATE TEMP TABLE "this" (i int PRIMARY KEY)') + >>> ct_that = db.prepare('CREATE TEMP TABLE "that" (i int PRIMARY KEY)') + >>> def filter_notices(msg): + ... if msg.details['severity'] == 'NOTICE': + ... return True + ... + >>> ct_that() + NOTICE: CREATE TABLE / PRIMARY KEY will create implicit index "that_pkey" for table "that" + ... + ('CREATE TABLE', None) + >>> ct_this.msghook = filter_notices + >>> ct_this() + ('CREATE TABLE', None) + >>> + +The above illustrates the quality of an installed ``msghook`` that simply +inhibits further propagation of messages with a severity of 'NOTICE'--but, only +notices coming from objects derived from the ``ct_this`` +`postgresql.api.Statement` object. + +Subsequently, if the filter is installed on the connection's ``msghook``:: + + >>> db = postgresql.open(...) + >>> db.settings['client_min_messages'] = 'NOTICE' + >>> ct_this = db.prepare('CREATE TEMP TABLE "this" (i int PRIMARY KEY)') + >>> ct_that = db.prepare('CREATE TEMP TABLE "that" (i int PRIMARY KEY)') + >>> def filter_notices(msg): + ... if msg.details['severity'] == 'NOTICE': + ... return True + ... + >>> db.msghook = filter_notices + >>> ct_that() + ('CREATE TABLE', None) + >>> ct_this() + ('CREATE TABLE', None) + >>> + +Any message with ``'NOTICE'`` severity coming from the connection, ``db``, will be +suffocated by the ``filter_notices`` function. However, if a ``msghook`` is +installed on either of those statements, it would be possible for display to +occur depending on the implementation of the hook installed on the statement +objects. + + +Message Metadata +---------------- + +PostgreSQL messages, `postgresql.message.Message`, are primarily described in three +parts: the SQL-state code, the main message string, and a mapping containing the +details. The follow attributes are available on message objects: + + ``Message.message`` + The primary message string. + + ``Message.code`` + The SQL-state code associated with a given message. + + ``Message.source`` + The origins of the message. Normally, ``'SERVER'`` or ``'CLIENT'``. + + ``Message.location`` + A terse, textual representation of ``'file'``, ``'line'``, and ``'function'`` + provided by the associated ``details``. + + ``Message.details`` + A mapping providing extended information about a message. This mapping + object **can** contain the following keys: + + ``'severity'`` + Any of ``'DEBUG'``, ``'INFO'``, ``'NOTICE'``, ``'WARNING'``, ``'ERROR'``, + ``'FATAL'``, or ``'PANIC'``; the latter three are usually associated with a + `postgresql.exceptions.Error` instance. + + ``'context'`` + The CONTEXT portion of the message. + + ``'detail'`` + The DETAIL portion of the message. + + ``'hint'`` + The HINT portion of the message. + + ``'position'`` + A number identifying the position in the statement string that caused a + parse error. + + ``'file'`` + The name of the file that emitted the message. + (*normally* server information) + + ``'function'`` + The name of the function that emitted the message. + (*normally* server information) + + ``'line'`` + The line of the file that emitted the message. + (*normally* server information) diff --git a/py_opengauss/documentation/sphinx/gotchas.rst b/py_opengauss/documentation/sphinx/gotchas.rst new file mode 100644 index 0000000000000000000000000000000000000000..915e3360c993a01df73707528ecf7ed1c420f423 --- /dev/null +++ b/py_opengauss/documentation/sphinx/gotchas.rst @@ -0,0 +1,114 @@ +Gotchas +======= + +It is recognized that decisions were made that may not always be ideal for a +given user. In order to highlight those potential issues and hopefully bring +some sense into a confusing situation, this document was drawn. + +Thread Safety +------------- + +py-postgresql connection operations are not thread safe. + +`client_encoding` setting should be altered carefully +----------------------------------------------------- + +`postgresql.driver`'s streaming cursor implementation reads a fixed set of rows +when it queries the server for more. In order to optimize some situations, the +driver will send a request for more data, but makes no attempt to wait and +process the data as it is not yet needed. When the user comes back to read more +data from the cursor, it will then look at this new data. The problem being, if +`client_encoding` was switched, it may use the wrong codec to transform the +wire data into higher level Python objects(str). + +To avoid this problem from ever happening, set the `client_encoding` early. +Furthermore, it is probably best to never change the `client_encoding` as the +driver automatically makes the necessary transformation to Python strings. + + +The user and password is correct, but it does not work when using `postgresql.driver` +------------------------------------------------------------------------------------- + +This issue likely comes from the possibility that the information sent to the +server early in the negotiation phase may not be in an encoding that is +consistent with the server's encoding. + +One problem is that PostgreSQL does not provide the client with the server +encoding early enough in the negotiation phase, and, therefore, is unable to +process the password data in a way that is consistent with the server's +expectations. + +Another problem is that PostgreSQL takes much of the data in the startup message +as-is, so a decision about the best way to encode parameters is difficult. + +The easy way to avoid *most* issues with this problem is to initialize the +database in the `utf-8` encoding. The driver defaults the expected server +encoding to `utf-8`. However, this can be overridden by creating the `Connector` +with a `server_encoding` parameter. Setting `server_encoding` to the proper +value of the target server will allow the driver to properly encode *some* of +the parameters. Also, any GUC parameters passed via the `settings` parameter +should use typed objects when possible to hint that the server encoding should +not be used on that parameter(`bytes`, for instance). + + +Backslash characters are being treated literally +------------------------------------------------ + +The driver enables standard compliant strings. Stop using non-standard features. +;) + +If support for non-standard strings was provided it would require to the +driver to provide subjective quote interfaces(eg, db.quote_literal). Doing so is +not desirable as it introduces difficulties for the driver *and* the user. + + +Types without binary support in the driver are unsupported in arrays and records +-------------------------------------------------------------------------------- + +When an array or composite type is identified, `postgresql.protocol.typio` +ultimately chooses the binary format for the transfer of the column or +parameter. When this is done, PostgreSQL will pack or expect *all* the values +in binary format as well. If that binary format is not supported and the type +is not an string, it will fail to unpack the row or pack the appropriate data for +the element or attribute. + +In most cases issues related to this can be avoided with explicit casts to text. + + +NOTICEs, WARNINGs, and other messages are too verbose +----------------------------------------------------- + +For many situations, the information provided with database messages is +far too verbose. However, considering that py-postgresql is a programmer's +library, the default of high verbosity is taken with the express purpose of +allowing the programmer to "adjust the volume" until appropriate. + +By default, py-postgresql adjusts the ``client_min_messages`` to only emit +messages at the WARNING level or higher--ERRORs, FATALs, and PANICs. +This reduces the number of messages generated by most connections dramatically. + +If further customization is needed, the :ref:`db_messages` section has +information on overriding the default action taken with database messages. + +Strange TypeError using load_rows() or load_chunks() +---------------------------------------------------- + +When a prepared statement is directly executed using ``__call__()``, it can easily +validate that the appropriate number of parameters are given to the function. +When ``load_rows()`` or ``load_chunks()`` is used, any tuple in the +the entire sequence can cause this TypeError during the loading process:: + + TypeError: inconsistent items, N processors and M items in row + +This exception is raised by a generic processing routine whose functionality +is abstract in nature, so the message is abstract as well. It essentially means +that a tuple in the sequence given to the loading method had too many or too few +items. + +Non-English Locales +------------------- + +In the past, some builds of PostgreSQL localized the severity field of some protocol messages. +`py-postgresql` expects these fields to be consistent with their english terms. If the driver +raises strange exceptions during the use of non-english locales, it may be necessary to use an +english setting in order to coax the server into issueing familiar terms. diff --git a/py_opengauss/documentation/sphinx/index.rst b/py_opengauss/documentation/sphinx/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..9189c563ede57efa2c3c6de6bed70bae6d1f5fa7 --- /dev/null +++ b/py_opengauss/documentation/sphinx/index.rst @@ -0,0 +1,75 @@ +py-postgresql +============= + +py-postgresql is a project dedicated to improving the Python client interfaces to PostgreSQL. + +At its core, py-postgresql provides a PG-API, `postgresql.api`, and +DB-API 2.0 interface for using a PostgreSQL database. + +Contents +-------- + +.. toctree:: + :maxdepth: 2 + + admin + driver + clientparameters + cluster + notifyman + alock + copyman + gotchas + +Reference +--------- + +.. toctree:: + :maxdepth: 2 + + bin + reference + +Changes +------- + +.. toctree:: + :maxdepth: 1 + + changes-v1.3 + changes-v1.2 + changes-v1.1 + changes-v1.0 + +Sample Code +----------- + +Using `postgresql.driver`:: + + >>> import postgresql + >>> db = postgresql.open("pq://user:password@host/name_of_database") + >>> db.execute("CREATE TABLE emp (emp_name text PRIMARY KEY, emp_salary numeric)") + >>> + >>> # Create the statements. + >>> make_emp = db.prepare("INSERT INTO emp VALUES ($1, $2)") + >>> raise_emp = db.prepare("UPDATE emp SET emp_salary = emp_salary + $2 WHERE emp_name = $1") + >>> get_emp_with_salary_lt = db.prepare("SELECT emp_name FROM emp WHERE emp_salay < $1") + >>> + >>> # Create some employees, but do it in a transaction--all or nothing. + >>> with db.xact(): + ... make_emp("John Doe", "150,000") + ... make_emp("Jane Doe", "150,000") + ... make_emp("Andrew Doe", "55,000") + ... make_emp("Susan Doe", "60,000") + >>> + >>> # Give some raises + >>> with db.xact(): + ... for row in get_emp_with_salary_lt("125,000"): + ... print(row["emp_name"]) + ... raise_emp(row["emp_name"], "10,000") + +Of course, if DB-API 2.0 is desired, the module is located at +`postgresql.driver.dbapi20`. DB-API extends PG-API, so the features +illustrated above are available on DB-API connections. + +See :ref:`db_interface` for more information. diff --git a/py_opengauss/documentation/sphinx/lib.rst b/py_opengauss/documentation/sphinx/lib.rst new file mode 100644 index 0000000000000000000000000000000000000000..592b96fa7f238d1f9f8eefd145f50736300b5efe --- /dev/null +++ b/py_opengauss/documentation/sphinx/lib.rst @@ -0,0 +1,534 @@ +Categories and Libraries +************************ + +This chapter discusses the usage and implementation of connection categories and +libraries. Originally these features were written with general purpose use in mind; +however, it is recommended that these features **not** be used in applications. +They are primarily used internally by the the driver and may be removed in the future. + +Libraries are a collection of SQL statements that can be bound to a +connection. Libraries are *normally* bound directly to the connection object as +an attribute using a name specified by the library. + +Libraries provide a common way for SQL statements to be managed outside of the +code that uses them. When using ILFs, this increases the portability of the SQL +by keeping the statements isolated from the Python code in an accessible format +that can be easily used by other languages or systems --- An ILF parser can be +implemented within a few dozen lines using basic text tools. + +SQL statements defined by a Library are identified by their Symbol. These +symbols are named and annotated in order to allow the user to define how a +statement is to be used. The user may state the default execution method of +the statement object, or whether the symbol is to be preloaded at bind +time--these properties are Symbol Annotations. + +The purpose of libraries are to provide a means to manage statements on +disk and at runtime. ILFs provide a means to reference a collection +of statements on disk, and, when loaded, the symbol bindings provides means to +reference a statement already prepared for use on a given connection. + +The `postgresql.lib` package-module provides fundamental classes for supporting +categories and libraries. + + +Writing Libraries +================= + +ILF files are the recommended way to build a library. These files use the +naming convention "lib{NAME}.sql". The prefix and suffix are used describe the +purpose of the file and to provide a hint to editors that SQL highlighting +should be used. The format of an ILF takes the form:: + + + [name:type:method] + + ... + +Where multiple symbols may be defined. The Preface that comes before the first +symbol is an arbitrary block of text that should be used to describe the library. +This block is free-form, and should be considered a good place for some +general documentation. + +Symbols are named and described using the contents of section markers: +``('[' ... ']')``. Section markers have three components: the symbol name, +the symbol type and the symbol method. Each of these components are separated +using a single colon, ``:``. All components are optional except the Symbol name. +For example:: + + [get_user_info] + SELECT * FROM user WHERE user_id = $1 + + [get_user_info_v2::] + SELECT * FROM user WHERE user_id = $1 + +In the above example, ``get_user_info`` and ``get_user_info_v2`` are identical. +Empty components indicate the default effect. + +The second component in the section identifier is the symbol type. All Symbol +types are listed in `Symbol Types`_. This can be +used to specify what the section's contents are or when to bind the +symbol:: + + [get_user_info:preload] + SELECT * FROM user WHERE user_id = $1 + +This provides the Binding with the knowledge that the statement should be +prepared when the Library is bound. Therefore, when this Symbol's statement +is used for the first time, it will have already been prepared. + +Another type is the ``const`` Symbol type. This defines a data Symbol whose +*statement results* will be resolved when the Library is bound:: + + [user_type_ids:const] + SELECT user_type_id, user_type FROM user_types; + +Constant Symbols cannot take parameters as they are data properties. The +*result* of the above query is set to the Bindings' ``user_type_ids`` +attribute:: + + >>> db.lib.user_type_ids + + +Where ``lib`` in the above is a Binding of the Library containing the +``user_type_ids`` Symbol. + +Finally, procedures can be bound as symbols using the ``proc`` type:: + + [remove_user:proc] + remove_user(bigint) + +All procedures symbols are loaded when the Library is bound. Procedure symbols +are special because the execution method is effectively specified by the +procedure itself. + + +The third component is the symbol ``method``. This defines the execution method +of the statement and ultimately what is returned when the Symbol is called at +runtime. All the execution methods are listed in `Symbol Execution Methods`_. + +The default execution method is the default execution method of +`postgresql.api.PreparedStatement` objects; return the entire result set in a +list object:: + + [get_numbers] + SELECT i FROM generate_series(0, 100-1) AS g(i); + +When bound:: + + >>> db.lib.get_numbers() == [(x,) for x in range(100)] + True + +The transformation of range in the above is necessary as statements +return a sequence of row objects by default. + +For large result-sets, fetching all the rows would be taxing on a system's +memory. The ``rows`` and ``chunks`` methods provide an iterator to rows produced +by a statement using a stream:: + + [get_some_rows::rows] + SELECT i FROM generate_series(0, 1000) AS g(i); + + [get_some_chunks::chunks] + SELECT i FROM generate_series(0, 1000) AS g(i); + +``rows`` means that the Symbol will return an iterator producing individual rows +of the result, and ``chunks`` means that the Symbol will return an iterator +producing sequences of rows of the result. + +When bound:: + + >>> from itertools import chain + >>> list(db.lib.get_some_rows()) == list(chain.from_iterable(db.lib.get_some_chunks())) + True + +Other methods include ``column`` and ``first``. The column method provides a +means to designate that the symbol should return an iterator of the values in +the first column instead of an iterator to the rows:: + + [another_generate_series_example::column] + SELECT i FROM generate_series(0, $1::int) AS g(i) + +In use:: + + >>> list(db.lib.another_generate_series_example(100-1)) == list(range(100)) + True + >>> list(db.lib.another_generate_series_example(10-1)) + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + +The ``first`` method provides direct access to simple results. +Specifically, the first column of the first row when there is only one column. +When there are multiple columns the first row is returned:: + + [get_one::first] + SELECT 1 + + [get_one_twice::first] + SELECT 1, 1 + +In use:: + + >>> db.lib.get_one() == 1 + True + >>> db.lib.get_one_twice() == (1,1) + True + +.. note:: + ``first`` should be used with care. When the result returns no rows, `None` + will be returned. + + +Using Libraries +=============== + +After a library is created, it must be loaded before it can be bound using +programmer interfaces. The `postgresql.lib.load` interface provides the +primary entry point for loading libraries. + +When ``load`` is given a string, it identifies if a directory separator is in +the string, if there is it will treat the string as a *path* to the ILF to be +loaded. If no separator is found, it will treat the string as the library +name fragment and look for "lib{NAME}.sql" in the directories listed in +`postgresql.sys.libpath`. + +Once a `postgresql.lib.Library` instance has been acquired, it can then be +bound to a connection for use. `postgresql.lib.Binding` is used to create an +object that provides and manages the Bound Symbols:: + + >>> import postgresql.lib as pg_lib + >>> lib = pg_lib.load(...) + >>> B = pg_lib.Binding(db, lib) + +The ``B`` object in the above example provides the Library's Symbols as +attributes which can be called to in order to execute the Symbol's statement:: + + >>> B.symbol(param) + ... + +While it is sometimes necessary, manual creation of a Binding is discouraged. +Rather, `postgresql.lib.Category` objects should be used to manage the set of +Libraries to be bound to a connection. + + +Categories +---------- + +Libraries provide access to a collection of symbols; Bindings provide an +interface to the symbols with respect to a subject database. When a connection +is established, multiple Bindings may need to be created in order to fulfill +the requirements of the programmer. When a Binding is created, it exists in +isolation; this can be an inconvenience when access to both the Binding and +the Connection is necessary. Categories exist to provide a formal method for +defining the interface extensions on a `postgresql.api.Database` +instance(connection). + +A Category is essentially a runtime-class for connections. It provides a +formal initialization procedure for connection objects at runtime. However, +the connection resource must be connected prior to category initialization. + +Categories are sets of Libraries to be bound to a connection with optional name +substitutions. In order to create one directly, pass the Library instances to +`postgresql.lib.Category`:: + + >>> import postgresql.lib as pg_lib + >>> cat = pg_lib.Category(lib1, lib2, libN) + +Where ``lib1``, ``lib2``, ``libN`` are `postgresql.lib.Library` instances; +usually created by `postgresql.lib.load`. Once created, categories can then +used by passing the ``category`` keyword to connection creation interfaces:: + + >>> import postgresql + >>> db = postgresql.open(category = cat) + +The ``db`` object will now have Bindings for ``lib1``, ``lib2``, ..., and +``libN``. + +Categories can alter the access point(attribute name) of Bindings. This is done +by instantiating the Category using keyword parameters:: + + >>> cat = pg_lib.Category(lib1, lib2, libname = libN) + +At this point, when a connection is established as the category ``cat``, +``libN`` will be bound to the connection object on the attribute ``libname`` +instead of the name defined by the library. + +And a final illustration of Category usage:: + + >>> db = postgresql.open(category = pg_lib.Category(pg_lib.load('name'))) + >>> db.name + + + +Symbol Types +============ + +The symbol type determines how a symbol is going to be treated by the Binding. +For instance, ``const`` symbols are resolved when the Library is bound and +the statement object is immediately discarded. Here is a list of symbol types +that can be used in ILF libraries: + + ```` (Empty component) + The symbol's statement will never change. This allows the Bound Symbol to + hold onto the `postgresql.api.PreparedStatement` object. When the symbol is + used again, it will refer to the existing prepared statement object. + + ``preload`` + Like the default type, the Symbol is a simple statement, but it should be + loaded when the library is bound to the connection. + + ``const`` + The statement takes no parameters and only needs to be executed once. This + will cause the statement to be executed when the library is bound and the + results of the statement will be set to the Binding using the symbol name so + that it may be used as a property by the user. + + ``proc`` + The contents of the section is a procedure identifier. When this type is used + the symbol method *should not* be specified as the method annotation will be + automatically resolved based on the procedure's signature. + + ``transient`` + The Symbol is a statement that should *not* be retained. Specifically, it is + a statement object that will be discarded when the user discard the referenced + Symbol. Used in cases where the statement is used once or very infrequently. + + +Symbol Execution Methods +======================== + +The Symbol Execution Method provides a way to specify how a statement is going +to be used. Specifically, which `postgresql.api.PreparedStatement` method +should be executed when a Bound Symbol is called. The following is a list of +the symbol execution methods and the effect it will have when invoked: + + ```` (Empty component) + Returns the entire result set in a single list object. If the statement does + not return rows, a ``(command, count)`` pair will be returned. + + ``rows`` + Returns an iterator producing each row in the result set. + + ``chunks`` + Returns an iterator producing "chunks" of rows in the result set. + + ``first`` + Returns the first column of the first row if there is one column in the result + set. If there are multiple columns in the result set, the first row is + returned. If query is non-RETURNING DML--insert, update, or delete, the row + count is returned. + + ``column`` + Returns an iterator to values in the first column. (Equivalent to + executing a statement as ``map(operator.itemgetter(0), ps.rows())``.) + + ``declare`` + Returns a scrollable cursor, `postgresql.api.Cursor`, to the result set. + + ``load_chunks`` + Takes an iterable row-chunks to be given to the statement. Returns `None`. If + the statement is a ``COPY ... FROM STDIN``, the iterable must produce chunks + of COPY lines. + + ``load_rows`` + Takes an iterable rows to be given as parameters. If the statement is a ``COPY + ... FROM STDIN``, the iterable must produce COPY lines. + + +Reference Symbols +================= + +Reference Symbols provide a way to construct a Bound Symbol using the Symbol's +query. When invoked, A Reference Symbol's query is executed in order to produce +an SQL statement to be used as a Bound Symbol. In ILF files, a reference is +identified by its symbol name being prefixed with an ampersand:: + + [&refsym::first] + SELECT 'SELECT 1::int4'::text + +Then executed:: + + >>> # Runs the 'refsym' SQL, and creates a Bound Symbol using the results. + >>> sym = lib.refsym() + >>> assert sym() == 1 + +The Reference Symbol's type and execution method are inherited by the created +Bound Symbol. With one exception, ``const`` reference symbols are +special in that they immediately resolved into the target Bound Symbol. + +A Reference Symbol's source query *must* produce rows of text columns. Multiple +columns and multiple rows may be produced by the query, but they must be +character types as the results are promptly joined together with whitespace so +that the target statement may be prepared. + +Reference Symbols are most likely to be used in dynamic DDL and DML situations, +or, somewhat more specifically, any query whose definition depends on a +generated column list. + +Distributing and Usage +====================== + +For applications, distribution and management can easily be a custom +process. The application designates the library directory; the entry point +adds the path to the `postgresql.sys.libpath` list; a category is built; and, a +connection is made using the category. + +For mere Python extensions, however, ``distutils`` has a feature that can +aid in ILF distribution. The ``package_data`` setup keyword can be used to +include ILF files alongside the Python modules that make up a project. See +http://docs.python.org/3.1/distutils/setupscript.html#installing-package-data +for more detailed information on the keyword parameter. + +The recommended way to manage libraries for extending projects is to +create a package to contain them. For instance, consider the following layout:: + + project/ + setup.py + pkg/ + __init__.py + lib/ + __init__.py + libthis.sql + libthat.sql + +The project's SQL libraries are organized into a single package directory, +``lib``, so ``package_data`` would be configured:: + + package_data = {'pkg.lib': ['*.sql']} + +Subsequently, the ``lib`` package initialization script can then be used to +load the libraries, and create any categories(``project/pkg/lib/__init__.py``):: + + import os.path + import postgresql.lib as pg_lib + import postgresql.sys as pg_sys + libdir = os.path.dirname(__file__) + pg_sys.libpath.append(libdir) + libthis = pg_lib.load('this') + libthat = pg_lib.load('that') + stdcat = pg_lib.Category(libthis, libthat) + +However, it can be undesirable to add the package directory to the global +`postgresql.sys.libpath` search paths. Direct path loading can be used in those +cases:: + + import os.path + import postgresql.lib as pg_lib + libdir = os.path.dirname(__file__) + libthis = pg_lib.load(os.path.join(libdir, 'libthis.sql')) + libthat = pg_lib.load(os.path.join(libdir, 'libthat.sql')) + stdcat = pg_lib.Category(libthis, libthat) + +Using the established project context, a connection would then be created as:: + + from pkg.lib import stdcat + import postgresql as pg + db = pg.open(..., category = stdcat) + # And execute some fictitious symbols. + db.this.sym_from_libthis() + db.that.sym_from_libthat(...) + + +Audience and Motivation +======================= + +This chapter covers advanced material. It is **not** recommended that categories +and libraries be used for trivial applications or introductory projects. + +.. note:: + Libraries and categories are not likely to be of interest to ORM or DB-API users. + +With exception to ORMs or other similar abstractions, the most common pattern +for managing connections and statements is delegation:: + + class MyAppDB(object): + def __init__(self, connection): + self.connection = connection + + def my_operation(self, op_arg1, op_arg2): + return self.connection.prepare( + "SELECT my_operation_proc($1,$2)", + )(op_arg1, op_arg2) + ... + +The straightforward nature is likeable, but the usage does not take advantage of +prepared statements. In order to do that an extra condition is necessary to see +if the statement has already been prepared:: + + ... + + def my_operation(self, op_arg1, op_arg2): + if self.hasattr(self, '_my_operation'): + ps = self._my_operation + else: + ps = self._my_operation = self.connection.prepare( + "SELECT my_operation_proc($1, $2)", + ) + return ps(op_arg1, op_arg2) + ... + +There are many variations that can implement the above. It works and it's +simple, but it will be exhausting if repeated and error prone if the +initialization condition is not factored out. Additionally, if access to statement +metadata is needed, the above example is still lacking as it would require +execution of the statement and further protocol expectations to be established. +This is the province of libraries: direct database interface management. + +Categories and Libraries are used to factor out and simplify +the above functionality so re-implementation is unnecessary. For example, an +ILF library containing the symbol:: + + [my_operation] + SELECT my_operation_proc($1, $2) + + [] + ... + +Will provide the same functionality as the ``my_operation`` method in the +latter Python implementation. + + +Terminology +=========== + +The following terms are used throughout this chapter: + + Annotations + The information of about a Symbol describing what it is and how it should be + used. + + Binding + An interface to the Symbols provided by a Library for use with a given + connection. + + Bound Symbol + An interface to an individual Symbol ready for execution against the subject + database. + + Bound Reference + An interface to an individual Reference Symbol that will produce a Bound + Symbol when executed. + + ILF + INI-style Library Format. "lib{NAME}.sql" files. + + Library + A collection of Symbols--mapping of names to SQL statements. + + Local Symbol + A relative term used to denote a symbol that exists in the same library as + the subject symbol. + + Preface + The block of text that comes before the first symbol in an ILF file. + + Symbol + An named database operation provided by a Library. Usually, an SQL statement + with Annotations. + + Reference Symbol + A Symbol whose SQL statement *produces* the source for a Bound Symbol. + + Category + An object supporting a classification for connectors that provides database + initialization facilities for produced connections. For libraries, + `postgresql.lib.Category` objects are a set of Libraries, + `postgresql.lib.Library`. diff --git a/py_opengauss/documentation/sphinx/notifyman.rst b/py_opengauss/documentation/sphinx/notifyman.rst new file mode 100644 index 0000000000000000000000000000000000000000..d774ee52b7e579bd73ef6a3e1fa199feb8db6cca --- /dev/null +++ b/py_opengauss/documentation/sphinx/notifyman.rst @@ -0,0 +1,237 @@ +.. _notifyman: + +*********************** +Notification Management +*********************** + +Relevant SQL commands: `NOTIFY `_, +`LISTEN `_, +`UNLISTEN `_. + +Asynchronous notifications offer a means for PostgreSQL to signal application +code. Often these notifications are used to signal cache invalidation. In 9.0 +and greater, notifications may include a "payload" in which arbitrary data may +be delivered on a channel being listened to. + +By default, received notifications will merely be appended to an internal +list on the connection object. This list will remain empty for the duration +of a connection *unless* the connection begins listening to a channel that +receives notifications. + +The `postgresql.notifyman.NotificationManager` class is used to wait for +messages to come in on a set of connections, pick up the messages, and deliver +the messages to the object's user via the `collections.Iterator` protocol. + + +Listening on a Single Connection +================================ + +The ``db.iternotifies()`` method is a simplification of the notification manager. It +returns an iterator to the notifications received on the subject connection. +The iterator yields triples consisting of the ``channel`` being +notified, the ``payload`` sent with the notification, and the ``pid`` of the +backend that caused the notification:: + + >>> db.listen('for_rabbits') + >>> db.notify('for_rabbits') + >>> for x in db.iternotifies(): + ... channel, payload, pid = x + ... break + >>> assert channel == 'for_rabbits' + True + >>> assert payload == '' + True + >>> assert pid == db.backend_id + True + +The iterator, by default, will continue listening forever unless the connection +is terminated--thus the immediate ``break`` statement in the above loop. In +cases where some additional activity is necessary, a timeout parameter may be +given to the ``iternotifies`` method in order to allow "idle" events to occur +at the designated frequency:: + + >>> for x in db.iternotifies(0.5): + ... if x is None: + ... break + +The above example illustrates that idle events are represented using `None` +objects. Idle events are guaranteed to occur *approximately* at the +specified interval--the ``timeout`` keyword parameter. In addition to +providing a means to do other processing or polling, they also offer a safe +break point for the loop. Internally, the iterator produced by the +``iternotifies`` method *is* a `NotificationManager`, which will localize the +notifications prior to emitting them via the iterator. +*It's not safe to break out of the loop, unless an idle event is being handled.* +If the loop is broken while a regular event is being processed, some events may +remain in the iterator. In order to consume those events, the iterator *must* +be accessible. + +The iterator will be exhausted when the connection is closed, but if the +connection is closed during the loop, any remaining notifications *will* +be emitted prior to the loop ending, so it is important to be prepared to +handle exceptions or check for a closed connection. + +In situations where multiple connections need to be watched, direct use of the +`NotificationManager` is necessary. + + +Listening on Multiple Connections +================================= + +The `postgresql.notifyman.NotificationManager` class is used to manage +*connections* that are expecting to receive notifications. Instances are +iterators that yield the connection object and notifications received on the +connection or `None` in the case of an idle event. The manager emits events as +a pair; the connection object that received notifications, and *all* the +notifications picked up on that connection:: + + >>> from postgresql.notifyman import NotificationManager + >>> # Using ``nm`` to reference the manager from here on. + >>> nm = NotificationManager(db1, db2, ..., dbN) + >>> nm.settimeout(2) + >>> for x in nm: + ... if x is None: + ... # idle + ... break + ... + ... db, notifies = x + ... for channel, payload, pid in notifies: + ... ... + +The manager will continue to wait for and emit events so long as there are +good connections available in the set; it is possible for connections to be +added and removed at any time. Although, in rare circumstances, discarded +connections may still have pending events if it not removed during an idle +event. The ``connections`` attribute on `NotificationManager` objects is a +set object that may be used directly in order to add and remove connections +from the manager:: + + >>> y = [] + >>> for x in nm: + ... if x is None: + ... if y: + ... nm.connections.add(y[0]) + ... del y[0] + ... + +The notification manager is resilient; if a connection dies, it will discard the +connection from the set, and add it to the set of bad connections, the +``garbage`` attribute. In these cases, the idle event *should* be leveraged to +check for these failures if that's a concern. It is the user's +responsibility to explicitly handle the failure cases, and remove the bad +connections from the ``garbage`` set:: + + >>> for x in nm: + ... if x is None: + ... if nm.garbage: + ... recovered = take_out_trash(nm.garbage) + ... nm.connections.update(recovered) + ... nm.garbage.clear() + ... db, notifies = x + ... for channel, payload, pid in notifies: + ... ... + +Explicitly removing connections from the set can also be a means to gracefully +terminate the event loop:: + + >>> for x in nm: + ... if x in None: + ... if done_listening is True: + ... nm.connections.clear() + +However, doing so inside the loop is not a requirement; it is safe to remove a +connection from the set at any point. + + +Notification Managers +===================== + +The `postgresql.notifyman.NotificationManager` is an event loop that services +multiple connections. In cases where only one connection needs to be serviced, +the `postgresql.api.Database.iternotifies` method can be used to simplify the +process. + + +Notification Manager Constructors +--------------------------------- + + ``NotificationManager(*connections, timeout = None)`` + Create a NotificationManager instance that manages the notifications coming + from the given set of connections. The ``timeout`` keyword is optional and + can be configured using the ``settimeout`` method as well. + + +Notification Manager Interface Points +------------------------------------- + + ``NotificationManager.__iter__()`` + Returns the instance; it is an iterator. + + ``NotificationManager.__next__()`` + Normally, yield the pair, connection and notifications list, when the next + event is received. If a timeout is configured, `None` may be yielded to signal + an idle event. The notifications list is a list of triples: + ``(channel, payload, pid)``. + + ``NotificationManager.settimeout(timeout : int)`` + Set the amount of time to wait before the manager yields an idle event. + If zero, the manager will never wait and only yield notifications that are + immediately available. + If `None`, the manager will never emit idle events. + + ``NotificationManager.gettimeout() -> [int, None]`` + Get the configured timeout; returns either `None`, or an `int`. + + ``NotificationManager.connections`` + The set of connections that the manager is actively watching for + notifications. Connections may be added or removed from the set at any time. + + ``NotificationManager.garbage`` + The set of connections that failed. Normally empty, but when a connection gets + an exceptional condition or explicitly raises an exception, it is removed from + the ``connections`` set, and placed in ``garbage``. + + +Zero Timeout +------------ + +When a timeout of zero, ``0``, is configured, the notification manager will +terminate early. Specifically, each connection will be polled for any pending +notifications, and once all of the collected notifications have been emitted +by the iterator, `StopIteration` will be raised. Notably, *no* idle events will +occur when the timeout is configured to zero. + +Zero timeouts offer a means for the notification "queue" to be polled. Often, +this is the appropriate way to collect pending notifications on active +connections where using the connection exclusively for waiting is not +practical:: + + >>> notifies = list(db.iternotifies(0)) + +Or with a NotificationManager instance:: + + >>> nm.settimeout(0) + >>> db_notifies = list(nm) + +In both cases of zero timeout, the iterator may be promptly discarded without +losing any events. + + +Summary of Characteristics +-------------------------- + + * The iterator will continue until the connections die. + * Objects yielded by the iterator are either `None`, an "idle event", or an + individual notification triple if using ``db.iternotifies()``, or a + ``(db, notifies)`` pair if using the base `NotificationManager`. + * When a connection dies or raises an exception, it will be removed from + the ``nm.connections`` set and added to the ``nm.garbage`` set. + * The NotificationManager instance will *not* hold any notifications + during an idle event. Idle events offer a break point in which the manager + may be discarded. + * A timeout of zero will cause the iterator to only yield the events + that are pending right now, and promptly end. However, the same manager + object may be used again. + * A notification triple is a tuple consisting of ``(channel, payload, pid)``. + * Connections may be added and removed from the ``nm.connections`` set at + any time. diff --git a/py_opengauss/documentation/sphinx/reference.rst b/py_opengauss/documentation/sphinx/reference.rst new file mode 100644 index 0000000000000000000000000000000000000000..466a672e3b7fb3805b99d789349b7a9157c805f7 --- /dev/null +++ b/py_opengauss/documentation/sphinx/reference.rst @@ -0,0 +1,82 @@ +Reference +========= + +:mod:`postgresql` +----------------- + +.. automodule:: postgresql +.. autodata:: version +.. autodata:: version_info +.. autofunction:: open + +:mod:`postgresql.api` +--------------------- + +.. automodule:: + postgresql.api + :members: + :show-inheritance: + +:mod:`postgresql.sys` +--------------------- + +.. automodule:: + postgresql.sys + :members: + :show-inheritance: + +:mod:`postgresql.string` +------------------------ + +.. automodule:: + postgresql.string + :members: + :show-inheritance: + +:mod:`postgresql.exceptions` +---------------------------- + +.. automodule:: + postgresql.exceptions + :members: + :show-inheritance: + +:mod:`postgresql.temporal` +-------------------------- + +.. automodule:: + postgresql.temporal + :members: + :show-inheritance: + +:mod:`postgresql.installation` +------------------------------ + +.. automodule:: + postgresql.installation + :members: + :show-inheritance: + +:mod:`postgresql.cluster` +------------------------- + +.. automodule:: + postgresql.cluster + :members: + :show-inheritance: + +:mod:`postgresql.copyman` +------------------------- + +.. automodule:: + postgresql.copyman + :members: + :show-inheritance: + +:mod:`postgresql.alock` +----------------------- + +.. automodule:: + postgresql.alock + :members: + :show-inheritance: diff --git a/py_opengauss/driver/__init__.py b/py_opengauss/driver/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c2c4337aaa04b593600c4eda87dda2e70408b7 --- /dev/null +++ b/py_opengauss/driver/__init__.py @@ -0,0 +1,14 @@ +## +# .driver package +## +""" +Driver package for providing an interface to a PostgreSQL database. +""" +__all__ = ['connect', 'default'] + +from .pq3 import Driver +default = Driver() + +def connect(*args, **kw): + 'Establish a connection using the default driver.' + return default.connect(*args, **kw) diff --git a/py_opengauss/driver/dbapi20.py b/py_opengauss/driver/dbapi20.py new file mode 100644 index 0000000000000000000000000000000000000000..7296793a33c3fc93c8197b350f70ae2d8a3e7876 --- /dev/null +++ b/py_opengauss/driver/dbapi20.py @@ -0,0 +1,418 @@ +## +# .driver.dbapi20 - DB-API 2.0 Implementation +## +""" +DB-API 2.0 conforming interface using postgresql.driver. +""" +threadsafety = 1 +paramstyle = 'pyformat' +apilevel = '2.0' + +from operator import itemgetter +from functools import partial +import datetime +import time +import re + +from .. import clientparameters as pg_param +from .. import driver as pg_driver +from .. import types as pg_type +from .. import string as pg_str +from .pq3 import Connection + +## +# Basically, is it a mapping, or is it a sequence? +# If findall()'s first index is 's', it's a sequence. +# If it starts with '(', it's mapping. +# The pain here is due to a need to recognize any %% escapes. +parameters_re = re.compile( + r'(?:%%)+|%(s|[(][^)]*[)]s)' +) +def percent_parameters(sql): + # filter any %% matches(empty strings). + return [x for x in parameters_re.findall(sql) if x] + +def convert_keywords(keys, mapping): + return [mapping[k] for k in keys] + +from py_opengauss.exceptions import \ + Error, DataError, InternalError, \ + ICVError as IntegrityError, \ + SEARVError as ProgrammingError, \ + IRError as OperationalError, \ + DriverError as InterfaceError, \ + Warning +DatabaseError = Error +class NotSupportedError(DatabaseError): + pass + +STRING = str +BINARY = bytes +NUMBER = int +DATETIME = datetime.datetime +ROWID = int + +Binary = BINARY +Date = datetime.date +Time = datetime.time +Timestamp = datetime.datetime +DateFromTicks = lambda x: Date(*time.localtime(x)[:3]) +TimeFromTicks = lambda x: Time(*time.localtime(x)[3:6]) +TimestampFromTicks = lambda x: Timestamp(*time.localtime(x)[:7]) + +def dbapi_type(typid): + if typid in ( + pg_type.TEXTOID, + pg_type.CHAROID, + pg_type.VARCHAROID, + pg_type.NAMEOID, + pg_type.CSTRINGOID, + ): + return STRING + elif typid == pg_type.BYTEAOID: + return BINARY + elif typid in (pg_type.INT8OID, pg_type.INT2OID, pg_type.INT4OID): + return NUMBER + elif typid in (pg_type.TIMESTAMPOID, pg_type.TIMESTAMPTZOID): + return DATETIME + elif typid == pg_type.OIDOID: + return ROWID + +class Portal(object): + """ + Manages read() interfaces to a chunks iterator. + """ + def __init__(self, chunks): + self.chunks = chunks + self.buf = [] + self.pos = 0 + + def __next__(self): + try: + r = self.buf[self.pos] + self.pos += 1 + return r + except IndexError: + # Any alledged infinite recursion will stop on the StopIteration + # thrown by this next(). Recursion is unlikely to occur more than + # once; specifically, empty chunks would need to be returned + # by this invocation of next(). + self.buf = next(self.chunks) + self.pos = 0 + return self.__next__() + + def readall(self): + self.buf = self.buf[self.pos:] + self.pos = 0 + for x in self.chunks: + self.buf.extend(x) + r = self.buf + self.buf = [] + return r + + def read(self, amount): + try: + while (len(self.buf) - self.pos) < amount: + self.buf.extend(next(self.chunks)) + end = self.pos + amount + except StopIteration: + # end of cursor + end = len(self.buf) + + r = self.buf[self.pos:end] + del self.buf[:end] + self.pos = 0 + return r + +class Cursor(object): + rowcount = -1 + arraysize = 1 + description = None + + def __init__(self, C): + self.database = self.connection = C + self.description = () + self.__portals = [] + + # Describe the "real" cursor as a "portal". + # This should keep ambiguous terminology out of the adaptation. + def _portal(): + def fget(self): + if self.__portals is None: + raise Error("cursor is closed", + source = 'CLIENT', creator = self.database + ) + try: + p = self.__portals[0] + except IndexError: + raise InterfaceError("no portal on stack") + return p + def fdel(self): + if self.__portals is None: + raise Error("cursor is closed", + source = 'CLIENT', creator = self.database + ) + try: + del self.__portals[0] + except IndexError: + raise InterfaceError("no portal on stack") + return locals() + _portal = property(**_portal()) + + def setinputsizes(self, sizes): + if self.__portals is None: + raise Error("cursor is closed", + source = 'CLIENT', creator = self.database) + + def setoutputsize(self, sizes, columns = None): + if self.__portals is None: + raise Error("cursor is closed", + source = 'CLIENT', creator = self.database) + + def callproc(self, proname, args): + if self.__portals is None: + raise Error("cursor is closed", + source = 'CLIENT', creator = self.database) + + p = self.database.prepare("SELECT %s(%s)" %( + proname, ','.join([ + '$' + str(x) for x in range(1, len(args) + 1) + ]) + )) + self.__portals.insert(0, Portal(p.chunks(*args))) + return args + + def fetchone(self): + try: + return next(self._portal) + except StopIteration: + return None + + def __next__(self): + return next(self._portal) + next = __next__ + + def __iter__(self): + return self + + def fetchmany(self, arraysize = None): + return self._portal.read(arraysize or self.arraysize or 1) + + def fetchall(self): + return self._portal.readall() + + def nextset(self): + del self._portal + return len(self.__portals) or None + + def fileno(self): + return self.database.fileno() + + def _convert_query(self, string): + parts = list(pg_str.split(string)) + style = None + count = 0 + keys = [] + kmap = {} + transformer = tuple + rparts = [] + for part in parts: + if part.__class__ is ().__class__: + # skip quoted portions + rparts.append(part) + else: + r = percent_parameters(part) + pcount = 0 + for x in r: + if x == 's': + pcount += 1 + else: + x = x[1:-2] + if x not in keys: + kmap[x] = '$' + str(len(keys) + 1) + keys.append(x) + if r: + if pcount: + # format + params = tuple([ + '$' + str(i+1) for i in range(count, count + pcount) + ]) + count += pcount + rparts.append(part % params) + else: + # pyformat + rparts.append(part % kmap) + else: + # no parameters identified in string + rparts.append(part) + + if keys: + if count: + raise TypeError( + "keyword parameters and positional parameters used in query" + ) + transformer = partial(convert_keywords, keys) + count = len(keys) + + return (pg_str.unsplit(rparts) if rparts else string, transformer, count) + + def execute(self, statement, parameters = ()): + if self.__portals is None: + raise Error("cursor is closed", + source = 'CLIENT', creator = self.database) + + sql, pxf, nparams = self._convert_query(statement) + if nparams != -1 and len(parameters) != nparams: + raise TypeError( + "statement require %d parameters, given %d" %( + nparams, len(parameters) + ) + ) + ps = self.database.prepare(sql) + c = ps.chunks(*pxf(parameters)) + if ps._output is not None and len(ps._output) > 0: + # name, relationId, columnNumber, typeId, typlen, typmod, format + self.rowcount = -1 + self.description = tuple([ + (self.database.typio.decode(x[0]), dbapi_type(x[3]), + None, None, None, None, None) + for x in ps._output + ]) + self.__portals.insert(0, Portal(c)) + else: + self.rowcount = c.count() + if self.rowcount is None: + self.rowcount = -1 + self.description = None + # execute bumps any current portal + if self.__portals: + del self._portal + return self + + def executemany(self, statement, parameters): + if self.__portals is None: + raise Error("cursor is closed", + source = 'CLIENT', creator = self.database) + + sql, pxf, nparams = self._convert_query(statement) + ps = self.database.prepare(sql) + if ps._input is not None: + ps.load_rows(map(pxf, parameters)) + else: + ps.load_rows(parameters) + self.rowcount = -1 + return self + + def close(self): + if self.__portals is None: + raise Error("cursor is closed", + source = 'CLIENT', creator = self.database) + self.description = None + self.__portals = None + +class Connection(Connection): + """ + DB-API 2.0 connection implementation for PG-API connection objects. + """ + from py_opengauss.exceptions import \ + Error, DataError, InternalError, \ + ICVError as IntegrityError, \ + SEARVError as ProgrammingError, \ + IRError as OperationalError, \ + DriverError as InterfaceError, \ + Warning + DatabaseError = DatabaseError + NotSupportedError = NotSupportedError + + # Explicitly manage DB-API connected state to properly + # throw the already closed error. + _dbapi_connected_flag = False + + def autocommit_set(self, val): + if val: + # already in autocommit mode. + if self._xact is None: + return + self._xact.rollback() + self._xact = None + else: + if self._xact is not None: + return + self._xact = self.xact() + self._xact.start() + + def autocommit_get(self): + return self._xact is None + + def autocommit_del(self): + self.autocommit = False + + autocommit = property( + fget = autocommit_get, + fset = autocommit_set, + fdel = autocommit_del, + ) + del autocommit_set, autocommit_get, autocommit_del + + def connect(self, *args, **kw): + super().connect(*args, **kw) + self._xact = self.xact() + self._xact.start() + self._dbapi_connected_flag = True + + def close(self): + if self.closed and self._dbapi_connected_flag: + raise Error( + "connection already closed", + source = 'CLIENT', + creator = self + ) + self._dbapi_connected_flag = True + super().close() + + def cursor(self): + return Cursor(self) + + def commit(self): + if self._xact is None: + raise InterfaceError( + "commit on connection in autocommit mode", + source = 'CLIENT', + details = { + 'hint': 'The "autocommit" property on the connection was set to True.' + }, + creator = self + ) + self._xact.commit() + self._xact = self.xact() + self._xact.start() + + def rollback(self): + if self._xact is None: + raise InterfaceError( + "rollback on connection in autocommit mode", + source = 'DRIVER', + details = { + 'hint': 'The "autocommit" property on the connection was set to True.' + }, + creator = self + ) + self._xact.rollback() + self._xact = self.xact() + self._xact.start() + +driver = pg_driver.Driver(connection = Connection) +def connect(**kw): + """ + Create a DB-API connection using the given parameters. + + Due to the way defaults are populated, when connecting to a local filesystem socket + using the `unix` keyword parameter, `host` and `port` must also be set to ``None``. + """ + std_params = pg_param.collect(prompt_title = None) + params = pg_param.normalize( + list(pg_param.denormalize_parameters(std_params)) + \ + list(pg_param.denormalize_parameters(kw)) + ) + pg_param.resolve_password(params) + return driver.connect(**params) diff --git a/py_opengauss/driver/pq3.py b/py_opengauss/driver/pq3.py new file mode 100644 index 0000000000000000000000000000000000000000..24976de025842310facc2a363285bc2c9bb0ea1a --- /dev/null +++ b/py_opengauss/driver/pq3.py @@ -0,0 +1,3064 @@ +## +# .driver.pq3 - interface to PostgreSQL using PQ v3.0. +## +""" +PG-API interface for PostgreSQL using PQ version 3.0. +""" +import os +import weakref +import socket +from traceback import format_exception +from itertools import repeat, chain, count +from functools import partial +from abc import abstractmethod +from codecs import lookup as lookup_codecs + +from operator import itemgetter +get0 = itemgetter(0) +get1 = itemgetter(1) + +from .. import lib as pg_lib + +from .. import versionstring as pg_version +from .. import iri as pg_iri +from .. import exceptions as pg_exc +from .. import string as pg_str +from .. import api as pg_api +from .. import message as pg_msg +from ..encodings.aliases import get_python_name +from ..string import quote_ident + +from ..python.itertools import interlace, chunk +from ..python.socket import SocketFactory +from ..python.functools import process_tuple, process_chunk +from ..python.functools import Composition as compose + +from ..protocol import xact3 as xact +from ..protocol import element3 as element +from ..protocol import client3 as client +from ..protocol.message_types import message_types + +from ..notifyman import NotificationManager +from .. import types as pg_types +from ..types import io as pg_types_io +from ..types.io import lib as io_lib + +import warnings + +# Map element3.Notice field identifiers +# to names used by message.Message. +notice_field_to_name = { + message_types[b'S'[0]] : 'severity', + message_types[b'C'[0]] : 'code', + message_types[b'M'[0]] : 'message', + message_types[b'D'[0]] : 'detail', + message_types[b'H'[0]] : 'hint', + message_types[b'W'[0]] : 'context', + message_types[b'P'[0]] : 'position', + message_types[b'p'[0]] : 'internal_position', + message_types[b'q'[0]] : 'internal_query', + message_types[b'F'[0]] : 'file', + message_types[b'L'[0]] : 'line', + message_types[b'R'[0]] : 'function', +} +del message_types + +notice_field_from_name = dict( + (v, k) for (k, v) in notice_field_to_name.items() +) + +could_not_connect = element.ClientError(( + (b'S', 'FATAL'), + (b'C', '08001'), + (b'M', "could not establish connection to server"), +)) + +# generate an id for a client statement or cursor +def ID(s, title = None, IDNS = 'py:'): + return IDNS + hex(id(s)) + +def declare_statement_string( + cursor_id, + statement_string, + insensitive = True, + scroll = True, + hold = True +): + s = 'DECLARE ' + cursor_id + if insensitive is True: + s += ' INSENSITIVE' + if scroll is True: + s += ' SCROLL' + s += ' CURSOR' + if hold is True: + s += ' WITH HOLD' + else: + s += ' WITHOUT HOLD' + return s + ' FOR ' + statement_string + +def direction_str_to_bool(str): + s = str.upper() + if s == 'FORWARD': + return True + elif s == 'BACKWARD': + return False + else: + raise ValueError("invalid direction " + repr(str)) + +def direction_to_bool(v): + if isinstance(v, str): + return direction_str_to_bool(v) + elif v is not True and v is not False: + raise TypeError("invalid direction " + repr(v)) + else: + return v + +class TypeIO(pg_api.TypeIO): + """ + A class that manages I/O for a given configuration. Normally, a connection + would create an instance, and configure it based upon the version and + configuration of PostgreSQL that it is connected to. + """ + _e_factors = ('database',) + strio = (None, None, str) + + def __init__(self, database): + self.database = database + self.encoding = None + strio = self.strio + self._cache = { + # Encoded character strings + pg_types.ACLITEMOID : strio, # No binary functions. + pg_types.NAMEOID : strio, + pg_types.BPCHAROID : strio, + pg_types.VARCHAROID : strio, + pg_types.CSTRINGOID : strio, + pg_types.TEXTOID : strio, + pg_types.REGTYPEOID : strio, + pg_types.REGPROCOID : strio, + pg_types.REGPROCEDUREOID : strio, + pg_types.REGOPEROID : strio, + pg_types.REGOPERATOROID : strio, + pg_types.REGCLASSOID : strio, + } + self.typinfo = {} + super().__init__() + + def lookup_type_info(self, typid): + return self.database.sys.lookup_type(typid) + + def lookup_composite_type_info(self, typid): + return self.database.sys.lookup_composite(typid) + + def lookup_domain_basetype(self, typid): + if self.database.version_info[:2] >= (8, 4): + return self.lookup_domain_basetype_84(typid) + + while typid: + r = self.database.sys.lookup_basetype(typid) + if not r[0][0]: + return typid + else: + typid = r[0][0] + + def lookup_domain_basetype_84(self, typid): + r = self.database.sys.lookup_basetype_recursive(typid) + return r[0][0] + + def set_encoding(self, value): + """ + Set a new client encoding. + """ + self.encoding = value.lower().strip() + enc = get_python_name(self.encoding) + ci = lookup_codecs(enc or self.encoding) + self._encode, self._decode, *_ = ci + + def encode(self, string_data): + return self._encode(string_data)[0] + + def decode(self, bytes_data): + return self._decode(bytes_data)[0] + + def encodes(self, iter, get0 = get0): + """ + Encode the items in the iterable in the configured encoding. + """ + return map(compose((self._encode, get0)), iter) + + def decodes(self, iter, get0 = get0): + """ + Decode the items in the iterable from the configured encoding. + """ + return map(compose((self._decode, get0)), iter) + + def resolve_pack(self, typid): + return self.resolve(typid)[0] or self.encode + + def resolve_unpack(self, typid): + return self.resolve(typid)[1] or self.decode + + def attribute_map(self, pq_descriptor): + return zip(self.decodes(pq_descriptor.keys()), count()) + + def sql_type_from_oid(self, oid, qi = quote_ident): + if oid in pg_types.oid_to_sql_name: + return pg_types.oid_to_sql_name[oid] + if oid in self.typinfo: + nsp, name, *_ = self.typinfo[oid] + return qi(nsp) + '.' + qi(name) + name = pg_types.oid_to_name.get(oid) + if name: + return 'pg_catalog.%s' % name + else: + return None + + def type_from_oid(self, oid): + if oid in self._cache: + typ = self._cache[oid][2] + return typ + + def resolve_descriptor(self, desc, index): + """ + Create a sequence of I/O routines from a pq descriptor. + """ + return [ + (self.resolve(x[3]) or (None, None))[index] for x in desc + ] + + # lookup a type's IO routines from a given typid + def resolve(self, + typid : int, + from_resolution_of : [int] = (), + builtins = pg_types_io.resolve, + quote_ident = quote_ident + ): + if from_resolution_of and typid in from_resolution_of: + raise TypeError( + "type, %d, is already being looked up: %r" %( + typid, from_resolution_of + ) + ) + typid = int(typid) + typio = None + + if typid in self._cache: + typio = self._cache[typid] + else: + typio = builtins(typid) + if typio is not None: + # If typio is a tuple, it's a constant pair: (pack, unpack) + # otherwise, it's an I/O pair constructor. + if typio.__class__ is not tuple: + typio = typio(typid, self) + self._cache[typid] = typio + + if typio is None: + # Lookup the type information for the typid as it's not cached. + ## + ti = self.lookup_type_info(typid) + if ti is not None: + typnamespace, typname, typtype, typlen, typelem, typrelid, \ + ae_typid, ae_hasbin_input, ae_hasbin_output = ti + self.typinfo[typid] = ( + typnamespace, typname, typrelid, int(typelem) if ae_typid else None + ) + if typrelid: + # Row type + # + # The attribute name map, + # column I/O, + # column type Oids + # are needed to build the packing pair. + attmap = {} + cio = [] + typids = [] + attnames = [] + i = 0 + for x in self.lookup_composite_type_info(typrelid): + attmap[x[1]] = i + attnames.append(x[1]) + if x[2]: + # This is a domain + fieldtypid = self.lookup_domain_basetype(x[0]) + else: + fieldtypid = x[0] + typids.append(x[0]) + te = self.resolve( + fieldtypid, list(from_resolution_of) + [typid] + ) + cio.append((te[0] or self.encode, te[1] or self.decode)) + i += 1 + self._cache[typid] = typio = self.record_io_factory( + cio, typids, attmap, list( + map(self.sql_type_from_oid, typids) + ), attnames, + typrelid, + quote_ident(typnamespace) + '.' + \ + quote_ident(typname), + ) + elif ae_typid is not None: + # resolve the element type and I/O pair + te = self.resolve( + int(typelem), + from_resolution_of = list(from_resolution_of) + [typid] + ) or (None, None) + typio = self.array_io_factory( + te[0] or self.encode, + te[1] or self.decode, + typelem, + ae_hasbin_input, + ae_hasbin_output + ) + self._cache[typid] = typio + else: + typio = None + if typtype == b'd': + basetype = self.lookup_domain_basetype(typid) + typio = self.resolve( + basetype, + from_resolution_of = list(from_resolution_of) + [typid] + ) + elif typtype == b'p' and typnamespace == 'pg_catalog' and typname == 'record': + # anonymous record type + typio = self.anon_record_io_factory() + + if not typio: + typio = self.strio + + self._cache[typid] = typio + else: + # Throw warning about type without entry in pg_type? + typio = self.strio + return typio + + def identify(self, **identity_mappings): + """ + Explicitly designate the I/O handler for the specified type. + + Primarily used in cases involving UDTs. + """ + # get them ordered; we process separately, then recombine. + id = list(identity_mappings.items()) + ios = [pg_types_io.resolve(x[0]) for x in id] + oids = list(self.database.sys.regtypes([x[1] for x in id])) + + self._cache.update([ + (oid, io if io.__class__ is tuple else io(oid, self)) + for oid, io in zip(oids, ios) + ]) + + def array_parts(self, array, ArrayType = pg_types.Array): + if array.__class__ is not ArrayType: + # Assume the data is a nested list. + array = ArrayType(array) + return ( + array.elements(), + array.dimensions, + array.lowerbounds + ) + + def array_from_parts(self, parts, ArrayType = pg_types.Array): + elements, dimensions, lowerbounds = parts + return ArrayType.from_elements( + elements, + lowerbounds = lowerbounds, + upperbounds = [x + lb - 1 for x, lb in zip(dimensions, lowerbounds)] + ) + + ## + # array_io_factory - build I/O pair for ARRAYs + ## + def array_io_factory( + self, + pack_element, unpack_element, + typoid, # array element id + hasbin_input, hasbin_output, + array_pack = io_lib.array_pack, + array_unpack = io_lib.array_unpack, + ): + packed_typoid = io_lib.ulong_pack(typoid) + if hasbin_input: + def pack_an_array(data, get_parts = self.array_parts): + elements, dimensions, lowerbounds = get_parts(data) + return array_pack(( + 0, # unused flags + typoid, dimensions, lowerbounds, + (x if x is None else pack_element(x) for x in elements), + )) + else: + # signals string formatting + pack_an_array = None + + if hasbin_output: + def unpack_an_array(data, array_from_parts = self.array_from_parts): + flags, typoid, dims, lbs, elements = array_unpack(data) + return array_from_parts(((x if x is None else unpack_element(x) for x in elements), dims, lbs)) + else: + # signals string formatting + unpack_an_array = None + + return (pack_an_array, unpack_an_array, pg_types.Array) + + def RowTypeFactory(self, attribute_map = {}, _Row = pg_types.Row.from_sequence, composite_relid = None): + return partial(_Row, attribute_map) + + ## + # record_io_factory - Build an I/O pair for RECORDs + ## + def record_io_factory(self, + column_io, typids, attmap, typnames, attnames, composite_relid, composite_name, + get0 = get0, + get1 = get1, + fmt_errmsg = "failed to {0} attribute {1}, {2}::{3}, of composite {4} from wire data".format + ): + # column_io: sequence (pack,unpack) tuples corresponding to the columns. + # typids: sequence of type Oids; index must correspond to the composite's. + # attmap: mapping of column name to index number. + # typnames: sequence of sql type names in order. + # attnames: sequence of attribute names in order. + # composite_relid: oid of the composite relation. + # composite_name: the name of the composite type. + + fpack = tuple(map(get0, column_io)) + funpack = tuple(map(get1, column_io)) + row_constructor = self.RowTypeFactory(attribute_map = attmap, composite_relid = composite_relid) + + def raise_pack_tuple_error(cause, procs, tup, itemnum): + data = repr(tup[itemnum]) + if len(data) > 80: + # Be sure not to fill screen with noise. + data = data[:75] + ' ...' + self.raise_client_error(element.ClientError(( + (b'C', '--cIO',), + (b'S', 'ERROR',), + (b'M', fmt_errmsg('pack', itemnum, attnames[itemnum], typnames[itemnum], composite_name),), + (b'W', data,), + (b'P', str(itemnum),) + )), cause = cause) + + def raise_unpack_tuple_error(cause, procs, tup, itemnum): + data = repr(tup[itemnum]) + if len(data) > 80: + # Be sure not to fill screen with noise. + data = data[:75] + ' ...' + self.raise_client_error(element.ClientError(( + (b'C', '--cIO',), + (b'S', 'ERROR',), + (b'M', fmt_errmsg('unpack', itemnum, attnames[itemnum], typnames[itemnum], composite_name),), + (b'W', data,), + (b'P', str(itemnum),), + )), cause = cause) + + def unpack_a_record(data, + unpack = io_lib.record_unpack, + process_tuple = process_tuple, + row_constructor = row_constructor + ): + data = tuple([x[1] for x in unpack(data)]) + return row_constructor(process_tuple(funpack, data, raise_unpack_tuple_error)) + + sorted_atts = sorted(attmap.items(), key = get1) + def pack_a_record(data, + pack = io_lib.record_pack, + process_tuple = process_tuple, + ): + if isinstance(data, dict): + data = [data.get(k) for k,_ in sorted_atts] + return pack( + tuple(zip( + typids, + process_tuple(fpack, tuple(data), raise_pack_tuple_error) + )) + ) + return (pack_a_record, unpack_a_record, tuple) + + def anon_record_io_factory(self): + def raise_unpack_tuple_error(cause, procs, tup, itemnum): + data = repr(tup[itemnum]) + if len(data) > 80: + # Be sure not to fill screen with noise. + data = data[:75] + ' ...' + self.raise_client_error(element.ClientError(( + (b'C', '--cIO',), + (b'S', 'ERROR',), + (b'M', 'Could not unpack element {} from anonymous record'.format(itemnum)), + (b'W', data,), + (b'P', str(itemnum),) + )), cause = cause) + + def _unpack_record(data, unpack = io_lib.record_unpack, process_tuple = process_tuple): + record = list(unpack(data)) + coloids = tuple(x[0] for x in record) + colio = map(self.resolve, coloids) + column_unpack = tuple(c[1] or self.decode for c in colio) + + data = tuple(x[1] for x in record) + + return process_tuple(column_unpack, data, raise_unpack_tuple_error) + + return (None, _unpack_record) + + def raise_client_error(self, error_message, cause = None, creator = None): + m = { + notice_field_to_name[k] : v + for k, v in error_message.items() + # don't include unknown messages in this list. + if k in notice_field_to_name + } + c = m.pop('code') + ms = m.pop('message') + client_error = self.lookup_exception(c) + client_error = client_error(ms, code = c, details = m, source = 'CLIENT', creator = creator or self.database) + client_error.database = self.database + if cause is not None: + raise client_error from cause + else: + raise client_error + + def lookup_exception(self, code, errorlookup = pg_exc.ErrorLookup,): + return errorlookup(code) + + def lookup_warning(self, code, warninglookup = pg_exc.WarningLookup,): + return warninglookup(code) + + def raise_server_error(self, error_message, cause = None, creator = None): + m = dict(self.decode_notice(error_message)) + c = m.pop('code') + ms = m.pop('message') + server_error = self.lookup_exception(c) + server_error = server_error(ms, code = c, details = m, source = 'SERVER', creator = creator or self.database) + server_error.database = self.database + if cause is not None: + raise server_error from cause + else: + raise server_error + + def raise_error(self, error_message, ClientError = element.ClientError, **kw): + if 'creator' not in kw: + kw['creator'] = getattr(self.database, '_controller', self.database) or self.database + + if error_message.__class__ is ClientError: + self.raise_client_error(error_message, **kw) + else: + self.raise_server_error(error_message, **kw) + + ## + # Used by decode_notice() + def _decode_failsafe(self, data): + decode = self._decode + i = iter(data) + for x in i: + try: + # prematurely optimized for your viewing displeasure. + v = x[1] + yield (x[0], decode(v)[0]) + for x in i: + v = x[1] + yield (x[0], decode(v)[0]) + except UnicodeDecodeError: + # Fallback to the bytes representation. + # This should be sufficiently informative in most cases, + # and in the cases where it isn't, an element traceback should + # ultimately yield the pertinent information + yield (x[0], repr(x[1])[2:-1]) + + def decode_notice(self, notice): + notice = self._decode_failsafe(notice.items()) + return { + notice_field_to_name[k] : v + for k, v in notice + # don't include unknown messages in this list. + if k in notice_field_to_name + } + + def emit_server_message(self, message, creator = None, + MessageType = pg_msg.Message + ): + fields = self.decode_notice(message) + m = fields.pop('message') + c = fields.pop('code') + + if fields['severity'].upper() == 'WARNING': + MessageType = self.lookup_warning(c) + + message = MessageType(m, code = c, details = fields, + creator = creator, source = 'SERVER') + message.database = self.database + message.emit() + return message + + def emit_client_message(self, message, creator = None, + MessageType = pg_msg.Message + ): + fields = { + notice_field_to_name[k] : v + for k, v in message.items() + # don't include unknown messages in this list. + if k in notice_field_to_name + } + m = fields.pop('message') + c = fields.pop('code') + + if fields['severity'].upper() == 'WARNING': + MessageType = self.lookup_warning(c) + + message = MessageType(m, code = c, details = fields, + creator = creator, source = 'CLIENT') + message.database = self.database + message.emit() + return message + + def emit_message(self, message, ClientNotice = element.ClientNotice, **kw): + if message.__class__ is ClientNotice: + return self.emit_client_message(message, **kw) + else: + return self.emit_server_message(message, **kw) + +## +# This class manages all the functionality used to get +# rows from a PostgreSQL portal/cursor. +class Output(object): + _output = None + _output_io = None + _output_formats = None + _output_attmap = None + + closed = False + cursor_id = None + statement = None + parameters = None + + _complete_message = None + + @abstractmethod + def _init(self): + """ + Bind a cursor based on the configured parameters. + """ + # The local initialization for the specific cursor. + + def __init__(self, cursor_id, wref = weakref.ref, ID = ID): + self.cursor_id = cursor_id + if self.statement is not None: + stmt = self.statement + self._output = stmt._output + self._output_io = stmt._output_io + self._row_constructor = stmt._row_constructor + self._output_formats = stmt._output_formats or () + self._output_attmap = stmt._output_attmap + + self._pq_cursor_id = self.database.typio.encode(cursor_id) + # If the cursor's id was generated, it should be garbage collected. + if cursor_id == ID(self): + self.database.pq.register_cursor(self, self._pq_cursor_id) + self._quoted_cursor_id = '"' + cursor_id.replace('"', '""') + '"' + self._init() + + def __iter__(self): + return self + + def close(self): + if self.closed is False: + self.database.pq.trash_cursor(self._pq_cursor_id) + self.closed = True + + def _ins(self, *args): + return xact.Instruction(*args, asynchook = self.database._receive_async) + + def _pq_xp_describe(self): + return (element.DescribePortal(self._pq_cursor_id),) + + def _pq_xp_bind(self): + return ( + element.Bind( + self._pq_cursor_id, + self.statement._pq_statement_id, + self.statement._input_formats, + self.statement._pq_parameters(self.parameters), + self._output_formats, + ), + ) + + def _pq_xp_fetchall(self): + return ( + element.Bind( + b'', + self.statement._pq_statement_id, + self.statement._input_formats, + self.statement._pq_parameters(self.parameters), + self._output_formats, + ), + element.Execute(b'', 0xFFFFFFFF), + ) + + def _pq_xp_declare(self): + return ( + element.Parse(b'', self.database.typio.encode( + declare_statement_string( + str(self._quoted_cursor_id), + str(self.statement.string) + ) + ), () + ), + element.Bind( + b'', b'', self.statement._input_formats, + self.statement._pq_parameters(self.parameters), () + ), + element.Execute(b'', 1), + ) + + def _pq_xp_execute(self, quantity): + return ( + element.Execute(self._pq_cursor_id, quantity), + ) + + def _pq_xp_fetch(self, direction, quantity): + ## + # It's an SQL declared cursor, manually construct the fetch commands. + qstr = "FETCH " + ("FORWARD " if direction else "BACKWARD ") + if quantity is None: + qstr = qstr + "ALL IN " + self._quoted_cursor_id + else: + qstr = qstr \ + + str(quantity) + " IN " + self._quoted_cursor_id + return ( + element.Parse(b'', self.database.typio.encode(qstr), ()), + element.Bind(b'', b'', (), (), self._output_formats), + # The "limit" is defined in the fetch query. + element.Execute(b'', 0xFFFFFFFF), + ) + + def _pq_xp_move(self, position, whence): + return ( + element.Parse(b'', + b'MOVE ' + whence + b' ' + position + b' IN ' + \ + self.database.typio.encode(self._quoted_cursor_id), + () + ), + element.Bind(b'', b'', (), (), ()), + element.Execute(b'', 1), + ) + + def _process_copy_chunk(self, x): + if x: + if x[0].__class__ is not bytes or x[-1].__class__ is not bytes: + return [ + y for y in x if y.__class__ is bytes + ] + return x + + # Process the element.Tuple message in x for column() + def _process_tuple_chunk_Column(self, x, range = range): + unpack = self._output_io[0] + # get the raw data for the first column + l = [y[0] for y in x] + # iterate over the range to keep track + # of which item we're processing. + r = range(len(l)) + try: + return [unpack(l[i]) for i in r] + except Exception: + cause = sys.exc_info()[1] + try: + i = next(r) + except StopIteration: + i = len(l) + self._raise_column_tuple_error(cause, self._output_io, (l[i],), 0) + + # Process the element.Tuple message in x for rows() + def _process_tuple_chunk_Row(self, x, + proc = process_chunk, + ): + rc = self._row_constructor + return [ + rc(y) + for y in proc(self._output_io, x, self._raise_column_tuple_error) + ] + + # Process the elemnt.Tuple messages in `x` for chunks() + def _process_tuple_chunk(self, x, proc = process_chunk): + return proc(self._output_io, x, self._raise_column_tuple_error) + + def _raise_column_tuple_error(self, cause, procs, tup, itemnum): + # For column processing. + # The element traceback will include the full list of parameters. + data = repr(tup[itemnum]) + if len(data) > 80: + # Be sure not to fill screen with noise. + data = data[:75] + ' ...' + + em = element.ClientError(( + (b'S', 'ERROR'), + (b'C', "--CIO"), + (b'M', "failed to unpack column %r, %s::%s, from wire data" %( + itemnum, + self.column_names[itemnum], + self.database.typio.sql_type_from_oid( + self.statement.pg_column_types[itemnum] + ) or '', + ) + ), + (b'D', data), + (b'H', "Try casting the column to 'text'."), + (b'P', str(itemnum)), + )) + self.database.typio.raise_client_error(em, creator = self, cause = cause) + + @property + def state(self): + if self.closed: + return 'closed' + else: + return 'open' + + @property + def column_names(self): + if self._output is not None: + return list(self.database.typio.decodes(self._output.keys())) + # `None` if _output does not exist; not row data + + @property + def column_types(self): + if self._output is not None: + return [self.database.typio.type_from_oid(x[3]) for x in self._output] + # `None` if _output does not exist; not row data + + @property + def pg_column_types(self): + if self._output is not None: + return [x[3] for x in self._output] + # `None` if _output does not exist; not row data + + @property + def sql_column_types(self): + return [ + self.database.typio.sql_type_from_oid(x) + for x in self.pg_column_types + ] + + def command(self): + """ + The completion message's command identifier. + """ + if self._complete_message is not None: + return self._complete_message.extract_command().decode('ascii') + + def count(self): + """ + The completion message's count number. + """ + if self._complete_message is not None: + return self._complete_message.extract_count() + +class Chunks(Output, pg_api.Chunks): + pass + +## +# FetchAll - A Chunks cursor that gets *all* the records in the cursor. +# +# It has added complexity over other variants as in order to stream results, +# chunks have to be removed from the protocol transaction's received messages. +# If this wasn't done, the entire result set would be fully buffered prior +# to processing. +class FetchAll(Chunks): + _e_factors = ('statement', 'parameters',) + def _e_metas(self): + yield ('type', type(self).__name__) + + def __init__(self, statement, parameters): + self.statement = statement + self.parameters = parameters + self.database = statement.database + Output.__init__(self, '') + + def _init(self, + null = element.Null.type, + complete = element.Complete.type, + bindcomplete = element.BindComplete.type, + parsecomplete = element.ParseComplete.type, + ): + expect = self._expect + self._xact = self._ins( + self._pq_xp_fetchall() + (element.SynchronizeMessage,) + ) + self.database._pq_push(self._xact, self) + + # Get more messages until the first Tuple is seen. + STEP = self.database._pq_step + while self._xact.state != xact.Complete: + STEP() + for x in self._xact.messages_received(): + if x.__class__ is tuple or expect == x.type: + # No need to step anymore once this is seen. + return + elif x.type == null: + # The protocol transaction is going to be complete.. + self.database._pq_complete() + self._xact = None + return + elif x.type == complete: + self._complete_message = x + self.database._pq_complete() + # If this was a select/copy cursor, + # the data messages would have caused an earlier + # return. It's empty. + self._xact = None + return + elif x.type in (bindcomplete, parsecomplete): + # Noise. + pass + else: + # This should have been caught by the protocol transaction. + # "Can't happen". + self.database._pq_complete() + if self._xact.fatal is None: + self._xact.fatal = False + self._xact.error_message = element.ClientError(( + (b'S', 'ERROR'), + (b'C', "--000"), + (b'M', "unexpected message type " + repr(x.type)) + )) + self.database.typio.raise_client_error(self._xact.error_message, creator = self) + return + + def __next__(self, + data_types = (tuple,bytes), + complete = element.Complete.type, + ): + x = self._xact + # self._xact = None; means that the cursor has been exhausted. + if x is None: + raise StopIteration + + # Finish the protocol transaction. + STEP = self.database._pq_step + while x.state is not xact.Complete and not x.completed: + STEP() + + # fatal is None == no error + # fatal is True == dead connection + # fatal is False == dead transaction + if x.fatal is not None: + self.database.typio.raise_error(x.error_message, creator = getattr(self, '_controller', self) or self) + + # no messages to process? + if not x.completed: + # Transaction has been cleaned out of completed? iterator is done. + self._xact = None + self.close() + raise StopIteration + + # Get the chunk to be processed. + chunk = [ + y for y in x.completed[0][1] + if y.__class__ in data_types + ] + r = self._process_chunk(chunk) + + # Scan for _complete_message. + # Arguably, this can fail, but it would be a case + # where multiple sync messages were issued. Something that's + # not naturally occurring. + for y in x.completed[0][1][-3:]: + if getattr(y, 'type', None) == complete: + self._complete_message = y + + # Remove it, it's been processed. + del x.completed[0] + return r + +class SingleXactCopy(FetchAll): + _expect = element.CopyToBegin.type + _process_chunk = FetchAll._process_copy_chunk + +class SingleXactFetch(FetchAll): + _expect = element.Tuple.type + +class MultiXactStream(Chunks): + chunksize = 1024 * 4 + # only tuple streams + _process_chunk = Output._process_tuple_chunk + + def _e_metas(self): + yield ('chunksize', self.chunksize) + yield ('type', self.__class__.__name__) + + def __init__(self, statement, parameters, cursor_id): + self.statement = statement + self.parameters = parameters + self.database = statement.database + Output.__init__(self, cursor_id or ID(self)) + + @abstractmethod + def _bind(self): + """ + Generate the commands needed to bind the cursor. + """ + + @abstractmethod + def _fetch(self): + """ + Generate the commands needed to bind the cursor. + """ + + def _init(self): + self._command = self._fetch() + self._xact = self._ins(self._bind() + self._command) + self.database._pq_push(self._xact, self) + + def __next__(self, tuple_type = tuple): + x = self._xact + if x is None: + raise StopIteration + + if self.database.pq.xact is x: + self.database._pq_complete() + + # get all the element.Tuple messages + chunk = [ + y for y in x.messages_received() if y.__class__ is tuple_type + ] + if len(chunk) == self.chunksize: + # there may be more, dispatch the request for the next chunk + self._xact = self._ins(self._command) + self.database._pq_push(self._xact, self) + else: + # it's done. + self._xact = None + self.close() + if not chunk: + # chunk is empty, it's done *right* now. + raise StopIteration + chunk = self._process_chunk(chunk) + return chunk + +## +# The cursor is streamed to the client on demand *inside* +# a single SQL transaction block. +class MultiXactInsideBlock(MultiXactStream): + _bind = MultiXactStream._pq_xp_bind + def _fetch(self): + ## + # Use the extended protocol's execute to fetch more. + return self._pq_xp_execute(self.chunksize) + \ + (element.SynchronizeMessage,) + +## +# The cursor is streamed to the client on demand *outside* of +# a single SQL transaction block. [DECLARE ... WITH HOLD] +class MultiXactOutsideBlock(MultiXactStream): + _bind = MultiXactStream._pq_xp_declare + + def _fetch(self): + ## + # Use the extended protocol's execute to fetch more *against* + # an SQL FETCH statement yielding the data in the proper format. + # + # MultiXactOutsideBlock uses DECLARE to create the cursor WITH HOLD. + # When this is done, the cursor is configured to use StringFormat with + # all columns. It's necessary to use FETCH to adjust the formatting. + return self._pq_xp_fetch(True, self.chunksize) + \ + (element.SynchronizeMessage,) + +## +# Cursor is used to manage scrollable cursors. +class Cursor(Output, pg_api.Cursor): + _process_tuple = Output._process_tuple_chunk_Row + def _e_metas(self): + yield ('direction', 'FORWARD' if self.direction else 'BACKWORD') + yield ('type', 'Cursor') + + def clone(self): + return type(self)(self.statement, self.parameters, self.database, None) + + def __init__(self, statement, parameters, database, cursor_id): + self.database = database or statement.database + self.statement = statement + self.parameters = parameters + self.__dict__['direction'] = True + if self.statement is None: + self._e_factors = ('database', 'cursor_id') + Output.__init__(self, cursor_id or ID(self)) + + def get_direction(self): + return self.__dict__['direction'] + def set_direction(self, value): + self.__dict__['direction'] = direction_to_bool(value) + direction = property( + fget = get_direction, + fset = set_direction, + ) + del get_direction, set_direction + + def _which_way(self, direction): + if direction is not None: + direction = direction_to_bool(direction) + # -1 * -1 = 1, -1 * 1 = -1, 1 * 1 = 1 + return not ((not self.direction) ^ (not direction)) + else: + return self.direction + + def _init(self, + tupledesc = element.TupleDescriptor.type, + ): + """ + Based on the cursor parameters and the current transaction state, + select a cursor strategy for managing the response from the server. + """ + if self.statement is not None: + x = self._ins(self._pq_xp_declare() + (element.SynchronizeMessage,)) + self.database._pq_push(x, self) + self.database._pq_complete() + else: + x = self._ins(self._pq_xp_describe() + (element.SynchronizeMessage,)) + self.database._pq_push(x, self) + self.database._pq_complete() + for m in x.messages_received(): + if m.type == tupledesc: + typio = self.database.typio + self._output = m + self._output_attmap = typio.attribute_map(self._output) + self._row_constructor = typio.RowTypeFactory(self._output_attmap) + # tuple output + self._output_io = typio.resolve_descriptor( + self._output, 1 # (input, output)[1] + ) + self._output_formats = [ + element.StringFormat + if x is None + else element.BinaryFormat + for x in self._output_io + ] + self._output_io = tuple([ + x or typio.decode for x in self._output_io + ]) + + def __next__(self): + result = self._fetch(self.direction, 1) + if not result: + raise StopIteration + else: + return result[0] + + def read(self, quantity = None, direction = None): + if quantity == 0: + return [] + dir = self._which_way(direction) + return self._fetch(dir, quantity) + + def _fetch(self, direction, quantity): + x = self._ins( + self._pq_xp_fetch(direction, quantity) + \ + (element.SynchronizeMessage,) + ) + self.database._pq_push(x, self) + self.database._pq_complete() + return self._process_tuple(( + y for y in x.messages_received() if y.__class__ is tuple + )) + + def seek(self, offset, whence = 'ABSOLUTE'): + rwhence = self._seek_whence_map.get(whence, whence) + if rwhence is None or rwhence.upper() not in \ + self._seek_whence_map.values(): + raise TypeError( + "unknown whence parameter, %r" %(whence,) + ) + rwhence = rwhence.upper() + + if offset == 'ALL': + if rwhence not in ('BACKWARD', 'FORWARD'): + rwhence = 'BACKWARD' if self.direction is False else 'FORWARD' + else: + if offset < 0 and rwhence == 'BACKWARD': + offset = -offset + rwhence = 'FORWARD' + + if self.direction is False: + if offset == 'ALL' and rwhence != 'FORWARD': + rwhence = 'BACKWARD' + else: + if rwhence == 'RELATIVE': + offset = -offset + elif rwhence == 'ABSOLUTE': + rwhence = 'FROM_END' + else: + rwhence = 'ABSOLUTE' + + if rwhence in ('RELATIVE', 'BACKWARD', 'FORWARD'): + if offset == 'ALL': + cmd = self._pq_xp_move( + str(offset).encode('ascii'), str(rwhence).encode('ascii') + ) + else: + if offset < 0: + cmd = self._pq_xp_move( + str(-offset).encode('ascii'), b'BACKWARD' + ) + else: + cmd = self._pq_xp_move( + str(offset).encode('ascii'), str(rwhence).encode('ascii') + ) + elif rwhence == 'ABSOLUTE': + cmd = self._pq_xp_move(str(offset).encode('ascii'), b'ABSOLUTE') + else: + # move to last record, then consume it to put the position at + # the very end of the cursor. + cmd = self._pq_xp_move(b'', b'LAST') + \ + self._pq_xp_move(b'', b'NEXT') + \ + self._pq_xp_move(str(offset).encode('ascii'), b'BACKWARD') + + x = self._ins(cmd + (element.SynchronizeMessage,),) + self.database._pq_push(x, self) + self.database._pq_complete() + + count = None + complete = element.Complete.type + for cm in x.messages_received(): + if getattr(cm, 'type', None) == complete: + count = cm.extract_count() + break + + # XXX: Raise if count is None? + return count + +class SingleExecution(pg_api.Execution): + database = None + def __init__(self, database): + self._prepare = database.prepare + + def load_rows(self, query, *parameters): + return self._prepare(query).load_rows(*parameters) + + def load_chunks(self, query, *parameters): + return self._prepare(query).load_chunks(*parameters) + + def __call__(self, query, *parameters): + return self._prepare(query)(*parameters) + + def declare(self, query, *parameters): + return self._prepare(query).declare(*parameters) + + def rows(self, query, *parameters): + return self._prepare(query).rows(*parameters) + + def chunks(self, query, *parameters): + return self._prepare(query).chunks(*parameters) + + def column(self, query, *parameters): + return self._prepare(query).column(*parameters) + + def first(self, query, *parameters): + return self._prepare(query).first(*parameters) + +class Statement(pg_api.Statement): + string = None + database = None + statement_id = None + _input = None + _output = None + _output_io = None + _output_formats = None + _output_attmap = None + + def _e_metas(self): + yield (None, '[' + self.state + ']') + if hasattr(self._xact, 'error_message'): + # be very careful not to trigger an exception. + # even in the cases of effective protocol errors, + # it is important not to bomb out. + pos = self._xact.error_message.get(b'P') + if pos is not None and pos.isdigit(): + try: + pos = int(pos) + # get the statement source + q = str(self.string) + # normalize position.. + pos = len('\n'.join(q[:pos].splitlines())) + # normalize newlines + q = '\n'.join(q.splitlines()) + line_no = q.count('\n', 0, pos) + 1 + # replace tabs with spaces because there is no way to identify + # the tab size of the final display. (ie, marker will be wrong) + q = q.replace('\t', ' ') + # grab the relevant part of the query string. + # the full source will be printed elsewhere. + # beginning of string or the newline before the position + bov = q.rfind('\n', 0, pos) + 1 + # end of string or the newline after the position + eov = q.find('\n', pos) + if eov == -1: + eov = len(q) + view = q[bov:eov] + # position relative to the beginning of the view + pos = pos-bov + # analyze lines prior to position + dlines = view.splitlines() + marker = ((pos-1) * ' ') + '^' + ( + ' [line %d, character %d] ' %(line_no, pos) + ) + # insert marker + dlines.append(marker) + yield ('LINE', os.linesep.join(dlines)) + except: + import traceback + yield ('LINE', traceback.format_exc(chain=False)) + spt = self.sql_parameter_types + if spt is not None: + yield ('sql_parameter_types', spt) + cn = self.column_names + ct = self.sql_column_types + if cn is not None: + if ct is not None: + yield ( + 'results', + '(' + ', '.join([ + '{!r} {!r}'.format(n, t) for n,t in zip(cn,ct) + ]) + ')' + ) + else: + yield ('sql_column_names', cn) + elif ct is not None: + yield ('sql_column_types', ct) + + def clone(self): + ps = self.__class__(self.database, None, self.string) + ps._init() + ps._fini() + return ps + + def __init__(self, + database, statement_id, string, + wref = weakref.ref + ): + self.database = database + self.string = string + self.statement_id = statement_id or ID(self) + self._xact = None + self.closed = None + self._pq_statement_id = database.typio._encode(self.statement_id)[0] + + if not statement_id: + # Register statement on a connection to close it automatically on db end + database.pq.register_statement(self, self._pq_statement_id) + + def __repr__(self): + return '<{mod}.{name}[{ci}] {state}>'.format( + mod = self.__class__.__module__, + name = self.__class__.__name__, + ci = self.database.connector._pq_iri, + state = self.state, + ) + + def _pq_parameters(self, parameters, proc = process_tuple): + return proc( + self._input_io, parameters, + self._raise_parameter_tuple_error + ) + + ## + # process_tuple failed(exception). The parameters could not be packed. + # This function is called with the given information in the context + # of the original exception(to allow chaining). + def _raise_parameter_tuple_error(self, cause, procs, tup, itemnum): + # Find the SQL type name. This should *not* hit the server. + typ = self.database.typio.sql_type_from_oid( + self.pg_parameter_types[itemnum] + ) or '' + + # Representation of the bad parameter. + bad_data = repr(tup[itemnum]) + if len(bad_data) > 80: + # Be sure not to fill screen with noise. + bad_data = bad_data[:75] + ' ...' + + em = element.ClientError(( + (b'S', 'ERROR'), + (b'C', '--PIO'), + (b'M', "could not pack parameter %s::%s for transfer" %( + ('$' + str(itemnum + 1)), typ, + ) + ), + (b'D', bad_data), + (b'H', "Try casting the parameter to 'text', then to the target type."), + (b'P', str(itemnum)) + )) + self.database.typio.raise_client_error(em, creator = self, cause = cause) + + ## + # Similar to the parameter variant. + def _raise_column_tuple_error(self, cause, procs, tup, itemnum): + # Find the SQL type name. This should *not* hit the server. + typ = self.database.typio.sql_type_from_oid( + self.pg_column_types[itemnum] + ) or '' + + # Representation of the bad column. + data = repr(tup[itemnum]) + if len(data) > 80: + # Be sure not to fill screen with noise. + data = data[:75] + ' ...' + + em = element.ClientError(( + (b'S', 'ERROR'), + (b'C', '--CIO'), + (b'M', "could not unpack column %r, %s::%s, from wire data" %( + itemnum, self.column_names[itemnum], typ + ) + ), + (b'D', data), + (b'H', "Try casting the column to 'text'."), + (b'P', str(itemnum)), + )) + self.database.typio.raise_client_error(em, creator = self, cause = cause) + + @property + def state(self) -> str: + if self.closed: + if self._xact is not None: + if self.string is not None: + return 'parsing' + else: + return 'describing' + return 'closed' + return 'prepared' + + @property + def column_names(self): + if self.closed is None: + self._fini() + if self._output is not None: + return list(self.database.typio.decodes(self._output.keys())) + + @property + def parameter_types(self): + if self.closed is None: + self._fini() + if self._input is not None: + return [self.database.typio.type_from_oid(x) for x in self._input] + + @property + def column_types(self): + if self.closed is None: + self._fini() + if self._output is not None: + return [ + self.database.typio.type_from_oid(x[3]) for x in self._output + ] + + @property + def pg_parameter_types(self): + if self.closed is None: + self._fini() + return self._input + + @property + def pg_column_types(self): + if self.closed is None: + self._fini() + if self._output is not None: + return [x[3] for x in self._output] + + @property + def sql_column_types(self): + if self.closed is None: + self._fini() + if self._output is not None: + return [ + self.database.typio.sql_type_from_oid(x) + for x in self.pg_column_types + ] + + @property + def sql_parameter_types(self): + if self.closed is None: + self._fini() + if self._input is not None: + return [ + self.database.typio.sql_type_from_oid(x) + for x in self.pg_parameter_types + ] + + def close(self): + if self.closed is False: + self.database.pq.trash_statement(self._pq_statement_id) + self.closed = True + + def _init(self): + """ + Push initialization messages to the server, but don't wait for + the return as there may be things that can be done while waiting + for the return. Use the _fini() to complete. + """ + if self.string is not None: + q = self.database.typio._encode(str(self.string))[0] + cmd = [ + element.CloseStatement(self._pq_statement_id), + element.Parse(self._pq_statement_id, q, ()), + ] + else: + cmd = [] + cmd.extend( + ( + element.DescribeStatement(self._pq_statement_id), + element.SynchronizeMessage, + ) + ) + self._xact = xact.Instruction(cmd, asynchook = self.database._receive_async) + self.database._pq_push(self._xact, self) + + def _fini(self, strfmt = element.StringFormat, binfmt = element.BinaryFormat): + """ + Complete initialization that the _init() method started. + """ + # assume that the transaction has been primed. + if self._xact is None: + raise RuntimeError("_fini called prior to _init; invalid state") + if self._xact is self.database.pq.xact: + try: + self.database._pq_complete() + except Exception: + self.closed = True + raise + + (*head, argtypes, tupdesc, last) = self._xact.messages_received() + + typio = self.database.typio + if tupdesc is None or tupdesc is element.NoDataMessage: + # Not typed output. + self._output = None + self._output_attmap = None + self._output_io = None + self._output_formats = None + self._row_constructor = None + else: + self._output = tupdesc + self._output_attmap = dict( + typio.attribute_map(tupdesc) + ) + self._row_constructor = self.database.typio.RowTypeFactory(self._output_attmap) + # tuple output + self._output_io = typio.resolve_descriptor(tupdesc, 1) + self._output_formats = [ + strfmt if x is None else binfmt + for x in self._output_io + ] + self._output_io = tuple([ + x or typio.decode for x in self._output_io + ]) + + self._input = argtypes + packs = [] + formats = [] + for x in argtypes: + pack = (typio.resolve(x) or (None,None))[0] + packs.append(pack or typio.encode) + formats.append( + strfmt if x is None else binfmt + ) + self._input_io = tuple(packs) + self._input_formats = formats + self.closed = False + self._xact = None + + def __call__(self, *parameters): + if self._input is not None: + if len(parameters) != len(self._input): + raise TypeError("statement requires %d parameters, given %d" %( + len(self._input), len(parameters) + )) + ## + # get em' all! + if self._output is None: + # might be a copy. + c = SingleXactCopy(self, parameters) + else: + c = SingleXactFetch(self, parameters) + c._process_chunk = c._process_tuple_chunk_Row + + # iff output is None, it's not a tuple returning query. + # however, if it's a copy, detect that fact by SingleXactCopy's + # immediate return after finding the copy begin message(no complete). + if self._output is None: + cmd = c.command() + if cmd is not None: + return (cmd, c.count()) + # Returns rows, accumulate in a list. + r = [] + for x in c: + r.extend(x) + return r + + def declare(self, *parameters): + if self.closed is None: + self._fini() + if self._input is not None: + if len(parameters) != len(self._input): + raise TypeError("statement requires %d parameters, given %d" %( + len(self._input), len(parameters) + )) + return Cursor(self, parameters, self.database, None) + + def rows(self, *parameters, **kw): + chunks = self.chunks(*parameters, **kw) + if chunks._output_io: + chunks._process_chunk = chunks._process_tuple_chunk_Row + return chain.from_iterable(chunks) + __iter__ = rows + + def chunks(self, *parameters): + if self.closed is None: + self._fini() + if self._input is not None: + if len(parameters) != len(self._input): + raise TypeError("statement requires %d parameters, given %d" %( + len(self._input), len(parameters) + )) + + if self._output is None: + # It's *probably* a COPY. + return SingleXactCopy(self, parameters) + if self.database.pq.state == b'I': + # Currently, *not* in a Transaction block, so + # DECLARE the statement WITH HOLD in order to allow + # access across transactions. + if self.string is not None: + return MultiXactOutsideBlock(self, parameters, None) + else: + ## + # Statement source unknown, so it can't be DECLARE'd. + # This happens when statement_from_id is used. + return SingleXactFetch(self, parameters) + else: + # Likely, the best possible case. It gets to use Execute messages. + return MultiXactInsideBlock(self, parameters, None) + + def column(self, *parameters, **kw): + chunks = self.chunks(*parameters, **kw) + chunks._process_chunk = chunks._process_tuple_chunk_Column + return chain.from_iterable(chunks) + + def first(self, *parameters): + if self.closed is None: + # Not fully initialized; assume interrupted. + self._fini() + if self._input is not None: + # Use a regular TypeError. + if len(parameters) != len(self._input): + raise TypeError("statement requires %d parameters, given %d" %( + len(self._input), len(parameters) + )) + + # Parameters? Build em'. + db = self.database + + if self._input_io: + params = process_tuple( + self._input_io, parameters, + self._raise_parameter_tuple_error + ) + else: + params = () + + # Run the statement + x = xact.Instruction(( + element.Bind( + b'', + self._pq_statement_id, + self._input_formats, + params, + self._output_formats or (), + ), + # Get all + element.Execute(b'', 0xFFFFFFFF), + element.ClosePortal(b''), + element.SynchronizeMessage + ), + asynchook = db._receive_async + ) + # Push and complete protocol transaction. + db._pq_push(x, self) + db._pq_complete() + + if self._output_io: + ## + # It returned rows, look for the first tuple. + tuple_type = element.Tuple.type + for xt in x.messages_received(): + if xt.__class__ is tuple: + break + else: + return None + + if len(self._output_io) > 1: + # Multiple columns, return a Row. + return self._row_constructor( + process_tuple( + self._output_io, xt, + self._raise_column_tuple_error + ) + ) + else: + # Single column output. + if xt[0] is None: + return None + io = self._output_io[0] or self.database.typio.decode + return io(xt[0]) + else: + ## + # It doesn't return rows, so return a count. + ## + # This loop searches through the received messages + # for the Complete message which contains the count. + complete = element.Complete.type + for cm in x.messages_received(): + # Use getattr because COPY doesn't produce + # element.Message instances. + if getattr(cm, 'type', None) == complete: + break + else: + # Probably a Null command. + return None + + count = cm.extract_count() + if count is None: + command = cm.extract_command() + if command is not None: + return command.decode('ascii') + return count + + def _load_copy_chunks(self, chunks, *parameters): + """ + Given an chunks of COPY lines, execute the COPY ... FROM STDIN + statement and send the copy lines produced by the iterable to + the remote end. + """ + x = xact.Instruction(( + element.Bind( + b'', + self._pq_statement_id, + (), (), (), + ), + element.Execute(b'', 1), + element.SynchronizeMessage, + ), + asynchook = self.database._receive_async + ) + self.database._pq_push(x, self) + + # localize + step = self.database._pq_step + + # Get the COPY started. + while x.state is not xact.Complete: + step() + if hasattr(x, 'CopyFailSequence') and x.messages is x.CopyFailSequence: + # The protocol transaction has noticed that its a COPY. + break + else: + # Oh, it's not a COPY at all. + x.fatal = x.fatal or False + x.error_message = element.ClientError(( + (b'S', 'ERROR'), + # OperationError + (b'C', '--OPE'), + (b'M', "_load_copy_chunks() used on a non-COPY FROM STDIN query"), + )) + self.database.typio.raise_client_error(x.error_message, creator = self) + + for chunk in chunks: + x.messages = list(chunk) + while x.messages is not x.CopyFailSequence: + # Continue stepping until the transaction + # sets the CopyFailSequence again. That's + # the signal that the transaction has sent + # all the previously set messages. + step() + x.messages = x.CopyDoneSequence + self.database._pq_complete() + self.database.pq.synchronize() + + def _load_tuple_chunks(self, chunks, tuple=tuple): + pte = self._raise_parameter_tuple_error + last = (element.SynchronizeMessage,) + + Bind = element.Bind + Instruction = xact.Instruction + Execute = element.Execute + + try: + for chunk in chunks: + bindings = [ + ( + Bind( + b'', + self._pq_statement_id, + self._input_formats, + process_tuple( + self._input_io, tuple(t), pte + ), + (), + ), + Execute(b'', 1), + ) + for t in chunk + ] + bindings.append(last) + self.database._pq_push( + Instruction( + chain.from_iterable(bindings), + asynchook = self.database._receive_async + ), + self + ) + self.database._pq_complete() + except: + ## + # In cases where row packing errors or occur, + # synchronize, finishing any pending transaction, + # and raise the error. + ## + # If the data sent to the remote end is invalid, + # _complete will raise the exception and the current + # exception being marked as the cause, so there should + # be no [exception] information loss. + ## + self.database.pq.synchronize() + raise + + def load_chunks(self, chunks, *parameters): + """ + Execute the query for each row-parameter set in `iterable`. + + In cases of ``COPY ... FROM STDIN``, iterable must be an iterable of + sequences of `bytes`. + """ + if self.closed is None: + self._fini() + if not self._input or parameters: + return self._load_copy_chunks(chunks) + else: + return self._load_tuple_chunks(chunks) + + def load_rows(self, rows, chunksize = 256): + return self.load_chunks(chunk(rows, chunksize)) +PreparedStatement = Statement + +class StoredProcedure(pg_api.StoredProcedure): + _e_factors = ('database', 'procedure_id') + procedure_id = None + + def _e_metas(self): + yield ('oid', self.oid) + + def __repr__(self): + return '<%s:%s>' %( + self.procedure_id, self.statement.string + ) + + def __call__(self, *args, **kw): + if kw: + input = [] + argiter = iter(args) + try: + word_idx = [(kw[k], self._input_attmap[k]) for k in kw] + except KeyError as k: + raise TypeError("%s got unexpected keyword argument %r" %( + self.name, k.message + ) + ) + word_idx.sort(key = get1) + current_word = word_idx.pop(0) + for x in range(argc): + if x == current_word[1]: + input.append(current_word[0]) + current_word = word_idx.pop(0) + else: + input.append(argiter.next()) + else: + input = args + + if self.srf is True: + if self.composite is True: + return self.statement.rows(*input) + else: + # A generator expression is very appropriate here + # as SRFs returning large number of rows would require + # substantial amounts of memory. + return map(get0, self.statement.rows(*input)) + else: + if self.composite is True: + return self.statement(*input)[0] + else: + return self.statement(*input)[0][0] + + def __init__(self, ident, database, description = ()): + # Lookup pg_proc on database. + if isinstance(ident, int): + proctup = database.sys.lookup_procedure_oid(int(ident)) + else: + proctup = database.sys.lookup_procedure_rp(str(ident)) + if proctup is None: + raise LookupError("no function with identifier %s" %(str(ident),)) + + self.procedure_id = ident + self.oid = proctup[0] + self.name = proctup["proname"] + + self._input_attmap = {} + argnames = proctup.get('proargnames') or () + for x in range(len(argnames)): + an = argnames[x] + if an is not None: + self._input_attmap[an] = x + + proargs = proctup['proargtypes'] + for x in proargs: + # get metadata filled out. + database.typio.resolve(x) + + self.statement = database.prepare( + "SELECT * FROM %s(%s) AS func%s" %( + proctup['_proid'], + # ($1::type, $2::type, ... $n::type) + ', '.join([ + '$%d::%s' %(x + 1, database.typio.sql_type_from_oid(proargs[x])) + for x in range(0, len(proargs)) + ]), + # Description for anonymous record returns + (description and \ + '(' + ','.join(description) + ')' or '') + ) + ) + self.srf = bool(proctup.get("proretset")) + self.composite = proctup["composite"] + +class SettingsCM(object): + def __init__(self, database, settings_to_set): + self.database = database + self.settings_to_set = settings_to_set + + def __enter__(self): + if hasattr(self, 'stored_settings'): + raise RuntimeError("cannot re-use setting CMs") + self.stored_settings = self.database.settings.getset( + self.settings_to_set.keys() + ) + self.database.settings.update(self.settings_to_set) + + def __exit__(self, typ, val, tb): + self.database.settings.update(self.stored_settings) + +class Settings(pg_api.Settings): + _e_factors = ('database',) + + def __init__(self, database): + self.database = database + self.cache = {} + + def _e_metas(self): + yield (None, str(len(self.cache))) + + def _clear_cache(self): + self.cache.clear() + + def __getitem__(self, i): + v = self.cache.get(i) + if v is None: + r = self.database.sys.setting_get(i) + + if r: + v = r[0][0] + else: + raise KeyError(i) + return v + + def __setitem__(self, i, v): + cv = self.cache.get(i) + if cv == v: + return + setas = self.database.sys.setting_set(i, v) + self.cache[i] = setas + + def __delitem__(self, k): + self.database.execute( + 'RESET "' + k.replace('"', '""') + '"' + ) + self.cache.pop(k, None) + + def __len__(self): + return self.database.sys.setting_len() + + def __call__(self, **settings): + return SettingsCM(self.database, settings) + + def path(): + def fget(self): + return pg_str.split_ident(self["search_path"]) + def fset(self, value): + self['search_path'] = ','.join([ + '"%s"' %(x.replace('"', '""'),) for x in value + ]) + def fdel(self): + if self.database.connector.path is not None: + self.path = self.database.connector.path + else: + self.database.execute("RESET search_path") + doc = 'structured search_path interface' + return locals() + path = property(**path()) + + def get(self, k, alt = None): + if k in self.cache: + return self.cache[k] + + db = self.database + r = self.database.sys.setting_get(k) + if r: + v = r[0][0] + self.cache[k] = v + else: + v = alt + return v + + def getset(self, keys): + setmap = {} + rkeys = [] + for k in keys: + v = self.cache.get(k) + if v is not None: + setmap[k] = v + else: + rkeys.append(k) + + if rkeys: + r = self.database.sys.setting_mget(rkeys) + self.cache.update(r) + setmap.update(r) + rem = set(rkeys) - set([x['name'] for x in r]) + if rem: + raise KeyError(rem) + return setmap + + def keys(self): + return map(get0, self.database.sys.setting_keys()) + __iter__ = keys + + def values(self): + return map(get0, self.database.sys.setting_values()) + + def items(self): + return self.database.sys.setting_items() + + def update(self, d): + kvl = [list(x) for x in dict(d).items()] + self.cache.update(self.database.sys.setting_update(kvl)) + + def _notify(self, msg): + subs = getattr(self, '_subscriptions', {}) + d = self.database.typio._decode + key = d(msg.name)[0] + val = d(msg.value)[0] + for x in subs.get(key, ()): + x(self.database, key, val) + if None in subs: + for x in subs[None]: + x(self.database, key, val) + self.cache[key] = val + + def subscribe(self, key, callback): + """ + Subscribe to changes of the setting using the callback. When the setting + is changed, the callback will be invoked with the connection, the key, + and the new value. If the old value is locally cached, its value will + still be available for inspection, but there is no guarantee. + If `None` is passed as the key, the callback will be called whenever any + setting is remotely changed. + + >>> def watch(connection, key, newval): + ... + >>> db.settings.subscribe('TimeZone', watch) + """ + subs = self._subscriptions = getattr(self, '_subscriptions', {}) + callbacks = subs.setdefault(key, []) + if callback not in callbacks: + callbacks.append(callback) + + def unsubscribe(self, key, callback): + """ + Stop listening for changes to a setting. The setting name(`key`), and + the callback used to subscribe must be given again for successful + termination of the subscription. + + >>> db.settings.unsubscribe('TimeZone', watch) + """ + subs = getattr(self, '_subscriptions', {}) + callbacks = subs.get(key, ()) + if callback in callbacks: + callbacks.remove(callback) + +class Transaction(pg_api.Transaction): + database = None + + mode = None + isolation = None + + _e_factors = ('database', 'isolation', 'mode') + + def _e_metas(self): + yield (None, self.state) + + def __init__(self, database, isolation = None, mode = None): + self.database = database + self.isolation = isolation + self.mode = mode + self.state = 'initialized' + self.type = None + + def __enter__(self): + self.start() + return self + + def __exit__(self, typ, value, tb): + if typ is None: + # No exception, but in a failed transaction? + if self.database.pq.state == b'E': + if not self.database.closed: + self.rollback() + # pg_exc.InFailedTransactionError + em = element.ClientError(( + (b'S', 'ERROR'), + (b'C', '25P02'), + (b'M', 'invalid transaction block exit detected'), + (b'H', "Database was in an error-state, but no exception was raised.") + )) + self.database.typio.raise_client_error(em, creator = self) + else: + # No exception, and no error state. Everything is good. + try: + self.commit() + # If an error occurs, clean up the transaction state + # and raise as needed. + except pg_exc.ActiveTransactionError as err: + if not self.database.closed: + # adjust the state so rollback will do the right thing and abort. + self.state = 'open' + self.rollback() + raise + elif issubclass(typ, Exception): + # There's an exception, so only rollback if the connection + # exists. If the rollback() was called here, it would just + # contribute noise to the error. + if not self.database.closed: + self.rollback() + + @staticmethod + def _start_xact_string(isolation = None, mode = None): + q = 'START TRANSACTION' + if isolation is not None: + if ';' in isolation: + raise ValueError("invalid transaction isolation " + repr(mode)) + q += ' ISOLATION LEVEL ' + isolation + if mode is not None: + if ';' in mode: + raise ValueError("invalid transaction mode " + repr(isolation)) + q += ' ' + mode + return q + ';' + + @staticmethod + def _savepoint_xact_string(id): + return 'SAVEPOINT "xact(' + id.replace('"', '""') + ')";' + + def start(self): + if self.state == 'open': + return + if self.state != 'initialized': + em = element.ClientError(( + (b'S', 'ERROR'), + (b'C', '--OPE'), + (b'M', "transactions cannot be restarted"), + (b'H', 'Create a new transaction object instead of re-using an old one.') + )) + self.database.typio.raise_client_error(em, creator = self) + + if self.database.pq.state == b'I': + self.type = 'block' + q = self._start_xact_string( + isolation = self.isolation, + mode = self.mode, + ) + else: + self.type = 'savepoint' + if (self.isolation, self.mode) != (None,None): + em = element.ClientError(( + (b'S', 'ERROR'), + (b'C', '--OPE'), + (b'M', "configured transaction used inside a transaction block"), + (b'H', 'A transaction block was already started.'), + )) + self.database.typio.raise_client_error(em, creator = self) + q = self._savepoint_xact_string(hex(id(self))) + self.database.execute(q) + self.state = 'open' + begin = start + + @staticmethod + def _release_string(id): + # Release ""; + return 'RELEASE "xact(' + id.replace('"', '""') + ')";' + + def commit(self): + if self.state == 'committed': + return + if self.state != 'open': + em = element.ClientError(( + (b'S', 'ERROR'), + (b'C', '--OPE'), + (b'M', "commit attempted on transaction with unexpected state, " + repr(self.state)), + )) + self.database.typio.raise_client_error(em, creator = self) + + if self.type == 'block': + q = 'COMMIT' + else: + q = self._release_string(hex(id(self))) + self.database.execute(q) + self.state = 'committed' + + @staticmethod + def _rollback_to_string(id, fmt = 'ROLLBACK TO "xact({0})"; RELEASE "xact({0})";'.format): + return fmt(id.replace('"', '""')) + + def rollback(self): + if self.state == 'aborted': + return + if self.state not in ('prepared', 'open'): + em = element.ClientError(( + (b'S', 'ERROR'), + (b'C', '--OPE'), + (b'M', "ABORT attempted on transaction with unexpected state, " + repr(self.state)), + )) + self.database.typio.raise_client_error(em, creator = self) + + if self.type == 'block': + q = 'ABORT;' + elif self.type == 'savepoint': + q = self._rollback_to_string(hex(id(self))) + else: + raise RuntimeError("unknown transaction type " + repr(self.type)) + self.database.execute(q) + self.state = 'aborted' + abort = rollback + +class Connection(pg_api.Connection): + connector = None + + type = None + version_info = None + version = None + + security = None + backend_id = None + client_address = None + client_port = None + + # Replaced with instances on connection instantiation. + settings = Settings + + def _e_metas(self): + yield (None, '[' + self.state + ']') + if self.client_address is not None: + yield ('client_address', self.client_address) + if self.client_port is not None: + yield ('client_port', self.client_port) + if self.version is not None: + yield ('version', self.version) + att = getattr(self, 'failures', None) + if att: + count = 0 + for x in att: + # Format each failure without their traceback. + errstr = ''.join(format_exception(type(x.error), x.error, None)) + factinfo = str(x.socket_factory) + if hasattr(x, 'ssl_negotiation'): + if x.ssl_negotiation is True: + factinfo = 'SSL ' + factinfo + else: + factinfo = 'NOSSL ' + factinfo + yield ( + 'failures[' + str(count) + ']', + factinfo + os.linesep + errstr + ) + count += 1 + + def __repr__(self): + return '<%s.%s[%s] %s>' %( + type(self).__module__, + type(self).__name__, + self.connector._pq_iri, + self.closed and 'closed' or '%s' %(self.pq.state,) + ) + + def __exit__(self, type, value, tb): + # Don't bother closing unless it's a normal exception. + if type is None or issubclass(type, Exception): + self.close() + + def interrupt(self, timeout = None): + self.pq.interrupt(timeout = timeout) + + def execute(self, query : str) -> None: + q = xact.Instruction(( + element.Query(self.typio._encode(query)[0]), + ), + asynchook = self._receive_async + ) + self._pq_push(q, self) + self._pq_complete() + + def do(self, language : str, source : str, + qlit = pg_str.quote_literal, + qid = pg_str.quote_ident, + ) -> None: + sql = "DO " + qlit(source) + " LANGUAGE " + qid(language) + ";" + self.execute(sql) + + # Alias transaction as xact. xact is the original term, but support + # the full word for identifier consistency with asyncpg. + def transaction(self, isolation = None, mode = None) -> Transaction: + x = Transaction(self, isolation = isolation, mode = mode) + return x + xact=transaction + + def prepare(self, + sql_statement_string : str, + statement_id = None, + Class = Statement + ) -> Statement: + ps = Class(self, statement_id, sql_statement_string) + ps._init() + + # Complete protocol transaction to maintain point of origin in error cases. + ps._fini() + return ps + + @property + def query(self, Class = SingleExecution): + return Class(self) + + def statement_from_id(self, statement_id : str) -> Statement: + ps = Statement(self, statement_id, None) + ps._init() + ps._fini() + return ps + + def proc(self, proc_id : (str, int)) -> StoredProcedure: + sp = StoredProcedure(proc_id, self) + return sp + + def cursor_from_id(self, cursor_id : str) -> Cursor: + c = Cursor(None, None, self, cursor_id) + c._init() + return c + + @property + def closed(self) -> bool: + if getattr(self, 'pq', None) is None: + return True + if hasattr(self.pq, 'socket') and self.pq.xact is not None: + return self.pq.xact.fatal is True + return False + + def close(self, getattr = getattr): + # Write out the disconnect message if the socket is around. + # If the connection is known to be lost, don't bother. It will + # generate an extra exception. + if getattr(self, 'pq', None) is None or getattr(self.pq, 'socket', None) is None: + # No action to take. + return + + x = getattr(self.pq, 'xact', None) + if x is not None and x.fatal is not True: + # finish the existing pq transaction iff it's not Closing. + self.pq.complete() + + if self.pq.xact is None: + # It completed the existing transaction. + self.pq.push(xact.Closing()) + self.pq.complete() + if self.pq.socket: + self.pq.complete() + + # Close the socket if there is one. + if self.pq.socket: + self.pq.socket.close() + self.pq.socket = None + + @property + def state(self) -> str: + if not hasattr(self, 'pq'): + return 'initialized' + if hasattr(self, 'failures'): + return 'failed' + if self.closed: + return 'closed' + if isinstance(self.pq.xact, xact.Negotiation): + return 'negotiating' + if self.pq.xact is None: + if self.pq.state == b'E': + return 'failed block' + return 'idle' + (' in block' if self.pq.state != b'I' else '') + else: + return 'busy' + + def reset(self): + """ + restore original settings, reset the transaction, drop temporary + objects. + """ + self.execute("ABORT; RESET ALL;") + + def __enter__(self): + self.connect() + return self + + def connect(self): + """ + Establish the connection to the server. + """ + if self.closed is False: + # already connected? just return. + return + + if hasattr(self, 'pq'): + # It's closed, *but* there's a PQ connection.. + x = self.pq.xact + self.typio.raise_error(x.error_message, cause = getattr(x, 'exception', None), creator = self) + + # It's closed. + try: + self._establish() + except Exception: + # Close it up on failure. + self.close() + raise + + def _establish(self): + # guts of connect() + self.pq = None + # if any exception occurs past this point, the connection + # object will not be usable. + timeout = self.connector.connect_timeout + sslmode = self.connector.sslmode or 'prefer' + failures = [] + exc = None + try: + # get the list of sockets to try + socket_factories = self.connector.socket_factory_sequence() + except Exception as e: + socket_factories = () + exc = e + + # When ssl is None: SSL negotiation will not occur. + # When ssl is True: SSL negotiation will occur *and* it must succeed. + # When ssl is False: SSL negotiation will occur but it may fail(NOSSL). + if sslmode == 'allow': + # without ssl, then with. :) + socket_factories = interlace( + zip(repeat(None, len(socket_factories)), socket_factories), + zip(repeat(True, len(socket_factories)), socket_factories) + ) + elif sslmode == 'prefer': + # with ssl, then without. [maybe] :) + socket_factories = interlace( + zip(repeat(False, len(socket_factories)), socket_factories), + zip(repeat(None, len(socket_factories)), socket_factories) + ) + # prefer is special, because it *may* be possible to + # skip the subsequent "without" in situations where SSL is off. + elif sslmode == 'require': + socket_factories = zip(repeat(True, len(socket_factories)), socket_factories) + elif sslmode == 'disable': + # None = Do Not Attempt SSL negotiation. + socket_factories = zip(repeat(None, len(socket_factories)), socket_factories) + else: + raise ValueError("invalid sslmode: " + repr(sslmode)) + + # can_skip is used when 'prefer' or 'allow' is the sslmode. + # if the ssl negotiation returns 'N' (nossl), then + # ssl "failed", but the socket is still usable for nossl. + # in these cases, can_skip is set to True so that the + # subsequent non-ssl attempt is skipped if it failed with the 'N' response. + can_skip = False + startup = self.connector._startup_parameters + password = self.connector._password + Connection3 = client.Connection + for (ssl, sf) in socket_factories: + if can_skip is True: + # the last attempt failed and knows this attempt will fail too. + can_skip = False + continue + pq = Connection3(sf, startup, password = password,) + if hasattr(self, 'tracer'): + pq.tracer = self.tracer + + # Grab the negotiation transaction before + # connecting as it will be needed later if successful. + neg = pq.xact + pq.connect(ssl = ssl, timeout = timeout) + + didssl = getattr(pq, 'ssl_negotiation', -1) + + # It successfully connected if pq.xact is None; + # The startup/negotiation xact completed. + if pq.xact is None: + self.pq = pq + if hasattr(self.pq.socket, 'fileno'): + self.fileno = self.pq.socket.fileno + self.security = 'ssl' if didssl is True else None + showoption_type = element.ShowOption.type + for x in neg.asyncs: + if x.type == showoption_type: + self._receive_async(x) + # success! + break + elif pq.socket is not None: + # In this case, an application/protocol error occurred. + # Close out the sockets ourselves. + pq.socket.close() + + # Identify whether or not we can skip the attempt. + # Whether or not we can skip depends entirely on the SSL parameter. + if sslmode == 'prefer' and ssl is False and didssl is False: + # In this case, the server doesn't support SSL or it's + # turned off. Therefore, the "without_ssl" attempt need + # *not* be ran because it has already been noted to be + # a failure. + can_skip = True + elif hasattr(pq.xact, 'exception'): + # If a Python exception occurred, chances are that it is + # going to fail again iff it is going to hit the same host. + if sslmode == 'prefer' and ssl is False: + # when 'prefer', the first attempt + # is marked with ssl is "False" + can_skip = True + elif sslmode == 'allow' and ssl is None: + # when 'allow', the first attempt + # is marked with dossl is "None" + can_skip = True + + try: + self.typio.raise_error(pq.xact.error_message) + except Exception as error: + pq.error = error + # Otherwise, infinite recursion in the element traceback. + error.creator = None + # The tracebacks of the specific failures aren't particularly useful.. + error.__traceback__ = None + if getattr(pq.xact, 'exception', None) is not None: + pq.error.__cause__ = pq.xact.exception + + failures.append(pq) + else: + # No servers available. (see the break-statement in the for-loop) + self.failures = failures or () + # it's over. + self.typio.raise_client_error(could_not_connect, creator = self, cause = exc) + ## + # connected, now initialize connection information. + self.backend_id = self.pq.backend_id + + sv = self.settings.cache.get("server_version", "0.0") + self.version_info = pg_version.normalize(pg_version.split(sv)) + # manual binding + self.sys = pg_lib.Binding(self, pg_lib.sys) + + vi = self.version_info[:2] + if vi <= (8,1): + sd = self.sys.startup_data_only_version() + elif vi >= (9,2): + sd = self.sys.startup_data_92() + else: + sd = self.sys.startup_data() + # connection info + self.version, self.backend_start, \ + self.client_address, self.client_port = sd + + # First word from the version string. + self.type = self.version.split()[0] + + ## + # Set standard_conforming_strings + scstr = self.settings.get('standard_conforming_strings') + if scstr is None or vi == (8,1): + # There used to be a warning emitted here. + # It was noisy, and had little added value + # over a nice WARNING at the top of the driver documentation. + pass + elif scstr.lower() not in ('on','true','yes'): + self.settings['standard_conforming_strings'] = 'on' + + super().connect() + + def _pq_push(self, xact, controller = None): + x = self.pq.xact + if x is not None: + self.pq.complete() + if x.fatal is not None: + self.typio.raise_error(x.error_message) + if controller is not None: + self._controller = controller + self.pq.push(xact) + + # Complete the current protocol transaction. + def _pq_complete(self): + pq = self.pq + x = pq.xact + if x is not None: + # There is a running transaction, finish it. + pq.complete() + # Raise an error *iff* one occurred. + if x.fatal is not None: + self.typio.raise_error(x.error_message, cause = getattr(x, 'exception', None)) + del self._controller + + # Process the next message. + def _pq_step(self, complete_state = globals()['xact'].Complete): + pq = self.pq + x = pq.xact + if x is not None: + pq.step() + # If the protocol transaction was completed by + # the last step, raise the error *iff* one occurred. + if x.state is complete_state: + if x.fatal is not None: + self.typio.raise_error(x.error_message, cause = getattr(x, 'exception', None)) + del self._controller + + def _receive_async(self, + msg, controller = None, + showoption = element.ShowOption.type, + notice = element.Notice.type, + notify = element.Notify.type, + ): + c = controller or getattr(self, '_controller', self) + typ = msg.type + if typ == showoption: + if msg.name == b'client_encoding': + self.typio.set_encoding(msg.value.decode('ascii')) + self.settings._notify(msg) + elif typ == notice: + m = self.typio.emit_message(msg, creator = c) + elif typ == notify: + self._notifies.append(msg) + else: + self.typio.emit_client_message( + element.ClientNotice(( + (b'C', '-1000'), + (b'S', 'WARNING'), + (b'M', 'cannot process unrecognized asynchronous message'), + (b'D', repr(msg)), + )), + creator = c + ) + + def clone(self, *args, **kw): + c = self.__class__(self.connector, *args, **kw) + c.connect() + return c + + def notify(self, *channels, **channel_and_payload): + notifies = "" + if channels: + notifies += ';'.join(( + 'NOTIFY "' + x.replace('"', '""') + '"' # str() case + if x.__class__ is not tuple else ( + # tuple() case + 'NOTIFY "' + x[0].replace('"', '""') + """",'""" + \ + x[1].replace("'", "''") + "'" + ) + for x in channels + )) + notifies += ';' + if channel_and_payload: + notifies += ';'.join(( + 'NOTIFY "' + channel.replace('"', '""') + """",'""" + \ + payload.replace("'", "''") + "'" + for channel, payload in channel_and_payload.items() + )) + notifies += ';' + return self.execute(notifies) + + def listening_channels(self): + if self.version_info[:2] > (8,4): + return self.sys.listening_channels() + else: + return self.sys.listening_relations() + + def listen(self, *channels, len = len): + qstr = '' + for x in channels: + # XXX: hardcoded identifier length? + if len(x) > 63: + raise ValueError("channel name too long: " + x) + qstr += '; LISTEN ' + x.replace('"', '""') + return self.execute(qstr) + + def unlisten(self, *channels, len = len): + qstr = '' + for x in channels: + # XXX: hardcoded identifier length? + if len(x) > 63: + raise ValueError("channel name too long: " + x) + qstr += '; UNLISTEN ' + x.replace('"', '""') + return self.execute(qstr) + + def iternotifies(self, timeout = None): + nm = NotificationManager(self, timeout = timeout) + for x in nm: + if x is None: + yield None + else: + for y in x[1]: + yield y + + def __init__(self, connector, *args, **kw): + """ + Create a connection based on the given connector. + """ + self.connector = connector + # raw notify messages + self._notifies = [] + self.fileno = -1 + self.typio = self.connector.driver.typio(self) + self.typio.set_encoding('ascii') + self.settings = Settings(self) +# class Connection + +class Connector(pg_api.Connector): + """ + All arguments to Connector are keywords. At the very least, user, + and socket, may be provided. If socket, unix, or process is not + provided, host and port must be. + """ + @property + def _pq_iri(self): + return pg_iri.serialize( + { + k : v for k,v in self.__dict__.items() + if v is not None and not k.startswith('_') and k not in ( + 'driver', 'category' + ) + }, + obscure_password = True + ) + + def _e_metas(self): + yield (None, '[' + self.__class__.__name__ + '] ' + self._pq_iri) + + def __repr__(self): + keywords = (',' + os.linesep + ' ').join([ + '%s = %r' %(k, getattr(self, k, None)) for k in self.__dict__ + if not k.startswith('_') and getattr(self, k, None) is not None + ]) + return '{mod}.{name}({keywords})'.format( + mod = type(self).__module__, + name = type(self).__name__, + keywords = os.linesep + ' ' + keywords if keywords else '' + ) + + @abstractmethod + def socket_factory_sequence(self): + """ + Generate a list of callables that will be used to attempt to make the + connection to the server. It is assumed that each factory will produce + an object with a socket interface that is ready for reading and writing + data. + + The callables in the sequence must take a timeout parameter. + """ + + def __init__(self, + connect_timeout : int = None, + server_encoding = None, + sslmode : ('allow', 'prefer', 'require', 'disable') = None, + sslcrtfile = None, + sslkeyfile = None, + sslrootcrtfile = None, + sslrootcrlfile = None, + driver = None, + **kw + ): + super().__init__(**kw) + self.driver = driver + + self.server_encoding = server_encoding + self.connect_timeout = connect_timeout + self.sslmode = sslmode + self.sslkeyfile = sslkeyfile + self.sslcrtfile = sslcrtfile + self.sslrootcrtfile = sslrootcrtfile + self.sslrootcrlfile = sslrootcrlfile + + if self.sslrootcrlfile is not None: + pg_exc.IgnoredClientParameterWarning( + "certificate revocation lists are *not* checked", + creator = self, + ).emit() + + # Startup message parameters. + tnkw = { + 'client_min_messages' : 'WARNING', + } + if self.settings: + s = dict(self.settings) + if 'search_path' in self.settings: + sp = s.get('search_path') + if sp is None: + self.settings.pop('search_path') + elif not isinstance(sp, str): + s['search_path'] = ','.join( + pg_str.quote_ident(x) for x in sp + ) + tnkw.update(s) + + tnkw['user'] = self.user + if self.database is not None: + tnkw['database'] = self.database + + se = self.server_encoding or 'utf-8' + ## + # Attempt to accommodate for literal treatment of startup data. + ## + self._startup_parameters = tuple([ + # All keys go in utf-8. However, ascii would probably be good enough. + ( + k.encode('utf-8'), + # If it's a str(), encode in the hinted server_encoding. + # Otherwise, convert the object(int, float, bool, etc) into a string + # and treat it as utf-8. + v.encode(se) if type(v) is str else str(v).encode('utf-8') + ) + for k, v in tnkw.items() + ]) + self._password = (self.password or '').encode(se) + self._socket_secure = { + 'keyfile' : self.sslkeyfile, + 'certfile' : self.sslcrtfile, + 'ca_certs' : self.sslrootcrtfile, + } +# class Connector + +class SocketConnector(Connector): + """ + Abstract connector for using `socket` and `ssl`. + """ + @abstractmethod + def socket_factory_sequence(self): + """ + Return a sequence of `SocketFactory`s for a connection to use to connect + to the target host. + """ + + def create_socket_factory(self, **params): + return SocketFactory(**params) + +class IPConnector(SocketConnector): + def socket_factory_sequence(self): + return self._socketcreators + + def socket_factory_params(self, host, port, ipv, **kw): + if ipv != self.ipv: + raise TypeError("'ipv' keyword must be '%d'" % self.ipv) + if host is None: + raise TypeError("'host' is a required keyword and cannot be 'None'") + if port is None: + raise TypeError("'port' is a required keyword and cannot be 'None'") + + return {'socket_create': (self.address_family, socket.SOCK_STREAM), + 'socket_connect': (host, int(port))} + + def __init__(self, host, port, ipv, **kw): + params = self.socket_factory_params(host, port, ipv, **kw) + self.host, self.port = params['socket_connect'] + # constant socket connector + self._socketcreator = self.create_socket_factory(**params) + self._socketcreators = (self._socketcreator,) + super().__init__(**kw) + +class IP4(IPConnector): + """ + Connector for establishing IPv4 connections. + """ + ipv = 4 + address_family = socket.AF_INET + + def __init__(self, + host : str = None, + port : int = None, + ipv = 4, + **kw + ): + super().__init__(host, port, ipv, **kw) + +class IP6(IPConnector): + """ + Connector for establishing IPv6 connections. + """ + ipv = 6 + address_family = socket.AF_INET6 + + def __init__(self, + host : str = None, + port : int = None, + ipv = 6, + **kw + ): + super().__init__(host, port, ipv, **kw) + +class Unix(SocketConnector): + """ + Connector for establishing unix domain socket connections. + """ + def socket_factory_sequence(self): + return self._socketcreators + + def socket_factory_params(self, unix): + if unix is None: + raise TypeError("'unix' is a required keyword and cannot be 'None'") + + return {'socket_create': (socket.AF_UNIX, socket.SOCK_STREAM), + 'socket_connect': unix} + + def __init__(self, unix = None, **kw): + params = self.socket_factory_params(unix) + self.unix = params['socket_connect'] + # constant socket connector + self._socketcreator = self.create_socket_factory(**params) + self._socketcreators = (self._socketcreator,) + super().__init__(**kw) + +class Host(SocketConnector): + """ + Connector for establishing hostname based connections. + + This connector exercises socket.getaddrinfo. + """ + def socket_factory_sequence(self): + """ + Return a list of `SocketCreator`s based on the results of + `socket.getaddrinfo`. + """ + return [ + # (AF, socktype, proto), (IP, Port) + self.create_socket_factory(**(self.socket_factory_params(x[0:3], x[4][:2], + self._socket_secure))) + for x in socket.getaddrinfo( + self.host, self.port, self._address_family, socket.SOCK_STREAM + ) + ] + + def socket_factory_params(self, socktype, address, sslparams): + return {'socket_create': socktype, + 'socket_connect': address, + 'socket_secure': sslparams} + + def __init__(self, + host : str = None, + port : (str, int) = None, + ipv : int = None, + address_family = None, + **kw + ): + if host is None: + raise TypeError("'host' is a required keyword") + if port is None: + raise TypeError("'port' is a required keyword") + + if address_family is not None and ipv is not None: + raise TypeError("'ipv' and 'address_family' on mutually exclusive") + + if ipv is None: + self._address_family = address_family or socket.AF_UNSPEC + elif ipv == 4: + self._address_family = socket.AF_INET + elif ipv == 6: + self._address_family = socket.AF_INET6 + else: + raise TypeError("unknown IP version selected: 'ipv' = " + repr(ipv)) + self.host = host + self.port = port + super().__init__(**kw) + +class Driver(pg_api.Driver): + def _e_metas(self): + yield (None, type(self).__module__ + '.' + type(self).__name__) + + def ip4(self, **kw): + return IP4(driver = self, **kw) + + def ip6(self, **kw): + return IP6(driver = self, **kw) + + def host(self, **kw): + return Host(driver = self, **kw) + + def unix(self, **kw): + return Unix(driver = self, **kw) + + def fit(self, + unix = None, + host = None, + port = None, + **kw + ) -> Connector: + """ + Create the appropriate `postgresql.api.Connector` based on the parameters. + """ + if unix is not None: + if host is not None: + raise TypeError("'unix' and 'host' keywords are exclusive") + if port is not None: + raise TypeError("'unix' and 'port' keywords are exclusive") + return self.unix(unix = unix, **kw) + else: + if host is None or port is None: + raise TypeError("'host' and 'port', or 'unix' must be supplied") + + # If it's an IP address, IP4 or IP6 should be selected. + if ':' in host: + # There's a ':' in host, good chance that it's IPv6. + try: + socket.inet_pton(socket.AF_INET6, host) + return self.ip6(host = host, port = port, **kw) + except (socket.error, NameError): + pass + + # Not IPv6, maybe IPv4. + try: + socket.inet_aton(host) + # It's IP4 + return self.ip4(host = host, port = port, **kw) + except socket.error: + pass + + # neither host, nor port are None, probably a hostname. + return self.host(host = host, port = port, **kw) + + def connect(self, **kw) -> Connection: + """ + Create an established Connection instance from a temporary Connector + built using the given keywords. + + For information on acceptable keywords, see: + + `postgresql.documentation.driver`:Connection Keywords + """ + c = self.fit(**kw)() + c.connect() + return c + + def __init__(self, connection = Connection, typio = TypeIO): + self.connection = connection + self.typio = typio diff --git a/py_opengauss/encodings/__init__.py b/py_opengauss/encodings/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1b95b6af47b2115c479b3ec3e4d27334c4f4543 --- /dev/null +++ b/py_opengauss/encodings/__init__.py @@ -0,0 +1,3 @@ +## +# .encodings +## diff --git a/py_opengauss/encodings/aliases.py b/py_opengauss/encodings/aliases.py new file mode 100644 index 0000000000000000000000000000000000000000..7f86830fe7fc0437bd35be21b483c64079e40dda --- /dev/null +++ b/py_opengauss/encodings/aliases.py @@ -0,0 +1,58 @@ +## +# .encodings.aliases +## +""" +Module for mapping PostgreSQL encoding names to Python encoding names. + +These are **not** installed in Python's aliases. Rather, `get_python_name` +should be used directly. + +URLs of interest: + * http://docs.python.org/library/codecs.html + * http://git.postgresql.org/gitweb?p=postgresql.git;a=blob;f=src/backend/utils/mb/encnames.c +""" + +## +#: Dictionary of Postgres encoding names to Python encoding names. +#: This mapping only contains those encoding names that do not intersect. +postgres_to_python = { + 'unicode' : 'utf_8', + 'sql_ascii' : 'ascii', + 'euc_jp' : 'eucjp', + 'euc_cn' : 'euccn', + 'euc_kr' : 'euckr', + 'shift_jis_2004' : 'euc_jis_2004', + 'sjis' : 'shift_jis', + 'alt' : 'cp866', # IBM866 + 'abc' : 'cp1258', + 'vscii' : 'cp1258', + 'koi8r' : 'koi8_r', + 'koi8u' : 'koi8_u', + 'tcvn' : 'cp1258', + 'tcvn5712' : 'cp1258', +# 'euc_tw' : None, # N/A +# 'mule_internal' : None, # N/A +} + +def get_python_name(encname): + """ + Lookup the name in the `postgres_to_python` dictionary. If no match is + found, check for a 'win' or 'windows-' name and convert that to a 'cp###' + name. + + Returns `None` if there is no alias for `encname`. + + The win[0-9]+ and windows-[0-9]+ entries are handled functionally. + """ + # check the dictionary first + localname = postgres_to_python.get(encname) + if localname is not None: + return localname + # no explicit mapping, check for functional transformation + if encname.startswith('win'): + # handle win#### and windows-#### + # remove the trailing CP number + bare = encname.rstrip('0123456789') + if bare.strip('_-') in ('win', 'windows'): + return 'cp' + encname[len(bare):] + return encname diff --git a/py_opengauss/encodings/bytea.py b/py_opengauss/encodings/bytea.py new file mode 100644 index 0000000000000000000000000000000000000000..3c406dfa4e1f3a2bcf8eb1e7c1c2798b34f7f053 --- /dev/null +++ b/py_opengauss/encodings/bytea.py @@ -0,0 +1,69 @@ +## +# .encodings.bytea +## +'PostgreSQL bytea encoding and decoding functions' +import codecs +import struct +import sys + +ord_to_seq = { + i : \ + "\\" + oct(i)[2:].rjust(3, '0') \ + if not (32 < i < 126) else r'\\' \ + if i == 92 else chr(i) + for i in range(256) +} + +if sys.version_info[:2] >= (3, 3): + # Subscripting memory in 3.3 returns byte as an integer, not as a bytestring + def decode(data): + return ''.join(map(ord_to_seq.__getitem__, (data[x] for x in range(len(data))))) +else: + def decode(data): + return ''.join(map(ord_to_seq.__getitem__, (data[x][0] for x in range(len(data))))) + +def encode(data): + diter = ((data[i] for i in range(len(data)))) + output = [] + next = diter.__next__ + for x in diter: + if x == "\\": + try: + y = next() + except StopIteration: + raise ValueError("incomplete backslash sequence") + if y == "\\": + # It's a backslash, so let x(\) be appended. + x = ord(x) + elif y.isdigit(): + try: + os = ''.join((y, next(), next())) + except StopIteration: + # requires three digits + raise ValueError("incomplete backslash sequence") + try: + x = int(os, base = 8) + except ValueError: + raise ValueError("invalid bytea octal sequence '%s'" %(os,)) + else: + raise ValueError("invalid backslash follow '%s'" %(y,)) + else: + x = ord(x) + output.append(x) + return struct.pack(str(len(output)) + 'B', *output) + +class Codec(codecs.Codec): + 'bytea codec' + def encode(data, errors = 'strict'): + return (encode(data), len(data)) + encode = staticmethod(encode) + + def decode(data, errors = 'strict'): + return (decode(data), len(data)) + decode = staticmethod(decode) + +class StreamWriter(Codec, codecs.StreamWriter): pass +class StreamReader(Codec, codecs.StreamReader): pass + +bytea_codec = (Codec.encode, Codec.decode, StreamReader, StreamWriter) +codecs.register(lambda x: x == 'bytea' and bytea_codec or None) diff --git a/py_opengauss/exceptions.py b/py_opengauss/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..5d71d01c0020ae754c22b8d749504dc12c2d916f --- /dev/null +++ b/py_opengauss/exceptions.py @@ -0,0 +1,750 @@ +## +# .exceptions - Exception hierarchy for PostgreSQL database ERRORs. +## +""" +PostgreSQL exceptions and warnings with associated state codes. + +The primary entry points of this module is the `ErrorLookup` function and the +`WarningLookup` function. Given an SQL state code, they give back the most +appropriate Error or Warning subclass. + +For more information on error codes see: + http://www.postgresql.org/docs/current/static/errcodes-appendix.html + +This module is executable via -m: python -m postgresql.exceptions. +It provides a convenient way to look up the exception object mapped to by the +given error code:: + + $ python -m postgresql.exceptions XX000 + postgresql.exceptions.InternalError [XX000] + +If the exact error code is not found, it will try to find the error class's +exception(The first two characters of the error code make up the class +identity):: + + $ python -m postgresql.exceptions XX400 + postgresql.exceptions.InternalError [XX000] + +If that fails, it will return `postgresql.exceptions.Error` +""" +import sys +import os +from functools import partial +from operator import attrgetter +from .message import Message +from . import sys as pg_sys + +PythonException = Exception +class Exception(Exception): + """ + Base PostgreSQL exception class. + """ + pass + +class LoadError(Exception): + """ + Failed to load a library. + """ + +class Disconnection(Exception): + """ + Exception identifying errors that result in disconnection. + """ + +class Warning(Message): + code = '01000' + _e_label = property(attrgetter('__class__.__name__')) + +class DriverWarning(Warning): + code = '01-00' + source = 'CLIENT' +class IgnoredClientParameterWarning(DriverWarning): + 'Warn the user of a valid, but ignored parameter.' + code = '01-CP' +class TypeConversionWarning(DriverWarning): + 'Report a potential issue with a conversion.' + code = '01-TP' + +class DeprecationWarning(Warning): + code = '01P01' +class DynamicResultSetsReturnedWarning(Warning): + code = '0100C' +class ImplicitZeroBitPaddingWarning(Warning): + code = '01008' +class NullValueEliminatedInSetFunctionWarning(Warning): + code = '01003' +class PrivilegeNotGrantedWarning(Warning): + code = '01007' +class PrivilegeNotRevokedWarning(Warning): + code = '01006' +class StringDataRightTruncationWarning(Warning): + code = '01004' + +class NoDataWarning(Warning): + code = '02000' +class NoMoreSetsReturned(NoDataWarning): + code = '02001' + +class Error(Message, Exception): + """ + A PostgreSQL Error. + """ + _e_label = 'ERROR' + code = '' + + def __str__(self): + """ + Call .sys.errformat(self). + """ + return pg_sys.errformat(self) + + @property + def fatal(self): + f = self.details.get('severity') + return None if f is None else f in ('PANIC', 'FATAL') + +class DriverError(Error): + """ + Errors originating in the driver's implementation. + """ + source = 'CLIENT' + code = '--000' +class AuthenticationMethodError(DriverError, Disconnection): + """ + Server requested an authentication method that is not supported by the + driver. + """ + code = '--AUT' +class InsecurityError(DriverError, Disconnection): + """ + Error signifying a secure channel to a server cannot be established. + """ + code = '--SEC' +class ConnectTimeoutError(DriverError, Disconnection): + """ + Client was unable to esablish a connection in the given time. + """ + code = '--TOE' + +class TypeIOError(DriverError): + """ + Driver failed to pack or unpack a tuple. + """ + code = '--TIO' +class ParameterError(TypeIOError): + code = '--PIO' +class ColumnError(TypeIOError): + code = '--CIO' +class CompositeError(TypeIOError): + code = '--cIO' + +class OperationError(DriverError): + """ + An invalid operation on an interface element. + """ + code = '--OPE' + +class TransactionError(Error): + pass + +class SQLNotYetCompleteError(Error): + code = '03000' + +class ConnectionError(Error, Disconnection): + code = '08000' +class ConnectionDoesNotExistError(ConnectionError): + """ + The connection is closed or was never connected. + """ + code = '08003' +class ConnectionFailureError(ConnectionError): + """ + Raised when a connection is dropped. + """ + code = '08006' + +class ClientCannotConnectError(ConnectionError): + """ + Client was unable to establish a connection to the server. + """ + code = '08001' + +class ConnectionRejectionError(ConnectionError): + code = '08004' +class TransactionResolutionUnknownError(ConnectionError): + code = '08007' +class ProtocolError(ConnectionError): + code = '08P01' + +class TriggeredActionError(Error): + code = '09000' + +class FeatureError(Error): + """ + "Unsupported feature. + """ + code = '0A000' + +class TransactionInitiationError(TransactionError): + code = '0B000' + +class LocatorError(Error): + code = '0F000' +class LocatorSpecificationError(LocatorError): + code = '0F001' + +class GrantorError(Error): + code = '0L000' +class GrantorOperationError(GrantorError): + code = '0LP01' + +class RoleSpecificationError(Error): + code = '0P000' + +class CaseNotFoundError(Error): + code = '20000' + +class CardinalityError(Error): + """ + Wrong number of rows returned. + """ + code = '21000' + +class TriggeredDataChangeViolation(Error): + code = '27000' + +class AuthenticationSpecificationError(Error, Disconnection): + code = '28000' + +class DPDSEError(Error): + """ + Dependent Privilege Descriptors Still Exist. + """ + code = '2B000' +class DPDSEObjectError(DPDSEError): + code = '2BP01' + +class SREError(Error): + """ + SQL Routine Exception. + """ + code = '2F000' +class FunctionExecutedNoReturnStatementError(SREError): + code = '2F005' +class DataModificationProhibitedError(SREError): + code = '2F002' +class StatementProhibitedError(SREError): + code = '2F003' +class ReadingDataProhibitedError(SREError): + code = '2F004' + +class EREError(Error): + """ + External Routine Exception. + """ + code = '38000' +class ContainingSQLNotPermittedError(EREError): + code = '38001' +class ModifyingSQLDataNotPermittedError(EREError): + code = '38002' +class ProhibitedSQLStatementError(EREError): + code = '38003' +class ReadingSQLDataNotPermittedError(EREError): + code = '38004' + +class ERIEError(Error): + """ + External Routine Invocation Exception. + """ + code = '39000' +class InvalidSQLState(ERIEError): + code = '39001' +class NullValueNotAllowed(ERIEError): + code = '39004' +class TriggerProtocolError(ERIEError): + code = '39P01' +class SRFProtocolError(ERIEError): + code = '39P02' + +class TRError(TransactionError): + """ + Transaction Rollback. + """ + code = '40000' +class DeadlockError(TRError): + code = '40P01' +class IntegrityConstraintViolationError(TRError): + code = '40002' +class SerializationError(TRError): + code = '40001' +class StatementCompletionUnknownError(TRError): + code = '40003' + + +class ITSError(TransactionError): + """ + Invalid Transaction State. + """ + code = '25000' +class ActiveTransactionError(ITSError): + code = '25001' +class BranchAlreadyActiveError(ITSError): + code = '25002' +class BadAccessModeForBranchError(ITSError): + code = '25003' +class BadIsolationForBranchError(ITSError): + code = '25004' +class NoActiveTransactionForBranchError(ITSError): + code = '25005' +class ReadOnlyTransactionError(ITSError): + """ + Occurs when an alteration occurs in a read-only transaction. + """ + code = '25006' +class SchemaAndDataStatementsError(ITSError): + """ + Mixed schema and data statements not allowed. + """ + code = '25007' +class InconsistentCursorIsolationError(ITSError): + """ + The held cursor requires the same isolation. + """ + code = '25008' + +class NoActiveTransactionError(ITSError): + code = '25P01' +class InFailedTransactionError(ITSError): + """ + Occurs when an action occurs in a failed transaction. + """ + code = '25P02' + + +class SavepointError(TransactionError): + """ + Classification error designating errors that relate to savepoints. + """ + code = '3B000' +class InvalidSavepointSpecificationError(SavepointError): + code = '3B001' + +class TransactionTerminationError(TransactionError): + code = '2D000' + +class IRError(Error): + """ + Insufficient Resource Error. + """ + code = '53000' +class MemoryError(IRError, MemoryError): + code = '53200' +class DiskFullError(IRError): + code = '53100' +class TooManyConnectionsError(IRError): + code = '53300' + +class PLEError(OverflowError): + """ + Program Limit Exceeded + """ + code = '54000' +class ComplexityOverflowError(PLEError): + code = '54001' +class ColumnOverflowError(PLEError): + code = '54011' +class ArgumentOverflowError(PLEError): + code = '54023' + +class ONIPSError(Error): + """ + Object Not In Prerequisite State. + """ + code = '55000' +class ObjectInUseError(ONIPSError): + code = '55006' +class ImmutableRuntimeParameterError(ONIPSError): + code = '55P02' +class UnavailableLockError(ONIPSError): + code = '55P03' + + +class SEARVError(Error): + """ + Syntax Error or Access Rule Violation. + """ + code = '42000' + +class SEARVNameError(SEARVError): + code = '42602' +class NameTooLongError(SEARVError): + code = '42622' +class ReservedNameError(SEARVError): + code = '42939' + +class ForeignKeyCreationError(SEARVError): + code = '42830' + +class InsufficientPrivilegeError(SEARVError): + code = '42501' +class GroupingError(SEARVError): + code = '42803' + +class RecursionError(SEARVError): + code = '42P19' +class WindowError(SEARVError): + code = '42P20' + +class SyntaxError(SEARVError): + code = '42601' + +class TypeError(SEARVError): + pass +class CoercionError(TypeError): + code = '42846' +class TypeMismatchError(TypeError): + code = '42804' +class IndeterminateTypeError(TypeError): + code = '42P18' +class WrongObjectTypeError(TypeError): + code = '42809' + +class UndefinedError(SEARVError): + pass +class UndefinedColumnError(UndefinedError): + code = '42703' +class UndefinedFunctionError(UndefinedError): + code = '42883' +class UndefinedTableError(UndefinedError): + code = '42P01' +class UndefinedParameterError(UndefinedError): + code = '42P02' +class UndefinedObjectError(UndefinedError): + code = '42704' + +class DuplicateError(SEARVError): + pass +class DuplicateColumnError(DuplicateError): + code = '42701' +class DuplicateCursorError(DuplicateError): + code = '42P03' +class DuplicateDatabaseError(DuplicateError): + code = '42P04' +class DuplicateFunctionError(DuplicateError): + code = '42723' +class DuplicatePreparedStatementError(DuplicateError): + code = '42P05' +class DuplicateSchemaError(DuplicateError): + code = '42P06' +class DuplicateTableError(DuplicateError): + code = '42P07' +class DuplicateAliasError(DuplicateError): + code = '42712' +class DuplicateObjectError(DuplicateError): + code = '42710' + +class AmbiguityError(SEARVError): + pass +class AmbiguousColumnError(AmbiguityError): + code = '42702' +class AmbiguousFunctionError(AmbiguityError): + code = '42725' +class AmbiguousParameterError(AmbiguityError): + code = '42P08' +class AmbiguousAliasError(AmbiguityError): + code = '42P09' + +class ColumnReferenceError(SEARVError): + code = '42P10' + +class DefinitionError(SEARVError): + pass +class ColumnDefinitionError(DefinitionError): + code = '42611' +class CursorDefinitionError(DefinitionError): + code = '42P11' +class DatabaseDefinitionError(DefinitionError): + code = '42P12' +class FunctionDefinitionError(DefinitionError): + code = '42P13' +class PreparedStatementDefinitionError(DefinitionError): + code = '42P14' +class SchemaDefinitionError(DefinitionError): + code = '42P15' +class TableDefinitionError(DefinitionError): + code = '42P16' +class ObjectDefinitionError(DefinitionError): + code = '42P17' + + +class CursorStateError(Error): + code = '24000' + +class WithCheckOptionError(Error): + code = '44000' + +class NameError(Error): + pass +class CatalogNameError(NameError): + code = '3D000' +class CursorNameError(NameError): + code = '34000' +class StatementNameError(NameError): + code = '26000' +class SchemaNameError(NameError): + code = '3F000' + +class ICVError(Error): + """ + Integrity Contraint Violation. + """ + code = '23000' +class RestrictError(ICVError): + code = '23001' +class NotNullError(ICVError): + code = '23502' +class ForeignKeyError(ICVError): + code = '23503' +class UniqueError(ICVError): + code = '23505' +class CheckError(ICVError): + code = '23514' + + +class DataError(Error): + code = '22000' + +class StringRightTruncationError(DataError): + code = '22001' +class StringDataLengthError(DataError): + code = '22026' +class ZeroLengthString(DataError): + code = '2200F' + +class EncodingError(DataError): + code = '22021' +class ArrayElementError(DataError): + code = '2202E' +class SpecificTypeMismatch(DataError): + code = '2200G' + +class NullValueNotAllowedError(DataError): + code = '22004' +class NullValueNoIndicatorParameter(DataError): + code = '22002' + +class ZeroDivisionError(DataError): + code = '22012' +class FloatingPointError(DataError): + code = '22P01' +class AssignmentError(DataError): + code = '22005' +class IndicatorOverflowError(DataError): + code = '22022' +class BadCopyError(DataError): + code = '22P04' + +class TextRepresentationError(DataError): + code = '22P02' +class BinaryRepresentationError(DataError): + code = '22P03' +class UntranslatableCharacterError(DataError): + code = '22P05' +class NonstandardUseOfEscapeCharacterError(DataError): + code = '22P06' + +class NotXMLError(DataError): + code = '2200L' +class XMLDocumentError(DataError): + code = '2200M' +class XMLContentError(DataError): + code = '2200N' +class XMLCommentError(DataError): + code = '2200S' +class XMLProcessingInstructionError(DataError): + code = '2200T' + +class DateTimeFormatError(DataError): + code = '22007' +class TimeZoneDisplacementValueError(DataError): + code = '22009' +class DateTimeFieldOverflowError(DataError): + code = '22008' +class IntervalFieldOverflowError(DataError): + code = '22015' + +class LogArgumentError(DataError): + code = '2201E' +class PowerFunctionArgumentError(DataError): + code = '2201F' +class WidthBucketFunctionArgumentError(DataError): + code = '2201G' +class CastCharacterValueError(DataError): + code = '22018' + +class EscapeCharacterError(DataError): + code = '22019' +class EscapeOctetError(DataError): + code = '2200D' +class EscapeSequenceError(DataError): + code = '22025' +class EscapeCharacterConflictError(DataError): + code = '2200B' +class EscapeCharacterError(DataError): + """ + Invalid escape character. + """ + code = '2200C' + +class SubstringError(DataError): + code = '22011' +class TrimError(DataError): + code = '22027' +class IndicatorParameterValueError(DataError): + code = '22010' + +class LimitValueError(DataError): + code = '2201W' + pg_code = '22020' +class OffsetValueError(DataError): + code = '2201X' + +class ParameterValueError(DataError): + code = '22023' +class RegularExpressionError(DataError): + code = '2201B' +class NumericRangeError(DataError): + code = '22003' +class UnterminatedCStringError(DataError): + code = '22024' + + +class InternalError(Error): + code = 'XX000' +class DataCorruptedError(InternalError): + code = 'XX001' +class IndexCorruptedError(InternalError): + code = 'XX002' + +class SIOError(Error): + """ + System I/O. + """ + code = '58000' +class UndefinedFileError(SIOError): + code = '58P01' +class DuplicateFileError(SIOError): + code = '58P02' + +class CFError(Error): + """ + Configuration File Error. + """ + code = 'F0000' +class LockFileExistsError(CFError): + code = 'F0001' + +class OIError(Error): + """ + Operator Intervention. + """ + code = '57000' +class QueryCanceledError(OIError): + code = '57014' +class AdminShutdownError(OIError, Disconnection): + code = '57P01' +class CrashShutdownError(OIError, Disconnection): + code = '57P02' +class ServerNotReadyError(OIError, Disconnection): + """ + Thrown when a connection is established to a server that is still starting up. + """ + code = '57P03' + +class PLPGSQLError(Error): + """ + Error raised by a PL/PgSQL procedural function. + """ + code = 'P0000' +class PLPGSQLRaiseError(PLPGSQLError): + """ + Error raised by a PL/PgSQL RAISE statement. + """ + code = 'P0001' +class PLPGSQLNoDataFoundError(PLPGSQLError): + code = 'P0002' +class PLPGSQLTooManyRowsError(PLPGSQLError): + code = 'P0003' + + +# Setup mapping to provide code based exception lookup. +code_to_error = {} +code_to_warning = {} +def map_errors_and_warnings( + objs, + error_container = code_to_error, + warning_container = code_to_warning, +): + """ + Construct the code-to-error and code-to-warning associations. + """ + for obj in objs: + if not issubclass(type(obj), (type(Warning), type(Error))): + # It's not object of interest. + continue + code = getattr(obj, 'code', None) + if code is None: + # It has no code attribute, or the code was set to None. + # If it's code is None, we don't map it as it's a "container". + continue + + if issubclass(obj, Error): + base = Error + container = error_container + elif issubclass(obj, Warning): + base = Warning + container = warning_container + else: + continue + + cur_obj = container.get(code) + if cur_obj is None or issubclass(cur_obj, obj): + # There is no object yet, or the object at the code + # is not the most general class. + # The latter condition comes into play when + # there are sub-Class types that share the Class code + # with the most general type. (See TypeError) + container[code] = obj + if hasattr(obj, 'pg_code'): + # If there's a PostgreSQL version of the code, + # map it as well for older servers. + container[obj.pg_code] = obj + +def code_lookup( + default, + container, + code +): + obj = container.get(code) + if obj is None: + obj = container.get(code[:2] + "000", default) + return obj + +map_errors_and_warnings(sys.modules[__name__].__dict__.values()) +ErrorLookup = partial(code_lookup, Error, code_to_error) +WarningLookup = partial(code_lookup, Warning, code_to_warning) + +if __name__ == '__main__': + for x in sys.argv[1:]: + if x.startswith('01'): + e = WarningLookup(x) + else: + e = ErrorLookup(x) + sys.stdout.write('postgresql.exceptions.%s [%s]%s%s' %( + e.__name__, e.code, os.linesep, ( + e.__doc__ is not None and os.linesep.join([ + ' ' + x for x in (e.__doc__).split('\n') + ]) + os.linesep or '' + ) + ) + ) diff --git a/py_opengauss/installation.py b/py_opengauss/installation.py new file mode 100644 index 0000000000000000000000000000000000000000..046bff4417cdda15093d836c9a96e9e763971064 --- /dev/null +++ b/py_opengauss/installation.py @@ -0,0 +1,263 @@ +## +# .installation +## +""" +Collect and access PostgreSQL installation information. +""" +import sys +import os +import os.path +import subprocess +import errno +from itertools import cycle, chain +from operator import itemgetter +from .python.os import find_executable, close_fds, platform_exe +from . import versionstring +from . import api as pg_api +from . import string as pg_str + +# Get the output from the given command. +# Variable arguments are transformed into "long options", '--' + x +def get_command_output(exe, *args, encoding='utf-8', timeout=8): + pa = list(exe) + [ + '--' + x.strip() for x in args if x is not None + ] + p = subprocess.Popen(pa, + close_fds = close_fds, + stdout = subprocess.PIPE, + stderr = None, + stdin = None, + shell = False + ) + + try: + stdout, stderr = p.communicate(timeout=timeout) + except subprocess.TimeoutExpired: + p.kill() + stdout, stderr = p.communicate(timeout=2) + + if p.returncode != 0: + return None + + return stdout.decode(encoding) + +def pg_config_dictionary(*pg_config_path, encoding='utf-8', timeout=8): + """ + Create a dictionary of the information available in the given + pg_config_path. This provides a one-shot solution to fetching information + from the pg_config binary. Returns a dictionary object. + """ + default_output = get_command_output(pg_config_path, encoding=encoding, timeout=timeout) + if default_output is not None: + d = {} + for x in default_output.splitlines(): + if not x or x.isspace() or x.find('=') == -1: + continue + k, v = x.split('=', 1) + # keep it semi-consistent with instance + d[k.lower().strip()] = v.strip() + return d + + # Support for 8.0 pg_config and earlier. + # This requires three invocations of pg_config: + # First --help, to get the -- options available, + # Second, all the -- options except version. + # Third, --version as it appears to be exclusive in some cases. + opt = [] + for l in get_command_output(pg_config_path, 'help', encoding=encoding, timeout=timeout).splitlines(): + dash_pos = l.find('--') + if dash_pos == -1: + continue + sp_pos = l.find(' ', dash_pos) + # the dashes are added by the call command + opt.append(l[dash_pos+2:sp_pos]) + if 'help' in opt: + opt.remove('help') + if 'version' in opt: + opt.remove('version') + + d=dict(zip(opt, get_command_output(pg_config_path, *opt, encoding=encoding, timeout=timeout).splitlines())) + d['version'] = get_command_output(pg_config_path, 'version', encoding=encoding, timeout=timeout).strip() + return d + +## +# Build a key-value pair list of the configure options. +# If the item is quoted, mind the quotes. +def parse_configure_options(confopt, quotes = '\'"', dash_and_quotes = '-\'"'): + # This is not a robust solution, but it will usually work. + # Chances are that there is a quote at the beginning of this string. + # However, in the windows pg_config.exe, this appears to be absent. + if confopt[0:1] in quotes: + # quote at the beginning. assume it's used consistently. + quote = confopt[0:1] + elif confopt[-1:] in quotes: + # quote at the end? + quote = confopt[-1] + else: + # fallback to something. :( + quote = "'" + ## + # This is using the wrong kind of split, but the pg_config + # output has been consistent enough for this to work. + parts = pg_str.split_using(confopt, quote, sep = ' ') + qq = quote * 2 + for x in parts: + if qq in x: + # singularize the quotes + x = x.replace(qq, quote) + # remove the quotes around '--' from option. + # if it splits once, the '1' index will + # be `True`, indicating that the flag was given, but + # was not given a value. + kv = x.strip(dash_and_quotes).split('=', 1) + [True] + key = kv[0].replace('-','_') + # Ignore empty keys. + if key: + yield (key, kv[1]) + +def default_pg_config(execname = 'pg_config', envkey = 'PGINSTALLATION'): + """ + Get the default `pg_config` executable on the system. + + If 'PGINSTALLATION' is in the environment, use it. + Otherwise, look through the system's PATH environment. + """ + pg_config_path = os.environ.get(envkey) + if pg_config_path: + # Trust PGINSTALLATION. + return platform_exe(pg_config_path) + return find_executable(execname) + +class Installation(pg_api.Installation): + """ + Class providing a Python interface to PostgreSQL installation information. + """ + version = None + version_info = None + type = None + configure_options = None + #: The pg_config information dictionary. + info = None + + pg_executables = ( + 'pg_config', + 'psql', + 'initdb', + 'pg_resetxlog', + 'pg_controldata', + 'clusterdb', + 'pg_ctl', + 'pg_dump', + 'pg_dumpall', + 'postgres', + 'postmaster', + 'reindexdb', + 'vacuumdb', + 'ipcclean', + 'createdb', + 'ecpg', + 'createuser', + 'createlang', + 'droplang', + 'dropuser', + 'pg_restore', + ) + + pg_libraries = ( + 'libpq', + 'libecpg', + 'libpgtypes', + 'libecpg_compat', + ) + + pg_directories = ( + 'bindir', + 'docdir', + 'includedir', + 'pkgincludedir', + 'includedir_server', + 'libdir', + 'pkglibdir', + 'localedir', + 'mandir', + 'sharedir', + 'sysconfdir', + ) + + def _e_metas(self): + l = list(self.configure_options.items()) + l.sort(key = itemgetter(0)) + yield ('version', self.version) + if l: + yield ('configure_options', + (os.linesep).join(( + k if v is True else k + '=' + v + for k,v in l + )) + ) + + def __repr__(self, format = "{mod}.{name}({info!r})".format): + return format( + mod = type(self).__module__, + name = type(self).__name__, + info = self.info + ) + + def __init__(self, info : dict): + """ + Initialize the Installation using the given information dictionary. + """ + self.info = info + self.version = self.info["version"] + self.type, vs = self.version.split() + self.version_info = versionstring.normalize(versionstring.split(vs)) + self.configure_options = dict( + parse_configure_options(self.info.get('configure', '')) + ) + # collect the paths in a dictionary first + self.paths = dict() + + exists = os.path.exists + join = os.path.join + for k in self.pg_directories: + self.paths[k] = self.info.get(k) + + # find all the PG executables that exist for the installation. + bindir_path = self.info.get('bindir') + if bindir_path is None: + self.paths.update(zip(self.pg_executables, cycle((None,)))) + else: + for k in self.pg_executables: + path = platform_exe(join(bindir_path, k)) + if exists(path): + self.paths[k] = path + else: + self.paths[k] = None + self.__dict__.update(self.paths) + + @property + def ssl(self): + """ + Whether the installation was compiled with SSL support. + """ + return 'with_openssl' in self.configure_options + +def default(typ = Installation): + """ + Get the default Installation. + + Uses default_pg_config() to identify the executable. + """ + path = default_pg_config() + if path is None: + return None + return typ(pg_config_dictionary(path)) + +if __name__ == '__main__': + if sys.argv[1:]: + d = pg_config_dictionary(sys.argv[1]) + i = Installation(d) + else: + i = default() + from .python.element import format_element + print(format_element(i)) diff --git a/py_opengauss/iri.py b/py_opengauss/iri.py new file mode 100644 index 0000000000000000000000000000000000000000..150b22c1b5f9b859ec28d5df4a61f92355713fcc --- /dev/null +++ b/py_opengauss/iri.py @@ -0,0 +1,203 @@ +## +# .iri +## +""" +Parse and serialize PQ IRIs. + +PQ IRIs take the form:: + + pq://user:pass@host:port/database?setting=value&setting2=value2 + +IPv6 is supported via the standard representation:: + + pq://[::1]:5432/database + +Driver Parameters: + + pq://user@host/?[driver_param]=value&[other_param]=value?server_setting=val +""" +from .resolved import riparse as ri +from .string import split_ident + +from operator import itemgetter +get0 = itemgetter(0) +del itemgetter + +import re +escape_path_re = re.compile('[%s]' %(re.escape(ri.unescaped + ','),)) + +def structure(d, fieldproc = ri.unescape): + """ + Create a clientparams dictionary from a parsed RI. + """ + scheme = d.get('scheme', 'pq').lower() + if scheme not in {'pq', 'postgres', 'postgresql', 'og', 'opengauss'}: + raise ValueError("PQ-IRI scheme is not 'pq', 'postgres', 'postgresql', 'og' or 'opengauss'") + if scheme in {'og', 'opengauss'}: + # recover opengauss scheme to pq + d['scheme'] = 'pq' + + cpd = { + k : fieldproc(v) for k, v in d.items() + if k not in ('path', 'fragment', 'query', 'host', 'scheme') + } + + path = d.get('path') + frag = d.get('fragment') + query = d.get('query') + host = d.get('host') + + if host is not None: + if host.startswith('[') and host.endswith(']'): + host = host[1:-1] + if host.startswith('unix:'): + cpd['unix'] = host[len('unix:'):].replace(':','/') + else: + cpd['host'] = host + else: + cpd['host'] = fieldproc(host) + + if path: + # Only state the database field's existence if the first path is non-empty. + if path[0]: + cpd['database'] = path[0] + path = path[1:] + if path: + cpd['path'] = path + + settings = {} + if query: + if hasattr(query, 'items'): + qiter = query.items() + else: + qiter = query + for k, v in qiter: + if k.startswith('[') and k.endswith(']'): + k = k[1:-1] + if k != 'settings' and k not in cpd: + cpd[fieldproc(k)] = fieldproc(v) + elif k: + settings[fieldproc(k)] = fieldproc(v) + # else: ignore empty query keys + + if frag: + settings['search_path'] = [ + fieldproc(x) for x in frag.split(',') + ] + + if settings: + cpd['settings'] = settings + + return cpd + +def construct_path(x, re = escape_path_re): + """ + Join a path sequence using ',' and escaping ',' in the pieces. + """ + return ','.join((re.sub(ri.re_pct_encode, y) for y in x)) + +def construct(x, obscure_password = False): + """ + Construct a RI dictionary from a clientparams dictionary. + """ + # the rather exhaustive settings choreography is due to + # a desire to allow the search_path to be appended in the fragment + settings = x.get('settings') + no_path_settings = None + search_path = None + if settings: + if isinstance(settings, dict): + siter = settings.items() + search_path = settings.get('search_path') + else: + siter = list(settings) + search_path = [(k,v) for k,v in siter if k == 'search_path'] + search_path.append((None,None)) + search_path = search_path[-1][1] + no_path_settings = [(k,v) for k,v in siter if k != 'search_path'] + if not no_path_settings: + no_path_settings = None + + # It could be a string search_path, split if it is. + if search_path is not None and isinstance(search_path, str): + search_path = split_ident(search_path, sep = ',') + + port = None + if 'unix' in x: + host = '[unix:' + x['unix'].replace('/',':') + ']' + # ignore port.. it's a mis-config. + elif 'host' in x: + host = x['host'] + if ':' in host: + host = '[' + host + ']' + port = x.get('port') + else: + host = None + port = x.get('port') + + path = [] + if 'database' in x: + path.append(x['database']) + if 'path' in x: + path.extend(x['path'] or ()) + + password = x.get('password') + if obscure_password and password is not None: + password = '***' + driver_params = list({ + '[' + k + ']' : str(v) for k,v in x.items() + if k not in ( + 'user', 'password', 'port', 'database', 'ssl', + 'path', 'host', 'unix', 'ipv','settings' + ) + }.items()) + driver_params.sort(key=get0) + + return ( + 'pqs' if x.get('ssl', False) is True else 'pq', + # netloc: user:pass@host[:port] + ri.unsplit_netloc(( + x.get('user'), + password, + host, + None if 'port' not in x else str(x['port']) + )), + None if not path else '/'.join([ + ri.escape_path_re.sub(path_comp, '/') + for path_comp in path + ]), + (ri.construct_query(driver_params) if driver_params else None) + if no_path_settings is None else ( + ri.construct_query( + driver_params + no_path_settings + ) + ), + None if search_path is None else construct_path(search_path), + ) + +def parse(s, fieldproc = ri.unescape): + """ + Parse a Postgres IRI into a dictionary object. + """ + return structure( + # In ri.parse, don't unescape the parsed values as our sub-structure + # uses the escape mechanism in IRIs to specify literal separator + # characters. + ri.parse(s, fieldproc = str), + fieldproc = fieldproc + ) + +def serialize(x, obscure_password = False): + """ + Return a Postgres IRI from a dictionary object. + """ + return ri.unsplit(construct(x, obscure_password = obscure_password)) + +if __name__ == '__main__': + import sys + for x in sys.argv[1:]: + print("{src} -> {parsed!r} -> {serial}".format( + src = x, + parsed = parse(x), + serial = serialize(parse(x)) + )) diff --git a/py_opengauss/lib/__init__.py b/py_opengauss/lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1307a3b9c82073ee23f42a1294fd1c7035ea0995 --- /dev/null +++ b/py_opengauss/lib/__init__.py @@ -0,0 +1,483 @@ +## +# .lib - libraries; manage SQL outside of Python. +## +""" +PostgreSQL statement and object libraries. + +The purpose of a library is provide a means to manage a mapping of symbols +to database operations or objects. These operations can be simple statements, +procedures, or something more complex. + +Libraries are intended to allow the programmer to isolate and manage SQL outside +of a system's code-flow. It provides a means to construct the basic Python +interfaces to a PostgreSQL-based application. +""" +import io +import os.path +from types import ModuleType +from abc import abstractmethod, abstractproperty +from ..python.element import Element, ElementSet +from .. import api as pg_api +from .. import sys as pg_sys +from .. import exceptions as pg_exc +from ..python.itertools import find +from itertools import chain + +try: + libdir = os.path.abspath(os.path.dirname(__file__)) +except NameError: + pass +else: + if os.path.exists(libdir): + pg_sys.libpath.insert(0, libdir) + del libdir + +__all__ = [ + 'Library', + 'SymbolCollection', + 'ILF', + 'Symbol', + 'Binding', + 'BoundSymbol', + 'find_libsql', + 'load', +] + +class Symbol(Element): + """ + An annotated SQL statement string. + + The annotations describe how the statement should be used. + """ + __slots__ = ( + 'library', + 'source', + 'name', + 'method', + 'type', + 'parameters', + ) + _e_label = 'SYMBOL' + _e_factors = ('library', 'source',) + + # The statement execution methods; symbols allow this to be specified + # in order for a default method to be selected. + execution_methods = { + 'first', + 'rows', + 'chunks', + 'declare', + 'load_chunks', + 'load_rows', + 'column', + } + + def _e_metas(self): + yield (None, self.name) + + def __init__(self, + library, source, + name = None, + method = None, + type = None, + parameters = None, + reference = False, + ): + self.library = library + self.source = source + self.name = name + if method in (None, '', 'all'): + method = None + elif method not in self.execution_methods: + raise ValueError("unknown execution method: " + repr(method)) + self.method = method + self.type = type + self.parameters = parameters + self.reference = reference + + def __str__(self): + """ + Provide the source of the query's symbol. + """ + # Explicitly run str() on source as it is expected that a + # given symbol's source may be generated. + return str(self.source) + +class Library(Element): + """ + A library is mapping of symbol names to `postgresql.lib.Symbol` instances. + """ + _e_label = 'LIBRARY' + _e_factors = () + + @abstractproperty + def address(self) -> str: + """ + A string indicating the source of the symbols. + """ + + @abstractproperty + def name(self) -> str: + """ + The name to bind the library as. Should be an identifier. + """ + + @abstractproperty + def preload(self) -> {str,}: + """ + A set of symbols that should prepared when the library is bound. + """ + + @abstractmethod + def symbols(self) -> [str]: + """ + Iterable of symbol names provides by the library. + """ + + @abstractmethod + def get_symbol(self, name) -> (Symbol, [Symbol]): + """ + Return the symbol with the given name. + """ + +class SymbolCollection(Library): + """ + Explicitly composed library. (Symbols passed into __init__) + """ + preload = None + symtypes = ( + 'static', + 'preload', + 'const', + 'proc', + 'transient', + ) + + def __init__(self, symbols, preface = None): + """ + Given an iterable of (symtype, symexe, doc, sql) tuples, create a + symbol collection. + """ + self.preface = preface + self._address = None + self._name = None + s = self.symbolsd = {} + self.preload = set() + for name, (isref, typ, exe, doc, query) in symbols: + if typ and typ not in self.symtypes: + raise ValueError( + "symbol %r has an invalid type: %r" %(name, typ) + ) + if typ == 'preload': + self.preload.add(name) + typ = None + elif typ == 'proc': + pass + SYM = Symbol(self, query, + name = name, + method = exe, + type = typ, + reference = isref + ) + s[name] = SYM + +class ILF(SymbolCollection): + 'INI Library Format' + def _e_metas(self): + yield (None, self._address or 'ILF') + + def __repr__(self): + return self.__class__.__module__ + '.' + self.__class__.__name__ + '.open(' + repr(self.address) + ')' + + @property + def name(self): + return self._name + + @property + def address(self): + return self._address + + def get_symbol(self, name): + return self.symbolsd.get(name) + + def symbols(self): + return self.symbolsd.keys() + + @classmethod + def from_lines(typ, lines): + """ + Create an anonymous ILF library from a sequence of lines. + """ + prev = '' + curid = None + curblock = [] + blocks = [] + for line in lines: + l = line.strip() + if l.startswith('[') and l.endswith(']'): + blocks.append((curid, curblock)) + curid = line + curblock = [] + elif line.startswith('*[') and ']' in line: + ref, rest = line.split(']', 1) + # strip the leading '*[' + ref = ref[2:] + # dereferencing will take place later. + curblock.append((ref, rest)) + else: + curblock.append(line) + blocks.append((curid, curblock)) + preface = ''.join(blocks.pop(0)[1]) + syms = [] + for symdesc, block in blocks: + # symbol name + # symbol type + # how to execute symbol + name, styp, exe, *_ = (tuple( + symdesc.strip().strip('[]').split(':') + ) + (None, None)) + doc = '' + endofcomment = 0 + # resolve any symbol references; only one per line. + block = [ + x if x.__class__ is not tuple else ( + find(reversed(syms), lambda y: y[0] == x[0])[1][-1] + x[1] + ) + for x in block + ] + for x in block: + if x.startswith('-- '): + doc += x[3:] + else: + break + endofcomment += 1 + query = ''.join(block[endofcomment:]) + if styp == 'proc': + query = query.strip() + if name.startswith('&'): + name = name[1:] + isref = True + else: + isref = False + syms.append((name, (isref, styp, exe, doc, query))) + return typ(syms, preface = preface) + + @classmethod + def open(typ, filepath, *args, **kw): + """ + Create a named ILF library from a file path. + """ + with io.open(filepath, *args, **kw) as fp: + r = typ.from_lines(fp) + r._address = os.path.abspath(filepath) + bn = os.path.basename(filepath) + if bn.startswith('lib') and bn.endswith('.sql'): + r._name = bn[3:-4] or None + return r + +class BoundSymbol(object): + """ + A symbol bound to a database(connection). + """ + def __init__(self, symbol, database): + if symbol.type == 'proc': + proc = database.proc(symbol) + self.method = proc.__call__ + self.object = proc + else: + ps = database.prepare(symbol) + m = symbol.method + if m is None: + self.method = ps.__call__ + else: + self.method = getattr(ps, m) + self.object = ps + + def __call__(self, *args, **kw): + return self.method(*args, **kw) + +class BoundReference(object): + """ + A symbol bound to a database whose results make up the source of a symbol + that will be created upon the execution of this symbol. + + A reference to a symbol. + """ + + def __init__(self, symbol, database): + self.symbol = symbol + self.database = database + self.method = database.prepare(symbol).chunks + + def __call__(self, *args, **kw): + chunks = chain.from_iterable(self.method(*args, **kw)) + # Join columns with a space, and rows with a newline. + src = '\n'.join([' '.join(row) for row in chunks]) + return BoundSymbol( + Symbol( + self.symbol.library, src, + name = self.symbol.name, + method = self.symbol.method, + type = self.symbol.type, + parameters = self.symbol.parameters, + reference = False, + ), + self.database, + ) + +class Binding(object): + """ + Library bound to a database(connection). + """ + def __init__(self, database, library): + self.__dict__.update({ + '__database__' : database, + '__symbol_library__' : library, + '__symbol_cache__' : {}, + }) + for x in library.preload: + # cache all preloaded symbols. + getattr(self, x) + + def __repr__(self): + return '' %( + self.__symbol_library__.name, + self.__database__ + ) + + def __dir__(self): + return dir(super()) + list(self.__symbol_library__.symbols()) + + def __getattr__(self, name): + """ + Return a BoundSymbol against the Binding's database with the + symbol named ``name`` in the Binding's library. + """ + d = self.__dict__ + s = d['__symbol_cache__'] + db = d['__database__'] + lib = d['__symbol_library__'] + + bs = s.get(name) + if bs is None: + # No symbol cached with that name. + # Everything is crammed in here because + # we do *not* want methods on this object. + # The namespace is primarily reserved for symbols. + sym = lib.get_symbol(name) + if sym is None: + raise AttributeError( + "symbol %r does not exist in library %r" %( + name, lib.address + ) + ) + if sym.reference: + # Reference. + bs = BoundReference(sym, db) + if sym.type == 'const': + # Constant Reference means a BoundSymbol. + bs = bs() + if sym.type != 'transient': + s[name] = bs + else: + if not isinstance(sym, Symbol): + # subjective symbol... + sym = sym(db) + if not isinstance(sym, Symbol): + raise TypeError( + "callable symbol, %r, did not produce " \ + "Symbol instance" %(name,) + ) + if sym.type == 'const': + r = BoundSymbol(sym, db)() + if sym.method in ('chunks', 'rows', 'column'): + # resolve the iterator + r = list(r) + bs = s[name] = r + else: + bs = BoundSymbol(sym, db) + if sym.type != 'transient': + s[name] = bs + return bs + +class Category(pg_api.Category): + """ + Library-based Category. + """ + _e_factors = ('libraries',) + def _e_metas(self): + yield ('aliases', {k.name: v for k, v in self.aliases.items()}) + + def __init__(self, *libs, **named_libs): + sl = set(libs) + nl = set(named_libs.values()) + self._direct = sl + self.libraries = ElementSet(sl | nl) + self.aliases = {} + # lib -> [alias-1, alias-2, ..., alias-n] + for k, v in named_libs.items(): + d = self.aliases.setdefault(v, []) + d.append(k) + + def __call__(self, database): + for l in self.libraries: + names = list(self.aliases.get(l, ())) + if l in self._direct: + names.append(l.name) + B = Binding(database, l) + for n in names: + if hasattr(database, n): + raise AttributeError("attribute already exists: " + name) + setattr(database, n, B) + +def find_libsql(libname, paths, prefix = 'lib', suffix = '.sql'): + """ + Given the base library name, `libname`, look for a file named + "" in each directory(`paths`). + All finds will be yielded out. + """ + lib = prefix + libname + suffix + for p in paths: + p = os.path.join(p, lib) + if os.path.exists(p): + yield p + +def load(libref): + """ + Given a reference to a symbol library, instantiate the Library instance. + + Currently this function accepts: + + * `str` objects as absolute paths or relative to sys.libpath. + * Module objects. + """ + if isinstance(libref, ModuleType): + if hasattr(libref, '__lib'): + lib = getattr(libref, '__lib') + else: + lib = ModuleLibrary(libref) + setattr(libref, '__lib', lib) + elif isinstance(libref, str): + try: + if os.path.sep in libref: + # sep in libref? it's treated as a path. + lib = ILF.open(libref) + else: + # first one wins. + for x in find_libsql(libref, pg_sys.libpath): + break + else: + raise pg_exc.LoadError("library %r not in postgresql.sys.libpath" % (libref,)) + lib = ILF.open(x) + except pg_exc.LoadError: + raise + except Exception: + # any exception is a load error. + raise pg_exc.LoadError("failed load ILF, " + repr(libref)) + else: + raise TypeError("load takes a module or str, given " + type(libref).__name__) + return lib + +sys = load('sys') + +__docformat__ = 'reStructuredText' diff --git a/py_opengauss/lib/libsys.sql b/py_opengauss/lib/libsys.sql new file mode 100644 index 0000000000000000000000000000000000000000..e0bcc0b2850867f89121dc671f015521141429e4 --- /dev/null +++ b/py_opengauss/lib/libsys.sql @@ -0,0 +1,341 @@ +## +# libsys.sql - SQL to support driver features +## +-- Queries for dealing with the PostgreSQL catalogs for supporting the driver. + +[lookup_type::first] +SELECT + ns.nspname as namespace, + bt.typname, + bt.typtype, + bt.typlen, + bt.typelem, + bt.typrelid, + ae.oid AS ae_typid, + ae.typreceive::oid != 0 AS ae_hasbin_input, + ae.typsend::oid != 0 AS ae_hasbin_output +FROM pg_catalog.pg_type bt + LEFT JOIN pg_type ae + ON ( + bt.typlen = -1 AND + bt.typelem != 0 AND + bt.typelem = ae.oid + ) + LEFT JOIN pg_catalog.pg_namespace ns + ON (ns.oid = bt.typnamespace) +WHERE bt.oid = $1 + +[lookup_composite] +-- Get the type Oid and name of the attributes in `attnum` order. +SELECT + CAST(atttypid AS oid) AS atttypid, + CAST(attname AS text) AS attname, + tt.typtype = 'd' AS is_domain +FROM + pg_catalog.pg_type t LEFT JOIN pg_catalog.pg_attribute a + ON (t.typrelid = a.attrelid) + LEFT JOIN pg_type tt ON (a.atttypid = tt.oid) +WHERE + attrelid = $1 AND NOT attisdropped AND attnum > 0 +ORDER BY attnum ASC + +[lookup_basetype_recursive] +SELECT + (CASE WHEN tt.typtype = 'd' THEN + (WITH RECURSIVE typehierarchy(typid, depth) AS ( + SELECT + t2.typbasetype, + 0 + FROM + pg_type t2 + WHERE + t2.oid = tt.oid + UNION ALL + SELECT + t2.typbasetype, + th.depth + 1 + FROM + pg_type t2, + typehierarchy th + WHERE + th.typid = t2.oid + AND t2.typbasetype != 0 + ) SELECT typid FROM typehierarchy ORDER BY depth DESC LIMIT 1) + + ELSE NULL + END) AS basetypid +FROM + pg_catalog.pg_type tt +WHERE + tt.oid = $1 + +[lookup_basetype] +SELECT + tt.typbasetype +FROM + pg_catalog.pg_type tt +WHERE + tt.oid = $1 + +[lookup_procedures] +SELECT + pg_proc.oid, + pg_proc.*, + pg_proc.oid::regproc AS _proid, + pg_proc.oid::regprocedure as procedure_id, + COALESCE(string_to_array(trim(replace(textin(oidvectorout(proargtypes)), ',', ' '), '{}'), ' ')::oid[], '{}'::oid[]) + AS proargtypes, + (pg_type.oid = 'record'::regtype or pg_type.typtype = 'c') AS composite +FROM + pg_catalog.pg_proc LEFT JOIN pg_catalog.pg_type ON ( + pg_proc.prorettype = pg_type.oid + ) + +[lookup_procedure_oid::first] +*[lookup_procedures] + WHERE pg_proc.oid = $1 + +[lookup_procedure_rp::first] +*[lookup_procedures] + WHERE pg_proc.oid = regprocedurein($1) + +[lookup_prepared_xacts::first] +SELECT + COALESCE(ARRAY( + SELECT + gid::text + FROM + pg_catalog.pg_prepared_xacts + WHERE + database = current_database() + AND ( + owner = $1::text + OR ( + (SELECT rolsuper FROM pg_roles WHERE rolname = $1::text) + ) + ) + ORDER BY prepared ASC + ), ('{}'::text[])) + +[regtypes::column] +SELECT pg_catalog.regtypein(pg_catalog.textout(($1::text[])[i]))::oid AS typoid +FROM pg_catalog.generate_series(1, array_upper($1::text[], 1)) AS g(i) + +[xact_is_prepared::first] +SELECT TRUE FROM pg_catalog.pg_prepared_xacts WHERE gid::text = $1 + +[get_statement_source::first] +SELECT statement FROM pg_catalog.pg_prepared_statements WHERE name = $1 + +[setting_get] +SELECT setting FROM pg_catalog.pg_settings WHERE name = $1 + +[setting_set::first] +SELECT pg_catalog.set_config($1, $2, false) + +[setting_len::first] +SELECT count(*) FROM pg_catalog.pg_settings + +[setting_item] +SELECT name, setting FROM pg_catalog.pg_settings WHERE name = $1 + +[setting_mget] +SELECT name, setting FROM pg_catalog.pg_settings WHERE name = ANY ($1) + +[setting_keys] +SELECT name FROM pg_catalog.pg_settings ORDER BY name + +[setting_values] +SELECT setting FROM pg_catalog.pg_settings ORDER BY name + +[setting_items] +SELECT name, setting FROM pg_catalog.pg_settings ORDER BY name + +[setting_update] +SELECT + ($1::text[][])[i][1] AS key, + pg_catalog.set_config(($1::text[][])[i][1], $1[i][2], false) AS value +FROM + pg_catalog.generate_series(1, array_upper(($1::text[][]), 1)) g(i) + +[startup_data:transient:first] +-- 8.2 and greater +SELECT + pg_catalog.version()::text AS version, + backend_start::text, + client_addr::text, + client_port::int +FROM pg_catalog.pg_stat_activity WHERE procpid = pg_catalog.pg_backend_pid() +UNION ALL SELECT + pg_catalog.version()::text AS version, + NULL::text AS backend_start, + NULL::text AS client_addr, + NULL::int AS client_port +LIMIT 1; + +[startup_data_92:transient:first] +-- 9.2 and greater +SELECT + pg_catalog.version()::text AS version, + backend_start::text, + client_addr::text, + client_port::int +FROM pg_catalog.pg_stat_activity WHERE pid = pg_catalog.pg_backend_pid() +UNION ALL SELECT + pg_catalog.version()::text AS version, + NULL::text AS backend_start, + NULL::text AS client_addr, + NULL::int AS client_port +LIMIT 1; + +[startup_data_no_start:transient:first] +-- 8.1 only, but is unused as often the backend's activity row is not +-- immediately present. +SELECT + pg_catalog.version()::text AS version, + NULL::text AS backend_start, + client_addr::text, + client_port::int +FROM pg_catalog.pg_stat_activity WHERE procpid = pg_catalog.pg_backend_pid(); + +[startup_data_only_version:transient:first] +-- In 8.0, there's nothing there. +SELECT + pg_catalog.version()::text AS version, + NULL::text AS backend_start, + NULL::text AS client_addr, + NULL::int AS client_port; + +[terminate_backends:transient:column] +-- Terminate all except mine. +SELECT + procpid, pg_catalog.pg_terminate_backend(procpid) +FROM + pg_catalog.pg_stat_activity +WHERE + procpid != pg_catalog.pg_backend_pid() + +[terminate_backends_92:transient:column] +-- Terminate all except mine. 9.2 and later +SELECT + pid, pg_catalog.pg_terminate_backend(pid) +FROM + pg_catalog.pg_stat_activity +WHERE + pid != pg_catalog.pg_backend_pid() + +[cancel_backends:transient:column] +-- Cancel all except mine. +SELECT + procpid, pg_catalog.pg_cancel_backend(procpid) +FROM + pg_catalog.pg_stat_activity +WHERE + procpid != pg_catalog.pg_backend_pid() + +[cancel_backends_92:transient:column] +-- Cancel all except mine. 9.2 and later +SELECT + pid, pg_catalog.pg_cancel_backend(pid) +FROM + pg_catalog.pg_stat_activity +WHERE + pid != pg_catalog.pg_backend_pid() + +[sizeof_db:transient:first] +SELECT pg_catalog.pg_database_size(current_database())::bigint + +[sizeof_cluster:transient:first] +SELECT SUM(pg_catalog.pg_database_size(datname))::bigint FROM pg_database + +[sizeof_relation::first] +SELECT pg_catalog.pg_relation_size($1::text)::bigint + +[pg_reload_conf:transient:] +SELECT pg_reload_conf() + +[languages:transient:column] +SELECT lanname FROM pg_catalog.pg_language + +[listening_channels:transient:column] +SELECT channel FROM pg_catalog.pg_listening_channels() AS x(channel) + +[listening_relations:transient:column] +-- listening_relations: old version of listening_channels. +SELECT relname as channel FROM pg_catalog.pg_listener +WHERE listenerpid = pg_catalog.pg_backend_pid(); + +[notify::first] +-- 9.0 and greater +SELECT + COUNT(pg_catalog.pg_notify(($1::text[])[i][1], $1[i][2]) IS NULL) +FROM + pg_catalog.generate_series(1, array_upper($1, 1)) AS g(i) + +[release_advisory_shared] +SELECT + CASE WHEN ($2::int8[])[i] IS NULL + THEN + pg_catalog.pg_advisory_unlock_shared(($1::int4[])[i][1], $1[i][2]) + ELSE + pg_catalog.pg_advisory_unlock_shared($2[i]) + END AS released +FROM + pg_catalog.generate_series(1, COALESCE(array_upper($2::int8[], 1), array_upper($1::int4[], 1))) AS g(i) + +[acquire_advisory_shared] +SELECT COUNT(( + CASE WHEN ($2::int8[])[i] IS NULL + THEN + pg_catalog.pg_advisory_lock_shared(($1::int4[])[i][1], $1[i][2]) + ELSE + pg_catalog.pg_advisory_lock_shared($2[i]) + END +) IS NULL) AS acquired +FROM + pg_catalog.generate_series(1, COALESCE(array_upper($2::int8[], 1), array_upper($1::int4[], 1))) AS g(i) + +[try_advisory_shared] +SELECT + CASE WHEN ($2::int8[])[i] IS NULL + THEN + pg_catalog.pg_try_advisory_lock_shared(($1::int4[])[i][1], $1[i][2]) + ELSE + pg_catalog.pg_try_advisory_lock_shared($2[i]) + END AS acquired +FROM + pg_catalog.generate_series(1, COALESCE(array_upper($2::int8[], 1), array_upper($1::int4[], 1))) AS g(i) + +[release_advisory_exclusive] +SELECT + CASE WHEN ($2::int8[])[i] IS NULL + THEN + pg_catalog.pg_advisory_unlock(($1::int4[])[i][1], $1[i][2]) + ELSE + pg_catalog.pg_advisory_unlock($2[i]) + END AS released +FROM + pg_catalog.generate_series(1, COALESCE(array_upper($2::int8[], 1), array_upper($1::int4[], 1))) AS g(i) + +[acquire_advisory_exclusive] +SELECT COUNT(( + CASE WHEN ($2::int8[])[i] IS NULL + THEN + pg_catalog.pg_advisory_lock(($1::int4[])[i][1], $1[i][2]) + ELSE + pg_catalog.pg_advisory_lock($2[i]) + END +) IS NULL) AS acquired -- Guaranteed to be acquired once complete. +FROM + pg_catalog.generate_series(1, COALESCE(array_upper($2::int8[], 1), array_upper($1::int4[], 1))) AS g(i) + +[try_advisory_exclusive] +SELECT + CASE WHEN ($2::int8[])[i] IS NULL + THEN + pg_catalog.pg_try_advisory_lock(($1::int4[])[i][1], $1[i][2]) + ELSE + pg_catalog.pg_try_advisory_lock($2[i]) + END AS acquired +FROM + pg_catalog.generate_series(1, COALESCE(array_upper($2::int8[], 1), array_upper($1::int4[], 1))) AS g(i) diff --git a/py_opengauss/message.py b/py_opengauss/message.py new file mode 100644 index 0000000000000000000000000000000000000000..7bbe77d9ba2d736564c116dca1fbe585fa8f34ff --- /dev/null +++ b/py_opengauss/message.py @@ -0,0 +1,144 @@ +## +# .message - PostgreSQL message representation +## +from operator import itemgetter +from .python.element import prime_factor +# Final msghook called exists at .sys.msghook +from . import sys as pg_sys + +from .api import Message +class Message(Message): + """ + A message emitted by PostgreSQL. This element is universal, so + `postgresql.api.Message` is a complete implementation for representing a + message. Any interface should produce these objects. + """ + _e_label = property(lambda x: getattr(x, 'details').get('severity', 'MESSAGE')) + _e_factors = ('creator',) + + def _e_metas(self, get0 = itemgetter(0)): + yield (None, self.message) + if self.code and self.code != "00000": + yield ('CODE', self.code) + locstr = self.location_string + if locstr: + yield ('LOCATION', locstr + ' from ' + self.source) + else: + yield ('LOCATION', self.source) + for k, v in sorted(self.details.items(), key = get0): + if k not in self.standard_detail_coverage: + yield (k.upper(), str(v)) + + source = 'SERVER' + code = '00000' + message = None + details = None + + severities = ( + 'DEBUG', + 'INFO', + 'NOTICE', + 'WARNING', + 'ERROR', + 'FATAL', + 'PANIC', + ) + sources = ( + 'SERVER', + 'CLIENT', + ) + + def isconsistent(self, other): + """ + Return `True` if the all the fields of the message in `self` are + equivalent to the fields in `other`. + """ + if not isinstance(other, self.__class__): + return False + # creator is contextual information + return ( + self.code == other.code and \ + self.message == other.message and \ + self.details == other.details and \ + self.source == other.source + ) + + def __init__(self, + message, + code = None, + details = {}, + source = None, + creator = None, + ): + self.message = message + self.details = details + self.creator = creator + if code is not None and self.code != code: + self.code = code + if source is not None and self.source != source: + self.source = source + + def __repr__(self): + return "{mod}.{typname}({message!r}{code}{details}{source}{creator})".format( + mod = self.__module__, + typname = self.__class__.__name__, + message = self.message, + code = ( + "" if self.code == type(self).code + else ", code = " + repr(self.code) + ), + details = ( + "" if not self.details + else ", details = " + repr(self.details) + ), + source = ( + "" if self.source is None + else ", source = " + repr(self.source) + ), + creator = ( + "" if self.creator is None + else ", creator = " + repr(self.creator) + ) + ) + + @property + def location_string(self): + """ + A single line representation of the 'file', 'line', and 'function' keys + in the `details` dictionary. + """ + details = self.details + loc = [ + details.get(k, '?') for k in ('file', 'line', 'function') + ] + return ( + "" if loc == ['?', '?', '?'] + else "File {0!r}, "\ + "line {1!s}, in {2!s}".format(*loc) + ) + + # keys to filter in .details + standard_detail_coverage = frozenset(['message', 'severity', 'file', 'function', 'line',]) + + def emit(self, starting_point = None): + """ + Take the given message object and hand it to all the primary + factors(creator) with a msghook callable. + """ + if starting_point is not None: + f = starting_point + else: + f = self.creator + + while f is not None: + if getattr(f, 'msghook', None) is not None: + if f.msghook(self): + # the trap returned a nonzero value, + # so don't continue raising. (like with's __exit__) + return f + f = prime_factor(f) + if f: + f = f[1] + # if the next primary factor is without a raise or does not exist, + # send the message to postgresql.sys.msghook + pg_sys.msghook(self) diff --git a/py_opengauss/notifyman.py b/py_opengauss/notifyman.py new file mode 100644 index 0000000000000000000000000000000000000000..d72d142126686ea03dca21c49b448cc5ca2eab39 --- /dev/null +++ b/py_opengauss/notifyman.py @@ -0,0 +1,227 @@ +## +# .notifyman - Receive and manage NOTIFY events. +## +""" +Notification Management Tools + +Primarily this module houses the `NotificationManager` class which provides an +iterator for a NOTIFY event loop against a set of connections. + + >>> import py_opengauss + >>> db = py_opengauss.open(...) + >>> from py_opengauss.notifyman import NotificationManager + >>> nm = NotificationManager(db, timeout = 10) # idle events every 10 seconds + >>> for x in nm: + ... if x is None: + ... # idle event + ... ... + ... db, notifies = x + ... for channel, payload, pid in notifies: + ... ... +""" +from time import time +from select import select +from itertools import chain + +class NotificationManager(object): + """ + A class for managing the asynchronous notifications received by a + set of connections. + + Instances provide the iterator for an event loop that responds to NOTIFYs + received by the connections being watched. There is no thread safety, so + when a connection is being managed, it should not be used concurrently in + other threads while being managed. + """ + __slots__ = ( + 'connections', + 'garbage', + 'incoming', + 'timeout', + '_last_time', + '_pulled', + ) + + def __init__(self, *connections, timeout = None): + self.settimeout(timeout) + self.connections = set(connections) + # Connections that failed. + self.garbage = set() + # Used to store NOTIFYs consumed from the connections + self.incoming = None + self._last_time = None + # connection -> sequence of NOTIFYs + self._pulled = dict() + + # Check the wire *and* wait for new messages. + def _wait_on_wires(self, time = time, select = select): + if self.timeout == 0: + # We're polling. + max_duration = 0 + else: + # If timeout is None, we don't issue idle events, but + # we still cycle in case the timeout is changed. + if self._last_time is not None: + max_duration = (self.timeout or 10) - (time() - self._last_time) + if max_duration < 0: + max_duration = 0 + else: + self._last_time = time() + max_duration = self.timeout or 10 + + # Connections already marked as "bad" should not be checked. + check = self.connections - self.garbage + for db in check: + if db.closed: + self.connections.remove(db) + self.garbage.add(db) + check = self.connections - self.garbage + + r, w, x = select(check, (), check, max_duration) + # Make sure the connection's _notifies get filled. + for db in r: + # Collect any pending events. + try: + # Even if db is in a failed transaction, this + # 'null' command will succeed. + # (only connection failures blow up) + db.execute('') + except Exception: + # failed to collect notifies; put in exception list. + # It is very unlikely that this is *not* a FATAL error. + x.append(db) + self.trash(x) + + def trash(self, connections): + """ + Remove the given connections from the set of good connections, and add + them to the `garbage` set. + + This method can be overridden by subclasses to take a callback approach + to connection failures. + """ + # Identify the bad connections. + self.garbage.update(connections) + self.connections.difference_update(connections) + + def queue(self, db, notifies): + """ + Queue the notifies for the specified connection. + + This method can be overridden by subclasses to take a callback approach + to notification management. + """ + l = self._pulled.setdefault(db, list()) + l.extend(notifies) + + # Check the connection's _notifies list; just scan everything. + def _pull_from_connections(self): + for db in self.connections: + if not db._notifies: + # nothing queued up, look at the next connection + continue + # Pull notifies into the NotificationManager + decode = db.typio.decode + notifies = [ + (decode(x.channel), decode(x.payload), x.pid) + for x in db._notifies + ] + self.queue(db, notifies) + del db._notifies[:len(notifies)] + + # "Append" the pulled NOTIFYs to the 'incoming' iterator. + def _queue_next(self): + new_seqs = [] + for db in self._pulled: + decode = db.typio.decode + new_seqs.append((db, self._pulled[db])) + + if new_seqs: + if self.incoming: + # Already have incoming; not an expected condition, + # but let's compensate. + self.incoming, self._pulled = chain(self.incoming, iter(new_seqs)), {} + else: + self.incoming, self._pulled = iter(new_seqs), {} + elif self.incoming is None: + # Use this to trigger the StopIteration case of zero-timeout. + self.incoming, self._pulled = iter(()), {} + + def _timedout(self, time = time): + # Idles are guaranteed to occur, but make sure that + # __next__ has a chance to check the connections and the wires. + now = time() + if self._last_time is None: + self._last_time = now + elif self.timeout and now >= (self._last_time + self.timeout): + # Set last_time to None in case the timeout is so low + # that this condition keeps NOTIFYs from being seen. + self._last_time = None + # Signal timeout. + return True + else: + # toggle back to None. + self._last_time = None + return False + + def settimeout(self, seconds): + """ + Set the maximum duration, in seconds, for waiting for NOTIFYs on the + set of managed connections. The given `seconds` argument can be a number + or `None`. + + A timeout of `None` means no timeout, and "idle" events will never + occur. + + A timeout of `0` means to never wait for NOTIFYs. This has the effect of + a StopIteration being raised by `__next__` when there are no more + Notifications available for any of the connections in the set. "Idle" + events will never occur in this situation as well. + + A timeout greater than zero means to emit `None` as "idle" events into + the loop at the specified interval. Idle events are guaranteed to occur. + """ + if seconds is not None and seconds < 0: + raise ValueError("cannot set timeout less than zero") + self.timeout = seconds + + def gettimeout(self): + """ + Get the timeout assigned by `settimeout`. + """ + return self.timeout + + def __iter__(self): + return self + + def __next__(self, time = time): + checked_wire = True + # Loop until NOTIFY received or timeout. + while True: + if self.incoming is not None: + try: + return next(self.incoming) + except StopIteration: + # Nothing more in this incoming. + self.incoming = None + # Allow a zero timeout to be used to indicate + # that there are no NOTIFYs to be read. + # This can be used to poll a set of + # connections instead of listening. + if self.timeout == 0 or not self.connections: + raise + + # timeout happened? yield the "idle" event. + # This check **must** happen after .incoming is checked. + # Never emit idle when there are real events. + if self._timedout(): + return None + + if not checked_wire and self.connections: + # Nothing queued up, check connections if any. + self._wait_on_wires() + checked_wire = True + else: + checked_wire = False + self._pull_from_connections() + self._queue_next() diff --git a/py_opengauss/pgpassfile.py b/py_opengauss/pgpassfile.py new file mode 100644 index 0000000000000000000000000000000000000000..ee0ae73fda4d13ba9cebc382a5db54a55d30be5b --- /dev/null +++ b/py_opengauss/pgpassfile.py @@ -0,0 +1,70 @@ +## +# .pgpassfile - parse and lookup passwords in a pgpassfile +## +""" +Parse pgpass files and subsequently lookup a password. +""" +import os.path + +def split(line, len = len): + line = line.strip() + if not line: + return None + r = [] + continuation = False + for x in line.split(':'): + if continuation: + # The colon was preceded by a backslash, it's part + # of the last field. Substitute the trailing backslash + # with the colon and append the next value. + r[-1] = r[-1][:-1] + ':' + x.replace('\\\\', '\\') + continuation = False + else: + # Even number of backslashes preceded the split. + # Normal field. + r.append(x.replace('\\\\', '\\')) + # Determine if the next field is a continuation of this one. + if (len(x) - len(x.rstrip('\\'))) % 2 == 1: + continuation = True + if len(r) != 5: + # Too few or too many fields. + return None + return r + +def parse(data): + """ + Produce a list of [(word, (host,port,dbname,user))] from a pgpass file object. + """ + return [ + (x[-1], x[0:4]) for x in [split(line) for line in data] if x + ] + +def lookup_password(words, uhpd): + """ + lookup_password(words, (user, host, port, database)) -> password + + Where 'words' is the output from pgpass.parse() + """ + user, host, port, database = uhpd + for word, (w_host, w_port, w_database, w_user) in words: + if (w_user == '*' or w_user == user) and \ + (w_host == '*' or w_host == host) and \ + (w_port == '*' or w_port == port) and \ + (w_database == '*' or w_database == database): + return word + +def lookup_password_file(path, t): + """ + Like lookup_password, but takes a file path. + """ + with open(path) as f: + return lookup_password(parse(f), t) + +def lookup_pgpass(d, passfile, exists = os.path.exists): + # If the password file exists, lookup the password + # using the config's criteria. + if exists(passfile): + return lookup_password_file(passfile, ( + str(d['user']), str(d['host']), str(d['port']), + str(d.get('database', d['user'])) + )) diff --git a/py_opengauss/port/__init__.py b/py_opengauss/port/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9ffb22f055b610b8ad2052600ded0fc213916ab0 --- /dev/null +++ b/py_opengauss/port/__init__.py @@ -0,0 +1,12 @@ +## +# .port +## +""" +Platform specific modules. + +The subject of each module should be the feature and the target platform. +This is done to keep modules small and descriptive. + +These modules are for internal use only. +""" +__docformat__ = 'reStructuredText' diff --git a/py_opengauss/port/_optimized/README b/py_opengauss/port/_optimized/README new file mode 100644 index 0000000000000000000000000000000000000000..4374c88e4efddeb37e61a72bf622c4d94a59731c --- /dev/null +++ b/py_opengauss/port/_optimized/README @@ -0,0 +1 @@ +This is the C ports of the more performance critical parts of py-postgresql. diff --git a/py_opengauss/port/_optimized/buffer.c b/py_opengauss/port/_optimized/buffer.c new file mode 100644 index 0000000000000000000000000000000000000000..7f46d2d1b29e8ee242f400aacbef29d7c6febe29 --- /dev/null +++ b/py_opengauss/port/_optimized/buffer.c @@ -0,0 +1,626 @@ +/* + * .port.optimized.pq_message_buffer - PQ message stream + */ +/* + * PQ messages normally take the form {type, (size), data} + */ +#define include_buffer_types \ + mTYPE(pq_message_stream) + +struct p_list +{ + PyObject *data; /* PyBytes pushed onto the buffer */ + struct p_list *next; +}; + +struct p_place +{ + struct p_list *list; + uint32_t offset; +}; + +struct p_buffer +{ + PyObject_HEAD + + struct p_place position; + struct p_list *last; /* for quick appends */ +}; + +/* + * Free the list until the given stop + */ +static void +pl_truncate(struct p_list *pl, struct p_list *stop) +{ + while (pl != stop) + { + struct p_list *next = pl->next; + Py_DECREF(pl->data); + free(pl); + pl = next; + } +} + +/* + * Reset the buffer + */ +static void +pb_truncate(struct p_buffer *pb) +{ + struct p_list *pl = pb->position.list; + + pb->position.offset = 0; + pb->position.list = NULL; + pb->last = NULL; + + pl_truncate(pl, NULL); +} + +/* + * p_truncate - truncate the buffer + */ +static PyObject * +p_truncate(PyObject *self) +{ + pb_truncate((struct p_buffer *) self); + Py_INCREF(Py_None); + return(Py_None); +} + + +static void +p_dealloc(PyObject *self) +{ + struct p_buffer *pb = ((struct p_buffer *) self); + pb_truncate(pb); + self->ob_type->tp_free(self); +} + +static PyObject * +p_new(PyTypeObject *subtype, PyObject *args, PyObject *kw) +{ + static char *kwlist[] = {NULL}; + struct p_buffer *pb; + PyObject *rob; + + if (!PyArg_ParseTupleAndKeywords(args, kw, "", kwlist)) + return(NULL); + + rob = subtype->tp_alloc(subtype, 0); + pb = ((struct p_buffer *) rob); + pb->last = pb->position.list = NULL; + pb->position.offset = 0; + return(rob); +} + +/* + * p_at_least - whether the position has at least given number of bytes. + */ +static char +p_at_least(struct p_place *p, uint32_t amount) +{ + uint32_t current = 0; + struct p_list *pl; + + pl = p->list; + if (pl) + current += PyBytes_GET_SIZE(pl->data) - p->offset; + + if (current >= amount) + return((char) 1); + + if (pl) + { + for (pl = pl->next; pl != NULL; pl = pl->next) + { + current += PyBytes_GET_SIZE(pl->data); + if (current >= amount) + return((char) 1); + } + } + + return((char) 0); +} + +static uint32_t +p_seek(struct p_place *p, uint32_t amount) +{ + uint32_t amount_left = amount; + Py_ssize_t chunk_size; + + /* Can't seek after the end. */ + if (!p->list || p->offset == PyBytes_GET_SIZE(p->list->data)) + return(0); + + chunk_size = PyBytes_GET_SIZE(p->list->data) - p->offset; + + while (amount_left > 0) + { + /* + * The current list item has the position. + * Set the offset and break out. + */ + if (amount_left < chunk_size) + { + p->offset += amount_left; + amount_left = 0; + break; + } + + amount_left -= chunk_size; + p->list = p->list->next; + p->offset = 0; + if (p->list == NULL) + break; + + chunk_size = PyBytes_GET_SIZE(p->list->data); + } + + return(amount - amount_left); +} + +static uint32_t +p_memcpy(char *dst, struct p_place *p, uint32_t amount) +{ + struct p_list *pl = p->list; + uint32_t offset = p->offset; + uint32_t amount_left = amount; + char *src; + Py_ssize_t chunk_size; + + /* Nothing to read */ + if (pl == NULL) + return(0); + + src = (PyBytes_AS_STRING(pl->data) + offset); + chunk_size = PyBytes_GET_SIZE(pl->data) - offset; + + while (amount_left > 0) + { + uint32_t this_read = + chunk_size < amount_left ? chunk_size : amount_left; + + memcpy(dst, src, this_read); + dst = dst + this_read; + amount_left = amount_left - this_read; + + pl = pl->next; + if (pl == NULL) + break; + + src = PyBytes_AS_STRING(pl->data); + chunk_size = PyBytes_GET_SIZE(pl->data); + } + + return(amount - amount_left); +} + +static Py_ssize_t +p_length(PyObject *self) +{ + char header[5]; + long msg_count = 0; + uint32_t msg_length; + uint32_t copy_amount = 0; + struct p_buffer *pb; + struct p_place p; + + pb = ((struct p_buffer *) self); + p.list = pb->position.list; + p.offset = pb->position.offset; + + while (p.list != NULL) + { + copy_amount = p_memcpy(header, &p, 5); + if (copy_amount < 5) + break; + p_seek(&p, copy_amount); + + memcpy(&msg_length, header + 1, 4); + msg_length = local_ntohl(msg_length); + if (msg_length < 4) + { + PyErr_Format(PyExc_ValueError, + "invalid message size '%d'", msg_length); + return(-1); + } + msg_length -= 4; + + if (p_seek(&p, msg_length) < msg_length) + break; + + ++msg_count; + } + + return(msg_count); +} + +static PySequenceMethods pq_ms_as_sequence = { + (lenfunc) p_length, 0 +}; + + +/* + * Build a tuple from the given place. + */ +static PyObject * +p_build_tuple(struct p_place *p) +{ + char header[5]; + uint32_t msg_length; + PyObject *tuple; + PyObject *mt, *md; + + char *body = NULL; + uint32_t copy_amount = 0; + + copy_amount = p_memcpy(header, p, 5); + if (copy_amount < 5) + return(NULL); + p_seek(p, copy_amount); + + memcpy(&msg_length, header + 1, 4); + msg_length = local_ntohl(msg_length); + if (msg_length < 4) + { + PyErr_Format(PyExc_ValueError, + "invalid message size '%d'", msg_length); + return(NULL); + } + msg_length -= 4; + + if (!p_at_least(p, msg_length)) + return(NULL); + + /* + * Copy out the message body if we need to. + */ + if (msg_length > 0) + { + body = malloc(msg_length); + if (body == NULL) + { + PyErr_SetString(PyExc_MemoryError, + "could not allocate memory for message data"); + return(NULL); + } + copy_amount = p_memcpy(body, p, msg_length); + + if (copy_amount != msg_length) + { + free(body); + return(NULL); + } + + p_seek(p, copy_amount); + } + + mt = PyTuple_GET_ITEM(message_types, (int) header[0]); + if (mt == NULL) + { + /* + * With message_types, this is nearly a can't happen. + */ + if (body != NULL) free(body); + return(NULL); + } + Py_INCREF(mt); + + md = PyBytes_FromStringAndSize(body, (Py_ssize_t) msg_length); + if (body != NULL) + free(body); + if (md == NULL) + { + Py_DECREF(mt); + return(NULL); + } + + + tuple = PyTuple_New(2); + if (tuple == NULL) + { + Py_DECREF(mt); + Py_DECREF(md); + } + else + { + PyTuple_SET_ITEM(tuple, 0, mt); + PyTuple_SET_ITEM(tuple, 1, md); + } + + return(tuple); +} + +static PyObject * +p_write(PyObject *self, PyObject *data) +{ + struct p_buffer *pb; + + if (!PyBytes_Check(data)) + { + PyErr_SetString(PyExc_TypeError, + "pq buffer.write() method requires a bytes object"); + return(NULL); + } + pb = ((struct p_buffer *) self); + + if (PyBytes_GET_SIZE(data) > 0) + { + struct p_list *pl; + + pl = malloc(sizeof(struct p_list)); + if (pl == NULL) + { + PyErr_SetString(PyExc_MemoryError, + "could not allocate memory for pq message stream data"); + return(NULL); + } + + pl->data = data; + Py_INCREF(data); + pl->next = NULL; + + if (pb->last == NULL) + { + /* + * First and last. + */ + pb->position.list = pb->last = pl; + } + else + { + pb->last->next = pl; + pb->last = pl; + } + } + + Py_INCREF(Py_None); + return(Py_None); +} + +static PyObject * +p_next(PyObject *self) +{ + struct p_buffer *pb = ((struct p_buffer *) self); + struct p_place p; + PyObject *rob; + + p.offset = pb->position.offset; + p.list = pb->position.list; + + rob = p_build_tuple(&p); + if (rob != NULL) + { + pl_truncate(pb->position.list, p.list); + pb->position.list = p.list; + pb->position.offset = p.offset; + if (p.list == NULL) + pb->last = NULL; + } + return(rob); +} + +static PyObject * +p_read(PyObject *self, PyObject *args) +{ + int cur_msg, msg_count = -1, msg_in = 0; + struct p_place p; + struct p_buffer *pb; + PyObject *rob = NULL; + + if (!PyArg_ParseTuple(args, "|i", &msg_count)) + return(NULL); + + pb = (struct p_buffer *) self; + p.list = pb->position.list; + p.offset = pb->position.offset; + + msg_in = p_length(self); + msg_count = msg_count < msg_in && msg_count != -1 ? msg_count : msg_in; + + rob = PyTuple_New(msg_count); + for (cur_msg = 0; cur_msg < msg_count; ++cur_msg) + { + PyObject *msg_tup = NULL; + msg_tup = p_build_tuple(&p); + if (msg_tup == NULL) + { + if (PyErr_Occurred()) + { + Py_DECREF(rob); + return(NULL); + } + break; + } + + PyTuple_SET_ITEM(rob, cur_msg, msg_tup); + } + + pl_truncate(pb->position.list, p.list); + pb->position.list = p.list; + pb->position.offset = p.offset; + if (p.list == NULL) + pb->last = NULL; + + return(rob); +} + +static PyObject * +p_has_message(PyObject *self) +{ + char header[5]; + uint32_t msg_length; + uint32_t copy_amount = 0; + struct p_buffer *pb; + struct p_place p; + PyObject *rob; + + pb = ((struct p_buffer *) self); + p.list = pb->position.list; + p.offset = pb->position.offset; + + copy_amount = p_memcpy(header, &p, 5); + if (copy_amount < 5) + { + Py_INCREF(Py_False); + return(Py_False); + } + p_seek(&p, copy_amount); + memcpy(&msg_length, header + 1, 4); + + msg_length = local_ntohl(msg_length); + if (msg_length < 4) + { + PyErr_Format(PyExc_ValueError, + "invalid message size '%d'", msg_length); + return(NULL); + } + msg_length -= 4; + + rob = p_at_least(&p, msg_length) ? Py_True : Py_False; + Py_INCREF(rob); + return(rob); +} + +static PyObject * +p_next_message(PyObject *self) +{ + struct p_buffer *pb = ((struct p_buffer *) self); + struct p_place p; + PyObject *rob; + + p.offset = pb->position.offset; + p.list = pb->position.list; + + rob = p_build_tuple(&p); + if (rob == NULL) + { + if (!PyErr_Occurred()) + { + rob = Py_None; + Py_INCREF(rob); + } + } + else + { + pl_truncate(pb->position.list, p.list); + pb->position.list = p.list; + pb->position.offset = p.offset; + if (p.list == NULL) + pb->last = NULL; + } + + return(rob); +} + +/* + * p_getvalue - get the unconsumed data in the buffer + * + * Normally used in conjunction with truncate to transfer + * control of the wire to another state machine. + */ +static PyObject * +p_getvalue(PyObject *self) +{ + struct p_buffer *pb = ((struct p_buffer *) self); + struct p_list *l; + uint32_t initial_offset; + PyObject *rob; + + /* + * Don't include data from already read() messages. + */ + initial_offset = pb->position.offset; + + l = pb->position.list; + if (l == NULL) + { + /* + * Empty list. + */ + return(PyBytes_FromString("")); + } + + /* + * Get the first chunk. + */ + rob = PyBytes_FromStringAndSize( + (PyBytes_AS_STRING(l->data) + initial_offset), + PyBytes_GET_SIZE(l->data) - initial_offset + ); + if (rob == NULL) + return(NULL); + + l = l->next; + while (l != NULL) + { + PyBytes_Concat(&rob, l->data); + if (rob == NULL) + break; + + l = l->next; + } + + return(rob); +} + +static PyMethodDef p_methods[] = { + {"write", p_write, METH_O, + PyDoc_STR("write the string to the buffer"),}, + {"read", p_read, METH_VARARGS, + PyDoc_STR("read the number of messages from the buffer")}, + {"truncate", (PyCFunction) p_truncate, METH_NOARGS, + PyDoc_STR("remove the contents of the buffer"),}, + {"has_message", (PyCFunction) p_has_message, METH_NOARGS, + PyDoc_STR("whether the buffer has a message ready"),}, + {"next_message", (PyCFunction) p_next_message, METH_NOARGS, + PyDoc_STR("get and remove the next message--None if none."),}, + {"getvalue", (PyCFunction) p_getvalue, METH_NOARGS, + PyDoc_STR("get the unprocessed data in the buffer")}, + {NULL} +}; + +PyTypeObject pq_message_stream_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "postgresql.port.optimized.pq_message_stream", /* tp_name */ + sizeof(struct p_buffer), /* tp_basicsize */ + 0, /* tp_itemsize */ + p_dealloc, /* tp_dealloc */ + NULL, /* tp_print */ + NULL, /* tp_getattr */ + NULL, /* tp_setattr */ + NULL, /* tp_compare */ + NULL, /* tp_repr */ + NULL, /* tp_as_number */ + &pq_ms_as_sequence, /* tp_as_sequence */ + NULL, /* tp_as_mapping */ + NULL, /* tp_hash */ + NULL, /* tp_call */ + NULL, /* tp_str */ + NULL, /* tp_getattro */ + NULL, /* tp_setattro */ + NULL, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + PyDoc_STR( + "Buffer data on write, return messages on read" + ), /* tp_doc */ + NULL, /* tp_traverse */ + NULL, /* tp_clear */ + NULL, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + NULL, /* tp_iter */ + p_next, /* tp_iternext */ + p_methods, /* tp_methods */ + NULL, /* tp_members */ + NULL, /* tp_getset */ + NULL, /* tp_base */ + NULL, /* tp_dict */ + NULL, /* tp_descr_get */ + NULL, /* tp_descr_set */ + 0, /* tp_dictoffset */ + NULL, /* tp_init */ + NULL, /* tp_alloc */ + p_new, /* tp_new */ + NULL, /* tp_free */ +}; diff --git a/py_opengauss/port/_optimized/element3.c b/py_opengauss/port/_optimized/element3.c new file mode 100644 index 0000000000000000000000000000000000000000..7ea2b82739b9fae3d4b99f0f013dc35dd9e2d237 --- /dev/null +++ b/py_opengauss/port/_optimized/element3.c @@ -0,0 +1,690 @@ +/* + * .port.optimized - .protocol.element3 optimizations + */ +#define include_element3_functions \ + mFUNC(cat_messages, METH_O, "cat the serialized form of the messages in the given list") \ + mFUNC(parse_tuple_message, METH_O, "parse the given tuple data into a tuple of raw data") \ + mFUNC(pack_tuple_data, METH_O, "serialize the give tuple message[tuple of bytes()]") \ + mFUNC(consume_tuple_messages, METH_O, "create a list of parsed tuple data tuples") \ + +/* + * Given a tuple of bytes and None objects, join them into a + * a single bytes object with sizes. + */ +static PyObject * +_pack_tuple_data(PyObject *tup) +{ + PyObject *rob; + Py_ssize_t natts; + Py_ssize_t catt; + + char *buf = NULL; + char *bufpos = NULL; + Py_ssize_t bufsize = 0; + + if (!PyTuple_Check(tup)) + { + PyErr_Format( + PyExc_TypeError, + "pack_tuple_data requires a tuple, given %s", + PyObject_TypeName(tup) + ); + return(NULL); + } + natts = PyTuple_GET_SIZE(tup); + if (natts == 0) + return(PyBytes_FromString("")); + + /* discover buffer size and valid att types */ + for (catt = 0; catt < natts; ++catt) + { + PyObject *ob; + ob = PyTuple_GET_ITEM(tup, catt); + + if (ob == Py_None) + { + bufsize = bufsize + 4; + } + else if (PyBytes_CheckExact(ob)) + { + bufsize = bufsize + PyBytes_GET_SIZE(ob) + 4; + } + else + { + PyErr_Format( + PyExc_TypeError, + "cannot serialize attribute %d, expected bytes() or None, got %s", + (int) catt, PyObject_TypeName(ob) + ); + return(NULL); + } + } + + buf = malloc(bufsize); + if (buf == NULL) + { + PyErr_Format( + PyExc_MemoryError, + "failed to allocate %d bytes of memory for packing tuple data", + bufsize + ); + return(NULL); + } + bufpos = buf; + + for (catt = 0; catt < natts; ++catt) + { + PyObject *ob; + ob = PyTuple_GET_ITEM(tup, catt); + if (ob == Py_None) + { + uint32_t attsize = 0xFFFFFFFFL; /* Indicates NULL */ + Py_MEMCPY(bufpos, &attsize, 4); + bufpos = bufpos + 4; + } + else + { + Py_ssize_t size = PyBytes_GET_SIZE(ob); + uint32_t msg_size; + if (size > 0xFFFFFFFE) + { + PyErr_Format(PyExc_OverflowError, + "data size of %d is greater than attribute capacity", + catt + ); + } + msg_size = local_ntohl((uint32_t) size); + Py_MEMCPY(bufpos, &msg_size, 4); + bufpos = bufpos + 4; + Py_MEMCPY(bufpos, PyBytes_AS_STRING(ob), PyBytes_GET_SIZE(ob)); + bufpos = bufpos + PyBytes_GET_SIZE(ob); + } + } + + rob = PyBytes_FromStringAndSize(buf, bufsize); + free(buf); + return(rob); +} + +/* + * dst must be of PyTuple_Type with at least natts items slots. + */ +static int +_unpack_tuple_data(PyObject *dst, register uint16_t natts, register const char *data, Py_ssize_t data_len) +{ + static const unsigned char null_sequence[4] = {0xFF, 0xFF, 0xFF, 0xFF}; + register PyObject *ob; + register uint16_t cnatt = 0; + register uint32_t attsize; + register const char *next; + register const char *eod = data + data_len; + char attsize_buf[4]; + + while (cnatt < natts) + { + /* + * Need enough data for the attribute size. + */ + next = data + 4; + if (next > eod) + { + PyErr_Format(PyExc_ValueError, + "not enough data available for attribute %d's size header: " + "needed %d bytes, but only %lu remain at position %lu", + cnatt, 4, eod - data, data_len - (eod - data) + ); + return(-1); + } + + Py_MEMCPY(attsize_buf, data, 4); + data = next; + if ((*((uint32_t *) attsize_buf)) == (*((uint32_t *) null_sequence))) + { + /* + * NULL. + */ + Py_INCREF(Py_None); + PyTuple_SET_ITEM(dst, cnatt, Py_None); + } + else + { + attsize = local_ntohl(*((uint32_t *) attsize_buf)); + + next = data + attsize; + if (next > eod || next < data) + { + /* + * Increment caused wrap... + */ + PyErr_Format(PyExc_ValueError, + "attribute %d has invalid size %lu", + cnatt, attsize + ); + return(-1); + } + + ob = PyBytes_FromStringAndSize(data, attsize); + if (ob == NULL) + { + /* + * Probably an OOM error. + */ + return(-1); + } + PyTuple_SET_ITEM(dst, cnatt, ob); + data = next; + } + + cnatt++; + } + + if (data != eod) + { + PyErr_Format(PyExc_ValueError, + "invalid tuple(D) message, %lu remaining " + "bytes after processing %d attributes", + (unsigned long) (eod - data), cnatt + ); + return(-1); + } + + return(0); +} + +static PyObject * +parse_tuple_message(PyObject *self, PyObject *arg) +{ + PyObject *rob; + const char *data; + Py_ssize_t dlen = 0; + uint16_t natts = 0; + + if (PyObject_AsReadBuffer(arg, (const void **) &data, &dlen)) + return(NULL); + + if (dlen < 2) + { + PyErr_Format(PyExc_ValueError, + "invalid tuple message: %d bytes is too small", dlen); + return(NULL); + } + Py_MEMCPY(&natts, data, 2); + natts = local_ntohs(natts); + + rob = PyTuple_New(natts); + if (rob == NULL) + return(NULL); + + if (_unpack_tuple_data(rob, natts, data+2, dlen-2) < 0) + { + Py_DECREF(rob); + return(NULL); + } + + return(rob); +} + +static PyObject * +consume_tuple_messages(PyObject *self, PyObject *list) +{ + Py_ssize_t i; + PyObject *rob; /* builtins.list */ + + if (!PyTuple_Check(list)) + { + PyErr_SetString(PyExc_TypeError, + "consume_tuple_messages requires a tuple"); + return(NULL); + } + rob = PyList_New(PyTuple_GET_SIZE(list)); + if (rob == NULL) + return(NULL); + + for (i = 0; i < PyTuple_GET_SIZE(list); ++i) + { + register PyObject *data; + PyObject *msg, *typ, *ptm; + + msg = PyTuple_GET_ITEM(list, i); + if (!PyTuple_CheckExact(msg) || PyTuple_GET_SIZE(msg) != 2) + { + Py_DECREF(rob); + PyErr_SetString(PyExc_TypeError, + "consume_tuple_messages requires tuples items to be tuples (pairs)"); + return(NULL); + } + + typ = PyTuple_GET_ITEM(msg, 0); + if (!PyBytes_CheckExact(typ) || PyBytes_GET_SIZE(typ) != 1) + { + Py_DECREF(rob); + PyErr_SetString(PyExc_TypeError, + "consume_tuple_messages requires pairs to consist of bytes"); + return(NULL); + } + + /* + * End of tuple messages. + */ + if (*(PyBytes_AS_STRING(typ)) != 'D') + break; + + data = PyTuple_GET_ITEM(msg, 1); + ptm = parse_tuple_message(NULL, data); + if (ptm == NULL) + { + Py_DECREF(rob); + return(NULL); + } + PyList_SET_ITEM(rob, i, ptm); + } + + if (i < PyTuple_GET_SIZE(list)) + { + PyObject *newrob; + newrob = PyList_GetSlice(rob, 0, i); + Py_DECREF(rob); + rob = newrob; + } + + return(rob); +} + +static PyObject * +pack_tuple_data(PyObject *self, PyObject *tup) +{ + return(_pack_tuple_data(tup)); +} + +/* + * Check for overflow before incrementing the buffer size for cat_messages. + */ +#define INCSIZET(XVAR, AMT) do { \ + size_t _amt_ = AMT; \ + size_t _newsize_ = XVAR + _amt_; \ + if (_newsize_ >= XVAR) XVAR = _newsize_; else { \ + PyErr_Format(PyExc_OverflowError, \ + "buffer size overflowed, was %zd bytes, but could not add %d more", XVAR, _amt_); \ + goto fail; } \ +} while(0) + +#define INCMSGSIZE(XVAR, AMT) do { \ + uint32_t _amt_ = AMT; \ + uint32_t _newsize_ = XVAR + _amt_; \ + if (_newsize_ >= XVAR) XVAR = _newsize_; else { \ + PyErr_Format(PyExc_OverflowError, \ + "message size too large, was %u bytes, but could not add %u more", XVAR, _amt_); \ + goto fail; } \ +} while(0) + +/* + * cat_messages - cat the serialized form of the messages in the given list + * + * This offers a fast way to construct the final bytes() object to be sent to + * the wire. It avoids re-creating bytes() objects by calculating the serialized + * size of contiguous, homogenous messages, allocating or extending the buffer + * to accommodate for the needed size, and finally, copying the data into the + * newly available space. + */ +static PyObject * +cat_messages(PyObject *self, PyObject *messages_in) +{ + const static char null_attribute[4] = {0xff,0xff,0xff,0xff}; + PyObject *msgs = NULL; + Py_ssize_t nmsgs = 0; + Py_ssize_t cmsg = 0; + + /* + * Buffer holding the messages' serialized form. + */ + char *buf = NULL; + char *nbuf = NULL; + size_t bufsize = 0; + size_t bufpos = 0; + + /* + * Get a List object for faster rescanning when dealing with copy data. + */ + msgs = PyObject_CallFunctionObjArgs((PyObject *) &PyList_Type, messages_in, NULL); + if (msgs == NULL) + return(NULL); + + nmsgs = PyList_GET_SIZE(msgs); + + while (cmsg < nmsgs) + { + PyObject *ob; + ob = PyList_GET_ITEM(msgs, cmsg); + + /* + * Choose the path, lots of copy data or more singles to serialize? + */ + if (PyBytes_CheckExact(ob)) + { + Py_ssize_t eofc = cmsg; + size_t xsize = 0; + /* find the last of the copy data (eofc) */ + do + { + ++eofc; + /* increase in size to allocate for the adjacent copy messages */ + INCSIZET(xsize, PyBytes_GET_SIZE(ob)); + if (eofc >= nmsgs) + break; /* end of messages in the list? */ + + /* Grab the next message. */ + ob = PyList_GET_ITEM(msgs, eofc); + } while(PyBytes_CheckExact(ob)); + + /* + * Either the end of the list or `ob` is not a data object meaning + * that it's the end of the copy data. + */ + + /* realloc the buf for the new copy data */ + INCSIZET(xsize, (5 * (eofc - cmsg))); + INCSIZET(bufsize, xsize); + nbuf = realloc(buf, bufsize); + if (nbuf == NULL) + { + PyErr_Format( + PyExc_MemoryError, + "failed to allocate %lu bytes of memory for out-going messages", + (unsigned long) bufsize + ); + goto fail; + } + else + { + buf = nbuf; + nbuf = NULL; + } + + /* + * Make the final pass through the copy lines memcpy'ing the data from + * the bytes() objects. + */ + while (cmsg < eofc) + { + uint32_t msg_length = 0; + char *localbuf = buf + bufpos + 1; + buf[bufpos] = 'd'; /* COPY data message type */ + + ob = PyList_GET_ITEM(msgs, cmsg); + INCMSGSIZE(msg_length, (uint32_t) PyBytes_GET_SIZE(ob) + 4); + + INCSIZET(bufpos, 1 + msg_length); + msg_length = local_ntohl(msg_length); + Py_MEMCPY(localbuf, &msg_length, 4); + Py_MEMCPY(localbuf + 4, PyBytes_AS_STRING(ob), PyBytes_GET_SIZE(ob)); + ++cmsg; + } + } + else if (PyTuple_CheckExact(ob)) + { + /* + * Handle 'D' tuple data from a raw Python tuple. + */ + Py_ssize_t eofc = cmsg; + size_t xsize = 0; + + /* find the last of the tuple data (eofc) */ + do + { + Py_ssize_t current_item, nitems; + + nitems = PyTuple_GET_SIZE(ob); + if (nitems > 0xFFFF) + { + PyErr_SetString(PyExc_OverflowError, + "too many attributes in tuple message"); + goto fail; + } + + /* + * The items take *at least* 4 bytes each. + * (The attribute count is considered later) + */ + INCSIZET(xsize, (nitems * 4)); + + for (current_item = 0; current_item < nitems; ++current_item) + { + PyObject *att = PyTuple_GET_ITEM(ob, current_item); + + /* + * Attributes *must* be bytes() or None. + */ + if (PyBytes_CheckExact(att)) + INCSIZET(xsize, PyBytes_GET_SIZE(att)); + else if (att != Py_None) + { + PyErr_Format(PyExc_TypeError, + "cannot serialize tuple message attribute of type '%s'", + Py_TYPE(att)->tp_name); + goto fail; + } + /* + * else it's Py_None and the size will be included later. + */ + } + + ++eofc; + if (eofc >= nmsgs) + break; /* end of messages in the list? */ + + /* Grab the next message. */ + ob = PyList_GET_ITEM(msgs, eofc); + } while(PyTuple_CheckExact(ob)); + + /* + * Either the end of the list or `ob` is not a data object meaning + * that it's the end of the copy data. + */ + + /* + * realloc the buf for the new tuple data + * + * Each D message consumes at least 1 + 4 + 2 bytes: + * 1 for the message type + * 4 for the message size + * 2 for the attribute count + */ + INCSIZET(xsize, (7 * (eofc - cmsg))); + INCSIZET(bufsize, xsize); + nbuf = realloc(buf, bufsize); + if (nbuf == NULL) + { + PyErr_Format( + PyExc_MemoryError, + "failed to allocate %zd bytes of memory for out-going messages", + bufsize + ); + goto fail; + } + else + { + buf = nbuf; + nbuf = NULL; + } + + /* + * Make the final pass through the tuple data memcpy'ing the data from + * the bytes() objects. + * + * No type checks are done here as they should have been done while + * gathering the sizes for the realloc(). + */ + while (cmsg < eofc) + { + Py_ssize_t current_item, nitems; + uint32_t msg_length, out_msg_len; + uint16_t natts; + char *localbuf = (buf + bufpos) + 5; /* skipping the header for now */ + buf[bufpos] = 'D'; /* Tuple data message type */ + + ob = PyList_GET_ITEM(msgs, cmsg); + nitems = PyTuple_GET_SIZE(ob); + + /* + * 4 bytes for the message length, + * 2 bytes for the attribute count and + * 4 bytes for each item in 'ob'. + */ + msg_length = 4 + 2 + (nitems * 4); + + /* + * Set number of attributes. + */ + natts = local_ntohs((uint16_t) nitems); + Py_MEMCPY(localbuf, &natts, 2); + localbuf = localbuf + 2; + + for (current_item = 0; current_item < nitems; ++current_item) + { + PyObject *att = PyTuple_GET_ITEM(ob, current_item); + + if (att == Py_None) + { + Py_MEMCPY(localbuf, &null_attribute, 4); + localbuf = localbuf + 4; + } + else + { + Py_ssize_t attsize = PyBytes_GET_SIZE(att); + uint32_t n_attsize; + + n_attsize = local_ntohl((uint32_t) attsize); + + Py_MEMCPY(localbuf, &n_attsize, 4); + localbuf = localbuf + 4; + Py_MEMCPY(localbuf, PyBytes_AS_STRING(att), attsize); + localbuf = localbuf + attsize; + + INCSIZET(msg_length, attsize); + } + } + + /* + * Summed up the message size while copying the attributes. + */ + out_msg_len = local_ntohl(msg_length); + Py_MEMCPY(buf + bufpos + 1, &out_msg_len, 4); + + /* + * Filled in the data while summing the message size, so + * adjust the buffer position for the next message. + */ + INCSIZET(bufpos, 1 + msg_length); + ++cmsg; + } + } + else + { + PyObject *serialized; + PyObject *msg_type; + int msg_type_size; + uint32_t msg_length; + + /* + * Call the serialize() method on the element object. + * Do this instead of the normal bytes() method to avoid + * the type and size packing overhead. + */ + serialized = PyObject_CallMethodObjArgs(ob, serialize_strob, NULL); + if (serialized == NULL) + goto fail; + if (!PyBytes_CheckExact(serialized)) + { + PyErr_Format( + PyExc_TypeError, + "%s.serialize() returned object of type %s, expected bytes", + PyObject_TypeName(ob), + PyObject_TypeName(serialized) + ); + goto fail; + } + + msg_type = PyObject_GetAttr(ob, msgtype_strob); + if (msg_type == NULL) + { + Py_DECREF(serialized); + goto fail; + } + if (!PyBytes_CheckExact(msg_type)) + { + Py_DECREF(serialized); + Py_DECREF(msg_type); + PyErr_Format( + PyExc_TypeError, + "message's 'type' attribute was %s, expected bytes", + PyObject_TypeName(ob) + ); + goto fail; + } + /* + * Some elements have empty message types--Startup for instance. + * It is important to get the actual size rather than assuming one. + */ + msg_type_size = PyBytes_GET_SIZE(msg_type); + + /* realloc the buf for the new copy data */ + INCSIZET(bufsize, 4 + msg_type_size); + INCSIZET(bufsize, PyBytes_GET_SIZE(serialized)); + nbuf = realloc(buf, bufsize); + if (nbuf == NULL) + { + Py_DECREF(serialized); + Py_DECREF(msg_type); + PyErr_Format( + PyExc_MemoryError, + "failed to allocate %d bytes of memory for out-going messages", + bufsize + ); + goto fail; + } + else + { + buf = nbuf; + nbuf = NULL; + } + + /* + * All necessary information acquired, so fill in the message's data. + */ + buf[bufpos] = *(PyBytes_AS_STRING(msg_type)); + msg_length = PyBytes_GET_SIZE(serialized); + INCMSGSIZE(msg_length, 4); + msg_length = local_ntohl(msg_length); + Py_MEMCPY(buf + bufpos + msg_type_size, &msg_length, 4); + Py_MEMCPY( + buf + bufpos + 4 + msg_type_size, + PyBytes_AS_STRING(serialized), + PyBytes_GET_SIZE(serialized) + ); + bufpos = bufsize; + + Py_DECREF(serialized); + Py_DECREF(msg_type); + ++cmsg; + } + } + + Py_DECREF(msgs); + if (buf == NULL) + /* no messages, no data */ + return(PyBytes_FromString("")); + else + { + PyObject *rob; + rob = PyBytes_FromStringAndSize(buf, bufsize); + free(buf); + + return(rob); + } +fail: + /* pyerr is expected to be set */ + Py_DECREF(msgs); + if (buf != NULL) + free(buf); + return(NULL); +} diff --git a/py_opengauss/port/_optimized/functools.c b/py_opengauss/port/_optimized/functools.c new file mode 100644 index 0000000000000000000000000000000000000000..9a0deea074eded966b0e69d70c75d980d01ac85d --- /dev/null +++ b/py_opengauss/port/_optimized/functools.c @@ -0,0 +1,337 @@ +/* + * .port.optimized - functools.c + * + *//* + * optimizations for postgresql.python package modules. + */ +/* + * process the tuple with the associated callables while + * calling the third object in cases of failure to generalize the exception. + */ +#define include_functools_functions \ + mFUNC(rsetattr, METH_VARARGS, "rsetattr(attr, val, ob) set the attribute to the value *and* return `ob`.") \ + mFUNC(compose, METH_VARARGS, "given a sequence of callables, and an argument for the first call, compose the result.") \ + mFUNC(process_tuple, METH_VARARGS, "process the items in the second argument with the corresponding items in the first argument.") \ + mFUNC(process_chunk, METH_VARARGS, "process the items of the chunk given as the second argument with the corresponding items in the first argument.") + +static PyObject * +_process_tuple(PyObject *procs, PyObject *tup, PyObject *fail) +{ + PyObject *rob; + Py_ssize_t len, i; + + if (!PyTuple_CheckExact(procs)) + { + PyErr_SetString( + PyExc_TypeError, + "process_tuple requires an exact tuple as its first argument" + ); + return(NULL); + } + + if (!PyTuple_Check(tup)) + { + PyErr_SetString( + PyExc_TypeError, + "process_tuple requires a tuple as its second argument" + ); + return(NULL); + } + + len = PyTuple_GET_SIZE(tup); + + if (len != PyTuple_GET_SIZE(procs)) + { + PyErr_Format( + PyExc_TypeError, + "inconsistent items, %d processors and %d items in row", + len, + PyTuple_GET_SIZE(procs) + ); + return(NULL); + } + /* types check out; consistent sizes */ + rob = PyTuple_New(len); + + for (i = 0; i < len; ++i) + { + PyObject *p, *o, *ot, *r; + /* p = processor, + * o = source object, + * ot = o's tuple (temp for application to p), + * r = transformed * output + */ + + /* + * If it's Py_None, that means it's NULL. No processing necessary. + */ + o = PyTuple_GET_ITEM(tup, i); + if (o == Py_None) + { + Py_INCREF(Py_None); + PyTuple_SET_ITEM(rob, i, Py_None); + /* mmmm, cake! */ + continue; + } + + p = PyTuple_GET_ITEM(procs, i); + /* + * Temp tuple for applying *args to p. + */ + ot = PyTuple_New(1); + PyTuple_SET_ITEM(ot, 0, o); + Py_INCREF(o); + + r = PyObject_CallObject(p, ot); + Py_DECREF(ot); + if (r != NULL) + { + /* good, set it and move on. */ + PyTuple_SET_ITEM(rob, i, r); + } + else + { + /* + * Exception caused by >>> p(*ot) + * + * In this case, the failure callback needs to be called + * in order to properly generalize the failure. There are numerous, + * and (sometimes) inconsistent reasons why a tuple cannot be + * processed and therefore a generalized exception raised in the + * context of the original is *very* useful. + */ + Py_DECREF(rob); + rob = NULL; + + /* + * Don't trap BaseException's. + */ + if (PyErr_ExceptionMatches(PyExc_Exception)) + { + PyObject *cause, *failargs, *failedat; + PyObject *exc, *tb; + + /* Store exception to set context after handler. */ + PyErr_Fetch(&exc, &cause, &tb); + PyErr_NormalizeException(&exc, &cause, &tb); + Py_XDECREF(exc); + Py_XDECREF(tb); + + failedat = PyLong_FromSsize_t(i); + if (failedat != NULL) + { + failargs = PyTuple_New(4); + if (failargs != NULL) + { + /* args for the exception "generalizer" */ + PyTuple_SET_ITEM(failargs, 0, cause); + PyTuple_SET_ITEM(failargs, 1, procs); + Py_INCREF(procs); + PyTuple_SET_ITEM(failargs, 2, tup); + Py_INCREF(tup); + PyTuple_SET_ITEM(failargs, 3, failedat); + + r = PyObject_CallObject(fail, failargs); + Py_DECREF(failargs); + if (r != NULL) + { + PyErr_SetString(PyExc_RuntimeError, + "process_tuple exception handler failed to raise" + ); + Py_DECREF(r); + } + } + else + { + Py_DECREF(failedat); + } + } + } + + /* + * Break out of loop to return(NULL); + */ + break; + } + } + + return(rob); +} + +/* + * process the tuple with the associated callables while + * calling the third object in cases of failure to generalize the exception. + */ +static PyObject * +process_tuple(PyObject *self, PyObject *args) +{ + PyObject *tup, *procs, *fail; + + if (!PyArg_ParseTuple(args, "OOO", &procs, &tup, &fail)) + return(NULL); + + return(_process_tuple(procs, tup, fail)); +} + +static PyObject * +_process_chunk_new_list(PyObject *procs, PyObject *tupc, PyObject *fail) +{ + PyObject *rob; + Py_ssize_t i, len; + + /* + * Turn the iterable into a new list. + */ + rob = PyObject_CallFunctionObjArgs((PyObject *) &PyList_Type, tupc, NULL); + if (rob == NULL) + return(NULL); + len = PyList_GET_SIZE(rob); + + for (i = 0; i < len; ++i) + { + PyObject *tup, *r; + /* + * If it's Py_None, that means it's NULL. No processing necessary. + */ + tup = PyList_GetItem(rob, i); /* borrowed ref from list */ + r = _process_tuple(procs, tup, fail); + if (r == NULL) + { + /* process_tuple failed. assume PyErr_Occurred() */ + Py_DECREF(rob); + return(NULL); + } + PyList_SetItem(rob, i, r); + } + + return(rob); +} + +static PyObject * +_process_chunk_from_list(PyObject *procs, PyObject *tupc, PyObject *fail) +{ + PyObject *rob; + Py_ssize_t i, len; + + len = PyList_GET_SIZE(tupc); + rob = PyList_New(len); + if (rob == NULL) + return(NULL); + + for (i = 0; i < len; ++i) + { + PyObject *tup, *r; + /* + * If it's Py_None, that means it's NULL. No processing necessary. + */ + tup = PyList_GET_ITEM(tupc, i); + r = _process_tuple(procs, tup, fail); + if (r == NULL) + { + Py_DECREF(rob); + return(NULL); + } + PyList_SET_ITEM(rob, i, r); + } + + return(rob); +} + +/* + * process the chunk of tuples with the associated callables while + * calling the third object in cases of failure to generalize the exception. + */ +static PyObject * +process_chunk(PyObject *self, PyObject *args) +{ + PyObject *tupc, *procs, *fail; + + if (!PyArg_ParseTuple(args, "OOO", &procs, &tupc, &fail)) + return(NULL); + + if (PyList_Check(tupc)) + { + return(_process_chunk_from_list(procs, tupc, fail)); + } + else + { + return(_process_chunk_new_list(procs, tupc, fail)); + } +} +static PyObject * +rsetattr(PyObject *self, PyObject *args) +{ + PyObject *ob, *attr, *val; + + if (!PyArg_ParseTuple(args, "OOO", &attr, &val, &ob)) + return(NULL); + + if (PyObject_SetAttr(ob, attr, val) < 0) + return(NULL); + + Py_INCREF(ob); + return(ob); +} + +/* + * Override the functools.Composition __call__. + */ +static PyObject * +compose(PyObject *self, PyObject *args) +{ + Py_ssize_t i, len; + PyObject *rob, *argt, *seq, *x; + + if (!PyArg_ParseTuple(args, "OO", &seq, &rob)) + return(NULL); + + Py_INCREF(rob); + if (PyObject_IsInstance(seq, (PyObject *) &PyTuple_Type)) + { + len = PyTuple_GET_SIZE(seq); + for (i = 0; i < len; ++i) + { + x = PyTuple_GET_ITEM(seq, i); + argt = PyTuple_New(1); + PyTuple_SET_ITEM(argt, 0, rob); + rob = PyObject_CallObject(x, argt); + Py_DECREF(argt); + if (rob == NULL) + break; + } + } + else if (PyObject_IsInstance(seq, (PyObject *) &PyList_Type)) + { + len = PyList_GET_SIZE(seq); + for (i = 0; i < len; ++i) + { + x = PyList_GET_ITEM(seq, i); + argt = PyTuple_New(1); + PyTuple_SET_ITEM(argt, 0, rob); + rob = PyObject_CallObject(x, argt); + Py_DECREF(argt); + if (rob == NULL) + break; + } + } + else + { + /* + * Arbitrary sequence. + */ + len = PySequence_Length(seq); + for (i = 0; i < len; ++i) + { + x = PySequence_GetItem(seq, i); + argt = PyTuple_New(1); + PyTuple_SET_ITEM(argt, 0, rob); + rob = PyObject_CallObject(x, argt); + Py_DECREF(x); + Py_DECREF(argt); + if (rob == NULL) + break; + } + } + + return(rob); +} diff --git a/py_opengauss/port/_optimized/module.c b/py_opengauss/port/_optimized/module.c new file mode 100644 index 0000000000000000000000000000000000000000..33f6875931cf5ab470e2bb5d925f94edd9701a0f --- /dev/null +++ b/py_opengauss/port/_optimized/module.c @@ -0,0 +1,151 @@ +/* + * module.c - optimizations for various parts of py-postgresql + * + * This module.c file ties together other classified C source. + * Each filename describing the part of the protocol package that it + * covers. It merely uses CPP includes to bring them into this + * file and then uses some CPP macros to expand the definitions + * in each file. + */ +#include +#include +/* + * If Python didn't find it, it won't include it. + * However, it's quite necessary. + */ +#ifndef HAVE_STDINT_H +#include +#endif + +#define USHORT_MAX ((1<<16)-1) +#define SHORT_MAX ((1<<15)-1) +#define SHORT_MIN (-(1<<15)) + +#define PyObject_TypeName(ob) \ + (((PyTypeObject *) (ob->ob_type))->tp_name) + +/* + * buffer.c needs the message_types object from .protocol.message_types. + * Initialized in PyInit_optimized. + */ +static PyObject *message_types = NULL; +static PyObject *serialize_strob = NULL; +static PyObject *msgtype_strob = NULL; + +static int32_t (*local_ntohl)(int32_t) = NULL; +static short (*local_ntohs)(short) = NULL; + +/* + * optimized module contents + */ +#include "structlib.c" +#include "functools.c" +#include "buffer.c" +#include "wirestate.c" +#include "element3.c" + + +/* cpp abuse, read up on X-Macros if you don't understand */ +#define mFUNC(name, typ, doc) \ + {#name, (PyCFunction) name, typ, PyDoc_STR(doc)}, +static PyMethodDef optimized_methods[] = { + include_element3_functions + include_structlib_functions + include_functools_functions + {NULL} +}; +#undef mFUNC + +static struct PyModuleDef optimized_module = { + PyModuleDef_HEAD_INIT, + "optimized", /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + optimized_methods, +}; + +PyMODINIT_FUNC +PyInit_optimized(void) +{ + PyObject *mod; + PyObject *msgtypes; + PyObject *fromlist, *fromstr; + long l; + + /* make some constants */ + if (serialize_strob == NULL) + { + serialize_strob = PyUnicode_FromString("serialize"); + if (serialize_strob == NULL) + return(NULL); + } + if (msgtype_strob == NULL) + { + msgtype_strob = PyUnicode_FromString("type"); + if (msgtype_strob == NULL) + return(NULL); + } + + mod = PyModule_Create(&optimized_module); + if (mod == NULL) + return(NULL); + +/* cpp abuse; ready types */ +#define mTYPE(name) \ + if (PyType_Ready(&name##_Type) < 0) \ + goto cleanup; \ + if (PyModule_AddObject(mod, #name, \ + (PyObject *) &name##_Type) < 0) \ + goto cleanup; + + /* buffer.c */ + include_buffer_types + /* wirestate.c */ + include_wirestate_types +#undef mTYPE + + l = 1; + if (((char *) &l)[0] == 1) + { + /* little */ + local_ntohl = swap_int4; + local_ntohs = swap_short; + } + else + { + /* big */ + local_ntohl = return_int4; + local_ntohs = return_short; + } + + /* + * Get the message_types tuple to type "instantiation". + */ + fromlist = PyList_New(1); + fromstr = PyUnicode_FromString("message_types"); + PyList_SetItem(fromlist, 0, fromstr); + msgtypes = PyImport_ImportModuleLevel( + "protocol.message_types", + PyModule_GetDict(mod), + PyModule_GetDict(mod), + fromlist, 2 + ); + Py_DECREF(fromlist); + if (msgtypes == NULL) + goto cleanup; + message_types = PyObject_GetAttrString(msgtypes, "message_types"); + Py_DECREF(msgtypes); + + if (!PyObject_IsInstance(message_types, (PyObject *) (&PyTuple_Type))) + { + PyErr_SetString(PyExc_RuntimeError, + "local protocol.message_types.message_types is not a tuple object"); + goto cleanup; + } + + return(mod); +cleanup: + Py_DECREF(mod); + return(NULL); +} diff --git a/py_opengauss/port/_optimized/structlib.c b/py_opengauss/port/_optimized/structlib.c new file mode 100644 index 0000000000000000000000000000000000000000..21ae1458b370d638e6c16ebdd3851a1730e045d2 --- /dev/null +++ b/py_opengauss/port/_optimized/structlib.c @@ -0,0 +1,599 @@ +/* + * .port.optimized - pack and unpack int2, int4, and int8. + */ + +/* + * Define the swap functionality for those endians. + */ +#define swap2(CP) do{register char c; \ + c=CP[1];CP[1]=CP[0];CP[0]=c;\ +}while(0) +#define swap4(P) do{register char c; \ + c=P[3];P[3]=P[0];P[0]=c;\ + c=P[2];P[2]=P[1];P[1]=c;\ +}while(0) +#define swap8(P) do{register char c; \ + c=P[7];P[7]=P[0];P[0]=c;\ + c=P[6];P[6]=P[1];P[1]=c;\ + c=P[5];P[5]=P[2];P[2]=c;\ + c=P[4];P[4]=P[3];P[3]=c;\ +}while(0) + +#define long_funcs \ + mFUNC(int2_pack, METH_O, "PyInt to serialized, int2") \ + mFUNC(int2_unpack, METH_O, "PyInt from serialized, int2") \ + mFUNC(int4_pack, METH_O, "PyInt to serialized, int4") \ + mFUNC(int4_unpack, METH_O, "PyInt from serialized, int4") \ + mFUNC(swap_int2_pack, METH_O, "PyInt to swapped serialized, int2") \ + mFUNC(swap_int2_unpack, METH_O, "PyInt from swapped serialized, int2") \ + mFUNC(swap_int4_pack, METH_O, "PyInt to swapped serialized, int4") \ + mFUNC(swap_int4_unpack, METH_O, "PyInt from swapped serialized, int4") \ + mFUNC(uint2_pack, METH_O, "PyInt to serialized, uint2") \ + mFUNC(uint2_unpack, METH_O, "PyInt from serialized, uint2") \ + mFUNC(uint4_pack, METH_O, "PyInt to serialized, uint4") \ + mFUNC(uint4_unpack, METH_O, "PyInt from serialized, uint4") \ + mFUNC(swap_uint2_pack, METH_O, "PyInt to swapped serialized, uint2") \ + mFUNC(swap_uint2_unpack, METH_O, "PyInt from swapped serialized, uint2") \ + mFUNC(swap_uint4_pack, METH_O, "PyInt to swapped serialized, uint4") \ + mFUNC(swap_uint4_unpack, METH_O, "PyInt from swapped serialized, uint4") \ + +#ifdef HAVE_LONG_LONG +#if SIZEOF_LONG_LONG == 8 +/* + * If the configuration is not consistent with the expectations, + * just use the slower struct.Struct versions. + */ +#define longlong_funcs \ + mFUNC(int8_pack, METH_O, "PyInt to serialized, int8") \ + mFUNC(int8_unpack, METH_O, "PyInt from serialized, int8") \ + mFUNC(swap_int8_pack, METH_O, "PyInt to swapped serialized, int8") \ + mFUNC(swap_int8_unpack, METH_O, "PyInt from swapped serialized, int8") \ + mFUNC(uint8_pack, METH_O, "PyInt to serialized, uint8") \ + mFUNC(uint8_unpack, METH_O, "PyInt from serialized, uint8") \ + mFUNC(swap_uint8_pack, METH_O, "PyInt to swapped serialized, uint8") \ + mFUNC(swap_uint8_unpack, METH_O, "PyInt from swapped serialized, uint8") \ + +#define include_structlib_functions \ + long_funcs \ + longlong_funcs + +#if 0 + Currently not used, so exclude. + +static PY_LONG_LONG +return_long_long(PY_LONG_LONG i) +{ + return(i); +} + +static PY_LONG_LONG +swap_long_long(PY_LONG_LONG i) +{ + swap8(((char *) &i)); + return(i); +} +#endif + +#endif +#endif + +#ifndef include_structlib_functions +#define include_structlib_functions \ + long_funcs +#endif + +static short +swap_short(short s) +{ + swap2(((char *) &s)); + return(s); +} + +static short +return_short(short s) +{ + return(s); +} + +static int32_t +swap_int4(int32_t i) +{ + swap4(((char *) &i)); + return(i); +} + +static int32_t +return_int4(int32_t i) +{ + return(i); +} + +static PyObject * +int2_pack(PyObject *self, PyObject *arg) +{ + long l; + short s; + + l = PyLong_AsLong(arg); + if (PyErr_Occurred()) + return(NULL); + + if (l > SHORT_MAX || l < SHORT_MIN) + { + PyErr_Format(PyExc_OverflowError, + "long '%d' overflows int2", l + ); + return(NULL); + } + + s = (short) l; + return(PyBytes_FromStringAndSize((const char *) &s, 2)); +} + +static PyObject * +swap_int2_pack(PyObject *self, PyObject *arg) +{ + long l; + short s; + + l = PyLong_AsLong(arg); + if (PyErr_Occurred()) + return(NULL); + if (l > SHORT_MAX || l < SHORT_MIN) + { + PyErr_SetString(PyExc_OverflowError, "long too big or small for int2"); + return(NULL); + } + + s = (short) l; + swap2(((char *) &s)); + return(PyBytes_FromStringAndSize((const char *) &s, 2)); +} + +static PyObject * +int2_unpack(PyObject *self, PyObject *arg) +{ + char *c; + short *i; + long l; + Py_ssize_t len; + PyObject *rob; + + if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) + return(NULL); + + if (len < 2) + { + PyErr_SetString(PyExc_ValueError, "not enough data for int2_unpack"); + return(NULL); + } + + i = (short *) c; + l = (long) *i; + rob = PyLong_FromLong(l); + return(rob); +} + +static PyObject * +swap_int2_unpack(PyObject *self, PyObject *arg) +{ + char *c; + short s; + long l; + Py_ssize_t len; + PyObject *rob; + + if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) + return(NULL); + + if (len < 2) + { + PyErr_SetString(PyExc_ValueError, "not enough data for swap_int2_unpack"); + return(NULL); + } + + s = *((short *) c); + swap2(((char *) &s)); + l = (long) s; + rob = PyLong_FromLong(l); + return(rob); +} + +static PyObject * +int4_pack(PyObject *self, PyObject *arg) +{ + long l; + int32_t i; + + l = PyLong_AsLong(arg); + if (PyErr_Occurred()) + return(NULL); + if (!(l <= (long) 0x7FFFFFFFL && l >= (long) (-0x80000000L))) + { + PyErr_Format(PyExc_OverflowError, + "long '%ld' overflows int4", l + ); + return(NULL); + } + i = (int32_t) l; + return(PyBytes_FromStringAndSize((const char *) &i, 4)); +} + +static PyObject * +swap_int4_pack(PyObject *self, PyObject *arg) +{ + long l; + int32_t i; + + l = PyLong_AsLong(arg); + if (PyErr_Occurred()) + return(NULL); + if (!(l <= (long) 0x7FFFFFFFL && l >= (long) (-0x80000000L))) + { + PyErr_Format(PyExc_OverflowError, + "long '%ld' overflows int4", l + ); + return(NULL); + } + i = (int32_t) l; + swap4(((char *) &i)); + return(PyBytes_FromStringAndSize((const char *) &i, 4)); +} + +static PyObject * +int4_unpack(PyObject *self, PyObject *arg) +{ + char *c; + int32_t i; + Py_ssize_t len; + + if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) + return(NULL); + + if (len < 4) + { + PyErr_SetString(PyExc_ValueError, "not enough data for int4_unpack"); + return(NULL); + } + i = *((int32_t *) c); + + return(PyLong_FromLong((long) i)); +} + +static PyObject * +swap_int4_unpack(PyObject *self, PyObject *arg) +{ + char *c; + int32_t i; + Py_ssize_t len; + + if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) + return(NULL); + + if (len < 4) + { + PyErr_SetString(PyExc_ValueError, "not enough data for swap_int4_unpack"); + return(NULL); + } + + i = *((int32_t *) c); + swap4(((char *) &i)); + return(PyLong_FromLong((long) i)); +} + +static PyObject * +uint2_pack(PyObject *self, PyObject *arg) +{ + long l; + unsigned short s; + + l = PyLong_AsLong(arg); + if (PyErr_Occurred()) + return(NULL); + + if (l > USHORT_MAX || l < 0) + { + PyErr_Format(PyExc_OverflowError, + "long '%ld' overflows uint2", l + ); + return(NULL); + } + + s = (unsigned short) l; + return(PyBytes_FromStringAndSize((const char *) &s, 2)); +} + +static PyObject * +swap_uint2_pack(PyObject *self, PyObject *arg) +{ + long l; + unsigned short s; + + l = PyLong_AsLong(arg); + if (PyErr_Occurred()) + return(NULL); + + if (l > USHORT_MAX || l < 0) + { + PyErr_Format(PyExc_OverflowError, + "long '%ld' overflows uint2", l + ); + return(NULL); + } + + s = (unsigned short) l; + swap2(((char *) &s)); + return(PyBytes_FromStringAndSize((const char *) &s, 2)); +} + +static PyObject * +uint2_unpack(PyObject *self, PyObject *arg) +{ + char *c; + unsigned short *i; + long l; + Py_ssize_t len; + PyObject *rob; + + if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) + return(NULL); + + if (len < 2) + { + PyErr_SetString(PyExc_ValueError, "not enough data for uint2_unpack"); + return(NULL); + } + + i = (unsigned short *) c; + l = (long) *i; + rob = PyLong_FromLong(l); + return(rob); +} + +static PyObject * +swap_uint2_unpack(PyObject *self, PyObject *arg) +{ + char *c; + unsigned short s; + long l; + Py_ssize_t len; + PyObject *rob; + + if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) + return(NULL); + if (len < 2) + { + PyErr_SetString(PyExc_ValueError, "not enough data for swap_uint2_unpack"); + return(NULL); + } + + s = *((short *) c); + swap2(((char *) &s)); + l = (long) s; + rob = PyLong_FromLong(l); + return(rob); +} + +static PyObject * +uint4_pack(PyObject *self, PyObject *arg) +{ + uint32_t i; + unsigned long l; + + l = PyLong_AsUnsignedLong(arg); + if (PyErr_Occurred()) + return(NULL); + if (l > 0xFFFFFFFFL) + { + PyErr_Format(PyExc_OverflowError, + "long '%lu' overflows uint4", l + ); + return(NULL); + } + + i = (uint32_t) l; + return(PyBytes_FromStringAndSize((const char *) &i, 4)); +} + +static PyObject * +swap_uint4_pack(PyObject *self, PyObject *arg) +{ + uint32_t i; + unsigned long l; + + l = PyLong_AsUnsignedLong(arg); + if (PyErr_Occurred()) + return(NULL); + if (l > 0xFFFFFFFFL) + { + PyErr_Format(PyExc_OverflowError, + "long '%lu' overflows uint4", l + ); + return(NULL); + } + + i = (uint32_t) l; + swap4(((char *) &i)); + return(PyBytes_FromStringAndSize((const char *) &i, 4)); +} + +static PyObject * +uint4_unpack(PyObject *self, PyObject *arg) +{ + char *c; + uint32_t i; + Py_ssize_t len; + + if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) + return(NULL); + + if (len < 4) + { + PyErr_SetString(PyExc_ValueError, "not enough data for uint4_unpack"); + return(NULL); + } + + i = *((uint32_t *) c); + return(PyLong_FromUnsignedLong((unsigned long) i)); +} + +static PyObject * +swap_uint4_unpack(PyObject *self, PyObject *arg) +{ + char *c; + uint32_t i; + Py_ssize_t len; + + if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) + return(NULL); + if (len < 4) + { + PyErr_SetString(PyExc_ValueError, + "not enough data for swap_uint4_unpack"); + return(NULL); + } + + i = *((uint32_t *) c); + swap4(((char *) &i)); + + return(PyLong_FromUnsignedLong((unsigned long) i)); +} + +#ifdef longlong_funcs +/* + * int8 and "uint8" I/O + */ +static PyObject * +int8_pack(PyObject *self, PyObject *arg) +{ + PY_LONG_LONG l; + + l = PyLong_AsLongLong(arg); + if (l == (PY_LONG_LONG) -1 && PyErr_Occurred()) + return(NULL); + + return(PyBytes_FromStringAndSize((const char *) &l, 8)); +} + +static PyObject * +swap_int8_pack(PyObject *self, PyObject *arg) +{ + PY_LONG_LONG l; + + l = PyLong_AsLongLong(arg); + if (l == (PY_LONG_LONG) -1 && PyErr_Occurred()) + return(NULL); + + swap8(((char *) &l)); + return(PyBytes_FromStringAndSize((const char *) &l, 8)); +} + +static PyObject * +uint8_pack(PyObject *self, PyObject *arg) +{ + unsigned PY_LONG_LONG l; + + l = PyLong_AsUnsignedLongLong(arg); + if (l == (unsigned PY_LONG_LONG) -1 && PyErr_Occurred()) + return(NULL); + + return(PyBytes_FromStringAndSize((const char *) &l, 8)); +} + +static PyObject * +swap_uint8_pack(PyObject *self, PyObject *arg) +{ + unsigned PY_LONG_LONG l; + + l = PyLong_AsUnsignedLongLong(arg); + if (l == (unsigned PY_LONG_LONG) -1 && PyErr_Occurred()) + return(NULL); + + swap8(((char *) &l)); + return(PyBytes_FromStringAndSize((const char *) &l, 8)); +} + +static PyObject * +uint8_unpack(PyObject *self, PyObject *arg) +{ + char *c; + Py_ssize_t len; + unsigned PY_LONG_LONG i; + + if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) + return(NULL); + if (len < 8) + { + PyErr_SetString(PyExc_ValueError, "not enough data for uint8_unpack"); + return(NULL); + } + + i = *((unsigned PY_LONG_LONG *) c); + return(PyLong_FromUnsignedLongLong(i)); +} +static PyObject * +swap_uint8_unpack(PyObject *self, PyObject *arg) +{ + char *c; + Py_ssize_t len; + unsigned PY_LONG_LONG i; + + if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) + return(NULL); + if (len < 8) + { + PyErr_SetString(PyExc_ValueError, + "not enough data for swap_uint8_unpack"); + return(NULL); + } + + i = *((unsigned PY_LONG_LONG *) c); + swap8(((char *) &i)); + return(PyLong_FromUnsignedLongLong(i)); +} + +static PyObject * +int8_unpack(PyObject *self, PyObject *arg) +{ + char *c; + Py_ssize_t len; + PY_LONG_LONG i; + + if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) + return(NULL); + if (len < 8) + { + PyErr_SetString(PyExc_ValueError, + "not enough data for int8_unpack"); + return(NULL); + } + + i = *((PY_LONG_LONG *) c); + return(PyLong_FromLongLong((PY_LONG_LONG) i)); +} + +static PyObject * +swap_int8_unpack(PyObject *self, PyObject *arg) +{ + char *c; + Py_ssize_t len; + PY_LONG_LONG i; + + c = PyBytes_AsString(arg); + if (PyErr_Occurred()) + return(NULL); + + if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) + return(NULL); + if (len < 8) + { + PyErr_SetString(PyExc_ValueError, + "not enough data for swap_int8_unpack"); + return(NULL); + } + + i = *((PY_LONG_LONG *) c); + swap8(((char *) &i)); + return(PyLong_FromLongLong(i)); +} +#endif /* longlong_funcs */ diff --git a/py_opengauss/port/_optimized/wirestate.c b/py_opengauss/port/_optimized/wirestate.c new file mode 100644 index 0000000000000000000000000000000000000000..7947e5cbdec75a1feef44b0a21280d52626b3361 --- /dev/null +++ b/py_opengauss/port/_optimized/wirestate.c @@ -0,0 +1,286 @@ +/* + * .port.optimized.WireState - PQ wire state for COPY. + */ +#define include_wirestate_types \ + mTYPE(WireState) + +struct wirestate +{ + PyObject_HEAD + char size_fragment[4]; /* the header fragment; continuation specifies bytes read so far. */ + PyObject *final_view; /* Py_None unless we reach an unknown message */ + Py_ssize_t remaining_bytes; /* Bytes remaining in message */ + short continuation; /* >= 0 when continuing a fragment */ +}; + +static void +ws_dealloc(PyObject *self) +{ + struct wirestate *ws = ((struct wirestate *) self); + Py_XDECREF(ws->final_view); + Py_TYPE(self)->tp_free(self); +} + +static PyObject * +ws_new(PyTypeObject *subtype, PyObject *args, PyObject *kw) +{ + static char *kwlist[] = {"condition", NULL}; + struct wirestate *ws; + PyObject *rob; + + if (!PyArg_ParseTupleAndKeywords(args, kw, "|O", kwlist, &rob)) + return(NULL); + + rob = subtype->tp_alloc(subtype, 0); + ws = ((struct wirestate *) rob); + + ws->continuation = -1; + ws->remaining_bytes = 0; + ws->final_view = NULL; + + return(rob); +} + +#define CONDITION(MSGTYPE) (MSGTYPE != 'd') + +static PyObject * +ws_update(PyObject *self, PyObject *view) +{ + struct wirestate *ws; + uint32_t remaining_bytes, nmessages = 0; + unsigned char *buf, msgtype; + char size_fragment[4]; + short continuation; + Py_ssize_t position = 0, len; + PyObject *rob, *final_view = NULL; + + if (PyObject_AsReadBuffer(view, (const void **) &buf, &len)) + return(NULL); + + if (len == 0) + { + /* + * Nothing changed. + */ + return(PyLong_FromUnsignedLong(0)); + } + + ws = (struct wirestate *) self; + + if (ws->final_view) + { + PyErr_SetString(PyExc_RuntimeError, "wire state has been terminated"); + return(NULL); + } + + remaining_bytes = ws->remaining_bytes; + continuation = ws->continuation; + + if (continuation >= 0) + { + short sf_len = continuation, added; + /* + * Continuation of message header. + */ + added = 4 - sf_len; + /* + * If the buffer's length does not provide, limit to len. + */ + if (len < added) + added = len; + + Py_MEMCPY(size_fragment, ws->size_fragment, 4); + Py_MEMCPY(size_fragment + sf_len, buf, added); + + continuation = continuation + added; + if (continuation == 4) + { + /* + * Completed the size part of the header. + */ + Py_MEMCPY(&remaining_bytes, size_fragment, 4); + remaining_bytes = (local_ntohl((int32_t) remaining_bytes)); + if (remaining_bytes < 4) + goto invalid_message_header; + + remaining_bytes = remaining_bytes - sf_len; + if (remaining_bytes == 0) + ++nmessages; + continuation = -1; + } + else + { + /* + * Consumed more of the header, but more is still needed. + * Jump past the main loop. + */ + goto return_nmessages; + } + } + + do + { + if (remaining_bytes > 0) + { + position = position + remaining_bytes; + if (position > len) + { + remaining_bytes = position - len; + position = len; + } + else + { + remaining_bytes = 0; + ++nmessages; + } + } + + /* + * Done with view. + */ + if (position >= len) + break; + + /* + * Validate message type. + */ + msgtype = *(buf + position); + if (CONDITION(msgtype)) + { + final_view = PySequence_GetSlice(view, position, len); + break; + } + + /* + * Have enough for a complete header? + */ + if (len - position < 5) + { + /* + * Start a continuation. Message type has been verified. + */ + continuation = (len - position) - 1; + Py_MEMCPY(size_fragment, buf + position + 1, (Py_ssize_t) continuation); + break; + } + + /* + * +1 to include the message type. + */ + Py_MEMCPY(&remaining_bytes, buf + position + 1, 4); + remaining_bytes = local_ntohl((int32_t) remaining_bytes) + 1; + if (remaining_bytes < 5) + goto invalid_message_header; + } while(1); + +return_nmessages: + rob = PyLong_FromUnsignedLong(nmessages); + if (rob == NULL) + { + Py_XDECREF(final_view); + return(NULL); + } + + /* Commit new state */ + ws->remaining_bytes = remaining_bytes; + ws->final_view = final_view; + ws->continuation = continuation; + Py_MEMCPY(ws->size_fragment, size_fragment, 4); + return(rob); + +invalid_message_header: + PyErr_SetString(PyExc_ValueError, "message header contained an invalid size"); + return(NULL); +} + +static PyMethodDef ws_methods[] = { + {"update", ws_update, METH_O, + PyDoc_STR("update the state of the wire using the given buffer object"),}, + {NULL} +}; + +PyObject * +ws_size_fragment(PyObject *self, void *closure) +{ + struct wirestate *ws; + ws = (struct wirestate *) self; + + return(PyBytes_FromStringAndSize(ws->size_fragment, + ws->continuation <= 0 ? 0 : ws->continuation)); +} + +PyObject * +ws_remaining_bytes(PyObject *self, void *closure) +{ + struct wirestate *ws; + ws = (struct wirestate *) self; + return(PyLong_FromLong( + ws->continuation == -1 ? ws->remaining_bytes : -1 + )); +} + +PyObject * +ws_final_view(PyObject *self, void *closure) +{ + struct wirestate *ws; + PyObject *rob; + + ws = (struct wirestate *) self; + rob = ws->final_view ? ws->final_view : Py_None; + + Py_INCREF(rob); + return(rob); +} + +static PyGetSetDef ws_getset[] = { + {"size_fragment", ws_size_fragment, NULL, + PyDoc_STR("The data acculumated for the continuation."), NULL,}, + {"remaining_bytes", ws_remaining_bytes, NULL, + PyDoc_STR("Number bytes necessary to complete the current message."), NULL,}, + {"final_view", ws_final_view, NULL, + PyDoc_STR("A memoryview of the data that triggered the CONDITION()."), NULL,}, + {NULL} +}; + +PyTypeObject WireState_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "postgresql.port.optimized.WireState", /* tp_name */ + sizeof(struct wirestate), /* tp_basicsize */ + 0, /* tp_itemsize */ + ws_dealloc, /* tp_dealloc */ + NULL, /* tp_print */ + NULL, /* tp_getattr */ + NULL, /* tp_setattr */ + NULL, /* tp_compare */ + NULL, /* tp_repr */ + NULL, /* tp_as_number */ + NULL, /* tp_as_sequence */ + NULL, /* tp_as_mapping */ + NULL, /* tp_hash */ + NULL, /* tp_call */ + NULL, /* tp_str */ + NULL, /* tp_getattro */ + NULL, /* tp_setattro */ + NULL, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + PyDoc_STR("Track the state of the wire."), + /* tp_doc */ + NULL, /* tp_traverse */ + NULL, /* tp_clear */ + NULL, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + NULL, /* tp_iter */ + NULL, /* tp_iternext */ + ws_methods, /* tp_methods */ + NULL, /* tp_members */ + ws_getset, /* tp_getset */ + NULL, /* tp_base */ + NULL, /* tp_dict */ + NULL, /* tp_descr_get */ + NULL, /* tp_descr_set */ + 0, /* tp_dictoffset */ + NULL, /* tp_init */ + NULL, /* tp_alloc */ + ws_new, /* tp_new */ + NULL, /* tp_free */ +}; diff --git a/py_opengauss/port/signal1_msw.py b/py_opengauss/port/signal1_msw.py new file mode 100644 index 0000000000000000000000000000000000000000..683a3a2ac9081bba9c8c8454e8be2d36a28a6f1b --- /dev/null +++ b/py_opengauss/port/signal1_msw.py @@ -0,0 +1,76 @@ +## +# .port.signal1_msw +## +""" +Support for PG signals on Windows platforms. + +This implementation supports all known versions of PostgreSQL. (2010) + +CallNamedPipe: + http://msdn.microsoft.com/en-us/library/aa365144%28VS.85%29.aspx +""" +import errno +from ctypes import windll, wintypes, pointer + +# CallNamedPipe from kernel32. +CallNamedPipeA = windll.kernel32.CallNamedPipeA +CallNamedPipeA.restype = wintypes.BOOL +CallNamedPipeA.argtypes = ( + wintypes.LPCSTR, # in namedpipename + wintypes.LPVOID, # in inbuffer (for signal number) + wintypes.DWORD, # in inbuffersize (always 1) + wintypes.LPVOID, # OutBuffer (signal return validation) + wintypes.DWORD, # in OutBufferSize (always 1) + wintypes.LPVOID, # out bytes read, really LPDWORD. + wintypes.DWORD, # in timeout +) + +from signal import SIGTERM, SIGINT, SIG_DFL +# SYNC: Values taken from the port/win32.h file. +SIG_DFL=0 +SIGHUP=1 +SIGQUIT=3 +SIGTRAP=5 +SIGABRT=22 # /* Set to match W32 value -- not UNIX value */ +SIGKILL=9 +SIGPIPE=13 +SIGALRM=14 +SIGSTOP=17 +SIGTSTP=18 +SIGCONT=19 +SIGCHLD=20 +SIGTTIN=21 +SIGTTOU=22 # /* Same as SIGABRT -- no problem, I hope */ +SIGWINCH=28 +SIGUSR1=30 +SIGUSR2=31 + +# SYNC: port.h +PG_SIGNAL_COUNT = 32 + +# In the situation of another variant, another module should be constructed. +def kill(pid : int, signal : int, timeout = 1000, dword1 = wintypes.DWORD(1)): + """ + Re-implementation of pg_kill for win32 using ctypes. + """ + if pid <= 0: + raise OSError(errno.EINVAL, "process group not supported") + if signal < 0 or signal >= PG_SIGNAL_COUNT: + raise OSError(errno.EINVAL, "unsupported signal number") + inbuffer = pointer(wintypes.BYTE(signal)) + outbuffer = pointer(wintypes.BYTE(0)) + outbytes = pointer(wintypes.DWORD(0)) + pidpipe = br'\\.\pipe\pgsignal_' + str(pid).encode('ascii') + timeout = wintypes.DWORD(timeout) + r = CallNamedPipeA( + pidpipe, inbuffer, dword1, outbuffer, dword1, outbytes, timeout + ) + if r: + if outbuffer.contents.value == signal: + if outbytes.contents.value == 1: + # success + return + # Don't bother emulating the other failure cases/abstractions. + # CallNamedPipeA should raise a WindowsError on those failures. + raise OSError(errno.ESRCH, "unexpected output from CallNamedPipeA") +__docformat__ = 'reStructuredText' diff --git a/py_opengauss/project.py b/py_opengauss/project.py new file mode 100644 index 0000000000000000000000000000000000000000..40f00ca7c0b3d90d92bb9e812b9723f067f1096f --- /dev/null +++ b/py_opengauss/project.py @@ -0,0 +1,12 @@ +""" +Project information. +""" + +name = 'py-opengauss' +identity = 'http://github.com/vimiix/py-opengauss' + +meaculpa = 'Python+openGauss' +abstract = 'Driver and tools library for openGauss' + +version_info = (1, 3, 1) # dev based on py-postgresql version 1.3.0 +version = '.'.join(map(str, version_info)) diff --git a/py_opengauss/protocol/__init__.py b/py_opengauss/protocol/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3550870a9a1583a814e526f3b1a7c1bbb935f8d3 --- /dev/null +++ b/py_opengauss/protocol/__init__.py @@ -0,0 +1,6 @@ +## +# .protocol +## +""" +PQ protocol facilities +""" diff --git a/py_opengauss/protocol/buffer.py b/py_opengauss/protocol/buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..e491928c02d666f4a5bac02211ccdd6709eba070 --- /dev/null +++ b/py_opengauss/protocol/buffer.py @@ -0,0 +1,16 @@ +## +# .protocol.buffer +## +""" +This is an abstraction module that provides the working buffer implementation. +If a C compiler is not available on the system that built the package, the slower +`postgresql.protocol.pbuffer` module can be used in +`postgresql.port.optimized.buffer`'s absence. + +This provides a convenient place to import the necessary module without +concerning the local code with the details. +""" +try: + from ..port.optimized import pq_message_stream +except ImportError: + from .pbuffer import pq_message_stream diff --git a/py_opengauss/protocol/client3.py b/py_opengauss/protocol/client3.py new file mode 100644 index 0000000000000000000000000000000000000000..f7e21750b9030b770547d9dbf7a67fe3264c11e0 --- /dev/null +++ b/py_opengauss/protocol/client3.py @@ -0,0 +1,558 @@ +## +# .protocol.client3 +## +""" +Protocol version 3.0 client and tools. +""" +import os +import weakref +from .buffer import pq_message_stream +from . import element3 as element +from . import xact3 as xact + +__all__ = ('Connection',) + +client_detected_protocol_error = element.ClientError(( + (b'S', 'FATAL'), + (b'C', '08P01'), + (b'M', "wire-data caused exception in protocol transaction"), + (b'H', "Protocol error detected."), +)) + +client_connect_timeout = element.ClientError(( + (b'S', 'FATAL'), + (b'C', '--TOE'), + (b'M', "connect timed out"), +)) + +not_pq_error = element.ClientError(( + # ProtocolError + (b'S', 'FATAL'), + (b'C', '08P01'), + (b'M', 'server did not support SSL negotiation'), + (b'H', 'The server is probably not PostgreSQL.'), +)) + +no_ssl_error = element.ClientError(( + (b'S', 'FATAL'), + # InsecurityError + (b'C', '--SEC'), + (b'M', 'SSL was required, and the server could not accommodate'), +)) + +# Details in __context__ +ssl_failed_error = element.ClientError(( + (b'S', 'FATAL'), + # InsecurityError + (b'C', '--SEC'), + (b'M', 'SSL negotiation caused exception'), +)) + +# failed to complete the connection, but no error set. +# indicates a programmer error. +partial_connection_error = element.ClientError(( + (b'S', 'FATAL'), + (b'C', '--XXX'), + (b'M', "failed to complete negotiation"), + (b'H', "Negotiation failed to completed, but no " \ + "error was attributed on the connection."), +)) + +eof_error = element.ClientError(( + (b'S', 'FATAL'), + (b'C', '08006'), + (b'M', 'unexpected EOF from server'), + (b'D', "Zero-length read from the connection's socket."), +)) + +class Connection(object): + """ + A PQv3 connection. + + Operations are designed to not raise exceptions. The user of the + connection must check for failures. This is done to encourage users + to use their own Exception hierarchy. + """ + _tracer = None + def tracer(): + def fget(self): + return self._tracer + def fset(self, value): + self._tracer = value + self.write_messages = self.traced_write_messages + self.read_messages = self.traced_read_messages + def fdel(self): + del self._tracer + self.write_messages = self.standard_write_messages + self.read_messages = self.standard_read_messages + doc = 'Callable object to pass protocol trace strings to. '\ + '(Normally a write method.)' + return locals() + tracer = property(**tracer()) + + def synchronize(self): + """ + Explicitly send a Synchronize message to the backend. + Useful for forcing the completion of lazily processed transactions. + + NOTE: This will not cause trash to be taken out. + """ + if self.xact is not None: + self.complete() + x = xact.Instruction((element.SynchronizeMessage,)) + self.xact = x + self.complete() + + def interrupt(self, timeout = None): + cq = element.CancelRequest(self.backend_id, self.key).bytes() + s = self.socket_factory(timeout = timeout) + try: + s.sendall(cq) + finally: + s.close() + + def connect(self, ssl = None, timeout = None): + """ + Establish the connection to the server. + + If `ssl` is None, the socket will not be secured. + If `ssl` is True, the socket will be secured, but it will + close the connection and return if SSL is not available. + If `ssl` is False, the socket will attempt to be secured, but + will continue even in the event of a server that does not + support SSL. + + `timeout` will be passed directly to the configured `socket_factory`. + """ + if hasattr(self, 'socket'): + # If there's a socket attribute it normally means + # that the connection has already been connected. + # Successfully or not; doesn't matter. + return + + # The existence of the socket attribute indicates an attempt was made. + self.socket = None + try: + self.socket = self.socket_factory(timeout = timeout) + except ( + self.socket_factory.timeout_exception, + self.socket_factory.fatal_exception + ) as err: + self.xact.state = xact.Complete + self.xact.fatal = True + self.xact.exception = err + if self.socket_factory.timed_out(err): + self.xact.error_message = client_connect_timeout + else: + errmsg = self.socket_factory.fatal_exception_message(err) + # It's an error that occurred during socket creation/connection. + # Even if there isn't a known fatal message, + # identify it as fatal and set an ambiguous message. + self.xact.error_message = element.ClientError(( + (b'S', 'FATAL'), + # ConnectionRejectionError + (b'C', '08004'), + (b'M', errmsg or "could not connect"), + )) + return + + if ssl is not None: + # if ssl is True, ssl is *required* + # if ssl is False, ssl will be tried, but not required + # if ssl is None, no SSL negotiation will happen + self.ssl_negotiation = supported = self.negotiate_ssl() + + # b'S' or b'N' was *not* received. + if supported is None: + # probably not PQv3.. + self.xact.fatal = True + self.xact.error_message = not_pq_error + self.xact.state = xact.Complete + return + + # b'N' was received, but ssl is required. + if not supported and ssl is True: + # ssl is required.. + self.xact.fatal = True + self.xact.error_message = no_ssl_error + self.xact.state = xact.Complete + return + + if supported: + # Make an SSL connection. + try: + self.socket = self.socket_factory.secure(self.socket) + except Exception as err: + # Any exception marks a failure. + self.xact.exception = err + self.xact.fatal = True + self.xact.state = xact.Complete + self.xact.error_message = ssl_failed_error + return + # time to negotiate + negxact = self.xact + self.complete() + if negxact.state is xact.Complete and negxact.fatal is None: + self.key = negxact.killinfo.key + self.backend_id = negxact.killinfo.pid + elif not hasattr(self.xact, 'error_message'): + # if it's not complete, something strange happened. + # make sure to clean up... + self.xact.fatal = True + self.xact.state = xact.Complete + self.xact.error_message = partial_connection_error + + def negotiate_ssl(self) -> (bool, None): + """ + Negotiate SSL + + If SSL is available--received b'S'--return True. + If SSL is unavailable--received b'N'--return False. + Otherwise, return None. Indicates non-PQv3 endpoint. + """ + r = element.NegotiateSSLMessage.bytes() + while r: + r = r[self.socket.send(r):] + status = self.socket.recv(1) + if status == b'S': + return True + elif status == b'N': + return False + # probably not postgresql. + return None + + def read_into(self, Complete = xact.Complete): + """ + read data from the wire and write it into the message buffer. + """ + BUFFER_HAS_MSG = self.message_buffer.has_message + BUFFER_WRITE_MSG = self.message_buffer.write + RECV_DATA = self.socket.recv + RECV_BYTES = self.recvsize + XACT = self.xact + while not BUFFER_HAS_MSG(): + if self.read_data is not None: + BUFFER_WRITE_MSG(self.read_data) + self.read_data = None + # If the read_data satisfied a message, + # no more data should be read. + continue + try: + self.read_data = RECV_DATA(RECV_BYTES) + except self.socket_factory.fatal_exception as e: + msg = self.socket_factory.fatal_exception_message(e) + if msg is not None: + XACT.state = Complete + XACT.fatal = True + XACT.exception = e + XACT.error_message = element.ClientError(( + (b'S', 'FATAL'), + (b'C', '08006'), + (b'M', msg), + )) + return False + else: + # It's probably a non-fatal error, + # timeout or try again.. + raise + + ## + # nothing read from a blocking socket? it's over. + if self.read_data == b'': + XACT.state = Complete + XACT.fatal = True + XACT.error_message = eof_error + return False + + # Got data. Put it in the buffer and clear read_data. + self.read_data = BUFFER_WRITE_MSG(self.read_data) + return True + + def standard_read_messages(self): + """ + Read more messages into self.read when self.read is empty. + """ + r = True + if not self.read: + # get more data from the wire and + # write it into the message buffer. + r = self.read_into() + self.read = self.message_buffer.read() + return r + read_messages = standard_read_messages + + def send_message_data(self): + """ + send all `message_data`. + + If an exception occurs, it will check if the exception + is fatal or not. + """ + SEND_DATA = self.socket.send + try: + while self.message_data: + # Send data while there is data to send. + self.message_data = self.message_data[ + SEND_DATA(self.message_data): + ] + except self.socket_factory.fatal_exception as e: + msg = self.socket_factory.fatal_exception_message(e) + if msg is not None: + # it's fatal. + self.xact.state = xact.Complete + self.xact.fatal = True + self.xact.exception = e + self.xact.error_message = element.ClientError(( + (b'S', 'FATAL'), + (b'C', '08006'), + (b'M', msg), + )) + return False + else: + # It wasn't fatal, so just raise + raise + return True + + def standard_write_messages(self, messages, + cat_messages = element.cat_messages + ): + """ + Protocol message writer. + """ + if self.writing is not self.written: + self.message_data += cat_messages(self.writing) + self.written = self.writing + + if messages is not self.writing: + self.writing = messages + self.message_data += cat_messages(self.writing) + self.written = self.writing + return self.send_message_data() + write_messages = standard_write_messages + + def traced_write_messages(self, messages): + """ + `message_writer` used when tracing. + """ + for msg in messages: + t = getattr(msg, 'type', None) + if t is not None: + data_out = msg.bytes() + self._tracer('↑ {type}({lend}): {data}{nl}'.format( + type = repr(t)[2:-1], + lend = len(data_out), + data = repr(data_out), + nl = os.linesep + )) + else: + # It's not a message instance, so assume raw data. + self._tracer('↑__(%d): %r%s' %( + len(msg), msg, os.linesep + )) + return self.standard_write_messages(messages) + + def traced_read_messages(self): + """ + `message_reader` used when tracing. + """ + r = self.standard_read_messages() + for msg in self.read: + self._tracer('↓ %r(%d): %r%s' %( + msg[0], len(msg[1]), msg[1], os.linesep) + ) + return r + + def take_out_trash(self): + """ + close cursors and statements slated for closure. + """ + xm = [] + cursors = 0 + for x in self.garbage_cursors: + xm.append(element.ClosePortal(x)) + cursors += 1 + statements = 0 + for x in self.garbage_statements: + xm.append(element.CloseStatement(x)) + statements += 1 + xm.append(element.SynchronizeMessage) + x = xact.Instruction(xm) + self.xact = x + del self.garbage_cursors[:cursors] + del self.garbage_statements[:statements] + self.complete() + + def push(self, x): + """ + setup the given transaction to be processed. + """ + # Push any queued closures onto the transaction or a new transaction. + if x.state is xact.Complete: + # It's already complete. + return + if self.xact is not None: + self.complete() + if self.xact is None: + if self.garbage_statements or self.garbage_cursors: + # This *has* to be done before a new transaction + # is pushed. + self.take_out_trash() + if self.xact is None: + # set it as the current transaction and begin + self.xact = x + # start it up + self.step() + + def step(self): + """ + Make a single transition on the transaction. + + This should be used during COPY TO STDOUT or large result sets + to stream information out. + """ + x = self.xact + try: + dir, op = x.state + if dir is xact.Sending: + self.write_messages(x.messages) + # The "op" callable will either switch the state, or + # set the 'messages' attribute with a new sequence + # of message objects for more writing. + op() + elif dir is xact.Receiving: + self.read_messages() + self.read = self.read[op(self.read):] + self.state = getattr(x, 'last_ready', self.state) + else: + raise RuntimeError( + "unexpected PQ transaction state: " + repr(dir) + ) + except self.socket_factory.try_again_exception as e: + # Unlike _complete, this catches at the outermost level + # as there is no loop here for more transitioning. + if self.socket_factory.try_again(e): + # Can't read or write, ATM? Consider it a transition. :( + return + else: + raise + if x.state is xact.Complete and \ + getattr(self.xact, 'fatal', None) is not True: + # only remove the transaction if it's *not* fatal + self.xact = None + + def complete(self): + """ + Complete the current transaction. + """ + # Continue to transition until all transactions have been + # completed, or an exception occurs that does not signal retry. + x = self.xact + R = xact.Receiving + S = xact.Sending + C = xact.Complete + READ_MORE = self.read_messages + WRITE_MESSAGES = self.write_messages + while x.state is not C: + try: + while x.state[0] is R: + if READ_MORE(): + self.read = self.read[x.state[1](self.read):] + # push() always takes one step, so it is likely that + # the transaction is done sending out data by the time + # complete() is called. + while x.state[0] is S: + if WRITE_MESSAGES(x.messages): + x.state[1]() + # Multiple calls to get() without signaling + # completion *should* yield the same set over + # and over again. + except self.socket_factory.try_again_exception as e: + if not self.socket_factory.try_again(e): + raise + except Exception as proto_exc: + # If an exception is raised here, it's a protocol or a programming error. + # XXX: It may be useful to have this closer to the actual + # message so that a more informative message can be given. + x.fatal = True + x.state = xact.Complete + x.exception = proto_exc + x.error_message = client_detected_protocol_error + self.state = b'' + return + self.state = getattr(x, 'last_ready', self.state) + if getattr(x, 'fatal', None) is not True: + # only remove the transaction if it's *not* fatal + self.xact = None + + def register_cursor(self, cursor, pq_cursor_id): + trash = self.trash_cursor + self.cursors[pq_cursor_id] = weakref.ref(cursor, lambda ref: trash(pq_cursor_id)) + + def trash_cursor(self, pq_cursor_id): + try: + del self.cursors[pq_cursor_id] + except KeyError: + pass + self.garbage_cursors.append(pq_cursor_id) + + def register_statement(self, statement, pq_statement_id): + trash = self.trash_statement + self.statements[pq_statement_id] = weakref.ref(statement, lambda ref: trash(pq_statement_id)) + + def trash_statement(self, pq_statement_id): + try: + del self.statements[pq_statement_id] + except KeyError: + pass + self.garbage_statements.append(pq_statement_id) + + def __str__(self): + if hasattr(self, 'ssl_negotiation'): + if self.ssl_negotiation is True: + ssl = 'SSL' + elif self.ssl_negotiation is False: + ssl = 'NOSSL after SSL' + else: + ssl = 'NOSSL' + + excstr = ''.join(self.exception_string(type(self.exception), self.exception)) + return str(self.socket_factory) \ + + ' -> (' + ssl + ')' \ + + os.linesep + excstr.strip() + + def __init__(self, socket_factory, startup, password = b'',): + """ + Create a connection. + + This does not establish the connection, it only initializes it. + """ + self.key = None + self.backend_id = None + + self.socket_factory = socket_factory + self.xact = xact.Negotiation( + element.Startup(startup), password + ) + + self.cursors = {} + self.statements = {} + + self.garbage_statements = [] + self.garbage_cursors = [] + + self.message_buffer = pq_message_stream() + self.recvsize = 8192 + + self.read = () + # bytes received. + self.read_data = None + + # serialized message data to be written + self.message_data = b'' + # messages to be written. + self.writing = None + # messages that have already been transformed into bytes. + # (used to detect whether messages have already been written) + self.written = None + + self.state = 'INITIALIZED' diff --git a/py_opengauss/protocol/element3.py b/py_opengauss/protocol/element3.py new file mode 100644 index 0000000000000000000000000000000000000000..326b75bc5fdf94af8d56c775a7c4b5b574928c53 --- /dev/null +++ b/py_opengauss/protocol/element3.py @@ -0,0 +1,985 @@ +## +# .protocol.element3 +## +""" +PQ version 3.0 elements. +""" +import sys +import os +import pprint +from struct import unpack, Struct +from .message_types import message_types +from ..python.structlib import ushort_pack, ushort_unpack, ulong_pack, ulong_unpack + +try: + from ..port.optimized import parse_tuple_message, pack_tuple_data +except ImportError: + def pack_tuple_data(atts, + none = None, + ulong_pack = ulong_pack, + blen = bytes.__len__ + ): + return b''.join([ + b'\xff\xff\xff\xff' + if x is none + else (ulong_pack(blen(x)) + x) + for x in atts + ]) + +try: + from ..port.optimized import cat_messages +except ImportError: + from ..python.structlib import lH_pack, long_pack + # Special case tuple()'s + def _pack_tuple(t, + blen = bytes.__len__, + tlen = tuple.__len__, + pack_head = lH_pack, + ulong_pack = ulong_pack, + ptd = pack_tuple_data, + ): + # NOTE: duplicated from above + r = b''.join([ + b'\xff\xff\xff\xff' + if x is None + else (ulong_pack(blen(x)) + x) + for x in t + ]) + return pack_head((blen(r) + 6, tlen(t))) + r + + def cat_messages(messages, + lpack = long_pack, + blen = bytes.__len__, + tuple = tuple, + pack_tuple = _pack_tuple + ): + return b''.join([ + (x.bytes() if x.__class__ is not bytes else ( + b'd' + lpack(blen(x) + 4) + x + )) if x.__class__ is not tuple else ( + b'D' + pack_tuple(x) + ) for x in messages + ]) + del _pack_tuple, lH_pack, long_pack + +StringFormat = b'\x00\x00' +BinaryFormat = b'\x00\x01' + +class Message(object): + bytes_struct = Struct("!cL") + __slots__ = () + def __repr__(self): + return '%s.%s(%s)' %( + type(self).__module__, + type(self).__name__, + ', '.join([repr(getattr(self, x)) for x in self.__slots__]) + ) + + def __eq__(self, ob): + return isinstance(ob, type(self)) and self.type == ob.type and \ + not False in ( + getattr(self, x) == getattr(ob, x) + for x in self.__slots__ + ) + + def bytes(self): + data = self.serialize() + return self.bytes_struct.pack(self.type, len(data) + 4) + data + + @classmethod + def parse(typ, data): + return typ(data) + +class StringMessage(Message): + """ + A message based on a single string component. + """ + type = b'' + __slots__ = ('data',) + + def __repr__(self): + return '%s.%s(%s)' %( + type(self).__module__, + type(self).__name__, + repr(self.data), + ) + + def __getitem__(self, i): + return self.data.__getitem__(i) + + def __init__(self, data): + self.data = data + + def serialize(self): + return bytes(self.data) + b'\x00' + + @classmethod + def parse(typ, data): + if not data.endswith(b'\x00'): + raise ValueError("string message not NUL-terminated") + return typ(data[:-1]) + +class TupleMessage(tuple, Message): + """ + A message who's data is based on a tuple structure. + """ + type = b'' + __slots__ = () + + def __repr__(self): + return '%s.%s(%s)' %( + type(self).__module__, + type(self).__name__, + tuple.__repr__(self) + ) + +class Void(Message): + """ + An absolutely empty message. When serialized, it always yields an empty string. + """ + type = b'' + __slots__ = () + + def bytes(self): + return b'' + + def serialize(self): + return b'' + + def __new__(typ, *args, **kw): + return VoidMessage +VoidMessage = Message.__new__(Void) + +def dict_message_repr(self): + return '%s.%s(**%s)' %( + type(self).__module__, + type(self).__name__, + pprint.pformat(dict(self)) + ) + +class WireMessage(Message): + def __init__(self, typ_data): + self.type = message_types[typ_data[0][0]] + self.data = typ_data[1] + + def serialize(self): + return self[1] + + @classmethod + def parse(typ, data): + if ulong_unpack(data[1:5]) != len(data) - 1: + raise ValueError( + "invalid wire message where data is %d bytes and " \ + "internal size stamp is %d bytes" %( + len(data), ulong_unpack(data[1:5]) + 1 + ) + ) + return typ((data[0:1], data[5:])) + +class EmptyMessage(Message): + """ + An abstract message that is always empty. + """ + __slots__ = () + type = b'' + + def __new__(typ): + return typ.SingleInstance + + def serialize(self): + return b'' + + @classmethod + def parse(typ, data): + if data != b'': + raise ValueError("empty message(%r) had data" %(typ.type,)) + return typ.SingleInstance + +class Notify(Message): + """ + Asynchronous notification message. + """ + type = message_types[b'A'[0]] + __slots__ = ('pid', 'channel', 'payload',) + + def __init__(self, pid, channel, payload = b''): + self.pid = pid + self.channel = channel + self.payload = payload + + def serialize(self): + return ulong_pack(self.pid) + \ + self.channel + b'\x00' + \ + self.payload + b'\x00' + + @classmethod + def parse(typ, data): + pid = ulong_unpack(data) + channel, payload, _ = data[4:].split(b'\x00', 2) + return typ(pid, channel, payload) + +class ShowOption(Message): + """ + GUC variable information from backend + """ + type = message_types[b'S'[0]] + __slots__ = ('name', 'value') + + def __init__(self, name, value): + self.name = name + self.value = value + + def serialize(self): + return self.name + b'\x00' + self.value + b'\x00' + + @classmethod + def parse(typ, data): + return typ(*(data.split(b'\x00', 2)[0:2])) + +class Complete(StringMessage): + """ + Command completion message. + """ + type = message_types[b'C'[0]] + __slots__ = () + + @classmethod + def parse(typ, data): + return typ(data.rstrip(b'\x00')) + + def extract_count(self): + """ + Extract the last set of digits as an integer. + """ + # Find the last sequence of digits. + # If there are no fields consisting only of digits, there is no count. + for x in reversed(self.data.split()): + if x.isdigit(): + return int(x) + return None + + def extract_command(self): + """ + Strip all the *surrounding* digits and spaces from the command tag, + and return that string. + """ + return self.data.strip(b'\c\n\t 0123456789') or None + +class Null(EmptyMessage): + """ + Null command. + """ + type = message_types[b'I'[0]] + __slots__ = () +NullMessage = Message.__new__(Null) +Null.SingleInstance = NullMessage + +class NoData(EmptyMessage): + """ + Null command. + """ + type = message_types[b'n'[0]] + __slots__ = () +NoDataMessage = Message.__new__(NoData) +NoData.SingleInstance = NoDataMessage + +class ParseComplete(EmptyMessage): + """ + Parse reaction. + """ + type = message_types[b'1'[0]] + __slots__ = () +ParseCompleteMessage = Message.__new__(ParseComplete) +ParseComplete.SingleInstance = ParseCompleteMessage + +class BindComplete(EmptyMessage): + """ + Bind reaction. + """ + type = message_types[b'2'[0]] + __slots__ = () +BindCompleteMessage = Message.__new__(BindComplete) +BindComplete.SingleInstance = BindCompleteMessage + +class CloseComplete(EmptyMessage): + """ + Close statement or Portal. + """ + type = message_types[b'3'[0]] + __slots__ = () +CloseCompleteMessage = Message.__new__(CloseComplete) +CloseComplete.SingleInstance = CloseCompleteMessage + +class Suspension(EmptyMessage): + """ + Portal was suspended, more tuples for reading. + """ + type = message_types[b's'[0]] + __slots__ = () +SuspensionMessage = Message.__new__(Suspension) +Suspension.SingleInstance = SuspensionMessage + +class Ready(Message): + """ + Ready for new query message. + """ + type = message_types[b'Z'[0]] + possible_states = ( + message_types[b'I'[0]], + message_types[b'E'[0]], + message_types[b'T'[0]], + ) + __slots__ = ('xact_state',) + + def __init__(self, data): + if data not in self.possible_states: + raise ValueError("invalid state for Ready message: " + repr(data)) + self.xact_state = data + + def serialize(self): + return self.xact_state + +class Notice(Message, dict): + """ + Notification message. + + Used by PQ to emit INFO, NOTICE, and WARNING messages among other + severities. + """ + type = message_types[b'N'[0]] + __slots__ = () + __repr__ = dict_message_repr + + def serialize(self): + return b'\x00'.join([ + k + v for k, v in self.items() + if k and v is not None + ]) + b'\x00' + + @classmethod + def parse(typ, data, msgtypes = message_types): + return typ([ + (msgtypes[x[0]], x[1:]) + # "if x" reduce empty fields + for x in data.split(b'\x00') if x + ]) + +class ClientNotice(Notice): + __slots__ = () + + def serialize(self): + raise RuntimeError("cannot serialize ClientNotice") + + @classmethod + def parse(self): + raise RuntimeError("cannot parse ClientNotice") + +class Error(Notice): + """ + Error information message. + """ + type = message_types[b'E'[0]] + __slots__ = () + +class ClientError(Error): + __slots__ = () + + def serialize(self): + raise RuntimeError("cannot serialize ClientError") + + @classmethod + def parse(self): + raise RuntimeError("cannot serialize ClientError") + +class FunctionResult(Message): + """ + Function result value. + """ + type = message_types[b'V'[0]] + __slots__ = ('result',) + + def __init__(self, datum): + self.result = datum + + def serialize(self): + return self.result is None and b'\xff\xff\xff\xff' or \ + ulong_pack(len(self.result)) + self.result + + @classmethod + def parse(typ, data): + if data == b'\xff\xff\xff\xff': + return typ(None) + size = ulong_unpack(data[0:4]) + data = data[4:] + if size != len(data): + raise ValueError( + "data length(%d) is not equal to the specified message size(%d)" %( + len(data), size + ) + ) + return typ(data) + +class AttributeTypes(TupleMessage): + """ + Tuple attribute types. + """ + type = message_types[b't'[0]] + __slots__ = () + + def serialize(self): + return ushort_pack(len(self)) + b''.join([ulong_pack(x) for x in self]) + + @classmethod + def parse(typ, data): + ac = ushort_unpack(data[0:2]) + args = data[2:] + if len(args) != ac * 4: + raise ValueError("invalid argument type data size") + return typ(unpack('!%dL'%(ac,), args)) + +class TupleDescriptor(TupleMessage): + """ + Tuple structure description. + """ + type = message_types[b'T'[0]] + struct = Struct("!LhLhlh") + __slots__ = () + + def keys(self): + return [x[0] for x in self] + + def serialize(self): + return ushort_pack(len(self)) + b''.join([ + x[0] + b'\x00' + self.struct.pack(*x[1:]) + for x in self + ]) + + @classmethod + def parse(typ, data): + ac = ushort_unpack(data[0:2]) + atts = [] + data = data[2:] + ca = 0 + while ca < ac: + # End Of Attribute Name + eoan = data.index(b'\x00') + name = data[0:eoan] + data = data[eoan+1:] + # name, relationId, columnNumber, typeId, typlen, typmod, format + atts.append((name,) + typ.struct.unpack(data[0:18])) + data = data[18:] + ca += 1 + return typ(atts) + +class Tuple(TupleMessage): + """ + Tuple Data. + """ + type = message_types[b'D'[0]] + __slots__ = () + + def serialize(self): + return ushort_pack(len(self)) + pack_tuple_data(self) + + @classmethod + def parse(typ, data, + T = tuple, ulong_unpack = ulong_unpack, + len = len + ): + natts = ushort_unpack(data[0:2]) + atts = [] + offset = 2 + add = atts.append + + while natts > 0: + alo = offset + offset += 4 + size = data[alo:offset] + if size == b'\xff\xff\xff\xff': + att = None + else: + al = ulong_unpack(size) + ao = offset + offset = ao + al + att = data[ao:offset] + add(att) + natts -= 1 + return T(atts) + try: + parse = parse_tuple_message + except NameError: + # This is an override when port.optimized is available. + pass + +class KillInformation(Message): + """ + Backend cancellation information. + """ + type = message_types[b'K'[0]] + struct = Struct("!LL") + __slots__ = ('pid', 'key') + + def __init__(self, pid, key): + self.pid = pid + self.key = key + + def serialize(self): + return self.struct.pack(self.pid, self.key) + + @classmethod + def parse(typ, data): + return typ(*typ.struct.unpack(data)) + +class CancelRequest(KillInformation): + """ + Abort the query in the specified backend. + """ + type = b'' + from .version import CancelRequestCode as version + packed_version = version.bytes() + __slots__ = ('pid', 'key') + + def serialize(self): + return self.packed_version + self.struct.pack( + self.pid, self.key + ) + + def bytes(self): + data = self.serialize() + return ulong_pack(len(data) + 4) + self.serialize() + + @classmethod + def parse(typ, data): + if data[0:4] != typ.packed_version: + raise ValueError("invalid cancel query code") + return typ(*typ.struct.unpack(data[4:])) + +class NegotiateSSL(Message): + """ + Discover backend's SSL support. + """ + type = b'' + from .version import NegotiateSSLCode as version + packed_version = version.bytes() + __slots__ = () + + def __new__(typ): + return NegotiateSSLMessage + + def bytes(self): + data = self.serialize() + return ulong_pack(len(data) + 4) + data + + def serialize(self): + return self.packed_version + + @classmethod + def parse(typ, data): + if data != typ.packed_version: + raise ValueError("invalid SSL Negotiation code") + return NegotiateSSLMessage +NegotiateSSLMessage = Message.__new__(NegotiateSSL) + +class Startup(Message, dict): + """ + Initiate a connection using the given keywords. + """ + type = b'' + from py_opengauss.protocol.version import V3_51 as version + packed_version = version.bytes() + __slots__ = () + __repr__ = dict_message_repr + + def serialize(self): + return self.packed_version + b''.join([ + k + b'\x00' + v + b'\x00' + for k, v in self.items() + if v is not None + ]) + b'\x00' + + def bytes(self): + data = self.serialize() + return ulong_pack(len(data) + 4) + data + + @classmethod + def parse(typ, data): + if data[0:4] != typ.packed_version: + raise ValueError("invalid version code {1}".format(repr(data[0:4]))) + kw = dict() + key = None + for value in data[4:].split(b'\x00')[:-2]: + if key is None: + key = value + continue + kw[key] = value + key = None + return typ(kw) + +AuthRequest_OK = 0 +AuthRequest_Cleartext = 3 +AuthRequest_Password = AuthRequest_Cleartext +AuthRequest_Crypt = 4 +AuthRequest_MD5 = 5 +AuthRequest_SHA256 = 10 # implementation for opengauss + +# Unsupported by pg_protocol. +AuthRequest_KRB4 = 1 +AuthRequest_KRB5 = 2 +AuthRequest_SCMC = 6 +AuthRequest_SSPI = 9 +AuthRequest_GSS = 7 +AuthRequest_GSSContinue = 8 + +AuthNameMap = { + AuthRequest_Password : 'Cleartext', + AuthRequest_Crypt : 'Crypt', + AuthRequest_MD5 : 'MD5', + + AuthRequest_KRB4 : 'Kerberos4', + AuthRequest_KRB5 : 'Kerberos5', + AuthRequest_SCMC : 'SCM Credential', + AuthRequest_SSPI : 'SSPI', + AuthRequest_GSS : 'GSS', + AuthRequest_GSSContinue : 'GSSContinue', + AuthRequest_SHA256: 'SHA256', +} + +class Authentication(Message): + """ + Authentication(request, salt) + """ + type = message_types[b'R'[0]] + __slots__ = ('request', 'salt') + + def __init__(self, request, salt): + self.request = request + self.salt = salt + + def serialize(self): + return ulong_pack(self.request) + self.salt + + @classmethod + def parse(typ, data): + return typ(ulong_unpack(data[0:4]), data[4:]) + +class Password(StringMessage): + """ + Password supplement. + """ + type = message_types[b'p'[0]] + __slots__ = ('data',) + +class Disconnect(EmptyMessage): + """ + Connection closed message. + """ + type = message_types[b'X'[0]] + __slots__ = () +DisconnectMessage = Message.__new__(Disconnect) +Disconnect.SingleInstance = DisconnectMessage + +class Flush(EmptyMessage): + """ + Flush message. + """ + type = message_types[b'H'[0]] + __slots__ = () +FlushMessage = Message.__new__(Flush) +Flush.SingleInstance = FlushMessage + +class Synchronize(EmptyMessage): + """ + Synchronize. + """ + type = message_types[b'S'[0]] + __slots__ = () +SynchronizeMessage = Message.__new__(Synchronize) +Synchronize.SingleInstance = SynchronizeMessage + +class Query(StringMessage): + """ + Execute the query with the given arguments. + """ + type = message_types[b'Q'[0]] + __slots__ = ('data',) + +class Parse(Message): + """ + Parse a query with the specified argument types. + """ + type = message_types[b'P'[0]] + __slots__ = ('name', 'statement', 'argtypes') + + def __init__(self, name, statement, argtypes): + self.name = name + self.statement = statement + self.argtypes = argtypes + + @classmethod + def parse(typ, data): + name, statement, args = data.split(b'\x00', 2) + ac = ushort_unpack(args[0:2]) + args = args[2:] + if len(args) != ac * 4: + raise ValueError("invalid argument type data") + at = unpack('!%dL'%(ac,), args) + return typ(name, statement, at) + + def serialize(self): + ac = ushort_pack(len(self.argtypes)) + return self.name + b'\x00' + self.statement + b'\x00' + ac + b''.join([ + ulong_pack(x) for x in self.argtypes + ]) + +class Bind(Message): + """ + Bind a parsed statement with the given arguments to a Portal + + Bind( + name, # Portal/Cursor identifier + statement, # Prepared Statement name/identifier + aformats, # Argument formats; Sequence of BinaryFormat or StringFormat. + arguments, # Argument data; Sequence of None or argument data(str). + rformats, # Result formats; Sequence of BinaryFormat or StringFormat. + ) + """ + type = message_types[b'B'[0]] + __slots__ = ('name', 'statement', 'aformats', 'arguments', 'rformats') + + def __init__(self, name, statement, aformats, arguments, rformats): + self.name = name + self.statement = statement + self.aformats = aformats + self.arguments = arguments + self.rformats = rformats + + def serialize(self, len = len): + args = self.arguments + ac = ushort_pack(len(args)) + ad = pack_tuple_data(tuple(args)) + return \ + self.name + b'\x00' + self.statement + b'\x00' + \ + ac + b''.join(self.aformats) + ac + ad + \ + ushort_pack(len(self.rformats)) + b''.join(self.rformats) + + @classmethod + def parse(typ, message_data): + name, statement, data = message_data.split(b'\x00', 2) + ac = ushort_unpack(data[:2]) + offset = 2 + (2 * ac) + aformats = unpack(("2s" * ac), data[2:offset]) + + natts = ushort_unpack(data[offset:offset+2]) + args = list() + offset += 2 + + while natts > 0: + alo = offset + offset += 4 + size = data[alo:offset] + if size == b'\xff\xff\xff\xff': + att = None + else: + al = ulong_unpack(size) + ao = offset + offset = ao + al + att = data[ao:offset] + args.append(att) + natts -= 1 + + rfc = ushort_unpack(data[offset:offset+2]) + ao = offset + 2 + offset = ao + (2 * rfc) + rformats = unpack(("2s" * rfc), data[ao:offset]) + + return typ(name, statement, aformats, args, rformats) + +class Execute(Message): + """ + Fetch results from the specified Portal. + """ + type = message_types[b'E'[0]] + __slots__ = ('name', 'max') + + def __init__(self, name, max = 0): + self.name = name + self.max = max + + def serialize(self): + return self.name + b'\x00' + ulong_pack(self.max) + + @classmethod + def parse(typ, data): + name, max = data.split(b'\x00', 1) + return typ(name, ulong_unpack(max)) + +class Describe(StringMessage): + """ + Request a description of a Portal or Prepared Statement. + """ + type = message_types[b'D'[0]] + __slots__ = ('data',) + + def serialize(self): + return self.subtype + self.data + b'\x00' + + @classmethod + def parse(typ, data): + if data[0:1] != typ.subtype: + raise ValueError( + "invalid Describe message subtype, %r; expected %r" %( + typ.subtype, data[0:1] + ) + ) + return super().parse(data[1:]) + +class DescribeStatement(Describe): + subtype = message_types[b'S'[0]] + __slots__ = ('data',) + +class DescribePortal(Describe): + subtype = message_types[b'P'[0]] + __slots__ = ('data',) + +class Close(StringMessage): + """ + Generic Close + """ + type = message_types[b'C'[0]] + __slots__ = () + + def serialize(self): + return self.subtype + self.data + b'\x00' + + @classmethod + def parse(typ, data): + if data[0:1] != typ.subtype: + raise ValueError( + "invalid Close message subtype, %r; expected %r" %( + typ.subtype, data[0:1] + ) + ) + return super().parse(data[1:]) + +class CloseStatement(Close): + """ + Close the specified Statement + """ + subtype = message_types[b'S'[0]] + __slots__ = () + +class ClosePortal(Close): + """ + Close the specified Portal + """ + subtype = message_types[b'P'[0]] + __slots__ = () + +class Function(Message): + """ + Execute the specified function with the given arguments + """ + type = message_types[b'F'[0]] + __slots__ = ('oid', 'aformats', 'arguments', 'rformat') + + def __init__(self, oid, aformats, args, rformat): + self.oid = oid + self.aformats = aformats + self.arguments = args + self.rformat = rformat + + def serialize(self): + ac = ushort_pack(len(self.arguments)) + return ulong_pack(self.oid) + \ + ac + b''.join(self.aformats) + \ + ac + pack_tuple_data(tuple(self.arguments)) + self.rformat + + @classmethod + def parse(typ, data): + oid = ulong_unpack(data[0:4]) + + ac = ushort_unpack(data[4:6]) + offset = 6 + (2 * ac) + aformats = unpack(("2s" * ac), data[6:offset]) + + natts = ushort_unpack(data[offset:offset+2]) + args = list() + offset += 2 + + while natts > 0: + alo = offset + offset += 4 + size = data[alo:offset] + if size == b'\xff\xff\xff\xff': + att = None + else: + al = ulong_unpack(size) + ao = offset + offset = ao + al + att = data[ao:offset] + args.append(att) + natts -= 1 + + return typ(oid, aformats, args, data[offset:]) + +class CopyBegin(Message): + type = None + struct = Struct("!BH") + __slots__ = ('format', 'formats') + + def __init__(self, format, formats): + self.format = format + self.formats = formats + + def serialize(self): + return self.struct.pack(self.format, len(self.formats)) + b''.join([ + ushort_pack(x) for x in self.formats + ]) + + @classmethod + def parse(typ, data): + format, natts = typ.struct.unpack(data[:3]) + formats_str = data[3:] + if len(formats_str) != natts * 2: + raise ValueError("number of formats and data do not match up") + return typ(format, [ + ushort_unpack(formats_str[x:x+2]) for x in range(0, natts * 2, 2) + ]) + +class CopyToBegin(CopyBegin): + """ + Begin copying to. + """ + type = message_types[b'H'[0]] + __slots__ = ('format', 'formats') + +class CopyFromBegin(CopyBegin): + """ + Begin copying from. + """ + type = message_types[b'G'[0]] + __slots__ = ('format', 'formats') + +class CopyData(Message): + type = message_types[b'd'[0]] + __slots__ = ('data',) + + def __init__(self, data): + self.data = bytes(data) + + def serialize(self): + return self.data + + @classmethod + def parse(typ, data): + return typ(data) + +class CopyFail(StringMessage): + type = message_types[b'f'[0]] + __slots__ = ('data',) + +class CopyDone(EmptyMessage): + type = message_types[b'c'[0]] + __slots__ = ('data',) +CopyDoneMessage = Message.__new__(CopyDone) +CopyDone.SingleInstance = CopyDoneMessage diff --git a/py_opengauss/protocol/message_types.py b/py_opengauss/protocol/message_types.py new file mode 100644 index 0000000000000000000000000000000000000000..24a7646dc05b04e2133f7c37917706adcd467078 --- /dev/null +++ b/py_opengauss/protocol/message_types.py @@ -0,0 +1,15 @@ +## +# .protocol.message_types +## +""" +Data module providing a sequence of bytes objects whose value corresponds to its +index in the sequence. + +This provides resource for buffer objects to use common message type objects. + +WARNING: It's tempting to use the 'is' operator and in some circumstances that +may be okay. However, it's possible (sys.modules.clear()) for the extension +modules' copy of this to become inconsistent with what protocol.element3 and +protocol.xact3 are using, so it's important to **not** use 'is'. +""" +message_types = tuple([bytes((x,)) for x in range(256)]) diff --git a/py_opengauss/protocol/pbuffer.py b/py_opengauss/protocol/pbuffer.py new file mode 100644 index 0000000000000000000000000000000000000000..d41a79e06cc879d3d67c917b009cd55b301b7b1f --- /dev/null +++ b/py_opengauss/protocol/pbuffer.py @@ -0,0 +1,182 @@ +## +# .protocol.pbuffer +## +""" +Pure Python message buffer implementation. + +Given data read from the wire, buffer the data until a complete message has been +received. +""" +__all__ = ['pq_message_stream'] + +from io import BytesIO +import struct +from .message_types import message_types + +xl_unpack = struct.Struct('!xL').unpack_from + +class pq_message_stream(object): + """ + Provide a message stream from a data stream. + """ + + _block = 512 + _limit = _block * 4 + def __init__(self): + self._strio = BytesIO() + self._start = 0 + + def truncate(self): + """ + Remove all data in the buffer. + """ + + self._strio.truncate(0) + self._start = 0 + + def _rtruncate(self, amt = None): + """ + [internal] remove the given amount of data. + """ + + strio = self._strio + if amt is None: + amt = self._strio.tell() + strio.seek(0, 2) + size = strio.tell() + # if the total size is equal to the amt, + # then the whole thing is going to be truncated. + if size == amt: + strio.truncate(0) + return + + copyto_pos = 0 + copyfrom_pos = amt + while True: + strio.seek(copyfrom_pos) + data = strio.read(self._block) + # Next copyfrom + copyfrom_pos = strio.tell() + strio.seek(copyto_pos) + strio.write(data) + if len(data) != self._block: + break + # Next copyto + copyto_pos = strio.tell() + + strio.truncate(size - amt) + + def has_message(self, xl_unpack = xl_unpack, len = len): + """ + Whether the buffer has a message available. + """ + + strio = self._strio + strio.seek(self._start) + header = strio.read(5) + if len(header) < 5: + return False + length, = xl_unpack(header) + if length < 4: + raise ValueError("invalid message size '%d'" %(length,)) + strio.seek(0, 2) + return (strio.tell() - self._start) >= length + 1 + + def __len__(self, xl_unpack = xl_unpack, len = len): + """ + Number of messages in buffer. + """ + + count = 0 + rpos = self._start + strio = self._strio + strio.seek(self._start) + while True: + # get the message metadata + header = strio.read(5) + rpos += 5 + if len(header) < 5: + # not enough data for another message + break + # unpack the length from the header + length, = xl_unpack(header) + rpos += length - 4 + + if length < 4: + raise ValueError("invalid message size '%d'" %(length,)) + strio.seek(length - 4 - 1, 1) + + if len(strio.read(1)) != 1: + break + count += 1 + return count + + def _get_message(self, + mtypes = message_types, + len = len, + xl_unpack = xl_unpack, + ): + strio = self._strio + header = strio.read(5) + if len(header) < 5: + return + length, = xl_unpack(header) + typ = mtypes[header[0]] + + if length < 4: + raise ValueError("invalid message size '%d'" %(length,)) + length -= 4 + body = strio.read(length) + if len(body) < length: + # Not enough data for message. + return + return (typ, body) + + def next_message(self): + if self._start > self._limit: + self._rtruncate(self._start) + self._start = 0 + + self._strio.seek(self._start) + msg = self._get_message() + if msg is not None: + self._start = self._strio.tell() + return msg + + def __next__(self): + if self._start > self._limit: + self._rtruncate(self._start) + self._start = 0 + + self._strio.seek(self._start) + msg = self._get_message() + if msg is None: + raise StopIteration + self._start = self._strio.tell() + return msg + + def read(self, num = 0xFFFFFFFF, len = len): + if self._start > self._limit: + self._rtruncate(self._start) + self._start = 0 + + new_start = self._start + self._strio.seek(new_start) + l = [] + while len(l) < num: + msg = self._get_message() + if msg is None: + break + l.append(msg) + new_start += (5 + len(msg[1])) + self._start = new_start + return l + + def write(self, data): + # Always append data; it's a stream, damnit.. + self._strio.seek(0, 2) + self._strio.write(data) + + def getvalue(self): + self._strio.seek(self._start) + return self._strio.read() diff --git a/py_opengauss/protocol/version.py b/py_opengauss/protocol/version.py new file mode 100644 index 0000000000000000000000000000000000000000..deb5b568a24b97b38e1ff5c1bf471157eabd5626 --- /dev/null +++ b/py_opengauss/protocol/version.py @@ -0,0 +1,48 @@ +## +# .protocol.version +## +""" +PQ version class used by startup messages. +""" +from struct import Struct +version_struct = Struct('!HH') + +class Version(tuple): + """ + Version((major, minor)) -> Version + + Version serializer and parser. + """ + major = property(fget = lambda s: s[0]) + minor = property(fget = lambda s: s[1]) + + def __new__(subtype, major_minor): + (major, minor) = major_minor + major = int(major) + minor = int(minor) + # If it can't be packed like this, it's not a valid version. + try: + version_struct.pack(major, minor) + except Exception as e: + raise ValueError("unpackable major and minor") from e + + return tuple.__new__(subtype, (major, minor)) + + def __int__(self): + return (self[0] << 16) | self[1] + + def bytes(self): + return version_struct.pack(self[0], self[1]) + + def __repr__(self): + return '%d.%d' %(self[0], self[1]) + + def parse(self, data): + return self(version_struct.unpack(data)) + parse = classmethod(parse) + +CancelRequestCode = Version((1234, 5678)) +NegotiateSSLCode = Version((1234, 5679)) +V2_0 = Version((2, 0)) +V3_0 = Version((3, 0)) +V3_51 = Version((3, 51)) # add for openGauss diff --git a/py_opengauss/protocol/xact3.py b/py_opengauss/protocol/xact3.py new file mode 100644 index 0000000000000000000000000000000000000000..c1ca11a45ce42e1bd7a395849043ed4a94459f06 --- /dev/null +++ b/py_opengauss/protocol/xact3.py @@ -0,0 +1,718 @@ +## +# .protocol.xact3 - protocol state machine +## +""" +PQ version 3.0 client transactions. +""" +import sys +import os +import pprint +from abc import ABCMeta, abstractmethod +from itertools import chain +from operator import itemgetter +get0 = itemgetter(0) +get1 = itemgetter(1) + +from ..python.functools import Composition as compose +from . import element3 as element + +from hashlib import md5 +from ..resolved.crypt import crypt +from ..resolved.opengauss import sha256_pw + +try: + from ..port.optimized import consume_tuple_messages +except ImportError: + pass + +Receiving = True +Sending = False +Complete = (None, None) + +AsynchronousMap = { + element.Notice.type : element.Notice.parse, + element.Notify.type : element.Notify.parse, + element.ShowOption.type : element.ShowOption.parse, +} + +def return_arg(x): + return x + +message_expectation = \ + "expected message of types {expected}, " \ + "but received {received} instead".format + +class Transaction(object, metaclass = ABCMeta): + """ + If the fatal attribute is not None, an error occurred, and the + `error_message` attribute should be set to a element3.Error instance. + """ + fatal = None + + @abstractmethod + def messages_received(self): + """ + Return an iterable to the messages received that have been processed. + """ + +class Closing(Transaction): + """ + Send the disconnect message and mark the connection as closed. + """ + error_message = element.ClientError(( + (b'S', 'FATAL'), + # pg_exc.ConnectionDoesNotExistError.code + (b'C', '08003'), + (b'M', 'operation on closed connection'), + (b'H', "A new connection needs to be "\ + "created in order to query the server."), + )) + + def messages_received(self): + return () + + def sent(self): + """ + Empty messages and mark complete. + """ + self.messages = () + self.fatal = True + self.state = Complete + + def __init__(self): + self.messages = (element.DisconnectMessage,) + self.state = (Sending, self.sent) + +class Negotiation(Transaction): + """ + Negotiation is a protocol transaction used to manage the initial stage of a + connection to PostgreSQL. + + This transaction revolves around the `state_machine` method which is a + generator that takes individual messages and progresses the state of the + connection negotiation. This was chosen over the route taken by + `Transaction`, seen later, as it's not terribly performance intensive and + there are many conditions which make a generator ideal for managing the + state. + """ + state = None + + def __init__(self, startup_message, password): + self.startup_message = startup_message + self.password = password + self.received = [()] + self.asyncs = [] + self.authtype = None + self.killinfo = None + self.authok = None + self.last_ready = None + self.machine = self.state_machine() + self.messages = next(self.machine) + self.state = (Sending, self.sent) + + def __repr__(self): + s = type(self).__module__ + "." + type(self).__name__ + s += pprint.pformat((self.startup_message, self.password)).lstrip() + return s + + def messages_received(self): + return self.processed + + def sent(self): + """ + Empty messages and switch state to receiving. + + This is called by the user after the `messages` have been sent to the + remote end. That is, this merely finalizes the "Sending" state. + """ + self.messages = () + self.state = (Receiving, self.put_messages) + + def put_messages(self, messages): + # Record everything received. + out_messages = () + if messages is not self.received[-1]: + self.received.append(messages) + else: + raise RuntimeError("negotiation was interrupted") + + # if an Error message was found, complete and leave. + count = 0 + try: + for x in messages: + count += 1 + if x[0] == element.Error.type: + if self.fatal is None: + self.error_message = element.Error.parse(x[1]) + self.fatal = True + self.state = Complete + return count + elif x[0] in AsynchronousMap: + self.asyncs.append( + AsynchronousMap[x[0]](x[1]) + ) + else: + out_messages = self.machine.send(x) + if out_messages: + break + except StopIteration: + # generator is complete, negotiation is complete.. + self.state = Complete + return count + + if out_messages: + self.messages = out_messages + self.state = (Sending, self.sent) + return count + + def unsupported_auth_request(self, req): + self.fatal = True + self.error_message = element.ClientError(( + (b'S', "FATAL"), + (b'C', "--AUT"), + (b'M', "unsupported authentication request %r(%d)" %( + element.AuthNameMap.get(req, ''), req, + )), + (b'H', "'postgresql.protocol' only supports: SHA256(for opengauss), MD5, crypt, plaintext, and trust."), + )) + self.state = Complete + + def state_machine(self): + """ + Generator keeping the state of the connection negotiation process. + """ + x = (yield (self.startup_message,)) + if x[0] != element.Authentication.type: + self.fatal = True + self.error_message = element.ClientError(( + (b'S', 'FATAL'), + (b'C', '08P01'), + (b'M', message_expectation( + expected = element.Authentication.type, + received = x[0], + )), + )) + return + + self.authtype = element.Authentication.parse(x[1]) + + req = self.authtype.request + if req != element.AuthRequest_OK: + if req == element.AuthRequest_Cleartext: + pw = self.password + elif req == element.AuthRequest_Crypt: + pw = crypt(self.password, self.authtype.salt) + elif req == element.AuthRequest_MD5: + pw = md5(self.password + self.startup_message[b'user']).hexdigest().encode('ascii') + pw = b'md5' + md5(pw + self.authtype.salt).hexdigest().encode('ascii') + elif req == element.AuthRequest_SHA256: + pw = sha256_pw(self.startup_message[b'user'], self.password, self.authtype.salt) + else: + ## + # Not going to work. Sorry :( + # The many authentication types supported by PostgreSQL are not + # easy to implement, especially when implementations for the + # type don't exist for Python. + self.unsupported_auth_request(req) + return + x = (yield (element.Password(pw),)) + + self.authok = element.Authentication.parse(x[1]) + if self.authok.request != element.AuthRequest_OK: + self.fatal = True + self.error_message = element.ClientError(( + (b'S', 'FATAL'), + (b'C', "08P01"), + (b'M', "expected OK from the authentication " \ + "message, but received %s(%s) instead" %( + repr(element.AuthNameMap.get( + self.authok.request, '' + )), + repr(self.authok.request), + ), + ) + )) + return + else: + self.authok = self.authtype + + # Done authenticating, pick up the killinfo and the ready message. + x = (yield None) + if x[0] != element.KillInformation.type: + self.fatal = True + self.error_message = element.ClientError(( + (b'S', 'FATAL'), + (b'C', '08P01'), + (b'M', message_expectation( + expected = element.KillInformation.type, + received = repr(x[0]), + )), + )) + return + self.killinfo = element.KillInformation.parse(x[1]) + + x = (yield None) + if x[0] != element.Ready.type: + self.fatal = True + self.error_message = element.ClientError(( + (b'S', "FATAL"), + (b'C', "08P01"), + (b'M', message_expectation( + expected = repr(element.Ready.type), + received = repr(x[0]), + )) + )) + return + self.last_ready = element.Ready.parse(x[1]) + +class Instruction(Transaction): + """ + Manage the state of a sequence of request messages to be sent to the server. + It provides the messages to be sent and takes the response messages for order + and integrity validation: + + Instruction([.element3.Message(), ..]) + + A message must be one of: + + * `.element3.Query` + * `.element3.Function` + * `.element3.Parse` + * `.element3.Bind` + * `.element3.Describe` + * `.element3.Close` + * `.element3.Execute` + * `.element3.Synchronize` + * `.element3.Flush` + """ + state = None + CopyFailMessage = element.CopyFail(b"invalid termination") + + # The hook is the dictionary that provides the path for the + # current working message. The received messages ultimately come + # through here and get parsed using the associated callable. + # Messages that complete a command are paired with None. + hook = { + element.Query.type : ( + # 0: Start. + { + element.TupleDescriptor.type : (element.TupleDescriptor.parse, 3), + element.Null.type : (element.Null.parse, 0), + element.Complete.type : (element.Complete.parse, 0), + element.CopyToBegin.type : (element.CopyToBegin.parse, 2), + element.CopyFromBegin.type : (element.CopyFromBegin.parse, 1), + element.Ready.type : (element.Ready.parse, None), + }, + # 1: Complete. + { + element.Complete.type : (element.Complete.parse, 0), + }, + # 2: Copy Data. + # CopyData until CopyDone. + # Complete comes next. + { + element.CopyData.type : (return_arg, 2), + element.CopyDone.type : (element.CopyDone.parse, 1), + }, + # 3: Row Data. + { + element.Tuple.type : (element.Tuple.parse, 3), + element.Complete.type : (element.Complete.parse, 0), + element.Ready.type : (element.Ready.parse, None), + }, + ), + + element.Function.type : ( + {element.FunctionResult.type : (element.FunctionResult.parse, 1)}, + {element.Ready.type : (element.Ready.parse, None)}, + ), + + # Extended Protocol + element.Parse.type : ( + {element.ParseComplete.type : (element.ParseComplete.parse, None)}, + ), + + element.Bind.type : ( + {element.BindComplete.type : (element.BindComplete.parse, None)}, + ), + + element.Describe.type : ( + # Still needs the descriptor. + { + element.AttributeTypes.type : (element.AttributeTypes.parse, 1), + element.TupleDescriptor.type : ( + element.TupleDescriptor.parse, None + ), + }, + # NoData or TupleDescriptor + { + element.NoData.type : (element.NoData.parse, None), + element.TupleDescriptor.type : ( + element.TupleDescriptor.parse, None + ), + }, + ), + + element.Close.type : ( + {element.CloseComplete.type : (element.CloseComplete.parse, None)}, + ), + + element.Execute.type : ( + # 0: Start. + { + element.Tuple.type : (element.Tuple.parse, 1), + element.CopyToBegin.type : (element.CopyToBegin.parse, 2), + element.CopyFromBegin.type : (element.CopyFromBegin.parse, 3), + element.Null.type : (element.Null.parse, None), + element.Complete.type : (element.Complete.parse, None), + }, + # 1: Row Data. + { + element.Tuple.type : (element.Tuple.parse, 1), + element.Suspension.type : (element.Suspension.parse, None), + element.Complete.type : (element.Complete.parse, None), + }, + # 2: Copy Data. + { + element.CopyData.type : (return_arg, 2), + element.CopyDone.type : (element.CopyDone.parse, 3), + }, + # 3: Complete. + { + element.Complete.type : (element.Complete.parse, None), + }, + ), + + element.Synchronize.type : ( + {element.Ready.type : (element.Ready.parse, None)}, + ), + + element.Flush.type : None, + } + + initial_state = ( + (), # last messages, + (0, 0), # request position, response position + (0, 0), # last request position, last response position + ) + + def __init__(self, commands, asynchook = return_arg): + """ + Initialize an `Instruction` instance using the given commands. + + Commands are `postgresql.protocol.element3.Message` instances: + + * `.element3.Query` + * `.element3.Function` + * `.element3.Parse` + * `.element3.Bind` + * `.element3.Describe` + * `.element3.Close` + * `.element3.Execute` + * `.element3.Synchronize` + * `.element3.Flush` + """ + # Commands are accessed by index. + self.commands = tuple(commands) + self.asynchook = asynchook + self.completed = [] + self.last = self.initial_state + self.messages = list(self.commands) + self.state = (Sending, self.standard_sent) + self.fatal = None + + for cmd in self.commands: + if cmd.type not in self.hook: + raise TypeError( + "unknown message type for PQ 3.0 protocol", cmd.type + ) + + def __repr__(self, format = '{mod}.{name}({nl}{args})'.format): + return format( + mod = type(self).__module__, + name = type(self).__name__, + nl = os.linesep, + args = pprint.pformat(self.commands) + ) + + def messages_received(self): + """ + Received and validate messages. + """ + return chain.from_iterable(map(get1, self.completed)) + + def reverse(self, + chaining = chain.from_iterable, + map = map, + transform = compose((get1, reversed)), + reversed = reversed + ): + """ + A iterator that producing the completed messages in reverse + order. Last in, first out. + """ + return chaining(map(transform, reversed(self.completed))) + + def standard_put(self, messages, + SWITCH_TYPES = element.Execute.type + element.Query.type, + ERROR_TYPE = element.Error.type, + READY_TYPE = element.Ready.type, + ERROR_PARSE = element.Error.parse, + len = len, + ): + """ + Attempt to forward the state of the transaction using the given + messages. "put" messages into the transaction for processing. + + If an invalid command is initialized on the `Transaction` object, an + `IndexError` will be thrown. + """ + COMMANDS = self.commands + NCOMMANDS = len(COMMANDS) + HOOK = self.hook + # We processed it, but apparently something went wrong, + # so go ahead and reprocess it. + if messages is self.last[0]: + offset, current_step = self.last[1] + # don't clear the asyncs. they have already been process by the hook. + else: + offset, current_step = self.last[2] + # it's a new set, so we can clear the asyncs record. + self._asyncs = [] + cmd = COMMANDS[offset] + paths = HOOK[cmd.type] + processed = [] + count = 0 + + for x in messages: + count += 1 + # For the current message, get the path for the message + # and whether it signals the end of the current command + path, next_step = paths[current_step].get(x[0], (None, None)) + + if path is None: + # No path for message type, could be a protocol error. + if x[0] == ERROR_TYPE: + em = ERROR_PARSE(x[1]) + # Is it fatal? + self.fatal = fatal = em[b'S'].upper() != b'ERROR' + self.error_message = em + if fatal is True: + # Can't sync up if the session is closed. + self.state = Complete + return count + # Error occurred, so sync up with backend if + # the current command is not 'Q' or 'F' as they + # imply a sync message. + if cmd.type not in ( + element.Function.type, element.Query.type + ): + # Adjust the offset forward until the Sync message is found. + for offset in range(offset, NCOMMANDS): + if COMMANDS[offset] is element.SynchronizeMessage: + break + else: + ## + # It's done. + self.state = Complete + return count + ## + # Not quite done, the state(Ready) message still + # needs to be received. + cmd = COMMANDS[offset] + paths = HOOK[cmd.type] + # On a new command, setup the new step. + current_step = 0 + continue + elif x[0] in AsynchronousMap: + if x not in self._asyncs: + msg = AsynchronousMap[x[0]](x[1]) + try: + self.asynchook(msg) + except Exception as err: + # exception thrown by async message handler? + # notify the user, but continue... + sys.excepthook(*sys.exc_info()) + # it's been processed, so don't process it again. + self._asyncs.append(x) + else: + ## + # Procotol violation. + self.fatal = True + self.error_message = element.ClientError(( + (b'S', 'FATAL'), + (b'C', '08P01'), + (b'M', message_expectation( + expected = tuple(paths[current_step].keys()), + received = x[0] + )), + )) + self.state = Complete + return count + else: + # Process a valid message. + r = path(x[1]) + processed.append(r) + + if next_step is not None: + current_step = next_step + else: + current_step = 0 + if r.type == READY_TYPE: + self.last_ready = r.xact_state + # Done with the current command. Increment the offset, and + # try to process the new command with the remaining data. + paths = None + while paths is None: + # Increment the offset past any commands + # whose hook is None (FlushMessage) + offset += 1 + # If the offset is the length, + # the transaction is complete. + if offset == NCOMMANDS: + # Done with transaction. + break + cmd = COMMANDS[offset] + paths = HOOK[cmd.type] + else: + # More commands to process in this transaction. + continue + # The while loop was broken offset == len(self.commands) + # So, that's all there is to this transaction. + break + + # Push the messages onto the completed list if they + # have not been put there already. + if not self.completed or self.completed[-1][0] != id(messages): + self.completed.append((id(messages), processed)) + + # Store the state for the next transition. + self.last = (messages, self.last[2], (offset, current_step),) + + if offset == NCOMMANDS: + # transaction complete. + self.state = Complete + elif cmd.type in SWITCH_TYPES and processed: + # Check the context to identify if the state should be + # switched to an optimized processor. + last = processed[-1] + if last.__class__ is bytes: + # Fast path for COPY data, 'd' messages. + self.state = (Receiving, self.put_copydata) + elif last.__class__ is tuple: + # Fast path for Tuples, 'D' messages. + self.state = (Receiving, self.put_tupledata) + elif last.type == element.CopyFromBegin.type: + # In this case, the commands that were sent past + # message starting the COPY, need to be re-issued + # once the COPY is complete. PG cleared its buffer. + self.CopyFailSequence = (self.CopyFailMessage,) + \ + self.commands[offset+1:] + self.CopyDoneSequence = (element.CopyDoneMessage,) + \ + self.commands[offset+1:] + self.state = (Sending, self.sent_from_stdin) + elif last.type == element.CopyToBegin.type: + # Should be seeing COPY data soon. + self.state = (Receiving, self.put_copydata) + return count + + def put_copydata(self, messages): + """ + In the context of a copy, `put_copydata` is used as a fast path for + storing `element.CopyData` messages. When a non-`element.CopyData.type` + message is received, it reverts the ``state`` attribute back to + `standard_put` to process the message-sequence. + """ + copydata = element.CopyData.type + # "Fail" quickly if the last message is not copy data. + if messages[-1][0] != copydata: + self.state = (Receiving, self.standard_put) + return self.standard_put(messages) + + lines = [x[1] for x in messages if x[0] == copydata] + if len(lines) != len(messages): + self.state = (Receiving, self.standard_put) + return self.standard_put(messages) + + if not self.completed or self.completed[-1][0] != id(messages): + self.completed.append((id(messages), lines)) + self.last = (messages, self.last[2], self.last[2],) + return len(messages) + + try: + def put_tupledata(self, messages, + consume = consume_tuple_messages, + ): + tuplemessages = consume(messages) + if not tuplemessages: + # bad handler switch? + self.state = (Receiving, self.standard_put) + return self.standard_put(messages) + + if not self.completed or self.completed[-1][0] != id(messages): + self.completed.append(((id(messages), tuplemessages))) + self.last = (messages, self.last[2], self.last[2],) + return len(tuplemessages) + except NameError: + ## + # No consume_tuple_messages function. + def put_tupledata(self, messages, + p = element.Tuple.parse, + t = element.Tuple.type, + ): + """ + Fast path used when inside an Execute command. As soon as tuple + data is seen. + """ + # Fallback to `standard_put` quickly if the last + # message is not tuple data. + if messages[-1][0] is not t: + self.state = (Receiving, self.standard_put) + return self.standard_put(messages) + + tuplemessages = [p(x[1]) for x in messages if x[0] == t] + if len(tuplemessages) != len(messages): + self.state = (Receiving, self.standard_put) + return self.standard_put(messages) + + if not self.completed or self.completed[-1][0] != id(messages): + self.completed.append(((id(messages), tuplemessages))) + self.last = (messages, self.last[2], self.last[2],) + return len(messages) + + def standard_sent(self): + """ + Empty messages and switch state to receiving. + + This is called by the user after the `messages` have been sent to the + remote end. That is, this merely finalizes the "Sending" state. + """ + self.messages = () + self.state = (Receiving, self.standard_put) + sent = standard_sent + + def sent_from_stdin(self): + """ + The state method for sending copy data. + + After each call to `sent_from_stdin`, the `messages` attribute is set + to a `CopyFailSequence`. This sequence of messages assures that the + COPY will be properly terminated. + + If new copy data is not provided, or `messages` is *not* set to + `CopyDoneSequence`, the transaction will instruct the remote end to + cause the COPY to fail. + """ + if self.messages is self.CopyDoneSequence or \ + self.messages is self.CopyFailSequence: + # If the last sent `messages` is CopyDone or CopyFail, finish out the + # transaction. + ## + self.messages = () + self.state = (Receiving, self.standard_put) + else: + ## + # Initialize to CopyFail, if the messages attribute is not + # set properly before each invocation, the transaction is + # being misused and will be terminated. + self.messages = self.CopyFailSequence diff --git a/py_opengauss/python/__init__.py b/py_opengauss/python/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9cd00be5a4aec739f737efb16047087fba8682 --- /dev/null +++ b/py_opengauss/python/__init__.py @@ -0,0 +1,5 @@ +""" +Python tools package. + +Various extensions to the standard library. +""" diff --git a/py_opengauss/python/command.py b/py_opengauss/python/command.py new file mode 100644 index 0000000000000000000000000000000000000000..18685c22349410f9a9e64759b60dd38641e1f0be --- /dev/null +++ b/py_opengauss/python/command.py @@ -0,0 +1,635 @@ +## +# .python.command - Python command emulation module. +## +""" +Create and Execute Python Commands +================================== + +The purpose of this module is to simplify the creation of a Python command +interface. Normally, one would want to do this if there is a *common* need +for a certain Python environment that may be, at least, partially initialized +via command line options. A notable case would be a Python environment with a +database connection whose connection parameters came from the command line. That +is, Python + command line driven configuration. + +The module also provides an extended interactive console that provides backslash +commands for editing and executing temporary files. Use ``python -m +pythoncommand`` to try it out. + +Simple usage:: + + import sys + import os + import optparse + import pythoncommand as pycmd + + op = optparse.OptionParser( + "%prog [options] [script] [script arguments]", + version = '1.0', + ) + op.disable_interspersed_args() + + # Basically, the standard -m and -c. (Some additional ones for fun) + op.add_options(pycmd.default_optparse_options) + + co, ca = op.parse_args(args[1:]) + + # This initializes an execution instance which gathers all the information + # about the code to be ran when ``pyexe`` is called. + pyexe = pycmd.Execution(ca, + context = getattr(co, 'python_context', ()), + loader = getattr(co, 'python_main', None), + ) + + # And run it. Any exceptions will be printed via print_exception. + rv = pyexe() + sys.exit(rv) +""" +import os +import sys +import re +import code +import types +import optparse +import subprocess +import contextlib + +from gettext import gettext as _ +from traceback import print_exception + +from pkgutil import get_loader as module_loader + +class single_loader(object): + """ + used for "loading" string modules(think -c) + """ + def __init__(self, source): + self.source = source + + def get_filename(self, fullpath): + if fullpath == self.source: + return '' + + def get_code(self, fullpath): + if fullpath == self.source: + return compile(self.source, '', 'exec') + + def get_source(self, fullpath): + if fullpath == self.source: + return self.source + +class file_loader(object): + """ + used for "loading" scripts + """ + def __init__(self, filepath, fileobj = None): + self.filepath = filepath + if fileobj is not None: + self._source = fileobj.read() + + def get_filename(self, fullpath): + if fullpath == self.filepath: + return self.filepath + + def get_source(self, fullpath): + if fullpath == self.filepath: + return self._read() + + def _read(self): + if hasattr(self, '_source'): + return self._source + f = open(self.filepath) + try: + return f.read() + finally: + f.close() + + def get_code(self, fullpath): + if fullpath != self.filepath: + return + return compile(self._read(), self.filepath, 'exec') + +def extract_filepath(x): + if x.startswith('file://'): + return x[7:] + return None + +def extract_module(x): + if x.startswith('module:'): + return x[7:] + return None + +module_loader_descriptor = ( + 'Python module', module_loader, extract_module +) +file_loader_descriptor = ( + 'Python script', file_loader, extract_filepath +) +single_loader_descriptor = ( + 'Python command', single_loader, lambda x: x +) + +_directory = ( + module_loader_descriptor, + file_loader_descriptor, +) +directory = list(_directory) + +def find_loader(ident, dir = directory): + for x in dir: + xid = x[2](ident) + if xid is not None: + return x + +## +# optparse options +## + +def append_context(option, opt_str, value, parser): + """ + Add some context to the execution of the Python code using + loader module's directory list of loader descriptions. + + If no loader can be found, assume it's a Python command. + """ + pc = getattr(parser.values, option.dest, None) or [] + if not pc: + setattr(parser.values, option.dest, pc) + ldesc = find_loader(value) + if ldesc is None: + ldesc = single_loader_descriptor + pc.append((value, ldesc)) + +def set_python_main(option, opt_str, value, parser): + """ + Set the main Python code; after contexts are initialized, main is ran. + """ + main = (value, option.python_loader) + setattr(parser.values, option.dest, main) + # only terminate parsing if not interspersing arguments + if not parser.allow_interspersed_args: + parser.rargs.insert(0, '--') + +context = optparse.make_option( + '-C', '--context', + help = _('Python context code to run[file://,module:,]'), + dest = 'python_context', + action = 'callback', + callback = append_context, + type = 'str' +) + +module = optparse.make_option( + '-m', + help = _('Python module to run as script(__main__)'), + dest = 'python_main', + action = 'callback', + callback = set_python_main, + type = 'str' +) +module.python_loader = module_loader_descriptor + +command = optparse.make_option( + '-c', + help = _('Python expression to run(__main__)'), + dest = 'python_main', + action = 'callback', + callback = set_python_main, + type = 'str' +) +command.python_loader = single_loader_descriptor + +default_optparse_options = [ + context, module, command, +] + +class ExtendedConsole(code.InteractiveConsole): + """ + Console subclass providing some convenient backslash commands. + """ + def __init__(self, *args, **kw): + import tempfile + self.mktemp = tempfile.mktemp + import shlex + self.split = shlex.split + code.InteractiveConsole.__init__(self, *args, **kw) + + self.bsc_map = {} + self.temp_files = {} + self.past_buffers = [] + + self.register_backslash(r'\?', self.showhelp, "Show this help message.") + self.register_backslash(r'\set', self.bs_set, + "Configure environment variables. \set without arguments to show all") + self.register_backslash(r'\E', self.bs_E, + "Edit a file or a temporary script.") + self.register_backslash(r'\i', self.bs_i, + "Execute a Python script within the interpreter's context.") + self.register_backslash(r'\e', self.bs_e, + "Edit and Execute the file directly in the context.") + self.register_backslash(r'\x', self.bs_x, + "Execute the Python command within this process.") + + def interact(self, *args, **kw): + self.showhelp(None, None) + return super().interact(*args,**kw) + + def showtraceback(self): + e, v, tb = sys.exc_info() + sys.last_type, sys.last_value, sys.last_traceback = e, v, tb + print_exception(e, v, tb.tb_next or tb) + + def register_backslash(self, bscmd, meth, doc): + self.bsc_map[bscmd] = (meth, doc) + + def execslash(self, line): + """ + If push() gets a line that starts with a backslash, execute + the command that the backslash sequence corresponds to. + """ + cmd = line.split(None, 1) + cmd.append('') + bsc = self.bsc_map.get(cmd[0]) + if bsc is None: + self.write("ERROR: unknown backslash command: %s%s"%(cmd, os.linesep)) + else: + return bsc[0](cmd[0], cmd[1]) + + def showhelp(self, cmd, arg): + i = list(self.bsc_map.items()) + i.sort(key = lambda x: x[0]) + helplines = os.linesep.join([ + ' %s%s%s' %( + x[0], ' ' * (8 - len(x[0])), x[1][1] + ) for x in i + ]) + self.write("Backslash Commands:%s%s%s" %( + os.linesep*2, helplines, os.linesep*2 + )) + + def bs_set(self, cmd, arg): + """ + Set a value in the interpreter's environment. + """ + if arg: + for x in self.split(arg): + if '=' in x: + k, v = x.split('=', 1) + os.environ[k] = v + self.write("%s=%s%s" %(k, v, os.linesep)) + elif x: + self.write("%s=%s%s" %(x, os.environ.get(x, ''), os.linesep)) + else: + for k,v in os.environ.items(): + self.write("%s=%s%s" %(k, v, os.linesep)) + + def resolve_path(self, path, dont_create = False): + """ + Get the path of the given string; if the path is not + absolute and does not contain path separators, identify + it as a temporary file. + """ + if not os.path.isabs(path) and not os.path.sep in path: + # clean it up to avoid typos + path = path.strip().lower() + tmppath = self.temp_files.get(path) + if tmppath is None: + if dont_create is False: + tmppath = self.mktemp( + suffix = '.py', + prefix = '_console_%s_' %(path,) + ) + self.temp_files[path] = tmppath + else: + return path + return tmppath + return path + + def execfile(self, filepath): + src = open(filepath) + try: + try: + co = compile(src.read(), filepath, 'exec') + except SyntaxError: + co = None + print_exception(*sys.exc_info()) + finally: + src.close() + if co is not None: + try: + exec(co, self.locals, self.locals) + except: + e, v, tb = sys.exc_info() + print_exception(e, v, tb.tb_next or tb) + + def editfiles(self, filepaths): + sp = list(filepaths) + # ;) + sp.insert(0, os.environ.get('EDITOR', 'vi')) + return subprocess.call(sp) + + def bs_i(self, cmd, arg): + 'execute the files' + for x in self.split(arg) or ('',): + p = self.resolve_path(x, dont_create = True) + self.execfile(p) + + def bs_E(self, cmd, arg): + 'edit the files, but *only* edit them' + self.editfiles([self.resolve_path(x) for x in self.split(arg) or ('',)]) + + def bs_e(self, cmd, arg): + 'edit *and* execute the files' + filepaths = [self.resolve_path(x) for x in self.split(arg) or ('',)] + self.editfiles(filepaths) + for x in filepaths: + self.execfile(x) + + def bs_x(self, cmd, arg): + rv = -1 + if len(cmd) > 1: + a = self.split(arg) + a.insert(0, '\\x') + try: + rv = command(argv = a) + except SystemExit as se: + rv = se.code + self.write("[Return Value: %d]%s" %(rv, os.linesep)) + + def push(self, line): + # Has to be a ps1 context. + if not self.buffer and line.startswith('\\'): + try: + self.execslash(line) + except: + # print the exception, but don't raise. + e, v, tb = sys.exc_info() + print_exception(e, v, tb.tb_next or tb) + else: + return code.InteractiveConsole.push(self, line) + +@contextlib.contextmanager +def postmortem(funcpath): + if not funcpath: + yield None + else: + pm = funcpath.split('.') + attr = pm.pop(-1) + modpath = '.'.join(pm) + try: + m = __import__(modpath, fromlist = modpath) + pmobject = getattr(m, attr, None) + except ValueError: + pmobject = None + + sys.stderr.write( + "%sERROR: no object at %r for postmortem%s"%( + os.linesep, funcpath, os.linesep + ) + ) + try: + yield None + except: + try: + sys.last_type, sys.last_value, sys.last_traceback = sys.exc_info() + pmobject() + except: + sys.stderr.write( + "[Exception raised by Postmortem]" + os.linesep + ) + print_exception(*sys.exc_info()) + raise + +class Execution(object): + """ + Given argv and context make an execution instance that, when called, will + execute the configured Python code. + + This class provides the ability to identify what the main part of the + execution of the configured Python code. For instance, shall it execute a + console, the file that the first argument points to, a -m option module + appended to the python_context option value, or the code given within -c? + """ + def __init__(self, + args, context = (), + main = None, + loader = None, + stdin = sys.stdin + ): + """ + args + The arguments passed to the script; usually sys.argv after being + processed by optparse(ca). + context + A list of loader descriptors that will be used to establish the + context of __main__ module. + main + Overload to explicitly state what main is. None will cause the + class to attempt to fill in the attribute using 'args' and other + system objects like sys.stdin. + """ + self.args = args + self.context = context and list(context) or () + + if main is not None: + self.main = main + elif loader is not None: + # Main explicitly stated, resolve the path and the loader + path, ldesc = loader + ltitle, rloader, xpath = ldesc + l = rloader(path) + if l is None: + raise ImportError( + "%s %r does not exist or cannot be read" %( + ltitle, path + ) + ) + self.main = (path, l) + # If there are args, but no main, run the first arg. + elif args: + fp = self.args[0] + f = open(fp) + try: + l = file_loader(fp, fileobj = f) + finally: + f.close() + self.main = (self.args[0], l) + self.args = self.args[1:] + # There is no main, no loader, and no args. + # If stdin is not a tty, use stdin as the main file. + elif not stdin.isatty(): + l = file_loader('', fileobj = stdin) + self.main = ('', l) + # tty and no "main". + else: + # console + self.main = (None, None) + self.reset_module__main__() + + def reset_module__main__(self): + mod = types.ModuleType('__main__') + mod.__builtins__ = __builtins__ + mod.__package__ = None + self.module__main__ = mod + path = getattr(self.main[1], 'fullname', None) + if path is not None: + mod.__package__ = '.'.join(path.split('.')[:-1]) + + def _call(self, + console = ExtendedConsole, + context = None + ): + """ + Initialize the context and run main in the given locals + (Note: tramples on sys.argv, __main__ in sys.modules) + (Use __call__ instead) + """ + sys.modules['__main__'] = self.module__main__ + md = self.module__main__.__dict__ + + # Establish execution context in the locals; + # iterate over all the loaders in self.context and + for path, ldesc in self.context: + ltitle, loader, xpath = ldesc + rpath = xpath(path) + li = loader(rpath) + if li is None: + sys.stderr.write( + "%s %r does not exist or cannot be read%s" %( + ltitle, rpath, os.linesep + ) + ) + return 1 + try: + code = li.get_code(rpath) + except: + print_exception(*sys.exc_info()) + return 1 + self.module__main__.__file__ = getattr( + li, 'get_filename', lambda x: x + )(rpath) + self.module__main__.__loader__ = li + try: + exec(code, md, md) + except: + e, v, tb = sys.exc_info() + print_exception(e, v, tb.tb_next or tb) + return 1 + + if self.main == (None, None): + # It's interactive. + sys.argv = self.args or [''] + + # Use readline if available + try: + import readline + except ImportError: + pass + + ic = console(locals = md) + try: + ic.interact() + except SystemExit as e: + return e.code + return 0 + else: + # It's ultimately a code object. + path, loader = self.main + self.module__main__.__file__ = getattr( + loader, 'get_filename', lambda x: x + )(path) + sys.argv = list(self.args) + sys.argv.insert(0, self.module__main__.__file__) + try: + code = loader.get_code(path) + except: + print_exception(*sys.exc_info()) + return 1 + + rv = 0 + exe_exception = False + try: + if context is not None: + with context: + try: + exec(code, md, md) + except: + exe_exception = True + raise + else: + try: + exec(code, md, md) + except: + exe_exception = True + raise + + except SystemExit as e: + # Assume it's an exe_exception as anything ran in `context` + # shouldn't cause an exception. + rv = e.code + e, v, tb = sys.exc_info() + sys.last_type = e + sys.last_value = v + sys.last_traceback = (tb.tb_next or tb) + except: + if exe_exception is False: + raise + rv = 1 + e, v, tb = sys.exc_info() + print_exception(e, v, tb.tb_next or tb) + sys.last_type = e + sys.last_value = v + sys.last_traceback = (tb.tb_next or tb) + + return rv + + def __call__(self, *args, **kw): + storage = ( + sys.modules.get('__context__'), + sys.modules.get('__main__'), + sys.argv, + os.environ.copy(), + ) + try: + return self._call(*args, **kw) + finally: + sys.modules['__context__'], \ + sys.modules['__main__'], \ + sys.argv, os.environ = storage + + def get_main_source(self): + """ + Get the execution's "__main__" source. Useful for configuring + environmental options derived from "magic" lines. + """ + path, loader = self.main + if path is not None: + return loader.get_source(path) + +def command_execution(argv = sys.argv): + 'create an execution using the given argv' + # The pwd should be in the path for python commands. + # setuptools' console_scripts appear to strip this out. + if '' not in sys.path: + sys.path.insert(0, '') + + op = optparse.OptionParser( + "%prog [options] [script] [script arguments]", + version = '1.0', + ) + op.disable_interspersed_args() + op.add_options(default_optparse_options) + co, ca = op.parse_args(argv[1:]) + + return Execution(ca, + context = getattr(co, 'python_context', ()), + loader = getattr(co, 'python_main', None), + ) + +def command(argv = sys.argv): + return command_execution(argv = argv)( + context = postmortem(os.environ.get('PYTHON_POSTMORTEM')) + ) + +if __name__ == '__main__': + sys.exit(command()) diff --git a/py_opengauss/python/datetime.py b/py_opengauss/python/datetime.py new file mode 100644 index 0000000000000000000000000000000000000000..5844e33fc65a8fe5fc0fb13c35dd63109b638f8d --- /dev/null +++ b/py_opengauss/python/datetime.py @@ -0,0 +1,42 @@ +## +# python.datetime - parts needed to use stdlib.datetime +## +import datetime + +## +# stdlib.datetime representation of PostgreSQL 'infinity' and '-infinity'. +infinity_datetime = datetime.datetime(datetime.MAXYEAR, 12, 31, 23, 59, 59, 999999) +negative_infinity_datetime = datetime.datetime(datetime.MINYEAR, 1, 1, 0, 0, 0, 0) + +infinity_date = datetime.date(datetime.MAXYEAR, 12, 31) +negative_infinity_date = datetime.date(datetime.MINYEAR, 1, 1) + +class FixedOffset(datetime.tzinfo): + def __init__(self, offset_in_seconds, tzname = None): + self._tzname = tzname + self._offset = offset_in_seconds + self._offset_in_mins = offset_in_seconds // 60 + self._td_offset = datetime.timedelta(0, self._offset_in_mins * 60) + self._dst = datetime.timedelta(0) + + def utcoffset(self, offset_from): + return self._td_offset + + def tzname(self, dt): + return self._tzname + + def dst(self, arg): + return self._dst + + def __repr__(self): + return "{path}.{name}({off}{tzname})".format( + path = type(self).__module__, + name = type(self).__name__, + off = repr(self._td_offset.days * 24 * 60 * 60 + self._td_offset.seconds), + tzname = ( + ", tzname = {tzname!r}".format(tzname = self._tzname) \ + if self._tzname is not None else "" + ) + ) + +UTC = FixedOffset(0, tzname = 'UTC') diff --git a/py_opengauss/python/decorlib.py b/py_opengauss/python/decorlib.py new file mode 100644 index 0000000000000000000000000000000000000000..24bcc36737e29b4e425260833fef83f3cc424b9f --- /dev/null +++ b/py_opengauss/python/decorlib.py @@ -0,0 +1,44 @@ +## +# .python.decorlib +## +""" +common decorators +""" +import os +import types + +def propertydoc(ap): + """ + Helper function for extracting an `abstractproperty`'s real documentation. + """ + doc = "" + rstr = "" + if ap.fget: + ret = ap.fget.__annotations__.get('return') + if ret is not None: + rstr = " -> " + repr(ret) + if ap.fget.__doc__: + doc += os.linesep*2 + "GET::" + (os.linesep + ' '*4) + (os.linesep + ' '*4).join( + [x.strip() for x in ap.fget.__doc__.strip().split(os.linesep)] + ) + if ap.fset and ap.fset.__doc__: + doc += os.linesep*2 + "SET::" + (os.linesep + ' '*4) + (os.linesep + ' '*4).join( + [x.strip() for x in ap.fset.__doc__.strip().split(os.linesep)] + ) + if ap.fdel and ap.fdel.__doc__: + doc += os.linesep*2 + "DELETE::" + (os.linesep + ' '*4) + (os.linesep + ' '*4).join( + [x.strip() for x in ap.fdel.__doc__.strip().split(os.linesep)] + ) + ap.__doc__ = "" if not doc else ( + "Abstract Property" + rstr + doc + ) + return ap + +class method(object): + __slots__ = ('callable',) + def __init__(self, callable): + self.callable = callable + def __get__(self, val, typ): + if val is None: + return self.callable + return types.MethodType(self.callable, val) diff --git a/py_opengauss/python/doc.py b/py_opengauss/python/doc.py new file mode 100644 index 0000000000000000000000000000000000000000..6a42bcb84a1cf7a8c49788f9d2a5886010e525d7 --- /dev/null +++ b/py_opengauss/python/doc.py @@ -0,0 +1,18 @@ +## +# .python.doc +## +""" +Documentation Tools. +""" +from operator import attrgetter + +class Doc(object): + """ + Simple object that sets the __doc__ attribute to the first parameter and + initializes __annotations__ using keyword arguments. + """ + def __init__(self, doc, **annotations): + self.__doc__ = str(doc) + self.__annotations__ = annotations + + __str__ = attrgetter('__doc__') diff --git a/py_opengauss/python/element.py b/py_opengauss/python/element.py new file mode 100644 index 0000000000000000000000000000000000000000..4257a44ed42dd57ec32e0d1c3a4b55aeca8213ca --- /dev/null +++ b/py_opengauss/python/element.py @@ -0,0 +1,215 @@ +## +# .python.element +## +import os +from abc import ABCMeta, abstractproperty, abstractmethod +from .string import indent +from .decorlib import propertydoc + +class RecursiveFactor(Exception): + """ + Raised when a factor is ultimately composed of itself. + """ + pass + +class Element(object, metaclass = ABCMeta): + """ + The purpose of an element is to provide a general mechanism for specifying + the factors that composed an object. Factors are designated using an + ordered set of strings referencing those significant attributes on the object. + + Factors are important for PG-API as it provides the foundation for + collecting the information about the state of the interface that ultimately + led up to an error. + + Traceback: + ... + postgresql.exceptions.*: + CODE: XX000 + CURSOR: + parameters: (p1, p2, ...) + STATEMENT: + ... + string: + + SYMBOL: get_types + LIBRARY: catalog + ... + CONNECTION: + + CONNECTOR: [Host] + IRI: pq://user@localhost:5432/database + DRIVER: postgresql.driver.pq3 + """ + + @propertydoc + @abstractproperty + def _e_label(self) -> str: + """ + Single-word string describing the kind of element. + + For instance, `postgresql.api.Statement`'s _e_label is 'STATEMENT'. + + Usually, this is set directly on the class itself, and is a shorter + version of the class's name. + """ + + @propertydoc + @abstractproperty + def _e_factors(self) -> (): + """ + The attribute names of the objects that contributed to the creation of + this object. + + The ordering is significant. The first factor is the prime factor. + """ + + @abstractmethod + def _e_metas(self) -> [(str, object)]: + """ + Return an iterable to key-value pairs that provide useful descriptive + information about an attribute. + + Factors on metas are not checked. They are expected to be primitives. + + If there are no metas, the str() of the object will be used to represent + it. + """ + +class ElementSet(Element, set): + """ + An ElementSet is a set of Elements that can be used as an individual factor. + + In situations where a single factor is composed of multiple elements where + each has no significance over the other, this Element can be used represent + that fact. + + Importantly, it provides the set metadata so that the appropriate information + will be produced in element tracebacks. + """ + _e_label = 'SET' + _e_factors = () + __slots__ = () + + def _e_metas(self): + yield (None, len(self)) + for x in self: + yield (None, '--') + yield (None, format_element(x)) + +def prime_factor(obj): + """ + Get the primary factor on the `obj`, returns None if none. + """ + f = getattr(obj, '_e_factors', None) + if f: + return f[0], getattr(obj, f[0], None) + +def prime_factors(obj): + """ + Yield out the sequence of primary factors of the given object. + """ + visited = set((obj,)) + ef = getattr(obj, '_e_factors', None) + if not ef: + return + fn = ef[0] + e = getattr(obj, fn, None) + if e in visited: + raise RecursiveFactor(obj, e) + visited.add(e) + yield fn, e + + while e is not None: + ef = getattr(obj, '_e_factors', None) + fn = ef[0] + e = getattr(e, fn, None) + if e in visited: + raise RecursiveFactor(obj, e) + visited.add(e) + yield fn, e + +def format_element(obj, coverage = ()): + """ + Format the given element with its factors and metadata into a readable string. + """ + # if it's not an Element, all there is to return is str(obj) + if obj in coverage: + raise RecursiveFactor(coverage) + coverage = coverage + (obj,) + + if not isinstance(obj, Element): + if obj is None: + return 'None' + return str(obj) + + # The description of `obj` is built first. + + # formal element, get metas first. + nolead = False + metas = [] + for key, val in obj._e_metas(): + m = "" + if val is None: + sval = 'None' + else: + sval = str(val) + + pre = ' ' + if key is not None: + m += key + ':' + if (len(sval) > 70 or os.linesep in sval): + pre = os.linesep + sval = indent(sval) + else: + # if the key is None, it is intended to be inlined. + nolead = True + pre = '' + m += pre + sval.rstrip() + metas.append(m) + + factors = [] + for att in obj._e_factors[1:]: + m = "" + f = getattr(obj, att) + # if the object has a label, use the label + m += att + ':' + sval = format_element(f, coverage = coverage) + if len(sval) > 70 or os.linesep in sval: + m += os.linesep + indent(sval) + else: + m += ' ' + sval + factors.append(m) + + mtxt = os.linesep.join(metas) + ftxt = os.linesep.join(factors) + if mtxt: + mtxt = indent(mtxt) + if ftxt: + ftxt = indent(ftxt) + s = mtxt + ftxt + if nolead is True: + # metas started with a `None` key. + s = ' ' + s.lstrip() + else: + s = os.linesep + s + s = obj._e_label + ':' + s.rstrip() + + # and resolve the next prime + pf = prime_factor(obj) + if pf is not None: + factor_name, prime = pf + factor = format_element(prime, coverage = coverage) + if getattr(prime, '_e_label', None) is not None: + # if the factor has a label, then it will be + # included in the format_element output, and + # thus factor_name is not needed. + factor_name = '' + else: + factor_name += ':' + if len(factor) > 70 or os.linesep in factor: + factor = os.linesep + indent(factor) + else: + factor_name += ' ' + s += os.linesep + factor_name + factor + return s diff --git a/py_opengauss/python/functools.py b/py_opengauss/python/functools.py new file mode 100644 index 0000000000000000000000000000000000000000..02cf71f98b47a081fee386c310a13e01e0c23878 --- /dev/null +++ b/py_opengauss/python/functools.py @@ -0,0 +1,69 @@ +## +# python.functools +## +import sys +from .decorlib import method + +def rsetattr(attr, val, ob): + """ + setattr() and return `ob`. Different order used to allow easier partial + usage. + """ + setattr(ob, attr, val) + return ob + +try: + from ..port.optimized import rsetattr +except ImportError: + pass + +class Composition(tuple): + def __call__(self, r): + for x in self: + r = x(r) + return r + + try: + from ..port.optimized import compose + __call__ = method(compose) + del compose + except ImportError: + pass + +try: + # C implementation of the tuple processors. + from ..port.optimized import process_tuple, process_chunk +except ImportError: + def process_tuple(procs, tup, exception_handler, len = len, tuple = tuple, cause = None): + """ + Call each item in `procs` with the corresponding + item in `tup` returning the result as `type`. + + If an item in `tup` is `None`, don't process it. + + If a give transformation failes, call the given exception_handler. + """ + i = len(procs) + if len(tup) != i: + raise TypeError( + "inconsistent items, %d processors and %d items in row" %( + i, len(tup) + ) + ) + r = [None] * i + try: + for i in range(i): + ob = tup[i] + if ob is None: + continue + r[i] = procs[i](ob) + except Exception: + cause = sys.exc_info()[1] + + if cause is not None: + exception_handler(cause, procs, tup, i) + raise RuntimeError("process_tuple exception handler failed to raise") + return tuple(r) + + def process_chunk(procs, tupc, fail, process_tuple = process_tuple): + return [process_tuple(procs, x, fail) for x in tupc] diff --git a/py_opengauss/python/itertools.py b/py_opengauss/python/itertools.py new file mode 100644 index 0000000000000000000000000000000000000000..08fcdb5df3e566bf6b8f12a908e4056b34e08f7c --- /dev/null +++ b/py_opengauss/python/itertools.py @@ -0,0 +1,48 @@ +## +# .python.itertools +## +""" +itertools extensions +""" +import collections.abc +from itertools import cycle, islice + +def interlace(*iters, next = next) -> collections.abc.Iterable: + """ + interlace(i1, i2, ..., in) -> ( + i1-0, i2-0, ..., in-0, + i1-1, i2-1, ..., in-1, + . + . + . + i1-n, i2-n, ..., in-n, + ) + """ + return map(next, cycle([iter(x) for x in iters])) + +def chunk(iterable, chunksize = 256): + """ + Given an iterable, return an iterable producing chunks of the objects + produced by the given iterable. + + chunks([o1,o2,o3,o4], chunksize = 2) -> [ + [o1,o2], + [o3,o4], + ] + """ + iterable = iter(iterable) + last = () + lastsize = chunksize + while lastsize == chunksize: + last = list(islice(iterable, chunksize)) + lastsize = len(last) + yield last + +def find(iterable, selector): + """ + Return the first item in the `iterable` that causes the `selector` to return + `True`. + """ + for x in iterable: + if selector(x): + return x diff --git a/py_opengauss/python/msw.py b/py_opengauss/python/msw.py new file mode 100644 index 0000000000000000000000000000000000000000..0da02cf23c65d82b2d78e6fa7eb18e8cff3d1ebb --- /dev/null +++ b/py_opengauss/python/msw.py @@ -0,0 +1,17 @@ +## +# .python.msw +## +""" +Additional Microsoft Windows tools. +""" + +# for Popen(), not supported on windows +close_fds = False + +def platform_exe(name): + """ + Append '.exe' if it's not already there. + """ + if name.endswith('.exe'): + return name + return name + '.exe' diff --git a/py_opengauss/python/os.py b/py_opengauss/python/os.py new file mode 100644 index 0000000000000000000000000000000000000000..6cdb7cc38295245118062e80aeaef77955a3ec88 --- /dev/null +++ b/py_opengauss/python/os.py @@ -0,0 +1,34 @@ +## +# .python.os +## +""" +General OS abstractions and information. +""" +import sys +import os + +#: By default, close the FDs on subprocess.Popen(). +close_fds = True + +#: By default, there is no modification for executable references. +platform_exe = str + +def find_file(basename, paths, + join = os.path.join, exists = os.path.exists, +): + """ + Find the file in the given paths. Return the first path + that exists. + """ + for x in paths: + path = join(x, basename) + if exists(path): + return path + +if sys.platform in ('win32','win64'): + # replace variants for windows + from .msw import close_fds, platform_exe + +def find_executable(basename, pathsep = os.pathsep, platexe = platform_exe): + paths = os.environ.get('PATH', '').split(pathsep) + return find_file(platexe(basename), paths) diff --git a/py_opengauss/python/socket.py b/py_opengauss/python/socket.py new file mode 100644 index 0000000000000000000000000000000000000000..6cdffdca2b678d042efd33cda9050ead7f25bd9a --- /dev/null +++ b/py_opengauss/python/socket.py @@ -0,0 +1,100 @@ +## +# .python.socket - additional tools for working with sockets +## +import sys +import os +import random +import socket +import errno +import ssl + +__all__ = ['find_available_port', 'SocketFactory'] + +class SocketFactory(object): + """ + Object used to create a socket and connect it. + + This is, more or less, a specialized partial() for socket creation. + + Additionally, it provides methods and attributes for abstracting + exception management on socket operation. + """ + + timeout_exception = socket.timeout + fatal_exception = socket.error + try_again_exception = socket.error + + def timed_out(self, err) -> bool: + return err.__class__ is self.timeout_exception + + @staticmethod + def try_again(err, codes = (errno.EAGAIN, errno.EINTR, errno.EWOULDBLOCK, errno.ETIMEDOUT)) -> bool: + """ + Does the error indicate that the operation should be tried again? + + More importantly, the connection is *not* dead. + """ + errno = getattr(err, 'errno', None) + if errno is None: + return False + return errno in codes + + @classmethod + def fatal_exception_message(typ, err) -> (str, None): + """ + If the exception was fatal to the connection, + what message should be given to the user? + """ + if typ.try_again(err): + return None + return getattr(err, 'strerror', '') + + def secure(self, socket : socket.socket) -> ssl.SSLSocket: + """ + Secure a socket with SSL. + """ + if self.socket_secure is not None: + return ssl.wrap_socket(socket, **self.socket_secure) + else: + return ssl.wrap_socket(socket) + + def __call__(self, timeout = None): + s = socket.socket(*self.socket_create) + try: + s.settimeout(float(timeout) if timeout is not None else None) + s.connect(self.socket_connect) + s.settimeout(None) + except Exception: + s.close() + raise + return s + + def __init__(self, + socket_create, + socket_connect, + socket_secure = None, + ): + self.socket_create = socket_create + self.socket_connect = socket_connect + self.socket_secure = socket_secure + + def __str__(self): + return 'socket' + repr(self.socket_connect) + +def find_available_port( + interface = 'localhost', + address_family = socket.AF_INET, +): + """ + Find an available port on the given interface for the given address family. + """ + + port = None + s = socket.socket(address_family, socket.SOCK_STREAM,) + try: + s.bind(('localhost', 0)) + port = s.getsockname()[1] + finally: + s.close() + + return port diff --git a/py_opengauss/python/string.py b/py_opengauss/python/string.py new file mode 100644 index 0000000000000000000000000000000000000000..304f25d5754acaac11ac570a6fb27015b7e3e645 --- /dev/null +++ b/py_opengauss/python/string.py @@ -0,0 +1,11 @@ +## +# .python.string +## +import os + +def indent(s, level = 2, char = ' '): + ind = char * level + r = "" + for x in s.splitlines(): + r += ((ind + x).rstrip() + os.linesep) + return r diff --git a/py_opengauss/python/structlib.py b/py_opengauss/python/structlib.py new file mode 100644 index 0000000000000000000000000000000000000000..7560accaed4e8743ae88f427a552b7391641a2d5 --- /dev/null +++ b/py_opengauss/python/structlib.py @@ -0,0 +1,108 @@ +## +# .python.structlib - module for extracting serialized data +## +import struct +from .functools import Composition as compose + +null_sequence = b'\xff\xff\xff\xff' + +# Always to and from network order. +# Create a pair, (pack, unpack) for the given `struct` format.' +def mk_pack(x): + s = struct.Struct('!' + x) + if len(x) > 1: + def pack(y, p = s.pack): + return p(*y) + return (pack, s.unpack_from) + else: + def unpack(y, p = s.unpack_from): + return p(y)[0] + return (s.pack, unpack) + +byte_pack, byte_unpack = lambda x: bytes((x,)), lambda x: x[0] +double_pack, double_unpack = mk_pack("d") +float_pack, float_unpack = mk_pack("f") +dd_pack, dd_unpack = mk_pack("dd") +ddd_pack, ddd_unpack = mk_pack("ddd") +dddd_pack, dddd_unpack = mk_pack("dddd") +LH_pack, LH_unpack = mk_pack("LH") +lH_pack, lH_unpack = mk_pack("lH") +llL_pack, llL_unpack = mk_pack("llL") +qll_pack, qll_unpack = mk_pack("qll") +dll_pack, dll_unpack = mk_pack("dll") + +dl_pack, dl_unpack = mk_pack("dl") +ql_pack, ql_unpack = mk_pack("ql") + +hhhh_pack, hhhh_unpack = mk_pack("hhhh") + +longlong_pack, longlong_unpack = mk_pack("q") +ulonglong_pack, ulonglong_unpack = mk_pack("Q") + +# Optimizations for int2, int4, and int8. +try: + from ..port import optimized as opt + from sys import byteorder as bo + if bo == 'little': + short_unpack = opt.swap_int2_unpack + short_pack = opt.swap_int2_pack + ushort_unpack = opt.swap_uint2_unpack + ushort_pack = opt.swap_uint2_pack + long_unpack = opt.swap_int4_unpack + long_pack = opt.swap_int4_pack + ulong_unpack = opt.swap_uint4_unpack + ulong_pack = opt.swap_uint4_pack + + if hasattr(opt, 'uint8_pack'): + longlong_unpack = opt.swap_int8_unpack + longlong_pack = opt.swap_int8_pack + ulonglong_unpack = opt.swap_uint8_unpack + ulonglong_pack = opt.swap_uint8_pack + elif bo == 'big': + short_unpack = opt.int2_unpack + short_pack = opt.int2_pack + ushort_unpack = opt.uint2_unpack + ushort_pack = opt.uint2_pack + long_unpack = opt.int4_unpack + long_pack = opt.int4_pack + ulong_unpack = opt.uint4_unpack + ulong_pack = opt.uint4_pack + + if hasattr(opt, 'uint8_pack'): + longlong_unpack = opt.int8_unpack + longlong_pack = opt.int8_pack + ulonglong_unpack = opt.uint8_unpack + ulonglong_pack = opt.uint8_pack + del bo, opt +except ImportError: + short_pack, short_unpack = mk_pack("h") + ushort_pack, ushort_unpack = mk_pack("H") + long_pack, long_unpack = mk_pack("l") + ulong_pack, ulong_unpack = mk_pack("L") + +def split_sized_data( + data, + ulong_unpack = ulong_unpack, + null_field = 0xFFFFFFFF, + len = len, + errmsg = "insufficient data in field {0}, required {1} bytes, {2} remaining".format +): + """ + Given serialized record data, return a tuple of tuples of type Oids and + attributes. + """ + v = memoryview(data) + f = 1 + while v: + l = ulong_unpack(v) + if l == null_field: + v = v[4:] + yield None + continue + l += 4 + d = v[4:l].tobytes() + if len(d) < l-4: + raise ValueError(errmsg(f, l - 4, len(d))) + v = v[l:] + f += 1 + yield d diff --git a/py_opengauss/release/__init__.py b/py_opengauss/release/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71af809309df5e4e7756e1c12449727d4cc27a75 --- /dev/null +++ b/py_opengauss/release/__init__.py @@ -0,0 +1,6 @@ +## +# .release +## +""" +Release management code and project/release meta-data. +""" diff --git a/py_opengauss/release/distutils.py b/py_opengauss/release/distutils.py new file mode 100644 index 0000000000000000000000000000000000000000..836b67f8f6e349d8274852a1420aac55746d62c1 --- /dev/null +++ b/py_opengauss/release/distutils.py @@ -0,0 +1,205 @@ +## +# .release.distutils - distutils data +## +""" +Python distutils data provisions module. + +For sub-packagers, the `prefixed_packages` and `prefixed_extensions` functions +should be of particular interest. If the distribution including ``py-postgresql`` +uses the standard layout, chances are that `prefixed_extensions` and +`prefixed_packages` will supply the appropriate information by default as they +use `default_prefix` which is derived from the module's `__package__`. +""" +import sys +import os +from ..project import version, name, identity as url +try: + from setuptools import Extension, Command +except ImportError as e: + from distutils.core import Extension, Command + +LONG_DESCRIPTION = """ +This package is based on py-postgresql upgrades to work with openGauss. + +Forked Repo: http://github.com/vimiix/py-opengauss + +.. warning:: + In v1.3, `postgresql.driver.dbapi20.connect` will now raise `ClientCannotConnectError` directly. + Exception traps around connect should still function, but the `__context__` attribute + on the error instance will be `None` in the usual failure case as it is no longer + incorrectly chained. Trapping `ClientCannotConnectError` ahead of `Error` should + allow both cases to co-exist in the event that data is being extracted from + the `ClientCannotConnectError`. + +py-postgresql is a set of Python modules providing interfaces to various parts +of PostgreSQL. Primarily, it provides a pure-Python driver with some C optimizations for +querying a PostgreSQL database. + +http://github.com/python-postgres/fe + +Features: + + * Prepared Statement driven interfaces. + * Cluster tools for creating and controlling a cluster. + * Support for most PostgreSQL types: composites, arrays, numeric, lots more. + * COPY support. + +Sample PG-API Code:: + + >>> import postgresql + >>> db = postgresql.open('pq://user:password@host:port/database') + >>> db.execute("CREATE TABLE emp (emp_first_name text, emp_last_name text, emp_salary numeric)") + >>> make_emp = db.prepare("INSERT INTO emp VALUES ($1, $2, $3)") + >>> make_emp("John", "Doe", "75,322") + >>> with db.xact(): + ... make_emp("Jane", "Doe", "75,322") + ... make_emp("Edward", "Johnson", "82,744") + ... + +There is a DB-API 2.0 module as well:: + + postgresql.driver.dbapi20 + +However, PG-API is recommended as it provides greater utility. + +Once installed, try out the ``pg_python`` console script:: + + $ python3 -m postgresql.bin.pg_python -h localhost -p port -U theuser -d database_name + +If a successful connection is made to the remote host, it will provide a Python +console with the database connection bound to the `db` name. +""" + +CLASSIFIERS = [ + 'Development Status :: 5 - Production/Stable', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: BSD License', + 'License :: OSI Approved :: MIT License', + 'License :: OSI Approved :: Attribution Assurance License', + 'License :: OSI Approved :: Python Software Foundation License', + 'Natural Language :: English', + 'Operating System :: OS Independent', + 'Programming Language :: Python', + 'Programming Language :: Python :: 3', + 'Topic :: Database', +] + +subpackages = [ + 'bin', + 'encodings', + 'lib', + 'protocol', + 'driver', + 'test', + 'documentation', + 'python', + 'port', + 'release', + # Modules imported from other packages. + 'resolved', + 'types', + 'types.io', +] +extensions_data = { + 'port.optimized' : { + 'sources' : [os.path.join('port', '_optimized', 'module.c')], + }, +} + +subpackage_data = { + 'lib' : ['*.sql'], + 'documentation' : ['*.txt'] +} + +try: + # :) + if __package__ is not None: + default_prefix = __package__.split('.')[:-1] + else: + default_prefix = __name__.split('.')[:-2] +except NameError: + default_prefix = ['postgresql'] + +def prefixed_extensions( + prefix = default_prefix, + extensions_data = extensions_data, +) -> [Extension]: + """ + Generator producing the `distutils` `Extension` objects. + """ + pkg_prefix = '.'.join(prefix) + '.' + path_prefix = os.path.sep.join(prefix) + for mod, data in extensions_data.items(): + yield Extension( + pkg_prefix + mod, + [os.path.join(path_prefix, src) for src in data['sources']], + libraries = data.get('libraries', ()), + optional = True, + ) + +def prefixed_packages( + prefix = default_prefix, + packages = subpackages, +): + """ + Generator producing the standard `package` list prefixed with `prefix`. + """ + prefix = '.'.join(prefix) + yield prefix + prefix = prefix + '.' + for pkg in packages: + yield prefix + pkg + +def prefixed_package_data( + prefix = default_prefix, + package_data = subpackage_data, +): + """ + Generator producing the standard `package` list prefixed with `prefix`. + """ + prefix = '.'.join(prefix) + prefix = prefix + '.' + for pkg, data in package_data.items(): + yield prefix + pkg, data + +def standard_setup_keywords(build_extensions = True, prefix = default_prefix): + """ + Used by the py-postgresql distribution. + """ + d = { + 'name' : name, + 'version' : version, + 'description' : 'Opengauss driver and tools library.', + 'long_description' : LONG_DESCRIPTION, + 'long_description_content_type' : 'text/x-rst', + 'author' : 'James William Pye', + 'author_email' : 'james.pye@gmail.com', + 'maintainer' : 'James William Pye', + 'maintainer_email' : 'james.pye@gmail.com', + 'url' : url, + 'classifiers' : CLASSIFIERS, + 'packages' : list(prefixed_packages(prefix = prefix)), + 'package_data' : dict(prefixed_package_data(prefix = prefix)), + 'cmdclass': dict(test=TestCommand), + 'python_requires': '>=3.3', + } + if build_extensions: + d['ext_modules'] = list(prefixed_extensions(prefix = prefix)) + return d + +class TestCommand(Command): + description = "run tests" + + # List of option tuples: long name, short name (None if no short + # name), and help string. + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + import unittest + unittest.main(module='postgresql.test.testall', argv=('setup.py',)) diff --git a/py_opengauss/resolved/__init__.py b/py_opengauss/resolved/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..51809b908d8e0dec20ede75289b0742f28d4fc7e --- /dev/null +++ b/py_opengauss/resolved/__init__.py @@ -0,0 +1,3 @@ +""" +Modules and packages resolved to avoid user dependency resolution. +""" diff --git a/py_opengauss/resolved/crypt.py b/py_opengauss/resolved/crypt.py new file mode 100644 index 0000000000000000000000000000000000000000..5524cda13684c254c826011c0c2131ff64fc7965 --- /dev/null +++ b/py_opengauss/resolved/crypt.py @@ -0,0 +1,619 @@ +# fcrypt.py + +"""Unix crypt(3) password hash algorithm. + +This is a port to Python of the standard Unix password crypt function. +It's a single self-contained source file that works with any version +of Python from version 1.5 or higher. The code is based on Eric +Young's optimised crypt in C. + +Python fcrypt is intended for users whose Python installation has not +had the crypt module enabled, or whose C library doesn't include the +crypt function. See the documentation for the Python crypt module for +more information: + + http://www.python.org/doc/current/lib/module-crypt.html + +An alternative Python crypt module that uses the MD5 algorithm and is +more secure than fcrypt is available from michal j wallace at: + + http://www.sabren.net/code/python/crypt/index.php3 + +The crypt() function is a one-way hash function, intended to hide a +password such that the only way to find out the original password is +to guess values until you get a match. If you need to encrypt and +decrypt data, this is not the module for you. + +There are at least two packages providing Python cryptography support: +M2Crypto at , and amkCrypto at +. + +Functions: + + crypt() -- return hashed password +""" + +__author__ = 'Carey Evans ' +__version__ = '1.3.1' +__date__ = '21 February 2004' +__credits__ = '''michal j wallace for inspiring me to write this. +Eric Young for the C code this module was copied from.''' + +__all__ = ['crypt'] + + +# Copyright (C) 2000, 2001, 2004 Carey Evans +# +# Permission to use, copy, modify, and distribute this software and +# its documentation for any purpose and without fee is hereby granted, +# provided that the above copyright notice appear in all copies and +# that both that copyright notice and this permission notice appear in +# supporting documentation. +# +# CAREY EVANS DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, +# INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO +# EVENT SHALL CAREY EVANS BE LIABLE FOR ANY SPECIAL, INDIRECT OR +# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF +# USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR +# OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +# PERFORMANCE OF THIS SOFTWARE. + +# Based on C code by Eric Young (eay@mincom.oz.au), which has the +# following copyright. Especially note condition 3, which imposes +# extra restrictions on top of the standard Python license used above. +# +# The fcrypt.c source is available from: +# ftp://ftp.psy.uq.oz.au/pub/Crypto/DES/ + +# ----- BEGIN fcrypt.c LICENSE ----- +# +# This library is free for commercial and non-commercial use as long as +# the following conditions are aheared to. The following conditions +# apply to all code found in this distribution, be it the RC4, RSA, +# lhash, DES, etc., code; not just the SSL code. The SSL documentation +# included with this distribution is covered by the same copyright terms +# except that the holder is Tim Hudson (tjh@mincom.oz.au). +# +# Copyright remains Eric Young's, and as such any Copyright notices in +# the code are not to be removed. +# If this package is used in a product, Eric Young should be given attribution +# as the author of the parts of the library used. +# This can be in the form of a textual message at program startup or +# in documentation (online or textual) provided with the package. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# 1. Redistributions of source code must retain the copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# 3. All advertising materials mentioning features or use of this software +# must display the following acknowledgement: +# "This product includes cryptographic software written by +# Eric Young (eay@mincom.oz.au)" +# The word 'cryptographic' can be left out if the rouines from the library +# being used are not cryptographic related :-). +# 4. If you include any Windows specific code (or a derivative thereof) from +# the apps directory (application code) you must include an acknowledgement: +# "This product includes software written by Tim Hudson (tjh@mincom.oz.au)" +# +# THIS SOFTWARE IS PROVIDED BY ERIC YOUNG ``AS IS'' AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +# SUCH DAMAGE. +# +# The licence and distribution terms for any publically available version or +# derivative of this code cannot be changed. i.e. this code cannot simply be +# copied and put under another distribution licence +# [including the GNU Public Licence.] +# +# ----- END fcrypt.c LICENSE ----- + + +import string, struct + + +_ITERATIONS = 16 + +_SPtrans = ( + # nibble 0 + ( 0x00820200, 0x00020000, 0x80800000, 0x80820200, + 0x00800000, 0x80020200, 0x80020000, 0x80800000, + 0x80020200, 0x00820200, 0x00820000, 0x80000200, + 0x80800200, 0x00800000, 0x00000000, 0x80020000, + 0x00020000, 0x80000000, 0x00800200, 0x00020200, + 0x80820200, 0x00820000, 0x80000200, 0x00800200, + 0x80000000, 0x00000200, 0x00020200, 0x80820000, + 0x00000200, 0x80800200, 0x80820000, 0x00000000, + 0x00000000, 0x80820200, 0x00800200, 0x80020000, + 0x00820200, 0x00020000, 0x80000200, 0x00800200, + 0x80820000, 0x00000200, 0x00020200, 0x80800000, + 0x80020200, 0x80000000, 0x80800000, 0x00820000, + 0x80820200, 0x00020200, 0x00820000, 0x80800200, + 0x00800000, 0x80000200, 0x80020000, 0x00000000, + 0x00020000, 0x00800000, 0x80800200, 0x00820200, + 0x80000000, 0x80820000, 0x00000200, 0x80020200 ), + + # nibble 1 + ( 0x10042004, 0x00000000, 0x00042000, 0x10040000, + 0x10000004, 0x00002004, 0x10002000, 0x00042000, + 0x00002000, 0x10040004, 0x00000004, 0x10002000, + 0x00040004, 0x10042000, 0x10040000, 0x00000004, + 0x00040000, 0x10002004, 0x10040004, 0x00002000, + 0x00042004, 0x10000000, 0x00000000, 0x00040004, + 0x10002004, 0x00042004, 0x10042000, 0x10000004, + 0x10000000, 0x00040000, 0x00002004, 0x10042004, + 0x00040004, 0x10042000, 0x10002000, 0x00042004, + 0x10042004, 0x00040004, 0x10000004, 0x00000000, + 0x10000000, 0x00002004, 0x00040000, 0x10040004, + 0x00002000, 0x10000000, 0x00042004, 0x10002004, + 0x10042000, 0x00002000, 0x00000000, 0x10000004, + 0x00000004, 0x10042004, 0x00042000, 0x10040000, + 0x10040004, 0x00040000, 0x00002004, 0x10002000, + 0x10002004, 0x00000004, 0x10040000, 0x00042000 ), + + # nibble 2 + ( 0x41000000, 0x01010040, 0x00000040, 0x41000040, + 0x40010000, 0x01000000, 0x41000040, 0x00010040, + 0x01000040, 0x00010000, 0x01010000, 0x40000000, + 0x41010040, 0x40000040, 0x40000000, 0x41010000, + 0x00000000, 0x40010000, 0x01010040, 0x00000040, + 0x40000040, 0x41010040, 0x00010000, 0x41000000, + 0x41010000, 0x01000040, 0x40010040, 0x01010000, + 0x00010040, 0x00000000, 0x01000000, 0x40010040, + 0x01010040, 0x00000040, 0x40000000, 0x00010000, + 0x40000040, 0x40010000, 0x01010000, 0x41000040, + 0x00000000, 0x01010040, 0x00010040, 0x41010000, + 0x40010000, 0x01000000, 0x41010040, 0x40000000, + 0x40010040, 0x41000000, 0x01000000, 0x41010040, + 0x00010000, 0x01000040, 0x41000040, 0x00010040, + 0x01000040, 0x00000000, 0x41010000, 0x40000040, + 0x41000000, 0x40010040, 0x00000040, 0x01010000 ), + + # nibble 3 + ( 0x00100402, 0x04000400, 0x00000002, 0x04100402, + 0x00000000, 0x04100000, 0x04000402, 0x00100002, + 0x04100400, 0x04000002, 0x04000000, 0x00000402, + 0x04000002, 0x00100402, 0x00100000, 0x04000000, + 0x04100002, 0x00100400, 0x00000400, 0x00000002, + 0x00100400, 0x04000402, 0x04100000, 0x00000400, + 0x00000402, 0x00000000, 0x00100002, 0x04100400, + 0x04000400, 0x04100002, 0x04100402, 0x00100000, + 0x04100002, 0x00000402, 0x00100000, 0x04000002, + 0x00100400, 0x04000400, 0x00000002, 0x04100000, + 0x04000402, 0x00000000, 0x00000400, 0x00100002, + 0x00000000, 0x04100002, 0x04100400, 0x00000400, + 0x04000000, 0x04100402, 0x00100402, 0x00100000, + 0x04100402, 0x00000002, 0x04000400, 0x00100402, + 0x00100002, 0x00100400, 0x04100000, 0x04000402, + 0x00000402, 0x04000000, 0x04000002, 0x04100400 ), + + # nibble 4 + ( 0x02000000, 0x00004000, 0x00000100, 0x02004108, + 0x02004008, 0x02000100, 0x00004108, 0x02004000, + 0x00004000, 0x00000008, 0x02000008, 0x00004100, + 0x02000108, 0x02004008, 0x02004100, 0x00000000, + 0x00004100, 0x02000000, 0x00004008, 0x00000108, + 0x02000100, 0x00004108, 0x00000000, 0x02000008, + 0x00000008, 0x02000108, 0x02004108, 0x00004008, + 0x02004000, 0x00000100, 0x00000108, 0x02004100, + 0x02004100, 0x02000108, 0x00004008, 0x02004000, + 0x00004000, 0x00000008, 0x02000008, 0x02000100, + 0x02000000, 0x00004100, 0x02004108, 0x00000000, + 0x00004108, 0x02000000, 0x00000100, 0x00004008, + 0x02000108, 0x00000100, 0x00000000, 0x02004108, + 0x02004008, 0x02004100, 0x00000108, 0x00004000, + 0x00004100, 0x02004008, 0x02000100, 0x00000108, + 0x00000008, 0x00004108, 0x02004000, 0x02000008 ), + + # nibble 5 + ( 0x20000010, 0x00080010, 0x00000000, 0x20080800, + 0x00080010, 0x00000800, 0x20000810, 0x00080000, + 0x00000810, 0x20080810, 0x00080800, 0x20000000, + 0x20000800, 0x20000010, 0x20080000, 0x00080810, + 0x00080000, 0x20000810, 0x20080010, 0x00000000, + 0x00000800, 0x00000010, 0x20080800, 0x20080010, + 0x20080810, 0x20080000, 0x20000000, 0x00000810, + 0x00000010, 0x00080800, 0x00080810, 0x20000800, + 0x00000810, 0x20000000, 0x20000800, 0x00080810, + 0x20080800, 0x00080010, 0x00000000, 0x20000800, + 0x20000000, 0x00000800, 0x20080010, 0x00080000, + 0x00080010, 0x20080810, 0x00080800, 0x00000010, + 0x20080810, 0x00080800, 0x00080000, 0x20000810, + 0x20000010, 0x20080000, 0x00080810, 0x00000000, + 0x00000800, 0x20000010, 0x20000810, 0x20080800, + 0x20080000, 0x00000810, 0x00000010, 0x20080010 ), + + # nibble 6 + ( 0x00001000, 0x00000080, 0x00400080, 0x00400001, + 0x00401081, 0x00001001, 0x00001080, 0x00000000, + 0x00400000, 0x00400081, 0x00000081, 0x00401000, + 0x00000001, 0x00401080, 0x00401000, 0x00000081, + 0x00400081, 0x00001000, 0x00001001, 0x00401081, + 0x00000000, 0x00400080, 0x00400001, 0x00001080, + 0x00401001, 0x00001081, 0x00401080, 0x00000001, + 0x00001081, 0x00401001, 0x00000080, 0x00400000, + 0x00001081, 0x00401000, 0x00401001, 0x00000081, + 0x00001000, 0x00000080, 0x00400000, 0x00401001, + 0x00400081, 0x00001081, 0x00001080, 0x00000000, + 0x00000080, 0x00400001, 0x00000001, 0x00400080, + 0x00000000, 0x00400081, 0x00400080, 0x00001080, + 0x00000081, 0x00001000, 0x00401081, 0x00400000, + 0x00401080, 0x00000001, 0x00001001, 0x00401081, + 0x00400001, 0x00401080, 0x00401000, 0x00001001 ), + + # nibble 7 + ( 0x08200020, 0x08208000, 0x00008020, 0x00000000, + 0x08008000, 0x00200020, 0x08200000, 0x08208020, + 0x00000020, 0x08000000, 0x00208000, 0x00008020, + 0x00208020, 0x08008020, 0x08000020, 0x08200000, + 0x00008000, 0x00208020, 0x00200020, 0x08008000, + 0x08208020, 0x08000020, 0x00000000, 0x00208000, + 0x08000000, 0x00200000, 0x08008020, 0x08200020, + 0x00200000, 0x00008000, 0x08208000, 0x00000020, + 0x00200000, 0x00008000, 0x08000020, 0x08208020, + 0x00008020, 0x08000000, 0x00000000, 0x00208000, + 0x08200020, 0x08008020, 0x08008000, 0x00200020, + 0x08208000, 0x00000020, 0x00200020, 0x08008000, + 0x08208020, 0x00200000, 0x08200000, 0x08000020, + 0x00208000, 0x00008020, 0x08008020, 0x08200000, + 0x00000020, 0x08208000, 0x00208020, 0x00000000, + 0x08000000, 0x08200020, 0x00008000, 0x00208020 ), +) + +_skb = ( + # for C bits (numbered as per FIPS 46) 1 2 3 4 5 6 + ( 0x00000000, 0x00000010, 0x20000000, 0x20000010, + 0x00010000, 0x00010010, 0x20010000, 0x20010010, + 0x00000800, 0x00000810, 0x20000800, 0x20000810, + 0x00010800, 0x00010810, 0x20010800, 0x20010810, + 0x00000020, 0x00000030, 0x20000020, 0x20000030, + 0x00010020, 0x00010030, 0x20010020, 0x20010030, + 0x00000820, 0x00000830, 0x20000820, 0x20000830, + 0x00010820, 0x00010830, 0x20010820, 0x20010830, + 0x00080000, 0x00080010, 0x20080000, 0x20080010, + 0x00090000, 0x00090010, 0x20090000, 0x20090010, + 0x00080800, 0x00080810, 0x20080800, 0x20080810, + 0x00090800, 0x00090810, 0x20090800, 0x20090810, + 0x00080020, 0x00080030, 0x20080020, 0x20080030, + 0x00090020, 0x00090030, 0x20090020, 0x20090030, + 0x00080820, 0x00080830, 0x20080820, 0x20080830, + 0x00090820, 0x00090830, 0x20090820, 0x20090830 ), + + # for C bits (numbered as per FIPS 46) 7 8 10 11 12 13 + ( 0x00000000, 0x02000000, 0x00002000, 0x02002000, + 0x00200000, 0x02200000, 0x00202000, 0x02202000, + 0x00000004, 0x02000004, 0x00002004, 0x02002004, + 0x00200004, 0x02200004, 0x00202004, 0x02202004, + 0x00000400, 0x02000400, 0x00002400, 0x02002400, + 0x00200400, 0x02200400, 0x00202400, 0x02202400, + 0x00000404, 0x02000404, 0x00002404, 0x02002404, + 0x00200404, 0x02200404, 0x00202404, 0x02202404, + 0x10000000, 0x12000000, 0x10002000, 0x12002000, + 0x10200000, 0x12200000, 0x10202000, 0x12202000, + 0x10000004, 0x12000004, 0x10002004, 0x12002004, + 0x10200004, 0x12200004, 0x10202004, 0x12202004, + 0x10000400, 0x12000400, 0x10002400, 0x12002400, + 0x10200400, 0x12200400, 0x10202400, 0x12202400, + 0x10000404, 0x12000404, 0x10002404, 0x12002404, + 0x10200404, 0x12200404, 0x10202404, 0x12202404 ), + + # for C bits (numbered as per FIPS 46) 14 15 16 17 19 20 + ( 0x00000000, 0x00000001, 0x00040000, 0x00040001, + 0x01000000, 0x01000001, 0x01040000, 0x01040001, + 0x00000002, 0x00000003, 0x00040002, 0x00040003, + 0x01000002, 0x01000003, 0x01040002, 0x01040003, + 0x00000200, 0x00000201, 0x00040200, 0x00040201, + 0x01000200, 0x01000201, 0x01040200, 0x01040201, + 0x00000202, 0x00000203, 0x00040202, 0x00040203, + 0x01000202, 0x01000203, 0x01040202, 0x01040203, + 0x08000000, 0x08000001, 0x08040000, 0x08040001, + 0x09000000, 0x09000001, 0x09040000, 0x09040001, + 0x08000002, 0x08000003, 0x08040002, 0x08040003, + 0x09000002, 0x09000003, 0x09040002, 0x09040003, + 0x08000200, 0x08000201, 0x08040200, 0x08040201, + 0x09000200, 0x09000201, 0x09040200, 0x09040201, + 0x08000202, 0x08000203, 0x08040202, 0x08040203, + 0x09000202, 0x09000203, 0x09040202, 0x09040203 ), + + # for C bits (numbered as per FIPS 46) 21 23 24 26 27 28 + ( 0x00000000, 0x00100000, 0x00000100, 0x00100100, + 0x00000008, 0x00100008, 0x00000108, 0x00100108, + 0x00001000, 0x00101000, 0x00001100, 0x00101100, + 0x00001008, 0x00101008, 0x00001108, 0x00101108, + 0x04000000, 0x04100000, 0x04000100, 0x04100100, + 0x04000008, 0x04100008, 0x04000108, 0x04100108, + 0x04001000, 0x04101000, 0x04001100, 0x04101100, + 0x04001008, 0x04101008, 0x04001108, 0x04101108, + 0x00020000, 0x00120000, 0x00020100, 0x00120100, + 0x00020008, 0x00120008, 0x00020108, 0x00120108, + 0x00021000, 0x00121000, 0x00021100, 0x00121100, + 0x00021008, 0x00121008, 0x00021108, 0x00121108, + 0x04020000, 0x04120000, 0x04020100, 0x04120100, + 0x04020008, 0x04120008, 0x04020108, 0x04120108, + 0x04021000, 0x04121000, 0x04021100, 0x04121100, + 0x04021008, 0x04121008, 0x04021108, 0x04121108 ), + + # for D bits (numbered as per FIPS 46) 1 2 3 4 5 6 + ( 0x00000000, 0x10000000, 0x00010000, 0x10010000, + 0x00000004, 0x10000004, 0x00010004, 0x10010004, + 0x20000000, 0x30000000, 0x20010000, 0x30010000, + 0x20000004, 0x30000004, 0x20010004, 0x30010004, + 0x00100000, 0x10100000, 0x00110000, 0x10110000, + 0x00100004, 0x10100004, 0x00110004, 0x10110004, + 0x20100000, 0x30100000, 0x20110000, 0x30110000, + 0x20100004, 0x30100004, 0x20110004, 0x30110004, + 0x00001000, 0x10001000, 0x00011000, 0x10011000, + 0x00001004, 0x10001004, 0x00011004, 0x10011004, + 0x20001000, 0x30001000, 0x20011000, 0x30011000, + 0x20001004, 0x30001004, 0x20011004, 0x30011004, + 0x00101000, 0x10101000, 0x00111000, 0x10111000, + 0x00101004, 0x10101004, 0x00111004, 0x10111004, + 0x20101000, 0x30101000, 0x20111000, 0x30111000, + 0x20101004, 0x30101004, 0x20111004, 0x30111004 ), + + # for D bits (numbered as per FIPS 46) 8 9 11 12 13 14 + ( 0x00000000, 0x08000000, 0x00000008, 0x08000008, + 0x00000400, 0x08000400, 0x00000408, 0x08000408, + 0x00020000, 0x08020000, 0x00020008, 0x08020008, + 0x00020400, 0x08020400, 0x00020408, 0x08020408, + 0x00000001, 0x08000001, 0x00000009, 0x08000009, + 0x00000401, 0x08000401, 0x00000409, 0x08000409, + 0x00020001, 0x08020001, 0x00020009, 0x08020009, + 0x00020401, 0x08020401, 0x00020409, 0x08020409, + 0x02000000, 0x0A000000, 0x02000008, 0x0A000008, + 0x02000400, 0x0A000400, 0x02000408, 0x0A000408, + 0x02020000, 0x0A020000, 0x02020008, 0x0A020008, + 0x02020400, 0x0A020400, 0x02020408, 0x0A020408, + 0x02000001, 0x0A000001, 0x02000009, 0x0A000009, + 0x02000401, 0x0A000401, 0x02000409, 0x0A000409, + 0x02020001, 0x0A020001, 0x02020009, 0x0A020009, + 0x02020401, 0x0A020401, 0x02020409, 0x0A020409 ), + + # for D bits (numbered as per FIPS 46) 16 17 18 19 20 21 + ( 0x00000000, 0x00000100, 0x00080000, 0x00080100, + 0x01000000, 0x01000100, 0x01080000, 0x01080100, + 0x00000010, 0x00000110, 0x00080010, 0x00080110, + 0x01000010, 0x01000110, 0x01080010, 0x01080110, + 0x00200000, 0x00200100, 0x00280000, 0x00280100, + 0x01200000, 0x01200100, 0x01280000, 0x01280100, + 0x00200010, 0x00200110, 0x00280010, 0x00280110, + 0x01200010, 0x01200110, 0x01280010, 0x01280110, + 0x00000200, 0x00000300, 0x00080200, 0x00080300, + 0x01000200, 0x01000300, 0x01080200, 0x01080300, + 0x00000210, 0x00000310, 0x00080210, 0x00080310, + 0x01000210, 0x01000310, 0x01080210, 0x01080310, + 0x00200200, 0x00200300, 0x00280200, 0x00280300, + 0x01200200, 0x01200300, 0x01280200, 0x01280300, + 0x00200210, 0x00200310, 0x00280210, 0x00280310, + 0x01200210, 0x01200310, 0x01280210, 0x01280310 ), + + # for D bits (numbered as per FIPS 46) 22 23 24 25 27 28 + ( 0x00000000, 0x04000000, 0x00040000, 0x04040000, + 0x00000002, 0x04000002, 0x00040002, 0x04040002, + 0x00002000, 0x04002000, 0x00042000, 0x04042000, + 0x00002002, 0x04002002, 0x00042002, 0x04042002, + 0x00000020, 0x04000020, 0x00040020, 0x04040020, + 0x00000022, 0x04000022, 0x00040022, 0x04040022, + 0x00002020, 0x04002020, 0x00042020, 0x04042020, + 0x00002022, 0x04002022, 0x00042022, 0x04042022, + 0x00000800, 0x04000800, 0x00040800, 0x04040800, + 0x00000802, 0x04000802, 0x00040802, 0x04040802, + 0x00002800, 0x04002800, 0x00042800, 0x04042800, + 0x00002802, 0x04002802, 0x00042802, 0x04042802, + 0x00000820, 0x04000820, 0x00040820, 0x04040820, + 0x00000822, 0x04000822, 0x00040822, 0x04040822, + 0x00002820, 0x04002820, 0x00042820, 0x04042820, + 0x00002822, 0x04002822, 0x00042822, 0x04042822 ) +) + +_shifts2 = (0,0,1,1,1,1,1,1,0,1,1,1,1,1,1,0) + +_con_salt = ( + 0xD2,0xD3,0xD4,0xD5,0xD6,0xD7,0xD8,0xD9, + 0xDA,0xDB,0xDC,0xDD,0xDE,0xDF,0xE0,0xE1, + 0xE2,0xE3,0xE4,0xE5,0xE6,0xE7,0xE8,0xE9, + 0xEA,0xEB,0xEC,0xED,0xEE,0xEF,0xF0,0xF1, + 0xF2,0xF3,0xF4,0xF5,0xF6,0xF7,0xF8,0xF9, + 0xFA,0xFB,0xFC,0xFD,0xFE,0xFF,0x00,0x01, + 0x02,0x03,0x04,0x05,0x06,0x07,0x08,0x09, + 0x0A,0x0B,0x05,0x06,0x07,0x08,0x09,0x0A, + 0x0B,0x0C,0x0D,0x0E,0x0F,0x10,0x11,0x12, + 0x13,0x14,0x15,0x16,0x17,0x18,0x19,0x1A, + 0x1B,0x1C,0x1D,0x1E,0x1F,0x20,0x21,0x22, + 0x23,0x24,0x25,0x20,0x21,0x22,0x23,0x24, + 0x25,0x26,0x27,0x28,0x29,0x2A,0x2B,0x2C, + 0x2D,0x2E,0x2F,0x30,0x31,0x32,0x33,0x34, + 0x35,0x36,0x37,0x38,0x39,0x3A,0x3B,0x3C, + 0x3D,0x3E,0x3F,0x40,0x41,0x42,0x43,0x44 +) + +_cov_2char = b'./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' + + +def _HPERM_OP(a): + """Clever bit manipulation.""" + t = ((a << 18) ^ a) & 0xcccc0000 + return a ^ t ^ ((t >> 18) & 0x3fff) + +def _PERM_OP(a,b,n,m): + """Cleverer bit manipulation.""" + t = ((a >> n) ^ b) & m + b = b ^ t + a = a ^ (t << n) + return a,b + +ii_struct = struct.Struct('> 16) | ((c >> 4) & 0x0f000000)) + c = c & 0x0fffffff + + # Copy globals into local variables for loop. + shifts2 = _shifts2 + skbc0, skbc1, skbc2, skbc3, skbd0, skbd1, skbd2, skbd3 = _skb + + k = [0] * (_ITERATIONS * 2) + + for i in range(_ITERATIONS): + # Only operates on top 28 bits. + if shifts2[i]: + c = (c >> 2) | (c << 26) + d = (d >> 2) | (d << 26) + else: + c = (c >> 1) | (c << 27) + d = (d >> 1) | (d << 27) + c = c & 0x0fffffff + d = d & 0x0fffffff + + s = ( skbc0[ c & 0x3f ] | + skbc1[((c>> 6) & 0x03) | ((c>> 7) & 0x3c)] | + skbc2[((c>>13) & 0x0f) | ((c>>14) & 0x30)] | + skbc3[((c>>20) & 0x01) | + ((c>>21) & 0x06) | ((c>>22) & 0x38)] ) + + t = ( skbd0[ d & 0x3f ] | + skbd1[((d>> 7) & 0x03) | ((d>> 8) & 0x3c)] | + skbd2[((d>>15) & 0x3f) ] | + skbd3[((d>>21) & 0x0f) | ((d>>22) & 0x30)] ) + + k[2*i] = ((t << 16) | (s & 0x0000ffff)) & 0xffffffff + s = (s >> 16) | (t & 0xffff0000) + + # Top bit of s may be 1. + s = (s << 4) | ((s >> 28) & 0x0f) + k[2*i + 1] = s & 0xffffffff + + return k + + +def _body(ks, E0, E1): + """Use the key schedule ks and salt E0, E1 to create the password hash.""" + + # Copy global variable into locals for loop. + SP0, SP1, SP2, SP3, SP4, SP5, SP6, SP7 = _SPtrans + + inner = range(0, _ITERATIONS*2, 2) + l = r = 0 + for j in range(25): + l,r = r,l + for i in inner: + t = r ^ ((r >> 16) & 0xffff) + u = t & E0 + t = t & E1 + u = u ^ (u << 16) ^ r ^ ks[i] + t = t ^ (t << 16) ^ r ^ ks[i+1] + t = ((t >> 4) & 0x0fffffff) | (t << 28) + + l,r = r,(SP1[(t ) & 0x3f] ^ SP3[(t>> 8) & 0x3f] ^ + SP5[(t>>16) & 0x3f] ^ SP7[(t>>24) & 0x3f] ^ + SP0[(u ) & 0x3f] ^ SP2[(u>> 8) & 0x3f] ^ + SP4[(u>>16) & 0x3f] ^ SP6[(u>>24) & 0x3f] ^ l) + + l = ((l >> 1) & 0x7fffffff) | ((l & 0x1) << 31) + r = ((r >> 1) & 0x7fffffff) | ((r & 0x1) << 31) + + r,l = _PERM_OP(r, l, 1, 0x55555555) + l,r = _PERM_OP(l, r, 8, 0x00ff00ff) + r,l = _PERM_OP(r, l, 2, 0x33333333) + l,r = _PERM_OP(l, r, 16, 0x0000ffff) + r,l = _PERM_OP(r, l, 4, 0x0f0f0f0f) + + return l,r + + +def crypt(password, salt): + """Generate an encrypted hash from the passed password. If the password +is longer than eight characters, only the first eight will be used. + +The first two characters of the salt are used to modify the encryption +algorithm used to generate in the hash in one of 4096 different ways. +The characters for the salt should be upper- and lower-case letters A +to Z, digits 0 to 9, '.' and '/'. + +The returned hash begins with the two characters of the salt, and +should be passed as the salt to verify the password. + +Example: + + >>> from fcrypt import crypt + >>> password = 'AlOtBsOl' + >>> salt = 'cE' + >>> hash = crypt(password, salt) + >>> hash + 'cEpWz5IUCShqM' + >>> crypt(password, hash) == hash + 1 + >>> crypt('IaLaIoK', hash) == hash + 0 + +In practice, you would read the password using something like the +getpass module, and generate the salt randomly: + + >>> import random, string + >>> saltchars = string.letters + string.digits + './' + >>> salt = random.choice(saltchars) + random.choice(saltchars) + +Note that other ASCII characters are accepted in the salt, but the +results may not be the same as other versions of crypt. In +particular, '_', '$1' and '$2' do not select alternative hash +algorithms such as the extended passwords, MD5 crypt and Blowfish +crypt supported by the OpenBSD C library. +""" + + # Extract the salt. + if len(salt) == 0: + salt = b'AA' + elif len(salt) == 1: + salt = salt + b'A' + Eswap0 = _con_salt[salt[0] & 0x7f] + Eswap1 = _con_salt[salt[1] & 0x7f] << 4 + + # Generate the key and use it to apply the encryption. + ks = _set_key((password + b'\x00\x00\x00\x00\x00\x00\x00\x00')[:8]) + o1, o2 = _body(ks, Eswap0, Eswap1) + + # Extract 24-bit subsets of result with bytes reversed. + t1 = (o1 << 16 & 0xff0000) | (o1 & 0xff00) | (o1 >> 16 & 0xff) + t2 = (o1 >> 8 & 0xff0000) | (o2 << 8 & 0xff00) | (o2 >> 8 & 0xff) + t3 = (o2 & 0xff0000) | (o2 >> 16 & 0xff00) + # Extract 6-bit subsets. + r = [ t1 >> 18 & 0x3f, t1 >> 12 & 0x3f, t1 >> 6 & 0x3f, t1 & 0x3f, + t2 >> 18 & 0x3f, t2 >> 12 & 0x3f, t2 >> 6 & 0x3f, t2 & 0x3f, + t3 >> 18 & 0x3f, t3 >> 12 & 0x3f, t3 >> 6 & 0x3f ] + # Convert to characters. + for i in range(len(r)): + r[i] = _cov_2char[r[i]:r[i]+1] + return salt[:2] + b''.join(r) + +def _test(): + """Run doctest on fcrypt module.""" + import doctest, fcrypt + return doctest.testmod(fcrypt) + +if __name__ == '__main__': + _test() diff --git a/py_opengauss/resolved/opengauss.py b/py_opengauss/resolved/opengauss.py new file mode 100644 index 0000000000000000000000000000000000000000..87f72cd5fbd2dd799a576e77bd7dd90ca02ea513 --- /dev/null +++ b/py_opengauss/resolved/opengauss.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- + +import hmac +from hashlib import md5, pbkdf2_hmac, sha256 +from ..python.structlib import ulong_unpack + +PLAIN_PASSWORD = 0 +MD5_PASSWORD = 1 +SHA256_PASSWORD = 2 + + +def sha256_pw(user, password, salt): + password_stored_method, salt = ulong_unpack(salt[:4]), salt[4:] + if password_stored_method in (PLAIN_PASSWORD, SHA256_PASSWORD): + random64_code_str, salt = salt[:64].decode(), salt[64:] + token_str, salt = salt[:8].decode(), salt[8:] + iteration = ulong_unpack(salt[:4]) + return rfc5802_algorithm(password, random64_code_str, token_str, "", iteration) + elif password_stored_method == 1: + # MD5 + pw = md5(password + user).hexdigest().encode('ascii') + return b'md5' + md5(pw + salt[:4]).hexdigest().encode('ascii') + else: + raise Exception("pq: the password-stored method is not supported, must be plain, md5 or sha256.") + + +def rfc5802_algorithm(password, random64_code_str, token_str, server_signature, server_iteration): + k = generate_k_from_pbkdf2(password, random64_code_str, server_iteration) + server_key = get_key_from_hmac(k, b'Sever Key') + client_key = get_key_from_hmac(k, b'Client Key') + stored_key = get_sha256(client_key) + token_bytes = hex_string_to_bytes(token_str) + client_signature = get_key_from_hmac(server_key, token_bytes) + if server_signature != "" and server_signature != bytes_to_hex_string(client_signature): + return b'' + hmac_result = get_key_from_hmac(stored_key, token_bytes) + h = XOR_between_password(hmac_result, client_key, len(client_key)) + return bytes_to_hex(h) + + +def XOR_between_password(password1, password2, length): + arr = bytearray() + for i in range(length): + arr.append(password1[i] ^ password2[i]) + return bytes(arr) + + +def bytes_to_hex(bytes_array): + lookup = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'] + res = "" + for i in range(len(bytes_array)): + c = bytes_array[i] & 0xFF + j = c >> 4 + res += (lookup[j]) + j = (c & 0xF) + res += (lookup[j]) + return res.encode('ascii') + + +def get_key_from_hmac(k, key): + h = hmac.new(k, digestmod=sha256) + h.update(key) + return h.digest() + + +def get_sha256(key): + h = sha256() + h.update(key) + return h.digest() + + +def generate_k_from_pbkdf2(password, random64_code_str, iterations): + random32_code = hex_string_to_bytes(random64_code_str) + return pbkdf2_hmac('sha1', password, random32_code, iterations, dklen=32) + + +def hex_string_to_bytes(s): + if not s: + return b'' + + arr = bytearray() + s = s.upper() + bytes_len = int(len(s) / 2) + for i in range(bytes_len): + pos = i * 2 + arr.append(ctoi(s[pos]) << 4 | ctoi(s[pos + 1])) + return bytes(arr) + + +def bytes_to_hex_string(bs): + s = "" + for b in bs: + v = b & 0xFF + hv = "%x" % v + if len(hv) < 2: + s += hv + s += "0" + else: + s += hv + return s + + +def ctoi(c): + return "0123456789ABCDEF".index(c) diff --git a/py_opengauss/resolved/riparse.py b/py_opengauss/resolved/riparse.py new file mode 100644 index 0000000000000000000000000000000000000000..f91a261897e35fd5437366565357ddcb7eacc7b5 --- /dev/null +++ b/py_opengauss/resolved/riparse.py @@ -0,0 +1,390 @@ +""" +Split, unsplit, parse, serialize, construct and structure resource indicators. + +Resource indicators take the form:: + + [scheme:[//]][user[:pass]@]host[:port][/[path[/path]*][?param-n1=value[¶m-n=value-n]*][#fragment]] + +It might be an URL, URI, or IRI. It tries not to care. +Notably, it only percent-encodes chr(0-33) as some RIs support character values +greater than 127. Usually, it's best to make a second pass on the string in +order to target a specific format, URI or IRI. + +If a specific format is being targeted, URL or URI or URI-represention of an +IRI, a second pass *must* be made on the string. +# Future versions may include subsequent transformation routines for targeting. + +Overview +-------- + +Where ``x`` is a text RI(ie, ``http://foo.com/path``):: + + unsplit(split(x)) == x + serialize(parse(x)) == x + parse(x) == structure(split(x)) + construct(parse(x)) == split(x) + + +Substructure +------------ + +In some cases, an RI may have additional structure that needs to be extracted. +To do this, the ``fieldproc`` keyword is used on `split_netloc`, `structure`, +and `parse` functions. + +The ``fieldproc`` keyword is a callable that takes a single argument and returns +the processed field. By default, ``fieldproc`` is the `unescape` function which +will decode percent escapes. This is not desirable when substructure exists +within an RI's component as it can create ambiguity about a token when a +percent encoded variant is decoded. +""" +import re + +pct_encode = '%%%0.2X'.__mod__ +unescaped = '%' + ''.join([chr(x) for x in range(0, 33)]) + +percent_escapes_re = re.compile('(%[0-9a-fA-F]{2,2})+') +escape_re = re.compile('[%s]' %(re.escape(unescaped),)) +escape_user_re = re.compile('[%s]' %(re.escape(unescaped + ':@/?#'),)) +escape_password_re = re.compile('[%s]' %(re.escape(unescaped + '@/?#'),)) +escape_host_re = re.compile('[%s]' %(re.escape(unescaped + '/?#'),)) +escape_port_re = re.compile('[%s]' %(re.escape(unescaped + '/?#'),)) +escape_path_re = re.compile('[%s]' %(re.escape(unescaped + '/?#'),)) +escape_query_key_re = re.compile('[%s]' %(re.escape(unescaped + '&=#'),)) +escape_query_value_re = re.compile('[%s]' %(re.escape(unescaped + '&#'),)) + +percent_escapes = {} +for x in range(256): + k = '%0.2X'.__mod__(x) + percent_escapes[k] = x + percent_escapes[k.lower()] = x + percent_escapes[k[0].lower() + k[1]] = x + percent_escapes[k[0] + k[1].lower()] = x + +scheme_chars = '-.+0123456789' +del x + +def unescape(x, mkval = chr): + """ + Substitute percent escapes with literal characters. + """ + nstr = type(x)('') + if isinstance(x, str): + mkval = chr + pos = 0 + end = len(x) + while pos != end: + newpos = x.find('%', pos) + if newpos == -1: + nstr += x[pos:] + break + else: + nstr += x[pos:newpos] + + val = percent_escapes.get(x[newpos+1:newpos+3]) + if val is not None: + nstr += mkval(val) + pos = newpos + 3 + else: + nstr += '%' + pos = newpos + 1 + return nstr + +def re_pct_encode(m): + return pct_encode(ord(m.group(0))) + +indexes = { + 'scheme' : 0, + 'netloc' : 1, + 'path' : 2, + 'query' : 3, + 'fragment' : 4 +} + +def split(s): + """ + Split an IRI into its base components based on the markers:: + + ://, /, ?, # + + Return a 5-tuple: (scheme, netloc, path, query, fragment) + """ + scheme = None + netloc = None + path = None + query = None + fragment = None + + end = len(s) + pos = 0 + + # Non-iauthority RI's should be special cased by the user. + scheme_pos = s.find('://') + if scheme_pos != -1: + pos = scheme_pos + 3 + scheme = s[:scheme_pos] + for x in scheme: + if not (x in scheme_chars) and \ + not ('A' <= x <= 'Z') and not ('a' <= x <= 'z'): + pos = 0 + scheme = None + break + + end_of_netloc = end + + path_pos = s.find('/', pos) + if path_pos == -1: + path_pos = None + else: + end_of_netloc = path_pos + + query_pos = s.find('?', pos) + if query_pos == -1: + query_pos = None + elif path_pos is None or query_pos < path_pos: + path_pos = None + end_of_netloc = query_pos + + fragment_pos = s.find('#', pos) + if fragment_pos == -1: + fragment_pos = None + else: + if query_pos is not None and fragment_pos < query_pos: + query_pos = None + if path_pos is not None and fragment_pos < path_pos: + path_pos = None + end_of_netloc = fragment_pos + if query_pos is None and path_pos is None: + end_of_netloc = fragment_pos + + if end_of_netloc != pos: + netloc = s[pos:end_of_netloc] + + if path_pos is not None: + path = s[path_pos+1:query_pos or fragment_pos or end] + + if query_pos is not None: + query = s[query_pos+1:fragment_pos or end] + + if fragment_pos is not None: + fragment = s[fragment_pos+1:end] + + return (scheme, netloc, path, query, fragment) + +def unsplit_path(p, re = escape_path_re): + """ + Join a list of paths(strings) on "/" *after* escaping them. + """ + if not p: + return None + return '/'.join([re.sub(re_pct_encode, x) for x in p]) + +def split_path(p, fieldproc = unescape): + """ + Return a list of unescaped strings split on "/". + + Set `fieldproc` to `str` if the components' percent escapes should not be + decoded. + """ + if p is None: + return [] + return [fieldproc(x) for x in p.split('/')] + +def unsplit(t): + """ + Make a RI from a split RI(5-tuple). + """ + s = '' + if t[0] is not None: + s += t[0] + s += '://' + if t[1] is not None: + s += t[1] + if t[2] is not None: + s += '/' + s += t[2] + if t[3] is not None: + s += '?' + s += t[3] + if t[4] is not None: + s += '#' + s += t[4] + return s + +def split_netloc(netloc, fieldproc = unescape): + """ + Split a net location into a 4-tuple, (user, password, host, port). + + Set `fieldproc` to `str` if the components' percent escapes should not be + decoded. + """ + pos = netloc.find('@') + if pos == -1: + # No user information + pos = 0 + user = None + password = None + else: + s = netloc[:pos] + userpw = s.split(':', 1) + if len(userpw) == 2: + user, password = userpw + user = fieldproc(user) + password = fieldproc(password) + else: + user = fieldproc(userpw[0]) + password = None + pos += 1 + + if pos >= len(netloc): + return (user, password, None, None) + + pos_chr = netloc[pos] + if pos_chr == '[': + # IPvN addr + next_pos = netloc.find(']', pos) + if next_pos == -1: + # unterminated IPvN block + next_pos = len(netloc) - 1 + addr = netloc[pos:next_pos+1] + pos = next_pos + 1 + next_pos = netloc.find(':', pos) + if next_pos == -1: + port = None + else: + port = fieldproc(netloc[next_pos+1:]) + else: + next_pos = netloc.find(':', pos) + if next_pos == -1: + addr = fieldproc(netloc[pos:]) + port = None + else: + addr = fieldproc(netloc[pos:next_pos]) + port = fieldproc(netloc[next_pos+1:]) + + return (user, password, addr, port) + +def unsplit_netloc(t): + """ + Create a netloc fragment from the given tuple(user,password,host,port). + """ + if t[0] is None and t[2] is None: + return None + s = '' + if t[0] is not None: + s += escape_user_re.sub(re_pct_encode, t[0]) + if t[1] is not None: + s += ':' + s += escape_password_re.sub(re_pct_encode, t[1]) + s += '@' + + if t[2] is not None: + s += escape_host_re.sub(re_pct_encode, t[2]) + if t[3] is not None: + s += ':' + s += escape_port_re.sub(re_pct_encode, t[3]) + + return s + +def structure(t, fieldproc = unescape): + """ + Create a dictionary from a split RI(5-tuple). + + Set `fieldproc` to `str` if the components' percent escapes should not be + decoded. + """ + d = {} + + if t[0] is not None: + d['scheme'] = t[0] + + if t[1] is not None: + uphp = split_netloc(t[1], fieldproc = fieldproc) + if uphp[0] is not None: + d['user'] = uphp[0] + if uphp[1] is not None: + d['password'] = uphp[1] + if uphp[2] is not None: + d['host'] = uphp[2] + if uphp[3] is not None: + d['port'] = uphp[3] + + if t[2] is not None: + if t[2]: + d['path'] = list(map(fieldproc, t[2].split('/'))) + else: + d['path'] = [] + + if t[3] is not None: + if t[3]: + d['query'] = [tuple((list(map(fieldproc, x.split('=', 1))) + [None])[:2]) for x in t[3].split('&')] + else: + # no characters followed the '?' + d['query'] = [] + + if t[4] is not None: + d['fragment'] = fieldproc(t[4]) + return d + +def construct_query(x, + key_re = escape_query_key_re, + value_re = escape_query_value_re, +): + 'Given a sequence of (key, value) pairs, construct' + return '&'.join([ + v is not None and \ + '%s=%s' %( + key_re.sub(re_pct_encode, k), + value_re.sub(re_pct_encode, v), + ) or \ + key_re.sub(re_pct_encode, k) + for k, v in x + ]) + +def construct(x): + """ + Construct a RI tuple(5-tuple) from a dictionary object. + """ + p = x.get('path') + if p is not None: + p = '/'.join([escape_path_re.sub(re_pct_encode, y) for y in p]) + q = x.get('query') + if q is not None: + q = construct_query(q) + f = x.get('fragment') + if f is not None: + f = escape_re.sub(re_pct_encode, f) + + u = x.get('user') + pw = x.get('password') + h = x.get('host') + port = x.get('port') + + return ( + x.get('scheme'), + # netloc: [user[:pass]@]host[:port] + unsplit_netloc(( + x.get('user'), + x.get('password'), + x.get('host'), + x.get('port'), + )), + p, q, f + ) + +def parse(s, fieldproc = unescape): + """ + Parse an RI into a dictionary object. Synonym for ``structure(split(x))``. + + Set `fieldproc` to `str` if the components' percent escapes should not be + decoded. + """ + return structure(split(s), fieldproc = fieldproc) + +def serialize(x): + """ + Return an RI from a dictionary object. Synonym for ``unsplit(construct(x))``. + """ + return unsplit(construct(x)) + +__docformat__ = 'reStructuredText' diff --git a/py_opengauss/string.py b/py_opengauss/string.py new file mode 100644 index 0000000000000000000000000000000000000000..53799d3797d8747fa9d91603b17ab5b77a516f55 --- /dev/null +++ b/py_opengauss/string.py @@ -0,0 +1,270 @@ +## +# .string +## +""" +String split and join operations for dealing with literals and identifiers. + +Notably, the functions in this module are intended to be used for simple +use-cases. It attempts to stay away from "real" parsing and simply provides +functions for common needs, like the ability to identify unquoted portions of a +query string so that logic or transformations can be applied to only unquoted +portions. Scanning for statement terminators, or safely interpolating +identifiers. + +All functions deal with strict quoting rules. +""" +import re + +def escape_literal(text): + """ + Replace every instance of ' with ''. + """ + return text.replace("'", "''") + +def quote_literal(text): + """ + Escape the literal and wrap it in [single] quotations. + """ + return "'" + text.replace("'", "''") + "'" + +def escape_ident(text): + """ + Replace every instance of " with "". + """ + return text.replace('"', '""') + +def needs_quoting(text): + return not (text and not text[0].isdecimal() and text.replace('_', 'a').isalnum()) + +def quote_ident(text): + """ + Replace every instance of '"' with '""' *and* place '"' on each end. + """ + return '"' + text.replace('"', '""') + '"' + +def quote_ident_if_needed(text): + """ + If needed, replace every instance of '"' with '""' *and* place '"' on each end. + Otherwise, just return the text. + """ + return quote_ident(text) if needs_quoting(text) else text + +quote_re = re.compile(r"""(?xu) + E'(?:''|\\.|[^'])*(?:'|$) (?# Backslash escapes E'str') +| '(?:''|[^'])*(?:'|$) (?# Regular literals 'str') +| "(?:""|[^"])*(?:"|$) (?# Identifiers "str") +| (\$(?:[^0-9$]\w*)?\$).*?(?:\1|$) (?# Dollar quotes $$str$$) +""") + +def split(text): + """ + split the string up by into non-quoted and quoted portions. Zero and even + numbered indexes are unquoted portions, while odd indexes are quoted + portions. + + Unquoted portions are regular strings, whereas quoted portions are + pair-tuples specifying the quotation mechanism and the content thereof. + + >>> list(split("select $$foobar$$")) + ['select ', ('$$', 'foobar'), ''] + + If the split ends on a quoted section, it means the string's quote was not + terminated. Subsequently, there will be an even number of objects in the + list. + + Quotation errors are detected, but never raised. Rather it's up to the user + to identify the best course of action for the given split. + """ + lastend = 0 + re = quote_re + scan = re.scanner(text) + match = scan.search() + while match is not None: + # text preceding the quotation + yield text[lastend:match.start()] + # the dollar quote, if any + dq = match.groups()[0] + if dq is not None: + endoff = len(dq) + quote = dq + end = quote + else: + endoff = 1 + q = text[match.start()] + if q == 'E': + quote = "E'" + end = "'" + else: + end = quote = q + + # If the end is not the expected quote, it consumed + # the end. Be sure to check that the match's end - end offset + # is *not* the start, ie an empty quotation at the end of the string. + if text[match.end()-endoff:match.end()] != end \ + or match.end() - endoff == match.start(): + yield (quote, text[match.start()+len(quote):]) + break + else: + yield (quote, text[match.start()+len(quote):match.end()-endoff]) + + lastend = match.end() + match = scan.search() + else: + # balanced quotes, yield the rest + yield text[lastend:] + +def unsplit(splitted_iter): + """ + catenate a split string. This is needed to handle the special + cases created by pg.string.split(). (Run-away quotations, primarily) + """ + s = '' + quoted = False + i = iter(splitted_iter) + endq = '' + for x in i: + s += endq + x + try: + q, qtext = next(i) + s += q + qtext + if q == "E'": + endq = "'" + else: + endq = q + except StopIteration: + break + return s + +def split_using(text, quote, sep = '.', maxsplit = -1): + """ + split the string on the seperator ignoring the separator in quoted areas. + + This is only useful for simple quoted strings. Dollar quotes, and backslash + escapes are not supported. + """ + escape = quote * 2 + esclen = len(escape) + offset = 0 + tl = len(text) + end = tl + # Fast path: No quotes? Do a simple split. + if quote not in text: + return text.split(sep, maxsplit) + l = [] + + while len(l) != maxsplit: + # Look for the separator first + nextsep = text.find(sep, offset) + if nextsep == -1: + # it's over. there are no more seps + break + else: + # There's a sep ahead, but is there a quoted section before it? + nextquote = text.find(quote, offset, nextsep) + while nextquote != -1: + # Yep, there's a quote before the sep; + # need to eat the escaped portion. + nextquote = text.find(quote, nextquote + 1,) + while nextquote != -1: + if text.find(escape, nextquote, nextquote+esclen) != nextquote: + # Not an escape, so it's the end. + break + # Look for another quote past the escape quote. + nextquote = text.find(quote, nextquote + 2) + else: + # the sep was located in the escape, and + # the escape consumed the rest of the string. + nextsep = -1 + break + + nextsep = text.find(sep, nextquote + 1) + if nextsep == -1: + # it's over. there are no more seps + # [likely they were consumed by the escape] + break + nextquote = text.find(quote, nextquote + 1, nextsep) + if nextsep == -1: + break + + l.append(text[offset:nextsep]) + offset = nextsep + 1 + l.append(text[offset:]) + return l + +def split_ident(text, sep = ',', quote = '"', maxsplit = -1): + """ + Split a series of identifiers using the specified separator. + """ + nr = [] + for x in split_using(text, quote, sep = sep, maxsplit = maxsplit): + x = x.strip() + if x.startswith('"'): + if not x.endswith('"'): + raise ValueError( + "unterminated identifier quotation", x + ) + else: + nr.append(x[1:-1].replace('""', '"')) + elif needs_quoting(x): + raise ValueError( + "non-ident characters in unquoted identifier", x + ) + else: + # postgres implies a lower, so to stay consistent + # with it on qname joins, lower the unquoted identifier now. + nr.append(x.lower()) + return nr + +def split_qname(text, maxsplit = -1): + """ + Call to .split_ident() with a '.' sep parameter. + """ + return split_ident(text, maxsplit = maxsplit, sep = '.') + +def qname(*args): + """ + Quote the identifiers and join them using '.'. + """ + return '.'.join([quote_ident(x) for x in args]) + +def qname_if_needed(*args): + return '.'.join([quote_ident_if_needed(x) for x in args]) + +def split_sql(sql, sep = ';'): + """ + Given SQL, safely split using the given separator. + Notably, this yields fully split text. This should be used instead of + split_sql_str() when quoted sections need be still be isolated. + + >>> list(split_sql('select $$1$$ AS "foo;"; select 2;')) + [['select ', ('$$', '1'), ' AS ', ('"', 'foo;'), ''], (' select 2',), ['']] + """ + i = iter(split(sql)) + cur = [] + for part in i: + sections = part.split(sep) + + if len(sections) < 2: + cur.append(part) + else: + cur.append(sections[0]) + yield cur + for x in sections[1:-1]: + yield (x,) + cur = [sections[-1]] + try: + cur.append(next(i)) + except StopIteration: + break + if cur: + yield cur + +def split_sql_str(sql, sep = ';'): + """ + Identical to split_sql but yields unsplit text. + + >>> list(split_sql_str('select $$1$$ AS "foo;"; select 2;')) + ['select $$1$$ AS "foo;"', ' select 2', ''] + """ + for x in split_sql(sql, sep = sep): + yield unsplit(x) diff --git a/py_opengauss/sys.py b/py_opengauss/sys.py new file mode 100644 index 0000000000000000000000000000000000000000..a8471317e6229ce19810d673767d039c386500dd --- /dev/null +++ b/py_opengauss/sys.py @@ -0,0 +1,99 @@ +## +# .sys +## +""" +py-postgresql system functions and data. + +Data +---- + + ``libpath`` + The local file system paths that contain query libraries. + +Overridable Functions +--------------------- + + excformat + Information that makes up an exception's displayed "body". + Effectively, the implementation of `postgresql.exception.Error.__str__` + + msghook + Display a message. +""" +import sys +import os +import traceback +from .python.element import format_element +from .python.string import indent + +libpath = [] + +def default_errformat(val): + """ + Built-in error formatter. Do not change. + """ + it = val._e_metas() + if val.creator is not None: + # Protect against element traceback failures. + try: + after = os.linesep + format_element(val.creator) + except Exception: + after = 'Element Traceback of %r caused exception:%s' %( + type(val.creator).__name__, + os.linesep + ) + after += indent(traceback.format_exc()) + after = os.linesep + indent(after).rstrip() + else: + after = '' + return next(it)[1] \ + + os.linesep + ' ' \ + + (os.linesep + ' ').join( + k + ': ' + v for k, v in it + ) + after + +def default_msghook(msg, format_message = format_element): + """ + Built-in message hook. DON'T TOUCH! + """ + if sys.stderr and not sys.stderr.closed: + try: + sys.stderr.write(format_message(msg) + os.linesep) + except Exception: + try: + sys.excepthook(*sys.exc_info()) + except Exception: + # gasp. + pass + +def errformat(*args, **kw): + """ + Raised Database Error formatter pointing to default_excformat. + + Override if you like. All postgresql.exceptions.Error's are formatted using + this function. + """ + return default_errformat(*args, **kw) + +def msghook(*args, **kw): + """ + Message hook pointing to default_msghook. + + Override if you like. All untrapped messages raised by + driver connections come here to be printed to stderr. + """ + return default_msghook(*args, **kw) + +def reset_errformat(with_func = errformat): + """ + Restore the original excformat function. + """ + global errformat + errformat = with_func + +def reset_msghook(with_func = msghook): + """ + Restore the original msghook function. + """ + global msghook + msghook = with_func diff --git a/py_opengauss/temporal.py b/py_opengauss/temporal.py new file mode 100644 index 0000000000000000000000000000000000000000..d5b891db95ff849273b1175e554a37a20ccecaab --- /dev/null +++ b/py_opengauss/temporal.py @@ -0,0 +1,275 @@ +## +# .temporal - manage the temporary cluster +## +""" +Temporary PostgreSQL cluster for the process. +""" +import os +import atexit +import builtins +from collections import deque +from .cluster import Cluster, ClusterError +from . import installation +from .python.socket import find_available_port + +class Temporal(object): + """ + Manages a temporary cluster for the duration of the process. + + Instances of this class reference a distinct cluster. These clusters are + transient; they will only exist until the process exits. + + Usage:: + + >>> from py_opengauss.temporal import pg_tmp + >>> with pg_tmp: + ... ps = db.prepare('SELECT 1') + ... assert ps.first() == 1 + + Or `pg_tmp` can decorate a method or function. + """ + + format_sandbox_id = staticmethod(('sandbox{0}_{1}').format) + cluster_dirname = staticmethod(('pg_tmp_{0}_{1}').format) + cluster = None + + _init_pid_ = None + _local_id_ = 0 + builtins_keys = { + 'connector', + 'db', + 'do', + 'xact', + 'proc', + 'settings', + 'prepare', + 'sqlexec', + 'newdb', + } + + def __init__(self): + self.builtins_stack = deque() + self.sandbox_id = 0 + # identifier for keeping temporary instances unique. + self.__class__._local_id_ = self.local_id = (self.__class__._local_id_ + 1) + + def __call__(self, callable): + def in_pg_temporal_context(*args, **kw): + with self: + return callable(*args, **kw) + n = getattr(callable, '__name__', None) + if n: + in_pg_temporal_context.__name__ = n + return in_pg_temporal_context + + def destroy(self): + # Don't destroy if it's not the initializing process. + if os.getpid() == self._init_pid_: + # Kill all the open connections. + try: + c = cluster.connection(user = 'test', database = 'template1',) + with c: + if c.version_info[:2] <= (9,1): + c.sys.terminate_backends() + else: + c.sys.terminate_backends_92() + except Exception: + # Doesn't matter much if it fails. + pass + cluster = self.cluster + self.cluster = None + self._init_pid_ = None + if cluster is not None: + cluster.stop() + cluster.wait_until_stopped(timeout = 5) + cluster.drop() + + def init(self, + installation_factory = installation.default, + inshint = { + 'hint' : "Try setting the PGINSTALLATION " \ + "environment variable to the `pg_config` path" + } + ): + if self.cluster is not None or 'PGTEST' in os.environ: + return + ## + # Hasn't been created yet, but doesn't matter. + # On exit, obliterate the cluster directory. + self._init_pid_ = os.getpid() + atexit.register(self.destroy) + + # [$HOME|.]/.pg_tmpdb_{pid} + self.cluster_path = os.path.join( + os.environ.get('HOME', os.getcwd()), + self.cluster_dirname(self._init_pid_, self.local_id) + ) + self.logfile = os.path.join(self.cluster_path, 'logfile') + installation = installation_factory() + if installation is None: + raise ClusterError( + 'could not find the default pg_config', details = inshint + ) + + vi = installation.version_info + cluster = Cluster(installation, self.cluster_path,) + + # If it exists already, destroy it. + if cluster.initialized(): + cluster.drop() + cluster.encoding = 'utf-8' + cluster.init( + user = 'test', # Consistent username. + encoding = cluster.encoding, + logfile = None, + ) + + try: + self.cluster_port = find_available_port() + except: + # Rely on chain. + raise ClusterError( + 'could not find a port for the test cluster on localhost', + creator = cluster + ) + + if vi[:2] > (9,6): + # Default changed in 10.x + cluster.settings['max_wal_senders'] = '0' + + cluster.settings.update(dict( + port = str(self.cluster_port), + max_connections = '20', + shared_buffers = '200', + listen_addresses = 'localhost', + log_destination = 'stderr', + log_min_messages = 'FATAL', + max_prepared_transactions = '10', + )) + + if installation.version_info[:2] < (9, 3): + cluster.settings.update(dict( + unix_socket_directory = cluster.data_directory, + )) + else: + cluster.settings.update(dict( + unix_socket_directories = cluster.data_directory, + )) + + # Start the database cluster. + with open(self.logfile, 'w') as lfo: + cluster.start(logfile = lfo) + cluster.wait_until_started() + + # Initialize template1 and the test user database. + c = cluster.connection(user = 'test', database = 'template1',) + with c: + c.execute('create database test') + self.cluster = cluster + + def push(self): + if 'PGTEST' in os.environ: + from . import open as pg_open + c = pg_open(os.environ['PGTEST']) # Ignoring PGINSTALLATION. + else: + c = self.cluster.connection(user = 'test') + c.connect() + + extras = [] + sbid = self.format_sandbox_id(os.getpid(), self.sandbox_id + 1) + + def new_pg_tmp_connection(l = extras, clone = c.clone, sbid = sbid): + # Used to create a new connection that will be closed + # when the context stack is popped along with 'db'. + l.append(clone()) + l[-1].settings['search_path'] = str(sbid) + ',' + l[-1].settings['search_path'] + return l[-1] + + # The new builtins. + local_builtins = { + 'db' : c, + 'prepare' : c.prepare, + 'xact' : c.xact, + 'sqlexec' : c.execute, + 'do' : c.do, + 'settings' : c.settings, + 'proc' : c.proc, + 'connector' : c.connector, + 'new' : new_pg_tmp_connection, + } + if not self.builtins_stack: + # Store any of those set or not set. + current = { + k : builtins.__dict__[k] for k in self.builtins_keys + if k in builtins.__dict__ + } + self.builtins_stack.append((current, [])) + + # Store and push. + self.builtins_stack.append((local_builtins, extras)) + builtins.__dict__.update(local_builtins) + self.sandbox_id += 1 + + def pop(self, exc, drop_schema = ('DROP SCHEMA {0} CASCADE').format): + local_builtins, extras = self.builtins_stack.pop() + self.sandbox_id -= 1 + + # restore __builtins__ + if len(self.builtins_stack) > 1: + builtins.__dict__.update(self.builtins_stack[-1][0]) + else: + previous = self.builtins_stack.popleft() + for x in self.builtins_keys: + if x in previous: + builtins.__dict__[x] = previous[x] + else: + # Wasn't set before. + builtins.__dict__.pop(x, None) + + # close popped connection, but only if we're not in an interrupt. + # However, temporal will always terminate all backends atexit. + if exc is None or isinstance(exc, Exception): + # Interrupt then close. Just in case something is lingering. + for xdb in [local_builtins['db']] + list(extras): + if xdb.closed is False: + # In order for a clean close of the connection, + # interrupt before closing. It is still + # possible for the close to block, but less likely. + xdb.interrupt() + xdb.close() + + # Interrupted and closed all the other connections at this level; + # now remove the sandbox schema. + xdb = local_builtins['db'] + with xdb.clone() as c: + # Use a new connection so that the state of + # the context connection will not have to be + # contended with. + c.execute(drop_schema(self.format_sandbox_id(os.getpid(), self.sandbox_id + 1))) + else: + # interrupt exception; avoid waiting for close + pass + + def _init_c(self, cxn): + cxn.connect() + sb = self.format_sandbox_id(os.getpid(), self.sandbox_id) + cxn.execute('CREATE SCHEMA ' + sb) + cxn.settings['search_path'] = ','.join((sb, cxn.settings['search_path'])) + + def __enter__(self): + if self.cluster is None: + self.init() + + self.push() + try: + self._init_c(builtins.db) + except Exception as e: + # failed to initialize sandbox schema; pop it. + self.pop(e) + raise + + def __exit__(self, exc, val, tb): + self.pop(val) + +#: The process' temporary cluster or connection source. +pg_tmp = Temporal() diff --git a/py_opengauss/test/__init__.py b/py_opengauss/test/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/py_opengauss/test/cursor_integrity.py b/py_opengauss/test/cursor_integrity.py new file mode 100644 index 0000000000000000000000000000000000000000..14a07acf96d0cba409d6a333bbf65367b937f39c --- /dev/null +++ b/py_opengauss/test/cursor_integrity.py @@ -0,0 +1,116 @@ +## +# .test.cursor_integrity +## +import os +import unittest +import random +import itertools + +iot = '_dst' + +getq = "SELECT i FROM generate_series(0, %d) AS g(i)" +copy = "COPY (%s) TO STDOUT" + +def random_read(curs, remaining_rows): + """ + Read from one of the three methods using a random amount if sized. + - 50% chance of curs.read(random()) + - 40% chance of next() + - 10% chance of read() # no count + """ + if random.random() > 0.5: + rrows = random.randrange(0, remaining_rows) + return curs.read(rrows), rrows + elif random.random() < 0.1: + return curs.read(), -1 + else: + try: + return [next(curs)], 1 + except StopIteration: + return [], 1 + +def random_select_get(limit): + return prepare(getq %(limit - 1,)) + +def random_copy_get(limit): + return prepare(copy %(getq %(limit - 1,),)) + +class test_integrity(unittest.TestCase): + """ + test the integrity of the get and put interfaces on queries + and result handles. + """ + def test_select(self): + total = 0 + while total < 10000: + limit = random.randrange(500000) + read = 0 + total += limit + p = random_select_get(limit)() + last = ([(-1,)], 1) + completed = [last[0]] + while True: + next = random_read(p, (limit - read) or 10) + thisread = len(next[0]) + read += thisread + completed.append(next[0]) + if thisread: + self.failUnlessEqual( + last[0][-1][0], next[0][0][0] - 1, + "first row(-1) of next failed to match the last row of the previous" + ) + last = next + elif next[1] != 0: + # done + break + self.failUnlessEqual(read, limit) + self.failUnlessEqual(list(range(-1, limit)), [ + x[0] for x in itertools.chain(*completed) + ]) + + def test_insert(self): + pass + + if 'db' in dir(__builtins__) and pg.version_info >= (8,2,0): + def test_copy_out(self): + total = 0 + while total < 10000000: + limit = random.randrange(500000) + read = 0 + total += limit + p = random_copy_get(limit)() + last = ([-1], 1) + completed = [last[0]] + while True: + next = random_read(p, (limit - read) or 10) + next = ([int(x) for x in next[0]], next[1]) + thisread = len(next[0]) + read += thisread + completed.append(next[0]) + if thisread: + self.failUnlessEqual( + last[0][-1], next[0][0] - 1, + "first row(-1) of next failed to match the last row of the previous" + ) + last = next + elif next[1] != 0: + # done + break + self.failUnlessEqual(read, limit) + self.failUnlessEqual( + list(range(-1, limit)), + list(itertools.chain(*completed)) + ) + + def test_copy_in(self): + pass + +def main(): + global copyin, loadin + execute("CREATE TEMP TABLE _dst (i bigint)") + copyin = prepare("COPY _dst FROM STDIN") + loadin = prepare("INSERT INTO _dst VALUES ($1)") + unittest.main() + +if __name__ == '__main__': + main() diff --git a/py_opengauss/test/perf_copy_io.py b/py_opengauss/test/perf_copy_io.py new file mode 100644 index 0000000000000000000000000000000000000000..5029b2f7dea43e4d172551652aa71b6da32d0450 --- /dev/null +++ b/py_opengauss/test/perf_copy_io.py @@ -0,0 +1,76 @@ +## +# test.perf_copy_io - Copy I/O: To and From performance +## +import os, sys, random, time + +if __name__ == '__main__': + with open('/usr/share/dict/words', mode='brU') as wordfile: + Words = wordfile.readlines() +else: + Words = [b'/usr/share/dict/words', b'is', b'read', b'in', b'__main__'] +wordcount = len(Words) +random.seed() + +def getWord(): + "extract a random word from ``Words``" + return Words[random.randrange(0, wordcount)].strip() + +def testSpeed(tuples = 50000 * 3): + sqlexec("CREATE TEMP TABLE _copy " + "(i int, t text, mt text, ts text, ty text, tx text);") + try: + Q = prepare("COPY _copy FROM STDIN") + size = 0 + def incsize(data): + 'count of bytes' + nonlocal size + size += len(data) + return data + sys.stderr.write("preparing data(%d tuples)...\n" %(tuples,)) + + # Use an LC to avoid the Python overhead involved with a GE + data = [incsize(b'\t'.join(( + str(x).encode('ascii'), getWord(), getWord(), + getWord(), getWord(), getWord() + )))+b'\n' for x in range(tuples)] + + sys.stderr.write("starting copy...\n") + start = time.time() + copied_in = Q.load_rows(data) + duration = time.time() - start + sys.stderr.write( + "COPY FROM STDIN Summary,\n " \ + "copied tuples: %d\n " \ + "copied bytes: %d\n " \ + "duration: %f\n " \ + "average tuple size(bytes): %f\n " \ + "average KB per second: %f\n " \ + "average tuples per second: %f\n" %( + tuples, size, duration, + size / tuples, + size / 1024 / duration, + tuples / duration, + ) + ) + Q = prepare("COPY _copy TO STDOUT") + start = time.time() + c = 0 + for rows in Q.chunks(): + c += len(rows) + duration = time.time() - start + sys.stderr.write( + "COPY TO STDOUT Summary,\n " \ + "copied tuples: %d\n " \ + "duration: %f\n " \ + "average KB per second: %f\n " \ + "average tuples per second: %f\n " %( + c, duration, + size / 1024 / duration, + tuples / duration, + ) + ) + finally: + sqlexec("DROP TABLE _copy") + +if __name__ == '__main__': + testSpeed() diff --git a/py_opengauss/test/perf_query_io.py b/py_opengauss/test/perf_query_io.py new file mode 100644 index 0000000000000000000000000000000000000000..e1ca333b9c74cb515392d9fc73b0eb2a429ffca9 --- /dev/null +++ b/py_opengauss/test/perf_query_io.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python +## +# .test.perf_query_io +## +# Statement I/O: Mass insert and select performance +## +import os +import time +import sys +import decimal +import datetime + +def insertSamples(count, insert_records): + recs = [ + ( + -3, 123, 0xfffffea023, + decimal.Decimal("90900023123.40031"), + decimal.Decimal("432.40031"), + 'some_óäæ_thing', 'varying', 'æ', + datetime.datetime(1982, 5, 18, 12, 0, 0, 100232) + ) + for x in range(count) + ] + gen = time.time() + insert_records.load_rows(recs) + fin = time.time() + xacttime = fin - gen + ats = count / xacttime + sys.stderr.write( + "INSERT Summary,\n " \ + "inserted tuples: %d\n " \ + "total time: %f\n " \ + "average tuples per second: %f\n\n" %( + count, xacttime, ats, + ) + ) + +def timeTupleRead(ps): + loops = 0 + tuples = 0 + genesis = time.time() + for x in ps.chunks(): + loops += 1 + tuples += len(x) + finalis = time.time() + looptime = finalis - genesis + ats = tuples / looptime + sys.stderr.write( + "SELECT Summary,\n " \ + "looped: {looped}\n " \ + "looptime: {looptime}\n " \ + "tuples: {ntuples}\n " \ + "average tuples per second: {tps}\n ".format( + looped = loops, + looptime = looptime, + ntuples = tuples, + tps = ats + ) + ) + +def main(count): + sqlexec('CREATE TEMP TABLE samples ' + '(i2 int2, i4 int4, i8 int8, n numeric, n2 numeric, t text, v varchar, c char(2), ts timestamp)') + insert_records = prepare( + "INSERT INTO samples VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + ) + select_records = prepare("SELECT * FROM samples") + try: + insertSamples(count, insert_records) + timeTupleRead(select_records) + finally: + sqlexec("DROP TABLE samples") + +def command(args): + main(int((args + [25000])[1])) + +if __name__ == '__main__': + command(sys.argv) diff --git a/py_opengauss/test/support.py b/py_opengauss/test/support.py new file mode 100644 index 0000000000000000000000000000000000000000..35f16535e190df80afd8eb9e545e4de330f917a6 --- /dev/null +++ b/py_opengauss/test/support.py @@ -0,0 +1,24 @@ +## +# .test.support +## +""" +Executable module used by test_* modules to mimic a command. +""" +import sys + +def pg_config(*args): + data = """FOO=BaR +FEH=YEAH +version=NAY +""" + sys.stdout.write(data) + +if __name__ == '__main__': + if sys.argv[1:]: + cmd = sys.argv[1] + if cmd in globals(): + cmd = globals()[cmd] + cmd(sys.argv[2:]) + sys.exit(0) + sys.stderr.write("no valid entry point referenced") + sys.exit(1) diff --git a/py_opengauss/test/test_alock.py b/py_opengauss/test/test_alock.py new file mode 100644 index 0000000000000000000000000000000000000000..1a8b95b02810284a820d260481a6028f58d0e5ce --- /dev/null +++ b/py_opengauss/test/test_alock.py @@ -0,0 +1,138 @@ +## +# .test.test_alock - test .alock +## +import unittest +import threading +import time +from ..temporal import pg_tmp +from .. import alock + +n_alocks = "select count(*) FROM pg_locks WHERE locktype = 'advisory'" + +class test_alock(unittest.TestCase): + @pg_tmp + def testALockWait(self): + # sadly, this is primarily used to exercise the code paths.. + ad = prepare(n_alocks).first + self.assertEqual(ad(), 0) + state = [False, False, False] + alt = new() + first = alock.ExclusiveLock(db, (0,0)) + second = alock.ExclusiveLock(db, 1) + def concurrent_lock(): + try: + with alock.ExclusiveLock(alt, 1): + with alock.ExclusiveLock(alt, (0,0)): + # start it + state[0] = True + while not state[1]: + pass + time.sleep(0.01) + while not state[2]: + time.sleep(0.01) + except Exception: + # Avoid dead lock in cases where advisory is not available. + state[0] = state[1] = state[2] = True + t = threading.Thread(target = concurrent_lock) + t.start() + while not state[0]: + time.sleep(0.01) + self.assertEqual(ad(), 2) + state[1] = True + with first: + self.assertEqual(ad(), 2) + state[2] = True + with second: + self.assertEqual(ad(), 2) + t.join(timeout = 1) + + @pg_tmp + def testALockNoWait(self): + alt = new() + ad = prepare(n_alocks).first + self.assertEqual(ad(), 0) + with alock.ExclusiveLock(db, (0,0)): + l=alock.ExclusiveLock(alt, (0,0)) + # should fail to acquire + self.assertEqual(l.acquire(blocking=False), False) + # no alocks should exist now + self.assertEqual(ad(), 0) + + @pg_tmp + def testALock(self): + ad = prepare(n_alocks).first + self.assertEqual(ad(), 0) + # test a variety.. + lockids = [ + (1,4), + -32532, 0, 2, + (7, -1232), + 4, 5, 232142423, + (18,7), + 2, (1,4) + ] + alt = new() + xal1 = alock.ExclusiveLock(db, *lockids) + xal2 = alock.ExclusiveLock(db, *lockids) + sal1 = alock.ShareLock(db, *lockids) + with sal1: + with xal1, xal2: + self.assertTrue(ad() > 0) + for x in lockids: + xl = alock.ExclusiveLock(alt, x) + self.assertEqual(xl.acquire(blocking=False), False) + # main has exclusives on these, so this should fail. + xl = alock.ShareLock(alt, *lockids) + self.assertEqual(xl.acquire(blocking=False), False) + for x in lockids: + # sal1 still holds + xl = alock.ExclusiveLock(alt, x) + self.assertEqual(xl.acquire(blocking=False), False) + # sal1 still holds, but we want a share lock too. + xl = alock.ShareLock(alt, x) + self.assertEqual(xl.acquire(blocking=False), True) + xl.release() + # no alocks should exist now + self.assertEqual(ad(), 0) + + @pg_tmp + def testPartialALock(self): + # Validates that release is properly cleaning up + ad = prepare(n_alocks).first + self.assertEqual(ad(), 0) + held = (0,-1234) + wanted = [0, 324, -1232948, 7, held, 1, (2,4), (834,1)] + alt = new() + with alock.ExclusiveLock(db, held): + l=alock.ExclusiveLock(alt, *wanted) + # should fail to acquire, db has held + self.assertEqual(l.acquire(blocking=False), False) + # No alocks should exist now. + # This *MUST* occur prior to alt being closed. + # Otherwise, we won't be testing for the recovery + # of a failed non-blocking acquire(). + self.assertEqual(ad(), 0) + + @pg_tmp + def testALockParameterErrors(self): + self.assertRaises(TypeError, alock.ALock) + l = alock.ExclusiveLock(db) + self.assertRaises(RuntimeError, l.release) + + @pg_tmp + def testALockOnClosed(self): + ad = prepare(n_alocks).first + self.assertEqual(ad(), 0) + held = (0,-1234) + alt = new() + # __exit__ should only touch the count. + with alock.ExclusiveLock(alt, held) as l: + self.assertEqual(ad(), 1) + self.assertEqual(l.locked(), True) + alt.close() + time.sleep(0.005) + self.assertEqual(ad(), 0) + self.assertEqual(l.locked(), False) + +if __name__ == '__main__': + unittest.main() diff --git a/py_opengauss/test/test_bytea_codec.py b/py_opengauss/test/test_bytea_codec.py new file mode 100644 index 0000000000000000000000000000000000000000..4de01727e3426ca9de57b1c7ca9654d3dca13ac7 --- /dev/null +++ b/py_opengauss/test/test_bytea_codec.py @@ -0,0 +1,45 @@ +## +# .test.test_bytea_codec +## +import unittest +import struct +from ..encodings import bytea + +byte = struct.Struct('B') + +class test_bytea_codec(unittest.TestCase): + def testDecoding(self): + for x in range(255): + c = byte.pack(x) + b = c.decode('bytea') + # normalize into octal escapes + if c == b'\\' and b == "\\\\": + b = "\\" + oct(b'\\'[0])[2:] + elif not b.startswith("\\"): + b = "\\" + oct(ord(b))[2:] + if int(b[1:], 8) != x: + self.fail( + "bytea encoding failed at %d; encoded %r to %r" %(x, c, b,) + ) + + def testEncoding(self): + self.assertEqual('bytea'.encode('bytea'), b'bytea') + self.assertEqual('\\\\'.encode('bytea'), b'\\') + self.assertRaises(ValueError, '\\'.encode, 'bytea') + self.assertRaises(ValueError, 'foo\\'.encode, 'bytea') + self.assertRaises(ValueError, r'foo\0'.encode, 'bytea') + self.assertRaises(ValueError, r'foo\00'.encode, 'bytea') + self.assertRaises(ValueError, r'\f'.encode, 'bytea') + self.assertRaises(ValueError, r'\800'.encode, 'bytea') + self.assertRaises(ValueError, r'\7f0'.encode, 'bytea') + for x in range(255): + seq = ('\\' + oct(x)[2:].lstrip('0').rjust(3, '0')) + dx = ord(seq.encode('bytea')) + if dx != x: + self.fail( + "generated sequence failed to map back; current is %d, " \ + "rendered %r, transformed to %d" %(x, seq, dx) + ) + +if __name__ == '__main__': + unittest.main() diff --git a/py_opengauss/test/test_cluster.py b/py_opengauss/test/test_cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..027b5fdab41b3545f1f99137f2ca4f8ddec7afaf --- /dev/null +++ b/py_opengauss/test/test_cluster.py @@ -0,0 +1,95 @@ +## +# .test.test_cluster +## +import sys +import os +import time +import unittest +import tempfile +from .. import installation +from ..cluster import Cluster, ClusterStartupError + +default_installation = installation.default() + +class test_cluster(unittest.TestCase): + def setUp(self): + self.cluster = Cluster(default_installation, 'test_cluster',) + + def tearDown(self): + if self.cluster.installation is not None: + self.cluster.drop() + self.cluster = None + + def start_cluster(self, logfile = None): + self.cluster.start(logfile = logfile) + self.cluster.wait_until_started(timeout = 10) + + def init(self, *args, **kw): + self.cluster.init(*args, **kw) + + vi = self.cluster.installation.version_info[:2] + if vi >= (9, 3): + usd = 'unix_socket_directories' + else: + usd = 'unix_socket_directory' + + if vi > (9, 6): + self.cluster.settings['max_wal_senders'] = '0' + + self.cluster.settings.update({ + 'max_connections' : '8', + 'listen_addresses' : 'localhost', + 'port' : '6543', + usd : self.cluster.data_directory, + }) + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + def testSilentMode(self): + self.init() + self.cluster.settings['silent_mode'] = 'on' + # if it fails to start(ClusterError), silent_mode is not working properly. + try: + self.start_cluster(logfile = sys.stdout) + except ClusterStartupError: + # silent_mode is not supported on windows by PG. + if sys.platform in ('win32','win64'): + pass + elif self.cluster.installation.version_info[:2] >= (9, 2): + pass + else: + raise + else: + if sys.platform in ('win32','win64'): + self.fail("silent_mode unexpectedly supported on windows") + elif self.cluster.installation.version_info[:2] >= (9, 2): + self.fail("silent_mode unexpectedly supported on PostgreSQL >=9.2") + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + def testSuperPassword(self): + self.init( + user = 'test', + password = 'secret', + logfile = sys.stdout, + ) + self.start_cluster() + c = self.cluster.connection( + user='test', + password='secret', + database='template1', + ) + with c: + self.assertEqual(c.prepare('select 1').first(), 1) + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + def testNoParameters(self): + """ + Simple init and drop. + """ + self.init() + self.start_cluster() + +if __name__ == '__main__': + from types import ModuleType + this = ModuleType("this") + this.__dict__.update(globals()) + unittest.main(this) diff --git a/py_opengauss/test/test_configfile.py b/py_opengauss/test/test_configfile.py new file mode 100644 index 0000000000000000000000000000000000000000..85d30b66317f93e5521397811e9b3853f5bdfa5e --- /dev/null +++ b/py_opengauss/test/test_configfile.py @@ -0,0 +1,255 @@ +## +# .test.test_configfile +## +import os +import unittest +from io import StringIO +from .. import configfile + +sample_config_Aroma = \ +""" +## +# A sample config file. +## +# This provides a good = test for alter_config. + +#shared_buffers = 4500 +search_path = window,$user,public +shared_buffers = 2500 + +port = 5234 +listen_addresses = 'localhost' +listen_addresses = '*' +""" + +## +# Wining cases are alteration cases that provide +# source and expectations from an alteration. +# +# The first string is the source, the second the +# alterations to make, the and the third, the expectation. +## +winning_cases = [ + ( + # Two top contenders; the first should be altered, second commented. + "foo = bar"+os.linesep+"foo = bar", + {'foo' : 'newbar'}, + "foo = 'newbar'"+os.linesep+"#foo = bar" + ), + ( + # Two top contenders, first one stays commented + "#foo = bar"+os.linesep+"foo = bar", + {'foo' : 'newbar'}, + "#foo = bar"+os.linesep+"foo = 'newbar'" + ), + ( + # Two top contenders, second one stays commented + "foo = bar"+os.linesep+"#foo = bar", + {'foo' : 'newbar'}, + "foo = 'newbar'"+os.linesep+"#foo = bar" + ), + ( + # Two candidates + "foo = bar"+os.linesep+"foo = none", + {'foo' : 'bar'}, + "foo = 'bar'"+os.linesep+"#foo = none" + ), + ( + # Two candidates, winner should be the first, second gets comment + "#foo = none"+os.linesep+"foo = bar", + {'foo' : 'none'}, + "foo = 'none'"+os.linesep+"#foo = bar" + ), + ( + # Two commented candidates + "#foo = none"+os.linesep+"#foo = some", + {'foo' : 'bar'}, + "foo = 'bar'"+os.linesep+"#foo = some" + ), + ( + # Two commented candidates, the latter a top contender + "#foo = none"+os.linesep+"#foo = bar", + {'foo' : 'bar'}, + "#foo = none"+os.linesep+"foo = 'bar'" + ), + ( + # Replace empty value + "foo = "+os.linesep, + {'foo' : 'feh'}, + "foo = 'feh'" + ), + ( + # Comment value + "foo = bar", + {'foo' : None}, + "#foo = bar" + ), + ( + # Commenting after value + "foo = val this should be commented", + {'foo' : 'newval'}, + "foo = 'newval' #this should be commented" + ), + ( + # Commenting after value + "#foo = val this should be commented", + {'foo' : 'newval'}, + "foo = 'newval' #this should be commented" + ), + ( + # Commenting after quoted value + "#foo = 'val'foo this should be commented", + {'foo' : 'newval'}, + "foo = 'newval' #this should be commented" + ), + ( + # Adjacent post-value comment + "#foo = 'val'#foo this should be commented", + {'foo' : 'newval'}, + "foo = 'newval'#foo this should be commented" + ), + ( + # New setting in empty string + "", + {'bar' : 'newvar'}, + "bar = 'newvar'", + ), + ( + # New setting + "foo = 'bar'", + {'bar' : 'newvar'}, + "foo = 'bar'"+os.linesep+"bar = 'newvar'", + ), + ( + # New setting with quote escape + "foo = 'bar'", + {'bar' : "new'var"}, + "foo = 'bar'"+os.linesep+"bar = 'new''var'", + ), +] + +class test_configfile(unittest.TestCase): + def parseNone(self, line): + sl = configfile.parse_line(line) + if sl is not None: + self.fail( + "With line %r, parsed out to %r, %r, and %r, %r, " \ + "but expected None to be returned by parse function." %( + line, line[sl[0]], sl[0], line[sl[0]], sl[0] + ) + ) + + def parseExpect(self, line, key, val): + line = line %(key, val) + sl = configfile.parse_line(line) + if sl is None: + self.fail( + "expecting %r and %r from line %r, " \ + "but got None(syntax error) instead." %( + key, val, line + ) + ) + k, v = sl + if line[k] != key: + self.fail( + "expecting key %r for line %r, " \ + "but got %r from %r instead." %( + key, line, line[k], k + ) + ) + if line[v] != val: + self.fail( + "expecting value %r for line %r, " \ + "but got %r from %r instead." %( + val, line, line[v], v + ) + ) + + def testParser(self): + self.parseExpect("#%s = %s", 'foo', 'none') + self.parseExpect("#%s=%s"+os.linesep, 'foo', 'bar') + self.parseExpect(" #%s=%s"+os.linesep, 'foo', 'bar') + self.parseExpect('%s =%s'+os.linesep, 'foo', 'bar') + self.parseExpect(' %s=%s '+os.linesep, 'foo', 'Bar') + self.parseExpect(' %s = %s '+os.linesep, 'foo', 'Bar') + self.parseExpect('# %s = %s '+os.linesep, 'foo', 'Bar') + self.parseExpect('\t # %s = %s '+os.linesep, 'foo', 'Bar') + self.parseExpect(' # %s = %s '+os.linesep, 'foo', 'Bar') + self.parseExpect(" # %s = %s"+os.linesep, 'foo', "' Bar '") + self.parseExpect("%s = %s# comment"+os.linesep, 'foo', '') + self.parseExpect(" # %s = %s # A # comment"+os.linesep, 'foo', "' B''a#r '") + # No equality or equality in complex comment + self.parseNone(' #i # foo = Bar '+os.linesep) + self.parseNone('#bar') + self.parseNone('bar') + + def testConfigRead(self): + sample = "foo = bar"+os.linesep+"# A comment, yes."+os.linesep+" bar = foo # yet?"+os.linesep + d = configfile.read_config(sample.split(os.linesep)) + self.assertTrue(d['foo'] == 'bar') + self.assertTrue(d['bar'] == 'foo') + + def testConfigWriteRead(self): + strio = StringIO() + d = { + '' : "'foo bar'" + } + configfile.write_config(d, strio.write) + strio.seek(0) + + def testWinningCases(self): + i = 0 + for before, alters, after in winning_cases: + befg = (x + os.linesep for x in before.split(os.linesep)) + became = ''.join(configfile.alter_config(alters, befg)) + self.assertTrue( + became.strip() == after, + 'On %d, before, %r, did not become after, %r; got %r using %r' %( + i, before, after, became, alters + ) + ) + i += 1 + + def testSimpleConfigAlter(self): + # Simple set and uncomment and set test. + strio = StringIO() + strio.write("foo = bar"+os.linesep+" # bleh = unset"+os.linesep+" # grr = 'oh yeah''s'") + strio.seek(0) + lines = configfile.alter_config({'foo' : 'yes', 'bleh' : 'feh'}, strio) + d = configfile.read_config(lines) + self.assertTrue(d['foo'] == 'yes') + self.assertTrue(d['bleh'] == 'feh') + self.assertTrue(''.join(lines).count('bleh') == 1) + + def testAroma(self): + lines = configfile.alter_config({ + 'shared_buffers' : '800', + 'port' : None + }, (x + os.linesep for x in sample_config_Aroma.split('\n')) + ) + d = configfile.read_config(lines) + self.assertTrue(d['shared_buffers'] == '800') + self.assertTrue(d.get('port') is None) + + nlines = configfile.alter_config({'port' : '1'}, lines) + d2 = configfile.read_config(nlines) + self.assertTrue(d2.get('port') == '1') + self.assertTrue( + nlines[:4] == lines[:4] + ) + + def testSelection(self): + # Sanity + red = configfile.read_config(['foo = bar'+os.linesep, 'bar = foo']) + self.assertTrue(len(red.keys()) == 2) + + # Test a simple selector + red = configfile.read_config(['foo = bar'+os.linesep, 'bar = foo'], + selector = lambda x: x == 'bar') + rkeys = list(red.keys()) + self.assertTrue(len(rkeys) == 1) + self.assertTrue(rkeys[0] == 'bar') + self.assertTrue(red['bar'] == 'foo') + +if __name__ == '__main__': + unittest.main() diff --git a/py_opengauss/test/test_connect.py b/py_opengauss/test/test_connect.py new file mode 100644 index 0000000000000000000000000000000000000000..ed587956683360be69c139de495eb4b7ba75c73f --- /dev/null +++ b/py_opengauss/test/test_connect.py @@ -0,0 +1,549 @@ +## +# .test.test_connect +## +import sys +import os +import unittest +import atexit +import socket +import errno + +from ..python.socket import find_available_port + +from .. import installation +from .. import cluster as pg_cluster +from .. import exceptions as pg_exc + +from ..driver import dbapi20 as dbapi20 +from .. import driver as pg_driver +from .. import open as pg_open + +default_installation = installation.default() + +def check_for_ipv6(): + result = False + if socket.has_ipv6: + try: + socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + result = True + except socket.error as e: + errs = [errno.EAFNOSUPPORT] + WSAEAFNOSUPPORT = getattr(errno, 'WSAEAFNOSUPPORT', None) + if WSAEAFNOSUPPORT is not None: + errs.append(WSAEAFNOSUPPORT) + if e.errno not in errs: + raise + return result + + +msw = sys.platform in ('win32', 'win64') + +# win32 binaries don't appear to be built with ipv6 +has_ipv6 = check_for_ipv6() and not msw + +has_unix_sock = not msw + + +class TestCaseWithCluster(unittest.TestCase): + """ + postgresql.driver *interface* tests. + """ + installation = default_installation + + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self.cluster_path = \ + 'pypg_test_' \ + + str(os.getpid()) + getattr(self, 'cluster_path_suffix', '') + + self.cluster = pg_cluster.Cluster( + self.installation, + self.cluster_path, + ) + + @property + def disable_replication(self): + """ + Whether replication settings should be disabled. + """ + return self.installation.version_info[:2] > (9, 6) + + def configure_cluster(self): + self.cluster_port = find_available_port() + if self.cluster_port is None: + pg_exc.ClusterError( + 'failed to find a port for the test cluster on localhost', + creator = self.cluster + ).raise_exception() + + listen_addresses = '127.0.0.1' + if has_ipv6: + listen_addresses += ',::1' + + self.cluster.settings.update(dict( + port = str(self.cluster_port), + max_connections = '6', + shared_buffers = '24', + listen_addresses = listen_addresses, + log_destination = 'stderr', + log_min_messages = 'FATAL', + )) + + if self.disable_replication: + self.cluster.settings.update({ + 'max_wal_senders': '0', + }) + + if self.cluster.installation.version_info[:2] < (9, 3): + self.cluster.settings.update(dict( + unix_socket_directory = self.cluster.data_directory, + )) + else: + self.cluster.settings.update(dict( + unix_socket_directories = self.cluster.data_directory, + )) + + # 8.4 turns prepared transactions off by default. + if self.cluster.installation.version_info >= (8,1): + self.cluster.settings.update(dict( + max_prepared_transactions = '3', + )) + + def initialize_database(self): + c = self.cluster.connection( + user = 'test', + database = 'template1', + ) + with c: + if c.prepare( + "select true from pg_catalog.pg_database " \ + "where datname = 'test'" + ).first() is None: + c.execute('create database test') + + def connection(self, *args, **kw): + return self.cluster.connection(*args, user = 'test', **kw) + + def drop_cluster(self): + if self.cluster.initialized(): + self.cluster.drop() + + def run(self, *args, **kw): + if 'PGINSTALLATION' not in os.environ: + # Expect tests to show skipped. + return super().run(*args, **kw) + + # From prior test run? + if self.cluster.initialized(): + self.cluster.drop() + + self.cluster.encoding = 'utf-8' + self.cluster.init( + user = 'test', + encoding = self.cluster.encoding, + logfile = None, + ) + sys.stderr.write('*') + + atexit.register(self.drop_cluster) + self.configure_cluster() + self.cluster.start(logfile = sys.stdout) + self.cluster.wait_until_started() + self.initialize_database() + + if not self.cluster.running(): + self.cluster.start() + self.cluster.wait_until_started() + + db = self.connection() + with db: + self.db = db + return super().run(*args, **kw) + self.db = None + +class test_connect(TestCaseWithCluster): + """ + postgresql.driver connection tests + """ + ip6 = '::1' + ip4 = '127.0.0.1' + host = 'localhost' + params = {} + cluster_path_suffix = '_test_connect' + + mk_common_users = """ + CREATE USER md5 WITH ENCRYPTED PASSWORD 'md5_password'; + CREATE USER password WITH ENCRYPTED PASSWORD 'password_password'; + CREATE USER trusted; + """ + + mk_crypt_user = """ + -- crypt doesn't work with encrypted passwords: + -- http://www.postgresql.org/docs/8.2/interactive/auth-methods.html#AUTH-PASSWORD + CREATE USER crypt WITH UNENCRYPTED PASSWORD 'crypt_password'; + """ + + def __init__(self, *args, **kw): + super().__init__(*args,**kw) + + @property + def check_crypt_user(self): + return (self.cluster.installation.version_info < (8,4)) + + def configure_cluster(self): + super().configure_cluster() + self.cluster.settings['log_min_messages'] = 'log' + + # Configure the hba file with the supported methods. + with open(self.cluster.hba_file, 'w') as hba: + hosts = ['0.0.0.0/0',] + if has_ipv6: + hosts.append('0::0/0') + + methods = ['md5', 'password'] + (['crypt'] if self.check_crypt_user else []) + for h in hosts: + for m in methods: + # user and method are the same name. + hba.writelines(['host test {m} {h} {m}\n'.format( + h = h, + m = m + )]) + + # trusted + hba.writelines(["local all all trust\n"]) + hba.writelines(["host test trusted 0.0.0.0/0 trust\n"]) + if has_ipv6: + hba.writelines(["host test trusted 0::0/0 trust\n"]) + # admin lines + hba.writelines(["host all test 0.0.0.0/0 trust\n"]) + if has_ipv6: + hba.writelines(["host all test 0::0/0 trust\n"]) + + def initialize_database(self): + super().initialize_database() + + with self.cluster.connection(user = 'test') as db: + db.execute(self.mk_common_users) + if self.check_crypt_user: + db.execute(self.mk_crypt_user) + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + def test_pg_open_SQL_ASCII(self): + # postgresql.open + host, port = self.cluster.address() + # test simple locators.. + with pg_open( + 'pq://' + 'md5:' + 'md5_password@' + host + ':' + str(port) \ + + '/test?client_encoding=SQL_ASCII' + ) as db: + self.assertEqual(db.prepare('select 1')(), [(1,)]) + self.assertEqual(db.settings['client_encoding'], 'SQL_ASCII') + self.assertTrue(db.closed) + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + def test_pg_open_keywords(self): + host, port = self.cluster.address() + # straight test, no IRI + with pg_open( + user = 'md5', + password = 'md5_password', + host = host, + port = port, + database = 'test' + ) as db: + self.assertEqual(db.prepare('select 1')(), [(1,)]) + self.assertTrue(db.closed) + # composite test + with pg_open( + "pq://md5:md5_password@", + host = host, + port = port, + database = 'test' + ) as db: + self.assertEqual(db.prepare('select 1')(), [(1,)]) + # override test + with pg_open( + "pq://md5:foobar@", + password = 'md5_password', + host = host, + port = port, + database = 'test' + ) as db: + self.assertEqual(db.prepare('select 1')(), [(1,)]) + # and, one with some settings + with pg_open( + "pq://md5:foobar@?search_path=ieeee", + password = 'md5_password', + host = host, + port = port, + database = 'test', + settings = {'search_path' : 'public'} + ) as db: + self.assertEqual(db.prepare('select 1')(), [(1,)]) + self.assertEqual(db.settings['search_path'], 'public') + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + def test_pg_open(self): + # postgresql.open + host, port = self.cluster.address() + # test simple locators.. + with pg_open( + 'pq://' + 'md5:' + 'md5_password@' + host + ':' + str(port) \ + + '/test' + ) as db: + self.assertEqual(db.prepare('select 1')(), [(1,)]) + self.assertTrue(db.closed) + + with pg_open( + 'pq://' + 'password:' + 'password_password@' + host + ':' + str(port) \ + + '/test' + ) as db: + self.assertEqual(db.prepare('select 1')(), [(1,)]) + self.assertTrue(db.closed) + + with pg_open( + 'pq://' + 'trusted@' + host + ':' + str(port) + '/test' + ) as db: + self.assertEqual(db.prepare('select 1')(), [(1,)]) + self.assertTrue(db.closed) + + # test environment collection + pgenv = ('PGUSER', 'PGPORT', 'PGHOST', 'PGSERVICE', 'PGPASSWORD', 'PGDATABASE') + stored = list(map(os.environ.get, pgenv)) + try: + os.environ.pop('PGSERVICE', None) + os.environ['PGUSER'] = 'md5' + os.environ['PGPASSWORD'] = 'md5_password' + os.environ['PGHOST'] = host + os.environ['PGPORT'] = str(port) + os.environ['PGDATABASE'] = 'test' + # No arguments, the environment provided everything. + with pg_open() as db: + self.assertEqual(db.prepare('select 1')(), [(1,)]) + self.assertEqual(db.prepare('select current_user').first(), 'md5') + self.assertTrue(db.closed) + finally: + i = 0 + for x in stored: + env = pgenv[i] + if x is None: + os.environ.pop(env, None) + else: + os.environ[env] = x + + oldservice = os.environ.get('PGSERVICE') + oldsysconfdir = os.environ.get('PGSYSCONFDIR') + try: + with open('pg_service.conf', 'w') as sf: + sf.write(''' +[myserv] +user = password +password = password_password +host = {host} +port = {port} +dbname = test +search_path = public +'''.format(host = host, port = port)) + sf.flush() + try: + os.environ['PGSERVICE'] = 'myserv' + os.environ['PGSYSCONFDIR'] = os.getcwd() + with pg_open() as db: + self.assertEqual(db.prepare('select 1')(), [(1,)]) + self.assertEqual(db.prepare('select current_user').first(), 'password') + self.assertEqual(db.settings['search_path'], 'public') + finally: + if oldservice is None: + os.environ.pop('PGSERVICE', None) + else: + os.environ['PGSERVICE'] = oldservice + if oldsysconfdir is None: + os.environ.pop('PGSYSCONFDIR', None) + else: + os.environ['PGSYSCONFDIR'] = oldsysconfdir + finally: + if os.path.exists('pg_service.conf'): + os.remove('pg_service.conf') + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + def test_dbapi_connect(self): + host, port = self.cluster.address() + MD5 = dbapi20.connect( + user = 'md5', + database = 'test', + password = 'md5_password', + host = host, port = port, + **self.params + ) + self.assertEqual(MD5.cursor().execute('select 1').fetchone()[0], 1) + MD5.close() + self.assertRaises(pg_exc.ConnectionDoesNotExistError, + MD5.cursor().execute, 'select 1' + ) + + if self.check_crypt_user: + CRYPT = dbapi20.connect( + user = 'crypt', + database = 'test', + password = 'crypt_password', + host = host, port = port, + **self.params + ) + self.assertEqual(CRYPT.cursor().execute('select 1').fetchone()[0], 1) + CRYPT.close() + self.assertRaises(pg_exc.ConnectionDoesNotExistError, + CRYPT.cursor().execute, 'select 1' + ) + + PASSWORD = dbapi20.connect( + user = 'password', + database = 'test', + password = 'password_password', + host = host, port = port, + **self.params + ) + self.assertEqual(PASSWORD.cursor().execute('select 1').fetchone()[0], 1) + PASSWORD.close() + self.assertRaises(pg_exc.ConnectionDoesNotExistError, + PASSWORD.cursor().execute, 'select 1' + ) + + TRUST = dbapi20.connect( + user = 'trusted', + database = 'test', + password = '', + host = host, port = port, + **self.params + ) + self.assertEqual(TRUST.cursor().execute('select 1').fetchone()[0], 1) + TRUST.close() + self.assertRaises(pg_exc.ConnectionDoesNotExistError, + TRUST.cursor().execute, 'select 1' + ) + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + def test_dbapi_connect_failure(self): + host, port = self.cluster.address() + badlogin = (lambda: dbapi20.connect( + user = '--', + database = '--', + password = '...', + host = host, port = port, + **self.params + )) + self.assertRaises(pg_exc.ClientCannotConnectError, badlogin) + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + def test_IP4_connect(self): + C = pg_driver.default.ip4( + user = 'test', + host = '127.0.0.1', + database = 'test', + port = self.cluster.address()[1], + **self.params + ) + with C() as c: + self.assertEqual(c.prepare('select 1').first(), 1) + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + @unittest.skipIf(not has_ipv6, "platform may not support IPv6") + def test_IP6_connect(self): + C = pg_driver.default.ip6( + user = 'test', + host = '::1', + database = 'test', + port = self.cluster.address()[1], + **self.params + ) + with C() as c: + self.assertEqual(c.prepare('select 1').first(), 1) + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + def test_Host_connect(self): + C = pg_driver.default.host( + user = 'test', + host = 'localhost', + database = 'test', + port = self.cluster.address()[1], + **self.params + ) + with C() as c: + self.assertEqual(c.prepare('select 1').first(), 1) + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + def test_md5_connect(self): + c = self.cluster.connection( + user = 'md5', + password = 'md5_password', + database = 'test', + **self.params + ) + with c: + self.assertEqual(c.prepare('select current_user').first(), 'md5') + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + def test_crypt_connect(self): + if self.check_crypt_user: + c = self.cluster.connection( + user = 'crypt', + password = 'crypt_password', + database = 'test', + **self.params + ) + with c: + self.assertEqual(c.prepare('select current_user').first(), 'crypt') + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + def test_password_connect(self): + c = self.cluster.connection( + user = 'password', + password = 'password_password', + database = 'test', + ) + with c: + self.assertEqual(c.prepare('select current_user').first(), 'password') + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + def test_trusted_connect(self): + c = self.cluster.connection( + user = 'trusted', + password = '', + database = 'test', + **self.params + ) + with c: + self.assertEqual(c.prepare('select current_user').first(), 'trusted') + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + def test_Unix_connect(self): + if not has_unix_sock: + return + unix_domain_socket = os.path.join( + self.cluster.data_directory, + '.s.PGSQL.' + self.cluster.settings['port'] + ) + C = pg_driver.default.unix( + user = 'test', + unix = unix_domain_socket, + ) + with C() as c: + self.assertEqual(c.prepare('select 1').first(), 1) + self.assertEqual(c.client_address, None) + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + def test_pg_open_unix(self): + if not has_unix_sock: + return + unix_domain_socket = os.path.join( + self.cluster.data_directory, + '.s.PGSQL.' + self.cluster.settings['port'] + ) + with pg_open(unix = unix_domain_socket, user = 'test') as c: + self.assertEqual(c.prepare('select 1').first(), 1) + self.assertEqual(c.client_address, None) + with pg_open('pq://test@[unix:' + unix_domain_socket.replace('/',':') + ']') as c: + self.assertEqual(c.prepare('select 1').first(), 1) + self.assertEqual(c.client_address, None) + +if __name__ == '__main__': + unittest.main() diff --git a/py_opengauss/test/test_copyman.py b/py_opengauss/test/test_copyman.py new file mode 100644 index 0000000000000000000000000000000000000000..95d0b35a1536e47bcd47e33537c99b6853d8a937 --- /dev/null +++ b/py_opengauss/test/test_copyman.py @@ -0,0 +1,619 @@ +## +# .test.test_copyman - test .copyman +## +import unittest +from itertools import islice +from .. import copyman +from ..temporal import pg_tmp +# The asyncs, and alternative termination. +from ..protocol.element3 import Notice, Notify, Error, cat_messages +from .. import exceptions as pg_exc + +# state manager can handle empty data messages, right? =) +emptysource = """ +CREATE TEMP TABLE emptysource (); +-- 10 +INSERT INTO emptysource DEFAULT VALUES; +INSERT INTO emptysource DEFAULT VALUES; +INSERT INTO emptysource DEFAULT VALUES; +INSERT INTO emptysource DEFAULT VALUES; +INSERT INTO emptysource DEFAULT VALUES; +INSERT INTO emptysource DEFAULT VALUES; +INSERT INTO emptysource DEFAULT VALUES; +INSERT INTO emptysource DEFAULT VALUES; +INSERT INTO emptysource DEFAULT VALUES; +INSERT INTO emptysource DEFAULT VALUES; +""" +emptydst = "CREATE TEMP TABLE empty ();" + +# The usual subjects. +stdrowcount = 10000 +stdsource = """ +CREATE TEMP TABLE source (i int, t text); +INSERT INTO source + SELECT i, i::text AS t + FROM generate_series(1, {0}) AS g(i); +""".format(stdrowcount) +stditer = [ + b'\t'.join((x, x)) + b'\n' + for x in ( + str(i).encode('ascii') for i in range(1, 10001) + ) +] +stditer_tuples = [ + (x, str(x)) for x in range(1, 10001) +] + +stddst = "CREATE TEMP TABLE destination (i int, t text)" +srcsql = "COPY source TO STDOUT" +dstsql = "COPY destination FROM STDIN" +binary_srcsql = "COPY source TO STDOUT WITH BINARY" +binary_dstsql = "COPY destination FROM STDIN WITH BINARY" +dstcount = "SELECT COUNT(*) FROM destination" +grabdst = "SELECT * FROM destination ORDER BY i ASC" +grabsrc = "SELECT * FROM source ORDER BY i ASC" + +## +# This subclass is used to append some arbitrary data +# after the initial data. This is used to exercise async/notice support. +class Injector(copyman.StatementProducer): + def __init__(self, appended_messages, *args, **kw): + super().__init__(*args, **kw) + self._appended_messages = appended_messages + + def confiscate(self): + pq = self.statement.database.pq + mb = pq.message_buffer + b = mb.getvalue() + mb.truncate() + mb.write(cat_messages(self._appended_messages)) + mb.write(b) + return super().confiscate() + +class test_copyman(unittest.TestCase): + def testNull(self): + # Test some of the basic machinery. + sp = copyman.NullProducer() + sr = copyman.NullReceiver() + copyman.CopyManager(sp, sr).run() + self.assertEqual(sp.total_messages, 0) + self.assertEqual(sp.total_bytes, 0) + + @pg_tmp + def testNullProducer(self): + sqlexec(stddst) + np = copyman.NullProducer() + sr = copyman.StatementReceiver(prepare(dstsql)) + copyman.CopyManager(np, sr).run() + self.assertEqual(np.total_messages, 0) + self.assertEqual(np.total_bytes, 0) + self.assertEqual(prepare(dstcount).first(), 0) + self.assertEqual(prepare(grabdst)(), []) + + @pg_tmp + def testNullReceiver(self): + sqlexec(stdsource) + sp = copyman.StatementProducer(prepare(srcsql), buffer_size = 128) + sr = copyman.NullReceiver() + with copyman.CopyManager(sp, sr) as copy: + for x in copy: + pass + self.assertEqual(sp.total_messages, stdrowcount) + self.assertEqual(sp.total_bytes > 0, True) + + def testIteratorToCall(self): + tmp = iter(stditer) + # segment stditer into chunks consisting of twenty rows each + sp = copyman.IteratorProducer([ + list(islice(tmp, 20)) for x in range(len(stditer) // 20) + ]) + dest = [] + sr = copyman.CallReceiver(dest.extend) + recomputed_bytes = 0 + recomputed_messages = 0 + with copyman.CopyManager(sp, sr) as copy: + for msg, bytes in copy: + recomputed_messages += msg + recomputed_bytes += bytes + self.assertEqual(stdrowcount, recomputed_messages) + self.assertEqual(recomputed_bytes, sp.total_bytes) + self.assertEqual(len(dest), stdrowcount) + self.assertEqual(dest, stditer) + + @pg_tmp + def testDirectStatements(self): + sqlexec(stdsource) + dst = new() + dst.execute(stddst) + sp = copyman.StatementProducer(prepare(srcsql), buffer_size = 512) + sr = copyman.StatementReceiver(dst.prepare(dstsql)) + with copyman.CopyManager(sp, sr) as copy: + for x in copy: + pass + self.assertEqual(dst.prepare(dstcount).first(), stdrowcount) + self.assertEqual(dst.prepare(grabdst)(), prepare(grabsrc)()) + + @pg_tmp + def testIteratorProducer(self): + sqlexec(stddst) + sp = copyman.IteratorProducer([stditer]) + sr = copyman.StatementReceiver(prepare(dstsql)) + recomputed_bytes = 0 + recomputed_messages = 0 + with copyman.CopyManager(sp, sr) as copy: + for msg, bytes in copy: + recomputed_messages += msg + recomputed_bytes += bytes + self.assertEqual(stdrowcount, recomputed_messages) + self.assertEqual(recomputed_bytes, sp.total_bytes) + self.assertEqual(prepare(dstcount).first(), stdrowcount) + self.assertEqual(prepare(grabdst)(), stditer_tuples) + + def multiple_destinations(self, count = 3, binary = False, buffer_size = 129): + if binary: + src = binary_srcsql + dst = binary_dstsql + # accommodate for the binary header. + count_offset = 1 + else: + src = srcsql + dst = dstsql + count_offset = 0 + sqlexec(stdsource) + dests = [new() for x in range(count)] + receivers = [] + for x in dests: + x.execute(stddst) + receivers.append(copyman.StatementReceiver(x.prepare(dst))) + sp = copyman.StatementProducer(prepare(src), buffer_size = buffer_size) + recomputed_bytes = 0 + recomputed_messages = 0 + with copyman.CopyManager(sp, *receivers) as copy: + for msg, bytes in copy: + recomputed_messages += msg + recomputed_bytes += bytes + src_snap = prepare(grabsrc)() + for x in dests: + self.assertEqual(x.prepare(dstcount).first(), stdrowcount) + self.assertEqual(x.prepare(grabdst)(), src_snap) + self.assertEqual(stdrowcount + count_offset, recomputed_messages) + self.assertEqual(recomputed_bytes, sp.total_bytes) + + @pg_tmp + def testMultipleStatements(self): + self.multiple_destinations() + + @pg_tmp + def testMultipleStatementsBinary(self): + self.multiple_destinations(binary = True) + + @pg_tmp + def testMultipleStatementsSmallBuffer(self): + self.multiple_destinations(buffer_size = 11) + + @pg_tmp + def testNotices(self): + # Inject a Notices directly into the stream to emulate + # cases of asynchronous messages received during COPY. + notices = [ + Notice(( + (b'S', b'NOTICE'), + (b'C', b'00000'), + (b'M', b'It\'s a beautiful day.'), + )), + Notice(( + (b'S', b'WARNING'), + (b'C', b'01X1X1'), + (b'M', b'FAILURE IS CERTAIN'), + )) + ] + sqlexec(stdsource) + dst = new() + dst.execute(stddst) + # hook for notices.. + rmessages = [] + def hook(msg): + rmessages.append(msg) + # suppress + return True + stmt = prepare(srcsql) + stmt.msghook = hook + sp = Injector(notices, stmt, buffer_size = 133) + sr = copyman.StatementReceiver(dst.prepare(dstsql)) + seen_in_loop = 0 + with copyman.CopyManager(sp, sr) as copy: + for x in copy: + if rmessages: + # Should get hooked before the COPY is over. + seen_in_loop += 1 + self.assertTrue(seen_in_loop > 0) + self.assertEqual(dst.prepare(dstcount).first(), stdrowcount) + self.assertEqual(dst.prepare(grabdst)(), prepare(grabsrc)()) + # The injector adds then everytime the wire data is confiscated + # from the protocol connection. + notice, warning = rmessages[:2] + self.assertEqual(notice.code, "00000") + self.assertEqual(warning.code, "01X1X1") + self.assertEqual(warning.details['severity'], "WARNING") + self.assertEqual(notice.message, "It's a beautiful day.") + self.assertEqual(warning.message, "FAILURE IS CERTAIN") + self.assertEqual(notice.details['severity'], "NOTICE") + + @pg_tmp + def testAsyncNotify(self): + # Inject a NOTIFY directly into the stream to emulate + # cases of asynchronous messages received during COPY. + notify = [Notify(1234, b'channel', b'payload')] + sqlexec(stdsource) + dst = new() + dst.execute(stddst) + sp = Injector(notify, prepare(srcsql), buffer_size = 32) + sr = copyman.StatementReceiver(dst.prepare(dstsql)) + seen_in_loop = 0 + r = [] + with copyman.CopyManager(sp, sr) as copy: + for x in copy: + r += list(db.iternotifies(0)) + # Got the injected NOTIFY's, right? + self.assertTrue(r) + # it may have happened multiple times, so adjust accordingly. + self.assertEqual(r, [('channel', 'payload', 1234)]*len(r)) + + @pg_tmp + def testUnfinishedCopy(self): + sqlexec(stdsource) + dst = new() + dst.execute(stddst) + sp = copyman.StatementProducer(prepare(srcsql), buffer_size = 32) + sr = copyman.StatementReceiver(dst.prepare(dstsql)) + try: + with copyman.CopyManager(sp, sr) as copy: + for x in copy: + break + self.fail("did not raise CopyFail") + except copyman.CopyFail: + pass + + @pg_tmp + def testRaiseInCopy(self): + sqlexec(stdsource) + dst = new() + dst.execute(stddst) + sp = copyman.StatementProducer(prepare(srcsql), buffer_size = 128) + sr = copyman.StatementReceiver(dst.prepare(dstsql)) + i = 0 + class ThisError(Exception): + pass + try: + with copyman.CopyManager(sp, sr) as copy: + for x in copy: + # Note, the state of the receiver has changed. + # We may not be on a message boundary, so this test + # exercises cases where an interrupt occurs where + # re-alignment *may* need to occur. + raise ThisError() + except copyman.CopyFail as cf: + # It's a copy failure, but due to ThisError. + self.assertTrue(isinstance(cf.__context__, ThisError)) + else: + self.fail("didn't raise CopyFail") + # Connections should be usable. + self.assertEqual(prepare('select 1').first(), 1) + self.assertEqual(dst.prepare('select 1').first(), 1) + + @pg_tmp + def testRaiseInCopyOnEnter(self): + sqlexec(stdsource) + dst = new() + dst.execute(stddst) + sp = copyman.StatementProducer(prepare(srcsql), buffer_size = 128) + sr = copyman.StatementReceiver(dst.prepare(dstsql)) + i = 0 + class ThatError(Exception): + pass + try: + with copyman.CopyManager(sp, sr) as copy: + raise ThatError() + except copyman.CopyFail as cf: + # yeah; error on incomplete COPY + self.assertTrue(isinstance(cf.__context__, ThatError)) + else: + self.fail("didn't raise CopyFail") + + @pg_tmp + def testCopyWithFailure(self): + sqlexec(stdsource) + dst = new() + dst2 = new() + dst.execute(stddst) + dst2.execute(stddst) + sp = copyman.StatementProducer(prepare(srcsql), buffer_size = 128) + sr1 = copyman.StatementReceiver(dst.prepare(dstsql)) + sr2 = copyman.StatementReceiver(dst2.prepare(dstsql)) + done = False + with copyman.CopyManager(sp, sr1, sr2) as copy: + while True: + try: + for x in copy: + if not done: + done = True + dst2.pq.socket.close() + else: + # Done with copy. + break + except copyman.ReceiverFault as cf: + if sr2 not in cf.faults: + raise + self.assertTrue(done) + self.assertRaises(Exception, dst2.execute, 'select 1') + self.assertEqual(dst.prepare(dstcount).first(), stdrowcount) + self.assertEqual(dst.prepare(grabdst)(), prepare(grabsrc)()) + + @pg_tmp + def testEmptyRows(self): + sqlexec(emptysource) + dst = new() + dst.execute(emptydst) + sp = copyman.StatementProducer(prepare("COPY emptysource TO STDOUT"), buffer_size = 127) + sr = copyman.StatementReceiver(dst.prepare("COPY empty FROM STDIN")) + m = 0 + b = 0 + with copyman.CopyManager(sp, sr) as copy: + for x in copy: + nmsg, nbytes = x + m += nmsg + b += nbytes + self.assertEqual(m, 10) + self.assertEqual(prepare("SELECT COUNT(*) FROM emptysource").first(), 10) + self.assertEqual(dst.prepare("SELECT COUNT(*) FROM empty").first(), 10) + self.assertEqual(sr.count(), 10) + self.assertEqual(sp.count(), 10) + + @pg_tmp + def testCopyOne(self): + from io import BytesIO + b = BytesIO() + copyman.transfer( + prepare('COPY (SELECT 1) TO STDOUT'), + copyman.CallReceiver(b.writelines) + ) + b.seek(0) + self.assertEqual(b.read(), b'1\n') + + @pg_tmp + def testCopyNone(self): + from io import BytesIO + b = BytesIO() + copyman.transfer( + prepare('COPY (SELECT 1 LIMIT 0) TO STDOUT'), + copyman.CallReceiver(b.writelines) + ) + b.seek(0) + self.assertEqual(b.read(), b'') + + @pg_tmp + def testNoReceivers(self): + sqlexec(stdsource) + dst = new() + dst.execute(stddst) + sp = copyman.StatementProducer(prepare(srcsql)) + sr1 = copyman.StatementReceiver(dst.prepare(dstsql)) + done = False + try: + with copyman.CopyManager(sp, sr1) as copy: + while not done: + try: + for x in copy: + if not done: + done = True + dst.pq.socket.close() + else: + self.fail("failed to detect dead socket") + except copyman.ReceiverFault as cf: + self.assertTrue(sr1 in cf.faults) + # Don't reconcile. Let the manager drop the receiver. + except copyman.CopyFail: + self.assertTrue(not bool(copy.receivers)) + # Success. + else: + self.fail("did not raise expected error") + # Let the exception cause a failure. + self.assertTrue(done) + + @pg_tmp + def testReconciliation(self): + # cm.reconcile() test. + sqlexec(stdsource) + dst = new() + dst.execute(stddst) + sp = copyman.StatementProducer(prepare(srcsql), buffer_size = 201) + sr = copyman.StatementReceiver(dst.prepare(dstsql)) + + original_call = sr.send + class RecoverableError(Exception): + pass + def failed_write(*args): + sr.send = original_call + raise RecoverableError() + sr.send = failed_write + + done = False + recomputed_messages = 0 + recomputed_bytes = 0 + with copyman.CopyManager(sp, sr) as copy: + while copy.receivers: + try: + for nmsg, nbytes in copy: + recomputed_messages += nmsg + recomputed_bytes += nbytes + else: + # Done with COPY, break out of while copy.receivers. + break + except copyman.ReceiverFault as cf: + if isinstance(cf.faults[sr], RecoverableError): + if done is True: + self.fail("failed_write was called twice?") + done = True + self.assertEqual(len(copy.receivers), 0) + copy.reconcile(sr) + self.assertEqual(len(copy.receivers), 1) + + self.assertEqual(done, True) + + # Connections should be usable. + self.assertEqual(prepare('select 1').first(), 1) + self.assertEqual(dst.prepare('select 1').first(), 1) + # validate completion + self.assertEqual(stdrowcount, recomputed_messages) + self.assertEqual(recomputed_bytes, sp.total_bytes) + self.assertEqual(dst.prepare(dstcount).first(), stdrowcount) + + @pg_tmp + def testDroppedConnection(self): + # no cm.reconcile() test. + sqlexec(stdsource) + dst = new() + dst2 = new() + dst2.execute(stddst) + dst.execute(stddst) + sp = copyman.StatementProducer(prepare(srcsql), buffer_size = 201) + sr1 = copyman.StatementReceiver(dst.prepare(dstsql)) + sr2 = copyman.StatementReceiver(dst2.prepare(dstsql)) + + class TheCause(Exception): + pass + def failed_write(*args): + raise TheCause() + sr2.send = failed_write + + done = False + recomputed_messages = 0 + recomputed_bytes = 0 + with copyman.CopyManager(sp, sr1, sr2) as copy: + while copy.receivers: + try: + for nmsg, nbytes in copy: + recomputed_messages += nmsg + recomputed_bytes += nbytes + else: + # Done with COPY, break out of while copy.receivers. + break + except copyman.ReceiverFault as cf: + self.assertTrue(isinstance(cf.faults[sr2], TheCause)) + if done is True: + self.fail("failed_write was called twice?") + done = True + self.assertEqual(len(copy.receivers), 1) + dst2.pq.socket.close() + # We don't reconcile, so the manager only has one target now. + + self.assertEqual(done, True) + # May not be aligned; really, we're expecting the connection to + # have died. + self.assertRaises(Exception, dst2.execute, "SELECT 1") + + # Connections should be usable. + self.assertEqual(prepare('select 1').first(), 1) + self.assertEqual(dst.prepare('select 1').first(), 1) + # validate completion + self.assertEqual(stdrowcount, recomputed_messages) + self.assertEqual(recomputed_bytes, sp.total_bytes) + self.assertEqual(dst.prepare(dstcount).first(), stdrowcount) + self.assertEqual(sp.count(), stdrowcount) + self.assertEqual(sp.command(), "COPY") + + @pg_tmp + def testProducerFailure(self): + sqlexec(stdsource) + dst = new() + dst.execute(stddst) + sp = copyman.StatementProducer(prepare(srcsql)) + sr = copyman.StatementReceiver(dst.prepare(dstsql)) + done = False + try: + with copyman.CopyManager(sp, sr) as copy: + try: + for x in copy: + if not done: + done = True + db.pq.socket.close() + except copyman.ProducerFault as pf: + self.assertTrue(pf.__context__ is not None) + self.fail('expected CopyManager to raise CopyFail') + except copyman.CopyFail as cf: + # Expecting to see CopyFail + self.assertTrue(True) + self.assertTrue(isinstance(cf.producer_fault, pg_exc.ConnectionFailureError)) + self.assertTrue(done) + self.assertRaises(Exception, sqlexec, 'select 1') + self.assertEqual(dst.prepare(dstcount).first(), 0) + +from ..copyman import WireState + +class test_WireState(unittest.TestCase): + def testNormal(self): + WS=WireState() + messages = WS.update(memoryview(b'd\x00\x00\x00\x04')) + self.assertEqual(messages, 1) + self.assertEqual(WS.remaining_bytes, 0) + self.assertEqual(WS.size_fragment, b'') + self.assertEqual(WS.final_view, None) + + def testIncomplete(self): + WS=WireState() + messages = WS.update(memoryview(b'd\x00\x00\x00\x05')) + self.assertEqual(messages, 0) + self.assertEqual(WS.remaining_bytes, 1) + self.assertEqual(WS.size_fragment, b'') + self.assertEqual(WS.final_view, None) + messages = WS.update(memoryview(b'x')) + self.assertEqual(messages, 1) + self.assertEqual(WS.remaining_bytes, 0) + self.assertEqual(WS.size_fragment, b'') + self.assertEqual(WS.final_view, None) + + def testIncompleteHeader_0size(self): + WS=WireState() + messages = WS.update(memoryview(b'd')) + self.assertEqual(messages, 0) + self.assertEqual(WS.remaining_bytes, -1) + self.assertEqual(WS.size_fragment, b'') + self.assertEqual(WS.final_view, None) + messages = WS.update(b'\x00\x00\x00\x04') + self.assertEqual(messages, 1) + + def testIncompleteHeader_1size(self): + WS=WireState() + messages = WS.update(memoryview(b'd\x00')) + self.assertEqual(messages, 0) + self.assertEqual(WS.size_fragment, b'\x00') + self.assertEqual(WS.final_view, None) + self.assertEqual(WS.remaining_bytes, -1) + messages = WS.update(memoryview(b'\x00\x00\x04')) + self.assertEqual(messages, 1) + self.assertEqual(WS.remaining_bytes, 0) + + def testIncompleteHeader_2size(self): + WS=WireState() + messages = WS.update(memoryview(b'd\x00\x00')) + self.assertEqual(messages, 0) + self.assertEqual(WS.remaining_bytes, -1) + self.assertEqual(WS.size_fragment, b'\x00\x00') + self.assertEqual(WS.final_view, None) + messages = WS.update(b'\x00\x04') + self.assertEqual(messages, 1) + self.assertEqual(WS.remaining_bytes, 0) + + def testIncompleteHeader_3size(self): + WS=WireState() + messages = WS.update(memoryview(b'd\x00\x00\x00')) + self.assertEqual(messages, 0) + self.assertEqual(WS.remaining_bytes, -1) + self.assertEqual(WS.size_fragment, b'\x00\x00\x00') + self.assertEqual(WS.final_view, None) + messages = WS.update(memoryview(b'\x04')) + self.assertEqual(messages, 1) + self.assertEqual(WS.remaining_bytes, 0) + +if __name__ == '__main__': + unittest.main() diff --git a/py_opengauss/test/test_dbapi20.py b/py_opengauss/test/test_dbapi20.py new file mode 100644 index 0000000000000000000000000000000000000000..00bb49b30fb31fb0cac337db7f19e755b09da935 --- /dev/null +++ b/py_opengauss/test/test_dbapi20.py @@ -0,0 +1,875 @@ +## +# .test.test_dbapi20 - test .driver.dbapi20 +## +import unittest +import time +from ..temporal import pg_tmp + +## +# Various Adjustments for .driver.dbapi20 +# +# Log: dbapi20.py +# Revision 1.10 2003/10/09 03:14:14 zenzen +# Add test for DB API 2.0 optional extension, where database exceptions +# are exposed as attributes on the Connection object. +# +# Revision 1.9 2003/08/13 01:16:36 zenzen +# Minor tweak from Stefan Fleiter +# +# Revision 1.8 2003/04/10 00:13:25 zenzen +# Changes, as per suggestions by M.-A. Lemburg +# - Add a table prefix, to ensure namespace collisions can always be avoided +# +# Revision 1.7 2003/02/26 23:33:37 zenzen +# Break out DDL into helper functions, as per request by David Rushby +# +# Revision 1.6 2003/02/21 03:04:33 zenzen +# Stuff from Henrik Ekelund: +# added test_None +# added test_nextset & hooks +# +# Revision 1.5 2003/02/17 22:08:43 zenzen +# Implement suggestions and code from Henrik Eklund - test that cursor.arraysize +# defaults to 1 & generic cursor.callproc test added +# +# Revision 1.4 2003/02/15 00:16:33 zenzen +# Changes, as per suggestions and bug reports by M.-A. Lemburg, +# Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar +# - Class renamed +# - Now a subclass of TestCase, to avoid requiring the driver stub +# to use multiple inheritance +# - Reversed the polarity of buggy test in test_description +# - Test exception heirarchy correctly +# - self.populate is now self._populate(), so if a driver stub +# overrides self.ddl1 this change propogates +# - VARCHAR columns now have a width, which will hopefully make the +# DDL even more portible (this will be reversed if it causes more problems) +# - cursor.rowcount being checked after various execute and fetchXXX methods +# - Check for fetchall and fetchmany returning empty lists after results +# are exhausted (already checking for empty lists if select retrieved +# nothing +# - Fix bugs in test_setoutputsize_basic and test_setinputsizes +# +class test_dbapi20(unittest.TestCase): + """ + Test a database self.driver for DB API 2.0 compatibility. + This implementation tests Gadfly, but the TestCase + is structured so that other self.drivers can subclass this + test case to ensure compiliance with the DB-API. It is + expected that this TestCase may be expanded in the future + if ambiguities or edge conditions are discovered. + + The 'Optional Extensions' are not yet being tested. + + self.drivers should subclass this test, overriding setUp, tearDown, + self.driver, connect_args and connect_kw_args. Class specification + should be as follows: + + import dbapi20 + class mytest(dbapi20.DatabaseAPI20Test): + [...] + + __rcs_id__ = 'Id: dbapi20.py,v 1.10 2003/10/09 03:14:14 zenzen Exp' + __version__ = 'Revision: 1.10' + __author__ = 'Stuart Bishop ' + """ + + import py_opengauss.driver.dbapi20 as driver + table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables + + booze_name = table_prefix + 'booze' + ddl1 = 'create temp table %s (name varchar(20))' % booze_name + ddl2 = 'create temp table %sbarflys (name varchar(20))' % table_prefix + xddl1 = 'drop table %sbooze' % table_prefix + xddl2 = 'drop table %sbarflys' % table_prefix + + lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase + + # Some drivers may need to override these helpers, for example adding + # a 'commit' after the execute. + def executeDDL1(self,cursor): + cursor.execute(self.ddl1) + + def executeDDL2(self,cursor): + cursor.execute(self.ddl2) + + def setUp(self): + pg_tmp.init() + pg_tmp.push() + pg_tmp._init_c(db) + + def tearDown(self): + pg_tmp.pop(None) + + def _connect(self): + c = db.clone() + c.__class__ = self.driver.Connection + c._xact = c.xact() + c._xact.start() + c._dbapi_connected_flag = True + return c + + def test_connect(self): + con = self._connect() + con.close() + + def test_apilevel(self): + try: + # Must exist + apilevel = self.driver.apilevel + # Must equal 2.0 + self.assertEqual(apilevel,'2.0') + except AttributeError: + self.fail("Driver doesn't define apilevel") + + def test_threadsafety(self): + try: + # Must exist + threadsafety = self.driver.threadsafety + # Must be a valid value + self.assertTrue(threadsafety in (0,1,2,3)) + except AttributeError: + self.fail("Driver doesn't define threadsafety") + + def test_paramstyle(self): + try: + # Must exist + paramstyle = self.driver.paramstyle + # Must be a valid value + self.assertTrue(paramstyle in ( + 'qmark','numeric','named','format','pyformat' + )) + except AttributeError: + self.fail("Driver doesn't define paramstyle") + + def test_Exceptions(self): + # Make sure required exceptions exist, and are in the + # defined heirarchy. + self.assertTrue(issubclass(self.driver.InterfaceError,self.driver.Error)) + self.assertTrue(issubclass(self.driver.DatabaseError,self.driver.Error)) + self.assertTrue(issubclass(self.driver.OperationalError,self.driver.Error)) + self.assertTrue(issubclass(self.driver.IntegrityError,self.driver.Error)) + self.assertTrue(issubclass(self.driver.InternalError,self.driver.Error)) + self.assertTrue(issubclass(self.driver.ProgrammingError,self.driver.Error)) + self.assertTrue(issubclass(self.driver.NotSupportedError,self.driver.Error)) + + def test_ExceptionsAsConnectionAttributes(self): + # OPTIONAL EXTENSION + # Test for the optional DB API 2.0 extension, where the exceptions + # are exposed as attributes on the Connection object + # I figure this optional extension will be implemented by any + # driver author who is using this test suite, so it is enabled + # by default. + con = self._connect() + try: + drv = self.driver + self.assertTrue(con.Warning is drv.Warning) + self.assertTrue(con.Error is drv.Error) + self.assertTrue(con.InterfaceError is drv.InterfaceError) + self.assertTrue(con.DatabaseError is drv.DatabaseError) + self.assertTrue(con.OperationalError is drv.OperationalError) + self.assertTrue(con.IntegrityError is drv.IntegrityError) + self.assertTrue(con.InternalError is drv.InternalError) + self.assertTrue(con.ProgrammingError is drv.ProgrammingError) + self.assertTrue(con.NotSupportedError is drv.NotSupportedError) + finally: + con.close() + + def test_commit(self): + con = self._connect() + try: + # Commit must work, even if it doesn't do anything + con.commit() + finally: + con.close() + + def test_rollback(self): + con = self._connect() + # If rollback is defined, it should either work or throw + # the documented exception + try: + if hasattr(con,'rollback'): + try: + con.rollback() + except self.driver.NotSupportedError: + pass + finally: + con.close() + + def test_cursor(self): + con = self._connect() + try: + cur = con.cursor() + finally: + con.close() + + def test_cursor_isolation(self): + con = self._connect() + try: + # Make sure cursors created from the same connection have + # the documented transaction isolation level + cur1 = con.cursor() + cur2 = con.cursor() + self.executeDDL1(cur1) + cur1.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + cur2.execute("select name from %sbooze" % self.table_prefix) + booze = cur2.fetchall() + self.assertEqual(len(booze),1) + self.assertEqual(len(booze[0]),1) + self.assertEqual(booze[0][0],'Victoria Bitter') + finally: + con.close() + + def test_description(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + self.assertEqual(cur.description,None, + 'cursor.description should be none after executing a ' + 'statement that can return no rows (such as DDL)' + ) + cur.execute('select name from %sbooze' % self.table_prefix) + self.assertEqual(len(cur.description),1, + 'cursor.description describes too many columns' + ) + self.assertEqual(len(cur.description[0]),7, + 'cursor.description[x] tuples must have 7 elements' + ) + self.assertEqual(cur.description[0][0].lower(),'name', + 'cursor.description[x][0] must return column name' + ) + self.assertEqual(cur.description[0][1],self.driver.STRING, + 'cursor.description[x][1] must return column type. Got %r' + % cur.description[0][1] + ) + + # Make sure self.description gets reset + self.executeDDL2(cur) + self.assertEqual(cur.description,None, + 'cursor.description not being set to None when executing ' + 'no-result statements (eg. DDL)' + ) + finally: + con.close() + + def test_rowcount(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + self.assertEqual(cur.rowcount, -1, + 'cursor.rowcount should be -1 after executing no-result ' + 'statements' + ) + cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + self.assertEqual(cur.rowcount, 1, + 'cursor.rowcount should == number or rows inserted, or ' + 'set to -1 after executing an insert statement' + ) + cur.execute("insert into %sbooze select 'Victoria Bitter' WHERE FALSE" % ( + self.table_prefix + )) + self.assertEqual(cur.rowcount, 0) + cur.execute("insert into %sbooze select 'First' UNION ALL select 'second'" % ( + self.table_prefix + )) + self.assertEqual(cur.rowcount, 2) + + cur.execute("select name from %sbooze" % self.table_prefix) + self.assertEqual(cur.rowcount, -1, + 'cursor.rowcount should == number of rows returned, or ' + 'set to -1 after executing a select statement' + ) + self.executeDDL2(cur) + self.assertEqual(cur.rowcount, -1, + 'cursor.rowcount not being reset to -1 after executing ' + 'no-result statements' + ) + finally: + con.close() + + lower_func = 'lower' + def test_callproc(self): + con = self._connect() + try: + cur = con.cursor() + if self.lower_func and hasattr(cur,'callproc'): + r = cur.callproc(self.lower_func,('FOO',)) + self.assertEqual(len(r),1) + self.assertEqual(r[0],'FOO') + r = cur.fetchall() + self.assertEqual(len(r),1,'callproc produced no result set') + self.assertEqual(len(r[0]),1, + 'callproc produced invalid result set' + ) + self.assertEqual(r[0][0],'foo', + 'callproc produced invalid results' + ) + finally: + con.close() + + def test_close(self): + con = self._connect() + try: + cur = con.cursor() + finally: + con.close() + + # cursor.execute should raise an Error if called after connection + # closed + self.assertRaises(self.driver.Error,self.executeDDL1,cur) + + # connection.commit should raise an Error if called after connection' + # closed.' + self.assertRaises(self.driver.Error,con.commit) + + # connection.close should raise an Error if called more than once + self.assertRaises(self.driver.Error,con.close) + + def test_cursor_close(self): + con = self._connect() + try: + cur = con.cursor() + cur.close() + # cursor.execute should raise an Error if called after cursor.close + # closed + self.assertRaises(self.driver.Error,self.executeDDL1,cur) + # cursor.executemany should raise an Error if called after connection' + # closed.' + self.assertRaises(self.driver.Error,cur.executemany,'foo', []) + + self.assertRaises(self.driver.Error,cur.callproc,'generate_series', [1, 10]) + + # cursor.close should raise an Error if called more than once + self.assertRaises(self.driver.Error,cur.close) + finally: + con.close() + + def test_execute(self): + con = self._connect() + try: + cur = con.cursor() + self._paraminsert(cur) + finally: + con.close() + + def test_format_execute(self): + self.driver.paramstyle = 'format' + try: + self.test_execute() + finally: + self.driver.paramstyle = 'pyformat' + + def _paraminsert(self,cur): + self.executeDDL1(cur) + cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + self.assertTrue(cur.rowcount in (-1,1)) + + if self.driver.paramstyle == 'qmark': + cur.execute( + 'insert into %sbooze values (?)' % self.table_prefix, + ("Cooper's",) + ) + elif self.driver.paramstyle == 'numeric': + cur.execute( + 'insert into %sbooze values (:1)' % self.table_prefix, + ("Cooper's",) + ) + elif self.driver.paramstyle == 'named': + cur.execute( + 'insert into %sbooze values (:beer)' % self.table_prefix, + {'beer':"Cooper's"} + ) + elif self.driver.paramstyle == 'format': + cur.execute( + 'insert into %sbooze values (%%s)' % self.table_prefix, + ("Cooper's",) + ) + elif self.driver.paramstyle == 'pyformat': + cur.execute( + 'insert into %sbooze values (%%(beer)s)' % self.table_prefix, + {'beer':"Cooper's"} + ) + else: + self.fail('Invalid paramstyle') + self.assertTrue(cur.rowcount in (-1,1)) + + cur.execute('select name from %sbooze' % self.table_prefix) + res = cur.fetchall() + self.assertEqual(len(res),2,'cursor.fetchall returned too few rows') + beers = [res[0][0],res[1][0]] + beers.sort() + self.assertEqual(beers[0],"Cooper's", + 'cursor.fetchall retrieved incorrect data, or data inserted ' + 'incorrectly' + ) + self.assertEqual(beers[1],"Victoria Bitter", + 'cursor.fetchall retrieved incorrect data, or data inserted ' + 'incorrectly' + ) + + def test_executemany(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + largs = [ ("Cooper's",) , ("Boag's",) ] + margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ] + if self.driver.paramstyle == 'qmark': + cur.executemany( + 'insert into %sbooze values (?)' % self.table_prefix, + largs + ) + elif self.driver.paramstyle == 'numeric': + cur.executemany( + 'insert into %sbooze values (:1)' % self.table_prefix, + largs + ) + elif self.driver.paramstyle == 'named': + cur.executemany( + 'insert into %sbooze values (:beer)' % self.table_prefix, + margs + ) + elif self.driver.paramstyle == 'format': + cur.executemany( + 'insert into %sbooze values (%%s)' % self.table_prefix, + largs + ) + elif self.driver.paramstyle == 'pyformat': + cur.executemany( + 'insert into %sbooze values (%%(beer)s)' % ( + self.table_prefix + ), + margs + ) + else: + self.fail('Unknown paramstyle') + self.assertTrue(cur.rowcount in (-1,2), + 'insert using cursor.executemany set cursor.rowcount to ' + 'incorrect value %r' % cur.rowcount + ) + cur.execute('select name from %sbooze' % self.table_prefix) + res = cur.fetchall() + self.assertEqual(len(res),2, + 'cursor.fetchall retrieved incorrect number of rows' + ) + beers = [res[0][0],res[1][0]] + beers.sort() + self.assertEqual(beers[0],"Boag's",'incorrect data retrieved') + self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved') + finally: + con.close() + + def test_format_executemany(self): + self.driver.paramstyle = 'format' + try: + self.test_executemany() + finally: + self.driver.paramstyle = 'pyformat' + + def test_fetchone(self): + con = self._connect() + try: + cur = con.cursor() + + # cursor.fetchone should raise an Error if called before + # executing a select-type query + self.assertRaises(self.driver.Error,cur.fetchone) + + # cursor.fetchone should raise an Error if called after + # executing a query that cannnot return rows + self.executeDDL1(cur) + self.assertRaises(self.driver.Error,cur.fetchone) + + cur.execute('select name from %sbooze' % self.table_prefix) + self.assertEqual(cur.fetchone(),None, + 'cursor.fetchone should return None if a query retrieves ' + 'no rows' + ) + self.assertTrue(cur.rowcount in (-1,0)) + + # cursor.fetchone should raise an Error if called after + # executing a query that cannnot return rows + cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + self.assertRaises(self.driver.Error,cur.fetchone) + + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchone() + self.assertEqual(len(r),1, + 'cursor.fetchone should have retrieved a single row' + ) + self.assertEqual(r[0],'Victoria Bitter', + 'cursor.fetchone retrieved incorrect data' + ) + self.assertEqual(cur.fetchone(),None, + 'cursor.fetchone should return None if no more rows available' + ) + self.assertTrue(cur.rowcount in (-1,1)) + finally: + con.close() + + samples = [ + 'Carlton Cold', + 'Carlton Draft', + 'Mountain Goat', + 'Redback', + 'Victoria Bitter', + 'XXXX' + ] + + def _populate(self): + ''' + Return a list of sql commands to setup the DB for the fetch tests. + ''' + populate = [ + "insert into %sbooze values ('%s')" % (self.table_prefix,s) + for s in self.samples + ] + return populate + + def test_fetchmany(self): + con = self._connect() + try: + cur = con.cursor() + + # cursor.fetchmany should raise an Error if called without + #issuing a query + self.assertRaises(self.driver.Error,cur.fetchmany,4) + + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchmany() + self.assertEqual(len(r),1, + 'cursor.fetchmany retrieved incorrect number of rows, ' + 'default of arraysize is one.' + ) + cur.arraysize=10 + r = cur.fetchmany(3) # Should get 3 rows + self.assertEqual(len(r),3, + 'cursor.fetchmany retrieved incorrect number of rows' + ) + r = cur.fetchmany(4) # Should get 2 more + self.assertEqual(len(r),2, + 'cursor.fetchmany retrieved incorrect number of rows' + ) + r = cur.fetchmany(4) # Should be an empty sequence + self.assertEqual(len(r),0, + 'cursor.fetchmany should return an empty sequence after ' + 'results are exhausted' + ) + self.assertTrue(cur.rowcount in (-1,6)) + + # Same as above, using cursor.arraysize + cur.arraysize=4 + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchmany() # Should get 4 rows + self.assertEqual(len(r),4, + 'cursor.arraysize not being honoured by fetchmany' + ) + r = cur.fetchmany() # Should get 2 more + self.assertEqual(len(r),2) + r = cur.fetchmany() # Should be an empty sequence + self.assertEqual(len(r),0) + self.assertTrue(cur.rowcount in (-1,6)) + + cur.arraysize=6 + cur.execute('select name from %sbooze' % self.table_prefix) + rows = cur.fetchmany() # Should get all rows + self.assertTrue(cur.rowcount in (-1,6)) + self.assertEqual(len(rows),6) + self.assertEqual(len(rows),6) + rows = [r[0] for r in rows] + rows.sort() + + # Make sure we get the right data back out + for i in range(0,6): + self.assertEqual(rows[i],self.samples[i], + 'incorrect data retrieved by cursor.fetchmany' + ) + + rows = cur.fetchmany() # Should return an empty list + self.assertEqual(len(rows),0, + 'cursor.fetchmany should return an empty sequence if ' + 'called after the whole result set has been fetched' + ) + self.assertTrue(cur.rowcount in (-1,6)) + + self.executeDDL2(cur) + cur.execute('select name from %sbarflys' % self.table_prefix) + r = cur.fetchmany() # Should get empty sequence + self.assertEqual(len(r),0, + 'cursor.fetchmany should return an empty sequence if ' + 'query retrieved no rows' + ) + self.assertTrue(cur.rowcount in (-1,0)) + + finally: + con.close() + + def test_fetchall(self): + con = self._connect() + try: + cur = con.cursor() + # cursor.fetchall should raise an Error if called + # without executing a query that may return rows (such + # as a select) + self.assertRaises(self.driver.Error, cur.fetchall) + + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + # cursor.fetchall should raise an Error if called + # after executing a a statement that cannot return rows + self.assertRaises(self.driver.Error,cur.fetchall) + + cur.execute('select name from %sbooze' % self.table_prefix) + rows = cur.fetchall() + self.assertTrue(cur.rowcount in (-1,len(self.samples))) + self.assertEqual(len(rows),len(self.samples), + 'cursor.fetchall did not retrieve all rows' + ) + rows = [r[0] for r in rows] + rows.sort() + for i in range(0,len(self.samples)): + self.assertEqual(rows[i],self.samples[i], + 'cursor.fetchall retrieved incorrect rows' + ) + rows = cur.fetchall() + self.assertEqual( + len(rows),0, + 'cursor.fetchall should return an empty list if called ' + 'after the whole result set has been fetched' + ) + self.assertTrue(cur.rowcount in (-1,len(self.samples))) + + self.executeDDL2(cur) + cur.execute('select name from %sbarflys' % self.table_prefix) + rows = cur.fetchall() + self.assertTrue(cur.rowcount in (-1,0)) + self.assertEqual(len(rows),0, + 'cursor.fetchall should return an empty list if ' + 'a select query returns no rows' + ) + finally: + con.close() + + def test_mixedfetch(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + cur.execute('select name from %sbooze' % self.table_prefix) + rows1 = cur.fetchone() + rows23 = cur.fetchmany(2) + rows4 = cur.fetchone() + rows56 = cur.fetchall() + self.assertTrue(cur.rowcount in (-1,6)) + self.assertEqual(len(rows23),2, + 'fetchmany returned incorrect number of rows' + ) + self.assertEqual(len(rows56),2, + 'fetchall returned incorrect number of rows' + ) + + rows = [rows1[0]] + rows.extend([rows23[0][0],rows23[1][0]]) + rows.append(rows4[0]) + rows.extend([rows56[0][0],rows56[1][0]]) + rows.sort() + for i in range(0,len(self.samples)): + self.assertEqual(rows[i],self.samples[i], + 'incorrect data retrieved or inserted' + ) + finally: + con.close() + + def help_nextset_setUp(self,cur): + ''' + Should create a procedure called deleteme + that returns two result sets, first the + number of rows in booze then "name from booze" + ''' + cur.execute('select name from ' + self.booze_name) + cur.execute('select count(*) from ' + self.booze_name) + + def help_nextset_tearDown(self,cur): + 'If cleaning up is needed after nextSetTest' + pass + + def test_nextset(self): + con = self._connect() + try: + cur = con.cursor() + if not hasattr(cur,'nextset'): + return + + try: + self.executeDDL1(cur) + sql=self._populate() + for sql in self._populate(): + cur.execute(sql) + + self.help_nextset_setUp(cur) + + numberofrows=cur.fetchone() + assert numberofrows[0]== len(self.samples) + assert cur.nextset() + names=cur.fetchall() + assert len(names) == len(self.samples) + s=cur.nextset() + assert s == None,'No more return sets, should return None' + finally: + self.help_nextset_tearDown(cur) + finally: + con.close() + + def test_arraysize(self): + # Not much here - rest of the tests for this are in test_fetchmany + con = self._connect() + try: + cur = con.cursor() + self.assertTrue(hasattr(cur,'arraysize'), + 'cursor.arraysize must be defined' + ) + finally: + con.close() + + def test_setinputsizes(self): + con = self._connect() + try: + cur = con.cursor() + cur.setinputsizes( (25,) ) + self._paraminsert(cur) # Make sure cursor still works + finally: + con.close() + + def test_setoutputsize_basic(self): + # Basic test is to make sure setoutputsize doesn't blow up + con = self._connect() + try: + cur = con.cursor() + cur.setoutputsize(1000) + cur.setoutputsize(2000,0) + self._paraminsert(cur) # Make sure the cursor still works + finally: + con.close() + + def test_setoutputsize(self): + # Real test for setoutputsize is driver dependant + pass + + def test_autocommit(self): + con = self._connect() + con2 = self._connect() + try: + con.autocommit = True + # autocommit mode on, commit/abort on inappropriate. + self.assertRaises( + con.InterfaceError, + con.commit + ) + self.assertRaises( + con.InterfaceError, + con.rollback + ) + c = con.cursor() + c.execute("create table some_committed_table(i int)") + # if this fails, autocommit had no effect on `con` + con2.cursor().execute("drop table some_committed_table") + con2.commit() + finally: + con.close() + con2.close() + + def test_None(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + cur.execute('insert into %sbooze values (NULL)' % self.table_prefix) + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchall() + self.assertEqual(len(r),1) + self.assertEqual(len(r[0]),1) + self.assertEqual(r[0][0],None,'NULL value not returned as None') + finally: + con.close() + + def test_Date(self): + d1 = self.driver.Date(2002,12,25) + d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0))) + # Can we assume this? API doesn't specify, but it seems implied + self.assertEqual(str(d1),str(d2)) + + def test_Time(self): + t1 = self.driver.Time(13,45,30) + t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0))) + # Can we assume this? API doesn't specify, but it seems implied + self.assertEqual(str(t1),str(t2)) + + def test_Timestamp(self): + t1 = self.driver.Timestamp(2002,12,25,13,45,30) + t2 = self.driver.TimestampFromTicks( + time.mktime((2002,12,25,13,45,30,0,0,0)) + ) + # Can we assume this? API doesn't specify, but it seems implied + #self.assertEqual(str(t1),str(t2)) + + def test_Binary(self): + b = self.driver.Binary(b'Something') + b = self.driver.Binary(b'') + + def test_STRING(self): + self.assertTrue(hasattr(self.driver,'STRING'), + 'module.STRING must be defined' + ) + + def test_BINARY(self): + self.assertTrue(hasattr(self.driver,'BINARY'), + 'module.BINARY must be defined.' + ) + + def test_NUMBER(self): + self.assertTrue(hasattr(self.driver,'NUMBER'), + 'module.NUMBER must be defined.' + ) + + def test_DATETIME(self): + self.assertTrue(hasattr(self.driver,'DATETIME'), + 'module.DATETIME must be defined.' + ) + + def test_ROWID(self): + self.assertTrue(hasattr(self.driver,'ROWID'), + 'module.ROWID must be defined.' + ) + + def test_placeholder_escape(self): + con = self._connect() + try: + c = con.cursor() + c.execute("SELECT 100 %% %s", (99,)) + self.assertEqual(1, c.fetchone()[0]) + c.execute("SELECT 100 %% %(foo)s", {'foo': 99}) + self.assertEqual(1, c.fetchone()[0]) + finally: + con.close() + +if __name__ == '__main__': + unittest.main() diff --git a/py_opengauss/test/test_driver.py b/py_opengauss/test/test_driver.py new file mode 100644 index 0000000000000000000000000000000000000000..df413314b0417c0d1065b5f340e56a65102c6410 --- /dev/null +++ b/py_opengauss/test/test_driver.py @@ -0,0 +1,1855 @@ +## +# .test.test_driver +## +import sys +import unittest +import gc +import threading +import time +import datetime +import decimal +import uuid +from itertools import chain, islice +from operator import itemgetter + +from ..python.datetime import FixedOffset, \ + negative_infinity_datetime, infinity_datetime, \ + negative_infinity_date, infinity_date +from .. import types as pg_types +from ..types.io.stdlib_xml_etree import etree +from .. import exceptions as pg_exc +from ..types.bitwise import Bit, Varbit +from ..temporal import pg_tmp + +type_samples = [ + ('smallint', ( + ((1 << 16) // 2) - 1, - ((1 << 16) // 2), + -1, 0, 1, + ), + ), + ('int', ( + ((1 << 32) // 2) - 1, - ((1 << 32) // 2), + -1, 0, 1, + ), + ), + ('bigint', ( + ((1 << 64) // 2) - 1, - ((1 << 64) // 2), + -1, 0, 1, + ), + ), + ('numeric', ( + -(2**64), + 2**64, + -(2**128), + 2**128, + -1, 0, 1, + decimal.Decimal("0.00000000000000"), + decimal.Decimal("1.00000000000000"), + decimal.Decimal("-1.00000000000000"), + decimal.Decimal("-2.00000000000000"), + decimal.Decimal("1000000000000000.00000000000000"), + decimal.Decimal("-0.00000000000000"), + decimal.Decimal(1234), + decimal.Decimal(-1234), + decimal.Decimal("1234000000.00088883231"), + decimal.Decimal(str(1234.00088883231)), + decimal.Decimal("3123.23111"), + decimal.Decimal("-3123000000.23111"), + decimal.Decimal("3123.2311100000"), + decimal.Decimal("-03123.0023111"), + decimal.Decimal("3123.23111"), + decimal.Decimal("3123.23111"), + decimal.Decimal("10000.23111"), + decimal.Decimal("100000.23111"), + decimal.Decimal("1000000.23111"), + decimal.Decimal("10000000.23111"), + decimal.Decimal("100000000.23111"), + decimal.Decimal("1000000000.23111"), + decimal.Decimal("1000000000.3111"), + decimal.Decimal("1000000000.111"), + decimal.Decimal("1000000000.11"), + decimal.Decimal("100000000.0"), + decimal.Decimal("10000000.0"), + decimal.Decimal("1000000.0"), + decimal.Decimal("100000.0"), + decimal.Decimal("10000.0"), + decimal.Decimal("1000.0"), + decimal.Decimal("100.0"), + decimal.Decimal("100"), + decimal.Decimal("100.1"), + decimal.Decimal("100.12"), + decimal.Decimal("100.123"), + decimal.Decimal("100.1234"), + decimal.Decimal("100.12345"), + decimal.Decimal("100.123456"), + decimal.Decimal("100.1234567"), + decimal.Decimal("100.12345679"), + decimal.Decimal("100.123456790"), + decimal.Decimal("100.123456790000000000000000"), + decimal.Decimal("1.0"), + decimal.Decimal("0.0"), + decimal.Decimal("-1.0"), + decimal.Decimal("1.0E-1000"), + decimal.Decimal("1.0E1000"), + decimal.Decimal("1.0E10000"), + decimal.Decimal("1.0E-10000"), + decimal.Decimal("1.0E15000"), + decimal.Decimal("1.0E-15000"), + decimal.Decimal("1.0E-16382"), + decimal.Decimal("1.0E32767"), + decimal.Decimal("0.000000000000000000000000001"), + decimal.Decimal("0.000000000000010000000000001"), + decimal.Decimal("0.00000000000000000000000001"), + decimal.Decimal("0.00000000100000000000000001"), + decimal.Decimal("0.0000000000000000000000001"), + decimal.Decimal("0.000000000000000000000001"), + decimal.Decimal("0.00000000000000000000001"), + decimal.Decimal("0.0000000000000000000001"), + decimal.Decimal("0.000000000000000000001"), + decimal.Decimal("0.00000000000000000001"), + decimal.Decimal("0.0000000000000000001"), + decimal.Decimal("0.000000000000000001"), + decimal.Decimal("0.00000000000000001"), + decimal.Decimal("0.0000000000000001"), + decimal.Decimal("0.000000000000001"), + decimal.Decimal("0.00000000000001"), + decimal.Decimal("0.0000000000001"), + decimal.Decimal("0.000000000001"), + decimal.Decimal("0.00000000001"), + decimal.Decimal("0.0000000001"), + decimal.Decimal("0.000000001"), + decimal.Decimal("0.00000001"), + decimal.Decimal("0.0000001"), + decimal.Decimal("0.000001"), + decimal.Decimal("0.00001"), + decimal.Decimal("0.0001"), + decimal.Decimal("0.001"), + decimal.Decimal("0.01"), + decimal.Decimal("0.1"), + # these require some weight transfer + ), + ), + ('bytea', ( + bytes(range(256)), + bytes(range(255, -1, -1)), + b'\x00\x00', + b'foo', + ), + ), + ('smallint[]', ( + [123,321,-123,-321], + [], + ), + ), + ('int[]', [ + [123,321,-123,-321], + [[1],[2]], + [], + ], + ), + ('bigint[]', [ + [ + 0, + 1, + -1, + 0xFFFFFFFFFFFF, + -0xFFFFFFFFFFFF, + ((1 << 64) // 2) - 1, + - ((1 << 64) // 2), + ], + [], + ], + ), + ('varchar[]', [ + ["foo", "bar",], + ["foo", "bar",], + [], + ], + ), + ('timestamp', [ + datetime.datetime(3000,5,20,5,30,10), + datetime.datetime(2000,1,1,5,25,10), + datetime.datetime(500,1,1,5,25,10), + datetime.datetime(250,1,1,5,25,10), + infinity_datetime, + negative_infinity_datetime, + ], + ), + ('date', [ + datetime.date(3000,5,20), + datetime.date(2000,1,1), + datetime.date(500,1,1), + datetime.date(1,1,1), + ], + ), + ('time', [ + datetime.time(12,15,20), + datetime.time(0,1,1), + datetime.time(23,59,59), + ], + ), + ('timestamptz', [ + # It's converted to UTC. When it comes back out, it will be in UTC + # again. The datetime comparison will take the tzinfo into account. + datetime.datetime(1990,5,12,10,10,0, tzinfo=FixedOffset(4000)), + datetime.datetime(1982,5,18,10,10,0, tzinfo=FixedOffset(6000)), + datetime.datetime(1950,1,1,10,10,0, tzinfo=FixedOffset(7000)), + datetime.datetime(1800,1,1,10,10,0, tzinfo=FixedOffset(2000)), + datetime.datetime(2400,1,1,10,10,0, tzinfo=FixedOffset(2000)), + infinity_datetime, + negative_infinity_datetime, + ], + ), + ('timetz', [ + # timetz retains the offset + datetime.time(10,10,0, tzinfo=FixedOffset(4000)), + datetime.time(10,10,0, tzinfo=FixedOffset(6000)), + datetime.time(10,10,0, tzinfo=FixedOffset(7000)), + datetime.time(10,10,0, tzinfo=FixedOffset(2000)), + datetime.time(22,30,0, tzinfo=FixedOffset(0)), + ], + ), + ('interval', [ + # no months :( + datetime.timedelta(40, 10, 1234), + datetime.timedelta(0, 0, 4321), + datetime.timedelta(0, 0), + datetime.timedelta(-100, 0), + datetime.timedelta(-100, -400), + ], + ), + ('point', [ + (10, 1234), + (-1, -1), + (0, 0), + (1, 1), + (-100, 0), + (-100, -400), + (-100.02314, -400.930425), + (0xFFFF, 1.3124243), + ], + ), + ('lseg', [ + ((0,0),(0,0)), + ((10,5),(18,293)), + ((55,5),(10,293)), + ((-1,-1),(-1,-1)), + ((-100,0.00231),(50,45.42132)), + ((0.123,0.00231),(50,45.42132)), + ], + ), + ('circle', [ + ((0,0),0), + ((0,0),1), + ((0,0),1.0011), + ((1,1),1.0011), + ((-1,-1),1.0011), + ((1,-1),1.0011), + ((-1,1),1.0011), + ], + ), + ('box', [ + ((0,0),(0,0)), + ((-1,-1),(-1,-1)), + ((1,1),(-1,-1)), + ((10,1),(-1,-1)), + ((100.2312,45.1232),(-123.023,-1423.82342)), + ], + ), + ('bit', [ + Bit('1'), + Bit('0'), + None, + ], + ), + ('varbit', [ + Varbit('1'), + Varbit('0'), + Varbit('10'), + Varbit('11'), + Varbit('00'), + Varbit('001'), + Varbit('101'), + Varbit('111'), + Varbit('0010'), + Varbit('1010'), + Varbit('1010'), + Varbit('01010101011111011010110101010101111'), + Varbit('010111101111'), + ], + ), + ('macaddr[]', [ + ['00:00:00:00:00:00', 'ff:ff:ff:ff:ff:ff'], + ['00:00:00:00:00:01', '00:00:00:00:00:00', 'ff:ff:ff:ff:ff:ff'], + ['00:00:00:00:00:01', '00:00:00:00:00:00', 'ff:ff:ff:ff:ff:ff', '10:00:00:00:00:00'], + ], + ), + ('jsonb', [ + '{"foo": "bar", "spam": ["ham"]}' + ]) +] + +try: + import ipaddress + + type_samples.extend([ + ('inet', [ + ipaddress.IPv4Address('255.255.255.255'), + ipaddress.IPv4Address('127.0.0.1'), + ipaddress.IPv4Address('10.0.0.1'), + ipaddress.IPv4Address('0.0.0.0'), + ipaddress.IPv6Address('::1'), + ipaddress.IPv6Address('ffff' + ':ffff'*7), + ipaddress.IPv6Address('fe80::1'), + ipaddress.IPv6Address('fe80::1'), + ipaddress.IPv6Address('::'), # 0::0 + ], + ), + ('cidr', [ + ipaddress.IPv4Network('255.255.255.255/32'), + ipaddress.IPv4Network('127.0.0.0/8'), + ipaddress.IPv4Network('127.1.0.0/16'), + ipaddress.IPv4Network('10.0.0.0/32'), + ipaddress.IPv4Network('0.0.0.0/0'), + ipaddress.IPv6Network('ffff' + ':ffff'*7 + '/128'), + ipaddress.IPv6Network('::1/128'), + ipaddress.IPv6Network('fe80::1/128'), + ipaddress.IPv6Network('fe80::/64'), + ipaddress.IPv6Network('fe80::/16'), + ipaddress.IPv6Network('::/0'), + ], + ), + ('inet[]', [ + [ipaddress.IPv4Address('127.0.0.1'), ipaddress.IPv6Address('::1')], + [ipaddress.IPv4Address('10.0.0.1'), ipaddress.IPv6Address('fe80::1')], + ], + ), + ('cidr[]', [ + [ipaddress.IPv4Network('127.0.0.0/8'), ipaddress.IPv6Network('::/0')], + [ipaddress.IPv4Network('10.0.0.0/16'), ipaddress.IPv6Network('fe80::/64')], + [ipaddress.IPv4Network('10.102.0.0/16'), ipaddress.IPv6Network('fe80::/64')], + ], + ), + ]) +except ImportError: + pass + +class test_driver(unittest.TestCase): + @pg_tmp + def testInterrupt(self): + def pg_sleep(l): + try: + db.execute("SELECT pg_sleep(5)") + except Exception: + l.append(sys.exc_info()) + else: + l.append(None) + return + rl = [] + t = threading.Thread(target = pg_sleep, args = (rl,)) + t.start() + time.sleep(0.2) + while t.is_alive(): + db.interrupt() + time.sleep(0.1) + + def raise_exc(l): + if l[0] is not None: + e, v, tb = rl[0] + raise v + self.assertRaises(pg_exc.QueryCanceledError, raise_exc, rl) + + @pg_tmp + def testClones(self): + db.execute('create table _can_clone_see_this (i int);') + try: + with db.clone() as db2: + self.assertEqual(db2.prepare('select 1').first(), 1) + self.assertEqual(db2.prepare( + "select count(*) FROM information_schema.tables " \ + "where table_name = '_can_clone_see_this'" + ).first(), 1 + ) + finally: + db.execute('drop table _can_clone_see_this') + + # check already open + db3 = db.clone() + try: + self.assertEqual(db3.prepare('select 1').first(), 1) + finally: + db3.close() + + ps = db.prepare('select 1') + ps2 = ps.clone() + self.assertEqual(ps2.first(), ps.first()) + ps2.close() + c = ps.declare() + c2 = c.clone() + self.assertEqual(c.read(), c2.read()) + + @pg_tmp + def testItsClosed(self): + ps = db.prepare("SELECT 1") + # If scroll is False it will pre-fetch, and no error will be thrown. + c = ps.declare() + # + c.close() + self.assertRaises(pg_exc.CursorNameError, c.read) + self.assertEqual(ps.first(), 1) + # + ps.close() + self.assertRaises(pg_exc.StatementNameError, ps.first) + # + db.close() + self.assertRaises( + pg_exc.ConnectionDoesNotExistError, + db.execute, "foo" + ) + # No errors, it's already closed. + ps.close() + c.close() + db.close() + + @pg_tmp + def testGarbage(self): + ps = db.prepare('select 1') + sid = ps.statement_id + ci = ps.chunks() + ci_id = ci.cursor_id + c = ps.declare() + cid = c.cursor_id + # make sure there are no remaining xact references.. + db._pq_complete() + # ci and c both hold references to ps, so they must + # be removed before we can observe the effects __del__ + del c + gc.collect() + self.assertTrue(db.typio.encode(cid) in db.pq.garbage_cursors) + del ci + gc.collect() + self.assertTrue(db.typio.encode(ci_id) in db.pq.garbage_cursors) + del ps + gc.collect() + self.assertTrue(db.typio.encode(sid) in db.pq.garbage_statements) + + @pg_tmp + def testStatementCall(self): + ps = db.prepare("SELECT 1") + r = ps() + self.assertTrue(isinstance(r, list)) + self.assertEqual(ps(), [(1,)]) + ps = db.prepare("SELECT 1, 2") + self.assertEqual(ps(), [(1,2)]) + ps = db.prepare("SELECT 1, 2 UNION ALL SELECT 3, 4") + self.assertEqual(ps(), [(1,2),(3,4)]) + + @pg_tmp + def testStatementFirstDML(self): + cmd = prepare("CREATE TEMP TABLE first (i int)").first() + self.assertEqual(cmd, 'CREATE TABLE') + fins = db.prepare("INSERT INTO first VALUES (123)").first + fupd = db.prepare("UPDATE first SET i = 321 WHERE i = 123").first + fdel = db.prepare("DELETE FROM first").first + self.assertEqual(fins(), 1) + self.assertEqual(fdel(), 1) + self.assertEqual(fins(), 1) + self.assertEqual(fupd(), 1) + self.assertEqual(fins(), 1) + self.assertEqual(fins(), 1) + self.assertEqual(fupd(), 2) + self.assertEqual(fdel(), 3) + self.assertEqual(fdel(), 0) + + @pg_tmp + def testStatementRowsPersistence(self): + # validate that rows' cursor will persist beyond a transaction. + ps = db.prepare("SELECT i FROM generate_series($1::int, $2::int) AS g(i)") + # create the iterator inside the transaction + rows = ps.rows(0, 10000-1) + ps(0,1) + # validate the first half. + self.assertEqual( + list(islice(map(itemgetter(0), rows), 5000)), + list(range(5000)) + ) + ps(0,1) + # and the second half. + self.assertEqual( + list(map(itemgetter(0), rows)), + list(range(5000, 10000)) + ) + + @pg_tmp + def testStatementParameters(self): + # too few and takes one + ps = db.prepare("select $1::integer") + self.assertRaises(TypeError, ps) + + # too many and takes one + self.assertRaises(TypeError, ps, 1, 2) + + # too many and takes none + ps = db.prepare("select 1") + self.assertRaises(TypeError, ps, 1) + + # too many and takes some + ps = db.prepare("select $1::int, $2::text") + self.assertRaises(TypeError, ps, 1, "foo", "bar") + + @pg_tmp + def testStatementAndCursorMetadata(self): + ps = db.prepare("SELECT $1::integer AS my_int_column") + self.assertEqual(tuple(ps.column_names), ('my_int_column',)) + self.assertEqual(tuple(ps.sql_column_types), ('INTEGER',)) + self.assertEqual(tuple(ps.sql_parameter_types), ('INTEGER',)) + self.assertEqual(tuple(ps.pg_parameter_types), (pg_types.INT4OID,)) + self.assertEqual(tuple(ps.parameter_types), (int,)) + self.assertEqual(tuple(ps.column_types), (int,)) + c = ps.declare(15) + self.assertEqual(tuple(c.column_names), ('my_int_column',)) + self.assertEqual(tuple(c.sql_column_types), ('INTEGER',)) + self.assertEqual(tuple(c.column_types), (int,)) + + ps = db.prepare("SELECT $1::text AS my_text_column") + self.assertEqual(tuple(ps.column_names), ('my_text_column',)) + self.assertEqual(tuple(ps.sql_column_types), ('pg_catalog.text',)) + self.assertEqual(tuple(ps.sql_parameter_types), ('pg_catalog.text',)) + self.assertEqual(tuple(ps.pg_parameter_types), (pg_types.TEXTOID,)) + self.assertEqual(tuple(ps.column_types), (str,)) + self.assertEqual(tuple(ps.parameter_types), (str,)) + c = ps.declare('textdata') + self.assertEqual(tuple(c.column_names), ('my_text_column',)) + self.assertEqual(tuple(c.sql_column_types), ('pg_catalog.text',)) + self.assertEqual(tuple(c.pg_column_types), (pg_types.TEXTOID,)) + self.assertEqual(tuple(c.column_types), (str,)) + + ps = db.prepare("SELECT $1::text AS my_column1, $2::varchar AS my_column2") + self.assertEqual(tuple(ps.column_names), ('my_column1','my_column2')) + self.assertEqual(tuple(ps.sql_column_types), ('pg_catalog.text', 'CHARACTER VARYING')) + self.assertEqual(tuple(ps.sql_parameter_types), ('pg_catalog.text', 'CHARACTER VARYING')) + self.assertEqual(tuple(ps.pg_parameter_types), (pg_types.TEXTOID, pg_types.VARCHAROID)) + self.assertEqual(tuple(ps.pg_column_types), (pg_types.TEXTOID, pg_types.VARCHAROID)) + self.assertEqual(tuple(ps.parameter_types), (str,str)) + self.assertEqual(tuple(ps.column_types), (str,str)) + c = ps.declare('textdata', 'varchardata') + self.assertEqual(tuple(c.column_names), ('my_column1','my_column2')) + self.assertEqual(tuple(c.sql_column_types), ('pg_catalog.text', 'CHARACTER VARYING')) + self.assertEqual(tuple(c.pg_column_types), (pg_types.TEXTOID, pg_types.VARCHAROID)) + self.assertEqual(tuple(c.column_types), (str,str)) + + # Should be pg_temp or sandbox. + schema = db.settings['search_path'].split(',')[0] + typpath = '"%s"."myudt"' %(schema,) + + db.execute("CREATE TYPE myudt AS (i int)") + myudt_oid = db.prepare("select oid from pg_type WHERE typname='myudt'").first() + ps = db.prepare("SELECT $1::text AS my_column1, $2::varchar AS my_column2, $3::myudt AS my_column3") + self.assertEqual(tuple(ps.column_names), ('my_column1','my_column2', 'my_column3')) + self.assertEqual(tuple(ps.sql_column_types), ('pg_catalog.text', 'CHARACTER VARYING', typpath)) + self.assertEqual(tuple(ps.sql_parameter_types), ('pg_catalog.text', 'CHARACTER VARYING', typpath)) + self.assertEqual(tuple(ps.pg_column_types), ( + pg_types.TEXTOID, pg_types.VARCHAROID, myudt_oid) + ) + self.assertEqual(tuple(ps.pg_parameter_types), ( + pg_types.TEXTOID, pg_types.VARCHAROID, myudt_oid) + ) + self.assertEqual(tuple(ps.parameter_types), (str,str,tuple)) + self.assertEqual(tuple(ps.column_types), (str,str,tuple)) + c = ps.declare('textdata', 'varchardata', (123,)) + self.assertEqual(tuple(c.column_names), ('my_column1','my_column2', 'my_column3')) + self.assertEqual(tuple(c.sql_column_types), ('pg_catalog.text', 'CHARACTER VARYING', typpath)) + self.assertEqual(tuple(c.pg_column_types), ( + pg_types.TEXTOID, pg_types.VARCHAROID, myudt_oid + )) + self.assertEqual(tuple(c.column_types), (str,str,tuple)) + + @pg_tmp + def testRowInterface(self): + data = (1, '0', decimal.Decimal('0.00'), datetime.datetime(1982,5,18,12,30,0)) + ps = db.prepare( + "SELECT 1::int2 AS col0, " \ + "'0'::text AS col1, 0.00::numeric as col2, " \ + "'1982-05-18 12:30:00'::timestamp as col3;" + ) + row = ps.first() + self.assertEqual(tuple(row), data) + + self.assertTrue(1 in row) + self.assertTrue('0' in row) + self.assertTrue(decimal.Decimal('0.00') in row) + self.assertTrue(datetime.datetime(1982,5,18,12,30,0) in row) + + self.assertEqual( + tuple(row.column_names), + tuple(['col' + str(i) for i in range(4)]) + ) + self.assertEqual( + (row["col0"], row["col1"], row["col2"], row["col3"]), + (row[0], row[1], row[2], row[3]), + ) + self.assertEqual( + (row["col0"], row["col1"], row["col2"], row["col3"]), + (row[0], row[1], row[2], row[3]), + ) + keys = list(row.keys()) + cnames = list(ps.column_names) + cnames.sort() + keys.sort() + self.assertEqual(keys, cnames) + self.assertEqual(list(row.values()), list(data)) + self.assertEqual(list(row.items()), list(zip(ps.column_names, data))) + + row_d = dict(row) + for x in ps.column_names: + self.assertEqual(row_d[x], row[x]) + for x in row_d.keys(): + self.assertEqual(row.get(x), row_d[x]) + + row_t = tuple(row) + self.assertEqual(row_t, row) + + # transform + crow = row.transform(col0 = str) + self.assertEqual(type(crow[0]), str) + crow = row.transform(str) + self.assertEqual(type(crow[0]), str) + crow = row.transform(str, int) + self.assertEqual(type(crow[0]), str) + self.assertEqual(type(crow[1]), int) + # None = no transformation + crow = row.transform(None, int) + self.assertEqual(type(crow[0]), int) + self.assertEqual(type(crow[1]), int) + # and a combination + crow = row.transform(str, col1 = int, col3 = str) + self.assertEqual(type(crow[0]), str) + self.assertEqual(type(crow[1]), int) + self.assertEqual(type(crow[3]), str) + + for i in range(4): + self.assertEqual(i, row.index_from_key('col' + str(i))) + self.assertEqual('col' + str(i), row.key_from_index(i)) + + def column_test(self): + g_i = db.prepare('SELECT i FROM generate_series(1,10) as g(i)').column + # ignore the second column. + g_ii = db.prepare('SELECT i, i+10 as i2 FROM generate_series(1,10) as g(i)').column + self.assertEqual(tuple(g_i()), tuple(g_ii())) + self.assertEqual(tuple(g_i()), (1,2,3,4,5,6,7,8,9,10)) + + @pg_tmp + def testColumn(self): + self.column_test() + + @pg_tmp + def testColumnInXact(self): + with db.xact(): + self.column_test() + + @pg_tmp + def testStatementFromId(self): + db.execute("PREPARE foo AS SELECT 1 AS colname;") + ps = db.statement_from_id('foo') + self.assertEqual(ps.first(), 1) + self.assertEqual(ps(), [(1,)]) + self.assertEqual(list(ps), [(1,)]) + self.assertEqual(tuple(ps.column_names), ('colname',)) + + @pg_tmp + def testCursorFromId(self): + db.execute("DECLARE foo CURSOR WITH HOLD FOR SELECT 1") + c = db.cursor_from_id('foo') + self.assertEqual(c.read(), [(1,)]) + db.execute( + "DECLARE bar SCROLL CURSOR WITH HOLD FOR SELECT i FROM generate_series(0, 99) AS g(i)" + ) + c = db.cursor_from_id('bar') + c.seek(50) + self.assertEqual([x for x, in c.read(10)], list(range(50,60))) + c.seek(0,2) + self.assertEqual(c.read(), []) + c.seek(0) + self.assertEqual([x for x, in c.read()], list(range(100))) + + @pg_tmp + def testCopyToSTDOUT(self): + with db.xact(): + db.execute("CREATE TABLE foo (i int)") + foo = db.prepare('insert into foo values ($1)') + foo.load_rows(((x,) for x in range(500))) + + copy_foo = db.prepare('copy foo to stdout') + foo_content = set(copy_foo) + expected = set((str(i).encode('ascii') + b'\n' for i in range(500))) + self.assertEqual(expected, foo_content) + self.assertEqual(expected, set(copy_foo())) + self.assertEqual(expected, set(chain.from_iterable(copy_foo.chunks()))) + self.assertEqual(expected, set(copy_foo.rows())) + db.execute("DROP TABLE foo") + + @pg_tmp + def testCopyFromSTDIN(self): + with db.xact(): + db.execute("CREATE TABLE foo (i int)") + foo = db.prepare('copy foo from stdin') + foo.load_rows((str(i).encode('ascii') + b'\n' for i in range(200))) + foo_content = list(( + x for (x,) in db.prepare('select * from foo order by 1 ASC') + )) + self.assertEqual(foo_content, list(range(200))) + db.execute("DROP TABLE foo") + + @pg_tmp + def testCopyInvalidTermination(self): + class DontTrapThis(BaseException): + pass + def EvilGenerator(): + raise DontTrapThis() + yield None + sqlexec("CREATE TABLE foo (i int)") + foo = prepare('copy foo from stdin') + try: + foo.load_chunks([EvilGenerator()]) + self.fail("didn't raise the BaseException subclass") + except DontTrapThis: + pass + try: + db._pq_complete() + except Exception: + pass + self.assertEqual(prepare('select 1').first(), 1) + + @pg_tmp + def testLookupProcByName(self): + db.execute( + "CREATE OR REPLACE FUNCTION public.foo() RETURNS INT LANGUAGE SQL AS 'SELECT 1'" + ) + db.settings['search_path'] = 'public' + f = db.proc('foo()') + f2 = db.proc('public.foo()') + self.assertTrue(f.oid == f2.oid, + "function lookup incongruence(%r != %r)" %(f, f2) + ) + + @pg_tmp + def testLookupProcById(self): + gsoid = db.prepare( + "select oid from pg_proc where proname = 'generate_series' limit 1" + ).first() + gs = db.proc(gsoid) + self.assertEqual(list(gs(1, 100)), list(range(1, 101))) + + def execute_proc(self): + ver = db.proc("version()") + ver() + db.execute( + "CREATE OR REPLACE FUNCTION ifoo(int) RETURNS int LANGUAGE SQL AS 'select $1'" + ) + ifoo = db.proc('ifoo(int)') + self.assertEqual(ifoo(1), 1) + self.assertEqual(ifoo(None), None) + db.execute( + "CREATE OR REPLACE FUNCTION ifoo(varchar) RETURNS text LANGUAGE SQL AS 'select $1'" + ) + ifoo = db.proc('ifoo(varchar)') + self.assertEqual(ifoo('1'), '1') + self.assertEqual(ifoo(None), None) + db.execute( + "CREATE OR REPLACE FUNCTION ifoo(varchar,int) RETURNS text LANGUAGE SQL AS 'select ($1::int + $2)::varchar'" + ) + ifoo = db.proc('ifoo(varchar,int)') + self.assertEqual(ifoo('1',1), '2') + self.assertEqual(ifoo(None,1), None) + self.assertEqual(ifoo('1',None), None) + self.assertEqual(ifoo('2',2), '4') + + @pg_tmp + def testProcExecution(self): + self.execute_proc() + + @pg_tmp + def testProcExecutionInXact(self): + with db.xact(): + self.execute_proc() + + @pg_tmp + def testProcExecutionInSubXact(self): + with db.xact(), db.xact(): + self.execute_proc() + + @pg_tmp + def testNULL(self): + # Directly commpare (SELECT NULL) is None + self.assertTrue( + db.prepare("SELECT NULL")()[0][0] is None, + "SELECT NULL did not return None" + ) + # Indirectly compare (select NULL) is None + self.assertTrue( + db.prepare("SELECT $1::text")(None)[0][0] is None, + "[SELECT $1::text](None) did not return None" + ) + + @pg_tmp + def testBool(self): + fst, snd = db.prepare("SELECT true, false").first() + self.assertTrue(fst is True) + self.assertTrue(snd is False) + + def select(self): + #self.assertEqual( + # db.prepare('')().command(), + # None, + # 'Empty statement has command?' + #) + # Test SELECT 1. + s1 = db.prepare("SELECT 1 as name") + p = s1() + tup = p[0] + self.assertTrue(tup[0] == 1) + + for tup in s1: + self.assertEqual(tup[0], 1) + + for tup in s1: + self.assertEqual(tup["name"], 1) + + @pg_tmp + def testSelect(self): + self.select() + + @pg_tmp + def testSelectInXact(self): + with db.xact(): + self.select() + + @pg_tmp + def testTransactionAlias(self): + self.assertEqual(db.transaction, db.xact) + + try: + with db.transaction(): + db.execute("CREATE TABLE t (i int);") + raise Exception('some failure') + except: + pass + else: + self.fail("expected exception was not raised") + + try: + db.query("select * from t") + except: + # No table. + pass + else: + self.fail("transaction abort had no effect") + + def cursor_read(self): + ps = db.prepare("SELECT i FROM generate_series(0, (2^8)::int - 1) AS g(i)") + c = ps.declare() + self.assertEqual(c.read(0), []) + self.assertEqual(c.read(0), []) + self.assertEqual(c.read(1), [(0,)]) + self.assertEqual(c.read(1), [(1,)]) + self.assertEqual(c.read(2), [(2,), (3,)]) + self.assertEqual(c.read(2), [(4,), (5,)]) + self.assertEqual(c.read(3), [(6,), (7,), (8,)]) + self.assertEqual(c.read(4), [(9,), (10,), (11,), (12,)]) + self.assertEqual(c.read(4), [(13,), (14,), (15,), (16,)]) + self.assertEqual(c.read(5), [(17,), (18,), (19,), (20,), (21,)]) + self.assertEqual(c.read(0), []) + self.assertEqual(c.read(6), [(22,),(23,),(24,),(25,),(26,),(27,)]) + r = [-1] + i = 4 + v = 28 + maxv = 2**8 + while r: + i = i * 2 + r = [x for x, in c.read(i)] + top = min(maxv, v + i) + self.assertEqual(r, list(range(v, top))) + v = top + + @pg_tmp + def testCursorRead(self): + self.cursor_read() + + @pg_tmp + def testCursorIter(self): + ps = db.prepare("SELECT i FROM generate_series(0, 10) AS g(i)") + c = ps.declare() + self.assertEqual(next(iter(c)), (0,)) + self.assertEqual(next(iter(c)), (1,)) + self.assertEqual(next(iter(c)), (2,)) + + @pg_tmp + def testCursorReadInXact(self): + with db.xact(): + self.cursor_read() + + @pg_tmp + def testScroll(self, direction = True): + # Use a large row-set. + imin = 0 + imax = 2**16 + if direction: + ps = db.prepare("SELECT i FROM generate_series(0, (2^16)::int) AS g(i)") + else: + ps = db.prepare("SELECT i FROM generate_series((2^16)::int, 0, -1) AS g(i)") + c = ps.declare() + c.direction = direction + if not direction: + c.seek(0) + + self.assertEqual([x for x, in c.read(10)], list(range(10))) + # bit strange to me, but i've watched the fetch backwards -jwp 2009 + self.assertEqual([x for x, in c.read(10, 'BACKWARD')], list(range(8, -1, -1))) + c.seek(0, 2) + self.assertEqual([x for x, in c.read(10, 'BACKWARD')], list(range(imax, imax-10, -1))) + + # move to end + c.seek(0, 2) + self.assertEqual([x for x, in c.read(100, 'BACKWARD')], list(range(imax, imax-100, -1))) + # move backwards, relative + c.seek(-100, 1) + self.assertEqual([x for x, in c.read(100, 'BACKWARD')], list(range(imax-200, imax-300, -1))) + + # move abs, again + c.seek(14000) + self.assertEqual([x for x, in c.read(100)], list(range(14000, 14100))) + # move forwards, relative + c.seek(100, 1) + self.assertEqual([x for x, in c.read(100)], list(range(14200, 14300))) + # move abs, again + c.seek(24000) + self.assertEqual([x for x, in c.read(200)], list(range(24000, 24200))) + # move to end and then back some + c.seek(20, 2) + self.assertEqual([x for x, in c.read(200, 'BACKWARD')], list(range(imax-20, imax-20-200, -1))) + + c.seek(0, 2) + c.seek(-10, 1) + r1 = c.read(10) + c.seek(10, 2) + self.assertEqual(r1, c.read(10)) + + @pg_tmp + def testSeek(self): + ps = db.prepare("SELECT i FROM generate_series(0, (2^6)::int - 1) AS g(i)") + c = ps.declare() + + self.assertEqual(c.seek(4, 'FORWARD'), 4) + self.assertEqual([x for x, in c.read(10)], list(range(4, 14))) + + self.assertEqual(c.seek(2, 'BACKWARD'), 2) + self.assertEqual([x for x, in c.read(10)], list(range(12, 22))) + + self.assertEqual(c.seek(-5, 'BACKWARD'), 5) + self.assertEqual([x for x, in c.read(10)], list(range(27, 37))) + + self.assertEqual(c.seek('ALL'), 27) + + def testScrollBackwards(self): + # testScroll again, but backwards this time. + self.testScroll(direction = False) + + @pg_tmp + def testWithHold(self): + with db.xact(): + ps = db.prepare("SELECT 1") + c = ps.declare() + cid = c.cursor_id + self.assertEqual(c.read()[0][0], 1) + # make sure it's not cheating + self.assertEqual(c.cursor_id, cid) + # check grabs beyond the default chunksize. + with db.xact(): + ps = db.prepare("SELECT i FROM generate_series(0, 99) as g(i)") + c = ps.declare() + cid = c.cursor_id + self.assertEqual([x for x, in c.read()], list(range(100))) + # make sure it's not cheating + self.assertEqual(c.cursor_id, cid) + + def load_rows(self): + gs = db.prepare("SELECT i FROM generate_series(1, 10000) AS g(i)") + self.assertEqual( + list((x[0] for x in gs.rows())), + list(range(1, 10001)) + ) + # exercise ``for x in chunks: dst.load_rows(x)`` + with new() as db2: + db2.execute( + """ + CREATE TABLE chunking AS + SELECT i::text AS t, i::int AS i + FROM generate_series(1, 10000) g(i); + """ + ) + read = db.prepare('select * FROM chunking').rows() + write = db2.prepare('insert into chunking values ($1, $2)').load_rows + with db2.xact(): + write(read) + del read, write + + self.assertEqual( + db.prepare('select count(*) FROM chunking').first(), + 20000 + ) + self.assertEqual( + db.prepare('select count(DISTINCT i) FROM chunking').first(), + 10000 + ) + db.execute('DROP TABLE chunking') + + @pg_tmp + def testLoadRows(self): + self.load_rows() + + @pg_tmp + def testLoadRowsInXact(self): + with db.xact(): + self.load_rows() + + def load_chunks(self): + gs = db.prepare("SELECT i FROM generate_series(1, 10000) AS g(i)") + self.assertEqual( + list((x[0] for x in chain.from_iterable(gs.chunks()))), + list(range(1, 10001)) + ) + # exercise ``for x in chunks: dst.load_chunks(x)`` + with new() as db2: + db2.execute( + """ + CREATE TABLE chunking AS + SELECT i::text AS t, i::int AS i + FROM generate_series(1, 10000) g(i); + """ + ) + read = db.prepare('select * FROM chunking').chunks() + write = db2.prepare('insert into chunking values ($1, $2)').load_chunks + with db2.xact(): + write(read) + del read, write + + self.assertEqual( + db.prepare('select count(*) FROM chunking').first(), + 20000 + ) + self.assertEqual( + db.prepare('select count(DISTINCT i) FROM chunking').first(), + 10000 + ) + db.execute('DROP TABLE chunking') + + @pg_tmp + def testLoadChunks(self): + self.load_chunks() + + @pg_tmp + def testLoadChunkInXact(self): + with db.xact(): + self.load_chunks() + + @pg_tmp + def testSimpleDML(self): + db.execute("CREATE TEMP TABLE emp(emp_name text, emp_age int)") + try: + mkemp = db.prepare("INSERT INTO emp VALUES ($1, $2)") + del_all_emp = db.prepare("DELETE FROM emp") + command, count = mkemp('john', 35) + self.assertEqual(command, 'INSERT') + self.assertEqual(count, 1) + command, count = mkemp('jane', 31) + self.assertEqual(command, 'INSERT') + self.assertEqual(count, 1) + command, count = del_all_emp() + self.assertEqual(command, 'DELETE') + self.assertEqual(count, 2) + finally: + db.execute("DROP TABLE emp") + + def dml(self): + db.execute("CREATE TEMP TABLE t(i int)") + try: + insert_t = db.prepare("INSERT INTO t VALUES ($1)") + delete_t = db.prepare("DELETE FROM t WHERE i = $1") + delete_all_t = db.prepare("DELETE FROM t") + update_t = db.prepare("UPDATE t SET i = $2 WHERE i = $1") + self.assertEqual(insert_t(1)[1], 1) + self.assertEqual(delete_t(1)[1], 1) + self.assertEqual(insert_t(2)[1], 1) + self.assertEqual(insert_t(2)[1], 1) + self.assertEqual(delete_t(2)[1], 2) + + self.assertEqual(insert_t(3)[1], 1) + self.assertEqual(insert_t(3)[1], 1) + self.assertEqual(insert_t(3)[1], 1) + self.assertEqual(delete_all_t()[1], 3) + + self.assertEqual(update_t(1, 2)[1], 0) + self.assertEqual(insert_t(1)[1], 1) + self.assertEqual(update_t(1, 2)[1], 1) + self.assertEqual(delete_t(1)[1], 0) + self.assertEqual(delete_t(2)[1], 1) + finally: + db.execute("DROP TABLE t") + + @pg_tmp + def testDML(self): + self.dml() + + @pg_tmp + def testDMLInXact(self): + with db.xact(): + self.dml() + + def batch_dml(self): + db.execute("CREATE TEMP TABLE t(i int)") + try: + insert_t = db.prepare("INSERT INTO t VALUES ($1)") + delete_t = db.prepare("DELETE FROM t WHERE i = $1") + delete_all_t = db.prepare("DELETE FROM t") + update_t = db.prepare("UPDATE t SET i = $2 WHERE i = $1") + mset = ( + (2,), (2,), (3,), (4,), (5,), + ) + insert_t.load_rows(mset) + content = db.prepare("SELECT * FROM t ORDER BY 1 ASC") + self.assertEqual(mset, tuple(content())) + finally: + db.execute("DROP TABLE t") + + @pg_tmp + def testBatchDML(self): + self.batch_dml() + + @pg_tmp + def testBatchDMLInXact(self): + with db.xact(): + self.batch_dml() + + @pg_tmp + def testTypes(self): + 'test basic object I/O--input must equal output' + for (typname, sample_data) in type_samples: + pb = db.prepare( + "SELECT $1::" + typname + ) + for sample in sample_data: + rsample = list(pb.rows(sample))[0][0] + if isinstance(rsample, pg_types.Array): + rsample = rsample.nest() + self.assertTrue( + rsample == sample, + "failed to return %s object data as-is; gave %r, received %r" %( + typname, sample, rsample + ) + ) + + @pg_tmp + def testDomainSupport(self): + 'test domain type I/O' + + db.execute('CREATE DOMAIN int_t AS int') + db.execute('CREATE DOMAIN int_t_2 AS int_t') + db.execute('CREATE TYPE tt AS (a int_t, b int_t_2)') + + samples = { + 'int_t': [10], + 'int_t_2': [11], + 'tt': [(12, 13)] + } + + for (typname, sample_data) in samples.items(): + pb = db.prepare( + "SELECT $1::" + typname + ) + for sample in sample_data: + rsample = list(pb.rows(sample))[0][0] + if isinstance(rsample, pg_types.Array): + rsample = rsample.nest() + self.assertTrue( + rsample == sample, + "failed to return %s object data as-is; gave %r, received %r" %( + typname, sample, rsample + ) + ) + + @pg_tmp + def testAnonymousRecord(self): + 'test anonymous record unpacking' + + db.execute('CREATE TYPE tar_t AS (a int, b int)') + + tests = { + "SELECT (1::int, '2'::text, '2012-01-01 18:00 UTC'::timestamptz)": + (1, '2', datetime.datetime(2012, 1, 1, 18, 0, tzinfo=FixedOffset(0))), + + "SELECT (1::int, '2'::text, (3::int, '4'::text))": + (1, '2', (3, '4')), + + "SELECT (i::int, (i + 1, i + 2)::tar_t) FROM generate_series(1, 10) as i": + (1, (2, 3)), + + "SELECT (1::int, ARRAY[(2, 3), (3, 4)])": + (1, pg_types.Array([(2, 3), (3, 4)])) + } + + for qry, expected in tests.items(): + pb = db.prepare(qry) + result = next(iter(pb.rows()))[0] + self.assertEqual(result, expected) + + def check_xml(self): + try: + xml = db.prepare('select $1::xml') + textxml = db.prepare('select $1::text::xml') + r = textxml.first('') + except (pg_exc.FeatureError, pg_exc.UndefinedObjectError): + # XML is not available. + return + foo = etree.XML('') + bar = etree.XML('') + if hasattr(etree, 'tostringlist'): + # 3.2 + def tostr(x): + return etree.tostring(x, encoding='utf-8') + else: + # 3.1 compat + tostr = etree.tostring + self.assertEqual(tostr(xml.first(foo)), tostr(foo)) + self.assertEqual(tostr(xml.first(bar)), tostr(bar)) + self.assertEqual(tostr(textxml.first('')), tostr(foo)) + self.assertEqual(tostr(textxml.first('')), tostr(foo)) + self.assertEqual(tostr(xml.first(etree.XML(''))), tostr(foo)) + self.assertEqual(tostr(textxml.first('')), tostr(foo)) + # test fragments + self.assertEqual( + tuple( + tostr(x) for x in xml.first('') + ), (tostr(foo), tostr(bar)) + ) + self.assertEqual( + tuple( + tostr(x) for x in textxml.first('') + ), + (tostr(foo), tostr(bar)) + ) + # mixed text and etree. + self.assertEqual( + tuple( + tostr(x) for x in xml.first(( + '', bar, + )) + ), + (tostr(foo), tostr(bar)) + ) + self.assertEqual( + tuple( + tostr(x) for x in xml.first(( + '', bar, + )) + ), + (tostr(foo), tostr(bar)) + ) + + @pg_tmp + def testXML(self): + self.check_xml() + + @pg_tmp + def testXML_ascii(self): + # check a non-utf8 encoding (3.2 and up) + db.settings['client_encoding'] = 'sql_ascii' + self.check_xml() + + @pg_tmp + def testXML_utf8(self): + # in 3.2 we always serialize at utf-8, so check that + # that path is being ran by forcing the client_encoding to utf8. + db.settings['client_encoding'] = 'utf8' + self.check_xml() + + @pg_tmp + def testUUID(self): + # doesn't exist in all versions supported by py-postgresql. + has_uuid = db.prepare( + "select true from pg_type where lower(typname) = 'uuid'").first() + if has_uuid: + ps = db.prepare('select $1::uuid').first + x = uuid.uuid1() + self.assertEqual(ps(x), x) + + def _infinity_test(self, typname, inf, neg): + ps = db.prepare('SELECT $1::' + typname).first + val = ps('infinity') + self.assertEqual(val, inf) + val = ps('-infinity') + self.assertEqual(val, neg) + val = ps(inf) + self.assertEqual(val, inf) + val = ps(neg) + self.assertEqual(val, neg) + ps = db.prepare('SELECT $1::' + typname + '::text').first + self.assertEqual(ps('infinity'), 'infinity') + self.assertEqual(ps('-infinity'), '-infinity') + + @pg_tmp + def testInfinity_stdlib_datetime(self): + self._infinity_test("timestamptz", infinity_datetime, negative_infinity_datetime) + self._infinity_test("timestamp", infinity_datetime, negative_infinity_datetime) + + @pg_tmp + def testInfinity_stdlib_date(self): + try: + db.prepare("SELECT 'infinity'::date")() + self._infinity_test('date', infinity_date, negative_infinity_date) + except: + pass + + @pg_tmp + def testTypeIOError(self): + original = dict(db.typio._cache) + ps = db.prepare('SELECT $1::numeric') + self.assertRaises(pg_exc.ParameterError, ps, 'foo') + try: + db.execute('CREATE type test_tuple_error AS (n numeric);') + ps = db.prepare('SELECT $1::test_tuple_error AS the_column') + self.assertRaises(pg_exc.ParameterError, ps, ('foo',)) + try: + ps(('foo',)) + except pg_exc.ParameterError as err: + # 'foo' is not a valid Decimal. + # Expecting a double TupleError here, one from the composite pack + # and one from the row pack. + self.assertTrue(isinstance(err.__cause__, pg_exc.CompositeError)) + self.assertEqual(int(err.details['position']), 0) + # attribute number that the failure occurred on + self.assertEqual(int(err.__cause__.details['position']), 0) + else: + self.fail("failed to raise TupleError") + + # testing tuple error reception is a bit more difficult. + # to do this, we need to immitate failure as we can't rely that any + # causable failure will always exist. + class ThisError(Exception): + pass + def raise_ThisError(arg): + raise ThisError(arg) + pack, unpack, typ = db.typio.resolve(pg_types.NUMERICOID) + # remove any existing knowledge about "test_tuple_error" + db.typio._cache = original + db.typio._cache[pg_types.NUMERICOID] = (pack, raise_ThisError, typ) + # Now, numeric_unpack will always raise "ThisError". + ps = db.prepare('SELECT $1::numeric as col') + self.assertRaises( + pg_exc.ColumnError, ps, decimal.Decimal("101") + ) + try: + ps(decimal.Decimal("101")) + except pg_exc.ColumnError as err: + self.assertTrue(isinstance(err.__cause__, ThisError)) + # might be too inquisitive.... + self.assertEqual(int(err.details['position']), 0) + self.assertTrue('NUMERIC' in err.message) + self.assertTrue('col' in err.message) + else: + self.fail("failed to raise TupleError from reception") + ps = db.prepare('SELECT $1::test_tuple_error AS tte') + try: + ps((decimal.Decimal("101"),)) + except pg_exc.ColumnError as err: + self.assertTrue(isinstance(err.__cause__, pg_exc.CompositeError)) + self.assertTrue(isinstance(err.__cause__.__cause__, ThisError)) + # might be too inquisitive.... + self.assertEqual(int(err.details['position']), 0) + self.assertEqual(int(err.__cause__.details['position']), 0) + self.assertTrue('test_tuple_error' in err.message) + else: + self.fail("failed to raise TupleError from reception") + finally: + db.execute('drop type test_tuple_error;') + + @pg_tmp + def testSyntaxError(self): + try: + db.prepare("SELEKT 1")() + except pg_exc.SyntaxError: + return + self.fail("SyntaxError was not raised") + + @pg_tmp + def testSchemaNameError(self): + try: + db.prepare("CREATE TABLE sdkfldasjfdskljZknvson.foo()")() + except pg_exc.SchemaNameError: + return + self.fail("SchemaNameError was not raised") + + @pg_tmp + def testUndefinedTableError(self): + try: + db.prepare("SELECT * FROM public.lkansdkvsndlvksdvnlsdkvnsdlvk")() + except pg_exc.UndefinedTableError: + return + self.fail("UndefinedTableError was not raised") + + @pg_tmp + def testUndefinedColumnError(self): + try: + db.prepare("SELECT x____ysldvndsnkv FROM information_schema.tables")() + except pg_exc.UndefinedColumnError: + return + self.fail("UndefinedColumnError was not raised") + + @pg_tmp + def testSEARVError_avgInWhere(self): + try: + db.prepare("SELECT 1 WHERE avg(1) = 1")() + except pg_exc.SEARVError: + return + self.fail("SEARVError was not raised") + + @pg_tmp + def testSEARVError_groupByAgg(self): + try: + db.prepare("SELECT 1 GROUP BY avg(1)")() + except pg_exc.SEARVError: + return + self.fail("SEARVError was not raised") + + @pg_tmp + def testTypeMismatchError(self): + try: + db.prepare("SELECT 1 WHERE 1")() + except pg_exc.TypeMismatchError: + return + self.fail("TypeMismatchError was not raised") + + @pg_tmp + def testUndefinedObjectError(self): + try: + self.assertRaises( + pg_exc.UndefinedObjectError, + db.prepare, "CREATE TABLE lksvdnvsdlksnv(i intt___t)" + ) + except: + # newer versions throw the exception on execution + self.assertRaises( + pg_exc.UndefinedObjectError, + db.prepare("CREATE TABLE lksvdnvsdlksnv(i intt___t)") + ) + + @pg_tmp + def testZeroDivisionError(self): + self.assertRaises( + pg_exc.ZeroDivisionError, + db.prepare("SELECT 1/i FROM (select 0 as i) AS g(i)").first, + ) + + @pg_tmp + def testTransactionCommit(self): + with db.xact(): + db.execute("CREATE TEMP TABLE withfoo(i int)") + db.prepare("SELECT * FROM withfoo") + + db.execute("DROP TABLE withfoo") + self.assertRaises( + pg_exc.UndefinedTableError, + db.execute, "SELECT * FROM withfoo" + ) + + @pg_tmp + def testTransactionAbort(self): + class SomeError(Exception): + pass + try: + with db.xact(): + db.execute("CREATE TABLE withfoo (i int)") + raise SomeError + except SomeError: + pass + self.assertRaises( + pg_exc.UndefinedTableError, + db.execute, "SELECT * FROM withfoo" + ) + + @pg_tmp + def testSerializeable(self): + with new() as db2: + db2.execute("create table some_darn_table (i int);") + try: + with db.xact(isolation = 'serializable'): + db.execute('insert into some_darn_table values (123);') + # db2 is in autocommit.. + db2.execute('insert into some_darn_table values (321);') + self.assertNotEqual( + list(db.prepare('select * from some_darn_table')), + list(db2.prepare('select * from some_darn_table')), + ) + finally: + # cleanup + db2.execute("drop table some_darn_table;") + + @pg_tmp + def testReadOnly(self): + class something(Exception): + pass + try: + with db.xact(mode = 'read only'): + self.assertRaises( + pg_exc.ReadOnlyTransactionError, + db.execute, + "create table ieeee(i int)" + ) + raise something("yeah, it raised.") + self.fail("should have been passed by exception") + except something: + pass + + @pg_tmp + def testFailedTransactionBlock(self): + try: + with db.xact(): + try: + db.execute("selekt 1;") + except pg_exc.SyntaxError: + pass + self.fail("__exit__ didn't identify failed transaction") + except pg_exc.InFailedTransactionError as err: + self.assertEqual(err.source, 'CLIENT') + + @pg_tmp + def testFailedSubtransactionBlock(self): + with db.xact(): + try: + with db.xact(): + try: + db.execute("selekt 1;") + except pg_exc.SyntaxError: + pass + self.fail("__exit__ didn't identify failed transaction") + except pg_exc.InFailedTransactionError as err: + # driver should have released/aborted instead + self.assertEqual(err.source, 'CLIENT') + + @pg_tmp + def testSuccessfulSubtransactionBlock(self): + with db.xact(): + with db.xact(): + db.execute("create temp table subxact_sx1(i int);") + with db.xact(): + db.execute("create temp table subxact_sx2(i int);") + # And, because I'm paranoid. + # The following block is used to make sure + # that savepoints are actually being set. + try: + with db.xact(): + db.execute("selekt 1") + except pg_exc.SyntaxError: + # Just in case the xact() aren't doing anything. + pass + with db.xact(): + db.execute("create temp table subxact_sx3(i int);") + # if it can't drop these tables, it didn't manage the subxacts + # properly. + db.execute("drop table subxact_sx1") + db.execute("drop table subxact_sx2") + db.execute("drop table subxact_sx3") + + @pg_tmp + def testReleasedSavepoint(self): + # validate that the rolled back savepoint is released as well. + x = None + with db.xact(): + try: + with db.xact(): + try: + with db.xact() as x: + db.execute("selekt 1") + except pg_exc.SyntaxError: + db.execute('RELEASE "xact(' + hex(id(x)) + ')"') + except pg_exc.InvalidSavepointSpecificationError as e: + pass + else: + self.fail("InvalidSavepointSpecificationError not raised") + + @pg_tmp + def testCloseInSubTransactionBlock(self): + try: + with db.xact(): + db.close() + self.fail("transaction __exit__ didn't identify cause ConnectionDoesNotExistError") + except pg_exc.ConnectionDoesNotExistError: + pass + + @pg_tmp + def testCloseInSubTransactionBlock(self): + try: + with db.xact(): + with db.xact(): + db.close() + self.fail("transaction __exit__ didn't identify cause ConnectionDoesNotExistError") + self.fail("transaction __exit__ didn't identify cause ConnectionDoesNotExistError") + except pg_exc.ConnectionDoesNotExistError: + pass + + @pg_tmp + def testSettingsCM(self): + orig = db.settings['search_path'] + with db.settings(search_path='public'): + self.assertEqual(db.settings['search_path'], 'public') + self.assertEqual(db.settings['search_path'], orig) + + @pg_tmp + def testSettingsReset(self): + # <3 search_path + del db.settings['search_path'] + cur = db.settings['search_path'] + db.settings['search_path'] = 'pg_catalog' + del db.settings['search_path'] + self.assertEqual(db.settings['search_path'], cur) + + @pg_tmp + def testSettingsCount(self): + self.assertEqual( + len(db.settings), db.prepare('select count(*) from pg_settings').first() + ) + + @pg_tmp + def testSettingsGet(self): + self.assertEqual( + db.settings['search_path'], db.settings.get('search_path') + ) + self.assertEqual(None, db.settings.get(' $*0293 vksnd')) + + @pg_tmp + def testSettingsGetSet(self): + sub = db.settings.getset( + ('search_path', 'default_statistics_target') + ) + self.assertEqual(db.settings['search_path'], sub['search_path']) + self.assertEqual(db.settings['default_statistics_target'], sub['default_statistics_target']) + + @pg_tmp + def testSettings(self): + d = dict(db.settings) + d = dict(db.settings.items()) + k = list(db.settings.keys()) + v = list(db.settings.values()) + self.assertEqual(len(k), len(d)) + self.assertEqual(len(k), len(v)) + for x in k: + self.assertTrue(d[x] in v) + all = list(db.settings.getset(k).items()) + all.sort(key=itemgetter(0)) + dall = list(d.items()) + dall.sort(key=itemgetter(0)) + self.assertEqual(dall, all) + + @pg_tmp + def testDo(self): + # plpgsql is expected to be available. + if db.version_info[:2] < (8,5): + return + if 'plpgsql' not in db.sys.languages(): + db.execute("CREATE LANGUAGE plpgsql") + db.do('plpgsql', "BEGIN CREATE TEMP TABLE do_tmp_table(i int, t text); END",) + self.assertEqual(len(db.prepare("SELECT * FROM do_tmp_table")()), 0) + db.do('plpgsql', "BEGIN INSERT INTO do_tmp_table VALUES (100, 'foo'); END") + self.assertEqual(len(db.prepare("SELECT * FROM do_tmp_table")()), 1) + + @pg_tmp + def testListeningChannels(self): + db.listen('foo', 'bar') + self.assertEqual(set(db.listening_channels()), {'foo','bar'}) + db.unlisten('bar') + db.listen('foo', 'bar') + self.assertEqual(set(db.listening_channels()), {'foo','bar'}) + db.unlisten('foo', 'bar') + self.assertEqual(set(db.listening_channels()), set()) + + @pg_tmp + def testNotify(self): + db.listen('foo', 'bar') + db.listen('foo', 'bar') + db.notify('foo') + db.execute('') + self.assertEqual(db._notifies[0].channel, b'foo') + self.assertEqual(db._notifies[0].pid, db.backend_id) + self.assertEqual(db._notifies[0].payload, b'') + del db._notifies[0] + db.notify('bar') + db.execute('') + self.assertEqual(db._notifies[0].channel, b'bar') + self.assertEqual(db._notifies[0].pid, db.backend_id) + self.assertEqual(db._notifies[0].payload, b'') + del db._notifies[0] + db.unlisten('foo') + db.notify('foo') + db.execute('') + self.assertEqual(db._notifies, []) + # Invoke an error to show that listen() is all or none. + self.assertRaises(Exception, db.listen, 'doesntexist', 'x'*64) + self.assertTrue('doesntexist' not in db.listening_channels()) + + @pg_tmp + def testPayloads(self): + if db.version_info[:2] >= (9,0): + db.listen('foo') + db.notify(foo = 'bar') + self.assertEqual(('foo', 'bar', db.backend_id), list(db.iternotifies(0))[0]) + db.notify(('foo', 'barred')) + self.assertEqual(('foo', 'barred', db.backend_id), list(db.iternotifies(0))[0]) + # mixed + db.notify(('foo', 'barred'), 'foo', ('foo', 'bleh'), foo = 'kw') + self.assertEqual([ + ('foo', 'barred', db.backend_id), + ('foo', '', db.backend_id), + ('foo', 'bleh', db.backend_id), + # Keywords are appened. + ('foo', 'kw', db.backend_id), + ], list(db.iternotifies(0)) + ) + # multiple keywords + expect = [ + ('foo', 'meh', db.backend_id), + ('bar', 'foo', db.backend_id), + ] + rexpect = list(reversed(expect)) + db.listen('bar') + db.notify(foo = 'meh', bar = 'foo') + self.assertTrue(list(db.iternotifies(0)) in [expect, rexpect]) + + @pg_tmp + def testMessageHook(self): + create = db.prepare('CREATE TEMP TABLE msghook (i INT)') + reindex = db.prepare('REINDEX TABLE msghook') + drop = db.prepare('DROP TABLE msghook') + parts = [ + reindex, + db, + db.connector, + db.connector.driver, + ] + notices = [] + def add(x): + notices.append(x) + # inhibit + return True + with db.xact(): + db.settings['client_min_messages'] = 'NOTICE' + # test an installed msghook at each level + for x in parts: + x.msghook = add + create() + reindex() + del x.msghook + drop() + self.assertEqual(len(notices), len(parts)) + last = None + for x in notices: + if last is None: + last = x + continue + self.assertTrue(x.isconsistent(last)) + last = x + + @pg_tmp + def testRowTypeFactory(self): + from ..types.namedtuple import NamedTupleFactory + db.typio.RowTypeFactory = NamedTupleFactory + ps = prepare('select 1 as foo, 2 as bar') + first_results = ps.first() + self.assertEqual(first_results.foo, 1) + self.assertEqual(first_results.bar, 2) + + call_results = ps()[0] + self.assertEqual(call_results.foo, 1) + self.assertEqual(call_results.bar, 2) + + declare_results = ps.declare().read(1)[0] + self.assertEqual(declare_results.foo, 1) + self.assertEqual(declare_results.bar, 2) + + sqlexec('create type rtf AS (foo int, bar int)') + ps = prepare('select ROW(1, 2)::rtf') + composite_results = ps.first() + self.assertEqual(composite_results.foo, 1) + self.assertEqual(composite_results.bar, 2) + + @pg_tmp + def testNamedTuples(self): + from ..types.namedtuple import namedtuples + ps = namedtuples(prepare('select 1 as foo, 2 as bar, $1::text as param')) + r = list(ps("hello"))[0] + self.assertEqual(r[0], 1) + self.assertEqual(r.foo, 1) + self.assertEqual(r[1], 2) + self.assertEqual(r.bar, 2) + self.assertEqual(r[2], "hello") + self.assertEqual(r.param, "hello") + + @pg_tmp + def testBadFD(self): + db.pq.socket.close() + # bad fd now. + self.assertRaises( + pg_exc.ConnectionFailureError, + sqlexec, "SELECT 1" + ) + self.assertTrue(issubclass(pg_exc.ConnectionFailureError, pg_exc.Disconnection)) + + @pg_tmp + def testAdminTerminated(self): + with new() as killer: + if killer.version_info[:2] <= (9,1): + killer.sys.terminate_backends() + else: + killer.sys.terminate_backends_92() + + self.assertRaises( + pg_exc.AdminShutdownError, + sqlexec, "SELECT 1", + ) + self.assertTrue(issubclass(pg_exc.AdminShutdownError, pg_exc.Disconnection)) + + @pg_tmp + def testQuery(self): + self.assertEqual(db.query('select 1'), [(1,)]) + self.assertEqual(db.query.first('select 1'), 1) + self.assertEqual(next(db.query.column('select 1')), 1) + self.assertEqual(next(db.query.rows('select 1')), (1,)) + self.assertEqual(db.query.declare('select 1').read(), [(1,)]) + + self.assertEqual(db.query('select $1::int', 1), [(1,)]) + self.assertEqual(db.query.first('select $1::int', 1), 1) + self.assertEqual(next(db.query.column('select $1::int', 1)), 1) + self.assertEqual(next(db.query.rows('select $1::int', 1)), (1,)) + self.assertEqual(db.query.declare('select $1::int', 1).read(), [(1,)]) + + self.assertEqual(db.query.load_rows('select $1::int', [[1]]), None) + self.assertEqual(db.query.load_chunks('select $1::int', [[[1]]]), None) + +class test_typio(unittest.TestCase): + @pg_tmp + def testIdentify(self): + # It just exercises the code path. + db.typio.identify(contrib_hstore = 'pg_catalog.text') + + @pg_tmp + def testArrayNulls(self): + try: + sqlexec('SELECT ARRAY[1,NULL]::int[]') + except Exception: + # unsupported here + return + inta = prepare('select $1::int[]').first + texta = prepare('select $1::text[]').first + self.assertEqual(inta([1,2,None]), [1,2,None]) + self.assertEqual(texta(["foo",None,"bar"]), ["foo",None,"bar"]) + +if __name__ == '__main__': + unittest.main() diff --git a/py_opengauss/test/test_exceptions.py b/py_opengauss/test/test_exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..64debba2fdac60a716f3f4c19718e3e6d2e39be6 --- /dev/null +++ b/py_opengauss/test/test_exceptions.py @@ -0,0 +1,56 @@ +## +# .test.test_exceptions +## +import unittest +import py_opengauss.exceptions as pg_exc + +class test_exceptions(unittest.TestCase): + def test_pg_code_lookup(self): + # in 8.4, pg started using the SQL defined error code for limits + # Users *will* get whatever code PG sends, but it's important + # that they have some way to abstract it. many-to-one map ftw. + self.assertEqual( + pg_exc.ErrorLookup('22020'), pg_exc.LimitValueError + ) + + def test_error_lookup(self): + # An error code that doesn't exist yields pg_exc.Error + self.assertEqual( + pg_exc.ErrorLookup('00000'), pg_exc.Error + ) + + self.assertEqual( + pg_exc.ErrorLookup('XX000'), pg_exc.InternalError + ) + # check class fallback + self.assertEqual( + pg_exc.ErrorLookup('XX444'), pg_exc.InternalError + ) + + # SEARV is a very large class, so there are many + # sub-"codeclass" exceptions used to group the many + # SEARV errors. Make sure looking up 42000 actually + # gives the SEARVError + self.assertEqual( + pg_exc.ErrorLookup('42000'), pg_exc.SEARVError + ) + self.assertEqual( + pg_exc.ErrorLookup('08P01'), pg_exc.ProtocolError + ) + + def test_warning_lookup(self): + self.assertEqual( + pg_exc.WarningLookup('01000'), pg_exc.Warning + ) + self.assertEqual( + pg_exc.WarningLookup('02000'), pg_exc.NoDataWarning + ) + self.assertEqual( + pg_exc.WarningLookup('01P01'), pg_exc.DeprecationWarning + ) + self.assertEqual( + pg_exc.WarningLookup('01888'), pg_exc.Warning + ) + +if __name__ == '__main__': + unittest.main() diff --git a/py_opengauss/test/test_installation.py b/py_opengauss/test/test_installation.py new file mode 100644 index 0000000000000000000000000000000000000000..ff92072a53751c1cea3d8d9b6ea22f9938aca36b --- /dev/null +++ b/py_opengauss/test/test_installation.py @@ -0,0 +1,83 @@ +## +# .test.test_installation +## +import sys +import os +import unittest +from .. import installation as ins + +class test_installation(unittest.TestCase): + """ + Most of this is exercised by TestCaseWithCluster, but do some + explicit checks up front to help find any specific issues that + do not naturally occur. + """ + def test_parse_configure_options(self): + # Check expectations. + self.assertEqual( + list(ins.parse_configure_options("")), [], + ) + self.assertEqual( + list(ins.parse_configure_options(" ")), [], + ) + self.assertEqual( + list(ins.parse_configure_options("--foo --bar")), + [('foo',True),('bar',True)] + ) + self.assertEqual( + list(ins.parse_configure_options("'--foo' '--bar'")), + [('foo',True),('bar',True)] + ) + self.assertEqual( + list(ins.parse_configure_options("'--foo=A properly isolated string' '--bar'")), + [('foo','A properly isolated string'),('bar',True)] + ) + # hope they don't ever use backslash escapes. + # This is pretty dirty, but it doesn't seem well defined anyways. + self.assertEqual( + list(ins.parse_configure_options("'--foo=A ''properly'' isolated string' '--bar'")), + [('foo',"A 'properly' isolated string"),('bar',True)] + ) + # handle some simple variations, but it's + self.assertEqual( + list(ins.parse_configure_options("'--foo' \"--bar\"")), + [('foo',True),('bar',True)] + ) + # Show the failure. + try: + self.assertEqual( + list(ins.parse_configure_options("'--foo' \"--bar=/A dir/file\"")), + [('foo',True),('bar','/A dir/file')] + ) + except AssertionError: + pass + else: + self.fail("did not detect induced failure") + + def test_minimum(self): + 'version info' + # Installation only "needs" the version information + i = ins.Installation({'version' : 'PostgreSQL 2.2.3'}) + self.assertEqual( + i.version, 'PostgreSQL 2.2.3' + ) + self.assertEqual( + i.version_info, (2,2,3,'final',0) + ) + self.assertEqual(i.postgres, None) + self.assertEqual(i.postmaster, None) + + def test_exec(self): + # check the executable + i = ins.pg_config_dictionary( + sys.executable, '-m', __package__ + '.support', 'pg_config') + # automatically lowers the key + self.assertEqual(i['foo'], 'BaR') + self.assertEqual(i['feh'], 'YEAH') + self.assertEqual(i['version'], 'NAY') + +if __name__ == '__main__': + from types import ModuleType + this = ModuleType("this") + this.__dict__.update(globals()) + unittest.main(this) diff --git a/py_opengauss/test/test_iri.py b/py_opengauss/test/test_iri.py new file mode 100644 index 0000000000000000000000000000000000000000..0a6f423ca47e001c087b08c2ac006955572793b0 --- /dev/null +++ b/py_opengauss/test/test_iri.py @@ -0,0 +1,176 @@ +## +# .test.test_iri +## +import unittest +import py_opengauss.iri as pg_iri + +value_errors = ( + # Invalid scheme. + 'http://user@host/index.html', +) + +iri_samples = ( + 'host/dbname/path?param=val#frag', + '#frag', + '?param=val', + '?param=val#frag', + 'user@', + ':pass@', + 'u:p@h', + 'u:p@h:1', + 'postgres://host/database', + 'pq://user:password@host:port/database?setting=value#public,private', + 'pq://fæm.com:123/õéf/á?param=val', + 'pq://l»»@fæm.com:123/õéf/á?param=val', + 'pq://fæᎱᏋm.com/õéf/á?param=val', + 'pq://fæᎱᏋm.com/õéf/á?param=val&[setting]=value', +) + +sample_structured_parameters = [ + { + 'host' : 'hostname', + 'port' : '1234', + 'database' : 'foo_db', + }, + { + 'user' : 'username', + 'database' : 'database_name', + 'settings' : {'foo':'bar','feh':'bl%,23'}, + }, + { + 'user' : 'username', + 'database' : 'database_name', + }, + { + 'database' : 'database_name', + }, + { + 'user' : 'user_name', + }, + { + 'host' : 'hostname', + }, + { + 'user' : 'username', + 'password' : 'pass', + 'host' : '', + 'port' : '4321', + 'database' : 'database_name', + 'path' : ['path'], + }, + { + 'user' : 'user', + 'password' : 'secret', + 'host' : '', + 'port' : 'ssh', + 'database' : 'database_name', + 'settings' : { + 'set1' : 'val1', + 'set2' : 'val2', + }, + }, + { + 'user' : 'user', + 'password' : 'secret', + 'host' : '', + 'port' : 'ssh', + 'database' : 'database_name', + 'settings' : { + 'set1' : 'val1', + 'set2' : 'val2', + }, + 'connect_timeout' : '10', + 'sslmode' : 'prefer', + }, +] + +class test_iri(unittest.TestCase): + def testAlternateSchemes(self): + field = pg_iri.parse("postgres://host")['host'] + self.assertEqual(field, 'host') + + field = pg_iri.parse("postgresql://host")['host'] + self.assertEqual(field, 'host') + + try: + pg_iri.parse("reject://host") + except ValueError: + pass + else: + self.fail("unacceptable IRI scheme not rejected") + + def testIP6Hosts(self): + """ + Validate that IPv6 hosts are properly extracted. + """ + s = [ + ('pq://[::1]/db', '::1'), + ('pq://[::1]:1234/db', '::1'), + ('pq://[1:2:3::1]/db', '1:2:3::1'), + ('pq://[1:2:3::1]:1234/db', '1:2:3::1'), + ('pq://[]:1234/db', ''), + ('pq://[]/db', ''), + ] + for i, h in s: + p = pg_iri.parse(i) + self.assertEqual(p['host'], h) + + def testPresentPasswordObscure(self): + """ + Password is present in IRI, and obscure it. + """ + s = 'pq://user:pass@host:port/dbname' + o = 'pq://user:***@host:port/dbname' + p = pg_iri.parse(s) + ps = pg_iri.serialize(p, obscure_password = True) + self.assertEqual(ps, o) + + def testPresentPasswordObscure(self): + """ + Password is *not* present in IRI, and do nothing. + """ + s = 'pq://user@host:port/dbname' + o = 'pq://user@host:port/dbname' + p = pg_iri.parse(s) + ps = pg_iri.serialize(p, obscure_password = True) + self.assertEqual(ps, o) + + def testValueErrors(self): + for x in value_errors: + self.assertRaises(ValueError, + pg_iri.parse, x + ) + + def testParseSerialize(self): + scheme = 'pq://' + for x in iri_samples: + px = pg_iri.parse(x) + spx = pg_iri.serialize(px) + pspx = pg_iri.parse(spx) + self.assertTrue( + pspx == px, + "parse-serialize incongruity, %r -> %r -> %r : %r != %r" %( + x, px, spx, pspx, px + ) + ) + spspx = pg_iri.serialize(pspx) + self.assertTrue( + spx == spspx, + "parse-serialize incongruity, %r -> %r -> %r -> %r : %r != %r" %( + x, px, spx, pspx, spspx, spx + ) + ) + + def testSerializeParse(self): + for x in sample_structured_parameters: + xs = pg_iri.serialize(x) + uxs = pg_iri.parse(xs) + self.assertTrue( + x == uxs, + "serialize-parse incongruity, %r -> %r -> %r" %( + x, xs, uxs, + ) + ) + +if __name__ == '__main__': + unittest.main() diff --git a/py_opengauss/test/test_lib.py b/py_opengauss/test/test_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..67e5f6a258545df9066f7b244f192693b3225151 --- /dev/null +++ b/py_opengauss/test/test_lib.py @@ -0,0 +1,170 @@ +## +# .test.test_lib - test the .lib package +## +import sys +import os +import unittest +import tempfile + +from .. import exceptions as pg_exc +from .. import lib as pg_lib +from .. import sys as pg_sys +from ..temporal import pg_tmp + +ilf = """ +preface + +[sym] +select 1 +[sym_ref] +*[sym] +[sym_ref_trail] +*[sym] WHERE FALSE +[sym_first::first] +select 1 + + +[sym_rows::rows] +select 1 + +[sym_chunks::chunks] +select 1 + +[sym_declare::declare] +select 1 + +[sym_const:const:first] +select 1 +[sym_const_rows:const:rows] +select 1 +[sym_const_chunks:const:chunks] +select 1 +[sym_const_column:const:column] +select 1 +[sym_const_ddl:const:] +create temp table sym_const_dll (i int); + +[sym_preload:preload:first] +select 1 + +[sym_proc:proc] +test_ilf_proc(int) + +[sym_srf_proc:proc] +test_ilf_srf_proc(int) + +[&sym_reference] +SELECT 'SELECT 1'; + +[&sym_reference_params] +SELECT 'SELECT ' || $1::text; + +[&sym_reference_first::first] +SELECT 'SELECT 1::int4'; + +[&sym_reference_const:const:first] +SELECT 'SELECT 1::int4'; + +[&sym_reference_proc:proc] +SELECT 'test_ilf_proc(int)'::text +""" + +class test_lib(unittest.TestCase): + # NOTE: Module libraries are implicitly tested + # in postgresql.test.test_driver; much functionality + # depends on the `sys` library. + def _testILF(self, lib): + self.assertTrue('preface' in lib.preface) + db.execute("CREATE OR REPLACE FUNCTION test_ilf_proc(int) RETURNS int language sql as 'select $1';") + db.execute("CREATE OR REPLACE FUNCTION test_ilf_srf_proc(int) RETURNS SETOF int language sql as 'select $1';") + b = pg_lib.Binding(db, lib) + self.assertEqual(b.sym_ref(), [(1,)]) + self.assertEqual(b.sym_ref_trail(), []) + self.assertEqual(b.sym(), [(1,)]) + self.assertEqual(b.sym_first(), 1) + self.assertEqual(list(b.sym_rows()), [(1,)]) + self.assertEqual([list(x) for x in b.sym_chunks()], [[(1,)]]) + c = b.sym_declare() + self.assertEqual(c.read(), [(1,)]) + c.seek(0) + self.assertEqual(c.read(), [(1,)]) + self.assertEqual(b.sym_const, 1) + self.assertEqual(b.sym_const_column, [1]) + self.assertEqual(b.sym_const_rows, [(1,)]) + self.assertEqual(b.sym_const_chunks, [[(1,)]]) + self.assertEqual(b.sym_const_ddl, ('CREATE TABLE', None)) + self.assertEqual(b.sym_preload(), 1) + # now stored procs + self.assertEqual(b.sym_proc(2,), 2) + self.assertEqual(list(b.sym_srf_proc(2,)), [2]) + self.assertRaises(AttributeError, getattr, b, 'LIES') + # reference symbols + self.assertEqual(b.sym_reference()(), [(1,)]) + self.assertEqual(b.sym_reference_params('1::int')(), [(1,)]) + self.assertEqual(b.sym_reference_params("'foo'::text")(), [('foo',)]) + self.assertEqual(b.sym_reference_first()(), 1) + self.assertEqual(b.sym_reference_const(), 1) + self.assertEqual(b.sym_reference_proc()(2,), 2) + + @pg_tmp + def testILF_from_lines(self): + lib = pg_lib.ILF.from_lines([l + '\n' for l in ilf.splitlines()]) + self._testILF(lib) + + @pg_tmp + def testILF_from_file(self): + f = tempfile.NamedTemporaryFile( + delete = False, mode = 'w', encoding = 'utf-8' + ) + n = f.name + try: + f.write(ilf) + f.flush() + f.seek(0) + lib = pg_lib.ILF.open(n, encoding = 'utf-8') + self._testILF(lib) + f.close() + finally: + # so annoying... + os.unlink(n) + + @pg_tmp + def testLoad(self): + # gotta test it in the cwd... + pid = os.getpid() + frag = 'temp' + str(pid) + fn = 'lib' + frag + '.sql' + try: + with open(fn, 'w') as f: + f.write("[foo]\nSELECT 1") + pg_sys.libpath.insert(0, os.path.curdir) + l = pg_lib.load(frag) + b = pg_lib.Binding(db, l) + self.assertEqual(b.foo(), [(1,)]) + finally: + os.remove(fn) + + @pg_tmp + def testCategory(self): + lib = pg_lib.ILF.from_lines([l + '\n' for l in ilf.splitlines()]) + # XXX: evil, careful.. + lib._name = 'name' + c = pg_lib.Category(lib) + c(db) + self.assertEqual(db.name.sym_first(), 1) + c = pg_lib.Category(renamed = lib) + c(db) + self.assertEqual(db.renamed.sym_first(), 1) + + @pg_tmp + def testCategoryAliases(self): + lib = pg_lib.ILF.from_lines([l + '\n' for l in ilf.splitlines()]) + # XXX: evil, careful.. + lib._name = 'name' + c = pg_lib.Category(lib, renamed = lib) + c(db) + self.assertEqual(db.name.sym_first(), 1) + self.assertEqual(db.renamed.sym_first(), 1) + +if __name__ == '__main__': + unittest.main() diff --git a/py_opengauss/test/test_notifyman.py b/py_opengauss/test/test_notifyman.py new file mode 100644 index 0000000000000000000000000000000000000000..67b94c862792dc750a52033d0b887585aaa3af1d --- /dev/null +++ b/py_opengauss/test/test_notifyman.py @@ -0,0 +1,137 @@ +## +# .test.test_notifyman - test .notifyman +## +import unittest +import threading +import time +from ..temporal import pg_tmp +from ..notifyman import NotificationManager + +class test_notifyman(unittest.TestCase): + @pg_tmp + def testNotificationManager(self): + # signals each other + alt = new() + with alt: + nm = NotificationManager(db, alt) + db.listen('foo') + alt.listen('bar') + # notify the other. + alt.notify('foo') + db.notify('bar') + # we can separate these here because there's no timeout + for ndb, notifies in nm: + for n in notifies: + if ndb is db: + self.assertEqual(n[0], 'foo') + self.assertEqual(n[1], '') + self.assertEqual(n[2], alt.backend_id) + nm.connections.discard(db) + elif ndb is alt: + self.assertEqual(n[0], 'bar') + self.assertEqual(n[1], '') + self.assertEqual(n[2], db.backend_id) + nm.connections.discard(alt) + else: + self.fail("unknown connection received notify..") + + @pg_tmp + def testNotificationManagerTimeout(self): + nm = NotificationManager(db, timeout = 0.1) + db.listen('foo') + count = 0 + for event in nm: + if event is None: + # do this a few times, then break out of the loop + db.notify('foo') + continue + ndb, notifies = event + self.assertEqual(ndb, db) + for n in notifies: + self.assertEqual(n[0], 'foo') + self.assertEqual(n[1], '') + self.assertEqual(n[2], db.backend_id) + count = count + 1 + if count > 3: + break + + @pg_tmp + def testNotificationManagerZeroTimeout(self): + # Zero-timeout means raise StopIteration when + # there are no notifications to emit. + # It checks the wire, but does *not* wait for data. + nm = NotificationManager(db, timeout = 0) + db.listen('foo') + self.assertEqual(list(nm), []) + db.notify('foo') + time.sleep(0.01) + self.assertEqual(list(nm), [('foo','',db.backend_id)]) # bit of a race + + @pg_tmp + def test_iternotifies(self): + # db.iternotifies() simplification of NotificationManager + alt = new() + alt.listen('foo') + alt.listen('close') + def get_notices(db, l): + with db: + for x in db.iternotifies(): + if x[0] == 'close': + break + l.append(x) + rl = [] + t = threading.Thread(target = get_notices, args = (alt, rl,)) + t.start() + db.notify('foo') + while not rl: + time.sleep(0.05) + channel, payload, pid = rl.pop(0) + self.assertEqual(channel, 'foo') + self.assertEqual(payload, '') + self.assertEqual(pid, db.backend_id) + db.notify('close') + + @pg_tmp + def testNotificationManagerZeroTimeout(self): + # Zero-timeout means raise StopIteration when + # there are no notifications to emit. + # It checks the wire, but does *not* wait for data. + db.listen('foo') + self.assertEqual(list(db.iternotifies(0)), []) + db.notify('foo') + time.sleep(0.01) + self.assertEqual(list(db.iternotifies(0)), [('foo','', db.backend_id)]) # bit of a race + + @pg_tmp + def testNotificationManagerOnClosed(self): + # When the connection goes away, the NM iterator + # should raise a Stop. + db = new() + db.listen('foo') + db.notify('foo') + for n in db.iternotifies(): + db.close() + self.assertEqual(db.closed, True) + del db + # closer, after an idle + db = new() + db.listen('foo') + for n in db.iternotifies(0.2): + if n is None: + # In the loop, notify, and expect to + # get the notification even though the + # connection was closed. + db.notify('foo') + db.execute('') + db.close() + hit = False + else: + hit = True + # hit should get set two times. + # once on the first idle, and once on the event + # received after the close. + self.assertEqual(db.closed, True) + self.assertEqual(hit, True) + +if __name__ == '__main__': + unittest.main() diff --git a/py_opengauss/test/test_optimized.py b/py_opengauss/test/test_optimized.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a5d4b109af816e9f66e0c541b00f5c43c68774 --- /dev/null +++ b/py_opengauss/test/test_optimized.py @@ -0,0 +1,250 @@ +## +# test.test_optimized +## +import unittest +import struct +import sys +from ..port import optimized +from ..python.itertools import interlace + +def pack_tuple(*data, + packH = struct.Struct("!H").pack, + packL = struct.Struct("!L").pack +): + return packH(len(data)) + b''.join(( + packL(len(x)) + x if x is not None else b'\xff\xff\xff\xff' + for x in data + )) + +tuplemessages = ( + (b'D', pack_tuple(b'foo', b'bar')), + (b'D', pack_tuple(b'foo', None, b'bar')), + (b'N', b'fee'), + (b'D', pack_tuple(b'foo', None, b'bar')), + (b'D', pack_tuple(b'foo', b'bar')), +) + +class test_optimized(unittest.TestCase): + def test_consume_tuple_messages(self): + ctm = optimized.consume_tuple_messages + # expecting a tuple of pairs. + self.assertRaises(TypeError, ctm, []) + self.assertEqual(ctm(()), []) + # Make sure that the slicing is working. + self.assertEqual(ctm(tuplemessages), [ + (b'foo', b'bar'), + (b'foo', None, b'bar'), + ]) + # Not really checking consume here, but we are validating that + # it's properly propagating exceptions. + self.assertRaises(ValueError, ctm, ((b'D', b'\xff\xff\xff\xfefoo'),)) + self.assertRaises(ValueError, ctm, ((b'D', b'\x00\x00\x00\x04foo'),)) + + def test_parse_tuple_message(self): + ptm = optimized.parse_tuple_message + self.assertRaises(TypeError, ptm, "stringzor") + self.assertRaises(TypeError, ptm, 123) + self.assertRaises(ValueError, ptm, b'') + self.assertRaises(ValueError, ptm, b'0') + + notenoughdata = struct.pack('!H', 2) + self.assertRaises(ValueError, ptm, notenoughdata) + + wraparound = struct.pack('!HL', 2, 10) + (b'0' * 10) + struct.pack('!L', 0xFFFFFFFE) + self.assertRaises(ValueError, ptm, wraparound) + + oneatt_notenough = struct.pack('!HL', 2, 10) + (b'0' * 10) + struct.pack('!L', 15) + self.assertRaises(ValueError, ptm, oneatt_notenough) + + toomuchdata = struct.pack('!HL', 1, 3) + (b'0' * 10) + self.assertRaises(ValueError, ptm, toomuchdata) + + class faketup(tuple): + def __new__(subtype, geeze): + r = tuple.__new__(subtype, ()) + r.foo = geeze + return r + zerodata = struct.pack('!H', 0) + r = ptm(zerodata) + self.assertRaises(AttributeError, getattr, r, 'foo') + self.assertRaises(AttributeError, setattr, r, 'foo', 'bar') + self.assertEqual(len(r), 0) + + def test_process_tuple(self): + def funpass(procs, tup, col): + pass + pt = optimized.process_tuple + # tuple() requirements + self.assertRaises(TypeError, pt, "foo", "bar", funpass) + self.assertRaises(TypeError, pt, (), "bar", funpass) + self.assertRaises(TypeError, pt, "foo", (), funpass) + self.assertRaises(TypeError, pt, (), ("foo",), funpass) + + def test_pack_tuple_data(self): + pit = optimized.pack_tuple_data + self.assertEqual(pit((None,)), b'\xff\xff\xff\xff') + self.assertEqual(pit((None,)*2), b'\xff\xff\xff\xff'*2) + self.assertEqual(pit((None,)*3), b'\xff\xff\xff\xff'*3) + self.assertEqual(pit((None,b'foo')), b'\xff\xff\xff\xff\x00\x00\x00\x03foo') + self.assertEqual(pit((None,b'')), b'\xff\xff\xff\xff\x00\x00\x00\x00') + self.assertEqual(pit((None,b'',b'bar')), b'\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x03bar') + self.assertRaises(TypeError, pit, 1) + self.assertRaises(TypeError, pit, (1,)) + self.assertRaises(TypeError, pit, ("",)) + + def test_int2(self): + d = b'\x00\x01' + rd = b'\x01\x00' + s = optimized.swap_int2_unpack(d) + n = optimized.int2_unpack(d) + sd = optimized.swap_int2_pack(1) + nd = optimized.int2_pack(1) + if sys.byteorder == 'little': + self.assertEqual(1, s) + self.assertEqual(256, n) + self.assertEqual(d, sd) + self.assertEqual(rd, nd) + else: + self.assertEqual(1, n) + self.assertEqual(256, s) + self.assertEqual(d, nd) + self.assertEqual(rd, sd) + self.assertRaises(OverflowError, optimized.swap_int2_pack, 2**15) + self.assertRaises(OverflowError, optimized.int2_pack, 2**15) + self.assertRaises(OverflowError, optimized.swap_int2_pack, (-2**15)-1) + self.assertRaises(OverflowError, optimized.int2_pack, (-2**15)-1) + + def test_int4(self): + d = b'\x00\x00\x00\x01' + rd = b'\x01\x00\x00\x00' + s = optimized.swap_int4_unpack(d) + n = optimized.int4_unpack(d) + sd = optimized.swap_int4_pack(1) + nd = optimized.int4_pack(1) + if sys.byteorder == 'little': + self.assertEqual(1, s) + self.assertEqual(16777216, n) + self.assertEqual(d, sd) + self.assertEqual(rd, nd) + else: + self.assertEqual(1, n) + self.assertEqual(16777216, s) + self.assertEqual(d, nd) + self.assertEqual(rd, sd) + self.assertRaises(OverflowError, optimized.swap_int4_pack, 2**31) + self.assertRaises(OverflowError, optimized.int4_pack, 2**31) + self.assertRaises(OverflowError, optimized.swap_int4_pack, (-2**31)-1) + self.assertRaises(OverflowError, optimized.int4_pack, (-2**31)-1) + + def test_int8(self): + d = b'\x00\x00\x00\x00\x00\x00\x00\x01' + rd = b'\x01\x00\x00\x00\x00\x00\x00\x00' + s = optimized.swap_int8_unpack(d) + n = optimized.int8_unpack(d) + sd = optimized.swap_int8_pack(1) + nd = optimized.int8_pack(1) + if sys.byteorder == 'little': + self.assertEqual(0x1, s) + self.assertEqual(0x100000000000000, n) + self.assertEqual(d, sd) + self.assertEqual(rd, nd) + else: + self.assertEqual(0x1, n) + self.assertEqual(0x100000000000000, s) + self.assertEqual(d, nd) + self.assertEqual(rd, sd) + self.assertEqual(optimized.swap_int8_pack(-1), b'\xFF\xFF\xFF\xFF'*2) + self.assertEqual(optimized.int8_pack(-1), b'\xFF\xFF\xFF\xFF'*2) + self.assertRaises(OverflowError, optimized.swap_int8_pack, 2**63) + self.assertRaises(OverflowError, optimized.int8_pack, 2**63) + self.assertRaises(OverflowError, optimized.swap_int8_pack, (-2**63)-1) + self.assertRaises(OverflowError, optimized.int8_pack, (-2**63)-1) + # edge I/O + int8_max = ((2**63) - 1) + int8_min = (-(2**63)) + swap_max = optimized.swap_int8_pack(int8_max) + max = optimized.int8_pack(int8_max) + swap_min = optimized.swap_int8_pack(int8_min) + min = optimized.int8_pack(int8_min) + self.assertEqual(optimized.swap_int8_unpack(swap_max), int8_max) + self.assertEqual(optimized.int8_unpack(max), int8_max) + self.assertEqual(optimized.swap_int8_unpack(swap_min), int8_min) + self.assertEqual(optimized.int8_unpack(min), int8_min) + + def test_uint2(self): + d = b'\x00\x01' + rd = b'\x01\x00' + s = optimized.swap_uint2_unpack(d) + n = optimized.uint2_unpack(d) + sd = optimized.swap_uint2_pack(1) + nd = optimized.uint2_pack(1) + if sys.byteorder == 'little': + self.assertEqual(1, s) + self.assertEqual(256, n) + self.assertEqual(d, sd) + self.assertEqual(rd, nd) + else: + self.assertEqual(1, n) + self.assertEqual(256, s) + self.assertEqual(d, nd) + self.assertEqual(rd, sd) + self.assertRaises(OverflowError, optimized.swap_uint2_pack, -1) + self.assertRaises(OverflowError, optimized.uint2_pack, -1) + self.assertRaises(OverflowError, optimized.swap_uint2_pack, 2**16) + self.assertRaises(OverflowError, optimized.uint2_pack, 2**16) + self.assertEqual(optimized.uint2_pack(2**16-1), b'\xFF\xFF') + self.assertEqual(optimized.swap_uint2_pack(2**16-1), b'\xFF\xFF') + + def test_uint4(self): + d = b'\x00\x00\x00\x01' + rd = b'\x01\x00\x00\x00' + s = optimized.swap_uint4_unpack(d) + n = optimized.uint4_unpack(d) + sd = optimized.swap_uint4_pack(1) + nd = optimized.uint4_pack(1) + if sys.byteorder == 'little': + self.assertEqual(1, s) + self.assertEqual(16777216, n) + self.assertEqual(d, sd) + self.assertEqual(rd, nd) + else: + self.assertEqual(1, n) + self.assertEqual(16777216, s) + self.assertEqual(d, nd) + self.assertEqual(rd, sd) + self.assertRaises(OverflowError, optimized.swap_uint4_pack, -1) + self.assertRaises(OverflowError, optimized.uint4_pack, -1) + self.assertRaises(OverflowError, optimized.swap_uint4_pack, 2**32) + self.assertRaises(OverflowError, optimized.uint4_pack, 2**32) + self.assertEqual(optimized.uint4_pack(2**32-1), b'\xFF\xFF\xFF\xFF') + self.assertEqual(optimized.swap_uint4_pack(2**32-1), b'\xFF\xFF\xFF\xFF') + + def test_uint8(self): + d = b'\x00\x00\x00\x00\x00\x00\x00\x01' + rd = b'\x01\x00\x00\x00\x00\x00\x00\x00' + s = optimized.swap_uint8_unpack(d) + n = optimized.uint8_unpack(d) + sd = optimized.swap_uint8_pack(1) + nd = optimized.uint8_pack(1) + if sys.byteorder == 'little': + self.assertEqual(0x1, s) + self.assertEqual(0x100000000000000, n) + self.assertEqual(d, sd) + self.assertEqual(rd, nd) + else: + self.assertEqual(0x1, n) + self.assertEqual(0x100000000000000, s) + self.assertEqual(d, nd) + self.assertEqual(rd, sd) + self.assertRaises(OverflowError, optimized.swap_uint8_pack, -1) + self.assertRaises(OverflowError, optimized.uint8_pack, -1) + self.assertRaises(OverflowError, optimized.swap_uint8_pack, 2**64) + self.assertRaises(OverflowError, optimized.uint8_pack, 2**64) + self.assertEqual(optimized.uint8_pack((2**64)-1), b'\xFF\xFF\xFF\xFF'*2) + self.assertEqual(optimized.swap_uint8_pack((2**64)-1), b'\xFF\xFF\xFF\xFF'*2) + +if __name__ == '__main__': + from types import ModuleType + this = ModuleType("this") + this.__dict__.update(globals()) + unittest.main(this) diff --git a/py_opengauss/test/test_pgpassfile.py b/py_opengauss/test/test_pgpassfile.py new file mode 100644 index 0000000000000000000000000000000000000000..577c98100535b470a0eb8d7399f06bc26f6b78ee --- /dev/null +++ b/py_opengauss/test/test_pgpassfile.py @@ -0,0 +1,72 @@ +## +# .test.test_pgpassfile +## +import unittest +from .. import pgpassfile as client_pgpass +from io import StringIO + +passfile_sample = """ +# host:1111:dbname:user:password1 +host:1111:dbname:user:password1 +*:1111:dbname:user:password2 +*:*:dbname:user:password3 + +# Comment + +*:*:*:user:password4 +*:*:*:usern:password4.5 +*:*:*:*:password5 +""" + +passfile_sample_map = { + ('user', 'host', '1111', 'dbname') : 'password1', + ('user', 'host', '1111', 'dbname') : 'password1', + ('user', 'foo', '1111', 'dbname') : 'password2', + ('user', 'foo', '4321', 'dbname') : 'password3', + ('user', 'foo', '4321', 'db,name') : 'password4', + + ('uuser', 'foo', '4321', 'db,name') : 'password5', + ('usern', 'foo', '4321', 'db,name') : 'password4.5', + ('foo', 'bar', '19231', 'somedbn') : 'password5', +} + +difficult_passfile_sample = r""" +host\\:1111:db\:name:u\\ser:word1 +*:1111:\:dbname\::\\user\\:pass\:word2 +foohost:1111:\:dbname\::\\user\\:pass\:word3 +""" + +difficult_passfile_sample_map = { + ('u\\ser','host\\','1111','db:name') : 'word1', + ('\\user\\','somehost','1111',':dbname:') : 'pass:word2', + ('\\user\\','someotherhost','1111',':dbname:') : 'pass:word2', + # More specific, but comes after '*' + ('\\user\\','foohost','1111',':dbname:') : 'pass:word2', + ('','','','') : None, +} + +class test_pgpass(unittest.TestCase): + def runTest(self): + sample1 = client_pgpass.parse(StringIO(passfile_sample)) + sample2 = client_pgpass.parse(StringIO(difficult_passfile_sample)) + + for k, pw in passfile_sample_map.items(): + lpw = client_pgpass.lookup_password(sample1, k) + self.assertEqual(lpw, pw, + "password lookup incongruity, expecting %r got %r with %r" + " in \n%s" %( + pw, lpw, k, passfile_sample + ) + ) + + for k, pw in difficult_passfile_sample_map.items(): + lpw = client_pgpass.lookup_password(sample2, k) + self.assertEqual(lpw, pw, + "password lookup incongruity, expecting %r got %r with %r" + " in \n%s" %( + pw, lpw, k, difficult_passfile_sample + ) + ) + +if __name__ == '__main__': + unittest.main() diff --git a/py_opengauss/test/test_protocol.py b/py_opengauss/test/test_protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..1a2164774a637c4749a338100daf82a15d47ad20 --- /dev/null +++ b/py_opengauss/test/test_protocol.py @@ -0,0 +1,691 @@ +## +# .test.test_protocol +## +import sys +import unittest +import struct +import decimal +import socket +import time +from threading import Thread + +from ..protocol import element3 as e3 +from ..protocol import xact3 as x3 +from ..protocol import client3 as c3 +from ..protocol import buffer as pq_buf +from ..python.socket import find_available_port, SocketFactory + +def pair(msg): + return (msg.type, msg.serialize()) +def pairs(*msgseq): + return list(map(pair, msgseq)) + +long = struct.Struct("!L") +packl = long.pack +unpackl = long.unpack + +class test_buffer(unittest.TestCase): + def setUp(self): + self.buffer = pq_buf.pq_message_stream() + + def testMultiByteMessage(self): + b = self.buffer + b.write(b's') + self.assertTrue(b.next_message() is None) + b.write(b'\x00\x00') + self.assertTrue(b.next_message() is None) + b.write(b'\x00\x10') + self.assertTrue(b.next_message() is None) + data = b'twelve_chars' + b.write(data) + self.assertEqual(b.next_message(), (b's', data)) + + def testSingleByteMessage(self): + b = self.buffer + b.write(b's') + self.assertTrue(b.next_message() is None) + b.write(b'\x00') + self.assertTrue(b.next_message() is None) + b.write(b'\x00\x00\x05') + self.assertTrue(b.next_message() is None) + b.write(b'b') + self.assertEqual(b.next_message(), (b's', b'b')) + + def testEmptyMessage(self): + b = self.buffer + b.write(b'x') + self.assertTrue(b.next_message() is None) + b.write(b'\x00\x00\x00') + self.assertTrue(b.next_message() is None) + b.write(b'\x04') + self.assertEqual(b.next_message(), (b'x', b'')) + + def testInvalidLength(self): + b = self.buffer + b.write(b'y\x00\x00\x00\x03') + self.assertRaises(ValueError, b.next_message,) + + def testRemainder(self): + b = self.buffer + b.write(b'r\x00\x00\x00\x05Aremainder') + self.assertEqual(b.next_message(), (b'r', b'A')) + + def testLarge(self): + b = self.buffer + factor = 1024 + r = 10000 + b.write(b'X' + packl(factor * r + 4)) + segment = b'\x00' * factor + for x in range(r-1): + b.write(segment) + b.write(segment) + msg = b.next_message() + self.assertTrue(msg is not None) + self.assertEqual(msg[0], b'X') + + def test_getvalue(self): + # Make sure that getvalue() only applies to messages + # that have not been read. + b = self.buffer + # It should be empty. + self.assertEqual(b.getvalue(), b'') + d = b'F' + packl(28) + b.write(d) + self.assertEqual(b.getvalue(), d) + d1 = b'01'*12 # 24 + b.write(d1) + self.assertEqual(b.getvalue(), d + d1) + out = b.read()[0] + self.assertEqual(out, (b'F', d1)) + nd = b'N' + b.write(nd) + self.assertEqual(b.getvalue(), nd) + b.write(packl(4)) + self.assertEqual(list(b.read()), [(b'N', b'')]) + self.assertEqual(b.getvalue(), b'') + # partial; read one message to exercise + # that the appropriate fragment of the first + # chunk in the buffer is picked up. + first_body = (b'1234' * 3) + first = b'v' + packl(len(first_body) + 4) + first_body + second_body = (b'4321' * 5) + second = b'z' + packl(len(second_body) + 4) + second_body + b.write(first + second) + self.assertEqual(b.getvalue(), first + second) + self.assertEqual(list(b.read(1)), [(b'v', first_body)]) + self.assertEqual(b.getvalue(), second) + self.assertEqual(list(b.read(1)), [(b'z', second_body)]) + # now, with a third full message in the next chunk + third_body = (b'9876' * 10) + third = b'3' + packl(len(third_body) + 4) + third_body + b.write(first + second) + b.write(third) + self.assertEqual(b.getvalue(), first + second + third) + self.assertEqual(list(b.read(1)), [(b'v', first_body)]) + self.assertEqual(b.getvalue(), second + third) + self.assertEqual(list(b.read(1)), [(b'z', second_body)]) + self.assertEqual(b.getvalue(), third) + self.assertEqual(list(b.read(1)), [(b'3', third_body)]) + self.assertEqual(b.getvalue(), b'') + +## +# element3 tests +## + +message_samples = [ + e3.VoidMessage, + e3.Startup([ + (b'user', b'jwp'), + (b'database', b'template1'), + (b'options', b'-f'), + ]), + e3.Notice(( + (b'S', b'FATAL'), + (b'M', b'a descriptive message'), + (b'C', b'FIVEC'), + (b'D', b'bleh'), + (b'H', b'dont spit into the fan'), + )), + e3.Notify(123, b'wood_table'), + e3.KillInformation(19320, 589483), + e3.ShowOption(b'foo', b'bar'), + e3.Authentication(4, b'salt'), + e3.Complete(b'SELECT'), + e3.Ready(b'I'), + e3.CancelRequest(4123, 14252), + e3.NegotiateSSL(), + e3.Password(b'ckr4t'), + e3.AttributeTypes(()), + e3.AttributeTypes( + (123,) * 1 + ), + e3.AttributeTypes( + (123,0) * 1 + ), + e3.AttributeTypes( + (123,0) * 2 + ), + e3.AttributeTypes( + (123,0) * 4 + ), + e3.TupleDescriptor(()), + e3.TupleDescriptor(( + (b'name', 123, 1, 1, 0, 0, 1,), + )), + e3.TupleDescriptor(( + (b'name', 123, 1, 2, 0, 0, 1,), + ) * 2), + e3.TupleDescriptor(( + (b'name', 123, 1, 2, 1, 0, 1,), + ) * 3), + e3.TupleDescriptor(( + (b'name', 123, 1, 1, 0, 0, 1,), + ) * 1000), + e3.Tuple([]), + e3.Tuple([b'foo',]), + e3.Tuple([None]), + e3.Tuple([b'foo',b'bar']), + e3.Tuple([None, None]), + e3.Tuple([None, b'foo', None]), + e3.Tuple([b'bar', None, b'foo', None, b'bleh']), + e3.Tuple([b'foo', b'bar'] * 100), + e3.Tuple([None] * 100), + e3.Query(b'select * from u'), + e3.Parse(b'statement_id', b'query', (123, 0)), + e3.Parse(b'statement_id', b'query', (123,)), + e3.Parse(b'statement_id', b'query', ()), + e3.Bind(b'portal_id', b'statement_id', + (b'tt',b'\x00\x00'), + [b'data',None], (b'ff',b'xx')), + e3.Bind(b'portal_id', b'statement_id', (b'tt',), [None], (b'xx',)), + e3.Bind(b'portal_id', b'statement_id', (b'ff',), [b'data'], ()), + e3.Bind(b'portal_id', b'statement_id', (), [], (b'xx',)), + e3.Bind(b'portal_id', b'statement_id', (), [], ()), + e3.Execute(b'portal_id', 500), + e3.Execute(b'portal_id', 0), + e3.DescribeStatement(b'statement_id'), + e3.DescribePortal(b'portal_id'), + e3.CloseStatement(b'statement_id'), + e3.ClosePortal(b'portal_id'), + e3.Function(123, (), [], b'xx'), + e3.Function(321, (b'tt',), [b'foo'], b'xx'), + e3.Function(321, (b'tt',), [None], b'xx'), + e3.Function(321, (b'aa', b'aa'), [None,b'a' * 200], b'xx'), + e3.FunctionResult(b''), + e3.FunctionResult(b'foobar'), + e3.FunctionResult(None), + e3.CopyToBegin(123, [321,123]), + e3.CopyToBegin(0, [10,]), + e3.CopyToBegin(123, []), + e3.CopyFromBegin(123, [321,123]), + e3.CopyFromBegin(0, [10]), + e3.CopyFromBegin(123, []), + e3.CopyData(b''), + e3.CopyData(b'foo'), + e3.CopyData(b'a' * 2048), + e3.CopyFail(b''), + e3.CopyFail(b'iiieeeeee!'), +] + +class test_element3(unittest.TestCase): + def test_cat_messages(self): + # The optimized implementation will identify adjacent copy data, and + # take a more efficient route; so rigorously test the switch between the + # two modes. + self.assertEqual(e3.cat_messages([]), b'') + self.assertEqual(e3.cat_messages([b'foo']), b'd\x00\x00\x00\x07foo') + self.assertEqual(e3.cat_messages([b'foo', b'foo']), 2*b'd\x00\x00\x00\x07foo') + # copy, other, copy + self.assertEqual(e3.cat_messages([b'foo', e3.SynchronizeMessage, b'foo']), + b'd\x00\x00\x00\x07foo' + e3.SynchronizeMessage.bytes() + b'd\x00\x00\x00\x07foo') + # copy, other, copy*1000 + self.assertEqual(e3.cat_messages(1000*[b'foo', e3.SynchronizeMessage, b'foo']), + 1000*(b'd\x00\x00\x00\x07foo' + e3.SynchronizeMessage.bytes() + b'd\x00\x00\x00\x07foo')) + # other, copy, copy*1000 + self.assertEqual(e3.cat_messages(1000*[e3.SynchronizeMessage, b'foo', b'foo']), + 1000*(e3.SynchronizeMessage.bytes() + 2*b'd\x00\x00\x00\x07foo')) + pack_head = struct.Struct("!lH").pack + # tuple + self.assertEqual(e3.cat_messages([(b'foo',),]), + b'D' + pack_head(7 + 4 + 2, 1) + b'\x00\x00\x00\x03foo') + # tuple(foo,\N) + self.assertEqual(e3.cat_messages([(b'foo',None,),]), + b'D' + pack_head(7 + 4 + 4 + 2, 2) + b'\x00\x00\x00\x03foo\xFF\xFF\xFF\xFF') + # tuple(foo,\N,bar) + self.assertEqual(e3.cat_messages([(b'foo',None,b'bar'),]), + b'D' + pack_head(7 + 7 + 4 + 4 + 2, 3) + \ + b'\x00\x00\x00\x03foo\xFF\xFF\xFF\xFF\x00\x00\x00\x03bar') + # too many attributes + self.assertRaises((OverflowError, struct.error), + e3.cat_messages, [(None,) * 0x10000]) + + class ThisEx(Exception): + pass + class ThatEx(Exception): + pass + class Bad(e3.Message): + def serialize(self): + raise ThisEx('foo') + self.assertRaises(ThisEx, e3.cat_messages, [Bad()]) + class NoType(e3.Message): + def serialize(self): + return b'' + self.assertRaises(AttributeError, e3.cat_messages, [NoType()]) + class BadType(e3.Message): + type = 123 + def serialize(self): + return b'' + self.assertRaises((TypeError,struct.error), e3.cat_messages, [BadType()]) + + + def testSerializeParseConsistency(self): + for msg in message_samples: + smsg = msg.serialize() + self.assertEqual(msg, msg.parse(smsg)) + + def testEmptyMessages(self): + for x in e3.__dict__.values(): + if isinstance(x, e3.EmptyMessage): + xtype = type(x) + self.assertTrue(x is xtype()) + + def testUnknownNoticeFields(self): + N = e3.Notice.parse(b'\x00\x00Z\x00Xklsvdnvldsvkndvlsn\x00Pfoobar\x00Mmessage\x00') + E = e3.Error.parse(b'Z\x00Xklsvdnvldsvkndvlsn\x00Pfoobar\x00Mmessage\x00\x00') + self.assertEqual(N[b'M'], b'message') + self.assertEqual(E[b'M'], b'message') + self.assertEqual(N[b'P'], b'foobar') + self.assertEqual(E[b'P'], b'foobar') + self.assertEqual(len(N), 4) + self.assertEqual(len(E), 4) + + def testCompleteExtracts(self): + x = e3.Complete(b'FOO BAR 1321') + self.assertEqual(x.extract_command(), b'FOO BAR') + self.assertEqual(x.extract_count(), 1321) + x = e3.Complete(b' CREATE TABLE 13210 ') + self.assertEqual(x.extract_command(), b'CREATE TABLE') + self.assertEqual(x.extract_count(), 13210) + x = e3.Complete(b' CREATE TABLE \t713210 ') + self.assertEqual(x.extract_command(), b'CREATE TABLE') + self.assertEqual(x.extract_count(), 713210) + x = e3.Complete(b' CREATE TABLE 0 \t13210 ') + self.assertEqual(x.extract_command(), b'CREATE TABLE') + self.assertEqual(x.extract_count(), 13210) + x = e3.Complete(b' 0 \t13210 ') + self.assertEqual(x.extract_command(), None) + self.assertEqual(x.extract_count(), 13210) + +## +# .protocol.xact3 tests +## + +xact_samples = [ + # Simple contrived exchange. + ( + ( + e3.Query(b"COMPLETE"), + ), ( + e3.Complete(b'COMPLETE'), + e3.Ready(b'I'), + ) + ), + ( + ( + e3.Query(b"ROW DATA"), + ), ( + e3.TupleDescriptor(( + (b'foo', 1, 1, 1, 1, 1, 1), + (b'bar', 1, 2, 1, 1, 1, 1), + )), + e3.Tuple((b'lame', b'lame')), + e3.Complete(b'COMPLETE'), + e3.Ready(b'I'), + ) + ), + ( + ( + e3.Query(b"ROW DATA"), + ), ( + e3.TupleDescriptor(( + (b'foo', 1, 1, 1, 1, 1, 1), + (b'bar', 1, 2, 1, 1, 1, 1), + )), + e3.Tuple((b'lame', b'lame')), + e3.Tuple((b'lame', b'lame')), + e3.Tuple((b'lame', b'lame')), + e3.Tuple((b'lame', b'lame')), + e3.Ready(b'I'), + ) + ), + ( + ( + e3.Query(b"NULL"), + ), ( + e3.Null(), + e3.Ready(b'I'), + ) + ), + ( + ( + e3.Query(b"COPY TO"), + ), ( + e3.CopyToBegin(1, [1,2]), + e3.CopyData(b'row1'), + e3.CopyData(b'row2'), + e3.CopyDone(), + e3.Complete(b'COPY TO'), + e3.Ready(b'I'), + ) + ), + ( + ( + e3.Function(1, [b''], [b''], 1), + ), ( + e3.FunctionResult(b'foo'), + e3.Ready(b'I'), + ) + ), + ( + ( + e3.Parse(b"NAME", b"SQL", ()), + ), ( + e3.ParseComplete(), + ) + ), + ( + ( + e3.Bind(b"NAME", b"STATEMENT_ID", (), (), ()), + ), ( + e3.BindComplete(), + ) + ), + ( + ( + e3.Parse(b"NAME", b"SQL", ()), + e3.Bind(b"NAME", b"STATEMENT_ID", (), (), ()), + ), ( + e3.ParseComplete(), + e3.BindComplete(), + ) + ), + ( + ( + e3.Describe(b"STATEMENT_ID"), + ), ( + e3.AttributeTypes(()), + e3.NoData(), + ) + ), + ( + ( + e3.Describe(b"STATEMENT_ID"), + ), ( + e3.AttributeTypes(()), + e3.TupleDescriptor(()), + ) + ), + ( + ( + e3.CloseStatement(b"foo"), + ), ( + e3.CloseComplete(), + ), + ), + ( + ( + e3.ClosePortal(b"foo"), + ), ( + e3.CloseComplete(), + ), + ), + ( + ( + e3.Synchronize(), + ), ( + e3.Ready(b'I'), + ), + ), +] + +class test_xact3(unittest.TestCase): + def testTransactionSamplesAll(self): + for xcmd, xres in xact_samples: + x = x3.Instruction(xcmd) + r = tuple([(y.type, y.serialize()) for y in xres]) + x.state[1]() + self.assertEqual(x.messages, ()) + x.state[1](r) + self.assertEqual(x.state, x3.Complete) + rec = [] + for y in x.completed: + for z in y[1]: + if type(z) is type(b''): + z = e3.CopyData(z) + rec.append(z) + self.assertEqual(xres, tuple(rec)) + + def testClosing(self): + c = x3.Closing() + self.assertEqual(c.messages, (e3.DisconnectMessage,)) + c.state[1]() + self.assertEqual(c.fatal, True) + self.assertEqual(c.error_message.__class__, e3.ClientError) + self.assertEqual(c.error_message[b'C'], '08003') + + def testNegotiation(self): + # simple successful run + n = x3.Negotiation({}, b'') + n.state[1]() + n.state[1]( + pairs( + e3.Notice(((b'M', b"foobar"),)), + e3.Authentication(e3.AuthRequest_OK, b''), + e3.KillInformation(0,0), + e3.ShowOption(b'name', b'val'), + e3.Ready(b'I'), + ) + ) + self.assertEqual(n.state, x3.Complete) + self.assertEqual(n.last_ready.xact_state, b'I') + # no killinfo.. should cause protocol error... + n = x3.Negotiation({}, b'') + n.state[1]() + n.state[1]( + pairs( + e3.Notice(((b'M', b"foobar"),)), + e3.Authentication(e3.AuthRequest_OK, b''), + e3.ShowOption(b'name', b'val'), + e3.Ready(b'I'), + ) + ) + self.assertEqual(n.state, x3.Complete) + self.assertEqual(n.last_ready, None) + self.assertEqual(n.error_message[b'C'], '08P01') + # killinfo twice.. must cause protocol error... + n = x3.Negotiation({}, b'') + n.state[1]() + n.state[1]( + pairs( + e3.Notice(((b'M', b"foobar"),)), + e3.Authentication(e3.AuthRequest_OK, b''), + e3.ShowOption(b'name', b'val'), + e3.KillInformation(0,0), + e3.KillInformation(0,0), + e3.Ready(b'I'), + ) + ) + self.assertEqual(n.state, x3.Complete) + self.assertEqual(n.last_ready, None) + self.assertEqual(n.error_message[b'C'], '08P01') + # start with ready message.. + n = x3.Negotiation({}, b'') + n.state[1]() + n.state[1]( + pairs( + e3.Notice(((b'M', b"foobar"),)), + e3.Ready(b'I'), + e3.Authentication(e3.AuthRequest_OK, b''), + e3.ShowOption(b'name', b'val'), + ) + ) + self.assertEqual(n.state, x3.Complete) + self.assertEqual(n.last_ready, None) + self.assertEqual(n.error_message[b'C'], '08P01') + # unsupported authreq + n = x3.Negotiation({}, b'') + n.state[1]() + n.state[1]( + pairs( + e3.Authentication(255, b''), + ) + ) + self.assertEqual(n.state, x3.Complete) + self.assertEqual(n.last_ready, None) + self.assertEqual(n.error_message[b'C'], '--AUT') + + def testInstructionAsynchook(self): + l = [] + def hook(data): + l.append(data) + x = x3.Instruction([ + e3.Query(b"NOTHING") + ], asynchook = hook) + a1 = e3.Notice(((b'M', b"m1"),)) + a2 = e3.Notify(0, b'relation', b'parameter') + a3 = e3.ShowOption(b'optname', b'optval') + # "send" the query message + x.state[1]() + # "receive" the tuple + x.state[1]([(a1.type, a1.serialize()),]) + a2l = [(a2.type, a2.serialize()),] + x.state[1](a2l) + # validate that the hook is not fed twice because + # it's the exact same message set. (later assertion will validate) + x.state[1](a2l) + x.state[1]([(a3.type, a3.serialize()),]) + # we only care about validating that l got everything. + self.assertEqual([a1,a2,a3], l) + self.assertEqual(x.state[0], x3.Receiving) + # validate that the asynchook exception is trapped. + class Nee(Exception): + pass + def ehook(msg): + raise Nee("this should **not** be part of the summary") + x = x3.Instruction([ + e3.Query(b"NOTHING") + ], asynchook = ehook) + a1 = e3.Notice(((b'M', b"m1"),)) + x.state[1]() + import sys + v = None + def exchook(typ, val, tb): + nonlocal v + v = val + seh = sys.excepthook + sys.excepthook = exchook + # we only care about validating that the exchook got called. + x.state[1]([(a1.type, a1.serialize())]) + sys.excepthook = seh + self.assertTrue(isinstance(v, Nee)) + +class test_client3(unittest.TestCase): + def test_timeout(self): + portnum = find_available_port() + servsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + with servsock: + servsock.bind(('localhost', portnum)) + pc = c3.Connection( + SocketFactory( + (socket.AF_INET, socket.SOCK_STREAM), + ('localhost', portnum) + ), + {} + ) + pc.connect(timeout = 1) + try: + self.assertEqual(pc.xact.fatal, True) + self.assertEqual(pc.xact.__class__, x3.Negotiation) + finally: + if pc.socket is not None: + pc.socket.close() + + def test_SSL_failure(self): + portnum = find_available_port() + servsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + with servsock: + servsock.bind(('localhost', portnum)) + pc = c3.Connection( + SocketFactory( + (socket.AF_INET, socket.SOCK_STREAM), + ('localhost', portnum) + ), + {} + ) + exc = None + servsock.listen(1) + def client_thread(): + pc.connect(ssl = True) + client = Thread(target = client_thread) + try: + client.start() + c, addr = servsock.accept() + with c: + c.send(b'S') + c.sendall(b'0000000000000000000000') + c.recv(1024) + c.close() + client.join() + finally: + if pc.socket is not None: + pc.socket.close() + + self.assertEqual(pc.xact.fatal, True) + self.assertEqual(pc.xact.__class__, x3.Negotiation) + self.assertEqual(pc.xact.error_message.__class__, e3.ClientError) + self.assertTrue(hasattr(pc.xact, 'exception')) + + def test_bad_negotiation(self): + portnum = find_available_port() + servsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + servsock.bind(('localhost', portnum)) + pc = c3.Connection( + SocketFactory( + (socket.AF_INET, socket.SOCK_STREAM), + ('localhost', portnum) + ), + {} + ) + exc = None + servsock.listen(1) + def client_thread(): + pc.connect() + client = Thread(target = client_thread) + try: + client.start() + c, addr = servsock.accept() + try: + c.recv(1024) + finally: + c.close() + time.sleep(0.25) + client.join() + servsock.close() + self.assertEqual(pc.xact.fatal, True) + self.assertEqual(pc.xact.__class__, x3.Negotiation) + self.assertEqual(pc.xact.error_message.__class__, e3.ClientError) + self.assertEqual(pc.xact.error_message[b'C'], '08006') + finally: + servsock.close() + if pc.socket is not None: + pc.socket.close() + +if __name__ == '__main__': + from types import ModuleType + this = ModuleType("this") + this.__dict__.update(globals()) + try: + unittest.main(this) + finally: + import gc + gc.collect() diff --git a/py_opengauss/test/test_python.py b/py_opengauss/test/test_python.py new file mode 100644 index 0000000000000000000000000000000000000000..33bb39c22a6952c26cd9ed88430f4b6926149b97 --- /dev/null +++ b/py_opengauss/test/test_python.py @@ -0,0 +1,180 @@ +## +# .test.test_python +## +import unittest +import socket +import errno +import struct +from itertools import chain +from operator import methodcaller +from contextlib import contextmanager + +from ..python.itertools import interlace +from ..python.structlib import split_sized_data +from ..python import functools +from ..python import itertools +from ..python.socket import find_available_port +from ..python import element + +class Ele(element.Element): + _e_label = property( + lambda x: getattr(x, 'label', 'ELEMENT') + ) + _e_factors = ('ancestor', 'secondary') + secondary = None + + def __init__(self, s = None): + self.ancestor = s + + def __str__(self): + return 'STRDATA' + + def _e_metas(self): + yield ('first', getattr(self, 'first', 'firstv')) + yield ('second', getattr(self, 'second', 'secondv')) + +class test_element(unittest.TestCase): + def test_primary_factor(self): + x = Ele() + # no factors + self.assertEqual(element.prime_factor(object()), None) + self.assertEqual(element.prime_factor(x), ('ancestor', None)) + y = Ele(x) + self.assertEqual(element.prime_factor(y), ('ancestor', x)) + + def test_primary_factors(self): + x = Ele() + x.ancestor = x + self.assertRaises( + element.RecursiveFactor, list, element.prime_factors(x) + ) + y = Ele(x) + x.ancestor = y + self.assertRaises( + element.RecursiveFactor, list, element.prime_factors(y) + ) + self.assertRaises( + element.RecursiveFactor, list, element.prime_factors(x) + ) + x.ancestor = None + z = Ele(y) + self.assertEqual(list(element.prime_factors(z)), [ + ('ancestor', y), + ('ancestor', x), + ('ancestor', None), + ]) + + def test_format_element(self): + # Considering that this is subject to change, frequently, + # I/O equality tests are inappropriate. + # Rather, a hierarchy will be defined, and the existence + # of certain pieces of information in the string will be validated. + x = Ele() + y = Ele() + z = Ele() + alt1 = Ele() + alt2 = Ele() + alt1.first = 'alt1-first' + alt1.second = 'alt1-second' + alt2.first = 'alt2-first' + alt2.second = 'alt2-second' + altprime = Ele() + altprime.first = 'alt2-ancestor' + alt2.ancestor = altprime + z.ancestor = y + y.ancestor = x + z.secondary = alt1 + y.secondary = alt2 + x.first = 'unique1' + y.first = 'unique2' + x.second = 'unique3' + z.second = 'unique4' + y.label = 'DIFF' + data = element.format_element(z) + self.assertTrue(x.first in data) + self.assertTrue(y.first in data) + self.assertTrue(x.second in data) + self.assertTrue(z.second in data) + self.assertTrue('DIFF' in data) + self.assertTrue('alt1-first' in data) + self.assertTrue('alt2-first' in data) + self.assertTrue('alt1-second' in data) + self.assertTrue('alt2-second' in data) + self.assertTrue('alt2-ancestor' in data) + x.ancestor = z + self.assertRaises(element.RecursiveFactor, element.format_element, z) + +class test_itertools(unittest.TestCase): + def testInterlace(self): + i1 = range(0, 100, 4) + i2 = range(1, 100, 4) + i3 = range(2, 100, 4) + i4 = range(3, 100, 4) + self.assertEqual( + list(itertools.interlace(i1, i2, i3, i4)), + list(range(100)) + ) + +class test_functools(unittest.TestCase): + def testComposition(self): + compose = functools.Composition + simple = compose((int, str)) + self.assertEqual("100", simple("100")) + timesfour_fourtimes = compose((methodcaller('__mul__', 4),)*4) + self.assertEqual(4*(4*4*4*4), timesfour_fourtimes(4)) + nothing = compose(()) + self.assertEqual(nothing("100"), "100") + self.assertEqual(nothing(100), 100) + self.assertEqual(nothing(None), None) + + def testRSetAttr(self): + class anob(object): + pass + ob = anob() + self.assertRaises(AttributeError, getattr, ob, 'foo') + rob = functools.rsetattr('foo', 'bar', ob) + self.assertTrue(rob is ob) + self.assertTrue(rob.foo is ob.foo) + self.assertTrue(rob.foo == 'bar') + +class test_socket(unittest.TestCase): + def testFindAvailable(self): + # Host sanity check; this is likely fragile. + for i in range(4): + portnum = find_available_port() + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + s.connect(('localhost', portnum)) + except socket.error as err: + self.assertEqual(err.errno, errno.ECONNREFUSED) + else: + self.fail("got a connection to an available port: " + str(portnum)) + finally: + s.close() + +def join_sized_data(*data, + packL = struct.Struct("!L").pack, + getlen = lambda x: len(x) if x is not None else 0xFFFFFFFF +): + return b''.join(interlace(map(packL, map(getlen, data)), (x if x is not None else b'' for x in data))) + +class test_structlib(unittest.TestCase): + def testSizedSplit(self): + sample = [ + (b'foo', b'bar'), + (b'foo', None, b'bar'), + (b'foo', None, b'bar'), + (b'foo', b'bar'), + (), + (None,None,None), + (b'x', None,None,None, b'yz'), + ] + packed_sample = [join_sized_data(*x) for x in sample] + self.assertRaises(ValueError, split_sized_data(b'\xFF\xFF\xFF\x01foo').__next__) + self.assertEqual(sample, [tuple(split_sized_data(x)) for x in packed_sample]) + +if __name__ == '__main__': + from types import ModuleType + this = ModuleType("this") + this.__dict__.update(globals()) + unittest.main(this) diff --git a/py_opengauss/test/test_ssl_connect.py b/py_opengauss/test/test_ssl_connect.py new file mode 100644 index 0000000000000000000000000000000000000000..fc72c86c7d961ab5b36d8bbbd0539f0bd2e20698 --- /dev/null +++ b/py_opengauss/test/test_ssl_connect.py @@ -0,0 +1,294 @@ +## +# .test.test_ssl_connect +## +import sys +import os +import unittest + +from .. import exceptions as pg_exc +from .. import driver as pg_driver +from ..driver import dbapi20 +from . import test_connect + +default_installation = test_connect.default_installation + +has_ssl = False +if default_installation is not None: + has_ssl = default_installation.ssl + +server_key = """ +-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQCy8veVaqL6MZVT8o0j98ggZYfibGwSN4XGC4rfineA2QZhi8t+ +zrzfOS10vLXKtgiIpevHeQbDlrqFDPUDowozurg+jfro2L1jzQjZPdgqOUs+YjKh +EO0Ya7NORO7ZgBx8WveXq30k4l8DK41jvpxRyBb9aqNWG4cB7fJqVTwZrwIDAQAB +AoGAJ74URGfheEVoz7MPq4xNMvy5mAzSV51jJV/M4OakscYBR8q/UBNkGQNe2A1N +Jo8VCBwpaCy11txz4jbFd6BPFFykgXleuRvMxoTv1qV0dZZ0X0ESNEAnjoHtjin/ +25mxsZTR6ucejHqXD9qE9NvFQ+wLv6Xo5rgDpx0onvgLA3kCQQDn4GeMkCfPZCve +lDUK+TpJnLYupyElZiidoFMITlFo5WoWNJror2W42A5TD9sZ23pGSxw7ypiWIF4f +ukGT5ZSzAkEAxZDwUUhgtoJIK7E9sCJM4AvcjDxGjslbUI/SmQTT+aTNCAmcIRrl +kq3WMkPjxi/QFEdkIpPsV9Kc94oQ/8b9FQJBAKHxRQCTsWoTsNvbsIwAcif1Lfu5 +N9oR1i34SeVUJWFYUFY/2SzHSwjkxGRYf5I4idZMIOTVYun+ox4PjDtJrScCQEQ4 +RiNrIKok1pLvwuNdFLqQnfl2ns6TTQrGfuwDtMaRV5Mc7mKoDPnXOQ1mT/KRdAJs +nHEsLwIsYbNAY5pOtfkCQDOy2Ffe7Z1YzFZXCTzpcq4mvMOPEUqlIX6hACNJGhgt +1EpruPwqR2PYDOIC4sXCaSogL8YyjI+Jlhm5kEJ4GaU= +-----END RSA PRIVATE KEY----- +""" + +server_crt = """ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: + a1:02:62:34:22:0d:45:6a + Signature Algorithm: md5WithRSAEncryption + Issuer: C=US, ST=Arizona, L=Nowhere, O=ACME Inc, OU=Test Division, CN=test.python.projects.postgresql.org + Validity + Not Before: Feb 18 15:52:20 2009 GMT + Not After : Mar 20 15:52:20 2009 GMT + Subject: C=US, ST=Arizona, L=Nowhere, O=ACME Inc, OU=Test Division, CN=test.python.projects.postgresql.org + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + RSA Public Key: (1024 bit) + Modulus (1024 bit): + 00:b2:f2:f7:95:6a:a2:fa:31:95:53:f2:8d:23:f7: + c8:20:65:87:e2:6c:6c:12:37:85:c6:0b:8a:df:8a: + 77:80:d9:06:61:8b:cb:7e:ce:bc:df:39:2d:74:bc: + b5:ca:b6:08:88:a5:eb:c7:79:06:c3:96:ba:85:0c: + f5:03:a3:0a:33:ba:b8:3e:8d:fa:e8:d8:bd:63:cd: + 08:d9:3d:d8:2a:39:4b:3e:62:32:a1:10:ed:18:6b: + b3:4e:44:ee:d9:80:1c:7c:5a:f7:97:ab:7d:24:e2: + 5f:03:2b:8d:63:be:9c:51:c8:16:fd:6a:a3:56:1b: + 87:01:ed:f2:6a:55:3c:19:af + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Subject Key Identifier: + 4B:2F:4F:1A:43:75:43:DC:26:59:89:48:56:73:BB:D0:AA:95:E8:60 + X509v3 Authority Key Identifier: + keyid:4B:2F:4F:1A:43:75:43:DC:26:59:89:48:56:73:BB:D0:AA:95:E8:60 + DirName:/C=US/ST=Arizona/L=Nowhere/O=ACME Inc/OU=Test Division/CN=test.python.projects.postgresql.org + serial:A1:02:62:34:22:0D:45:6A + + X509v3 Basic Constraints: + CA:TRUE + Signature Algorithm: md5WithRSAEncryption + 24:ee:20:0f:b5:86:08:d6:3c:8f:d4:8d:16:fd:ac:e8:49:77: + 86:74:7d:b8:f3:15:51:1d:d8:65:17:5e:a8:58:aa:b0:f6:68: + 45:cb:77:9d:9f:21:81:e3:5e:86:1c:64:31:39:b6:29:5f:f1: + ec:b1:33:45:1f:0c:54:16:26:11:af:e2:23:1b:a6:03:46:9b: + 0e:63:ce:2c:02:41:26:93:bc:6f:6e:08:7e:95:b7:7a:f9:3a: + 5a:bd:47:4c:92:ce:ea:09:75:de:3d:bb:30:51:a0:c5:f1:5d: + 33:5f:c0:37:75:53:4e:6c:b4:3b:b1:a5:1b:fd:59:19:07:18: + 22:6a +-----BEGIN CERTIFICATE----- +MIIDhzCCAvCgAwIBAgIJAKECYjQiDUVqMA0GCSqGSIb3DQEBBAUAMIGKMQswCQYD +VQQGEwJVUzEQMA4GA1UECBMHQXJpem9uYTEQMA4GA1UEBxMHTm93aGVyZTERMA8G +A1UEChMIQUNNRSBJbmMxFjAUBgNVBAsTDVRlc3QgRGl2aXNpb24xLDAqBgNVBAMT +I3Rlc3QucHl0aG9uLnByb2plY3RzLnBvc3RncmVzcWwub3JnMB4XDTA5MDIxODE1 +NTIyMFoXDTA5MDMyMDE1NTIyMFowgYoxCzAJBgNVBAYTAlVTMRAwDgYDVQQIEwdB +cml6b25hMRAwDgYDVQQHEwdOb3doZXJlMREwDwYDVQQKEwhBQ01FIEluYzEWMBQG +A1UECxMNVGVzdCBEaXZpc2lvbjEsMCoGA1UEAxMjdGVzdC5weXRob24ucHJvamVj +dHMucG9zdGdyZXNxbC5vcmcwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBALLy +95VqovoxlVPyjSP3yCBlh+JsbBI3hcYLit+Kd4DZBmGLy37OvN85LXS8tcq2CIil +68d5BsOWuoUM9QOjCjO6uD6N+ujYvWPNCNk92Co5Sz5iMqEQ7Rhrs05E7tmAHHxa +95erfSTiXwMrjWO+nFHIFv1qo1YbhwHt8mpVPBmvAgMBAAGjgfIwge8wHQYDVR0O +BBYEFEsvTxpDdUPcJlmJSFZzu9CqlehgMIG/BgNVHSMEgbcwgbSAFEsvTxpDdUPc +JlmJSFZzu9CqlehgoYGQpIGNMIGKMQswCQYDVQQGEwJVUzEQMA4GA1UECBMHQXJp +em9uYTEQMA4GA1UEBxMHTm93aGVyZTERMA8GA1UEChMIQUNNRSBJbmMxFjAUBgNV +BAsTDVRlc3QgRGl2aXNpb24xLDAqBgNVBAMTI3Rlc3QucHl0aG9uLnByb2plY3Rz +LnBvc3RncmVzcWwub3JnggkAoQJiNCINRWowDAYDVR0TBAUwAwEB/zANBgkqhkiG +9w0BAQQFAAOBgQAk7iAPtYYI1jyP1I0W/azoSXeGdH248xVRHdhlF16oWKqw9mhF +y3ednyGB416GHGQxObYpX/HssTNFHwxUFiYRr+IjG6YDRpsOY84sAkEmk7xvbgh+ +lbd6+TpavUdMks7qCXXePbswUaDF8V0zX8A3dVNObLQ7saUb/VkZBxgiag== +-----END CERTIFICATE----- +""" + +class test_ssl_connect(test_connect.test_connect): + """ + Run test_connect, but with SSL. + """ + params = {'sslmode' : 'require'} + cluster_path_suffix = '_test_ssl_connect' + + def configure_cluster(self): + if not has_ssl: + return + + super().configure_cluster() + self.cluster.settings['ssl'] = 'on' + with open(self.cluster.hba_file, 'a') as hba: + hba.writelines([ + # nossl user + "\n", + "hostnossl test nossl 0::0/0 trust\n", + "hostnossl test nossl 0.0.0.0/0 trust\n", + # ssl-only user + "hostssl test sslonly 0.0.0.0/0 trust\n", + "hostssl test sslonly 0::0/0 trust\n", + ]) + key_file = os.path.join(self.cluster.data_directory, 'server.key') + crt_file = os.path.join(self.cluster.data_directory, 'server.crt') + with open(key_file, 'w') as key: + key.write(server_key) + with open(crt_file, 'w') as crt: + crt.write(server_crt) + os.chmod(key_file, 0o700) + os.chmod(crt_file, 0o700) + + def initialize_database(self): + if not has_ssl: + return + + super().initialize_database() + with self.cluster.connection(user = 'test') as db: + db.execute( + """ +CREATE USER nossl; +CREATE USER sslonly; + """ + ) + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + @unittest.skipIf(not has_ssl, "could not detect installation tls") + def test_ssl_mode_require(self): + host, port = self.cluster.address() + params = dict(self.params) + params['sslmode'] = 'require' + try: + pg_driver.connect( + user = 'nossl', + database = 'test', + host = host, + port = port, + **params + ) + self.fail("successful connection to nossl user when sslmode = 'require'") + except pg_exc.ClientCannotConnectError as err: + for pq in err.database.failures: + x = pq.error + dossl = pq.ssl_negotiation + if isinstance(x, pg_exc.AuthenticationSpecificationError) and dossl is True: + break + else: + # let it show as a failure. + raise + with pg_driver.connect( + host = host, + port = port, + user = 'sslonly', + database = 'test', + **params + ) as c: + self.assertEqual(c.prepare('select 1').first(), 1) + self.assertEqual(c.security, 'ssl') + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + @unittest.skipIf(not has_ssl, "could not detect installation tls") + def test_ssl_mode_disable(self): + host, port = self.cluster.address() + params = dict(self.params) + params['sslmode'] = 'disable' + try: + pg_driver.connect( + user = 'sslonly', + database = 'test', + host = host, + port = port, + **params + ) + self.fail("successful connection to sslonly user with sslmode = 'disable'") + except pg_exc.ClientCannotConnectError as err: + for pq in err.database.failures: + x = pq.error + if isinstance(x, pg_exc.AuthenticationSpecificationError) and not hasattr(pq, 'ssl_negotiation'): + # looking for an authspec error... + break + else: + # let it show as a failure. + raise + + with pg_driver.connect( + host = host, + port = port, + user = 'nossl', + database = 'test', + **params + ) as c: + self.assertEqual(c.prepare('select 1').first(), 1) + self.assertEqual(c.security, None) + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + @unittest.skipIf(not has_ssl, "could not detect installation tls") + def test_ssl_mode_prefer(self): + host, port = self.cluster.address() + params = dict(self.params) + params['sslmode'] = 'prefer' + with pg_driver.connect( + user = 'sslonly', + host = host, + port = port, + database = 'test', + **params + ) as c: + self.assertEqual(c.prepare('select 1').first(), 1) + self.assertEqual(c.security, 'ssl') + + with pg_driver.connect( + user = 'test', + host = host, + port = port, + database = 'test', + **params + ) as c: + self.assertEqual(c.security, 'ssl') + + with pg_driver.connect( + user = 'nossl', + host = host, + port = port, + database = 'test', + **params + ) as c: + self.assertEqual(c.prepare('select 1').first(), 1) + self.assertEqual(c.security, None) + + @unittest.skipIf(default_installation is None, "no installation provided by environment") + @unittest.skipIf(not has_ssl, "could not detect installation tls") + def test_ssl_mode_allow(self): + host, port = self.cluster.address() + params = dict(self.params) + params['sslmode'] = 'allow' + + # nossl user (hostnossl) + with pg_driver.connect( + user = 'nossl', + database = 'test', + host = host, + port = port, + **params + ) as c: + self.assertEqual(c.prepare('select 1').first(), 1) + self.assertEqual(c.security, None) + + # test user (host) + with pg_driver.connect( + user = 'test', + host = host, + port = port, + database = 'test', + **params + ) as c: + self.assertEqual(c.security, None) + + # sslonly user (hostssl) + with pg_driver.connect( + user = 'sslonly', + host = host, + port = port, + database = 'test', + **params + ) as c: + self.assertEqual(c.prepare('select 1').first(), 1) + self.assertEqual(c.security, 'ssl') + +if __name__ == '__main__': + unittest.main() diff --git a/py_opengauss/test/test_string.py b/py_opengauss/test/test_string.py new file mode 100644 index 0000000000000000000000000000000000000000..5d5089d18d075d2ecad7a65eaea711bf44f19807 --- /dev/null +++ b/py_opengauss/test/test_string.py @@ -0,0 +1,308 @@ +## +# .test.test_string +## +import sys +import os +import unittest +from .. import string as pg_str + +# strange possibility, split, normalized +split_qname_samples = [ + ('base', ['base'], 'base'), + ('bASe', ['base'], 'base'), + ('"base"', ['base'], 'base'), + ('"base "', ['base '], '"base "'), + ('" base"', [' base'], '" base"'), + ('" base"""', [' base"'], '" base"""'), + ('""" base"""', ['" base"'], '""" base"""'), + ('".base"', ['.base'], '".base"'), + ('".base."', ['.base.'], '".base."'), + ('schema.base', ['schema', 'base'], 'schema.base'), + ('"schema".base', ['schema', 'base'], 'schema.base'), + ('schema."base"', ['schema', 'base'], 'schema.base'), + ('"schema.base"', ['schema.base'], '"schema.base"'), + ('schEmÅ."base"', ['schemå', 'base'], 'schemå.base'), + ('scheMa."base"', ['schema', 'base'], 'schema.base'), + ('sche_ma.base', ['sche_ma', 'base'], 'sche_ma.base'), + ('_schema.base', ['_schema', 'base'], '_schema.base'), + ('a000.b111', ['a000', 'b111'], 'a000.b111'), + ('" schema"."base"', [' schema', 'base'], '" schema".base'), + ('" schema"."ba se"', [' schema', 'ba se'], '" schema"."ba se"'), + ('" ""schema"."ba""se"', [' "schema', 'ba"se'], '" ""schema"."ba""se"'), + ('" schema" . "ba se"', [' schema', 'ba se'], '" schema"."ba se"'), + (' " schema" . "ba se" ', [' schema', 'ba se'], '" schema"."ba se"'), + (' ". schema." . "ba se" ', ['. schema.', 'ba se'], '". schema."."ba se"'), + ('CAT . ". schema." . "ba se" ', ['cat', '. schema.', 'ba se'], + 'cat.". schema."."ba se"'), + ('"cat" . ". schema." . "ba se" ', ['cat', '. schema.', 'ba se'], + 'cat.". schema."."ba se"'), + ('"""cat" . ". schema." . "ba se" ', ['"cat', '. schema.', 'ba se'], + '"""cat".". schema."."ba se"'), + ('"""cÅt" . ". schema." . "ba se" ', ['"cÅt', '. schema.', 'ba se'], + '"""cÅt".". schema."."ba se"'), +] + +split_samples = [ + ('', ['']), + ('one-to-one', ['one-to-one']), + ('"one-to-one"', [ + '', + ('"', 'one-to-one'), + '' + ]), + ('$$one-to-one$$', [ + '', + ('$$', 'one-to-one'), + '' + ]), + ("E'one-to-one'", [ + '', + ("E'", 'one-to-one'), + '' + ]), + ("E'on''e-to-one'", [ + '', + ("E'", "on''e-to-one"), + '' + ]), + ("E'on''e-to-\\'one'", [ + '', + ("E'", "on''e-to-\\'one"), + '' + ]), + ("'one\\'-to-one'", [ + '', + ("'", "one\\"), + "-to-one", + ("'", ''), + ]), + + ('"foo"""', [ + '', + ('"', 'foo""'), + '', + ]), + + ('"""foo"', [ + '', + ('"', '""foo'), + '', + ]), + + ("'''foo'", [ + '', + ("'", "''foo"), + '', + ]), + ("'foo'''", [ + '', + ("'", "foo''"), + '', + ]), + ("E'foo\\''", [ + '', + ("E'", "foo\\'"), + '', + ]), + (r"E'foo\\' '", [ + '', + ("E'", r"foo\\"), + ' ', + ("'", ''), + ]), + (r"E'foo\\'' '", [ + '', + ("E'", r"foo\\'' "), + '', + ]), + + ('select \'foo\' as "one"', [ + 'select ', + ("'", 'foo'), + ' as ', + ('"', 'one'), + '' + ]), + ('select $$foo$$ as "one"', [ + 'select ', + ("$$", 'foo'), + ' as ', + ('"', 'one'), + '' + ]), + ('select $b$foo$b$ as "one"', [ + 'select ', + ("$b$", 'foo'), + ' as ', + ('"', 'one'), + '' + ]), + ('select $b$', [ + 'select ', + ('$b$', ''), + ]), + + ('select $1', [ + 'select $1', + ]), + + ('select $1$', [ + 'select $1$', + ]), +] + +split_sql_samples = [ + ('select 1; select 1', [ + ['select 1'], + [' select 1'] + ]), + ('select \'one\' as "text"; select 1', [ + ['select ', ("'", 'one'), ' as ', ('"', 'text'), ''], + [' select 1'] + ]), + ('select \'one\' as "text"; select 1', [ + ['select ', ("'", 'one'), ' as ', ('"', 'text'), ''], + [' select 1'] + ]), + ('select \'one;\' as ";text;"; select 1; foo', [ + ['select ', ("'", 'one;'), ' as ', ('"', ';text;'), ''], + (' select 1',), + [' foo'], + ]), + ('select \'one;\' as ";text;"; select $$;$$; foo', [ + ['select ', ("'", 'one;'), ' as ', ('"', ';text;'), ''], + [' select ', ('$$', ';'), ''], + [' foo'], + ]), + ('select \'one;\' as ";text;"; select $$;$$; foo;\';b\'\'ar\'', [ + ['select ', ("'", 'one;'), ' as ', ('"', ';text;'), ''], + [' select ', ('$$', ';'), ''], + (' foo',), + ['', ("'", ";b''ar"), ''], + ]), +] + +class test_strings(unittest.TestCase): + def test_split(self): + for unsplit, split in split_samples: + xsplit = list(pg_str.split(unsplit)) + self.assertEqual(xsplit, split) + self.assertEqual(pg_str.unsplit(xsplit), unsplit) + + def test_split_sql(self): + for unsplit, split in split_sql_samples: + xsplit = list(pg_str.split_sql(unsplit)) + self.assertEqual(xsplit, split) + self.assertEqual(';'.join([pg_str.unsplit(x) for x in xsplit]), unsplit) + + def test_qname(self): + "indirectly tests split_using" + for unsplit, split, norm in split_qname_samples: + xsplit = pg_str.split_qname(unsplit) + self.assertEqual(xsplit, split) + self.assertEqual(pg_str.qname_if_needed(*split), norm) + + self.assertRaises( + ValueError, + pg_str.split_qname, '"foo' + ) + self.assertRaises( + ValueError, + pg_str.split_qname, 'foo"' + ) + self.assertRaises( + ValueError, + pg_str.split_qname, 'bar.foo"' + ) + self.assertRaises( + ValueError, + pg_str.split_qname, 'bar".foo"' + ) + self.assertRaises( + ValueError, + pg_str.split_qname, '0bar.foo' + ) + self.assertRaises( + ValueError, + pg_str.split_qname, 'bar.fo@' + ) + + def test_quotes(self): + self.assertEqual( + pg_str.quote_literal("""foo'bar"""), + """'foo''bar'""" + ) + self.assertEqual( + pg_str.quote_literal("""\\foo'bar\\"""), + """'\\foo''bar\\'""" + ) + self.assertEqual( + pg_str.quote_ident_if_needed("foo"), + "foo" + ) + self.assertEqual( + pg_str.quote_ident_if_needed("0foo"), + '"0foo"' + ) + self.assertEqual( + pg_str.quote_ident_if_needed("foo0"), + 'foo0' + ) + self.assertEqual( + pg_str.quote_ident_if_needed("_"), + '_' + ) + self.assertEqual( + pg_str.quote_ident_if_needed("_9"), + '_9' + ) + self.assertEqual( + pg_str.quote_ident_if_needed('''\\foo'bar\\'''), + '''"\\foo'bar\\"''' + ) + self.assertEqual( + pg_str.quote_ident("spam"), + '"spam"' + ) + self.assertEqual( + pg_str.qname("spam", "ham"), + '"spam"."ham"' + ) + self.assertEqual( + pg_str.escape_ident('"'), + '""', + ) + self.assertEqual( + pg_str.escape_ident('""'), + '""""', + ) + chars = ''.join([ + chr(x) for x in range(10000) + if chr(x) != '"' + ]) + self.assertEqual( + pg_str.escape_ident(chars), + chars, + ) + chars = ''.join([ + chr(x) for x in range(10000) + if chr(x) != "'" + ]) + self.assertEqual( + pg_str.escape_literal(chars), + chars, + ) + chars = ''.join([ + chr(x) for x in range(10000) + if chr(x) not in "\\'" + ]) + self.assertEqual( + pg_str.escape_literal(chars), + chars, + ) + +if __name__ == '__main__': + from types import ModuleType + this = ModuleType("this") + this.__dict__.update(globals()) + unittest.main(this) diff --git a/py_opengauss/test/test_types.py b/py_opengauss/test/test_types.py new file mode 100644 index 0000000000000000000000000000000000000000..c552b9d4b1ba9c9ae4866770d820a09890132fd9 --- /dev/null +++ b/py_opengauss/test/test_types.py @@ -0,0 +1,617 @@ +## +# .test.test_types - test type representations and I/O +## +import unittest +import struct +from ..python.functools import process_tuple +from .. import types as pg_types +from ..types.io import lib as typlib +from ..types.io import builtins +from ..types.io.contrib_hstore import hstore_factory +from ..types import Array + +class fake_typio(object): + @staticmethod + def encode(x): + return x.encode('utf-8') + @staticmethod + def decode(x): + return x.decode('utf-8') +hstore_pack, hstore_unpack = hstore_factory(0, fake_typio) + +# this must pack to that, and +# that must unpack to this +expectation_samples = { + ('bool', lambda x: builtins.bool_pack(x), lambda x: builtins.bool_unpack(x)) : [ + (True, b'\x01'), + (False, b'\x00'), + ], + + ('int2', builtins.int2_pack, builtins.int2_unpack) : [ + (0, b'\x00\x00'), + (1, b'\x00\x01'), + (2, b'\x00\x02'), + (0x0f, b'\x00\x0f'), + (0xf00, b'\x0f\x00'), + (0x7fff, b'\x7f\xff'), + (-0x8000, b'\x80\x00'), + (-1, b'\xff\xff'), + (-2, b'\xff\xfe'), + (-3, b'\xff\xfd'), + ], + + ('int4', builtins.int4_pack, builtins.int4_unpack) : [ + (0, b'\x00\x00\x00\x00'), + (1, b'\x00\x00\x00\x01'), + (2, b'\x00\x00\x00\x02'), + (0x0f, b'\x00\x00\x00\x0f'), + (0x7fff, b'\x00\x00\x7f\xff'), + (-0x8000, b'\xff\xff\x80\x00'), + (0x7fffffff, b'\x7f\xff\xff\xff'), + (-0x80000000, b'\x80\x00\x00\x00'), + (-1, b'\xff\xff\xff\xff'), + (-2, b'\xff\xff\xff\xfe'), + (-3, b'\xff\xff\xff\xfd'), + ], + + ('int8', builtins.int8_pack, builtins.int8_unpack) : [ + (0, b'\x00\x00\x00\x00\x00\x00\x00\x00'), + (1, b'\x00\x00\x00\x00\x00\x00\x00\x01'), + (2, b'\x00\x00\x00\x00\x00\x00\x00\x02'), + (0x0f, b'\x00\x00\x00\x00\x00\x00\x00\x0f'), + (0x7fffffff, b'\x00\x00\x00\x00\x7f\xff\xff\xff'), + (0x80000000, b'\x00\x00\x00\x00\x80\x00\x00\x00'), + (-0x80000000, b'\xff\xff\xff\xff\x80\x00\x00\x00'), + (-1, b'\xff\xff\xff\xff\xff\xff\xff\xff'), + (-2, b'\xff\xff\xff\xff\xff\xff\xff\xfe'), + (-3, b'\xff\xff\xff\xff\xff\xff\xff\xfd'), + ], + + ('numeric', typlib.numeric_pack, typlib.numeric_unpack) : [ + (((0,0,0,0),[]), b'\x00'*2*4), + (((0,0,0,0),[1]), b'\x00'*2*4 + b'\x00\x01'), + (((1,0,0,0),[1]), b'\x00\x01' + b'\x00'*2*3 + b'\x00\x01'), + (((1,1,1,1),[1]), b'\x00\x01'*4 + b'\x00\x01'), + (((1,1,1,1),[1,2]), b'\x00\x01'*4 + b'\x00\x01\x00\x02'), + (((1,1,1,1),[1,2,3]), b'\x00\x01'*4 + b'\x00\x01\x00\x02\x00\x03'), + ], + + ('varbit', typlib.varbit_pack, typlib.varbit_unpack) : [ + ((0, b'\x00'), b'\x00\x00\x00\x00\x00'), + ((1, b'\x01'), b'\x00\x00\x00\x01\x01'), + ((1, b'\x00'), b'\x00\x00\x00\x01\x00'), + ((2, b'\x00'), b'\x00\x00\x00\x02\x00'), + ((3, b'\x00'), b'\x00\x00\x00\x03\x00'), + ((9, b'\x00\x00'), b'\x00\x00\x00\x09\x00\x00'), + # More data than necessary, we allow this. + # Let the user do the necessary check if the cost is worth the benefit. + ((9, b'\x00\x00\x00'), b'\x00\x00\x00\x09\x00\x00\x00'), + ], + + # idk why + ('bytea', builtins.bytea_pack, builtins.bytea_unpack) : [ + (b'foo', b'foo'), + (b'bar', b'bar'), + (b'\x00', b'\x00'), + (b'\x01', b'\x01'), + ], + + ('char', builtins.char_pack, builtins.char_unpack) : [ + (b'a', b'a'), + (b'b', b'b'), + (b'\x00', b'\x00'), + ], + + ('point', typlib.point_pack, typlib.point_unpack) : [ + ((1.0, 1.0), b'?\xf0\x00\x00\x00\x00\x00\x00?\xf0\x00\x00\x00\x00\x00\x00'), + ((2.0, 2.0), b'@\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00'), + ((-1.0, -1.0), + b'\xbf\xf0\x00\x00\x00\x00\x00\x00\xbf\xf0\x00\x00\x00\x00\x00\x00'), + ], + + ('circle', typlib.circle_pack, typlib.circle_unpack) : [ + ((1.0, 1.0, 1.0), + b'?\xf0\x00\x00\x00\x00\x00\x00?\xf0\x00\x00' \ + b'\x00\x00\x00\x00?\xf0\x00\x00\x00\x00\x00\x00'), + ((2.0, 2.0, 2.0), + b'@\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00' \ + b'\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00'), + ], + + ('record', typlib.record_pack, typlib.record_unpack) : [ + ([], b'\x00\x00\x00\x00'), + ([(0,b'foo')], b'\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03foo'), + ([(0,None)], b'\x00\x00\x00\x01\x00\x00\x00\x00\xff\xff\xff\xff'), + ([(15,None)], b'\x00\x00\x00\x01\x00\x00\x00\x0f\xff\xff\xff\xff'), + ([(0xffffffff,None)], b'\x00\x00\x00\x01\xff\xff\xff\xff\xff\xff\xff\xff'), + ([(0,None), (1,b'some')], + b'\x00\x00\x00\x02\x00\x00\x00\x00\xff\xff\xff\xff' \ + b'\x00\x00\x00\x01\x00\x00\x00\x04some'), + ], + + ('array', typlib.array_pack, typlib.array_unpack) : [ + ([0, 0xf, (1,), (0,), (b'foo',)], + b'\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x0f\x00\x00\x00\x01' \ + b'\x00\x00\x00\x00\x00\x00\x00\x03foo' + ), + ([0, 0xf, (1,), (0,), (None,)], + b'\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x0f\x00\x00\x00\x01' \ + b'\x00\x00\x00\x00\xff\xff\xff\xff' + ) + ], + + ('hstore', hstore_pack, hstore_unpack) : [ + ({}, b'\x00\x00\x00\x00'), + ({'b' : None}, b'\x00\x00\x00\x01\x00\x00\x00\x01b\xff\xff\xff\xff'), + ({'b' : 'k'}, b'\x00\x00\x00\x01\x00\x00\x00\x01b\x00\x00\x00\x01k'), + ({'foo' : 'bar'}, b'\x00\x00\x00\x01\x00\x00\x00\x03foo\x00\x00\x00\x03bar'), + ({'foo' : None}, b'\x00\x00\x00\x01\x00\x00\x00\x03foo\xff\xff\xff\xff'), + ], +} +expectation_samples[('box', typlib.box_pack, typlib.box_unpack)] = \ + expectation_samples[('lseg', typlib.lseg_pack, typlib.lseg_unpack)] = [ + ((1.0, 1.0, 1.0, 1.0), + b'?\xf0\x00\x00\x00\x00\x00\x00?\xf0' \ + b'\x00\x00\x00\x00\x00\x00?\xf0\x00\x00' \ + b'\x00\x00\x00\x00?\xf0\x00\x00\x00\x00\x00\x00'), + ((2.0, 2.0, 1.0, 1.0), + b'@\x00\x00\x00\x00\x00\x00\x00@\x00\x00' \ + b'\x00\x00\x00\x00\x00?\xf0\x00\x00\x00\x00' \ + b'\x00\x00?\xf0\x00\x00\x00\x00\x00\x00'), + ((-1.0, -1.0, 1.0, 1.0), + b'\xbf\xf0\x00\x00\x00\x00\x00\x00\xbf\xf0' \ + b'\x00\x00\x00\x00\x00\x00?\xf0\x00\x00\x00' \ + b'\x00\x00\x00?\xf0\x00\x00\x00\x00\x00\x00'), + ] + +expectation_samples[('oid', typlib.oid_pack, typlib.oid_unpack)] = \ + expectation_samples[('cid', typlib.cid_pack, typlib.cid_unpack)] = \ + expectation_samples[('xid', typlib.xid_pack, typlib.xid_unpack)] = [ + (0, b'\x00\x00\x00\x00'), + (1, b'\x00\x00\x00\x01'), + (2, b'\x00\x00\x00\x02'), + (0xf, b'\x00\x00\x00\x0f'), + (0xffffffff, b'\xff\xff\xff\xff'), + (0x7fffffff, b'\x7f\xff\xff\xff'), + ] + +# this must pack and then unpack back into this +consistency_samples = { + ('bool', lambda x: builtins.bool_pack(x), lambda x: builtins.bool_unpack(x)) : [True, False], + + ('record', typlib.record_pack, typlib.record_unpack) : [ + [], + [(0,b'foo')], + [(0,None)], + [(15,None)], + [(0xffffffff,None)], + [(0,None), (1,b'some')], + [(0,None), (1,b'some'), (0xffff, b"something_else\x00")], + [(0,None), (1,b"s\x00me"), (0xffff, b"\x00something_else\x00")], + ], + + ('array', typlib.array_pack, typlib.array_unpack) : [ + [0, 0xf, (), (), ()], + [0, 0xf, (0,), (0,), ()], + [0, 0xf, (1,), (0,), (b'foo',)], + [0, 0xf, (1,), (0,), (None,)], + [0, 0xf, (2,), (0,), (None,None)], + [0, 0xf, (2,), (0,), (b'foo',None)], + [0, 0xff, (2,), (0,), (None,b'foo',)], + [0, 0xffffffff, (3,), (0,), (None,b'foo',None)], + [1, 0xffffffff, (3,), (0,), (None,b'foo',None)], + [1, 0xffffffff, (3, 1), (0, 0), (None,b'foo',None)], + [1, 0xffffffff, (3, 2), (0, 0), (None,b'one',b'foo',b'two',None,b'three')], + ], + + # Just some random data; it's just an integer, so nothing fancy. + ('date', typlib.date_pack, typlib.date_unpack) : [ + 123, + 321, + 0x7FFFFFF, + -0x8000000, + ], + + ('point', typlib.point_pack, typlib.point_unpack) : [ + (0, 0), + (2, 2), + (-1, -1), + (-1.5, -1.2), + (1.5, 1.2), + ], + + ('circle', typlib.circle_pack, typlib.circle_unpack) : [ + (0, 0, 0), + (2, 2, 2), + (-1, -1, -1), + (-1.5, -1.2, -1.8), + ], + + ('tid', typlib.tid_pack, typlib.tid_unpack) : [ + (0, 0), + (1, 1), + (0xffffffff, 0xffff), + (0, 0xffff), + (0xffffffff, 0), + (0xffffffff // 2, 0xffff // 2), + ], +} +__ = { + ('cidr', typlib.net_pack, typlib.net_unpack) : [ + (0, 0, b"\x00\x00\x00\x00"), + (2, 0, b"\x00" * 4), + (2, 0, b"\xFF" * 4), + (2, 32, b"\xFF" * 4), + (3, 0, b"\x00\x00" * 16), + ], + + ('inet', typlib.net_pack, typlib.net_unpack) : [ + (2, 32, b"\x00\x00\x00\x00"), + (2, 16, b"\x7f\x00\x00\x01"), + (2, 8, b"\xff\x00\xff\x01"), + (3, 128, b"\x7f\x00" * 16), + (3, 64, b"\xff\xff" * 16), + (3, 32, b"\x00\x00" * 16), + ], +} + +consistency_samples[('time', typlib.time_pack, typlib.time_unpack)] = \ +consistency_samples[('time64', typlib.time64_pack, typlib.time64_unpack)] = [ + (0, 0), + (123, 123), + (0xFFFFFFFF, 999999), +] + +# months, days, (seconds, microseconds) +consistency_samples[('interval', typlib.interval_pack, typlib.interval_unpack)] = [ + (0, 0, (0, 0)), + (1, 0, (0, 0)), + (0, 1, (0, 0)), + (1, 1, (0, 0)), + (0, 0, (0, 10000)), + (0, 0, (1, 0)), + (0, 0, (1, 10000)), + (1, 1, (1, 10000)), + (100, 50, (1423, 29313)) +] + +consistency_samples[('timetz', typlib.timetz_pack, typlib.timetz_unpack)] = \ +consistency_samples[('timetz', typlib.timetz64_pack, typlib.timetz64_unpack)] = \ + [ + ((0, 0), 0), + ((123, 123), 123), + ((0xFFFFFFFF, 999999), -123), + ] + +consistency_samples[('oid', typlib.oid_pack, typlib.oid_unpack)] = \ + consistency_samples[('cid', typlib.cid_pack, typlib.cid_unpack)] = \ + consistency_samples[('xid', typlib.xid_pack, typlib.xid_unpack)] = [ + 0, 0xffffffff, 0xffffffff // 2, 123, 321, 1, 2, 3 +] + +consistency_samples[('lseg', typlib.lseg_pack, typlib.lseg_unpack)] = \ + consistency_samples[('box', typlib.box_pack, typlib.box_unpack)] = [ + (1,2,3,4), + (4,3,2,1), + (0,0,0,0), + (-1,-1,-1,-1), + (-1.2,-1.5,-2.0,4.0) +] + +consistency_samples[('path', typlib.path_pack, typlib.path_unpack)] = \ + consistency_samples[('polygon', typlib.polygon_pack, typlib.polygon_unpack)] = [ + (1,2,3,4), + (4,3,2,1), + (0,0,0,0), + (-1,-1,-1,-1), + (-1.2,-1.5,-2.0,4.0), +] + +from types import GeneratorType +def resolve(ob): + 'make sure generators get "tuplified"' + if type(ob) not in (list, tuple, GeneratorType): + return ob + return [resolve(x) for x in ob] + +def testExpectIO(self, samples): + for id, sample in samples.items(): + name, pack, unpack = id + + for (sample_unpacked, sample_packed) in sample: + pack_trial = pack(sample_unpacked) + self.assertTrue( + pack_trial == sample_packed, + "%s sample: unpacked sample, %r, did not match " \ + "%r when packed, rather, %r" %( + name, sample_unpacked, + sample_packed, pack_trial + ) + ) + + sample_unpacked = resolve(sample_unpacked) + unpack_trial = resolve(unpack(sample_packed)) + self.assertTrue( + unpack_trial == sample_unpacked, + "%s sample: packed sample, %r, did not match " \ + "%r when unpacked, rather, %r" %( + name, sample_packed, + sample_unpacked, unpack_trial + ) + ) + +class test_io(unittest.TestCase): + def test_process_tuple(self): + def funpass(cause, procs, tup, col): + pass + self.assertEqual(tuple(process_tuple((),(), funpass)), ()) + self.assertEqual(tuple(process_tuple((int,),("100",), funpass)), (100,)) + self.assertEqual(tuple(process_tuple((int,int),("100","200"), funpass)), (100,200)) + self.assertEqual(tuple(process_tuple((int,int),(None,"200"), funpass)), (None,200)) + self.assertEqual(tuple(process_tuple((int,int,int),(None,None,"200"), funpass)), (None,None,200)) + # The exception handler must raise. + self.assertRaises(RuntimeError, process_tuple, (int,), ("foo",), funpass) + + class ThisError(Exception): + pass + data = [] + def funraise(cause, procs, tup, col): + data.append((procs, tup, col)) + raise ThisError from cause + self.assertRaises(ThisError, process_tuple, (int,), ("foo",), funraise) + self.assertEqual(data[0], ((int,), ("foo",), 0)) + del data[0] + self.assertRaises(ThisError, process_tuple, (int,int), ("100","bar"), funraise) + self.assertEqual(data[0], ((int,int), ("100","bar"), 1)) + + def testExpectations(self): + 'IO tests where the pre-made expected serialized form is compared' + testExpectIO(self, expectation_samples) + + def testConsistency(self): + 'IO tests where the unpacked source is compared to re-unpacked result' + for id, sample in consistency_samples.items(): + name, pack, unpack = id + if pack is not None: + for x in sample: + packed = pack(x) + unpacked = resolve(unpack(packed)) + x = resolve(x) + self.assertTrue(x == unpacked, + "inconsistency with %s, %r -> %r -> %r" %( + name, x, packed, unpacked + ) + ) + + ## + # Further hstore tests. + def test_hstore(self): + # Can't do some tests with the consistency checks + # because we are not using ordered dictionaries. + self.assertRaises((ValueError, struct.error), hstore_unpack, b'\x00\x00\x00\x00foo') + self.assertRaises(ValueError, hstore_unpack, b'\x00\x00\x00\x01') + self.assertRaises(ValueError, hstore_unpack, b'\x00\x00\x00\x02\x00\x00\x00\x01G\x00\x00\x00\x01G') + sample = [ + ([('foo','bar'),('k',None),('zero','heroes')], + b'\x00\x00\x00\x03\x00\x00\x00\x03foo' + \ + b'\x00\x00\x00\x03bar\x00\x00\x00\x01k\xFF\xFF\xFF\xFF' + \ + b'\x00\x00\x00\x04zero\x00\x00\x00\x06heroes'), + ([('foo',None),('k',None),('zero',None)], + b'\x00\x00\x00\x03\x00\x00\x00\x03foo' + \ + b'\xff\xff\xff\xff\x00\x00\x00\x01k\xFF\xFF\xFF\xFF' + \ + b'\x00\x00\x00\x04zero\xFF\xFF\xFF\xFF'), + ([], b'\x00\x00\x00\x00'), + ] + for x in sample: + src, serialized = x + self.assertEqual(hstore_pack(src), serialized) + self.assertEqual(hstore_unpack(serialized), dict(src)) + +# Make some slices; used by testSlicing +slice_samples = [ + slice(0, None, x+1) for x in range(10) +] + [ + slice(x, None, 1) for x in range(10) +] + [ + slice(None, x, 1) for x in range(10) +] + [ + slice(None, -x, 70) for x in range(10) +] + [ + slice(x+1, x, -1) for x in range(10) +] + [ + slice(x+4, x, -2) for x in range(10) +] + +class test_Array(unittest.TestCase): + def emptyArray(self, a): + self.assertEqual(len(a), 0) + self.assertEqual(list(a.elements()), []) + self.assertEqual(a.dimensions, ()) + self.assertEqual(a.lowerbounds, ()) + self.assertEqual(a.upperbounds, ()) + self.assertRaises(IndexError, a.__getitem__, 0) + + def testArrayInstantiation(self): + a = Array([]) + self.emptyArray(a) + # exercise default upper/lower + a = Array((1,2,3,)) + self.assertEqual((a[0],a[1],a[2]), (1,2,3,)) + # Python interface, Python semantics. + self.assertRaises(IndexError, a.__getitem__, 3) + self.assertEqual(a.dimensions, (3,)) + self.assertEqual(a.lowerbounds, (1,)) + self.assertEqual(a.upperbounds, (3,)) + + def testNestedArrayInstantiation(self): + a = Array(([1,2],[3,4])) + # Python interface, Python semantics. + self.assertRaises(IndexError, a.__getitem__, 3) + self.assertEqual(a.dimensions, (2,2,)) + self.assertEqual(a.lowerbounds, (1,1)) + self.assertEqual(a.upperbounds, (2,2)) + self.assertEqual(list(a.elements()), [1,2,3,4]) + self.assertEqual(list(a), + [ + Array([1, 2]), + Array([3, 4]), + ] + ) + + a = Array(([[1],[2]],[[3],[4]])) + self.assertRaises(IndexError, a.__getitem__, 3) + self.assertEqual(a.dimensions, (2,2,1)) + self.assertEqual(a.lowerbounds, (1,1,1)) + self.assertEqual(a.upperbounds, (2,2,1)) + self.assertEqual(list(a), + [ + Array([[1], [2]]), + Array([[3], [4]]), + ] + ) + + self.assertRaises(ValueError, Array, [ + [1], [2,3] + ]) + self.assertRaises(ValueError, Array, [ + [1], [] + ]) + self.assertRaises(ValueError, Array, [ + [[1]], + [[],2] + ]) + self.assertRaises(ValueError, Array, [ + [[[[[1,2,3]]]]], + [[[[[1,2,3]]]]], + [[[[[1,2,3]]]]], + [[[[[2,2]]]]], + ]) + + def testSlicing(self): + elements = [1,2,3,4,5,6,7,8] + d1 = Array([1,2,3,4,5,6,7,8]) + for x in slice_samples: + self.assertEqual( + d1[x], Array(elements[x]) + ) + elements = [[1,2],[3,4],[5,6],[7,8]] + d2 = Array(elements) + for x in slice_samples: + self.assertEqual( + d2[x], Array(elements[x]) + ) + elements = [ + [[[1,2],[3,4]]], + [[[5,6],[791,8]]], + [[[1,2],[333,4]]], + [[[1,2],[3,4]]], + [[[5,10],[7,8]]], + [[[0,6],[7,8]]], + [[[1,2],[3,4]]], + [[[5,6],[7,8]]], + ] + d3 = Array(elements) + for x in slice_samples: + self.assertEqual( + d3[x], Array(elements[x]) + ) + + def testFromElements(self): + a = Array.from_elements(()) + self.emptyArray(a) + + # exercise default upper/lower + a = Array.from_elements((1,2,3,)) + self.assertEqual((a[0],a[1],a[2]), (1,2,3,)) + # Python interface, Python semantics. + self.assertRaises(IndexError, a.__getitem__, 3) + self.assertEqual(a.dimensions, (3,)) + self.assertEqual(a.lowerbounds, (1,)) + self.assertEqual(a.upperbounds, (3,)) + + # exercise default upper/lower + a = Array.from_elements([3,2,1], lowerbounds = (2,), upperbounds = (4,)) + self.assertEqual(a.dimensions, (3,)) + self.assertEqual(a.lowerbounds, (2,)) + self.assertEqual(a.upperbounds, (4,)) + + def testEmptyDimension(self): + self.assertRaises(ValueError, + Array, [[]] + ) + self.assertRaises(ValueError, + Array, [[2],[]] + ) + self.assertRaises(ValueError, + Array, [[],[],[]] + ) + self.assertRaises(ValueError, + Array, [[2],[3],[]] + ) + + def testExcessive(self): + # lowerbounds too high for upperbounds + self.assertRaises(ValueError, + Array.from_elements, [1], lowerbounds = (2,), upperbounds = (1,) + ) + + def testNegatives(self): + a = Array.from_elements([0], lowerbounds = (-1,), upperbounds = (-1,)) + self.assertEqual(a[0], 0) + self.assertEqual(a[-1], 0) + # upperbounds at zero + a = Array.from_elements([1,2], lowerbounds = (-1,), upperbounds = (0,)) + self.assertEqual(a[0], 1) + self.assertEqual(a[1], 2) + self.assertEqual(a[-2], 1) + self.assertEqual(a[-1], 2) + + def testGetElement(self): + a = Array([1,2,3,4]) + self.assertEqual(a.get_element((0,)), 1) + self.assertEqual(a.get_element((1,)), 2) + self.assertEqual(a.get_element((2,)), 3) + self.assertEqual(a.get_element((3,)), 4) + self.assertEqual(a.get_element((-1,)), 4) + self.assertEqual(a.get_element((-2,)), 3) + self.assertEqual(a.get_element((-3,)), 2) + self.assertEqual(a.get_element((-4,)), 1) + self.assertRaises(IndexError, a.get_element, (4,)) + a = Array([[1,2],[3,4]]) + self.assertEqual(a.get_element((0,0)), 1) + self.assertEqual(a.get_element((0,1,)), 2) + self.assertEqual(a.get_element((1,0,)), 3) + self.assertEqual(a.get_element((1,1,)), 4) + self.assertEqual(a.get_element((-1,-1)), 4) + self.assertEqual(a.get_element((-1,-2,)), 3) + self.assertEqual(a.get_element((-2,-1,)), 2) + self.assertEqual(a.get_element((-2,-2,)), 1) + self.assertRaises(IndexError, a.get_element, (2,0)) + self.assertRaises(IndexError, a.get_element, (1,2)) + self.assertRaises(IndexError, a.get_element, (0,2)) + + def testSQLGetElement(self): + a = Array([1,2,3,4]) + self.assertEqual(a.sql_get_element((1,)), 1) + self.assertEqual(a.sql_get_element((2,)), 2) + self.assertEqual(a.sql_get_element((3,)), 3) + self.assertEqual(a.sql_get_element((4,)), 4) + self.assertEqual(a.sql_get_element((0,)), None) + self.assertEqual(a.sql_get_element((5,)), None) + self.assertEqual(a.sql_get_element((-1,)), None) + self.assertEqual(a.sql_get_element((-2,)), None) + self.assertEqual(a.sql_get_element((-3,)), None) + self.assertEqual(a.sql_get_element((-4,)), None) + a = Array([[1,2],[3,4]]) + self.assertEqual(a.sql_get_element((1,1)), 1) + self.assertEqual(a.sql_get_element((1,2,)), 2) + self.assertEqual(a.sql_get_element((2,1,)), 3) + self.assertEqual(a.sql_get_element((2,2,)), 4) + self.assertEqual(a.sql_get_element((3,1)), None) + self.assertEqual(a.sql_get_element((1,3)), None) + +if __name__ == '__main__': + from types import ModuleType + this = ModuleType("this") + this.__dict__.update(globals()) + unittest.main(this) diff --git a/py_opengauss/test/testall.py b/py_opengauss/test/testall.py new file mode 100644 index 0000000000000000000000000000000000000000..b32ccaa1581565edce6f114e01462a524591de15 --- /dev/null +++ b/py_opengauss/test/testall.py @@ -0,0 +1,38 @@ +## +# .test.testall +## +import unittest +from sys import stderr + +from ..installation import default + +from .test_exceptions import * +from .test_bytea_codec import * +from .test_iri import * +from .test_protocol import * +from .test_configfile import * +from .test_pgpassfile import * +from .test_python import * + +from .test_installation import * +from .test_cluster import * + +# Expects PGINSTALLATION to be set. Tests may be skipped. +from .test_connect import * +from .test_ssl_connect import * + +try: + from .test_optimized import * +except ImportError: + stderr.write("NOTICE: port.optimized could not be imported\n") + +from .test_driver import * +from .test_alock import * +from .test_notifyman import * +from .test_copyman import * +from .test_lib import * +from .test_dbapi20 import * +from .test_types import * + +if __name__ == '__main__': + unittest.main() diff --git a/py_opengauss/types/__init__.py b/py_opengauss/types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..690e19f3db742c0b44e4ad377fa055ba1ba402ef --- /dev/null +++ b/py_opengauss/types/__init__.py @@ -0,0 +1,643 @@ +## +# types. - Package for I/O and PostgreSQL specific types. +## +""" +PostgreSQL types and identifiers. +""" +# XXX: Would be nicer to generate these from a header file... +InvalidOid = 0 + +RECORDOID = 2249 +BOOLOID = 16 +BITOID = 1560 +VARBITOID = 1562 +ACLITEMOID = 1033 + +CHAROID = 18 +NAMEOID = 19 +TEXTOID = 25 +BYTEAOID = 17 +BPCHAROID = 1042 +VARCHAROID = 1043 +CSTRINGOID = 2275 +UNKNOWNOID = 705 +REFCURSOROID = 1790 +UUIDOID = 2950 + +TSVECTOROID = 3614 +GTSVECTOROID = 3642 +TSQUERYOID = 3615 +REGCONFIGOID = 3734 +REGDICTIONARYOID = 3769 + +JSONOID = 114 +JSONBOID = 3802 +XMLOID = 142 + +MACADDROID = 829 +INETOID = 869 +CIDROID = 650 + +TYPEOID = 71 +PROCOID = 81 +CLASSOID = 83 +ATTRIBUTEOID = 75 + +DATEOID = 1082 +TIMEOID = 1083 +TIMESTAMPOID = 1114 +TIMESTAMPTZOID = 1184 +INTERVALOID = 1186 +TIMETZOID = 1266 +ABSTIMEOID = 702 +RELTIMEOID = 703 +TINTERVALOID = 704 + +INT8OID = 20 +INT2OID = 21 +INT4OID = 23 +OIDOID = 26 +TIDOID = 27 +XIDOID = 28 +CIDOID = 29 +CASHOID = 790 +FLOAT4OID = 700 +FLOAT8OID = 701 +NUMERICOID = 1700 + +POINTOID = 600 +LINEOID = 628 +LSEGOID = 601 +PATHOID = 602 +BOXOID = 603 +POLYGONOID = 604 +CIRCLEOID = 718 + +OIDVECTOROID = 30 +INT2VECTOROID = 22 +INT4ARRAYOID = 1007 + +REGPROCOID = 24 +REGPROCEDUREOID = 2202 +REGOPEROID = 2203 +REGOPERATOROID = 2204 +REGCLASSOID = 2205 +REGTYPEOID = 2206 +REGTYPEARRAYOID = 2211 + +TRIGGEROID = 2279 +LANGUAGE_HANDLEROID = 2280 +INTERNALOID = 2281 +OPAQUEOID = 2282 +VOIDOID = 2278 +ANYARRAYOID = 2277 +ANYELEMENTOID = 2283 +ANYOID = 2276 +ANYNONARRAYOID = 2776 +ANYENUMOID = 3500 + +#: Mapping of type Oid to SQL type name. +oid_to_sql_name = { + BPCHAROID : 'CHARACTER', + VARCHAROID : 'CHARACTER VARYING', + # *OID : 'CHARACTER LARGE OBJECT', + + # SELECT X'0F' -> bit. XXX: Does bytea have any play here? + #BITOID : 'BINARY', + #BYTEAOID : 'BINARY VARYING', + # *OID : 'BINARY LARGE OBJECT', + + BOOLOID : 'BOOLEAN', + +# exact numeric types + INT2OID : 'SMALLINT', + INT4OID : 'INTEGER', + INT8OID : 'BIGINT', + NUMERICOID : 'NUMERIC', + +# approximate numeric types + FLOAT4OID : 'REAL', + FLOAT8OID : 'DOUBLE PRECISION', + +# datetime types + TIMEOID : 'TIME WITHOUT TIME ZONE', + TIMETZOID : 'TIME WITH TIME ZONE', + TIMESTAMPOID : 'TIMESTAMP WITHOUT TIME ZONE', + TIMESTAMPTZOID : 'TIMESTAMP WITH TIME ZONE', + DATEOID : 'DATE', + +# interval types + INTERVALOID : 'INTERVAL', + + XMLOID : 'XML', +} + +#: Mapping of type Oid to name. +oid_to_name = { + RECORDOID : 'record', + BOOLOID : 'bool', + BITOID : 'bit', + VARBITOID : 'varbit', + ACLITEMOID : 'aclitem', + + CHAROID : 'char', + NAMEOID : 'name', + TEXTOID : 'text', + BYTEAOID : 'bytea', + BPCHAROID : 'bpchar', + VARCHAROID : 'varchar', + CSTRINGOID : 'cstring', + UNKNOWNOID : 'unknown', + REFCURSOROID : 'refcursor', + UUIDOID : 'uuid', + + TSVECTOROID : 'tsvector', + GTSVECTOROID : 'gtsvector', + TSQUERYOID : 'tsquery', + REGCONFIGOID : 'regconfig', + REGDICTIONARYOID : 'regdictionary', + + XMLOID : 'xml', + JSONOID : 'json', + JSONBOID : 'jsonb', + + MACADDROID : 'macaddr', + INETOID : 'inet', + CIDROID : 'cidr', + + TYPEOID : 'type', + PROCOID : 'proc', + CLASSOID : 'class', + ATTRIBUTEOID : 'attribute', + + DATEOID : 'date', + TIMEOID : 'time', + TIMESTAMPOID : 'timestamp', + TIMESTAMPTZOID : 'timestamptz', + INTERVALOID : 'interval', + TIMETZOID : 'timetz', + ABSTIMEOID : 'abstime', + RELTIMEOID : 'reltime', + TINTERVALOID : 'tinterval', + + INT8OID : 'int8', + INT2OID : 'int2', + INT4OID : 'int4', + OIDOID : 'oid', + TIDOID : 'tid', + XIDOID : 'xid', + CIDOID : 'cid', + CASHOID : 'cash', + FLOAT4OID : 'float4', + FLOAT8OID : 'float8', + NUMERICOID : 'numeric', + + POINTOID : 'point', + LINEOID : 'line', + LSEGOID : 'lseg', + PATHOID : 'path', + BOXOID : 'box', + POLYGONOID : 'polygon', + CIRCLEOID : 'circle', + + OIDVECTOROID : 'oidvector', + INT2VECTOROID : 'int2vector', + INT4ARRAYOID : 'int4array', + + REGPROCOID : 'regproc', + REGPROCEDUREOID : 'regprocedure', + REGOPEROID : 'regoper', + REGOPERATOROID : 'regoperator', + REGCLASSOID : 'regclass', + REGTYPEOID : 'regtype', + REGTYPEARRAYOID : 'regtypearray', + + TRIGGEROID : 'trigger', + LANGUAGE_HANDLEROID : 'language_handler', + INTERNALOID : 'internal', + OPAQUEOID : 'opaque', + VOIDOID : 'void', + ANYARRAYOID : 'anyarray', + ANYELEMENTOID : 'anyelement', + ANYOID : 'any', + ANYNONARRAYOID : 'anynonarray', + ANYENUMOID : 'anyenum', +} + +name_to_oid = dict( + [(v,k) for k,v in oid_to_name.items()] +) + +class Array(object): + """ + Type used to mimic PostgreSQL arrays. While there are many semantic + differences, the primary one is that the elements contained by an Array + instance are not strongly typed. The purpose of this class is to provide + some consistency with PostgreSQL with respect to the structure of an Array. + + The structure consists of three parts: + + * The elements of the array. + * The lower boundaries. + * The upper boundaries. + + There is also a `dimensions` property, but it is derived from the + `lowerbounds` and `upperbounds` to yield a normalized description of the + ARRAY's structure. + + The Python interfaces, such as __getitem__, are *not* subjected to the + semantics of the lower and upper bounds. Rather, the normalized dimensions + provide the primary influence for these interfaces. So, unlike SQL + indirection, getting an index that does *not* exist will raise a Python + `IndexError`. + """ + # return an iterator over the absolute elements of a nested sequence + @classmethod + def unroll_nest(typ, hier, dimensions, depth = 0): + dsize = dimensions and dimensions[depth] or 0 + if len(hier) != dsize: + raise ValueError("list size not consistent with dimensions at depth " + str(depth)) + r = [] + ndepth = depth + 1 + if ndepth == len(dimensions): + # at the bottom + r = hier + else: + # go deeper + for x in hier: + r.extend(typ.unroll_nest(x, dimensions, ndepth)) + return r + + # Detect the dimensions of a nested sequence + @staticmethod + def detect_dimensions(hier, len = len): + # if the list is empty, it's a zero-dimension array. + if hier: + yield len(hier) + hier = hier[0] + depth = 1 + while hier.__class__ is list: + depth += 1 + l = len(hier) + if l < 1: + raise ValueError("axis {0} is empty".format(depth)) + yield l + hier = hier[0] + + @classmethod + def from_elements(typ, + elements, + lowerbounds = None, + upperbounds = None, + len = len, + ): + """ + Instantiate an Array from the given elements, lowerbounds, and upperbounds. + + The given elements are bound to the array which provides them with the + structure defined by the lower boundaries and the upper boundaries. + + A `ValueError` will be raised in the following situations: + + * The number of elements given are inconsistent with the number of elements + described by the upper and lower bounds. + * The lower bounds at a given axis exceeds the upper bounds at a given + axis. + * The number of lower bounds is inconsistent with the number of upper + bounds. + """ + # resolve iterable + elements = list(elements) + nelements = len(elements) + + # If ndims is zero, lowerbounds will be () + if lowerbounds is None: + if upperbounds: + lowerbounds = (1,) * len(upperbounds) + elif nelements == 0: + # special for empty ARRAY; no dimensions. + lowerbounds = () + else: + # one dimension. + lowerbounds = (1,) + else: + lowerbounds = tuple(lowerbounds) + + if upperbounds is not None: + upperbounds = tuple(upperbounds) + dimensions = [] + # upperbounds were given, so check. + if upperbounds: + elcount = 1 + for lb, ub in zip(lowerbounds, upperbounds): + x = ub - lb + 1 + if x < 1: + # special case empty ARRAYs + if nelements == 0: + upperbounds = () + lowerbounds = () + dimensions = () + elcount = 0 + break + raise ValueError("lowerbounds exceeds upperbounds") + # physical dimensions. + dimensions.append(x) + elcount = x * elcount + else: + elcount = 0 + if nelements != elcount: + raise ValueError("element count inconsistent with boundaries") + dimensions = tuple(dimensions) + else: + # fill in default + if nelements == 0: + upperbounds = () + dimensions = () + else: + upperbounds = (nelements,) + dimensions = (nelements,) + + # consistency.. + if len(lowerbounds) != len(upperbounds): + raise ValueError("number of lowerbounds inconsistent with upperbounds") + + rob = super().__new__(typ) + rob._elements = elements + rob.lowerbounds = lowerbounds + rob.upperbounds = upperbounds + rob.dimensions = dimensions + rob.ndims = len(dimensions) + rob._weight = len(rob._elements) // (dimensions and dimensions[0] or 1) + return rob + + # Method used to create an Array() from nested lists. + @classmethod + def from_nest(typ, nest): + dims = tuple(typ.detect_dimensions(nest)) + return typ.from_elements( + list(typ.unroll_nest(nest, dims)), + upperbounds = dims, + # lowerbounds is implied to (1,)*len(upper) + ) + + def __new__(typ, nested_elements): + """ + Create an types.Array() using the given nested lists. The boundaries of + the array are detected by traversing the first items of the nested + lists:: + + Array([[1,2,4],[3,4,8]]) + + Lists are used to define the boundaries so that tuples may be used to + represent any complex elements. The above array will the `lowerbounds` + ``(1,1)``, and the `upperbounds` ``(2,3)``. + """ + if nested_elements.__class__ is Array: + return nested_elements + return typ.from_nest(list(nested_elements)) + + def __getnewargs__(self): + return (self.nest(),) + + def elements(self): + """ + Returns an iterator to the elements of the Array. The elements are + produced in physical order. + """ + return iter(self._elements) + + def nest(self, seqtype = list): + """ + Transform the array into a nested list. + + The `seqtype` keyword can be used to override the type used to represent + the elements of a given axis. + """ + if self.ndims < 2: + return seqtype(self._elements) + else: + rl = [] + for x in self: + rl.append(x.nest(seqtype = seqtype)) + return seqtype(rl) + + def get_element(self, address, + idxerr = "index {0} at axis {1} is out of range {2}".format + ): + """ + Get an element in the array using the given axis sequence. + + >>> a=Array([[1,2],[3,4]]) + >>> a.get_element((0,0)) == 1 + True + >>> a.get_element((1,1)) == 4 + True + + This is similar to getting items in a nested list:: + + >>> l=[[1,2],[3,4]] + >>> l[0][0] == 1 + True + """ + if not self.dimensions: + raise IndexError("array is empty") + if len(address) != len(self.dimensions): + raise ValueError("given axis sequence is inconsistent with number of dimensions") + + # normalize axis specification (-N + DIM), check for IndexErrors, and + # resolve the element's position. + cur = 0 + nelements = len(self._elements) + for n, a, dim in zip(range(len(address)), address, self.dimensions): + if a < 0: + a = a + dim + if a < 0: + raise IndexError(idxerr(a, n, dim)) + else: + if a >= dim: + raise IndexError(idxerr(a, n, dim)) + nelements = nelements // dim + cur += (a * nelements) + return self._elements[cur] + + def sql_get_element(self, address): + """ + Like `get_element`, but with SQL indirection semantics. Notably, returns + `None` on IndexError. + """ + try: + a = [a - lb for (a, lb) in zip(address, self.lowerbounds)] + # get_element accepts negatives, so check the converted sequence. + for x in a: + if x < 0: + return None + return self.get_element(a) + except IndexError: + return None + + def __repr__(self): + return '%s.%s(%r)' %( + type(self).__module__, + type(self).__name__, + self.nest() + ) + + def __len__(self): + return self.dimensions and self.dimensions[0] or 0 + + def __eq__(self, ob): + return list(self) == ob + + def __ne__(self, ob): + return list(self) != ob + + def __gt__(self, ob): + return list(self) > ob + + def __lt__(self, ob): + return list(self) < ob + + def __le__(self, ob): + return list(self) <= ob + + def __ge__(self, ob): + return list(self) >= ob + + def __getitem__(self, item): + if self.ndims < 2: + # Array with 1dim is more or less a list. + return self._elements[item] + if isinstance(item, slice): + # get a sub-array slice + l = len(self) + n = 0 + r = [] + # for each offset in the slice, get the elements and add them + # to the new elements list used to build the new Array(). + for x in range(*(item.indices(l))): + n = n + 1 + r.extend( + self._elements[slice(self._weight*x,self._weight*(x+1))] + ) + if n: + return self.__class__.from_elements(r, + lowerbounds = (1,) + self.lowerbounds[1:], + upperbounds = (n,) + self.upperbounds[1:], + ) + else: + # Empty + return self.__class__.from_elements(()) + else: + # get a sub-array + l = len(self) + if item > l: + raise IndexError("index {0} is out of range".format(l)) + return self.__class__.from_elements( + self._elements[self._weight*item:self._weight*(item+1)], + lowerbounds = self.lowerbounds[1:], + upperbounds = self.upperbounds[1:], + ) + + def __iter__(self): + if self.ndims < 2: + # Special case empty and single dimensional ARRAYs + return self.elements() + return (self[x] for x in range(len(self))) + +from operator import itemgetter +get0 = itemgetter(0) +get1 = itemgetter(1) +del itemgetter + +class Row(tuple): + "Name addressable items tuple; mapping and sequence" + @classmethod + def from_mapping(typ, keymap, map, get1 = get1): + iter = [ + map.get(k) for k,_ in sorted(keymap.items(), key = get1) + ] + r = typ(iter) + r.keymap = keymap + return r + + @classmethod + def from_sequence(typ, keymap, seq): + r = typ(seq) + r.keymap = keymap + return r + + def __getitem__(self, i, gi = tuple.__getitem__): + if isinstance(i, (int, slice)): + return gi(self, i) + idx = self.keymap[i] + return gi(self, idx) + + def get(self, i, gi = tuple.__getitem__, len = len): + if type(i) is int: + l = len(self) + if -l < i < l: + return gi(self, i) + else: + idx = self.keymap.get(i) + if idx is not None: + return gi(self, idx) + return None + + def keys(self): + return self.keymap.keys() + + def values(self): + return iter(self) + + def items(self): + return zip(iter(self.column_names), iter(self)) + + def index_from_key(self, key): + return self.keymap.get(key) + + def key_from_index(self, index): + for k,v in self.keymap.items(): + if v == index: + return k + return None + + @property + def column_names(self, get0 = get0, get1 = get1): + l=list(self.keymap.items()) + l.sort(key=get1) + return tuple(map(get0, l)) + + def transform(self, *args, **kw): + """ + Make a new Row after processing the values with the callables associated + with the values either by index, \*args, or my column name, \*\*kw. + + >>> r=Row.from_sequence({'col1':0,'col2':1}, (1,'two')) + >>> r.transform(str) + ('1','two') + >>> r.transform(col2 = str.upper) + (1,'TWO') + >>> r.transform(str, col2 = str.upper) + ('1','TWO') + + Combine with methodcaller and map to transform lots of rows: + + >>> rowseq = [r] + >>> xf = operator.methodcaller('transform', col2 = str.upper) + >>> list(map(xf, rowseq)) + [(1,'TWO')] + + """ + r = list(self) + i = 0 + for x in args: + if x is not None: + r[i] = x(tuple.__getitem__(self, i)) + i = i + 1 + for k,v in kw.items(): + if v is not None: + i = self.index_from_key(k) + if i is None: + raise KeyError("row has no such key, " + repr(k)) + r[i] = v(self[k]) + return type(self).from_sequence(self.keymap, r) diff --git a/py_opengauss/types/bitwise.py b/py_opengauss/types/bitwise.py new file mode 100644 index 0000000000000000000000000000000000000000..018b762519a6d80529239e2d94e77c8a110ebfdf --- /dev/null +++ b/py_opengauss/types/bitwise.py @@ -0,0 +1,103 @@ +class Varbit(object): + __slots__ = ('data', 'bits') + + def from_bits(subtype, bits, data): + if bits == 1: + return (data[0] & (1 << 7)) and OneBit or ZeroBit + else: + rob = object.__new__(subtype) + rob.bits = bits + rob.data = data + return rob + from_bits = classmethod(from_bits) + + def __new__(typ, data): + if isinstance(data, Varbit): + return data + if isinstance(data, bytes): + return typ.from_bits(len(data) * 8, data) + # str(), eg '00101100' + bits = len(data) + nbytes, remain = divmod(bits, 8) + bdata = [bytes((int(data[x:x+8], 2),)) for x in range(0, bits - remain, 8)] + if remain != 0: + bdata.append(bytes((int(data[nbytes*8:].ljust(8,'0'), 2),))) + return typ.from_bits(bits, b''.join(bdata)) + + def __str__(self): + if self.bits: + # cut off the remainder from the bits + blocks = [bin(x)[2:].rjust(8, '0') for x in self.data] + blocks[-1] = blocks[-1][0:(self.bits % 8) or 8] + return ''.join(blocks) + else: + return '' + + def __repr__(self): + return '%s.%s(%r)' %( + type(self).__module__, + type(self).__name__, + str(self) + ) + + def __eq__(self, ob): + if not isinstance(ob, type(self)): + ob = type(self)(ob) + return ob.bits == self.bits and ob.data == self.data + + def __len__(self): + return self.bits + + def __add__(self, ob): + return Varbit(str(self) + str(ob)) + + def __mul__(self, ob): + return Varbit(str(self) * ob) + + def getbit(self, bitoffset): + if bitoffset < 0: + idx = self.bits + bitoffset + else: + idx = bitoffset + if not 0 <= idx < self.bits: + raise IndexError("bit index %d out of range" %(bitoffset,)) + + byte, bitofbyte = divmod(idx, 8) + if ord(self.data[byte]) & (1 << (7 - bitofbyte)): + return OneBit + else: + return ZeroBit + + def __getitem__(self, item): + if isinstance(item, slice): + return type(self)(str(self)[item]) + else: + return self.getbit(item) + + def __nonzero__(self): + for x in self.data: + if x != 0: + return True + return False + +class Bit(Varbit): + def __new__(subtype, ob): + if ob is ZeroBit or ob is False or ob == '0': + return ZeroBit + elif ob is OneBit or ob is True or ob == '1': + return OneBit + + raise ValueError('unknown bit value %r, 0 or 1' %(ob,)) + + def __nonzero__(self): + return self is OneBit + + def __str__(self): + return self is OneBit and '1' or '0' + +ZeroBit = object.__new__(Bit) +ZeroBit.data = b'\x00' +ZeroBit.bits = 1 +OneBit = object.__new__(Bit) +OneBit.data = b'\x80' +OneBit.bits = 1 diff --git a/py_opengauss/types/geometry.py b/py_opengauss/types/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..ba996e57450aedb1738c22bb193632c32b353ef6 --- /dev/null +++ b/py_opengauss/types/geometry.py @@ -0,0 +1,191 @@ +import math +from operator import itemgetter +get0 = itemgetter(0) +get1 = itemgetter(1) +# Geometric types + +class Point(tuple): + """ + A point; a pair of floating point numbers. + """ + __slots__ = () + x = property(fget = lambda s: s[0]) + y = property(fget = lambda s: s[1]) + + def __new__(subtype, pair): + return tuple.__new__(subtype, (float(pair[0]), float(pair[1]))) + + def __repr__(self): + return '%s.%s(%s)' %( + type(self).__module__, + type(self).__name__, + tuple.__repr__(self), + ) + + def __str__(self): + return tuple.__repr__(self) + + def __add__(self, ob): + wx, wy = ob + return type(self)((self[0] + wx, self[1] + wy)) + + def __sub__(self, ob): + wx, wy = ob + return type(self)((self[0] - wx, self[1] - wy)) + + def __mul__(self, ob): + wx, wy = ob + rx = (self[0] * wx) - (self[1] * wy) + ry = (self[0] * wy) + (self[1] * wx) + return type(self)((rx, ry)) + + def __div__(self, ob): + sx, sy = self + wx, wy = ob + div = (wx * wx) + (wy * wy) + rx = ((sx * wx) + (sy * wy)) / div + ry = ((wx * sy) + (wy * sx)) / div + return type(self)((rx, ry)) + + def distance(self, ob, sqrt = math.sqrt): + wx, wy = ob + dx = self[0] - float(wx) + dy = self[1] - float(wy) + return sqrt(dx**2 + dy**2) + +class Lseg(tuple): + __slots__ = () + one = property(fget = lambda s: s[0]) + two = property(fget = lambda s: s[1]) + + length = property(fget = lambda s: s[0].distance(s[1])) + vertical = property(fget = lambda s: s[0][0] == s[1][0]) + horizontal = property(fget = lambda s: s[0][1] == s[1][1]) + slope = property( + fget = lambda s: (s[1][1] - s[0][1]) / (s[1][0] - s[0][0]) + ) + center = property( + fget = lambda s: Point(( + (s[0][0] + s[1][0]) / 2.0, + (s[0][1] + s[1][1]) / 2.0, + )) + ) + + def __new__(subtype, pair): + p1, p2 = pair + return tuple.__new__(subtype, (Point(p1), Point(p2))) + + def __repr__(self): + # Avoid the point representation + return '%s.%s(%s, %s)' %( + type(self).__module__, + type(self).__name__, + tuple.__repr__(self[0]), + tuple.__repr__(self[1]), + ) + + def __str__(self): + return '[(%s,%s),(%s,%s)]' %( + self[0][0], + self[0][1], + self[1][0], + self[1][1], + ) + + def parallel(self, ob): + return self.slope == type(self)(ob).slope + + def perpendicular(self, ob): + return (self.slope / type(self)(ob).slope) == -1.0 + +class Box(tuple): + """ + A pair of points. One specifying the top-right point of the box; the other + specifying the bottom-left. `high` being top-right; `low` being bottom-left. + + http://www.postgresql.org/docs/current/static/datatype-geometric.html + + >>> Box(( (0,0), (-2, -2) )) + postgresql.types.geometry.Box(((0.0, 0.0), (-2.0, -2.0))) + + It will also relocate values to enforce the high-low expectation: + + >>> t.box(((-4,0),(-2,-3))) + postgresql.types.geometry.Box(((-2.0, 0.0), (-4.0, -3.0))) + + :: + + (-2, 0) `high` + | + | + (-4,-3) -------+-x + `low` y + + This happens because ``-4`` is less than ``-2``; therefore the ``-4`` + belongs on the low point. This is consistent with what PostgreSQL does + with its ``box`` type. + """ + __slots__ = () + high = property(fget = get0, doc = "high point of the box") + low = property(fget = get1, doc = "low point of the box") + center = property( + fget = lambda s: Point(( + (s[0][0] + s[1][0]) / 2.0, + (s[0][1] + s[1][1]) / 2.0 + )), + doc = "center of the box as a point" + ) + + def __new__(subtype, hl): + if isinstance(hl, Box): + return hl + one, two = hl + if one[0] > two[0]: + hx = one[0] + lx = two[0] + else: + hx = two[0] + lx = one[0] + if one[1] > two[1]: + hy = one[1] + ly = two[1] + else: + hy = two[1] + ly = one[1] + return tuple.__new__(subtype, (Point((hx, hy)), Point((lx, ly)))) + + def __repr__(self): + return '%s.%s((%s, %s))' %( + type(self).__module__, + type(self).__name__, + tuple.__repr__(self[0]), + tuple.__repr__(self[1]), + ) + + def __str__(self): + return '%s,%s' %(self[0], self[1]) + +class Circle(tuple): + """ + Type for PostgreSQL circles. + """ + __slots__ = () + center = property(fget = get0, doc = "center of the circle (point)") + radius = property(fget = get1, doc = "radius of the circle (radius >= 0)") + + def __new__(subtype, pair): + center, radius = pair + if radius < 0: + raise ValueError("radius is subzero") + return tuple.__new__(subtype, (Point(center), float(radius))) + + def __repr__(self): + return '%s.%s((%s, %s))' %( + type(self).__module__, + type(self).__name__, + tuple.__repr__(self[0]), + repr(self[1]) + ) + + def __str__(self): + return '<%s,%s>' %(self[0], self[1]) diff --git a/py_opengauss/types/io/__init__.py b/py_opengauss/types/io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..93542be40185712fe4311165b5e19c0700a1df65 --- /dev/null +++ b/py_opengauss/types/io/__init__.py @@ -0,0 +1,113 @@ +## +# .types.io - I/O routines for packing and unpacking data +## +""" +PostgreSQL type I/O routines--packing and unpacking functions. + +This package manages the modules providing I/O routines. + +The name of the function describes what type the function is intended to be used +on. Normally, the fucntions return a structured form of the serialized data to +be used as a parameter to the creation of a higher level instance. In +particular, most of the functions that deal with time return a pair for +representing the relative offset: (seconds, microseconds). For times, this +provides an abstraction for quad-word based times used by some configurations of +PostgreSQL. +""" +import sys +from itertools import cycle, chain +from ... import types as pg_types + +io_modules = { + 'builtins' : ( + pg_types.BOOLOID, + pg_types.CHAROID, + pg_types.BYTEAOID, + + pg_types.INT2OID, + pg_types.INT4OID, + pg_types.INT8OID, + + pg_types.FLOAT4OID, + pg_types.FLOAT8OID, + pg_types.ABSTIMEOID, + ), + + 'pg_bitwise': ( + pg_types.BITOID, + pg_types.VARBITOID, + ), + + 'pg_network': ( + pg_types.MACADDROID, + pg_types.INETOID, + pg_types.CIDROID, + ), + + 'pg_system': ( + pg_types.OIDOID, + pg_types.XIDOID, + pg_types.CIDOID, + pg_types.TIDOID, + ), + + 'pg_geometry': ( + pg_types.POINTOID, + pg_types.LSEGOID, + pg_types.BOXOID, + pg_types.CIRCLEOID, + ), + + 'stdlib_datetime' : ( + pg_types.DATEOID, + pg_types.INTERVALOID, + pg_types.TIMEOID, + pg_types.TIMETZOID, + pg_types.TIMESTAMPOID, + pg_types.TIMESTAMPTZOID + ), + + 'stdlib_decimal' : ( + pg_types.NUMERICOID, + ), + + 'stdlib_uuid' : ( + pg_types.UUIDOID, + ), + + 'stdlib_xml_etree' : ( + pg_types.XMLOID, + ), + + 'stdlib_jsonb' : ( + pg_types.JSONBOID, + ), + + # Must be db.typio.identify(contrib_hstore = 'hstore')'d + 'contrib_hstore' : ( + 'contrib_hstore', + ), +} + +# OID -> module name +module_io = dict( + chain.from_iterable(( + zip(x[1], cycle((x[0],))) for x in io_modules.items() + )) +) + +if sys.version_info[:2] < (3,3): + def load(relmod): + return __import__(__name__ + '.' + relmod, fromlist = True, level = 1) +else: + def load(relmod): + return __import__(relmod, globals = globals(), locals = locals(), fromlist = [''], level = 1) + +def resolve(oid): + io = module_io.get(oid) + if io is None: + return None + if io.__class__ is str: + module_io.update(load(io).oid_to_io) + io = module_io[oid] + return io diff --git a/py_opengauss/types/io/builtins.py b/py_opengauss/types/io/builtins.py new file mode 100644 index 0000000000000000000000000000000000000000..0b180dfb54ac7b51eb0b97a641a60cd60a1ef123 --- /dev/null +++ b/py_opengauss/types/io/builtins.py @@ -0,0 +1,54 @@ +from .. import \ + INT2OID, INT4OID, INT8OID, \ + BOOLOID, BYTEAOID, CHAROID, \ + ABSTIMEOID, FLOAT4OID, FLOAT8OID, \ + TEXTOID, BPCHAROID, NAMEOID, VARCHAROID +from . import lib + +bool_pack = {True:b'\x01', False:b'\x00'}.__getitem__ +bool_unpack = {b'\x01':True, b'\x00':False}.__getitem__ + +int2_pack, int2_unpack = lib.short_pack, lib.short_unpack +int4_pack, int4_unpack = lib.long_pack, lib.long_unpack +int8_pack, int8_unpack = lib.longlong_pack, lib.longlong_unpack + +bytea_pack = bytes +bytea_unpack = bytes +char_pack = bytes +char_unpack = bytes + +oid_to_io = { + BOOLOID : (bool_pack, bool_unpack, bool), + + BYTEAOID : (bytea_pack, bytea_unpack, bytes), + CHAROID : (char_pack, char_unpack, bytes), + + INT2OID : (int2_pack, int2_unpack, int), + INT4OID : (int4_pack, int4_unpack, int), + INT8OID : (int8_pack, int8_unpack, int), + + ABSTIMEOID : (lib.long_pack, lib.long_unpack, int), + FLOAT4OID : (lib.float_pack, lib.float_unpack, float), + FLOAT8OID : (lib.double_pack, lib.double_unpack, float), +} + +# Python Representations of PostgreSQL Types +oid_to_type = { + BOOLOID: bool, + + VARCHAROID: str, + TEXTOID: str, + BPCHAROID: str, + NAMEOID: str, + + # This is *not* bpchar, the SQL CHARACTER type. + CHAROID: bytes, + BYTEAOID: bytes, + + INT2OID: int, + INT4OID: int, + INT8OID: int, + + FLOAT4OID: float, + FLOAT8OID: float, +} diff --git a/py_opengauss/types/io/contrib_hstore.py b/py_opengauss/types/io/contrib_hstore.py new file mode 100644 index 0000000000000000000000000000000000000000..927986323a9562782d50528e8839c2975b475a51 --- /dev/null +++ b/py_opengauss/types/io/contrib_hstore.py @@ -0,0 +1,48 @@ +## +# .types.io.contrib_hstore - I/O routines for binary hstore +## +from ...python.structlib import split_sized_data, ulong_pack, ulong_unpack +from ...python.itertools import chunk + +## +# Build the hstore I/O pair for a given typio. +# It primarily needs typio for decode and encode. +def hstore_factory(oid, typio, + unpack_err = "expected {0} items in hstore, but found {1}".format +): + def pack_hstore(x, + encode = typio.encode, + len = len, + ): + if hasattr(x, 'items'): + x = x.items() + encoded = [ + (encode(k), encode(v)) if v is not None else (encode(k), None) + for k,v in x + ] + return ulong_pack(len(encoded)) + b''.join( + ulong_pack(len(k)) + k + b'\xFF\xFF\xFF\xFF' + if v is None else ulong_pack(len(k)) + k + ulong_pack(len(v)) + v + for k,v in encoded + ) + + def unpack_hstore(x, + decode = typio.decode, + split = split_sized_data, + len = len + ): + view = memoryview(x)[4:] + n = ulong_unpack(x) + r = { + decode(y[0]) : (decode(y[1]) if y[1] is not None else None) + for y in chunk(split(view), 2) if y + } + if len(r) != n: + raise ValueError(unpack_err(n, len(r))) + return r + + return (pack_hstore, unpack_hstore) + +oid_to_io = { + 'contrib_hstore' : hstore_factory, +} diff --git a/py_opengauss/types/io/lib.py b/py_opengauss/types/io/lib.py new file mode 100644 index 0000000000000000000000000000000000000000..f9638c71adf4a40bb5b93514a0eebbb8869bdfac --- /dev/null +++ b/py_opengauss/types/io/lib.py @@ -0,0 +1,476 @@ +import struct +from math import floor +from ...python.functools import Composition as compose +from ...python.itertools import interlace +from ...python.structlib import \ + short_pack, short_unpack, \ + ulong_pack, ulong_unpack, \ + long_pack, long_unpack, \ + double_pack, double_unpack, \ + longlong_pack, longlong_unpack, \ + float_pack, float_unpack, \ + LH_pack, LH_unpack, \ + dl_pack, dl_unpack, \ + dll_pack, dll_unpack, \ + ql_pack, ql_unpack, \ + qll_pack, qll_unpack, \ + llL_pack, llL_unpack, \ + dd_pack, dd_unpack, \ + ddd_pack, ddd_unpack, \ + dddd_pack, dddd_unpack, \ + hhhh_pack, hhhh_unpack + +oid_pack = cid_pack = xid_pack = ulong_pack +oid_unpack = cid_unpack = xid_unpack = ulong_unpack +tid_pack, tid_unpack = LH_pack, LH_unpack + +# geometry types +point_pack, point_unpack = dd_pack, dd_unpack +circle_pack, circle_unpack = ddd_pack, ddd_unpack +lseg_pack = box_pack = dddd_pack +lseg_unpack = box_unpack = dddd_unpack + +null_sequence = b'\xff\xff\xff\xff' +string_format = b'\x00\x00' +binary_format = b'\x00\x01' + +def numeric_pack(data, hhhh_pack = hhhh_pack, pack = struct.pack, len = len): + return hhhh_pack(data[0]) + pack("!%dh"%(len(data[1]),), *data[1]) + +def numeric_unpack(data, hhhh_unpack = hhhh_unpack, unpack = struct.unpack, len = len): + return (hhhh_unpack(data[:8]), unpack("!8x%dh"%((len(data)-8) // 2,), data)) + +def path_pack(data, pack = struct.pack, len = len): + """ + Given a sequence of point data, pack it into a path's serialized form. + + [px1, py1, px2, py2, ...] + + Must be an even number of numbers. + """ + return pack("!l%dd" %(len(data),), len(data), *data) + +def path_unpack(data, long_unpack = long_unpack, unpack = struct.unpack): + """ + Unpack a path's serialized form into a sequence of point data: + + [px1, py1, px2, py2, ...] + + Should be an even number of numbers. + """ + return unpack("!4x%dd" %(long_unpack(data[:4]),), data) +polygon_pack, polygon_unpack = path_pack, path_unpack + +## +# Binary representations of infinity for datetimes. +time_infinity = b'\x7f\xf0\x00\x00\x00\x00\x00\x00' +time_negative_infinity = b'\xff\xf0\x00\x00\x00\x00\x00\x00' +time64_infinity = b'\x7f\xff\xff\xff\xff\xff\xff\xff' +time64_negative_infinity = b'\x80\x00\x00\x00\x00\x00\x00\x00' +date_infinity = b'\x7f\xff\xff\xff' +date_negative_infinity = b'\x80\x00\x00\x00' + +# time types +date_pack, date_unpack = long_pack, long_unpack + +def mktimetuple(ts, floor = floor): + 'make a pair of (seconds, microseconds) out of the given double' + seconds = floor(ts) + return (int(seconds), int(1000000 * (ts - seconds))) + +def mktimetuple64(ts, divmod = divmod): + 'make a pair of (seconds, microseconds) out of the given long' + return divmod(ts, 1000000) + +def mktime(seconds_ms, float = float): + 'make a double out of the pair of (seconds, microseconds)' + return float(seconds_ms[0]) + (seconds_ms[1] / 1000000.0) + +def mktime64(seconds_ms): + 'make an integer out of the pair of (seconds, microseconds)' + return seconds_ms[0] * 1000000 + seconds_ms[1] + +# takes a pair, (seconds, microseconds) +time_pack = compose((mktime, double_pack)) +time_unpack = compose((double_unpack, mktimetuple)) + +def interval_pack(m_d_timetup, mktime = mktime, dll_pack = dll_pack): + """ + Given a triple, (month, day, (seconds, microseconds)), serialize it for + transport. + """ + (month, day, timetup) = m_d_timetup + return dll_pack((mktime(timetup), day, month)) + +def interval_unpack(data, dll_unpack = dll_unpack, mktimetuple = mktimetuple): + """ + Given a serialized interval, '{month}{day}{time}', yield the triple: + + (month, day, (seconds, microseconds)) + """ + tim, day, month = dll_unpack(data) + return (month, day, mktimetuple(tim)) + +def interval_noday_pack(month_day_timetup, dl_pack = dl_pack, mktime = mktime): + """ + Given a triple, (month, day, (seconds, microseconds)), return the serialized + form that does not have an individual day component. + + There is no day component, so if day is non-zero, it will be converted to + seconds and subsequently added to the seconds. + """ + (month, day, timetup) = month_day_timetup + if day: + timetup = (timetup[0] + (day * 24 * 60 * 60), timetup[1]) + return dl_pack((mktime(timetup), month)) + +def interval_noday_unpack(data, dl_unpack = dl_unpack, mktimetuple = mktimetuple): + """ + Given a serialized interval without a day component, return the triple: + + (month, day(always zero), (seconds, microseconds)) + """ + tim, month = dl_unpack(data) + return (month, 0, mktimetuple(tim)) + +def time64_pack(data, mktime64 = mktime64, longlong_pack = longlong_pack): + return longlong_pack(mktime64(data)) +def time64_unpack(data, longlong_unpack = longlong_unpack, mktimetuple64 = mktimetuple64): + return mktimetuple64(longlong_unpack(data)) + +def interval64_pack(m_d_timetup, qll_pack = qll_pack, mktime64 = mktime64): + """ + Given a triple, (month, day, (seconds, microseconds)), return the serialized + data using a quad-word for the (seconds, microseconds) tuple. + """ + (month, day, timetup) = m_d_timetup + return qll_pack((mktime64(timetup), day, month)) + +def interval64_unpack(data, qll_unpack = qll_unpack, mktimetuple = mktimetuple): + """ + Unpack an interval containing a quad-word into a triple: + + (month, day, (seconds, microseconds)) + """ + tim, day, month = qll_unpack(data) + return (month, day, mktimetuple64(tim)) + +def interval64_noday_pack(m_d_timetup, ql_pack = ql_pack, mktime64 = mktime64): + """ + Pack an interval without a day component and using a quad-word for second + representation. + + There is no day component, so if day is non-zero, it will be converted to + seconds and subsequently added to the seconds. + """ + (month, day, timetup) = m_d_timetup + if day: + timetup = (timetup[0] + (day * 24 * 60 * 60), timetup[1]) + return ql_pack((mktime64(timetup), month)) + +def interval64_noday_unpack(data, ql_unpack = ql_unpack, mktimetuple64 = mktimetuple64): + """ + Unpack a ``noday`` quad-word based interval. Returns a triple: + + (month, day(always zero), (seconds, microseconds)) + """ + tim, month = ql_unpack(data) + return (month, 0, mktimetuple64(tim)) + +def timetz_pack(timetup_tz, dl_pack = dl_pack, mktime = mktime): + """ + Pack a time; offset from beginning of the day and timezone offset. + + Given a pair, ((seconds, microseconds), timezone_offset), pack it into its + serialized form: "!dl". + """ + (timetup, tz_offset) = timetup_tz + return dl_pack((mktime(timetup), tz_offset)) + +def timetz_unpack(data, dl_unpack = dl_unpack, mktimetuple = mktimetuple): + """ + Given serialized time data, unpack it into a pair: + + ((seconds, microseconds), timezone_offset). + """ + ts, tz = dl_unpack(data) + return (mktimetuple(ts), tz) + +def timetz64_pack(timetup_tz, ql_pack = ql_pack, mktime64 = mktime64): + """ + Pack a time; offset from beginning of the day and timezone offset. + + Given a pair, ((seconds, microseconds), timezone_offset), pack it into its + serialized form using a long long: "!ql". + """ + (timetup, tz_offset) = timetup_tz + return ql_pack((mktime64(timetup), tz_offset)) + +def timetz64_unpack(data, ql_unpack = ql_unpack, mktimetuple64 = mktimetuple64): + """ + Given "long long" serialized time data, "ql", unpack it into a pair: + + ((seconds, microseconds), timezone_offset) + """ + ts, tz = ql_unpack(data) + return (mktimetuple64(ts), tz) + +# oidvectors are 128 bytes, so pack the number of Oids in self +# and justify that to 128 by padding with \x00. +def oidvector_pack(seq, pack = struct.pack): + """ + Given a sequence of Oids, pack them into the serialized form. + + An oidvector is a type used by the PostgreSQL catalog. + """ + return pack("!%dL"%(len(seq),), *seq).ljust(128, '\x00') + +def oidvector_unpack(data, unpack = struct.unpack): + """ + Given a serialized oidvector(32 longs), unpack it into a list of unsigned integers. + + An int2vector is a type used by the PostgreSQL catalog. + """ + return unpack("!32L", data) + +def int2vector_pack(seq, pack = struct.pack): + """ + Given a sequence of integers, pack them into the serialized form. + + An int2vector is a type used by the PostgreSQL catalog. + """ + return pack("!%dh"%(len(seq),), *seq).ljust(64, '\x00') + +def int2vector_unpack(data, unpack = struct.unpack): + """ + Given a serialized int2vector, unpack it into a list of integers. + + An int2vector is a type used by the PostgreSQL catalog. + """ + return unpack("!32h", data) + +def varbit_pack(bits_data, long_pack = long_pack): + r""" + Given a pair, serialize the varbit. + + # (number of bits, data) + >>> varbit_pack((1, '\x00')) + b'\x00\x00\x00\x01\x00' + """ + return long_pack(bits_data[0]) + bits_data[1] + +def varbit_unpack(data, long_unpack = long_unpack): + """ + Given ``varbit`` data, unpack it into a pair: + + (bits, data) + + Where bits are the total number of bits in data (bytes). + """ + return long_unpack(data[0:4]), data[4:] + +def net_pack(triple, + # Map PGSQL src/include/utils/inet.h to IP version number. + fmap = { + 4: 2, + 6: 3, + }, + len = len, +): + """ + net_pack((family, mask, data)) + + Pack Postgres' inet/cidr data structure. + """ + family, mask, data = triple + return bytes((fmap[family], mask or 0, 0 if mask is None else 1, len(data))) + data + +def net_unpack(data, + # Map IP version number to PGSQL src/include/utils/inet.h. + fmap = { + 2: 4, + 3: 6, + } +): + """ + net_unpack(data) + + Unpack Postgres' inet/cidr data structure. + """ + family, mask, is_cidr, size = data[:4] + return (fmap[family], mask, data[4:]) + +def macaddr_pack(data, bytes = bytes): + """ + Pack a MAC address + + Format found in PGSQL src/backend/utils/adt/mac.c, and PGSQL Manual types + """ + # Accept all possible PGSQL Macaddr formats as in manual + # Oh for sscanf() as we could just copy PGSQL C in src/util/adt/mac.c + colon_parts = data.split(':') + dash_parts = data.split('-') + dot_parts = data.split('.') + if len(colon_parts) == 6: + mac_parts = colon_parts + elif len(dash_parts) == 6: + mac_parts = dash_parts + elif len(colon_parts) == 2: + mac_parts = [colon_parts[0][:2], colon_parts[0][2:4], colon_parts[0][4:], + colon_parts[1][:2], colon_parts[1][2:4], colon_parts[1][4:]] + elif len(dash_parts) == 2: + mac_parts = [dash_parts[0][:2], dash_parts[0][2:4], dash_parts[0][4:], + dash_parts[1][:2], dash_parts[1][2:4], dash_parts[1][4:]] + elif len(dot_parts) == 3: + mac_parts = [dot_parts[0][:2], dot_parts[0][2:], dot_parts[1][:2], + dot_parts[1][2:], dot_parts[2][:2], dot_parts[2][2:]] + elif len(colon_parts) == 1: + mac_parts = [data[:2], data[2:4], data[4:6], data[6:8], data[8:10], data[10:]] + else: + raise ValueError('data string cannot be parsed to bytes') + if len(mac_parts) != 6 and len(mac_parts[-1]) != 2: + raise ValueError('data string cannot be parsed to bytes') + return bytes([int(p, 16) for p in mac_parts]) + +def macaddr_unpack(data): + """ + Unpack a MAC address + + Format found in PGSQL src/backend/utils/adt/mac.c + """ + # This is easy, just go for standard macaddr format, + # just like PGSQL in src/util/adt/mac.c macaddr_out() + if len(data) != 6: + raise ValueError('macaddr has incorrect length') + return ("%02x:%02x:%02x:%02x:%02x:%02x" % tuple(data)) + +def record_unpack(data, + long_unpack = long_unpack, + oid_unpack = oid_unpack, + null_sequence = null_sequence, + len = len): + """ + Given serialized record data, return a tuple of tuples of type Oids and + attributes. + """ + columns = long_unpack(data) + offset = 4 + + for x in range(columns): + typid = oid_unpack(data[offset:offset+4]) + offset += 4 + + if data[offset:offset+4] == null_sequence: + att = None + offset += 4 + else: + size = long_unpack(data[offset:offset+4]) + offset += 4 + att = data[offset:offset + size] + if size < -1 or len(att) != size: + raise ValueError("insufficient data left in message") + offset += size + yield (typid, att) + + if len(data) - offset != 0: + raise ValueError("extra data, %d octets, at end of record" %(len(data),)) + +def record_pack(seq, + long_pack = long_pack, + oid_pack = oid_pack, + null_sequence = null_sequence): + """ + pack a record given an iterable of (type_oid, data) pairs. + """ + return long_pack(len(seq)) + b''.join([ + # typid + (null_seq or data) + oid_pack(x) + (y is None and null_sequence or (long_pack(len(y)) + y)) + for x, y in seq + ]) + +def elements_pack(elements, + null_sequence = null_sequence, + long_pack = long_pack, len = len +): + """ + Pack the elements for containment within a serialized array. + + This is used by array_pack. + """ + for x in elements: + if x is None: + yield null_sequence + else: + yield long_pack(len(x)) + yield x + +def array_pack(array_data, + llL_pack = llL_pack, + len = len, + long_pack = long_pack, + interlace = interlace +): + """ + Pack a raw array. A raw array consists of flags, type oid, sequence of lower + and upper bounds, and an iterable of already serialized element data: + + (0, element type oid, (lower bounds, upper bounds, ...), iterable of element_data) + + The lower bounds and upper bounds specifies boundaries of the dimension. So the length + of the boundaries sequence is two times the number of dimensions that the array has. + + array_pack((flags, type_id, dims, lowers, element_data)) + + The format of ``lower_upper_bounds`` is a sequence of lower bounds and upper + bounds. First lower then upper inlined within the sequence: + + [lower, upper, lower, upper] + + The above array `dlb` has two dimensions. The lower and upper bounds of the + first dimension is defined by the first two elements in the sequence. The + second dimension is then defined by the last two elements in the sequence. + """ + (flags, typid, dims, lbs, elements) = array_data + return llL_pack((len(dims), flags, typid)) + \ + b''.join(map(long_pack, interlace(dims, lbs))) + \ + b''.join(elements_pack(elements)) + +def elements_unpack(data, offset, + long_unpack = long_unpack, + null_sequence = null_sequence): + """ + Unpack the serialized elements of an array into a list. + + This is used by array_unpack. + """ + data_len = len(data) + while offset < data_len: + lend = data[offset:offset+4] + offset += 4 + if lend == null_sequence: + yield None + else: + sizeof_el = long_unpack(lend) + yield data[offset:offset+sizeof_el] + offset += sizeof_el + +def array_unpack(data, + llL_unpack = llL_unpack, + unpack = struct.unpack_from, + long_unpack = long_unpack +): + """ + Given a serialized array, unpack it into a tuple: + + (flags, typid, (dims, lower bounds, ...), [elements]) + """ + ndim, flags, typid = llL_unpack(data) + if ndim < 0: + raise ValueError("invalid number of dimensions: %d" %(ndim,)) + # "ndim" number of pairs of longs + end = (4 * 2 * ndim) + 12 + # Dimensions and lower bounds; split the two early. + #dlb = unpack("!%dl"%(2 * ndim,), data, 12) + dims = [long_unpack(data[x:x+4]) for x in range(12, end, 8)] + lbs = [long_unpack(data[x:x+4]) for x in range(16, end, 8)] + return (flags, typid, dims, lbs, elements_unpack(data, end)) diff --git a/py_opengauss/types/io/pg_bitwise.py b/py_opengauss/types/io/pg_bitwise.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb36e965b8a7ef486a2f5e2a0bbcd7b67c1e52c --- /dev/null +++ b/py_opengauss/types/io/pg_bitwise.py @@ -0,0 +1,19 @@ +from .. import BITOID, VARBITOID +from ..bitwise import Varbit, Bit +from . import lib + +def varbit_pack(x, pack = lib.varbit_pack): + return pack((x.bits, x.data)) + +def varbit_unpack(x, unpack = lib.varbit_unpack): + return Varbit.from_bits(*unpack(x)) + +oid_to_io = { + BITOID : (varbit_pack, varbit_unpack, Bit), + VARBITOID : (varbit_pack, varbit_unpack, Varbit), +} + +oid_to_type = { + BITOID : Bit, + VARBITOID : Varbit, +} diff --git a/py_opengauss/types/io/pg_geometry.py b/py_opengauss/types/io/pg_geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..56d7759f26db452002f8a959ece2d2c244072508 --- /dev/null +++ b/py_opengauss/types/io/pg_geometry.py @@ -0,0 +1,43 @@ +from .. import POINTOID, BOXOID, LSEGOID, CIRCLEOID +from ..geometry import Point, Box, Lseg, Circle +from ...python.functools import Composition as compose +from . import lib + +oid_to_type = { + POINTOID: Point, + BOXOID: Box, + LSEGOID: Lseg, + CIRCLEOID: Circle, +} + +# Make a pair of pairs out of a sequence of four objects +def two_pair(x): + return ((x[0], x[1]), (x[2], x[3])) + +point_pack = lib.point_pack +point_unpack = compose((lib.point_unpack, Point)) + +def box_pack(x): + return lib.box_pack((x[0][0], x[0][1], x[1][0], x[1][1])) +box_unpack = compose((lib.box_unpack, two_pair, Box,)) + +def lseg_pack(x, pack = lib.lseg_pack): + return pack((x[0][0], x[0][1], x[1][0], x[1][1])) +lseg_unpack = compose((lib.lseg_unpack, two_pair, Lseg)) + +def circle_pack(x): + return lib.circle_pack((x[0][0], x[0][1], x[1])) +def circle_unpack(x, unpack = lib.circle_unpack, Circle = Circle): + x = unpack(x) + return Circle(((x[0], x[1]), x[2])) + +# Map type oids to a (pack, unpack) pair. +oid_to_io = { + POINTOID : (point_pack, point_unpack, Point), + BOXOID : (box_pack, box_unpack, Box), + LSEGOID : (lseg_pack, lseg_unpack, Lseg), + CIRCLEOID : (circle_pack, circle_unpack, Circle), + #PATHOID : (path_pack, path_unpack), + #POLYGONOID : (polygon_pack, polygon_unpack), + #LINEOID : (line_pack, line_unpack), +} diff --git a/py_opengauss/types/io/pg_network.py b/py_opengauss/types/io/pg_network.py new file mode 100644 index 0000000000000000000000000000000000000000..a74d5e16dca532a918353357a51748c61d792cff --- /dev/null +++ b/py_opengauss/types/io/pg_network.py @@ -0,0 +1,31 @@ +from .. import INETOID, CIDROID, MACADDROID +from . import lib +import ipaddress + +oid_to_type = { + MACADDROID : str, + INETOID: ipaddress._IPAddressBase, + CIDROID: ipaddress._BaseNetwork, +} + +def inet_pack(ob, pack = lib.net_pack, Constructor = ipaddress.ip_address): + a = Constructor(ob) + return pack((a.version, None, a.packed)) + +def cidr_pack(ob, pack = lib.net_pack, Constructor = ipaddress.ip_network): + a = Constructor(ob) + return pack((a.version, a.prefixlen, a.network_address.packed)) + +def inet_unpack(data, unpack = lib.net_unpack, Constructor = ipaddress.ip_address): + version, mask, data = unpack(data) + return Constructor(data) + +def cidr_unpack(data, unpack = lib.net_unpack, Constructor = ipaddress.ip_network): + version, mask, data = unpack(data) + return Constructor(data).supernet(new_prefix=mask) + +oid_to_io = { + MACADDROID : (lib.macaddr_pack, lib.macaddr_unpack, str), + CIDROID : (cidr_pack, cidr_unpack, str), + INETOID : (inet_pack, inet_unpack, str), +} diff --git a/py_opengauss/types/io/pg_system.py b/py_opengauss/types/io/pg_system.py new file mode 100644 index 0000000000000000000000000000000000000000..70ce964619652bc64b344076f51b46ee2c0a5813 --- /dev/null +++ b/py_opengauss/types/io/pg_system.py @@ -0,0 +1,10 @@ +from ...types import OIDOID, XIDOID, CIDOID, TIDOID +from . import lib + +oid_to_io = { + OIDOID : (lib.oid_pack, lib.oid_unpack), + XIDOID : (lib.xid_pack, lib.xid_unpack), + CIDOID : (lib.cid_pack, lib.cid_unpack), + TIDOID : (lib.tid_pack, lib.tid_unpack), + #ACLITEMOID : (aclitem_pack, aclitem_unpack), +} diff --git a/py_opengauss/types/io/stdlib_datetime.py b/py_opengauss/types/io/stdlib_datetime.py new file mode 100644 index 0000000000000000000000000000000000000000..995e126c50514275c11a5e83f7c9eaa1aa63e240 --- /dev/null +++ b/py_opengauss/types/io/stdlib_datetime.py @@ -0,0 +1,304 @@ +## +# stdlib_datetime - support for the stdlib's datetime. +# +# I/O routines for date, time, timetz, timestamp, timestamptz, and interval. +# Supported by the datetime module. +## +import datetime +import warnings +from functools import partial +from operator import methodcaller, add + +from ...python.datetime import UTC, FixedOffset, \ + infinity_date, infinity_datetime, \ + negative_infinity_date, negative_infinity_datetime +from ...python.functools import Composition as compose +from ...exceptions import TypeConversionWarning + +from .. import \ + DATEOID, INTERVALOID, \ + TIMEOID, TIMETZOID, \ + TIMESTAMPOID, TIMESTAMPTZOID + +from . import lib + +oid_to_type = { + DATEOID: datetime.date, + TIMESTAMPOID: datetime.datetime, + TIMESTAMPTZOID: datetime.datetime, + TIMEOID: datetime.time, + TIMETZOID: datetime.time, + + # XXX: datetime.timedelta doesn't support months. + INTERVALOID: datetime.timedelta, +} + +seconds_in_day = 24 * 60 * 60 +seconds_in_hour = 60 * 60 + +pg_epoch_datetime = datetime.datetime(2000, 1, 1) +pg_epoch_datetime_utc = pg_epoch_datetime.replace(tzinfo = UTC) +pg_epoch_date = pg_epoch_datetime.date() +pg_date_offset = pg_epoch_date.toordinal() + +## Difference between PostgreSQL epoch and Unix epoch. +## Used to convert a PostgreSQL ordinal to an ordinal usable by datetime +pg_time_days = (pg_date_offset - datetime.date(1970, 1, 1).toordinal()) + +## +# Constants used to special case infinity and -infinity. +time64_pack_constants = { + infinity_datetime: lib.time64_infinity, + negative_infinity_datetime: lib.time64_negative_infinity, + 'infinity': lib.time64_infinity, + '-infinity': lib.time64_negative_infinity, +} +time_pack_constants = { + infinity_datetime: lib.time_infinity, + negative_infinity_datetime: lib.time_negative_infinity, + 'infinity': lib.time_infinity, + '-infinity': lib.time_negative_infinity, +} +date_pack_constants = { + infinity_date: lib.date_infinity, + negative_infinity_date: lib.date_negative_infinity, + 'infinity': lib.date_infinity, + '-infinity': lib.date_negative_infinity, +} +time64_unpack_constants = { + lib.time64_infinity: infinity_datetime, + lib.time64_negative_infinity: negative_infinity_datetime, +} +time_unpack_constants = { + lib.time_infinity: infinity_datetime, + lib.time_negative_infinity: negative_infinity_datetime, +} +date_unpack_constants = { + lib.date_infinity: infinity_date, + lib.date_negative_infinity: negative_infinity_date, +} + +def date_pack(x, + pack = lib.date_pack, + offset = pg_date_offset, + get = date_pack_constants.get, +): + return get(x) or pack(x.toordinal() - offset) + +def date_unpack(x, + unpack = lib.date_unpack, + offset = pg_date_offset, + from_ord = datetime.date.fromordinal, + get = date_unpack_constants.get, +): + return get(x) or from_ord(unpack(x) + pg_date_offset) + +def timestamp_pack(x, + seconds_in_day = seconds_in_day, + pg_epoch_datetime = pg_epoch_datetime, +): + """ + Create a (seconds, microseconds) pair from a `datetime.datetime` instance. + """ + x = (x - pg_epoch_datetime) + return ((x.days * seconds_in_day) + x.seconds, x.microseconds) + +def timestamp_unpack(seconds, + timedelta = datetime.timedelta, + relative_to = pg_epoch_datetime.__add__, +): + """ + Create a `datetime.datetime` instance from a (seconds, microseconds) pair. + """ + return relative_to(timedelta(0, *seconds)) + +def timestamptz_pack(x, + seconds_in_day = seconds_in_day, + pg_epoch_datetime_utc = pg_epoch_datetime_utc, + UTC = UTC, +): + """ + Create a (seconds, microseconds) pair from a `datetime.datetime` instance. + """ + x = (x.astimezone(UTC) - pg_epoch_datetime_utc) + return ((x.days * seconds_in_day) + x.seconds, x.microseconds) + +def timestamptz_unpack(seconds, + timedelta = datetime.timedelta, + relative_to = pg_epoch_datetime_utc.__add__, +): + """ + Create a `datetime.datetime` instance from a (seconds, microseconds) pair. + """ + return relative_to(timedelta(0, *seconds)) + +def time_pack(x, seconds_in_hour = seconds_in_hour): + """ + Create a (seconds, microseconds) pair from a `datetime.time` instance. + """ + return ( + (x.hour * seconds_in_hour) + (x.minute * 60) + x.second, + x.microsecond + ) + +def time_unpack(seconds_ms, time = datetime.time, divmod = divmod): + """ + Create a `datetime.time` instance from a (seconds, microseconds) pair. + Seconds being offset from epoch. + """ + seconds, ms = seconds_ms + minutes, sec = divmod(seconds, 60) + hours, min = divmod(minutes, 60) + return time(hours, min, sec, ms) + +def interval_pack(x): + """ + Create a (months, days, (seconds, microseconds)) tuple from a + `datetime.timedelta` instance. + """ + return (0, x.days, (x.seconds, x.microseconds)) + +def interval_unpack(mds, timedelta = datetime.timedelta): + """ + Given a (months, days, (seconds, microseconds)) tuple, create a + `datetime.timedelta` instance. + """ + months, days, seconds_ms = mds + if months != 0: + # XXX: Should this raise an exception? + w = TypeConversionWarning( + "datetime.timedelta cannot represent relative intervals", + details = { + 'hint': 'An interval was unpacked with a non-zero "month" field.' + }, + source = 'DRIVER' + ) + warnings.warn(w) + return timedelta( + days = days + (months * 30), + seconds = seconds_ms[0], microseconds = seconds_ms[1] + ) + +def timetz_pack(x, + time_pack = time_pack, +): + """ + Create a ((seconds, microseconds), timezone) tuple from a `datetime.time` + instance. + """ + td = x.tzinfo.utcoffset(x) + seconds = (td.days * seconds_in_day + td.seconds) + return (time_pack(x), seconds) + +def timetz_unpack(tstz, + time_unpack = time_unpack, + FixedOffset = FixedOffset, +): + """ + Create a `datetime.time` instance from a ((seconds, microseconds), timezone) + tuple. + """ + t = time_unpack(tstz[0]) + return t.replace(tzinfo = FixedOffset(tstz[1])) + +FloatTimes = False +IntTimes = True +NoDay = True +WithDay = False + +# Used to handle the special cases: infinity and -infinity. +def proc_when_not_in(proc, dict): + def _proc(x, get=dict.get): + return get(x) or proc(x) + return _proc + +id_to_io = { + (FloatTimes, TIMEOID) : ( + compose((time_pack, lib.time_pack)), + compose((lib.time_unpack, time_unpack)), + datetime.time + ), + (FloatTimes, TIMETZOID) : ( + compose((timetz_pack, lib.timetz_pack)), + compose((lib.timetz_unpack, timetz_unpack)), + datetime.time + ), + (FloatTimes, TIMESTAMPOID) : ( + proc_when_not_in(compose((timestamp_pack, lib.time_pack)), time_pack_constants), + proc_when_not_in(compose((lib.time_unpack, timestamp_unpack)), time_unpack_constants), + datetime.datetime + ), + (FloatTimes, TIMESTAMPTZOID) : ( + proc_when_not_in(compose((timestamptz_pack, lib.time_pack)), time_pack_constants), + proc_when_not_in(compose((lib.time_unpack, timestamptz_unpack)), time_unpack_constants), + datetime.datetime + ), + (FloatTimes, WithDay, INTERVALOID): ( + compose((interval_pack, lib.interval_pack)), + compose((lib.interval_unpack, interval_unpack)), + datetime.timedelta + ), + (FloatTimes, NoDay, INTERVALOID): ( + compose((interval_pack, lib.interval_noday_pack)), + compose((lib.interval_noday_unpack, interval_unpack)), + datetime.timedelta + ), + + (IntTimes, TIMEOID) : ( + compose((time_pack, lib.time64_pack)), + compose((lib.time64_unpack, time_unpack)), + datetime.time + ), + (IntTimes, TIMETZOID) : ( + compose((timetz_pack, lib.timetz64_pack)), + compose((lib.timetz64_unpack, timetz_unpack)), + datetime.time + ), + (IntTimes, TIMESTAMPOID) : ( + proc_when_not_in(compose((timestamp_pack, lib.time64_pack)), time64_pack_constants), + proc_when_not_in(compose((lib.time64_unpack, timestamp_unpack)), time64_unpack_constants), + datetime.datetime + ), + (IntTimes, TIMESTAMPTZOID) : ( + proc_when_not_in(compose((timestamptz_pack, lib.time64_pack)), time64_pack_constants), + proc_when_not_in(compose((lib.time64_unpack, timestamptz_unpack)), time64_unpack_constants), + datetime.datetime + ), + (IntTimes, WithDay, INTERVALOID) : ( + compose((interval_pack, lib.interval64_pack)), + compose((lib.interval64_unpack, interval_unpack)), + datetime.timedelta + ), + (IntTimes, NoDay, INTERVALOID) : ( + compose((interval_pack, lib.interval64_noday_pack)), + compose((lib.interval64_noday_unpack, interval_unpack)), + datetime.timedelta + ), +} + +## +# Identify whether it's IntTimes or FloatTimes +def time_type(typio): + idt = typio.database.settings.get('integer_datetimes', None) + if idt is None: + # assume its absence means its on after 9.0 + return bool(typio.database.version_info >= (9,0)) + elif idt.__class__ is bool: + return idt + else: + return (idt.lower() in ('on', 'true', 't', True)) + +def select_format(oid, typio, get = id_to_io.__getitem__): + return get((time_type(typio), oid)) + +def select_day_format(oid, typio, get = id_to_io.__getitem__): + return get((time_type(typio), typio.database.version_info[:2] <= (8,0), oid)) + +oid_to_io = { + DATEOID : (date_pack, date_unpack, datetime.date,), + TIMEOID : select_format, + TIMETZOID : select_format, + TIMESTAMPOID : select_format, + TIMESTAMPTZOID : select_format, + INTERVALOID : select_day_format, +} diff --git a/py_opengauss/types/io/stdlib_decimal.py b/py_opengauss/types/io/stdlib_decimal.py new file mode 100644 index 0000000000000000000000000000000000000000..3deee388df78609caf774c07d4f5672a4e1554c7 --- /dev/null +++ b/py_opengauss/types/io/stdlib_decimal.py @@ -0,0 +1,165 @@ +## +# types.io.stdlib_decimal +# +# I/O routines for transforming NUMERIC to and from decimal.Decimal. +## +from decimal import Decimal +from operator import itemgetter, mul +# You know it's gonna get serious :) +from itertools import chain, starmap, repeat, groupby, cycle, islice +from ...types import NUMERICOID +from . import lib + +oid_to_type = { + NUMERICOID: Decimal, +} + +## +# numeric is represented using: +# 1. ndigits, the number of *numeric* digits. +# 2. weight, the *numeric* digits "left" of the decimal point +# 3. sign, negativity. see `numeric_signs` below +# 4. dscale, *display* precision. used to identify exponent. +# +# NOTE: A numeric digit is actually four digits in the representation. +# +# Python's Decimal consists of: +# 1. sign, negativity. +# 2. digits, sequence of int()'s +# 3. exponent, digits that fall to the right of the decimal point +numeric_negative = 16384 + +def numeric_pack(x, + numeric_digit_length : "number of decimal digits in a numeric digit" = 4, + get0 = itemgetter(0), + get1 = itemgetter(1), + Decimal = Decimal, + pack = lib.numeric_pack +): + if not isinstance(x, Decimal): + x = Decimal(x) + x = x.as_tuple() + if x.exponent == 'F': + raise ValueError("numeric does not support infinite values") + + # normalize trailing zeros (truncate em') + # this is important in order to get the weight and padding correct + # and to avoid packing superfluous data which will make pg angry. + trailing_zeros = 0 + weight = 0 + if x.exponent < 0: + # only attempt to truncate if there are digits after the point, + ## + for i in range(-1, max(-len(x.digits), x.exponent)-1, -1): + if x.digits[i] != 0: + break + trailing_zeros += 1 + # truncate trailing zeros right of the decimal point + # this *is* the case as exponent < 0. + if trailing_zeros: + digits = x.digits[:-trailing_zeros] + else: + digits = x.digits + # the entire exponent is just trailing zeros(zero-weight). + rdigits = -(x.exponent + trailing_zeros) + ldigits = len(digits) - rdigits + rpad = rdigits % numeric_digit_length + if rpad: + rpad = numeric_digit_length - rpad + else: + # Need the weight to be divisible by four, + # so append zeros onto digits until it is. + r = (x.exponent % numeric_digit_length) + if x.exponent and r: + digits = x.digits + ((0,) * r) + weight = (x.exponent - r) + else: + digits = x.digits + weight = x.exponent + # The exponent is not evenly divisible by four, so + # the weight can't simple be x.exponent as it doesn't + # match the size of the numeric digit. + ldigits = len(digits) + # no fractional quantity. + rdigits = 0 + rpad = 0 + + lpad = ldigits % numeric_digit_length + if lpad: + lpad = numeric_digit_length - lpad + weight += (ldigits + lpad) + + digit_groups = map( + get1, + groupby( + zip( + # group by NUMERIC digit size, + # every four digits make up a NUMERIC digit + cycle((0,) * numeric_digit_length + (1,) * numeric_digit_length), + + # multiply each digit appropriately + # for the eventual sum() into a NUMERIC digit + starmap( + mul, + zip( + # pad with leading zeros to make + # the cardinality of the digit sequence + # to be evenly divisible by four, + # the NUMERIC digit size. + chain( + repeat(0, lpad), + digits, + repeat(0, rpad), + ), + cycle([10**x for x in range(numeric_digit_length-1, -1, -1)]), + ) + ), + ), + get0, + ), + ) + return pack(( + ( + (ldigits + rdigits + lpad + rpad) // numeric_digit_length, # ndigits + (weight // numeric_digit_length) - 1, # NUMERIC weight + numeric_negative if x.sign == 1 else x.sign, # sign + - x.exponent if x.exponent < 0 else 0, # dscale + ), + list(map(sum, ([get1(y) for y in x] for x in digit_groups))), + )) + +def numeric_convert_digits(d, str = str, int = int): + i = iter(d) + try: + for x in str(next(i)): + # no leading zeros + yield int(x) + # leading digit should not include zeros + for y in i: + for x in str(y).rjust(4, '0'): + yield int(x) + except StopIteration: + # Python 3.5+ does not like generators raising StopIteration + return + +numeric_signs = { + numeric_negative : 1, +} + +def numeric_unpack(x, unpack = lib.numeric_unpack): + header, digits = unpack(x) + npad = (header[3] - ((header[0] - (header[1] + 1)) * 4)) + return Decimal(( + numeric_signs.get(header[2], header[2]), + tuple(chain( + numeric_convert_digits(digits), + (0,) * npad + ) if npad >= 0 else list( + numeric_convert_digits(digits) + )[:npad]), + -header[3] + )) + +oid_to_io = { + NUMERICOID : (numeric_pack, numeric_unpack, Decimal), +} diff --git a/py_opengauss/types/io/stdlib_jsonb.py b/py_opengauss/types/io/stdlib_jsonb.py new file mode 100644 index 0000000000000000000000000000000000000000..08223569d7b9a0912d7342253bcbee39aecb738d --- /dev/null +++ b/py_opengauss/types/io/stdlib_jsonb.py @@ -0,0 +1,24 @@ +from ...types import JSONBOID + + +def jsonb_pack(x, typeio): + jsonb = typeio.encode(x) + return b'\x01' + jsonb + + +def jsonb_unpack(x, typeio): + if x[0] != 1: + raise ValueError('unexpected JSONB format version: {!r}'.format(x[0])) + return typeio.decode(x[1:]) + + +def _jsonb_io_factory(oid, typeio): + _pack = lambda x: jsonb_pack(x, typeio) + _unpack = lambda x: jsonb_unpack(x, typeio) + + return (_pack, _unpack, str) + + +oid_to_io = { + JSONBOID: _jsonb_io_factory +} diff --git a/py_opengauss/types/io/stdlib_uuid.py b/py_opengauss/types/io/stdlib_uuid.py new file mode 100644 index 0000000000000000000000000000000000000000..3fc33b912bb881cf8ed1ec1c8dddd4b9fad35681 --- /dev/null +++ b/py_opengauss/types/io/stdlib_uuid.py @@ -0,0 +1,14 @@ +import uuid +from ...types import UUIDOID + +def uuid_pack(x, UUID = uuid.UUID, bytes = bytes): + if isinstance(x, UUID): + return bytes(x.bytes) + return bytes(UUID(x).bytes) + +def uuid_unpack(x, UUID = uuid.UUID): + return UUID(bytes=x) + +oid_to_io = { + UUIDOID : (uuid_pack, uuid_unpack), +} diff --git a/py_opengauss/types/io/stdlib_xml_etree.py b/py_opengauss/types/io/stdlib_xml_etree.py new file mode 100644 index 0000000000000000000000000000000000000000..52d56f7f449cf82e2c85b88310f44ac8156e7303 --- /dev/null +++ b/py_opengauss/types/io/stdlib_xml_etree.py @@ -0,0 +1,83 @@ +## +# types.io.stdlib_xml_etree +## +try: + import xml.etree.cElementTree as etree +except ImportError: + import xml.etree.ElementTree as etree +from .. import XMLOID +from ...python.functools import Composition as compose + +oid_to_type = { + XMLOID: etree.ElementTree, +} + +def xml_unpack(xmldata, XML = etree.XML): + try: + return XML(xmldata) + except Exception: + # try it again, but return the sequence of children. + return tuple(XML('' + xmldata + '')) + +if not hasattr(etree, 'tostringlist'): + # Python 3.1 support. + def xml_pack(xml, tostr = etree.tostring, et = etree.ElementTree, + str = str, isinstance = isinstance, tuple = tuple + ): + if isinstance(xml, str): + # If it's a string, encode and return. + return xml + elif isinstance(xml, tuple): + # If it's a tuple, encode and return the joined items. + # We do not accept lists here--emphasizing lists being used for ARRAY + # bounds. + return ''.join((x if isinstance(x, str) else tostr(x) for x in xml)) + return tostr(xml) + + def xml_io_factory(typoid, typio, c = compose): + return ( + c((xml_pack, typio.encode)), + c((typio.decode, xml_unpack)), + etree.ElementTree, + ) +else: + # New etree tostring API. + def xml_pack(xml, encoding, encoder, + tostr = etree.tostring, et = etree.ElementTree, + str = str, isinstance = isinstance, tuple = tuple, + ): + if isinstance(xml, bytes): + return xml + if isinstance(xml, str): + # If it's a string, encode and return. + return encoder(xml) + elif isinstance(xml, tuple): + # If it's a tuple, encode and return the joined items. + # We do not accept lists here--emphasizing lists being used for ARRAY + # bounds. + ## + # 3.2 + # XXX: tostring doesn't include declaration with utf-8? + x = b''.join( + x.encode('utf-8') if isinstance(x, str) else + tostr(x, encoding = "utf-8") + for x in xml + ) + else: + ## + # 3.2 + # XXX: tostring doesn't include declaration with utf-8? + x = tostr(xml, encoding = "utf-8") + if encoding in ('utf8','utf-8'): + return x + else: + return encoder(x.decode('utf-8')) + + def xml_io_factory(typoid, typio, c = compose): + def local_xml_pack(x, encoder = typio.encode, typio = typio, xml_pack = xml_pack): + return xml_pack(x, typio.encoding, encoder) + return (local_xml_pack, c((typio.decode, xml_unpack)), etree.ElementTree,) + +oid_to_io = { + XMLOID : xml_io_factory +} diff --git a/py_opengauss/types/namedtuple.py b/py_opengauss/types/namedtuple.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ac61b8b5fefd5bc6f9934265bcc11227ae681a --- /dev/null +++ b/py_opengauss/types/namedtuple.py @@ -0,0 +1,68 @@ +## +# .types.namedtuple - return rows as namedtuples +## +""" +Factories for namedtuple row representation. +""" +from collections import namedtuple + +#: Global namedtuple type cache. +cache = {} + +# Build and cache the namedtuple's produced. +def _factory(colnames : [str], namedtuple = namedtuple) -> tuple: + global cache + # Provide some normalization. + # Anything beyond this can just get renamed. + colnames = tuple([ + x.replace(' ', '_') for x in colnames + ]) + try: + return cache[colnames] + except KeyError: + NT = namedtuple('row', colnames, rename = True) + cache[colnames] = NT + return NT + +def NamedTupleFactory(attribute_map, composite_relid = None): + """ + Alternative db.typio.RowFactory for producing namedtuple's instead of + postgresql.types.Row() instances. + + To install:: + + >>> from py_opengauss.types.namedtuple import NamedTupleFactory + >>> import py_opengauss + >>> db = py_opengauss.open(...) + >>> db.typio.RowTypeFactory(NamedTupleFactory) + + And **all** Rows produced by that connection will be namedtuple()'s. + This includes composites. + """ + colnames = list(attribute_map.items()) + colnames.sort(key = lambda x: x[1]) + return lambda y: _factory((x[0] for x in colnames))(*y) + +from itertools import chain, starmap + +def namedtuples(stmt, from_iter = chain.from_iterable, map = starmap): + """ + Alternative to the .rows() execution method. + + Use:: + + >>> from py_opengauss.types.namedtuple import namedtuples + >>> ps = namedtuples(db.prepare(...)) + >>> for nt in ps(...): + ... nt.a_column_name + + This effectively selects the execution method to be used with the statement. + """ + NT = _factory(stmt.column_names) + # build the execution "method" + chunks = stmt.chunks + def rows_as_namedtuples(*args, **kw): + return map(NT, from_iter(chunks(*args, **kw))) # starmap + return rows_as_namedtuples + +del chain, starmap diff --git a/py_opengauss/versionstring.py b/py_opengauss/versionstring.py new file mode 100644 index 0000000000000000000000000000000000000000..58b518bf1d3a6a2535c4445cfbe70cefa2a30899 --- /dev/null +++ b/py_opengauss/versionstring.py @@ -0,0 +1,119 @@ +## +# .versionstring +## +""" +PostgreSQL version string parsing. + +>>> postgresql.versionstring.split('8.0.1') +(8, 0, 1, None, None) +""" + +def split(vstr: str) -> tuple: + """ + Split a PostgreSQL version string into a tuple. + (major, minor, patch, ..., state_class, state_level) + """ + # bug: eg. version is 13.2 (Debian 13.2-1.pgdg100+1), so split first. + v = vstr.strip().split()[0].split('.') + + # Get rid of the numbers around the state_class (beta,a,dev,alpha, etc) + state_class = v[-1].strip('0123456789') + if state_class: + last_version, state_level = v[-1].split(state_class) + if not state_level: + state_level = None + else: + state_level = int(state_level) + vlist = [int(x or '0') for x in v[:-1]] + if last_version: + vlist.append(int(last_version)) + vlist += [None] * (3 - len(vlist)) + vlist += [state_class, state_level] + else: + state_level = None + state_class = None + vlist = [int(x or '0') for x in v] + # pad the difference with `None` objects, and +2 for the state_*. + vlist += [None] * ((3 - len(vlist)) + 2) + return tuple(vlist) + +def unsplit(vtup: tuple) -> str: + """ + Join a version tuple back into the original version string. + """ + svtup = [str(x) for x in vtup[:-2] if x is not None] + state_class, state_level = vtup[-2:] + return '.'.join(svtup) + ('' if state_class is None else state_class + str(state_level)) + +def normalize(split_version: tuple) -> tuple: + """ + Given a tuple produced by `split`, normalize the `None` objects into int(0) + or 'final' if it's the ``state_class``. + """ + (*head, state_class, state_level) = split_version + mmp = [x if x is not None else 0 for x in head] + return tuple(mmp + [state_class or 'final', state_level or 0]) + +default_state_class_priority = [ + 'dev', + 'a', + 'alpha', + 'b', + 'beta', + 'rc', + 'final', + None, +] + +python = repr + +def xml(self): + return '\n' + \ + ' ' + str(self[0]) + '\n' + \ + ' ' + str(self[1]) + '\n' + \ + ' ' + str(self[2]) + '\n' + \ + ' ' + str(self[-2]) + '\n' + \ + ' ' + str(self[-1]) + '\n' + \ + '' + +def sh(self): + return """PG_VERSION_MAJOR=%s +PG_VERSION_MINOR=%s +PG_VERSION_PATCH=%s +PG_VERSION_STATE=%s +PG_VERSION_LEVEL=%s""" %( + str(self[0]), + str(self[1]), + str(self[2]), + str(self[-2]), + str(self[-1]), + ) + +if __name__ == '__main__': + import sys + import os + from optparse import OptionParser + op = OptionParser() + op.add_option('-f', '--format', + type='choice', + dest='format', + help='format of output information', + choices=('sh', 'xml', 'python'), + default='sh', + ) + op.add_option('-n', '--normalize', + action='store_true', + dest='normalize', + help='replace missing values with defaults', + default=False, + ) + op.set_usage(op.get_usage().strip() + ' "version to parse"') + co, ca = op.parse_args() + if len(ca) != 1: + op.error('requires exactly one argument, the version') + else: + v = split(ca[0]) + if co.normalize: + v = normalize(v) + sys.stdout.write(getattr(sys.modules[__name__], co.format)(v)) + sys.stdout.write(os.linesep) diff --git a/readthedocs.yml b/readthedocs.yml new file mode 100644 index 0000000000000000000000000000000000000000..d75e54ac240aada66677ac0bf6cdcaff2b3b6f7c --- /dev/null +++ b/readthedocs.yml @@ -0,0 +1,5 @@ +build: + image: latest + +python: + version: 3.7 diff --git a/setup.py b/setup.py new file mode 100755 index 0000000000000000000000000000000000000000..fa89b177b64d860a06d0ce0f69fb607a2be05c55 --- /dev/null +++ b/setup.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +## +# setup.py - .release.distutils +## +import sys +import os + +if sys.version_info[:2] < (3,3): + sys.stderr.write( + "ERROR: py-postgresql is for Python 3.3 and greater." + os.linesep + ) + sys.stderr.write( + "HINT: setup.py was ran using Python " + \ + '.'.join([str(x) for x in sys.version_info[:3]]) + + ': ' + sys.executable + os.linesep + ) + sys.exit(1) + +# distutils data is kept in `postgresql.release.distutils` +sys.path.insert(0, '') + +sys.dont_write_bytecode = True +import py_opengauss.release.distutils as dist +defaults = dist.standard_setup_keywords() +sys.dont_write_bytecode = False + +if __name__ == '__main__': + try: + from setuptools import setup + except ImportError as e: + from distutils.core import setup + setup(**defaults)