Now, we will implement the k-NN algorithm in Python to find Mary's temperature preference. At the end of this section, we will also implement the visualization of the data produced in the previous section, that is, Mary and her temperature preferences. The full, compilable code, with the input files, can be found in the source code provided with this book. The most important parts have been extracted and presented here:
# source_code/1/mary_and_temperature_preferences/knn_to_data.py
# Applies the knn algorithm to the input data.
# The input text file is assumed to be of the format with one line per
# every data entry consisting of the temperature in degrees Celsius,
# wind speed and then the classification cold/warm.
import sys
sys.path.append('..')
sys.path.append('../../common')
import knn # noqa
import common # noqa
# Program start
# E.g. "mary_and_temperature_preferences.data"
input_file = sys.argv[1]
# E.g. "mary_and_temperature_preferences_completed.data"
output_file = sys.argv[2]
k = int(sys.argv[3])
x_from = int(sys.argv[4])
x_to = int(sys.argv[5])
y_from = int(sys.argv[6])
y_to = int(sys.argv[7])
data = common.load_3row_data_to_dic(input_file)
new_data = knn.knn_to_2d_data(data, x_from, x_to, y_from, y_to, k)
common.save_3row_data_from_dic(output_file, new_data)
# source_code/common/common.py
# ***Library with common routines and functions***
def dic_inc(dic, key):
if key is None:
pass
if dic.get(key, None) is None:
dic[key] = 1
else:
dic[key] = dic[key] + 1
# source_code/1/knn.py
# ***Library implementing knn algorithm***
def info_reset(info):
info['nbhd_count'] = 0
info['class_count'] = {}
# Find the class of a neighbor with the coordinates x,y.
# If the class is known count that neighbor.
def info_add(info, data, x, y):
group = data.get((x, y), None)
common.dic_inc(info['class_count'], group)
info['nbhd_count'] += int(group is not None)
# Apply knn algorithm to the 2d data using the k-nearest neighbors with
# the Manhattan distance.
# The dictionary data comes in the form with keys being 2d coordinates
# and the values being the class.
# x,y are integer coordinates for the 2d data with the range
# [x_from,x_to] x [y_from,y_to].
def knn_to_2d_data(data, x_from, x_to, y_from, y_to, k):
new_data = {}
info = {}
# Go through every point in an integer coordinate system.
for y in range(y_from, y_to + 1):
for x in range(x_from, x_to + 1):
info_reset(info)
# Count the number of neighbors for each class group for
# every distance dist starting at 0 until at least k
# neighbors with known classes are found.
for dist in range(0, x_to - x_from + y_to - y_from):
# Count all neighbors that are distanced dist from
# the point [x,y].
if dist == 0:
info_add(info, data, x, y)
else:
for i in range(0, dist + 1):
info_add(info, data, x - i, y + dist - i)
info_add(info, data, x + dist - i, y - i)
for i in range(1, dist):
info_add(info, data, x + i, y + dist - i)
info_add(info, data, x - dist + i, y - i)
# There could be more than k-closest neighbors if the
# distance of more of them is the same from the point
# [x,y]. But immediately when we have at least k of
# them, we break from the loop.
if info['nbhd_count'] >= k:
break
class_max_count = None
# Choose the class with the highest count of the neighbors
# from among the k-closest neighbors.
for group, count in info['class_count'].items():
if group is not None and (class_max_count is None or
count > info['class_count'][class_max_count]):
class_max_count = group
new_data[x, y] = class_max_count
return new_data
Input:
The preceding program will use the following file as the source of the input data. The file contains the table with the known data about Mary's temperature preferences:
# source_code/1/mary_and_temperature_preferences/
marry_and_temperature_preferences.data
10 0 cold
25 0 warm
15 5 cold
20 3 warm
18 7 cold
20 10 cold
22 5 warm
24 6 warm
Output:
We run the preceding implementation on the mary_and_temperature_preferences.data input file by using the k-NN algorithm for k=1 neighbors. The algorithm classifies all the points with integer coordinates in the rectangle with a size of (30-5=25) by (10-0=10), hence, with a size of (25+1) * (10+1) = 286 integer points (adding one to count points on boundaries). Using the wc command, we find out that the output file contains exactly 286 lines—one data item per point. Using the head command, we display the first 10 lines from the output file:
$ python knn_to_data.py mary_and_temperature_preferences.data mary_and_temperature_preferences_completed.data 1 5 30 0 10
$ wc -l mary_and_temperature_preferences_completed.data
286 mary_and_temperature_preferences_completed.data
$ head -10 mary_and_temperature_preferences_completed.data
7 3 cold
6 9 cold
12 1 cold
16 6 cold
16 9 cold
14 4 cold
13 4 cold
19 4 warm
18 4 cold
15 1 cold
Visualization:
For the visualization depicted earlier in this chapter, the matplotlib library was used. A data file is loaded, and then displayed in a scatter diagram:
# source_code/common/common.py
# returns a dictionary of 3 lists: 1st with x coordinates,
# 2nd with y coordinates, 3rd with colors with numeric values
def get_x_y_colors(data):
dic = {}
dic['x'] = [0] * len(data)
dic['y'] = [0] * len(data)
dic['colors'] = [0] * len(data)
for i in range(0, len(data)):
dic['x'][i] = data[i][0]
dic['y'][i] = data[i][1]
dic['colors'][i] = data[i][2]
return dic
# source_code/1/mary_and_temperature_preferences/
mary_and_temperature_preferences_draw_graph.py
import sys
sys.path.append('../../common') # noqa
import common
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib
matplotlib.style.use('ggplot')
data_file_name = 'mary_and_temperature_preferences_completed.data'
temp_from = 5
temp_to = 30
wind_from = 0
wind_to = 10
data = np.loadtxt(open(data_file_name, 'r'),
dtype={'names': ('temperature', 'wind', 'perception'),
'formats': ('i4', 'i4', 'S4')})
# Convert the classes to the colors to be displayed in a diagram.
for i in range(0, len(data)):
if data[i][2] == 'cold':
data[i][2] = 'blue'
elif data[i][2] == 'warm':
data[i][2] = 'red'
else:
data[i][2] = 'gray'
# Convert the array into the format ready for drawing functions.
data_processed = common.get_x_y_colors(data)
# Draw the graph.
plt.title('Mary and temperature preferences')
plt.xlabel('temperature in C')
plt.ylabel('wind speed in kmph')
plt.axis([temp_from, temp_to, wind_from, wind_to])
# Add legends to the graph.
blue_patch = mpatches.Patch(color='blue', label='cold')
red_patch = mpatches.Patch(color='red', label='warm')
plt.legend(handles=[blue_patch, red_patch])
plt.scatter(data_processed['x'], data_processed['y'],
c=data_processed['colors'], s=[1400] * len(data))
plt.show()