diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index a4424dc9ba1..a2236a3de36 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -6,14 +6,13 @@ from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel -from text_generation_server.utils import flash_attn, paged_attention -from text_generation_server.utils.layers import ( - FastLayerNorm, - PositionRotaryEmbedding, - SpeculativeHead, +from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.utils.flash_attn import attention +from text_generation_server.layers import ( + TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - TensorParallelRowLinear, + SpeculativeHead, get_linear, ) from text_generation_server.layers.layernorm import (