diff --git a/gen/DocumentFormat.OpenXml.Generator.Models/Models/SchemaType.cs b/gen/DocumentFormat.OpenXml.Generator.Models/Models/SchemaType.cs index 19c2042ab..4b705cfa1 100644 --- a/gen/DocumentFormat.OpenXml.Generator.Models/Models/SchemaType.cs +++ b/gen/DocumentFormat.OpenXml.Generator.Models/Models/SchemaType.cs @@ -11,6 +11,8 @@ public class SchemaType public bool IsPart => !string.IsNullOrEmpty(Part); + public bool IsRootElement => IsPart || BaseClass == "OpenXmlPartRootElement" || BaseClass == "OpenXmlPartRootElement"; + public string Part { get; set; } = null!; public ParticleOrderType CompositeType { get; set; } diff --git a/gen/DocumentFormat.OpenXml.Generator/FactoryGenerator.cs b/gen/DocumentFormat.OpenXml.Generator/FactoryGenerator.cs new file mode 100644 index 000000000..46aaca76e --- /dev/null +++ b/gen/DocumentFormat.OpenXml.Generator/FactoryGenerator.cs @@ -0,0 +1,135 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using DocumentFormat.OpenXml.Generator.Editor; +using DocumentFormat.OpenXml.Generator.Generators.Parts; +using Microsoft.CodeAnalysis; +using System.CodeDom.Compiler; +using System.Text; + +namespace DocumentFormat.OpenXml.Generator; + +[Generator] +public class FactoryGenerator : IIncrementalGenerator +{ + public void Initialize(IncrementalGeneratorInitializationContext context) + { + var openXml = context.GetOpenXmlGeneratorContext() + .GetOpenXmlServices(); + + context.RegisterSourceOutput(openXml, (context, openXml) => + { + GeneratePartFactory(context, openXml); + GenerateRootActivator(context, openXml); + }); + } + + private static void GeneratePartFactory(SourceProductionContext context, OpenXmlGeneratorServices openXml) + { + using var sw = new StringWriter(); + using var writer = new IndentedTextWriter(sw); + + writer.WriteFileHeader(); + + writer.WriteLine("using DocumentFormat.OpenXml;"); + writer.WriteLine("using DocumentFormat.OpenXml.Packaging;"); + writer.WriteLine(); + writer.WriteLine("namespace DocumentFormat.OpenXml.Features;"); + + writer.WriteLine("internal partial class TypedPartFactory : IPartFactory"); + + using (writer.AddBlock()) + { + writer.WriteLine("public T? Create() where T : OpenXmlPart"); + + using (writer.AddBlock()) + { + foreach (var part in openXml.Context.Parts) + { + writer.Write("if (typeof(T) == typeof("); + writer.Write(part.Name); + writer.WriteLine("))"); + + using (writer.AddBlock()) + { + writer.Write("return (T)(object)new "); + writer.Write(part.Name); + writer.WriteLine("();"); + } + + writer.WriteLine(); + } + + writer.WriteLine("return default;"); + } + } + + context.AddSource("TypedPartFactory", sw.ToString()); + } + + private static void GenerateRootActivator(SourceProductionContext context, OpenXmlGeneratorServices openXml) + { + using var sw = new StringWriter(); + using var writer = new IndentedTextWriter(sw); + + writer.WriteFileHeader(); + + writer.WriteLine("using System;"); + writer.WriteLine("using System.Collections.Generic;"); + writer.WriteLine("using DocumentFormat.OpenXml;"); + writer.WriteLine("using DocumentFormat.OpenXml.Packaging;"); + writer.WriteLine("using DocumentFormat.OpenXml.Framework.Metadata;"); + writer.WriteLine(); + writer.WriteLine("namespace DocumentFormat.OpenXml.Features;"); + + writer.WriteLine("internal partial class TypedRootElementFactory : IRootElementFactory"); + + using (writer.AddBlock()) + { + writer.WriteLine("public static IEnumerable GetAllRootElements()"); + + using (writer.AddBlock()) + { + foreach (var model in openXml.Context.Namespaces) + { + foreach (var type in model.Types) + { + if (type.IsRootElement) + { + var className = openXml.FindClassName(type.Name, fullyQualified: true); + + writer.Write("yield return new ElementFactory(typeof("); + writer.Write(className); + writer.Write("), new("); + writer.WriteString(openXml.GetNamespaceInfo(type.Name.QName.Prefix).Uri); + writer.Write(", "); + writer.WriteString(type.Name.QName.Name); + writer.Write("), () => new "); + writer.Write(className); + writer.WriteLine("());"); + } + } + } + } + } + + context.AddSource("TypedRootFactory", sw.ToString()); + } + + private static void WritePartFiles(SourceProductionContext context, OpenXmlGeneratorServices openXml) + { + var sb = new StringBuilder(); + var sw = new StringWriter(sb); + var writer = new IndentedTextWriter(sw); + + foreach (var part in openXml.Context.Parts) + { + sb.Clear(); + + writer.WriteFileHeader(); + writer.WritePart(openXml, part); + + context.AddSource(part.Name, sb.ToString()); + } + } +} diff --git a/gen/DocumentFormat.OpenXml.Generator/PartGenerator.cs b/gen/DocumentFormat.OpenXml.Generator/PartGenerator.cs index cf5fffebd..359324e86 100644 --- a/gen/DocumentFormat.OpenXml.Generator/PartGenerator.cs +++ b/gen/DocumentFormat.OpenXml.Generator/PartGenerator.cs @@ -16,8 +16,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { var openXml = context.GetOpenXmlGeneratorContext().GetOpenXmlServices(); var options = context.GetOpenXmlOptions().Select(static (o, _) => o.GenerateParts); + var parts = openXml.Combine(options); - context.RegisterSourceOutput(openXml.Combine(options), (context, data) => + context.RegisterSourceOutput(parts, (context, data) => { if (!data.Right) { @@ -25,19 +26,25 @@ public void Initialize(IncrementalGeneratorInitializationContext context) } var openXml = data.Left; - var sb = new StringBuilder(); - var sw = new StringWriter(sb); - var writer = new IndentedTextWriter(sw); - foreach (var part in openXml.Context.Parts) - { - sb.Clear(); + WritePartFiles(context, openXml); + }); + } - writer.WriteFileHeader(); - writer.WritePart(openXml, part); + private static void WritePartFiles(SourceProductionContext context, OpenXmlGeneratorServices openXml) + { + var sb = new StringBuilder(); + var sw = new StringWriter(sb); + var writer = new IndentedTextWriter(sw); - context.AddSource(part.Name, sb.ToString()); - } - }); + foreach (var part in openXml.Context.Parts) + { + sb.Clear(); + + writer.WriteFileHeader(); + writer.WritePart(openXml, part); + + context.AddSource(part.Name, sb.ToString()); + } } } diff --git a/src/DocumentFormat.OpenXml/Features/ClassActivator{T}.cs b/src/DocumentFormat.OpenXml/Features/ClassActivator{T}.cs deleted file mode 100644 index f86781900..000000000 --- a/src/DocumentFormat.OpenXml/Features/ClassActivator{T}.cs +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using System; -using System.Diagnostics; -using System.Linq.Expressions; -using System.Reflection; - -#if NETSTANDARD1_3 -using System.Linq; -#endif - -namespace DocumentFormat.OpenXml.Features -{ - internal static class ClassActivator - { - public static Func CreateActivator(Type type) - { -#if DEBUG - Debug.Assert(typeof(T).GetTypeInfo().IsAssignableFrom(type.GetTypeInfo())); -#endif - -#if NETSTANDARD1_3 - var constructor = type.GetTypeInfo().DeclaredConstructors.FirstOrDefault(c => !c.GetParameters().Any() && !c.IsStatic); -#else - var constructor = type.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.HasThis, Type.EmptyTypes, null); -#endif - - if (constructor is null) - { - throw new ArgumentOutOfRangeException(nameof(type)); - } - - return Expression.Lambda>(Expression.New(constructor)).Compile(); - } - } -} diff --git a/src/DocumentFormat.OpenXml/Features/IPartFactory.cs b/src/DocumentFormat.OpenXml/Features/IPartFactory.cs index 8a25c1412..ac2542a94 100644 --- a/src/DocumentFormat.OpenXml/Features/IPartFactory.cs +++ b/src/DocumentFormat.OpenXml/Features/IPartFactory.cs @@ -7,5 +7,5 @@ namespace DocumentFormat.OpenXml.Features; internal interface IPartFactory { - T Create() where T : OpenXmlPart; + T? Create() where T : OpenXmlPart; } diff --git a/src/DocumentFormat.OpenXml/Features/ReflectionBasedRootElementFactory.cs b/src/DocumentFormat.OpenXml/Features/ReflectionBasedRootElementFactory.cs deleted file mode 100644 index c94f974d4..000000000 --- a/src/DocumentFormat.OpenXml/Features/ReflectionBasedRootElementFactory.cs +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using DocumentFormat.OpenXml.Framework; -using DocumentFormat.OpenXml.Framework.Metadata; -using System; -using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; -using System.Reflection; - -namespace DocumentFormat.OpenXml.Features -{ - internal class ReflectionBasedRootElementFactory : IRootElementFactory - { - private readonly Lazy _factoryCollection; - - public ReflectionBasedRootElementFactory() - : this(typeof(ReflectionBasedRootElementFactory).GetTypeInfo().Assembly, ClassActivator.CreateActivator) - { - } - - public ReflectionBasedRootElementFactory(Assembly assembly, Func> activatorFactory) - { - _factoryCollection = new Lazy(() => CreatePartLookup(assembly, activatorFactory), isThreadSafe: true); - } - - private static ElementFactoryCollection CreatePartLookup(Assembly assembly, Func> activatorFactory) - { - List? lookup = null; - - foreach (var child in GetAllRootElements(assembly)) - { - if (lookup is null) - { - lookup = new List(); - } - - var key = ElementFactory.Create(child, activatorFactory(child)); - - lookup.Add(key); - } - - if (lookup is null) - { - return ElementFactoryCollection.Empty; - } - - return new ElementFactoryCollection(lookup); - } - - private static IEnumerable GetAllRootElements(Assembly assembly) - { - var types = assembly.GetTypes(); - - foreach (var elementType in types) - { - if (!elementType.GetTypeInfo().IsAbstract && typeof(OpenXmlPartRootElement).GetTypeInfo().IsAssignableFrom(elementType.GetTypeInfo())) - { - yield return elementType; - } - } - } - - public ElementFactoryCollection Collection => _factoryCollection.Value; - - public bool TryCreate(in OpenXmlQualifiedName qname, [NotNullWhen(true)] out OpenXmlElement? element) - { - element = _factoryCollection.Value.Create(in qname); - return element is not null; - } - } -} diff --git a/src/DocumentFormat.OpenXml/Features/ReflectionPartFactory.cs b/src/DocumentFormat.OpenXml/Features/ReflectionPartFactory.cs deleted file mode 100644 index b4d7cbede..000000000 --- a/src/DocumentFormat.OpenXml/Features/ReflectionPartFactory.cs +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using DocumentFormat.OpenXml.Packaging; -using System; - -namespace DocumentFormat.OpenXml.Features; - -internal class ReflectionPartFactory : IPartFactory -{ - public T Create() where T : OpenXmlPart => Utility.Activator(); - - private static class Utility - { - public static readonly Func Activator = ClassActivator.CreateActivator(typeof(T)); - } -} diff --git a/src/DocumentFormat.OpenXml/Features/TypedFeatures.cs b/src/DocumentFormat.OpenXml/Features/TypedFeatures.cs index c6cba294a..8b3ad51a5 100644 --- a/src/DocumentFormat.OpenXml/Features/TypedFeatures.cs +++ b/src/DocumentFormat.OpenXml/Features/TypedFeatures.cs @@ -29,10 +29,10 @@ public static IFeatureCollection Shared public int Revision => 0; - [KnownFeature(typeof(IRootElementFactory), typeof(ReflectionBasedRootElementFactory))] + [KnownFeature(typeof(IRootElementFactory), typeof(TypedRootElementFactory))] [KnownFeature(typeof(IPartMetadataFeature), typeof(CachedPartMetadataProvider))] [KnownFeature(typeof(IOpenXmlNamespaceResolver), typeof(OpenXmlNamespaceResolver))] - [KnownFeature(typeof(IPartFactory), typeof(ReflectionPartFactory))] + [KnownFeature(typeof(IPartFactory), typeof(TypedPartFactory))] [DelegatedFeature(nameof(FeatureCollection.Default), typeof(FeatureCollection))] [ThreadSafe] public partial T? Get(); diff --git a/src/DocumentFormat.OpenXml/Features/TypedRootElementFactory.cs b/src/DocumentFormat.OpenXml/Features/TypedRootElementFactory.cs new file mode 100644 index 000000000..962ce8f86 --- /dev/null +++ b/src/DocumentFormat.OpenXml/Features/TypedRootElementFactory.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using DocumentFormat.OpenXml.Framework; +using DocumentFormat.OpenXml.Framework.Metadata; +using System; +using System.Diagnostics.CodeAnalysis; + +namespace DocumentFormat.OpenXml.Features; + +internal partial class TypedRootElementFactory : IRootElementFactory +{ + private readonly Lazy _collection; + + public TypedRootElementFactory() + { + _collection = new Lazy(() => new ElementFactoryCollection(GetAllRootElements())); + } + + public ElementFactoryCollection Collection => _collection.Value; + + public bool TryCreate(in OpenXmlQualifiedName qname, [NotNullWhen(true)] out OpenXmlElement? element) + { + element = _collection.Value.Create(in qname); + return element is not null; + } +} diff --git a/src/DocumentFormat.OpenXml/Packaging/OpenXmlPartContainer.cs b/src/DocumentFormat.OpenXml/Packaging/OpenXmlPartContainer.cs index 0d8db09dc..6124ca017 100644 --- a/src/DocumentFormat.OpenXml/Packaging/OpenXmlPartContainer.cs +++ b/src/DocumentFormat.OpenXml/Packaging/OpenXmlPartContainer.cs @@ -1077,9 +1077,13 @@ internal T AddNewPartInternal() { ThrowIfObjectDisposed(); - // use reflection to create the instance. As the default constructor of part is not "public" var part = Features.GetRequired().Create(); + if (part is null) + { + throw new OpenXmlPackageException(ExceptionMessages.AddedPartIsNotAllowed); + } + try { InitPart(part, part.ContentType); @@ -1145,6 +1149,11 @@ internal T AddNewPartInternal(string? contentType, string? id) throw new ArgumentOutOfRangeException(nameof(T), ExceptionMessages.ExtendedPartNotAllowed); } + if (part is null) + { + throw new OpenXmlPackageException(ExceptionMessages.AddedPartIsNotAllowed); + } + if (contentType is not null && part.IsContentTypeFixed && !string.Equals(contentType, part.ContentType, StringComparison.Ordinal)) { throw new ArgumentOutOfRangeException(nameof(contentType), ExceptionMessages.ErrorContentType); diff --git a/test/DocumentFormat.OpenXml.Packaging.Tests/PartConstraintCollectionTests.cs b/test/DocumentFormat.OpenXml.Packaging.Tests/PartConstraintCollectionTests.cs index 859334a75..969e7c815 100644 --- a/test/DocumentFormat.OpenXml.Packaging.Tests/PartConstraintCollectionTests.cs +++ b/test/DocumentFormat.OpenXml.Packaging.Tests/PartConstraintCollectionTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using DocumentFormat.OpenXml.Features; using DocumentFormat.OpenXml.Framework; using System; using System.IO; @@ -85,6 +86,8 @@ public void PartsAreInherited() using (var m = new MemoryStream()) using (var doc = SpreadsheetDocument.Create(m, SpreadsheetDocumentType.Workbook, true)) { + doc.Features.Set(new CustomFactory(doc.Features.Get())); + var wb = doc.AddWorkbookPart(); // Adding new worksheet part using custom worksheetpart derived class @@ -94,11 +97,29 @@ public void PartsAreInherited() } } -#pragma warning disable CA1812 private sealed class PsWorksheetPart : WorksheetPart { } -#pragma warning restore CA1812 + + private class CustomFactory : IPartFactory + { + private readonly IPartFactory _other; + + public CustomFactory(IPartFactory other) + { + _other = other; + } + + public T Create() where T : OpenXmlPart + { + if (typeof(T) == typeof(PsWorksheetPart)) + { + return (T)(object)new PsWorksheetPart(); + } + + return _other.Create(); + } + } [RelationshipType(Relationship)] private class ConstraintTest1