/**
 * Copyright (c) 2022 Yansen Zhang
 * wxcomponent is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
**/

package encrypt

import (
	"bytes"
	"crypto/aes"
	"crypto/cipher"
	"encoding/base64"
	"errors"
)

// MsgEncode 使用 AES CBC 模式加密数据
// data 为待加密数据
func MsgEncode(data []byte, encodingAESKey, iv string) ([]byte, error) {
	if encodingAESKey == "" {
		return nil, errors.New("加密 EncodingAESKey 为空")
	}

	if iv == "" {
		return nil, errors.New("加密 IV 为空")
	}

	if nil == data || len(data) == 0 {
		return nil, errors.New("待加密 消息 为空")
	}

	aesKey, e1 := base64.StdEncoding.DecodeString(encodingAESKey + "=")
	if e1 != nil {
		return nil, e1
	}

	block, e2 := aes.NewCipher(aesKey)
	if e2 != nil {
		return nil, e2
	}

	blockSize := block.BlockSize()
	dist := pkcs7(data, blockSize)

	// CBC
	mode := cipher.NewCBCEncrypter(block, []byte(iv))
	mode.CryptBlocks(dist, data)

	// 如果需要拿到字符串, 请使用
	// fmt.Sprintf("%x", xxx)
	// 或者使用 base64 转码
	return dist, nil
}

// MsgDecode 是 MsgEncode 的反操作, 用来解密数据
// data 为待解密数据
func MsgDecode(data []byte, encodingAESKey, iv string) ([]byte, error) {
	if encodingAESKey == "" {
		return nil, errors.New("解密 EncodingAESKey 为空")
	}

	if iv == "" {
		return nil, errors.New("解密 IV 为空")
	}

	if nil == data || len(data) == 0 {
		return nil, errors.New("待解密 消息 为空")
	}

	aesKey, e1 := base64.StdEncoding.DecodeString(encodingAESKey + "=")
	if e1 != nil {
		return nil, e1
	}

	block, err := aes.NewCipher(aesKey)
	if err != nil {
		return nil, err
	}

	dist := make([]byte, len(data))

	mode := cipher.NewCBCDecrypter(block, []byte(iv))
	mode.CryptBlocks(dist, data)

	// 如果需要拿到字符串, 请使用
	// fmt.Sprintf("%s", xxx)
	return unpkcs7(dist), nil
}

// pkcs7 填充
func pkcs7(data []byte, blockSize int) []byte {
	paddingBlock := blockSize - len(data)%blockSize
	if paddingBlock == 0 {
		paddingBlock = blockSize
	}

	padding := bytes.Repeat([]byte{byte(paddingBlock)}, paddingBlock)
	return append(data, padding...)
}

// unpkcs7 取消 pkcs7 填充
func unpkcs7(data []byte) []byte {
	length := len(data)

	unpadding := int(data[length-1])
	return data[:(length - unpadding)]
}