@@ -1323,14 +1323,14 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
13231323 (pt1.argInfos.init, typeTree(interpolateWildcards(pt1.argInfos.last.hiBound)))
13241324 case RefinedType (parent, nme.apply, mt @ MethodTpe (_, formals, restpe))
13251325 if (defn.isNonRefinedFunction(parent) || defn.isErasedFunctionType(parent)) && formals.length == defaultArity =>
1326- (formals, untpd.DependentTypeTree ( (_, syms) => restpe.substParams(mt, syms.map(_.termRef))))
1326+ (formals, untpd.InLambdaTypeTree (isResult = true , (_, syms) => restpe.substParams(mt, syms.map(_.termRef))))
13271327 case pt1 @ SAMType (mt @ MethodTpe (_, formals, _)) if ! SAMType .isParamDependentRec(mt) =>
13281328 val restpe = mt.resultType match
13291329 case mt : MethodType => mt.toFunctionType(isJava = pt1.classSymbol.is(JavaDefined ))
13301330 case tp => tp
13311331 (formals,
13321332 if (mt.isResultDependent)
1333- untpd.DependentTypeTree ( (_, syms) => restpe.substParams(mt, syms.map(_.termRef)))
1333+ untpd.InLambdaTypeTree (isResult = true , (_, syms) => restpe.substParams(mt, syms.map(_.termRef)))
13341334 else
13351335 typeTree(restpe))
13361336 case _ =>
@@ -1641,13 +1641,34 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16411641 val untpd .PolyFunction (tparams : List [untpd.TypeDef ] @ unchecked, fun) = tree : @ unchecked
16421642 val untpd .Function (vparams : List [untpd.ValDef ] @ unchecked, body) = fun : @ unchecked
16431643
1644+ // If the expected type is a polymorphic function with the same number of
1645+ // type and value parameters, then infer the types of value parameters from the expected type.
1646+ val inferredVParams = pt match
1647+ case RefinedType (parent, nme.apply, poly @ PolyType (_, mt : MethodType ))
1648+ if (parent.typeSymbol eq defn.PolyFunctionClass )
1649+ && tparams.lengthCompare(poly.paramNames) == 0
1650+ && vparams.lengthCompare(mt.paramNames) == 0
1651+ =>
1652+ vparams.zipWithConserve(mt.paramInfos): (vparam, formal) =>
1653+ // Unlike in typedFunctionValue, `formal` cannot be a TypeBounds since
1654+ // it must be a valid method parameter type.
1655+ if vparam.tpt.isEmpty && isFullyDefined(formal, ForceDegree .failBottom) then
1656+ cpy.ValDef (vparam)(tpt = new untpd.InLambdaTypeTree (isResult = false , (tsyms, vsyms) =>
1657+ // We don't need to substitute `mt` by `vsyms` because we currently disallow
1658+ // dependencies between value parameters of a closure.
1659+ formal.substParams(poly, tsyms.map(_.typeRef)))
1660+ )
1661+ else vparam
1662+ case _ =>
1663+ vparams
1664+
16441665 val resultTpt = pt.dealias match
16451666 case RefinedType (parent, nme.apply, poly @ PolyType (_, mt : MethodType )) if parent.classSymbol eq defn.PolyFunctionClass =>
1646- untpd.DependentTypeTree ( (tsyms, vsyms) =>
1667+ untpd.InLambdaTypeTree (isResult = true , (tsyms, vsyms) =>
16471668 mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
16481669 case _ => untpd.TypeTree ()
16491670
1650- val desugared = desugar.makeClosure(tparams, vparams , body, resultTpt, tree.span)
1671+ val desugared = desugar.makeClosure(tparams, inferredVParams , body, resultTpt, tree.span)
16511672 typed(desugared, pt)
16521673 end typedPolyFunctionValue
16531674
@@ -2098,6 +2119,18 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
20982119 case _ =>
20992120 completeTypeTree(InferredTypeTree (), pt, tree)
21002121
2122+ def typedInLambdaTypeTree (tree : untpd.InLambdaTypeTree , pt : Type )(using Context ): Tree =
2123+ val tp =
2124+ if tree.isResult then pt // See InLambdaTypeTree logic in Namer#valOrDefDefSig.
2125+ else
2126+ val lambdaCtx = ctx.outersIterator.dropWhile(_.owner.name ne nme.ANON_FUN ).next()
2127+ // A lambda has at most one type parameter list followed by exactly one term parameter list.
2128+ // Parameters are entered in order in the scope of the lambda.
2129+ val (tsyms : List [TypeSymbol @ unchecked], vsyms : List [TermSymbol @ unchecked]) =
2130+ lambdaCtx.scope.toList.partition(_.isType): @ unchecked
2131+ tree.tpFun(tsyms, vsyms)
2132+ completeTypeTree(InferredTypeTree (), tp, tree)
2133+
21012134 def typedSingletonTypeTree (tree : untpd.SingletonTypeTree )(using Context ): SingletonTypeTree = {
21022135 val ref1 = typedExpr(tree.ref, SingletonTypeProto )
21032136 checkStable(ref1.tpe, tree.srcPos, " singleton type" )
@@ -3109,7 +3142,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
31093142 case tree : untpd.TypedSplice => typedTypedSplice(tree)
31103143 case tree : untpd.UnApply => typedUnApply(tree, pt)
31113144 case tree : untpd.Tuple => typedTuple(tree, pt)
3112- case tree : untpd.DependentTypeTree => completeTypeTree(untpd. InferredTypeTree () , pt, tree )
3145+ case tree : untpd.InLambdaTypeTree => typedInLambdaTypeTree(tree , pt)
31133146 case tree : untpd.InfixOp => typedInfixOp(tree, pt)
31143147 case tree : untpd.ParsedTry => typedTry(tree, pt)
31153148 case tree @ untpd.PostfixOp (qual, Ident (nme.WILDCARD )) => typedAsFunction(tree, pt)
0 commit comments