激活函数之tanh介绍及C++/PyTorch实现

x33g5p2x  于2021-11-11 转载在 C/C++  
字(1.9k)|赞(0)|评价(0)|浏览(347)

    深度神经网络中使用的激活函数有很多种,这里介绍下tanh。它的公式如下,截图来自于维基百科(https://en.wikipedia.org/wiki/Activation_function):

    tanh又称双曲正切,它解决了sigmoid非零中心问题。tanh取值范围在(-1, 1)内,它也是非线性的。它也不能完全解决梯度消失问题。

    C++实现如下:

template<typename _Tp>
int activation_function_tanh(const _Tp* src, _Tp* dst, int length)
{
	for (int i = 0; i < length; ++i) {
		_Tp ep = std::exp(src[i]);
		_Tp em = std::exp(-src[i]);

		dst[i] = (ep - em) / (ep + em);
	}

	return 0;
}

template<typename _Tp>
int activation_function_tanh_derivative(const _Tp* src, _Tp* dst, int length)
{
	for (int i = 0; i < length; ++i) {
		dst[i] = (_Tp)1. - src[i] * src[i];
	}

	return 0;
}

int test_activation_function()
{
	std::vector<float> src{ 1.1f, -2.2f, 3.3f, 0.4f, -0.5f, -1.6f };
	int length = src.size();
	std::vector<float> dst(length);

	fprintf(stderr, "source vector: \n");
	fbc::print_matrix(src);
	fprintf(stderr, "calculate activation function:\n");

	fprintf(stderr, "type: tanh result: \n");
	fbc::activation_function_tanh(src.data(), dst.data(), length);
	fbc::print_matrix(dst);
	fprintf(stderr, "type: tanh derivative result: \n");
	fbc::activation_function_tanh_derivative(dst.data(), dst.data(), length);
	fbc::print_matrix(dst);
}

    执行结果如下:

    Python和PyTorch实现如下:

import numpy as np
import torch

data = [1.1, -2.2, 3.3, 0.4, -0.5, -1.6]

# numpy impl
def tanh(x):
	lists = list()
	for i in range(len(x)):
		lists.append((np.exp(x[i]) - np.exp(-x[i])) / (np.exp(x[i]) + np.exp(-x[i])))
	return lists

def tanh_derivative(x):
	return 1 - np.power(tanh(x), 2)

output = [round(value, 4) for value in tanh(data)] # 通过round保留小数点后4位
print("numpy tanh:", output)
print("numpt tanh derivative:", [round(value, 4) for value in tanh_derivative(data)])
print("numpt tanh derivative2:", [round(1. - value*value, 4) for value in tanh(data)])

# call pytorch interface
input = torch.FloatTensor(data)
m = torch.nn.Tanh()
output2 = m(input)
print("pytorch tanh:", output2)
print("pytorch tanh derivative:", 1. - output2*output2)

    执行结果如下:

    由以上执行结果可知:C++、Python、PyTorch三种实现方式结果完全一致。 

   GitHub:

          https://github.com/fengbingchun/NN_Test

          https://github.com/fengbingchun/PyTorch_Test

相关文章

微信公众号

最新文章

更多