From 543108bdcc752fcf8dff23f2b7040e7bf5bb507f Mon Sep 17 00:00:00 2001 From: toddouska Date: Fri, 15 Mar 2013 13:17:05 -0700 Subject: [PATCH] add memory tracker to example client and server if using default memory cbs --- cyassl/test.h | 108 +++++++++++++++++++++++++++++++++++++++ examples/client/client.c | 27 +++++++++- examples/server/server.c | 26 +++++++++- 3 files changed, 158 insertions(+), 3 deletions(-) diff --git a/cyassl/test.h b/cyassl/test.h index 5c0f15f67..73682320f 100644 --- a/cyassl/test.h +++ b/cyassl/test.h @@ -937,5 +937,113 @@ static INLINE int CurrentDir(const char* str) #endif /* USE_WINDOWS_API */ + +#ifdef USE_CYASSL_MEMORY + + typedef struct memoryStats { + size_t totalAllocs; /* number of allocations */ + size_t totalBytes; /* total number of bytes allocated */ + size_t peakBytes; /* concurrent max bytes */ + size_t currentBytes; /* total current bytes in use */ + } memoryStats; + + typedef struct memHint { + size_t thisSize; /* size of this memory */ + void* thisMemory; /* actual memory for user */ + } memHint; + + typedef struct memoryTrack { + union { + memHint hint; + byte alignit[16]; /* make sure we have strong alignment */ + } u; + } memoryTrack; + + #if defined(CYASSL_TRACK_MEMORY) + #define DO_MEM_STATS + static memoryStats ourMemStats; + #endif + + static INLINE void* TrackMalloc(size_t sz) + { + memoryTrack* mt; + + if (sz == 0) + return NULL; + + mt = (memoryTrack*)malloc(sizeof(memoryTrack) + sz); + if (mt == NULL) + return NULL; + + mt->u.hint.thisSize = sz; + mt->u.hint.thisMemory = (byte*)mt + sizeof(memoryTrack); + +#ifdef DO_MEM_STATS + ourMemStats.totalAllocs++; + ourMemStats.totalBytes += sz; + ourMemStats.currentBytes += sz; + if (ourMemStats.currentBytes > ourMemStats.peakBytes) + ourMemStats.peakBytes = ourMemStats.currentBytes; +#endif + + return mt->u.hint.thisMemory; + } + + + static INLINE void TrackFree(void* ptr) + { + memoryTrack* mt; + + if (ptr == NULL) + return; + + mt = (memoryTrack*)((byte*)ptr - sizeof(memoryTrack)); + +#ifdef DO_MEM_STATS + ourMemStats.currentBytes -= mt->u.hint.thisSize; +#endif + + return free(mt); + } + + + static INLINE void* TrackRealloc(void* ptr, size_t sz) + { + void* ret = TrackMalloc(sz); + + if (ret && ptr) + memcpy(ret, ptr, sz); + + if (ret) + TrackFree(ptr); + + return ret; + } + + static INLINE void InitMemoryTracker(void) + { + if (CyaSSL_SetAllocators(TrackMalloc, TrackFree, TrackRealloc) != 0) + err_sys("CyaSSL SetAllocators failed for track memory"); + + #ifdef DO_MEM_STATS + ourMemStats.totalAllocs = 0; + ourMemStats.totalBytes = 0; + ourMemStats.peakBytes = 0; + ourMemStats.currentBytes = 0; + #endif + } + + static INLINE void ShowMemoryTracker(void) + { + #ifdef DO_MEM_STATS + printf("total Allocs = %9ld\n", ourMemStats.totalAllocs); + printf("total Bytes = %9ld\n", ourMemStats.totalBytes); + printf("peak Bytes = %9ld\n", ourMemStats.peakBytes); + printf("current Bytes = %9ld\n", ourMemStats.currentBytes); + #endif + } + +#endif /* USE_CYASSL_MEMORY */ + #endif /* CyaSSL_TEST_H */ diff --git a/examples/client/client.c b/examples/client/client.c index 3a2601292..9411c55a5 100644 --- a/examples/client/client.c +++ b/examples/client/client.c @@ -23,6 +23,11 @@ #include #endif +#if !defined(CYASSL_TRACK_MEMORY) && !defined(NO_MAIN_DRIVER) + /* in case memory tracker wants stats */ + #define CYASSL_TRACK_MEMORY +#endif + #include #include @@ -35,6 +40,7 @@ Timeval timeout; #endif + static void NonBlockingSSL_Connect(CYASSL* ssl) { #ifndef CYASSL_CALLBACKS @@ -97,6 +103,7 @@ static void Usage(void) printf("-A Certificate Authority file, default %s\n", caCert); printf("-b Benchmark connections and print stats\n"); printf("-s Use pre Shared keys\n"); + printf("-t Track CyaSSL memory use\n"); printf("-d Disable peer checks\n"); printf("-g Send server HTTP GET\n"); printf("-u Use UDP DTLS," @@ -139,6 +146,7 @@ void client_test(void* args) int doPeerCheck = 1; int nonBlocking = 0; int resumeSession = 0; + int trackMemory = 0; char* cipherList = NULL; char* verifyCert = (char*)caCert; char* ourCert = (char*)cliCert; @@ -158,7 +166,7 @@ void client_test(void* args) (void)session; (void)sslResume; - while ((ch = mygetopt(argc, argv, "?gdusmNrh:p:v:l:A:c:k:b:")) != -1) { + while ((ch = mygetopt(argc, argv, "?gdusmNrth:p:v:l:A:c:k:b:")) != -1) { switch (ch) { case '?' : Usage(); @@ -180,6 +188,12 @@ void client_test(void* args) usePsk = 1; break; + case 't' : + #ifdef USE_CYASSL_MEMORY + trackMemory = 1; + #endif + break; + case 'm' : matchName = 1; break; @@ -257,6 +271,11 @@ void client_test(void* args) } } +#ifdef USE_CYASSL_MEMORY + if (trackMemory) + InitMemoryTracker(); +#endif + switch (version) { #ifndef NO_OLD_TLS case 0: @@ -563,6 +582,11 @@ void client_test(void* args) CyaSSL_CTX_free(ctx); ((func_args*)args)->return_code = 0; + +#ifdef USE_CYASSL_MEMORY + if (trackMemory) + ShowMemoryTracker(); +#endif /* USE_CYASSL_MEMORY */ } @@ -624,4 +648,3 @@ void client_test(void* args) #endif - diff --git a/examples/server/server.c b/examples/server/server.c index 324fb41a8..d205450b3 100644 --- a/examples/server/server.c +++ b/examples/server/server.c @@ -23,6 +23,11 @@ #include #endif +#if !defined(CYASSL_TRACK_MEMORY) && !defined(NO_MAIN_DRIVER) + /* in case memory tracker wants stats */ + #define CYASSL_TRACK_MEMORY +#endif + #include #include @@ -98,6 +103,7 @@ static void Usage(void) printf("-d Disable client cert check\n"); printf("-b Bind to any interface instead of localhost only\n"); printf("-s Use pre Shared keys\n"); + printf("-t Track CyaSSL memory use\n"); printf("-u Use UDP DTLS," " add -v 2 for DTLSv1 (default), -v 3 for DTLSv1.2\n"); printf("-N Use Non-blocking sockets\n"); @@ -125,6 +131,7 @@ THREAD_RETURN CYASSL_THREAD server_test(void* args) int doDTLS = 0; int useNtruKey = 0; int nonBlocking = 0; + int trackMemory = 0; char* cipherList = NULL; char* verifyCert = (char*)cliCert; char* ourCert = (char*)svrCert; @@ -140,7 +147,7 @@ THREAD_RETURN CYASSL_THREAD server_test(void* args) ourKey = (char*)eccKey; #endif - while ((ch = mygetopt(argc, argv, "?dbsnNup:v:l:A:c:k:")) != -1) { + while ((ch = mygetopt(argc, argv, "?dbstnNup:v:l:A:c:k:")) != -1) { switch (ch) { case '?' : Usage(); @@ -158,6 +165,12 @@ THREAD_RETURN CYASSL_THREAD server_test(void* args) usePsk = 1; break; + case 't' : + #ifdef USE_CYASSL_MEMORY + trackMemory = 1; + #endif + break; + case 'n' : useNtruKey = 1; break; @@ -222,6 +235,11 @@ THREAD_RETURN CYASSL_THREAD server_test(void* args) } } +#ifdef USE_CYASSL_MEMORY + if (trackMemory) + InitMemoryTracker(); +#endif + switch (version) { #ifndef NO_OLD_TLS case 0: @@ -400,6 +418,12 @@ THREAD_RETURN CYASSL_THREAD server_test(void* args) CloseSocket(clientfd); ((func_args*)args)->return_code = 0; + +#ifdef USE_CYASSL_MEMORY + if (trackMemory) + ShowMemoryTracker(); +#endif /* USE_CYASSL_MEMORY */ + return 0; }