方法一

第一阶段的Map,对每一数据项 $(i,j,v)$,若来自矩阵 A ,则输出 $(j,(A,i,v))$,若来自矩阵 B ,则输出$(i,(B,j,v))$,这 样 矩 阵 A 的 第j 列 和 矩 阵 B 的 第 i 行 会被同一个 reduce 节点处理;在 Reduce 端,将来自 A 和 B 的数据分别存储在数组 listA 和 listB ,对来自 A 中的数据 $(j,(A,i,v))$ ,令 listA[i] = v ,对来自 B 的数 据$(i,(B,j,v))$,令 listB[ j] = v 。将 listA 中的每个项 乘以 listB 中的每个项然后输出,对于 listA[i] 和 listB[ j] ,输出 $((i,j),listA[i]*listB[j])$。在第二阶段只 需要将第一阶段输出中有相同 key 的数据求和即可。 

MatrixMultiplication1.java

  1 package com.lagou.mining.hdfs;
  2 
  3 import java.io.IOException;
  4 import java.util.HashMap;
  5 import java.util.Iterator;
  6 import java.util.Map;
  7 import java.util.Map.Entry;
  8 
  9 import org.apache.hadoop.conf.Configuration;
 10 import org.apache.hadoop.fs.FileSystem;
 11 import org.apache.hadoop.fs.Path;
 12 import org.apache.hadoop.io.DoubleWritable;
 13 import org.apache.hadoop.io.IntWritable;
 14 import org.apache.hadoop.io.LongWritable;
 15 import org.apache.hadoop.io.Text;
 16 import org.apache.hadoop.mapreduce.Job;
 17 import org.apache.hadoop.mapreduce.Mapper;
 18 import org.apache.hadoop.mapreduce.Reducer;
 19 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
 20 import org.apache.hadoop.mapreduce.lib.input.FileSplit;
 21 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
 22 
 23 
 24 /**
 25  * 矩阵相乘。
 26  * 
 27  * @Author:orisun
 28  * @Since:2015-6-26
 29  * @Version:1.0
 30  */
 31 public class MatrixMultiplication1 {
 32 
 33     private static final String MATRIXFILE1 = "A";
 34     private static final String MATRIXFILE2 = "B";
 35 
 36     public static class EleEmitMapper extends
 37             Mapper<IntWritable, Text, IntWritable, Text> {
 38 
 39         private String matrixFile1 = null;
 40         private String matrixFile2 = null;
 41 
 42         @Override
 43         protected void setup(Context context) {
 44             matrixFile1 = context.getConfiguration().get(MATRIXFILE1);
 45             matrixFile2 = context.getConfiguration().get(MATRIXFILE2);
 46         }
 47 
 48         @Override
 49         protected void map(IntWritable key, Text value, Context context)
 50                 throws IOException, InterruptedException {
 51             IntWritable i = key;
 52             String filename = ((FileSplit) context.getInputSplit()).getPath()
 53                     .getName();
 54             String[] arr = value.toString().split("\\s+");
 55             // 发出去N个互不相同的key
 56             if (filename.equals(matrixFile1)) {
 57                 for (int j = 0; j < arr.length; j++) {
 58                     double v = Double.parseDouble(arr[j]);
 59                     //元素为0时不发出,对于稀疏矩阵这样效率会比较高
 60                     if (v != 0) {
 61                         context.write(new IntWritable(j), new Text(MATRIXFILE1
 62                                 + "\t" + i.toString() + "\t" + v));
 63                     }
 64                 }
 65             } else if (filename.equals(matrixFile2)) {
 66                 for (int j = 0; j < arr.length; j++) {
 67                     double v = Double.parseDouble(arr[j]);
 68                     //元素为0时不发出,对于稀疏矩阵这样效率会比较高
 69                     if (v != 0) {
 70                         context.write(i, new Text(MATRIXFILE2 + "\t" + j + "\t"
 71                                 + v));
 72                     }
 73                 }
 74             }
 75         }
 76 
 77         @Override
 78         protected void cleanup(Context context) {
 79         }
 80     }
 81 
 82     public static class MultiplicationReducer extends
 83             Reducer<IntWritable, Text, Text, DoubleWritable> {
 84 
 85         @Override
 86         protected void reduce(IntWritable key, Iterable<Text> value,
 87                 Context context) throws IOException, InterruptedException {
 88             Map<Integer, Double> listA = new HashMap<Integer, Double>();
 89             Map<Integer, Double> listB = new HashMap<Integer, Double>();
 90             Iterator<Text> itr = value.iterator();
 91             while (itr.hasNext()) {
 92                 String[] arr = itr.next().toString().split("\\s+");
 93                 String matrixTag = arr[0];
 94                 int pos = Integer.parseInt(arr[1]);
 95                 double v = Double.parseDouble(arr[2]);
 96                 if (MATRIXFILE1.equals(matrixTag)) {
 97                     listA.put(pos, v);
 98                 } else if (MATRIXFILE2.equals(matrixTag)) {
 99                     listB.put(pos, v);
100                 }
101             }
102             // 在此需要进行N*N次的乘法
103             for (Entry<Integer, Double> entryA : listA.entrySet()) {
104                 int posA = entryA.getKey();
105                 double valA = entryA.getValue();
106                 for (Entry<Integer, Double> entryB : listB.entrySet()) {
107                     int posB = entryB.getKey();
108                     double valB = entryB.getValue();
109                     double production = valA * valB;
110                     context.write(new Text(posA + "\t" + posB),
111                             new DoubleWritable(production));
112                 }
113             }
114         }
115     }
116 
117     public static class SumMapper extends
118             Mapper<LongWritable, Text, Text, DoubleWritable> {
119 
120         @Override
121         protected void map(LongWritable key, Text value, Context context)
122                 throws IOException, InterruptedException {
123             String[] arr = value.toString().split("\\s+");
124             if (arr.length == 3) {
125                 context.write(new Text(arr[0] + "\t" + arr[1]),
126                         new DoubleWritable(Double.parseDouble(arr[2])));
127             }
128         }
129     }
130 
131     public static class SumCombiner extends
132             Reducer<Text, DoubleWritable, Text, DoubleWritable> {
133 
134         @Override
135         protected void reduce(Text key, Iterable<DoubleWritable> value,
136                 Context context) throws IOException, InterruptedException {
137             double sum = 0;
138             Iterator<DoubleWritable> itr = value.iterator();
139             while (itr.hasNext()) {
140                 sum += itr.next().get();
141             }
142             context.write(key, new DoubleWritable(sum));
143         }
144     }
145 
146     public static class SumReducer extends
147             Reducer<Text, DoubleWritable, Text, DoubleWritable> {
148 
149         @Override
150         protected void reduce(Text key, Iterable<DoubleWritable> value,
151                 Context context) throws IOException, InterruptedException {
152             double sum = 0;
153             Iterator<DoubleWritable> itr = value.iterator();
154             while (itr.hasNext()) {
155                 sum += itr.next().get();
156             }
157             context.write(key, new DoubleWritable(sum));
158         }
159     }
160 
161     /**
162      * matrix1 * matrix2 = product<br>
163      * matrixFile1:输入文件,m行q列。<br>
164      * matrixFile2:输入文件,q行n列。<br>
165      * productFile:输出文件,m行n列。<br>
166      * 各列用空白符分隔。
167      */
168     public static void main(String[] args) throws IOException,
169             ClassNotFoundException, InterruptedException {
170         if (args.length < 3) {
171             System.err
172                     .println("please input 3 cmd args: matrixFile1 matrixFile2 productFile");
173             System.exit(1);
174         }
175         String matrixFile1 = args[0];
176         String matrixFile2 = args[1];
177         String productFile = args[2];
178 
179         Configuration conf = new Configuration();
180         FileSystem fs = FileSystem.get(conf);
181         Path inFile1 = new Path(matrixFile1);
182         Path inFile2 = new Path(matrixFile2);
183         conf.set(MATRIXFILE1, inFile1.getName());
184         conf.set(MATRIXFILE2, inFile2.getName());
185         Path midFile = new Path(inFile1.getParent().toUri().getPath()
186                 + "/product_tmp");
187         Path outFile = new Path(productFile);
188         if (!fs.exists(inFile2) || !fs.exists(inFile1)) {
189             System.err.println("input matrix file does not exists");
190             System.exit(1);
191         }
192         if (fs.exists(midFile)) {
193             fs.delete(midFile, true);
194         }
195         if (fs.exists(outFile)) {
196             fs.delete(outFile, true);
197         }
198 
199         {
200             Job productionJob1 = Job.getInstance(conf);
201             productionJob1.setJobName("MatrixMultiplication1_step1");
202             productionJob1.setJarByClass(MatrixMultiplication1.class);
203 
204             FileInputFormat.addInputPath(productionJob1, inFile1);
205             FileInputFormat.addInputPath(productionJob1, inFile2);
206             productionJob1.setInputFormatClass(MatrixInputFormat.class);
207             productionJob1.setMapperClass(EleEmitMapper.class);
208             productionJob1.setMapOutputKeyClass(IntWritable.class);
209             productionJob1.setMapOutputValueClass(Text.class);
210 
211             FileOutputFormat.setOutputPath(productionJob1, midFile);
212             productionJob1.setReducerClass(MultiplicationReducer.class);
213             productionJob1.setNumReduceTasks(12);
214             productionJob1.setOutputKeyClass(Text.class);
215             productionJob1.setOutputValueClass(DoubleWritable.class);
216 
217             productionJob1.waitForCompletion(true);
218         }
219 
220         {
221             Job productionJob2 = Job.getInstance(conf);
222             productionJob2.setJobName("MatrixMultiplication1_step2");
223             productionJob2.setJarByClass(MatrixMultiplication1.class);
224 
225             FileInputFormat.setInputPaths(productionJob2, midFile);
226             productionJob2.setMapperClass(SumMapper.class);
227             productionJob2.setMapOutputKeyClass(Text.class);
228             productionJob2.setMapOutputValueClass(DoubleWritable.class);
229 
230             FileOutputFormat.setOutputPath(productionJob2, outFile);
231             productionJob2.setCombinerClass(SumCombiner.class);
232             productionJob2.setReducerClass(SumReducer.class);
233             productionJob2.setNumReduceTasks(1);
234             productionJob2.setOutputKeyClass(Text.class);
235             productionJob2.setOutputValueClass(DoubleWritable.class);
236 
237             productionJob2.waitForCompletion(true);
238             
239         }
240         fs.delete(midFile, true);
241         System.exit(0);
242     }
243 }
View Code

相关文章: