This commit is contained in:
@@ -45,11 +45,6 @@ class SSLMode(enum.IntEnum):
|
||||
return getattr(cls, sslmode.replace('-', '_'))
|
||||
|
||||
|
||||
class SSLNegotiation(compat.StrEnum):
|
||||
postgres = "postgres"
|
||||
direct = "direct"
|
||||
|
||||
|
||||
_ConnectionParameters = collections.namedtuple(
|
||||
'ConnectionParameters',
|
||||
[
|
||||
@@ -58,11 +53,9 @@ _ConnectionParameters = collections.namedtuple(
|
||||
'database',
|
||||
'ssl',
|
||||
'sslmode',
|
||||
'ssl_negotiation',
|
||||
'direct_tls',
|
||||
'server_settings',
|
||||
'target_session_attrs',
|
||||
'krbsrvname',
|
||||
'gsslib',
|
||||
])
|
||||
|
||||
|
||||
@@ -268,13 +261,12 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
|
||||
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
|
||||
password, passfile, database, ssl,
|
||||
direct_tls, server_settings,
|
||||
target_session_attrs, krbsrvname, gsslib):
|
||||
target_session_attrs):
|
||||
# `auth_hosts` is the version of host information for the purposes
|
||||
# of reading the pgpass file.
|
||||
auth_hosts = None
|
||||
sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None
|
||||
ssl_min_protocol_version = ssl_max_protocol_version = None
|
||||
sslnegotiation = None
|
||||
|
||||
if dsn:
|
||||
parsed = urllib.parse.urlparse(dsn)
|
||||
@@ -368,9 +360,6 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
|
||||
if 'sslrootcert' in query:
|
||||
sslrootcert = query.pop('sslrootcert')
|
||||
|
||||
if 'sslnegotiation' in query:
|
||||
sslnegotiation = query.pop('sslnegotiation')
|
||||
|
||||
if 'sslcrl' in query:
|
||||
sslcrl = query.pop('sslcrl')
|
||||
|
||||
@@ -394,16 +383,6 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
|
||||
if target_session_attrs is None:
|
||||
target_session_attrs = dsn_target_session_attrs
|
||||
|
||||
if 'krbsrvname' in query:
|
||||
val = query.pop('krbsrvname')
|
||||
if krbsrvname is None:
|
||||
krbsrvname = val
|
||||
|
||||
if 'gsslib' in query:
|
||||
val = query.pop('gsslib')
|
||||
if gsslib is None:
|
||||
gsslib = val
|
||||
|
||||
if query:
|
||||
if server_settings is None:
|
||||
server_settings = query
|
||||
@@ -512,36 +491,13 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
|
||||
if ssl is None and have_tcp_addrs:
|
||||
ssl = 'prefer'
|
||||
|
||||
if direct_tls is not None:
|
||||
sslneg = (
|
||||
SSLNegotiation.direct if direct_tls else SSLNegotiation.postgres
|
||||
)
|
||||
else:
|
||||
if sslnegotiation is None:
|
||||
sslnegotiation = os.environ.get("PGSSLNEGOTIATION")
|
||||
|
||||
if sslnegotiation is not None:
|
||||
try:
|
||||
sslneg = SSLNegotiation(sslnegotiation)
|
||||
except ValueError:
|
||||
modes = ', '.join(
|
||||
m.name.replace('_', '-')
|
||||
for m in SSLNegotiation
|
||||
)
|
||||
raise exceptions.ClientConfigurationError(
|
||||
f'`sslnegotiation` parameter must be one of: {modes}'
|
||||
) from None
|
||||
else:
|
||||
sslneg = SSLNegotiation.postgres
|
||||
|
||||
if isinstance(ssl, (str, SSLMode)):
|
||||
try:
|
||||
sslmode = SSLMode.parse(ssl)
|
||||
except AttributeError:
|
||||
modes = ', '.join(m.name.replace('_', '-') for m in SSLMode)
|
||||
raise exceptions.ClientConfigurationError(
|
||||
'`sslmode` parameter must be one of: {}'.format(modes)
|
||||
) from None
|
||||
'`sslmode` parameter must be one of: {}'.format(modes))
|
||||
|
||||
# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
|
||||
if sslmode < SSLMode.allow:
|
||||
@@ -694,24 +650,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
|
||||
)
|
||||
) from None
|
||||
|
||||
if krbsrvname is None:
|
||||
krbsrvname = os.getenv('PGKRBSRVNAME')
|
||||
|
||||
if gsslib is None:
|
||||
gsslib = os.getenv('PGGSSLIB')
|
||||
if gsslib is None:
|
||||
gsslib = 'sspi' if _system == 'Windows' else 'gssapi'
|
||||
if gsslib not in {'gssapi', 'sspi'}:
|
||||
raise exceptions.ClientConfigurationError(
|
||||
"gsslib parameter must be either 'gssapi' or 'sspi'"
|
||||
", got {!r}".format(gsslib))
|
||||
|
||||
params = _ConnectionParameters(
|
||||
user=user, password=password, database=database, ssl=ssl,
|
||||
sslmode=sslmode, ssl_negotiation=sslneg,
|
||||
sslmode=sslmode, direct_tls=direct_tls,
|
||||
server_settings=server_settings,
|
||||
target_session_attrs=target_session_attrs,
|
||||
krbsrvname=krbsrvname, gsslib=gsslib)
|
||||
target_session_attrs=target_session_attrs)
|
||||
|
||||
return addrs, params
|
||||
|
||||
@@ -722,7 +665,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
|
||||
max_cached_statement_lifetime,
|
||||
max_cacheable_statement_size,
|
||||
ssl, direct_tls, server_settings,
|
||||
target_session_attrs, krbsrvname, gsslib):
|
||||
target_session_attrs):
|
||||
local_vars = locals()
|
||||
for var_name in {'max_cacheable_statement_size',
|
||||
'max_cached_statement_lifetime',
|
||||
@@ -751,8 +694,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
|
||||
password=password, passfile=passfile, ssl=ssl,
|
||||
direct_tls=direct_tls, database=database,
|
||||
server_settings=server_settings,
|
||||
target_session_attrs=target_session_attrs,
|
||||
krbsrvname=krbsrvname, gsslib=gsslib)
|
||||
target_session_attrs=target_session_attrs)
|
||||
|
||||
config = _ClientConfiguration(
|
||||
command_timeout=command_timeout,
|
||||
@@ -914,9 +856,9 @@ async def __connect_addr(
|
||||
# UNIX socket
|
||||
connector = loop.create_unix_connection(proto_factory, addr)
|
||||
|
||||
elif params.ssl and params.ssl_negotiation is SSLNegotiation.direct:
|
||||
# if ssl and ssl_negotiation is `direct`, skip STARTTLS and perform
|
||||
# direct SSL connection
|
||||
elif params.ssl and params.direct_tls:
|
||||
# if ssl and direct_tls are given, skip STARTTLS and perform direct
|
||||
# SSL connection
|
||||
connector = loop.create_connection(
|
||||
proto_factory, *addr, ssl=params.ssl
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user