go
复制代码
// worker.go
package mr
import (
"fmt"
"log"
"net/rpc"
"hash/fnv"
"strconv"
"strings"
"os"
"io/ioutil"
"sort"
"bufio"
)
//
// Map functions return a slice of KeyValue.
//
type KeyValue struct {
Key string
Value string
}
type ByKey []KeyValue
// for sorting by key.
func (a ByKey) Len() int { return len(a) }
func (a ByKey) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByKey) Less(i, j int) bool { return a[i].Key < a[j].Key }
//
// use ihash(key) % NReduce to choose the reduce
// task number for each KeyValue emitted by Map.
//
func ihash(key string) int {
h := fnv.New32a()
h.Write([]byte(key))
return int(h.Sum32() & 0x7fffffff)
}
func getStatus() int {
args := StatusArgs{}
reply := StatusReply{}
ok := call("Coordinator.Status", &args, &reply)
if ok {
return reply.Status
} else {
return -1
}
}
func createOrReplaceFile(filename string) error {
if _, err := os.Stat(filename); err == nil {
if err := os.Remove(filename); err != nil {
return fmt.Errorf("failed to delete file: %v", err)
}
} else if !os.IsNotExist(err) {
return fmt.Errorf("failed to check if file exists: %v", err)
}
file, err := os.Create(filename)
if err != nil {
log.Printf("failed to create file: %v", err)
return err
}
defer file.Close()
return nil
}
func getMapIntermediate(inputFile string, mapf func(string, string) []KeyValue) []KeyValue {
intermediate := []KeyValue{}
file, err := os.Open(inputFile)
content, err := ioutil.ReadAll(file)
if err != nil {
log.Fatalf("cannot read %v", inputFile)
}
file.Close()
// 通过map获取该文件的中间值集合
kva := mapf(inputFile, string(content))
// 加入到总的集合
intermediate = append(intermediate, kva...)
// 排序
sort.Sort(ByKey(intermediate))
return intermediate
}
func writeMap(nReduce int, mapperId int, intermediate []KeyValue) {
// 将map的输出写到中间文件
var files []*os.File
for i := 0; i < nReduce; i++ {
filepath := strconv.Itoa(mapperId) + "-" + strconv.Itoa(i)
err := createOrReplaceFile(filepath)
if err != nil {
fmt.Println("os.Create failed: %s \n", filepath)
panic(err)
}
file, _ := os.OpenFile(filepath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0666)
files = append(files, file)
}
for i := 0; i < len(intermediate); {
j := i + 1
for j < len(intermediate) && intermediate[j].Key == intermediate[i].Key {
j++
}
values := []string{}
for k := i; k < j; k++ {
values = append(values, intermediate[k].Value)
}
// 将同一个key的value组成string slice,key-value作为一行写入文件
fmt.Fprintf(files[ihash(intermediate[i].Key) % nReduce], "%v %v\n", intermediate[i].Key, strings.Join(values, " "))
i = j
}
// 关闭文件
for _, file := range files {
file.Close()
}
}
func getMap(mapf func(string, string) []KeyValue) {
args := MapArgs{}
reply := MapReply{}
ok := call("Coordinator.Map", &args, &reply)
if ok {
if reply.MapperId == -1 {
return
}
// 获取map func处理后的中间结果
intermediate := getMapIntermediate(reply.Inputfile, mapf)
// 将map的输出写到中间文件
writeMap(reply.NReduce, reply.MapperId, intermediate)
mdargs := MapDoneArgs{}
mdreply := MapDoneReply{}
mdargs.MapperId = reply.MapperId
ok = call("Coordinator.MapDone", &mdargs, &mdreply)
}
}
func getReduceIntermediate(reducerId int, nMap int, reducef func(string, []string) string) []KeyValue {
intermediate := make(map[string][]string)
for mapperId := 0; mapperId < nMap; mapperId++ {
filePath := strconv.Itoa(mapperId) + "-" + strconv.Itoa(reducerId)
file, err := os.Open(filePath)
defer file.Close()
if err != nil {
log.Println("Error opening file:", err)
}
// 逐行读取文件
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
words := strings.Fields(line)
if len(words) > 0 {
key := words[0] // 第一个单词作为 key
value := words[1:] // 其余单词作为 value
if existingValue, exists := intermediate[key]; exists {
intermediate[key] = append(existingValue, value...)
} else {
intermediate[key] = value
}
}
}
// 检查是否发生错误
if err := scanner.Err(); err != nil {
// fmt.Println("Error reading file:", err)
}
}
toWrite := make([]KeyValue, 0)
for key, value := range intermediate {
output := reducef(key, value)
kv := KeyValue{Key: key, Value: output}
toWrite = append(toWrite, kv)
}
// 排序
sort.Sort(ByKey(toWrite))
return toWrite
}
func writeReduce(reducerId int, toWrite []KeyValue) {
filePath := "mr-out-" + strconv.Itoa(reducerId)
createOrReplaceFile(filePath)
ofile, _ := os.OpenFile(filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0666)
defer ofile.Close()
for _, v := range toWrite {
fmt.Fprintf(ofile, "%v %v\n", v.Key, v.Value)
}
}
func getReduce(reducef func(string, []string) string) {
args := ReduceArgs{}
reply := ReduceReply{}
ok := call("Coordinator.Reduce", &args, &reply)
if ok {
if reply.ReducerId == -1 {
return
}
// 获取 Reduce func 处理结果
toWrite := getReduceIntermediate(reply.ReducerId, reply.NMap, reducef)
// 将 Reduce 结果写入文件
writeReduce(reply.ReducerId, toWrite)
rdargs := ReduceDoneArgs{}
rdreply := ReduceDoneReply{}
rdargs.ReducerId = reply.ReducerId
call("Coordinator.ReduceDone", &rdargs, &rdreply)
}
}
//
// main/mrworker.go calls this function.
//
func Worker(mapf func(string, string) []KeyValue,
reducef func(string, []string) string) {
// Your worker implementation here.
for {
status := getStatus()
if status == 0 {
// log.Println("Worker Mapping now")
getMap(mapf)
} else if status == 1 {
// log.Println("Worker Reducing now")
getReduce(reducef)
} else {
// log.Println("Worker Done")
break
}
}
// uncomment to send the Example RPC to the coordinator.
// CallExample()
}
//
// example function to show how to make an RPC call to the coordinator.
//
// the RPC argument and reply types are defined in rpc.go.
//
func CallExample() {
// declare an argument structure.
args := ExampleArgs{}
// fill in the argument(s).
args.X = 99
// declare a reply structure.
reply := ExampleReply{}
// send the RPC request, wait for the reply.
// the "Coordinator.Example" tells the
// receiving server that we'd like to call
// the Example() method of struct Coordinator.
ok := call("Coordinator.Example", &args, &reply)
if ok {
// reply.Y should be 100.
fmt.Printf("reply.Y %v\n", reply.Y)
} else {
fmt.Printf("call failed!\n")
}
}
//
// send an RPC request to the coordinator, wait for the response.
// usually returns true.
// returns false if something goes wrong.
//
func call(rpcname string, args interface{}, reply interface{}) bool {
// c, err := rpc.DialHTTP("tcp", "127.0.0.1"+":1234")
sockname := coordinatorSock()
c, err := rpc.DialHTTP("unix", sockname)
if err != nil {
log.Fatal("dialing:", err)
}
defer c.Close()
err = c.Call(rpcname, args, reply)
if err == nil {
return true
}
fmt.Println(err)
return false
}