Skip to content
This repository has been archived by the owner on Apr 24, 2024. It is now read-only.

How will we interface with different molecular dynamics (MD) codes? #42

Open
agoscinski opened this issue Mar 16, 2023 · 5 comments
Open
Labels
software design Discusses the design route to implement functionalities

Comments

@agoscinski
Copy link
Collaborator

agoscinski commented Mar 16, 2023

We can write a machine learning potential (MLP) class in python that can be used within a thin driver/class in ase and i-pi. Like

class MlPotenial(Module):
    def __init__(self, model):
        self.model = model

     # ase, ipi both support ase.Atoms, if other types are needed then conversion can be added
    def compute(self, frame: ase.Atoms):
        self.output = model.forward(frame)

    def get_energy(self):
        energies = self.output.block().values
        # do some additional stuff to have coherent units
        return energies

    def get_forces(self):
    ...

# in ase.calculator.... 
from ase.calculators.calculator import calculator
class MlPotenialCalculator(ase.calculator):
    ... # loading MlPotenial 

# in i-pi drivers
class  MlPotenialDriver(Dummy_driver):
    ... # loading MlPotenial

To me it was not clear how we would use this for any MD code that does not support python codes (like LAMMPS). For the models that use torch, the way-to-go is clearly TorchScript. But I was thinking about the models that are numpy based. How should we handle the interfaces for these ones?
Given our resources and the practicability of other approaches I explored a bit (see below), it seems that for the numpy models the only reasonable approach is to export them also as TorchScript. We nevertheless give the option to completely use numpy for training and running their model in python. But when exporting it for running in a MD library with drivers supporting only low-level code, we convert to a TorchScript compatible format. We will also offer an way to export the model keeping the numpy arrays, but that exported model will only work for MD codes with python drivers as i-pi and ase.

Long-range

For long-range interaction the support for MD codes with python drivers works. To get the MPI support from LAMMPS, it seems to me that we need a different approach in kspace which means a different class and interface.

Alternatives to TorchScript? Using JAX

It seems to me that JAX does not have yet the infrastructure that TorchScript offers with their jit compiled custom operators. I only found on GitHub issue how to load JAX jit compiled function in C++ jax-ml/jax#5337 (comment) looks like a lot of work for just one simple function

Alternatives to TorchScript? Running Python code from C/C++

I just skimmed through this guide https://stackoverflow.com/a/1056057 linking to https://www.linuxjournal.com/article/8497 but it looked like it just opens the door to many more low-level issues.

Link collection how MLP codes interface to MD codes:

@agoscinski agoscinski added the software design Discusses the design route to implement functionalities label Mar 16, 2023
agoscinski added a commit that referenced this issue Mar 18, 2023
Because we want to allow people to transfert their NumpyModule
to a TorchModule so they can get access to TorchScript (see issue #42)
we need to change the way how we do the inheritance. Before it was

Module (reference to torch.nn.Module or our BaseModule)
  --> CustomModule (e.g. Ridge)

But that means when loading the library there is just one class CustomModule
that even inherits from torch.nn.Module or BaseModule depending if torch
is available on the machine. With one inheritance, it is hard switch between
the classes. Changing the base class is very hacky, so this is not a
good approach. There we create both classes when torch is present
(note BaseModule wase renamed to NumpyModule)

def factory_custom_module(base):
    class _CustomModule(base):
        ...
    # change name ...
    return _CustomModule

CustomNumpyModule = factory_custom_module(NumpyModule)
CustomTorchModule = factory_custom_module(torch.nn.Module)

if HAS_TORCH:
    CustomModule = CustomTorchModule
else:
    CustomModule = CustomNumpyModule
@ipcamit
Copy link

ipcamit commented Nov 25, 2023

Hi, I am just curious, were you able to decide on way forward? I have written TorchScript based KIM-API model driver that can run most ML models from LAMMPS, ASE, DL_POLY etc. But as a next step I wanted to use ONNX or OpenXLA based backend for more universal inference. Problem with ONNX were that I wasn't able to get the model to differentiate till recently (https://stackoverflow.com/questions/70177709/differentiate-onnx-objects). And I am still reading up on XLA, stable-HLO route. So I was wondering you have more notes on it.

@Luthaf
Copy link
Collaborator

Luthaf commented Nov 27, 2023

Hi Amit!

We are going with the TorchScript-only route for now, the initial implementation is in metatensor/metatensor#405. I was planning on contacting you back once this initial implementation was merged to see how to integrate it with your work.

We have two reasons for using only TorchScript: (a) we have a couple of well optimized TorchScript extensions that we want to use; and (b) we are using custom TorchScript class extensively. While the TorchScript extensions could be ported to ONNX/XLA with some work, I think they don't offer support for custom classes.

I don't see much of an advantage for us using ONNX. XLA could be interesting for TPU support and better optimization, but this might not be worth the amount of work to get there. What are your motivations for using them? Supporting models written with jax & friends?

@ipcamit
Copy link

ipcamit commented Nov 27, 2023

Hi Guillaume! I was not aware that you are working on this!
Metatensor looks interesting, perhaps I can look into KIM also having a meta-tensor based model driver. I dont think it will be too difficult, except the neighbor lists, but worth to explore.

We had a minimal ONNX model earlier, but gave up on it quickly as ONNX is more focused on running inference only conventional image/language like models. Other than that there were several restrictions such as floor function demanding single precision float always.

My major interest in XLA stems from desire for higher performant GPU support, and JAX support. As I would like to have KLIFF (our training framework) support for JAX layers. One benefit of JAX was better ahead of time compilation support, owing to its more restrictive nature.

That is why I was wondering if you have made any progress, as internet is bit scarce on examples of running serialized HLO modules form C++.

@Luthaf
Copy link
Collaborator

Luthaf commented Nov 29, 2023

We are also interested in XLA/HLO, but we will only look at it once the rest of the code is working and usable! Agreed that in general documentation is very scarce for these technologies, I had to resort to reading the source code many times just to write a JAX extension/XLA custom call.

Please let us know if you find anything on this front, and I'll look into it as well!

EDIT: reading very quickly though https://github.com/openxla/stablehlo, it seems that one needs a StableHLO runtime to execute the module. I'd guess XLA can provide this, so you should link against XLA an use it to load & execute the HLO model.

@ipcamit
Copy link

ipcamit commented Nov 29, 2023

Sure. i am planning to take a dive in it during holidays. Will share findings afterwards. I did some digging from within your link, and it seems like this might give me enough hints to start something.
My approach will be to have similar framework as our Torch ML model driver for KIM,

  1. A C++ class to parse a StableHLO module and keep it memory
  2. Simulator will query kim-api for forces
  3. (This step I am bit unclear about) XLA backend will execute the HLO program with inputs to compute the forces

Somethings I am uncertain about is the dimensions, as it seems like Stable HLO supports fixed input dimensions, but it does support batched inputs as well. In any case perhaps I will have some more info by january!

Will update you as well on the findings

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
software design Discusses the design route to implement functionalities
Projects
None yet
Development

No branches or pull requests

3 participants