【问题标题】:ML-Agents agent not resetting?ML-Agents 代理未重置?
【发布时间】:2019-12-12 16:55:53
【问题描述】:

我一直在研究一双能够自我平衡的腿。如果他的“腰”低于某个 y 位置值(跌倒/绊倒),则该区域应该重置并从他的奖励分数中扣除分数。我对机器学习非常陌生,所以放轻松! 为什么代理跌倒时没有重置




代理代码(更新):

    using MLAgents;
    using System;
    using System.Collections;
    using System.Collections.Generic;
    using UnityEngine;

    using MLAgents;
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class BalanceAgent : Agent
{
    private BalancingArea area;
    public GameObject waist;
    public GameObject buttR;
    public GameObject buttL;
    public GameObject thighR;
    public GameObject thighL;
    public GameObject legR;
    public GameObject legL;
    public GameObject footR;
    public GameObject footL;

    //public GameObject goal;

    // private float buttR = 0f;

    public GameObject[] bodyParts = new GameObject[9];
    public Vector3[] posStart = new Vector3[9];
    public Vector3[] eulerStart = new Vector3[9];



    public override void InitializeAgent() {
        base.InitializeAgent();
        area = GetComponentInParent<BalancingArea>();

        bodyParts = new GameObject[]{waist, buttR, buttL, thighR, thighL, legR, legL, footR, footL};

        for(int i = 0; i < bodyParts.Length; i++) {
            posStart[i] = bodyParts[i].transform.position;
            eulerStart[i] = bodyParts[i].transform.eulerAngles;
        }

    }

    public override void AgentReset() {
        for (int i = 0; i < bodyParts.Length; i++) {
            bodyParts[i].transform.position = posStart[i];
            bodyParts[i].transform.eulerAngles = eulerStart[i];
            bodyParts[i].GetComponent<Rigidbody>().velocity = Vector3.zero;
            bodyParts[i].GetComponent<Rigidbody>().angularVelocity = Vector3.zero;
        }
    }

    public override void AgentAction(float[] vectorAction) {

        int buttRDir = 0;
        int buttRVec = (int)vectorAction[0];
        switch (buttRVec) {
            case 3:
                buttRDir = 0;
                break;
            case 1:
                buttRDir = -1;
                break;
            case 2:
                buttRDir = 1;
                break;
        }
        buttR.transform.Rotate(0, buttRDir, 0);

        int buttLDir = 0;
        int buttLVec = (int)vectorAction[1];
        switch (buttLVec) {
            case 3:
                buttLDir = 0;
                break;
            case 1:
                buttLDir = -1;
                break;
            case 2:
                buttLDir = 1;
                break;
        }
        buttL.transform.Rotate(0, buttLDir, 0);

        int thighRDir = 0;
        int thighRVec = (int)vectorAction[2];
        switch (thighRVec) {
            case 3:
                thighRDir = 0;
                break;
            case 1:
                thighRDir = -1;
                break;
            case 2:
                thighRDir = 1;
                break;
        }
        thighR.transform.Rotate(0, thighRDir, 0);

        int thighLDir = 0;
        int thighLVec = (int)vectorAction[3];
        switch (thighLVec) {
            case 3:
                thighLDir = 0;
                break;
            case 1:
                thighLDir = -1;
                break;
            case 2:
                thighLDir = 1;
                break;
        }
        thighL.transform.Rotate(0, thighLDir, 0);

        int legRDir = 0;
        int legRVec = (int)vectorAction[4];
        switch (legRVec) {
            case 3:
                legRDir = 0;
                break;
            case 1:
                legRDir = -1;
                break;
            case 2:
                legRDir = 1;
                break;
        }
        legR.transform.Rotate(0, legRDir, 0);

        int legLDir = 0;
        int legLVec = (int)vectorAction[5];
        switch (legLVec) {
            case 3:
                legLDir = 0;
                break;
            case 1:
                legLDir = -1;
                break;
            case 2:
                legLDir = 1;
                break;
        }
        legL.transform.Rotate(0, legLDir, 0);

        int footRDir = 0;
        int footRVec = (int)vectorAction[6];
        switch (footRVec) {
            case 3:
                footRDir = 0;
                break;
            case 1:
                footRDir = -1;
                break;
            case 2:
                footRDir = 1;
                break;
        }
        footR.transform.Rotate(0, footRDir, 0);

        int footLDir = 0;
        int footLVec = (int)vectorAction[7];
        switch (footLVec) {
            case 3:
                footLDir = 0;
                break;
            case 1:
                footLDir = -1;
                break;
            case 2:
                footLDir = 1;
                break;
        }
        footL.transform.Rotate(0, footLDir, 0);

        //buttR = vectorAction[0]; //Right or none
        //if (buttR == 2) buttR = -1f; //Left

        if (waist.transform.position.y > -1) {
            AddReward(.1f);
        }
        else {
            AddReward(-.02f);
        }

        if (waist.transform.position.y <= -3) {
            print("reset!");
            AddReward(-.1f);
            Done();
        }

        public override void CollectObservations() {
            AddVectorObs(waist.transform.localEulerAngles.y);
            AddVectorObs(buttR.transform.localEulerAngles.x);
            AddVectorObs(buttL.transform.localEulerAngles.x);
            AddVectorObs(thighR.transform.localEulerAngles.y);
            AddVectorObs(thighL.transform.localEulerAngles.y);
            AddVectorObs(legR.transform.localEulerAngles.y);
            AddVectorObs(legL.transform.localEulerAngles.y);
            AddVectorObs(footR.transform.localEulerAngles.y);
            AddVectorObs(footL.transform.localEulerAngles.y);
            AddVectorObs(waist.transform.position);
        }
    }




区域代码:

using MLAgents;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;

public class BalancingArea : Area
{
    public List<BalanceAgent> BalanceAgent { get; private set; }
    public BalanceAcademy BalanceAcademy { get; private set; }
    public GameObject area;

    private void Awake() {
        BalanceAgent = transform.GetComponentsInChildren<BalanceAgent>().ToList();              //Grabs all agents in area
        BalanceAcademy = FindObjectOfType<BalanceAcademy>();                //Grabs balance acedem
    }

    private void Start() {

    }

    public void ResetAgentPosition(BalanceAgent agent) {
        agent.transform.position = new Vector3(area.transform.position.x, 0, area.transform.position.z);
        agent.transform.eulerAngles = new Vector3(0,0,0);
    }

    // Update is called once per frame
    void Update()
    {

    }
}




BalanceAcademy 代码:

using MLAgents;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class BalanceAcademy : Academy
{

}



用于运行训练器的命令:

mlagents-learn config/trainer_config.yaml --run-id=balancetest09 --train

【问题讨论】:

  • 请包含BalanceAcademy的定义
  • @Ruzihm 我看过他们没有在 BalanceAcademy 脚本中添加太多内容的视频。实际上,我不确定我是否将它附加到任何游戏对象上。有必要吗?它应该附加到区域对象吗?无论哪种方式,它现在都已添加到帖子中。谢谢。
  • 感谢您的更新。如果someone answers您的问题,请不要忘记接受答案!
  • 所有身体部位上都有刚体吗?
  • @Ruzihm 是的,他们这样做了。为什么这很重要?

标签: c# unity3d machine-learning game-physics ml-agent


【解决方案1】:

来自creating a new environment上的文档:

初始化和重置代理

当代理到达它的目标时,它会将自己标记为完成并且它的代理 重置功能将目标移动到随机位置。此外,如果 代理滚下平台,重置功能将其放回平台 地板。

要移动目标游戏对象,我们需要一个对其变换的引用 (它在 3D 中存储游戏对象的位置、方向和比例 世界)。要获取此引用,请将 Transform 类型的公共字段添加到 RollerAgent 类。 Unity中组件的公共字段获取 显示在检查器窗口中,允许您选择哪个 在 Unity 编辑器中用作目标的游戏对象。

重置特工的速度(然后施加力移动 代理)我们需要一个对刚体组件的引用。刚体是 Unity 用于物理模拟的主要元素。 (见物理完整 Unity 物理的文档。)由于刚体组件已打开 与我们的代理脚本相同的游戏对象,获得它的最佳方法 参考使用GameObject.GetComponent&lt;T&gt;(),我们可以调用 我们脚本的Start() 方法。

到目前为止,我们的 RollerAgent 脚本如下所示:

using System.Collections.Generic;
using UnityEngine;
using MLAgents;

public class RollerAgent : Agent
{
    Rigidbody rBody;
    void Start () {
        rBody = GetComponent<Rigidbody>();
    }

    public Transform Target;
    public override void AgentReset()
    {
        if (this.transform.position.y < 0)
        {
            // If the Agent fell, zero its momentum
            this.rBody.angularVelocity = Vector3.zero;
            this.rBody.velocity = Vector3.zero;
            this.transform.position = new Vector3( 0, 0.5f, 0);
        }

        // Move the target to a new spot
        Target.position = new Vector3(Random.value * 8 - 4,
                                      0.5f,
                                      Random.value * 8 - 4);
    }
}

因此,您应该重写AgentReset 方法,以便重置代理关节的位置。首先,您可以在InitializeAgent 中获取每个关节的旋转和位置,然后在AgentReset 中恢复它们。此外,将刚体的速度和角速度归零。

我在文档或示例中没有看到任何关于在 Update 中调用 Done 的内容,因此可能建议甚至要求它在 AgentAction 中按预期运行。不妨将所有内容移出Update 并移入AgentAction

此外,您可能希望在具有 3 个分量 (xyz) 的特征向量中使用 transform.localEulerAngles,而不是具有 4 个分量 (xyzw) 的 transform.localRotation。否则,您不应省略localRotation 的 w 组件。

总而言之,它可能看起来像这样:

using MLAgents;
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class BalanceAgent : Agent
{
    private BalancingArea area;
    public GameObject waist;
    public GameObject buttR;
    public GameObject buttL;
    public GameObject thighR;
    public GameObject thighL;
    public GameObject legR;
    public GameObject legL;
    public GameObject footR;
    public GameObject footL;
    public GameObject goal;

    private List<GameObject> gameObjectsToReset;
    private List<Rigidbody> rigidbodiesToReset;
    private List<Vector3> initEulers;
    private List<Vector3> initPositions;

    // private float buttR = 0f;


    public override void InitializeAgent() {
        base.InitializeAgent();
        area = GetComponentInParent<BalancingArea>();

        gameObjectsToReset= new List<GameObject>(new GameObject[]{
                waist, buttR, buttL, thighR, thighL, legR, legL,
                footR, footL});
        rigidbodiesToReset = new List<Rigidbody>();
        initEulers = new List<Vector3>();
        initPositions = new List<Vector3>();

        foreach (GameObject g in gameObjectsToReset) {
            rigidbodiesToReset.Add(g.GetComponent<Rigidbody>());
            initEulers.Add(g.transform.eulerAngles);
            initPositions.Add(g.transform.position);
        }
    }

    public override void AgentReset() {
        for (int i = 0 ; i < gameObjectsToReset.Count ; i++) {
            Transform t = gameObjectsToReset[i].transform;
            t.position = initPositions[i];
            t.eulerAngles = initEulers[i];

            Rigidbody r = rigidbodiesToReset[i];
            r.velocity = Vector3.zero;
            r.angularVelocity = Vector3.zero;
        } 
    }

    public override void AgentAction(float[] vectorAction) {

        int buttRDir = 0;
        int buttRVec = (int)vectorAction[0];
        switch (buttRVec) {
            case 3:
                buttRDir = 0;
                break;
            case 1:
                buttRDir = -1;
                break;
            case 2:
                buttRDir = 1;
                break;
        }
        buttR.transform.Rotate(0, buttRDir, 0);

        int buttLDir = 0;
        int buttLVec = (int)vectorAction[1];
        switch (buttLVec) {
            case 3:
                buttLDir = 0;
                break;
            case 1:
                buttLDir = -1;
                break;
            case 2:
                buttLDir = 1;
                break;
        }
        buttL.transform.Rotate(0, buttLDir, 0);

        int thighRDir = 0;
        int thighRVec = (int)vectorAction[2];
        switch (thighRVec) {
            case 3:
                thighRDir = 0;
                break;
            case 1:
                thighRDir = -1;
                break;
            case 2:
                thighRDir = 1;
                break;
        }
        thighR.transform.Rotate(0, thighRDir, 0);

        int thighLDir = 0;
        int thighLVec = (int)vectorAction[3];
        switch (thighLVec) {
            case 3:
                thighLDir = 0;
                break;
            case 1:
                thighLDir = -1;
                break;
            case 2:
                thighLDir = 1;
                break;
        }
        thighL.transform.Rotate(0, thighLDir, 0);

        int legRDir = 0;
        int legRVec = (int)vectorAction[4];
        switch (legRVec) {
            case 3:
                legRDir = 0;
                break;
            case 1:
                legRDir = -1;
                break;
            case 2:
                legRDir = 1;
                break;
        }
        legR.transform.Rotate(0, legRDir, 0);

        int legLDir = 0;
        int legLVec = (int)vectorAction[5];
        switch (legLVec) {
            case 3:
                legLDir = 0;
                break;
            case 1:
                legLDir = -1;
                break;
            case 2:
                legLDir = 1;
                break;
        }
        legL.transform.Rotate(0, legLDir, 0);

        int footRDir = 0;
        int footRVec = (int)vectorAction[6];
        switch (footRVec) {
            case 3:
                footRDir = 0;
                break;
            case 1:
                footRDir = -1;
                break;
            case 2:
                footRDir = 1;
                break;
        }
        footR.transform.Rotate(0, footRDir, 0);

        int footLDir = 0;
        int footLVec = (int)vectorAction[7];
        switch (footLVec) {
            case 3:
                footLDir = 0;
                break;
            case 1:
                footLDir = -1;
                break;
            case 2:
                footLDir = 1;
                break;
        }
        footL.transform.Rotate(0, footLDir, 0);



        //buttR = vectorAction[0]; //Right or none
        //if (buttR == 2) buttR = -1f; //Left

        if (waist.transform.position.y > -1.3) {
            AddReward(.1f);
        }
        else {
            AddReward(-.02f);
        }

        if (waist.transform.position.y <= -3) {
            Done();
            AddReward(-.1f);
        }
    }

    public override void CollectObservations() {
        AddVectorObs(waist.transform.localEulerAngles.y);
        AddVectorObs(buttR.transform.localEulerAngles.x);
        AddVectorObs(buttL.transform.localEulerAngles.x);
        AddVectorObs(thighR.transform.localEulerAngles.y);
        AddVectorObs(thighL.transform.localEulerAngles.y);
        AddVectorObs(legR.transform.localEulerAngles.y);
        AddVectorObs(legL.transform.localEulerAngles.y);
        AddVectorObs(footR.transform.localEulerAngles.y);
        AddVectorObs(footL.transform.localEulerAngles.y);

        AddVectorObs(waist.GetComponent<Rigidbody>().freezeRotation);

        AddVectorObs(waist.transform.position);
    }
}

最后,确保将 BalanceAgent 的 Max Step 设置为足够大以查看代理是否会失败,对于初学者来说可能是 500 或 1000。

【讨论】:

  • 所以现在他正在重置(谢谢),但它似乎每半秒发生一次!我希望他重置并开始新的“情节”的唯一一次是当他的腰部在他的 y 位置低于 -3 时。我认为 Done() 方法用于此目的。谢谢。
  • @JadenWilliams 我添加了关于重置物理的部分。注意新的ListInitializeAgentAgentReset 中的新部分
  • @JadenWilliams 我的建议没有改变什么吗?
  • @JadenWilliams 刚刚修正了一个错字,rigidbodysToReset -> rigidbodiesToReset。我不知道如何重现你的场景,所以我只能猜测半秒是腰部低于 y=-3 需要多长时间,这可能是因为上一集的向下速度是'不重置。
  • 如您所说,Done 重置代理。尝试将代码移出Update 并移入AgentAction
猜你喜欢
  • 2020-04-18
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2011-02-26
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多