개발자로 후회없는 삶 살기
[최적화] GPU 풀 구현기 1 본문
서론
Multi Node Single GPU 환경에서 GPU를 효율적으로 사용하기 위해서 GPU 풀을 구현해봅니다. Spring 쓰레드 풀과 같은 개념으로 GPU를 미리 할당 받아서 모아 두어 효율적으로 사용하기 위한 목적으로 볼 수 있습니다.
본론
- 구현 목적
다수의 GPU로 모델 학습을 돌리다보면 GPU를 100% 효율적으로 사용하기 어려울 수 있고, 프로젝트 종료 후에 실제 학습을 돌린 시간을 보면 절반도 안되는 경우가 있습니다. 그 이유는 누군가는 모델을 안 돌리고 있을 수 있고, 누군가는 모델을 돌리고 싶지만 현재 자신의 GPU를 돌리고 있어서 다른 실험을 할 수 없을 수 있습니다.
이렇게 되면 GPU utilization이 떨어집니다. 따라서 6개의 GPU의 사용성을 올리기 위해서 개개인의 GPU가 아닌 공용 GPU를 만들어 이 상황을 해결하고자 합니다.
1. Publisher Agent
Publisher는 6명의 작업을 공통 경로에 저장하고 Task Queue에 Insert하는 역할입니다. 일종의 Task를 한 곳으로 모으는 깔데기 역할입니다.
2. Consumer Agent
Consumer는 Queue에서 작업을 가져와서 모델 학습을 돌리는 GPU 서버로 Queue와 커넥션 연결을 하고 listen()으로 가져올 작업이 있다면 가져와 학습을 진행합니다. 모델 학습을 위해 상시 대기하는 학습 서버입니다.
-> 로컬 환경에서 테스트
이것을 사설망에서 6대의 서버로 구현하기 전에 로컬에서 WSL과 도커로 테스트 해보려 합니다. WSL 우분투 서버가 Publisher 역할을 하고 Redis 도커 서버가 Queue, 나머지 도커 서버들이 Consumer 역할을 할 것입니다.
=> 이슈 발생
이를 구현하면서 발생한 다양한 문제들을 알아보겠습니다.
1. 지연 시간이 수반된다.
import redis
import json
import time
r = redis.Redis(host='localhost', port=6379, db=0)
def publish_single_task():
config = {"key" : "1"}
r.publish('ai_train', json.dumps(config))
p = r.pubsub()
p.subscribe("ai_train")
time.sleep(1) # 1초 동안 대기
message = p.get_message()
if message and message['type'] == 'message':
data = message["data"]
if type(data) is bytes:
print(json.loads(data))
먼저, 테스트를 위해 pub과 sub을 하나의 메서드에서 실행되도록 구현하였습니다. 하지만 pub/sub 메커니즘은 거의 동시에 일어나서 지연시간이 필요하였고 sleep을 추가하였습니다.
2. while 루프가 수반된다.
오직 한번의 sub을 해보기 위해 while문을 제거하고 실행해 보았더니 type이 subscribe인 메세지를 가져왔습니다. Redis에서 채널을 subscribe하면, 첫 번째로 받는 메시지는 해당 채널에 대한 subscribe confirmation이라서 while을 통해 publish된 메시지를 받을 때까지 get_message()를 반복적으로 호출해야 합니다. 하지만 while을 추가하여도 message는 한 번만 가져오고 None이 반환되어 Task를 가져올 수 없었습니다.
3. Redis 서버 상태 확인
docker run --name my-redis -p 6379:6379 -d redis
따라서 '레디스 서버에 데이터가 publish 안 된 것이 아닌가'라는 의문으로 레디스를 테스트 해보기로 하였습니다.
하지만, exec로 접속해서 실행해보면 Redis는 정상적으로 작동하였고 파이썬으로 접속하려고 하면 안 되고 있는 상황인 것을 확인하였습니다. 따라서 코드 상 이슈라고 판정하였습니다.
4. 멀티 스레드 활용
Redis를 조사해보니 pubsub 메커니즘이 별도의 프로세스에서 동작할 것을 기대한다고 하여 pub/sub 메서드를 분리하였습니다.
def run_pubsub():
# publish
threading.Thread(target=publisher, args=('ai_queue', {"lr" : 0.01})).start()
# subscribe
pubsub = conn.pubsub()
pubsub.subscribe(['ai_queue'])
while True: # 이것도 계속 필요한 것 같고 왜냐? publish 한 순간데 sub가 없으면 메세지가 유실된다.
print("waiting message...")
res = pubsub.get_message()
if res is not None:
print(res)
time.sleep(0.5)
def publisher(channel: str, config: dict):
time.sleep(1)
conn.publish(channel=f'{channel}', message=json.dumps(config))
if __name__ == '__main__':
run_pubsub()
멀티 스레드를 한 이유는 레디스 pub/sub은 카프카 메세지 큐와 다르게 publish한 시점에 subscriber가 없으면 데이터가 곧바로 유실됩니다. 따라서 publish가 전부 끝난 후 subscribe를 하면 publish한 메세지가 print 되지 않습니다.
실험 Tasks로 {"lr" : 0.01}을 넣어주었고, 위처럼 멀티스레드와 메세지 분리로 Publisher Agent 테스트를 성공했습니다. publish 되는 건 int, float, byte, str만 가능하므로 dict 타입은 json으로 바꿔줘야 합니다. 앞서 get_message()로 Consumer로 구현한 것이지만, 이제 각각의 Consumer Agent를 구현하고 사설망 내부 통신을 해보겠습니다.
🚨 메커니즘의 문제점
앞서 레디스 pub/sub 메커니즘은 publish 시점에 sub가 안 되면 메세지가 곧바로 유실되는 문제가 있었습니다.
consumer에서 모델 학습을 하는 도중 메세지가 들어오면 sub을 할 수 없어서 데이터 유실
이때 다음과 같은 문제가 발생할 수 있습니다. 따라서 pub/sub은 모델 학습에 사용할 수 없습니다.
✅ 유실되지 않는 List 메커니즘을 사용하자
Redis List를 사용해 메세징 큐를 구현하면 pub 시점에 sub이 일어나지 않아도 메세지를 넣어 둔 후 시간이 지나도 메세지를 가져오는 것이 가능합니다.
import threading
import redis
import time
conn = redis.Redis(host='localhost', port=6379, db=0)
def run_list_queue():
# sender
threading.Thread(target=sender, args=(10,)).start()
# receiver
while True:
print("waiting message...")
res = conn.blpop('test_list', timeout=0)
if res is not None:
print(res)
time.sleep(0.5)
def sender(n: int):
time.sleep(1)
for num in range(n):
time.sleep(1)
conn.rpush('test_list', f'message #{num}')
if __name__ == '__main__':
run_list_queue()
rpush()는 레디스 List에 데이터를 적재하며, 이 데이터는 가져가기 전까지 보관됩니다.
blpop()은 데이터가 있으면 가져가고 데이터가 없으면 대기합니다.
- 컴포넌트 각각 띄우기
1. consumer, publisher 코드 분리
이제 위에서 구성한 퍼블리셔, 컨슈머, 큐를 각각 띄워서 네트워크로 통신해보겠습니다.
pub을 먼저 실행하면 for문 10회 반복으로 인해 약 10초 뒤에 끝납니다. 만약 데이터가 보관된다면 pub이 끝난 후에도 sub을 할 수 있어야 합니다.
sub 결과 넣어놓은 데이터를 가져올 수 있습니다. 또한 데이터가 10개라서, 데이터가 없다면 대기하는 모습입니다.
2. nohup, &로 백그라운드 실행
우분투에서는 nohup, &로 백그라운드 실행을 할 수 있습니다.
3. 모델 학습 코드에 적용
-> 모델 학습 과정
GPU Pool을 사용하기 전 기본 학습 과정은 train 코드에 명령행 인자로 config를 주면서 실행하는 것이었습니다.
따라서, 이 config를 큐의 메세지로 적재하는 퍼블리셔 코드와 큐에서 가져와서 학습하는 컨슈머 코드가 필요합니다.
1) 퍼블리셔가 config를 큐에 입력
conn = redis.Redis(host='localhost', port=6379, db=0)
def run_queue():
# sender
config = {
"arch": {
"type": "CustomModel",
"args": {
"num_classes": 11
}
},
"data_loader": {
"type": "CustomDataLoader",
"args":{
"batch_size": 3,
"shuffle": False,
"num_workers": 2
}
},
"dataset": {
"type": "CustomDataset",
"args":{
"annotation" : "../dataset/train.json",
"data_dir" : "../dataset",
"resize" : 1024
}
},
"trainer": {
"type": "CustomTrainer",
"args":{
"epochs": 2,
"save_path": "./checkpoints/faster_rcnn_torchvision_checkpoints.pth"
}
}
}
threading.Thread(target=sender, args=(config, )).start()
def sender(config: dict):
time.sleep(1)
conn.rpush("config", json.dumps(config))
if(__name__ == "__main__"):
run_queue()
테스트를 위해 config 변수를 만들어 실험합니다. dict 타입의 config를 json으로 레디스 큐에 적재하는 과정입니다.
2) 컨슈머가 메세지를 가져와서 학습
config dict를 메세지 큐에서 읽어서 config.json으로 저장 후, read_json으로 읽기
원래 train.py는 명령행 인자로 config를 받고 저장된 config.json 파일을 load해서 dict로 파싱하기에 config를 활용한 학습 방법으로 위 방법을 생각했습니다. 하지만 메세지를 받아서 다시 저장하는 과정이 stream의 강점을 약화시키는 느낌을 강하게 받았습니다.
-> 어떻게 할까? 🚨
def read_json(cfg_fname):
fname = Path(cfg_fname)
with fname.open("rt") as handle:
return json.load(handle, object_hook=OrderedDict)
def read_dict(cfg: str):
return json.loads(cfg, object_hook=OrderedDict)
기존에 json을 로드하는 것 대신 큐에 config를 문자열로 적재하고 바로 받아서 dict로 변환하는 것이 stream을 제대로 활용하는 것이라고 판단되었습니다.
import subprocess as sp
conn = redis.Redis(host='localhost', port=6379, db=0)
CONFIG_LOCATION: Final = 1
def receiver():
# receiver
while True:
print("waiting message...")
config: bytes = conn.blpop('config', timeout=0)
if config is not None:
byte_config = config[CONFIG_LOCATION]
decoded_config: str = byte_config.decode('utf-8')
command = "python %s/train.py -dc '%s'" % (CWD, decoded_config)
train = sp.run(command, capture_output=True, text=True, shell=True)
print(train)
time.sleep(0.5)
if __name__ == '__main__':
receiver()
컨슈머는 while문으로 계속 큐를 리스닝하고 있고 메세지가 있다면 받아와서 문자열로 바꾸고 명령행 인자로 만듭니다. 이후 subprocess로 파이썬 학습 코드를 실행합니다.
실행 결과 명령어와 명령행 인자가 제대로 들어간 것이 확인되고
학습이 진행되고 있는 것을 확인할 수 있습니다.
참고
'[AI] > [딥러닝 | 이슈해결]' 카테고리의 다른 글
MMLab PART.커스텀 파이프라인 제작 (0) | 2024.01.09 |
---|---|
Augmentation PART.albumentation 활용 (2) | 2023.12.19 |
Pytorch PART.데이터 로더, 폴더 활용(Collate, Sampler 등) (0) | 2023.12.04 |
Pytorch PART.Pytorch 탬플릿 Base 파일 관리 및 활용 (0) | 2023.12.04 |
Pytorch PART.논문을 코드로 구현하는 능력(ResNet) (0) | 2023.11.28 |