Commit a49331c6 authored by Jason Madden's avatar Jason Madden

Merge an updated version of #418. Fixes #418.

parent fd8c77a0
...@@ -7,12 +7,15 @@ from _socket import gaierror ...@@ -7,12 +7,15 @@ from _socket import gaierror
__all__ = ['channel'] __all__ = ['channel']
cdef object basestring cdef object string_types
cdef object text_type
if sys.version_info[0] >= 3: if sys.version_info[0] >= 3:
basestring = (bytes, str) string_types = str,
text_type = str
else: else:
basestring = __builtins__.basestring string_types = __builtins__.basestring,
text_type = __builtins__.unicode
TIMEOUT = 1 TIMEOUT = 1
...@@ -34,12 +37,13 @@ cdef extern from "dnshelper.c": ...@@ -34,12 +37,13 @@ cdef extern from "dnshelper.c":
struct ares_channeldata: struct ares_channeldata:
pass pass
object parse_h_name(hostent*)
object parse_h_aliases(hostent*) object parse_h_aliases(hostent*)
object parse_h_addr_list(hostent*) object parse_h_addr_list(hostent*)
void* create_object_from_hostent(void*) void* create_object_from_hostent(void*)
# this imports _socket lazily # this imports _socket lazily
object PyBytes_FromString(char*) object PyUnicode_FromString(char*)
int PyTuple_Check(object) int PyTuple_Check(object)
int PyArg_ParseTuple(object, char*, ...) except 0 int PyArg_ParseTuple(object, char*, ...) except 0
struct sockaddr_in6: struct sockaddr_in6:
...@@ -205,7 +209,7 @@ cdef void gevent_ares_host_callback(void *arg, int status, int timeouts, hostent ...@@ -205,7 +209,7 @@ cdef void gevent_ares_host_callback(void *arg, int status, int timeouts, hostent
callback(result(None, gaierror(status, strerror(status)))) callback(result(None, gaierror(status, strerror(status))))
else: else:
try: try:
host_result = ares_host_result(host.h_addrtype, (host.h_name, parse_h_aliases(host), parse_h_addr_list(host))) host_result = ares_host_result(host.h_addrtype, (parse_h_name(host), parse_h_aliases(host), parse_h_addr_list(host)))
except: except:
callback(result(None, sys.exc_info()[1])) callback(result(None, sys.exc_info()[1]))
else: else:
...@@ -226,11 +230,11 @@ cdef void gevent_ares_nameinfo_callback(void *arg, int status, int timeouts, cha ...@@ -226,11 +230,11 @@ cdef void gevent_ares_nameinfo_callback(void *arg, int status, int timeouts, cha
callback(result(None, gaierror(status, strerror(status)))) callback(result(None, gaierror(status, strerror(status))))
else: else:
if c_node: if c_node:
node = PyBytes_FromString(c_node) node = PyUnicode_FromString(c_node)
else: else:
node = None node = None
if c_service: if c_service:
service = PyBytes_FromString(c_service) service = PyUnicode_FromString(c_service)
else: else:
service = None service = None
callback(result((node, service))) callback(result((node, service)))
...@@ -312,7 +316,7 @@ cdef public class channel [object PyGeventAresChannelObject, type PyGeventAresCh ...@@ -312,7 +316,7 @@ cdef public class channel [object PyGeventAresChannelObject, type PyGeventAresCh
raise gaierror(cares.ARES_EDESTRUCTION, 'this ares channel has been destroyed') raise gaierror(cares.ARES_EDESTRUCTION, 'this ares channel has been destroyed')
if not servers: if not servers:
servers = [] servers = []
if isinstance(servers, basestring): if isinstance(servers, string_types):
servers = servers.split(',') servers = servers.split(',')
cdef int length = len(servers) cdef int length = len(servers)
cdef int result, index cdef int result, index
...@@ -327,8 +331,8 @@ cdef public class channel [object PyGeventAresChannelObject, type PyGeventAresCh ...@@ -327,8 +331,8 @@ cdef public class channel [object PyGeventAresChannelObject, type PyGeventAresCh
try: try:
index = 0 index = 0
for server in servers: for server in servers:
if isinstance(server, str): if isinstance(server, unicode):
server = server.encode() server = server.encode('ascii')
string = <char*?>server string = <char*?>server
if cares.ares_inet_pton(AF_INET, string, &c_servers[index].addr) > 0: if cares.ares_inet_pton(AF_INET, string, &c_servers[index].addr) > 0:
c_servers[index].family = AF_INET c_servers[index].family = AF_INET
......
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#include "cares_pton.h" #include "cares_pton.h"
#if PY_VERSION_HEX < 0x02060000 #if PY_VERSION_HEX < 0x02060000
#define PyBytes_FromString PyString_FromString #define PyUnicode_FromString PyString_FromString
#elif PY_MAJOR_VERSION < 3
#define PyUnicode_FromString PyBytes_FromString
#endif #endif
...@@ -49,7 +51,7 @@ gevent_append_addr(PyObject* list, int family, void* src, char* tmpbuf, size_t t ...@@ -49,7 +51,7 @@ gevent_append_addr(PyObject* list, int family, void* src, char* tmpbuf, size_t t
int status = -1; int status = -1;
PyObject* tmp; PyObject* tmp;
if (ares_inet_ntop(family, src, tmpbuf, tmpsize)) { if (ares_inet_ntop(family, src, tmpbuf, tmpsize)) {
tmp = PyBytes_FromString(tmpbuf); tmp = PyUnicode_FromString(tmpbuf);
if (tmp) { if (tmp) {
status = PyList_Append(list, tmp); status = PyList_Append(list, tmp);
Py_DECREF(tmp); Py_DECREF(tmp);
...@@ -59,6 +61,13 @@ gevent_append_addr(PyObject* list, int family, void* src, char* tmpbuf, size_t t ...@@ -59,6 +61,13 @@ gevent_append_addr(PyObject* list, int family, void* src, char* tmpbuf, size_t t
} }
static PyObject*
parse_h_name(struct hostent *h)
{
return PyUnicode_FromString(h->h_name);
}
static PyObject* static PyObject*
parse_h_aliases(struct hostent *h) parse_h_aliases(struct hostent *h)
{ {
...@@ -72,7 +81,7 @@ parse_h_aliases(struct hostent *h) ...@@ -72,7 +81,7 @@ parse_h_aliases(struct hostent *h)
for (pch = h->h_aliases; *pch != NULL; pch++) { for (pch = h->h_aliases; *pch != NULL; pch++) {
if (*pch != h->h_name && strcmp(*pch, h->h_name)) { if (*pch != h->h_name && strcmp(*pch, h->h_name)) {
int status; int status;
tmp = PyBytes_FromString(*pch); tmp = PyUnicode_FromString(*pch);
if (tmp == NULL) { if (tmp == NULL) {
break; break;
} }
......
# Copyright (c) 2011 Denis Bilenko. See LICENSE for details. # Copyright (c) 2011 Denis Bilenko. See LICENSE for details.
from __future__ import absolute_import from __future__ import absolute_import
import os import os
import sys
from _socket import getservbyname, getaddrinfo, gaierror, error from _socket import getservbyname, getaddrinfo, gaierror, error
from gevent.hub import Waiter, get_hub, string_types, text_type from gevent.hub import Waiter, get_hub, string_types, text_type, reraise, PY3
from gevent.socket import AF_UNSPEC, AF_INET, AF_INET6, SOCK_STREAM, SOCK_DGRAM, SOCK_RAW, AI_NUMERICHOST, EAI_SERVICE, AI_PASSIVE from gevent.socket import AF_UNSPEC, AF_INET, AF_INET6, SOCK_STREAM, SOCK_DGRAM, SOCK_RAW, AI_NUMERICHOST, EAI_SERVICE, AI_PASSIVE
from gevent.ares import channel, InvalidIP from gevent.ares import channel, InvalidIP
...@@ -19,7 +20,7 @@ class Resolver(object): ...@@ -19,7 +20,7 @@ class Resolver(object):
hub = get_hub() hub = get_hub()
self.hub = hub self.hub = hub
if use_environ: if use_environ:
for key in os.environ.keys(): for key in os.environ:
if key.startswith('GEVENTARES_'): if key.startswith('GEVENTARES_'):
name = key[11:].lower() name = key[11:].lower()
if name: if name:
...@@ -52,10 +53,16 @@ class Resolver(object): ...@@ -52,10 +53,16 @@ class Resolver(object):
return self.gethostbyname_ex(hostname, family)[-1][0] return self.gethostbyname_ex(hostname, family)[-1][0]
def gethostbyname_ex(self, hostname, family=AF_INET): def gethostbyname_ex(self, hostname, family=AF_INET):
if isinstance(hostname, text_type): if PY3:
hostname = hostname.encode('ascii') if isinstance(hostname, str):
elif not isinstance(hostname, str): hostname = hostname.encode('idna')
raise TypeError('Expected string, not %s' % type(hostname).__name__) elif not isinstance(hostname, (bytes, bytearray)):
raise TypeError('Expected es(idna), not %s' % type(hostname).__name__)
else:
if isinstance(hostname, text_type):
hostname = hostname.encode('ascii')
elif not isinstance(hostname, str):
raise TypeError('Expected string, not %s' % type(hostname).__name__)
while True: while True:
ares = self.ares ares = self.ares
...@@ -191,10 +198,16 @@ class Resolver(object): ...@@ -191,10 +198,16 @@ class Resolver(object):
raise raise
def _gethostbyaddr(self, ip_address): def _gethostbyaddr(self, ip_address):
if isinstance(ip_address, text_type): if PY3:
ip_address = ip_address.encode('ascii') if isinstance(ip_address, str):
elif not isinstance(ip_address, str): ip_address = ip_address.encode('idna')
raise TypeError('Expected string, not %s' % type(ip_address).__name__) elif not isinstance(ip_address, (bytes, bytearray)):
raise TypeError('Expected es(idna), not %s' % type(ip_address).__name__)
else:
if isinstance(ip_address, text_type):
ip_address = ip_address.encode('ascii')
elif not isinstance(ip_address, str):
raise TypeError('Expected string, not %s' % type(ip_address).__name__)
waiter = Waiter(self.hub) waiter = Waiter(self.hub)
try: try:
...@@ -205,11 +218,11 @@ class Resolver(object): ...@@ -205,11 +218,11 @@ class Resolver(object):
if not result: if not result:
raise raise
_ip_address = result[0][-1][0] _ip_address = result[0][-1][0]
if isinstance(_ip_address, text_type):
_ip_address = _ip_address.encode('ascii')
if _ip_address == ip_address: if _ip_address == ip_address:
raise raise
waiter.clear() waiter.clear()
if isinstance(_ip_address, text_type):
_ip_address = _ip_address.encode('ascii')
self.ares.gethostbyaddr(waiter, _ip_address) self.ares.gethostbyaddr(waiter, _ip_address)
return waiter.get() return waiter.get()
...@@ -230,10 +243,10 @@ class Resolver(object): ...@@ -230,10 +243,10 @@ class Resolver(object):
raise TypeError('getnameinfo() argument 1 must be a tuple') raise TypeError('getnameinfo() argument 1 must be a tuple')
address = sockaddr[0] address = sockaddr[0]
if isinstance(address, text_type): if not PY3 and isinstance(address, text_type):
address = address.encode('ascii') address = address.encode('ascii')
if not isinstance(address, str): if not isinstance(address, string_types):
raise TypeError('sockaddr[0] must be a string, not %s' % type(address).__name__) raise TypeError('sockaddr[0] must be a string, not %s' % type(address).__name__)
port = sockaddr[1] port = sockaddr[1]
...@@ -243,7 +256,7 @@ class Resolver(object): ...@@ -243,7 +256,7 @@ class Resolver(object):
waiter = Waiter(self.hub) waiter = Waiter(self.hub)
result = self._getaddrinfo(address, str(sockaddr[1]), family=AF_UNSPEC, socktype=SOCK_DGRAM) result = self._getaddrinfo(address, str(sockaddr[1]), family=AF_UNSPEC, socktype=SOCK_DGRAM)
if not result: if not result:
raise reraise(*sys.exc_info())
elif len(result) != 1: elif len(result) != 1:
raise error('sockaddr resolved to multiple addresses') raise error('sockaddr resolved to multiple addresses')
family, socktype, proto, name, address = result[0] family, socktype, proto, name, address = result[0]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment