diff --git a/src/Microsoft.DotNet.Interactive.PostgreSql.Tests/PostgreSqlConnectionTests.cs b/src/Microsoft.DotNet.Interactive.PostgreSql.Tests/PostgreSqlConnectionTests.cs index 9925dad72f..bc99dc85b0 100644 --- a/src/Microsoft.DotNet.Interactive.PostgreSql.Tests/PostgreSqlConnectionTests.cs +++ b/src/Microsoft.DotNet.Interactive.PostgreSql.Tests/PostgreSqlConnectionTests.cs @@ -11,8 +11,10 @@ using Microsoft.DotNet.Interactive.Events; using Microsoft.DotNet.Interactive.CSharp; using Microsoft.DotNet.Interactive.App; +using Microsoft.DotNet.Interactive.Commands; using FluentAssertions; using Xunit; +using System.Collections.Generic; namespace Microsoft.DotNet.Interactive.PostgreSql.Tests; @@ -82,8 +84,148 @@ public async Task It_returns_error_if_query_is_not_valid() .Contain("column \"not_known_column\" does not exist"); } + [PostgreSqlFact] + public async Task When_variable_does_not_exist_then_an_error_is_returned() + { + var connectionString = PostgreSqlFactAttribute.GetConnectionStringForTests(); + using var kernel = CreateKernel(); + var result = await kernel.SubmitCodeAsync( + $"#!connect postgres --kernel-name adventureworks \"{connectionString}\""); + + result.Events + .Should() + .NotContainErrors(); + + var sqlKernel = kernel.FindKernelByName("sql-adventureworks"); + + result = await sqlKernel.SendAsync(new RequestValue("my_data_result")); + + result.Events.Should() + .ContainSingle() + .Which + .Message + .Should() + .Contain("Value 'my_data_result' not found in kernel sql-adventureworks"); + } + + [PostgreSqlFact] + public async Task It_can_store_result_set_with_a_name() + { + var connectionString = PostgreSqlFactAttribute.GetConnectionStringForTests(); + using var kernel = CreateKernel(); + await kernel.SubmitCodeAsync( + $"#!connect postgres --kernel-name adventureworks \"{connectionString}\""); + + await kernel.SubmitCodeAsync(""" + #!sql-adventureworks --name my_data_result + SELECT * FROM customers LIMIT 10; + """); + + var result = await kernel.SubmitCodeAsync(""" + #!csharp + #!share --from sql-adventureworks my_data_result + my_data_result + """); + + result.Events + .Should() + .ContainSingle() + .Which + .Value + .Should() + .BeAssignableTo>() + .Which.Count() + .Should() + .Be(1); + } + + [PostgreSqlFact] + public async Task Stored_query_results_are_listed_in_ValueInfos() + { + var connectionString = PostgreSqlFactAttribute.GetConnectionStringForTests(); + using var kernel = CreateKernel(); + await kernel.SubmitCodeAsync( + $"#!connect postgres --kernel-name adventureworks \"{connectionString}\""); + + await kernel.SubmitCodeAsync(""" + #!sql-adventureworks --name my_data_result + SELECT * FROM customers LIMIT 10; + """); + + var sqlKernel = kernel.FindKernelByName("sql-adventureworks"); + + var result = await sqlKernel.SendAsync(new RequestValueInfos()); + + var valueInfos = result.Events.Should().ContainSingle() + .Which.ValueInfos; + + valueInfos.Should().Contain(v => v.Name == "my_data_result"); + } + + [PostgreSqlFact] + public async Task Storing_results_does_interfere_with_subsequent_executions() + { + var connectionString = PostgreSqlFactAttribute.GetConnectionStringForTests(); + using var kernel = CreateKernel(); + await kernel.SubmitCodeAsync( + $"#!connect postgres --kernel-name adventureworks \"{connectionString}\""); + + await kernel.SubmitCodeAsync(""" + #!sql-adventureworks --name my_data_result + SELECT * FROM customers LIMIT 10; + """); + + var sqlKernel = kernel.FindKernelByName("sql-adventureworks"); + + var result = await sqlKernel.SendAsync(new RequestValueInfos()); + + var valueInfos = result.Events.Should().ContainSingle() + .Which.ValueInfos; + + valueInfos.Should().Contain(v => v.Name == "my_data_result"); + + result = await kernel.SubmitCodeAsync(""" + #!sql-adventureworks --name my_data_result + SELECT * FROM customers LIMIT 10; + """); + + result.Events.Should().NotContainErrors(); + } + + [PostgreSqlFact] + public async Task It_can_store_multiple_result_set_with_a_name() + { + var connectionString = PostgreSqlFactAttribute.GetConnectionStringForTests(); + using var kernel = CreateKernel(); + await kernel.SubmitCodeAsync( + $"#!connect postgres --kernel-name adventureworks \"{connectionString}\""); + + await kernel.SubmitCodeAsync(""" + #!sql-adventureworks --name my_data_result + SELECT * FROM customers LIMIT 5; + SELECT * FROM customers LIMIT 5; + """); + + var result = await kernel.SubmitCodeAsync(""" + #!csharp + #!share --from sql-adventureworks my_data_result + my_data_result + """); + + result.Events + .Should() + .ContainSingle() + .Which + .Value + .Should() + .BeAssignableTo>() + .Which.Count() + .Should() + .Be(2); + } + public void Dispose() { DataExplorer.ResetToDefault(); } -} \ No newline at end of file +} diff --git a/src/Microsoft.DotNet.Interactive.PostgreSql/PostgreSqlKernel.cs b/src/Microsoft.DotNet.Interactive.PostgreSql/PostgreSqlKernel.cs index ed23cbbbf9..e2dfd21c23 100644 --- a/src/Microsoft.DotNet.Interactive.PostgreSql/PostgreSqlKernel.cs +++ b/src/Microsoft.DotNet.Interactive.PostgreSql/PostgreSqlKernel.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation and contributors. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System; using System.Collections.Generic; using System.Data; using System.Data.Common; @@ -8,7 +9,11 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Html; using Microsoft.DotNet.Interactive.Commands; +using Microsoft.DotNet.Interactive.Directives; +using Microsoft.DotNet.Interactive.Events; +using Microsoft.DotNet.Interactive.Formatting; using Microsoft.DotNet.Interactive.Formatting.TabularData; +using Microsoft.DotNet.Interactive.ValueSharing; using Npgsql; using Enumerable = System.Linq.Enumerable; @@ -16,10 +21,13 @@ namespace Microsoft.DotNet.Interactive.PostgreSql; public class PostgreSqlKernel : Kernel, - IKernelCommandHandler + IKernelCommandHandler, + IKernelCommandHandler, + IKernelCommandHandler { private readonly string _connectionString; private IEnumerable>> _tables; + private readonly Dictionary _resultSets = new(StringComparer.Ordinal); public PostgreSqlKernel(string name, string connectionString) : base(name) { @@ -31,6 +39,16 @@ Query a PostgreSQL database _connectionString = connectionString; } + public override KernelSpecifierDirective KernelSpecifierDirective + { + get + { + var directive = base.KernelSpecifierDirective; + directive.Parameters.Add(new("--name")); + return directive; + } + } + private DbConnection OpenConnection() { return new NpgsqlConnection(_connectionString); @@ -40,24 +58,36 @@ async Task IKernelCommandHandler.HandleAsync( SubmitCode submitCode, KernelInvocationContext context) { - await using var connection = OpenConnection(); - if (connection.State is not ConnectionState.Open) + var results = new List(); + try { - await connection.OpenAsync(); - } + await using var connection = OpenConnection(); + if (connection.State is not ConnectionState.Open) + { + await connection.OpenAsync(); + } - await using var dbCommand = connection.CreateCommand(); + await using var dbCommand = connection.CreateCommand(); - dbCommand.CommandText = submitCode.Code; + dbCommand.CommandText = submitCode.Code; - _tables = Execute(dbCommand); + _tables = Execute(dbCommand); - foreach (var table in _tables) - { - var tabularDataResource = table.ToTabularDataResource(); + foreach (var table in _tables) + { + var tabularDataResource = table.ToTabularDataResource(); - var explorer = DataExplorer.CreateDefault(tabularDataResource); - context.Display(explorer); + var explorer = DataExplorer.CreateDefault(tabularDataResource); + context.Display(explorer); + + results.Add(tabularDataResource); + } + } + finally + { + submitCode.Parameters.TryGetValue("--name", out var queryName); + string name = queryName ?? ""; + _resultSets[name] = results; } } @@ -125,4 +155,61 @@ public static void AddPostgreSqlKernelConnectorToCurrentRoot() "text/html"); } } + + private bool TryGetValue(string name, out T value) + { + if (_resultSets.TryGetValue(name, out var resultSet) && + resultSet is T resultSetT) + { + value = resultSetT; + return true; + } + + value = default; + return false; + } + + Task IKernelCommandHandler.HandleAsync(RequestValue command, KernelInvocationContext context) + { + if (TryGetValue(command.Name, out var value)) + { + context.Publish(new ValueProduced( + value, + command.Name, + new FormattedValue( + command.MimeType, + value.ToDisplayString(command.MimeType)), + command)); + } + else + { + context.Fail(command, message: $"Value '{command.Name}' not found in kernel {Name}"); + } + + return Task.CompletedTask; + } + + Task IKernelCommandHandler.HandleAsync(RequestValueInfos command, KernelInvocationContext context) + { + var valueInfos = CreateKernelValueInfos(_resultSets, command.MimeType).ToArray(); + + context.Publish(new ValueInfosProduced(valueInfos, command)); + + return Task.CompletedTask; + + static IEnumerable CreateKernelValueInfos(IReadOnlyDictionary source, string mimeType) + { + return source.Keys.Select(key => + { + var formattedValues = FormattedValue.CreateSingleFromObject( + source[key], + mimeType); + + return new KernelValueInfo( + key, + formattedValues, + type: typeof(IEnumerable)); + }); + } + } }