package tests import ( "bytes" "encoding/json" "errors" "fmt" "io" "io/ioutil" "mime/multipart" "net/http" "net/http/httptest" "net/url" "os" "wechat-conf/utils" "strconv" "strings" "time" "github.com/astaxie/beego" ) // RequestMock 模拟请求 type RequestMock struct { Headers map[string]string Host string } // NewRequestMock new inst func NewRequestMock() *RequestMock { return &RequestMock{ Headers: make(map[string]string), } } // SetHeaders 设置自定义请求头 func (t *RequestMock) SetHeaders(headers map[string]string) *RequestMock { if headers != nil { for k, v := range headers { t.Headers[k] = v } } return t } // AddWechatUA 添加微信UA func (t *RequestMock) AddWechatUA() *RequestMock { t.Headers["User-Agent"] = "Mozilla/5.0 (iPhone; CPU iPhone OS 8_0 like Mac OS X) AppleWebKit/600.1.4 (KHTML, like Gecko) Mobile/12A365 MicroMessenger/5.4.1 NetType/WIFI" return t } // AddToken 添加Token func (t *RequestMock) AddToken(dt map[string]interface{}) *RequestMock { token, _ := utils.CreateToken(dt) t.Headers[utils.TokenHeader] = utils.TokenSchema + " " + token return t } // SetHost 设置 host 包含 port func (t *RequestMock) SetHost(host string) *RequestMock { t.Host = host return t } // Request 请求 // 暂时只支持 params 是 url.Values 及 map[string]interface{} 两种类型 // url.Values 将被理解为 content-type="application/x-www-form-urlencoded" // map[string]interface{} 将被理解为 content-type="multipart/form-data" func (t *RequestMock) Request(meth, addr string, params interface{}, result interface{}) (code int, body []byte, err error) { var r *http.Request defer t.SetHeaders(nil) code = 0 body = nil err = nil if t.Host != "" { addr = t.Host + addr } else { addr = "http://127.0.0.1:8080" + addr } switch meth { // get 请求, 只支持 params 为 url.Values 的参数 case http.MethodGet: if params != nil { if dt, ok := params.(url.Values); ok { searchStr := dt.Encode() if strings.Index(addr, "?") > -1 { searchStr = "&" + searchStr } else { searchStr = "?" + searchStr } r, _ = http.NewRequest(meth, addr+searchStr, nil) } } else { r, _ = http.NewRequest(meth, addr, nil) } // post, put, delete case http.MethodPost, http.MethodPut, http.MethodDelete: rb, ct, e := GetBody(params) if e != nil { err = e return } r, _ = http.NewRequest(meth, addr, rb) if ct != "" { r.Header.Set("Content-Type", ct) } // 其他的本系统没有使用 default: err = errors.New("不支持的请求类型") return } // 自定义的 header 头 if t.Headers != nil { for k, v := range t.Headers { r.Header.Set(k, v) } } w := httptest.NewRecorder() beego.BeeApp.Handlers.ServeHTTP(w, r) code = w.Code body, _ = ioutil.ReadAll(w.Result().Body) beego.Trace("testing", meth+" - "+addr) if result != nil { err = json.Unmarshal(body, result) return } return } func GetBody(params interface{}) (body io.Reader, contentType string, err error) { body = nil contentType = "" err = nil if params == nil { return } t := fmt.Sprintf("%T", params) switch t { // 被会解析为 application/x-www-form-urlencoded case "url.Values": data, _ := params.(url.Values) body = strings.NewReader(data.Encode()) contentType = "application/x-www-form-urlencoded" return // 被会解析为 multipart/form-data case "map[string]interface {}": var b bytes.Buffer var fw io.Writer w := multipart.NewWriter(&b) data, _ := params.(map[string]interface{}) // 遍历字段 for k, v := range data { switch x := v.(type) { // 文件 case *os.File: if fw, err = w.CreateFormFile(k, x.Name()); err != nil { return } if _, err = io.Copy(fw, x); err != nil { return } // 字符串 case string: err = w.WriteField(k, x) if err != nil { return } // 整数, 暂不支持 int 各种原始类型, 比如 int32, int64 等 case int: dt := strconv.Itoa(x) err = w.WriteField(k, dt) if err != nil { return } // 小数, 暂时不支持其他类型的浮点数 case float64: dt := strconv.FormatFloat(x, 'f', -1, 64) err = w.WriteField(k, dt) if err != nil { return } // 时间 case time.Time: dt := x.Format("2006-01-02 15:04:05") err = w.WriteField(k, dt) if err != nil { return } // 其他 default: err = fmt.Errorf("暂时不支持 key "+k+" 对应的数据类型 %T", x) return } } body = bytes.NewReader(b.Bytes()) contentType = w.FormDataContentType() return default: err = fmt.Errorf("暂时不支持的参数类型 " + t) return } }