Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions experimental/ssh/cmd/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ the SSH server and handling the connection proxy.
var liteswap string
var skipSettingsCheck bool
var environmentVersion int
var noConfig bool
var multiplex bool
var noStart bool

cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)")
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects")
Expand Down Expand Up @@ -71,6 +74,12 @@ the SSH server and handling the connection proxy.
cmd.Flags().IntVar(&environmentVersion, "environment-version", defaultEnvironmentVersion, "Environment version for serverless compute")
cmd.Flags().MarkHidden("environment-version")

cmd.Flags().BoolVar(&noConfig, "no-config", false, "Do not write SSH config entry (disables scp/rsync support)")
cmd.Flags().BoolVar(&multiplex, "multiplex", false, "Enable SSH connection multiplexing (ControlMaster) for faster scp/rsync")

cmd.Flags().BoolVar(&noStart, "no-start", false, "Only connect to an existing SSH server, do not start a new one")
cmd.Flags().MarkHidden("no-start")

cmd.PreRunE = func(cmd *cobra.Command, args []string) error {
// CLI in the proxy mode is executed by the ssh client and can't prompt for input
if proxyMode {
Expand Down Expand Up @@ -109,6 +118,9 @@ the SSH server and handling the connection proxy.
Liteswap: liteswap,
SkipSettingsCheck: skipSettingsCheck,
EnvironmentVersion: environmentVersion,
SkipConfigWrite: noConfig,
Multiplex: multiplex,
NoServerStart: noStart,
AdditionalArgs: args,
}
if err := opts.Validate(); err != nil {
Expand Down
3 changes: 3 additions & 0 deletions experimental/ssh/cmd/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ an SSH host configuration to your SSH config file.
var sshConfigPath string
var shutdownDelay time.Duration
var autoStartCluster bool
var multiplex bool

cmd.Flags().StringVar(&hostName, "name", "", "Host name to use in SSH config")
cmd.MarkFlagRequired("name")
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster when establishing the ssh connection")
cmd.Flags().StringVar(&sshConfigPath, "ssh-config", "", "Path to SSH config file (default ~/.ssh/config)")
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "SSH server will terminate after this delay if there are no active connections")
cmd.Flags().BoolVar(&multiplex, "multiplex", false, "Enable SSH connection multiplexing (ControlMaster) for faster scp/rsync")

cmd.PreRunE = func(cmd *cobra.Command, args []string) error {
// We want to avoid the situation where the setup command works because it pulls the auth config from a bundle,
Expand All @@ -53,6 +55,7 @@ an SSH host configuration to your SSH config file.
SSHConfigPath: sshConfigPath,
ShutdownDelay: shutdownDelay,
Profile: wsClient.Config.Profile,
Multiplex: multiplex,
}
clientOpts := client.ClientOptions{
ClusterID: setupOpts.ClusterID,
Expand Down
137 changes: 127 additions & 10 deletions experimental/ssh/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"os/signal"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
"syscall"
Expand Down Expand Up @@ -99,6 +100,13 @@
SkipSettingsCheck bool
// Environment version for serverless compute.
EnvironmentVersion int
// If true, skip writing the SSH config entry in terminal mode.
SkipConfigWrite bool
// If true, enable SSH ControlMaster multiplexing for connection reuse.
Multiplex bool
// If true, do not attempt to start the SSH server — only connect to an existing one.
// Used in ProxyCommand for scp/rsync where the server should already be running.
NoServerStart bool
}

func (o *ClientOptions) Validate() error {
Expand Down Expand Up @@ -203,18 +211,27 @@
proxyCommand += " --environment-version=" + strconv.Itoa(o.EnvironmentVersion)
}

if o.NoServerStart {
proxyCommand += " --no-start"
}

return proxyCommand, nil
}

func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOptions) error {
// In proxy mode, the CLI runs as a ProxyCommand subprocess of ssh/scp/rsync.
// Suppress all user-facing output so it doesn't interfere with the parent tool.
if opts.ProxyMode {
ctx = cmdio.MockDiscard(ctx)
}

ctx, cancel := context.WithCancel(ctx)
defer cancel()

sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt, syscall.SIGHUP, syscall.SIGTERM)
go func() {
<-sigCh
cmdio.LogString(ctx, "Received termination signal, cleaning up...")
cancel()
}()

Expand All @@ -223,6 +240,11 @@
return errors.New("either --cluster or --name must be provided")
}

// Fast path for scp/rsync: only connect to an existing server, don't start a new one.
if opts.ProxyMode && opts.NoServerStart && opts.ServerMetadata == "" {
return runProxyWithLivenessCheck(ctx, client, opts)
}

if !opts.ProxyMode {
cmdio.LogString(ctx, fmt.Sprintf("Connecting to %s...", sessionID))
if opts.IsServerlessMode() && opts.Accelerator == "" {
Expand Down Expand Up @@ -350,10 +372,24 @@
}

if opts.ProxyMode {
return runSSHProxy(ctx, client, serverPort, clusterID, opts)
err := runSSHProxy(ctx, client, serverPort, clusterID, opts)
// context.Canceled is the normal exit path when the SSH client (scp/rsync) disconnects.
if errors.Is(err, context.Canceled) {
return nil
}
return err
} else if opts.IDE != "" {
return runIDE(ctx, client, userName, keyPath, serverPort, clusterID, opts)
} else {
hostName := opts.SessionIdentifier()
if !opts.SkipConfigWrite {
if err := writeSSHConfigForConnect(ctx, hostName, userName, keyPath, opts); err != nil {
// Non-fatal: log and continue with the SSH session
log.Warnf(ctx, "Failed to write SSH config entry: %v", err)
} else {
printSSHToolHints(ctx, hostName)
}
}
log.Infof(ctx, "Additional SSH arguments: %v", opts.AdditionalArgs)
return spawnSSHClient(ctx, userName, keyPath, serverPort, clusterID, opts)
}
Expand All @@ -377,31 +413,44 @@
return fmt.Errorf("failed to get SSH config path: %w", err)
}

err = ensureSSHConfigEntry(ctx, configPath, connectionName, userName, keyPath, serverPort, clusterID, opts)
err = ensureSSHConfigEntry(ctx, configPath, connectionName, userName, keyPath, opts)
if err != nil {
return fmt.Errorf("failed to ensure SSH config entry: %w", err)
}

return vscode.LaunchIDE(ctx, opts.IDE, connectionName, userName, currentUser.UserName)
}

func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error {
func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, keyPath string, opts ClientOptions) error {
// Ensure the Include directive exists in the main SSH config
err := sshconfig.EnsureIncludeDirective(ctx, configPath)
if err != nil {
return err
}

// Generate ProxyCommand with server metadata
optsWithMetadata := opts
optsWithMetadata.ServerMetadata = FormatMetadata(userName, serverPort, clusterID)

proxyCommand, err := optsWithMetadata.ToProxyCommand()
// Generate ProxyCommand without metadata so the config is resilient to server restarts.
// The inline SSH invocation passes metadata separately for fast first-connection.
proxyCommand, err := opts.ToProxyCommand()
if err != nil {
return fmt.Errorf("failed to generate ProxyCommand: %w", err)
}

hostConfig := sshconfig.GenerateHostConfig(hostName, userName, keyPath, proxyCommand)
configOpts := sshconfig.HostConfigOptions{
HostName: hostName,
UserName: userName,
IdentityFile: keyPath,
ProxyCommand: proxyCommand,
}

if opts.Multiplex {
controlPath, cpErr := controlSocketPath(ctx)
if cpErr != nil {
return cpErr
}
configOpts.ControlPath = controlPath
}

hostConfig := sshconfig.GenerateHostConfig(configOpts)

_, err = sshconfig.CreateOrUpdateHostConfig(ctx, hostName, hostConfig, true)
if err != nil {
Expand All @@ -412,6 +461,41 @@
return nil
}

// writeSSHConfigForConnect writes an SSH config entry so that SSH-based tools
// (scp, rsync, sftp) can connect using the same hostname.
func writeSSHConfigForConnect(ctx context.Context, hostName, userName, keyPath string, opts ClientOptions) error {
configPath, err := sshconfig.GetMainConfigPath(ctx)
if err != nil {
return err
}

if opts.Multiplex {
if err := sshconfig.EnsureSocketsDir(ctx); err != nil {
return err
}
}

// The scp/rsync ProxyCommand should not start a new server — only connect to an existing one.
opts.NoServerStart = true
return ensureSSHConfigEntry(ctx, configPath, hostName, userName, keyPath, opts)
}

// controlSocketPath returns the ControlPath pattern for SSH multiplexing.
func controlSocketPath(ctx context.Context) (string, error) {
socketsDir, err := sshconfig.GetSocketsDir(ctx)
if err != nil {
return "", err
}
return filepath.ToSlash(filepath.Join(socketsDir, "%h")), nil
}

func printSSHToolHints(ctx context.Context, hostName string) {
cmdio.LogString(ctx, fmt.Sprintf("SSH config written for '%s'. You can now use SSH tools in another terminal:", hostName))
cmdio.LogString(ctx, fmt.Sprintf(" scp %s:remote-file local-file", hostName))
cmdio.LogString(ctx, fmt.Sprintf(" rsync -avz %s:remote-dir/ local-dir/", hostName))
cmdio.LogString(ctx, " sftp "+hostName)
}

// getServerMetadata retrieves the server metadata from the workspace and validates it via Driver Proxy.
// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless).
// For dedicated clusters, clusterID should be the same as sessionID.
Expand Down Expand Up @@ -580,6 +664,16 @@
if opts.UserKnownHostsFile != "" {
sshArgs = append(sshArgs, "-o", "UserKnownHostsFile="+opts.UserKnownHostsFile)
}
if opts.Multiplex && runtime.GOOS != "windows" {
cp, cpErr := controlSocketPath(ctx)
if cpErr == nil {
sshArgs = append(sshArgs,
"-o", "ControlMaster=auto",
"-o", "ControlPath="+cp,
"-o", "ControlPersist=10m",
)
}
}
sshArgs = append(sshArgs, hostName)
sshArgs = append(sshArgs, opts.AdditionalArgs...)

Expand All @@ -593,6 +687,29 @@
return sshCmd.Run()
}

// runProxyWithLivenessCheck checks if the SSH server is still alive before
// connecting. Used by scp/rsync ProxyCommands to fail fast when the session is gone.
func runProxyWithLivenessCheck(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOptions) error {
sessionID := opts.SessionIdentifier()
version := build.GetInfo().Version
clusterID := opts.ClusterID

serverPort, _, effectiveClusterID, err := getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap)
if err != nil {
reconnectCmd := fmt.Sprintf("databricks ssh connect --cluster=%s", sessionID)

Check failure on line 699 in experimental/ssh/internal/client/client.go

View workflow job for this annotation

GitHub Actions / lint

string-format: fmt.Sprintf can be replaced with string concatenation (perfsprint)
if opts.IsServerlessMode() {
reconnectCmd = fmt.Sprintf("databricks ssh connect --name=%s", sessionID)

Check failure on line 701 in experimental/ssh/internal/client/client.go

View workflow job for this annotation

GitHub Actions / lint

string-format: fmt.Sprintf can be replaced with string concatenation (perfsprint)
}
return fmt.Errorf("SSH session is no longer active. Start a new one with:\n %s", reconnectCmd)
}

err = runSSHProxy(ctx, client, serverPort, effectiveClusterID, opts)
if errors.Is(err, context.Canceled) {
return nil
}
return err
}

func runSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, serverPort int, clusterID string, opts ClientOptions) error {
createConn := func(ctx context.Context, connID string) (*websocket.Conn, error) {
return createWebsocketConnection(ctx, client, connID, clusterID, serverPort, opts.Liteswap)
Expand Down
5 changes: 5 additions & 0 deletions experimental/ssh/internal/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,11 @@ func TestToProxyCommand(t *testing.T) {
opts: client.ClientOptions{ClusterID: "abc-123", EnvironmentVersion: 4},
want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=0s --environment-version=4",
},
{
name: "with no-start",
opts: client.ClientOptions{ClusterID: "abc-123", NoServerStart: true},
want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=0s --no-start",
},
}

for _, tt := range tests {
Expand Down
27 changes: 25 additions & 2 deletions experimental/ssh/internal/setup/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"path/filepath"
"time"

"github.com/databricks/cli/experimental/ssh/internal/keys"
Expand All @@ -30,6 +31,8 @@ type SetupOptions struct {
Profile string
// Proxy command to use for the SSH connection
ProxyCommand string
// If true, enable SSH ControlMaster multiplexing for connection reuse by scp/rsync/sftp.
Multiplex bool
}

func validateClusterAccess(ctx context.Context, client *databricks.WorkspaceClient, clusterID string) error {
Expand All @@ -49,8 +52,22 @@ func generateHostConfig(ctx context.Context, opts SetupOptions) (string, error)
return "", fmt.Errorf("failed to get local keys folder: %w", err)
}

hostConfig := sshconfig.GenerateHostConfig(opts.HostName, "root", identityFilePath, opts.ProxyCommand)
return hostConfig, nil
configOpts := sshconfig.HostConfigOptions{
HostName: opts.HostName,
UserName: "root",
IdentityFile: identityFilePath,
ProxyCommand: opts.ProxyCommand,
}

if opts.Multiplex {
socketsDir, sockErr := sshconfig.GetSocketsDir(ctx)
if sockErr != nil {
return "", sockErr
}
configOpts.ControlPath = filepath.ToSlash(filepath.Join(socketsDir, "%h"))
}

return sshconfig.GenerateHostConfig(configOpts), nil
}

func clusterSelectionPrompt(ctx context.Context, client *databricks.WorkspaceClient) (string, error) {
Expand Down Expand Up @@ -100,6 +117,12 @@ func Setup(ctx context.Context, client *databricks.WorkspaceClient, opts SetupOp
return err
}

if opts.Multiplex {
if err := sshconfig.EnsureSocketsDir(ctx); err != nil {
return err
}
}

hostConfig, err := generateHostConfig(ctx, opts)
if err != nil {
return err
Expand Down
Loading
Loading