这是因为它实现了__call__方法model(input)相当于调用了model.call(input)

下面这个是__call__的源码

  def __call__(self, *input, **kwargs):
        for hook in self._forward_pre_hooks.values():
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)
        for hook in self._forward_hooks.values():
            hook_result = hook(self, input, result)
            if hook_result is not None:
                result = hook_result
        if len(self._backward_hooks) > 0:
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in self._backward_hooks.values():
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
        return result

相关文章:

  • 1970-01-01
  • 2021-05-17
  • 2022-12-23
  • 2021-05-28
  • 2022-02-19
  • 2022-12-23
  • 2022-01-01
猜你喜欢
  • 1970-01-01
  • 2022-12-23
  • 2022-12-23
  • 2022-12-23
  • 2022-12-23
  • 2021-04-30
  • 2021-12-03
相关资源
相似解决方案