diff --git a/packages/common/src/union/union.ts b/packages/common/src/union/union.ts new file mode 100644 index 00000000..22a5345c --- /dev/null +++ b/packages/common/src/union/union.ts @@ -0,0 +1,45 @@ +import { Provable, Struct, Field, InferProvable, Unconstrained } from "o1js"; + +import { padArray } from "../utils"; + +export function createQualifiedUnion< + T extends (Provable & { name: string })[], +>(provables: T) { + const maxLength = Math.max( + ...provables.map((provable) => provable.sizeInFields()) + ); + + const typeMap = Object.fromEntries( + provables.map(({ name }, index) => [name, index]) + ); + + class ProvableUnion extends Struct({ + array: Provable.Array(Field, maxLength), + type: Field, + }) { + public static from( + provable: Type, + value: InferProvable + ) { + const fields = provable.toFields(value); + const fullFields = padArray(fields, maxLength, () => Field(0)); + return new ProvableUnion({ + array: fullFields, + type: Field(typeMap[provable.name]), + }); + } + + public into(provable: Type): InferProvable { + const size = provable.sizeInFields(); + const fields = this.array.slice(0, size); + + this.array.slice(size).forEach((field) => field.assertEquals(0)); + + this.type.assertEquals(typeMap[provable.name]); + + return provable.fromFields(fields, []); + } + } + + return ProvableUnion; +} diff --git a/packages/common/test/union/union.test.ts b/packages/common/test/union/union.test.ts new file mode 100644 index 00000000..4c3c5085 --- /dev/null +++ b/packages/common/test/union/union.test.ts @@ -0,0 +1,16 @@ +import { Field, UInt64 } from "o1js"; + +import { createQualifiedUnion } from "../../src/union/union"; + +describe("union", () => { + it("should serialize correctly", () => { + const provable = createQualifiedUnion([Field, UInt64]); + const p = provable.from(UInt64, UInt64.from(1)); + // const p2 = provable.from(UInt32, UInt32.from(1)); + + const uint = p.into(UInt64); + const x = uint.add(UInt64.from(2)).toString(); + + expect(x).toBe("3"); + }); +});