Commit 5b6ce5a2 authored by Jeremy Hylton's avatar Jeremy Hylton

Lots of code reorganization with a few small API changes.

Change all the local names that start with SSL to start with PySSL.
The OpenSSL library defines lots of calls that start with "SSL_".  The
calls for Python's SSL objects also started with "SSL_".  This choice
made it really confusing to figure out which calls were to the library
and which calls were local to the file.

Add PySSL_SetError() that sets an exception based on the information
from SSL_get_error().  This function will eventually replace all the
calls that set it with an error message that is based on the name of
the call that failed rather than the reason it failed.  (Example: If
SSL_connect() failed it used to report "SSL_connect error" now it will
offer a specific message about why SSL_connect failed.)

    XXX It might be helpful to augment the error message generated
    below with the name of the SSL function that generated the error.
    I expect it's obvious most of the time.

Remove several unnecessary INCREFs in the module's constructor call.
PyDict_SetItem() and friends do the INCREF for you.
parent 22738b9b
......@@ -282,7 +282,7 @@ static PyObject *PyH_Error;
static PyObject *PyGAI_Error;
#ifdef USE_SSL
static PyObject *SSLErrorObject;
static PyObject *PySSLErrorObject;
#endif /* USE_SSL */
......@@ -498,13 +498,13 @@ typedef struct {
char server[256];
char issuer[256];
} SSLObject;
} PySSLObject;
staticforward PyTypeObject SSL_Type;
staticforward PyObject *SSL_SSLwrite(SSLObject *self, PyObject *args);
staticforward PyObject *SSL_SSLread(SSLObject *self, PyObject *args);
staticforward PyTypeObject PySSL_Type;
staticforward PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args);
staticforward PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args);
#define SSLObject_Check(v) ((v)->ob_type == &SSL_Type)
#define PySSLObject_Check(v) ((v)->ob_type == &PySSL_Type)
#endif /* USE_SSL */
......@@ -2490,19 +2490,88 @@ static char getnameinfo_doc[] =
\n\
Get host and port for a sockaddr.";
/* XXX It might be helpful to augment the error message generated
below with the name of the SSL function that generated the error.
I expect it's obvious most of the time.
*/
#ifdef USE_SSL
static PyObject *
PySSL_SetError(SSL *ssl, int ret)
{
PyObject *v, *n, *s;
char *errstr;
int err;
assert(ret <= 0);
err = SSL_get_error(ssl, ret);
n = PyInt_FromLong(err);
if (n == NULL)
return NULL;
v = PyTuple_New(2);
if (v == NULL) {
Py_DECREF(n);
return NULL;
}
switch (SSL_get_error(ssl, ret)) {
case SSL_ERROR_ZERO_RETURN:
errstr = "TLS/SSL connection has been closed";
break;
case SSL_ERROR_WANT_READ:
errstr = "The operation did not complete (read)";
break;
case SSL_ERROR_WANT_WRITE:
errstr = "The operation did not complete (write)";
break;
case SSL_ERROR_WANT_X509_LOOKUP:
errstr = "The operation did not complete (X509 lookup)";
break;
case SSL_ERROR_SYSCALL:
case SSL_ERROR_SSL:
{
unsigned long e = ERR_get_error();
if (e == 0) {
/* an EOF was observed that violates the protocol */
errstr = "EOF occurred in violation of protocol";
} else if (e == -1) {
/* the underlying BIO reported an I/O error */
Py_DECREF(v);
Py_DECREF(n);
PyErr_SetFromErrno(PyExc_IOError);
return NULL;
} else {
/* XXX Protected by global interpreter lock */
errstr = ERR_error_string(e, NULL);
}
break;
}
default:
errstr = "Invalid error code";
}
s = PyString_FromString(errstr);
if (s == NULL) {
Py_DECREF(v);
Py_DECREF(n);
}
PyTuple_SET_ITEM(v, 0, n);
PyTuple_SET_ITEM(v, 1, s);
PyErr_SetObject(PySSLErrorObject, v);
return NULL;
}
/* This is a C function to be called for new object initialization */
static SSLObject *
newSSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file)
static PySSLObject *
newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file)
{
SSLObject *self;
PySSLObject *self;
char *errstr = NULL;
int ret;
self = PyObject_New(SSLObject, &SSL_Type); /* Create new object */
self = PyObject_New(PySSLObject, &PySSL_Type); /* Create new object */
if (self == NULL){
errstr = "newSSLObject error";
errstr = "newPySSLObject error";
goto fail;
}
memset(self->server, '\0', sizeof(char) * 256);
......@@ -2523,7 +2592,7 @@ newSSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file)
goto fail;
}
if (key_file && cert_file) {
if (key_file) {
if (SSL_CTX_use_PrivateKey_file(self->ctx, key_file,
SSL_FILETYPE_PEM) < 1) {
errstr = "SSL_CTX_use_PrivateKey_file error";
......@@ -2545,8 +2614,9 @@ newSSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file)
/* Actually negotiate SSL connection */
/* XXX If SSL_connect() returns 0, it's also a failure. */
if ((SSL_connect(self->ssl)) == -1) {
errstr = "SSL_connect error";
ret = SSL_connect(self->ssl);
if (ret <= 0) {
PySSL_SetError(self->ssl, ret);
goto fail;
}
self->ssl->debug = 1;
......@@ -2562,7 +2632,7 @@ newSSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file)
return self;
fail:
if (errstr)
PyErr_SetString(SSLErrorObject, errstr);
PyErr_SetString(PySSLErrorObject, errstr);
Py_DECREF(self);
return NULL;
}
......@@ -2571,17 +2641,17 @@ newSSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file)
static PyObject *
PySocket_ssl(PyObject *self, PyObject *args)
{
SSLObject *rv;
PySSLObject *rv;
PySocketSockObject *Sock;
char *key_file;
char *cert_file;
char *key_file = NULL;
char *cert_file = NULL;
if (!PyArg_ParseTuple(args, "O!zz:ssl",
if (!PyArg_ParseTuple(args, "O!|zz:ssl",
&PySocketSock_Type, (PyObject*)&Sock,
&key_file, &cert_file))
return NULL;
rv = newSSLObject(Sock, key_file, cert_file);
rv = newPySSLObject(Sock, key_file, cert_file);
if (rv == NULL)
return NULL;
return (PyObject *)rv;
......@@ -2591,13 +2661,13 @@ static char ssl_doc[] =
"ssl(socket, keyfile, certfile) -> sslobject";
static PyObject *
SSL_server(SSLObject *self, PyObject *args)
PySSL_server(PySSLObject *self, PyObject *args)
{
return PyString_FromString(self->server);
}
static PyObject *
SSL_issuer(SSLObject *self, PyObject *args)
PySSL_issuer(PySSLObject *self, PyObject *args)
{
return PyString_FromString(self->issuer);
}
......@@ -2605,15 +2675,15 @@ SSL_issuer(SSLObject *self, PyObject *args)
/* SSL object methods */
static PyMethodDef SSLMethods[] = {
{"write", (PyCFunction)SSL_SSLwrite, 1},
{"read", (PyCFunction)SSL_SSLread, 1},
{"server", (PyCFunction)SSL_server, 1},
{"issuer", (PyCFunction)SSL_issuer, 1},
static PyMethodDef PySSLMethods[] = {
{"write", (PyCFunction)PySSL_SSLwrite, 1},
{"read", (PyCFunction)PySSL_SSLread, 1},
{"server", (PyCFunction)PySSL_server, 1},
{"issuer", (PyCFunction)PySSL_issuer, 1},
{NULL, NULL}
};
static void SSL_dealloc(SSLObject *self)
static void PySSL_dealloc(PySSLObject *self)
{
if (self->server_cert) /* Possible not to have one? */
X509_free (self->server_cert);
......@@ -2625,21 +2695,21 @@ static void SSL_dealloc(SSLObject *self)
PyObject_Del(self);
}
static PyObject *SSL_getattr(SSLObject *self, char *name)
static PyObject *PySSL_getattr(PySSLObject *self, char *name)
{
return Py_FindMethod(SSLMethods, (PyObject *)self, name);
return Py_FindMethod(PySSLMethods, (PyObject *)self, name);
}
staticforward PyTypeObject SSL_Type = {
staticforward PyTypeObject PySSL_Type = {
PyObject_HEAD_INIT(NULL)
0, /*ob_size*/
"SSL", /*tp_name*/
sizeof(SSLObject), /*tp_basicsize*/
sizeof(PySSLObject), /*tp_basicsize*/
0, /*tp_itemsize*/
/* methods */
(destructor)SSL_dealloc, /*tp_dealloc*/
(destructor)PySSL_dealloc, /*tp_dealloc*/
0, /*tp_print*/
(getattrfunc)SSL_getattr, /*tp_getattr*/
(getattrfunc)PySSL_getattr, /*tp_getattr*/
0, /*tp_setattr*/
0, /*tp_compare*/
0, /*tp_repr*/
......@@ -2651,7 +2721,7 @@ staticforward PyTypeObject SSL_Type = {
static PyObject *SSL_SSLwrite(SSLObject *self, PyObject *args)
static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args)
{
char *data;
size_t len;
......@@ -2663,7 +2733,7 @@ static PyObject *SSL_SSLwrite(SSLObject *self, PyObject *args)
return PyInt_FromLong((long)len);
}
static PyObject *SSL_SSLread(SSLObject *self, PyObject *args)
static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
{
PyObject *buf;
int count = 0;
......@@ -2686,14 +2756,14 @@ static PyObject *SSL_SSLread(SSLObject *self, PyObject *args)
assert(count == 0);
break;
default:
return PyErr_SetFromErrno(SSLErrorObject);
return PyErr_SetFromErrno(PySSLErrorObject);
}
fflush(stderr);
if (count < 0) {
Py_DECREF(buf);
return PyErr_SetFromErrno(SSLErrorObject);
return PyErr_SetFromErrno(PySSLErrorObject);
}
if (count != len && _PyString_Resize(&buf, count) < 0)
......@@ -2904,7 +2974,7 @@ init_socket(void)
#endif /* MS_WINDOWS */
#endif /* RISCOS */
#ifdef USE_SSL
SSL_Type.ob_type = &PyType_Type;
PySSL_Type.ob_type = &PyType_Type;
#endif
m = Py_InitModule3("_socket", PySocket_methods, module_doc);
d = PyModule_GetDict(m);
......@@ -2924,18 +2994,16 @@ init_socket(void)
#ifdef USE_SSL
SSL_load_error_strings();
SSLeay_add_ssl_algorithms();
SSLErrorObject = PyErr_NewException("socket.sslerror", NULL, NULL);
if (SSLErrorObject == NULL)
PySSLErrorObject = PyErr_NewException("socket.sslerror", NULL, NULL);
if (PySSLErrorObject == NULL)
return;
PyDict_SetItemString(d, "sslerror", SSLErrorObject);
Py_INCREF(&SSL_Type);
PyDict_SetItemString(d, "sslerror", PySSLErrorObject);
if (PyDict_SetItemString(d, "SSLType",
(PyObject *)&SSL_Type) != 0)
(PyObject *)&PySSL_Type) != 0)
return;
#endif /* USE_SSL */
PySocketSock_Type.ob_type = &PyType_Type;
PySocketSock_Type.tp_doc = sockettype_doc;
Py_INCREF(&PySocketSock_Type);
if (PyDict_SetItemString(d, "SocketType",
(PyObject *)&PySocketSock_Type) != 0)
return;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment