diff --git a/.changeset/dull-parrots-burn.md b/.changeset/dull-parrots-burn.md new file mode 100644 index 0000000000..2c04c21bb1 --- /dev/null +++ b/.changeset/dull-parrots-burn.md @@ -0,0 +1,7 @@ +--- +"@hyperdx/common-utils": patch +"@hyperdx/api": patch +"@hyperdx/app": patch +--- + +Add better types for AI features, Fix bug that could cause page crash when generating graphs diff --git a/packages/api/src/controllers/ai.ts b/packages/api/src/controllers/ai.ts index e8c5d5dab6..fab4603926 100644 --- a/packages/api/src/controllers/ai.ts +++ b/packages/api/src/controllers/ai.ts @@ -1,9 +1,26 @@ import { createAnthropic } from '@ai-sdk/anthropic'; +import { ClickhouseClient } from '@hyperdx/common-utils/dist/clickhouse/node'; +import { + getMetadata, + TableMetadata, +} from '@hyperdx/common-utils/dist/core/metadata'; +import { + AILineTableResponse, + AssistantLineTableConfigSchema, + ChartConfigWithDateRange, +} from '@hyperdx/common-utils/dist/types'; import type { LanguageModel } from 'ai'; +import * as chrono from 'chrono-node'; +import ms from 'ms'; +import z from 'zod'; import * as config from '@/config'; +import { ISource } from '@/models/source'; +import { Api500Error } from '@/utils/errors'; import logger from '@/utils/logger'; +import { getConnectionById } from './connection'; + /** * Get configured AI model for use in the application. * Currently supports Anthropic (with both direct API and Azure AI endpoints). @@ -55,6 +72,268 @@ export function getAIModel(): LanguageModel { } } +export async function getAIMetadata(source: ISource) { + const connectionId = source.connection.toString(); + + const connection = await getConnectionById( + source.team.toString(), + connectionId, + ); + + if (connection == null) { + throw new Api500Error('Invalid connection'); + } + + const clickhouseClient = new ClickhouseClient({ + host: connection.host, + username: connection.username, + password: connection.password, + }); + + const metadata = getMetadata(clickhouseClient); + + const databaseName = source.from.databaseName; + const tableName = source.from.tableName; + + const tableMetadata = await metadata.getTableMetadata({ + databaseName, + tableName, + connectionId, + }); + + const allFields = await metadata.getAllFields({ + databaseName, + tableName, + connectionId, + }); + + // TODO: Dedup with DBSearchPageFilters.tsx logic + allFields.sort((a, b) => { + // Prioritize primary keys + // TODO: Support JSON + const aPath = mergePath(a.path, []); + const bPath = mergePath(b.path, []); + if (isFieldPrimary(tableMetadata, aPath)) { + return -1; // TODO: Check sort order + } else if (isFieldPrimary(tableMetadata, bPath)) { + return 1; + } + + //First show low cardinality fields + const isLowCardinality = (type: string) => type.includes('LowCardinality'); + return isLowCardinality(a.type) && !isLowCardinality(b.type) ? -1 : 1; + }); + + const allFieldsWithKeys = allFields.map(f => { + return { + ...f, + key: mergePath(f.path), + }; + }); + const keysToFetch = allFieldsWithKeys.slice(0, 30); + const cc: ChartConfigWithDateRange = { + select: '', + from: { + databaseName, + tableName, + }, + connection: connectionId, + where: '', + groupBy: '', + timestampValueExpression: source.timestampValueExpression, + dateRange: [new Date(Date.now() - ms('60m')), new Date()], + }; + const keyValues = await metadata.getKeyValues({ + chartConfig: cc, + keys: keysToFetch.map(f => f.key), + source, + }); + + return { + allFields, + allFieldsWithKeys, + keyValues, + }; +} + +function normalizeParsedDate(parsed?: chrono.ParsedComponents): Date | null { + if (!parsed) { + return null; + } + + if (parsed.isCertain('year')) { + return parsed.date(); + } + + const now = new Date(); + if ( + !( + parsed.isCertain('hour') || + parsed.isCertain('minute') || + parsed.isCertain('second') || + parsed.isCertain('millisecond') + ) + ) { + // If all of the time components have been inferred, set the time components of now + // to match the parsed time components. This ensures that the comparison later on uses + // the same point in time when only worrying about dates. + now.setHours(parsed.get('hour') || 0); + now.setMinutes(parsed.get('minute') || 0); + now.setSeconds(parsed.get('second') || 0); + now.setMilliseconds(parsed.get('millisecond') || 0); + } + + const parsedDate = parsed.date(); + if (parsedDate > now) { + parsedDate.setFullYear(parsedDate.getFullYear() - 1); + } + return parsedDate; +} + +export function parseTimeRangeInput( + str: string, + isUTC: boolean = false, +): [Date | null, Date | null] { + const parsedTimeResults = chrono.parse(str, isUTC ? { timezone: 0 } : {}); + if (parsedTimeResults.length === 0) { + return [null, null]; + } + + const parsedTimeResult = + parsedTimeResults.length === 1 + ? parsedTimeResults[0] + : parsedTimeResults[1]; + const start = normalizeParsedDate(parsedTimeResult.start); + const end = normalizeParsedDate(parsedTimeResult.end) || new Date(); + if (end && start && end < start) { + // For date range strings that omit years, the chrono parser will infer the year + // using the current year. This can cause the start date to be in the future, and + // returned as the end date instead of the start date. After normalizing the dates, + // we then need to swap the order to maintain a range from older to newer. + return [end, start]; + } else { + return [start, end]; + } +} + +export const LIVE_TAIL_TIME_QUERY = 'Live Tail'; + +export const RELATIVE_TIME_OPTIONS: ([string, string] | 'divider')[] = [ + // ['Last 15 seconds', '15s'], + // ['Last 30 seconds', '30s'], + // 'divider', + ['Last 1 minute', '1m'], + ['Last 5 minutes', '5m'], + ['Last 15 minutes', '15m'], + ['Last 30 minutes', '30m'], + ['Last 45 minutes', '45m'], + 'divider', + ['Last 1 hour', '1h'], + ['Last 3 hours', '3h'], + ['Last 6 hours', '6h'], + ['Last 12 hours', '12h'], + 'divider', + ['Last 1 days', '1d'], + ['Last 2 days', '2d'], + ['Last 7 days', '7d'], + ['Last 14 days', '14d'], + ['Last 30 days', '30d'], +]; + +export const DURATION_OPTIONS = [ + '30s', + '1m', + '5m', + '15m', + '30m', + '1h', + '3h', + '6h', + '12h', +]; + +export const DURATIONS: Record = { + '30s': { seconds: 30 }, + '1m': { minutes: 1 }, + '5m': { minutes: 5 }, + '15m': { minutes: 15 }, + '30m': { minutes: 30 }, + '1h': { hours: 1 }, + '3h': { hours: 3 }, + '6h': { hours: 6 }, + '12h': { hours: 12 }, +}; + +export const dateParser = (input?: string) => { + if (!input) { + return null; + } + const parsed = chrono.casual.parse(input)[0]; + return normalizeParsedDate(parsed?.start); +}; + +// TODO: Dedup from DBSearchPageFilters +function isFieldPrimary(tableMetadata: TableMetadata | undefined, key: string) { + return tableMetadata?.primary_key?.includes(key); +} + +// TODO: Dedup w/ app/src/utils.ts +// Date formatting +export const mergePath = (path: string[], jsonColumns: string[] = []) => { + const [key, ...rest] = path; + if (rest.length === 0) { + return key; + } + return jsonColumns.includes(key) + ? `${key}.${rest + .map(v => + v + .split('.') + .map(v => (v.startsWith('`') && v.endsWith('`') ? v : `\`${v}\``)) + .join('.'), + ) + .join('.')}` + : `${key}['${rest.join("']['")}']`; +}; + +export function getChartConfigFromResolvedConfig( + resObject: z.infer, + source: ISource, +): AILineTableResponse { + const parsedTimeRange = parseTimeRangeInput(resObject.timeRange); + // TODO: More robust recovery logic + const dateRange: [Date, Date] = [ + parsedTimeRange[0] ?? new Date(Date.now() - ms('1h')), + parsedTimeRange[1] ?? new Date(), + ]; + + return { + displayType: resObject.displayType, + select: resObject.select.map(s => ({ + aggFn: s.aggregationFunction, + valueExpression: s.property, + ...(s.condition + ? { + aggCondition: s.condition, + aggConditionLanguage: 'sql', + } + : {}), + })), + from: { + tableName: source.from.tableName, + databaseName: source.from.databaseName, + }, + source: source.id, + connection: source.connection.toString(), + groupBy: resObject.groupBy, + timestampValueExpression: source.timestampValueExpression, + dateRange: [dateRange[0].toString(), dateRange[1].toString()], + markdown: resObject.markdown, + granularity: 'auto', + whereLanguage: 'lucene', + }; +} + /** * Configure Anthropic model. * Supports both direct Anthropic API and Azure AI Anthropic endpoints. diff --git a/packages/api/src/routers/api/ai.ts b/packages/api/src/routers/api/ai.ts index a6d7f2f93e..c3e257d75a 100644 --- a/packages/api/src/routers/api/ai.ts +++ b/packages/api/src/routers/api/ai.ts @@ -1,171 +1,25 @@ -import { ClickhouseClient } from '@hyperdx/common-utils/dist/clickhouse/node'; import { - getMetadata, - TableMetadata, -} from '@hyperdx/common-utils/dist/core/metadata'; -import { - AggregateFunctionSchema, - ChartConfigWithDateRange, - DisplayType, + AssistantLineTableConfigSchema, SourceKind, } from '@hyperdx/common-utils/dist/types'; -import { generateObject } from 'ai'; -import * as chrono from 'chrono-node'; +import { APICallError, generateObject } from 'ai'; import express from 'express'; -import ms from 'ms'; import { z } from 'zod'; import { validateRequest } from 'zod-express-middleware'; -import * as config from '@/config'; -import { getAIModel } from '@/controllers/ai'; -import { getConnectionById } from '@/controllers/connection'; +import { + getAIMetadata, + getAIModel, + getChartConfigFromResolvedConfig, +} from '@/controllers/ai'; import { getSource } from '@/controllers/sources'; import { getNonNullUserWithTeam } from '@/middleware/auth'; +import { Api404Error, Api500Error } from '@/utils/errors'; import logger from '@/utils/logger'; import { objectIdSchema } from '@/utils/zod'; const router = express.Router(); -function normalizeParsedDate(parsed?: chrono.ParsedComponents): Date | null { - if (!parsed) { - return null; - } - - if (parsed.isCertain('year')) { - return parsed.date(); - } - - const now = new Date(); - if ( - !( - parsed.isCertain('hour') || - parsed.isCertain('minute') || - parsed.isCertain('second') || - parsed.isCertain('millisecond') - ) - ) { - // If all of the time components have been inferred, set the time components of now - // to match the parsed time components. This ensures that the comparison later on uses - // the same point in time when only worrying about dates. - now.setHours(parsed.get('hour') || 0); - now.setMinutes(parsed.get('minute') || 0); - now.setSeconds(parsed.get('second') || 0); - now.setMilliseconds(parsed.get('millisecond') || 0); - } - - const parsedDate = parsed.date(); - if (parsedDate > now) { - parsedDate.setFullYear(parsedDate.getFullYear() - 1); - } - return parsedDate; -} - -export function parseTimeRangeInput( - str: string, - isUTC: boolean = false, -): [Date | null, Date | null] { - const parsedTimeResults = chrono.parse(str, isUTC ? { timezone: 0 } : {}); - if (parsedTimeResults.length === 0) { - return [null, null]; - } - - const parsedTimeResult = - parsedTimeResults.length === 1 - ? parsedTimeResults[0] - : parsedTimeResults[1]; - const start = normalizeParsedDate(parsedTimeResult.start); - const end = normalizeParsedDate(parsedTimeResult.end) || new Date(); - if (end && start && end < start) { - // For date range strings that omit years, the chrono parser will infer the year - // using the current year. This can cause the start date to be in the future, and - // returned as the end date instead of the start date. After normalizing the dates, - // we then need to swap the order to maintain a range from older to newer. - return [end, start]; - } else { - return [start, end]; - } -} - -export const LIVE_TAIL_TIME_QUERY = 'Live Tail'; - -export const RELATIVE_TIME_OPTIONS: ([string, string] | 'divider')[] = [ - // ['Last 15 seconds', '15s'], - // ['Last 30 seconds', '30s'], - // 'divider', - ['Last 1 minute', '1m'], - ['Last 5 minutes', '5m'], - ['Last 15 minutes', '15m'], - ['Last 30 minutes', '30m'], - ['Last 45 minutes', '45m'], - 'divider', - ['Last 1 hour', '1h'], - ['Last 3 hours', '3h'], - ['Last 6 hours', '6h'], - ['Last 12 hours', '12h'], - 'divider', - ['Last 1 days', '1d'], - ['Last 2 days', '2d'], - ['Last 7 days', '7d'], - ['Last 14 days', '14d'], - ['Last 30 days', '30d'], -]; - -export const DURATION_OPTIONS = [ - '30s', - '1m', - '5m', - '15m', - '30m', - '1h', - '3h', - '6h', - '12h', -]; - -export const DURATIONS: Record = { - '30s': { seconds: 30 }, - '1m': { minutes: 1 }, - '5m': { minutes: 5 }, - '15m': { minutes: 15 }, - '30m': { minutes: 30 }, - '1h': { hours: 1 }, - '3h': { hours: 3 }, - '6h': { hours: 6 }, - '12h': { hours: 12 }, -}; - -export const dateParser = (input?: string) => { - if (!input) { - return null; - } - const parsed = chrono.casual.parse(input)[0]; - return normalizeParsedDate(parsed?.start); -}; - -// TODO: Dedup from DBSearchPageFilters -function isFieldPrimary(tableMetadata: TableMetadata | undefined, key: string) { - return tableMetadata?.primary_key?.includes(key); -} - -// TODO: Dedup w/ app/src/utils.ts -// Date formatting -export const mergePath = (path: string[], jsonColumns: string[] = []) => { - const [key, ...rest] = path; - if (rest.length === 0) { - return key; - } - return jsonColumns.includes(key) - ? `${key}.${rest - .map(v => - v - .split('.') - .map(v => (v.startsWith('`') && v.endsWith('`') ? v : `\`${v}\``)) - .join('.'), - ) - .join('.')}` - : `${key}['${rest.join("']['")}']`; -}; - router.post( '/assistant', validateRequest({ @@ -176,6 +30,8 @@ router.post( }), async (req, res, next) => { try { + const model = getAIModel(); + const { teamId } = getNonNullUserWithTeam(req); const { text, sourceId } = req.body; @@ -183,97 +39,11 @@ router.post( const source = await getSource(teamId.toString(), sourceId); if (source == null) { - logger.error({ sourceId, teamId }, 'invalid source id'); - return res.status(400).json({ - error: 'Invalid source', - }); - } - - const connectionId = source.connection.toString(); - - const connection = await getConnectionById( - teamId.toString(), - connectionId, - ); - - if (connection == null) { - logger.error({ - message: 'invalid connection id', - connectionId, - teamId, - }); - return res.status(400).json({ - error: 'Invalid connection', - }); + logger.error({ message: 'invalid source id', sourceId, teamId }); + throw new Api404Error('Invalid source'); } - const clickhouseClient = new ClickhouseClient({ - host: connection.host, - username: connection.username, - password: connection.password, - }); - const metadata = getMetadata(clickhouseClient); - - const databaseName = source.from.databaseName; - const tableName = source.from.tableName; - - const tableMetadata = await metadata.getTableMetadata({ - databaseName, - tableName, - connectionId, - }); - - const allFields = await metadata.getAllFields({ - databaseName, - tableName, - connectionId, - }); - - // TODO: Dedup with DBSearchPageFilters.tsx logic - allFields.sort((a, b) => { - // Prioritize primary keys - // TODO: Support JSON - const aPath = mergePath(a.path, []); - const bPath = mergePath(b.path, []); - if (isFieldPrimary(tableMetadata, aPath)) { - return -1; // TODO: Check sort order - } else if (isFieldPrimary(tableMetadata, bPath)) { - return 1; - } - - //First show low cardinality fields - const isLowCardinality = (type: string) => - type.includes('LowCardinality'); - return isLowCardinality(a.type) && !isLowCardinality(b.type) ? -1 : 1; - }); - - const allFieldsWithKeys = allFields.map(f => { - return { - ...f, - key: mergePath(f.path), - }; - }); - const keysToFetch = allFieldsWithKeys.slice(0, 30); - const cc: ChartConfigWithDateRange = { - select: '', - from: { - databaseName, - tableName, - }, - connection: connectionId, - where: '', - groupBy: '', - timestampValueExpression: source.timestampValueExpression, - dateRange: [new Date(Date.now() - ms('60m')), new Date()], - }; - const keyValues = await metadata.getKeyValues({ - chartConfig: cc, - keys: keysToFetch.map(f => f.key), - source, - }); - - // Get configured AI model - const model = getAIModel(); + const { allFieldsWithKeys, keyValues } = await getAIMetadata(source); const prompt = `You are an AI assistant that helps users create chart configurations for an observability platform called HyperDX. @@ -310,80 +80,26 @@ ${JSON.stringify(allFieldsWithKeys.slice(0, 200).map(f => ({ field: f.key, type: logger.info(prompt); - const result = await generateObject({ - model, - schema: z.object({ - displayType: z.enum([DisplayType.Line, DisplayType.Table]), - select: z - .array( - // @ts-ignore - z.object({ - // TODO: Change percentile to fixed functions - aggregationFunction: AggregateFunctionSchema.describe( - 'SQL-like function to aggregate the property by', - ), - property: z - .string() - .describe( - 'Property or column to be aggregated (ex. Duration)', - ), - condition: z - .string() - .optional() - .describe( - "SQL filter condition to filter on ex. `SeverityText = 'error'`", - ), - }), - ) - .describe('Array of data series or columns to chart for the user'), - groupBy: z - .string() - .optional() - .describe('Group by column or properties for the chart'), - timeRange: z - .string() - .default('Past 1h') - .describe( - 'Time range of data to query for like "Past 1h", "Past 24h"', - ), - }), - prompt, - }); - - const resObject = result.object; - const parsedTimeRange = parseTimeRangeInput(resObject.timeRange); - // TODO: More robust recovery logic - const dateRange: [Date, Date] = [ - parsedTimeRange[0] ?? new Date(Date.now() - ms('1h')), - parsedTimeRange[1] ?? new Date(), - ]; + try { + const result = await generateObject({ + model, + schema: AssistantLineTableConfigSchema, + experimental_telemetry: { isEnabled: true }, + prompt, + }); - const chartConfig: ChartConfigWithDateRange & { source: string } = { - displayType: resObject.displayType, - select: resObject.select.map(s => ({ - aggFn: s.aggregationFunction, - valueExpression: s.property, - ...(s.condition - ? { - aggCondition: s.condition, - aggConditionLanguage: 'sql', - } - : {}), - })), - from: { - tableName: source.from.tableName, - databaseName: source.from.databaseName, - }, - source: sourceId, - connection: connectionId, - where: '', - groupBy: resObject.groupBy, - timestampValueExpression: source.timestampValueExpression, - dateRange, - granularity: 'auto', - }; + const chartConfig = getChartConfigFromResolvedConfig( + result.object, + source, + ); - return res.json(chartConfig); + return res.json(chartConfig); + } catch (err) { + if (err instanceof APICallError) { + throw new Api500Error(`AI Provider Error: ${err.message}`); + } + throw err; + } } catch (e) { next(e); } diff --git a/packages/app/src/DBChartPage.tsx b/packages/app/src/DBChartPage.tsx index 6e530870c9..496eff2bde 100644 --- a/packages/app/src/DBChartPage.tsx +++ b/packages/app/src/DBChartPage.tsx @@ -77,11 +77,9 @@ function AIAssistant({ }, { onSuccess(data) { - setConfig(data); + setConfig({ ...data, where: '' }); onTimeRangeSelect( - // @ts-ignore TODO: fix these types new Date(data.dateRange[0]), - // @ts-ignore TODO: fix these types new Date(data.dateRange[1]), ); diff --git a/packages/app/src/hooks/ai.ts b/packages/app/src/hooks/ai.ts index 2e6dffb073..177a1aba3c 100644 --- a/packages/app/src/hooks/ai.ts +++ b/packages/app/src/hooks/ai.ts @@ -1,4 +1,7 @@ -import type { SavedChartConfig } from '@hyperdx/common-utils/dist/types'; +import type { + AILineTableResponse, + SavedChartConfig, +} from '@hyperdx/common-utils/dist/types'; import { useMutation } from '@tanstack/react-query'; import { hdxServer } from '@/api'; @@ -9,11 +12,11 @@ type AssistantInput = { }; export function useChartAssistant() { - return useMutation({ + return useMutation({ mutationFn: async ({ sourceId, text }: AssistantInput) => hdxServer('ai/assistant', { method: 'POST', json: { sourceId, text }, - }).json(), + }).json(), }); } diff --git a/packages/common-utils/src/types.ts b/packages/common-utils/src/types.ts index 891a407ff2..8e817f92cd 100644 --- a/packages/common-utils/src/types.ts +++ b/packages/common-utils/src/types.ts @@ -491,7 +491,7 @@ export type ChartConfigWithOptDateRange = Omit< export const SavedChartConfigSchema = z .object({ - name: z.string(), + name: z.string().optional(), source: z.string(), alert: z.union([ AlertBaseSchema.optional(), @@ -833,3 +833,104 @@ type TSourceWithoutDefaults = FlattenUnion>; export type TSource = TSourceWithoutDefaults & { timestampValueExpression: string; }; + +export const AssistantLineTableConfigSchema = z.object({ + displayType: z.enum([DisplayType.Line, DisplayType.Table]), + markdown: z.string().optional(), + select: z + .array( + z.object({ + // TODO: Change percentile to fixed functions + aggregationFunction: AggregateFunctionSchema.describe( + 'SQL-like function to aggregate the property by', + ), + property: z + .string() + .describe('Property or column to be aggregated (ex. Duration)'), + condition: z + .string() + .optional() + .describe( + "SQL filter condition to filter on ex. `SeverityText = 'error'`", + ), + }), + ) + .describe('Array of data series or columns to chart for the user'), + groupBy: z + .string() + .optional() + .describe('Group by column or properties for the chart'), + timeRange: z + .string() + .default('Past 1h') + .describe('Time range of data to query for like "Past 1h", "Past 24h"'), +}); + +// Base fields common to all three shapes +const AIBaseSchema = z.object({ + from: SelectSQLStatementSchema.shape.from, + source: z.string(), + connection: z.string(), + where: SearchConditionSchema.optional(), + whereLanguage: SearchConditionLanguageSchema, + timestampValueExpression: z.string(), + dateRange: z.tuple([z.string(), z.string()]), // keep as string tuple (ISO recommended) + name: z.string().optional(), + markdown: z.string().optional(), +}); + +// SEARCH +const AISearchQuerySchema = z + .object({ + displayType: z.literal(DisplayType.Search), + select: z.string(), + groupBy: z.string().optional(), + limit: z + .object({ + limit: z.number().int().positive(), + }) + .optional(), + }) + .merge( + AIBaseSchema.required({ + where: true, + }), + ); + +// TABLE +const AITableQuerySchema = z + .object({ + displayType: z.literal(DisplayType.Table), + // Use your DerivedColumnSchema so aggFn/valueExpression/conditions are validated consistently + select: z.array(DerivedColumnSchema).min(1), + groupBy: z.string().optional(), + granularity: z.union([SQLIntervalSchema, z.literal('auto')]).optional(), + limit: LimitSchema.optional(), + }) + .merge(AIBaseSchema); + +// LINE +const AILineQuerySchema = z + .object({ + displayType: z.literal(DisplayType.Line), + select: z.array(DerivedColumnSchema).min(1), + groupBy: z.string().optional(), + granularity: z.union([SQLIntervalSchema, z.literal('auto')]).optional(), + limit: LimitSchema.optional(), + }) + .merge(AIBaseSchema); + +export type AILineTableResponse = + | z.infer + | z.infer; + +// Union that covers all 3 objects +export const AssistantResponseConfig = z.discriminatedUnion('displayType', [ + AISearchQuerySchema, + AITableQuerySchema, + AILineQuerySchema, +]); + +export type AssistantResponseConfigSchema = z.infer< + typeof AssistantResponseConfig +>;