方法一
第一阶段的Map,对每一数据项 $(i,j,v)$,若来自矩阵 A ,则输出 $(j,(A,i,v))$,若来自矩阵 B ,则输出$(i,(B,j,v))$,这 样 矩 阵 A 的 第j 列 和 矩 阵 B 的 第 i 行 会被同一个 reduce 节点处理;在 Reduce 端,将来自 A 和 B 的数据分别存储在数组 listA 和 listB ,对来自 A 中的数据 $(j,(A,i,v))$ ,令 listA[i] = v ,对来自 B 的数 据$(i,(B,j,v))$,令 listB[ j] = v 。将 listA 中的每个项 乘以 listB 中的每个项然后输出,对于 listA[i] 和 listB[ j] ,输出 $((i,j),listA[i]*listB[j])$。在第二阶段只 需要将第一阶段输出中有相同 key 的数据求和即可。
MatrixMultiplication1.java
1 package com.lagou.mining.hdfs; 2 3 import java.io.IOException; 4 import java.util.HashMap; 5 import java.util.Iterator; 6 import java.util.Map; 7 import java.util.Map.Entry; 8 9 import org.apache.hadoop.conf.Configuration; 10 import org.apache.hadoop.fs.FileSystem; 11 import org.apache.hadoop.fs.Path; 12 import org.apache.hadoop.io.DoubleWritable; 13 import org.apache.hadoop.io.IntWritable; 14 import org.apache.hadoop.io.LongWritable; 15 import org.apache.hadoop.io.Text; 16 import org.apache.hadoop.mapreduce.Job; 17 import org.apache.hadoop.mapreduce.Mapper; 18 import org.apache.hadoop.mapreduce.Reducer; 19 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; 20 import org.apache.hadoop.mapreduce.lib.input.FileSplit; 21 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; 22 23 24 /** 25 * 矩阵相乘。 26 * 27 * @Author:orisun 28 * @Since:2015-6-26 29 * @Version:1.0 30 */ 31 public class MatrixMultiplication1 { 32 33 private static final String MATRIXFILE1 = "A"; 34 private static final String MATRIXFILE2 = "B"; 35 36 public static class EleEmitMapper extends 37 Mapper<IntWritable, Text, IntWritable, Text> { 38 39 private String matrixFile1 = null; 40 private String matrixFile2 = null; 41 42 @Override 43 protected void setup(Context context) { 44 matrixFile1 = context.getConfiguration().get(MATRIXFILE1); 45 matrixFile2 = context.getConfiguration().get(MATRIXFILE2); 46 } 47 48 @Override 49 protected void map(IntWritable key, Text value, Context context) 50 throws IOException, InterruptedException { 51 IntWritable i = key; 52 String filename = ((FileSplit) context.getInputSplit()).getPath() 53 .getName(); 54 String[] arr = value.toString().split("\\s+"); 55 // 发出去N个互不相同的key 56 if (filename.equals(matrixFile1)) { 57 for (int j = 0; j < arr.length; j++) { 58 double v = Double.parseDouble(arr[j]); 59 //元素为0时不发出,对于稀疏矩阵这样效率会比较高 60 if (v != 0) { 61 context.write(new IntWritable(j), new Text(MATRIXFILE1 62 + "\t" + i.toString() + "\t" + v)); 63 } 64 } 65 } else if (filename.equals(matrixFile2)) { 66 for (int j = 0; j < arr.length; j++) { 67 double v = Double.parseDouble(arr[j]); 68 //元素为0时不发出,对于稀疏矩阵这样效率会比较高 69 if (v != 0) { 70 context.write(i, new Text(MATRIXFILE2 + "\t" + j + "\t" 71 + v)); 72 } 73 } 74 } 75 } 76 77 @Override 78 protected void cleanup(Context context) { 79 } 80 } 81 82 public static class MultiplicationReducer extends 83 Reducer<IntWritable, Text, Text, DoubleWritable> { 84 85 @Override 86 protected void reduce(IntWritable key, Iterable<Text> value, 87 Context context) throws IOException, InterruptedException { 88 Map<Integer, Double> listA = new HashMap<Integer, Double>(); 89 Map<Integer, Double> listB = new HashMap<Integer, Double>(); 90 Iterator<Text> itr = value.iterator(); 91 while (itr.hasNext()) { 92 String[] arr = itr.next().toString().split("\\s+"); 93 String matrixTag = arr[0]; 94 int pos = Integer.parseInt(arr[1]); 95 double v = Double.parseDouble(arr[2]); 96 if (MATRIXFILE1.equals(matrixTag)) { 97 listA.put(pos, v); 98 } else if (MATRIXFILE2.equals(matrixTag)) { 99 listB.put(pos, v); 100 } 101 } 102 // 在此需要进行N*N次的乘法 103 for (Entry<Integer, Double> entryA : listA.entrySet()) { 104 int posA = entryA.getKey(); 105 double valA = entryA.getValue(); 106 for (Entry<Integer, Double> entryB : listB.entrySet()) { 107 int posB = entryB.getKey(); 108 double valB = entryB.getValue(); 109 double production = valA * valB; 110 context.write(new Text(posA + "\t" + posB), 111 new DoubleWritable(production)); 112 } 113 } 114 } 115 } 116 117 public static class SumMapper extends 118 Mapper<LongWritable, Text, Text, DoubleWritable> { 119 120 @Override 121 protected void map(LongWritable key, Text value, Context context) 122 throws IOException, InterruptedException { 123 String[] arr = value.toString().split("\\s+"); 124 if (arr.length == 3) { 125 context.write(new Text(arr[0] + "\t" + arr[1]), 126 new DoubleWritable(Double.parseDouble(arr[2]))); 127 } 128 } 129 } 130 131 public static class SumCombiner extends 132 Reducer<Text, DoubleWritable, Text, DoubleWritable> { 133 134 @Override 135 protected void reduce(Text key, Iterable<DoubleWritable> value, 136 Context context) throws IOException, InterruptedException { 137 double sum = 0; 138 Iterator<DoubleWritable> itr = value.iterator(); 139 while (itr.hasNext()) { 140 sum += itr.next().get(); 141 } 142 context.write(key, new DoubleWritable(sum)); 143 } 144 } 145 146 public static class SumReducer extends 147 Reducer<Text, DoubleWritable, Text, DoubleWritable> { 148 149 @Override 150 protected void reduce(Text key, Iterable<DoubleWritable> value, 151 Context context) throws IOException, InterruptedException { 152 double sum = 0; 153 Iterator<DoubleWritable> itr = value.iterator(); 154 while (itr.hasNext()) { 155 sum += itr.next().get(); 156 } 157 context.write(key, new DoubleWritable(sum)); 158 } 159 } 160 161 /** 162 * matrix1 * matrix2 = product<br> 163 * matrixFile1:输入文件,m行q列。<br> 164 * matrixFile2:输入文件,q行n列。<br> 165 * productFile:输出文件,m行n列。<br> 166 * 各列用空白符分隔。 167 */ 168 public static void main(String[] args) throws IOException, 169 ClassNotFoundException, InterruptedException { 170 if (args.length < 3) { 171 System.err 172 .println("please input 3 cmd args: matrixFile1 matrixFile2 productFile"); 173 System.exit(1); 174 } 175 String matrixFile1 = args[0]; 176 String matrixFile2 = args[1]; 177 String productFile = args[2]; 178 179 Configuration conf = new Configuration(); 180 FileSystem fs = FileSystem.get(conf); 181 Path inFile1 = new Path(matrixFile1); 182 Path inFile2 = new Path(matrixFile2); 183 conf.set(MATRIXFILE1, inFile1.getName()); 184 conf.set(MATRIXFILE2, inFile2.getName()); 185 Path midFile = new Path(inFile1.getParent().toUri().getPath() 186 + "/product_tmp"); 187 Path outFile = new Path(productFile); 188 if (!fs.exists(inFile2) || !fs.exists(inFile1)) { 189 System.err.println("input matrix file does not exists"); 190 System.exit(1); 191 } 192 if (fs.exists(midFile)) { 193 fs.delete(midFile, true); 194 } 195 if (fs.exists(outFile)) { 196 fs.delete(outFile, true); 197 } 198 199 { 200 Job productionJob1 = Job.getInstance(conf); 201 productionJob1.setJobName("MatrixMultiplication1_step1"); 202 productionJob1.setJarByClass(MatrixMultiplication1.class); 203 204 FileInputFormat.addInputPath(productionJob1, inFile1); 205 FileInputFormat.addInputPath(productionJob1, inFile2); 206 productionJob1.setInputFormatClass(MatrixInputFormat.class); 207 productionJob1.setMapperClass(EleEmitMapper.class); 208 productionJob1.setMapOutputKeyClass(IntWritable.class); 209 productionJob1.setMapOutputValueClass(Text.class); 210 211 FileOutputFormat.setOutputPath(productionJob1, midFile); 212 productionJob1.setReducerClass(MultiplicationReducer.class); 213 productionJob1.setNumReduceTasks(12); 214 productionJob1.setOutputKeyClass(Text.class); 215 productionJob1.setOutputValueClass(DoubleWritable.class); 216 217 productionJob1.waitForCompletion(true); 218 } 219 220 { 221 Job productionJob2 = Job.getInstance(conf); 222 productionJob2.setJobName("MatrixMultiplication1_step2"); 223 productionJob2.setJarByClass(MatrixMultiplication1.class); 224 225 FileInputFormat.setInputPaths(productionJob2, midFile); 226 productionJob2.setMapperClass(SumMapper.class); 227 productionJob2.setMapOutputKeyClass(Text.class); 228 productionJob2.setMapOutputValueClass(DoubleWritable.class); 229 230 FileOutputFormat.setOutputPath(productionJob2, outFile); 231 productionJob2.setCombinerClass(SumCombiner.class); 232 productionJob2.setReducerClass(SumReducer.class); 233 productionJob2.setNumReduceTasks(1); 234 productionJob2.setOutputKeyClass(Text.class); 235 productionJob2.setOutputValueClass(DoubleWritable.class); 236 237 productionJob2.waitForCompletion(true); 238 239 } 240 fs.delete(midFile, true); 241 System.exit(0); 242 } 243 }