Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/WinRT.Interop.Generator/Extensions/CilInstructionExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.Diagnostics.CodeAnalysis;
using AsmResolver.DotNet;
using AsmResolver.DotNet.Code.Cil;
using AsmResolver.DotNet.Signatures;
using AsmResolver.PE.DotNet.Cil;
using AsmResolver.PE.DotNet.Metadata.Tables;
using static AsmResolver.PE.DotNet.Cil.CilOpCodes;

namespace WindowsRuntime.InteropGenerator;
Expand Down Expand Up @@ -69,5 +73,33 @@ public static CilInstruction CreateLdloc(CilLocalVariable local, CilMethodBody m
int i => new CilInstruction(Ldloc, i)
};
}

/// <summary>
/// Create a new instruction storing a value indirectly to a target location.
/// </summary>
/// <param name="type">The type of value to store.</param>
/// <param name="module">The <see cref="ModuleDefinition"/> in use.</param>
/// <returns>The instruction.</returns>
[SuppressMessage("Style", "IDE0072", Justification = "We use 'stobj' for all other possible types.")]
public static CilInstruction CreateStind(TypeSignature type, ModuleDefinition module)
{
return type.ElementType switch
{
ElementType.Boolean => new CilInstruction(Stind_I1),
ElementType.Char => new CilInstruction(Stind_I2),
ElementType.I1 => new CilInstruction(Stind_I1),
ElementType.U1 => new CilInstruction(Stind_I1),
ElementType.I2 => new CilInstruction(Stind_I2),
ElementType.U2 => new CilInstruction(Stind_I2),
ElementType.I4 => new CilInstruction(Stind_I4),
ElementType.U4 => new CilInstruction(Stind_I4),
ElementType.I8 => new CilInstruction(Stind_I8),
ElementType.U8 => new CilInstruction(Stind_I8),
ElementType.R4 => new CilInstruction(Stind_R4),
ElementType.R8 => new CilInstruction(Stind_R8),
ElementType.ValueType when type.Resolve() is { IsClass: true, IsEnum: true } => new CilInstruction(Stind_I4),
_ => new CilInstruction(Stobj, type.Import(module).ToTypeDefOrRef()),
};
}
}
}
51 changes: 41 additions & 10 deletions src/WinRT.Interop.Generator/Extensions/WindowsRuntimeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using AsmResolver;
using AsmResolver.DotNet;
Expand Down Expand Up @@ -49,27 +50,27 @@ public bool TryGetGuidAttribute(InteropReferences interopReferences, out Guid ii
extension(ITypeDescriptor type)
{
/// <summary>
/// Checks whether an <see cref="ITypeDescriptor"/> is some <see cref="System.Guid"/> type.
/// Checks whether an <see cref="ITypeDescriptor"/> is some <see cref="Guid"/> type.
/// </summary>
/// <returns>Whether the type is some <see cref="System.Guid"/> type.</returns>
/// <returns>Whether the type is some <see cref="Guid"/> type.</returns>
public bool IsTypeOfGuid(InteropReferences interopReferences)
{
return SignatureComparer.IgnoreVersion.Equals(type, interopReferences.Guid);
}

/// <summary>
/// Checks whether an <see cref="ITypeDescriptor"/> is some <see cref="System.Type"/> type.
/// Checks whether an <see cref="ITypeDescriptor"/> is some <see cref="Type"/> type.
/// </summary>
/// <returns>Whether the type is some <see cref="System.Type"/> type.</returns>
/// <returns>Whether the type is some <see cref="Type"/> type.</returns>
public bool IsTypeOfType(InteropReferences interopReferences)
{
return SignatureComparer.IgnoreVersion.Equals(type, interopReferences.Type);
}

/// <summary>
/// Checks whether an <see cref="ITypeDescriptor"/> is some <see cref="System.Exception"/> type.
/// Checks whether an <see cref="ITypeDescriptor"/> is some <see cref="Exception"/> type.
/// </summary>
/// <returns>Whether the type is some <see cref="System.Exception"/> type.</returns>
/// <returns>Whether the type is some <see cref="Exception"/> type.</returns>
public bool IsTypeOfException(InteropReferences interopReferences)
{
return SignatureComparer.IgnoreVersion.Equals(type, interopReferences.Exception);
Expand Down Expand Up @@ -675,20 +676,50 @@ public bool IsConstructedKeyValuePairType(InteropReferences interopReferences)
}

/// <summary>
/// Checks whether a <see cref="TypeSignature"/> is some <see cref="System.Nullable{T}"/> type.
/// Checks whether a <see cref="TypeSignature"/> is some <see cref="Nullable{T}"/> type.
/// </summary>
/// <param name="interopReferences">The <see cref="InteropReferences"/> instance to use.</param>
/// <returns>Whether the type is some <see cref="System.Nullable{T}"/> type.</returns>
/// <returns>Whether the type is some <see cref="Nullable{T}"/> type.</returns>
public bool IsConstructedNullableValueType(InteropReferences interopReferences)
{
return SignatureComparer.IgnoreVersion.Equals((signature as GenericInstanceTypeSignature)?.GenericType, interopReferences.Nullable1);
}

/// <summary>
/// Checks whether a <see cref="TypeSignature"/> is some <see cref="System.Span{T}"/> or <see cref="System.ReadOnlySpan{T}"/> type.
/// Tries to extract the underlying type from a constructed <see cref="Nullable{T}"/> type.
/// </summary>
/// <param name="interopReferences">The <see cref="InteropReferences"/> instance to use.</param>
/// <returns>Whether the type is some <see cref="System.Span{T}"/> or <see cref="System.ReadOnlySpan{T}"/> type.</returns>
/// <param name="underlyingType">The underlying nullable type, if the input type is a constructed <see cref="Nullable{T}"/> type.</param>
/// <returns>Whether <paramref name="underlyingType"/> was successfully retrieved.</returns>
public bool TryGetNullableUnderlyingType(InteropReferences interopReferences, [NotNullWhen(true)] out TypeSignature? underlyingType)
{
// First check that we have some constructed generic value type.
Copy link

Copilot AI Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pattern match checks IsValueType: true when extracting the underlying type from Nullable<T>. While Nullable<T> is indeed a value type, this constraint might be overly restrictive or unclear in intent. Consider adding a comment explaining why this check is necessary, or if it's primarily a sanity check to ensure we're dealing with value type generics.

Suggested change
// First check that we have some constructed generic value type.
// First check that we have some constructed generic value type.
// The 'IsValueType: true' constraint is a sanity check to ensure we're dealing with a
// value-type generic (such as Nullable<T>) and not an arbitrary reference-type generic.

Copilot uses AI. Check for mistakes.
// We also check that we have a single type argument to narrow down.
if (signature is not GenericInstanceTypeSignature { IsValueType: true, TypeArguments: [TypeSignature typeArgument] } genericSignature)
{
underlyingType = null;

return false;
}

// Check that we actually have a constructed 'Nullable<T>' type
if (!SignatureComparer.IgnoreVersion.Equals(genericSignature.GenericType, interopReferences.Nullable1))
{
underlyingType = null;

return false;
}

underlyingType = typeArgument;

return true;
}

/// <summary>
/// Checks whether a <see cref="TypeSignature"/> is some <see cref="Span{T}"/> or <see cref="ReadOnlySpan{T}"/> type.
/// </summary>
/// <param name="interopReferences">The <see cref="InteropReferences"/> instance to use.</param>
/// <returns>Whether the type is some <see cref="Span{T}"/> or <see cref="ReadOnlySpan{T}"/> type.</returns>
public bool IsConstructedSpanOrReadOnlySpanType(InteropReferences interopReferences)
{
if (signature is not GenericInstanceTypeSignature genericSignature)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
using WindowsRuntime.InteropGenerator.Errors;
using WindowsRuntime.InteropGenerator.Generation;
using WindowsRuntime.InteropGenerator.References;
using WindowsRuntime.InteropGenerator.Resolvers;
using static AsmResolver.PE.DotNet.Cil.CilOpCodes;

namespace WindowsRuntime.InteropGenerator.Factories;

/// <inheritdoc cref="InteropMethodRewriteFactory"/>
internal partial class InteropMethodRewriteFactory
/// <summary>
/// A factory to rewrite interop method definitons, and add marshalling code as needed.
/// </summary>
internal static partial class InteropMethodRewriteFactory
{
/// <summary>
/// Contains the logic for marshalling managed parameters (i.e. parameters that are passed to managed methods).
Expand Down Expand Up @@ -73,49 +76,26 @@ public static void RewriteMethod(
{
body.Instructions.ReferenceReplaceRange(marker, CilInstruction.CreateLdarg(parameterIndex));
}
else if (parameterType.IsConstructedKeyValuePairType(interopReferences))
{
// If the type is some constructed 'KeyValuePair<,>' type, we use the generated marshaller
body.Instructions.ReferenceReplaceRange(marker, [
CilInstruction.CreateLdarg(parameterIndex),
new CilInstruction(Call, emitState.LookupTypeDefinition(parameterType, "Marshaller").GetMethod("ConvertToManaged"))]);
}
else if (parameterType.IsConstructedNullableValueType(interopReferences))
{
TypeSignature underlyingType = ((GenericInstanceTypeSignature)parameterType).TypeArguments[0];

// For 'Nullable<T>' return types, we need the marshaller for the instantiated 'T' type (same as for return values)
ITypeDefOrRef marshallerType = GetValueTypeMarshallerType(underlyingType, interopReferences, emitState);

// Get the right reference to the unboxing marshalling method to call
IMethodDefOrRef marshallerMethod = marshallerType.GetMethodDefOrRef(
name: "UnboxToManaged"u8,
signature: MethodSignature.CreateStatic(
returnType: parameterType,
parameterTypes: [module.CorLibTypeFactory.Void.MakePointerType()]));
InteropMarshallerType marshallerType = InteropMarshallerTypeResolver.GetMarshallerType(parameterType, interopReferences, emitState);

// Emit code similar to 'KeyValuePair<,>' above, to marshal the resulting 'Nullable<T>' value
// For 'Nullable<T>' parameters (i.e. we have an 'IReference<T>' interface pointer), we unbox the underlying type
body.Instructions.ReferenceReplaceRange(marker, [
CilInstruction.CreateLdarg(parameterIndex),
new CilInstruction(Call, marshallerMethod.Import(module))]);
new CilInstruction(Call, marshallerType.UnboxToManaged().Import(module))]);
}
else
{
// The last case handles all other value types. It doesn't matter if they possibly hold some unmanaged
// resources, as they're only being used as parameters. That means the caller is responsible for disposal.
ITypeDefOrRef marshallerType = GetValueTypeMarshallerType(parameterType, interopReferences, emitState);

// Get the reference to 'ConvertToManaged' to produce the resulting value to return
IMethodDefOrRef marshallerMethod = marshallerType.GetMethodDefOrRef(
name: "ConvertToManaged"u8,
signature: MethodSignature.CreateStatic(
returnType: parameterType,
parameterTypes: [parameterType.GetAbiType(interopReferences)]));
// This case can also handle 'KeyValuePair<,>' instantiations, which are just marshalled normally too.
InteropMarshallerType marshallerType = InteropMarshallerTypeResolver.GetMarshallerType(parameterType, interopReferences, emitState);

// We can directly call the marshaller and return it, no 'try/finally' complexity is needed
body.Instructions.ReferenceReplaceRange(marker, [
CilInstruction.CreateLdarg(parameterIndex),
new CilInstruction(Call, marshallerMethod.Import(module))]);
new CilInstruction(Call, marshallerType.ConvertToManaged().Import(module))]);
}
}
else if (parameterType.IsTypeOfString())
Expand All @@ -125,29 +105,15 @@ public static void RewriteMethod(
CilInstruction.CreateLdarg(parameterIndex),
new CilInstruction(Call, interopReferences.HStringMarshallerConvertToManaged.Import(module))]);
}
else if (parameterType is GenericInstanceTypeSignature)
{
// This case (constructed interfaces or delegates) is effectively identical to marshalling 'KeyValuePair<,>' values
body.Instructions.ReferenceReplaceRange(marker, [
CilInstruction.CreateLdarg(parameterIndex),
new CilInstruction(Call, emitState.LookupTypeDefinition(parameterType, "Marshaller").GetMethod("ConvertToManaged"))]);
}
else
{
// Get the marshaller type for all other reference types
ITypeDefOrRef marshallerType = GetReferenceTypeMarshallerType(parameterType, interopReferences, emitState);

// Get the marshalling method, with the parameter type always just being 'void*' here too
IMethodDefOrRef marshallerMethod = marshallerType.GetMethodDefOrRef(
name: "ConvertToManaged"u8,
signature: MethodSignature.CreateStatic(
returnType: parameterType,
parameterTypes: [module.CorLibTypeFactory.Void.MakePointerType()]));
// Get the marshaller type for all other reference types (including generics)
InteropMarshallerType marshallerType = InteropMarshallerTypeResolver.GetMarshallerType(parameterType, interopReferences, emitState);

// Marshal the value and release the original interface pointer
// Marshal the value normally (the caller will own the native resource)
body.Instructions.ReferenceReplaceRange(marker, [
CilInstruction.CreateLdarg(parameterIndex),
new CilInstruction(Call, marshallerMethod.Import(module))]);
new CilInstruction(Call, marshallerType.ConvertToManaged().Import(module))]);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using WindowsRuntime.InteropGenerator.Errors;
using WindowsRuntime.InteropGenerator.Generation;
using WindowsRuntime.InteropGenerator.References;
using WindowsRuntime.InteropGenerator.Resolvers;
using static AsmResolver.PE.DotNet.Cil.CilOpCodes;

#pragma warning disable CS1573
Expand Down Expand Up @@ -100,17 +101,7 @@ public static void RewriteMethod(
}
else if (parameterType.IsConstructedNullableValueType(interopReferences))
{
TypeSignature underlyingType = ((GenericInstanceTypeSignature)parameterType).TypeArguments[0];

// For 'Nullable<T>' return types, we need the marshaller for the instantiated 'T' type (same as for return values)
ITypeDefOrRef marshallerType = GetValueTypeMarshallerType(underlyingType, interopReferences, emitState);

// Get the right reference to the unboxing marshalling method to call
IMethodDefOrRef marshallerMethod = marshallerType.GetMethodDefOrRef(
name: "BoxToUnmanaged"u8,
signature: MethodSignature.CreateStatic(
returnType: interopReferences.WindowsRuntimeObjectReferenceValue.ToValueTypeSignature(),
parameterTypes: [parameterType]));
InteropMarshallerType marshallerType = InteropMarshallerTypeResolver.GetMarshallerType(parameterType, interopReferences, emitState);

RewriteBody(
parameterType: parameterType,
Expand All @@ -119,29 +110,15 @@ public static void RewriteMethod(
loadMarker: loadMarker,
finallyMarker: finallyMarker,
parameterIndex: parameterIndex,
marshallerMethod: marshallerMethod,
marshallerMethod: marshallerType.BoxToUnmanaged(),
disposeMethod: null,
interopReferences: interopReferences,
module: module);
}
else
{
// The last case handles all other value types, which need explicit disposal for their ABI values
ITypeDefOrRef marshallerType = GetValueTypeMarshallerType(parameterType, interopReferences, emitState);

// Get the reference to 'ConvertToUnmanaged' to produce the resulting value to pass as argument
IMethodDefOrRef marshallerMethod = marshallerType.GetMethodDefOrRef(
name: "ConvertToUnmanaged"u8,
signature: MethodSignature.CreateStatic(
returnType: parameterType.GetAbiType(interopReferences),
parameterTypes: [parameterType]));

// Get the reference to 'Dispose' method to call on the ABI value
IMethodDefOrRef disposeMethod = marshallerType.GetMethodDefOrRef(
name: "Dispose"u8,
signature: MethodSignature.CreateStatic(
returnType: interopReferences.CorLibTypeFactory.Void,
parameterTypes: [parameterType.GetAbiType(interopReferences)]));
InteropMarshallerType marshallerType = InteropMarshallerTypeResolver.GetMarshallerType(parameterType, interopReferences, emitState);

RewriteBody(
parameterType: parameterType,
Expand All @@ -150,8 +127,8 @@ public static void RewriteMethod(
loadMarker: loadMarker,
finallyMarker: finallyMarker,
parameterIndex: parameterIndex,
marshallerMethod: marshallerMethod,
disposeMethod: disposeMethod,
marshallerMethod: marshallerType.ConvertToUnmanaged(),
disposeMethod: marshallerType.Dispose(),
interopReferences: interopReferences,
module: module);
}
Expand Down Expand Up @@ -189,14 +166,7 @@ public static void RewriteMethod(
else
{
// Get the marshaller for all other types (doesn't matter if constructed generics or not)
ITypeDefOrRef marshallerType = GetReferenceTypeMarshallerType(parameterType, interopReferences, emitState);

// Get the reference to 'ConvertToUnmanaged' to produce the resulting value to pass as argument
IMethodDefOrRef marshallerMethod = marshallerType.GetMethodDefOrRef(
name: "ConvertToUnmanaged"u8,
signature: MethodSignature.CreateStatic(
returnType: interopReferences.WindowsRuntimeObjectReferenceValue.ToValueTypeSignature(),
parameterTypes: [parameterType]));
InteropMarshallerType marshallerType = InteropMarshallerTypeResolver.GetMarshallerType(parameterType, interopReferences, emitState);

RewriteBody(
parameterType: parameterType,
Expand All @@ -205,7 +175,7 @@ public static void RewriteMethod(
loadMarker: loadMarker,
finallyMarker: finallyMarker,
parameterIndex: parameterIndex,
marshallerMethod: marshallerMethod,
marshallerMethod: marshallerType.ConvertToUnmanaged(),
disposeMethod: null,
interopReferences: interopReferences,
module: module);
Expand Down
Loading
Loading