diff --git a/cyassl/internal.h b/cyassl/internal.h index 0921740a4..cc109738f 100644 --- a/cyassl/internal.h +++ b/cyassl/internal.h @@ -1264,6 +1264,7 @@ enum HandShakeType { client_hello = 1, server_hello = 2, hello_verify_request = 3, /* DTLS addition */ + session_ticket = 4, certificate = 11, server_key_exchange = 12, certificate_request = 13, diff --git a/cyassl/sniffer_error.h b/cyassl/sniffer_error.h index ecd2e77a2..7b7a50496 100644 --- a/cyassl/sniffer_error.h +++ b/cyassl/sniffer_error.h @@ -90,9 +90,14 @@ #define OUT_OF_ORDER_STR 57 #define OVERLAP_DUPLICATE_STR 58 #define OVERLAP_REASSEMBLY_BEGIN_STR 59 - #define OVERLAP_REASSEMBLY_END_STR 60 + #define MISSED_CLIENT_HELLO_STR 61 +#define GOT_HELLO_REQUEST_STR 62 +#define GOT_SESSION_TICKET_STR 63 +#define BAD_INPUT_STR 64 +#define BAD_DECRYPT_TYPE 65 +#define BAD_FINISHED_MSG 66 /* !!!! also add to msgTable in sniffer.c and .rc file !!!! */ diff --git a/cyassl/sniffer_error.rc b/cyassl/sniffer_error.rc index 89b12e2e9..ad722f4e0 100644 --- a/cyassl/sniffer_error.rc +++ b/cyassl/sniffer_error.rc @@ -74,5 +74,11 @@ STRINGTABLE 60, "Received an Overlap Reassembly End Duplicate Packet" 61, "Missed the Client Hello Entirely" + 62, "Got Hello Request msg" + 63, "Got Session Ticket msg" + 64, "Bad Input" + 65, "Bad Decrypt Type" + + 66, "Bad Finished Message Processing" } diff --git a/src/sniffer.c b/src/sniffer.c index e139fd80a..3b5a35649 100644 --- a/src/sniffer.c +++ b/src/sniffer.c @@ -45,6 +45,17 @@ #include #include + +#ifndef min + +static INLINE word32 min(word32 a, word32 b) +{ + return a > b ? b : a; +} + +#endif + + /* Misc constants */ enum { MAX_SERVER_ADDRESS = 128, /* maximum server address length */ @@ -61,6 +72,9 @@ enum { PSEUDO_HDR_SZ = 12, /* TCP Pseudo Header size in bytes */ FATAL_ERROR_STATE = 1, /* SnifferSession fatal error state */ SNIFFER_TIMEOUT = 900, /* Cache unclosed Sessions for 15 minutes */ + TICKET_HINT_LEN = 4, /* Session Ticket Hint length */ + EXT_TYPE_SZ = 2, /* Extension length */ + TICKET_EXT_ID = 0x23 /* Session Ticket Extension ID */ }; @@ -196,6 +210,13 @@ static const char* const msgTable[] = /* 61 */ "Missed the Client Hello Entirely", + "Got Hello Request msg", + "Got Session Ticket msg", + "Bad Input", + "Bad Decrypt Type", + + /* 66 */ + "Bad Finished Message Processing" }; @@ -248,7 +269,7 @@ typedef struct Flags { byte cached; /* have we cached this session yet */ byte clientHello; /* processed client hello yet, for SSLv2 */ byte finCount; /* get both FINs before removing */ - byte fatalError; /* fatal error state */ + byte fatalError; /* fatal error state */ } Flags; @@ -279,7 +300,8 @@ typedef struct SnifferSession { time_t bornOn; /* born on ticks */ PacketBuffer* cliReassemblyList; /* client out of order packets */ PacketBuffer* srvReassemblyList; /* server out of order packets */ - struct SnifferSession* next; /* for hash table list */ + struct SnifferSession* next; /* for hash table list */ + byte* ticketID; /* mac ID of session ticket */ } SnifferSession; @@ -347,6 +369,8 @@ static void FreeSnifferSession(SnifferSession* session) FreePacketList(session->cliReassemblyList); FreePacketList(session->srvReassemblyList); + + free(session->ticketID); } free(session); } @@ -442,6 +466,7 @@ static void InitSession(SnifferSession* session) session->cliReassemblyList = 0; session->srvReassemblyList = 0; session->next = 0; + session->ticketID = 0; InitFlags(&session->flags); InitFinCapture(&session->finCaputre); @@ -1067,6 +1092,39 @@ static int ProcessClientKeyExchange(const byte* input, int* sslBytes, } +/* Process Session Ticket */ +static int ProcessSessionTicket(const byte* input, int* sslBytes, + SnifferSession* session, char* error) +{ + word16 len; + + /* make sure can read through hint and len */ + if (TICKET_HINT_LEN + LENGTH_SZ > *sslBytes) { + SetError(BAD_INPUT_STR, error, session, FATAL_ERROR_STATE); + return -1; + } + + input += TICKET_HINT_LEN; /* skip over hint */ + *sslBytes -= TICKET_HINT_LEN; + + len = (input[0] << 8) | input[1]; + input += LENGTH_SZ; + *sslBytes -= LENGTH_SZ; + + /* make sure can read through ticket */ + if (len > *sslBytes || len < ID_LEN) { + SetError(BAD_INPUT_STR, error, session, FATAL_ERROR_STATE); + return -1; + } + + /* store session with macID as sessionID */ + session->sslServer->options.haveSessionId = 1; + XMEMCPY(session->sslServer->arrays.sessionID, input + len - ID_LEN, ID_LEN); + + return 0; +} + + /* Process Server Hello */ static int ProcessServerHello(const byte* input, int* sslBytes, SnifferSession* session, char* error) @@ -1074,6 +1132,7 @@ static int ProcessServerHello(const byte* input, int* sslBytes, ProtocolVersion pv; byte b; int toRead = sizeof(ProtocolVersion) + RAN_LEN + ENUM_LEN; + int doResume = 0; /* make sure we didn't miss ClientHello */ if (session->flags.clientHello == 0) { @@ -1107,22 +1166,34 @@ static int ProcessServerHello(const byte* input, int* sslBytes, SetError(SERVER_HELLO_INPUT_STR, error, session, FATAL_ERROR_STATE); return -1; } - XMEMCPY(session->sslServer->arrays.sessionID, input, ID_LEN); + if (b) { + XMEMCPY(session->sslServer->arrays.sessionID, input, ID_LEN); + session->sslServer->options.haveSessionId = 1; + } input += b; *sslBytes -= b; - if (b) - session->sslServer->options.haveSessionId = 1; (void)*input++; /* eat first byte, always 0 */ b = *input++; session->sslServer->options.cipherSuite = b; session->sslClient->options.cipherSuite = b; *sslBytes -= SUITE_LEN; - + if (session->sslServer->options.haveSessionId && - XMEMCMP(session->sslServer->arrays.sessionID, - session->sslClient->arrays.sessionID, ID_LEN) == 0) { - /* resuming */ + XMEMCMP(session->sslServer->arrays.sessionID, + session->sslClient->arrays.sessionID, ID_LEN) == 0) + doResume = 1; + else if (session->sslClient->options.haveSessionId == 0 && + session->sslServer->options.haveSessionId == 0 && + session->ticketID) + doResume = 1; + + if (session->ticketID && doResume) { + /* use ticketID to retrieve from session */ + XMEMCPY(session->sslServer->arrays.sessionID, session->ticketID,ID_LEN); + } + + if (doResume ) { SSL_SESSION* resume = GetSession(session->sslServer, session->sslServer->arrays.masterSecret); if (resume == NULL) { @@ -1173,8 +1244,9 @@ static int ProcessServerHello(const byte* input, int* sslBytes, static int ProcessClientHello(const byte* input, int* sslBytes, SnifferSession* session, char* error) { - byte sessionLen; - int toRead = sizeof(ProtocolVersion) + RAN_LEN + ENUM_LEN; + byte bLen; + word16 len; + int toRead = sizeof(ProtocolVersion) + RAN_LEN + ENUM_LEN; session->flags.clientHello = 1; /* don't process again */ @@ -1195,14 +1267,16 @@ static int ProcessClientHello(const byte* input, int* sslBytes, *sslBytes -= RAN_LEN; /* store session in case trying to resume */ - sessionLen = *input++; - if (sessionLen) { + bLen = *input++; + *sslBytes -= ENUM_LEN; + if (bLen) { if (ID_LEN > *sslBytes) { SetError(CLIENT_HELLO_INPUT_STR, error, session, FATAL_ERROR_STATE); return -1; } Trace(CLIENT_RESUME_TRY_STR); XMEMCPY(session->sslClient->arrays.sessionID, input, ID_LEN); + session->sslClient->options.haveSessionId = 1; } #ifdef SHOW_SECRETS { @@ -1213,11 +1287,142 @@ static int ProcessClientHello(const byte* input, int* sslBytes, printf("\n"); } #endif + + input += bLen; + *sslBytes -= bLen; + + /* skip cipher suites */ + /* make sure can read len */ + if (SUITE_LEN > *sslBytes) { + SetError(CLIENT_HELLO_INPUT_STR, error, session, FATAL_ERROR_STATE); + return -1; + } + len = (input[0] << 8) | input[1]; + input += SUITE_LEN; + *sslBytes -= SUITE_LEN; + /* make sure can read suites + comp len */ + if (len + ENUM_LEN > *sslBytes) { + SetError(CLIENT_HELLO_INPUT_STR, error, session, FATAL_ERROR_STATE); + return -1; + } + input += len; + *sslBytes -= len; + + /* skip compression */ + bLen = *input++; + *sslBytes -= ENUM_LEN; + /* make sure can read len */ + if (bLen > *sslBytes) { + SetError(CLIENT_HELLO_INPUT_STR, error, session, FATAL_ERROR_STATE); + return -1; + } + input += bLen; + *sslBytes -= bLen; + + if (*sslBytes == 0) { + /* no extensions */ + return 0; + } + /* skip extensions until session ticket */ + /* make sure can read len */ + if (SUITE_LEN > *sslBytes) { + SetError(CLIENT_HELLO_INPUT_STR, error, session, FATAL_ERROR_STATE); + return -1; + } + len = (input[0] << 8) | input[1]; + input += SUITE_LEN; + *sslBytes -= SUITE_LEN; + /* make sure can read through all extensions */ + if (len > *sslBytes) { + SetError(CLIENT_HELLO_INPUT_STR, error, session, FATAL_ERROR_STATE); + return -1; + } + + while (len > EXT_TYPE_SZ + LENGTH_SZ) { + byte extType[EXT_TYPE_SZ]; + word16 extLen; + + extType[0] = input[0]; + extType[1] = input[1]; + input += EXT_TYPE_SZ; + *sslBytes -= EXT_TYPE_SZ; + + extLen = (input[0] << 8) | input[1]; + input += LENGTH_SZ; + *sslBytes -= LENGTH_SZ; + + /* make sure can read through individual extension */ + if (extLen > *sslBytes) { + SetError(CLIENT_HELLO_INPUT_STR, error, session, FATAL_ERROR_STATE); + return -1; + } + + if (extType[0] == 0x00 && extType[1] == TICKET_EXT_ID) { + + /* make sure can read through ticket if there is a non blank one */ + if (extLen && extLen < ID_LEN) { + SetError(CLIENT_HELLO_INPUT_STR, error, session, + FATAL_ERROR_STATE); + return -1; + } + + if (extLen) { + if (session->ticketID == 0) { + session->ticketID = (byte*)malloc(ID_LEN); + if (session->ticketID == 0) { + SetError(MEMORY_STR, error, session, + FATAL_ERROR_STATE); + return -1; + } + } + XMEMCPY(session->ticketID, input + extLen - ID_LEN, ID_LEN); + } + } + + input += extLen; + *sslBytes -= extLen; + len -= extLen + EXT_TYPE_SZ + LENGTH_SZ; + } + return 0; } +/* Process Finished */ +static int ProcessFinished(const byte* input, int* sslBytes, + SnifferSession* session, char* error) +{ + SSL* ssl; + word32 inOutIdx = 0; + int ret; + + if (session->flags.side == SERVER_END) + ssl = session->sslServer; + else + ssl = session->sslClient; + ret = DoFinished(ssl, input, &inOutIdx, SNIFF); + *sslBytes -= (int)inOutIdx; + + if (ret < 0) { + SetError(BAD_FINISHED_MSG, error, session, FATAL_ERROR_STATE); + return ret; + } + + if (ret == 0 && session->flags.cached == 0) { + if (session->sslServer->options.haveSessionId) { + CYASSL_SESSION* sess = GetSession(session->sslServer, NULL); + if (sess == NULL) + AddSession(session->sslServer); /* don't re add */ + session->flags.cached = 1; + } + } + + + return ret; +} + + /* Process HandShake input */ static int DoHandShake(const byte* input, int* sslBytes, SnifferSession* session, char* error) @@ -1245,6 +1450,13 @@ static int DoHandShake(const byte* input, int* sslBytes, case hello_verify_request: Trace(GOT_HELLO_VERIFY_STR); break; + case hello_request: + Trace(GOT_HELLO_REQUEST_STR); + break; + case session_ticket: + Trace(GOT_SESSION_TICKET_STR); + ret = ProcessSessionTicket(input, sslBytes, session, error); + break; case server_hello: Trace(GOT_SERVER_HELLO_STR); ret = ProcessServerHello(input, sslBytes, session, error); @@ -1263,22 +1475,7 @@ static int DoHandShake(const byte* input, int* sslBytes, break; case finished: Trace(GOT_FINISHED_STR); - { - SSL* ssl; - word32 inOutIdx = 0; - - if (session->flags.side == SERVER_END) - ssl = session->sslServer; - else - ssl = session->sslClient; - ret = DoFinished(ssl, input, &inOutIdx, SNIFF); - - if (ret == 0 && session->flags.cached == 0) { - session->sslServer->options.haveSessionId = 1; - AddSession(session->sslServer); - session->flags.cached = 1; - } - } + ret = ProcessFinished(input, sslBytes, session, error); break; case client_hello: Trace(GOT_CLIENT_HELLO_STR); @@ -1333,6 +1530,10 @@ static void Decrypt(SSL* ssl, byte* output, const byte* input, word32 sz) RabbitProcess(&ssl->decrypt.rabbit, output, input, sz); break; #endif + + default: + Trace(BAD_DECRYPT_TYPE); + break; } } @@ -1648,16 +1849,6 @@ static int CheckSession(IpInfo* ipInfo, TcpInfo* tcpInfo, int sslBytes, } -#ifndef min - -static INLINE word32 min(word32 a, word32 b) -{ - return a > b ? b : a; -} - -#endif - - /* Create a Packet Buffer from *begin - end, adjust new *begin and bytesLeft */ static PacketBuffer* CreateBuffer(word32* begin, word32 end, const byte* data, int* bytesLeft) diff --git a/sslSniffer/sslSnifferTest/snifftest.c b/sslSniffer/sslSnifferTest/snifftest.c index a1434e1bf..0b844b2d2 100755 --- a/sslSniffer/sslSnifferTest/snifftest.c +++ b/sslSniffer/sslSnifferTest/snifftest.c @@ -49,6 +49,7 @@ int main() #include /* pcap stuff */ #include /* printf */ #include /* EXIT_SUCCESS */ +#include /* strcmp */ #include /* signal */ #include @@ -71,7 +72,7 @@ pcap_if_t *alldevs; static void sig_handler(const int sig) { - printf("SIGINT handled.\n"); + printf("SIGINT handled = %d.\n", sig); if (pcap) pcap_close(pcap); pcap_freealldevs(alldevs); @@ -82,7 +83,7 @@ static void sig_handler(const int sig) } -void err_sys(const char* msg) +static void err_sys(const char* msg) { fprintf(stderr, "%s\n", msg); exit(EXIT_FAILURE); @@ -96,7 +97,7 @@ void err_sys(const char* msg) #endif -char* iptos(unsigned int addr) +static char* iptos(unsigned int addr) { static char output[32]; byte *p = (byte*)&addr; @@ -112,11 +113,12 @@ int main(int argc, char** argv) int ret; int inum; int port; + int saveFile = 0; int i = 0; + int frame = ETHER_IF_FRAME_LEN; char err[PCAP_ERRBUF_SIZE]; char filter[32]; - char loopback = 0; - char *server = NULL; + const char *server = NULL; struct bpf_program fp; pcap_if_t *d; pcap_addr_t *a; @@ -124,7 +126,7 @@ int main(int argc, char** argv) signal(SIGINT, sig_handler); #ifndef _WIN32 - ssl_InitSniffer(); + ssl_InitSniffer(); /* dll load on Windows */ #endif ssl_Trace("./tracefile.txt", err); @@ -159,9 +161,6 @@ int main(int argc, char** argv) if (pcap == NULL) printf("pcap_create failed %s\n", err); - if (d->flags & PCAP_IF_LOOPBACK) - loopback = 1; - /* get an IPv4 address */ for (a = d->addresses; a; a = a->next) { switch(a->addr->sa_family) @@ -171,6 +170,9 @@ int main(int argc, char** argv) iptos(((struct sockaddr_in *)a->addr)->sin_addr.s_addr); printf("server = %s\n", server); break; + + default: + break; } } if (server == NULL) @@ -208,6 +210,7 @@ int main(int argc, char** argv) FILETYPE_PEM, NULL, err); } else if (argc >= 3) { + saveFile = 1; pcap = pcap_open_offline(argv[1], err); if (pcap == NULL) { printf("pcap_open_offline failed %s\n", err); @@ -238,6 +241,9 @@ int main(int argc, char** argv) if (ret != 0) err_sys(err); + if (pcap_datalink(pcap) == 0) + frame = LOCAL_IF_FRAME_LEN; + while (1) { struct pcap_pkthdr header; const unsigned char* packet = pcap_next(pcap, &header); @@ -246,9 +252,6 @@ int main(int argc, char** argv) byte data[65535]; if (header.caplen > 40) { /* min ip(20) + min tcp(20) */ - int frame = ETHER_IF_FRAME_LEN; - if (loopback) - frame = LOCAL_IF_FRAME_LEN; packet += frame; header.caplen -= frame; } @@ -263,6 +266,8 @@ int main(int argc, char** argv) printf("SSL App Data:%s\n", data); } } + else if (saveFile) + break; /* we're done reading file */ } return EXIT_SUCCESS;