Skip to content

Commit 0841e5c

Browse files
committed
Use float32 for levels
1 parent 600e3b5 commit 0841e5c

File tree

2 files changed

+13
-16
lines changed

2 files changed

+13
-16
lines changed

auto_editor/analyze.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def obj_tag(tag: str, tb: Fraction, obj: dict[str, Any]) -> str:
8888
return key
8989

9090

91-
def iter_audio(src, tb: Fraction, stream: int = 0) -> Iterator[float]:
91+
def iter_audio(src, tb: Fraction, stream: int = 0) -> Iterator[np.float32]:
9292
fifo = AudioFifo()
9393
try:
9494
container = av.open(src.path, "r")
@@ -117,13 +117,13 @@ def iter_audio(src, tb: Fraction, stream: int = 0) -> Iterator[float]:
117117
audio_chunk = fifo.read(current_size)
118118
assert audio_chunk is not None
119119
arr = audio_chunk.to_ndarray().flatten()
120-
yield float(np.max(np.abs(arr)))
120+
yield np.max(np.abs(arr))
121121

122122
finally:
123123
container.close()
124124

125125

126-
def iter_motion(src, tb, stream: int, blur: int, width: int) -> Iterator[float]:
126+
def iter_motion(src, tb, stream: int, blur: int, width: int) -> Iterator[np.float32]:
127127
container = av.open(src.path, "r")
128128

129129
video = container.streams.video[stream]
@@ -155,11 +155,11 @@ def iter_motion(src, tb, stream: int, blur: int, width: int) -> Iterator[float]:
155155

156156
current_frame = frame.to_ndarray()
157157
if prev_frame is None:
158-
value = 0.0
158+
value = np.float32(0.0)
159159
else:
160160
# Use `int16` to avoid underflow with `uint8` datatype
161161
diff = np.abs(prev_frame.astype(np.int16) - current_frame.astype(np.int16))
162-
value = np.count_nonzero(diff) / total_pixels
162+
value = np.float32(np.count_nonzero(diff) / total_pixels)
163163

164164
for _ in range(index - prev_index):
165165
yield value
@@ -237,7 +237,7 @@ def cache(self, tag: str, obj: dict[str, Any], arr: np.ndarray) -> np.ndarray:
237237

238238
return arr
239239

240-
def audio(self, stream: int) -> NDArray[np.float64]:
240+
def audio(self, stream: int) -> NDArray[np.float32]:
241241
if stream >= len(self.src.audios):
242242
raise LevelError(f"audio: audio stream '{stream}' does not exist.")
243243

@@ -256,12 +256,12 @@ def audio(self, stream: int) -> NDArray[np.float64]:
256256
bar = self.bar
257257
bar.start(inaccurate_dur, "Analyzing audio volume")
258258

259-
result = np.zeros((inaccurate_dur), dtype=np.float64)
259+
result = np.zeros((inaccurate_dur), dtype=np.float32)
260260
index = 0
261261
for value in iter_audio(self.src, self.tb, stream):
262262
if index > len(result) - 1:
263263
result = np.concatenate(
264-
(result, np.zeros((len(result)), dtype=np.float64))
264+
(result, np.zeros((len(result)), dtype=np.float32))
265265
)
266266
result[index] = value
267267
bar.tick(index)
@@ -270,7 +270,7 @@ def audio(self, stream: int) -> NDArray[np.float64]:
270270
bar.end()
271271
return self.cache("audio", {"stream": stream}, result[:index])
272272

273-
def motion(self, stream: int, blur: int, width: int) -> NDArray[np.float64]:
273+
def motion(self, stream: int, blur: int, width: int) -> NDArray[np.float32]:
274274
if stream >= len(self.src.videos):
275275
raise LevelError(f"motion: video stream '{stream}' does not exist.")
276276

@@ -289,12 +289,12 @@ def motion(self, stream: int, blur: int, width: int) -> NDArray[np.float64]:
289289
bar = self.bar
290290
bar.start(inaccurate_dur, "Analyzing motion")
291291

292-
result = np.zeros((inaccurate_dur), dtype=np.float64)
292+
result = np.zeros((inaccurate_dur), dtype=np.float32)
293293
index = 0
294294
for value in iter_motion(self.src, self.tb, stream, blur, width):
295295
if index > len(result) - 1:
296296
result = np.concatenate(
297-
(result, np.zeros((len(result)), dtype=np.float64))
297+
(result, np.zeros((len(result)), dtype=np.float32))
298298
)
299299
result[index] = value
300300
bar.tick(index)

auto_editor/subcommands/levels.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,11 @@ def print_arr(arr: NDArray) -> None:
7272
print("")
7373

7474

75-
def print_arr_gen(arr: Iterator[int | float]) -> None:
75+
def print_arr_gen(arr: Iterator[float | np.float32]) -> None:
7676
print("")
7777
print("@start")
7878
for a in arr:
79-
if isinstance(a, float):
80-
print(f"{a:.20f}")
81-
else:
82-
print(a)
79+
print(f"{a:.20f}")
8380
print("")
8481

8582

0 commit comments

Comments
 (0)