【问题标题】:rewrite TensorFlow 1.x to 2.x version将 TensorFlow 1.x 重写为 2.x 版本
【发布时间】:2021-09-27 21:08:06
【问题描述】:

我需要将 Tensorflow 1.x 代码重写为 2.x 版本。于是,我将注释过的代码改写如下(不同的activation和initializer都是我自己修改的):

def model(X, nact):
    # h = conv(tf.cast(X, tf.float32), nf=32, rf=8, stride=1, init_scale=np.sqrt(2))
    h = tf.keras.layers.Conv2D(filters=32,
                               kernel_size=8,
                               activation='relu',
                               kernel_initializer=orthogonal(np.sqrt(2)))(X)
    # h2 = conv(h, nf=64, rf=4, stride=1, init_scale=np.sqrt(2))
    h2 = tf.keras.layers.Conv2D(filters=64,
                                kernel_size=4,
                                activation='relu',
                                kernel_initializer=orthogonal(np.sqrt(2)))(h)
    . . .
    # pi = fc(h4, nact, act=lambda x: x)
    pi = tf.keras.layers.Dense(units=nact,
                               activation='linear',
                               kernel_initializer=orthogonal(np.sqrt(2)))(h4)
    # vf = fc(h4, 1, act=lambda x: tf.tanh(x))
    vf = tf.keras.layers.Dense(units=1,
                               activation='tanh',
                               kernel_initializer=orthogonal(np.sqrt(2)))(h4)

    # filter out non-valid actions from pi
    valid = tf.reduce_max(tf.cast(X, tf.float32), axis=1)
    valid_flat = tf.reshape(valid, [-1, nact])
    pi_fil = pi + (valid_flat - tf.ones(tf.shape(valid_flat))) * 1e32

    return pi_fil, vf[:, 0]

还有一些方法我有以下:

def build_model(args):
    nh = args.max_clause
    nw = args.max_var
    nc = 2
    nact = nc * nw
    ob_shape = (None, nh, nw, nc * args.n_stack)
    X = tf.placeholder(tf.float32, ob_shape)
    Y = tf.placeholder(tf.float32, (None, nact))
    Z = tf.placeholder(tf.float32, None)

    p, v = model(X, nact)
    params = tf.trainable_variables()
    with tf.name_scope("loss"):
        cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y, logits=p))
        value_loss = tf.losses.mean_squared_error(labels=Z, predictions=v)
        lossL2 = tf.add_n([tf.nn.l2_loss(vv) for vv in params])
        loss = cross_entropy + value_loss + args.l2_coeff * lossL2

    return X, Y, Z, p, v, params, loss

def self_play(args, status_track):
    X, _, _, p, v, params, _ = build_model(args)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        model_dir = status_track.get_model_dir()
        sess.run(load(params, os.path.join(args.save_dir, model_dir)))
        . . .

def super_train(args, status_track):
    X, Y, Z, _, _, params, loss = build_model(args)
    with tf.name_scope("train"):
        train_step = tf.train.AdamOptimizer(1e-3).minimize(loss)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        model_dir = status_track.get_sl_starter()
        sess.run(load(params, os.path.join(args.save_dir, model_dir)))
        . . .

如何在 TensorFlow 2.x 中重写这两个函数,即类似 Keras 的样式?

【问题讨论】:

    标签: tensorflow keras neural-network tensorflow2.0 tf.keras


    【解决方案1】:

    Tensorflow 2.x 兼容代码 sn-p.

    def build_model(args):
        nh = args.max_clause
        nw = args.max_var
        nc = 2
        nact = nc * nw
        ob_shape = (None, nh, nw, nc * args.n_stack)
        X = tf.compat.v1.placeholder(tf.float32, ob_shape)
        Y = tf.compat.v1.placeholder(tf.float32, (None, nact))
        Z = tf.compat.v1.placeholder(tf.float32, None)
    
        p, v = model(X, nact)
        params = tf.compat.v1.trainable_variables()
        with tf.name_scope("loss"):
            cross_entropy = tf.reduce_mean(tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(labels=Y, logits=p))
            value_loss = tf.keras.metrics.mean_squared_error(labels=Z, predictions=v)
            lossL2 = tf.add_n([tf.compat.v1.nn.l2_loss(vv) for vv in params])
            loss = cross_entropy + value_loss + args.l2_coeff * lossL2
    
        return X, Y, Z, p, v, params, loss
    

    【讨论】:

      猜你喜欢
      • 2020-10-28
      • 2016-08-18
      • 2020-08-24
      • 2021-09-29
      • 2021-09-08
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多