This repository has been archived by the owner on May 4, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathknn_decision_boundary.m
executable file
·95 lines (81 loc) · 2.37 KB
/
knn_decision_boundary.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
function [train_accu, test_data_labels, train_data_labels] = knn_decision_boundary(train_data, train_label, new_data, K)
new_data_size = size(new_data,1);
train_data_size = length(train_label);
% Standardise
mean_train_data = mean(train_data);
std_train_data = std(train_data);
standardise_new_data = bsxfun(@minus, new_data, mean_train_data );
standardise_new_data = bsxfun(@rdivide, standardise_new_data, std_train_data);
standardise_train_data = bsxfun(@minus, train_data, mean_train_data );
standardise_train_data = bsxfun(@rdivide, standardise_train_data, std_train_data);
test_label = [];
all_distances = {};
for i =1:new_data_size
distances = [];
sample = standardise_new_data(i,:);
for j=1:train_data_size
distances(end+1) = norm(sample-standardise_train_data(j,:));
end
%all_distances{i} = (distances);
[sorted idx] = sort(distances, 'ascend');
indices= idx(1:K);
pivot = sorted(K);
k=K;
if pivot == sorted(K+1)
% search for all indices s
while sorted(k+1) ~= pivot
k=k+1;
indices(end+1)=k;
end
end
all_labels = [];
for j=1:length(indices)
all_labels(end+1) = train_label(indices(j));
end
test_label(end+1) = mode(all_labels);
end
test_data_labels = test_label;
%test_l = test_label;
new_accu=0;
train_accu = 0;
test_label = [];
all_distances = {};
for i =1:train_data_size
distances = [];
sample = standardise_train_data(i,:);
for j=1:train_data_size
if i ~= j
distances(end+1) = norm(sample-standardise_train_data(j,:));
end
end
all_distances{i} = distances;
end
for i=1:train_data_size
distances = cell2mat(all_distances(i));
[sorted idx] = sort(distances, 'ascend');
indices= idx(1:K);
pivot = sorted(K);
k=K;
if pivot == sorted(K+1)
% search for all indices s
while sorted(k+1) ~= pivot
k=k+1;
indices(end+1)=k;
end
end
all_labels = [];
for j=1:length(indices)
all_labels(end+1) = train_label(indices(j));
end
test_label(end+1) = mode(all_labels);
end
train_data_labels = test_label;
%train_l = test_label;
train_accu =0;
for i=1:train_data_size
if test_label(i) == train_label(i)
train_accu = train_accu+1;
end
end
train_accu = train_accu/(train_data_size-1);
end