You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Dribble/Run_Utils.py

96 lines
4.4 KiB
Python

def main():
from scripts.commons.Script import Script
script = Script() #Initialize: load config file, parse arguments, build cpp modules (warns the user about inconsistencies before choosing a test script)
# Allows using local version of StableBaselines3 (e.g. https://github.com/m-abr/Adaptive-Symmetry-Learning)
# place the 'stable-baselines3' folder in the parent directory of this project
import sys
from os.path import dirname, abspath, join
sys.path.insert( 0, join( dirname(dirname( abspath(__file__) )), "stable-baselines3") )
from scripts.commons.UI import UI
from os.path import isfile, join, realpath, dirname
from os import listdir, getcwd
from importlib import import_module
_cwd = realpath( join(getcwd(), dirname(__file__)))
gyms_path = _cwd + "/scripts/gyms/"
utils_path = _cwd + "/scripts/utils/"
exclusions = ["__init__.py"]
utils = sorted([f[:-3] for f in listdir(utils_path) if isfile(join(utils_path, f)) and f.endswith(".py") and f not in exclusions], key=lambda x: (x != "Server", x))
gyms = sorted([f[:-3] for f in listdir(gyms_path ) if isfile(join(gyms_path , f)) and f.endswith(".py") and f not in exclusions])
while True:
_, col_idx, col = UI.print_table( [utils, gyms], ["Demos & Tests & Utils","Gyms"], cols_per_title=[2,1], numbering=[True]*2, prompt='Choose script (ctrl+c to exit): ' )
is_gym = False
if col == 0:
chosen = ("scripts.utils." , utils[col_idx])
elif col == 1:
chosen = ("scripts.gyms." , gyms[col_idx])
is_gym = True
cls_name = chosen[1]
mod = import_module(chosen[0]+chosen[1])
'''
An imported script should not automatically execute the main code because:
- Multiprocessing methods, such as 'forkserver' and 'spawn', will execute the main code in every child
- The script can only be called once unless it is reloaded
'''
if not is_gym:
'''
Utils receive a script support object with user-defined arguments and useful methods
Each util must implement an execute() method, which may or may not return
'''
from world.commons.Draw import Draw
from agent.Base_Agent import Base_Agent
obj = getattr(mod,cls_name)(script)
try:
obj.execute() # Util may return normally or through KeyboardInterrupt
except KeyboardInterrupt:
print("\nctrl+c pressed, returning...\n")
Draw.clear_all() # clear all drawings
Base_Agent.terminate_all() # close all server sockets + monitor socket
script.players = [] # clear list of players created through batch commands
del obj
else:
'''
Gyms must implement a class Train() which is initialized with user-defined arguments and implements:
train() - method to run the optimization and save a new model
test(folder_dir, folder_name, model_file) - method to load an existing model and test it
'''
from scripts.commons.Train_Base import Train_Base
print("\nBefore using GYMS, make sure all server parameters are set correctly")
print("(sync mode should be 'On', real time should be 'Off', cheats should be 'On', ...)")
print("To change these parameters go to the previous menu, and select Server\n")
print("Also, GYMS start their own servers, so don't run any server manually")
while True:
try:
idx = UI.print_table([["Train","Test","Retrain"]], numbering=[True], prompt='Choose option (ctrl+c to return): ')[0]
except KeyboardInterrupt:
print()
break
if idx == 0:
mod.Train(script).train(dict())
else:
model_info = Train_Base.prompt_user_for_model()
if model_info is not None and idx == 1:
mod.Train(script).test(model_info)
elif model_info is not None:
mod.Train(script).train(model_info)
# allow child processes to bypass this file
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nctrl+c pressed, exiting...")
exit()