package tests

import (
	"bytes"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"mime/multipart"
	"net/http"
	"net/http/httptest"
	"net/url"
	"os"
	"wechat-conf/utils"
	"strconv"
	"strings"
	"time"

	"github.com/astaxie/beego"
)

// RequestMock 模拟请求
type RequestMock struct {
	Headers map[string]string
	Host    string
}

// NewRequestMock new inst
func NewRequestMock() *RequestMock {
	return &RequestMock{
		Headers: make(map[string]string),
	}
}

// SetHeaders 设置自定义请求头
func (t *RequestMock) SetHeaders(headers map[string]string) *RequestMock {
	if headers != nil {
		for k, v := range headers {
			t.Headers[k] = v
		}
	}

	return t
}

// AddWechatUA 添加微信UA
func (t *RequestMock) AddWechatUA() *RequestMock {
	t.Headers["User-Agent"] = "Mozilla/5.0 (iPhone; CPU iPhone OS 8_0 like Mac OS X) AppleWebKit/600.1.4 (KHTML, like Gecko) Mobile/12A365 MicroMessenger/5.4.1 NetType/WIFI"
	return t
}

// AddToken 添加Token
func (t *RequestMock) AddToken(dt map[string]interface{}) *RequestMock {
	token, _ := utils.CreateToken(dt)
	t.Headers[utils.TokenHeader] = utils.TokenSchema + " " + token
	return t
}

// SetHost 设置 host 包含 port
func (t *RequestMock) SetHost(host string) *RequestMock {
	t.Host = host
	return t
}

// Request 请求
// 暂时只支持 params 是 url.Values 及 map[string]interface{} 两种类型
// url.Values 将被理解为 content-type="application/x-www-form-urlencoded"
// map[string]interface{} 将被理解为 content-type="multipart/form-data"
func (t *RequestMock) Request(meth, addr string, params interface{}, result interface{}) (code int, body []byte, err error) {
	var r *http.Request
	defer t.SetHeaders(nil)

	code = 0
	body = nil
	err = nil

	if t.Host != "" {
		addr = t.Host + addr
	} else {
		addr = "http://127.0.0.1:8080" + addr
	}

	switch meth {

	// get 请求, 只支持 params 为 url.Values 的参数
	case http.MethodGet:
		if params != nil {
			if dt, ok := params.(url.Values); ok {
				searchStr := dt.Encode()
				if strings.Index(addr, "?") > -1 {
					searchStr = "&" + searchStr
				} else {
					searchStr = "?" + searchStr
				}
				r, _ = http.NewRequest(meth, addr+searchStr, nil)
			}
		} else {
			r, _ = http.NewRequest(meth, addr, nil)
		}

	// post, put, delete
	case http.MethodPost, http.MethodPut, http.MethodDelete:
		rb, ct, e := GetBody(params)
		if e != nil {
			err = e
			return
		}

		r, _ = http.NewRequest(meth, addr, rb)
		if ct != "" {
			r.Header.Set("Content-Type", ct)
		}

	// 其他的本系统没有使用
	default:
		err = errors.New("不支持的请求类型")
		return
	}

	// 自定义的 header 头
	if t.Headers != nil {
		for k, v := range t.Headers {
			r.Header.Set(k, v)
		}
	}

	w := httptest.NewRecorder()
	beego.BeeApp.Handlers.ServeHTTP(w, r)

	code = w.Code
	body, _ = ioutil.ReadAll(w.Result().Body)

	beego.Trace("testing", meth+" - "+addr)

	if result != nil {
		err = json.Unmarshal(body, result)
		return
	}

	return
}

func GetBody(params interface{}) (body io.Reader, contentType string, err error) {
	body = nil
	contentType = ""
	err = nil

	if params == nil {
		return
	}

	t := fmt.Sprintf("%T", params)
	switch t {

	// 被会解析为 application/x-www-form-urlencoded
	case "url.Values":
		data, _ := params.(url.Values)
		body = strings.NewReader(data.Encode())
		contentType = "application/x-www-form-urlencoded"
		return

	// 被会解析为 multipart/form-data
	case "map[string]interface {}":
		var b bytes.Buffer
		var fw io.Writer

		w := multipart.NewWriter(&b)
		data, _ := params.(map[string]interface{})

		// 遍历字段
		for k, v := range data {
			switch x := v.(type) {

			// 文件
			case *os.File:
				if fw, err = w.CreateFormFile(k, x.Name()); err != nil {
					return
				}

				if _, err = io.Copy(fw, x); err != nil {
					return
				}

			// 字符串
			case string:
				err = w.WriteField(k, x)
				if err != nil {
					return
				}

			// 整数, 暂不支持 int 各种原始类型, 比如 int32, int64 等
			case int:
				dt := strconv.Itoa(x)
				err = w.WriteField(k, dt)
				if err != nil {
					return
				}

			// 小数, 暂时不支持其他类型的浮点数
			case float64:
				dt := strconv.FormatFloat(x, 'f', -1, 64)
				err = w.WriteField(k, dt)
				if err != nil {
					return
				}

			// 时间
			case time.Time:
				dt := x.Format("2006-01-02 15:04:05")
				err = w.WriteField(k, dt)
				if err != nil {
					return
				}

			// 其他
			default:
				err = fmt.Errorf("暂时不支持 key "+k+" 对应的数据类型 %T", x)
				return
			}
		}

		body = bytes.NewReader(b.Bytes())
		contentType = w.FormDataContentType()
		return

	default:
		err = fmt.Errorf("暂时不支持的参数类型 " + t)
		return
	}
}