Skip to content

Commit

Permalink
Replace self.assertEqual with assert statement
Browse files Browse the repository at this point in the history
Works better with mypy. `assert`s work equally well
since we're using pytest now.
  • Loading branch information
WyattBlue committed Sep 23, 2024
1 parent 68d6a99 commit 985f9bf
Show file tree
Hide file tree
Showing 11 changed files with 205 additions and 221 deletions.
25 changes: 12 additions & 13 deletions tests/test_audiofifo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,17 @@ def test_pts_simple(self) -> None:
assert oframe.pts == 0
assert oframe.time_base == iframe.time_base

self.assertEqual(fifo.samples_written, 1024)
self.assertEqual(fifo.samples_read, 512)
self.assertEqual(fifo.pts_per_sample, 1.0)
assert fifo.samples_written == 1024
assert fifo.samples_read == 512
assert fifo.pts_per_sample == 1.0

iframe.pts = 1024
fifo.write(iframe)
oframe = fifo.read(512)
assert oframe is not None

self.assertEqual(oframe.pts, 512)
self.assertEqual(oframe.time_base, iframe.time_base)
assert oframe.pts == 512
assert oframe.time_base == iframe.time_base

iframe.pts = 9999 # Wrong!
self.assertRaises(ValueError, fifo.write, iframe)
Expand All @@ -88,8 +88,8 @@ def test_pts_complex(self) -> None:

oframe = fifo.read_many(1024)[-1]

self.assertEqual(oframe.pts, 2048)
self.assertEqual(fifo.pts_per_sample, 2.0)
assert oframe.pts == 2048
assert fifo.pts_per_sample == 2.0

def test_missing_sample_rate(self) -> None:
fifo = av.AudioFifo()
Expand All @@ -103,9 +103,9 @@ def test_missing_sample_rate(self) -> None:
oframe = fifo.read(512)

assert oframe is not None
self.assertIsNone(oframe.pts)
self.assertEqual(oframe.sample_rate, 0)
self.assertEqual(oframe.time_base, iframe.time_base)
assert oframe.pts is None
assert oframe.sample_rate == 0
assert oframe.time_base == iframe.time_base

def test_missing_time_base(self) -> None:
fifo = av.AudioFifo()
Expand All @@ -119,6 +119,5 @@ def test_missing_time_base(self) -> None:
oframe = fifo.read(512)

assert oframe is not None
self.assertIsNone(oframe.pts)
self.assertIsNone(oframe.time_base)
self.assertEqual(oframe.sample_rate, iframe.sample_rate)
assert oframe.pts is None and oframe.time_base is None
assert oframe.sample_rate == iframe.sample_rate
26 changes: 13 additions & 13 deletions tests/test_bitstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ def test_filter_chomp(self) -> None:
ctx = BitStreamFilterContext("chomp")

src_packets: tuple[Packet, None] = (Packet(b"\x0012345\0\0\0"), None)
self.assertEqual(bytes(src_packets[0]), b"\x0012345\0\0\0")
assert bytes(src_packets[0]) == b"\x0012345\0\0\0"

result_packets = []
for p in src_packets:
result_packets.extend(ctx.filter(p))

self.assertEqual(len(result_packets), 1)
self.assertEqual(bytes(result_packets[0]), b"\x0012345")
assert len(result_packets) == 1
assert bytes(result_packets[0]) == b"\x0012345"

def test_filter_setts(self) -> None:
ctx = BitStreamFilterContext("setts=pts=N")
Expand All @@ -48,9 +48,9 @@ def test_filter_setts(self) -> None:
for p in src_packets:
result_packets.extend(ctx.filter(p))

self.assertEqual(len(result_packets), 2)
self.assertEqual(result_packets[0].pts, 0)
self.assertEqual(result_packets[1].pts, 1)
assert len(result_packets) == 2
assert result_packets[0].pts == 0
assert result_packets[1].pts == 1

def test_filter_h264_mp4toannexb(self) -> None:
with av.open(fate_suite("h264/interlaced_crop.mp4"), "r") as container:
Expand All @@ -62,22 +62,22 @@ def test_filter_h264_mp4toannexb(self) -> None:
self.assertFalse(is_annexb(p))
res_packets.extend(ctx.filter(p))

self.assertEqual(len(res_packets), stream.frames)
assert len(res_packets) == stream.frames

for p in res_packets:
self.assertTrue(is_annexb(p))
assert is_annexb(p)

def test_filter_output_parameters(self) -> None:
with av.open(fate_suite("h264/interlaced_crop.mp4"), "r") as container:
stream = container.streams.video[0]

self.assertFalse(is_annexb(stream.codec_context.extradata))
assert not is_annexb(stream.codec_context.extradata)
ctx = BitStreamFilterContext("h264_mp4toannexb", stream)
self.assertFalse(is_annexb(stream.codec_context.extradata))
assert not is_annexb(stream.codec_context.extradata)
del ctx

_ = BitStreamFilterContext("h264_mp4toannexb", stream, out_stream=stream)
self.assertTrue(is_annexb(stream.codec_context.extradata))
assert is_annexb(stream.codec_context.extradata)

def test_filter_flush(self) -> None:
with av.open(fate_suite("h264/interlaced_crop.mp4"), "r") as container:
Expand All @@ -87,7 +87,7 @@ def test_filter_flush(self) -> None:
res_packets = []
for p in container.demux(stream):
res_packets.extend(ctx.filter(p))
self.assertEqual(len(res_packets), stream.frames)
assert len(res_packets) == stream.frames

container.seek(0)
# Without flushing, we expect to get an error: "A non-NULL packet sent after an EOF."
Expand All @@ -100,4 +100,4 @@ def test_filter_flush(self) -> None:
for p in container.demux(stream):
res_packets.extend(ctx.filter(p))

self.assertEqual(len(res_packets), stream.frames * 2)
assert len(res_packets) == stream.frames * 2
62 changes: 31 additions & 31 deletions tests/test_codec_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,52 +38,52 @@ def iter_raw_frames(path, packet_sizes, ctx):
class TestCodecContext(TestCase):
def test_skip_frame_default(self):
ctx = Codec("png", "w").create()
self.assertEqual(ctx.skip_frame.name, "DEFAULT")
assert ctx.skip_frame.name == "DEFAULT"

def test_codec_delay(self):
with av.open(fate_suite("mkv/codec_delay_opus.mkv")) as container:
self.assertEqual(container.streams.audio[0].codec_context.delay, 312)
assert container.streams.audio[0].codec_context.delay == 312
with av.open(fate_suite("h264/interlaced_crop.mp4")) as container:
self.assertEqual(container.streams.video[0].codec_context.delay, 0)
assert container.streams.video[0].codec_context.delay == 0

def test_codec_tag(self):
ctx = Codec("mpeg4", "w").create()
self.assertEqual(ctx.codec_tag, "\x00\x00\x00\x00")
assert ctx.codec_tag == "\x00\x00\x00\x00"
ctx.codec_tag = "xvid"
self.assertEqual(ctx.codec_tag, "xvid")
assert ctx.codec_tag == "xvid"

# wrong length
with self.assertRaises(ValueError) as cm:
ctx.codec_tag = "bob"
self.assertEqual(str(cm.exception), "Codec tag should be a 4 character string.")
assert str(cm.exception) == "Codec tag should be a 4 character string."

# wrong type
with self.assertRaises(ValueError) as cm:
ctx.codec_tag = 123
self.assertEqual(str(cm.exception), "Codec tag should be a 4 character string.")
assert str(cm.exception) == "Codec tag should be a 4 character string."

with av.open(fate_suite("h264/interlaced_crop.mp4")) as container:
self.assertEqual(container.streams[0].codec_tag, "avc1")
assert container.streams[0].codec_tag == "avc1"

def test_decoder_extradata(self):
ctx = av.codec.Codec("h264", "r").create()
self.assertEqual(ctx.extradata, None)
self.assertEqual(ctx.extradata_size, 0)
assert ctx.extradata is None
assert ctx.extradata_size == 0

ctx.extradata = b"123"
self.assertEqual(ctx.extradata, b"123")
self.assertEqual(ctx.extradata_size, 3)
assert ctx.extradata == b"123"
assert ctx.extradata_size == 3

ctx.extradata = b"54321"
self.assertEqual(ctx.extradata, b"54321")
self.assertEqual(ctx.extradata_size, 5)
assert ctx.extradata == b"54321"
assert ctx.extradata_size == 5

ctx.extradata = None
self.assertEqual(ctx.extradata, None)
self.assertEqual(ctx.extradata_size, 0)
assert ctx.extradata is None
assert ctx.extradata_size == 0

def test_decoder_gop_size(self):
ctx = av.codec.Codec("h264", "r").create()
def test_decoder_gop_size(self) -> None:
ctx = av.codec.Codec("h264", "r").create("video")

with self.assertRaises(RuntimeError):
ctx.gop_size
Expand All @@ -99,7 +99,7 @@ def test_decoder_timebase(self) -> None:

def test_encoder_extradata(self):
ctx = av.codec.Codec("h264", "w").create()
self.assertEqual(ctx.extradata, None)
assert ctx.extradata is None
self.assertEqual(ctx.extradata_size, 0)

ctx.extradata = b"123"
Expand Down Expand Up @@ -170,7 +170,7 @@ def _assert_parse(self, codec_name, path):

parsed_source = b"".join(bytes(p) for p in packets)
self.assertEqual(len(parsed_source), len(full_source))
self.assertEqual(full_source, parsed_source)
assert full_source == parsed_source


class TestEncoding(TestCase):
Expand Down Expand Up @@ -214,7 +214,7 @@ def image_sequence_encode(self, codec_name: str) -> None:
new_frame = frame.reformat(width, height, pix_fmt)
new_packets = ctx.encode(new_frame)

self.assertEqual(len(new_packets), 1)
assert len(new_packets) == 1
new_packet = new_packets[0]

path = self.sandboxed(
Expand All @@ -240,9 +240,9 @@ def image_sequence_encode(self, codec_name: str) -> None:
packet = Packet(size)
size = f.readinto(packet)
frame = ctx.decode(packet)[0]
self.assertEqual(frame.width, width)
self.assertEqual(frame.height, height)
self.assertEqual(frame.format.name, pix_fmt)
assert frame.width == width
assert frame.height == height
assert frame.format.name == pix_fmt

def test_encoding_h264(self):
self.video_encoding("h264", {"crf": "19"})
Expand Down Expand Up @@ -333,13 +333,13 @@ def video_encoding(self, codec_name, options={}, codec_tag=None):
decoded_frame_count = 0
for frame in iter_raw_frames(path, packet_sizes, ctx):
decoded_frame_count += 1
self.assertEqual(frame.width, width)
self.assertEqual(frame.height, height)
self.assertEqual(frame.format.name, pix_fmt)
assert frame.width == width
assert frame.height == height
assert frame.format.name == pix_fmt
if frame.key_frame:
keyframe_indices.append(decoded_frame_count)

self.assertEqual(frame_count, decoded_frame_count)
assert frame_count == decoded_frame_count

self.assertIsInstance(
all(keyframe_index for keyframe_index in keyframe_indices), int
Expand All @@ -352,7 +352,7 @@ def video_encoding(self, codec_name, options={}, codec_tag=None):
):
raise SkipTest()
for i in decoded_gop_sizes:
self.assertEqual(i, gop_size)
assert i == gop_size

final_gop_size = decoded_frame_count - max(keyframe_indices)
self.assertLessEqual(final_gop_size, gop_size)
Expand Down Expand Up @@ -433,5 +433,5 @@ def _audio_encoding(

for frame in iter_raw_frames(path, packet_sizes, ctx):
result_samples += frame.samples
self.assertEqual(frame.sample_rate, sample_rate)
self.assertEqual(frame.layout.nb_channels, 2)
assert frame.sample_rate == sample_rate
assert frame.layout.nb_channels == 2
36 changes: 17 additions & 19 deletions tests/test_colorspace.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,38 @@
import av
from av.video.reformatter import ColorRange, Colorspace

from .common import TestCase, fate_suite
from .common import fate_suite


class TestColorSpace(TestCase):
class TestColorSpace:
def test_penguin_joke(self) -> None:
container = av.open(
fate_suite("amv/MTV_high_res_320x240_sample_Penguin_Joke_MTV_from_WMV.amv")
)
stream = container.streams.video[0]

self.assertEqual(stream.codec_context.color_range, 2)
self.assertEqual(stream.codec_context.color_range, ColorRange.JPEG)
assert stream.codec_context.color_range == 2
assert stream.codec_context.color_range == ColorRange.JPEG

self.assertEqual(stream.codec_context.color_primaries, 2)
self.assertEqual(stream.codec_context.color_trc, 2)
assert stream.codec_context.color_primaries == 2
assert stream.codec_context.color_trc == 2

self.assertEqual(stream.codec_context.colorspace, 5)
self.assertEqual(stream.codec_context.colorspace, Colorspace.ITU601)
assert stream.codec_context.colorspace == 5
assert stream.codec_context.colorspace == Colorspace.ITU601

for packet in container.demux(stream):
for frame in packet.decode():
assert isinstance(frame, av.VideoFrame)
self.assertEqual(frame.color_range, ColorRange.JPEG) # a.k.a "pc"
self.assertEqual(frame.colorspace, Colorspace.ITU601)
return
for frame in container.decode(stream):
assert frame.color_range == ColorRange.JPEG # a.k.a "pc"
assert frame.colorspace == Colorspace.ITU601
return

def test_sky_timelapse(self) -> None:
container = av.open(
av.datasets.curated("pexels/time-lapse-video-of-night-sky-857195.mp4")
)
stream = container.streams.video[0]

self.assertEqual(stream.codec_context.color_range, 1)
self.assertEqual(stream.codec_context.color_range, ColorRange.MPEG)
self.assertEqual(stream.codec_context.color_primaries, 1)
self.assertEqual(stream.codec_context.color_trc, 1)
self.assertEqual(stream.codec_context.colorspace, 1)
assert stream.codec_context.color_range == 1
assert stream.codec_context.color_range == ColorRange.MPEG
assert stream.codec_context.color_primaries == 1
assert stream.codec_context.color_trc == 1
assert stream.codec_context.colorspace == 1
Loading

0 comments on commit 985f9bf

Please sign in to comment.