Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Merge pull request #46 from hpcaitech/lzm_develop_2
Browse files Browse the repository at this point in the history
Add comments to Batch Manager
  • Loading branch information
dujiangsu authored May 6, 2022
2 parents f370a39 + f4a8664 commit 238c811
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 37 deletions.
132 changes: 101 additions & 31 deletions energon/server/batch_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import torch
"""
------------------------------------------
Class Batch Manager and the function for generating cached cost.
This code modifies the batch wrapping algorithm of Turbo Transformer.
------------------------------------------
"""
import time
from scipy import stats
import numpy as np
Expand All @@ -7,28 +12,47 @@
import random
import redis
import os
from tqdm import tqdm, trange
from tqdm import trange
import threading
from readerwriterlock import rwlock
import logging


def generate_cached_cost(engine, max_seq_len: int = 1024, max_batch_size: int = 16, step: int = 1,
repeat_round: int = 3):
def select_top_k(predictions, k=10):
predicted_index = random.choice(
predictions[0, -1, :].sort(descending=True)[1][:k]).item()
return predicted_index
"""
Test the running time for different sequence length and batch size on the current machine.
:param engine: InferenceEngine from energon.engine
:type engine: InferenceEngine
:param max_seq_len: The max sequence length that is measured.
:param max_batch_size: The max batch size that is measured.
:param step: Run time is measured every other 'step' of sequence length
:param repeat_round: We inference current batch 'repeat_round' times and take average.
"""

print("fetching cached cost")
def select_top_k(temp_predictions, top_k: int = 10):
"""
Pick out a word from the top k of 50257 words according to the possibility given by temp_predictions
for each sequence in this batch.
:param temp_predictions: Transformer output tensor with size of (batch size, sequence length, vocab size)
which contains the possibilities for each word in this batch.
:type temp_predictions: torch.Tensor
:param top_k: How many top words to choose from.
"""
temp_predicted_index = random.choice(
temp_predictions[0, -1, :].sort(descending=True)[1][:top_k]).item()
return temp_predicted_index

logging.log(0, "fetching cached cost")
cached_name = "cached_cost_{}_{}_{}_{}.npy".format(max_seq_len, max_batch_size, step, repeat_round)
if os.path.exists(cached_name):
print("loading cached cost from file")
logging.log(0, "loading cached cost from file")
cached_cost = np.load(cached_name).tolist()
else:
print("generating new cached cost")
logging.log(0, "generating new cached cost")
cached_cost = [[0 for i in range(max_batch_size + 1)] for j in range(max_seq_len + 1)]
input_text = ""
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("./")
for tmp_len in trange(1, max_seq_len + 1, step):
input_text += "test "
for tmp_batch in range(1, max_batch_size + 1):
Expand All @@ -39,25 +63,35 @@ def select_top_k(predictions, k=10):
output = engine.run(input_token)
predictions = output.to_here()
predicted_index = select_top_k(predictions, k=1)
total_predicted_text = tokenizer.decode(predicted_index)
tokenizer.decode(predicted_index)
time_cost = (time.time() - start_time) / repeat_round
cached_cost[tmp_len][tmp_batch] = time_cost
for k in range(1, step):
cached_cost[tmp_len + k][tmp_batch] = time_cost
np.save(cached_name, np.array(cached_cost))
print("cached cost loaded")
logging.log(0, "cached cost loaded")
return cached_cost


class single_request():
class single_request:
def __init__(self, input_, time_stamp: float, input_str: str):
"""
class to store related information for a single request.
:param input_: The output of GPT2Tokenizer.tokenizer, a dict including input_ids and attention_mask
:param time_stamp: The time stamp when we receive the request. We use the time stamp as a index to
identify the request.
:param input_str: The input string of the request.
"""
self.input_ = input_
self.text = input_str
self.time_ = time_stamp
self.seq_len = input_['input_ids'].shape[1]


class Manager:
"""
Base class of batch manager.
"""
def __init__(self):
pass

Expand All @@ -66,19 +100,30 @@ def insert_req(self, time_stamp: float, input_ids, input_str: str):


class Batch_Manager(Manager):
"""
This batch manager is mainly used for maintaining a queue of request to be processed. The requests in the
queue is wrapped into batches according to the sequence length and the priority calculated with the equation
in function cal_priority and then sent into the inference engine.
"""
def __init__(self, engine: InferenceEngine, cached_cost: list, init_mu: int = 512, init_theta: int = 180,
max_batch_size: int = 32, lr: float = 0.01, max_seq_len=1024):
max_batch_size: int = 32, lr: float = 0.01):
"""
:param engine: The InferenceEngine from energon.engine
:param cached_cost: The output of function generate_cached_cost
:param init_mu: initial mean value we suppose for incoming sequence length.
:param init_theta: initial variance value we suppose for incoming sequence length.
:param max_batch_size: the max number of requests that can be wrapped into one batch.
:param lr: the learning rate we use to update the mean and variance that we suppose for the normal
distribution of sequence length.
"""
super().__init__()
self.engine = engine
self.max_batch_size = max_batch_size
self.lr = lr
self.mu = init_mu
self.theta = init_theta
self.max_seq_len = max_seq_len
# self.normal_weight = self._init_normal_dist_weight()
self.req_list = []
self.req_list_lock = rwlock.RWLockFair()
self.read_lock = self.req_list_lock.gen_rlock()
self.write_lock = self.req_list_lock.gen_wlock()
self.cached_cost = cached_cost
self.tokenizer = GPT2Tokenizer.from_pretrained('/home/lcdjs/hf_gpt2')
Expand All @@ -89,13 +134,26 @@ def __init__(self, engine: InferenceEngine, cached_cost: list, init_mu: int = 51
self.main_thread.start()

def insert_req(self, time_stamp: float, input_ids, input_str: str):
"""
Build a single_request class with the input string and then insert it into the queue.
"""
tmp_req = single_request(input_ids, time_stamp, input_str)
self.write_lock.acquire()
self.req_list.append(tmp_req)
self.req_list.sort(key=lambda x: x.seq_len)
self.write_lock.release()

def cal_priority(self, batch_list: list, cur_stamp: float):
"""
Given a wrapped batch, calculate its priority to decide which batch to be given to the inference engine.
The equation is based on the sequence length, batch size and the max wait time among the batch.
We suppose that the length of the requests follows a normal distribution, so for the batches with a
length that has a higher possibility to appear, we tend to let it wait a little longer for other requests
with similar length in order to increase the batch size.
The batches with larger batch size also gains higher priority.
In order to avoid starving problem, we use exponential function to raise the priority of batches which
have waited for long.
"""
cur_len = batch_list[-1].seq_len
earliest_timestamp = min([i.time_ for i in batch_list])

Expand All @@ -107,18 +165,18 @@ def cal_priority(self, batch_list: list, cur_stamp: float):
priority = appear_possibility_weight * batch_size * np.exp(wait_time)
return priority

# def _init_normal_dist_weight(self):
# temp_weight_list = [0]
# for i in range(1, self.max_seq_len):
# temp_weight_list.append(stats.norm(self.mu, self.theta).cdf(i) -
# stats.norm(self.mu, self.theta).cdf(i - 1))
# return temp_weight_list

def cal_norm_weight(self, seq_len):
"""
Approximately estimate the possibility of a certain sequence length using normal distribution.
"""
return stats.norm(self.mu, self.theta).cdf(seq_len) - \
stats.norm(self.mu, self.theta).cdf(seq_len - 1)

def update_norm(self, batch_: list):
"""
Every time we are done inserting a request into the inference engine, we update mu and theta of our
distribution with the current batch and the pre-set learning rate.
"""
new_mu = np.mean([i.seq_len for i in batch_])
delta_mu = new_mu - self.mu
self.mu += self.lr * delta_mu
Expand All @@ -129,14 +187,18 @@ def update_norm(self, batch_: list):
return

def wrap_batch(self):
"""
Given a sorted sequence list, calculate the best way to wrap the batch with DP according to the
cached cost.
The algorithm in this function comes from the paper of Turbo Transformer.
"""
self.write_lock.acquire()
states = [0]
start_idx_list = [0]
for i in range(1, len(self.req_list) + 1):
j = i - 1
start_idx = i - 1
cur_length = self.req_list[i - 1].seq_len
# print(i, j, cur_length)
min_cost = self.cached_cost[cur_length][1] + states[j]
while j > max(0, i - self.max_batch_size):
tmp_cost = states[j - 1] + \
Expand Down Expand Up @@ -169,30 +231,38 @@ def wrap_batch(self):
return result_batch

def processing_batch(self):
"""
The background process that continuously calls wrap_batch, puts the batch into the inference engine,
and starts new processes that wait for and publish the inference result.
"""
while self.running_flag:
if len(self.req_list) > 0:
target_batch = self.wrap_batch()
pad_len = target_batch[-1].seq_len
print("A batch with {} requests and length of {} packed".format(len(target_batch), pad_len))
logging.log(0, "A batch with {} requests and length of {} packed".format(len(target_batch), pad_len))
input_text = [i.text for i in target_batch]
input_ids = self.tokenizer(input_text, padding="longest", return_tensors="pt")
print("input_ids shape: {}".format(input_ids['input_ids'].shape))
print("attention_mask shape: {}".format(input_ids['attention_mask'].shape))
# input_ids = self.tokenizer(input_text, return_tensors="pt")
# print(input_ids)
# print("input_ids shape: {}".format(input_ids['input_ids'].shape))
# print("attention_mask shape: {}".format(input_ids['attention_mask'].shape))
output = self.engine.run(input_ids)
pub_thread = threading.Thread(target=self.publish_result, args=(output, target_batch))
pub_thread.start()

def publish_result(self, output, target_batch):
"""
Background process that waits for the inference result and uses the publisher of Redis to publish it to
the waiting requests.
:param output: the rpc reference of the inference result.
:param target_batch: the input batch
"""
def select_top_k(batch_id, predictions, k=10):
predicted_index = random.choice(
predictions[batch_id, -1, :].sort(descending=True)[1][:k]).item()
return predicted_index

# print("output: {}".format(output))
predictions = output.to_here()
# print("predictions: {}".format(predictions), flush=True)
# decode_list = self.tokenizer.decode(predictions)
for i in range(len(target_batch)):
# print(i, predictions.shape, target_batch)
temp_st = target_batch[i].time_
Expand Down
19 changes: 13 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
torch>=1.8
numpy
tqdm
numpy~=1.21.2
tqdm~=4.64.0
psutil
packaging
fastapi
fastapi~=0.75.1
uvicorn==0.14
typer
redis
scipy
typer~=0.4.0
redis~=4.2.2
scipy~=1.8.0
energon~=0.0.1b0
pytest~=7.1.1
requests~=2.27.1
click~=8.1.2
transformers~=4.18.0
readerwriterlock~=1.0.9
setuptools~=58.0.4

0 comments on commit 238c811

Please sign in to comment.