Skip to content

Commit

Permalink
adding model utilities
Browse files Browse the repository at this point in the history
	new file:   ../../docker-compose.yml
	new file:   iq_models/efficientnet/efficientnet1d.py
	new file:   iq_models/xcit/xcit1d.py
	new file:   model_utils/general_layers.py
	new file:   model_utils/layer_tools.py
	new file:   model_utils/model_utils_1d/conversions_to_1d.py
	new file:   model_utils/model_utils_1d/iq_sampling.py
	new file:   model_utils/model_utils_1d/layers_1d.py
	new file:   model_utils/simple_models.py
  • Loading branch information
pvallance committed Jun 15, 2024
1 parent 4546d43 commit 0dde0d9
Show file tree
Hide file tree
Showing 9 changed files with 644 additions and 0 deletions.
23 changes: 23 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: torch_sig_container_${PROJECT_NAME}
services:
torchsig_service:
build: .
image: torchsig:v0.5.0
container_name: torchsig_${PROJECT_NAME}
stdin_open: true
tty: true
volumes:
- ./:/workspace/code
ports:
- '${JUP_PORT}:${JUP_PORT}'
environment:
- NVIDIA_VISIBLE_DEVICES=all
- NVIDIA_DRIVER_CAPABILITIES=all
command: jupyter lab --allow-root --ip=0.0.0.0 --no-browser --port ${JUP_PORT} --NotebookApp.token=''
shm_size: 512GB
deploy:
resources:
reservations:
devices:
- capabilities: [gpu]
driver: nvidia
45 changes: 45 additions & 0 deletions torchsig/models/iq_models/efficientnet/efficientnet1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import timm
from torch.nn import Linear

from torchsig.models.model_utils.model_utils_1d.conversions_to_1d import convert_2d_model_to_1d

__all__ = ["EfficientNet1d"]

def EfficientNet1d(
input_channels: int,
n_features: int,
efficientnet_version: str = "b0",
drop_path_rate: float = 0.2,
drop_rate: float = 0.3,
):
"""Constructs and returns a 1d version of the EfficientNet model described in
`"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" <https://arxiv.org/abs/1905.11946>`_.
Args:
input_channels (int):
Number of 1d input channels; e.g., common practice is to split complex number time-series data into 2 channels, representing the real and imaginary parts respectively
n_features (int):
Number of output features; should be the number of classes when used directly for classification
efficientnet_version (str):
Specifies the version of efficientnet to use. See the timm efficientnet documentation for details. Examples are 'b0', 'b1', and 'b4'
drop_path_rate (float):
Drop path rate for training
drop_rate (float):
Dropout rate for training
"""
mdl = convert_2d_model_to_1d(
timm.create_model(
"efficientnet_" + efficientnet_version,
in_chans=input_channels,
drop_path_rate=drop_path_rate,
drop_rate=drop_rate,
)
)
mdl.classifier = Linear(mdl.classifier.in_features, n_features)
return mdl
83 changes: 83 additions & 0 deletions torchsig/models/iq_models/xcit/xcit1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import timm
from torch import cat
from torch.nn import Module, Conv1d, Linear

from torchsig.models.model_utils.model_utils_1d.iq_sampling import ConvDownSampler, Chunker

__all__ = ["XCiT1d"]

class XCiT1d(Module):
"""A 1d implementation of the XCiT architecture from
`"XCiT: Cross-Covariance Image Transformers" <https://arxiv.org/pdf/2106.09681.pdf>`_.
Args:
input_channels (int):
Number of 1d input channels; e.g., common practice is to split complex number time-series data into 2 channels, representing the real and imaginary parts respectively
n_features (int):
Number of output features; should be the number of classes when used directly for classification
xcit_version (str):
Specifies the version of efficientnet to use. See the timm xcit documentation for details. Examples are 'nano_12_p16_224', and 'xcit_tiny_12_p16_224'
drop_path_rate (float):
Drop path rate for training
drop_rate (float):
Dropout rate for training
ds_method (str):
Specifies the downsampling method to use in the model. Currently convolutional downsampling and chunking are supported, using string arguments 'downsample' and 'chunk' respectively
ds_rate (int):
Specifies the downsampling rate; e.g., ds_rate=2 will downsample the imput by a factor of 2
"""
def __init__(self,
input_channels: int,
n_features: int,
xcit_version: str = "nano_12_p16_224",
drop_path_rate: float = 0.0,
drop_rate: float = 0.3,
ds_method: str = "downsample",
ds_rate: int = 2):

super().__init__()
self.backbone = timm.create_model(
"xcit_" + xcit_version,
num_classes=n_features,
in_chans=input_channels,
drop_path_rate=drop_path_rate,
drop_rate=drop_rate,
)

W = self.backbone.num_features
self.grouper = Conv1d(W, n_features, 1)
if ds_method == "downsample":
self.backbone.patch_embed = ConvDownSampler(input_channels, W, ds_rate)
elif ds_method == "chunk":
self.backbone.patch_embed = Chunker(input_channels, W, ds_rate)
else:
raise ValueError(ds_method + " is not a supported downsampling method; currently 'downsample', and 'chunk' are supported")

self.backbone.head = Linear(self.backbone.head.in_features, n_features)

def forward(self, x):
mdl = self.backbone
B = x.shape[0]
x = self.backbone.patch_embed(x)

Hp, Wp = x.shape[-1], 1
pos_encoding = mdl.pos_embed(B, Hp, Wp).reshape(B, -1, Hp).permute(0, 2, 1).half()
x = x.transpose(1, 2) + pos_encoding
for blk in mdl.blocks:
x = blk(x, Hp, Wp)
cls_tokens = mdl.cls_token.expand(B, -1, -1)
x = cat((cls_tokens, x), dim=1)
for blk in mdl.cls_attn_blocks:
x = blk(x)
x = mdl.norm(x)
x = self.grouper(x.transpose(1, 2)[:, :, :1]).squeeze()
if x.dim() == 2:
x = x.unsqueeze(0)
return x
82 changes: 82 additions & 0 deletions torchsig/models/model_utils/general_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from torch import mean
from torch.nn import Module, LSTM

class DebugPrintLayer(Module):
"""
A layer for debugging pytorch models; prints out the shape and data type of the input tensor at runtime
returns he input tensor unchanged
"""
def __init__(self):
super().__init__()

def forward(self, x):
print(x.shape, x.dtype)
return x

class ScalingLayer(Module):
"""
A layer that given input tensor x outputs scale_val * x
used to linearly scale inputs by a fixed value
"""
def __init__(self, scale_val):
super().__init__()
self.scale_val = scale_val

def forward(self, x):
return self.scale_val * x

class DropChannel(Module):
"""
A layer that drops the last color channel of an image [must be in channel-first form]
"""
def __init__(self):
super().__init__()

def forward(self, x):
return x[:,:-1,:,:]

class LSTMImageReader(Module):
"""
TODO add some real documentation here
"""
def __init__(self, input_width, lstm_width, img_shape, num_layers=2):
super().__init__()
self.img_shape = img_shape
self.img_height = img_shape[0]
self.img_width = img_shape[1]
self.input_width = input_width
self.lstm_width = lstm_width
self.lstm_model = LSTM(self.input_width,self.lstm_width,num_layers,True,True,0,False,self.img_height)

def forward(self, x):
output, (h,c) = self.lstm_model(x.transpose(1,2))
img_tensor = output.transpose(1,2)[:,:self.img_height,:self.img_width] #take only the last img_height entries in the outut sequence
return img_tensor.reshape([x.size(0),1,self.img_height,self.img_width])

class Reshape(Module):
"""
A layer that reshapes the input tensor to a tensor of the given shape
if keep_batch_dim is True (defaults to True), the batch dimension is excluded from the reshape operation; otherwise it is included
"""
def __init__(self, shape, keep_batch_dim=True):
super(Reshape, self).__init__()
self.shape = shape
self.keep_batch_dim = keep_batch_dim

def forward(self, x):
if self.keep_batch_dim:
batch_dim = x.size(0)
shape = [batch_dim] + list(self.shape)
return x.view(shape)
return x.view(self.shape)

class Mean(Module):
"""
A layer which returns the mean(s) along the dimension specified by dim of the input tensor
"""
def __init__(self, dim):
super(Mean, self).__init__()
self.dim = dim

def forward(self, x):
return mean(x,self.dim)
111 changes: 111 additions & 0 deletions torchsig/models/model_utils/layer_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
def get_layer_list(model):
"""
returns a list of all layers in the input model, including layers in any nested models therein
layers are listed in forward-pass order
"""
arr = []
final_arr = []
try:
arr = [m for m in model.modules()]
if len(arr) > 1:
for module in arr[1:]:
final_arr += get_module_list(module)
return final_arr
else:
return arr
except:
raise(NotImplementedError("expected module list to be populated, but no '_modules' field was found"))

def replace_layer(old_layer, new_layer, model):
"""
search through model until old_layer is found, and replace it with new layer;
returns True is old_layer was found; False otherwise
"""
try:
modules = model._modules
for k in modules.keys():
if modules[k] == old_layer:
modules[k] = new_layer
return True
else:
if replace_layer(old_layer, new_layer, modules[k]):
return True
return False
except:
raise(NotImplementedError("expected module list to be populated, but no '_modules' field was found"))

def is_same_type(layer1, layer2):
"""
returns True if layer1 and layer2 are of the same type; false otherwise
if a class is input as layer2 [e.g., is_same_type(my_conv_layer, Conv2d) ], the type defined by the class is used
if a string is input as layer2, the string is matched to the name of the class of layer1
"""
if type(layer2) == type:
return type(layer1) == layer2
elif type(layer2) == str:
return type(layer1).__name__ == layer2
else:
return type(layer1) == type(layer2)

def same_type_fn(layer1):
"""
curried version of is_same_type; returns a function f such than f(layer2) <-> is_same_type(layer1, layer2)
"""
return lambda x: is_same_type(x, layer1)


def replace_layers_on_condition(model, condition_fn, new_layer_factory_fn):
"""
search through model finding all layers L such that conditional_fn(L), and replace them with new_layer_factory_fn(L)
returns true if at least one layer was replaced; false otherwise
"""
has_replaced = False
try:
modules = model._modules
for k in modules.keys():
if condition_fn(modules[k]):
modules[k] = new_layer_factory_fn(modules[k])
has_replaced = True
else:
has_replaced = replace_layers_on_condition(modules[k], condition_fn, new_layer_factory_fn) or has_replaced
return has_replaced
except:
raise(NotImplementedError("expected module list to be populated, but no '_modules' field was found"))

def replace_layers_on_conditions(model, condition_factory_pairs):
"""
search through model finding all layers L such that for some ordered pair [conditional_fn, new_layer_factory_fn] in condition_factory_pairs,
conditional_fn(L), and replace them with new_layer_factory_fn(L)
layers will only be replaced once, so the first conditional for which a layer returns true will be last conditional to which it is compared
returns true if at least one layer was replaced; false otherwise
"""
has_replaced = False
try:
modules = model._modules
for k in modules.keys():
for (condition_fn, new_layer_factory_fn) in condition_factory_pairs:
if condition_fn(modules[k]):
modules[k] = new_layer_factory_fn(modules[k])
has_replaced = True
break
else:
has_replaced = replace_layers_on_conditions(modules[k], condition_factory_pairs) or has_replaced
return has_replaced
except:
raise(NotImplementedError("expected module list to be populated, but no '_modules' field was found"))

def replace_layers_of_type(model, layer_type, new_layer_factory_fn):
"""
search through model finding all layers L of type layer_type and replace with new_layer_factory_fn(L)
returns true if at least one layer was replaced; false otherwise
"""
return replace_layers_on_condition(model, lambda x: is_same_type(x,layer_type), new_layer_factory_fn)

def replace_layers_of_types(model, type_factory_pairs):
"""
search through model finding all layers L such that for some ordered pair [layer_type, new_layer_factory_fn] in type_factory_pairs,
L is of type layer_type, and replace with new_layer_factory_fn(L)
returns true if at least one layer was replaced; false otherwise
"""
condition_factory_pairs = [(same_type_fn(layer_type), new_layer_factory_fn) for (layer_type, new_layer_factory_fn) in type_factory_pairs]
return replace_layers_on_conditions(model, condition_factory_pairs)
Loading

0 comments on commit 0dde0d9

Please sign in to comment.