diff --git a/packages/runtime/spec/field-mask-utils.spec.ts b/packages/runtime/spec/field-mask-utils.spec.ts new file mode 100644 index 00000000..0a275d0e --- /dev/null +++ b/packages/runtime/spec/field-mask-utils.spec.ts @@ -0,0 +1,130 @@ +import {fieldMaskUtils} from "../src/field-mask-utils"; + +const { + canonicalForm, + from, + intersect, + union, +} = fieldMaskUtils; + +/** + * Other FieldMaskUtils test can be found in the test-generated package. + * These tests operater purely on plain strings and don't require any MessageTypes. + */ +describe('FieldMaskUtils', function () { + // https://github.com/protocolbuffers/protobuf/blob/c9d2bd2fc781/python/google/protobuf/internal/well_known_types_test.py#L463 + describe('canonicalForm()', function () { + it('sorts the paths', () => { + const mask = from('baz.quz,bar,foo'); + const outMask = from(''); + canonicalForm(mask, outMask); + expect(outMask.paths).toEqual(['bar', 'baz.quz', 'foo']); + }); + + it('deduplicates the paths', () => { + const mask = from('foo,bar,foo'); + const outMask = canonicalForm(mask); + expect(outMask).toEqual(from('bar,foo')); + }); + + it('removes sub-paths or other paths', () => { + const mask = from('foo.b1,bar.b1,foo.b2,bar'); + const outMask = canonicalForm(mask); + expect(outMask).toEqual(from('bar,foo.b1,foo.b2')); + }); + + it('handles more deeply nested cases', () => { + let mask = from([ + 'foo.bar.baz1', + 'foo.bar.baz2.quz', + 'foo.bar.baz2', + ]); + let outMask = canonicalForm(mask); + expect(outMask).toEqual(from('foo.bar.baz1,foo.bar.baz2')); + + mask = from([ + 'foo.bar.baz1', + 'foo.bar.baz2', + 'foo.bar.baz2.quz', + ]); + outMask = canonicalForm(mask); + expect(outMask).toEqual(from('foo.bar.baz1,foo.bar.baz2')); + + mask = from([ + 'foo.bar.baz1', + 'foo.bar.baz2', + 'foo.bar.baz2.quz', + 'foo.bar', + ]); + outMask = canonicalForm(mask); + expect(outMask).toEqual(from('foo.bar')); + + mask = from([ + 'foo.bar.baz1', + 'foo.bar.baz2', + 'foo.bar.baz2.quz', + 'foo', + ]); + outMask = canonicalForm(mask); + expect(outMask).toEqual(from('foo')); + }); + }); + + // https://github.com/protocolbuffers/protobuf/blob/c9d2bd2fc781/python/google/protobuf/internal/well_known_types_test.py#L499 + describe('union()', function () { + it('handles no overlap', () => { + const expected = from('bar,baz,foo,quz'); + const outMask = from(''); + const mask1 = from('foo,baz'); + const mask2 = from('bar,quz'); + expect(union(mask1, mask2)).toEqual(expected); + expect(outMask).not.toEqual(expected); + expect(union(mask1, mask2, outMask)).toEqual(expected); + expect(outMask).toEqual(expected); + }); + it('handles overlap with duplicate paths', () => { + const mask1 = from('foo,baz.bb'); + const mask2 = from('baz.bb,quz'); + expect(union(mask1, mask2)).toEqual(from('baz.bb,foo,quz')); + }); + it('handles overlap with paths covering some other paths', () => { + const mask1 = from('foo.bar.baz,quz'); + const mask2 = from('foo.bar,bar'); + expect(union(mask1, mask2)).toEqual(from('bar,foo.bar,quz')); + }); + }); + + // https://github.com/protocolbuffers/protobuf/blob/c9d2bd2fc781/python/google/protobuf/internal/well_known_types_test.py#L521 + describe('intersect()', function () { + it('handles no overlap', () => { + const mask1 = from('foo,baz'); + const mask2 = from('bar,quz'); + expect(intersect(mask1, mask2)).toEqual(from('')); + }); + it('handles overlap with duplicate paths', () => { + const mask1 = from('foo,baz.bb'); + const mask2 = from('baz.bb,quz'); + expect(intersect(mask1, mask2)).toEqual(from('baz.bb')); + }); + it('handles overlap with paths covering some other paths', () => { + const expected = from('foo.bar.baz'); + const mask1 = from('foo.bar.baz,quz'); + const mask2 = from('foo.bar,bar'); + expect(intersect(mask1, mask2)).toEqual(expected); + expect(intersect(mask2, mask1)).toEqual(expected); + }); + it('handles intersect "" with ""', () => { + const mask1 = from(''); + const mask2 = from(''); + mask1.paths.push(''); + mask2.paths.push(''); + expect(mask1.paths).toEqual(['']); + expect(intersect(mask1, mask2)).toEqual(from('')); + }); + it('handles overlap with unsorted fields', () => { + const mask1 = from('baz.bb,foo'); + const mask2 = from('quz,baz.bb'); + expect(intersect(mask1, mask2)).toEqual(from('baz.bb')); + }); + }); +}); diff --git a/packages/runtime/spec/reflection-merge-partial.spec.ts b/packages/runtime/spec/reflection-merge-partial.spec.ts index e7ab153b..c79210b7 100644 --- a/packages/runtime/spec/reflection-merge-partial.spec.ts +++ b/packages/runtime/spec/reflection-merge-partial.spec.ts @@ -2,6 +2,7 @@ import { IMessageType, MessageInfo, MessageType, + MergeOptions, normalizeFieldInfo, reflectionCreate, reflectionMergePartial, @@ -105,12 +106,12 @@ describe('reflectionMergePartial()', () => { describe('and source field empty', () => { const source: object = {child: undefined}; it('does not touch target', () => { - const target: any = {child: 123}; + const target: any = {child: 123, children: []}; reflectionMergePartial(messageInfo, target, source); expect(target.child).toBe(123); }); it('does not call child handler', () => { - reflectionMergePartial(messageInfo, {}, source); + reflectionMergePartial(messageInfo, {children: []}, source); expect(childHandler.create).not.toHaveBeenCalled(); expect(childHandler.mergePartial).not.toHaveBeenCalled(); }); @@ -119,12 +120,12 @@ describe('reflectionMergePartial()', () => { describe('and source field null', () => { const source: object = {child: null}; it('does not touch target', () => { - const target: any = {child: 123}; + const target: any = {child: 123, children: []}; reflectionMergePartial(messageInfo, target, source); expect(target.child).toBe(123); }); it('does not call child handler', () => { - reflectionMergePartial(messageInfo, {}, source); + reflectionMergePartial(messageInfo, {children: []}, source); expect(childHandler.create).not.toHaveBeenCalled(); expect(childHandler.mergePartial).not.toHaveBeenCalled(); }); @@ -133,13 +134,13 @@ describe('reflectionMergePartial()', () => { describe('and target field empty', () => { it('calls child handler´s create()', () => { const source = {child: {other_msg_fake_field: true}}; - const target = {child: undefined}; + const target = {child: undefined, children: []}; reflectionMergePartial(messageInfo, target, source); expect(childHandler.create).toHaveBeenCalled(); expect(childHandler.create).toHaveBeenCalledWith(source.child); }); it('uses child handler´s create()', () => { - const target: any = {}; + const target: any = {children: []}; const source = {child: {}}; reflectionMergePartial(messageInfo, target, source); expect(target.child).toEqual(handlerCreateReturn); @@ -149,10 +150,10 @@ describe('reflectionMergePartial()', () => { describe('and target field non-empty', () => { it('calls child handler´s mergePartial()', () => { const source = {child: {other_msg_fake_field: true}}; - const target = {child: {other_msg_fake_field: false}}; + const target = {child: {other_msg_fake_field: false}, children: []}; reflectionMergePartial(messageInfo, target, source); expect(childHandler.mergePartial).toHaveBeenCalled(); - expect(childHandler.mergePartial).toHaveBeenCalledWith({other_msg_fake_field: false}, {other_msg_fake_field: true}); + expect(childHandler.mergePartial).toHaveBeenCalledWith({other_msg_fake_field: false}, {other_msg_fake_field: true}, MergeOptions.defaults); }); }); diff --git a/packages/runtime/src/field-mask-utils.ts b/packages/runtime/src/field-mask-utils.ts new file mode 100644 index 00000000..6a6bbb0a --- /dev/null +++ b/packages/runtime/src/field-mask-utils.ts @@ -0,0 +1,336 @@ +import {MergeOptions} from "./merge-options"; +import {mergeFromFieldValue} from './reflection-merge-partial'; +import type {IMessageType, PartialMessage} from "./message-type-contract"; +import type {MessageInfo} from "./reflection-info"; +import type {UnknownMessage, UnknownOneofGroup} from "./unknown-types"; + +export interface FieldMaskLike { + paths: string[]; +} + +type ReadonlyPaths = ReadonlyArray; + +interface ReadonlyFieldMaskLike { + readonly paths: ReadonlyPaths; +} + +type FieldMaskOrPaths = ReadonlyFieldMaskLike | ReadonlyPaths | string; + +type FieldMaskTreeNode = Map; + +function isPaths(mask: ReadonlyFieldMaskLike | ReadonlyPaths): mask is ReadonlyPaths { + return Array.isArray(mask); +} + +function getPaths(mask: FieldMaskOrPaths): string[] { + const paths = typeof mask === 'string' + ? mask.split(',') + : isPaths(mask) + ? mask + : mask.paths; + return paths.filter(Boolean); +} + +/** Returns FieldMask from a: string, string array, or FieldMask */ +export function fieldMaskFrom(mask: FieldMaskOrPaths, target?: T): T; +export function fieldMaskFrom(mask: FieldMaskOrPaths, target: FieldMaskLike = {paths: []}): FieldMaskLike { + target.paths.length = 0; + target.paths.push(...getPaths(mask)); + return target; +} + +export function fieldMaskFromFieldNumbers(messageType: MessageInfo, fiNumbers: ReadonlyArray, target?: T): T; +export function fieldMaskFromFieldNumbers(messageType: MessageInfo, fiNumbers: ReadonlyArray, target: FieldMaskLike = {paths: []}): FieldMaskLike { + target.paths.length = 0; + const {fields, typeName} = messageType; + for (const no of fiNumbers) { + const field = fields.find((fi) => fi.no === no) + if (!field) + throw new TypeError(`Cannot find field number ${no} in message type ${typeName}.`); + target.paths.push(field.name); + } + return target; +} + +/** Checks whether the FieldMask is valid for MessageInfo */ +export function fieldMaskIsValid(messageType: MessageInfo, mask: FieldMaskOrPaths): boolean { + for (const path of getPaths(mask)) + if (!isValidPath(messageType, path)) + return false; + return true; +} + +/** Gets all direct fields of MessageInfo to FieldMask. */ +export function fieldMaskFromMessageType({fields}: MessageInfo, target?: T): T; +export function fieldMaskFromMessageType({fields}: MessageInfo, target: FieldMaskLike = {paths: []}): FieldMaskLike { + target.paths.length = 0; + for (const field of fields) + target.paths.push(field.name); + return target; +} + +/** + * Converts a FieldMask to the canonical form. + * + * Removes paths that are covered by another path. For example, + * "foo.bar" is covered by "foo" and will be removed if "foo" + * is also in the FieldMask. Then sorts all paths in alphabetical order. + */ +export function fieldMaskCanonicalForm(mask: FieldMaskLike, target?: T): T { + return new FieldMaskTree(mask).toFieldMask(target); +} + +/** Merges mask1 and mask2 into a target FieldMask */ +export function fieldMaskUnion(mask1: FieldMaskOrPaths, mask2: FieldMaskOrPaths, target?: T): T { + return new FieldMaskTree(mask1).mergeFromFieldMask(mask2).toFieldMask(target); +} + +/** Intersects mask1 and mask2 into a target FieldMask */ +export function fieldMaskIntersect(mask1: FieldMaskOrPaths, mask2: FieldMaskOrPaths, target?: T): T { + const tree = new FieldMaskTree(mask1); + const intersection = new FieldMaskTree(); + for (const path of getPaths(mask2)) + tree.intersectPath(path, intersection); + return intersection.toFieldMask(target); +} + +/** + * Merges fields specified in FieldMask from source to target. + * Note that this merge behavior is different from protobuf-ts's + * MessageType.mergePartial() or reflectionMergePartial(). + * By default it follows the canonical protobuf behavior of + * appending repeated fields instead of replacing them. + */ +export function fieldMaskMergeMessage< + T extends object = object, + M extends MessageInfo = MessageInfo +>( + mask: FieldMaskOrPaths, + messageType: M extends IMessageType ? M extends IMessageType ? M : never : MessageInfo, + target: T, + source: PartialMessage, + mergeOptions?: MergeOptions +): T { + new FieldMaskTree(mask).mergeMessage(messageType, target, source, mergeOptions); + return target; +} + +export const fieldMaskUtils = { + /** + * Converts a FieldMask to the canonical form. + * + * Removes paths that are covered by another path. For example, + * "foo.bar" is covered by "foo" and will be removed if "foo" + * is also in the FieldMask. Then sorts all paths in alphabetical order. + */ + canonicalForm: fieldMaskCanonicalForm, + /** Returns FieldMask from a: string, string array, or FieldMask */ + from: fieldMaskFrom, + fromFieldNumbers: fieldMaskFromFieldNumbers, + /** Gets all direct fields of MessageInfo to FieldMask. */ + fromMessageType: fieldMaskFromMessageType, + /** Intersects mask1 and mask2 into a target FieldMask */ + intersect: fieldMaskIntersect, + /** Checks whether the FieldMask is valid for MessageInfo */ + isValid: fieldMaskIsValid, + /** + * Merges fields specified in FieldMask from source to target. + * Note that this merge behavior is different from protobuf-ts's + * MessageType.mergePartial() or reflectionMergePartial(). + * By default it follows the canonical protobuf behavior of + * appending repeated fields instead of replacing them. + */ + mergeMessage: fieldMaskMergeMessage, + /** Merges mask1 and mask2 into a target FieldMask */ + union: fieldMaskUnion, +}; + +/** + * Represents a FieldMask in a tree structure. Each leaf + * node in this tree represent a field path in the FieldMask. + * For example, given a FieldMask `"foo.bar,foo.baz,bar.baz"`, + * the FieldMaskTree will be: + * ``` + * [root] -+- foo -+- bar + * | | + * | +- baz + * | + * +- bar --- baz + * ``` + */ +class FieldMaskTree { + private root: FieldMaskTreeNode = new Map(); + + constructor(fieldMask?: FieldMaskOrPaths) { + if (fieldMask) + this.mergeFromFieldMask(fieldMask); + } + + /** Merges a FieldMask to the tree. */ + mergeFromFieldMask(fieldMask: FieldMaskOrPaths): this { + for (const path of getPaths(fieldMask).sort()) + this.addPath(path); + return this; + } + + /** + * Adds a field path into the tree. + * + * If the field path to add is a sub-path of an existing field path + * in the tree (i.e., a leaf node), it means the tree already matches + * the given path so nothing will be added to the tree. If the path + * matches an existing non-leaf node in the tree, that non-leaf node + * will be turned into a leaf node with all its children removed because + * the path matches all the node's children. Otherwise, a new path will + * be added. + */ + addPath(path: string): this { + let node = this.root; + for (const part of path.split('.')) { + let nextNode = node.get(part); + if (!nextNode) + node.set(part, nextNode = new Map()); + else if (!nextNode.size) + // Pre-existing empty node implies we already have this entire tree. + return this; + node = nextNode; + } + // Remove any sub-trees we might have had. + node.clear(); + return this; + } + + /** Converts the tree to a FieldMask. */ + toFieldMask(mask?: T): T; + toFieldMask(mask: FieldMaskLike = {paths: []}): FieldMaskLike { + mask.paths.length = 0; + addFieldPaths(this.root, '', mask); + return mask; + } + + /** Calculates the intersection part of a field path with this tree. */ + intersectPath(path: string, intersection: FieldMaskTree): this { + let node = this.root; + for (const part of path.split('.')) { + let nextNode = node.get(part); + if (!nextNode) + return this; + else if (!nextNode.size) { + intersection.addPath(path); + return this; + } + node = nextNode; + } + intersection.addLeafNodes(path, node); + return this; + } + + /** Adds leaf nodes begin with prefix to this tree. */ + addLeafNodes(prefix: string, node: FieldMaskTreeNode): this { + if (!node.size) + this.addPath(prefix); + for (const [name, nextNode] of node) + this.addLeafNodes(`${prefix}.${name}`, nextNode); + return this; + } + + /** Merge all fields specified by this tree from source to target. */ + mergeMessage( + messageType: MessageInfo, + target: T, + source: PartialMessage, + mergeOptions: MergeOptions = {}, + ): this { + mergeMessageIntoTree( + this.root, + messageType, + target, + source, + // Use the canonical protobuf merge options + // (append instead of replace for repeated) + {repeated: MergeOptions.Repeated.APPEND, ...mergeOptions} + ) + return this; + } +} + + + +/** Merge all fields specified by a sub-tree from source to target. */ +function mergeMessageIntoTree( + node: FieldMaskTreeNode, + messageType: MessageInfo, + target: T, + source: PartialMessage, + maybeMergeOptions?: MergeOptions, +) { + const + {typeName, fields} = messageType, + mergeOptions = MergeOptions.withDefaults(maybeMergeOptions); + + for (const [childName, child] of node) { + const field = fields.find((fi) => fi.name === childName); + if (!field) + throw new TypeError(`Cannot find field ${childName} in message type ${typeName}.`); + + let name = field.localName, + src: UnknownMessage | UnknownOneofGroup = source as UnknownMessage, + out: UnknownMessage | UnknownOneofGroup = target as UnknownMessage; + + if (field.oneof) { + let sourceGroup = (source as UnknownMessage)[field.oneof] as UnknownOneofGroup | undefined, + targetGroup = (target as UnknownMessage)[field.oneof] as UnknownOneofGroup; + if (sourceGroup?.oneofKind !== name) + continue; + delete targetGroup[targetGroup.oneofKind!]; + targetGroup.oneofKind = name; + src = sourceGroup; + out = targetGroup; + } + + let fieldValue = src[name]; + + if (child.size) { + // Sub-paths are only allowed for singular message fields. + if (field.repeat || field.kind !== 'message') + throw new TypeError(`Field ${childName} in message ${typeName} ` + + `is not a singular message field and cannot have sub-fields.`); + if (fieldValue != undefined) { + const T = field.T(); + mergeMessageIntoTree( + child, + T, + (out[name] || (out[name] = T.create())) as UnknownMessage, + fieldValue as UnknownMessage, + mergeOptions + ); + } + continue; + } + + mergeFromFieldValue(field, fieldValue, out, mergeOptions); + } +} + +/** Checks whether the path is valid for MessageInfo */ +function isValidPath({fields}: MessageInfo, path: string): boolean { + const parts = path.split('.'); + let last = parts.pop(); + for (const name of parts) { + const field = fields.find((fi) => fi.name === name); + if (!field || field.repeat || field.kind !== 'message') + return false; + fields = field.T().fields; + } + return fields.some((fi) => fi.name === last); +} + +/** Adds the field paths descended from node to FieldMask. */ +function addFieldPaths(node: FieldMaskTreeNode, prefix: string, mask: FieldMaskLike): void { + if (!node.size && prefix) + mask.paths.push(prefix); + else + for (const name of Array.from(node.keys()).sort()) + addFieldPaths(node.get(name)!, prefix ? `${prefix}.${name}` : name, mask); + +} diff --git a/packages/runtime/src/index.ts b/packages/runtime/src/index.ts index c872e193..a8fc57ae 100644 --- a/packages/runtime/src/index.ts +++ b/packages/runtime/src/index.ts @@ -36,6 +36,9 @@ export { JsonReadOptions, JsonWriteOptions, JsonWriteStringOptions, jsonReadOptions, jsonWriteOptions, mergeJsonOptions } from './json-format-contract'; +// Merge options, types, and defaults +export {MergeOptions} from './merge-options'; + // Message type contract export {IMessageType, PartialMessage, MESSAGE_TYPE} from './message-type-contract'; @@ -78,6 +81,20 @@ export {ReflectionJsonReader} from './reflection-json-reader'; export {ReflectionJsonWriter} from './reflection-json-writer'; export {containsMessageType, MessageTypeContainer} from './reflection-contains-message-type'; +// FieldMask utils +export { + FieldMaskLike, + fieldMaskUtils, + fieldMaskCanonicalForm, + fieldMaskFrom, + fieldMaskFromFieldNumbers, + fieldMaskFromMessageType, + fieldMaskIntersect, + fieldMaskIsValid, + fieldMaskMergeMessage, + fieldMaskUnion, +} from './field-mask-utils'; + // Oneof helpers export {isOneofGroup, setOneofValue, getOneofValue, clearOneofValue, getSelectedOneofValue} from './oneof'; diff --git a/packages/runtime/src/merge-options.ts b/packages/runtime/src/merge-options.ts new file mode 100644 index 00000000..d9376554 --- /dev/null +++ b/packages/runtime/src/merge-options.ts @@ -0,0 +1,173 @@ +const MERGE_OPTIONS_DEFAULTED = Symbol.for("protobuf-ts/merge-options-defaulted"); + +export interface MergeOptions { + /** + * Merge options for map fields + * @default MergeOptions.Map.SHALLOW + */ + map?: MergeOptions.Map; + /** + * Merge options for repeated fields + * @default MergeOptions.Repeated.REPLACE + */ + repeated?: MergeOptions.Repeated; + /** + * Merge options for singular message fields + * @default MergeOptions.ReplaceMessages.NEVER + */ + replaceMessages?: MergeOptions.ReplaceMessages; + /** + * The MergeOptions have already been defaulted. + * Avoids excess comparisons, GC. + */ + [MERGE_OPTIONS_DEFAULTED]?: true; +} + +export namespace MergeOptions { + export type NonNullable = { + [k in keyof MergeOptions]-?: MergeOptions[k]; + } + + /** Merge options for map fields */ + export enum Map { + /** + * Source will replace target + * ```ts + * target = { foo: { a: 9, b: true }, bar: { a: 8, b: true } } + * source = { foo: { b: false }, baz: { a: 1 } } + * result = { foo: { b: false }, baz: { a: 1 } } + * ``` + */ + REPLACE = 1, + /** + * Source will overwrite target by key + * (canonical protobuf behavior, default for protobuf-ts) + * ```ts + * target = { foo: { a: 9, b: true }, bar: { a: 8, b: true } } + * source = { foo: { b: false }, baz: { a: 1 } } + * result = { foo: { b: false }, bar: { a: 8, b: true }, baz: { a: 1 } } + * ``` + */ + SHALLOW = 2, + /** + * Source will recursively merge + * ```ts + * target = { foo: { a: 9, b: true }, bar: { a: 8, b: true } } + * source = { foo: { b: false }, baz: { a: 1 } } + * result = { foo: { a: 9, b: false }, bar: { a: 8, b: true }, baz: { a: 1 } } + * ``` + */ + DEEP = 3, + } + + /** Merge options for repeated fields */ + export enum Repeated { + /** + * Source will append to target + * (canonical protobuf behavior) + * ```ts + * target = [{ a: 9, b: true }, { a: 8, b: true } ] + * source = [ { a: 1 }] + * result = [{ a: 9, b: true }, { a: 8, b: true }, { a: 1 }] + * ``` + */ + APPEND = 0, + /** + * Source will replace target + * (default for protobuf-ts) + * ```ts + * target = [{ a: 9, b: true }, { a: 8, b: true }] + * source = [{ a: 1 } ] + * result = [{ a: 1 } ] + * ``` + */ + REPLACE = 1, + /** + * Source will overwrite target by index + * ```ts + * target = [{ a: 9, b: true }, { a: 8, b: true }] + * source = [{ a: 1 } ] + * result = [{ a: 1 }, { a: 8, b: true }] + * ``` + */ + SHALLOW = 2, + /** + * Source will deeply merge target by index + * ```ts + * target = [{ a: 9, b: true }, { a: 8, b: true }] + * source = [{ a: 1 } ] + * result = [{ a: 1, b: true }, { a: 8, b: true }] + * ``` + */ + DEEP = 3, + } + + /** Merge options for singular message fields */ + export enum ReplaceMessages { + /** + * If singlular message field in source is set, it will be + * merged into target singular message field. + * (canonical protobuf behavior, default for protobuf-ts) + * ```ts + * // when set in source + * target = { msg: { a: 9, b: true }, str: "A" } + * source = { msg: { b: false }, str: "B" } + * result = { msg: { a: 9 b: false }, str: "B" } + * + * // when unset in source + * target = { msg: { a: 9, b: true }, str: "A" } + * source = { str: "B" } + * result = { msg: { a: 9, b: true }, str: "B" } + * ``` + */ + NEVER = 0, + /** + * If singlular message field in source is set, it will replace + * target singular message field. + * ```ts + * // when set in source + * target = { msg: { a: 9, b: true }, str: "A" } + * source = { msg: { b: false }, str: "B" } + * result = { msg: { b: false }, str: "B" } + * + * // when unset in source + * target = { msg: { a: 9, b: true }, str: "A" } + * source = { str: "B" } + * result = { msg: { a: 9, b: true }, str: "B" } + * ``` + */ + IF_SET = 1, + /** + * Source singular message field will replace target singular + * message field even if source field is unset. + * ```ts + * // when set in source + * target = { msg: { a: 9, b: true }, str: "A" } + * source = { msg: { b: false }, str: "B" } + * result = { msg: { b: false }, str: "B" } + * + * // when unset in source + * target = { msg: { a: 9, b: true }, str: "A" } + * source = { str: "B" } + * result = { str: "B" } + * ``` + */ + ALWAYS = 2, + } + + export const defaults: MergeOptions.NonNullable = { + map: MergeOptions.Map.SHALLOW, + repeated: MergeOptions.Repeated.REPLACE, + replaceMessages: MergeOptions.ReplaceMessages.NEVER, + [MERGE_OPTIONS_DEFAULTED]: true, + } + + export function withDefaults(o?: MergeOptions): MergeOptions.NonNullable { + return !o ? MergeOptions.defaults : o[MERGE_OPTIONS_DEFAULTED] ? o as MergeOptions.NonNullable : { + map: o.map ?? MergeOptions.defaults.map, + repeated: o.repeated ?? MergeOptions.defaults.repeated, + replaceMessages: o.replaceMessages ?? MergeOptions.defaults.replaceMessages, + [MERGE_OPTIONS_DEFAULTED]: true + }; + } +} diff --git a/packages/runtime/src/message-type-contract.ts b/packages/runtime/src/message-type-contract.ts index d4590101..78c817d3 100644 --- a/packages/runtime/src/message-type-contract.ts +++ b/packages/runtime/src/message-type-contract.ts @@ -2,6 +2,7 @@ import type {FieldInfo, MessageInfo} from "./reflection-info"; import type {BinaryReadOptions, BinaryWriteOptions, IBinaryReader, IBinaryWriter} from "./binary-format-contract"; import type {JsonValue} from "./json-typings"; import type {JsonReadOptions, JsonWriteOptions, JsonWriteStringOptions} from "./json-format-contract"; +import type {MergeOptions} from "./merge-options"; /** * The symbol used as a key on message objects to store the message type. @@ -129,23 +130,23 @@ export interface IMessageType extends MessageInfo { * If a singular scalar or enum field is present in the source, it * replaces the field in the target. * - * If a singular message field is present in the source, it is merged - * with the target field by calling mergePartial() of the responsible - * message type. + * By default if a singular message field is present in the source, + * it is merged with the target field by calling mergePartial() of + * the responsible message type. * - * If a repeated field is present in the source, its values replace - * all values in the target array, removing extraneous values. - * Repeated message fields are copied, not merged. + * By default if a repeated field is present in the source, its values + * replace all values in the target array, removing extraneous values. + * By default repeated message fields are copied, not merged. * - * If a map field is present in the source, entries are added to the - * target map, replacing entries with the same key. Entries that only - * exist in the target remain. Entries with message values are copied, - * not merged. + * By default if a map field is present in the source, entries are added + * to the target map, replacing entries with the same key. Entries that + * only exist in the target remain. By default, entries with message + * values are copied, not merged. * - * Note that this function differs from protobuf merge semantics, - * which appends repeated fields. + * Note that this function's defaults differs from protobuf merge + * semantics, which appends repeated fields. */ - mergePartial(target: T, source: PartialMessage): void; + mergePartial(target: T, source: PartialMessage, mergeOptions?: MergeOptions): void; /** diff --git a/packages/runtime/src/message-type.ts b/packages/runtime/src/message-type.ts index eb3cae49..fd98f043 100644 --- a/packages/runtime/src/message-type.ts +++ b/packages/runtime/src/message-type.ts @@ -1,3 +1,4 @@ +import type {MergeOptions} from "./merge-options"; import type {IMessageType, PartialMessage} from "./message-type-contract"; import type {FieldInfo, PartialFieldInfo} from "./reflection-info"; import {normalizeFieldInfo} from "./reflection-info"; @@ -143,8 +144,8 @@ export class MessageType implements IMessageType { /** * Copy partial data into the target message. */ - mergePartial(target: T, source: PartialMessage): void { - reflectionMergePartial(this, target, source); + mergePartial(target: T, source: PartialMessage, mergeOptions?: MergeOptions): void { + reflectionMergePartial(this, target, source, mergeOptions); } diff --git a/packages/runtime/src/reflection-merge-partial.ts b/packages/runtime/src/reflection-merge-partial.ts index d5cbccf5..351e51c0 100644 --- a/packages/runtime/src/reflection-merge-partial.ts +++ b/packages/runtime/src/reflection-merge-partial.ts @@ -1,100 +1,121 @@ -import type {MessageInfo} from "./reflection-info"; +import {MergeOptions} from "./merge-options"; +import type {FieldInfo} from "./reflection-info"; import type {PartialMessage} from "./message-type-contract"; import type {UnknownMessage, UnknownOneofGroup} from "./unknown-types"; - /** * Copy partial data into the target message. * * If a singular scalar or enum field is present in the source, it * replaces the field in the target. * - * If a singular message field is present in the source, it is merged - * with the target field by calling mergePartial() of the responsible - * message type. + * By default if a singular message field is present in the source, + * it is merged with the target field by calling mergePartial() of + * the responsible message type. * - * If a repeated field is present in the source, its values replace - * all values in the target array, removing extraneous values. - * Repeated message fields are copied, not merged. + * By default if a repeated field is present in the source, its values + * replace all values in the target array, removing extraneous values. + * By default repeated message fields are copied, not merged. * - * If a map field is present in the source, entries are added to the - * target map, replacing entries with the same key. Entries that only - * exist in the target remain. Entries with message values are copied, - * not merged. + * By default if a map field is present in the source, entries are added + * to the target map, replacing entries with the same key. Entries that + * only exist in the target remain. By default, entries with message + * values are copied, not merged. * - * Note that this function differs from protobuf merge semantics, - * which appends repeated fields. + * Note that this function's defaults differs from protobuf merge + * semantics, which appends repeated fields. */ -export function reflectionMergePartial(info: MessageInfo, target: T, source: PartialMessage) { - - let - fieldValue: UnknownMessage[string], // the field value we are working with - input = source as Partial, - output: UnknownMessage | UnknownOneofGroup; // where we want our field value to go - +export function reflectionMergePartial(info: { + /** Simple information for each message field in `T` */ + readonly fields: readonly FieldInfo[] +}, target: T, source: PartialMessage, maybeMergeOptions?: MergeOptions) { + const mergeOptions = MergeOptions.withDefaults(maybeMergeOptions); for (let field of info.fields) { let name = field.localName; - - if (field.oneof) { - const group = input[field.oneof] as UnknownOneofGroup | undefined; // this is the oneof`s group in the source - if (group?.oneofKind == undefined) { // the user is free to omit - continue; // we skip this field, and all other members too - } - fieldValue = group[name]; // our value comes from the the oneof group of the source - output = (target as UnknownMessage)[field.oneof] as UnknownOneofGroup; // and our output is the oneof group of the target - output.oneofKind = group.oneofKind; // always update discriminator - if (fieldValue == undefined) { - delete output[name]; // remove any existing value - continue; // skip further work on field - } - } else { - fieldValue = input[name]; // we are using the source directly - output = target as UnknownMessage; // we want our field value to go directly into the target - if (fieldValue == undefined) { - continue; // skip further work on field, existing value is used as is + if (!field.oneof) + mergeFromFieldValue(field, (source as UnknownMessage)[name], target as UnknownMessage, mergeOptions); + else { + let sourceGroup = (source as UnknownMessage)[field.oneof] as UnknownOneofGroup | undefined, + targetGroup = (target as UnknownMessage)[field.oneof] as UnknownOneofGroup; + if (sourceGroup?.oneofKind === name) { + delete targetGroup[targetGroup.oneofKind!]; + targetGroup.oneofKind = name; + mergeFromFieldValue(field, sourceGroup[name], targetGroup, mergeOptions); } } + } +} - if (field.repeat) - (output[name] as any[]).length = (fieldValue as any[]).length; // resize target array to match source array - - // now we just work with `fieldValue` and `output` to merge the value +export function mergeFromFieldValue( + field: FieldInfo, + fieldValue: UnknownMessage[string], + output: UnknownMessage, + mergeOptions: MergeOptions.NonNullable, +): void { + const name = field.localName; + if (field.repeat) { + if (!fieldValue) + return; + let outArr = output[name] as any[]; + if (mergeOptions.repeated === MergeOptions.Repeated.REPLACE) + outArr.length = 0; + let srcArr = fieldValue as any[], + lo = outArr.length, + hi = lo + srcArr.length; + switch (field.kind) { case "scalar": case "enum": - if (field.repeat) - for (let i = 0; i < (fieldValue as any[]).length; i++) - (output[name] as any[])[i] = (fieldValue as any[])[i]; // not a reference type - else - output[name] = fieldValue; // not a reference type - break; - + for (let i = lo; i < hi; i++) + outArr[i] = srcArr[i - lo]; // elements are not reference types + return; case "message": let T = field.T(); - if (field.repeat) - for (let i = 0; i < (fieldValue as any[]).length; i++) - (output[name] as any[])[i] = T.create((fieldValue as any[])[i]); - else if (output[name] === undefined) - output[name] = T.create(fieldValue as PartialMessage); // nothing to merge with - else - T.mergePartial(output[name], fieldValue as PartialMessage); - break; - - case "map": - // Map and repeated fields are simply overwritten, not appended or merged - switch (field.V.kind) { - case "scalar": - case "enum": - Object.assign(output[name], fieldValue); // elements are not reference types - break; - case "message": - let T = field.V.T(); - for (let k of Object.keys(fieldValue as any)) - (output[name] as any)[k] = T.create((fieldValue as any)[k]); - break; - } - break; - + for (let i = lo; i < hi; i++) + if (mergeOptions.repeated === MergeOptions.Repeated.DEEP && outArr[i]) + T.mergePartial(outArr[i], srcArr[i - lo], mergeOptions); + else + outArr[i] = T.create(srcArr[i - lo]); + return; } + return; + } + + // Only deal with non-repeated values + switch (field.kind) { + case "scalar": + case "enum": + if (fieldValue != undefined) + output[name] = fieldValue; // not a reference type + return; + case "message": + if (!fieldValue) { + if (mergeOptions.replaceMessages === MergeOptions.ReplaceMessages.ALWAYS) + delete output[name]; + } else if (mergeOptions.replaceMessages || !output[name]) + output[name] = field.T().create(fieldValue as PartialMessage); + else + field.T().mergePartial(output[name], fieldValue as PartialMessage, mergeOptions); + return; + case "map": + if (!fieldValue) + return; + let outMap = (mergeOptions.map === MergeOptions.Map.REPLACE + ? output[name] = {} + : output[name]) as any; + switch (field.V.kind) { + case "scalar": + case "enum": + Object.assign(outMap, fieldValue); // elements are not reference types + return; + case "message": + let T = field.V.T(); + for (let [k, v] of Object.entries(fieldValue as any)) + if (mergeOptions.map === MergeOptions.Map.DEEP && outMap[k]) + T.mergePartial(outMap[k], v as PartialMessage, mergeOptions); + else + outMap[k] = T.create(v as UnknownMessage); + return; + } } } diff --git a/packages/test-generated/spec/field-mask-utils.spec.ts b/packages/test-generated/spec/field-mask-utils.spec.ts new file mode 100644 index 00000000..dde6af9e --- /dev/null +++ b/packages/test-generated/spec/field-mask-utils.spec.ts @@ -0,0 +1,238 @@ +import {FieldMask} from "../ts-out/google/protobuf/field_mask"; +import { + NestedTestAllTypes, + TestAllTypes, + TestOneof2, +} from "../ts-out/google/protobuf/unittest"; +import {TestRecursiveMapMessage} from "../ts-out/google/protobuf/map_unittest"; +import {makeInt64Value} from "./support/helpers" +import {fieldMaskUtils, MergeOptions} from "@protobuf-ts/runtime"; +import {join} from "path"; +import {readFileSync} from "fs"; + +const { + from, + fromFieldNumbers, + fromMessageType, + isValid, + mergeMessage, +} = fieldMaskUtils; + +/** + * Other FieldMaskUtils tests can be found in the runtime package. + * These tests utilize MessageTypes that are generated from protobuf/src/google/protobuf and so must remain here. + */ +describe('FieldMaskUtils', function () { + // https://github.com/protocolbuffers/protobuf/blob/c9d2bd2fc781/python/google/protobuf/internal/well_known_types_test.py#L423 + describe('fromMessageType()', function () { + it('returns a FieldMask of all top-level fields in MessageType', () => { + const {fields} = TestAllTypes; + const mask = fromMessageType(TestAllTypes); + expect(fields.length).toEqual(mask.paths.length); + expect(isValid(TestAllTypes, mask)).toBeTrue(); + for (const fi of fields) + expect(mask.paths).toContain(fi.name); + }); + }); + + // https://github.com/protocolbuffers/protobuf/blob/c9d2bd2fc781/java/util/src/test/java/com/google/protobuf/util/FieldMaskUtilTest.java#L126 + describe('fromFieldNumbers()', function () { + it('returns a FieldMask for a MessageType given an array of field numbers in the message', () => { + let mask = fromFieldNumbers(TestAllTypes, []); + expect(mask.paths).toEqual([]); + mask = fromFieldNumbers(TestAllTypes, [1]); + expect(mask.paths).toEqual(['optional_int32']); + mask = fromFieldNumbers(TestAllTypes, [1, 2]); + expect(mask.paths).toEqual(['optional_int32', 'optional_int64']); + expect(() => { + mask = fromFieldNumbers(TestAllTypes, [1000]); + }).toThrowError(/cannot find field number 1000/i); + }); + }); + + // https://github.com/protocolbuffers/protobuf/blob/c9d2bd2fc781/java/util/src/test/java/com/google/protobuf/util/FieldMaskUtilTest.java#L43 + describe('isValid()', function () { + ([ + [true, 'payload'], + [false, 'nonexist'], + [true, 'payload.optional_int32'], + [true, 'payload.repeated_int32'], + [true, 'payload.optional_nested_message'], + [true, 'payload.repeated_nested_message'], + [true, from('payload')], + [false, from('nonexist')], + [false, from('payload,nonexist')], + [true, 'payload.optional_nested_message.bb'], + [false, 'payload.repeated_nested_message.bb', 'Repeated fields cannot have sub-paths.'], + [false, 'payload.optional_int32.bb', 'Non-message fields cannot have sub-paths.'], + ] as any[]).forEach(([expected, input, failMessage = 'fail']) => { + it(`returns ${String(expected)} for ${JSON.stringify(input)}`, () => { + expect(isValid(NestedTestAllTypes, input)).toBe(expected, failMessage); + }) + }); + }); + + // https://github.com/protocolbuffers/protobuf/blob/c9d2bd2fc781/python/google/protobuf/internal/well_known_types_test.py#L556 + describe('mergeMessage()', function () { + const goldenFilePath = join(__dirname, '../../proto/google/protobuf/testdata/golden_message_oneof_implemented'); + const goldenMessageBinary = new Uint8Array(readFileSync(goldenFilePath)); + it('merges correctly with just one field set', () => { + const src = TestAllTypes.fromBinary(goldenMessageBinary); + TestAllTypes.fields.forEach((fi) => { + if (fi.oneof) + return; + const fiName = fi.name as keyof TestAllTypes; + const fiLocalName = fi.localName as keyof TestAllTypes; + expect(src[fiLocalName]).withContext(`golden message should have set "${fiName}" (field no. ${fi.no})`).toBeDefined(); + const dst = TestAllTypes.create(); + const mask = from(fiName); + mergeMessage(mask, TestAllTypes, dst, src); + const expected = TestAllTypes.create({ [fiLocalName]: src[fiLocalName] }); + expect(dst).withContext(`"${fiName}" (field no. ${fi.no})`).toEqual(expected); + }); + }); + + const src = NestedTestAllTypes.create({ + child: { payload: { optionalInt32: 1234 }, + child: { payload: { optionalInt32: 5678 } }, + }, + }); + + it('merges nested fields', () => { + let dst = NestedTestAllTypes.create(); + let mask = from('child.payload'); + mergeMessage(mask, NestedTestAllTypes, dst, src); + expect(dst.child?.payload?.optionalInt32).toBe(1234); + expect(dst.child?.child?.payload?.optionalInt32).toBeUndefined(); + + mask = from('child.child.payload'); + mergeMessage(mask, NestedTestAllTypes, dst, src); + expect(dst.child?.payload?.optionalInt32).toBe(1234); + expect(dst.child?.child?.payload?.optionalInt32).toBe(5678); + + dst = NestedTestAllTypes.create(); + mergeMessage(mask, NestedTestAllTypes, dst, src); + expect(dst.child?.payload?.optionalInt32).toBeUndefined(); + expect(dst.child?.child?.payload?.optionalInt32).toBe(5678); + + dst = NestedTestAllTypes.create(); + mask = from('child'); + mergeMessage(mask, NestedTestAllTypes, dst, src); + expect(dst.child?.payload?.optionalInt32).toBe(1234); + expect(dst.child?.child?.payload?.optionalInt32).toBe(5678); + }); + + it('(by default) merges message fields. Change the behavior to replace message fields.', () => { + const int64Value = makeInt64Value(4321).value; + const dst = NestedTestAllTypes.create({ + child: { payload: { optionalInt64: int64Value } }, + }); + const mask = from('child.payload'); + // (by default) merges message fields. + mergeMessage(mask, NestedTestAllTypes, dst, src); + expect(dst.child?.payload?.optionalInt32).toBe(1234); + expect(dst.child?.payload?.optionalInt64).toBe(int64Value); + // Change the behavior to replace message fields. + mergeMessage(mask, NestedTestAllTypes, dst, src, { + replaceMessages: MergeOptions.ReplaceMessages.ALWAYS + }); + expect(dst.child?.payload?.optionalInt32).toBe(1234); + expect(dst.child?.payload?.optionalInt64).toBeUndefined(); + }); + + it('(by default) will keep fields if missing in source. But they are cleared when replacing message fields.', () => { + let dst = NestedTestAllTypes.create({ + payload: { optionalInt32: 1234 }, + }); + const mask = from('payload'); + mergeMessage(mask, NestedTestAllTypes, dst, src); + expect(dst.payload?.optionalInt32).toBe(1234); + dst = NestedTestAllTypes.create({ + payload: { optionalInt32: 1234 } + }); + mergeMessage(mask, NestedTestAllTypes, dst, src, { + replaceMessages: MergeOptions.ReplaceMessages.ALWAYS + }); + expect(dst.payload).toBeUndefined(); + }); + + it('(by default) will append repeated fields. Change the behavior to replace repeated fields.', () => { + let src = NestedTestAllTypes.create({ + payload: { repeatedInt32: [1234] }, + }); + let dst = NestedTestAllTypes.create({ + payload: { repeatedInt32: [5678] }, + }); + const mask = FieldMask.fromJson('payload.repeatedInt32'); + // (by default) it will append repeated fields. + mergeMessage(mask, NestedTestAllTypes, dst, src); + expect(dst.payload?.repeatedInt32).toEqual([5678, 1234]); + dst = NestedTestAllTypes.create({ + payload: { repeatedInt32: [5678] }, + }); + // Change the behavior to replace repeated fields. + mergeMessage(mask, NestedTestAllTypes, dst, src, { + repeated: MergeOptions.Repeated.REPLACE + }); + expect(dst.payload?.repeatedInt32).toEqual([1234]); + }); + + it('merges oneof fields', () => { + const mask = FieldMask.fromJson('fooMessage,fooLazyMessage.quxInt'); + let dst = TestOneof2.fromJson({ fooMessage: { quxInt: '1' } }); + + // src does not have any of the foo oneof fields set, so no change to dst + let src = TestOneof2.create(); + mergeMessage(mask, TestOneof2, dst, src); + expect(TestOneof2.toJson(dst)).toEqual({ fooMessage: { quxInt: '1' } }); + + // the oneof foo field which is set in src is not part of the mask, so no change to dst + src = TestOneof2.fromJson({ fooInt: 1 }); + mergeMessage(mask, TestOneof2, dst, src); + expect(TestOneof2.toJson(dst)).toEqual({ fooMessage: { quxInt: '1' } }); + + // the oneof foo field which is set in src is part of the mask, but only partially + src = TestOneof2.fromJson({ fooLazyMessage: { corgeInt: [1], quxInt: '1' } }); + mergeMessage(mask, TestOneof2, dst, src); + expect(TestOneof2.toJson(dst)).toEqual({ fooLazyMessage: { quxInt: '1' } }); + }); + + it('merges map fields', () => { + const emptyMap = TestRecursiveMapMessage.create(); + const srcLevel2 = TestRecursiveMapMessage.create({ a: { + ['src level 2']: TestRecursiveMapMessage.clone(emptyMap), + }}); + const src = TestRecursiveMapMessage.create({ a: { + ['common key']: TestRecursiveMapMessage.clone(srcLevel2), + ['src level 1']: TestRecursiveMapMessage.clone(srcLevel2), + }}); + + const dstLevel2 = TestRecursiveMapMessage.create({ a: { + ['dst level 2']: TestRecursiveMapMessage.clone(emptyMap), + }}); + const dst = TestRecursiveMapMessage.create({ a: { + ['common key']: TestRecursiveMapMessage.clone(dstLevel2), + ['dst level 1']: TestRecursiveMapMessage.clone(emptyMap), + }}); + + const mask = from('a'); + mergeMessage(mask, TestRecursiveMapMessage, dst, src); + + // map from dst is replaced with map from src. + expect(dst.a['common key']).toEqual(srcLevel2); + expect(dst.a['src level 1']).toEqual(srcLevel2); + expect(dst.a['dst level 1']).toEqual(emptyMap); + }); + + it('throws for bad merge paths (repeated fields)', () => { + const src = TestAllTypes.fromBinary(goldenMessageBinary); + const dst = TestAllTypes.create(); + const mask = FieldMask.fromJson('optionalInt32.field'); + expect(() => mergeMessage(mask, TestAllTypes, dst, src)).toThrowError( + 'Field optional_int32 in message protobuf_unittest.TestAllTypes ' + + 'is not a singular message field and cannot have sub-fields.' + ); + }); + }); + +});