diff --git a/src/internal.c b/src/internal.c index e1a042e4d..a8a9ddb97 100644 --- a/src/internal.c +++ b/src/internal.c @@ -13349,10 +13349,34 @@ int SendUserAuthKeyboardRequest(WOLFSSH* ssh, WS_UserAuthData* authData) } if (ret == WS_SUCCESS) { - ret = ssh->ctx->keyboardAuthCb(&authData->sf.keyboard, - ssh->keyboardAuthCtx); + /* Set responseCount to 0 to indicate this is a prompt setup call */ + authData->sf.keyboard.responseCount = 0; + + /* First try using userAuthCb if it's set */ + if (ssh->ctx->userAuthCb != NULL) { + WLOG(WS_LOG_DEBUG, "SUAKR: Calling userAuthCb for prompt setup"); + ret = ssh->ctx->userAuthCb(WOLFSSH_USERAUTH_KEYBOARD, + authData, ssh->userAuthCtx); + + /* If userAuthCb doesn't return SUCCESS_ANOTHER, fall back to keyboardAuthCb */ + if (ret != WOLFSSH_USERAUTH_SUCCESS_ANOTHER) { + WLOG(WS_LOG_DEBUG, "SUAKR: userAuthCb didn't return SUCCESS_ANOTHER, falling back"); + ret = ssh->ctx->keyboardAuthCb(&authData->sf.keyboard, + ssh->keyboardAuthCtx); + } + else { + WLOG(WS_LOG_DEBUG, "SUAKR: userAuthCb returned SUCCESS_ANOTHER, proceeding"); + ret = WS_SUCCESS; + } + } + else { + /* Fall back to keyboardAuthCb if userAuthCb is not set */ + ret = ssh->ctx->keyboardAuthCb(&authData->sf.keyboard, + ssh->keyboardAuthCtx); + } } + /* Only check for NULL pointers if we actually have prompts */ if (authData->sf.keyboard.promptCount > 0 && (authData->sf.keyboard.prompts == NULL || authData->sf.keyboard.promptLengths == NULL || diff --git a/tests/auth.c b/tests/auth.c index 147db8292..f10ccf684 100644 --- a/tests/auth.c +++ b/tests/auth.c @@ -156,6 +156,7 @@ word32 kbResponseCount; byte kbMultiRound = 0; byte currentRound = 0; byte unbalanced = 0; +byte useUserAuthCb = 0; /* Flag to test userAuthCb for keyboard-interactive */ WS_UserAuthData_Keyboard promptData; @@ -223,38 +224,73 @@ static int load_key(byte isEcc, byte* buf, word32 bufSz) static int serverUserAuth(byte authType, WS_UserAuthData* authData, void* ctx) { (void) ctx; - if (authType != WOLFSSH_USERAUTH_KEYBOARD) { - return WOLFSSH_USERAUTH_FAILURE; - } - - if (authData->sf.keyboard.responseCount != kbResponseCount) { - return WOLFSSH_USERAUTH_FAILURE; - } - - for (word32 resp = 0; resp < kbResponseCount; resp++) { - if (authData->sf.keyboard.responseLengths[resp] != - kbResponseLengths[resp]) { + + /* Handle keyboard-interactive auth */ + if (authType == WOLFSSH_USERAUTH_KEYBOARD) { + /* If responseCount is 0, this is a prompt setup call */ + if (authData->sf.keyboard.responseCount == 0) { + /* Set up prompts - only copy the necessary fields, not the entire structure */ + authData->sf.keyboard.promptCount = promptData.promptCount; + authData->sf.keyboard.promptName = promptData.promptName; + authData->sf.keyboard.promptNameSz = promptData.promptNameSz; + authData->sf.keyboard.promptInstruction = promptData.promptInstruction; + authData->sf.keyboard.promptInstructionSz = promptData.promptInstructionSz; + authData->sf.keyboard.promptLanguage = promptData.promptLanguage; + authData->sf.keyboard.promptLanguageSz = promptData.promptLanguageSz; + authData->sf.keyboard.prompts = promptData.prompts; + authData->sf.keyboard.promptLengths = promptData.promptLengths; + authData->sf.keyboard.promptEcho = promptData.promptEcho; + + /* Return SUCCESS_ANOTHER to proceed with sending prompts */ + if (useUserAuthCb) { + return WOLFSSH_USERAUTH_SUCCESS_ANOTHER; + } + /* When not testing userAuthCb, return FAILURE to fall back to keyboardAuthCb */ return WOLFSSH_USERAUTH_FAILURE; - } - if (WSTRCMP((const char*)authData->sf.keyboard.responses[resp], - (const char*)kbResponses[resp]) != 0) { + + /* Validate responses */ + if (authData->sf.keyboard.responseCount != kbResponseCount) { return WOLFSSH_USERAUTH_FAILURE; } + + for (word32 resp = 0; resp < kbResponseCount; resp++) { + if (authData->sf.keyboard.responseLengths[resp] != + kbResponseLengths[resp]) { + return WOLFSSH_USERAUTH_FAILURE; + + } + if (WSTRCMP((const char*)authData->sf.keyboard.responses[resp], + (const char*)kbResponses[resp]) != 0) { + return WOLFSSH_USERAUTH_FAILURE; + } + } + if (kbMultiRound && currentRound == 0) { + currentRound++; + kbResponses[0] = (byte*)testText2; + kbResponseLengths[0] = 8; + return WOLFSSH_USERAUTH_SUCCESS_ANOTHER; + } + return WOLFSSH_USERAUTH_SUCCESS; } - if (kbMultiRound && currentRound == 0) { - currentRound++; - kbResponses[0] = (byte*)testText2; - kbResponseLengths[0] = 8; - return WOLFSSH_USERAUTH_SUCCESS_ANOTHER; - } - return WOLFSSH_USERAUTH_SUCCESS; + + return WOLFSSH_USERAUTH_FAILURE; } static int serverKeyboardCallback(WS_UserAuthData_Keyboard *kbAuth, void *ctx) { (void) ctx; - WMEMCPY(kbAuth, &promptData, sizeof(WS_UserAuthData_Keyboard)); + /* Copy individual fields instead of the entire structure to avoid memory issues */ + kbAuth->promptCount = promptData.promptCount; + kbAuth->promptName = promptData.promptName; + kbAuth->promptNameSz = promptData.promptNameSz; + kbAuth->promptInstruction = promptData.promptInstruction; + kbAuth->promptInstructionSz = promptData.promptInstructionSz; + kbAuth->promptLanguage = promptData.promptLanguage; + kbAuth->promptLanguageSz = promptData.promptLanguageSz; + kbAuth->prompts = promptData.prompts; + kbAuth->promptLengths = promptData.promptLengths; + kbAuth->promptEcho = promptData.promptEcho; return WS_SUCCESS; } @@ -332,7 +368,12 @@ static THREAD_RETURN WOLFSSH_THREAD server_thread(void* args) } wolfSSH_SetUserAuth(ctx, serverUserAuth); - wolfSSH_SetKeyboardAuthPrompts(ctx, serverKeyboardCallback); + + /* Only set keyboard auth callback when not testing userAuthCb */ + if (!useUserAuthCb) { + wolfSSH_SetKeyboardAuthPrompts(ctx, serverKeyboardCallback); + } + ssh = wolfSSH_new(ctx); if (ssh == NULL) { ES_ERROR("Couldn't allocate SSH data.\n"); @@ -394,16 +435,24 @@ static int keyboardUserAuth(byte authType, WS_UserAuthData* authData, void* ctx) if (authType == WOLFSSH_USERAUTH_KEYBOARD) { AssertIntEQ(kbResponseCount, authData->sf.keyboard.promptCount); - for (word32 prompt = 0; prompt < kbResponseCount; prompt++) { - AssertStrEQ("Password: ", authData->sf.keyboard.prompts[prompt]); + + /* Only check prompts if there are any */ + if (kbResponseCount > 0) { + for (word32 prompt = 0; prompt < kbResponseCount; prompt++) { + AssertStrEQ("Password: ", authData->sf.keyboard.prompts[prompt]); + } } authData->sf.keyboard.responseCount = kbResponseCount; if (unbalanced) { authData->sf.keyboard.responseCount++; } - authData->sf.keyboard.responseLengths = kbResponseLengths; - authData->sf.keyboard.responses = (byte**)kbResponses; + + /* Only set response pointers if there are responses */ + if (kbResponseCount > 0) { + authData->sf.keyboard.responseLengths = kbResponseLengths; + authData->sf.keyboard.responses = (byte**)kbResponses; + } ret = WS_SUCCESS; } return ret; @@ -574,6 +623,34 @@ static void test_unbalanced_client_KeyboardInteractive(void) test_client(); unbalanced = 0; } + +static void test_userAuthCb_KeyboardInteractive(void) +{ + printf("Testing keyboard-interactive auth via userAuthCb\n"); + kbResponses[0] = (byte*)testText1; + kbResponseLengths[0] = 4; + kbResponseCount = 1; + useUserAuthCb = 1; + + test_client(); + useUserAuthCb = 0; +} + +static void test_userAuthCb_multi_round_KeyboardInteractive(void) +{ + printf("Testing multiple prompt rounds via userAuthCb\n"); + kbResponses[0] = (byte*)testText1; + kbResponseLengths[0] = 4; + kbResponseCount = 1; + kbMultiRound = 1; + useUserAuthCb = 1; + + test_client(); + AssertIntEQ(currentRound, 1); + currentRound = 0; + kbMultiRound = 0; + useUserAuthCb = 0; +} #endif /* WOLFSSH_TEST_BLOCK */ int wolfSSH_AuthTest(int argc, char** argv) @@ -603,6 +680,8 @@ int wolfSSH_AuthTest(int argc, char** argv) test_multi_prompt_KeyboardInteractive(); test_multi_round_KeyboardInteractive(); test_unbalanced_client_KeyboardInteractive(); + test_userAuthCb_KeyboardInteractive(); + test_userAuthCb_multi_round_KeyboardInteractive(); AssertIntEQ(wolfSSH_Cleanup(), WS_SUCCESS); diff --git a/wolfssh/ssh.h b/wolfssh/ssh.h index 9a49b4403..c0ead9078 100644 --- a/wolfssh/ssh.h +++ b/wolfssh/ssh.h @@ -360,6 +360,13 @@ typedef struct WS_UserAuthData { } sf; } WS_UserAuthData; +/* User Authentication callback + * For keyboard-interactive authentication: + * - When responseCount is 0, the callback is being called to set up prompts + * Return WOLFSSH_USERAUTH_SUCCESS_ANOTHER to proceed with sending prompts + * - When responseCount > 0, the callback is being called to validate responses + * Return WOLFSSH_USERAUTH_SUCCESS_ANOTHER to request more prompts + */ typedef int (*WS_CallbackUserAuth)(byte, WS_UserAuthData*, void*); WOLFSSH_API void wolfSSH_SetUserAuth(WOLFSSH_CTX*, WS_CallbackUserAuth); typedef int (*WS_CallbackUserAuthTypes)(WOLFSSH* ssh, void* ctx);