Skip to content

kNN #419

@tpoisot

Description

@tpoisot

This isn't optimized at all but it Just Works ™ - will make a proper PR when not on public wifi

Base.@kwdef mutable struct kNN <: Classifier
    k::Int = 3
    X::Ref = Ref(0)
    y::Ref = Ref(0)
    d::Function = (x, y) -> sqrt(sum((x .- y) .^2.0))
end

Base.zero(::Type{kNN}) = 1.0 - eps()

function SDeMo.train!(knn::kNN, y::Vector{Bool}, X::Matrix{T}; kwargs...) where {T <: Number}
    knn.X = Ref(X)
    knn.y = Ref(y)
    return knn
end

function StatsAPI.predict(knn::kNN, x::Vector{T}) where {T <: Number}
    d = [knn.d(x, knn.X.x[:,i]) for i in axes(knn.X.x, 2)]
    return Statistics.mean(knn.y.x[sortperm(d)[1:knn.k]])
end

function StatsAPI.predict(knn::kNN, X::Matrix{T}) where {T <: Number}
    return [predict(knn, X[:,i]) for i in axes(X, 2)]
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    🏖️ low effortIssue that can be tackled without too much problem🧠 SDeMoChanges primarily affecting or taking place in SDeMo

    Projects

    Status

    Needs assessment

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions