M2_SETI/IA/seti_master-master/code/pyrat_api.py
2023-01-29 16:56:40 +01:00

69 lines
No EOL
2.7 KiB
Python

from loguru import logger
import sys
sys.path.append("pyrat39/")
from analysis_param import AnalysisParam
from sound_pyrat.analyzer_split import analyze_splits
from pyrat import load_img
from sound_pyrat.domains.domain_factory import create_box_maker
import torch
from sound_pyrat.utils_analyzer import analyze, read_inputs, read_model, analyze_adv
import warnings
warnings.filterwarnings("ignore", message="invalid value encountered in *")
warnings.filterwarnings("ignore", message="divide by zero encountered in reciprocal")
def launch_pyrat(model_path, property_path, domains=[], split_timeout=0):
"""
Args:
model_path: path to the model
property_path: path to the property file
domains: a python list of domains to propagate in the network from "zonotopes" or "poly"
nb_of_splits: maximum number of splits to do on the input intervals (-1 for no split)
"""
logger.remove()
params = AnalysisParam(domains=domains, squeeze=True, verbose=False, check="skip", by_layer=True)
inputs = read_inputs(model_path=model_path, prop_path=property_path, params=params)
pyrat_model, bounds, to_verify, to_counter = inputs
if pyrat_model.root[0].input[0].input_shape == (1, 3):
pyrat_model.root[0].input[0].input_shape = (3,)
bounds = params.box_maker(bounds[0], bounds[1]).to_type(pyrat_model.dtype)
if pyrat_model.root[0].input[0].input_shape == (1, 3):
bounds = bounds.reshape((3,))
if split_timeout == 0:
result = analyze(bounds, pyrat_model, to_verify, to_counter, params)
print("Output bounds:\n", result.single_res.output_bounds)
print("Result = {}, Time = {:.2f} s".format(result.result, result.time))
else:
params.timeout = split_timeout
result = analyze_splits(bounds, pyrat_model, to_verify, to_counter, params)
result.print_res()
return result.result, result.time
def local_robustness(model_path, image, label, pert, domains=[]):
box_maker = create_box_maker("torch", sound=False, dtype=torch.float32)
params = AnalysisParam(total_labels=10, epsilon=pert, box_maker=box_maker, domains=domains,
by_layer=True, force_analysis=True, true_label=label)
pyrat_model = read_model(model_path, params=params)
res = analyze_adv(image, model_path, pyrat_model, params)
return res.result, res.time
def read_images():
images = []
logger.remove()
with open("fmnist/img_labels.csv") as f:
lines = f.readlines()
for line in lines[:1]:
name, label = line.split(",")
image = load_img(f"fmnist/images/{name}", grayscale=False)
label = int(label.replace("\n", "").strip())
images.append((image, label))
return images