-
Notifications
You must be signed in to change notification settings - Fork 1
How will we interface with different molecular dynamics (MD) codes? #42
Comments
Because we want to allow people to transfert their NumpyModule to a TorchModule so they can get access to TorchScript (see issue #42) we need to change the way how we do the inheritance. Before it was Module (reference to torch.nn.Module or our BaseModule) --> CustomModule (e.g. Ridge) But that means when loading the library there is just one class CustomModule that even inherits from torch.nn.Module or BaseModule depending if torch is available on the machine. With one inheritance, it is hard switch between the classes. Changing the base class is very hacky, so this is not a good approach. There we create both classes when torch is present (note BaseModule wase renamed to NumpyModule) def factory_custom_module(base): class _CustomModule(base): ... # change name ... return _CustomModule CustomNumpyModule = factory_custom_module(NumpyModule) CustomTorchModule = factory_custom_module(torch.nn.Module) if HAS_TORCH: CustomModule = CustomTorchModule else: CustomModule = CustomNumpyModule
Hi, I am just curious, were you able to decide on way forward? I have written TorchScript based KIM-API model driver that can run most ML models from LAMMPS, ASE, DL_POLY etc. But as a next step I wanted to use ONNX or OpenXLA based backend for more universal inference. Problem with ONNX were that I wasn't able to get the model to differentiate till recently (https://stackoverflow.com/questions/70177709/differentiate-onnx-objects). And I am still reading up on XLA, stable-HLO route. So I was wondering you have more notes on it. |
Hi Amit! We are going with the TorchScript-only route for now, the initial implementation is in metatensor/metatensor#405. I was planning on contacting you back once this initial implementation was merged to see how to integrate it with your work. We have two reasons for using only TorchScript: (a) we have a couple of well optimized TorchScript extensions that we want to use; and (b) we are using custom TorchScript class extensively. While the TorchScript extensions could be ported to ONNX/XLA with some work, I think they don't offer support for custom classes. I don't see much of an advantage for us using ONNX. XLA could be interesting for TPU support and better optimization, but this might not be worth the amount of work to get there. What are your motivations for using them? Supporting models written with jax & friends? |
Hi Guillaume! I was not aware that you are working on this! We had a minimal ONNX model earlier, but gave up on it quickly as ONNX is more focused on running inference only conventional image/language like models. Other than that there were several restrictions such as My major interest in XLA stems from desire for higher performant GPU support, and JAX support. As I would like to have KLIFF (our training framework) support for JAX layers. One benefit of JAX was better ahead of time compilation support, owing to its more restrictive nature. That is why I was wondering if you have made any progress, as internet is bit scarce on examples of running serialized HLO modules form C++. |
We are also interested in XLA/HLO, but we will only look at it once the rest of the code is working and usable! Agreed that in general documentation is very scarce for these technologies, I had to resort to reading the source code many times just to write a JAX extension/XLA custom call. Please let us know if you find anything on this front, and I'll look into it as well! EDIT: reading very quickly though https://github.com/openxla/stablehlo, it seems that one needs a StableHLO runtime to execute the module. I'd guess XLA can provide this, so you should link against XLA an use it to load & execute the HLO model. |
Sure. i am planning to take a dive in it during holidays. Will share findings afterwards. I did some digging from within your link, and it seems like this might give me enough hints to start something.
Somethings I am uncertain about is the dimensions, as it seems like Stable HLO supports fixed input dimensions, but it does support batched inputs as well. In any case perhaps I will have some more info by january! Will update you as well on the findings |
We can write a machine learning potential (MLP) class in python that can be used within a thin driver/class in ase and i-pi. Like
To me it was not clear how we would use this for any MD code that does not support python codes (like LAMMPS). For the models that use torch, the way-to-go is clearly TorchScript. But I was thinking about the models that are numpy based. How should we handle the interfaces for these ones?
Given our resources and the practicability of other approaches I explored a bit (see below), it seems that for the numpy models the only reasonable approach is to export them also as TorchScript. We nevertheless give the option to completely use numpy for training and running their model in python. But when exporting it for running in a MD library with drivers supporting only low-level code, we convert to a TorchScript compatible format. We will also offer an way to export the model keeping the numpy arrays, but that exported model will only work for MD codes with python drivers as i-pi and ase.
Long-range
For long-range interaction the support for MD codes with python drivers works. To get the MPI support from LAMMPS, it seems to me that we need a different approach in kspace which means a different class and interface.
Alternatives to TorchScript? Using JAX
It seems to me that JAX does not have yet the infrastructure that TorchScript offers with their jit compiled custom operators. I only found on GitHub issue how to load JAX jit compiled function in C++ jax-ml/jax#5337 (comment) looks like a lot of work for just one simple function
Alternatives to TorchScript? Running Python code from C/C++
I just skimmed through this guide https://stackoverflow.com/a/1056057 linking to https://www.linuxjournal.com/article/8497 but it looked like it just opens the door to many more low-level issues.
Link collection how MLP codes interface to MD codes:
The text was updated successfully, but these errors were encountered: