Spawn, torch.multiprocessing

Hello all

We have developed a multilingual TTS service, and we have got several DL models to run at test time , and those models can be run in parallel because they don’t have any dependencies to each other (trying to get lower runtime, better performance)
We do that on a GPU
but I ran into several problems

A simpler version of it is declared by below codes :

import torch.multiprocessing as mp

def train1():
    print("\nx")
    q5 = np.random.randint(2,size=(4,2))
    q5_targ = torch.tensor(q5).to(torch.device("cuda"))

def train2():
    print("\ny")
    g5 = np.random.randint(2,size=(4,2))
    g5_targ = torch.tensor(g5).to(torch.device("cuda"))

if __name__ == '__main__':
    p1 = mp.Process(target=train1, args=())
    p2 = mp.Process(target=train2, args=())

    p1.start()
    p2.start()
    p1.join()
    p2.join()

when I run above codes, I get output of

x
y
.
.
.
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the ‘spawn’ start method

and when I run

from torch.multiprocessing import set_start_method
set_start_method(‘spawn’)

before those , I would see nothing as my output !
(means it doesn’t even enter target functions )

I am on colab now , and using Tesla T4

Do you know that @muellerzr ?

I’ve never touched multiprocessing so I have no ideas.

1 Like

I tried running below codes as .py file on our own Linux server and it worked OK !
maybe the problem is with .ipynb

import torch
import numpy as np
import torch.multiprocessing as mp
from torch.multiprocessing import set_start_method

def train1():
    print("\nx")
    q5 = np.random.randint(2,size=(4,2))
    q5_targ = torch.tensor(q5).to(torch.device("cuda"))

def train2():
    print("\ny")
    g5 = np.random.randint(2,size=(4,2))
    g5_targ = torch.tensor(g5).to(torch.device("cuda"))

if __name__ == '__main__':
    set_start_method('spawn')
    p1 = mp.Process(target=train1, args=())
    p2 = mp.Process(target=train2, args=())

    p1.start()
    p2.start()

    p1.join()
    p2.join()