diff --git a/internal/hcs/export_test.go b/internal/hcs/export_test.go new file mode 100644 index 0000000000..4d3de6b00b --- /dev/null +++ b/internal/hcs/export_test.go @@ -0,0 +1,37 @@ +//go:build windows + +package hcs + +import ( + "context" + + "github.com/Microsoft/hcsshim/internal/vmcompute" +) + +// fireNotificationForTest simulates an HCS notification callback for a given +// callback number. Used to drive async completion paths in tests. +func fireNotificationForTest(callbackNumber uintptr, notification hcsNotification, result error) { + callbackMapLock.RLock() + ctx := callbackMap[callbackNumber] + callbackMapLock.RUnlock() + if ctx == nil { + return + } + if ch, ok := ctx.channels[notification]; ok { + ch <- result + } +} + +func newTestSystemWithHandle(id string, handle uintptr) *System { + s := newSystem(id) + s.handle = vmcompute.HcsSystem(handle) + return s +} + +func registerCallbackForTest(s *System) error { + return s.registerCallback(context.Background()) +} + +func startWaitBackgroundForTest(s *System) { + go s.waitBackground() +} diff --git a/internal/hcs/process.go b/internal/hcs/process.go index fef2bf546c..8481f7150d 100644 --- a/internal/hcs/process.go +++ b/internal/hcs/process.go @@ -106,7 +106,7 @@ func (process *Process) Signal(ctx context.Context, options interface{}) (bool, return false, err } - resultJSON, err := vmcompute.HcsSignalProcess(ctx, process.handle, string(optionsb)) + resultJSON, err := hcsSignalProcess(ctx, process.handle, string(optionsb)) events := processHcsResult(ctx, resultJSON) delivered, err := process.processSignalResult(ctx, err) if err != nil { @@ -171,7 +171,7 @@ func (process *Process) Kill(ctx context.Context) (bool, error) { } defer newProcessHandle.Close() - resultJSON, err := vmcompute.HcsTerminateProcess(ctx, newProcessHandle.handle) + resultJSON, err := hcsTerminateProcess(ctx, newProcessHandle.handle) if err != nil { // We still need to check these two cases, as processes may still be killed by an // external actor (human operator, OOM, random script etc). @@ -234,7 +234,7 @@ func (process *Process) waitBackground() { // Make sure we didn't race with Close() here if process.handle != 0 { - propertiesJSON, resultJSON, err = vmcompute.HcsGetProcessProperties(ctx, process.handle) + propertiesJSON, resultJSON, err = hcsGetProcessProperties(ctx, process.handle) events := processHcsResult(ctx, resultJSON) if err != nil { err = makeProcessError(process, operation, err, events) @@ -303,7 +303,7 @@ func (process *Process) ResizeConsole(ctx context.Context, width, height uint16) return err } - resultJSON, err := vmcompute.HcsModifyProcess(ctx, process.handle, string(modifyRequestb)) + resultJSON, err := hcsModifyProcess(ctx, process.handle, string(modifyRequestb)) events := processHcsResult(ctx, resultJSON) if err != nil { return makeProcessError(process, operation, err, events) @@ -352,7 +352,7 @@ func (process *Process) StdioLegacy() (_ io.WriteCloser, _ io.ReadCloser, _ io.R return stdin, stdout, stderr, nil } - processInfo, resultJSON, err := vmcompute.HcsGetProcessInfo(ctx, process.handle) + processInfo, resultJSON, err := hcsGetProcessInfo(ctx, process.handle) events := processHcsResult(ctx, resultJSON) if err != nil { return nil, nil, nil, makeProcessError(process, operation, err, events) @@ -406,7 +406,7 @@ func (process *Process) CloseStdin(ctx context.Context) (err error) { return err } - resultJSON, err := vmcompute.HcsModifyProcess(ctx, process.handle, string(modifyRequestb)) + resultJSON, err := hcsModifyProcess(ctx, process.handle, string(modifyRequestb)) events := processHcsResult(ctx, resultJSON) if err != nil { return makeProcessError(process, operation, err, events) @@ -509,7 +509,7 @@ func (process *Process) Close() (err error) { return makeProcessError(process, operation, err, nil) } - if err = vmcompute.HcsCloseProcess(ctx, process.handle); err != nil { + if err = hcsCloseProcess(ctx, process.handle); err != nil { return makeProcessError(process, operation, err, nil) } @@ -536,7 +536,7 @@ func (process *Process) registerCallback(ctx context.Context) error { callbackMap[callbackNumber] = callbackContext callbackMapLock.Unlock() - callbackHandle, err := vmcompute.HcsRegisterProcessCallback(ctx, process.handle, notificationWatcherCallback, callbackNumber) + callbackHandle, err := hcsRegisterProcessCallback(ctx, process.handle, notificationWatcherCallback, callbackNumber) if err != nil { return err } @@ -563,9 +563,10 @@ func (process *Process) unregisterCallback(ctx context.Context) error { return nil } - // vmcompute.HcsUnregisterProcessCallback has its own synchronization to - // wait for all callbacks to complete. We must NOT hold the callbackMapLock. - err := vmcompute.HcsUnregisterProcessCallback(ctx, handle) + // The underlying HCS API (HcsUnregisterProcessCallback) has its own + // synchronization to wait for all in-flight callbacks to complete. + // We must NOT hold the callbackMapLock during this call. + err := hcsUnregisterProcessCallback(ctx, handle) if err != nil { return err } diff --git a/internal/hcs/system.go b/internal/hcs/system.go index 3d9fcce1bd..3bd74d9aea 100644 --- a/internal/hcs/system.go +++ b/internal/hcs/system.go @@ -91,7 +91,7 @@ func CreateComputeSystem(ctx context.Context, id string, hcsDocumentInterface in resultJSON string createError error ) - computeSystem.handle, resultJSON, createError = vmcompute.HcsCreateComputeSystem(ctx, id, hcsDocument, identity) + computeSystem.handle, resultJSON, createError = hcsCreateComputeSystem(ctx, id, hcsDocument, identity) if createError == nil || IsPending(createError) { defer func() { if err != nil { @@ -128,7 +128,7 @@ func OpenComputeSystem(ctx context.Context, id string) (*System, error) { operation := "hcs::OpenComputeSystem" computeSystem := newSystem(id) - handle, resultJSON, err := vmcompute.HcsOpenComputeSystem(ctx, id) + handle, resultJSON, err := hcsOpenComputeSystem(ctx, id) events := processHcsResult(ctx, resultJSON) if err != nil { return nil, makeSystemError(computeSystem, operation, err, events) @@ -185,7 +185,7 @@ func GetComputeSystems(ctx context.Context, q schema1.ComputeSystemQuery) ([]sch return nil, err } - computeSystemsJSON, resultJSON, err := vmcompute.HcsEnumerateComputeSystems(ctx, string(queryb)) + computeSystemsJSON, resultJSON, err := hcsEnumerateComputeSystems(ctx, string(queryb)) events := processHcsResult(ctx, resultJSON) if err != nil { return nil, &HcsError{Op: operation, Err: err, Events: events} @@ -211,11 +211,11 @@ func (computeSystem *System) Start(ctx context.Context) error { return makeSystemError(computeSystem, "hcs::System::Start", ErrAlreadyClosed, nil) } - op, err := computecore.HcsCreateOperation(ctx, 0, 0) + op, err := hcsCreateOperation(ctx, 0, 0) if err != nil { return makeSystemError(computeSystem, "hcs::System::Start", err, nil) } - defer computecore.HcsCloseOperation(ctx, op) + defer hcsCloseOperation(ctx, op) return computeSystem.start(ctx, op, "") } @@ -236,7 +236,7 @@ func (computeSystem *System) start(ctx context.Context, op computecore.HcsOperat defer func() { oc.SetSpanStatus(span, err) }() span.AddAttributes(trace.StringAttribute("cid", computeSystem.id)) - if err := computecore.HcsStartComputeSystem( + if err := hcsStartComputeSystem( ctx, computecore.HcsSystem(computeSystem.handle), op, @@ -245,7 +245,7 @@ func (computeSystem *System) start(ctx context.Context, op computecore.HcsOperat return makeSystemError(computeSystem, operation, err, nil) } - if _, err := computecore.HcsWaitForOperationResult(ctx, op, 0xFFFFFFFF); err != nil { + if _, err := hcsWaitForOperationResult(ctx, op, 0xFFFFFFFF); err != nil { return makeSystemError(computeSystem, operation, err, nil) } @@ -269,7 +269,7 @@ func (computeSystem *System) Shutdown(ctx context.Context) error { return nil } - resultJSON, err := vmcompute.HcsShutdownComputeSystem(ctx, computeSystem.handle, "") + resultJSON, err := hcsShutdownComputeSystem(ctx, computeSystem.handle, "") events := processHcsResult(ctx, resultJSON) if err != nil && !errors.Is(err, ErrVmcomputeAlreadyStopped) && @@ -291,7 +291,7 @@ func (computeSystem *System) Terminate(ctx context.Context) error { return nil } - resultJSON, err := vmcompute.HcsTerminateComputeSystem(ctx, computeSystem.handle, "") + resultJSON, err := hcsTerminateComputeSystem(ctx, computeSystem.handle, "") events := processHcsResult(ctx, resultJSON) if err != nil && !errors.Is(err, ErrVmcomputeAlreadyStopped) && @@ -394,7 +394,7 @@ func (computeSystem *System) Properties(ctx context.Context, types ...schema1.Pr return nil, makeSystemError(computeSystem, operation, err, nil) } - propertiesJSON, resultJSON, err := vmcompute.HcsGetComputeSystemProperties(ctx, computeSystem.handle, string(queryBytes)) + propertiesJSON, resultJSON, err := hcsGetComputeSystemProperties(ctx, computeSystem.handle, string(queryBytes)) events := processHcsResult(ctx, resultJSON) if err != nil { return nil, makeSystemError(computeSystem, operation, err, events) @@ -535,7 +535,7 @@ func (computeSystem *System) hcsPropertiesV2Query(ctx context.Context, types []h return nil, makeSystemError(computeSystem, operation, err, nil) } - propertiesJSON, resultJSON, err := vmcompute.HcsGetComputeSystemProperties(ctx, computeSystem.handle, string(queryBytes)) + propertiesJSON, resultJSON, err := hcsGetComputeSystemProperties(ctx, computeSystem.handle, string(queryBytes)) events := processHcsResult(ctx, resultJSON) if err != nil { return nil, makeSystemError(computeSystem, operation, err, events) @@ -672,7 +672,7 @@ func (computeSystem *System) Pause(ctx context.Context) (err error) { return makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } - resultJSON, err := vmcompute.HcsPauseComputeSystem(ctx, computeSystem.handle, "") + resultJSON, err := hcsPauseComputeSystem(ctx, computeSystem.handle, "") events, err := processAsyncHcsResult(ctx, err, resultJSON, computeSystem.callbackNumber, hcsNotificationSystemPauseCompleted, &timeout.SystemPause) if err != nil { @@ -700,7 +700,7 @@ func (computeSystem *System) Resume(ctx context.Context) (err error) { return makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } - resultJSON, err := vmcompute.HcsResumeComputeSystem(ctx, computeSystem.handle, "") + resultJSON, err := hcsResumeComputeSystem(ctx, computeSystem.handle, "") events, err := processAsyncHcsResult(ctx, err, resultJSON, computeSystem.callbackNumber, hcsNotificationSystemResumeCompleted, &timeout.SystemResume) if err != nil { @@ -733,7 +733,7 @@ func (computeSystem *System) Save(ctx context.Context, options interface{}) (err return makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } - result, err := vmcompute.HcsSaveComputeSystem(ctx, computeSystem.handle, string(saveOptions)) + result, err := hcsSaveComputeSystem(ctx, computeSystem.handle, string(saveOptions)) events, err := processAsyncHcsResult(ctx, err, result, computeSystem.callbackNumber, hcsNotificationSystemSaveCompleted, &timeout.SystemSave) if err != nil { @@ -757,7 +757,7 @@ func (computeSystem *System) createProcess(ctx context.Context, operation string } configuration := string(configurationb) - processInfo, processHandle, resultJSON, err := vmcompute.HcsCreateProcess(ctx, computeSystem.handle, configuration) + processInfo, processHandle, resultJSON, err := hcsCreateProcess(ctx, computeSystem.handle, configuration) events := processHcsResult(ctx, resultJSON) if err != nil { if v2, ok := c.(*hcsschema.ProcessParameters); ok { @@ -813,7 +813,7 @@ func (computeSystem *System) OpenProcess(ctx context.Context, pid int) (*Process return nil, makeSystemError(computeSystem, operation, ErrAlreadyClosed, nil) } - processHandle, resultJSON, err := vmcompute.HcsOpenProcess(ctx, computeSystem.handle, uint32(pid)) + processHandle, resultJSON, err := hcsOpenProcess(ctx, computeSystem.handle, uint32(pid)) events := processHcsResult(ctx, resultJSON) if err != nil { return nil, makeSystemError(computeSystem, operation, err, events) @@ -856,7 +856,7 @@ func (computeSystem *System) CloseCtx(ctx context.Context) (err error) { return makeSystemError(computeSystem, operation, err, nil) } - err = vmcompute.HcsCloseComputeSystem(ctx, computeSystem.handle) + err = hcsCloseComputeSystem(ctx, computeSystem.handle) if err != nil { return makeSystemError(computeSystem, operation, err, nil) } @@ -885,7 +885,7 @@ func (computeSystem *System) registerCallback(ctx context.Context) error { callbackMap[callbackNumber] = callbackContext callbackMapLock.Unlock() - callbackHandle, err := vmcompute.HcsRegisterComputeSystemCallback(ctx, computeSystem.handle, + callbackHandle, err := hcsRegisterComputeSystemCallback(ctx, computeSystem.handle, notificationWatcherCallback, callbackNumber) if err != nil { return err @@ -915,7 +915,7 @@ func (computeSystem *System) unregisterCallback(ctx context.Context) error { // hcsUnregisterComputeSystemCallback has its own synchronization // to wait for all callbacks to complete. We must NOT hold the callbackMapLock. - err := vmcompute.HcsUnregisterComputeSystemCallback(ctx, handle) + err := hcsUnregisterComputeSystemCallback(ctx, handle) if err != nil { return err } @@ -948,7 +948,7 @@ func (computeSystem *System) Modify(ctx context.Context, config interface{}) err } requestJSON := string(requestBytes) - resultJSON, err := vmcompute.HcsModifyComputeSystem(ctx, computeSystem.handle, requestJSON) + resultJSON, err := hcsModifyComputeSystem(ctx, computeSystem.handle, requestJSON) events := processHcsResult(ctx, resultJSON) if err != nil { return makeSystemError(computeSystem, operation, err, events) diff --git a/internal/hcs/system_test.go b/internal/hcs/system_test.go new file mode 100644 index 0000000000..18670980b8 --- /dev/null +++ b/internal/hcs/system_test.go @@ -0,0 +1,318 @@ +//go:build windows + +package hcs + +import ( + "context" + "errors" + "syscall" + "testing" + "time" + + "github.com/Microsoft/hcsshim/internal/computecore" + "github.com/Microsoft/hcsshim/internal/vmcompute" +) + +// swapFunc replaces *target with fn for the duration of t, restoring the original on cleanup. +func swapFunc[T any](t *testing.T, target *T, fn T) { + t.Helper() + orig := *target + *target = fn + t.Cleanup(func() { *target = orig }) +} + +// setupCallback registers a fake callback on sys so the notification channels exist. +// Returns the callback number for firing test notifications. +func setupCallback(t *testing.T, sys *System) uintptr { + t.Helper() + swapFunc(t, &hcsRegisterComputeSystemCallback, func(_ context.Context, _ vmcompute.HcsSystem, _ uintptr, _ uintptr) (vmcompute.HcsCallback, error) { + return vmcompute.HcsCallback(99), nil + }) + if err := registerCallbackForTest(sys); err != nil { + t.Fatalf("registerCallback: %v", err) + } + return sys.callbackNumber +} + +// startMocks holds fake implementations of the four computecore calls used by Start. +// Tests build a startMocks, then call install(t) to swap the package-level vars +// for the duration of the test. +type startMocks struct { + createOp computecore.HcsOperation + createErr error + createCalls int + + closeCalls int + + startErr error + startCalls int + + waitResult string + waitErr error + waitCalls int +} + +func (m *startMocks) install(t *testing.T) { + t.Helper() + swapFunc(t, &hcsCreateOperation, func(_ context.Context, _ uintptr, _ uintptr) (computecore.HcsOperation, error) { + m.createCalls++ + return m.createOp, m.createErr + }) + swapFunc(t, &hcsCloseOperation, func(_ context.Context, _ computecore.HcsOperation) { + m.closeCalls++ + }) + swapFunc(t, &hcsStartComputeSystem, func(_ context.Context, _ computecore.HcsSystem, _ computecore.HcsOperation, _ string) error { + m.startCalls++ + return m.startErr + }) + swapFunc(t, &hcsWaitForOperationResult, func(_ context.Context, _ computecore.HcsOperation, _ uint32) (string, error) { + m.waitCalls++ + return m.waitResult, m.waitErr + }) +} + +// TestStart_Success verifies the happy path: CreateOperation, StartComputeSystem, +// and WaitForOperationResult all succeed; CloseOperation runs via defer. +func TestStart_Success(t *testing.T) { + sys := newTestSystemWithHandle("start-success", 42) + m := &startMocks{createOp: 99} + m.install(t) + + if err := sys.Start(context.Background()); err != nil { + t.Fatalf("Start: %v", err) + } + if m.createCalls != 1 || m.startCalls != 1 || m.waitCalls != 1 || m.closeCalls != 1 { + t.Errorf("call counts: create=%d start=%d wait=%d close=%d (want 1 each)", + m.createCalls, m.startCalls, m.waitCalls, m.closeCalls) + } + if sys.startTime.IsZero() { + t.Error("startTime should be set after successful Start") + } +} + +// TestStart_AlreadyClosed verifies that Start on a system whose handle has been +// cleared returns ErrAlreadyClosed without invoking any HCS API. +func TestStart_AlreadyClosed(t *testing.T) { + sys := newTestSystemWithHandle("start-closed", 0) + m := &startMocks{} + m.install(t) + + err := sys.Start(context.Background()) + if !errors.Is(err, ErrAlreadyClosed) { + t.Fatalf("expected ErrAlreadyClosed, got: %v", err) + } + if m.createCalls != 0 || m.startCalls != 0 || m.waitCalls != 0 { + t.Errorf("no HCS calls expected for closed handle: create=%d start=%d wait=%d", + m.createCalls, m.startCalls, m.waitCalls) + } +} + +// TestStart_CreateOperationFails verifies the failure surfaces and Close is not +// called (no operation handle was acquired). +func TestStart_CreateOperationFails(t *testing.T) { + sys := newTestSystemWithHandle("start-create-fail", 42) + m := &startMocks{createErr: errors.New("create-op-failed")} + m.install(t) + + err := sys.Start(context.Background()) + if err == nil || !errors.Is(err, m.createErr) { + t.Fatalf("expected create error to surface, got: %v", err) + } + if m.startCalls != 0 || m.waitCalls != 0 || m.closeCalls != 0 { + t.Errorf("no further calls expected after create failure: start=%d wait=%d close=%d", + m.startCalls, m.waitCalls, m.closeCalls) + } +} + +// TestStart_StartCallFails verifies that a failure from HcsStartComputeSystem +// surfaces, Wait is skipped, but Close is still invoked via defer. +func TestStart_StartCallFails(t *testing.T) { + sys := newTestSystemWithHandle("start-call-fail", 42) + m := &startMocks{createOp: 7, startErr: errors.New("start-call-failed")} + m.install(t) + + err := sys.Start(context.Background()) + if err == nil || !errors.Is(err, m.startErr) { + t.Fatalf("expected start-call error, got: %v", err) + } + if m.waitCalls != 0 { + t.Errorf("wait must not be called when start fails (got %d)", m.waitCalls) + } + if m.closeCalls != 1 { + t.Errorf("close must run via defer even on start failure (got %d)", m.closeCalls) + } + if !sys.startTime.IsZero() { + t.Error("startTime must remain zero after failed Start") + } +} + +// TestStart_WaitReturnsContainerExit simulates the VM exiting during boot: +// HcsWaitForOperationResult returns ErrUnexpectedContainerExit, which Start +// must surface unchanged. +func TestStart_WaitReturnsContainerExit(t *testing.T) { + sys := newTestSystemWithHandle("start-vm-exit", 42) + m := &startMocks{createOp: 1, waitErr: ErrUnexpectedContainerExit} + m.install(t) + + err := sys.Start(context.Background()) + if !errors.Is(err, ErrUnexpectedContainerExit) { + t.Fatalf("expected ErrUnexpectedContainerExit, got: %v", err) + } + if m.closeCalls != 1 { + t.Errorf("close must run via defer (got %d)", m.closeCalls) + } +} + +// TestStart_WaitReturnsServiceAbort simulates the HCS service disconnecting +// during boot: HcsWaitForOperationResult returns ErrUnexpectedProcessAbort. +func TestStart_WaitReturnsServiceAbort(t *testing.T) { + sys := newTestSystemWithHandle("start-svc-abort", 42) + m := &startMocks{createOp: 1, waitErr: ErrUnexpectedProcessAbort} + m.install(t) + + err := sys.Start(context.Background()) + if !errors.Is(err, ErrUnexpectedProcessAbort) { + t.Fatalf("expected ErrUnexpectedProcessAbort, got: %v", err) + } + if m.closeCalls != 1 { + t.Errorf("close must run via defer (got %d)", m.closeCalls) + } +} + +// TestStart_WaitReturnsTimeout simulates HCS reporting a timeout from the +// operation wait. Start must surface the syscall.Errno from HCS. +func TestStart_WaitReturnsTimeout(t *testing.T) { + sys := newTestSystemWithHandle("start-timeout", 42) + waitErr := syscall.Errno(0x000005B4) // ERROR_TIMEOUT + m := &startMocks{createOp: 1, waitErr: waitErr} + m.install(t) + + err := sys.Start(context.Background()) + if !errors.Is(err, waitErr) { + t.Fatalf("expected wait timeout to surface, got: %v", err) + } + if m.closeCalls != 1 { + t.Errorf("close must run via defer (got %d)", m.closeCalls) + } +} + +// TestPause_SystemExitedDuringPending verifies that if the VM exits while Pause +// is waiting for PauseCompleted, the caller gets ErrUnexpectedContainerExit. +func TestPause_SystemExitedDuringPending(t *testing.T) { + sys := newTestSystemWithHandle("pause-exit", 42) + cbNum := setupCallback(t, sys) + + swapFunc(t, &hcsPauseComputeSystem, func(_ context.Context, _ vmcompute.HcsSystem, _ string) (string, error) { + return "", syscall.Errno(0xC0370103) + }) + + go func() { + time.Sleep(50 * time.Millisecond) + fireNotificationForTest(cbNum, hcsNotificationSystemExited, nil) + }() + + err := sys.Pause(context.Background()) + if !errors.Is(err, ErrUnexpectedContainerExit) { + t.Fatalf("expected ErrUnexpectedContainerExit, got: %v", err) + } +} + +// TestWaitBackground_NormalExit verifies that when SystemExited fires with nil +// error, Wait returns nil and exitError stays nil. This is the clean shutdown path. +func TestWaitBackground_NormalExit(t *testing.T) { + sys := newTestSystemWithHandle("wait-normal", 42) + cbNum := setupCallback(t, sys) + startWaitBackgroundForTest(sys) + + time.Sleep(20 * time.Millisecond) + fireNotificationForTest(cbNum, hcsNotificationSystemExited, nil) + + if err := sys.Wait(); err != nil { + t.Fatalf("expected nil from Wait, got: %v", err) + } + if sys.exitError != nil { + t.Fatalf("expected nil exitError, got: %v", sys.exitError) + } +} + +// TestWaitBackground_UnexpectedExit verifies that when SystemExited fires with +// ErrVmcomputeUnexpectedExit, waitError is nil but exitError captures the crash. +// This distinction matters: Wait() callers get nil (the system did stop), but +// ExitError() reveals it was abnormal. +func TestWaitBackground_UnexpectedExit(t *testing.T) { + sys := newTestSystemWithHandle("wait-unexpected", 42) + cbNum := setupCallback(t, sys) + startWaitBackgroundForTest(sys) + + time.Sleep(20 * time.Millisecond) + fireNotificationForTest(cbNum, hcsNotificationSystemExited, syscall.Errno(0xC0370106)) + + if err := sys.Wait(); err != nil { + t.Fatalf("expected nil from Wait, got: %v", err) + } + if sys.exitError == nil { + t.Fatal("expected non-nil exitError after unexpected exit") + } + if !errors.Is(sys.exitError, syscall.Errno(0xC0370106)) { + t.Fatalf("exitError should wrap ErrVmcomputeUnexpectedExit, got: %v", sys.exitError) + } +} + +// TestWait_MultipleGoroutines verifies that multiple goroutines blocked on Wait +// all unblock when the system exits. This tests the channel fan-out via waitBlock. +func TestWait_MultipleGoroutines(t *testing.T) { + sys := newTestSystemWithHandle("wait-fanout", 42) + cbNum := setupCallback(t, sys) + startWaitBackgroundForTest(sys) + + const numWaiters = 5 + results := make(chan error, numWaiters) + for i := 0; i < numWaiters; i++ { + go func() { results <- sys.Wait() }() + } + + time.Sleep(50 * time.Millisecond) + fireNotificationForTest(cbNum, hcsNotificationSystemExited, nil) + + for i := 0; i < numWaiters; i++ { + select { + case err := <-results: + if err != nil { + t.Errorf("waiter %d: expected nil, got: %v", i, err) + } + case <-time.After(2 * time.Second): + t.Fatalf("waiter %d timed out", i) + } + } +} + +// TestCallback_LateNotificationAfterUnregister verifies that firing a +// notification after unregisterCallback has cleaned up does not panic. +// The callbackMap entry is deleted and channels are closed; a late fire +// should be a no-op. +func TestCallback_LateNotificationAfterUnregister(t *testing.T) { + sys := newTestSystemWithHandle("late-callback", 42) + cbNum := setupCallback(t, sys) + + if _, ok := callbackMap[cbNum]; !ok { + t.Fatal("callback should exist after registration") + } + + swapFunc(t, &hcsUnregisterComputeSystemCallback, func(_ context.Context, _ vmcompute.HcsCallback) error { + return nil + }) + if err := sys.unregisterCallback(context.Background()); err != nil { + t.Fatalf("unregisterCallback: %v", err) + } + + callbackMapLock.RLock() + _, exists := callbackMap[cbNum] + callbackMapLock.RUnlock() + if exists { + t.Fatal("callback should not exist after unregistration") + } + + // Late fires — must not panic. + fireNotificationForTest(cbNum, hcsNotificationSystemExited, nil) + fireNotificationForTest(cbNum, hcsNotificationSystemStartCompleted, nil) +} diff --git a/internal/hcs/zsyscall_compute.go b/internal/hcs/zsyscall_compute.go new file mode 100644 index 0000000000..acbf52cbeb --- /dev/null +++ b/internal/hcs/zsyscall_compute.go @@ -0,0 +1,63 @@ +//go:build windows + +package hcs + +import ( + "github.com/Microsoft/hcsshim/internal/computecore" + "github.com/Microsoft/hcsshim/internal/vmcompute" +) + +// Function variables for HCS compute system and process API calls. +// Production code calls these directly; tests swap them to intercept. + +// --- Compute System Lifecycle --- + +var hcsCreateComputeSystem = vmcompute.HcsCreateComputeSystem +var hcsOpenComputeSystem = vmcompute.HcsOpenComputeSystem +var hcsCloseComputeSystem = vmcompute.HcsCloseComputeSystem +var hcsShutdownComputeSystem = vmcompute.HcsShutdownComputeSystem +var hcsTerminateComputeSystem = vmcompute.HcsTerminateComputeSystem +var hcsPauseComputeSystem = vmcompute.HcsPauseComputeSystem +var hcsResumeComputeSystem = vmcompute.HcsResumeComputeSystem +var hcsSaveComputeSystem = vmcompute.HcsSaveComputeSystem + +// --- Compute System Operations --- + +var hcsGetComputeSystemProperties = vmcompute.HcsGetComputeSystemProperties +var hcsModifyComputeSystem = vmcompute.HcsModifyComputeSystem +var hcsEnumerateComputeSystems = vmcompute.HcsEnumerateComputeSystems + +// --- Compute System Callbacks --- + +var hcsRegisterComputeSystemCallback = vmcompute.HcsRegisterComputeSystemCallback +var hcsUnregisterComputeSystemCallback = vmcompute.HcsUnregisterComputeSystemCallback + +// --- Computecore Operation API (used by Start) --- +// +// HcsStartComputeSystem migrated from vmcompute to the operation-based +// computecore API. The wrappers below preserve testability by allowing tests +// to substitute fake implementations of each call in the Start path. + +var hcsCreateOperation = computecore.HcsCreateOperation +var hcsCloseOperation = computecore.HcsCloseOperation +var hcsStartComputeSystem = computecore.HcsStartComputeSystem +var hcsWaitForOperationResult = computecore.HcsWaitForOperationResult + +// --- Process Lifecycle --- + +var hcsCreateProcess = vmcompute.HcsCreateProcess +var hcsOpenProcess = vmcompute.HcsOpenProcess +var hcsCloseProcess = vmcompute.HcsCloseProcess +var hcsTerminateProcess = vmcompute.HcsTerminateProcess + +// --- Process Operations --- + +var hcsSignalProcess = vmcompute.HcsSignalProcess +var hcsGetProcessInfo = vmcompute.HcsGetProcessInfo +var hcsGetProcessProperties = vmcompute.HcsGetProcessProperties +var hcsModifyProcess = vmcompute.HcsModifyProcess + +// --- Process Callbacks --- + +var hcsRegisterProcessCallback = vmcompute.HcsRegisterProcessCallback +var hcsUnregisterProcessCallback = vmcompute.HcsUnregisterProcessCallback