본문 바로가기
machine learning

[TENSORFLOW] 텐서플로우에서의 쓰레드

by 유주원 2016. 5. 20.

TENSORFLOW는 파이썬 기반 언어라고 생각해서 그런지 멀티 쓰레드가 지원 안되는지 알았는데, 다행히도 멀티 쓰레드가 가능한 것 같다.


TENSORFLOW에서의 멀티쓰레드는 동일한 Session 객체를 사용할 수 있게 해주고 병령로 ops를 동작하게 해준다. 

TENSORFLOW에서는 보다 원할한 멀티쓰레드 동작을 위해 두 개의 class를 제공하는데 tf.Coordinator와 tf.QueueRunner이다. 두 개의 클래스는 함께 사용되도록 디자인 되었다. 


Coordinator는 멀티쓰레드가 함께 종료될 수 있도록 도와주고, 예외처리를 할 수 있도록 제공하고 있다. QueueRunner는 동일한 큐안에서 tensor가 동작할 수 있도록 쓰레드를 생성하는데 도움을 준다.


Coordinator


- should_stop() : 만약 쓰레드가 멈춰야 한다면 True를 리턴한다.

- request_stop(<exception>) : 쓰레드를 멈추도록 요청한다.

- join(<list of threads>) : 명시된 쓰레드가 멈출때까지 기다린다.


멀티 쓰레드를 하기 위해서는 처음에 우선 Coordinator 객체를 생성한다. 그런 다음 coordinator를 사용할 쓰레드들을 생성한다. 쓰레드들은 should_stop() 함수가 True를 리턴하기 전까지는 계속 동작할 것이다.


아래 코드를 보면 좀 더 이해하기가 쉽다.


def MyLoop(coord):

    while not coord.should_stop():

        ... do something...

        if ...some condition...:

            coord.request_stop()


coord = Coordinator()

threads = [threading.Thread(target=MyLoop, args=(coord,)) for i in xrange(10)]


for t in threads: t.start()

coord.join(threads)


coord 객체를 가지고 MyLoop 로직을 실행하는 쓰레드 10개를 생성하고, 쓰레드를 시작한다. 쓰레드가 끝나기 전에 프로그램이 종료되는 것을 막기 위해서 마지막에 coord.join을 걸어서 waiting을 해준다.


MyLoop 함수를 보면 coord.should_stop()이 True가 아니면 루프를 계속 실행 한다. 루프 안에서 어떤 경우가 발생해서 coord.request_stop이 발생되면 coord.should_stop이 True가 되고 쓰레드가 종료가 된다.




QueueRunner


QueueRunner 클래스는 enqueue 시 operation이 동작할 수 있는 쓰레드들을 생성한다. 이렇게 생성된 쓰레드들은 coordinator를 통해 중지 될 수 있으며, coordinator에서 어떠한 예외가 발생되었을 시에는 queue를 자동으로 중지시킬 수도 있다.


example = ... ops ...

queue = tf.RandomShuffleQueue(...)

enqueue_op = queue.enqueue(example)

inputs = queue.dequeue_many(batch_size)

train_op = ... 


일단 queue를 생성하고 queue에 들어갈 operation을 만든다. 또한 queue로부터 dequeue를 실행하는 train_op를 만든다.


qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)

sess = tf.Session()

coord = tf.train.Coordinator()

enqueue_threads = qr.create_threads(sess, coord=coord, start=True)


for step in xrange(1000000):

    if coord.should_stop():

        break

    sess.run(train_op)

coord.request_stop()

coord.join(threads)


QueueRunner를 통해 enqueue에서 4개의 쓰레드가 병렬로 ops를 진행 할 것이다.

그 후 Session을 만들고 쓰레드 관리를 위해 Coordinator도 만든다. QueueRunner 쓰레드와 Coordinator를 연결하고 session run을 돌리면 train_op는 병렬로 4개의 쓰레드가 동작하게 된다. 

마지막에는 request_stop을 통해 쓰레드를 종료하고, join을 통해 다 종료될 때까지 waiting 한 후 프로그램을 종료한다.


Handling Exceptions


QueueRunner시 발생할 수 있는 예외 처리에 대한 코드를 살펴보자.


try:

    for step in xrange(1000000):

        if coord.should_stop():

            break

        sess.run(train_op)

except Exception, e:

    coord.request_stop(e)


coord.request_stop()

coord.join(threads)