flat: convert steel-browser from submodule to regular folder

This commit is contained in:
2026-05-06 19:19:17 +10:00
parent 1efe7c189d
commit 865d8d0d66
32 changed files with 13160 additions and 1 deletions

View File

@@ -0,0 +1,103 @@
import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
import { resolveSessionMode, type SteelSessionMode } from "./session-mode.js";
import { SteelClient } from "./steel-client.js";
import { clickTool } from "./tools/click.js";
import { computerTool } from "./tools/computer.js";
import { extractTool } from "./tools/extract.js";
import { findElementsTool } from "./tools/find-elements.js";
import { fillFormTool } from "./tools/fill-form.js";
import { getTitleTool, getUrlTool, goBackTool } from "./tools/navigation.js";
import { navigateTool } from "./tools/navigate.js";
import { pdfTool } from "./tools/pdf.js";
import { scrapeTool } from "./tools/scrape.js";
import { screenshotTool } from "./tools/screenshot.js";
import { scrollTool } from "./tools/scroll.js";
import { pinSessionTool, releaseSessionTool } from "./tools/session-control.js";
import { typeTool } from "./tools/type.js";
import { waitTool } from "./tools/wait.js";
export default function steelExtension(pi: ExtensionAPI): void {
const steelClient = new SteelClient();
const defaultSessionMode = resolveSessionMode();
let sessionMode = defaultSessionMode;
let closingSessions: Promise<void> | null = null;
const closeSessions = async (reason: string) => {
if (!closingSessions) {
closingSessions = (async () => {
try {
await steelClient.closeAllSessions();
} catch (error: unknown) {
// Cleanup failures should not break the main agent response path.
console.warn(`[steel] session cleanup failed (${reason})`, error);
} finally {
closingSessions = null;
}
})();
}
await closingSessions;
};
const sessionController = {
getDefaultSessionMode: () => defaultSessionMode,
getSessionMode: () => sessionMode,
setSessionMode: (mode: SteelSessionMode) => {
sessionMode = mode;
},
closeSessions,
};
const tools = [
navigateTool(steelClient),
scrapeTool(steelClient),
screenshotTool(steelClient),
pdfTool(steelClient),
clickTool(steelClient),
computerTool(steelClient),
findElementsTool(steelClient),
typeTool(steelClient),
fillFormTool(steelClient),
waitTool(steelClient),
extractTool(steelClient),
scrollTool(steelClient),
goBackTool(steelClient),
getUrlTool(steelClient),
getTitleTool(steelClient),
pinSessionTool(steelClient, sessionController),
releaseSessionTool(steelClient, sessionController),
];
for (const tool of tools) {
pi.registerTool(tool);
}
pi.on("turn_end", async () => {
if (sessionMode === "turn") {
await closeSessions("turn_end");
}
});
pi.on("agent_end", async () => {
if (sessionMode === "agent") {
await closeSessions("agent_end");
}
});
// Defensive cleanup for interactive session switches/forks.
pi.on("session_before_switch", async () => {
await closeSessions("session_before_switch");
});
pi.on("session_shutdown", async () => {
await closeSessions("session_shutdown");
});
const shutdownApi = pi as ExtensionAPI & {
onShutdown?: (handler: () => Promise<void> | void) => void;
};
shutdownApi.onShutdown?.(async () => {
await closeSessions("onShutdown");
});
}

View File

@@ -0,0 +1,18 @@
export type SteelSessionMode = "turn" | "agent" | "session";
export function resolveSessionMode(): SteelSessionMode {
const rawValue = process.env.STEEL_SESSION_MODE?.trim().toLowerCase();
if (!rawValue) {
return "agent";
}
if (rawValue === "turn" || rawValue === "agent" || rawValue === "session") {
return rawValue;
}
console.warn(
`[steel] unsupported STEEL_SESSION_MODE="${rawValue}", falling back to "agent"`
);
return "agent";
}

View File

@@ -0,0 +1,686 @@
import fs from "node:fs";
import os from "node:os";
import path from "node:path";
import Steel from "steel-sdk";
import type {
CaptchaSolveResponse,
CaptchaStatusResponse,
} from "steel-sdk/resources/sessions";
import { chromium, type Browser, type BrowserContext, type Page } from "playwright-core";
import { toolError } from "./tools/tool-runtime.js";
type SessionCreateOptions = Steel.SessionCreateParams;
type SessionMetadata = Awaited<ReturnType<Steel["sessions"]["create"]>>;
type SessionGotoOptions = Parameters<Page["goto"]>[1];
type SessionWaitForSelectorOptions = Parameters<Page["waitForSelector"]>[1];
type SessionClickOptions = Parameters<Page["click"]>[1];
type SessionTypeOptions = Parameters<Page["type"]>[2];
type SessionScreenshotOptions = Parameters<Page["screenshot"]>[0];
type SessionPdfOptions = Parameters<Page["pdf"]>[0];
type SessionComputerParams = Steel.SessionComputerParams;
type SessionComputerResponse = Steel.SessionComputerResponse;
type SteelConfigFile = {
apiKey?: unknown;
browser?: {
apiUrl?: unknown;
} | null;
} | null;
type ResolvedSteelRuntimeConfig = {
apiKey: string | null;
baseURL?: string;
baseURLOverridden: boolean;
viewerBaseURL?: string;
};
export interface LiveSteelSession {
id: string;
sessionViewerUrl: string;
debugUrl: string;
page: Page;
goto: (url: string, options?: SessionGotoOptions) => Promise<unknown>;
goBack: (options?: Parameters<Page["goBack"]>[0]) => Promise<unknown>;
back: (options?: Parameters<Page["goBack"]>[0]) => Promise<unknown>;
url: () => string;
title: () => Promise<string>;
waitForSelector: (
selector: string,
options?: SessionWaitForSelectorOptions
) => Promise<unknown>;
click: (selector: string, options?: SessionClickOptions) => Promise<unknown>;
fill: (selector: string, text: string) => Promise<unknown>;
type: (
selector: string,
text: string,
options?: SessionTypeOptions
) => Promise<unknown>;
evaluate: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
locator: (selector: string) => ReturnType<Page["locator"]>;
content: () => Promise<string>;
screenshot: (options?: SessionScreenshotOptions) => Promise<unknown>;
pdf: (options?: SessionPdfOptions) => Promise<unknown>;
computer: (body: SessionComputerParams) => Promise<SessionComputerResponse>;
captchasStatus: () => Promise<CaptchaStatusResponse>;
captchasSolve: () => Promise<CaptchaSolveResponse>;
}
type TrackedSession = {
metadata: SessionMetadata;
browser: Browser;
context: BrowserContext;
page: Page;
liveSession: LiveSteelSession;
};
export interface SteelClientOptions {
apiKey?: string | null;
baseURL?: string;
sessionTimeoutMs?: number;
sessionCreateOptions?: Partial<SessionCreateOptions>;
}
export interface SessionRefreshOptions {
useProxy?: boolean;
proxyUrl?: string | null;
}
const TRUE_ENV_VALUES = new Set(["1", "true", "yes", "on"]);
const FALSE_ENV_VALUES = new Set(["0", "false", "no", "off"]);
const DEFAULT_STEEL_BASE_URL = "https://api.steel.dev";
const DEFAULT_STEEL_APP_URL = "https://app.steel.dev";
function normalizeConfigDir(input: string | undefined): string {
const trimmed = input?.trim();
if (trimmed) {
return trimmed;
}
return path.join(os.homedir(), ".config", "steel");
}
function readSteelConfigFile(): SteelConfigFile {
const configPath = path.join(
normalizeConfigDir(process.env.STEEL_CONFIG_DIR),
"config.json"
);
try {
const contents = fs.readFileSync(configPath, "utf-8");
const parsed = JSON.parse(contents) as SteelConfigFile;
if (!parsed || typeof parsed !== "object") {
return null;
}
return parsed;
} catch {
return null;
}
}
function normalizeOptionalString(value: unknown): string | undefined {
if (typeof value !== "string") {
return undefined;
}
const trimmed = value.trim();
return trimmed || undefined;
}
function normalizeSdkBaseURL(rawUrl: string): string {
const trimmed = rawUrl.trim().replace(/\/+$/, "");
if (!trimmed) {
throw new Error("base URL must not be empty.");
}
let parsed: URL;
try {
parsed = new URL(trimmed);
} catch (error: unknown) {
throw toolError(
"SteelClient initialization",
`Invalid Steel base URL: ${error instanceof Error ? error.message : "invalid URL"}`
);
}
if (!["http:", "https:"].includes(parsed.protocol)) {
throw toolError(
"SteelClient initialization",
"Steel base URL must use http or https."
);
}
const pathname = parsed.pathname.replace(/\/+$/, "");
if (pathname === "/v1") {
parsed.pathname = "";
}
return parsed.toString().replace(/\/+$/, "");
}
function resolveViewerBaseURL(baseURL: string | undefined, overridden: boolean): string | undefined {
if (!overridden || !baseURL) {
return DEFAULT_STEEL_APP_URL;
}
try {
const parsed = new URL(baseURL);
const host = parsed.hostname.toLowerCase();
if (
host === "api.steel.dev" ||
host.endsWith(".steel.dev")
) {
return DEFAULT_STEEL_APP_URL;
}
} catch {
return undefined;
}
return undefined;
}
function resolveSteelRuntimeConfig(
apiKeyOverride?: string | null,
baseURLOverride?: string
): ResolvedSteelRuntimeConfig {
const config = readSteelConfigFile();
const configApiKey = normalizeOptionalString(config?.apiKey);
const configBrowserApiUrl = normalizeOptionalString(config?.browser?.apiUrl);
const explicitApiKey = normalizeOptionalString(apiKeyOverride ?? undefined);
const envApiKey = normalizeOptionalString(process.env.STEEL_API_KEY);
const resolvedApiKey = explicitApiKey ?? envApiKey ?? configApiKey ?? null;
const explicitBaseURL = normalizeOptionalString(baseURLOverride);
const envBaseURL = normalizeOptionalString(process.env.STEEL_BASE_URL);
const envBrowserApiURL = normalizeOptionalString(process.env.STEEL_BROWSER_API_URL);
const envLocalApiURL = normalizeOptionalString(process.env.STEEL_LOCAL_API_URL);
const envApiURL = normalizeOptionalString(process.env.STEEL_API_URL);
const rawBaseURL =
explicitBaseURL ??
envBaseURL ??
envBrowserApiURL ??
envLocalApiURL ??
configBrowserApiUrl ??
envApiURL;
const normalizedBaseURL = rawBaseURL
? normalizeSdkBaseURL(rawBaseURL)
: undefined;
const baseURLOverridden = normalizedBaseURL !== undefined;
if (!resolvedApiKey && !baseURLOverridden) {
throw toolError(
"SteelClient initialization",
"STEEL_API_KEY is required. Set it in the environment, run `steel login`, or configure a custom Steel base URL for self-hosted usage."
);
}
return {
apiKey: resolvedApiKey,
baseURL: normalizedBaseURL,
baseURLOverridden,
viewerBaseURL: resolveViewerBaseURL(normalizedBaseURL, baseURLOverridden),
};
}
function getSessionFieldString(
session: Record<string, unknown>,
keys: readonly string[]
): string | undefined {
for (const key of keys) {
const value = session[key];
if (typeof value === "string" && value.trim()) {
return value.trim();
}
}
return undefined;
}
export function resolveSessionId(session: Record<string, unknown>): string | undefined {
return getSessionFieldString(session, ["id", "sessionId"]);
}
export function resolveSessionConnectURL(session: Record<string, unknown>): string | undefined {
return getSessionFieldString(session, [
"websocketUrl",
"wsUrl",
"connectUrl",
"cdpUrl",
"browserWSEndpoint",
"wsEndpoint",
]);
}
export function buildSessionConnectURL(
session: Record<string, unknown>,
apiKey?: string | null
): string | undefined {
const rawConnectURL = resolveSessionConnectURL(session);
const sessionId = resolveSessionId(session);
if (!rawConnectURL) {
if (!sessionId || !apiKey) {
return undefined;
}
return `wss://connect.steel.dev?apiKey=${encodeURIComponent(apiKey)}&sessionId=${encodeURIComponent(sessionId)}`;
}
try {
const parsed = new URL(rawConnectURL);
if (apiKey && !parsed.searchParams.get("apiKey")) {
parsed.searchParams.set("apiKey", apiKey);
}
if (sessionId && !parsed.searchParams.get("sessionId")) {
parsed.searchParams.set("sessionId", sessionId);
}
return parsed.toString();
} catch {
const params = new URLSearchParams();
if (apiKey && !/(?:[?&])apiKey=/.test(rawConnectURL)) {
params.set("apiKey", apiKey);
}
if (sessionId && !/(?:[?&])sessionId=/.test(rawConnectURL)) {
params.set("sessionId", sessionId);
}
const query = params.toString();
if (!query) {
return rawConnectURL;
}
const separator = rawConnectURL.includes("?") ? "&" : "?";
return `${rawConnectURL}${separator}${query}`;
}
}
export function resolveSessionViewerURL(
session: Record<string, unknown>,
viewerBaseURL?: string
): string | undefined {
const explicit = getSessionFieldString(session, [
"sessionViewerUrl",
"viewerUrl",
"liveViewUrl",
"debugUrl",
]);
if (explicit) {
return explicit;
}
const sessionId = resolveSessionId(session);
if (!sessionId || !viewerBaseURL) {
return undefined;
}
return `${viewerBaseURL.replace(/\/+$/, "")}/sessions/${sessionId}`;
}
export function sessionDetails(session: {
id: string;
sessionViewerUrl?: string | null;
}) {
return {
sessionId: session.id,
sessionViewerUrl:
typeof session.sessionViewerUrl === "string"
? session.sessionViewerUrl
: "",
};
}
function parseBooleanEnv(name: string): boolean | undefined {
const raw = process.env[name];
if (raw === undefined) {
return undefined;
}
const normalized = raw.trim().toLowerCase();
if (!normalized) {
return undefined;
}
if (TRUE_ENV_VALUES.has(normalized)) {
return true;
}
if (FALSE_ENV_VALUES.has(normalized)) {
return false;
}
throw toolError(
"SteelClient initialization",
`${name} must be a boolean value (one of: ${[...TRUE_ENV_VALUES, ...FALSE_ENV_VALUES].join(", ")}).`
);
}
function parseProxyUrlEnv(name: string): string | undefined {
const raw = process.env[name];
if (raw === undefined) {
return undefined;
}
const trimmed = raw.trim();
if (!trimmed) {
return undefined;
}
try {
const parsed = new URL(trimmed);
if (!["http:", "https:"].includes(parsed.protocol)) {
throw new Error("proxy URL protocol must be http or https");
}
return parsed.toString();
} catch (error: unknown) {
throw toolError(
"SteelClient initialization",
`${name} is invalid: ${error instanceof Error ? error.message : "invalid URL"}`
);
}
}
function parseStringEnv(name: string): string | undefined {
const raw = process.env[name];
if (raw === undefined) {
return undefined;
}
const trimmed = raw.trim();
return trimmed || undefined;
}
function resolveSessionCreateOptionsFromEnv(): Partial<SessionCreateOptions> {
const resolved: Partial<SessionCreateOptions> = {};
const solveCaptcha = parseBooleanEnv("STEEL_SOLVE_CAPTCHA");
const useProxy = parseBooleanEnv("STEEL_USE_PROXY");
const proxyUrl = parseProxyUrlEnv("STEEL_PROXY_URL");
const headless = parseBooleanEnv("STEEL_SESSION_HEADLESS");
const persistProfile = parseBooleanEnv("STEEL_SESSION_PERSIST_PROFILE");
const useCredentials = parseBooleanEnv("STEEL_SESSION_CREDENTIALS");
const region = parseStringEnv("STEEL_SESSION_REGION");
const profileId = parseStringEnv("STEEL_SESSION_PROFILE_ID");
const namespace = parseStringEnv("STEEL_SESSION_NAMESPACE");
if (solveCaptcha !== undefined) {
resolved.solveCaptcha = solveCaptcha;
}
if (useProxy !== undefined) {
resolved.useProxy = useProxy;
}
if (proxyUrl !== undefined) {
resolved.proxyUrl = proxyUrl;
}
if (headless !== undefined) {
resolved.headless = headless;
}
if (persistProfile !== undefined) {
resolved.persistProfile = persistProfile;
}
if (useCredentials) {
resolved.credentials = {};
}
if (region !== undefined) {
resolved.region = region;
}
if (profileId !== undefined) {
resolved.profileId = profileId;
}
if (namespace !== undefined) {
resolved.namespace = namespace;
}
return resolved;
}
export class SteelClient {
private static readonly DEFAULT_SESSION_TIMEOUT_MS = 30 * 60 * 1000;
private readonly client: Steel;
private readonly apiKey: string | null;
private readonly sessionTimeoutMs: number;
private readonly sessionCreateOptions: Partial<SessionCreateOptions>;
private readonly viewerBaseURL?: string;
private currentSession: TrackedSession | null = null;
private readonly sessions = new Map<string, TrackedSession>();
private creatingSession: Promise<TrackedSession> | null = null;
constructor(apiKey?: string, options: SteelClientOptions = {}) {
const runtimeConfig = resolveSteelRuntimeConfig(
options.apiKey ?? apiKey,
options.baseURL
);
const configuredTimeout =
options.sessionTimeoutMs === undefined
? undefined
: Number(options.sessionTimeoutMs);
const fallbackTimeout = Number.parseInt(
process.env.STEEL_SESSION_TIMEOUT_MS || "",
10
);
const normalizedConfiguredTimeout =
typeof configuredTimeout === "number" &&
Number.isFinite(configuredTimeout) &&
configuredTimeout > 0
? configuredTimeout
: undefined;
const normalizedFallbackTimeout =
Number.isFinite(fallbackTimeout) && fallbackTimeout > 0
? fallbackTimeout
: undefined;
const resolvedTimeout =
normalizedConfiguredTimeout ??
normalizedFallbackTimeout ??
SteelClient.DEFAULT_SESSION_TIMEOUT_MS;
this.client = new Steel({
steelAPIKey: runtimeConfig.apiKey,
baseURL: runtimeConfig.baseURL,
});
this.apiKey = runtimeConfig.apiKey;
this.viewerBaseURL = runtimeConfig.viewerBaseURL;
this.sessionTimeoutMs = resolvedTimeout;
this.sessionCreateOptions = {
...resolveSessionCreateOptionsFromEnv(),
...(options.sessionCreateOptions ?? {}),
};
}
async getOrCreateSession(): Promise<LiveSteelSession> {
if (this.currentSession) {
return this.currentSession.liveSession;
}
if (!this.creatingSession) {
this.creatingSession = this.createSession();
}
const tracked = await this.creatingSession;
return tracked.liveSession;
}
getCurrentSessionId(): string | null {
return this.currentSession?.metadata.id ?? null;
}
hasActiveSession(): boolean {
return this.currentSession !== null;
}
isProxyConfigured(): boolean {
const { useProxy, proxyUrl } = this.sessionCreateOptions;
if (typeof proxyUrl === "string" && proxyUrl.trim().length > 0) {
return true;
}
if (typeof useProxy === "boolean") {
return useProxy;
}
return useProxy !== undefined;
}
async refreshSession(options: SessionRefreshOptions = {}): Promise<LiveSteelSession> {
const currentSessionId = this.currentSession?.metadata.id;
if (currentSessionId) {
await this.closeSession(currentSessionId);
}
this.creatingSession = this.createSession(
this.resolveSessionCreateOptions(options)
);
const tracked = await this.creatingSession;
return tracked.liveSession;
}
async closeSession(sessionId?: string): Promise<void> {
const targetSessionId = sessionId ?? this.currentSession?.metadata.id;
if (!targetSessionId) {
return;
}
const tracked = this.sessions.get(targetSessionId);
this.sessions.delete(targetSessionId);
if (this.currentSession?.metadata.id === targetSessionId) {
this.currentSession = null;
}
if (!tracked) {
return;
}
await Promise.allSettled([
tracked.browser.close(),
this.client.sessions.release(targetSessionId),
]);
}
async closeAllSessions(): Promise<void> {
const trackedSessions = [...this.sessions.values()];
const sessionIds = trackedSessions.map((tracked) => tracked.metadata.id);
this.sessions.clear();
this.currentSession = null;
this.creatingSession = null;
if (sessionIds.length === 0) {
return;
}
await Promise.allSettled(
trackedSessions.map((tracked) => tracked.browser.close())
);
const releaseResult = await Promise.allSettled(
sessionIds.map((sessionId) => this.client.sessions.release(sessionId))
);
const allRejected = releaseResult.every((entry) => entry.status === "rejected");
if (allRejected) {
await this.client.sessions.releaseAll();
}
}
private resolveSessionCreateOptions(
options: SessionRefreshOptions = {}
): Partial<SessionCreateOptions> {
const merged: Partial<SessionCreateOptions> = {
...this.sessionCreateOptions,
};
if (options.useProxy !== undefined) {
merged.useProxy = options.useProxy;
if (options.useProxy === false && options.proxyUrl === undefined) {
delete merged.proxyUrl;
}
}
if (options.proxyUrl === null) {
delete merged.proxyUrl;
} else if (typeof options.proxyUrl === "string" && options.proxyUrl.trim()) {
merged.proxyUrl = options.proxyUrl.trim();
}
return merged;
}
private async createSession(
createOptions: Partial<SessionCreateOptions> = this.sessionCreateOptions
): Promise<TrackedSession> {
try {
const session = await this.client.sessions.create({
...createOptions,
timeout: this.sessionTimeoutMs,
blockAds: true,
});
const websocketUrl = buildSessionConnectURL(
session as unknown as Record<string, unknown>,
this.apiKey
);
if (!websocketUrl) {
throw new Error("Steel session did not include a connect URL.");
}
const browser = await chromium.connectOverCDP(websocketUrl);
const context = browser.contexts()[0] ?? (await browser.newContext());
const page = context.pages()[0] ?? (await context.newPage());
const liveSession = this.buildLiveSession(session, page);
const tracked: TrackedSession = {
metadata: session,
browser,
context,
page,
liveSession,
};
this.sessions.set(session.id, tracked);
this.currentSession = tracked;
return tracked;
} catch (error: unknown) {
throw toolError("SteelClient session creation", error);
} finally {
this.creatingSession = null;
}
}
private buildLiveSession(
session: SessionMetadata,
page: Page
): LiveSteelSession {
const sessionId =
resolveSessionId(session as unknown as Record<string, unknown>) ?? session.id;
return {
id: sessionId,
sessionViewerUrl:
resolveSessionViewerURL(
session as unknown as Record<string, unknown>,
this.viewerBaseURL
) ?? "",
debugUrl: session.debugUrl || "",
page,
goto: (url, options) => page.goto(url, options),
goBack: (options) => page.goBack(options),
back: (options) => page.goBack(options),
url: () => page.url(),
title: () => page.title(),
waitForSelector: (selector, options) =>
options
? page.waitForSelector(selector, options)
: page.waitForSelector(selector),
click: (selector, options) => page.click(selector, options),
fill: (selector, text) => page.fill(selector, text),
type: (selector, text, options) => page.type(selector, text, options),
evaluate: <T>(fn: (...args: any[]) => T, ...args: any[]) =>
page.evaluate(fn, ...args),
locator: (selector: string) => page.locator(selector),
content: () => page.content(),
screenshot: (options) => page.screenshot(options),
pdf: (options) => page.pdf(options),
computer: (body) => this.client.sessions.computer(sessionId, body),
captchasStatus: () => this.client.sessions.captchas.status(sessionId),
captchasSolve: () => this.client.sessions.captchas.solve(sessionId),
};
}
}

View File

@@ -0,0 +1,321 @@
import {
emitProgress,
isAbortError,
sleepWithSignal,
throwIfAborted,
type ToolProgressUpdater,
} from "./tool-runtime.js";
const CAPTCHA_WAIT_MS_ENV = "STEEL_CAPTCHA_WAIT_MS";
const CAPTCHA_MAX_RETRIES_ENV = "STEEL_CAPTCHA_MAX_RETRIES";
const CAPTCHA_POLL_INTERVAL_MS_ENV = "STEEL_CAPTCHA_POLL_INTERVAL_MS";
const DEFAULT_CAPTCHA_WAIT_MS = 45_000;
const DEFAULT_CAPTCHA_MAX_RETRIES = 1;
const DEFAULT_CAPTCHA_POLL_INTERVAL_MS = 1_500;
const MIN_CAPTCHA_WAIT_MS = 1_000;
const MAX_CAPTCHA_WAIT_MS = 180_000;
const MIN_CAPTCHA_POLL_INTERVAL_MS = 250;
const MAX_CAPTCHA_POLL_INTERVAL_MS = 10_000;
const MAX_CAPTCHA_RETRIES = 3;
type CaptchaStatusEntry = {
isSolvingCaptcha?: boolean;
tasks?: unknown;
};
export type CaptchaAwareSession = {
id: string;
captchasStatus?: () => Promise<unknown>;
captchasSolve?: () => Promise<unknown>;
};
export type CaptchaRecoverySummary = {
triggered: boolean;
retries: number;
solveAttempts: number;
statusChecks: number;
waitTimedOut: boolean;
};
type CaptchaRecoveryOptions<T> = {
session: CaptchaAwareSession;
context: string;
actionLabel: string;
onUpdate: ToolProgressUpdater;
operation: () => Promise<T>;
signal?: AbortSignal;
shouldRetry?: (error: unknown) => boolean;
};
function parsePositiveInt(raw: string | undefined): number | null {
if (raw === undefined) {
return null;
}
const value = raw.trim();
if (!value) {
return null;
}
const parsed = Number.parseInt(value, 10);
if (!Number.isFinite(parsed) || parsed <= 0) {
return null;
}
return parsed;
}
function resolveCaptchaWaitMs(): number {
const parsed = parsePositiveInt(process.env[CAPTCHA_WAIT_MS_ENV]);
if (parsed === null) {
return DEFAULT_CAPTCHA_WAIT_MS;
}
return Math.max(MIN_CAPTCHA_WAIT_MS, Math.min(parsed, MAX_CAPTCHA_WAIT_MS));
}
function resolveCaptchaMaxRetries(): number {
const parsed = parsePositiveInt(process.env[CAPTCHA_MAX_RETRIES_ENV]);
if (parsed === null) {
return DEFAULT_CAPTCHA_MAX_RETRIES;
}
return Math.max(0, Math.min(parsed, MAX_CAPTCHA_RETRIES));
}
function resolveCaptchaPollIntervalMs(): number {
const parsed = parsePositiveInt(process.env[CAPTCHA_POLL_INTERVAL_MS_ENV]);
if (parsed === null) {
return DEFAULT_CAPTCHA_POLL_INTERVAL_MS;
}
return Math.max(
MIN_CAPTCHA_POLL_INTERVAL_MS,
Math.min(parsed, MAX_CAPTCHA_POLL_INTERVAL_MS)
);
}
function normalizeErrorText(error: unknown): string {
if (error instanceof Error) {
return error.message.toLowerCase();
}
if (typeof error === "string") {
return error.toLowerCase();
}
return String(error ?? "").toLowerCase();
}
export function isCaptchaInterferenceError(error: unknown): boolean {
const message = normalizeErrorText(error);
return (
message.includes("captcha") ||
message.includes("hcaptcha") ||
message.includes("recaptcha") ||
message.includes("intercepts pointer events")
);
}
function normalizeCaptchaStatusEntries(value: unknown): CaptchaStatusEntry[] {
if (!Array.isArray(value)) {
return [];
}
return value.filter(
(entry): entry is CaptchaStatusEntry =>
typeof entry === "object" && entry !== null
);
}
function hasActiveCaptcha(entries: CaptchaStatusEntry[]): boolean {
for (const entry of entries) {
if (entry.isSolvingCaptcha) {
return true;
}
if (Array.isArray(entry.tasks) && entry.tasks.length > 0) {
return true;
}
}
return false;
}
async function tryReadCaptchaStatus(
session: CaptchaAwareSession,
summary: CaptchaRecoverySummary,
signal: AbortSignal | undefined
): Promise<CaptchaStatusEntry[]> {
throwIfAborted(signal);
if (typeof session.captchasStatus !== "function") {
return [];
}
const status = await session.captchasStatus();
summary.statusChecks += 1;
return normalizeCaptchaStatusEntries(status);
}
async function runCaptchaRecoveryStep(
session: CaptchaAwareSession,
context: string,
actionLabel: string,
onUpdate: ToolProgressUpdater,
summary: CaptchaRecoverySummary,
signal: AbortSignal | undefined
): Promise<void> {
throwIfAborted(signal);
const waitMs = resolveCaptchaWaitMs();
const pollIntervalMs = resolveCaptchaPollIntervalMs();
const deadline = Date.now() + waitMs;
let statusEntries: CaptchaStatusEntry[] = [];
try {
statusEntries = await tryReadCaptchaStatus(session, summary, signal);
} catch (error: unknown) {
if (isAbortError(error)) {
throw error;
}
await emitProgress(
onUpdate,
context,
`Captcha status check failed: ${
error instanceof Error ? error.message : "unknown error"
}`
);
}
if (statusEntries.length > 0) {
await emitProgress(
onUpdate,
context,
`Captcha status detected for ${statusEntries.length} page(s)`
);
} else {
await emitProgress(
onUpdate,
context,
"No explicit captcha status returned; attempting solve anyway"
);
}
if (typeof session.captchasSolve === "function") {
throwIfAborted(signal);
summary.solveAttempts += 1;
try {
const solveResult = await session.captchasSolve();
const message =
typeof solveResult === "object" &&
solveResult !== null &&
"message" in solveResult &&
typeof (solveResult as { message?: unknown }).message === "string"
? (solveResult as { message: string }).message
: "captcha solve requested";
await emitProgress(onUpdate, context, `Captcha solve call: ${message}`);
} catch (error: unknown) {
if (isAbortError(error)) {
throw error;
}
await emitProgress(
onUpdate,
context,
`Captcha solve call failed: ${
error instanceof Error ? error.message : "unknown error"
}`
);
}
} else {
await emitProgress(
onUpdate,
context,
"Session does not expose captchas.solve; proceeding with retry"
);
}
while (Date.now() < deadline && typeof session.captchasStatus === "function") {
throwIfAborted(signal);
await sleepWithSignal(pollIntervalMs, signal);
try {
statusEntries = await tryReadCaptchaStatus(session, summary, signal);
} catch (error: unknown) {
if (isAbortError(error)) {
throw error;
}
await emitProgress(
onUpdate,
context,
`Captcha status polling failed: ${
error instanceof Error ? error.message : "unknown error"
}`
);
break;
}
if (!hasActiveCaptcha(statusEntries)) {
await emitProgress(onUpdate, context, "Captcha state cleared; retrying action");
return;
}
}
if (typeof session.captchasStatus === "function") {
summary.waitTimedOut = true;
await emitProgress(
onUpdate,
context,
`Captcha wait reached ${waitMs}ms; retrying ${actionLabel}`
);
}
}
export async function runWithCaptchaRecovery<T>(
options: CaptchaRecoveryOptions<T>
): Promise<CaptchaRecoverySummary> {
const {
session,
context,
actionLabel,
onUpdate,
operation,
signal,
shouldRetry = isCaptchaInterferenceError,
} = options;
const maxRetries = resolveCaptchaMaxRetries();
const summary: CaptchaRecoverySummary = {
triggered: false,
retries: 0,
solveAttempts: 0,
statusChecks: 0,
waitTimedOut: false,
};
let attempt = 0;
while (true) {
throwIfAborted(signal);
try {
await operation();
return summary;
} catch (error: unknown) {
if (isAbortError(error)) {
throw error;
}
throwIfAborted(signal);
const retriable = shouldRetry(error);
if (!retriable || attempt >= maxRetries) {
throw error;
}
summary.triggered = true;
summary.retries += 1;
await emitProgress(
onUpdate,
context,
`Captcha-related blocker detected while trying to ${actionLabel}`
);
await runCaptchaRecoveryStep(
session,
context,
actionLabel,
onUpdate,
summary,
signal
);
attempt += 1;
}
}
}

View File

@@ -0,0 +1,340 @@
import type { ExtensionContext, ToolDefinition } from "@mariozechner/pi-coding-agent";
import { Type } from "@sinclair/typebox";
import { sessionDetails, type SteelClient } from "../steel-client.js";
import { runWithCaptchaRecovery, type CaptchaRecoverySummary } from "./captcha-guard.js";
import {
emitProgress,
throwIfAborted,
withAbortSignal,
withToolError,
type ToolProgressUpdater,
} from "./tool-runtime.js";
import {
MAX_TOOL_TIMEOUT_MS,
resolveToolTimeoutMs,
} from "./tool-settings.js";
type WaitState = "attached" | "visible";
type SessionLike = {
id: string;
sessionViewerUrl?: string | null;
captchasStatus?: () => Promise<unknown>;
captchasSolve?: () => Promise<unknown>;
waitForSelector?: (
selector: string,
options?: { state?: WaitState; timeout?: number }
) => Promise<unknown>;
click?: (selector: string, options?: { timeout?: number }) => Promise<unknown>;
evaluate?: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
locator?: (selector: string) => {
waitFor?: (options?: { state?: WaitState; timeout?: number }) => Promise<unknown>;
isVisible?: () => Promise<boolean>;
isEnabled?: () => Promise<boolean>;
click?: (options?: { timeout?: number }) => Promise<unknown>;
};
page?: {
waitForSelector?: (
selector: string,
options?: { state?: WaitState; timeout?: number }
) => Promise<unknown>;
click?: (selector: string, options?: { timeout?: number }) => Promise<unknown>;
locator?: (selector: string) => {
waitFor?: (options?: { state?: WaitState; timeout?: number }) => Promise<unknown>;
isVisible?: () => Promise<boolean>;
isEnabled?: () => Promise<boolean>;
click?: (options?: { timeout?: number }) => Promise<unknown>;
};
evaluate?: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
};
};
function compactCaptchaRecovery(summary: CaptchaRecoverySummary) {
return {
triggered: summary.triggered,
retries: summary.retries,
solveAttempts: summary.solveAttempts,
statusChecks: summary.statusChecks,
waitTimedOut: summary.waitTimedOut,
};
}
function normalizeSelector(selector: string): string {
const trimmed = selector.trim();
if (!trimmed) {
throw new Error("Selector cannot be empty.");
}
return trimmed;
}
function normalizeTimeout(timeoutMs?: number): number {
return resolveToolTimeoutMs(timeoutMs);
}
function getLocator(
session: SessionLike,
selector: string
):
| {
waitFor?: (options?: { state?: WaitState; timeout?: number }) => Promise<unknown>;
isVisible?: () => Promise<boolean>;
isEnabled?: () => Promise<boolean>;
click?: (options?: { timeout?: number }) => Promise<unknown>;
}
| undefined {
if (typeof session.locator === "function") {
return session.locator(selector);
}
if (typeof session.page?.locator === "function") {
return session.page.locator(selector);
}
return undefined;
}
function supportsCssSelectorFallback(selector: string): boolean {
const normalized = selector.trim();
if (!normalized) {
return false;
}
if (
normalized.includes(">>") ||
normalized.includes("text=") ||
normalized.includes("xpath=") ||
normalized.includes("nth=") ||
normalized.includes(":has-text(") ||
normalized.includes(":text(") ||
normalized.includes(":contains(")
) {
return false;
}
return true;
}
async function waitForTarget(
session: SessionLike,
selector: string,
timeoutMs: number,
signal: AbortSignal | undefined
): Promise<void> {
throwIfAborted(signal);
const locator = getLocator(session, selector);
if (locator?.waitFor) {
await withAbortSignal(
locator.waitFor({ state: "visible", timeout: timeoutMs }),
signal
);
return;
}
if (typeof session.waitForSelector === "function") {
await withAbortSignal(
session.waitForSelector(selector, { state: "visible", timeout: timeoutMs }),
signal
);
return;
}
if (typeof session.page?.waitForSelector === "function") {
await withAbortSignal(
session.page.waitForSelector(selector, { state: "visible", timeout: timeoutMs }),
signal
);
}
}
async function ensureClickable(
session: SessionLike,
selector: string,
signal: AbortSignal | undefined
): Promise<void> {
throwIfAborted(signal);
const locator = getLocator(session, selector);
if (locator) {
if (typeof locator.isVisible === "function") {
const visible = await withAbortSignal(locator.isVisible(), signal);
if (!visible) {
throw new Error(`Element is not visible: ${selector}`);
}
}
if (typeof locator.isEnabled === "function") {
const enabled = await withAbortSignal(locator.isEnabled(), signal);
if (!enabled) {
throw new Error(`Element is disabled and cannot be clicked: ${selector}`);
}
}
return;
}
if (!supportsCssSelectorFallback(selector)) {
return;
}
const evaluate = session.evaluate ?? session.page?.evaluate;
if (typeof evaluate !== "function") {
return;
}
const result = await withAbortSignal(
evaluate(
(input: { selector: string }) => {
const element = document.querySelector(input.selector) as HTMLElement | null;
if (!element) {
return { found: false, clickable: false, disabled: false };
}
const style = getComputedStyle(element);
const rect = element.getBoundingClientRect();
const visible =
rect.width > 0 &&
rect.height > 0 &&
style.display !== "none" &&
style.visibility !== "hidden" &&
Number.parseFloat(style.opacity) > 0;
const disabled =
(element as HTMLInputElement).disabled === true ||
element.getAttribute("aria-disabled") === "true";
const clickable = visible && !disabled && style.pointerEvents !== "none";
return { found: true, clickable, disabled };
},
{ selector }
),
signal
);
if (!result || typeof result !== "object") {
return;
}
const found = Boolean((result as Record<string, unknown>).found);
const clickable = Boolean((result as Record<string, unknown>).clickable);
const disabled = Boolean((result as Record<string, unknown>).disabled);
if (!found) {
throw new Error(`No element matched selector: ${selector}`);
}
if (disabled) {
throw new Error(`Element is disabled and cannot be clicked: ${selector}`);
}
if (!clickable) {
throw new Error(`Element is not clickable: ${selector}`);
}
}
async function invokeClick(
session: SessionLike,
selector: string,
timeoutMs: number,
signal: AbortSignal | undefined
): Promise<void> {
throwIfAborted(signal);
const locator = getLocator(session, selector);
if (locator?.click) {
await withAbortSignal(locator.click({ timeout: timeoutMs }), signal);
return;
}
if (typeof session.click === "function") {
await withAbortSignal(session.click(selector, { timeout: timeoutMs }), signal);
return;
}
if (typeof session.page?.click === "function") {
await withAbortSignal(
session.page.click(selector, { timeout: timeoutMs }),
signal
);
return;
}
const pageEvaluate = session.evaluate ?? session.page?.evaluate;
if (typeof pageEvaluate === "function" && supportsCssSelectorFallback(selector)) {
const clicked = await withAbortSignal(
pageEvaluate(
(input: { selector: string }) => {
const element = document.querySelector(input.selector) as HTMLElement | null;
if (!element) {
return false;
}
element.click();
return true;
},
{ selector }
),
signal
);
if (clicked) {
return;
}
}
throw new Error("Session does not support click operations.");
}
export function clickTool(client: SteelClient): ToolDefinition<any, any> {
return {
name: "steel_click",
label: "Click",
description: "Click an element in the page",
parameters: Type.Object(
{
selector: Type.String({ description: "CSS selector of the element to click" }),
timeout: Type.Optional(
Type.Integer({
minimum: 100,
maximum: MAX_TOOL_TIMEOUT_MS,
description: "Maximum milliseconds to wait for the element",
})
),
}
),
async execute(
_toolCallId: string,
params: { selector: string; timeout?: number },
signal: AbortSignal | undefined,
onUpdate: ToolProgressUpdater,
_ctx: ExtensionContext
): Promise<{ content: Array<{ type: "text"; text: string }>; details: object }> {
return withToolError("steel_click", async () => {
throwIfAborted(signal);
const selector = normalizeSelector(params.selector);
const timeoutMs = normalizeTimeout(params.timeout);
await emitProgress(onUpdate, "steel_click", `Preparing click for ${selector}`);
const session = (await withAbortSignal(
client.getOrCreateSession(),
signal
)) as SessionLike;
await emitProgress(onUpdate, "steel_click", "Running click sequence");
const captchaRecovery = await runWithCaptchaRecovery({
session,
context: "steel_click",
actionLabel: `click ${selector}`,
onUpdate,
signal,
operation: async () => {
throwIfAborted(signal);
await waitForTarget(session, selector, timeoutMs, signal);
throwIfAborted(signal);
await ensureClickable(session, selector, signal);
throwIfAborted(signal);
await invokeClick(session, selector, timeoutMs, signal);
},
});
await emitProgress(onUpdate, "steel_click", "Click succeeded");
return {
content: [{ type: "text", text: `Clicked element ${selector}` }],
details: {
...sessionDetails(session),
selector,
timeoutMs,
clicked: true,
captchaRecovery: compactCaptchaRecovery(captchaRecovery),
},
};
}, signal);
},
};
}

View File

@@ -0,0 +1,456 @@
import { randomUUID } from "node:crypto";
import { promises as fs } from "node:fs";
import path from "node:path";
import type { ExtensionContext, ToolDefinition } from "@mariozechner/pi-coding-agent";
import { Type } from "@sinclair/typebox";
import type Steel from "steel-sdk";
import { sessionDetails, type SteelClient } from "../steel-client.js";
import {
emitProgress,
throwIfAborted,
withAbortSignal,
withToolError,
type ToolProgressUpdater,
} from "./tool-runtime.js";
type SessionComputerParams = Steel.SessionComputerParams;
type SessionComputerResponse = Steel.SessionComputerResponse;
type ComputerAction = SessionComputerParams["action"];
type SessionLike = {
id: string;
sessionViewerUrl?: string | null;
computer?: (body: SessionComputerParams) => Promise<SessionComputerResponse>;
};
type ComputerToolParams = {
action: ComputerAction;
screenshot?: boolean;
hold_keys?: string[];
coordinates?: number[];
button?: "left" | "right" | "middle" | "back" | "forward";
click_type?: "down" | "up" | "click";
num_clicks?: number;
path?: number[][];
delta_x?: number;
delta_y?: number;
keys?: string[];
duration?: number;
text?: string;
};
const RELATIVE_SCREENSHOT_DIR = path.join(".artifacts", "screenshots");
const SUPPORTED_ACTIONS: readonly ComputerAction[] = [
"move_mouse",
"click_mouse",
"drag_mouse",
"scroll",
"press_key",
"type_text",
"wait",
"take_screenshot",
"get_cursor_position",
];
function isFiniteNumber(value: unknown): value is number {
return typeof value === "number" && Number.isFinite(value);
}
function normalizeCoordinatePair(
raw: number[] | undefined,
fieldName: string
): [number, number] {
if (!Array.isArray(raw) || raw.length !== 2) {
throw new Error(`${fieldName} must be [x, y].`);
}
const [x, y] = raw;
if (!isFiniteNumber(x) || !isFiniteNumber(y)) {
throw new Error(`${fieldName} must contain finite numbers.`);
}
return [x, y];
}
function normalizeKeyList(raw: string[] | undefined, fieldName: string): string[] {
if (!Array.isArray(raw) || raw.length === 0) {
throw new Error(`${fieldName} must contain at least one key.`);
}
const keys = raw
.map((item) => item.trim())
.filter((item) => item.length > 0);
if (keys.length === 0) {
throw new Error(`${fieldName} must contain at least one non-empty key.`);
}
return keys;
}
function normalizeOptionalHoldKeys(raw: string[] | undefined): string[] | undefined {
if (raw === undefined) {
return undefined;
}
if (!Array.isArray(raw)) {
throw new Error("hold_keys must be an array of key names.");
}
const keys = raw
.map((item) => item.trim())
.filter((item) => item.length > 0);
return keys.length > 0 ? keys : undefined;
}
function normalizeDuration(
raw: number | undefined,
fieldName: string
): number | undefined {
if (raw === undefined) {
return undefined;
}
if (!isFiniteNumber(raw) || raw <= 0) {
throw new Error(`${fieldName} must be a positive number.`);
}
return raw;
}
function normalizeAction(action: string): ComputerAction {
const trimmed = action.trim() as ComputerAction;
if (!SUPPORTED_ACTIONS.includes(trimmed)) {
throw new Error(
`Unsupported action "${action}". Supported actions: ${SUPPORTED_ACTIONS.join(", ")}.`
);
}
return trimmed;
}
function buildActionRequest(params: ComputerToolParams): SessionComputerParams {
const action = normalizeAction(params.action);
const screenshot = params.screenshot;
const holdKeys = normalizeOptionalHoldKeys(params.hold_keys);
switch (action) {
case "move_mouse": {
return {
action,
coordinates: normalizeCoordinatePair(params.coordinates, "coordinates"),
...(screenshot === undefined ? {} : { screenshot }),
...(holdKeys ? { hold_keys: holdKeys } : {}),
};
}
case "click_mouse": {
const button = params.button;
if (!button) {
throw new Error("button is required for click_mouse.");
}
const body: Extract<SessionComputerParams, { action: "click_mouse" }> = {
action,
button,
...(screenshot === undefined ? {} : { screenshot }),
...(holdKeys ? { hold_keys: holdKeys } : {}),
};
if (params.coordinates !== undefined) {
body.coordinates = normalizeCoordinatePair(params.coordinates, "coordinates");
}
if (params.click_type !== undefined) {
body.click_type = params.click_type;
}
if (params.num_clicks !== undefined) {
if (!Number.isInteger(params.num_clicks) || params.num_clicks <= 0) {
throw new Error("num_clicks must be a positive integer.");
}
body.num_clicks = params.num_clicks;
}
return body;
}
case "drag_mouse": {
if (!Array.isArray(params.path) || params.path.length < 2) {
throw new Error("path must contain at least two [x, y] coordinates.");
}
const pathPairs = params.path.map((entry, index) =>
normalizeCoordinatePair(entry, `path[${index}]`)
);
return {
action,
path: pathPairs,
...(screenshot === undefined ? {} : { screenshot }),
...(holdKeys ? { hold_keys: holdKeys } : {}),
};
}
case "scroll": {
const hasDeltaX = params.delta_x !== undefined;
const hasDeltaY = params.delta_y !== undefined;
if (!hasDeltaX && !hasDeltaY) {
throw new Error("scroll requires delta_x, delta_y, or both.");
}
if (hasDeltaX && !isFiniteNumber(params.delta_x)) {
throw new Error("delta_x must be a finite number.");
}
if (hasDeltaY && !isFiniteNumber(params.delta_y)) {
throw new Error("delta_y must be a finite number.");
}
const body: Extract<SessionComputerParams, { action: "scroll" }> = {
action,
...(screenshot === undefined ? {} : { screenshot }),
...(holdKeys ? { hold_keys: holdKeys } : {}),
};
if (params.coordinates !== undefined) {
body.coordinates = normalizeCoordinatePair(params.coordinates, "coordinates");
}
if (hasDeltaX) {
body.delta_x = params.delta_x;
}
if (hasDeltaY) {
body.delta_y = params.delta_y;
}
return body;
}
case "press_key": {
const duration = normalizeDuration(params.duration, "duration");
return {
action,
keys: normalizeKeyList(params.keys, "keys"),
...(duration === undefined ? {} : { duration }),
...(screenshot === undefined ? {} : { screenshot }),
};
}
case "type_text": {
if (typeof params.text !== "string") {
throw new Error("text is required for type_text.");
}
return {
action,
text: params.text,
...(screenshot === undefined ? {} : { screenshot }),
...(holdKeys ? { hold_keys: holdKeys } : {}),
};
}
case "wait": {
const duration = normalizeDuration(params.duration, "duration");
if (duration === undefined) {
throw new Error("duration is required for wait.");
}
return {
action,
duration,
...(screenshot === undefined ? {} : { screenshot }),
};
}
case "take_screenshot":
return { action };
case "get_cursor_position":
return { action };
default:
throw new Error(`Unsupported action "${action}".`);
}
}
function screenshotDirectory(): string {
return path.resolve(process.cwd(), RELATIVE_SCREENSHOT_DIR);
}
function toArtifactDisplayPath(filePath: string): string {
const relativePath = path.relative(process.cwd(), filePath);
if (!relativePath || relativePath.startsWith("..")) {
return path.basename(filePath);
}
return relativePath;
}
async function createScreenshotPath(): Promise<string> {
const dir = screenshotDirectory();
await fs.mkdir(dir, { recursive: true });
const safeId = randomUUID().slice(0, 8);
return path.join(dir, `steel-computer-${Date.now()}-${safeId}.png`);
}
function decodeBase64Png(raw: string): Buffer {
const text = raw.trim();
if (!text) {
throw new Error("empty base64_image payload.");
}
const payload = text.startsWith("data:")
? text.slice(text.indexOf(",") + 1)
: text;
const decoded = Buffer.from(payload, "base64");
if (decoded.length === 0) {
throw new Error("invalid base64_image payload.");
}
return decoded;
}
async function persistScreenshotArtifact(base64Image: string) {
const buffer = decodeBase64Png(base64Image);
const targetPath = await createScreenshotPath();
await fs.writeFile(targetPath, buffer);
const displayPath = toArtifactDisplayPath(targetPath);
return {
path: displayPath,
fileName: path.basename(displayPath),
mimeType: "image/png",
sizeBytes: buffer.length,
type: "image",
};
}
export function computerTool(client: SteelClient): ToolDefinition<any, any> {
return {
name: "steel_computer",
label: "Computer Action",
description: "Execute low-level Steel computer actions (mouse, keyboard, scroll, screenshot)",
parameters: Type.Object({
action: Type.Union(
SUPPORTED_ACTIONS.map((value) => Type.Literal(value)),
{ description: "Computer action type to execute" }
),
screenshot: Type.Optional(
Type.Boolean({
description: "Request screenshot output after the action (supported by most actions)",
})
),
hold_keys: Type.Optional(
Type.Array(Type.String(), {
description: "Modifier keys to hold while performing supported actions",
})
),
coordinates: Type.Optional(
Type.Array(Type.Number(), {
minItems: 2,
maxItems: 2,
description: "Target coordinates as [x, y]",
})
),
button: Type.Optional(
Type.Union(
[
Type.Literal("left"),
Type.Literal("right"),
Type.Literal("middle"),
Type.Literal("back"),
Type.Literal("forward"),
],
{ description: "Mouse button for click_mouse" }
)
),
click_type: Type.Optional(
Type.Union(
[Type.Literal("click"), Type.Literal("down"), Type.Literal("up")],
{ description: "Click type for click_mouse" }
)
),
num_clicks: Type.Optional(
Type.Integer({
minimum: 1,
description: "Number of clicks for click_mouse",
})
),
path: Type.Optional(
Type.Array(
Type.Array(Type.Number(), { minItems: 2, maxItems: 2 }),
{
minItems: 2,
description: "Drag path as array of [x, y] points for drag_mouse",
}
)
),
delta_x: Type.Optional(
Type.Number({ description: "Horizontal scroll amount for scroll" })
),
delta_y: Type.Optional(
Type.Number({ description: "Vertical scroll amount for scroll" })
),
keys: Type.Optional(
Type.Array(Type.String(), {
minItems: 1,
description: "Keys for press_key",
})
),
duration: Type.Optional(
Type.Number({
exclusiveMinimum: 0,
description: "Duration in seconds for wait/press_key",
})
),
text: Type.Optional(
Type.String({ description: "Text for type_text action" })
),
}),
async execute(
_toolCallId: string,
params: ComputerToolParams,
signal: AbortSignal | undefined,
onUpdate: ToolProgressUpdater,
_ctx: ExtensionContext
): Promise<{ content: Array<{ type: "text"; text: string }>; details: object }> {
return withToolError("steel_computer", async () => {
throwIfAborted(signal);
await emitProgress(onUpdate, "steel_computer", `Preparing action ${params.action}`);
const session = (await withAbortSignal(
client.getOrCreateSession(),
signal
)) as SessionLike;
if (typeof session.computer !== "function") {
throw new Error(
"Current Steel client does not expose sessions.computer(). Upgrade steel-sdk to a newer version."
);
}
const requestBody = buildActionRequest(params);
await emitProgress(onUpdate, "steel_computer", `Dispatching ${requestBody.action}`);
const response = await withAbortSignal(
session.computer(requestBody),
signal
);
if (response.error) {
throw new Error(response.error);
}
let artifact:
| {
path: string;
fileName: string;
mimeType: string;
sizeBytes: number;
type: string;
}
| undefined;
if (typeof response.base64_image === "string" && response.base64_image.trim()) {
await emitProgress(onUpdate, "steel_computer", "Persisting screenshot artifact");
artifact = await persistScreenshotArtifact(response.base64_image);
}
const outputParts = [response.output, response.system]
.filter((item): item is string => typeof item === "string" && item.trim().length > 0)
.map((item) => item.trim());
const outputSuffix = outputParts.length > 0 ? ` ${outputParts.join(" ")}` : "";
return {
content: [
{
type: "text",
text: `Computer action ${requestBody.action} completed.${outputSuffix}`,
},
],
details: {
...sessionDetails(session),
action: requestBody.action,
request: requestBody,
output: response.output ?? null,
system: response.system ?? null,
hasScreenshot: Boolean(artifact),
...(artifact
? {
filePath: artifact.path,
fileName: artifact.fileName,
artifact,
}
: {}),
},
};
}, signal);
},
};
}

View File

@@ -0,0 +1,621 @@
import type { ExtensionContext, ToolDefinition } from "@mariozechner/pi-coding-agent";
import { Type } from "@sinclair/typebox";
import { sessionDetails as baseSessionDetails, type SteelClient } from "../steel-client.js";
import {
emitProgress,
throwIfAborted,
withAbortSignal,
withToolError,
type ToolProgressUpdater,
} from "./tool-runtime.js";
type SessionLike = {
id: string;
sessionViewerUrl?: string | null;
url?: (() => Promise<string> | string) | string;
evaluate?: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
page?: {
evaluate?: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
};
};
type SchemaType = "object" | "array" | "string" | "number" | "integer" | "boolean" | "null";
type PrimitiveSchemaType = Exclude<SchemaType, "object" | "array">;
type ExtractionSchema = {
type: SchemaType;
properties: Record<string, ExtractionSchema>;
required: string[];
items?: ExtractionSchema;
selector?: string;
attribute?: string;
additionalProperties: boolean;
};
const ALLOWED_TYPES = new Set<SchemaType>([
"object",
"array",
"string",
"number",
"integer",
"boolean",
"null",
]);
function asPlainObject(input: unknown, path: string): Record<string, unknown> {
if (!input || typeof input !== "object" || Array.isArray(input)) {
throw new Error(`Schema at ${path} must be an object.`);
}
return input as Record<string, unknown>;
}
function normalizeBoolean(value: unknown, path: string): boolean {
if (typeof value === "boolean") {
return value;
}
throw new Error(`Schema at ${path} must define a boolean value.`);
}
function normalizeString(value: unknown, path: string): string | undefined {
if (value === undefined) {
return undefined;
}
if (typeof value !== "string") {
throw new Error(`Schema at ${path} must define a string value.`);
}
const normalized = value.trim();
if (!normalized) {
throw new Error(`Schema at ${path} must not be empty.`);
}
return normalized;
}
function normalizeRequired(
value: unknown,
properties: Record<string, ExtractionSchema>,
path: string
): string[] {
if (value === undefined) {
return [];
}
if (!Array.isArray(value) || value.length !== value.filter((entry) => typeof entry === "string").length) {
throw new Error(`Schema at ${path} must use an array of strings for required fields.`);
}
return value.filter((entry): entry is string => true);
}
function normalizeSchemaType(
rawType: unknown,
rawSchema: Record<string, unknown>,
path: string
): SchemaType {
const hasProperties =
Object.prototype.hasOwnProperty.call(rawSchema, "properties");
const hasItems = Object.prototype.hasOwnProperty.call(rawSchema, "items");
if (rawType === undefined) {
if (hasProperties) {
return "object";
}
if (hasItems) {
return "array";
}
throw new Error(
`Schema at ${path} must define a type or include "properties"/"items" to infer object/array shape.`
);
}
if (typeof rawType !== "string" || !ALLOWED_TYPES.has(rawType as SchemaType)) {
throw new Error(`Schema at ${path} has unsupported type "${String(rawType)}".`);
}
return rawType as SchemaType;
}
function normalizeProperties(rawValue: unknown, path: string): Record<string, ExtractionSchema> {
if (rawValue === undefined) {
return {};
}
if (!rawValue || typeof rawValue !== "object" || Array.isArray(rawValue)) {
throw new Error(`Schema at ${path} must use an object for properties.`);
}
const properties = rawValue as Record<string, unknown>;
const normalized: Record<string, ExtractionSchema> = {};
for (const [name, propertySchema] of Object.entries(properties)) {
normalized[name] = normalizeSchema(propertySchema, `${path}.${name}`);
}
return normalized;
}
function normalizeSchema(rawSchema: unknown, path: string): ExtractionSchema {
const schemaObject = asPlainObject(rawSchema, path);
const type = normalizeSchemaType(schemaObject.type, schemaObject, path);
const schema: ExtractionSchema = {
type,
properties: {},
required: [],
additionalProperties: true,
};
if (type === "object") {
const properties = normalizeProperties(schemaObject.properties, `${path}.properties`);
schema.properties = properties;
schema.required = normalizeRequired(
schemaObject.required,
properties,
`${path}.required`
);
schema.additionalProperties = normalizeBoolean(
schemaObject.additionalProperties ?? true,
`${path}.additionalProperties`
);
return schema;
}
if (type === "array") {
schema.items = normalizeSchema(
schemaObject.items,
`${path}.items`
);
schema.additionalProperties = normalizeBoolean(
schemaObject.additionalProperties ?? true,
`${path}.additionalProperties`
);
return schema;
}
schema.selector = normalizeString(
schemaObject.selector,
`${path}.selector`
);
schema.attribute = normalizeString(
schemaObject.attribute,
`${path}.attribute`
);
schema.additionalProperties = normalizeBoolean(
schemaObject.additionalProperties ?? true,
`${path}.additionalProperties`
);
return schema;
}
function enforceStrictMode(schema: ExtractionSchema): ExtractionSchema {
if (schema.type === "object") {
const properties: Record<string, ExtractionSchema> = {};
for (const [key, propertySchema] of Object.entries(schema.properties)) {
properties[key] = enforceStrictMode(propertySchema);
}
return {
...schema,
additionalProperties: false,
properties,
};
}
if (schema.type === "array") {
return {
...schema,
items: schema.items ? enforceStrictMode(schema.items) : undefined,
};
}
return { ...schema };
}
function readSessionUrl(session: SessionLike): Promise<string> {
const direct = session.url;
if (typeof direct === "string" && direct.trim()) {
return Promise.resolve(direct);
}
if (typeof direct === "function") {
return Promise.resolve(direct.call(session)).then((value) => {
if (typeof value === "string" && value.trim()) {
return value;
}
return "unknown";
});
}
const getter = (session as { getCurrentUrl?: () => Promise<string> | string }).getCurrentUrl;
if (typeof getter === "function") {
return Promise.resolve(getter.call(session)).then((value) => {
if (typeof value === "string" && value.trim()) {
return value;
}
return "unknown";
});
}
return Promise.resolve("unknown");
}
function sessionDetails(session: SessionLike, url: string, scopeSelector: string | null) {
return {
...baseSessionDetails(session),
url,
scopeSelector,
};
}
function buildPrompt(summary: string, instructions: string | undefined): string {
const instructionLine = instructions ? `\nInstructions: ${instructions}` : "";
return `Extract structured JSON from the page following this schema contract.${instructionLine}\n${summary}`;
}
function summarizeSchema(schema: ExtractionSchema, path: string): string[] {
const lines: string[] = [];
const children = [];
const requiredSet = new Set(schema.required);
if (schema.type === "object") {
lines.push(`${path}: object`);
for (const [key, propertySchema] of Object.entries(schema.properties)) {
const childPath = `${path}.${key}`;
children.push(...summarizeSchema(
propertySchema,
`${childPath}${requiredSet.has(key) ? " (required)" : ""}`
));
}
} else if (schema.type === "array") {
lines.push(`${path}: array`);
if (schema.items) {
lines.push(...summarizeSchema(schema.items, `${path}[]`));
}
} else {
const selectorPart = schema.selector ? ` selector=${schema.selector}` : "";
const attributePart = schema.attribute ? ` attr=${schema.attribute}` : "";
lines.push(`${path}: ${schema.type}${selectorPart}${attributePart}`);
}
return [...lines, ...children];
}
function toPathPart(name: string): string {
return name.includes(".") ? `["${name}"]` : `.${name}`;
}
function pushError(errors: string[], path: string, message: string): void {
errors.push(`${path}: ${message}`);
}
function validateExtraction(value: unknown, schema: ExtractionSchema, path: string, errors: string[]): void {
if (schema.type === "object") {
if (!value || typeof value !== "object" || Array.isArray(value)) {
pushError(errors, path, "expected object");
return;
}
const record = value as Record<string, unknown>;
const valueKeys = Object.keys(record);
if (!schema.additionalProperties) {
for (const key of valueKeys) {
if (!Object.prototype.hasOwnProperty.call(schema.properties, key)) {
pushError(errors, `${path}${toPathPart(key)}`, "unexpected property");
}
}
}
for (const required of schema.required) {
if (!Object.prototype.hasOwnProperty.call(record, required)) {
pushError(errors, `${path}${toPathPart(required)}`, "missing required value");
}
}
for (const [key, childSchema] of Object.entries(schema.properties)) {
if (!Object.prototype.hasOwnProperty.call(record, key)) {
continue;
}
validateExtraction(record[key], childSchema, `${path}${toPathPart(key)}`, errors);
}
return;
}
if (schema.type === "array") {
if (!Array.isArray(value)) {
pushError(errors, path, "expected array");
return;
}
if (!schema.items) {
return;
}
for (let i = 0; i < value.length; i++) {
validateExtraction(value[i], schema.items, `${path}[${i}]`, errors);
}
return;
}
if (schema.type === "string") {
if (typeof value !== "string") {
pushError(errors, path, "expected string");
}
return;
}
if (schema.type === "number" || schema.type === "integer") {
if (typeof value !== "number" || !Number.isFinite(value)) {
pushError(errors, path, "expected finite number");
return;
}
if (schema.type === "integer" && !Number.isInteger(value)) {
pushError(errors, path, "expected integer");
}
return;
}
if (schema.type === "boolean") {
if (typeof value !== "boolean") {
pushError(errors, path, "expected boolean");
}
return;
}
if (value !== null) {
pushError(errors, path, "expected null");
}
}
function trimAndNormalizeText(raw: string | null | undefined): string {
if (typeof raw !== "string") {
return "";
}
return raw.replace(/\u00a0/g, " ").trim();
}
async function extractWithBrowser(
session: SessionLike,
schema: ExtractionSchema,
scopeSelector: string | null
): Promise<unknown> {
const evaluate = session.evaluate ?? session.page?.evaluate;
if (typeof evaluate !== "function") {
throw new Error("Session does not support DOM-based extraction.");
}
return evaluate(
(input: { schema: ExtractionSchema; scopeSelector: string | null }): unknown => {
const cleanText = (value: string | null): string => {
if (typeof value !== "string") {
return "";
}
return value.replace(/\u00a0/g, " ").trim();
};
const resolveScope = (scope: string | null): ParentNode => {
if (!scope) {
return document;
}
const root = document.querySelector(scope);
if (!root) {
return document;
}
return root;
};
const coercePrimitive = (source: string, schemaType: "string" | "number" | "integer" | "boolean" | "null"): string | number | boolean | null => {
const normalized = cleanText(source);
if (schemaType === "string") {
return normalized;
}
if (schemaType === "boolean") {
const value = normalized.toLowerCase();
if (["true", "1", "yes", "on"].includes(value)) {
return true;
}
if (["false", "0", "no", "off"].includes(value)) {
return false;
}
return Boolean(normalized);
}
if (schemaType === "number" || schemaType === "integer") {
const sanitized = normalized.replace(/[^0-9.-]/g, "");
const parsed = Number.parseFloat(sanitized);
if (!Number.isFinite(parsed)) {
return NaN as unknown as boolean;
}
if (schemaType === "integer") {
return Number.isInteger(parsed) ? parsed : NaN as unknown as boolean;
}
return parsed;
}
return null;
};
const findBySelector = (ctx: ParentNode, selector: string | undefined): ParentNode[] => {
if (!selector) {
return [ctx];
}
if (!("querySelectorAll" in ctx)) {
return [];
}
return Array.from(ctx.querySelectorAll(selector)) as ParentNode[];
};
const readPrimitiveValue = (
ctx: ParentNode,
targetSchema: ExtractionSchema
): string | number | boolean | null | undefined => {
const selector = targetSchema.selector;
const attr = targetSchema.attribute;
const candidates = findBySelector(ctx, selector);
if (!candidates[0] || !(candidates[0] instanceof Element)) {
return undefined;
}
const element = candidates[0] as Element & { value?: unknown };
if (attr) {
const attributeValue = element.getAttribute(attr);
if (attributeValue === null) {
return undefined;
}
const casted = coercePrimitive(attributeValue, targetSchema.type as PrimitiveSchemaType);
return typeof casted === "number" && !Number.isFinite(casted)
? undefined
: casted;
}
if (element instanceof HTMLInputElement || element instanceof HTMLTextAreaElement) {
const casted = coercePrimitive(
String((element as HTMLInputElement).value ?? ""),
targetSchema.type as PrimitiveSchemaType
);
return typeof casted === "number" && !Number.isFinite(casted)
? undefined
: casted;
}
const casted = coercePrimitive(
element.textContent ?? "",
targetSchema.type as PrimitiveSchemaType
);
return typeof casted === "number" && !Number.isFinite(casted)
? undefined
: casted;
};
const extract = (ctx: ParentNode, currentSchema: ExtractionSchema): unknown => {
if (currentSchema.type === "object") {
const base = currentSchema.selector ? findBySelector(ctx, currentSchema.selector)[0] : ctx;
if (!base || !(base instanceof Element) && base !== document) {
return undefined;
}
const result: Record<string, unknown> = {};
for (const [key, childSchema] of Object.entries(currentSchema.properties)) {
const childValue = extract(base, childSchema);
if (childValue !== undefined) {
result[key] = childValue;
}
}
return result;
}
if (currentSchema.type === "array") {
if (!currentSchema.items) {
return [];
}
const nodes = currentSchema.selector ? findBySelector(ctx, currentSchema.selector) : [];
if (nodes.length === 0) {
return [];
}
const extracted = [];
for (const node of nodes) {
if (node instanceof Element) {
const value = extract(node, currentSchema.items);
extracted.push(value);
}
}
return extracted;
}
const value = readPrimitiveValue(ctx, currentSchema);
return value;
};
const root = resolveScope(input.scopeSelector);
return extract(root, input.schema);
},
{ schema, scopeSelector }
);
}
export function extractTool(client: SteelClient): ToolDefinition<any, any> {
return {
name: "steel_extract",
label: "Extract",
description: "Extract structured values from page content using a JSON Schema contract",
parameters: Type.Object({
schema: Type.Object({}, { additionalProperties: true, description: "JSON-Schema-like extraction contract." }),
instructions: Type.Optional(
Type.String({ description: "Optional extraction guidance used to disambiguate field selection." })
),
scopeSelector: Type.Optional(
Type.String({ description: "Optional CSS selector that scopes extraction to a container." })
),
strict: Type.Optional(
Type.Boolean({ description: "Reject properties not defined in schema (default true)." })
),
}),
async execute(
_toolCallId: string,
params: {
schema: Record<string, unknown>;
instructions?: string;
scopeSelector?: string;
strict?: boolean;
},
signal: AbortSignal | undefined,
onUpdate: ToolProgressUpdater,
_ctx: ExtensionContext
): Promise<{ content: Array<{ type: "text"; text: string }>; details: object }> {
return withToolError("steel_extract", async () => {
throwIfAborted(signal);
const scopeSelector = normalizeString(params.scopeSelector, "scopeSelector") ?? null;
const strict = params.strict ?? true;
await emitProgress(onUpdate, "steel_extract", "Preparing structured extraction");
const normalizedSchema = normalizeSchema(params.schema, "schema");
const enforcedSchema = strict ? enforceStrictMode(normalizedSchema) : normalizedSchema;
const prompt = buildPrompt(
summarizeSchema(enforcedSchema, "result").join("\n"),
params.instructions
);
const session = (await withAbortSignal(
client.getOrCreateSession(),
signal
)) as SessionLike;
throwIfAborted(signal);
const url = await readSessionUrl(session);
await emitProgress(onUpdate, "steel_extract", `Preparing prompt with ${prompt.split("\n").length} lines`);
const extracted = await withAbortSignal(
extractWithBrowser(session, enforcedSchema, scopeSelector),
signal
);
const validationErrors: string[] = [];
validateExtraction(extracted, enforcedSchema, "result", validationErrors);
if (validationErrors.length > 0) {
throw new Error(
`Extraction result does not match requested schema:\n${validationErrors
.map((error) => `- ${error}`)
.join("\n")}`
);
}
await emitProgress(onUpdate, "steel_extract", "Extraction validated");
return {
content: [{
type: "text",
text: JSON.stringify(extracted, null, 2),
}],
details: {
...sessionDetails(session, url, scopeSelector),
schemaEnforced: strict,
prompt,
},
};
}, signal);
},
};
}

View File

@@ -0,0 +1,304 @@
import type { ExtensionContext, ToolDefinition } from "@mariozechner/pi-coding-agent";
import { Type } from "@sinclair/typebox";
import { sessionDetails, type SteelClient } from "../steel-client.js";
import { runWithCaptchaRecovery, type CaptchaRecoverySummary } from "./captcha-guard.js";
import {
emitProgress,
isAbortError,
throwIfAborted,
withAbortSignal,
withToolError,
type ToolProgressUpdater,
} from "./tool-runtime.js";
import {
MAX_TOOL_TIMEOUT_MS,
resolveToolTimeoutMs,
} from "./tool-settings.js";
type SessionLike = {
id: string;
sessionViewerUrl?: string | null;
captchasStatus?: () => Promise<unknown>;
captchasSolve?: () => Promise<unknown>;
waitForSelector?: (
selector: string,
options?: { state?: "attached" | "visible"; timeout?: number }
) => Promise<unknown>;
fill?: (selector: string, text: string) => Promise<unknown>;
evaluate?: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
locator?: (selector: string) => {
fill?: (text: string) => Promise<unknown>;
waitFor?: (options?: { state?: "attached" | "visible"; timeout?: number }) => Promise<unknown>;
};
page?: {
waitForSelector?: (
selector: string,
options?: { state?: "attached" | "visible"; timeout?: number }
) => Promise<unknown>;
fill?: (selector: string, text: string) => Promise<unknown>;
evaluate?: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
locator?: (selector: string) => {
fill?: (text: string) => Promise<unknown>;
waitFor?: (options?: { state?: "attached" | "visible"; timeout?: number }) => Promise<unknown>;
};
};
};
type FieldInput = {
selector: string;
value: string;
};
type FieldResult = {
selector: string;
status: "success" | "error";
reason?: string;
valueLength: number;
captchaRecovery?: {
triggered: boolean;
retries: number;
solveAttempts: number;
statusChecks: number;
waitTimedOut: boolean;
};
};
function compactCaptchaRecovery(summary: CaptchaRecoverySummary) {
return {
triggered: summary.triggered,
retries: summary.retries,
solveAttempts: summary.solveAttempts,
statusChecks: summary.statusChecks,
waitTimedOut: summary.waitTimedOut,
};
}
function normalizeSelector(selector: string): string {
const trimmed = selector.trim();
if (!trimmed) {
throw new Error("Selector cannot be empty.");
}
return trimmed;
}
function normalizeTimeout(timeoutMs?: number): number {
return resolveToolTimeoutMs(timeoutMs);
}
function normalizeValue(raw: string): string {
return raw;
}
function asArray(input: unknown): FieldInput[] {
if (!Array.isArray(input)) {
return [];
}
return input
.map((entry): FieldInput | null => {
if (typeof entry !== "object" || entry === null) {
return null;
}
const record = entry as Partial<FieldInput>;
if (typeof record.selector !== "string" || typeof record.value !== "string") {
return null;
}
return {
selector: normalizeSelector(record.selector),
value: normalizeValue(record.value),
};
})
.filter((entry): entry is FieldInput => Boolean(entry));
}
async function ensureField(session: SessionLike, selector: string, timeoutMs: number): Promise<void> {
if (typeof session.waitForSelector === "function") {
await session.waitForSelector(selector, { state: "visible", timeout: timeoutMs });
return;
}
if (typeof session.page?.waitForSelector === "function") {
await session.page.waitForSelector(selector, { state: "visible", timeout: timeoutMs });
return;
}
const evaluate = session.evaluate ?? session.page?.evaluate;
if (typeof evaluate !== "function") {
return;
}
const valid = await evaluate((rawSelector: string) => {
const element = document.querySelector(rawSelector);
return Boolean(element);
}, selector);
if (!valid) {
throw new Error(`No element matched selector: ${selector}`);
}
}
async function fill(session: SessionLike, selector: string, value: string): Promise<void> {
if (typeof session.fill === "function") {
await session.fill(selector, value);
return;
}
if (typeof session.page?.fill === "function") {
await session.page.fill(selector, value);
return;
}
const locator =
typeof session.locator === "function"
? session.locator(selector)
: session.page?.locator?.(selector);
const locatorFill = locator?.fill;
if (typeof locatorFill === "function") {
await locatorFill.call(locator, value);
return;
}
const evaluate = session.evaluate ?? session.page?.evaluate;
if (typeof evaluate !== "function") {
throw new Error("Session does not support setting input values.");
}
const ok = await evaluate(
(input: { selector: string; value: string }) => {
const element = document.querySelector(input.selector) as HTMLInputElement | HTMLTextAreaElement | null;
if (!element) {
return false;
}
element.value = input.value;
element.dispatchEvent(new Event("input", { bubbles: true }));
element.dispatchEvent(new Event("change", { bubbles: true }));
return true;
},
{ selector, value }
);
if (!ok) {
throw new Error(`Could not set value for selector: ${selector}`);
}
}
export function fillFormTool(client: SteelClient): ToolDefinition<any, any> {
return {
name: "steel_fill_form",
label: "Fill Form",
description: "Fill multiple input fields in a single tool call",
parameters: Type.Object({
fields: Type.Array(
Type.Object({
selector: Type.String({ description: "CSS selector for the field" }),
value: Type.String({ description: "Value for the field" }),
})
),
timeout: Type.Optional(
Type.Integer({
minimum: 100,
maximum: MAX_TOOL_TIMEOUT_MS,
description: "Maximum milliseconds to wait for each field",
})
),
}),
async execute(
_toolCallId: string,
params: { fields: unknown; timeout?: number },
signal: AbortSignal | undefined,
onUpdate: ToolProgressUpdater,
_ctx: ExtensionContext
): Promise<{ content: Array<{ type: "text"; text: string }>; details: object }> {
return withToolError("steel_fill_form", async () => {
throwIfAborted(signal);
const fields = asArray(params.fields);
if (!fields.length) {
throw new Error("At least one field with selector and value is required.");
}
const timeoutMs = normalizeTimeout(params.timeout);
await emitProgress(onUpdate, "steel_fill_form", `Preparing ${fields.length} field(s)`);
const session = (await withAbortSignal(
client.getOrCreateSession(),
signal
)) as SessionLike;
const results: FieldResult[] = [];
let successCount = 0;
for (let index = 0; index < fields.length; index += 1) {
throwIfAborted(signal);
const entry = fields[index];
const result: FieldResult = {
selector: entry.selector,
status: "error",
valueLength: entry.value.length,
};
await emitProgress(onUpdate, "steel_fill_form", `Processing ${index + 1}/${fields.length}: ${entry.selector}`);
try {
const captchaRecovery = await runWithCaptchaRecovery({
session,
context: "steel_fill_form",
actionLabel: `fill ${entry.selector}`,
onUpdate,
signal,
operation: async () => {
throwIfAborted(signal);
await withAbortSignal(
ensureField(session, entry.selector, timeoutMs),
signal
);
throwIfAborted(signal);
await withAbortSignal(fill(session, entry.selector, entry.value), signal);
},
});
result.status = "success";
result.captchaRecovery = compactCaptchaRecovery(captchaRecovery);
successCount += 1;
await emitProgress(onUpdate, "steel_fill_form", `Filled ${entry.selector}`);
} catch (error) {
if (isAbortError(error)) {
throw error;
}
result.reason = error instanceof Error ? error.message : "Unknown error";
}
results.push(result);
}
if (successCount === 0) {
throw new Error("No form fields were filled successfully.");
}
await emitProgress(onUpdate, "steel_fill_form", `Filled ${successCount}/${fields.length} field(s).`);
return {
content: [
{
type: "text",
text:
successCount === fields.length
? `Filled ${fields.length} form field(s).`
: `Filled ${successCount}/${fields.length} form fields. Some fields failed.`,
},
],
details: {
...sessionDetails(session),
timeoutMs,
total: fields.length,
successCount,
results,
},
};
}, signal);
},
};
}

View File

@@ -0,0 +1,316 @@
import type { ExtensionContext, ToolDefinition } from "@mariozechner/pi-coding-agent";
import { Type } from "@sinclair/typebox";
import { sessionDetails as baseSessionDetails, type SteelClient } from "../steel-client.js";
import {
emitProgress,
throwIfAborted,
withAbortSignal,
withToolError,
type ToolProgressUpdater,
} from "./tool-runtime.js";
import {
blankPageError,
isBlankPageUrl,
readSessionUrl,
} from "./session-state.js";
type SessionLike = {
id: string;
sessionViewerUrl?: string | null;
evaluate?: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
page?: {
evaluate?: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
};
url?: (() => Promise<string> | string) | string;
getCurrentUrl?: () => Promise<string> | string;
};
type Candidate = {
selector: string;
text: string;
tag: string;
role: string | null;
clickable: boolean;
visible: boolean;
};
const MAX_RESULT_LIMIT = 25;
function normalizeLimit(rawLimit?: number): number {
if (rawLimit === undefined) {
return 10;
}
const parsed = Number(rawLimit);
if (!Number.isFinite(parsed) || parsed <= 0) {
throw new Error("limit must be a positive integer.");
}
return Math.min(MAX_RESULT_LIMIT, Math.trunc(parsed));
}
function normalizeOptionalString(value?: string): string | null {
if (value === undefined) {
return null;
}
const normalized = value.trim();
if (!normalized) {
return null;
}
return normalized;
}
function sessionDetails(session: SessionLike, url: string) {
return {
...baseSessionDetails(session),
url,
};
}
async function discoverElements(
session: SessionLike,
input: {
query: string | null;
tag: string | null;
role: string | null;
limit: number;
clickableOnly: boolean;
}
): Promise<Candidate[]> {
const evaluate = session.evaluate ?? session.page?.evaluate;
if (typeof evaluate !== "function") {
throw new Error("Session does not support element discovery.");
}
const results = await evaluate((params: {
query: string | null;
tag: string | null;
role: string | null;
limit: number;
clickableOnly: boolean;
}) => {
const toLower = (value: string | null | undefined): string =>
String(value || "").toLowerCase();
const normalize = (value: string | null | undefined): string =>
String(value || "").replace(/\s+/g, " ").trim();
const cssEscape = (value: string): string => {
if ((window as unknown as { CSS?: { escape?: (v: string) => string } }).CSS?.escape) {
return (window as unknown as { CSS: { escape: (v: string) => string } }).CSS.escape(value);
}
return value.replace(/\\/g, "\\\\").replace(/"/g, '\\"');
};
const isVisible = (element: Element): boolean => {
const style = window.getComputedStyle(element);
const rect = element.getBoundingClientRect();
return (
rect.width > 0 &&
rect.height > 0 &&
style.visibility !== "hidden" &&
style.display !== "none" &&
Number.parseFloat(style.opacity) > 0
);
};
const isClickable = (element: Element): boolean => {
const tag = element.tagName.toLowerCase();
const role = element.getAttribute("role");
if (["a", "button", "summary", "select"].includes(tag)) {
return true;
}
if (tag === "input") {
const input = element as HTMLInputElement;
return input.type !== "hidden";
}
if (role === "button" || role === "link" || role === "menuitem") {
return true;
}
if ((element as HTMLElement).onclick) {
return true;
}
if (element.getAttribute("tabindex") !== null) {
return true;
}
return false;
};
const buildSelector = (element: Element): string => {
const tag = element.tagName.toLowerCase();
const id = element.getAttribute("id");
if (id && document.querySelectorAll(`#${cssEscape(id)}`).length === 1) {
return `#${cssEscape(id)}`;
}
const testId = element.getAttribute("data-testid");
if (testId) {
return `${tag}[data-testid="${cssEscape(testId)}"]`;
}
const name = element.getAttribute("name");
if (name) {
return `${tag}[name="${cssEscape(name)}"]`;
}
const ariaLabel = element.getAttribute("aria-label");
if (ariaLabel) {
return `${tag}[aria-label="${cssEscape(ariaLabel)}"]`;
}
if (tag === "a") {
const href = element.getAttribute("href");
if (href) {
return `a[href="${cssEscape(href)}"]`;
}
}
const text = normalize(element.textContent);
if (text) {
return `text=${text.slice(0, 80)}`;
}
return tag;
};
const queryLower = toLower(params.query);
const tagLower = toLower(params.tag);
const roleLower = toLower(params.role);
const source = Array.from(document.querySelectorAll("*"));
const candidates = source
.map((element) => {
const tag = element.tagName.toLowerCase();
const role = element.getAttribute("role");
const text = normalize(element.textContent);
const clickable = isClickable(element);
const visible = isVisible(element);
const searchBlob = toLower(
`${text} ${element.getAttribute("aria-label") || ""} ${element.getAttribute("title") || ""}`
);
if (tagLower && tag !== tagLower) {
return null;
}
if (roleLower && toLower(role) !== roleLower) {
return null;
}
if (queryLower && !searchBlob.includes(queryLower)) {
return null;
}
if (params.clickableOnly && !clickable) {
return null;
}
if (!visible) {
return null;
}
return {
selector: buildSelector(element),
text: text.slice(0, 200),
tag,
role,
clickable,
visible,
};
})
.filter((item) => Boolean(item)) as Candidate[];
return candidates.slice(0, params.limit);
}, input);
if (!Array.isArray(results)) {
return [];
}
return results as Candidate[];
}
export function findElementsTool(client: SteelClient): ToolDefinition<any, any> {
return {
name: "steel_find_elements",
label: "Find Elements",
description: "Discover likely interactive elements and selector candidates",
parameters: Type.Object({
query: Type.Optional(
Type.String({ description: "Optional text query to filter by visible label/text" })
),
tag: Type.Optional(
Type.String({ description: "Optional exact tag name filter (e.g. button, a, input)" })
),
role: Type.Optional(
Type.String({ description: "Optional exact ARIA role filter (e.g. button, link)" })
),
limit: Type.Optional(
Type.Integer({
minimum: 1,
maximum: MAX_RESULT_LIMIT,
description: "Max number of candidates to return",
})
),
clickableOnly: Type.Optional(
Type.Boolean({ description: "When true, include only likely interactive elements" })
),
}),
async execute(
_toolCallId: string,
params: {
query?: string;
tag?: string;
role?: string;
limit?: number;
clickableOnly?: boolean;
},
signal: AbortSignal | undefined,
onUpdate: ToolProgressUpdater,
_ctx: ExtensionContext
): Promise<{ content: Array<{ type: "text"; text: string }>; details: object }> {
return withToolError("steel_find_elements", async () => {
throwIfAborted(signal);
const query = normalizeOptionalString(params.query);
const tag = normalizeOptionalString(params.tag);
const role = normalizeOptionalString(params.role);
const limit = normalizeLimit(params.limit);
const clickableOnly = params.clickableOnly ?? true;
await emitProgress(onUpdate, "steel_find_elements", "Discovering page elements");
const session = (await withAbortSignal(
client.getOrCreateSession(),
signal
)) as SessionLike;
throwIfAborted(signal);
const url = await readSessionUrl(session);
if (isBlankPageUrl(url)) {
throw blankPageError("discover page elements");
}
const candidates = await withAbortSignal(
discoverElements(session, {
query,
tag,
role,
limit,
clickableOnly,
}),
signal
);
await emitProgress(
onUpdate,
"steel_find_elements",
`Found ${candidates.length} candidate element(s)`
);
return {
content: [{ type: "text", text: JSON.stringify(candidates, null, 2) }],
details: {
...sessionDetails(session, url),
query,
tag,
role,
limit,
clickableOnly,
count: candidates.length,
},
};
}, signal);
},
};
}

View File

@@ -0,0 +1,335 @@
import type { ExtensionContext, ToolDefinition } from "@mariozechner/pi-coding-agent";
import { Type } from "@sinclair/typebox";
import { sessionDetails, type SteelClient } from "../steel-client.js";
import {
emitProgress,
throwIfAborted,
withAbortSignal,
withToolError,
type ToolProgressUpdater,
} from "./tool-runtime.js";
type WaitUntil = "load" | "domcontentloaded" | "networkidle";
type SessionLike = {
id: string;
sessionViewerUrl?: string | null;
goto?: (
url: string,
options?: { waitUntil?: WaitUntil }
) => Promise<unknown> | unknown;
};
const ALLOWED_WAIT_UNTIL: readonly WaitUntil[] = ["load", "domcontentloaded", "networkidle"];
const DEFAULT_WAIT_UNTIL: WaitUntil = "networkidle";
const FALLBACK_WAIT_UNTILS: readonly WaitUntil[] = ["domcontentloaded", "load"];
const DEFAULT_NAVIGATION_RETRIES = 1;
const NAVIGATE_CONTEXT = "steel_navigate";
type SessionRefreshOptions = {
useProxy?: boolean;
proxyUrl?: string | null;
};
type SessionRefreshClient = {
refreshSession?: (options?: SessionRefreshOptions) => Promise<SessionLike>;
isProxyConfigured?: () => boolean;
};
function resolveWaitUntil(waitUntil?: string): WaitUntil {
if (waitUntil !== undefined && ALLOWED_WAIT_UNTIL.includes(waitUntil as WaitUntil)) {
return waitUntil as WaitUntil;
}
return DEFAULT_WAIT_UNTIL;
}
function normalizeUrl(rawUrl: string): string {
const trimmed = rawUrl.trim();
if (!trimmed) {
throw new Error("URL cannot be empty.");
}
const hasSchemeWithAuthority = /^[a-zA-Z][a-zA-Z\d+\-.]*:\/\//.test(trimmed);
const hasSchemeWithoutAuthority = /^[a-zA-Z][a-zA-Z\d+\-.]*:/.test(trimmed);
const looksLikeHostWithPort = /^[^/\s:]+:\d+(?:[/?#]|$)/.test(trimmed);
const normalized = trimmed.startsWith("//")
? `https:${trimmed}`
: hasSchemeWithAuthority || (hasSchemeWithoutAuthority && !looksLikeHostWithPort)
? trimmed
: `https://${trimmed}`;
try {
const parsed = new URL(normalized);
if (parsed.protocol !== "http:" && parsed.protocol !== "https:") {
throw new Error("Only http and https URLs are supported.");
}
return parsed.toString();
} catch (error) {
throw new Error(`Invalid URL: ${String(error instanceof Error ? error.message : "invalid URL")}`);
}
}
function normalizeRetryCount(raw: string | undefined): number {
if (raw === undefined) {
return DEFAULT_NAVIGATION_RETRIES;
}
const value = raw.trim();
if (!value) {
return DEFAULT_NAVIGATION_RETRIES;
}
const parsed = Number.parseInt(value, 10);
if (!Number.isFinite(parsed) || parsed < 0) {
return DEFAULT_NAVIGATION_RETRIES;
}
return Math.min(parsed, 3);
}
function isTimeoutError(error: unknown): boolean {
const message = String(error instanceof Error ? error.message : error || "");
return /timed? ?out|timeout/i.test(message);
}
function isNetworkError(error: unknown): boolean {
const message = String(error instanceof Error ? error.message : error || "");
return /ERR_|ECONN|ENOTFOUND|EAI_AGAIN|DNS|network/i.test(message);
}
function isTunnelConnectionError(error: unknown): boolean {
const message = String(error instanceof Error ? error.message : error || "");
return /ERR_TUNNEL_CONNECTION_FAILED|TUNNEL_CONNECTION_FAILED/i.test(message);
}
function buildWaitStrategy(preferred: WaitUntil): WaitUntil[] {
const ordered = [preferred, ...FALLBACK_WAIT_UNTILS];
const deduped: WaitUntil[] = [];
for (const value of ordered) {
if (!deduped.includes(value)) {
deduped.push(value);
}
}
return deduped;
}
async function navigateWithRecovery(
session: SessionLike,
options: {
targetUrl: string;
waitUntil: WaitUntil;
onUpdate: ToolProgressUpdater;
signal: AbortSignal | undefined;
}
): Promise<WaitUntil> {
const { targetUrl, waitUntil, onUpdate, signal } = options;
throwIfAborted(signal);
if (!session.goto) {
throw new Error("Session does not support navigation.");
}
const retryCount = normalizeRetryCount(process.env.STEEL_NAVIGATE_RETRY_COUNT);
const waitStrategy = buildWaitStrategy(waitUntil);
let lastError: unknown = null;
for (let waitIndex = 0; waitIndex < waitStrategy.length; waitIndex += 1) {
throwIfAborted(signal);
const waitMode = waitStrategy[waitIndex];
for (let attempt = 0; attempt <= retryCount; attempt += 1) {
throwIfAborted(signal);
try {
await emitProgress(
onUpdate,
NAVIGATE_CONTEXT,
`Navigating with ${waitMode} (attempt ${attempt + 1}/${retryCount + 1})`
);
await withAbortSignal(
Promise.resolve(session.goto(targetUrl, { waitUntil: waitMode })),
signal
);
return waitMode;
} catch (error: unknown) {
throwIfAborted(signal);
lastError = error;
const canRetryNetwork = attempt < retryCount && isNetworkError(error);
if (canRetryNetwork) {
await emitProgress(
onUpdate,
NAVIGATE_CONTEXT,
`Network issue detected; retrying ${waitMode}`
);
continue;
}
if (
waitIndex < waitStrategy.length - 1 &&
isTimeoutError(error)
) {
await emitProgress(
onUpdate,
NAVIGATE_CONTEXT,
`Timeout on ${waitMode}; falling back to ${waitStrategy[waitIndex + 1]}`
);
}
break;
}
}
}
throw lastError instanceof Error
? lastError
: new Error("Navigation failed");
}
async function refreshNavigationSession(
client: SteelClient,
options?: SessionRefreshOptions
): Promise<SessionLike | null> {
const refresh = (client as unknown as SessionRefreshClient).refreshSession;
if (typeof refresh !== "function") {
return null;
}
return refresh(options);
}
function shouldTryNoProxyFallback(client: SteelClient): boolean {
const isProxyConfigured = (client as unknown as SessionRefreshClient)
.isProxyConfigured;
if (typeof isProxyConfigured !== "function") {
return false;
}
return isProxyConfigured();
}
export function navigateTool(client: SteelClient): ToolDefinition<any, any> {
return {
name: "steel_navigate",
label: "Navigate",
description: "Navigate to a URL in the browser",
parameters: Type.Object({
url: Type.String({ description: "The URL to navigate to" }),
waitUntil: Type.Optional(
Type.Union([
Type.Literal("load"),
Type.Literal("domcontentloaded"),
Type.Literal("networkidle"),
], { description: "When to consider navigation complete" })
),
}),
async execute(
_toolCallId: string,
params: { url: string; waitUntil?: WaitUntil },
signal: AbortSignal | undefined,
onUpdate: ToolProgressUpdater,
_ctx: ExtensionContext
): Promise<{ content: Array<{ type: "text"; text: string }>; details: object }> {
return withToolError("steel_navigate", async () => {
throwIfAborted(signal);
const targetUrl = normalizeUrl(params.url);
const waitUntil = resolveWaitUntil(params.waitUntil);
await emitProgress(onUpdate, NAVIGATE_CONTEXT, `Preparing navigation to ${targetUrl}`);
await emitProgress(onUpdate, NAVIGATE_CONTEXT, `Waiting for browser session`);
let session = (await withAbortSignal(
client.getOrCreateSession(),
signal
)) as SessionLike;
let usedWaitUntil: WaitUntil;
let recoveryMode: "none" | "fresh_session" | "no_proxy" = "none";
try {
usedWaitUntil = await navigateWithRecovery(session, {
targetUrl,
waitUntil,
onUpdate,
signal,
});
} catch (error: unknown) {
throwIfAborted(signal);
if (!isTunnelConnectionError(error)) {
throw error;
}
await emitProgress(
onUpdate,
NAVIGATE_CONTEXT,
"Tunnel connection failed; recreating browser session and retrying"
);
const freshSession = await withAbortSignal(
refreshNavigationSession(client),
signal
);
if (!freshSession) {
throw error;
}
session = freshSession;
try {
usedWaitUntil = await navigateWithRecovery(session, {
targetUrl,
waitUntil,
onUpdate,
signal,
});
recoveryMode = "fresh_session";
} catch (freshError: unknown) {
throwIfAborted(signal);
if (
!isTunnelConnectionError(freshError) ||
!shouldTryNoProxyFallback(client)
) {
throw freshError;
}
await emitProgress(
onUpdate,
NAVIGATE_CONTEXT,
"Tunnel failure persisted; retrying once with proxy disabled"
);
const noProxySession = await withAbortSignal(
refreshNavigationSession(client, {
useProxy: false,
proxyUrl: null,
}),
signal
);
if (!noProxySession) {
throw freshError;
}
session = noProxySession;
usedWaitUntil = await navigateWithRecovery(session, {
targetUrl,
waitUntil,
onUpdate,
signal,
});
recoveryMode = "no_proxy";
}
}
await emitProgress(onUpdate, NAVIGATE_CONTEXT, `Navigation complete to ${targetUrl}`);
return {
content: [{
type: "text",
text: `Successfully navigated to ${targetUrl}`,
}],
details: {
...sessionDetails(session),
requestedUrl: params.url,
url: targetUrl,
waitUntil: usedWaitUntil,
requestedWaitUntil: waitUntil,
tunnelRecovery:
recoveryMode === "none"
? null
: {
attempted: true,
mode: recoveryMode,
},
},
};
}, signal);
},
};
}

View File

@@ -0,0 +1,193 @@
import type { ExtensionContext, ToolDefinition } from "@mariozechner/pi-coding-agent";
import { Type } from "@sinclair/typebox";
import { sessionDetails, type SteelClient } from "../steel-client.js";
import {
emitProgress,
throwIfAborted,
withAbortSignal,
withToolError,
type ToolProgressUpdater,
} from "./tool-runtime.js";
import {
blankPageError,
describeBlankPage,
isBlankPageUrl,
readSessionTitle,
readSessionUrl,
} from "./session-state.js";
type SessionLike = {
id: string;
sessionViewerUrl?: string | null;
goBack?: (options?: { waitUntil?: "load" | "domcontentloaded" | "networkidle"; timeout?: number }) => Promise<unknown> | unknown;
back?: (options?: { waitUntil?: "load" | "domcontentloaded" | "networkidle"; timeout?: number }) => Promise<unknown> | unknown;
url?: (() => Promise<string> | string) | string;
title?: (() => Promise<string> | string) | string;
getCurrentUrl?: () => Promise<string> | string;
};
const GO_BACK_TIMEOUT_MS = 10_000;
function isTimeoutError(error: unknown): boolean {
const message = String(error instanceof Error ? error.message : error || "");
return /timed? ?out|timeout/i.test(message);
}
export function goBackTool(client: SteelClient): ToolDefinition<any, any> {
return {
name: "steel_go_back",
label: "Go Back",
description: "Navigate back in browser history",
parameters: Type.Object({}),
async execute(
_toolCallId: string,
_params: {},
signal: AbortSignal | undefined,
onUpdate: ToolProgressUpdater,
_ctx: ExtensionContext
): Promise<{ content: Array<{ type: "text"; text: string }>; details: object }> {
return withToolError("steel_go_back", async () => {
throwIfAborted(signal);
await emitProgress(onUpdate, "steel_go_back", "Preparing history navigation");
const session = (await withAbortSignal(
client.getOrCreateSession(),
signal
)) as SessionLike;
const previousUrl = await readSessionUrl(session);
const goBack = session.goBack ?? session.back;
if (typeof goBack !== "function") {
throw new Error("Session does not support browser history navigation.");
}
await emitProgress(onUpdate, "steel_go_back", "Returning to previous page");
let timeoutRecovered = false;
try {
await withAbortSignal(
Promise.resolve(
goBack.call(session, {
waitUntil: "domcontentloaded",
timeout: GO_BACK_TIMEOUT_MS,
})
),
signal
);
} catch (error: unknown) {
const currentUrlAfterFailure = await readSessionUrl(session);
if (
isTimeoutError(error) &&
currentUrlAfterFailure !== "unknown" &&
currentUrlAfterFailure !== previousUrl &&
!isBlankPageUrl(currentUrlAfterFailure)
) {
timeoutRecovered = true;
await emitProgress(
onUpdate,
"steel_go_back",
`History navigation completed after timeout; now at ${currentUrlAfterFailure}`
);
} else {
throw error;
}
}
const currentUrl = await readSessionUrl(session);
await emitProgress(onUpdate, "steel_go_back", `Returned to ${currentUrl}`);
return {
content: [{
type: "text",
text: `Navigated back to ${currentUrl}`,
}],
details: {
...sessionDetails(session),
previousUrl,
url: currentUrl,
timeoutRecovered,
},
};
}, signal);
},
};
}
export function getUrlTool(client: SteelClient): ToolDefinition<any, any> {
return {
name: "steel_get_url",
label: "Get URL",
description: "Get current page URL",
parameters: Type.Object({}),
async execute(
_toolCallId: string,
_params: {},
signal: AbortSignal | undefined,
onUpdate: ToolProgressUpdater,
_ctx: ExtensionContext
): Promise<{ content: Array<{ type: "text"; text: string }>; details: object }> {
return withToolError("steel_get_url", async () => {
throwIfAborted(signal);
await emitProgress(onUpdate, "steel_get_url", "Reading current URL");
const session = (await withAbortSignal(
client.getOrCreateSession(),
signal
)) as SessionLike;
const url = await readSessionUrl(session);
const isFreshSession = isBlankPageUrl(url);
const text = isFreshSession ? describeBlankPage(url) : `Current URL: ${url}`;
return {
content: [{ type: "text", text }],
details: {
...sessionDetails(session),
url,
isFreshSession,
},
};
}, signal);
},
};
}
export function getTitleTool(client: SteelClient): ToolDefinition<any, any> {
return {
name: "steel_get_title",
label: "Get Title",
description: "Get current page title",
parameters: Type.Object({}),
async execute(
_toolCallId: string,
_params: {},
signal: AbortSignal | undefined,
onUpdate: ToolProgressUpdater,
_ctx: ExtensionContext
): Promise<{ content: Array<{ type: "text"; text: string }>; details: object }> {
return withToolError("steel_get_title", async () => {
throwIfAborted(signal);
await emitProgress(onUpdate, "steel_get_title", "Reading current page title");
const session = (await withAbortSignal(
client.getOrCreateSession(),
signal
)) as SessionLike;
const url = await readSessionUrl(session);
if (isBlankPageUrl(url)) {
throw blankPageError("read the page title");
}
const title = await readSessionTitle(session);
return {
content: [{ type: "text", text: `Current title: ${title}` }],
details: {
...sessionDetails(session),
url,
title,
},
};
}, signal);
},
};
}

View File

@@ -0,0 +1,237 @@
import { promises as fs } from "node:fs";
import path from "node:path";
import { randomUUID } from "node:crypto";
import type { ExtensionContext, ToolDefinition } from "@mariozechner/pi-coding-agent";
import { Type } from "@sinclair/typebox";
import { sessionDetails as baseSessionDetails, type SteelClient } from "../steel-client.js";
import {
emitProgress,
throwIfAborted,
withAbortSignal,
withToolError,
type ToolProgressUpdater,
} from "./tool-runtime.js";
type SessionLike = {
id: string;
sessionViewerUrl?: string | null;
pdf?: (options?: {
path?: string;
printBackground?: boolean;
preferCSSPageSize?: boolean;
}) => Promise<unknown>;
page?: {
pdf?: (options?: {
path?: string;
printBackground?: boolean;
preferCSSPageSize?: boolean;
}) => Promise<unknown>;
};
url?: (() => Promise<string> | string) | string;
};
const RELATIVE_PDF_DIR = path.join(".artifacts", "pdfs");
const DEFAULT_PDF_OPTIONS = {
printBackground: true,
preferCSSPageSize: true,
};
function sessionDetails(session: SessionLike, url: string) {
return {
...baseSessionDetails(session),
url,
};
}
function artifactDirectory(): string {
return path.resolve(process.cwd(), RELATIVE_PDF_DIR);
}
function toArtifactDisplayPath(filePath: string): string {
const relativePath = path.relative(process.cwd(), filePath);
if (!relativePath || relativePath.startsWith("..")) {
return path.basename(filePath);
}
return relativePath;
}
async function makeArtifactPath(): Promise<string> {
const dir = artifactDirectory();
await fs.mkdir(dir, { recursive: true });
const safeId = randomUUID().slice(0, 8);
return path.join(dir, `steel-pdf-${Date.now()}-${safeId}.pdf`);
}
async function fileExists(filePath: string): Promise<boolean> {
try {
await fs.access(filePath);
return true;
} catch {
return false;
}
}
function isBinaryLike(value: unknown): Buffer | Uint8Array | null {
if (value instanceof Uint8Array) {
return value;
}
if (value instanceof Buffer) {
return value;
}
return null;
}
async function writeBinaryArtifact(filePath: string, payload: unknown): Promise<void> {
const binary = isBinaryLike(payload);
if (!binary) {
return;
}
await fs.writeFile(filePath, Buffer.from(binary));
}
async function readSessionUrl(session: SessionLike): Promise<string> {
const direct = session.url;
if (typeof direct === "string" && direct.trim()) {
return direct;
}
if (typeof direct === "function") {
const value = await direct.call(session);
if (typeof value === "string" && value.trim()) {
return value;
}
}
const getter = (session as { getCurrentUrl?: () => Promise<string> | string }).getCurrentUrl;
if (typeof getter === "function") {
const value = await getter.call(session);
if (typeof value === "string" && value.trim()) {
return value;
}
}
return "unknown";
}
async function generatePdf(session: SessionLike, filePath: string): Promise<unknown> {
const pdfCall = session.pdf ?? session.page?.pdf;
if (typeof pdfCall !== "function") {
throw new Error("Session does not support PDF generation.");
}
const options = { path: filePath, ...DEFAULT_PDF_OPTIONS };
if (pdfCall === session.pdf) {
return session.pdf?.(options);
}
return session.page?.pdf?.(options);
}
export function pdfTool(client: SteelClient): ToolDefinition<any, any> {
return {
name: "steel_pdf",
label: "PDF",
description: "Capture the current page as a PDF artifact",
parameters: Type.Object({
printBackground: Type.Optional(
Type.Boolean({
description: "Whether to include page background graphics in the PDF",
})
),
preferCSSPageSize: Type.Optional(
Type.Boolean({
description: "Whether to use page-defined CSS size when available",
})
),
}),
async execute(
_toolCallId: string,
params: {
printBackground?: boolean;
preferCSSPageSize?: boolean;
},
signal: AbortSignal | undefined,
onUpdate: ToolProgressUpdater,
_ctx: ExtensionContext
): Promise<{ content: Array<{ type: "text"; text: string }>; details: object }> {
return withToolError("steel_pdf", async () => {
throwIfAborted(signal);
await emitProgress(onUpdate, "steel_pdf", "Preparing PDF artifact path");
const session = (await withAbortSignal(
client.getOrCreateSession(),
signal
)) as SessionLike;
throwIfAborted(signal);
const url = await readSessionUrl(session);
const targetPath = await makeArtifactPath();
const options = {
printBackground:
params.printBackground !== undefined
? params.printBackground
: DEFAULT_PDF_OPTIONS.printBackground,
preferCSSPageSize:
params.preferCSSPageSize !== undefined
? params.preferCSSPageSize
: DEFAULT_PDF_OPTIONS.preferCSSPageSize,
};
const pdfOptions = {
...options,
path: targetPath,
};
await emitProgress(onUpdate, "steel_pdf", "Generating PDF now");
const pdfResult = await (async () => {
const pdfCall = session.pdf ?? session.page?.pdf;
if (typeof pdfCall !== "function") {
throw new Error("Session does not support PDF generation.");
}
if (pdfCall === session.pdf) {
return session.pdf?.(pdfOptions);
}
return session.page?.pdf?.(pdfOptions);
})();
await emitProgress(onUpdate, "steel_pdf", `Writing PDF to ${targetPath}`);
await writeBinaryArtifact(targetPath, pdfResult);
if (!(await fileExists(targetPath))) {
throw new Error("PDF artifact was not written to disk.");
}
const stats = await fs.stat(targetPath);
const fileName = path.basename(targetPath);
const displayPath = toArtifactDisplayPath(targetPath);
return {
content: [{
type: "text",
text: `PDF saved: ${displayPath}`,
}],
details: {
...sessionDetails(session, url),
filePath: displayPath,
absoluteFilePath: targetPath,
artifact: {
type: "pdf",
mimeType: "application/pdf",
path: displayPath,
fileName,
sizeBytes: stats.size,
createdAt: new Date().toISOString(),
},
options,
},
};
}, signal);
},
};
}

View File

@@ -0,0 +1,408 @@
import type { ExtensionContext, ToolDefinition } from "@mariozechner/pi-coding-agent";
import { Type } from "@sinclair/typebox";
import { sessionDetails as baseSessionDetails, type SteelClient } from "../steel-client.js";
import {
emitProgress,
throwIfAborted,
withAbortSignal,
withToolError,
type ToolProgressUpdater,
} from "./tool-runtime.js";
import {
blankPageError,
isBlankPageUrl,
readSessionUrl,
} from "./session-state.js";
type ScrapeFormat = "html" | "markdown" | "text";
type SessionLike = {
id: string;
sessionViewerUrl?: string | null;
content?: () => Promise<unknown>;
evaluate?: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
page?: {
content?: () => Promise<unknown>;
evaluate?: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
};
url?: (() => Promise<string> | string) | string;
getCurrentUrl?: () => Promise<string> | string;
};
const ALLOWED_FORMATS: readonly ScrapeFormat[] = ["html", "markdown", "text"];
const DEFAULT_FORMAT: ScrapeFormat = "text";
const DEFAULT_MAX_CHARS = 12_000;
const MIN_MAX_CHARS = 1;
const MAX_MAX_CHARS = 200_000;
function resolveFormat(rawFormat?: string): ScrapeFormat {
if (typeof rawFormat === "string" && ALLOWED_FORMATS.includes(rawFormat as ScrapeFormat)) {
return rawFormat as ScrapeFormat;
}
return DEFAULT_FORMAT;
}
function readMaxCharsFromEnv(): number | null {
const raw = process.env.STEEL_SCRAPE_MAX_CHARS;
if (!raw) {
return null;
}
const parsed = Number(raw);
if (!Number.isFinite(parsed) || parsed <= 0) {
return null;
}
return Math.min(MAX_MAX_CHARS, Math.trunc(parsed));
}
function resolveMaxChars(rawMaxChars?: number): number {
if (rawMaxChars === undefined) {
return readMaxCharsFromEnv() ?? DEFAULT_MAX_CHARS;
}
const parsed = Number(rawMaxChars);
if (!Number.isFinite(parsed) || parsed < MIN_MAX_CHARS) {
throw new Error(`maxChars must be an integer >= ${MIN_MAX_CHARS}.`);
}
return Math.min(MAX_MAX_CHARS, Math.trunc(parsed));
}
function normalizeSelector(selector?: string): string | undefined {
if (selector === undefined) {
return undefined;
}
const trimmed = selector.trim();
if (!trimmed) {
throw new Error("selector cannot be empty.");
}
return trimmed;
}
function sessionDetails(session: SessionLike, url: string, format: ScrapeFormat, selector: string | undefined) {
return {
...baseSessionDetails(session),
url,
format,
selector: selector ?? null,
};
}
function extractFallbackText(rawHtml: string): string {
return rawHtml
.replace(/<script\b[^<]*(?:(?!<\/script>)<[^<]*)*<\/script>/gi, "")
.replace(/<style\b[^<]*(?:(?!<\/style>)<[^<]*)*<\/style>/gi, "")
.replace(/<[^>]*>/g, "\n")
.replace(/\u00a0/g, " ")
.replace(/\s+\n/g, "\n")
.replace(/\n{3,}/g, "\n\n")
.trim();
}
function cleanInnerText(raw: string): string {
return raw
.replace(/\u00a0/g, " ")
.replace(/\r?\n{3,}/g, "\n\n")
.trim();
}
function truncateContent(raw: string, maxChars: number): {
text: string;
truncated: boolean;
originalLength: number;
} {
const originalLength = raw.length;
if (originalLength <= maxChars) {
return {
text: raw,
truncated: false,
originalLength,
};
}
const omitted = originalLength - maxChars;
const marker = `\n\n[truncated ${omitted} chars]`;
const headLength = Math.max(0, maxChars - marker.length);
return {
text: `${raw.slice(0, headLength)}${marker}`,
truncated: true,
originalLength,
};
}
async function extractWithBrowserEvaluate(
session: SessionLike,
format: ScrapeFormat,
selector: string | undefined
): Promise<string> {
const evaluate = session.evaluate ?? session.page?.evaluate;
if (typeof evaluate !== "function") {
throw new Error("Session does not support DOM extraction.");
}
const payload = await evaluate((input: { selector: string | null; format: ScrapeFormat }) => {
const getRoot = () => {
if (!input.selector) {
return document.documentElement;
}
return document.querySelector(input.selector);
};
const root = getRoot();
if (!root) {
return null as unknown as string;
}
const baseText = (): string => {
const text = (root as HTMLElement).innerText || root.textContent || "";
return text.replace(/\u00a0/g, " ").replace(/\r?\n{3,}/g, "\n\n").trim();
};
const markdownFromNode = (node: Node, depth = 0): string => {
if (node.nodeType === Node.TEXT_NODE) {
return (node.textContent || "").replace(/\u00a0/g, " ");
}
if (node.nodeType !== Node.ELEMENT_NODE) {
return "";
}
const element = node as Element;
const tag = element.tagName.toLowerCase();
const pad = " ".repeat(depth);
const childText = Array.from(element.childNodes)
.map((child) => markdownFromNode(child, depth + 1))
.join("");
switch (tag) {
case "h1":
return `\n# ${clean(childText)}\n\n`;
case "h2":
return `\n## ${clean(childText)}\n\n`;
case "h3":
return `\n### ${clean(childText)}\n\n`;
case "h4":
return `\n#### ${clean(childText)}\n\n`;
case "h5":
return `\n##### ${clean(childText)}\n\n`;
case "h6":
return `\n###### ${clean(childText)}\n\n`;
case "p":
case "article":
case "section":
return `${clean(childText)}\n\n`;
case "blockquote":
return `\n${clean(childText).replace(/\n/g, "\n> ")}\n\n`;
case "pre":
return `\n\`\`\`\n${(element.textContent || "").replace(/\n+$/, "")}\n\`\`\`\n\n`;
case "code":
return `\`${clean(childText)}\``;
case "strong":
case "b":
return `**${clean(childText)}**`;
case "em":
case "i":
return `*${clean(childText)}*`;
case "a": {
const href = (element as HTMLAnchorElement).getAttribute("href") || "";
return `[${clean(childText)}](${href})`;
}
case "img": {
const src = (element as HTMLImageElement).getAttribute("src") || "";
const alt = (element as HTMLImageElement).getAttribute("alt") || "";
return `![${alt}](${src})`;
}
case "ul":
return (
Array.from(element.children)
.filter((item) => item.tagName.toLowerCase() === "li")
.map((item) => `${pad}- ${clean(markdownFromNode(item).trim())}`)
.join("\n") + "\n\n"
);
case "ol":
return (
Array.from(element.children)
.filter((item) => item.tagName.toLowerCase() === "li")
.map((item, index) => `${pad}${index + 1}. ${clean(markdownFromNode(item).trim())}`)
.join("\n") + "\n\n"
);
case "li":
return childText.trim();
case "div":
case "main":
case "header":
case "footer":
case "nav":
case "aside":
return `${clean(childText)}\n`;
case "br":
return "\n";
default:
return childText;
}
};
const clean = (value: string): string =>
value
.replace(/\n{3,}/g, "\n\n")
.replace(/\s+\n/g, "\n")
.trim();
if (input.format === "html") {
return (root as HTMLElement).outerHTML;
}
if (input.format === "text") {
return baseText();
}
if (input.format === "markdown") {
return clean(markdownFromNode(root).trim());
}
return clean(root.textContent || "");
}, { selector: selector ?? null, format });
if (payload === null) {
throw new Error(selector
? `No element matched selector: ${selector}`
: "Could not extract page HTML from the browser.");
}
if (typeof payload !== "string") {
throw new Error("Scrape operation returned an unexpected payload.");
}
return payload;
}
async function scrapeContent(
session: SessionLike,
format: ScrapeFormat,
selector: string | undefined
): Promise<string> {
if (!selector && format === "html" && typeof session.content === "function") {
const pageHtml = await session.content();
if (typeof pageHtml === "string") {
return pageHtml;
}
}
if (!selector && format === "html" && typeof session.page?.content === "function") {
const pageHtml = await session.page.content();
if (typeof pageHtml === "string") {
return pageHtml;
}
}
try {
const value = await extractWithBrowserEvaluate(session, format, selector);
if (typeof value === "string") {
return value;
}
} catch (error) {
if (format !== "text") {
throw error;
}
}
const maybeHtml = await (() => {
if (typeof session.content === "function") {
return session.content();
}
if (typeof session.page?.content === "function") {
return session.page.content();
}
return Promise.resolve(undefined);
})();
if (typeof maybeHtml === "string") {
return extractFallbackText(maybeHtml);
}
throw new Error("Session does not support scrape content extraction.");
}
export function scrapeTool(client: SteelClient): ToolDefinition<any, any> {
return {
name: "steel_scrape",
label: "Scrape",
description: "Extract readable current page content. Use text by default for answering questions, markdown when structure matters, and html only for DOM/debugging cases.",
parameters: Type.Object({
format: Type.Optional(
Type.Union(
[Type.Literal("html"), Type.Literal("markdown"), Type.Literal("text")],
{ description: "Output format. Prefer text for concise reading, markdown to preserve headings/lists/links, and html only when raw DOM markup is specifically needed." }
)
),
selector: Type.Optional(
Type.String({ description: "Optional CSS selector to scope extraction to a specific element before converting to the requested output format" })
),
maxChars: Type.Optional(
Type.Integer({
minimum: MIN_MAX_CHARS,
maximum: MAX_MAX_CHARS,
description: `Maximum characters to return after conversion to text/markdown/html (default: ${DEFAULT_MAX_CHARS}, env override: STEEL_SCRAPE_MAX_CHARS)`,
})
),
}),
async execute(
_toolCallId: string,
params: { format?: ScrapeFormat; selector?: string; maxChars?: number },
signal: AbortSignal | undefined,
onUpdate: ToolProgressUpdater,
_ctx: ExtensionContext
): Promise<{ content: Array<{ type: "text"; text: string }>; details: object }> {
return withToolError("steel_scrape", async () => {
throwIfAborted(signal);
const format = resolveFormat(params.format);
const selector = normalizeSelector(params.selector);
const maxChars = resolveMaxChars(params.maxChars);
const target = selector ? ` (selector ${selector})` : " (full page)";
await emitProgress(onUpdate, "steel_scrape", `Preparing ${format} scrape for${target}`);
const session = (await withAbortSignal(
client.getOrCreateSession(),
signal
)) as SessionLike;
throwIfAborted(signal);
const url = await readSessionUrl(session);
if (isBlankPageUrl(url)) {
throw blankPageError("scrape page content");
}
await emitProgress(onUpdate, "steel_scrape", "Running extraction");
const result = await withAbortSignal(
scrapeContent(session, format, selector),
signal
);
const cleanedResult = format === "text" ? cleanInnerText(result) : result;
const limitedResult = truncateContent(cleanedResult, maxChars);
if (limitedResult.truncated) {
await emitProgress(
onUpdate,
"steel_scrape",
`Scrape output truncated to ${maxChars} chars`
);
}
await emitProgress(onUpdate, "steel_scrape", "Scrape complete");
return {
content: [{ type: "text", text: limitedResult.text }],
details: {
...sessionDetails(session, url, format, selector),
maxChars,
contentLength: limitedResult.text.length,
originalContentLength: limitedResult.originalLength,
truncated: limitedResult.truncated,
},
};
}, signal);
},
};
}

View File

@@ -0,0 +1,373 @@
import { promises as fs } from "node:fs";
import path from "node:path";
import { randomUUID } from "node:crypto";
import type { ExtensionContext, ToolDefinition } from "@mariozechner/pi-coding-agent";
import { Type } from "@sinclair/typebox";
import { sessionDetails as baseSessionDetails, type SteelClient } from "../steel-client.js";
import {
emitProgress,
throwIfAborted,
withAbortSignal,
withToolError,
type ToolProgressUpdater,
} from "./tool-runtime.js";
import {
MAX_TOOL_TIMEOUT_MS,
resolveToolTimeoutMs,
} from "./tool-settings.js";
type SessionLike = {
id: string;
sessionViewerUrl?: string | null;
evaluate?: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
waitForSelector?: (
selector: string,
options?: { state?: "attached" | "visible"; timeout?: number }
) => Promise<unknown>;
screenshot?: (options?: Record<string, unknown>) => Promise<unknown>;
locator?: (selector: string) => {
screenshot?: (options?: Record<string, unknown>) => Promise<unknown>;
};
page?: {
evaluate?: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
waitForSelector?: (
selector: string,
options?: { state?: "attached" | "visible"; timeout?: number }
) => Promise<unknown>;
screenshot?: (options?: Record<string, unknown>) => Promise<unknown>;
locator?: (selector: string) => {
screenshot?: (options?: Record<string, unknown>) => Promise<unknown>;
};
};
url?: (() => Promise<string> | string) | string;
};
type ClipRect = {
x: number;
y: number;
width: number;
height: number;
};
const DEFAULT_FULL_PAGE = false;
const RELATIVE_SCREENSHOT_DIR = path.join(".artifacts", "screenshots");
function sessionDetails(session: SessionLike, url: string, selector: string | undefined, fullPage: boolean) {
return {
...baseSessionDetails(session),
url,
selector: selector ?? null,
fullPage,
};
}
function normalizeSelector(selector?: string): string | undefined {
if (selector === undefined) {
return undefined;
}
const trimmed = selector.trim();
if (!trimmed) {
throw new Error("selector cannot be empty.");
}
return trimmed;
}
function resolveTimeoutMs(rawTimeout?: number): number {
return resolveToolTimeoutMs(rawTimeout);
}
function normalizeFullPage(fullPage?: boolean): boolean {
return fullPage === true;
}
async function readSessionUrl(session: SessionLike): Promise<string> {
const direct = session.url;
if (typeof direct === "string" && direct.trim()) {
return direct;
}
if (typeof direct === "function") {
const value = await direct.call(session);
if (typeof value === "string" && value.trim()) {
return value;
}
}
const getter = (session as { getCurrentUrl?: () => Promise<string> | string }).getCurrentUrl;
if (typeof getter === "function") {
const value = await getter.call(session);
if (typeof value === "string" && value.trim()) {
return value;
}
}
return "unknown";
}
async function fileExists(filePath: string): Promise<boolean> {
try {
await fs.access(filePath);
return true;
} catch {
return false;
}
}
function artifactDirectory(): string {
return path.resolve(process.cwd(), RELATIVE_SCREENSHOT_DIR);
}
function toArtifactDisplayPath(filePath: string): string {
const relativePath = path.relative(process.cwd(), filePath);
if (!relativePath || relativePath.startsWith("..")) {
return path.basename(filePath);
}
return relativePath;
}
async function makeArtifactPath(): Promise<string> {
const dir = artifactDirectory();
await fs.mkdir(dir, { recursive: true });
const safeId = randomUUID().slice(0, 8);
return path.join(dir, `steel-screenshot-${Date.now()}-${safeId}.png`);
}
async function getWaitForSelector(session: SessionLike): Promise<
(selector: string, timeoutMs: number) => Promise<void>
> {
if (typeof session.waitForSelector === "function") {
return async (selector, timeoutMs) => {
await session.waitForSelector?.(selector, { state: "visible", timeout: timeoutMs });
};
}
if (typeof session.page?.waitForSelector === "function") {
return async (selector, timeoutMs) => {
await session.page?.waitForSelector?.(selector, { state: "visible", timeout: timeoutMs });
};
}
return async () => {
return;
};
}
function getSessionScreenshot(
session: SessionLike
): ((options: Record<string, unknown>) => Promise<unknown>) | undefined {
if (typeof session.screenshot === "function") {
return async (options: Record<string, unknown>) => {
return session.screenshot?.(options);
};
}
if (typeof session.page?.screenshot === "function") {
return async (options: Record<string, unknown>) => {
return session.page?.screenshot?.(options);
};
}
return undefined;
}
function getSessionLocator(
session: SessionLike,
selector: string
): { screenshot?: (options: Record<string, unknown>) => Promise<unknown> } | undefined {
if (typeof session.locator === "function") {
return session.locator(selector);
}
if (typeof session.page?.locator === "function") {
return session.page.locator(selector);
}
return undefined;
}
async function captureWithSelector(
session: SessionLike,
selector: string,
targetPath: string,
timeoutMs: number
): Promise<unknown> {
const waitForSelector = await getWaitForSelector(session);
await waitForSelector(selector, timeoutMs);
const locator = getSessionLocator(session, selector);
if (locator?.screenshot) {
return locator.screenshot({ path: targetPath });
}
const evaluate = session.evaluate ?? session.page?.evaluate;
if (typeof evaluate !== "function") {
return false;
}
const clip = await evaluate((rawSelector: string): ClipRect | null => {
const element = document.querySelector(rawSelector) as HTMLElement | null;
if (!element) {
return null;
}
const bounds = element.getBoundingClientRect();
if (!bounds.width || !bounds.height) {
return null;
}
return {
x: Math.max(0, Math.floor(bounds.left)),
y: Math.max(0, Math.floor(bounds.top)),
width: Math.max(1, Math.ceil(bounds.width)),
height: Math.max(1, Math.ceil(bounds.height)),
};
}, selector);
if (!clip) {
throw new Error(`No element matched selector: ${selector}`);
}
const screenshot = getSessionScreenshot(session);
if (!screenshot) {
return undefined;
}
return screenshot({
path: targetPath,
clip,
});
}
async function captureFullPage(
session: SessionLike,
targetPath: string,
fullPage: boolean
): Promise<unknown> {
const screenshot = getSessionScreenshot(session);
if (!screenshot) {
throw new Error("Session does not support screenshot capture.");
}
return screenshot({
path: targetPath,
fullPage,
});
}
function isBinaryLike(value: unknown): Buffer | Uint8Array | null {
if (value instanceof Uint8Array) {
return value;
}
if (value instanceof Buffer) {
return value;
}
return null;
}
async function persistScreenshotBuffer(
targetPath: string,
value: unknown
): Promise<void> {
const buffer = isBinaryLike(value);
if (!buffer) {
return;
}
await fs.writeFile(targetPath, Buffer.from(buffer));
}
async function writeArtifact(targetPath: string, sessionResult: unknown): Promise<void> {
if (await fileExists(targetPath)) {
return;
}
await persistScreenshotBuffer(targetPath, sessionResult);
if (!(await fileExists(targetPath))) {
throw new Error(`Screenshot not written to expected path: ${targetPath}`);
}
}
export function screenshotTool(client: SteelClient): ToolDefinition<any, any> {
return {
name: "steel_screenshot",
label: "Screenshot",
description: "Capture a screenshot of the current page",
parameters: Type.Object({
fullPage: Type.Optional(
Type.Boolean({ description: "Capture full page screenshot instead of viewport-only" })
),
selector: Type.Optional(
Type.String({
description: "Optional CSS selector to capture a single element instead of full page",
})
),
timeout: Type.Optional(
Type.Integer({
minimum: 100,
maximum: MAX_TOOL_TIMEOUT_MS,
description: "Timeout for waiting on selector when selector mode is used",
})
),
}),
async execute(
_toolCallId: string,
params: { fullPage?: boolean; selector?: string; timeout?: number },
signal: AbortSignal | undefined,
onUpdate: ToolProgressUpdater,
_ctx: ExtensionContext
): Promise<{ content: Array<{ type: "text"; text: string }>; details: object }> {
return withToolError("steel_screenshot", async () => {
throwIfAborted(signal);
const selector = normalizeSelector(params.selector);
const fullPage = normalizeFullPage(params.fullPage);
const timeoutMs = resolveTimeoutMs(params.timeout);
const target = selector ? ` element ${selector}` : " visible page";
await emitProgress(onUpdate, "steel_screenshot", `Preparing capture for${target}`);
const session = (await withAbortSignal(
client.getOrCreateSession(),
signal
)) as SessionLike;
throwIfAborted(signal);
const url = await readSessionUrl(session);
const targetPath = await makeArtifactPath();
let screenshotResult: unknown;
if (selector) {
await emitProgress(onUpdate, "steel_screenshot", `Capturing element ${selector}`);
screenshotResult = await captureWithSelector(session, selector, targetPath, timeoutMs);
if (!screenshotResult && !(await fileExists(targetPath))) {
throw new Error("Session does not support selector-based screenshot capture.");
}
} else {
await emitProgress(onUpdate, "steel_screenshot", fullPage ? "Capturing full-page screenshot" : "Capturing viewport screenshot");
screenshotResult = await captureFullPage(session, targetPath, fullPage);
}
await emitProgress(onUpdate, "steel_screenshot", `Persisting image to ${targetPath}`);
await writeArtifact(targetPath, screenshotResult);
const displayPath = toArtifactDisplayPath(targetPath);
const contentText = selector
? `Captured screenshot of ${selector}`
: fullPage
? "Captured full-page screenshot"
: "Captured viewport screenshot";
return {
content: [{ type: "text", text: contentText }],
details: {
...sessionDetails(session, url, selector, fullPage),
filePath: displayPath,
timeoutMs,
},
};
}, signal);
},
};
}

View File

@@ -0,0 +1,311 @@
import type { ExtensionContext, ToolDefinition } from "@mariozechner/pi-coding-agent";
import { Type } from "@sinclair/typebox";
import { sessionDetails, type SteelClient } from "../steel-client.js";
import {
emitProgress,
throwIfAborted,
withAbortSignal,
withToolError,
type ToolProgressUpdater,
} from "./tool-runtime.js";
type ScrollDirection = "up" | "down";
type SessionLike = {
id: string;
sessionViewerUrl?: string | null;
evaluate?: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
page?: {
evaluate?: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
};
};
type ScrollResult = {
before: number;
after: number;
maxScrollY: number;
effectiveAmount: number;
viewportHeight: number;
contentHeight: number;
targetType: "page" | "container";
targetSelector: string | null;
};
const DEFAULT_SCROLL_AMOUNT = 800;
const MIN_SCROLL_AMOUNT = 50;
const MAX_SCROLL_AMOUNT = 5000;
function resolveDirection(rawDirection?: string): ScrollDirection {
if (rawDirection === "up") {
return "up";
}
if (rawDirection === "down") {
return "down";
}
return "down";
}
function normalizeAmount(rawAmount?: number): number {
if (rawAmount === undefined) {
return DEFAULT_SCROLL_AMOUNT;
}
const parsed = Number(rawAmount);
if (!Number.isFinite(parsed) || parsed <= 0) {
throw new Error("amount must be a positive number of pixels.");
}
const rounded = Math.trunc(parsed);
return Math.max(MIN_SCROLL_AMOUNT, Math.min(rounded, MAX_SCROLL_AMOUNT));
}
function getSessionEvaluate(session: SessionLike): ((fn: (...args: any[]) => unknown, ...args: any[]) => Promise<unknown>) {
if (typeof session.evaluate === "function") {
return async (fn, ...args) => {
return session.evaluate?.(fn, ...args);
};
}
if (typeof session.page?.evaluate === "function") {
return async (fn, ...args) => {
return session.page?.evaluate?.(fn, ...args);
};
}
throw new Error("Session does not support DOM evaluation.");
}
async function performScroll(
session: SessionLike,
direction: ScrollDirection,
amount: number,
selector?: string
): Promise<ScrollResult> {
const evaluate = getSessionEvaluate(session);
return evaluate(
(input: {
amount: number;
direction: ScrollDirection;
selector: string | null;
}) => {
const toSelector = (element: Element): string | null => {
const tag = element.tagName.toLowerCase();
const id = element.getAttribute("id");
if (id) {
return `#${id}`;
}
const testId = element.getAttribute("data-testid");
if (testId) {
return `${tag}[data-testid="${testId}"]`;
}
const name = element.getAttribute("name");
if (name) {
return `${tag}[name="${name}"]`;
}
const role = element.getAttribute("role");
if (role) {
return `${tag}[role="${role}"]`;
}
return tag;
};
const isScrollable = (element: Element): boolean => {
const htmlElement = element as HTMLElement;
const style = window.getComputedStyle(htmlElement);
const overflowY = style.overflowY;
const canOverflow = overflowY === "auto" || overflowY === "scroll" || overflowY === "overlay";
return canOverflow && htmlElement.scrollHeight > htmlElement.clientHeight + 4;
};
const isVisible = (element: Element): boolean => {
const htmlElement = element as HTMLElement;
const style = window.getComputedStyle(htmlElement);
const rect = htmlElement.getBoundingClientRect();
return (
rect.width > 0 &&
rect.height > 0 &&
style.visibility !== "hidden" &&
style.display !== "none" &&
Number.parseFloat(style.opacity) > 0
);
};
const findScrollableAncestor = (element: Element | null): Element | null => {
let current = element;
while (current) {
if (isScrollable(current) && isVisible(current)) {
return current;
}
current = current.parentElement;
}
return null;
};
const findBestScrollableContainer = (): Element | null => {
const elements = Array.from(document.querySelectorAll("*"));
let best: Element | null = null;
let bestScore = -1;
for (const element of elements) {
if (!isScrollable(element) || !isVisible(element)) {
continue;
}
const htmlElement = element as HTMLElement;
const score = (htmlElement.scrollHeight - htmlElement.clientHeight) * Math.max(1, htmlElement.clientHeight);
if (score > bestScore) {
best = element;
bestScore = score;
}
}
return best;
};
const signedAmount = input.direction === "down" ? input.amount : -input.amount;
const scrollElement = (element: HTMLElement, targetSelector: string | null): ScrollResult => {
const before = Number(element.scrollTop || 0);
const viewportHeight = Math.max(0, element.clientHeight);
const contentHeight = Math.max(0, element.scrollHeight);
const maxScrollY = Math.max(0, contentHeight - viewportHeight);
const target = Math.max(0, Math.min(maxScrollY, before + signedAmount));
element.scrollTo({ top: target, left: element.scrollLeft || 0 });
return {
before,
after: Number(element.scrollTop || 0),
maxScrollY,
effectiveAmount: target - before,
viewportHeight,
contentHeight,
targetType: "container",
targetSelector,
};
};
const explicitTarget = input.selector
? findScrollableAncestor(document.querySelector(input.selector))
: null;
if (explicitTarget) {
return scrollElement(explicitTarget as HTMLElement, toSelector(explicitTarget));
}
const bodyHeight = document.body?.scrollHeight ?? 0;
const docHeight = document.documentElement?.scrollHeight ?? 0;
const contentHeight = Math.max(bodyHeight, docHeight, document.body?.offsetHeight ?? 0, document.documentElement?.offsetHeight ?? 0);
const viewportHeight = Math.max(window.innerHeight, document.documentElement?.clientHeight ?? 0);
const maxScrollY = Math.max(0, contentHeight - viewportHeight);
const before = Number(window.scrollY || window.pageYOffset || 0);
const target = Math.max(0, Math.min(maxScrollY, before + signedAmount));
window.scrollTo({ top: target, left: window.pageXOffset || window.scrollX || 0 });
const pageResult = {
before,
after: Number(window.scrollY || window.pageYOffset || 0),
maxScrollY,
effectiveAmount: target - before,
viewportHeight,
contentHeight,
targetType: "page" as const,
targetSelector: null,
};
if (pageResult.before !== pageResult.after || pageResult.contentHeight > pageResult.viewportHeight) {
return pageResult;
}
const fallbackTarget = findBestScrollableContainer();
if (fallbackTarget) {
return scrollElement(fallbackTarget as HTMLElement, toSelector(fallbackTarget));
}
return pageResult;
},
{ amount, direction, selector: selector ?? null }
) as Promise<ScrollResult>;
}
export function scrollTool(client: SteelClient): ToolDefinition<any, any> {
return {
name: "steel_scroll",
label: "Scroll",
description: "Scroll the current page or a visible scroll container up or down",
parameters: Type.Object({
direction: Type.Optional(
Type.Union([Type.Literal("up"), Type.Literal("down")], {
description: "Direction to scroll",
})
),
amount: Type.Optional(
Type.Integer({
minimum: MIN_SCROLL_AMOUNT,
maximum: MAX_SCROLL_AMOUNT,
description: "Pixel amount for one scroll action",
})
),
selector: Type.Optional(
Type.String({
description: "Optional selector for an element inside the scroll target; useful for nested panes like lists, sidebars, or map results",
})
),
}),
async execute(
_toolCallId: string,
params: { direction?: ScrollDirection; amount?: number; selector?: string },
signal: AbortSignal | undefined,
onUpdate: ToolProgressUpdater,
_ctx: ExtensionContext
): Promise<{ content: Array<{ type: "text"; text: string }>; details: object }> {
return withToolError("steel_scroll", async () => {
throwIfAborted(signal);
const direction = resolveDirection(params.direction);
const amount = normalizeAmount(params.amount);
const selector = typeof params.selector === "string" && params.selector.trim()
? params.selector.trim()
: undefined;
const session = (await withAbortSignal(
client.getOrCreateSession(),
signal
)) as SessionLike;
const targetLabel = selector ? ` near ${selector}` : "";
await emitProgress(onUpdate, "steel_scroll", `Preparing scroll ${direction} by ${amount}px${targetLabel}`);
const result = await withAbortSignal(
performScroll(session, direction, amount, selector),
signal
);
if (result.contentHeight <= result.viewportHeight) {
throw new Error("Page is not scrollable: content fits within viewport.");
}
if (result.before === result.after) {
const edge = direction === "down" ? "bottom" : "top";
throw new Error(`No scroll movement occurred; already at ${edge}.`);
}
await emitProgress(onUpdate, "steel_scroll", `Scroll movement: ${Math.abs(result.effectiveAmount)}px`);
return {
content: [{
type: "text",
text: `Scrolled ${direction} by ${Math.abs(result.effectiveAmount)}px.`,
}],
details: {
...sessionDetails(session),
direction,
requestedAmount: amount,
requestedSelector: selector ?? null,
effectiveAmount: Math.abs(result.effectiveAmount),
before: result.before,
after: result.after,
maxScrollY: result.maxScrollY,
targetType: result.targetType,
targetSelector: result.targetSelector,
bounds: {
atTop: result.after <= 0,
atBottom: result.after >= result.maxScrollY,
},
},
};
}, signal);
},
};
}

View File

@@ -0,0 +1,108 @@
import type { ExtensionContext, ToolDefinition } from "@mariozechner/pi-coding-agent";
import { Type } from "@sinclair/typebox";
import type { SteelSessionMode } from "../session-mode.js";
import type { SteelClient } from "../steel-client.js";
import { withToolError, type ToolProgressUpdater } from "./tool-runtime.js";
export type SteelSessionController = {
getDefaultSessionMode: () => SteelSessionMode;
getSessionMode: () => SteelSessionMode;
setSessionMode: (mode: SteelSessionMode) => void;
closeSessions: (reason: string) => Promise<void>;
};
function buildPinMessage(sessionId: string | null): string {
if (sessionId) {
return `Enabled Steel session persistence for this Pi session. Current session: ${sessionId}.`;
}
return "Enabled Steel session persistence for this Pi session.";
}
function buildReleaseMessage(
sessionId: string | null,
nextMode: SteelSessionMode
): string {
if (sessionId) {
return `Released Steel session ${sessionId}. Runtime session mode reset to ${nextMode}.`;
}
return `No active Steel session to release. Runtime session mode reset to ${nextMode}.`;
}
export function pinSessionTool(
client: SteelClient,
controller: SteelSessionController
): ToolDefinition<any, any> {
return {
name: "steel_pin_session",
label: "Pin Session",
description: "Keep the current Steel browser session alive across prompts until explicitly released",
parameters: Type.Object({}),
async execute(
_toolCallId: string,
_params: {},
_signal: AbortSignal | undefined,
_onUpdate: ToolProgressUpdater,
_ctx: ExtensionContext
): Promise<{ content: Array<{ type: "text"; text: string }>; details: object }> {
return withToolError("steel_pin_session", async () => {
const previousMode = controller.getSessionMode();
controller.setSessionMode("session");
const sessionId = client.getCurrentSessionId();
return {
content: [{ type: "text", text: buildPinMessage(sessionId) }],
details: {
previousMode,
mode: "session",
defaultMode: controller.getDefaultSessionMode(),
sessionId,
hasActiveSession: client.hasActiveSession(),
},
};
});
},
};
}
export function releaseSessionTool(
client: SteelClient,
controller: SteelSessionController
): ToolDefinition<any, any> {
return {
name: "steel_release_session",
label: "Release Session",
description: "Close the current Steel browser session immediately and restore the default runtime session mode",
parameters: Type.Object({}),
async execute(
_toolCallId: string,
_params: {},
_signal: AbortSignal | undefined,
_onUpdate: ToolProgressUpdater,
_ctx: ExtensionContext
): Promise<{ content: Array<{ type: "text"; text: string }>; details: object }> {
return withToolError("steel_release_session", async () => {
const previousMode = controller.getSessionMode();
const defaultMode = controller.getDefaultSessionMode();
const sessionId = client.getCurrentSessionId();
await controller.closeSessions("steel_release_session");
controller.setSessionMode(defaultMode);
return {
content: [{ type: "text", text: buildReleaseMessage(sessionId, defaultMode) }],
details: {
previousMode,
mode: defaultMode,
defaultMode,
releasedSessionId: sessionId,
hadActiveSession: Boolean(sessionId),
},
};
});
},
};
}

View File

@@ -0,0 +1,64 @@
type SessionGetter = (() => Promise<string> | string) | string;
export type SessionStateLike = {
url?: SessionGetter;
title?: SessionGetter;
getCurrentUrl?: () => Promise<string> | string;
};
export async function readSessionUrl(session: SessionStateLike): Promise<string> {
const direct = session.url;
if (typeof direct === "string" && direct.trim()) {
return direct;
}
if (typeof direct === "function") {
const value = await direct.call(session);
if (typeof value === "string" && value.trim()) {
return value;
}
}
if (typeof session.getCurrentUrl === "function") {
const value = await session.getCurrentUrl.call(session);
if (typeof value === "string" && value.trim()) {
return value;
}
}
return "unknown";
}
export async function readSessionTitle(session: SessionStateLike): Promise<string> {
const direct = session.title;
if (typeof direct === "string" && direct.trim()) {
return direct;
}
if (typeof direct === "function") {
const value = await direct.call(session);
if (typeof value === "string" && value.trim()) {
return value;
}
}
return "unknown";
}
export function isBlankPageUrl(url: string): boolean {
const normalized = url.trim().toLowerCase();
return normalized === "about:blank" || normalized === "about:srcdoc";
}
export function freshSessionHint(): string {
return "This usually means Pi started a fresh Steel session. Navigate to a page first, or run Pi with STEEL_SESSION_MODE=session to keep the same browser across prompts.";
}
export function blankPageError(action: string): Error {
return new Error(`Cannot ${action} because the current page is about:blank. ${freshSessionHint()}`);
}
export function describeBlankPage(url: string): string {
return `Current URL: ${url} (fresh Steel session; navigate first or use STEEL_SESSION_MODE=session for cross-prompt continuity)`;
}

View File

@@ -0,0 +1,246 @@
import type { AgentToolUpdateCallback } from "@mariozechner/pi-coding-agent";
export type ToolErrorCategory =
| "validation"
| "timeout"
| "network"
| "tool_execution"
| "unknown";
export type ToolProgressUpdater = AgentToolUpdateCallback<{
context: string;
kind: "progress";
message: string;
}> | undefined;
const ABORT_ERROR_NAME = "AbortError";
const ABORT_ERROR_MESSAGE = "Tool execution cancelled.";
const TOOL_ERROR_PATTERNS: Record<ToolErrorCategory, readonly string[]> = {
validation: [
"bad request",
"invalid",
"missing",
"required",
"schema",
"format",
"validation",
"unsupported value",
"not allowed",
],
timeout: [
"timed out",
"timeout",
"timed-out",
"deadline",
"time out",
],
network: [
"network",
"connection",
"econn",
"enotfound",
"dns",
"econnreset",
"econnrefused",
"proxy",
"ssl",
"certificate",
],
tool_execution: [
"selector",
"tool",
"navigation",
"screenshot",
"pdf",
"session",
"click",
"extract",
"not supported",
"page",
],
unknown: [],
};
const TOOL_ERROR_LABELS: Record<ToolErrorCategory, string> = {
validation: "Validation failed",
timeout: "Timed out",
network: "Network issue",
tool_execution: "Tool execution failed",
unknown: "Tool error",
};
const TOOL_ERROR_GUIDANCE: Record<ToolErrorCategory, string> = {
validation: "Check required inputs and retry with corrected values.",
timeout:
"Retry with narrower scope or longer timeout values.",
network:
"Retry once connectivity is stable.",
tool_execution:
"Retrying usually succeeds; if selector-based operations fail, refresh page state and try again.",
unknown: "Retry the action and, if it repeats, rerun with simplified inputs.",
};
function normalizeErrorMessage(error: unknown): string {
if (error instanceof Error) {
return error.message?.trim() || "Unknown error";
}
if (typeof error === "string") {
return error.trim() || "Unknown error";
}
if (error === undefined || error === null) {
return "Unknown error";
}
try {
return JSON.stringify(error);
} catch {
return String(error);
}
}
function classifyError(message: string): ToolErrorCategory {
const normalized = message.toLowerCase();
for (const [category, markers] of Object.entries(
TOOL_ERROR_PATTERNS
) as [ToolErrorCategory, readonly string[]][]) {
if (markers.some((marker) => normalized.includes(marker))) {
return category;
}
}
return "unknown";
}
export function toolErrorMessage(context: string, error: unknown): string {
const message = normalizeErrorMessage(error);
const category = classifyError(message);
const label = TOOL_ERROR_LABELS[category];
const guidance = TOOL_ERROR_GUIDANCE[category];
return `${context}: ${label}. ${message}. Retry guidance: ${guidance}`;
}
export function toolError(context: string, error: unknown): Error {
return new Error(toolErrorMessage(context, error));
}
function abortError(message = ABORT_ERROR_MESSAGE): Error {
const error = new Error(message);
error.name = ABORT_ERROR_NAME;
return error;
}
export function isAbortError(error: unknown): boolean {
if (!(error instanceof Error)) {
return false;
}
if (error.name === ABORT_ERROR_NAME) {
return true;
}
const message = error.message.toLowerCase();
return message.includes("cancelled") || message.includes("canceled");
}
export function throwIfAborted(signal: AbortSignal | undefined): void {
if (signal?.aborted) {
throw abortError();
}
}
export function sleepWithSignal(
ms: number,
signal: AbortSignal | undefined
): Promise<void> {
if (!signal) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
throwIfAborted(signal);
return new Promise((resolve, reject) => {
const timer = setTimeout(() => {
signal.removeEventListener("abort", onAbort);
resolve();
}, ms);
const onAbort = () => {
clearTimeout(timer);
signal.removeEventListener("abort", onAbort);
reject(abortError());
};
signal.addEventListener("abort", onAbort, { once: true });
});
}
export function withAbortSignal<T>(
promise: Promise<T>,
signal: AbortSignal | undefined
): Promise<T> {
if (!signal) {
return promise;
}
throwIfAborted(signal);
return new Promise((resolve, reject) => {
const onAbort = () => {
signal.removeEventListener("abort", onAbort);
reject(abortError());
};
signal.addEventListener("abort", onAbort, { once: true });
promise.then(
(value) => {
signal.removeEventListener("abort", onAbort);
resolve(value);
},
(error: unknown) => {
signal.removeEventListener("abort", onAbort);
reject(error);
}
);
});
}
export function withToolError<T>(
context: string,
operation: () => Promise<T>,
signal?: AbortSignal
): Promise<T> {
try {
throwIfAborted(signal);
return operation().catch((error: unknown) => {
if (isAbortError(error) || signal?.aborted) {
throw abortError(`${context}: ${ABORT_ERROR_MESSAGE}`);
}
throw toolError(context, error);
});
} catch (error: unknown) {
if (isAbortError(error) || signal?.aborted) {
throw abortError(`${context}: ${ABORT_ERROR_MESSAGE}`);
}
throw toolError(context, error);
}
}
export function emitProgress(
onUpdate: ToolProgressUpdater,
context: string,
message: string
): void {
if (!onUpdate) {
return;
}
const trimmed = message.trim();
onUpdate({
content: [{ type: "text", text: `${context}: ${trimmed}` }],
details: {
context,
kind: "progress",
message: trimmed,
},
});
}

View File

@@ -0,0 +1,56 @@
const DEFAULT_TOOL_TIMEOUT_MS = 30_000;
const TOOL_TIMEOUT_ENV = "STEEL_TOOL_TIMEOUT_MS";
export const MIN_TOOL_TIMEOUT_MS = 100;
export const MAX_TOOL_TIMEOUT_MS = 120_000;
let cachedDefaultToolTimeoutMs: number | null = null;
function parsePositiveInt(raw: string | undefined): number | null {
if (raw === undefined) {
return null;
}
const value = raw.trim();
if (!value) {
return null;
}
const parsed = Number.parseInt(value, 10);
if (!Number.isFinite(parsed) || parsed <= 0) {
return null;
}
return parsed;
}
export function getDefaultToolTimeoutMs(): number {
if (cachedDefaultToolTimeoutMs !== null) {
return cachedDefaultToolTimeoutMs;
}
const parsed = parsePositiveInt(process.env[TOOL_TIMEOUT_ENV]);
if (parsed === null) {
cachedDefaultToolTimeoutMs = DEFAULT_TOOL_TIMEOUT_MS;
return cachedDefaultToolTimeoutMs;
}
cachedDefaultToolTimeoutMs = Math.max(
MIN_TOOL_TIMEOUT_MS,
Math.min(parsed, MAX_TOOL_TIMEOUT_MS)
);
return cachedDefaultToolTimeoutMs;
}
export function resolveToolTimeoutMs(rawTimeout: number | undefined): number {
if (rawTimeout === undefined) {
return getDefaultToolTimeoutMs();
}
const parsed = Number(rawTimeout);
if (!Number.isFinite(parsed) || parsed <= 0) {
throw new Error("timeout must be a positive number in milliseconds.");
}
const rounded = Math.max(MIN_TOOL_TIMEOUT_MS, Math.trunc(parsed));
return Math.min(rounded, MAX_TOOL_TIMEOUT_MS);
}

View File

@@ -0,0 +1,269 @@
import type { ExtensionContext, ToolDefinition } from "@mariozechner/pi-coding-agent";
import { Type } from "@sinclair/typebox";
import { sessionDetails, type SteelClient } from "../steel-client.js";
import { runWithCaptchaRecovery, type CaptchaRecoverySummary } from "./captcha-guard.js";
import {
emitProgress,
throwIfAborted,
withAbortSignal,
withToolError,
type ToolProgressUpdater,
} from "./tool-runtime.js";
import {
MAX_TOOL_TIMEOUT_MS,
resolveToolTimeoutMs,
} from "./tool-settings.js";
type SessionLike = {
id: string;
sessionViewerUrl?: string | null;
captchasStatus?: () => Promise<unknown>;
captchasSolve?: () => Promise<unknown>;
waitForSelector?: (
selector: string,
options?: { state?: "attached" | "visible"; timeout?: number }
) => Promise<unknown>;
fill?: (selector: string, text: string) => Promise<unknown>;
type?: (selector: string, text: string, options?: { delay?: number }) => Promise<unknown>;
evaluate?: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
locator?: (selector: string) => {
fill?: (text: string) => Promise<unknown>;
type?: (text: string, options?: { delay?: number }) => Promise<unknown>;
waitFor?: (options?: { state?: "attached" | "visible"; timeout?: number }) => Promise<unknown>;
};
page?: {
waitForSelector?: (
selector: string,
options?: { state?: "attached" | "visible"; timeout?: number }
) => Promise<unknown>;
fill?: (selector: string, text: string) => Promise<unknown>;
type?: (selector: string, text: string, options?: { delay?: number }) => Promise<unknown>;
evaluate?: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
locator?: (selector: string) => {
fill?: (text: string) => Promise<unknown>;
type?: (text: string, options?: { delay?: number }) => Promise<unknown>;
waitFor?: (options?: { state?: "attached" | "visible"; timeout?: number }) => Promise<unknown>;
};
};
};
type FieldActionState = {
found: boolean;
editable: boolean;
};
function compactCaptchaRecovery(summary: CaptchaRecoverySummary) {
return {
triggered: summary.triggered,
retries: summary.retries,
solveAttempts: summary.solveAttempts,
statusChecks: summary.statusChecks,
waitTimedOut: summary.waitTimedOut,
};
}
function normalizeSelector(selector: string): string {
const trimmed = selector.trim();
if (!trimmed) {
throw new Error("Selector cannot be empty.");
}
return trimmed;
}
function normalizeTimeout(timeoutMs?: number): number {
return resolveToolTimeoutMs(timeoutMs);
}
async function ensureField(session: SessionLike, selector: string, timeoutMs: number): Promise<FieldActionState> {
if (typeof session.waitForSelector === "function") {
await session.waitForSelector(selector, { state: "visible", timeout: timeoutMs });
} else if (typeof session.page?.waitForSelector === "function") {
await session.page.waitForSelector(selector, { state: "visible", timeout: timeoutMs });
}
const evaluate = session.evaluate ?? session.page?.evaluate;
if (typeof evaluate !== "function") {
return { found: true, editable: true };
}
return evaluate((rawSelector: string) => {
const element = document.querySelector(rawSelector) as HTMLElement | null;
if (!element) {
return { found: false, editable: false };
}
const tag = element.tagName.toLowerCase();
const isInputLike =
tag === "input" ||
tag === "textarea" ||
element.isContentEditable;
const htmlInput = element as HTMLInputElement;
const editable = isInputLike && htmlInput.readOnly !== true;
const disabled =
(element as HTMLInputElement).disabled === true ||
element.getAttribute("aria-disabled") === "true";
return { found: true, editable: editable && !disabled };
}, selector);
}
async function setValue(session: SessionLike, selector: string, text: string): Promise<void> {
if (typeof session.fill === "function") {
await session.fill(selector, text);
return;
}
if (typeof session.page?.fill === "function") {
await session.page.fill(selector, text);
return;
}
const locator =
typeof session.locator === "function"
? session.locator(selector)
: session.page?.locator?.(selector);
const locatorFill = locator?.fill;
if (typeof locatorFill === "function") {
await locatorFill.call(locator, text);
return;
}
const evaluate = session.evaluate ?? session.page?.evaluate;
if (typeof evaluate !== "function") {
throw new Error("Session does not support setting input values.");
}
const result = await evaluate((input: { selector: string; value: string }) => {
const element = document.querySelector(input.selector) as HTMLInputElement | HTMLTextAreaElement | null;
if (!element) {
return false;
}
element.focus();
element.value = input.value;
element.dispatchEvent(new Event("input", { bubbles: true }));
element.dispatchEvent(new Event("change", { bubbles: true }));
return true;
}, { selector, value: text });
if (!result) {
throw new Error(`Element not found: ${selector}`);
}
}
async function typeValue(session: SessionLike, selector: string, text: string): Promise<void> {
if (typeof session.type === "function") {
await session.type(selector, text);
return;
}
if (typeof session.page?.type === "function") {
await session.page.type(selector, text);
return;
}
const locator =
typeof session.locator === "function"
? session.locator(selector)
: session.page?.locator?.(selector);
const locatorType = locator?.type;
if (typeof locatorType === "function") {
await locatorType.call(locator, text);
return;
}
await setValue(session, selector, text);
}
export function typeTool(client: SteelClient): ToolDefinition<any, any> {
return {
name: "steel_type",
label: "Type",
description: "Type text into an input element",
parameters: Type.Object(
{
selector: Type.String({ description: "CSS selector for the input field" }),
text: Type.String({ description: "Text to type into the field" }),
clear: Type.Optional(Type.Boolean({ description: "Whether to clear the field before typing" })),
timeout: Type.Optional(
Type.Integer({
minimum: 100,
maximum: MAX_TOOL_TIMEOUT_MS,
description: "Maximum milliseconds to wait for the input",
})
),
}
),
async execute(
_toolCallId: string,
params: { selector: string; text: string; clear?: boolean; timeout?: number },
signal: AbortSignal | undefined,
onUpdate: ToolProgressUpdater,
_ctx: ExtensionContext
): Promise<{ content: Array<{ type: "text"; text: string }>; details: object }> {
return withToolError("steel_type", async () => {
throwIfAborted(signal);
const selector = normalizeSelector(params.selector);
const timeoutMs = normalizeTimeout(params.timeout);
const text = params.text;
const shouldClear = params.clear ?? true;
await emitProgress(onUpdate, "steel_type", `Preparing input for ${selector}`);
const session = (await withAbortSignal(
client.getOrCreateSession(),
signal
)) as SessionLike;
await emitProgress(onUpdate, "steel_type", "Running field input sequence");
const captchaRecovery = await runWithCaptchaRecovery({
session,
context: "steel_type",
actionLabel: `type into ${selector}`,
onUpdate,
signal,
operation: async () => {
throwIfAborted(signal);
const fieldState = await withAbortSignal(
ensureField(session, selector, timeoutMs),
signal
);
if (!fieldState.found) {
throw new Error(`No element matched selector: ${selector}`);
}
if (!fieldState.editable) {
throw new Error(`Element is not editable: ${selector}`);
}
await emitProgress(
onUpdate,
"steel_type",
shouldClear ? "Clearing existing value" : "Typing into field"
);
if (shouldClear) {
await withAbortSignal(setValue(session, selector, text), signal);
} else {
await withAbortSignal(typeValue(session, selector, text), signal);
}
},
});
await emitProgress(onUpdate, "steel_type", `Input applied to ${selector}`);
return {
content: [{ type: "text", text: `Typed into ${selector}` }],
details: {
...sessionDetails(session),
selector,
timeoutMs,
clear: shouldClear,
textLength: text.length,
captchaRecovery: compactCaptchaRecovery(captchaRecovery),
},
};
}, signal);
},
};
}

View File

@@ -0,0 +1,236 @@
import type { ExtensionContext, ToolDefinition } from "@mariozechner/pi-coding-agent";
import { Type } from "@sinclair/typebox";
import { sessionDetails as baseSessionDetails, type SteelClient } from "../steel-client.js";
import {
emitProgress,
sleepWithSignal,
throwIfAborted,
withAbortSignal,
withToolError,
type ToolProgressUpdater,
} from "./tool-runtime.js";
import {
MAX_TOOL_TIMEOUT_MS,
MIN_TOOL_TIMEOUT_MS,
resolveToolTimeoutMs,
} from "./tool-settings.js";
type WaitState = "attached" | "visible";
type SessionLike = {
id: string;
sessionViewerUrl?: string | null;
waitForSelector?: (
selector: string,
options?: { state?: WaitState; timeout?: number }
) => Promise<unknown>;
evaluate?: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
page?: {
waitForSelector?: (
selector: string,
options?: { state?: WaitState; timeout?: number }
) => Promise<unknown>;
evaluate?: <T>(fn: (...args: any[]) => T, ...args: any[]) => Promise<T>;
};
url?: (() => Promise<string> | string) | string;
};
const POLL_DELAY_MS = 100;
function sessionDetails(session: SessionLike, url: string) {
return {
...baseSessionDetails(session),
url,
};
}
function normalizeSelector(rawSelector?: string): string {
if (typeof rawSelector !== "string") {
throw new Error("selector is required and must be a string.");
}
const trimmed = rawSelector.trim();
if (!trimmed) {
throw new Error("selector cannot be empty.");
}
return trimmed;
}
function resolveTimeout(rawTimeout?: number): number {
return resolveToolTimeoutMs(rawTimeout);
}
function resolveState(rawState?: string): WaitState {
if (rawState === "attached") {
return "attached";
}
return "visible";
}
function getWaitFunction(session: SessionLike): ((selector: string, state: WaitState, timeoutMs: number, signal: AbortSignal | undefined) => Promise<void>) {
if (typeof session.waitForSelector === "function") {
return async (selector, state, timeoutMs, signal) => {
throwIfAborted(signal);
await withAbortSignal(
session.waitForSelector?.(selector, { state, timeout: timeoutMs }) as Promise<unknown>,
signal
);
};
}
if (typeof session.page?.waitForSelector === "function") {
return async (selector, state, timeoutMs, signal) => {
throwIfAborted(signal);
await withAbortSignal(
session.page?.waitForSelector?.(selector, { state, timeout: timeoutMs }) as Promise<unknown>,
signal
);
};
}
const evaluate = session.evaluate ?? session.page?.evaluate;
if (typeof evaluate !== "function") {
throw new Error("Session does not support selector waiting.");
}
return async (selector, state, timeoutMs, signal) => {
const deadline = Date.now() + timeoutMs;
while (true) {
throwIfAborted(signal);
const isMatched = await withAbortSignal(
evaluate(
(input: { selector: string; state: WaitState }) => {
const element = document.querySelector(input.selector);
if (!element) {
return false;
}
if (input.state === "attached") {
return true;
}
const rect = element.getBoundingClientRect();
const style = getComputedStyle(element);
const isVisible =
rect.width > 0 &&
rect.height > 0 &&
style.opacity !== "0" &&
style.visibility !== "hidden" &&
style.display !== "none" &&
Number.parseFloat(style.opacity) > 0;
return isVisible;
},
{ selector, state }
) as Promise<boolean>, signal);
if (isMatched) {
return;
}
if (Date.now() > deadline) {
throw new Error("selector wait timed out");
}
await sleepWithSignal(Math.min(POLL_DELAY_MS, Math.max(10, deadline - Date.now())), signal);
}
};
}
async function readSessionUrl(session: SessionLike): Promise<string> {
const direct = session.url;
if (typeof direct === "string" && direct.trim()) {
return direct;
}
if (typeof direct === "function") {
const value = await direct.call(session);
if (typeof value === "string" && value.trim()) {
return value;
}
}
const getter = (session as { getCurrentUrl?: () => Promise<string> | string }).getCurrentUrl;
if (typeof getter === "function") {
const value = await getter.call(session);
if (typeof value === "string" && value.trim()) {
return value;
}
}
return "unknown";
}
export function waitTool(client: SteelClient): ToolDefinition<any, any> {
return {
name: "steel_wait",
label: "Wait",
description: "Wait for an element state with timeout",
parameters: Type.Object({
selector: Type.String({ description: "CSS selector to wait for" }),
timeout: Type.Optional(
Type.Integer({
minimum: MIN_TOOL_TIMEOUT_MS,
maximum: MAX_TOOL_TIMEOUT_MS,
description: "Maximum milliseconds to wait for selector state",
})
),
state: Type.Optional(
Type.Union([Type.Literal("attached"), Type.Literal("visible")], {
description: "Selector state to wait for",
})
),
}),
async execute(
_toolCallId: string,
params: { selector?: string; timeout?: number; state?: WaitState },
signal: AbortSignal | undefined,
onUpdate: ToolProgressUpdater,
_ctx: ExtensionContext
): Promise<{ content: Array<{ type: "text"; text: string }>; details: object }> {
return withToolError("steel_wait", async () => {
throwIfAborted(signal);
const selector = normalizeSelector(params.selector);
const timeoutMs = resolveTimeout(params.timeout);
const state = resolveState(params.state);
const session = (await withAbortSignal(client.getOrCreateSession(), signal)) as SessionLike;
throwIfAborted(signal);
const url = await readSessionUrl(session);
await emitProgress(onUpdate, "steel_wait", `Waiting for ${selector} with state ${state}`);
try {
const waitForSelector = getWaitFunction(session);
await waitForSelector(selector, state, timeoutMs, signal);
} catch (error) {
const message = String(error instanceof Error ? error.message : "");
if (/timed? ?out|timeout/i.test(message)) {
throw new Error(`Timed out waiting for selector "${selector}" after ${timeoutMs}ms.`);
}
throw error instanceof Error
? error
: new Error(`Failed to wait for selector "${selector}"`);
}
await emitProgress(onUpdate, "steel_wait", `Matched ${selector}`);
return {
content: [{
type: "text",
text: `Selector matched: ${selector}`,
}],
details: {
...sessionDetails(session, url),
selector,
state,
timeoutMs,
success: true,
},
};
}, signal);
},
};
}