From c98f9d60e4b19d57c3d1e110dde22043b5214305 Mon Sep 17 00:00:00 2001 From: Anton Nekipelov <226657+anton-107@users.noreply.github.com> Date: Thu, 16 Apr 2026 18:07:11 +0200 Subject: [PATCH 1/3] Enable scp/rsync/sftp support for `ssh connect` The `ssh connect` command now automatically writes an SSH config entry so that standard SSH tools can connect using the same hostname. This also suppresses noisy output in proxy mode and handles context cancellation gracefully when tools like scp disconnect. - Auto-write SSH config from `connect` (opt out with `--no-config`) - Add `--multiplex` flag for SSH ControlMaster connection reuse - Suppress cmdio output in proxy mode (fixes "Uploading binaries..." noise) - Return nil for context.Canceled in proxy mode (fixes "Error: context canceled") - Remove stale metadata from persisted SSH config (IDE mode fix) - Refactor GenerateHostConfig to use HostConfigOptions struct Co-authored-by: Isaac --- experimental/ssh/cmd/connect.go | 7 ++ experimental/ssh/cmd/setup.go | 3 + experimental/ssh/internal/client/client.go | 100 ++++++++++++++++-- experimental/ssh/internal/setup/setup.go | 27 ++++- experimental/ssh/internal/setup/setup_test.go | 38 +++++++ .../ssh/internal/sshconfig/sshconfig.go | 65 ++++++++++-- .../ssh/internal/sshconfig/sshconfig_test.go | 65 ++++++++++++ 7 files changed, 283 insertions(+), 22 deletions(-) diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index b19043d803..5f7b3bce8f 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -36,6 +36,8 @@ the SSH server and handling the connection proxy. var liteswap string var skipSettingsCheck bool var environmentVersion int + var noConfig bool + var multiplex bool cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)") cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects") @@ -71,6 +73,9 @@ the SSH server and handling the connection proxy. cmd.Flags().IntVar(&environmentVersion, "environment-version", defaultEnvironmentVersion, "Environment version for serverless compute") cmd.Flags().MarkHidden("environment-version") + cmd.Flags().BoolVar(&noConfig, "no-config", false, "Do not write SSH config entry (disables scp/rsync support)") + cmd.Flags().BoolVar(&multiplex, "multiplex", false, "Enable SSH connection multiplexing (ControlMaster) for faster scp/rsync") + cmd.PreRunE = func(cmd *cobra.Command, args []string) error { // CLI in the proxy mode is executed by the ssh client and can't prompt for input if proxyMode { @@ -109,6 +114,8 @@ the SSH server and handling the connection proxy. Liteswap: liteswap, SkipSettingsCheck: skipSettingsCheck, EnvironmentVersion: environmentVersion, + SkipConfigWrite: noConfig, + Multiplex: multiplex, AdditionalArgs: args, } if err := opts.Validate(); err != nil { diff --git a/experimental/ssh/cmd/setup.go b/experimental/ssh/cmd/setup.go index 81b7863666..504911e8c6 100644 --- a/experimental/ssh/cmd/setup.go +++ b/experimental/ssh/cmd/setup.go @@ -28,6 +28,7 @@ an SSH host configuration to your SSH config file. var sshConfigPath string var shutdownDelay time.Duration var autoStartCluster bool + var multiplex bool cmd.Flags().StringVar(&hostName, "name", "", "Host name to use in SSH config") cmd.MarkFlagRequired("name") @@ -35,6 +36,7 @@ an SSH host configuration to your SSH config file. cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster when establishing the ssh connection") cmd.Flags().StringVar(&sshConfigPath, "ssh-config", "", "Path to SSH config file (default ~/.ssh/config)") cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "SSH server will terminate after this delay if there are no active connections") + cmd.Flags().BoolVar(&multiplex, "multiplex", false, "Enable SSH connection multiplexing (ControlMaster) for faster scp/rsync") cmd.PreRunE = func(cmd *cobra.Command, args []string) error { // We want to avoid the situation where the setup command works because it pulls the auth config from a bundle, @@ -53,6 +55,7 @@ an SSH host configuration to your SSH config file. SSHConfigPath: sshConfigPath, ShutdownDelay: shutdownDelay, Profile: wsClient.Config.Profile, + Multiplex: multiplex, } clientOpts := client.ClientOptions{ ClusterID: setupOpts.ClusterID, diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index cd6d73f51e..00107bcd40 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -15,6 +15,7 @@ import ( "os/signal" "path/filepath" "regexp" + "runtime" "strconv" "strings" "syscall" @@ -99,6 +100,10 @@ type ClientOptions struct { SkipSettingsCheck bool // Environment version for serverless compute. EnvironmentVersion int + // If true, skip writing the SSH config entry in terminal mode. + SkipConfigWrite bool + // If true, enable SSH ControlMaster multiplexing for connection reuse. + Multiplex bool } func (o *ClientOptions) Validate() error { @@ -207,6 +212,12 @@ func (o *ClientOptions) ToProxyCommand() (string, error) { } func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOptions) error { + // In proxy mode, the CLI runs as a ProxyCommand subprocess of ssh/scp/rsync. + // Suppress all user-facing output so it doesn't interfere with the parent tool. + if opts.ProxyMode { + ctx = cmdio.MockDiscard(ctx) + } + ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -214,7 +225,6 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt signal.Notify(sigCh, os.Interrupt, syscall.SIGHUP, syscall.SIGTERM) go func() { <-sigCh - cmdio.LogString(ctx, "Received termination signal, cleaning up...") cancel() }() @@ -350,10 +360,24 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt } if opts.ProxyMode { - return runSSHProxy(ctx, client, serverPort, clusterID, opts) + err := runSSHProxy(ctx, client, serverPort, clusterID, opts) + // context.Canceled is the normal exit path when the SSH client (scp/rsync) disconnects. + if errors.Is(err, context.Canceled) { + return nil + } + return err } else if opts.IDE != "" { return runIDE(ctx, client, userName, keyPath, serverPort, clusterID, opts) } else { + hostName := opts.SessionIdentifier() + if !opts.SkipConfigWrite { + if err := writeSSHConfigForConnect(ctx, hostName, userName, keyPath, opts); err != nil { + // Non-fatal: log and continue with the SSH session + log.Warnf(ctx, "Failed to write SSH config entry: %v", err) + } else { + printSSHToolHints(ctx, hostName) + } + } log.Infof(ctx, "Additional SSH arguments: %v", opts.AdditionalArgs) return spawnSSHClient(ctx, userName, keyPath, serverPort, clusterID, opts) } @@ -377,7 +401,7 @@ func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, k return fmt.Errorf("failed to get SSH config path: %w", err) } - err = ensureSSHConfigEntry(ctx, configPath, connectionName, userName, keyPath, serverPort, clusterID, opts) + err = ensureSSHConfigEntry(ctx, configPath, connectionName, userName, keyPath, opts) if err != nil { return fmt.Errorf("failed to ensure SSH config entry: %w", err) } @@ -385,23 +409,36 @@ func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, k return vscode.LaunchIDE(ctx, opts.IDE, connectionName, userName, currentUser.UserName) } -func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error { +func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, keyPath string, opts ClientOptions) error { // Ensure the Include directive exists in the main SSH config err := sshconfig.EnsureIncludeDirective(ctx, configPath) if err != nil { return err } - // Generate ProxyCommand with server metadata - optsWithMetadata := opts - optsWithMetadata.ServerMetadata = FormatMetadata(userName, serverPort, clusterID) - - proxyCommand, err := optsWithMetadata.ToProxyCommand() + // Generate ProxyCommand without metadata so the config is resilient to server restarts. + // The inline SSH invocation passes metadata separately for fast first-connection. + proxyCommand, err := opts.ToProxyCommand() if err != nil { return fmt.Errorf("failed to generate ProxyCommand: %w", err) } - hostConfig := sshconfig.GenerateHostConfig(hostName, userName, keyPath, proxyCommand) + configOpts := sshconfig.HostConfigOptions{ + HostName: hostName, + UserName: userName, + IdentityFile: keyPath, + ProxyCommand: proxyCommand, + } + + if opts.Multiplex { + controlPath, cpErr := controlSocketPath(ctx) + if cpErr != nil { + return cpErr + } + configOpts.ControlPath = controlPath + } + + hostConfig := sshconfig.GenerateHostConfig(configOpts) _, err = sshconfig.CreateOrUpdateHostConfig(ctx, hostName, hostConfig, true) if err != nil { @@ -412,6 +449,39 @@ func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, k return nil } +// writeSSHConfigForConnect writes an SSH config entry so that SSH-based tools +// (scp, rsync, sftp) can connect using the same hostname. +func writeSSHConfigForConnect(ctx context.Context, hostName, userName, keyPath string, opts ClientOptions) error { + configPath, err := sshconfig.GetMainConfigPath(ctx) + if err != nil { + return err + } + + if opts.Multiplex { + if err := sshconfig.EnsureSocketsDir(ctx); err != nil { + return err + } + } + + return ensureSSHConfigEntry(ctx, configPath, hostName, userName, keyPath, opts) +} + +// controlSocketPath returns the ControlPath pattern for SSH multiplexing. +func controlSocketPath(ctx context.Context) (string, error) { + socketsDir, err := sshconfig.GetSocketsDir(ctx) + if err != nil { + return "", err + } + return filepath.ToSlash(filepath.Join(socketsDir, "%h")), nil +} + +func printSSHToolHints(ctx context.Context, hostName string) { + cmdio.LogString(ctx, fmt.Sprintf("SSH config written for '%s'. You can now use SSH tools in another terminal:", hostName)) + cmdio.LogString(ctx, fmt.Sprintf(" scp %s:remote-file local-file", hostName)) + cmdio.LogString(ctx, fmt.Sprintf(" rsync -avz %s:remote-dir/ local-dir/", hostName)) + cmdio.LogString(ctx, fmt.Sprintf(" sftp %s", hostName)) +} + // getServerMetadata retrieves the server metadata from the workspace and validates it via Driver Proxy. // sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). // For dedicated clusters, clusterID should be the same as sessionID. @@ -580,6 +650,16 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server if opts.UserKnownHostsFile != "" { sshArgs = append(sshArgs, "-o", "UserKnownHostsFile="+opts.UserKnownHostsFile) } + if opts.Multiplex && runtime.GOOS != "windows" { + cp, cpErr := controlSocketPath(ctx) + if cpErr == nil { + sshArgs = append(sshArgs, + "-o", "ControlMaster=auto", + "-o", "ControlPath="+cp, + "-o", "ControlPersist=10m", + ) + } + } sshArgs = append(sshArgs, hostName) sshArgs = append(sshArgs, opts.AdditionalArgs...) diff --git a/experimental/ssh/internal/setup/setup.go b/experimental/ssh/internal/setup/setup.go index d96510631f..563f801080 100644 --- a/experimental/ssh/internal/setup/setup.go +++ b/experimental/ssh/internal/setup/setup.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "path/filepath" "time" "github.com/databricks/cli/experimental/ssh/internal/keys" @@ -30,6 +31,8 @@ type SetupOptions struct { Profile string // Proxy command to use for the SSH connection ProxyCommand string + // If true, enable SSH ControlMaster multiplexing for connection reuse by scp/rsync/sftp. + Multiplex bool } func validateClusterAccess(ctx context.Context, client *databricks.WorkspaceClient, clusterID string) error { @@ -49,8 +52,22 @@ func generateHostConfig(ctx context.Context, opts SetupOptions) (string, error) return "", fmt.Errorf("failed to get local keys folder: %w", err) } - hostConfig := sshconfig.GenerateHostConfig(opts.HostName, "root", identityFilePath, opts.ProxyCommand) - return hostConfig, nil + configOpts := sshconfig.HostConfigOptions{ + HostName: opts.HostName, + UserName: "root", + IdentityFile: identityFilePath, + ProxyCommand: opts.ProxyCommand, + } + + if opts.Multiplex { + socketsDir, sockErr := sshconfig.GetSocketsDir(ctx) + if sockErr != nil { + return "", sockErr + } + configOpts.ControlPath = filepath.ToSlash(filepath.Join(socketsDir, "%h")) + } + + return sshconfig.GenerateHostConfig(configOpts), nil } func clusterSelectionPrompt(ctx context.Context, client *databricks.WorkspaceClient) (string, error) { @@ -100,6 +117,12 @@ func Setup(ctx context.Context, client *databricks.WorkspaceClient, opts SetupOp return err } + if opts.Multiplex { + if err := sshconfig.EnsureSocketsDir(ctx); err != nil { + return err + } + } + hostConfig, err := generateHostConfig(ctx, opts) if err != nil { return err diff --git a/experimental/ssh/internal/setup/setup_test.go b/experimental/ssh/internal/setup/setup_test.go index 77f38cb09d..10629c8fe2 100644 --- a/experimental/ssh/internal/setup/setup_test.go +++ b/experimental/ssh/internal/setup/setup_test.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "runtime" "testing" "time" @@ -201,6 +202,43 @@ func TestGenerateHostConfig_PathEscaping(t *testing.T) { assert.Contains(t, result, fmt.Sprintf(`IdentityFile %q`, expectedPath)) } +func TestGenerateHostConfig_WithMultiplex(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + clientOpts := client.ClientOptions{ + ClusterID: "cluster-123", + AutoStartCluster: true, + ShutdownDelay: 30 * time.Second, + } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + + opts := SetupOptions{ + HostName: "test-host", + ClusterID: "cluster-123", + SSHKeysDir: tmpDir, + ShutdownDelay: 30 * time.Second, + ProxyCommand: proxyCommand, + Multiplex: true, + } + + result, err := generateHostConfig(t.Context(), opts) + require.NoError(t, err) + + assert.Contains(t, result, "Host test-host") + assert.Contains(t, result, "--cluster=cluster-123") + + if runtime.GOOS == "windows" { + assert.NotContains(t, result, "ControlMaster") + } else { + assert.Contains(t, result, "ControlMaster auto") + assert.Contains(t, result, "ControlPath") + assert.Contains(t, result, "ControlPersist 10m") + } +} + func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) { ctx := cmdio.MockDiscard(t.Context()) tmpDir := t.TempDir() diff --git a/experimental/ssh/internal/sshconfig/sshconfig.go b/experimental/ssh/internal/sshconfig/sshconfig.go index ad8ca0ee2a..3f1aa4412c 100644 --- a/experimental/ssh/internal/sshconfig/sshconfig.go +++ b/experimental/ssh/internal/sshconfig/sshconfig.go @@ -7,6 +7,7 @@ import ( "io/fs" "os" "path/filepath" + "runtime" "strings" "github.com/databricks/cli/experimental/ssh/internal/fileutil" @@ -17,6 +18,9 @@ import ( const ( // configDirName is the directory name for Databricks SSH tunnel configs, relative to the user's home directory. configDirName = ".databricks/ssh-tunnel-configs" + + // socketsDirName is the directory name for SSH ControlMaster sockets, relative to the user's home directory. + socketsDirName = ".databricks/ssh-sockets" ) func GetConfigDir(ctx context.Context) (string, error) { @@ -201,14 +205,55 @@ func PromptRecreateConfig(ctx context.Context, hostName string) (bool, error) { return response, nil } -func GenerateHostConfig(hostName, userName, identityFile, proxyCommand string) string { - return fmt.Sprintf(` -Host %s - User %s - ConnectTimeout 360 - StrictHostKeyChecking accept-new - IdentitiesOnly yes - IdentityFile %q - ProxyCommand %s -`, hostName, userName, identityFile, proxyCommand) +// GetSocketsDir returns the directory for SSH ControlMaster sockets. +func GetSocketsDir(ctx context.Context) (string, error) { + homeDir, err := env.UserHomeDir(ctx) + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + return filepath.Join(homeDir, socketsDirName), nil +} + +// EnsureSocketsDir creates the ControlMaster sockets directory if it does not exist. +func EnsureSocketsDir(ctx context.Context) error { + socketsDir, err := GetSocketsDir(ctx) + if err != nil { + return err + } + err = os.MkdirAll(socketsDir, 0o700) + if err != nil { + return fmt.Errorf("failed to create SSH sockets directory: %w", err) + } + return nil +} + +// HostConfigOptions contains the parameters for generating an SSH host config entry. +type HostConfigOptions struct { + HostName string + UserName string + IdentityFile string + ProxyCommand string + // ControlPath enables SSH ControlMaster multiplexing when non-empty. + // Ignored on Windows where ControlMaster is not supported. + ControlPath string +} + +// GenerateHostConfig generates an SSH host config entry from the given options. +func GenerateHostConfig(opts HostConfigOptions) string { + var b strings.Builder + fmt.Fprintf(&b, "\nHost %s\n", opts.HostName) + fmt.Fprintf(&b, " User %s\n", opts.UserName) + b.WriteString(" ConnectTimeout 360\n") + b.WriteString(" StrictHostKeyChecking accept-new\n") + b.WriteString(" IdentitiesOnly yes\n") + fmt.Fprintf(&b, " IdentityFile %q\n", opts.IdentityFile) + fmt.Fprintf(&b, " ProxyCommand %s\n", opts.ProxyCommand) + + if opts.ControlPath != "" && runtime.GOOS != "windows" { + b.WriteString(" ControlMaster auto\n") + fmt.Fprintf(&b, " ControlPath %s\n", opts.ControlPath) + b.WriteString(" ControlPersist 10m\n") + } + + return b.String() } diff --git a/experimental/ssh/internal/sshconfig/sshconfig_test.go b/experimental/ssh/internal/sshconfig/sshconfig_test.go index 6c453910cd..6a3ac936bd 100644 --- a/experimental/ssh/internal/sshconfig/sshconfig_test.go +++ b/experimental/ssh/internal/sshconfig/sshconfig_test.go @@ -3,6 +3,7 @@ package sshconfig import ( "os" "path/filepath" + "runtime" "testing" "github.com/databricks/cli/libs/cmdio" @@ -394,3 +395,67 @@ func TestCreateOrUpdateHostConfig_ExistingConfigWithRecreate(t *testing.T) { assert.NoError(t, err) assert.Equal(t, newConfig, string(content)) } + +func TestGetSocketsDir(t *testing.T) { + dir, err := GetSocketsDir(t.Context()) + assert.NoError(t, err) + assert.Contains(t, dir, filepath.Join(".databricks", "ssh-sockets")) +} + +func TestEnsureSocketsDir(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + err := EnsureSocketsDir(t.Context()) + require.NoError(t, err) + + socketsDir := filepath.Join(tmpDir, socketsDirName) + info, err := os.Stat(socketsDir) + require.NoError(t, err) + assert.True(t, info.IsDir()) +} + +func TestGenerateHostConfig_Basic(t *testing.T) { + config := GenerateHostConfig(HostConfigOptions{ + HostName: "my-cluster", + UserName: "root", + IdentityFile: "/home/user/.databricks/ssh-tunnel-keys/abc-123", + ProxyCommand: `"/usr/local/bin/databricks" ssh connect --proxy --cluster=abc-123`, + }) + + assert.Contains(t, config, "Host my-cluster") + assert.Contains(t, config, "User root") + assert.Contains(t, config, "ConnectTimeout 360") + assert.Contains(t, config, "StrictHostKeyChecking accept-new") + assert.Contains(t, config, "IdentitiesOnly yes") + assert.Contains(t, config, `IdentityFile "/home/user/.databricks/ssh-tunnel-keys/abc-123"`) + assert.Contains(t, config, `ProxyCommand "/usr/local/bin/databricks" ssh connect --proxy --cluster=abc-123`) + assert.NotContains(t, config, "ControlMaster") + assert.NotContains(t, config, "ControlPath") + assert.NotContains(t, config, "ControlPersist") +} + +func TestGenerateHostConfig_WithControlMaster(t *testing.T) { + config := GenerateHostConfig(HostConfigOptions{ + HostName: "my-cluster", + UserName: "root", + IdentityFile: "/home/user/.databricks/ssh-tunnel-keys/abc-123", + ProxyCommand: `"/usr/local/bin/databricks" ssh connect --proxy --cluster=abc-123`, + ControlPath: "~/.databricks/ssh-sockets/%h", + }) + + assert.Contains(t, config, "Host my-cluster") + assert.Contains(t, config, "User root") + assert.Contains(t, config, `ProxyCommand "/usr/local/bin/databricks" ssh connect --proxy --cluster=abc-123`) + + if runtime.GOOS == "windows" { + assert.NotContains(t, config, "ControlMaster") + assert.NotContains(t, config, "ControlPath") + assert.NotContains(t, config, "ControlPersist") + } else { + assert.Contains(t, config, "ControlMaster auto") + assert.Contains(t, config, "ControlPath ~/.databricks/ssh-sockets/%h") + assert.Contains(t, config, "ControlPersist 10m") + } +} From 4616522b9055b546b1e3b1ed1369a5ad4c7abd32 Mon Sep 17 00:00:00 2001 From: Anton Nekipelov <226657+anton-107@users.noreply.github.com> Date: Thu, 16 Apr 2026 18:12:27 +0200 Subject: [PATCH 2/3] Fix lint: simplify fmt.Sprintf to string concatenation Co-authored-by: Isaac --- experimental/ssh/internal/client/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 00107bcd40..11b5b61d56 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -479,7 +479,7 @@ func printSSHToolHints(ctx context.Context, hostName string) { cmdio.LogString(ctx, fmt.Sprintf("SSH config written for '%s'. You can now use SSH tools in another terminal:", hostName)) cmdio.LogString(ctx, fmt.Sprintf(" scp %s:remote-file local-file", hostName)) cmdio.LogString(ctx, fmt.Sprintf(" rsync -avz %s:remote-dir/ local-dir/", hostName)) - cmdio.LogString(ctx, fmt.Sprintf(" sftp %s", hostName)) + cmdio.LogString(ctx, " sftp "+hostName) } // getServerMetadata retrieves the server metadata from the workspace and validates it via Driver Proxy. From 8b17ba87b7f6cad6d962da11dca43c4f9aba6295 Mon Sep 17 00:00:00 2001 From: Anton Nekipelov <226657+anton-107@users.noreply.github.com> Date: Mon, 20 Apr 2026 09:26:43 +0200 Subject: [PATCH 3/3] Add --no-start flag for fast-fail liveness check in scp/rsync proxy mode When scp/rsync invokes the ProxyCommand, check if the SSH server is still alive before attempting the full setup. If the session has ended, fail immediately with a clear error message instead of uploading binaries and trying to start a new server. Co-authored-by: Isaac --- experimental/ssh/cmd/connect.go | 5 +++ experimental/ssh/internal/client/client.go | 37 +++++++++++++++++++ .../ssh/internal/client/client_test.go | 5 +++ 3 files changed, 47 insertions(+) diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index 5f7b3bce8f..0b0eee8dc6 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -38,6 +38,7 @@ the SSH server and handling the connection proxy. var environmentVersion int var noConfig bool var multiplex bool + var noStart bool cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)") cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects") @@ -76,6 +77,9 @@ the SSH server and handling the connection proxy. cmd.Flags().BoolVar(&noConfig, "no-config", false, "Do not write SSH config entry (disables scp/rsync support)") cmd.Flags().BoolVar(&multiplex, "multiplex", false, "Enable SSH connection multiplexing (ControlMaster) for faster scp/rsync") + cmd.Flags().BoolVar(&noStart, "no-start", false, "Only connect to an existing SSH server, do not start a new one") + cmd.Flags().MarkHidden("no-start") + cmd.PreRunE = func(cmd *cobra.Command, args []string) error { // CLI in the proxy mode is executed by the ssh client and can't prompt for input if proxyMode { @@ -116,6 +120,7 @@ the SSH server and handling the connection proxy. EnvironmentVersion: environmentVersion, SkipConfigWrite: noConfig, Multiplex: multiplex, + NoServerStart: noStart, AdditionalArgs: args, } if err := opts.Validate(); err != nil { diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 11b5b61d56..0c96fd4a8b 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -104,6 +104,9 @@ type ClientOptions struct { SkipConfigWrite bool // If true, enable SSH ControlMaster multiplexing for connection reuse. Multiplex bool + // If true, do not attempt to start the SSH server — only connect to an existing one. + // Used in ProxyCommand for scp/rsync where the server should already be running. + NoServerStart bool } func (o *ClientOptions) Validate() error { @@ -208,6 +211,10 @@ func (o *ClientOptions) ToProxyCommand() (string, error) { proxyCommand += " --environment-version=" + strconv.Itoa(o.EnvironmentVersion) } + if o.NoServerStart { + proxyCommand += " --no-start" + } + return proxyCommand, nil } @@ -233,6 +240,11 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt return errors.New("either --cluster or --name must be provided") } + // Fast path for scp/rsync: only connect to an existing server, don't start a new one. + if opts.ProxyMode && opts.NoServerStart && opts.ServerMetadata == "" { + return runProxyWithLivenessCheck(ctx, client, opts) + } + if !opts.ProxyMode { cmdio.LogString(ctx, fmt.Sprintf("Connecting to %s...", sessionID)) if opts.IsServerlessMode() && opts.Accelerator == "" { @@ -463,6 +475,8 @@ func writeSSHConfigForConnect(ctx context.Context, hostName, userName, keyPath s } } + // The scp/rsync ProxyCommand should not start a new server — only connect to an existing one. + opts.NoServerStart = true return ensureSSHConfigEntry(ctx, configPath, hostName, userName, keyPath, opts) } @@ -673,6 +687,29 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server return sshCmd.Run() } +// runProxyWithLivenessCheck checks if the SSH server is still alive before +// connecting. Used by scp/rsync ProxyCommands to fail fast when the session is gone. +func runProxyWithLivenessCheck(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOptions) error { + sessionID := opts.SessionIdentifier() + version := build.GetInfo().Version + clusterID := opts.ClusterID + + serverPort, _, effectiveClusterID, err := getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap) + if err != nil { + reconnectCmd := fmt.Sprintf("databricks ssh connect --cluster=%s", sessionID) + if opts.IsServerlessMode() { + reconnectCmd = fmt.Sprintf("databricks ssh connect --name=%s", sessionID) + } + return fmt.Errorf("SSH session is no longer active. Start a new one with:\n %s", reconnectCmd) + } + + err = runSSHProxy(ctx, client, serverPort, effectiveClusterID, opts) + if errors.Is(err, context.Canceled) { + return nil + } + return err +} + func runSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, serverPort int, clusterID string, opts ClientOptions) error { createConn := func(ctx context.Context, connID string) (*websocket.Conn, error) { return createWebsocketConnection(ctx, client, connID, clusterID, serverPort, opts.Liteswap) diff --git a/experimental/ssh/internal/client/client_test.go b/experimental/ssh/internal/client/client_test.go index ef9e6fb53b..33f33d45c7 100644 --- a/experimental/ssh/internal/client/client_test.go +++ b/experimental/ssh/internal/client/client_test.go @@ -224,6 +224,11 @@ func TestToProxyCommand(t *testing.T) { opts: client.ClientOptions{ClusterID: "abc-123", EnvironmentVersion: 4}, want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=0s --environment-version=4", }, + { + name: "with no-start", + opts: client.ClientOptions{ClusterID: "abc-123", NoServerStart: true}, + want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=0s --no-start", + }, } for _, tt := range tests {