diff --git a/src/Bolero/Router.fs b/src/Bolero/Router.fs index c68efb26..cd78a086 100644 --- a/src/Bolero/Router.fs +++ b/src/Bolero/Router.fs @@ -212,11 +212,24 @@ module private RouterImpl = (fun l -> listRevAndUnbox.Invoke(null, [|l|])) (fun l -> listLengthAndBox.Invoke(null, [|l|]) :?> _) + [] + type ParameterModifier = + | Basic + //| Optional + | Rest of (seq -> obj) * (obj -> seq) + + interface IEquatable with + member this.Equals(that) = + match this, that with + | Basic, Basic + | Rest _, Rest _ -> true + | _ -> false + type UnionParserSegment = | Constant of string /// A UnionParserSegment can be common among multiple union cases. /// fieldIndex lists these cases, and for each of them, its total number of fields and the index of the field for this segment. - | Parameter of fieldIndex: list * fieldType: Type * fieldSegment: Segment + | Parameter of fieldIndex: list * fieldType: Type * fieldSegment: Segment * ParameterModifier type UnionParser = { @@ -238,7 +251,50 @@ module private RouterImpl = let isConstantFragment (s: string) = not (s.Contains("{")) - let fragmentParameterRE = Regex(@"^\{([a-zA-Z0-9_]+)\}$", RegexOptions.Compiled) + type Unboxer = + static member List<'T> (items: seq) : list<'T> = + [ for x in items -> unbox<'T> x ] + + static member Array<'T> (items: seq) : 'T[] = + [| for x in items -> unbox<'T> x |] + + type Decons = + static member List<'T> (l: list<'T>) : seq = + Seq.cast l + + static member Array<'T> (l: 'T[]) : seq = + Seq.cast l + + let restModifierFor (ty: Type) = + if ty = typeof then + ty, Rest( + Seq.cast >> String.concat "/" >> box, + fun s -> (unbox s).Split('/') |> Seq.cast + ) + elif ty.IsArray && ty.GetArrayRank() = 1 then + let elt = ty.GetElementType() + let unboxer = typeof.GetMethod("Array", FLAGS_STATIC).MakeGenericMethod([|elt|]) + let decons = typeof.GetMethod("Array", FLAGS_STATIC).MakeGenericMethod([|elt|]) + elt, Rest( + (fun x -> unboxer.Invoke(null, [|x|])), + (fun x -> decons.Invoke(null, [|x|]) :?> _) + ) + elif ty.IsGenericType then + let tdef = ty.GetGenericTypeDefinition() + if tdef = typedefof> then + let targs = ty.GetGenericArguments() + let unboxer = typeof.GetMethod("List", FLAGS_STATIC).MakeGenericMethod(targs) + let decons = typeof.GetMethod("List", FLAGS_STATIC).MakeGenericMethod(targs) + targs.[0], Rest( + (fun x -> unboxer.Invoke(null, [|x|])), + (fun x -> decons.Invoke(null, [|x|]) :?> _) + ) + else + failwithf "Invalid type for *rest parameter: %A" ty + else + failwithf "Invalid type for *rest parameter: %A" ty + + let fragmentParameterRE = Regex(@"^\{([?*]?)([a-zA-Z0-9_]+)\}$", RegexOptions.Compiled) let parseEndPointCase getSegment (case: UnionCaseInfo) = let fields = case.GetFields() @@ -246,7 +302,7 @@ module private RouterImpl = fields |> Array.mapi (fun i p -> let ty = p.PropertyType - Parameter([case, fields.Length, i], ty, getSegment ty)) + Parameter([case, fields.Length, i], ty, getSegment ty, Basic)) |> List.ofSeq match parseEndPointCasePath case with // EndPoint "/" @@ -264,13 +320,18 @@ module private RouterImpl = else let m = fragmentParameterRE.Match(frag) if m.Success then - let fieldName = m.Groups.[1].Value + let fieldName = m.Groups.[2].Value match fields |> Array.tryFindIndex (fun p -> p.Name = fieldName) with | Some i -> let p = fields.[i] if unboundFields.Remove(fieldName) then let ty = p.PropertyType - Parameter([case, fields.Length, i], ty, getSegment ty) + let eltTy, modifier = + match m.Groups.[1].Value with + | "" -> ty, Basic + | "*" -> restModifierFor ty + | s -> failwithf "Invalid parameter modifier: %s" s + Parameter([case, fields.Length, i], ty, getSegment eltTy, modifier) else failwithf "Union case %s.%s has endpoint definition with duplicate field %s" case.DeclaringType.FullName case.Name fieldName @@ -299,15 +360,17 @@ module private RouterImpl = | true, x -> x | false, _ -> [] constants.[s] <- (case, rest) :: existing - | Parameter(n, ty, seg) :: rest -> + | Parameter(n, ty, seg, modif) :: rest -> match parameter with - | Some (n', ty', seg, ps) -> + | Some (n', ty', seg, ps, modif') -> if ty <> ty' then + failwithf "[1] Union %s has cases with conflicting endpoint definitions" case.DeclaringType.FullName + elif modif <> modif' then failwithf "[2] Union %s has cases with conflicting endpoint definitions" case.DeclaringType.FullName else - parameter <- Some (n @ n', ty, seg, (case, rest) :: ps) + parameter <- Some (n @ n', ty, seg, (case, rest) :: ps, modif) | None -> - parameter <- Some (n, ty, seg, [case, rest]) + parameter <- Some (n, ty, seg, [case, rest], modif) | [] -> match final with | Some _ -> @@ -325,10 +388,10 @@ module private RouterImpl = } match parameter with | None -> () - | Some (n, ty, seg, cases) -> + | Some (n, ty, seg, cases, modif) -> let tails, final = mergeEndPointCaseFragments cases yield { - head = Parameter(n, ty, seg) + head = Parameter(n, ty, seg, modif) tails = tails finalize = final } @@ -346,7 +409,7 @@ module private RouterImpl = run p.tails p.finalize rest | Constant _, _ -> None - | Parameter(n, _, seg), l -> + | Parameter(n, _, seg, Basic), l -> match seg.parse l with | None -> None | Some (o, rest) -> @@ -360,11 +423,31 @@ module private RouterImpl = a a.[i] <- o run p.tails p.finalize rest + | Parameter(n, _, seg, Rest(restBuild, _)), l -> + let restValues = ResizeArray() + let rec parse l = + match seg.parse l, l with + | None, [] -> + for (case, fieldCount, i) in n do + let a = + match d.TryGetValue(case) with + | true, a -> a + | false, _ -> + let a = Array.zeroCreate fieldCount + d.[case] <- a + a + a.[i] <- restBuild restValues + run p.tails p.finalize [] + | None, _::_ -> None + | Some (o, rest), _ -> + restValues.Add(o) + parse rest + parse l ) |> Option.orElseWith (fun () -> - final |> Option.map (fun (ty, ctor) -> + final |> Option.map (fun (case, ctor) -> let args = - match d.TryGetValue(ty) with + match d.TryGetValue(case) with | true, args -> args | false, _ -> [||] ctor args, l @@ -402,9 +485,12 @@ module private RouterImpl = let vals = dector o path |> List.collect (function | Constant s -> [s] - | Parameter(n, _, seg) -> + | Parameter(n, _, seg, Basic) -> let (_, _, i) = n |> List.find (fun (case', _, _) -> case' = case) seg.write vals.[i] + | Parameter(n, _, seg, Rest(_, decons)) -> + let (_, _, i) = n |> List.find (fun (case', _, _) -> case' = case) + [ for x in decons vals.[i] do yield! seg.write x ] ) let unionSegment (getSegment: Type -> Segment) (ty: Type) : Segment =