Skip to content

Commit

Permalink
Support custom unwrapped unions (#469)
Browse files Browse the repository at this point in the history
Users can now specify a custom projection function to control
value unwrapping. This simplifies a variety of use-cases which
would previously require a type hook and logical type.
  • Loading branch information
joscha authored Sep 29, 2024
1 parent 384b656 commit aeada8c
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 54 deletions.
136 changes: 83 additions & 53 deletions lib/types.js
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class Type {
wrapUnions = 'auto';
} else if (typeof wrapUnions == 'string') {
wrapUnions = wrapUnions.toLowerCase();
} else if (typeof wrapUnions === 'function') {
wrapUnions = 'auto';
}
switch (wrapUnions) {
case 'always':
Expand Down Expand Up @@ -196,11 +198,20 @@ class Type {
let types = schema.map((obj) => {
return Type.forSchema(obj, opts);
});
let projectionFn;
if (!UnionType) {
UnionType = isAmbiguous(types) ? WrappedUnionType : UnwrappedUnionType;
if (typeof opts.wrapUnions === 'function') {
// we have a projection function
projectionFn = opts.wrapUnions(types);
UnionType = typeof projectionFn !== 'undefined'
? UnwrappedUnionType
: WrappedUnionType;
} else {
UnionType = isAmbiguous(types) ? WrappedUnionType : UnwrappedUnionType;
}
}
LOGICAL_TYPE = logicalType;
type = new UnionType(types, opts);
type = new UnionType(types, opts, projectionFn);
} else { // New type definition.
type = (function (typeName) {
let Type = TYPES[typeName];
Expand Down Expand Up @@ -341,10 +352,10 @@ class Type {
return branchTypes[name];
}), opts);
} catch (err) {
opts.wrapUnions = wrapUnions;
throw err;
} finally {
opts.wrapUnions = wrapUnions;
}
opts.wrapUnions = wrapUnions;
return unionType;
}

Expand Down Expand Up @@ -1226,6 +1237,60 @@ UnionType.prototype._branchConstructor = function () {
throw new Error('unions cannot be directly wrapped');
};


function generateProjectionIndexer(projectionFn) {
return (val) => {
const index = projectionFn(val);
if (typeof index !== 'number') {
throw new Error(`Projected index '${index}' is not valid`);
}
return index;
};
}

function generateDefaultIndexer(types) {
const dynamicBranches = [];
const bucketIndices = {};

const getBranchIndex = (any, index) => {
let logicalBranches = dynamicBranches;
for (let i = 0, l = logicalBranches.length; i < l; i++) {
let branch = logicalBranches[i];
if (branch.type._check(any)) {
if (index === undefined) {
index = branch.index;
} else {
// More than one branch matches the value so we aren't guaranteed to
// infer the correct type. We throw rather than corrupt data. This can
// be fixed by "tightening" the logical types.
throw new Error('ambiguous conversion');
}
}
}
return index;
}

types.forEach(function (type, index) {
if (Type.isType(type, 'abstract', 'logical')) {
dynamicBranches.push({index, type});
} else {
let bucket = getTypeBucket(type);
if (bucketIndices[bucket] !== undefined) {
throw new Error(`ambiguous unwrapped union: ${j(this)}`);
}
bucketIndices[bucket] = index;
}
});
return (val) => {
let index = bucketIndices[getValueBucket(val)];
if (dynamicBranches.length) {
// Slower path, we must run the value through all branches.
index = getBranchIndex(val, index);
}
return index;
};
}

/**
* "Natural" union type.
*
Expand All @@ -1246,54 +1311,17 @@ UnionType.prototype._branchConstructor = function () {
* + `map`, `record`
*/
class UnwrappedUnionType extends UnionType {
constructor (schema, opts) {
constructor (schema, opts, /* @private parameter */ _projectionFn) {
super(schema, opts);

this._dynamicBranches = null;
this._bucketIndices = {};
this.types.forEach(function (type, index) {
if (Type.isType(type, 'abstract', 'logical')) {
if (!this._dynamicBranches) {
this._dynamicBranches = [];
}
this._dynamicBranches.push({index, type});
} else {
let bucket = getTypeBucket(type);
if (this._bucketIndices[bucket] !== undefined) {
throw new Error(`ambiguous unwrapped union: ${j(this)}`);
}
this._bucketIndices[bucket] = index;
}
}, this);

Object.freeze(this);
}

_getIndex (val) {
let index = this._bucketIndices[getValueBucket(val)];
if (this._dynamicBranches) {
// Slower path, we must run the value through all branches.
index = this._getBranchIndex(val, index);
if (!_projectionFn && opts && typeof opts.wrapUnions === 'function') {
_projectionFn = opts.wrapUnions(this.types);
}
return index;
}
this._getIndex = _projectionFn
? generateProjectionIndexer(_projectionFn)
: generateDefaultIndexer(this.types);

_getBranchIndex (any, index) {
let logicalBranches = this._dynamicBranches;
for (let i = 0, l = logicalBranches.length; i < l; i++) {
let branch = logicalBranches[i];
if (branch.type._check(any)) {
if (index === undefined) {
index = branch.index;
} else {
// More than one branch matches the value so we aren't guaranteed to
// infer the correct type. We throw rather than corrupt data. This can
// be fixed by "tightening" the logical types.
throw new Error('ambiguous conversion');
}
}
}
return index;
Object.freeze(this);
}

_check (val, flags, hook, path) {
Expand Down Expand Up @@ -1355,16 +1383,18 @@ class UnwrappedUnionType extends UnionType {
// Using the `coerceBuffers` option can cause corruption and erroneous
// failures with unwrapped unions (in rare cases when the union also
// contains a record which matches a buffer's JSON representation).
if (isJsonBuffer(val) && this._bucketIndices.buffer !== undefined) {
index = this._bucketIndices.buffer;
} else {
index = this._getIndex(val);
if (isJsonBuffer(val)) {
let bufIndex = this.types.findIndex(t => getTypeBucket(t) === 'buffer');
if (bufIndex !== -1) {
index = bufIndex;
}
}
index ??= this._getIndex(val);
break;
case 2:
// Decoding from JSON, we must unwrap the value.
if (val === null) {
index = this._bucketIndices['null'];
index = this._getIndex(null);
} else if (typeof val === 'object') {
let keys = Object.keys(val);
if (keys.length === 1) {
Expand Down
51 changes: 51 additions & 0 deletions test/test_types.js
Original file line number Diff line number Diff line change
Expand Up @@ -3505,6 +3505,57 @@ suite('types', () => {
assert(Type.isType(t.field('unwrapped').type, 'union:unwrapped'));
});

test('union projection', () => {
const Dog = {
type: 'record',
name: 'Dog',
fields: [
{ type: 'string', name: 'bark' }
],
};
const Cat = {
type: 'record',
name: 'Cat',
fields: [
{ type: 'string', name: 'meow' }
],
};
const animalTypes = [Dog, Cat];

let callsToWrapUnions = 0;
const wrapUnions = (types) => {
callsToWrapUnions++;
assert.deepEqual(types.map(t => t.name), ['Dog', 'Cat']);
return (animal) => {
const animalType = ((animal) => {
if ('bark' in animal) {
return 'Dog';
} else if ('meow' in animal) {
return 'Cat';
}
throw new Error('Unknown animal');
})(animal);
return types.indexOf(types.find(type => type.name === animalType));
}
};

// Ambiguous, but we have a projection function
const Animal = Type.forSchema(animalTypes, { wrapUnions });
Animal.toBuffer({ meow: '🐈' });
assert.equal(callsToWrapUnions, 1);
assert.throws(() => Animal.toBuffer({ snap: '🐊' }), /Unknown animal/)
});

test('union projection with fallback', () => {
let t = Type.forSchema({
type: 'record',
fields: [
{name: 'wrapped', type: ['int', 'double' ]}, // Ambiguous.
]
}, {wrapUnions: () => undefined });
assert(Type.isType(t.field('wrapped').type, 'union:wrapped'));
});

test('invalid wrap unions option', () => {
assert.throws(() => {
Type.forSchema('string', {wrapUnions: 'FOO'});
Expand Down
17 changes: 16 additions & 1 deletion types/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,21 @@ interface EncoderOptions {
syncMarker: Buffer;
}

/**
* A projection function that is used when unwrapping unions.
* This function is called at schema parsing time on each union with its branches'
* types.
* If it returns a non-null (function) value, that function will be called each
* time a value's branch needs to be inferred and should return the branch's
* index.
* The index muss be a number between 0 and length-1 of the passed types.
* In this case (a branch index) the union will use an unwrapped representation.
* Otherwise (undefined), the union will be wrapped.
*/
type BranchProjection = (types: ReadonlyArray<Type>) =>
| ((val: unknown) => number)
| undefined;

interface ForSchemaOptions {
assertLogicalTypes: boolean;
logicalTypes: { [type: string]: new (schema: Schema, opts?: any) => types.LogicalType; };
Expand All @@ -103,7 +118,7 @@ interface ForSchemaOptions {
omitRecordMethods: boolean;
registry: { [name: string]: Type };
typeHook: (schema: Schema | string, opts: ForSchemaOptions) => Type | undefined;
wrapUnions: boolean | 'auto' | 'always' | 'never';
wrapUnions: BranchProjection | boolean | 'auto' | 'always' | 'never';
}

interface TypeOptions extends ForSchemaOptions {
Expand Down

0 comments on commit aeada8c

Please sign in to comment.