Skip to content

Commit

Permalink
Add outputDataType to argmin/argmax (#730)
Browse files Browse the repository at this point in the history
* Add outputDataType to argmin/argmax

* fix output value type

* add validation step
  • Loading branch information
philloooo authored Jul 24, 2024
1 parent 3c38c41 commit 4a2b8ca
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions index.bs
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,7 @@ Return the index location of the minimum or maximum values of all the input valu
<script type=idl>
dictionary MLArgMinMaxOptions {
boolean keepDimensions = false;
MLOperandDataType outputDataType = "int32";
};

partial interface MLGraphBuilder {
Expand All @@ -1463,6 +1464,10 @@ partial interface MLGraphBuilder {
: <dfn>keepDimensions</dfn>
::
If true, retains reduced dimensions with [=list/size=] 1.

: <dfn>outputDataType</dfn>
::
An {{MLOperandDataType}}. The output data type.
</dl>

<div dfn-for="MLGraphBuilder/argMin(input, axis, options), MLGraphBuilder/argMax(input, axis, options)" dfn-type=argument>
Expand All @@ -1471,7 +1476,7 @@ partial interface MLGraphBuilder {
- <dfn>axis</dfn>: The dimension to reduce. The value must be in the range [0, N-1] where N is the [=MLOperand/rank=] of the input tensor.
- <dfn>options</dfn>: an optional {{MLArgMinMaxOptions}}. The optional parameters of the operation.

**Returns:** an {{MLOperand}}. The N-D tensor of the reduced shape. The values must be of type {{MLOperandDataType/"int64"}} in the range [0, N-1] where N is the size of the input dimension specified by axis.
**Returns:** an {{MLOperand}}. The N-D tensor of the reduced shape. The values must be of type |options|.{{MLArgMinMaxOptions/outputDataType}} in the range [0, N-1] where N is the size of the input dimension specified by axis.
</div>

<details open algorithm>
Expand All @@ -1481,9 +1486,10 @@ partial interface MLGraphBuilder {
1. [=Assert=]: |op| is one of "argMin", "argMax".
1. If [=this=].{{MLGraphBuilder/[[hasBuilt]]}} is true, then [=exception/throw=] an "{{InvalidStateError}}" {{DOMException}}.
1. If [=MLGraphBuilder/validating operand=] with [=this=] and |input| returns false, then [=exception/throw=] a {{TypeError}}.
1. If |input|'s [=MLOperand/shape=][|axis|] is greater than |options|.{{MLArgMinMaxOptions/outputDataType}}'s maximum value, [=exception/throw=] a {{TypeError}}.
1. Let |outputShape| be the result of [=MLGraphBuilder/calculating reduction output sizes=] given |input|'s [=MLOperand/shape=], « |axis| », and |options|.{{MLArgMinMaxOptions/keepDimensions}}. If that returns failure, then [=exception/throw=] a {{TypeError}}.
1. Let |desc| be a new {{MLOperandDescriptor}}.
1. Set |desc|.{{MLOperandDescriptor/dataType}} to {{MLOperandDataType/"int64"}}.
1. Set |desc|.{{MLOperandDescriptor/dataType}} to |options|.{{MLArgMinMaxOptions/outputDataType}}.
1. Set |desc|.{{MLOperandDescriptor/dimensions}} to |outputShape|.
1. *Make graph connections:*
1. Let |operator| be an [=operator=] for the |op| operation, given |options|.
Expand Down

0 comments on commit 4a2b8ca

Please sign in to comment.