Creating a price model using k-Nearest Neighbours + Genetic Algorithm
Chapter 8 of Programming Collective Intelligence (PCI) explains the usage and implementation of the k-Nearest Neighbours algorithm. (k-NN).
Simply put:
k-NN is a classification algorithm that uses (k) for the number of neighbours to determine what class an item will belong to. To determine the neighbours to be used the algorithm uses a distance / similarity score function, in this example (Euclidian Distance).
PCI takes it a little further to help with accuracy in some scenarios. This includes the usage of a weighted average of the neighbours, as well as then using either simulated annealing or genetic algorithms to determine the best weights, building on Optimization techniques – Simulated Annealing & Genetic Algorithms (As with all the previous chapters the code is in my github repository).
So the similarity score function looked like (slightly different to the one used earlier, which was inverted to return 1 if equals):
01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | package net.briandupreez.pci.chapter8; import java.util.HashMap; import java.util.List; import java.util.Map; public class EuclideanDistanceScore { /** * Determine distance between list of points. * * @param list1 first list * @param list2 second list * @return distance between the two lists between 0 and 1... 0 being identical. */ public static double distanceList( final List<Double> list1, final List<Double> list2) { if (list1.size() != list2.size()) { throw new RuntimeException( "Same number of values required." ); } double sumOfAllSquares = 0 ; for ( int i = 0 ; i < list1.size(); i++) { sumOfAllSquares += Math.pow(list2.get(i) - list1.get(i), 2 ); } return Math.sqrt(sumOfAllSquares); } } |
The simulated annealing and genetic algorithm code I updated as I originally implemented them using Ints… (lesson learnt when doing anything it do with ML or AI, stick to doubles).
001 002 003 004 005 006 007 008 009 010 011 012 013 014 015 016 017 018 019 020 021 022 023 024 025 026 027 028 029 030 031 032 033 034 035 036 037 038 039 040 041 042 043 044 045 046 047 048 049 050 051 052 053 054 055 056 057 058 059 060 061 062 063 064 065 066 067 068 069 070 071 072 073 074 075 076 077 078 079 080 081 082 083 084 085 086 087 088 089 090 091 092 093 094 095 096 097 098 099 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | package net.briandupreez.pci.chapter8; import org.javatuples.Pair; import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Arrays; import java.util.Date; import java.util.List; import java.util.Map; import java.util.Random; import java.util.SortedMap; import java.util.TreeMap; /** * Created with IntelliJ IDEA. * User: bdupreez * Date: 2013/07/05 * Time: 9:08 PM */ public class Optimization { public List<Pair<Integer, Integer>> createDomain() { final List<Pair<Integer, Integer>> domain = new ArrayList<>( 4 ); for ( int i = 0 ; i < 4 ; i++) { final Pair<Integer, Integer> pair = new Pair<>( 0 , 10 ); domain.add(pair); } return domain; } /** * Simulated Annealing * * @param domain list of tuples with min and max * @return (global minimum) */ public Double[] simulatedAnnealing( final List<Pair<Integer, Integer>> domain, final double startingTemp, final double cool, final int step) { double temp = startingTemp; //create random Double[] sol = new Double[domain.size()]; Random random = new Random(); for ( int r = 0 ; r < domain.size(); r++) { sol[r] = Double.valueOf(random.nextInt( 19 )); } while (temp > 0.1 ) { //pick a random indices int i = random.nextInt(domain.size() - 1 ); //pick a directions + or - int direction = random.nextInt(step) % 2 == 0 ? -(random.nextInt(step)) : random.nextInt( 1 ); Double[] cloneSolr = sol.clone(); cloneSolr[i] += direction; if (cloneSolr[i] < domain.get(i).getValue0()) { cloneSolr[i] = Double.valueOf(domain.get(i).getValue0()); } else if (cloneSolr[i] > domain.get(i).getValue1()) { cloneSolr[i] = Double.valueOf(domain.get(i).getValue1()); } //calc current and new cost double currentCost = scheduleCost(sol); double newCost = scheduleCost(cloneSolr); System.out.println( "Current: " + currentCost + " New: " + newCost); double probability = Math.pow(Math.E, -(newCost - currentCost) / temp); // Is it better, or does it make the probability cutoff? if (newCost < currentCost || Math.random() < probability) { sol = cloneSolr; } temp = temp * cool; } return sol; } public double scheduleCost(Double[] sol) { NumPredict numPredict = new NumPredict(); final List<Map<String,List<Double>>> rescale = numPredict.rescale(numPredict.createWineSet2(), Arrays.asList(sol)); return numPredict.crossValidate(rescale, 0.1 , 100 ); } public Double[] geneticAlgorithm( final List<Pair<Integer, Integer>> domain, final int populationSize, final int step, final double elite, final int maxIter, final double mutProb) { List<Double[]> pop = createPopulation(domain.size(), populationSize); final int topElite = new Double(elite * populationSize).intValue(); final SortedMap<Double, Double[]> scores = new TreeMap<>(); for ( int i = 0 ; i < maxIter; i++) { for ( final Double[] run : pop) { scores.put(scheduleCost(run), run); } pop = determineElite(topElite, scores); while (pop.size() < populationSize) { final Random random = new Random(); if (Math.random() < mutProb) { final int ran = random.nextInt(topElite); pop.add(mutate(domain, pop.get(ran), step)); } else { final int ran1 = random.nextInt(topElite); final int ran2 = random.nextInt(topElite); pop.add(crossover(pop.get(ran1), pop.get(ran2), domain.size())); } } System.out.println(scores); } return scores.entrySet().iterator().next().getValue(); } /** * Grab the elites * * @param topElite how many * @param scores sorted on score * @return best ones */ private List<Double[]> determineElite( int topElite, SortedMap<Double, Double[]> scores) { Double toKey = null ; int index = 0 ; for ( final Double key : scores.keySet()) { if (index++ == topElite) { toKey = key; break ; } } scores = scores.headMap(toKey); return new ArrayList<>(scores.values()); } /** * Create a population * * @param arraySize the array size * @param popSize the population size * @return a random population */ private List<Double[]> createPopulation( final int arraySize, final int popSize) { final List<Double[]> returnList = new ArrayList<>(); for ( int i = 0 ; i < popSize; i++) { Double[] sol = new Double[arraySize]; Random random = new Random(); for ( int r = 0 ; r < arraySize; r++) { sol[r] = Double.valueOf(random.nextInt( 8 )); } returnList.add(sol); } return returnList; } /** * Mutate a value. * * @param domain the domain * @param vec the data to be mutated * @param step the step * @return mutated array */ private Double[] mutate( final List<Pair<Integer, Integer>> domain, final Double[] vec, final int step) { final Random random = new Random(); int i = random.nextInt(domain.size() - 1 ); Double[] retArr = vec.clone(); if (Math.random() < 0.5 && (vec[ 1 ] - step) > domain.get(i).getValue0()) { retArr[i] -= step; } else if (vec[i] + step < domain.get(i).getValue1()) { retArr[i] += step; } return vec; } /** * Cross over parts of each array * * @param arr1 array 1 * @param arr2 array 2 * @param max max value * @return new array */ private Double[] crossover( final Double[] arr1, final Double[] arr2, final int max) { final Random random = new Random(); int i = random.nextInt(max); return concatArrays(Arrays.copyOfRange(arr1, 0 , i), Arrays.copyOfRange(arr2, i, arr2.length)); } /** * Concat 2 arrays * * @param first first * @param second second * @return new combined array */ private Double[] concatArrays( final Double[] first, final Double[] second) { Double[] result = Arrays.copyOf(first, first.length + second.length); System.arraycopy(second, 0 , result, first.length, second.length); return result; } } |
Then finally putting it all together my Java implementation of the PCI example
001 002 003 004 005 006 007 008 009 010 011 012 013 014 015 016 017 018 019 020 021 022 023 024 025 026 027 028 029 030 031 032 033 034 035 036 037 038 039 040 041 042 043 044 045 046 047 048 049 050 051 052 053 054 055 056 057 058 059 060 061 062 063 064 065 066 067 068 069 070 071 072 073 074 075 076 077 078 079 080 081 082 083 084 085 086 087 088 089 090 091 092 093 094 095 096 097 098 099 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 | package net.briandupreez.pci.chapter8; import org.javatuples.Pair; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Random; /** * NumPredict. * User: bdupreez * Date: 2013/08/12 * Time: 8:29 AM */ public class NumPredict { /** * Determine the wine price. * * @param rating rating * @param age age * @return the price */ public double winePrice( final double rating, final double age) { final double peakAge = rating - 50 ; //Calculate the price based on rating double price = rating / 2 ; if (age > peakAge) { //goes bad in 10 years price *= 5 - (age - peakAge) / 2 ; } else { //increases as it reaches its peak price *= 5 * ((age + 1 )) / peakAge; } if (price < 0 ) { price = 0.0 ; } return price; } /** * Data Generator * * @return data */ @SuppressWarnings ( "unchecked" ) public List<Map<String, List<Double>>> createWineSet1() { final List<Map<String, List<Double>>> wineList = new ArrayList<>(); for ( int i = 0 ; i < 300 ; i++) { double rating = Math.random() * 50 + 50 ; double age = Math.random() * 50 ; double price = winePrice(rating, age); price *= (Math.random() * 0.2 + 0.9 ); final Map<String, List<Double>> map = new HashMap<>(); final List<Double> input = new LinkedList<>(); input.add(rating); input.add(age); map.put( "input" , input); final List<Double> result = new ArrayList(); result.add(price); map.put( "result" , result); wineList.add(map); } return wineList; } /** * Data Generator * * @return data */ @SuppressWarnings ( "unchecked" ) public List<Map<String, List<Double>>> createWineSet2() { final List<Map<String, List<Double>>> wineList = new ArrayList<>(); for ( int i = 0 ; i < 300 ; i++) { double rating = Math.random() * 50 + 50 ; double age = Math.random() * 50 ; final Random random = new Random(); double aisle = ( double ) random.nextInt( 20 ); double [] sizes = new double []{ 375.0 , 750.0 , 1500.0 }; double bottleSize = sizes[random.nextInt( 3 )]; double price = winePrice(rating, age); price *= (bottleSize / 750 ); price *= (Math.random() * 0.2 + 0.9 ); final Map<String, List<Double>> map = new HashMap<>(); final List<Double> input = new LinkedList<>(); input.add(rating); input.add(age); input.add(aisle); input.add(bottleSize); map.put( "input" , input); final List<Double> result = new ArrayList(); result.add(price); map.put( "result" , result); wineList.add(map); } return wineList; } /** * Rescale * * @param data data * @param scale the scales * @return scaled data */ public List<Map<String, List<Double>>> rescale( final List<Map<String, List<Double>>> data, final List<Double> scale) { final List<Map<String, List<Double>>> scaledData = new ArrayList<>(); for ( final Map<String, List<Double>> dataItem : data) { final List<Double> scaledList = new LinkedList<>(); for ( int i = 0 ; i < scale.size(); i++) { scaledList.add(scale.get(i) * dataItem.get( "input" ).get(i)); } dataItem.put( "input" , scaledList); scaledData.add(dataItem); } return scaledData; } /** * Determine all the distances from a list * * @param data all the data * @param vec1 one list * @return all the distances */ public List<Pair<Double, Integer>> determineDistances( final List<Map<String, List<Double>>> data, final List<Double> vec1) { final List<Pair<Double, Integer>> distances = new ArrayList<>(); int i = 1 ; for ( final Map<String, List<Double>> map : data) { final List<Double> vec2 = map.get( "input" ); distances.add( new Pair(EuclideanDistanceScore.distanceList(vec1, vec2), i++)); } Collections.sort(distances); return distances; } /** * Use kNN to estimate a new price * * @param data all the data * @param vec1 new fields to price * @param k the amount of neighbours * @return the estimated price */ public double knnEstimate( final List<Map<String, List<Double>>> data, final List<Double> vec1, final int k) { final List<Pair<Double, Integer>> distances = determineDistances(data, vec1); double avg = 0.0 ; for ( int i = 0 ; i <= k; i++) { int idx = distances.get(i).getValue1(); avg += data.get(idx - 1 ).get( "result" ).get( 0 ); } avg = avg / k; return avg; } /** * KNN using a weighted average of the neighbours * * @param data the dataset * @param vec1 the data to price * @param k number of neighbours * @return the weighted price */ public double weightedKnn( final List<Map<String, List<Double>>> data, final List<Double> vec1, final int k) { final List<Pair<Double, Integer>> distances = determineDistances(data, vec1); double avg = 0.0 ; double totalWeight = 0.0 ; for ( int i = 0 ; i <= k; i++) { double dist = distances.get(i).getValue0(); int idx = distances.get(i).getValue1(); double weight = guassianWeight(dist, 5.0 ); avg += weight * data.get(idx - 1 ).get( "result" ).get( 0 ); totalWeight += weight; } if (totalWeight == 0.0 ) { return 0.0 ; } avg = avg / totalWeight; return avg; } /** * Gaussian Weight function, smoother weight curve that doesnt go to 0 * * @param distance the distance * @param sigma sigma * @return weighted value */ public double guassianWeight( final double distance, final double sigma) { double alteredDistance = -(Math.pow(distance, 2 )); double sigmaSize = ( 2 * Math.pow(sigma, 2 )); return Math.pow(Math.E, (alteredDistance / sigmaSize)); } /** * Split the data for cross validation. * * @param data the data to split * @param testPercent % of data to use for the tests * @return a tuple 0 - training, 1 - test */ @SuppressWarnings ( "unchecked" ) public Pair<List, List> divideData( final List<Map<String, List<Double>>> data, final double testPercent) { final List trainingList = new ArrayList(); final List testList = new ArrayList(); for ( final Map<String, List<Double>> dataItem : data) { if (Math.random() < testPercent) { testList.add(dataItem); } else { trainingList.add(dataItem); } } return new Pair(trainingList, testList); } /** * Test result and squares the differences to make it more obvious * * @param trainingSet the training set * @param testSet the test set * @return the error */ @SuppressWarnings ( "unchecked" ) public double testAlgorithm( final List trainingSet, final List testSet) { double error = 0.0 ; final List<Map<String, List<Double>>> typedSet = (List<Map<String, List<Double>>>) testSet; for ( final Map<String, List<Double>> testData : typedSet) { double guess = weightedKnn(trainingSet, testData.get( "input" ), 3 ); error += Math.pow((testData.get( "result" ).get( 0 ) - guess), 2 ); } return error / testSet.size(); } /** * This runs iterations of the test, and returns an averaged score * * @param data the data * @param testPercent % test * @param trials number of iterations * @return result */ public double crossValidate( final List<Map<String, List<Double>>> data, final double testPercent, final int trials) { double error = 0.0 ; for ( int i = 0 ; i < trials; i++) { final Pair<List, List> trainingPair = divideData(data, testPercent); error += testAlgorithm(trainingPair.getValue0(), trainingPair.getValue1()); } return error / trials; } /** * Gives the probability that an item is in a price range between 0 and 1 * Adds up the neighbours weightd and divides it by the total * * @param data the data * @param vec1 the input * @param k the number of neighbours * @param low low amount of range * @param high the high amount * @return probability between 0 and 1 */ public double probabilityGuess( final List<Map<String, List<Double>>> data, final List<Double> vec1, final int k, final double low, final double high) { final List<Pair<Double, Integer>> distances = determineDistances(data, vec1); double neighbourWeights = 0.0 ; double totalWeights = 0.0 ; for ( int i = 0 ; i < k; i++) { double dist = distances.get(i).getValue0(); int index = distances.get(i).getValue1(); double weight = guassianWeight(dist, 5 ); final List<Double> result = data.get(index).get( "result" ); double v = result.get( 0 ); //check if the point is in the range. if (v >= low && v <= high) { neighbourWeights += weight; } totalWeights += weight; } if (totalWeights == 0 ) { return 0 ; } return neighbourWeights / totalWeights; } } |
While reading up some more on k-NN I also stumbled upon the following blog posts:
- First one describing some of the difficulties around using k-NN.: k-Nearest Neighbors – dangerously simple
- And then one giving a great overview of k-NN: A detailed introduction to k-NN algorithm