本网站(662p.com)打包出售,且带程序代码数据,662p.com域名,程序内核采用TP框架开发,需要联系扣扣:2360248666 /wx:lianweikj
精品域名一口价出售:1y1m.com(350元) ,6b7b.com(400元) , 5k5j.com(380元) , yayj.com(1800元), jiongzhun.com(1000元) , niuzen.com(2800元) , zennei.com(5000元)
需要联系扣扣:2360248666 /wx:lianweikj
浅谈tensorflow与pytorch的相互转换
makebo · 220浏览 · 发布于2022-06-28 +关注

本文主要介绍了简单介绍一下tensorflow与pytorch的相互转换,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

本文以一段代码为例,简单介绍一下tensorflow与pytorch的相互转换(主要是tensorflow转pytorch),可能介绍的没有那么详细,仅供参考。

由于本人只熟悉pytorch,而对tensorflow一知半解,而代码经常遇到tensorflow,而我希望使用pytorch,因此简单介绍一下tensorflow转pytorch,可能存在诸多错误,希望轻喷~

1.变量预定义

在TensorFlow的世界里,变量的定义和初始化是分开的。
tensorflow中一般都是在开头预定义变量,声明其数据类型、形状等,在执行的时候再赋具体的值,如下图所示,而pytorch用到时才会定义,定义和变量初始化是合在一起的。


2.创建变量并初始化

tensorflow中利用tf.Variable创建变量并进行初始化,而pytorch中使用torch.tensor创建变量并进行初始化,如下图所示。

3.语句执行

在TensorFlow的世界里,变量的定义和初始化是分开的,所有关于图变量的赋值和计算都要通过tf.Session的run来进行。

sess.run([G_solver, G_loss_temp, MSE_loss],
             feed_dict = {X: X_mb, M: M_mb, H: H_mb})

而在pytorch中,并不需要通过run进行,赋值完了直接计算即可。

4.tensor

pytorch运算时要创建完的numpy数组转为tensor,如下:

if use_gpu is True:
    X_mb = torch.tensor(X_mb, device="cuda")
    M_mb = torch.tensor(M_mb, device="cuda")
    H_mb = torch.tensor(H_mb, device="cuda")
else:
    X_mb = torch.tensor(X_mb)
    M_mb = torch.tensor(M_mb)
    H_mb = torch.tensor(H_mb)


最后运行完还要将tensor数据类型转换回numpy数组: 

if use_gpu is True:
    imputed_data=imputed_data.cpu().detach().numpy()
else:
    imputed_data=imputed_data.detach().numpy()


而tensorflow中不需要这种操作。 

5.其他函数

在tensorflow中包含诸多函数是pytorch中没有的,但是都可以在其他库中找到类似,具体如下表所示。

tensorflow中函数 pytorch中代替(所在库) 参数区别
tf.sqrt np.sqrt(numpy) 完全相同
tf.random_normal np.random.normal(numpy) tf.random_normal(shape = size, stddev = xavier_stddev)
np.random.normal(size = size, scale = xavier_stddev)
tf.concat torch.cat(torch) inputs = tf.concat(values = [x, m], axis = 1)
inputs = torch.cat(dim=1, tensors=[x, m])
tf.nn.relu F.relu(torch.nn.functional) 完全相同
tf.nn.sigmoid torch.sigmoid(torch) 完全相同
tf.matmul torch.matmul(torch) 完全相同
tf.reduce_mean torch.mean(torch) 完全相同
tf.log torch.log(torch) 完全相同
tf.zeros np.zeros 完全相同
tf.train.AdamOptimizer torch.optim.Adam(torch) optimizer_D = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
optimizer_D = torch.optim.Adam(params=theta_D)

到此这篇关于浅谈tensorflow与pytorch的相互转换的文章就介绍到这了


相关推荐

PHP实现部分字符隐藏

沙雕mars · 1325浏览 · 2019-04-28 09:47:56
Java中ArrayList和LinkedList区别

kenrry1992 · 908浏览 · 2019-05-08 21:14:54
Tomcat 下载及安装配置

manongba · 970浏览 · 2019-05-13 21:03:56
JAVA变量介绍

manongba · 962浏览 · 2019-05-13 21:05:52
什么是SpringBoot

iamitnan · 1086浏览 · 2019-05-14 22:20:36
加载中

0评论

评论
没有最好,只有更好,一切都在路上!
分类专栏
小鸟云服务器
扫码进入手机网页