From 790958ef7ae2da49aca88aaf7c9ef96174826510 Mon Sep 17 00:00:00 2001 From: Renaud Rohlinger Date: Tue, 21 Jan 2025 18:10:24 +0900 Subject: [PATCH 1/4] TSL: Add matrix operations support --- src/nodes/math/OperatorNode.js | 103 +++++++++++++++++++++++++++------ 1 file changed, 85 insertions(+), 18 deletions(-) diff --git a/src/nodes/math/OperatorNode.js b/src/nodes/math/OperatorNode.js index c009eed09761a6..f5fc05ca0e5b86 100644 --- a/src/nodes/math/OperatorNode.js +++ b/src/nodes/math/OperatorNode.js @@ -118,23 +118,39 @@ class OperatorNode extends TempNode { } else { - if ( typeA === 'float' && builder.isMatrix( typeB ) ) { + // Handle matrix operations + if ( builder.isMatrix( typeA ) ) { - return typeB; + if ( typeB === 'float' ) { + + return typeA; // matrix * scalar = matrix + + } else if ( builder.isVector( typeB ) ) { - } else if ( builder.isMatrix( typeA ) && builder.isVector( typeB ) ) { + return builder.getVectorFromMatrix( typeA ); // matrix * vector - // matrix x vector + } else if ( builder.isMatrix( typeB ) ) { + + return typeA; // matrix * matrix + + } - return builder.getVectorFromMatrix( typeA ); + } else if ( builder.isMatrix( typeB ) ) { - } else if ( builder.isVector( typeA ) && builder.isMatrix( typeB ) ) { + if ( typeA === 'float' ) { - // vector x matrix + return typeB; // scalar * matrix = matrix - return builder.getVectorFromMatrix( typeB ); + } else if ( builder.isVector( typeA ) ) { + + return builder.getVectorFromMatrix( typeB ); // vector * matrix + + } + + } - } else if ( builder.getTypeLength( typeB ) > builder.getTypeLength( typeA ) ) { + // Handle non-matrix cases + if ( builder.getTypeLength( typeB ) > builder.getTypeLength( typeA ) ) { // anytype x anytype: use the greater length vector @@ -182,17 +198,43 @@ class OperatorNode extends TempNode { typeA = type; typeB = builder.changeComponentType( typeB, 'uint' ); - } else if ( builder.isMatrix( typeA ) && builder.isVector( typeB ) ) { + } else if ( builder.isMatrix( typeA ) ) { + + if ( typeB === 'float' ) { + + // Keep matrix type for typeA, but ensure typeB stays float + typeB = 'float'; + + } else if ( builder.isVector( typeB ) ) { + + // matrix x vector + typeB = builder.getVectorFromMatrix( typeA ); + + } else if ( builder.isMatrix( typeB ) ) { + // matrix x matrix - keep both types + } else { + + typeA = typeB = type; + + } + + } else if ( builder.isMatrix( typeB ) ) { - // matrix x vector + if ( typeA === 'float' ) { - typeB = builder.getVectorFromMatrix( typeA ); + // Keep matrix type for typeB, but ensure typeA stays float + typeA = 'float'; - } else if ( builder.isVector( typeA ) && builder.isMatrix( typeB ) ) { + } else if ( builder.isVector( typeA ) ) { - // vector x matrix + // vector x matrix + typeA = builder.getVectorFromMatrix( typeB ); - typeA = builder.getVectorFromMatrix( typeB ); + } else { + + typeA = typeB = type; + + } } else { @@ -274,7 +316,20 @@ class OperatorNode extends TempNode { } else { - return builder.format( `( ${ a } ${ op } ${ b } )`, type, output ); + // Handle matrix operations + if ( builder.isMatrix( typeA ) && typeB === 'float' ) { + + return builder.format( `(${b} ${op} ${a})`, type, output ); + + } else if ( typeA === 'float' && builder.isMatrix( typeB ) ) { + + return builder.format( `(${a} ${op} ${b})`, type, output ); + + } else { + + return builder.format( `(${a} ${op} ${b})`, type, output ); + + } } @@ -282,11 +337,23 @@ class OperatorNode extends TempNode { if ( fnOpSnippet ) { - return builder.format( `${ fnOpSnippet }( ${ a }, ${ b } )`, type, output ); + return builder.format( `${fnOpSnippet}(${a}, ${b})`, type, output ); } else { - return builder.format( `${ a } ${ op } ${ b }`, type, output ); + if ( builder.isMatrix( typeA ) && typeB === 'float' ) { + + return builder.format( `${b} ${op} ${a}`, type, output ); + + } else if ( typeA === 'float' && builder.isMatrix( typeB ) ) { + + return builder.format( `${a} ${op} ${b}`, type, output ); + + } else { + + return builder.format( `${a} ${op} ${b}`, type, output ); + + } } From 2d1cbf8b5f388f25604b4f47677664adb8fec9ea Mon Sep 17 00:00:00 2001 From: Renaud Rohlinger Date: Tue, 21 Jan 2025 18:17:58 +0900 Subject: [PATCH 2/4] fix lint --- src/nodes/math/OperatorNode.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/nodes/math/OperatorNode.js b/src/nodes/math/OperatorNode.js index f5fc05ca0e5b86..cb2d48d274b7c9 100644 --- a/src/nodes/math/OperatorNode.js +++ b/src/nodes/math/OperatorNode.js @@ -337,7 +337,7 @@ class OperatorNode extends TempNode { if ( fnOpSnippet ) { - return builder.format( `${fnOpSnippet}(${a}, ${b})`, type, output ); + return builder.format( `${ fnOpSnippet }( ${ a }, ${ b } )`, type, output ); } else { @@ -347,11 +347,11 @@ class OperatorNode extends TempNode { } else if ( typeA === 'float' && builder.isMatrix( typeB ) ) { - return builder.format( `${a} ${op} ${b}`, type, output ); + return builder.format( `${ a } ${ op } ${ b }`, type, output ); } else { - return builder.format( `${a} ${op} ${b}`, type, output ); + return builder.format( `${ a } ${ op } ${ b }`, type, output ); } From 8f8b1c5fb4778c732a8f741e8af4b3ff758ed6ee Mon Sep 17 00:00:00 2001 From: Renaud Rohlinger Date: Tue, 21 Jan 2025 18:22:44 +0900 Subject: [PATCH 3/4] fix lint --- src/nodes/math/OperatorNode.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/nodes/math/OperatorNode.js b/src/nodes/math/OperatorNode.js index cb2d48d274b7c9..2a3e94a24c9da2 100644 --- a/src/nodes/math/OperatorNode.js +++ b/src/nodes/math/OperatorNode.js @@ -319,15 +319,15 @@ class OperatorNode extends TempNode { // Handle matrix operations if ( builder.isMatrix( typeA ) && typeB === 'float' ) { - return builder.format( `(${b} ${op} ${a})`, type, output ); + return builder.format( `( ${ b } ${ op } ${ a } )`, type, output ); } else if ( typeA === 'float' && builder.isMatrix( typeB ) ) { - return builder.format( `(${a} ${op} ${b})`, type, output ); + return builder.format( `${ a } ${ op } ${ b }`, type, output ); } else { - return builder.format( `(${a} ${op} ${b})`, type, output ); + return builder.format( `( ${ a } ${ op } ${ b } )`, type, output ); } From 555e729cd73820d5a6fcdbcf1b30b294b5112294 Mon Sep 17 00:00:00 2001 From: Renaud Rohlinger Date: Tue, 21 Jan 2025 18:24:02 +0900 Subject: [PATCH 4/4] fix CI --- src/nodes/math/OperatorNode.js | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/nodes/math/OperatorNode.js b/src/nodes/math/OperatorNode.js index 2a3e94a24c9da2..be13905f8137fb 100644 --- a/src/nodes/math/OperatorNode.js +++ b/src/nodes/math/OperatorNode.js @@ -343,11 +343,7 @@ class OperatorNode extends TempNode { if ( builder.isMatrix( typeA ) && typeB === 'float' ) { - return builder.format( `${b} ${op} ${a}`, type, output ); - - } else if ( typeA === 'float' && builder.isMatrix( typeB ) ) { - - return builder.format( `${ a } ${ op } ${ b }`, type, output ); + return builder.format( `${ b } ${ op } ${ a }`, type, output ); } else {