【问题标题】:TypeError: multiple values for argument 'weight_decay'TypeError:参数“weight_decay”的多个值
【发布时间】:2021-08-23 09:09:13
【问题描述】:

我正在使用 AdamW 优化器,该优化器使用余弦衰减和预热学习调度器。我从头开始编写了自定义调度程序,并使用了 TensorFlow 插件库提供的 AdamW 优化器。

class CosineScheduler(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self,
                learning_rate_base,
                total_steps,
                warmup_learning_rate=0.0,
                warmup_steps=0):
        self.learning_rate_base = learning_rate_base
        self.total_steps = total_steps
        self.warmup_learning_rate =warmup_learning_rate
        self.warmup_steps = warmup_steps
    
    def __call__(self,step):
        learning_rate = 0.5 * self.learning_rate_base * (1 + tf.cos(
            np.pi * 
            (tf.cast(step, tf.float32) - self.warmup_steps)/ float(self.total_steps-self.warmup_steps)))
        if self.warmup_steps > 0:
            slope = (self.learning_rate_base - self.warmup_learning_rate) / self.warmup_steps
            warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate
            learning_rate = tf.where(step < self.warmup_steps, warmup_rate, learning_rate)
        lr = tf.where(step > self.total_steps, 0.0, learning_rate, name='learning_rate')
        wandb.log({"lr": lr})
        return lr

learning_rate = CosineScheduler(learning_rate_base=0.001, 
                                total_steps=23000, 
                                warmup_learning_rate=0.0, 
                                warmup_steps=1660)
loss_func = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1)
optimizer = tfa.optimizers.AdamW(learning_rate,weight_decay=0.1)

我收到以下错误提示,它说 weight_decay 有多个参数

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-12-6f9fd0a9c1cb> in <module>
      1 loss_func = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1)
----> 2 optimizer = tfa.optimizers.AdamW(learning_rate,weight_decay=0.1)

/opt/conda/lib/python3.7/site-packages/typeguard/__init__.py in wrapper(*args, **kwargs)
    923 
    924     def wrapper(*args, **kwargs):
--> 925         memo = _CallMemo(python_func, _localns, args=args, kwargs=kwargs)
    926         check_argument_types(memo)
    927         retval = func(*args, **kwargs)

/opt/conda/lib/python3.7/site-packages/typeguard/__init__.py in __init__(self, func, frame_locals, args, kwargs, forward_refs_policy)
    126 
    127         if args is not None and kwargs is not None:
--> 128             self.arguments = signature.bind(*args, **kwargs).arguments
    129         else:
    130             assert frame_locals is not None, 'frame must be specified if args or kwargs is None'

/opt/conda/lib/python3.7/inspect.py in bind(*args, **kwargs)
   3013         if the passed arguments can not be bound.
   3014         """
-> 3015         return args[0]._bind(args[1:], kwargs)
   3016 
   3017     def bind_partial(*args, **kwargs):

/opt/conda/lib/python3.7/inspect.py in _bind(self, args, kwargs, partial)
   2954                         raise TypeError(
   2955                             'multiple values for argument {arg!r}'.format(
-> 2956                                 arg=param.name)) from None
   2957 
   2958                     arguments[param.name] = arg_val

TypeError: multiple values for argument 'weight_decay'

是什么导致了问题,我该如何解决?

【问题讨论】:

    标签: python tensorflow machine-learning deep-learning


    【解决方案1】:

    问题在于weight_decaytfa.optimizers.AdamW 的第一个位置参数。在

    optimizer = tfa.optimizers.AdamW(learning_rate,weight_decay=0.1)
    

    你交出一个位置参数一个 kw 参数weight_decay。这会导致错误。 According to the documentationlearning rate 是第二个位置参数(尽管是可选的),而不是第一个。

    随便写

    optimizer = tfa.optimizers.AdamW(0.1, learning_rate)
    

    optimizer = tfa.optimizers.AdamW(weight_decay=0.1, learning_rate=learning_rate)
    

    optimizer = tfa.optimizers.AdamW(learning_rate=learning_rate, weight_decay=0.1)
    

    【讨论】:

      猜你喜欢
      • 2022-12-19
      • 2014-03-12
      • 1970-01-01
      • 1970-01-01
      • 2022-08-22
      • 2016-11-09
      • 2019-07-05
      • 2019-02-08
      • 2019-09-15
      相关资源
      最近更新 更多