Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions apps/code/src/main/di/container.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import { McpCallbackService } from "../services/mcp-callback/service";
import { McpProxyService } from "../services/mcp-proxy/service";
import { NotificationService } from "../services/notification/service";
import { OAuthService } from "../services/oauth/service";
import { PostHogCodeInternalMcpService } from "../services/posthog-code-internal-mcp/service";
import { PosthogPluginService } from "../services/posthog-plugin/service";
import { ProcessTrackingService } from "../services/process-tracking/service";
import { ProvisioningService } from "../services/provisioning/service";
Expand Down Expand Up @@ -102,6 +103,9 @@ container.bind(MAIN_TOKENS.AgentService).to(AgentService);
container.bind(MAIN_TOKENS.AuthService).to(AuthService);
container.bind(MAIN_TOKENS.AuthProxyService).to(AuthProxyService);
container.bind(MAIN_TOKENS.McpProxyService).to(McpProxyService);
container
.bind(MAIN_TOKENS.PostHogCodeInternalMcpService)
.to(PostHogCodeInternalMcpService);
container.bind(MAIN_TOKENS.ArchiveService).to(ArchiveService);
container.bind(MAIN_TOKENS.SuspensionService).to(SuspensionService);
container.bind(MAIN_TOKENS.AppLifecycleService).to(AppLifecycleService);
Expand Down
3 changes: 3 additions & 0 deletions apps/code/src/main/di/tokens.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ export const MAIN_TOKENS = Object.freeze({
AuthService: Symbol.for("Main.AuthService"),
AuthProxyService: Symbol.for("Main.AuthProxyService"),
McpProxyService: Symbol.for("Main.McpProxyService"),
PostHogCodeInternalMcpService: Symbol.for(
"Main.PostHogCodeInternalMcpService",
),
ArchiveService: Symbol.for("Main.ArchiveService"),
SuspensionService: Symbol.for("Main.SuspensionService"),
AppLifecycleService: Symbol.for("Main.AppLifecycleService"),
Expand Down
6 changes: 6 additions & 0 deletions apps/code/src/main/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
initializePostHog,
trackAppEvent,
} from "./services/posthog-analytics";
import type { PostHogCodeInternalMcpService } from "./services/posthog-code-internal-mcp/service";
import type { PosthogPluginService } from "./services/posthog-plugin/service";
import type { SuspensionService } from "./services/suspension/service";
import type { TaskLinkService } from "./services/task-link/service";
Expand Down Expand Up @@ -52,6 +53,11 @@ async function initializeServices(): Promise<void> {

await authService.initialize();

const internalMcp = container.get<PostHogCodeInternalMcpService>(
MAIN_TOKENS.PostHogCodeInternalMcpService,
);
await internalMcp.start();

// Initialize workspace branch watcher for live branch rename detection
const workspaceService = container.get<WorkspaceService>(
MAIN_TOKENS.WorkspaceService,
Expand Down
7 changes: 7 additions & 0 deletions apps/code/src/main/services/agent/auth-adapter.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ function createDependencies() {
(id: string) => `http://127.0.0.1:9998/${encodeURIComponent(id)}`,
),
},
internalMcp: {
getUrl: vi.fn().mockReturnValue("http://127.0.0.1:9997/mcp"),
getAuthHeader: vi
.fn()
.mockReturnValue({ name: "authorization", value: "Bearer test" }),
},
};
}

Expand All @@ -77,6 +83,7 @@ describe("AgentAuthAdapter", () => {
deps.authService as never,
deps.authProxy as never,
deps.mcpProxy as never,
deps.internalMcp as never,
);
});

Expand Down
16 changes: 16 additions & 0 deletions apps/code/src/main/services/agent/auth-adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { logger } from "../../utils/logger";
import type { AuthService } from "../auth/service";
import type { AuthProxyService } from "../auth-proxy/service";
import type { McpProxyService } from "../mcp-proxy/service";
import type { PostHogCodeInternalMcpService } from "../posthog-code-internal-mcp/service";
import type { Credentials } from "./schemas";

const log = logger.scope("agent-auth-adapter");
Expand Down Expand Up @@ -63,6 +64,8 @@ export class AgentAuthAdapter {
private readonly authProxy: AuthProxyService,
@inject(MAIN_TOKENS.McpProxyService)
private readonly mcpProxy: McpProxyService,
@inject(MAIN_TOKENS.PostHogCodeInternalMcpService)
private readonly internalMcp: PostHogCodeInternalMcpService,
) {}

createPosthogConfig(credentials: Credentials): AgentPosthogConfig {
Expand Down Expand Up @@ -102,6 +105,19 @@ export class AgentAuthAdapter {
],
});

try {
servers.push({
name: "posthog-code-internal",
type: "http",
url: this.internalMcp.getUrl(),
headers: [this.internalMcp.getAuthHeader()],
});
} catch (err) {
// Service should always be running by the time the agent starts a task,
// but don't take down the whole MCP config if it isn't.
log.warn("posthog-code-internal MCP not available", { error: err });
}

const installations = await this.fetchMcpInstallations(credentials);

for (const installation of installations) {
Expand Down
4 changes: 4 additions & 0 deletions apps/code/src/main/services/agent/service.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ function createMockDependencies() {
appDataPath: "/mock/userData",
logsPath: "/mock/logs",
},
internalMcp: {
on: vi.fn(),
},
};
}

Expand Down Expand Up @@ -220,6 +223,7 @@ describe("AgentService", () => {
deps.bundledResources as never,
deps.appMeta as never,
deps.storagePaths as never,
deps.internalMcp as never,
);
});

Expand Down
56 changes: 56 additions & 0 deletions apps/code/src/main/services/agent/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
import {
isMcpToolReadOnly,
isNotification,
POSTHOG_METHODS,
POSTHOG_NOTIFICATIONS,
} from "@posthog/agent";
import type { McpToolApprovals } from "@posthog/agent/adapters/claude/mcp/tool-metadata";
Expand Down Expand Up @@ -50,6 +51,8 @@ import { logger } from "../../utils/logger";
import { TypedEventEmitter } from "../../utils/typed-event-emitter";
import type { FsService } from "../fs/service";
import type { McpAppsService } from "../mcp-apps/service";
import { PostHogCodeInternalMcpEvent } from "../posthog-code-internal-mcp/schemas";
import type { PostHogCodeInternalMcpService } from "../posthog-code-internal-mcp/service";
import type { PosthogPluginService } from "../posthog-plugin/service";
import type { ProcessTrackingService } from "../process-tracking/service";
import type { SleepService } from "../sleep/service";
Expand Down Expand Up @@ -247,6 +250,8 @@ interface ManagedSession {
mcpToolApprovals: McpToolApprovals;
/** Maps tool keys to their installation for backend approval updates */
toolInstallations: McpToolInstallations;
/** Set when an MCP server is installed mid-turn; refresh runs after the turn ends. */
pendingMcpRefresh: boolean;
}

/** Get the agent session ID from a managed session, throwing if not set. */
Expand Down Expand Up @@ -304,6 +309,8 @@ export class AgentService extends TypedEventEmitter<AgentServiceEvents> {
private readonly appMeta: IAppMeta,
@inject(MAIN_TOKENS.StoragePaths)
private readonly storagePaths: IStoragePaths,
@inject(MAIN_TOKENS.PostHogCodeInternalMcpService)
internalMcp: PostHogCodeInternalMcpService,
) {
super();
this.processTracking = processTracking;
Expand All @@ -314,6 +321,9 @@ export class AgentService extends TypedEventEmitter<AgentServiceEvents> {
this.mcpAppsService = mcpAppsService;

powerManager.onResume(() => this.checkIdleDeadlines());
internalMcp.on(PostHogCodeInternalMcpEvent.McpServerInstalled, () => {
void this.refreshAllSessionMcpServers();
});
}

private getClaudeCliPath(): string {
Expand Down Expand Up @@ -395,6 +405,46 @@ export class AgentService extends TypedEventEmitter<AgentServiceEvents> {
this.recordActivity(taskRunId);
}

private async refreshSessionMcpServers(
session: ManagedSession,
): Promise<void> {
try {
const { servers } = await this.agentAuthAdapter.buildMcpServers(
session.config.credentials,
);
await session.clientSideConnection.extMethod(
POSTHOG_METHODS.REFRESH_SESSION,
{ mcpServers: servers },
);
log.info("Refreshed MCP servers for session", {
taskRunId: session.taskRunId,
serverCount: servers.length,
});
} catch (err) {
log.warn("Failed to refresh MCP servers for session", {
taskRunId: session.taskRunId,
err,
});
}
}

private async refreshAllSessionMcpServers(): Promise<void> {
const refreshable: ManagedSession[] = [];
for (const session of this.sessions.values()) {
if (session.promptPending) {
// ACP refresh contract requires no prompt in flight; defer until the
// turn completes (see prompt() finally block).
session.pendingMcpRefresh = true;
log.info("Deferring MCP refresh until current turn ends", {
taskRunId: session.taskRunId,
});
continue;
}
refreshable.push(session);
}
await Promise.all(refreshable.map((s) => this.refreshSessionMcpServers(s)));
}

/**
* Check if any sessions are currently active (i.e. have a prompt pending).
*/
Expand Down Expand Up @@ -797,6 +847,7 @@ When creating pull requests, add the following footer at the end of the PR descr
inFlightMcpToolCalls: new Map(),
mcpToolApprovals: toolApprovals,
toolInstallations,
pendingMcpRefresh: false,
};

this.sessions.set(taskRunId, session);
Expand Down Expand Up @@ -885,6 +936,11 @@ When creating pull requests, add the following footer at the end of the PR descr
this.recordActivity(sessionId);
this.sleepService.release(sessionId);

if (session.pendingMcpRefresh) {
session.pendingMcpRefresh = false;
void this.refreshSessionMcpServers(session);
}

if (!this.hasActiveSessions()) {
this.emit(AgentServiceEvent.SessionsIdle, undefined);
}
Expand Down
19 changes: 19 additions & 0 deletions apps/code/src/main/services/posthog-code-internal-mcp/schemas.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { z } from "zod";

export const customInstructionsChanged = z.object({
customInstructions: z.string(),
});

export type CustomInstructionsChanged = z.infer<
typeof customInstructionsChanged
>;

export const PostHogCodeInternalMcpEvent = {
CustomInstructionsChanged: "custom-instructions-changed",
McpServerInstalled: "mcp-server-installed",
} as const;

export interface PostHogCodeInternalMcpEvents {
[PostHogCodeInternalMcpEvent.CustomInstructionsChanged]: CustomInstructionsChanged;
[PostHogCodeInternalMcpEvent.McpServerInstalled]: Record<never, never>;
}
105 changes: 105 additions & 0 deletions apps/code/src/main/services/posthog-code-internal-mcp/service.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";

// electron-store mkdir's userDataDir at import time, which fails in CI where
// the default mocked path (/mock/userData) isn't writable. The tests below
// don't exercise the store paths, so a no-op mock is safe.
vi.mock("../../utils/store", () => ({
rendererStore: {
has: () => false,
get: () => undefined,
set: () => {},
},
}));

import { PostHogCodeInternalMcpEvent } from "./schemas";
import { PostHogCodeInternalMcpService } from "./service";

interface FakeAuthService {
getValidAccessToken: () => Promise<{ apiHost: string; token: string }>;
getState: () => { projectId: number };
authenticatedFetch: (
fetchImpl: typeof fetch,
url: string,
init?: RequestInit,
) => Promise<Response>;
}

const createFakeAuth = (
fetchImpl: (url: string) => Promise<Response>,
): FakeAuthService => ({
getValidAccessToken: async () => ({
apiHost: "https://example.com",
token: "t",
}),
getState: () => ({ projectId: 1 }),
authenticatedFetch: async (_f, url) => fetchImpl(String(url)),
});

const okJson = (body: unknown): Response =>
new Response(JSON.stringify(body), {
status: 200,
headers: { "Content-Type": "application/json" },
});

describe("PostHogCodeInternalMcpService.pollForOauthCompletion", () => {
beforeEach(() => {
vi.useFakeTimers();
});
afterEach(() => {
vi.useRealTimers();
vi.restoreAllMocks();
});

it("emits McpServerInstalled when pending_oauth flips to false", async () => {
const responses = [
okJson({
results: [
{ id: "abc", name: "linear", pending_oauth: true, is_enabled: true },
],
}),
okJson({
results: [
{ id: "abc", name: "linear", pending_oauth: false, is_enabled: true },
],
}),
];
const auth = createFakeAuth(async () => {
const next = responses.shift();
if (!next) throw new Error("no more responses");
return next;
});
const service = new PostHogCodeInternalMcpService(auth as never);
const handler = vi.fn();
service.on(PostHogCodeInternalMcpEvent.McpServerInstalled, handler);

const poll = (
service as unknown as {
pollForOauthCompletion: (id: string, name: string) => Promise<void>;
}
).pollForOauthCompletion("abc", "linear");

await vi.advanceTimersByTimeAsync(3500);
await vi.advanceTimersByTimeAsync(3500);
await poll;

expect(handler).toHaveBeenCalledOnce();
});

it("stops polling when installation disappears", async () => {
const auth = createFakeAuth(async () => okJson({ results: [] }));
const service = new PostHogCodeInternalMcpService(auth as never);
const handler = vi.fn();
service.on(PostHogCodeInternalMcpEvent.McpServerInstalled, handler);

const poll = (
service as unknown as {
pollForOauthCompletion: (id: string, name: string) => Promise<void>;
}
).pollForOauthCompletion("abc", "linear");

await vi.advanceTimersByTimeAsync(3500);
await poll;

expect(handler).not.toHaveBeenCalled();
});
});
Loading
Loading