Commit cd69f462 authored by unknown's avatar unknown

Merge shellback.(none):/home/msvensson/mysql/yassl_import/my50-yassl_import

into  shellback.(none):/home/msvensson/mysql/yassl_import/mysql-5.0-maint


extra/yassl/src/ssl.cpp:
  Auto merged
parents 89d106c1 eb6ab467
yaSSL Release notes, version 1.4.0 (08/13/06) yaSSL Release notes, version 1.5.0 (11/09/06)
This release of yaSSL contains bug fixes, portability enhancements,
and full TLS 1.1 support. Use the functions:
SSL_METHOD *TLSv1_1_server_method(void);
SSL_METHOD *TLSv1_1_client_method(void);
or the SSLv23 versions (even though yaSSL doesn't support SSL 2.0 the v23
means to pick the highest of SSL 3.0, TLS 1.0, or TLS 1.1.
See normal build instructions below under 1.0.6.
See libcurl build instructions below under 1.3.0.
****************yaSSL Release notes, version 1.4.5 (10/15/06)
This release of yaSSL contains bug fixes, portability enhancements,
zlib compression support, removal of assembly instructions at runtime if
not supported, and initial TLS 1.1 support.
Compression Notes: yaSSL uses zlib for compression and the compression
should only be used if yaSSL is at both ends because the implementation
details aren't yet standard. If you'd like to turn compression on use
the SSL_set_compression() function on the client before calling
SSL_connect(). If both the client and server were built with zlib support
then the connection will use compression. If the client isn't built with
support then SSL_set_compression() will return an error (-1).
To build yaSSL with zlib support on Unix simply have zlib support on your
system and configure will find it if it's in the standard locations. If
it's somewhere else use the option ./configure --with-zlib=DIR. If you'd
like to disable compression support in yaSSL use ./configure --without-zlib.
To build yaSSL with zlib support on Windows:
1) download zlib from http://www.zlib.net/
2) follow the instructions in zlib from projects/visualc6/README.txt
for how to add the zlib project into the yaSSL workspace noting that
you'll need to add configuration support for "Win32 Debug" and
"Win32 Release" in note 3 under "To use:".
3) define HAVE_LIBZ when building yaSSL
See normal build instructions below under 1.0.6.
See libcurl build instructions below under 1.3.0.
********************yaSSL Release notes, version 1.4.0 (08/13/06)
This release of yaSSL contains bug fixes, portability enhancements, This release of yaSSL contains bug fixes, portability enhancements,
...@@ -122,18 +174,6 @@ Choose (Re)Build All from the project workspace ...@@ -122,18 +174,6 @@ Choose (Re)Build All from the project workspace
run Debug\testsuite.exe from yaSSL-Home\testsuite to test the build run Debug\testsuite.exe from yaSSL-Home\testsuite to test the build
--To enable ia32 assembly for TaoCrypt ciphers and message digests
On MSVC this is always on
On GCC **, use ./configure --enable-ia32-asm
** This isn't on by default because of the use of intel syntax and the
problem that olders versions of gas have with some addressing statements.
If you enable this and get assemler errors during compilation or can't
pass the TaoCrypt tests, please send todd@yassl.com a message and disable
this option in the meantime.
***************** yaSSL Release notes, version 1.0.5 ***************** yaSSL Release notes, version 1.0.5
......
...@@ -5,6 +5,35 @@ ...@@ -5,6 +5,35 @@
//#define TEST_RESUME //#define TEST_RESUME
void ClientError(SSL_CTX* ctx, SSL* ssl, SOCKET_T& sockfd, const char* msg)
{
SSL_CTX_free(ctx);
SSL_free(ssl);
tcp_close(sockfd);
err_sys(msg);
}
#ifdef NON_BLOCKING
void NonBlockingSSL_Connect(SSL* ssl, SSL_CTX* ctx, SOCKET_T& sockfd)
{
int ret = SSL_connect(ssl);
while (ret =! SSL_SUCCESS && SSL_get_error(ssl, 0) ==
SSL_ERROR_WANT_READ) {
printf("... client would block\n");
#ifdef _WIN32
Sleep(1000);
#else
sleep(1);
#endif
ret = SSL_connect(ssl);
}
if (ret != SSL_SUCCESS)
ClientError(ctx, ssl, sockfd, "SSL_connect failed");
}
#endif
void client_test(void* args) void client_test(void* args)
{ {
#ifdef _WIN32 #ifdef _WIN32
...@@ -18,6 +47,9 @@ void client_test(void* args) ...@@ -18,6 +47,9 @@ void client_test(void* args)
set_args(argc, argv, *static_cast<func_args*>(args)); set_args(argc, argv, *static_cast<func_args*>(args));
tcp_connect(sockfd); tcp_connect(sockfd);
#ifdef NON_BLOCKING
tcp_set_nonblocking(sockfd);
#endif
SSL_METHOD* method = TLSv1_client_method(); SSL_METHOD* method = TLSv1_client_method();
SSL_CTX* ctx = SSL_CTX_new(method); SSL_CTX* ctx = SSL_CTX_new(method);
...@@ -27,13 +59,13 @@ void client_test(void* args) ...@@ -27,13 +59,13 @@ void client_test(void* args)
SSL_set_fd(ssl, sockfd); SSL_set_fd(ssl, sockfd);
#ifdef NON_BLOCKING
NonBlockingSSL_Connect(ssl, ctx, sockfd);
#else
if (SSL_connect(ssl) != SSL_SUCCESS) if (SSL_connect(ssl) != SSL_SUCCESS)
{ ClientError(ctx, ssl, sockfd, "SSL_connect failed");
SSL_CTX_free(ctx); #endif
SSL_free(ssl);
tcp_close(sockfd);
err_sys("SSL_connect failed");
}
showPeer(ssl); showPeer(ssl);
const char* cipher = 0; const char* cipher = 0;
...@@ -49,16 +81,14 @@ void client_test(void* args) ...@@ -49,16 +81,14 @@ void client_test(void* args)
char msg[] = "hello yassl!"; char msg[] = "hello yassl!";
if (SSL_write(ssl, msg, sizeof(msg)) != sizeof(msg)) if (SSL_write(ssl, msg, sizeof(msg)) != sizeof(msg))
{ ClientError(ctx, ssl, sockfd, "SSL_write failed");
SSL_CTX_free(ctx);
SSL_free(ssl);
tcp_close(sockfd);
err_sys("SSL_write failed");
}
char reply[1024]; char reply[1024];
reply[SSL_read(ssl, reply, sizeof(reply))] = 0; int input = SSL_read(ssl, reply, sizeof(reply));
if (input > 0) {
reply[input] = 0;
printf("Server response: %s\n", reply); printf("Server response: %s\n", reply);
}
#ifdef TEST_RESUME #ifdef TEST_RESUME
SSL_SESSION* session = SSL_get_session(ssl); SSL_SESSION* session = SSL_get_session(ssl);
...@@ -75,24 +105,17 @@ void client_test(void* args) ...@@ -75,24 +105,17 @@ void client_test(void* args)
SSL_set_session(sslResume, session); SSL_set_session(sslResume, session);
if (SSL_connect(sslResume) != SSL_SUCCESS) if (SSL_connect(sslResume) != SSL_SUCCESS)
{ ClientError(ctx, sslResume, sockfd, "SSL_resume failed");
SSL_CTX_free(ctx);
SSL_free(ssl);
tcp_close(sockfd);
err_sys("SSL resume failed");
}
showPeer(sslResume); showPeer(sslResume);
if (SSL_write(sslResume, msg, sizeof(msg)) != sizeof(msg)) if (SSL_write(sslResume, msg, sizeof(msg)) != sizeof(msg))
{ ClientError(ctx, sslResume, sockfd, "SSL_write failed");
SSL_CTX_free(ctx);
SSL_free(ssl);
tcp_close(sockfd);
err_sys("SSL_write failed");
}
reply[SSL_read(sslResume, reply, sizeof(reply))] = 0; input = SSL_read(sslResume, reply, sizeof(reply));
if (input > 0) {
reply[input] = 0;
printf("Server response: %s\n", reply); printf("Server response: %s\n", reply);
}
SSL_shutdown(sslResume); SSL_shutdown(sslResume);
SSL_free(sslResume); SSL_free(sslResume);
......
...@@ -3,6 +3,15 @@ ...@@ -3,6 +3,15 @@
#include "../../testsuite/test.hpp" #include "../../testsuite/test.hpp"
void EchoClientError(SSL_CTX* ctx, SSL* ssl, SOCKET_T& sockfd, const char* msg)
{
SSL_CTX_free(ctx);
SSL_free(ssl);
tcp_close(sockfd);
err_sys(msg);
}
void echoclient_test(void* args) void echoclient_test(void* args)
{ {
#ifdef _WIN32 #ifdef _WIN32
...@@ -35,7 +44,7 @@ void echoclient_test(void* args) ...@@ -35,7 +44,7 @@ void echoclient_test(void* args)
tcp_connect(sockfd); tcp_connect(sockfd);
SSL_METHOD* method = TLSv1_client_method(); SSL_METHOD* method = SSLv23_client_method();
SSL_CTX* ctx = SSL_CTX_new(method); SSL_CTX* ctx = SSL_CTX_new(method);
set_certs(ctx); set_certs(ctx);
SSL* ssl = SSL_new(ctx); SSL* ssl = SSL_new(ctx);
...@@ -43,12 +52,7 @@ void echoclient_test(void* args) ...@@ -43,12 +52,7 @@ void echoclient_test(void* args)
SSL_set_fd(ssl, sockfd); SSL_set_fd(ssl, sockfd);
if (SSL_connect(ssl) != SSL_SUCCESS) if (SSL_connect(ssl) != SSL_SUCCESS)
{ EchoClientError(ctx, ssl, sockfd, "SSL_connect failed");
SSL_CTX_free(ctx);
SSL_free(ssl);
tcp_close(sockfd);
err_sys("SSL_connect failed");
}
char send[1024]; char send[1024];
char reply[1024]; char reply[1024];
...@@ -57,12 +61,7 @@ void echoclient_test(void* args) ...@@ -57,12 +61,7 @@ void echoclient_test(void* args)
int sendSz = strlen(send) + 1; int sendSz = strlen(send) + 1;
if (SSL_write(ssl, send, sendSz) != sendSz) if (SSL_write(ssl, send, sendSz) != sendSz)
{ EchoClientError(ctx, ssl, sockfd, "SSL_write failed");
SSL_CTX_free(ctx);
SSL_free(ssl);
tcp_close(sockfd);
err_sys("SSL_write failed");
}
if (strncmp(send, "quit", 4) == 0) { if (strncmp(send, "quit", 4) == 0) {
fputs("sending server shutdown command: quit!\n", fout); fputs("sending server shutdown command: quit!\n", fout);
......
...@@ -56,7 +56,7 @@ THREAD_RETURN YASSL_API echoserver_test(void* args) ...@@ -56,7 +56,7 @@ THREAD_RETURN YASSL_API echoserver_test(void* args)
tcp_listen(sockfd); tcp_listen(sockfd);
SSL_METHOD* method = TLSv1_server_method(); SSL_METHOD* method = SSLv23_server_method();
SSL_CTX* ctx = SSL_CTX_new(method); SSL_CTX* ctx = SSL_CTX_new(method);
set_serverCerts(ctx); set_serverCerts(ctx);
...@@ -87,8 +87,12 @@ THREAD_RETURN YASSL_API echoserver_test(void* args) ...@@ -87,8 +87,12 @@ THREAD_RETURN YASSL_API echoserver_test(void* args)
SSL* ssl = SSL_new(ctx); SSL* ssl = SSL_new(ctx);
SSL_set_fd(ssl, clientfd); SSL_set_fd(ssl, clientfd);
if (SSL_accept(ssl) != SSL_SUCCESS) if (SSL_accept(ssl) != SSL_SUCCESS) {
EchoError(ctx, ssl, sockfd, clientfd, "SSL_accept failed"); printf("SSL_accept failed\n");
SSL_free(ssl);
tcp_close(clientfd);
continue;
}
char command[1024]; char command[1024];
int echoSz(0); int echoSz(0);
...@@ -130,6 +134,7 @@ THREAD_RETURN YASSL_API echoserver_test(void* args) ...@@ -130,6 +134,7 @@ THREAD_RETURN YASSL_API echoserver_test(void* args)
if (SSL_write(ssl, command, echoSz) != echoSz) if (SSL_write(ssl, command, echoSz) != echoSz)
EchoError(ctx, ssl, sockfd, clientfd, "SSL_write failed"); EchoError(ctx, ssl, sockfd, clientfd, "SSL_write failed");
} }
SSL_shutdown(ssl);
SSL_free(ssl); SSL_free(ssl);
tcp_close(clientfd); tcp_close(clientfd);
} }
......
...@@ -13,6 +13,26 @@ void ServerError(SSL_CTX* ctx, SSL* ssl, SOCKET_T& sockfd, const char* msg) ...@@ -13,6 +13,26 @@ void ServerError(SSL_CTX* ctx, SSL* ssl, SOCKET_T& sockfd, const char* msg)
} }
#ifdef NON_BLOCKING
void NonBlockingSSL_Accept(SSL* ssl, SSL_CTX* ctx, SOCKET_T& clientfd)
{
int ret = SSL_accept(ssl);
while (ret != SSL_SUCCESS && SSL_get_error(ssl, 0) ==
SSL_ERROR_WANT_READ) {
printf("... server would block\n");
#ifdef _WIN32
Sleep(1000);
#else
sleep(1);
#endif
ret = SSL_accept(ssl);
}
if (ret != SSL_SUCCESS)
ServerError(ctx, ssl, clientfd, "SSL_accept failed");
}
#endif
THREAD_RETURN YASSL_API server_test(void* args) THREAD_RETURN YASSL_API server_test(void* args)
{ {
#ifdef _WIN32 #ifdef _WIN32
...@@ -33,7 +53,7 @@ THREAD_RETURN YASSL_API server_test(void* args) ...@@ -33,7 +53,7 @@ THREAD_RETURN YASSL_API server_test(void* args)
SSL_METHOD* method = TLSv1_server_method(); SSL_METHOD* method = TLSv1_server_method();
SSL_CTX* ctx = SSL_CTX_new(method); SSL_CTX* ctx = SSL_CTX_new(method);
//SSL_CTX_set_cipher_list(ctx, "RC4-SHA"); //SSL_CTX_set_cipher_list(ctx, "RC4-SHA:RC4-MD5");
SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, 0); SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, 0);
set_serverCerts(ctx); set_serverCerts(ctx);
DH* dh = set_tmpDH(ctx); DH* dh = set_tmpDH(ctx);
...@@ -41,15 +61,22 @@ THREAD_RETURN YASSL_API server_test(void* args) ...@@ -41,15 +61,22 @@ THREAD_RETURN YASSL_API server_test(void* args)
SSL* ssl = SSL_new(ctx); SSL* ssl = SSL_new(ctx);
SSL_set_fd(ssl, clientfd); SSL_set_fd(ssl, clientfd);
#ifdef NON_BLOCKING
NonBlockingSSL_Accept(ssl, ctx, clientfd);
#else
if (SSL_accept(ssl) != SSL_SUCCESS) if (SSL_accept(ssl) != SSL_SUCCESS)
ServerError(ctx, ssl, clientfd, "SSL_accept failed"); ServerError(ctx, ssl, clientfd, "SSL_accept failed");
#endif
showPeer(ssl); showPeer(ssl);
printf("Using Cipher Suite: %s\n", SSL_get_cipher(ssl)); printf("Using Cipher Suite: %s\n", SSL_get_cipher(ssl));
char command[1024]; char command[1024];
command[SSL_read(ssl, command, sizeof(command))] = 0; int input = SSL_read(ssl, command, sizeof(command));
if (input > 0) {
command[input] = 0;
printf("First client command: %s\n", command); printf("First client command: %s\n", command);
}
char msg[] = "I hear you, fa shizzle!"; char msg[] = "I hear you, fa shizzle!";
if (SSL_write(ssl, msg, sizeof(msg)) != sizeof(msg)) if (SSL_write(ssl, msg, sizeof(msg)) != sizeof(msg))
...@@ -57,6 +84,7 @@ THREAD_RETURN YASSL_API server_test(void* args) ...@@ -57,6 +84,7 @@ THREAD_RETURN YASSL_API server_test(void* args)
DH_free(dh); DH_free(dh);
SSL_CTX_free(ctx); SSL_CTX_free(ctx);
SSL_shutdown(ssl);
SSL_free(ssl); SSL_free(ssl);
tcp_close(clientfd); tcp_close(clientfd);
...@@ -82,3 +110,4 @@ THREAD_RETURN YASSL_API server_test(void* args) ...@@ -82,3 +110,4 @@ THREAD_RETURN YASSL_API server_test(void* args)
} }
#endif // NO_MAIN_DRIVER #endif // NO_MAIN_DRIVER
...@@ -42,12 +42,7 @@ ...@@ -42,12 +42,7 @@
namespace STL = STL_NAMESPACE; namespace STL = STL_NAMESPACE;
// VC60 workaround: it doesn't allow typename in some places
#if defined(_MSC_VER) && (_MSC_VER < 1300)
#define CPP_TYPENAME
#else
#define CPP_TYPENAME typename
#endif
namespace yaSSL { namespace yaSSL {
......
...@@ -52,6 +52,7 @@ ...@@ -52,6 +52,7 @@
#define SSL_set_session yaSSL_set_session #define SSL_set_session yaSSL_set_session
#define SSL_get_session yaSSL_get_session #define SSL_get_session yaSSL_get_session
#define SSL_SESSION_set_timeout yaSSL_SESSION_set_timeout #define SSL_SESSION_set_timeout yaSSL_SESSION_set_timeout
#define SSL_CTX_set_session_cache_mode yaSSL_CTX_set_session_cache_mode
#define SSL_get_peer_certificate yaSSL_get_peer_certificate #define SSL_get_peer_certificate yaSSL_get_peer_certificate
#define SSL_get_verify_result yaSSL_get_verify_result #define SSL_get_verify_result yaSSL_get_verify_result
#define SSL_CTX_set_verify yaSSL_CTX_set_verify #define SSL_CTX_set_verify yaSSL_CTX_set_verify
...@@ -98,6 +99,8 @@ ...@@ -98,6 +99,8 @@
#define SSLv3_client_method yaSSLv3_client_method #define SSLv3_client_method yaSSLv3_client_method
#define TLSv1_server_method yaTLSv1_server_method #define TLSv1_server_method yaTLSv1_server_method
#define TLSv1_client_method yaTLSv1_client_method #define TLSv1_client_method yaTLSv1_client_method
#define TLSv1_1_server_method yaTLSv1_1_server_method
#define TLSv1_1_client_method yaTLSv1_1_client_method
#define SSLv23_server_method yaSSLv23_server_method #define SSLv23_server_method yaSSLv23_server_method
#define SSL_CTX_use_certificate_file yaSSL_CTX_use_certificate_file #define SSL_CTX_use_certificate_file yaSSL_CTX_use_certificate_file
#define SSL_CTX_use_PrivateKey_file yaSSL_CTX_use_PrivateKey_file #define SSL_CTX_use_PrivateKey_file yaSSL_CTX_use_PrivateKey_file
...@@ -159,3 +162,4 @@ ...@@ -159,3 +162,4 @@
#define MD5_Init yaMD5_Init #define MD5_Init yaMD5_Init
#define MD5_Update yaMD5_Update #define MD5_Update yaMD5_Update
#define MD5_Final yaMD5_Final #define MD5_Final yaMD5_Final
#define SSL_set_compression yaSSL_set_compression
...@@ -41,7 +41,7 @@ ...@@ -41,7 +41,7 @@
#include "rsa.h" #include "rsa.h"
#define YASSL_VERSION "1.4.3" #define YASSL_VERSION "1.5.0"
#if defined(__cplusplus) #if defined(__cplusplus)
...@@ -228,6 +228,7 @@ void SSL_load_error_strings(void); ...@@ -228,6 +228,7 @@ void SSL_load_error_strings(void);
int SSL_set_session(SSL *ssl, SSL_SESSION *session); int SSL_set_session(SSL *ssl, SSL_SESSION *session);
SSL_SESSION* SSL_get_session(SSL* ssl); SSL_SESSION* SSL_get_session(SSL* ssl);
long SSL_SESSION_set_timeout(SSL_SESSION*, long); long SSL_SESSION_set_timeout(SSL_SESSION*, long);
long SSL_CTX_set_session_cache_mode(SSL_CTX* ctx, long mode);
X509* SSL_get_peer_certificate(SSL*); X509* SSL_get_peer_certificate(SSL*);
long SSL_get_verify_result(SSL*); long SSL_get_verify_result(SSL*);
...@@ -361,6 +362,8 @@ SSL_METHOD *SSLv3_server_method(void); ...@@ -361,6 +362,8 @@ SSL_METHOD *SSLv3_server_method(void);
SSL_METHOD *SSLv3_client_method(void); SSL_METHOD *SSLv3_client_method(void);
SSL_METHOD *TLSv1_server_method(void); SSL_METHOD *TLSv1_server_method(void);
SSL_METHOD *TLSv1_client_method(void); SSL_METHOD *TLSv1_client_method(void);
SSL_METHOD *TLSv1_1_server_method(void);
SSL_METHOD *TLSv1_1_client_method(void);
SSL_METHOD *SSLv23_server_method(void); SSL_METHOD *SSLv23_server_method(void);
int SSL_CTX_use_certificate_file(SSL_CTX*, const char*, int); int SSL_CTX_use_certificate_file(SSL_CTX*, const char*, int);
...@@ -531,6 +534,10 @@ void MD5_Final(unsigned char*, MD5_CTX*); ...@@ -531,6 +534,10 @@ void MD5_Final(unsigned char*, MD5_CTX*);
#define SSL_DEFAULT_CIPHER_LIST "" /* default all */ #define SSL_DEFAULT_CIPHER_LIST "" /* default all */
/* yaSSL adds */
int SSL_set_compression(SSL*); /* turn on yaSSL zlib compression */
#if defined(__cplusplus) && !defined(YASSL_MYSQL_COMPATIBLE) #if defined(__cplusplus) && !defined(YASSL_MYSQL_COMPATIBLE)
......
...@@ -70,8 +70,8 @@ typedef unsigned char byte; ...@@ -70,8 +70,8 @@ typedef unsigned char byte;
// Wraps Windows Sockets and BSD Sockets // Wraps Windows Sockets and BSD Sockets
class Socket { class Socket {
socket_t socket_; // underlying socket descriptor socket_t socket_; // underlying socket descriptor
bool wouldBlock_; // for non-blocking data bool wouldBlock_; // if non-blocking data, for last read
bool blocking_; // is option set bool nonBlocking_; // is option set
public: public:
explicit Socket(socket_t s = INVALID_SOCKET); explicit Socket(socket_t s = INVALID_SOCKET);
~Socket(); ~Socket();
...@@ -85,7 +85,7 @@ public: ...@@ -85,7 +85,7 @@ public:
bool wait(); bool wait();
bool WouldBlock() const; bool WouldBlock() const;
bool IsBlocking() const; bool IsNonBlocking() const;
void closeSocket(); void closeSocket();
void shutDown(int how = SD_SEND); void shutDown(int how = SD_SEND);
......
...@@ -56,7 +56,10 @@ enum YasslError { ...@@ -56,7 +56,10 @@ enum YasslError {
receive_error = 114, receive_error = 114,
certificate_error = 115, certificate_error = 115,
privateKey_error = 116, privateKey_error = 116,
badVersion_error = 117 badVersion_error = 117,
compress_error = 118,
decompress_error = 119,
pms_version_error = 120
// !!!! add error message to .cpp !!!! // !!!! add error message to .cpp !!!!
......
...@@ -132,7 +132,6 @@ class Data : public Message { ...@@ -132,7 +132,6 @@ class Data : public Message {
public: public:
Data(); Data();
Data(uint16 len, opaque* b); Data(uint16 len, opaque* b);
Data(uint16 len, const opaque* w);
friend output_buffer& operator<<(output_buffer&, const Data&); friend output_buffer& operator<<(output_buffer&, const Data&);
...@@ -141,9 +140,9 @@ public: ...@@ -141,9 +140,9 @@ public:
ContentType get_type() const; ContentType get_type() const;
uint16 get_length() const; uint16 get_length() const;
const opaque* get_buffer() const;
void set_length(uint16 l); void set_length(uint16 l);
opaque* set_buffer(); opaque* set_buffer();
void SetData(uint16, const opaque*);
void Process(input_buffer&, SSL&); void Process(input_buffer&, SSL&);
private: private:
Data(const Data&); // hide copy Data(const Data&); // hide copy
...@@ -232,11 +231,11 @@ public: ...@@ -232,11 +231,11 @@ public:
void Process(input_buffer&, SSL&); void Process(input_buffer&, SSL&);
const opaque* get_random() const; const opaque* get_random() const;
friend void buildClientHello(SSL&, ClientHello&, CompressionMethod); friend void buildClientHello(SSL&, ClientHello&);
friend void ProcessOldClientHello(input_buffer& input, SSL& ssl); friend void ProcessOldClientHello(input_buffer& input, SSL& ssl);
ClientHello(); ClientHello();
explicit ClientHello(ProtocolVersion pv); ClientHello(ProtocolVersion pv, bool useCompression);
private: private:
ClientHello(const ClientHello&); // hide copy ClientHello(const ClientHello&); // hide copy
ClientHello& operator=(const ClientHello&); // and assign ClientHello& operator=(const ClientHello&); // and assign
...@@ -253,7 +252,7 @@ class ServerHello : public HandShakeBase { ...@@ -253,7 +252,7 @@ class ServerHello : public HandShakeBase {
opaque cipher_suite_[SUITE_LEN]; opaque cipher_suite_[SUITE_LEN];
CompressionMethod compression_method_; CompressionMethod compression_method_;
public: public:
explicit ServerHello(ProtocolVersion pv); ServerHello(ProtocolVersion pv, bool useCompression);
ServerHello(); ServerHello();
friend input_buffer& operator>>(input_buffer&, ServerHello&); friend input_buffer& operator>>(input_buffer&, ServerHello&);
...@@ -629,8 +628,11 @@ struct Connection { ...@@ -629,8 +628,11 @@ struct Connection {
bool send_server_key_; // server key exchange? bool send_server_key_; // server key exchange?
bool master_clean_; // master secret clean? bool master_clean_; // master secret clean?
bool TLS_; // TLSv1 or greater bool TLS_; // TLSv1 or greater
bool TLSv1_1_; // TLSv1.1 or greater
bool sessionID_Set_; // do we have a session bool sessionID_Set_; // do we have a session
ProtocolVersion version_; bool compression_; // zlib compression?
ProtocolVersion version_; // negotiated version
ProtocolVersion chVersion_; // client hello version
RandomPool& random_; RandomPool& random_;
Connection(ProtocolVersion v, RandomPool& ran); Connection(ProtocolVersion v, RandomPool& ran);
...@@ -640,6 +642,7 @@ struct Connection { ...@@ -640,6 +642,7 @@ struct Connection {
void CleanPreMaster(); void CleanPreMaster();
void CleanMaster(); void CleanMaster();
void TurnOffTLS(); void TurnOffTLS();
void TurnOffTLS1_1();
private: private:
Connection(const Connection&); // hide copy Connection(const Connection&); // hide copy
Connection& operator=(const Connection&); // and assign Connection& operator=(const Connection&); // and assign
......
...@@ -431,6 +431,7 @@ private: ...@@ -431,6 +431,7 @@ private:
DH_Parms dhParms_; DH_Parms dhParms_;
pem_password_cb passwordCb_; pem_password_cb passwordCb_;
void* userData_; void* userData_;
bool sessionCacheOff_;
Stats stats_; Stats stats_;
Mutex mutex_; // for Stats Mutex mutex_; // for Stats
public: public:
...@@ -445,6 +446,7 @@ public: ...@@ -445,6 +446,7 @@ public:
const Stats& GetStats() const; const Stats& GetStats() const;
pem_password_cb GetPasswordCb() const; pem_password_cb GetPasswordCb() const;
void* GetUserData() const; void* GetUserData() const;
bool GetSessionCacheOff() const;
void setVerifyPeer(); void setVerifyPeer();
void setVerifyNone(); void setVerifyNone();
...@@ -453,6 +455,7 @@ public: ...@@ -453,6 +455,7 @@ public:
bool SetDH(const DH&); bool SetDH(const DH&);
void SetPasswordCb(pem_password_cb cb); void SetPasswordCb(pem_password_cb cb);
void SetUserData(void*); void SetUserData(void*);
void SetSessionCacheOff();
void IncrementStats(StatsField); void IncrementStats(StatsField);
void AddCA(x509* ca); void AddCA(x509* ca);
...@@ -600,6 +603,7 @@ public: ...@@ -600,6 +603,7 @@ public:
const Socket& getSocket() const; const Socket& getSocket() const;
YasslError GetError() const; YasslError GetError() const;
bool GetMultiProtocol() const; bool GetMultiProtocol() const;
bool CompressionOn() const;
Crypto& useCrypto(); Crypto& useCrypto();
Security& useSecurity(); Security& useSecurity();
...@@ -617,9 +621,12 @@ public: ...@@ -617,9 +621,12 @@ public:
void set_preMaster(const opaque*, uint); void set_preMaster(const opaque*, uint);
void set_masterSecret(const opaque*); void set_masterSecret(const opaque*);
void SetError(YasslError); void SetError(YasslError);
int SetCompression();
void UnSetCompression();
// helpers // helpers
bool isTLS() const; bool isTLS() const;
bool isTLSv1_1() const;
void order_error(); void order_error();
void makeMasterSecret(); void makeMasterSecret();
void makeTLSMasterSecret(); void makeTLSMasterSecret();
...@@ -653,6 +660,10 @@ private: ...@@ -653,6 +660,10 @@ private:
}; };
// compression
int Compress(const byte*, int, input_buffer&);
int DeCompress(input_buffer&, int, input_buffer&);
// conversion functions // conversion functions
void c32to24(uint32, uint24&); void c32to24(uint32, uint24&);
......
...@@ -211,6 +211,7 @@ const int FINISHED_LABEL_SZ = 15; // TLS finished lable length ...@@ -211,6 +211,7 @@ const int FINISHED_LABEL_SZ = 15; // TLS finished lable length
const int SEED_LEN = RAN_LEN * 2; // TLS seed, client + server random const int SEED_LEN = RAN_LEN * 2; // TLS seed, client + server random
const int DEFAULT_TIMEOUT = 500; // Default Session timeout in seconds const int DEFAULT_TIMEOUT = 500; // Default Session timeout in seconds
const int MAX_RECORD_SIZE = 16384; // 2^14, max size by standard const int MAX_RECORD_SIZE = 16384; // 2^14, max size by standard
const int COMPRESS_EXTRA = 1024; // extra compression possible addition
typedef uint8 Cipher; // first byte is always 0x00 for SSLv3 & TLS typedef uint8 Cipher; // first byte is always 0x00 for SSLv3 & TLS
...@@ -222,7 +223,7 @@ typedef opaque* DistinguishedName; ...@@ -222,7 +223,7 @@ typedef opaque* DistinguishedName;
typedef bool IsExportable; typedef bool IsExportable;
enum CompressionMethod { no_compression = 0 }; enum CompressionMethod { no_compression = 0, zlib = 221 };
enum CipherType { stream, block }; enum CipherType { stream, block };
......
...@@ -40,9 +40,11 @@ namespace yaSSL { ...@@ -40,9 +40,11 @@ namespace yaSSL {
// Build a client hello message from cipher suites and compression method // Build a client hello message from cipher suites and compression method
void buildClientHello(SSL& ssl, ClientHello& hello, void buildClientHello(SSL& ssl, ClientHello& hello)
CompressionMethod compression = no_compression)
{ {
// store for pre master secret
ssl.useSecurity().use_connection().chVersion_ = hello.client_version_;
ssl.getCrypto().get_random().Fill(hello.random_, RAN_LEN); ssl.getCrypto().get_random().Fill(hello.random_, RAN_LEN);
if (ssl.getSecurity().get_resuming()) { if (ssl.getSecurity().get_resuming()) {
hello.id_len_ = ID_LEN; hello.id_len_ = ID_LEN;
...@@ -55,7 +57,6 @@ void buildClientHello(SSL& ssl, ClientHello& hello, ...@@ -55,7 +57,6 @@ void buildClientHello(SSL& ssl, ClientHello& hello,
memcpy(hello.cipher_suites_, ssl.getSecurity().get_parms().suites_, memcpy(hello.cipher_suites_, ssl.getSecurity().get_parms().suites_,
hello.suite_len_); hello.suite_len_);
hello.comp_len_ = 1; hello.comp_len_ = 1;
hello.compression_methods_ = compression;
hello.set_length(sizeof(ProtocolVersion) + hello.set_length(sizeof(ProtocolVersion) +
RAN_LEN + RAN_LEN +
...@@ -83,7 +84,7 @@ void buildServerHello(SSL& ssl, ServerHello& hello) ...@@ -83,7 +84,7 @@ void buildServerHello(SSL& ssl, ServerHello& hello)
hello.cipher_suite_[0] = ssl.getSecurity().get_parms().suite_[0]; hello.cipher_suite_[0] = ssl.getSecurity().get_parms().suite_[0];
hello.cipher_suite_[1] = ssl.getSecurity().get_parms().suite_[1]; hello.cipher_suite_[1] = ssl.getSecurity().get_parms().suite_[1];
hello.compression_method_ = no_compression; hello.compression_method_ = hello.compression_method_;
hello.set_length(sizeof(ProtocolVersion) + RAN_LEN + ID_LEN + hello.set_length(sizeof(ProtocolVersion) + RAN_LEN + ID_LEN +
sizeof(hello.id_len_) + SUITE_LEN + SIZEOF_ENUM); sizeof(hello.id_len_) + SUITE_LEN + SIZEOF_ENUM);
...@@ -151,12 +152,18 @@ void buildHeaders(SSL& ssl, HandShakeHeader& hsHeader, ...@@ -151,12 +152,18 @@ void buildHeaders(SSL& ssl, HandShakeHeader& hsHeader,
// add handshake from buffer into md5 and sha hashes, exclude record header // add handshake from buffer into md5 and sha hashes, exclude record header
void hashHandShake(SSL& ssl, const output_buffer& output) void hashHandShake(SSL& ssl, const output_buffer& output, bool removeIV = false)
{ {
uint sz = output.get_size() - RECORD_HEADER; uint sz = output.get_size() - RECORD_HEADER;
const opaque* buffer = output.get_buffer() + RECORD_HEADER; const opaque* buffer = output.get_buffer() + RECORD_HEADER;
if (removeIV) { // TLSv1_1 IV
uint blockSz = ssl.getCrypto().get_cipher().get_blockSize();
sz -= blockSz;
buffer += blockSz;
}
ssl.useHashes().use_MD5().update(buffer, sz); ssl.useHashes().use_MD5().update(buffer, sz);
ssl.useHashes().use_SHA().update(buffer, sz); ssl.useHashes().use_SHA().update(buffer, sz);
} }
...@@ -229,6 +236,18 @@ void decrypt_message(SSL& ssl, input_buffer& input, uint sz) ...@@ -229,6 +236,18 @@ void decrypt_message(SSL& ssl, input_buffer& input, uint sz)
ssl.useCrypto().use_cipher().decrypt(plain.get_buffer(), cipher, sz); ssl.useCrypto().use_cipher().decrypt(plain.get_buffer(), cipher, sz);
memcpy(cipher, plain.get_buffer(), sz); memcpy(cipher, plain.get_buffer(), sz);
ssl.useSecurity().use_parms().encrypt_size_ = sz; ssl.useSecurity().use_parms().encrypt_size_ = sz;
if (ssl.isTLSv1_1()) // IV
input.set_current(input.get_current() +
ssl.getCrypto().get_cipher().get_blockSize());
}
// output operator for input_buffer
output_buffer& operator<<(output_buffer& output, const input_buffer& input)
{
output.write(input.get_buffer(), input.get_size());
return output;
} }
...@@ -239,9 +258,12 @@ void cipherFinished(SSL& ssl, Finished& fin, output_buffer& output) ...@@ -239,9 +258,12 @@ void cipherFinished(SSL& ssl, Finished& fin, output_buffer& output)
uint finishedSz = ssl.isTLS() ? TLS_FINISHED_SZ : FINISHED_SZ; uint finishedSz = ssl.isTLS() ? TLS_FINISHED_SZ : FINISHED_SZ;
uint sz = RECORD_HEADER + HANDSHAKE_HEADER + finishedSz + digestSz; uint sz = RECORD_HEADER + HANDSHAKE_HEADER + finishedSz + digestSz;
uint pad = 0; uint pad = 0;
uint blockSz = ssl.getCrypto().get_cipher().get_blockSize();
if (ssl.getSecurity().get_parms().cipher_type_ == block) { if (ssl.getSecurity().get_parms().cipher_type_ == block) {
if (ssl.isTLSv1_1())
sz += blockSz; // IV
sz += 1; // pad byte sz += 1; // pad byte
uint blockSz = ssl.getCrypto().get_cipher().get_blockSize();
pad = (sz - RECORD_HEADER) % blockSz; pad = (sz - RECORD_HEADER) % blockSz;
pad = blockSz - pad; pad = blockSz - pad;
sz += pad; sz += pad;
...@@ -252,14 +274,21 @@ void cipherFinished(SSL& ssl, Finished& fin, output_buffer& output) ...@@ -252,14 +274,21 @@ void cipherFinished(SSL& ssl, Finished& fin, output_buffer& output)
buildHeaders(ssl, hsHeader, rlHeader, fin); buildHeaders(ssl, hsHeader, rlHeader, fin);
rlHeader.length_ = sz - RECORD_HEADER; // record header includes mac rlHeader.length_ = sz - RECORD_HEADER; // record header includes mac
// and pad, hanshake doesn't // and pad, hanshake doesn't
input_buffer iv;
if (ssl.isTLSv1_1() && ssl.getSecurity().get_parms().cipher_type_== block){
iv.allocate(blockSz);
ssl.getCrypto().get_random().Fill(iv.get_buffer(), blockSz);
iv.add_size(blockSz);
}
uint ivSz = iv.get_size();
output.allocate(sz); output.allocate(sz);
output << rlHeader << hsHeader << fin; output << rlHeader << iv << hsHeader << fin;
hashHandShake(ssl, output); hashHandShake(ssl, output, ssl.isTLSv1_1() ? true : false);
opaque digest[SHA_LEN]; // max size opaque digest[SHA_LEN]; // max size
if (ssl.isTLS()) if (ssl.isTLS())
TLS_hmac(ssl, digest, output.get_buffer() + RECORD_HEADER, TLS_hmac(ssl, digest, output.get_buffer() + RECORD_HEADER + ivSz,
output.get_size() - RECORD_HEADER, handshake); output.get_size() - RECORD_HEADER - ivSz, handshake);
else else
hmac(ssl, digest, output.get_buffer() + RECORD_HEADER, hmac(ssl, digest, output.get_buffer() + RECORD_HEADER,
output.get_size() - RECORD_HEADER, handshake); output.get_size() - RECORD_HEADER, handshake);
...@@ -282,9 +311,12 @@ void buildMessage(SSL& ssl, output_buffer& output, const Message& msg) ...@@ -282,9 +311,12 @@ void buildMessage(SSL& ssl, output_buffer& output, const Message& msg)
uint digestSz = ssl.getCrypto().get_digest().get_digestSize(); uint digestSz = ssl.getCrypto().get_digest().get_digestSize();
uint sz = RECORD_HEADER + msg.get_length() + digestSz; uint sz = RECORD_HEADER + msg.get_length() + digestSz;
uint pad = 0; uint pad = 0;
uint blockSz = ssl.getCrypto().get_cipher().get_blockSize();
if (ssl.getSecurity().get_parms().cipher_type_ == block) { if (ssl.getSecurity().get_parms().cipher_type_ == block) {
if (ssl.isTLSv1_1()) // IV
sz += blockSz;
sz += 1; // pad byte sz += 1; // pad byte
uint blockSz = ssl.getCrypto().get_cipher().get_blockSize();
pad = (sz - RECORD_HEADER) % blockSz; pad = (sz - RECORD_HEADER) % blockSz;
pad = blockSz - pad; pad = blockSz - pad;
sz += pad; sz += pad;
...@@ -294,13 +326,21 @@ void buildMessage(SSL& ssl, output_buffer& output, const Message& msg) ...@@ -294,13 +326,21 @@ void buildMessage(SSL& ssl, output_buffer& output, const Message& msg)
buildHeader(ssl, rlHeader, msg); buildHeader(ssl, rlHeader, msg);
rlHeader.length_ = sz - RECORD_HEADER; // record header includes mac rlHeader.length_ = sz - RECORD_HEADER; // record header includes mac
// and pad, hanshake doesn't // and pad, hanshake doesn't
input_buffer iv;
if (ssl.isTLSv1_1() && ssl.getSecurity().get_parms().cipher_type_== block){
iv.allocate(blockSz);
ssl.getCrypto().get_random().Fill(iv.get_buffer(), blockSz);
iv.add_size(blockSz);
}
uint ivSz = iv.get_size();
output.allocate(sz); output.allocate(sz);
output << rlHeader << msg; output << rlHeader << iv << msg;
opaque digest[SHA_LEN]; // max size opaque digest[SHA_LEN]; // max size
if (ssl.isTLS()) if (ssl.isTLS())
TLS_hmac(ssl, digest, output.get_buffer() + RECORD_HEADER, TLS_hmac(ssl, digest, output.get_buffer() + RECORD_HEADER + ivSz,
output.get_size() - RECORD_HEADER, msg.get_type()); output.get_size() - RECORD_HEADER - ivSz, msg.get_type());
else else
hmac(ssl, digest, output.get_buffer() + RECORD_HEADER, hmac(ssl, digest, output.get_buffer() + RECORD_HEADER,
output.get_size() - RECORD_HEADER, msg.get_type()); output.get_size() - RECORD_HEADER, msg.get_type());
...@@ -456,6 +496,10 @@ void buildSHA_CertVerify(SSL& ssl, byte* digest) ...@@ -456,6 +496,10 @@ void buildSHA_CertVerify(SSL& ssl, byte* digest)
// some clients still send sslv2 client hello // some clients still send sslv2 client hello
void ProcessOldClientHello(input_buffer& input, SSL& ssl) void ProcessOldClientHello(input_buffer& input, SSL& ssl)
{ {
if (input.get_remaining() < 2) {
ssl.SetError(bad_input);
return;
}
byte b0 = input[AUTO]; byte b0 = input[AUTO];
byte b1 = input[AUTO]; byte b1 = input[AUTO];
...@@ -721,6 +765,7 @@ int DoProcessReply(SSL& ssl) ...@@ -721,6 +765,7 @@ int DoProcessReply(SSL& ssl)
// each message in record, can be more than 1 if not encrypted // each message in record, can be more than 1 if not encrypted
if (ssl.getSecurity().get_parms().pending_ == false) // cipher on if (ssl.getSecurity().get_parms().pending_ == false) // cipher on
decrypt_message(ssl, buffer, hdr.length_); decrypt_message(ssl, buffer, hdr.length_);
mySTL::auto_ptr<Message> msg(mf.CreateObject(hdr.type_)); mySTL::auto_ptr<Message> msg(mf.CreateObject(hdr.type_));
if (!msg.get()) { if (!msg.get()) {
ssl.SetError(factory_error); ssl.SetError(factory_error);
...@@ -744,13 +789,13 @@ void processReply(SSL& ssl) ...@@ -744,13 +789,13 @@ void processReply(SSL& ssl)
if (DoProcessReply(ssl)) if (DoProcessReply(ssl))
// didn't complete process // didn't complete process
if (!ssl.getSocket().IsBlocking()) { if (!ssl.getSocket().IsNonBlocking()) {
// keep trying now // keep trying now, blocking ok
while (!ssl.GetError()) while (!ssl.GetError())
if (DoProcessReply(ssl) == 0) break; if (DoProcessReply(ssl) == 0) break;
} }
else else
// user will have try again later // user will have try again later, non blocking
ssl.SetError(YasslError(SSL_ERROR_WANT_READ)); ssl.SetError(YasslError(SSL_ERROR_WANT_READ));
} }
...@@ -761,7 +806,8 @@ void sendClientHello(SSL& ssl) ...@@ -761,7 +806,8 @@ void sendClientHello(SSL& ssl)
ssl.verifyState(serverNull); ssl.verifyState(serverNull);
if (ssl.GetError()) return; if (ssl.GetError()) return;
ClientHello ch(ssl.getSecurity().get_connection().version_); ClientHello ch(ssl.getSecurity().get_connection().version_,
ssl.getSecurity().get_connection().compression_);
RecordLayerHeader rlHeader; RecordLayerHeader rlHeader;
HandShakeHeader hsHeader; HandShakeHeader hsHeader;
output_buffer out; output_buffer out;
...@@ -859,6 +905,7 @@ void sendFinished(SSL& ssl, ConnectionEnd side, BufferOutput buffer) ...@@ -859,6 +905,7 @@ void sendFinished(SSL& ssl, ConnectionEnd side, BufferOutput buffer)
buildFinished(ssl, ssl.useHashes().use_verify(), client); // client buildFinished(ssl, ssl.useHashes().use_verify(), client); // client
} }
else { else {
if (!ssl.getSecurity().GetContext()->GetSessionCacheOff())
GetSessions().add(ssl); // store session GetSessions().add(ssl); // store session
if (side == client_end) if (side == client_end)
buildFinished(ssl, ssl.useHashes().use_verify(), server); // server buildFinished(ssl, ssl.useHashes().use_verify(), server); // server
...@@ -885,7 +932,20 @@ int sendData(SSL& ssl, const void* buffer, int sz) ...@@ -885,7 +932,20 @@ int sendData(SSL& ssl, const void* buffer, int sz)
for (;;) { for (;;) {
int len = min(sz - sent, MAX_RECORD_SIZE); int len = min(sz - sent, MAX_RECORD_SIZE);
output_buffer out; output_buffer out;
const Data data(len, static_cast<const opaque*>(buffer) + sent); input_buffer tmp;
Data data;
if (ssl.CompressionOn()) {
if (Compress(static_cast<const opaque*>(buffer) + sent, len,
tmp) == -1) {
ssl.SetError(compress_error);
return -1;
}
data.SetData(tmp.get_size(), tmp.get_buffer());
}
else
data.SetData(len, static_cast<const opaque*>(buffer) + sent);
buildMessage(ssl, out, data); buildMessage(ssl, out, data);
ssl.Send(out.get_buffer(), out.get_size()); ssl.Send(out.get_buffer(), out.get_size());
...@@ -947,7 +1007,8 @@ void sendServerHello(SSL& ssl, BufferOutput buffer) ...@@ -947,7 +1007,8 @@ void sendServerHello(SSL& ssl, BufferOutput buffer)
ssl.verifyState(clientHelloComplete); ssl.verifyState(clientHelloComplete);
if (ssl.GetError()) return; if (ssl.GetError()) return;
ServerHello sh(ssl.getSecurity().get_connection().version_); ServerHello sh(ssl.getSecurity().get_connection().version_,
ssl.getSecurity().get_connection().compression_);
RecordLayerHeader rlHeader; RecordLayerHeader rlHeader;
HandShakeHeader hsHeader; HandShakeHeader hsHeader;
mySTL::auto_ptr<output_buffer> out(NEW_YS output_buffer); mySTL::auto_ptr<output_buffer> out(NEW_YS output_buffer);
......
REM quick and dirty build file for testing different MSDEVs REM quick and dirty build file for testing different MSDEVs
setlocal setlocal
set myFLAGS= /I../include /I../mySTL /I../taocrypt/include /W3 /c /ZI set myFLAGS= /I../include /I../taocrypt/mySTL /I../taocrypt/include /W3 /c /ZI
cl %myFLAGS% buffer.cpp cl %myFLAGS% buffer.cpp
cl %myFLAGS% cert_wrapper.cpp cl %myFLAGS% cert_wrapper.cpp
......
...@@ -63,7 +63,7 @@ namespace yaSSL { ...@@ -63,7 +63,7 @@ namespace yaSSL {
Socket::Socket(socket_t s) Socket::Socket(socket_t s)
: socket_(s), wouldBlock_(false), blocking_(false) : socket_(s), wouldBlock_(false), nonBlocking_(false)
{} {}
...@@ -148,8 +148,8 @@ uint Socket::receive(byte* buf, unsigned int sz, int flags) ...@@ -148,8 +148,8 @@ uint Socket::receive(byte* buf, unsigned int sz, int flags)
if (recvd == -1) { if (recvd == -1) {
if (get_lastError() == SOCKET_EWOULDBLOCK || if (get_lastError() == SOCKET_EWOULDBLOCK ||
get_lastError() == SOCKET_EAGAIN) { get_lastError() == SOCKET_EAGAIN) {
wouldBlock_ = true; wouldBlock_ = true; // would have blocked this time only
blocking_ = true; // socket can block, only way to tell for win32 nonBlocking_ = true; // socket nonblocking, win32 only way to tell
return 0; return 0;
} }
} }
...@@ -191,9 +191,9 @@ bool Socket::WouldBlock() const ...@@ -191,9 +191,9 @@ bool Socket::WouldBlock() const
} }
bool Socket::IsBlocking() const bool Socket::IsNonBlocking() const
{ {
return blocking_; return nonBlocking_;
} }
......
...@@ -184,10 +184,22 @@ SSL_METHOD* TLSv1_client_method() ...@@ -184,10 +184,22 @@ SSL_METHOD* TLSv1_client_method()
} }
SSL_METHOD* TLSv1_1_server_method()
{
return NEW_YS SSL_METHOD(server_end, ProtocolVersion(3,2));
}
SSL_METHOD* TLSv1_1_client_method()
{
return NEW_YS SSL_METHOD(client_end, ProtocolVersion(3,2));
}
SSL_METHOD* SSLv23_server_method() SSL_METHOD* SSLv23_server_method()
{ {
// compatibility only, no version 2 support, but does SSL 3 and TLS 1 // compatibility only, no version 2 support, but does SSL 3 and TLS 1
return NEW_YS SSL_METHOD(server_end, ProtocolVersion(3,1), true); return NEW_YS SSL_METHOD(server_end, ProtocolVersion(3,2), true);
} }
...@@ -196,7 +208,7 @@ SSL_METHOD* SSLv23_client_method() ...@@ -196,7 +208,7 @@ SSL_METHOD* SSLv23_client_method()
// compatibility only, no version 2 support, but does SSL 3 and TLS 1 // compatibility only, no version 2 support, but does SSL 3 and TLS 1
// though it sends TLS1 hello not SSLv2 so SSLv3 only servers will decline // though it sends TLS1 hello not SSLv2 so SSLv3 only servers will decline
// TODO: maybe add support to send SSLv2 hello ??? // TODO: maybe add support to send SSLv2 hello ???
return NEW_YS SSL_METHOD(client_end, ProtocolVersion(3,1), true); return NEW_YS SSL_METHOD(client_end, ProtocolVersion(3,2), true);
} }
...@@ -407,7 +419,6 @@ int SSL_shutdown(SSL* ssl) ...@@ -407,7 +419,6 @@ int SSL_shutdown(SSL* ssl)
Alert alert(warning, close_notify); Alert alert(warning, close_notify);
sendAlert(*ssl, alert); sendAlert(*ssl, alert);
ssl->useLog().ShowTCP(ssl->getSocket().get_fd(), true); ssl->useLog().ShowTCP(ssl->getSocket().get_fd(), true);
ssl->useSocket().closeSocket();
GetErrors().Remove(); GetErrors().Remove();
...@@ -415,8 +426,21 @@ int SSL_shutdown(SSL* ssl) ...@@ -415,8 +426,21 @@ int SSL_shutdown(SSL* ssl)
} }
/* on by default but allow user to turn off */
long SSL_CTX_set_session_cache_mode(SSL_CTX* ctx, long mode)
{
if (mode == SSL_SESS_CACHE_OFF)
ctx->SetSessionCacheOff();
return SSL_SUCCESS;
}
SSL_SESSION* SSL_get_session(SSL* ssl) SSL_SESSION* SSL_get_session(SSL* ssl)
{ {
if (ssl->getSecurity().GetContext()->GetSessionCacheOff())
return 0;
return GetSessions().lookup( return GetSessions().lookup(
ssl->getSecurity().get_connection().sessionID_); ssl->getSecurity().get_connection().sessionID_);
} }
...@@ -424,6 +448,9 @@ SSL_SESSION* SSL_get_session(SSL* ssl) ...@@ -424,6 +448,9 @@ SSL_SESSION* SSL_get_session(SSL* ssl)
int SSL_set_session(SSL* ssl, SSL_SESSION* session) int SSL_set_session(SSL* ssl, SSL_SESSION* session)
{ {
if (ssl->getSecurity().GetContext()->GetSessionCacheOff())
return SSL_FAILURE;
ssl->set_session(session); ssl->set_session(session);
return SSL_SUCCESS; return SSL_SUCCESS;
} }
...@@ -512,6 +539,19 @@ int SSL_get_error(SSL* ssl, int /*previous*/) ...@@ -512,6 +539,19 @@ int SSL_get_error(SSL* ssl, int /*previous*/)
} }
/* turn on yaSSL zlib compression
returns 0 for success, else error (not built in)
only need to turn on for client, becuase server on by default if built in
but calling for server will tell you whether it's available or not
*/
int SSL_set_compression(SSL* ssl)
{
return ssl->SetCompression();
}
X509* SSL_get_peer_certificate(SSL* ssl) X509* SSL_get_peer_certificate(SSL* ssl)
{ {
return ssl->getCrypto().get_certManager().get_peerX509(); return ssl->getCrypto().get_certManager().get_peerX509();
...@@ -1359,6 +1399,56 @@ int SSL_pending(SSL* ssl) ...@@ -1359,6 +1399,56 @@ int SSL_pending(SSL* ssl)
} }
void SSL_CTX_set_default_passwd_cb(SSL_CTX* ctx, pem_password_cb cb)
{
ctx->SetPasswordCb(cb);
}
int SSLeay_add_ssl_algorithms() // compatibility only
{
return 1;
}
void ERR_remove_state(unsigned long)
{
GetErrors().Remove();
}
int ERR_GET_REASON(int l)
{
return l & 0xfff;
}
unsigned long err_helper(bool peek = false)
{
int ysError = GetErrors().Lookup(peek);
// translate cert error for libcurl, it uses OpenSSL hex code
switch (ysError) {
case TaoCrypt::SIG_OTHER_E:
return CERTFICATE_ERROR;
break;
default :
return 0;
}
}
unsigned long ERR_peek_error()
{
return err_helper(true);
}
unsigned long ERR_get_error()
{
return err_helper();
}
// functions for stunnel // functions for stunnel
...@@ -1477,13 +1567,6 @@ int SSL_pending(SSL* ssl) ...@@ -1477,13 +1567,6 @@ int SSL_pending(SSL* ssl)
} }
long SSL_CTX_set_session_cache_mode(SSL_CTX*, long)
{
// TDOD:
return SSL_SUCCESS;
}
long SSL_CTX_set_timeout(SSL_CTX*, long) long SSL_CTX_set_timeout(SSL_CTX*, long)
{ {
// TDOD: // TDOD:
...@@ -1498,12 +1581,6 @@ int SSL_pending(SSL* ssl) ...@@ -1498,12 +1581,6 @@ int SSL_pending(SSL* ssl)
} }
void SSL_CTX_set_default_passwd_cb(SSL_CTX* ctx, pem_password_cb cb)
{
ctx->SetPasswordCb(cb);
}
int SSL_CTX_use_RSAPrivateKey_file(SSL_CTX*, const char*, int) int SSL_CTX_use_RSAPrivateKey_file(SSL_CTX*, const char*, int)
{ {
// TDOD: // TDOD:
...@@ -1555,49 +1632,6 @@ int SSL_pending(SSL* ssl) ...@@ -1555,49 +1632,6 @@ int SSL_pending(SSL* ssl)
} }
int SSLeay_add_ssl_algorithms() // compatibility only
{
return 1;
}
void ERR_remove_state(unsigned long)
{
GetErrors().Remove();
}
int ERR_GET_REASON(int l)
{
return l & 0xfff;
}
unsigned long err_helper(bool peek = false)
{
int ysError = GetErrors().Lookup(peek);
// translate cert error for libcurl, it uses OpenSSL hex code
switch (ysError) {
case TaoCrypt::SIG_OTHER_E:
return CERTFICATE_ERROR;
break;
default :
return 0;
}
}
unsigned long ERR_peek_error()
{
return err_helper(true);
}
unsigned long ERR_get_error()
{
return err_helper();
}
// end stunnel needs // end stunnel needs
......
...@@ -133,6 +133,18 @@ void SetErrorString(YasslError error, char* buffer) ...@@ -133,6 +133,18 @@ void SetErrorString(YasslError error, char* buffer)
strncpy(buffer, "protocl version mismatch", max); strncpy(buffer, "protocl version mismatch", max);
break; break;
case compress_error :
strncpy(buffer, "compression error", max);
break;
case decompress_error :
strncpy(buffer, "decompression error", max);
break;
case pms_version_error :
strncpy(buffer, "bad PreMasterSecret version error", max);
break;
// openssl errors // openssl errors
case SSL_ERROR_WANT_READ : case SSL_ERROR_WANT_READ :
strncpy(buffer, "the read operation would block", max); strncpy(buffer, "the read operation would block", max);
......
...@@ -87,7 +87,7 @@ void EncryptedPreMasterSecret::build(SSL& ssl) ...@@ -87,7 +87,7 @@ void EncryptedPreMasterSecret::build(SSL& ssl)
opaque tmp[SECRET_LEN]; opaque tmp[SECRET_LEN];
memset(tmp, 0, sizeof(tmp)); memset(tmp, 0, sizeof(tmp));
ssl.getCrypto().get_random().Fill(tmp, SECRET_LEN); ssl.getCrypto().get_random().Fill(tmp, SECRET_LEN);
ProtocolVersion pv = ssl.getSecurity().get_connection().version_; ProtocolVersion pv = ssl.getSecurity().get_connection().chVersion_;
tmp[0] = pv.major_; tmp[0] = pv.major_;
tmp[1] = pv.minor_; tmp[1] = pv.minor_;
ssl.set_preMaster(tmp, SECRET_LEN); ssl.set_preMaster(tmp, SECRET_LEN);
...@@ -233,6 +233,10 @@ void EncryptedPreMasterSecret::read(SSL& ssl, input_buffer& input) ...@@ -233,6 +233,10 @@ void EncryptedPreMasterSecret::read(SSL& ssl, input_buffer& input)
rsa.decrypt(preMasterSecret, secret_, length_, rsa.decrypt(preMasterSecret, secret_, length_,
ssl.getCrypto().get_random()); ssl.getCrypto().get_random());
ProtocolVersion pv = ssl.getSecurity().get_connection().chVersion_;
if (pv.major_ != preMasterSecret[0] || pv.minor_ != preMasterSecret[1])
ssl.SetError(pms_version_error); // continue deriving for timing attack
ssl.set_preMaster(preMasterSecret, SECRET_LEN); ssl.set_preMaster(preMasterSecret, SECRET_LEN);
ssl.makeMasterSecret(); ssl.makeMasterSecret();
} }
...@@ -437,6 +441,7 @@ Parameters::Parameters(ConnectionEnd ce, const Ciphers& ciphers, ...@@ -437,6 +441,7 @@ Parameters::Parameters(ConnectionEnd ce, const Ciphers& ciphers,
ProtocolVersion pv, bool haveDH) : entity_(ce) ProtocolVersion pv, bool haveDH) : entity_(ce)
{ {
pending_ = true; // suite not set yet pending_ = true; // suite not set yet
strncpy(cipher_name_, "NONE", 5);
if (ciphers.setSuites_) { // use user set list if (ciphers.setSuites_) { // use user set list
suites_size_ = ciphers.suiteSz_; suites_size_ = ciphers.suiteSz_;
...@@ -445,6 +450,7 @@ Parameters::Parameters(ConnectionEnd ce, const Ciphers& ciphers, ...@@ -445,6 +450,7 @@ Parameters::Parameters(ConnectionEnd ce, const Ciphers& ciphers,
} }
else else
SetSuites(pv, ce == server_end && !haveDH); // defaults SetSuites(pv, ce == server_end && !haveDH); // defaults
} }
...@@ -613,14 +619,18 @@ output_buffer& operator<<(output_buffer& output, const HandShakeHeader& hdr) ...@@ -613,14 +619,18 @@ output_buffer& operator<<(output_buffer& output, const HandShakeHeader& hdr)
void HandShakeHeader::Process(input_buffer& input, SSL& ssl) void HandShakeHeader::Process(input_buffer& input, SSL& ssl)
{ {
ssl.verifyState(*this); ssl.verifyState(*this);
if (ssl.GetError()) return;
const HandShakeFactory& hsf = ssl.getFactory().getHandShake(); const HandShakeFactory& hsf = ssl.getFactory().getHandShake();
mySTL::auto_ptr<HandShakeBase> hs(hsf.CreateObject(type_)); mySTL::auto_ptr<HandShakeBase> hs(hsf.CreateObject(type_));
if (!hs.get()) { if (!hs.get()) {
ssl.SetError(factory_error); ssl.SetError(factory_error);
return; return;
} }
hashHandShake(ssl, input, c24to32(length_));
uint len = c24to32(length_);
hashHandShake(ssl, input, len);
hs->set_length(len);
input >> *hs; input >> *hs;
hs->Process(input, ssl); hs->Process(input, ssl);
} }
...@@ -849,11 +859,17 @@ void Alert::Process(input_buffer& input, SSL& ssl) ...@@ -849,11 +859,17 @@ void Alert::Process(input_buffer& input, SSL& ssl)
opaque mac[SHA_LEN]; opaque mac[SHA_LEN];
input.read(mac, digestSz); input.read(mac, digestSz);
if (ssl.getSecurity().get_parms().cipher_type_ == block) {
int ivExtra = 0;
opaque fill; opaque fill;
int padSz = ssl.getSecurity().get_parms().encrypt_size_ - aSz -
digestSz; if (ssl.isTLSv1_1())
ivExtra = ssl.getCrypto().get_cipher().get_blockSize();
int padSz = ssl.getSecurity().get_parms().encrypt_size_ - ivExtra -
aSz - digestSz;
for (int i = 0; i < padSz; i++) for (int i = 0; i < padSz; i++)
fill = input[AUTO]; fill = input[AUTO];
}
// verify // verify
if (memcmp(mac, verify, digestSz)) { if (memcmp(mac, verify, digestSz)) {
...@@ -879,9 +895,13 @@ Data::Data(uint16 len, opaque* b) ...@@ -879,9 +895,13 @@ Data::Data(uint16 len, opaque* b)
{} {}
Data::Data(uint16 len, const opaque* w) void Data::SetData(uint16 len, const opaque* buffer)
: length_(len), buffer_(0), write_buffer_(w) {
{} assert(write_buffer_ == 0);
length_ = len;
write_buffer_ = buffer;
}
input_buffer& Data::set(input_buffer& in) input_buffer& Data::set(input_buffer& in)
{ {
...@@ -907,17 +927,12 @@ uint16 Data::get_length() const ...@@ -907,17 +927,12 @@ uint16 Data::get_length() const
} }
const opaque* Data::get_buffer() const
{
return write_buffer_;
}
void Data::set_length(uint16 l) void Data::set_length(uint16 l)
{ {
length_ = l; length_ = l;
} }
opaque* Data::set_buffer() opaque* Data::set_buffer()
{ {
return buffer_; return buffer_;
...@@ -937,27 +952,42 @@ void Data::Process(input_buffer& input, SSL& ssl) ...@@ -937,27 +952,42 @@ void Data::Process(input_buffer& input, SSL& ssl)
{ {
int msgSz = ssl.getSecurity().get_parms().encrypt_size_; int msgSz = ssl.getSecurity().get_parms().encrypt_size_;
int pad = 0, padByte = 0; int pad = 0, padByte = 0;
int ivExtra = 0;
if (ssl.getSecurity().get_parms().cipher_type_ == block) { if (ssl.getSecurity().get_parms().cipher_type_ == block) {
pad = *(input.get_buffer() + input.get_current() + msgSz - 1); if (ssl.isTLSv1_1()) // IV
ivExtra = ssl.getCrypto().get_cipher().get_blockSize();
pad = *(input.get_buffer() + input.get_current() + msgSz -ivExtra - 1);
padByte = 1; padByte = 1;
} }
int digestSz = ssl.getCrypto().get_digest().get_digestSize(); int digestSz = ssl.getCrypto().get_digest().get_digestSize();
int dataSz = msgSz - digestSz - pad - padByte; int dataSz = msgSz - ivExtra - digestSz - pad - padByte;
opaque verify[SHA_LEN]; opaque verify[SHA_LEN];
const byte* rawData = input.get_buffer() + input.get_current();
// read data // read data
if (dataSz) { if (dataSz) { // could be compressed
if (ssl.CompressionOn()) {
input_buffer tmp;
if (DeCompress(input, dataSz, tmp) == -1) {
ssl.SetError(decompress_error);
return;
}
ssl.addData(NEW_YS input_buffer(tmp.get_size(),
tmp.get_buffer(), tmp.get_size()));
}
else {
input_buffer* data; input_buffer* data;
ssl.addData(data = NEW_YS input_buffer(dataSz)); ssl.addData(data = NEW_YS input_buffer(dataSz));
input.read(data->get_buffer(), dataSz); input.read(data->get_buffer(), dataSz);
data->add_size(dataSz); data->add_size(dataSz);
}
if (ssl.isTLS()) if (ssl.isTLS())
TLS_hmac(ssl, verify, data->get_buffer(), dataSz, application_data, TLS_hmac(ssl, verify, rawData, dataSz, application_data, true);
true);
else else
hmac(ssl, verify, data->get_buffer(), dataSz, application_data, hmac(ssl, verify, rawData, dataSz, application_data, true);
true);
} }
// read mac and fill // read mac and fill
...@@ -1220,6 +1250,13 @@ void ServerHello::Process(input_buffer&, SSL& ssl) ...@@ -1220,6 +1250,13 @@ void ServerHello::Process(input_buffer&, SSL& ssl)
if (ssl.isTLS() && server_version_.minor_ < 1) if (ssl.isTLS() && server_version_.minor_ < 1)
// downgrade to SSLv3 // downgrade to SSLv3
ssl.useSecurity().use_connection().TurnOffTLS(); ssl.useSecurity().use_connection().TurnOffTLS();
else if (ssl.isTLSv1_1() && server_version_.minor_ == 1)
// downdrage to TLSv1
ssl.useSecurity().use_connection().TurnOffTLS1_1();
}
else if (ssl.isTLSv1_1() && server_version_.minor_ < 2) {
ssl.SetError(badVersion_error);
return;
} }
else if (ssl.isTLS() && server_version_.minor_ < 1) { else if (ssl.isTLS() && server_version_.minor_ < 1) {
ssl.SetError(badVersion_error); ssl.SetError(badVersion_error);
...@@ -1252,6 +1289,10 @@ void ServerHello::Process(input_buffer&, SSL& ssl) ...@@ -1252,6 +1289,10 @@ void ServerHello::Process(input_buffer&, SSL& ssl)
ssl.useSecurity().set_resuming(false); ssl.useSecurity().set_resuming(false);
ssl.useLog().Trace("server denied resumption"); ssl.useLog().Trace("server denied resumption");
} }
if (ssl.CompressionOn() && !compression_method_)
ssl.UnSetCompression(); // server isn't supporting yaSSL zlib request
ssl.useStates().useClient() = serverHelloComplete; ssl.useStates().useClient() = serverHelloComplete;
} }
...@@ -1263,8 +1304,9 @@ ServerHello::ServerHello() ...@@ -1263,8 +1304,9 @@ ServerHello::ServerHello()
} }
ServerHello::ServerHello(ProtocolVersion pv) ServerHello::ServerHello(ProtocolVersion pv, bool useCompression)
: server_version_(pv) : server_version_(pv),
compression_method_(useCompression ? zlib : no_compression)
{ {
memset(random_, 0, RAN_LEN); memset(random_, 0, RAN_LEN);
memset(session_id_, 0, ID_LEN); memset(session_id_, 0, ID_LEN);
...@@ -1341,6 +1383,8 @@ opaque* ClientKeyBase::get_clientKey() const ...@@ -1341,6 +1383,8 @@ opaque* ClientKeyBase::get_clientKey() const
// input operator for Client Hello // input operator for Client Hello
input_buffer& operator>>(input_buffer& input, ClientHello& hello) input_buffer& operator>>(input_buffer& input, ClientHello& hello)
{ {
uint begin = input.get_current(); // could have extensions at end
// Protocol // Protocol
hello.client_version_.major_ = input[AUTO]; hello.client_version_.major_ = input[AUTO];
hello.client_version_.minor_ = input[AUTO]; hello.client_version_.minor_ = input[AUTO];
...@@ -1361,8 +1405,19 @@ input_buffer& operator>>(input_buffer& input, ClientHello& hello) ...@@ -1361,8 +1405,19 @@ input_buffer& operator>>(input_buffer& input, ClientHello& hello)
// Compression // Compression
hello.comp_len_ = input[AUTO]; hello.comp_len_ = input[AUTO];
while (hello.comp_len_--) // ignore for now hello.compression_methods_ = no_compression;
hello.compression_methods_ = CompressionMethod(input[AUTO]); while (hello.comp_len_--) {
CompressionMethod cm = CompressionMethod(input[AUTO]);
if (cm == zlib)
hello.compression_methods_ = zlib;
}
uint read = input.get_current() - begin;
uint expected = hello.get_length();
// ignore client hello extensions for now
if (read < expected)
input.set_current(input.get_current() + expected - read);
return input; return input;
} }
...@@ -1400,6 +1455,13 @@ output_buffer& operator<<(output_buffer& output, const ClientHello& hello) ...@@ -1400,6 +1455,13 @@ output_buffer& operator<<(output_buffer& output, const ClientHello& hello)
// Client Hello processing handler // Client Hello processing handler
void ClientHello::Process(input_buffer&, SSL& ssl) void ClientHello::Process(input_buffer&, SSL& ssl)
{ {
// store version for pre master secret
ssl.useSecurity().use_connection().chVersion_ = client_version_;
if (client_version_.major_ != 3) {
ssl.SetError(badVersion_error);
return;
}
if (ssl.GetMultiProtocol()) { // SSLv23 support if (ssl.GetMultiProtocol()) { // SSLv23 support
if (ssl.isTLS() && client_version_.minor_ < 1) { if (ssl.isTLS() && client_version_.minor_ < 1) {
// downgrade to SSLv3 // downgrade to SSLv3
...@@ -1407,20 +1469,29 @@ void ClientHello::Process(input_buffer&, SSL& ssl) ...@@ -1407,20 +1469,29 @@ void ClientHello::Process(input_buffer&, SSL& ssl)
ProtocolVersion pv = ssl.getSecurity().get_connection().version_; ProtocolVersion pv = ssl.getSecurity().get_connection().version_;
ssl.useSecurity().use_parms().SetSuites(pv); // reset w/ SSL suites ssl.useSecurity().use_parms().SetSuites(pv); // reset w/ SSL suites
} }
else if (ssl.isTLSv1_1() && client_version_.minor_ == 1)
// downgrade to TLSv1, but use same suites
ssl.useSecurity().use_connection().TurnOffTLS1_1();
}
else if (ssl.isTLSv1_1() && client_version_.minor_ < 2) {
ssl.SetError(badVersion_error);
return;
} }
else if (ssl.isTLS() && client_version_.minor_ < 1) { else if (ssl.isTLS() && client_version_.minor_ < 1) {
ssl.SetError(badVersion_error); ssl.SetError(badVersion_error);
return; return;
} }
else if (!ssl.isTLS() && (client_version_.major_ == 3 && else if (!ssl.isTLS() && client_version_.minor_ >= 1) {
client_version_.minor_ >= 1)) {
ssl.SetError(badVersion_error); ssl.SetError(badVersion_error);
return; return;
} }
ssl.set_random(random_, client_end); ssl.set_random(random_, client_end);
while (id_len_) { // trying to resume while (id_len_) { // trying to resume
SSL_SESSION* session = GetSessions().lookup(session_id_); SSL_SESSION* session = 0;
if (!ssl.getSecurity().GetContext()->GetSessionCacheOff())
session = GetSessions().lookup(session_id_);
if (!session) { if (!session) {
ssl.useLog().Trace("session lookup failed"); ssl.useLog().Trace("session lookup failed");
break; break;
...@@ -1444,6 +1515,9 @@ void ClientHello::Process(input_buffer&, SSL& ssl) ...@@ -1444,6 +1515,9 @@ void ClientHello::Process(input_buffer&, SSL& ssl)
ssl.matchSuite(cipher_suites_, suite_len_); ssl.matchSuite(cipher_suites_, suite_len_);
ssl.set_pending(ssl.getSecurity().get_parms().suite_[1]); ssl.set_pending(ssl.getSecurity().get_parms().suite_[1]);
if (compression_methods_ == zlib)
ssl.SetCompression();
ssl.useStates().useServer() = clientHelloComplete; ssl.useStates().useServer() = clientHelloComplete;
} }
...@@ -1478,8 +1552,9 @@ ClientHello::ClientHello() ...@@ -1478,8 +1552,9 @@ ClientHello::ClientHello()
} }
ClientHello::ClientHello(ProtocolVersion pv) ClientHello::ClientHello(ProtocolVersion pv, bool useCompression)
: client_version_(pv) : client_version_(pv),
compression_methods_(useCompression ? zlib : no_compression)
{ {
memset(random_, 0, RAN_LEN); memset(random_, 0, RAN_LEN);
} }
...@@ -1943,8 +2018,13 @@ void Finished::Process(input_buffer& input, SSL& ssl) ...@@ -1943,8 +2018,13 @@ void Finished::Process(input_buffer& input, SSL& ssl)
int digestSz = ssl.getCrypto().get_digest().get_digestSize(); int digestSz = ssl.getCrypto().get_digest().get_digestSize();
input.read(mac, digestSz); input.read(mac, digestSz);
uint ivExtra = 0;
if (ssl.getSecurity().get_parms().cipher_type_ == block)
if (ssl.isTLSv1_1())
ivExtra = ssl.getCrypto().get_cipher().get_blockSize();
opaque fill; opaque fill;
int padSz = ssl.getSecurity().get_parms().encrypt_size_ - int padSz = ssl.getSecurity().get_parms().encrypt_size_ - ivExtra -
HANDSHAKE_HEADER - finishedSz - digestSz; HANDSHAKE_HEADER - finishedSz - digestSz;
for (int i = 0; i < padSz; i++) for (int i = 0; i < padSz; i++)
fill = input[AUTO]; fill = input[AUTO];
...@@ -2018,7 +2098,9 @@ void clean(volatile opaque* p, uint sz, RandomPool& ran) ...@@ -2018,7 +2098,9 @@ void clean(volatile opaque* p, uint sz, RandomPool& ran)
Connection::Connection(ProtocolVersion v, RandomPool& ran) Connection::Connection(ProtocolVersion v, RandomPool& ran)
: pre_master_secret_(0), sequence_number_(0), peer_sequence_number_(0), : pre_master_secret_(0), sequence_number_(0), peer_sequence_number_(0),
pre_secret_len_(0), send_server_key_(false), master_clean_(false), pre_secret_len_(0), send_server_key_(false), master_clean_(false),
TLS_(v.major_ >= 3 && v.minor_ >= 1), version_(v), random_(ran) TLS_(v.major_ >= 3 && v.minor_ >= 1),
TLSv1_1_(v.major_ >= 3 && v.minor_ >= 2), compression_(false),
version_(v), random_(ran)
{ {
memset(sessionID_, 0, sizeof(sessionID_)); memset(sessionID_, 0, sizeof(sessionID_));
} }
...@@ -2043,6 +2125,13 @@ void Connection::TurnOffTLS() ...@@ -2043,6 +2125,13 @@ void Connection::TurnOffTLS()
} }
void Connection::TurnOffTLS1_1()
{
TLSv1_1_ = false;
version_.minor_ = 1;
}
// wipeout master secret // wipeout master secret
void Connection::CleanMaster() void Connection::CleanMaster()
{ {
......
...@@ -38,6 +38,11 @@ ...@@ -38,6 +38,11 @@
#endif #endif
#ifdef HAVE_LIBZ
#include "zlib.h"
#endif
#ifdef YASSL_PURE_C #ifdef YASSL_PURE_C
void* operator new(size_t sz, yaSSL::new_t) void* operator new(size_t sz, yaSSL::new_t)
...@@ -727,6 +732,32 @@ void SSL::set_preMaster(const opaque* pre, uint sz) ...@@ -727,6 +732,32 @@ void SSL::set_preMaster(const opaque* pre, uint sz)
} }
// set yaSSL zlib type compression
int SSL::SetCompression()
{
#ifdef HAVE_LIBZ
secure_.use_connection().compression_ = true;
return 0;
#else
return -1; // not built in
#endif
}
// unset yaSSL zlib type compression
void SSL::UnSetCompression()
{
secure_.use_connection().compression_ = false;
}
// is yaSSL zlib compression on
bool SSL::CompressionOn() const
{
return secure_.get_connection().compression_;
}
// store master secret // store master secret
void SSL::set_masterSecret(const opaque* sec) void SSL::set_masterSecret(const opaque* sec)
{ {
...@@ -1109,6 +1140,11 @@ void SSL::verifyState(const RecordLayerHeader& rlHeader) ...@@ -1109,6 +1140,11 @@ void SSL::verifyState(const RecordLayerHeader& rlHeader)
{ {
if (GetError()) return; if (GetError()) return;
if (rlHeader.version_.major_ != 3 || rlHeader.version_.minor_ > 2) {
SetError(badVersion_error);
return;
}
if (states_.getRecord() == recordNotReady || if (states_.getRecord() == recordNotReady ||
(rlHeader.type_ == application_data && // data and handshake (rlHeader.type_ == application_data && // data and handshake
states_.getHandShake() != handShakeReady) ) // isn't complete yet states_.getHandShake() != handShakeReady) ) // isn't complete yet
...@@ -1247,6 +1283,9 @@ void SSL::matchSuite(const opaque* peer, uint length) ...@@ -1247,6 +1283,9 @@ void SSL::matchSuite(const opaque* peer, uint length)
void SSL::set_session(SSL_SESSION* s) void SSL::set_session(SSL_SESSION* s)
{ {
if (getSecurity().GetContext()->GetSessionCacheOff())
return;
if (s && GetSessions().lookup(s->GetID(), &secure_.use_resume())) { if (s && GetSessions().lookup(s->GetID(), &secure_.use_resume())) {
secure_.set_resuming(true); secure_.set_resuming(true);
crypto_.use_certManager().setPeerX509(s->GetPeerX509()); crypto_.use_certManager().setPeerX509(s->GetPeerX509());
...@@ -1344,6 +1383,12 @@ bool SSL::isTLS() const ...@@ -1344,6 +1383,12 @@ bool SSL::isTLS() const
} }
bool SSL::isTLSv1_1() const
{
return secure_.get_connection().TLSv1_1_;
}
void SSL::addData(input_buffer* data) void SSL::addData(input_buffer* data)
{ {
buffers_.useData().push_back(data); buffers_.useData().push_back(data);
...@@ -1703,7 +1748,7 @@ bool SSL_METHOD::multipleProtocol() const ...@@ -1703,7 +1748,7 @@ bool SSL_METHOD::multipleProtocol() const
SSL_CTX::SSL_CTX(SSL_METHOD* meth) SSL_CTX::SSL_CTX(SSL_METHOD* meth)
: method_(meth), certificate_(0), privateKey_(0), passwordCb_(0), : method_(meth), certificate_(0), privateKey_(0), passwordCb_(0),
userData_(0) userData_(0), sessionCacheOff_(false)
{} {}
...@@ -1784,12 +1829,24 @@ void* SSL_CTX::GetUserData() const ...@@ -1784,12 +1829,24 @@ void* SSL_CTX::GetUserData() const
} }
bool SSL_CTX::GetSessionCacheOff() const
{
return sessionCacheOff_;
}
void SSL_CTX::SetUserData(void* data) void SSL_CTX::SetUserData(void* data)
{ {
userData_ = data; userData_ = data;
} }
void SSL_CTX::SetSessionCacheOff()
{
sessionCacheOff_ = true;
}
void SSL_CTX::setVerifyPeer() void SSL_CTX::setVerifyPeer()
{ {
method_->setVerifyPeer(); method_->setVerifyPeer();
...@@ -2312,9 +2369,110 @@ ASN1_STRING* StringHolder::GetString() ...@@ -2312,9 +2369,110 @@ ASN1_STRING* StringHolder::GetString()
} }
#ifdef HAVE_LIBZ
void* myAlloc(void* /* opaque */, unsigned int item, unsigned int size)
{
return NEW_YS unsigned char[item * size];
}
void myFree(void* /* opaque */, void* memory)
{
unsigned char* ptr = static_cast<unsigned char*>(memory);
yaSSL::ysArrayDelete(ptr);
}
// put size in front of compressed data
int Compress(const byte* in, int sz, input_buffer& buffer)
{
byte tmp[LENGTH_SZ];
z_stream c_stream; /* compression stream */
buffer.allocate(sz + sizeof(uint16) + COMPRESS_EXTRA);
c_stream.zalloc = myAlloc;
c_stream.zfree = myFree;
c_stream.opaque = (voidpf)0;
c_stream.next_in = const_cast<byte*>(in);
c_stream.avail_in = sz;
c_stream.next_out = buffer.get_buffer() + sizeof(tmp);
c_stream.avail_out = buffer.get_capacity() - sizeof(tmp);
if (deflateInit(&c_stream, 8) != Z_OK) return -1;
int err = deflate(&c_stream, Z_FINISH);
deflateEnd(&c_stream);
if (err != Z_OK && err != Z_STREAM_END) return -1;
c16toa(sz, tmp);
memcpy(buffer.get_buffer(), tmp, sizeof(tmp));
buffer.add_size(c_stream.total_out + sizeof(tmp));
return 0;
}
// get uncompressed size in front
int DeCompress(input_buffer& in, int sz, input_buffer& out)
{
byte tmp[LENGTH_SZ];
in.read(tmp, sizeof(tmp));
uint16 len;
ato16(tmp, len);
out.allocate(len);
z_stream d_stream; /* decompression stream */
d_stream.zalloc = myAlloc;
d_stream.zfree = myFree;
d_stream.opaque = (voidpf)0;
d_stream.next_in = in.get_buffer() + in.get_current();
d_stream.avail_in = sz - sizeof(tmp);
d_stream.next_out = out.get_buffer();
d_stream.avail_out = out.get_capacity();
if (inflateInit(&d_stream) != Z_OK) return -1;
int err = inflate(&d_stream, Z_FINISH);
inflateEnd(&d_stream);
if (err != Z_OK && err != Z_STREAM_END) return -1;
out.add_size(d_stream.total_out);
in.set_current(in.get_current() + sz - sizeof(tmp));
return 0;
}
#else // LIBZ
// these versions should never get called
int Compress(const byte* in, int sz, input_buffer& buffer)
{
assert(0);
return -1;
}
int DeCompress(input_buffer& in, int sz, input_buffer& out)
{
assert(0);
return -1;
}
#endif // LIBZ
} // namespace } // namespace
extern "C" void yaSSL_CleanUp() extern "C" void yaSSL_CleanUp()
{ {
TaoCrypt::CleanUp(); TaoCrypt::CleanUp();
......
REM quick and dirty build file for testing different MSDEVs REM quick and dirty build file for testing different MSDEVs
setlocal setlocal
set myFLAGS= /I../include /I../../mySTL /c /W3 /G6 /O2 set myFLAGS= /I../include /I../mySTL /c /W3 /G6 /O2
cl %myFLAGS% benchmark.cpp cl %myFLAGS% benchmark.cpp
......
...@@ -34,6 +34,12 @@ ...@@ -34,6 +34,12 @@
#include "modes.hpp" #include "modes.hpp"
#if defined(TAOCRYPT_X86ASM_AVAILABLE) && defined(TAO_ASM)
#define DO_AES_ASM
#endif
namespace TaoCrypt { namespace TaoCrypt {
...@@ -46,15 +52,14 @@ public: ...@@ -46,15 +52,14 @@ public:
enum { BLOCK_SIZE = AES_BLOCK_SIZE }; enum { BLOCK_SIZE = AES_BLOCK_SIZE };
AES(CipherDir DIR, Mode MODE) AES(CipherDir DIR, Mode MODE)
: Mode_BASE(BLOCK_SIZE), dir_(DIR), mode_(MODE) {} : Mode_BASE(BLOCK_SIZE, DIR, MODE) {}
#ifdef DO_AES_ASM
void Process(byte*, const byte*, word32); void Process(byte*, const byte*, word32);
#endif
void SetKey(const byte* key, word32 sz, CipherDir fake = ENCRYPTION); void SetKey(const byte* key, word32 sz, CipherDir fake = ENCRYPTION);
void SetIV(const byte* iv) { memcpy(r_, iv, BLOCK_SIZE); } void SetIV(const byte* iv) { memcpy(r_, iv, BLOCK_SIZE); }
private: private:
CipherDir dir_;
Mode mode_;
static const word32 rcon_[]; static const word32 rcon_[];
word32 rounds_; word32 rounds_;
......
...@@ -75,7 +75,8 @@ public: ...@@ -75,7 +75,8 @@ public:
typedef Integer Element; typedef Integer Element;
AbstractRing() : AbstractGroup() {m_mg.m_pRing = this;} AbstractRing() : AbstractGroup() {m_mg.m_pRing = this;}
AbstractRing(const AbstractRing &source) {m_mg.m_pRing = this;} AbstractRing(const AbstractRing &source) : AbstractGroup()
{m_mg.m_pRing = this;}
AbstractRing& operator=(const AbstractRing &source) {return *this;} AbstractRing& operator=(const AbstractRing &source) {return *this;}
virtual bool IsUnit(const Element &a) const =0; virtual bool IsUnit(const Element &a) const =0;
......
...@@ -46,7 +46,6 @@ public: ...@@ -46,7 +46,6 @@ public:
ARC4() {} ARC4() {}
void Process(byte*, const byte*, word32); void Process(byte*, const byte*, word32);
void AsmProcess(byte*, const byte*, word32);
void SetKey(const byte*, word32); void SetKey(const byte*, word32);
private: private:
byte x_; byte x_;
...@@ -55,6 +54,8 @@ private: ...@@ -55,6 +54,8 @@ private:
ARC4(const ARC4&); // hide copy ARC4(const ARC4&); // hide copy
const ARC4 operator=(const ARC4&); // and assign const ARC4 operator=(const ARC4&); // and assign
void AsmProcess(byte*, const byte*, word32);
}; };
} // namespace } // namespace
......
...@@ -34,7 +34,11 @@ ...@@ -34,7 +34,11 @@
#include "misc.hpp" #include "misc.hpp"
#include "block.hpp" #include "block.hpp"
#include "error.hpp" #include "error.hpp"
#include STL_LIST_FILE #ifdef USE_SYS_STL
#include <list>
#else
#include "list.hpp"
#endif
namespace STL = STL_NAMESPACE; namespace STL = STL_NAMESPACE;
......
...@@ -34,7 +34,12 @@ ...@@ -34,7 +34,12 @@
#include "misc.hpp" #include "misc.hpp"
#include <string.h> // memcpy #include <string.h> // memcpy
#include <stddef.h> // ptrdiff_t #include <stddef.h> // ptrdiff_t
#include STL_ALGORITHM_FILE
#ifdef USE_SYS_STL
#include <algorithm>
#else
#include "algorithm.hpp"
#endif
namespace STL = STL_NAMESPACE; namespace STL = STL_NAMESPACE;
......
...@@ -32,12 +32,21 @@ ...@@ -32,12 +32,21 @@
#include "misc.hpp" #include "misc.hpp"
#include "modes.hpp" #include "modes.hpp"
#include STL_ALGORITHM_FILE #ifdef USE_SYS_STL
#include <algorithm>
#else
#include "algorithm.hpp"
#endif
namespace STL = STL_NAMESPACE; namespace STL = STL_NAMESPACE;
#if defined(TAOCRYPT_X86ASM_AVAILABLE) && defined(TAO_ASM)
#define DO_BLOWFISH_ASM
#endif
namespace TaoCrypt { namespace TaoCrypt {
enum { BLOWFISH_BLOCK_SIZE = 8 }; enum { BLOWFISH_BLOCK_SIZE = 8 };
...@@ -49,15 +58,14 @@ public: ...@@ -49,15 +58,14 @@ public:
enum { BLOCK_SIZE = BLOWFISH_BLOCK_SIZE, ROUNDS = 16 }; enum { BLOCK_SIZE = BLOWFISH_BLOCK_SIZE, ROUNDS = 16 };
Blowfish(CipherDir DIR, Mode MODE) Blowfish(CipherDir DIR, Mode MODE)
: Mode_BASE(BLOCK_SIZE), dir_(DIR), mode_(MODE) {} : Mode_BASE(BLOCK_SIZE, DIR, MODE) {}
#ifdef DO_BLOWFISH_ASM
void Process(byte*, const byte*, word32); void Process(byte*, const byte*, word32);
#endif
void SetKey(const byte* key, word32 sz, CipherDir fake = ENCRYPTION); void SetKey(const byte* key, word32 sz, CipherDir fake = ENCRYPTION);
void SetIV(const byte* iv) { memcpy(r_, iv, BLOCK_SIZE); } void SetIV(const byte* iv) { memcpy(r_, iv, BLOCK_SIZE); }
private: private:
CipherDir dir_;
Mode mode_;
static const word32 p_init_[ROUNDS + 2]; static const word32 p_init_[ROUNDS + 2];
static const word32 s_init_[4 * 256]; static const word32 s_init_[4 * 256];
......
...@@ -34,6 +34,12 @@ ...@@ -34,6 +34,12 @@
#include "misc.hpp" #include "misc.hpp"
#include "modes.hpp" #include "modes.hpp"
#if defined(TAOCRYPT_X86ASM_AVAILABLE) && defined(TAO_ASM)
#define DO_DES_ASM
#endif
namespace TaoCrypt { namespace TaoCrypt {
...@@ -53,13 +59,9 @@ protected: ...@@ -53,13 +59,9 @@ protected:
class DES : public Mode_BASE, public BasicDES { class DES : public Mode_BASE, public BasicDES {
public: public:
DES(CipherDir DIR, Mode MODE) DES(CipherDir DIR, Mode MODE)
: Mode_BASE(DES_BLOCK_SIZE), dir_(DIR), mode_(MODE) {} : Mode_BASE(DES_BLOCK_SIZE, DIR, MODE) {}
void Process(byte*, const byte*, word32);
private: private:
CipherDir dir_;
Mode mode_;
void ProcessAndXorBlock(const byte*, const byte*, byte*) const; void ProcessAndXorBlock(const byte*, const byte*, byte*) const;
DES(const DES&); // hide copy DES(const DES&); // hide copy
...@@ -71,14 +73,10 @@ private: ...@@ -71,14 +73,10 @@ private:
class DES_EDE2 : public Mode_BASE { class DES_EDE2 : public Mode_BASE {
public: public:
DES_EDE2(CipherDir DIR, Mode MODE) DES_EDE2(CipherDir DIR, Mode MODE)
: Mode_BASE(DES_BLOCK_SIZE), dir_(DIR), mode_(MODE) {} : Mode_BASE(DES_BLOCK_SIZE, DIR, MODE) {}
void SetKey(const byte*, word32, CipherDir dir); void SetKey(const byte*, word32, CipherDir dir);
void Process(byte*, const byte*, word32);
private: private:
CipherDir dir_;
Mode mode_;
BasicDES des1_; BasicDES des1_;
BasicDES des2_; BasicDES des2_;
...@@ -94,15 +92,14 @@ private: ...@@ -94,15 +92,14 @@ private:
class DES_EDE3 : public Mode_BASE { class DES_EDE3 : public Mode_BASE {
public: public:
DES_EDE3(CipherDir DIR, Mode MODE) DES_EDE3(CipherDir DIR, Mode MODE)
: Mode_BASE(DES_BLOCK_SIZE), dir_(DIR), mode_(MODE) {} : Mode_BASE(DES_BLOCK_SIZE, DIR, MODE) {}
void SetKey(const byte*, word32, CipherDir dir); void SetKey(const byte*, word32, CipherDir dir);
void SetIV(const byte* iv) { memcpy(r_, iv, DES_BLOCK_SIZE); } void SetIV(const byte* iv) { memcpy(r_, iv, DES_BLOCK_SIZE); }
#ifdef DO_DES_ASM
void Process(byte*, const byte*, word32); void Process(byte*, const byte*, word32);
#endif
private: private:
CipherDir dir_;
Mode mode_;
BasicDES des1_; BasicDES des1_;
BasicDES des2_; BasicDES des2_;
BasicDES des3_; BasicDES des3_;
......
...@@ -45,7 +45,11 @@ ...@@ -45,7 +45,11 @@
#include "random.hpp" #include "random.hpp"
#include "file.hpp" #include "file.hpp"
#include <string.h> #include <string.h>
#include STL_ALGORITHM_FILE #ifdef USE_SYS_STL
#include <algorithm>
#else
#include "algorithm.hpp"
#endif
#ifdef TAOCRYPT_X86ASM_AVAILABLE #ifdef TAOCRYPT_X86ASM_AVAILABLE
...@@ -67,7 +71,8 @@ ...@@ -67,7 +71,8 @@
#endif #endif
// SSE2 intrinsics work in GCC 3.3 or later // SSE2 intrinsics work in GCC 3.3 or later
#if defined(__SSE2__) && (__GNUC_MAJOR__ > 3 || __GNUC_MINOR__ > 2) #if defined(__SSE2__) && (__GNUC__ == 4 || __GNUC_MAJOR__ > 3 || \
__GNUC_MINOR__ > 2)
#define SSE2_INTRINSICS_AVAILABLE #define SSE2_INTRINSICS_AVAILABLE
#endif #endif
...@@ -106,7 +111,6 @@ namespace TaoCrypt { ...@@ -106,7 +111,6 @@ namespace TaoCrypt {
#endif #endif
}; };
template class TAOCRYPT_DLL AlignedAllocator<word>;
typedef Block<word, AlignedAllocator<word> > AlignedWordBlock; typedef Block<word, AlignedAllocator<word> > AlignedWordBlock;
#else #else
typedef WordBlock AlignedWordBlock; typedef WordBlock AlignedWordBlock;
......
...@@ -31,6 +31,11 @@ ...@@ -31,6 +31,11 @@
#include "hash.hpp" #include "hash.hpp"
#if defined(TAOCRYPT_X86ASM_AVAILABLE) && defined(TAO_ASM)
#define DO_MD5_ASM
#endif
namespace TaoCrypt { namespace TaoCrypt {
...@@ -49,7 +54,9 @@ public: ...@@ -49,7 +54,9 @@ public:
MD5(const MD5&); MD5(const MD5&);
MD5& operator= (const MD5&); MD5& operator= (const MD5&);
#ifdef DO_MD5_ASM
void Update(const byte*, word32); void Update(const byte*, word32);
#endif
void Init(); void Init();
void Swap(MD5&); void Swap(MD5&);
......
...@@ -151,6 +151,17 @@ void CleanUp(); ...@@ -151,6 +151,17 @@ void CleanUp();
#endif #endif
#ifdef TAOCRYPT_X86ASM_AVAILABLE
bool HaveCpuId();
bool IsPentium();
void CpuId(word32 input, word32 *output);
extern bool isMMX;
#endif
// Turn on ia32 ASM for Ciphers and Message Digests // Turn on ia32 ASM for Ciphers and Message Digests
// Seperate define since these are more complex, use member offsets // Seperate define since these are more complex, use member offsets
// and user may want to turn off while leaving Big Integer optos on // and user may want to turn off while leaving Big Integer optos on
...@@ -200,17 +211,9 @@ void CleanUp(); ...@@ -200,17 +211,9 @@ void CleanUp();
#ifdef USE_SYS_STL #ifdef USE_SYS_STL
// use system STL // use system STL
#define STL_VECTOR_FILE <vector>
#define STL_LIST_FILE <list>
#define STL_ALGORITHM_FILE <algorithm>
#define STL_MEMORY_FILE <memory>
#define STL_NAMESPACE std #define STL_NAMESPACE std
#else #else
// use mySTL // use mySTL
#define STL_VECTOR_FILE "vector.hpp"
#define STL_LIST_FILE "list.hpp"
#define STL_ALGORITHM_FILE "algorithm.hpp"
#define STL_MEMORY_FILE "memory.hpp"
#define STL_NAMESPACE mySTL #define STL_NAMESPACE mySTL
#endif #endif
......
...@@ -38,6 +38,7 @@ namespace TaoCrypt { ...@@ -38,6 +38,7 @@ namespace TaoCrypt {
enum Mode { ECB, CBC }; enum Mode { ECB, CBC };
// BlockCipher abstraction // BlockCipher abstraction
template<CipherDir DIR, class T, Mode MODE> template<CipherDir DIR, class T, Mode MODE>
class BlockCipher { class BlockCipher {
...@@ -63,14 +64,16 @@ class Mode_BASE : public virtual_base { ...@@ -63,14 +64,16 @@ class Mode_BASE : public virtual_base {
public: public:
enum { MaxBlockSz = 16 }; enum { MaxBlockSz = 16 };
explicit Mode_BASE(int sz) explicit Mode_BASE(int sz, CipherDir dir, Mode mode)
: blockSz_(sz), reg_(reinterpret_cast<byte*>(r_)), : blockSz_(sz), reg_(reinterpret_cast<byte*>(r_)),
tmp_(reinterpret_cast<byte*>(t_)) tmp_(reinterpret_cast<byte*>(t_)), dir_(dir), mode_(mode)
{ {
assert(sz <= MaxBlockSz); assert(sz <= MaxBlockSz);
} }
virtual ~Mode_BASE() {} virtual ~Mode_BASE() {}
virtual void Process(byte*, const byte*, word32);
void SetIV(const byte* iv) { memcpy(reg_, iv, blockSz_); } void SetIV(const byte* iv) { memcpy(reg_, iv, blockSz_); }
protected: protected:
int blockSz_; int blockSz_;
...@@ -80,6 +83,9 @@ protected: ...@@ -80,6 +83,9 @@ protected:
word32 r_[MaxBlockSz / sizeof(word32)]; // align reg_ on word32 word32 r_[MaxBlockSz / sizeof(word32)]; // align reg_ on word32
word32 t_[MaxBlockSz / sizeof(word32)]; // align tmp_ on word32 word32 t_[MaxBlockSz / sizeof(word32)]; // align tmp_ on word32
CipherDir dir_;
Mode mode_;
void ECB_Process(byte*, const byte*, word32); void ECB_Process(byte*, const byte*, word32);
void CBC_Encrypt(byte*, const byte*, word32); void CBC_Encrypt(byte*, const byte*, word32);
void CBC_Decrypt(byte*, const byte*, word32); void CBC_Decrypt(byte*, const byte*, word32);
...@@ -92,6 +98,18 @@ private: ...@@ -92,6 +98,18 @@ private:
}; };
inline void Mode_BASE::Process(byte* out, const byte* in, word32 sz)
{
if (mode_ == ECB)
ECB_Process(out, in, sz);
else if (mode_ == CBC)
if (dir_ == ENCRYPTION)
CBC_Encrypt(out, in, sz);
else
CBC_Decrypt(out, in, sz);
}
// ECB Process blocks // ECB Process blocks
inline void Mode_BASE::ECB_Process(byte* out, const byte* in, word32 sz) inline void Mode_BASE::ECB_Process(byte* out, const byte* in, word32 sz)
{ {
......
...@@ -31,6 +31,11 @@ ...@@ -31,6 +31,11 @@
#include "hash.hpp" #include "hash.hpp"
#if defined(TAOCRYPT_X86ASM_AVAILABLE) && defined(TAO_ASM)
#define DO_RIPEMD_ASM
#endif
namespace TaoCrypt { namespace TaoCrypt {
...@@ -49,7 +54,9 @@ public: ...@@ -49,7 +54,9 @@ public:
RIPEMD160(const RIPEMD160&); RIPEMD160(const RIPEMD160&);
RIPEMD160& operator= (const RIPEMD160&); RIPEMD160& operator= (const RIPEMD160&);
#ifdef DO_RIPEMD_ASM
void Update(const byte*, word32); void Update(const byte*, word32);
#endif
void Init(); void Init();
void Swap(RIPEMD160&); void Swap(RIPEMD160&);
private: private:
......
...@@ -239,7 +239,8 @@ bool RSA_Encryptor<Pad>::SSL_Verify(const byte* message, word32 sz, ...@@ -239,7 +239,8 @@ bool RSA_Encryptor<Pad>::SSL_Verify(const byte* message, word32 sz,
const byte* sig) const byte* sig)
{ {
ByteBlock plain(PK_Lengths(key_.GetModulus()).FixedMaxPlaintextLength()); ByteBlock plain(PK_Lengths(key_.GetModulus()).FixedMaxPlaintextLength());
SSL_Decrypt(key_, sig, plain.get_buffer()); if (SSL_Decrypt(key_, sig, plain.get_buffer()) != sz)
return false; // not right justified or bad padding
if ( (memcmp(plain.get_buffer(), message, sz)) == 0) if ( (memcmp(plain.get_buffer(), message, sz)) == 0)
return true; return true;
......
...@@ -31,6 +31,11 @@ ...@@ -31,6 +31,11 @@
#include "hash.hpp" #include "hash.hpp"
#if defined(TAOCRYPT_X86ASM_AVAILABLE) && defined(TAO_ASM)
#define DO_SHA_ASM
#endif
namespace TaoCrypt { namespace TaoCrypt {
...@@ -46,7 +51,9 @@ public: ...@@ -46,7 +51,9 @@ public:
word32 getDigestSize() const { return DIGEST_SIZE; } word32 getDigestSize() const { return DIGEST_SIZE; }
word32 getPadSize() const { return PAD_SIZE; } word32 getPadSize() const { return PAD_SIZE; }
#ifdef DO_SHA_ASM
void Update(const byte* data, word32 len); void Update(const byte* data, word32 len);
#endif
void Init(); void Init();
SHA(const SHA&); SHA(const SHA&);
......
...@@ -32,12 +32,20 @@ ...@@ -32,12 +32,20 @@
#include "misc.hpp" #include "misc.hpp"
#include "modes.hpp" #include "modes.hpp"
#include STL_ALGORITHM_FILE #ifdef USE_SYS_STL
#include <algorithm>
#else
#include "algorithm.hpp"
#endif
namespace STL = STL_NAMESPACE; namespace STL = STL_NAMESPACE;
#if defined(TAOCRYPT_X86ASM_AVAILABLE) && defined(TAO_ASM)
#define DO_TWOFISH_ASM
#endif
namespace TaoCrypt { namespace TaoCrypt {
enum { TWOFISH_BLOCK_SIZE = 16 }; enum { TWOFISH_BLOCK_SIZE = 16 };
...@@ -49,15 +57,14 @@ public: ...@@ -49,15 +57,14 @@ public:
enum { BLOCK_SIZE = TWOFISH_BLOCK_SIZE }; enum { BLOCK_SIZE = TWOFISH_BLOCK_SIZE };
Twofish(CipherDir DIR, Mode MODE) Twofish(CipherDir DIR, Mode MODE)
: Mode_BASE(BLOCK_SIZE), dir_(DIR), mode_(MODE) {} : Mode_BASE(BLOCK_SIZE, DIR, MODE) {}
#ifdef DO_TWOFISH_ASM
void Process(byte*, const byte*, word32); void Process(byte*, const byte*, word32);
#endif
void SetKey(const byte* key, word32 sz, CipherDir fake = ENCRYPTION); void SetKey(const byte* key, word32 sz, CipherDir fake = ENCRYPTION);
void SetIV(const byte* iv) { memcpy(r_, iv, BLOCK_SIZE); } void SetIV(const byte* iv) { memcpy(r_, iv, BLOCK_SIZE); }
private: private:
CipherDir dir_;
Mode mode_;
static const byte q_[2][256]; static const byte q_[2][256];
static const word32 mds_[4][256]; static const word32 mds_[4][256];
......
...@@ -34,33 +34,19 @@ ...@@ -34,33 +34,19 @@
#include "aes.hpp" #include "aes.hpp"
#if defined(TAOCRYPT_X86ASM_AVAILABLE) && defined(TAO_ASM)
#define DO_AES_ASM
#endif
namespace TaoCrypt { namespace TaoCrypt {
#if !defined(DO_AES_ASM) #if defined(DO_AES_ASM)
// Generic Version
void AES::Process(byte* out, const byte* in, word32 sz)
{
if (mode_ == ECB)
ECB_Process(out, in, sz);
else if (mode_ == CBC)
if (dir_ == ENCRYPTION)
CBC_Encrypt(out, in, sz);
else
CBC_Decrypt(out, in, sz);
}
#else
// ia32 optimized version // ia32 optimized version
void AES::Process(byte* out, const byte* in, word32 sz) void AES::Process(byte* out, const byte* in, word32 sz)
{ {
if (!isMMX) {
Mode_BASE::Process(out, in, sz);
return;
}
word32 blocks = sz / BLOCK_SIZE; word32 blocks = sz / BLOCK_SIZE;
if (mode_ == ECB) if (mode_ == ECB)
......
...@@ -29,7 +29,11 @@ ...@@ -29,7 +29,11 @@
#include "runtime.hpp" #include "runtime.hpp"
#include "algebra.hpp" #include "algebra.hpp"
#include STL_VECTOR_FILE #ifdef USE_SYS_STL
#include <vector>
#else
#include "vector.hpp"
#endif
namespace STL = STL_NAMESPACE; namespace STL = STL_NAMESPACE;
......
...@@ -80,12 +80,18 @@ inline unsigned int MakeByte(word32& x, word32& y, byte* s) ...@@ -80,12 +80,18 @@ inline unsigned int MakeByte(word32& x, word32& y, byte* s)
} // namespace } // namespace
#ifndef DO_ARC4_ASM
void ARC4::Process(byte* out, const byte* in, word32 length) void ARC4::Process(byte* out, const byte* in, word32 length)
{ {
if (length == 0) return; if (length == 0) return;
#ifdef DO_ARC4_ASM
if (isMMX) {
AsmProcess(out, in, length);
return;
}
#endif
byte *const s = state_; byte *const s = state_;
word32 x = x_; word32 x = x_;
word32 y = y_; word32 y = y_;
...@@ -100,13 +106,16 @@ void ARC4::Process(byte* out, const byte* in, word32 length) ...@@ -100,13 +106,16 @@ void ARC4::Process(byte* out, const byte* in, word32 length)
y_ = y; y_ = y;
} }
#else // DO_ARC4_ASM
#ifdef DO_ARC4_ASM
#ifdef _MSC_VER #ifdef _MSC_VER
__declspec(naked) __declspec(naked)
#else
__attribute__ ((noinline))
#endif #endif
void ARC4::Process(byte* out, const byte* in, word32 length) void ARC4::AsmProcess(byte* out, const byte* in, word32 length)
{ {
#ifdef __GNUC__ #ifdef __GNUC__
#define AS1(x) asm(#x); #define AS1(x) asm(#x);
......
...@@ -37,34 +37,21 @@ ...@@ -37,34 +37,21 @@
#if defined(TAOCRYPT_X86ASM_AVAILABLE) && defined(TAO_ASM)
#define DO_BLOWFISH_ASM
#endif
namespace TaoCrypt { namespace TaoCrypt {
#if !defined(DO_BLOWFISH_ASM) #if defined(DO_BLOWFISH_ASM)
// Generic Version
void Blowfish::Process(byte* out, const byte* in, word32 sz)
{
if (mode_ == ECB)
ECB_Process(out, in, sz);
else if (mode_ == CBC)
if (dir_ == ENCRYPTION)
CBC_Encrypt(out, in, sz);
else
CBC_Decrypt(out, in, sz);
}
#else
// ia32 optimized version // ia32 optimized version
void Blowfish::Process(byte* out, const byte* in, word32 sz) void Blowfish::Process(byte* out, const byte* in, word32 sz)
{ {
if (!isMMX) {
Mode_BASE::Process(out, in, sz);
return;
}
word32 blocks = sz / BLOCK_SIZE; word32 blocks = sz / BLOCK_SIZE;
if (mode_ == ECB) if (mode_ == ECB)
......
...@@ -34,16 +34,16 @@ ...@@ -34,16 +34,16 @@
#include "runtime.hpp" #include "runtime.hpp"
#include "des.hpp" #include "des.hpp"
#include STL_ALGORITHM_FILE #ifdef USE_SYS_STL
#include <algorithm>
#else
#include "algorithm.hpp"
#endif
namespace STL = STL_NAMESPACE; namespace STL = STL_NAMESPACE;
#if defined(TAOCRYPT_X86ASM_AVAILABLE) && defined(TAO_ASM)
#define DO_DES_ASM
#endif
namespace TaoCrypt { namespace TaoCrypt {
...@@ -357,18 +357,6 @@ void BasicDES::RawProcessBlock(word32& lIn, word32& rIn) const ...@@ -357,18 +357,6 @@ void BasicDES::RawProcessBlock(word32& lIn, word32& rIn) const
} }
void DES::Process(byte* out, const byte* in, word32 sz)
{
if (mode_ == ECB)
ECB_Process(out, in, sz);
else if (mode_ == CBC)
if (dir_ == ENCRYPTION)
CBC_Encrypt(out, in, sz);
else
CBC_Decrypt(out, in, sz);
}
typedef BlockGetAndPut<word32, BigEndian> Block; typedef BlockGetAndPut<word32, BigEndian> Block;
...@@ -386,17 +374,6 @@ void DES::ProcessAndXorBlock(const byte* in, const byte* xOr, byte* out) const ...@@ -386,17 +374,6 @@ void DES::ProcessAndXorBlock(const byte* in, const byte* xOr, byte* out) const
} }
void DES_EDE2::Process(byte* out, const byte* in, word32 sz)
{
if (mode_ == ECB)
ECB_Process(out, in, sz);
else if (mode_ == CBC)
if (dir_ == ENCRYPTION)
CBC_Encrypt(out, in, sz);
else
CBC_Decrypt(out, in, sz);
}
void DES_EDE2::SetKey(const byte* key, word32 sz, CipherDir dir) void DES_EDE2::SetKey(const byte* key, word32 sz, CipherDir dir)
{ {
des1_.SetKey(key, sz, dir); des1_.SetKey(key, sz, dir);
...@@ -429,25 +406,16 @@ void DES_EDE3::SetKey(const byte* key, word32 sz, CipherDir dir) ...@@ -429,25 +406,16 @@ void DES_EDE3::SetKey(const byte* key, word32 sz, CipherDir dir)
#if !defined(DO_DES_ASM) #if defined(DO_DES_ASM)
// Generic Version
void DES_EDE3::Process(byte* out, const byte* in, word32 sz)
{
if (mode_ == ECB)
ECB_Process(out, in, sz);
else if (mode_ == CBC)
if (dir_ == ENCRYPTION)
CBC_Encrypt(out, in, sz);
else
CBC_Decrypt(out, in, sz);
}
#else
// ia32 optimized version // ia32 optimized version
void DES_EDE3::Process(byte* out, const byte* in, word32 sz) void DES_EDE3::Process(byte* out, const byte* in, word32 sz)
{ {
if (!isMMX) {
Mode_BASE::Process(out, in, sz);
return;
}
word32 blocks = sz / DES_BLOCK_SIZE; word32 blocks = sz / DES_BLOCK_SIZE;
if (mode_ == CBC) if (mode_ == CBC)
......
...@@ -55,12 +55,15 @@ extern "C" word myUMULH(word, word); ...@@ -55,12 +55,15 @@ extern "C" word myUMULH(word, word);
#pragma intrinsic (myUMULH) #pragma intrinsic (myUMULH)
#endif #endif
#ifdef __GNUC__
#include <signal.h>
#include <setjmp.h>
#endif
#ifdef SSE2_INTRINSICS_AVAILABLE #ifdef SSE2_INTRINSICS_AVAILABLE
#ifdef __GNUC__ #ifdef __GNUC__
#include <xmmintrin.h> #include <xmmintrin.h>
#include <signal.h>
#include <setjmp.h>
#ifdef TAOCRYPT_MEMALIGN_AVAILABLE #ifdef TAOCRYPT_MEMALIGN_AVAILABLE
#include <malloc.h> #include <malloc.h>
#else #else
...@@ -1015,44 +1018,20 @@ void Portable::Multiply8Bottom(word *R, const word *A, const word *B) ...@@ -1015,44 +1018,20 @@ void Portable::Multiply8Bottom(word *R, const word *A, const word *B)
// ************** x86 feature detection *************** // ************** x86 feature detection ***************
static bool s_sse2Enabled = true;
static void CpuId(word32 input, word32 *output)
{
#ifdef __GNUC__
__asm__
(
// save ebx in case -fPIC is being used
"push %%ebx; cpuid; mov %%ebx, %%edi; pop %%ebx"
: "=a" (output[0]), "=D" (output[1]), "=c" (output[2]), "=d"(output[3])
: "a" (input)
);
#else
__asm
{
mov eax, input
cpuid
mov edi, output
mov [edi], eax
mov [edi+4], ebx
mov [edi+8], ecx
mov [edi+12], edx
}
#endif
}
#ifdef SSE2_INTRINSICS_AVAILABLE #ifdef SSE2_INTRINSICS_AVAILABLE
#ifndef _MSC_VER #ifndef _MSC_VER
static jmp_buf s_env; static jmp_buf s_env;
static void SigIllHandler(int) static void SigIllHandler(int)
{ {
longjmp(s_env, 1); longjmp(s_env, 1);
} }
#endif #endif
static bool HasSSE2() static bool HasSSE2()
{ {
if (!s_sse2Enabled) if (!IsPentium())
return false; return false;
word32 cpuid[4]; word32 cpuid[4];
...@@ -1081,23 +1060,22 @@ static bool HasSSE2() ...@@ -1081,23 +1060,22 @@ static bool HasSSE2()
if (setjmp(s_env)) if (setjmp(s_env))
result = false; result = false;
else else
__asm __volatile ("xorps %xmm0, %xmm0"); __asm __volatile ("xorpd %xmm0, %xmm0");
signal(SIGILL, oldHandler); signal(SIGILL, oldHandler);
return result; return result;
#endif #endif
} }
#endif #endif // SSE2_INTRINSICS_AVAILABLE
static bool IsP4() static bool IsP4()
{ {
word32 cpuid[4]; if (!IsPentium())
CpuId(0, cpuid);
STL::swap(cpuid[2], cpuid[3]);
if (memcmp(cpuid+1, "GenuineIntel", 12) != 0)
return false; return false;
word32 cpuid[4];
CpuId(1, cpuid); CpuId(1, cpuid);
return ((cpuid[0] >> 8) & 0xf) == 0xf; return ((cpuid[0] >> 8) & 0xf) == 0xf;
} }
...@@ -1147,7 +1125,12 @@ static PMul s_pMul4, s_pMul8, s_pMul8B; ...@@ -1147,7 +1125,12 @@ static PMul s_pMul4, s_pMul8, s_pMul8B;
static void SetPentiumFunctionPointers() static void SetPentiumFunctionPointers()
{ {
if (IsP4()) if (!IsPentium())
{
s_pAdd = &Portable::Add;
s_pSub = &Portable::Subtract;
}
else if (IsP4())
{ {
s_pAdd = &P4Optimized::Add; s_pAdd = &P4Optimized::Add;
s_pSub = &P4Optimized::Subtract; s_pSub = &P4Optimized::Subtract;
...@@ -1159,7 +1142,13 @@ static void SetPentiumFunctionPointers() ...@@ -1159,7 +1142,13 @@ static void SetPentiumFunctionPointers()
} }
#ifdef SSE2_INTRINSICS_AVAILABLE #ifdef SSE2_INTRINSICS_AVAILABLE
if (HasSSE2()) if (!IsPentium())
{
s_pMul4 = &Portable::Multiply4;
s_pMul8 = &Portable::Multiply8;
s_pMul8B = &Portable::Multiply8Bottom;
}
else if (HasSSE2())
{ {
s_pMul4 = &P4Optimized::Multiply4; s_pMul4 = &P4Optimized::Multiply4;
s_pMul8 = &P4Optimized::Multiply8; s_pMul8 = &P4Optimized::Multiply8;
...@@ -1177,11 +1166,6 @@ static void SetPentiumFunctionPointers() ...@@ -1177,11 +1166,6 @@ static void SetPentiumFunctionPointers()
static const char s_RunAtStartupSetPentiumFunctionPointers = static const char s_RunAtStartupSetPentiumFunctionPointers =
(SetPentiumFunctionPointers(), 0); (SetPentiumFunctionPointers(), 0);
void DisableSSE2()
{
s_sse2Enabled = false;
SetPentiumFunctionPointers();
}
class LowLevel : public PentiumOptimized class LowLevel : public PentiumOptimized
{ {
...@@ -3984,6 +3968,9 @@ Integer CRT(const Integer &xp, const Integer &p, const Integer &xq, ...@@ -3984,6 +3968,9 @@ Integer CRT(const Integer &xp, const Integer &p, const Integer &xq,
template hword DivideThreeWordsByTwo<hword, Word>(hword*, hword, hword, Word*); template hword DivideThreeWordsByTwo<hword, Word>(hword*, hword, hword, Word*);
#endif #endif
template word DivideThreeWordsByTwo<word, DWord>(word*, word, word, DWord*); template word DivideThreeWordsByTwo<word, DWord>(word*, word, word, DWord*);
#ifdef SSE2_INTRINSICS_AVAILABLE
template class AlignedAllocator<word>;
#endif
#endif #endif
......
REM quick and dirty build file for testing different MSDEVs REM quick and dirty build file for testing different MSDEVs
setlocal setlocal
set myFLAGS= /I../include /I../../mySTL /c /W3 /G6 /O2 set myFLAGS= /I../include /I../mySTL /c /W3 /G6 /O2
cl %myFLAGS% aes.cpp cl %myFLAGS% aes.cpp
cl %myFLAGS% aestables.cpp cl %myFLAGS% aestables.cpp
...@@ -21,6 +21,7 @@ cl %myFLAGS% file.cpp ...@@ -21,6 +21,7 @@ cl %myFLAGS% file.cpp
cl %myFLAGS% hash.cpp cl %myFLAGS% hash.cpp
cl %myFLAGS% integer.cpp cl %myFLAGS% integer.cpp
cl %myFLAGS% md2.cpp cl %myFLAGS% md2.cpp
cl %myFLAGS% md4.cpp
cl %myFLAGS% md5.cpp cl %myFLAGS% md5.cpp
cl %myFLAGS% misc.cpp cl %myFLAGS% misc.cpp
...@@ -33,5 +34,5 @@ cl %myFLAGS% template_instnt.cpp ...@@ -33,5 +34,5 @@ cl %myFLAGS% template_instnt.cpp
cl %myFLAGS% tftables.cpp cl %myFLAGS% tftables.cpp
cl %myFLAGS% twofish.cpp cl %myFLAGS% twofish.cpp
link.exe -lib /out:taocrypt.lib aes.obj aestables.obj algebra.obj arc4.obj asn.obj bftables.obj blowfish.obj coding.obj des.obj dh.obj dsa.obj file.obj hash.obj integer.obj md2.obj md5.obj misc.obj random.obj ripemd.obj rsa.obj sha.obj template_instnt.obj tftables.obj twofish.obj link.exe -lib /out:taocrypt.lib aes.obj aestables.obj algebra.obj arc4.obj asn.obj bftables.obj blowfish.obj coding.obj des.obj dh.obj dsa.obj file.obj hash.obj integer.obj md2.obj md4.obj md5.obj misc.obj random.obj ripemd.obj rsa.obj sha.obj template_instnt.obj tftables.obj twofish.obj
...@@ -28,7 +28,11 @@ ...@@ -28,7 +28,11 @@
#include "runtime.hpp" #include "runtime.hpp"
#include "md4.hpp" #include "md4.hpp"
#include STL_ALGORITHM_FILE #ifdef USE_SYS_STL
#include <algorithm>
#else
#include "algorithm.hpp"
#endif
namespace STL = STL_NAMESPACE; namespace STL = STL_NAMESPACE;
......
...@@ -28,15 +28,16 @@ ...@@ -28,15 +28,16 @@
#include "runtime.hpp" #include "runtime.hpp"
#include "md5.hpp" #include "md5.hpp"
#include STL_ALGORITHM_FILE #ifdef USE_SYS_STL
#include <algorithm>
#else
#include "algorithm.hpp"
#endif
namespace STL = STL_NAMESPACE; namespace STL = STL_NAMESPACE;
#if defined(TAOCRYPT_X86ASM_AVAILABLE) && defined(TAO_ASM)
#define DO_MD5_ASM
#endif
namespace TaoCrypt { namespace TaoCrypt {
...@@ -84,10 +85,17 @@ void MD5::Swap(MD5& other) ...@@ -84,10 +85,17 @@ void MD5::Swap(MD5& other)
} }
// Update digest with data of size len, do in blocks #ifdef DO_MD5_ASM
// Update digest with data of size len
void MD5::Update(const byte* data, word32 len) void MD5::Update(const byte* data, word32 len)
{ {
byte* local = (byte*)buffer_; if (!isMMX) {
HASHwithTransform::Update(data, len);
return;
}
byte* local = reinterpret_cast<byte*>(buffer_);
// remove buffered data if possible // remove buffered data if possible
if (buffLen_) { if (buffLen_) {
...@@ -99,27 +107,14 @@ void MD5::Update(const byte* data, word32 len) ...@@ -99,27 +107,14 @@ void MD5::Update(const byte* data, word32 len)
len -= add; len -= add;
if (buffLen_ == BLOCK_SIZE) { if (buffLen_ == BLOCK_SIZE) {
ByteReverseIf(local, local, BLOCK_SIZE, LittleEndianOrder);
Transform(); Transform();
AddLength(BLOCK_SIZE); AddLength(BLOCK_SIZE);
buffLen_ = 0; buffLen_ = 0;
} }
} }
// do block size transforms or all at once for asm // at once for asm
if (buffLen_ == 0) { if (buffLen_ == 0) {
#ifndef DO_MD5_ASM
while (len >= BLOCK_SIZE) {
memcpy(&local[0], data, BLOCK_SIZE);
data += BLOCK_SIZE;
len -= BLOCK_SIZE;
ByteReverseIf(local, local, BLOCK_SIZE, LittleEndianOrder);
Transform();
AddLength(BLOCK_SIZE);
}
#else
word32 times = len / BLOCK_SIZE; word32 times = len / BLOCK_SIZE;
if (times) { if (times) {
AsmTransform(data, times); AsmTransform(data, times);
...@@ -128,7 +123,6 @@ void MD5::Update(const byte* data, word32 len) ...@@ -128,7 +123,6 @@ void MD5::Update(const byte* data, word32 len)
len -= add; len -= add;
data += add; data += add;
} }
#endif
} }
// cache any data left // cache any data left
...@@ -139,7 +133,6 @@ void MD5::Update(const byte* data, word32 len) ...@@ -139,7 +133,6 @@ void MD5::Update(const byte* data, word32 len)
} }
#ifdef DO_MD5_ASM
/* /*
......
...@@ -30,6 +30,20 @@ ...@@ -30,6 +30,20 @@
#include "misc.hpp" #include "misc.hpp"
#ifdef __GNUC__
#include <signal.h>
#include <setjmp.h>
#endif
#ifdef USE_SYS_STL
#include <algorithm>
#else
#include "algorithm.hpp"
#endif
namespace STL = STL_NAMESPACE;
#ifdef YASSL_PURE_C #ifdef YASSL_PURE_C
void* operator new(size_t sz, TaoCrypt::new_t) void* operator new(size_t sz, TaoCrypt::new_t)
...@@ -156,5 +170,129 @@ unsigned long Crop(unsigned long value, unsigned int size) ...@@ -156,5 +170,129 @@ unsigned long Crop(unsigned long value, unsigned int size)
} }
#ifdef TAOCRYPT_X86ASM_AVAILABLE
#ifndef _MSC_VER
static jmp_buf s_env;
static void SigIllHandler(int)
{
longjmp(s_env, 1);
}
#endif
bool HaveCpuId()
{
#ifdef _MSC_VER
__try
{
__asm
{
mov eax, 0
cpuid
}
}
__except (1)
{
return false;
}
return true;
#else
typedef void (*SigHandler)(int);
SigHandler oldHandler = signal(SIGILL, SigIllHandler);
if (oldHandler == SIG_ERR)
return false;
bool result = true;
if (setjmp(s_env))
result = false;
else
__asm__ __volatile
(
// save ebx in case -fPIC is being used
"push %%ebx; mov $0, %%eax; cpuid; pop %%ebx"
:
:
: "%eax", "%ecx", "%edx"
);
signal(SIGILL, oldHandler);
return result;
#endif
}
void CpuId(word32 input, word32 *output)
{
#ifdef __GNUC__
__asm__
(
// save ebx in case -fPIC is being used
"push %%ebx; cpuid; mov %%ebx, %%edi; pop %%ebx"
: "=a" (output[0]), "=D" (output[1]), "=c" (output[2]), "=d"(output[3])
: "a" (input)
);
#else
__asm
{
mov eax, input
cpuid
mov edi, output
mov [edi], eax
mov [edi+4], ebx
mov [edi+8], ecx
mov [edi+12], edx
}
#endif
}
bool IsPentium()
{
if (!HaveCpuId())
return false;
word32 cpuid[4];
CpuId(0, cpuid);
STL::swap(cpuid[2], cpuid[3]);
if (memcmp(cpuid+1, "GenuineIntel", 12) != 0)
return false;
CpuId(1, cpuid);
byte family = ((cpuid[0] >> 8) & 0xf);
if (family < 5)
return false;
return true;
}
static bool IsMmx()
{
if (!IsPentium())
return false;
word32 cpuid[4];
CpuId(1, cpuid);
if ((cpuid[3] & (1 << 23)) == 0)
return false;
return true;
}
bool isMMX = IsMmx();
#endif // TAOCRYPT_X86ASM_AVAILABLE
} // namespace } // namespace
...@@ -50,8 +50,11 @@ namespace TaoCrypt { ...@@ -50,8 +50,11 @@ namespace TaoCrypt {
RandomNumberGenerator::RandomNumberGenerator() RandomNumberGenerator::RandomNumberGenerator()
{ {
byte key[32]; byte key[32];
byte junk[256];
seed_.GenerateSeed(key, sizeof(key)); seed_.GenerateSeed(key, sizeof(key));
cipher_.SetKey(key, sizeof(key)); cipher_.SetKey(key, sizeof(key));
GenerateBlock(junk, sizeof(junk)); // rid initial state
} }
......
...@@ -28,15 +28,16 @@ ...@@ -28,15 +28,16 @@
#include "runtime.hpp" #include "runtime.hpp"
#include "ripemd.hpp" #include "ripemd.hpp"
#include STL_ALGORITHM_FILE #ifdef USE_SYS_STL
#include <algorithm>
#else
#include "algorithm.hpp"
#endif
namespace STL = STL_NAMESPACE; namespace STL = STL_NAMESPACE;
#if defined(TAOCRYPT_X86ASM_AVAILABLE) && defined(TAO_ASM)
#define DO_RIPEMD_ASM
#endif
namespace TaoCrypt { namespace TaoCrypt {
...@@ -86,10 +87,17 @@ void RIPEMD160::Swap(RIPEMD160& other) ...@@ -86,10 +87,17 @@ void RIPEMD160::Swap(RIPEMD160& other)
} }
// Update digest with data of size len, do in blocks #ifdef DO_RIPEMD_ASM
// Update digest with data of size len
void RIPEMD160::Update(const byte* data, word32 len) void RIPEMD160::Update(const byte* data, word32 len)
{ {
byte* local = (byte*)buffer_; if (!isMMX) {
HASHwithTransform::Update(data, len);
return;
}
byte* local = reinterpret_cast<byte*>(buffer_);
// remove buffered data if possible // remove buffered data if possible
if (buffLen_) { if (buffLen_) {
...@@ -101,27 +109,14 @@ void RIPEMD160::Update(const byte* data, word32 len) ...@@ -101,27 +109,14 @@ void RIPEMD160::Update(const byte* data, word32 len)
len -= add; len -= add;
if (buffLen_ == BLOCK_SIZE) { if (buffLen_ == BLOCK_SIZE) {
ByteReverseIf(local, local, BLOCK_SIZE, LittleEndianOrder);
Transform(); Transform();
AddLength(BLOCK_SIZE); AddLength(BLOCK_SIZE);
buffLen_ = 0; buffLen_ = 0;
} }
} }
// do block size transforms or all at once for asm // all at once for asm
if (buffLen_ == 0) { if (buffLen_ == 0) {
#ifndef DO_RIPEMD_ASM
while (len >= BLOCK_SIZE) {
memcpy(&local[0], data, BLOCK_SIZE);
data += BLOCK_SIZE;
len -= BLOCK_SIZE;
ByteReverseIf(local, local, BLOCK_SIZE, LittleEndianOrder);
Transform();
AddLength(BLOCK_SIZE);
}
#else
word32 times = len / BLOCK_SIZE; word32 times = len / BLOCK_SIZE;
if (times) { if (times) {
AsmTransform(data, times); AsmTransform(data, times);
...@@ -130,7 +125,6 @@ void RIPEMD160::Update(const byte* data, word32 len) ...@@ -130,7 +125,6 @@ void RIPEMD160::Update(const byte* data, word32 len)
len -= add; len -= add;
data += add; data += add;
} }
#endif
} }
// cache any data left // cache any data left
...@@ -140,6 +134,8 @@ void RIPEMD160::Update(const byte* data, word32 len) ...@@ -140,6 +134,8 @@ void RIPEMD160::Update(const byte* data, word32 len)
} }
} }
#endif // DO_RIPEMD_ASM
// for all // for all
#define F(x, y, z) (x ^ y ^ z) #define F(x, y, z) (x ^ y ^ z)
......
...@@ -28,16 +28,16 @@ ...@@ -28,16 +28,16 @@
#include "runtime.hpp" #include "runtime.hpp"
#include <string.h> #include <string.h>
#include "sha.hpp" #include "sha.hpp"
#include STL_ALGORITHM_FILE #ifdef USE_SYS_STL
#include <algorithm>
#else
#include "algorithm.hpp"
#endif
namespace STL = STL_NAMESPACE; namespace STL = STL_NAMESPACE;
#if defined(TAOCRYPT_X86ASM_AVAILABLE) && defined(TAO_ASM)
#define DO_SHA_ASM
#endif
namespace TaoCrypt { namespace TaoCrypt {
...@@ -108,10 +108,18 @@ void SHA::Swap(SHA& other) ...@@ -108,10 +108,18 @@ void SHA::Swap(SHA& other)
} }
// Update digest with data of size len, do in blocks
#ifdef DO_SHA_ASM
// Update digest with data of size len
void SHA::Update(const byte* data, word32 len) void SHA::Update(const byte* data, word32 len)
{ {
byte* local = (byte*)buffer_; if (!isMMX) {
HASHwithTransform::Update(data, len);
return;
}
byte* local = reinterpret_cast<byte*>(buffer_);
// remove buffered data if possible // remove buffered data if possible
if (buffLen_) { if (buffLen_) {
...@@ -123,27 +131,15 @@ void SHA::Update(const byte* data, word32 len) ...@@ -123,27 +131,15 @@ void SHA::Update(const byte* data, word32 len)
len -= add; len -= add;
if (buffLen_ == BLOCK_SIZE) { if (buffLen_ == BLOCK_SIZE) {
ByteReverseIf(local, local, BLOCK_SIZE, BigEndianOrder); ByteReverse(local, local, BLOCK_SIZE);
Transform(); Transform();
AddLength(BLOCK_SIZE); AddLength(BLOCK_SIZE);
buffLen_ = 0; buffLen_ = 0;
} }
} }
// do block size transforms or all at once for asm // all at once for asm
if (buffLen_ == 0) { if (buffLen_ == 0) {
#ifndef DO_SHA_ASM
while (len >= BLOCK_SIZE) {
memcpy(&local[0], data, BLOCK_SIZE);
data += BLOCK_SIZE;
len -= BLOCK_SIZE;
ByteReverseIf(local, local, BLOCK_SIZE, BigEndianOrder);
Transform();
AddLength(BLOCK_SIZE);
}
#else
word32 times = len / BLOCK_SIZE; word32 times = len / BLOCK_SIZE;
if (times) { if (times) {
AsmTransform(data, times); AsmTransform(data, times);
...@@ -152,7 +148,6 @@ void SHA::Update(const byte* data, word32 len) ...@@ -152,7 +148,6 @@ void SHA::Update(const byte* data, word32 len)
len -= add; len -= add;
data += add; data += add;
} }
#endif
} }
// cache any data left // cache any data left
...@@ -162,6 +157,8 @@ void SHA::Update(const byte* data, word32 len) ...@@ -162,6 +157,8 @@ void SHA::Update(const byte* data, word32 len)
} }
} }
#endif // DO_SHA_ASM
void SHA::Transform() void SHA::Transform()
{ {
......
...@@ -35,33 +35,20 @@ ...@@ -35,33 +35,20 @@
#include "twofish.hpp" #include "twofish.hpp"
#if defined(TAOCRYPT_X86ASM_AVAILABLE) && defined(TAO_ASM)
#define DO_TWOFISH_ASM
#endif
namespace TaoCrypt { namespace TaoCrypt {
#if !defined(DO_TWOFISH_ASM) #if defined(DO_TWOFISH_ASM)
// Generic Version
void Twofish::Process(byte* out, const byte* in, word32 sz)
{
if (mode_ == ECB)
ECB_Process(out, in, sz);
else if (mode_ == CBC)
if (dir_ == ENCRYPTION)
CBC_Encrypt(out, in, sz);
else
CBC_Decrypt(out, in, sz);
}
#else
// ia32 optimized version // ia32 optimized version
void Twofish::Process(byte* out, const byte* in, word32 sz) void Twofish::Process(byte* out, const byte* in, word32 sz)
{ {
if (!isMMX) {
Mode_BASE::Process(out, in, sz);
return;
}
word32 blocks = sz / BLOCK_SIZE; word32 blocks = sz / BLOCK_SIZE;
if (mode_ == ECB) if (mode_ == ECB)
......
REM quick and dirty build file for testing different MSDEVs REM quick and dirty build file for testing different MSDEVs
setlocal setlocal
set myFLAGS= /I../include /I../../mySTL /c /W3 /G6 /O2 set myFLAGS= /I../include /I../mySTL /c /W3 /G6 /O2
cl %myFLAGS% test.cpp cl %myFLAGS% test.cpp
......
...@@ -247,6 +247,8 @@ void taocrypt_test(void* args) ...@@ -247,6 +247,8 @@ void taocrypt_test(void* args)
args.argv = argv; args.argv = argv;
taocrypt_test(&args); taocrypt_test(&args);
TaoCrypt::CleanUp();
return args.return_code; return args.return_code;
} }
......
REM quick and dirty build file for testing different MSDEVs REM quick and dirty build file for testing different MSDEVs
setlocal setlocal
set myFLAGS= /I../include /I../taocrypt/include /I../mySTL /c /W3 /G6 /O2 /MT /D"WIN32" /D"NO_MAIN_DRIVER" set myFLAGS= /I../include /I../taocrypt/include /I../taocrypt/mySTL /c /W3 /G6 /O2 /MT /D"WIN32" /D"NO_MAIN_DRIVER"
cl %myFLAGS% testsuite.cpp cl %myFLAGS% testsuite.cpp
cl %myFLAGS% ../examples/client/client.cpp cl %myFLAGS% ../examples/client/client.cpp
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <stdlib.h> #include <stdlib.h>
#include <assert.h> #include <assert.h>
//#define NON_BLOCKING // test server and client example (not echos)
#ifdef _WIN32 #ifdef _WIN32
#include <winsock2.h> #include <winsock2.h>
#include <process.h> #include <process.h>
...@@ -23,16 +25,17 @@ ...@@ -23,16 +25,17 @@
#include <sys/types.h> #include <sys/types.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <pthread.h> #include <pthread.h>
#ifdef NON_BLOCKING
#include <fcntl.h>
#endif
#define SOCKET_T int #define SOCKET_T int
#endif /* _WIN32 */ #endif /* _WIN32 */
#if !defined(_SOCKLEN_T) && defined(_WIN32) #if !defined(_SOCKLEN_T) && \
(defined(_WIN32) || defined(__NETWARE__) || defined(__APPLE__))
typedef int socklen_t; typedef int socklen_t;
#endif #endif
#if !defined(_SOCKLEN_T) && defined(__NETWARE__)
typedef size_t socklen_t;
#endif
// Check type of third arg to accept // Check type of third arg to accept
...@@ -262,6 +265,20 @@ inline void set_args(int& argc, char**& argv, func_args& args) ...@@ -262,6 +265,20 @@ inline void set_args(int& argc, char**& argv, func_args& args)
} }
inline void tcp_set_nonblocking(SOCKET_T& sockfd)
{
#ifdef NON_BLOCKING
#ifdef _WIN32
unsigned long blocking = 1;
int ret = ioctlsocket(sockfd, FIONBIO, &blocking);
#else
int flags = fcntl(sockfd, F_GETFL, 0);
int ret = fcntl(sockfd, F_SETFL, flags | O_NONBLOCK);
#endif
#endif
}
inline void tcp_socket(SOCKET_T& sockfd, sockaddr_in& addr) inline void tcp_socket(SOCKET_T& sockfd, sockaddr_in& addr)
{ {
sockfd = socket(AF_INET, SOCK_STREAM, 0); sockfd = socket(AF_INET, SOCK_STREAM, 0);
...@@ -289,8 +306,7 @@ inline void tcp_connect(SOCKET_T& sockfd) ...@@ -289,8 +306,7 @@ inline void tcp_connect(SOCKET_T& sockfd)
sockaddr_in addr; sockaddr_in addr;
tcp_socket(sockfd, addr); tcp_socket(sockfd, addr);
if (connect(sockfd, (const sockaddr*)&addr, sizeof(addr)) != 0) if (connect(sockfd, (const sockaddr*)&addr, sizeof(addr)) != 0) {
{
tcp_close(sockfd); tcp_close(sockfd);
err_sys("tcp connect failed"); err_sys("tcp connect failed");
} }
...@@ -302,19 +318,18 @@ inline void tcp_listen(SOCKET_T& sockfd) ...@@ -302,19 +318,18 @@ inline void tcp_listen(SOCKET_T& sockfd)
sockaddr_in addr; sockaddr_in addr;
tcp_socket(sockfd, addr); tcp_socket(sockfd, addr);
if (bind(sockfd, (const sockaddr*)&addr, sizeof(addr)) != 0) if (bind(sockfd, (const sockaddr*)&addr, sizeof(addr)) != 0) {
{
tcp_close(sockfd); tcp_close(sockfd);
err_sys("tcp bind failed"); err_sys("tcp bind failed");
} }
if (listen(sockfd, 3) != 0) if (listen(sockfd, 3) != 0) {
{
tcp_close(sockfd); tcp_close(sockfd);
err_sys("tcp listen failed"); err_sys("tcp listen failed");
} }
} }
inline void tcp_accept(SOCKET_T& sockfd, SOCKET_T& clientfd, func_args& args) inline void tcp_accept(SOCKET_T& sockfd, SOCKET_T& clientfd, func_args& args)
{ {
tcp_listen(sockfd); tcp_listen(sockfd);
...@@ -333,11 +348,14 @@ inline void tcp_accept(SOCKET_T& sockfd, SOCKET_T& clientfd, func_args& args) ...@@ -333,11 +348,14 @@ inline void tcp_accept(SOCKET_T& sockfd, SOCKET_T& clientfd, func_args& args)
clientfd = accept(sockfd, (sockaddr*)&client, (ACCEPT_THIRD_T)&client_len); clientfd = accept(sockfd, (sockaddr*)&client, (ACCEPT_THIRD_T)&client_len);
if (clientfd == -1) if (clientfd == -1) {
{
tcp_close(sockfd); tcp_close(sockfd);
err_sys("tcp accept failed"); err_sys("tcp accept failed");
} }
#ifdef NON_BLOCKING
tcp_set_nonblocking(clientfd);
#endif
} }
...@@ -363,25 +381,30 @@ inline void showPeer(SSL* ssl) ...@@ -363,25 +381,30 @@ inline void showPeer(SSL* ssl)
inline DH* set_tmpDH(SSL_CTX* ctx) inline DH* set_tmpDH(SSL_CTX* ctx)
{ {
static unsigned char dh512_p[] = static unsigned char dh1024_p[] =
{ {
0xDA,0x58,0x3C,0x16,0xD9,0x85,0x22,0x89,0xD0,0xE4,0xAF,0x75, 0xE6, 0x96, 0x9D, 0x3D, 0x49, 0x5B, 0xE3, 0x2C, 0x7C, 0xF1, 0x80, 0xC3,
0x6F,0x4C,0xCA,0x92,0xDD,0x4B,0xE5,0x33,0xB8,0x04,0xFB,0x0F, 0xBD, 0xD4, 0x79, 0x8E, 0x91, 0xB7, 0x81, 0x82, 0x51, 0xBB, 0x05, 0x5E,
0xED,0x94,0xEF,0x9C,0x8A,0x44,0x03,0xED,0x57,0x46,0x50,0xD3, 0x2A, 0x20, 0x64, 0x90, 0x4A, 0x79, 0xA7, 0x70, 0xFA, 0x15, 0xA2, 0x59,
0x69,0x99,0xDB,0x29,0xD7,0x76,0x27,0x6B,0xA2,0xD3,0xD4,0x12, 0xCB, 0xD5, 0x23, 0xA6, 0xA6, 0xEF, 0x09, 0xC4, 0x30, 0x48, 0xD5, 0xA2,
0xE2,0x18,0xF4,0xDD,0x1E,0x08,0x4C,0xF6,0xD8,0x00,0x3E,0x7C, 0x2F, 0x97, 0x1F, 0x3C, 0x20, 0x12, 0x9B, 0x48, 0x00, 0x0E, 0x6E, 0xDD,
0x47,0x74,0xE8,0x33, 0x06, 0x1C, 0xBC, 0x05, 0x3E, 0x37, 0x1D, 0x79, 0x4E, 0x53, 0x27, 0xDF,
0x61, 0x1E, 0xBB, 0xBE, 0x1B, 0xAC, 0x9B, 0x5C, 0x60, 0x44, 0xCF, 0x02,
0x3D, 0x76, 0xE0, 0x5E, 0xEA, 0x9B, 0xAD, 0x99, 0x1B, 0x13, 0xA6, 0x3C,
0x97, 0x4E, 0x9E, 0xF1, 0x83, 0x9E, 0xB5, 0xDB, 0x12, 0x51, 0x36, 0xF7,
0x26, 0x2E, 0x56, 0xA8, 0x87, 0x15, 0x38, 0xDF, 0xD8, 0x23, 0xC6, 0x50,
0x50, 0x85, 0xE2, 0x1F, 0x0D, 0xD5, 0xC8, 0x6B,
}; };
static unsigned char dh512_g[] = static unsigned char dh1024_g[] =
{ {
0x02, 0x02,
}; };
DH* dh; DH* dh;
if ( (dh = DH_new()) ) { if ( (dh = DH_new()) ) {
dh->p = BN_bin2bn(dh512_p, sizeof(dh512_p), 0); dh->p = BN_bin2bn(dh1024_p, sizeof(dh1024_p), 0);
dh->g = BN_bin2bn(dh512_g, sizeof(dh512_g), 0); dh->g = BN_bin2bn(dh1024_g, sizeof(dh1024_g), 0);
} }
if (!dh->p || !dh->g) { if (!dh->p || !dh->g) {
DH_free(dh); DH_free(dh);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment