diff --git a/src/internal.c b/src/internal.c index 238d4c5..e1f03d4 100644 --- a/src/internal.c +++ b/src/internal.c @@ -447,8 +447,13 @@ WOLFSSH* SshInit(WOLFSSH* ssh, WOLFSSH_CTX* ctx) ssh->ctx = ctx; ssh->error = WS_SUCCESS; +#ifdef USE_WINDOWS_API + ssh->rfd = INVALID_SOCKET; + ssh->wfd = INVALID_SOCKET; +#else ssh->rfd = -1; /* set to invalid */ ssh->wfd = -1; /* set to invalid */ +#endif ssh->ioReadCtx = &ssh->rfd; /* prevent invalid access if not correctly */ ssh->ioWriteCtx = &ssh->wfd; /* set */ ssh->highwaterMark = ctx->highwaterMark; diff --git a/src/io.c b/src/io.c index 6f4900a..24ef4f5 100644 --- a/src/io.c +++ b/src/io.c @@ -1,4 +1,4 @@ -/* io.c +/* io.c * * Copyright (C) 2014-2016 wolfSSL Inc. * @@ -264,10 +264,8 @@ void* wolfSSH_GetIOWriteCtx(WOLFSSH* ssh) #endif -/* Translates return codes returned from - * send() and recv() if need be. - */ -static INLINE int TranslateReturnCode(int old, int sd) +/* Translates return codes returned from send() and recv() if need be. */ +static INLINE int TranslateReturnCode(int old, WS_SOCKET_T sd) { (void)sd; @@ -301,7 +299,7 @@ static INLINE int TranslateReturnCode(int old, int sd) static INLINE int LastError(void) { -#ifdef USE_WINDOWS_API +#ifdef USE_WINDOWS_API return WSAGetLastError(); #elif defined(EBSNET) return xn_getlasterror(); @@ -317,7 +315,7 @@ int wsEmbedRecv(WOLFSSH* ssh, void* data, word32 sz, void* ctx) { int recvd; int err; - int sd = *(int*)ctx; + WS_SOCKET_T sd = *(WS_SOCKET_T*)ctx; char* buf = (char*)data; #ifdef WOLFSSH_TEST_BLOCK @@ -375,7 +373,7 @@ int wsEmbedRecv(WOLFSSH* ssh, void* data, word32 sz, void* ctx) */ int wsEmbedSend(WOLFSSH* ssh, void* data, word32 sz, void* ctx) { - int sd = *(int*)ctx; + WS_SOCKET_T sd = *(WS_SOCKET_T*)ctx; int sent; int err; char* buf = (char*)data; diff --git a/src/ssh.c b/src/ssh.c index 3fca55e..d76d9a8 100644 --- a/src/ssh.c +++ b/src/ssh.c @@ -139,7 +139,7 @@ void wolfSSH_free(WOLFSSH* ssh) } -int wolfSSH_set_fd(WOLFSSH* ssh, int fd) +int wolfSSH_set_fd(WOLFSSH* ssh, WS_SOCKET_T fd) { WLOG(WS_LOG_DEBUG, "Entering wolfSSH_set_fd()"); @@ -156,14 +156,18 @@ int wolfSSH_set_fd(WOLFSSH* ssh, int fd) } -int wolfSSH_get_fd(const WOLFSSH* ssh) +WS_SOCKET_T wolfSSH_get_fd(const WOLFSSH* ssh) { WLOG(WS_LOG_DEBUG, "Entering wolfSSH_get_fd()"); if (ssh) return ssh->rfd; +#ifdef USE_WINDOWS_API + return INVALID_SOCKET; +#else return WS_BAD_ARGUMENT; +#endif } diff --git a/tests/api.c b/tests/api.c index 4d07080..d5db677 100644 --- a/tests/api.c +++ b/tests/api.c @@ -127,6 +127,29 @@ static void test_client_wolfSSH_new(void) } +static void test_wolfSSH_set_fd(void) +{ + WOLFSSH_CTX* ctx; + WOLFSSH* ssh; + WS_SOCKET_T fd = 23, check; + + AssertNotNull(ctx = wolfSSH_CTX_new(WOLFSSH_ENDPOINT_CLIENT, NULL)); + AssertNotNull(ssh = wolfSSH_new(ctx)); + + AssertIntNE(WS_SUCCESS, wolfSSH_set_fd(NULL, fd)); + check = wolfSSH_get_fd(NULL); + AssertFalse(WS_SUCCESS == check); + + AssertIntEQ(WS_SUCCESS, wolfSSH_set_fd(ssh, fd)); + check = wolfSSH_get_fd(ssh); + AssertTrue(fd == check); + AssertTrue(0 != check); + + wolfSSH_free(ssh); + wolfSSH_CTX_free(ctx); +} + + static void test_wolfSSH_SetUsername(void) { #ifndef WOLFSSH_NO_CLIENT @@ -310,6 +333,7 @@ int main(void) test_wolfSSH_CTX_new(); test_server_wolfSSH_new(); test_client_wolfSSH_new(); + test_wolfSSH_set_fd(); test_wolfSSH_SetUsername(); test_wolfSSH_ConvertConsole(); diff --git a/wolfssh/internal.h b/wolfssh/internal.h index a6608f0..14283b6 100644 --- a/wolfssh/internal.h +++ b/wolfssh/internal.h @@ -283,8 +283,8 @@ struct WS_SFTP_RENAME_STATE; struct WOLFSSH { WOLFSSH_CTX* ctx; /* owner context */ int error; - int rfd; - int wfd; + WS_SOCKET_T rfd; + WS_SOCKET_T wfd; void* ioReadCtx; /* I/O Read Context handle */ void* ioWriteCtx; /* I/O Write Context handle */ int rflags; /* optional read flags */ diff --git a/wolfssh/port.h b/wolfssh/port.h index 7998c29..593a666 100644 --- a/wolfssh/port.h +++ b/wolfssh/port.h @@ -607,6 +607,13 @@ extern "C" { #endif +#if defined(USE_WINDOWS_API) + #define WS_SOCKET_T SOCKET +#else + #define WS_SOCKET_T int +#endif + + #if !defined(NO_TERMIOS) && defined(WOLFSSH_TERM) #if !defined(USE_WINDOWS_API) && !defined(MICROCHIP_PIC32) #include diff --git a/wolfssh/ssh.h b/wolfssh/ssh.h index 709267a..1b3a228 100644 --- a/wolfssh/ssh.h +++ b/wolfssh/ssh.h @@ -64,8 +64,8 @@ WOLFSSH_API void wolfSSH_free(WOLFSSH*); WOLFSSH_API int wolfSSH_worker(WOLFSSH*, word32*); -WOLFSSH_API int wolfSSH_set_fd(WOLFSSH*, int); -WOLFSSH_API int wolfSSH_get_fd(const WOLFSSH*); +WOLFSSH_API int wolfSSH_set_fd(WOLFSSH*, WS_SOCKET_T); +WOLFSSH_API WS_SOCKET_T wolfSSH_get_fd(const WOLFSSH*); /* data high water mark functions */ WOLFSSH_API int wolfSSH_SetHighwater(WOLFSSH*, word32);