updated .env file and working now

This commit is contained in:
abhishek
2026-05-26 11:18:35 +00:00
parent 0235ccbaa2
commit 8e936240ee
6335 changed files with 819584 additions and 3 deletions

View File

@@ -0,0 +1,110 @@
"""
psycopg -- PostgreSQL database adapter for Python
"""
# Copyright (C) 2020 The Psycopg Team
import logging
from . import pq # noqa: F401 import early to stabilize side effects
from . import types
from . import postgres
from ._tpc import Xid
from .copy import Copy, AsyncCopy
from ._enums import IsolationLevel
from .cursor import Cursor
from .errors import Warning, Error, InterfaceError, DatabaseError
from .errors import DataError, OperationalError, IntegrityError
from .errors import InternalError, ProgrammingError, NotSupportedError
from ._column import Column
from ._pipeline import Pipeline, AsyncPipeline
from .connection import BaseConnection, Connection, Notify
from .transaction import Rollback, Transaction, AsyncTransaction
from .cursor_async import AsyncCursor
from .server_cursor import AsyncServerCursor, ServerCursor
from .client_cursor import AsyncClientCursor, ClientCursor
from ._connection_info import ConnectionInfo
from .connection_async import AsyncConnection
from . import dbapi20
from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING
from .dbapi20 import Binary, Date, DateFromTicks, Time, TimeFromTicks
from .dbapi20 import Timestamp, TimestampFromTicks
from .version import __version__ as __version__ # noqa: F401
# Set the logger to a quiet default, can be enabled if needed
logger = logging.getLogger("psycopg")
if logger.level == logging.NOTSET:
logger.setLevel(logging.WARNING)
# DBAPI compliance
connect = Connection.connect
apilevel = "2.0"
threadsafety = 2
paramstyle = "pyformat"
# register default adapters for PostgreSQL
adapters = postgres.adapters # exposed by the package
postgres.register_default_adapters(adapters)
# After the default ones, because these can deal with the bytea oid better
dbapi20.register_dbapi20_adapters(adapters)
# Must come after all the types have been registered
types.array.register_all_arrays(adapters)
# Note: defining the exported methods helps both Sphynx in documenting that
# this is the canonical place to obtain them and should be used by MyPy too,
# so that function signatures are consistent with the documentation.
__all__ = [
"AsyncClientCursor",
"AsyncConnection",
"AsyncCopy",
"AsyncCursor",
"AsyncPipeline",
"AsyncServerCursor",
"AsyncTransaction",
"BaseConnection",
"ClientCursor",
"Column",
"Connection",
"ConnectionInfo",
"Copy",
"Cursor",
"IsolationLevel",
"Notify",
"Pipeline",
"Rollback",
"ServerCursor",
"Transaction",
"Xid",
# DBAPI exports
"connect",
"apilevel",
"threadsafety",
"paramstyle",
"Warning",
"Error",
"InterfaceError",
"DatabaseError",
"DataError",
"OperationalError",
"IntegrityError",
"InternalError",
"ProgrammingError",
"NotSupportedError",
# DBAPI type constructors and singletons
"Binary",
"Date",
"DateFromTicks",
"Time",
"TimeFromTicks",
"Timestamp",
"TimestampFromTicks",
"BINARY",
"DATETIME",
"NUMBER",
"ROWID",
"STRING",
]

View File

@@ -0,0 +1,296 @@
"""
Mapping from types/oids to Dumpers/Loaders
"""
# Copyright (C) 2020 The Psycopg Team
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from typing import cast, TYPE_CHECKING
from . import pq
from . import errors as e
from .abc import Dumper, Loader
from ._enums import PyFormat as PyFormat
from ._cmodule import _psycopg
from ._typeinfo import TypesRegistry
if TYPE_CHECKING:
from .connection import BaseConnection
RV = TypeVar("RV")
class AdaptersMap:
r"""
Establish how types should be converted between Python and PostgreSQL in
an `~psycopg.abc.AdaptContext`.
`!AdaptersMap` maps Python types to `~psycopg.adapt.Dumper` classes to
define how Python types are converted to PostgreSQL, and maps OIDs to
`~psycopg.adapt.Loader` classes to establish how query results are
converted to Python.
Every `!AdaptContext` object has an underlying `!AdaptersMap` defining how
types are converted in that context, exposed as the
`~psycopg.abc.AdaptContext.adapters` attribute: changing such map allows
to customise adaptation in a context without changing separated contexts.
When a context is created from another context (for instance when a
`~psycopg.Cursor` is created from a `~psycopg.Connection`), the parent's
`!adapters` are used as template for the child's `!adapters`, so that every
cursor created from the same connection use the connection's types
configuration, but separate connections have independent mappings.
Once created, `!AdaptersMap` are independent. This means that objects
already created are not affected if a wider scope (e.g. the global one) is
changed.
The connections adapters are initialised using a global `!AdptersMap`
template, exposed as `psycopg.adapters`: changing such mapping allows to
customise the type mapping for every connections created afterwards.
The object can start empty or copy from another object of the same class.
Copies are copy-on-write: if the maps are updated make a copy. This way
extending e.g. global map by a connection or a connection map from a cursor
is cheap: a copy is only made on customisation.
"""
__module__ = "psycopg.adapt"
types: TypesRegistry
_dumpers: Dict[PyFormat, Dict[Union[type, str], Type[Dumper]]]
_dumpers_by_oid: List[Dict[int, Type[Dumper]]]
_loaders: List[Dict[int, Type[Loader]]]
# Record if a dumper or loader has an optimised version.
_optimised: Dict[type, type] = {}
def __init__(
self,
template: Optional["AdaptersMap"] = None,
types: Optional[TypesRegistry] = None,
):
if template:
self._dumpers = template._dumpers.copy()
self._own_dumpers = _dumpers_shared.copy()
template._own_dumpers = _dumpers_shared.copy()
self._dumpers_by_oid = template._dumpers_by_oid[:]
self._own_dumpers_by_oid = [False, False]
template._own_dumpers_by_oid = [False, False]
self._loaders = template._loaders[:]
self._own_loaders = [False, False]
template._own_loaders = [False, False]
self.types = TypesRegistry(template.types)
else:
self._dumpers = {fmt: {} for fmt in PyFormat}
self._own_dumpers = _dumpers_owned.copy()
self._dumpers_by_oid = [{}, {}]
self._own_dumpers_by_oid = [True, True]
self._loaders = [{}, {}]
self._own_loaders = [True, True]
self.types = types or TypesRegistry()
# implement the AdaptContext protocol too
@property
def adapters(self) -> "AdaptersMap":
return self
@property
def connection(self) -> Optional["BaseConnection[Any]"]:
return None
def register_dumper(
self, cls: Union[type, str, None], dumper: Type[Dumper]
) -> None:
"""
Configure the context to use `!dumper` to convert objects of type `!cls`.
If two dumpers with different `~Dumper.format` are registered for the
same type, the last one registered will be chosen when the query
doesn't specify a format (i.e. when the value is used with a ``%s``
"`~PyFormat.AUTO`" placeholder).
:param cls: The type to manage.
:param dumper: The dumper to register for `!cls`.
If `!cls` is specified as string it will be lazy-loaded, so that it
will be possible to register it without importing it before. In this
case it should be the fully qualified name of the object (e.g.
``"uuid.UUID"``).
If `!cls` is None, only use the dumper when looking up using
`get_dumper_by_oid()`, which happens when we know the Postgres type to
adapt to, but not the Python type that will be adapted (e.g. in COPY
after using `~psycopg.Copy.set_types()`).
"""
if not (cls is None or isinstance(cls, (str, type))):
raise TypeError(
f"dumpers should be registered on classes, got {cls} instead"
)
if _psycopg:
dumper = self._get_optimised(dumper)
# Register the dumper both as its format and as auto
# so that the last dumper registered is used in auto (%s) format
if cls:
for fmt in (PyFormat.from_pq(dumper.format), PyFormat.AUTO):
if not self._own_dumpers[fmt]:
self._dumpers[fmt] = self._dumpers[fmt].copy()
self._own_dumpers[fmt] = True
self._dumpers[fmt][cls] = dumper
# Register the dumper by oid, if the oid of the dumper is fixed
if dumper.oid:
if not self._own_dumpers_by_oid[dumper.format]:
self._dumpers_by_oid[dumper.format] = self._dumpers_by_oid[
dumper.format
].copy()
self._own_dumpers_by_oid[dumper.format] = True
self._dumpers_by_oid[dumper.format][dumper.oid] = dumper
def register_loader(self, oid: Union[int, str], loader: Type["Loader"]) -> None:
"""
Configure the context to use `!loader` to convert data of oid `!oid`.
:param oid: The PostgreSQL OID or type name to manage.
:param loader: The loar to register for `!oid`.
If `oid` is specified as string, it refers to a type name, which is
looked up in the `types` registry. `
"""
if isinstance(oid, str):
oid = self.types[oid].oid
if not isinstance(oid, int):
raise TypeError(f"loaders should be registered on oid, got {oid} instead")
if _psycopg:
loader = self._get_optimised(loader)
fmt = loader.format
if not self._own_loaders[fmt]:
self._loaders[fmt] = self._loaders[fmt].copy()
self._own_loaders[fmt] = True
self._loaders[fmt][oid] = loader
def get_dumper(self, cls: type, format: PyFormat) -> Type["Dumper"]:
"""
Return the dumper class for the given type and format.
Raise `~psycopg.ProgrammingError` if a class is not available.
:param cls: The class to adapt.
:param format: The format to dump to. If `~psycopg.adapt.PyFormat.AUTO`,
use the last one of the dumpers registered on `!cls`.
"""
try:
# Fast path: the class has a known dumper.
return self._dumpers[format][cls]
except KeyError:
if format not in self._dumpers:
raise ValueError(f"bad dumper format: {format}")
# If the KeyError was caused by cls missing from dmap, let's
# look for different cases.
dmap = self._dumpers[format]
# Look for the right class, including looking at superclasses
for scls in cls.__mro__:
if scls in dmap:
return dmap[scls]
# If the adapter is not found, look for its name as a string
fqn = scls.__module__ + "." + scls.__qualname__
if fqn in dmap:
# Replace the class name with the class itself
d = dmap[scls] = dmap.pop(fqn)
return d
format = PyFormat(format)
raise e.ProgrammingError(
f"cannot adapt type {cls.__name__!r} using placeholder '%{format.value}'"
f" (format: {format.name})"
)
def get_dumper_by_oid(self, oid: int, format: pq.Format) -> Type["Dumper"]:
"""
Return the dumper class for the given oid and format.
Raise `~psycopg.ProgrammingError` if a class is not available.
:param oid: The oid of the type to dump to.
:param format: The format to dump to.
"""
try:
dmap = self._dumpers_by_oid[format]
except KeyError:
raise ValueError(f"bad dumper format: {format}")
try:
return dmap[oid]
except KeyError:
info = self.types.get(oid)
if info:
msg = (
f"cannot find a dumper for type {info.name} (oid {oid})"
f" format {pq.Format(format).name}"
)
else:
msg = (
f"cannot find a dumper for unknown type with oid {oid}"
f" format {pq.Format(format).name}"
)
raise e.ProgrammingError(msg)
def get_loader(self, oid: int, format: pq.Format) -> Optional[Type["Loader"]]:
"""
Return the loader class for the given oid and format.
Return `!None` if not found.
:param oid: The oid of the type to load.
:param format: The format to load from.
"""
return self._loaders[format].get(oid)
@classmethod
def _get_optimised(self, cls: Type[RV]) -> Type[RV]:
"""Return the optimised version of a Dumper or Loader class.
Return the input class itself if there is no optimised version.
"""
try:
return self._optimised[cls]
except KeyError:
pass
# Check if the class comes from psycopg.types and there is a class
# with the same name in psycopg_c._psycopg.
from psycopg import types
if cls.__module__.startswith(types.__name__):
new = cast(Type[RV], getattr(_psycopg, cls.__name__, None))
if new:
self._optimised[cls] = new
return new
self._optimised[cls] = cls
return cls
# Micro-optimization: copying these objects is faster than creating new dicts
_dumpers_owned = dict.fromkeys(PyFormat, True)
_dumpers_shared = dict.fromkeys(PyFormat, False)

View File

@@ -0,0 +1,24 @@
"""
Simplify access to the _psycopg module
"""
# Copyright (C) 2021 The Psycopg Team
from typing import Optional
from . import pq
__version__: Optional[str] = None
# Note: "c" must the first attempt so that mypy associates the variable the
# right module interface. It will not result Optional, but hey.
if pq.__impl__ == "c":
from psycopg_c import _psycopg as _psycopg
from psycopg_c import __version__ as __version__ # noqa: F401
elif pq.__impl__ == "binary":
from psycopg_binary import _psycopg as _psycopg # type: ignore
from psycopg_binary import __version__ as __version__ # type: ignore # noqa: F401
elif pq.__impl__ == "python":
_psycopg = None # type: ignore
else:
raise ImportError(f"can't find _psycopg optimised module in {pq.__impl__!r}")

View File

@@ -0,0 +1,142 @@
"""
The Column object in Cursor.description
"""
# Copyright (C) 2020 The Psycopg Team
from typing import Any, NamedTuple, Optional, Sequence, TYPE_CHECKING
from operator import attrgetter
if TYPE_CHECKING:
from .cursor import BaseCursor
class ColumnData(NamedTuple):
ftype: int
fmod: int
fsize: int
class Column(Sequence[Any]):
__module__ = "psycopg"
def __init__(self, cursor: "BaseCursor[Any, Any]", index: int):
res = cursor.pgresult
assert res
fname = res.fname(index)
if fname:
self._name = fname.decode(cursor._encoding)
else:
# COPY_OUT results have columns but no name
self._name = f"column_{index + 1}"
self._data = ColumnData(
ftype=res.ftype(index),
fmod=res.fmod(index),
fsize=res.fsize(index),
)
self._type = cursor.adapters.types.get(self._data.ftype)
_attrs = tuple(
attrgetter(attr)
for attr in """
name type_code display_size internal_size precision scale null_ok
""".split()
)
def __repr__(self) -> str:
return (
f"<Column {self.name!r},"
f" type: {self._type_display()} (oid: {self.type_code})>"
)
def __len__(self) -> int:
return 7
def _type_display(self) -> str:
parts = []
parts.append(self._type.name if self._type else str(self.type_code))
mod1 = self.precision
if mod1 is None:
mod1 = self.display_size
if mod1:
parts.append(f"({mod1}")
if self.scale:
parts.append(f", {self.scale}")
parts.append(")")
if self._type and self.type_code == self._type.array_oid:
parts.append("[]")
return "".join(parts)
def __getitem__(self, index: Any) -> Any:
if isinstance(index, slice):
return tuple(getter(self) for getter in self._attrs[index])
else:
return self._attrs[index](self)
@property
def name(self) -> str:
"""The name of the column."""
return self._name
@property
def type_code(self) -> int:
"""The numeric OID of the column."""
return self._data.ftype
@property
def display_size(self) -> Optional[int]:
"""The field size, for :sql:`varchar(n)`, None otherwise."""
if not self._type:
return None
if self._type.name in ("varchar", "char"):
fmod = self._data.fmod
if fmod >= 0:
return fmod - 4
return None
@property
def internal_size(self) -> Optional[int]:
"""The internal field size for fixed-size types, None otherwise."""
fsize = self._data.fsize
return fsize if fsize >= 0 else None
@property
def precision(self) -> Optional[int]:
"""The number of digits for fixed precision types."""
if not self._type:
return None
dttypes = ("time", "timetz", "timestamp", "timestamptz", "interval")
if self._type.name == "numeric":
fmod = self._data.fmod
if fmod >= 0:
return fmod >> 16
elif self._type.name in dttypes:
fmod = self._data.fmod
if fmod >= 0:
return fmod & 0xFFFF
return None
@property
def scale(self) -> Optional[int]:
"""The number of digits after the decimal point if available."""
if self._type and self._type.name == "numeric":
fmod = self._data.fmod - 4
if fmod >= 0:
return fmod & 0xFFFF
return None
@property
def null_ok(self) -> Optional[bool]:
"""Always `!None`"""
return None

View File

@@ -0,0 +1,73 @@
"""
compatibility functions for different Python versions
"""
# Copyright (C) 2021 The Psycopg Team
import sys
import asyncio
from typing import Any, Awaitable, Generator, Optional, Sequence, Union, TypeVar
# NOTE: TypeAlias cannot be exported by this module, as pyright special-cases it.
# For this raisin it must be imported directly from typing_extension where used.
# See https://github.com/microsoft/pyright/issues/4197
from typing_extensions import TypeAlias
if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol
T = TypeVar("T")
FutureT: TypeAlias = Union["asyncio.Future[T]", Generator[Any, None, T], Awaitable[T]]
if sys.version_info >= (3, 8):
create_task = asyncio.create_task
from math import prod
else:
def create_task(
coro: FutureT[T], name: Optional[str] = None
) -> "asyncio.Future[T]":
return asyncio.create_task(coro)
from functools import reduce
def prod(seq: Sequence[int]) -> int:
return reduce(int.__mul__, seq, 1)
if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
from functools import cache
from collections import Counter, deque as Deque
else:
from typing import Counter, Deque
from functools import lru_cache
from backports.zoneinfo import ZoneInfo
cache = lru_cache(maxsize=None)
if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard
if sys.version_info >= (3, 11):
from typing import LiteralString, Self
else:
from typing_extensions import LiteralString, Self
__all__ = [
"Counter",
"Deque",
"LiteralString",
"Protocol",
"Self",
"TypeGuard",
"ZoneInfo",
"cache",
"create_task",
"prod",
]

View File

@@ -0,0 +1,174 @@
"""
Objects to return information about a PostgreSQL connection.
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
from pathlib import Path
from datetime import tzinfo
from . import pq
from ._tz import get_tzinfo
from ._encodings import pgconn_encoding
from .conninfo import make_conninfo
class ConnectionInfo:
"""Allow access to information about the connection."""
__module__ = "psycopg"
def __init__(self, pgconn: pq.abc.PGconn):
self.pgconn = pgconn
@property
def vendor(self) -> str:
"""A string representing the database vendor connected to."""
return "PostgreSQL"
@property
def host(self) -> str:
"""The server host name of the active connection. See :pq:`PQhost()`."""
return self._get_pgconn_attr("host")
@property
def hostaddr(self) -> str:
"""The server IP address of the connection. See :pq:`PQhostaddr()`."""
return self._get_pgconn_attr("hostaddr")
@property
def port(self) -> int:
"""The port of the active connection. See :pq:`PQport()`."""
return int(self._get_pgconn_attr("port"))
@property
def dbname(self) -> str:
"""The database name of the connection. See :pq:`PQdb()`."""
return self._get_pgconn_attr("db")
@property
def user(self) -> str:
"""The user name of the connection. See :pq:`PQuser()`."""
return self._get_pgconn_attr("user")
@property
def password(self) -> str:
"""The password of the connection. See :pq:`PQpass()`."""
return self._get_pgconn_attr("password")
@property
def options(self) -> str:
"""
The command-line options passed in the connection request.
See :pq:`PQoptions`.
"""
return self._get_pgconn_attr("options")
def get_parameters(self) -> dict[str, str]:
"""Return the connection parameters values.
Return all the parameters set to a non-default value, which might come
either from the connection string and parameters passed to
`~Connection.connect()` or from environment variables. The password
is never returned (you can read it using the `password` attribute).
"""
pyenc = self.encoding
# Get the known defaults to avoid reporting them
defaults = {
i.keyword: i.compiled
for i in pq.Conninfo.get_defaults()
if i.compiled is not None
}
# Not returned by the libq. Bug? Bet we're using SSH.
defaults.setdefault(b"channel_binding", b"prefer")
defaults[b"passfile"] = str(Path.home() / ".pgpass").encode()
return {
i.keyword.decode(pyenc): i.val.decode(pyenc)
for i in self.pgconn.info
if i.val is not None
and i.keyword != b"password"
and i.val != defaults.get(i.keyword)
}
@property
def dsn(self) -> str:
"""Return the connection string to connect to the database.
The string contains all the parameters set to a non-default value,
which might come either from the connection string and parameters
passed to `~Connection.connect()` or from environment variables. The
password is never returned (you can read it using the `password`
attribute).
"""
return make_conninfo(**self.get_parameters())
@property
def status(self) -> pq.ConnStatus:
"""The status of the connection. See :pq:`PQstatus()`."""
return pq.ConnStatus(self.pgconn.status)
@property
def transaction_status(self) -> pq.TransactionStatus:
"""
The current in-transaction status of the session.
See :pq:`PQtransactionStatus()`.
"""
return pq.TransactionStatus(self.pgconn.transaction_status)
@property
def pipeline_status(self) -> pq.PipelineStatus:
"""
The current pipeline status of the client.
See :pq:`PQpipelineStatus()`.
"""
return pq.PipelineStatus(self.pgconn.pipeline_status)
def parameter_status(self, param_name: str) -> str | None:
"""
Return a parameter setting of the connection.
Return `None` is the parameter is unknown.
"""
res = self.pgconn.parameter_status(param_name.encode(self.encoding))
return res.decode(self.encoding) if res is not None else None
@property
def server_version(self) -> int:
"""
An integer representing the server version. See :pq:`PQserverVersion()`.
"""
return self.pgconn.server_version
@property
def backend_pid(self) -> int:
"""
The process ID (PID) of the backend process handling this connection.
See :pq:`PQbackendPID()`.
"""
return self.pgconn.backend_pid
@property
def error_message(self) -> str:
"""
The error message most recently generated by an operation on the connection.
See :pq:`PQerrorMessage()`.
"""
return self._get_pgconn_attr("error_message")
@property
def timezone(self) -> tzinfo:
"""The Python timezone info of the connection's timezone."""
return get_tzinfo(self.pgconn)
@property
def encoding(self) -> str:
"""The Python codec name of the connection's client encoding."""
return pgconn_encoding(self.pgconn)
def _get_pgconn_attr(self, name: str) -> str:
value: bytes = getattr(self.pgconn, name)
return value.decode(self.encoding)

View File

@@ -0,0 +1,90 @@
"""
Separate connection attempts from a connection string.
"""
# Copyright (C) 2024 The Psycopg Team
from __future__ import annotations
import socket
import logging
from random import shuffle
from . import errors as e
from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def
from ._conninfo_utils import split_attempts
logger = logging.getLogger("psycopg")
def conninfo_attempts(params: ConnDict) -> list[ConnDict]:
"""Split a set of connection params on the single attempts to perform.
A connection param can perform more than one attempt more than one ``host``
is provided.
Also perform async resolution of the hostname into hostaddr. Because a host
can resolve to more than one address, this can lead to yield more attempts
too. Raise `OperationalError` if no host could be resolved.
Because the libpq async function doesn't honour the timeout, we need to
reimplement the repeated attempts.
"""
last_exc = None
attempts = []
for attempt in split_attempts(params):
try:
attempts.extend(_resolve_hostnames(attempt))
except OSError as ex:
logger.debug("failed to resolve host %r: %s", attempt.get("host"), str(ex))
last_exc = ex
if not attempts:
assert last_exc
# We couldn't resolve anything
raise e.OperationalError(str(last_exc))
if get_param(params, "load_balance_hosts") == "random":
shuffle(attempts)
return attempts
def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
"""
Perform DNS lookup of the hosts and return a list of connection attempts.
If a ``host`` param is present but not ``hostname``, resolve the host
addresses asynchronously.
:param params: The input parameters, for instance as returned by
`~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
a single entry for host, hostaddr because it is designed to further
process the input of split_attempts().
:return: A list of attempts to make (to include the case of a hostname
resolving to more than one IP).
"""
host = get_param(params, "host")
if not host or host.startswith("/") or host[1:2] == ":":
# Local path, or no host to resolve
return [params]
hostaddr = get_param(params, "hostaddr")
if hostaddr:
# Already resolved
return [params]
if is_ip_address(host):
# If the host is already an ip address don't try to resolve it
return [{**params, "hostaddr": host}]
port = get_param(params, "port")
if not port:
port_def = get_param_def("port")
port = port_def and port_def.compiled or "5432"
ans = socket.getaddrinfo(
host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
)
return [{**params, "hostaddr": item[4][0]} for item in ans]

View File

@@ -0,0 +1,92 @@
"""
Separate connection attempts from a connection string.
"""
# Copyright (C) 2024 The Psycopg Team
from __future__ import annotations
import socket
import asyncio
import logging
from random import shuffle
from . import errors as e
from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def
from ._conninfo_utils import split_attempts
logger = logging.getLogger("psycopg")
async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]:
"""Split a set of connection params on the single attempts to perform.
A connection param can perform more than one attempt more than one ``host``
is provided.
Also perform async resolution of the hostname into hostaddr. Because a host
can resolve to more than one address, this can lead to yield more attempts
too. Raise `OperationalError` if no host could be resolved.
Because the libpq async function doesn't honour the timeout, we need to
reimplement the repeated attempts.
"""
last_exc = None
attempts = []
for attempt in split_attempts(params):
try:
attempts.extend(await _resolve_hostnames(attempt))
except OSError as ex:
logger.debug("failed to resolve host %r: %s", attempt.get("host"), str(ex))
last_exc = ex
if not attempts:
assert last_exc
# We couldn't resolve anything
raise e.OperationalError(str(last_exc))
if get_param(params, "load_balance_hosts") == "random":
shuffle(attempts)
return attempts
async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
"""
Perform async DNS lookup of the hosts and return a list of connection attempts.
If a ``host`` param is present but not ``hostname``, resolve the host
addresses asynchronously.
:param params: The input parameters, for instance as returned by
`~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
a single entry for host, hostaddr because it is designed to further
process the input of split_attempts().
:return: A list of attempts to make (to include the case of a hostname
resolving to more than one IP).
"""
host = get_param(params, "host")
if not host or host.startswith("/") or host[1:2] == ":":
# Local path, or no host to resolve
return [params]
hostaddr = get_param(params, "hostaddr")
if hostaddr:
# Already resolved
return [params]
if is_ip_address(host):
# If the host is already an ip address don't try to resolve it
return [{**params, "hostaddr": host}]
port = get_param(params, "port")
if not port:
port_def = get_param_def("port")
port = port_def and port_def.compiled or "5432"
loop = asyncio.get_running_loop()
ans = await loop.getaddrinfo(
host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
)
return [{**params, "hostaddr": item[4][0]} for item in ans]

View File

@@ -0,0 +1,127 @@
"""
Internal utilities to manipulate connection strings
"""
# Copyright (C) 2024 The Psycopg Team
from __future__ import annotations
import os
from typing import Any
from functools import lru_cache
from ipaddress import ip_address
from dataclasses import dataclass
from typing_extensions import TypeAlias
from . import pq
from . import errors as e
ConnDict: TypeAlias = "dict[str, Any]"
def split_attempts(params: ConnDict) -> list[ConnDict]:
"""
Split connection parameters with a sequence of hosts into separate attempts.
"""
def split_val(key: str) -> list[str]:
val = get_param(params, key)
return val.split(",") if val else []
hosts = split_val("host")
hostaddrs = split_val("hostaddr")
ports = split_val("port")
if hosts and hostaddrs and len(hosts) != len(hostaddrs):
raise e.OperationalError(
f"could not match {len(hosts)} host names"
f" with {len(hostaddrs)} hostaddr values"
)
nhosts = max(len(hosts), len(hostaddrs))
if 1 < len(ports) != nhosts:
raise e.OperationalError(
f"could not match {len(ports)} port numbers to {len(hosts)} hosts"
)
# A single attempt to make. Don't mangle the conninfo string.
if nhosts <= 1:
return [params]
if len(ports) == 1:
ports *= nhosts
# Now all lists are either empty or have the same length
rv = []
for i in range(nhosts):
attempt = params.copy()
if hosts:
attempt["host"] = hosts[i]
if hostaddrs:
attempt["hostaddr"] = hostaddrs[i]
if ports:
attempt["port"] = ports[i]
rv.append(attempt)
return rv
def get_param(params: ConnDict, name: str) -> str | None:
"""
Return a value from a connection string.
The value may be also specified in a PG* env var.
"""
if name in params:
return str(params[name])
# TODO: check if in service
paramdef = get_param_def(name)
if not paramdef:
return None
env = os.environ.get(paramdef.envvar)
if env is not None:
return env
return None
@dataclass
class ParamDef:
"""
Information about defaults and env vars for connection params
"""
keyword: str
envvar: str
compiled: str | None
def get_param_def(keyword: str, _cache: dict[str, ParamDef] = {}) -> ParamDef | None:
"""
Return the ParamDef of a connection string parameter.
"""
if not _cache:
defs = pq.Conninfo.get_defaults()
for d in defs:
cd = ParamDef(
keyword=d.keyword.decode(),
envvar=d.envvar.decode() if d.envvar else "",
compiled=d.compiled.decode() if d.compiled is not None else None,
)
_cache[cd.keyword] = cd
return _cache.get(keyword)
@lru_cache()
def is_ip_address(s: str) -> bool:
"""Return True if the string represent a valid ip address."""
try:
ip_address(s)
except ValueError:
return False
return True

View File

@@ -0,0 +1,246 @@
# type: ignore # dnspython is currently optional and mypy fails if missing
"""
DNS query support
"""
# Copyright (C) 2021 The Psycopg Team
import os
import re
import warnings
from random import randint
from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Sequence
from typing import TYPE_CHECKING
from collections import defaultdict
try:
from dns.resolver import Resolver, Cache
from dns.asyncresolver import Resolver as AsyncResolver
from dns.exception import DNSException
except ImportError:
raise ImportError(
"the module psycopg._dns requires the package 'dnspython' installed"
)
from . import errors as e
from . import conninfo
if TYPE_CHECKING:
from dns.rdtypes.IN.SRV import SRV
resolver = Resolver()
resolver.cache = Cache()
async_resolver = AsyncResolver()
async_resolver.cache = Cache()
async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
"""
Perform async DNS lookup of the hosts and return a new params dict.
.. deprecated:: 3.1
The use of this function is not necessary anymore, because
`psycopg.AsyncConnection.connect()` performs non-blocking name
resolution automatically.
"""
warnings.warn(
"from psycopg 3.1, resolve_hostaddr_async() is not needed anymore",
DeprecationWarning,
)
hosts: list[str] = []
hostaddrs: list[str] = []
ports: list[str] = []
for attempt in await conninfo.conninfo_attempts_async(params):
if attempt.get("host") is not None:
hosts.append(attempt["host"])
if attempt.get("hostaddr") is not None:
hostaddrs.append(attempt["hostaddr"])
if attempt.get("port") is not None:
ports.append(str(attempt["port"]))
out = params.copy()
shosts = ",".join(hosts)
if shosts:
out["host"] = shosts
shostaddrs = ",".join(hostaddrs)
if shostaddrs:
out["hostaddr"] = shostaddrs
sports = ",".join(ports)
if ports:
out["port"] = sports
return out
def resolve_srv(params: Dict[str, Any]) -> Dict[str, Any]:
"""Apply SRV DNS lookup as defined in :RFC:`2782`."""
return Rfc2782Resolver().resolve(params)
async def resolve_srv_async(params: Dict[str, Any]) -> Dict[str, Any]:
"""Async equivalent of `resolve_srv()`."""
return await Rfc2782Resolver().resolve_async(params)
class HostPort(NamedTuple):
host: str
port: str
totry: bool = False
target: Optional[str] = None
class Rfc2782Resolver:
"""Implement SRV RR Resolution as per RFC 2782
The class is organised to minimise code duplication between the sync and
the async paths.
"""
re_srv_rr = re.compile(r"^(?P<service>_[^\.]+)\.(?P<proto>_[^\.]+)\.(?P<target>.+)")
def resolve(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Update the parameters host and port after SRV lookup."""
attempts = self._get_attempts(params)
if not attempts:
return params
hps = []
for hp in attempts:
if hp.totry:
hps.extend(self._resolve_srv(hp))
else:
hps.append(hp)
return self._return_params(params, hps)
async def resolve_async(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Update the parameters host and port after SRV lookup."""
attempts = self._get_attempts(params)
if not attempts:
return params
hps = []
for hp in attempts:
if hp.totry:
hps.extend(await self._resolve_srv_async(hp))
else:
hps.append(hp)
return self._return_params(params, hps)
def _get_attempts(self, params: Dict[str, Any]) -> List[HostPort]:
"""
Return the list of host, and for each host if SRV lookup must be tried.
Return an empty list if no lookup is requested.
"""
# If hostaddr is defined don't do any resolution.
if params.get("hostaddr", os.environ.get("PGHOSTADDR", "")):
return []
host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
hosts_in = host_arg.split(",")
port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
ports_in = port_arg.split(",")
if len(ports_in) == 1:
# If only one port is specified, it applies to all the hosts.
ports_in *= len(hosts_in)
if len(ports_in) != len(hosts_in):
# ProgrammingError would have been more appropriate, but this is
# what the raise if the libpq fails connect in the same case.
raise e.OperationalError(
f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers"
)
out = []
srv_found = False
for host, port in zip(hosts_in, ports_in):
m = self.re_srv_rr.match(host)
if m or port.lower() == "srv":
srv_found = True
target = m.group("target") if m else None
hp = HostPort(host=host, port=port, totry=True, target=target)
else:
hp = HostPort(host=host, port=port)
out.append(hp)
return out if srv_found else []
def _resolve_srv(self, hp: HostPort) -> List[HostPort]:
try:
ans = resolver.resolve(hp.host, "SRV")
except DNSException:
ans = ()
return self._get_solved_entries(hp, ans)
async def _resolve_srv_async(self, hp: HostPort) -> List[HostPort]:
try:
ans = await async_resolver.resolve(hp.host, "SRV")
except DNSException:
ans = ()
return self._get_solved_entries(hp, ans)
def _get_solved_entries(
self, hp: HostPort, entries: "Sequence[SRV]"
) -> List[HostPort]:
if not entries:
# No SRV entry found. Delegate the libpq a QNAME=target lookup
if hp.target and hp.port.lower() != "srv":
return [HostPort(host=hp.target, port=hp.port)]
else:
return []
# If there is precisely one SRV RR, and its Target is "." (the root
# domain), abort.
if len(entries) == 1 and str(entries[0].target) == ".":
return []
return [
HostPort(host=str(entry.target).rstrip("."), port=str(entry.port))
for entry in self.sort_rfc2782(entries)
]
def _return_params(
self, params: Dict[str, Any], hps: List[HostPort]
) -> Dict[str, Any]:
if not hps:
# Nothing found, we ended up with an empty list
raise e.OperationalError("no host found after SRV RR lookup")
out = params.copy()
out["host"] = ",".join(hp.host for hp in hps)
out["port"] = ",".join(str(hp.port) for hp in hps)
return out
def sort_rfc2782(self, ans: "Sequence[SRV]") -> "List[SRV]":
"""
Implement the priority/weight ordering defined in RFC 2782.
"""
# Divide the entries by priority:
priorities: DefaultDict[int, "List[SRV]"] = defaultdict(list)
out: "List[SRV]" = []
for entry in ans:
priorities[entry.priority].append(entry)
for pri, entries in sorted(priorities.items()):
if len(entries) == 1:
out.append(entries[0])
continue
entries.sort(key=lambda ent: ent.weight)
total_weight = sum(ent.weight for ent in entries)
while entries:
r = randint(0, total_weight)
csum = 0
for i, ent in enumerate(entries):
csum += ent.weight
if csum >= r:
break
out.append(ent)
total_weight -= ent.weight
del entries[i]
return out

View File

@@ -0,0 +1,170 @@
"""
Mappings between PostgreSQL and Python encodings.
"""
# Copyright (C) 2020 The Psycopg Team
import re
import string
import codecs
from typing import Any, Dict, Optional, TYPE_CHECKING
from .pq._enums import ConnStatus
from .errors import NotSupportedError
from ._compat import cache
if TYPE_CHECKING:
from .pq.abc import PGconn
from .connection import BaseConnection
OK = ConnStatus.OK
_py_codecs = {
"BIG5": "big5",
"EUC_CN": "gb2312",
"EUC_JIS_2004": "euc_jis_2004",
"EUC_JP": "euc_jp",
"EUC_KR": "euc_kr",
# "EUC_TW": not available in Python
"GB18030": "gb18030",
"GBK": "gbk",
"ISO_8859_5": "iso8859-5",
"ISO_8859_6": "iso8859-6",
"ISO_8859_7": "iso8859-7",
"ISO_8859_8": "iso8859-8",
"JOHAB": "johab",
"KOI8R": "koi8-r",
"KOI8U": "koi8-u",
"LATIN1": "iso8859-1",
"LATIN10": "iso8859-16",
"LATIN2": "iso8859-2",
"LATIN3": "iso8859-3",
"LATIN4": "iso8859-4",
"LATIN5": "iso8859-9",
"LATIN6": "iso8859-10",
"LATIN7": "iso8859-13",
"LATIN8": "iso8859-14",
"LATIN9": "iso8859-15",
# "MULE_INTERNAL": not available in Python
"SHIFT_JIS_2004": "shift_jis_2004",
"SJIS": "shift_jis",
# this actually means no encoding, see PostgreSQL docs
# it is special-cased by the text loader.
"SQL_ASCII": "ascii",
"UHC": "cp949",
"UTF8": "utf-8",
"WIN1250": "cp1250",
"WIN1251": "cp1251",
"WIN1252": "cp1252",
"WIN1253": "cp1253",
"WIN1254": "cp1254",
"WIN1255": "cp1255",
"WIN1256": "cp1256",
"WIN1257": "cp1257",
"WIN1258": "cp1258",
"WIN866": "cp866",
"WIN874": "cp874",
}
py_codecs: Dict[bytes, str] = {}
py_codecs.update((k.encode(), v) for k, v in _py_codecs.items())
# Add an alias without underscore, for lenient lookups
py_codecs.update(
(k.replace("_", "").encode(), v) for k, v in _py_codecs.items() if "_" in k
)
pg_codecs = {v: k.encode() for k, v in _py_codecs.items()}
def conn_encoding(conn: "Optional[BaseConnection[Any]]") -> str:
"""
Return the Python encoding name of a psycopg connection.
Default to utf8 if the connection has no encoding info.
"""
if not conn or conn.closed:
return "utf-8"
pgenc = conn.pgconn.parameter_status(b"client_encoding") or b"UTF8"
return pg2pyenc(pgenc)
def pgconn_encoding(pgconn: "PGconn") -> str:
"""
Return the Python encoding name of a libpq connection.
Default to utf8 if the connection has no encoding info.
"""
if pgconn.status != OK:
return "utf-8"
pgenc = pgconn.parameter_status(b"client_encoding") or b"UTF8"
return pg2pyenc(pgenc)
def conninfo_encoding(conninfo: str) -> str:
"""
Return the Python encoding name passed in a conninfo string. Default to utf8.
Because the input is likely to come from the user and not normalised by the
server, be somewhat lenient (non-case-sensitive lookup, ignore noise chars).
"""
from .conninfo import conninfo_to_dict
params = conninfo_to_dict(conninfo)
pgenc = params.get("client_encoding")
if pgenc:
try:
return pg2pyenc(pgenc.encode())
except NotSupportedError:
pass
return "utf-8"
@cache
def py2pgenc(name: str) -> bytes:
"""Convert a Python encoding name to PostgreSQL encoding name.
Raise LookupError if the Python encoding is unknown.
"""
return pg_codecs[codecs.lookup(name).name]
@cache
def pg2pyenc(name: bytes) -> str:
"""Convert a PostgreSQL encoding name to Python encoding name.
Raise NotSupportedError if the PostgreSQL encoding is not supported by
Python.
"""
try:
return py_codecs[name.replace(b"-", b"").replace(b"_", b"").upper()]
except KeyError:
sname = name.decode("utf8", "replace")
raise NotSupportedError(f"codec not available in Python: {sname!r}")
def _as_python_identifier(s: str, prefix: str = "f") -> str:
"""
Reduce a string to a valid Python identifier.
Replace all non-valid chars with '_' and prefix the value with `!prefix` if
the first letter is an '_'.
"""
if not s.isidentifier():
if s[0] in "1234567890":
s = prefix + s
if not s.isidentifier():
s = _re_clean.sub("_", s)
# namedtuple fields cannot start with underscore. So...
if s[0] == "_":
s = prefix + s
return s
_re_clean = re.compile(
f"[^{string.ascii_lowercase}{string.ascii_uppercase}{string.digits}_]"
)

View File

@@ -0,0 +1,79 @@
"""
Enum values for psycopg
These values are defined by us and are not necessarily dependent on
libpq-defined enums.
"""
# Copyright (C) 2020 The Psycopg Team
from enum import Enum, IntEnum
from selectors import EVENT_READ, EVENT_WRITE
from . import pq
class Wait(IntEnum):
R = EVENT_READ
W = EVENT_WRITE
RW = EVENT_READ | EVENT_WRITE
class Ready(IntEnum):
R = EVENT_READ
W = EVENT_WRITE
RW = EVENT_READ | EVENT_WRITE
class PyFormat(str, Enum):
"""
Enum representing the format wanted for a query argument.
The value `AUTO` allows psycopg to choose the best format for a certain
parameter.
"""
__module__ = "psycopg.adapt"
AUTO = "s"
"""Automatically chosen (``%s`` placeholder)."""
TEXT = "t"
"""Text parameter (``%t`` placeholder)."""
BINARY = "b"
"""Binary parameter (``%b`` placeholder)."""
@classmethod
def from_pq(cls, fmt: pq.Format) -> "PyFormat":
return _pg2py[fmt]
@classmethod
def as_pq(cls, fmt: "PyFormat") -> pq.Format:
return _py2pg[fmt]
class IsolationLevel(IntEnum):
"""
Enum representing the isolation level for a transaction.
"""
__module__ = "psycopg"
READ_UNCOMMITTED = 1
""":sql:`READ UNCOMMITTED` isolation level."""
READ_COMMITTED = 2
""":sql:`READ COMMITTED` isolation level."""
REPEATABLE_READ = 3
""":sql:`REPEATABLE READ` isolation level."""
SERIALIZABLE = 4
""":sql:`SERIALIZABLE` isolation level."""
_py2pg = {
PyFormat.TEXT: pq.Format.TEXT,
PyFormat.BINARY: pq.Format.BINARY,
}
_pg2py = {
pq.Format.TEXT: PyFormat.TEXT,
pq.Format.BINARY: PyFormat.BINARY,
}

View File

@@ -0,0 +1,295 @@
"""
commands pipeline management
"""
# Copyright (C) 2021 The Psycopg Team
import logging
from types import TracebackType
from typing import Any, List, Optional, Union, Tuple, Type, TYPE_CHECKING
from typing_extensions import TypeAlias
from . import pq
from . import errors as e
from .abc import PipelineCommand, PQGen
from ._compat import Deque, Self
from .pq.misc import connection_summary
from ._encodings import pgconn_encoding
from ._preparing import Key, Prepare
from .generators import pipeline_communicate, fetch_many, send
if TYPE_CHECKING:
from .pq.abc import PGresult
from .cursor import BaseCursor
from .connection import BaseConnection, Connection
from .connection_async import AsyncConnection
PendingResult: TypeAlias = Union[
None, Tuple["BaseCursor[Any, Any]", Optional[Tuple[Key, Prepare, bytes]]]
]
FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
BAD = pq.ConnStatus.BAD
ACTIVE = pq.TransactionStatus.ACTIVE
logger = logging.getLogger("psycopg")
class BasePipeline:
command_queue: Deque[PipelineCommand]
result_queue: Deque[PendingResult]
_is_supported: Optional[bool] = None
def __init__(self, conn: "BaseConnection[Any]") -> None:
self._conn = conn
self.pgconn = conn.pgconn
self.command_queue = Deque[PipelineCommand]()
self.result_queue = Deque[PendingResult]()
self.level = 0
def __repr__(self) -> str:
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
info = connection_summary(self._conn.pgconn)
return f"<{cls} {info} at 0x{id(self):x}>"
@property
def status(self) -> pq.PipelineStatus:
return pq.PipelineStatus(self.pgconn.pipeline_status)
@classmethod
def is_supported(cls) -> bool:
"""Return `!True` if the psycopg libpq wrapper supports pipeline mode."""
if BasePipeline._is_supported is None:
BasePipeline._is_supported = not cls._not_supported_reason()
return BasePipeline._is_supported
@classmethod
def _not_supported_reason(cls) -> str:
"""Return the reason why the pipeline mode is not supported.
Return an empty string if pipeline mode is supported.
"""
# Support only depends on the libpq functions available in the pq
# wrapper, not on the database version.
if pq.version() < 140000:
return (
f"libpq too old {pq.version()};"
" v14 or greater required for pipeline mode"
)
if pq.__build_version__ < 140000:
return (
f"libpq too old: module built for {pq.__build_version__};"
" v14 or greater required for pipeline mode"
)
return ""
def _enter_gen(self) -> PQGen[None]:
if not self.is_supported():
raise e.NotSupportedError(
f"pipeline mode not supported: {self._not_supported_reason()}"
)
if self.level == 0:
self.pgconn.enter_pipeline_mode()
elif self.command_queue or self.pgconn.transaction_status == ACTIVE:
# Nested pipeline case.
# Transaction might be ACTIVE when the pipeline uses an "implicit
# transaction", typically in autocommit mode. But when entering a
# Psycopg transaction(), we expect the IDLE state. By sync()-ing,
# we make sure all previous commands are completed and the
# transaction gets back to IDLE.
yield from self._sync_gen()
self.level += 1
def _exit(self, exc: Optional[BaseException]) -> None:
self.level -= 1
if self.level == 0 and self.pgconn.status != BAD:
try:
self.pgconn.exit_pipeline_mode()
except e.OperationalError as exc2:
# Notice that this error might be pretty irrecoverable. It
# happens on COPY, for instance: even if sync succeeds, exiting
# fails with "cannot exit pipeline mode with uncollected results"
if exc:
logger.warning("error ignored exiting %r: %s", self, exc2)
else:
raise exc2.with_traceback(None)
def _sync_gen(self) -> PQGen[None]:
self._enqueue_sync()
yield from self._communicate_gen()
yield from self._fetch_gen(flush=False)
def _exit_gen(self) -> PQGen[None]:
"""
Exit current pipeline by sending a Sync and fetch back all remaining results.
"""
try:
self._enqueue_sync()
yield from self._communicate_gen()
finally:
yield from self._fetch_gen(flush=True)
def _communicate_gen(self) -> PQGen[None]:
"""Communicate with pipeline to send commands and possibly fetch
results, which are then processed.
"""
fetched = yield from pipeline_communicate(self.pgconn, self.command_queue)
exception = None
for results in fetched:
queued = self.result_queue.popleft()
try:
self._process_results(queued, results)
except e.Error as exc:
if exception is None:
exception = exc
if exception is not None:
raise exception
def _fetch_gen(self, *, flush: bool) -> PQGen[None]:
"""Fetch available results from the connection and process them with
pipeline queued items.
If 'flush' is True, a PQsendFlushRequest() is issued in order to make
sure results can be fetched. Otherwise, the caller may emit a
PQpipelineSync() call to ensure the output buffer gets flushed before
fetching.
"""
if not self.result_queue:
return
if flush:
self.pgconn.send_flush_request()
yield from send(self.pgconn)
exception = None
while self.result_queue:
results = yield from fetch_many(self.pgconn)
if not results:
# No more results to fetch, but there may still be pending
# commands.
break
queued = self.result_queue.popleft()
try:
self._process_results(queued, results)
except e.Error as exc:
if exception is None:
exception = exc
if exception is not None:
raise exception
def _process_results(
self, queued: PendingResult, results: List["PGresult"]
) -> None:
"""Process a results set fetched from the current pipeline.
This matches 'results' with its respective element in the pipeline
queue. For commands (None value in the pipeline queue), results are
checked directly. For prepare statement creation requests, update the
cache. Otherwise, results are attached to their respective cursor.
"""
if queued is None:
(result,) = results
if result.status == FATAL_ERROR:
raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn))
elif result.status == PIPELINE_ABORTED:
raise e.PipelineAborted("pipeline aborted")
else:
cursor, prepinfo = queued
if prepinfo:
key, prep, name = prepinfo
# Update the prepare state of the query.
cursor._conn._prepared.validate(key, prep, name, results)
cursor._set_results_from_pipeline(results)
def _enqueue_sync(self) -> None:
"""Enqueue a PQpipelineSync() command."""
self.command_queue.append(self.pgconn.pipeline_sync)
self.result_queue.append(None)
class Pipeline(BasePipeline):
"""Handler for connection in pipeline mode."""
__module__ = "psycopg"
_conn: "Connection[Any]"
def __init__(self, conn: "Connection[Any]") -> None:
super().__init__(conn)
def sync(self) -> None:
"""Sync the pipeline, send any pending command and receive and process
all available results.
"""
try:
with self._conn.lock:
self._conn.wait(self._sync_gen())
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
def __enter__(self) -> Self:
with self._conn.lock:
self._conn.wait(self._enter_gen())
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
try:
with self._conn.lock:
self._conn.wait(self._exit_gen())
except Exception as exc2:
# Don't clobber an exception raised in the block with this one
if exc_val:
logger.warning("error ignored terminating %r: %s", self, exc2)
else:
raise exc2.with_traceback(None)
finally:
self._exit(exc_val)
class AsyncPipeline(BasePipeline):
"""Handler for async connection in pipeline mode."""
__module__ = "psycopg"
_conn: "AsyncConnection[Any]"
def __init__(self, conn: "AsyncConnection[Any]") -> None:
super().__init__(conn)
async def sync(self) -> None:
try:
async with self._conn.lock:
await self._conn.wait(self._sync_gen())
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
async def __aenter__(self) -> Self:
async with self._conn.lock:
await self._conn.wait(self._enter_gen())
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
try:
async with self._conn.lock:
await self._conn.wait(self._exit_gen())
except Exception as exc2:
# Don't clobber an exception raised in the block with this one
if exc_val:
logger.warning("error ignored terminating %r: %s", self, exc2)
else:
raise exc2.with_traceback(None)
finally:
self._exit(exc_val)

View File

@@ -0,0 +1,194 @@
"""
Support for prepared statements
"""
# Copyright (C) 2020 The Psycopg Team
from enum import IntEnum, auto
from typing import Iterator, Optional, Sequence, Tuple, TYPE_CHECKING
from collections import OrderedDict
from typing_extensions import TypeAlias
from . import pq
from ._compat import Deque
from ._queries import PostgresQuery
if TYPE_CHECKING:
from .pq.abc import PGresult
Key: TypeAlias = Tuple[bytes, Tuple[int, ...]]
COMMAND_OK = pq.ExecStatus.COMMAND_OK
TUPLES_OK = pq.ExecStatus.TUPLES_OK
class Prepare(IntEnum):
NO = auto()
YES = auto()
SHOULD = auto()
class PrepareManager:
# Number of times a query is executed before it is prepared.
prepare_threshold: Optional[int] = 5
# Maximum number of prepared statements on the connection.
prepared_max: int = 100
def __init__(self) -> None:
# Map (query, types) to the number of times the query was seen.
self._counts: OrderedDict[Key, int] = OrderedDict()
# Map (query, types) to the name of the statement if prepared.
self._names: OrderedDict[Key, bytes] = OrderedDict()
# Counter to generate prepared statements names
self._prepared_idx = 0
self._maint_commands = Deque[bytes]()
@staticmethod
def key(query: PostgresQuery) -> Key:
return (query.query, query.types)
def get(
self, query: PostgresQuery, prepare: Optional[bool] = None
) -> Tuple[Prepare, bytes]:
"""
Check if a query is prepared, tell back whether to prepare it.
"""
if prepare is False or self.prepare_threshold is None:
# The user doesn't want this query to be prepared
return Prepare.NO, b""
key = self.key(query)
name = self._names.get(key)
if name:
# The query was already prepared in this session
return Prepare.YES, name
count = self._counts.get(key, 0)
if count >= self.prepare_threshold or prepare:
# The query has been executed enough times and needs to be prepared
name = f"_pg3_{self._prepared_idx}".encode()
self._prepared_idx += 1
return Prepare.SHOULD, name
else:
# The query is not to be prepared yet
return Prepare.NO, b""
def _should_discard(self, prep: Prepare, results: Sequence["PGresult"]) -> bool:
"""Check if we need to discard our entire state: it should happen on
rollback or on dropping objects, because the same object may get
recreated and postgres would fail internal lookups.
"""
if self._names or prep == Prepare.SHOULD:
for result in results:
if result.status != COMMAND_OK:
continue
cmdstat = result.command_status
if cmdstat and (cmdstat.startswith(b"DROP ") or cmdstat == b"ROLLBACK"):
return self.clear()
return False
@staticmethod
def _check_results(results: Sequence["PGresult"]) -> bool:
"""Return False if 'results' are invalid for prepared statement cache."""
if len(results) != 1:
# We cannot prepare a multiple statement
return False
status = results[0].status
if COMMAND_OK != status != TUPLES_OK:
# We don't prepare failed queries or other weird results
return False
return True
def _rotate(self) -> None:
"""Evict an old value from the cache.
If it was prepared, deallocate it. Do it only once: if the cache was
resized, deallocate gradually.
"""
if len(self._counts) > self.prepared_max:
self._counts.popitem(last=False)
if len(self._names) > self.prepared_max:
name = self._names.popitem(last=False)[1]
self._maint_commands.append(b"DEALLOCATE " + name)
def maybe_add_to_cache(
self, query: PostgresQuery, prep: Prepare, name: bytes
) -> Optional[Key]:
"""Handle 'query' for possible addition to the cache.
If a new entry has been added, return its key. Return None otherwise
(meaning the query is already in cache or cache is not enabled).
"""
# don't do anything if prepared statements are disabled
if self.prepare_threshold is None:
return None
key = self.key(query)
if key in self._counts:
if prep is Prepare.SHOULD:
del self._counts[key]
self._names[key] = name
else:
self._counts[key] += 1
self._counts.move_to_end(key)
return None
elif key in self._names:
self._names.move_to_end(key)
return None
else:
if prep is Prepare.SHOULD:
self._names[key] = name
else:
self._counts[key] = 1
return key
def validate(
self,
key: Key,
prep: Prepare,
name: bytes,
results: Sequence["PGresult"],
) -> None:
"""Validate cached entry with 'key' by checking query 'results'.
Possibly record a command to perform maintenance on database side.
"""
if self._should_discard(prep, results):
return
if not self._check_results(results):
self._names.pop(key, None)
self._counts.pop(key, None)
else:
self._rotate()
def clear(self) -> bool:
"""Clear the cache of the maintenance commands.
Clear the internal state and prepare a command to clear the state of
the server.
"""
self._counts.clear()
if self._names:
self._names.clear()
self._maint_commands.clear()
self._maint_commands.append(b"DEALLOCATE ALL")
return True
else:
return False
def get_maintenance_commands(self) -> Iterator[bytes]:
"""
Iterate over the commands needed to align the server state to our state
"""
while self._maint_commands:
yield self._maint_commands.popleft()

View File

@@ -0,0 +1,415 @@
"""
Utility module to manipulate queries
"""
# Copyright (C) 2020 The Psycopg Team
import re
from typing import Any, Callable, Dict, List, Mapping, Match, NamedTuple, Optional
from typing import Sequence, Tuple, Union, TYPE_CHECKING
from functools import lru_cache
from typing_extensions import TypeAlias
from . import pq
from . import errors as e
from .sql import Composable
from .abc import Buffer, Query, Params
from ._enums import PyFormat
from ._encodings import conn_encoding
if TYPE_CHECKING:
from .abc import Transformer
MAX_CACHED_STATEMENT_LENGTH = 4096
MAX_CACHED_STATEMENT_PARAMS = 50
class QueryPart(NamedTuple):
pre: bytes
item: Union[int, str]
format: PyFormat
class PostgresQuery:
"""
Helper to convert a Python query and parameters into Postgres format.
"""
__slots__ = """
query params types formats
_tx _want_formats _parts _encoding _order
""".split()
def __init__(self, transformer: "Transformer"):
self._tx = transformer
self.params: Optional[Sequence[Optional[Buffer]]] = None
# these are tuples so they can be used as keys e.g. in prepared stmts
self.types: Tuple[int, ...] = ()
# The format requested by the user and the ones to really pass Postgres
self._want_formats: Optional[List[PyFormat]] = None
self.formats: Optional[Sequence[pq.Format]] = None
self._encoding = conn_encoding(transformer.connection)
self._parts: List[QueryPart]
self.query = b""
self._order: Optional[List[str]] = None
def convert(self, query: Query, vars: Optional[Params]) -> None:
"""
Set up the query and parameters to convert.
The results of this function can be obtained accessing the object
attributes (`query`, `params`, `types`, `formats`).
"""
if isinstance(query, str):
bquery = query.encode(self._encoding)
elif isinstance(query, Composable):
bquery = query.as_bytes(self._tx)
else:
bquery = query
if vars is not None:
# Avoid caching queries extremely long or with a huge number of
# parameters. They are usually generated by ORMs and have poor
# cacheablility (e.g. INSERT ... VALUES (...), (...) with varying
# numbers of tuples.
# see https://github.com/psycopg/psycopg/discussions/628
if (
len(bquery) <= MAX_CACHED_STATEMENT_LENGTH
and len(vars) <= MAX_CACHED_STATEMENT_PARAMS
):
f: _Query2Pg = _query2pg
else:
f = _query2pg_nocache
(self.query, self._want_formats, self._order, self._parts) = f(
bquery, self._encoding
)
else:
self.query = bquery
self._want_formats = self._order = None
self.dump(vars)
def dump(self, vars: Optional[Params]) -> None:
"""
Process a new set of variables on the query processed by `convert()`.
This method updates `params` and `types`.
"""
if vars is not None:
params = _validate_and_reorder_params(self._parts, vars, self._order)
assert self._want_formats is not None
self.params = self._tx.dump_sequence(params, self._want_formats)
self.types = self._tx.types or ()
self.formats = self._tx.formats
else:
self.params = None
self.types = ()
self.formats = None
# The type of the _query2pg() and _query2pg_nocache() methods
_Query2Pg: TypeAlias = Callable[
[bytes, str], Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]
]
def _query2pg_nocache(
query: bytes, encoding: str
) -> Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]:
"""
Convert Python query and params into something Postgres understands.
- Convert Python placeholders (``%s``, ``%(name)s``) into Postgres
format (``$1``, ``$2``)
- placeholders can be %s, %t, or %b (auto, text or binary)
- return ``query`` (bytes), ``formats`` (list of formats) ``order``
(sequence of names used in the query, in the position they appear)
``parts`` (splits of queries and placeholders).
"""
parts = _split_query(query, encoding)
order: Optional[List[str]] = None
chunks: List[bytes] = []
formats = []
if isinstance(parts[0].item, int):
for part in parts[:-1]:
assert isinstance(part.item, int)
chunks.append(part.pre)
chunks.append(b"$%d" % (part.item + 1))
formats.append(part.format)
elif isinstance(parts[0].item, str):
seen: Dict[str, Tuple[bytes, PyFormat]] = {}
order = []
for part in parts[:-1]:
assert isinstance(part.item, str)
chunks.append(part.pre)
if part.item not in seen:
ph = b"$%d" % (len(seen) + 1)
seen[part.item] = (ph, part.format)
order.append(part.item)
chunks.append(ph)
formats.append(part.format)
else:
if seen[part.item][1] != part.format:
raise e.ProgrammingError(
f"placeholder '{part.item}' cannot have different formats"
)
chunks.append(seen[part.item][0])
# last part
chunks.append(parts[-1].pre)
return b"".join(chunks), formats, order, parts
# Note: the cache size is 128 items, but someone has reported throwing ~12k
# queries (of type `INSERT ... VALUES (...), (...)` with a varying amount of
# records), and the resulting cache size is >100Mb. So, we will avoid to cache
# large queries or queries with a large number of params. See
# https://github.com/sqlalchemy/sqlalchemy/discussions/10270
_query2pg = lru_cache()(_query2pg_nocache)
class PostgresClientQuery(PostgresQuery):
"""
PostgresQuery subclass merging query and arguments client-side.
"""
__slots__ = ("template",)
def convert(self, query: Query, vars: Optional[Params]) -> None:
"""
Set up the query and parameters to convert.
The results of this function can be obtained accessing the object
attributes (`query`, `params`, `types`, `formats`).
"""
if isinstance(query, str):
bquery = query.encode(self._encoding)
elif isinstance(query, Composable):
bquery = query.as_bytes(self._tx)
else:
bquery = query
if vars is not None:
if (
len(bquery) <= MAX_CACHED_STATEMENT_LENGTH
and len(vars) <= MAX_CACHED_STATEMENT_PARAMS
):
f: _Query2PgClient = _query2pg_client
else:
f = _query2pg_client_nocache
(self.template, self._order, self._parts) = f(bquery, self._encoding)
else:
self.query = bquery
self._order = None
self.dump(vars)
def dump(self, vars: Optional[Params]) -> None:
"""
Process a new set of variables on the query processed by `convert()`.
This method updates `params` and `types`.
"""
if vars is not None:
params = _validate_and_reorder_params(self._parts, vars, self._order)
self.params = tuple(
self._tx.as_literal(p) if p is not None else b"NULL" for p in params
)
self.query = self.template % self.params
else:
self.params = None
_Query2PgClient: TypeAlias = Callable[
[bytes, str], Tuple[bytes, Optional[List[str]], List[QueryPart]]
]
def _query2pg_client_nocache(
query: bytes, encoding: str
) -> Tuple[bytes, Optional[List[str]], List[QueryPart]]:
"""
Convert Python query and params into a template to perform client-side binding
"""
parts = _split_query(query, encoding, collapse_double_percent=False)
order: Optional[List[str]] = None
chunks: List[bytes] = []
if isinstance(parts[0].item, int):
for part in parts[:-1]:
assert isinstance(part.item, int)
chunks.append(part.pre)
chunks.append(b"%s")
elif isinstance(parts[0].item, str):
seen: Dict[str, Tuple[bytes, PyFormat]] = {}
order = []
for part in parts[:-1]:
assert isinstance(part.item, str)
chunks.append(part.pre)
if part.item not in seen:
ph = b"%s"
seen[part.item] = (ph, part.format)
order.append(part.item)
chunks.append(ph)
else:
chunks.append(seen[part.item][0])
order.append(part.item)
# last part
chunks.append(parts[-1].pre)
return b"".join(chunks), order, parts
_query2pg_client = lru_cache()(_query2pg_client_nocache)
def _validate_and_reorder_params(
parts: List[QueryPart], vars: Params, order: Optional[List[str]]
) -> Sequence[Any]:
"""
Verify the compatibility between a query and a set of params.
"""
# Try concrete types, then abstract types
t = type(vars)
if t is list or t is tuple:
sequence = True
elif t is dict:
sequence = False
elif isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
sequence = True
elif isinstance(vars, Mapping):
sequence = False
else:
raise TypeError(
"query parameters should be a sequence or a mapping,"
f" got {type(vars).__name__}"
)
if sequence:
if len(vars) != len(parts) - 1:
raise e.ProgrammingError(
f"the query has {len(parts) - 1} placeholders but"
f" {len(vars)} parameters were passed"
)
if vars and not isinstance(parts[0].item, int):
raise TypeError("named placeholders require a mapping of parameters")
return vars # type: ignore[return-value]
else:
if vars and len(parts) > 1 and not isinstance(parts[0][1], str):
raise TypeError(
"positional placeholders (%s) require a sequence of parameters"
)
try:
return [vars[item] for item in order or ()] # type: ignore[call-overload]
except KeyError:
raise e.ProgrammingError(
"query parameter missing:"
f" {', '.join(sorted(i for i in order or () if i not in vars))}"
)
_re_placeholder = re.compile(
rb"""(?x)
% # a literal %
(?:
(?:
\( ([^)]+) \) # or a name in (braces)
. # followed by a format
)
|
(?:.) # or any char, really
)
"""
)
def _split_query(
query: bytes, encoding: str = "ascii", collapse_double_percent: bool = True
) -> List[QueryPart]:
parts: List[Tuple[bytes, Optional[Match[bytes]]]] = []
cur = 0
# pairs [(fragment, match], with the last match None
m = None
for m in _re_placeholder.finditer(query):
pre = query[cur : m.span(0)[0]]
parts.append((pre, m))
cur = m.span(0)[1]
if m:
parts.append((query[cur:], None))
else:
parts.append((query, None))
rv = []
# drop the "%%", validate
i = 0
phtype = None
while i < len(parts):
pre, m = parts[i]
if m is None:
# last part
rv.append(QueryPart(pre, 0, PyFormat.AUTO))
break
ph = m.group(0)
if ph == b"%%":
# unescape '%%' to '%' if necessary, then merge the parts
if collapse_double_percent:
ph = b"%"
pre1, m1 = parts[i + 1]
parts[i + 1] = (pre + ph + pre1, m1)
del parts[i]
continue
if ph == b"%(":
raise e.ProgrammingError(
"incomplete placeholder:"
f" '{query[m.span(0)[0]:].split()[0].decode(encoding)}'"
)
elif ph == b"% ":
# explicit messasge for a typical error
raise e.ProgrammingError(
"incomplete placeholder: '%'; if you want to use '%' as an"
" operator you can double it up, i.e. use '%%'"
)
elif ph[-1:] not in b"sbt":
raise e.ProgrammingError(
"only '%s', '%b', '%t' are allowed as placeholders, got"
f" '{m.group(0).decode(encoding)}'"
)
# Index or name
item: Union[int, str]
item = m.group(1).decode(encoding) if m.group(1) else i
if not phtype:
phtype = type(item)
elif phtype is not type(item):
raise e.ProgrammingError(
"positional and named placeholders cannot be mixed"
)
format = _ph_to_fmt[ph[-1:]]
rv.append(QueryPart(pre, item, format))
i += 1
return rv
_ph_to_fmt = {
b"s": PyFormat.AUTO,
b"t": PyFormat.TEXT,
b"b": PyFormat.BINARY,
}

View File

@@ -0,0 +1,56 @@
"""
Utility functions to deal with binary structs.
"""
# Copyright (C) 2020 The Psycopg Team
import struct
from typing import Callable, cast, Optional, Tuple
from typing_extensions import TypeAlias
from .abc import Buffer
from . import errors as e
from ._compat import Protocol
PackInt: TypeAlias = Callable[[int], bytes]
UnpackInt: TypeAlias = Callable[[Buffer], Tuple[int]]
PackFloat: TypeAlias = Callable[[float], bytes]
UnpackFloat: TypeAlias = Callable[[Buffer], Tuple[float]]
class UnpackLen(Protocol):
def __call__(self, data: Buffer, start: Optional[int]) -> Tuple[int]: ...
pack_int2 = cast(PackInt, struct.Struct("!h").pack)
pack_uint2 = cast(PackInt, struct.Struct("!H").pack)
pack_int4 = cast(PackInt, struct.Struct("!i").pack)
pack_uint4 = cast(PackInt, struct.Struct("!I").pack)
pack_int8 = cast(PackInt, struct.Struct("!q").pack)
pack_float4 = cast(PackFloat, struct.Struct("!f").pack)
pack_float8 = cast(PackFloat, struct.Struct("!d").pack)
unpack_int2 = cast(UnpackInt, struct.Struct("!h").unpack)
unpack_uint2 = cast(UnpackInt, struct.Struct("!H").unpack)
unpack_int4 = cast(UnpackInt, struct.Struct("!i").unpack)
unpack_uint4 = cast(UnpackInt, struct.Struct("!I").unpack)
unpack_int8 = cast(UnpackInt, struct.Struct("!q").unpack)
unpack_float4 = cast(UnpackFloat, struct.Struct("!f").unpack)
unpack_float8 = cast(UnpackFloat, struct.Struct("!d").unpack)
_struct_len = struct.Struct("!i")
pack_len = cast(Callable[[int], bytes], _struct_len.pack)
unpack_len = cast(UnpackLen, _struct_len.unpack_from)
def pack_float4_bug_304(x: float) -> bytes:
raise e.InterfaceError(
"cannot dump Float4: Python affected by bug #304. Note that the psycopg-c"
" and psycopg-binary packages are not affected by this issue."
" See https://github.com/psycopg/psycopg/issues/304"
)
# If issue #304 is detected, raise an error instead of dumping wrong data.
if struct.Struct("!f").pack(1.0) != bytes.fromhex("3f800000"):
pack_float4 = pack_float4_bug_304

View File

@@ -0,0 +1,116 @@
"""
psycopg two-phase commit support
"""
# Copyright (C) 2021 The Psycopg Team
import re
import datetime as dt
from base64 import b64encode, b64decode
from typing import Optional, Union
from dataclasses import dataclass, replace
_re_xid = re.compile(r"^(\d+)_([^_]*)_([^_]*)$")
@dataclass(frozen=True)
class Xid:
"""A two-phase commit transaction identifier.
The object can also be unpacked as a 3-item tuple (`format_id`, `gtrid`,
`bqual`).
"""
format_id: Optional[int]
gtrid: str
bqual: Optional[str]
prepared: Optional[dt.datetime] = None
owner: Optional[str] = None
database: Optional[str] = None
@classmethod
def from_string(cls, s: str) -> "Xid":
"""Try to parse an XA triple from the string.
This may fail for several reasons. In such case return an unparsed Xid.
"""
try:
return cls._parse_string(s)
except Exception:
return Xid(None, s, None)
def __str__(self) -> str:
return self._as_tid()
def __len__(self) -> int:
return 3
def __getitem__(self, index: int) -> Union[int, str, None]:
return (self.format_id, self.gtrid, self.bqual)[index]
@classmethod
def _parse_string(cls, s: str) -> "Xid":
m = _re_xid.match(s)
if not m:
raise ValueError("bad Xid format")
format_id = int(m.group(1))
gtrid = b64decode(m.group(2)).decode()
bqual = b64decode(m.group(3)).decode()
return cls.from_parts(format_id, gtrid, bqual)
@classmethod
def from_parts(
cls, format_id: Optional[int], gtrid: str, bqual: Optional[str]
) -> "Xid":
if format_id is not None:
if bqual is None:
raise TypeError("if format_id is specified, bqual must be too")
if not 0 <= format_id < 0x80000000:
raise ValueError("format_id must be a non-negative 32-bit integer")
if len(bqual) > 64:
raise ValueError("bqual must be not longer than 64 chars")
if len(gtrid) > 64:
raise ValueError("gtrid must be not longer than 64 chars")
elif bqual is None:
raise TypeError("if format_id is None, bqual must be None too")
return Xid(format_id, gtrid, bqual)
def _as_tid(self) -> str:
"""
Return the PostgreSQL transaction_id for this XA xid.
PostgreSQL wants just a string, while the DBAPI supports the XA
standard and thus a triple. We use the same conversion algorithm
implemented by JDBC in order to allow some form of interoperation.
see also: the pgjdbc implementation
http://cvs.pgfoundry.org/cgi-bin/cvsweb.cgi/jdbc/pgjdbc/org/
postgresql/xa/RecoveredXid.java?rev=1.2
"""
if self.format_id is None or self.bqual is None:
# Unparsed xid: return the gtrid.
return self.gtrid
# XA xid: mash together the components.
egtrid = b64encode(self.gtrid.encode()).decode()
ebqual = b64encode(self.bqual.encode()).decode()
return f"{self.format_id}_{egtrid}_{ebqual}"
@classmethod
def _get_recover_query(cls) -> str:
return "SELECT gid, prepared, owner, database FROM pg_prepared_xacts"
@classmethod
def _from_record(
cls, gid: str, prepared: dt.datetime, owner: str, database: str
) -> "Xid":
xid = Xid.from_string(gid)
return replace(xid, prepared=prepared, owner=owner, database=database)
Xid.__module__ = "psycopg"

View File

@@ -0,0 +1,354 @@
"""
Helper object to transform values between Python and PostgreSQL
"""
# Copyright (C) 2020 The Psycopg Team
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import DefaultDict, TYPE_CHECKING
from collections import defaultdict
from typing_extensions import TypeAlias
from . import pq
from . import postgres
from . import errors as e
from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, DumperKey, NoneType
from .rows import Row, RowMaker
from .postgres import INVALID_OID, TEXT_OID
from ._encodings import pgconn_encoding
if TYPE_CHECKING:
from .abc import Dumper, Loader
from .adapt import AdaptersMap
from .pq.abc import PGresult
from .connection import BaseConnection
DumperCache: TypeAlias = Dict[DumperKey, "Dumper"]
OidDumperCache: TypeAlias = Dict[int, "Dumper"]
LoaderCache: TypeAlias = Dict[int, "Loader"]
TEXT = pq.Format.TEXT
PY_TEXT = PyFormat.TEXT
class Transformer(AdaptContext):
"""
An object that can adapt efficiently between Python and PostgreSQL.
The life cycle of the object is the query, so it is assumed that attributes
such as the server version or the connection encoding will not change. The
object have its state so adapting several values of the same type can be
optimised.
"""
__module__ = "psycopg.adapt"
__slots__ = """
types formats
_conn _adapters _pgresult _dumpers _loaders _encoding _none_oid
_oid_dumpers _oid_types _row_dumpers _row_loaders
""".split()
types: Optional[Tuple[int, ...]]
formats: Optional[List[pq.Format]]
_adapters: "AdaptersMap"
_pgresult: Optional["PGresult"]
_none_oid: int
def __init__(self, context: Optional[AdaptContext] = None):
self._pgresult = self.types = self.formats = None
# WARNING: don't store context, or you'll create a loop with the Cursor
if context:
self._adapters = context.adapters
self._conn = context.connection
else:
self._adapters = postgres.adapters
self._conn = None
# mapping fmt, class -> Dumper instance
self._dumpers: DefaultDict[PyFormat, DumperCache]
self._dumpers = defaultdict(dict)
# mapping fmt, oid -> Dumper instance
# Not often used, so create it only if needed.
self._oid_dumpers: Optional[Tuple[OidDumperCache, OidDumperCache]]
self._oid_dumpers = None
# mapping fmt, oid -> Loader instance
self._loaders: Tuple[LoaderCache, LoaderCache] = ({}, {})
self._row_dumpers: Optional[List["Dumper"]] = None
# sequence of load functions from value to python
# the length of the result columns
self._row_loaders: List[LoadFunc] = []
# mapping oid -> type sql representation
self._oid_types: Dict[int, bytes] = {}
self._encoding = ""
@classmethod
def from_context(cls, context: Optional[AdaptContext]) -> "Transformer":
"""
Return a Transformer from an AdaptContext.
If the context is a Transformer instance, just return it.
"""
if isinstance(context, Transformer):
return context
else:
return cls(context)
@property
def connection(self) -> Optional["BaseConnection[Any]"]:
return self._conn
@property
def encoding(self) -> str:
if not self._encoding:
conn = self.connection
self._encoding = pgconn_encoding(conn.pgconn) if conn else "utf-8"
return self._encoding
@property
def adapters(self) -> "AdaptersMap":
return self._adapters
@property
def pgresult(self) -> Optional["PGresult"]:
return self._pgresult
def set_pgresult(
self,
result: Optional["PGresult"],
*,
set_loaders: bool = True,
format: Optional[pq.Format] = None,
) -> None:
self._pgresult = result
if not result:
self._nfields = self._ntuples = 0
if set_loaders:
self._row_loaders = []
return
self._ntuples = result.ntuples
nf = self._nfields = result.nfields
if not set_loaders:
return
if not nf:
self._row_loaders = []
return
fmt: pq.Format
fmt = result.fformat(0) if format is None else format # type: ignore
self._row_loaders = [
self.get_loader(result.ftype(i), fmt).load for i in range(nf)
]
def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
self._row_dumpers = [self.get_dumper_by_oid(oid, format) for oid in types]
self.types = tuple(types)
self.formats = [format] * len(types)
def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
self._row_loaders = [self.get_loader(oid, format).load for oid in types]
def dump_sequence(
self, params: Sequence[Any], formats: Sequence[PyFormat]
) -> Sequence[Optional[Buffer]]:
nparams = len(params)
out: List[Optional[Buffer]] = [None] * nparams
# If we have dumpers, it means set_dumper_types had been called, in
# which case self.types and self.formats are set to sequences of the
# right size.
if self._row_dumpers:
for i in range(nparams):
param = params[i]
if param is not None:
out[i] = self._row_dumpers[i].dump(param)
return out
types = [self._get_none_oid()] * nparams
pqformats = [TEXT] * nparams
for i in range(nparams):
param = params[i]
if param is None:
continue
dumper = self.get_dumper(param, formats[i])
out[i] = dumper.dump(param)
types[i] = dumper.oid
pqformats[i] = dumper.format
self.types = tuple(types)
self.formats = pqformats
return out
def as_literal(self, obj: Any) -> bytes:
dumper = self.get_dumper(obj, PY_TEXT)
rv = dumper.quote(obj)
# If the result is quoted, and the oid not unknown or text,
# add an explicit type cast.
# Check the last char because the first one might be 'E'.
oid = dumper.oid
if oid and rv and rv[-1] == b"'"[0] and oid != TEXT_OID:
try:
type_sql = self._oid_types[oid]
except KeyError:
ti = self.adapters.types.get(oid)
if ti:
if oid < 8192:
# builtin: prefer "timestamptz" to "timestamp with time zone"
type_sql = ti.name.encode(self.encoding)
else:
type_sql = ti.regtype.encode(self.encoding)
if oid == ti.array_oid:
type_sql += b"[]"
else:
type_sql = b""
self._oid_types[oid] = type_sql
if type_sql:
rv = b"%s::%s" % (rv, type_sql)
if not isinstance(rv, bytes):
rv = bytes(rv)
return rv
def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper":
"""
Return a Dumper instance to dump `!obj`.
"""
# Normally, the type of the object dictates how to dump it
key = type(obj)
# Reuse an existing Dumper class for objects of the same type
cache = self._dumpers[format]
try:
dumper = cache[key]
except KeyError:
# If it's the first time we see this type, look for a dumper
# configured for it.
try:
dcls = self.adapters.get_dumper(key, format)
except e.ProgrammingError as ex:
raise ex from None
else:
cache[key] = dumper = dcls(key, self)
# Check if the dumper requires an upgrade to handle this specific value
key1 = dumper.get_key(obj, format)
if key1 is key:
return dumper
# If it does, ask the dumper to create its own upgraded version
try:
return cache[key1]
except KeyError:
dumper = cache[key1] = dumper.upgrade(obj, format)
return dumper
def _get_none_oid(self) -> int:
try:
return self._none_oid
except AttributeError:
pass
try:
rv = self._none_oid = self._adapters.get_dumper(NoneType, PY_TEXT).oid
except KeyError:
raise e.InterfaceError("None dumper not found")
return rv
def get_dumper_by_oid(self, oid: int, format: pq.Format) -> "Dumper":
"""
Return a Dumper to dump an object to the type with given oid.
"""
if not self._oid_dumpers:
self._oid_dumpers = ({}, {})
# Reuse an existing Dumper class for objects of the same type
cache = self._oid_dumpers[format]
try:
return cache[oid]
except KeyError:
# If it's the first time we see this type, look for a dumper
# configured for it.
dcls = self.adapters.get_dumper_by_oid(oid, format)
cache[oid] = dumper = dcls(NoneType, self)
return dumper
def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]:
res = self._pgresult
if not res:
raise e.InterfaceError("result not set")
if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples):
raise e.InterfaceError(
f"rows must be included between 0 and {self._ntuples}"
)
records = []
for row in range(row0, row1):
record: List[Any] = [None] * self._nfields
for col in range(self._nfields):
val = res.get_value(row, col)
if val is not None:
record[col] = self._row_loaders[col](val)
records.append(make_row(record))
return records
def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]:
res = self._pgresult
if not res:
return None
if not 0 <= row < self._ntuples:
return None
record: List[Any] = [None] * self._nfields
for col in range(self._nfields):
val = res.get_value(row, col)
if val is not None:
record[col] = self._row_loaders[col](val)
return make_row(record)
def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]:
if len(self._row_loaders) != len(record):
raise e.ProgrammingError(
f"cannot load sequence of {len(record)} items:"
f" {len(self._row_loaders)} loaders registered"
)
return tuple(
(self._row_loaders[i](val) if val is not None else None)
for i, val in enumerate(record)
)
def get_loader(self, oid: int, format: pq.Format) -> "Loader":
try:
return self._loaders[format][oid]
except KeyError:
pass
loader_cls = self._adapters.get_loader(oid, format)
if not loader_cls:
loader_cls = self._adapters.get_loader(INVALID_OID, format)
if not loader_cls:
raise e.InterfaceError("unknown oid loader not found")
loader = self._loaders[format][oid] = loader_cls(oid, self)
return loader

View File

@@ -0,0 +1,494 @@
"""
Information about PostgreSQL types
These types allow to read information from the system catalog and provide
information to the adapters if needed.
"""
# Copyright (C) 2020 The Psycopg Team
from enum import Enum
from typing import Any, Dict, Iterator, Optional, overload
from typing import Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
from typing_extensions import TypeAlias
from . import errors as e
from .abc import AdaptContext, Query
from .rows import dict_row
from ._encodings import conn_encoding
if TYPE_CHECKING:
from .connection import BaseConnection, Connection
from .connection_async import AsyncConnection
from .sql import Identifier, SQL
T = TypeVar("T", bound="TypeInfo")
RegistryKey: TypeAlias = Union[str, int, Tuple[type, int]]
class TypeInfo:
"""
Hold information about a PostgreSQL base type.
"""
__module__ = "psycopg.types"
def __init__(
self,
name: str,
oid: int,
array_oid: int,
*,
regtype: str = "",
delimiter: str = ",",
):
self.name = name
self.oid = oid
self.array_oid = array_oid
self.regtype = regtype or name
self.delimiter = delimiter
def __repr__(self) -> str:
return (
f"<{self.__class__.__qualname__}:"
f" {self.name} (oid: {self.oid}, array oid: {self.array_oid})>"
)
@overload
@classmethod
def fetch(
cls: Type[T], conn: "Connection[Any]", name: Union[str, "Identifier"]
) -> Optional[T]: ...
@overload
@classmethod
async def fetch(
cls: Type[T], conn: "AsyncConnection[Any]", name: Union[str, "Identifier"]
) -> Optional[T]: ...
@classmethod
def fetch(
cls: Type[T], conn: "BaseConnection[Any]", name: Union[str, "Identifier"]
) -> Any:
"""Query a system catalog to read information about a type."""
from .sql import Composable
from .connection import Connection
from .connection_async import AsyncConnection
if isinstance(name, Composable):
name = name.as_string(conn)
if isinstance(conn, Connection):
return cls._fetch(conn, name)
elif isinstance(conn, AsyncConnection):
return cls._fetch_async(conn, name)
else:
raise TypeError(
f"expected Connection or AsyncConnection, got {type(conn).__name__}"
)
@classmethod
def _fetch(cls: Type[T], conn: "Connection[Any]", name: str) -> Optional[T]:
# This might result in a nested transaction. What we want is to leave
# the function with the connection in the state we found (either idle
# or intrans)
try:
with conn.transaction():
if conn_encoding(conn) == "ascii":
conn.execute("set local client_encoding to utf8")
with conn.cursor(row_factory=dict_row) as cur:
cur.execute(cls._get_info_query(conn), {"name": name})
recs = cur.fetchall()
except e.UndefinedObject:
return None
return cls._from_records(name, recs)
@classmethod
async def _fetch_async(
cls: Type[T], conn: "AsyncConnection[Any]", name: str
) -> Optional[T]:
try:
async with conn.transaction():
if conn_encoding(conn) == "ascii":
await conn.execute("set local client_encoding to utf8")
async with conn.cursor(row_factory=dict_row) as cur:
await cur.execute(cls._get_info_query(conn), {"name": name})
recs = await cur.fetchall()
except e.UndefinedObject:
return None
return cls._from_records(name, recs)
@classmethod
def _from_records(
cls: Type[T], name: str, recs: Sequence[Dict[str, Any]]
) -> Optional[T]:
if len(recs) == 1:
return cls(**recs[0])
elif not recs:
return None
else:
raise e.ProgrammingError(f"found {len(recs)} different types named {name}")
def register(self, context: Optional[AdaptContext] = None) -> None:
"""
Register the type information, globally or in the specified `!context`.
"""
if context:
types = context.adapters.types
else:
from . import postgres
types = postgres.types
types.add(self)
if self.array_oid:
from .types.array import register_array
register_array(self, context)
@classmethod
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
from .sql import SQL
return SQL(
"""\
SELECT
typname AS name, oid, typarray AS array_oid,
oid::regtype::text AS regtype, typdelim AS delimiter
FROM pg_type t
WHERE t.oid = {regtype}
ORDER BY t.oid
"""
).format(regtype=cls._to_regtype(conn))
@classmethod
def _has_to_regtype_function(cls, conn: "BaseConnection[Any]") -> bool:
# to_regtype() introduced in PostgreSQL 9.4 and CockroachDB 22.2
info = conn.info
if info.vendor == "PostgreSQL":
return info.server_version >= 90400
elif info.vendor == "CockroachDB":
return info.server_version >= 220200
else:
return False
@classmethod
def _to_regtype(cls, conn: "BaseConnection[Any]") -> "SQL":
# `to_regtype()` returns the type oid or NULL, unlike the :: operator,
# which returns the type or raises an exception, which requires
# a transaction rollback and leaves traces in the server logs.
from .sql import SQL
if cls._has_to_regtype_function(conn):
return SQL("to_regtype(%(name)s)")
else:
return SQL("%(name)s::regtype")
def _added(self, registry: "TypesRegistry") -> None:
"""Method called by the `!registry` when the object is added there."""
pass
class RangeInfo(TypeInfo):
"""Manage information about a range type."""
__module__ = "psycopg.types.range"
def __init__(
self,
name: str,
oid: int,
array_oid: int,
*,
regtype: str = "",
subtype_oid: int,
):
super().__init__(name, oid, array_oid, regtype=regtype)
self.subtype_oid = subtype_oid
@classmethod
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
from .sql import SQL
return SQL(
"""\
SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
t.oid::regtype::text AS regtype,
r.rngsubtype AS subtype_oid
FROM pg_type t
JOIN pg_range r ON t.oid = r.rngtypid
WHERE t.oid = {regtype}
"""
).format(regtype=cls._to_regtype(conn))
def _added(self, registry: "TypesRegistry") -> None:
# Map ranges subtypes to info
registry._registry[RangeInfo, self.subtype_oid] = self
class MultirangeInfo(TypeInfo):
"""Manage information about a multirange type."""
__module__ = "psycopg.types.multirange"
def __init__(
self,
name: str,
oid: int,
array_oid: int,
*,
regtype: str = "",
range_oid: int,
subtype_oid: int,
):
super().__init__(name, oid, array_oid, regtype=regtype)
self.range_oid = range_oid
self.subtype_oid = subtype_oid
@classmethod
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
from .sql import SQL
if conn.info.server_version < 140000:
raise e.NotSupportedError(
"multirange types are only available from PostgreSQL 14"
)
return SQL(
"""\
SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
t.oid::regtype::text AS regtype,
r.rngtypid AS range_oid, r.rngsubtype AS subtype_oid
FROM pg_type t
JOIN pg_range r ON t.oid = r.rngmultitypid
WHERE t.oid = {regtype}
"""
).format(regtype=cls._to_regtype(conn))
def _added(self, registry: "TypesRegistry") -> None:
# Map multiranges ranges and subtypes to info
registry._registry[MultirangeInfo, self.range_oid] = self
registry._registry[MultirangeInfo, self.subtype_oid] = self
class CompositeInfo(TypeInfo):
"""Manage information about a composite type."""
__module__ = "psycopg.types.composite"
def __init__(
self,
name: str,
oid: int,
array_oid: int,
*,
regtype: str = "",
field_names: Sequence[str],
field_types: Sequence[int],
):
super().__init__(name, oid, array_oid, regtype=regtype)
self.field_names = field_names
self.field_types = field_types
# Will be set by register() if the `factory` is a type
self.python_type: Optional[type] = None
@classmethod
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
from .sql import SQL
return SQL(
"""\
SELECT
t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
t.oid::regtype::text AS regtype,
coalesce(a.fnames, '{{}}') AS field_names,
coalesce(a.ftypes, '{{}}') AS field_types
FROM pg_type t
LEFT JOIN (
SELECT
attrelid,
array_agg(attname) AS fnames,
array_agg(atttypid) AS ftypes
FROM (
SELECT a.attrelid, a.attname, a.atttypid
FROM pg_attribute a
JOIN pg_type t ON t.typrelid = a.attrelid
WHERE t.oid = {regtype}
AND a.attnum > 0
AND NOT a.attisdropped
ORDER BY a.attnum
) x
GROUP BY attrelid
) a ON a.attrelid = t.typrelid
WHERE t.oid = {regtype}
"""
).format(regtype=cls._to_regtype(conn))
class EnumInfo(TypeInfo):
"""Manage information about an enum type."""
__module__ = "psycopg.types.enum"
def __init__(
self,
name: str,
oid: int,
array_oid: int,
labels: Sequence[str],
):
super().__init__(name, oid, array_oid)
self.labels = labels
# Will be set by register_enum()
self.enum: Optional[Type[Enum]] = None
@classmethod
def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
from .sql import SQL
return SQL(
"""\
SELECT name, oid, array_oid, array_agg(label) AS labels
FROM (
SELECT
t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
e.enumlabel AS label
FROM pg_type t
LEFT JOIN pg_enum e
ON e.enumtypid = t.oid
WHERE t.oid = {regtype}
ORDER BY e.enumsortorder
) x
GROUP BY name, oid, array_oid
"""
).format(regtype=cls._to_regtype(conn))
class TypesRegistry:
"""
Container for the information about types in a database.
"""
__module__ = "psycopg.types"
def __init__(self, template: Optional["TypesRegistry"] = None):
self._registry: Dict[RegistryKey, TypeInfo]
# Make a shallow copy: it will become a proper copy if the registry
# is edited.
if template:
self._registry = template._registry
self._own_state = False
template._own_state = False
else:
self.clear()
def clear(self) -> None:
self._registry = {}
self._own_state = True
def add(self, info: TypeInfo) -> None:
self._ensure_own_state()
if info.oid:
self._registry[info.oid] = info
if info.array_oid:
self._registry[info.array_oid] = info
self._registry[info.name] = info
if info.regtype and info.regtype not in self._registry:
self._registry[info.regtype] = info
# Allow info to customise further their relation with the registry
info._added(self)
def __iter__(self) -> Iterator[TypeInfo]:
seen = set()
for t in self._registry.values():
if id(t) not in seen:
seen.add(id(t))
yield t
@overload
def __getitem__(self, key: Union[str, int]) -> TypeInfo: ...
@overload
def __getitem__(self, key: Tuple[Type[T], int]) -> T: ...
def __getitem__(self, key: RegistryKey) -> TypeInfo:
"""
Return info about a type, specified by name or oid
:param key: the name or oid of the type to look for.
Raise KeyError if not found.
"""
if isinstance(key, str):
if key.endswith("[]"):
key = key[:-2]
elif not isinstance(key, (int, tuple)):
raise TypeError(f"the key must be an oid or a name, got {type(key)}")
try:
return self._registry[key]
except KeyError:
raise KeyError(f"couldn't find the type {key!r} in the types registry")
@overload
def get(self, key: Union[str, int]) -> Optional[TypeInfo]: ...
@overload
def get(self, key: Tuple[Type[T], int]) -> Optional[T]: ...
def get(self, key: RegistryKey) -> Optional[TypeInfo]:
"""
Return info about a type, specified by name or oid
:param key: the name or oid of the type to look for.
Unlike `__getitem__`, return None if not found.
"""
try:
return self[key]
except KeyError:
return None
def get_oid(self, name: str) -> int:
"""
Return the oid of a PostgreSQL type by name.
:param key: the name of the type to look for.
Return the array oid if the type ends with "``[]``"
Raise KeyError if the name is unknown.
"""
t = self[name]
if name.endswith("[]"):
return t.array_oid
else:
return t.oid
def get_by_subtype(self, cls: Type[T], subtype: Union[int, str]) -> Optional[T]:
"""
Return info about a `TypeInfo` subclass by its element name or oid.
:param cls: the subtype of `!TypeInfo` to look for. Currently
supported are `~psycopg.types.range.RangeInfo` and
`~psycopg.types.multirange.MultirangeInfo`.
:param subtype: The name or OID of the subtype of the element to look for.
:return: The `!TypeInfo` object of class `!cls` whose subtype is
`!subtype`. `!None` if the element or its range are not found.
"""
try:
info = self[subtype]
except KeyError:
return None
return self.get((cls, info.oid))
def _ensure_own_state(self) -> None:
# Time to write! so, copy.
if not self._own_state:
self._registry = self._registry.copy()
self._own_state = True

View File

@@ -0,0 +1,44 @@
"""
Timezone utility functions.
"""
# Copyright (C) 2020 The Psycopg Team
import logging
from typing import Dict, Optional, Union
from datetime import timezone, tzinfo
from .pq.abc import PGconn
from ._compat import ZoneInfo
logger = logging.getLogger("psycopg")
_timezones: Dict[Union[None, bytes], tzinfo] = {
None: timezone.utc,
b"UTC": timezone.utc,
}
def get_tzinfo(pgconn: Optional[PGconn]) -> tzinfo:
"""Return the Python timezone info of the connection's timezone."""
tzname = pgconn.parameter_status(b"TimeZone") if pgconn else None
try:
return _timezones[tzname]
except KeyError:
sname = tzname.decode() if tzname else "UTC"
try:
zi: tzinfo = ZoneInfo(sname)
except (KeyError, OSError):
logger.warning("unknown PostgreSQL timezone: %r; will use UTC", sname)
zi = timezone.utc
except Exception as ex:
logger.warning(
"error handling PostgreSQL timezone: %r; will use UTC (%s - %s)",
sname,
type(ex).__name__,
ex,
)
zi = timezone.utc
_timezones[tzname] = zi
return zi

View File

@@ -0,0 +1,137 @@
"""
Wrappers for numeric types.
"""
# Copyright (C) 2020 The Psycopg Team
# Wrappers to force numbers to be cast as specific PostgreSQL types
# These types are implemented here but exposed by `psycopg.types.numeric`.
# They are defined here to avoid a circular import.
_MODULE = "psycopg.types.numeric"
class Int2(int):
"""
Force dumping a Python `!int` as a PostgreSQL :sql:`smallint/int2`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: int) -> "Int2":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
class Int4(int):
"""
Force dumping a Python `!int` as a PostgreSQL :sql:`integer/int4`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: int) -> "Int4":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
class Int8(int):
"""
Force dumping a Python `!int` as a PostgreSQL :sql:`bigint/int8`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: int) -> "Int8":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
class IntNumeric(int):
"""
Force dumping a Python `!int` as a PostgreSQL :sql:`numeric/decimal`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: int) -> "IntNumeric":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
class Float4(float):
"""
Force dumping a Python `!float` as a PostgreSQL :sql:`float4/real`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: float) -> "Float4":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
class Float8(float):
"""
Force dumping a Python `!float` as a PostgreSQL :sql:`float8/double precision`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: float) -> "Float8":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
class Oid(int):
"""
Force dumping a Python `!int` as a PostgreSQL :sql:`oid`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: int) -> "Oid":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"

View File

@@ -0,0 +1,248 @@
"""
Protocol objects representing different implementations of the same classes.
"""
# Copyright (C) 2020 The Psycopg Team
from typing import Any, Callable, Generator, Mapping
from typing import List, Optional, Sequence, Tuple, TypeVar, Union
from typing import TYPE_CHECKING
from typing_extensions import TypeAlias
from . import pq
from ._enums import PyFormat as PyFormat
from ._compat import Protocol, LiteralString
if TYPE_CHECKING:
from . import sql
from .rows import Row, RowMaker
from .pq.abc import PGresult
from .waiting import Wait, Ready
from .connection import BaseConnection
from ._adapters_map import AdaptersMap
NoneType: type = type(None)
# An object implementing the buffer protocol
Buffer: TypeAlias = Union[bytes, bytearray, memoryview]
Query: TypeAlias = Union[LiteralString, bytes, "sql.SQL", "sql.Composed"]
Params: TypeAlias = Union[Sequence[Any], Mapping[str, Any]]
ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]")
PipelineCommand: TypeAlias = Callable[[], None]
DumperKey: TypeAlias = Union[type, Tuple["DumperKey", ...]]
# Waiting protocol types
RV = TypeVar("RV")
PQGenConn: TypeAlias = Generator[Tuple[int, "Wait"], "Ready", RV]
"""Generator for processes where the connection file number can change.
This can happen in connection and reset, but not in normal querying.
"""
PQGen: TypeAlias = Generator["Wait", "Ready", RV]
"""Generator for processes where the connection file number won't change.
"""
class WaitFunc(Protocol):
"""
Wait on the connection which generated `PQgen` and return its final result.
"""
def __call__(
self, gen: PQGen[RV], fileno: int, timeout: Optional[float] = None
) -> RV: ...
# Adaptation types
DumpFunc: TypeAlias = Callable[[Any], Buffer]
LoadFunc: TypeAlias = Callable[[Buffer], Any]
class AdaptContext(Protocol):
"""
A context describing how types are adapted.
Example of `~AdaptContext` are `~psycopg.Connection`, `~psycopg.Cursor`,
`~psycopg.adapt.Transformer`, `~psycopg.adapt.AdaptersMap`.
Note that this is a `~typing.Protocol`, so objects implementing
`!AdaptContext` don't need to explicitly inherit from this class.
"""
@property
def adapters(self) -> "AdaptersMap":
"""The adapters configuration that this object uses."""
...
@property
def connection(self) -> Optional["BaseConnection[Any]"]:
"""The connection used by this object, if available.
:rtype: `~psycopg.Connection` or `~psycopg.AsyncConnection` or `!None`
"""
...
class Dumper(Protocol):
"""
Convert Python objects of type `!cls` to PostgreSQL representation.
"""
format: pq.Format
"""
The format that this class `dump()` method produces,
`~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`.
This is a class attribute.
"""
oid: int
"""The oid to pass to the server, if known; 0 otherwise (class attribute)."""
def __init__(self, cls: type, context: Optional[AdaptContext] = None): ...
def dump(self, obj: Any) -> Buffer:
"""Convert the object `!obj` to PostgreSQL representation.
:param obj: the object to convert.
"""
...
def quote(self, obj: Any) -> Buffer:
"""Convert the object `!obj` to escaped representation.
:param obj: the object to convert.
"""
...
def get_key(self, obj: Any, format: PyFormat) -> DumperKey:
"""Return an alternative key to upgrade the dumper to represent `!obj`.
:param obj: The object to convert
:param format: The format to convert to
Normally the type of the object is all it takes to define how to dump
the object to the database. For instance, a Python `~datetime.date` can
be simply converted into a PostgreSQL :sql:`date`.
In a few cases, just the type is not enough. For example:
- A Python `~datetime.datetime` could be represented as a
:sql:`timestamptz` or a :sql:`timestamp`, according to whether it
specifies a `!tzinfo` or not.
- A Python int could be stored as several Postgres types: int2, int4,
int8, numeric. If a type too small is used, it may result in an
overflow. If a type too large is used, PostgreSQL may not want to
cast it to a smaller type.
- Python lists should be dumped according to the type they contain to
convert them to e.g. array of strings, array of ints (and which
size of int?...)
In these cases, a dumper can implement `!get_key()` and return a new
class, or sequence of classes, that can be used to identify the same
dumper again. If the mechanism is not needed, the method should return
the same `!cls` object passed in the constructor.
If a dumper implements `get_key()` it should also implement
`upgrade()`.
"""
...
def upgrade(self, obj: Any, format: PyFormat) -> "Dumper":
"""Return a new dumper to manage `!obj`.
:param obj: The object to convert
:param format: The format to convert to
Once `Transformer.get_dumper()` has been notified by `get_key()` that
this Dumper class cannot handle `!obj` itself, it will invoke
`!upgrade()`, which should return a new `Dumper` instance, which will
be reused for every objects for which `!get_key()` returns the same
result.
"""
...
class Loader(Protocol):
"""
Convert PostgreSQL values with type OID `!oid` to Python objects.
"""
format: pq.Format
"""
The format that this class `load()` method can convert,
`~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`.
This is a class attribute.
"""
def __init__(self, oid: int, context: Optional[AdaptContext] = None): ...
def load(self, data: Buffer) -> Any:
"""
Convert the data returned by the database into a Python object.
:param data: the data to convert.
"""
...
class Transformer(Protocol):
types: Optional[Tuple[int, ...]]
formats: Optional[List[pq.Format]]
def __init__(self, context: Optional[AdaptContext] = None): ...
@classmethod
def from_context(cls, context: Optional[AdaptContext]) -> "Transformer": ...
@property
def connection(self) -> Optional["BaseConnection[Any]"]: ...
@property
def encoding(self) -> str: ...
@property
def adapters(self) -> "AdaptersMap": ...
@property
def pgresult(self) -> Optional["PGresult"]: ...
def set_pgresult(
self,
result: Optional["PGresult"],
*,
set_loaders: bool = True,
format: Optional[pq.Format] = None
) -> None: ...
def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: ...
def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: ...
def dump_sequence(
self, params: Sequence[Any], formats: Sequence[PyFormat]
) -> Sequence[Optional[Buffer]]: ...
def as_literal(self, obj: Any) -> bytes: ...
def get_dumper(self, obj: Any, format: PyFormat) -> Dumper: ...
def load_rows(
self, row0: int, row1: int, make_row: "RowMaker[Row]"
) -> List["Row"]: ...
def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]: ...
def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]: ...
def get_loader(self, oid: int, format: pq.Format) -> Loader: ...

View File

@@ -0,0 +1,161 @@
"""
Entry point into the adaptation system.
"""
# Copyright (C) 2020 The Psycopg Team
from abc import ABC, abstractmethod
from typing import Any, Optional, Type, TYPE_CHECKING
from . import pq, abc
from . import _adapters_map
from ._enums import PyFormat as PyFormat
from ._cmodule import _psycopg
if TYPE_CHECKING:
from .connection import BaseConnection
AdaptersMap = _adapters_map.AdaptersMap
Buffer = abc.Buffer
ORD_BS = ord("\\")
class Dumper(abc.Dumper, ABC):
"""
Convert Python object of the type `!cls` to PostgreSQL representation.
"""
oid: int = 0
"""The oid to pass to the server, if known."""
format: pq.Format = pq.Format.TEXT
"""The format of the data dumped."""
def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
self.cls = cls
self.connection: Optional["BaseConnection[Any]"] = (
context.connection if context else None
)
def __repr__(self) -> str:
return (
f"<{type(self).__module__}.{type(self).__qualname__}"
f" (oid={self.oid}) at 0x{id(self):x}>"
)
@abstractmethod
def dump(self, obj: Any) -> Buffer: ...
def quote(self, obj: Any) -> Buffer:
"""
By default return the `dump()` value quoted and sanitised, so
that the result can be used to build a SQL string. This works well
for most types and you won't likely have to implement this method in a
subclass.
"""
value = self.dump(obj)
if self.connection:
esc = pq.Escaping(self.connection.pgconn)
# escaping and quoting
return esc.escape_literal(value)
# This path is taken when quote is asked without a connection,
# usually it means by psycopg.sql.quote() or by
# 'Composible.as_string(None)'. Most often than not this is done by
# someone generating a SQL file to consume elsewhere.
# No quoting, only quote escaping, random bs escaping. See further.
esc = pq.Escaping()
out = esc.escape_string(value)
# b"\\" in memoryview doesn't work so search for the ascii value
if ORD_BS not in out:
# If the string has no backslash, the result is correct and we
# don't need to bother with standard_conforming_strings.
return b"'" + out + b"'"
# The libpq has a crazy behaviour: PQescapeString uses the last
# standard_conforming_strings setting seen on a connection. This
# means that backslashes might be escaped or might not.
#
# A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH,
# if scs is off, '\\' raises a warning and '\' is an error.
#
# Check what the libpq does, and if it doesn't escape the backslash
# let's do it on our own. Never mind the race condition.
rv: bytes = b" E'" + out + b"'"
if esc.escape_string(b"\\") == b"\\":
rv = rv.replace(b"\\", b"\\\\")
return rv
def get_key(self, obj: Any, format: PyFormat) -> abc.DumperKey:
"""
Implementation of the `~psycopg.abc.Dumper.get_key()` member of the
`~psycopg.abc.Dumper` protocol. Look at its definition for details.
This implementation returns the `!cls` passed in the constructor.
Subclasses needing to specialise the PostgreSQL type according to the
*value* of the object dumped (not only according to to its type)
should override this class.
"""
return self.cls
def upgrade(self, obj: Any, format: PyFormat) -> "Dumper":
"""
Implementation of the `~psycopg.abc.Dumper.upgrade()` member of the
`~psycopg.abc.Dumper` protocol. Look at its definition for details.
This implementation just returns `!self`. If a subclass implements
`get_key()` it should probably override `!upgrade()` too.
"""
return self
class Loader(abc.Loader, ABC):
"""
Convert PostgreSQL values with type OID `!oid` to Python objects.
"""
format: pq.Format = pq.Format.TEXT
"""The format of the data loaded."""
def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
self.oid = oid
self.connection: Optional["BaseConnection[Any]"] = (
context.connection if context else None
)
@abstractmethod
def load(self, data: Buffer) -> Any:
"""Convert a PostgreSQL value to a Python object."""
...
Transformer: Type["abc.Transformer"]
# Override it with fast object if available
if _psycopg:
Transformer = _psycopg.Transformer
else:
from . import _transform
Transformer = _transform.Transformer
class RecursiveDumper(Dumper):
"""Dumper with a transformer to help dumping recursive types."""
def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
super().__init__(cls, context)
self._tx = Transformer.from_context(context)
class RecursiveLoader(Loader):
"""Loader with a transformer to help loading recursive types."""
def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
super().__init__(oid, context)
self._tx = Transformer.from_context(context)

View File

@@ -0,0 +1,95 @@
"""
psycopg client-side binding cursors
"""
# Copyright (C) 2022 The Psycopg Team
from typing import Optional, Tuple, TYPE_CHECKING
from functools import partial
from ._queries import PostgresQuery, PostgresClientQuery
from . import pq
from . import adapt
from . import errors as e
from .abc import ConnectionType, Query, Params
from .rows import Row
from .cursor import BaseCursor, Cursor
from ._preparing import Prepare
from .cursor_async import AsyncCursor
if TYPE_CHECKING:
from typing import Any # noqa: F401
from .connection import Connection # noqa: F401
from .connection_async import AsyncConnection # noqa: F401
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
class ClientCursorMixin(BaseCursor[ConnectionType, Row]):
def mogrify(self, query: Query, params: Optional[Params] = None) -> str:
"""
Return the query and parameters merged.
Parameters are adapted and merged to the query the same way that
`!execute()` would do.
"""
self._tx = adapt.Transformer(self)
pgq = self._convert_query(query, params)
return pgq.query.decode(self._tx.encoding)
def _execute_send(
self,
query: PostgresQuery,
*,
force_extended: bool = False,
binary: Optional[bool] = None,
) -> None:
if binary is None:
fmt = self.format
else:
fmt = BINARY if binary else TEXT
if fmt == BINARY:
raise e.NotSupportedError(
"client-side cursors don't support binary results"
)
self._query = query
if self._conn._pipeline:
# In pipeline mode always use PQsendQueryParams - see #314
# Multiple statements in the same query are not allowed anyway.
self._conn._pipeline.command_queue.append(
partial(self._pgconn.send_query_params, query.query, None)
)
elif force_extended:
self._pgconn.send_query_params(query.query, None)
else:
# If we can, let's use simple query protocol,
# as it can execute more than one statement in a single query.
self._pgconn.send_query(query.query)
def _convert_query(
self, query: Query, params: Optional[Params] = None
) -> PostgresQuery:
pgq = PostgresClientQuery(self._tx)
pgq.convert(query, params)
return pgq
def _get_prepared(
self, pgq: PostgresQuery, prepare: Optional[bool] = None
) -> Tuple[Prepare, bytes]:
return (Prepare.NO, b"")
class ClientCursor(ClientCursorMixin["Connection[Any]", Row], Cursor[Row]):
__module__ = "psycopg"
class AsyncClientCursor(
ClientCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row]
):
__module__ = "psycopg"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,425 @@
"""
psycopg async connection objects
"""
# Copyright (C) 2020 The Psycopg Team
import sys
import asyncio
import logging
from types import TracebackType
from typing import Any, AsyncGenerator, AsyncIterator, List, Optional
from typing import Type, Union, cast, overload, TYPE_CHECKING
from contextlib import asynccontextmanager
from . import pq
from . import errors as e
from . import waiting
from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
from ._tpc import Xid
from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
from .adapt import AdaptersMap
from ._enums import IsolationLevel
from ._compat import Self
from .conninfo import ConnDict, make_conninfo, conninfo_to_dict
from .conninfo import conninfo_attempts_async, timeout_from_conninfo
from ._pipeline import AsyncPipeline
from ._encodings import pgconn_encoding
from .connection import BaseConnection, CursorRow, Notify
from .generators import notifies
from .transaction import AsyncTransaction
from .cursor_async import AsyncCursor
from .server_cursor import AsyncServerCursor
if TYPE_CHECKING:
from .pq.abc import PGconn
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
IDLE = pq.TransactionStatus.IDLE
INTRANS = pq.TransactionStatus.INTRANS
logger = logging.getLogger("psycopg")
class AsyncConnection(BaseConnection[Row]):
"""
Asynchronous wrapper for a connection to the database.
"""
__module__ = "psycopg"
cursor_factory: Type[AsyncCursor[Row]]
server_cursor_factory: Type[AsyncServerCursor[Row]]
row_factory: AsyncRowFactory[Row]
_pipeline: Optional[AsyncPipeline]
def __init__(
self,
pgconn: "PGconn",
row_factory: AsyncRowFactory[Row] = cast(AsyncRowFactory[Row], tuple_row),
):
super().__init__(pgconn)
self.row_factory = row_factory
self.lock = asyncio.Lock()
self.cursor_factory = AsyncCursor
self.server_cursor_factory = AsyncServerCursor
@overload
@classmethod
async def connect(
cls,
conninfo: str = "",
*,
autocommit: bool = False,
prepare_threshold: Optional[int] = 5,
row_factory: AsyncRowFactory[Row],
cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
context: Optional[AdaptContext] = None,
**kwargs: Union[None, int, str],
) -> "AsyncConnection[Row]":
# TODO: returned type should be Self. See #308.
# Unfortunately we cannot use Self[Row] as Self is not parametric.
# https://peps.python.org/pep-0673/#use-in-generic-classes
...
@overload
@classmethod
async def connect(
cls,
conninfo: str = "",
*,
autocommit: bool = False,
prepare_threshold: Optional[int] = 5,
cursor_factory: Optional[Type[AsyncCursor[Any]]] = None,
context: Optional[AdaptContext] = None,
**kwargs: Union[None, int, str],
) -> "AsyncConnection[TupleRow]": ...
@classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004
async def connect(
cls,
conninfo: str = "",
*,
autocommit: bool = False,
prepare_threshold: Optional[int] = 5,
context: Optional[AdaptContext] = None,
row_factory: Optional[AsyncRowFactory[Row]] = None,
cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
**kwargs: Any,
) -> Self:
if sys.platform == "win32":
loop = asyncio.get_running_loop()
if isinstance(loop, asyncio.ProactorEventLoop):
raise e.InterfaceError(
"Psycopg cannot use the 'ProactorEventLoop' to run in async"
" mode. Please use a compatible event loop, for instance by"
" setting 'asyncio.set_event_loop_policy"
"(WindowsSelectorEventLoopPolicy())'"
)
params = await cls._get_connection_params(conninfo, **kwargs)
timeout = timeout_from_conninfo(params)
rv = None
attempts = await conninfo_attempts_async(params)
for attempt in attempts:
try:
conninfo = make_conninfo(**attempt)
rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
break
except e._NO_TRACEBACK as ex:
if len(attempts) > 1:
logger.debug(
"connection attempt failed on host: %r, port: %r,"
" hostaddr: %r: %s",
attempt.get("host"),
attempt.get("port"),
attempt.get("hostaddr"),
str(ex),
)
last_ex = ex
if not rv:
assert last_ex
raise last_ex.with_traceback(None)
rv._autocommit = bool(autocommit)
if row_factory:
rv.row_factory = row_factory
if cursor_factory:
rv.cursor_factory = cursor_factory
if context:
rv._adapters = AdaptersMap(context.adapters)
rv.prepare_threshold = prepare_threshold
return rv
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
if self.closed:
return
if exc_type:
# try to rollback, but if there are problems (connection in a bad
# state) just warn without clobbering the exception bubbling up.
try:
await self.rollback()
except Exception as exc2:
logger.warning(
"error ignored in rollback on %s: %s",
self,
exc2,
)
else:
await self.commit()
# Close the connection only if it doesn't belong to a pool.
if not getattr(self, "_pool", None):
await self.close()
@classmethod
async def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
"""Manipulate connection parameters before connecting."""
return conninfo_to_dict(conninfo, **kwargs)
async def close(self) -> None:
if self.closed:
return
self._closed = True
# TODO: maybe send a cancel on close, if the connection is ACTIVE?
self.pgconn.finish()
@overload
def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]: ...
@overload
def cursor(
self, *, binary: bool = False, row_factory: AsyncRowFactory[CursorRow]
) -> AsyncCursor[CursorRow]: ...
@overload
def cursor(
self,
name: str,
*,
binary: bool = False,
scrollable: Optional[bool] = None,
withhold: bool = False,
) -> AsyncServerCursor[Row]: ...
@overload
def cursor(
self,
name: str,
*,
binary: bool = False,
row_factory: AsyncRowFactory[CursorRow],
scrollable: Optional[bool] = None,
withhold: bool = False,
) -> AsyncServerCursor[CursorRow]: ...
def cursor(
self,
name: str = "",
*,
binary: bool = False,
row_factory: Optional[AsyncRowFactory[Any]] = None,
scrollable: Optional[bool] = None,
withhold: bool = False,
) -> Union[AsyncCursor[Any], AsyncServerCursor[Any]]:
"""
Return a new `AsyncCursor` to send commands and queries to the connection.
"""
self._check_connection_ok()
if not row_factory:
row_factory = self.row_factory
cur: Union[AsyncCursor[Any], AsyncServerCursor[Any]]
if name:
cur = self.server_cursor_factory(
self,
name=name,
row_factory=row_factory,
scrollable=scrollable,
withhold=withhold,
)
else:
cur = self.cursor_factory(self, row_factory=row_factory)
if binary:
cur.format = BINARY
return cur
async def execute(
self,
query: Query,
params: Optional[Params] = None,
*,
prepare: Optional[bool] = None,
binary: bool = False,
) -> AsyncCursor[Row]:
try:
cur = self.cursor()
if binary:
cur.format = BINARY
return await cur.execute(query, params, prepare=prepare)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
async def commit(self) -> None:
async with self.lock:
await self.wait(self._commit_gen())
async def rollback(self) -> None:
async with self.lock:
await self.wait(self._rollback_gen())
@asynccontextmanager
async def transaction(
self,
savepoint_name: Optional[str] = None,
force_rollback: bool = False,
) -> AsyncIterator[AsyncTransaction]:
"""
Start a context block with a new transaction or nested transaction.
:rtype: AsyncTransaction
"""
tx = AsyncTransaction(self, savepoint_name, force_rollback)
if self._pipeline:
async with self.pipeline(), tx, self.pipeline():
yield tx
else:
async with tx:
yield tx
async def notifies(self) -> AsyncGenerator[Notify, None]:
while True:
async with self.lock:
try:
ns = await self.wait(notifies(self.pgconn))
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
enc = pgconn_encoding(self.pgconn)
for pgn in ns:
n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
yield n
@asynccontextmanager
async def pipeline(self) -> AsyncIterator[AsyncPipeline]:
"""Context manager to switch the connection into pipeline mode."""
async with self.lock:
self._check_connection_ok()
pipeline = self._pipeline
if pipeline is None:
# WARNING: reference loop, broken ahead.
pipeline = self._pipeline = AsyncPipeline(self)
try:
async with pipeline:
yield pipeline
finally:
if pipeline.level == 0:
async with self.lock:
assert pipeline is self._pipeline
self._pipeline = None
async def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
try:
return await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout)
except (asyncio.CancelledError, KeyboardInterrupt):
# On Ctrl-C, try to cancel the query in the server, otherwise
# the connection will remain stuck in ACTIVE state.
self._try_cancel(self.pgconn)
try:
await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout)
except e.QueryCanceled:
pass # as expected
raise
@classmethod
async def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV:
return await waiting.wait_conn_async(gen, timeout)
def _set_autocommit(self, value: bool) -> None:
self._no_set_async("autocommit")
async def set_autocommit(self, value: bool) -> None:
"""Async version of the `~Connection.autocommit` setter."""
async with self.lock:
await self.wait(self._set_autocommit_gen(value))
def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
self._no_set_async("isolation_level")
async def set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
"""Async version of the `~Connection.isolation_level` setter."""
async with self.lock:
await self.wait(self._set_isolation_level_gen(value))
def _set_read_only(self, value: Optional[bool]) -> None:
self._no_set_async("read_only")
async def set_read_only(self, value: Optional[bool]) -> None:
"""Async version of the `~Connection.read_only` setter."""
async with self.lock:
await self.wait(self._set_read_only_gen(value))
def _set_deferrable(self, value: Optional[bool]) -> None:
self._no_set_async("deferrable")
async def set_deferrable(self, value: Optional[bool]) -> None:
"""Async version of the `~Connection.deferrable` setter."""
async with self.lock:
await self.wait(self._set_deferrable_gen(value))
def _no_set_async(self, attribute: str) -> None:
raise AttributeError(
f"'the {attribute!r} property is read-only on async connections:"
f" please use 'await .set_{attribute}()' instead."
)
async def tpc_begin(self, xid: Union[Xid, str]) -> None:
async with self.lock:
await self.wait(self._tpc_begin_gen(xid))
async def tpc_prepare(self) -> None:
try:
async with self.lock:
await self.wait(self._tpc_prepare_gen())
except e.ObjectNotInPrerequisiteState as ex:
raise e.NotSupportedError(str(ex)) from None
async def tpc_commit(self, xid: Union[Xid, str, None] = None) -> None:
async with self.lock:
await self.wait(self._tpc_finish_gen("commit", xid))
async def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None:
async with self.lock:
await self.wait(self._tpc_finish_gen("rollback", xid))
async def tpc_recover(self) -> List[Xid]:
self._check_tpc()
status = self.info.transaction_status
async with self.cursor(row_factory=args_row(Xid._from_record)) as cur:
await cur.execute(Xid._get_recover_query())
res = await cur.fetchall()
if status == IDLE and self.info.transaction_status == INTRANS:
await self.rollback()
return res

View File

@@ -0,0 +1,154 @@
"""
Functions to manipulate conninfo strings
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
import re
from typing import Any
from . import pq
from . import errors as e
from . import _conninfo_utils
from . import _conninfo_attempts
from . import _conninfo_attempts_async
# re-exoprts
ConnDict = _conninfo_utils.ConnDict
conninfo_attempts = _conninfo_attempts.conninfo_attempts
conninfo_attempts_async = _conninfo_attempts_async.conninfo_attempts_async
# Default timeout for connection a attempt.
# Arbitrary timeout, what applied by the libpq on my computer.
# Your mileage won't vary.
_DEFAULT_CONNECT_TIMEOUT = 130
def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
"""
Merge a string and keyword params into a single conninfo string.
:param conninfo: A `connection string`__ as accepted by PostgreSQL.
:param kwargs: Parameters overriding the ones specified in `!conninfo`.
:return: A connection string valid for PostgreSQL, with the `!kwargs`
parameters merged.
Raise `~psycopg.ProgrammingError` if the input doesn't make a valid
conninfo string.
.. __: https://www.postgresql.org/docs/current/libpq-connect.html
#LIBPQ-CONNSTRING
"""
if not conninfo and not kwargs:
return ""
# If no kwarg specified don't mung the conninfo but check if it's correct.
# Make sure to return a string, not a subtype, to avoid making Liskov sad.
if not kwargs:
_parse_conninfo(conninfo)
return str(conninfo)
# Override the conninfo with the parameters
# Drop the None arguments
kwargs = {k: v for (k, v) in kwargs.items() if v is not None}
if conninfo:
tmp = conninfo_to_dict(conninfo)
tmp.update(kwargs)
kwargs = tmp
conninfo = " ".join(f"{k}={_param_escape(str(v))}" for (k, v) in kwargs.items())
# Verify the result is valid
_parse_conninfo(conninfo)
return conninfo
def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> ConnDict:
"""
Convert the `!conninfo` string into a dictionary of parameters.
:param conninfo: A `connection string`__ as accepted by PostgreSQL.
:param kwargs: Parameters overriding the ones specified in `!conninfo`.
:return: Dictionary with the parameters parsed from `!conninfo` and
`!kwargs`.
Raise `~psycopg.ProgrammingError` if `!conninfo` is not a a valid connection
string.
.. __: https://www.postgresql.org/docs/current/libpq-connect.html
#LIBPQ-CONNSTRING
"""
opts = _parse_conninfo(conninfo)
rv = {opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None}
for k, v in kwargs.items():
if v is not None:
rv[k] = v
return rv
def _parse_conninfo(conninfo: str) -> list[pq.ConninfoOption]:
"""
Verify that `!conninfo` is a valid connection string.
Raise ProgrammingError if the string is not valid.
Return the result of pq.Conninfo.parse() on success.
"""
try:
return pq.Conninfo.parse(conninfo.encode())
except e.OperationalError as ex:
raise e.ProgrammingError(str(ex)) from None
re_escape = re.compile(r"([\\'])")
re_space = re.compile(r"\s")
def _param_escape(s: str) -> str:
"""
Apply the escaping rule required by PQconnectdb
"""
if not s:
return "''"
s = re_escape.sub(r"\\\1", s)
if re_space.search(s):
s = "'" + s + "'"
return s
def timeout_from_conninfo(params: ConnDict) -> int:
"""
Return the timeout in seconds from the connection parameters.
"""
# Follow the libpq convention:
#
# - 0 or less means no timeout (but we will use a default to simulate
# the socket timeout)
# - at least 2 seconds.
#
# See connectDBComplete in fe-connect.c
value: str | int | None = _conninfo_utils.get_param(params, "connect_timeout")
if value is None:
value = _DEFAULT_CONNECT_TIMEOUT
try:
timeout = int(value)
except ValueError:
raise e.ProgrammingError(f"bad value for connect_timeout: {value!r}")
if timeout <= 0:
# The sync connect function will stop on the default socket timeout
# Because in async connection mode we need to enforce the timeout
# ourselves, we need a finite value.
timeout = _DEFAULT_CONNECT_TIMEOUT
elif timeout < 2:
# Enforce a 2s min
timeout = 2
return timeout

View File

@@ -0,0 +1,912 @@
"""
psycopg copy support
"""
# Copyright (C) 2020 The Psycopg Team
import re
import queue
import struct
import asyncio
import threading
from abc import ABC, abstractmethod
from types import TracebackType
from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match, IO
from typing import Optional, Sequence, Tuple, Type, Union, TYPE_CHECKING
from . import pq
from . import adapt
from . import errors as e
from .abc import Buffer, ConnectionType, PQGen, Transformer
from ._compat import create_task, Self
from .pq.misc import connection_summary
from ._cmodule import _psycopg
from ._encodings import pgconn_encoding
from .generators import copy_from, copy_to, copy_end
if TYPE_CHECKING:
from .cursor import BaseCursor, Cursor
from .cursor_async import AsyncCursor
from .connection import Connection # noqa: F401
from .connection_async import AsyncConnection # noqa: F401
PY_TEXT = adapt.PyFormat.TEXT
PY_BINARY = adapt.PyFormat.BINARY
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
COPY_IN = pq.ExecStatus.COPY_IN
COPY_OUT = pq.ExecStatus.COPY_OUT
ACTIVE = pq.TransactionStatus.ACTIVE
# Size of data to accumulate before sending it down the network. We fill a
# buffer this size field by field, and when it passes the threshold size
# we ship it, so it may end up being bigger than this.
BUFFER_SIZE = 32 * 1024
# Maximum data size we want to queue to send to the libpq copy. Sending a
# buffer too big to be handled can cause an infinite loop in the libpq
# (#255) so we want to split it in more digestable chunks.
MAX_BUFFER_SIZE = 4 * BUFFER_SIZE
# Note: making this buffer too large, e.g.
# MAX_BUFFER_SIZE = 1024 * 1024
# makes operations *way* slower! Probably triggering some quadraticity
# in the libpq memory management and data sending.
# Max size of the write queue of buffers. More than that copy will block
# Each buffer should be around BUFFER_SIZE size.
QUEUE_SIZE = 1024
class BaseCopy(Generic[ConnectionType]):
"""
Base implementation for the copy user interface.
Two subclasses expose real methods with the sync/async differences.
The difference between the text and binary format is managed by two
different `Formatter` subclasses.
Writing (the I/O part) is implemented in the subclasses by a `Writer` or
`AsyncWriter` instance. Normally writing implies sending copy data to a
database, but a different writer might be chosen, e.g. to stream data into
a file for later use.
"""
formatter: "Formatter"
def __init__(
self,
cursor: "BaseCursor[ConnectionType, Any]",
*,
binary: Optional[bool] = None,
):
self.cursor = cursor
self.connection = cursor.connection
self._pgconn = self.connection.pgconn
result = cursor.pgresult
if result:
self._direction = result.status
if self._direction != COPY_IN and self._direction != COPY_OUT:
raise e.ProgrammingError(
"the cursor should have performed a COPY operation;"
f" its status is {pq.ExecStatus(self._direction).name} instead"
)
else:
self._direction = COPY_IN
if binary is None:
binary = bool(result and result.binary_tuples)
tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor)
if binary:
self.formatter = BinaryFormatter(tx)
else:
self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn))
self._finished = False
def __repr__(self) -> str:
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
info = connection_summary(self._pgconn)
return f"<{cls} {info} at 0x{id(self):x}>"
def _enter(self) -> None:
if self._finished:
raise TypeError("copy blocks can be used only once")
def set_types(self, types: Sequence[Union[int, str]]) -> None:
"""
Set the types expected in a COPY operation.
The types must be specified as a sequence of oid or PostgreSQL type
names (e.g. ``int4``, ``timestamptz[]``).
This operation overcomes the lack of metadata returned by PostgreSQL
when a COPY operation begins:
- On :sql:`COPY TO`, `!set_types()` allows to specify what types the
operation returns. If `!set_types()` is not used, the data will be
returned as unparsed strings or bytes instead of Python objects.
- On :sql:`COPY FROM`, `!set_types()` allows to choose what type the
database expects. This is especially useful in binary copy, because
PostgreSQL will apply no cast rule.
"""
registry = self.cursor.adapters.types
oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types]
if self._direction == COPY_IN:
self.formatter.transformer.set_dumper_types(oids, self.formatter.format)
else:
self.formatter.transformer.set_loader_types(oids, self.formatter.format)
# High level copy protocol generators (state change of the Copy object)
def _read_gen(self) -> PQGen[Buffer]:
if self._finished:
return memoryview(b"")
res = yield from copy_from(self._pgconn)
if isinstance(res, memoryview):
return res
# res is the final PGresult
self._finished = True
# This result is a COMMAND_OK which has info about the number of rows
# returned, but not about the columns, which is instead an information
# that was received on the COPY_OUT result at the beginning of COPY.
# So, don't replace the results in the cursor, just update the rowcount.
nrows = res.command_tuples
self.cursor._rowcount = nrows if nrows is not None else -1
return memoryview(b"")
def _read_row_gen(self) -> PQGen[Optional[Tuple[Any, ...]]]:
data = yield from self._read_gen()
if not data:
return None
row = self.formatter.parse_row(data)
if row is None:
# Get the final result to finish the copy operation
yield from self._read_gen()
self._finished = True
return None
return row
def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
if not exc:
return
if self._pgconn.transaction_status != ACTIVE:
# The server has already finished to send copy data. The connection
# is already in a good state.
return
# Throw a cancel to the server, then consume the rest of the copy data
# (which might or might not have been already transferred entirely to
# the client, so we won't necessary see the exception associated with
# canceling).
self.connection.cancel()
try:
while (yield from self._read_gen()):
pass
except e.QueryCanceled:
pass
class Copy(BaseCopy["Connection[Any]"]):
"""Manage a :sql:`COPY` operation.
:param cursor: the cursor where the operation is performed.
:param binary: if `!True`, write binary format.
:param writer: the object to write to destination. If not specified, write
to the `!cursor` connection.
Choosing `!binary` is not necessary if the cursor has executed a
:sql:`COPY` operation, because the operation result describes the format
too. The parameter is useful when a `!Copy` object is created manually and
no operation is performed on the cursor, such as when using ``writer=``\\
`~psycopg.copy.FileWriter`.
"""
__module__ = "psycopg"
writer: "Writer"
def __init__(
self,
cursor: "Cursor[Any]",
*,
binary: Optional[bool] = None,
writer: Optional["Writer"] = None,
):
super().__init__(cursor, binary=binary)
if not writer:
writer = LibpqWriter(cursor)
self.writer = writer
self._write = writer.write
def __enter__(self) -> Self:
self._enter()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.finish(exc_val)
# End user sync interface
def __iter__(self) -> Iterator[Buffer]:
"""Implement block-by-block iteration on :sql:`COPY TO`."""
while True:
data = self.read()
if not data:
break
yield data
def read(self) -> Buffer:
"""
Read an unparsed row after a :sql:`COPY TO` operation.
Return an empty string when the data is finished.
"""
return self.connection.wait(self._read_gen())
def rows(self) -> Iterator[Tuple[Any, ...]]:
"""
Iterate on the result of a :sql:`COPY TO` operation record by record.
Note that the records returned will be tuples of unparsed strings or
bytes, unless data types are specified using `set_types()`.
"""
while True:
record = self.read_row()
if record is None:
break
yield record
def read_row(self) -> Optional[Tuple[Any, ...]]:
"""
Read a parsed row of data from a table after a :sql:`COPY TO` operation.
Return `!None` when the data is finished.
Note that the records returned will be tuples of unparsed strings or
bytes, unless data types are specified using `set_types()`.
"""
return self.connection.wait(self._read_row_gen())
def write(self, buffer: Union[Buffer, str]) -> None:
"""
Write a block of data to a table after a :sql:`COPY FROM` operation.
If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In
text mode it can be either `!bytes` or `!str`.
"""
data = self.formatter.write(buffer)
if data:
self._write(data)
def write_row(self, row: Sequence[Any]) -> None:
"""Write a record to a table after a :sql:`COPY FROM` operation."""
data = self.formatter.write_row(row)
if data:
self._write(data)
def finish(self, exc: Optional[BaseException]) -> None:
"""Terminate the copy operation and free the resources allocated.
You shouldn't need to call this function yourself: it is usually called
by exit. It is available if, despite what is documented, you end up
using the `Copy` object outside a block.
"""
if self._direction == COPY_IN:
data = self.formatter.end()
if data:
self._write(data)
self.writer.finish(exc)
self._finished = True
else:
self.connection.wait(self._end_copy_out_gen(exc))
class Writer(ABC):
"""
A class to write copy data somewhere.
"""
@abstractmethod
def write(self, data: Buffer) -> None:
"""
Write some data to destination.
"""
...
def finish(self, exc: Optional[BaseException] = None) -> None:
"""
Called when write operations are finished.
If operations finished with an error, it will be passed to ``exc``.
"""
pass
class LibpqWriter(Writer):
"""
A `Writer` to write copy data to a Postgres database.
"""
def __init__(self, cursor: "Cursor[Any]"):
self.cursor = cursor
self.connection = cursor.connection
self._pgconn = self.connection.pgconn
def write(self, data: Buffer) -> None:
if len(data) <= MAX_BUFFER_SIZE:
# Most used path: we don't need to split the buffer in smaller
# bits, so don't make a copy.
self.connection.wait(copy_to(self._pgconn, data))
else:
# Copy a buffer too large in chunks to avoid causing a memory
# error in the libpq, which may cause an infinite loop (#255).
for i in range(0, len(data), MAX_BUFFER_SIZE):
self.connection.wait(
copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
)
def finish(self, exc: Optional[BaseException] = None) -> None:
bmsg: Optional[bytes]
if exc:
msg = f"error from Python: {type(exc).__qualname__} - {exc}"
bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
else:
bmsg = None
try:
res = self.connection.wait(copy_end(self._pgconn, bmsg))
# The QueryCanceled is expected if we sent an exception message to
# pgconn.put_copy_end(). The Python exception that generated that
# cancelling is more important, so don't clobber it.
except e.QueryCanceled:
if not bmsg:
raise
else:
self.cursor._results = [res]
class QueuedLibpqWriter(LibpqWriter):
"""
A writer using a buffer to queue data to write to a Postgres database.
`write()` returns immediately, so that the main thread can be CPU-bound
formatting messages, while a worker thread can be IO-bound waiting to write
on the connection.
"""
def __init__(self, cursor: "Cursor[Any]"):
super().__init__(cursor)
self._queue: queue.Queue[Buffer] = queue.Queue(maxsize=QUEUE_SIZE)
self._worker: Optional[threading.Thread] = None
self._worker_error: Optional[BaseException] = None
def worker(self) -> None:
"""Push data to the server when available from the copy queue.
Terminate reading when the queue receives a false-y value, or in case
of error.
The function is designed to be run in a separate thread.
"""
try:
while True:
data = self._queue.get(block=True, timeout=24 * 60 * 60)
if not data:
break
self.connection.wait(copy_to(self._pgconn, data))
except BaseException as ex:
# Propagate the error to the main thread.
self._worker_error = ex
def write(self, data: Buffer) -> None:
if not self._worker:
# warning: reference loop, broken by _write_end
self._worker = threading.Thread(target=self.worker)
self._worker.daemon = True
self._worker.start()
# If the worker thread raies an exception, re-raise it to the caller.
if self._worker_error:
raise self._worker_error
if len(data) <= MAX_BUFFER_SIZE:
# Most used path: we don't need to split the buffer in smaller
# bits, so don't make a copy.
self._queue.put(data)
else:
# Copy a buffer too large in chunks to avoid causing a memory
# error in the libpq, which may cause an infinite loop (#255).
for i in range(0, len(data), MAX_BUFFER_SIZE):
self._queue.put(data[i : i + MAX_BUFFER_SIZE])
def finish(self, exc: Optional[BaseException] = None) -> None:
self._queue.put(b"")
if self._worker:
self._worker.join()
self._worker = None # break the loop
# Check if the worker thread raised any exception before terminating.
if self._worker_error:
raise self._worker_error
super().finish(exc)
class FileWriter(Writer):
"""
A `Writer` to write copy data to a file-like object.
:param file: the file where to write copy data. It must be open for writing
in binary mode.
"""
def __init__(self, file: IO[bytes]):
self.file = file
def write(self, data: Buffer) -> None:
self.file.write(data)
class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
"""Manage an asynchronous :sql:`COPY` operation."""
__module__ = "psycopg"
writer: "AsyncWriter"
def __init__(
self,
cursor: "AsyncCursor[Any]",
*,
binary: Optional[bool] = None,
writer: Optional["AsyncWriter"] = None,
):
super().__init__(cursor, binary=binary)
if not writer:
writer = AsyncLibpqWriter(cursor)
self.writer = writer
self._write = writer.write
async def __aenter__(self) -> Self:
self._enter()
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.finish(exc_val)
async def __aiter__(self) -> AsyncIterator[Buffer]:
while True:
data = await self.read()
if not data:
break
yield data
async def read(self) -> Buffer:
return await self.connection.wait(self._read_gen())
async def rows(self) -> AsyncIterator[Tuple[Any, ...]]:
while True:
record = await self.read_row()
if record is None:
break
yield record
async def read_row(self) -> Optional[Tuple[Any, ...]]:
return await self.connection.wait(self._read_row_gen())
async def write(self, buffer: Union[Buffer, str]) -> None:
data = self.formatter.write(buffer)
if data:
await self._write(data)
async def write_row(self, row: Sequence[Any]) -> None:
data = self.formatter.write_row(row)
if data:
await self._write(data)
async def finish(self, exc: Optional[BaseException]) -> None:
if self._direction == COPY_IN:
data = self.formatter.end()
if data:
await self._write(data)
await self.writer.finish(exc)
self._finished = True
else:
await self.connection.wait(self._end_copy_out_gen(exc))
class AsyncWriter(ABC):
"""
A class to write copy data somewhere (for async connections).
"""
@abstractmethod
async def write(self, data: Buffer) -> None: ...
async def finish(self, exc: Optional[BaseException] = None) -> None:
pass
class AsyncLibpqWriter(AsyncWriter):
"""
An `AsyncWriter` to write copy data to a Postgres database.
"""
def __init__(self, cursor: "AsyncCursor[Any]"):
self.cursor = cursor
self.connection = cursor.connection
self._pgconn = self.connection.pgconn
async def write(self, data: Buffer) -> None:
if len(data) <= MAX_BUFFER_SIZE:
# Most used path: we don't need to split the buffer in smaller
# bits, so don't make a copy.
await self.connection.wait(copy_to(self._pgconn, data))
else:
# Copy a buffer too large in chunks to avoid causing a memory
# error in the libpq, which may cause an infinite loop (#255).
for i in range(0, len(data), MAX_BUFFER_SIZE):
await self.connection.wait(
copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
)
async def finish(self, exc: Optional[BaseException] = None) -> None:
bmsg: Optional[bytes]
if exc:
msg = f"error from Python: {type(exc).__qualname__} - {exc}"
bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
else:
bmsg = None
try:
res = await self.connection.wait(copy_end(self._pgconn, bmsg))
# The QueryCanceled is expected if we sent an exception message to
# pgconn.put_copy_end(). The Python exception that generated that
# cancelling is more important, so don't clobber it.
except e.QueryCanceled:
if not bmsg:
raise
else:
self.cursor._results = [res]
class AsyncQueuedLibpqWriter(AsyncLibpqWriter):
"""
An `AsyncWriter` using a buffer to queue data to write.
`write()` returns immediately, so that the main thread can be CPU-bound
formatting messages, while a worker thread can be IO-bound waiting to write
on the connection.
"""
def __init__(self, cursor: "AsyncCursor[Any]"):
super().__init__(cursor)
self._queue: asyncio.Queue[Buffer] = asyncio.Queue(maxsize=QUEUE_SIZE)
self._worker: Optional[asyncio.Future[None]] = None
async def worker(self) -> None:
"""Push data to the server when available from the copy queue.
Terminate reading when the queue receives a false-y value.
The function is designed to be run in a separate task.
"""
while True:
data = await self._queue.get()
if not data:
break
await self.connection.wait(copy_to(self._pgconn, data))
async def write(self, data: Buffer) -> None:
if not self._worker:
self._worker = create_task(self.worker())
if len(data) <= MAX_BUFFER_SIZE:
# Most used path: we don't need to split the buffer in smaller
# bits, so don't make a copy.
await self._queue.put(data)
else:
# Copy a buffer too large in chunks to avoid causing a memory
# error in the libpq, which may cause an infinite loop (#255).
for i in range(0, len(data), MAX_BUFFER_SIZE):
await self._queue.put(data[i : i + MAX_BUFFER_SIZE])
async def finish(self, exc: Optional[BaseException] = None) -> None:
await self._queue.put(b"")
if self._worker:
await asyncio.gather(self._worker)
self._worker = None # break reference loops if any
await super().finish(exc)
class Formatter(ABC):
"""
A class which understand a copy format (text, binary).
"""
format: pq.Format
def __init__(self, transformer: Transformer):
self.transformer = transformer
self._write_buffer = bytearray()
self._row_mode = False # true if the user is using write_row()
@abstractmethod
def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: ...
@abstractmethod
def write(self, buffer: Union[Buffer, str]) -> Buffer: ...
@abstractmethod
def write_row(self, row: Sequence[Any]) -> Buffer: ...
@abstractmethod
def end(self) -> Buffer: ...
class TextFormatter(Formatter):
format = TEXT
def __init__(self, transformer: Transformer, encoding: str = "utf-8"):
super().__init__(transformer)
self._encoding = encoding
def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
if data:
return parse_row_text(data, self.transformer)
else:
return None
def write(self, buffer: Union[Buffer, str]) -> Buffer:
data = self._ensure_bytes(buffer)
self._signature_sent = True
return data
def write_row(self, row: Sequence[Any]) -> Buffer:
# Note down that we are writing in row mode: it means we will have
# to take care of the end-of-copy marker too
self._row_mode = True
format_row_text(row, self.transformer, self._write_buffer)
if len(self._write_buffer) > BUFFER_SIZE:
buffer, self._write_buffer = self._write_buffer, bytearray()
return buffer
else:
return b""
def end(self) -> Buffer:
buffer, self._write_buffer = self._write_buffer, bytearray()
return buffer
def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
if isinstance(data, str):
return data.encode(self._encoding)
else:
# Assume, for simplicity, that the user is not passing stupid
# things to the write function. If that's the case, things
# will fail downstream.
return data
class BinaryFormatter(Formatter):
format = BINARY
def __init__(self, transformer: Transformer):
super().__init__(transformer)
self._signature_sent = False
def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
if not self._signature_sent:
if data[: len(_binary_signature)] != _binary_signature:
raise e.DataError(
"binary copy doesn't start with the expected signature"
)
self._signature_sent = True
data = data[len(_binary_signature) :]
elif data == _binary_trailer:
return None
return parse_row_binary(data, self.transformer)
def write(self, buffer: Union[Buffer, str]) -> Buffer:
data = self._ensure_bytes(buffer)
self._signature_sent = True
return data
def write_row(self, row: Sequence[Any]) -> Buffer:
# Note down that we are writing in row mode: it means we will have
# to take care of the end-of-copy marker too
self._row_mode = True
if not self._signature_sent:
self._write_buffer += _binary_signature
self._signature_sent = True
format_row_binary(row, self.transformer, self._write_buffer)
if len(self._write_buffer) > BUFFER_SIZE:
buffer, self._write_buffer = self._write_buffer, bytearray()
return buffer
else:
return b""
def end(self) -> Buffer:
# If we have sent no data we need to send the signature
# and the trailer
if not self._signature_sent:
self._write_buffer += _binary_signature
self._write_buffer += _binary_trailer
elif self._row_mode:
# if we have sent data already, we have sent the signature
# too (either with the first row, or we assume that in
# block mode the signature is included).
# Write the trailer only if we are sending rows (with the
# assumption that who is copying binary data is sending the
# whole format).
self._write_buffer += _binary_trailer
buffer, self._write_buffer = self._write_buffer, bytearray()
return buffer
def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
if isinstance(data, str):
raise TypeError("cannot copy str data in binary mode: use bytes instead")
else:
# Assume, for simplicity, that the user is not passing stupid
# things to the write function. If that's the case, things
# will fail downstream.
return data
def _format_row_text(
row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None
) -> bytearray:
"""Convert a row of objects to the data to send for copy."""
if out is None:
out = bytearray()
if not row:
out += b"\n"
return out
for item in row:
if item is not None:
dumper = tx.get_dumper(item, PY_TEXT)
b = dumper.dump(item)
out += _dump_re.sub(_dump_sub, b)
else:
out += rb"\N"
out += b"\t"
out[-1:] = b"\n"
return out
def _format_row_binary(
row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None
) -> bytearray:
"""Convert a row of objects to the data to send for binary copy."""
if out is None:
out = bytearray()
out += _pack_int2(len(row))
adapted = tx.dump_sequence(row, [PY_BINARY] * len(row))
for b in adapted:
if b is not None:
out += _pack_int4(len(b))
out += b
else:
out += _binary_null
return out
def _parse_row_text(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
if not isinstance(data, bytes):
data = bytes(data)
fields = data.split(b"\t")
fields[-1] = fields[-1][:-1] # drop \n
row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields]
return tx.load_sequence(row)
def _parse_row_binary(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
row: List[Optional[Buffer]] = []
nfields = _unpack_int2(data, 0)[0]
pos = 2
for i in range(nfields):
length = _unpack_int4(data, pos)[0]
pos += 4
if length >= 0:
row.append(data[pos : pos + length])
pos += length
else:
row.append(None)
return tx.load_sequence(row)
_pack_int2 = struct.Struct("!h").pack
_pack_int4 = struct.Struct("!i").pack
_unpack_int2 = struct.Struct("!h").unpack_from
_unpack_int4 = struct.Struct("!i").unpack_from
_binary_signature = (
b"PGCOPY\n\xff\r\n\0" # Signature
b"\x00\x00\x00\x00" # flags
b"\x00\x00\x00\x00" # extra length
)
_binary_trailer = b"\xff\xff"
_binary_null = b"\xff\xff\xff\xff"
_dump_re = re.compile(b"[\b\t\n\v\f\r\\\\]")
_dump_repl = {
b"\b": b"\\b",
b"\t": b"\\t",
b"\n": b"\\n",
b"\v": b"\\v",
b"\f": b"\\f",
b"\r": b"\\r",
b"\\": b"\\\\",
}
def _dump_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl) -> bytes:
return __map[m.group(0)]
_load_re = re.compile(b"\\\\[btnvfr\\\\]")
_load_repl = {v: k for k, v in _dump_repl.items()}
def _load_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl) -> bytes:
return __map[m.group(0)]
# Override functions with fast versions if available
if _psycopg:
format_row_text = _psycopg.format_row_text
format_row_binary = _psycopg.format_row_binary
parse_row_text = _psycopg.parse_row_text
parse_row_binary = _psycopg.parse_row_binary
else:
format_row_text = _format_row_text
format_row_binary = _format_row_binary
parse_row_text = _parse_row_text
parse_row_binary = _parse_row_binary

View File

@@ -0,0 +1,19 @@
"""
CockroachDB support package.
"""
# Copyright (C) 2022 The Psycopg Team
from . import _types
from .connection import CrdbConnection, AsyncCrdbConnection, CrdbConnectionInfo
adapters = _types.adapters # exposed by the package
connect = CrdbConnection.connect
_types.register_crdb_adapters(adapters)
__all__ = [
"AsyncCrdbConnection",
"CrdbConnection",
"CrdbConnectionInfo",
]

View File

@@ -0,0 +1,163 @@
"""
Types configuration specific for CockroachDB.
"""
# Copyright (C) 2022 The Psycopg Team
from enum import Enum
from .._typeinfo import TypeInfo, TypesRegistry
from ..abc import AdaptContext, NoneType
from ..postgres import TEXT_OID
from .._adapters_map import AdaptersMap
from ..types.enum import EnumDumper, EnumBinaryDumper
from ..types.none import NoneDumper
types = TypesRegistry()
# Global adapter maps with PostgreSQL types configuration
adapters = AdaptersMap(types=types)
class CrdbEnumDumper(EnumDumper):
oid = TEXT_OID
class CrdbEnumBinaryDumper(EnumBinaryDumper):
oid = TEXT_OID
class CrdbNoneDumper(NoneDumper):
oid = TEXT_OID
def register_postgres_adapters(context: AdaptContext) -> None:
# Same adapters used by PostgreSQL, or a good starting point for customization
from ..types import array, bool, composite, datetime
from ..types import numeric, string, uuid
array.register_default_adapters(context)
bool.register_default_adapters(context)
composite.register_default_adapters(context)
datetime.register_default_adapters(context)
numeric.register_default_adapters(context)
string.register_default_adapters(context)
uuid.register_default_adapters(context)
def register_crdb_adapters(context: AdaptContext) -> None:
from .. import dbapi20
from ..types import array
register_postgres_adapters(context)
# String must come after enum to map text oid -> string dumper
register_crdb_enum_adapters(context)
register_crdb_string_adapters(context)
register_crdb_json_adapters(context)
register_crdb_net_adapters(context)
register_crdb_none_adapters(context)
dbapi20.register_dbapi20_adapters(adapters)
array.register_all_arrays(adapters)
def register_crdb_string_adapters(context: AdaptContext) -> None:
from ..types import string
# Dump strings with text oid instead of unknown.
# Unlike PostgreSQL, CRDB seems able to cast text to most types.
context.adapters.register_dumper(str, string.StrDumper)
context.adapters.register_dumper(str, string.StrBinaryDumper)
def register_crdb_enum_adapters(context: AdaptContext) -> None:
context.adapters.register_dumper(Enum, CrdbEnumBinaryDumper)
context.adapters.register_dumper(Enum, CrdbEnumDumper)
def register_crdb_json_adapters(context: AdaptContext) -> None:
from ..types import json
adapters = context.adapters
# CRDB doesn't have json/jsonb: both names map to the jsonb oid
adapters.register_dumper(json.Json, json.JsonbBinaryDumper)
adapters.register_dumper(json.Json, json.JsonbDumper)
adapters.register_dumper(json.Jsonb, json.JsonbBinaryDumper)
adapters.register_dumper(json.Jsonb, json.JsonbDumper)
adapters.register_loader("json", json.JsonLoader)
adapters.register_loader("jsonb", json.JsonbLoader)
adapters.register_loader("json", json.JsonBinaryLoader)
adapters.register_loader("jsonb", json.JsonbBinaryLoader)
def register_crdb_net_adapters(context: AdaptContext) -> None:
from ..types import net
adapters = context.adapters
adapters.register_dumper("ipaddress.IPv4Address", net.InterfaceDumper)
adapters.register_dumper("ipaddress.IPv6Address", net.InterfaceDumper)
adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceDumper)
adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceDumper)
adapters.register_dumper("ipaddress.IPv4Address", net.AddressBinaryDumper)
adapters.register_dumper("ipaddress.IPv6Address", net.AddressBinaryDumper)
adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceBinaryDumper)
adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceBinaryDumper)
adapters.register_dumper(None, net.InetBinaryDumper)
adapters.register_loader("inet", net.InetLoader)
adapters.register_loader("inet", net.InetBinaryLoader)
def register_crdb_none_adapters(context: AdaptContext) -> None:
context.adapters.register_dumper(NoneType, CrdbNoneDumper)
for t in [
TypeInfo("json", 3802, 3807, regtype="jsonb"), # Alias json -> jsonb.
TypeInfo("int8", 20, 1016, regtype="integer"), # Alias integer -> int8
TypeInfo('"char"', 18, 1002), # special case, not generated
# autogenerated: start
# Generated from CockroachDB 22.1.0
TypeInfo("bit", 1560, 1561),
TypeInfo("bool", 16, 1000, regtype="boolean"),
TypeInfo("bpchar", 1042, 1014, regtype="character"),
TypeInfo("bytea", 17, 1001),
TypeInfo("date", 1082, 1182),
TypeInfo("float4", 700, 1021, regtype="real"),
TypeInfo("float8", 701, 1022, regtype="double precision"),
TypeInfo("inet", 869, 1041),
TypeInfo("int2", 21, 1005, regtype="smallint"),
TypeInfo("int2vector", 22, 1006),
TypeInfo("int4", 23, 1007),
TypeInfo("int8", 20, 1016, regtype="bigint"),
TypeInfo("interval", 1186, 1187),
TypeInfo("jsonb", 3802, 3807),
TypeInfo("name", 19, 1003),
TypeInfo("numeric", 1700, 1231),
TypeInfo("oid", 26, 1028),
TypeInfo("oidvector", 30, 1013),
TypeInfo("record", 2249, 2287),
TypeInfo("regclass", 2205, 2210),
TypeInfo("regnamespace", 4089, 4090),
TypeInfo("regproc", 24, 1008),
TypeInfo("regprocedure", 2202, 2207),
TypeInfo("regrole", 4096, 4097),
TypeInfo("regtype", 2206, 2211),
TypeInfo("text", 25, 1009),
TypeInfo("time", 1083, 1183, regtype="time without time zone"),
TypeInfo("timestamp", 1114, 1115, regtype="timestamp without time zone"),
TypeInfo("timestamptz", 1184, 1185, regtype="timestamp with time zone"),
TypeInfo("timetz", 1266, 1270, regtype="time with time zone"),
TypeInfo("unknown", 705, 0),
TypeInfo("uuid", 2950, 2951),
TypeInfo("varbit", 1562, 1563, regtype="bit varying"),
TypeInfo("varchar", 1043, 1015, regtype="character varying"),
# autogenerated: end
]:
types.add(t)

View File

@@ -0,0 +1,180 @@
"""
CockroachDB-specific connections.
"""
# Copyright (C) 2022 The Psycopg Team
import re
from typing import Any, Optional, Type, Union, overload, TYPE_CHECKING
from .. import errors as e
from ..abc import AdaptContext
from ..rows import Row, RowFactory, AsyncRowFactory, TupleRow
from .._compat import Self
from ..connection import Connection
from .._adapters_map import AdaptersMap
from .._connection_info import ConnectionInfo
from ..connection_async import AsyncConnection
from ._types import adapters
if TYPE_CHECKING:
from ..pq.abc import PGconn
from ..cursor import Cursor
from ..cursor_async import AsyncCursor
class _CrdbConnectionMixin:
_adapters: Optional[AdaptersMap]
pgconn: "PGconn"
@classmethod
def is_crdb(
cls, conn: Union[Connection[Any], AsyncConnection[Any], "PGconn"]
) -> bool:
"""
Return `!True` if the server connected to `!conn` is CockroachDB.
"""
if isinstance(conn, (Connection, AsyncConnection)):
conn = conn.pgconn
return bool(conn.parameter_status(b"crdb_version"))
@property
def adapters(self) -> AdaptersMap:
if not self._adapters:
# By default, use CockroachDB adapters map
self._adapters = AdaptersMap(adapters)
return self._adapters
@property
def info(self) -> "CrdbConnectionInfo":
return CrdbConnectionInfo(self.pgconn)
def _check_tpc(self) -> None:
if self.is_crdb(self.pgconn):
raise e.NotSupportedError("CockroachDB doesn't support prepared statements")
class CrdbConnection(_CrdbConnectionMixin, Connection[Row]):
"""
Wrapper for a connection to a CockroachDB database.
"""
__module__ = "psycopg.crdb"
# TODO: this method shouldn't require re-definition if the base class
# implements a generic self.
# https://github.com/psycopg/psycopg/issues/308
@overload
@classmethod
def connect(
cls,
conninfo: str = "",
*,
autocommit: bool = False,
row_factory: RowFactory[Row],
prepare_threshold: Optional[int] = 5,
cursor_factory: "Optional[Type[Cursor[Row]]]" = None,
context: Optional[AdaptContext] = None,
**kwargs: Union[None, int, str],
) -> "CrdbConnection[Row]": ...
@overload
@classmethod
def connect(
cls,
conninfo: str = "",
*,
autocommit: bool = False,
prepare_threshold: Optional[int] = 5,
cursor_factory: "Optional[Type[Cursor[Any]]]" = None,
context: Optional[AdaptContext] = None,
**kwargs: Union[None, int, str],
) -> "CrdbConnection[TupleRow]": ...
@classmethod
def connect(cls, conninfo: str = "", **kwargs: Any) -> Self:
"""
Connect to a database server and return a new `CrdbConnection` instance.
"""
return super().connect(conninfo, **kwargs) # type: ignore[return-value]
class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]):
"""
Wrapper for an async connection to a CockroachDB database.
"""
__module__ = "psycopg.crdb"
# TODO: this method shouldn't require re-definition if the base class
# implements a generic self.
# https://github.com/psycopg/psycopg/issues/308
@overload
@classmethod
async def connect(
cls,
conninfo: str = "",
*,
autocommit: bool = False,
prepare_threshold: Optional[int] = 5,
row_factory: AsyncRowFactory[Row],
cursor_factory: "Optional[Type[AsyncCursor[Row]]]" = None,
context: Optional[AdaptContext] = None,
**kwargs: Union[None, int, str],
) -> "AsyncCrdbConnection[Row]": ...
@overload
@classmethod
async def connect(
cls,
conninfo: str = "",
*,
autocommit: bool = False,
prepare_threshold: Optional[int] = 5,
cursor_factory: "Optional[Type[AsyncCursor[Any]]]" = None,
context: Optional[AdaptContext] = None,
**kwargs: Union[None, int, str],
) -> "AsyncCrdbConnection[TupleRow]": ...
@classmethod
async def connect(cls, conninfo: str = "", **kwargs: Any) -> Self:
return await super().connect(conninfo, **kwargs) # type: ignore[no-any-return]
class CrdbConnectionInfo(ConnectionInfo):
"""
`~psycopg.ConnectionInfo` subclass to get info about a CockroachDB database.
"""
__module__ = "psycopg.crdb"
@property
def vendor(self) -> str:
return "CockroachDB"
@property
def server_version(self) -> int:
"""
Return the CockroachDB server version connected.
Return a number in the PostgreSQL format (e.g. 21.2.10 -> 210210).
"""
sver = self.parameter_status("crdb_version")
if not sver:
raise e.InternalError("'crdb_version' parameter status not set")
ver = self.parse_crdb_version(sver)
if ver is None:
raise e.InterfaceError(f"couldn't parse CockroachDB version from: {sver!r}")
return ver
@classmethod
def parse_crdb_version(self, sver: str) -> Optional[int]:
m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)
if not m:
return None
return int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3))

View File

@@ -0,0 +1,924 @@
"""
psycopg cursor objects
"""
# Copyright (C) 2020 The Psycopg Team
from functools import partial
from types import TracebackType
from typing import Any, Generic, Iterable, Iterator, List
from typing import Optional, NoReturn, Sequence, Tuple, Type
from typing import overload, TYPE_CHECKING
from warnings import warn
from contextlib import contextmanager
from . import pq
from . import adapt
from . import errors as e
from .abc import ConnectionType, Query, Params, PQGen
from .copy import Copy, Writer as CopyWriter
from .rows import Row, RowMaker, RowFactory
from ._column import Column
from ._compat import Self
from .pq.misc import connection_summary
from ._queries import PostgresQuery, PostgresClientQuery
from ._pipeline import Pipeline
from ._encodings import pgconn_encoding
from ._preparing import Prepare
from .generators import execute, fetch, send
if TYPE_CHECKING:
from .abc import Transformer
from .pq.abc import PGconn, PGresult
from .connection import Connection
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
EMPTY_QUERY = pq.ExecStatus.EMPTY_QUERY
COMMAND_OK = pq.ExecStatus.COMMAND_OK
TUPLES_OK = pq.ExecStatus.TUPLES_OK
COPY_OUT = pq.ExecStatus.COPY_OUT
COPY_IN = pq.ExecStatus.COPY_IN
COPY_BOTH = pq.ExecStatus.COPY_BOTH
FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
ACTIVE = pq.TransactionStatus.ACTIVE
class BaseCursor(Generic[ConnectionType, Row]):
__slots__ = """
_conn format _adapters arraysize _closed _results pgresult _pos
_iresult _rowcount _query _tx _last_query _row_factory _make_row
_pgconn _execmany_returning
__weakref__
""".split()
ExecStatus = pq.ExecStatus
_tx: "Transformer"
_make_row: RowMaker[Row]
_pgconn: "PGconn"
def __init__(self, connection: ConnectionType):
self._conn = connection
self.format = TEXT
self._pgconn = connection.pgconn
self._adapters = adapt.AdaptersMap(connection.adapters)
self.arraysize = 1
self._closed = False
self._last_query: Optional[Query] = None
self._reset()
def _reset(self, reset_query: bool = True) -> None:
self._results: List["PGresult"] = []
self.pgresult: Optional["PGresult"] = None
self._pos = 0
self._iresult = 0
self._rowcount = -1
self._query: Optional[PostgresQuery]
# None if executemany() not executing, True/False according to returning state
self._execmany_returning: Optional[bool] = None
if reset_query:
self._query = None
def __repr__(self) -> str:
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
info = connection_summary(self._pgconn)
if self._closed:
status = "closed"
elif self.pgresult:
status = pq.ExecStatus(self.pgresult.status).name
else:
status = "no result"
return f"<{cls} [{status}] {info} at 0x{id(self):x}>"
@property
def connection(self) -> ConnectionType:
"""The connection this cursor is using."""
return self._conn
@property
def adapters(self) -> adapt.AdaptersMap:
return self._adapters
@property
def closed(self) -> bool:
"""`True` if the cursor is closed."""
return self._closed
@property
def description(self) -> Optional[List[Column]]:
"""
A list of `Column` objects describing the current resultset.
`!None` if the current resultset didn't return tuples.
"""
res = self.pgresult
# We return columns if we have nfields, but also if we don't but
# the query said we got tuples (mostly to handle the super useful
# query "SELECT ;"
if res and (
res.nfields or res.status == TUPLES_OK or res.status == SINGLE_TUPLE
):
return [Column(self, i) for i in range(res.nfields)]
else:
return None
@property
def rowcount(self) -> int:
"""Number of records affected by the precedent operation."""
return self._rowcount
@property
def rownumber(self) -> Optional[int]:
"""Index of the next row to fetch in the current result.
`!None` if there is no result to fetch.
"""
tuples = self.pgresult and self.pgresult.status == TUPLES_OK
return self._pos if tuples else None
def setinputsizes(self, sizes: Sequence[Any]) -> None:
# no-op
pass
def setoutputsize(self, size: Any, column: Optional[int] = None) -> None:
# no-op
pass
def nextset(self) -> Optional[bool]:
"""
Move to the result set of the next query executed through `executemany()`
or to the next result set if `execute()` returned more than one.
Return `!True` if a new result is available, which will be the one
methods `!fetch*()` will operate on.
"""
# Raise a warning if people is calling nextset() in pipeline mode
# after a sequence of execute() in pipeline mode. Pipeline accumulating
# execute() results in the cursor is an unintended difference w.r.t.
# non-pipeline mode.
if self._execmany_returning is None and self._conn._pipeline:
warn(
"using nextset() in pipeline mode for several execute() is"
" deprecated and will be dropped in 3.2; please use different"
" cursors to receive more than one result",
DeprecationWarning,
)
if self._iresult < len(self._results) - 1:
self._select_current_result(self._iresult + 1)
return True
else:
return None
@property
def statusmessage(self) -> Optional[str]:
"""
The command status tag from the last SQL command executed.
`!None` if the cursor doesn't have a result available.
"""
msg = self.pgresult.command_status if self.pgresult else None
return msg.decode() if msg else None
def _make_row_maker(self) -> RowMaker[Row]:
raise NotImplementedError
#
# Generators for the high level operations on the cursor
#
# Like for sync/async connections, these are implemented as generators
# so that different concurrency strategies (threads,asyncio) can use their
# own way of waiting (or better, `connection.wait()`).
#
def _execute_gen(
self,
query: Query,
params: Optional[Params] = None,
*,
prepare: Optional[bool] = None,
binary: Optional[bool] = None,
) -> PQGen[None]:
"""Generator implementing `Cursor.execute()`."""
yield from self._start_query(query)
pgq = self._convert_query(query, params)
results = yield from self._maybe_prepare_gen(
pgq, prepare=prepare, binary=binary
)
if self._conn._pipeline:
yield from self._conn._pipeline._communicate_gen()
else:
assert results is not None
self._check_results(results)
self._results = results
self._select_current_result(0)
self._last_query = query
for cmd in self._conn._prepared.get_maintenance_commands():
yield from self._conn._exec_command(cmd)
def _executemany_gen_pipeline(
self, query: Query, params_seq: Iterable[Params], returning: bool
) -> PQGen[None]:
"""
Generator implementing `Cursor.executemany()` with pipelines available.
"""
pipeline = self._conn._pipeline
assert pipeline
yield from self._start_query(query)
if not returning:
self._rowcount = 0
assert self._execmany_returning is None
self._execmany_returning = returning
first = True
for params in params_seq:
if first:
pgq = self._convert_query(query, params)
self._query = pgq
first = False
else:
pgq.dump(params)
yield from self._maybe_prepare_gen(pgq, prepare=True)
yield from pipeline._communicate_gen()
self._last_query = query
if returning:
yield from pipeline._fetch_gen(flush=True)
for cmd in self._conn._prepared.get_maintenance_commands():
yield from self._conn._exec_command(cmd)
def _executemany_gen_no_pipeline(
self, query: Query, params_seq: Iterable[Params], returning: bool
) -> PQGen[None]:
"""
Generator implementing `Cursor.executemany()` with pipelines not available.
"""
yield from self._start_query(query)
if not returning:
self._rowcount = 0
first = True
for params in params_seq:
if first:
pgq = self._convert_query(query, params)
self._query = pgq
first = False
else:
pgq.dump(params)
results = yield from self._maybe_prepare_gen(pgq, prepare=True)
assert results is not None
self._check_results(results)
if returning:
self._results.extend(results)
else:
# In non-returning case, set rowcount to the cumulated number
# of rows of executed queries.
for res in results:
self._rowcount += res.command_tuples or 0
if self._results:
self._select_current_result(0)
self._last_query = query
for cmd in self._conn._prepared.get_maintenance_commands():
yield from self._conn._exec_command(cmd)
def _maybe_prepare_gen(
self,
pgq: PostgresQuery,
*,
prepare: Optional[bool] = None,
binary: Optional[bool] = None,
) -> PQGen[Optional[List["PGresult"]]]:
# Check if the query is prepared or needs preparing
prep, name = self._get_prepared(pgq, prepare)
if prep is Prepare.NO:
# The query must be executed without preparing
self._execute_send(pgq, binary=binary)
else:
# If the query is not already prepared, prepare it.
if prep is Prepare.SHOULD:
self._send_prepare(name, pgq)
if not self._conn._pipeline:
(result,) = yield from execute(self._pgconn)
if result.status == FATAL_ERROR:
raise e.error_from_result(result, encoding=self._encoding)
# Then execute it.
self._send_query_prepared(name, pgq, binary=binary)
# Update the prepare state of the query.
# If an operation requires to flush our prepared statements cache,
# it will be added to the maintenance commands to execute later.
key = self._conn._prepared.maybe_add_to_cache(pgq, prep, name)
if self._conn._pipeline:
queued = None
if key is not None:
queued = (key, prep, name)
self._conn._pipeline.result_queue.append((self, queued))
return None
# run the query
results = yield from execute(self._pgconn)
if key is not None:
self._conn._prepared.validate(key, prep, name, results)
return results
def _get_prepared(
self, pgq: PostgresQuery, prepare: Optional[bool] = None
) -> Tuple[Prepare, bytes]:
return self._conn._prepared.get(pgq, prepare)
def _stream_send_gen(
self,
query: Query,
params: Optional[Params] = None,
*,
binary: Optional[bool] = None,
) -> PQGen[None]:
"""Generator to send the query for `Cursor.stream()`."""
yield from self._start_query(query)
pgq = self._convert_query(query, params)
self._execute_send(pgq, binary=binary, force_extended=True)
self._pgconn.set_single_row_mode()
self._last_query = query
yield from send(self._pgconn)
def _stream_fetchone_gen(self, first: bool) -> PQGen[Optional["PGresult"]]:
res = yield from fetch(self._pgconn)
if res is None:
return None
status = res.status
if status == SINGLE_TUPLE:
self.pgresult = res
self._tx.set_pgresult(res, set_loaders=first)
if first:
self._make_row = self._make_row_maker()
return res
elif status == TUPLES_OK or status == COMMAND_OK:
# End of single row results
while res:
res = yield from fetch(self._pgconn)
if status != TUPLES_OK:
raise e.ProgrammingError(
"the operation in stream() didn't produce a result"
)
return None
else:
# Errors, unexpected values
return self._raise_for_result(res)
def _start_query(self, query: Optional[Query] = None) -> PQGen[None]:
"""Generator to start the processing of a query.
It is implemented as generator because it may send additional queries,
such as `begin`.
"""
if self.closed:
raise e.InterfaceError("the cursor is closed")
self._reset()
if not self._last_query or (self._last_query is not query):
self._last_query = None
self._tx = adapt.Transformer(self)
yield from self._conn._start_query()
def _start_copy_gen(
self, statement: Query, params: Optional[Params] = None
) -> PQGen[None]:
"""Generator implementing sending a command for `Cursor.copy()."""
# The connection gets in an unrecoverable state if we attempt COPY in
# pipeline mode. Forbid it explicitly.
if self._conn._pipeline:
raise e.NotSupportedError("COPY cannot be used in pipeline mode")
yield from self._start_query()
# Merge the params client-side
if params:
pgq = PostgresClientQuery(self._tx)
pgq.convert(statement, params)
statement = pgq.query
query = self._convert_query(statement)
self._execute_send(query, binary=False)
results = yield from execute(self._pgconn)
if len(results) != 1:
raise e.ProgrammingError("COPY cannot be mixed with other operations")
self._check_copy_result(results[0])
self._results = results
self._select_current_result(0)
def _execute_send(
self,
query: PostgresQuery,
*,
force_extended: bool = False,
binary: Optional[bool] = None,
) -> None:
"""
Implement part of execute() before waiting common to sync and async.
This is not a generator, but a normal non-blocking function.
"""
if binary is None:
fmt = self.format
else:
fmt = BINARY if binary else TEXT
self._query = query
if self._conn._pipeline:
# In pipeline mode always use PQsendQueryParams - see #314
# Multiple statements in the same query are not allowed anyway.
self._conn._pipeline.command_queue.append(
partial(
self._pgconn.send_query_params,
query.query,
query.params,
param_formats=query.formats,
param_types=query.types,
result_format=fmt,
)
)
elif force_extended or query.params or fmt == BINARY:
self._pgconn.send_query_params(
query.query,
query.params,
param_formats=query.formats,
param_types=query.types,
result_format=fmt,
)
else:
# If we can, let's use simple query protocol,
# as it can execute more than one statement in a single query.
self._pgconn.send_query(query.query)
def _convert_query(
self, query: Query, params: Optional[Params] = None
) -> PostgresQuery:
pgq = PostgresQuery(self._tx)
pgq.convert(query, params)
return pgq
def _check_results(self, results: List["PGresult"]) -> None:
"""
Verify that the results of a query are valid.
Verify that the query returned at least one result and that they all
represent a valid result from the database.
"""
if not results:
raise e.InternalError("got no result from the query")
for res in results:
status = res.status
if status != TUPLES_OK and status != COMMAND_OK and status != EMPTY_QUERY:
self._raise_for_result(res)
def _raise_for_result(self, result: "PGresult") -> NoReturn:
"""
Raise an appropriate error message for an unexpected database result
"""
status = result.status
if status == FATAL_ERROR:
raise e.error_from_result(result, encoding=self._encoding)
elif status == PIPELINE_ABORTED:
raise e.PipelineAborted("pipeline aborted")
elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
raise e.ProgrammingError(
"COPY cannot be used with this method; use copy() instead"
)
else:
raise e.InternalError(
"unexpected result status from query:" f" {pq.ExecStatus(status).name}"
)
def _select_current_result(
self, i: int, format: Optional[pq.Format] = None
) -> None:
"""
Select one of the results in the cursor as the active one.
"""
self._iresult = i
res = self.pgresult = self._results[i]
# Note: the only reason to override format is to correctly set
# binary loaders on server-side cursors, because send_describe_portal
# only returns a text result.
self._tx.set_pgresult(res, format=format)
self._pos = 0
if res.status == TUPLES_OK:
self._rowcount = self.pgresult.ntuples
# COPY_OUT has never info about nrows. We need such result for the
# columns in order to return a `description`, but not overwrite the
# cursor rowcount (which was set by the Copy object).
elif res.status != COPY_OUT:
nrows = self.pgresult.command_tuples
self._rowcount = nrows if nrows is not None else -1
self._make_row = self._make_row_maker()
def _set_results_from_pipeline(self, results: List["PGresult"]) -> None:
self._check_results(results)
first_batch = not self._results
if self._execmany_returning is None:
# Received from execute()
self._results.extend(results)
if first_batch:
self._select_current_result(0)
else:
# Received from executemany()
if self._execmany_returning:
self._results.extend(results)
if first_batch:
self._select_current_result(0)
else:
# In non-returning case, set rowcount to the cumulated number of
# rows of executed queries.
for res in results:
self._rowcount += res.command_tuples or 0
def _send_prepare(self, name: bytes, query: PostgresQuery) -> None:
if self._conn._pipeline:
self._conn._pipeline.command_queue.append(
partial(
self._pgconn.send_prepare,
name,
query.query,
param_types=query.types,
)
)
self._conn._pipeline.result_queue.append(None)
else:
self._pgconn.send_prepare(name, query.query, param_types=query.types)
def _send_query_prepared(
self, name: bytes, pgq: PostgresQuery, *, binary: Optional[bool] = None
) -> None:
if binary is None:
fmt = self.format
else:
fmt = BINARY if binary else TEXT
if self._conn._pipeline:
self._conn._pipeline.command_queue.append(
partial(
self._pgconn.send_query_prepared,
name,
pgq.params,
param_formats=pgq.formats,
result_format=fmt,
)
)
else:
self._pgconn.send_query_prepared(
name, pgq.params, param_formats=pgq.formats, result_format=fmt
)
def _check_result_for_fetch(self) -> None:
if self.closed:
raise e.InterfaceError("the cursor is closed")
res = self.pgresult
if not res:
raise e.ProgrammingError("no result available")
status = res.status
if status == TUPLES_OK:
return
elif status == FATAL_ERROR:
raise e.error_from_result(res, encoding=self._encoding)
elif status == PIPELINE_ABORTED:
raise e.PipelineAborted("pipeline aborted")
else:
raise e.ProgrammingError("the last operation didn't produce a result")
def _check_copy_result(self, result: "PGresult") -> None:
"""
Check that the value returned in a copy() operation is a legit COPY.
"""
status = result.status
if status == COPY_IN or status == COPY_OUT:
return
elif status == FATAL_ERROR:
raise e.error_from_result(result, encoding=self._encoding)
else:
raise e.ProgrammingError(
"copy() should be used only with COPY ... TO STDOUT or COPY ..."
f" FROM STDIN statements, got {pq.ExecStatus(status).name}"
)
def _scroll(self, value: int, mode: str) -> None:
self._check_result_for_fetch()
assert self.pgresult
if mode == "relative":
newpos = self._pos + value
elif mode == "absolute":
newpos = value
else:
raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
if not 0 <= newpos < self.pgresult.ntuples:
raise IndexError("position out of bound")
self._pos = newpos
def _close(self) -> None:
"""Non-blocking part of closing. Common to sync/async."""
# Don't reset the query because it may be useful to investigate after
# an error.
self._reset(reset_query=False)
self._closed = True
@property
def _encoding(self) -> str:
return pgconn_encoding(self._pgconn)
class Cursor(BaseCursor["Connection[Any]", Row]):
__module__ = "psycopg"
__slots__ = ()
@overload
def __init__(self, connection: "Connection[Row]"): ...
@overload
def __init__(
self, connection: "Connection[Any]", *, row_factory: RowFactory[Row]
): ...
def __init__(
self,
connection: "Connection[Any]",
*,
row_factory: Optional[RowFactory[Row]] = None,
):
super().__init__(connection)
self._row_factory = row_factory or connection.row_factory
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.close()
def close(self) -> None:
"""
Close the current cursor and free associated resources.
"""
self._close()
@property
def row_factory(self) -> RowFactory[Row]:
"""Writable attribute to control how result rows are formed."""
return self._row_factory
@row_factory.setter
def row_factory(self, row_factory: RowFactory[Row]) -> None:
self._row_factory = row_factory
if self.pgresult:
self._make_row = row_factory(self)
def _make_row_maker(self) -> RowMaker[Row]:
return self._row_factory(self)
def execute(
self,
query: Query,
params: Optional[Params] = None,
*,
prepare: Optional[bool] = None,
binary: Optional[bool] = None,
) -> Self:
"""
Execute a query or command to the database.
"""
try:
with self._conn.lock:
self._conn.wait(
self._execute_gen(query, params, prepare=prepare, binary=binary)
)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
return self
def executemany(
self,
query: Query,
params_seq: Iterable[Params],
*,
returning: bool = False,
) -> None:
"""
Execute the same command with a sequence of input data.
"""
try:
if Pipeline.is_supported():
# If there is already a pipeline, ride it, in order to avoid
# sending unnecessary Sync.
with self._conn.lock:
p = self._conn._pipeline
if p:
self._conn.wait(
self._executemany_gen_pipeline(query, params_seq, returning)
)
# Otherwise, make a new one
if not p:
with self._conn.pipeline(), self._conn.lock:
self._conn.wait(
self._executemany_gen_pipeline(query, params_seq, returning)
)
else:
with self._conn.lock:
self._conn.wait(
self._executemany_gen_no_pipeline(query, params_seq, returning)
)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
def stream(
self,
query: Query,
params: Optional[Params] = None,
*,
binary: Optional[bool] = None,
) -> Iterator[Row]:
"""
Iterate row-by-row on a result from the database.
"""
if self._pgconn.pipeline_status:
raise e.ProgrammingError("stream() cannot be used in pipeline mode")
with self._conn.lock:
try:
self._conn.wait(self._stream_send_gen(query, params, binary=binary))
first = True
while self._conn.wait(self._stream_fetchone_gen(first)):
# We know that, if we got a result, it has a single row.
rec: Row = self._tx.load_row(0, self._make_row) # type: ignore
yield rec
first = False
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
finally:
if self._pgconn.transaction_status == ACTIVE:
# Try to cancel the query, then consume the results
# already received.
self._conn.cancel()
try:
while self._conn.wait(self._stream_fetchone_gen(first=False)):
pass
except Exception:
pass
# Try to get out of ACTIVE state. Just do a single attempt, which
# should work to recover from an error or query cancelled.
try:
self._conn.wait(self._stream_fetchone_gen(first=False))
except Exception:
pass
def fetchone(self) -> Optional[Row]:
"""
Return the next record from the current recordset.
Return `!None` the recordset is finished.
:rtype: Optional[Row], with Row defined by `row_factory`
"""
self._fetch_pipeline()
self._check_result_for_fetch()
record = self._tx.load_row(self._pos, self._make_row)
if record is not None:
self._pos += 1
return record
def fetchmany(self, size: int = 0) -> List[Row]:
"""
Return the next `!size` records from the current recordset.
`!size` default to `!self.arraysize` if not specified.
:rtype: Sequence[Row], with Row defined by `row_factory`
"""
self._fetch_pipeline()
self._check_result_for_fetch()
assert self.pgresult
if not size:
size = self.arraysize
records = self._tx.load_rows(
self._pos,
min(self._pos + size, self.pgresult.ntuples),
self._make_row,
)
self._pos += len(records)
return records
def fetchall(self) -> List[Row]:
"""
Return all the remaining records from the current recordset.
:rtype: Sequence[Row], with Row defined by `row_factory`
"""
self._fetch_pipeline()
self._check_result_for_fetch()
assert self.pgresult
records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
self._pos = self.pgresult.ntuples
return records
def __iter__(self) -> Iterator[Row]:
self._fetch_pipeline()
self._check_result_for_fetch()
def load(pos: int) -> Optional[Row]:
return self._tx.load_row(pos, self._make_row)
while True:
row = load(self._pos)
if row is None:
break
self._pos += 1
yield row
def scroll(self, value: int, mode: str = "relative") -> None:
"""
Move the cursor in the result set to a new position according to mode.
If `!mode` is ``'relative'`` (default), `!value` is taken as offset to
the current position in the result set; if set to ``'absolute'``,
`!value` states an absolute target position.
Raise `!IndexError` in case a scroll operation would leave the result
set. In this case the position will not change.
"""
self._fetch_pipeline()
self._scroll(value, mode)
@contextmanager
def copy(
self,
statement: Query,
params: Optional[Params] = None,
*,
writer: Optional[CopyWriter] = None,
) -> Iterator[Copy]:
"""
Initiate a :sql:`COPY` operation and return an object to manage it.
:rtype: Copy
"""
try:
with self._conn.lock:
self._conn.wait(self._start_copy_gen(statement, params))
with Copy(self, writer=writer) as copy:
yield copy
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
# If a fresher result has been set on the cursor by the Copy object,
# read its properties (especially rowcount).
self._select_current_result(0)
def _fetch_pipeline(self) -> None:
if (
self._execmany_returning is not False
and not self.pgresult
and self._conn._pipeline
):
with self._conn.lock:
self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))

View File

@@ -0,0 +1,246 @@
"""
psycopg async cursor objects
"""
# Copyright (C) 2020 The Psycopg Team
from types import TracebackType
from typing import Any, AsyncIterator, Iterable, List
from typing import Optional, Type, TYPE_CHECKING, overload
from contextlib import asynccontextmanager
from . import pq
from . import errors as e
from .abc import Query, Params
from .copy import AsyncCopy, AsyncWriter as AsyncCopyWriter
from .rows import Row, RowMaker, AsyncRowFactory
from .cursor import BaseCursor
from ._compat import Self
from ._pipeline import Pipeline
if TYPE_CHECKING:
from .connection_async import AsyncConnection
ACTIVE = pq.TransactionStatus.ACTIVE
class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
__module__ = "psycopg"
__slots__ = ()
@overload
def __init__(self, connection: "AsyncConnection[Row]"): ...
@overload
def __init__(
self, connection: "AsyncConnection[Any]", *, row_factory: AsyncRowFactory[Row]
): ...
def __init__(
self,
connection: "AsyncConnection[Any]",
*,
row_factory: Optional[AsyncRowFactory[Row]] = None,
):
super().__init__(connection)
self._row_factory = row_factory or connection.row_factory
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.close()
async def close(self) -> None:
self._close()
@property
def row_factory(self) -> AsyncRowFactory[Row]:
return self._row_factory
@row_factory.setter
def row_factory(self, row_factory: AsyncRowFactory[Row]) -> None:
self._row_factory = row_factory
if self.pgresult:
self._make_row = row_factory(self)
def _make_row_maker(self) -> RowMaker[Row]:
return self._row_factory(self)
async def execute(
self,
query: Query,
params: Optional[Params] = None,
*,
prepare: Optional[bool] = None,
binary: Optional[bool] = None,
) -> Self:
try:
async with self._conn.lock:
await self._conn.wait(
self._execute_gen(query, params, prepare=prepare, binary=binary)
)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
return self
async def executemany(
self,
query: Query,
params_seq: Iterable[Params],
*,
returning: bool = False,
) -> None:
try:
if Pipeline.is_supported():
# If there is already a pipeline, ride it, in order to avoid
# sending unnecessary Sync.
async with self._conn.lock:
p = self._conn._pipeline
if p:
await self._conn.wait(
self._executemany_gen_pipeline(query, params_seq, returning)
)
# Otherwise, make a new one
if not p:
async with self._conn.pipeline(), self._conn.lock:
await self._conn.wait(
self._executemany_gen_pipeline(query, params_seq, returning)
)
else:
async with self._conn.lock:
await self._conn.wait(
self._executemany_gen_no_pipeline(query, params_seq, returning)
)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
async def stream(
self,
query: Query,
params: Optional[Params] = None,
*,
binary: Optional[bool] = None,
) -> AsyncIterator[Row]:
if self._pgconn.pipeline_status:
raise e.ProgrammingError("stream() cannot be used in pipeline mode")
async with self._conn.lock:
try:
await self._conn.wait(
self._stream_send_gen(query, params, binary=binary)
)
first = True
while await self._conn.wait(self._stream_fetchone_gen(first)):
# We know that, if we got a result, it has a single row.
rec: Row = self._tx.load_row(0, self._make_row) # type: ignore
yield rec
first = False
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
finally:
if self._pgconn.transaction_status == ACTIVE:
# Try to cancel the query, then consume the results
# already received.
self._conn.cancel()
try:
while await self._conn.wait(
self._stream_fetchone_gen(first=False)
):
pass
except Exception:
pass
# Try to get out of ACTIVE state. Just do a single attempt, which
# should work to recover from an error or query cancelled.
try:
await self._conn.wait(self._stream_fetchone_gen(first=False))
except Exception:
pass
async def fetchone(self) -> Optional[Row]:
await self._fetch_pipeline()
self._check_result_for_fetch()
record = self._tx.load_row(self._pos, self._make_row)
if record is not None:
self._pos += 1
return record
async def fetchmany(self, size: int = 0) -> List[Row]:
await self._fetch_pipeline()
self._check_result_for_fetch()
assert self.pgresult
if not size:
size = self.arraysize
records = self._tx.load_rows(
self._pos,
min(self._pos + size, self.pgresult.ntuples),
self._make_row,
)
self._pos += len(records)
return records
async def fetchall(self) -> List[Row]:
await self._fetch_pipeline()
self._check_result_for_fetch()
assert self.pgresult
records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
self._pos = self.pgresult.ntuples
return records
async def __aiter__(self) -> AsyncIterator[Row]:
await self._fetch_pipeline()
self._check_result_for_fetch()
def load(pos: int) -> Optional[Row]:
return self._tx.load_row(pos, self._make_row)
while True:
row = load(self._pos)
if row is None:
break
self._pos += 1
yield row
async def scroll(self, value: int, mode: str = "relative") -> None:
await self._fetch_pipeline()
self._scroll(value, mode)
@asynccontextmanager
async def copy(
self,
statement: Query,
params: Optional[Params] = None,
*,
writer: Optional[AsyncCopyWriter] = None,
) -> AsyncIterator[AsyncCopy]:
"""
:rtype: AsyncCopy
"""
try:
async with self._conn.lock:
await self._conn.wait(self._start_copy_gen(statement, params))
async with AsyncCopy(self, writer=writer) as copy:
yield copy
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
self._select_current_result(0)
async def _fetch_pipeline(self) -> None:
if (
self._execmany_returning is not False
and not self.pgresult
and self._conn._pipeline
):
async with self._conn.lock:
await self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))

View File

@@ -0,0 +1,112 @@
"""
Compatibility objects with DBAPI 2.0
"""
# Copyright (C) 2020 The Psycopg Team
import time
import datetime as dt
from math import floor
from typing import Any, Sequence, Union
from . import postgres
from .abc import AdaptContext, Buffer
from .types.string import BytesDumper, BytesBinaryDumper
class DBAPITypeObject:
def __init__(self, name: str, type_names: Sequence[str]):
self.name = name
self.values = tuple(postgres.types[n].oid for n in type_names)
def __repr__(self) -> str:
return f"psycopg.{self.name}"
def __eq__(self, other: Any) -> bool:
if isinstance(other, int):
return other in self.values
else:
return NotImplemented
def __ne__(self, other: Any) -> bool:
if isinstance(other, int):
return other not in self.values
else:
return NotImplemented
BINARY = DBAPITypeObject("BINARY", ("bytea",))
DATETIME = DBAPITypeObject(
"DATETIME", "timestamp timestamptz date time timetz interval".split()
)
NUMBER = DBAPITypeObject("NUMBER", "int2 int4 int8 float4 float8 numeric".split())
ROWID = DBAPITypeObject("ROWID", ("oid",))
STRING = DBAPITypeObject("STRING", "text varchar bpchar".split())
class Binary:
def __init__(self, obj: Any):
self.obj = obj
def __repr__(self) -> str:
sobj = repr(self.obj)
if len(sobj) > 40:
sobj = f"{sobj[:35]} ... ({len(sobj)} byteschars)"
return f"{self.__class__.__name__}({sobj})"
class BinaryBinaryDumper(BytesBinaryDumper):
def dump(self, obj: Union[Buffer, Binary]) -> Buffer:
if isinstance(obj, Binary):
return super().dump(obj.obj)
else:
return super().dump(obj)
class BinaryTextDumper(BytesDumper):
def dump(self, obj: Union[Buffer, Binary]) -> Buffer:
if isinstance(obj, Binary):
return super().dump(obj.obj)
else:
return super().dump(obj)
def Date(year: int, month: int, day: int) -> dt.date:
return dt.date(year, month, day)
def DateFromTicks(ticks: float) -> dt.date:
return TimestampFromTicks(ticks).date()
def Time(hour: int, minute: int, second: int) -> dt.time:
return dt.time(hour, minute, second)
def TimeFromTicks(ticks: float) -> dt.time:
return TimestampFromTicks(ticks).time()
def Timestamp(
year: int, month: int, day: int, hour: int, minute: int, second: int
) -> dt.datetime:
return dt.datetime(year, month, day, hour, minute, second)
def TimestampFromTicks(ticks: float) -> dt.datetime:
secs = floor(ticks)
frac = ticks - secs
t = time.localtime(ticks)
tzinfo = dt.timezone(dt.timedelta(seconds=t.tm_gmtoff))
rv = dt.datetime(*t[:6], round(frac * 1_000_000), tzinfo=tzinfo)
return rv
def register_dbapi20_adapters(context: AdaptContext) -> None:
adapters = context.adapters
adapters.register_dumper(Binary, BinaryTextDumper)
adapters.register_dumper(Binary, BinaryBinaryDumper)
# Make them also the default dumpers when dumping by bytea oid
adapters.register_dumper(None, BinaryTextDumper)
adapters.register_dumper(None, BinaryBinaryDumper)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,333 @@
"""
Generators implementing communication protocols with the libpq
Certain operations (connection, querying) are an interleave of libpq calls and
waiting for the socket to be ready. This module contains the code to execute
the operations, yielding a polling state whenever there is to wait. The
functions in the `waiting` module are the ones who wait more or less
cooperatively for the socket to be ready and make these generators continue.
All these generators yield pairs (fileno, `Wait`) whenever an operation would
block. The generator can be restarted sending the appropriate `Ready` state
when the file descriptor is ready.
"""
# Copyright (C) 2020 The Psycopg Team
import logging
from typing import List, Optional, Union
from . import pq
from . import errors as e
from .abc import Buffer, PipelineCommand, PQGen, PQGenConn
from .pq.abc import PGconn, PGresult
from .waiting import Wait, Ready
from ._compat import Deque
from ._cmodule import _psycopg
from ._encodings import pgconn_encoding, conninfo_encoding
OK = pq.ConnStatus.OK
BAD = pq.ConnStatus.BAD
POLL_OK = pq.PollingStatus.OK
POLL_READING = pq.PollingStatus.READING
POLL_WRITING = pq.PollingStatus.WRITING
POLL_FAILED = pq.PollingStatus.FAILED
COMMAND_OK = pq.ExecStatus.COMMAND_OK
COPY_OUT = pq.ExecStatus.COPY_OUT
COPY_IN = pq.ExecStatus.COPY_IN
COPY_BOTH = pq.ExecStatus.COPY_BOTH
PIPELINE_SYNC = pq.ExecStatus.PIPELINE_SYNC
WAIT_R = Wait.R
WAIT_W = Wait.W
WAIT_RW = Wait.RW
READY_R = Ready.R
READY_W = Ready.W
READY_RW = Ready.RW
logger = logging.getLogger(__name__)
def _connect(conninfo: str) -> PQGenConn[PGconn]:
"""
Generator to create a database connection without blocking.
"""
conn = pq.PGconn.connect_start(conninfo.encode())
while True:
if conn.status == BAD:
encoding = conninfo_encoding(conninfo)
raise e.OperationalError(
f"connection is bad: {pq.error_message(conn, encoding=encoding)}",
pgconn=conn,
)
status = conn.connect_poll()
if status == POLL_OK:
break
elif status == POLL_READING:
yield conn.socket, WAIT_R
elif status == POLL_WRITING:
yield conn.socket, WAIT_W
elif status == POLL_FAILED:
encoding = conninfo_encoding(conninfo)
raise e.OperationalError(
f"connection failed: {pq.error_message(conn, encoding=encoding)}",
pgconn=e.finish_pgconn(conn),
)
else:
raise e.InternalError(
f"unexpected poll status: {status}", pgconn=e.finish_pgconn(conn)
)
conn.nonblocking = 1
return conn
def _execute(pgconn: PGconn) -> PQGen[List[PGresult]]:
"""
Generator sending a query and returning results without blocking.
The query must have already been sent using `pgconn.send_query()` or
similar. Flush the query and then return the result using nonblocking
functions.
Return the list of results returned by the database (whether success
or error).
"""
yield from _send(pgconn)
rv = yield from _fetch_many(pgconn)
return rv
def _send(pgconn: PGconn) -> PQGen[None]:
"""
Generator to send a query to the server without blocking.
The query must have already been sent using `pgconn.send_query()` or
similar. Flush the query and then return the result using nonblocking
functions.
After this generator has finished you may want to cycle using `fetch()`
to retrieve the results available.
"""
while True:
f = pgconn.flush()
if f == 0:
break
ready = yield WAIT_RW
if ready & READY_R:
# This call may read notifies: they will be saved in the
# PGconn buffer and passed to Python later, in `fetch()`.
pgconn.consume_input()
def _fetch_many(pgconn: PGconn) -> PQGen[List[PGresult]]:
"""
Generator retrieving results from the database without blocking.
The query must have already been sent to the server, so pgconn.flush() has
already returned 0.
Return the list of results returned by the database (whether success
or error).
"""
results: List[PGresult] = []
while True:
res = yield from _fetch(pgconn)
if not res:
break
results.append(res)
status = res.status
if status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
# After entering copy mode the libpq will create a phony result
# for every request so let's break the endless loop.
break
if status == PIPELINE_SYNC:
# PIPELINE_SYNC is not followed by a NULL, but we return it alone
# similarly to other result sets.
assert len(results) == 1, results
break
return results
def _fetch(pgconn: PGconn) -> PQGen[Optional[PGresult]]:
"""
Generator retrieving a single result from the database without blocking.
The query must have already been sent to the server, so pgconn.flush() has
already returned 0.
Return a result from the database (whether success or error).
"""
if pgconn.is_busy():
yield WAIT_R
while True:
pgconn.consume_input()
if not pgconn.is_busy():
break
yield WAIT_R
_consume_notifies(pgconn)
return pgconn.get_result()
def _pipeline_communicate(
pgconn: PGconn, commands: Deque[PipelineCommand]
) -> PQGen[List[List[PGresult]]]:
"""Generator to send queries from a connection in pipeline mode while also
receiving results.
Return a list results, including single PIPELINE_SYNC elements.
"""
results = []
while True:
ready = yield WAIT_RW
if ready & READY_R:
pgconn.consume_input()
_consume_notifies(pgconn)
res: List[PGresult] = []
while not pgconn.is_busy():
r = pgconn.get_result()
if r is None:
if not res:
break
results.append(res)
res = []
else:
status = r.status
if status == PIPELINE_SYNC:
assert not res
results.append([r])
elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
# This shouldn't happen, but insisting hard enough, it will.
# For instance, in test_executemany_badquery(), with the COPY
# statement and the AsyncClientCursor, which disables
# prepared statements).
# Bail out from the resulting infinite loop.
raise e.NotSupportedError(
"COPY cannot be used in pipeline mode"
)
else:
res.append(r)
if ready & READY_W:
pgconn.flush()
if not commands:
break
commands.popleft()()
return results
def _consume_notifies(pgconn: PGconn) -> None:
# Consume notifies
while True:
n = pgconn.notifies()
if not n:
break
if pgconn.notify_handler:
pgconn.notify_handler(n)
def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]:
yield WAIT_R
pgconn.consume_input()
ns = []
while True:
n = pgconn.notifies()
if n:
ns.append(n)
else:
break
return ns
def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]:
while True:
nbytes, data = pgconn.get_copy_data(1)
if nbytes != 0:
break
# would block
yield WAIT_R
pgconn.consume_input()
if nbytes > 0:
# some data
return data
# Retrieve the final result of copy
results = yield from _fetch_many(pgconn)
if len(results) > 1:
# TODO: too brutal? Copy worked.
raise e.ProgrammingError("you cannot mix COPY with other operations")
result = results[0]
if result.status != COMMAND_OK:
encoding = pgconn_encoding(pgconn)
raise e.error_from_result(result, encoding=encoding)
return result
def copy_to(pgconn: PGconn, buffer: Buffer) -> PQGen[None]:
# Retry enqueuing data until successful.
#
# WARNING! This can cause an infinite loop if the buffer is too large. (see
# ticket #255). We avoid it in the Copy object by splitting a large buffer
# into smaller ones. We prefer to do it there instead of here in order to
# do it upstream the queue decoupling the writer task from the producer one.
while pgconn.put_copy_data(buffer) == 0:
yield WAIT_W
def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]:
# Retry enqueuing end copy message until successful
while pgconn.put_copy_end(error) == 0:
yield WAIT_W
# Repeat until it the message is flushed to the server
while True:
yield WAIT_W
f = pgconn.flush()
if f == 0:
break
# Retrieve the final result of copy
(result,) = yield from _fetch_many(pgconn)
if result.status != COMMAND_OK:
encoding = pgconn_encoding(pgconn)
raise e.error_from_result(result, encoding=encoding)
return result
# Override functions with fast versions if available
if _psycopg:
connect = _psycopg.connect
execute = _psycopg.execute
send = _psycopg.send
fetch_many = _psycopg.fetch_many
fetch = _psycopg.fetch
pipeline_communicate = _psycopg.pipeline_communicate
else:
connect = _connect
execute = _execute
send = _send
fetch_many = _fetch_many
fetch = _fetch
pipeline_communicate = _pipeline_communicate

View File

@@ -0,0 +1,124 @@
"""
Types configuration specific to PostgreSQL.
"""
# Copyright (C) 2020 The Psycopg Team
from ._typeinfo import TypeInfo, RangeInfo, MultirangeInfo, TypesRegistry
from .abc import AdaptContext
from ._adapters_map import AdaptersMap
# Global objects with PostgreSQL builtins and globally registered user types.
types = TypesRegistry()
# Global adapter maps with PostgreSQL types configuration
adapters = AdaptersMap(types=types)
# Use tools/update_oids.py to update this data.
for t in [
TypeInfo('"char"', 18, 1002),
# autogenerated: start
# Generated from PostgreSQL 16.0
TypeInfo("aclitem", 1033, 1034),
TypeInfo("bit", 1560, 1561),
TypeInfo("bool", 16, 1000, regtype="boolean"),
TypeInfo("box", 603, 1020, delimiter=";"),
TypeInfo("bpchar", 1042, 1014, regtype="character"),
TypeInfo("bytea", 17, 1001),
TypeInfo("cid", 29, 1012),
TypeInfo("cidr", 650, 651),
TypeInfo("circle", 718, 719),
TypeInfo("date", 1082, 1182),
TypeInfo("float4", 700, 1021, regtype="real"),
TypeInfo("float8", 701, 1022, regtype="double precision"),
TypeInfo("gtsvector", 3642, 3644),
TypeInfo("inet", 869, 1041),
TypeInfo("int2", 21, 1005, regtype="smallint"),
TypeInfo("int2vector", 22, 1006),
TypeInfo("int4", 23, 1007, regtype="integer"),
TypeInfo("int8", 20, 1016, regtype="bigint"),
TypeInfo("interval", 1186, 1187),
TypeInfo("json", 114, 199),
TypeInfo("jsonb", 3802, 3807),
TypeInfo("jsonpath", 4072, 4073),
TypeInfo("line", 628, 629),
TypeInfo("lseg", 601, 1018),
TypeInfo("macaddr", 829, 1040),
TypeInfo("macaddr8", 774, 775),
TypeInfo("money", 790, 791),
TypeInfo("name", 19, 1003),
TypeInfo("numeric", 1700, 1231),
TypeInfo("oid", 26, 1028),
TypeInfo("oidvector", 30, 1013),
TypeInfo("path", 602, 1019),
TypeInfo("pg_lsn", 3220, 3221),
TypeInfo("point", 600, 1017),
TypeInfo("polygon", 604, 1027),
TypeInfo("record", 2249, 2287),
TypeInfo("refcursor", 1790, 2201),
TypeInfo("regclass", 2205, 2210),
TypeInfo("regcollation", 4191, 4192),
TypeInfo("regconfig", 3734, 3735),
TypeInfo("regdictionary", 3769, 3770),
TypeInfo("regnamespace", 4089, 4090),
TypeInfo("regoper", 2203, 2208),
TypeInfo("regoperator", 2204, 2209),
TypeInfo("regproc", 24, 1008),
TypeInfo("regprocedure", 2202, 2207),
TypeInfo("regrole", 4096, 4097),
TypeInfo("regtype", 2206, 2211),
TypeInfo("text", 25, 1009),
TypeInfo("tid", 27, 1010),
TypeInfo("time", 1083, 1183, regtype="time without time zone"),
TypeInfo("timestamp", 1114, 1115, regtype="timestamp without time zone"),
TypeInfo("timestamptz", 1184, 1185, regtype="timestamp with time zone"),
TypeInfo("timetz", 1266, 1270, regtype="time with time zone"),
TypeInfo("tsquery", 3615, 3645),
TypeInfo("tsvector", 3614, 3643),
TypeInfo("txid_snapshot", 2970, 2949),
TypeInfo("uuid", 2950, 2951),
TypeInfo("varbit", 1562, 1563, regtype="bit varying"),
TypeInfo("varchar", 1043, 1015, regtype="character varying"),
TypeInfo("xid", 28, 1011),
TypeInfo("xid8", 5069, 271),
TypeInfo("xml", 142, 143),
RangeInfo("daterange", 3912, 3913, subtype_oid=1082),
RangeInfo("int4range", 3904, 3905, subtype_oid=23),
RangeInfo("int8range", 3926, 3927, subtype_oid=20),
RangeInfo("numrange", 3906, 3907, subtype_oid=1700),
RangeInfo("tsrange", 3908, 3909, subtype_oid=1114),
RangeInfo("tstzrange", 3910, 3911, subtype_oid=1184),
MultirangeInfo("datemultirange", 4535, 6155, range_oid=3912, subtype_oid=1082),
MultirangeInfo("int4multirange", 4451, 6150, range_oid=3904, subtype_oid=23),
MultirangeInfo("int8multirange", 4536, 6157, range_oid=3926, subtype_oid=20),
MultirangeInfo("nummultirange", 4532, 6151, range_oid=3906, subtype_oid=1700),
MultirangeInfo("tsmultirange", 4533, 6152, range_oid=3908, subtype_oid=1114),
MultirangeInfo("tstzmultirange", 4534, 6153, range_oid=3910, subtype_oid=1184),
# autogenerated: end
]:
types.add(t)
# A few oids used a bit everywhere
INVALID_OID = 0
TEXT_OID = types["text"].oid
TEXT_ARRAY_OID = types["text"].array_oid
def register_default_adapters(context: AdaptContext) -> None:
from .types import array, bool, composite, datetime, enum, json, multirange
from .types import net, none, numeric, range, string, uuid
array.register_default_adapters(context)
bool.register_default_adapters(context)
composite.register_default_adapters(context)
datetime.register_default_adapters(context)
enum.register_default_adapters(context)
json.register_default_adapters(context)
multirange.register_default_adapters(context)
net.register_default_adapters(context)
none.register_default_adapters(context)
numeric.register_default_adapters(context)
range.register_default_adapters(context)
string.register_default_adapters(context)
uuid.register_default_adapters(context)

View File

@@ -0,0 +1,133 @@
"""
psycopg libpq wrapper
This package exposes the libpq functionalities as Python objects and functions.
The real implementation (the binding to the C library) is
implementation-dependant but all the implementations share the same interface.
"""
# Copyright (C) 2020 The Psycopg Team
import os
import logging
from typing import Callable, List, Type
from . import abc
from .misc import ConninfoOption, PGnotify, PGresAttDesc
from .misc import error_message
from ._enums import ConnStatus, DiagnosticField, ExecStatus, Format, Trace
from ._enums import Ping, PipelineStatus, PollingStatus, TransactionStatus
logger = logging.getLogger(__name__)
__impl__: str
"""The currently loaded implementation of the `!psycopg.pq` package.
Possible values include ``python``, ``c``, ``binary``.
"""
__build_version__: int
"""The libpq version the C package was built with.
A number in the same format of `~psycopg.ConnectionInfo.server_version`
representing the libpq used to build the speedup module (``c``, ``binary``) if
available.
Certain features might not be available if the built version is too old.
"""
version: Callable[[], int]
PGconn: Type[abc.PGconn]
PGresult: Type[abc.PGresult]
Conninfo: Type[abc.Conninfo]
Escaping: Type[abc.Escaping]
PGcancel: Type[abc.PGcancel]
def import_from_libpq() -> None:
"""
Import pq objects implementation from the best libpq wrapper available.
If an implementation is requested try to import only it, otherwise
try to import the best implementation available.
"""
# import these names into the module on success as side effect
global __impl__, version, __build_version__
global PGconn, PGresult, Conninfo, Escaping, PGcancel
impl = os.environ.get("PSYCOPG_IMPL", "").lower()
module = None
attempts: List[str] = []
def handle_error(name: str, e: Exception) -> None:
if not impl:
msg = f"couldn't import psycopg '{name}' implementation: {e}"
logger.debug(msg)
attempts.append(msg)
else:
msg = f"couldn't import requested psycopg '{name}' implementation: {e}"
raise ImportError(msg) from e
# The best implementation: fast but requires the system libpq installed
if not impl or impl == "c":
try:
from psycopg_c import pq as module # type: ignore
except Exception as e:
handle_error("c", e)
# Second best implementation: fast and stand-alone
if not module and (not impl or impl == "binary"):
try:
from psycopg_binary import pq as module # type: ignore
except Exception as e:
handle_error("binary", e)
# Pure Python implementation, slow and requires the system libpq installed.
if not module and (not impl or impl == "python"):
try:
from . import pq_ctypes as module # type: ignore[assignment]
except Exception as e:
handle_error("python", e)
if module:
__impl__ = module.__impl__
version = module.version
PGconn = module.PGconn
PGresult = module.PGresult
Conninfo = module.Conninfo
Escaping = module.Escaping
PGcancel = module.PGcancel
__build_version__ = module.__build_version__
elif impl:
raise ImportError(f"requested psycopg implementation '{impl}' unknown")
else:
sattempts = "\n".join(f"- {attempt}" for attempt in attempts)
raise ImportError(
f"""\
no pq wrapper available.
Attempts made:
{sattempts}"""
)
import_from_libpq()
__all__ = (
"ConnStatus",
"PipelineStatus",
"PollingStatus",
"TransactionStatus",
"ExecStatus",
"Ping",
"DiagnosticField",
"Format",
"Trace",
"PGconn",
"PGnotify",
"Conninfo",
"PGresAttDesc",
"error_message",
"ConninfoOption",
"version",
)

View File

@@ -0,0 +1,106 @@
"""
libpq debugging tools
These functionalities are exposed here for convenience, but are not part of
the public interface and are subject to change at any moment.
Suggested usage::
import logging
import psycopg
from psycopg import pq
from psycopg.pq._debug import PGconnDebug
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger("psycopg.debug")
logger.setLevel(logging.INFO)
assert pq.__impl__ == "python"
pq.PGconn = PGconnDebug
with psycopg.connect("") as conn:
conn.pgconn.trace(2)
conn.pgconn.set_trace_flags(
pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE)
...
"""
# Copyright (C) 2022 The Psycopg Team
import inspect
import logging
from typing import Any, Callable, TypeVar, TYPE_CHECKING
from functools import wraps
from .._compat import Self
from . import PGconn
from .misc import connection_summary
if TYPE_CHECKING:
from . import abc
Func = TypeVar("Func", bound=Callable[..., Any])
logger = logging.getLogger("psycopg.debug")
class PGconnDebug:
"""Wrapper for a PQconn logging all its access."""
_pgconn: "abc.PGconn"
def __init__(self, pgconn: "abc.PGconn"):
super().__setattr__("_pgconn", pgconn)
def __repr__(self) -> str:
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
info = connection_summary(self._pgconn)
return f"<{cls} {info} at 0x{id(self):x}>"
def __getattr__(self, attr: str) -> Any:
value = getattr(self._pgconn, attr)
if callable(value):
return debugging(value)
else:
logger.info("PGconn.%s -> %s", attr, value)
return value
def __setattr__(self, attr: str, value: Any) -> None:
setattr(self._pgconn, attr, value)
logger.info("PGconn.%s <- %s", attr, value)
@classmethod
def connect(cls, conninfo: bytes) -> Self:
return cls(debugging(PGconn.connect)(conninfo))
@classmethod
def connect_start(cls, conninfo: bytes) -> Self:
return cls(debugging(PGconn.connect_start)(conninfo))
@classmethod
def ping(self, conninfo: bytes) -> int:
return debugging(PGconn.ping)(conninfo)
def debugging(f: Func) -> Func:
"""Wrap a function in order to log its arguments and return value on call."""
@wraps(f)
def debugging_(*args: Any, **kwargs: Any) -> Any:
reprs = []
for arg in args:
reprs.append(f"{arg!r}")
for k, v in kwargs.items():
reprs.append(f"{k}={v!r}")
logger.info("PGconn.%s(%s)", f.__name__, ", ".join(reprs))
rv = f(*args, **kwargs)
# Display the return value only if the function is declared to return
# something else than None.
ra = inspect.signature(f).return_annotation
if ra is not None or rv is not None:
logger.info(" <- %r", rv)
return rv
return debugging_ # type: ignore

View File

@@ -0,0 +1,249 @@
"""
libpq enum definitions for psycopg
"""
# Copyright (C) 2020 The Psycopg Team
from enum import IntEnum, IntFlag, auto
class ConnStatus(IntEnum):
"""
Current status of the connection.
"""
__module__ = "psycopg.pq"
OK = 0
"""The connection is in a working state."""
BAD = auto()
"""The connection is closed."""
STARTED = auto()
MADE = auto()
AWAITING_RESPONSE = auto()
AUTH_OK = auto()
SETENV = auto()
SSL_STARTUP = auto()
NEEDED = auto()
CHECK_WRITABLE = auto()
CONSUME = auto()
GSS_STARTUP = auto()
CHECK_TARGET = auto()
CHECK_STANDBY = auto()
class PollingStatus(IntEnum):
"""
The status of the socket during a connection.
If ``READING`` or ``WRITING`` you may select before polling again.
"""
__module__ = "psycopg.pq"
FAILED = 0
"""Connection attempt failed."""
READING = auto()
"""Will have to wait before reading new data."""
WRITING = auto()
"""Will have to wait before writing new data."""
OK = auto()
"""Connection completed."""
ACTIVE = auto()
class ExecStatus(IntEnum):
"""
The status of a command.
"""
__module__ = "psycopg.pq"
EMPTY_QUERY = 0
"""The string sent to the server was empty."""
COMMAND_OK = auto()
"""Successful completion of a command returning no data."""
TUPLES_OK = auto()
"""
Successful completion of a command returning data (such as a SELECT or SHOW).
"""
COPY_OUT = auto()
"""Copy Out (from server) data transfer started."""
COPY_IN = auto()
"""Copy In (to server) data transfer started."""
BAD_RESPONSE = auto()
"""The server's response was not understood."""
NONFATAL_ERROR = auto()
"""A nonfatal error (a notice or warning) occurred."""
FATAL_ERROR = auto()
"""A fatal error occurred."""
COPY_BOTH = auto()
"""
Copy In/Out (to and from server) data transfer started.
This feature is currently used only for streaming replication, so this
status should not occur in ordinary applications.
"""
SINGLE_TUPLE = auto()
"""
The PGresult contains a single result tuple from the current command.
This status occurs only when single-row mode has been selected for the
query.
"""
PIPELINE_SYNC = auto()
"""
The PGresult represents a synchronization point in pipeline mode,
requested by PQpipelineSync.
This status occurs only when pipeline mode has been selected.
"""
PIPELINE_ABORTED = auto()
"""
The PGresult represents a pipeline that has received an error from the server.
PQgetResult must be called repeatedly, and each time it will return this
status code until the end of the current pipeline, at which point it will
return PGRES_PIPELINE_SYNC and normal processing can resume.
"""
class TransactionStatus(IntEnum):
"""
The transaction status of a connection.
"""
__module__ = "psycopg.pq"
IDLE = 0
"""Connection ready, no transaction active."""
ACTIVE = auto()
"""A command is in progress."""
INTRANS = auto()
"""Connection idle in an open transaction."""
INERROR = auto()
"""An error happened in the current transaction."""
UNKNOWN = auto()
"""Unknown connection state, broken connection."""
class Ping(IntEnum):
"""Response from a ping attempt."""
__module__ = "psycopg.pq"
OK = 0
"""
The server is running and appears to be accepting connections.
"""
REJECT = auto()
"""
The server is running but is in a state that disallows connections.
"""
NO_RESPONSE = auto()
"""
The server could not be contacted.
"""
NO_ATTEMPT = auto()
"""
No attempt was made to contact the server.
"""
class PipelineStatus(IntEnum):
"""Pipeline mode status of the libpq connection."""
__module__ = "psycopg.pq"
OFF = 0
"""
The libpq connection is *not* in pipeline mode.
"""
ON = auto()
"""
The libpq connection is in pipeline mode.
"""
ABORTED = auto()
"""
The libpq connection is in pipeline mode and an error occurred while
processing the current pipeline. The aborted flag is cleared when
PQgetResult returns a result of type PGRES_PIPELINE_SYNC.
"""
class DiagnosticField(IntEnum):
"""
Fields in an error report.
"""
__module__ = "psycopg.pq"
# from postgres_ext.h
SEVERITY = ord("S")
SEVERITY_NONLOCALIZED = ord("V")
SQLSTATE = ord("C")
MESSAGE_PRIMARY = ord("M")
MESSAGE_DETAIL = ord("D")
MESSAGE_HINT = ord("H")
STATEMENT_POSITION = ord("P")
INTERNAL_POSITION = ord("p")
INTERNAL_QUERY = ord("q")
CONTEXT = ord("W")
SCHEMA_NAME = ord("s")
TABLE_NAME = ord("t")
COLUMN_NAME = ord("c")
DATATYPE_NAME = ord("d")
CONSTRAINT_NAME = ord("n")
SOURCE_FILE = ord("F")
SOURCE_LINE = ord("L")
SOURCE_FUNCTION = ord("R")
class Format(IntEnum):
"""
Enum representing the format of a query argument or return value.
These values are only the ones managed by the libpq. `~psycopg` may also
support automatically-chosen values: see `psycopg.adapt.PyFormat`.
"""
__module__ = "psycopg.pq"
TEXT = 0
"""Text parameter."""
BINARY = 1
"""Binary parameter."""
class Trace(IntFlag):
"""
Enum to control tracing of the client/server communication.
"""
__module__ = "psycopg.pq"
SUPPRESS_TIMESTAMPS = 1
"""Do not include timestamps in messages."""
REGRESS_MODE = 2
"""Redact some fields, e.g. OIDs, from messages."""

View File

@@ -0,0 +1,807 @@
"""
libpq access using ctypes
"""
# Copyright (C) 2020 The Psycopg Team
import sys
import ctypes
import ctypes.util
from ctypes import Structure, CFUNCTYPE, POINTER
from ctypes import c_char, c_char_p, c_int, c_size_t, c_ubyte, c_uint, c_void_p
from typing import List, Optional, Tuple
from .misc import find_libpq_full_path
from ..errors import NotSupportedError
libname = find_libpq_full_path()
if not libname:
raise ImportError("libpq library not found")
pq = ctypes.cdll.LoadLibrary(libname)
class FILE(Structure):
pass
FILE_ptr = POINTER(FILE)
if sys.platform == "linux":
libcname = ctypes.util.find_library("c")
if not libcname:
# Likely this is a system using musl libc, see the following bug:
# https://github.com/python/cpython/issues/65821
libcname = "libc.so"
libc = ctypes.cdll.LoadLibrary(libcname)
fdopen = libc.fdopen
fdopen.argtypes = (c_int, c_char_p)
fdopen.restype = FILE_ptr
# Get the libpq version to define what functions are available.
PQlibVersion = pq.PQlibVersion
PQlibVersion.argtypes = []
PQlibVersion.restype = c_int
libpq_version = PQlibVersion()
# libpq data types
Oid = c_uint
class PGconn_struct(Structure):
_fields_: List[Tuple[str, type]] = []
class PGresult_struct(Structure):
_fields_: List[Tuple[str, type]] = []
class PQconninfoOption_struct(Structure):
_fields_ = [
("keyword", c_char_p),
("envvar", c_char_p),
("compiled", c_char_p),
("val", c_char_p),
("label", c_char_p),
("dispchar", c_char_p),
("dispsize", c_int),
]
class PGnotify_struct(Structure):
_fields_ = [
("relname", c_char_p),
("be_pid", c_int),
("extra", c_char_p),
]
class PGcancel_struct(Structure):
_fields_: List[Tuple[str, type]] = []
class PGresAttDesc_struct(Structure):
_fields_ = [
("name", c_char_p),
("tableid", Oid),
("columnid", c_int),
("format", c_int),
("typid", Oid),
("typlen", c_int),
("atttypmod", c_int),
]
PGconn_ptr = POINTER(PGconn_struct)
PGresult_ptr = POINTER(PGresult_struct)
PQconninfoOption_ptr = POINTER(PQconninfoOption_struct)
PGnotify_ptr = POINTER(PGnotify_struct)
PGcancel_ptr = POINTER(PGcancel_struct)
PGresAttDesc_ptr = POINTER(PGresAttDesc_struct)
# Function definitions as explained in PostgreSQL 12 documentation
# 33.1. Database Connection Control Functions
# PQconnectdbParams: doesn't seem useful, won't wrap for now
PQconnectdb = pq.PQconnectdb
PQconnectdb.argtypes = [c_char_p]
PQconnectdb.restype = PGconn_ptr
# PQsetdbLogin: not useful
# PQsetdb: not useful
# PQconnectStartParams: not useful
PQconnectStart = pq.PQconnectStart
PQconnectStart.argtypes = [c_char_p]
PQconnectStart.restype = PGconn_ptr
PQconnectPoll = pq.PQconnectPoll
PQconnectPoll.argtypes = [PGconn_ptr]
PQconnectPoll.restype = c_int
PQconndefaults = pq.PQconndefaults
PQconndefaults.argtypes = []
PQconndefaults.restype = PQconninfoOption_ptr
PQconninfoFree = pq.PQconninfoFree
PQconninfoFree.argtypes = [PQconninfoOption_ptr]
PQconninfoFree.restype = None
PQconninfo = pq.PQconninfo
PQconninfo.argtypes = [PGconn_ptr]
PQconninfo.restype = PQconninfoOption_ptr
PQconninfoParse = pq.PQconninfoParse
PQconninfoParse.argtypes = [c_char_p, POINTER(c_char_p)]
PQconninfoParse.restype = PQconninfoOption_ptr
PQfinish = pq.PQfinish
PQfinish.argtypes = [PGconn_ptr]
PQfinish.restype = None
PQreset = pq.PQreset
PQreset.argtypes = [PGconn_ptr]
PQreset.restype = None
PQresetStart = pq.PQresetStart
PQresetStart.argtypes = [PGconn_ptr]
PQresetStart.restype = c_int
PQresetPoll = pq.PQresetPoll
PQresetPoll.argtypes = [PGconn_ptr]
PQresetPoll.restype = c_int
PQping = pq.PQping
PQping.argtypes = [c_char_p]
PQping.restype = c_int
# 33.2. Connection Status Functions
PQdb = pq.PQdb
PQdb.argtypes = [PGconn_ptr]
PQdb.restype = c_char_p
PQuser = pq.PQuser
PQuser.argtypes = [PGconn_ptr]
PQuser.restype = c_char_p
PQpass = pq.PQpass
PQpass.argtypes = [PGconn_ptr]
PQpass.restype = c_char_p
PQhost = pq.PQhost
PQhost.argtypes = [PGconn_ptr]
PQhost.restype = c_char_p
_PQhostaddr = None
if libpq_version >= 120000:
_PQhostaddr = pq.PQhostaddr
_PQhostaddr.argtypes = [PGconn_ptr]
_PQhostaddr.restype = c_char_p
def PQhostaddr(pgconn: PGconn_struct) -> bytes:
if not _PQhostaddr:
raise NotSupportedError(
"PQhostaddr requires libpq from PostgreSQL 12,"
f" {libpq_version} available instead"
)
return _PQhostaddr(pgconn)
PQport = pq.PQport
PQport.argtypes = [PGconn_ptr]
PQport.restype = c_char_p
PQtty = pq.PQtty
PQtty.argtypes = [PGconn_ptr]
PQtty.restype = c_char_p
PQoptions = pq.PQoptions
PQoptions.argtypes = [PGconn_ptr]
PQoptions.restype = c_char_p
PQstatus = pq.PQstatus
PQstatus.argtypes = [PGconn_ptr]
PQstatus.restype = c_int
PQtransactionStatus = pq.PQtransactionStatus
PQtransactionStatus.argtypes = [PGconn_ptr]
PQtransactionStatus.restype = c_int
PQparameterStatus = pq.PQparameterStatus
PQparameterStatus.argtypes = [PGconn_ptr, c_char_p]
PQparameterStatus.restype = c_char_p
PQprotocolVersion = pq.PQprotocolVersion
PQprotocolVersion.argtypes = [PGconn_ptr]
PQprotocolVersion.restype = c_int
PQserverVersion = pq.PQserverVersion
PQserverVersion.argtypes = [PGconn_ptr]
PQserverVersion.restype = c_int
PQerrorMessage = pq.PQerrorMessage
PQerrorMessage.argtypes = [PGconn_ptr]
PQerrorMessage.restype = c_char_p
PQsocket = pq.PQsocket
PQsocket.argtypes = [PGconn_ptr]
PQsocket.restype = c_int
PQbackendPID = pq.PQbackendPID
PQbackendPID.argtypes = [PGconn_ptr]
PQbackendPID.restype = c_int
PQconnectionNeedsPassword = pq.PQconnectionNeedsPassword
PQconnectionNeedsPassword.argtypes = [PGconn_ptr]
PQconnectionNeedsPassword.restype = c_int
PQconnectionUsedPassword = pq.PQconnectionUsedPassword
PQconnectionUsedPassword.argtypes = [PGconn_ptr]
PQconnectionUsedPassword.restype = c_int
PQsslInUse = pq.PQsslInUse
PQsslInUse.argtypes = [PGconn_ptr]
PQsslInUse.restype = c_int
# TODO: PQsslAttribute, PQsslAttributeNames, PQsslStruct, PQgetssl
# 33.3. Command Execution Functions
PQexec = pq.PQexec
PQexec.argtypes = [PGconn_ptr, c_char_p]
PQexec.restype = PGresult_ptr
PQexecParams = pq.PQexecParams
PQexecParams.argtypes = [
PGconn_ptr,
c_char_p,
c_int,
POINTER(Oid),
POINTER(c_char_p),
POINTER(c_int),
POINTER(c_int),
c_int,
]
PQexecParams.restype = PGresult_ptr
PQprepare = pq.PQprepare
PQprepare.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_int, POINTER(Oid)]
PQprepare.restype = PGresult_ptr
PQexecPrepared = pq.PQexecPrepared
PQexecPrepared.argtypes = [
PGconn_ptr,
c_char_p,
c_int,
POINTER(c_char_p),
POINTER(c_int),
POINTER(c_int),
c_int,
]
PQexecPrepared.restype = PGresult_ptr
PQdescribePrepared = pq.PQdescribePrepared
PQdescribePrepared.argtypes = [PGconn_ptr, c_char_p]
PQdescribePrepared.restype = PGresult_ptr
PQdescribePortal = pq.PQdescribePortal
PQdescribePortal.argtypes = [PGconn_ptr, c_char_p]
PQdescribePortal.restype = PGresult_ptr
PQresultStatus = pq.PQresultStatus
PQresultStatus.argtypes = [PGresult_ptr]
PQresultStatus.restype = c_int
# PQresStatus: not needed, we have pretty enums
PQresultErrorMessage = pq.PQresultErrorMessage
PQresultErrorMessage.argtypes = [PGresult_ptr]
PQresultErrorMessage.restype = c_char_p
# TODO: PQresultVerboseErrorMessage
PQresultErrorField = pq.PQresultErrorField
PQresultErrorField.argtypes = [PGresult_ptr, c_int]
PQresultErrorField.restype = c_char_p
PQclear = pq.PQclear
PQclear.argtypes = [PGresult_ptr]
PQclear.restype = None
# 33.3.2. Retrieving Query Result Information
PQntuples = pq.PQntuples
PQntuples.argtypes = [PGresult_ptr]
PQntuples.restype = c_int
PQnfields = pq.PQnfields
PQnfields.argtypes = [PGresult_ptr]
PQnfields.restype = c_int
PQfname = pq.PQfname
PQfname.argtypes = [PGresult_ptr, c_int]
PQfname.restype = c_char_p
# PQfnumber: useless and hard to use
PQftable = pq.PQftable
PQftable.argtypes = [PGresult_ptr, c_int]
PQftable.restype = Oid
PQftablecol = pq.PQftablecol
PQftablecol.argtypes = [PGresult_ptr, c_int]
PQftablecol.restype = c_int
PQfformat = pq.PQfformat
PQfformat.argtypes = [PGresult_ptr, c_int]
PQfformat.restype = c_int
PQftype = pq.PQftype
PQftype.argtypes = [PGresult_ptr, c_int]
PQftype.restype = Oid
PQfmod = pq.PQfmod
PQfmod.argtypes = [PGresult_ptr, c_int]
PQfmod.restype = c_int
PQfsize = pq.PQfsize
PQfsize.argtypes = [PGresult_ptr, c_int]
PQfsize.restype = c_int
PQbinaryTuples = pq.PQbinaryTuples
PQbinaryTuples.argtypes = [PGresult_ptr]
PQbinaryTuples.restype = c_int
PQgetvalue = pq.PQgetvalue
PQgetvalue.argtypes = [PGresult_ptr, c_int, c_int]
PQgetvalue.restype = POINTER(c_char) # not a null-terminated string
PQgetisnull = pq.PQgetisnull
PQgetisnull.argtypes = [PGresult_ptr, c_int, c_int]
PQgetisnull.restype = c_int
PQgetlength = pq.PQgetlength
PQgetlength.argtypes = [PGresult_ptr, c_int, c_int]
PQgetlength.restype = c_int
PQnparams = pq.PQnparams
PQnparams.argtypes = [PGresult_ptr]
PQnparams.restype = c_int
PQparamtype = pq.PQparamtype
PQparamtype.argtypes = [PGresult_ptr, c_int]
PQparamtype.restype = Oid
# PQprint: pretty useless
# 33.3.3. Retrieving Other Result Information
PQcmdStatus = pq.PQcmdStatus
PQcmdStatus.argtypes = [PGresult_ptr]
PQcmdStatus.restype = c_char_p
PQcmdTuples = pq.PQcmdTuples
PQcmdTuples.argtypes = [PGresult_ptr]
PQcmdTuples.restype = c_char_p
PQoidValue = pq.PQoidValue
PQoidValue.argtypes = [PGresult_ptr]
PQoidValue.restype = Oid
# 33.3.4. Escaping Strings for Inclusion in SQL Commands
PQescapeLiteral = pq.PQescapeLiteral
PQescapeLiteral.argtypes = [PGconn_ptr, c_char_p, c_size_t]
PQescapeLiteral.restype = POINTER(c_char)
PQescapeIdentifier = pq.PQescapeIdentifier
PQescapeIdentifier.argtypes = [PGconn_ptr, c_char_p, c_size_t]
PQescapeIdentifier.restype = POINTER(c_char)
PQescapeStringConn = pq.PQescapeStringConn
# TODO: raises "wrong type" error
# PQescapeStringConn.argtypes = [
# PGconn_ptr, c_char_p, c_char_p, c_size_t, POINTER(c_int)
# ]
PQescapeStringConn.restype = c_size_t
PQescapeString = pq.PQescapeString
# TODO: raises "wrong type" error
# PQescapeString.argtypes = [c_char_p, c_char_p, c_size_t]
PQescapeString.restype = c_size_t
PQescapeByteaConn = pq.PQescapeByteaConn
PQescapeByteaConn.argtypes = [
PGconn_ptr,
POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
c_size_t,
POINTER(c_size_t),
]
PQescapeByteaConn.restype = POINTER(c_ubyte)
PQescapeBytea = pq.PQescapeBytea
PQescapeBytea.argtypes = [
POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
c_size_t,
POINTER(c_size_t),
]
PQescapeBytea.restype = POINTER(c_ubyte)
PQunescapeBytea = pq.PQunescapeBytea
PQunescapeBytea.argtypes = [
POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
POINTER(c_size_t),
]
PQunescapeBytea.restype = POINTER(c_ubyte)
# 33.4. Asynchronous Command Processing
PQsendQuery = pq.PQsendQuery
PQsendQuery.argtypes = [PGconn_ptr, c_char_p]
PQsendQuery.restype = c_int
PQsendQueryParams = pq.PQsendQueryParams
PQsendQueryParams.argtypes = [
PGconn_ptr,
c_char_p,
c_int,
POINTER(Oid),
POINTER(c_char_p),
POINTER(c_int),
POINTER(c_int),
c_int,
]
PQsendQueryParams.restype = c_int
PQsendPrepare = pq.PQsendPrepare
PQsendPrepare.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_int, POINTER(Oid)]
PQsendPrepare.restype = c_int
PQsendQueryPrepared = pq.PQsendQueryPrepared
PQsendQueryPrepared.argtypes = [
PGconn_ptr,
c_char_p,
c_int,
POINTER(c_char_p),
POINTER(c_int),
POINTER(c_int),
c_int,
]
PQsendQueryPrepared.restype = c_int
PQsendDescribePrepared = pq.PQsendDescribePrepared
PQsendDescribePrepared.argtypes = [PGconn_ptr, c_char_p]
PQsendDescribePrepared.restype = c_int
PQsendDescribePortal = pq.PQsendDescribePortal
PQsendDescribePortal.argtypes = [PGconn_ptr, c_char_p]
PQsendDescribePortal.restype = c_int
PQgetResult = pq.PQgetResult
PQgetResult.argtypes = [PGconn_ptr]
PQgetResult.restype = PGresult_ptr
PQconsumeInput = pq.PQconsumeInput
PQconsumeInput.argtypes = [PGconn_ptr]
PQconsumeInput.restype = c_int
PQisBusy = pq.PQisBusy
PQisBusy.argtypes = [PGconn_ptr]
PQisBusy.restype = c_int
PQsetnonblocking = pq.PQsetnonblocking
PQsetnonblocking.argtypes = [PGconn_ptr, c_int]
PQsetnonblocking.restype = c_int
PQisnonblocking = pq.PQisnonblocking
PQisnonblocking.argtypes = [PGconn_ptr]
PQisnonblocking.restype = c_int
PQflush = pq.PQflush
PQflush.argtypes = [PGconn_ptr]
PQflush.restype = c_int
# 33.5. Retrieving Query Results Row-by-Row
PQsetSingleRowMode = pq.PQsetSingleRowMode
PQsetSingleRowMode.argtypes = [PGconn_ptr]
PQsetSingleRowMode.restype = c_int
# 33.6. Canceling Queries in Progress
PQgetCancel = pq.PQgetCancel
PQgetCancel.argtypes = [PGconn_ptr]
PQgetCancel.restype = PGcancel_ptr
PQfreeCancel = pq.PQfreeCancel
PQfreeCancel.argtypes = [PGcancel_ptr]
PQfreeCancel.restype = None
PQcancel = pq.PQcancel
# TODO: raises "wrong type" error
# PQcancel.argtypes = [PGcancel_ptr, POINTER(c_char), c_int]
PQcancel.restype = c_int
# 33.8. Asynchronous Notification
PQnotifies = pq.PQnotifies
PQnotifies.argtypes = [PGconn_ptr]
PQnotifies.restype = PGnotify_ptr
# 33.9. Functions Associated with the COPY Command
PQputCopyData = pq.PQputCopyData
PQputCopyData.argtypes = [PGconn_ptr, c_char_p, c_int]
PQputCopyData.restype = c_int
PQputCopyEnd = pq.PQputCopyEnd
PQputCopyEnd.argtypes = [PGconn_ptr, c_char_p]
PQputCopyEnd.restype = c_int
PQgetCopyData = pq.PQgetCopyData
PQgetCopyData.argtypes = [PGconn_ptr, POINTER(c_char_p), c_int]
PQgetCopyData.restype = c_int
# 33.10. Control Functions
PQtrace = pq.PQtrace
PQtrace.argtypes = [PGconn_ptr, FILE_ptr]
PQtrace.restype = None
_PQsetTraceFlags = None
if libpq_version >= 140000:
_PQsetTraceFlags = pq.PQsetTraceFlags
_PQsetTraceFlags.argtypes = [PGconn_ptr, c_int]
_PQsetTraceFlags.restype = None
def PQsetTraceFlags(pgconn: PGconn_struct, flags: int) -> None:
if not _PQsetTraceFlags:
raise NotSupportedError(
"PQsetTraceFlags requires libpq from PostgreSQL 14,"
f" {libpq_version} available instead"
)
_PQsetTraceFlags(pgconn, flags)
PQuntrace = pq.PQuntrace
PQuntrace.argtypes = [PGconn_ptr]
PQuntrace.restype = None
# 33.11. Miscellaneous Functions
PQfreemem = pq.PQfreemem
PQfreemem.argtypes = [c_void_p]
PQfreemem.restype = None
if libpq_version >= 100000:
_PQencryptPasswordConn = pq.PQencryptPasswordConn
_PQencryptPasswordConn.argtypes = [
PGconn_ptr,
c_char_p,
c_char_p,
c_char_p,
]
_PQencryptPasswordConn.restype = POINTER(c_char)
def PQencryptPasswordConn(
pgconn: PGconn_struct, passwd: bytes, user: bytes, algorithm: bytes
) -> Optional[bytes]:
if not _PQencryptPasswordConn:
raise NotSupportedError(
"PQencryptPasswordConn requires libpq from PostgreSQL 10,"
f" {libpq_version} available instead"
)
return _PQencryptPasswordConn(pgconn, passwd, user, algorithm)
PQmakeEmptyPGresult = pq.PQmakeEmptyPGresult
PQmakeEmptyPGresult.argtypes = [PGconn_ptr, c_int]
PQmakeEmptyPGresult.restype = PGresult_ptr
PQsetResultAttrs = pq.PQsetResultAttrs
PQsetResultAttrs.argtypes = [PGresult_ptr, c_int, PGresAttDesc_ptr]
PQsetResultAttrs.restype = c_int
# 33.12. Notice Processing
PQnoticeReceiver = CFUNCTYPE(None, c_void_p, PGresult_ptr)
PQsetNoticeReceiver = pq.PQsetNoticeReceiver
PQsetNoticeReceiver.argtypes = [PGconn_ptr, PQnoticeReceiver, c_void_p]
PQsetNoticeReceiver.restype = PQnoticeReceiver
# 34.5 Pipeline Mode
_PQpipelineStatus = None
_PQenterPipelineMode = None
_PQexitPipelineMode = None
_PQpipelineSync = None
_PQsendFlushRequest = None
if libpq_version >= 140000:
_PQpipelineStatus = pq.PQpipelineStatus
_PQpipelineStatus.argtypes = [PGconn_ptr]
_PQpipelineStatus.restype = c_int
_PQenterPipelineMode = pq.PQenterPipelineMode
_PQenterPipelineMode.argtypes = [PGconn_ptr]
_PQenterPipelineMode.restype = c_int
_PQexitPipelineMode = pq.PQexitPipelineMode
_PQexitPipelineMode.argtypes = [PGconn_ptr]
_PQexitPipelineMode.restype = c_int
_PQpipelineSync = pq.PQpipelineSync
_PQpipelineSync.argtypes = [PGconn_ptr]
_PQpipelineSync.restype = c_int
_PQsendFlushRequest = pq.PQsendFlushRequest
_PQsendFlushRequest.argtypes = [PGconn_ptr]
_PQsendFlushRequest.restype = c_int
def PQpipelineStatus(pgconn: PGconn_struct) -> int:
if not _PQpipelineStatus:
raise NotSupportedError(
"PQpipelineStatus requires libpq from PostgreSQL 14,"
f" {libpq_version} available instead"
)
return _PQpipelineStatus(pgconn)
def PQenterPipelineMode(pgconn: PGconn_struct) -> int:
if not _PQenterPipelineMode:
raise NotSupportedError(
"PQenterPipelineMode requires libpq from PostgreSQL 14,"
f" {libpq_version} available instead"
)
return _PQenterPipelineMode(pgconn)
def PQexitPipelineMode(pgconn: PGconn_struct) -> int:
if not _PQexitPipelineMode:
raise NotSupportedError(
"PQexitPipelineMode requires libpq from PostgreSQL 14,"
f" {libpq_version} available instead"
)
return _PQexitPipelineMode(pgconn)
def PQpipelineSync(pgconn: PGconn_struct) -> int:
if not _PQpipelineSync:
raise NotSupportedError(
"PQpipelineSync requires libpq from PostgreSQL 14,"
f" {libpq_version} available instead"
)
return _PQpipelineSync(pgconn)
def PQsendFlushRequest(pgconn: PGconn_struct) -> int:
if not _PQsendFlushRequest:
raise NotSupportedError(
"PQsendFlushRequest requires libpq from PostgreSQL 14,"
f" {libpq_version} available instead"
)
return _PQsendFlushRequest(pgconn)
# 33.18. SSL Support
PQinitOpenSSL = pq.PQinitOpenSSL
PQinitOpenSSL.argtypes = [c_int, c_int]
PQinitOpenSSL.restype = None
def generate_stub() -> None:
import re
from ctypes import _CFuncPtr # type: ignore
def type2str(fname, narg, t):
if t is None:
return "None"
elif t is c_void_p:
return "Any"
elif t is c_int or t is c_uint or t is c_size_t:
return "int"
elif t is c_char_p or t.__name__ == "LP_c_char":
if narg is not None:
return "bytes"
else:
return "Optional[bytes]"
elif t.__name__ in (
"LP_PGconn_struct",
"LP_PGresult_struct",
"LP_PGcancel_struct",
):
if narg is not None:
return f"Optional[{t.__name__[3:]}]"
else:
return t.__name__[3:]
elif t.__name__ in ("LP_PQconninfoOption_struct",):
return f"Sequence[{t.__name__[3:]}]"
elif t.__name__ in (
"LP_c_ubyte",
"LP_c_char_p",
"LP_c_int",
"LP_c_uint",
"LP_c_ulong",
"LP_FILE",
):
return f"_Pointer[{t.__name__[3:]}]"
else:
assert False, f"can't deal with {t} in {fname}"
fn = __file__ + "i"
with open(fn) as f:
lines = f.read().splitlines()
istart, iend = (
i
for i, line in enumerate(lines)
if re.match(r"\s*#\s*autogenerated:\s+(start|end)", line)
)
known = {
line[4:].split("(", 1)[0] for line in lines[:istart] if line.startswith("def ")
}
signatures = []
for name, obj in globals().items():
if name in known:
continue
if not isinstance(obj, _CFuncPtr):
continue
params = []
for i, t in enumerate(obj.argtypes):
params.append(f"arg{i + 1}: {type2str(name, i, t)}")
resname = type2str(name, None, obj.restype)
signatures.append(f"def {name}({', '.join(params)}) -> {resname}: ...")
lines[istart + 1 : iend] = signatures
with open(fn, "w") as f:
f.write("\n".join(lines))
f.write("\n")
if __name__ == "__main__":
generate_stub()

View File

@@ -0,0 +1,216 @@
"""
types stub for ctypes functions
"""
# Copyright (C) 2020 The Psycopg Team
from typing import Any, Callable, Optional, Sequence
from ctypes import Array, pointer, _Pointer
from ctypes import c_char, c_char_p, c_int, c_ubyte, c_uint, c_ulong
class FILE: ...
def fdopen(fd: int, mode: bytes) -> _Pointer[FILE]: ... # type: ignore[type-var]
Oid = c_uint
class PGconn_struct: ...
class PGresult_struct: ...
class PGcancel_struct: ...
class PQconninfoOption_struct:
keyword: bytes
envvar: bytes
compiled: bytes
val: bytes
label: bytes
dispchar: bytes
dispsize: int
class PGnotify_struct:
be_pid: int
relname: bytes
extra: bytes
class PGresAttDesc_struct:
name: bytes
tableid: int
columnid: int
format: int
typid: int
typlen: int
atttypmod: int
def PQhostaddr(arg1: Optional[PGconn_struct]) -> bytes: ...
def PQerrorMessage(arg1: Optional[PGconn_struct]) -> bytes: ...
def PQresultErrorMessage(arg1: Optional[PGresult_struct]) -> bytes: ...
def PQexecPrepared(
arg1: Optional[PGconn_struct],
arg2: bytes,
arg3: int,
arg4: Optional[Array[c_char_p]],
arg5: Optional[Array[c_int]],
arg6: Optional[Array[c_int]],
arg7: int,
) -> PGresult_struct: ...
def PQprepare(
arg1: Optional[PGconn_struct],
arg2: bytes,
arg3: bytes,
arg4: int,
arg5: Optional[Array[c_uint]],
) -> PGresult_struct: ...
def PQgetvalue(
arg1: Optional[PGresult_struct], arg2: int, arg3: int
) -> _Pointer[c_char]: ...
def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ...
def PQescapeStringConn(
arg1: Optional[PGconn_struct],
arg2: c_char_p,
arg3: bytes,
arg4: int,
arg5: _Pointer[c_int],
) -> int: ...
def PQescapeString(arg1: c_char_p, arg2: bytes, arg3: int) -> int: ...
def PQsendPrepare(
arg1: Optional[PGconn_struct],
arg2: bytes,
arg3: bytes,
arg4: int,
arg5: Optional[Array[c_uint]],
) -> int: ...
def PQsendQueryPrepared(
arg1: Optional[PGconn_struct],
arg2: bytes,
arg3: int,
arg4: Optional[Array[c_char_p]],
arg5: Optional[Array[c_int]],
arg6: Optional[Array[c_int]],
arg7: int,
) -> int: ...
def PQcancel(arg1: Optional[PGcancel_struct], arg2: c_char_p, arg3: int) -> int: ...
def PQsetNoticeReceiver(
arg1: PGconn_struct, arg2: Callable[[Any], PGresult_struct], arg3: Any
) -> Callable[[Any], PGresult_struct]: ...
# TODO: Ignoring type as getting an error on mypy/ctypes:
# Type argument "psycopg.pq._pq_ctypes.PGnotify_struct" of "pointer" must be
# a subtype of "ctypes._CData"
def PQnotifies(
arg1: Optional[PGconn_struct],
) -> Optional[_Pointer[PGnotify_struct]]: ... # type: ignore
def PQputCopyEnd(arg1: Optional[PGconn_struct], arg2: Optional[bytes]) -> int: ...
# Arg 2 is a _Pointer, reported as _CArgObject by mypy
def PQgetCopyData(arg1: Optional[PGconn_struct], arg2: Any, arg3: int) -> int: ...
def PQsetResultAttrs(
arg1: Optional[PGresult_struct],
arg2: int,
arg3: Array[PGresAttDesc_struct], # type: ignore
) -> int: ...
def PQtrace(
arg1: Optional[PGconn_struct],
arg2: _Pointer[FILE], # type: ignore[type-var]
) -> None: ...
def PQencryptPasswordConn(
arg1: Optional[PGconn_struct],
arg2: bytes,
arg3: bytes,
arg4: Optional[bytes],
) -> bytes: ...
def PQpipelineStatus(pgconn: Optional[PGconn_struct]) -> int: ...
def PQenterPipelineMode(pgconn: Optional[PGconn_struct]) -> int: ...
def PQexitPipelineMode(pgconn: Optional[PGconn_struct]) -> int: ...
def PQpipelineSync(pgconn: Optional[PGconn_struct]) -> int: ...
def PQsendFlushRequest(pgconn: Optional[PGconn_struct]) -> int: ...
# fmt: off
# autogenerated: start
def PQlibVersion() -> int: ...
def PQconnectdb(arg1: bytes) -> PGconn_struct: ...
def PQconnectStart(arg1: bytes) -> PGconn_struct: ...
def PQconnectPoll(arg1: Optional[PGconn_struct]) -> int: ...
def PQconndefaults() -> Sequence[PQconninfoOption_struct]: ...
def PQconninfoFree(arg1: Sequence[PQconninfoOption_struct]) -> None: ...
def PQconninfo(arg1: Optional[PGconn_struct]) -> Sequence[PQconninfoOption_struct]: ...
def PQconninfoParse(arg1: bytes, arg2: _Pointer[c_char_p]) -> Sequence[PQconninfoOption_struct]: ...
def PQfinish(arg1: Optional[PGconn_struct]) -> None: ...
def PQreset(arg1: Optional[PGconn_struct]) -> None: ...
def PQresetStart(arg1: Optional[PGconn_struct]) -> int: ...
def PQresetPoll(arg1: Optional[PGconn_struct]) -> int: ...
def PQping(arg1: bytes) -> int: ...
def PQdb(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
def PQuser(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
def PQpass(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
def PQhost(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
def _PQhostaddr(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
def PQport(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
def PQtty(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
def PQoptions(arg1: Optional[PGconn_struct]) -> Optional[bytes]: ...
def PQstatus(arg1: Optional[PGconn_struct]) -> int: ...
def PQtransactionStatus(arg1: Optional[PGconn_struct]) -> int: ...
def PQparameterStatus(arg1: Optional[PGconn_struct], arg2: bytes) -> Optional[bytes]: ...
def PQprotocolVersion(arg1: Optional[PGconn_struct]) -> int: ...
def PQserverVersion(arg1: Optional[PGconn_struct]) -> int: ...
def PQsocket(arg1: Optional[PGconn_struct]) -> int: ...
def PQbackendPID(arg1: Optional[PGconn_struct]) -> int: ...
def PQconnectionNeedsPassword(arg1: Optional[PGconn_struct]) -> int: ...
def PQconnectionUsedPassword(arg1: Optional[PGconn_struct]) -> int: ...
def PQsslInUse(arg1: Optional[PGconn_struct]) -> int: ...
def PQexec(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ...
def PQexecParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_uint], arg5: _Pointer[c_char_p], arg6: _Pointer[c_int], arg7: _Pointer[c_int], arg8: int) -> PGresult_struct: ...
def PQdescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ...
def PQdescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ...
def PQresultStatus(arg1: Optional[PGresult_struct]) -> int: ...
def PQresultErrorField(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ...
def PQclear(arg1: Optional[PGresult_struct]) -> None: ...
def PQntuples(arg1: Optional[PGresult_struct]) -> int: ...
def PQnfields(arg1: Optional[PGresult_struct]) -> int: ...
def PQfname(arg1: Optional[PGresult_struct], arg2: int) -> Optional[bytes]: ...
def PQftable(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
def PQftablecol(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
def PQfformat(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
def PQftype(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
def PQfmod(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
def PQfsize(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
def PQbinaryTuples(arg1: Optional[PGresult_struct]) -> int: ...
def PQgetisnull(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ...
def PQgetlength(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ...
def PQnparams(arg1: Optional[PGresult_struct]) -> int: ...
def PQparamtype(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
def PQcmdStatus(arg1: Optional[PGresult_struct]) -> Optional[bytes]: ...
def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ...
def PQescapeLiteral(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ...
def PQescapeIdentifier(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ...
def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ...
def PQescapeBytea(arg1: bytes, arg2: int, arg3: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ...
def PQunescapeBytea(arg1: bytes, arg2: _Pointer[c_ulong]) -> _Pointer[c_ubyte]: ...
def PQsendQuery(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ...
def PQsendQueryParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: _Pointer[c_uint], arg5: _Pointer[c_char_p], arg6: _Pointer[c_int], arg7: _Pointer[c_int], arg8: int) -> int: ...
def PQsendDescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ...
def PQsendDescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ...
def PQgetResult(arg1: Optional[PGconn_struct]) -> PGresult_struct: ...
def PQconsumeInput(arg1: Optional[PGconn_struct]) -> int: ...
def PQisBusy(arg1: Optional[PGconn_struct]) -> int: ...
def PQsetnonblocking(arg1: Optional[PGconn_struct], arg2: int) -> int: ...
def PQisnonblocking(arg1: Optional[PGconn_struct]) -> int: ...
def PQflush(arg1: Optional[PGconn_struct]) -> int: ...
def PQsetSingleRowMode(arg1: Optional[PGconn_struct]) -> int: ...
def PQgetCancel(arg1: Optional[PGconn_struct]) -> PGcancel_struct: ...
def PQfreeCancel(arg1: Optional[PGcancel_struct]) -> None: ...
def PQputCopyData(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> int: ...
def PQsetTraceFlags(arg1: Optional[PGconn_struct], arg2: int) -> None: ...
def PQuntrace(arg1: Optional[PGconn_struct]) -> None: ...
def PQfreemem(arg1: Any) -> None: ...
def _PQencryptPasswordConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: bytes, arg4: bytes) -> Optional[bytes]: ...
def PQmakeEmptyPGresult(arg1: Optional[PGconn_struct], arg2: int) -> PGresult_struct: ...
def _PQpipelineStatus(arg1: Optional[PGconn_struct]) -> int: ...
def _PQenterPipelineMode(arg1: Optional[PGconn_struct]) -> int: ...
def _PQexitPipelineMode(arg1: Optional[PGconn_struct]) -> int: ...
def _PQpipelineSync(arg1: Optional[PGconn_struct]) -> int: ...
def _PQsendFlushRequest(arg1: Optional[PGconn_struct]) -> int: ...
def PQinitOpenSSL(arg1: int, arg2: int) -> None: ...
# autogenerated: end
# fmt: on
# vim: set syntax=python:

View File

@@ -0,0 +1,291 @@
"""
Protocol objects to represent objects exposed by different pq implementations.
"""
# Copyright (C) 2020 The Psycopg Team
from typing import Any, Callable, List, Optional, Sequence, Tuple
from typing import Union, TYPE_CHECKING
from typing_extensions import TypeAlias
from ._enums import Format, Trace
from .._compat import Protocol
if TYPE_CHECKING:
from .misc import PGnotify, ConninfoOption, PGresAttDesc
# An object implementing the buffer protocol (ish)
Buffer: TypeAlias = Union[bytes, bytearray, memoryview]
class PGconn(Protocol):
notice_handler: Optional[Callable[["PGresult"], None]]
notify_handler: Optional[Callable[["PGnotify"], None]]
@classmethod
def connect(cls, conninfo: bytes) -> "PGconn": ...
@classmethod
def connect_start(cls, conninfo: bytes) -> "PGconn": ...
def connect_poll(self) -> int: ...
def finish(self) -> None: ...
@property
def info(self) -> List["ConninfoOption"]: ...
def reset(self) -> None: ...
def reset_start(self) -> None: ...
def reset_poll(self) -> int: ...
@classmethod
def ping(self, conninfo: bytes) -> int: ...
@property
def db(self) -> bytes: ...
@property
def user(self) -> bytes: ...
@property
def password(self) -> bytes: ...
@property
def host(self) -> bytes: ...
@property
def hostaddr(self) -> bytes: ...
@property
def port(self) -> bytes: ...
@property
def tty(self) -> bytes: ...
@property
def options(self) -> bytes: ...
@property
def status(self) -> int: ...
@property
def transaction_status(self) -> int: ...
def parameter_status(self, name: bytes) -> Optional[bytes]: ...
@property
def error_message(self) -> bytes: ...
@property
def server_version(self) -> int: ...
@property
def socket(self) -> int: ...
@property
def backend_pid(self) -> int: ...
@property
def needs_password(self) -> bool: ...
@property
def used_password(self) -> bool: ...
@property
def ssl_in_use(self) -> bool: ...
def exec_(self, command: bytes) -> "PGresult": ...
def send_query(self, command: bytes) -> None: ...
def exec_params(
self,
command: bytes,
param_values: Optional[Sequence[Optional[Buffer]]],
param_types: Optional[Sequence[int]] = None,
param_formats: Optional[Sequence[int]] = None,
result_format: int = Format.TEXT,
) -> "PGresult": ...
def send_query_params(
self,
command: bytes,
param_values: Optional[Sequence[Optional[Buffer]]],
param_types: Optional[Sequence[int]] = None,
param_formats: Optional[Sequence[int]] = None,
result_format: int = Format.TEXT,
) -> None: ...
def send_prepare(
self,
name: bytes,
command: bytes,
param_types: Optional[Sequence[int]] = None,
) -> None: ...
def send_query_prepared(
self,
name: bytes,
param_values: Optional[Sequence[Optional[Buffer]]],
param_formats: Optional[Sequence[int]] = None,
result_format: int = Format.TEXT,
) -> None: ...
def prepare(
self,
name: bytes,
command: bytes,
param_types: Optional[Sequence[int]] = None,
) -> "PGresult": ...
def exec_prepared(
self,
name: bytes,
param_values: Optional[Sequence[Buffer]],
param_formats: Optional[Sequence[int]] = None,
result_format: int = 0,
) -> "PGresult": ...
def describe_prepared(self, name: bytes) -> "PGresult": ...
def send_describe_prepared(self, name: bytes) -> None: ...
def describe_portal(self, name: bytes) -> "PGresult": ...
def send_describe_portal(self, name: bytes) -> None: ...
def get_result(self) -> Optional["PGresult"]: ...
def consume_input(self) -> None: ...
def is_busy(self) -> int: ...
@property
def nonblocking(self) -> int: ...
@nonblocking.setter
def nonblocking(self, arg: int) -> None: ...
def flush(self) -> int: ...
def set_single_row_mode(self) -> None: ...
def get_cancel(self) -> "PGcancel": ...
def notifies(self) -> Optional["PGnotify"]: ...
def put_copy_data(self, buffer: Buffer) -> int: ...
def put_copy_end(self, error: Optional[bytes] = None) -> int: ...
def get_copy_data(self, async_: int) -> Tuple[int, memoryview]: ...
def trace(self, fileno: int) -> None: ...
def set_trace_flags(self, flags: Trace) -> None: ...
def untrace(self) -> None: ...
def encrypt_password(
self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None
) -> bytes: ...
def make_empty_result(self, exec_status: int) -> "PGresult": ...
@property
def pipeline_status(self) -> int: ...
def enter_pipeline_mode(self) -> None: ...
def exit_pipeline_mode(self) -> None: ...
def pipeline_sync(self) -> None: ...
def send_flush_request(self) -> None: ...
class PGresult(Protocol):
def clear(self) -> None: ...
@property
def status(self) -> int: ...
@property
def error_message(self) -> bytes: ...
def error_field(self, fieldcode: int) -> Optional[bytes]: ...
@property
def ntuples(self) -> int: ...
@property
def nfields(self) -> int: ...
def fname(self, column_number: int) -> Optional[bytes]: ...
def ftable(self, column_number: int) -> int: ...
def ftablecol(self, column_number: int) -> int: ...
def fformat(self, column_number: int) -> int: ...
def ftype(self, column_number: int) -> int: ...
def fmod(self, column_number: int) -> int: ...
def fsize(self, column_number: int) -> int: ...
@property
def binary_tuples(self) -> int: ...
def get_value(self, row_number: int, column_number: int) -> Optional[bytes]: ...
@property
def nparams(self) -> int: ...
def param_type(self, param_number: int) -> int: ...
@property
def command_status(self) -> Optional[bytes]: ...
@property
def command_tuples(self) -> Optional[int]: ...
@property
def oid_value(self) -> int: ...
def set_attributes(self, descriptions: List["PGresAttDesc"]) -> None: ...
class PGcancel(Protocol):
def free(self) -> None: ...
def cancel(self) -> None: ...
class Conninfo(Protocol):
@classmethod
def get_defaults(cls) -> List["ConninfoOption"]: ...
@classmethod
def parse(cls, conninfo: bytes) -> List["ConninfoOption"]: ...
@classmethod
def _options_from_array(cls, opts: Sequence[Any]) -> List["ConninfoOption"]: ...
class Escaping(Protocol):
def __init__(self, conn: Optional[PGconn] = None): ...
def escape_literal(self, data: Buffer) -> bytes: ...
def escape_identifier(self, data: Buffer) -> bytes: ...
def escape_string(self, data: Buffer) -> bytes: ...
def escape_bytea(self, data: Buffer) -> bytes: ...
def unescape_bytea(self, data: Buffer) -> bytes: ...

View File

@@ -0,0 +1,146 @@
"""
Various functionalities to make easier to work with the libpq.
"""
# Copyright (C) 2020 The Psycopg Team
import os
import sys
import logging
import ctypes.util
from typing import cast, NamedTuple, Optional, Union
from .abc import PGconn, PGresult
from ._enums import ConnStatus, TransactionStatus, PipelineStatus
from .._compat import cache
from .._encodings import pgconn_encoding
logger = logging.getLogger("psycopg.pq")
OK = ConnStatus.OK
class PGnotify(NamedTuple):
relname: bytes
be_pid: int
extra: bytes
class ConninfoOption(NamedTuple):
keyword: bytes
envvar: Optional[bytes]
compiled: Optional[bytes]
val: Optional[bytes]
label: bytes
dispchar: bytes
dispsize: int
class PGresAttDesc(NamedTuple):
name: bytes
tableid: int
columnid: int
format: int
typid: int
typlen: int
atttypmod: int
@cache
def find_libpq_full_path() -> Optional[str]:
if sys.platform == "win32":
libname = ctypes.util.find_library("libpq.dll")
elif sys.platform == "darwin":
libname = ctypes.util.find_library("libpq.dylib")
# (hopefully) temporary hack: libpq not in a standard place
# https://github.com/orgs/Homebrew/discussions/3595
# If pg_config is available and agrees, let's use its indications.
if not libname:
try:
import subprocess as sp
libdir = sp.check_output(["pg_config", "--libdir"]).strip().decode()
libname = os.path.join(libdir, "libpq.dylib")
if not os.path.exists(libname):
libname = None
except Exception as ex:
logger.debug("couldn't use pg_config to find libpq: %s", ex)
else:
libname = ctypes.util.find_library("pq")
return libname
def error_message(obj: Union[PGconn, PGresult], encoding: str = "utf8") -> str:
"""
Return an error message from a `PGconn` or `PGresult`.
The return value is a `!str` (unlike pq data which is usually `!bytes`):
use the connection encoding if available, otherwise the `!encoding`
parameter as a fallback for decoding. Don't raise exceptions on decoding
errors.
"""
bmsg: bytes
if hasattr(obj, "error_field"):
# obj is a PGresult
obj = cast(PGresult, obj)
bmsg = obj.error_message
# strip severity and whitespaces
if bmsg:
bmsg = bmsg.split(b":", 1)[-1].strip()
elif hasattr(obj, "error_message"):
# obj is a PGconn
if obj.status == OK:
encoding = pgconn_encoding(obj)
bmsg = obj.error_message
# strip severity and whitespaces
if bmsg:
bmsg = bmsg.split(b":", 1)[-1].strip()
else:
raise TypeError(f"PGconn or PGresult expected, got {type(obj).__name__}")
if bmsg:
msg = bmsg.decode(encoding, "replace")
else:
msg = "no details available"
return msg
def connection_summary(pgconn: PGconn) -> str:
"""
Return summary information on a connection.
Useful for __repr__
"""
parts = []
if pgconn.status == OK:
# Put together the [STATUS]
status = TransactionStatus(pgconn.transaction_status).name
if pgconn.pipeline_status:
status += f", pipeline={PipelineStatus(pgconn.pipeline_status).name}"
# Put together the (CONNECTION)
if not pgconn.host.startswith(b"/"):
parts.append(("host", pgconn.host.decode()))
if pgconn.port != b"5432":
parts.append(("port", pgconn.port.decode()))
if pgconn.user != pgconn.db:
parts.append(("user", pgconn.user.decode()))
parts.append(("database", pgconn.db.decode()))
else:
status = ConnStatus(pgconn.status).name
sparts = " ".join("%s=%s" % part for part in parts)
if sparts:
sparts = f" ({sparts})"
return f"[{status}]{sparts}"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,251 @@
"""
psycopg row factories
"""
# Copyright (C) 2021 The Psycopg Team
import functools
from typing import Any, Callable, Dict, List, Optional, NamedTuple, NoReturn
from typing import TYPE_CHECKING, Sequence, Tuple, Type, TypeVar
from collections import namedtuple
from typing_extensions import TypeAlias
from . import pq
from . import errors as e
from ._compat import Protocol
from ._encodings import _as_python_identifier
if TYPE_CHECKING:
from .cursor import BaseCursor, Cursor
from .cursor_async import AsyncCursor
from psycopg.pq.abc import PGresult
COMMAND_OK = pq.ExecStatus.COMMAND_OK
TUPLES_OK = pq.ExecStatus.TUPLES_OK
SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
T = TypeVar("T", covariant=True)
# Row factories
Row = TypeVar("Row", covariant=True)
class RowMaker(Protocol[Row]):
"""
Callable protocol taking a sequence of value and returning an object.
The sequence of value is what is returned from a database query, already
adapted to the right Python types. The return value is the object that your
program would like to receive: by default (`tuple_row()`) it is a simple
tuple, but it may be any type of object.
Typically, `!RowMaker` functions are returned by `RowFactory`.
"""
def __call__(self, __values: Sequence[Any]) -> Row: ...
class RowFactory(Protocol[Row]):
"""
Callable protocol taking a `~psycopg.Cursor` and returning a `RowMaker`.
A `!RowFactory` is typically called when a `!Cursor` receives a result.
This way it can inspect the cursor state (for instance the
`~psycopg.Cursor.description` attribute) and help a `!RowMaker` to create
a complete object.
For instance the `dict_row()` `!RowFactory` uses the names of the column to
define the dictionary key and returns a `!RowMaker` function which would
use the values to create a dictionary for each record.
"""
def __call__(self, __cursor: "Cursor[Any]") -> RowMaker[Row]: ...
class AsyncRowFactory(Protocol[Row]):
"""
Like `RowFactory`, taking an async cursor as argument.
"""
def __call__(self, __cursor: "AsyncCursor[Any]") -> RowMaker[Row]: ...
class BaseRowFactory(Protocol[Row]):
"""
Like `RowFactory`, taking either type of cursor as argument.
"""
def __call__(self, __cursor: "BaseCursor[Any, Any]") -> RowMaker[Row]: ...
TupleRow: TypeAlias = Tuple[Any, ...]
"""
An alias for the type returned by `tuple_row()` (i.e. a tuple of any content).
"""
DictRow: TypeAlias = Dict[str, Any]
"""
An alias for the type returned by `dict_row()`
A `!DictRow` is a dictionary with keys as string and any value returned by the
database.
"""
def tuple_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[TupleRow]":
r"""Row factory to represent rows as simple tuples.
This is the default factory, used when `~psycopg.Connection.connect()` or
`~psycopg.Connection.cursor()` are called without a `!row_factory`
parameter.
"""
# Implementation detail: make sure this is the tuple type itself, not an
# equivalent function, because the C code fast-paths on it.
return tuple
def dict_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[DictRow]":
"""Row factory to represent rows as dictionaries.
The dictionary keys are taken from the column names of the returned columns.
"""
names = _get_names(cursor)
if names is None:
return no_result
def dict_row_(values: Sequence[Any]) -> Dict[str, Any]:
return dict(zip(names, values))
return dict_row_
def namedtuple_row(
cursor: "BaseCursor[Any, Any]",
) -> "RowMaker[NamedTuple]":
"""Row factory to represent rows as `~collections.namedtuple`.
The field names are taken from the column names of the returned columns,
with some mangling to deal with invalid names.
"""
res = cursor.pgresult
if not res:
return no_result
nfields = _get_nfields(res)
if nfields is None:
return no_result
nt = _make_nt(cursor._encoding, *(res.fname(i) for i in range(nfields)))
return nt._make
@functools.lru_cache(512)
def _make_nt(enc: str, *names: bytes) -> Type[NamedTuple]:
snames = tuple(_as_python_identifier(n.decode(enc)) for n in names)
return namedtuple("Row", snames) # type: ignore[return-value]
def class_row(cls: Type[T]) -> BaseRowFactory[T]:
r"""Generate a row factory to represent rows as instances of the class `!cls`.
The class must support every output column name as a keyword parameter.
:param cls: The class to return for each row. It must support the fields
returned by the query as keyword arguments.
:rtype: `!Callable[[Cursor],` `RowMaker`\[~T]]
"""
def class_row_(cursor: "BaseCursor[Any, Any]") -> "RowMaker[T]":
names = _get_names(cursor)
if names is None:
return no_result
def class_row__(values: Sequence[Any]) -> T:
return cls(**dict(zip(names, values)))
return class_row__
return class_row_
def args_row(func: Callable[..., T]) -> BaseRowFactory[T]:
"""Generate a row factory calling `!func` with positional parameters for every row.
:param func: The function to call for each row. It must support the fields
returned by the query as positional arguments.
"""
def args_row_(cur: "BaseCursor[Any, T]") -> "RowMaker[T]":
def args_row__(values: Sequence[Any]) -> T:
return func(*values)
return args_row__
return args_row_
def kwargs_row(func: Callable[..., T]) -> BaseRowFactory[T]:
"""Generate a row factory calling `!func` with keyword parameters for every row.
:param func: The function to call for each row. It must support the fields
returned by the query as keyword arguments.
"""
def kwargs_row_(cursor: "BaseCursor[Any, T]") -> "RowMaker[T]":
names = _get_names(cursor)
if names is None:
return no_result
def kwargs_row__(values: Sequence[Any]) -> T:
return func(**dict(zip(names, values)))
return kwargs_row__
return kwargs_row_
def no_result(values: Sequence[Any]) -> NoReturn:
"""A `RowMaker` that always fail.
It can be used as return value for a `RowFactory` called with no result.
Note that the `!RowFactory` *will* be called with no result, but the
resulting `!RowMaker` never should.
"""
raise e.InterfaceError("the cursor doesn't have a result")
def _get_names(cursor: "BaseCursor[Any, Any]") -> Optional[List[str]]:
res = cursor.pgresult
if not res:
return None
nfields = _get_nfields(res)
if nfields is None:
return None
enc = cursor._encoding
return [
res.fname(i).decode(enc) for i in range(nfields) # type: ignore[union-attr]
]
def _get_nfields(res: "PGresult") -> Optional[int]:
"""
Return the number of columns in a result, if it returns tuples else None
Take into account the special case of results with zero columns.
"""
nfields = res.nfields
if (
res.status == TUPLES_OK
or res.status == SINGLE_TUPLE
# "describe" in named cursors
or (res.status == COMMAND_OK and nfields)
):
return nfields
else:
return None

View File

@@ -0,0 +1,473 @@
"""
psycopg server-side cursor objects.
"""
# Copyright (C) 2020 The Psycopg Team
from typing import Any, AsyncIterator, List, Iterable, Iterator
from typing import Optional, TYPE_CHECKING, overload
from warnings import warn
from . import pq
from . import sql
from . import errors as e
from .abc import ConnectionType, Query, Params, PQGen
from .rows import Row, RowFactory, AsyncRowFactory
from .cursor import BaseCursor, Cursor
from ._compat import Self
from .generators import execute
from .cursor_async import AsyncCursor
if TYPE_CHECKING:
from .connection import Connection
from .connection_async import AsyncConnection
DEFAULT_ITERSIZE = 100
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
COMMAND_OK = pq.ExecStatus.COMMAND_OK
TUPLES_OK = pq.ExecStatus.TUPLES_OK
IDLE = pq.TransactionStatus.IDLE
INTRANS = pq.TransactionStatus.INTRANS
class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
"""Mixin to add ServerCursor behaviour and implementation a BaseCursor."""
__slots__ = "_name _scrollable _withhold _described itersize _format".split()
def __init__(
self,
name: str,
scrollable: Optional[bool],
withhold: bool,
):
self._name = name
self._scrollable = scrollable
self._withhold = withhold
self._described = False
self.itersize: int = DEFAULT_ITERSIZE
self._format = TEXT
def __repr__(self) -> str:
# Insert the name as the second word
parts = super().__repr__().split(None, 1)
parts.insert(1, f"{self._name!r}")
return " ".join(parts)
@property
def name(self) -> str:
"""The name of the cursor."""
return self._name
@property
def scrollable(self) -> Optional[bool]:
"""
Whether the cursor is scrollable or not.
If `!None` leave the choice to the server. Use `!True` if you want to
use `scroll()` on the cursor.
"""
return self._scrollable
@property
def withhold(self) -> bool:
"""
If the cursor can be used after the creating transaction has committed.
"""
return self._withhold
@property
def rownumber(self) -> Optional[int]:
"""Index of the next row to fetch in the current result.
`!None` if there is no result to fetch.
"""
res = self.pgresult
# command_status is empty if the result comes from
# describe_portal, which means that we have just executed the DECLARE,
# so we can assume we are at the first row.
tuples = res and (res.status == TUPLES_OK or res.command_status == b"")
return self._pos if tuples else None
def _declare_gen(
self,
query: Query,
params: Optional[Params] = None,
binary: Optional[bool] = None,
) -> PQGen[None]:
"""Generator implementing `ServerCursor.execute()`."""
query = self._make_declare_statement(query)
# If the cursor is being reused, the previous one must be closed.
if self._described:
yield from self._close_gen()
self._described = False
yield from self._start_query(query)
pgq = self._convert_query(query, params)
self._execute_send(pgq, force_extended=True)
results = yield from execute(self._conn.pgconn)
if results[-1].status != COMMAND_OK:
self._raise_for_result(results[-1])
# Set the format, which will be used by describe and fetch operations
if binary is None:
self._format = self.format
else:
self._format = BINARY if binary else TEXT
# The above result only returned COMMAND_OK. Get the cursor shape
yield from self._describe_gen()
def _describe_gen(self) -> PQGen[None]:
self._pgconn.send_describe_portal(self._name.encode(self._encoding))
results = yield from execute(self._pgconn)
self._check_results(results)
self._results = results
self._select_current_result(0, format=self._format)
self._described = True
def _close_gen(self) -> PQGen[None]:
ts = self._conn.pgconn.transaction_status
# if the connection is not in a sane state, don't even try
if ts != IDLE and ts != INTRANS:
return
# If we are IDLE, a WITHOUT HOLD cursor will surely have gone already.
if not self._withhold and ts == IDLE:
return
# if we didn't declare the cursor ourselves we still have to close it
# but we must make sure it exists.
if not self._described:
query = sql.SQL(
"SELECT 1 FROM pg_catalog.pg_cursors WHERE name = {}"
).format(sql.Literal(self._name))
res = yield from self._conn._exec_command(query)
# pipeline mode otherwise, unsupported here.
assert res is not None
if res.ntuples == 0:
return
query = sql.SQL("CLOSE {}").format(sql.Identifier(self._name))
yield from self._conn._exec_command(query)
def _fetch_gen(self, num: Optional[int]) -> PQGen[List[Row]]:
if self.closed:
raise e.InterfaceError("the cursor is closed")
# If we are stealing the cursor, make sure we know its shape
if not self._described:
yield from self._start_query()
yield from self._describe_gen()
query = sql.SQL("FETCH FORWARD {} FROM {}").format(
sql.SQL("ALL") if num is None else sql.Literal(num),
sql.Identifier(self._name),
)
res = yield from self._conn._exec_command(query, result_format=self._format)
# pipeline mode otherwise, unsupported here.
assert res is not None
self.pgresult = res
self._tx.set_pgresult(res, set_loaders=False)
return self._tx.load_rows(0, res.ntuples, self._make_row)
def _scroll_gen(self, value: int, mode: str) -> PQGen[None]:
if mode not in ("relative", "absolute"):
raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
query = sql.SQL("MOVE{} {} FROM {}").format(
sql.SQL(" ABSOLUTE" if mode == "absolute" else ""),
sql.Literal(value),
sql.Identifier(self._name),
)
yield from self._conn._exec_command(query)
def _make_declare_statement(self, query: Query) -> sql.Composed:
if isinstance(query, bytes):
query = query.decode(self._encoding)
if not isinstance(query, sql.Composable):
query = sql.SQL(query)
parts = [
sql.SQL("DECLARE"),
sql.Identifier(self._name),
]
if self._scrollable is not None:
parts.append(sql.SQL("SCROLL" if self._scrollable else "NO SCROLL"))
parts.append(sql.SQL("CURSOR"))
if self._withhold:
parts.append(sql.SQL("WITH HOLD"))
parts.append(sql.SQL("FOR"))
parts.append(query)
return sql.SQL(" ").join(parts)
class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]):
__module__ = "psycopg"
__slots__ = ()
@overload
def __init__(
self,
connection: "Connection[Row]",
name: str,
*,
scrollable: Optional[bool] = None,
withhold: bool = False,
): ...
@overload
def __init__(
self,
connection: "Connection[Any]",
name: str,
*,
row_factory: RowFactory[Row],
scrollable: Optional[bool] = None,
withhold: bool = False,
): ...
def __init__(
self,
connection: "Connection[Any]",
name: str,
*,
row_factory: Optional[RowFactory[Row]] = None,
scrollable: Optional[bool] = None,
withhold: bool = False,
):
Cursor.__init__(
self, connection, row_factory=row_factory or connection.row_factory
)
ServerCursorMixin.__init__(self, name, scrollable, withhold)
def __del__(self) -> None:
if not self.closed:
warn(
f"the server-side cursor {self} was deleted while still open."
" Please use 'with' or '.close()' to close the cursor properly",
ResourceWarning,
)
def close(self) -> None:
"""
Close the current cursor and free associated resources.
"""
with self._conn.lock:
if self.closed:
return
if not self._conn.closed:
self._conn.wait(self._close_gen())
super().close()
def execute(
self,
query: Query,
params: Optional[Params] = None,
*,
binary: Optional[bool] = None,
**kwargs: Any,
) -> Self:
"""
Open a cursor to execute a query to the database.
"""
if kwargs:
raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
if self._pgconn.pipeline_status:
raise e.NotSupportedError(
"server-side cursors not supported in pipeline mode"
)
try:
with self._conn.lock:
self._conn.wait(self._declare_gen(query, params, binary))
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
return self
def executemany(
self,
query: Query,
params_seq: Iterable[Params],
*,
returning: bool = True,
) -> None:
"""Method not implemented for server-side cursors."""
raise e.NotSupportedError("executemany not supported on server-side cursors")
def fetchone(self) -> Optional[Row]:
with self._conn.lock:
recs = self._conn.wait(self._fetch_gen(1))
if recs:
self._pos += 1
return recs[0]
else:
return None
def fetchmany(self, size: int = 0) -> List[Row]:
if not size:
size = self.arraysize
with self._conn.lock:
recs = self._conn.wait(self._fetch_gen(size))
self._pos += len(recs)
return recs
def fetchall(self) -> List[Row]:
with self._conn.lock:
recs = self._conn.wait(self._fetch_gen(None))
self._pos += len(recs)
return recs
def __iter__(self) -> Iterator[Row]:
while True:
with self._conn.lock:
recs = self._conn.wait(self._fetch_gen(self.itersize))
for rec in recs:
self._pos += 1
yield rec
if len(recs) < self.itersize:
break
def scroll(self, value: int, mode: str = "relative") -> None:
with self._conn.lock:
self._conn.wait(self._scroll_gen(value, mode))
# Postgres doesn't have a reliable way to report a cursor out of bound
if mode == "relative":
self._pos += value
else:
self._pos = value
class AsyncServerCursor(
ServerCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row]
):
__module__ = "psycopg"
__slots__ = ()
@overload
def __init__(
self,
connection: "AsyncConnection[Row]",
name: str,
*,
scrollable: Optional[bool] = None,
withhold: bool = False,
): ...
@overload
def __init__(
self,
connection: "AsyncConnection[Any]",
name: str,
*,
row_factory: AsyncRowFactory[Row],
scrollable: Optional[bool] = None,
withhold: bool = False,
): ...
def __init__(
self,
connection: "AsyncConnection[Any]",
name: str,
*,
row_factory: Optional[AsyncRowFactory[Row]] = None,
scrollable: Optional[bool] = None,
withhold: bool = False,
):
AsyncCursor.__init__(
self, connection, row_factory=row_factory or connection.row_factory
)
ServerCursorMixin.__init__(self, name, scrollable, withhold)
def __del__(self) -> None:
if not self.closed:
warn(
f"the server-side cursor {self} was deleted while still open."
" Please use 'with' or '.close()' to close the cursor properly",
ResourceWarning,
)
async def close(self) -> None:
async with self._conn.lock:
if self.closed:
return
if not self._conn.closed:
await self._conn.wait(self._close_gen())
await super().close()
async def execute(
self,
query: Query,
params: Optional[Params] = None,
*,
binary: Optional[bool] = None,
**kwargs: Any,
) -> Self:
if kwargs:
raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
if self._pgconn.pipeline_status:
raise e.NotSupportedError(
"server-side cursors not supported in pipeline mode"
)
try:
async with self._conn.lock:
await self._conn.wait(self._declare_gen(query, params, binary))
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
return self
async def executemany(
self,
query: Query,
params_seq: Iterable[Params],
*,
returning: bool = True,
) -> None:
raise e.NotSupportedError("executemany not supported on server-side cursors")
async def fetchone(self) -> Optional[Row]:
async with self._conn.lock:
recs = await self._conn.wait(self._fetch_gen(1))
if recs:
self._pos += 1
return recs[0]
else:
return None
async def fetchmany(self, size: int = 0) -> List[Row]:
if not size:
size = self.arraysize
async with self._conn.lock:
recs = await self._conn.wait(self._fetch_gen(size))
self._pos += len(recs)
return recs
async def fetchall(self) -> List[Row]:
async with self._conn.lock:
recs = await self._conn.wait(self._fetch_gen(None))
self._pos += len(recs)
return recs
async def __aiter__(self) -> AsyncIterator[Row]:
while True:
async with self._conn.lock:
recs = await self._conn.wait(self._fetch_gen(self.itersize))
for rec in recs:
self._pos += 1
yield rec
if len(recs) < self.itersize:
break
async def scroll(self, value: int, mode: str = "relative") -> None:
async with self._conn.lock:
await self._conn.wait(self._scroll_gen(value, mode))

View File

@@ -0,0 +1,467 @@
"""
SQL composition utility module
"""
# Copyright (C) 2020 The Psycopg Team
import codecs
import string
from abc import ABC, abstractmethod
from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union
from .pq import Escaping
from .abc import AdaptContext
from .adapt import Transformer, PyFormat
from ._compat import LiteralString
from ._encodings import conn_encoding
def quote(obj: Any, context: Optional[AdaptContext] = None) -> str:
"""
Adapt a Python object to a quoted SQL string.
Use this function only if you absolutely want to convert a Python string to
an SQL quoted literal to use e.g. to generate batch SQL and you won't have
a connection available when you will need to use it.
This function is relatively inefficient, because it doesn't cache the
adaptation rules. If you pass a `!context` you can adapt the adaptation
rules used, otherwise only global rules are used.
"""
return Literal(obj).as_string(context)
class Composable(ABC):
"""
Abstract base class for objects that can be used to compose an SQL string.
`!Composable` objects can be passed directly to
`~psycopg.Cursor.execute()`, `~psycopg.Cursor.executemany()`,
`~psycopg.Cursor.copy()` in place of the query string.
`!Composable` objects can be joined using the ``+`` operator: the result
will be a `Composed` instance containing the objects joined. The operator
``*`` is also supported with an integer argument: the result is a
`!Composed` instance containing the left argument repeated as many times as
requested.
"""
def __init__(self, obj: Any):
self._obj = obj
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._obj!r})"
@abstractmethod
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
"""
Return the value of the object as bytes.
:param context: the context to evaluate the object into.
:type context: `connection` or `cursor`
The method is automatically invoked by `~psycopg.Cursor.execute()`,
`~psycopg.Cursor.executemany()`, `~psycopg.Cursor.copy()` if a
`!Composable` is passed instead of the query string.
"""
raise NotImplementedError
def as_string(self, context: Optional[AdaptContext]) -> str:
"""
Return the value of the object as string.
:param context: the context to evaluate the string into.
:type context: `connection` or `cursor`
"""
conn = context.connection if context else None
enc = conn_encoding(conn)
b = self.as_bytes(context)
if isinstance(b, bytes):
return b.decode(enc)
else:
# buffer object
return codecs.lookup(enc).decode(b)[0]
def __add__(self, other: "Composable") -> "Composed":
if isinstance(other, Composed):
return Composed([self]) + other
if isinstance(other, Composable):
return Composed([self]) + Composed([other])
else:
return NotImplemented
def __mul__(self, n: int) -> "Composed":
return Composed([self] * n)
def __eq__(self, other: Any) -> bool:
return type(self) is type(other) and self._obj == other._obj
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
class Composed(Composable):
"""
A `Composable` object made of a sequence of `!Composable`.
The object is usually created using `!Composable` operators and methods.
However it is possible to create a `!Composed` directly specifying a
sequence of objects as arguments: if they are not `!Composable` they will
be wrapped in a `Literal`.
Example::
>>> comp = sql.Composed(
... [sql.SQL("INSERT INTO "), sql.Identifier("table")])
>>> print(comp.as_string(conn))
INSERT INTO "table"
`!Composed` objects are iterable (so they can be used in `SQL.join` for
instance).
"""
_obj: List[Composable]
def __init__(self, seq: Sequence[Any]):
seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq]
super().__init__(seq)
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
return b"".join(obj.as_bytes(context) for obj in self._obj)
def __iter__(self) -> Iterator[Composable]:
return iter(self._obj)
def __add__(self, other: Composable) -> "Composed":
if isinstance(other, Composed):
return Composed(self._obj + other._obj)
if isinstance(other, Composable):
return Composed(self._obj + [other])
else:
return NotImplemented
def join(self, joiner: Union["SQL", LiteralString]) -> "Composed":
"""
Return a new `!Composed` interposing the `!joiner` with the `!Composed` items.
The `!joiner` must be a `SQL` or a string which will be interpreted as
an `SQL`.
Example::
>>> fields = sql.Identifier('foo') + sql.Identifier('bar') # a Composed
>>> print(fields.join(', ').as_string(conn))
"foo", "bar"
"""
if isinstance(joiner, str):
joiner = SQL(joiner)
elif not isinstance(joiner, SQL):
raise TypeError(
"Composed.join() argument must be strings or SQL,"
f" got {joiner!r} instead"
)
return joiner.join(self._obj)
class SQL(Composable):
"""
A `Composable` representing a snippet of SQL statement.
`!SQL` exposes `join()` and `format()` methods useful to create a template
where to merge variable parts of a query (for instance field or table
names).
The `!obj` string doesn't undergo any form of escaping, so it is not
suitable to represent variable identifiers or values: you should only use
it to pass constant strings representing templates or snippets of SQL
statements; use other objects such as `Identifier` or `Literal` to
represent variable parts.
Example::
>>> query = sql.SQL("SELECT {0} FROM {1}").format(
... sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]),
... sql.Identifier('table'))
>>> print(query.as_string(conn))
SELECT "foo", "bar" FROM "table"
"""
_obj: LiteralString
_formatter = string.Formatter()
def __init__(self, obj: LiteralString):
super().__init__(obj)
if not isinstance(obj, str):
raise TypeError(f"SQL values must be strings, got {obj!r} instead")
def as_string(self, context: Optional[AdaptContext]) -> str:
return self._obj
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
enc = "utf-8"
if context:
enc = conn_encoding(context.connection)
return self._obj.encode(enc)
def format(self, *args: Any, **kwargs: Any) -> Composed:
"""
Merge `Composable` objects into a template.
:param args: parameters to replace to numbered (``{0}``, ``{1}``) or
auto-numbered (``{}``) placeholders
:param kwargs: parameters to replace to named (``{name}``) placeholders
:return: the union of the `!SQL` string with placeholders replaced
:rtype: `Composed`
The method is similar to the Python `str.format()` method: the string
template supports auto-numbered (``{}``), numbered (``{0}``,
``{1}``...), and named placeholders (``{name}``), with positional
arguments replacing the numbered placeholders and keywords replacing
the named ones. However placeholder modifiers (``{0!r}``, ``{0:<10}``)
are not supported.
If a `!Composable` objects is passed to the template it will be merged
according to its `as_string()` method. If any other Python object is
passed, it will be wrapped in a `Literal` object and so escaped
according to SQL rules.
Example::
>>> print(sql.SQL("SELECT * FROM {} WHERE {} = %s")
... .format(sql.Identifier('people'), sql.Identifier('id'))
... .as_string(conn))
SELECT * FROM "people" WHERE "id" = %s
>>> print(sql.SQL("SELECT * FROM {tbl} WHERE name = {name}")
... .format(tbl=sql.Identifier('people'), name="O'Rourke"))
... .as_string(conn))
SELECT * FROM "people" WHERE name = 'O''Rourke'
"""
rv: List[Composable] = []
autonum: Optional[int] = 0
# TODO: this is probably not the right way to whitelist pre
# pyre complains. Will wait for mypy to complain too to fix.
pre: LiteralString
for pre, name, spec, conv in self._formatter.parse(self._obj):
if spec:
raise ValueError("no format specification supported by SQL")
if conv:
raise ValueError("no format conversion supported by SQL")
if pre:
rv.append(SQL(pre))
if name is None:
continue
if name.isdigit():
if autonum:
raise ValueError(
"cannot switch from automatic field numbering to manual"
)
rv.append(args[int(name)])
autonum = None
elif not name:
if autonum is None:
raise ValueError(
"cannot switch from manual field numbering to automatic"
)
rv.append(args[autonum])
autonum += 1
else:
rv.append(kwargs[name])
return Composed(rv)
def join(self, seq: Iterable[Composable]) -> Composed:
"""
Join a sequence of `Composable`.
:param seq: the elements to join.
:type seq: iterable of `!Composable`
Use the `!SQL` object's string to separate the elements in `!seq`.
Note that `Composed` objects are iterable too, so they can be used as
argument for this method.
Example::
>>> snip = sql.SQL(', ').join(
... sql.Identifier(n) for n in ['foo', 'bar', 'baz'])
>>> print(snip.as_string(conn))
"foo", "bar", "baz"
"""
rv = []
it = iter(seq)
try:
rv.append(next(it))
except StopIteration:
pass
else:
for i in it:
rv.append(self)
rv.append(i)
return Composed(rv)
class Identifier(Composable):
"""
A `Composable` representing an SQL identifier or a dot-separated sequence.
Identifiers usually represent names of database objects, such as tables or
fields. PostgreSQL identifiers follow `different rules`__ than SQL string
literals for escaping (e.g. they use double quotes instead of single).
.. __: https://www.postgresql.org/docs/current/sql-syntax-lexical.html# \
SQL-SYNTAX-IDENTIFIERS
Example::
>>> t1 = sql.Identifier("foo")
>>> t2 = sql.Identifier("ba'r")
>>> t3 = sql.Identifier('ba"z')
>>> print(sql.SQL(', ').join([t1, t2, t3]).as_string(conn))
"foo", "ba'r", "ba""z"
Multiple strings can be passed to the object to represent a qualified name,
i.e. a dot-separated sequence of identifiers.
Example::
>>> query = sql.SQL("SELECT {} FROM {}").format(
... sql.Identifier("table", "field"),
... sql.Identifier("schema", "table"))
>>> print(query.as_string(conn))
SELECT "table"."field" FROM "schema"."table"
"""
_obj: Sequence[str]
def __init__(self, *strings: str):
# init super() now to make the __repr__ not explode in case of error
super().__init__(strings)
if not strings:
raise TypeError("Identifier cannot be empty")
for s in strings:
if not isinstance(s, str):
raise TypeError(
f"SQL identifier parts must be strings, got {s!r} instead"
)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})"
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
conn = context.connection if context else None
if not conn:
raise ValueError("a connection is necessary for Identifier")
esc = Escaping(conn.pgconn)
enc = conn_encoding(conn)
escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
return b".".join(escs)
class Literal(Composable):
"""
A `Composable` representing an SQL value to include in a query.
Usually you will want to include placeholders in the query and pass values
as `~cursor.execute()` arguments. If however you really really need to
include a literal value in the query you can use this object.
The string returned by `!as_string()` follows the normal :ref:`adaptation
rules <types-adaptation>` for Python objects.
Example::
>>> s1 = sql.Literal("fo'o")
>>> s2 = sql.Literal(42)
>>> s3 = sql.Literal(date(2000, 1, 1))
>>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn))
'fo''o', 42, '2000-01-01'::date
"""
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
tx = Transformer.from_context(context)
return tx.as_literal(self._obj)
class Placeholder(Composable):
"""A `Composable` representing a placeholder for query parameters.
If the name is specified, generate a named placeholder (e.g. ``%(name)s``,
``%(name)b``), otherwise generate a positional placeholder (e.g. ``%s``,
``%b``).
The object is useful to generate SQL queries with a variable number of
arguments.
Examples::
>>> names = ['foo', 'bar', 'baz']
>>> q1 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format(
... sql.SQL(', ').join(map(sql.Identifier, names)),
... sql.SQL(', ').join(sql.Placeholder() * len(names)))
>>> print(q1.as_string(conn))
INSERT INTO my_table ("foo", "bar", "baz") VALUES (%s, %s, %s)
>>> q2 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format(
... sql.SQL(', ').join(map(sql.Identifier, names)),
... sql.SQL(', ').join(map(sql.Placeholder, names)))
>>> print(q2.as_string(conn))
INSERT INTO my_table ("foo", "bar", "baz") VALUES (%(foo)s, %(bar)s, %(baz)s)
"""
def __init__(self, name: str = "", format: Union[str, PyFormat] = PyFormat.AUTO):
super().__init__(name)
if not isinstance(name, str):
raise TypeError(f"expected string as name, got {name!r}")
if ")" in name:
raise ValueError(f"invalid name: {name!r}")
if type(format) is str:
format = PyFormat(format)
if not isinstance(format, PyFormat):
raise TypeError(
f"expected PyFormat as format, got {type(format).__name__!r}"
)
self._format: PyFormat = format
def __repr__(self) -> str:
parts = []
if self._obj:
parts.append(repr(self._obj))
if self._format is not PyFormat.AUTO:
parts.append(f"format={self._format.name}")
return f"{self.__class__.__name__}({', '.join(parts)})"
def as_string(self, context: Optional[AdaptContext]) -> str:
code = self._format.value
return f"%({self._obj}){code}" if self._obj else f"%{code}"
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
conn = context.connection if context else None
enc = conn_encoding(conn)
return self.as_string(context).encode(enc)
# Literals
NULL = SQL("NULL")
DEFAULT = SQL("DEFAULT")

View File

@@ -0,0 +1,288 @@
"""
Transaction context managers returned by Connection.transaction()
"""
# Copyright (C) 2020 The Psycopg Team
import logging
from types import TracebackType
from typing import Generic, Iterator, Optional, Type, Union, TYPE_CHECKING
from . import pq
from . import sql
from . import errors as e
from .abc import ConnectionType, PQGen
from ._compat import Self
from .pq.misc import connection_summary
if TYPE_CHECKING:
from typing import Any
from .connection import Connection
from .connection_async import AsyncConnection
IDLE = pq.TransactionStatus.IDLE
OK = pq.ConnStatus.OK
logger = logging.getLogger(__name__)
class Rollback(Exception):
"""
Exit the current `Transaction` context immediately and rollback any changes
made within this context.
If a transaction context is specified in the constructor, rollback
enclosing transactions contexts up to and including the one specified.
"""
__module__ = "psycopg"
def __init__(
self,
transaction: Union["Transaction", "AsyncTransaction", None] = None,
):
self.transaction = transaction
def __repr__(self) -> str:
return f"{self.__class__.__qualname__}({self.transaction!r})"
class OutOfOrderTransactionNesting(e.ProgrammingError):
"""Out-of-order transaction nesting detected"""
class BaseTransaction(Generic[ConnectionType]):
def __init__(
self,
connection: ConnectionType,
savepoint_name: Optional[str] = None,
force_rollback: bool = False,
):
self._conn = connection
self.pgconn = self._conn.pgconn
self._savepoint_name = savepoint_name or ""
self.force_rollback = force_rollback
self._entered = self._exited = False
self._outer_transaction = False
self._stack_index = -1
@property
def savepoint_name(self) -> Optional[str]:
"""
The name of the savepoint; `!None` if handling the main transaction.
"""
# Yes, it may change on __enter__. No, I don't care, because the
# un-entered state is outside the public interface.
return self._savepoint_name
def __repr__(self) -> str:
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
info = connection_summary(self.pgconn)
if not self._entered:
status = "inactive"
elif not self._exited:
status = "active"
else:
status = "terminated"
sp = f"{self.savepoint_name!r} " if self.savepoint_name else ""
return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>"
def _enter_gen(self) -> PQGen[None]:
if self._entered:
raise TypeError("transaction blocks can be used only once")
self._entered = True
self._push_savepoint()
for command in self._get_enter_commands():
yield from self._conn._exec_command(command)
def _exit_gen(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> PQGen[bool]:
if not exc_val and not self.force_rollback:
yield from self._commit_gen()
return False
else:
# try to rollback, but if there are problems (connection in a bad
# state) just warn without clobbering the exception bubbling up.
try:
return (yield from self._rollback_gen(exc_val))
except OutOfOrderTransactionNesting:
# Clobber an exception happened in the block with the exception
# caused by out-of-order transaction detected, so make the
# behaviour consistent with _commit_gen and to make sure the
# user fixes this condition, which is unrelated from
# operational error that might arise in the block.
raise
except Exception as exc2:
logger.warning("error ignored in rollback of %s: %s", self, exc2)
return False
def _commit_gen(self) -> PQGen[None]:
ex = self._pop_savepoint("commit")
self._exited = True
if ex:
raise ex
for command in self._get_commit_commands():
yield from self._conn._exec_command(command)
def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]:
if isinstance(exc_val, Rollback):
logger.debug(f"{self._conn}: Explicit rollback from: ", exc_info=True)
ex = self._pop_savepoint("rollback")
self._exited = True
if ex:
raise ex
for command in self._get_rollback_commands():
yield from self._conn._exec_command(command)
if isinstance(exc_val, Rollback):
if not exc_val.transaction or exc_val.transaction is self:
return True # Swallow the exception
return False
def _get_enter_commands(self) -> Iterator[bytes]:
if self._outer_transaction:
yield self._conn._get_tx_start_command()
if self._savepoint_name:
yield (
sql.SQL("SAVEPOINT {}")
.format(sql.Identifier(self._savepoint_name))
.as_bytes(self._conn)
)
def _get_commit_commands(self) -> Iterator[bytes]:
if self._savepoint_name and not self._outer_transaction:
yield (
sql.SQL("RELEASE {}")
.format(sql.Identifier(self._savepoint_name))
.as_bytes(self._conn)
)
if self._outer_transaction:
assert not self._conn._num_transactions
yield b"COMMIT"
def _get_rollback_commands(self) -> Iterator[bytes]:
if self._savepoint_name and not self._outer_transaction:
yield (
sql.SQL("ROLLBACK TO {n}")
.format(n=sql.Identifier(self._savepoint_name))
.as_bytes(self._conn)
)
yield (
sql.SQL("RELEASE {n}")
.format(n=sql.Identifier(self._savepoint_name))
.as_bytes(self._conn)
)
if self._outer_transaction:
assert not self._conn._num_transactions
yield b"ROLLBACK"
# Also clear the prepared statements cache.
if self._conn._prepared.clear():
yield from self._conn._prepared.get_maintenance_commands()
def _push_savepoint(self) -> None:
"""
Push the transaction on the connection transactions stack.
Also set the internal state of the object and verify consistency.
"""
self._outer_transaction = self.pgconn.transaction_status == IDLE
if self._outer_transaction:
# outer transaction: if no name it's only a begin, else
# there will be an additional savepoint
assert not self._conn._num_transactions
else:
# inner transaction: it always has a name
if not self._savepoint_name:
self._savepoint_name = f"_pg3_{self._conn._num_transactions + 1}"
self._stack_index = self._conn._num_transactions
self._conn._num_transactions += 1
def _pop_savepoint(self, action: str) -> Optional[Exception]:
"""
Pop the transaction from the connection transactions stack.
Also verify the state consistency.
"""
self._conn._num_transactions -= 1
if self._conn._num_transactions == self._stack_index:
return None
return OutOfOrderTransactionNesting(
f"transaction {action} at the wrong nesting level: {self}"
)
class Transaction(BaseTransaction["Connection[Any]"]):
"""
Returned by `Connection.transaction()` to handle a transaction block.
"""
__module__ = "psycopg"
@property
def connection(self) -> "Connection[Any]":
"""The connection the object is managing."""
return self._conn
def __enter__(self) -> Self:
with self._conn.lock:
self._conn.wait(self._enter_gen())
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
if self.pgconn.status == OK:
with self._conn.lock:
return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
else:
return False
class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]):
"""
Returned by `AsyncConnection.transaction()` to handle a transaction block.
"""
__module__ = "psycopg"
@property
def connection(self) -> "AsyncConnection[Any]":
return self._conn
async def __aenter__(self) -> Self:
async with self._conn.lock:
await self._conn.wait(self._enter_gen())
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
if self.pgconn.status == OK:
async with self._conn.lock:
return await self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
else:
return False

Some files were not shown because too many files have changed in this diff Show More