|
|
|
|
|
|
# It matches "mlagents" and "mlagents_envs", accessible as group "package" |
|
|
|
# and optionally matches the version, e.g. "==1.2.3" |
|
|
|
PIP_INSTALL_PATTERN = re.compile( |
|
|
|
r"(python -m )?pip3* install (?P<package>mlagents(_envs)?)(==[0-9]\.[0-9]\.[0-9](\.dev[0-9]+)?)?" |
|
|
|
r"(python -m )?pip3* install (?P<package>mlagents(_envs)?)(==[0-9]+\.[0-9]+\.[0-9]+(\.dev[0-9]+)?)?" |
|
|
|
) |
|
|
|
TRAINER_INIT_FILE = "ml-agents/mlagents/trainers/__init__.py" |
|
|
|
|
|
|
|
|
|
|
("python -m pip install mlagents", True), |
|
|
|
("python -m pip install mlagents==1.2.3", True), |
|
|
|
("python -m pip install mlagents_envs==1.2.3", True), |
|
|
|
("python -m pip install mlagents==11.222.3333", True), |
|
|
|
("python -m pip install mlagents_envs==11.222.3333", True), |
|
|
|
]: |
|
|
|
assert bool(PIP_INSTALL_PATTERN.search(s)) is expected |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_pip_install_line(line, package_verion): |
|
|
|
match = PIP_INSTALL_PATTERN.search(line) |
|
|
|
package_name = match.group("package") |
|
|
|
replacement_version = f"python -m pip install {package_name}=={package_verion}" |
|
|
|
updated = PIP_INSTALL_PATTERN.sub(replacement_version, line) |
|
|
|
return updated |
|
|
|
if match is not None: # if there is a pip install line |
|
|
|
package_name = match.group("package") |
|
|
|
replacement_version = f"python -m pip install {package_name}=={package_verion}" |
|
|
|
updated = PIP_INSTALL_PATTERN.sub(replacement_version, line) |
|
|
|
return updated |
|
|
|
else: # Don't do anything |
|
|
|
return line |
|
|
|
|
|
|
|
|
|
|
|
def git_ls_files() -> List[str]: |
|
|
|