Skip to content

Commit

Permalink
Merge branch 'main' into yeet
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment authored Oct 16, 2023
2 parents bbdd5dc + cfd0e14 commit 1222efb
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 476 deletions.
10 changes: 6 additions & 4 deletions backend/src/nodes/impl/pytorch/architecture/RRDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,15 @@ def __init__(

self.state = self.new_to_old_arch(self.state)

self.key_arr = list(self.state.keys())
highest_weight_num = max(
int(re.search(r"model.(\d+)", k).group(1)) for k in self.state
)

self.in_nc: int = self.state[self.key_arr[0]].shape[1]
self.out_nc: int = self.state[self.key_arr[-1]].shape[0]
self.in_nc: int = self.state["model.0.weight"].shape[1]
self.out_nc: int = self.state[f"model.{highest_weight_num}.bias"].shape[0]

self.scale: int = self.get_scale()
self.num_filters: int = self.state[self.key_arr[0]].shape[0]
self.num_filters: int = self.state["model.0.weight"].shape[0]

c2x2 = False
if self.state["model.0.weight"].shape[-2] == 2:
Expand Down
13 changes: 8 additions & 5 deletions backend/src/nodes/impl/pytorch/architecture/SRVGG.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def __init__(
if "params" in self.state:
self.state = self.state["params"]

self.key_arr = list(self.state.keys())
self.weight_keys = [key for key in self.state.keys() if "weight" in key]
self.highest_num = max(
[int(key.split(".")[1]) for key in self.weight_keys if "body" in key]
)

self.in_nc = self.get_in_nc()
self.num_feat = self.get_num_feats()
Expand Down Expand Up @@ -81,16 +84,16 @@ def __init__(
self.load_state_dict(self.state, strict=False)

def get_num_conv(self) -> int:
return (int(self.key_arr[-1].split(".")[1]) - 2) // 2
return (self.highest_num - 2) // 2

def get_num_feats(self) -> int:
return self.state[self.key_arr[0]].shape[0]
return self.state[self.weight_keys[0]].shape[0]

def get_in_nc(self) -> int:
return self.state[self.key_arr[0]].shape[1]
return self.state[self.weight_keys[0]].shape[1]

def get_scale(self) -> int:
self.pixelshuffle_shape = self.state[self.key_arr[-1]].shape[0]
self.pixelshuffle_shape = self.state[f"body.{self.highest_num}.bias"].shape[0]
# Assume out_nc is the same as in_nc
# I cant think of a better way to do that
self.out_nc = self.in_nc
Expand Down
2 changes: 1 addition & 1 deletion backend/src/nodes/properties/inputs/file_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def PthFileInput(primary_input: bool = False) -> FileInput:
input_type_name="PthFile",
label="Model",
file_kind="pth",
filetypes=[".pt", ".pth", ".ckpt"],
filetypes=[".pt", ".pth", ".ckpt", ".safetensors"],
primary_input=primary_input,
)

Expand Down
6 changes: 6 additions & 0 deletions backend/src/packages/chaiNNer_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ def get_pytorch():
version="0.6.1",
size_estimate=42.2 * KB,
),
Dependency(
display_name="safetensors",
pypi_name="safetensors",
version="0.4.0",
size_estimate=1 * MB,
),
],
icon="PyTorch",
color="#DD6B20",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Tuple

import torch
from safetensors.torch import load_file
from sanic.log import logger

from nodes.impl.pytorch.model_loading import load_state_dict
Expand Down Expand Up @@ -98,6 +99,9 @@ def load_model_node(path: str) -> Tuple[PyTorchModel, str, str]:
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
state_dict = parse_ckpt_state_dict(checkpoint)
elif extension == ".safetensors":
state_dict = load_file(path, device=str(pytorch_device))
logger.info(state_dict.keys())
else:
raise ValueError(
f"Unsupported model file extension {extension}. Please try a supported model type."
Expand Down
31 changes: 27 additions & 4 deletions backend/src/packages/chaiNNer_pytorch/pytorch/io/save_model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
from __future__ import annotations

import os
from enum import Enum

import torch
from safetensors.torch import save_file
from sanic.log import logger

from nodes.impl.pytorch.types import PyTorchModel
from nodes.properties.inputs import DirectoryInput, ModelInput, TextInput
from nodes.properties.inputs import DirectoryInput, EnumInput, ModelInput, TextInput

from .. import io_group


class WeightFormat(Enum):
PTH = "pth"
ST = "safetensors"


@io_group.register(
schema_id="chainner:pytorch:save_model",
name="Save Model",
Expand All @@ -23,12 +30,28 @@
ModelInput(),
DirectoryInput(has_handle=True),
TextInput("Model Name"),
EnumInput(
WeightFormat,
"Weight Format",
default=WeightFormat.PTH,
option_labels={
WeightFormat.PTH: "PyTorch (.pth)",
WeightFormat.ST: "SafeTensors (.safetensors)",
},
),
],
outputs=[],
side_effects=True,
)
def save_model_node(model: PyTorchModel, directory: str, name: str) -> None:
full_file = f"{name}.pth"
def save_model_node(
model: PyTorchModel, directory: str, name: str, weight_format: WeightFormat
) -> None:
full_file = f"{name}.{weight_format.value}"
full_path = os.path.join(directory, full_file)
logger.debug(f"Writing model to path: {full_path}")
torch.save(model.state, full_path)
if weight_format == WeightFormat.PTH:
torch.save(model.state_dict(), full_path)
elif weight_format == WeightFormat.ST:
save_file(model.state_dict(), full_path)
else:
raise ValueError(f"Unknown weight format: {weight_format}")
23 changes: 11 additions & 12 deletions src/common/migrations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
NodeData,
OutputId,
SchemaId,
Size,
} from './common-types';
import { log } from './log';
import { legacyMigrations } from './migrations-legacy';
Expand Down Expand Up @@ -1358,20 +1359,18 @@ const writeOutputFrame: ModernMigration = (data) => {
};

const separateNodeWidthAndInputHeight: ModernMigration = (data) => {
const hasInputSize = (
nodeData: Mutable<ReadonlyNodeData>
): nodeData is Mutable<ReadonlyNodeData> & {
inputSize?: Record<InputId, Size>;
} => 'inputSize' in nodeData;

data.nodes.forEach((node) => {
let maxWidth = 0;
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
if (node.data.inputSize) {
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
const inputSize = node.data.inputSize as Record<
InputId,
{ height: number; width: number }
>;
if (!node.data.inputHeight) {
node.data.inputHeight = {};
}
if (hasInputSize(node.data)) {
const inputSize = node.data.inputSize!;
delete node.data.inputSize;
node.data.inputHeight ??= {};
for (const [inputId, { width, height }] of Object.entries(inputSize)) {
maxWidth = Math.max(maxWidth, width);
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
Expand Down
Loading

0 comments on commit 1222efb

Please sign in to comment.