Skip to content

Commit

Permalink
Merge branch 'release/0.3.3'
Browse files Browse the repository at this point in the history
  • Loading branch information
ControlNet committed Sep 24, 2023
2 parents 93ef891 + f4bd553 commit dcb08f7
Show file tree
Hide file tree
Showing 14 changed files with 86 additions and 52 deletions.
24 changes: 16 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,7 @@ This library provides event bus based reactive tools. The API integrates the Pyt

```python
# useful decorators for default event bus
from tensorneko.util import (
subscribe, # run in the main thread
subscribe_thread, # run in a new thread
subscribe_async, # run async
subscribe_process # run in a new process
)
from tensorneko.util import subscribe
# Event base type
from tensorneko.util import Event

Expand All @@ -488,21 +483,34 @@ class LogEvent(Event):
self.message = message

# the event argument should be annotated correctly
@subscribe
@subscribe # run in the main thread
def log_information(event: LogEvent):
print(event.message)

@subscribe_thread

@subscribe.thread # run in a new thread
def log_information_thread(event: LogEvent):
print(event.message, "in another thread")


@subscribe.coro # run with async
async def log_information_async(event: LogEvent):
print(event.message, "async")


@subscribe.process # run in a new process
def log_information_process(event: LogEvent):
print(event.message, "in a new process")

if __name__ == '__main__':
# emit an event, and then the event handler will be invoked
# The sequential order is not guaranteed
LogEvent("Hello world!")
# one possible output:
# Hello world! in another thread
# Hello world! async
# Hello world!
# Hello world! in a new process
```

### Multiple Dispatch
Expand Down
3 changes: 2 additions & 1 deletion src/tensorneko/neko_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ def log_on_training_step_end(self, output: STEP_OUTPUT) -> None:
for key, value in output.items():
history_item[key] = value
self.logger.log_metrics({key: value}, step=self.trainer.global_step)
self.log(key, value, on_epoch=False, on_step=True, logger=False, sync_dist=self.distributed)
self.log(key, value, on_epoch=False, on_step=True, logger=False, prog_bar=key == "loss",
sync_dist=self.distributed)
self.history.append(history_item)

def on_test_batch_end(
Expand Down
10 changes: 4 additions & 6 deletions src/tensorneko/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from tensorneko_util.util import dispatch, AverageMeter, tensorneko_util_path
from tensorneko_util.util.fp import Seq, AbstractSeq, curry, F, Stream, return_option, Option, Monad, Eval, _, __
from tensorneko_util.util import ref, Timer, Singleton
from tensorneko_util.util.eventbus import Event, EventBus, EventHandler, subscribe, subscribe_async, \
subscribe_process, subscribe_thread
from tensorneko_util.util.eventbus import Event, EventBus, EventHandler, subscribe
from tensorneko_util.util import download_file, WindowMerger
from . import type
from .configuration import Configuration
from .misc import reduce_dict_by, summarize_dict_by, with_printed_shape, is_bad_num, count_parameters, compose, \
generate_inf_seq, listdir, with_printed, ifelse, dict_add, as_list, identity, list_to_dict, circular_pad
generate_inf_seq, listdir, with_printed, ifelse, dict_add, as_list, identity, list_to_dict, circular_pad, \
load_py
from .misc import get_tensorneko_path
from .dispatched_misc import sparse2binary, binary2sparse
from .reproducibility import Seed
Expand Down Expand Up @@ -67,11 +67,9 @@
"EventBus",
"EventHandler",
"subscribe",
"subscribe_async",
"subscribe_process",
"subscribe_thread",
"Singleton",
"circular_pad",
"load_py",
"download_file",
"WindowMerger",
]
3 changes: 2 additions & 1 deletion src/tensorneko/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.nn import Module

from tensorneko_util.util.misc import generate_inf_seq, listdir, with_printed, ifelse, dict_add, as_list, \
identity, list_to_dict, compose, circular_pad
identity, list_to_dict, compose, circular_pad, load_py
from .type import T, A


Expand Down Expand Up @@ -159,3 +159,4 @@ def get_tensorneko_path() -> str:
identity = identity
list_to_dict = list_to_dict
circular_pad = circular_pad
load_py = load_py
5 changes: 3 additions & 2 deletions src/tensorneko_util/io/_default_backends.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ..backend.visual_lib import VisualLib
from ..backend.audio_lib import AudioLib


def _default_image_io_backend():
Expand All @@ -23,7 +24,7 @@ def _default_video_io_backend():


def _default_audio_io_backend():
if VisualLib.pytorch_available():
return VisualLib.PYTORCH
if AudioLib.pytorch_available():
return AudioLib.PYTORCH
else:
raise ValueError("No backend available. Please install Torchaudio.")
14 changes: 11 additions & 3 deletions src/tensorneko_util/io/video/video_writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pathlib
import warnings
from typing import Optional, Union

import numpy as np
Expand Down Expand Up @@ -36,12 +37,12 @@ def to(cls, path: str, video: VideoData, audio_codec: str = None, channel_first:
@classmethod
@dispatch
def to(cls, path: str, video: T_ARRAY, video_fps: float, audio: T_ARRAY = None,
audio_fps: Optional[int] = None, audio_codec: str = None, channel_first: bool = False,
audio_fps: int = None, audio_codec: str = None, channel_first: bool = False,
backend: VisualLib = None
) -> None:
"""
Write to video file from :class:`~torch.Tensor` or :class:`~numpy.ndarray` with (T, C, H, W).
TODO: Buggy when the argument is too much.
Args:
path (``str``): The path of output file.
video (:class:`~torch.Tensor` | :class:`~numpy.ndarray`): The video tensor or array with (T, C, H, W) for
Expand All @@ -63,13 +64,19 @@ def to(cls, path: str, video: T_ARRAY, video_fps: float, audio: T_ARRAY = None,
if channel_first:
video = rearrange(video, "t c h w -> t h w c")

# if the video is already a uint8 array, give a warning
# to uint8
if isinstance(video, np.ndarray):
if video.dtype == np.uint8:
warnings.warn("The video array is already a uint8 array. The output video may be incorrect.")
video = (video * 255).astype(np.uint8)
else:
try:
import torch
if isinstance(video, torch.Tensor):
if video.dtype == torch.uint8:
warnings.warn("The video tensor is already a uint8 tensor. The output video may be incorrect.")

if backend == VisualLib.PYTORCH:
video = (video * 255).type(torch.IntTensor)
else:
Expand Down Expand Up @@ -101,7 +108,8 @@ def to(cls, path: str, video: T_ARRAY, video_fps: float, audio: T_ARRAY = None,
if not VisualLib.pytorch_available():
raise ValueError("Torchvision is not installed.")
import torchvision
torchvision.io.write_video(path, video, video_fps, audio_fps=audio_fps, audio_array=audio)
torchvision.io.write_video(path, video, video_fps, audio_fps=audio_fps, audio_array=audio,
audio_codec=audio_codec)
elif backend == VisualLib.FFMPEG:
if audio is not None:
raise ValueError("Write audio is not supported in ffmpeg backend.")
Expand Down
4 changes: 2 additions & 2 deletions src/tensorneko_util/preprocess/ffmpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ def clip_video(video_path: str, output_path: str, start_time: float, end_time: f

if precise:
args = [
"-ss", str(start_time), "-to", str(end_time), "-i", video_path, "-c", "copy", output_path, *ffmpeg_args
"-i", video_path, "-ss", str(start_time), "-to", str(end_time), "-c", "copy", output_path, *ffmpeg_args
]
else:
args = [
"-i", video_path, "-ss", str(start_time), "-to", str(end_time), "-c", "copy", output_path, *ffmpeg_args
"-ss", str(start_time), "-to", str(end_time), "-i", video_path, "-c", "copy", output_path, *ffmpeg_args
]

return ffmpeg_command(args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
8 changes: 3 additions & 5 deletions src/tensorneko_util/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from .dispatcher import dispatch
from .fp import __, F, _, Stream, return_option, Option, Monad, Eval, Seq, AbstractSeq, curry
from .misc import generate_inf_seq, compose, listdir, with_printed, ifelse, dict_add, as_list, identity, list_to_dict, \
get_tensorneko_util_path, circular_pad
get_tensorneko_util_path, circular_pad, load_py
from .dispatched_misc import sparse2binary, binary2sparse
from .ref import ref
from .timer import Timer
from .eventbus import Event, EventBus, EventHandler, subscribe, subscribe_async, subscribe_process, subscribe_thread
from .eventbus import Event, EventBus, EventHandler, subscribe
from .singleton import Singleton
from .downloader import download_file
from .window_merger import WindowMerger
Expand Down Expand Up @@ -46,11 +46,9 @@
"Event",
"EventBus",
"subscribe",
"subscribe_async",
"subscribe_process",
"subscribe_thread",
"Singleton",
"circular_pad",
"load_py",
"download_file",
"WindowMerger",
]
7 changes: 2 additions & 5 deletions src/tensorneko_util/util/eventbus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from .event import Event
from .decorator import subscribe, subscribe_async, subscribe_process, subscribe_thread
from .decorator import subscribe
from .bus import EventBus, EventHandler

__all__ = [
"Event",
"EventBus",
"EventHandler",
"subscribe",
"subscribe_async",
"subscribe_process",
"subscribe_thread",
"subscribe"
]
21 changes: 13 additions & 8 deletions src/tensorneko_util/util/eventbus/decorator.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from typing import Callable, Coroutine, Any

from .bus import EventBus
from .event import E


def subscribe(func):
return EventBus.default.subscribe(func)
class SubscribeDecorator:

def __call__(self, func: Callable[[E], None]):
return EventBus.default.subscribe(func)

def subscribe_async(func):
return EventBus.default.subscribe_async(func)
def coro(self, func: Callable[[E], Coroutine[Any, Any, None]]):
return EventBus.default.subscribe_async(func)

def thread(self, func: Callable[[E], None]):
return EventBus.default.subscribe_thread(func)

def subscribe_thread(func):
return EventBus.default.subscribe_thread(func)
def process(self, func: Callable[[E], None]):
return EventBus.default.subscribe_process(func)


def subscribe_process(func):
return EventBus.default.subscribe_process(func)
subscribe = SubscribeDecorator()
17 changes: 17 additions & 0 deletions src/tensorneko_util/util/misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib.util
import os
from functools import reduce
from os.path import dirname, abspath
Expand Down Expand Up @@ -218,3 +219,19 @@ def circular_pad(x: List, target: int) -> List:
return x + x[:target - len(x)]
else:
return circular_pad(x + x, target)


def load_py(path: str) -> Any:
"""
Load a python file as a module.
Args:
path (``str``): The path of the python file.
Returns:
``Any``: The loaded module.
"""
spec = importlib.util.spec_from_file_location("module.name", path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
14 changes: 7 additions & 7 deletions test/util/eventbus/test_event_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from unittest.mock import patch
from itertools import permutations

from tensorneko_util.util import subscribe, subscribe_async, subscribe_thread, Event
from tensorneko_util.util import subscribe, Event


@subscribe
Expand Down Expand Up @@ -35,13 +35,13 @@ def thread_handler_normal(event: ThreadEvent):
print("thread_handler_normal is called, x =", event.x)


@subscribe_thread
@subscribe.thread
def thread_handler_thread(event: ThreadEvent):
time.sleep(random.random())
print("thread_handler_thread is called, x =", event.x)


@subscribe_thread
@subscribe.thread
def thread_handler_thread(event: ThreadEvent):
time.sleep(random.random())
print("thread_handler_thread2 is called, x =", event.x)
Expand All @@ -52,13 +52,13 @@ def __init__(self, x):
self.x = x


@subscribe_async
@subscribe.coro
async def async_handler_async(event: AsyncEvent):
await asyncio.sleep(random.random())
print("async_handler_async is called, x =", event.x)


@subscribe_async
@subscribe.coro
async def async_handler_async(event: AsyncEvent):
await asyncio.sleep(random.random())
print("async_handler_async2 is called, x =", event.x)
Expand All @@ -75,13 +75,13 @@ def mixed_handler_normal(event: MixedEvent):
print("mixed_handler_normal is called, x =", event.x)


@subscribe_thread
@subscribe.thread
def mixed_handler_thread(event: MixedEvent):
time.sleep(random.random())
print("mixed_handler_thread is called, x =", event.x)


@subscribe_async
@subscribe.coro
async def mixed_handler_async(event: MixedEvent):
await asyncio.sleep(random.random())
print("mixed_handler_async is called, x =", event.x)
Expand Down
6 changes: 3 additions & 3 deletions test/util/eventbus/test_handler_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import unittest

from tensorneko_util.util import subscribe, subscribe_thread, subscribe_async, Event
from tensorneko_util.util import subscribe, Event
from tensorneko_util.util.eventbus.bus import EventHandler


Expand All @@ -29,7 +29,7 @@ def __init__(self, value: int):
self.value = value


@subscribe_thread
@subscribe.thread
class CustomThreadHandler(EventHandler):

def __init__(self):
Expand All @@ -46,7 +46,7 @@ def __init__(self, value: int):
self.value = value


@subscribe_async
@subscribe.coro
class CustomAsyncHandler(EventHandler):

def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.2
0.3.3

0 comments on commit dcb08f7

Please sign in to comment.