Skip to content

Commit f3fe8cb

Browse files
committed
WIP: Loser tree
1 parent dbe42b4 commit f3fe8cb

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed

src/tree.jl

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
module t
2+
3+
# TODO: The tree needs to remove Nothing's to be efficient.
4+
# * Do not pad to power of two.
5+
# * Every time an iterator is exhausted, re-build the tree.
6+
7+
struct LoserTree{T, I, S}
8+
# The loser tree. Index 1 is winner, rest is the loser tree.
9+
# Except node 1, node x' parents is 2x - 1 and 2x.
10+
# Padded with `nothing` to be a power-of-two long
11+
12+
# TODO: The vector lengths never changes, and it would be nice if they were
13+
# constant propagated. So, use Memory{T} when it's supported by Julia LTS
14+
tree::Vector{Union{Tuple{Int, T}, Nothing}}
15+
iters::Vector{I}
16+
states::Vector{S}
17+
end
18+
19+
function nextpow_2(x)
20+
B = 8 * sizeof(x)
21+
s = B - leading_zeros(x - 1)
22+
return one(x) << (s & (B - 1))
23+
end
24+
25+
function LoserTree(it)
26+
iters = collect(it)
27+
I = eltype(iters)
28+
T = eltype(I)
29+
tree = fill!(Vector{Union{Tuple{Int, T}, Nothing}}(undef, nextpow_2(length(iters))), nothing)
30+
states = init_tree!(tree, iters)
31+
return LoserTree{T, I, eltype(states)}(tree, iters, states)
32+
end
33+
34+
function init_tree!(tree::Vector{Union{Nothing, Tuple{Int, T}}}, iters::Vector)::Vector where {T}
35+
isempty(iters) && return Vector{Union{}}
36+
(val, states) = _init_tree(tree, 2, nothing, iters)
37+
tree[1] = val
38+
return if states === nothing
39+
Vector{Union{}}(undef, length(iters))
40+
else
41+
states
42+
end
43+
end
44+
45+
function _init_tree(tree, i, states, iters)::Tuple{Union{Nothing, Tuple{Int, Any}}, Union{Nothing, Vector}}
46+
if i > length(tree)
47+
iter_i = i - length(tree)
48+
return if iter_i > length(iters)
49+
(nothing, states)
50+
else
51+
iter = iters[iter_i]
52+
itval = iterate(iter)
53+
if itval === nothing
54+
(nothing, states)
55+
else
56+
(val, state) = itval
57+
if states === nothing
58+
states = Vector{typeof(state)}(undef, length(iters))
59+
end
60+
states[iter_i] = state
61+
((iter_i, val), states)
62+
end
63+
end
64+
else
65+
(x1, states) = _init_tree(tree, 2i - 1, states, iters)
66+
(x2, states) = _init_tree(tree, 2i, states, iters)
67+
(winner, loser) = get_winner_loser(x1, x2)
68+
tree[i] = loser
69+
return (winner, states)
70+
end
71+
end
72+
73+
function get_winner_loser(a::Union{Nothing, Tuple{Int, Any}}, b::Union{Nothing, Tuple{Int, Any}})
74+
return if a === nothing
75+
(b, a)
76+
elseif b === nothing
77+
(a, b)
78+
elseif last(a) < last(b)
79+
(a, b)
80+
else
81+
(b, a)
82+
end
83+
end
84+
85+
@inline function bubble_up!(tree::Vector{Union{Nothing, Tuple{Int, T}}}, winner::Union{Nothing, Tuple{Int, T}}, pi::Int) where {T}
86+
n_iters = trailing_zeros(length(tree))
87+
i = 0
88+
@inbounds while i < n_iters
89+
parent = tree[pi]
90+
(winner, loser) = get_winner_loser(winner, parent)
91+
tree[pi] = loser
92+
pi = (pi + 1) >>> 1
93+
i += 1
94+
end
95+
@inbounds tree[1] = winner
96+
return nothing
97+
end
98+
99+
Base.IteratorSize(::Type{<:LoserTree}) = Base.SizeUnknown()
100+
Base.eltype(::Type{<:LoserTree{T}}) where {T} = Tuple{Int, T}
101+
102+
function Base.iterate(x::LoserTree, _state::Nothing = nothing)
103+
result = @inbounds x.tree[1]
104+
result === nothing && return result
105+
(i, element) = result
106+
state = @inbounds x.states[i]
107+
itval = iterate(@inbounds(x.iters[i]), state)
108+
new_val = if itval === nothing
109+
nothing
110+
else
111+
(new_element, new_state) = itval
112+
@inbounds x.states[i] = new_state
113+
(i, new_element)
114+
end
115+
pi = (i + nextpow_2(length(x.iters)) + 1) >>> 1
116+
bubble_up!(x.tree, new_val, pi)
117+
return ((i, element), nothing)
118+
end
119+
120+
end # module

0 commit comments

Comments
 (0)