Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Transition matrices with incorrect shapes do not throw error #63

Open
buddejul opened this issue Mar 22, 2024 · 3 comments
Open

BUG: Transition matrices with incorrect shapes do not throw error #63

buddejul opened this issue Mar 22, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@buddejul
Copy link

buddejul commented Mar 22, 2024

I came across unexpected behavior for stochastic next functions with the _period argument.

For N_PERIODS = 3 and a binary state I thought we needed a 2x2 matrix (one for each period transition). But the following does not throw an error:

@lcm.mark.stochastic
def next_dummy_state(_period):
    pass

DUMMY_STATE_TRANSITION = jnp.array(
    [[0.5, 0.5]]
)

I checked it's actually included in the model, providing an empty transition matrices raises an indexing error as expected.

Providing a matrix that's "too large" (e.g. (4, 2)) here also doesn't result in an error.

@buddejul buddejul added the bug Something isn't working label Mar 22, 2024
@timmens
Copy link
Member

timmens commented Mar 22, 2024

Thank you for opening the issue!

This seems to be related to JAX's out-of-bounds indexing. The following code runs without error:

import jax.numpy as jnp

jnp.arange(10)[11]

> Array(9, dtype=int32)

For reference,see: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#out-of-bounds-indexing

I will have to check if it is not a more serious problem however. In any case, we have to throw an informative error on the LCM side.

@hmgaudecker
Copy link
Member

Great detective work. Sounds like we'll need to add bound checks up front.

Also, a good example why you want to test with asymmetric inputs. I wonder whether things still work with [0.4, 0.6] or you'd see a violation of add-to-one constraints?

@buddejul
Copy link
Author

buddejul commented Mar 22, 2024

Re asymmetry: I originally noticed this for a state in our replication and the row had asymmetric entries (but would need to confirm). So I guess the indexing is not at the individual entry level but the rows (at least in this example with just period as argument).

FWIW I also tried different combinations of required rows (i.e..periods) and provided rows but this didn't seem to have an effect, consistent with the above.


Edit:

Dummy state transition: [[0.5 0.5]] 
Solved 
Dummy state transition: [[0.3 0.7]] 
Solved 
Dummy state transition: [[0.5]] 
Solved 
Dummy state transition: [[0.3]] 
Solved 
Dummy state transition: [[0.3 0.7]
 [0.1 0.9]] 
Solved 
Dummy state transition: [] 
Failed with error: mul got incompatible shapes for broadcasting: (2, 2, 2, 2), (0, 2, 2, 2). 
Dummy state transition: [[0.3 0.6 0.1]] 
Failed with error: mul got incompatible shapes for broadcasting: (2, 2, 2, 2), (3, 2, 2, 2). 
Dummy state transition: [[0.3 0.7 0.9]] 
Failed with error: mul got incompatible shapes for broadcasting: (2, 2, 2, 2), (3, 2, 2, 2). 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants