Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion Extension/src/LanguageServer/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ import { Location, TextEdit, WorkspaceEdit } from './commonTypes';
import * as configs from './configurations';
import { DataBinding } from './dataBinding';
import { cachedEditorConfigSettings, getEditorConfigSettings } from './editorConfig';
import { CppSourceStr, clients, configPrefix, updateLanguageConfigurations, usesCrashHandler, watchForCrashes } from './extension';
import { CppSourceStr, SnippetEntry, clients, configPrefix, updateLanguageConfigurations, usesCrashHandler, watchForCrashes } from './extension';
import { LocalizeStringParams, getLocaleId, getLocalizedString } from './localization';
import { PersistentFolderState, PersistentWorkspaceState } from './persistentState';
import { RequestCancelled, ServerCancelled, createProtocolFilter } from './protocolFilter';
Expand Down Expand Up @@ -554,6 +554,15 @@ export interface ProjectContextResult {
fileContext: FileContextResult;
}

export interface CompletionContextsResult {
context: SnippetEntry[];
}

export interface CompletionContextParams {
file: string;
caretOffset: number;
}

// Requests
const PreInitializationRequest: RequestType<void, string, void> = new RequestType<void, string, void>('cpptools/preinitialize');
const InitializationRequest: RequestType<CppInitializationParams, void, void> = new RequestType<CppInitializationParams, void, void>('cpptools/initialize');
Expand All @@ -575,6 +584,7 @@ const ChangeCppPropertiesRequest: RequestType<CppPropertiesParams, void, void> =
const IncludesRequest: RequestType<GetIncludesParams, GetIncludesResult, void> = new RequestType<GetIncludesParams, GetIncludesResult, void>('cpptools/getIncludes');
const CppContextRequest: RequestType<TextDocumentIdentifier, ChatContextResult, void> = new RequestType<TextDocumentIdentifier, ChatContextResult, void>('cpptools/getChatContext');
const ProjectContextRequest: RequestType<TextDocumentIdentifier, ProjectContextResult, void> = new RequestType<TextDocumentIdentifier, ProjectContextResult, void>('cpptools/getProjectContext');
const CompletionContextRequest: RequestType<CompletionContextParams, CompletionContextsResult, void> = new RequestType<CompletionContextParams, CompletionContextsResult, void>('cpptools/getCompletionContext');

// Notifications to the server
const DidOpenNotification: NotificationType<DidOpenTextDocumentParams> = new NotificationType<DidOpenTextDocumentParams>('textDocument/didOpen');
Expand Down Expand Up @@ -807,6 +817,7 @@ export interface Client {
getIncludes(maxDepth: number, token: vscode.CancellationToken): Promise<GetIncludesResult>;
getChatContext(uri: vscode.Uri, token: vscode.CancellationToken): Promise<ChatContextResult>;
getProjectContext(uri: vscode.Uri, token: vscode.CancellationToken): Promise<ProjectContextResult>;
getCompletionContext(fileName: vscode.Uri, caretOffset: number, token: vscode.CancellationToken): Promise<CompletionContextsResult>;
}

export function createClient(workspaceFolder?: vscode.WorkspaceFolder): Client {
Expand Down Expand Up @@ -2249,6 +2260,12 @@ export class DefaultClient implements Client {
() => this.languageClient.sendRequest(ProjectContextRequest, params, token), token);
}

public async getCompletionContext(file: vscode.Uri, caretOffset: number, token: vscode.CancellationToken): Promise<CompletionContextsResult> {
await withCancellation(this.ready, token);
return DefaultClient.withLspCancellationHandling(
() => this.languageClient.sendRequest(CompletionContextRequest, { file: file.toString(), caretOffset }, token), token);
}

/**
* a Promise that can be awaited to know when it's ok to proceed.
*
Expand Down Expand Up @@ -4154,4 +4171,5 @@ class NullClient implements Client {
getIncludes(maxDepth: number, token: vscode.CancellationToken): Promise<GetIncludesResult> { return Promise.resolve({} as GetIncludesResult); }
getChatContext(uri: vscode.Uri, token: vscode.CancellationToken): Promise<ChatContextResult> { return Promise.resolve({} as ChatContextResult); }
getProjectContext(uri: vscode.Uri, token: vscode.CancellationToken): Promise<ProjectContextResult> { return Promise.resolve({} as ProjectContextResult); }
getCompletionContext(file: vscode.Uri, caretOffset: number, token: vscode.CancellationToken): Promise<CompletionContextsResult> { return Promise.resolve({} as CompletionContextsResult); }
}
154 changes: 154 additions & 0 deletions Extension/src/LanguageServer/copilotCompletionContextProvider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/* --------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All Rights Reserved.
* See 'LICENSE' in the project root for license information.
* ------------------------------------------------------------------------------------------ */
import * as vscode from 'vscode';
import { DocumentSelector } from 'vscode-languageserver-protocol';
import { getOutputChannelLogger, Logger } from '../logger';
import * as telemetry from '../telemetry';
import { getCopilotApi } from "./copilotProviders";
import { clients } from './extension';
import { CodeSnippet, CompletionContext, ContextProviderApiV1 } from './tmp/contextProviderV1';

// An ever growing cache of completion context snippets. //?? TODO Evict old entries.
const completionContextCache: Map<string, CodeSnippet[]> = new Map<string, CodeSnippet[]>();
const cppDocumentSelector: DocumentSelector = [{ language: 'cpp' }, { language: 'c' }];

class DefaultValueFallback extends Error {
static readonly DefaultValue = "DefaultValue";
constructor() { super(DefaultValueFallback.DefaultValue); }
}

class CancellationError extends Error {
static readonly Cancelled = "Cancelled";
constructor() { super(CancellationError.Cancelled); }
}

let completionContextCancellation = new vscode.CancellationTokenSource();

// Mutually exclusive values for the kind of snippets. They either come from the cache,
// are computed, or the computation is taking too long and no cache is present. In the latter
// case, the cache is computed anyway while unblocking the execution flow returning undefined.
enum SnippetsKind {
Computed = 'computed',
CacheHit = 'cacheHit',
CacheMiss = 'cacheMiss'
}

// Get the default value if the timeout expires, but throws an exception if the token is cancelled.
async function waitForCompletionWithTimeoutAndCancellation<T>(promise: Promise<T>, defaultValue: T | undefined,
timeout: number, token: vscode.CancellationToken): Promise<[T | undefined, SnippetsKind]> {
const defaultValuePromise = new Promise<T>((resolve, reject) => setTimeout(() => {
if (token.isCancellationRequested) {
reject('DefaultValuePromise was cancelled');
} else {
reject(new DefaultValueFallback());
}
}, timeout));
const cancellationPromise = new Promise<T>((_, reject) => {
token.onCancellationRequested(() => {
reject(new CancellationError());
});
});
let snippetsOrNothing: T | undefined;
try {
snippetsOrNothing = await Promise.race([promise, cancellationPromise, defaultValuePromise]);
} catch (e) {
if (e instanceof DefaultValueFallback) {
return [defaultValue, defaultValue !== undefined ? SnippetsKind.CacheHit : SnippetsKind.CacheMiss];
}

// Rethrow the error for cancellation cases.
throw e;
}

return [snippetsOrNothing, SnippetsKind.Computed];
}

// Get the completion context with a timeout and a cancellation token.
// The cancellationToken indicates that the value should not be returned nor cached.
async function getCompletionContextWithCancellation(documentUri: string, caretOffset: number,
startTime: number, out: Logger, token: vscode.CancellationToken): Promise<CodeSnippet[]> {
try {
const activeEditor: vscode.TextEditor | undefined = vscode.window.activeTextEditor;
if (!activeEditor ||
activeEditor.document.uri.toString() !== vscode.Uri.parse(documentUri).toString()) {
return [];
}

const snippets = await clients.ActiveClient.getCompletionContext(activeEditor.document.uri, caretOffset, token);

const codeSnippets = snippets.context.map((item) => {
if (token.isCancellationRequested) {
throw new CancellationError();
}
return {
importance: item.importance, uri: item.uri, value: item.text
};
});

completionContextCache.set(documentUri, codeSnippets);
const duration: number = Date.now() - startTime;
out.appendLine(`Copilot: getCompletionContextWithCancellation(): Cached in [ms]: ${duration}`);
// //?? TODO Add telemetry for elapsed time.

return codeSnippets;
} catch (e) {
const err = e as Error;
out.appendLine(`Copilot: getCompletionContextWithCancellation(): Error: '${err?.message}', stack '${err?.stack}`);

// //?? TODO Add telemetry for failure.
return [];
}
}

const timeBudgetFactor: number = 0.5;
const cppToolsResolver = {
async resolve(context: CompletionContext, copilotAborts: vscode.CancellationToken): Promise<CodeSnippet[]> {
const startTime = Date.now();
const out: Logger = getOutputChannelLogger();
let snippetsKind: SnippetsKind = SnippetsKind.Computed;
try {
completionContextCancellation.cancel();
completionContextCancellation = new vscode.CancellationTokenSource();
const docUri = context.documentContext.uri;
const cachedValue: CodeSnippet[] | undefined = completionContextCache.get(docUri.toString());
const snippetsPromise = getCompletionContextWithCancellation(docUri,
context.documentContext.offset, startTime, out, completionContextCancellation.token);
const [codeSnippets, kind] = await waitForCompletionWithTimeoutAndCancellation(
snippetsPromise, cachedValue, context.timeBudget * timeBudgetFactor, copilotAborts);
snippetsKind = kind;
// //?? TODO Add telemetry for Computed vs Cached.

return codeSnippets ?? [];
} catch (e: any) {
if (e instanceof CancellationError) {
out.appendLine(`Copilot: getCompletionContext(): cancelled!`);
}
// //?? TODO Add telemetry for failure.
} finally {
const duration: number = Date.now() - startTime;
out.appendLine(`Copilot: getCompletionContext(): snippets retrieval (${snippetsKind.toString()}) elapsed time (ms): ${duration}`);
// //?? TODO Add telemetry for elapsed time.
}

return [];
}
};

export async function registerCopilotContextProvider(): Promise<void> {
try {
const isCustomSnippetProviderApiEnabled = await telemetry.isExperimentEnabled("CppToolsCustomSnippetsApi");
if (isCustomSnippetProviderApiEnabled) {
const contextAPI = (await getCopilotApi() as any).getContextProviderAPI('v1') as ContextProviderApiV1;
contextAPI.registerContextProvider({
id: 'cppTools',
selector: cppDocumentSelector,
resolver: cppToolsResolver
});
}
} catch {
console.warn("Failed to register the Copilot Context Provider.");
// //?? TODO Add telemetry for failure.
}
}
11 changes: 11 additions & 0 deletions Extension/src/LanguageServer/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import * as telemetry from '../telemetry';
import { Client, DefaultClient, DoxygenCodeActionCommandArguments, openFileVersions } from './client';
import { ClientCollection } from './clientCollection';
import { CodeActionDiagnosticInfo, CodeAnalysisDiagnosticIdentifiersAndUri, codeAnalysisAllFixes, codeAnalysisCodeToFixes, codeAnalysisFileToCodeActions } from './codeAnalysis';
import { registerCopilotContextProvider } from './copilotCompletionContextProvider';
import { registerRelatedFilesProvider } from './copilotProviders';
import { CppBuildTaskProvider } from './cppBuildTaskProvider';
import { getCustomConfigProviders } from './customProviders';
Expand All @@ -34,6 +35,14 @@ import { CppSettings } from './settings';
import { LanguageStatusUI, getUI } from './ui';
import { makeLspRange, rangeEquals, showInstallCompilerWalkthrough } from './utils';

export interface SnippetEntry {
uri: string;
text: string;
startLine: number;
endLine: number;
importance: number;
}

nls.config({ messageFormat: nls.MessageFormat.bundle, bundleFormat: nls.BundleFormat.standalone })();
const localize: nls.LocalizeFunc = nls.loadMessageBundle();
export const CppSourceStr: string = "C/C++";
Expand Down Expand Up @@ -264,6 +273,8 @@ export async function activate(): Promise<void> {
}

await registerRelatedFilesProvider();

await registerCopilotContextProvider();
}

export function updateLanguageConfigurations(): void {
Expand Down
Loading