问题描述
我正在尝试更多地了解如何使用数据类,并创建了一些将输入多臂强盗模型的类:
mix.js(['resources/js/app.js','resources/js/functions.js'],'public/js')
.css(['resources/css/default.css','resources/css/app.css'],'public/css');
它的作用是:
我的问题是,如果我启动 BanditParameters,例如
@dataclass
class DefaultBanditParameters:
"""
The set default parameters for the Multi-Arm Bandit
@param model: name of model
@last_training_time: Metadata for last recent train time
@param arms: a dict of all the trained arms
@n: Total step count (data samples trained on)
@mean_reward: the mean reward of the model,updates each step
"""
model: str = "Default Bandit"
training_time: str = None
arms: Dict[str,Dict[str,str]] = field(default_factory=dict)
n: int = 0
mean_reward: float = 0
total_arms: int = field(init=False)
def __post_init__(self):
self.total_arms = len(self.arms)
def get_params(self):
return asdict(self)
@dataclass
class BanditParameters(DefaultBanditParameters):
"""
A class object that retrieves the most recent model parameters or default if can't find any
and returns class object.
"""
@classmethod
def update(cls,bucket,session,input_dir):
"""
Retrieves latest updates from S3 directory and updates model parameters if any,otherwise uses default
@param bucket: S3 bucket
@param session: AWS session object
@param input_dir: S3 key for model parameters
@return: class object containing model parameters
"""
s3 = session.resource('s3')
try:
key = [obj.key for obj in s3.Bucket(bucket).objects.filter(Prefix=input_dir)][-1]
file_content = s3.Object(bucket,key).get()['Body'].read().decode('utf-8')
logger.info("Latest parameters found in directory.")
params = json.loads(file_content)
params_dict = {
'model': params["model"],'training_time': params["timestamp"],'arms': params["arms"],'n': params["total_events"],'mean_reward': params["mean_reward"]
}
return cls(**params_dict)
except:
logger.warning("No parameters found in directory. Using default...")
return cls()
@classmethod
def get_default_params(cls):
"""Returns the default model parameters"""
return DefaultBanditParameters()
然后运行 params = BanditParameters(model='my_model')
如果失败,它将返回默认参数而不是更新的参数(本例中为“my_model”)。
我对数据类还是个新手,所以可能我做的不对。
示例:
params.update(bucket,'my_param_path')
返回 params = BanditParameters(model='my_model')
params
但是当我运行失败的更新时,它返回默认值:
BanditParameters(model='my_model',training_time=None,arms={},n=0,mean_reward=0,total_arms=0)
返回 params = params.update(bucket,"bad_path")
params
[编辑]: 我找到了一种方法,但不确定这是否是最佳方法。它涉及删除更新的类方法:
BanditParameters(model='Default Bandit',total_arms=0)
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)