Fix numpy 2.0 compatibility (np.NINF removed)

This commit is contained in:
allegroai 2024-06-18 16:52:50 +03:00
parent cf4134ee02
commit 0cf4d6a6ce
2 changed files with 14 additions and 2 deletions

View File

@ -412,7 +412,13 @@ class Logger(object):
reporter_table = table.fillna(str(np.nan)) reporter_table = table.fillna(str(np.nan))
replace("NaN", np.nan, math.nan if six.PY3 else float("nan")) replace("NaN", np.nan, math.nan if six.PY3 else float("nan"))
replace("Inf", np.inf, math.inf if six.PY3 else float("inf")) replace("Inf", np.inf, math.inf if six.PY3 else float("inf"))
replace("-Inf", -np.inf, np.NINF, -math.inf if six.PY3 else -float("inf")) minus_inf = [-np.inf, -math.inf if six.PY3 else -float("inf")]
try:
minus_inf.append(np.NINF)
except AttributeError:
# NINF has been removed in numpy>2.0
pass
replace("-Inf", *minus_inf)
# noinspection PyProtectedMember # noinspection PyProtectedMember
return self._task._reporter.report_table( return self._task._reporter.report_table(
title=title, title=title,

View File

@ -647,7 +647,13 @@ class BaseModel(object):
reporter_table = table.fillna(str(np.nan)) reporter_table = table.fillna(str(np.nan))
replace("NaN", np.nan, math.nan if six.PY3 else float("nan")) replace("NaN", np.nan, math.nan if six.PY3 else float("nan"))
replace("Inf", np.inf, math.inf if six.PY3 else float("inf")) replace("Inf", np.inf, math.inf if six.PY3 else float("inf"))
replace("-Inf", -np.inf, np.NINF, -math.inf if six.PY3 else -float("inf")) minus_inf = [-np.inf, -math.inf if six.PY3 else -float("inf")]
try:
minus_inf.append(np.NINF)
except AttributeError:
# NINF has been removed in numpy>2.0
pass
replace("-Inf", *minus_inf)
self._init_reporter() self._init_reporter()
return self._reporter.report_table( return self._reporter.report_table(
title=title, title=title,