From 1c49d451a3fe93ab6da93f5b172ad640cc267d79 Mon Sep 17 00:00:00 2001 From: luca cappa Date: Tue, 3 Dec 2024 17:50:52 -0800 Subject: [PATCH] code snippet provider --- Extension/src/LanguageServer/client.ts | 20 ++- .../copilotCompletionContextProvider.ts | 154 ++++++++++++++++++ Extension/src/LanguageServer/extension.ts | 11 ++ 3 files changed, 184 insertions(+), 1 deletion(-) create mode 100644 Extension/src/LanguageServer/copilotCompletionContextProvider.ts diff --git a/Extension/src/LanguageServer/client.ts b/Extension/src/LanguageServer/client.ts index ba353a858..bb2af899c 100644 --- a/Extension/src/LanguageServer/client.ts +++ b/Extension/src/LanguageServer/client.ts @@ -55,7 +55,7 @@ import { Location, TextEdit, WorkspaceEdit } from './commonTypes'; import * as configs from './configurations'; import { DataBinding } from './dataBinding'; import { cachedEditorConfigSettings, getEditorConfigSettings } from './editorConfig'; -import { CppSourceStr, clients, configPrefix, updateLanguageConfigurations, usesCrashHandler, watchForCrashes } from './extension'; +import { CppSourceStr, SnippetEntry, clients, configPrefix, updateLanguageConfigurations, usesCrashHandler, watchForCrashes } from './extension'; import { LocalizeStringParams, getLocaleId, getLocalizedString } from './localization'; import { PersistentFolderState, PersistentWorkspaceState } from './persistentState'; import { RequestCancelled, ServerCancelled, createProtocolFilter } from './protocolFilter'; @@ -554,6 +554,15 @@ export interface ProjectContextResult { fileContext: FileContextResult; } +export interface CompletionContextsResult { + context: SnippetEntry[]; +} + +export interface CompletionContextParams { + file: string; + caretOffset: number; +} + // Requests const PreInitializationRequest: RequestType = new RequestType('cpptools/preinitialize'); const InitializationRequest: RequestType = new RequestType('cpptools/initialize'); @@ -575,6 +584,7 @@ const ChangeCppPropertiesRequest: RequestType = const IncludesRequest: RequestType = new RequestType('cpptools/getIncludes'); const CppContextRequest: RequestType = new RequestType('cpptools/getChatContext'); const ProjectContextRequest: RequestType = new RequestType('cpptools/getProjectContext'); +const CompletionContextRequest: RequestType = new RequestType('cpptools/getCompletionContext'); // Notifications to the server const DidOpenNotification: NotificationType = new NotificationType('textDocument/didOpen'); @@ -807,6 +817,7 @@ export interface Client { getIncludes(maxDepth: number, token: vscode.CancellationToken): Promise; getChatContext(uri: vscode.Uri, token: vscode.CancellationToken): Promise; getProjectContext(uri: vscode.Uri, token: vscode.CancellationToken): Promise; + getCompletionContext(fileName: vscode.Uri, caretOffset: number, token: vscode.CancellationToken): Promise; } export function createClient(workspaceFolder?: vscode.WorkspaceFolder): Client { @@ -2249,6 +2260,12 @@ export class DefaultClient implements Client { () => this.languageClient.sendRequest(ProjectContextRequest, params, token), token); } + public async getCompletionContext(file: vscode.Uri, caretOffset: number, token: vscode.CancellationToken): Promise { + await withCancellation(this.ready, token); + return DefaultClient.withLspCancellationHandling( + () => this.languageClient.sendRequest(CompletionContextRequest, { file: file.toString(), caretOffset }, token), token); + } + /** * a Promise that can be awaited to know when it's ok to proceed. * @@ -4154,4 +4171,5 @@ class NullClient implements Client { getIncludes(maxDepth: number, token: vscode.CancellationToken): Promise { return Promise.resolve({} as GetIncludesResult); } getChatContext(uri: vscode.Uri, token: vscode.CancellationToken): Promise { return Promise.resolve({} as ChatContextResult); } getProjectContext(uri: vscode.Uri, token: vscode.CancellationToken): Promise { return Promise.resolve({} as ProjectContextResult); } + getCompletionContext(file: vscode.Uri, caretOffset: number, token: vscode.CancellationToken): Promise { return Promise.resolve({} as CompletionContextsResult); } } diff --git a/Extension/src/LanguageServer/copilotCompletionContextProvider.ts b/Extension/src/LanguageServer/copilotCompletionContextProvider.ts new file mode 100644 index 000000000..6ff27241a --- /dev/null +++ b/Extension/src/LanguageServer/copilotCompletionContextProvider.ts @@ -0,0 +1,154 @@ +/* -------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All Rights Reserved. + * See 'LICENSE' in the project root for license information. + * ------------------------------------------------------------------------------------------ */ +import * as vscode from 'vscode'; +import { DocumentSelector } from 'vscode-languageserver-protocol'; +import { getOutputChannelLogger, Logger } from '../logger'; +import * as telemetry from '../telemetry'; +import { getCopilotApi } from "./copilotProviders"; +import { clients } from './extension'; +import { CodeSnippet, CompletionContext, ContextProviderApiV1 } from './tmp/contextProviderV1'; + +// An ever growing cache of completion context snippets. //?? TODO Evict old entries. +const completionContextCache: Map = new Map(); +const cppDocumentSelector: DocumentSelector = [{ language: 'cpp' }, { language: 'c' }]; + +class DefaultValueFallback extends Error { + static readonly DefaultValue = "DefaultValue"; + constructor() { super(DefaultValueFallback.DefaultValue); } +} + +class CancellationError extends Error { + static readonly Cancelled = "Cancelled"; + constructor() { super(CancellationError.Cancelled); } +} + +let completionContextCancellation = new vscode.CancellationTokenSource(); + +// Mutually exclusive values for the kind of snippets. They either come from the cache, +// are computed, or the computation is taking too long and no cache is present. In the latter +// case, the cache is computed anyway while unblocking the execution flow returning undefined. +enum SnippetsKind { + Computed = 'computed', + CacheHit = 'cacheHit', + CacheMiss = 'cacheMiss' +} + +// Get the default value if the timeout expires, but throws an exception if the token is cancelled. +async function waitForCompletionWithTimeoutAndCancellation(promise: Promise, defaultValue: T | undefined, + timeout: number, token: vscode.CancellationToken): Promise<[T | undefined, SnippetsKind]> { + const defaultValuePromise = new Promise((resolve, reject) => setTimeout(() => { + if (token.isCancellationRequested) { + reject('DefaultValuePromise was cancelled'); + } else { + reject(new DefaultValueFallback()); + } + }, timeout)); + const cancellationPromise = new Promise((_, reject) => { + token.onCancellationRequested(() => { + reject(new CancellationError()); + }); + }); + let snippetsOrNothing: T | undefined; + try { + snippetsOrNothing = await Promise.race([promise, cancellationPromise, defaultValuePromise]); + } catch (e) { + if (e instanceof DefaultValueFallback) { + return [defaultValue, defaultValue !== undefined ? SnippetsKind.CacheHit : SnippetsKind.CacheMiss]; + } + + // Rethrow the error for cancellation cases. + throw e; + } + + return [snippetsOrNothing, SnippetsKind.Computed]; +} + +// Get the completion context with a timeout and a cancellation token. +// The cancellationToken indicates that the value should not be returned nor cached. +async function getCompletionContextWithCancellation(documentUri: string, caretOffset: number, + startTime: number, out: Logger, token: vscode.CancellationToken): Promise { + try { + const activeEditor: vscode.TextEditor | undefined = vscode.window.activeTextEditor; + if (!activeEditor || + activeEditor.document.uri.toString() !== vscode.Uri.parse(documentUri).toString()) { + return []; + } + + const snippets = await clients.ActiveClient.getCompletionContext(activeEditor.document.uri, caretOffset, token); + + const codeSnippets = snippets.context.map((item) => { + if (token.isCancellationRequested) { + throw new CancellationError(); + } + return { + importance: item.importance, uri: item.uri, value: item.text + }; + }); + + completionContextCache.set(documentUri, codeSnippets); + const duration: number = Date.now() - startTime; + out.appendLine(`Copilot: getCompletionContextWithCancellation(): Cached in [ms]: ${duration}`); + // //?? TODO Add telemetry for elapsed time. + + return codeSnippets; + } catch (e) { + const err = e as Error; + out.appendLine(`Copilot: getCompletionContextWithCancellation(): Error: '${err?.message}', stack '${err?.stack}`); + + // //?? TODO Add telemetry for failure. + return []; + } +} + +const timeBudgetFactor: number = 0.5; +const cppToolsResolver = { + async resolve(context: CompletionContext, copilotAborts: vscode.CancellationToken): Promise { + const startTime = Date.now(); + const out: Logger = getOutputChannelLogger(); + let snippetsKind: SnippetsKind = SnippetsKind.Computed; + try { + completionContextCancellation.cancel(); + completionContextCancellation = new vscode.CancellationTokenSource(); + const docUri = context.documentContext.uri; + const cachedValue: CodeSnippet[] | undefined = completionContextCache.get(docUri.toString()); + const snippetsPromise = getCompletionContextWithCancellation(docUri, + context.documentContext.offset, startTime, out, completionContextCancellation.token); + const [codeSnippets, kind] = await waitForCompletionWithTimeoutAndCancellation( + snippetsPromise, cachedValue, context.timeBudget * timeBudgetFactor, copilotAborts); + snippetsKind = kind; + // //?? TODO Add telemetry for Computed vs Cached. + + return codeSnippets ?? []; + } catch (e: any) { + if (e instanceof CancellationError) { + out.appendLine(`Copilot: getCompletionContext(): cancelled!`); + } + // //?? TODO Add telemetry for failure. + } finally { + const duration: number = Date.now() - startTime; + out.appendLine(`Copilot: getCompletionContext(): snippets retrieval (${snippetsKind.toString()}) elapsed time (ms): ${duration}`); + // //?? TODO Add telemetry for elapsed time. + } + + return []; + } +}; + +export async function registerCopilotContextProvider(): Promise { + try { + const isCustomSnippetProviderApiEnabled = await telemetry.isExperimentEnabled("CppToolsCustomSnippetsApi"); + if (isCustomSnippetProviderApiEnabled) { + const contextAPI = (await getCopilotApi() as any).getContextProviderAPI('v1') as ContextProviderApiV1; + contextAPI.registerContextProvider({ + id: 'cppTools', + selector: cppDocumentSelector, + resolver: cppToolsResolver + }); + } + } catch { + console.warn("Failed to register the Copilot Context Provider."); + // //?? TODO Add telemetry for failure. + } +} diff --git a/Extension/src/LanguageServer/extension.ts b/Extension/src/LanguageServer/extension.ts index 8bc64f82f..2be4e9842 100644 --- a/Extension/src/LanguageServer/extension.ts +++ b/Extension/src/LanguageServer/extension.ts @@ -23,6 +23,7 @@ import * as telemetry from '../telemetry'; import { Client, DefaultClient, DoxygenCodeActionCommandArguments, openFileVersions } from './client'; import { ClientCollection } from './clientCollection'; import { CodeActionDiagnosticInfo, CodeAnalysisDiagnosticIdentifiersAndUri, codeAnalysisAllFixes, codeAnalysisCodeToFixes, codeAnalysisFileToCodeActions } from './codeAnalysis'; +import { registerCopilotContextProvider } from './copilotCompletionContextProvider'; import { registerRelatedFilesProvider } from './copilotProviders'; import { CppBuildTaskProvider } from './cppBuildTaskProvider'; import { getCustomConfigProviders } from './customProviders'; @@ -34,6 +35,14 @@ import { CppSettings } from './settings'; import { LanguageStatusUI, getUI } from './ui'; import { makeLspRange, rangeEquals, showInstallCompilerWalkthrough } from './utils'; +export interface SnippetEntry { + uri: string; + text: string; + startLine: number; + endLine: number; + importance: number; +} + nls.config({ messageFormat: nls.MessageFormat.bundle, bundleFormat: nls.BundleFormat.standalone })(); const localize: nls.LocalizeFunc = nls.loadMessageBundle(); export const CppSourceStr: string = "C/C++"; @@ -264,6 +273,8 @@ export async function activate(): Promise { } await registerRelatedFilesProvider(); + + await registerCopilotContextProvider(); } export function updateLanguageConfigurations(): void {