Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax does not work with the multiprocessing "fork" strategy. #1805

Closed
dchatterjee172 opened this issue Dec 3, 2019 · 4 comments
Closed

jax does not work with the multiprocessing "fork" strategy. #1805

dchatterjee172 opened this issue Dec 3, 2019 · 4 comments
Labels
question Questions for the JAX team

Comments

@dchatterjee172
Copy link

dchatterjee172 commented Dec 3, 2019

import numpy as np
from functools import partial
import jax.numpy as jnp
from jax import jit, random
from collections import namedtuple
from multiprocessing import Process, Queue
from pickle import dumps


class Brain(namedtuple("Brain", ("w1", "b1", "w2", "b2"))):
    def __sub__(self, other):
        return Brain(
            w1=self.w1 - other.w1,
            b1=self.b1 - other.b1,
            w2=self.w2 - other.w2,
            b2=self.b2 - other.b2,
        )

    def __mul__(self, scalar):
        return Brain(
            w1=self.w1 * scalar,
            b1=self.b1 * scalar,
            w2=self.w2 * scalar,
            b2=self.b2 * scalar,
        )

    __rmul__ = __mul__


def get_brain(
    input_size: int, hidden_size: int, output_size: int, max_memory: int, seed: int
):
    key = random.PRNGKey(seed)
    w1 = random.truncated_normal(
        key, lower=0, upper=0.1, shape=(input_size, hidden_size)
    )
    w2 = random.truncated_normal(
        key, lower=0, upper=0.1, shape=(hidden_size, output_size)
    )
    b1 = jnp.zeros(shape=(hidden_size,))
    b2 = jnp.zeros(shape=(output_size,))
    return Brain(w1=w1, b1=b1, w2=w2, b2=b2)


@jit
def forward(brain: Brain, data: np.ndarray):
    o1 = jnp.matmul(data, brain.w1) + brain.b1
    a1 = jnp.tanh(o1)
    o2 = jnp.matmul(a1, brain.w2) + brain.b2
    a2 = o2 - jnp.expand_dims(jnp.log(jnp.exp(o2).sum(axis=1)), 1)
    return a2


def worker(queue):
    import jax.numpy as jnp
    from jax import grad, jit

    @jit
    def forward(brain: Brain, data: np.ndarray):
        o1 = jnp.matmul(data, brain.w1) + brain.b1
        a1 = jnp.tanh(o1)
        o2 = jnp.matmul(a1, brain.w2) + brain.b2
        a2 = o2 - jnp.expand_dims(jnp.log(jnp.exp(o2).sum(axis=1)), 1)
        return a2

    @jit
    def loss(brain: Brain, data: np.ndarray, labels: np.ndarray):
        pred = forward(brain, data)
        loss = jnp.mean(-(labels * pred).sum(1))
        return loss

    @jit
    def grad_loss(brain: Brain, data: np.ndarray, labels: np.ndarray):
        return partial(grad(loss), data=data, labels=labels)(brain)

    @jit
    def sgd(brain: Brain, data: np.ndarray, labels: np.ndarray, learning_rate: float):
        g = grad_loss(brain, data, labels)
        brain = brain - g * learning_rate
        return brain

    while True:
        brain, data, label, epoch, learning_rate = queue.get()
        for i in range(epoch):
            brain = sgd(brain, data, labels, learning_rate)
        break # if multiprocess the control flow dows not even come here, nothing get's returned from forward


if __name__ == "__main__":
    brain = get_brain(100, 200, 9, 1000, 1)
    data = np.random.normal(size=(1000, 100))
    labels = np.random.uniform(0, 1, size=(1000, 9))
    queue = Queue(10)
    workers = []
    for i in range(2):
        p = Process(target=worker, args=(queue,))
        p.start()  # does not work
        workers.append(p)
    for i in range(10):
        queue.put((brain, data, labels, 1000, 1))
    worker(queue)  # works
@dchatterjee172
Copy link
Author

the same behavior noticed without jit decorator too

@dchatterjee172 dchatterjee172 changed the title jax jit not working in multiprocessing jax not working in multiprocessing Dec 3, 2019
@hawkinsp
Copy link
Collaborator

hawkinsp commented Dec 3, 2019

My strong guess: you're using the fork strategy in multiprocessing. That won't work with JAX, because JAX is internally multithreaded.

Can you try the spawn or forkserver strategies, described here?
https://docs.python.org/3/library/multiprocessing.html

@dchatterjee172
Copy link
Author

Works!

@mattjj mattjj added the question Questions for the JAX team label Dec 3, 2019
@mattjj
Copy link
Collaborator

mattjj commented Dec 3, 2019

Thanks for the question!

@mattjj mattjj closed this as completed Dec 3, 2019
@hawkinsp hawkinsp changed the title jax not working in multiprocessing jax does not work with the multiprocessing "fork" strategy. Dec 3, 2019
jack89roberts added a commit to alan-turing-institute/AIrsenal that referenced this issue Jul 18, 2021
- Hangs  if running multithreaded, possibly because bpl-next  uses jax which is itself uses multiprocessing. See jax-ml/jax#1805 but changing multiprocessing in AIrsenal gives errors about sqlalchemy session not being pickle-able.

- bpl-next predictions occasionally have nan values
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

3 participants