浏览代码

GAIL and Pretraining (#2118)

Based on the new reward signals architecture, add BC pretrainer and GAIL for PPO. Main changes:

- A new GAILRewardSignal and GAILModel for GAIL/VAIL
- A BCModule component (not a reward signal) to do pretraining during RL
- Documentation for both of these
- Change to Demo Loader that lets you load multiple demo files in a folder
- Example Demo files for all of our tested sample environments (for future regression testing)
/develop-generalizationTraining-TrainerController
GitHub 5 年前
当前提交
9c50abcf
共有 44 个文件被更改,包括 15563 次插入155 次删除
  1. 125
      docs/Training-Imitation-Learning.md
  2. 70
      docs/Training-PPO.md
  3. 99
      docs/Training-RewardSignals.md
  4. 7
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py
  5. 2
      ml-agents/mlagents/trainers/components/reward_signals/reward_signal_factory.py
  6. 37
      ml-agents/mlagents/trainers/demo_loader.py
  7. 14
      ml-agents/mlagents/trainers/ppo/policy.py
  8. 4
      ml-agents/mlagents/trainers/ppo/trainer.py
  9. 49
      ml-agents/mlagents/trainers/tests/mock_brain.py
  10. 13
      ml-agents/mlagents/trainers/tests/test_demo_loader.py
  11. 154
      ml-agents/mlagents/trainers/tests/test_reward_signals.py
  12. 92
      docs/Training-BehavioralCloning.md
  13. 80
      docs/images/mlagents-ImitationAndRL.png
  14. 158
      ml-agents/mlagents/trainers/tests/test_bcmodule.py
  15. 1001
      ml-agents/mlagents/trainers/tests/testdcvis.demo
  16. 442
      demos/Expert3DBall.demo
  17. 1001
      demos/Expert3DBallHard.demo
  18. 1001
      demos/ExpertBanana.demo
  19. 171
      demos/ExpertBasic.demo
  20. 198
      demos/ExpertBouncer.demo
  21. 1001
      demos/ExpertCrawlerSta.demo
  22. 1001
      demos/ExpertGrid.demo
  23. 1001
      demos/ExpertHallway.demo
  24. 1001
      demos/ExpertPush.demo
  25. 1001
      demos/ExpertPyramid.demo
  26. 1001
      demos/ExpertReacher.demo
  27. 1001
      demos/ExpertSoccerGoal.demo
  28. 1001
      demos/ExpertSoccerStri.demo
  29. 1001
      demos/ExpertTennis.demo
  30. 1001
      demos/ExpertWalker.demo
  31. 1
      ml-agents/mlagents/trainers/components/bc/__init__.py
  32. 101
      ml-agents/mlagents/trainers/components/bc/model.py
  33. 172
      ml-agents/mlagents/trainers/components/bc/module.py
  34. 1
      ml-agents/mlagents/trainers/components/reward_signals/gail/__init__.py
  35. 265
      ml-agents/mlagents/trainers/components/reward_signals/gail/model.py
  36. 270
      ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py
  37. 60
      ml-agents/mlagents/trainers/tests/test_demo_dir/test.demo
  38. 60
      ml-agents/mlagents/trainers/tests/test_demo_dir/test2.demo
  39. 60
      ml-agents/mlagents/trainers/tests/test_demo_dir/test3.demo

125
docs/Training-Imitation-Learning.md


Imitation Learning uses pairs of observations and actions from
from a demonstration to learn a policy. [Video Link](https://youtu.be/kpb8ZkMBFYs).
Imitation learning can also be used to help reinforcement learning. Especially in
environments with sparse (i.e., infrequent or rare) rewards, the agent may never see
the reward and thus not learn from it. Curiosity helps the agent explore, but in some cases
it is easier to just show the agent how to achieve the reward. In these cases,
imitation learning can dramatically reduce the time it takes to solve the environment.
For instance, on the [Pyramids environment](Learning-Environment-Examples.md#pyramids),
just 6 episodes of demonstrations can reduce training steps by more than 4 times.
<p align="center">
<img src="images/mlagents-ImitationAndRL.png"
alt="Using Demonstrations with Reinforcement Learning"
width="350" border="0" />
</p>
ML-Agents provides several ways to learn from demonstrations. For most situations,
[GAIL](Training-RewardSignals.md#the-gail-reward-signal) is the preferred approach.
* To train using GAIL (Generative Adversarial Imitaiton Learning) you can add the
[GAIL reward signal](Training-RewardSignals.md#the-gail-reward-signal). GAIL can be
used with or without environment rewards, and works well when there are a limited
number of demonstrations.
* To help bootstrap reinforcement learning, you can enable
[pretraining](Training-PPO.md#optional-pretraining-using-demonstrations)
on the PPO trainer, in addition to using a small GAIL reward signal.
* To train an agent to exactly mimic demonstrations, you can use the
[Behavioral Cloning](Training-BehavioralCloning.md) trainer. Behavioral Cloning can be
used offline and online (in-editor), and learns very quickly. However, it usually is ineffective
on more complex environments without a large number of demonstrations.
## Recording Demonstrations
It is possible to record demonstrations of agent behavior from the Unity Editor,

alt="BC Teacher Helper"
width="375" border="10" />
</p>
## Training with Behavioral Cloning
There are a variety of possible imitation learning algorithms which can
be used, the simplest one of them is Behavioral Cloning. It works by collecting
demonstrations from a teacher, and then simply uses them to directly learn a
policy, in the same way the supervised learning for image classification
or other traditional Machine Learning tasks work.
### Offline Training
With offline behavioral cloning, we can use demonstrations (`.demo` files)
generated using the `Demonstration Recorder` as the dataset used to train a behavior.
1. Choose an agent you would like to learn to imitate some set of demonstrations.
2. Record a set of demonstration using the `Demonstration Recorder` (see above).
For illustrative purposes we will refer to this file as `AgentRecording.demo`.
3. Build the scene, assigning the agent a Learning Brain, and set the Brain to
Control in the Broadcast Hub. For more information on Brains, see
[here](Learning-Environment-Design-Brains.md).
4. Open the `config/offline_bc_config.yaml` file.
5. Modify the `demo_path` parameter in the file to reference the path to the
demonstration file recorded in step 2. In our case this is:
`./UnitySDK/Assets/Demonstrations/AgentRecording.demo`
6. Launch `mlagent-learn`, providing `./config/offline_bc_config.yaml`
as the config parameter, and include the `--run-id` and `--train` as usual.
Provide your environment as the `--env` parameter if it has been compiled
as standalone, or omit to train in the editor.
7. (Optional) Observe training performance using TensorBoard.
This will use the demonstration file to train a neural network driven agent
to directly imitate the actions provided in the demonstration. The environment
will launch and be used for evaluating the agent's performance during training.
### Online Training
It is also possible to provide demonstrations in realtime during training,
without pre-recording a demonstration file. The steps to do this are as follows:
1. First create two Brains, one which will be the "Teacher," and the other which
will be the "Student." We will assume that the names of the Brain
Assets are "Teacher" and "Student" respectively.
2. The "Teacher" Brain must be a **Player Brain**. You must properly
configure the inputs to map to the corresponding actions.
3. The "Student" Brain must be a **Learning Brain**.
4. The Brain Parameters of both the "Teacher" and "Student" Brains must be
compatible with the agent.
5. Drag both the "Teacher" and "Student" Brain into the Academy's `Broadcast Hub`
and check the `Control` checkbox on the "Student" Brain.
6. Link the Brains to the desired Agents (one Agent as the teacher and at least
one Agent as a student).
7. In `config/online_bc_config.yaml`, add an entry for the "Student" Brain. Set
the `trainer` parameter of this entry to `online_bc`, and the
`brain_to_imitate` parameter to the name of the teacher Brain: "Teacher".
Additionally, set `batches_per_epoch`, which controls how much training to do
each moment. Increase the `max_steps` option if you'd like to keep training
the Agents for a longer period of time.
8. Launch the training process with `mlagents-learn config/online_bc_config.yaml
--train --slow`, and press the :arrow_forward: button in Unity when the
message _"Start training by pressing the Play button in the Unity Editor"_ is
displayed on the screen
9. From the Unity window, control the Agent with the Teacher Brain by providing
"teacher demonstrations" of the behavior you would like to see.
10. Watch as the Agent(s) with the student Brain attached begin to behave
similarly to the demonstrations.
11. Once the Student Agents are exhibiting the desired behavior, end the training
process with `CTL+C` from the command line.
12. Move the resulting `*.nn` file into the `TFModels` subdirectory of the
Assets folder (or a subdirectory within Assets of your choosing) , and use
with `Learning` Brain.
**BC Teacher Helper**
We provide a convenience utility, `BC Teacher Helper` component that you can add
to the Teacher Agent.
<p align="center">
<img src="images/bc_teacher_helper.png"
alt="BC Teacher Helper"
width="375" border="10" />
</p>
This utility enables you to use keyboard shortcuts to do the following:
1. To start and stop recording experiences. This is useful in case you'd like to
interact with the game _but not have the agents learn from these
interactions_. The default command to toggle this is to press `R` on the
keyboard.
2. Reset the training buffer. This enables you to instruct the agents to forget
their buffer of recent experiences. This is useful if you'd like to get them
to quickly learn a new behavior. The default command to reset the buffer is
to press `C` on the keyboard.

70
docs/Training-PPO.md


presented to an agent, see [Training with Curriculum
Learning](Training-Curriculum-Learning.md).
For information about imitation learning, which uses a different training
algorithm, see
For information about imitation learning from demonstrations, see
[Training with Imitation Learning](Training-Imitation-Learning.md).
## Best Practices when training with PPO

the agent will need to remember in order to successfully complete the task.
Typical Range: `64` - `512`
## (Optional) Pretraining Using Demonstrations
In some cases, you might want to bootstrap the agent's policy using behavior recorded
from a player. This can help guide the agent towards the reward. Pretraining adds
training operations that mimic a demonstration rather than attempting to maximize reward.
It is essentially equivalent to running [behavioral cloning](./Training-BehavioralCloning.md)
in-line with PPO.
To use pretraining, add a `pretraining` section to the trainer_config. For instance:
```
pretraining:
demo_path: ./demos/ExpertPyramid.demo
strength: 0.5
steps: 10000
```
Below are the avaliable hyperparameters for pretraining.
### Strength
`strength` corresponds to the learning rate of the imitation relative to the learning
rate of PPO, and roughly corresponds to how strongly we allow the behavioral cloning
to influence the policy.
Typical Range: `0.1` - `0.5`
### Demo Path
`demo_path` is the path to your `.demo` file or directory of `.demo` files.
See the [imitation learning guide](Training-ImitationLearning.md) for more on `.demo` files.
### Steps
During pretraining, it is often desirable to stop using demonstrations after the agent has
"seen" rewards, and allow it to optimize past the available demonstrations and/or generalize
outside of the provided demonstrations. `steps` corresponds to the training steps over which
pretraining is active. The learning rate of the pretrainer will anneal over the steps. Set
the steps to 0 for constant imitation over the entire training run.
### (Optional) Batch Size
`batch_size` is the number of demonstration experiences used for one iteration of a gradient
descent update. If not specified, it will default to the `batch_size` defined for PPO.
Typical Range (Continuous): `512` - `5120`
Typical Range (Discrete): `32` - `512`
### (Optional) Number of Epochs
`num_epoch` is the number of passes through the experience buffer during
gradient descent. If not specified, it will default to the number of epochs set for PPO.
Typical Range: `3` - `10`
### (Optional) Samples Per Update
`samples_per_update` is the maximum number of samples
to use during each imitation update. You may want to lower this if your demonstration
dataset is very large to avoid overfitting the policy on demonstrations. Set to 0
to train over all of the demonstrations at each update step.
Default Value: `0` (all)
Typical Range: Approximately equal to PPO's `buffer_size`
## Training Statistics

99
docs/Training-RewardSignals.md


observation, but also not too small to prevent it from learning to differentiate between
demonstrated and actual behavior.
Default Value: 64
Default Value: `64`
Typical Range: `64` - `256`
#### Learning Rate

Default Value: `3e-4`
### The GAIL Reward Signal
GAIL, or [Generative Adversarial Imitation Learning](https://arxiv.org/abs/1606.03476), is an
imitation learning algorithm that uses an adversarial approach, in a similar vein to GANs
(Generative Adversarial Networks). In this framework, a second neural network, the
discriminator, is taught to distinguish whether an observation/action is from a demonstration, or
produced by the agent. This discriminator can the examine a new observation/action and provide it a
reward based on how close it believes this new observation/action is to the provided demonstrations.
At each training step, the agent tries to learn how to maximize this reward. Then, the
discriminator is trained to better distinguish between demonstrations and agent state/actions.
In this way, while the agent gets better and better at mimicing the demonstrations, the
discriminator keeps getting stricter and stricter and the agent must try harder to "fool" it.
This approach, when compared to [Behavioral Cloning](Training-BehavioralCloning.md), requires
far fewer demonstrations to be provided. After all, we are still learning a policy that happens
to be similar to the demonstration, not directly copying the behavior of the demonstrations. It
is also especially effective when combined with an Extrinsic signal, but can also be used
independently to purely learn from demonstration.
Using GAIL requires recorded demonstrations from your Unity environment. See the
[imitation learning guide](Training-Imitation-Learning.md) to learn more about recording demonstrations.
#### Strength
`strength` is the factor by which to multiply the raw reward. Note that when using GAIL
with an Extrinsic Signal, this value should be set lower if your demonstrations are
suboptimal (e.g. from a human), so that a trained agent will focus on receiving extrinsic
rewards instead of exactly copying the demonstrations. Keep the strength below about 0.1 in those cases.
Typical Range: `0.01` - `1.0`
#### Gamma
`gamma` corresponds to the discount factor for future rewards.
Typical Range: `0.8` - `0.9`
#### Demo Path
`demo_path` is the path to your `.demo` file or directory of `.demo` files. See the [imitation learning guide]
(Training-ImitationLearning.md).
#### Encoding Size
`encoding_size` corresponds to the size of the hidden layer used by the discriminator.
This value should be small enough to encourage the discriminator to compress the original
observation, but also not too small to prevent it from learning to differentiate between
demonstrated and actual behavior. Dramatically increasing this size will also negatively affect
training times.
Default Value: `64`
Typical Range: `64` - `256`
#### Learning Rate
`learning_rate` is the learning rate used to update the discriminator.
This should typically be decreased if training is unstable, and the GAIL loss is unstable.
Default Value: `3e-4`
Typical Range: `1e-5` - `1e-3`
#### Use Actions
`use_actions` determines whether the discriminator should discriminate based on both
observations and actions, or just observations. Set to `True` if you want the agent to
mimic the actions from the demonstrations, and `False` if you'd rather have the agent
visit the same states as in the demonstrations but with possibly different actions.
Setting to `False` is more likely to be stable, especially with imperfect demonstrations,
but may learn slower.
Default Value: `false`
#### (Optional) Samples Per Update
`samples_per_update` is the maximum number of samples to use during each discriminator update. You may
want to lower this if your buffer size is very large to avoid overfitting the discriminator on current data.
If set to 0, we will use the minimum of buffer size and the number of demonstration samples.
Default Value: `0`
Typical Range: Approximately equal to [`buffer_size`](Training-PPO.md)
#### (Optional) Variational Discriminator Bottleneck
`use_vail` enables a [variational bottleneck](https://arxiv.org/abs/1810.00821) within the
GAIL discriminator. This forces the discriminator to learn a more general representation
and reduces its tendency to be "too good" at discriminating, making learning more stable.
However, it does increase training time. Enable this if you notice your imitation learning is
unstable, or unable to learn the task at hand.
Default Value: `false`

7
ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py


"""
Creates the Curiosity reward generator
:param policy: The Learning Policy
:param encoding_size: The size of the Curiosity encoding
:param signal_strength: The scaling parameter for the reward. The scaled reward will be the unscaled
:param strength: The scaling parameter for the reward. The scaled reward will be the unscaled
:param gamma: The time discounting factor used for this reward.
:param encoding_size: The size of the hidden encoding layer for the ICM
:param learning_rate: The learning rate for the ICM.
:param num_epoch: The number of epochs to train over the training buffer for the ICM.
"""
super().__init__(policy, strength, gamma)
self.model = CuriosityModel(

2
ml-agents/mlagents/trainers/components/reward_signals/reward_signal_factory.py


from mlagents.trainers.components.reward_signals.extrinsic.signal import (
ExtrinsicRewardSignal,
)
from mlagents.trainers.components.reward_signals.gail.signal import GAILRewardSignal
from mlagents.trainers.components.reward_signals.curiosity.signal import (
CuriosityRewardSignal,
)

NAME_TO_CLASS: Dict[str, Type[RewardSignal]] = {
"extrinsic": ExtrinsicRewardSignal,
"curiosity": CuriosityRewardSignal,
"gail": GAILRewardSignal,
}

37
ml-agents/mlagents/trainers/demo_loader.py


import pathlib
import logging
import os
from typing import List, Tuple
from mlagents.envs.communicator_objects import *
from mlagents.envs.communicator_objects import (
AgentInfoProto,
BrainParametersProto,
DemonstrationMetaProto,
)
from google.protobuf.internal.decoder import _DecodeVarint32 # type: ignore

def make_demo_buffer(brain_infos, brain_params, sequence_length):
def make_demo_buffer(
brain_infos: List[BrainInfo], brain_params: BrainParameters, sequence_length: int
) -> Buffer:
# Create and populate buffer using experiences
demo_buffer = Buffer()
for idx, experience in enumerate(brain_infos):

return demo_buffer
def demo_to_buffer(file_path, sequence_length):
def demo_to_buffer(
file_path: str, sequence_length: int
) -> Tuple[BrainParameters, Buffer]:
"""
Loads demonstration file and uses it to fill training buffer.
:param file_path: Location of demonstration file (.demo).

return brain_params, demo_buffer
def load_demonstration(file_path):
def load_demonstration(file_path: str) -> Tuple[BrainParameters, List[BrainInfo], int]:
"""
Loads and parses a demonstration file.
:param file_path: Location of demonstration file (.demo).

all_files = os.listdir(file_path)
for _file in all_files:
if _file.endswith(".demo"):
file_paths.append(_file)
file_paths.append(os.path.join(file_path, _file))
if not all_files:
raise ValueError("There are no '.demo' files in the provided directory.")
file_extension = pathlib.Path(file_path).suffix
if file_extension != ".demo":
raise ValueError(
"The file is not a '.demo' file. Please provide a file with the "
"correct extension."
)
file_extension = pathlib.Path(file_path).suffix
if file_extension != ".demo":
raise ValueError(
"The file is not a '.demo' file. Please provide a file with the "
"correct extension."
)
total_expected = 0
total_expected = 0
total_expected = meta_data_proto.number_steps
total_expected += meta_data_proto.number_steps
pos = INITIAL_POS
if obs_decoded == 1:
brain_param_proto = BrainParametersProto()

14
ml-agents/mlagents/trainers/ppo/policy.py


from mlagents.trainers.components.reward_signals.reward_signal_factory import (
create_reward_signal,
)
from mlagents.trainers.components.bc.module import BCModule
logger = logging.getLogger("mlagents.trainers")

self.reward_signals[reward_signal] = create_reward_signal(
self, reward_signal, config
)
# Create pretrainer if needed
if "pretraining" in trainer_params:
BCModule.check_config(trainer_params["pretraining"])
self.bc_module = BCModule(
self,
policy_learning_rate=trainer_params["learning_rate"],
default_batch_size=trainer_params["batch_size"],
default_num_epoch=trainer_params["num_epoch"],
**trainer_params["pretraining"],
)
else:
self.bc_module = None
if load:
self._load_graph()

4
ml-agents/mlagents/trainers/ppo/trainer.py


)
for stat, val in update_stats.items():
self.stats[stat].append(val)
if self.policy.bc_module:
update_stats = self.policy.bc_module.update()
for stat, val in update_stats.items():
self.stats[stat].append(val)
self.training_buffer.reset_update_buffer()
self.trainer_metrics.end_policy_update()

49
ml-agents/mlagents/trainers/tests/mock_brain.py


import pytest
import numpy as np
from mlagents.trainers.buffer import Buffer
def create_mock_brainparams(
number_visual_observations=0,

mock_env.return_value.brain_names = ["MockBrain"]
mock_env.return_value.reset.return_value = {"MockBrain": mock_braininfo}
mock_env.return_value.step.return_value = {"MockBrain": mock_braininfo}
def simulate_rollout(env, policy, buffer_init_samples):
brain_info_list = []
for i in range(buffer_init_samples):
brain_info_list.append(env.step()[env.brain_names[0]])
buffer = create_buffer(brain_info_list, policy.brain, policy.sequence_length)
return buffer
def create_buffer(brain_infos, brain_params, sequence_length):
buffer = Buffer()
# Make a buffer
for idx, experience in enumerate(brain_infos):
if idx > len(brain_infos) - 2:
break
current_brain_info = brain_infos[idx]
next_brain_info = brain_infos[idx + 1]
buffer[0].last_brain_info = current_brain_info
buffer[0]["done"].append(next_brain_info.local_done[0])
buffer[0]["rewards"].append(next_brain_info.rewards[0])
for i in range(brain_params.number_visual_observations):
buffer[0]["visual_obs%d" % i].append(
current_brain_info.visual_observations[i][0]
)
buffer[0]["next_visual_obs%d" % i].append(
current_brain_info.visual_observations[i][0]
)
if brain_params.vector_observation_space_size > 0:
buffer[0]["vector_obs"].append(current_brain_info.vector_observations[0])
buffer[0]["next_vector_in"].append(
current_brain_info.vector_observations[0]
)
buffer[0]["actions"].append(next_brain_info.previous_vector_actions[0])
buffer[0]["prev_action"].append(current_brain_info.previous_vector_actions[0])
buffer[0]["masks"].append(1.0)
buffer[0]["advantages"].append(1.0)
buffer[0]["action_probs"].append(np.ones(buffer[0]["actions"][0].shape))
buffer[0]["actions_pre"].append(np.ones(buffer[0]["actions"][0].shape))
buffer[0]["random_normal_epsilon"].append(
np.ones(buffer[0]["actions"][0].shape)
)
buffer[0]["action_mask"].append(np.ones(buffer[0]["actions"][0].shape))
buffer[0]["memory"].append(np.ones(8))
buffer.append_update_buffer(0, batch_size=None, training_length=sequence_length)
return buffer

13
ml-agents/mlagents/trainers/tests/test_demo_loader.py


demo_buffer = make_demo_buffer(brain_infos, brain_parameters, 1)
assert len(demo_buffer.update_buffer["actions"]) == total_expected - 1
def test_load_demo_dir():
path_prefix = os.path.dirname(os.path.abspath(__file__))
brain_parameters, brain_infos, total_expected = load_demonstration(
path_prefix + "/test_demo_dir"
)
assert brain_parameters.brain_name == "Ball3DBrain"
assert brain_parameters.vector_observation_space_size == 8
assert len(brain_infos) == total_expected
demo_buffer = make_demo_buffer(brain_infos, brain_parameters, 1)
assert len(demo_buffer.update_buffer["actions"]) == total_expected - 1

154
ml-agents/mlagents/trainers/tests/test_reward_signals.py


from mlagents.trainers.ppo.models import PPOModel
from mlagents.trainers.ppo.trainer import discount_rewards
from mlagents.trainers.ppo.policy import PPOPolicy
from mlagents.trainers.demo_loader import make_demo_buffer
from mlagents.envs import UnityEnvironment
from mlagents.envs.mock_communicator import MockCommunicator

@pytest.fixture
def gail_dummy_config():
return {
"gail": {
"strength": 0.1,
"gamma": 0.9,
"encoding_size": 128,
"demo_path": os.path.dirname(os.path.abspath(__file__)) + "/test.demo",
}
}
@pytest.fixture
VECTOR_ACTION_SPACE = [2]
VECTOR_OBS_SPACE = 8
DISCRETE_ACTION_SPACE = [2]
BUFFER_INIT_SAMPLES = 20
NUM_AGENTS = 12
def create_ppo_policy_mock(

if not use_visual:
mock_brain = mb.create_mock_brainparams(
vector_action_space_type="discrete" if use_discrete else "continuous",
vector_action_space_size=[2],
vector_observation_space_size=8,
vector_action_space_size=DISCRETE_ACTION_SPACE
if use_discrete
else VECTOR_ACTION_SPACE,
vector_observation_space_size=VECTOR_OBS_SPACE,
num_agents=12,
num_vector_observations=8,
num_vector_acts=2,
num_agents=NUM_AGENTS,
num_vector_observations=VECTOR_OBS_SPACE,
num_vector_acts=sum(
DISCRETE_ACTION_SPACE if use_discrete else VECTOR_ACTION_SPACE
),
vector_action_space_size=[2],
vector_action_space_size=DISCRETE_ACTION_SPACE
if use_discrete
else VECTOR_ACTION_SPACE,
num_agents=12,
num_agents=NUM_AGENTS,
num_vector_acts=2,
num_vector_acts=sum(
DISCRETE_ACTION_SPACE if use_discrete else VECTOR_ACTION_SPACE
),
discrete=use_discrete,
)
mb.setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo)

return env, policy
@mock.patch("mlagents.envs.UnityEnvironment")
def test_curiosity_cc_evaluate(mock_env, dummy_config, curiosity_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, curiosity_dummy_config, False, False, False
)
def reward_signal_eval(env, policy, reward_signal_name):
scaled_reward, unscaled_reward = policy.reward_signals["curiosity"].evaluate(
# Test evaluate
rsig_result = policy.reward_signals[reward_signal_name].evaluate(
assert scaled_reward.shape == (12,)
assert unscaled_reward.shape == (12,)
assert rsig_result.scaled_reward.shape == (NUM_AGENTS,)
assert rsig_result.unscaled_reward.shape == (NUM_AGENTS,)
def reward_signal_update(env, policy, reward_signal_name):
buffer = mb.simulate_rollout(env, policy, BUFFER_INIT_SAMPLES)
out = policy.reward_signals[reward_signal_name].update(buffer.update_buffer, 2)
assert type(out) is dict
def test_curiosity_dc_evaluate(mock_env, dummy_config, curiosity_dummy_config):
def test_gail_cc(mock_env, dummy_config, gail_dummy_config):
mock_env, dummy_config, curiosity_dummy_config, False, True, False
mock_env, dummy_config, gail_dummy_config, False, False, False
brain_infos = env.reset()
brain_info = brain_infos[env.brain_names[0]]
next_brain_info = env.step()[env.brain_names[0]]
scaled_reward, unscaled_reward = policy.reward_signals["curiosity"].evaluate(
brain_info, next_brain_info
reward_signal_eval(env, policy, "gail")
reward_signal_update(env, policy, "gail")
@mock.patch("mlagents.envs.UnityEnvironment")
def test_gail_dc(mock_env, dummy_config, gail_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, gail_dummy_config, False, True, False
assert scaled_reward.shape == (12,)
assert unscaled_reward.shape == (12,)
reward_signal_eval(env, policy, "gail")
reward_signal_update(env, policy, "gail")
@mock.patch("mlagents.envs.UnityEnvironment")
def test_gail_visual(mock_env, dummy_config, gail_dummy_config):
gail_dummy_config["gail"]["demo_path"] = (
os.path.dirname(os.path.abspath(__file__)) + "/testdcvis.demo"
)
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, gail_dummy_config, False, True, True
)
reward_signal_eval(env, policy, "gail")
reward_signal_update(env, policy, "gail")
@mock.patch("mlagents.envs.UnityEnvironment")
def test_gail_rnn(mock_env, dummy_config, gail_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, gail_dummy_config, True, False, False
)
reward_signal_eval(env, policy, "gail")
reward_signal_update(env, policy, "gail")
@mock.patch("mlagents.envs.UnityEnvironment")
def test_curiosity_cc(mock_env, dummy_config, curiosity_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, curiosity_dummy_config, False, False, False
)
reward_signal_eval(env, policy, "curiosity")
reward_signal_update(env, policy, "curiosity")
@mock.patch("mlagents.envs.UnityEnvironment")
def test_curiosity_dc(mock_env, dummy_config, curiosity_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, curiosity_dummy_config, False, True, False
)
reward_signal_eval(env, policy, "curiosity")
reward_signal_update(env, policy, "curiosity")
def test_curiosity_visual_evaluate(mock_env, dummy_config, curiosity_dummy_config):
def test_curiosity_visual(mock_env, dummy_config, curiosity_dummy_config):
brain_infos = env.reset()
brain_info = brain_infos[env.brain_names[0]]
next_brain_info = env.step()[env.brain_names[0]]
scaled_reward, unscaled_reward = policy.reward_signals["curiosity"].evaluate(
brain_info, next_brain_info
)
assert scaled_reward.shape == (12,)
assert unscaled_reward.shape == (12,)
reward_signal_eval(env, policy, "curiosity")
reward_signal_update(env, policy, "curiosity")
def test_curiosity_rnn_evaluate(mock_env, dummy_config, curiosity_dummy_config):
def test_curiosity_rnn(mock_env, dummy_config, curiosity_dummy_config):
brain_infos = env.reset()
brain_info = brain_infos[env.brain_names[0]]
next_brain_info = env.step()[env.brain_names[0]]
scaled_reward, unscaled_reward = policy.reward_signals["curiosity"].evaluate(
brain_info, next_brain_info
reward_signal_eval(env, policy, "curiosity")
reward_signal_update(env, policy, "curiosity")
@mock.patch("mlagents.envs.UnityEnvironment")
def test_extrinsic(mock_env, dummy_config, curiosity_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, curiosity_dummy_config, False, False, False
assert scaled_reward.shape == (12,)
assert unscaled_reward.shape == (12,)
reward_signal_eval(env, policy, "extrinsic")
reward_signal_update(env, policy, "extrinsic")
if __name__ == "__main__":

92
docs/Training-BehavioralCloning.md


# Training with Behavioral Cloning
There are a variety of possible imitation learning algorithms which can
be used, the simplest one of them is Behavioral Cloning. It works by collecting
demonstrations from a teacher, and then simply uses them to directly learn a
policy, in the same way the supervised learning for image classification
or other traditional Machine Learning tasks work.
## Offline Training
With offline behavioral cloning, we can use demonstrations (`.demo` files)
generated using the `Demonstration Recorder` as the dataset used to train a behavior.
1. Choose an agent you would like to learn to imitate some set of demonstrations.
2. Record a set of demonstration using the `Demonstration Recorder` (see [here](Training-Imitation-Learning.md)).
For illustrative purposes we will refer to this file as `AgentRecording.demo`.
3. Build the scene, assigning the agent a Learning Brain, and set the Brain to
Control in the Broadcast Hub. For more information on Brains, see
[here](Learning-Environment-Design-Brains.md).
4. Open the `config/offline_bc_config.yaml` file.
5. Modify the `demo_path` parameter in the file to reference the path to the
demonstration file recorded in step 2. In our case this is:
`./UnitySDK/Assets/Demonstrations/AgentRecording.demo`
6. Launch `mlagent-learn`, providing `./config/offline_bc_config.yaml`
as the config parameter, and include the `--run-id` and `--train` as usual.
Provide your environment as the `--env` parameter if it has been compiled
as standalone, or omit to train in the editor.
7. (Optional) Observe training performance using TensorBoard.
This will use the demonstration file to train a neural network driven agent
to directly imitate the actions provided in the demonstration. The environment
will launch and be used for evaluating the agent's performance during training.
## Online Training
It is also possible to provide demonstrations in realtime during training,
without pre-recording a demonstration file. The steps to do this are as follows:
1. First create two Brains, one which will be the "Teacher," and the other which
will be the "Student." We will assume that the names of the Brain
Assets are "Teacher" and "Student" respectively.
2. The "Teacher" Brain must be a **Player Brain**. You must properly
configure the inputs to map to the corresponding actions.
3. The "Student" Brain must be a **Learning Brain**.
4. The Brain Parameters of both the "Teacher" and "Student" Brains must be
compatible with the agent.
5. Drag both the "Teacher" and "Student" Brain into the Academy's `Broadcast Hub`
and check the `Control` checkbox on the "Student" Brain.
6. Link the Brains to the desired Agents (one Agent as the teacher and at least
one Agent as a student).
7. In `config/online_bc_config.yaml`, add an entry for the "Student" Brain. Set
the `trainer` parameter of this entry to `online_bc`, and the
`brain_to_imitate` parameter to the name of the teacher Brain: "Teacher".
Additionally, set `batches_per_epoch`, which controls how much training to do
each moment. Increase the `max_steps` option if you'd like to keep training
the Agents for a longer period of time.
8. Launch the training process with `mlagents-learn config/online_bc_config.yaml
--train --slow`, and press the :arrow_forward: button in Unity when the
message _"Start training by pressing the Play button in the Unity Editor"_ is
displayed on the screen
9. From the Unity window, control the Agent with the Teacher Brain by providing
"teacher demonstrations" of the behavior you would like to see.
10. Watch as the Agent(s) with the student Brain attached begin to behave
similarly to the demonstrations.
11. Once the Student Agents are exhibiting the desired behavior, end the training
process with `CTL+C` from the command line.
12. Move the resulting `*.nn` file into the `TFModels` subdirectory of the
Assets folder (or a subdirectory within Assets of your choosing) , and use
with `Learning` Brain.
**BC Teacher Helper**
We provide a convenience utility, `BC Teacher Helper` component that you can add
to the Teacher Agent.
<p align="center">
<img src="images/bc_teacher_helper.png"
alt="BC Teacher Helper"
width="375" border="10" />
</p>
This utility enables you to use keyboard shortcuts to do the following:
1. To start and stop recording experiences. This is useful in case you'd like to
interact with the game _but not have the agents learn from these
interactions_. The default command to toggle this is to press `R` on the
keyboard.
2. Reset the training buffer. This enables you to instruct the agents to forget
their buffer of recent experiences. This is useful if you'd like to get them
to quickly learn a new behavior. The default command to reset the buffer is
to press `C` on the keyboard.

80
docs/images/mlagents-ImitationAndRL.png

之前 之后
宽度: 600  |  高度: 371  |  大小: 23 KiB

158
ml-agents/mlagents/trainers/tests/test_bcmodule.py


import unittest.mock as mock
import pytest
import mlagents.trainers.tests.mock_brain as mb
import numpy as np
import yaml
import os
from mlagents.trainers.ppo.policy import PPOPolicy
@pytest.fixture
def dummy_config():
return yaml.safe_load(
"""
trainer: ppo
batch_size: 32
beta: 5.0e-3
buffer_size: 512
epsilon: 0.2
hidden_units: 128
lambd: 0.95
learning_rate: 3.0e-4
max_steps: 5.0e4
normalize: true
num_epoch: 5
num_layers: 2
time_horizon: 64
sequence_length: 64
summary_freq: 1000
use_recurrent: false
memory_size: 8
pretraining:
demo_path: ./demos/ExpertPyramid.demo
strength: 1.0
steps: 10000000
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
"""
)
def create_mock_3dball_brain():
mock_brain = mb.create_mock_brainparams(
vector_action_space_type="continuous",
vector_action_space_size=[2],
vector_observation_space_size=8,
)
return mock_brain
def create_mock_banana_brain():
mock_brain = mb.create_mock_brainparams(
number_visual_observations=1,
vector_action_space_type="discrete",
vector_action_space_size=[3, 3, 3, 2],
vector_observation_space_size=0,
)
return mock_brain
def create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, use_rnn, demo_file
):
mock_braininfo = mb.create_mock_braininfo(num_agents=12, num_vector_observations=8)
mb.setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo)
env = mock_env()
trainer_parameters = dummy_config
model_path = env.brain_names[0]
trainer_parameters["model_path"] = model_path
trainer_parameters["keep_checkpoints"] = 3
trainer_parameters["use_recurrent"] = use_rnn
trainer_parameters["pretraining"]["demo_path"] = (
os.path.dirname(os.path.abspath(__file__)) + "/" + demo_file
)
policy = PPOPolicy(0, mock_brain, trainer_parameters, False, False)
return env, policy
# Test default values
@mock.patch("mlagents.envs.UnityEnvironment")
def test_bcmodule_defaults(mock_env, dummy_config):
# See if default values match
mock_brain = create_mock_3dball_brain()
env, policy = create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, False, "test.demo"
)
assert policy.bc_module.num_epoch == dummy_config["num_epoch"]
assert policy.bc_module.batch_size == dummy_config["batch_size"]
env.close()
# Assign strange values and see if it overrides properly
dummy_config["pretraining"]["num_epoch"] = 100
dummy_config["pretraining"]["batch_size"] = 10000
env, policy = create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, False, "test.demo"
)
assert policy.bc_module.num_epoch == 100
assert policy.bc_module.batch_size == 10000
env.close()
# Test with continuous control env and vector actions
@mock.patch("mlagents.envs.UnityEnvironment")
def test_bcmodule_update(mock_env, dummy_config):
mock_brain = create_mock_3dball_brain()
env, policy = create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, False, "test.demo"
)
stats = policy.bc_module.update()
for _, item in stats.items():
assert isinstance(item, np.float32)
env.close()
# Test with RNN
@mock.patch("mlagents.envs.UnityEnvironment")
def test_bcmodule_rnn_update(mock_env, dummy_config):
mock_brain = create_mock_3dball_brain()
env, policy = create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, True, "test.demo"
)
stats = policy.bc_module.update()
for _, item in stats.items():
assert isinstance(item, np.float32)
env.close()
# Test with discrete control and visual observations
@mock.patch("mlagents.envs.UnityEnvironment")
def test_bcmodule_dc_visual_update(mock_env, dummy_config):
mock_brain = create_mock_banana_brain()
env, policy = create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, False, "testdcvis.demo"
)
stats = policy.bc_module.update()
for _, item in stats.items():
assert isinstance(item, np.float32)
env.close()
# Test with discrete control, visual observations and RNN
@mock.patch("mlagents.envs.UnityEnvironment")
def test_bcmodule_rnn_dc_update(mock_env, dummy_config):
mock_brain = create_mock_banana_brain()
env, policy = create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, True, "testdcvis.demo"
)
stats = policy.bc_module.update()
for _, item in stats.items():
assert isinstance(item, np.float32)
env.close()
if __name__ == "__main__":
pytest.main()

1001
ml-agents/mlagents/trainers/tests/testdcvis.demo
文件差异内容过多而无法显示
查看文件

442
demos/Expert3DBall.demo


BallDemo� -bfB**0: 3DBallBrain7
f&C�v���x��?�@�Q��"P���������<
��<����x��?�;|@�Q���"{�"n��>@��==���=P���������<
†�=���x��?0r@�Q���"��":MG?��J?=���=P���������<
���=e] =x��?�a@�Q��Z<�"�hw�{�>=���=P���������<
��m={׾;x��?�BK@�Q���"{�"�ѫ��٠�=���=P���������<
q�<�H=x��?|a.@�Q������"3����K�>=���=P���������<
��_�ޚJ=x��?�8 @�Q��Z��"������E>=���=P���������<
�l=qR=x��?���?�Q��r���"�R?6�o<=���=P���������<
�')=F-�x��?0FH?�Q���"��"o�=��=���=P���������<
���=A��袧?��@?�:����J���{Ѿ"���>�Ҵ�=���=P���������<
�'H=����(��?-5?0k�Q�(�����t^޾"�dr���N>=���=P���������<
���;8��9��?�T?P���,-�z�H�Q���"�%���t>=���=P���������<
ގn���Y=��? ) ?�j���!����5�˾"��5�l?=���=P���������<
��V��9�<�Շ?0�?h�$�; ������햾"܂=@ ��=���=P���������<
�.����<P[�?�� ?+�~/ӾA=?�}�"�A�=6���=���=P���������<
��B�%�1=�{?�"?p /�xϿ�3l�<.J�"��>p-c>=���=P���������<
��U=4(=�Wp?�3A?x�0�1NϾ;������"s;<?�;
�=���=P���������<
.�<��|<P�e?�Z1?M4�1NϾ�w�����"�~�ÿ��=���=P���������<
�/��>F= �Z?P�+?��6�����5���k�"?+���>=���=P���������<
ψ);��=��N?p�5?�j5��w��k޻ ="�3"=dO�>=���=P���������<
t�v<}��=�C?�9?�2�,����Z:>"��>�}U�=���=P���������<
}nE<Z��<P7?`=)?�=/�,������Z:>"I�����=���=P���������<
z�ڼ� �<� +?Pa?�+�����J���=>"PA㾻�<=���=P���������<
^�0=˙X=��?�s8?(."�J��
��� 9�>"��M?�*�>=���=P���������<
��S<M�<>?��'?Pk�J羙,�� 9�>"���?��=���=P���������<
1g��N� =?�_"?��k���g��NJ�>"7 Z�(a�==���=P���������<
�ݡ�q�&� ��>`&?h(
�k����W��NJ�>"ne+���Y�=���=P���������<
^2����%�� �> �?x�(���*h��Q �>"R1�xµ>=���=P���������<
�-�@T%;�Y�>��?@9�� ���[8���>"#�c?�q>=���=P���������<
: <�絼���>n?�z��ٱ��ݘ��j�>"�&�=�\��=���=P���������<
�����ʻ��>@�?�ݾ�/��2��;[Zr>"�X;���<>=���=P���������<
r�Fa\�`��> ?pGѾ@���`���U>"�g��'� �=���=P���������<
�*������>��?pA˾�
m�q�K<]��="0u=HwS>=���=P���������<
ZE&=?�� Ǔ>`�?�_Ⱦ��� o���U="�?��S==���=P���������<
��=�C =��>�:%?��ľ`���������="�����=?=���=P���������<
,C<"�<�
e>��?0V���ª��qp�SL�="k+���=���=P���������<
qeӻ��;�xB>?����e��7�M��B�="�ZY��)M�=���=P���������<
ڿF���,=�$> ? ����G��Vo<.y">"����&P�>=���=P���������<
���.>-=�T
>P ?0���“y���!���]>"+��>����=���=P���������<
]���<J<d�=��?���жd��TS��du>"������=���=P���������<
��<�u����=��?0���%�k�u1���>>"�k?��=���=P���������<
(i<���;�=0s?1���4�����j��="���� `l?=���=P���������<
�.�$x �&=�w?����D���F��%�=" |�����=���=P���������<
�"�����:�7<`?M��au�OH<v��="���h<�>=���=P���������<
G-��K6�LJ���?p���\l��@�_T="^�7>���=���=P���������<
ږ��<
;y��[?pҀ��{I���O<K�V�"� ���[ ?=���=P���������<
m��<^ˑ<�T��x? �����L�'�*�P<"��'?��2>=���=P���������<
��1�:�z�j��#?ీ���.��u��Q�"LqJ�(xf�=���=P���������<
�w�����;���<?�����������;�ף�"�2|>��G?=���=P���������<
��%����;z��@j?�݈�k=ϽV{K����"v�h>���=���=P���������<
���;[��;�wӽ�? a��Ea׽�:@�3��"��=�;A<=���=P���������<
��I��'��P`?P���0瞽�d�; ɇ�"����;��=���=P���������<
�2�;l>y<�G��0}?�䒾�΀���^:Wxe�"�@?�zd>=���=P���������<
����P:�+����?�e�� �&���(;]�@�"׀޾/2&�=���=P���������<
k3�9yk�=���$?�(���B��Ah� �;"t�>�x9?=���=P���������<
c<�����0??��������ܾ9�9<"�a�=PC�=���=P���������<
n/,��������?�Y��s��:l�����;"x��K�<=���=P���������<
�F<���5��?�;��ю;��~�����"
?�3��=���=P���������<
��e��E<=@���^!?P~��^a�;�N��WO<"����]H?=���=P���������<
^U3=d�;��p�?`���?��������<"#C)?�;��=���=P���������<
�B�75�<�����?@돾�T�� ` �o�[="o�n��K?>=���=P���������<
>�żc"#��
���?�k��nn<����="�>++8�=���=P���������<
>��&�<�����?�*�����<-�K�_(<"B�B>z�2?=���=P���������<
�)P<�8[��T��?@��z<�<�ɾ E��"�jm>d[�=���=P���������<
���h� =!�P9?@ҏ�ރ�;)~Z;{�K�"B�*��?=���=P���������<
��E�%6�;@����?0t���==>Ն���Լ"i������=���=P���������<
�=���<�0��K?�B��Y=\�m�W��"�?��3>=���=P���������<
{���\nf=&�� $?������<
u
��v="\�B����>=���=P���������<
�79�A��L����?�⋾�]v=���n�p="�LU���Y�=���=P���������<
�I4�Z����bڽ�d?+��ԡ�=g�����<"���>;Fe�=���=P���������<
D�0= g�:�ŽI?̉��K�=�]�;��;"��?��>=���=P���������<
�V =G�ﺀU��@2?�v����=�U��Q�;"km佌M*�=���=P���������<
��E�RǾ<�����J? T���JZ=�ͻ���<"�n��ϙ>=���=P���������<
!x�<ST=��`�?`��jx�=g%�L�="�H<?�7�==���=P���������<
BJd�*Z��ғ�0�?p�����=Sb�sU:="�jT��6=�=���=P���������<
( ����^=�a�@�!?���&>�/Z�X��="���>�?=���=P���������<
����[ɻ���? gx��{8>�����Ț="+�����1�=���=P���������<
D^�={��<���&?@�o�A
>�a<�,�="�?{�>=���=P���������<
3��<=�Y��;� �?��e����=�8���="w��l�پ=���=P���������<
撧�v�� =���?`�_�'��=3d)���<"��L@�=���=P���������<
�����_�<�?��f�̅�=e����h��"G�>@x�>=���=P���������<
/ d=�zN=�k<�!?�l��I�=�b<��q�"��<?�?=���=P���������<
C�V=a�1= �<�� ?� o���:�e=���:"��ơ��=���=P���������<
��+=�.��w<��? n�:{;���r�� <"c��m�=���=P���������<
9Oj�{;�;� ?`l�?���`=�7��<">B����==���=P���������<
�x<9�%��@���{?@�j�J��x�����<"K�?<n�n�=���=P���������<
@��;�*����N?��i�o��������ǻ"8�K=�ⴾ=���=P���������<
rlܼ��ǼЎ��<?`�m�����u�:�9V�"�鹾�_�==���=P���������<
qHt=_�0=n߼�6?*q���ν���!i��"�?��C?=���=P���������<
ft�;��b���`1?`t����?~��}��"(n!���%�=���=P���������<
+��� �<�I�0�?�v���ڽ�d�;�p�"�r��*k?=���=P���������<
�*�:o�P3?��v�����x�������"�9�!���=���=P���������<
n�;�a̼O����?@xw�Er�7�%�U���"BU�>�x��=���=P���������<
{^�<�lA����PN?��{��Ӷ��zֻ{1L�"�Y�>Ȥ>=���=P���������<
�)��Ƙ<�G���^?�-��6������;�N'�"8>7��8�>=���=P���������<
�p��a�������?`��Ui�����&��"M��{ы�=���=P���������<
�^��ئ����`�?A�� %�<J�9�=�\�"_?�e6�=���=P���������<
H�!;de�=�ң��/#? ���9�=��:R>@�"1K>�?=���=P���������<
�9�9�
=���P�?�샾[=ŜQ��� ="������þ=���=P���������<
��<a�<������?@��� ��< ýz��="��N>�Q�=���=P���������<
�α��TW������? �x��<�{��~��="y쾳z¾=���=P���������<
gF#<"�;�����?tq�M� =m�8�I�="�>�>��D>=���=P���������<
nI�:r���g���?�k����<� [����<"�VԽr�B�=���=P���������<
G��<@`�����`z?@�n��{;�X9rm��"�~�>|*1?=���=P���������<
d]���=���ׂ���?�ar���R;�} �JH�""�����x�=���=P���������<
�؇�Qϥ<�{���?��w�R��<6�t:��=�"�(;��?=���=P���������<
�U�9z~�<nm�@+?��z��=�4¼����"o�G> eU�=���=P���������<
�-��K
;2[���?�{�>�^=��_� ��"Z;��f�C�=���=P���������<
I�;
�:=8?���?��y����=#�2�s�="rɓ>� �>=���=P���������<
��
<0W&�,$�j?��v��K=N���>=_<"Njn=�A}�=���=P���������<
h�̻���< �0R?��w�(/k=�kչ���"y|!���Q?=���=P���������<
�1=���.��P�?��x���=^���}�"�6�>{�,�=���=P���������<
(��<�=4���?�z�����Ln��"V"����:?=���=P���������<
�'�<����� �`?� {�e�A��8���<\�"��Լ�*�=���=P���������<
xr��"=��@� ?�%{��B���@,;()�"����7?=���=P���������<
c��=Z�<���?�y�Y�U<)2.�k+�<"j��>o��=���=P���������<
��=rx<0��Q?�t�+�y���d="�?{�(�=���=P���������<
�����ig:'�?��n���λK�Ǽ�#�="���b�!�=���=P���������<
�|M=%`l;`� �?`�g���%�h���0�="B?{8�<=���=P���������<
�b������4�`�?`�a��_b��8���O:="�,�S翾=���=P���������<
>��<��0<ZM��F?@_^�����n�M��� ="�>J��>=���=P���������<
�y��@�n��_�`�? [�䙴�T\l;aR�<"�n��3��=���=P���������<
®Y���<XZ��?`X�n��<����="�?���>=���=P���������<
����m��<�K���? S��6=3�p�Wp="V!�� ��=���=P���������<
�;H<�֤�:9���?�AM���*=׭��lM="<�>þվ=���=P���������<
ӎ;<P�;�.� �?`�I����<���9t ="v%ջC�>=���=P���������<
c�4<�¦�)���?��F�Z�;� �����<"�Q������=���=P���������<
aL�<�f���.�`!? �E���м�������;"țT>�i>=���=P���������<
��:�&�<t?��9?`
D�.���W�m �<"�����^O>=���=P���������<
�X��n��I��+?@vC�)��:^ʽ����"�� �p?[�=���=P���������<
Jo�����6� >?`K��c�=�>�B�Ž"�vm>~t>=���=P���������<
:1�<Nӟ<�� L?��U����=]��;}�ν"B�?�@$?=���=P���������<
�f��5A <����
?`�^����=NV��y��"�nʾD�ٽ=���=P���������<
�-6=�i�<�ɼ��? �e��L6=���;KPZ�"��1?�rA>=���=P���������<
KŻN��,���s?�xk���=U�g��@��""H� p,�=���=P���������<
-�=]�b=,��`� ?`tp��d<��;��E�"O��>�?=���=P���������<
�dd=J-�6��p�?��u����A4����s�"��>��=���=P���������<
@=�9"<r��p�?@�~��d�!!+�~��"@7ǽ���>=���=P���������<
����`&�E6���?�ƒ��� ��tü�V��"��I���b�=���=P���������<
U�����<�d�`/? s����ٽ&ɨ;!��"_�M=ҁV>=���=P���������<
� (�W��<H��� ?` ��c|���<[���"����v�S>=���=P���������<
�����<7���?����[�e�ڻ�����"��w=0�߽=���=P���������<
C4�V-��Y���?`�����|;Aǒ���Y�"5��>�8 �=���=P���������<
�:��ϼüц��+?P����\<�@����۽"�1�=6��==���=P���������<
ⱀ��bO=Yk�``$?�$���v�=�{�Ґ��"Z:5���\?=���=P���������<
֌�NK�<�L�`?�ǜ�*ղ=1����h�"g�?��þ=���=P���������<
?�=|��;�+��
? 垾)I�=ṽ���"s�>E$�=���=P���������<
�!L<��
�����?�2���g0=!�V��n�"Yg�cM޾=���=P���������<
䡰�RS��M�В?����+O=�"p�b�ý"k�ƾ&~�>=���=P���������<
�l�<�*�<Fݼ�e?@ ���<=���;pؙ�"Ӹ�>f��>=���=P���������<
�P$<>��<T����?��� ��<#�����B�"�M�����=���=P���������<
����� <$����?�䮾��d=B���� ��"0 �.�۽=���=P���������<
��H=��7=x�"?h��6Q=��;��k<"�?���>=���=P���������<
b���ɼ�:�P-?�����. =/���T��;"2�(�@K�=���=P���������<
�P= �K=|H���"?�q��ϫY:����1�<"<.?�RY?=���=P���������<
�v=P=(s���?`����9� ����="nG/��OQ�=���=P���������<
��<2rx����0s?0���zA��mu��aڊ="C6a�ԋ �=���=P���������<
sL;r{�����p�?�x��W6��Dƨ�S�<"�2�Oo5�=���=P���������<
;ļ^�R=����"?`ȡ��+��B�ĺ�(D="�����Sq?=���=P���������<
��$��<�)��?�)������*�8��="�0�Zϝ�=���=P���������<
�O\;��q��4�@�?P���O7���Ɇ��I�="f�>ʎ��=���=P���������<
t.H�� �9V;��?@ȕ���P�Ha�9�I�="P�5���1>=���=P���������<
-�����< =�pH? ���$����L��!�=".��={�M>=���=P���������<
[Z9=ok��@�J?�u���]��S̾�EP="0?��=���=P���������<
�b�<�|R��e��\?pq��q ֽ�Y�����"B膾�P>=���=P���������<
#f*��L<�����^?�����ݽ�U�;n�ν"�澾z�;?=���=P���������<
�>��hE��ћ��N?&���Y����L��ֽ"��Ծ����=���=P���������<
M��<��8=�����O ?A���k\�Aǐ;vܖ�"7�F?Wr'?=���=P���������<
�`���t;�&��`?{�����/�35a�"�T��]��=���=P���������<
b=}o��O���? H���g��6X��W3d�"�?���=���=P���������<
r��)��<��̽�q?����N���b�;��"C�a�� �>=���=P���������<
���6=|;��۽��?�1��$j��
�G���"^�����=���=P���������<
b�;� <�,���?�Ө��#�~pc�+��"&��>��O==���=P���������<
E�w��� �����p�?�:��.���Jx�k6�"�����,U�=���=P���������<
��λiF=��Z"?� ���y��p��`�<"�K�=��*?=���=P���������<
�m'����<�����{ ?p���y
�<�-�����="��˾�U��=���=P���������<
K�
<�<�;@��@'?&]�<Զh�0�="�o?*�U�=���=P���������<
~���ʢ� ��`�?p���&M=��幫e�="tH����ؽ=���=P���������<
��^<�w5<�,۽��? ����; =lzj�i��="\��>:>=���=P���������<
i6h<U�=�@׽� ?���/�z<�<�����="<+܂>=���=P���������<
}���N|��sӽP�?p�����<��V�{1�="�{ξ��#�=���=P���������<
�l�<8o��`ͽ�?0���t(�<�zV��O="���>ԕӾ=���=P���������<
��=Hm,<��ѽ@7?Pv����������1�;"Ôa>d�I?=���=P���������<
"U��N�<�ؽ@�?P�������ٹ��<"\��x>=���=P���������<
���<�0"<�޽`>?0 ��������#D="��>R}�=���=P���������<
�G�:L�:��D��0?������J��ό9="ЏG���w�=���=P���������<
B�����x<��p�?�!��^�vѸ�CM="Ȼ��a�>=���=P���������<
T� � >E������?`���·�����\M�<"� >�J:�=���=P���������<
Gk
�J]�����P�?�,��O��<�Zf�(�I�"^i��+ґ>=���=P���������<
b_=��\;����x?�#��� �;�q:��q�"�iA?�'�>=���=P���������<
��:?yV=�"��� ? ؅��},��C�:��U�"Ӆ���?=���=P���������<
�����@<X��p�?ළ�;H�:�;8��e<"�/i�����=���=P���������<
~�;�f;,����?������3<�%ƽ���<"��>4ɽ=���=P���������<
.�ռ
��z���/?`׃�``�<Av��oK��"n��6�ݾ=���=P���������<
R�;B�==m���|?`Y��2)=��I�f�d<"��>�Ml?=���=P���������<
� S���;��ܽ�&?����P�D=�8��H ="0�p��*��=���=P���������<
���<�e开�ҽ��?𘀾�,=j�H�&��<"{j�>|���=���=P���������<
>s����/<�Ƚ�V?p����V=�M��:$9"�F
�� �>=���=P���������<
�$�<P���.���~?瀾0,B=��b���(�"�S?J���=���=P���������<
|1r<
{=�Q���b!? �~��p<|2��v�<"62t�� q?=���=P���������<
�D9�q�;���@L? �{���@<xY���.="�C@�F(,�=���=P���������<
�_R�ND<�����?�Nv��I�<8���~mc="f� ����==���=P���������<
Z;$"���.��@t?.q�v�< �I�ګ!="��)>D��=���=P���������<
�b|���<����@�?`�m��S=�,s���$="o�D�pu2?=���=P���������<
��%�dU=Ś�@" ?��e����=n�T�ゲ="������==���=P���������<
D��<��������E?�X]�RV�=rݧ�՝�="�-?�J�=���=P���������<
S�Q��+ ��|�0�?`[�{#~=�¡�֑;"࿾քw�=���=P���������<
x)�<�����e��1?�,]��V1=K�^�"|V�>�Ta==���=P���������<
��&= �{��a��I? c�[�����'����"_� >�>=���=P���������<
���;��R<no���?��i�CE���:s��"��ž���>=���=P���������<
�Y�;��P=W�� ?`l�n.��{:$M�"�n���>=���=P���������<
�E��O��<�ׇ��z?��j���޼��ҽ�?�<"HH��iġ�=���=P���������<
|��:��W��ь���?��h�{��Ed��2��"��>L_�=���=P���������<
�ޝ<�{������?`7p��-��� ��T��"�W>��+>=���=P���������<
pà� �G=U���� ?`�w��P��"
<����"�&ھ�?=���=P���������<
����� 1=����`_ ?`1{�H�f ��hU�"8�3=�̓�=���=P���������<
>@��6$�= ��p�)? �o�r~�<筼��="�m��|0?=���=P���������<
h=�U�<E��`�?`�e�gC�<UC���="*`c?��=���=P���������<
�`���s�������?`cY��Gi�B�r;���="��W�7ɋ�=���=P���������<
c�׼��=�����?@K�E�мQ���d>"�Dؽ~�?=���=P���������<
��ȼ%������o?`b<�j�ź�d��L
>"v D<��K�=���=P���������<
���<=���|���?��1���� ���^ �="��"?��>=���=P���������<
i�ɻ��?��0����?O*�?��N��:��="n�׾���==���=P���������<
r�_�7�n����?,%����-�1�\�="�o�� ��=���=P���������<
?_3;���; ����?��#�N�J� ���T<"��F>+K�>=���=P���������<
B��K������?@�$���]=f'-�D�޼"0�\�B��=���=P���������<
�m^�ʙ<t����?�Q)����=����A�"3`-?K+?=���=P���������<
.��<w;�܆��T?@�-�B%�=�B��y �"�>�㐽=���=P���������<
�f�D��<�`���?7/��J�=!>R�?97�"�$v� ��>=���=P���������<
e��<�6�=�)� � ?�5+��Q>�����=" �s?���>=���=P���������<
T�)=���v��@z?�g$����=�������="C�+>��=���=P���������<
dŁ;�\��ü0:? �* O=$ �:��="
ܾV��;=���=P���������<
a\;~�<@��0?�_��A>="ƀ�|��<"�ȕ�i��>=���=P���������<
X>ռ��<�Y�`�?���qM�=G»��T="�;����6>=���=P���������<
�=[=�ɼ����?������<� D�x1"="��h?�� �=���=P���������<
h�"�ࡆ�����?@��U�;yw|9#�!<"�.:�XK�==���=P���������<
,s=Z��<��0S?���RE��Iz�А�<"�v?��>=���=P���������<
X��^�]<T>���?�3 ��_�&Ϊ��n'="s�ؾ[���=���=P���������<
�2�*ڨ;v��?@0�z9���ꦼ)TX="�k���ý=���=P���������<
j��<���ʘ���?��GC����V��"���>��V�=���=P���������<
3����I��v̼�?`
��Ȃ��b��BP_�"d$���E9?=���=P���������<
;|Ӽ���H����?@��g+�t�����" ݍ��+O�=���=P���������<
^h7��xV=|�0C?@T�*���[�;�)8�"X�.>�\?=���=P���������<
9��@����`�?�2���|9ܯX�bg�"�ֽ��=���=P���������<
OE<��=��,?��!���<���;k�m�"|
�>G?=���=P���������<
�̂<%1�<���M?��%���/�%HE��>ͼ"֠5=X�=���=P���������<
®_<W��<���?`&�4��1��  <" �׼��"==���=P���������<
9ߕ�)%���@�? 3&����<Pǻ���"��� ;�=���=P���������<
&�F�~�<��7?��(��=k��E���"H�4?vlH?=���=P���������<
5�Z=���n��P:?`8+���-=�? ��
$�"��??��D�=���=P���������<
�;=�n<����@?��3��/�ҹ���ݽ"8K���ٰ�=���=P���������<
D�뻹���漀�?��B���p���2��*�" ����==���=P���������<
S�X;���<v
�@o?@�S��<j���<�Y%�"���=L�:?=���=P���������<
��D<ߥ:��#�Ћ?��c���������"�z�=���=���=P���������<
 ѻ�A�={A� "?�km�&B���2R<����"@�W�L>;?=���=P���������<
(R�<�ݧ�v_�=?� w�4���U޾s�̽"3��>��v�=���=P���������<
�F%���;=�}��F!?`R��(���[<o¤�"P�=��|B?=���=P���������<
E=k���p����?�僾�浽Ϡ��
"�]?� Y�=���=P���������<
6V�grh<����?�쉾o���#D<����"���>=���=P���������<
�"d;+@=i��Pg?�ێ��j����;z���"g#? X[>=���=P���������<
߻k2<����D?`d��q�]��6���z�"+���
��=���=P���������<
C0��s�@<Z���H?��������;pf6�"6�� �2==���=P���������<
�T����;�½��?֖�Ǵ�;���?"�"�P^�.���=���=P���������<
s_�<������03?ט�F��eY��@�W�"�9D?�Rؾ=���=P���������<
Z���P:�ǽp=?0R���C����]:-寽"�
����>=���=P���������<
��D��#Y<�ý��?PJ�����<��"�$%��"�뷾X>=���=P���������<
�k=�h�;����}?�夾4�i<� �;��"�y?O��=���=P���������<
�Sq;f�=ý@ ? m��W���w�:q+Ҽ"pHþ���>=���=P���������<
�T��c3= ����$?�2��B��<P�����="��"�kS==���=P���������<
c�߼�׎<����P_?����k�Y=�����g="�|�>�7��=���=P���������<
v=��f=&����!?�^�����<�03��*�="�?Sc�>=���=P���������<
��K=,��<���@�?>��'[ּZq�c�>"�����;=���=P���������<
�4�;�ο��׼��V?���*p��?!&���>"�y��r��=���=P���������<
���<���:�ν��?�P���Ľ ���S>"��d>�p�>=���=P���������<
��i�D%G������?�Q��m�½��ǽ� �="^36�=���=P���������<
B���0����p�?@�~���d�U/�;���<"TPO�k7{==���=P���������<
��=��������?�~�:s����4�����"��L?��w>=���=P���������<
<Ê��X�<����z?�7��A�ŽE!:\pF�"&����C ?=���=P���������<
���J���`� �?�h��yڊ��C����b�"�˾bį�=���=P���������<
\��;������P;?�逾��k���8�����"A�>���=���=P���������<
��ͼ��K;%���? ����
��1 ;s���"������x>=���=P���������<
�:���h��@��p0?P���3a��3"�,K�"ێ=�^��=���=P���������<
`lE����
���?����
@=��Ļ���"�k��q_>=���=P���������<
o����<��p�?@_����=�N��6��"�5�>Xr�>=���=P���������<
*�^=�~�<��
�0?� ��>Pa=3˾�4b�"�b]?+B�==���=P���������<
�2�<�=�� �`]?����5>N: U�h@S<"O�޾��T==���=P���������<
�K�JA<�� �PE?�s��� ;�U��� ="���0�v�=���=P���������<
��;�^.�� ��5?0Ӌ���;T����P�<"�WX>'&�=���=P���������<
��(� !��@ �@�?�t��!�;������"�B>�I�>=���=P���������<
�c�JS=@��0?"?�錾��<�{��C�P<"�i:��"?=���=P���������<
)P���wǼ���J?`\��5W=0�¾�5�;"�fx���^�=���=P���������<
���M5<=����k ?�����=^F�Բ�<"�#0>J�N?=���=P���������<
м�<�x�;����8?�1��ͰK=�t��%� ="s٩>�D��=���=P���������<
ܕM��p=�6轀s?`†��==�ʻ���="G?ھ(��>=���=P���������<
wT =҈N�����Р?pP���ã</��N�T="��?��=���=P���������<
t��<�����!?�����cx�qwa�n
��"��&�2�?=���=P���������<
�< �#�;� ����?���y=�f;m�Ҽ"�r.��>=���=P���������<
�6�;K|<^��`e?@��������\��U&�"���>^��==���=P���������<
�Ҽ��P=�����"?焾1h�n0="ӈ���{�>=���=P���������<
�-��eD��L�K?������xI���6="��7>��8�=���=P���������<
�d@����<����� ? ���bC=$����m="B�׾�2�>=���=P���������<
�H;k��<���@?�
x����=�H��F�=""�?�Q'�=���=P���������<
ʈ�<��ͼ���?`�n��8=�큾�="�>�j�=���=P���������<
�����=�ܽ�=?�f� �;=�d��O�="��*�X�)?=���=P���������<
1蛻��v���ѽ��?��\��m=�H���?�="�q�>�p �=���=P���������<
!y�<�s<�Ƚ��? kT��t=/�湮��="��>���>=���=P���������<
\���,2��½pH?@�K����<iE��="DgѾ�ZǾ=���=P���������<
�L����<����z?`-C��I=�묻*;�="�;c=l
?=���=P���������<
��߼q.�;����?@�8�$t=���/�="�����h��=���=P���������<
��;_����(��=?�(/���=ߢH�h#z="���>۱K�=���=P���������<
��r;���� ��`�? C0��k=�����B��"Ɠ�����>=���=P���������<
���<=K��g���?��5� �=�.��"n�W>c�>=���=P���������<
9(��%��<����d?� <�)�!=/Y�:_l�"(s쾾�?=���=P���������<
�X><iip=�|�P�?�R=��X =���9�Q�;" V�>���>=���=P���������<
�zH��.�<�l�@C?��:�/11=.����4="����>߾=���=P���������<
�2��i��ZS�`�?@=5�:1�=����)O="�Q����=���=P���������<
���<_e=�1�~?`�.��ߜ=��Ǻb�="��?�s?=���=P���������<
�K�� �S����?�(�w�==�)�W�2="+����=���=P���������<
��c<ںN�fۼ@�?�n)��ƪ=�n3�z�3�"vh�>�0�<=���=P���������<
ӹ�<���������?��0�Ij=���:���"`+�=L�?=���=P���������<
_��;d��=���P�"?@�1�s&=��;���"tL��?=���=P���������<
O��<k,=�X� ?�z/���<iE���F="_4=>�'��=���=P���������<
����p1=0�P2?`%��q=U��>�="��+5L<=���=P���������<
L&��3��ϻ�?���е[=�}-�)>"��{>?�=���=P���������<
E�J=˭<�0'���?�W �k�a<����տ�="�L/?�q;=���=P���������<
�C_�Ac��?��v?���\A��*9�LN="��9��?=���=P���������<
1�<�M���X�?���h=���J� s�<"�ϲ>���=���=P���������<
l�W=L��;��p�?@��J̈́��滟�D<"���>���>=���=P���������<
�5<.�^�����?@����Ƚ��P� ;�;"0��b�m�=���=P���������<
x8˻���;�ۼ� ? ���9ɽ���:��9"��0� �h>=���=P���������<
P?��䭥�����?�[��㞽�֕�-F��"��*���=���=P���������<
��<}3*� 3��?@�M鵽����,
�"��?�#�>=���=P���������<
@\X�?�,<�Z�0�?�#�&⿽�;&�-�"����)?=���=P���������<
��輇��<\x��G?�2�E\� �8<.z�"�1���D>=���=P���������<
9M��]: ;�����? @�K���Z׽���"�C�=(���=���=P���������<
��,�=N�;����@/? �L���t�N#��!���"��j>'�<=���=P���������<
&c���
=q���?2W�,����;���"�'T��&�>=���=P���������<
�j3=�h�;!���D?�_�K����$�]{��"j�?;檾=���=P���������<
4Q�F�Q<D��@?��f�3����A�:.?~�"k���==���=P���������<
��<���������?�l�����[�ӽ�Sl�"�%�>p�*�=���=P���������<
�p�f?<�����|? �q�|{���� <�lF�"��x�>=���=P���������<
����= ���$?�oq�G�>�\����q?<"���>�6?=���=P���������<
*)��(sH;�*���~?@�o�jH�<LG��o�<"��:>>:5�=���=P���������<
��V=wb<���`�?�kl� ���í"�W)(="\�T?�>=���=P���������<
T�Z���@��#��P�?@�i�PBǼ<����;"�~@�|b3�=���=P���������<
�$><��z<�ý��?�vk���ȼޟ��ߢu�" R�>�k7?=���=P���������<
͎Լ=�"=��ƽ��?`ki��[�'7ֺ��<"<�ܾKw�>=���=P���������<
��,�-9���Ƚ`�?�>g�y ߻MrľLW�;"�d�>s�x�=���=P���������<
�Fȼ%M���ƽ�o?@�m�?��<�N��R���"4ߋ�nP">=���=P���������<
Y9�<��J��mý0M?`^x� G<�k�ux�"��?�T�=���=P���������<
k�^=��C<��?�[����y��<��q�"T��>�4?=���=P���������<
u�<��>=���p�?`4��*˽�b
<D���"WH�h��>=���=P���������<
���� �=�����%?��������(�;�ۼ"P����G�>=���=P���������<
���<����@�� �?Ps�������jy��ۼ".B�>��=���=P���������<
JO���l�;�� �� ? ��psH�y�;��K�"��k?�>=���=P���������<
�=��5�<� �� ?�)�����<��4��"�zt>.��==���=P���������<
x*;͗�<��P�?@ ��]x8=����7R#�"ز?Hz�==���=P���������<
)6=�=<)��@?`���.E�<��(�gU�<"8�>;O�=���=P���������<
9�����<����?�����t;HO6�h="pz����=���=P���������<
P�;J�3��j��9?`F��@�;hL���="on�>�߽=���=P���������<
�y<�<���`�?PG���SF�{'���,="S��=�6�==���=P���������<
A���ڻ����9?p��խ��k):
�'="��پ�@+�=���=P���������<
���<��c����?@I������7���bq�<"\�>�p��=���=P���������<
���<�}Y<@�2?0�����
�D_�����<"��;"��>=���=P���������<
n�`���<�� �`?�����e�����S="�۶��Jw==���=P���������<
�s��h����} �Ѓ?0т� �����xd="�Լ=>2��=���=P���������<
�=Ɍ������?@����l4��,R����<"kR�>����=���=P���������<
vGռ���<@2��l?�����N�� ;��c<"��-�rN?=���=P���������<
!w��l��<`��?��|�|����ђ���,="ku�=��{==���=P���������<
�j&<�G����p�?@;y���ҼH�ھfI@<"2`�>Lpx�=���=P���������<
�nJ���< ���?*���><y�9mqo�"0/���S?=���=P���������<
Ah�O�~��Q���?p���=�-5�/�d�"�>�O��=���=P���������<
�b)<���<@���X?�B��=�&=m�';�@
�"_i�>u�>=���=P���������<
��U=i�<�����? F��Ҹi�9 սp��"���>��x�=���=P���������<
�GS<T��r�@�?�ކ���&�5�u!ü"�Y����=���=P���������<
�p���<����?߈��o���;�� �"K����?=���=P���������<
:�c�C����00?๊��_ļFDM�� -�"�gA<H��=���=P���������<
�w���%=����#!?� ����<���:�o˼"�>ٽ[�4?=���=P���������<
�����4`=�d���#?p܊�95=�һ��0="��e<��)>=���=P���������<
�*W���<�0� ?����p��=#1��C��="p��F���=���=P���������<
�� =�k"<����<?�����=�a�T��="�?�]��=���=P���������<
'�<��üY���?��|�r�'=p޽ æ="�cW��Ǿ=���=P���������<
��
�������pj?`�w�8n=�� ��W9="|���W�<=���=P���������<
j�_�P�?<����?��s� y�=9+a���2="��j>+�>=���=P���������<
@��< ��<Խ��?��m�g��=���ރ="�?�>��==���=P���������<
p�8=C3���ǽ�X?��g�"�=k�h���5="���>���=���=P���������<
�z�<,f0;��̽`p?�f�����g�溯�0<"kQO�+��>=���=P���������BallDemo� -bfB

1001
demos/Expert3DBallHard.demo
文件差异内容过多而无法显示
查看文件

1001
demos/ExpertBanana.demo
文件差异内容过多而无法显示
查看文件

171
demos/ExpertBasic.demo


 ExpertBasic\ -�]?*: BasicLearningc
P�?"P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������j
P�?"@=�p}?@P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������j
P�?"@=�p}?@P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������j
P�?"@=�p}?@P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������j
P�?"@=�p}?@P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������j
P�?"@=�p}?@P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������j
P�?"@=�p}?@P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������j
P�?"@=�p}?@P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������j
P�?"@=�p}?@P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������j
P�?"@=�p}?@P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������j
P�?"@=�p}?@P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������h
P�?"@=
�#�P���������j
P�?"@=�p}?@P���������h
P�?"@=
�#�P���������h
P�?"@=