summaryrefslogtreecommitdiff
path: root/slixmpp/xmlstream
diff options
context:
space:
mode:
authorDaniel Roschka <danielroschka@phoenitydawn.de>2022-07-29 12:04:01 +0200
committerDaniel Roschka <danielroschka@phoenitydawn.de>2022-07-31 13:15:25 +0200
commitd43c83800e51c2455b5070a1ccaca56b57fb1575 (patch)
treefdf72a437d1bfeedba9af549ee7ff49b66266d93 /slixmpp/xmlstream
parent1f47acaec13f30832fcfb56fc45843e90ad27673 (diff)
downloadslixmpp-d43c83800e51c2455b5070a1ccaca56b57fb1575.tar.gz
slixmpp-d43c83800e51c2455b5070a1ccaca56b57fb1575.tar.bz2
slixmpp-d43c83800e51c2455b5070a1ccaca56b57fb1575.tar.xz
slixmpp-d43c83800e51c2455b5070a1ccaca56b57fb1575.zip
Use gethostbyname when using aiodns
Slixmpp behaves differently when resolving host names, whether aiodns is used or not. With aiodns only DNS is used, while without `asyncio.loop.getaddrinfo()` is used instead, which utilizes the Name Service Switch (NSS) to resolve host names by other means (hosts-file, mDNS, ...) as well. To unify the behavior, this replaces the use of `aiodns.DNSResolver().query()` with `aiodns.DNSResolver().gethostbyname()`. This makes the behavior resolving host names more consistent between using aiodns or not, as both now honor the NSS configuration and removes the need for the previously existing workaround to resolve localhost.
Diffstat (limited to 'slixmpp/xmlstream')
-rw-r--r--slixmpp/xmlstream/resolver.py36
1 files changed, 20 insertions, 16 deletions
diff --git a/slixmpp/xmlstream/resolver.py b/slixmpp/xmlstream/resolver.py
index e524da3b..3de6629d 100644
--- a/slixmpp/xmlstream/resolver.py
+++ b/slixmpp/xmlstream/resolver.py
@@ -15,7 +15,13 @@ from slixmpp.types import Protocol
log = logging.getLogger(__name__)
-class AnswerProtocol(Protocol):
+class GetHostByNameAnswerProtocol(Protocol):
+ name: str
+ aliases: List[str]
+ addresses: List[str]
+
+
+class QueryAnswerProtocol(Protocol):
host: str
priority: int
weight: int
@@ -23,6 +29,9 @@ class AnswerProtocol(Protocol):
class ResolverProtocol(Protocol):
+ def gethostbyname(self, host: str, socket_family: socket.AddressFamily) -> Future:
+ ...
+
def query(self, query: str, querytype: str) -> Future:
...
@@ -147,11 +156,6 @@ async def resolve(host: str, port: int, *, loop: AbstractEventLoop,
results = []
for host, port in hosts:
- if host == 'localhost':
- if use_ipv6:
- results.append((host, '::1', port))
- results.append((host, '127.0.0.1', port))
-
if use_ipv6:
aaaa = await get_AAAA(host, resolver=resolver,
use_aiodns=use_aiodns, loop=loop)
@@ -201,13 +205,13 @@ async def get_A(host: str, *, loop: AbstractEventLoop,
return []
# Using aiodns:
- future = resolver.query(host, 'A')
+ future = resolver.gethostbyname(host, socket.AF_INET)
try:
- recs = cast(Iterable[AnswerProtocol], await future)
+ recs = cast(GetHostByNameAnswerProtocol, await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s A records: %s', host, e)
- recs = []
- return [rec.host for rec in recs]
+ return []
+ return [addr for addr in recs.addresses]
async def get_AAAA(host: str, *, loop: AbstractEventLoop,
@@ -249,13 +253,13 @@ async def get_AAAA(host: str, *, loop: AbstractEventLoop,
return []
# Using aiodns:
- future = resolver.query(host, 'AAAA')
+ future = resolver.gethostbyname(host, socket.AF_INET6)
try:
- recs = cast(Iterable[AnswerProtocol], await future)
+ recs = cast(GetHostByNameAnswerProtocol, await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
- recs = []
- return [rec.host for rec in recs]
+ return []
+ return [addr for addr in recs.addresses]
async def get_SRV(host: str, port: int, service: str,
@@ -295,12 +299,12 @@ async def get_SRV(host: str, port: int, service: str,
try:
future = resolver.query('_%s._%s.%s' % (service, proto, host),
'SRV')
- recs = cast(Iterable[AnswerProtocol], await future)
+ recs = cast(Iterable[QueryAnswerProtocol], await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
return []
- answers: Dict[int, List[AnswerProtocol]] = {}
+ answers: Dict[int, List[QueryAnswerProtocol]] = {}
for rec in recs:
if rec.priority not in answers:
answers[rec.priority] = []