Creating a price model using k-Nearest Neighbours + Genetic Algorithm
Chapter 8 of Programming Collective Intelligence (PCI) explains the usage and implementation of the k-Nearest Neighbours algorithm. (k-NN).
Simply put:
k-NN is a classification algorithm that uses (k) for the number of neighbours to determine what class an item will belong to. To determine the neighbours to be used the algorithm uses a distance / similarity score function, in this example (Euclidian Distance).
PCI takes it a little further to help with accuracy in some scenarios. This includes the usage of a weighted average of the neighbours, as well as then using either simulated annealing or genetic algorithms to determine the best weights, building on Optimization techniques – Simulated Annealing & Genetic Algorithms (As with all the previous chapters the code is in my github repository).
So the similarity score function looked like (slightly different to the one used earlier, which was inverted to return 1 if equals):
package net.briandupreez.pci.chapter8; import java.util.HashMap; import java.util.List; import java.util.Map; public class EuclideanDistanceScore { /** * Determine distance between list of points. * * @param list1 first list * @param list2 second list * @return distance between the two lists between 0 and 1... 0 being identical. */ public static double distanceList(final List<Double> list1, final List<Double> list2) { if (list1.size() != list2.size()) { throw new RuntimeException("Same number of values required."); } double sumOfAllSquares = 0; for (int i = 0; i < list1.size(); i++) { sumOfAllSquares += Math.pow(list2.get(i) - list1.get(i), 2); } return Math.sqrt(sumOfAllSquares); } }
The simulated annealing and genetic algorithm code I updated as I originally implemented them using Ints… (lesson learnt when doing anything it do with ML or AI, stick to doubles).
package net.briandupreez.pci.chapter8; import org.javatuples.Pair; import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Arrays; import java.util.Date; import java.util.List; import java.util.Map; import java.util.Random; import java.util.SortedMap; import java.util.TreeMap; /** * Created with IntelliJ IDEA. * User: bdupreez * Date: 2013/07/05 * Time: 9:08 PM */ public class Optimization { public List<Pair<Integer, Integer>> createDomain() { final List<Pair<Integer, Integer>> domain = new ArrayList<>(4); for (int i = 0; i < 4; i++) { final Pair<Integer, Integer> pair = new Pair<>(0, 10); domain.add(pair); } return domain; } /** * Simulated Annealing * * @param domain list of tuples with min and max * @return (global minimum) */ public Double[] simulatedAnnealing(final List<Pair<Integer, Integer>> domain, final double startingTemp, final double cool, final int step) { double temp = startingTemp; //create random Double[] sol = new Double[domain.size()]; Random random = new Random(); for (int r = 0; r < domain.size(); r++) { sol[r] = Double.valueOf(random.nextInt(19)); } while (temp > 0.1) { //pick a random indices int i = random.nextInt(domain.size() - 1); //pick a directions + or - int direction = random.nextInt(step) % 2 == 0 ? -(random.nextInt(step)) : random.nextInt(1); Double[] cloneSolr = sol.clone(); cloneSolr[i] += direction; if (cloneSolr[i] < domain.get(i).getValue0()) { cloneSolr[i] = Double.valueOf(domain.get(i).getValue0()); } else if (cloneSolr[i] > domain.get(i).getValue1()) { cloneSolr[i] = Double.valueOf(domain.get(i).getValue1()); } //calc current and new cost double currentCost = scheduleCost(sol); double newCost = scheduleCost(cloneSolr); System.out.println("Current: " + currentCost + " New: " + newCost); double probability = Math.pow(Math.E, -(newCost - currentCost) / temp); // Is it better, or does it make the probability cutoff? if (newCost < currentCost || Math.random() < probability) { sol = cloneSolr; } temp = temp * cool; } return sol; } public double scheduleCost(Double[] sol) { NumPredict numPredict = new NumPredict(); final List<Map<String,List<Double>>> rescale = numPredict.rescale(numPredict.createWineSet2(), Arrays.asList(sol)); return numPredict.crossValidate(rescale,0.1,100); } public Double[] geneticAlgorithm(final List<Pair<Integer, Integer>> domain, final int populationSize, final int step, final double elite, final int maxIter, final double mutProb) { List<Double[]> pop = createPopulation(domain.size(), populationSize); final int topElite = new Double(elite * populationSize).intValue(); final SortedMap<Double, Double[]> scores = new TreeMap<>(); for (int i = 0; i < maxIter; i++) { for (final Double[] run : pop) { scores.put(scheduleCost(run), run); } pop = determineElite(topElite, scores); while (pop.size() < populationSize) { final Random random = new Random(); if (Math.random() < mutProb) { final int ran = random.nextInt(topElite); pop.add(mutate(domain, pop.get(ran), step)); } else { final int ran1 = random.nextInt(topElite); final int ran2 = random.nextInt(topElite); pop.add(crossover(pop.get(ran1), pop.get(ran2), domain.size())); } } System.out.println(scores); } return scores.entrySet().iterator().next().getValue(); } /** * Grab the elites * * @param topElite how many * @param scores sorted on score * @return best ones */ private List<Double[]> determineElite(int topElite, SortedMap<Double, Double[]> scores) { Double toKey = null; int index = 0; for (final Double key : scores.keySet()) { if (index++ == topElite) { toKey = key; break; } } scores = scores.headMap(toKey); return new ArrayList<>(scores.values()); } /** * Create a population * * @param arraySize the array size * @param popSize the population size * @return a random population */ private List<Double[]> createPopulation(final int arraySize, final int popSize) { final List<Double[]> returnList = new ArrayList<>(); for (int i = 0; i < popSize; i++) { Double[] sol = new Double[arraySize]; Random random = new Random(); for (int r = 0; r < arraySize; r++) { sol[r] = Double.valueOf(random.nextInt(8)); } returnList.add(sol); } return returnList; } /** * Mutate a value. * * @param domain the domain * @param vec the data to be mutated * @param step the step * @return mutated array */ private Double[] mutate(final List<Pair<Integer, Integer>> domain, final Double[] vec, final int step) { final Random random = new Random(); int i = random.nextInt(domain.size() - 1); Double[] retArr = vec.clone(); if (Math.random() < 0.5 && (vec[1] - step) > domain.get(i).getValue0()) { retArr[i] -= step; } else if (vec[i] + step < domain.get(i).getValue1()) { retArr[i] += step; } return vec; } /** * Cross over parts of each array * * @param arr1 array 1 * @param arr2 array 2 * @param max max value * @return new array */ private Double[] crossover(final Double[] arr1, final Double[] arr2, final int max) { final Random random = new Random(); int i = random.nextInt(max); return concatArrays(Arrays.copyOfRange(arr1, 0, i), Arrays.copyOfRange(arr2, i, arr2.length)); } /** * Concat 2 arrays * * @param first first * @param second second * @return new combined array */ private Double[] concatArrays(final Double[] first, final Double[] second) { Double[] result = Arrays.copyOf(first, first.length + second.length); System.arraycopy(second, 0, result, first.length, second.length); return result; } }
Then finally putting it all together my Java implementation of the PCI example
package net.briandupreez.pci.chapter8; import org.javatuples.Pair; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Random; /** * NumPredict. * User: bdupreez * Date: 2013/08/12 * Time: 8:29 AM */ public class NumPredict { /** * Determine the wine price. * * @param rating rating * @param age age * @return the price */ public double winePrice(final double rating, final double age) { final double peakAge = rating - 50; //Calculate the price based on rating double price = rating / 2; if (age > peakAge) { //goes bad in 10 years price *= 5 - (age - peakAge) / 2; } else { //increases as it reaches its peak price *= 5 * ((age + 1)) / peakAge; } if (price < 0) { price = 0.0; } return price; } /** * Data Generator * * @return data */ @SuppressWarnings("unchecked") public List<Map<String, List<Double>>> createWineSet1() { final List<Map<String, List<Double>>> wineList = new ArrayList<>(); for (int i = 0; i < 300; i++) { double rating = Math.random() * 50 + 50; double age = Math.random() * 50; double price = winePrice(rating, age); price *= (Math.random() * 0.2 + 0.9); final Map<String, List<Double>> map = new HashMap<>(); final List<Double> input = new LinkedList<>(); input.add(rating); input.add(age); map.put("input", input); final List<Double> result = new ArrayList(); result.add(price); map.put("result", result); wineList.add(map); } return wineList; } /** * Data Generator * * @return data */ @SuppressWarnings("unchecked") public List<Map<String, List<Double>>> createWineSet2() { final List<Map<String, List<Double>>> wineList = new ArrayList<>(); for (int i = 0; i < 300; i++) { double rating = Math.random() * 50 + 50; double age = Math.random() * 50; final Random random = new Random(); double aisle = (double) random.nextInt(20); double[] sizes = new double[]{375.0, 750.0, 1500.0}; double bottleSize = sizes[random.nextInt(3)]; double price = winePrice(rating, age); price *= (bottleSize / 750); price *= (Math.random() * 0.2 + 0.9); final Map<String, List<Double>> map = new HashMap<>(); final List<Double> input = new LinkedList<>(); input.add(rating); input.add(age); input.add(aisle); input.add(bottleSize); map.put("input", input); final List<Double> result = new ArrayList(); result.add(price); map.put("result", result); wineList.add(map); } return wineList; } /** * Rescale * * @param data data * @param scale the scales * @return scaled data */ public List<Map<String, List<Double>>> rescale(final List<Map<String, List<Double>>> data, final List<Double> scale) { final List<Map<String, List<Double>>> scaledData = new ArrayList<>(); for (final Map<String, List<Double>> dataItem : data) { final List<Double> scaledList = new LinkedList<>(); for (int i = 0; i < scale.size(); i++) { scaledList.add(scale.get(i) * dataItem.get("input").get(i)); } dataItem.put("input", scaledList); scaledData.add(dataItem); } return scaledData; } /** * Determine all the distances from a list * * @param data all the data * @param vec1 one list * @return all the distances */ public List<Pair<Double, Integer>> determineDistances(final List<Map<String, List<Double>>> data, final List<Double> vec1) { final List<Pair<Double, Integer>> distances = new ArrayList<>(); int i = 1; for (final Map<String, List<Double>> map : data) { final List<Double> vec2 = map.get("input"); distances.add(new Pair(EuclideanDistanceScore.distanceList(vec1, vec2), i++)); } Collections.sort(distances); return distances; } /** * Use kNN to estimate a new price * * @param data all the data * @param vec1 new fields to price * @param k the amount of neighbours * @return the estimated price */ public double knnEstimate(final List<Map<String, List<Double>>> data, final List<Double> vec1, final int k) { final List<Pair<Double, Integer>> distances = determineDistances(data, vec1); double avg = 0.0; for (int i = 0; i <= k; i++) { int idx = distances.get(i).getValue1(); avg += data.get(idx - 1).get("result").get(0); } avg = avg / k; return avg; } /** * KNN using a weighted average of the neighbours * * @param data the dataset * @param vec1 the data to price * @param k number of neighbours * @return the weighted price */ public double weightedKnn(final List<Map<String, List<Double>>> data, final List<Double> vec1, final int k) { final List<Pair<Double, Integer>> distances = determineDistances(data, vec1); double avg = 0.0; double totalWeight = 0.0; for (int i = 0; i <= k; i++) { double dist = distances.get(i).getValue0(); int idx = distances.get(i).getValue1(); double weight = guassianWeight(dist, 5.0); avg += weight * data.get(idx - 1).get("result").get(0); totalWeight += weight; } if (totalWeight == 0.0) { return 0.0; } avg = avg / totalWeight; return avg; } /** * Gaussian Weight function, smoother weight curve that doesnt go to 0 * * @param distance the distance * @param sigma sigma * @return weighted value */ public double guassianWeight(final double distance, final double sigma) { double alteredDistance = -(Math.pow(distance, 2)); double sigmaSize = (2 * Math.pow(sigma, 2)); return Math.pow(Math.E, (alteredDistance / sigmaSize)); } /** * Split the data for cross validation. * * @param data the data to split * @param testPercent % of data to use for the tests * @return a tuple 0 - training, 1 - test */ @SuppressWarnings("unchecked") public Pair<List, List> divideData(final List<Map<String, List<Double>>> data, final double testPercent) { final List trainingList = new ArrayList(); final List testList = new ArrayList(); for (final Map<String, List<Double>> dataItem : data) { if (Math.random() < testPercent) { testList.add(dataItem); } else { trainingList.add(dataItem); } } return new Pair(trainingList, testList); } /** * Test result and squares the differences to make it more obvious * * @param trainingSet the training set * @param testSet the test set * @return the error */ @SuppressWarnings("unchecked") public double testAlgorithm(final List trainingSet, final List testSet) { double error = 0.0; final List<Map<String, List<Double>>> typedSet = (List<Map<String, List<Double>>>) testSet; for (final Map<String, List<Double>> testData : typedSet) { double guess = weightedKnn(trainingSet, testData.get("input"), 3); error += Math.pow((testData.get("result").get(0) - guess), 2); } return error / testSet.size(); } /** * This runs iterations of the test, and returns an averaged score * * @param data the data * @param testPercent % test * @param trials number of iterations * @return result */ public double crossValidate(final List<Map<String, List<Double>>> data, final double testPercent, final int trials) { double error = 0.0; for (int i = 0; i < trials; i++) { final Pair<List, List> trainingPair = divideData(data, testPercent); error += testAlgorithm(trainingPair.getValue0(), trainingPair.getValue1()); } return error / trials; } /** * Gives the probability that an item is in a price range between 0 and 1 * Adds up the neighbours weightd and divides it by the total * * @param data the data * @param vec1 the input * @param k the number of neighbours * @param low low amount of range * @param high the high amount * @return probability between 0 and 1 */ public double probabilityGuess(final List<Map<String, List<Double>>> data, final List<Double> vec1, final int k, final double low, final double high) { final List<Pair<Double, Integer>> distances = determineDistances(data, vec1); double neighbourWeights = 0.0; double totalWeights = 0.0; for (int i = 0; i < k; i++) { double dist = distances.get(i).getValue0(); int index = distances.get(i).getValue1(); double weight = guassianWeight(dist, 5); final List<Double> result = data.get(index).get("result"); double v = result.get(0); //check if the point is in the range. if (v >= low && v <= high) { neighbourWeights += weight; } totalWeights += weight; } if (totalWeights == 0) { return 0; } return neighbourWeights / totalWeights; } }
While reading up some more on k-NN I also stumbled upon the following blog posts:
- First one describing some of the difficulties around using k-NN.: k-Nearest Neighbors – dangerously simple
- And then one giving a great overview of k-NN: A detailed introduction to k-NN algorithm