go語言中的errgroup

乾飯人小羽 發佈 2023-05-26T12:42:13.235744+00:00

1、串行執行假如我們需要查詢一個課件列表,其中有課件的信息,還有課件創建者的信息,和課件的縮略圖信息。但是此時我們已經對服務做了拆分,假設有課件服務和用戶服務還有文件服務。

1、串行執行

假如我們需要查詢一個課件列表,其中有課件的信息,還有課件創建者的信息,和課件的縮略圖信息。但是此時我們已經對服務做了拆分,假設有課件服務用戶服務還有文件服務

我們通常的做法是,當我們查詢課件列表時,我們首先調用課件服務,比如查詢10條課件記錄,然後獲取到課件的創建人ID,課件的縮略圖ID;再通過這些創建人ID去用戶服務查詢用戶信息,通過縮略圖ID去文件服務查詢文件信息;然後再寫到這10條課件記錄中返回給前端。

像下面這樣:

package main

import (
        "fmt"
        "time"
)

type Courseware struct {
        Id         int64
        Name       string
        Code       string
        CreateId   int64
        CreateName string
        CoverId   int64
        CoverPath string
}

type User struct {
        Id   int64
        Name string
}

type File struct {
        Id   int64
        Path string
}

var coursewares []Courseware
var users map[int64]User
var files map[int64]File
var err error

func main() {
        // 查詢課件
        coursewares, err = CoursewareList()
        if err != nil {
                fmt.Println("獲取課件錯誤")
                return
        }

        // 獲取用戶ID、文件ID
        userIds := make([]int64, 0)
        fileIds := make([]int64, 0)
        for _, courseware := range coursewares {
                userIds = append(userIds, courseware.CreateId)
                fileIds = append(fileIds, courseware.CoverId)
        }

        // 批量獲取用戶信息
        users, err = UserMap(userIds)
        if err != nil {
                fmt.Println("獲取用戶錯誤")
                return
        }

        // 批量獲取文件信息
        files, err = FileMap(fileIds)
        if err != nil {
                fmt.Println("獲取文件錯誤")
                return
        }

        // 填充
        for i, courseware := range coursewares {
                if user, ok := users[courseware.CreateId]; ok {
                        coursewares[i].CreateName = user.Name
                }

                if file, ok := files[courseware.CoverId]; ok {
                        coursewares[i].CoverPath = file.Path
                }
        }
        fmt.Println(coursewares)
}

func UserMap(ids []int64) (map[int64]User, error) {
        time.Sleep(3 * time.Second) // 模擬資料庫請求
        return map[int64]User{
                1: {Id: 1, Name: "liu"},
                2: {Id: 2, Name: "kang"},
        }, nil
}

func FileMap(ids []int64) (map[int64]File, Error) {
        time.Sleep(3 * time.Second) // 模擬資料庫請求
        return map[int64]File{
                1: {Id: 1, Path: "/a/b/c.jpg"},
                2: {Id: 2, Path: "/a/b/c/d.jpg"},
        }, nil
}

func CoursewareList() ([]Courseware, error) {
        time.Sleep(3 * time.Second)
        return []Courseware{
                {Id: 1, Name: "課件1", Code: "CW1", CreateId: 1, CreateName: "", CoverId: 1, CoverPath: ""},
                {Id: 2, Name: "課件2", Code: "CW2", CreateId: 2, CreateName: "", CoverId: 2, CoverPath: ""},
        }, nil
}

2、並發執行

但我們獲取課件之後,填充用戶信息和文件信息是可以並行執行的,我們可以修改獲取用戶和文件的代碼,把他們放到協程裡面,這樣就可以並行執行了:

...

        // 此處放到協程里
        go func() {
                // 批量獲取用戶信息
                users, err = UserMap(userIds)
                if err != nil {
                        fmt.Println("獲取用戶錯誤")
                        return
                }
        }()

        // 此處放到協程里
        go func() {
                // 批量獲取文件信息
                files, err = FileMap(fileIds)
                if err != nil {
                        fmt.Println("獲取文件錯誤")
                        return
                }
        }()

        ...

但是當你執行的時候你會發現這樣是有問題的,因為下面的填充數據的代碼有可能會在這兩個協程執行完成之前去執行。也就是說最終的數據有可能沒有填充用戶信息和文件信息。那怎麼辦呢?這是我們就可以使用golang的waitgroup了,主要作用就是協程的編排。

我們可以等2個協程都執行完成再去走下面的填充邏輯

我們繼續修改代碼成下面的樣子

...

// 初始化一個sync.WaitGroup
var wg sync.WaitGroup

func main() {
        // 查詢課件
        ...
        // 獲取用戶ID、文件ID
        ...

        // 此處放到協程里
        wg.Add(1) // 計數器+1
        go func() {
                defer wg.Done() // 計數器-1
                // 批量獲取用戶信息
                users, err = UserMap(userIds)
                if err != nil {
                        fmt.Println("獲取用戶錯誤")
                        return
                }
        }()

        // 此處放到協程里
        wg.Add(1) // 計數器+1
        go func() {
                defer wg.Done() // 計數器-1
                // 批量獲取文件信息
                files, err = FileMap(fileIds)
                if err != nil {
                        fmt.Println("獲取文件錯誤")
                        return
                }
        }()

  // 阻塞等待計數器小於等於0
        wg.Wait()

        // 填充
        for i, courseware := range coursewares {
                if user, ok := users[courseware.CreateId]; ok {
                        coursewares[i].CreateName = user.Name
                }

                if file, ok := files[courseware.CoverId]; ok {
                        coursewares[i].CoverPath = file.Path
                }
        }
        fmt.Println(coursewares)
}

...

我們初始化一個sync.WaitGroup,調用wg.Add(1)給計數器加一,調用wg.Done()計數器減一,wg.Wait()阻塞等待直到計數器小於等於0,結束阻塞,繼續往下執行。

3、errgroup

但是我們現在又有這樣的需求,我們希望如果獲取用戶或者獲取文件有任何一方報錯了,直接拋錯,不再組裝數據。

我們可以像下面這樣寫

...

var goErr error
var wg sync.WaitGroup

...

func main() {
        ...

        // 此處放到協程里
        wg.Add(1)
        go func() {
                defer wg.Done()
                // 批量獲取用戶信息
                users, err = UserMap(userIds)
                if err != nil {
                        goErr = err
                        fmt.Println("獲取用戶錯誤:", err)
                        return
                }
        }()

        // 此處放到協程里
        wg.Add(1)
        go func() {
                defer wg.Done()
                // 批量獲取文件信息
                files, err = FileMap(fileIds)
                if err != nil {
                        goErr = err
                        fmt.Println("獲取文件錯誤:", err)
                        return
                }
        }()

        wg.Wait()

        if goErr != nil {
                fmt.Println("goroutine err:", err)
                return
        }

        ...
}

...

把錯誤放在goErr中,結束阻塞後判斷協程調用是否拋錯。

那golang裡面有沒有類似這樣的實現呢?答案是有的,那就是errgroup。其實和我們上面的方法差不多,但是errgroup包做了一層結構體的封裝,也不需要在每個協程裡面判斷error傳給errGo了。

下面是errgroup的實現

package main

import (
        "errors"
        "fmt"
        "golang.org/x/sync/errgroup"
        "time"
)

type Courseware struct {
        Id         int64
        Name       string
        Code       string
        CreateId   int64
        CreateName string
        CoverId   int64
        CoverPath string
}

type User struct {
        Id   int64
        Name string
}

type File struct {
        Id   int64
        Path string
}

var coursewares []Courseware
var users map[int64]User
var files map[int64]file
var err error
// 定義一個errgroup
var eg errgroup.Group

func main() {
        // 查詢課件
        coursewares, err = CoursewareList()
        if err != nil {
                fmt.Println("獲取課件錯誤:", err)
                return
        }

        // 獲取用戶ID、文件ID
        userIds := make([]int64, 0)
        fileIds := make([]int64, 0)
        for _, courseware := range coursewares {
                userIds = append(userIds, courseware.CreateId)
                fileIds = append(fileIds, courseware.CoverId)
        }


        // 此處放到協程里
        eg.Go(func() error {
                // 批量獲取用戶信息
                users, err = UserMap(userIds)
                if err != nil {
                        fmt.Println("獲取用戶錯誤:", err)
                        return err
                }
                return nil
        })

        // 此處放到協程里
        eg.Go(func() error {
                // 批量獲取文件信息
                files, err = FileMap(fileIds)
                if err != nil {
                        fmt.Println("獲取文件錯誤:", err)
                        return err
                }
                return nil
        })

  // 判斷group中是否有報錯
        if goErr := eg.Wait(); goErr != nil {
                fmt.Println("goroutine err:", err)
                return
        }

        // 填充
        for i, courseware := range coursewares {
                if user, ok := users[courseware.CreateId]; ok {
                        coursewares[i].CreateName = user.Name
                }

                if file, ok := files[courseware.CoverId]; ok {
                        coursewares[i].CoverPath = file.Path
                }
        }
        fmt.Println(coursewares)
}

func UserMap(ids []int64) (map[int64]User, error) {
        time.Sleep(3 * time.Second)
        return map[int64]User{
                1: {Id: 1, Name: "liu"},
                2: {Id: 2, Name: "kang"},
        }, errors.New("sql err")
}

func FileMap(ids []int64) (map[int64]File, error) {
        time.Sleep(3 * time.Second)
        return map[int64]File{
                1: {Id: 1, Path: "/a/b/c.jpg"},
                2: {Id: 2, Path: "/a/b/c/d.jpg"},
        }, nil
}

func CoursewareList() ([]Courseware, error) {
        time.Sleep(3 * time.Second)
        return []Courseware{
                {Id: 1, Name: "課件1", Code: "CW1", CreateId: 1, CreateName: "", CoverId: 1, CoverPath: ""},
                {Id: 2, Name: "課件2", Code: "CW2", CreateId: 2, CreateName: "", CoverId: 2, CoverPath: ""},
        }, nil
}

當然,errgroup中也有針對上下文的errgroup.WithContext函數,如果我們想控制請求接口的時間,用這個是最合適不過的。如果請求超時會返回一個關閉上下文的報錯,像下面這樣

package main

import (
        "context"
        "fmt"
        "golang.org/x/sync/errgroup"
        "time"
)

type Courseware struct {
        Id         int64
        Name       string
        Code       string
        CreateId   int64
        CreateName string
        CoverId    int64
        CoverPath  string
}

type User struct {
        Id   int64
        Name string
}

type File struct {
        Id   int64
        Path string
}

var coursewares []Courseware
var users map[int64]User
var files map[int64]File
var err error

func main() {
        // 查詢課件
        ...

        // 獲取用戶ID、文件ID
        ...

  // 定義一個帶超時時間的上下文,1秒鐘超時
        ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second)
        defer cancelFunc()
  // 定義一個帶上下文的errgroup,使用上面帶有超時時間的上下文
        eg, ctx := errgroup.WithContext(ctx)
        // 此處放到協程里
        eg.Go(func() error {
                // 批量獲取用戶信息
                users, err = UserMap(ctx, userIds)
                if err != nil {
                        fmt.Println("獲取用戶錯誤:", err)
                        return err
                }
                return nil
        })

        // 此處放到協程里
        eg.Go(func() error {
                // 批量獲取文件信息
                files, err = FileMap(ctx, fileIds)
                if err != nil {
                        fmt.Println("獲取文件錯誤:", err)
                        return err
                }
                return nil
        })

        if goErr := eg.Wait(); goErr != nil {
                fmt.Println("goroutine err:", err)
                return
        }

        // 填充
        for i, courseware := range coursewares {
                if user, ok := users[courseware.CreateId]; ok {
                        coursewares[i].CreateName = user.Name
                }

                if file, ok := files[courseware.CoverId]; ok {
                        coursewares[i].CoverPath = file.Path
                }
        }
        fmt.Println(coursewares)
}

func UserMap(ctx context.Context, ids []int64) (map[int64]User, error) {
        result := make(chan map[int64]User)
        go func() {
                time.Sleep(2 * time.Second) // 假裝請求超過1秒鐘
                result <- map[int64]User{
                        1: {Id: 1, Name: "liu"},
                        2: {Id: 2, Name: "kang"},
                }
        }()

        select {
        case <-ctx.Done(): // 如果上下文結束直接返回錯誤信息
                return nil, ctx.Err()
        case res := <-result: // 返回正確結果
                return res, nil
        }
}

func FileMap(ctx context.Context, ids []int64) (map[int64]File, error) {
        return map[int64]File{
                1: {Id: 1, Path: "/a/b/c.jpg"},
                2: {Id: 2, Path: "/a/b/c/d.jpg"},
        }, nil
}

func CoursewareList() ([]Courseware, error) {
        time.Sleep(3 * time.Second)
        return []Courseware{
                {Id: 1, Name: "課件1", Code: "CW1", CreateId: 1, CreateName: "", CoverId: 1, CoverPath: ""},
                {Id: 2, Name: "課件2", Code: "CW2", CreateId: 2, CreateName: "", CoverId: 2, CoverPath: ""},
        }, nil
}

執行上面的代碼:

go run waitgroup.go
獲取用戶錯誤: context deadline exceeded
goroutine err: context deadline exceeded
關鍵字: