diff --git a/.github/workflows/go_tests.yml b/.github/workflows/go_tests.yml new file mode 100644 index 00000000..0cb83747 --- /dev/null +++ b/.github/workflows/go_tests.yml @@ -0,0 +1,68 @@ +name: Test Go SDK + +on: + workflow_call: + secrets: + E2B_API_KEY: + required: false + inputs: + E2B_DOMAIN: + required: false + type: string + E2B_TESTS_TEMPLATE: + required: false + type: string + +permissions: + contents: read + +jobs: + test: + defaults: + run: + working-directory: ./go + name: Go SDK - Build and test + runs-on: ubuntu-22.04 + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Parse .tool-versions + uses: wistia/parse-tool-versions@v2.1.1 + with: + filename: '.tool-versions' + uppercase: 'true' + prefix: 'tool_version_' + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '${{ env.TOOL_VERSION_GOLANG }}' + cache-dependency-path: go/go.sum + + - name: Download dependencies + run: go mod download + + - name: Verify dependencies + run: go mod verify + + - name: Build + run: go build ./... + + - name: Vet + run: go vet ./... + + - name: Run tests + run: go test -race -coverprofile=coverage.out -covermode=atomic ./... + env: + E2B_API_KEY: ${{ secrets.E2B_API_KEY }} + E2B_DOMAIN: ${{ inputs.E2B_DOMAIN }} + E2B_TESTS_TEMPLATE: ${{ inputs.E2B_TESTS_TEMPLATE }} + + - name: Upload coverage report + if: always() + uses: actions/upload-artifact@v4 + with: + name: go-coverage + path: go/coverage.out + if-no-files-found: ignore diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 08beb345..0ee14044 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -59,6 +59,27 @@ jobs: poetry install --with dev pip install ruff=="0.11.12" + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '${{ env.TOOL_VERSION_GOLANG }}' + cache-dependency-path: go/go.sum + + - name: Check Go formatting + working-directory: go + run: | + if [[ -n "$(gofmt -l .)" ]]; then + echo "❌ Go files are not formatted properly:" + gofmt -d . + exit 1 + else + echo "✅ Go files are properly formatted." + fi + + - name: Go vet + working-directory: go + run: go vet ./... + - name: Run linting run: | pnpm run lint diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 7ed194c9..462b328a 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -36,6 +36,14 @@ jobs: with: E2B_DOMAIN: ${{ vars.E2B_DOMAIN }} E2B_TESTS_TEMPLATE: ${{ needs.build-template.outputs.template_id }} + go-sdk: + uses: ./.github/workflows/go_tests.yml + needs: build-template + secrets: + E2B_API_KEY: ${{ secrets.E2B_API_KEY }} + with: + E2B_DOMAIN: ${{ vars.E2B_DOMAIN }} + E2B_TESTS_TEMPLATE: ${{ needs.build-template.outputs.template_id }} performance-tests: uses: ./.github/workflows/performance_tests.yml needs: build-template @@ -46,7 +54,7 @@ jobs: E2B_TESTS_TEMPLATE: ${{ needs.build-template.outputs.template_id }} cleanup-build-template: uses: ./.github/workflows/cleanup_build_template.yml - needs: [build-template, js-sdk, python-sdk, performance-tests] + needs: [build-template, js-sdk, python-sdk, go-sdk, performance-tests] if: always() && !contains(needs.build-template.result, 'failure') && !contains(needs.build-template.result, 'cancelled') secrets: E2B_TESTS_ACCESS_TOKEN: ${{ secrets.E2B_TESTS_ACCESS_TOKEN }} diff --git a/.tool-versions b/.tool-versions index d98d583a..b7594830 100644 --- a/.tool-versions +++ b/.tool-versions @@ -1,2 +1,3 @@ python 3.10 poetry 1.8.5 +golang 1.21.5 diff --git a/go/README.md b/go/README.md new file mode 100644 index 00000000..6eeeb117 --- /dev/null +++ b/go/README.md @@ -0,0 +1,171 @@ +# E2B Code Interpreter — Go SDK + +Go SDK for the [E2B](https://e2b.dev) Code Interpreter. It lets you run +AI-generated code inside secure, isolated E2B sandboxes and get back rich, +structured results (text, HTML, images, chart data, …). + +This package mirrors the features of the official +[Python](../python) and [JavaScript](../js) SDKs. + +## Install + +```bash +go get github.com/e2b-dev/codeinterpreter-go +``` + +> **Go version**: 1.21+ + +## Get your API key + +1. Sign up at [e2b.dev](https://e2b.dev). +2. Grab an API key from the [dashboard](https://e2b.dev/dashboard?tab=keys). +3. Export it: + +```bash +export E2B_API_KEY=e2b_*** +``` + +## Quick start + +```go +package main + +import ( + "context" + "fmt" + "log" + "time" + + codeinterpreter "github.com/e2b-dev/codeinterpreter-go" +) + +func main() { + ctx := context.Background() + + sbx, err := codeinterpreter.Create(ctx, &codeinterpreter.SandboxOpts{ + Timeout: 60 * time.Second, + }) + if err != nil { log.Fatal(err) } + defer sbx.Kill(ctx) + + _, _ = sbx.RunCode(ctx, "x = 1", nil) + + exec, err := sbx.RunCode(ctx, "x += 1; x", nil) + if err != nil { log.Fatal(err) } + + fmt.Println(exec.Text()) // "2" +} +``` + +## Features + +The Go SDK exposes the same surface as the Python / JS SDKs: + +### Sandbox lifecycle + +| Function | Description | +|---|---| +| `Create(ctx, opts)` | Start a new sandbox. | +| `Connect(ctx, id, opts)` | Attach to a running sandbox. | +| `List(ctx, opts)` | List sandboxes for the API key. | +| `Sandbox.Kill(ctx)` | Terminate the sandbox. | +| `Sandbox.SetTimeout(ctx, d)` | Extend the sandbox lifetime. | +| `Sandbox.IsRunning(ctx)` | Health check. | +| `Sandbox.GetInfo(ctx)` | Metadata, start time, etc. | +| `Sandbox.GetHost(port)` | Hostname for a port exposed by the sandbox. | + +### Code execution + +| Function | Description | +|---|---| +| `Sandbox.RunCode(ctx, code, opts)` | Execute code (any supported language). | + +`RunCodeOpts` lets you pass: + +* `Language` — `"python"` / `"javascript"` / `"typescript"` / `"r"` / `"java"` / `"bash"` or any custom kernel id. +* `Context` — a pre-created `*Context` (mutually exclusive with `Language`). +* `Envs` — extra environment variables. +* `Timeout` / `RequestTimeout` — execution / request timeouts. +* `OnStdout`, `OnStderr`, `OnResult`, `OnError` — streaming callbacks. + +### Code contexts (Jupyter kernels) + +| Function | Description | +|---|---| +| `Sandbox.CreateCodeContext(ctx, opts)` | Create a fresh kernel. | +| `Sandbox.ListCodeContexts(ctx)` | List known kernels. | +| `Sandbox.RestartCodeContext(ctx, c)` | Restart a kernel (clears state). | +| `Sandbox.RemoveCodeContext(ctx, c)` | Terminate a kernel. | + +### Result / Execution model + +Every `RunCode` call returns an `*Execution`: + +```go +type Execution struct { + Results []*Result + Logs Logs // stdout / stderr lines + Error *ExecutionError // nil on success + ExecutionCount int +} +``` + +Each `Result` may carry multiple representations of the same value: `Text`, +`HTML`, `Markdown`, `SVG`, `PNG`, `JPEG`, `PDF`, `LaTeX`, `JSON`, `JavaScript`, +`Data`, `Chart`, plus arbitrary `Extra` MIME types. + +### Charts + +`Result.Chart` is a `Chart` interface — type-assert it to inspect the +structured data: + +```go +switch c := result.Chart.(type) { +case *codeinterpreter.LineChart: + for _, series := range c.Points { ... } +case *codeinterpreter.BarChart: + for _, bar := range c.Bars { ... } +case *codeinterpreter.PieChart: + for _, slice := range c.Slices { ... } +case *codeinterpreter.BoxAndWhiskerChart: + ... +case *codeinterpreter.ScatterChart: + ... +case *codeinterpreter.SuperChart: + for _, sub := range c.Charts { ... } +} +``` + +## Streaming output + +```go +exec, err := sbx.RunCode(ctx, "for i in range(5): print(i)", &codeinterpreter.RunCodeOpts{ + OnStdout: func(msg codeinterpreter.OutputMessage) { + fmt.Println(">", msg.Line) + }, + OnResult: func(r *codeinterpreter.Result) { + fmt.Println("got result with formats:", r.Formats()) + }, + OnError: func(e *codeinterpreter.ExecutionError) { + fmt.Println("error:", e.Name, e.Value) + }, +}) +``` + +## Error types + +| Error | When | +|---|---| +| `*AuthenticationError` | Invalid / missing API key. | +| `*NotFoundError` | Resource (sandbox, context) not found. | +| `*TimeoutError` | Request or execution timed out. | +| `*RateLimitError` | Hit E2B rate limit. | +| `*InvalidArgumentError` | Bad arguments (e.g. both `Language` + `Context`). | +| `*SandboxError` | Generic error from the backend. | + +Use Go's type assertion / `errors.As` to discriminate between them. + +## Check docs + +Visit the [E2B documentation](https://e2b.dev/docs) for more details. + diff --git a/go/charts.go b/go/charts.go new file mode 100644 index 00000000..1d5c30c4 --- /dev/null +++ b/go/charts.go @@ -0,0 +1,374 @@ +package codeinterpreter + +import "encoding/json" + +// ChartType represents the kind of chart returned by the server. +type ChartType string + +const ( + ChartTypeLine ChartType = "line" + ChartTypeScatter ChartType = "scatter" + ChartTypeBar ChartType = "bar" + ChartTypePie ChartType = "pie" + ChartTypeBoxAndWhisker ChartType = "box_and_whisker" + ChartTypeSuperChart ChartType = "superchart" + ChartTypeUnknown ChartType = "unknown" +) + +// ScaleType represents an axis scale type (linear, log, etc.) +type ScaleType string + +const ( + ScaleTypeLinear ScaleType = "linear" + ScaleTypeDatetime ScaleType = "datetime" + ScaleTypeCategorical ScaleType = "categorical" + ScaleTypeLog ScaleType = "log" + ScaleTypeSymlog ScaleType = "symlog" + ScaleTypeLogit ScaleType = "logit" + ScaleTypeFunction ScaleType = "function" + ScaleTypeFunctionLog ScaleType = "functionlog" + ScaleTypeAsinh ScaleType = "asinh" + ScaleTypeUnknown ScaleType = "unknown" +) + +// Chart is the common interface implemented by all concrete chart types. +// +// Use a type switch on the concrete types to inspect specialized fields, e.g. +// +// switch c := result.Chart.(type) { +// case *LineChart: ... +// case *BarChart: ... +// } +type Chart interface { + ChartType() ChartType + ChartTitle() string + // ToDict returns the raw JSON representation of the chart. + ToJSON() map[string]interface{} +} + +// BaseChart contains the fields shared by every chart type. +type BaseChart struct { + Type ChartType `json:"type"` + Title string `json:"title"` + Elements []interface{} `json:"elements"` + raw map[string]interface{} `json:"-"` +} + +func (c *BaseChart) ChartType() ChartType { return c.Type } +func (c *BaseChart) ChartTitle() string { return c.Title } +func (c *BaseChart) ToJSON() map[string]interface{} { return c.raw } + +// Chart2D is the base for charts that live on a 2D plane. +type Chart2D struct { + BaseChart + XLabel string `json:"x_label,omitempty"` + YLabel string `json:"y_label,omitempty"` + XUnit string `json:"x_unit,omitempty"` + YUnit string `json:"y_unit,omitempty"` +} + +// PointData is one series in a point based chart (line/scatter). +type PointData struct { + Label string `json:"label"` + Points [][2]interface{} `json:"points"` +} + +// PointChart is the base for line/scatter. +type PointChart struct { + Chart2D + XTicks []interface{} `json:"x_ticks"` + XTickLabels []string `json:"x_tick_labels"` + XScale ScaleType `json:"x_scale"` + YTicks []interface{} `json:"y_ticks"` + YTickLabels []string `json:"y_tick_labels"` + YScale ScaleType `json:"y_scale"` + Points []PointData `json:"-"` +} + +// LineChart represents a line chart. +type LineChart struct { + PointChart +} + +// ScatterChart represents a scatter chart. +type ScatterChart struct { + PointChart +} + +// BarData represents a single bar in a bar chart. +type BarData struct { + Label string `json:"label"` + Value string `json:"value"` + Group string `json:"group"` +} + +// BarChart represents a bar chart. +type BarChart struct { + Chart2D + Bars []BarData `json:"-"` +} + +// PieData represents a slice of a pie chart. +type PieData struct { + Label string `json:"label"` + Angle float64 `json:"angle"` + Radius float64 `json:"radius"` +} + +// PieChart represents a pie chart. +type PieChart struct { + BaseChart + Slices []PieData `json:"-"` +} + +// BoxAndWhiskerData represents one box-and-whisker series. +type BoxAndWhiskerData struct { + Label string `json:"label"` + Min float64 `json:"min"` + FirstQuartile float64 `json:"first_quartile"` + Median float64 `json:"median"` + ThirdQuartile float64 `json:"third_quartile"` + Max float64 `json:"max"` + Outliers []float64 `json:"outliers"` +} + +// BoxAndWhiskerChart represents a box-and-whisker chart. +type BoxAndWhiskerChart struct { + Chart2D + Boxes []BoxAndWhiskerData `json:"-"` +} + +// SuperChart is a composite chart containing multiple sub-charts. +type SuperChart struct { + BaseChart + Charts []Chart `json:"-"` +} + +// UnknownChart is used when the server returns a chart type that the SDK does +// not yet understand; all data is still accessible through ToDict(). +type UnknownChart struct { + BaseChart +} + +// deserializeChart converts the raw JSON payload coming from the server into +// the matching Chart implementation. +func deserializeChart(data map[string]interface{}) Chart { + if data == nil { + return nil + } + + typeStr, _ := data["type"].(string) + ct := ChartType(typeStr) + + base := BaseChart{ + Type: ct, + Title: getString(data, "title"), + raw: data, + } + if el, ok := data["elements"].([]interface{}); ok { + base.Elements = el + } + + switch ct { + case ChartTypeLine: + pc := buildPointChart(base, data) + return &LineChart{PointChart: pc} + case ChartTypeScatter: + pc := buildPointChart(base, data) + return &ScatterChart{PointChart: pc} + case ChartTypeBar: + c2 := Chart2D{ + BaseChart: base, + XLabel: getString(data, "x_label"), + YLabel: getString(data, "y_label"), + XUnit: getString(data, "x_unit"), + YUnit: getString(data, "y_unit"), + } + bc := &BarChart{Chart2D: c2} + bc.Type = ChartTypeBar + if raw, ok := data["elements"].([]interface{}); ok { + for _, item := range raw { + m, ok := item.(map[string]interface{}) + if !ok { + continue + } + bc.Bars = append(bc.Bars, BarData{ + Label: getString(m, "label"), + Value: getString(m, "value"), + Group: getString(m, "group"), + }) + } + } + return bc + case ChartTypePie: + pc := &PieChart{BaseChart: base} + pc.Type = ChartTypePie + if raw, ok := data["elements"].([]interface{}); ok { + for _, item := range raw { + m, ok := item.(map[string]interface{}) + if !ok { + continue + } + pc.Slices = append(pc.Slices, PieData{ + Label: getString(m, "label"), + Angle: getFloat(m, "angle"), + Radius: getFloat(m, "radius"), + }) + } + } + return pc + case ChartTypeBoxAndWhisker: + c2 := Chart2D{ + BaseChart: base, + XLabel: getString(data, "x_label"), + YLabel: getString(data, "y_label"), + XUnit: getString(data, "x_unit"), + YUnit: getString(data, "y_unit"), + } + bwc := &BoxAndWhiskerChart{Chart2D: c2} + bwc.Type = ChartTypeBoxAndWhisker + if raw, ok := data["elements"].([]interface{}); ok { + for _, item := range raw { + m, ok := item.(map[string]interface{}) + if !ok { + continue + } + bwc.Boxes = append(bwc.Boxes, BoxAndWhiskerData{ + Label: getString(m, "label"), + Min: getFloat(m, "min"), + FirstQuartile: getFloat(m, "first_quartile"), + Median: getFloat(m, "median"), + ThirdQuartile: getFloat(m, "third_quartile"), + Max: getFloat(m, "max"), + Outliers: getFloatSlice(m, "outliers"), + }) + } + } + return bwc + case ChartTypeSuperChart: + sc := &SuperChart{BaseChart: base} + sc.Type = ChartTypeSuperChart + if raw, ok := data["elements"].([]interface{}); ok { + for _, item := range raw { + if m, ok := item.(map[string]interface{}); ok { + sc.Charts = append(sc.Charts, deserializeChart(m)) + } + } + } + return sc + default: + base.Type = ChartTypeUnknown + return &UnknownChart{BaseChart: base} + } +} + +func buildPointChart(base BaseChart, data map[string]interface{}) PointChart { + c2 := Chart2D{ + BaseChart: base, + XLabel: getString(data, "x_label"), + YLabel: getString(data, "y_label"), + XUnit: getString(data, "x_unit"), + YUnit: getString(data, "y_unit"), + } + + pc := PointChart{ + Chart2D: c2, + XTicks: getInterfaceSlice(data, "x_ticks"), + XTickLabels: getStringSlice(data, "x_tick_labels"), + XScale: ScaleType(getString(data, "x_scale")), + YTicks: getInterfaceSlice(data, "y_ticks"), + YTickLabels: getStringSlice(data, "y_tick_labels"), + YScale: ScaleType(getString(data, "y_scale")), + } + + if raw, ok := data["elements"].([]interface{}); ok { + for _, item := range raw { + m, ok := item.(map[string]interface{}) + if !ok { + continue + } + pd := PointData{Label: getString(m, "label")} + if pts, ok := m["points"].([]interface{}); ok { + for _, p := range pts { + arr, ok := p.([]interface{}) + if !ok || len(arr) < 2 { + continue + } + pd.Points = append(pd.Points, [2]interface{}{arr[0], arr[1]}) + } + } + pc.Points = append(pc.Points, pd) + } + } + return pc +} + +func getString(m map[string]interface{}, key string) string { + if v, ok := m[key]; ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +func getFloat(m map[string]interface{}, key string) float64 { + if v, ok := m[key]; ok { + switch n := v.(type) { + case float64: + return n + case float32: + return float64(n) + case int: + return float64(n) + case int64: + return float64(n) + case json.Number: + f, _ := n.Float64() + return f + } + } + return 0 +} + +func getFloatSlice(m map[string]interface{}, key string) []float64 { + v, ok := m[key].([]interface{}) + if !ok { + return nil + } + out := make([]float64, 0, len(v)) + for _, item := range v { + switch n := item.(type) { + case float64: + out = append(out, n) + case int: + out = append(out, float64(n)) + case json.Number: + f, _ := n.Float64() + out = append(out, f) + } + } + return out +} + +func getStringSlice(m map[string]interface{}, key string) []string { + v, ok := m[key].([]interface{}) + if !ok { + return nil + } + out := make([]string, 0, len(v)) + for _, item := range v { + if s, ok := item.(string); ok { + out = append(out, s) + } + } + return out +} + +func getInterfaceSlice(m map[string]interface{}, key string) []interface{} { + v, ok := m[key].([]interface{}) + if !ok { + return nil + } + return v +} diff --git a/go/charts_test.go b/go/charts_test.go new file mode 100644 index 00000000..b9dcd862 --- /dev/null +++ b/go/charts_test.go @@ -0,0 +1,389 @@ +package codeinterpreter + +import ( + "encoding/json" + "testing" +) + +func mustUnmarshal(t *testing.T, raw string) map[string]interface{} { + t.Helper() + var m map[string]interface{} + if err := json.Unmarshal([]byte(raw), &m); err != nil { + t.Fatalf("unmarshal: %v", err) + } + return m +} + +func TestDeserializeChart_Bar(t *testing.T) { + raw := `{ + "type": "bar", + "title": "Book Sales by Authors", + "x_label": "Authors", + "y_label": "Number of Books Sold", + "elements": [ + {"label": "Author A", "value": "100", "group": "Books Sold"}, + {"label": "Author B", "value": "200", "group": "Books Sold"}, + {"label": "Author C", "value": "300", "group": "Books Sold"}, + {"label": "Author D", "value": "400", "group": "Books Sold"} + ] + }` + c := deserializeChart(mustUnmarshal(t, raw)) + bc, ok := c.(*BarChart) + if !ok { + t.Fatalf("expected *BarChart, got %T", c) + } + if bc.Type != ChartTypeBar { + t.Errorf("unexpected type: %s", bc.Type) + } + if bc.Title != "Book Sales by Authors" { + t.Errorf("unexpected title: %q", bc.Title) + } + if bc.XLabel != "Authors" || bc.YLabel != "Number of Books Sold" { + t.Errorf("labels not parsed: x=%q y=%q", bc.XLabel, bc.YLabel) + } + if bc.XUnit != "" || bc.YUnit != "" { + t.Errorf("expected empty units, got x=%q y=%q", bc.XUnit, bc.YUnit) + } + if len(bc.Bars) != 4 { + t.Fatalf("expected 4 bars, got %d", len(bc.Bars)) + } + + labels := []string{} + values := []string{} + groups := []string{} + for _, b := range bc.Bars { + labels = append(labels, b.Label) + values = append(values, b.Value) + groups = append(groups, b.Group) + } + want := []string{"Author A", "Author B", "Author C", "Author D"} + for i := range want { + if labels[i] != want[i] { + t.Errorf("label[%d] got %q want %q", i, labels[i], want[i]) + } + } + for i, v := range []string{"100", "200", "300", "400"} { + if values[i] != v { + t.Errorf("value[%d] got %q want %q", i, values[i], v) + } + } + for _, g := range groups { + if g != "Books Sold" { + t.Errorf("unexpected group: %q", g) + } + } +} + +func TestDeserializeChart_PieDetailed(t *testing.T) { + raw := `{ + "type": "pie", + "title": "Will I wake up early tomorrow?", + "elements": [ + {"label": "No", "angle": 324, "radius": 1}, + {"label": "No, in blue", "angle": 36, "radius": 1} + ] + }` + c := deserializeChart(mustUnmarshal(t, raw)) + pc, ok := c.(*PieChart) + if !ok { + t.Fatalf("expected *PieChart, got %T", c) + } + if pc.Title != "Will I wake up early tomorrow?" { + t.Errorf("unexpected title: %q", pc.Title) + } + if len(pc.Slices) != 2 { + t.Fatalf("expected 2 slices, got %d", len(pc.Slices)) + } + if pc.Slices[0].Label != "No" || pc.Slices[0].Angle != 324 || pc.Slices[0].Radius != 1 { + t.Errorf("first slice wrong: %+v", pc.Slices[0]) + } + if pc.Slices[1].Label != "No, in blue" || pc.Slices[1].Angle != 36 { + t.Errorf("second slice wrong: %+v", pc.Slices[1]) + } +} + +func TestDeserializeChart_Scatter(t *testing.T) { + raw := `{ + "type": "scatter", + "title": null, + "x_label": "A", + "y_label": "B", + "x_scale": "linear", + "y_scale": "linear", + "x_ticks": [0.1, 0.2, 0.3], + "y_ticks": [0.5, 0.6, 0.7], + "x_tick_labels": ["0.1", "0.2", "0.3"], + "y_tick_labels": ["0.5", "0.6", "0.7"], + "elements": [ + {"label": "Dataset 1", "points": [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]}, + {"label": "Dataset 2", "points": [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10]]} + ] + }` + + c := deserializeChart(mustUnmarshal(t, raw)) + sc, ok := c.(*ScatterChart) + if !ok { + t.Fatalf("expected *ScatterChart, got %T", c) + } + if sc.Title != "" { + t.Errorf("expected empty title, got %q", sc.Title) + } + if sc.XLabel != "A" || sc.YLabel != "B" { + t.Errorf("labels: %+v", sc) + } + if sc.XScale != ScaleTypeLinear || sc.YScale != ScaleTypeLinear { + t.Errorf("scale wrong: x=%s y=%s", sc.XScale, sc.YScale) + } + if len(sc.Points) != 2 { + t.Fatalf("expected 2 datasets, got %d", len(sc.Points)) + } + if sc.Points[0].Label != "Dataset 1" || len(sc.Points[0].Points) != 5 { + t.Errorf("dataset 1 wrong: %+v", sc.Points[0]) + } + if sc.Points[1].Label != "Dataset 2" || len(sc.Points[1].Points) != 10 { + t.Errorf("dataset 2 wrong: %+v", sc.Points[1]) + } +} + +func TestDeserializeChart_BoxAndWhisker(t *testing.T) { + raw := `{ + "type": "box_and_whisker", + "title": "Exam Scores Distribution", + "x_label": "Class", + "y_label": "Score", + "elements": [ + {"label": "Class A", "min": 78, "first_quartile": 85, "median": 88, "third_quartile": 90, "max": 92, "outliers": []}, + {"label": "Class B", "min": 84, "first_quartile": 84.75, "median": 88, "third_quartile": 90.5, "max": 95, "outliers": [76]}, + {"label": "Class C", "min": 75, "first_quartile": 79, "median": 82, "third_quartile": 86, "max": 88, "outliers": []} + ] + }` + c := deserializeChart(mustUnmarshal(t, raw)) + bw, ok := c.(*BoxAndWhiskerChart) + if !ok { + t.Fatalf("expected *BoxAndWhiskerChart, got %T", c) + } + if bw.Type != ChartTypeBoxAndWhisker { + t.Errorf("unexpected type: %s", bw.Type) + } + if bw.XLabel != "Class" || bw.YLabel != "Score" { + t.Errorf("labels: %+v", bw) + } + if bw.XUnit != "" || bw.YUnit != "" { + t.Errorf("expected empty units") + } + if len(bw.Boxes) != 3 { + t.Fatalf("expected 3 boxes, got %d", len(bw.Boxes)) + } + cases := []struct { + label string + min, fq, median, tq, max float64 + outlierLen int + }{ + {"Class A", 78, 85, 88, 90, 92, 0}, + {"Class B", 84, 84.75, 88, 90.5, 95, 1}, + {"Class C", 75, 79, 82, 86, 88, 0}, + } + for i, c := range cases { + b := bw.Boxes[i] + if b.Label != c.label || b.Min != c.min || b.FirstQuartile != c.fq || + b.Median != c.median || b.ThirdQuartile != c.tq || b.Max != c.max || + len(b.Outliers) != c.outlierLen { + t.Errorf("box[%d] mismatch: %+v (want %+v)", i, b, c) + } + } + if bw.Boxes[1].Outliers[0] != 76 { + t.Errorf("outlier got %v, want 76", bw.Boxes[1].Outliers[0]) + } +} + +func TestDeserializeChart_Super(t *testing.T) { + raw := `{ + "type": "superchart", + "title": "Multiple Charts Example", + "elements": [ + { + "type": "line", + "title": "Sine Wave", + "elements": [ + {"label": "sin", "points": [[0, 0], [1, 0.5]]} + ] + }, + { + "type": "scatter", + "title": "Scatter Plot", + "x_label": "X", + "y_label": "Y", + "elements": [ + {"label": "Dataset 1", "points": [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]]} + ] + } + ] + }` + c := deserializeChart(mustUnmarshal(t, raw)) + sc, ok := c.(*SuperChart) + if !ok { + t.Fatalf("expected *SuperChart, got %T", c) + } + if sc.Type != ChartTypeSuperChart { + t.Errorf("type: %s", sc.Type) + } + if sc.Title != "Multiple Charts Example" { + t.Errorf("title: %q", sc.Title) + } + if len(sc.Charts) != 2 { + t.Fatalf("expected 2 sub-charts, got %d", len(sc.Charts)) + } + lc, ok := sc.Charts[0].(*LineChart) + if !ok { + t.Fatalf("first sub-chart not line: %T", sc.Charts[0]) + } + if lc.Title != "Sine Wave" || lc.XLabel != "" { + t.Errorf("line title/label: %+v", lc) + } + if len(lc.Points) != 1 || len(lc.Points[0].Points) != 2 { + t.Errorf("line points wrong: %+v", lc.Points) + } + sct, ok := sc.Charts[1].(*ScatterChart) + if !ok { + t.Fatalf("second sub-chart not scatter: %T", sc.Charts[1]) + } + if sct.XLabel != "X" || sct.YLabel != "Y" { + t.Errorf("scatter labels: %+v", sct) + } + if len(sct.Points) != 1 || len(sct.Points[0].Points) != 5 { + t.Errorf("scatter points wrong: %+v", sct.Points) + } +} + +func TestDeserializeChart_LogScale(t *testing.T) { + raw := `{ + "type": "line", + "title": "Chart with Log Scale on Y-axis", + "x_label": "X-axis", + "y_label": "Y-axis (log scale)", + "y_unit": "log scale", + "x_scale": "linear", + "y_scale": "log", + "elements": [{"label": "y = e^x", "points": [[1, 2.7], [2, 7.3]]}] + }` + c := deserializeChart(mustUnmarshal(t, raw)) + lc, ok := c.(*LineChart) + if !ok { + t.Fatalf("expected *LineChart, got %T", c) + } + if lc.YScale != ScaleTypeLog || lc.XScale != ScaleTypeLinear { + t.Errorf("scales: x=%s y=%s", lc.XScale, lc.YScale) + } + if lc.YUnit != "log scale" { + t.Errorf("y unit: %q", lc.YUnit) + } +} + +func TestDeserializeChart_DatetimeScale(t *testing.T) { + raw := `{ + "type": "line", + "title": "T", + "x_scale": "datetime", + "y_scale": "linear", + "elements": [] + }` + c := deserializeChart(mustUnmarshal(t, raw)) + lc, ok := c.(*LineChart) + if !ok { + t.Fatalf("expected *LineChart") + } + if lc.XScale != ScaleTypeDatetime { + t.Errorf("want datetime, got %s", lc.XScale) + } +} + +func TestDeserializeChart_CategoricalScale(t *testing.T) { + raw := `{ + "type": "line", + "title": "T", + "x_scale": "linear", + "y_scale": "categorical", + "elements": [] + }` + c := deserializeChart(mustUnmarshal(t, raw)) + lc, ok := c.(*LineChart) + if !ok { + t.Fatalf("expected *LineChart") + } + if lc.YScale != ScaleTypeCategorical { + t.Errorf("want categorical, got %s", lc.YScale) + } +} + +func TestChart_JSONSerialization(t *testing.T) { + raw := `{ + "type": "scatter", + "title": "S", + "elements": [{"label": "Dataset", "points": [[0.1, 0.2]]}] + }` + result := newResultFromRaw(map[string]interface{}{ + "type": "result", + "chart": mustUnmarshal(t, raw), + "is_main_result": true, + }) + if result.Chart == nil { + t.Fatal("expected chart to be set") + } + if result.Chart.ChartType() != ChartTypeScatter { + t.Errorf("type: %s", result.Chart.ChartType()) + } + + exec := NewExecution() + exec.Results = append(exec.Results, result) + exec.Results[0].IsMainResult = true + + serialized, err := exec.ToJSON() + if err != nil { + t.Fatalf("ToJSON: %v", err) + } + if len(serialized) == 0 { + t.Error("expected non-empty serialization") + } + + // ToJSON on the chart must round-trip through json. + dict := result.Chart.ToJSON() + b, err := json.Marshal(dict) + if err != nil { + t.Fatalf("marshal chart dict: %v", err) + } + var back map[string]interface{} + if err := json.Unmarshal(b, &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back["type"] != "scatter" { + t.Errorf("expected scatter in roundtrip, got %v", back["type"]) + } +} + +func TestDeserializeChart_UnknownKeepsElements(t *testing.T) { + raw := `{ + "type": "weird", + "title": "Two Concentric Circles", + "elements": [] + }` + c := deserializeChart(mustUnmarshal(t, raw)) + if c.ChartType() != ChartTypeUnknown { + t.Errorf("expected unknown") + } + if c.ChartTitle() != "Two Concentric Circles" { + t.Errorf("title: %q", c.ChartTitle()) + } + uc, ok := c.(*UnknownChart) + if !ok { + t.Fatalf("expected *UnknownChart, got %T", c) + } + if len(uc.Elements) != 0 { + t.Errorf("expected 0 elements, got %d", len(uc.Elements)) + } +} + +func TestDeserializeChart_NilInput(t *testing.T) { + if deserializeChart(nil) != nil { + t.Error("deserializeChart(nil) should return nil") + } +} diff --git a/go/client.go b/go/client.go new file mode 100644 index 00000000..4cdd282f --- /dev/null +++ b/go/client.go @@ -0,0 +1,153 @@ +package codeinterpreter + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" +) + +// ConnectionConfig holds the configuration needed to talk to the E2B API and +// to an individual sandbox. +type ConnectionConfig struct { + // APIKey is the E2B API key. If empty the E2B_API_KEY environment variable + // is used. + APIKey string + // AccessToken is an optional envd access token used to authenticate direct + // requests to the sandbox. + AccessToken string + // TrafficAccessToken is an optional token used to bypass traffic controls. + TrafficAccessToken string + // Domain is the e2b domain (default: e2b.app). The E2B_DOMAIN env var is + // used when empty. + Domain string + // Debug turns on debug-mode. When true, the SDK talks over http:// and + // uses the sandbox host unchanged (useful for local development). + Debug bool + // RequestTimeout is the default HTTP request timeout. + RequestTimeout time.Duration + // HTTPClient is the underlying client used for all HTTP traffic. If nil a + // sensible default is created. + HTTPClient *http.Client + // Headers lets callers inject additional headers on every request. + Headers map[string]string +} + +func (c *ConnectionConfig) init() { + if c.APIKey == "" { + c.APIKey = os.Getenv("E2B_API_KEY") + } + if c.AccessToken == "" { + c.AccessToken = os.Getenv("E2B_ACCESS_TOKEN") + } + if c.Domain == "" { + if d := os.Getenv("E2B_DOMAIN"); d != "" { + c.Domain = d + } else { + c.Domain = DefaultDomain + } + } + if !c.Debug { + if v := os.Getenv("E2B_DEBUG"); v == "1" || strings.EqualFold(v, "true") { + c.Debug = true + } + } + if c.RequestTimeout == 0 { + c.RequestTimeout = DefaultRequestTimeout * time.Second + } + if c.HTTPClient == nil { + c.HTTPClient = &http.Client{} + } +} + +// APIBase returns the base URL for the E2B management API. +func (c *ConnectionConfig) APIBase() string { + scheme := "https" + if c.Debug { + scheme = "http" + } + return fmt.Sprintf("%s://api.%s", scheme, c.Domain) +} + +// do is a low level helper that performs an HTTP request against the E2B API +// and decodes the JSON response into `out` (if non-nil). It returns a typed +// error on non-2xx responses. +func (c *ConnectionConfig) do(ctx context.Context, method, path string, body interface{}, out interface{}) error { + var reader io.Reader + if body != nil { + b, err := json.Marshal(body) + if err != nil { + return fmt.Errorf("marshal request body: %w", err) + } + reader = bytes.NewReader(b) + } + + u := c.APIBase() + path + req, err := http.NewRequestWithContext(ctx, method, u, reader) + if err != nil { + return err + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + req.Header.Set("Accept", "application/json") + if c.APIKey != "" { + req.Header.Set("X-API-Key", c.APIKey) + } + if c.AccessToken != "" { + req.Header.Set("X-Access-Token", c.AccessToken) + } + for k, v := range c.Headers { + req.Header.Set(k, v) + } + + client := c.HTTPClient + if client.Timeout == 0 && c.RequestTimeout > 0 { + client = &http.Client{ + Timeout: c.RequestTimeout, + Transport: client.Transport, + } + } + + resp, err := client.Do(req) + if err != nil { + var netErr interface{ Timeout() bool } + if errors.As(err, &netErr) && netErr.Timeout() { + return formatRequestTimeoutError() + } + return err + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(resp.Body) + return mapHTTPError(resp.StatusCode, string(body)) + } + if out == nil { + return nil + } + return json.NewDecoder(resp.Body).Decode(out) +} + +// mapHTTPError translates an HTTP status code into the proper SDK error type. +func mapHTTPError(status int, body string) error { + msg := strings.TrimSpace(body) + switch status { + case http.StatusNotFound: + return &NotFoundError{Message: msg} + case http.StatusUnauthorized, http.StatusForbidden: + return &AuthenticationError{Message: msg} + case http.StatusTooManyRequests: + return &RateLimitError{Message: msg} + case http.StatusBadGateway, http.StatusGatewayTimeout: + return &TimeoutError{Message: msg + ": This error is likely due to sandbox timeout. You can modify the sandbox timeout by passing 'Timeout' when starting the sandbox or by calling 'SetTimeout' on the sandbox with the desired timeout."} + default: + return &SandboxError{StatusCode: status, Message: msg} + } +} diff --git a/go/code_interpreter.go b/go/code_interpreter.go new file mode 100644 index 00000000..60a7ef71 --- /dev/null +++ b/go/code_interpreter.go @@ -0,0 +1,380 @@ +package codeinterpreter + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "time" +) + +// RunCodeOpts holds options for Sandbox.RunCode. +// +// Either Language or Context may be supplied — not both. When both are empty +// the default Python context is used. +type RunCodeOpts struct { + // Language to use. Must not be combined with Context. + Language RunCodeLanguage + // Context (pre-created kernel) to run the code in. + Context *Context + // OnStdout is called for every stdout chunk. + OnStdout OnStdoutFunc + // OnStderr is called for every stderr chunk. + OnStderr OnStderrFunc + // OnResult is called for every Result (display call or final result). + OnResult OnResultFunc + // OnError is called when the kernel reports an error for this cell. + OnError OnErrorFunc + // Envs are extra environment variables exposed to the running code. + Envs map[string]string + // Timeout is the maximum execution time for this cell (default: 300s). + // Pass -1 to disable. + Timeout time.Duration + // RequestTimeout is the HTTP-level request timeout. + RequestTimeout time.Duration +} + +// CreateCodeContextOpts holds options for Sandbox.CreateCodeContext. +type CreateCodeContextOpts struct { + // Cwd is the working directory for the context (default /home/user). + Cwd string + // Language of the new context (default python). + Language RunCodeLanguage + // RequestTimeout overrides the default HTTP request timeout. + RequestTimeout time.Duration +} + +// RunCode executes the supplied code in the sandbox and returns the full +// Execution result. +// +// Streaming output is forwarded to the callbacks on opts in real time. +func (s *Sandbox) RunCode(ctx context.Context, code string, opts *RunCodeOpts) (*Execution, error) { + if opts == nil { + opts = &RunCodeOpts{} + } + + if opts.Language != "" && opts.Context != nil { + return nil, &InvalidArgumentError{ + Message: "You can provide context or language, but not both at the same time.", + } + } + + // Build request body + payload := map[string]interface{}{ + "code": code, + } + if opts.Context != nil { + payload["context_id"] = opts.Context.ID + } + if opts.Language != "" { + payload["language"] = string(opts.Language) + } + if len(opts.Envs) > 0 { + payload["env_vars"] = opts.Envs + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("marshal execute payload: %w", err) + } + + // Build request with its own context to honour Timeout/RequestTimeout. + timeout := opts.Timeout + if timeout == 0 { + timeout = DefaultTimeout * time.Second + } + + reqCtx := ctx + var cancel context.CancelFunc + if timeout > 0 { + reqCtx, cancel = context.WithTimeout(ctx, timeout+s.connection.RequestTimeout) + defer cancel() + } + + u := s.jupyterURL() + "/execute" + req, err := http.NewRequestWithContext(reqCtx, "POST", u, bytes.NewReader(body)) + if err != nil { + return nil, err + } + s.addAuthHeaders(req.Header) + + resp, err := s.connection.HTTPClient.Do(req) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return nil, formatExecutionTimeoutError() + } + var ne interface{ Timeout() bool } + if errors.As(err, &ne) && ne.Timeout() { + return nil, formatRequestTimeoutError() + } + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + b, _ := io.ReadAll(resp.Body) + return nil, mapHTTPError(resp.StatusCode, string(b)) + } + + execution := NewExecution() + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 16*1024*1024) + + for scanner.Scan() { + line := scanner.Text() + if line == "" { + continue + } + if err := parseOutputLine(execution, line, opts); err != nil { + return nil, err + } + } + if err := scanner.Err(); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return nil, formatExecutionTimeoutError() + } + return nil, err + } + + return execution, nil +} + +// parseOutputLine decodes a single NDJSON line emitted by /execute and +// dispatches it to the correct handler. +func parseOutputLine(execution *Execution, line string, opts *RunCodeOpts) error { + var msg map[string]interface{} + if err := json.Unmarshal([]byte(line), &msg); err != nil { + // Ignore un-parseable lines — they shouldn't happen but we don't + // want to blow up a long running execution over them. + return nil + } + + t, _ := msg["type"].(string) + switch t { + case "result": + // The "type" / "is_main_result" fields are not relevant as keys on the + // resulting Result object; newResultFromRaw handles the split. + result := newResultFromRaw(msg) + execution.Results = append(execution.Results, result) + if opts.OnResult != nil { + opts.OnResult(result) + } + case "stdout": + text := getString(msg, "text") + execution.Logs.Stdout = append(execution.Logs.Stdout, text) + if opts.OnStdout != nil { + opts.OnStdout(OutputMessage{ + Line: text, + Timestamp: getInt64(msg, "timestamp"), + Error: false, + }) + } + case "stderr": + text := getString(msg, "text") + execution.Logs.Stderr = append(execution.Logs.Stderr, text) + if opts.OnStderr != nil { + opts.OnStderr(OutputMessage{ + Line: text, + Timestamp: getInt64(msg, "timestamp"), + Error: true, + }) + } + case "error": + execution.Error = &ExecutionError{ + Name: getString(msg, "name"), + Value: getString(msg, "value"), + Traceback: getString(msg, "traceback"), + } + if opts.OnError != nil { + opts.OnError(execution.Error) + } + case "number_of_executions": + if v, ok := msg["execution_count"]; ok { + switch n := v.(type) { + case float64: + execution.ExecutionCount = int(n) + case int: + execution.ExecutionCount = n + case json.Number: + i, _ := n.Int64() + execution.ExecutionCount = int(i) + } + } + } + return nil +} + +func getInt64(m map[string]interface{}, key string) int64 { + if v, ok := m[key]; ok { + switch n := v.(type) { + case float64: + return int64(n) + case int: + return int64(n) + case int64: + return n + case json.Number: + i, _ := n.Int64() + return i + } + } + return 0 +} + +// CreateCodeContext creates a fresh kernel in which subsequent code can be run. +func (s *Sandbox) CreateCodeContext(ctx context.Context, opts *CreateCodeContextOpts) (*Context, error) { + if opts == nil { + opts = &CreateCodeContextOpts{} + } + + body := map[string]interface{}{} + if opts.Language != "" { + body["language"] = string(opts.Language) + } + if opts.Cwd != "" { + body["cwd"] = opts.Cwd + } + + return s.doContextRequest(ctx, "POST", "/contexts", body, opts.RequestTimeout) +} + +// ListCodeContexts lists the contexts currently available in the sandbox. +func (s *Sandbox) ListCodeContexts(ctx context.Context) ([]*Context, error) { + b, status, err := s.jupyterRequest(ctx, "GET", "/contexts", nil, 0) + if err != nil { + return nil, err + } + if status >= 400 { + return nil, mapHTTPError(status, string(b)) + } + + var raw []map[string]interface{} + if err := json.Unmarshal(b, &raw); err != nil { + return nil, fmt.Errorf("decode contexts: %w", err) + } + + contexts := make([]*Context, 0, len(raw)) + for _, data := range raw { + contexts = append(contexts, contextFromJSON(data)) + } + return contexts, nil +} + +// RestartCodeContext restarts the given context. The parameter can either be a +// *Context or a context-id string. +func (s *Sandbox) RestartCodeContext(ctx context.Context, c interface{}) error { + id, err := contextID(c) + if err != nil { + return err + } + _, status, err := s.jupyterRequest(ctx, "POST", "/contexts/"+url.PathEscape(id)+"/restart", nil, 0) + if err != nil { + return err + } + if status >= 400 { + return mapHTTPError(status, "") + } + return nil +} + +// RemoveCodeContext removes the context. The parameter can either be a +// *Context or a context-id string. +func (s *Sandbox) RemoveCodeContext(ctx context.Context, c interface{}) error { + id, err := contextID(c) + if err != nil { + return err + } + _, status, err := s.jupyterRequest(ctx, "DELETE", "/contexts/"+url.PathEscape(id), nil, 0) + if err != nil { + return err + } + if status >= 400 { + return mapHTTPError(status, "") + } + return nil +} + +// doContextRequest is a helper for creating a Context via POST /contexts. +func (s *Sandbox) doContextRequest(ctx context.Context, method, path string, body map[string]interface{}, reqTimeout time.Duration) (*Context, error) { + b, status, err := s.jupyterRequest(ctx, method, path, body, reqTimeout) + if err != nil { + return nil, err + } + if status >= 400 { + return nil, mapHTTPError(status, string(b)) + } + var raw map[string]interface{} + if err := json.Unmarshal(b, &raw); err != nil { + return nil, fmt.Errorf("decode context response: %w", err) + } + return contextFromJSON(raw), nil +} + +// jupyterRequest performs an HTTP request against the sandbox jupyter server +// and returns the raw body, status code and any transport-level error. +func (s *Sandbox) jupyterRequest(ctx context.Context, method, path string, body interface{}, reqTimeout time.Duration) ([]byte, int, error) { + var reader io.Reader + if body != nil { + b, err := json.Marshal(body) + if err != nil { + return nil, 0, err + } + reader = bytes.NewReader(b) + } + + if reqTimeout == 0 { + reqTimeout = s.connection.RequestTimeout + } + + reqCtx := ctx + var cancel context.CancelFunc + if reqTimeout > 0 { + reqCtx, cancel = context.WithTimeout(ctx, reqTimeout) + defer cancel() + } + + u := s.jupyterURL() + path + req, err := http.NewRequestWithContext(reqCtx, method, u, reader) + if err != nil { + return nil, 0, err + } + s.addAuthHeaders(req.Header) + + resp, err := s.connection.HTTPClient.Do(req) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return nil, 0, formatRequestTimeoutError() + } + return nil, 0, err + } + defer resp.Body.Close() + + out, err := io.ReadAll(resp.Body) + if err != nil { + return nil, resp.StatusCode, err + } + return out, resp.StatusCode, nil +} + +// contextID returns the string ID from either a *Context or a plain string. +func contextID(c interface{}) (string, error) { + switch v := c.(type) { + case string: + return v, nil + case *Context: + if v == nil { + return "", &InvalidArgumentError{Message: "context is nil"} + } + return v.ID, nil + case Context: + return v.ID, nil + default: + return "", &InvalidArgumentError{Message: fmt.Sprintf("unsupported context type %T", c)} + } +} diff --git a/go/code_interpreter_test.go b/go/code_interpreter_test.go new file mode 100644 index 00000000..44658d11 --- /dev/null +++ b/go/code_interpreter_test.go @@ -0,0 +1,228 @@ +package codeinterpreter + +import ( + "encoding/json" + "strings" + "testing" +) + +func TestDeserializeChart_Line(t *testing.T) { + rawJSON := `{ + "type": "line", + "title": "Sample Line", + "x_label": "Time", + "y_label": "Value", + "x_ticks": [1, 2, 3], + "x_tick_labels": ["a", "b", "c"], + "x_scale": "linear", + "y_ticks": [0.1, 0.2, 0.3], + "y_tick_labels": ["0.1", "0.2", "0.3"], + "y_scale": "log", + "elements": [ + {"label": "series1", "points": [[1, 0.1], [2, 0.2]]} + ] + }` + + var data map[string]interface{} + if err := json.Unmarshal([]byte(rawJSON), &data); err != nil { + t.Fatal(err) + } + + c := deserializeChart(data) + lc, ok := c.(*LineChart) + if !ok { + t.Fatalf("expected *LineChart, got %T", c) + } + if lc.ChartType() != ChartTypeLine { + t.Errorf("expected line, got %s", lc.ChartType()) + } + if lc.Title != "Sample Line" { + t.Errorf("unexpected title: %q", lc.Title) + } + if lc.XLabel != "Time" || lc.YLabel != "Value" { + t.Errorf("labels not parsed") + } + if lc.YScale != ScaleTypeLog { + t.Errorf("expected log y-scale, got %s", lc.YScale) + } + if len(lc.Points) != 1 || lc.Points[0].Label != "series1" || len(lc.Points[0].Points) != 2 { + t.Errorf("points not parsed correctly: %+v", lc.Points) + } +} + +func TestDeserializeChart_Pie(t *testing.T) { + raw := `{ + "type": "pie", + "title": "Pie!", + "elements": [ + {"label": "a", "angle": 1.5, "radius": 1}, + {"label": "b", "angle": 2.0, "radius": 1} + ] + }` + var data map[string]interface{} + _ = json.Unmarshal([]byte(raw), &data) + + c := deserializeChart(data) + pc, ok := c.(*PieChart) + if !ok { + t.Fatalf("expected *PieChart, got %T", c) + } + if len(pc.Slices) != 2 { + t.Errorf("expected 2 slices, got %d", len(pc.Slices)) + } + if pc.Slices[0].Label != "a" || pc.Slices[0].Angle != 1.5 { + t.Errorf("slice mismatch: %+v", pc.Slices[0]) + } +} + +func TestDeserializeChart_Unknown(t *testing.T) { + raw := `{"type": "weird_chart", "title": "?", "elements": []}` + var data map[string]interface{} + _ = json.Unmarshal([]byte(raw), &data) + + c := deserializeChart(data) + if c.ChartType() != ChartTypeUnknown { + t.Errorf("expected unknown, got %s", c.ChartType()) + } +} + +func TestParseOutput_StreamFlow(t *testing.T) { + lines := []string{ + `{"type": "stdout", "text": "hello\n", "timestamp": 1}`, + `{"type": "stderr", "text": "oops\n", "timestamp": 2}`, + `{"type": "result", "text": "42", "is_main_result": true}`, + `{"type": "number_of_executions", "execution_count": 3}`, + } + + execution := NewExecution() + stdoutN := 0 + stderrN := 0 + resultN := 0 + + opts := &RunCodeOpts{ + OnStdout: func(msg OutputMessage) { stdoutN++ }, + OnStderr: func(msg OutputMessage) { stderrN++ }, + OnResult: func(r *Result) { resultN++ }, + } + + for _, line := range lines { + if err := parseOutputLine(execution, line, opts); err != nil { + t.Fatal(err) + } + } + + if stdoutN != 1 || stderrN != 1 || resultN != 1 { + t.Errorf("unexpected callback counts: stdout=%d stderr=%d result=%d", stdoutN, stderrN, resultN) + } + if execution.ExecutionCount != 3 { + t.Errorf("expected execution count 3, got %d", execution.ExecutionCount) + } + if execution.Text() != "42" { + t.Errorf("expected text %q, got %q", "42", execution.Text()) + } + if !strings.Contains(strings.Join(execution.Logs.Stdout, ""), "hello") { + t.Errorf("stdout not captured: %v", execution.Logs.Stdout) + } +} + +func TestParseOutput_Error(t *testing.T) { + execution := NewExecution() + line := `{"type": "error", "name": "NameError", "value": "x is not defined", "traceback": "..."}` + var called bool + opts := &RunCodeOpts{ + OnError: func(err *ExecutionError) { called = true }, + } + if err := parseOutputLine(execution, line, opts); err != nil { + t.Fatal(err) + } + if !called { + t.Error("OnError not called") + } + if execution.Error == nil || execution.Error.Name != "NameError" { + t.Errorf("error not captured: %+v", execution.Error) + } +} + +func TestResult_Formats(t *testing.T) { + raw := map[string]interface{}{ + "type": "result", + "text": "x", + "html": "", + "png": "base64...", + "is_main_result": true, + "custom_mime": "something", + } + r := newResultFromRaw(raw) + formats := r.Formats() + + want := map[string]bool{ + "text": true, "html": true, "png": true, "custom_mime": true, + } + for k := range want { + found := false + for _, f := range formats { + if f == k { + found = true + break + } + } + if !found { + t.Errorf("expected %q in formats, got %v", k, formats) + } + } + + if !r.IsMainResult { + t.Error("expected IsMainResult to be true") + } + if r.Extra["custom_mime"] != "something" { + t.Errorf("custom mime not captured in Extra") + } +} + +func TestContextFromJSON(t *testing.T) { + m := map[string]interface{}{"id": "ctx-1", "language": "python", "cwd": "/tmp"} + c := contextFromJSON(m) + if c.ID != "ctx-1" || c.Language != "python" || c.Cwd != "/tmp" { + t.Errorf("context mismatch: %+v", c) + } +} + +func TestMapHTTPError(t *testing.T) { + err := mapHTTPError(404, "not found") + if _, ok := err.(*NotFoundError); !ok { + t.Errorf("expected NotFoundError, got %T", err) + } + err = mapHTTPError(401, "bad key") + if _, ok := err.(*AuthenticationError); !ok { + t.Errorf("expected AuthenticationError, got %T", err) + } + err = mapHTTPError(429, "slow down") + if _, ok := err.(*RateLimitError); !ok { + t.Errorf("expected RateLimitError, got %T", err) + } + err = mapHTTPError(502, "timeout") + if _, ok := err.(*TimeoutError); !ok { + t.Errorf("expected TimeoutError, got %T", err) + } + err = mapHTTPError(500, "boom") + if _, ok := err.(*SandboxError); !ok { + t.Errorf("expected SandboxError, got %T", err) + } +} + +func TestContextID(t *testing.T) { + id, err := contextID("abc") + if err != nil || id != "abc" { + t.Errorf("unexpected: id=%s err=%v", id, err) + } + + id, err = contextID(&Context{ID: "x"}) + if err != nil || id != "x" { + t.Errorf("unexpected: id=%s err=%v", id, err) + } + + _, err = contextID(123) + if _, ok := err.(*InvalidArgumentError); !ok { + t.Errorf("expected InvalidArgumentError, got %T", err) + } +} diff --git a/go/constants.go b/go/constants.go new file mode 100644 index 00000000..033ff478 --- /dev/null +++ b/go/constants.go @@ -0,0 +1,21 @@ +package codeinterpreter + +// DefaultTemplate is the default sandbox template used for the code interpreter. +const DefaultTemplate = "code-interpreter-v1" + +// JupyterPort is the internal port on which the Jupyter/Code-Interpreter server +// listens inside the sandbox. +const JupyterPort = 49999 + +// DefaultTimeout is the default timeout for code execution in seconds. +const DefaultTimeout = 300 + +// DefaultSandboxTimeout is the default lifetime of a freshly created sandbox in +// seconds (matches e2b defaults). +const DefaultSandboxTimeout = 300 + +// DefaultRequestTimeout is the default HTTP request timeout in seconds. +const DefaultRequestTimeout = 30 + +// DefaultDomain is the default e2b.dev API domain. +const DefaultDomain = "e2b.app" diff --git a/go/errors.go b/go/errors.go new file mode 100644 index 00000000..75ed7592 --- /dev/null +++ b/go/errors.go @@ -0,0 +1,69 @@ +package codeinterpreter + +import "fmt" + +// SandboxError is the generic error returned by the SDK for unexpected +// server responses. +type SandboxError struct { + Message string + StatusCode int +} + +func (e *SandboxError) Error() string { + if e.StatusCode != 0 { + return fmt.Sprintf("sandbox error (%d): %s", e.StatusCode, e.Message) + } + return fmt.Sprintf("sandbox error: %s", e.Message) +} + +// NotFoundError is returned when a resource (context, sandbox, file) is missing. +type NotFoundError struct { + Message string +} + +func (e *NotFoundError) Error() string { return "not found: " + e.Message } + +// TimeoutError is returned when a request or execution times out. +type TimeoutError struct { + Message string +} + +func (e *TimeoutError) Error() string { return "timeout: " + e.Message } + +// InvalidArgumentError is returned when input parameters are invalid +// (e.g. providing both `context` and `language`). +type InvalidArgumentError struct { + Message string +} + +func (e *InvalidArgumentError) Error() string { return "invalid argument: " + e.Message } + +// AuthenticationError is returned when the supplied API key is invalid or +// missing. +type AuthenticationError struct { + Message string +} + +func (e *AuthenticationError) Error() string { return "authentication error: " + e.Message } + +// RateLimitError is returned when the caller has exceeded the API's rate limit. +type RateLimitError struct { + Message string +} + +func (e *RateLimitError) Error() string { return "rate limit: " + e.Message } + +// formatRequestTimeoutError wraps an error with a friendlier timeout message. +func formatRequestTimeoutError() error { + return &TimeoutError{ + Message: "Request timed out — the 'RequestTimeout' option can be used to increase this timeout", + } +} + +// formatExecutionTimeoutError wraps an error with a friendlier timeout message +// for code execution. +func formatExecutionTimeoutError() error { + return &TimeoutError{ + Message: "Execution timed out — the 'Timeout' option can be used to increase this timeout", + } +} diff --git a/go/example/main.go b/go/example/main.go new file mode 100644 index 00000000..6228cc40 --- /dev/null +++ b/go/example/main.go @@ -0,0 +1,60 @@ +// Example command demonstrating typical use of the Go Code Interpreter SDK. +// +// Usage: +// +// E2B_API_KEY=e2b_... go run ./go/example +package main + +import ( + "context" + "fmt" + "log" + "time" + + codeinterpreter "github.com/e2b-dev/codeinterpreter" +) + +var ( + code = ` +import numpy +from PIL import Image + +imarray = numpy.random.rand(16,16,3) * 255 +image = Image.fromarray(imarray.astype('uint8')).convert('RGBA') + +image.save("test.png") +print("Image saved.")` +) + +func main() { + ctx := context.Background() + + sbx, err := codeinterpreter.Create(ctx, &codeinterpreter.SandboxOpts{ + Timeout: 60 * time.Second, + }) + if err != nil { + log.Fatalf("create sandbox: %v", err) + } + fmt.Println("ℹ️ sandbox created", sbx.SandboxID()) + defer func() { + _ = sbx.Kill(ctx) + fmt.Println("🧹 sandbox killed") + }() + + exec, err := sbx.RunCode(ctx, code, &codeinterpreter.RunCodeOpts{ + OnStdout: func(msg codeinterpreter.OutputMessage) { + fmt.Printf("[stdout] %s", msg.Line) + }, + OnStderr: func(msg codeinterpreter.OutputMessage) { + fmt.Printf("[stderr] %s", msg.Line) + }, + }) + if err != nil { + log.Fatalf("run code: %v", err) + } + + fmt.Println("result:", exec.Results[0].PNG) + if exec.Error != nil { + fmt.Println("error:", exec.Error) + } +} diff --git a/go/example_test.go b/go/example_test.go new file mode 100644 index 00000000..4de7f3ff --- /dev/null +++ b/go/example_test.go @@ -0,0 +1,80 @@ +package codeinterpreter_test + +import ( + "context" + "fmt" + "time" + + codeinterpreter "github.com/e2b-dev/codeinterpreter" +) + +// Example_create shows how to create a sandbox and run code. This is a +// documentation-only example; it requires the E2B_API_KEY env var to work +// against real infrastructure, so the helper below guards the call and simply +// returns without running if the key is missing. +func Example_create() { + ctx := context.Background() + + sbx, err := codeinterpreter.Create(ctx, &codeinterpreter.SandboxOpts{ + Timeout: 60 * time.Second, + }) + if err != nil { + // No API key configured — skip the rest. + return + } + defer sbx.Kill(ctx) + + if _, err := sbx.RunCode(ctx, "x = 1", nil); err != nil { + return + } + exec, err := sbx.RunCode(ctx, "x += 1; x", nil) + if err != nil { + return + } + fmt.Println(exec.Text()) +} + +// Example_streaming demonstrates the streaming callbacks. +func Example_streaming() { + ctx := context.Background() + sbx, err := codeinterpreter.Create(ctx, nil) + if err != nil { + return + } + defer sbx.Kill(ctx) + + _, _ = sbx.RunCode(ctx, "print('hello')\nprint('world')", &codeinterpreter.RunCodeOpts{ + OnStdout: func(msg codeinterpreter.OutputMessage) { + fmt.Printf("stdout: %s", msg.Line) + }, + OnStderr: func(msg codeinterpreter.OutputMessage) { + fmt.Printf("stderr: %s", msg.Line) + }, + }) +} + +// Example_codeContext shows how to use multiple contexts (kernels). +func Example_codeContext() { + ctx := context.Background() + sbx, err := codeinterpreter.Create(ctx, nil) + if err != nil { + return + } + defer sbx.Kill(ctx) + + rctx, err := sbx.CreateCodeContext(ctx, &codeinterpreter.CreateCodeContextOpts{ + Language: codeinterpreter.LanguageR, + }) + if err != nil { + return + } + defer sbx.RemoveCodeContext(ctx, rctx) + + exec, err := sbx.RunCode(ctx, `x <- 1; x + 1`, &codeinterpreter.RunCodeOpts{ + Context: rctx, + }) + if err != nil { + return + } + fmt.Println(exec.Text()) +} diff --git a/go/go.mod b/go/go.mod new file mode 100644 index 00000000..04c97f03 --- /dev/null +++ b/go/go.mod @@ -0,0 +1,3 @@ +module github.com/e2b-dev/codeinterpreter + +go 1.21 diff --git a/go/integration_test.go b/go/integration_test.go new file mode 100644 index 00000000..61f50c9d --- /dev/null +++ b/go/integration_test.go @@ -0,0 +1,526 @@ +package codeinterpreter + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" +) + +func TestRunCode_Basic(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assertJupyterRequest(t, r, "POST", "/execute") + assertExecuteBody(t, r, "x = 1; x", "", "") + w.Header().Set("Content-Type", "application/x-ndjson") + _, _ = io.WriteString(w, ndjson( + `{"type": "result", "text": "1", "is_main_result": true}`, + )) + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + exec, err := sbx.RunCode(context.Background(), "x = 1; x", nil) + if err != nil { + t.Fatalf("run code: %v", err) + } + if got := exec.Text(); got != "1" { + t.Errorf("text: got %q, want %q", got, "1") + } + if len(exec.Results) != 1 { + t.Errorf("expected 1 result, got %d", len(exec.Results)) + } +} + +func TestRunCode_SendsContextID(t *testing.T) { + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&calls, 1) + body := readJSON(t, r) + if body["context_id"] != "ctx-42" { + t.Errorf("expected context_id=ctx-42, got %v", body["context_id"]) + } + if n == 1 { + _, _ = io.WriteString(w, `{"type": "result", "text": "", "is_main_result": true}`) + } else { + _, _ = io.WriteString(w, `{"type": "result", "text": "2", "is_main_result": true}`) + } + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + ctx := context.Background() + ctxObj := &Context{ID: "ctx-42", Language: "python"} + + if _, err := sbx.RunCode(ctx, "test_stateful = 1", &RunCodeOpts{Context: ctxObj}); err != nil { + t.Fatal(err) + } + exec, err := sbx.RunCode(ctx, "test_stateful+=1; test_stateful", &RunCodeOpts{Context: ctxObj}) + if err != nil { + t.Fatal(err) + } + if exec.Text() != "2" { + t.Errorf("expected 2, got %q", exec.Text()) + } + if atomic.LoadInt32(&calls) != 2 { + t.Errorf("expected 2 /execute calls, got %d", calls) + } +} + +func TestRunCode_Callbacks(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, ndjson( + `{"type": "stdout", "text": "Hello from e2b\n", "timestamp": 1}`, + `{"type": "stderr", "text": "This is an error message\n", "timestamp": 2}`, + `{"type": "result", "text": "1", "is_main_result": true}`, + )) + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + + var stdout, stderr []OutputMessage + var results []*Result + exec, err := sbx.RunCode(context.Background(), "x = 1;x", &RunCodeOpts{ + OnStdout: func(o OutputMessage) { stdout = append(stdout, o) }, + OnStderr: func(o OutputMessage) { stderr = append(stderr, o) }, + OnResult: func(r *Result) { results = append(results, r) }, + }) + if err != nil { + t.Fatalf("run code: %v", err) + } + if len(stdout) != 1 || stdout[0].Line != "Hello from e2b\n" { + t.Errorf("stdout cb: %+v", stdout) + } + if len(stderr) != 1 || stderr[0].Line != "This is an error message\n" || !stderr[0].Error { + t.Errorf("stderr cb: %+v", stderr) + } + if len(results) != 1 { + t.Errorf("expected 1 result callback, got %d", len(results)) + } + if exec.Logs.Stdout[0] != "Hello from e2b\n" { + t.Errorf("logs stdout: %+v", exec.Logs.Stdout) + } + if exec.Logs.Stderr[0] != "This is an error message\n" { + t.Errorf("logs stderr: %+v", exec.Logs.Stderr) + } +} + +func TestRunCode_ErrorCallback(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, `{"type": "error", "name": "NameError", "value": "name 'xyz' is not defined", "traceback": "..."}`) + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + var errs []*ExecutionError + exec, err := sbx.RunCode(context.Background(), "xyz", &RunCodeOpts{ + OnError: func(e *ExecutionError) { errs = append(errs, e) }, + }) + if err != nil { + t.Fatal(err) + } + if len(errs) != 1 { + t.Fatalf("expected 1 error callback, got %d", len(errs)) + } + if exec.Error == nil || exec.Error.Name != "NameError" { + t.Errorf("error not captured: %+v", exec.Error) + } +} + +func TestRunCode_StreamingResults(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, ndjson( + `{"type": "result", "png": "abc", "text": "
"}`, + `{"type": "result", "text": "final", "is_main_result": true}`, + )) + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + var out []*Result + exec, err := sbx.RunCode(context.Background(), "plot()", &RunCodeOpts{ + OnResult: func(r *Result) { out = append(out, r) }, + }) + if err != nil { + t.Fatal(err) + } + if len(out) != 2 { + t.Errorf("expected 2 streaming results, got %d", len(out)) + } + if exec.Text() != "final" { + t.Errorf("expected final text 'final', got %q", exec.Text()) + } +} + +func TestRunCode_DisplayData(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, `{"type": "result", "text": "
", "png": "base64data", "is_main_result": true}`) + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + exec, err := sbx.RunCode(context.Background(), "plt.show()", nil) + if err != nil { + t.Fatal(err) + } + if len(exec.Results) != 1 { + t.Fatalf("expected 1 result, got %d", len(exec.Results)) + } + r := exec.Results[0] + if r.PNG == "" { + t.Error("expected PNG data") + } + if r.Text == "" { + t.Error("expected text") + } + if len(r.Extra) != 0 { + t.Errorf("expected no extra keys, got %+v", r.Extra) + } +} + +func TestRunCode_DataRepresentation(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, `{"type": "result", "data": {"a": [1, 2, 3]}, "text": "df", "is_main_result": true}`) + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + exec, err := sbx.RunCode(context.Background(), "df", nil) + if err != nil { + t.Fatal(err) + } + r := exec.Results[0] + if r.Data == nil { + t.Fatal("expected Data to be set") + } + arr, ok := r.Data["a"].([]interface{}) + if !ok || len(arr) != 3 { + t.Errorf("unexpected data: %+v", r.Data) + } +} + +func TestRunCode_CustomReprLatexOnly(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, `{"type": "result", "latex": "\\text{X}", "is_main_result": true}`) + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + exec, err := sbx.RunCode(context.Background(), "display()", nil) + if err != nil { + t.Fatal(err) + } + formats := exec.Results[0].Formats() + if len(formats) != 1 || formats[0] != "latex" { + t.Errorf("expected ['latex'], got %+v", formats) + } +} + +func TestRunCode_ExecutionCount(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, ndjson( + `{"type": "result", "text": "/home/user", "is_main_result": true}`, + `{"type": "number_of_executions", "execution_count": 2}`, + )) + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + exec, err := sbx.RunCode(context.Background(), "!pwd", nil) + if err != nil { + t.Fatal(err) + } + if exec.ExecutionCount != 2 { + t.Errorf("expected count 2, got %d", exec.ExecutionCount) + } +} + +func TestRunCode_PassContextAndLanguage(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("server should not have been called") + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + _, err := sbx.RunCode(context.Background(), "console.log('Hello')", &RunCodeOpts{ + Language: LanguageJavaScript, + Context: &Context{ID: "ctx"}, + }) + if err == nil { + t.Fatal("expected error") + } + if _, ok := err.(*InvalidArgumentError); !ok { + t.Errorf("expected *InvalidArgumentError, got %T", err) + } +} + +func TestRunCode_LanguageIsSent(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body := readJSON(t, r) + if body["language"] != "javascript" { + t.Errorf("expected language=javascript, got %v", body["language"]) + } + _, _ = io.WriteString(w, `{"type": "result", "text": "Hello, World!", "is_main_result": true}`) + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + exec, err := sbx.RunCode(context.Background(), "console.log('Hello')", &RunCodeOpts{ + Language: LanguageJavaScript, + }) + if err != nil { + t.Fatal(err) + } + if exec.Text() != "Hello, World!" { + t.Errorf("text: %q", exec.Text()) + } +} + +func TestRunCode_EnvVarsAreSent(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body := readJSON(t, r) + env, ok := body["env_vars"].(map[string]interface{}) + if !ok { + t.Fatalf("env_vars missing: %+v", body) + } + if env["FOO"] != "bar" { + t.Errorf("env FOO: %v", env["FOO"]) + } + _, _ = io.WriteString(w, `{"type": "result", "text": "ok", "is_main_result": true}`) + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + _, err := sbx.RunCode(context.Background(), "x", &RunCodeOpts{ + Envs: map[string]string{"FOO": "bar"}, + }) + if err != nil { + t.Fatal(err) + } +} + +func TestRunCode_BackendErrorPropagates(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = io.WriteString(w, "server error") + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + _, err := sbx.RunCode(context.Background(), "x", nil) + if err == nil { + t.Fatal("expected error") + } + if _, ok := err.(*SandboxError); !ok { + t.Errorf("expected *SandboxError, got %T: %v", err, err) + } +} + +func TestRunCode_502IsTimeout(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + _, err := sbx.RunCode(context.Background(), "x", nil) + if err == nil { + t.Fatal("expected error") + } + if _, ok := err.(*TimeoutError); !ok { + t.Errorf("expected *TimeoutError, got %T", err) + } +} + +func TestCreateCodeContext_DefaultOptions(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assertJupyterRequest(t, r, "POST", "/contexts") + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"id": "ctx-new", "language": "python", "cwd": "/home/user"}`) + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + c, err := sbx.CreateCodeContext(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + if c.ID != "ctx-new" || c.Language != "python" || c.Cwd != "/home/user" { + t.Errorf("unexpected context: %+v", c) + } +} + +func TestCreateCodeContext_WithOptions(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body := readJSON(t, r) + if body["language"] != "python" { + t.Errorf("language: %v", body["language"]) + } + if body["cwd"] != "/root" { + t.Errorf("cwd: %v", body["cwd"]) + } + _, _ = io.WriteString(w, `{"id": "ctx-x", "language": "python", "cwd": "/root"}`) + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + c, err := sbx.CreateCodeContext(context.Background(), &CreateCodeContextOpts{ + Language: LanguagePython, + Cwd: "/root", + }) + if err != nil { + t.Fatal(err) + } + if c.Cwd != "/root" { + t.Errorf("cwd: %q", c.Cwd) + } +} + +func TestListCodeContexts(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assertJupyterRequest(t, r, "GET", "/contexts") + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `[ + {"id": "ctx-1", "language": "python", "cwd": "/home/user"}, + {"id": "ctx-2", "language": "javascript", "cwd": "/home/user"} + ]`) + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + contexts, err := sbx.ListCodeContexts(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(contexts) != 2 { + t.Fatalf("expected 2, got %d", len(contexts)) + } + + // default contexts should include python and javascript + langs := map[string]bool{} + for _, c := range contexts { + langs[c.Language] = true + } + if !langs["python"] || !langs["javascript"] { + t.Errorf("expected python + javascript, got %+v", langs) + } +} + +func TestRestartCodeContext(t *testing.T) { + var path string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path = r.URL.Path + if r.Method != "POST" { + t.Errorf("expected POST, got %s", r.Method) + } + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + if err := sbx.RestartCodeContext(context.Background(), "ctx-xy"); err != nil { + t.Fatal(err) + } + if path != "/contexts/ctx-xy/restart" { + t.Errorf("path: %q", path) + } + + if err := sbx.RestartCodeContext(context.Background(), &Context{ID: "ctx-zz"}); err != nil { + t.Fatal(err) + } + if path != "/contexts/ctx-zz/restart" { + t.Errorf("path with *Context: %q", path) + } +} + +func TestRemoveCodeContext(t *testing.T) { + var called bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + if r.Method != "DELETE" { + t.Errorf("method: %s", r.Method) + } + if !strings.HasSuffix(r.URL.Path, "/contexts/ctx-1") { + t.Errorf("path: %s", r.URL.Path) + } + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + if err := sbx.RemoveCodeContext(context.Background(), "ctx-1"); err != nil { + t.Fatal(err) + } + if !called { + t.Error("expected delete call") + } +} + +func TestCreateCodeContext_NotFound(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = io.WriteString(w, "context gone") + })) + defer srv.Close() + + sbx := newMockSandbox(t, srv) + _, err := sbx.CreateCodeContext(context.Background(), nil) + if err == nil { + t.Fatal("expected error") + } + if _, ok := err.(*NotFoundError); !ok { + t.Errorf("expected *NotFoundError, got %T", err) + } +} + +func TestRestartCodeContext_NilContext(t *testing.T) { + sbx := newMockSandbox(t, httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("should not have been called") + }))) + + err := sbx.RestartCodeContext(context.Background(), (*Context)(nil)) + if err == nil { + t.Fatal("expected error on nil context") + } + if _, ok := err.(*InvalidArgumentError); !ok { + t.Errorf("expected *InvalidArgumentError, got %T", err) + } +} +func assertJupyterRequest(t *testing.T, r *http.Request, method, path string) { + t.Helper() + if r.Method != method { + t.Errorf("method: got %s, want %s", r.Method, method) + } + if r.URL.Path != path { + t.Errorf("path: got %s, want %s", r.URL.Path, path) + } +} + +func assertExecuteBody(t *testing.T, r *http.Request, wantCode, wantCtxID, wantLang string) { + t.Helper() + body := readJSON(t, r) + if body["code"] != wantCode { + t.Errorf("code: got %v, want %q", body["code"], wantCode) + } + if wantCtxID != "" && body["context_id"] != wantCtxID { + t.Errorf("context_id: got %v, want %q", body["context_id"], wantCtxID) + } + if wantLang != "" && body["language"] != wantLang { + t.Errorf("language: got %v, want %q", body["language"], wantLang) + } +} + +func readJSON(t *testing.T, r *http.Request) map[string]interface{} { + t.Helper() + b, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + var m map[string]interface{} + if err := json.Unmarshal(b, &m); err != nil { + t.Fatalf("parse body: %v (raw=%s)", err, string(b)) + } + return m +} diff --git a/go/models.go b/go/models.go new file mode 100644 index 00000000..3c473c9a --- /dev/null +++ b/go/models.go @@ -0,0 +1,272 @@ +package codeinterpreter + +import ( + "encoding/json" + "fmt" +) + +// RunCodeLanguage is a supported language identifier for code execution. +// +// Known values: "python", "javascript", "typescript", "r", "java", "bash". +// Custom strings are allowed for user-installed kernels. +type RunCodeLanguage string + +const ( + LanguagePython RunCodeLanguage = "python" + LanguageJavaScript RunCodeLanguage = "javascript" + LanguageTypeScript RunCodeLanguage = "typescript" + LanguageR RunCodeLanguage = "r" + LanguageJava RunCodeLanguage = "java" + LanguageBash RunCodeLanguage = "bash" +) + +// MIMEType is a string alias for convenience. +type MIMEType = string + +// OutputMessage represents an output message from the sandbox code execution. +type OutputMessage struct { + // Line is the raw output line. + Line string + // Timestamp is the unix epoch in nanoseconds. + Timestamp int64 + // Error indicates whether this output originates from stderr. + Error bool +} + +func (o OutputMessage) String() string { return o.Line } + +// ExecutionError represents an error that occurred during the execution of a cell. +type ExecutionError struct { + // Name of the error (e.g. "NameError"). + Name string `json:"name"` + // Value/message of the error. + Value string `json:"value"` + // Traceback is the raw traceback text. + Traceback string `json:"traceback"` +} + +func (e *ExecutionError) Error() string { + return fmt.Sprintf("%s: %s\n%s", e.Name, e.Value, e.Traceback) +} + +// ToJSON returns the JSON representation of this error. +func (e *ExecutionError) ToJSON() string { + b, _ := json.Marshal(e) + return string(b) +} + +// Logs holds data printed to stdout and stderr during execution. +type Logs struct { + Stdout []string `json:"stdout"` + Stderr []string `json:"stderr"` +} + +// Result represents the data to be displayed as a result of executing a cell in +// a Jupyter notebook. A Result may carry several representations of the same +// underlying data (text, HTML, PNG, SVG, …). +type Result struct { + Text string `json:"text,omitempty"` + HTML string `json:"html,omitempty"` + Markdown string `json:"markdown,omitempty"` + SVG string `json:"svg,omitempty"` + PNG string `json:"png,omitempty"` + JPEG string `json:"jpeg,omitempty"` + PDF string `json:"pdf,omitempty"` + LaTeX string `json:"latex,omitempty"` + JSON map[string]interface{} `json:"json,omitempty"` + JavaScript string `json:"javascript,omitempty"` + Data map[string]interface{} `json:"data,omitempty"` + // Chart is the structured chart data extracted by the server, if any. + Chart Chart `json:"-"` + // IsMainResult indicates whether this is the primary result of the cell + // (as opposed to an intermediate display call). + IsMainResult bool `json:"is_main_result,omitempty"` + // Extra holds any additional representations not covered by the + // standard fields above. + Extra map[string]interface{} `json:"extra,omitempty"` + // Raw holds the full raw JSON payload as returned by the server. + Raw map[string]interface{} `json:"-"` +} + +// Formats returns the list of MIME-like format names available on this result. +func (r *Result) Formats() []string { + formats := []string{} + if r.Text != "" { + formats = append(formats, "text") + } + if r.HTML != "" { + formats = append(formats, "html") + } + if r.Markdown != "" { + formats = append(formats, "markdown") + } + if r.SVG != "" { + formats = append(formats, "svg") + } + if r.PNG != "" { + formats = append(formats, "png") + } + if r.JPEG != "" { + formats = append(formats, "jpeg") + } + if r.PDF != "" { + formats = append(formats, "pdf") + } + if r.LaTeX != "" { + formats = append(formats, "latex") + } + if r.JSON != nil { + formats = append(formats, "json") + } + if r.JavaScript != "" { + formats = append(formats, "javascript") + } + if r.Data != nil { + formats = append(formats, "data") + } + if r.Chart != nil { + formats = append(formats, "chart") + } + for k := range r.Extra { + formats = append(formats, k) + } + return formats +} + +// String returns a short description of the result. If a textual representation +// exists it is returned as-is, otherwise the list of available formats is +// returned. +func (r *Result) String() string { + if r.Text != "" { + return fmt.Sprintf("Result(%s)", r.Text) + } + return fmt.Sprintf("Result(Formats: %v)", r.Formats()) +} + +// newResultFromRaw builds a Result from the raw JSON map coming from the server. +// This mirrors the behaviour of the Python/JS SDKs. +func newResultFromRaw(raw map[string]interface{}) *Result { + r := &Result{Raw: raw} + + // Known keys extraction + knownKeys := map[string]struct{}{ + "type": {}, "is_main_result": {}, + "text": {}, "html": {}, "markdown": {}, "svg": {}, + "png": {}, "jpeg": {}, "pdf": {}, "latex": {}, + "json": {}, "javascript": {}, "data": {}, "chart": {}, + "extra": {}, + } + + r.Text = getString(raw, "text") + r.HTML = getString(raw, "html") + r.Markdown = getString(raw, "markdown") + r.SVG = getString(raw, "svg") + r.PNG = getString(raw, "png") + r.JPEG = getString(raw, "jpeg") + r.PDF = getString(raw, "pdf") + r.LaTeX = getString(raw, "latex") + r.JavaScript = getString(raw, "javascript") + + if v, ok := raw["json"].(map[string]interface{}); ok { + r.JSON = v + } + if v, ok := raw["data"].(map[string]interface{}); ok { + r.Data = v + } + if v, ok := raw["is_main_result"].(bool); ok { + r.IsMainResult = v + } + if c, ok := raw["chart"].(map[string]interface{}); ok { + r.Chart = deserializeChart(c) + } + if extra, ok := raw["extra"].(map[string]interface{}); ok { + r.Extra = extra + } + + // Collect unknown keys into Extra (keeps parity with JS SDK). + for k, v := range raw { + if _, known := knownKeys[k]; known { + continue + } + if r.Extra == nil { + r.Extra = make(map[string]interface{}) + } + r.Extra[k] = v + } + + return r +} + +// Execution represents the result of a cell execution. +type Execution struct { + // Results collects the cell's main result and any intermediate display + // calls (e.g. matplotlib plots). + Results []*Result `json:"results"` + // Logs holds stdout/stderr lines printed during execution. + Logs Logs `json:"logs"` + // Error is set if the execution failed; nil otherwise. + Error *ExecutionError `json:"error,omitempty"` + // ExecutionCount is the execution count (cell index) reported by the kernel. + ExecutionCount int `json:"execution_count,omitempty"` +} + +// NewExecution creates an empty Execution. +func NewExecution() *Execution { + return &Execution{ + Results: []*Result{}, + Logs: Logs{Stdout: []string{}, Stderr: []string{}}, + } +} + +// Text returns the text representation of the main result, if any. +func (e *Execution) Text() string { + for _, r := range e.Results { + if r.IsMainResult { + return r.Text + } + } + return "" +} + +// ToJSON serializes the execution to JSON. +func (e *Execution) ToJSON() (string, error) { + b, err := json.Marshal(e) + if err != nil { + return "", err + } + return string(b), nil +} + +// Context represents a code execution context (a persistent kernel). +type Context struct { + // ID of the context. + ID string `json:"id"` + // Language of the context. + Language string `json:"language"` + // Cwd is the working directory inside the sandbox. + Cwd string `json:"cwd"` +} + +// contextFromJSON decodes a context from a raw JSON map. +func contextFromJSON(data map[string]interface{}) *Context { + return &Context{ + ID: getString(data, "id"), + Language: getString(data, "language"), + Cwd: getString(data, "cwd"), + } +} + +// OnStdoutFunc is a callback invoked for every stdout chunk produced by the +// running code. +type OnStdoutFunc func(msg OutputMessage) + +// OnStderrFunc is a callback invoked for every stderr chunk produced by the +// running code. +type OnStderrFunc func(msg OutputMessage) + +// OnResultFunc is a callback invoked for every Result emitted by the running +// code (display calls as well as the final main result). +type OnResultFunc func(result *Result) + +// OnErrorFunc is a callback invoked when the running code raises an error. +type OnErrorFunc func(err *ExecutionError) diff --git a/go/models_test.go b/go/models_test.go new file mode 100644 index 00000000..a74cebd3 --- /dev/null +++ b/go/models_test.go @@ -0,0 +1,239 @@ +package codeinterpreter + +import ( + "encoding/json" + "strings" + "testing" +) + +func TestExecution_Text_OnlyMainResult(t *testing.T) { + e := NewExecution() + e.Results = append(e.Results, + &Result{Text: "intermediate", IsMainResult: false}, + &Result{Text: "final", IsMainResult: true}, + ) + if e.Text() != "final" { + t.Errorf("got %q, want %q", e.Text(), "final") + } +} + +func TestExecution_Text_NoMain(t *testing.T) { + e := NewExecution() + e.Results = append(e.Results, &Result{Text: "x", IsMainResult: false}) + if e.Text() != "" { + t.Errorf("expected empty, got %q", e.Text()) + } +} + +func TestExecution_ToJSON(t *testing.T) { + e := NewExecution() + e.Results = append(e.Results, &Result{Text: "1", IsMainResult: true}) + e.Logs.Stdout = []string{"hello\n"} + e.ExecutionCount = 5 + s, err := e.ToJSON() + if err != nil { + t.Fatalf("ToJSON: %v", err) + } + var back map[string]interface{} + if err := json.Unmarshal([]byte(s), &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back["execution_count"].(float64) != 5 { + t.Errorf("exec count: %v", back["execution_count"]) + } + results, ok := back["results"].([]interface{}) + if !ok || len(results) != 1 { + t.Errorf("results: %+v", back["results"]) + } +} + +func TestExecutionError_ErrorAndJSON(t *testing.T) { + e := &ExecutionError{Name: "NameError", Value: "x", Traceback: "tb"} + msg := e.Error() + if !strings.Contains(msg, "NameError") || !strings.Contains(msg, "x") || !strings.Contains(msg, "tb") { + t.Errorf("unexpected: %q", msg) + } + + js := e.ToJSON() + var m map[string]interface{} + if err := json.Unmarshal([]byte(js), &m); err != nil { + t.Fatalf("parse json: %v", err) + } + if m["name"] != "NameError" || m["value"] != "x" { + t.Errorf("json: %+v", m) + } +} + +func TestOutputMessage_String(t *testing.T) { + o := OutputMessage{Line: "hello"} + if o.String() != "hello" { + t.Errorf("got %q", o.String()) + } +} + +func TestResult_StringWithText(t *testing.T) { + r := &Result{Text: "abc"} + s := r.String() + if !strings.Contains(s, "abc") { + t.Errorf("result string: %q", s) + } +} + +func TestResult_StringWithoutText(t *testing.T) { + r := &Result{HTML: ""} + s := r.String() + if !strings.Contains(s, "Formats:") { + t.Errorf("result string w/o text: %q", s) + } +} + +func TestNewResultFromRaw_AllKnownFormats(t *testing.T) { + raw := map[string]interface{}{ + "text": "t", + "html": "", + "markdown": "# m", + "svg": "", + "png": "p", + "jpeg": "j", + "pdf": "d", + "latex": "l", + "json": map[string]interface{}{"k": "v"}, + "javascript": "js", + "data": map[string]interface{}{"x": 1}, + "chart": map[string]interface{}{"type": "line", "title": "c", "elements": []interface{}{}}, + "is_main_result": true, + "extra": map[string]interface{}{"foo": "bar"}, + } + r := newResultFromRaw(raw) + + if r.Text != "t" || r.HTML != "" || r.Markdown != "# m" || r.SVG != "" { + t.Errorf("basic fields wrong: %+v", r) + } + if r.PNG != "p" || r.JPEG != "j" || r.PDF != "d" || r.LaTeX != "l" || r.JavaScript != "js" { + t.Errorf("secondary fields wrong: %+v", r) + } + if r.JSON == nil || r.JSON["k"] != "v" { + t.Errorf("json field: %+v", r.JSON) + } + if r.Data == nil || r.Data["x"] != 1 { + t.Errorf("data field: %+v", r.Data) + } + if r.Chart == nil || r.Chart.ChartType() != ChartTypeLine { + t.Errorf("chart missing or wrong type: %+v", r.Chart) + } + if !r.IsMainResult { + t.Error("is_main_result not parsed") + } + if r.Extra == nil || r.Extra["foo"] != "bar" { + t.Errorf("extra not parsed: %+v", r.Extra) + } + + // Formats must contain every key we fed in (including 'chart' and + // the custom 'foo' Extra). + want := []string{"text", "html", "markdown", "svg", "png", "jpeg", "pdf", "latex", "json", "javascript", "data", "chart", "foo"} + formats := r.Formats() + fmtSet := map[string]bool{} + for _, f := range formats { + fmtSet[f] = true + } + for _, w := range want { + if !fmtSet[w] { + t.Errorf("missing format %q in %+v", w, formats) + } + } +} + +func TestNewResultFromRaw_UnknownKeysGoToExtra(t *testing.T) { + raw := map[string]interface{}{ + "text": "t", + "custom/mime": "data-one", + "another_mime": "data-two", + } + r := newResultFromRaw(raw) + if r.Extra["custom/mime"] != "data-one" || r.Extra["another_mime"] != "data-two" { + t.Errorf("extras not captured: %+v", r.Extra) + } +} + +func TestGetFloat_AllTypes(t *testing.T) { + m := map[string]interface{}{ + "a": float64(1.5), + "b": float32(2.5), + "c": int(3), + "d": int64(4), + "e": json.Number("5.5"), + "f": "not-a-number", + } + cases := []struct { + key string + want float64 + }{ + {"a", 1.5}, {"b", 2.5}, {"c", 3}, {"d", 4}, {"e", 5.5}, {"f", 0}, + {"missing", 0}, + } + for _, c := range cases { + if got := getFloat(m, c.key); got != c.want { + t.Errorf("getFloat(%q) = %v, want %v", c.key, got, c.want) + } + } +} + +func TestGetFloatSlice(t *testing.T) { + m := map[string]interface{}{ + "a": []interface{}{float64(1), int(2), json.Number("3.3")}, + "b": "nope", + } + if got := getFloatSlice(m, "a"); len(got) != 3 || got[2] != 3.3 { + t.Errorf("got %+v", got) + } + if got := getFloatSlice(m, "b"); got != nil { + t.Errorf("expected nil for string, got %+v", got) + } + if got := getFloatSlice(m, "missing"); got != nil { + t.Error("expected nil for missing") + } +} + +func TestGetStringSlice(t *testing.T) { + m := map[string]interface{}{ + "a": []interface{}{"x", "y", 3, "z"}, // mixed types — non-strings dropped + } + got := getStringSlice(m, "a") + if len(got) != 3 || got[0] != "x" || got[2] != "z" { + t.Errorf("got %+v", got) + } +} + +func TestGetInt64(t *testing.T) { + m := map[string]interface{}{ + "a": float64(10), + "b": int(20), + "c": int64(30), + "d": json.Number("40"), + } + for _, c := range []struct { + key string + want int64 + }{{"a", 10}, {"b", 20}, {"c", 30}, {"d", 40}, {"missing", 0}} { + if got := getInt64(m, c.key); got != c.want { + t.Errorf("getInt64(%q) = %v, want %v", c.key, got, c.want) + } + } +} + +func TestChartInterface_DelegatesMethods(t *testing.T) { + c := deserializeChart(map[string]interface{}{ + "type": "bar", + "title": "T", + "elements": []interface{}{}, + }) + if c.ChartType() != ChartTypeBar { + t.Errorf("chart type: %s", c.ChartType()) + } + if c.ChartTitle() != "T" { + t.Errorf("chart title: %s", c.ChartTitle()) + } + if c.ToJSON()["type"] != "bar" { + t.Errorf("ToJSON missing type: %+v", c.ToJSON()) + } +} diff --git a/go/sandbox.go b/go/sandbox.go new file mode 100644 index 00000000..e50b6495 --- /dev/null +++ b/go/sandbox.go @@ -0,0 +1,254 @@ +package codeinterpreter + +import ( + "context" + "fmt" + "net/http" + "time" +) + +// SandboxOpts are options for creating or connecting to a Sandbox. +type SandboxOpts struct { + // APIKey to use. Falls back to the E2B_API_KEY env variable. + APIKey string + // AccessToken to use. + AccessToken string + // Domain to use (defaults to e2b.app). + Domain string + // Debug, if true, uses plain http:// against the sandbox. + Debug bool + // RequestTimeout default HTTP request timeout. + RequestTimeout time.Duration + // Timeout is the sandbox lifetime in seconds (not the request timeout). + Timeout time.Duration + // Template id/alias to use when creating a sandbox. Defaults to the + // code-interpreter template. + Template string + // Metadata attached to the sandbox. + Metadata map[string]string + // EnvVars passed to the sandbox at startup. + EnvVars map[string]string + // HTTPClient allows overriding the underlying http.Client. + HTTPClient *http.Client + // Headers are additional headers to send on every request. + Headers map[string]string +} + +// SandboxInfo is the JSON shape returned by the API when listing sandboxes. +type SandboxInfo struct { + SandboxID string `json:"sandboxID"` + ClientID string `json:"clientID"` + TemplateID string `json:"templateID"` + Alias string `json:"alias,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + StartedAt string `json:"startedAt,omitempty"` + EndAt string `json:"endAt,omitempty"` + State string `json:"state,omitempty"` +} + +// Sandbox is a running E2B sandbox with code-interpreter capabilities. +type Sandbox struct { + id string + clientID string + template string + envdPort int + connection *ConnectionConfig +} + +// SandboxID returns the ID of this sandbox. +func (s *Sandbox) SandboxID() string { return s.id } + +// ClientID returns the client id (envd worker) running this sandbox. +func (s *Sandbox) ClientID() string { return s.clientID } + +// fullID returns "-" which is the form used in the +// sandbox hostnames. +func (s *Sandbox) fullID() string { + if s.clientID == "" { + return s.id + } + return fmt.Sprintf("%s-%s", s.id, s.clientID) +} + +// getHost returns the public host for a port exposed by the sandbox. +func (s *Sandbox) getHost(port int) string { + return fmt.Sprintf("%d-%s.%s", port, s.fullID(), s.connection.Domain) +} + +// jupyterURL returns the URL to the internal Jupyter/Code-Interpreter server. +func (s *Sandbox) jupyterURL() string { + scheme := "https" + if s.connection.Debug { + scheme = "http" + } + return fmt.Sprintf("%s://%s", scheme, s.getHost(JupyterPort)) +} + +// Create starts a new sandbox. `opts` may be nil, in which case sensible +// defaults are used (template = code-interpreter-v1). +func Create(ctx context.Context, opts *SandboxOpts) (*Sandbox, error) { + if opts == nil { + opts = &SandboxOpts{} + } + + cfg := &ConnectionConfig{ + APIKey: opts.APIKey, + AccessToken: opts.AccessToken, + Domain: opts.Domain, + Debug: opts.Debug, + RequestTimeout: opts.RequestTimeout, + HTTPClient: opts.HTTPClient, + Headers: opts.Headers, + } + cfg.init() + + if cfg.APIKey == "" { + return nil, &AuthenticationError{Message: "API key is required; set E2B_API_KEY or SandboxOpts.APIKey"} + } + + template := opts.Template + if template == "" { + template = DefaultTemplate + } + + timeoutSec := int(opts.Timeout / time.Second) + if timeoutSec == 0 { + timeoutSec = DefaultSandboxTimeout + } + + body := map[string]interface{}{ + "templateID": template, + "timeout": timeoutSec, + } + if len(opts.Metadata) > 0 { + body["metadata"] = opts.Metadata + } + if len(opts.EnvVars) > 0 { + body["envVars"] = opts.EnvVars + } + + var out struct { + SandboxID string `json:"sandboxID"` + ClientID string `json:"clientID"` + TemplateID string `json:"templateID"` + EnvdPort int `json:"envdPort"` + } + if err := cfg.do(ctx, "POST", "/sandboxes", body, &out); err != nil { + return nil, err + } + + return &Sandbox{ + id: out.SandboxID, + clientID: out.ClientID, + template: out.TemplateID, + envdPort: out.EnvdPort, + connection: cfg, + }, nil +} + +// Connect attaches to an already running sandbox by its ID. The caller must +// supply at least the API key (via opts or the env var). +func Connect(ctx context.Context, sandboxID string, opts *SandboxOpts) (*Sandbox, error) { + if opts == nil { + opts = &SandboxOpts{} + } + cfg := &ConnectionConfig{ + APIKey: opts.APIKey, + AccessToken: opts.AccessToken, + Domain: opts.Domain, + Debug: opts.Debug, + RequestTimeout: opts.RequestTimeout, + HTTPClient: opts.HTTPClient, + Headers: opts.Headers, + } + cfg.init() + + var info SandboxInfo + if err := cfg.do(ctx, "GET", "/sandboxes/"+sandboxID, nil, &info); err != nil { + return nil, err + } + + return &Sandbox{ + id: info.SandboxID, + clientID: info.ClientID, + template: info.TemplateID, + connection: cfg, + }, nil +} + +// Kill terminates the sandbox. +func (s *Sandbox) Kill(ctx context.Context) error { + return s.connection.do(ctx, "DELETE", "/sandboxes/"+s.id, nil, nil) +} + +// SetTimeout updates the remaining lifetime of the sandbox. Pass the desired +// wall-clock time-until-expiration. +func (s *Sandbox) SetTimeout(ctx context.Context, timeout time.Duration) error { + body := map[string]int{ + "timeout": int(timeout / time.Second), + } + return s.connection.do(ctx, "POST", "/sandboxes/"+s.id+"/timeout", body, nil) +} + +// IsRunning checks whether the sandbox is still reachable. +func (s *Sandbox) IsRunning(ctx context.Context) (bool, error) { + err := s.connection.do(ctx, "GET", "/sandboxes/"+s.id, nil, nil) + if err == nil { + return true, nil + } + if _, ok := err.(*NotFoundError); ok { + return false, nil + } + return false, err +} + +// List returns all sandboxes currently running under the configured API key. +func List(ctx context.Context, opts *SandboxOpts) ([]SandboxInfo, error) { + if opts == nil { + opts = &SandboxOpts{} + } + cfg := &ConnectionConfig{ + APIKey: opts.APIKey, + AccessToken: opts.AccessToken, + Domain: opts.Domain, + Debug: opts.Debug, + RequestTimeout: opts.RequestTimeout, + HTTPClient: opts.HTTPClient, + Headers: opts.Headers, + } + cfg.init() + + var out []SandboxInfo + if err := cfg.do(ctx, "GET", "/sandboxes", nil, &out); err != nil { + return nil, err + } + return out, nil +} + +// GetInfo returns information about this sandbox, including metadata and +// start/end times. +func (s *Sandbox) GetInfo(ctx context.Context) (*SandboxInfo, error) { + var info SandboxInfo + if err := s.connection.do(ctx, "GET", "/sandboxes/"+s.id, nil, &info); err != nil { + return nil, err + } + return &info, nil +} + +// GetHost returns a routable hostname for a port exposed by the sandbox. This +// lets callers build URLs to user-exposed services. +func (s *Sandbox) GetHost(port int) string { + return s.getHost(port) +} + +// addAuthHeaders adds authentication headers used by direct-to-sandbox HTTP +// calls (jupyterURL/envd). +func (s *Sandbox) addAuthHeaders(h http.Header) { + h.Set("Content-Type", "application/json") + if s.connection.AccessToken != "" { + h.Set("X-Access-Token", s.connection.AccessToken) + } + if s.connection.TrafficAccessToken != "" { + h.Set("E2B-Traffic-Access-Token", s.connection.TrafficAccessToken) + } +} diff --git a/go/sandbox_test.go b/go/sandbox_test.go new file mode 100644 index 00000000..b9e5075b --- /dev/null +++ b/go/sandbox_test.go @@ -0,0 +1,352 @@ +package codeinterpreter + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" +) + +func TestConnectionConfig_DefaultsFromEnv(t *testing.T) { + t.Setenv("E2B_API_KEY", "env-key") + t.Setenv("E2B_DOMAIN", "example.dev") + t.Setenv("E2B_ACCESS_TOKEN", "env-token") + t.Setenv("E2B_DEBUG", "true") + + c := &ConnectionConfig{} + c.init() + + if c.APIKey != "env-key" { + t.Errorf("api key: %q", c.APIKey) + } + if c.Domain != "example.dev" { + t.Errorf("domain: %q", c.Domain) + } + if c.AccessToken != "env-token" { + t.Errorf("token: %q", c.AccessToken) + } + if !c.Debug { + t.Error("expected debug to be true") + } + if c.RequestTimeout == 0 { + t.Error("expected default request timeout") + } + if c.HTTPClient == nil { + t.Error("http client should have been created") + } +} + +func TestConnectionConfig_ExplicitOverrides(t *testing.T) { + // Clear env to make sure the explicit values win. + _ = os.Unsetenv("E2B_API_KEY") + _ = os.Unsetenv("E2B_DOMAIN") + _ = os.Unsetenv("E2B_ACCESS_TOKEN") + _ = os.Unsetenv("E2B_DEBUG") + + c := &ConnectionConfig{ + APIKey: "my-key", + Domain: "my.dev", + AccessToken: "tok", + RequestTimeout: 10 * time.Second, + } + c.init() + + if c.APIKey != "my-key" || c.Domain != "my.dev" || c.AccessToken != "tok" { + t.Errorf("explicit values not kept: %+v", c) + } + if c.RequestTimeout != 10*time.Second { + t.Errorf("timeout overwritten: %v", c.RequestTimeout) + } +} + +func TestConnectionConfig_APIBase(t *testing.T) { + c := &ConnectionConfig{Domain: "e2b.app"} + if got := c.APIBase(); got != "https://api.e2b.app" { + t.Errorf("api base: %s", got) + } + c.Debug = true + if got := c.APIBase(); got != "http://api.e2b.app" { + t.Errorf("debug api base: %s", got) + } +} + +func TestSandbox_GetHost(t *testing.T) { + sbx := &Sandbox{ + id: "abc", + clientID: "xyz", + connection: &ConnectionConfig{Domain: "e2b.app"}, + } + got := sbx.GetHost(3000) + if got != "3000-abc-xyz.e2b.app" { + t.Errorf("host: %q", got) + } + + sbx.clientID = "" + got = sbx.GetHost(8080) + if got != "8080-abc.e2b.app" { + t.Errorf("host w/o client: %q", got) + } +} + +func TestSandbox_JupyterURL(t *testing.T) { + sbx := &Sandbox{ + id: "sid", + clientID: "cid", + connection: &ConnectionConfig{Domain: "e2b.app"}, + } + if got := sbx.jupyterURL(); got != "https://49999-sid-cid.e2b.app" { + t.Errorf("jupyter url: %q", got) + } + sbx.connection.Debug = true + if got := sbx.jupyterURL(); got != "http://49999-sid-cid.e2b.app" { + t.Errorf("debug jupyter url: %q", got) + } +} + +func TestSandbox_AddAuthHeaders(t *testing.T) { + sbx := &Sandbox{ + connection: &ConnectionConfig{ + AccessToken: "tok", + TrafficAccessToken: "traf", + }, + } + h := http.Header{} + sbx.addAuthHeaders(h) + if h.Get("Content-Type") != "application/json" { + t.Error("content-type missing") + } + if h.Get("X-Access-Token") != "tok" { + t.Errorf("access token: %q", h.Get("X-Access-Token")) + } + if h.Get("E2B-Traffic-Access-Token") != "traf" { + t.Errorf("traffic token: %q", h.Get("E2B-Traffic-Access-Token")) + } +} + +func TestCreate_MissingAPIKey(t *testing.T) { + t.Setenv("E2B_API_KEY", "") + + _, err := Create(context.Background(), &SandboxOpts{}) + if err == nil { + t.Fatal("expected error when API key missing") + } + if _, ok := err.(*AuthenticationError); !ok { + t.Errorf("expected AuthenticationError, got %T: %v", err, err) + } +} + +func TestErrorTypes_Error(t *testing.T) { + cases := []struct { + err error + want string + }{ + {&NotFoundError{Message: "ctx"}, "not found: ctx"}, + {&TimeoutError{Message: "slow"}, "timeout: slow"}, + {&InvalidArgumentError{Message: "bad"}, "invalid argument: bad"}, + {&AuthenticationError{Message: "key"}, "authentication error: key"}, + {&RateLimitError{Message: "rl"}, "rate limit: rl"}, + {&SandboxError{Message: "boom"}, "sandbox error: boom"}, + } + for _, c := range cases { + if got := c.err.Error(); got != c.want { + t.Errorf("Error(): got %q, want %q", got, c.want) + } + } + + se := &SandboxError{Message: "oops", StatusCode: 500} + if !strings.Contains(se.Error(), "500") { + t.Errorf("sandbox error with status should contain code: %s", se.Error()) + } +} + +func TestCreate_ThroughMockAPI(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/sandboxes": + if r.Method != "POST" { + t.Errorf("unexpected method: %s", r.Method) + } + if r.Header.Get("X-API-Key") != "my-key" { + t.Errorf("api key header missing: %q", r.Header.Get("X-API-Key")) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"sandboxID": "sbx-1", "clientID": "c-1", "templateID": "code-interpreter-v1", "envdPort": 49999}`)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + client := &http.Client{Transport: &rewriteToServerTransport{target: srv.URL}} + + sbx, err := Create(context.Background(), &SandboxOpts{ + APIKey: "my-key", + Domain: "e2b.test", + Debug: true, + HTTPClient: client, + }) + if err != nil { + t.Fatalf("create: %v", err) + } + if sbx.SandboxID() != "sbx-1" || sbx.ClientID() != "c-1" { + t.Errorf("unexpected sandbox: id=%s client=%s", sbx.SandboxID(), sbx.ClientID()) + } + if sbx.template != "code-interpreter-v1" { + t.Errorf("template: %s", sbx.template) + } +} + +func TestKill_ThroughMockAPI(t *testing.T) { + var killed bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "DELETE" && strings.HasPrefix(r.URL.Path, "/sandboxes/") { + killed = true + w.WriteHeader(http.StatusNoContent) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + sbx := &Sandbox{ + id: "abc", + connection: &ConnectionConfig{ + APIKey: "k", + Domain: "e2b.test", + Debug: true, + HTTPClient: &http.Client{Transport: &rewriteToServerTransport{target: srv.URL}}, + }, + } + if err := sbx.Kill(context.Background()); err != nil { + t.Fatalf("kill: %v", err) + } + if !killed { + t.Error("expected handler to have been called") + } +} + +func TestSetTimeout_ThroughMockAPI(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" || r.URL.Path != "/sandboxes/abc/timeout" { + t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path) + } + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + sbx := &Sandbox{ + id: "abc", + connection: &ConnectionConfig{ + APIKey: "k", + Domain: "e2b.test", + Debug: true, + HTTPClient: &http.Client{Transport: &rewriteToServerTransport{target: srv.URL}}, + }, + } + if err := sbx.SetTimeout(context.Background(), 30*time.Second); err != nil { + t.Fatalf("set timeout: %v", err) + } +} + +func TestIsRunning_ReturnsFalseOnNotFound(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + sbx := &Sandbox{ + id: "abc", + connection: &ConnectionConfig{ + APIKey: "k", + Domain: "e2b.test", + Debug: true, + HTTPClient: &http.Client{Transport: &rewriteToServerTransport{target: srv.URL}}, + }, + } + running, err := sbx.IsRunning(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if running { + t.Error("expected sandbox to be reported as not running") + } +} + +func TestIsRunning_ReturnsTrueOnOK(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"sandboxID":"abc"}`)) + })) + defer srv.Close() + + sbx := &Sandbox{ + id: "abc", + connection: &ConnectionConfig{ + APIKey: "k", + Domain: "e2b.test", + Debug: true, + HTTPClient: &http.Client{Transport: &rewriteToServerTransport{target: srv.URL}}, + }, + } + running, err := sbx.IsRunning(context.Background()) + if err != nil { + t.Fatalf("unexpected: %v", err) + } + if !running { + t.Error("expected running=true") + } +} + +func TestList_ThroughMockAPI(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/sandboxes" || r.Method != "GET" { + t.Errorf("unexpected: %s %s", r.Method, r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`[ + {"sandboxID":"sbx-1","clientID":"c-1","templateID":"tpl","state":"running"}, + {"sandboxID":"sbx-2","clientID":"c-2","templateID":"tpl","state":"running"} + ]`)) + })) + defer srv.Close() + + client := &http.Client{Transport: &rewriteToServerTransport{target: srv.URL}} + list, err := List(context.Background(), &SandboxOpts{ + APIKey: "k", + Domain: "e2b.test", + Debug: true, + HTTPClient: client, + }) + if err != nil { + t.Fatalf("list: %v", err) + } + if len(list) != 2 { + t.Fatalf("expected 2, got %d", len(list)) + } + if list[0].SandboxID != "sbx-1" || list[1].SandboxID != "sbx-2" { + t.Errorf("unexpected list: %+v", list) + } +} + +type rewriteToServerTransport struct { + target string +} + +func (r *rewriteToServerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + // parse target + idx := strings.Index(r.target, "://") + if idx == -1 { + return nil, &SandboxError{Message: "invalid target"} + } + scheme := r.target[:idx] + host := r.target[idx+3:] + req.URL.Scheme = scheme + req.URL.Host = host + req.Host = host + return http.DefaultTransport.RoundTrip(req) +} diff --git a/go/testutil_test.go b/go/testutil_test.go new file mode 100644 index 00000000..73772b9a --- /dev/null +++ b/go/testutil_test.go @@ -0,0 +1,67 @@ +package codeinterpreter + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" +) + +// newMockSandbox creates a *Sandbox whose jupyterURL() points at the given +// httptest.Server, so tests can exercise RunCode / context handlers without a +// real e2b backend. +func newMockSandbox(t *testing.T, server *httptest.Server) *Sandbox { + t.Helper() + + u, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("parse mock server url: %v", err) + } + + // Use the httptest server's own client (which trusts the test TLS root, + // if any). We inject a RoundTripper that rewrites the host to the test + // server's address regardless of what the SDK builds. + client := server.Client() + origTransport := client.Transport + if origTransport == nil { + origTransport = http.DefaultTransport + } + client.Transport = &rewriteTransport{ + base: origTransport, + targetHost: u.Host, + scheme: u.Scheme, + } + + return &Sandbox{ + id: "sbx-test", + clientID: "client-test", + template: DefaultTemplate, + connection: &ConnectionConfig{ + APIKey: "test-api-key", + Domain: "e2b.test", + Debug: true, + RequestTimeout: 5 * time.Second, + HTTPClient: client, + }, + } +} + +type rewriteTransport struct { + base http.RoundTripper + targetHost string + scheme string +} + +func (r *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.URL.Scheme = r.scheme + req.URL.Host = r.targetHost + req.Host = r.targetHost + return r.base.RoundTrip(req) +} + +func ndjson(lines ...string) string { + return strings.Join(lines, "\n") +}