比较提交
合并到: unity-tech-cn:main
unity-tech-cn:/main
unity-tech-cn:/develop-generalizationTraining-TrainerController
unity-tech-cn:/tag-0.2.0
unity-tech-cn:/tag-0.2.1
unity-tech-cn:/tag-0.2.1a
unity-tech-cn:/tag-0.2.1c
unity-tech-cn:/tag-0.2.1d
unity-tech-cn:/hotfix-v0.9.2a
unity-tech-cn:/develop-gpu-test
unity-tech-cn:/0.10.1
unity-tech-cn:/develop-pyinstaller
unity-tech-cn:/develop-horovod
unity-tech-cn:/PhysXArticulations20201
unity-tech-cn:/importdocfix
unity-tech-cn:/develop-resizetexture
unity-tech-cn:/hh-develop-walljump_bugfixes
unity-tech-cn:/develop-walljump-fix-sac
unity-tech-cn:/hh-develop-walljump_rnd
unity-tech-cn:/tag-0.11.0.dev0
unity-tech-cn:/develop-pytorch
unity-tech-cn:/tag-0.11.0.dev2
unity-tech-cn:/develop-newnormalization
unity-tech-cn:/tag-0.11.0.dev3
unity-tech-cn:/develop
unity-tech-cn:/release-0.12.0
unity-tech-cn:/tag-0.12.0-dev
unity-tech-cn:/tag-0.12.0.dev0
unity-tech-cn:/tag-0.12.1
unity-tech-cn:/2D-explorations
unity-tech-cn:/asymm-envs
unity-tech-cn:/tag-0.12.1.dev0
unity-tech-cn:/2D-exploration-raycast
unity-tech-cn:/tag-0.12.1.dev1
unity-tech-cn:/release-0.13.0
unity-tech-cn:/release-0.13.1
unity-tech-cn:/plugin-proof-of-concept
unity-tech-cn:/release-0.14.0
unity-tech-cn:/hotfix-bump-version-master
unity-tech-cn:/soccer-fives
unity-tech-cn:/release-0.14.1
unity-tech-cn:/bug-failed-api-check
unity-tech-cn:/test-recurrent-gail
unity-tech-cn:/hh-add-icons
unity-tech-cn:/release-0.15.0
unity-tech-cn:/release-0.15.1
unity-tech-cn:/hh-develop-all-posed-characters
unity-tech-cn:/internal-policy-ghost
unity-tech-cn:/distributed-training
unity-tech-cn:/hh-develop-improve_tennis
unity-tech-cn:/test-tf-ver
unity-tech-cn:/release_1_branch
unity-tech-cn:/tennis-time-horizon
unity-tech-cn:/whitepaper-experiments
unity-tech-cn:/r2v-yamato-linux
unity-tech-cn:/docs-update
unity-tech-cn:/release_2_branch
unity-tech-cn:/exp-mede
unity-tech-cn:/sensitivity
unity-tech-cn:/release_2_verified_load_fix
unity-tech-cn:/test-sampler
unity-tech-cn:/release_2_verified
unity-tech-cn:/hh-develop-ragdoll-testing
unity-tech-cn:/origin-develop-taggedobservations
unity-tech-cn:/MLA-1734-demo-provider
unity-tech-cn:/sampler-refactor-copy
unity-tech-cn:/PhysXArticulations20201Package
unity-tech-cn:/tag-com.unity.ml-agents_1.0.8
unity-tech-cn:/release_3_branch
unity-tech-cn:/github-actions
unity-tech-cn:/release_3_distributed
unity-tech-cn:/fix-batch-tennis
unity-tech-cn:/distributed-ppo-sac
unity-tech-cn:/gridworld-custom-obs
unity-tech-cn:/hw20-segmentation
unity-tech-cn:/hh-develop-gamedev-demo
unity-tech-cn:/active-variablespeed
unity-tech-cn:/release_4_branch
unity-tech-cn:/fix-env-step-loop
unity-tech-cn:/release_5_branch
unity-tech-cn:/fix-walker
unity-tech-cn:/release_6_branch
unity-tech-cn:/hh-32-observation-crawler
unity-tech-cn:/trainer-plugin
unity-tech-cn:/hh-develop-max-steps-demo-recorder
unity-tech-cn:/hh-develop-loco-walker-variable-speed
unity-tech-cn:/exp-0002
unity-tech-cn:/experiment-less-max-step
unity-tech-cn:/hh-develop-hallway-wall-mesh-fix
unity-tech-cn:/release_7_branch
unity-tech-cn:/exp-vince
unity-tech-cn:/hh-develop-gridsensor-tests
unity-tech-cn:/tag-release_8_test0
unity-tech-cn:/tag-release_8_test1
unity-tech-cn:/release_8_branch
unity-tech-cn:/docfix-end-episode
unity-tech-cn:/release_9_branch
unity-tech-cn:/hybrid-action-rewardsignals
unity-tech-cn:/MLA-462-yamato-win
unity-tech-cn:/exp-alternate-atten
unity-tech-cn:/hh-develop-fps_game_project
unity-tech-cn:/fix-conflict-base-env
unity-tech-cn:/release_10_branch
unity-tech-cn:/exp-bullet-hell-trainer
unity-tech-cn:/ai-summit-exp
unity-tech-cn:/comms-grad
unity-tech-cn:/walljump-pushblock
unity-tech-cn:/goal-conditioning
unity-tech-cn:/release_11_branch
unity-tech-cn:/hh-develop-water-balloon-fight
unity-tech-cn:/gc-hyper
unity-tech-cn:/layernorm
unity-tech-cn:/yamato-linux-debug-venv
unity-tech-cn:/soccer-comms
unity-tech-cn:/hh-develop-pushblockcollab
unity-tech-cn:/release_12_branch
unity-tech-cn:/fix-get-step-sp-curr
unity-tech-cn:/continuous-comms
unity-tech-cn:/no-comms
unity-tech-cn:/hh-develop-zombiepushblock
unity-tech-cn:/hypernetwork
unity-tech-cn:/revert-4859-develop-update-readme
unity-tech-cn:/sequencer-env-attention
unity-tech-cn:/hh-develop-variableobs
unity-tech-cn:/exp-tanh
unity-tech-cn:/reward-dist
unity-tech-cn:/exp-weight-decay
unity-tech-cn:/exp-robot
unity-tech-cn:/bullet-hell-barracuda-test-1.3.1
unity-tech-cn:/release_13_branch
unity-tech-cn:/release_14_branch
unity-tech-cn:/exp-clipped-gaussian-entropy
unity-tech-cn:/tic-tac-toe
unity-tech-cn:/hh-develop-dodgeball
unity-tech-cn:/repro-vis-obs-perf
unity-tech-cn:/v2-staging-rebase
unity-tech-cn:/release_15_branch
unity-tech-cn:/release_15_removeendepisode
unity-tech-cn:/release_16_branch
unity-tech-cn:/release_16_fix_gridsensor
unity-tech-cn:/ai-hw-2021
unity-tech-cn:/check-for-ModelOverriders
unity-tech-cn:/fix-grid-obs-shape-init
unity-tech-cn:/fix-gym-needs-reset
unity-tech-cn:/fix-resume-imi
unity-tech-cn:/release_17_branch
unity-tech-cn:/release_17_branch_gpu_test
unity-tech-cn:/colab-links
unity-tech-cn:/exp-continuous-div
unity-tech-cn:/release_17_branch_gpu_2
unity-tech-cn:/exp-diverse-behavior
unity-tech-cn:/grid-onehot-extra-dim-empty
unity-tech-cn:/2.0-verified
unity-tech-cn:/faster-entropy-coeficient-convergence
unity-tech-cn:/pre-r18-update-changelog
unity-tech-cn:/release_18_branch
unity-tech-cn:/main/tracking
unity-tech-cn:/main/reward-providers
unity-tech-cn:/main/project-upgrade
unity-tech-cn:/main/limitation-docs
unity-tech-cn:/develop/nomaxstep-test
unity-tech-cn:/develop/tf2.0
unity-tech-cn:/develop/tanhsquash
unity-tech-cn:/develop/magic-string
unity-tech-cn:/develop/trainerinterface
unity-tech-cn:/develop/separatevalue
unity-tech-cn:/develop/nopreviousactions
unity-tech-cn:/develop/reenablerepeatactions
unity-tech-cn:/develop/0memories
unity-tech-cn:/develop/fixmemoryleak
unity-tech-cn:/develop/reducewalljump
unity-tech-cn:/develop/removeactionholder-onehot
unity-tech-cn:/develop/canonicalize-quaternions
unity-tech-cn:/develop/self-playassym
unity-tech-cn:/develop/demo-load-seek
unity-tech-cn:/develop/progress-bar
unity-tech-cn:/develop/sac-apex
unity-tech-cn:/develop/cubewars
unity-tech-cn:/develop/add-fire
unity-tech-cn:/develop/gym-wrapper
unity-tech-cn:/develop/mm-docs-main-readme
unity-tech-cn:/develop/mm-docs-overview
unity-tech-cn:/develop/no-threading
unity-tech-cn:/develop/dockerfile
unity-tech-cn:/develop/model-store
unity-tech-cn:/develop/checkout-conversion-rebase
unity-tech-cn:/develop/model-transfer
unity-tech-cn:/develop/bisim-review
unity-tech-cn:/develop/taggedobservations
unity-tech-cn:/develop/transfer-bisim
unity-tech-cn:/develop/bisim-sac-transfer
unity-tech-cn:/develop/basketball
unity-tech-cn:/develop/torchmodules
unity-tech-cn:/develop/fixmarkdown
unity-tech-cn:/develop/shortenstrikervsgoalie
unity-tech-cn:/develop/shortengoalie
unity-tech-cn:/develop/torch-save-rp
unity-tech-cn:/develop/torch-to-np
unity-tech-cn:/develop/torch-omp-no-thread
unity-tech-cn:/develop/actionmodel-csharp
unity-tech-cn:/develop/torch-extra
unity-tech-cn:/develop/restructure-torch-networks
unity-tech-cn:/develop/jit
unity-tech-cn:/develop/adjust-cpu-settings-experiment
unity-tech-cn:/develop/torch-sac-threading
unity-tech-cn:/develop/wb
unity-tech-cn:/develop/amrl
unity-tech-cn:/develop/memorydump
unity-tech-cn:/develop/permutepytorch
unity-tech-cn:/develop/sac-targetq
unity-tech-cn:/develop/actions-out
unity-tech-cn:/develop/reshapeonnxmemories
unity-tech-cn:/develop/crawlergail
unity-tech-cn:/develop/debugtorchfood
unity-tech-cn:/develop/hybrid-actions
unity-tech-cn:/develop/bullet-hell
unity-tech-cn:/develop/action-spec-gym
unity-tech-cn:/develop/battlefoodcollector
unity-tech-cn:/develop/use-action-buffers
unity-tech-cn:/develop/hardswish
unity-tech-cn:/develop/leakyrelu
unity-tech-cn:/develop/torch-clip-scale
unity-tech-cn:/develop/contentropy
unity-tech-cn:/develop/manch
unity-tech-cn:/develop/torchcrawlerdebug
unity-tech-cn:/develop/fix-nan
unity-tech-cn:/develop/multitype-buffer
unity-tech-cn:/develop/windows-delay
unity-tech-cn:/develop/torch-tanh
unity-tech-cn:/develop/gail-norm
unity-tech-cn:/develop/multiprocess
unity-tech-cn:/develop/unified-obs
unity-tech-cn:/develop/rm-rf-new-models
unity-tech-cn:/develop/skipcritic
unity-tech-cn:/develop/centralizedcritic
unity-tech-cn:/develop/dodgeball-tests
unity-tech-cn:/develop/cc-teammanager
unity-tech-cn:/develop/weight-decay
unity-tech-cn:/develop/singular-embeddings
unity-tech-cn:/develop/zombieteammanager
unity-tech-cn:/develop/superpush
unity-tech-cn:/develop/teammanager
unity-tech-cn:/develop/zombie-exp
unity-tech-cn:/develop/update-readme
unity-tech-cn:/develop/readme-fix
unity-tech-cn:/develop/coma-noact
unity-tech-cn:/develop/coma-withq
unity-tech-cn:/develop/coma2
unity-tech-cn:/develop/action-slice
unity-tech-cn:/develop/gru
unity-tech-cn:/develop/critic-op-lstm-currentmem
unity-tech-cn:/develop/decaygail
unity-tech-cn:/develop/gail-srl-hack
unity-tech-cn:/develop/rear-pad
unity-tech-cn:/develop/mm-copyright-dates
unity-tech-cn:/develop/dodgeball-raycasts
unity-tech-cn:/develop/collab-envs-exp-ervin
unity-tech-cn:/develop/pushcollabonly
unity-tech-cn:/develop/sample-curation
unity-tech-cn:/develop/soccer-groupman
unity-tech-cn:/develop/input-actuator-tanks
unity-tech-cn:/develop/validate-release-fix
unity-tech-cn:/develop/new-console-log
unity-tech-cn:/develop/lex-walker-model
unity-tech-cn:/develop/lstm-burnin
unity-tech-cn:/develop/grid-vaiable-names
unity-tech-cn:/develop/fix-attn-embedding
unity-tech-cn:/develop/api-documentation-update-some-fixes
unity-tech-cn:/develop/update-grpc
unity-tech-cn:/develop/grid-rootref-debug
unity-tech-cn:/develop/pbcollab-rays
unity-tech-cn:/develop/2.0-verified-pre
unity-tech-cn:/develop/parameterizedenvs
unity-tech-cn:/develop/custom-ray-sensor
unity-tech-cn:/develop/mm-add-v2blog
unity-tech-cn:/develop/custom-raycast
unity-tech-cn:/develop/area-manager
unity-tech-cn:/develop/remove-unecessary-lr
unity-tech-cn:/develop/use-base-env-in-learn
unity-tech-cn:/soccer-fives/multiagent
unity-tech-cn:/develop/cubewars/splashdamage
unity-tech-cn:/develop/add-fire/exp
unity-tech-cn:/develop/add-fire/jit
unity-tech-cn:/develop/add-fire/speedtest
unity-tech-cn:/develop/add-fire/bc
unity-tech-cn:/develop/add-fire/ckpt-2
unity-tech-cn:/develop/add-fire/normalize-context
unity-tech-cn:/develop/add-fire/components-dir
unity-tech-cn:/develop/add-fire/halfentropy
unity-tech-cn:/develop/add-fire/memoryclass
unity-tech-cn:/develop/add-fire/categoricaldist
unity-tech-cn:/develop/add-fire/mm
unity-tech-cn:/develop/add-fire/sac-lst
unity-tech-cn:/develop/add-fire/mm3
unity-tech-cn:/develop/add-fire/continuous
unity-tech-cn:/develop/add-fire/ghost
unity-tech-cn:/develop/add-fire/policy-tests
unity-tech-cn:/develop/add-fire/export-discrete
unity-tech-cn:/develop/add-fire/test-simple-rl-fix-resnet
unity-tech-cn:/develop/add-fire/remove-currdoc
unity-tech-cn:/develop/add-fire/clean2
unity-tech-cn:/develop/add-fire/doc-cleanups
unity-tech-cn:/develop/add-fire/changelog
unity-tech-cn:/develop/add-fire/mm2
unity-tech-cn:/develop/model-transfer/add-physics
unity-tech-cn:/develop/model-transfer/train
unity-tech-cn:/develop/jit/experiments
unity-tech-cn:/exp-vince/sep30-2020
unity-tech-cn:/hh-develop-gridsensor-tests/static
unity-tech-cn:/develop/hybrid-actions/distlist
unity-tech-cn:/develop/bullet-hell/buffer
unity-tech-cn:/goal-conditioning/new
unity-tech-cn:/goal-conditioning/sensors-2
unity-tech-cn:/goal-conditioning/sensors-3-pytest-fix
unity-tech-cn:/goal-conditioning/grid-world
unity-tech-cn:/soccer-comms/disc
unity-tech-cn:/develop/centralizedcritic/counterfact
unity-tech-cn:/develop/centralizedcritic/mm
unity-tech-cn:/develop/centralizedcritic/nonego
unity-tech-cn:/develop/zombieteammanager/disableagent
unity-tech-cn:/develop/zombieteammanager/killfirst
unity-tech-cn:/develop/superpush/int
unity-tech-cn:/develop/superpush/branch-cleanup
unity-tech-cn:/develop/teammanager/int
unity-tech-cn:/develop/teammanager/cubewar-nocycle
unity-tech-cn:/develop/teammanager/cubewars
unity-tech-cn:/develop/superpush/int/hunter
unity-tech-cn:/goal-conditioning/new/allo-crawler
unity-tech-cn:/develop/coma2/clip
unity-tech-cn:/develop/coma2/singlenetwork
unity-tech-cn:/develop/coma2/samenet
unity-tech-cn:/develop/coma2/fixgroup
unity-tech-cn:/develop/coma2/samenet/sum
unity-tech-cn:/hh-develop-dodgeball/goy-input
unity-tech-cn:/develop/soccer-groupman/mod
unity-tech-cn:/develop/soccer-groupman/mod/hunter
unity-tech-cn:/develop/soccer-groupman/mod/hunter/cine
unity-tech-cn:/ai-hw-2021/tensor-applier
拉取从: unity-tech-cn:develop/add-fire/changelog
unity-tech-cn:/main
unity-tech-cn:/develop-generalizationTraining-TrainerController
unity-tech-cn:/tag-0.2.0
unity-tech-cn:/tag-0.2.1
unity-tech-cn:/tag-0.2.1a
unity-tech-cn:/tag-0.2.1c
unity-tech-cn:/tag-0.2.1d
unity-tech-cn:/hotfix-v0.9.2a
unity-tech-cn:/develop-gpu-test
unity-tech-cn:/0.10.1
unity-tech-cn:/develop-pyinstaller
unity-tech-cn:/develop-horovod
unity-tech-cn:/PhysXArticulations20201
unity-tech-cn:/importdocfix
unity-tech-cn:/develop-resizetexture
unity-tech-cn:/hh-develop-walljump_bugfixes
unity-tech-cn:/develop-walljump-fix-sac
unity-tech-cn:/hh-develop-walljump_rnd
unity-tech-cn:/tag-0.11.0.dev0
unity-tech-cn:/develop-pytorch
unity-tech-cn:/tag-0.11.0.dev2
unity-tech-cn:/develop-newnormalization
unity-tech-cn:/tag-0.11.0.dev3
unity-tech-cn:/develop
unity-tech-cn:/release-0.12.0
unity-tech-cn:/tag-0.12.0-dev
unity-tech-cn:/tag-0.12.0.dev0
unity-tech-cn:/tag-0.12.1
unity-tech-cn:/2D-explorations
unity-tech-cn:/asymm-envs
unity-tech-cn:/tag-0.12.1.dev0
unity-tech-cn:/2D-exploration-raycast
unity-tech-cn:/tag-0.12.1.dev1
unity-tech-cn:/release-0.13.0
unity-tech-cn:/release-0.13.1
unity-tech-cn:/plugin-proof-of-concept
unity-tech-cn:/release-0.14.0
unity-tech-cn:/hotfix-bump-version-master
unity-tech-cn:/soccer-fives
unity-tech-cn:/release-0.14.1
unity-tech-cn:/bug-failed-api-check
unity-tech-cn:/test-recurrent-gail
unity-tech-cn:/hh-add-icons
unity-tech-cn:/release-0.15.0
unity-tech-cn:/release-0.15.1
unity-tech-cn:/hh-develop-all-posed-characters
unity-tech-cn:/internal-policy-ghost
unity-tech-cn:/distributed-training
unity-tech-cn:/hh-develop-improve_tennis
unity-tech-cn:/test-tf-ver
unity-tech-cn:/release_1_branch
unity-tech-cn:/tennis-time-horizon
unity-tech-cn:/whitepaper-experiments
unity-tech-cn:/r2v-yamato-linux
unity-tech-cn:/docs-update
unity-tech-cn:/release_2_branch
unity-tech-cn:/exp-mede
unity-tech-cn:/sensitivity
unity-tech-cn:/release_2_verified_load_fix
unity-tech-cn:/test-sampler
unity-tech-cn:/release_2_verified
unity-tech-cn:/hh-develop-ragdoll-testing
unity-tech-cn:/origin-develop-taggedobservations
unity-tech-cn:/MLA-1734-demo-provider
unity-tech-cn:/sampler-refactor-copy
unity-tech-cn:/PhysXArticulations20201Package
unity-tech-cn:/tag-com.unity.ml-agents_1.0.8
unity-tech-cn:/release_3_branch
unity-tech-cn:/github-actions
unity-tech-cn:/release_3_distributed
unity-tech-cn:/fix-batch-tennis
unity-tech-cn:/distributed-ppo-sac
unity-tech-cn:/gridworld-custom-obs
unity-tech-cn:/hw20-segmentation
unity-tech-cn:/hh-develop-gamedev-demo
unity-tech-cn:/active-variablespeed
unity-tech-cn:/release_4_branch
unity-tech-cn:/fix-env-step-loop
unity-tech-cn:/release_5_branch
unity-tech-cn:/fix-walker
unity-tech-cn:/release_6_branch
unity-tech-cn:/hh-32-observation-crawler
unity-tech-cn:/trainer-plugin
unity-tech-cn:/hh-develop-max-steps-demo-recorder
unity-tech-cn:/hh-develop-loco-walker-variable-speed
unity-tech-cn:/exp-0002
unity-tech-cn:/experiment-less-max-step
unity-tech-cn:/hh-develop-hallway-wall-mesh-fix
unity-tech-cn:/release_7_branch
unity-tech-cn:/exp-vince
unity-tech-cn:/hh-develop-gridsensor-tests
unity-tech-cn:/tag-release_8_test0
unity-tech-cn:/tag-release_8_test1
unity-tech-cn:/release_8_branch
unity-tech-cn:/docfix-end-episode
unity-tech-cn:/release_9_branch
unity-tech-cn:/hybrid-action-rewardsignals
unity-tech-cn:/MLA-462-yamato-win
unity-tech-cn:/exp-alternate-atten
unity-tech-cn:/hh-develop-fps_game_project
unity-tech-cn:/fix-conflict-base-env
unity-tech-cn:/release_10_branch
unity-tech-cn:/exp-bullet-hell-trainer
unity-tech-cn:/ai-summit-exp
unity-tech-cn:/comms-grad
unity-tech-cn:/walljump-pushblock
unity-tech-cn:/goal-conditioning
unity-tech-cn:/release_11_branch
unity-tech-cn:/hh-develop-water-balloon-fight
unity-tech-cn:/gc-hyper
unity-tech-cn:/layernorm
unity-tech-cn:/yamato-linux-debug-venv
unity-tech-cn:/soccer-comms
unity-tech-cn:/hh-develop-pushblockcollab
unity-tech-cn:/release_12_branch
unity-tech-cn:/fix-get-step-sp-curr
unity-tech-cn:/continuous-comms
unity-tech-cn:/no-comms
unity-tech-cn:/hh-develop-zombiepushblock
unity-tech-cn:/hypernetwork
unity-tech-cn:/revert-4859-develop-update-readme
unity-tech-cn:/sequencer-env-attention
unity-tech-cn:/hh-develop-variableobs
unity-tech-cn:/exp-tanh
unity-tech-cn:/reward-dist
unity-tech-cn:/exp-weight-decay
unity-tech-cn:/exp-robot
unity-tech-cn:/bullet-hell-barracuda-test-1.3.1
unity-tech-cn:/release_13_branch
unity-tech-cn:/release_14_branch
unity-tech-cn:/exp-clipped-gaussian-entropy
unity-tech-cn:/tic-tac-toe
unity-tech-cn:/hh-develop-dodgeball
unity-tech-cn:/repro-vis-obs-perf
unity-tech-cn:/v2-staging-rebase
unity-tech-cn:/release_15_branch
unity-tech-cn:/release_15_removeendepisode
unity-tech-cn:/release_16_branch
unity-tech-cn:/release_16_fix_gridsensor
unity-tech-cn:/ai-hw-2021
unity-tech-cn:/check-for-ModelOverriders
unity-tech-cn:/fix-grid-obs-shape-init
unity-tech-cn:/fix-gym-needs-reset
unity-tech-cn:/fix-resume-imi
unity-tech-cn:/release_17_branch
unity-tech-cn:/release_17_branch_gpu_test
unity-tech-cn:/colab-links
unity-tech-cn:/exp-continuous-div
unity-tech-cn:/release_17_branch_gpu_2
unity-tech-cn:/exp-diverse-behavior
unity-tech-cn:/grid-onehot-extra-dim-empty
unity-tech-cn:/2.0-verified
unity-tech-cn:/faster-entropy-coeficient-convergence
unity-tech-cn:/pre-r18-update-changelog
unity-tech-cn:/release_18_branch
unity-tech-cn:/main/tracking
unity-tech-cn:/main/reward-providers
unity-tech-cn:/main/project-upgrade
unity-tech-cn:/main/limitation-docs
unity-tech-cn:/develop/nomaxstep-test
unity-tech-cn:/develop/tf2.0
unity-tech-cn:/develop/tanhsquash
unity-tech-cn:/develop/magic-string
unity-tech-cn:/develop/trainerinterface
unity-tech-cn:/develop/separatevalue
unity-tech-cn:/develop/nopreviousactions
unity-tech-cn:/develop/reenablerepeatactions
unity-tech-cn:/develop/0memories
unity-tech-cn:/develop/fixmemoryleak
unity-tech-cn:/develop/reducewalljump
unity-tech-cn:/develop/removeactionholder-onehot
unity-tech-cn:/develop/canonicalize-quaternions
unity-tech-cn:/develop/self-playassym
unity-tech-cn:/develop/demo-load-seek
unity-tech-cn:/develop/progress-bar
unity-tech-cn:/develop/sac-apex
unity-tech-cn:/develop/cubewars
unity-tech-cn:/develop/add-fire
unity-tech-cn:/develop/gym-wrapper
unity-tech-cn:/develop/mm-docs-main-readme
unity-tech-cn:/develop/mm-docs-overview
unity-tech-cn:/develop/no-threading
unity-tech-cn:/develop/dockerfile
unity-tech-cn:/develop/model-store
unity-tech-cn:/develop/checkout-conversion-rebase
unity-tech-cn:/develop/model-transfer
unity-tech-cn:/develop/bisim-review
unity-tech-cn:/develop/taggedobservations
unity-tech-cn:/develop/transfer-bisim
unity-tech-cn:/develop/bisim-sac-transfer
unity-tech-cn:/develop/basketball
unity-tech-cn:/develop/torchmodules
unity-tech-cn:/develop/fixmarkdown
unity-tech-cn:/develop/shortenstrikervsgoalie
unity-tech-cn:/develop/shortengoalie
unity-tech-cn:/develop/torch-save-rp
unity-tech-cn:/develop/torch-to-np
unity-tech-cn:/develop/torch-omp-no-thread
unity-tech-cn:/develop/actionmodel-csharp
unity-tech-cn:/develop/torch-extra
unity-tech-cn:/develop/restructure-torch-networks
unity-tech-cn:/develop/jit
unity-tech-cn:/develop/adjust-cpu-settings-experiment
unity-tech-cn:/develop/torch-sac-threading
unity-tech-cn:/develop/wb
unity-tech-cn:/develop/amrl
unity-tech-cn:/develop/memorydump
unity-tech-cn:/develop/permutepytorch
unity-tech-cn:/develop/sac-targetq
unity-tech-cn:/develop/actions-out
unity-tech-cn:/develop/reshapeonnxmemories
unity-tech-cn:/develop/crawlergail
unity-tech-cn:/develop/debugtorchfood
unity-tech-cn:/develop/hybrid-actions
unity-tech-cn:/develop/bullet-hell
unity-tech-cn:/develop/action-spec-gym
unity-tech-cn:/develop/battlefoodcollector
unity-tech-cn:/develop/use-action-buffers
unity-tech-cn:/develop/hardswish
unity-tech-cn:/develop/leakyrelu
unity-tech-cn:/develop/torch-clip-scale
unity-tech-cn:/develop/contentropy
unity-tech-cn:/develop/manch
unity-tech-cn:/develop/torchcrawlerdebug
unity-tech-cn:/develop/fix-nan
unity-tech-cn:/develop/multitype-buffer
unity-tech-cn:/develop/windows-delay
unity-tech-cn:/develop/torch-tanh
unity-tech-cn:/develop/gail-norm
unity-tech-cn:/develop/multiprocess
unity-tech-cn:/develop/unified-obs
unity-tech-cn:/develop/rm-rf-new-models
unity-tech-cn:/develop/skipcritic
unity-tech-cn:/develop/centralizedcritic
unity-tech-cn:/develop/dodgeball-tests
unity-tech-cn:/develop/cc-teammanager
unity-tech-cn:/develop/weight-decay
unity-tech-cn:/develop/singular-embeddings
unity-tech-cn:/develop/zombieteammanager
unity-tech-cn:/develop/superpush
unity-tech-cn:/develop/teammanager
unity-tech-cn:/develop/zombie-exp
unity-tech-cn:/develop/update-readme
unity-tech-cn:/develop/readme-fix
unity-tech-cn:/develop/coma-noact
unity-tech-cn:/develop/coma-withq
unity-tech-cn:/develop/coma2
unity-tech-cn:/develop/action-slice
unity-tech-cn:/develop/gru
unity-tech-cn:/develop/critic-op-lstm-currentmem
unity-tech-cn:/develop/decaygail
unity-tech-cn:/develop/gail-srl-hack
unity-tech-cn:/develop/rear-pad
unity-tech-cn:/develop/mm-copyright-dates
unity-tech-cn:/develop/dodgeball-raycasts
unity-tech-cn:/develop/collab-envs-exp-ervin
unity-tech-cn:/develop/pushcollabonly
unity-tech-cn:/develop/sample-curation
unity-tech-cn:/develop/soccer-groupman
unity-tech-cn:/develop/input-actuator-tanks
unity-tech-cn:/develop/validate-release-fix
unity-tech-cn:/develop/new-console-log
unity-tech-cn:/develop/lex-walker-model
unity-tech-cn:/develop/lstm-burnin
unity-tech-cn:/develop/grid-vaiable-names
unity-tech-cn:/develop/fix-attn-embedding
unity-tech-cn:/develop/api-documentation-update-some-fixes
unity-tech-cn:/develop/update-grpc
unity-tech-cn:/develop/grid-rootref-debug
unity-tech-cn:/develop/pbcollab-rays
unity-tech-cn:/develop/2.0-verified-pre
unity-tech-cn:/develop/parameterizedenvs
unity-tech-cn:/develop/custom-ray-sensor
unity-tech-cn:/develop/mm-add-v2blog
unity-tech-cn:/develop/custom-raycast
unity-tech-cn:/develop/area-manager
unity-tech-cn:/develop/remove-unecessary-lr
unity-tech-cn:/develop/use-base-env-in-learn
unity-tech-cn:/soccer-fives/multiagent
unity-tech-cn:/develop/cubewars/splashdamage
unity-tech-cn:/develop/add-fire/exp
unity-tech-cn:/develop/add-fire/jit
unity-tech-cn:/develop/add-fire/speedtest
unity-tech-cn:/develop/add-fire/bc
unity-tech-cn:/develop/add-fire/ckpt-2
unity-tech-cn:/develop/add-fire/normalize-context
unity-tech-cn:/develop/add-fire/components-dir
unity-tech-cn:/develop/add-fire/halfentropy
unity-tech-cn:/develop/add-fire/memoryclass
unity-tech-cn:/develop/add-fire/categoricaldist
unity-tech-cn:/develop/add-fire/mm
unity-tech-cn:/develop/add-fire/sac-lst
unity-tech-cn:/develop/add-fire/mm3
unity-tech-cn:/develop/add-fire/continuous
unity-tech-cn:/develop/add-fire/ghost
unity-tech-cn:/develop/add-fire/policy-tests
unity-tech-cn:/develop/add-fire/export-discrete
unity-tech-cn:/develop/add-fire/test-simple-rl-fix-resnet
unity-tech-cn:/develop/add-fire/remove-currdoc
unity-tech-cn:/develop/add-fire/clean2
unity-tech-cn:/develop/add-fire/doc-cleanups
unity-tech-cn:/develop/add-fire/changelog
unity-tech-cn:/develop/add-fire/mm2
unity-tech-cn:/develop/model-transfer/add-physics
unity-tech-cn:/develop/model-transfer/train
unity-tech-cn:/develop/jit/experiments
unity-tech-cn:/exp-vince/sep30-2020
unity-tech-cn:/hh-develop-gridsensor-tests/static
unity-tech-cn:/develop/hybrid-actions/distlist
unity-tech-cn:/develop/bullet-hell/buffer
unity-tech-cn:/goal-conditioning/new
unity-tech-cn:/goal-conditioning/sensors-2
unity-tech-cn:/goal-conditioning/sensors-3-pytest-fix
unity-tech-cn:/goal-conditioning/grid-world
unity-tech-cn:/soccer-comms/disc
unity-tech-cn:/develop/centralizedcritic/counterfact
unity-tech-cn:/develop/centralizedcritic/mm
unity-tech-cn:/develop/centralizedcritic/nonego
unity-tech-cn:/develop/zombieteammanager/disableagent
unity-tech-cn:/develop/zombieteammanager/killfirst
unity-tech-cn:/develop/superpush/int
unity-tech-cn:/develop/superpush/branch-cleanup
unity-tech-cn:/develop/teammanager/int
unity-tech-cn:/develop/teammanager/cubewar-nocycle
unity-tech-cn:/develop/teammanager/cubewars
unity-tech-cn:/develop/superpush/int/hunter
unity-tech-cn:/goal-conditioning/new/allo-crawler
unity-tech-cn:/develop/coma2/clip
unity-tech-cn:/develop/coma2/singlenetwork
unity-tech-cn:/develop/coma2/samenet
unity-tech-cn:/develop/coma2/fixgroup
unity-tech-cn:/develop/coma2/samenet/sum
unity-tech-cn:/hh-develop-dodgeball/goy-input
unity-tech-cn:/develop/soccer-groupman/mod
unity-tech-cn:/develop/soccer-groupman/mod/hunter
unity-tech-cn:/develop/soccer-groupman/mod/hunter/cine
unity-tech-cn:/ai-hw-2021/tensor-applier
此合并请求有变更与目标分支冲突。
/com.unity.ml-agents/CHANGELOG.md
/docs/Learning-Environment-Examples.md
/ml-agents/mlagents/trainers/cli_utils.py
/ml-agents/mlagents/trainers/settings.py
/ml-agents/mlagents/trainers/trainer_controller.py
/ml-agents/mlagents/trainers/ghost/trainer.py
/ml-agents/mlagents/trainers/ppo/trainer.py
/ml-agents/mlagents/trainers/sac/trainer.py
/ml-agents/mlagents/trainers/trainer/trainer.py
/ml-agents/mlagents/trainers/trainer/rl_trainer.py
/ml-agents/mlagents/trainers/tests/torch/test_utils.py
/ml-agents/mlagents/trainers/tests/torch/test_layers.py
/ml-agents/mlagents/trainers/tests/torch/test_networks.py
/ml-agents/mlagents/trainers/tests/test_rl_trainer.py
/ml-agents/mlagents/trainers/buffer.py
/ml-agents/mlagents/trainers/torch/layers.py
/ml-agents/mlagents/trainers/torch/utils.py
/ml-agents/mlagents/trainers/torch/networks.py
/ml-agents/mlagents/trainers/torch/encoders.py
/ml-agents/mlagents/trainers/saver/saver.py
/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
/ml-agents/mlagents/trainers/policy/torch_policy.py
/ml-agents/mlagents/trainers/ppo/optimizer_torch.py
/ml-agents/mlagents/trainers/sac/optimizer_torch.py
/ml-agents/mlagents/trainers/saver/torch_saver.py
/ml-agents/mlagents/trainers/tests/torch/test_reward_providers
/ml-agents/mlagents/trainers/tests/torch/test_bcmodule.py
/ml-agents/mlagents/trainers/tests/torch/test_policy.py
/ml-agents/mlagents/trainers/tests/torch/test_ghost.py
/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
/ml-agents/mlagents/trainers/torch/components
/ml-agents/mlagents/trainers/policy/tf_policy.py
/ml-agents/mlagents/trainers/tests/test_ghost.py
/ml-agents/mlagents/trainers/tests/test_sac.py
/ml-agents/mlagents/trainers/tests/test_simple_rl.py
/com.unity.ml-agents.extensions/Tests/Editor/Sensors
/utils/validate_release_links.py
/com.unity.ml-agents.extensions/Runtime/Sensors
/com.unity.ml-agents.extensions/Runtime/Sensors
/com.unity.ml-agents.extensions/Runtime/Sensors
/com.unity.ml-agents.extensions/Runtime/Sensors
/com.unity.ml-agents.extensions/Runtime/Sensors
/com.unity.ml-agents.extensions/Runtime/Sensors
/com.unity.ml-agents.extensions/Runtime/Sensors
/com.unity.ml-agents.extensions/Runtime/Sensors
/com.unity.ml-agents.extensions/Runtime/Sensors
/ml-agents/mlagents/trainers/saver
/ml-agents/mlagents/trainers/tests/torch
/ml-agents/mlagents/trainers/tests/torch
/ml-agents/mlagents/trainers/tests/torch
/ml-agents/mlagents/trainers/tests/torch
/ml-agents/mlagents/trainers/tests/torch
/ml-agents/mlagents/trainers/tests/torch
/ml-agents/mlagents/trainers/tests/torch
/ml-agents/mlagents/trainers/tests/torch
/ml-agents/mlagents/trainers/tests/torch
/ml-agents/mlagents/trainers/tests/torch
/ml-agents/mlagents/trainers/tests/torch
/ml-agents/mlagents/trainers/tests/torch
/ml-agents/mlagents/trainers/tests/torch
/ml-agents/mlagents/trainers/tests/torch
/ml-agents/mlagents/trainers/tests/torch
/ml-agents/mlagents/trainers/tests/torch
/ml-agents/mlagents/trainers/tests/torch
/ml-agents/mlagents/trainers/torch
/ml-agents/mlagents/trainers/torch
/ml-agents/mlagents/trainers/torch
/ml-agents/mlagents/trainers/torch
/ml-agents/mlagents/trainers/torch
/ml-agents/mlagents/trainers/torch
/ml-agents/mlagents/trainers/torch
/ml-agents/mlagents/trainers/torch
/ml-agents/mlagents/trainers/saver
/ml-agents/mlagents/trainers/tf/distributions.py
/ml-agents/mlagents/trainers/tf/tensorflow_to_barracuda.py
/ml-agents/mlagents/trainers/tf/models.py
/com.unity.ml-agents.extensions/Tests/Editor/Sensors
/com.unity.ml-agents.extensions/Tests/Editor/Sensors
/com.unity.ml-agents.extensions/Tests/Editor/Sensors
/com.unity.ml-agents.extensions/Tests/Editor/Sensors
1 次代码提交
作者 | SHA1 | 备注 | 提交日期 |
---|---|---|---|
Ervin Teng | 7cd80378 | Update changelog | 4 年前 |
共有 131 个文件被更改,包括 8000 次插入 和 454 次删除
-
0config/sac/3DBall.yaml
-
0config/sac/3DBallHard.yaml
-
2docs/Learning-Environment-Examples.md
-
2ml-agents/mlagents/trainers/buffer.py
-
12ml-agents/mlagents/trainers/ghost/trainer.py
-
7ml-agents/mlagents/trainers/cli_utils.py
-
14ml-agents/mlagents/trainers/settings.py
-
99ml-agents/mlagents/trainers/trainer/rl_trainer.py
-
5ml-agents/mlagents/trainers/trainer/trainer.py
-
2ml-agents/mlagents/trainers/policy/tf_policy.py
-
7ml-agents/mlagents/trainers/trainer_controller.py
-
7ml-agents/mlagents/trainers/tests/test_ghost.py
-
5ml-agents/mlagents/trainers/tests/test_rl_trainer.py
-
3ml-agents/mlagents/trainers/tests/test_sac.py
-
2ml-agents/mlagents/trainers/tests/test_simple_rl.py
-
76ml-agents/mlagents/trainers/ppo/trainer.py
-
120ml-agents/mlagents/trainers/sac/trainer.py
-
4com.unity.ml-agents/CHANGELOG.md
-
2com.unity.ml-agents/Runtime/SensorHelper.cs.meta
-
2Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/OrientationCubeController.cs.meta
-
19ml-agents/mlagents/trainers/tests/torch/test_layers.py
-
17ml-agents/mlagents/trainers/tests/torch/test_networks.py
-
6ml-agents/mlagents/trainers/tests/torch/test_utils.py
-
17ml-agents/mlagents/trainers/torch/encoders.py
-
67ml-agents/mlagents/trainers/torch/layers.py
-
115ml-agents/mlagents/trainers/torch/networks.py
-
4ml-agents/mlagents/trainers/torch/utils.py
-
2Project/Assets/csc.rsp
-
132utils/validate_release_links.py
-
94ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
-
281ml-agents/mlagents/trainers/policy/torch_policy.py
-
36ml-agents/mlagents/trainers/tests/test_models.py
-
113ml-agents/mlagents/trainers/tests/test_saver.py
-
203ml-agents/mlagents/trainers/ppo/optimizer_torch.py
-
561ml-agents/mlagents/trainers/sac/optimizer_torch.py
-
11com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs.meta
-
11com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs.meta
-
152com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
-
107com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
-
47com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
-
122com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
-
473com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
-
209com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
-
116com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
-
140com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
-
136com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
-
249com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
-
186com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
-
0ml-agents/mlagents/trainers/tf/__init__.py
-
221ml-agents/mlagents/trainers/tf/model_serialization.py
-
0ml-agents/mlagents/trainers/saver/__init__.py
-
170ml-agents/mlagents/trainers/saver/tf_saver.py
-
118ml-agents/mlagents/trainers/saver/torch_saver.py
-
111ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py
-
56ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py
-
138ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py
-
32ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py
-
1001ml-agents/mlagents/trainers/tests/torch/test.demo
-
144ml-agents/mlagents/trainers/tests/torch/test_bcmodule.py
-
446ml-agents/mlagents/trainers/tests/torch/testdcvis.demo
-
150ml-agents/mlagents/trainers/tests/torch/test_policy.py
-
177ml-agents/mlagents/trainers/tests/torch/test_ghost.py
-
505ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
-
0ml-agents/mlagents/trainers/torch/__init__.py
-
0ml-agents/mlagents/trainers/torch/components/__init__.py
-
15ml-agents/mlagents/trainers/torch/components/reward_providers/__init__.py
-
72ml-agents/mlagents/trainers/torch/components/reward_providers/base_reward_provider.py
-
15ml-agents/mlagents/trainers/torch/components/reward_providers/extrinsic_reward_provider.py
-
43ml-agents/mlagents/trainers/torch/components/reward_providers/reward_provider_factory.py
-
225ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
-
256ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
-
0ml-agents/mlagents/trainers/torch/components/bc/__init__.py
-
183ml-agents/mlagents/trainers/torch/components/bc/module.py
-
74ml-agents/mlagents/trainers/torch/model_serialization.py
-
11Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/OrientationCubeController.cs.meta
-
2Project/Assets/csc.rsp
-
132utils/validate_release_links.py
-
36ml-agents/mlagents/trainers/tests/test_models.py
-
113ml-agents/mlagents/trainers/tests/test_saver.py
-
11com.unity.ml-agents/Runtime/SensorHelper.cs.meta
-
0/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs.meta
-
0/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs.meta
-
0/com.unity.ml-agents/Runtime/SensorHelper.cs.meta
-
0/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/OrientationCubeController.cs.meta
|
|||
-warnaserror+ |
|||
-warnaserror-:618 |
|
|||
#!/usr/bin/env python3 |
|||
|
|||
import ast |
|||
import sys |
|||
import os |
|||
import re |
|||
import subprocess |
|||
from typing import List, Optional, Pattern |
|||
|
|||
RELEASE_PATTERN = re.compile(r"release_[0-9]+(_docs)*") |
|||
TRAINER_INIT_FILE = "ml-agents/mlagents/trainers/__init__.py" |
|||
|
|||
# Filename -> regex list to allow specific lines. |
|||
# To allow everything in the file, use None for the value |
|||
ALLOW_LIST = { |
|||
# Previous release table |
|||
"README.md": re.compile(r"\*\*Release [0-9]+\*\*"), |
|||
"docs/Versioning.md": None, |
|||
"com.unity.ml-agents/CHANGELOG.md": None, |
|||
"utils/make_readme_table.py": None, |
|||
"utils/validate_release_links.py": None, |
|||
} |
|||
|
|||
|
|||
def test_pattern(): |
|||
# Just some sanity check that the regex works as expected. |
|||
assert RELEASE_PATTERN.search( |
|||
"https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/Food.md" |
|||
) |
|||
assert RELEASE_PATTERN.search( |
|||
"https://github.com/Unity-Technologies/ml-agents/blob/release_4/Foo.md" |
|||
) |
|||
assert RELEASE_PATTERN.search( |
|||
"git clone --branch release_4 https://github.com/Unity-Technologies/ml-agents.git" |
|||
) |
|||
assert RELEASE_PATTERN.search( |
|||
"https://github.com/Unity-Technologies/ml-agents/blob/release_123_docs/Foo.md" |
|||
) |
|||
assert RELEASE_PATTERN.search( |
|||
"https://github.com/Unity-Technologies/ml-agents/blob/release_123/Foo.md" |
|||
) |
|||
assert not RELEASE_PATTERN.search( |
|||
"https://github.com/Unity-Technologies/ml-agents/blob/latest_release/docs/Foo.md" |
|||
) |
|||
print("tests OK!") |
|||
|
|||
|
|||
def git_ls_files() -> List[str]: |
|||
""" |
|||
Run "git ls-files" and return a list with one entry per line. |
|||
This returns the list of all files tracked by git. |
|||
""" |
|||
return subprocess.check_output(["git", "ls-files"], universal_newlines=True).split( |
|||
"\n" |
|||
) |
|||
|
|||
|
|||
def get_release_tag() -> Optional[str]: |
|||
""" |
|||
Returns the release tag for the mlagents python package. |
|||
This will be None on the master branch. |
|||
:return: |
|||
""" |
|||
with open(TRAINER_INIT_FILE) as f: |
|||
for line in f: |
|||
if "__release_tag__" in line: |
|||
lhs, equals_string, rhs = line.strip().partition(" = ") |
|||
# Evaluate the right hand side of the expression |
|||
return ast.literal_eval(rhs) |
|||
# If we couldn't find the release tag, raise an exception |
|||
# (since we can't return None here) |
|||
raise RuntimeError("Can't determine release tag") |
|||
|
|||
|
|||
def check_file(filename: str, global_allow_pattern: Pattern) -> List[str]: |
|||
""" |
|||
Validate a single file and return any offending lines. |
|||
""" |
|||
bad_lines = [] |
|||
with open(filename) as f: |
|||
for line in f: |
|||
if not RELEASE_PATTERN.search(line): |
|||
continue |
|||
|
|||
if global_allow_pattern.search(line): |
|||
continue |
|||
|
|||
if filename in ALLOW_LIST: |
|||
if ALLOW_LIST[filename] is None or ALLOW_LIST[filename].search(line): |
|||
continue |
|||
|
|||
bad_lines.append(f"{filename}: {line.strip()}") |
|||
return bad_lines |
|||
|
|||
|
|||
def check_all_files(allow_pattern: Pattern) -> List[str]: |
|||
""" |
|||
Validate all files tracked by git. |
|||
:param allow_pattern: |
|||
""" |
|||
bad_lines = [] |
|||
file_types = {".py", ".md", ".cs"} |
|||
for file_name in git_ls_files(): |
|||
if "localized" in file_name or os.path.splitext(file_name)[1] not in file_types: |
|||
continue |
|||
bad_lines += check_file(file_name, allow_pattern) |
|||
return bad_lines |
|||
|
|||
|
|||
def main(): |
|||
release_tag = get_release_tag() |
|||
if not release_tag: |
|||
print("Release tag is None, exiting") |
|||
sys.exit(0) |
|||
|
|||
print(f"Release tag: {release_tag}") |
|||
allow_pattern = re.compile(f"{release_tag}(_docs)*") |
|||
bad_lines = check_all_files(allow_pattern) |
|||
if bad_lines: |
|||
print( |
|||
f"Found lines referring to previous release. Either update the files, or add an exclusion to {__file__}" |
|||
) |
|||
for line in bad_lines: |
|||
print(line) |
|||
|
|||
sys.exit(1 if bad_lines else 0) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
if "--test" in sys.argv: |
|||
test_pattern() |
|||
main() |
|
|||
from typing import Dict, Optional, Tuple, List |
|||
import torch |
|||
import numpy as np |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents.trainers.trajectory import SplitObservations |
|||
from mlagents.trainers.torch.components.bc.module import BCModule |
|||
from mlagents.trainers.torch.components.reward_providers import create_reward_provider |
|||
|
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.optimizer import Optimizer |
|||
from mlagents.trainers.settings import TrainerSettings |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
|
|||
|
|||
class TorchOptimizer(Optimizer): # pylint: disable=W0223 |
|||
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): |
|||
super().__init__() |
|||
self.policy = policy |
|||
self.trainer_settings = trainer_settings |
|||
self.update_dict: Dict[str, torch.Tensor] = {} |
|||
self.value_heads: Dict[str, torch.Tensor] = {} |
|||
self.memory_in: torch.Tensor = None |
|||
self.memory_out: torch.Tensor = None |
|||
self.m_size: int = 0 |
|||
self.global_step = torch.tensor(0) |
|||
self.bc_module: Optional[BCModule] = None |
|||
self.create_reward_signals(trainer_settings.reward_signals) |
|||
if trainer_settings.behavioral_cloning is not None: |
|||
self.bc_module = BCModule( |
|||
self.policy, |
|||
trainer_settings.behavioral_cloning, |
|||
policy_learning_rate=trainer_settings.hyperparameters.learning_rate, |
|||
default_batch_size=trainer_settings.hyperparameters.batch_size, |
|||
default_num_epoch=3, |
|||
) |
|||
|
|||
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|||
pass |
|||
|
|||
def create_reward_signals(self, reward_signal_configs): |
|||
""" |
|||
Create reward signals |
|||
:param reward_signal_configs: Reward signal config. |
|||
""" |
|||
for reward_signal, settings in reward_signal_configs.items(): |
|||
# Name reward signals by string in case we have duplicates later |
|||
self.reward_signals[reward_signal.value] = create_reward_provider( |
|||
reward_signal, self.policy.behavior_spec, settings |
|||
) |
|||
|
|||
def get_trajectory_value_estimates( |
|||
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool |
|||
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]: |
|||
vector_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] |
|||
if self.policy.use_vis_obs: |
|||
visual_obs = [] |
|||
for idx, _ in enumerate( |
|||
self.policy.actor_critic.network_body.visual_encoders |
|||
): |
|||
visual_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx]) |
|||
visual_obs.append(visual_ob) |
|||
else: |
|||
visual_obs = [] |
|||
|
|||
memory = torch.zeros([1, 1, self.policy.m_size]) |
|||
|
|||
vec_vis_obs = SplitObservations.from_observations(next_obs) |
|||
next_vec_obs = [ |
|||
ModelUtils.list_to_tensor(vec_vis_obs.vector_observations).unsqueeze(0) |
|||
] |
|||
next_vis_obs = [ |
|||
ModelUtils.list_to_tensor(_vis_ob).unsqueeze(0) |
|||
for _vis_ob in vec_vis_obs.visual_observations |
|||
] |
|||
|
|||
value_estimates, next_memory = self.policy.actor_critic.critic_pass( |
|||
vector_obs, visual_obs, memory, sequence_length=batch.num_experiences |
|||
) |
|||
|
|||
next_value_estimate, _ = self.policy.actor_critic.critic_pass( |
|||
next_vec_obs, next_vis_obs, next_memory, sequence_length=1 |
|||
) |
|||
|
|||
for name, estimate in value_estimates.items(): |
|||
value_estimates[name] = estimate.detach().cpu().numpy() |
|||
next_value_estimate[name] = next_value_estimate[name].detach().cpu().numpy() |
|||
|
|||
if done: |
|||
for k in next_value_estimate: |
|||
if not self.reward_signals[k].ignore_done: |
|||
next_value_estimate[k] = 0.0 |
|||
|
|||
return value_estimates, next_value_estimate |
|
|||
from typing import Any, Dict, List, Tuple, Optional |
|||
import numpy as np |
|||
import torch |
|||
import copy |
|||
|
|||
from mlagents.trainers.action_info import ActionInfo |
|||
from mlagents.trainers.behavior_id_utils import get_global_agent_id |
|||
from mlagents.trainers.policy import Policy |
|||
from mlagents_envs.base_env import DecisionSteps, BehaviorSpec |
|||
from mlagents_envs.timers import timed |
|||
|
|||
from mlagents.trainers.settings import TrainerSettings |
|||
from mlagents.trainers.trajectory import SplitObservations |
|||
from mlagents.trainers.torch.networks import ( |
|||
SharedActorCritic, |
|||
SeparateActorCritic, |
|||
GlobalSteps, |
|||
) |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
|
|||
EPSILON = 1e-7 # Small value to avoid divide by zero |
|||
|
|||
|
|||
class TorchPolicy(Policy): |
|||
def __init__( |
|||
self, |
|||
seed: int, |
|||
behavior_spec: BehaviorSpec, |
|||
trainer_settings: TrainerSettings, |
|||
tanh_squash: bool = False, |
|||
reparameterize: bool = False, |
|||
separate_critic: bool = True, |
|||
condition_sigma_on_obs: bool = True, |
|||
): |
|||
""" |
|||
Policy that uses a multilayer perceptron to map the observations to actions. Could |
|||
also use a CNN to encode visual input prior to the MLP. Supports discrete and |
|||
continuous action spaces, as well as recurrent networks. |
|||
:param seed: Random seed. |
|||
:param brain: Assigned BrainParameters object. |
|||
:param trainer_settings: Defined training parameters. |
|||
:param load: Whether a pre-trained model will be loaded or a new one created. |
|||
:param tanh_squash: Whether to use a tanh function on the continuous output, |
|||
or a clipped output. |
|||
:param reparameterize: Whether we are using the resampling trick to update the policy |
|||
in continuous output. |
|||
""" |
|||
super().__init__( |
|||
seed, |
|||
behavior_spec, |
|||
trainer_settings, |
|||
tanh_squash, |
|||
reparameterize, |
|||
condition_sigma_on_obs, |
|||
) |
|||
self.global_step = ( |
|||
GlobalSteps() |
|||
) # could be much simpler if TorchPolicy is nn.Module |
|||
self.grads = None |
|||
|
|||
torch.set_default_tensor_type(torch.FloatTensor) |
|||
|
|||
reward_signal_configs = trainer_settings.reward_signals |
|||
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] |
|||
|
|||
self.stats_name_to_update_name = { |
|||
"Losses/Value Loss": "value_loss", |
|||
"Losses/Policy Loss": "policy_loss", |
|||
} |
|||
if separate_critic: |
|||
ac_class = SeparateActorCritic |
|||
else: |
|||
ac_class = SharedActorCritic |
|||
self.actor_critic = ac_class( |
|||
observation_shapes=self.behavior_spec.observation_shapes, |
|||
network_settings=trainer_settings.network_settings, |
|||
act_type=behavior_spec.action_type, |
|||
act_size=self.act_size, |
|||
stream_names=reward_signal_names, |
|||
conditional_sigma=self.condition_sigma_on_obs, |
|||
tanh_squash=tanh_squash, |
|||
) |
|||
# Save the m_size needed for export |
|||
self._export_m_size = self.m_size |
|||
# m_size needed for training is determined by network, not trainer settings |
|||
self.m_size = self.actor_critic.memory_size |
|||
|
|||
self.actor_critic.to("cpu") |
|||
|
|||
@property |
|||
def export_memory_size(self) -> int: |
|||
""" |
|||
Returns the memory size of the exported ONNX policy. This only includes the memory |
|||
of the Actor and not any auxillary networks. |
|||
""" |
|||
return self._export_m_size |
|||
|
|||
def _split_decision_step( |
|||
self, decision_requests: DecisionSteps |
|||
) -> Tuple[SplitObservations, np.ndarray]: |
|||
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs) |
|||
mask = None |
|||
if not self.use_continuous_act: |
|||
mask = torch.ones([len(decision_requests), np.sum(self.act_size)]) |
|||
if decision_requests.action_mask is not None: |
|||
mask = torch.as_tensor( |
|||
1 - np.concatenate(decision_requests.action_mask, axis=1) |
|||
) |
|||
return vec_vis_obs, mask |
|||
|
|||
def update_normalization(self, vector_obs: np.ndarray) -> None: |
|||
""" |
|||
If this policy normalizes vector observations, this will update the norm values in the graph. |
|||
:param vector_obs: The vector observations to add to the running estimate of the distribution. |
|||
""" |
|||
vector_obs = [torch.as_tensor(vector_obs)] |
|||
if self.use_vec_obs and self.normalize: |
|||
self.actor_critic.update_normalization(vector_obs) |
|||
|
|||
@timed |
|||
def sample_actions( |
|||
self, |
|||
vec_obs: List[torch.Tensor], |
|||
vis_obs: List[torch.Tensor], |
|||
masks: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
seq_len: int = 1, |
|||
all_log_probs: bool = False, |
|||
) -> Tuple[ |
|||
torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor |
|||
]: |
|||
""" |
|||
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action. |
|||
""" |
|||
dists, value_heads, memories = self.actor_critic.get_dist_and_value( |
|||
vec_obs, vis_obs, masks, memories, seq_len |
|||
) |
|||
action_list = self.actor_critic.sample_action(dists) |
|||
log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy( |
|||
action_list, dists |
|||
) |
|||
actions = torch.stack(action_list, dim=-1) |
|||
if self.use_continuous_act: |
|||
actions = actions[:, :, 0] |
|||
else: |
|||
actions = actions[:, 0, :] |
|||
|
|||
return ( |
|||
actions, |
|||
all_logs if all_log_probs else log_probs, |
|||
entropies, |
|||
value_heads, |
|||
memories, |
|||
) |
|||
|
|||
def evaluate_actions( |
|||
self, |
|||
vec_obs: torch.Tensor, |
|||
vis_obs: torch.Tensor, |
|||
actions: torch.Tensor, |
|||
masks: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
seq_len: int = 1, |
|||
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: |
|||
dists, value_heads, _ = self.actor_critic.get_dist_and_value( |
|||
vec_obs, vis_obs, masks, memories, seq_len |
|||
) |
|||
action_list = [actions[..., i] for i in range(actions.shape[-1])] |
|||
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_list, dists) |
|||
|
|||
return log_probs, entropies, value_heads |
|||
|
|||
@timed |
|||
def evaluate( |
|||
self, decision_requests: DecisionSteps, global_agent_ids: List[str] |
|||
) -> Dict[str, Any]: |
|||
""" |
|||
Evaluates policy for the agent experiences provided. |
|||
:param global_agent_ids: |
|||
:param decision_requests: DecisionStep object containing inputs. |
|||
:return: Outputs from network as defined by self.inference_dict. |
|||
""" |
|||
vec_vis_obs, masks = self._split_decision_step(decision_requests) |
|||
vec_obs = [torch.as_tensor(vec_vis_obs.vector_observations)] |
|||
vis_obs = [ |
|||
torch.as_tensor(vis_ob) for vis_ob in vec_vis_obs.visual_observations |
|||
] |
|||
memories = torch.as_tensor(self.retrieve_memories(global_agent_ids)).unsqueeze( |
|||
0 |
|||
) |
|||
|
|||
run_out = {} |
|||
with torch.no_grad(): |
|||
action, log_probs, entropy, value_heads, memories = self.sample_actions( |
|||
vec_obs, vis_obs, masks=masks, memories=memories |
|||
) |
|||
run_out["action"] = action.detach().cpu().numpy() |
|||
run_out["pre_action"] = action.detach().cpu().numpy() |
|||
# Todo - make pre_action difference |
|||
run_out["log_probs"] = log_probs.detach().cpu().numpy() |
|||
run_out["entropy"] = entropy.detach().cpu().numpy() |
|||
run_out["value_heads"] = { |
|||
name: t.detach().cpu().numpy() for name, t in value_heads.items() |
|||
} |
|||
run_out["value"] = np.mean(list(run_out["value_heads"].values()), 0) |
|||
run_out["learning_rate"] = 0.0 |
|||
if self.use_recurrent: |
|||
run_out["memory_out"] = memories.detach().cpu().numpy().squeeze(0) |
|||
return run_out |
|||
|
|||
def get_action( |
|||
self, decision_requests: DecisionSteps, worker_id: int = 0 |
|||
) -> ActionInfo: |
|||
""" |
|||
Decides actions given observations information, and takes them in environment. |
|||
:param worker_id: |
|||
:param decision_requests: A dictionary of brain names and BrainInfo from environment. |
|||
:return: an ActionInfo containing action, memories, values and an object |
|||
to be passed to add experiences |
|||
""" |
|||
if len(decision_requests) == 0: |
|||
return ActionInfo.empty() |
|||
|
|||
global_agent_ids = [ |
|||
get_global_agent_id(worker_id, int(agent_id)) |
|||
for agent_id in decision_requests.agent_id |
|||
] # For 1-D array, the iterator order is correct. |
|||
|
|||
run_out = self.evaluate( |
|||
decision_requests, global_agent_ids |
|||
) # pylint: disable=assignment-from-no-return |
|||
self.save_memories(global_agent_ids, run_out.get("memory_out")) |
|||
return ActionInfo( |
|||
action=run_out.get("action"), |
|||
value=run_out.get("value"), |
|||
outputs=run_out, |
|||
agent_ids=list(decision_requests.agent_id), |
|||
) |
|||
|
|||
@property |
|||
def use_vis_obs(self): |
|||
return self.vis_obs_size > 0 |
|||
|
|||
@property |
|||
def use_vec_obs(self): |
|||
return self.vec_obs_size > 0 |
|||
|
|||
def get_current_step(self): |
|||
""" |
|||
Gets current model step. |
|||
:return: current model step. |
|||
""" |
|||
return self.global_step.current_step |
|||
|
|||
def set_step(self, step: int) -> int: |
|||
""" |
|||
Sets current model step to step without creating additional ops. |
|||
:param step: Step to set the current model step to. |
|||
:return: The step the model was set to. |
|||
""" |
|||
self.global_step.current_step = step |
|||
return step |
|||
|
|||
def increment_step(self, n_steps): |
|||
""" |
|||
Increments model step. |
|||
""" |
|||
self.global_step.increment(n_steps) |
|||
return self.get_current_step() |
|||
|
|||
def load_weights(self, values: List[np.ndarray]) -> None: |
|||
self.actor_critic.load_state_dict(values) |
|||
|
|||
def init_load_weights(self) -> None: |
|||
pass |
|||
|
|||
def get_weights(self) -> List[np.ndarray]: |
|||
return copy.deepcopy(self.actor_critic.state_dict()) |
|||
|
|||
def get_modules(self): |
|||
return {"Policy": self.actor_critic, "global_step": self.global_step} |
|
|||
import pytest |
|||
|
|||
from mlagents.trainers.tf.models import ModelUtils |
|||
from mlagents.tf_utils import tf |
|||
from mlagents_envs.base_env import BehaviorSpec, ActionType |
|||
|
|||
|
|||
def create_behavior_spec(num_visual, num_vector, vector_size): |
|||
behavior_spec = BehaviorSpec( |
|||
[(84, 84, 3)] * int(num_visual) + [(vector_size,)] * int(num_vector), |
|||
ActionType.DISCRETE, |
|||
(1,), |
|||
) |
|||
return behavior_spec |
|||
|
|||
|
|||
@pytest.mark.parametrize("num_visual", [1, 2, 4]) |
|||
@pytest.mark.parametrize("num_vector", [1, 2, 4]) |
|||
def test_create_input_placeholders(num_vector, num_visual): |
|||
vec_size = 8 |
|||
name_prefix = "test123" |
|||
bspec = create_behavior_spec(num_visual, num_vector, vec_size) |
|||
vec_in, vis_in = ModelUtils.create_input_placeholders( |
|||
bspec.observation_shapes, name_prefix=name_prefix |
|||
) |
|||
|
|||
assert isinstance(vis_in, list) |
|||
assert len(vis_in) == num_visual |
|||
assert isinstance(vec_in, tf.Tensor) |
|||
assert vec_in.get_shape().as_list()[1] == num_vector * 8 |
|||
|
|||
# Check names contain prefix and vis shapes are correct |
|||
for _vis in vis_in: |
|||
assert _vis.get_shape().as_list() == [None, 84, 84, 3] |
|||
assert _vis.name.startswith(name_prefix) |
|||
assert vec_in.name.startswith(name_prefix) |
|
|||
import pytest |
|||
from unittest import mock |
|||
import os |
|||
import unittest |
|||
import tempfile |
|||
|
|||
import numpy as np |
|||
from mlagents.tf_utils import tf |
|||
from mlagents.trainers.saver.tf_saver import TFSaver |
|||
from mlagents.trainers import __version__ |
|||
from mlagents.trainers.settings import TrainerSettings |
|||
from mlagents.trainers.policy.tf_policy import TFPolicy |
|||
from mlagents.trainers.tests import mock_brain as mb |
|||
from mlagents.trainers.tests.test_nn_policy import create_policy_mock |
|||
from mlagents.trainers.ppo.optimizer import PPOOptimizer |
|||
|
|||
|
|||
def test_register(tmp_path): |
|||
trainer_params = TrainerSettings() |
|||
saver = TFSaver(trainer_params, tmp_path) |
|||
|
|||
opt = mock.Mock(spec=PPOOptimizer) |
|||
saver.register(opt) |
|||
assert saver.policy is None |
|||
|
|||
trainer_params = TrainerSettings() |
|||
policy = create_policy_mock(trainer_params) |
|||
saver.register(policy) |
|||
assert saver.policy is not None |
|||
|
|||
|
|||
class ModelVersionTest(unittest.TestCase): |
|||
def test_version_compare(self): |
|||
# Test write_stats |
|||
with self.assertLogs("mlagents.trainers", level="WARNING") as cm: |
|||
trainer_params = TrainerSettings() |
|||
mock_path = tempfile.mkdtemp() |
|||
policy = create_policy_mock(trainer_params) |
|||
saver = TFSaver(trainer_params, mock_path) |
|||
saver.register(policy) |
|||
|
|||
saver._check_model_version( |
|||
"0.0.0" |
|||
) # This is not the right version for sure |
|||
# Assert that 1 warning has been thrown with incorrect version |
|||
assert len(cm.output) == 1 |
|||
saver._check_model_version(__version__) # This should be the right version |
|||
# Assert that no additional warnings have been thrown wth correct ver |
|||
assert len(cm.output) == 1 |
|||
|
|||
|
|||
def test_load_save(tmp_path): |
|||
path1 = os.path.join(tmp_path, "runid1") |
|||
path2 = os.path.join(tmp_path, "runid2") |
|||
trainer_params = TrainerSettings() |
|||
policy = create_policy_mock(trainer_params) |
|||
saver = TFSaver(trainer_params, path1) |
|||
saver.register(policy) |
|||
saver.initialize_or_load(policy) |
|||
policy.set_step(2000) |
|||
|
|||
mock_brain_name = "MockBrain" |
|||
saver.save_checkpoint(mock_brain_name, 2000) |
|||
assert len(os.listdir(tmp_path)) > 0 |
|||
|
|||
# Try load from this path |
|||
saver = TFSaver(trainer_params, path1, load=True) |
|||
policy2 = create_policy_mock(trainer_params) |
|||
saver.register(policy2) |
|||
saver.initialize_or_load(policy2) |
|||
_compare_two_policies(policy, policy2) |
|||
assert policy2.get_current_step() == 2000 |
|||
|
|||
# Try initialize from path 1 |
|||
trainer_params.init_path = path1 |
|||
saver = TFSaver(trainer_params, path2) |
|||
policy3 = create_policy_mock(trainer_params) |
|||
saver.register(policy3) |
|||
saver.initialize_or_load(policy3) |
|||
|
|||
_compare_two_policies(policy2, policy3) |
|||
# Assert that the steps are 0. |
|||
assert policy3.get_current_step() == 0 |
|||
|
|||
|
|||
def _compare_two_policies(policy1: TFPolicy, policy2: TFPolicy) -> None: |
|||
""" |
|||
Make sure two policies have the same output for the same input. |
|||
""" |
|||
decision_step, _ = mb.create_steps_from_behavior_spec( |
|||
policy1.behavior_spec, num_agents=1 |
|||
) |
|||
run_out1 = policy1.evaluate(decision_step, list(decision_step.agent_id)) |
|||
run_out2 = policy2.evaluate(decision_step, list(decision_step.agent_id)) |
|||
|
|||
np.testing.assert_array_equal(run_out2["log_probs"], run_out1["log_probs"]) |
|||
|
|||
|
|||
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) |
|||
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"]) |
|||
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"]) |
|||
def test_checkpoint_conversion(tmpdir, rnn, visual, discrete): |
|||
tf.reset_default_graph() |
|||
dummy_config = TrainerSettings() |
|||
model_path = os.path.join(tmpdir, "Mock_Brain") |
|||
policy = create_policy_mock( |
|||
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual |
|||
) |
|||
trainer_params = TrainerSettings() |
|||
saver = TFSaver(trainer_params, model_path) |
|||
saver.register(policy) |
|||
saver.save_checkpoint("Mock_Brain", 100) |
|||
assert os.path.isfile(model_path + "/Mock_Brain-100.nn") |
|
|||
from typing import Dict, cast |
|||
import torch |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
|
|||
from mlagents_envs.timers import timed |
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer |
|||
from mlagents.trainers.settings import TrainerSettings, PPOSettings |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
|
|||
|
|||
class TorchPPOOptimizer(TorchOptimizer): |
|||
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): |
|||
""" |
|||
Takes a Policy and a Dict of trainer parameters and creates an Optimizer around the policy. |
|||
The PPO optimizer has a value estimator and a loss function. |
|||
:param policy: A TFPolicy object that will be updated by this PPO Optimizer. |
|||
:param trainer_params: Trainer parameters dictionary that specifies the |
|||
properties of the trainer. |
|||
""" |
|||
# Create the graph here to give more granular control of the TF graph to the Optimizer. |
|||
|
|||
super().__init__(policy, trainer_settings) |
|||
params = list(self.policy.actor_critic.parameters()) |
|||
self.hyperparameters: PPOSettings = cast( |
|||
PPOSettings, trainer_settings.hyperparameters |
|||
) |
|||
self.decay_learning_rate = ModelUtils.DecayedValue( |
|||
self.hyperparameters.learning_rate_schedule, |
|||
self.hyperparameters.learning_rate, |
|||
1e-10, |
|||
self.trainer_settings.max_steps, |
|||
) |
|||
self.decay_epsilon = ModelUtils.DecayedValue( |
|||
self.hyperparameters.learning_rate_schedule, |
|||
self.hyperparameters.epsilon, |
|||
0.1, |
|||
self.trainer_settings.max_steps, |
|||
) |
|||
self.decay_beta = ModelUtils.DecayedValue( |
|||
self.hyperparameters.learning_rate_schedule, |
|||
self.hyperparameters.beta, |
|||
1e-5, |
|||
self.trainer_settings.max_steps, |
|||
) |
|||
|
|||
self.optimizer = torch.optim.Adam( |
|||
params, lr=self.trainer_settings.hyperparameters.learning_rate |
|||
) |
|||
self.stats_name_to_update_name = { |
|||
"Losses/Value Loss": "value_loss", |
|||
"Losses/Policy Loss": "policy_loss", |
|||
} |
|||
|
|||
self.stream_names = list(self.reward_signals.keys()) |
|||
|
|||
def ppo_value_loss( |
|||
self, |
|||
values: Dict[str, torch.Tensor], |
|||
old_values: Dict[str, torch.Tensor], |
|||
returns: Dict[str, torch.Tensor], |
|||
epsilon: float, |
|||
loss_masks: torch.Tensor, |
|||
) -> torch.Tensor: |
|||
""" |
|||
Evaluates value loss for PPO. |
|||
:param values: Value output of the current network. |
|||
:param old_values: Value stored with experiences in buffer. |
|||
:param returns: Computed returns. |
|||
:param epsilon: Clipping value for value estimate. |
|||
:param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences. |
|||
""" |
|||
value_losses = [] |
|||
for name, head in values.items(): |
|||
old_val_tensor = old_values[name] |
|||
returns_tensor = returns[name] |
|||
clipped_value_estimate = old_val_tensor + torch.clamp( |
|||
head - old_val_tensor, -1 * epsilon, epsilon |
|||
) |
|||
v_opt_a = (returns_tensor - head) ** 2 |
|||
v_opt_b = (returns_tensor - clipped_value_estimate) ** 2 |
|||
value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks) |
|||
value_losses.append(value_loss) |
|||
value_loss = torch.mean(torch.stack(value_losses)) |
|||
return value_loss |
|||
|
|||
def ppo_policy_loss( |
|||
self, |
|||
advantages: torch.Tensor, |
|||
log_probs: torch.Tensor, |
|||
old_log_probs: torch.Tensor, |
|||
loss_masks: torch.Tensor, |
|||
) -> torch.Tensor: |
|||
""" |
|||
Evaluate PPO policy loss. |
|||
:param advantages: Computed advantages. |
|||
:param log_probs: Current policy probabilities |
|||
:param old_log_probs: Past policy probabilities |
|||
:param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences. |
|||
""" |
|||
advantage = advantages.unsqueeze(-1) |
|||
|
|||
decay_epsilon = self.hyperparameters.epsilon |
|||
|
|||
r_theta = torch.exp(log_probs - old_log_probs) |
|||
p_opt_a = r_theta * advantage |
|||
p_opt_b = ( |
|||
torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage |
|||
) |
|||
policy_loss = -1 * ModelUtils.masked_mean( |
|||
torch.min(p_opt_a, p_opt_b), loss_masks |
|||
) |
|||
return policy_loss |
|||
|
|||
@timed |
|||
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|||
""" |
|||
Performs update on model. |
|||
:param batch: Batch of experiences. |
|||
:param num_sequences: Number of sequences to process. |
|||
:return: Results of update. |
|||
""" |
|||
# Get decayed parameters |
|||
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) |
|||
decay_eps = self.decay_epsilon.get_value(self.policy.get_current_step()) |
|||
decay_bet = self.decay_beta.get_value(self.policy.get_current_step()) |
|||
returns = {} |
|||
old_values = {} |
|||
for name in self.reward_signals: |
|||
old_values[name] = ModelUtils.list_to_tensor( |
|||
batch[f"{name}_value_estimates"] |
|||
) |
|||
returns[name] = ModelUtils.list_to_tensor(batch[f"{name}_returns"]) |
|||
|
|||
vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] |
|||
act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) |
|||
if self.policy.use_continuous_act: |
|||
actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1) |
|||
else: |
|||
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) |
|||
|
|||
memories = [ |
|||
ModelUtils.list_to_tensor(batch["memory"][i]) |
|||
for i in range(0, len(batch["memory"]), self.policy.sequence_length) |
|||
] |
|||
if len(memories) > 0: |
|||
memories = torch.stack(memories).unsqueeze(0) |
|||
|
|||
if self.policy.use_vis_obs: |
|||
vis_obs = [] |
|||
for idx, _ in enumerate( |
|||
self.policy.actor_critic.network_body.visual_encoders |
|||
): |
|||
vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx]) |
|||
vis_obs.append(vis_ob) |
|||
else: |
|||
vis_obs = [] |
|||
log_probs, entropy, values = self.policy.evaluate_actions( |
|||
vec_obs, |
|||
vis_obs, |
|||
masks=act_masks, |
|||
actions=actions, |
|||
memories=memories, |
|||
seq_len=self.policy.sequence_length, |
|||
) |
|||
loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) |
|||
value_loss = self.ppo_value_loss( |
|||
values, old_values, returns, decay_eps, loss_masks |
|||
) |
|||
policy_loss = self.ppo_policy_loss( |
|||
ModelUtils.list_to_tensor(batch["advantages"]), |
|||
log_probs, |
|||
ModelUtils.list_to_tensor(batch["action_probs"]), |
|||
loss_masks, |
|||
) |
|||
loss = ( |
|||
policy_loss |
|||
+ 0.5 * value_loss |
|||
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks) |
|||
) |
|||
|
|||
# Set optimizer learning rate |
|||
ModelUtils.update_learning_rate(self.optimizer, decay_lr) |
|||
self.optimizer.zero_grad() |
|||
loss.backward() |
|||
|
|||
self.optimizer.step() |
|||
update_stats = { |
|||
"Losses/Policy Loss": abs(policy_loss.detach().cpu().numpy()), |
|||
"Losses/Value Loss": value_loss.detach().cpu().numpy(), |
|||
"Policy/Learning Rate": decay_lr, |
|||
"Policy/Epsilon": decay_eps, |
|||
"Policy/Beta": decay_bet, |
|||
} |
|||
|
|||
for reward_provider in self.reward_signals.values(): |
|||
update_stats.update(reward_provider.update(batch)) |
|||
|
|||
return update_stats |
|||
|
|||
def get_modules(self): |
|||
return {"Optimizer": self.optimizer} |
|
|||
import numpy as np |
|||
from typing import Dict, List, Mapping, cast, Tuple, Optional |
|||
import torch |
|||
from torch import nn |
|||
import attr |
|||
|
|||
from mlagents_envs.logging_util import get_logger |
|||
from mlagents_envs.base_env import ActionType |
|||
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer |
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.settings import NetworkSettings |
|||
from mlagents.trainers.torch.networks import ValueNetwork |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents_envs.timers import timed |
|||
from mlagents.trainers.exception import UnityTrainerException |
|||
from mlagents.trainers.settings import TrainerSettings, SACSettings |
|||
|
|||
EPSILON = 1e-6 # Small value to avoid divide by zero |
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
|
|||
class TorchSACOptimizer(TorchOptimizer): |
|||
class PolicyValueNetwork(nn.Module): |
|||
def __init__( |
|||
self, |
|||
stream_names: List[str], |
|||
observation_shapes: List[Tuple[int, ...]], |
|||
network_settings: NetworkSettings, |
|||
act_type: ActionType, |
|||
act_size: List[int], |
|||
): |
|||
super().__init__() |
|||
if act_type == ActionType.CONTINUOUS: |
|||
num_value_outs = 1 |
|||
num_action_ins = sum(act_size) |
|||
else: |
|||
num_value_outs = sum(act_size) |
|||
num_action_ins = 0 |
|||
self.q1_network = ValueNetwork( |
|||
stream_names, |
|||
observation_shapes, |
|||
network_settings, |
|||
num_action_ins, |
|||
num_value_outs, |
|||
) |
|||
self.q2_network = ValueNetwork( |
|||
stream_names, |
|||
observation_shapes, |
|||
network_settings, |
|||
num_action_ins, |
|||
num_value_outs, |
|||
) |
|||
|
|||
def forward( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
actions: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: |
|||
q1_out, _ = self.q1_network( |
|||
vec_inputs, |
|||
vis_inputs, |
|||
actions=actions, |
|||
memories=memories, |
|||
sequence_length=sequence_length, |
|||
) |
|||
q2_out, _ = self.q2_network( |
|||
vec_inputs, |
|||
vis_inputs, |
|||
actions=actions, |
|||
memories=memories, |
|||
sequence_length=sequence_length, |
|||
) |
|||
return q1_out, q2_out |
|||
|
|||
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): |
|||
super().__init__(policy, trainer_params) |
|||
hyperparameters: SACSettings = cast(SACSettings, trainer_params.hyperparameters) |
|||
self.tau = hyperparameters.tau |
|||
self.init_entcoef = hyperparameters.init_entcoef |
|||
|
|||
self.policy = policy |
|||
self.act_size = policy.act_size |
|||
policy_network_settings = policy.network_settings |
|||
|
|||
self.tau = hyperparameters.tau |
|||
self.burn_in_ratio = 0.0 |
|||
|
|||
# Non-exposed SAC parameters |
|||
self.discrete_target_entropy_scale = 0.2 # Roughly equal to e-greedy 0.05 |
|||
self.continuous_target_entropy_scale = 1.0 |
|||
|
|||
self.stream_names = list(self.reward_signals.keys()) |
|||
# Use to reduce "survivor bonus" when using Curiosity or GAIL. |
|||
self.gammas = [_val.gamma for _val in trainer_params.reward_signals.values()] |
|||
self.use_dones_in_backup = { |
|||
name: int(not self.reward_signals[name].ignore_done) |
|||
for name in self.stream_names |
|||
} |
|||
|
|||
# Critics should have 1/2 of the memory of the policy |
|||
critic_memory = policy_network_settings.memory |
|||
if critic_memory is not None: |
|||
critic_memory = attr.evolve( |
|||
critic_memory, memory_size=critic_memory.memory_size // 2 |
|||
) |
|||
value_network_settings = attr.evolve( |
|||
policy_network_settings, memory=critic_memory |
|||
) |
|||
|
|||
self.value_network = TorchSACOptimizer.PolicyValueNetwork( |
|||
self.stream_names, |
|||
self.policy.behavior_spec.observation_shapes, |
|||
value_network_settings, |
|||
self.policy.behavior_spec.action_type, |
|||
self.act_size, |
|||
) |
|||
|
|||
self.target_network = ValueNetwork( |
|||
self.stream_names, |
|||
self.policy.behavior_spec.observation_shapes, |
|||
value_network_settings, |
|||
) |
|||
self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0) |
|||
|
|||
self._log_ent_coef = torch.nn.Parameter( |
|||
torch.log(torch.as_tensor([self.init_entcoef] * len(self.act_size))), |
|||
requires_grad=True, |
|||
) |
|||
if self.policy.use_continuous_act: |
|||
self.target_entropy = torch.as_tensor( |
|||
-1 |
|||
* self.continuous_target_entropy_scale |
|||
* np.prod(self.act_size[0]).astype(np.float32) |
|||
) |
|||
else: |
|||
self.target_entropy = [ |
|||
self.discrete_target_entropy_scale * np.log(i).astype(np.float32) |
|||
for i in self.act_size |
|||
] |
|||
|
|||
policy_params = list(self.policy.actor_critic.network_body.parameters()) + list( |
|||
self.policy.actor_critic.distribution.parameters() |
|||
) |
|||
value_params = list(self.value_network.parameters()) + list( |
|||
self.policy.actor_critic.critic.parameters() |
|||
) |
|||
|
|||
logger.debug("value_vars") |
|||
for param in value_params: |
|||
logger.debug(param.shape) |
|||
logger.debug("policy_vars") |
|||
for param in policy_params: |
|||
logger.debug(param.shape) |
|||
|
|||
self.decay_learning_rate = ModelUtils.DecayedValue( |
|||
hyperparameters.learning_rate_schedule, |
|||
hyperparameters.learning_rate, |
|||
1e-10, |
|||
self.trainer_settings.max_steps, |
|||
) |
|||
self.policy_optimizer = torch.optim.Adam( |
|||
policy_params, lr=hyperparameters.learning_rate |
|||
) |
|||
self.value_optimizer = torch.optim.Adam( |
|||
value_params, lr=hyperparameters.learning_rate |
|||
) |
|||
self.entropy_optimizer = torch.optim.Adam( |
|||
[self._log_ent_coef], lr=hyperparameters.learning_rate |
|||
) |
|||
|
|||
def sac_q_loss( |
|||
self, |
|||
q1_out: Dict[str, torch.Tensor], |
|||
q2_out: Dict[str, torch.Tensor], |
|||
target_values: Dict[str, torch.Tensor], |
|||
dones: torch.Tensor, |
|||
rewards: Dict[str, torch.Tensor], |
|||
loss_masks: torch.Tensor, |
|||
) -> Tuple[torch.Tensor, torch.Tensor]: |
|||
q1_losses = [] |
|||
q2_losses = [] |
|||
# Multiple q losses per stream |
|||
for i, name in enumerate(q1_out.keys()): |
|||
q1_stream = q1_out[name].squeeze() |
|||
q2_stream = q2_out[name].squeeze() |
|||
with torch.no_grad(): |
|||
q_backup = rewards[name] + ( |
|||
(1.0 - self.use_dones_in_backup[name] * dones) |
|||
* self.gammas[i] |
|||
* target_values[name] |
|||
) |
|||
_q1_loss = 0.5 * ModelUtils.masked_mean( |
|||
torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks |
|||
) |
|||
_q2_loss = 0.5 * ModelUtils.masked_mean( |
|||
torch.nn.functional.mse_loss(q_backup, q2_stream), loss_masks |
|||
) |
|||
|
|||
q1_losses.append(_q1_loss) |
|||
q2_losses.append(_q2_loss) |
|||
q1_loss = torch.mean(torch.stack(q1_losses)) |
|||
q2_loss = torch.mean(torch.stack(q2_losses)) |
|||
return q1_loss, q2_loss |
|||
|
|||
def soft_update(self, source: nn.Module, target: nn.Module, tau: float) -> None: |
|||
for source_param, target_param in zip(source.parameters(), target.parameters()): |
|||
target_param.data.copy_( |
|||
target_param.data * (1.0 - tau) + source_param.data * tau |
|||
) |
|||
|
|||
def sac_value_loss( |
|||
self, |
|||
log_probs: torch.Tensor, |
|||
values: Dict[str, torch.Tensor], |
|||
q1p_out: Dict[str, torch.Tensor], |
|||
q2p_out: Dict[str, torch.Tensor], |
|||
loss_masks: torch.Tensor, |
|||
discrete: bool, |
|||
) -> torch.Tensor: |
|||
min_policy_qs = {} |
|||
with torch.no_grad(): |
|||
_ent_coef = torch.exp(self._log_ent_coef) |
|||
for name in values.keys(): |
|||
if not discrete: |
|||
min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name]) |
|||
else: |
|||
action_probs = log_probs.exp() |
|||
_branched_q1p = ModelUtils.break_into_branches( |
|||
q1p_out[name] * action_probs, self.act_size |
|||
) |
|||
_branched_q2p = ModelUtils.break_into_branches( |
|||
q2p_out[name] * action_probs, self.act_size |
|||
) |
|||
_q1p_mean = torch.mean( |
|||
torch.stack( |
|||
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q1p] |
|||
), |
|||
dim=0, |
|||
) |
|||
_q2p_mean = torch.mean( |
|||
torch.stack( |
|||
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q2p] |
|||
), |
|||
dim=0, |
|||
) |
|||
|
|||
min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean) |
|||
|
|||
value_losses = [] |
|||
if not discrete: |
|||
for name in values.keys(): |
|||
with torch.no_grad(): |
|||
v_backup = min_policy_qs[name] - torch.sum( |
|||
_ent_coef * log_probs, dim=1 |
|||
) |
|||
value_loss = 0.5 * ModelUtils.masked_mean( |
|||
torch.nn.functional.mse_loss(values[name], v_backup), loss_masks |
|||
) |
|||
value_losses.append(value_loss) |
|||
else: |
|||
branched_per_action_ent = ModelUtils.break_into_branches( |
|||
log_probs * log_probs.exp(), self.act_size |
|||
) |
|||
# We have to do entropy bonus per action branch |
|||
branched_ent_bonus = torch.stack( |
|||
[ |
|||
torch.sum(_ent_coef[i] * _lp, dim=1, keepdim=True) |
|||
for i, _lp in enumerate(branched_per_action_ent) |
|||
] |
|||
) |
|||
for name in values.keys(): |
|||
with torch.no_grad(): |
|||
v_backup = min_policy_qs[name] - torch.mean( |
|||
branched_ent_bonus, axis=0 |
|||
) |
|||
value_loss = 0.5 * ModelUtils.masked_mean( |
|||
torch.nn.functional.mse_loss(values[name], v_backup.squeeze()), |
|||
loss_masks, |
|||
) |
|||
value_losses.append(value_loss) |
|||
value_loss = torch.mean(torch.stack(value_losses)) |
|||
if torch.isinf(value_loss).any() or torch.isnan(value_loss).any(): |
|||
raise UnityTrainerException("Inf found") |
|||
return value_loss |
|||
|
|||
def sac_policy_loss( |
|||
self, |
|||
log_probs: torch.Tensor, |
|||
q1p_outs: Dict[str, torch.Tensor], |
|||
loss_masks: torch.Tensor, |
|||
discrete: bool, |
|||
) -> torch.Tensor: |
|||
_ent_coef = torch.exp(self._log_ent_coef) |
|||
mean_q1 = torch.mean(torch.stack(list(q1p_outs.values())), axis=0) |
|||
if not discrete: |
|||
mean_q1 = mean_q1.unsqueeze(1) |
|||
batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1) |
|||
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks) |
|||
else: |
|||
action_probs = log_probs.exp() |
|||
branched_per_action_ent = ModelUtils.break_into_branches( |
|||
log_probs * action_probs, self.act_size |
|||
) |
|||
branched_q_term = ModelUtils.break_into_branches( |
|||
mean_q1 * action_probs, self.act_size |
|||
) |
|||
branched_policy_loss = torch.stack( |
|||
[ |
|||
torch.sum(_ent_coef[i] * _lp - _qt, dim=1, keepdim=True) |
|||
for i, (_lp, _qt) in enumerate( |
|||
zip(branched_per_action_ent, branched_q_term) |
|||
) |
|||
] |
|||
) |
|||
batch_policy_loss = torch.squeeze(branched_policy_loss) |
|||
policy_loss = torch.mean(loss_masks * batch_policy_loss) |
|||
return policy_loss |
|||
|
|||
def sac_entropy_loss( |
|||
self, log_probs: torch.Tensor, loss_masks: torch.Tensor, discrete: bool |
|||
) -> torch.Tensor: |
|||
if not discrete: |
|||
with torch.no_grad(): |
|||
target_current_diff = torch.sum(log_probs + self.target_entropy, dim=1) |
|||
entropy_loss = -torch.mean( |
|||
self._log_ent_coef * loss_masks * target_current_diff |
|||
) |
|||
else: |
|||
with torch.no_grad(): |
|||
branched_per_action_ent = ModelUtils.break_into_branches( |
|||
log_probs * log_probs.exp(), self.act_size |
|||
) |
|||
target_current_diff_branched = torch.stack( |
|||
[ |
|||
torch.sum(_lp, axis=1, keepdim=True) + _te |
|||
for _lp, _te in zip( |
|||
branched_per_action_ent, self.target_entropy |
|||
) |
|||
], |
|||
axis=1, |
|||
) |
|||
target_current_diff = torch.squeeze( |
|||
target_current_diff_branched, axis=2 |
|||
) |
|||
entropy_loss = -1 * ModelUtils.masked_mean( |
|||
torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks |
|||
) |
|||
|
|||
return entropy_loss |
|||
|
|||
def _condense_q_streams( |
|||
self, q_output: Dict[str, torch.Tensor], discrete_actions: torch.Tensor |
|||
) -> Dict[str, torch.Tensor]: |
|||
condensed_q_output = {} |
|||
onehot_actions = ModelUtils.actions_to_onehot(discrete_actions, self.act_size) |
|||
for key, item in q_output.items(): |
|||
branched_q = ModelUtils.break_into_branches(item, self.act_size) |
|||
only_action_qs = torch.stack( |
|||
[ |
|||
torch.sum(_act * _q, dim=1, keepdim=True) |
|||
for _act, _q in zip(onehot_actions, branched_q) |
|||
] |
|||
) |
|||
|
|||
condensed_q_output[key] = torch.mean(only_action_qs, dim=0) |
|||
return condensed_q_output |
|||
|
|||
@timed |
|||
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|||
""" |
|||
Updates model using buffer. |
|||
:param num_sequences: Number of trajectories in batch. |
|||
:param batch: Experience mini-batch. |
|||
:param update_target: Whether or not to update target value network |
|||
:param reward_signal_batches: Minibatches to use for updating the reward signals, |
|||
indexed by name. If none, don't update the reward signals. |
|||
:return: Output from update process. |
|||
""" |
|||
rewards = {} |
|||
for name in self.reward_signals: |
|||
rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"]) |
|||
|
|||
vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] |
|||
next_vec_obs = [ModelUtils.list_to_tensor(batch["next_vector_in"])] |
|||
act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) |
|||
if self.policy.use_continuous_act: |
|||
actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1) |
|||
else: |
|||
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) |
|||
|
|||
memories_list = [ |
|||
ModelUtils.list_to_tensor(batch["memory"][i]) |
|||
for i in range(0, len(batch["memory"]), self.policy.sequence_length) |
|||
] |
|||
# LSTM shouldn't have sequence length <1, but stop it from going out of the index if true. |
|||
offset = 1 if self.policy.sequence_length > 1 else 0 |
|||
next_memories_list = [ |
|||
ModelUtils.list_to_tensor( |
|||
batch["memory"][i][self.policy.m_size // 2 :] |
|||
) # only pass value part of memory to target network |
|||
for i in range(offset, len(batch["memory"]), self.policy.sequence_length) |
|||
] |
|||
|
|||
if len(memories_list) > 0: |
|||
memories = torch.stack(memories_list).unsqueeze(0) |
|||
next_memories = torch.stack(next_memories_list).unsqueeze(0) |
|||
else: |
|||
memories = None |
|||
next_memories = None |
|||
# Q network memories are 0'ed out, since we don't have them during inference. |
|||
q_memories = ( |
|||
torch.zeros_like(next_memories) if next_memories is not None else None |
|||
) |
|||
|
|||
vis_obs: List[torch.Tensor] = [] |
|||
next_vis_obs: List[torch.Tensor] = [] |
|||
if self.policy.use_vis_obs: |
|||
vis_obs = [] |
|||
for idx, _ in enumerate( |
|||
self.policy.actor_critic.network_body.visual_encoders |
|||
): |
|||
vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx]) |
|||
vis_obs.append(vis_ob) |
|||
next_vis_ob = ModelUtils.list_to_tensor( |
|||
batch["next_visual_obs%d" % idx] |
|||
) |
|||
next_vis_obs.append(next_vis_ob) |
|||
|
|||
# Copy normalizers from policy |
|||
self.value_network.q1_network.network_body.copy_normalization( |
|||
self.policy.actor_critic.network_body |
|||
) |
|||
self.value_network.q2_network.network_body.copy_normalization( |
|||
self.policy.actor_critic.network_body |
|||
) |
|||
self.target_network.network_body.copy_normalization( |
|||
self.policy.actor_critic.network_body |
|||
) |
|||
( |
|||
sampled_actions, |
|||
log_probs, |
|||
entropies, |
|||
sampled_values, |
|||
_, |
|||
) = self.policy.sample_actions( |
|||
vec_obs, |
|||
vis_obs, |
|||
masks=act_masks, |
|||
memories=memories, |
|||
seq_len=self.policy.sequence_length, |
|||
all_log_probs=not self.policy.use_continuous_act, |
|||
) |
|||
if self.policy.use_continuous_act: |
|||
squeezed_actions = actions.squeeze(-1) |
|||
q1p_out, q2p_out = self.value_network( |
|||
vec_obs, |
|||
vis_obs, |
|||
sampled_actions, |
|||
memories=q_memories, |
|||
sequence_length=self.policy.sequence_length, |
|||
) |
|||
q1_out, q2_out = self.value_network( |
|||
vec_obs, |
|||
vis_obs, |
|||
squeezed_actions, |
|||
memories=q_memories, |
|||
sequence_length=self.policy.sequence_length, |
|||
) |
|||
q1_stream, q2_stream = q1_out, q2_out |
|||
else: |
|||
with torch.no_grad(): |
|||
q1p_out, q2p_out = self.value_network( |
|||
vec_obs, |
|||
vis_obs, |
|||
memories=q_memories, |
|||
sequence_length=self.policy.sequence_length, |
|||
) |
|||
q1_out, q2_out = self.value_network( |
|||
vec_obs, |
|||
vis_obs, |
|||
memories=q_memories, |
|||
sequence_length=self.policy.sequence_length, |
|||
) |
|||
q1_stream = self._condense_q_streams(q1_out, actions) |
|||
q2_stream = self._condense_q_streams(q2_out, actions) |
|||
|
|||
with torch.no_grad(): |
|||
target_values, _ = self.target_network( |
|||
next_vec_obs, |
|||
next_vis_obs, |
|||
memories=next_memories, |
|||
sequence_length=self.policy.sequence_length, |
|||
) |
|||
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) |
|||
use_discrete = not self.policy.use_continuous_act |
|||
dones = ModelUtils.list_to_tensor(batch["done"]) |
|||
|
|||
q1_loss, q2_loss = self.sac_q_loss( |
|||
q1_stream, q2_stream, target_values, dones, rewards, masks |
|||
) |
|||
value_loss = self.sac_value_loss( |
|||
log_probs, sampled_values, q1p_out, q2p_out, masks, use_discrete |
|||
) |
|||
policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks, use_discrete) |
|||
entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete) |
|||
|
|||
total_value_loss = q1_loss + q2_loss + value_loss |
|||
|
|||
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) |
|||
ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr) |
|||
self.policy_optimizer.zero_grad() |
|||
policy_loss.backward() |
|||
self.policy_optimizer.step() |
|||
|
|||
ModelUtils.update_learning_rate(self.value_optimizer, decay_lr) |
|||
self.value_optimizer.zero_grad() |
|||
total_value_loss.backward() |
|||
self.value_optimizer.step() |
|||
|
|||
ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr) |
|||
self.entropy_optimizer.zero_grad() |
|||
entropy_loss.backward() |
|||
self.entropy_optimizer.step() |
|||
|
|||
# Update target network |
|||
self.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau) |
|||
update_stats = { |
|||
"Losses/Policy Loss": abs(policy_loss.detach().cpu().numpy()), |
|||
"Losses/Value Loss": value_loss.detach().cpu().numpy(), |
|||
"Losses/Q1 Loss": q1_loss.detach().cpu().numpy(), |
|||
"Losses/Q2 Loss": q2_loss.detach().cpu().numpy(), |
|||
"Policy/Entropy Coeff": torch.exp(self._log_ent_coef) |
|||
.detach() |
|||
.cpu() |
|||
.numpy(), |
|||
"Policy/Learning Rate": decay_lr, |
|||
} |
|||
|
|||
for signal in self.reward_signals.values(): |
|||
signal.update(batch) |
|||
|
|||
return update_stats |
|||
|
|||
def update_reward_signals( |
|||
self, reward_signal_minibatches: Mapping[str, AgentBuffer], num_sequences: int |
|||
) -> Dict[str, float]: |
|||
return {} |
|||
|
|||
def get_modules(self): |
|||
return { |
|||
"Optimizer:value_network": self.value_network, |
|||
"Optimizer:target_network": self.target_network, |
|||
"Optimizer:policy_optimizer": self.policy_optimizer, |
|||
"Optimizer:value_optimizer": self.value_optimizer, |
|||
"Optimizer:entropy_optimizer": self.entropy_optimizer, |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 11fe037a02b4a483cb9342c3454232cd |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: fcb7a51f0d5f8404db7b85bd35ecc1fb |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System; |
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
/// <summary>
|
|||
/// Settings that define the observations generated for physics-based sensors.
|
|||
/// </summary>
|
|||
[Serializable] |
|||
public struct PhysicsSensorSettings |
|||
{ |
|||
/// <summary>
|
|||
/// Whether to use model space (relative to the root body) translations as observations.
|
|||
/// </summary>
|
|||
public bool UseModelSpaceTranslations; |
|||
|
|||
/// <summary>
|
|||
/// Whether to use model space (relative to the root body) rotations as observations.
|
|||
/// </summary>
|
|||
public bool UseModelSpaceRotations; |
|||
|
|||
/// <summary>
|
|||
/// Whether to use local space (relative to the parent body) translations as observations.
|
|||
/// </summary>
|
|||
public bool UseLocalSpaceTranslations; |
|||
|
|||
/// <summary>
|
|||
/// Whether to use local space (relative to the parent body) translations as observations.
|
|||
/// </summary>
|
|||
public bool UseLocalSpaceRotations; |
|||
|
|||
/// <summary>
|
|||
/// Whether to use model space (relative to the root body) linear velocities as observations.
|
|||
/// </summary>
|
|||
public bool UseModelSpaceLinearVelocity; |
|||
|
|||
/// <summary>
|
|||
/// Whether to use local space (relative to the parent body) linear velocities as observations.
|
|||
/// </summary>
|
|||
public bool UseLocalSpaceLinearVelocity; |
|||
|
|||
/// <summary>
|
|||
/// Whether to use joint-specific positions and angles as observations.
|
|||
/// </summary>
|
|||
public bool UseJointPositionsAndAngles; |
|||
|
|||
/// <summary>
|
|||
/// Whether to use the joint forces and torques that are applied by the solver as observations.
|
|||
/// </summary>
|
|||
public bool UseJointForces; |
|||
|
|||
/// <summary>
|
|||
/// Creates a PhysicsSensorSettings with reasonable default values.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
public static PhysicsSensorSettings Default() |
|||
{ |
|||
return new PhysicsSensorSettings |
|||
{ |
|||
UseModelSpaceTranslations = true, |
|||
UseModelSpaceRotations = true, |
|||
}; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Whether any model space observations are being used.
|
|||
/// </summary>
|
|||
public bool UseModelSpace |
|||
{ |
|||
get { return UseModelSpaceTranslations || UseModelSpaceRotations || UseModelSpaceLinearVelocity; } |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Whether any local space observations are being used.
|
|||
/// </summary>
|
|||
public bool UseLocalSpace |
|||
{ |
|||
get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity; } |
|||
} |
|||
} |
|||
|
|||
internal static class ObservationWriterPhysicsExtensions |
|||
{ |
|||
/// <summary>
|
|||
/// Utility method for writing a PoseExtractor to an ObservationWriter.
|
|||
/// </summary>
|
|||
/// <param name="writer"></param>
|
|||
/// <param name="settings"></param>
|
|||
/// <param name="poseExtractor"></param>
|
|||
/// <param name="baseOffset">The offset into the ObservationWriter to start writing at.</param>
|
|||
/// <returns>The number of observations written.</returns>
|
|||
public static int WritePoses(this ObservationWriter writer, PhysicsSensorSettings settings, PoseExtractor poseExtractor, int baseOffset = 0) |
|||
{ |
|||
var offset = baseOffset; |
|||
if (settings.UseModelSpace) |
|||
{ |
|||
foreach (var pose in poseExtractor.GetEnabledModelSpacePoses()) |
|||
{ |
|||
if (settings.UseModelSpaceTranslations) |
|||
{ |
|||
writer.Add(pose.position, offset); |
|||
offset += 3; |
|||
} |
|||
|
|||
if (settings.UseModelSpaceRotations) |
|||
{ |
|||
writer.Add(pose.rotation, offset); |
|||
offset += 4; |
|||
} |
|||
} |
|||
|
|||
foreach(var vel in poseExtractor.GetEnabledModelSpaceVelocities()) |
|||
{ |
|||
if (settings.UseModelSpaceLinearVelocity) |
|||
{ |
|||
writer.Add(vel, offset); |
|||
offset += 3; |
|||
} |
|||
} |
|||
} |
|||
|
|||
if (settings.UseLocalSpace) |
|||
{ |
|||
foreach (var pose in poseExtractor.GetEnabledLocalSpacePoses()) |
|||
{ |
|||
if (settings.UseLocalSpaceTranslations) |
|||
{ |
|||
writer.Add(pose.position, offset); |
|||
offset += 3; |
|||
} |
|||
|
|||
if (settings.UseLocalSpaceRotations) |
|||
{ |
|||
writer.Add(pose.rotation, offset); |
|||
offset += 4; |
|||
} |
|||
} |
|||
|
|||
foreach(var vel in poseExtractor.GetEnabledLocalSpaceVelocities()) |
|||
{ |
|||
if (settings.UseLocalSpaceLinearVelocity) |
|||
{ |
|||
writer.Add(vel, offset); |
|||
offset += 3; |
|||
} |
|||
} |
|||
} |
|||
|
|||
return offset - baseOffset; |
|||
} |
|||
} |
|||
} |
|
|||
#if UNITY_2020_1_OR_NEWER
|
|||
|
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
/// <summary>
|
|||
/// Utility class to track a hierarchy of ArticulationBodies.
|
|||
/// </summary>
|
|||
public class ArticulationBodyPoseExtractor : PoseExtractor |
|||
{ |
|||
ArticulationBody[] m_Bodies; |
|||
|
|||
public ArticulationBodyPoseExtractor(ArticulationBody rootBody) |
|||
{ |
|||
if (rootBody == null) |
|||
{ |
|||
return; |
|||
} |
|||
|
|||
if (!rootBody.isRoot) |
|||
{ |
|||
Debug.Log("Must pass ArticulationBody.isRoot"); |
|||
return; |
|||
} |
|||
|
|||
var bodies = rootBody.GetComponentsInChildren <ArticulationBody>(); |
|||
if (bodies[0] != rootBody) |
|||
{ |
|||
Debug.Log("Expected root body at index 0"); |
|||
return; |
|||
} |
|||
|
|||
var numBodies = bodies.Length; |
|||
m_Bodies = bodies; |
|||
int[] parentIndices = new int[numBodies]; |
|||
parentIndices[0] = -1; |
|||
|
|||
var bodyToIndex = new Dictionary<ArticulationBody, int>(); |
|||
for (var i = 0; i < numBodies; i++) |
|||
{ |
|||
bodyToIndex[m_Bodies[i]] = i; |
|||
} |
|||
|
|||
for (var i = 1; i < numBodies; i++) |
|||
{ |
|||
var currentArticBody = m_Bodies[i]; |
|||
// Component.GetComponentInParent will consider the provided object as well.
|
|||
// So start looking from the parent.
|
|||
var currentGameObject = currentArticBody.gameObject; |
|||
var parentGameObject = currentGameObject.transform.parent; |
|||
var parentArticBody = parentGameObject.GetComponentInParent<ArticulationBody>(); |
|||
parentIndices[i] = bodyToIndex[parentArticBody]; |
|||
} |
|||
|
|||
Setup(parentIndices); |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
protected internal override Vector3 GetLinearVelocityAt(int index) |
|||
{ |
|||
return m_Bodies[index].velocity; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
protected internal override Pose GetPoseAt(int index) |
|||
{ |
|||
var body = m_Bodies[index]; |
|||
var go = body.gameObject; |
|||
var t = go.transform; |
|||
return new Pose { rotation = t.rotation, position = t.position }; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
protected internal override Object GetObjectAt(int index) |
|||
{ |
|||
return m_Bodies[index]; |
|||
} |
|||
|
|||
internal ArticulationBody[] Bodies => m_Bodies; |
|||
|
|||
internal IEnumerable<ArticulationBody> GetEnabledArticulationBodies() |
|||
{ |
|||
if (m_Bodies == null) |
|||
{ |
|||
yield break; |
|||
} |
|||
|
|||
for (var i = 0; i < m_Bodies.Length; i++) |
|||
{ |
|||
var articBody = m_Bodies[i]; |
|||
if (articBody == null) |
|||
{ |
|||
// Ignore a virtual root.
|
|||
continue; |
|||
} |
|||
|
|||
if (IsPoseEnabled(i)) |
|||
{ |
|||
yield return articBody; |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
#endif // UNITY_2020_1_OR_NEWER
|
|
|||
#if UNITY_2020_1_OR_NEWER
|
|||
using UnityEngine; |
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
public class ArticulationBodySensorComponent : SensorComponent |
|||
{ |
|||
public ArticulationBody RootBody; |
|||
|
|||
[SerializeField] |
|||
public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default(); |
|||
public string sensorName; |
|||
|
|||
/// <summary>
|
|||
/// Creates a PhysicsBodySensor.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
public override ISensor CreateSensor() |
|||
{ |
|||
return new PhysicsBodySensor(RootBody, Settings, sensorName); |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public override int[] GetObservationShape() |
|||
{ |
|||
if (RootBody == null) |
|||
{ |
|||
return new[] { 0 }; |
|||
} |
|||
|
|||
// TODO static method in PhysicsBodySensor?
|
|||
// TODO only update PoseExtractor when body changes?
|
|||
var poseExtractor = new ArticulationBodyPoseExtractor(RootBody); |
|||
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings); |
|||
var numJointObservations = 0; |
|||
|
|||
foreach(var articBody in poseExtractor.GetEnabledArticulationBodies()) |
|||
{ |
|||
numJointObservations += ArticulationBodyJointExtractor.NumObservations(articBody, Settings); |
|||
} |
|||
return new[] { numPoseObservations + numJointObservations }; |
|||
} |
|||
} |
|||
|
|||
} |
|||
#endif // UNITY_2020_1_OR_NEWER
|
|
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
/// <summary>
|
|||
/// ISensor implementation that generates observations for a group of Rigidbodies or ArticulationBodies.
|
|||
/// </summary>
|
|||
public class PhysicsBodySensor : ISensor |
|||
{ |
|||
int[] m_Shape; |
|||
string m_SensorName; |
|||
|
|||
PoseExtractor m_PoseExtractor; |
|||
List<IJointExtractor> m_JointExtractors; |
|||
PhysicsSensorSettings m_Settings; |
|||
|
|||
/// <summary>
|
|||
/// Construct a new PhysicsBodySensor
|
|||
/// </summary>
|
|||
/// <param name="poseExtractor"></param>
|
|||
/// <param name="settings"></param>
|
|||
/// <param name="sensorName"></param>
|
|||
public PhysicsBodySensor( |
|||
RigidBodyPoseExtractor poseExtractor, |
|||
PhysicsSensorSettings settings, |
|||
string sensorName |
|||
) |
|||
{ |
|||
m_PoseExtractor = poseExtractor; |
|||
m_SensorName = sensorName; |
|||
m_Settings = settings; |
|||
|
|||
var numJointExtractorObservations = 0; |
|||
m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses); |
|||
foreach(var rb in poseExtractor.GetEnabledRigidbodies()) |
|||
{ |
|||
var jointExtractor = new RigidBodyJointExtractor(rb); |
|||
numJointExtractorObservations += jointExtractor.NumObservations(settings); |
|||
m_JointExtractors.Add(jointExtractor); |
|||
} |
|||
|
|||
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); |
|||
m_Shape = new[] { numTransformObservations + numJointExtractorObservations }; |
|||
} |
|||
|
|||
#if UNITY_2020_1_OR_NEWER
|
|||
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null) |
|||
{ |
|||
var poseExtractor = new ArticulationBodyPoseExtractor(rootBody); |
|||
m_PoseExtractor = poseExtractor; |
|||
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName; |
|||
m_Settings = settings; |
|||
|
|||
var numJointExtractorObservations = 0; |
|||
m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses); |
|||
foreach(var articBody in poseExtractor.GetEnabledArticulationBodies()) |
|||
{ |
|||
var jointExtractor = new ArticulationBodyJointExtractor(articBody); |
|||
numJointExtractorObservations += jointExtractor.NumObservations(settings); |
|||
m_JointExtractors.Add(jointExtractor); |
|||
} |
|||
|
|||
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); |
|||
m_Shape = new[] { numTransformObservations + numJointExtractorObservations }; |
|||
} |
|||
#endif
|
|||
|
|||
/// <inheritdoc/>
|
|||
public int[] GetObservationShape() |
|||
{ |
|||
return m_Shape; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public int Write(ObservationWriter writer) |
|||
{ |
|||
var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor); |
|||
foreach (var jointExtractor in m_JointExtractors) |
|||
{ |
|||
numWritten += jointExtractor.Write(m_Settings, writer, numWritten); |
|||
} |
|||
return numWritten; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public byte[] GetCompressedObservation() |
|||
{ |
|||
return null; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public void Update() |
|||
{ |
|||
if (m_Settings.UseModelSpace) |
|||
{ |
|||
m_PoseExtractor.UpdateModelSpacePoses(); |
|||
} |
|||
|
|||
if (m_Settings.UseLocalSpace) |
|||
{ |
|||
m_PoseExtractor.UpdateLocalSpacePoses(); |
|||
} |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public void Reset() {} |
|||
|
|||
/// <inheritdoc/>
|
|||
public SensorCompressionType GetCompressionType() |
|||
{ |
|||
return SensorCompressionType.None; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public string GetName() |
|||
{ |
|||
return m_SensorName; |
|||
} |
|||
} |
|||
} |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
using Object = UnityEngine.Object; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
/// <summary>
|
|||
/// Abstract class for managing the transforms of a hierarchy of objects.
|
|||
/// This could be GameObjects or Monobehaviours in the scene graph, but this is
|
|||
/// not a requirement; for example, the objects could be rigid bodies whose hierarchy
|
|||
/// is defined by Joint configurations.
|
|||
///
|
|||
/// Poses are either considered in model space, which is relative to a root body,
|
|||
/// or in local space, which is relative to their parent.
|
|||
/// </summary>
|
|||
public abstract class PoseExtractor |
|||
{ |
|||
int[] m_ParentIndices; |
|||
Pose[] m_ModelSpacePoses; |
|||
Pose[] m_LocalSpacePoses; |
|||
|
|||
Vector3[] m_ModelSpaceLinearVelocities; |
|||
Vector3[] m_LocalSpaceLinearVelocities; |
|||
|
|||
bool[] m_PoseEnabled; |
|||
|
|||
|
|||
/// <summary>
|
|||
/// Read iterator for the enabled model space transforms.
|
|||
/// </summary>
|
|||
public IEnumerable<Pose> GetEnabledModelSpacePoses() |
|||
{ |
|||
if (m_ModelSpacePoses == null) |
|||
{ |
|||
yield break; |
|||
} |
|||
|
|||
for (var i = 0; i < m_ModelSpacePoses.Length; i++) |
|||
{ |
|||
if (m_PoseEnabled[i]) |
|||
{ |
|||
yield return m_ModelSpacePoses[i]; |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Read iterator for the enabled local space transforms.
|
|||
/// </summary>
|
|||
public IEnumerable<Pose> GetEnabledLocalSpacePoses() |
|||
{ |
|||
if (m_LocalSpacePoses == null) |
|||
{ |
|||
yield break; |
|||
} |
|||
|
|||
for (var i = 0; i < m_LocalSpacePoses.Length; i++) |
|||
{ |
|||
if (m_PoseEnabled[i]) |
|||
{ |
|||
yield return m_LocalSpacePoses[i]; |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Read iterator for the enabled model space linear velocities.
|
|||
/// </summary>
|
|||
public IEnumerable<Vector3> GetEnabledModelSpaceVelocities() |
|||
{ |
|||
if (m_ModelSpaceLinearVelocities == null) |
|||
{ |
|||
yield break; |
|||
} |
|||
|
|||
for (var i = 0; i < m_ModelSpaceLinearVelocities.Length; i++) |
|||
{ |
|||
if (m_PoseEnabled[i]) |
|||
{ |
|||
yield return m_ModelSpaceLinearVelocities[i]; |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Read iterator for the enabled local space linear velocities.
|
|||
/// </summary>
|
|||
public IEnumerable<Vector3> GetEnabledLocalSpaceVelocities() |
|||
{ |
|||
if (m_LocalSpaceLinearVelocities == null) |
|||
{ |
|||
yield break; |
|||
} |
|||
|
|||
for (var i = 0; i < m_LocalSpaceLinearVelocities.Length; i++) |
|||
{ |
|||
if (m_PoseEnabled[i]) |
|||
{ |
|||
yield return m_LocalSpaceLinearVelocities[i]; |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Number of enabled poses in the hierarchy (read-only).
|
|||
/// </summary>
|
|||
public int NumEnabledPoses |
|||
{ |
|||
get |
|||
{ |
|||
if (m_PoseEnabled == null) |
|||
{ |
|||
return 0; |
|||
} |
|||
|
|||
var numEnabled = 0; |
|||
for (var i = 0; i < m_PoseEnabled.Length; i++) |
|||
{ |
|||
numEnabled += m_PoseEnabled[i] ? 1 : 0; |
|||
} |
|||
|
|||
return numEnabled; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Number of total poses in the hierarchy (read-only).
|
|||
/// </summary>
|
|||
public int NumPoses |
|||
{ |
|||
get { return m_ModelSpacePoses?.Length ?? 0; } |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Get the parent index of the body at the specified index.
|
|||
/// </summary>
|
|||
/// <param name="index"></param>
|
|||
/// <returns></returns>
|
|||
public int GetParentIndex(int index) |
|||
{ |
|||
if (m_ParentIndices == null) |
|||
{ |
|||
throw new NullReferenceException("No parent indices set"); |
|||
} |
|||
|
|||
return m_ParentIndices[index]; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Set whether the pose at the given index is enabled or disabled for observations.
|
|||
/// </summary>
|
|||
/// <param name="index"></param>
|
|||
/// <param name="val"></param>
|
|||
public void SetPoseEnabled(int index, bool val) |
|||
{ |
|||
m_PoseEnabled[index] = val; |
|||
} |
|||
|
|||
public bool IsPoseEnabled(int index) |
|||
{ |
|||
return m_PoseEnabled[index]; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Initialize with the mapping of parent indices.
|
|||
/// The 0th element is assumed to be -1, indicating that it's the root.
|
|||
/// </summary>
|
|||
/// <param name="parentIndices"></param>
|
|||
protected void Setup(int[] parentIndices) |
|||
{ |
|||
#if DEBUG
|
|||
if (parentIndices[0] != -1) |
|||
{ |
|||
throw new UnityAgentsException($"Expected parentIndices[0] to be -1, got {parentIndices[0]}"); |
|||
} |
|||
#endif
|
|||
m_ParentIndices = parentIndices; |
|||
var numPoses = parentIndices.Length; |
|||
m_ModelSpacePoses = new Pose[numPoses]; |
|||
m_LocalSpacePoses = new Pose[numPoses]; |
|||
|
|||
m_ModelSpaceLinearVelocities = new Vector3[numPoses]; |
|||
m_LocalSpaceLinearVelocities = new Vector3[numPoses]; |
|||
|
|||
m_PoseEnabled = new bool[numPoses]; |
|||
// All poses are enabled by default. Generally we'll want to disable the root though.
|
|||
for (var i = 0; i < numPoses; i++) |
|||
{ |
|||
m_PoseEnabled[i] = true; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Return the world space Pose of the i'th object.
|
|||
/// </summary>
|
|||
/// <param name="index"></param>
|
|||
/// <returns></returns>
|
|||
protected internal abstract Pose GetPoseAt(int index); |
|||
|
|||
/// <summary>
|
|||
/// Return the world space linear velocity of the i'th object.
|
|||
/// </summary>
|
|||
/// <param name="index"></param>
|
|||
/// <returns></returns>
|
|||
protected internal abstract Vector3 GetLinearVelocityAt(int index); |
|||
|
|||
/// <summary>
|
|||
/// Return the underlying object at the given index. This is only
|
|||
/// used for display in the inspector.
|
|||
/// </summary>
|
|||
/// <param name="index"></param>
|
|||
/// <returns></returns>
|
|||
protected internal virtual Object GetObjectAt(int index) |
|||
{ |
|||
return null; |
|||
} |
|||
|
|||
|
|||
/// <summary>
|
|||
/// Update the internal model space transform storage based on the underlying system.
|
|||
/// </summary>
|
|||
public void UpdateModelSpacePoses() |
|||
{ |
|||
using (TimerStack.Instance.Scoped("UpdateModelSpacePoses")) |
|||
{ |
|||
if (m_ModelSpacePoses == null) |
|||
{ |
|||
return; |
|||
} |
|||
|
|||
var rootWorldTransform = GetPoseAt(0); |
|||
var worldToModel = rootWorldTransform.Inverse(); |
|||
var rootLinearVel = GetLinearVelocityAt(0); |
|||
|
|||
for (var i = 0; i < m_ModelSpacePoses.Length; i++) |
|||
{ |
|||
var currentWorldSpacePose = GetPoseAt(i); |
|||
var currentModelSpacePose = worldToModel.Multiply(currentWorldSpacePose); |
|||
m_ModelSpacePoses[i] = currentModelSpacePose; |
|||
|
|||
var currentBodyLinearVel = GetLinearVelocityAt(i); |
|||
var relativeVelocity = currentBodyLinearVel - rootLinearVel; |
|||
m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity; |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Update the internal model space transform storage based on the underlying system.
|
|||
/// </summary>
|
|||
public void UpdateLocalSpacePoses() |
|||
{ |
|||
using (TimerStack.Instance.Scoped("UpdateLocalSpacePoses")) |
|||
{ |
|||
if (m_LocalSpacePoses == null) |
|||
{ |
|||
return; |
|||
} |
|||
|
|||
for (var i = 0; i < m_LocalSpacePoses.Length; i++) |
|||
{ |
|||
if (m_ParentIndices[i] != -1) |
|||
{ |
|||
var parentTransform = GetPoseAt(m_ParentIndices[i]); |
|||
// This is slightly inefficient, since for a body with multiple children, we'll end up inverting
|
|||
// the transform multiple times. Might be able to trade space for perf here.
|
|||
var invParent = parentTransform.Inverse(); |
|||
var currentTransform = GetPoseAt(i); |
|||
m_LocalSpacePoses[i] = invParent.Multiply(currentTransform); |
|||
|
|||
var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]); |
|||
var currentLinearVel = GetLinearVelocityAt(i); |
|||
m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel); |
|||
} |
|||
else |
|||
{ |
|||
m_LocalSpacePoses[i] = Pose.identity; |
|||
m_LocalSpaceLinearVelocities[i] = Vector3.zero; |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Compute the number of floats needed to represent the poses for the given PhysicsSensorSettings.
|
|||
/// </summary>
|
|||
/// <param name="settings"></param>
|
|||
/// <returns></returns>
|
|||
public int GetNumPoseObservations(PhysicsSensorSettings settings) |
|||
{ |
|||
int obsPerPose = 0; |
|||
obsPerPose += settings.UseModelSpaceTranslations ? 3 : 0; |
|||
obsPerPose += settings.UseModelSpaceRotations ? 4 : 0; |
|||
obsPerPose += settings.UseLocalSpaceTranslations ? 3 : 0; |
|||
obsPerPose += settings.UseLocalSpaceRotations ? 4 : 0; |
|||
|
|||
obsPerPose += settings.UseModelSpaceLinearVelocity ? 3 : 0; |
|||
obsPerPose += settings.UseLocalSpaceLinearVelocity ? 3 : 0; |
|||
|
|||
return NumEnabledPoses * obsPerPose; |
|||
} |
|||
|
|||
internal void DrawModelSpace(Vector3 offset) |
|||
{ |
|||
UpdateLocalSpacePoses(); |
|||
UpdateModelSpacePoses(); |
|||
|
|||
var pose = m_ModelSpacePoses; |
|||
var localPose = m_LocalSpacePoses; |
|||
for (var i = 0; i < pose.Length; i++) |
|||
{ |
|||
var current = pose[i]; |
|||
if (m_ParentIndices[i] == -1) |
|||
{ |
|||
continue; |
|||
} |
|||
|
|||
var parent = pose[m_ParentIndices[i]]; |
|||
Debug.DrawLine(current.position + offset, parent.position + offset, Color.cyan); |
|||
var localUp = localPose[i].rotation * Vector3.up; |
|||
var localFwd = localPose[i].rotation * Vector3.forward; |
|||
var localRight = localPose[i].rotation * Vector3.right; |
|||
Debug.DrawLine(current.position+offset, current.position+offset+.1f*localUp, Color.red); |
|||
Debug.DrawLine(current.position+offset, current.position+offset+.1f*localFwd, Color.green); |
|||
Debug.DrawLine(current.position+offset, current.position+offset+.1f*localRight, Color.blue); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Simplified representation of the a node in the hierarchy for display.
|
|||
/// </summary>
|
|||
internal struct DisplayNode |
|||
{ |
|||
/// <summary>
|
|||
/// Underlying object in the hierarchy. Pass to EditorGUIUtility.ObjectContent() for display.
|
|||
/// </summary>
|
|||
public Object NodeObject; |
|||
|
|||
/// <summary>
|
|||
/// Whether the poses for the object are enabled.
|
|||
/// </summary>
|
|||
public bool Enabled; |
|||
|
|||
/// <summary>
|
|||
/// Depth in the hierarchy, used for adjusting the indent level.
|
|||
/// </summary>
|
|||
public int Depth; |
|||
|
|||
/// <summary>
|
|||
/// The index of the corresponding object in the PoseExtractor.
|
|||
/// </summary>
|
|||
public int OriginalIndex; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Get a list of display nodes in depth-first order.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
internal IList<DisplayNode> GetDisplayNodes() |
|||
{ |
|||
if (NumPoses == 0) |
|||
{ |
|||
return Array.Empty<DisplayNode>(); |
|||
} |
|||
var nodesOut = new List<DisplayNode>(NumPoses); |
|||
|
|||
// List of children for each node
|
|||
var tree = new Dictionary<int, List<int>>(); |
|||
for (var i = 0; i < NumPoses; i++) |
|||
{ |
|||
var parent = GetParentIndex(i); |
|||
if (i == -1) |
|||
{ |
|||
continue; |
|||
} |
|||
|
|||
if (!tree.ContainsKey(parent)) |
|||
{ |
|||
tree[parent] = new List<int>(); |
|||
} |
|||
tree[parent].Add(i); |
|||
} |
|||
|
|||
// Store (index, depth) in the stack
|
|||
var stack = new Stack<(int, int)>(); |
|||
stack.Push((0, 0)); |
|||
|
|||
while (stack.Count != 0) |
|||
{ |
|||
var (current, depth) = stack.Pop(); |
|||
var obj = GetObjectAt(current); |
|||
|
|||
var node = new DisplayNode |
|||
{ |
|||
NodeObject = obj, |
|||
Enabled = IsPoseEnabled(current), |
|||
OriginalIndex = current, |
|||
Depth = depth |
|||
}; |
|||
nodesOut.Add(node); |
|||
|
|||
// Add children
|
|||
if (tree.ContainsKey(current)) |
|||
{ |
|||
// Push to the stack in reverse order
|
|||
var children = tree[current]; |
|||
for (var childIdx = children.Count-1; childIdx >= 0; childIdx--) |
|||
{ |
|||
stack.Push((children[childIdx], depth+1)); |
|||
} |
|||
} |
|||
|
|||
// Safety check
|
|||
// This shouldn't even happen, but in case we have a cycle in the graph
|
|||
// exit instead of looping forever and eating up all the memory.
|
|||
if (nodesOut.Count > NumPoses) |
|||
{ |
|||
return nodesOut; |
|||
} |
|||
} |
|||
|
|||
return nodesOut; |
|||
} |
|||
|
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Extension methods for the Pose struct, in order to improve the readability of some math.
|
|||
/// </summary>
|
|||
public static class PoseExtensions |
|||
{ |
|||
/// <summary>
|
|||
/// Compute the inverse of a Pose. For any Pose P,
|
|||
/// P.Inverse() * P
|
|||
/// will equal the identity pose (within tolerance).
|
|||
/// </summary>
|
|||
/// <param name="pose"></param>
|
|||
/// <returns></returns>
|
|||
public static Pose Inverse(this Pose pose) |
|||
{ |
|||
var rotationInverse = Quaternion.Inverse(pose.rotation); |
|||
var translationInverse = -(rotationInverse * pose.position); |
|||
return new Pose { rotation = rotationInverse, position = translationInverse }; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// This is equivalent to Pose.GetTransformedBy(), but keeps the order more intuitive.
|
|||
/// </summary>
|
|||
/// <param name="pose"></param>
|
|||
/// <param name="rhs"></param>
|
|||
/// <returns></returns>
|
|||
public static Pose Multiply(this Pose pose, Pose rhs) |
|||
{ |
|||
return rhs.GetTransformedBy(pose); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Transform the vector by the pose. Conceptually this is equivalent to treating the Pose
|
|||
/// as a 4x4 matrix and multiplying the augmented vector.
|
|||
/// See https://en.wikipedia.org/wiki/Affine_transformation#Augmented_matrix for more details.
|
|||
/// </summary>
|
|||
/// <param name="pose"></param>
|
|||
/// <param name="rhs"></param>
|
|||
/// <returns></returns>
|
|||
public static Vector3 Multiply(this Pose pose, Vector3 rhs) |
|||
{ |
|||
return pose.rotation * rhs + pose.position; |
|||
} |
|||
|
|||
// TODO optimize inv(A)*B?
|
|||
} |
|||
} |
|
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
|
|||
/// <summary>
|
|||
/// Utility class to track a hierarchy of RigidBodies. These are assumed to have a root node,
|
|||
/// and child nodes are connect to their parents via Joints.
|
|||
/// </summary>
|
|||
public class RigidBodyPoseExtractor : PoseExtractor |
|||
{ |
|||
Rigidbody[] m_Bodies; |
|||
|
|||
/// <summary>
|
|||
/// Optional game object used to determine the root of the poses, separate from the actual Rigidbodies
|
|||
/// in the hierarchy. For locomotion
|
|||
/// </summary>
|
|||
GameObject m_VirtualRoot; |
|||
|
|||
/// <summary>
|
|||
/// Initialize given a root RigidBody.
|
|||
/// </summary>
|
|||
/// <param name="rootBody">The root Rigidbody. This has no Joints on it (but other Joints may connect to it).</param>
|
|||
/// <param name="rootGameObject">Optional GameObject used to find Rigidbodies in the hierarchy.</param>
|
|||
/// <param name="virtualRoot">Optional GameObject used to determine the root of the poses,
|
|||
/// separate from the actual Rigidbodies in the hierarchy. For locomotion tasks, with ragdolls, this provides
|
|||
/// a stabilized reference frame, which can improve learning.</param>
|
|||
/// <param name="enableBodyPoses">Optional mapping of whether a body's psoe should be enabled or not.</param>
|
|||
public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null, |
|||
GameObject virtualRoot = null, Dictionary<Rigidbody, bool> enableBodyPoses = null) |
|||
{ |
|||
if (rootBody == null) |
|||
{ |
|||
return; |
|||
} |
|||
|
|||
Rigidbody[] rbs; |
|||
Joint[] joints; |
|||
if (rootGameObject == null) |
|||
{ |
|||
rbs = rootBody.GetComponentsInChildren<Rigidbody>(); |
|||
joints = rootBody.GetComponentsInChildren <Joint>(); |
|||
} |
|||
else |
|||
{ |
|||
rbs = rootGameObject.GetComponentsInChildren<Rigidbody>(); |
|||
joints = rootGameObject.GetComponentsInChildren<Joint>(); |
|||
} |
|||
|
|||
if (rbs == null || rbs.Length == 0) |
|||
{ |
|||
Debug.Log("No rigid bodies found!"); |
|||
return; |
|||
} |
|||
|
|||
if (rbs[0] != rootBody) |
|||
{ |
|||
Debug.Log("Expected root body at index 0"); |
|||
return; |
|||
} |
|||
|
|||
// Adjust the array if we have a virtual root.
|
|||
// This will be at index 0, and the "real" root will be parented to it.
|
|||
if (virtualRoot != null) |
|||
{ |
|||
var extendedRbs = new Rigidbody[rbs.Length + 1]; |
|||
for (var i = 0; i < rbs.Length; i++) |
|||
{ |
|||
extendedRbs[i + 1] = rbs[i]; |
|||
} |
|||
|
|||
rbs = extendedRbs; |
|||
} |
|||
|
|||
var bodyToIndex = new Dictionary<Rigidbody, int>(rbs.Length); |
|||
var parentIndices = new int[rbs.Length]; |
|||
parentIndices[0] = -1; |
|||
|
|||
for (var i = 0; i < rbs.Length; i++) |
|||
{ |
|||
if(rbs[i] != null) |
|||
{ |
|||
bodyToIndex[rbs[i]] = i; |
|||
} |
|||
} |
|||
|
|||
foreach (var j in joints) |
|||
{ |
|||
var parent = j.connectedBody; |
|||
var child = j.GetComponent<Rigidbody>(); |
|||
|
|||
var parentIndex = bodyToIndex[parent]; |
|||
var childIndex = bodyToIndex[child]; |
|||
parentIndices[childIndex] = parentIndex; |
|||
} |
|||
|
|||
if (virtualRoot != null) |
|||
{ |
|||
// Make sure the original root treats the virtual root as its parent.
|
|||
parentIndices[1] = 0; |
|||
m_VirtualRoot = virtualRoot; |
|||
} |
|||
|
|||
m_Bodies = rbs; |
|||
Setup(parentIndices); |
|||
|
|||
// By default, ignore the root
|
|||
SetPoseEnabled(0, false); |
|||
|
|||
if (enableBodyPoses != null) |
|||
{ |
|||
foreach (var pair in enableBodyPoses) |
|||
{ |
|||
var rb = pair.Key; |
|||
if (bodyToIndex.TryGetValue(rb, out var index)) |
|||
{ |
|||
SetPoseEnabled(index, pair.Value); |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
protected internal override Vector3 GetLinearVelocityAt(int index) |
|||
{ |
|||
if (index == 0 && m_VirtualRoot != null) |
|||
{ |
|||
// No velocity on the virtual root
|
|||
return Vector3.zero; |
|||
} |
|||
return m_Bodies[index].velocity; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
protected internal override Pose GetPoseAt(int index) |
|||
{ |
|||
if (index == 0 && m_VirtualRoot != null) |
|||
{ |
|||
// Use the GameObject's world transform
|
|||
return new Pose |
|||
{ |
|||
rotation = m_VirtualRoot.transform.rotation, |
|||
position = m_VirtualRoot.transform.position |
|||
}; |
|||
} |
|||
|
|||
var body = m_Bodies[index]; |
|||
return new Pose { rotation = body.rotation, position = body.position }; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
protected internal override Object GetObjectAt(int index) |
|||
{ |
|||
if (index == 0 && m_VirtualRoot != null) |
|||
{ |
|||
return m_VirtualRoot; |
|||
} |
|||
return m_Bodies[index]; |
|||
} |
|||
|
|||
internal Rigidbody[] Bodies => m_Bodies; |
|||
|
|||
/// <summary>
|
|||
/// Get a dictionary indicating which Rigidbodies' poses are enabled or disabled.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
internal Dictionary<Rigidbody, bool> GetBodyPosesEnabled() |
|||
{ |
|||
var bodyPosesEnabled = new Dictionary<Rigidbody, bool>(m_Bodies.Length); |
|||
for (var i = 0; i < m_Bodies.Length; i++) |
|||
{ |
|||
var rb = m_Bodies[i]; |
|||
if (rb == null) |
|||
{ |
|||
continue; // skip virtual root
|
|||
} |
|||
|
|||
bodyPosesEnabled[rb] = IsPoseEnabled(i); |
|||
} |
|||
|
|||
return bodyPosesEnabled; |
|||
} |
|||
|
|||
internal IEnumerable<Rigidbody> GetEnabledRigidbodies() |
|||
{ |
|||
if (m_Bodies == null) |
|||
{ |
|||
yield break; |
|||
} |
|||
|
|||
for (var i = 0; i < m_Bodies.Length; i++) |
|||
{ |
|||
var rb = m_Bodies[i]; |
|||
if (rb == null) |
|||
{ |
|||
// Ignore a virtual root.
|
|||
continue; |
|||
} |
|||
|
|||
if (IsPoseEnabled(i)) |
|||
{ |
|||
yield return rb; |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
} |
|
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
/// <summary>
|
|||
/// Editor component that creates a PhysicsBodySensor for the Agent.
|
|||
/// </summary>
|
|||
public class RigidBodySensorComponent : SensorComponent |
|||
{ |
|||
/// <summary>
|
|||
/// The root Rigidbody of the system.
|
|||
/// </summary>
|
|||
public Rigidbody RootBody; |
|||
|
|||
/// <summary>
|
|||
/// Optional GameObject used to determine the root of the poses.
|
|||
/// </summary>
|
|||
public GameObject VirtualRoot; |
|||
|
|||
/// <summary>
|
|||
/// Settings defining what types of observations will be generated.
|
|||
/// </summary>
|
|||
[SerializeField] |
|||
public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default(); |
|||
|
|||
/// <summary>
|
|||
/// Optional sensor name. This must be unique for each Agent.
|
|||
/// </summary>
|
|||
[SerializeField] |
|||
public string sensorName; |
|||
|
|||
[SerializeField] |
|||
[HideInInspector] |
|||
RigidBodyPoseExtractor m_PoseExtractor; |
|||
|
|||
/// <summary>
|
|||
/// Creates a PhysicsBodySensor.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
public override ISensor CreateSensor() |
|||
{ |
|||
var _sensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{RootBody?.name}" : sensorName; |
|||
return new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName); |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public override int[] GetObservationShape() |
|||
{ |
|||
if (RootBody == null) |
|||
{ |
|||
return new[] { 0 }; |
|||
} |
|||
|
|||
var poseExtractor = GetPoseExtractor(); |
|||
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings); |
|||
|
|||
var numJointObservations = 0; |
|||
foreach(var rb in poseExtractor.GetEnabledRigidbodies()) |
|||
{ |
|||
var joint = rb.GetComponent<Joint>(); |
|||
numJointObservations += RigidBodyJointExtractor.NumObservations(rb, joint, Settings); |
|||
} |
|||
return new[] { numPoseObservations + numJointObservations }; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Get the DisplayNodes of the hierarchy.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
internal IList<PoseExtractor.DisplayNode> GetDisplayNodes() |
|||
{ |
|||
return GetPoseExtractor().GetDisplayNodes(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Lazy construction of the PoseExtractor.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
RigidBodyPoseExtractor GetPoseExtractor() |
|||
{ |
|||
if (m_PoseExtractor == null) |
|||
{ |
|||
ResetPoseExtractor(); |
|||
} |
|||
|
|||
return m_PoseExtractor; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Reset the pose extractor, trying to keep the enabled state of the corresponding poses the same.
|
|||
/// </summary>
|
|||
internal void ResetPoseExtractor() |
|||
{ |
|||
// Get the current enabled state of each body, so that we can reinitialize with them.
|
|||
Dictionary<Rigidbody, bool> bodyPosesEnabled = null; |
|||
if (m_PoseExtractor != null) |
|||
{ |
|||
bodyPosesEnabled = m_PoseExtractor.GetBodyPosesEnabled(); |
|||
} |
|||
m_PoseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject, VirtualRoot, bodyPosesEnabled); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Toggle the pose at the given index.
|
|||
/// </summary>
|
|||
/// <param name="index"></param>
|
|||
/// <param name="enabled"></param>
|
|||
internal void SetPoseEnabled(int index, bool enabled) |
|||
{ |
|||
GetPoseExtractor().SetPoseEnabled(index, enabled); |
|||
} |
|||
} |
|||
|
|||
} |
|
|||
#if UNITY_2020_1_OR_NEWER
|
|||
using UnityEngine; |
|||
using NUnit.Framework; |
|||
using Unity.MLAgents.Extensions.Sensors; |
|||
|
|||
|
|||
namespace Unity.MLAgents.Extensions.Tests.Sensors |
|||
{ |
|||
|
|||
public class ArticulationBodySensorTests |
|||
{ |
|||
[Test] |
|||
public void TestNullRootBody() |
|||
{ |
|||
var gameObj = new GameObject(); |
|||
|
|||
var sensorComponent = gameObj.AddComponent<ArticulationBodySensorComponent>(); |
|||
var sensor = sensorComponent.CreateSensor(); |
|||
SensorTestHelper.CompareObservation(sensor, new float[0]); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestSingleBody() |
|||
{ |
|||
var gameObj = new GameObject(); |
|||
var articulationBody = gameObj.AddComponent<ArticulationBody>(); |
|||
var sensorComponent = gameObj.AddComponent<ArticulationBodySensorComponent>(); |
|||
sensorComponent.RootBody = articulationBody; |
|||
sensorComponent.Settings = new PhysicsSensorSettings |
|||
{ |
|||
UseModelSpaceLinearVelocity = true, |
|||
UseLocalSpaceTranslations = true, |
|||
UseLocalSpaceRotations = true |
|||
}; |
|||
|
|||
var sensor = sensorComponent.CreateSensor(); |
|||
sensor.Update(); |
|||
var expected = new[] |
|||
{ |
|||
0f, 0f, 0f, // ModelSpaceLinearVelocity
|
|||
0f, 0f, 0f, // LocalSpaceTranslations
|
|||
0f, 0f, 0f, 1f // LocalSpaceRotations
|
|||
}; |
|||
SensorTestHelper.CompareObservation(sensor, expected); |
|||
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestBodiesWithJoint() |
|||
{ |
|||
var rootObj = new GameObject(); |
|||
var rootArticBody = rootObj.AddComponent<ArticulationBody>(); |
|||
|
|||
var middleGamObj = new GameObject(); |
|||
var middleArticBody = middleGamObj.AddComponent<ArticulationBody>(); |
|||
middleArticBody.AddForce(new Vector3(0f, 1f, 0f)); |
|||
middleGamObj.transform.SetParent(rootObj.transform); |
|||
middleGamObj.transform.localPosition = new Vector3(13.37f, 0f, 0f); |
|||
middleArticBody.jointType = ArticulationJointType.RevoluteJoint; |
|||
|
|||
var leafGameObj = new GameObject(); |
|||
var leafArticBody = leafGameObj.AddComponent<ArticulationBody>(); |
|||
leafGameObj.transform.SetParent(middleGamObj.transform); |
|||
leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f); |
|||
leafArticBody.jointType = ArticulationJointType.PrismaticJoint; |
|||
leafArticBody.linearLockZ = ArticulationDofLock.LimitedMotion; |
|||
leafArticBody.zDrive = new ArticulationDrive |
|||
{ |
|||
lowerLimit = -3, |
|||
upperLimit = 1 |
|||
}; |
|||
|
|||
|
|||
#if UNITY_2020_2_OR_NEWER
|
|||
// ArticulationBody.velocity is read-only in 2020.1
|
|||
rootArticBody.velocity = new Vector3(1f, 0f, 0f); |
|||
middleArticBody.velocity = new Vector3(0f, 1f, 0f); |
|||
leafArticBody.velocity = new Vector3(0f, 0f, 1f); |
|||
#endif
|
|||
|
|||
var sensorComponent = rootObj.AddComponent<ArticulationBodySensorComponent>(); |
|||
sensorComponent.RootBody = rootArticBody; |
|||
sensorComponent.Settings = new PhysicsSensorSettings |
|||
{ |
|||
UseModelSpaceTranslations = true, |
|||
UseLocalSpaceTranslations = true, |
|||
#if UNITY_2020_2_OR_NEWER
|
|||
UseLocalSpaceLinearVelocity = true |
|||
#endif
|
|||
}; |
|||
|
|||
var sensor = sensorComponent.CreateSensor(); |
|||
sensor.Update(); |
|||
var expected = new[] |
|||
{ |
|||
// Model space
|
|||
0f, 0f, 0f, // Root pos
|
|||
13.37f, 0f, 0f, // Middle pos
|
|||
leafGameObj.transform.position.x, 0f, 0f, // Leaf pos
|
|||
|
|||
// Local space
|
|||
0f, 0f, 0f, // Root pos
|
|||
13.37f, 0f, 0f, // Attached pos
|
|||
4.2f, 0f, 0f, // Leaf pos
|
|||
|
|||
#if UNITY_2020_2_OR_NEWER
|
|||
0f, 0f, 0f, // Root vel
|
|||
-1f, 1f, 0f, // Attached vel
|
|||
0f, -1f, 1f // Leaf vel
|
|||
#endif
|
|||
}; |
|||
SensorTestHelper.CompareObservation(sensor, expected); |
|||
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); |
|||
|
|||
// Update the settings to only process joint observations
|
|||
sensorComponent.Settings = new PhysicsSensorSettings |
|||
{ |
|||
UseJointForces = true, |
|||
UseJointPositionsAndAngles = true, |
|||
}; |
|||
|
|||
sensor = sensorComponent.CreateSensor(); |
|||
sensor.Update(); |
|||
|
|||
expected = new[] |
|||
{ |
|||
// revolute
|
|||
0f, 1f, // joint1.position (sin and cos)
|
|||
0f, // joint1.force
|
|||
|
|||
// prismatic
|
|||
0.5f, // joint2.position (interpolate between limits)
|
|||
0f, // joint2.force
|
|||
}; |
|||
SensorTestHelper.CompareObservation(sensor, expected); |
|||
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); |
|||
} |
|||
} |
|||
} |
|||
#endif // #if UNITY_2020_1_OR_NEWER
|
|
|||
using UnityEngine; |
|||
using NUnit.Framework; |
|||
using Unity.MLAgents.Sensors; |
|||
using Unity.MLAgents.Extensions.Sensors; |
|||
|
|||
|
|||
namespace Unity.MLAgents.Extensions.Tests.Sensors |
|||
{ |
|||
|
|||
public static class SensorTestHelper |
|||
{ |
|||
public static void CompareObservation(ISensor sensor, float[] expected) |
|||
{ |
|||
string errorMessage; |
|||
bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage); |
|||
Assert.IsTrue(isOK, errorMessage); |
|||
} |
|||
} |
|||
|
|||
public class RigidBodySensorTests |
|||
{ |
|||
[Test] |
|||
public void TestNullRootBody() |
|||
{ |
|||
var gameObj = new GameObject(); |
|||
|
|||
var sensorComponent = gameObj.AddComponent<RigidBodySensorComponent>(); |
|||
var sensor = sensorComponent.CreateSensor(); |
|||
SensorTestHelper.CompareObservation(sensor, new float[0]); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestSingleRigidbody() |
|||
{ |
|||
var gameObj = new GameObject(); |
|||
var rootRb = gameObj.AddComponent<Rigidbody>(); |
|||
var sensorComponent = gameObj.AddComponent<RigidBodySensorComponent>(); |
|||
sensorComponent.RootBody = rootRb; |
|||
sensorComponent.Settings = new PhysicsSensorSettings |
|||
{ |
|||
UseModelSpaceLinearVelocity = true, |
|||
UseLocalSpaceTranslations = true, |
|||
UseLocalSpaceRotations = true |
|||
}; |
|||
|
|||
var sensor = sensorComponent.CreateSensor(); |
|||
sensor.Update(); |
|||
|
|||
// The root body is ignored since it always generates identity values
|
|||
// and there are no other bodies to generate observations.
|
|||
var expected = new float[0]; |
|||
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); |
|||
SensorTestHelper.CompareObservation(sensor, expected); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestBodiesWithJoint() |
|||
{ |
|||
var rootObj = new GameObject(); |
|||
var rootRb = rootObj.AddComponent<Rigidbody>(); |
|||
rootRb.velocity = new Vector3(1f, 0f, 0f); |
|||
|
|||
var middleGamObj = new GameObject(); |
|||
var middleRb = middleGamObj.AddComponent<Rigidbody>(); |
|||
middleRb.velocity = new Vector3(0f, 1f, 0f); |
|||
middleGamObj.transform.SetParent(rootObj.transform); |
|||
middleGamObj.transform.localPosition = new Vector3(13.37f, 0f, 0f); |
|||
var joint = middleGamObj.AddComponent<ConfigurableJoint>(); |
|||
joint.connectedBody = rootRb; |
|||
|
|||
var leafGameObj = new GameObject(); |
|||
var leafRb = leafGameObj.AddComponent<Rigidbody>(); |
|||
leafRb.velocity = new Vector3(0f, 0f, 1f); |
|||
leafGameObj.transform.SetParent(middleGamObj.transform); |
|||
leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f); |
|||
var joint2 = leafGameObj.AddComponent<ConfigurableJoint>(); |
|||
joint2.connectedBody = middleRb; |
|||
|
|||
var virtualRoot = new GameObject(); |
|||
|
|||
var sensorComponent = rootObj.AddComponent<RigidBodySensorComponent>(); |
|||
sensorComponent.RootBody = rootRb; |
|||
sensorComponent.Settings = new PhysicsSensorSettings |
|||
{ |
|||
UseModelSpaceTranslations = true, |
|||
UseLocalSpaceTranslations = true, |
|||
UseLocalSpaceLinearVelocity = true |
|||
}; |
|||
sensorComponent.VirtualRoot = virtualRoot; |
|||
|
|||
var sensor = sensorComponent.CreateSensor(); |
|||
sensor.Update(); |
|||
|
|||
// Note that the VirtualRoot is ignored from the observations
|
|||
var expected = new[] |
|||
{ |
|||
// Model space
|
|||
0f, 0f, 0f, // Root pos
|
|||
13.37f, 0f, 0f, // Middle pos
|
|||
leafGameObj.transform.position.x, 0f, 0f, // Leaf pos
|
|||
|
|||
// Local space
|
|||
0f, 0f, 0f, // Root pos
|
|||
13.37f, 0f, 0f, // Attached pos
|
|||
4.2f, 0f, 0f, // Leaf pos
|
|||
|
|||
1f, 0f, 0f, // Root vel (relative to virtual root)
|
|||
-1f, 1f, 0f, // Attached vel
|
|||
0f, -1f, 1f // Leaf vel
|
|||
}; |
|||
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); |
|||
SensorTestHelper.CompareObservation(sensor, expected); |
|||
|
|||
// Update the settings to only process joint observations
|
|||
sensorComponent.Settings = new PhysicsSensorSettings |
|||
{ |
|||
UseJointPositionsAndAngles = true, |
|||
UseJointForces = true, |
|||
}; |
|||
|
|||
sensor = sensorComponent.CreateSensor(); |
|||
sensor.Update(); |
|||
|
|||
expected = new[] |
|||
{ |
|||
0f, 0f, 0f, // joint1.force
|
|||
0f, 0f, 0f, // joint1.torque
|
|||
0f, 0f, 0f, // joint2.force
|
|||
0f, 0f, 0f, // joint2.torque
|
|||
}; |
|||
SensorTestHelper.CompareObservation(sensor, expected); |
|||
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); |
|||
|
|||
} |
|||
} |
|||
} |
|
|||
using System; |
|||
using UnityEngine; |
|||
using NUnit.Framework; |
|||
using Unity.MLAgents.Extensions.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Tests.Sensors |
|||
{ |
|||
public class PoseExtractorTests |
|||
{ |
|||
|
|||
class BasicPoseExtractor : PoseExtractor |
|||
{ |
|||
protected internal override Pose GetPoseAt(int index) |
|||
{ |
|||
return Pose.identity; |
|||
} |
|||
|
|||
protected internal override Vector3 GetLinearVelocityAt(int index) |
|||
{ |
|||
return Vector3.zero; |
|||
} |
|||
} |
|||
|
|||
class UselessPoseExtractor : BasicPoseExtractor |
|||
{ |
|||
public void Init(int[] parentIndices) |
|||
{ |
|||
Setup(parentIndices); |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestEmptyExtractor() |
|||
{ |
|||
var poseExtractor = new UselessPoseExtractor(); |
|||
|
|||
// These should be no-ops
|
|||
poseExtractor.UpdateLocalSpacePoses(); |
|||
poseExtractor.UpdateModelSpacePoses(); |
|||
|
|||
Assert.AreEqual(0, poseExtractor.NumPoses); |
|||
|
|||
// Iterating through poses and velocities should be an empty loop
|
|||
foreach (var pose in poseExtractor.GetEnabledModelSpacePoses()) |
|||
{ |
|||
throw new UnityAgentsException("This shouldn't happen"); |
|||
} |
|||
|
|||
foreach (var pose in poseExtractor.GetEnabledLocalSpacePoses()) |
|||
{ |
|||
throw new UnityAgentsException("This shouldn't happen"); |
|||
} |
|||
|
|||
foreach (var vel in poseExtractor.GetEnabledModelSpaceVelocities()) |
|||
{ |
|||
throw new UnityAgentsException("This shouldn't happen"); |
|||
} |
|||
|
|||
foreach (var vel in poseExtractor.GetEnabledLocalSpaceVelocities()) |
|||
{ |
|||
throw new UnityAgentsException("This shouldn't happen"); |
|||
} |
|||
|
|||
// Getting a parent index should throw an index exception
|
|||
Assert.Throws <NullReferenceException>( |
|||
() => poseExtractor.GetParentIndex(0) |
|||
); |
|||
|
|||
// DisplayNodes should be empty
|
|||
var displayNodes = poseExtractor.GetDisplayNodes(); |
|||
Assert.AreEqual(0, displayNodes.Count); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestSimpleExtractor() |
|||
{ |
|||
var poseExtractor = new UselessPoseExtractor(); |
|||
var parentIndices = new[] { -1, 0 }; |
|||
poseExtractor.Init(parentIndices); |
|||
Assert.AreEqual(2, poseExtractor.NumPoses); |
|||
} |
|||
|
|||
|
|||
/// <summary>
|
|||
/// A simple "chain" hierarchy, where each object is parented to the one before it.
|
|||
/// 0 <- 1 <- 2 <- ...
|
|||
/// </summary>
|
|||
class ChainPoseExtractor : PoseExtractor |
|||
{ |
|||
public Vector3 offset; |
|||
public ChainPoseExtractor(int size) |
|||
{ |
|||
var parents = new int[size]; |
|||
for (var i = 0; i < size; i++) |
|||
{ |
|||
parents[i] = i - 1; |
|||
} |
|||
Setup(parents); |
|||
} |
|||
|
|||
protected internal override Pose GetPoseAt(int index) |
|||
{ |
|||
var rotation = Quaternion.identity; |
|||
var translation = offset + new Vector3(index, index, index); |
|||
return new Pose |
|||
{ |
|||
rotation = rotation, |
|||
position = translation |
|||
}; |
|||
} |
|||
|
|||
protected internal override Vector3 GetLinearVelocityAt(int index) |
|||
{ |
|||
return Vector3.zero; |
|||
} |
|||
|
|||
} |
|||
|
|||
[Test] |
|||
public void TestChain() |
|||
{ |
|||
var size = 4; |
|||
var chain = new ChainPoseExtractor(size); |
|||
chain.offset = new Vector3(.5f, .75f, .333f); |
|||
|
|||
chain.UpdateModelSpacePoses(); |
|||
chain.UpdateLocalSpacePoses(); |
|||
|
|||
|
|||
var modelPoseIndex = 0; |
|||
foreach (var modelSpace in chain.GetEnabledModelSpacePoses()) |
|||
{ |
|||
if (modelPoseIndex == 0) |
|||
{ |
|||
// Root transforms are currently always the identity.
|
|||
Assert.IsTrue(modelSpace == Pose.identity); |
|||
} |
|||
else |
|||
{ |
|||
var expectedModelTranslation = new Vector3(modelPoseIndex, modelPoseIndex, modelPoseIndex); |
|||
Assert.IsTrue(expectedModelTranslation == modelSpace.position); |
|||
|
|||
} |
|||
modelPoseIndex++; |
|||
} |
|||
Assert.AreEqual(size, modelPoseIndex); |
|||
|
|||
var localPoseIndex = 0; |
|||
foreach (var localSpace in chain.GetEnabledLocalSpacePoses()) |
|||
{ |
|||
if (localPoseIndex == 0) |
|||
{ |
|||
// Root transforms are currently always the identity.
|
|||
Assert.IsTrue(localSpace == Pose.identity); |
|||
} |
|||
else |
|||
{ |
|||
var expectedLocalTranslation = new Vector3(1, 1, 1); |
|||
Assert.IsTrue(expectedLocalTranslation == localSpace.position, $"{expectedLocalTranslation} != {localSpace.position}"); |
|||
} |
|||
|
|||
localPoseIndex++; |
|||
} |
|||
Assert.AreEqual(size, localPoseIndex); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestChainDisplayNodes() |
|||
{ |
|||
var size = 4; |
|||
var chain = new ChainPoseExtractor(size); |
|||
|
|||
var displayNodes = chain.GetDisplayNodes(); |
|||
Assert.AreEqual(size, displayNodes.Count); |
|||
|
|||
for (var i = 0; i < size; i++) |
|||
{ |
|||
var displayNode = displayNodes[i]; |
|||
Assert.AreEqual(i, displayNode.OriginalIndex); |
|||
Assert.AreEqual(null, displayNode.NodeObject); |
|||
Assert.AreEqual(i, displayNode.Depth); |
|||
Assert.AreEqual(true, displayNode.Enabled); |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestDisplayNodesLoop() |
|||
{ |
|||
// Degenerate case with a loop
|
|||
var poseExtractor = new UselessPoseExtractor(); |
|||
poseExtractor.Init(new[] {-1, 2, 1}); |
|||
|
|||
// This just shouldn't blow up
|
|||
poseExtractor.GetDisplayNodes(); |
|||
|
|||
// Self-loop
|
|||
poseExtractor.Init(new[] {-1, 1}); |
|||
|
|||
// This just shouldn't blow up
|
|||
poseExtractor.GetDisplayNodes(); |
|||
} |
|||
|
|||
class BadPoseExtractor : BasicPoseExtractor |
|||
{ |
|||
public BadPoseExtractor() |
|||
{ |
|||
var size = 2; |
|||
var parents = new int[size]; |
|||
// Parents are intentionally invalid - expect -1 at root
|
|||
for (var i = 0; i < size; i++) |
|||
{ |
|||
parents[i] = i; |
|||
} |
|||
Setup(parents); |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestExpectedRoot() |
|||
{ |
|||
Assert.Throws<UnityAgentsException>(() => |
|||
{ |
|||
var bad = new BadPoseExtractor(); |
|||
}); |
|||
} |
|||
|
|||
} |
|||
|
|||
public class PoseExtensionTests |
|||
{ |
|||
[Test] |
|||
public void TestInverse() |
|||
{ |
|||
Pose t = new Pose |
|||
{ |
|||
rotation = Quaternion.AngleAxis(23.0f, new Vector3(1, 1, 1).normalized), |
|||
position = new Vector3(-1.0f, 2.0f, 3.0f) |
|||
}; |
|||
|
|||
var inverseT = t.Inverse(); |
|||
var product = inverseT.Multiply(t); |
|||
Assert.IsTrue(Vector3.zero == product.position); |
|||
Assert.IsTrue(Quaternion.identity == product.rotation); |
|||
|
|||
Assert.IsTrue(Pose.identity == product); |
|||
} |
|||
|
|||
} |
|||
} |
|
|||
using UnityEngine; |
|||
using NUnit.Framework; |
|||
using Unity.MLAgents.Extensions.Sensors; |
|||
using UnityEditor; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Tests.Sensors |
|||
{ |
|||
public class RigidBodyPoseExtractorTests |
|||
{ |
|||
[TearDown] |
|||
public void RemoveGameObjects() |
|||
{ |
|||
var objects = GameObject.FindObjectsOfType<GameObject>(); |
|||
foreach (var o in objects) |
|||
{ |
|||
UnityEngine.Object.DestroyImmediate(o); |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestNullRoot() |
|||
{ |
|||
var poseExtractor = new RigidBodyPoseExtractor(null); |
|||
// These should be no-ops
|
|||
poseExtractor.UpdateLocalSpacePoses(); |
|||
poseExtractor.UpdateModelSpacePoses(); |
|||
|
|||
Assert.AreEqual(0, poseExtractor.NumPoses); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestSingleBody() |
|||
{ |
|||
var go = new GameObject(); |
|||
var rootRb = go.AddComponent<Rigidbody>(); |
|||
var poseExtractor = new RigidBodyPoseExtractor(rootRb); |
|||
Assert.AreEqual(1, poseExtractor.NumPoses); |
|||
|
|||
// Also pass the GameObject
|
|||
poseExtractor = new RigidBodyPoseExtractor(rootRb, go); |
|||
Assert.AreEqual(1, poseExtractor.NumPoses); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestNoBodiesFound() |
|||
{ |
|||
// Check that if we can't find any bodies under the game object, we get an empty extractor
|
|||
var gameObj = new GameObject(); |
|||
var rootRb = gameObj.AddComponent<Rigidbody>(); |
|||
var otherGameObj = new GameObject(); |
|||
var poseExtractor = new RigidBodyPoseExtractor(rootRb, otherGameObj); |
|||
Assert.AreEqual(0, poseExtractor.NumPoses); |
|||
|
|||
// Add an RB under the other GameObject. Constructor will find a rigid body, but not the root.
|
|||
var otherRb = otherGameObj.AddComponent<Rigidbody>(); |
|||
poseExtractor = new RigidBodyPoseExtractor(rootRb, otherGameObj); |
|||
Assert.AreEqual(0, poseExtractor.NumPoses); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestTwoBodies() |
|||
{ |
|||
// * rootObj
|
|||
// - rb1
|
|||
// * go2
|
|||
// - rb2
|
|||
// - joint
|
|||
var rootObj = new GameObject(); |
|||
var rb1 = rootObj.AddComponent<Rigidbody>(); |
|||
|
|||
var go2 = new GameObject(); |
|||
var rb2 = go2.AddComponent<Rigidbody>(); |
|||
go2.transform.SetParent(rootObj.transform); |
|||
|
|||
var joint = go2.AddComponent<ConfigurableJoint>(); |
|||
joint.connectedBody = rb1; |
|||
|
|||
var poseExtractor = new RigidBodyPoseExtractor(rb1); |
|||
Assert.AreEqual(2, poseExtractor.NumPoses); |
|||
|
|||
rb1.position = new Vector3(1, 0, 0); |
|||
rb1.rotation = Quaternion.Euler(0, 13.37f, 0); |
|||
rb1.velocity = new Vector3(2, 0, 0); |
|||
|
|||
Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(0).position); |
|||
Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(0).rotation); |
|||
Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(0)); |
|||
|
|||
// Check DisplayNodes gives expected results
|
|||
var displayNodes = poseExtractor.GetDisplayNodes(); |
|||
Assert.AreEqual(2, displayNodes.Count); |
|||
Assert.AreEqual(rb1, displayNodes[0].NodeObject); |
|||
Assert.AreEqual(false, displayNodes[0].Enabled); |
|||
|
|||
Assert.AreEqual(rb2, displayNodes[1].NodeObject); |
|||
Assert.AreEqual(true, displayNodes[1].Enabled); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestTwoBodiesVirtualRoot() |
|||
{ |
|||
// * virtualRoot
|
|||
// * rootObj
|
|||
// - rb1
|
|||
// * go2
|
|||
// - rb2
|
|||
// - joint
|
|||
var virtualRoot = new GameObject("I am vroot"); |
|||
|
|||
var rootObj = new GameObject(); |
|||
var rb1 = rootObj.AddComponent<Rigidbody>(); |
|||
|
|||
var go2 = new GameObject(); |
|||
var rb2 = go2.AddComponent<Rigidbody>(); |
|||
go2.transform.SetParent(rootObj.transform); |
|||
|
|||
var joint = go2.AddComponent<ConfigurableJoint>(); |
|||
joint.connectedBody = rb1; |
|||
|
|||
var poseExtractor = new RigidBodyPoseExtractor(rb1, null, virtualRoot); |
|||
Assert.AreEqual(3, poseExtractor.NumPoses); |
|||
|
|||
// "body" 0 has no parent
|
|||
Assert.AreEqual(-1, poseExtractor.GetParentIndex(0)); |
|||
|
|||
// body 1 has parent 0
|
|||
Assert.AreEqual(0, poseExtractor.GetParentIndex(1)); |
|||
|
|||
var virtualRootPos = new Vector3(0,2,0); |
|||
var virtualRootRot = Quaternion.Euler(0, 42, 0); |
|||
virtualRoot.transform.position = virtualRootPos; |
|||
virtualRoot.transform.rotation = virtualRootRot; |
|||
|
|||
Assert.AreEqual(virtualRootPos, poseExtractor.GetPoseAt(0).position); |
|||
Assert.IsTrue(virtualRootRot == poseExtractor.GetPoseAt(0).rotation); |
|||
Assert.AreEqual(Vector3.zero, poseExtractor.GetLinearVelocityAt(0)); |
|||
|
|||
// Same as above test, but using index 1
|
|||
rb1.position = new Vector3(1, 0, 0); |
|||
rb1.rotation = Quaternion.Euler(0, 13.37f, 0); |
|||
rb1.velocity = new Vector3(2, 0, 0); |
|||
|
|||
Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(1).position); |
|||
Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(1).rotation); |
|||
Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(1)); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestBodyPosesEnabledDictionary() |
|||
{ |
|||
// * rootObj
|
|||
// - rb1
|
|||
// * go2
|
|||
// - rb2
|
|||
// - joint
|
|||
var rootObj = new GameObject(); |
|||
var rb1 = rootObj.AddComponent<Rigidbody>(); |
|||
|
|||
var go2 = new GameObject(); |
|||
var rb2 = go2.AddComponent<Rigidbody>(); |
|||
go2.transform.SetParent(rootObj.transform); |
|||
|
|||
var joint = go2.AddComponent<ConfigurableJoint>(); |
|||
joint.connectedBody = rb1; |
|||
|
|||
var poseExtractor = new RigidBodyPoseExtractor(rb1); |
|||
|
|||
// Expect the root body disabled and the attached one enabled.
|
|||
Assert.IsFalse(poseExtractor.IsPoseEnabled(0)); |
|||
Assert.IsTrue(poseExtractor.IsPoseEnabled(1)); |
|||
var bodyPosesEnabled = poseExtractor.GetBodyPosesEnabled(); |
|||
Assert.IsFalse(bodyPosesEnabled[rb1]); |
|||
Assert.IsTrue(bodyPosesEnabled[rb2]); |
|||
|
|||
// Swap the values
|
|||
bodyPosesEnabled[rb1] = true; |
|||
bodyPosesEnabled[rb2] = false; |
|||
|
|||
var poseExtractor2 = new RigidBodyPoseExtractor(rb1, null, null, bodyPosesEnabled); |
|||
Assert.IsTrue(poseExtractor2.IsPoseEnabled(0)); |
|||
Assert.IsFalse(poseExtractor2.IsPoseEnabled(1)); |
|||
|
|||
|
|||
} |
|||
} |
|||
} |
|
|||
from distutils.util import strtobool |
|||
import os |
|||
from typing import Any, List, Set |
|||
from distutils.version import LooseVersion |
|||
|
|||
try: |
|||
from tf2onnx.tfonnx import process_tf_graph, tf_optimize |
|||
from tf2onnx import optimizer |
|||
|
|||
ONNX_EXPORT_ENABLED = True |
|||
except ImportError: |
|||
# Either onnx and tf2onnx not installed, or they're not compatible with the version of tensorflow |
|||
ONNX_EXPORT_ENABLED = False |
|||
pass |
|||
|
|||
from mlagents.tf_utils import tf |
|||
|
|||
from tensorflow.python.platform import gfile |
|||
from tensorflow.python.framework import graph_util |
|||
|
|||
from mlagents_envs.logging_util import get_logger |
|||
from mlagents.trainers.settings import SerializationSettings |
|||
from mlagents.trainers.tf import tensorflow_to_barracuda as tf2bc |
|||
|
|||
if LooseVersion(tf.__version__) < LooseVersion("1.12.0"): |
|||
# ONNX is only tested on 1.12.0 and later |
|||
ONNX_EXPORT_ENABLED = False |
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
|
|||
POSSIBLE_INPUT_NODES = frozenset( |
|||
[ |
|||
"action_masks", |
|||
"epsilon", |
|||
"prev_action", |
|||
"recurrent_in", |
|||
"sequence_length", |
|||
"vector_observation", |
|||
] |
|||
) |
|||
|
|||
POSSIBLE_OUTPUT_NODES = frozenset( |
|||
["action", "action_probs", "recurrent_out", "value_estimate"] |
|||
) |
|||
|
|||
MODEL_CONSTANTS = frozenset( |
|||
[ |
|||
"action_output_shape", |
|||
"is_continuous_control", |
|||
"memory_size", |
|||
"version_number", |
|||
"trainer_major_version", |
|||
"trainer_minor_version", |
|||
"trainer_patch_version", |
|||
] |
|||
) |
|||
VISUAL_OBSERVATION_PREFIX = "visual_observation_" |
|||
|
|||
|
|||
def export_policy_model( |
|||
model_path: str, |
|||
output_filepath: str, |
|||
behavior_name: str, |
|||
graph: tf.Graph, |
|||
sess: tf.Session, |
|||
) -> None: |
|||
""" |
|||
Exports a TF graph for a Policy to .nn and/or .onnx format for Unity embedding. |
|||
|
|||
:param output_filepath: file path to output the model (without file suffix) |
|||
:param behavior_name: behavior name of the trained model |
|||
:param graph: Tensorflow Graph for the policy |
|||
:param sess: Tensorflow session for the policy |
|||
""" |
|||
frozen_graph_def = _make_frozen_graph(behavior_name, graph, sess) |
|||
if not os.path.exists(output_filepath): |
|||
os.makedirs(output_filepath) |
|||
# Save frozen graph |
|||
frozen_graph_def_path = model_path + "/frozen_graph_def.pb" |
|||
with gfile.GFile(frozen_graph_def_path, "wb") as f: |
|||
f.write(frozen_graph_def.SerializeToString()) |
|||
|
|||
# Convert to barracuda |
|||
if SerializationSettings.convert_to_barracuda: |
|||
tf2bc.convert(frozen_graph_def_path, f"{output_filepath}.nn") |
|||
logger.info(f"Exported {output_filepath}.nn") |
|||
|
|||
# Save to onnx too (if we were able to import it) |
|||
if ONNX_EXPORT_ENABLED: |
|||
if SerializationSettings.convert_to_onnx: |
|||
try: |
|||
onnx_graph = convert_frozen_to_onnx(behavior_name, frozen_graph_def) |
|||
onnx_output_path = f"{output_filepath}.onnx" |
|||
with open(onnx_output_path, "wb") as f: |
|||
f.write(onnx_graph.SerializeToString()) |
|||
logger.info(f"Converting to {onnx_output_path}") |
|||
except Exception: |
|||
# Make conversion errors fatal depending on environment variables (only done during CI) |
|||
if _enforce_onnx_conversion(): |
|||
raise |
|||
logger.exception( |
|||
"Exception trying to save ONNX graph. Please report this error on " |
|||
"https://github.com/Unity-Technologies/ml-agents/issues and " |
|||
"attach a copy of frozen_graph_def.pb" |
|||
) |
|||
|
|||
else: |
|||
if _enforce_onnx_conversion(): |
|||
raise RuntimeError( |
|||
"ONNX conversion enforced, but couldn't import dependencies." |
|||
) |
|||
|
|||
|
|||
def _make_frozen_graph( |
|||
behavior_name: str, graph: tf.Graph, sess: tf.Session |
|||
) -> tf.GraphDef: |
|||
with graph.as_default(): |
|||
target_nodes = ",".join(_process_graph(behavior_name, graph)) |
|||
graph_def = graph.as_graph_def() |
|||
output_graph_def = graph_util.convert_variables_to_constants( |
|||
sess, graph_def, target_nodes.replace(" ", "").split(",") |
|||
) |
|||
return output_graph_def |
|||
|
|||
|
|||
def convert_frozen_to_onnx(behavior_name: str, frozen_graph_def: tf.GraphDef) -> Any: |
|||
# This is basically https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py |
|||
|
|||
inputs = _get_input_node_names(frozen_graph_def) |
|||
outputs = _get_output_node_names(frozen_graph_def) |
|||
logger.info(f"onnx export - inputs:{inputs} outputs:{outputs}") |
|||
|
|||
frozen_graph_def = tf_optimize( |
|||
inputs, outputs, frozen_graph_def, fold_constant=True |
|||
) |
|||
|
|||
with tf.Graph().as_default() as tf_graph: |
|||
tf.import_graph_def(frozen_graph_def, name="") |
|||
with tf.Session(graph=tf_graph): |
|||
g = process_tf_graph( |
|||
tf_graph, |
|||
input_names=inputs, |
|||
output_names=outputs, |
|||
opset=SerializationSettings.onnx_opset, |
|||
) |
|||
|
|||
onnx_graph = optimizer.optimize_graph(g) |
|||
model_proto = onnx_graph.make_model(behavior_name) |
|||
|
|||
return model_proto |
|||
|
|||
|
|||
def _get_input_node_names(frozen_graph_def: Any) -> List[str]: |
|||
""" |
|||
Get the list of input node names from the graph. |
|||
Names are suffixed with ":0" |
|||
""" |
|||
node_names = _get_frozen_graph_node_names(frozen_graph_def) |
|||
input_names = node_names & POSSIBLE_INPUT_NODES |
|||
|
|||
# Check visual inputs sequentially, and exit as soon as we don't find one |
|||
vis_index = 0 |
|||
while True: |
|||
vis_node_name = f"{VISUAL_OBSERVATION_PREFIX}{vis_index}" |
|||
if vis_node_name in node_names: |
|||
input_names.add(vis_node_name) |
|||
else: |
|||
break |
|||
vis_index += 1 |
|||
# Append the port |
|||
return [f"{n}:0" for n in input_names] |
|||
|
|||
|
|||
def _get_output_node_names(frozen_graph_def: Any) -> List[str]: |
|||
""" |
|||
Get the list of output node names from the graph. |
|||
Also include constants, so that they will be readable by the |
|||
onnx importer. |
|||
Names are suffixed with ":0" |
|||
""" |
|||
node_names = _get_frozen_graph_node_names(frozen_graph_def) |
|||
output_names = node_names & (POSSIBLE_OUTPUT_NODES | MODEL_CONSTANTS) |
|||
# Append the port |
|||
return [f"{n}:0" for n in output_names] |
|||
|
|||
|
|||
def _get_frozen_graph_node_names(frozen_graph_def: Any) -> Set[str]: |
|||
""" |
|||
Get all the node names from the graph. |
|||
""" |
|||
names = set() |
|||
for node in frozen_graph_def.node: |
|||
names.add(node.name) |
|||
return names |
|||
|
|||
|
|||
def _process_graph(behavior_name: str, graph: tf.Graph) -> List[str]: |
|||
""" |
|||
Gets the list of the output nodes present in the graph for inference |
|||
:return: list of node names |
|||
""" |
|||
all_nodes = [x.name for x in graph.as_graph_def().node] |
|||
nodes = [x for x in all_nodes if x in POSSIBLE_OUTPUT_NODES | MODEL_CONSTANTS] |
|||
logger.info("List of nodes to export for behavior :" + behavior_name) |
|||
for n in nodes: |
|||
logger.info("\t" + n) |
|||
return nodes |
|||
|
|||
|
|||
def _enforce_onnx_conversion() -> bool: |
|||
env_var_name = "TEST_ENFORCE_ONNX_CONVERSION" |
|||
if env_var_name not in os.environ: |
|||
return False |
|||
|
|||
val = os.environ[env_var_name] |
|||
try: |
|||
# This handles e.g. "false" converting reasonably to False |
|||
return strtobool(val) |
|||
except Exception: |
|||
return False |
|
|||
import os |
|||
import shutil |
|||
from typing import Optional, Union, cast |
|||
from mlagents_envs.exception import UnityPolicyException |
|||
from mlagents_envs.logging_util import get_logger |
|||
from mlagents.tf_utils import tf |
|||
from mlagents.trainers.saver.saver import BaseSaver |
|||
from mlagents.trainers.tf.model_serialization import export_policy_model |
|||
from mlagents.trainers.settings import TrainerSettings, SerializationSettings |
|||
from mlagents.trainers.policy.tf_policy import TFPolicy |
|||
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer |
|||
from mlagents.trainers import __version__ |
|||
|
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
|
|||
class TFSaver(BaseSaver): |
|||
""" |
|||
Saver class for TensorFlow |
|||
""" |
|||
|
|||
def __init__( |
|||
self, trainer_settings: TrainerSettings, model_path: str, load: bool = False |
|||
): |
|||
super().__init__() |
|||
self.model_path = model_path |
|||
self.initialize_path = trainer_settings.init_path |
|||
self._keep_checkpoints = trainer_settings.keep_checkpoints |
|||
self.load = load |
|||
|
|||
# Currently only support saving one policy. This is the one to be saved. |
|||
self.policy: Optional[TFPolicy] = None |
|||
self.graph = None |
|||
self.sess = None |
|||
self.tf_saver = None |
|||
|
|||
def register(self, module: Union[TFPolicy, TFOptimizer]) -> None: |
|||
if isinstance(module, TFPolicy): |
|||
self._register_policy(module) |
|||
elif isinstance(module, TFOptimizer): |
|||
self._register_optimizer(module) |
|||
else: |
|||
raise UnityPolicyException( |
|||
"Registering Object of unsupported type {} to Saver ".format( |
|||
type(module) |
|||
) |
|||
) |
|||
|
|||
def _register_policy(self, policy: TFPolicy) -> None: |
|||
if self.policy is None: |
|||
self.policy = policy |
|||
self.graph = self.policy.graph |
|||
self.sess = self.policy.sess |
|||
with self.policy.graph.as_default(): |
|||
self.tf_saver = tf.train.Saver(max_to_keep=self._keep_checkpoints) |
|||
|
|||
def save_checkpoint(self, brain_name: str, step: int) -> str: |
|||
checkpoint_path = os.path.join(self.model_path, f"{brain_name}-{step}") |
|||
# Save the TF checkpoint and graph definition |
|||
if self.graph: |
|||
with self.graph.as_default(): |
|||
if self.tf_saver: |
|||
self.tf_saver.save(self.sess, f"{checkpoint_path}.ckpt") |
|||
tf.train.write_graph( |
|||
self.graph, self.model_path, "raw_graph_def.pb", as_text=False |
|||
) |
|||
# also save the policy so we have optimized model files for each checkpoint |
|||
self.export(checkpoint_path, brain_name) |
|||
return checkpoint_path |
|||
|
|||
def export(self, output_filepath: str, brain_name: str) -> None: |
|||
# save model if there is only one worker or |
|||
# only on worker-0 if there are multiple workers |
|||
if self.policy and self.policy.rank is not None and self.policy.rank != 0: |
|||
return |
|||
export_policy_model( |
|||
self.model_path, output_filepath, brain_name, self.graph, self.sess |
|||
) |
|||
|
|||
def initialize_or_load(self, policy: Optional[TFPolicy] = None) -> None: |
|||
# If there is an initialize path, load from that. Else, load from the set model path. |
|||
# If load is set to True, don't reset steps to 0. Else, do. This allows a user to, |
|||
# e.g., resume from an initialize path. |
|||
if policy is None: |
|||
policy = self.policy |
|||
policy = cast(TFPolicy, policy) |
|||
reset_steps = not self.load |
|||
if self.initialize_path is not None: |
|||
self._load_graph( |
|||
policy, self.initialize_path, reset_global_steps=reset_steps |
|||
) |
|||
elif self.load: |
|||
self._load_graph(policy, self.model_path, reset_global_steps=reset_steps) |
|||
else: |
|||
policy.initialize() |
|||
TFPolicy.broadcast_global_variables(0) |
|||
|
|||
def _load_graph( |
|||
self, policy: TFPolicy, model_path: str, reset_global_steps: bool = False |
|||
) -> None: |
|||
with policy.graph.as_default(): |
|||
logger.info(f"Loading model from {model_path}.") |
|||
ckpt = tf.train.get_checkpoint_state(model_path) |
|||
if ckpt is None: |
|||
raise UnityPolicyException( |
|||
"The model {} could not be loaded. Make " |
|||
"sure you specified the right " |
|||
"--run-id and that the previous run you are loading from had the same " |
|||
"behavior names.".format(model_path) |
|||
) |
|||
if self.tf_saver: |
|||
try: |
|||
self.tf_saver.restore(policy.sess, ckpt.model_checkpoint_path) |
|||
except tf.errors.NotFoundError: |
|||
raise UnityPolicyException( |
|||
"The model {} was found but could not be loaded. Make " |
|||
"sure the model is from the same version of ML-Agents, has the same behavior parameters, " |
|||
"and is using the same trainer configuration as the current run.".format( |
|||
model_path |
|||
) |
|||
) |
|||
self._check_model_version(__version__) |
|||
if reset_global_steps: |
|||
policy.set_step(0) |
|||
logger.info( |
|||
"Starting training from step 0 and saving to {}.".format( |
|||
self.model_path |
|||
) |
|||
) |
|||
else: |
|||
logger.info(f"Resuming training from step {policy.get_current_step()}.") |
|||
|
|||
def _check_model_version(self, version: str) -> None: |
|||
""" |
|||
Checks whether the model being loaded was created with the same version of |
|||
ML-Agents, and throw a warning if not so. |
|||
""" |
|||
if self.policy is not None and self.policy.version_tensors is not None: |
|||
loaded_ver = tuple( |
|||
num.eval(session=self.sess) for num in self.policy.version_tensors |
|||
) |
|||
if loaded_ver != TFPolicy._convert_version_string(version): |
|||
logger.warning( |
|||
f"The model checkpoint you are loading from was saved with ML-Agents version " |
|||
f"{loaded_ver[0]}.{loaded_ver[1]}.{loaded_ver[2]} but your current ML-Agents" |
|||
f"version is {version}. Model may not behave properly." |
|||
) |
|||
|
|||
def copy_final_model(self, source_nn_path: str) -> None: |
|||
""" |
|||
Copy the .nn file at the given source to the destination. |
|||
Also copies the corresponding .onnx file if it exists. |
|||
""" |
|||
final_model_name = os.path.splitext(source_nn_path)[0] |
|||
|
|||
if SerializationSettings.convert_to_barracuda: |
|||
source_path = f"{final_model_name}.nn" |
|||
destination_path = f"{self.model_path}.nn" |
|||
shutil.copyfile(source_path, destination_path) |
|||
logger.info(f"Copied {source_path} to {destination_path}.") |
|||
|
|||
if SerializationSettings.convert_to_onnx: |
|||
try: |
|||
source_path = f"{final_model_name}.onnx" |
|||
destination_path = f"{self.model_path}.onnx" |
|||
shutil.copyfile(source_path, destination_path) |
|||
logger.info(f"Copied {source_path} to {destination_path}.") |
|||
except OSError: |
|||
pass |
|
|||
import os |
|||
import shutil |
|||
import torch |
|||
from typing import Dict, Union, Optional, cast |
|||
from mlagents_envs.exception import UnityPolicyException |
|||
from mlagents_envs.logging_util import get_logger |
|||
from mlagents.trainers.saver.saver import BaseSaver |
|||
from mlagents.trainers.settings import TrainerSettings, SerializationSettings |
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer |
|||
from mlagents.trainers.torch.model_serialization import ModelSerializer |
|||
|
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
|
|||
class TorchSaver(BaseSaver): |
|||
""" |
|||
Saver class for PyTorch |
|||
""" |
|||
|
|||
def __init__( |
|||
self, trainer_settings: TrainerSettings, model_path: str, load: bool = False |
|||
): |
|||
super().__init__() |
|||
self.model_path = model_path |
|||
self.initialize_path = trainer_settings.init_path |
|||
self._keep_checkpoints = trainer_settings.keep_checkpoints |
|||
self.load = load |
|||
|
|||
self.policy: Optional[TorchPolicy] = None |
|||
self.exporter: Optional[ModelSerializer] = None |
|||
self.modules: Dict[str, torch.nn.Modules] = {} |
|||
|
|||
def register(self, module: Union[TorchPolicy, TorchOptimizer]) -> None: |
|||
if isinstance(module, TorchPolicy) or isinstance(module, TorchOptimizer): |
|||
self.modules.update(module.get_modules()) # type: ignore |
|||
else: |
|||
raise UnityPolicyException( |
|||
"Registering Object of unsupported type {} to Saver ".format( |
|||
type(module) |
|||
) |
|||
) |
|||
if self.policy is None and isinstance(module, TorchPolicy): |
|||
self.policy = module |
|||
self.exporter = ModelSerializer(self.policy) |
|||
|
|||
def save_checkpoint(self, brain_name: str, step: int) -> str: |
|||
if not os.path.exists(self.model_path): |
|||
os.makedirs(self.model_path) |
|||
checkpoint_path = os.path.join(self.model_path, f"{brain_name}-{step}") |
|||
state_dict = { |
|||
name: module.state_dict() for name, module in self.modules.items() |
|||
} |
|||
torch.save(state_dict, f"{checkpoint_path}.pt") |
|||
torch.save(state_dict, os.path.join(self.model_path, "checkpoint.pt")) |
|||
self.export(checkpoint_path, brain_name) |
|||
return checkpoint_path |
|||
|
|||
def export(self, output_filepath: str, brain_name: str) -> None: |
|||
if self.exporter is not None: |
|||
self.exporter.export_policy_model(output_filepath) |
|||
|
|||
def initialize_or_load(self, policy: Optional[TorchPolicy] = None) -> None: |
|||
# Initialize/Load registered self.policy by default. |
|||
# If given input argument policy, use the input policy instead. |
|||
# This argument is mainly for initialization of the ghost trainer's fixed policy. |
|||
reset_steps = not self.load |
|||
if self.initialize_path is not None: |
|||
self._load_model( |
|||
self.initialize_path, policy, reset_global_steps=reset_steps |
|||
) |
|||
elif self.load: |
|||
self._load_model(self.model_path, policy, reset_global_steps=reset_steps) |
|||
|
|||
def _load_model( |
|||
self, |
|||
load_path: str, |
|||
policy: Optional[TorchPolicy] = None, |
|||
reset_global_steps: bool = False, |
|||
) -> None: |
|||
model_path = os.path.join(load_path, "checkpoint.pt") |
|||
saved_state_dict = torch.load(model_path) |
|||
if policy is None: |
|||
modules = self.modules |
|||
policy = self.policy |
|||
else: |
|||
modules = policy.get_modules() |
|||
policy = cast(TorchPolicy, policy) |
|||
|
|||
for name, mod in modules.items(): |
|||
mod.load_state_dict(saved_state_dict[name]) |
|||
|
|||
if reset_global_steps: |
|||
policy.set_step(0) |
|||
logger.info( |
|||
"Starting training from step 0 and saving to {}.".format( |
|||
self.model_path |
|||
) |
|||
) |
|||
else: |
|||
logger.info(f"Resuming training from step {policy.get_current_step()}.") |
|||
|
|||
def copy_final_model(self, source_nn_path: str) -> None: |
|||
""" |
|||
Copy the .nn file at the given source to the destination. |
|||
Also copies the corresponding .onnx file if it exists. |
|||
""" |
|||
final_model_name = os.path.splitext(source_nn_path)[0] |
|||
|
|||
if SerializationSettings.convert_to_onnx: |
|||
try: |
|||
source_path = f"{final_model_name}.onnx" |
|||
destination_path = f"{self.model_path}.onnx" |
|||
shutil.copyfile(source_path, destination_path) |
|||
logger.info(f"Copied {source_path} to {destination_path}.") |
|||
except OSError: |
|||
pass |
|
|||
import numpy as np |
|||
import pytest |
|||
import torch |
|||
from mlagents.trainers.torch.components.reward_providers import ( |
|||
CuriosityRewardProvider, |
|||
create_reward_provider, |
|||
) |
|||
from mlagents_envs.base_env import BehaviorSpec, ActionType |
|||
from mlagents.trainers.settings import CuriositySettings, RewardSignalType |
|||
from mlagents.trainers.tests.torch.test_reward_providers.utils import ( |
|||
create_agent_buffer, |
|||
) |
|||
|
|||
SEED = [42] |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
], |
|||
) |
|||
def test_construction(behavior_spec: BehaviorSpec) -> None: |
|||
curiosity_settings = CuriositySettings(32, 0.01) |
|||
curiosity_settings.strength = 0.1 |
|||
curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings) |
|||
assert curiosity_rp.strength == 0.1 |
|||
assert curiosity_rp.name == "Curiosity" |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,), (64, 66, 1)], ActionType.DISCRETE, (2, 3)), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)), |
|||
], |
|||
) |
|||
def test_factory(behavior_spec: BehaviorSpec) -> None: |
|||
curiosity_settings = CuriositySettings(32, 0.01) |
|||
curiosity_rp = create_reward_provider( |
|||
RewardSignalType.CURIOSITY, behavior_spec, curiosity_settings |
|||
) |
|||
assert curiosity_rp.name == "Curiosity" |
|||
|
|||
|
|||
@pytest.mark.parametrize("seed", SEED) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)), |
|||
], |
|||
) |
|||
def test_reward_decreases(behavior_spec: BehaviorSpec, seed: int) -> None: |
|||
np.random.seed(seed) |
|||
torch.manual_seed(seed) |
|||
curiosity_settings = CuriositySettings(32, 0.01) |
|||
curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings) |
|||
buffer = create_agent_buffer(behavior_spec, 5) |
|||
curiosity_rp.update(buffer) |
|||
reward_old = curiosity_rp.evaluate(buffer)[0] |
|||
for _ in range(10): |
|||
curiosity_rp.update(buffer) |
|||
reward_new = curiosity_rp.evaluate(buffer)[0] |
|||
assert reward_new < reward_old |
|||
reward_old = reward_new |
|||
|
|||
|
|||
@pytest.mark.parametrize("seed", SEED) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", [BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5)] |
|||
) |
|||
def test_continuous_action_prediction(behavior_spec: BehaviorSpec, seed: int) -> None: |
|||
np.random.seed(seed) |
|||
torch.manual_seed(seed) |
|||
curiosity_settings = CuriositySettings(32, 0.1) |
|||
curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings) |
|||
buffer = create_agent_buffer(behavior_spec, 5) |
|||
for _ in range(200): |
|||
curiosity_rp.update(buffer) |
|||
prediction = curiosity_rp._network.predict_action(buffer)[0].detach() |
|||
target = buffer["actions"][0] |
|||
error = float(torch.mean((prediction - target) ** 2)) |
|||
assert error < 0.001 |
|||
|
|||
|
|||
@pytest.mark.parametrize("seed", SEED) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,), (64, 66, 3)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)), |
|||
], |
|||
) |
|||
def test_next_state_prediction(behavior_spec: BehaviorSpec, seed: int) -> None: |
|||
np.random.seed(seed) |
|||
torch.manual_seed(seed) |
|||
curiosity_settings = CuriositySettings(32, 0.1) |
|||
curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings) |
|||
buffer = create_agent_buffer(behavior_spec, 5) |
|||
for _ in range(100): |
|||
curiosity_rp.update(buffer) |
|||
prediction = curiosity_rp._network.predict_next_state(buffer)[0] |
|||
target = curiosity_rp._network.get_next_state(buffer)[0] |
|||
error = float(torch.mean((prediction - target) ** 2).detach()) |
|||
assert error < 0.001 |
|
|||
import pytest |
|||
from mlagents.trainers.torch.components.reward_providers import ( |
|||
ExtrinsicRewardProvider, |
|||
create_reward_provider, |
|||
) |
|||
from mlagents_envs.base_env import BehaviorSpec, ActionType |
|||
from mlagents.trainers.settings import RewardSignalSettings, RewardSignalType |
|||
from mlagents.trainers.tests.torch.test_reward_providers.utils import ( |
|||
create_agent_buffer, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
], |
|||
) |
|||
def test_construction(behavior_spec: BehaviorSpec) -> None: |
|||
settings = RewardSignalSettings() |
|||
settings.gamma = 0.2 |
|||
extrinsic_rp = ExtrinsicRewardProvider(behavior_spec, settings) |
|||
assert extrinsic_rp.gamma == 0.2 |
|||
assert extrinsic_rp.name == "Extrinsic" |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
], |
|||
) |
|||
def test_factory(behavior_spec: BehaviorSpec) -> None: |
|||
settings = RewardSignalSettings() |
|||
extrinsic_rp = create_reward_provider( |
|||
RewardSignalType.EXTRINSIC, behavior_spec, settings |
|||
) |
|||
assert extrinsic_rp.name == "Extrinsic" |
|||
|
|||
|
|||
@pytest.mark.parametrize("reward", [2.0, 3.0, 4.0]) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
], |
|||
) |
|||
def test_reward(behavior_spec: BehaviorSpec, reward: float) -> None: |
|||
buffer = create_agent_buffer(behavior_spec, 1000, reward) |
|||
settings = RewardSignalSettings() |
|||
extrinsic_rp = ExtrinsicRewardProvider(behavior_spec, settings) |
|||
generated_rewards = extrinsic_rp.evaluate(buffer) |
|||
assert (generated_rewards == reward).all() |
|
|||
from typing import Any |
|||
import numpy as np |
|||
import pytest |
|||
from unittest.mock import patch |
|||
import torch |
|||
import os |
|||
from mlagents.trainers.torch.components.reward_providers import ( |
|||
GAILRewardProvider, |
|||
create_reward_provider, |
|||
) |
|||
from mlagents_envs.base_env import BehaviorSpec, ActionType |
|||
from mlagents.trainers.settings import GAILSettings, RewardSignalType |
|||
from mlagents.trainers.tests.torch.test_reward_providers.utils import ( |
|||
create_agent_buffer, |
|||
) |
|||
from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import ( |
|||
DiscriminatorNetwork, |
|||
) |
|||
|
|||
CONTINUOUS_PATH = ( |
|||
os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir) |
|||
+ "/test.demo" |
|||
) |
|||
DISCRETE_PATH = ( |
|||
os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir) |
|||
+ "/testdcvis.demo" |
|||
) |
|||
SEED = [42] |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", [BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2)] |
|||
) |
|||
def test_construction(behavior_spec: BehaviorSpec) -> None: |
|||
gail_settings = GAILSettings(demo_path=CONTINUOUS_PATH) |
|||
gail_rp = GAILRewardProvider(behavior_spec, gail_settings) |
|||
assert gail_rp.name == "GAIL" |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", [BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2)] |
|||
) |
|||
def test_factory(behavior_spec: BehaviorSpec) -> None: |
|||
gail_settings = GAILSettings(demo_path=CONTINUOUS_PATH) |
|||
gail_rp = create_reward_provider( |
|||
RewardSignalType.GAIL, behavior_spec, gail_settings |
|||
) |
|||
assert gail_rp.name == "GAIL" |
|||
|
|||
|
|||
@pytest.mark.parametrize("seed", SEED) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(8,), (24, 26, 1)], ActionType.CONTINUOUS, 2), |
|||
BehaviorSpec([(50,)], ActionType.DISCRETE, (2, 3, 3, 3)), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (20,)), |
|||
], |
|||
) |
|||
@pytest.mark.parametrize("use_actions", [False, True]) |
|||
@patch( |
|||
"mlagents.trainers.torch.components.reward_providers.gail_reward_provider.demo_to_buffer" |
|||
) |
|||
def test_reward_decreases( |
|||
demo_to_buffer: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int |
|||
) -> None: |
|||
np.random.seed(seed) |
|||
torch.manual_seed(seed) |
|||
buffer_expert = create_agent_buffer(behavior_spec, 1000) |
|||
buffer_policy = create_agent_buffer(behavior_spec, 1000) |
|||
demo_to_buffer.return_value = None, buffer_expert |
|||
gail_settings = GAILSettings( |
|||
demo_path="", learning_rate=0.05, use_vail=False, use_actions=use_actions |
|||
) |
|||
gail_rp = create_reward_provider( |
|||
RewardSignalType.GAIL, behavior_spec, gail_settings |
|||
) |
|||
|
|||
init_reward_expert = gail_rp.evaluate(buffer_expert)[0] |
|||
init_reward_policy = gail_rp.evaluate(buffer_policy)[0] |
|||
|
|||
for _ in range(10): |
|||
gail_rp.update(buffer_policy) |
|||
reward_expert = gail_rp.evaluate(buffer_expert)[0] |
|||
reward_policy = gail_rp.evaluate(buffer_policy)[0] |
|||
assert reward_expert >= 0 # GAIL / VAIL reward always positive |
|||
assert reward_policy >= 0 |
|||
reward_expert = gail_rp.evaluate(buffer_expert)[0] |
|||
reward_policy = gail_rp.evaluate(buffer_policy)[0] |
|||
assert reward_expert > reward_policy # Expert reward greater than non-expert reward |
|||
assert ( |
|||
reward_expert > init_reward_expert |
|||
) # Expert reward getting better as network trains |
|||
assert ( |
|||
reward_policy < init_reward_policy |
|||
) # Non-expert reward getting worse as network trains |
|||
|
|||
|
|||
@pytest.mark.parametrize("seed", SEED) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3, 3, 3)), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (20,)), |
|||
], |
|||
) |
|||
@pytest.mark.parametrize("use_actions", [False, True]) |
|||
@patch( |
|||
"mlagents.trainers.torch.components.reward_providers.gail_reward_provider.demo_to_buffer" |
|||
) |
|||
def test_reward_decreases_vail( |
|||
demo_to_buffer: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int |
|||
) -> None: |
|||
np.random.seed(seed) |
|||
torch.manual_seed(seed) |
|||
buffer_expert = create_agent_buffer(behavior_spec, 1000) |
|||
buffer_policy = create_agent_buffer(behavior_spec, 1000) |
|||
demo_to_buffer.return_value = None, buffer_expert |
|||
gail_settings = GAILSettings( |
|||
demo_path="", learning_rate=0.005, use_vail=True, use_actions=use_actions |
|||
) |
|||
DiscriminatorNetwork.initial_beta = 0.0 |
|||
# we must set the initial value of beta to 0 for testing |
|||
# If we do not, the kl-loss will dominate early and will block the estimator |
|||
gail_rp = create_reward_provider( |
|||
RewardSignalType.GAIL, behavior_spec, gail_settings |
|||
) |
|||
|
|||
for _ in range(100): |
|||
gail_rp.update(buffer_policy) |
|||
reward_expert = gail_rp.evaluate(buffer_expert)[0] |
|||
reward_policy = gail_rp.evaluate(buffer_policy)[0] |
|||
assert reward_expert >= 0 # GAIL / VAIL reward always positive |
|||
assert reward_policy >= 0 |
|||
reward_expert = gail_rp.evaluate(buffer_expert)[0] |
|||
reward_policy = gail_rp.evaluate(buffer_policy)[0] |
|||
assert reward_expert > reward_policy # Expert reward greater than non-expert reward |
|
|||
import numpy as np |
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
from mlagents.trainers.trajectory import SplitObservations |
|||
|
|||
|
|||
def create_agent_buffer( |
|||
behavior_spec: BehaviorSpec, number: int, reward: float = 0.0 |
|||
) -> AgentBuffer: |
|||
buffer = AgentBuffer() |
|||
curr_observations = [ |
|||
np.random.normal(size=shape) for shape in behavior_spec.observation_shapes |
|||
] |
|||
next_observations = [ |
|||
np.random.normal(size=shape) for shape in behavior_spec.observation_shapes |
|||
] |
|||
action = behavior_spec.create_random_action(1)[0, :] |
|||
for _ in range(number): |
|||
curr_split_obs = SplitObservations.from_observations(curr_observations) |
|||
next_split_obs = SplitObservations.from_observations(next_observations) |
|||
for i, _ in enumerate(curr_split_obs.visual_observations): |
|||
buffer["visual_obs%d" % i].append(curr_split_obs.visual_observations[i]) |
|||
buffer["next_visual_obs%d" % i].append( |
|||
next_split_obs.visual_observations[i] |
|||
) |
|||
buffer["vector_obs"].append(curr_split_obs.vector_observations) |
|||
buffer["next_vector_in"].append(next_split_obs.vector_observations) |
|||
buffer["actions"].append(action) |
|||
buffer["done"].append(np.zeros(1, dtype=np.float32)) |
|||
buffer["reward"].append(np.ones(1, dtype=np.float32) * reward) |
|||
buffer["masks"].append(np.ones(1, dtype=np.float32)) |
|||
return buffer |
1001
ml-agents/mlagents/trainers/tests/torch/test.demo
文件差异内容过多而无法显示
查看文件
文件差异内容过多而无法显示
查看文件
|
|||
from unittest.mock import MagicMock |
|||
import pytest |
|||
import mlagents.trainers.tests.mock_brain as mb |
|||
|
|||
import numpy as np |
|||
import os |
|||
|
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.torch.components.bc.module import BCModule |
|||
from mlagents.trainers.settings import ( |
|||
TrainerSettings, |
|||
BehavioralCloningSettings, |
|||
NetworkSettings, |
|||
) |
|||
|
|||
|
|||
def create_bc_module(mock_behavior_specs, bc_settings, use_rnn, tanhresample): |
|||
# model_path = env.external_brain_names[0] |
|||
trainer_config = TrainerSettings() |
|||
trainer_config.network_settings.memory = ( |
|||
NetworkSettings.MemorySettings() if use_rnn else None |
|||
) |
|||
policy = TorchPolicy( |
|||
0, mock_behavior_specs, trainer_config, tanhresample, tanhresample |
|||
) |
|||
bc_module = BCModule( |
|||
policy, |
|||
settings=bc_settings, |
|||
policy_learning_rate=trainer_config.hyperparameters.learning_rate, |
|||
default_batch_size=trainer_config.hyperparameters.batch_size, |
|||
default_num_epoch=3, |
|||
) |
|||
return bc_module |
|||
|
|||
|
|||
# Test default values |
|||
def test_bcmodule_defaults(): |
|||
# See if default values match |
|||
mock_specs = mb.create_mock_3dball_behavior_specs() |
|||
bc_settings = BehavioralCloningSettings( |
|||
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo" |
|||
) |
|||
bc_module = create_bc_module(mock_specs, bc_settings, False, False) |
|||
assert bc_module.num_epoch == 3 |
|||
assert bc_module.batch_size == TrainerSettings().hyperparameters.batch_size |
|||
# Assign strange values and see if it overrides properly |
|||
bc_settings = BehavioralCloningSettings( |
|||
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo", |
|||
num_epoch=100, |
|||
batch_size=10000, |
|||
) |
|||
bc_module = create_bc_module(mock_specs, bc_settings, False, False) |
|||
assert bc_module.num_epoch == 100 |
|||
assert bc_module.batch_size == 10000 |
|||
|
|||
|
|||
# Test with continuous control env and vector actions |
|||
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"]) |
|||
def test_bcmodule_update(is_sac): |
|||
mock_specs = mb.create_mock_3dball_behavior_specs() |
|||
bc_settings = BehavioralCloningSettings( |
|||
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo" |
|||
) |
|||
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac) |
|||
stats = bc_module.update() |
|||
for _, item in stats.items(): |
|||
assert isinstance(item, np.float32) |
|||
|
|||
|
|||
# Test with constant pretraining learning rate |
|||
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"]) |
|||
def test_bcmodule_constant_lr_update(is_sac): |
|||
mock_specs = mb.create_mock_3dball_behavior_specs() |
|||
bc_settings = BehavioralCloningSettings( |
|||
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo", |
|||
steps=0, |
|||
) |
|||
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac) |
|||
stats = bc_module.update() |
|||
for _, item in stats.items(): |
|||
assert isinstance(item, np.float32) |
|||
old_learning_rate = bc_module.current_lr |
|||
|
|||
_ = bc_module.update() |
|||
assert old_learning_rate == bc_module.current_lr |
|||
|
|||
|
|||
# Test with constant pretraining learning rate |
|||
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"]) |
|||
def test_bcmodule_linear_lr_update(is_sac): |
|||
mock_specs = mb.create_mock_3dball_behavior_specs() |
|||
bc_settings = BehavioralCloningSettings( |
|||
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo", |
|||
steps=100, |
|||
) |
|||
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac) |
|||
# Should decay by 10/100 * 0.0003 = 0.00003 |
|||
bc_module.policy.get_current_step = MagicMock(return_value=10) |
|||
old_learning_rate = bc_module.current_lr |
|||
_ = bc_module.update() |
|||
assert old_learning_rate - 0.00003 == pytest.approx(bc_module.current_lr, abs=0.01) |
|||
|
|||
|
|||
# Test with RNN |
|||
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"]) |
|||
def test_bcmodule_rnn_update(is_sac): |
|||
mock_specs = mb.create_mock_3dball_behavior_specs() |
|||
bc_settings = BehavioralCloningSettings( |
|||
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo" |
|||
) |
|||
bc_module = create_bc_module(mock_specs, bc_settings, True, is_sac) |
|||
stats = bc_module.update() |
|||
for _, item in stats.items(): |
|||
assert isinstance(item, np.float32) |
|||
|
|||
|
|||
# Test with discrete control and visual observations |
|||
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"]) |
|||
def test_bcmodule_dc_visual_update(is_sac): |
|||
mock_specs = mb.create_mock_banana_behavior_specs() |
|||
bc_settings = BehavioralCloningSettings( |
|||
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "testdcvis.demo" |
|||
) |
|||
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac) |
|||
stats = bc_module.update() |
|||
for _, item in stats.items(): |
|||
assert isinstance(item, np.float32) |
|||
|
|||
|
|||
# Test with discrete control, visual observations and RNN |
|||
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"]) |
|||
def test_bcmodule_rnn_dc_update(is_sac): |
|||
mock_specs = mb.create_mock_banana_behavior_specs() |
|||
bc_settings = BehavioralCloningSettings( |
|||
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "testdcvis.demo" |
|||
) |
|||
bc_module = create_bc_module(mock_specs, bc_settings, True, is_sac) |
|||
stats = bc_module.update() |
|||
for _, item in stats.items(): |
|||
assert isinstance(item, np.float32) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
pytest.main() |
|
|||
bcvis& -* * * * :VisualFoodCollectorLearning� |
|||
�P�������j� |
|||
TT��PNG |
|||
|
|||
|