Commit b7539b30 authored by Jason Madden's avatar Jason Madden

Merge pull request #586 from gevent/python3-socket-type

c-ares update
parents d07e4bb6 a49331c6
......@@ -65,7 +65,7 @@ class socket(object):
# Only defined under Linux
@property
def type(self):
return _socket.socket.type.__get__(self._sock) & ~_socket.SOCK_NONBLOCK
return self._sock.type & ~_socket.SOCK_NONBLOCK
def __enter__(self):
return self
......
......@@ -7,12 +7,15 @@ from _socket import gaierror
__all__ = ['channel']
cdef object basestring
cdef object string_types
cdef object text_type
if sys.version_info[0] >= 3:
basestring = (bytes, str)
string_types = str,
text_type = str
else:
basestring = __builtins__.basestring
string_types = __builtins__.basestring,
text_type = __builtins__.unicode
TIMEOUT = 1
......@@ -34,12 +37,13 @@ cdef extern from "dnshelper.c":
struct ares_channeldata:
pass
object parse_h_name(hostent*)
object parse_h_aliases(hostent*)
object parse_h_addr_list(hostent*)
void* create_object_from_hostent(void*)
# this imports _socket lazily
object PyBytes_FromString(char*)
object PyUnicode_FromString(char*)
int PyTuple_Check(object)
int PyArg_ParseTuple(object, char*, ...) except 0
struct sockaddr_in6:
......@@ -205,7 +209,7 @@ cdef void gevent_ares_host_callback(void *arg, int status, int timeouts, hostent
callback(result(None, gaierror(status, strerror(status))))
else:
try:
host_result = ares_host_result(host.h_addrtype, (host.h_name, parse_h_aliases(host), parse_h_addr_list(host)))
host_result = ares_host_result(host.h_addrtype, (parse_h_name(host), parse_h_aliases(host), parse_h_addr_list(host)))
except:
callback(result(None, sys.exc_info()[1]))
else:
......@@ -226,11 +230,11 @@ cdef void gevent_ares_nameinfo_callback(void *arg, int status, int timeouts, cha
callback(result(None, gaierror(status, strerror(status))))
else:
if c_node:
node = PyBytes_FromString(c_node)
node = PyUnicode_FromString(c_node)
else:
node = None
if c_service:
service = PyBytes_FromString(c_service)
service = PyUnicode_FromString(c_service)
else:
service = None
callback(result((node, service)))
......@@ -312,7 +316,7 @@ cdef public class channel [object PyGeventAresChannelObject, type PyGeventAresCh
raise gaierror(cares.ARES_EDESTRUCTION, 'this ares channel has been destroyed')
if not servers:
servers = []
if isinstance(servers, basestring):
if isinstance(servers, string_types):
servers = servers.split(',')
cdef int length = len(servers)
cdef int result, index
......@@ -327,8 +331,8 @@ cdef public class channel [object PyGeventAresChannelObject, type PyGeventAresCh
try:
index = 0
for server in servers:
if isinstance(server, str):
server = server.encode()
if isinstance(server, unicode):
server = server.encode('ascii')
string = <char*?>server
if cares.ares_inet_pton(AF_INET, string, &c_servers[index].addr) > 0:
c_servers[index].family = AF_INET
......
......@@ -14,7 +14,9 @@
#include "cares_pton.h"
#if PY_VERSION_HEX < 0x02060000
#define PyBytes_FromString PyString_FromString
#define PyUnicode_FromString PyString_FromString
#elif PY_MAJOR_VERSION < 3
#define PyUnicode_FromString PyBytes_FromString
#endif
......@@ -49,7 +51,7 @@ gevent_append_addr(PyObject* list, int family, void* src, char* tmpbuf, size_t t
int status = -1;
PyObject* tmp;
if (ares_inet_ntop(family, src, tmpbuf, tmpsize)) {
tmp = PyBytes_FromString(tmpbuf);
tmp = PyUnicode_FromString(tmpbuf);
if (tmp) {
status = PyList_Append(list, tmp);
Py_DECREF(tmp);
......@@ -59,6 +61,13 @@ gevent_append_addr(PyObject* list, int family, void* src, char* tmpbuf, size_t t
}
static PyObject*
parse_h_name(struct hostent *h)
{
return PyUnicode_FromString(h->h_name);
}
static PyObject*
parse_h_aliases(struct hostent *h)
{
......@@ -72,7 +81,7 @@ parse_h_aliases(struct hostent *h)
for (pch = h->h_aliases; *pch != NULL; pch++) {
if (*pch != h->h_name && strcmp(*pch, h->h_name)) {
int status;
tmp = PyBytes_FromString(*pch);
tmp = PyUnicode_FromString(*pch);
if (tmp == NULL) {
break;
}
......
# Copyright (c) 2011 Denis Bilenko. See LICENSE for details.
from __future__ import absolute_import
import os
import sys
from _socket import getservbyname, getaddrinfo, gaierror, error
from gevent.hub import Waiter, get_hub, string_types, text_type
from gevent.hub import Waiter, get_hub, string_types, text_type, reraise, PY3
from gevent.socket import AF_UNSPEC, AF_INET, AF_INET6, SOCK_STREAM, SOCK_DGRAM, SOCK_RAW, AI_NUMERICHOST, EAI_SERVICE, AI_PASSIVE
from gevent.ares import channel, InvalidIP
......@@ -19,7 +20,7 @@ class Resolver(object):
hub = get_hub()
self.hub = hub
if use_environ:
for key in os.environ.keys():
for key in os.environ:
if key.startswith('GEVENTARES_'):
name = key[11:].lower()
if name:
......@@ -52,10 +53,16 @@ class Resolver(object):
return self.gethostbyname_ex(hostname, family)[-1][0]
def gethostbyname_ex(self, hostname, family=AF_INET):
if isinstance(hostname, text_type):
hostname = hostname.encode('ascii')
elif not isinstance(hostname, str):
raise TypeError('Expected string, not %s' % type(hostname).__name__)
if PY3:
if isinstance(hostname, str):
hostname = hostname.encode('idna')
elif not isinstance(hostname, (bytes, bytearray)):
raise TypeError('Expected es(idna), not %s' % type(hostname).__name__)
else:
if isinstance(hostname, text_type):
hostname = hostname.encode('ascii')
elif not isinstance(hostname, str):
raise TypeError('Expected string, not %s' % type(hostname).__name__)
while True:
ares = self.ares
......@@ -191,10 +198,16 @@ class Resolver(object):
raise
def _gethostbyaddr(self, ip_address):
if isinstance(ip_address, text_type):
ip_address = ip_address.encode('ascii')
elif not isinstance(ip_address, str):
raise TypeError('Expected string, not %s' % type(ip_address).__name__)
if PY3:
if isinstance(ip_address, str):
ip_address = ip_address.encode('idna')
elif not isinstance(ip_address, (bytes, bytearray)):
raise TypeError('Expected es(idna), not %s' % type(ip_address).__name__)
else:
if isinstance(ip_address, text_type):
ip_address = ip_address.encode('ascii')
elif not isinstance(ip_address, str):
raise TypeError('Expected string, not %s' % type(ip_address).__name__)
waiter = Waiter(self.hub)
try:
......@@ -205,11 +218,11 @@ class Resolver(object):
if not result:
raise
_ip_address = result[0][-1][0]
if isinstance(_ip_address, text_type):
_ip_address = _ip_address.encode('ascii')
if _ip_address == ip_address:
raise
waiter.clear()
if isinstance(_ip_address, text_type):
_ip_address = _ip_address.encode('ascii')
self.ares.gethostbyaddr(waiter, _ip_address)
return waiter.get()
......@@ -230,10 +243,10 @@ class Resolver(object):
raise TypeError('getnameinfo() argument 1 must be a tuple')
address = sockaddr[0]
if isinstance(address, text_type):
if not PY3 and isinstance(address, text_type):
address = address.encode('ascii')
if not isinstance(address, str):
if not isinstance(address, string_types):
raise TypeError('sockaddr[0] must be a string, not %s' % type(address).__name__)
port = sockaddr[1]
......@@ -243,7 +256,7 @@ class Resolver(object):
waiter = Waiter(self.hub)
result = self._getaddrinfo(address, str(sockaddr[1]), family=AF_UNSPEC, socktype=SOCK_DGRAM)
if not result:
raise
reraise(*sys.exc_info())
elif len(result) != 1:
raise error('sockaddr resolved to multiple addresses')
family, socktype, proto, name, address = result[0]
......
......@@ -185,6 +185,16 @@ class TestTCP(greentest.TestCase):
fd.close()
acceptor.join()
def test_attributes(self):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
self.assertEqual(socket.AF_INET, s.type)
self.assertEqual(socket.SOCK_DGRAM, s.family)
self.assertEqual(0, s.proto)
if hasattr(socket, 'SOCK_NONBLOCK'):
s.settimeout(1)
self.assertEqual(socket.AF_INET, s.type)
def get_port():
tempsock = socket.socket()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment