From b2d19fb0987e381a1418d26a2ef07628ede32c64 Mon Sep 17 00:00:00 2001 From: Yosuke Shimizu Date: Fri, 24 Apr 2026 13:42:33 +0900 Subject: [PATCH] Add service-name check and regress test --- src/internal.c | 20 ++++++- tests/unit.c | 127 +++++++++++++++++++++++++++++++++++++++++++++ wolfssh/internal.h | 2 + 3 files changed, 147 insertions(+), 2 deletions(-) diff --git a/src/internal.c b/src/internal.c index d4ddb71be..b6eb72f45 100644 --- a/src/internal.c +++ b/src/internal.c @@ -8367,6 +8367,7 @@ static int DoUserAuthRequest(WOLFSSH* ssh, word32 begin; int ret = WS_SUCCESS; byte authNameId; + byte serviceValid = 1; WS_UserAuthData authData; WLOG(WS_LOG_DEBUG, "Entering DoUserAuthRequest()"); @@ -8401,10 +8402,19 @@ static int DoUserAuthRequest(WOLFSSH* ssh, authData.serviceName = buf + begin; begin += authData.serviceNameSz; - ret = GetSize(&authData.authNameSz, buf, len, &begin); + if (NameToId((const char*)authData.serviceName, authData.serviceNameSz) + != ID_SERVICE_CONNECTION) { + WLOG(WS_LOG_DEBUG, "DUAR: Invalid service name"); + serviceValid = 0; + ret = SendUserAuthFailure(ssh, 0); + *idx = len; + } + else { + ret = GetSize(&authData.authNameSz, buf, len, &begin); + } } - if (ret == WS_SUCCESS) { + if (ret == WS_SUCCESS && serviceValid) { authData.authName = buf + begin; begin += authData.authNameSz; authNameId = NameToId((char*)authData.authName, authData.authNameSz); @@ -10773,6 +10783,12 @@ int wolfSSH_TestChannelPutData(WOLFSSH_CHANNEL* channel, byte* data, { return ChannelPutData(channel, data, dataSz); } + +int wolfSSH_TestDoUserAuthRequest(WOLFSSH* ssh, byte* buf, word32 len, + word32* idx) +{ + return DoUserAuthRequest(ssh, buf, len, idx); +} #endif diff --git a/tests/unit.c b/tests/unit.c index 47c31d3c1..53a64adb2 100644 --- a/tests/unit.c +++ b/tests/unit.c @@ -918,6 +918,128 @@ static int test_DoChannelRequest(void) return result; } +/* Capture buffer for the service-name unit test. Separate from the channel- + * request capture so the two tests can run independently in any order. */ +static byte s_authSvcCapture[256]; +static word32 s_authSvcCaptureSz = 0; + +static int CaptureIoSendAuthSvc(WOLFSSH* ssh, void* buf, word32 sz, void* ctx) +{ + (void)ssh; (void)ctx; + s_authSvcCaptureSz = (sz < (word32)sizeof(s_authSvcCapture)) + ? sz : (word32)sizeof(s_authSvcCapture); + WMEMCPY(s_authSvcCapture, buf, s_authSvcCaptureSz); + return (int)sz; +} + +/* Verify DoUserAuthRequest rejects non-"ssh-connection" service names per + * RFC 4252 Section 5. For each case we assert: + * 1. ret == WS_SUCCESS (connection stays open for retry) + * 2. SSH_MSG_USERAUTH_FAILURE is actually sent (captured at packet byte 5) + * 3. *idx == len (entire payload consumed; buffer stays aligned) + * + * For invalid-service cases the auth-method field is intentionally omitted + * from the payload. DoUserAuthRequest must short-circuit at the service-name + * check and still satisfy all three assertions — proving it never tries to + * parse the missing auth-method field. If the short-circuit were absent, + * GetSize() for authNameSz would hit end-of-buffer and return WS_BUFFER_E, + * failing assertion 1. + * + * For the valid-service case, auth method "xyz-unknown" (always unsupported + * regardless of compile-time options) is included. The function reaches + * auth-method dispatch, falls to the unknown-method else-branch, and sends + * USERAUTH_FAILURE via that normal path. */ +static int test_DoUserAuthRequest_serviceName(void) +{ + WOLFSSH_CTX* ctx = NULL; + WOLFSSH* ssh = NULL; + int result = 0; + struct { + const char* svcName; + word32 svcNameSz; + const char* authMethod; /* NULL = omit field (proves short-circuit) */ + word32 authMethodSz; + int expectRet; + const char* label; + } cases[] = { + /* valid service: auth dispatch fires, fails on unknown method */ + { "ssh-connection", 14, "xyz-unknown", 11, WS_SUCCESS, + "valid svc unknown auth" }, + /* invalid service: short-circuit, auth-method field absent */ + { "ssh-agent", 9, NULL, 0, WS_SUCCESS, + "invalid ssh-agent svc" }, + { "bad", 3, NULL, 0, WS_SUCCESS, + "invalid bad svc" }, + }; + int i; + + ctx = wolfSSH_CTX_new(WOLFSSH_ENDPOINT_SERVER, NULL); + if (ctx == NULL) return -500; + wolfSSH_SetIOSend(ctx, CaptureIoSendAuthSvc); + + ssh = wolfSSH_new(ctx); + if (ssh == NULL) { wolfSSH_CTX_free(ctx); return -501; } + + for (i = 0; i < (int)(sizeof(cases)/sizeof(cases[0])); i++) { + byte buf[128]; + word32 len = 0, idx = 0; + word32 snsz = cases[i].svcNameSz; + int ret; + + s_authSvcCaptureSz = 0; + WMEMSET(s_authSvcCapture, 0, sizeof(s_authSvcCapture)); + + /* username: "user" */ + buf[len++] = 0; buf[len++] = 0; buf[len++] = 0; buf[len++] = 4; + WMEMCPY(buf + len, "user", 4); len += 4; + + /* service name */ + buf[len++] = (byte)(snsz >> 24); buf[len++] = (byte)(snsz >> 16); + buf[len++] = (byte)(snsz >> 8); buf[len++] = (byte)snsz; + WMEMCPY(buf + len, cases[i].svcName, snsz); len += snsz; + + /* auth method: omit for invalid-service cases to prove short-circuit */ + if (cases[i].authMethod != NULL) { + word32 amsz = cases[i].authMethodSz; + buf[len++] = (byte)(amsz >> 24); buf[len++] = (byte)(amsz >> 16); + buf[len++] = (byte)(amsz >> 8); buf[len++] = (byte)amsz; + WMEMCPY(buf + len, cases[i].authMethod, amsz); len += amsz; + } + + ret = wolfSSH_TestDoUserAuthRequest(ssh, buf, len, &idx); + + if (ret != cases[i].expectRet) { + printf("DoUserAuthRequest_svcName[%s]: ret=%d expected=%d\n", + cases[i].label, ret, cases[i].expectRet); + result = -502 - i; + break; + } + + /* MSGID_USERAUTH_FAILURE must be in the captured packet. */ + if (s_authSvcCaptureSz <= 5 || + s_authSvcCapture[5] != MSGID_USERAUTH_FAILURE) { + printf("DoUserAuthRequest_svcName[%s]: USERAUTH_FAILURE not sent " + "(capSz=%u byte5=0x%02x)\n", cases[i].label, + s_authSvcCaptureSz, + s_authSvcCaptureSz > 5 ? s_authSvcCapture[5] : 0); + result = -520 - i; + break; + } + + /* All cases must consume the entire payload. */ + if (idx != len) { + printf("DoUserAuthRequest_svcName[%s]: idx=%u expected len=%u\n", + cases[i].label, idx, len); + result = -510 - i; + break; + } + } + + wolfSSH_free(ssh); + wolfSSH_CTX_free(ctx); + return result; +} + #if !defined(WOLFSSH_NO_RSA) /* 2048-bit RSA private key (PKCS#1 DER). @@ -1210,6 +1332,11 @@ int wolfSSH_UnitTest(int argc, char** argv) unitResult = test_ChannelPutData(); printf("ChannelPutData: %s\n", (unitResult == 0 ? "SUCCESS" : "FAILED")); testResult = testResult || unitResult; + + unitResult = test_DoUserAuthRequest_serviceName(); + printf("DoUserAuthRequest_serviceName: %s\n", + (unitResult == 0 ? "SUCCESS" : "FAILED")); + testResult = testResult || unitResult; #endif #ifdef WOLFSSH_KEYGEN diff --git a/wolfssh/internal.h b/wolfssh/internal.h index 5616b5556..9c5ab5f98 100644 --- a/wolfssh/internal.h +++ b/wolfssh/internal.h @@ -1338,6 +1338,8 @@ enum WS_MessageIdLimits { WOLFSSH_API int wolfSSH_TestDoKexDhInit(WOLFSSH* ssh, byte* buf, word32 len, word32* idx); WOLFSSH_API int wolfSSH_TestChannelPutData(WOLFSSH_CHANNEL*, byte*, word32); + WOLFSSH_API int wolfSSH_TestDoUserAuthRequest(WOLFSSH* ssh, byte* buf, + word32 len, word32* idx); #ifndef WOLFSSH_NO_DH_GEX_SHA256 WOLFSSH_API int wolfSSH_TestDoKexDhGexRequest(WOLFSSH* ssh, byte* buf, word32 len, word32* idx);