比较提交
合并到: unity-tech-cn:main
unity-tech-cn:/main
unity-tech-cn:/develop-generalizationTraining-TrainerController
unity-tech-cn:/tag-0.2.0
unity-tech-cn:/tag-0.2.1
unity-tech-cn:/tag-0.2.1a
unity-tech-cn:/tag-0.2.1c
unity-tech-cn:/tag-0.2.1d
unity-tech-cn:/hotfix-v0.9.2a
unity-tech-cn:/develop-gpu-test
unity-tech-cn:/0.10.1
unity-tech-cn:/develop-pyinstaller
unity-tech-cn:/develop-horovod
unity-tech-cn:/PhysXArticulations20201
unity-tech-cn:/importdocfix
unity-tech-cn:/develop-resizetexture
unity-tech-cn:/hh-develop-walljump_bugfixes
unity-tech-cn:/develop-walljump-fix-sac
unity-tech-cn:/hh-develop-walljump_rnd
unity-tech-cn:/tag-0.11.0.dev0
unity-tech-cn:/develop-pytorch
unity-tech-cn:/tag-0.11.0.dev2
unity-tech-cn:/develop-newnormalization
unity-tech-cn:/tag-0.11.0.dev3
unity-tech-cn:/develop
unity-tech-cn:/release-0.12.0
unity-tech-cn:/tag-0.12.0-dev
unity-tech-cn:/tag-0.12.0.dev0
unity-tech-cn:/tag-0.12.1
unity-tech-cn:/2D-explorations
unity-tech-cn:/asymm-envs
unity-tech-cn:/tag-0.12.1.dev0
unity-tech-cn:/2D-exploration-raycast
unity-tech-cn:/tag-0.12.1.dev1
unity-tech-cn:/release-0.13.0
unity-tech-cn:/release-0.13.1
unity-tech-cn:/plugin-proof-of-concept
unity-tech-cn:/release-0.14.0
unity-tech-cn:/hotfix-bump-version-master
unity-tech-cn:/soccer-fives
unity-tech-cn:/release-0.14.1
unity-tech-cn:/bug-failed-api-check
unity-tech-cn:/test-recurrent-gail
unity-tech-cn:/hh-add-icons
unity-tech-cn:/release-0.15.0
unity-tech-cn:/release-0.15.1
unity-tech-cn:/hh-develop-all-posed-characters
unity-tech-cn:/internal-policy-ghost
unity-tech-cn:/distributed-training
unity-tech-cn:/hh-develop-improve_tennis
unity-tech-cn:/test-tf-ver
unity-tech-cn:/release_1_branch
unity-tech-cn:/tennis-time-horizon
unity-tech-cn:/whitepaper-experiments
unity-tech-cn:/r2v-yamato-linux
unity-tech-cn:/docs-update
unity-tech-cn:/release_2_branch
unity-tech-cn:/exp-mede
unity-tech-cn:/sensitivity
unity-tech-cn:/release_2_verified_load_fix
unity-tech-cn:/test-sampler
unity-tech-cn:/release_2_verified
unity-tech-cn:/hh-develop-ragdoll-testing
unity-tech-cn:/origin-develop-taggedobservations
unity-tech-cn:/MLA-1734-demo-provider
unity-tech-cn:/sampler-refactor-copy
unity-tech-cn:/PhysXArticulations20201Package
unity-tech-cn:/tag-com.unity.ml-agents_1.0.8
unity-tech-cn:/release_3_branch
unity-tech-cn:/github-actions
unity-tech-cn:/release_3_distributed
unity-tech-cn:/fix-batch-tennis
unity-tech-cn:/distributed-ppo-sac
unity-tech-cn:/gridworld-custom-obs
unity-tech-cn:/hw20-segmentation
unity-tech-cn:/hh-develop-gamedev-demo
unity-tech-cn:/active-variablespeed
unity-tech-cn:/release_4_branch
unity-tech-cn:/fix-env-step-loop
unity-tech-cn:/release_5_branch
unity-tech-cn:/fix-walker
unity-tech-cn:/release_6_branch
unity-tech-cn:/hh-32-observation-crawler
unity-tech-cn:/trainer-plugin
unity-tech-cn:/hh-develop-max-steps-demo-recorder
unity-tech-cn:/hh-develop-loco-walker-variable-speed
unity-tech-cn:/exp-0002
unity-tech-cn:/experiment-less-max-step
unity-tech-cn:/hh-develop-hallway-wall-mesh-fix
unity-tech-cn:/release_7_branch
unity-tech-cn:/exp-vince
unity-tech-cn:/hh-develop-gridsensor-tests
unity-tech-cn:/tag-release_8_test0
unity-tech-cn:/tag-release_8_test1
unity-tech-cn:/release_8_branch
unity-tech-cn:/docfix-end-episode
unity-tech-cn:/release_9_branch
unity-tech-cn:/hybrid-action-rewardsignals
unity-tech-cn:/MLA-462-yamato-win
unity-tech-cn:/exp-alternate-atten
unity-tech-cn:/hh-develop-fps_game_project
unity-tech-cn:/fix-conflict-base-env
unity-tech-cn:/release_10_branch
unity-tech-cn:/exp-bullet-hell-trainer
unity-tech-cn:/ai-summit-exp
unity-tech-cn:/comms-grad
unity-tech-cn:/walljump-pushblock
unity-tech-cn:/goal-conditioning
unity-tech-cn:/release_11_branch
unity-tech-cn:/hh-develop-water-balloon-fight
unity-tech-cn:/gc-hyper
unity-tech-cn:/layernorm
unity-tech-cn:/yamato-linux-debug-venv
unity-tech-cn:/soccer-comms
unity-tech-cn:/hh-develop-pushblockcollab
unity-tech-cn:/release_12_branch
unity-tech-cn:/fix-get-step-sp-curr
unity-tech-cn:/continuous-comms
unity-tech-cn:/no-comms
unity-tech-cn:/hh-develop-zombiepushblock
unity-tech-cn:/hypernetwork
unity-tech-cn:/revert-4859-develop-update-readme
unity-tech-cn:/sequencer-env-attention
unity-tech-cn:/hh-develop-variableobs
unity-tech-cn:/exp-tanh
unity-tech-cn:/reward-dist
unity-tech-cn:/exp-weight-decay
unity-tech-cn:/exp-robot
unity-tech-cn:/bullet-hell-barracuda-test-1.3.1
unity-tech-cn:/release_13_branch
unity-tech-cn:/release_14_branch
unity-tech-cn:/exp-clipped-gaussian-entropy
unity-tech-cn:/tic-tac-toe
unity-tech-cn:/hh-develop-dodgeball
unity-tech-cn:/repro-vis-obs-perf
unity-tech-cn:/v2-staging-rebase
unity-tech-cn:/release_15_branch
unity-tech-cn:/release_15_removeendepisode
unity-tech-cn:/release_16_branch
unity-tech-cn:/release_16_fix_gridsensor
unity-tech-cn:/ai-hw-2021
unity-tech-cn:/check-for-ModelOverriders
unity-tech-cn:/fix-grid-obs-shape-init
unity-tech-cn:/fix-gym-needs-reset
unity-tech-cn:/fix-resume-imi
unity-tech-cn:/release_17_branch
unity-tech-cn:/release_17_branch_gpu_test
unity-tech-cn:/colab-links
unity-tech-cn:/exp-continuous-div
unity-tech-cn:/release_17_branch_gpu_2
unity-tech-cn:/exp-diverse-behavior
unity-tech-cn:/grid-onehot-extra-dim-empty
unity-tech-cn:/2.0-verified
unity-tech-cn:/faster-entropy-coeficient-convergence
unity-tech-cn:/pre-r18-update-changelog
unity-tech-cn:/release_18_branch
unity-tech-cn:/main/tracking
unity-tech-cn:/main/reward-providers
unity-tech-cn:/main/project-upgrade
unity-tech-cn:/main/limitation-docs
unity-tech-cn:/develop/nomaxstep-test
unity-tech-cn:/develop/tf2.0
unity-tech-cn:/develop/tanhsquash
unity-tech-cn:/develop/magic-string
unity-tech-cn:/develop/trainerinterface
unity-tech-cn:/develop/separatevalue
unity-tech-cn:/develop/nopreviousactions
unity-tech-cn:/develop/reenablerepeatactions
unity-tech-cn:/develop/0memories
unity-tech-cn:/develop/fixmemoryleak
unity-tech-cn:/develop/reducewalljump
unity-tech-cn:/develop/removeactionholder-onehot
unity-tech-cn:/develop/canonicalize-quaternions
unity-tech-cn:/develop/self-playassym
unity-tech-cn:/develop/demo-load-seek
unity-tech-cn:/develop/progress-bar
unity-tech-cn:/develop/sac-apex
unity-tech-cn:/develop/cubewars
unity-tech-cn:/develop/add-fire
unity-tech-cn:/develop/gym-wrapper
unity-tech-cn:/develop/mm-docs-main-readme
unity-tech-cn:/develop/mm-docs-overview
unity-tech-cn:/develop/no-threading
unity-tech-cn:/develop/dockerfile
unity-tech-cn:/develop/model-store
unity-tech-cn:/develop/checkout-conversion-rebase
unity-tech-cn:/develop/model-transfer
unity-tech-cn:/develop/bisim-review
unity-tech-cn:/develop/taggedobservations
unity-tech-cn:/develop/transfer-bisim
unity-tech-cn:/develop/bisim-sac-transfer
unity-tech-cn:/develop/basketball
unity-tech-cn:/develop/torchmodules
unity-tech-cn:/develop/fixmarkdown
unity-tech-cn:/develop/shortenstrikervsgoalie
unity-tech-cn:/develop/shortengoalie
unity-tech-cn:/develop/torch-save-rp
unity-tech-cn:/develop/torch-to-np
unity-tech-cn:/develop/torch-omp-no-thread
unity-tech-cn:/develop/actionmodel-csharp
unity-tech-cn:/develop/torch-extra
unity-tech-cn:/develop/restructure-torch-networks
unity-tech-cn:/develop/jit
unity-tech-cn:/develop/adjust-cpu-settings-experiment
unity-tech-cn:/develop/torch-sac-threading
unity-tech-cn:/develop/wb
unity-tech-cn:/develop/amrl
unity-tech-cn:/develop/memorydump
unity-tech-cn:/develop/permutepytorch
unity-tech-cn:/develop/sac-targetq
unity-tech-cn:/develop/actions-out
unity-tech-cn:/develop/reshapeonnxmemories
unity-tech-cn:/develop/crawlergail
unity-tech-cn:/develop/debugtorchfood
unity-tech-cn:/develop/hybrid-actions
unity-tech-cn:/develop/bullet-hell
unity-tech-cn:/develop/action-spec-gym
unity-tech-cn:/develop/battlefoodcollector
unity-tech-cn:/develop/use-action-buffers
unity-tech-cn:/develop/hardswish
unity-tech-cn:/develop/leakyrelu
unity-tech-cn:/develop/torch-clip-scale
unity-tech-cn:/develop/contentropy
unity-tech-cn:/develop/manch
unity-tech-cn:/develop/torchcrawlerdebug
unity-tech-cn:/develop/fix-nan
unity-tech-cn:/develop/multitype-buffer
unity-tech-cn:/develop/windows-delay
unity-tech-cn:/develop/torch-tanh
unity-tech-cn:/develop/gail-norm
unity-tech-cn:/develop/multiprocess
unity-tech-cn:/develop/unified-obs
unity-tech-cn:/develop/rm-rf-new-models
unity-tech-cn:/develop/skipcritic
unity-tech-cn:/develop/centralizedcritic
unity-tech-cn:/develop/dodgeball-tests
unity-tech-cn:/develop/cc-teammanager
unity-tech-cn:/develop/weight-decay
unity-tech-cn:/develop/singular-embeddings
unity-tech-cn:/develop/zombieteammanager
unity-tech-cn:/develop/superpush
unity-tech-cn:/develop/teammanager
unity-tech-cn:/develop/zombie-exp
unity-tech-cn:/develop/update-readme
unity-tech-cn:/develop/readme-fix
unity-tech-cn:/develop/coma-noact
unity-tech-cn:/develop/coma-withq
unity-tech-cn:/develop/coma2
unity-tech-cn:/develop/action-slice
unity-tech-cn:/develop/gru
unity-tech-cn:/develop/critic-op-lstm-currentmem
unity-tech-cn:/develop/decaygail
unity-tech-cn:/develop/gail-srl-hack
unity-tech-cn:/develop/rear-pad
unity-tech-cn:/develop/mm-copyright-dates
unity-tech-cn:/develop/dodgeball-raycasts
unity-tech-cn:/develop/collab-envs-exp-ervin
unity-tech-cn:/develop/pushcollabonly
unity-tech-cn:/develop/sample-curation
unity-tech-cn:/develop/soccer-groupman
unity-tech-cn:/develop/input-actuator-tanks
unity-tech-cn:/develop/validate-release-fix
unity-tech-cn:/develop/new-console-log
unity-tech-cn:/develop/lex-walker-model
unity-tech-cn:/develop/lstm-burnin
unity-tech-cn:/develop/grid-vaiable-names
unity-tech-cn:/develop/fix-attn-embedding
unity-tech-cn:/develop/api-documentation-update-some-fixes
unity-tech-cn:/develop/update-grpc
unity-tech-cn:/develop/grid-rootref-debug
unity-tech-cn:/develop/pbcollab-rays
unity-tech-cn:/develop/2.0-verified-pre
unity-tech-cn:/develop/parameterizedenvs
unity-tech-cn:/develop/custom-ray-sensor
unity-tech-cn:/develop/mm-add-v2blog
unity-tech-cn:/develop/custom-raycast
unity-tech-cn:/develop/area-manager
unity-tech-cn:/develop/remove-unecessary-lr
unity-tech-cn:/develop/use-base-env-in-learn
unity-tech-cn:/soccer-fives/multiagent
unity-tech-cn:/develop/cubewars/splashdamage
unity-tech-cn:/develop/add-fire/exp
unity-tech-cn:/develop/add-fire/jit
unity-tech-cn:/develop/add-fire/speedtest
unity-tech-cn:/develop/add-fire/bc
unity-tech-cn:/develop/add-fire/ckpt-2
unity-tech-cn:/develop/add-fire/normalize-context
unity-tech-cn:/develop/add-fire/components-dir
unity-tech-cn:/develop/add-fire/halfentropy
unity-tech-cn:/develop/add-fire/memoryclass
unity-tech-cn:/develop/add-fire/categoricaldist
unity-tech-cn:/develop/add-fire/mm
unity-tech-cn:/develop/add-fire/sac-lst
unity-tech-cn:/develop/add-fire/mm3
unity-tech-cn:/develop/add-fire/continuous
unity-tech-cn:/develop/add-fire/ghost
unity-tech-cn:/develop/add-fire/policy-tests
unity-tech-cn:/develop/add-fire/export-discrete
unity-tech-cn:/develop/add-fire/test-simple-rl-fix-resnet
unity-tech-cn:/develop/add-fire/remove-currdoc
unity-tech-cn:/develop/add-fire/clean2
unity-tech-cn:/develop/add-fire/doc-cleanups
unity-tech-cn:/develop/add-fire/changelog
unity-tech-cn:/develop/add-fire/mm2
unity-tech-cn:/develop/model-transfer/add-physics
unity-tech-cn:/develop/model-transfer/train
unity-tech-cn:/develop/jit/experiments
unity-tech-cn:/exp-vince/sep30-2020
unity-tech-cn:/hh-develop-gridsensor-tests/static
unity-tech-cn:/develop/hybrid-actions/distlist
unity-tech-cn:/develop/bullet-hell/buffer
unity-tech-cn:/goal-conditioning/new
unity-tech-cn:/goal-conditioning/sensors-2
unity-tech-cn:/goal-conditioning/sensors-3-pytest-fix
unity-tech-cn:/goal-conditioning/grid-world
unity-tech-cn:/soccer-comms/disc
unity-tech-cn:/develop/centralizedcritic/counterfact
unity-tech-cn:/develop/centralizedcritic/mm
unity-tech-cn:/develop/centralizedcritic/nonego
unity-tech-cn:/develop/zombieteammanager/disableagent
unity-tech-cn:/develop/zombieteammanager/killfirst
unity-tech-cn:/develop/superpush/int
unity-tech-cn:/develop/superpush/branch-cleanup
unity-tech-cn:/develop/teammanager/int
unity-tech-cn:/develop/teammanager/cubewar-nocycle
unity-tech-cn:/develop/teammanager/cubewars
unity-tech-cn:/develop/superpush/int/hunter
unity-tech-cn:/goal-conditioning/new/allo-crawler
unity-tech-cn:/develop/coma2/clip
unity-tech-cn:/develop/coma2/singlenetwork
unity-tech-cn:/develop/coma2/samenet
unity-tech-cn:/develop/coma2/fixgroup
unity-tech-cn:/develop/coma2/samenet/sum
unity-tech-cn:/hh-develop-dodgeball/goy-input
unity-tech-cn:/develop/soccer-groupman/mod
unity-tech-cn:/develop/soccer-groupman/mod/hunter
unity-tech-cn:/develop/soccer-groupman/mod/hunter/cine
unity-tech-cn:/ai-hw-2021/tensor-applier
拉取从: unity-tech-cn:develop/add-fire/mm
unity-tech-cn:/main
unity-tech-cn:/develop-generalizationTraining-TrainerController
unity-tech-cn:/tag-0.2.0
unity-tech-cn:/tag-0.2.1
unity-tech-cn:/tag-0.2.1a
unity-tech-cn:/tag-0.2.1c
unity-tech-cn:/tag-0.2.1d
unity-tech-cn:/hotfix-v0.9.2a
unity-tech-cn:/develop-gpu-test
unity-tech-cn:/0.10.1
unity-tech-cn:/develop-pyinstaller
unity-tech-cn:/develop-horovod
unity-tech-cn:/PhysXArticulations20201
unity-tech-cn:/importdocfix
unity-tech-cn:/develop-resizetexture
unity-tech-cn:/hh-develop-walljump_bugfixes
unity-tech-cn:/develop-walljump-fix-sac
unity-tech-cn:/hh-develop-walljump_rnd
unity-tech-cn:/tag-0.11.0.dev0
unity-tech-cn:/develop-pytorch
unity-tech-cn:/tag-0.11.0.dev2
unity-tech-cn:/develop-newnormalization
unity-tech-cn:/tag-0.11.0.dev3
unity-tech-cn:/develop
unity-tech-cn:/release-0.12.0
unity-tech-cn:/tag-0.12.0-dev
unity-tech-cn:/tag-0.12.0.dev0
unity-tech-cn:/tag-0.12.1
unity-tech-cn:/2D-explorations
unity-tech-cn:/asymm-envs
unity-tech-cn:/tag-0.12.1.dev0
unity-tech-cn:/2D-exploration-raycast
unity-tech-cn:/tag-0.12.1.dev1
unity-tech-cn:/release-0.13.0
unity-tech-cn:/release-0.13.1
unity-tech-cn:/plugin-proof-of-concept
unity-tech-cn:/release-0.14.0
unity-tech-cn:/hotfix-bump-version-master
unity-tech-cn:/soccer-fives
unity-tech-cn:/release-0.14.1
unity-tech-cn:/bug-failed-api-check
unity-tech-cn:/test-recurrent-gail
unity-tech-cn:/hh-add-icons
unity-tech-cn:/release-0.15.0
unity-tech-cn:/release-0.15.1
unity-tech-cn:/hh-develop-all-posed-characters
unity-tech-cn:/internal-policy-ghost
unity-tech-cn:/distributed-training
unity-tech-cn:/hh-develop-improve_tennis
unity-tech-cn:/test-tf-ver
unity-tech-cn:/release_1_branch
unity-tech-cn:/tennis-time-horizon
unity-tech-cn:/whitepaper-experiments
unity-tech-cn:/r2v-yamato-linux
unity-tech-cn:/docs-update
unity-tech-cn:/release_2_branch
unity-tech-cn:/exp-mede
unity-tech-cn:/sensitivity
unity-tech-cn:/release_2_verified_load_fix
unity-tech-cn:/test-sampler
unity-tech-cn:/release_2_verified
unity-tech-cn:/hh-develop-ragdoll-testing
unity-tech-cn:/origin-develop-taggedobservations
unity-tech-cn:/MLA-1734-demo-provider
unity-tech-cn:/sampler-refactor-copy
unity-tech-cn:/PhysXArticulations20201Package
unity-tech-cn:/tag-com.unity.ml-agents_1.0.8
unity-tech-cn:/release_3_branch
unity-tech-cn:/github-actions
unity-tech-cn:/release_3_distributed
unity-tech-cn:/fix-batch-tennis
unity-tech-cn:/distributed-ppo-sac
unity-tech-cn:/gridworld-custom-obs
unity-tech-cn:/hw20-segmentation
unity-tech-cn:/hh-develop-gamedev-demo
unity-tech-cn:/active-variablespeed
unity-tech-cn:/release_4_branch
unity-tech-cn:/fix-env-step-loop
unity-tech-cn:/release_5_branch
unity-tech-cn:/fix-walker
unity-tech-cn:/release_6_branch
unity-tech-cn:/hh-32-observation-crawler
unity-tech-cn:/trainer-plugin
unity-tech-cn:/hh-develop-max-steps-demo-recorder
unity-tech-cn:/hh-develop-loco-walker-variable-speed
unity-tech-cn:/exp-0002
unity-tech-cn:/experiment-less-max-step
unity-tech-cn:/hh-develop-hallway-wall-mesh-fix
unity-tech-cn:/release_7_branch
unity-tech-cn:/exp-vince
unity-tech-cn:/hh-develop-gridsensor-tests
unity-tech-cn:/tag-release_8_test0
unity-tech-cn:/tag-release_8_test1
unity-tech-cn:/release_8_branch
unity-tech-cn:/docfix-end-episode
unity-tech-cn:/release_9_branch
unity-tech-cn:/hybrid-action-rewardsignals
unity-tech-cn:/MLA-462-yamato-win
unity-tech-cn:/exp-alternate-atten
unity-tech-cn:/hh-develop-fps_game_project
unity-tech-cn:/fix-conflict-base-env
unity-tech-cn:/release_10_branch
unity-tech-cn:/exp-bullet-hell-trainer
unity-tech-cn:/ai-summit-exp
unity-tech-cn:/comms-grad
unity-tech-cn:/walljump-pushblock
unity-tech-cn:/goal-conditioning
unity-tech-cn:/release_11_branch
unity-tech-cn:/hh-develop-water-balloon-fight
unity-tech-cn:/gc-hyper
unity-tech-cn:/layernorm
unity-tech-cn:/yamato-linux-debug-venv
unity-tech-cn:/soccer-comms
unity-tech-cn:/hh-develop-pushblockcollab
unity-tech-cn:/release_12_branch
unity-tech-cn:/fix-get-step-sp-curr
unity-tech-cn:/continuous-comms
unity-tech-cn:/no-comms
unity-tech-cn:/hh-develop-zombiepushblock
unity-tech-cn:/hypernetwork
unity-tech-cn:/revert-4859-develop-update-readme
unity-tech-cn:/sequencer-env-attention
unity-tech-cn:/hh-develop-variableobs
unity-tech-cn:/exp-tanh
unity-tech-cn:/reward-dist
unity-tech-cn:/exp-weight-decay
unity-tech-cn:/exp-robot
unity-tech-cn:/bullet-hell-barracuda-test-1.3.1
unity-tech-cn:/release_13_branch
unity-tech-cn:/release_14_branch
unity-tech-cn:/exp-clipped-gaussian-entropy
unity-tech-cn:/tic-tac-toe
unity-tech-cn:/hh-develop-dodgeball
unity-tech-cn:/repro-vis-obs-perf
unity-tech-cn:/v2-staging-rebase
unity-tech-cn:/release_15_branch
unity-tech-cn:/release_15_removeendepisode
unity-tech-cn:/release_16_branch
unity-tech-cn:/release_16_fix_gridsensor
unity-tech-cn:/ai-hw-2021
unity-tech-cn:/check-for-ModelOverriders
unity-tech-cn:/fix-grid-obs-shape-init
unity-tech-cn:/fix-gym-needs-reset
unity-tech-cn:/fix-resume-imi
unity-tech-cn:/release_17_branch
unity-tech-cn:/release_17_branch_gpu_test
unity-tech-cn:/colab-links
unity-tech-cn:/exp-continuous-div
unity-tech-cn:/release_17_branch_gpu_2
unity-tech-cn:/exp-diverse-behavior
unity-tech-cn:/grid-onehot-extra-dim-empty
unity-tech-cn:/2.0-verified
unity-tech-cn:/faster-entropy-coeficient-convergence
unity-tech-cn:/pre-r18-update-changelog
unity-tech-cn:/release_18_branch
unity-tech-cn:/main/tracking
unity-tech-cn:/main/reward-providers
unity-tech-cn:/main/project-upgrade
unity-tech-cn:/main/limitation-docs
unity-tech-cn:/develop/nomaxstep-test
unity-tech-cn:/develop/tf2.0
unity-tech-cn:/develop/tanhsquash
unity-tech-cn:/develop/magic-string
unity-tech-cn:/develop/trainerinterface
unity-tech-cn:/develop/separatevalue
unity-tech-cn:/develop/nopreviousactions
unity-tech-cn:/develop/reenablerepeatactions
unity-tech-cn:/develop/0memories
unity-tech-cn:/develop/fixmemoryleak
unity-tech-cn:/develop/reducewalljump
unity-tech-cn:/develop/removeactionholder-onehot
unity-tech-cn:/develop/canonicalize-quaternions
unity-tech-cn:/develop/self-playassym
unity-tech-cn:/develop/demo-load-seek
unity-tech-cn:/develop/progress-bar
unity-tech-cn:/develop/sac-apex
unity-tech-cn:/develop/cubewars
unity-tech-cn:/develop/add-fire
unity-tech-cn:/develop/gym-wrapper
unity-tech-cn:/develop/mm-docs-main-readme
unity-tech-cn:/develop/mm-docs-overview
unity-tech-cn:/develop/no-threading
unity-tech-cn:/develop/dockerfile
unity-tech-cn:/develop/model-store
unity-tech-cn:/develop/checkout-conversion-rebase
unity-tech-cn:/develop/model-transfer
unity-tech-cn:/develop/bisim-review
unity-tech-cn:/develop/taggedobservations
unity-tech-cn:/develop/transfer-bisim
unity-tech-cn:/develop/bisim-sac-transfer
unity-tech-cn:/develop/basketball
unity-tech-cn:/develop/torchmodules
unity-tech-cn:/develop/fixmarkdown
unity-tech-cn:/develop/shortenstrikervsgoalie
unity-tech-cn:/develop/shortengoalie
unity-tech-cn:/develop/torch-save-rp
unity-tech-cn:/develop/torch-to-np
unity-tech-cn:/develop/torch-omp-no-thread
unity-tech-cn:/develop/actionmodel-csharp
unity-tech-cn:/develop/torch-extra
unity-tech-cn:/develop/restructure-torch-networks
unity-tech-cn:/develop/jit
unity-tech-cn:/develop/adjust-cpu-settings-experiment
unity-tech-cn:/develop/torch-sac-threading
unity-tech-cn:/develop/wb
unity-tech-cn:/develop/amrl
unity-tech-cn:/develop/memorydump
unity-tech-cn:/develop/permutepytorch
unity-tech-cn:/develop/sac-targetq
unity-tech-cn:/develop/actions-out
unity-tech-cn:/develop/reshapeonnxmemories
unity-tech-cn:/develop/crawlergail
unity-tech-cn:/develop/debugtorchfood
unity-tech-cn:/develop/hybrid-actions
unity-tech-cn:/develop/bullet-hell
unity-tech-cn:/develop/action-spec-gym
unity-tech-cn:/develop/battlefoodcollector
unity-tech-cn:/develop/use-action-buffers
unity-tech-cn:/develop/hardswish
unity-tech-cn:/develop/leakyrelu
unity-tech-cn:/develop/torch-clip-scale
unity-tech-cn:/develop/contentropy
unity-tech-cn:/develop/manch
unity-tech-cn:/develop/torchcrawlerdebug
unity-tech-cn:/develop/fix-nan
unity-tech-cn:/develop/multitype-buffer
unity-tech-cn:/develop/windows-delay
unity-tech-cn:/develop/torch-tanh
unity-tech-cn:/develop/gail-norm
unity-tech-cn:/develop/multiprocess
unity-tech-cn:/develop/unified-obs
unity-tech-cn:/develop/rm-rf-new-models
unity-tech-cn:/develop/skipcritic
unity-tech-cn:/develop/centralizedcritic
unity-tech-cn:/develop/dodgeball-tests
unity-tech-cn:/develop/cc-teammanager
unity-tech-cn:/develop/weight-decay
unity-tech-cn:/develop/singular-embeddings
unity-tech-cn:/develop/zombieteammanager
unity-tech-cn:/develop/superpush
unity-tech-cn:/develop/teammanager
unity-tech-cn:/develop/zombie-exp
unity-tech-cn:/develop/update-readme
unity-tech-cn:/develop/readme-fix
unity-tech-cn:/develop/coma-noact
unity-tech-cn:/develop/coma-withq
unity-tech-cn:/develop/coma2
unity-tech-cn:/develop/action-slice
unity-tech-cn:/develop/gru
unity-tech-cn:/develop/critic-op-lstm-currentmem
unity-tech-cn:/develop/decaygail
unity-tech-cn:/develop/gail-srl-hack
unity-tech-cn:/develop/rear-pad
unity-tech-cn:/develop/mm-copyright-dates
unity-tech-cn:/develop/dodgeball-raycasts
unity-tech-cn:/develop/collab-envs-exp-ervin
unity-tech-cn:/develop/pushcollabonly
unity-tech-cn:/develop/sample-curation
unity-tech-cn:/develop/soccer-groupman
unity-tech-cn:/develop/input-actuator-tanks
unity-tech-cn:/develop/validate-release-fix
unity-tech-cn:/develop/new-console-log
unity-tech-cn:/develop/lex-walker-model
unity-tech-cn:/develop/lstm-burnin
unity-tech-cn:/develop/grid-vaiable-names
unity-tech-cn:/develop/fix-attn-embedding
unity-tech-cn:/develop/api-documentation-update-some-fixes
unity-tech-cn:/develop/update-grpc
unity-tech-cn:/develop/grid-rootref-debug
unity-tech-cn:/develop/pbcollab-rays
unity-tech-cn:/develop/2.0-verified-pre
unity-tech-cn:/develop/parameterizedenvs
unity-tech-cn:/develop/custom-ray-sensor
unity-tech-cn:/develop/mm-add-v2blog
unity-tech-cn:/develop/custom-raycast
unity-tech-cn:/develop/area-manager
unity-tech-cn:/develop/remove-unecessary-lr
unity-tech-cn:/develop/use-base-env-in-learn
unity-tech-cn:/soccer-fives/multiagent
unity-tech-cn:/develop/cubewars/splashdamage
unity-tech-cn:/develop/add-fire/exp
unity-tech-cn:/develop/add-fire/jit
unity-tech-cn:/develop/add-fire/speedtest
unity-tech-cn:/develop/add-fire/bc
unity-tech-cn:/develop/add-fire/ckpt-2
unity-tech-cn:/develop/add-fire/normalize-context
unity-tech-cn:/develop/add-fire/components-dir
unity-tech-cn:/develop/add-fire/halfentropy
unity-tech-cn:/develop/add-fire/memoryclass
unity-tech-cn:/develop/add-fire/categoricaldist
unity-tech-cn:/develop/add-fire/mm
unity-tech-cn:/develop/add-fire/sac-lst
unity-tech-cn:/develop/add-fire/mm3
unity-tech-cn:/develop/add-fire/continuous
unity-tech-cn:/develop/add-fire/ghost
unity-tech-cn:/develop/add-fire/policy-tests
unity-tech-cn:/develop/add-fire/export-discrete
unity-tech-cn:/develop/add-fire/test-simple-rl-fix-resnet
unity-tech-cn:/develop/add-fire/remove-currdoc
unity-tech-cn:/develop/add-fire/clean2
unity-tech-cn:/develop/add-fire/doc-cleanups
unity-tech-cn:/develop/add-fire/changelog
unity-tech-cn:/develop/add-fire/mm2
unity-tech-cn:/develop/model-transfer/add-physics
unity-tech-cn:/develop/model-transfer/train
unity-tech-cn:/develop/jit/experiments
unity-tech-cn:/exp-vince/sep30-2020
unity-tech-cn:/hh-develop-gridsensor-tests/static
unity-tech-cn:/develop/hybrid-actions/distlist
unity-tech-cn:/develop/bullet-hell/buffer
unity-tech-cn:/goal-conditioning/new
unity-tech-cn:/goal-conditioning/sensors-2
unity-tech-cn:/goal-conditioning/sensors-3-pytest-fix
unity-tech-cn:/goal-conditioning/grid-world
unity-tech-cn:/soccer-comms/disc
unity-tech-cn:/develop/centralizedcritic/counterfact
unity-tech-cn:/develop/centralizedcritic/mm
unity-tech-cn:/develop/centralizedcritic/nonego
unity-tech-cn:/develop/zombieteammanager/disableagent
unity-tech-cn:/develop/zombieteammanager/killfirst
unity-tech-cn:/develop/superpush/int
unity-tech-cn:/develop/superpush/branch-cleanup
unity-tech-cn:/develop/teammanager/int
unity-tech-cn:/develop/teammanager/cubewar-nocycle
unity-tech-cn:/develop/teammanager/cubewars
unity-tech-cn:/develop/superpush/int/hunter
unity-tech-cn:/goal-conditioning/new/allo-crawler
unity-tech-cn:/develop/coma2/clip
unity-tech-cn:/develop/coma2/singlenetwork
unity-tech-cn:/develop/coma2/samenet
unity-tech-cn:/develop/coma2/fixgroup
unity-tech-cn:/develop/coma2/samenet/sum
unity-tech-cn:/hh-develop-dodgeball/goy-input
unity-tech-cn:/develop/soccer-groupman/mod
unity-tech-cn:/develop/soccer-groupman/mod/hunter
unity-tech-cn:/develop/soccer-groupman/mod/hunter/cine
unity-tech-cn:/ai-hw-2021/tensor-applier
此合并请求有变更与目标分支冲突。
/test_requirements.txt
/docs/Training-ML-Agents.md
/ml-agents/mlagents/trainers/cli_utils.py
/ml-agents/mlagents/trainers/settings.py
/ml-agents/mlagents/trainers/learn.py
/ml-agents/mlagents/trainers/ppo/trainer.py
/ml-agents/mlagents/trainers/sac/trainer.py
/ml-agents/mlagents/trainers/trainer/rl_trainer.py
/ml-agents/mlagents/trainers/tests/test_rl_trainer.py
/ml-agents/mlagents/trainers/buffer.py
/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
/ml-agents/mlagents/trainers/policy/torch_policy.py
/ml-agents/mlagents/trainers/ppo/optimizer_torch.py
/ml-agents/mlagents/trainers/sac/optimizer_torch.py
/ml-agents/mlagents/trainers/torch
/ml-agents/mlagents/trainers/tests/torch
/.circleci/config.yml
/ml-agents/mlagents/trainers/ppo/optimizer_tf.py
/ml-agents/mlagents/trainers/tests/test_ppo.py
/ml-agents/mlagents/trainers/tests/test_reward_signals.py
/com.unity.ml-agents.extensions/Tests/Editor/Sensors
/ml-agents/mlagents/trainers/ppo/optimizer_tf.py
/utils/validate_release_links.py
/com.unity.ml-agents.extensions/Runtime/Sensors
/com.unity.ml-agents.extensions/Runtime/Sensors
/com.unity.ml-agents.extensions/Runtime/Sensors
/com.unity.ml-agents.extensions/Runtime/Sensors
/com.unity.ml-agents.extensions/Runtime/Sensors
/com.unity.ml-agents.extensions/Runtime/Sensors
/com.unity.ml-agents.extensions/Runtime/Sensors
/com.unity.ml-agents.extensions/Runtime/Sensors
/com.unity.ml-agents.extensions/Runtime/Sensors
/ml-agents/mlagents/trainers/tf/distributions.py
/ml-agents/mlagents/trainers/tf/models.py
/ml-agents/mlagents/trainers/tf/tensorflow_to_barracuda.py
/com.unity.ml-agents.extensions/Tests/Editor/Sensors
/com.unity.ml-agents.extensions/Tests/Editor/Sensors
/com.unity.ml-agents.extensions/Tests/Editor/Sensors
/com.unity.ml-agents.extensions/Tests/Editor/Sensors
1 次代码提交
作者 | SHA1 | 备注 | 提交日期 |
---|---|---|---|
Ervin Teng | dc937d5c | Merge branch 'master' into develop-add-fire-mm | 4 年前 |
共有 110 个文件被更改,包括 6442 次插入 和 267 次删除
-
2.circleci/config.yml
-
3test_requirements.txt
-
0config/ppo/WalkerDynamic.yaml
-
0config/sac/3DBall.yaml
-
0config/sac/3DBallHard.yaml
-
0Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerDynamic.nn
-
0Project/Assets/ML-Agents/Examples/Walker/TFModels/WalkerStatic.nn
-
5Project/ProjectSettings/EditorBuildSettings.asset
-
2Project/ProjectSettings/UnityConnectSettings.asset
-
79docs/Training-ML-Agents.md
-
4ml-agents-envs/mlagents_envs/communicator_objects/unity_to_external_pb2_grpc.py
-
2ml-agents/mlagents/trainers/buffer.py
-
7ml-agents/mlagents/trainers/cli_utils.py
-
2ml-agents/mlagents/trainers/ppo/optimizer_tf.py
-
78ml-agents/mlagents/trainers/ppo/trainer.py
-
118ml-agents/mlagents/trainers/sac/trainer.py
-
35ml-agents/mlagents/trainers/learn.py
-
21ml-agents/mlagents/trainers/settings.py
-
2ml-agents/mlagents/trainers/tests/test_ppo.py
-
2ml-agents/mlagents/trainers/tests/test_reward_signals.py
-
5ml-agents/mlagents/trainers/tests/test_rl_trainer.py
-
64ml-agents/mlagents/trainers/trainer/rl_trainer.py
-
2com.unity.ml-agents/Runtime/SensorHelper.cs.meta
-
2Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/OrientationCubeController.cs.meta
-
4ml-agents/mlagents/trainers/tf/models.py
-
248experiment_torch.py
-
2Project/Assets/csc.rsp
-
111ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
-
289ml-agents/mlagents/trainers/policy/torch_policy.py
-
182ml-agents/mlagents/trainers/ppo/optimizer_torch.py
-
483ml-agents/mlagents/trainers/sac/optimizer_torch.py
-
36ml-agents/mlagents/trainers/tests/test_models.py
-
132utils/validate_release_links.py
-
11com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs.meta
-
49com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
-
11com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs.meta
-
78com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
-
142com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
-
152com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
-
358com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
-
141com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
-
66com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
-
140com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
-
187com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
-
119com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
-
136com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
-
0ml-agents/mlagents/trainers/tf/__init__.py
-
0ml-agents/mlagents/trainers/torch/__init__.py
-
485ml-agents/mlagents/trainers/torch/networks.py
-
0ml-agents/mlagents/trainers/torch/components/__init__.py
-
15ml-agents/mlagents/trainers/torch/components/reward_providers/__init__.py
-
72ml-agents/mlagents/trainers/torch/components/reward_providers/base_reward_provider.py
-
15ml-agents/mlagents/trainers/torch/components/reward_providers/extrinsic_reward_provider.py
-
43ml-agents/mlagents/trainers/torch/components/reward_providers/reward_provider_factory.py
-
225ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
-
254ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
-
23ml-agents/mlagents/trainers/torch/decoders.py
-
206ml-agents/mlagents/trainers/torch/distributions.py
-
280ml-agents/mlagents/trainers/torch/encoders.py
-
48ml-agents/mlagents/trainers/torch/layers.py
-
286ml-agents/mlagents/trainers/torch/utils.py
-
31ml-agents/mlagents/trainers/tests/torch/test_decoders.py
-
141ml-agents/mlagents/trainers/tests/torch/test_distributions.py
-
110ml-agents/mlagents/trainers/tests/torch/test_encoders.py
-
200ml-agents/mlagents/trainers/tests/torch/test_utils.py
-
111ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py
-
56ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py
-
138ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py
-
32ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py
-
20ml-agents/mlagents/trainers/tests/torch/test_layers.py
-
214ml-agents/mlagents/trainers/tests/torch/test_networks.py
-
11Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/OrientationCubeController.cs.meta
-
2Project/Assets/csc.rsp
-
11com.unity.ml-agents/Runtime/SensorHelper.cs.meta
-
36ml-agents/mlagents/trainers/tests/test_models.py
-
132utils/validate_release_links.py
-
0/ml-agents/mlagents/trainers/ppo/optimizer_tf.py
-
0/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs.meta
-
0/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs.meta
-
0/com.unity.ml-agents/Runtime/SensorHelper.cs.meta
-
0/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/OrientationCubeController.cs.meta
-
0/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs.meta
-
0/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs.meta
-
0/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs.meta
-
0/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs
-
0/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs.meta
-
0/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs
-
0/com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs.meta
-
0/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs
|
|||
import json |
|||
import os |
|||
import torch |
|||
from mlagents.tf_utils import tf |
|||
import argparse |
|||
from mlagents.trainers.learn import run_cli, parse_command_line |
|||
from mlagents.trainers.settings import TestingConfiguration |
|||
from mlagents.trainers.stats import StatsReporter |
|||
from mlagents_envs.timers import _thread_timer_stacks |
|||
|
|||
|
|||
def run_experiment( |
|||
name: str, |
|||
steps: int, |
|||
use_torch: bool, |
|||
algo: str, |
|||
num_torch_threads: int, |
|||
use_gpu: bool, |
|||
num_envs: int = 1, |
|||
config_name=None, |
|||
): |
|||
TestingConfiguration.env_name = name |
|||
TestingConfiguration.max_steps = steps |
|||
TestingConfiguration.use_torch = use_torch |
|||
TestingConfiguration.device = "cuda:0" if use_gpu else "cpu" |
|||
if use_gpu: |
|||
tf.device("/GPU:0") |
|||
else: |
|||
tf.device("/device:CPU:0") |
|||
if not torch.cuda.is_available() and use_gpu: |
|||
return ( |
|||
name, |
|||
str(steps), |
|||
str(use_torch), |
|||
algo, |
|||
str(num_torch_threads), |
|||
str(num_envs), |
|||
str(use_gpu), |
|||
"na", |
|||
"na", |
|||
"na", |
|||
"na", |
|||
"na", |
|||
"na", |
|||
"na", |
|||
) |
|||
if config_name is None: |
|||
config_name = name |
|||
run_options = parse_command_line( |
|||
[f"config/{algo}/{config_name}.yaml", "--num-envs", f"{num_envs}"] |
|||
) |
|||
run_options.checkpoint_settings.run_id = ( |
|||
f"{name}_test_" + str(steps) + "_" + ("torch" if use_torch else "tf") |
|||
) |
|||
run_options.checkpoint_settings.force = True |
|||
# run_options.env_settings.num_envs = num_envs |
|||
for trainer_settings in run_options.behaviors.values(): |
|||
trainer_settings.threaded = False |
|||
timers_path = os.path.join( |
|||
"results", run_options.checkpoint_settings.run_id, "run_logs", "timers.json" |
|||
) |
|||
if use_torch: |
|||
torch.set_num_threads(num_torch_threads) |
|||
run_cli(run_options) |
|||
StatsReporter.writers.clear() |
|||
StatsReporter.stats_dict.clear() |
|||
_thread_timer_stacks.clear() |
|||
with open(timers_path) as timers_json_file: |
|||
timers_json = json.load(timers_json_file) |
|||
total = timers_json["total"] |
|||
tc_advance = timers_json["children"]["TrainerController.start_learning"][ |
|||
"children" |
|||
]["TrainerController.advance"] |
|||
evaluate = timers_json["children"]["TrainerController.start_learning"][ |
|||
"children" |
|||
]["TrainerController.advance"]["children"]["env_step"]["children"][ |
|||
"SubprocessEnvManager._take_step" |
|||
][ |
|||
"children" |
|||
] |
|||
update = timers_json["children"]["TrainerController.start_learning"][ |
|||
"children" |
|||
]["TrainerController.advance"]["children"]["trainer_advance"]["children"][ |
|||
"_update_policy" |
|||
][ |
|||
"children" |
|||
] |
|||
tc_advance_total = tc_advance["total"] |
|||
tc_advance_count = tc_advance["count"] |
|||
if use_torch: |
|||
if algo == "ppo": |
|||
update_total = update["TorchPPOOptimizer.update"]["total"] |
|||
update_count = update["TorchPPOOptimizer.update"]["count"] |
|||
else: |
|||
update_total = update["SACTrainer._update_policy"]["total"] |
|||
update_count = update["SACTrainer._update_policy"]["count"] |
|||
evaluate_total = evaluate["TorchPolicy.evaluate"]["total"] |
|||
evaluate_count = evaluate["TorchPolicy.evaluate"]["count"] |
|||
else: |
|||
if algo == "ppo": |
|||
update_total = update["PPOOptimizer.update"]["total"] |
|||
update_count = update["PPOOptimizer.update"]["count"] |
|||
else: |
|||
update_total = update["SACTrainer._update_policy"]["total"] |
|||
update_count = update["SACTrainer._update_policy"]["count"] |
|||
evaluate_total = evaluate["NNPolicy.evaluate"]["total"] |
|||
evaluate_count = evaluate["NNPolicy.evaluate"]["count"] |
|||
# todo: do total / count |
|||
return ( |
|||
name, |
|||
str(steps), |
|||
str(use_torch), |
|||
algo, |
|||
str(num_torch_threads), |
|||
str(num_envs), |
|||
str(use_gpu), |
|||
str(total), |
|||
str(tc_advance_total), |
|||
str(tc_advance_count), |
|||
str(update_total), |
|||
str(update_count), |
|||
str(evaluate_total), |
|||
str(evaluate_count), |
|||
) |
|||
|
|||
|
|||
def main(): |
|||
parser = argparse.ArgumentParser() |
|||
parser.add_argument("--steps", default=25000, type=int, help="The number of steps") |
|||
parser.add_argument("--num-envs", default=1, type=int, help="The number of envs") |
|||
parser.add_argument( |
|||
"--gpu", default=False, action="store_true", help="If true, will use the GPU" |
|||
) |
|||
parser.add_argument( |
|||
"--threads", |
|||
default=False, |
|||
action="store_true", |
|||
help="If true, will try both 1 and 8 threads for torch", |
|||
) |
|||
parser.add_argument( |
|||
"--ball", |
|||
default=False, |
|||
action="store_true", |
|||
help="If true, will only do 3dball", |
|||
) |
|||
parser.add_argument( |
|||
"--sac", |
|||
default=False, |
|||
action="store_true", |
|||
help="If true, will run sac instead of ppo", |
|||
) |
|||
args = parser.parse_args() |
|||
|
|||
if args.gpu: |
|||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|||
else: |
|||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
|||
|
|||
algo = "ppo" |
|||
if args.sac: |
|||
algo = "sac" |
|||
|
|||
envs_config_tuples = [ |
|||
("3DBall", "3DBall"), |
|||
("GridWorld", "GridWorld"), |
|||
("PushBlock", "PushBlock"), |
|||
("CrawlerStaticTarget", "CrawlerStatic"), |
|||
] |
|||
if algo == "ppo": |
|||
envs_config_tuples += [ |
|||
("Hallway", "Hallway"), |
|||
("VisualHallway", "VisualHallway"), |
|||
] |
|||
if args.ball: |
|||
envs_config_tuples = [("3DBall", "3DBall")] |
|||
|
|||
labels = ( |
|||
"name", |
|||
"steps", |
|||
"use_torch", |
|||
"algorithm", |
|||
"num_torch_threads", |
|||
"num_envs", |
|||
"use_gpu", |
|||
"total", |
|||
"tc_advance_total", |
|||
"tc_advance_count", |
|||
"update_total", |
|||
"update_count", |
|||
"evaluate_total", |
|||
"evaluate_count", |
|||
) |
|||
|
|||
results = [] |
|||
results.append(labels) |
|||
f = open( |
|||
f"result_data_steps_{args.steps}_algo_{algo}_envs_{args.num_envs}_gpu_{args.gpu}_thread_{args.threads}.txt", |
|||
"w", |
|||
) |
|||
f.write(" ".join(labels) + "\n") |
|||
|
|||
for env_config in envs_config_tuples: |
|||
data = run_experiment( |
|||
name=env_config[0], |
|||
steps=args.steps, |
|||
use_torch=True, |
|||
algo=algo, |
|||
num_torch_threads=1, |
|||
use_gpu=args.gpu, |
|||
num_envs=args.num_envs, |
|||
config_name=env_config[1], |
|||
) |
|||
results.append(data) |
|||
f.write(" ".join(data) + "\n") |
|||
|
|||
if args.threads: |
|||
data = run_experiment( |
|||
name=env_config[0], |
|||
steps=args.steps, |
|||
use_torch=True, |
|||
algo=algo, |
|||
num_torch_threads=8, |
|||
use_gpu=args.gpu, |
|||
num_envs=args.num_envs, |
|||
config_name=env_config[1], |
|||
) |
|||
results.append(data) |
|||
f.write(" ".join(data) + "\n") |
|||
|
|||
data = run_experiment( |
|||
name=env_config[0], |
|||
steps=args.steps, |
|||
use_torch=False, |
|||
algo=algo, |
|||
num_torch_threads=1, |
|||
use_gpu=args.gpu, |
|||
num_envs=args.num_envs, |
|||
config_name=env_config[1], |
|||
) |
|||
results.append(data) |
|||
f.write(" ".join(data) + "\n") |
|||
for r in results: |
|||
print(*r) |
|||
f.close() |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
main() |
|
|||
-warnaserror+ |
|||
-warnaserror-:618 |
|
|||
from typing import Dict, Optional, Tuple, List |
|||
import torch |
|||
import numpy as np |
|||
from mlagents_envs.base_env import DecisionSteps |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents.trainers.components.bc.module import BCModule |
|||
from mlagents.trainers.torch.components.reward_providers import create_reward_provider |
|||
|
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.optimizer import Optimizer |
|||
from mlagents.trainers.settings import TrainerSettings |
|||
from mlagents.trainers.trajectory import SplitObservations |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
|
|||
|
|||
class TorchOptimizer(Optimizer): # pylint: disable=W0223 |
|||
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): |
|||
super().__init__() |
|||
self.policy = policy |
|||
self.trainer_settings = trainer_settings |
|||
self.update_dict: Dict[str, torch.Tensor] = {} |
|||
self.value_heads: Dict[str, torch.Tensor] = {} |
|||
self.memory_in: torch.Tensor = None |
|||
self.memory_out: torch.Tensor = None |
|||
self.m_size: int = 0 |
|||
self.global_step = torch.tensor(0) |
|||
self.bc_module: Optional[BCModule] = None |
|||
self.create_reward_signals(trainer_settings.reward_signals) |
|||
|
|||
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|||
pass |
|||
|
|||
def create_reward_signals(self, reward_signal_configs): |
|||
""" |
|||
Create reward signals |
|||
:param reward_signal_configs: Reward signal config. |
|||
""" |
|||
for reward_signal, settings in reward_signal_configs.items(): |
|||
# Name reward signals by string in case we have duplicates later |
|||
self.reward_signals[reward_signal.value] = create_reward_provider( |
|||
reward_signal, self.policy.behavior_spec, settings |
|||
) |
|||
|
|||
def get_value_estimates( |
|||
self, decision_requests: DecisionSteps, idx: int, done: bool |
|||
) -> Dict[str, float]: |
|||
""" |
|||
Generates value estimates for bootstrapping. |
|||
:param decision_requests: |
|||
:param idx: Index in BrainInfo of agent. |
|||
:param done: Whether or not this is the last element of the episode, |
|||
in which case the value estimate will be 0. |
|||
:return: The value estimate dictionary with key being the name of the reward signal |
|||
and the value the corresponding value estimate. |
|||
""" |
|||
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs) |
|||
|
|||
value_estimates = self.policy.actor_critic.critic_pass( |
|||
np.expand_dims(vec_vis_obs.vector_observations[idx], 0), |
|||
np.expand_dims(vec_vis_obs.visual_observations[idx], 0), |
|||
) |
|||
|
|||
value_estimates = {k: float(v) for k, v in value_estimates.items()} |
|||
|
|||
# If we're done, reassign all of the value estimates that need terminal states. |
|||
if done: |
|||
for k in value_estimates: |
|||
if not self.reward_signals[k].ignore_done: |
|||
value_estimates[k] = 0.0 |
|||
|
|||
return value_estimates |
|||
|
|||
def get_trajectory_value_estimates( |
|||
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool |
|||
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]: |
|||
vector_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] |
|||
if self.policy.use_vis_obs: |
|||
visual_obs = [] |
|||
for idx, _ in enumerate( |
|||
self.policy.actor_critic.network_body.visual_encoders |
|||
): |
|||
visual_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx]) |
|||
visual_obs.append(visual_ob) |
|||
else: |
|||
visual_obs = [] |
|||
|
|||
memory = torch.zeros([1, len(vector_obs[0]), self.policy.m_size]) |
|||
|
|||
next_obs = np.concatenate(next_obs, axis=-1) |
|||
next_obs = [ModelUtils.list_to_tensor(next_obs).unsqueeze(0)] |
|||
next_memory = torch.zeros([1, 1, self.policy.m_size]) |
|||
|
|||
value_estimates = self.policy.actor_critic.critic_pass( |
|||
vector_obs, visual_obs, memory |
|||
) |
|||
|
|||
next_value_estimate = self.policy.actor_critic.critic_pass( |
|||
next_obs, next_obs, next_memory |
|||
) |
|||
|
|||
for name, estimate in value_estimates.items(): |
|||
value_estimates[name] = estimate.detach().cpu().numpy() |
|||
next_value_estimate[name] = next_value_estimate[name].detach().cpu().numpy() |
|||
|
|||
if done: |
|||
for k in next_value_estimate: |
|||
if not self.reward_signals[k].ignore_done: |
|||
next_value_estimate[k] = 0.0 |
|||
|
|||
return value_estimates, next_value_estimate |
|
|||
from typing import Any, Dict, List |
|||
import numpy as np |
|||
import torch |
|||
|
|||
import os |
|||
from torch import onnx |
|||
from mlagents.model_serialization import SerializationSettings |
|||
from mlagents.trainers.action_info import ActionInfo |
|||
from mlagents.trainers.behavior_id_utils import get_global_agent_id |
|||
|
|||
from mlagents.trainers.policy import Policy |
|||
from mlagents_envs.base_env import DecisionSteps, BehaviorSpec |
|||
from mlagents_envs.timers import timed |
|||
|
|||
from mlagents.trainers.settings import TrainerSettings, TestingConfiguration |
|||
from mlagents.trainers.trajectory import SplitObservations |
|||
from mlagents.trainers.torch.networks import SharedActorCritic, SeparateActorCritic |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
|
|||
EPSILON = 1e-7 # Small value to avoid divide by zero |
|||
|
|||
|
|||
class TorchPolicy(Policy): |
|||
def __init__( |
|||
self, |
|||
seed: int, |
|||
behavior_spec: BehaviorSpec, |
|||
trainer_settings: TrainerSettings, |
|||
model_path: str, |
|||
load: bool = False, |
|||
tanh_squash: bool = False, |
|||
reparameterize: bool = False, |
|||
separate_critic: bool = True, |
|||
condition_sigma_on_obs: bool = True, |
|||
): |
|||
""" |
|||
Policy that uses a multilayer perceptron to map the observations to actions. Could |
|||
also use a CNN to encode visual input prior to the MLP. Supports discrete and |
|||
continuous action spaces, as well as recurrent networks. |
|||
:param seed: Random seed. |
|||
:param brain: Assigned BrainParameters object. |
|||
:param trainer_settings: Defined training parameters. |
|||
:param load: Whether a pre-trained model will be loaded or a new one created. |
|||
:param tanh_squash: Whether to use a tanh function on the continuous output, |
|||
or a clipped output. |
|||
:param reparameterize: Whether we are using the resampling trick to update the policy |
|||
in continuous output. |
|||
""" |
|||
super().__init__( |
|||
seed, |
|||
behavior_spec, |
|||
trainer_settings, |
|||
model_path, |
|||
load, |
|||
tanh_squash, |
|||
reparameterize, |
|||
condition_sigma_on_obs, |
|||
) |
|||
self.global_step = 0 |
|||
self.grads = None |
|||
if TestingConfiguration.device != "cpu": |
|||
torch.set_default_tensor_type(torch.cuda.FloatTensor) |
|||
else: |
|||
torch.set_default_tensor_type(torch.FloatTensor) |
|||
|
|||
reward_signal_configs = trainer_settings.reward_signals |
|||
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] |
|||
|
|||
self.stats_name_to_update_name = { |
|||
"Losses/Value Loss": "value_loss", |
|||
"Losses/Policy Loss": "policy_loss", |
|||
} |
|||
if separate_critic: |
|||
ac_class = SeparateActorCritic |
|||
else: |
|||
ac_class = SharedActorCritic |
|||
self.actor_critic = ac_class( |
|||
observation_shapes=self.behavior_spec.observation_shapes, |
|||
network_settings=trainer_settings.network_settings, |
|||
act_type=behavior_spec.action_type, |
|||
act_size=self.act_size, |
|||
stream_names=reward_signal_names, |
|||
conditional_sigma=self.condition_sigma_on_obs, |
|||
tanh_squash=tanh_squash, |
|||
) |
|||
|
|||
self.actor_critic.to(TestingConfiguration.device) |
|||
|
|||
def split_decision_step(self, decision_requests): |
|||
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs) |
|||
mask = None |
|||
if not self.use_continuous_act: |
|||
mask = torch.ones([len(decision_requests), np.sum(self.act_size)]) |
|||
if decision_requests.action_mask is not None: |
|||
mask = torch.as_tensor( |
|||
1 - np.concatenate(decision_requests.action_mask, axis=1) |
|||
) |
|||
return vec_vis_obs.vector_observations, vec_vis_obs.visual_observations, mask |
|||
|
|||
def update_normalization(self, vector_obs: np.ndarray) -> None: |
|||
""" |
|||
If this policy normalizes vector observations, this will update the norm values in the graph. |
|||
:param vector_obs: The vector observations to add to the running estimate of the distribution. |
|||
""" |
|||
vector_obs = [torch.as_tensor(vector_obs)] |
|||
if self.use_vec_obs and self.normalize: |
|||
self.actor_critic.update_normalization(vector_obs) |
|||
|
|||
@timed |
|||
def sample_actions( |
|||
self, |
|||
vec_obs, |
|||
vis_obs, |
|||
masks=None, |
|||
memories=None, |
|||
seq_len=1, |
|||
all_log_probs=False, |
|||
): |
|||
""" |
|||
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action. |
|||
""" |
|||
dists, value_heads, memories = self.actor_critic.get_dist_and_value( |
|||
vec_obs, vis_obs, masks, memories, seq_len |
|||
) |
|||
action_list = self.actor_critic.sample_action(dists) |
|||
log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy( |
|||
action_list, dists |
|||
) |
|||
actions = torch.stack(action_list, dim=-1) |
|||
if self.use_continuous_act: |
|||
actions = actions[:, :, 0] |
|||
else: |
|||
actions = actions[:, 0, :] |
|||
|
|||
return ( |
|||
actions, |
|||
all_logs if all_log_probs else log_probs, |
|||
entropies, |
|||
value_heads, |
|||
memories, |
|||
) |
|||
|
|||
def evaluate_actions( |
|||
self, vec_obs, vis_obs, actions, masks=None, memories=None, seq_len=1 |
|||
): |
|||
dists, value_heads, _ = self.actor_critic.get_dist_and_value( |
|||
vec_obs, vis_obs, masks, memories, seq_len |
|||
) |
|||
if len(actions.shape) <= 2: |
|||
actions = actions.unsqueeze(-1) |
|||
action_list = [actions[..., i] for i in range(actions.shape[2])] |
|||
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_list, dists) |
|||
|
|||
return log_probs, entropies, value_heads |
|||
|
|||
@timed |
|||
def evaluate( |
|||
self, decision_requests: DecisionSteps, global_agent_ids: List[str] |
|||
) -> Dict[str, Any]: |
|||
""" |
|||
Evaluates policy for the agent experiences provided. |
|||
:param global_agent_ids: |
|||
:param decision_requests: DecisionStep object containing inputs. |
|||
:return: Outputs from network as defined by self.inference_dict. |
|||
""" |
|||
vec_obs, vis_obs, masks = self.split_decision_step(decision_requests) |
|||
vec_obs = [torch.as_tensor(vec_obs)] |
|||
vis_obs = [torch.as_tensor(vis_ob) for vis_ob in vis_obs] |
|||
memories = torch.as_tensor(self.retrieve_memories(global_agent_ids)).unsqueeze( |
|||
0 |
|||
) |
|||
|
|||
run_out = {} |
|||
with torch.no_grad(): |
|||
action, log_probs, entropy, value_heads, memories = self.sample_actions( |
|||
vec_obs, vis_obs, masks=masks, memories=memories |
|||
) |
|||
run_out["action"] = action.detach().cpu().numpy() |
|||
run_out["pre_action"] = action.detach().cpu().numpy() |
|||
# Todo - make pre_action difference |
|||
run_out["log_probs"] = log_probs.detach().cpu().numpy() |
|||
run_out["entropy"] = entropy.detach().cpu().numpy() |
|||
run_out["value_heads"] = { |
|||
name: t.detach().cpu().numpy() for name, t in value_heads.items() |
|||
} |
|||
run_out["value"] = np.mean(list(run_out["value_heads"].values()), 0) |
|||
run_out["learning_rate"] = 0.0 |
|||
if self.use_recurrent: |
|||
run_out["memories"] = memories.detach().cpu().numpy() |
|||
return run_out |
|||
|
|||
def get_action( |
|||
self, decision_requests: DecisionSteps, worker_id: int = 0 |
|||
) -> ActionInfo: |
|||
""" |
|||
Decides actions given observations information, and takes them in environment. |
|||
:param worker_id: |
|||
:param decision_requests: A dictionary of brain names and BrainInfo from environment. |
|||
:return: an ActionInfo containing action, memories, values and an object |
|||
to be passed to add experiences |
|||
""" |
|||
if len(decision_requests) == 0: |
|||
return ActionInfo.empty() |
|||
|
|||
global_agent_ids = [ |
|||
get_global_agent_id(worker_id, int(agent_id)) |
|||
for agent_id in decision_requests.agent_id |
|||
] # For 1-D array, the iterator order is correct. |
|||
|
|||
run_out = self.evaluate( |
|||
decision_requests, global_agent_ids |
|||
) # pylint: disable=assignment-from-no-return |
|||
self.save_memories(global_agent_ids, run_out.get("memory_out")) |
|||
return ActionInfo( |
|||
action=run_out.get("action"), |
|||
value=run_out.get("value"), |
|||
outputs=run_out, |
|||
agent_ids=list(decision_requests.agent_id), |
|||
) |
|||
|
|||
def checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None: |
|||
""" |
|||
Checkpoints the policy on disk. |
|||
|
|||
:param checkpoint_path: filepath to write the checkpoint |
|||
:param settings: SerializationSettings for exporting the model. |
|||
""" |
|||
if not os.path.exists(self.model_path): |
|||
os.makedirs(self.model_path) |
|||
torch.save(self.actor_critic.state_dict(), f"{checkpoint_path}.pt") |
|||
|
|||
def save(self, output_filepath: str, settings: SerializationSettings) -> None: |
|||
self.export_model(self.global_step) |
|||
|
|||
def load_model(self, step=0): # TODO: this doesn't work |
|||
load_path = self.model_path + "/model-" + str(step) + ".pt" |
|||
self.actor_critic.load_state_dict(torch.load(load_path)) |
|||
|
|||
def export_model(self, step=0): |
|||
fake_vec_obs = [torch.zeros([1] + [self.vec_obs_size])] |
|||
fake_vis_obs = [torch.zeros([1] + [84, 84, 3])] |
|||
fake_masks = torch.ones([1] + self.actor_critic.act_size) |
|||
# fake_memories = torch.zeros([1] + [self.m_size]) |
|||
export_path = "./model-" + str(step) + ".onnx" |
|||
output_names = ["action", "action_probs"] |
|||
input_names = ["vector_observation", "action_mask"] |
|||
dynamic_axes = {"vector_observation": [0], "action": [0], "action_probs": [0]} |
|||
onnx.export( |
|||
self.actor_critic, |
|||
(fake_vec_obs, fake_vis_obs, fake_masks), |
|||
export_path, |
|||
verbose=True, |
|||
opset_version=12, |
|||
input_names=input_names, |
|||
output_names=output_names, |
|||
dynamic_axes=dynamic_axes, |
|||
) |
|||
|
|||
@property |
|||
def use_vis_obs(self): |
|||
return self.vis_obs_size > 0 |
|||
|
|||
@property |
|||
def use_vec_obs(self): |
|||
return self.vec_obs_size > 0 |
|||
|
|||
def get_current_step(self): |
|||
""" |
|||
Gets current model step. |
|||
:return: current model step. |
|||
""" |
|||
step = self.global_step |
|||
return step |
|||
|
|||
def increment_step(self, n_steps): |
|||
""" |
|||
Increments model step. |
|||
""" |
|||
self.global_step += n_steps |
|||
return self.get_current_step() |
|||
|
|||
def load_weights(self, values: List[np.ndarray]) -> None: |
|||
pass |
|||
|
|||
def init_load_weights(self) -> None: |
|||
pass |
|||
|
|||
def get_weights(self) -> List[np.ndarray]: |
|||
return [] |
|
|||
from typing import Dict, cast |
|||
import torch |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
|
|||
from mlagents_envs.timers import timed |
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer |
|||
from mlagents.trainers.settings import TrainerSettings, PPOSettings |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
|
|||
|
|||
class TorchPPOOptimizer(TorchOptimizer): |
|||
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): |
|||
""" |
|||
Takes a Policy and a Dict of trainer parameters and creates an Optimizer around the policy. |
|||
The PPO optimizer has a value estimator and a loss function. |
|||
:param policy: A TFPolicy object that will be updated by this PPO Optimizer. |
|||
:param trainer_params: Trainer parameters dictionary that specifies the |
|||
properties of the trainer. |
|||
""" |
|||
# Create the graph here to give more granular control of the TF graph to the Optimizer. |
|||
|
|||
super().__init__(policy, trainer_settings) |
|||
params = list(self.policy.actor_critic.parameters()) |
|||
self.hyperparameters: PPOSettings = cast( |
|||
PPOSettings, trainer_settings.hyperparameters |
|||
) |
|||
self.decay_learning_rate = ModelUtils.DecayedValue( |
|||
self.hyperparameters.learning_rate_schedule, |
|||
self.hyperparameters.learning_rate, |
|||
1e-10, |
|||
self.trainer_settings.max_steps, |
|||
) |
|||
self.decay_epsilon = ModelUtils.DecayedValue( |
|||
self.hyperparameters.learning_rate_schedule, |
|||
self.hyperparameters.epsilon, |
|||
0.1, |
|||
self.trainer_settings.max_steps, |
|||
) |
|||
self.decay_beta = ModelUtils.DecayedValue( |
|||
self.hyperparameters.learning_rate_schedule, |
|||
self.hyperparameters.beta, |
|||
1e-5, |
|||
self.trainer_settings.max_steps, |
|||
) |
|||
|
|||
self.optimizer = torch.optim.Adam( |
|||
params, lr=self.trainer_settings.hyperparameters.learning_rate |
|||
) |
|||
self.stats_name_to_update_name = { |
|||
"Losses/Value Loss": "value_loss", |
|||
"Losses/Policy Loss": "policy_loss", |
|||
} |
|||
|
|||
self.stream_names = list(self.reward_signals.keys()) |
|||
|
|||
def ppo_value_loss( |
|||
self, |
|||
values: Dict[str, torch.Tensor], |
|||
old_values: Dict[str, torch.Tensor], |
|||
returns: Dict[str, torch.Tensor], |
|||
epsilon: float, |
|||
) -> torch.Tensor: |
|||
""" |
|||
Creates training-specific Tensorflow ops for PPO models. |
|||
:param returns: |
|||
:param old_values: |
|||
:param values: |
|||
""" |
|||
value_losses = [] |
|||
for name, head in values.items(): |
|||
old_val_tensor = old_values[name] |
|||
returns_tensor = returns[name] |
|||
clipped_value_estimate = old_val_tensor + torch.clamp( |
|||
head - old_val_tensor, -1 * epsilon, epsilon |
|||
) |
|||
v_opt_a = (returns_tensor - head) ** 2 |
|||
v_opt_b = (returns_tensor - clipped_value_estimate) ** 2 |
|||
value_loss = torch.mean(torch.max(v_opt_a, v_opt_b)) |
|||
value_losses.append(value_loss) |
|||
value_loss = torch.mean(torch.stack(value_losses)) |
|||
return value_loss |
|||
|
|||
def ppo_policy_loss(self, advantages, log_probs, old_log_probs, masks): |
|||
""" |
|||
Creates training-specific Tensorflow ops for PPO models. |
|||
:param masks: |
|||
:param advantages: |
|||
:param log_probs: Current policy probabilities |
|||
:param old_log_probs: Past policy probabilities |
|||
""" |
|||
advantage = advantages.unsqueeze(-1) |
|||
|
|||
decay_epsilon = self.hyperparameters.epsilon |
|||
|
|||
r_theta = torch.exp(log_probs - old_log_probs) |
|||
p_opt_a = r_theta * advantage |
|||
p_opt_b = ( |
|||
torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage |
|||
) |
|||
policy_loss = -torch.mean(torch.min(p_opt_a, p_opt_b)) |
|||
return policy_loss |
|||
|
|||
@timed |
|||
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|||
""" |
|||
Performs update on model. |
|||
:param batch: Batch of experiences. |
|||
:param num_sequences: Number of sequences to process. |
|||
:return: Results of update. |
|||
""" |
|||
# Get decayed parameters |
|||
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) |
|||
decay_eps = self.decay_epsilon.get_value(self.policy.get_current_step()) |
|||
decay_bet = self.decay_beta.get_value(self.policy.get_current_step()) |
|||
returns = {} |
|||
old_values = {} |
|||
for name in self.reward_signals: |
|||
old_values[name] = ModelUtils.list_to_tensor( |
|||
batch[f"{name}_value_estimates"] |
|||
) |
|||
returns[name] = ModelUtils.list_to_tensor(batch[f"{name}_returns"]) |
|||
|
|||
vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] |
|||
act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) |
|||
if self.policy.use_continuous_act: |
|||
actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1) |
|||
else: |
|||
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) |
|||
|
|||
memories = [ |
|||
ModelUtils.list_to_tensor(batch["memory"][i]) |
|||
for i in range(0, len(batch["memory"]), self.policy.sequence_length) |
|||
] |
|||
if len(memories) > 0: |
|||
memories = torch.stack(memories).unsqueeze(0) |
|||
|
|||
if self.policy.use_vis_obs: |
|||
vis_obs = [] |
|||
for idx, _ in enumerate( |
|||
self.policy.actor_critic.network_body.visual_encoders |
|||
): |
|||
vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx]) |
|||
vis_obs.append(vis_ob) |
|||
else: |
|||
vis_obs = [] |
|||
log_probs, entropy, values = self.policy.evaluate_actions( |
|||
vec_obs, |
|||
vis_obs, |
|||
masks=act_masks, |
|||
actions=actions, |
|||
memories=memories, |
|||
seq_len=self.policy.sequence_length, |
|||
) |
|||
value_loss = self.ppo_value_loss(values, old_values, returns, decay_eps) |
|||
policy_loss = self.ppo_policy_loss( |
|||
ModelUtils.list_to_tensor(batch["advantages"]), |
|||
log_probs, |
|||
ModelUtils.list_to_tensor(batch["action_probs"]), |
|||
ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32), |
|||
) |
|||
loss = policy_loss + 0.5 * value_loss - decay_bet * torch.mean(entropy) |
|||
|
|||
# Set optimizer learning rate |
|||
ModelUtils.update_learning_rate(self.optimizer, decay_lr) |
|||
self.optimizer.zero_grad() |
|||
loss.backward() |
|||
|
|||
self.optimizer.step() |
|||
update_stats = { |
|||
"Losses/Policy Loss": abs(policy_loss.detach().cpu().numpy()), |
|||
"Losses/Value Loss": value_loss.detach().cpu().numpy(), |
|||
"Policy/Learning Rate": decay_lr, |
|||
"Policy/Epsilon": decay_eps, |
|||
"Policy/Beta": decay_bet, |
|||
} |
|||
|
|||
for reward_provider in self.reward_signals.values(): |
|||
update_stats.update(reward_provider.update(batch)) |
|||
|
|||
return update_stats |
|
|||
import numpy as np |
|||
from typing import Dict, List, Mapping, cast, Tuple |
|||
import torch |
|||
from torch import nn |
|||
|
|||
from mlagents_envs.logging_util import get_logger |
|||
from mlagents_envs.base_env import ActionType |
|||
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer |
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.settings import NetworkSettings |
|||
from mlagents.trainers.torch.networks import ValueNetwork |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents_envs.timers import timed |
|||
from mlagents.trainers.exception import UnityTrainerException |
|||
from mlagents.trainers.settings import TrainerSettings, SACSettings |
|||
|
|||
EPSILON = 1e-6 # Small value to avoid divide by zero |
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
|
|||
class TorchSACOptimizer(TorchOptimizer): |
|||
class PolicyValueNetwork(nn.Module): |
|||
def __init__( |
|||
self, |
|||
stream_names: List[str], |
|||
observation_shapes: List[Tuple[int, ...]], |
|||
network_settings: NetworkSettings, |
|||
act_type: ActionType, |
|||
act_size: List[int], |
|||
): |
|||
super().__init__() |
|||
if act_type == ActionType.CONTINUOUS: |
|||
num_value_outs = 1 |
|||
num_action_ins = sum(act_size) |
|||
else: |
|||
num_value_outs = sum(act_size) |
|||
num_action_ins = 0 |
|||
self.q1_network = ValueNetwork( |
|||
stream_names, |
|||
observation_shapes, |
|||
network_settings, |
|||
num_action_ins, |
|||
num_value_outs, |
|||
) |
|||
self.q2_network = ValueNetwork( |
|||
stream_names, |
|||
observation_shapes, |
|||
network_settings, |
|||
num_action_ins, |
|||
num_value_outs, |
|||
) |
|||
|
|||
def forward( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
actions: torch.Tensor = None, |
|||
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: |
|||
q1_out, _ = self.q1_network(vec_inputs, vis_inputs, actions=actions) |
|||
q2_out, _ = self.q2_network(vec_inputs, vis_inputs, actions=actions) |
|||
return q1_out, q2_out |
|||
|
|||
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): |
|||
super().__init__(policy, trainer_params) |
|||
hyperparameters: SACSettings = cast(SACSettings, trainer_params.hyperparameters) |
|||
self.tau = hyperparameters.tau |
|||
self.init_entcoef = hyperparameters.init_entcoef |
|||
|
|||
self.policy = policy |
|||
self.act_size = policy.act_size |
|||
policy_network_settings = policy.network_settings |
|||
|
|||
self.tau = hyperparameters.tau |
|||
self.burn_in_ratio = 0.0 |
|||
|
|||
# Non-exposed SAC parameters |
|||
self.discrete_target_entropy_scale = 0.2 # Roughly equal to e-greedy 0.05 |
|||
self.continuous_target_entropy_scale = 1.0 |
|||
|
|||
self.stream_names = list(self.reward_signals.keys()) |
|||
# Use to reduce "survivor bonus" when using Curiosity or GAIL. |
|||
self.gammas = [_val.gamma for _val in trainer_params.reward_signals.values()] |
|||
self.use_dones_in_backup = { |
|||
name: int(not self.reward_signals[name].ignore_done) |
|||
for name in self.stream_names |
|||
} |
|||
|
|||
self.value_network = TorchSACOptimizer.PolicyValueNetwork( |
|||
self.stream_names, |
|||
self.policy.behavior_spec.observation_shapes, |
|||
policy_network_settings, |
|||
self.policy.behavior_spec.action_type, |
|||
self.act_size, |
|||
) |
|||
self.target_network = ValueNetwork( |
|||
self.stream_names, |
|||
self.policy.behavior_spec.observation_shapes, |
|||
policy_network_settings, |
|||
) |
|||
self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0) |
|||
|
|||
self._log_ent_coef = torch.nn.Parameter( |
|||
torch.log(torch.as_tensor([self.init_entcoef] * len(self.act_size))), |
|||
requires_grad=True, |
|||
) |
|||
if self.policy.use_continuous_act: |
|||
self.target_entropy = torch.as_tensor( |
|||
-1 |
|||
* self.continuous_target_entropy_scale |
|||
* np.prod(self.act_size[0]).astype(np.float32) |
|||
) |
|||
else: |
|||
self.target_entropy = [ |
|||
self.discrete_target_entropy_scale * np.log(i).astype(np.float32) |
|||
for i in self.act_size |
|||
] |
|||
|
|||
policy_params = list(self.policy.actor_critic.network_body.parameters()) + list( |
|||
self.policy.actor_critic.distribution.parameters() |
|||
) |
|||
value_params = list(self.value_network.parameters()) + list( |
|||
self.policy.actor_critic.critic.parameters() |
|||
) |
|||
|
|||
logger.debug("value_vars") |
|||
for param in value_params: |
|||
logger.debug(param.shape) |
|||
logger.debug("policy_vars") |
|||
for param in policy_params: |
|||
logger.debug(param.shape) |
|||
|
|||
self.decay_learning_rate = ModelUtils.DecayedValue( |
|||
hyperparameters.learning_rate_schedule, |
|||
hyperparameters.learning_rate, |
|||
1e-10, |
|||
self.trainer_settings.max_steps, |
|||
) |
|||
self.policy_optimizer = torch.optim.Adam( |
|||
policy_params, lr=hyperparameters.learning_rate |
|||
) |
|||
self.value_optimizer = torch.optim.Adam( |
|||
value_params, lr=hyperparameters.learning_rate |
|||
) |
|||
self.entropy_optimizer = torch.optim.Adam( |
|||
[self._log_ent_coef], lr=hyperparameters.learning_rate |
|||
) |
|||
|
|||
def sac_q_loss( |
|||
self, |
|||
q1_out: Dict[str, torch.Tensor], |
|||
q2_out: Dict[str, torch.Tensor], |
|||
target_values: Dict[str, torch.Tensor], |
|||
dones: torch.Tensor, |
|||
rewards: Dict[str, torch.Tensor], |
|||
loss_masks: torch.Tensor, |
|||
) -> Tuple[torch.Tensor, torch.Tensor]: |
|||
q1_losses = [] |
|||
q2_losses = [] |
|||
# Multiple q losses per stream |
|||
for i, name in enumerate(q1_out.keys()): |
|||
q1_stream = q1_out[name].squeeze() |
|||
q2_stream = q2_out[name].squeeze() |
|||
with torch.no_grad(): |
|||
q_backup = rewards[name] + ( |
|||
(1.0 - self.use_dones_in_backup[name] * dones) |
|||
* self.gammas[i] |
|||
* target_values[name] |
|||
) |
|||
_q1_loss = 0.5 * torch.mean( |
|||
loss_masks * torch.nn.functional.mse_loss(q_backup, q1_stream) |
|||
) |
|||
_q2_loss = 0.5 * torch.mean( |
|||
loss_masks * torch.nn.functional.mse_loss(q_backup, q2_stream) |
|||
) |
|||
|
|||
q1_losses.append(_q1_loss) |
|||
q2_losses.append(_q2_loss) |
|||
q1_loss = torch.mean(torch.stack(q1_losses)) |
|||
q2_loss = torch.mean(torch.stack(q2_losses)) |
|||
return q1_loss, q2_loss |
|||
|
|||
def soft_update(self, source: nn.Module, target: nn.Module, tau: float) -> None: |
|||
for source_param, target_param in zip(source.parameters(), target.parameters()): |
|||
target_param.data.copy_( |
|||
target_param.data * (1.0 - tau) + source_param.data * tau |
|||
) |
|||
|
|||
def sac_value_loss( |
|||
self, |
|||
log_probs: torch.Tensor, |
|||
values: Dict[str, torch.Tensor], |
|||
q1p_out: Dict[str, torch.Tensor], |
|||
q2p_out: Dict[str, torch.Tensor], |
|||
loss_masks: torch.Tensor, |
|||
discrete: bool, |
|||
) -> torch.Tensor: |
|||
min_policy_qs = {} |
|||
with torch.no_grad(): |
|||
_ent_coef = torch.exp(self._log_ent_coef) |
|||
for name in values.keys(): |
|||
if not discrete: |
|||
min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name]) |
|||
else: |
|||
action_probs = log_probs.exp() |
|||
_branched_q1p = ModelUtils.break_into_branches( |
|||
q1p_out[name] * action_probs, self.act_size |
|||
) |
|||
_branched_q2p = ModelUtils.break_into_branches( |
|||
q2p_out[name] * action_probs, self.act_size |
|||
) |
|||
_q1p_mean = torch.mean( |
|||
torch.stack( |
|||
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q1p] |
|||
), |
|||
dim=0, |
|||
) |
|||
_q2p_mean = torch.mean( |
|||
torch.stack( |
|||
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q2p] |
|||
), |
|||
dim=0, |
|||
) |
|||
|
|||
min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean) |
|||
|
|||
value_losses = [] |
|||
if not discrete: |
|||
for name in values.keys(): |
|||
with torch.no_grad(): |
|||
v_backup = min_policy_qs[name] - torch.sum( |
|||
_ent_coef * log_probs, dim=1 |
|||
) |
|||
# print(log_probs, v_backup, _ent_coef, loss_masks) |
|||
value_loss = 0.5 * torch.mean( |
|||
loss_masks * torch.nn.functional.mse_loss(values[name], v_backup) |
|||
) |
|||
value_losses.append(value_loss) |
|||
else: |
|||
branched_per_action_ent = ModelUtils.break_into_branches( |
|||
log_probs * log_probs.exp(), self.act_size |
|||
) |
|||
# We have to do entropy bonus per action branch |
|||
branched_ent_bonus = torch.stack( |
|||
[ |
|||
torch.sum(_ent_coef[i] * _lp, dim=1, keepdim=True) |
|||
for i, _lp in enumerate(branched_per_action_ent) |
|||
] |
|||
) |
|||
for name in values.keys(): |
|||
with torch.no_grad(): |
|||
v_backup = min_policy_qs[name] - torch.mean( |
|||
branched_ent_bonus, axis=0 |
|||
) |
|||
value_loss = 0.5 * torch.mean( |
|||
loss_masks |
|||
* torch.nn.functional.mse_loss(values[name], v_backup.squeeze()) |
|||
) |
|||
value_losses.append(value_loss) |
|||
value_loss = torch.mean(torch.stack(value_losses)) |
|||
if torch.isinf(value_loss).any() or torch.isnan(value_loss).any(): |
|||
raise UnityTrainerException("Inf found") |
|||
return value_loss |
|||
|
|||
def sac_policy_loss( |
|||
self, |
|||
log_probs: torch.Tensor, |
|||
q1p_outs: Dict[str, torch.Tensor], |
|||
loss_masks: torch.Tensor, |
|||
discrete: bool, |
|||
) -> torch.Tensor: |
|||
_ent_coef = torch.exp(self._log_ent_coef) |
|||
mean_q1 = torch.mean(torch.stack(list(q1p_outs.values())), axis=0) |
|||
if not discrete: |
|||
mean_q1 = mean_q1.unsqueeze(1) |
|||
batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1) |
|||
policy_loss = torch.mean(loss_masks * batch_policy_loss) |
|||
else: |
|||
action_probs = log_probs.exp() |
|||
branched_per_action_ent = ModelUtils.break_into_branches( |
|||
log_probs * action_probs, self.act_size |
|||
) |
|||
branched_q_term = ModelUtils.break_into_branches( |
|||
mean_q1 * action_probs, self.act_size |
|||
) |
|||
branched_policy_loss = torch.stack( |
|||
[ |
|||
torch.sum(_ent_coef[i] * _lp - _qt, dim=1, keepdim=True) |
|||
for i, (_lp, _qt) in enumerate( |
|||
zip(branched_per_action_ent, branched_q_term) |
|||
) |
|||
] |
|||
) |
|||
batch_policy_loss = torch.squeeze(branched_policy_loss) |
|||
policy_loss = torch.mean(loss_masks * batch_policy_loss) |
|||
return policy_loss |
|||
|
|||
def sac_entropy_loss( |
|||
self, log_probs: torch.Tensor, loss_masks: torch.Tensor, discrete: bool |
|||
) -> torch.Tensor: |
|||
if not discrete: |
|||
with torch.no_grad(): |
|||
target_current_diff = torch.sum(log_probs + self.target_entropy, dim=1) |
|||
entropy_loss = -torch.mean( |
|||
self._log_ent_coef * loss_masks * target_current_diff |
|||
) |
|||
else: |
|||
with torch.no_grad(): |
|||
branched_per_action_ent = ModelUtils.break_into_branches( |
|||
log_probs * log_probs.exp(), self.act_size |
|||
) |
|||
target_current_diff_branched = torch.stack( |
|||
[ |
|||
torch.sum(_lp, axis=1, keepdim=True) + _te |
|||
for _lp, _te in zip( |
|||
branched_per_action_ent, self.target_entropy |
|||
) |
|||
], |
|||
axis=1, |
|||
) |
|||
target_current_diff = torch.squeeze( |
|||
target_current_diff_branched, axis=2 |
|||
) |
|||
entropy_loss = -torch.mean( |
|||
loss_masks |
|||
* torch.mean(self._log_ent_coef * target_current_diff, axis=1) |
|||
) |
|||
|
|||
return entropy_loss |
|||
|
|||
def _condense_q_streams( |
|||
self, q_output: Dict[str, torch.Tensor], discrete_actions: torch.Tensor |
|||
) -> Dict[str, torch.Tensor]: |
|||
condensed_q_output = {} |
|||
onehot_actions = ModelUtils.actions_to_onehot(discrete_actions, self.act_size) |
|||
for key, item in q_output.items(): |
|||
branched_q = ModelUtils.break_into_branches(item, self.act_size) |
|||
only_action_qs = torch.stack( |
|||
[ |
|||
torch.sum(_act * _q, dim=1, keepdim=True) |
|||
for _act, _q in zip(onehot_actions, branched_q) |
|||
] |
|||
) |
|||
|
|||
condensed_q_output[key] = torch.mean(only_action_qs, dim=0) |
|||
return condensed_q_output |
|||
|
|||
@timed |
|||
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|||
""" |
|||
Updates model using buffer. |
|||
:param num_sequences: Number of trajectories in batch. |
|||
:param batch: Experience mini-batch. |
|||
:param update_target: Whether or not to update target value network |
|||
:param reward_signal_batches: Minibatches to use for updating the reward signals, |
|||
indexed by name. If none, don't update the reward signals. |
|||
:return: Output from update process. |
|||
""" |
|||
rewards = {} |
|||
for name in self.reward_signals: |
|||
rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"]) |
|||
|
|||
vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] |
|||
next_vec_obs = [ModelUtils.list_to_tensor(batch["next_vector_in"])] |
|||
act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) |
|||
if self.policy.use_continuous_act: |
|||
actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1) |
|||
else: |
|||
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) |
|||
|
|||
memories = [ |
|||
ModelUtils.list_to_tensor(batch["memory"][i]) |
|||
for i in range(0, len(batch["memory"]), self.policy.sequence_length) |
|||
] |
|||
if len(memories) > 0: |
|||
memories = torch.stack(memories).unsqueeze(0) |
|||
vis_obs: List[torch.Tensor] = [] |
|||
next_vis_obs: List[torch.Tensor] = [] |
|||
if self.policy.use_vis_obs: |
|||
vis_obs = [] |
|||
for idx, _ in enumerate( |
|||
self.policy.actor_critic.network_body.visual_encoders |
|||
): |
|||
vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx]) |
|||
vis_obs.append(vis_ob) |
|||
next_vis_ob = ModelUtils.list_to_tensor( |
|||
batch["next_visual_obs%d" % idx] |
|||
) |
|||
next_vis_obs.append(next_vis_ob) |
|||
|
|||
# Copy normalizers from policy |
|||
self.value_network.q1_network.network_body.copy_normalization( |
|||
self.policy.actor_critic.network_body |
|||
) |
|||
self.value_network.q2_network.network_body.copy_normalization( |
|||
self.policy.actor_critic.network_body |
|||
) |
|||
self.target_network.network_body.copy_normalization( |
|||
self.policy.actor_critic.network_body |
|||
) |
|||
( |
|||
sampled_actions, |
|||
log_probs, |
|||
entropies, |
|||
sampled_values, |
|||
_, |
|||
) = self.policy.sample_actions( |
|||
vec_obs, |
|||
vis_obs, |
|||
masks=act_masks, |
|||
memories=memories, |
|||
seq_len=self.policy.sequence_length, |
|||
all_log_probs=not self.policy.use_continuous_act, |
|||
) |
|||
if self.policy.use_continuous_act: |
|||
squeezed_actions = actions.squeeze(-1) |
|||
q1p_out, q2p_out = self.value_network(vec_obs, vis_obs, sampled_actions) |
|||
q1_out, q2_out = self.value_network(vec_obs, vis_obs, squeezed_actions) |
|||
q1_stream, q2_stream = q1_out, q2_out |
|||
else: |
|||
with torch.no_grad(): |
|||
q1p_out, q2p_out = self.value_network(vec_obs, vis_obs) |
|||
q1_out, q2_out = self.value_network(vec_obs, vis_obs) |
|||
q1_stream = self._condense_q_streams(q1_out, actions) |
|||
q2_stream = self._condense_q_streams(q2_out, actions) |
|||
|
|||
with torch.no_grad(): |
|||
target_values, _ = self.target_network(next_vec_obs, next_vis_obs) |
|||
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32) |
|||
use_discrete = not self.policy.use_continuous_act |
|||
dones = ModelUtils.list_to_tensor(batch["done"]) |
|||
|
|||
q1_loss, q2_loss = self.sac_q_loss( |
|||
q1_stream, q2_stream, target_values, dones, rewards, masks |
|||
) |
|||
value_loss = self.sac_value_loss( |
|||
log_probs, sampled_values, q1p_out, q2p_out, masks, use_discrete |
|||
) |
|||
policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks, use_discrete) |
|||
entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete) |
|||
|
|||
total_value_loss = q1_loss + q2_loss + value_loss |
|||
|
|||
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) |
|||
ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr) |
|||
self.policy_optimizer.zero_grad() |
|||
policy_loss.backward() |
|||
self.policy_optimizer.step() |
|||
|
|||
ModelUtils.update_learning_rate(self.value_optimizer, decay_lr) |
|||
self.value_optimizer.zero_grad() |
|||
total_value_loss.backward() |
|||
self.value_optimizer.step() |
|||
|
|||
ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr) |
|||
self.entropy_optimizer.zero_grad() |
|||
entropy_loss.backward() |
|||
self.entropy_optimizer.step() |
|||
|
|||
# Update target network |
|||
self.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau) |
|||
update_stats = { |
|||
"Losses/Policy Loss": abs(policy_loss.detach().cpu().numpy()), |
|||
"Losses/Value Loss": value_loss.detach().cpu().numpy(), |
|||
"Losses/Q1 Loss": q1_loss.detach().cpu().numpy(), |
|||
"Losses/Q2 Loss": q2_loss.detach().cpu().numpy(), |
|||
"Policy/Entropy Coeff": torch.exp(self._log_ent_coef) |
|||
.detach() |
|||
.cpu() |
|||
.numpy(), |
|||
"Policy/Learning Rate": decay_lr, |
|||
} |
|||
|
|||
for signal in self.reward_signals.values(): |
|||
signal.update(batch) |
|||
|
|||
return update_stats |
|||
|
|||
def update_reward_signals( |
|||
self, reward_signal_minibatches: Mapping[str, AgentBuffer], num_sequences: int |
|||
) -> Dict[str, float]: |
|||
return {} |
|
|||
import pytest |
|||
|
|||
from mlagents.trainers.tf.models import ModelUtils |
|||
from mlagents.tf_utils import tf |
|||
from mlagents_envs.base_env import BehaviorSpec, ActionType |
|||
|
|||
|
|||
def create_behavior_spec(num_visual, num_vector, vector_size): |
|||
behavior_spec = BehaviorSpec( |
|||
[(84, 84, 3)] * int(num_visual) + [(vector_size,)] * int(num_vector), |
|||
ActionType.DISCRETE, |
|||
(1,), |
|||
) |
|||
return behavior_spec |
|||
|
|||
|
|||
@pytest.mark.parametrize("num_visual", [1, 2, 4]) |
|||
@pytest.mark.parametrize("num_vector", [1, 2, 4]) |
|||
def test_create_input_placeholders(num_vector, num_visual): |
|||
vec_size = 8 |
|||
name_prefix = "test123" |
|||
bspec = create_behavior_spec(num_visual, num_vector, vec_size) |
|||
vec_in, vis_in = ModelUtils.create_input_placeholders( |
|||
bspec.observation_shapes, name_prefix=name_prefix |
|||
) |
|||
|
|||
assert isinstance(vis_in, list) |
|||
assert len(vis_in) == num_visual |
|||
assert isinstance(vec_in, tf.Tensor) |
|||
assert vec_in.get_shape().as_list()[1] == num_vector * 8 |
|||
|
|||
# Check names contain prefix and vis shapes are correct |
|||
for _vis in vis_in: |
|||
assert _vis.get_shape().as_list() == [None, 84, 84, 3] |
|||
assert _vis.name.startswith(name_prefix) |
|||
assert vec_in.name.startswith(name_prefix) |
|
|||
#!/usr/bin/env python3 |
|||
|
|||
import ast |
|||
import sys |
|||
import os |
|||
import re |
|||
import subprocess |
|||
from typing import List, Optional, Pattern |
|||
|
|||
RELEASE_PATTERN = re.compile(r"release_[0-9]+(_docs)*") |
|||
TRAINER_INIT_FILE = "ml-agents/mlagents/trainers/__init__.py" |
|||
|
|||
# Filename -> regex list to allow specific lines. |
|||
# To allow everything in the file, use None for the value |
|||
ALLOW_LIST = { |
|||
# Previous release table |
|||
"README.md": re.compile(r"\*\*Release [0-9]+\*\*"), |
|||
"docs/Versioning.md": None, |
|||
"com.unity.ml-agents/CHANGELOG.md": None, |
|||
"utils/make_readme_table.py": None, |
|||
"utils/validate_release_links.py": None, |
|||
} |
|||
|
|||
|
|||
def test_pattern(): |
|||
# Just some sanity check that the regex works as expected. |
|||
assert RELEASE_PATTERN.search( |
|||
"https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/Food.md" |
|||
) |
|||
assert RELEASE_PATTERN.search( |
|||
"https://github.com/Unity-Technologies/ml-agents/blob/release_4/Foo.md" |
|||
) |
|||
assert RELEASE_PATTERN.search( |
|||
"git clone --branch release_4 https://github.com/Unity-Technologies/ml-agents.git" |
|||
) |
|||
assert RELEASE_PATTERN.search( |
|||
"https://github.com/Unity-Technologies/ml-agents/blob/release_123_docs/Foo.md" |
|||
) |
|||
assert RELEASE_PATTERN.search( |
|||
"https://github.com/Unity-Technologies/ml-agents/blob/release_123/Foo.md" |
|||
) |
|||
assert not RELEASE_PATTERN.search( |
|||
"https://github.com/Unity-Technologies/ml-agents/blob/latest_release/docs/Foo.md" |
|||
) |
|||
print("tests OK!") |
|||
|
|||
|
|||
def git_ls_files() -> List[str]: |
|||
""" |
|||
Run "git ls-files" and return a list with one entry per line. |
|||
This returns the list of all files tracked by git. |
|||
""" |
|||
return subprocess.check_output(["git", "ls-files"], universal_newlines=True).split( |
|||
"\n" |
|||
) |
|||
|
|||
|
|||
def get_release_tag() -> Optional[str]: |
|||
""" |
|||
Returns the release tag for the mlagents python package. |
|||
This will be None on the master branch. |
|||
:return: |
|||
""" |
|||
with open(TRAINER_INIT_FILE) as f: |
|||
for line in f: |
|||
if "__release_tag__" in line: |
|||
lhs, equals_string, rhs = line.strip().partition(" = ") |
|||
# Evaluate the right hand side of the expression |
|||
return ast.literal_eval(rhs) |
|||
# If we couldn't find the release tag, raise an exception |
|||
# (since we can't return None here) |
|||
raise RuntimeError("Can't determine release tag") |
|||
|
|||
|
|||
def check_file(filename: str, global_allow_pattern: Pattern) -> List[str]: |
|||
""" |
|||
Validate a single file and return any offending lines. |
|||
""" |
|||
bad_lines = [] |
|||
with open(filename) as f: |
|||
for line in f: |
|||
if not RELEASE_PATTERN.search(line): |
|||
continue |
|||
|
|||
if global_allow_pattern.search(line): |
|||
continue |
|||
|
|||
if filename in ALLOW_LIST: |
|||
if ALLOW_LIST[filename] is None or ALLOW_LIST[filename].search(line): |
|||
continue |
|||
|
|||
bad_lines.append(f"{filename}: {line.strip()}") |
|||
return bad_lines |
|||
|
|||
|
|||
def check_all_files(allow_pattern: Pattern) -> List[str]: |
|||
""" |
|||
Validate all files tracked by git. |
|||
:param allow_pattern: |
|||
""" |
|||
bad_lines = [] |
|||
file_types = {".py", ".md", ".cs"} |
|||
for file_name in git_ls_files(): |
|||
if "localized" in file_name or os.path.splitext(file_name)[1] not in file_types: |
|||
continue |
|||
bad_lines += check_file(file_name, allow_pattern) |
|||
return bad_lines |
|||
|
|||
|
|||
def main(): |
|||
release_tag = get_release_tag() |
|||
if not release_tag: |
|||
print("Release tag is None, exiting") |
|||
sys.exit(0) |
|||
|
|||
print(f"Release tag: {release_tag}") |
|||
allow_pattern = re.compile(f"{release_tag}(_docs)*") |
|||
bad_lines = check_all_files(allow_pattern) |
|||
if bad_lines: |
|||
print( |
|||
f"Found lines referring to previous release. Either update the files, or add an exclusion to {__file__}" |
|||
) |
|||
for line in bad_lines: |
|||
print(line) |
|||
|
|||
sys.exit(1 if bad_lines else 0) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
if "--test" in sys.argv: |
|||
test_pattern() |
|||
main() |
|
|||
fileFormatVersion: 2 |
|||
guid: 11fe037a02b4a483cb9342c3454232cd |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
#if UNITY_2020_1_OR_NEWER
|
|||
using UnityEngine; |
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
public class ArticulationBodySensorComponent : SensorComponent |
|||
{ |
|||
public ArticulationBody RootBody; |
|||
|
|||
[SerializeField] |
|||
public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default(); |
|||
public string sensorName; |
|||
|
|||
/// <summary>
|
|||
/// Creates a PhysicsBodySensor.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
public override ISensor CreateSensor() |
|||
{ |
|||
return new PhysicsBodySensor(RootBody, Settings, sensorName); |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public override int[] GetObservationShape() |
|||
{ |
|||
if (RootBody == null) |
|||
{ |
|||
return new[] { 0 }; |
|||
} |
|||
|
|||
// TODO static method in PhysicsBodySensor?
|
|||
// TODO only update PoseExtractor when body changes?
|
|||
var poseExtractor = new ArticulationBodyPoseExtractor(RootBody); |
|||
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings); |
|||
var numJointObservations = 0; |
|||
// Start from i=1 to ignore the root
|
|||
for (var i = 1; i < poseExtractor.Bodies.Length; i++) |
|||
{ |
|||
numJointObservations += ArticulationBodyJointExtractor.NumObservations( |
|||
poseExtractor.Bodies[i], Settings |
|||
); |
|||
} |
|||
return new[] { numPoseObservations + numJointObservations }; |
|||
} |
|||
} |
|||
|
|||
} |
|||
#endif // UNITY_2020_1_OR_NEWER
|
|
|||
fileFormatVersion: 2 |
|||
guid: fcb7a51f0d5f8404db7b85bd35ecc1fb |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
#if UNITY_2020_1_OR_NEWER
|
|||
|
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
/// <summary>
|
|||
/// Utility class to track a hierarchy of ArticulationBodies.
|
|||
/// </summary>
|
|||
public class ArticulationBodyPoseExtractor : PoseExtractor |
|||
{ |
|||
ArticulationBody[] m_Bodies; |
|||
|
|||
public ArticulationBodyPoseExtractor(ArticulationBody rootBody) |
|||
{ |
|||
if (rootBody == null) |
|||
{ |
|||
return; |
|||
} |
|||
|
|||
if (!rootBody.isRoot) |
|||
{ |
|||
Debug.Log("Must pass ArticulationBody.isRoot"); |
|||
return; |
|||
} |
|||
|
|||
var bodies = rootBody.GetComponentsInChildren <ArticulationBody>(); |
|||
if (bodies[0] != rootBody) |
|||
{ |
|||
Debug.Log("Expected root body at index 0"); |
|||
return; |
|||
} |
|||
|
|||
var numBodies = bodies.Length; |
|||
m_Bodies = bodies; |
|||
int[] parentIndices = new int[numBodies]; |
|||
parentIndices[0] = -1; |
|||
|
|||
var bodyToIndex = new Dictionary<ArticulationBody, int>(); |
|||
for (var i = 0; i < numBodies; i++) |
|||
{ |
|||
bodyToIndex[m_Bodies[i]] = i; |
|||
} |
|||
|
|||
for (var i = 1; i < numBodies; i++) |
|||
{ |
|||
var currentArticBody = m_Bodies[i]; |
|||
// Component.GetComponentInParent will consider the provided object as well.
|
|||
// So start looking from the parent.
|
|||
var currentGameObject = currentArticBody.gameObject; |
|||
var parentGameObject = currentGameObject.transform.parent; |
|||
var parentArticBody = parentGameObject.GetComponentInParent<ArticulationBody>(); |
|||
parentIndices[i] = bodyToIndex[parentArticBody]; |
|||
} |
|||
|
|||
Setup(parentIndices); |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
protected internal override Vector3 GetLinearVelocityAt(int index) |
|||
{ |
|||
return m_Bodies[index].velocity; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
protected internal override Pose GetPoseAt(int index) |
|||
{ |
|||
var body = m_Bodies[index]; |
|||
var go = body.gameObject; |
|||
var t = go.transform; |
|||
return new Pose { rotation = t.rotation, position = t.position }; |
|||
} |
|||
|
|||
internal ArticulationBody[] Bodies => m_Bodies; |
|||
} |
|||
} |
|||
#endif // UNITY_2020_1_OR_NEWER
|
|
|||
using UnityEngine; |
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
/// <summary>
|
|||
/// ISensor implementation that generates observations for a group of Rigidbodies or ArticulationBodies.
|
|||
/// </summary>
|
|||
public class PhysicsBodySensor : ISensor |
|||
{ |
|||
int[] m_Shape; |
|||
string m_SensorName; |
|||
|
|||
PoseExtractor m_PoseExtractor; |
|||
IJointExtractor[] m_JointExtractors; |
|||
PhysicsSensorSettings m_Settings; |
|||
|
|||
/// <summary>
|
|||
/// Construct a new PhysicsBodySensor
|
|||
/// </summary>
|
|||
/// <param name="rootBody">The root Rigidbody. This has no Joints on it (but other Joints may connect to it).</param>
|
|||
/// <param name="rootGameObject">Optional GameObject used to find Rigidbodies in the hierarchy.</param>
|
|||
/// <param name="virtualRoot">Optional GameObject used to determine the root of the poses,
|
|||
/// <param name="settings"></param>
|
|||
/// <param name="sensorName"></param>
|
|||
public PhysicsBodySensor( |
|||
Rigidbody rootBody, |
|||
GameObject rootGameObject, |
|||
GameObject virtualRoot, |
|||
PhysicsSensorSettings settings, |
|||
string sensorName=null |
|||
) |
|||
{ |
|||
var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject, virtualRoot); |
|||
m_PoseExtractor = poseExtractor; |
|||
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName; |
|||
m_Settings = settings; |
|||
|
|||
var numJointExtractorObservations = 0; |
|||
var rigidBodies = poseExtractor.Bodies; |
|||
if (rigidBodies != null) |
|||
{ |
|||
m_JointExtractors = new IJointExtractor[rigidBodies.Length - 1]; // skip the root
|
|||
for (var i = 1; i < rigidBodies.Length; i++) |
|||
{ |
|||
var jointExtractor = new RigidBodyJointExtractor(rigidBodies[i]); |
|||
numJointExtractorObservations += jointExtractor.NumObservations(settings); |
|||
m_JointExtractors[i - 1] = jointExtractor; |
|||
} |
|||
} |
|||
else |
|||
{ |
|||
m_JointExtractors = new IJointExtractor[0]; |
|||
} |
|||
|
|||
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); |
|||
m_Shape = new[] { numTransformObservations + numJointExtractorObservations }; |
|||
} |
|||
|
|||
#if UNITY_2020_1_OR_NEWER
|
|||
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null) |
|||
{ |
|||
var poseExtractor = new ArticulationBodyPoseExtractor(rootBody); |
|||
m_PoseExtractor = poseExtractor; |
|||
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName; |
|||
m_Settings = settings; |
|||
|
|||
var numJointExtractorObservations = 0; |
|||
var articBodies = poseExtractor.Bodies; |
|||
if (articBodies != null) |
|||
{ |
|||
m_JointExtractors = new IJointExtractor[articBodies.Length - 1]; // skip the root
|
|||
for (var i = 1; i < articBodies.Length; i++) |
|||
{ |
|||
var jointExtractor = new ArticulationBodyJointExtractor(articBodies[i]); |
|||
numJointExtractorObservations += jointExtractor.NumObservations(settings); |
|||
m_JointExtractors[i - 1] = jointExtractor; |
|||
} |
|||
} |
|||
else |
|||
{ |
|||
m_JointExtractors = new IJointExtractor[0]; |
|||
} |
|||
|
|||
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); |
|||
m_Shape = new[] { numTransformObservations + numJointExtractorObservations }; |
|||
} |
|||
#endif
|
|||
|
|||
/// <inheritdoc/>
|
|||
public int[] GetObservationShape() |
|||
{ |
|||
return m_Shape; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public int Write(ObservationWriter writer) |
|||
{ |
|||
var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor); |
|||
foreach (var jointExtractor in m_JointExtractors) |
|||
{ |
|||
numWritten += jointExtractor.Write(m_Settings, writer, numWritten); |
|||
} |
|||
return numWritten; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public byte[] GetCompressedObservation() |
|||
{ |
|||
return null; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public void Update() |
|||
{ |
|||
if (m_Settings.UseModelSpace) |
|||
{ |
|||
m_PoseExtractor.UpdateModelSpacePoses(); |
|||
} |
|||
|
|||
if (m_Settings.UseLocalSpace) |
|||
{ |
|||
m_PoseExtractor.UpdateLocalSpacePoses(); |
|||
} |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public void Reset() {} |
|||
|
|||
/// <inheritdoc/>
|
|||
public SensorCompressionType GetCompressionType() |
|||
{ |
|||
return SensorCompressionType.None; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public string GetName() |
|||
{ |
|||
return m_SensorName; |
|||
} |
|||
} |
|||
} |
|
|||
using System; |
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
/// <summary>
|
|||
/// Settings that define the observations generated for physics-based sensors.
|
|||
/// </summary>
|
|||
[Serializable] |
|||
public struct PhysicsSensorSettings |
|||
{ |
|||
/// <summary>
|
|||
/// Whether to use model space (relative to the root body) translations as observations.
|
|||
/// </summary>
|
|||
public bool UseModelSpaceTranslations; |
|||
|
|||
/// <summary>
|
|||
/// Whether to use model space (relative to the root body) rotations as observations.
|
|||
/// </summary>
|
|||
public bool UseModelSpaceRotations; |
|||
|
|||
/// <summary>
|
|||
/// Whether to use local space (relative to the parent body) translations as observations.
|
|||
/// </summary>
|
|||
public bool UseLocalSpaceTranslations; |
|||
|
|||
/// <summary>
|
|||
/// Whether to use local space (relative to the parent body) translations as observations.
|
|||
/// </summary>
|
|||
public bool UseLocalSpaceRotations; |
|||
|
|||
/// <summary>
|
|||
/// Whether to use model space (relative to the root body) linear velocities as observations.
|
|||
/// </summary>
|
|||
public bool UseModelSpaceLinearVelocity; |
|||
|
|||
/// <summary>
|
|||
/// Whether to use local space (relative to the parent body) linear velocities as observations.
|
|||
/// </summary>
|
|||
public bool UseLocalSpaceLinearVelocity; |
|||
|
|||
/// <summary>
|
|||
/// Whether to use joint-specific positions and angles as observations.
|
|||
/// </summary>
|
|||
public bool UseJointPositionsAndAngles; |
|||
|
|||
/// <summary>
|
|||
/// Whether to use the joint forces and torques that are applied by the solver as observations.
|
|||
/// </summary>
|
|||
public bool UseJointForces; |
|||
|
|||
/// <summary>
|
|||
/// Creates a PhysicsSensorSettings with reasonable default values.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
public static PhysicsSensorSettings Default() |
|||
{ |
|||
return new PhysicsSensorSettings |
|||
{ |
|||
UseModelSpaceTranslations = true, |
|||
UseModelSpaceRotations = true, |
|||
}; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Whether any model space observations are being used.
|
|||
/// </summary>
|
|||
public bool UseModelSpace |
|||
{ |
|||
get { return UseModelSpaceTranslations || UseModelSpaceRotations || UseModelSpaceLinearVelocity; } |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Whether any local space observations are being used.
|
|||
/// </summary>
|
|||
public bool UseLocalSpace |
|||
{ |
|||
get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity; } |
|||
} |
|||
} |
|||
|
|||
internal static class ObservationWriterPhysicsExtensions |
|||
{ |
|||
/// <summary>
|
|||
/// Utility method for writing a PoseExtractor to an ObservationWriter.
|
|||
/// </summary>
|
|||
/// <param name="writer"></param>
|
|||
/// <param name="settings"></param>
|
|||
/// <param name="poseExtractor"></param>
|
|||
/// <param name="baseOffset">The offset into the ObservationWriter to start writing at.</param>
|
|||
/// <returns>The number of observations written.</returns>
|
|||
public static int WritePoses(this ObservationWriter writer, PhysicsSensorSettings settings, PoseExtractor poseExtractor, int baseOffset = 0) |
|||
{ |
|||
var offset = baseOffset; |
|||
if (settings.UseModelSpace) |
|||
{ |
|||
foreach (var pose in poseExtractor.GetEnabledModelSpacePoses()) |
|||
{ |
|||
if (settings.UseModelSpaceTranslations) |
|||
{ |
|||
writer.Add(pose.position, offset); |
|||
offset += 3; |
|||
} |
|||
|
|||
if (settings.UseModelSpaceRotations) |
|||
{ |
|||
writer.Add(pose.rotation, offset); |
|||
offset += 4; |
|||
} |
|||
} |
|||
|
|||
foreach(var vel in poseExtractor.GetEnabledModelSpaceVelocities()) |
|||
{ |
|||
if (settings.UseModelSpaceLinearVelocity) |
|||
{ |
|||
writer.Add(vel, offset); |
|||
offset += 3; |
|||
} |
|||
} |
|||
} |
|||
|
|||
if (settings.UseLocalSpace) |
|||
{ |
|||
foreach (var pose in poseExtractor.GetEnabledLocalSpacePoses()) |
|||
{ |
|||
if (settings.UseLocalSpaceTranslations) |
|||
{ |
|||
writer.Add(pose.position, offset); |
|||
offset += 3; |
|||
} |
|||
|
|||
if (settings.UseLocalSpaceRotations) |
|||
{ |
|||
writer.Add(pose.rotation, offset); |
|||
offset += 4; |
|||
} |
|||
} |
|||
|
|||
foreach(var vel in poseExtractor.GetEnabledLocalSpaceVelocities()) |
|||
{ |
|||
if (settings.UseLocalSpaceLinearVelocity) |
|||
{ |
|||
writer.Add(vel, offset); |
|||
offset += 3; |
|||
} |
|||
} |
|||
} |
|||
|
|||
return offset - baseOffset; |
|||
} |
|||
} |
|||
} |
|
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
/// <summary>
|
|||
/// Abstract class for managing the transforms of a hierarchy of objects.
|
|||
/// This could be GameObjects or Monobehaviours in the scene graph, but this is
|
|||
/// not a requirement; for example, the objects could be rigid bodies whose hierarchy
|
|||
/// is defined by Joint configurations.
|
|||
///
|
|||
/// Poses are either considered in model space, which is relative to a root body,
|
|||
/// or in local space, which is relative to their parent.
|
|||
/// </summary>
|
|||
public abstract class PoseExtractor |
|||
{ |
|||
int[] m_ParentIndices; |
|||
Pose[] m_ModelSpacePoses; |
|||
Pose[] m_LocalSpacePoses; |
|||
|
|||
Vector3[] m_ModelSpaceLinearVelocities; |
|||
Vector3[] m_LocalSpaceLinearVelocities; |
|||
|
|||
bool[] m_PoseEnabled; |
|||
|
|||
|
|||
/// <summary>
|
|||
/// Read iterator for the enabled model space transforms.
|
|||
/// </summary>
|
|||
public IEnumerable<Pose> GetEnabledModelSpacePoses() |
|||
{ |
|||
if (m_ModelSpacePoses == null) |
|||
{ |
|||
yield break; |
|||
} |
|||
|
|||
for (var i = 0; i < m_ModelSpacePoses.Length; i++) |
|||
{ |
|||
if (m_PoseEnabled[i]) |
|||
{ |
|||
yield return m_ModelSpacePoses[i]; |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Read iterator for the enabled local space transforms.
|
|||
/// </summary>
|
|||
public IEnumerable<Pose> GetEnabledLocalSpacePoses() |
|||
{ |
|||
if (m_LocalSpacePoses == null) |
|||
{ |
|||
yield break; |
|||
} |
|||
|
|||
for (var i = 0; i < m_LocalSpacePoses.Length; i++) |
|||
{ |
|||
if (m_PoseEnabled[i]) |
|||
{ |
|||
yield return m_LocalSpacePoses[i]; |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Read iterator for the enabled model space linear velocities.
|
|||
/// </summary>
|
|||
public IEnumerable<Vector3> GetEnabledModelSpaceVelocities() |
|||
{ |
|||
if (m_ModelSpaceLinearVelocities == null) |
|||
{ |
|||
yield break; |
|||
} |
|||
|
|||
for (var i = 0; i < m_ModelSpaceLinearVelocities.Length; i++) |
|||
{ |
|||
if (m_PoseEnabled[i]) |
|||
{ |
|||
yield return m_ModelSpaceLinearVelocities[i]; |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Read iterator for the enabled local space linear velocities.
|
|||
/// </summary>
|
|||
public IEnumerable<Vector3> GetEnabledLocalSpaceVelocities() |
|||
{ |
|||
if (m_LocalSpaceLinearVelocities == null) |
|||
{ |
|||
yield break; |
|||
} |
|||
|
|||
for (var i = 0; i < m_LocalSpaceLinearVelocities.Length; i++) |
|||
{ |
|||
if (m_PoseEnabled[i]) |
|||
{ |
|||
yield return m_LocalSpaceLinearVelocities[i]; |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Number of enabled poses in the hierarchy (read-only).
|
|||
/// </summary>
|
|||
public int NumEnabledPoses |
|||
{ |
|||
get |
|||
{ |
|||
if (m_PoseEnabled == null) |
|||
{ |
|||
return 0; |
|||
} |
|||
|
|||
var numEnabled = 0; |
|||
for (var i = 0; i < m_PoseEnabled.Length; i++) |
|||
{ |
|||
numEnabled += m_PoseEnabled[i] ? 1 : 0; |
|||
} |
|||
|
|||
return numEnabled; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Number of total poses in the hierarchy (read-only).
|
|||
/// </summary>
|
|||
public int NumPoses |
|||
{ |
|||
get { return m_ModelSpacePoses?.Length ?? 0; } |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Get the parent index of the body at the specified index.
|
|||
/// </summary>
|
|||
/// <param name="index"></param>
|
|||
/// <returns></returns>
|
|||
public int GetParentIndex(int index) |
|||
{ |
|||
if (m_ParentIndices == null) |
|||
{ |
|||
return -1; |
|||
} |
|||
|
|||
return m_ParentIndices[index]; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Set whether the pose at the given index is enabled or disabled for observations.
|
|||
/// </summary>
|
|||
/// <param name="index"></param>
|
|||
/// <param name="val"></param>
|
|||
public void SetPoseEnabled(int index, bool val) |
|||
{ |
|||
m_PoseEnabled[index] = val; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Initialize with the mapping of parent indices.
|
|||
/// The 0th element is assumed to be -1, indicating that it's the root.
|
|||
/// </summary>
|
|||
/// <param name="parentIndices"></param>
|
|||
protected void Setup(int[] parentIndices) |
|||
{ |
|||
#if DEBUG
|
|||
if (parentIndices[0] != -1) |
|||
{ |
|||
throw new UnityAgentsException($"Expected parentIndices[0] to be -1, got {parentIndices[0]}"); |
|||
} |
|||
#endif
|
|||
m_ParentIndices = parentIndices; |
|||
var numPoses = parentIndices.Length; |
|||
m_ModelSpacePoses = new Pose[numPoses]; |
|||
m_LocalSpacePoses = new Pose[numPoses]; |
|||
|
|||
m_ModelSpaceLinearVelocities = new Vector3[numPoses]; |
|||
m_LocalSpaceLinearVelocities = new Vector3[numPoses]; |
|||
|
|||
m_PoseEnabled = new bool[numPoses]; |
|||
// All poses are enabled by default. Generally we'll want to disable the root though.
|
|||
for (var i = 0; i < numPoses; i++) |
|||
{ |
|||
m_PoseEnabled[i] = true; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Return the world space Pose of the i'th object.
|
|||
/// </summary>
|
|||
/// <param name="index"></param>
|
|||
/// <returns></returns>
|
|||
protected internal abstract Pose GetPoseAt(int index); |
|||
|
|||
/// <summary>
|
|||
/// Return the world space linear velocity of the i'th object.
|
|||
/// </summary>
|
|||
/// <param name="index"></param>
|
|||
/// <returns></returns>
|
|||
protected internal abstract Vector3 GetLinearVelocityAt(int index); |
|||
|
|||
|
|||
/// <summary>
|
|||
/// Update the internal model space transform storage based on the underlying system.
|
|||
/// </summary>
|
|||
public void UpdateModelSpacePoses() |
|||
{ |
|||
using (TimerStack.Instance.Scoped("UpdateModelSpacePoses")) |
|||
{ |
|||
if (m_ModelSpacePoses == null) |
|||
{ |
|||
return; |
|||
} |
|||
|
|||
var rootWorldTransform = GetPoseAt(0); |
|||
var worldToModel = rootWorldTransform.Inverse(); |
|||
var rootLinearVel = GetLinearVelocityAt(0); |
|||
|
|||
for (var i = 0; i < m_ModelSpacePoses.Length; i++) |
|||
{ |
|||
var currentWorldSpacePose = GetPoseAt(i); |
|||
var currentModelSpacePose = worldToModel.Multiply(currentWorldSpacePose); |
|||
m_ModelSpacePoses[i] = currentModelSpacePose; |
|||
|
|||
var currentBodyLinearVel = GetLinearVelocityAt(i); |
|||
var relativeVelocity = currentBodyLinearVel - rootLinearVel; |
|||
m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity; |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Update the internal model space transform storage based on the underlying system.
|
|||
/// </summary>
|
|||
public void UpdateLocalSpacePoses() |
|||
{ |
|||
using (TimerStack.Instance.Scoped("UpdateLocalSpacePoses")) |
|||
{ |
|||
if (m_LocalSpacePoses == null) |
|||
{ |
|||
return; |
|||
} |
|||
|
|||
for (var i = 0; i < m_LocalSpacePoses.Length; i++) |
|||
{ |
|||
if (m_ParentIndices[i] != -1) |
|||
{ |
|||
var parentTransform = GetPoseAt(m_ParentIndices[i]); |
|||
// This is slightly inefficient, since for a body with multiple children, we'll end up inverting
|
|||
// the transform multiple times. Might be able to trade space for perf here.
|
|||
var invParent = parentTransform.Inverse(); |
|||
var currentTransform = GetPoseAt(i); |
|||
m_LocalSpacePoses[i] = invParent.Multiply(currentTransform); |
|||
|
|||
var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]); |
|||
var currentLinearVel = GetLinearVelocityAt(i); |
|||
m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel); |
|||
} |
|||
else |
|||
{ |
|||
m_LocalSpacePoses[i] = Pose.identity; |
|||
m_LocalSpaceLinearVelocities[i] = Vector3.zero; |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Compute the number of floats needed to represent the poses for the given PhysicsSensorSettings.
|
|||
/// </summary>
|
|||
/// <param name="settings"></param>
|
|||
/// <returns></returns>
|
|||
public int GetNumPoseObservations(PhysicsSensorSettings settings) |
|||
{ |
|||
int obsPerPose = 0; |
|||
obsPerPose += settings.UseModelSpaceTranslations ? 3 : 0; |
|||
obsPerPose += settings.UseModelSpaceRotations ? 4 : 0; |
|||
obsPerPose += settings.UseLocalSpaceTranslations ? 3 : 0; |
|||
obsPerPose += settings.UseLocalSpaceRotations ? 4 : 0; |
|||
|
|||
obsPerPose += settings.UseModelSpaceLinearVelocity ? 3 : 0; |
|||
obsPerPose += settings.UseLocalSpaceLinearVelocity ? 3 : 0; |
|||
|
|||
return NumEnabledPoses * obsPerPose; |
|||
} |
|||
|
|||
internal void DrawModelSpace(Vector3 offset) |
|||
{ |
|||
UpdateLocalSpacePoses(); |
|||
UpdateModelSpacePoses(); |
|||
|
|||
var pose = m_ModelSpacePoses; |
|||
var localPose = m_LocalSpacePoses; |
|||
for (var i = 0; i < pose.Length; i++) |
|||
{ |
|||
var current = pose[i]; |
|||
if (m_ParentIndices[i] == -1) |
|||
{ |
|||
continue; |
|||
} |
|||
|
|||
var parent = pose[m_ParentIndices[i]]; |
|||
Debug.DrawLine(current.position + offset, parent.position + offset, Color.cyan); |
|||
var localUp = localPose[i].rotation * Vector3.up; |
|||
var localFwd = localPose[i].rotation * Vector3.forward; |
|||
var localRight = localPose[i].rotation * Vector3.right; |
|||
Debug.DrawLine(current.position+offset, current.position+offset+.1f*localUp, Color.red); |
|||
Debug.DrawLine(current.position+offset, current.position+offset+.1f*localFwd, Color.green); |
|||
Debug.DrawLine(current.position+offset, current.position+offset+.1f*localRight, Color.blue); |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Extension methods for the Pose struct, in order to improve the readability of some math.
|
|||
/// </summary>
|
|||
public static class PoseExtensions |
|||
{ |
|||
/// <summary>
|
|||
/// Compute the inverse of a Pose. For any Pose P,
|
|||
/// P.Inverse() * P
|
|||
/// will equal the identity pose (within tolerance).
|
|||
/// </summary>
|
|||
/// <param name="pose"></param>
|
|||
/// <returns></returns>
|
|||
public static Pose Inverse(this Pose pose) |
|||
{ |
|||
var rotationInverse = Quaternion.Inverse(pose.rotation); |
|||
var translationInverse = -(rotationInverse * pose.position); |
|||
return new Pose { rotation = rotationInverse, position = translationInverse }; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// This is equivalent to Pose.GetTransformedBy(), but keeps the order more intuitive.
|
|||
/// </summary>
|
|||
/// <param name="pose"></param>
|
|||
/// <param name="rhs"></param>
|
|||
/// <returns></returns>
|
|||
public static Pose Multiply(this Pose pose, Pose rhs) |
|||
{ |
|||
return rhs.GetTransformedBy(pose); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Transform the vector by the pose. Conceptually this is equivalent to treating the Pose
|
|||
/// as a 4x4 matrix and multiplying the augmented vector.
|
|||
/// See https://en.wikipedia.org/wiki/Affine_transformation#Augmented_matrix for more details.
|
|||
/// </summary>
|
|||
/// <param name="pose"></param>
|
|||
/// <param name="rhs"></param>
|
|||
/// <returns></returns>
|
|||
public static Vector3 Multiply(this Pose pose, Vector3 rhs) |
|||
{ |
|||
return pose.rotation * rhs + pose.position; |
|||
} |
|||
|
|||
// TODO optimize inv(A)*B?
|
|||
} |
|||
} |
|
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
|
|||
/// <summary>
|
|||
/// Utility class to track a hierarchy of RigidBodies. These are assumed to have a root node,
|
|||
/// and child nodes are connect to their parents via Joints.
|
|||
/// </summary>
|
|||
public class RigidBodyPoseExtractor : PoseExtractor |
|||
{ |
|||
Rigidbody[] m_Bodies; |
|||
|
|||
/// <summary>
|
|||
/// Optional game object used to determine the root of the poses, separate from the actual Rigidbodies
|
|||
/// in the hierarchy. For locomotion
|
|||
/// </summary>
|
|||
GameObject m_VirtualRoot; |
|||
|
|||
/// <summary>
|
|||
/// Initialize given a root RigidBody.
|
|||
/// </summary>
|
|||
/// <param name="rootBody">The root Rigidbody. This has no Joints on it (but other Joints may connect to it).</param>
|
|||
/// <param name="rootGameObject">Optional GameObject used to find Rigidbodies in the hierarchy.</param>
|
|||
/// <param name="virtualRoot">Optional GameObject used to determine the root of the poses,
|
|||
/// separate from the actual Rigidbodies in the hierarchy. For locomotion tasks, with ragdolls, this provides
|
|||
/// a stabilized refernece frame, which can improve learning.</param>
|
|||
public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null, GameObject virtualRoot = null) |
|||
{ |
|||
if (rootBody == null) |
|||
{ |
|||
return; |
|||
} |
|||
|
|||
Rigidbody[] rbs; |
|||
if (rootGameObject == null) |
|||
{ |
|||
rbs = rootBody.GetComponentsInChildren<Rigidbody>(); |
|||
} |
|||
else |
|||
{ |
|||
rbs = rootGameObject.GetComponentsInChildren<Rigidbody>(); |
|||
} |
|||
|
|||
if (rbs == null || rbs.Length == 0) |
|||
{ |
|||
Debug.Log("No rigid bodies found!"); |
|||
return; |
|||
} |
|||
|
|||
if (rbs[0] != rootBody) |
|||
{ |
|||
Debug.Log("Expected root body at index 0"); |
|||
return; |
|||
} |
|||
|
|||
// Adjust the array if we have a virtual root.
|
|||
// This will be at index 0, and the "real" root will be parented to it.
|
|||
if (virtualRoot != null) |
|||
{ |
|||
var extendedRbs = new Rigidbody[rbs.Length + 1]; |
|||
for (var i = 0; i < rbs.Length; i++) |
|||
{ |
|||
extendedRbs[i + 1] = rbs[i]; |
|||
} |
|||
|
|||
rbs = extendedRbs; |
|||
} |
|||
|
|||
var bodyToIndex = new Dictionary<Rigidbody, int>(rbs.Length); |
|||
var parentIndices = new int[rbs.Length]; |
|||
parentIndices[0] = -1; |
|||
|
|||
for (var i = 0; i < rbs.Length; i++) |
|||
{ |
|||
if(rbs[i] != null) |
|||
{ |
|||
bodyToIndex[rbs[i]] = i; |
|||
} |
|||
} |
|||
|
|||
var joints = rootBody.GetComponentsInChildren <Joint>(); |
|||
|
|||
|
|||
foreach (var j in joints) |
|||
{ |
|||
var parent = j.connectedBody; |
|||
var child = j.GetComponent<Rigidbody>(); |
|||
|
|||
var parentIndex = bodyToIndex[parent]; |
|||
var childIndex = bodyToIndex[child]; |
|||
parentIndices[childIndex] = parentIndex; |
|||
} |
|||
|
|||
if (virtualRoot != null) |
|||
{ |
|||
// Make sure the original root treats the virtual root as its parent.
|
|||
parentIndices[1] = 0; |
|||
m_VirtualRoot = virtualRoot; |
|||
} |
|||
|
|||
m_Bodies = rbs; |
|||
Setup(parentIndices); |
|||
|
|||
// By default, ignore the root
|
|||
SetPoseEnabled(0, false); |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
protected internal override Vector3 GetLinearVelocityAt(int index) |
|||
{ |
|||
if (index == 0 && m_VirtualRoot != null) |
|||
{ |
|||
// No velocity on the virtual root
|
|||
return Vector3.zero; |
|||
} |
|||
return m_Bodies[index].velocity; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
protected internal override Pose GetPoseAt(int index) |
|||
{ |
|||
if (index == 0 && m_VirtualRoot != null) |
|||
{ |
|||
// Use the GameObject's world transform
|
|||
return new Pose |
|||
{ |
|||
rotation = m_VirtualRoot.transform.rotation, |
|||
position = m_VirtualRoot.transform.position |
|||
}; |
|||
} |
|||
|
|||
var body = m_Bodies[index]; |
|||
return new Pose { rotation = body.rotation, position = body.position }; |
|||
} |
|||
|
|||
internal Rigidbody[] Bodies => m_Bodies; |
|||
} |
|||
|
|||
} |
|
|||
using UnityEngine; |
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
/// <summary>
|
|||
/// Editor component that creates a PhysicsBodySensor for the Agent.
|
|||
/// </summary>
|
|||
public class RigidBodySensorComponent : SensorComponent |
|||
{ |
|||
/// <summary>
|
|||
/// The root Rigidbody of the system.
|
|||
/// </summary>
|
|||
public Rigidbody RootBody; |
|||
|
|||
/// <summary>
|
|||
/// Optional GameObject used to determine the root of the poses.
|
|||
/// </summary>
|
|||
public GameObject VirtualRoot; |
|||
|
|||
/// <summary>
|
|||
/// Settings defining what types of observations will be generated.
|
|||
/// </summary>
|
|||
[SerializeField] |
|||
public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default(); |
|||
|
|||
/// <summary>
|
|||
/// Optional sensor name. This must be unique for each Agent.
|
|||
/// </summary>
|
|||
public string sensorName; |
|||
|
|||
/// <summary>
|
|||
/// Creates a PhysicsBodySensor.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
public override ISensor CreateSensor() |
|||
{ |
|||
return new PhysicsBodySensor(RootBody, gameObject, VirtualRoot, Settings, sensorName); |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public override int[] GetObservationShape() |
|||
{ |
|||
if (RootBody == null) |
|||
{ |
|||
return new[] { 0 }; |
|||
} |
|||
|
|||
// TODO static method in PhysicsBodySensor?
|
|||
// TODO only update PoseExtractor when body changes?
|
|||
var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject, VirtualRoot); |
|||
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings); |
|||
|
|||
var numJointObservations = 0; |
|||
// Start from i=1 to ignore the root
|
|||
for (var i = 1; i < poseExtractor.Bodies.Length; i++) |
|||
{ |
|||
var body = poseExtractor.Bodies[i]; |
|||
var joint = body?.GetComponent<Joint>(); |
|||
numJointObservations += RigidBodyJointExtractor.NumObservations(body, joint, Settings); |
|||
} |
|||
return new[] { numPoseObservations + numJointObservations }; |
|||
} |
|||
} |
|||
|
|||
} |
|
|||
#if UNITY_2020_1_OR_NEWER
|
|||
using UnityEngine; |
|||
using NUnit.Framework; |
|||
using Unity.MLAgents.Extensions.Sensors; |
|||
|
|||
|
|||
namespace Unity.MLAgents.Extensions.Tests.Sensors |
|||
{ |
|||
|
|||
public class ArticulationBodySensorTests |
|||
{ |
|||
[Test] |
|||
public void TestNullRootBody() |
|||
{ |
|||
var gameObj = new GameObject(); |
|||
|
|||
var sensorComponent = gameObj.AddComponent<ArticulationBodySensorComponent>(); |
|||
var sensor = sensorComponent.CreateSensor(); |
|||
SensorTestHelper.CompareObservation(sensor, new float[0]); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestSingleBody() |
|||
{ |
|||
var gameObj = new GameObject(); |
|||
var articulationBody = gameObj.AddComponent<ArticulationBody>(); |
|||
var sensorComponent = gameObj.AddComponent<ArticulationBodySensorComponent>(); |
|||
sensorComponent.RootBody = articulationBody; |
|||
sensorComponent.Settings = new PhysicsSensorSettings |
|||
{ |
|||
UseModelSpaceLinearVelocity = true, |
|||
UseLocalSpaceTranslations = true, |
|||
UseLocalSpaceRotations = true |
|||
}; |
|||
|
|||
var sensor = sensorComponent.CreateSensor(); |
|||
sensor.Update(); |
|||
var expected = new[] |
|||
{ |
|||
0f, 0f, 0f, // ModelSpaceLinearVelocity
|
|||
0f, 0f, 0f, // LocalSpaceTranslations
|
|||
0f, 0f, 0f, 1f // LocalSpaceRotations
|
|||
}; |
|||
SensorTestHelper.CompareObservation(sensor, expected); |
|||
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestBodiesWithJoint() |
|||
{ |
|||
var rootObj = new GameObject(); |
|||
var rootArticBody = rootObj.AddComponent<ArticulationBody>(); |
|||
|
|||
var middleGamObj = new GameObject(); |
|||
var middleArticBody = middleGamObj.AddComponent<ArticulationBody>(); |
|||
middleArticBody.AddForce(new Vector3(0f, 1f, 0f)); |
|||
middleGamObj.transform.SetParent(rootObj.transform); |
|||
middleGamObj.transform.localPosition = new Vector3(13.37f, 0f, 0f); |
|||
middleArticBody.jointType = ArticulationJointType.RevoluteJoint; |
|||
|
|||
var leafGameObj = new GameObject(); |
|||
var leafArticBody = leafGameObj.AddComponent<ArticulationBody>(); |
|||
leafGameObj.transform.SetParent(middleGamObj.transform); |
|||
leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f); |
|||
leafArticBody.jointType = ArticulationJointType.PrismaticJoint; |
|||
leafArticBody.linearLockZ = ArticulationDofLock.LimitedMotion; |
|||
leafArticBody.zDrive = new ArticulationDrive |
|||
{ |
|||
lowerLimit = -3, |
|||
upperLimit = 1 |
|||
}; |
|||
|
|||
|
|||
#if UNITY_2020_2_OR_NEWER
|
|||
// ArticulationBody.velocity is read-only in 2020.1
|
|||
rootArticBody.velocity = new Vector3(1f, 0f, 0f); |
|||
middleArticBody.velocity = new Vector3(0f, 1f, 0f); |
|||
leafArticBody.velocity = new Vector3(0f, 0f, 1f); |
|||
#endif
|
|||
|
|||
var sensorComponent = rootObj.AddComponent<ArticulationBodySensorComponent>(); |
|||
sensorComponent.RootBody = rootArticBody; |
|||
sensorComponent.Settings = new PhysicsSensorSettings |
|||
{ |
|||
UseModelSpaceTranslations = true, |
|||
UseLocalSpaceTranslations = true, |
|||
#if UNITY_2020_2_OR_NEWER
|
|||
UseLocalSpaceLinearVelocity = true |
|||
#endif
|
|||
}; |
|||
|
|||
var sensor = sensorComponent.CreateSensor(); |
|||
sensor.Update(); |
|||
var expected = new[] |
|||
{ |
|||
// Model space
|
|||
0f, 0f, 0f, // Root pos
|
|||
13.37f, 0f, 0f, // Middle pos
|
|||
leafGameObj.transform.position.x, 0f, 0f, // Leaf pos
|
|||
|
|||
// Local space
|
|||
0f, 0f, 0f, // Root pos
|
|||
13.37f, 0f, 0f, // Attached pos
|
|||
4.2f, 0f, 0f, // Leaf pos
|
|||
|
|||
#if UNITY_2020_2_OR_NEWER
|
|||
0f, 0f, 0f, // Root vel
|
|||
-1f, 1f, 0f, // Attached vel
|
|||
0f, -1f, 1f // Leaf vel
|
|||
#endif
|
|||
}; |
|||
SensorTestHelper.CompareObservation(sensor, expected); |
|||
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); |
|||
|
|||
// Update the settings to only process joint observations
|
|||
sensorComponent.Settings = new PhysicsSensorSettings |
|||
{ |
|||
UseJointForces = true, |
|||
UseJointPositionsAndAngles = true, |
|||
}; |
|||
|
|||
sensor = sensorComponent.CreateSensor(); |
|||
sensor.Update(); |
|||
|
|||
expected = new[] |
|||
{ |
|||
// revolute
|
|||
0f, 1f, // joint1.position (sin and cos)
|
|||
0f, // joint1.force
|
|||
|
|||
// prismatic
|
|||
0.5f, // joint2.position (interpolate between limits)
|
|||
0f, // joint2.force
|
|||
}; |
|||
SensorTestHelper.CompareObservation(sensor, expected); |
|||
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); |
|||
} |
|||
} |
|||
} |
|||
#endif // #if UNITY_2020_1_OR_NEWER
|
|
|||
using UnityEngine; |
|||
using NUnit.Framework; |
|||
using Unity.MLAgents.Extensions.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Tests.Sensors |
|||
{ |
|||
public class PoseExtractorTests |
|||
{ |
|||
class UselessPoseExtractor : PoseExtractor |
|||
{ |
|||
protected internal override Pose GetPoseAt(int index) |
|||
{ |
|||
return Pose.identity; |
|||
} |
|||
|
|||
protected internal override Vector3 GetLinearVelocityAt(int index) |
|||
{ |
|||
return Vector3.zero; |
|||
} |
|||
|
|||
public void Init(int[] parentIndices) |
|||
{ |
|||
Setup(parentIndices); |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestEmptyExtractor() |
|||
{ |
|||
var poseExtractor = new UselessPoseExtractor(); |
|||
|
|||
// These should be no-ops
|
|||
poseExtractor.UpdateLocalSpacePoses(); |
|||
poseExtractor.UpdateModelSpacePoses(); |
|||
|
|||
Assert.AreEqual(0, poseExtractor.NumPoses); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestSimpleExtractor() |
|||
{ |
|||
var poseExtractor = new UselessPoseExtractor(); |
|||
var parentIndices = new[] { -1, 0 }; |
|||
poseExtractor.Init(parentIndices); |
|||
Assert.AreEqual(2, poseExtractor.NumPoses); |
|||
} |
|||
|
|||
|
|||
/// <summary>
|
|||
/// A simple "chain" hierarchy, where each object is parented to the one before it.
|
|||
/// 0 <- 1 <- 2 <- ...
|
|||
/// </summary>
|
|||
class ChainPoseExtractor : PoseExtractor |
|||
{ |
|||
public Vector3 offset; |
|||
public ChainPoseExtractor(int size) |
|||
{ |
|||
var parents = new int[size]; |
|||
for (var i = 0; i < size; i++) |
|||
{ |
|||
parents[i] = i - 1; |
|||
} |
|||
Setup(parents); |
|||
} |
|||
|
|||
protected internal override Pose GetPoseAt(int index) |
|||
{ |
|||
var rotation = Quaternion.identity; |
|||
var translation = offset + new Vector3(index, index, index); |
|||
return new Pose |
|||
{ |
|||
rotation = rotation, |
|||
position = translation |
|||
}; |
|||
} |
|||
|
|||
protected internal override Vector3 GetLinearVelocityAt(int index) |
|||
{ |
|||
return Vector3.zero; |
|||
} |
|||
|
|||
} |
|||
|
|||
[Test] |
|||
public void TestChain() |
|||
{ |
|||
var size = 4; |
|||
var chain = new ChainPoseExtractor(size); |
|||
chain.offset = new Vector3(.5f, .75f, .333f); |
|||
|
|||
chain.UpdateModelSpacePoses(); |
|||
chain.UpdateLocalSpacePoses(); |
|||
|
|||
|
|||
var modelPoseIndex = 0; |
|||
foreach (var modelSpace in chain.GetEnabledModelSpacePoses()) |
|||
{ |
|||
if (modelPoseIndex == 0) |
|||
{ |
|||
// Root transforms are currently always the identity.
|
|||
Assert.IsTrue(modelSpace == Pose.identity); |
|||
} |
|||
else |
|||
{ |
|||
var expectedModelTranslation = new Vector3(modelPoseIndex, modelPoseIndex, modelPoseIndex); |
|||
Assert.IsTrue(expectedModelTranslation == modelSpace.position); |
|||
|
|||
} |
|||
modelPoseIndex++; |
|||
} |
|||
Assert.AreEqual(size, modelPoseIndex); |
|||
|
|||
var localPoseIndex = 0; |
|||
foreach (var localSpace in chain.GetEnabledLocalSpacePoses()) |
|||
{ |
|||
if (localPoseIndex == 0) |
|||
{ |
|||
// Root transforms are currently always the identity.
|
|||
Assert.IsTrue(localSpace == Pose.identity); |
|||
} |
|||
else |
|||
{ |
|||
var expectedLocalTranslation = new Vector3(1, 1, 1); |
|||
Assert.IsTrue(expectedLocalTranslation == localSpace.position, $"{expectedLocalTranslation} != {localSpace.position}"); |
|||
} |
|||
|
|||
localPoseIndex++; |
|||
} |
|||
Assert.AreEqual(size, localPoseIndex); |
|||
} |
|||
|
|||
class BadPoseExtractor : PoseExtractor |
|||
{ |
|||
public BadPoseExtractor() |
|||
{ |
|||
var size = 2; |
|||
var parents = new int[size]; |
|||
// Parents are intentionally invalid - expect -1 at root
|
|||
for (var i = 0; i < size; i++) |
|||
{ |
|||
parents[i] = i; |
|||
} |
|||
Setup(parents); |
|||
} |
|||
|
|||
protected internal override Pose GetPoseAt(int index) |
|||
{ |
|||
return Pose.identity; |
|||
} |
|||
|
|||
protected internal override Vector3 GetLinearVelocityAt(int index) |
|||
{ |
|||
return Vector3.zero; |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestExpectedRoot() |
|||
{ |
|||
Assert.Throws<UnityAgentsException>(() => |
|||
{ |
|||
var bad = new BadPoseExtractor(); |
|||
}); |
|||
} |
|||
} |
|||
|
|||
public class PoseExtensionTests |
|||
{ |
|||
[Test] |
|||
public void TestInverse() |
|||
{ |
|||
Pose t = new Pose |
|||
{ |
|||
rotation = Quaternion.AngleAxis(23.0f, new Vector3(1, 1, 1).normalized), |
|||
position = new Vector3(-1.0f, 2.0f, 3.0f) |
|||
}; |
|||
|
|||
var inverseT = t.Inverse(); |
|||
var product = inverseT.Multiply(t); |
|||
Assert.IsTrue(Vector3.zero == product.position); |
|||
Assert.IsTrue(Quaternion.identity == product.rotation); |
|||
|
|||
Assert.IsTrue(Pose.identity == product); |
|||
} |
|||
|
|||
} |
|||
} |
|
|||
using UnityEngine; |
|||
using NUnit.Framework; |
|||
using Unity.MLAgents.Extensions.Sensors; |
|||
using UnityEditor; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Tests.Sensors |
|||
{ |
|||
public class RigidBodyPoseExtractorTests |
|||
{ |
|||
[TearDown] |
|||
public void RemoveGameObjects() |
|||
{ |
|||
var objects = GameObject.FindObjectsOfType<GameObject>(); |
|||
foreach (var o in objects) |
|||
{ |
|||
UnityEngine.Object.DestroyImmediate(o); |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestNullRoot() |
|||
{ |
|||
var poseExtractor = new RigidBodyPoseExtractor(null); |
|||
// These should be no-ops
|
|||
poseExtractor.UpdateLocalSpacePoses(); |
|||
poseExtractor.UpdateModelSpacePoses(); |
|||
|
|||
Assert.AreEqual(0, poseExtractor.NumPoses); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestSingleBody() |
|||
{ |
|||
var go = new GameObject(); |
|||
var rootRb = go.AddComponent<Rigidbody>(); |
|||
var poseExtractor = new RigidBodyPoseExtractor(rootRb); |
|||
Assert.AreEqual(1, poseExtractor.NumPoses); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestTwoBodies() |
|||
{ |
|||
// * rootObj
|
|||
// - rb1
|
|||
// * go2
|
|||
// - rb2
|
|||
// - joint
|
|||
var rootObj = new GameObject(); |
|||
var rb1 = rootObj.AddComponent<Rigidbody>(); |
|||
|
|||
var go2 = new GameObject(); |
|||
var rb2 = go2.AddComponent<Rigidbody>(); |
|||
go2.transform.SetParent(rootObj.transform); |
|||
|
|||
var joint = go2.AddComponent<ConfigurableJoint>(); |
|||
joint.connectedBody = rb1; |
|||
|
|||
var poseExtractor = new RigidBodyPoseExtractor(rb1); |
|||
Assert.AreEqual(2, poseExtractor.NumPoses); |
|||
|
|||
rb1.position = new Vector3(1, 0, 0); |
|||
rb1.rotation = Quaternion.Euler(0, 13.37f, 0); |
|||
rb1.velocity = new Vector3(2, 0, 0); |
|||
|
|||
Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(0).position); |
|||
Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(0).rotation); |
|||
Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(0)); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestTwoBodiesVirtualRoot() |
|||
{ |
|||
// * virtualRoot
|
|||
// * rootObj
|
|||
// - rb1
|
|||
// * go2
|
|||
// - rb2
|
|||
// - joint
|
|||
var virtualRoot = new GameObject("I am vroot"); |
|||
|
|||
var rootObj = new GameObject(); |
|||
var rb1 = rootObj.AddComponent<Rigidbody>(); |
|||
|
|||
var go2 = new GameObject(); |
|||
var rb2 = go2.AddComponent<Rigidbody>(); |
|||
go2.transform.SetParent(rootObj.transform); |
|||
|
|||
var joint = go2.AddComponent<ConfigurableJoint>(); |
|||
joint.connectedBody = rb1; |
|||
|
|||
var poseExtractor = new RigidBodyPoseExtractor(rb1, null, virtualRoot); |
|||
Assert.AreEqual(3, poseExtractor.NumPoses); |
|||
|
|||
// "body" 0 has no parent
|
|||
Assert.AreEqual(-1, poseExtractor.GetParentIndex(0)); |
|||
|
|||
// body 1 has parent 0
|
|||
Assert.AreEqual(0, poseExtractor.GetParentIndex(1)); |
|||
|
|||
var virtualRootPos = new Vector3(0,2,0); |
|||
var virtualRootRot = Quaternion.Euler(0, 42, 0); |
|||
virtualRoot.transform.position = virtualRootPos; |
|||
virtualRoot.transform.rotation = virtualRootRot; |
|||
|
|||
Assert.AreEqual(virtualRootPos, poseExtractor.GetPoseAt(0).position); |
|||
Assert.IsTrue(virtualRootRot == poseExtractor.GetPoseAt(0).rotation); |
|||
Assert.AreEqual(Vector3.zero, poseExtractor.GetLinearVelocityAt(0)); |
|||
|
|||
// Same as above test, but using index 1
|
|||
rb1.position = new Vector3(1, 0, 0); |
|||
rb1.rotation = Quaternion.Euler(0, 13.37f, 0); |
|||
rb1.velocity = new Vector3(2, 0, 0); |
|||
|
|||
Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(1).position); |
|||
Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(1).rotation); |
|||
Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(1)); |
|||
} |
|||
} |
|||
} |
|
|||
using UnityEngine; |
|||
using NUnit.Framework; |
|||
using Unity.MLAgents.Sensors; |
|||
using Unity.MLAgents.Extensions.Sensors; |
|||
|
|||
|
|||
namespace Unity.MLAgents.Extensions.Tests.Sensors |
|||
{ |
|||
|
|||
public static class SensorTestHelper |
|||
{ |
|||
public static void CompareObservation(ISensor sensor, float[] expected) |
|||
{ |
|||
string errorMessage; |
|||
bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage); |
|||
Assert.IsTrue(isOK, errorMessage); |
|||
} |
|||
} |
|||
|
|||
public class RigidBodySensorTests |
|||
{ |
|||
[Test] |
|||
public void TestNullRootBody() |
|||
{ |
|||
var gameObj = new GameObject(); |
|||
|
|||
var sensorComponent = gameObj.AddComponent<RigidBodySensorComponent>(); |
|||
var sensor = sensorComponent.CreateSensor(); |
|||
SensorTestHelper.CompareObservation(sensor, new float[0]); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestSingleRigidbody() |
|||
{ |
|||
var gameObj = new GameObject(); |
|||
var rootRb = gameObj.AddComponent<Rigidbody>(); |
|||
var sensorComponent = gameObj.AddComponent<RigidBodySensorComponent>(); |
|||
sensorComponent.RootBody = rootRb; |
|||
sensorComponent.Settings = new PhysicsSensorSettings |
|||
{ |
|||
UseModelSpaceLinearVelocity = true, |
|||
UseLocalSpaceTranslations = true, |
|||
UseLocalSpaceRotations = true |
|||
}; |
|||
|
|||
var sensor = sensorComponent.CreateSensor(); |
|||
sensor.Update(); |
|||
|
|||
// The root body is ignored since it always generates identity values
|
|||
// and there are no other bodies to generate observations.
|
|||
var expected = new float[0]; |
|||
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); |
|||
SensorTestHelper.CompareObservation(sensor, expected); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestBodiesWithJoint() |
|||
{ |
|||
var rootObj = new GameObject(); |
|||
var rootRb = rootObj.AddComponent<Rigidbody>(); |
|||
rootRb.velocity = new Vector3(1f, 0f, 0f); |
|||
|
|||
var middleGamObj = new GameObject(); |
|||
var middleRb = middleGamObj.AddComponent<Rigidbody>(); |
|||
middleRb.velocity = new Vector3(0f, 1f, 0f); |
|||
middleGamObj.transform.SetParent(rootObj.transform); |
|||
middleGamObj.transform.localPosition = new Vector3(13.37f, 0f, 0f); |
|||
var joint = middleGamObj.AddComponent<ConfigurableJoint>(); |
|||
joint.connectedBody = rootRb; |
|||
|
|||
var leafGameObj = new GameObject(); |
|||
var leafRb = leafGameObj.AddComponent<Rigidbody>(); |
|||
leafRb.velocity = new Vector3(0f, 0f, 1f); |
|||
leafGameObj.transform.SetParent(middleGamObj.transform); |
|||
leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f); |
|||
var joint2 = leafGameObj.AddComponent<ConfigurableJoint>(); |
|||
joint2.connectedBody = middleRb; |
|||
|
|||
var virtualRoot = new GameObject(); |
|||
|
|||
var sensorComponent = rootObj.AddComponent<RigidBodySensorComponent>(); |
|||
sensorComponent.RootBody = rootRb; |
|||
sensorComponent.Settings = new PhysicsSensorSettings |
|||
{ |
|||
UseModelSpaceTranslations = true, |
|||
UseLocalSpaceTranslations = true, |
|||
UseLocalSpaceLinearVelocity = true |
|||
}; |
|||
sensorComponent.VirtualRoot = virtualRoot; |
|||
|
|||
var sensor = sensorComponent.CreateSensor(); |
|||
sensor.Update(); |
|||
|
|||
// Note that the VirtualRoot is ignored from the observations
|
|||
var expected = new[] |
|||
{ |
|||
// Model space
|
|||
0f, 0f, 0f, // Root pos
|
|||
13.37f, 0f, 0f, // Middle pos
|
|||
leafGameObj.transform.position.x, 0f, 0f, // Leaf pos
|
|||
|
|||
// Local space
|
|||
0f, 0f, 0f, // Root pos
|
|||
13.37f, 0f, 0f, // Attached pos
|
|||
4.2f, 0f, 0f, // Leaf pos
|
|||
|
|||
1f, 0f, 0f, // Root vel (relative to virtual root)
|
|||
-1f, 1f, 0f, // Attached vel
|
|||
0f, -1f, 1f // Leaf vel
|
|||
}; |
|||
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); |
|||
SensorTestHelper.CompareObservation(sensor, expected); |
|||
|
|||
// Update the settings to only process joint observations
|
|||
sensorComponent.Settings = new PhysicsSensorSettings |
|||
{ |
|||
UseJointPositionsAndAngles = true, |
|||
UseJointForces = true, |
|||
}; |
|||
|
|||
sensor = sensorComponent.CreateSensor(); |
|||
sensor.Update(); |
|||
|
|||
expected = new[] |
|||
{ |
|||
0f, 0f, 0f, // joint1.force
|
|||
0f, 0f, 0f, // joint1.torque
|
|||
0f, 0f, 0f, // joint2.force
|
|||
0f, 0f, 0f, // joint2.torque
|
|||
}; |
|||
SensorTestHelper.CompareObservation(sensor, expected); |
|||
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]); |
|||
|
|||
} |
|||
} |
|||
} |
|
|||
from typing import Callable, List, Dict, Tuple, Optional |
|||
import attr |
|||
import abc |
|||
|
|||
import torch |
|||
from torch import nn |
|||
|
|||
from mlagents_envs.base_env import ActionType |
|||
from mlagents.trainers.torch.distributions import ( |
|||
GaussianDistribution, |
|||
MultiCategoricalDistribution, |
|||
DistInstance, |
|||
) |
|||
from mlagents.trainers.settings import NetworkSettings |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
from mlagents.trainers.torch.decoders import ValueHeads |
|||
|
|||
ActivationFunction = Callable[[torch.Tensor], torch.Tensor] |
|||
EncoderFunction = Callable[ |
|||
[torch.Tensor, int, ActivationFunction, int, str, bool], torch.Tensor |
|||
] |
|||
|
|||
EPSILON = 1e-7 |
|||
|
|||
|
|||
class NetworkBody(nn.Module): |
|||
def __init__( |
|||
self, |
|||
observation_shapes: List[Tuple[int, ...]], |
|||
network_settings: NetworkSettings, |
|||
encoded_act_size: int = 0, |
|||
): |
|||
super().__init__() |
|||
self.normalize = network_settings.normalize |
|||
self.use_lstm = network_settings.memory is not None |
|||
self.h_size = network_settings.hidden_units |
|||
self.m_size = ( |
|||
network_settings.memory.memory_size |
|||
if network_settings.memory is not None |
|||
else 0 |
|||
) |
|||
|
|||
self.visual_encoders, self.vector_encoders = ModelUtils.create_encoders( |
|||
observation_shapes, |
|||
self.h_size, |
|||
network_settings.num_layers, |
|||
network_settings.vis_encode_type, |
|||
unnormalized_inputs=encoded_act_size, |
|||
normalize=self.normalize, |
|||
) |
|||
|
|||
if self.use_lstm: |
|||
self.lstm = nn.LSTM(self.h_size, self.m_size // 2, 1) |
|||
else: |
|||
self.lstm = None |
|||
|
|||
def update_normalization(self, vec_inputs: List[torch.Tensor]) -> None: |
|||
for vec_input, vec_enc in zip(vec_inputs, self.vector_encoders): |
|||
vec_enc.update_normalization(vec_input) |
|||
|
|||
def copy_normalization(self, other_network: "NetworkBody") -> None: |
|||
if self.normalize: |
|||
for n1, n2 in zip(self.vector_encoders, other_network.vector_encoders): |
|||
n1.copy_normalization(n2) |
|||
|
|||
def forward( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
actions: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[torch.Tensor, torch.Tensor]: |
|||
vec_encodes = [] |
|||
for idx, encoder in enumerate(self.vector_encoders): |
|||
vec_input = vec_inputs[idx] |
|||
if actions is not None: |
|||
hidden = encoder(vec_input, actions) |
|||
else: |
|||
hidden = encoder(vec_input) |
|||
vec_encodes.append(hidden) |
|||
|
|||
vis_encodes = [] |
|||
for idx, encoder in enumerate(self.visual_encoders): |
|||
vis_input = vis_inputs[idx] |
|||
vis_input = vis_input.permute([0, 3, 1, 2]) |
|||
hidden = encoder(vis_input) |
|||
vis_encodes.append(hidden) |
|||
|
|||
if len(vec_encodes) > 0 and len(vis_encodes) > 0: |
|||
vec_encodes_tensor = torch.stack(vec_encodes, dim=-1).sum(dim=-1) |
|||
vis_encodes_tensor = torch.stack(vis_encodes, dim=-1).sum(dim=-1) |
|||
encoding = torch.stack( |
|||
[vec_encodes_tensor, vis_encodes_tensor], dim=-1 |
|||
).sum(dim=-1) |
|||
elif len(vec_encodes) > 0: |
|||
encoding = torch.stack(vec_encodes, dim=-1).sum(dim=-1) |
|||
elif len(vis_encodes) > 0: |
|||
encoding = torch.stack(vis_encodes, dim=-1).sum(dim=-1) |
|||
else: |
|||
raise Exception("No valid inputs to network.") |
|||
|
|||
if self.use_lstm: |
|||
encoding = encoding.view([sequence_length, -1, self.h_size]) |
|||
memories = torch.split(memories, self.m_size // 2, dim=-1) |
|||
encoding, memories = self.lstm( |
|||
encoding.contiguous(), |
|||
(memories[0].contiguous(), memories[1].contiguous()), |
|||
) |
|||
encoding = encoding.view([-1, self.m_size // 2]) |
|||
memories = torch.cat(memories, dim=-1) |
|||
return encoding, memories |
|||
|
|||
|
|||
class ValueNetwork(nn.Module): |
|||
def __init__( |
|||
self, |
|||
stream_names: List[str], |
|||
observation_shapes: List[Tuple[int, ...]], |
|||
network_settings: NetworkSettings, |
|||
encoded_act_size: int = 0, |
|||
outputs_per_stream: int = 1, |
|||
): |
|||
|
|||
# This is not a typo, we want to call __init__ of nn.Module |
|||
nn.Module.__init__(self) |
|||
self.network_body = NetworkBody( |
|||
observation_shapes, network_settings, encoded_act_size=encoded_act_size |
|||
) |
|||
if network_settings.memory is not None: |
|||
encoding_size = network_settings.memory.memory_size // 2 |
|||
else: |
|||
encoding_size = network_settings.hidden_units |
|||
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream) |
|||
|
|||
def forward( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
actions: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|||
encoding, memories = self.network_body( |
|||
vec_inputs, vis_inputs, actions, memories, sequence_length |
|||
) |
|||
output = self.value_heads(encoding) |
|||
return output, memories |
|||
|
|||
|
|||
class Actor(abc.ABC): |
|||
@abc.abstractmethod |
|||
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None: |
|||
""" |
|||
Updates normalization of Actor based on the provided List of vector obs. |
|||
:param vector_obs: A List of vector obs as tensors. |
|||
""" |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]: |
|||
""" |
|||
Takes a List of Distribution iinstances and samples an action from each. |
|||
""" |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def get_dists( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
masks: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[List[DistInstance], Optional[torch.Tensor]]: |
|||
""" |
|||
Returns distributions from this Actor, from which actions can be sampled. |
|||
If memory is enabled, return the memories as well. |
|||
:param vec_inputs: A List of vector inputs as tensors. |
|||
:param vis_inputs: A List of visual inputs as tensors. |
|||
:param masks: If using discrete actions, a Tensor of action masks. |
|||
:param memories: If using memory, a Tensor of initial memories. |
|||
:param sequence_length: If using memory, the sequence length. |
|||
:return: A Tuple of a List of action distribution instances, and memories. |
|||
Memories will be None if not using memory. |
|||
""" |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def forward( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
masks: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]: |
|||
""" |
|||
Forward pass of the Actor for inference. This is required for export to ONNX, and |
|||
the inputs and outputs of this method should not be changed without a respective change |
|||
in the ONNX export code. |
|||
""" |
|||
pass |
|||
|
|||
|
|||
class ActorCritic(Actor): |
|||
@abc.abstractmethod |
|||
def critic_pass( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
memories: Optional[torch.Tensor] = None, |
|||
) -> Dict[str, torch.Tensor]: |
|||
""" |
|||
Get value outputs for the given obs. |
|||
:param vec_inputs: List of vector inputs as tensors. |
|||
:param vis_inputs: List of visual inputs as tensors. |
|||
:param memories: Tensor of memories, if using memory. Otherwise, None. |
|||
:returns: Dict of reward stream to output tensor for values. |
|||
""" |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def get_dist_and_value( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
masks: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: |
|||
""" |
|||
Returns distributions, from which actions can be sampled, and value estimates. |
|||
If memory is enabled, return the memories as well. |
|||
:param vec_inputs: A List of vector inputs as tensors. |
|||
:param vis_inputs: A List of visual inputs as tensors. |
|||
:param masks: If using discrete actions, a Tensor of action masks. |
|||
:param memories: If using memory, a Tensor of initial memories. |
|||
:param sequence_length: If using memory, the sequence length. |
|||
:return: A Tuple of a List of action distribution instances, a Dict of reward signal |
|||
name to value estimate, and memories. Memories will be None if not using memory. |
|||
""" |
|||
pass |
|||
|
|||
|
|||
class SimpleActor(nn.Module, Actor): |
|||
def __init__( |
|||
self, |
|||
observation_shapes: List[Tuple[int, ...]], |
|||
network_settings: NetworkSettings, |
|||
act_type: ActionType, |
|||
act_size: List[int], |
|||
conditional_sigma: bool = False, |
|||
tanh_squash: bool = False, |
|||
): |
|||
super().__init__() |
|||
self.act_type = act_type |
|||
self.act_size = act_size |
|||
self.version_number = torch.nn.Parameter(torch.Tensor([2.0])) |
|||
self.memory_size = torch.nn.Parameter(torch.Tensor([0])) |
|||
self.is_continuous_int = torch.nn.Parameter( |
|||
torch.Tensor([int(act_type == ActionType.CONTINUOUS)]) |
|||
) |
|||
self.act_size_vector = torch.nn.Parameter(torch.Tensor(act_size)) |
|||
self.network_body = NetworkBody(observation_shapes, network_settings) |
|||
if network_settings.memory is not None: |
|||
self.encoding_size = network_settings.memory.memory_size // 2 |
|||
else: |
|||
self.encoding_size = network_settings.hidden_units |
|||
if self.act_type == ActionType.CONTINUOUS: |
|||
self.distribution = GaussianDistribution( |
|||
self.encoding_size, |
|||
act_size[0], |
|||
conditional_sigma=conditional_sigma, |
|||
tanh_squash=tanh_squash, |
|||
) |
|||
else: |
|||
self.distribution = MultiCategoricalDistribution( |
|||
self.encoding_size, act_size |
|||
) |
|||
|
|||
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None: |
|||
self.network_body.update_normalization(vector_obs) |
|||
|
|||
def sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]: |
|||
actions = [] |
|||
for action_dist in dists: |
|||
action = action_dist.sample() |
|||
actions.append(action) |
|||
return actions |
|||
|
|||
def get_dists( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
masks: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[List[DistInstance], Optional[torch.Tensor]]: |
|||
encoding, memories = self.network_body( |
|||
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|||
) |
|||
if self.act_type == ActionType.CONTINUOUS: |
|||
dists = self.distribution(encoding) |
|||
else: |
|||
dists = self.distribution(encoding, masks) |
|||
|
|||
return dists, memories |
|||
|
|||
def forward( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
masks: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]: |
|||
""" |
|||
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs. |
|||
""" |
|||
dists, _ = self.get_dists( |
|||
vec_inputs, vis_inputs, masks, memories, sequence_length |
|||
) |
|||
action_list = self.sample_action(dists) |
|||
sampled_actions = torch.stack(action_list, dim=-1) |
|||
return ( |
|||
sampled_actions, |
|||
dists[0].pdf(sampled_actions), |
|||
self.version_number, |
|||
self.memory_size, |
|||
self.is_continuous_int, |
|||
self.act_size_vector, |
|||
) |
|||
|
|||
|
|||
class SharedActorCritic(SimpleActor, ActorCritic): |
|||
def __init__( |
|||
self, |
|||
observation_shapes: List[Tuple[int, ...]], |
|||
network_settings: NetworkSettings, |
|||
act_type: ActionType, |
|||
act_size: List[int], |
|||
stream_names: List[str], |
|||
conditional_sigma: bool = False, |
|||
tanh_squash: bool = False, |
|||
): |
|||
super().__init__( |
|||
observation_shapes, |
|||
network_settings, |
|||
act_type, |
|||
act_size, |
|||
conditional_sigma, |
|||
tanh_squash, |
|||
) |
|||
self.stream_names = stream_names |
|||
self.value_heads = ValueHeads(stream_names, self.encoding_size) |
|||
|
|||
def critic_pass( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
memories: Optional[torch.Tensor] = None, |
|||
) -> Dict[str, torch.Tensor]: |
|||
encoding, _ = self.network_body(vec_inputs, vis_inputs, memories=memories) |
|||
return self.value_heads(encoding) |
|||
|
|||
def get_dist_and_value( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
masks: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: |
|||
encoding, memories = self.network_body( |
|||
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|||
) |
|||
if self.act_type == ActionType.CONTINUOUS: |
|||
dists = self.distribution(encoding) |
|||
else: |
|||
dists = self.distribution(encoding, masks=masks) |
|||
|
|||
value_outputs = self.value_heads(encoding) |
|||
return dists, value_outputs, memories |
|||
|
|||
|
|||
class SeparateActorCritic(SimpleActor, ActorCritic): |
|||
def __init__( |
|||
self, |
|||
observation_shapes: List[Tuple[int, ...]], |
|||
network_settings: NetworkSettings, |
|||
act_type: ActionType, |
|||
act_size: List[int], |
|||
stream_names: List[str], |
|||
conditional_sigma: bool = False, |
|||
tanh_squash: bool = False, |
|||
): |
|||
# Give the Actor only half the memories. Note we previously validate |
|||
# that memory_size must be a multiple of 4. |
|||
self.use_lstm = network_settings.memory is not None |
|||
if network_settings.memory is not None: |
|||
self.half_mem_size = network_settings.memory.memory_size // 2 |
|||
new_memory_settings = attr.evolve( |
|||
network_settings.memory, memory_size=self.half_mem_size |
|||
) |
|||
use_network_settings = attr.evolve( |
|||
network_settings, memory=new_memory_settings |
|||
) |
|||
else: |
|||
use_network_settings = network_settings |
|||
self.half_mem_size = 0 |
|||
super().__init__( |
|||
observation_shapes, |
|||
use_network_settings, |
|||
act_type, |
|||
act_size, |
|||
conditional_sigma, |
|||
tanh_squash, |
|||
) |
|||
self.stream_names = stream_names |
|||
self.critic = ValueNetwork( |
|||
stream_names, observation_shapes, use_network_settings |
|||
) |
|||
|
|||
def critic_pass( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
memories: Optional[torch.Tensor] = None, |
|||
) -> Dict[str, torch.Tensor]: |
|||
if self.use_lstm: |
|||
# Use only the back half of memories for critic |
|||
_, critic_mem = torch.split(memories, self.half_mem_size, -1) |
|||
else: |
|||
critic_mem = None |
|||
value_outputs, _memories = self.critic( |
|||
vec_inputs, vis_inputs, memories=critic_mem |
|||
) |
|||
return value_outputs |
|||
|
|||
def get_dist_and_value( |
|||
self, |
|||
vec_inputs: List[torch.Tensor], |
|||
vis_inputs: List[torch.Tensor], |
|||
masks: Optional[torch.Tensor] = None, |
|||
memories: Optional[torch.Tensor] = None, |
|||
sequence_length: int = 1, |
|||
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: |
|||
if self.use_lstm: |
|||
# Use only the back half of memories for critic and actor |
|||
actor_mem, critic_mem = torch.split(memories, self.half_mem_size, dim=-1) |
|||
else: |
|||
critic_mem = None |
|||
actor_mem = None |
|||
dists, actor_mem_outs = self.get_dists( |
|||
vec_inputs, |
|||
vis_inputs, |
|||
memories=actor_mem, |
|||
sequence_length=sequence_length, |
|||
masks=masks, |
|||
) |
|||
value_outputs, critic_mem_outs = self.critic( |
|||
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length |
|||
) |
|||
if self.use_lstm: |
|||
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=1) |
|||
else: |
|||
mem_out = None |
|||
return dists, value_outputs, mem_out |
|||
|
|||
|
|||
class GlobalSteps(nn.Module): |
|||
def __init__(self): |
|||
super().__init__() |
|||
self.global_step = torch.Tensor([0]) |
|||
|
|||
def increment(self, value): |
|||
self.global_step += value |
|||
|
|||
|
|||
class LearningRate(nn.Module): |
|||
def __init__(self, lr): |
|||
# Todo: add learning rate decay |
|||
super().__init__() |
|||
self.learning_rate = torch.Tensor([lr]) |
|
|||
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( # noqa F401 |
|||
BaseRewardProvider, |
|||
) |
|||
from mlagents.trainers.torch.components.reward_providers.extrinsic_reward_provider import ( # noqa F401 |
|||
ExtrinsicRewardProvider, |
|||
) |
|||
from mlagents.trainers.torch.components.reward_providers.curiosity_reward_provider import ( # noqa F401 |
|||
CuriosityRewardProvider, |
|||
) |
|||
from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import ( # noqa F401 |
|||
GAILRewardProvider, |
|||
) |
|||
from mlagents.trainers.torch.components.reward_providers.reward_provider_factory import ( # noqa F401 |
|||
create_reward_provider, |
|||
) |
|
|||
import numpy as np |
|||
from abc import ABC, abstractmethod |
|||
from typing import Dict |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents.trainers.settings import RewardSignalSettings |
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
|
|||
|
|||
class BaseRewardProvider(ABC): |
|||
def __init__(self, specs: BehaviorSpec, settings: RewardSignalSettings) -> None: |
|||
self._policy_specs = specs |
|||
self._gamma = settings.gamma |
|||
self._strength = settings.strength |
|||
self._ignore_done = False |
|||
|
|||
@property |
|||
def gamma(self) -> float: |
|||
""" |
|||
The discount factor for the reward signal |
|||
""" |
|||
return self._gamma |
|||
|
|||
@property |
|||
def strength(self) -> float: |
|||
""" |
|||
The strength multiplier of the reward provider |
|||
""" |
|||
return self._strength |
|||
|
|||
@property |
|||
def name(self) -> str: |
|||
""" |
|||
The name of the reward provider. Is used for reporting and identification |
|||
""" |
|||
class_name = self.__class__.__name__ |
|||
return class_name.replace("RewardProvider", "") |
|||
|
|||
@property |
|||
def ignore_done(self) -> bool: |
|||
""" |
|||
If true, when the agent is done, the rewards of the next episode must be |
|||
used to calculate the return of the current episode. |
|||
Is used to mitigate the positive bias in rewards with no natural end. |
|||
""" |
|||
return self._ignore_done |
|||
|
|||
@abstractmethod |
|||
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray: |
|||
""" |
|||
Evaluates the reward for the data present in the Dict mini_batch. Use this when evaluating a reward |
|||
function drawn straight from a Buffer. |
|||
:param mini_batch: A Dict of numpy arrays (the format used by our Buffer) |
|||
when drawing from the update buffer. |
|||
:return: a np.ndarray of rewards generated by the reward provider |
|||
""" |
|||
raise NotImplementedError( |
|||
"The reward provider's evaluate method has not been implemented " |
|||
) |
|||
|
|||
@abstractmethod |
|||
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]: |
|||
""" |
|||
Update the reward for the data present in the Dict mini_batch. Use this when updating a reward |
|||
function drawn straight from a Buffer. |
|||
:param mini_batch: A Dict of numpy arrays (the format used by our Buffer) |
|||
when drawing from the update buffer. |
|||
:return: A dictionary from string to stats values |
|||
""" |
|||
raise NotImplementedError( |
|||
"The reward provider's update method has not been implemented " |
|||
) |
|
|||
import numpy as np |
|||
from typing import Dict |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( |
|||
BaseRewardProvider, |
|||
) |
|||
|
|||
|
|||
class ExtrinsicRewardProvider(BaseRewardProvider): |
|||
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray: |
|||
return np.array(mini_batch["environment_rewards"], dtype=np.float32) |
|||
|
|||
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]: |
|||
return {} |
|
|||
from typing import Dict, Type |
|||
from mlagents.trainers.exception import UnityTrainerException |
|||
|
|||
from mlagents.trainers.settings import RewardSignalSettings, RewardSignalType |
|||
|
|||
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( |
|||
BaseRewardProvider, |
|||
) |
|||
from mlagents.trainers.torch.components.reward_providers.extrinsic_reward_provider import ( |
|||
ExtrinsicRewardProvider, |
|||
) |
|||
from mlagents.trainers.torch.components.reward_providers.curiosity_reward_provider import ( |
|||
CuriosityRewardProvider, |
|||
) |
|||
from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import ( |
|||
GAILRewardProvider, |
|||
) |
|||
|
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
|
|||
NAME_TO_CLASS: Dict[RewardSignalType, Type[BaseRewardProvider]] = { |
|||
RewardSignalType.EXTRINSIC: ExtrinsicRewardProvider, |
|||
RewardSignalType.CURIOSITY: CuriosityRewardProvider, |
|||
RewardSignalType.GAIL: GAILRewardProvider, |
|||
} |
|||
|
|||
|
|||
def create_reward_provider( |
|||
name: RewardSignalType, specs: BehaviorSpec, settings: RewardSignalSettings |
|||
) -> BaseRewardProvider: |
|||
""" |
|||
Creates a reward provider class based on the name and config entry provided as a dict. |
|||
:param name: The name of the reward signal |
|||
:param specs: The BehaviorSpecs of the policy |
|||
:param settings: The RewardSignalSettings for that reward signal |
|||
:return: The reward signal class instantiated |
|||
""" |
|||
rcls = NAME_TO_CLASS.get(name) |
|||
if not rcls: |
|||
raise UnityTrainerException(f"Unknown reward signal type {name}") |
|||
|
|||
class_inst = rcls(specs, settings) |
|||
return class_inst |
|
|||
import numpy as np |
|||
from typing import Dict |
|||
import torch |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( |
|||
BaseRewardProvider, |
|||
) |
|||
from mlagents.trainers.settings import CuriositySettings |
|||
|
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
from mlagents.trainers.torch.networks import NetworkBody |
|||
from mlagents.trainers.torch.layers import linear_layer, Swish |
|||
from mlagents.trainers.settings import NetworkSettings, EncoderType |
|||
|
|||
|
|||
class CuriosityRewardProvider(BaseRewardProvider): |
|||
beta = 0.2 # Forward vs Inverse loss weight |
|||
loss_multiplier = 10.0 # Loss multiplier |
|||
|
|||
def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None: |
|||
super().__init__(specs, settings) |
|||
self._ignore_done = True |
|||
self._network = CuriosityNetwork(specs, settings) |
|||
self.optimizer = torch.optim.Adam( |
|||
self._network.parameters(), lr=settings.learning_rate |
|||
) |
|||
self._has_updated_once = False |
|||
|
|||
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray: |
|||
with torch.no_grad(): |
|||
rewards = self._network.compute_reward(mini_batch).detach().cpu().numpy() |
|||
rewards = np.minimum(rewards, 1.0 / self.strength) |
|||
return rewards * self._has_updated_once |
|||
|
|||
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]: |
|||
self._has_updated_once = True |
|||
forward_loss = self._network.compute_forward_loss(mini_batch) |
|||
inverse_loss = self._network.compute_inverse_loss(mini_batch) |
|||
|
|||
loss = self.loss_multiplier * ( |
|||
self.beta * forward_loss + (1.0 - self.beta) * inverse_loss |
|||
) |
|||
self.optimizer.zero_grad() |
|||
loss.backward() |
|||
self.optimizer.step() |
|||
return { |
|||
"Losses/Curiosity Forward Loss": forward_loss.detach().cpu().numpy(), |
|||
"Losses/Curiosity Inverse Loss": inverse_loss.detach().cpu().numpy(), |
|||
} |
|||
|
|||
|
|||
class CuriosityNetwork(torch.nn.Module): |
|||
EPSILON = 1e-10 |
|||
|
|||
def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None: |
|||
super().__init__() |
|||
self._policy_specs = specs |
|||
state_encoder_settings = NetworkSettings( |
|||
normalize=False, |
|||
hidden_units=settings.encoding_size, |
|||
num_layers=2, |
|||
vis_encode_type=EncoderType.SIMPLE, |
|||
memory=None, |
|||
) |
|||
self._state_encoder = NetworkBody( |
|||
specs.observation_shapes, state_encoder_settings |
|||
) |
|||
|
|||
self._action_flattener = ModelUtils.ActionFlattener(specs) |
|||
|
|||
self.inverse_model_action_predition = torch.nn.Sequential( |
|||
linear_layer(2 * settings.encoding_size, 256), |
|||
Swish(), |
|||
linear_layer(256, self._action_flattener.flattened_size), |
|||
) |
|||
|
|||
self.forward_model_next_state_prediction = torch.nn.Sequential( |
|||
linear_layer( |
|||
settings.encoding_size + self._action_flattener.flattened_size, 256 |
|||
), |
|||
Swish(), |
|||
linear_layer(256, settings.encoding_size), |
|||
) |
|||
|
|||
def get_current_state(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
""" |
|||
Extracts the current state embedding from a mini_batch. |
|||
""" |
|||
n_vis = len(self._state_encoder.visual_encoders) |
|||
hidden, _ = self._state_encoder.forward( |
|||
vec_inputs=[ |
|||
ModelUtils.list_to_tensor(mini_batch["vector_obs"], dtype=torch.float) |
|||
], |
|||
vis_inputs=[ |
|||
ModelUtils.list_to_tensor( |
|||
mini_batch["visual_obs%d" % i], dtype=torch.float |
|||
) |
|||
for i in range(n_vis) |
|||
], |
|||
) |
|||
return hidden |
|||
|
|||
def get_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
""" |
|||
Extracts the next state embedding from a mini_batch. |
|||
""" |
|||
n_vis = len(self._state_encoder.visual_encoders) |
|||
hidden, _ = self._state_encoder.forward( |
|||
vec_inputs=[ |
|||
ModelUtils.list_to_tensor( |
|||
mini_batch["next_vector_in"], dtype=torch.float |
|||
) |
|||
], |
|||
vis_inputs=[ |
|||
ModelUtils.list_to_tensor( |
|||
mini_batch["next_visual_obs%d" % i], dtype=torch.float |
|||
) |
|||
for i in range(n_vis) |
|||
], |
|||
) |
|||
return hidden |
|||
|
|||
def predict_action(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
""" |
|||
In the continuous case, returns the predicted action. |
|||
In the discrete case, returns the logits. |
|||
""" |
|||
inverse_model_input = torch.cat( |
|||
(self.get_current_state(mini_batch), self.get_next_state(mini_batch)), dim=1 |
|||
) |
|||
hidden = self.inverse_model_action_predition(inverse_model_input) |
|||
if self._policy_specs.is_action_continuous(): |
|||
return hidden |
|||
else: |
|||
branches = ModelUtils.break_into_branches( |
|||
hidden, self._policy_specs.discrete_action_branches |
|||
) |
|||
branches = [torch.softmax(b, dim=1) for b in branches] |
|||
return torch.cat(branches, dim=1) |
|||
|
|||
def predict_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
""" |
|||
Uses the current state embedding and the action of the mini_batch to predict |
|||
the next state embedding. |
|||
""" |
|||
if self._policy_specs.is_action_continuous(): |
|||
action = ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.float) |
|||
else: |
|||
action = torch.cat( |
|||
ModelUtils.actions_to_onehot( |
|||
ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.long), |
|||
self._policy_specs.discrete_action_branches, |
|||
), |
|||
dim=1, |
|||
) |
|||
forward_model_input = torch.cat( |
|||
(self.get_current_state(mini_batch), action), dim=1 |
|||
) |
|||
|
|||
return self.forward_model_next_state_prediction(forward_model_input) |
|||
|
|||
def compute_inverse_loss(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
""" |
|||
Computes the inverse loss for a mini_batch. Corresponds to the error on the |
|||
action prediction (given the current and next state). |
|||
""" |
|||
predicted_action = self.predict_action(mini_batch) |
|||
if self._policy_specs.is_action_continuous(): |
|||
sq_difference = ( |
|||
ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.float) |
|||
- predicted_action |
|||
) ** 2 |
|||
sq_difference = torch.sum(sq_difference, dim=1) |
|||
return torch.mean( |
|||
ModelUtils.dynamic_partition( |
|||
sq_difference, |
|||
ModelUtils.list_to_tensor(mini_batch["masks"], dtype=torch.float), |
|||
2, |
|||
)[1] |
|||
) |
|||
else: |
|||
true_action = torch.cat( |
|||
ModelUtils.actions_to_onehot( |
|||
ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.long), |
|||
self._policy_specs.discrete_action_branches, |
|||
), |
|||
dim=1, |
|||
) |
|||
cross_entropy = torch.sum( |
|||
-torch.log(predicted_action + self.EPSILON) * true_action, dim=1 |
|||
) |
|||
return torch.mean( |
|||
ModelUtils.dynamic_partition( |
|||
cross_entropy, |
|||
ModelUtils.list_to_tensor( |
|||
mini_batch["masks"], dtype=torch.float |
|||
), # use masks not action_masks |
|||
2, |
|||
)[1] |
|||
) |
|||
|
|||
def compute_reward(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
""" |
|||
Calculates the curiosity reward for the mini_batch. Corresponds to the error |
|||
between the predicted and actual next state. |
|||
""" |
|||
predicted_next_state = self.predict_next_state(mini_batch) |
|||
target = self.get_next_state(mini_batch) |
|||
sq_difference = 0.5 * (target - predicted_next_state) ** 2 |
|||
sq_difference = torch.sum(sq_difference, dim=1) |
|||
return sq_difference |
|||
|
|||
def compute_forward_loss(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
""" |
|||
Computes the loss for the next state prediction |
|||
""" |
|||
return torch.mean( |
|||
ModelUtils.dynamic_partition( |
|||
self.compute_reward(mini_batch), |
|||
ModelUtils.list_to_tensor(mini_batch["masks"], dtype=torch.float), |
|||
2, |
|||
)[1] |
|||
) |
|
|||
from typing import Optional, Dict |
|||
import numpy as np |
|||
import torch |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( |
|||
BaseRewardProvider, |
|||
) |
|||
from mlagents.trainers.settings import GAILSettings |
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
from mlagents.trainers.torch.networks import NetworkBody |
|||
from mlagents.trainers.torch.layers import linear_layer, Swish, Initialization |
|||
from mlagents.trainers.settings import NetworkSettings, EncoderType |
|||
from mlagents.trainers.demo_loader import demo_to_buffer |
|||
|
|||
|
|||
class GAILRewardProvider(BaseRewardProvider): |
|||
def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None: |
|||
super().__init__(specs, settings) |
|||
self._ignore_done = True |
|||
self._discriminator_network = DiscriminatorNetwork(specs, settings) |
|||
_, self._demo_buffer = demo_to_buffer( |
|||
settings.demo_path, 1, specs |
|||
) # This is supposed to be the sequence length but we do not have access here |
|||
params = list(self._discriminator_network.parameters()) |
|||
self.optimizer = torch.optim.Adam(params, lr=settings.learning_rate) |
|||
|
|||
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray: |
|||
with torch.no_grad(): |
|||
estimates, _ = self._discriminator_network.compute_estimate( |
|||
mini_batch, use_vail_noise=False |
|||
) |
|||
return ( |
|||
-torch.log( |
|||
1.0 |
|||
- estimates.squeeze(dim=1) |
|||
* (1.0 - self._discriminator_network.EPSILON) |
|||
) |
|||
.detach() |
|||
.cpu() |
|||
.numpy() |
|||
) |
|||
|
|||
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]: |
|||
expert_batch = self._demo_buffer.sample_mini_batch( |
|||
mini_batch.num_experiences, 1 |
|||
) |
|||
loss, policy_mean_estimate, expert_mean_estimate, kl_loss = self._discriminator_network.compute_loss( |
|||
mini_batch, expert_batch |
|||
) |
|||
self.optimizer.zero_grad() |
|||
loss.backward() |
|||
self.optimizer.step() |
|||
stats_dict = { |
|||
"Losses/GAIL Discriminator Loss": loss.detach().cpu().numpy(), |
|||
"Policy/GAIL Policy Estimate": policy_mean_estimate.detach().cpu().numpy(), |
|||
"Policy/GAIL Expert Estimate": expert_mean_estimate.detach().cpu().numpy(), |
|||
} |
|||
if self._discriminator_network.use_vail: |
|||
stats_dict["Policy/GAIL Beta"] = ( |
|||
self._discriminator_network.beta.detach().cpu().numpy() |
|||
) |
|||
stats_dict["Losses/GAIL KL Loss"] = kl_loss.detach().cpu().numpy() |
|||
return stats_dict |
|||
|
|||
|
|||
class DiscriminatorNetwork(torch.nn.Module): |
|||
gradient_penalty_weight = 10.0 |
|||
z_size = 128 |
|||
alpha = 0.0005 |
|||
mutual_information = 0.5 |
|||
EPSILON = 1e-7 |
|||
initial_beta = 0.0 |
|||
|
|||
def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None: |
|||
super().__init__() |
|||
self._policy_specs = specs |
|||
self.use_vail = settings.use_vail |
|||
self._settings = settings |
|||
|
|||
state_encoder_settings = NetworkSettings( |
|||
normalize=False, |
|||
hidden_units=settings.encoding_size, |
|||
num_layers=2, |
|||
vis_encode_type=EncoderType.SIMPLE, |
|||
memory=None, |
|||
) |
|||
self._state_encoder = NetworkBody( |
|||
specs.observation_shapes, state_encoder_settings |
|||
) |
|||
|
|||
self._action_flattener = ModelUtils.ActionFlattener(specs) |
|||
|
|||
encoder_input_size = settings.encoding_size |
|||
if settings.use_actions: |
|||
encoder_input_size += ( |
|||
self._action_flattener.flattened_size + 1 |
|||
) # + 1 is for done |
|||
|
|||
self.encoder = torch.nn.Sequential( |
|||
linear_layer(encoder_input_size, settings.encoding_size), |
|||
Swish(), |
|||
linear_layer(settings.encoding_size, settings.encoding_size), |
|||
Swish(), |
|||
) |
|||
|
|||
estimator_input_size = settings.encoding_size |
|||
if settings.use_vail: |
|||
estimator_input_size = self.z_size |
|||
self.z_sigma = torch.nn.Parameter( |
|||
torch.ones((self.z_size), dtype=torch.float), requires_grad=True |
|||
) |
|||
self.z_mu_layer = linear_layer( |
|||
settings.encoding_size, |
|||
self.z_size, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=0.1, |
|||
) |
|||
self.beta = torch.nn.Parameter( |
|||
torch.tensor(self.initial_beta, dtype=torch.float), requires_grad=False |
|||
) |
|||
|
|||
self.estimator = torch.nn.Sequential( |
|||
linear_layer(estimator_input_size, 1), torch.nn.Sigmoid() |
|||
) |
|||
|
|||
def get_action_input(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
""" |
|||
Creates the action Tensor. In continuous case, corresponds to the action. In |
|||
the discrete case, corresponds to the concatenation of one hot action Tensors. |
|||
""" |
|||
return self._action_flattener.forward( |
|||
torch.as_tensor(mini_batch["actions"], dtype=torch.float) |
|||
) |
|||
|
|||
def get_state_encoding(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
""" |
|||
Creates the observation input. |
|||
""" |
|||
n_vis = len(self._state_encoder.visual_encoders) |
|||
hidden, _ = self._state_encoder.forward( |
|||
vec_inputs=[torch.as_tensor(mini_batch["vector_obs"], dtype=torch.float)], |
|||
vis_inputs=[ |
|||
torch.as_tensor(mini_batch["visual_obs%d" % i], dtype=torch.float) |
|||
for i in range(n_vis) |
|||
], |
|||
) |
|||
return hidden |
|||
|
|||
def compute_estimate( |
|||
self, mini_batch: AgentBuffer, use_vail_noise: bool = False |
|||
) -> torch.Tensor: |
|||
""" |
|||
Given a mini_batch, computes the estimate (How much the discriminator believes |
|||
the data was sampled from the demonstration data). |
|||
:param mini_batch: The AgentBuffer of data |
|||
:param use_vail_noise: Only when using VAIL : If true, will sample the code, if |
|||
false, will return the mean of the code. |
|||
""" |
|||
encoder_input = self.get_state_encoding(mini_batch) |
|||
if self._settings.use_actions: |
|||
actions = self.get_action_input(mini_batch) |
|||
dones = torch.as_tensor(mini_batch["done"], dtype=torch.float) |
|||
encoder_input = torch.cat([encoder_input, actions, dones], dim=1) |
|||
hidden = self.encoder(encoder_input) |
|||
z_mu: Optional[torch.Tensor] = None |
|||
if self._settings.use_vail: |
|||
z_mu = self.z_mu_layer(hidden) |
|||
hidden = torch.normal(z_mu, self.z_sigma * use_vail_noise) |
|||
estimate = self.estimator(hidden) |
|||
return estimate, z_mu |
|||
|
|||
def compute_loss( |
|||
self, policy_batch: AgentBuffer, expert_batch: AgentBuffer |
|||
) -> torch.Tensor: |
|||
""" |
|||
Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator. |
|||
""" |
|||
policy_estimate, policy_mu = self.compute_estimate( |
|||
policy_batch, use_vail_noise=True |
|||
) |
|||
expert_estimate, expert_mu = self.compute_estimate( |
|||
expert_batch, use_vail_noise=True |
|||
) |
|||
loss = -( |
|||
torch.log(expert_estimate * (1 - self.EPSILON)) |
|||
+ torch.log(1.0 - policy_estimate * (1 - self.EPSILON)) |
|||
).mean() |
|||
kl_loss: Optional[torch.Tensor] = None |
|||
if self._settings.use_vail: |
|||
# KL divergence loss (encourage latent representation to be normal) |
|||
kl_loss = torch.mean( |
|||
-torch.sum( |
|||
1 |
|||
+ (self.z_sigma ** 2).log() |
|||
- 0.5 * expert_mu ** 2 |
|||
- 0.5 * policy_mu ** 2 |
|||
- (self.z_sigma ** 2), |
|||
dim=1, |
|||
) |
|||
) |
|||
vail_loss = self.beta * (kl_loss - self.mutual_information) |
|||
with torch.no_grad(): |
|||
self.beta.data = torch.max( |
|||
self.beta + self.alpha * (kl_loss - self.mutual_information), |
|||
torch.tensor(0.0), |
|||
) |
|||
loss += vail_loss |
|||
if self.gradient_penalty_weight > 0.0: |
|||
loss += self.gradient_penalty_weight * self.compute_gradient_magnitude( |
|||
policy_batch, expert_batch |
|||
) |
|||
return loss, torch.mean(policy_estimate), torch.mean(expert_estimate), kl_loss |
|||
|
|||
def compute_gradient_magnitude( |
|||
self, policy_batch: AgentBuffer, expert_batch: AgentBuffer |
|||
) -> torch.Tensor: |
|||
""" |
|||
Gradient penalty from https://arxiv.org/pdf/1704.00028. Adds stability esp. |
|||
for off-policy. Compute gradients w.r.t randomly interpolated input. |
|||
""" |
|||
policy_obs = self.get_state_encoding(policy_batch) |
|||
expert_obs = self.get_state_encoding(expert_batch) |
|||
obs_epsilon = torch.rand(policy_obs.shape) |
|||
encoder_input = obs_epsilon * policy_obs + (1 - obs_epsilon) * expert_obs |
|||
if self._settings.use_actions: |
|||
policy_action = self.get_action_input(policy_batch) |
|||
expert_action = self.get_action_input(policy_batch) |
|||
action_epsilon = torch.rand(policy_action.shape) |
|||
policy_dones = torch.as_tensor(policy_batch["done"], dtype=torch.float) |
|||
expert_dones = torch.as_tensor(expert_batch["done"], dtype=torch.float) |
|||
dones_epsilon = torch.rand(policy_dones.shape) |
|||
encoder_input = torch.cat( |
|||
[ |
|||
encoder_input, |
|||
action_epsilon * policy_action |
|||
+ (1 - action_epsilon) * expert_action, |
|||
dones_epsilon * policy_dones + (1 - dones_epsilon) * expert_dones, |
|||
], |
|||
dim=1, |
|||
) |
|||
hidden = self.encoder(encoder_input) |
|||
if self._settings.use_vail: |
|||
use_vail_noise = True |
|||
z_mu = self.z_mu_layer(hidden) |
|||
hidden = torch.normal(z_mu, self.z_sigma * use_vail_noise) |
|||
hidden = self.estimator(hidden) |
|||
estimate = torch.mean(torch.sum(hidden, dim=1)) |
|||
gradient = torch.autograd.grad(estimate, encoder_input)[0] |
|||
# Norm's gradient could be NaN at 0. Use our own safe_norm |
|||
safe_norm = (torch.sum(gradient ** 2, dim=1) + self.EPSILON).sqrt() |
|||
gradient_mag = torch.mean((safe_norm - 1) ** 2) |
|||
return gradient_mag |
|
|||
from typing import List, Dict |
|||
|
|||
import torch |
|||
from torch import nn |
|||
from mlagents.trainers.torch.layers import linear_layer |
|||
|
|||
|
|||
class ValueHeads(nn.Module): |
|||
def __init__(self, stream_names: List[str], input_size: int, output_size: int = 1): |
|||
super().__init__() |
|||
self.stream_names = stream_names |
|||
_value_heads = {} |
|||
|
|||
for name in stream_names: |
|||
value = linear_layer(input_size, output_size) |
|||
_value_heads[name] = value |
|||
self.value_heads = nn.ModuleDict(_value_heads) |
|||
|
|||
def forward(self, hidden: torch.Tensor) -> Dict[str, torch.Tensor]: |
|||
value_outputs = {} |
|||
for stream_name, head in self.value_heads.items(): |
|||
value_outputs[stream_name] = head(hidden).squeeze(-1) |
|||
return value_outputs |
|
|||
import abc |
|||
from typing import List |
|||
import torch |
|||
from torch import nn |
|||
import numpy as np |
|||
import math |
|||
from mlagents.trainers.torch.layers import linear_layer, Initialization |
|||
|
|||
EPSILON = 1e-7 # Small value to avoid divide by zero |
|||
|
|||
|
|||
class DistInstance(nn.Module, abc.ABC): |
|||
@abc.abstractmethod |
|||
def sample(self) -> torch.Tensor: |
|||
""" |
|||
Return a sample from this distribution. |
|||
""" |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def log_prob(self, value: torch.Tensor) -> torch.Tensor: |
|||
""" |
|||
Returns the log probabilities of a particular value. |
|||
:param value: A value sampled from the distribution. |
|||
:returns: Log probabilities of the given value. |
|||
""" |
|||
pass |
|||
|
|||
@abc.abstractmethod |
|||
def entropy(self) -> torch.Tensor: |
|||
""" |
|||
Returns the entropy of this distribution. |
|||
""" |
|||
pass |
|||
|
|||
|
|||
class DiscreteDistInstance(DistInstance): |
|||
@abc.abstractmethod |
|||
def all_log_prob(self) -> torch.Tensor: |
|||
""" |
|||
Returns the log probabilities of all actions represented by this distribution. |
|||
""" |
|||
pass |
|||
|
|||
|
|||
class GaussianDistInstance(DistInstance): |
|||
def __init__(self, mean, std): |
|||
super().__init__() |
|||
self.mean = mean |
|||
self.std = std |
|||
|
|||
def sample(self): |
|||
sample = self.mean + torch.randn_like(self.mean) * self.std |
|||
return sample |
|||
|
|||
def log_prob(self, value): |
|||
var = self.std ** 2 |
|||
log_scale = torch.log(self.std + EPSILON) |
|||
return ( |
|||
-((value - self.mean) ** 2) / (2 * var + EPSILON) |
|||
- log_scale |
|||
- math.log(math.sqrt(2 * math.pi)) |
|||
) |
|||
|
|||
def pdf(self, value): |
|||
log_prob = self.log_prob(value) |
|||
return torch.exp(log_prob) |
|||
|
|||
def entropy(self): |
|||
return 0.5 * torch.log(2 * math.pi * math.e * self.std + EPSILON) |
|||
|
|||
|
|||
class TanhGaussianDistInstance(GaussianDistInstance): |
|||
def __init__(self, mean, std): |
|||
super().__init__(mean, std) |
|||
self.transform = torch.distributions.transforms.TanhTransform(cache_size=1) |
|||
|
|||
def sample(self): |
|||
unsquashed_sample = super().sample() |
|||
squashed = self.transform(unsquashed_sample) |
|||
return squashed |
|||
|
|||
def _inverse_tanh(self, value): |
|||
capped_value = torch.clamp(value, -1 + EPSILON, 1 - EPSILON) |
|||
return 0.5 * torch.log((1 + capped_value) / (1 - capped_value) + EPSILON) |
|||
|
|||
def log_prob(self, value): |
|||
unsquashed = self.transform.inv(value) |
|||
return super().log_prob(unsquashed) - self.transform.log_abs_det_jacobian( |
|||
unsquashed, value |
|||
) |
|||
|
|||
|
|||
class CategoricalDistInstance(DiscreteDistInstance): |
|||
def __init__(self, logits): |
|||
super().__init__() |
|||
self.logits = logits |
|||
self.probs = torch.softmax(self.logits, dim=-1) |
|||
|
|||
def sample(self): |
|||
return torch.multinomial(self.probs, 1) |
|||
|
|||
def pdf(self, value): |
|||
# This function is equivalent to torch.diag(self.probs.T[value.flatten().long()]), |
|||
# but torch.diag is not supported by ONNX export. |
|||
idx = torch.arange(start=0, end=len(value)).unsqueeze(-1) |
|||
return torch.gather( |
|||
self.probs.permute(1, 0)[value.flatten().long()], -1, idx |
|||
).squeeze(-1) |
|||
|
|||
def log_prob(self, value): |
|||
return torch.log(self.pdf(value)) |
|||
|
|||
def all_log_prob(self): |
|||
return torch.log(self.probs) |
|||
|
|||
def entropy(self): |
|||
return -torch.sum(self.probs * torch.log(self.probs), dim=-1) |
|||
|
|||
|
|||
class GaussianDistribution(nn.Module): |
|||
def __init__( |
|||
self, |
|||
hidden_size: int, |
|||
num_outputs: int, |
|||
conditional_sigma: bool = False, |
|||
tanh_squash: bool = False, |
|||
): |
|||
super().__init__() |
|||
self.conditional_sigma = conditional_sigma |
|||
self.mu = linear_layer( |
|||
hidden_size, |
|||
num_outputs, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=0.1, |
|||
bias_init=Initialization.Zero, |
|||
) |
|||
self.tanh_squash = tanh_squash |
|||
if conditional_sigma: |
|||
self.log_sigma = linear_layer( |
|||
hidden_size, |
|||
num_outputs, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=0.1, |
|||
bias_init=Initialization.Zero, |
|||
) |
|||
else: |
|||
self.log_sigma = nn.Parameter( |
|||
torch.zeros(1, num_outputs, requires_grad=True) |
|||
) |
|||
|
|||
def forward(self, inputs: torch.Tensor) -> List[DistInstance]: |
|||
mu = self.mu(inputs) |
|||
if self.conditional_sigma: |
|||
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2) |
|||
else: |
|||
log_sigma = self.log_sigma |
|||
if self.tanh_squash: |
|||
return [TanhGaussianDistInstance(mu, torch.exp(log_sigma))] |
|||
else: |
|||
return [GaussianDistInstance(mu, torch.exp(log_sigma))] |
|||
|
|||
|
|||
class MultiCategoricalDistribution(nn.Module): |
|||
def __init__(self, hidden_size: int, act_sizes: List[int]): |
|||
super().__init__() |
|||
self.act_sizes = act_sizes |
|||
self.branches = self._create_policy_branches(hidden_size) |
|||
|
|||
def _create_policy_branches(self, hidden_size: int) -> nn.ModuleList: |
|||
branches = [] |
|||
for size in self.act_sizes: |
|||
branch_output_layer = linear_layer( |
|||
hidden_size, |
|||
size, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=0.1, |
|||
bias_init=Initialization.Zero, |
|||
) |
|||
branches.append(branch_output_layer) |
|||
return nn.ModuleList(branches) |
|||
|
|||
def _mask_branch(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
|||
raw_probs = torch.nn.functional.softmax(logits, dim=-1) * mask |
|||
normalized_probs = raw_probs / torch.sum(raw_probs, dim=-1).unsqueeze(-1) |
|||
normalized_logits = torch.log(normalized_probs + EPSILON) |
|||
return normalized_logits |
|||
|
|||
def _split_masks(self, masks: torch.Tensor) -> List[torch.Tensor]: |
|||
split_masks = [] |
|||
for idx, _ in enumerate(self.act_sizes): |
|||
start = int(np.sum(self.act_sizes[:idx])) |
|||
end = int(np.sum(self.act_sizes[: idx + 1])) |
|||
split_masks.append(masks[:, start:end]) |
|||
return split_masks |
|||
|
|||
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> List[DistInstance]: |
|||
# Todo - Support multiple branches in mask code |
|||
branch_distributions = [] |
|||
masks = self._split_masks(masks) |
|||
for idx, branch in enumerate(self.branches): |
|||
logits = branch(inputs) |
|||
norm_logits = self._mask_branch(logits, masks[idx]) |
|||
distribution = CategoricalDistInstance(norm_logits) |
|||
branch_distributions.append(distribution) |
|||
return branch_distributions |
|
|||
from typing import Tuple, Optional |
|||
|
|||
from mlagents.trainers.exception import UnityTrainerException |
|||
from mlagents.trainers.torch.layers import linear_layer, Initialization, Swish |
|||
|
|||
import torch |
|||
from torch import nn |
|||
|
|||
|
|||
class Normalizer(nn.Module): |
|||
def __init__(self, vec_obs_size: int): |
|||
super().__init__() |
|||
self.register_buffer("normalization_steps", torch.tensor(1)) |
|||
self.register_buffer("running_mean", torch.zeros(vec_obs_size)) |
|||
self.register_buffer("running_variance", torch.ones(vec_obs_size)) |
|||
|
|||
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|||
normalized_state = torch.clamp( |
|||
(inputs - self.running_mean) |
|||
/ torch.sqrt(self.running_variance / self.normalization_steps), |
|||
-5, |
|||
5, |
|||
) |
|||
return normalized_state |
|||
|
|||
def update(self, vector_input: torch.Tensor) -> None: |
|||
steps_increment = vector_input.size()[0] |
|||
total_new_steps = self.normalization_steps + steps_increment |
|||
|
|||
input_to_old_mean = vector_input - self.running_mean |
|||
new_mean = self.running_mean + (input_to_old_mean / total_new_steps).sum(0) |
|||
|
|||
input_to_new_mean = vector_input - new_mean |
|||
new_variance = self.running_variance + ( |
|||
input_to_new_mean * input_to_old_mean |
|||
).sum(0) |
|||
# Update in-place |
|||
self.running_mean.data.copy_(new_mean.data) |
|||
self.running_variance.data.copy_(new_variance.data) |
|||
self.normalization_steps.data.copy_(total_new_steps.data) |
|||
|
|||
def copy_from(self, other_normalizer: "Normalizer") -> None: |
|||
self.normalization_steps.data.copy_(other_normalizer.normalization_steps.data) |
|||
self.running_mean.data.copy_(other_normalizer.running_mean.data) |
|||
self.running_variance.copy_(other_normalizer.running_variance.data) |
|||
|
|||
|
|||
def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1): |
|||
from math import floor |
|||
|
|||
if type(kernel_size) is not tuple: |
|||
kernel_size = (kernel_size, kernel_size) |
|||
h = floor( |
|||
((h_w[0] + (2 * pad) - (dilation * (kernel_size[0] - 1)) - 1) / stride) + 1 |
|||
) |
|||
w = floor( |
|||
((h_w[1] + (2 * pad) - (dilation * (kernel_size[1] - 1)) - 1) / stride) + 1 |
|||
) |
|||
return h, w |
|||
|
|||
|
|||
def pool_out_shape(h_w: Tuple[int, int], kernel_size: int) -> Tuple[int, int]: |
|||
height = (h_w[0] - kernel_size) // 2 + 1 |
|||
width = (h_w[1] - kernel_size) // 2 + 1 |
|||
return height, width |
|||
|
|||
|
|||
class VectorEncoder(nn.Module): |
|||
def __init__( |
|||
self, |
|||
input_size: int, |
|||
hidden_size: int, |
|||
num_layers: int, |
|||
normalize: bool = False, |
|||
): |
|||
self.normalizer: Optional[Normalizer] = None |
|||
super().__init__() |
|||
self.layers = [ |
|||
linear_layer( |
|||
input_size, |
|||
hidden_size, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=1.0, |
|||
) |
|||
] |
|||
self.layers.append(Swish()) |
|||
if normalize: |
|||
self.normalizer = Normalizer(input_size) |
|||
|
|||
for _ in range(num_layers - 1): |
|||
self.layers.append( |
|||
linear_layer( |
|||
hidden_size, |
|||
hidden_size, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=1.0, |
|||
) |
|||
) |
|||
self.layers.append(Swish()) |
|||
self.seq_layers = nn.Sequential(*self.layers) |
|||
|
|||
def forward(self, inputs: torch.Tensor) -> None: |
|||
if self.normalizer is not None: |
|||
inputs = self.normalizer(inputs) |
|||
return self.seq_layers(inputs) |
|||
|
|||
def copy_normalization(self, other_encoder: "VectorEncoder") -> None: |
|||
if self.normalizer is not None and other_encoder.normalizer is not None: |
|||
self.normalizer.copy_from(other_encoder.normalizer) |
|||
|
|||
def update_normalization(self, inputs: torch.Tensor) -> None: |
|||
if self.normalizer is not None: |
|||
self.normalizer.update(inputs) |
|||
|
|||
|
|||
class VectorAndUnnormalizedInputEncoder(VectorEncoder): |
|||
""" |
|||
Encoder for concatenated vector input (can be normalized) and unnormalized vector input. |
|||
This is used for passing inputs to the network that should not be normalized, such as |
|||
actions in the case of a Q function or task parameterizations. It will result in an encoder with |
|||
this structure: |
|||
____________ ____________ ____________ |
|||
| Vector | | Normalize | | Fully | |
|||
| | --> | | --> | Connected | ___________ |
|||
|____________| |____________| | | | Output | |
|||
____________ | | --> | | |
|||
|Unnormalized| | | |___________| |
|||
| Input | ---------------------> | | |
|||
|____________| |____________| |
|||
""" |
|||
|
|||
def __init__( |
|||
self, |
|||
input_size: int, |
|||
hidden_size: int, |
|||
unnormalized_input_size: int, |
|||
num_layers: int, |
|||
normalize: bool = False, |
|||
): |
|||
super().__init__( |
|||
input_size + unnormalized_input_size, |
|||
hidden_size, |
|||
num_layers, |
|||
normalize=False, |
|||
) |
|||
if normalize: |
|||
self.normalizer = Normalizer(input_size) |
|||
else: |
|||
self.normalizer = None |
|||
|
|||
def forward( # pylint: disable=W0221 |
|||
self, inputs: torch.Tensor, unnormalized_inputs: Optional[torch.Tensor] = None |
|||
) -> None: |
|||
if unnormalized_inputs is None: |
|||
raise UnityTrainerException( |
|||
"Attempted to call an VectorAndUnnormalizedInputEncoder without an unnormalized input." |
|||
) # Fix mypy errors about method parameters. |
|||
if self.normalizer is not None: |
|||
inputs = self.normalizer(inputs) |
|||
return self.seq_layers(torch.cat([inputs, unnormalized_inputs], dim=-1)) |
|||
|
|||
|
|||
class SimpleVisualEncoder(nn.Module): |
|||
def __init__( |
|||
self, height: int, width: int, initial_channels: int, output_size: int |
|||
): |
|||
super().__init__() |
|||
self.h_size = output_size |
|||
conv_1_hw = conv_output_shape((height, width), 8, 4) |
|||
conv_2_hw = conv_output_shape(conv_1_hw, 4, 2) |
|||
self.final_flat = conv_2_hw[0] * conv_2_hw[1] * 32 |
|||
|
|||
self.conv_layers = nn.Sequential( |
|||
nn.Conv2d(initial_channels, 16, [8, 8], [4, 4]), |
|||
nn.LeakyReLU(), |
|||
nn.Conv2d(16, 32, [4, 4], [2, 2]), |
|||
nn.LeakyReLU(), |
|||
) |
|||
self.dense = nn.Sequential( |
|||
linear_layer( |
|||
self.final_flat, |
|||
self.h_size, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=1.0, |
|||
), |
|||
nn.LeakyReLU(), |
|||
) |
|||
|
|||
def forward(self, visual_obs: torch.Tensor) -> None: |
|||
hidden = self.conv_layers(visual_obs) |
|||
hidden = torch.reshape(hidden, (-1, self.final_flat)) |
|||
hidden = self.dense(hidden) |
|||
return hidden |
|||
|
|||
|
|||
class NatureVisualEncoder(nn.Module): |
|||
def __init__(self, height, width, initial_channels, output_size): |
|||
super().__init__() |
|||
self.h_size = output_size |
|||
conv_1_hw = conv_output_shape((height, width), 8, 4) |
|||
conv_2_hw = conv_output_shape(conv_1_hw, 4, 2) |
|||
conv_3_hw = conv_output_shape(conv_2_hw, 3, 1) |
|||
self.final_flat = conv_3_hw[0] * conv_3_hw[1] * 64 |
|||
|
|||
self.conv_layers = nn.Sequential( |
|||
nn.Conv2d(initial_channels, 32, [8, 8], [4, 4]), |
|||
nn.LeakyReLU(), |
|||
nn.Conv2d(32, 64, [4, 4], [2, 2]), |
|||
nn.LeakyReLU(), |
|||
nn.Conv2d(64, 64, [3, 3], [1, 1]), |
|||
nn.LeakyReLU(), |
|||
) |
|||
self.dense = nn.Sequential( |
|||
linear_layer( |
|||
self.final_flat, |
|||
self.h_size, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=1.0, |
|||
), |
|||
nn.LeakyReLU(), |
|||
) |
|||
|
|||
def forward(self, visual_obs: torch.Tensor) -> None: |
|||
hidden = self.conv_layers(visual_obs) |
|||
hidden = hidden.view([-1, self.final_flat]) |
|||
hidden = self.dense(hidden) |
|||
return hidden |
|||
|
|||
|
|||
class ResNetVisualEncoder(nn.Module): |
|||
def __init__(self, height, width, initial_channels, final_hidden): |
|||
super().__init__() |
|||
n_channels = [16, 32, 32] # channel for each stack |
|||
n_blocks = 2 # number of residual blocks |
|||
self.layers = [] |
|||
last_channel = initial_channels |
|||
for _, channel in enumerate(n_channels): |
|||
self.layers.append( |
|||
nn.Conv2d(last_channel, channel, [3, 3], [1, 1], padding=1) |
|||
) |
|||
self.layers.append(nn.MaxPool2d([3, 3], [2, 2])) |
|||
height, width = pool_out_shape((height, width), 3) |
|||
for _ in range(n_blocks): |
|||
self.layers.append(self.make_block(channel)) |
|||
last_channel = channel |
|||
self.layers.append(Swish()) |
|||
self.dense = linear_layer( |
|||
n_channels[-1] * height * width, |
|||
final_hidden, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=1.0, |
|||
) |
|||
|
|||
@staticmethod |
|||
def make_block(channel): |
|||
block_layers = [ |
|||
Swish(), |
|||
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), |
|||
Swish(), |
|||
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), |
|||
] |
|||
return block_layers |
|||
|
|||
@staticmethod |
|||
def forward_block(input_hidden, block_layers): |
|||
hidden = input_hidden |
|||
for layer in block_layers: |
|||
hidden = layer(hidden) |
|||
return hidden + input_hidden |
|||
|
|||
def forward(self, visual_obs): |
|||
batch_size = visual_obs.shape[0] |
|||
hidden = visual_obs |
|||
for layer in self.layers: |
|||
if isinstance(layer, nn.Module): |
|||
hidden = layer(hidden) |
|||
elif isinstance(layer, list): |
|||
hidden = self.forward_block(hidden, layer) |
|||
before_out = hidden.view(batch_size, -1) |
|||
return torch.relu(self.dense(before_out)) |
|
|||
import torch |
|||
from enum import Enum |
|||
|
|||
|
|||
class Swish(torch.nn.Module): |
|||
def forward(self, data: torch.Tensor) -> torch.Tensor: |
|||
return torch.mul(data, torch.sigmoid(data)) |
|||
|
|||
|
|||
class Initialization(Enum): |
|||
Zero = 0 |
|||
XavierGlorotNormal = 1 |
|||
XavierGlorotUniform = 2 |
|||
KaimingHeNormal = 3 # also known as Variance scaling |
|||
KaimingHeUniform = 4 |
|||
|
|||
|
|||
_init_methods = { |
|||
Initialization.Zero: torch.zero_, |
|||
Initialization.XavierGlorotNormal: torch.nn.init.xavier_normal_, |
|||
Initialization.XavierGlorotUniform: torch.nn.init.xavier_uniform_, |
|||
Initialization.KaimingHeNormal: torch.nn.init.kaiming_normal_, |
|||
Initialization.KaimingHeUniform: torch.nn.init.kaiming_uniform_, |
|||
} |
|||
|
|||
|
|||
def linear_layer( |
|||
input_size: int, |
|||
output_size: int, |
|||
kernel_init: Initialization = Initialization.XavierGlorotUniform, |
|||
kernel_gain: float = 1.0, |
|||
bias_init: Initialization = Initialization.Zero, |
|||
) -> torch.nn.Module: |
|||
""" |
|||
Creates a torch.nn.Linear module and initializes its weights. |
|||
:param input_size: The size of the input tensor |
|||
:param output_size: The size of the output tensor |
|||
:param kernel_init: The Initialization to use for the weights of the layer |
|||
:param kernel_gain: The multiplier for the weights of the kernel. Note that in |
|||
TensorFlow, calling variance_scaling with scale 0.01 is equivalent to calling |
|||
KaimingHeNormal with kernel_gain of 0.1 |
|||
:param bias_init: The Initialization to use for the weights of the bias layer |
|||
""" |
|||
layer = torch.nn.Linear(input_size, output_size) |
|||
_init_methods[kernel_init](layer.weight.data) |
|||
layer.weight.data *= kernel_gain |
|||
_init_methods[bias_init](layer.bias.data) |
|||
return layer |
|
|||
from typing import List, Optional, Tuple |
|||
import torch |
|||
import numpy as np |
|||
from torch import nn |
|||
|
|||
from mlagents.trainers.torch.encoders import ( |
|||
SimpleVisualEncoder, |
|||
ResNetVisualEncoder, |
|||
NatureVisualEncoder, |
|||
VectorEncoder, |
|||
VectorAndUnnormalizedInputEncoder, |
|||
) |
|||
from mlagents.trainers.settings import EncoderType, ScheduleType |
|||
from mlagents.trainers.exception import UnityTrainerException |
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance |
|||
|
|||
|
|||
class ModelUtils: |
|||
# Minimum supported side for each encoder type. If refactoring an encoder, please |
|||
# adjust these also. |
|||
MIN_RESOLUTION_FOR_ENCODER = { |
|||
EncoderType.SIMPLE: 20, |
|||
EncoderType.NATURE_CNN: 36, |
|||
EncoderType.RESNET: 15, |
|||
} |
|||
|
|||
class ActionFlattener: |
|||
def __init__(self, behavior_spec: BehaviorSpec): |
|||
self._specs = behavior_spec |
|||
|
|||
@property |
|||
def flattened_size(self) -> int: |
|||
if self._specs.is_action_continuous(): |
|||
return self._specs.action_size |
|||
else: |
|||
return sum(self._specs.discrete_action_branches) |
|||
|
|||
def forward(self, action: torch.Tensor) -> torch.Tensor: |
|||
if self._specs.is_action_continuous(): |
|||
return action |
|||
else: |
|||
return torch.cat( |
|||
ModelUtils.actions_to_onehot( |
|||
torch.as_tensor(action, dtype=torch.long), |
|||
self._specs.discrete_action_branches, |
|||
), |
|||
dim=1, |
|||
) |
|||
|
|||
@staticmethod |
|||
def update_learning_rate(optim: torch.optim.Optimizer, lr: float) -> None: |
|||
""" |
|||
Apply a learning rate to a torch optimizer. |
|||
:param optim: Optimizer |
|||
:param lr: Learning rate |
|||
""" |
|||
for param_group in optim.param_groups: |
|||
param_group["lr"] = lr |
|||
|
|||
class DecayedValue: |
|||
def __init__( |
|||
self, |
|||
schedule: ScheduleType, |
|||
initial_value: float, |
|||
min_value: float, |
|||
max_step: int, |
|||
): |
|||
""" |
|||
Object that represnets value of a parameter that should be decayed, assuming it is a function of |
|||
global_step. |
|||
:param schedule: Type of learning rate schedule. |
|||
:param initial_value: Initial value before decay. |
|||
:param min_value: Decay value to this value by max_step. |
|||
:param max_step: The final step count where the return value should equal min_value. |
|||
:param global_step: The current step count. |
|||
:return: The value. |
|||
""" |
|||
self.schedule = schedule |
|||
self.initial_value = initial_value |
|||
self.min_value = min_value |
|||
self.max_step = max_step |
|||
|
|||
def get_value(self, global_step: int) -> float: |
|||
""" |
|||
Get the value at a given global step. |
|||
:param global_step: Step count. |
|||
:returns: Decayed value at this global step. |
|||
""" |
|||
if self.schedule == ScheduleType.CONSTANT: |
|||
return self.initial_value |
|||
elif self.schedule == ScheduleType.LINEAR: |
|||
return ModelUtils.polynomial_decay( |
|||
self.initial_value, self.min_value, self.max_step, global_step |
|||
) |
|||
else: |
|||
raise UnityTrainerException(f"The schedule {self.schedule} is invalid.") |
|||
|
|||
@staticmethod |
|||
def polynomial_decay( |
|||
initial_value: float, |
|||
min_value: float, |
|||
max_step: int, |
|||
global_step: int, |
|||
power: float = 1.0, |
|||
) -> float: |
|||
""" |
|||
Get a decayed value based on a polynomial schedule, with respect to the current global step. |
|||
:param initial_value: Initial value before decay. |
|||
:param min_value: Decay value to this value by max_step. |
|||
:param max_step: The final step count where the return value should equal min_value. |
|||
:param global_step: The current step count. |
|||
:param power: Power of polynomial decay. 1.0 (default) is a linear decay. |
|||
:return: The current decayed value. |
|||
""" |
|||
global_step = min(global_step, max_step) |
|||
decayed_value = (initial_value - min_value) * ( |
|||
1 - float(global_step) / max_step |
|||
) ** (power) + min_value |
|||
return decayed_value |
|||
|
|||
@staticmethod |
|||
def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module: |
|||
ENCODER_FUNCTION_BY_TYPE = { |
|||
EncoderType.SIMPLE: SimpleVisualEncoder, |
|||
EncoderType.NATURE_CNN: NatureVisualEncoder, |
|||
EncoderType.RESNET: ResNetVisualEncoder, |
|||
} |
|||
return ENCODER_FUNCTION_BY_TYPE.get(encoder_type) |
|||
|
|||
@staticmethod |
|||
def _check_resolution_for_encoder( |
|||
height: int, width: int, vis_encoder_type: EncoderType |
|||
) -> None: |
|||
min_res = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[vis_encoder_type] |
|||
if height < min_res or width < min_res: |
|||
raise UnityTrainerException( |
|||
f"Visual observation resolution ({width}x{height}) is too small for" |
|||
f"the provided EncoderType ({vis_encoder_type.value}). The min dimension is {min_res}" |
|||
) |
|||
|
|||
@staticmethod |
|||
def create_encoders( |
|||
observation_shapes: List[Tuple[int, ...]], |
|||
h_size: int, |
|||
num_layers: int, |
|||
vis_encode_type: EncoderType, |
|||
unnormalized_inputs: int = 0, |
|||
normalize: bool = False, |
|||
) -> Tuple[nn.ModuleList, nn.ModuleList]: |
|||
""" |
|||
Creates visual and vector encoders, along with their normalizers. |
|||
:param observation_shapes: List of Tuples that represent the action dimensions. |
|||
:param action_size: Number of additional un-normalized inputs to each vector encoder. Used for |
|||
conditioining network on other values (e.g. actions for a Q function) |
|||
:param h_size: Number of hidden units per layer. |
|||
:param num_layers: Depth of MLP per encoder. |
|||
:param vis_encode_type: Type of visual encoder to use. |
|||
:param unnormalized_inputs: Vector inputs that should not be normalized, and added to the vector |
|||
obs. |
|||
:param normalize: Normalize all vector inputs. |
|||
:return: Tuple of visual encoders and vector encoders each as a list. |
|||
""" |
|||
visual_encoders: List[nn.Module] = [] |
|||
vector_encoders: List[nn.Module] = [] |
|||
|
|||
visual_encoder_class = ModelUtils.get_encoder_for_type(vis_encode_type) |
|||
vector_size = 0 |
|||
for i, dimension in enumerate(observation_shapes): |
|||
if len(dimension) == 3: |
|||
ModelUtils._check_resolution_for_encoder( |
|||
dimension[0], dimension[1], vis_encode_type |
|||
) |
|||
visual_encoders.append( |
|||
visual_encoder_class( |
|||
dimension[0], dimension[1], dimension[2], h_size |
|||
) |
|||
) |
|||
elif len(dimension) == 1: |
|||
vector_size += dimension[0] |
|||
else: |
|||
raise UnityTrainerException( |
|||
f"Unsupported shape of {dimension} for observation {i}" |
|||
) |
|||
if vector_size + unnormalized_inputs > 0: |
|||
if unnormalized_inputs > 0: |
|||
vector_encoders.append( |
|||
VectorAndUnnormalizedInputEncoder( |
|||
vector_size, h_size, unnormalized_inputs, num_layers, normalize |
|||
) |
|||
) |
|||
else: |
|||
vector_encoders.append( |
|||
VectorEncoder(vector_size, h_size, num_layers, normalize) |
|||
) |
|||
return nn.ModuleList(visual_encoders), nn.ModuleList(vector_encoders) |
|||
|
|||
@staticmethod |
|||
def list_to_tensor( |
|||
ndarray_list: List[np.ndarray], dtype: Optional[torch.dtype] = None |
|||
) -> torch.Tensor: |
|||
""" |
|||
Converts a list of numpy arrays into a tensor. MUCH faster than |
|||
calling as_tensor on the list directly. |
|||
""" |
|||
return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype) |
|||
|
|||
@staticmethod |
|||
def break_into_branches( |
|||
concatenated_logits: torch.Tensor, action_size: List[int] |
|||
) -> List[torch.Tensor]: |
|||
""" |
|||
Takes a concatenated set of logits that represent multiple discrete action branches |
|||
and breaks it up into one Tensor per branch. |
|||
:param concatenated_logits: Tensor that represents the concatenated action branches |
|||
:param action_size: List of ints containing the number of possible actions for each branch. |
|||
:return: A List of Tensors containing one tensor per branch. |
|||
""" |
|||
action_idx = [0] + list(np.cumsum(action_size)) |
|||
branched_logits = [ |
|||
concatenated_logits[:, action_idx[i] : action_idx[i + 1]] |
|||
for i in range(len(action_size)) |
|||
] |
|||
return branched_logits |
|||
|
|||
@staticmethod |
|||
def actions_to_onehot( |
|||
discrete_actions: torch.Tensor, action_size: List[int] |
|||
) -> List[torch.Tensor]: |
|||
""" |
|||
Takes a tensor of discrete actions and turns it into a List of onehot encoding for each |
|||
action. |
|||
:param discrete_actions: Actions in integer form. |
|||
:param action_size: List of branch sizes. Should be of same size as discrete_actions' |
|||
last dimension. |
|||
:return: List of one-hot tensors, one representing each branch. |
|||
""" |
|||
onehot_branches = [ |
|||
torch.nn.functional.one_hot(_act.T, action_size[i]).float() |
|||
for i, _act in enumerate(discrete_actions.long().T) |
|||
] |
|||
return onehot_branches |
|||
|
|||
@staticmethod |
|||
def dynamic_partition( |
|||
data: torch.Tensor, partitions: torch.Tensor, num_partitions: int |
|||
) -> List[torch.Tensor]: |
|||
""" |
|||
Torch implementation of dynamic_partition : |
|||
https://www.tensorflow.org/api_docs/python/tf/dynamic_partition |
|||
Splits the data Tensor input into num_partitions Tensors according to the indices in |
|||
partitions. |
|||
:param data: The Tensor data that will be split into partitions. |
|||
:param partitions: An indices tensor that determines in which partition each element |
|||
of data will be in. |
|||
:param num_partitions: The number of partitions to output. Corresponds to the |
|||
maximum possible index in the partitions argument. |
|||
:return: A list of Tensor partitions (Their indices correspond to their partition index). |
|||
""" |
|||
res: List[torch.Tensor] = [] |
|||
for i in range(num_partitions): |
|||
res += [data[(partitions == i).nonzero().squeeze(1)]] |
|||
return res |
|||
|
|||
@staticmethod |
|||
def get_probs_and_entropy( |
|||
action_list: List[torch.Tensor], dists: List[DistInstance] |
|||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
|||
log_probs_list = [] |
|||
all_probs_list = [] |
|||
entropies_list = [] |
|||
for action, action_dist in zip(action_list, dists): |
|||
log_prob = action_dist.log_prob(action) |
|||
log_probs_list.append(log_prob) |
|||
entropies_list.append(action_dist.entropy()) |
|||
if isinstance(action_dist, DiscreteDistInstance): |
|||
all_probs_list.append(action_dist.all_log_prob()) |
|||
log_probs = torch.stack(log_probs_list, dim=-1) |
|||
entropies = torch.stack(entropies_list, dim=-1) |
|||
if not all_probs_list: |
|||
log_probs = log_probs.squeeze(-1) |
|||
entropies = entropies.squeeze(-1) |
|||
all_probs = None |
|||
else: |
|||
all_probs = torch.cat(all_probs_list, dim=-1) |
|||
return log_probs, entropies, all_probs |
|
|||
import pytest |
|||
import torch |
|||
|
|||
from mlagents.trainers.torch.decoders import ValueHeads |
|||
|
|||
|
|||
def test_valueheads(): |
|||
stream_names = [f"reward_signal_{num}" for num in range(5)] |
|||
input_size = 5 |
|||
batch_size = 4 |
|||
|
|||
# Test default 1 value per head |
|||
value_heads = ValueHeads(stream_names, input_size) |
|||
input_data = torch.ones((batch_size, input_size)) |
|||
value_out = value_heads(input_data) # Note: mean value will be removed shortly |
|||
|
|||
for stream_name in stream_names: |
|||
assert value_out[stream_name].shape == (batch_size,) |
|||
|
|||
# Test that inputting the wrong size input will throw an error |
|||
with pytest.raises(Exception): |
|||
value_out = value_heads(torch.ones((batch_size, input_size + 2))) |
|||
|
|||
# Test multiple values per head (e.g. discrete Q function) |
|||
output_size = 4 |
|||
value_heads = ValueHeads(stream_names, input_size, output_size) |
|||
input_data = torch.ones((batch_size, input_size)) |
|||
value_out = value_heads(input_data) |
|||
|
|||
for stream_name in stream_names: |
|||
assert value_out[stream_name].shape == (batch_size, output_size) |
|
|||
import pytest |
|||
import torch |
|||
|
|||
from mlagents.trainers.torch.distributions import ( |
|||
GaussianDistribution, |
|||
MultiCategoricalDistribution, |
|||
GaussianDistInstance, |
|||
TanhGaussianDistInstance, |
|||
CategoricalDistInstance, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize("tanh_squash", [True, False]) |
|||
@pytest.mark.parametrize("conditional_sigma", [True, False]) |
|||
def test_gaussian_distribution(conditional_sigma, tanh_squash): |
|||
torch.manual_seed(0) |
|||
hidden_size = 16 |
|||
act_size = 4 |
|||
sample_embedding = torch.ones((1, 16)) |
|||
gauss_dist = GaussianDistribution( |
|||
hidden_size, |
|||
act_size, |
|||
conditional_sigma=conditional_sigma, |
|||
tanh_squash=tanh_squash, |
|||
) |
|||
|
|||
# Make sure backprop works |
|||
force_action = torch.zeros((1, act_size)) |
|||
optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3) |
|||
|
|||
for _ in range(50): |
|||
dist_inst = gauss_dist(sample_embedding)[0] |
|||
if tanh_squash: |
|||
assert isinstance(dist_inst, TanhGaussianDistInstance) |
|||
else: |
|||
assert isinstance(dist_inst, GaussianDistInstance) |
|||
log_prob = dist_inst.log_prob(force_action) |
|||
loss = torch.nn.functional.mse_loss(log_prob, -2 * torch.ones(log_prob.shape)) |
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
for prob in log_prob.flatten(): |
|||
assert prob == pytest.approx(-2, abs=0.1) |
|||
|
|||
|
|||
def test_multi_categorical_distribution(): |
|||
torch.manual_seed(0) |
|||
hidden_size = 16 |
|||
act_size = [3, 3, 4] |
|||
sample_embedding = torch.ones((1, 16)) |
|||
gauss_dist = MultiCategoricalDistribution(hidden_size, act_size) |
|||
|
|||
# Make sure backprop works |
|||
optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3) |
|||
|
|||
def create_test_prob(size: int) -> torch.Tensor: |
|||
test_prob = torch.tensor( |
|||
[[1.0 - 0.01 * (size - 1)] + [0.01] * (size - 1)] |
|||
) # High prob for first action |
|||
return test_prob.log() |
|||
|
|||
for _ in range(100): |
|||
dist_insts = gauss_dist(sample_embedding, masks=torch.ones((1, sum(act_size)))) |
|||
loss = 0 |
|||
for i, dist_inst in enumerate(dist_insts): |
|||
assert isinstance(dist_inst, CategoricalDistInstance) |
|||
log_prob = dist_inst.all_log_prob() |
|||
test_log_prob = create_test_prob(act_size[i]) |
|||
# Force log_probs to match the high probability for the first action generated by |
|||
# create_test_prob |
|||
loss += torch.nn.functional.mse_loss(log_prob, test_log_prob) |
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
for dist_inst, size in zip(dist_insts, act_size): |
|||
# Check that the log probs are close to the fake ones that we generated. |
|||
test_log_probs = create_test_prob(size) |
|||
for _prob, _test_prob in zip( |
|||
dist_inst.all_log_prob().flatten().tolist(), |
|||
test_log_probs.flatten().tolist(), |
|||
): |
|||
assert _prob == pytest.approx(_test_prob, abs=0.1) |
|||
|
|||
# Test masks |
|||
masks = [] |
|||
for branch in act_size: |
|||
masks += [0] * (branch - 1) + [1] |
|||
masks = torch.tensor([masks]) |
|||
dist_insts = gauss_dist(sample_embedding, masks=masks) |
|||
for dist_inst in dist_insts: |
|||
log_prob = dist_inst.all_log_prob() |
|||
assert log_prob.flatten()[-1] == pytest.approx(0, abs=0.001) |
|||
|
|||
|
|||
def test_gaussian_dist_instance(): |
|||
torch.manual_seed(0) |
|||
act_size = 4 |
|||
dist_instance = GaussianDistInstance( |
|||
torch.zeros(1, act_size), torch.ones(1, act_size) |
|||
) |
|||
action = dist_instance.sample() |
|||
assert action.shape == (1, act_size) |
|||
for log_prob in dist_instance.log_prob(torch.zeros((1, act_size))).flatten(): |
|||
# Log prob of standard normal at 0 |
|||
assert log_prob == pytest.approx(-0.919, abs=0.01) |
|||
|
|||
for ent in dist_instance.entropy().flatten(): |
|||
# entropy of standard normal at 0, based on 1/2 + ln(sqrt(2pi)sigma) |
|||
assert ent == pytest.approx(1.42, abs=0.01) |
|||
|
|||
|
|||
def test_tanh_gaussian_dist_instance(): |
|||
torch.manual_seed(0) |
|||
act_size = 4 |
|||
dist_instance = TanhGaussianDistInstance( |
|||
torch.zeros(1, act_size), torch.ones(1, act_size) |
|||
) |
|||
for _ in range(10): |
|||
action = dist_instance.sample() |
|||
assert action.shape == (1, act_size) |
|||
assert torch.max(action) < 1.0 and torch.min(action) > -1.0 |
|||
|
|||
|
|||
def test_categorical_dist_instance(): |
|||
torch.manual_seed(0) |
|||
act_size = 4 |
|||
test_prob = torch.tensor( |
|||
[[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)] |
|||
) # High prob for first action |
|||
dist_instance = CategoricalDistInstance(test_prob) |
|||
|
|||
for _ in range(10): |
|||
action = dist_instance.sample() |
|||
assert action.shape == (1, 1) |
|||
assert action < act_size |
|||
|
|||
# Make sure the first action as higher probability than the others. |
|||
prob_first_action = dist_instance.log_prob(torch.tensor([0])) |
|||
|
|||
for i in range(1, act_size): |
|||
assert dist_instance.log_prob(torch.tensor([i])) < prob_first_action |
|
|||
import torch |
|||
from unittest import mock |
|||
import pytest |
|||
|
|||
from mlagents.trainers.torch.encoders import ( |
|||
VectorEncoder, |
|||
VectorAndUnnormalizedInputEncoder, |
|||
Normalizer, |
|||
SimpleVisualEncoder, |
|||
ResNetVisualEncoder, |
|||
NatureVisualEncoder, |
|||
) |
|||
|
|||
|
|||
# This test will also reveal issues with states not being saved in the state_dict. |
|||
def compare_models(module_1, module_2): |
|||
is_same = True |
|||
for key_item_1, key_item_2 in zip( |
|||
module_1.state_dict().items(), module_2.state_dict().items() |
|||
): |
|||
# Compare tensors in state_dict and not the keys. |
|||
is_same = torch.equal(key_item_1[1], key_item_2[1]) and is_same |
|||
return is_same |
|||
|
|||
|
|||
def test_normalizer(): |
|||
input_size = 2 |
|||
norm = Normalizer(input_size) |
|||
|
|||
# These three inputs should mean to 0.5, and variance 2 |
|||
# with the steps starting at 1 |
|||
vec_input1 = torch.tensor([[1, 1]]) |
|||
vec_input2 = torch.tensor([[1, 1]]) |
|||
vec_input3 = torch.tensor([[0, 0]]) |
|||
norm.update(vec_input1) |
|||
norm.update(vec_input2) |
|||
norm.update(vec_input3) |
|||
|
|||
# Test normalization |
|||
for val in norm(vec_input1)[0]: |
|||
assert val == pytest.approx(0.707, abs=0.001) |
|||
|
|||
# Test copy normalization |
|||
norm2 = Normalizer(input_size) |
|||
assert not compare_models(norm, norm2) |
|||
norm2.copy_from(norm) |
|||
assert compare_models(norm, norm2) |
|||
for val in norm2(vec_input1)[0]: |
|||
assert val == pytest.approx(0.707, abs=0.001) |
|||
|
|||
|
|||
@mock.patch("mlagents.trainers.torch.encoders.Normalizer") |
|||
def test_vector_encoder(mock_normalizer): |
|||
mock_normalizer_inst = mock.Mock() |
|||
mock_normalizer.return_value = mock_normalizer_inst |
|||
input_size = 64 |
|||
hidden_size = 128 |
|||
num_layers = 3 |
|||
normalize = False |
|||
vector_encoder = VectorEncoder(input_size, hidden_size, num_layers, normalize) |
|||
output = vector_encoder(torch.ones((1, input_size))) |
|||
assert output.shape == (1, hidden_size) |
|||
|
|||
normalize = True |
|||
vector_encoder = VectorEncoder(input_size, hidden_size, num_layers, normalize) |
|||
new_vec = torch.ones((1, input_size)) |
|||
vector_encoder.update_normalization(new_vec) |
|||
|
|||
mock_normalizer.assert_called_with(input_size) |
|||
mock_normalizer_inst.update.assert_called_with(new_vec) |
|||
|
|||
vector_encoder2 = VectorEncoder(input_size, hidden_size, num_layers, normalize) |
|||
vector_encoder.copy_normalization(vector_encoder2) |
|||
mock_normalizer_inst.copy_from.assert_called_with(mock_normalizer_inst) |
|||
|
|||
|
|||
@mock.patch("mlagents.trainers.torch.encoders.Normalizer") |
|||
def test_vector_and_unnormalized_encoder(mock_normalizer): |
|||
mock_normalizer_inst = mock.Mock() |
|||
mock_normalizer.return_value = mock_normalizer_inst |
|||
input_size = 64 |
|||
unnormalized_size = 32 |
|||
hidden_size = 128 |
|||
num_layers = 3 |
|||
normalize = True |
|||
mock_normalizer_inst.return_value = torch.ones((1, input_size)) |
|||
vector_encoder = VectorAndUnnormalizedInputEncoder( |
|||
input_size, hidden_size, unnormalized_size, num_layers, normalize |
|||
) |
|||
# Make sure normalizer is only called on input_size |
|||
mock_normalizer.assert_called_with(input_size) |
|||
normal_input = torch.ones((1, input_size)) |
|||
|
|||
unnormalized_input = torch.ones((1, 32)) |
|||
output = vector_encoder(normal_input, unnormalized_input) |
|||
mock_normalizer_inst.assert_called_with(normal_input) |
|||
assert output.shape == (1, hidden_size) |
|||
|
|||
|
|||
@pytest.mark.parametrize("image_size", [(36, 36, 3), (84, 84, 4), (256, 256, 5)]) |
|||
@pytest.mark.parametrize( |
|||
"vis_class", [SimpleVisualEncoder, ResNetVisualEncoder, NatureVisualEncoder] |
|||
) |
|||
def test_visual_encoder(vis_class, image_size): |
|||
num_outputs = 128 |
|||
enc = vis_class(image_size[0], image_size[1], image_size[2], num_outputs) |
|||
# Note: NCHW not NHWC |
|||
sample_input = torch.ones((1, image_size[2], image_size[0], image_size[1])) |
|||
encoding = enc(sample_input) |
|||
assert encoding.shape == (1, num_outputs) |
|
|||
import pytest |
|||
import torch |
|||
import numpy as np |
|||
|
|||
from mlagents.trainers.settings import EncoderType, ScheduleType |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
from mlagents.trainers.exception import UnityTrainerException |
|||
from mlagents.trainers.torch.encoders import ( |
|||
VectorEncoder, |
|||
VectorAndUnnormalizedInputEncoder, |
|||
) |
|||
from mlagents.trainers.torch.distributions import ( |
|||
CategoricalDistInstance, |
|||
GaussianDistInstance, |
|||
) |
|||
|
|||
|
|||
def test_min_visual_size(): |
|||
# Make sure each EncoderType has an entry in MIS_RESOLUTION_FOR_ENCODER |
|||
assert set(ModelUtils.MIN_RESOLUTION_FOR_ENCODER.keys()) == set(EncoderType) |
|||
|
|||
for encoder_type in EncoderType: |
|||
good_size = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type] |
|||
vis_input = torch.ones((1, 3, good_size, good_size)) |
|||
ModelUtils._check_resolution_for_encoder(good_size, good_size, encoder_type) |
|||
enc_func = ModelUtils.get_encoder_for_type(encoder_type) |
|||
enc = enc_func(good_size, good_size, 3, 1) |
|||
enc.forward(vis_input) |
|||
|
|||
# Anything under the min size should raise an exception. If not, decrease the min size! |
|||
with pytest.raises(Exception): |
|||
bad_size = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type] - 1 |
|||
vis_input = torch.ones((1, 3, bad_size, bad_size)) |
|||
|
|||
with pytest.raises(UnityTrainerException): |
|||
# Make sure we'd hit a friendly error during model setup time. |
|||
ModelUtils._check_resolution_for_encoder( |
|||
bad_size, bad_size, encoder_type |
|||
) |
|||
|
|||
enc = enc_func(bad_size, bad_size, 3, 1) |
|||
enc.forward(vis_input) |
|||
|
|||
|
|||
@pytest.mark.parametrize("unnormalized_inputs", [0, 1]) |
|||
@pytest.mark.parametrize("num_visual", [0, 1, 2]) |
|||
@pytest.mark.parametrize("num_vector", [0, 1, 2]) |
|||
@pytest.mark.parametrize("normalize", [True, False]) |
|||
@pytest.mark.parametrize("encoder_type", [EncoderType.SIMPLE, EncoderType.NATURE_CNN]) |
|||
def test_create_encoders( |
|||
encoder_type, normalize, num_vector, num_visual, unnormalized_inputs |
|||
): |
|||
vec_obs_shape = (5,) |
|||
vis_obs_shape = (84, 84, 3) |
|||
obs_shapes = [] |
|||
for _ in range(num_vector): |
|||
obs_shapes.append(vec_obs_shape) |
|||
for _ in range(num_visual): |
|||
obs_shapes.append(vis_obs_shape) |
|||
h_size = 128 |
|||
num_layers = 3 |
|||
unnormalized_inputs = 1 |
|||
vis_enc, vec_enc = ModelUtils.create_encoders( |
|||
obs_shapes, h_size, num_layers, encoder_type, unnormalized_inputs, normalize |
|||
) |
|||
vec_enc = list(vec_enc) |
|||
vis_enc = list(vis_enc) |
|||
assert len(vec_enc) == ( |
|||
1 if unnormalized_inputs + num_vector > 0 else 0 |
|||
) # There's always at most one vector encoder. |
|||
assert len(vis_enc) == num_visual |
|||
|
|||
if unnormalized_inputs > 0: |
|||
assert isinstance(vec_enc[0], VectorAndUnnormalizedInputEncoder) |
|||
elif num_vector > 0: |
|||
assert isinstance(vec_enc[0], VectorEncoder) |
|||
|
|||
for enc in vis_enc: |
|||
assert isinstance(enc, ModelUtils.get_encoder_for_type(encoder_type)) |
|||
|
|||
|
|||
def test_decayed_value(): |
|||
test_steps = [0, 4, 9] |
|||
# Test constant decay |
|||
param = ModelUtils.DecayedValue(ScheduleType.CONSTANT, 1.0, 0.2, test_steps[-1]) |
|||
for _step in test_steps: |
|||
_param = param.get_value(_step) |
|||
assert _param == 1.0 |
|||
|
|||
test_results = [1.0, 0.6444, 0.2] |
|||
# Test linear decay |
|||
param = ModelUtils.DecayedValue(ScheduleType.LINEAR, 1.0, 0.2, test_steps[-1]) |
|||
for _step, _result in zip(test_steps, test_results): |
|||
_param = param.get_value(_step) |
|||
assert _param == pytest.approx(_result, abs=0.01) |
|||
|
|||
# Test invalid |
|||
with pytest.raises(UnityTrainerException): |
|||
ModelUtils.DecayedValue( |
|||
"SomeOtherSchedule", 1.0, 0.2, test_steps[-1] |
|||
).get_value(0) |
|||
|
|||
|
|||
def test_polynomial_decay(): |
|||
test_steps = [0, 4, 9] |
|||
test_results = [1.0, 0.7, 0.2] |
|||
for _step, _result in zip(test_steps, test_results): |
|||
decayed = ModelUtils.polynomial_decay( |
|||
1.0, 0.2, test_steps[-1], _step, power=0.8 |
|||
) |
|||
assert decayed == pytest.approx(_result, abs=0.01) |
|||
|
|||
|
|||
def test_list_to_tensor(): |
|||
# Test converting pure list |
|||
unconverted_list = [[1, 2], [1, 3], [1, 4]] |
|||
tensor = ModelUtils.list_to_tensor(unconverted_list) |
|||
# Should be equivalent to torch.tensor conversion |
|||
assert torch.equal(tensor, torch.tensor(unconverted_list)) |
|||
|
|||
# Test converting pure numpy array |
|||
np_list = np.asarray(unconverted_list) |
|||
tensor = ModelUtils.list_to_tensor(np_list) |
|||
# Should be equivalent to torch.tensor conversion |
|||
assert torch.equal(tensor, torch.tensor(unconverted_list)) |
|||
|
|||
# Test converting list of numpy arrays |
|||
list_of_np = [np.asarray(_el) for _el in unconverted_list] |
|||
tensor = ModelUtils.list_to_tensor(list_of_np) |
|||
# Should be equivalent to torch.tensor conversion |
|||
assert torch.equal(tensor, torch.tensor(unconverted_list)) |
|||
|
|||
|
|||
def test_break_into_branches(): |
|||
# Test normal multi-branch case |
|||
all_actions = torch.tensor([[1, 2, 3, 4, 5, 6]]) |
|||
action_size = [2, 1, 3] |
|||
broken_actions = ModelUtils.break_into_branches(all_actions, action_size) |
|||
assert len(action_size) == len(broken_actions) |
|||
for i, _action in enumerate(broken_actions): |
|||
assert _action.shape == (1, action_size[i]) |
|||
|
|||
# Test 1-branch case |
|||
action_size = [6] |
|||
broken_actions = ModelUtils.break_into_branches(all_actions, action_size) |
|||
assert len(broken_actions) == 1 |
|||
assert broken_actions[0].shape == (1, 6) |
|||
|
|||
|
|||
def test_actions_to_onehot(): |
|||
all_actions = torch.tensor([[1, 0, 2], [1, 0, 2]]) |
|||
action_size = [2, 1, 3] |
|||
oh_actions = ModelUtils.actions_to_onehot(all_actions, action_size) |
|||
expected_result = [ |
|||
torch.tensor([[0, 1], [0, 1]], dtype=torch.float), |
|||
torch.tensor([[1], [1]], dtype=torch.float), |
|||
torch.tensor([[0, 0, 1], [0, 0, 1]], dtype=torch.float), |
|||
] |
|||
for res, exp in zip(oh_actions, expected_result): |
|||
assert torch.equal(res, exp) |
|||
|
|||
|
|||
def test_get_probs_and_entropy(): |
|||
# Test continuous |
|||
# Add two dists to the list. This isn't done in the code but we'd like to support it. |
|||
dist_list = [ |
|||
GaussianDistInstance(torch.zeros((1, 2)), torch.ones((1, 2))), |
|||
GaussianDistInstance(torch.zeros((1, 2)), torch.ones((1, 2))), |
|||
] |
|||
action_list = [torch.zeros((1, 2)), torch.zeros((1, 2))] |
|||
log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy( |
|||
action_list, dist_list |
|||
) |
|||
assert log_probs.shape == (1, 2, 2) |
|||
assert entropies.shape == (1, 2, 2) |
|||
assert all_probs is None |
|||
|
|||
for log_prob in log_probs.flatten(): |
|||
# Log prob of standard normal at 0 |
|||
assert log_prob == pytest.approx(-0.919, abs=0.01) |
|||
|
|||
for ent in entropies.flatten(): |
|||
# entropy of standard normal at 0 |
|||
assert ent == pytest.approx(1.42, abs=0.01) |
|||
|
|||
# Test continuous |
|||
# Add two dists to the list. |
|||
act_size = 2 |
|||
test_prob = torch.tensor( |
|||
[[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)] |
|||
) # High prob for first action |
|||
dist_list = [CategoricalDistInstance(test_prob), CategoricalDistInstance(test_prob)] |
|||
action_list = [torch.tensor([0]), torch.tensor([1])] |
|||
log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy( |
|||
action_list, dist_list |
|||
) |
|||
assert all_probs.shape == (1, len(dist_list * act_size)) |
|||
assert entropies.shape == (1, len(dist_list)) |
|||
# Make sure the first action has high probability than the others. |
|||
assert log_probs.flatten()[0] > log_probs.flatten()[1] |
|
|||
import numpy as np |
|||
import pytest |
|||
import torch |
|||
from mlagents.trainers.torch.components.reward_providers import ( |
|||
CuriosityRewardProvider, |
|||
create_reward_provider, |
|||
) |
|||
from mlagents_envs.base_env import BehaviorSpec, ActionType |
|||
from mlagents.trainers.settings import CuriositySettings, RewardSignalType |
|||
from mlagents.trainers.tests.torch.test_reward_providers.utils import ( |
|||
create_agent_buffer, |
|||
) |
|||
|
|||
SEED = [42] |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
], |
|||
) |
|||
def test_construction(behavior_spec: BehaviorSpec) -> None: |
|||
curiosity_settings = CuriositySettings(32, 0.01) |
|||
curiosity_settings.strength = 0.1 |
|||
curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings) |
|||
assert curiosity_rp.strength == 0.1 |
|||
assert curiosity_rp.name == "Curiosity" |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,), (64, 66, 1)], ActionType.DISCRETE, (2, 3)), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)), |
|||
], |
|||
) |
|||
def test_factory(behavior_spec: BehaviorSpec) -> None: |
|||
curiosity_settings = CuriositySettings(32, 0.01) |
|||
curiosity_rp = create_reward_provider( |
|||
RewardSignalType.CURIOSITY, behavior_spec, curiosity_settings |
|||
) |
|||
assert curiosity_rp.name == "Curiosity" |
|||
|
|||
|
|||
@pytest.mark.parametrize("seed", SEED) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)), |
|||
], |
|||
) |
|||
def test_reward_decreases(behavior_spec: BehaviorSpec, seed: int) -> None: |
|||
np.random.seed(seed) |
|||
torch.manual_seed(seed) |
|||
curiosity_settings = CuriositySettings(32, 0.01) |
|||
curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings) |
|||
buffer = create_agent_buffer(behavior_spec, 5) |
|||
curiosity_rp.update(buffer) |
|||
reward_old = curiosity_rp.evaluate(buffer)[0] |
|||
for _ in range(10): |
|||
curiosity_rp.update(buffer) |
|||
reward_new = curiosity_rp.evaluate(buffer)[0] |
|||
assert reward_new < reward_old |
|||
reward_old = reward_new |
|||
|
|||
|
|||
@pytest.mark.parametrize("seed", SEED) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", [BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5)] |
|||
) |
|||
def test_continuous_action_prediction(behavior_spec: BehaviorSpec, seed: int) -> None: |
|||
np.random.seed(seed) |
|||
torch.manual_seed(seed) |
|||
curiosity_settings = CuriositySettings(32, 0.1) |
|||
curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings) |
|||
buffer = create_agent_buffer(behavior_spec, 5) |
|||
for _ in range(200): |
|||
curiosity_rp.update(buffer) |
|||
prediction = curiosity_rp._network.predict_action(buffer)[0].detach() |
|||
target = buffer["actions"][0] |
|||
error = float(torch.mean((prediction - target) ** 2)) |
|||
assert error < 0.001 |
|||
|
|||
|
|||
@pytest.mark.parametrize("seed", SEED) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,), (64, 66, 3)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)), |
|||
], |
|||
) |
|||
def test_next_state_prediction(behavior_spec: BehaviorSpec, seed: int) -> None: |
|||
np.random.seed(seed) |
|||
torch.manual_seed(seed) |
|||
curiosity_settings = CuriositySettings(32, 0.1) |
|||
curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings) |
|||
buffer = create_agent_buffer(behavior_spec, 5) |
|||
for _ in range(100): |
|||
curiosity_rp.update(buffer) |
|||
prediction = curiosity_rp._network.predict_next_state(buffer)[0] |
|||
target = curiosity_rp._network.get_next_state(buffer)[0] |
|||
error = float(torch.mean((prediction - target) ** 2).detach()) |
|||
assert error < 0.001 |
|
|||
import pytest |
|||
from mlagents.trainers.torch.components.reward_providers import ( |
|||
ExtrinsicRewardProvider, |
|||
create_reward_provider, |
|||
) |
|||
from mlagents_envs.base_env import BehaviorSpec, ActionType |
|||
from mlagents.trainers.settings import RewardSignalSettings, RewardSignalType |
|||
from mlagents.trainers.tests.torch.test_reward_providers.utils import ( |
|||
create_agent_buffer, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
], |
|||
) |
|||
def test_construction(behavior_spec: BehaviorSpec) -> None: |
|||
settings = RewardSignalSettings() |
|||
settings.gamma = 0.2 |
|||
extrinsic_rp = ExtrinsicRewardProvider(behavior_spec, settings) |
|||
assert extrinsic_rp.gamma == 0.2 |
|||
assert extrinsic_rp.name == "Extrinsic" |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
], |
|||
) |
|||
def test_factory(behavior_spec: BehaviorSpec) -> None: |
|||
settings = RewardSignalSettings() |
|||
extrinsic_rp = create_reward_provider( |
|||
RewardSignalType.EXTRINSIC, behavior_spec, settings |
|||
) |
|||
assert extrinsic_rp.name == "Extrinsic" |
|||
|
|||
|
|||
@pytest.mark.parametrize("reward", [2.0, 3.0, 4.0]) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
], |
|||
) |
|||
def test_reward(behavior_spec: BehaviorSpec, reward: float) -> None: |
|||
buffer = create_agent_buffer(behavior_spec, 1000, reward) |
|||
settings = RewardSignalSettings() |
|||
extrinsic_rp = ExtrinsicRewardProvider(behavior_spec, settings) |
|||
generated_rewards = extrinsic_rp.evaluate(buffer) |
|||
assert (generated_rewards == reward).all() |
|
|||
from typing import Any |
|||
import numpy as np |
|||
import pytest |
|||
from unittest.mock import patch |
|||
import torch |
|||
import os |
|||
from mlagents.trainers.torch.components.reward_providers import ( |
|||
GAILRewardProvider, |
|||
create_reward_provider, |
|||
) |
|||
from mlagents_envs.base_env import BehaviorSpec, ActionType |
|||
from mlagents.trainers.settings import GAILSettings, RewardSignalType |
|||
from mlagents.trainers.tests.torch.test_reward_providers.utils import ( |
|||
create_agent_buffer, |
|||
) |
|||
from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import ( |
|||
DiscriminatorNetwork, |
|||
) |
|||
|
|||
CONTINUOUS_PATH = ( |
|||
os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir) |
|||
+ "/test.demo" |
|||
) |
|||
DISCRETE_PATH = ( |
|||
os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir) |
|||
+ "/testdcvis.demo" |
|||
) |
|||
SEED = [42] |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", [BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2)] |
|||
) |
|||
def test_construction(behavior_spec: BehaviorSpec) -> None: |
|||
gail_settings = GAILSettings(demo_path=CONTINUOUS_PATH) |
|||
gail_rp = GAILRewardProvider(behavior_spec, gail_settings) |
|||
assert gail_rp.name == "GAIL" |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", [BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2)] |
|||
) |
|||
def test_factory(behavior_spec: BehaviorSpec) -> None: |
|||
gail_settings = GAILSettings(demo_path=CONTINUOUS_PATH) |
|||
gail_rp = create_reward_provider( |
|||
RewardSignalType.GAIL, behavior_spec, gail_settings |
|||
) |
|||
assert gail_rp.name == "GAIL" |
|||
|
|||
|
|||
@pytest.mark.parametrize("seed", SEED) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(8,), (24, 26, 1)], ActionType.CONTINUOUS, 2), |
|||
BehaviorSpec([(50,)], ActionType.DISCRETE, (2, 3, 3, 3)), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (20,)), |
|||
], |
|||
) |
|||
@pytest.mark.parametrize("use_actions", [False, True]) |
|||
@patch( |
|||
"mlagents.trainers.torch.components.reward_providers.gail_reward_provider.demo_to_buffer" |
|||
) |
|||
def test_reward_decreases( |
|||
demo_to_buffer: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int |
|||
) -> None: |
|||
np.random.seed(seed) |
|||
torch.manual_seed(seed) |
|||
buffer_expert = create_agent_buffer(behavior_spec, 1000) |
|||
buffer_policy = create_agent_buffer(behavior_spec, 1000) |
|||
demo_to_buffer.return_value = None, buffer_expert |
|||
gail_settings = GAILSettings( |
|||
demo_path="", learning_rate=0.05, use_vail=False, use_actions=use_actions |
|||
) |
|||
gail_rp = create_reward_provider( |
|||
RewardSignalType.GAIL, behavior_spec, gail_settings |
|||
) |
|||
|
|||
init_reward_expert = gail_rp.evaluate(buffer_expert)[0] |
|||
init_reward_policy = gail_rp.evaluate(buffer_policy)[0] |
|||
|
|||
for _ in range(10): |
|||
gail_rp.update(buffer_policy) |
|||
reward_expert = gail_rp.evaluate(buffer_expert)[0] |
|||
reward_policy = gail_rp.evaluate(buffer_policy)[0] |
|||
assert reward_expert >= 0 # GAIL / VAIL reward always positive |
|||
assert reward_policy >= 0 |
|||
reward_expert = gail_rp.evaluate(buffer_expert)[0] |
|||
reward_policy = gail_rp.evaluate(buffer_policy)[0] |
|||
assert reward_expert > reward_policy # Expert reward greater than non-expert reward |
|||
assert ( |
|||
reward_expert > init_reward_expert |
|||
) # Expert reward getting better as network trains |
|||
assert ( |
|||
reward_policy < init_reward_policy |
|||
) # Non-expert reward getting worse as network trains |
|||
|
|||
|
|||
@pytest.mark.parametrize("seed", SEED) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3, 3, 3)), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (20,)), |
|||
], |
|||
) |
|||
@pytest.mark.parametrize("use_actions", [False, True]) |
|||
@patch( |
|||
"mlagents.trainers.torch.components.reward_providers.gail_reward_provider.demo_to_buffer" |
|||
) |
|||
def test_reward_decreases_vail( |
|||
demo_to_buffer: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int |
|||
) -> None: |
|||
np.random.seed(seed) |
|||
torch.manual_seed(seed) |
|||
buffer_expert = create_agent_buffer(behavior_spec, 1000) |
|||
buffer_policy = create_agent_buffer(behavior_spec, 1000) |
|||
demo_to_buffer.return_value = None, buffer_expert |
|||
gail_settings = GAILSettings( |
|||
demo_path="", learning_rate=0.005, use_vail=True, use_actions=use_actions |
|||
) |
|||
DiscriminatorNetwork.initial_beta = 0.0 |
|||
# we must set the initial value of beta to 0 for testing |
|||
# If we do not, the kl-loss will dominate early and will block the estimator |
|||
gail_rp = create_reward_provider( |
|||
RewardSignalType.GAIL, behavior_spec, gail_settings |
|||
) |
|||
|
|||
for _ in range(100): |
|||
gail_rp.update(buffer_policy) |
|||
reward_expert = gail_rp.evaluate(buffer_expert)[0] |
|||
reward_policy = gail_rp.evaluate(buffer_policy)[0] |
|||
assert reward_expert >= 0 # GAIL / VAIL reward always positive |
|||
assert reward_policy >= 0 |
|||
reward_expert = gail_rp.evaluate(buffer_expert)[0] |
|||
reward_policy = gail_rp.evaluate(buffer_policy)[0] |
|||
assert reward_expert > reward_policy # Expert reward greater than non-expert reward |
|
|||
import numpy as np |
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
from mlagents.trainers.trajectory import SplitObservations |
|||
|
|||
|
|||
def create_agent_buffer( |
|||
behavior_spec: BehaviorSpec, number: int, reward: float = 0.0 |
|||
) -> AgentBuffer: |
|||
buffer = AgentBuffer() |
|||
curr_observations = [ |
|||
np.random.normal(size=shape) for shape in behavior_spec.observation_shapes |
|||
] |
|||
next_observations = [ |
|||
np.random.normal(size=shape) for shape in behavior_spec.observation_shapes |
|||
] |
|||
action = behavior_spec.create_random_action(1)[0, :] |
|||
for _ in range(number): |
|||
curr_split_obs = SplitObservations.from_observations(curr_observations) |
|||
next_split_obs = SplitObservations.from_observations(next_observations) |
|||
for i, _ in enumerate(curr_split_obs.visual_observations): |
|||
buffer["visual_obs%d" % i].append(curr_split_obs.visual_observations[i]) |
|||
buffer["next_visual_obs%d" % i].append( |
|||
next_split_obs.visual_observations[i] |
|||
) |
|||
buffer["vector_obs"].append(curr_split_obs.vector_observations) |
|||
buffer["next_vector_in"].append(next_split_obs.vector_observations) |
|||
buffer["actions"].append(action) |
|||
buffer["done"].append(np.zeros(1, dtype=np.float32)) |
|||
buffer["reward"].append(np.ones(1, dtype=np.float32) * reward) |
|||
buffer["masks"].append(np.ones(1, dtype=np.float32)) |
|||
return buffer |
|
|||
import torch |
|||
|
|||
from mlagents.trainers.torch.layers import Swish, linear_layer, Initialization |
|||
|
|||
|
|||
def test_swish(): |
|||
layer = Swish() |
|||
input_tensor = torch.Tensor([[1, 2, 3], [4, 5, 6]]) |
|||
target_tensor = torch.mul(input_tensor, torch.sigmoid(input_tensor)) |
|||
assert torch.all(torch.eq(layer(input_tensor), target_tensor)) |
|||
|
|||
|
|||
def test_initialization_layer(): |
|||
torch.manual_seed(0) |
|||
# Test Zero |
|||
layer = linear_layer( |
|||
3, 4, kernel_init=Initialization.Zero, bias_init=Initialization.Zero |
|||
) |
|||
assert torch.all(torch.eq(layer.weight.data, torch.zeros_like(layer.weight.data))) |
|||
assert torch.all(torch.eq(layer.bias.data, torch.zeros_like(layer.bias.data))) |
|
|||
import pytest |
|||
|
|||
import torch |
|||
from mlagents.trainers.torch.networks import ( |
|||
NetworkBody, |
|||
ValueNetwork, |
|||
SimpleActor, |
|||
SharedActorCritic, |
|||
SeparateActorCritic, |
|||
) |
|||
from mlagents.trainers.settings import NetworkSettings |
|||
from mlagents_envs.base_env import ActionType |
|||
from mlagents.trainers.torch.distributions import ( |
|||
GaussianDistInstance, |
|||
CategoricalDistInstance, |
|||
) |
|||
|
|||
|
|||
def test_networkbody_vector(): |
|||
torch.manual_seed(0) |
|||
obs_size = 4 |
|||
network_settings = NetworkSettings() |
|||
obs_shapes = [(obs_size,)] |
|||
|
|||
networkbody = NetworkBody(obs_shapes, network_settings, encoded_act_size=2) |
|||
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) |
|||
sample_obs = 0.1 * torch.ones((1, obs_size)) |
|||
sample_act = 0.1 * torch.ones((1, 2)) |
|||
|
|||
for _ in range(300): |
|||
encoded, _ = networkbody([sample_obs], [], sample_act) |
|||
assert encoded.shape == (1, network_settings.hidden_units) |
|||
# Try to force output to 1 |
|||
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) |
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
# In the last step, values should be close to 1 |
|||
for _enc in encoded.flatten(): |
|||
assert _enc == pytest.approx(1.0, abs=0.1) |
|||
|
|||
|
|||
def test_networkbody_lstm(): |
|||
torch.manual_seed(0) |
|||
obs_size = 4 |
|||
seq_len = 16 |
|||
network_settings = NetworkSettings( |
|||
memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=4) |
|||
) |
|||
obs_shapes = [(obs_size,)] |
|||
|
|||
networkbody = NetworkBody(obs_shapes, network_settings) |
|||
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) |
|||
sample_obs = torch.ones((1, seq_len, obs_size)) |
|||
|
|||
for _ in range(100): |
|||
encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 4)) |
|||
# Try to force output to 1 |
|||
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) |
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
# In the last step, values should be close to 1 |
|||
for _enc in encoded.flatten(): |
|||
assert _enc == pytest.approx(1.0, abs=0.1) |
|||
|
|||
|
|||
def test_networkbody_visual(): |
|||
torch.manual_seed(0) |
|||
vec_obs_size = 4 |
|||
obs_size = (84, 84, 3) |
|||
network_settings = NetworkSettings() |
|||
obs_shapes = [(vec_obs_size,), obs_size] |
|||
torch.random.manual_seed(0) |
|||
|
|||
networkbody = NetworkBody(obs_shapes, network_settings) |
|||
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) |
|||
sample_obs = torch.ones((1, 84, 84, 3)) |
|||
sample_vec_obs = torch.ones((1, vec_obs_size)) |
|||
|
|||
for _ in range(150): |
|||
encoded, _ = networkbody([sample_vec_obs], [sample_obs]) |
|||
assert encoded.shape == (1, network_settings.hidden_units) |
|||
# Try to force output to 1 |
|||
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) |
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
# In the last step, values should be close to 1 |
|||
for _enc in encoded.flatten(): |
|||
assert _enc == pytest.approx(1.0, abs=0.1) |
|||
|
|||
|
|||
def test_valuenetwork(): |
|||
torch.manual_seed(0) |
|||
obs_size = 4 |
|||
num_outputs = 2 |
|||
network_settings = NetworkSettings() |
|||
obs_shapes = [(obs_size,)] |
|||
|
|||
stream_names = [f"stream_name{n}" for n in range(4)] |
|||
value_net = ValueNetwork( |
|||
stream_names, obs_shapes, network_settings, outputs_per_stream=num_outputs |
|||
) |
|||
optimizer = torch.optim.Adam(value_net.parameters(), lr=3e-3) |
|||
|
|||
for _ in range(50): |
|||
sample_obs = torch.ones((1, obs_size)) |
|||
values, _ = value_net([sample_obs], []) |
|||
loss = 0 |
|||
for s_name in stream_names: |
|||
assert values[s_name].shape == (1, num_outputs) |
|||
# Try to force output to 1 |
|||
loss += torch.nn.functional.mse_loss( |
|||
values[s_name], torch.ones((1, num_outputs)) |
|||
) |
|||
|
|||
optimizer.zero_grad() |
|||
loss.backward() |
|||
optimizer.step() |
|||
# In the last step, values should be close to 1 |
|||
for value in values.values(): |
|||
for _out in value: |
|||
assert _out[0] == pytest.approx(1.0, abs=0.1) |
|||
|
|||
|
|||
@pytest.mark.parametrize("action_type", [ActionType.DISCRETE, ActionType.CONTINUOUS]) |
|||
def test_simple_actor(action_type): |
|||
obs_size = 4 |
|||
network_settings = NetworkSettings() |
|||
obs_shapes = [(obs_size,)] |
|||
act_size = [2] |
|||
masks = None if action_type == ActionType.CONTINUOUS else torch.ones((1, 1)) |
|||
actor = SimpleActor(obs_shapes, network_settings, action_type, act_size) |
|||
# Test get_dist |
|||
sample_obs = torch.ones((1, obs_size)) |
|||
dists, _ = actor.get_dists([sample_obs], [], masks=masks) |
|||
for dist in dists: |
|||
if action_type == ActionType.CONTINUOUS: |
|||
assert isinstance(dist, GaussianDistInstance) |
|||
else: |
|||
assert isinstance(dist, CategoricalDistInstance) |
|||
|
|||
# Test sample_actions |
|||
actions = actor.sample_action(dists) |
|||
for act in actions: |
|||
if action_type == ActionType.CONTINUOUS: |
|||
assert act.shape == (1, act_size[0]) |
|||
else: |
|||
assert act.shape == (1, 1) |
|||
|
|||
# Test forward |
|||
actions, probs, ver_num, mem_size, is_cont, act_size_vec = actor.forward( |
|||
[sample_obs], [], masks=masks |
|||
) |
|||
for act in actions: |
|||
if action_type == ActionType.CONTINUOUS: |
|||
assert act.shape == ( |
|||
act_size[0], |
|||
1, |
|||
) # This is different from above for ONNX export |
|||
else: |
|||
assert act.shape == (1, 1) |
|||
|
|||
# TODO: Once export works properly. fix the shapes here. |
|||
assert mem_size == 0 |
|||
assert is_cont == int(action_type == ActionType.CONTINUOUS) |
|||
assert act_size_vec == torch.tensor(act_size) |
|||
|
|||
|
|||
@pytest.mark.parametrize("ac_type", [SharedActorCritic, SeparateActorCritic]) |
|||
@pytest.mark.parametrize("lstm", [True, False]) |
|||
def test_actor_critic(ac_type, lstm): |
|||
obs_size = 4 |
|||
network_settings = NetworkSettings( |
|||
memory=NetworkSettings.MemorySettings() if lstm else None |
|||
) |
|||
obs_shapes = [(obs_size,)] |
|||
act_size = [2] |
|||
stream_names = [f"stream_name{n}" for n in range(4)] |
|||
actor = ac_type( |
|||
obs_shapes, network_settings, ActionType.CONTINUOUS, act_size, stream_names |
|||
) |
|||
if lstm: |
|||
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size)) |
|||
memories = torch.ones( |
|||
( |
|||
1, |
|||
network_settings.memory.sequence_length, |
|||
network_settings.memory.memory_size, |
|||
) |
|||
) |
|||
else: |
|||
sample_obs = torch.ones((1, obs_size)) |
|||
memories = torch.tensor([]) |
|||
# memories isn't always set to None, the network should be able to |
|||
# deal with that. |
|||
# Test critic pass |
|||
value_out = actor.critic_pass([sample_obs], [], memories=memories) |
|||
for stream in stream_names: |
|||
if lstm: |
|||
assert value_out[stream].shape == (network_settings.memory.sequence_length,) |
|||
else: |
|||
assert value_out[stream].shape == (1,) |
|||
|
|||
# Test get_dist_and_value |
|||
dists, value_out, _ = actor.get_dist_and_value([sample_obs], [], memories=memories) |
|||
for dist in dists: |
|||
assert isinstance(dist, GaussianDistInstance) |
|||
for stream in stream_names: |
|||
if lstm: |
|||
assert value_out[stream].shape == (network_settings.memory.sequence_length,) |
|||
else: |
|||
assert value_out[stream].shape == (1,) |
|
|||
fileFormatVersion: 2 |
|||
guid: 771e78c5e980e440e8cd19716b55075f |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
-warnaserror+ |
|||
-warnaserror-:618 |
|
|||
fileFormatVersion: 2 |
|||
guid: 7c1189c0af42c46f7b533350d49ad3e7 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
import pytest |
|||
|
|||
from mlagents.trainers.tf.models import ModelUtils |
|||
from mlagents.tf_utils import tf |
|||
from mlagents_envs.base_env import BehaviorSpec, ActionType |
|||
|
|||
|
|||
def create_behavior_spec(num_visual, num_vector, vector_size): |
|||
behavior_spec = BehaviorSpec( |
|||
[(84, 84, 3)] * int(num_visual) + [(vector_size,)] * int(num_vector), |
|||
ActionType.DISCRETE, |
|||
(1,), |
|||
) |
|||
return behavior_spec |
|||
|
|||
|
|||
@pytest.mark.parametrize("num_visual", [1, 2, 4]) |
|||
@pytest.mark.parametrize("num_vector", [1, 2, 4]) |
|||
def test_create_input_placeholders(num_vector, num_visual): |
|||
vec_size = 8 |
|||
name_prefix = "test123" |
|||
bspec = create_behavior_spec(num_visual, num_vector, vec_size) |
|||
vec_in, vis_in = ModelUtils.create_input_placeholders( |
|||
bspec.observation_shapes, name_prefix=name_prefix |
|||
) |
|||
|
|||
assert isinstance(vis_in, list) |
|||
assert len(vis_in) == num_visual |
|||
assert isinstance(vec_in, tf.Tensor) |
|||
assert vec_in.get_shape().as_list()[1] == num_vector * 8 |
|||
|
|||
# Check names contain prefix and vis shapes are correct |
|||
for _vis in vis_in: |
|||
assert _vis.get_shape().as_list() == [None, 84, 84, 3] |
|||
assert _vis.name.startswith(name_prefix) |
|||
assert vec_in.name.startswith(name_prefix) |
|
|||
#!/usr/bin/env python3 |
|||
|
|||
import ast |
|||
import sys |
|||
import os |
|||
import re |
|||
import subprocess |
|||
from typing import List, Optional, Pattern |
|||
|
|||
RELEASE_PATTERN = re.compile(r"release_[0-9]+(_docs)*") |
|||
TRAINER_INIT_FILE = "ml-agents/mlagents/trainers/__init__.py" |
|||
|
|||
# Filename -> regex list to allow specific lines. |
|||
# To allow everything in the file, use None for the value |
|||
ALLOW_LIST = { |
|||
# Previous release table |
|||
"README.md": re.compile(r"\*\*Release [0-9]+\*\*"), |
|||
"docs/Versioning.md": None, |
|||
"com.unity.ml-agents/CHANGELOG.md": None, |
|||
"utils/make_readme_table.py": None, |
|||
"utils/validate_release_links.py": None, |
|||
} |
|||
|
|||
|
|||
def test_pattern(): |
|||
# Just some sanity check that the regex works as expected. |
|||
assert RELEASE_PATTERN.search( |
|||
"https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/Food.md" |
|||
) |
|||
assert RELEASE_PATTERN.search( |
|||
"https://github.com/Unity-Technologies/ml-agents/blob/release_4/Foo.md" |
|||
) |
|||
assert RELEASE_PATTERN.search( |
|||
"git clone --branch release_4 https://github.com/Unity-Technologies/ml-agents.git" |
|||
) |
|||
assert RELEASE_PATTERN.search( |
|||
"https://github.com/Unity-Technologies/ml-agents/blob/release_123_docs/Foo.md" |
|||
) |
|||
assert RELEASE_PATTERN.search( |
|||
"https://github.com/Unity-Technologies/ml-agents/blob/release_123/Foo.md" |
|||
) |
|||
assert not RELEASE_PATTERN.search( |
|||
"https://github.com/Unity-Technologies/ml-agents/blob/latest_release/docs/Foo.md" |
|||
) |
|||
print("tests OK!") |
|||
|
|||
|
|||
def git_ls_files() -> List[str]: |
|||
""" |
|||
Run "git ls-files" and return a list with one entry per line. |
|||
This returns the list of all files tracked by git. |
|||
""" |
|||
return subprocess.check_output(["git", "ls-files"], universal_newlines=True).split( |
|||
"\n" |
|||
) |
|||
|
|||
|
|||
def get_release_tag() -> Optional[str]: |
|||
""" |
|||
Returns the release tag for the mlagents python package. |
|||
This will be None on the master branch. |
|||
:return: |
|||
""" |
|||
with open(TRAINER_INIT_FILE) as f: |
|||
for line in f: |
|||
if "__release_tag__" in line: |
|||
lhs, equals_string, rhs = line.strip().partition(" = ") |
|||
# Evaluate the right hand side of the expression |
|||
return ast.literal_eval(rhs) |
|||
# If we couldn't find the release tag, raise an exception |
|||
# (since we can't return None here) |
|||
raise RuntimeError("Can't determine release tag") |
|||
|
|||
|
|||
def check_file(filename: str, global_allow_pattern: Pattern) -> List[str]: |
|||
""" |
|||
Validate a single file and return any offending lines. |
|||
""" |
|||
bad_lines = [] |
|||
with open(filename) as f: |
|||
for line in f: |
|||
if not RELEASE_PATTERN.search(line): |
|||
continue |
|||
|
|||
if global_allow_pattern.search(line): |
|||
continue |
|||
|
|||
if filename in ALLOW_LIST: |
|||
if ALLOW_LIST[filename] is None or ALLOW_LIST[filename].search(line): |
|||
continue |
|||
|
|||
bad_lines.append(f"{filename}: {line.strip()}") |
|||
return bad_lines |
|||
|
|||
|
|||
def check_all_files(allow_pattern: Pattern) -> List[str]: |
|||
""" |
|||
Validate all files tracked by git. |
|||
:param allow_pattern: |
|||
""" |
|||
bad_lines = [] |
|||
file_types = {".py", ".md", ".cs"} |
|||
for file_name in git_ls_files(): |
|||
if "localized" in file_name or os.path.splitext(file_name)[1] not in file_types: |
|||
continue |
|||
bad_lines += check_file(file_name, allow_pattern) |
|||
return bad_lines |
|||
|
|||
|
|||
def main(): |
|||
release_tag = get_release_tag() |
|||
if not release_tag: |
|||
print("Release tag is None, exiting") |
|||
sys.exit(0) |
|||
|
|||
print(f"Release tag: {release_tag}") |
|||
allow_pattern = re.compile(f"{release_tag}(_docs)*") |
|||
bad_lines = check_all_files(allow_pattern) |
|||
if bad_lines: |
|||
print( |
|||
f"Found lines referring to previous release. Either update the files, or add an exclusion to {__file__}" |
|||
) |
|||
for line in bad_lines: |
|||
print(line) |
|||
|
|||
sys.exit(1 if bad_lines else 0) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
if "--test" in sys.argv: |
|||
test_pattern() |
|||
main() |
部分文件因为文件数量过多而无法显示
撰写
预览
正在加载...
取消
保存
Reference in new issue