From ebec3e1f7d28401958c0c99d3bd623d4c474ba35 Mon Sep 17 00:00:00 2001 From: Shreyansh Sancheti Date: Mon, 6 Apr 2026 11:33:50 +0530 Subject: [PATCH] hcs: add function variables for compute API calls and unit tests Replace direct vmcompute.HcsXxx() calls in system.go and process.go with package-level function variables (hcsXxx) defined in zsyscall_compute.go. The variables default to the vmcompute functions. Tests swap individual variables to intercept specific operations. When upstream migrated System.Start to the operation-based computecore API (HcsCreateOperation + HcsStartComputeSystem + HcsWaitForOperationResult + HcsCloseOperation), the same wrapper pattern is extended to those four calls so the Start path remains testable without standing up real HCS. Tests are in package hcs (internal) so they access the function vars directly - no wrapper layer needed. 12 tests covering: Start (operation API): success, already-closed handle, CreateOperation failure, StartComputeSystem failure, wait returns ErrUnexpectedContainerExit (VM crash at boot), wait returns ErrUnexpectedProcessAbort (HCS service disconnect), wait returns ERROR_TIMEOUT. Pause: system exit during pending Pause. waitBackground: normal vs unexpected exit classification. Wait: multi-goroutine fan-out. Callback: late notification after unregistration is a no-op. Signed-off-by: Shreyansh Sancheti --- internal/hcs/export_test.go | 37 ++++ internal/hcs/process.go | 23 +-- internal/hcs/system.go | 40 ++-- internal/hcs/system_test.go | 318 +++++++++++++++++++++++++++++++ internal/hcs/zsyscall_compute.go | 63 ++++++ 5 files changed, 450 insertions(+), 31 deletions(-) create mode 100644 internal/hcs/export_test.go create mode 100644 internal/hcs/system_test.go create mode 100644 internal/hcs/zsyscall_compute.go diff --git a/internal/hcs/export_test.go b/internal/hcs/export_test.go new file mode 100644 index 0000000000..4d3de6b00b --- /dev/null +++ b/internal/hcs/export_test.go @@ -0,0 +1,37 @@ +//go:build windows + +package hcs + +import ( + "context" + + "github.com/Microsoft/hcsshim/internal/vmcompute" +) + +// fireNotificationForTest simulates an HCS notification callback for a given +// callback number. Used to drive async completion paths in tests. +func fireNotificationForTest(callbackNumber uintptr, notification hcsNotification, result error) { + callbackMapLock.RLock() + ctx := callbackMap[callbackNumber] + callbackMapLock.RUnlock() + if ctx == nil { + return + } + if ch, ok := ctx.channels[notification]; ok { + ch <- result + } +} + +func newTestSystemWithHandle(id string, handle uintptr) *System { + s := newSystem(id) + s.handle = vmcompute.HcsSystem(handle) + return s +} + +func registerCallbackForTest(s *System) error { + return s.registerCallback(context.Background()) +} + +func startWaitBackgroundForTest(s *System) { + go s.waitBackground() +} diff --git a/internal/hcs/process.go b/internal/hcs/process.go index fef2bf546c..8481f7150d 100644 --- a/internal/hcs/process.go +++ b/internal/hcs/process.go @@ -106,7 +106,7 @@ func (process *Process) Signal(ctx context.Context, options interface{}) (bool, return false, err } - resultJSON, err := vmcompute.HcsSignalProcess(ctx, process.handle, string(optionsb)) + resultJSON, err := hcsSignalProcess(ctx, process.handle, string(optionsb)) events := processHcsResult(ctx, resultJSON) delivered, err := process.processSignalResult(ctx, err) if err != nil { @@ -171,7 +171,7 @@ func (process *Process) Kill(ctx context.Context) (bool, error) { } defer newProcessHandle.Close() - resultJSON, err := vmcompute.HcsTerminateProcess(ctx, newProcessHandle.handle) + resultJSON, err := hcsTerminateProcess(ctx, newProcessHandle.handle) if err != nil { // We still need to check these two cases, as processes may still be killed by an // external actor (human operator, OOM, random script etc). @@ -234,7 +234,7 @@ func (process *Process) waitBackground() { // Make sure we didn't race with Close() here if process.handle != 0 { - propertiesJSON, resultJSON, err = vmcompute.HcsGetProcessProperties(ctx, process.handle) + propertiesJSON, resultJSON, err = hcsGetProcessProperties(ctx, process.handle) events := processHcsResult(ctx, resultJSON) if err != nil { err = makeProcessError(process, operation, err, events) @@ -303,7 +303,7 @@ func (process *Process) ResizeConsole(ctx context.Context, width, height uint16) return err } - resultJSON, err := vmcompute.HcsModifyProcess(ctx, process.handle, string(modifyRequestb)) + resultJSON, err := hcsModifyProcess(ctx, process.handle, string(modifyRequestb)) events := processHcsResult(ctx, resultJSON) if err != nil { return makeProcessError(process, operation, err, events) @@ -352,7 +352,7 @@ func (process *Process) StdioLegacy() (_ io.WriteCloser, _ io.ReadCloser, _ io.R return stdin, stdout, stderr, nil } - processInfo, resultJSON, err := vmcompute.HcsGetProcessInfo(ctx, process.handle) + processInfo, resultJSON, err := hcsGetProcessInfo(ctx, process.handle) events := processHcsResult(ctx, resultJSON) if err != nil { return nil, nil, nil, makeProcessError(process, operation, err, events) @@ -406,7 +406,7 @@ func (process *Process) CloseStdin(ctx context.Context) (err error) { return err } - resultJSON, err := vmcompute.HcsModifyProcess(ctx, process.handle, string(modifyRequestb)) + resultJSON, err := hcsModifyProcess(ctx, process.handle, string(modifyRequestb)) events := processHcsResult(ctx, resultJSON) if err != nil { return makeProcessError(process, operation, err, events) @@ -509,7 +509,7 @@ func (process *Process) Close() (err error) { return makeProcessError(process, operation, err, nil) } - if err = vmcompute.HcsCloseProcess(ctx, process.handle); err != nil { + if err = hcsCloseProcess(ctx, process.handle); err != nil { return makeProcessError(process, operation, err, nil) } @@ -536,7 +536,7 @@ func (process *Process) registerCallback(ctx context.Context) error { callbackMap[callbackNumber] = callbackContext callbackMapLock.Unlock() - callbackHandle, err := vmcompute.HcsRegisterProcessCallback(ctx, process.handle, notificationWatcherCallback, callbackNumber) + callbackHandle, err := hcsRegisterProcessCallback(ctx, process.handle, notificationWatcherCallback, callbackNumber) if err != nil { return err } @@ -563,9 +563,10 @@ func (process *Process) unregisterCallback(ctx context.Context) error { return nil } - // vmcompute.HcsUnregisterProcessCallback has its own synchronization to - // wait for all callbacks to complete. We must NOT hold the callbackMapLock. - err := vmcompute.HcsUnregisterProcessCallback(ctx, handle) + // The underlying HCS API (HcsUnregisterProcessCallback) has its own + // synchronization to wait for all in-flight callbacks to complete. + // We must NOT hold the callbackMapLock during this call. + err := hcsUnregisterProcessCallback(ctx, handle) if err != nil { return err } diff --git a/internal/hcs/system.go b/internal/hcs/system.go index 3d9fcce1bd..3bd74d9aea 100644 --- a/internal/hcs/system.go +++ b/internal/hcs/system.go @@ -91,7 +91,7 @@ func CreateComputeSystem(ctx context.Context, id string, hcsDocumentInterface in resultJSON string createError error ) - computeSystem.handle, resultJSON, createError = vmcompute.HcsCreateComputeSystem(ctx, id, hcsDocument, identity) + computeSystem.handle, resultJSON, createError = hcsCreateComputeSystem(ctx, id, hcsDocument, identity) if createError == nil || IsPending(createError) { defer func() { if err != nil { @@ -128,7 +128,7 @@ func OpenComputeSystem(ctx context.Context, id string) (*System, error) { operation := "hcs::OpenComputeSystem" computeSystem := newSystem(id) - handle, resultJSON, err := vmcompute.HcsOpenComputeSystem(ctx, id) + handle, resultJSON, err := hcsOpenComputeSystem(ctx, id) events := processHcsResult(ctx, resultJSON) if err != nil { return nil, makeSystemError(computeSystem, operation, err, events) @@ -185,7 +185,7 @@ func GetComputeSystems(ctx context.Context, q schema1.ComputeSystemQuery) ([]sch return nil, err } - computeSystemsJSON, resultJSON, err := vmcompute.HcsEnumerateComputeSystems(ctx, string(queryb)) + computeSystemsJSON, resultJSON, err := hcsEnumerateComputeSystems(ctx, string(queryb)) events := processHcsResult(ctx, resultJSON) if err != nil { return nil, &HcsError{Op: operation, Err: err, Events: events} @@ -211,11 +211,11 @@ func (computeSystem *System) Start(ctx context.Context) error { return makeSystemError(computeSystem, "hcs::System::Start", ErrAlreadyClosed, nil) } - op, err := computecore.HcsCreateOperation(ctx, 0, 0) + op, err := hcsCreateOperation(ctx, 0, 0) if err != nil { return makeSystemError(computeSystem, "hcs::System::Start", err, nil) } - defer computecore.HcsCloseOperation(ctx, op) + defer hcsCloseOperation(ctx, op) return computeSystem.start(ctx, op, "") } @@ -236,7 +236,7 @@ func (computeSystem *System) start(ctx context.Context, op computecore.HcsOperat defer func() { oc.SetSpanStatus(span, err) }() span.AddAttributes(trace.StringAttribute("cid", computeSystem.id)) - if err := computecore.HcsStartComputeSystem( + if err := hcsStartComputeSystem( ctx, computecore.HcsSystem(computeSystem.handle), op, @@ -245,7 +245,7 @@ func (computeSystem *System) start(ctx context.Context, op computecore.HcsOperat return makeSystemError(computeSystem, operation, err, nil) } - if _, err := computecore.HcsWaitForOperationResult(ctx, op, 0xFFFFFFFF); err != nil { + if _, err := hcsWaitForOperationResult(ctx, op, 0xFFFFFFFF); err != nil { return makeSystemError(computeSystem, operation, err, nil) } @@ -269,7 +269,7 @@ func (computeSystem *System) Shutdown(ctx context.Context) error { return nil } - resultJSON, err := vmcompute.HcsShutdownComputeSystem(ctx, computeSystem.handle, "") + resultJSON, err := hcsShutdownComputeSystem(ctx, computeSystem.handle, "") events := processHcsResult(ctx, resultJSON) if err != nil && !errors.Is(err, ErrVmcomputeAlreadyStopped) && @@ -291,7 +291,7 @@ func (computeSystem *System) Terminate(ctx context.Context) error { return nil } - resultJSON, err := vmcompute.HcsTerminateComputeSystem(ctx, computeSystem.handle, "") + resultJSON, err := hcsTerminateComputeSystem(ctx, computeSystem.handle, "") events := processHcsResult(ctx, resultJSON) if err != nil && !errors.Is(err, ErrVmcomputeAlreadyStopped) && @@ -394,7 +394,7 @@ func (computeSystem *System) Properties(ctx context.Context, types ...schema1.Pr return nil, makeSystemError(computeSystem, operation, err, nil) } - propertiesJSON, resultJSON, err := vmcompute.HcsGetComputeSystemProperties(ctx, computeSystem.handle, string(queryBytes)) + propertiesJSON, resultJSON, err := hcsGetComputeSystemProperties(ctx, computeSystem.handle, string(queryBytes)) events := processHcsResult(ctx, resultJSON) if err != nil { return nil, makeSystemError(computeSystem, operation, err, events) @@ -535,7 +535,7 @@ func (computeSystem *System) hcsPropertiesV2Query(ctx context.Context, types []h return nil, makeSystemError(computeSystem, operation, err, nil) } - propertiesJSON, resultJSON, err := vmcompute.HcsGetComputeSystemProperties(ctx, computeSystem.handle, string(queryBytes)) + propertiesJSON, resultJSON, err := hcsGetComputeSystemProperties(ctx, computeSystem.handle, string(queryBytes)) events := processHcsResult(ctx, resultJSON) if err != nil { return nil, makeSystemError(computeSystem, operation, err, events) @@ -672,7 +672,7 @@ func (computeSystem *System) Pause(ctx context.Context) (err error) { return makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } - resultJSON, err := vmcompute.HcsPauseComputeSystem(ctx, computeSystem.handle, "") + resultJSON, err := hcsPauseComputeSystem(ctx, computeSystem.handle, "") events, err := processAsyncHcsResult(ctx, err, resultJSON, computeSystem.callbackNumber, hcsNotificationSystemPauseCompleted, &timeout.SystemPause) if err != nil { @@ -700,7 +700,7 @@ func (computeSystem *System) Resume(ctx context.Context) (err error) { return makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } - resultJSON, err := vmcompute.HcsResumeComputeSystem(ctx, computeSystem.handle, "") + resultJSON, err := hcsResumeComputeSystem(ctx, computeSystem.handle, "") events, err := processAsyncHcsResult(ctx, err, resultJSON, computeSystem.callbackNumber, hcsNotificationSystemResumeCompleted, &timeout.SystemResume) if err != nil { @@ -733,7 +733,7 @@ func (computeSystem *System) Save(ctx context.Context, options interface{}) (err return makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } - result, err := vmcompute.HcsSaveComputeSystem(ctx, computeSystem.handle, string(saveOptions)) + result, err := hcsSaveComputeSystem(ctx, computeSystem.handle, string(saveOptions)) events, err := processAsyncHcsResult(ctx, err, result, computeSystem.callbackNumber, hcsNotificationSystemSaveCompleted, &timeout.SystemSave) if err != nil { @@ -757,7 +757,7 @@ func (computeSystem *System) createProcess(ctx context.Context, operation string } configuration := string(configurationb) - processInfo, processHandle, resultJSON, err := vmcompute.HcsCreateProcess(ctx, computeSystem.handle, configuration) + processInfo, processHandle, resultJSON, err := hcsCreateProcess(ctx, computeSystem.handle, configuration) events := processHcsResult(ctx, resultJSON) if err != nil { if v2, ok := c.(*hcsschema.ProcessParameters); ok { @@ -813,7 +813,7 @@ func (computeSystem *System) OpenProcess(ctx context.Context, pid int) (*Process return nil, makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } - processHandle, resultJSON, err := vmcompute.HcsOpenProcess(ctx, computeSystem.handle, uint32(pid)) + processHandle, resultJSON, err := hcsOpenProcess(ctx, computeSystem.handle, uint32(pid)) events := processHcsResult(ctx, resultJSON) if err != nil { return nil, makeSystemError(computeSystem, operation, err, events) @@ -856,7 +856,7 @@ func (computeSystem *System) CloseCtx(ctx context.Context) (err error) { return makeSystemError(computeSystem, operation, err, nil) } - err = vmcompute.HcsCloseComputeSystem(ctx, computeSystem.handle) + err = hcsCloseComputeSystem(ctx, computeSystem.handle) if err != nil { return makeSystemError(computeSystem, operation, err, nil) } @@ -885,7 +885,7 @@ func (computeSystem *System) registerCallback(ctx context.Context) error { callbackMap[callbackNumber] = callbackContext callbackMapLock.Unlock() - callbackHandle, err := vmcompute.HcsRegisterComputeSystemCallback(ctx, computeSystem.handle, + callbackHandle, err := hcsRegisterComputeSystemCallback(ctx, computeSystem.handle, notificationWatcherCallback, callbackNumber) if err != nil { return err @@ -915,7 +915,7 @@ func (computeSystem *System) unregisterCallback(ctx context.Context) error { // hcsUnregisterComputeSystemCallback has its own synchronization // to wait for all callbacks to complete. We must NOT hold the callbackMapLock. - err := vmcompute.HcsUnregisterComputeSystemCallback(ctx, handle) + err := hcsUnregisterComputeSystemCallback(ctx, handle) if err != nil { return err } @@ -948,7 +948,7 @@ func (computeSystem *System) Modify(ctx context.Context, config interface{}) err } requestJSON := string(requestBytes) - resultJSON, err := vmcompute.HcsModifyComputeSystem(ctx, computeSystem.handle, requestJSON) + resultJSON, err := hcsModifyComputeSystem(ctx, computeSystem.handle, requestJSON) events := processHcsResult(ctx, resultJSON) if err != nil { return makeSystemError(computeSystem, operation, err, events) diff --git a/internal/hcs/system_test.go b/internal/hcs/system_test.go new file mode 100644 index 0000000000..18670980b8 --- /dev/null +++ b/internal/hcs/system_test.go @@ -0,0 +1,318 @@ +//go:build windows + +package hcs + +import ( + "context" + "errors" + "syscall" + "testing" + "time" + + "github.com/Microsoft/hcsshim/internal/computecore" + "github.com/Microsoft/hcsshim/internal/vmcompute" +) + +// swapFunc replaces *target with fn for the duration of t, restoring the original on cleanup. +func swapFunc[T any](t *testing.T, target *T, fn T) { + t.Helper() + orig := *target + *target = fn + t.Cleanup(func() { *target = orig }) +} + +// setupCallback registers a fake callback on sys so the notification channels exist. +// Returns the callback number for firing test notifications. +func setupCallback(t *testing.T, sys *System) uintptr { + t.Helper() + swapFunc(t, &hcsRegisterComputeSystemCallback, func(_ context.Context, _ vmcompute.HcsSystem, _ uintptr, _ uintptr) (vmcompute.HcsCallback, error) { + return vmcompute.HcsCallback(99), nil + }) + if err := registerCallbackForTest(sys); err != nil { + t.Fatalf("registerCallback: %v", err) + } + return sys.callbackNumber +} + +// startMocks holds fake implementations of the four computecore calls used by Start. +// Tests build a startMocks, then call install(t) to swap the package-level vars +// for the duration of the test. +type startMocks struct { + createOp computecore.HcsOperation + createErr error + createCalls int + + closeCalls int + + startErr error + startCalls int + + waitResult string + waitErr error + waitCalls int +} + +func (m *startMocks) install(t *testing.T) { + t.Helper() + swapFunc(t, &hcsCreateOperation, func(_ context.Context, _ uintptr, _ uintptr) (computecore.HcsOperation, error) { + m.createCalls++ + return m.createOp, m.createErr + }) + swapFunc(t, &hcsCloseOperation, func(_ context.Context, _ computecore.HcsOperation) { + m.closeCalls++ + }) + swapFunc(t, &hcsStartComputeSystem, func(_ context.Context, _ computecore.HcsSystem, _ computecore.HcsOperation, _ string) error { + m.startCalls++ + return m.startErr + }) + swapFunc(t, &hcsWaitForOperationResult, func(_ context.Context, _ computecore.HcsOperation, _ uint32) (string, error) { + m.waitCalls++ + return m.waitResult, m.waitErr + }) +} + +// TestStart_Success verifies the happy path: CreateOperation, StartComputeSystem, +// and WaitForOperationResult all succeed; CloseOperation runs via defer. +func TestStart_Success(t *testing.T) { + sys := newTestSystemWithHandle("start-success", 42) + m := &startMocks{createOp: 99} + m.install(t) + + if err := sys.Start(context.Background()); err != nil { + t.Fatalf("Start: %v", err) + } + if m.createCalls != 1 || m.startCalls != 1 || m.waitCalls != 1 || m.closeCalls != 1 { + t.Errorf("call counts: create=%d start=%d wait=%d close=%d (want 1 each)", + m.createCalls, m.startCalls, m.waitCalls, m.closeCalls) + } + if sys.startTime.IsZero() { + t.Error("startTime should be set after successful Start") + } +} + +// TestStart_AlreadyClosed verifies that Start on a system whose handle has been +// cleared returns ErrAlreadyClosed without invoking any HCS API. +func TestStart_AlreadyClosed(t *testing.T) { + sys := newTestSystemWithHandle("start-closed", 0) + m := &startMocks{} + m.install(t) + + err := sys.Start(context.Background()) + if !errors.Is(err, ErrAlreadyClosed) { + t.Fatalf("expected ErrAlreadyClosed, got: %v", err) + } + if m.createCalls != 0 || m.startCalls != 0 || m.waitCalls != 0 { + t.Errorf("no HCS calls expected for closed handle: create=%d start=%d wait=%d", + m.createCalls, m.startCalls, m.waitCalls) + } +} + +// TestStart_CreateOperationFails verifies the failure surfaces and Close is not +// called (no operation handle was acquired). +func TestStart_CreateOperationFails(t *testing.T) { + sys := newTestSystemWithHandle("start-create-fail", 42) + m := &startMocks{createErr: errors.New("create-op-failed")} + m.install(t) + + err := sys.Start(context.Background()) + if err == nil || !errors.Is(err, m.createErr) { + t.Fatalf("expected create error to surface, got: %v", err) + } + if m.startCalls != 0 || m.waitCalls != 0 || m.closeCalls != 0 { + t.Errorf("no further calls expected after create failure: start=%d wait=%d close=%d", + m.startCalls, m.waitCalls, m.closeCalls) + } +} + +// TestStart_StartCallFails verifies that a failure from HcsStartComputeSystem +// surfaces, Wait is skipped, but Close is still invoked via defer. +func TestStart_StartCallFails(t *testing.T) { + sys := newTestSystemWithHandle("start-call-fail", 42) + m := &startMocks{createOp: 7, startErr: errors.New("start-call-failed")} + m.install(t) + + err := sys.Start(context.Background()) + if err == nil || !errors.Is(err, m.startErr) { + t.Fatalf("expected start-call error, got: %v", err) + } + if m.waitCalls != 0 { + t.Errorf("wait must not be called when start fails (got %d)", m.waitCalls) + } + if m.closeCalls != 1 { + t.Errorf("close must run via defer even on start failure (got %d)", m.closeCalls) + } + if !sys.startTime.IsZero() { + t.Error("startTime must remain zero after failed Start") + } +} + +// TestStart_WaitReturnsContainerExit simulates the VM exiting during boot: +// HcsWaitForOperationResult returns ErrUnexpectedContainerExit, which Start +// must surface unchanged. +func TestStart_WaitReturnsContainerExit(t *testing.T) { + sys := newTestSystemWithHandle("start-vm-exit", 42) + m := &startMocks{createOp: 1, waitErr: ErrUnexpectedContainerExit} + m.install(t) + + err := sys.Start(context.Background()) + if !errors.Is(err, ErrUnexpectedContainerExit) { + t.Fatalf("expected ErrUnexpectedContainerExit, got: %v", err) + } + if m.closeCalls != 1 { + t.Errorf("close must run via defer (got %d)", m.closeCalls) + } +} + +// TestStart_WaitReturnsServiceAbort simulates the HCS service disconnecting +// during boot: HcsWaitForOperationResult returns ErrUnexpectedProcessAbort. +func TestStart_WaitReturnsServiceAbort(t *testing.T) { + sys := newTestSystemWithHandle("start-svc-abort", 42) + m := &startMocks{createOp: 1, waitErr: ErrUnexpectedProcessAbort} + m.install(t) + + err := sys.Start(context.Background()) + if !errors.Is(err, ErrUnexpectedProcessAbort) { + t.Fatalf("expected ErrUnexpectedProcessAbort, got: %v", err) + } + if m.closeCalls != 1 { + t.Errorf("close must run via defer (got %d)", m.closeCalls) + } +} + +// TestStart_WaitReturnsTimeout simulates HCS reporting a timeout from the +// operation wait. Start must surface the syscall.Errno from HCS. +func TestStart_WaitReturnsTimeout(t *testing.T) { + sys := newTestSystemWithHandle("start-timeout", 42) + waitErr := syscall.Errno(0x000005B4) // ERROR_TIMEOUT + m := &startMocks{createOp: 1, waitErr: waitErr} + m.install(t) + + err := sys.Start(context.Background()) + if !errors.Is(err, waitErr) { + t.Fatalf("expected wait timeout to surface, got: %v", err) + } + if m.closeCalls != 1 { + t.Errorf("close must run via defer (got %d)", m.closeCalls) + } +} + +// TestPause_SystemExitedDuringPending verifies that if the VM exits while Pause +// is waiting for PauseCompleted, the caller gets ErrUnexpectedContainerExit. +func TestPause_SystemExitedDuringPending(t *testing.T) { + sys := newTestSystemWithHandle("pause-exit", 42) + cbNum := setupCallback(t, sys) + + swapFunc(t, &hcsPauseComputeSystem, func(_ context.Context, _ vmcompute.HcsSystem, _ string) (string, error) { + return "", syscall.Errno(0xC0370103) + }) + + go func() { + time.Sleep(50 * time.Millisecond) + fireNotificationForTest(cbNum, hcsNotificationSystemExited, nil) + }() + + err := sys.Pause(context.Background()) + if !errors.Is(err, ErrUnexpectedContainerExit) { + t.Fatalf("expected ErrUnexpectedContainerExit, got: %v", err) + } +} + +// TestWaitBackground_NormalExit verifies that when SystemExited fires with nil +// error, Wait returns nil and exitError stays nil. This is the clean shutdown path. +func TestWaitBackground_NormalExit(t *testing.T) { + sys := newTestSystemWithHandle("wait-normal", 42) + cbNum := setupCallback(t, sys) + startWaitBackgroundForTest(sys) + + time.Sleep(20 * time.Millisecond) + fireNotificationForTest(cbNum, hcsNotificationSystemExited, nil) + + if err := sys.Wait(); err != nil { + t.Fatalf("expected nil from Wait, got: %v", err) + } + if sys.exitError != nil { + t.Fatalf("expected nil exitError, got: %v", sys.exitError) + } +} + +// TestWaitBackground_UnexpectedExit verifies that when SystemExited fires with +// ErrVmcomputeUnexpectedExit, waitError is nil but exitError captures the crash. +// This distinction matters: Wait() callers get nil (the system did stop), but +// ExitError() reveals it was abnormal. +func TestWaitBackground_UnexpectedExit(t *testing.T) { + sys := newTestSystemWithHandle("wait-unexpected", 42) + cbNum := setupCallback(t, sys) + startWaitBackgroundForTest(sys) + + time.Sleep(20 * time.Millisecond) + fireNotificationForTest(cbNum, hcsNotificationSystemExited, syscall.Errno(0xC0370106)) + + if err := sys.Wait(); err != nil { + t.Fatalf("expected nil from Wait, got: %v", err) + } + if sys.exitError == nil { + t.Fatal("expected non-nil exitError after unexpected exit") + } + if !errors.Is(sys.exitError, syscall.Errno(0xC0370106)) { + t.Fatalf("exitError should wrap ErrVmcomputeUnexpectedExit, got: %v", sys.exitError) + } +} + +// TestWait_MultipleGoroutines verifies that multiple goroutines blocked on Wait +// all unblock when the system exits. This tests the channel fan-out via waitBlock. +func TestWait_MultipleGoroutines(t *testing.T) { + sys := newTestSystemWithHandle("wait-fanout", 42) + cbNum := setupCallback(t, sys) + startWaitBackgroundForTest(sys) + + const numWaiters = 5 + results := make(chan error, numWaiters) + for i := 0; i < numWaiters; i++ { + go func() { results <- sys.Wait() }() + } + + time.Sleep(50 * time.Millisecond) + fireNotificationForTest(cbNum, hcsNotificationSystemExited, nil) + + for i := 0; i < numWaiters; i++ { + select { + case err := <-results: + if err != nil { + t.Errorf("waiter %d: expected nil, got: %v", i, err) + } + case <-time.After(2 * time.Second): + t.Fatalf("waiter %d timed out", i) + } + } +} + +// TestCallback_LateNotificationAfterUnregister verifies that firing a +// notification after unregisterCallback has cleaned up does not panic. +// The callbackMap entry is deleted and channels are closed; a late fire +// should be a no-op. +func TestCallback_LateNotificationAfterUnregister(t *testing.T) { + sys := newTestSystemWithHandle("late-callback", 42) + cbNum := setupCallback(t, sys) + + if _, ok := callbackMap[cbNum]; !ok { + t.Fatal("callback should exist after registration") + } + + swapFunc(t, &hcsUnregisterComputeSystemCallback, func(_ context.Context, _ vmcompute.HcsCallback) error { + return nil + }) + if err := sys.unregisterCallback(context.Background()); err != nil { + t.Fatalf("unregisterCallback: %v", err) + } + + callbackMapLock.RLock() + _, exists := callbackMap[cbNum] + callbackMapLock.RUnlock() + if exists { + t.Fatal("callback should not exist after unregistration") + } + + // Late fires — must not panic. + fireNotificationForTest(cbNum, hcsNotificationSystemExited, nil) + fireNotificationForTest(cbNum, hcsNotificationSystemStartCompleted, nil) +} diff --git a/internal/hcs/zsyscall_compute.go b/internal/hcs/zsyscall_compute.go new file mode 100644 index 0000000000..acbf52cbeb --- /dev/null +++ b/internal/hcs/zsyscall_compute.go @@ -0,0 +1,63 @@ +//go:build windows + +package hcs + +import ( + "github.com/Microsoft/hcsshim/internal/computecore" + "github.com/Microsoft/hcsshim/internal/vmcompute" +) + +// Function variables for HCS compute system and process API calls. +// Production code calls these directly; tests swap them to intercept. + +// --- Compute System Lifecycle --- + +var hcsCreateComputeSystem = vmcompute.HcsCreateComputeSystem +var hcsOpenComputeSystem = vmcompute.HcsOpenComputeSystem +var hcsCloseComputeSystem = vmcompute.HcsCloseComputeSystem +var hcsShutdownComputeSystem = vmcompute.HcsShutdownComputeSystem +var hcsTerminateComputeSystem = vmcompute.HcsTerminateComputeSystem +var hcsPauseComputeSystem = vmcompute.HcsPauseComputeSystem +var hcsResumeComputeSystem = vmcompute.HcsResumeComputeSystem +var hcsSaveComputeSystem = vmcompute.HcsSaveComputeSystem + +// --- Compute System Operations --- + +var hcsGetComputeSystemProperties = vmcompute.HcsGetComputeSystemProperties +var hcsModifyComputeSystem = vmcompute.HcsModifyComputeSystem +var hcsEnumerateComputeSystems = vmcompute.HcsEnumerateComputeSystems + +// --- Compute System Callbacks --- + +var hcsRegisterComputeSystemCallback = vmcompute.HcsRegisterComputeSystemCallback +var hcsUnregisterComputeSystemCallback = vmcompute.HcsUnregisterComputeSystemCallback + +// --- Computecore Operation API (used by Start) --- +// +// HcsStartComputeSystem migrated from vmcompute to the operation-based +// computecore API. The wrappers below preserve testability by allowing tests +// to substitute fake implementations of each call in the Start path. + +var hcsCreateOperation = computecore.HcsCreateOperation +var hcsCloseOperation = computecore.HcsCloseOperation +var hcsStartComputeSystem = computecore.HcsStartComputeSystem +var hcsWaitForOperationResult = computecore.HcsWaitForOperationResult + +// --- Process Lifecycle --- + +var hcsCreateProcess = vmcompute.HcsCreateProcess +var hcsOpenProcess = vmcompute.HcsOpenProcess +var hcsCloseProcess = vmcompute.HcsCloseProcess +var hcsTerminateProcess = vmcompute.HcsTerminateProcess + +// --- Process Operations --- + +var hcsSignalProcess = vmcompute.HcsSignalProcess +var hcsGetProcessInfo = vmcompute.HcsGetProcessInfo +var hcsGetProcessProperties = vmcompute.HcsGetProcessProperties +var hcsModifyProcess = vmcompute.HcsModifyProcess + +// --- Process Callbacks --- + +var hcsRegisterProcessCallback = vmcompute.HcsRegisterProcessCallback +var hcsUnregisterProcessCallback = vmcompute.HcsUnregisterProcessCallback